Skip to content

Apply JuliaFormatter to fix code formatting #980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/examples/augmented_neural_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ function plot_contour(model, ps, st, npoints = 300)
x = range(-4.0f0, 4.0f0; length = npoints)
y = range(-4.0f0, 4.0f0; length = npoints)
for x1 in x, x2 in y

grid_points[:, idx] .= [x1, x2]
idx += 1
end
Expand Down Expand Up @@ -212,6 +213,7 @@ function plot_contour(model, ps, st, npoints = 300)
x = range(-4.0f0, 4.0f0; length = npoints)
y = range(-4.0f0, 4.0f0; length = npoints)
for x1 in x, x2 in y

grid_points[:, idx] .= [x1, x2]
idx += 1
end
Expand Down
12 changes: 8 additions & 4 deletions docs/src/examples/multiple_shooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,15 @@ function loss_function(data, pred)
return sum(abs2, data - pred)
end

l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
l1,
preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
Tsit5(), group_size; continuity_term)

function loss_multiple_shooting(p)
ps = ComponentArray(p, pax)

loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
loss,
currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
Tsit5(), group_size; continuity_term)
global preds = currpred
return loss
Expand All @@ -93,7 +95,8 @@ function callback(state, l; doplot = true, prob_node = prob_node)
# plot the original data
plt = scatter(tsteps, ode_data[1, :]; label = "Data")
# plot the different predictions for individual shoot
l1, preds = multiple_shoot(
l1,
preds = multiple_shoot(
ComponentArray(state.u, pax), ode_data, tsteps, prob_node, loss_function,
Tsit5(), group_size; continuity_term)
plot_multiple_shoot(plt, preds, group_size)
Expand Down Expand Up @@ -127,7 +130,8 @@ pd, pax = getdata(ps), getaxes(ps)

function loss_single_shooting(p)
ps = ComponentArray(p, pax)
loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
loss,
currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
Tsit5(), group_size; continuity_term)
global preds = currpred
return loss
Expand Down
3 changes: 2 additions & 1 deletion docs/src/examples/neural_ode_weather_forecast.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ function train(t, y, obs_grid, maxiters, lr, rng, p = nothing, state = nothing;
p === nothing && (p = p_new)
state === nothing && (state = state_new)

p, state = train_one_round(node, p, state, y, OptimizationOptimisers.AdamW(lr),
p,
state = train_one_round(node, p, state, y, OptimizationOptimisers.AdamW(lr),
maxiters, rng; callback = log_results(ps, losses), kwargs...)
end
ps, state, losses
Expand Down
213 changes: 107 additions & 106 deletions test/cnf_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,25 @@ export callback
end

@testitem "Smoke test for FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin
nn = Chain(Dense(1, 1, tanh))
tspan = (0.0f0, 1.0f0)
ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5())
ps, st = Lux.setup(Xoshiro(0), ffjord_mdl)
ps = ComponentArray(ps)
nn=Chain(Dense(1, 1, tanh))
tspan=(0.0f0, 1.0f0)
ffjord_mdl=FFJORD(nn, tspan, (1,), Tsit5())
ps, st=Lux.setup(Xoshiro(0), ffjord_mdl)
ps=ComponentArray(ps)

data_dist = Beta(2.0f0, 2.0f0)
train_data = Float32.(rand(data_dist, 1, 100))
data_dist=Beta(2.0f0, 2.0f0)
train_data=Float32.(rand(data_dist, 1, 100))

function loss(model, θ)
logpx, λ₁, λ₂ = model(train_data, θ)
logpx, λ₁, λ₂=model(train_data, θ)
return -mean(logpx)
end

@testset "ADType: $(adtype)" for adtype in (Optimization.AutoForwardDiff(),
Optimization.AutoReverseDiff(), Optimization.AutoTracker(),
Optimization.AutoZygote(), Optimization.AutoFiniteDiff())
@testset "regularize = $(regularize) & monte_carlo = $(monte_carlo)" for regularize in (
@testset "regularize = $(regularize) & monte_carlo = $(monte_carlo)" for regularize in
(
true, false),
monte_carlo in (true, false)

Expand All @@ -54,177 +55,177 @@ end
end

@testitem "Smoke test for FFJORDDistribution (sampling & pdf)" setup=[CNFTestSetup] tags=[:advancedneuralde] begin
nn = Chain(Dense(1, 1, tanh))
tspan = (0.0f0, 1.0f0)
ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5())
ps, st = Lux.setup(Xoshiro(0), ffjord_mdl)
ps = ComponentArray(ps)
nn=Chain(Dense(1, 1, tanh))
tspan=(0.0f0, 1.0f0)
ffjord_mdl=FFJORD(nn, tspan, (1,), Tsit5())
ps, st=Lux.setup(Xoshiro(0), ffjord_mdl)
ps=ComponentArray(ps)

regularize = false
monte_carlo = false
regularize=false
monte_carlo=false

data_dist = Beta(2.0f0, 2.0f0)
train_data = Float32.(rand(data_dist, 1, 100))
data_dist=Beta(2.0f0, 2.0f0)
train_data=Float32.(rand(data_dist, 1, 100))

function loss(model, θ)
logpx, λ₁, λ₂ = model(train_data, θ)
logpx, λ₁, λ₂=model(train_data, θ)
return -mean(logpx)
end

adtype = Optimization.AutoZygote()
adtype=Optimization.AutoZygote()

st_ = (; st..., regularize, monte_carlo)
model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)
st_=(; st..., regularize, monte_carlo)
model=StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)

optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype)
optprob = Optimization.OptimizationProblem(optf, ps)
res = Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 10)
optf=Optimization.OptimizationFunction((θ, _)->loss(model, θ), adtype)
optprob=Optimization.OptimizationProblem(optf, ps)
res=Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 10)

ffjord_d = FFJORDDistribution(ffjord_mdl, res.u, st_)
ffjord_d=FFJORDDistribution(ffjord_mdl, res.u, st_)

@test !isnothing(pdf(ffjord_d, train_data))
@test !isnothing(rand(ffjord_d))
@test !isnothing(rand(ffjord_d, 10))
end

@testitem "Test for default base distribution and deterministic trace FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin
nn = Chain(Dense(1, 1, tanh))
tspan = (0.0f0, 1.0f0)
ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5())
ps, st = Lux.setup(Xoshiro(0), ffjord_mdl)
ps = ComponentArray(ps)
nn=Chain(Dense(1, 1, tanh))
tspan=(0.0f0, 1.0f0)
ffjord_mdl=FFJORD(nn, tspan, (1,), Tsit5())
ps, st=Lux.setup(Xoshiro(0), ffjord_mdl)
ps=ComponentArray(ps)

regularize = false
monte_carlo = false
regularize=false
monte_carlo=false

data_dist = Beta(7.0f0, 7.0f0)
train_data = Float32.(rand(data_dist, 1, 100))
test_data = Float32.(rand(data_dist, 1, 100))
data_dist=Beta(7.0f0, 7.0f0)
train_data=Float32.(rand(data_dist, 1, 100))
test_data=Float32.(rand(data_dist, 1, 100))

function loss(model, θ)
logpx, λ₁, λ₂ = model(train_data, θ)
logpx, λ₁, λ₂=model(train_data, θ)
return -mean(logpx)
end

adtype = Optimization.AutoZygote()
st_ = (; st..., regularize, monte_carlo)
model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)
adtype=Optimization.AutoZygote()
st_=(; st..., regularize, monte_carlo)
model=StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)

optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype)
optprob = Optimization.OptimizationProblem(optf, ps)
res = Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 10)
optf=Optimization.OptimizationFunction((θ, _)->loss(model, θ), adtype)
optprob=Optimization.OptimizationProblem(optf, ps)
res=Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 10)

actual_pdf = pdf.(data_dist, test_data)
learned_pdf = exp.(model(test_data, res.u)[1])
actual_pdf=pdf.(data_dist, test_data)
learned_pdf=exp.(model(test_data, res.u)[1])

@test ps != res.u
@test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.9
end

@testitem "Test for alternative base distribution and deterministic trace FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin
nn = Chain(Dense(1, 3, tanh), Dense(3, 1, tanh))
tspan = (0.0f0, 1.0f0)
ffjord_mdl = FFJORD(
nn=Chain(Dense(1, 3, tanh), Dense(3, 1, tanh))
tspan=(0.0f0, 1.0f0)
ffjord_mdl=FFJORD(
nn, tspan, (1,), Tsit5(); basedist = MvNormal([0.0f0], Diagonal([4.0f0])))
ps, st = Lux.setup(Xoshiro(0), ffjord_mdl)
ps = ComponentArray(ps)
ps, st=Lux.setup(Xoshiro(0), ffjord_mdl)
ps=ComponentArray(ps)

regularize = false
monte_carlo = false
regularize=false
monte_carlo=false

data_dist = Normal(6.0f0, 0.7f0)
train_data = Float32.(rand(data_dist, 1, 100))
test_data = Float32.(rand(data_dist, 1, 100))
data_dist=Normal(6.0f0, 0.7f0)
train_data=Float32.(rand(data_dist, 1, 100))
test_data=Float32.(rand(data_dist, 1, 100))

function loss(model, θ)
logpx, λ₁, λ₂ = model(train_data, θ)
logpx, λ₁, λ₂=model(train_data, θ)
return -mean(logpx)
end

adtype = Optimization.AutoZygote()
st_ = (; st..., regularize, monte_carlo)
model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)
adtype=Optimization.AutoZygote()
st_=(; st..., regularize, monte_carlo)
model=StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)

optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype)
optprob = Optimization.OptimizationProblem(optf, ps)
res = Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 30)
optf=Optimization.OptimizationFunction((θ, _)->loss(model, θ), adtype)
optprob=Optimization.OptimizationProblem(optf, ps)
res=Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 30)

actual_pdf = pdf.(data_dist, test_data)
learned_pdf = exp.(model(test_data, res.u)[1])
actual_pdf=pdf.(data_dist, test_data)
learned_pdf=exp.(model(test_data, res.u)[1])

@test ps != res.u
@test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.25
end

@testitem "Test for multivariate distribution and deterministic trace FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin
nn = Chain(Dense(2, 2, tanh))
tspan = (0.0f0, 1.0f0)
ffjord_mdl = FFJORD(nn, tspan, (2,), Tsit5())
ps, st = Lux.setup(Xoshiro(0), ffjord_mdl)
ps = ComponentArray(ps)
nn=Chain(Dense(2, 2, tanh))
tspan=(0.0f0, 1.0f0)
ffjord_mdl=FFJORD(nn, tspan, (2,), Tsit5())
ps, st=Lux.setup(Xoshiro(0), ffjord_mdl)
ps=ComponentArray(ps)

regularize = false
monte_carlo = false
regularize=false
monte_carlo=false

μ = ones(Float32, 2)
Σ = Diagonal([7.0f0, 7.0f0])
data_dist = MvNormal(μ, Σ)
train_data = Float32.(rand(data_dist, 100))
test_data = Float32.(rand(data_dist, 100))
μ=ones(Float32, 2)
Σ=Diagonal([7.0f0, 7.0f0])
data_dist=MvNormal(μ, Σ)
train_data=Float32.(rand(data_dist, 100))
test_data=Float32.(rand(data_dist, 100))

function loss(model, θ)
logpx, λ₁, λ₂ = model(train_data, θ)
logpx, λ₁, λ₂=model(train_data, θ)
return -mean(logpx)
end

adtype = Optimization.AutoZygote()
st_ = (; st..., regularize, monte_carlo)
model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)
adtype=Optimization.AutoZygote()
st_=(; st..., regularize, monte_carlo)
model=StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)

optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype)
optprob = Optimization.OptimizationProblem(optf, ps)
res = Optimization.solve(
optf=Optimization.OptimizationFunction((θ, _)->loss(model, θ), adtype)
optprob=Optimization.OptimizationProblem(optf, ps)
res=Optimization.solve(
optprob, Adam(0.01); callback = callback(adtype), maxiters = 30)

actual_pdf = pdf(data_dist, test_data)
learned_pdf = exp.(model(test_data, res.u)[1])
actual_pdf=pdf(data_dist, test_data)
learned_pdf=exp.(model(test_data, res.u)[1])

@test ps != res.u
@test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.25
end

@testitem "Test for multivariate distribution and FFJORD with regularizers" setup=[CNFTestSetup] tags=[:advancedneuralde] begin
nn = Chain(Dense(2, 2, tanh))
tspan = (0.0f0, 1.0f0)
ffjord_mdl = FFJORD(nn, tspan, (2,), Tsit5())
ps, st = Lux.setup(Xoshiro(0), ffjord_mdl)
ps = ComponentArray(ps) .* 0.001f0
nn=Chain(Dense(2, 2, tanh))
tspan=(0.0f0, 1.0f0)
ffjord_mdl=FFJORD(nn, tspan, (2,), Tsit5())
ps, st=Lux.setup(Xoshiro(0), ffjord_mdl)
ps=ComponentArray(ps) .* 0.001f0

regularize = true
monte_carlo = true
regularize=true
monte_carlo=true

μ = ones(Float32, 2)
Σ = Diagonal([7.0f0, 7.0f0])
data_dist = MvNormal(μ, Σ)
train_data = Float32.(rand(data_dist, 100))
test_data = Float32.(rand(data_dist, 100))
μ=ones(Float32, 2)
Σ=Diagonal([7.0f0, 7.0f0])
data_dist=MvNormal(μ, Σ)
train_data=Float32.(rand(data_dist, 100))
test_data=Float32.(rand(data_dist, 100))

function loss(model, θ)
logpx, λ₁, λ₂ = model(train_data, θ)
logpx, λ₁, λ₂=model(train_data, θ)
return -mean(logpx)
end

adtype = Optimization.AutoZygote()
st_ = (; st..., regularize, monte_carlo)
model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)
adtype=Optimization.AutoZygote()
st_=(; st..., regularize, monte_carlo)
model=StatefulLuxLayer{true}(ffjord_mdl, nothing, st_)

optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype)
optprob = Optimization.OptimizationProblem(optf, ps)
res = Optimization.solve(
optf=Optimization.OptimizationFunction((θ, _)->loss(model, θ), adtype)
optprob=Optimization.OptimizationProblem(optf, ps)
res=Optimization.solve(
optprob, Adam(0.01); callback = callback(adtype), maxiters = 30)

actual_pdf = pdf(data_dist, test_data)
learned_pdf = exp.(model(test_data, res.u)[1])
actual_pdf=pdf(data_dist, test_data)
learned_pdf=exp.(model(test_data, res.u)[1])

@test ps != res.u
@test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.25
Expand Down
9 changes: 6 additions & 3 deletions test/multiple_shoot_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
(
name = "Multi-D Test Config",
u0 = Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0],
ode_func = (du, u, p, t) -> (du .= ((u .^ 3).*[-0.01 0.02; -0.02 -0.01; 0.01 -0.05])),
ode_func = (
du, u, p, t) -> (du .= ((u .^ 3) .* [-0.01 0.02; -0.02 -0.01; 0.01 -0.05])),
nn = Chain(x -> x .^ 3, Dense(3 => 3, tanh)),
u0s_ensemble = [Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0], Float32[3.0 1.0; 2.0 0.5; 1.5 -0.5]]
u0s_ensemble = [
Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0], Float32[3.0 1.0; 2.0 0.5; 1.5 -0.5]]
)
]

Expand Down Expand Up @@ -158,7 +160,8 @@
group_size = 3
continuity_term = 200
function loss_multiple_shooting_ens(p)
return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
return multiple_shoot(
p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg,
loss_function, Tsit5(), group_size; continuity_term,
trajectories, abstol = 1e-8, reltol = 1e-6)[1]
end
Expand Down
Loading
Loading