Skip to content
Snippets Groups Projects
Commit 0d5ed41b authored by Kate Rakelly's avatar Kate Rakelly
Browse files

HW5c fix (nit): make `build_rnn` more clear, remove unneeded arg

parent 5e908c25
Branches
No related tags found
No related merge requests found
......@@ -57,11 +57,14 @@ def build_mlp(x, output_size, scope, n_layers, size, activation=tf.tanh, output_
x = tf.layers.dense(inputs=x, units=output_size, activation=output_activation, name='fc{}'.format(i + 1), kernel_regularizer=regularizer, bias_regularizer=regularizer)
return x
def build_rnn(x, h, output_size, scope, n_layers, size, gru_size, activation=tf.tanh, output_activation=None, regularizer=None):
def build_rnn(x, h, output_size, scope, n_layers, size, activation=tf.tanh, output_activation=None, regularizer=None):
"""
builds a gated recurrent neural network
inputs are first embedded by an MLP then passed to a GRU cell
make MLP layers with `size` number of units
make the GRU with `output_size` number of units
arguments:
(see `build_policy()`)
......@@ -96,7 +99,7 @@ def build_policy(x, h, output_size, scope, n_layers, size, gru_size, recurrent=T
"""
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
if recurrent:
x, h = build_rnn(x, h, gru_size, scope, n_layers, size, gru_size, activation=activation, output_activation=output_activation)
x, h = build_rnn(x, h, gru_size, scope, n_layers, size, activation=activation, output_activation=output_activation)
else:
x = tf.reshape(x, (-1, x.get_shape()[1]*x.get_shape()[2]))
x = build_mlp(x, gru_size, scope, n_layers + 1, size, activation=activation, output_activation=activation)
......@@ -115,7 +118,7 @@ def build_critic(x, h, output_size, scope, n_layers, size, gru_size, recurrent=T
"""
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
if recurrent:
x, h = build_rnn(x, h, gru_size, scope, n_layers, size, gru_size, activation=activation, output_activation=output_activation, regularizer=regularizer)
x, h = build_rnn(x, h, gru_size, scope, n_layers, size, activation=activation, output_activation=output_activation, regularizer=regularizer)
else:
x = tf.reshape(x, (-1, x.get_shape()[1]*x.get_shape()[2]))
x = build_mlp(x, gru_size, scope, n_layers + 1, size, activation=activation, output_activation=activation, regularizer=regularizer)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment