Skip to content
Snippets Groups Projects
Commit 26ab7500 authored by tuhe's avatar tuhe
Browse files

Minor fixes to show eligibility trace for TD(Lambda)

parent 02d61ecc
Branches
No related tags found
No related merge requests found
...@@ -183,7 +183,16 @@ class GridworldEnvironment(MDP2GymEnv): ...@@ -183,7 +183,16 @@ class GridworldEnvironment(MDP2GymEnv):
else: else:
returns_sum = None returns_sum = None
self.display_pygame.displayValues(mdp=self.mdp, v=v, preferred_actions=preferred_actions, currentState=state, message=label, returns_count=returns_count, returns_sum=returns_sum) if hasattr(agent, 'e') and isinstance(agent.e, defaultdict):
eligibility_trace = defaultdict(float)
for k, vv in agent.e.items():
if vv > 0:
eligibility_trace[k] = vv
else:
eligibility_trace = None
self.display_pygame.displayValues(mdp=self.mdp, v=v, preferred_actions=preferred_actions, currentState=state, message=label, returns_count=returns_count, returns_sum=returns_sum,
eligibility_trace=eligibility_trace)
elif avail_modes[self.view_mode] == 'Q': elif avail_modes[self.view_mode] == 'Q':
......
...@@ -115,12 +115,9 @@ class GraphicsGridworldDisplay: ...@@ -115,12 +115,9 @@ class GraphicsGridworldDisplay:
# def end_frame(self): # def end_frame(self):
# self.ga.end_frame() # self.ga.end_frame()
def displayValues(self, mdp, v, preferred_actions=None, currentState=None, message='Agent Values', returns_count=None, returns_sum=None): def displayValues(self, mdp, v, preferred_actions=None, currentState=None, message='Agent Values', returns_count=None, returns_sum=None,
# if self.v_old == None: eligibility_trace=None):
# self.ga.gc.clear()
# self.v_old = {}
# else:
# pass
self.ga.draw_background() self.ga.draw_background()
m = [v[s] for s in mdp.nonterminal_states] m = [v[s] for s in mdp.nonterminal_states]
self.Q_old = None self.Q_old = None
...@@ -150,16 +147,15 @@ class GraphicsGridworldDisplay: ...@@ -150,16 +147,15 @@ class GraphicsGridworldDisplay:
returns_sum_ = returns_sum[state] if returns_sum is not None else None returns_sum_ = returns_sum[state] if returns_sum is not None else None
returns_count_ = returns_count[state] if returns_count is not None else None returns_count_ = returns_count[state] if returns_count is not None else None
# de = 8
de = eligibility_trace[state] if eligibility_trace is not None and state in eligibility_trace else None
self.drawSquare(name, x, y, value, minValue, maxValue, valString, all_actions, False, isExit, isCurrent, self.drawSquare(name, x, y, value, minValue, maxValue, valString, all_actions, False, isExit, isCurrent,
returns_sum=returns_sum_, returns_count=returns_count_) returns_sum=returns_sum_, returns_count=returns_count_,
eligibility_trace=de)
# print("Drawing...")
if isinstance(currentState, tuple): if isinstance(currentState, tuple):
# print("found pacman")
screen_x, screen_y = self.to_screen(currentState) screen_x, screen_y = self.to_screen(currentState)
self.draw_player((screen_x, screen_y), 0.12 * self.GRID_SIZE) self.draw_player((screen_x, screen_y), 0.12 * self.GRID_SIZE)
# else:
# print("no instance found??")
pos = self.to_screen(((mdp.width - 1.0) / 2.0, - 0.8)) pos = self.to_screen(((mdp.width - 1.0) / 2.0, - 0.8))
self.ga.text(f"v_text_", pos, TEXT_COLOR, message, "Courier", -32, "bold", "c") self.ga.text(f"v_text_", pos, TEXT_COLOR, message, "Courier", -32, "bold", "c")
...@@ -324,7 +320,7 @@ class GraphicsGridworldDisplay: ...@@ -324,7 +320,7 @@ class GraphicsGridworldDisplay:
def drawSquare(self, name, x, y, val, min, max, valStr, all_action, isObstacle, isTerminal, isCurrent, def drawSquare(self, name, x, y, val, min, max, valStr, all_action, isObstacle, isTerminal, isCurrent,
returns_count=None, returns_sum=None): returns_count=None, returns_sum=None, eligibility_trace=None):
square_color = getColor(val, min, max) square_color = getColor(val, min, max)
(screen_x, screen_y) = self.to_screen((x, y)) (screen_x, screen_y) = self.to_screen((x, y))
if isObstacle: if isObstacle:
...@@ -370,6 +366,16 @@ class GraphicsGridworldDisplay: ...@@ -370,6 +366,16 @@ class GraphicsGridworldDisplay:
if returns_sum is not None: if returns_sum is not None:
self.ga.text(name + "_rs", (screen_x-GRID_SIZE/3, screen_y+2*GRID_SIZE/7), RED_TEXT_COLOR, f"S(s)={returns_sum:.2f}", "Courier", -20, "bold", "w") self.ga.text(name + "_rs", (screen_x-GRID_SIZE/3, screen_y+2*GRID_SIZE/7), RED_TEXT_COLOR, f"S(s)={returns_sum:.2f}", "Courier", -20, "bold", "w")
if eligibility_trace is not None:
# if eligibility_trace is not None:
estr = f'{eligibility_trace:.2f}'
# dh = 0.95 * GRID_SIZE
# dw = 0.5 * GRID_SIZE
ECOL = RED_TEXT_COLOR if eligibility_trace != 0 else square_color
esize = -16
self.ga.text(name + "_txt1e", (screen_x, screen_y + GRID_SIZE/5), ECOL, estr, "Courier", esize, "bold", "c")
# if returns_count is not None: # if returns_count is not None:
# self.ga.text(name + "_rs", (screen_x, screen_y), text_color, valStr, "Courier", -30, "bold", "c") # self.ga.text(name + "_rs", (screen_x, screen_y), text_color, valStr, "Courier", -30, "bold", "c")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment