diff --git a/hw4/logger.py b/hw4/logger.py index 8df593e0e0b0da994d84552079bdef4e7ed185d3..b6f4f4134fd3ee2a83002c1dac79dfa5b7688e15 100644 --- a/hw4/logger.py +++ b/hw4/logger.py @@ -6,7 +6,7 @@ from colorlog import ColoredFormatter import pandas import numpy as np -from hw4.tabulate import tabulate +from tabulate import tabulate class LoggerClass(object): diff --git a/hw4/main.py b/hw4/main.py index f6db504f99e9c5c60c1a2b5ced9ed5f8902c5069..de4cad01d86befa54c7c74bb85284362f54f39a9 100644 --- a/hw4/main.py +++ b/hw4/main.py @@ -2,9 +2,9 @@ import os import argparse import time -from hw4.half_cheetah_env import HalfCheetahEnv -from hw4.logger import logger -from hw4.model_based_rl import ModelBasedRL +from half_cheetah_env import HalfCheetahEnv +from logger import logger +from model_based_rl import ModelBasedRL parser = argparse.ArgumentParser() parser.add_argument('question', type=str, choices=('q1, q2, q3')) diff --git a/hw4/model_based_policy.py b/hw4/model_based_policy.py index 70aaafeda3f622ad0c69815fb4eab049d5f31ff0..4fd9764d67c76b2abbdf5ac57aab4b499d502fcd 100644 --- a/hw4/model_based_policy.py +++ b/hw4/model_based_policy.py @@ -1,7 +1,7 @@ import tensorflow as tf import numpy as np -import hw4.utils as utils +import utils class ModelBasedPolicy(object): @@ -141,7 +141,17 @@ class ModelBasedPolicy(object): """ ### PROBLEM 2 ### YOUR CODE HERE - raise NotImplementedError + actions = tf.random_uniform( + shape=[self._num_random_action_selection, self._horizon, self._action_dim], + minval=self._action_space_low, maxval=self._action_space_high + ) + costs = tf.zeros(self._num_random_action_selection) + states = tf.stack([state_ph[0]] * self._num_random_action_selection) + for t in range(self._horizon): + next_states = self._dynamics_func(states, actions[:, t, :], True) + costs += self._cost_fn(states, actions[:, t, :], next_states) + states = next_states + best_action = actions[tf.argmin(costs)][0] return best_action @@ -165,7 +175,7 @@ class ModelBasedPolicy(object): ### PROBLEM 2 ### YOUR CODE HERE - best_action = None + best_action = self._setup_action_selection(state_ph) sess.run(tf.global_variables_initializer()) @@ -222,7 +232,7 @@ class ModelBasedPolicy(object): ### PROBLEM 2 ### YOUR CODE HERE - raise NotImplementedError + best_action = self._sess.run(self._best_action, feed_dict={self._state_ph: [state]}) assert np.shape(best_action) == (self._action_dim,) return best_action diff --git a/hw4/model_based_rl.py b/hw4/model_based_rl.py index 531fe9ed0dead256728e5dcae39887eae7144ca5..a7c634ad05b9833889b9c64fb379ac793ab4a139 100644 --- a/hw4/model_based_rl.py +++ b/hw4/model_based_rl.py @@ -3,10 +3,10 @@ import os import numpy as np import matplotlib.pyplot as plt -from hw4.model_based_policy import ModelBasedPolicy -import hw4.utils as utils -from hw4.logger import logger -from hw4.timer import timeit +from model_based_policy import ModelBasedPolicy +import utils +from logger import logger +from timer import timeit class ModelBasedRL(object): @@ -164,12 +164,12 @@ class ModelBasedRL(object): logger.info('Training policy....') ### PROBLEM 2 ### YOUR CODE HERE - raise NotImplementedError + self._train_policy(self._random_dataset) logger.info('Evaluating policy...') ### PROBLEM 2 ### YOUR CODE HERE - raise NotImplementedError + eval_dataset = self._gather_rollouts(self._policy, self._num_onpolicy_rollouts) logger.info('Trained policy') self._log(eval_dataset) @@ -193,16 +193,16 @@ class ModelBasedRL(object): ### PROBLEM 3 ### YOUR CODE HERE logger.info('Training policy...') - raise NotImplementedError + self._train_policy(dataset) ### PROBLEM 3 ### YOUR CODE HERE logger.info('Gathering rollouts...') - raise NotImplementedError + new_dataset = self._gather_rollouts(self._policy, self._num_onpolicy_rollouts) ### PROBLEM 3 ### YOUR CODE HERE logger.info('Appending dataset...') - raise NotImplementedError + dataset.append(new_dataset) self._log(new_dataset) diff --git a/hw4/run_all.sh b/hw4/run_all.sh old mode 100644 new mode 100755 diff --git a/hw4/utils.py b/hw4/utils.py index 13cd698cd44d1c97b74f721e3f4afaed9d01a56f..5d90585dc0b5d31f8223c9199b24dd193cc5bc84 100644 --- a/hw4/utils.py +++ b/hw4/utils.py @@ -1,7 +1,7 @@ import numpy as np import tensorflow as tf -from hw4.logger import logger +from logger import logger ############