Skip to content
Snippets Groups Projects
Commit 112efef9 authored by tuhe's avatar tuhe
Browse files

Final lectures

parent 0b9db9c5
Branches
No related tags found
No related merge requests found
Showing
with 285 additions and 0 deletions
# This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
# This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
# from irlc.berkley.rl.feature_encoder import SimplePacmanExtractor
# from irlc.lectures.lecture_09_mc import keyboard_play
# alpha = 0.5
# gamma =
# def open_play(Agent, method_label, **args):
# env = OpenGridEnvironment()
# agent = Agent(env, gamma=0.95, epsilon=0.1, alpha=.5, **args)
# keyboard_play(env, agent, method_label=method_label)
from irlc.lectures.lec11.lecture_10_sarsa_open import open_play
from irlc.ex10.mc_agent import MCAgent
if __name__ == "__main__":
# env = OpenGridEnvironment()
# agent = (env, gamma=0.95, epsilon=0.1, alpha=.5)
open_play(MCAgent, method_label="MC agent")
#
# This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
from irlc.ex11.semi_grad_q import LinearSemiGradQAgent
from irlc.pacman.pacman_environment import PacmanEnvironment
from irlc.ex01.agent import train
from irlc import interactive
from irlc.lectures.chapter14lectures.lecture11pacman import layout, rns
# from irlc import VideoMonitor
if __name__ == "__main__":
env = PacmanEnvironment(animate_movement=False, layout=layout)
n, agent = rns[-1]
agent = agent(env)
# env, agent = interactive(env, agent)
train(env, agent, num_episodes=100, max_runs=20)
env2 = PacmanEnvironment(animate_movement=True, layout=layout, render_mode='human')
# agent.env = env2
env2, agent = interactive(env2, agent)
train(env2, agent, num_episodes=100, max_runs=20)
env2.close()
# This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
from irlc.lectures.lec11.lecture_10_sarsa_open import open_play
from irlc.lectures.lec12.sarsa_lambda_delay import SarsaLambdaDelayAgent
if __name__ == "__main__":
open_play(SarsaLambdaDelayAgent, method_label="Sarsa(Lambda)", lamb=0.8)
# This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
from irlc.gridworld.gridworld_environments import OpenGridEnvironment
from irlc import train
from irlc.lectures.lec11.lecture_10_sarsa_open import open_play
from irlc.lectures.lec11.sarsa_nstep_delay import SarsaDelayNAgent
if __name__ == "__main__":
n = 8
env = OpenGridEnvironment()
agent = SarsaDelayNAgent(env, n=n)
train(env, agent, num_episodes=100)
open_play(SarsaDelayNAgent, method_label=f"Sarsa n={n}", n=n)
# This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
from irlc import train
from irlc.gridworld.gridworld_environments import OpenGridEnvironment
from irlc.lectures.lec11.lecture_10_sarsa_open import open_play
from irlc.lectures.lec11.sarsa_nstep_delay import SarsaDelayNAgent
if __name__ == "__main__":
env = OpenGridEnvironment()
agent = SarsaDelayNAgent(env, n=1)
train(env, agent, num_episodes=100)
open_play(SarsaDelayNAgent, method_label=f"Sarsa")
# This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
from irlc.lectures.lec12.td_lambda import TDLambdaAgent
if __name__ == "__main__":
from irlc.lectures.lec10.lecture_10_mc_q_estimation import keyboard_play
from irlc.gridworld.gridworld_environments import OpenGridEnvironment
env = OpenGridEnvironment(render_mode='human', frames_per_second=30)
gam = 0.99
alpha = 0.5
lamb = 0.9
agent = TDLambdaAgent(env, gamma=gam, alpha=alpha, lamb=lamb)
method_label = f'TD(Lambda={lamb})'
method_label = f"{method_label} (gamma={gam}, alpha={alpha})"
keyboard_play(env, agent, method_label=method_label)
env.close()
# This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
from collections import defaultdict
from irlc.ex11.q_agent import QAgent
class SarsaLambdaDelayAgent(QAgent):
def __init__(self, env, gamma=0.99, epsilon=0.1, alpha=0.5, lamb=0.9):
super().__init__(env, gamma=gamma, alpha=alpha, epsilon=epsilon)
self.lamb = lamb
self.method = 'Sarsa(Lambda)'
self.e = defaultdict(float)
def pi(self, s, k, info=None):
self.t = k
action = self.pi_eps(s,info=info)
return action
def lmb_update(self, s, a, r, sp, ap, done):
delta = r + self.gamma * (self.Q[sp,ap] if not done else 0) - self.Q[s,a]
for (s,a), ee in self.e.items():
self.Q[s,a] += self.alpha * delta * ee
self.e[(s,a)] = self.gamma * self.lamb * ee
def train(self, s, a, r, sp, done=False, info_s=None, info_sp=None):
# if self.t == 0:
# self.e.clear()
if self.t > 0:
# We have an update in the buffer and can update the states.
self.lmb_update(self.s_prev, self.a_prev, self.r_prev, s, a, done=False)
self.e[(s, a)] += 1
if done:
self.lmb_update(s, a, r, sp, ap=None, done=True)
self.e.clear()
self.s_prev = s
self.a_prev = a
self.r_prev = r
def __str__(self):
return f"SarsaLambdaDelay_{self.gamma}_{self.epsilon}_{self.alpha}_{self.lamb}"
if __name__ == "__main__":
from irlc.ex12.sarsa_lambda_open import keyboard_play
keyboard_play(SarsaLambdaDelayAgent, method_label="Sarsa(Lambda) (delayed)")
# This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
from irlc.ex09.rl_agent import ValueAgent
from collections import defaultdict
class TDLambdaAgent(ValueAgent):
def __init__(self, env, gamma=0.99, alpha=0.5, lamb=0.9):
# def __init__(self, env, policy=None, gamma=0.99, alpha=0.05, v_init_fun=None):
self.alpha = alpha
self.lamb = lamb
self.e = defaultdict(float)
super().__init__(env, gamma=gamma)
# def pi(self, s, k, info=None):
# action = super().pi(s, k, info=info)
# return action
def train(self, s, a, r, sp, done=False, info_s=None, info_sp=None):
self.e[s] += 1
delta = r + self.gamma * (self.v[sp] if not done else 0) - self.v[s]
for s, ee in self.e.items():
self.v[s] += self.alpha * delta * ee
self.e[s] = self.gamma * self.lamb * ee
if done:
self.e.clear()
# def train(self, s, a, r, sp, done=False, info_s=None, info_sp=None):
# TODO: 3 lines missing.
raise NotImplementedError("Implement function body")
return f"TD(Lambda={self.lamb})_value_{self.gamma}_{self.alpha}"
if __name__ == "__main__":
from irlc.lectures.lec10.lecture_10_mc_q_estimation import keyboard_play
from irlc.gridworld.gridworld_environments import OpenGridEnvironment
env = OpenGridEnvironment(render_mode='human', frames_per_second=30)
agent = TDLambdaAgent(env, gamma=1, alpha=.5, lamb=0.9)
method_label = 'TD(Lambda)'
method_label = f"{method_label} (gamma=0.99, alpha=0.5)"
keyboard_play(env, agent, method_label=method_label)
env.close()
# if __name__ == "__main__":
#
# open_play(TDLambdaAgent, method_label="TD(Lambda)", lamb=0.8)
#
# pass
# This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
import numpy as np
from irlc.ex01.agent import train
import gymnasium as gym
from irlc import main_plot
import matplotlib.pyplot as plt
from irlc import savepdf
from irlc.ex11.sarsa_agent import SarsaAgent
from irlc.ex11.q_agent import QAgent
from irlc.ex13.tabular_double_q import TabularDoubleQ
from irlc.ex09.rl_agent import TabularQ
from irlc.gridworld.gridworld_environments import CliffGridEnvironment
class DoubleQVizAgent(TabularDoubleQ):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.Q = TabularQ(self.env)
def train(self, s, a, r, sp, done=False, info_s=None, info_sp=None):
super().train(s, a, r, sp, done, info_s,info_sp)
self.Q[s,a] = (self.Q1[s,a] + self.Q2[s,a] )/2
def train_cliff(runs=4, extension="long", save_pdf=False, alpha=0.02, num_episodes=5000):
""" Part 1: Cliffwalking """
# env = gym.make('CliffWalking-v0')
env = CliffGridEnvironment(zoom=1)
epsilon = 0.1
# alpha = 0.02
for _ in range(runs):
agents = [QAgent(env, gamma=1, epsilon=epsilon, alpha=alpha),
SarsaAgent(env, gamma=1, epsilon=epsilon, alpha=alpha),
DoubleQVizAgent(env, gamma=1, epsilon=epsilon, alpha=alpha)]
experiments = []
for agent in agents:
expn = f"experiments/doubleq_cliffwalk_{extension}_{str(agent)}"
train(env, agent, expn, num_episodes=num_episodes, max_runs=1e6)
experiments.append(expn)
if save_pdf:
main_plot(experiments, smoothing_window=20, resample_ticks=500)
plt.ylim([-100, 50])
plt.title(f"Double-Q learning on Cliffwalk ({extension})")
savepdf(f"double_Q_learning_cliff_{extension}")
plt.show()
return agents, env
def grid_experiment(runs=20, extension="long", alpha=0.02, num_episodes=5000):
from irlc.gridworld.gridworld_environments import CliffGridEnvironment
# from irlc import VideoMonitor, PlayWrapper
from irlc import interactive
agents, env = train_cliff(runs=runs, extension=extension, save_pdf=True, alpha=alpha, num_episodes=num_episodes)
labels = ["Q-learning", "Sarsa", "Double Q-learning"]
for na in range(len(agents)):
env2 = CliffGridEnvironment(zoom=1, view_mode='human')
env2, agent = interactive(env2, agent=agents[na])# , agent_monitor_keys=('Q',), render_kwargs={'method_label': labels[na]})
# agent = PlayWrapper(agents[na], env)
env2.savepdf(f"doubleq_cliff_{extension}_agent_{na}")
env2.close()
env.close()
pass
if __name__ == "__main__":
"""
Test cliffwalk in both the long and short version
"""
grid_experiment(runs=1, extension="long", alpha=0.02, num_episodes=5000)
grid_experiment(runs=1, extension="short", alpha=0.25, num_episodes=500)
# This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
from irlc.lectures.lec11.lecture_10_sarsa_open import open_play
from irlc.ex11.q_agent import QAgent
from irlc.ex13.dyna_q import DynaQ
from irlc.lectures.lec10.lecture_10_mc_q_estimation import keyboard_play
from irlc.gridworld.gridworld_environments import SuttonMazeEnvironment
def sutton_maze_play(Agent, method_label="Q-learning agent", **kwargs):
env = SuttonMazeEnvironment(render_mode="human")
agent = Agent(env, gamma=0.98, epsilon=0.1, alpha=.5, **kwargs)
keyboard_play(env, agent, method_label=method_label)
if __name__ == "__main__":
sutton_maze_play(DynaQ, method_label="Q-learning agent", n=0)
# This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
from irlc.lectures.lec11.lecture_10_sarsa_open import open_play
from irlc.ex11.q_agent import QAgent
if __name__ == "__main__":
open_play(QAgent, method_label="Q-learning agent")
# This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
from irlc.lectures.lec11.lecture_10_sarsa_open import open_play
from irlc.ex11.q_agent import QAgent
from irlc.ex13.dyna_q import DynaQ
from irlc.lectures.lec10.lecture_10_mc_q_estimation import keyboard_play
from irlc.gridworld.gridworld_environments import SuttonMazeEnvironment
from irlc.lectures.lec13.lecture_13_Q_maze import sutton_maze_play
if __name__ == "__main__":
sutton_maze_play(DynaQ, method_label="DynaQ (n=5)", n=5)
# This file may not be shared/redistributed without permission. Please read copyright notice in the git repo. If this file contains other copyright notices disregard this text.
from irlc.lectures.lec13.lecture_13_Q_maze import sutton_maze_play
from irlc.ex12.sarsa_lambda_agent import SarsaLambdaAgent
if __name__ == "__main__":
sutton_maze_play(SarsaLambdaAgent, method_label="Sarsa(Lambda=0.9)", lamb=0.9)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment