Skip to content

Commit 5a3d6f4

Browse files
author
Norbert Kozlowski
committed
Real Multiplexer
1 parent 80d7b24 commit 5a3d6f4

File tree

9 files changed

+183
-72
lines changed

9 files changed

+183
-72
lines changed

gym_multiplexer/__init__.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,60 @@
1-
from .boolean_multiplexer import BooleanMultiplexer
2-
31
from gym.envs.registration import register
42

5-
name = "boolean-multiplexer"
3+
from .boolean_multiplexer import BooleanMultiplexer
4+
from .real_multiplexer import RealMultiplexer
5+
6+
bool_mpx_name = "boolean-multiplexer"
7+
real_mpx_name = "real-multiplexer"
68
max_episode_steps = 1
79

810
# Length of a multiplexer is calculated
911
# using l = k + 2^k
1012

1113
register(
12-
id='{}-3bit-v0'.format(name),
14+
id='{}-3bit-v0'.format(bool_mpx_name),
1315
entry_point='gym_multiplexer:BooleanMultiplexer',
1416
max_episode_steps=max_episode_steps,
1517
kwargs={'control_bits': 1}
1618
)
1719

1820
register(
19-
id='{}-6bit-v0'.format(name),
21+
id='{}-6bit-v0'.format(bool_mpx_name),
2022
entry_point='gym_multiplexer:BooleanMultiplexer',
2123
max_episode_steps=max_episode_steps,
2224
kwargs={'control_bits': 2}
2325
)
2426

2527
register(
26-
id='{}-11bit-v0'.format(name),
28+
id='{}-11bit-v0'.format(bool_mpx_name),
2729
entry_point='gym_multiplexer:BooleanMultiplexer',
2830
max_episode_steps=max_episode_steps,
2931
kwargs={'control_bits': 3}
3032
)
3133

3234
register(
33-
id='{}-20bit-v0'.format(name),
35+
id='{}-20bit-v0'.format(bool_mpx_name),
3436
entry_point='gym_multiplexer:BooleanMultiplexer',
3537
max_episode_steps=max_episode_steps,
3638
kwargs={'control_bits': 4}
3739
)
3840

3941
register(
40-
id='{}-37bit-v0'.format(name),
42+
id='{}-37bit-v0'.format(bool_mpx_name),
4143
entry_point='gym_multiplexer:BooleanMultiplexer',
4244
max_episode_steps=max_episode_steps,
4345
kwargs={'control_bits': 5}
46+
)
47+
48+
register(
49+
id='{}-3bit-v0'.format(real_mpx_name),
50+
entry_point='gym_multiplexer:RealMultiplexer',
51+
max_episode_steps=max_episode_steps,
52+
kwargs={'control_bits': 1}
53+
)
54+
55+
register(
56+
id='{}-6bit-v0'.format(real_mpx_name),
57+
entry_point='gym_multiplexer:RealMultiplexer',
58+
max_episode_steps=max_episode_steps,
59+
kwargs={'control_bits': 2}
4460
)
Lines changed: 5 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,14 @@
1-
import logging
2-
import random
3-
4-
import gym
5-
from bitstring import BitArray
61
from gym.spaces import Discrete
72

3+
from .multiplexer import Multiplexer
84

9-
class BooleanMultiplexer(gym.Env):
105

11-
REWARD = 1000
6+
class BooleanMultiplexer(Multiplexer):
127

138
def __init__(self, control_bits=3) -> None:
14-
self.control_bits = control_bits
15-
self.metadata = {'render.modes': ['human']}
9+
super().__init__(control_bits)
1610
self.observation_space = Discrete(self._observation_string_length)
1711
self.action_space = Discrete(2)
1812

19-
def _reset(self):
20-
logging.debug("Resetting the environment")
21-
bits = BitArray([random.randint(0, 1) for _ in
22-
range(0, self._observation_string_length - 1)])
23-
24-
self._ctrl_bits = bits[:self.control_bits]
25-
self._data_bits = bits[self.control_bits:]
26-
self._validation_bit = False
27-
28-
return self._observation()
29-
30-
def _step(self, action):
31-
reward = 0
32-
33-
if action == self._answer:
34-
self._validation_bit = True
35-
reward = self.REWARD
36-
37-
return self._observation(), reward, None, None
38-
39-
def _render(self, mode='human', close=False):
40-
if close:
41-
return
42-
43-
if mode == 'human':
44-
return self._observation()
45-
else:
46-
super(BooleanMultiplexer, self).render(mode=mode)
47-
48-
def _observation(self) -> str:
49-
return (self._ctrl_bits + self._data_bits
50-
+ BitArray([self._validation_bit])).bin
51-
52-
@property
53-
def _observation_string_length(self):
54-
return self.control_bits + pow(2, self.control_bits) + 1
55-
56-
@property
57-
def _answer(self):
58-
return int(self._data_bits[self._ctrl_bits.uint])
13+
def _internal_state(self):
14+
return map(lambda x: round(x), self._state)

gym_multiplexer/multiplexer.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import random
2+
3+
import gym
4+
5+
from .utils import get_correct_answer
6+
7+
8+
class Multiplexer(gym.Env):
9+
10+
REWARD = 1000
11+
12+
def _internal_state(self): raise NotImplementedError
13+
14+
def __init__(self, control_bits=3) -> None:
15+
self.control_bits = control_bits
16+
self.metadata = {'render.modes': ['human']}
17+
18+
self._state = None
19+
self._validation_bit = 0
20+
21+
def _reset(self):
22+
self._state = [random.random() for _ in
23+
range(0, self._observation_string_length - 1)]
24+
self._validation_bit = 0
25+
return self._observation
26+
27+
def _step(self, action):
28+
reward = 0
29+
30+
if action == self._correct_answer:
31+
self._validation_bit = 1
32+
reward = self.REWARD
33+
34+
return self._observation, reward, None, None
35+
36+
def _render(self, mode='human', close=False):
37+
if close:
38+
return
39+
40+
if mode == 'human':
41+
return self._observation
42+
43+
return self.render(mode=mode)
44+
45+
@property
46+
def _observation(self) -> list:
47+
observation = list(self._state)
48+
observation.append(self._validation_bit)
49+
return observation
50+
51+
@property
52+
def _correct_answer(self):
53+
return get_correct_answer(list(self._internal_state()) , self.control_bits)
54+
55+
@property
56+
def _observation_string_length(self):
57+
return self.control_bits + pow(2, self.control_bits) + 1

gym_multiplexer/real_multiplexer.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from gym.spaces import Box, Discrete
2+
3+
from .multiplexer import Multiplexer
4+
5+
6+
class RealMultiplexer(Multiplexer):
7+
8+
def __init__(self, control_bits=3, threshold=.5) -> None:
9+
super().__init__(control_bits)
10+
self.threshold = threshold
11+
self.observation_space = Box(low=0, high=1, shape=(self._observation_string_length, ))
12+
self.action_space = Discrete(2)
13+
14+
def _internal_state(self):
15+
return map(lambda x: x > self.threshold, self._state)

gym_multiplexer/tests/test_multiplexer.py renamed to gym_multiplexer/tests/test_boolean_multiplexer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
1111

1212

13-
class TestMultiplexer:
13+
class TestBooleanMultiplexer:
1414
def test_should_initialize_multiplexer(self):
1515
# when
1616
mp = gym.make('boolean-multiplexer-6bit-v0')
@@ -29,9 +29,9 @@ def test_should_return_observation_when_reset(self):
2929

3030
# then
3131
assert state is not None
32-
assert state[-1] == '0'
32+
assert state[-1] == 0
3333
assert 7 == len(state)
34-
assert type(state) is str
34+
assert type(state) is list
3535

3636
def test_should_render_state(self):
3737
# given
@@ -43,9 +43,9 @@ def test_should_render_state(self):
4343

4444
# then
4545
assert state is not None
46-
assert state[-1] == '0'
46+
assert state[-1] == 0
4747
assert 4 == len(state)
48-
assert type(state) is str
48+
assert type(state) is list
4949

5050
def test_should_execute_step(self):
5151
# given
@@ -58,8 +58,8 @@ def test_should_execute_step(self):
5858

5959
# then
6060
assert state is not None
61-
assert state[-1] in ['0', '1']
62-
assert type(state) is str
61+
assert state[-1] in [0, 1]
62+
assert type(state) is list
6363
assert reward in [0, 1000]
6464
assert done is True
6565

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import logging
2+
import random
3+
import sys
4+
5+
import gym
6+
7+
# noinspection PyUnresolvedReferences
8+
import gym_multiplexer
9+
10+
logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
11+
12+
class TestRealMultiplexer:
13+
14+
def test_should_initialize_real_mpx(self):
15+
# when
16+
mp = gym.make("real-multiplexer-6bit-v0")
17+
18+
# then
19+
assert mp is not None
20+
assert (7,) == mp.observation_space.shape
21+
assert 2 == mp.action_space.n
22+
23+
def test_should_return_observation_when_reset(self):
24+
# given
25+
mp = gym.make('real-multiplexer-6bit-v0')
26+
27+
# when
28+
state = mp.reset()
29+
30+
# then
31+
assert state is not None
32+
assert state[-1] == 0
33+
assert 7 == len(state)
34+
assert type(state) is list
35+
36+
def test_should_execute_step(self):
37+
# given
38+
mp = gym.make('real-multiplexer-6bit-v0')
39+
mp.reset()
40+
action = self._random_action()
41+
42+
# when
43+
state, reward, done, _ = mp.step(action)
44+
45+
# then
46+
assert state is not None
47+
assert state[-1] in [0, 1]
48+
assert type(state) is list
49+
assert reward in [0, 1000]
50+
assert done is True
51+
52+
def test_execute_multiple_steps_and_keep_constant_perception_length(self):
53+
# given
54+
mp = gym.make('real-multiplexer-6bit-v0')
55+
steps = 100
56+
57+
# when & then
58+
for _ in range(0, steps):
59+
p0 = mp.reset()
60+
assert 7 == len(p0)
61+
62+
action = self._random_action()
63+
p1, reward, done, _ = mp.step(action)
64+
assert 7 == len(p1)
65+
66+
@staticmethod
67+
def _random_action():
68+
return random.sample([0, 1], 1)[0]

gym_multiplexer/tests/test_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
class TestUtils:
55

66
def test_should_calculate_correct_answer_for_3bit_multiplexer(self):
7-
assert 1 == get_correct_answer('0100', 1)
8-
assert 0 == get_correct_answer('1100', 1)
7+
assert 1 == get_correct_answer([0,1,0,0], 1)
8+
assert 0 == get_correct_answer([1,1,0,0], 1)
99

1010
def test_should_calculate_correct_answer_for_6bit_multiplexer(self):
11-
assert 0 == get_correct_answer('1101000', 2)
11+
assert 0 == get_correct_answer([1,1,0,1,0,0,0], 2)
1212

1313
def test_should_calculate_correct_answer_for_11bit_multiplexer(self):
14-
assert 1 == get_correct_answer('1011011010', 3)
14+
assert 1 == get_correct_answer([1,0,1,1,0,1,1,0,1,0], 3)

gym_multiplexer/utils/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from bitstring import BitString
22

33

4-
def get_correct_answer(bitstring: str, control_bits: int) -> int:
5-
bits = BitString(bin=bitstring)
4+
def get_correct_answer(bitstring: list, control_bits: int) -> int:
5+
bits = BitString(bitstring)
66

77
_ctrl_bits = bits[:control_bits]
88
_data_bits = bits[control_bits:]
9-
_validation_bit = bits[-1]
109

1110
return int(_data_bits[_ctrl_bits.uint])

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from setuptools import setup, find_packages
22

33
setup(name='parrotprediction-openai-envs',
4-
version='1.0.0',
4+
version='2.0.0',
55
description='Custom environments for OpenAI Gym',
66
keywords='acs lcs machine-learning reinforcement-learning openai',
77
url='https://github.com/ParrotPrediction/openai-envs',

0 commit comments

Comments
 (0)