Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
A
AdvRL19
Manage
Activity
Members
Code
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Locked files
Deploy
Releases
Model registry
Analyze
Contributor analytics
Repository analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
GitLab community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
flgw
AdvRL19
Commits
c4496029
Commit
c4496029
authored
Mar 27, 2019
by
Florian Gawrilowicz
Browse files
Options
Downloads
Patches
Plain Diff
vanilla DQN - solves Pong, PongRam, LunarLander
parent
044ab278
No related branches found
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
hw3/dqn.py
+349
-299
349 additions, 299 deletions
hw3/dqn.py
with
349 additions
and
299 deletions
hw3/dqn.py
+
349
−
299
View file @
c4496029
...
...
@@ -9,10 +9,11 @@ import random
import
tensorflow
as
tf
import
tensorflow.contrib.layers
as
layers
from
collections
import
namedtuple
from
dqn_utils
import
*
from
hw3.
dqn_utils
import
*
OptimizerSpec
=
namedtuple
(
"
OptimizerSpec
"
,
[
"
constructor
"
,
"
kwargs
"
,
"
lr_schedule
"
])
class
QLearner
(
object
):
def
__init__
(
...
...
@@ -160,6 +161,15 @@ class QLearner(object):
# YOUR CODE HERE
self
.
q
=
q_func
(
obs_t_float
,
self
.
num_actions
,
scope
=
"
q_func
"
,
reuse
=
False
)
q_func_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
scope
=
'
q_func
'
)
target_q
=
q_func
(
obs_tp1_float
,
self
.
num_actions
,
scope
=
"
target_q_func
"
,
reuse
=
False
)
y
=
self
.
rew_t_ph
+
gamma
*
tf
.
reduce_max
(
target_q
,
axis
=-
1
)
target_q_func_vars
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
,
scope
=
'
target_q_func
'
)
self
.
total_error
=
tf
.
reduce_mean
(
huber_loss
(
tf
.
squeeze
(
tf
.
batch_gather
(
self
.
q
,
tf
.
expand_dims
(
self
.
act_t_ph
,
axis
=
1
)))
-
y
))
######
# construct optimization op (with gradient clipping)
...
...
@@ -229,6 +239,27 @@ class QLearner(object):
#####
# YOUR CODE HERE
self
.
replay_buffer_idx
=
self
.
replay_buffer
.
next_idx
self
.
replay_buffer
.
store_frame
(
self
.
last_obs
)
if
not
self
.
model_initialized
:
act
=
self
.
env
.
action_space
.
sample
()
else
:
if
self
.
exploration
.
value
(
self
.
t
)
>
np
.
random
.
sample
():
act
=
self
.
env
.
action_space
.
sample
()
# print(act)
else
:
state
=
self
.
replay_buffer
.
encode_recent_observation
()
values
=
self
.
session
.
run
(
self
.
q
,
{
self
.
obs_t_ph
:
state
[
np
.
newaxis
,
...]})
act
=
np
.
argmax
(
values
)
# print(values)
obs
,
reward
,
done
,
info
=
self
.
env
.
step
(
action
=
act
)
self
.
replay_buffer
.
store_effect
(
idx
=
self
.
replay_buffer_idx
,
action
=
act
,
reward
=
reward
,
done
=
done
)
if
done
:
# print('DONE')
obs
=
self
.
env
.
reset
()
self
.
last_obs
=
obs
def
update_model
(
self
):
### 3. Perform experience replay and train the network.
...
...
@@ -274,6 +305,25 @@ class QLearner(object):
#####
# YOUR CODE HERE
obs_batch
,
act_batch
,
rew_batch
,
next_obs_batch
,
done_mask
=
self
.
replay_buffer
.
sample
(
self
.
batch_size
)
if
not
self
.
model_initialized
:
initialize_interdependent_variables
(
self
.
session
,
tf
.
global_variables
(),
{
self
.
obs_t_ph
:
obs_batch
,
self
.
obs_tp1_ph
:
next_obs_batch
,
})
self
.
session
.
run
(
self
.
update_target_fn
)
self
.
model_initialized
=
True
# 3.c
self
.
session
.
run
([
self
.
train_fn
,
self
.
total_error
],
{
self
.
obs_t_ph
:
obs_batch
,
self
.
act_t_ph
:
act_batch
,
self
.
rew_t_ph
:
rew_batch
,
self
.
obs_tp1_ph
:
next_obs_batch
,
self
.
done_mask_ph
:
done_mask
,
self
.
learning_rate
:
self
.
optimizer_spec
.
lr_schedule
.
value
(
self
.
t
)
})
if
(
self
.
num_param_updates
%
self
.
target_update_freq
)
==
0
:
self
.
session
.
run
(
self
.
update_target_fn
)
self
.
num_param_updates
+=
1
...
...
@@ -305,6 +355,7 @@ class QLearner(object):
with
open
(
self
.
rew_file
,
'
wb
'
)
as
f
:
pickle
.
dump
(
episode_rewards
,
f
,
pickle
.
HIGHEST_PROTOCOL
)
def
learn
(
*
args
,
**
kwargs
):
alg
=
QLearner
(
*
args
,
**
kwargs
)
while
not
alg
.
stopping_criterion_met
():
...
...
@@ -314,4 +365,3 @@ def learn(*args, **kwargs):
# observation
alg
.
update_model
()
alg
.
log_progress
()
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment