diff --git a/docs/src/examples/augmented_neural_ode.md b/docs/src/examples/augmented_neural_ode.md index 7799243597..a1ddf433c6 100644 --- a/docs/src/examples/augmented_neural_ode.md +++ b/docs/src/examples/augmented_neural_ode.md @@ -61,6 +61,7 @@ function plot_contour(model, ps, st, npoints = 300) x = range(-4.0f0, 4.0f0; length = npoints) y = range(-4.0f0, 4.0f0; length = npoints) for x1 in x, x2 in y + grid_points[:, idx] .= [x1, x2] idx += 1 end @@ -212,6 +213,7 @@ function plot_contour(model, ps, st, npoints = 300) x = range(-4.0f0, 4.0f0; length = npoints) y = range(-4.0f0, 4.0f0; length = npoints) for x1 in x, x2 in y + grid_points[:, idx] .= [x1, x2] idx += 1 end diff --git a/docs/src/examples/multiple_shooting.md b/docs/src/examples/multiple_shooting.md index 6d08f5eed2..3815eb4572 100644 --- a/docs/src/examples/multiple_shooting.md +++ b/docs/src/examples/multiple_shooting.md @@ -62,13 +62,15 @@ function loss_function(data, pred) return sum(abs2, data - pred) end -l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, +l1, +preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term) function loss_multiple_shooting(p) ps = ComponentArray(p, pax) - loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, + loss, + currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term) global preds = currpred return loss @@ -93,7 +95,8 @@ function callback(state, l; doplot = true, prob_node = prob_node) # plot the original data plt = scatter(tsteps, ode_data[1, :]; label = "Data") # plot the different predictions for individual shoot - l1, preds = multiple_shoot( + l1, + preds = multiple_shoot( ComponentArray(state.u, pax), ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term) plot_multiple_shoot(plt, preds, group_size) @@ -127,7 +130,8 @@ pd, pax = getdata(ps), getaxes(ps) function loss_single_shooting(p) ps = ComponentArray(p, pax) - loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, + loss, + currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term) global preds = currpred return loss diff --git a/docs/src/examples/neural_ode_weather_forecast.md b/docs/src/examples/neural_ode_weather_forecast.md index 8f69acfe5c..be7e519e42 100644 --- a/docs/src/examples/neural_ode_weather_forecast.md +++ b/docs/src/examples/neural_ode_weather_forecast.md @@ -134,7 +134,8 @@ function train(t, y, obs_grid, maxiters, lr, rng, p = nothing, state = nothing; p === nothing && (p = p_new) state === nothing && (state = state_new) - p, state = train_one_round(node, p, state, y, OptimizationOptimisers.AdamW(lr), + p, + state = train_one_round(node, p, state, y, OptimizationOptimisers.AdamW(lr), maxiters, rng; callback = log_results(ps, losses), kwargs...) end ps, state, losses diff --git a/test/cnf_tests.jl b/test/cnf_tests.jl index 322b1647d8..08e8264ff3 100644 --- a/test/cnf_tests.jl +++ b/test/cnf_tests.jl @@ -20,24 +20,25 @@ export callback end @testitem "Smoke test for FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin - nn = Chain(Dense(1, 1, tanh)) - tspan = (0.0f0, 1.0f0) - ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5()) - ps, st = Lux.setup(Xoshiro(0), ffjord_mdl) - ps = ComponentArray(ps) + nn=Chain(Dense(1, 1, tanh)) + tspan=(0.0f0, 1.0f0) + ffjord_mdl=FFJORD(nn, tspan, (1,), Tsit5()) + ps, st=Lux.setup(Xoshiro(0), ffjord_mdl) + ps=ComponentArray(ps) - data_dist = Beta(2.0f0, 2.0f0) - train_data = Float32.(rand(data_dist, 1, 100)) + data_dist=Beta(2.0f0, 2.0f0) + train_data=Float32.(rand(data_dist, 1, 100)) function loss(model, θ) - logpx, λ₁, λ₂ = model(train_data, θ) + logpx, λ₁, λ₂=model(train_data, θ) return -mean(logpx) end @testset "ADType: $(adtype)" for adtype in (Optimization.AutoForwardDiff(), Optimization.AutoReverseDiff(), Optimization.AutoTracker(), Optimization.AutoZygote(), Optimization.AutoFiniteDiff()) - @testset "regularize = $(regularize) & monte_carlo = $(monte_carlo)" for regularize in ( + @testset "regularize = $(regularize) & monte_carlo = $(monte_carlo)" for regularize in + ( true, false), monte_carlo in (true, false) @@ -54,33 +55,33 @@ end end @testitem "Smoke test for FFJORDDistribution (sampling & pdf)" setup=[CNFTestSetup] tags=[:advancedneuralde] begin - nn = Chain(Dense(1, 1, tanh)) - tspan = (0.0f0, 1.0f0) - ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5()) - ps, st = Lux.setup(Xoshiro(0), ffjord_mdl) - ps = ComponentArray(ps) + nn=Chain(Dense(1, 1, tanh)) + tspan=(0.0f0, 1.0f0) + ffjord_mdl=FFJORD(nn, tspan, (1,), Tsit5()) + ps, st=Lux.setup(Xoshiro(0), ffjord_mdl) + ps=ComponentArray(ps) - regularize = false - monte_carlo = false + regularize=false + monte_carlo=false - data_dist = Beta(2.0f0, 2.0f0) - train_data = Float32.(rand(data_dist, 1, 100)) + data_dist=Beta(2.0f0, 2.0f0) + train_data=Float32.(rand(data_dist, 1, 100)) function loss(model, θ) - logpx, λ₁, λ₂ = model(train_data, θ) + logpx, λ₁, λ₂=model(train_data, θ) return -mean(logpx) end - adtype = Optimization.AutoZygote() + adtype=Optimization.AutoZygote() - st_ = (; st..., regularize, monte_carlo) - model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st_) + st_=(; st..., regularize, monte_carlo) + model=StatefulLuxLayer{true}(ffjord_mdl, nothing, st_) - optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype) - optprob = Optimization.OptimizationProblem(optf, ps) - res = Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 10) + optf=Optimization.OptimizationFunction((θ, _)->loss(model, θ), adtype) + optprob=Optimization.OptimizationProblem(optf, ps) + res=Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 10) - ffjord_d = FFJORDDistribution(ffjord_mdl, res.u, st_) + ffjord_d=FFJORDDistribution(ffjord_mdl, res.u, st_) @test !isnothing(pdf(ffjord_d, train_data)) @test !isnothing(rand(ffjord_d)) @@ -88,143 +89,143 @@ end end @testitem "Test for default base distribution and deterministic trace FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin - nn = Chain(Dense(1, 1, tanh)) - tspan = (0.0f0, 1.0f0) - ffjord_mdl = FFJORD(nn, tspan, (1,), Tsit5()) - ps, st = Lux.setup(Xoshiro(0), ffjord_mdl) - ps = ComponentArray(ps) + nn=Chain(Dense(1, 1, tanh)) + tspan=(0.0f0, 1.0f0) + ffjord_mdl=FFJORD(nn, tspan, (1,), Tsit5()) + ps, st=Lux.setup(Xoshiro(0), ffjord_mdl) + ps=ComponentArray(ps) - regularize = false - monte_carlo = false + regularize=false + monte_carlo=false - data_dist = Beta(7.0f0, 7.0f0) - train_data = Float32.(rand(data_dist, 1, 100)) - test_data = Float32.(rand(data_dist, 1, 100)) + data_dist=Beta(7.0f0, 7.0f0) + train_data=Float32.(rand(data_dist, 1, 100)) + test_data=Float32.(rand(data_dist, 1, 100)) function loss(model, θ) - logpx, λ₁, λ₂ = model(train_data, θ) + logpx, λ₁, λ₂=model(train_data, θ) return -mean(logpx) end - adtype = Optimization.AutoZygote() - st_ = (; st..., regularize, monte_carlo) - model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st_) + adtype=Optimization.AutoZygote() + st_=(; st..., regularize, monte_carlo) + model=StatefulLuxLayer{true}(ffjord_mdl, nothing, st_) - optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype) - optprob = Optimization.OptimizationProblem(optf, ps) - res = Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 10) + optf=Optimization.OptimizationFunction((θ, _)->loss(model, θ), adtype) + optprob=Optimization.OptimizationProblem(optf, ps) + res=Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 10) - actual_pdf = pdf.(data_dist, test_data) - learned_pdf = exp.(model(test_data, res.u)[1]) + actual_pdf=pdf.(data_dist, test_data) + learned_pdf=exp.(model(test_data, res.u)[1]) @test ps != res.u @test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.9 end @testitem "Test for alternative base distribution and deterministic trace FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin - nn = Chain(Dense(1, 3, tanh), Dense(3, 1, tanh)) - tspan = (0.0f0, 1.0f0) - ffjord_mdl = FFJORD( + nn=Chain(Dense(1, 3, tanh), Dense(3, 1, tanh)) + tspan=(0.0f0, 1.0f0) + ffjord_mdl=FFJORD( nn, tspan, (1,), Tsit5(); basedist = MvNormal([0.0f0], Diagonal([4.0f0]))) - ps, st = Lux.setup(Xoshiro(0), ffjord_mdl) - ps = ComponentArray(ps) + ps, st=Lux.setup(Xoshiro(0), ffjord_mdl) + ps=ComponentArray(ps) - regularize = false - monte_carlo = false + regularize=false + monte_carlo=false - data_dist = Normal(6.0f0, 0.7f0) - train_data = Float32.(rand(data_dist, 1, 100)) - test_data = Float32.(rand(data_dist, 1, 100)) + data_dist=Normal(6.0f0, 0.7f0) + train_data=Float32.(rand(data_dist, 1, 100)) + test_data=Float32.(rand(data_dist, 1, 100)) function loss(model, θ) - logpx, λ₁, λ₂ = model(train_data, θ) + logpx, λ₁, λ₂=model(train_data, θ) return -mean(logpx) end - adtype = Optimization.AutoZygote() - st_ = (; st..., regularize, monte_carlo) - model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st_) + adtype=Optimization.AutoZygote() + st_=(; st..., regularize, monte_carlo) + model=StatefulLuxLayer{true}(ffjord_mdl, nothing, st_) - optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype) - optprob = Optimization.OptimizationProblem(optf, ps) - res = Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 30) + optf=Optimization.OptimizationFunction((θ, _)->loss(model, θ), adtype) + optprob=Optimization.OptimizationProblem(optf, ps) + res=Optimization.solve(optprob, Adam(0.1); callback = callback(adtype), maxiters = 30) - actual_pdf = pdf.(data_dist, test_data) - learned_pdf = exp.(model(test_data, res.u)[1]) + actual_pdf=pdf.(data_dist, test_data) + learned_pdf=exp.(model(test_data, res.u)[1]) @test ps != res.u @test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.25 end @testitem "Test for multivariate distribution and deterministic trace FFJORD" setup=[CNFTestSetup] tags=[:advancedneuralde] begin - nn = Chain(Dense(2, 2, tanh)) - tspan = (0.0f0, 1.0f0) - ffjord_mdl = FFJORD(nn, tspan, (2,), Tsit5()) - ps, st = Lux.setup(Xoshiro(0), ffjord_mdl) - ps = ComponentArray(ps) + nn=Chain(Dense(2, 2, tanh)) + tspan=(0.0f0, 1.0f0) + ffjord_mdl=FFJORD(nn, tspan, (2,), Tsit5()) + ps, st=Lux.setup(Xoshiro(0), ffjord_mdl) + ps=ComponentArray(ps) - regularize = false - monte_carlo = false + regularize=false + monte_carlo=false - μ = ones(Float32, 2) - Σ = Diagonal([7.0f0, 7.0f0]) - data_dist = MvNormal(μ, Σ) - train_data = Float32.(rand(data_dist, 100)) - test_data = Float32.(rand(data_dist, 100)) + μ=ones(Float32, 2) + Σ=Diagonal([7.0f0, 7.0f0]) + data_dist=MvNormal(μ, Σ) + train_data=Float32.(rand(data_dist, 100)) + test_data=Float32.(rand(data_dist, 100)) function loss(model, θ) - logpx, λ₁, λ₂ = model(train_data, θ) + logpx, λ₁, λ₂=model(train_data, θ) return -mean(logpx) end - adtype = Optimization.AutoZygote() - st_ = (; st..., regularize, monte_carlo) - model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st_) + adtype=Optimization.AutoZygote() + st_=(; st..., regularize, monte_carlo) + model=StatefulLuxLayer{true}(ffjord_mdl, nothing, st_) - optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype) - optprob = Optimization.OptimizationProblem(optf, ps) - res = Optimization.solve( + optf=Optimization.OptimizationFunction((θ, _)->loss(model, θ), adtype) + optprob=Optimization.OptimizationProblem(optf, ps) + res=Optimization.solve( optprob, Adam(0.01); callback = callback(adtype), maxiters = 30) - actual_pdf = pdf(data_dist, test_data) - learned_pdf = exp.(model(test_data, res.u)[1]) + actual_pdf=pdf(data_dist, test_data) + learned_pdf=exp.(model(test_data, res.u)[1]) @test ps != res.u @test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.25 end @testitem "Test for multivariate distribution and FFJORD with regularizers" setup=[CNFTestSetup] tags=[:advancedneuralde] begin - nn = Chain(Dense(2, 2, tanh)) - tspan = (0.0f0, 1.0f0) - ffjord_mdl = FFJORD(nn, tspan, (2,), Tsit5()) - ps, st = Lux.setup(Xoshiro(0), ffjord_mdl) - ps = ComponentArray(ps) .* 0.001f0 + nn=Chain(Dense(2, 2, tanh)) + tspan=(0.0f0, 1.0f0) + ffjord_mdl=FFJORD(nn, tspan, (2,), Tsit5()) + ps, st=Lux.setup(Xoshiro(0), ffjord_mdl) + ps=ComponentArray(ps) .* 0.001f0 - regularize = true - monte_carlo = true + regularize=true + monte_carlo=true - μ = ones(Float32, 2) - Σ = Diagonal([7.0f0, 7.0f0]) - data_dist = MvNormal(μ, Σ) - train_data = Float32.(rand(data_dist, 100)) - test_data = Float32.(rand(data_dist, 100)) + μ=ones(Float32, 2) + Σ=Diagonal([7.0f0, 7.0f0]) + data_dist=MvNormal(μ, Σ) + train_data=Float32.(rand(data_dist, 100)) + test_data=Float32.(rand(data_dist, 100)) function loss(model, θ) - logpx, λ₁, λ₂ = model(train_data, θ) + logpx, λ₁, λ₂=model(train_data, θ) return -mean(logpx) end - adtype = Optimization.AutoZygote() - st_ = (; st..., regularize, monte_carlo) - model = StatefulLuxLayer{true}(ffjord_mdl, nothing, st_) + adtype=Optimization.AutoZygote() + st_=(; st..., regularize, monte_carlo) + model=StatefulLuxLayer{true}(ffjord_mdl, nothing, st_) - optf = Optimization.OptimizationFunction((θ, _) -> loss(model, θ), adtype) - optprob = Optimization.OptimizationProblem(optf, ps) - res = Optimization.solve( + optf=Optimization.OptimizationFunction((θ, _)->loss(model, θ), adtype) + optprob=Optimization.OptimizationProblem(optf, ps) + res=Optimization.solve( optprob, Adam(0.01); callback = callback(adtype), maxiters = 30) - actual_pdf = pdf(data_dist, test_data) - learned_pdf = exp.(model(test_data, res.u)[1]) + actual_pdf=pdf(data_dist, test_data) + learned_pdf=exp.(model(test_data, res.u)[1]) @test ps != res.u @test totalvariation(learned_pdf, actual_pdf) / size(test_data, 2) < 0.25 diff --git a/test/multiple_shoot_tests.jl b/test/multiple_shoot_tests.jl index 057f5ab474..f049078263 100644 --- a/test/multiple_shoot_tests.jl +++ b/test/multiple_shoot_tests.jl @@ -24,9 +24,11 @@ ( name = "Multi-D Test Config", u0 = Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0], - ode_func = (du, u, p, t) -> (du .= ((u .^ 3).*[-0.01 0.02; -0.02 -0.01; 0.01 -0.05])), + ode_func = ( + du, u, p, t) -> (du .= ((u .^ 3) .* [-0.01 0.02; -0.02 -0.01; 0.01 -0.05])), nn = Chain(x -> x .^ 3, Dense(3 => 3, tanh)), - u0s_ensemble = [Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0], Float32[3.0 1.0; 2.0 0.5; 1.5 -0.5]] + u0s_ensemble = [ + Float32[2.0 0.0; 1.0 1.5; 0.5 -1.0], Float32[3.0 1.0; 2.0 0.5; 1.5 -0.5]] ) ] @@ -158,7 +160,8 @@ group_size = 3 continuity_term = 200 function loss_multiple_shooting_ens(p) - return multiple_shoot(p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg, + return multiple_shoot( + p, ode_data_ensemble, tsteps, ensemble_prob, ensemble_alg, loss_function, Tsit5(), group_size; continuity_term, trajectories, abstol = 1e-8, reltol = 1e-6)[1] end diff --git a/test/neural_de_tests.jl b/test/neural_de_tests.jl index 8bbd35a23d..7407c5caa1 100644 --- a/test/neural_de_tests.jl +++ b/test/neural_de_tests.jl @@ -246,10 +246,10 @@ end CUDA.allowscalar(false) - rng = Xoshiro(0) + rng=Xoshiro(0) - const gdev = gpu_device() - const cdev = cpu_device() + const gdev=gpu_device() + const cdev=cpu_device() @testset "Neural DE" begin mp = Float32[0.1, 0.1] |> gdev