Skip to content
Snippets Groups Projects
Commit c4496029 authored by Florian Gawrilowicz's avatar Florian Gawrilowicz
Browse files

vanilla DQN - solves Pong, PongRam, LunarLander

parent 044ab278
No related branches found
No related tags found
No related merge requests found
......@@ -9,10 +9,11 @@ import random
import tensorflow as tf
import tensorflow.contrib.layers as layers
from collections import namedtuple
from dqn_utils import *
from hw3.dqn_utils import *
OptimizerSpec = namedtuple("OptimizerSpec", ["constructor", "kwargs", "lr_schedule"])
class QLearner(object):
def __init__(
......@@ -160,6 +161,15 @@ class QLearner(object):
# YOUR CODE HERE
self.q = q_func(obs_t_float, self.num_actions, scope="q_func", reuse=False)
q_func_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='q_func')
target_q = q_func(obs_tp1_float, self.num_actions, scope="target_q_func", reuse=False)
y = self.rew_t_ph + gamma * tf.reduce_max(target_q, axis=-1)
target_q_func_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_q_func')
self.total_error = tf.reduce_mean(
huber_loss(tf.squeeze(tf.batch_gather(self.q, tf.expand_dims(self.act_t_ph, axis=1))) - y))
######
# construct optimization op (with gradient clipping)
......@@ -229,6 +239,27 @@ class QLearner(object):
#####
# YOUR CODE HERE
self.replay_buffer_idx = self.replay_buffer.next_idx
self.replay_buffer.store_frame(self.last_obs)
if not self.model_initialized:
act = self.env.action_space.sample()
else:
if self.exploration.value(self.t) > np.random.sample():
act = self.env.action_space.sample()
# print(act)
else:
state = self.replay_buffer.encode_recent_observation()
values = self.session.run(self.q, {
self.obs_t_ph: state[np.newaxis, ...]})
act = np.argmax(values)
# print(values)
obs, reward, done, info = self.env.step(action=act)
self.replay_buffer.store_effect(
idx=self.replay_buffer_idx, action=act, reward=reward, done=done)
if done:
# print('DONE')
obs = self.env.reset()
self.last_obs = obs
def update_model(self):
### 3. Perform experience replay and train the network.
......@@ -274,6 +305,25 @@ class QLearner(object):
#####
# YOUR CODE HERE
obs_batch, act_batch, rew_batch, next_obs_batch, done_mask = self.replay_buffer.sample(self.batch_size)
if not self.model_initialized:
initialize_interdependent_variables(self.session, tf.global_variables(), {
self.obs_t_ph: obs_batch,
self.obs_tp1_ph: next_obs_batch,
})
self.session.run(self.update_target_fn)
self.model_initialized = True
# 3.c
self.session.run([self.train_fn, self.total_error], {
self.obs_t_ph: obs_batch,
self.act_t_ph: act_batch,
self.rew_t_ph: rew_batch,
self.obs_tp1_ph: next_obs_batch,
self.done_mask_ph: done_mask,
self.learning_rate: self.optimizer_spec.lr_schedule.value(self.t)
})
if (self.num_param_updates % self.target_update_freq) == 0:
self.session.run(self.update_target_fn)
self.num_param_updates += 1
......@@ -305,6 +355,7 @@ class QLearner(object):
with open(self.rew_file, 'wb') as f:
pickle.dump(episode_rewards, f, pickle.HIGHEST_PROTOCOL)
def learn(*args, **kwargs):
alg = QLearner(*args, **kwargs)
while not alg.stopping_criterion_met():
......@@ -314,4 +365,3 @@ def learn(*args, **kwargs):
# observation
alg.update_model()
alg.log_progress()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment