From 6eab14739c20c1150eb5e5fcb504294871eb0f81 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sun, 1 Jun 2025 10:38:00 +0000 Subject: [PATCH 1/5] handle nesting: ConvertDType, ToArray, relax Concatenate Concatenate can be equal to rename if only one key is supplied. By not calling concatenate in that case, we can accept arbitrary inputs in the transform, as long as only one is supplied. This simplifies things e.g. in the `BasicWorkflow`, where the user passes the `summary_variables` to concatenate, which may be a single dict, which does not need to be concatenated. --- bayesflow/adapters/transforms/concatenate.py | 13 +++- .../adapters/transforms/convert_dtype.py | 5 +- bayesflow/adapters/transforms/to_array.py | 28 +++++++- bayesflow/utils/__init__.py | 3 +- bayesflow/utils/tree.py | 69 +++++++++++++++++++ tests/test_adapters/conftest.py | 4 +- tests/test_workflows/conftest.py | 9 +-- tests/test_workflows/test_basic_workflow.py | 7 +- 8 files changed, 116 insertions(+), 22 deletions(-) create mode 100644 bayesflow/utils/tree.py diff --git a/bayesflow/adapters/transforms/concatenate.py b/bayesflow/adapters/transforms/concatenate.py index ac3700616..5c1474c22 100644 --- a/bayesflow/adapters/transforms/concatenate.py +++ b/bayesflow/adapters/transforms/concatenate.py @@ -49,7 +49,7 @@ def get_config(self) -> dict: return serialize(config) def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dict[str, any]: - if not strict and self.indices is None: + if not strict and self.indices is None and len(self.keys) != 1: raise ValueError("Cannot call `forward` with `strict=False` before calling `forward` with `strict=True`.") # copy to avoid side effects @@ -69,6 +69,10 @@ def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dic data.pop(key) return data + elif len(required_keys) == 1: + # only a rename + data[self.into] = data.pop(self.keys[0]) + return data if self.indices is None: # remember the indices of the parts in the concatenated array @@ -86,7 +90,7 @@ def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dic return data def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> dict[str, any]: - if self.indices is None: + if self.indices is None and len(self.keys) != 1: raise RuntimeError("Cannot call `inverse` before calling `forward` at least once.") # copy to avoid side effects @@ -98,6 +102,9 @@ def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> di elif self.into not in data: # nothing to do return data + elif len(self.keys) == 1: + data[self.keys[0]] = data.pop(self.into) + return data # split the concatenated array and remove the concatenated key keys = self.keys @@ -141,7 +148,7 @@ def log_det_jac( available_keys = set(log_det_jac.keys()) common_keys = available_keys & required_keys - if len(common_keys) == 0: + if len(common_keys) == 0 or len(self.keys) == 1: return log_det_jac parts = [log_det_jac.pop(key) for key in common_keys] diff --git a/bayesflow/adapters/transforms/convert_dtype.py b/bayesflow/adapters/transforms/convert_dtype.py index 8cd21b4cc..ea7b6b10f 100644 --- a/bayesflow/adapters/transforms/convert_dtype.py +++ b/bayesflow/adapters/transforms/convert_dtype.py @@ -1,4 +1,5 @@ import numpy as np +from keras.tree import map_structure from bayesflow.utils.serialization import serializable, serialize @@ -32,7 +33,7 @@ def get_config(self) -> dict: return serialize(config) def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: - return data.astype(self.to_dtype, copy=False) + return map_structure(lambda d: d.astype(self.to_dtype, copy=False), data) def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: - return data.astype(self.from_dtype, copy=False) + return map_structure(lambda d: d.astype(self.from_dtype, copy=False), data) diff --git a/bayesflow/adapters/transforms/to_array.py b/bayesflow/adapters/transforms/to_array.py index fe1b82f2d..5192d66a0 100644 --- a/bayesflow/adapters/transforms/to_array.py +++ b/bayesflow/adapters/transforms/to_array.py @@ -2,6 +2,7 @@ import numpy as np +from bayesflow.utils.tree import map_dict, get_value_at_path, map_dict_with_path from bayesflow.utils.serialization import serializable, serialize from .elementwise_transform import ElementwiseTransform @@ -35,13 +36,36 @@ def get_config(self) -> dict: def forward(self, data: any, **kwargs) -> np.ndarray: if self.original_type is None: - self.original_type = type(data) + if isinstance(data, dict): + self.original_type = map_dict(type, data) + else: + self.original_type = type(data) + if isinstance(self.original_type, dict): + # use self.original_type in check to preserve serializablitiy + return map_dict(np.asarray, data) return np.asarray(data) - def inverse(self, data: np.ndarray, **kwargs) -> any: + def inverse(self, data: np.ndarray | dict, **kwargs) -> any: if self.original_type is None: raise RuntimeError("Cannot call `inverse` before calling `forward` at least once.") + if isinstance(self.original_type, dict): + # use self.original_type in check to preserve serializablitiy + + def restore_original_type(path, value): + try: + original_type = get_value_at_path(self.original_type, path) + return original_type(value) + except KeyError: + pass + except TypeError: + pass + except ValueError: + # separate statements, as optree does not allow (KeyError | TypeError | ValueError) + pass + return value + + return map_dict_with_path(restore_original_type, data) if issubclass(self.original_type, Number): try: diff --git a/bayesflow/utils/__init__.py b/bayesflow/utils/__init__.py index 776e42fcd..ee5e63f23 100644 --- a/bayesflow/utils/__init__.py +++ b/bayesflow/utils/__init__.py @@ -7,6 +7,7 @@ logging, numpy_utils, serialization, + tree, ) from .callbacks import detailed_loss_callback @@ -104,4 +105,4 @@ from ._docs import _add_imports_to_all -_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils", "serialization"]) +_add_imports_to_all(include_modules=["keras_utils", "logging", "numpy_utils", "serialization", "tree"]) diff --git a/bayesflow/utils/tree.py b/bayesflow/utils/tree.py new file mode 100644 index 000000000..c60cc4afa --- /dev/null +++ b/bayesflow/utils/tree.py @@ -0,0 +1,69 @@ +import optree + + +def flatten_shape(structure): + def is_shape_tuple(x): + return isinstance(x, (list, tuple)) and all(isinstance(e, (int, type(None))) for e in x) + + leaves, _ = optree.tree_flatten( + structure, + is_leaf=is_shape_tuple, + none_is_leaf=True, + namespace="keras", + ) + return leaves + + +def map_dict(func, *structures): + def is_not_dict(x): + return not isinstance(x, dict) + + if not structures: + raise ValueError("Must provide at least one structure") + + # Add check for same structures, otherwise optree just maps to shallowest. + def func_with_check(*args): + if not all(optree.tree_is_leaf(s, is_leaf=is_not_dict, none_is_leaf=True, namespace="keras") for s in args): + raise ValueError("Structures don't have the same nested structure.") + return func(*args) + + map_func = func_with_check if len(structures) > 1 else func + + return optree.tree_map( + map_func, + *structures, + is_leaf=is_not_dict, + none_is_leaf=True, + namespace="keras", + ) + + +def map_dict_with_path(func, *structures): + def is_not_dict(x): + return not isinstance(x, dict) + + if not structures: + raise ValueError("Must provide at least one structure") + + # Add check for same structures, otherwise optree just maps to shallowest. + def func_with_check(*args): + if not all(optree.tree_is_leaf(s, is_leaf=is_not_dict, none_is_leaf=True, namespace="keras") for s in args): + raise ValueError("Structures don't have the same nested structure.") + return func(*args) + + map_func = func_with_check if len(structures) > 1 else func + + return optree.tree_map_with_path( + map_func, + *structures, + is_leaf=is_not_dict, + none_is_leaf=True, + namespace="keras", + ) + + +def get_value_at_path(structure, path): + output = structure + for accessor in path: + output = output.__getitem__(accessor) + return output diff --git a/tests/test_adapters/conftest.py b/tests/test_adapters/conftest.py index 3193309ae..feccc6d77 100644 --- a/tests/test_adapters/conftest.py +++ b/tests/test_adapters/conftest.py @@ -13,7 +13,9 @@ def serializable_fn(x): return ( Adapter() + .group(["p1", "p2"], into="ps", prefix="p") .to_array() + .ungroup("ps", prefix="p") .as_set(["s1", "s2"]) .broadcast("t1", to="t2") .as_time_series(["t1", "t2"]) @@ -37,8 +39,6 @@ def serializable_fn(x): .rename("o1", "o2") .random_subsample("s3", sample_size=33, axis=0) .take("s3", indices=np.arange(0, 32), axis=0) - .group(["p1", "p2"], into="ps", prefix="p") - .ungroup("ps", prefix="p") ) diff --git a/tests/test_workflows/conftest.py b/tests/test_workflows/conftest.py index 84b3fdafb..126c39cea 100644 --- a/tests/test_workflows/conftest.py +++ b/tests/test_workflows/conftest.py @@ -81,13 +81,6 @@ def sample(self, batch_shape: Shape, num_observations: int = 4) -> dict[str, Ten x = mean[:, None] + noise - return dict(mean=mean, a=x, b=x) + return dict(mean=mean, observables=dict(a=x, b=x)) return FusionSimulator() - - -@pytest.fixture -def fusion_adapter(): - from bayesflow import Adapter - - return Adapter.create_default(["mean"]).group(["a", "b"], "summary_variables") diff --git a/tests/test_workflows/test_basic_workflow.py b/tests/test_workflows/test_basic_workflow.py index 7cca5c1c1..28b98b2ed 100644 --- a/tests/test_workflows/test_basic_workflow.py +++ b/tests/test_workflows/test_basic_workflow.py @@ -34,14 +34,13 @@ def test_basic_workflow(tmp_path, inference_network, summary_network): assert samples["parameters"].shape == (5, 3, 2) -def test_basic_workflow_fusion( - tmp_path, fusion_inference_network, fusion_summary_network, fusion_simulator, fusion_adapter -): +def test_basic_workflow_fusion(tmp_path, fusion_inference_network, fusion_summary_network, fusion_simulator): workflow = bf.BasicWorkflow( - adapter=fusion_adapter, inference_network=fusion_inference_network, summary_network=fusion_summary_network, simulator=fusion_simulator, + inference_variables=["mean"], + summary_variables=["observables"], checkpoint_filepath=str(tmp_path), ) From a47fb3d2012cf604d373e97523bb18a25918677f Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Thu, 5 Jun 2025 20:40:04 +0000 Subject: [PATCH 2/5] use Rename instead of Concatenate if only one key is supplied Moves the fix from the Concatenate transform to the concatenate method of the adapter. --- bayesflow/adapters/adapter.py | 3 +++ bayesflow/adapters/transforms/concatenate.py | 13 +++---------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index fa84a9b4f..6add7e8a1 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -482,6 +482,9 @@ def concatenate(self, keys: str | Sequence[str], *, into: str, axis: int = -1): axis : int, optional Along which axis to concatenate the keys. The last axis is used by default. """ + if isinstance(keys, Sequence) and len(keys) == 1: + # unpack string if only one key is supplied, so that Rename is used below + keys = keys[0] if isinstance(keys, str): transform = Rename(keys, to_key=into) else: diff --git a/bayesflow/adapters/transforms/concatenate.py b/bayesflow/adapters/transforms/concatenate.py index 5c1474c22..ac3700616 100644 --- a/bayesflow/adapters/transforms/concatenate.py +++ b/bayesflow/adapters/transforms/concatenate.py @@ -49,7 +49,7 @@ def get_config(self) -> dict: return serialize(config) def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dict[str, any]: - if not strict and self.indices is None and len(self.keys) != 1: + if not strict and self.indices is None: raise ValueError("Cannot call `forward` with `strict=False` before calling `forward` with `strict=True`.") # copy to avoid side effects @@ -69,10 +69,6 @@ def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dic data.pop(key) return data - elif len(required_keys) == 1: - # only a rename - data[self.into] = data.pop(self.keys[0]) - return data if self.indices is None: # remember the indices of the parts in the concatenated array @@ -90,7 +86,7 @@ def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dic return data def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> dict[str, any]: - if self.indices is None and len(self.keys) != 1: + if self.indices is None: raise RuntimeError("Cannot call `inverse` before calling `forward` at least once.") # copy to avoid side effects @@ -102,9 +98,6 @@ def inverse(self, data: dict[str, any], *, strict: bool = False, **kwargs) -> di elif self.into not in data: # nothing to do return data - elif len(self.keys) == 1: - data[self.keys[0]] = data.pop(self.into) - return data # split the concatenated array and remove the concatenated key keys = self.keys @@ -148,7 +141,7 @@ def log_det_jac( available_keys = set(log_det_jac.keys()) common_keys = available_keys & required_keys - if len(common_keys) == 0 or len(self.keys) == 1: + if len(common_keys) == 0: return log_det_jac parts = [log_det_jac.pop(key) for key in common_keys] From c145fe32623c870b68b2651387afc3ef1cca260d Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Fri, 6 Jun 2025 06:42:11 +0000 Subject: [PATCH 3/5] add test for concatenate to rename conversion --- tests/test_adapters/test_adapters.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_adapters/test_adapters.py b/tests/test_adapters/test_adapters.py index 23721a938..095058eea 100644 --- a/tests/test_adapters/test_adapters.py +++ b/tests/test_adapters/test_adapters.py @@ -393,3 +393,16 @@ def test_nnpe(random_data): # Both should assign noise to high-variance dimension assert std_dim[1] > 0 assert std_glob[1] > 0 + + +def test_single_concatenate_to_rename(): + # test that single-element concatenate is converted to rename + from bayesflow import Adapter + from bayesflow.adapters.transforms import Rename, Concatenate + + ad = Adapter().concatenate("a", into="b") + assert isinstance(ad[0], Rename) + ad = Adapter().concatenate(["a"], into="b") + assert isinstance(ad[0], Rename) + ad = Adapter().concatenate(["a", "b"], into="c") + assert isinstance(ad[0], Concatenate) From ebd49ea6a524ede8a510de9f70391959c9e1ac9d Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sat, 14 Jun 2025 21:34:23 +0000 Subject: [PATCH 4/5] simplifications: remove invertibility of ToArray for dict - simplify map_dict to only a single structure, as we probably will not require the more general behavior. Add test and docstring. - remove tree functions that were required for restoring original types - minor cleanups to account for review comments --- bayesflow/adapters/adapter.py | 5 +- bayesflow/adapters/transforms/to_array.py | 35 ++++--------- bayesflow/utils/tree.py | 63 +++++++---------------- tests/test_utils/test_tree.py | 16 ++++++ 4 files changed, 46 insertions(+), 73 deletions(-) create mode 100644 tests/test_utils/test_tree.py diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index 6add7e8a1..458a2b136 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -482,11 +482,10 @@ def concatenate(self, keys: str | Sequence[str], *, into: str, axis: int = -1): axis : int, optional Along which axis to concatenate the keys. The last axis is used by default. """ - if isinstance(keys, Sequence) and len(keys) == 1: - # unpack string if only one key is supplied, so that Rename is used below - keys = keys[0] if isinstance(keys, str): transform = Rename(keys, to_key=into) + elif len(keys) == 1: + transform = Rename(keys[0], to_key=into) else: transform = Concatenate(keys, into=into, axis=axis) self.transforms.append(transform) diff --git a/bayesflow/adapters/transforms/to_array.py b/bayesflow/adapters/transforms/to_array.py index 5192d66a0..d6dba2aa9 100644 --- a/bayesflow/adapters/transforms/to_array.py +++ b/bayesflow/adapters/transforms/to_array.py @@ -2,7 +2,7 @@ import numpy as np -from bayesflow.utils.tree import map_dict, get_value_at_path, map_dict_with_path +from bayesflow.utils.tree import map_dict from bayesflow.utils.serialization import serializable, serialize from .elementwise_transform import ElementwiseTransform @@ -35,37 +35,22 @@ def get_config(self) -> dict: return serialize({"original_type": self.original_type}) def forward(self, data: any, **kwargs) -> np.ndarray: + if isinstance(data, dict): + # no invertiblity for dict, do not store original type + return map_dict(np.asarray, data) + if self.original_type is None: - if isinstance(data, dict): - self.original_type = map_dict(type, data) - else: - self.original_type = type(data) + self.original_type = type(data) - if isinstance(self.original_type, dict): - # use self.original_type in check to preserve serializablitiy - return map_dict(np.asarray, data) return np.asarray(data) def inverse(self, data: np.ndarray | dict, **kwargs) -> any: + if isinstance(data, dict): + # no invertibility for dict to keep complexity low + return data + if self.original_type is None: raise RuntimeError("Cannot call `inverse` before calling `forward` at least once.") - if isinstance(self.original_type, dict): - # use self.original_type in check to preserve serializablitiy - - def restore_original_type(path, value): - try: - original_type = get_value_at_path(self.original_type, path) - return original_type(value) - except KeyError: - pass - except TypeError: - pass - except ValueError: - # separate statements, as optree does not allow (KeyError | TypeError | ValueError) - pass - return value - - return map_dict_with_path(restore_original_type, data) if issubclass(self.original_type, Number): try: diff --git a/bayesflow/utils/tree.py b/bayesflow/utils/tree.py index c60cc4afa..bdcb098d2 100644 --- a/bayesflow/utils/tree.py +++ b/bayesflow/utils/tree.py @@ -14,56 +14,29 @@ def is_shape_tuple(x): return leaves -def map_dict(func, *structures): - def is_not_dict(x): - return not isinstance(x, dict) - - if not structures: - raise ValueError("Must provide at least one structure") - - # Add check for same structures, otherwise optree just maps to shallowest. - def func_with_check(*args): - if not all(optree.tree_is_leaf(s, is_leaf=is_not_dict, none_is_leaf=True, namespace="keras") for s in args): - raise ValueError("Structures don't have the same nested structure.") - return func(*args) - - map_func = func_with_check if len(structures) > 1 else func - - return optree.tree_map( - map_func, - *structures, - is_leaf=is_not_dict, - none_is_leaf=True, - namespace="keras", - ) - +def map_dict(func, dictionary): + """Applies a function to all leaves of a (possibly nested) dictionary. + + Parameters + ---------- + func : Callable + The function to apply to the leaves. + dictionary : dict + The input dictionary. + + Returns + ------- + dict + A dictionary with the outputs of `func` as leaves. + """ -def map_dict_with_path(func, *structures): def is_not_dict(x): return not isinstance(x, dict) - if not structures: - raise ValueError("Must provide at least one structure") - - # Add check for same structures, otherwise optree just maps to shallowest. - def func_with_check(*args): - if not all(optree.tree_is_leaf(s, is_leaf=is_not_dict, none_is_leaf=True, namespace="keras") for s in args): - raise ValueError("Structures don't have the same nested structure.") - return func(*args) - - map_func = func_with_check if len(structures) > 1 else func - - return optree.tree_map_with_path( - map_func, - *structures, + return optree.tree_map( + func, + dictionary, is_leaf=is_not_dict, none_is_leaf=True, namespace="keras", ) - - -def get_value_at_path(structure, path): - output = structure - for accessor in path: - output = output.__getitem__(accessor) - return output diff --git a/tests/test_utils/test_tree.py b/tests/test_utils/test_tree.py new file mode 100644 index 000000000..0c3dc81f2 --- /dev/null +++ b/tests/test_utils/test_tree.py @@ -0,0 +1,16 @@ +def test_map_dict(): + from bayesflow.utils.tree import map_dict + + input = { + "a": { + "x": [0, 1, 2], + }, + "b": [0, 1], + "c": "foo", + } + output = map_dict(len, input) + for key, value in output.items(): + if key == "a": + assert value["x"] == len(input["a"]["x"]) + continue + assert value == len(input[key]) From 608a6f46b1f47ab8ff997058ccc5821c52d69934 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sat, 14 Jun 2025 21:43:27 +0000 Subject: [PATCH 5/5] add and adapt type hints --- bayesflow/adapters/transforms/convert_dtype.py | 4 ++-- bayesflow/utils/tree.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/bayesflow/adapters/transforms/convert_dtype.py b/bayesflow/adapters/transforms/convert_dtype.py index ea7b6b10f..d9159487e 100644 --- a/bayesflow/adapters/transforms/convert_dtype.py +++ b/bayesflow/adapters/transforms/convert_dtype.py @@ -32,8 +32,8 @@ def get_config(self) -> dict: } return serialize(config) - def forward(self, data: np.ndarray, **kwargs) -> np.ndarray: + def forward(self, data: np.ndarray | dict, **kwargs) -> np.ndarray | dict: return map_structure(lambda d: d.astype(self.to_dtype, copy=False), data) - def inverse(self, data: np.ndarray, **kwargs) -> np.ndarray: + def inverse(self, data: np.ndarray | dict, **kwargs) -> np.ndarray | dict: return map_structure(lambda d: d.astype(self.from_dtype, copy=False), data) diff --git a/bayesflow/utils/tree.py b/bayesflow/utils/tree.py index bdcb098d2..a19d2e68e 100644 --- a/bayesflow/utils/tree.py +++ b/bayesflow/utils/tree.py @@ -1,4 +1,5 @@ import optree +from typing import Callable def flatten_shape(structure): @@ -14,7 +15,7 @@ def is_shape_tuple(x): return leaves -def map_dict(func, dictionary): +def map_dict(func: Callable, dictionary: dict) -> dict: """Applies a function to all leaves of a (possibly nested) dictionary. Parameters