Skip to content
Snippets Groups Projects
Commit d2dedd10 authored by Kate Rakelly's avatar Kate Rakelly
Browse files

HW5c fix: critic arch should match policy

also remove debugging print statement
parent ee7e7a46
No related branches found
No related tags found
No related merge requests found
...@@ -308,7 +308,7 @@ class Agent(object): ...@@ -308,7 +308,7 @@ class Agent(object):
# PPO critic update # PPO critic update
critic_regularizer = tf.contrib.layers.l2_regularizer(1e-3) if self.l2reg else None critic_regularizer = tf.contrib.layers.l2_regularizer(1e-3) if self.l2reg else None
self.critic_prediction = tf.squeeze(build_critic(self.sy_ob_no, self.sy_hidden, 1, 'critic_network', n_layers=self.n_layers, size=self.size, gru_size=self.gru_size, regularizer=critic_regularizer)) self.critic_prediction = tf.squeeze(build_critic(self.sy_ob_no, self.sy_hidden, 1, 'critic_network', n_layers=self.n_layers, size=self.size, gru_size=self.gru_size, recurrent=self.recurrent, regularizer=critic_regularizer))
self.sy_target_n = tf.placeholder(shape=[None], name="critic_target", dtype=tf.float32) self.sy_target_n = tf.placeholder(shape=[None], name="critic_target", dtype=tf.float32)
self.critic_loss = tf.losses.mean_squared_error(self.sy_target_n, self.critic_prediction) self.critic_loss = tf.losses.mean_squared_error(self.sy_target_n, self.critic_prediction)
self.critic_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='critic_network') self.critic_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='critic_network')
...@@ -715,7 +715,6 @@ def train_PG( ...@@ -715,7 +715,6 @@ def train_PG(
log_probs = agent.sess.run(agent.sy_lp_n, log_probs = agent.sess.run(agent.sy_lp_n,
feed_dict={agent.sy_ob_no: ob_no, agent.sy_hidden: hidden, agent.sy_ac_na: ac_na}) feed_dict={agent.sy_ob_no: ob_no, agent.sy_hidden: hidden, agent.sy_ac_na: ac_na})
print('new log prob', log_probs.shape)
agent.update_parameters(ob_no, hidden, ac_na, fixed_log_probs, q_n, adv_n) agent.update_parameters(ob_no, hidden, ac_na, fixed_log_probs, q_n, adv_n)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment