Skip to content

Commit f34eb26

Browse files
committed
Do not fail with zero-sized arrays in dataset_to_point_list
Numpy does not support reshape(-1, ...) when size is zero
1 parent 8a436d8 commit f34eb26

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

pymc/backends/arviz.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,11 +618,13 @@ def dataset_to_point_list(
618618
for vn in var_names:
619619
if not isinstance(vn, str):
620620
raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.")
621+
621622
num_sample_dims = len(sample_dims)
622623
stacked_dims = {dim_name: ds[var_names[0]][dim_name] for dim_name in sample_dims}
623624
transposed_dict = {vn: da.transpose(*sample_dims, ...) for vn, da in ds.items()}
625+
stacked_size = np.prod(transposed_dict[var_names[0]].shape[:num_sample_dims], dtype=int)
624626
stacked_dict = {
625-
vn: da.values.reshape((-1, *da.shape[num_sample_dims:]))
627+
vn: da.values.reshape((stacked_size, *da.shape[num_sample_dims:]))
626628
for vn, da in transposed_dict.items()
627629
}
628630
points = [

tests/backends/test_arviz.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,3 +837,14 @@ def test_dataset_to_point_list_str_key(self):
837837
ds[3] = xarray.DataArray([1, 2, 3])
838838
with pytest.raises(ValueError, match="must be str"):
839839
dataset_to_point_list(ds, sample_dims=["chain", "draw"])
840+
841+
def test_zero_size(self):
842+
ds = xarray.Dataset()
843+
ds["x"] = xarray.DataArray(
844+
np.zeros((4, 10, 0, 5)), dims=("chain", "draw", "dim_0", "dim_5")
845+
)
846+
pl, _ = dataset_to_point_list(ds, sample_dims=("chain", "draw"))
847+
assert len(pl) == 40
848+
assert tuple(pl[0]) == ("x",)
849+
assert pl[0]["x"].shape == (0, 5)
850+
assert pl[0]["x"].dtype == np.float64

0 commit comments

Comments
 (0)