Skip to content

[WIP] Add support for graphical simulators #487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: dev
Choose a base branch
from

Conversation

daniel-habermann
Copy link
Contributor

@daniel-habermann daniel-habermann commented May 23, 2025

WIP pull request to add support for graphical simulators. I'm going to update this message and tag some people once it has reached a state where it makes sense to read further. All discussion and feedback welcome!

Summary and Motivation

This PR introduces initial support for graphical simulators. The main idea is to represent a complex simulation program as a directed acyclic graph (DAG), where nodes represent sets of parameters and edges denote conditional dependencies.
Such a structure is a natural representation for many Bayesian models because the joint distribution of parameters $p(\theta)$ can often be expressed in the form of some factorization along a DAG $p(\theta_1, \dots, \theta_N) = \prod_{i=1}^{N} p(\theta_i | \text{Parents}(\theta_i))$.

The benefit of making these dependency structures explicit is that the converse is also true: By stating the conditional dependencies, a corresponding DAG also encodes conditional independencies implied by the distribution, which we can then use to automatically build efficient network architectures, for example for multilevel models.

Current Implementation

Consider a standard two-level hierarchical model:

$$ \begin{aligned} \tau,\omega &\sim \text{Normal}^{+}(0, 1)\\ \lambda_j &\sim \text{Normal}(0, \tau)\\ x_{ij} &\sim \text{Normal}(\lambda_j, \omega)\\ \end{aligned} $$

Such a model can be represented by the following diagram:

where the dashed boxes denote that parameters are exchangeable. Currently, such a diagram would be implemented like this:

from bayesflow.experimental.graphical_simulator import GraphicalSimulator
import numpy as np

def sample_tau():
    tau = np.abs(np.random.normal())
    return dict(tau=tau)

def sample_omega():
    omega = np.abs(np.random.normal())
    return dict(omega=omega)

def sample_lambda_j(tau):
    lambda_j = np.abs(np.random.normal(loc=0, scale=tau))
    return dict(lambda_j=lambda_j)

def sample_x_ij(lambda_j, omega):
    x_ij = np.random.normal(loc=lambda_j, scale=omega)
    return dict(x_ij=x_ij)

simulator = GraphicalSimulator()

simulator.add_node("tau", sampling_function=sample_tau, sample_size=lambda: 1)
simulator.add_node("omega", sampling_function=sample_omega, sample_size=lambda: 1)
simulator.add_node("lambda_j", sampling_function=sample_lambda_j, sample_size=lambda: np.random.randint(5, 10))
simulator.add_node("x_ij", sampling_function=sample_x_ij, sample_size=lambda: np.random.randint(1, 10))

simulator.add_edge("tau", "lambda_j")
simulator.add_edge("lambda_j", "x_ij")
simulator.add_edge("omega", "x_ij")

Design space

There is still a long list of design choices:

How to determine how often each node is executed for each batch?

For multilevel models, we want to vary the number of groups and observations within each group for each batch. Currently, this is achieved by the sample_size function argument, which expects a callable returning an integer.
One question is if we even need such an argument, or could remove it by relying on something like the current meta_fn.

If we go the meta_fn route, do we have a single meta_fn for each node or a global one

How can we represent more exotic models or non-DAG structures, like state space models

How do we handle which nodes return observed data?

This becomes important when talking about graph inversions. Currently, we can attach arbitrary metadata to each node and the graph inversion algorithm searches for an "observed" keyword, but from a user perspective this should probably be improved. We might even not care about this at all because the adapter defines summary_conditions or inference_conditions.

How is all of this represented internally?

The resulting data structure is non-rectangular because each batch might have a different number of calls for each node.

@daniel-habermann daniel-habermann added feature New feature or request draft Draft Pull Request, Work in Progress labels May 23, 2025
@daniel-habermann daniel-habermann marked this pull request as draft May 23, 2025 14:29
Copy link

codecov bot commented May 23, 2025

@LarsKue
Copy link
Contributor

LarsKue commented May 23, 2025

I just thought about an alternative interface, would be glad for some feedback:

# we could make passing graph optional later
def sample_tau(graph):
    tau = np.abs(np.random.normal())
    return tau


def sample_omega(graph):
    omega = np.abs(np.random.normal())
    return omega


def sample_lambda(graph):
    ncol = graph.sample_int("ncol", 5, 10)
    tau = graph.sample_var("tau", shape=(ncol,))
    lamb = np.random.normal(loc=0, scale=tau)
    return lamb


def sample_x(graph):
    nrow = graph.sample_int("nrow", 1, 10)
    ncol = graph.sample_int("ncol", 5, 10)  # cached between sample_lambda and sample_x
    lamb = graph.sample_var("lambda", shape=(1, ncol))
    omega = graph.sample_var("omega", shape=(1, 1))

    x = np.random.normal(loc=lamb, scale=omega, size=(nrow, ncol))
    return x


graph = GraphSimulator()

# each node returns exactly one variable
graph.add_node("tau", sample_tau)
graph.add_node("omega", sample_omega)
graph.add_node("lambda", sample_lambda)
graph.add_node("x", sample_x)

# returns a list of dicts
samples = graph.sample(10)

I already have a working implementation for this, but it might not exactly fit the needs of the rest of the library, so I would like to discuss first.

@paul-buerkner
Copy link
Contributor

paul-buerkner commented May 24, 2025

Thank you @daniel-habermann for your PR! I will review it (specifically the interface) in the next couple of days.

@LarsKue since Daniel already spend quite a bit of time and thought for this PR, I would like to go with his implementation for now. If we end up not liking it for some reason, we can still discuss alternatives then.

@daniel-habermann
Copy link
Contributor Author

daniel-habermann commented May 26, 2025

I just thought about an alternative interface, would be glad for some feedback:

# we could make passing graph optional later
def sample_tau(graph):
    tau = np.abs(np.random.normal())
    return tau


def sample_omega(graph):
    omega = np.abs(np.random.normal())
    return omega


def sample_lambda(graph):
    ncol = graph.sample_int("ncol", 5, 10)
    tau = graph.sample_var("tau", shape=(ncol,))
    lamb = np.random.normal(loc=0, scale=tau)
    return lamb


def sample_x(graph):
    nrow = graph.sample_int("nrow", 1, 10)
    ncol = graph.sample_int("ncol", 5, 10)  # cached between sample_lambda and sample_x
    lamb = graph.sample_var("lambda", shape=(1, ncol))
    omega = graph.sample_var("omega", shape=(1, 1))

    x = np.random.normal(loc=lamb, scale=omega, size=(nrow, ncol))
    return x


graph = GraphSimulator()

# each node returns exactly one variable
graph.add_node("tau", sample_tau)
graph.add_node("omega", sample_omega)
graph.add_node("lambda", sample_lambda)
graph.add_node("x", sample_x)

# returns a list of dicts
samples = graph.sample(10)

I already have a working implementation for this, but it might not exactly fit the needs of the rest of the library, so I would like to discuss first.

I'm not against radically changing the suggested interface, but one design consideration is consistency: In the current interface, a user can return dictionaries with arbitrary keys, so it would be quite difficult to explain why this is possible when using bf.simulators.make_simulator, but not when building a graphical simulator.

The same is true for sample_var. The current interface allows passing inputs as parameters, so it would break consistency with the current interface if we suddenly have to request parameters from within a function. There are still some inconsistencies with the approach outlined in the first pr, for example sampling_function should at least be sampling_fn to stay consistent with meta_fn in the current interface.

What problem were you trying to resolve with your suggestion? If the concern is boilerplate and adding edges to the networks, I expect we can resolve almost all cases with code introspection, i.e. all a user has to provide are function definitions as in the usual make_simulator call, and the graph can then be inferred from the function arguments, because each function is a node and the edges can be retrieved by matching the dictionary outputs to input arguments.

@daniel-habermann
Copy link
Contributor Author

daniel-habermann commented Jun 20, 2025

Just tagging everyone :) @paul-buerkner @stefanradev93 @LarsKue @elseml @arrjon

I've finally reached a stage where I'm happy with the general design and thought it would be a good moment to get your feedback.
Here is an updated example of twolevel model mentioned above:

from bayesflow.experimental.graphical_simulator import GraphicalSimulator
import numpy as np

def sample_tau():
    tau = np.abs(np.random.normal())
    return dict(tau=tau)

def sample_omega():
    omega = np.abs(np.random.normal())
    return dict(omega=omega)

def sample_lambda_j(tau):
    lambda_j = np.abs(np.random.normal(loc=0, scale=tau))
    return dict(lambda_j=lambda_j)

def sample_x_ij(lambda_j, omega):
    x_ij = np.random.normal(loc=lambda_j, scale=omega)
    return dict(x_ij=x_ij)

def meta():
    return {
        "num_groups": np.random.randint(5, 10),
        "num_obs": np.random.randint(1, 10)
    }

simulator = GraphicalSimulator(meta_fn=meta)

simulator.add_node("tau", sampling_fn=sample_tau)
simulator.add_node("omega", sampling_fn=sample_omega)
simulator.add_node("lambda_j", sampling_fn=sample_lambda_j, reps="num_groups")
simulator.add_node("x_ij", sampling_fn=sample_x_ij, reps="num_obs")

simulator.add_edge("tau", "lambda_j")
simulator.add_edge("lambda_j", "x_ij")
simulator.add_edge("omega", "x_ij")

Major changes to the previous versions are:

  • sampling_function argument renamed to sampling_fn to stay consistent with the rest of the codebase
  • sample_size argument removed in favor of reps, which is either an int or string. For string inputs, the node is repeated according to the output of dictionary return by meta_fn.
  • calling the sample method is now without side effects (previously, samples were stored in the graph directly).
  • the sample method now has a return type of dict[str, np.ndarray].

The main design goal was consistency with our other interfaces. Concretely, I wanted the output of a single-level model implemented as a GraphicalSimulator to match that of a simulator created with make_simulator. This is currently the case:

    def prior():
        beta = np.random.normal([2, 0], [3, 1])
        sigma = np.random.gamma(1, 1)

        return {"beta": beta, "sigma": sigma}

    def likelihood(beta, sigma, N):
        x = np.random.normal(0, 1, size=N)
        y = np.random.normal(beta[0] + beta[1] * x, sigma, size=N)

        return {"x": x, "y": y}

    def meta():
        N = np.random.randint(5, 15)

        return {"N": N}

    simulator = GraphicalSimulator(meta_fn=meta)

    simulator.add_node("prior", sampling_fn=prior)
    simulator.add_node("likelihood", sampling_fn=likelihood)

    simulator.add_edge("prior", "likelihood")

    sim_draws = simulator.sample(500)
    sim_draws["N"] # 13
    sim_draws["beta"].shape # (500, 2)
    sim_draws["sigma"].shape # (500, 1)
    sim_draws["x"].shape # (500, 13)
    sim_draws["y"].shape # (500, 13)

Of course, the preferred way to define the number of observations would be to remove N as a function argument to likelihood instead of using the meta_fn approach, but both works. We also need both because we want to have the ability to add arbitrary meta data.

To channel feedback, here is a list of points that I believe are most important to agree on (but of course all other comments are also highly welcome):

How do we vary the number of groups and observations during training?

Our current simulators just return data sets of varying number of observations. This is fine for online training, but doesn't work well for offline training, which was always the default for my workflows.
In Bayesflow v1, the approach that worked best for me was to create one dataset with the maximum number of groups and observations and then use the configurator to subset this dataset for each batch. Do we want to allow a similar approach here, for example by extending the adapter with a subset functionality?

An alternative would be to allow simulation of non-rectangular datasets. The internal representation of the GraphicalSimulator uses a long format, which has shown to be the most effective representation of hierarchical data in Stan and brms. Of course, the big question would then be how to efficiently pass non-rectangular data to the networks. I don't think there is a satisfying solution to this, which is why I prefer the "amortize via random subsetting" approach.

What is the output of GraphicalSimulator.sample?

Our current Simulator class demands the output to be dict[str, np.ndarray], which seems to work well for non-rectangular data.
But what is the output shape of a variable exactly? Right now, it is [batch_shape, reps of non-root ancestors, node reps if != 1, variable output shape]. This results in output shapes like [batch_shape, num_groups, num_obs, num_dims] as we previously had for the twolevel simulator in Bayesflow v1 and also produces consistent outputs for single level models. However, this rule might be a bit unintuitive and something like [batch_shape, ancestor reps, node reps, variable shape] might be easier. A drawback of the latter would be that this is not consistent with the output of single level models, for example the output shape for "beta" in the example above would be (500, 1, 2) instead of (500, 2). I favor the current rule, but don't have a strong stance on that.

Do we want to allow repetitions of root nodes?

This is adjacent to the previous point. Repetitions of root nodes would essentially be just another batch dimension. I'm leaning towards yes, but this is currently not implemented and also depends on the outcome of the previous questions.

Happy to hear your thoughts! :)

@stefanradev93
Copy link
Contributor

stefanradev93 commented Jun 23, 2025

Hi Daniel, the interface is a fine job! Here are some more-or-less detailed thoughts regarding your questions:

How do we vary the number of groups and observations during training?

  • As you mention, this is easy for online data sets. For disk / offline data, we either go with the max_obs approach + a subsample transform or padding.

What is the output of GraphicalSimulator.sample?

  • I also favor the current rule. `[batch_shape, ...#varying stuff, dim] is what works well and is consistent with standard DL input-output signatures.

Do we want to allow repetitions of root nodes?

  • Yes, this would be needed, especially for state-space models.

@paul-buerkner
Copy link
Contributor

I discussed with @daniel-habermann this week and we agree with @stefanradev93. I believe everything is in place now. @daniel-habermann what are the next steps for this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
draft Draft Pull Request, Work in Progress feature New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants