|
#!/usr/bin/env python3 |
|
|
|
""" |
|
Solution to OpenAI Gym CartPole-v0 using hill climbing. |
|
https://gym.openai.com/envs/CartPole-v0 |
|
""" |
|
|
|
from __future__ import print_function |
|
import sys |
|
import time |
|
import math |
|
import random |
|
|
|
import gym |
|
from gym import wrappers |
|
|
|
RENDER_MODE = None |
|
|
|
class EnvLims(object): |
|
|
|
def __init__(self, env): |
|
self.env = env |
|
self.x_min = env.observation_space.low[0] |
|
self.x_max = env.observation_space.high[0] |
|
self.x_mid = (self.x_min + self.x_max) / 2 |
|
self.th_min = env.observation_space.low[2] * 180 / math.pi |
|
self.th_max = env.observation_space.high[2] * 180 / math.pi |
|
self.th_mid = (self.th_min + self.th_max) / 2 |
|
|
|
self.kx_max = 1.5 * self.th_max / self.x_max |
|
self.sample() |
|
self.kw_max = 1.5 * self.th_max / self.w_max |
|
self.kv_max = 1.5 * self.th_max / self.v_max |
|
|
|
|
|
def sample(self): |
|
# Find max w and v |
|
self.w_min = float('+inf') |
|
self.w_max = float('-inf') |
|
self.v_min = float('+inf') |
|
self.v_max = float('-inf') |
|
|
|
state = self.env.reset() |
|
prev_x = 0 |
|
th_good = True |
|
while prev_x >= self.x_min and prev_x <= self.x_max: |
|
if RENDER_MODE is not None: |
|
self.env.render(mode=RENDER_MODE) |
|
x, v, th, w = state |
|
th *= 180 / math.pi |
|
if th_good: |
|
self.w_min = min(self.w_min, w) |
|
if th < self.th_min or th > self.th_max: |
|
th_good = False |
|
self.v_max = max(self.v_max, v) |
|
prev_x = x |
|
state, std_reward, done2, info = self.env.step(1) |
|
self.w_min *= 2 |
|
self.v_max *= 2 |
|
self.w_max = - self.w_min |
|
self.v_min = - self.v_max |
|
|
|
print('Limits:') |
|
print('w_min={}, w_max={}'.format(self.w_min, self.w_max)) |
|
print('v_min={}, v_max={}'.format(self.v_min, self.v_max)) |
|
|
|
|
|
def get_action(state, t, kw=0, kv=0, kx=0): |
|
# get action based on state and time |
|
x, v, th, w = state |
|
th2 = th + kw * w - kv * v - kx * x |
|
return 0 if th2 < 0 else 1 |
|
|
|
def ftup_to_str(X): |
|
return '({})'.format(', '.join(['{:8.5g}'.format(x) for x in X])) |
|
|
|
def play(env, iteration, reps, params): |
|
sum_t = 0 |
|
print("Episode {} with params {}".format(iteration, ftup_to_str(params))) |
|
try: |
|
for j in range(reps): |
|
state = env.reset() |
|
done = False |
|
t = 0 |
|
while not done and t < 100000: |
|
if RENDER_MODE is not None: |
|
env.render(mode=RENDER_MODE) |
|
#time.sleep(0.1) |
|
action = get_action(state, t, *params) |
|
state, std_reward, done, info = env.step(action) |
|
t += 1 |
|
sum_t += t |
|
#print("\tSub-episode {}: {}".format(j+1, t)) |
|
except KeyboardInterrupt: |
|
sum_t += t |
|
print("Incomplete Episode {}: {}".format(iteration, sum_t / (j+1))) |
|
raise |
|
print("Episode {}: {}".format(iteration, sum_t / reps)) |
|
print() |
|
return sum_t / reps |
|
|
|
def random_select(X): |
|
return tuple(((2 * random.random() - 1) * xi for xi in X)) |
|
#return tuple((random.choice((-1, 1)) * xi for xi in X)) |
|
|
|
def scalar_prod(k, X): |
|
return tuple((k * xi for xi in X)) |
|
|
|
def scalar_add(X1, X2): |
|
return tuple((x1 + x2 for x1, x2 in zip(X1, X2))) |
|
|
|
def hill_climb(env, X0, dX0): |
|
reps = 2 |
|
y0 = play(env, 0, reps, X0) |
|
X, dX, y = X0, dX0, y0 |
|
depr = 0.9 |
|
|
|
iteration = 1 |
|
j = 0 |
|
#while depr ** j >= 0.001: |
|
while j <= 100: |
|
X2 = scalar_add(X, random_select(dX)) |
|
#if all(x >= 0 for x in X2): |
|
if True: |
|
y2 = play(env, iteration, reps, X2) |
|
iteration += 1 |
|
if y2 > y != (random.random() < 0.04): |
|
y = y2 |
|
X = X2 |
|
print("improvement") |
|
else: |
|
j += 1 |
|
dX = scalar_prod(depr, dX) |
|
|
|
def main(): |
|
env = gym.make('CartPole-v0') |
|
path = '/tmp/cartpole-expt-7' |
|
lims = EnvLims(env) |
|
env = wrappers.Monitor(env, path) |
|
|
|
#play(env, 0, 3, (1.2, 10, 0)) |
|
|
|
hill_climb(env, (0, 0, 0), (lims.kw_max / 4, lims.kv_max / 4, lims.kx_max / 4)) |
|
#hill_climb(env, (0,), (lims.kw_max / 4,)) |
|
|
|
env.close() |
|
if input("Do you wish to upload results? (y/n) ") == 'y': |
|
gym.upload(path, api_key='<your_api_key_here>') |
|
|
|
if __name__ == '__main__': |
|
main() |