From 50c89843f5485855a47df2e975b3d96e5932d992 Mon Sep 17 00:00:00 2001 From: Florian Gawrilowicz <floicz@gmail.com> Date: Sat, 8 Jun 2019 22:12:40 +0200 Subject: [PATCH] final push --- hw4/logger.py | 2 +- hw4/main.py | 6 +++--- hw4/model_based_policy.py | 18 ++++++++++++++---- hw4/model_based_rl.py | 18 +++++++++--------- hw4/run_all.sh | 0 hw4/utils.py | 2 +- 6 files changed, 28 insertions(+), 18 deletions(-) mode change 100644 => 100755 hw4/run_all.sh diff --git a/hw4/logger.py b/hw4/logger.py index 8df593e..b6f4f41 100644 --- a/hw4/logger.py +++ b/hw4/logger.py @@ -6,7 +6,7 @@ from colorlog import ColoredFormatter import pandas import numpy as np -from hw4.tabulate import tabulate +from tabulate import tabulate class LoggerClass(object): diff --git a/hw4/main.py b/hw4/main.py index f6db504..de4cad0 100644 --- a/hw4/main.py +++ b/hw4/main.py @@ -2,9 +2,9 @@ import os import argparse import time -from hw4.half_cheetah_env import HalfCheetahEnv -from hw4.logger import logger -from hw4.model_based_rl import ModelBasedRL +from half_cheetah_env import HalfCheetahEnv +from logger import logger +from model_based_rl import ModelBasedRL parser = argparse.ArgumentParser() parser.add_argument('question', type=str, choices=('q1, q2, q3')) diff --git a/hw4/model_based_policy.py b/hw4/model_based_policy.py index 70aaafe..4fd9764 100644 --- a/hw4/model_based_policy.py +++ b/hw4/model_based_policy.py @@ -1,7 +1,7 @@ import tensorflow as tf import numpy as np -import hw4.utils as utils +import utils class ModelBasedPolicy(object): @@ -141,7 +141,17 @@ class ModelBasedPolicy(object): """ ### PROBLEM 2 ### YOUR CODE HERE - raise NotImplementedError + actions = tf.random_uniform( + shape=[self._num_random_action_selection, self._horizon, self._action_dim], + minval=self._action_space_low, maxval=self._action_space_high + ) + costs = tf.zeros(self._num_random_action_selection) + states = tf.stack([state_ph[0]] * self._num_random_action_selection) + for t in range(self._horizon): + next_states = self._dynamics_func(states, actions[:, t, :], True) + costs += self._cost_fn(states, actions[:, t, :], next_states) + states = next_states + best_action = actions[tf.argmin(costs)][0] return best_action @@ -165,7 +175,7 @@ class ModelBasedPolicy(object): ### PROBLEM 2 ### YOUR CODE HERE - best_action = None + best_action = self._setup_action_selection(state_ph) sess.run(tf.global_variables_initializer()) @@ -222,7 +232,7 @@ class ModelBasedPolicy(object): ### PROBLEM 2 ### YOUR CODE HERE - raise NotImplementedError + best_action = self._sess.run(self._best_action, feed_dict={self._state_ph: [state]}) assert np.shape(best_action) == (self._action_dim,) return best_action diff --git a/hw4/model_based_rl.py b/hw4/model_based_rl.py index 531fe9e..a7c634a 100644 --- a/hw4/model_based_rl.py +++ b/hw4/model_based_rl.py @@ -3,10 +3,10 @@ import os import numpy as np import matplotlib.pyplot as plt -from hw4.model_based_policy import ModelBasedPolicy -import hw4.utils as utils -from hw4.logger import logger -from hw4.timer import timeit +from model_based_policy import ModelBasedPolicy +import utils +from logger import logger +from timer import timeit class ModelBasedRL(object): @@ -164,12 +164,12 @@ class ModelBasedRL(object): logger.info('Training policy....') ### PROBLEM 2 ### YOUR CODE HERE - raise NotImplementedError + self._train_policy(self._random_dataset) logger.info('Evaluating policy...') ### PROBLEM 2 ### YOUR CODE HERE - raise NotImplementedError + eval_dataset = self._gather_rollouts(self._policy, self._num_onpolicy_rollouts) logger.info('Trained policy') self._log(eval_dataset) @@ -193,16 +193,16 @@ class ModelBasedRL(object): ### PROBLEM 3 ### YOUR CODE HERE logger.info('Training policy...') - raise NotImplementedError + self._train_policy(dataset) ### PROBLEM 3 ### YOUR CODE HERE logger.info('Gathering rollouts...') - raise NotImplementedError + new_dataset = self._gather_rollouts(self._policy, self._num_onpolicy_rollouts) ### PROBLEM 3 ### YOUR CODE HERE logger.info('Appending dataset...') - raise NotImplementedError + dataset.append(new_dataset) self._log(new_dataset) diff --git a/hw4/run_all.sh b/hw4/run_all.sh old mode 100644 new mode 100755 diff --git a/hw4/utils.py b/hw4/utils.py index 13cd698..5d90585 100644 --- a/hw4/utils.py +++ b/hw4/utils.py @@ -1,7 +1,7 @@ import numpy as np import tensorflow as tf -from hw4.logger import logger +from logger import logger ############ -- GitLab