Skip to content

Extend multiple_shoot loss to multidimensional NeuralODEs #974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions src/multiple_shooting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ Arguments:
function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
continuity_loss::C, solver::SciMLBase.AbstractODEAlgorithm,
group_size::Integer; continuity_term::Real = 100, kwargs...) where {F, C}
datasize = size(ode_data, 2)
datasize = size(ode_data, ndims(ode_data))
griddims = ntuple(_ -> Colon(), ndims(ode_data) - 1)

if group_size < 2 || group_size > datasize
throw(DomainError(group_size, "group_size can't be < 2 or > number of data points"))
Expand All @@ -48,7 +49,7 @@ function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
# Multiple shooting predictions
sols = [solve(
remake(prob; p, tspan = (tsteps[first(rg)], tsteps[last(rg)]),
u0 = ode_data[:, first(rg)]),
u0 = ode_data[griddims..., first(rg)]),
solver;
saveat = tsteps[rg],
kwargs...) for rg in ranges]
Expand All @@ -61,15 +62,15 @@ function multiple_shoot(p, ode_data, tsteps, prob::ODEProblem, loss_function::F,
# Calculate multiple shooting loss
loss = 0
for (i, rg) in enumerate(ranges)
u = ode_data[:, rg]
û = group_predictions[i]
u = ode_data[griddims..., rg]
û = group_predictions[i][griddims..., :]
loss += loss_function(u, û)

if i > 1
# Ensure continuity between last state in previous prediction
# and current initial condition in ode_data
loss += continuity_term *
continuity_loss(group_predictions[i - 1][:, end], u[:, 1])
continuity_loss(group_predictions[i - 1][griddims..., end], u[griddims..., 1])
end
end

Expand Down Expand Up @@ -121,16 +122,18 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
ensemblealg::SciMLBase.BasicEnsembleAlgorithm, loss_function::F,
continuity_loss::C, solver::SciMLBase.AbstractODEAlgorithm,
group_size::Integer; continuity_term::Real = 100, kwargs...) where {F, C}
datasize = size(ode_data, 2)
ntraj = size(ode_data, ndims(ode_data))
datasize = size(ode_data, ndims(ode_data)-1)
griddims = ntuple(_ -> Colon(), ndims(ode_data) - 2)
prob = ensembleprob.prob

if group_size < 2 || group_size > datasize
throw(DomainError(group_size, "group_size can't be < 2 or > number of data points"))
end

@assert ndims(ode_data)==3 "ode_data must have three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)"
@assert size(ode_data, 2) == length(tsteps)
@assert size(ode_data, 3) == kwargs[:trajectories]
@assert ndims(ode_data)>=3 "ode_data must have at least three dimension: `size(ode_data) = (problem_dimension,length(tsteps),trajectories)"
@assert datasize == length(tsteps)
@assert ntraj == kwargs[:trajectories]

# Get ranges that partition data to groups of size group_size
ranges = group_ranges(datasize, group_size)
Expand All @@ -140,7 +143,7 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
rg -> begin
newprob = remake(prob; p = p, tspan = (tsteps[first(rg)], tsteps[last(rg)]))
function prob_func(prob, i, repeat)
remake(prob; u0 = ode_data[:, first(rg), i])
remake(prob; u0 = ode_data[griddims..., first(rg), i])
end
newensembleprob = EnsembleProblem(
newprob, prob_func, ensembleprob.output_func, ensembleprob.reduction,
Expand All @@ -158,7 +161,7 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
loss = 0
for (i, rg) in enumerate(ranges)
û = group_predictions[i]
u = ode_data[:, rg, :] # trajectories are at dims 3
u = ode_data[griddims..., rg, :] # trajectories are at dims 3
# just summing up losses for all trajectories
# but other alternatives might be considered

Expand All @@ -168,7 +171,7 @@ function multiple_shoot(p, ode_data, tsteps, ensembleprob::EnsembleProblem,
# Ensure continuity between last state in previous prediction
# and current initial condition in ode_data
loss += continuity_term *
continuity_loss(group_predictions[i - 1][:, end, :], u[:, 1, :])
continuity_loss(group_predictions[i - 1][griddims..., end, :], u[griddims..., 1, :])
end
end

Expand Down
305 changes: 162 additions & 143 deletions test/multiple_shoot_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,148 +12,167 @@
@test_throws DomainError group_ranges(10, 1)
@test_throws DomainError group_ranges(10, 11)

## Define initial conditions and time steps
datasize = 30
u0 = Float32[2.0, 0.0]
tspan = (0.0f0, 5.0f0)
tsteps = range(tspan[1], tspan[2]; length = datasize)

# Get the data
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u .^ 3)'true_A)'
# Test configurations
test_configs = [
(
name = "Vector Test Config",
u0 = Float32[2.0, 0.0],
ode_func = (du, u, p, t) -> (du .= ((u .^ 3)'*[-0.1 2.0; -2.0 -0.1])'),
nn = Chain(x -> x .^ 3, Dense(2 => 16, tanh), Dense(16 => 2)),
u0s_ensemble = [Float32[2.0, 0.0], Float32[3.0, 1.0]]
),
(
name = "Multi-D Test Config",
u0 = Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0],
ode_func = (du, u, p, t) -> (du .= ((u .^ 3).*[-0.01 0.02; -0.02 -0.01; 0.01 -0.05])),
nn = Chain(x -> x .^ 3, Dense(3 => 3, tanh)),
u0s_ensemble = [Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0], Float32[3.0 1.0; 2.0 0.5; 1.5 -0.5]]
)
]

for config in test_configs
@info "Running tests for: $(config.name)"

## Define initial conditions and time steps
datasize = 30
u0 = config.u0
tspan = (0.0f0, 5.0f0)
tsteps = range(tspan[1], tspan[2]; length = datasize)

# Get the data
trueODEfunc = config.ode_func
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps))

# Define the Neural Network
nn = config.nn
p_init, st = Lux.setup(rng, nn)
p_init = ComponentArray(p_init)

neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps)
prob_node = ODEProblem((u, p, t) -> first(nn(u, p, st)), u0, tspan, p_init)

predict_single_shooting(p) = Array(first(neuralode(u0, p, st)))

# Define loss function
loss_function(data, pred) = sum(abs2, data - pred)

## Evaluate Single Shooting
function loss_single_shooting(p)
pred = predict_single_shooting(p)
l = loss_function(ode_data, pred)
return l
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((p, _) -> loss_single_shooting(p), adtype)
optprob = Optimization.OptimizationProblem(optf, p_init)
res_single_shooting = Optimization.solve(optprob, Adam(0.05); maxiters = 300)

loss_ss = loss_single_shooting(res_single_shooting.minimizer)
@info "Single shooting loss: $(loss_ss)"

## Test Multiple Shooting
group_size = 3
continuity_term = 200

function loss_multiple_shooting(p)
return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(),
group_size; continuity_term, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting(p), adtype)
optprob = Optimization.OptimizationProblem(optf, p_init)
res_ms = Optimization.solve(optprob, Adam(0.05); maxiters = 300)

# Calculate single shooting loss with parameter from multiple_shoot training
loss_ms = loss_single_shooting(res_ms.minimizer)
println("Multiple shooting loss: $(loss_ms)")
@test loss_ms < 10loss_ss

# Test with custom loss function
group_size = 4
continuity_term = 50

function continuity_loss_abs2(û_end, u_0)
return sum(abs2, û_end - u_0) # using abs2 instead of default abs
end

function loss_multiple_shooting_abs2(p)
return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function,
continuity_loss_abs2, Tsit5(), group_size; continuity_term)[1]
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction(
(p, _) -> loss_multiple_shooting_abs2(p), adtype)
optprob = Optimization.OptimizationProblem(optf, p_init)
res_ms_abs2 = Optimization.solve(optprob, Adam(0.05); maxiters = 300)

loss_ms_abs2 = loss_single_shooting(res_ms_abs2.minimizer)
println("Multiple shooting loss with abs2: $(loss_ms_abs2)")
@test loss_ms_abs2 < loss_ss

## Test different SensitivityAlgorithm (default is InterpolatingAdjoint)
function loss_multiple_shooting_fd(p)
return multiple_shoot(
p, ode_data, tsteps, prob_node, loss_function, continuity_loss_abs2,
Tsit5(), group_size; continuity_term, sensealg = ForwardDiffSensitivity())[1]
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting_fd(p), adtype)
optprob = Optimization.OptimizationProblem(optf, p_init)
res_ms_fd = Optimization.solve(optprob, Adam(0.05); maxiters = 300)

# Calculate single shooting loss with parameter from multiple_shoot training
loss_ms_fd = loss_single_shooting(res_ms_fd.minimizer)
println("Multiple shooting loss with ForwardDiffSensitivity: $(loss_ms_fd)")
@test loss_ms_fd < 10loss_ss

# Integration return codes `!= :Success` should return infinite loss.
# In this case, we trigger `retcode = :MaxIters` by setting the solver option `maxiters=1`.
loss_fail = multiple_shoot(p_init, ode_data, tsteps, prob_node, loss_function,
Tsit5(), datasize; maxiters = 1, verbose = false)[1]
@test loss_fail == Inf

## Test for DomainErrors
@test_throws DomainError multiple_shoot(
p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), 1)
@test_throws DomainError multiple_shoot(
p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), datasize + 1)

## Ensembles
u0s = config.u0s_ensemble
function prob_func(prob, i, repeat)
remake(prob; u0 = u0s[i])
end
ensemble_prob = EnsembleProblem(prob_node; prob_func = prob_func)
ensemble_prob_trueODE = EnsembleProblem(prob_trueode; prob_func = prob_func)
ensemble_alg = EnsembleThreads()
trajectories = 2
ode_data_ensemble = Array(solve(
ensemble_prob_trueODE, Tsit5(), ensemble_alg; trajectories, saveat = tsteps))

group_size = 3
continuity_term = 200
function loss_multiple_shooting_ens(p)
return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
loss_function, Tsit5(), group_size; continuity_term,
trajectories, abstol = 1e-8, reltol = 1e-6)[1]
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction(
(p, _) -> loss_multiple_shooting_ens(p), adtype)
optprob = Optimization.OptimizationProblem(optf, p_init)
res_ms_ensembles = Optimization.solve(optprob, Adam(0.05); maxiters = 300)

loss_ms_ensembles = loss_single_shooting(res_ms_ensembles.minimizer)

println("Multiple shooting loss with EnsembleProblem: $(loss_ms_ensembles)")

@test loss_ms_ensembles < 10loss_ss
end
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps))

# Define the Neural Network
nn = Chain(x -> x .^ 3, Dense(2 => 16, tanh), Dense(16 => 2))
p_init, st = Lux.setup(rng, nn)
p_init = ComponentArray(p_init)

neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps)
prob_node = ODEProblem((u, p, t) -> first(nn(u, p, st)), u0, tspan, p_init)

predict_single_shooting(p) = Array(first(neuralode(u0, p, st)))

# Define loss function
loss_function(data, pred) = sum(abs2, data - pred)

## Evaluate Single Shooting
function loss_single_shooting(p)
pred = predict_single_shooting(p)
l = loss_function(ode_data, pred)
return l
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((p, _) -> loss_single_shooting(p), adtype)
optprob = Optimization.OptimizationProblem(optf, p_init)
res_single_shooting = Optimization.solve(optprob, Adam(0.05); maxiters = 300)

loss_ss = loss_single_shooting(res_single_shooting.minimizer)
@info "Single shooting loss: $(loss_ss)"

## Test Multiple Shooting
group_size = 3
continuity_term = 200

function loss_multiple_shooting(p)
return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(),
group_size; continuity_term, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting(p), adtype)
optprob = Optimization.OptimizationProblem(optf, p_init)
res_ms = Optimization.solve(optprob, Adam(0.05); maxiters = 300)

# Calculate single shooting loss with parameter from multiple_shoot training
loss_ms = loss_single_shooting(res_ms.minimizer)
println("Multiple shooting loss: $(loss_ms)")
@test loss_ms < 10loss_ss

# Test with custom loss function
group_size = 4
continuity_term = 50

function continuity_loss_abs2(û_end, u_0)
return sum(abs2, û_end - u_0) # using abs2 instead of default abs
end

function loss_multiple_shooting_abs2(p)
return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function,
continuity_loss_abs2, Tsit5(), group_size; continuity_term)[1]
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction(
(p, _) -> loss_multiple_shooting_abs2(p), adtype)
optprob = Optimization.OptimizationProblem(optf, p_init)
res_ms_abs2 = Optimization.solve(optprob, Adam(0.05); maxiters = 300)

loss_ms_abs2 = loss_single_shooting(res_ms_abs2.minimizer)
println("Multiple shooting loss with abs2: $(loss_ms_abs2)")
@test loss_ms_abs2 < loss_ss

## Test different SensitivityAlgorithm (default is InterpolatingAdjoint)
function loss_multiple_shooting_fd(p)
return multiple_shoot(
p, ode_data, tsteps, prob_node, loss_function, continuity_loss_abs2,
Tsit5(), group_size; continuity_term, sensealg = ForwardDiffSensitivity())[1]
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((p, _) -> loss_multiple_shooting_fd(p), adtype)
optprob = Optimization.OptimizationProblem(optf, p_init)
res_ms_fd = Optimization.solve(optprob, Adam(0.05); maxiters = 300)

# Calculate single shooting loss with parameter from multiple_shoot training
loss_ms_fd = loss_single_shooting(res_ms_fd.minimizer)
println("Multiple shooting loss with ForwardDiffSensitivity: $(loss_ms_fd)")
@test loss_ms_fd < 10loss_ss

# Integration return codes `!= :Success` should return infinite loss.
# In this case, we trigger `retcode = :MaxIters` by setting the solver option `maxiters=1`.
loss_fail = multiple_shoot(p_init, ode_data, tsteps, prob_node, loss_function,
Tsit5(), datasize; maxiters = 1, verbose = false)[1]
@test loss_fail == Inf

## Test for DomainErrors
@test_throws DomainError multiple_shoot(
p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), 1)
@test_throws DomainError multiple_shoot(
p_init, ode_data, tsteps, prob_node, loss_function, Tsit5(), datasize + 1)

## Ensembles
u0s = [Float32[2.0, 0.0], Float32[3.0, 1.0]]
function prob_func(prob, i, repeat)
remake(prob; u0 = u0s[i])
end
ensemble_prob = EnsembleProblem(prob_node; prob_func = prob_func)
ensemble_prob_trueODE = EnsembleProblem(prob_trueode; prob_func = prob_func)
ensemble_alg = EnsembleThreads()
trajectories = 2
ode_data_ensemble = Array(solve(
ensemble_prob_trueODE, Tsit5(), ensemble_alg; trajectories, saveat = tsteps))

group_size = 3
continuity_term = 200
function loss_multiple_shooting_ens(p)
return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
loss_function, Tsit5(), group_size; continuity_term,
trajectories, abstol = 1e-8, reltol = 1e-6)[1] # test solver kwargs
end

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction(
(p, _) -> loss_multiple_shooting_ens(p), adtype)
optprob = Optimization.OptimizationProblem(optf, p_init)
res_ms_ensembles = Optimization.solve(optprob, Adam(0.05); maxiters = 300)

loss_ms_ensembles = loss_single_shooting(res_ms_ensembles.minimizer)

println("Multiple shooting loss with EnsembleProblem: $(loss_ms_ensembles)")

@test loss_ms_ensembles < 10loss_ss
end
Loading