Skip to content
Snippets Groups Projects
Commit 0cb6b38c authored by Florian Gawrilowicz's avatar Florian Gawrilowicz
Browse files

add roboschool

parent c73e1782
Branches
No related tags found
No related merge requests found
......@@ -13,30 +13,36 @@ import os
import pickle
import tensorflow as tf
import numpy as np
import tf_util
from hw1 import tf_util
import gym
import load_policy
from hw1 import load_policy
from hw1 import roboschool_agents
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('expert_policy_file', type=str)
parser.add_argument('envname', type=str)
parser.add_argument('--render', action='store_true')
parser.add_argument('-r', '--render', action='store_true')
parser.add_argument('--roboschool', action='store_true')
parser.add_argument("--max_timesteps", type=int)
parser.add_argument('--num_rollouts', type=int, default=20,
help='Number of expert roll outs')
args = parser.parse_args()
print('loading and building expert policy')
policy_fn = load_policy.load_policy(args.expert_policy_file)
print('loaded and built')
env = gym.make(args.envname)
with tf.Session():
tf_util.initialize()
import gym
env = gym.make(args.envname)
print('loading and building expert policy')
if args.roboschool:
pi = roboschool_agents.load_policy(args.envname, env)
else:
policy_fn = load_policy.load_policy(args.expert_policy_file)
print('loaded and built')
max_steps = args.max_timesteps or env.spec.timestep_limit
returns = []
......@@ -49,6 +55,9 @@ def main():
totalr = 0.
steps = 0
while not done:
if args.roboschool:
action = pi.act(obs, env)
else:
action = policy_fn(obs[None, :])
observations.append(obs)
actions.append(action)
......@@ -57,9 +66,12 @@ def main():
steps += 1
if args.render:
env.render()
if steps % 100 == 0: print("%i/%i"%(steps, max_steps))
if steps % 100 == 0:
print("%i/%i" % (steps, max_steps))
if steps >= max_steps:
break
if args.render:
break
returns.append(totalr)
print('returns', returns)
......@@ -72,5 +84,6 @@ def main():
with open(os.path.join('expert_data', args.envname + '.pkl'), 'wb') as f:
pickle.dump(expert_data, f, pickle.HIGHEST_PROTOCOL)
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment