Source code for bulletarm_baselines.logger.plotter

'''
.. moduleauthor: Colin Kohler <github.com/ColinKohler>
'''

import os
import pickle
import numpy as np
import more_itertools
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

[docs]class Plotter(object): ''' Plotting utility. ''' def __init__(self, log_filepaths, log_names): self.logs = self.loadLogs(log_filepaths, log_names)
[docs] def loadLogs(self, filepaths, names): ''' ''' logs = dict() for n, fp in zip(names, filepaths): if os.path.exists(fp): with open(fp, 'rb') as f: logs[n] = pickle.load(f) else: print('No log found at {}'.format(f)) return logs
[docs] def plotLearningCurves(self, title, filepath, window=100, max_eps=None): ''' Plot mulitple learning curves on a single plot. Args: ''' fig, ax = plt.subplots() ax.set_title(title) ax.set_xlabel('Episodes') ax.set_ylabel('Avg. Reward') for log_name, log in self.logs.items(): eps_rewards = log['training_eps_rewards'] if max_eps: eps_rewards = eps_rewards[:max_eps] avg_reward = np.mean(list(more_itertools.windowed(eps_rewards, window)), axis=1) xs = np.arange(window, len(avg_reward) + window) ax.plot(xs, avg_reward, label=log_name) ax.legend() plt.savefig(filepath) plt.close()
[docs] def plotLearningCurve(self, name, title, filepath, window=100): ''' Plot the learning curve for the given episode rewards. Args: ''' fig, ax = plt.subplots() ax.set_title(title) ax.set_xlabel('Episodes') ax.set_ylabel('Avg. Reward') eps_reward = self.logs[name]['training_eps_rewards'] avg_reward = np.mean(list(more_itertools.windowed(eps_rewards, window)), axis=1) xs = np.arange(window, len(avg_reward) + window) ax.plot(xs, avg_reward) plt.savefig(filepath) plt.close()