Skip to content

Commit e4f920a

Browse files
authored
Corridor and Grid enviornments (#15)
1 parent 5720c95 commit e4f920a

File tree

10 files changed

+362
-4
lines changed

10 files changed

+362
-4
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ The repository contains environments used in LCS literature that are compliant w
1212
- Hand Eye
1313
- Checkerboard
1414
- Real-valued toy problems
15+
- 1D Corridor,
16+
- 2D Grid
1517

16-
For usage examples look at [examples/](examples) directory.
18+
For some usage examples look at [examples/](examples) directory.
1719

1820
## Development
1921

gym_corridor/__init__.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from gym.envs.registration import register
2+
3+
from .corridor import Corridor
4+
5+
max_episode_steps = 200
6+
7+
register(
8+
id='corridor-20-v0',
9+
entry_point='gym_corridor:Corridor',
10+
max_episode_steps=max_episode_steps,
11+
kwargs={'size': 20}
12+
)
13+
14+
register(
15+
id='corridor-40-v0',
16+
entry_point='gym_corridor:Corridor',
17+
max_episode_steps=max_episode_steps,
18+
kwargs={'size': 40}
19+
)
20+
21+
register(
22+
id='corridor-100-v0',
23+
entry_point='gym_corridor:Corridor',
24+
max_episode_steps=max_episode_steps,
25+
kwargs={'size': 100}
26+
)

gym_corridor/corridor.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from random import randint
2+
3+
import gym
4+
from gym.spaces import Discrete
5+
6+
MOVE_LEFT = 0
7+
MOVE_RIGHT = 1
8+
9+
10+
class Corridor(gym.Env):
11+
metadata = {'render.modes': ['human', 'ansi']}
12+
13+
REWARD = 1000
14+
15+
def __init__(self, size=20):
16+
self._size = size
17+
self._position = None
18+
19+
self.observation_space = Discrete(1)
20+
self.action_space = Discrete(2)
21+
22+
def reset(self):
23+
self._position = randint(1, self._size - 1)
24+
return str(self._position)
25+
26+
def step(self, action):
27+
if action == MOVE_LEFT:
28+
self._position -= 1
29+
elif action == MOVE_RIGHT:
30+
self._position += 1
31+
else:
32+
raise ValueError("Illegal action passed")
33+
34+
if self._position == self._size:
35+
return str(self._position), self.REWARD, True, None
36+
37+
if self._position == 0:
38+
self._position = 1
39+
40+
return str(self._position), 0, False, None
41+
42+
def render(self, mode='human'):
43+
if mode == 'human':
44+
print(self._visualize())
45+
elif mode == 'ansi':
46+
return self._visualize()
47+
else:
48+
raise ValueError('Unknown visualisation mode')
49+
50+
def _visualize(self):
51+
corridor = ["" for _ in range(0, self._size - 1)]
52+
corridor[self._position - 1] = "X"
53+
corridor[self._size - 2] = "$"
54+
return "[" + ".".join(corridor) + "]"

gym_corridor/tests/__init__.py

Whitespace-only changes.

gym_corridor/tests/test_corridor.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import logging
2+
import sys
3+
4+
import gym
5+
6+
# noinspection PyUnresolvedReferences
7+
import gym_corridor
8+
from gym_corridor.corridor import MOVE_LEFT, MOVE_RIGHT
9+
10+
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
11+
12+
13+
class TestCorridor:
14+
15+
def test_should_initialize(self):
16+
# when
17+
corr = gym.make('corridor-20-v0')
18+
19+
# then
20+
assert corr is not None
21+
assert 1 == corr.observation_space.n
22+
assert 2 == corr.action_space.n
23+
24+
def test_should_visualize(self):
25+
# given
26+
corr = gym.make('corridor-20-v0')
27+
28+
# when
29+
obs = corr.reset()
30+
vis = corr.render(mode='ansi')
31+
32+
# then
33+
assert 1 <= int(obs) < 20
34+
assert len(vis) == 22
35+
assert 1 == vis.count('X')
36+
assert 1 == vis.count('$')
37+
assert 18 == vis.count('.')
38+
39+
def test_should_hit_left_wall(self):
40+
# given
41+
corr = gym.make('corridor-20-v0')
42+
reward = 0
43+
done = False
44+
45+
# when
46+
obs = corr.reset()
47+
48+
while not done:
49+
obs, reward, done, _ = corr.step(MOVE_LEFT)
50+
51+
# then
52+
assert obs == '1'
53+
assert reward == 0
54+
assert done is True
55+
56+
def test_should_get_reward(self):
57+
# given
58+
corr = gym.make('corridor-20-v0')
59+
reward = 0
60+
done = False
61+
62+
# when
63+
obs = corr.reset()
64+
65+
while not done:
66+
obs, reward, done, _ = corr.step(MOVE_RIGHT)
67+
68+
# then
69+
assert obs == '20'
70+
assert reward == 1000
71+
assert done is True
72+
73+
def test_should_move_in_both_directions(self):
74+
# given
75+
corr = gym.make('corridor-20-v0')
76+
p0 = corr.reset()
77+
78+
while p0 in ["1", "19"]:
79+
p0 = corr.reset()
80+
81+
# when & then
82+
p1, _, _, _ = corr.step(MOVE_LEFT)
83+
assert int(p1) == int(p0) - 1
84+
85+
p2, _, _, _ = corr.step(MOVE_RIGHT)
86+
assert int(p2) == int(p0)

gym_grid/__init__.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from gym.envs.registration import register
2+
3+
from .grid import Grid
4+
5+
max_episode_steps = 200
6+
7+
register(
8+
id='grid-20-v0',
9+
entry_point='gym_grid:Grid',
10+
max_episode_steps=max_episode_steps,
11+
kwargs={'size': 20}
12+
)
13+
14+
register(
15+
id='grid-40-v0',
16+
entry_point='gym_grid:Grid',
17+
max_episode_steps=max_episode_steps,
18+
kwargs={'size': 40}
19+
)
20+
21+
register(
22+
id='grid-100-v0',
23+
entry_point='gym_grid:Grid',
24+
max_episode_steps=max_episode_steps,
25+
kwargs={'size': 100}
26+
)

gym_grid/grid.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import gym
2+
import numpy as np
3+
from gym.spaces import Discrete
4+
5+
MOVE_LEFT = 0
6+
MOVE_RIGHT = 1
7+
MOVE_UP = 3
8+
MOVE_DOWN = 4
9+
10+
# Food located in [n, n]
11+
# Observation x,y in [1, n]
12+
13+
14+
class Grid(gym.Env):
15+
metadata = {'render.modes': ['human', 'ansi']}
16+
17+
REWARD = 1000
18+
19+
def __init__(self, size=20):
20+
self._size = size
21+
self._pos_x = None
22+
self._pos_y = None
23+
24+
self.observation_space = Discrete(2)
25+
self.action_space = Discrete(4)
26+
27+
@property
28+
def _state(self):
29+
return str(self._pos_x), str(self._pos_y)
30+
31+
def reset(self):
32+
(self._pos_x, self._pos_y) = np.random.randint(
33+
1, self._size + 1, size=2)
34+
35+
if self._pos_x == self._size and self._pos_y == self._size:
36+
self.reset()
37+
38+
return self._state
39+
40+
def step(self, action):
41+
if action == MOVE_LEFT:
42+
self._pos_x -= 1
43+
elif action == MOVE_RIGHT:
44+
self._pos_x += 1
45+
elif action == MOVE_UP:
46+
self._pos_y += 1
47+
elif action == MOVE_DOWN:
48+
self._pos_y -= 1
49+
else:
50+
raise ValueError("Illegal action passed")
51+
52+
# Handle reaching final state
53+
if self._pos_x == self._size and self._pos_y == self._size:
54+
return self._state, self.REWARD, True, None
55+
56+
# Handle leaving grid
57+
if self._pos_x == 0:
58+
self._pos_x = 1
59+
elif self._pos_x == 21:
60+
self._pos_x = 20
61+
62+
if self._pos_y == 0:
63+
self._pos_y = 1
64+
elif self._pos_y == 21:
65+
self._pos_y = 20
66+
67+
# Return default observation
68+
return self._state, 0, False, None
69+
70+
def render(self, mode='human'):
71+
if mode == 'human':
72+
print(self._visualize())
73+
elif mode == 'ansi':
74+
return self._visualize()
75+
else:
76+
raise ValueError('Unknown visualisation mode')
77+
78+
def _visualize(self):
79+
print("")
80+
print(self._state)
81+
for y in reversed(range(0, self._size + 1)):
82+
for x in range(0, self._size + 1):
83+
if x == 0 and y == 0:
84+
print(f"{'':^3}", end='')
85+
elif x == 0:
86+
print(f"{y:>3}", end='')
87+
elif y == 0:
88+
print(f"{x:^3}", end='')
89+
elif x == self._pos_x and y == self._pos_y:
90+
print(f"{'X':^3}", end='')
91+
elif x == self._size and y == self._size:
92+
print(f"{'$':^3}", end='')
93+
else:
94+
print(f"{'_':^3}", end='')
95+
print("")

gym_grid/tests/__init__.py

Whitespace-only changes.

gym_grid/tests/test_grid.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import logging
2+
import sys
3+
4+
import gym
5+
import numpy as np
6+
7+
# noinspection PyUnresolvedReferences
8+
import gym_grid
9+
from gym_grid.grid import MOVE_LEFT, MOVE_RIGHT, MOVE_UP, MOVE_DOWN
10+
11+
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
12+
13+
14+
class TestGrid:
15+
16+
def test_should_initialize(self):
17+
# when
18+
grid = gym.make('grid-20-v0')
19+
20+
# then
21+
assert grid is not None
22+
assert 2 == grid.observation_space.n
23+
assert 4 == grid.action_space.n
24+
25+
def test_should_handle_hitting_boundaries(self):
26+
# given
27+
grid = gym.make('grid-20-v0')
28+
29+
# handle hitting upper bound
30+
np.random.seed(42)
31+
grid.reset() # (x=7, y=20)
32+
state, _, _, _ = grid.step(MOVE_UP)
33+
assert state == ("7", "20")
34+
35+
# handle hitting right bound
36+
np.random.seed(27)
37+
grid.reset() # (x=20, y=9)
38+
state, _, _, _ = grid.step(MOVE_RIGHT)
39+
assert state == ("20", "9")
40+
41+
# handle hitting lower bound
42+
np.random.seed(50)
43+
grid.reset() # (x=17, y=1)
44+
state, _, _, _ = grid.step(MOVE_DOWN)
45+
assert state == ("17", "1")
46+
47+
# handle hitting left bound
48+
np.random.seed(48)
49+
grid.reset() # (x=1, y=20)
50+
state, _, _, _ = grid.step(MOVE_LEFT)
51+
assert state == ("1", "20")
52+
53+
def test_should_get_reward(self):
54+
# given
55+
grid = gym.make('grid-20-v0')
56+
reward = 0
57+
done = False
58+
59+
# when
60+
grid.reset()
61+
for _ in range(0, 20):
62+
grid.step(MOVE_RIGHT)
63+
while not done:
64+
obs, reward, done, _ = grid.step(MOVE_UP)
65+
66+
# then
67+
assert obs == ('20', '20')
68+
assert reward == 1000
69+
assert done is True

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
from setuptools import setup, find_packages
22

33
setup(name='parrotprediction-openai-envs',
4-
version='2.0.4',
4+
version='2.1.0',
55
description='Custom environments for OpenAI Gym',
66
keywords='acs lcs machine-learning reinforcement-learning openai',
77
url='https://github.com/ParrotPrediction/openai-envs',
88
author='Parrot Prediction Ltd.',
9-
author_email='contact@parrotprediction.com',
9+
author_email='nkozlowski@protonmail.com',
1010
license='MIT',
1111
packages=find_packages(),
1212
install_requires=[
1313
'numpy',
14-
'gym>=0.10',
14+
'gym==0.11',
1515
'networkx==2.0',
1616
'bitstring==3.1.5'
1717
],

0 commit comments

Comments
 (0)