Skip to content
Snippets Groups Projects
Commit fae6e7ba authored by Kate Rakelly's avatar Kate Rakelly
Browse files

HW5c fix: Problem 1 eval deterministically on all tasks

parent d2dedd10
Branches
No related tags found
No related merge requests found
...@@ -18,6 +18,7 @@ class ObservedPointEnv(Env): ...@@ -18,6 +18,7 @@ class ObservedPointEnv(Env):
# YOUR CODE SOMEWHERE HERE # YOUR CODE SOMEWHERE HERE
def __init__(self, num_tasks=1): def __init__(self, num_tasks=1):
self.tasks = [0, 1, 2, 3][:num_tasks] self.tasks = [0, 1, 2, 3][:num_tasks]
self.task_idx = -1
self.reset_task() self.reset_task()
self.reset() self.reset()
...@@ -25,10 +26,15 @@ class ObservedPointEnv(Env): ...@@ -25,10 +26,15 @@ class ObservedPointEnv(Env):
self.action_space = spaces.Box(low=-0.1, high=0.1, shape=(2,)) self.action_space = spaces.Box(low=-0.1, high=0.1, shape=(2,))
def reset_task(self, is_evaluation=False): def reset_task(self, is_evaluation=False):
idx = np.random.choice(len(self.tasks)) # for evaluation, cycle deterministically through all tasks
self._task = self.tasks[idx] if is_evaluation:
self.task_idx = (self.task_idx + 1) % len(self.tasks)
# during training, sample tasks randomly
else:
self.task_idx = np.random.randint(len(self.tasks))
self._task = self.tasks[self.task_idx]
goals = [[-1, -1], [-1, 1], [1, -1], [1, 1]] goals = [[-1, -1], [-1, 1], [1, -1], [1, 1]]
self._goal = np.array(goals[idx])*10 self._goal = np.array(goals[self.task_idx])*10
def reset(self): def reset(self):
self._state = np.array([0, 0], dtype=np.float32) self._state = np.array([0, 0], dtype=np.float32)
......
...@@ -689,8 +689,11 @@ def train_PG( ...@@ -689,8 +689,11 @@ def train_PG(
# sample trajectories to fill agent's replay buffer # sample trajectories to fill agent's replay buffer
print("********** Iteration %i ************"%itr) print("********** Iteration %i ************"%itr)
stats, timesteps_this_batch = agent.sample_trajectories(itr, env, min_timesteps_per_batch) stats = []
for _ in range(num_tasks):
s, timesteps_this_batch = agent.sample_trajectories(itr, env, min_timesteps_per_batch)
total_timesteps += timesteps_this_batch total_timesteps += timesteps_this_batch
stats += s
# compute the log probs, advantages, and returns for all data in agent's buffer # compute the log probs, advantages, and returns for all data in agent's buffer
# store in ppo buffer for use in multiple ppo updates # store in ppo buffer for use in multiple ppo updates
...@@ -720,7 +723,10 @@ def train_PG( ...@@ -720,7 +723,10 @@ def train_PG(
# compute validation statistics # compute validation statistics
print('Validating...') print('Validating...')
val_stats, timesteps_this_batch = agent.sample_trajectories(itr, env, min_timesteps_per_batch // 10, is_evaluation=True) val_stats = []
for _ in range(num_tasks):
vs, timesteps_this_batch = agent.sample_trajectories(itr, env, min_timesteps_per_batch // 10, is_evaluation=True)
val_stats += vs
# save trajectories for viz # save trajectories for viz
with open("output/{}-epoch{}.pkl".format(exp_name, itr), 'wb') as f: with open("output/{}-epoch{}.pkl".format(exp_name, itr), 'wb') as f:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment