Skip to content
Snippets Groups Projects
Commit 50c89843 authored by Florian Gawrilowicz's avatar Florian Gawrilowicz
Browse files

final push

parent 49151bb2
No related branches found
No related tags found
No related merge requests found
......@@ -6,7 +6,7 @@ from colorlog import ColoredFormatter
import pandas
import numpy as np
from hw4.tabulate import tabulate
from tabulate import tabulate
class LoggerClass(object):
......
......@@ -2,9 +2,9 @@ import os
import argparse
import time
from hw4.half_cheetah_env import HalfCheetahEnv
from hw4.logger import logger
from hw4.model_based_rl import ModelBasedRL
from half_cheetah_env import HalfCheetahEnv
from logger import logger
from model_based_rl import ModelBasedRL
parser = argparse.ArgumentParser()
parser.add_argument('question', type=str, choices=('q1, q2, q3'))
......
import tensorflow as tf
import numpy as np
import hw4.utils as utils
import utils
class ModelBasedPolicy(object):
......@@ -141,7 +141,17 @@ class ModelBasedPolicy(object):
"""
### PROBLEM 2
### YOUR CODE HERE
raise NotImplementedError
actions = tf.random_uniform(
shape=[self._num_random_action_selection, self._horizon, self._action_dim],
minval=self._action_space_low, maxval=self._action_space_high
)
costs = tf.zeros(self._num_random_action_selection)
states = tf.stack([state_ph[0]] * self._num_random_action_selection)
for t in range(self._horizon):
next_states = self._dynamics_func(states, actions[:, t, :], True)
costs += self._cost_fn(states, actions[:, t, :], next_states)
states = next_states
best_action = actions[tf.argmin(costs)][0]
return best_action
......@@ -165,7 +175,7 @@ class ModelBasedPolicy(object):
### PROBLEM 2
### YOUR CODE HERE
best_action = None
best_action = self._setup_action_selection(state_ph)
sess.run(tf.global_variables_initializer())
......@@ -222,7 +232,7 @@ class ModelBasedPolicy(object):
### PROBLEM 2
### YOUR CODE HERE
raise NotImplementedError
best_action = self._sess.run(self._best_action, feed_dict={self._state_ph: [state]})
assert np.shape(best_action) == (self._action_dim,)
return best_action
......@@ -3,10 +3,10 @@ import os
import numpy as np
import matplotlib.pyplot as plt
from hw4.model_based_policy import ModelBasedPolicy
import hw4.utils as utils
from hw4.logger import logger
from hw4.timer import timeit
from model_based_policy import ModelBasedPolicy
import utils
from logger import logger
from timer import timeit
class ModelBasedRL(object):
......@@ -164,12 +164,12 @@ class ModelBasedRL(object):
logger.info('Training policy....')
### PROBLEM 2
### YOUR CODE HERE
raise NotImplementedError
self._train_policy(self._random_dataset)
logger.info('Evaluating policy...')
### PROBLEM 2
### YOUR CODE HERE
raise NotImplementedError
eval_dataset = self._gather_rollouts(self._policy, self._num_onpolicy_rollouts)
logger.info('Trained policy')
self._log(eval_dataset)
......@@ -193,16 +193,16 @@ class ModelBasedRL(object):
### PROBLEM 3
### YOUR CODE HERE
logger.info('Training policy...')
raise NotImplementedError
self._train_policy(dataset)
### PROBLEM 3
### YOUR CODE HERE
logger.info('Gathering rollouts...')
raise NotImplementedError
new_dataset = self._gather_rollouts(self._policy, self._num_onpolicy_rollouts)
### PROBLEM 3
### YOUR CODE HERE
logger.info('Appending dataset...')
raise NotImplementedError
dataset.append(new_dataset)
self._log(new_dataset)
hw4/run_all.sh 100644 → 100755
File mode changed from 100644 to 100755
import numpy as np
import tensorflow as tf
from hw4.logger import logger
from logger import logger
############
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment