Skip to content

Commit f5cb8a4

Browse files
fix: HuggingFace save_to_disk takes PathLike type which is defined as str, bytes or os.PathLike. imitation.util.parse_path always returned pathlib.Path which is not one of these types. This commit converts pathlib.Path to str before calling the HF fn.
1 parent a8b079c commit f5cb8a4

File tree

4 files changed

+126
-49
lines changed

4 files changed

+126
-49
lines changed

src/imitation/data/serialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def save(path: AnyPath, trajectories: Sequence[Trajectory]) -> None:
1919
path: Trajectories are saved to this path.
2020
trajectories: The trajectories to save.
2121
"""
22-
p = util.parse_path(path)
22+
p = str(util.parse_path(path))
2323
huggingface_utils.trajectories_to_dataset(trajectories).save_to_disk(p)
2424
logging.info(f"Dumped demonstrations to {p}.")
2525

tests/data/conftest.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import gymnasium as gym
2+
import numpy as np
3+
import pytest
4+
5+
from imitation.data import types
6+
7+
SPACES = [
8+
gym.spaces.Discrete(3),
9+
gym.spaces.MultiDiscrete([3, 4]),
10+
gym.spaces.Box(-1, 1, shape=(1,)),
11+
gym.spaces.Box(-1, 1, shape=(2,)),
12+
gym.spaces.Box(-np.inf, np.inf, shape=(2,)),
13+
]
14+
DICT_SPACE = gym.spaces.Dict(
15+
{"a": gym.spaces.Discrete(3), "b": gym.spaces.Box(-1, 1, shape=(2,))},
16+
)
17+
LENGTHS = [0, 1, 2, 10]
18+
19+
20+
@pytest.fixture(params=SPACES)
21+
def act_space(request):
22+
return request.param
23+
24+
25+
@pytest.fixture(params=SPACES + [DICT_SPACE])
26+
def obs_space(request):
27+
return request.param
28+
29+
30+
@pytest.fixture(params=LENGTHS)
31+
def length(request):
32+
return request.param
33+
34+
35+
@pytest.fixture
36+
def trajectory(
37+
obs_space: gym.Space,
38+
act_space: gym.Space,
39+
length: int,
40+
) -> types.Trajectory:
41+
"""Fixture to generate trajectory of length `length` iid sampled from spaces."""
42+
if length == 0:
43+
pytest.skip()
44+
45+
raw_obs = [obs_space.sample() for _ in range(length + 1)]
46+
if isinstance(obs_space, gym.spaces.Dict):
47+
obs: types.Observation = types.DictObs.from_obs_list(raw_obs)
48+
else:
49+
obs = np.array(raw_obs)
50+
acts = np.array([act_space.sample() for _ in range(length)])
51+
infos = np.array([{f"key{i}": i} for i in range(length)])
52+
return types.Trajectory(obs=obs, acts=acts, infos=infos, terminal=True)
53+
54+
55+
@pytest.fixture
56+
def trajectory_rew(trajectory: types.Trajectory) -> types.TrajectoryWithRew:
57+
"""Like `trajectory` but with reward randomly sampled from a Gaussian."""
58+
rews = np.random.randn(len(trajectory))
59+
return types.TrajectoryWithRew(
60+
**types.dataclass_quick_asdict(trajectory),
61+
rews=rews,
62+
)

tests/data/test_serialize.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""Tests for `imitation.data.serialize`."""
2+
3+
import pathlib
4+
5+
import gymnasium as gym
6+
import numpy as np
7+
import pytest
8+
9+
from imitation.data import serialize, types
10+
from imitation.data.types import DictObs
11+
12+
13+
@pytest.fixture
14+
def data_path(tmp_path):
15+
return tmp_path / "data"
16+
17+
18+
@pytest.mark.parametrize("path_type", [str, pathlib.Path])
19+
def test_save_trajectory(data_path, trajectory, path_type):
20+
if isinstance(trajectory.obs, DictObs):
21+
pytest.skip("serialize.save does not yet support DictObs")
22+
23+
serialize.save(path_type(data_path), [trajectory])
24+
assert data_path.exists()
25+
26+
27+
@pytest.mark.parametrize("path_type", [str, pathlib.Path])
28+
def test_save_trajectory_rew(data_path, trajectory_rew, path_type):
29+
if isinstance(trajectory_rew.obs, DictObs):
30+
pytest.skip("serialize.save does not yet support DictObs")
31+
serialize.save(path_type(data_path), [trajectory_rew])
32+
assert data_path.exists()
33+
34+
35+
def test_save_load_trajectory(data_path, trajectory):
36+
if isinstance(trajectory.obs, DictObs):
37+
pytest.skip("serialize.save does not yet support DictObs")
38+
serialize.save(data_path, [trajectory])
39+
40+
reconstructed = list(serialize.load(data_path))
41+
reconstructedi = reconstructed[0]
42+
43+
assert len(reconstructed) == 1
44+
assert np.allclose(reconstructedi.obs, trajectory.obs)
45+
assert np.allclose(reconstructedi.acts, trajectory.acts)
46+
assert np.allclose(reconstructedi.terminal, trajectory.terminal)
47+
assert not hasattr(reconstructedi, "rews")
48+
49+
50+
@pytest.mark.parametrize("load_fn", [serialize.load, serialize.load_with_rewards])
51+
def test_save_load_trajectory_rew(data_path, trajectory_rew, load_fn):
52+
if isinstance(trajectory_rew.obs, DictObs):
53+
pytest.skip("serialize.save does not yet support DictObs")
54+
serialize.save(data_path, [trajectory_rew])
55+
56+
reconstructed = list(load_fn(data_path))
57+
reconstructedi = reconstructed[0]
58+
59+
assert len(reconstructed) == 1
60+
assert np.allclose(reconstructedi.obs, trajectory_rew.obs)
61+
assert np.allclose(reconstructedi.acts, trajectory_rew.acts)
62+
assert np.allclose(reconstructedi.terminal, trajectory_rew.terminal)
63+
assert np.allclose(reconstructedi.rews, trajectory_rew.rews)

tests/data/test_types.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,58 +15,13 @@
1515
from imitation.data import serialize, types
1616
from imitation.util import util
1717

18-
SPACES = [
19-
gym.spaces.Discrete(3),
20-
gym.spaces.MultiDiscrete([3, 4]),
21-
gym.spaces.Box(-1, 1, shape=(1,)),
22-
gym.spaces.Box(-1, 1, shape=(2,)),
23-
gym.spaces.Box(-np.inf, np.inf, shape=(2,)),
24-
]
25-
DICT_SPACE = gym.spaces.Dict(
26-
{"a": gym.spaces.Discrete(3), "b": gym.spaces.Box(-1, 1, shape=(2,))},
27-
)
28-
29-
OBS_SPACES = SPACES + [DICT_SPACE]
30-
ACT_SPACES = SPACES
31-
LENGTHS = [0, 1, 2, 10]
32-
3318

3419
def _check_1d_shape(fn: Callable[[np.ndarray], Any], length: int, expected_msg: str):
3520
for shape in [(), (length, 1), (length, 2), (length - 1,), (length + 1,)]:
3621
with pytest.raises(ValueError, match=expected_msg):
3722
fn(np.zeros(shape))
3823

3924

40-
@pytest.fixture
41-
def trajectory(
42-
obs_space: gym.Space,
43-
act_space: gym.Space,
44-
length: int,
45-
) -> types.Trajectory:
46-
"""Fixture to generate trajectory of length `length` iid sampled from spaces."""
47-
if length == 0:
48-
pytest.skip()
49-
50-
raw_obs = [obs_space.sample() for _ in range(length + 1)]
51-
if isinstance(obs_space, gym.spaces.Dict):
52-
obs: types.Observation = types.DictObs.from_obs_list(raw_obs)
53-
else:
54-
obs = np.array(raw_obs)
55-
acts = np.array([act_space.sample() for _ in range(length)])
56-
infos = np.array([{f"key{i}": i} for i in range(length)])
57-
return types.Trajectory(obs=obs, acts=acts, infos=infos, terminal=True)
58-
59-
60-
@pytest.fixture
61-
def trajectory_rew(trajectory: types.Trajectory) -> types.TrajectoryWithRew:
62-
"""Like `trajectory` but with reward randomly sampled from a Gaussian."""
63-
rews = np.random.randn(len(trajectory))
64-
return types.TrajectoryWithRew(
65-
**types.dataclass_quick_asdict(trajectory),
66-
rews=rews,
67-
)
68-
69-
7025
@pytest.fixture
7126
def transitions_min(
7227
obs_space: gym.Space,
@@ -134,9 +89,6 @@ def pushd(dir_path):
13489
os.chdir(orig_dir)
13590

13691

137-
@pytest.mark.parametrize("obs_space", OBS_SPACES)
138-
@pytest.mark.parametrize("act_space", ACT_SPACES)
139-
@pytest.mark.parametrize("length", LENGTHS)
14092
class TestData:
14193
"""Tests of imitation.util.data.
14294

0 commit comments

Comments
 (0)