Skip to content

Add Deep ritz to NeuralPDE.jl #857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ include("ode_solve.jl")
# include("rode_solve.jl")
include("dae_solve.jl")
include("transform_inf_integral.jl")
include("deep_ritz.jl")
include("discretize.jl")
include("neural_adapter.jl")
include("advancedHMC_MCMC.jl")
include("BPINN_ode.jl")
include("PDE_BPINN.jl")
include("dgm.jl")


export NNODE, NNDAE,
PhysicsInformedNN, discretize,
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining,
Expand All @@ -66,6 +68,6 @@ export NNODE, NNDAE,
MiniMaxAdaptiveLoss, LogOptions,
ahmc_bayesian_pinn_ode, BNNODE, ahmc_bayesian_pinn_pde, vector_to_parameters,
BPINNsolution, BayesianPINN,
DeepGalerkin
DeepGalerkin, DeepRitz

end # module
135 changes: 135 additions & 0 deletions src/deep_ritz.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""
DeepRitz(chain,
strategy;
init_params = nothing,
phi = nothing,
param_estim = false,
additional_loss = nothing,
adaptive_loss = nothing,
logger = nothing,
log_options = LogOptions(),
iteration = nothing,
kwargs...)

A `discretize` algorithm for the ModelingToolkit PDESystem interface, which transforms a
`PDESystem` into an `OptimizationProblem` for the Deep Ritz method.

## Positional Arguments

* `chain`: a vector of Lux/Flux chains with a d-dimensional input and a
1-dimensional output corresponding to each of the dependent variables. Note that this
specification respects the order of the dependent variables as specified in the PDESystem.
Flux chains will be converted to Lux internally using `adapt(FromFluxAdaptor(false, false), chain)`.
* `strategy`: determines which training strategy will be used. See the Training Strategy
documentation for more details.

## Keyword Arguments

* `init_params`: the initial parameters of the neural networks. If `init_params` is not
given, then the neural network default parameters are used. Note that for Lux, the default
will convert to Float64.
* `phi`: a trial solution, specified as `phi(x,p)` where `x` is the coordinates vector for
the dependent variable and `p` are the weights of the phi function (generally the weights
of the neural network defining `phi`). By default, this is generated from the `chain`. This
should only be used to more directly impose functional information in the training problem,
for example imposing the boundary condition by the test function formulation.
* `adaptive_loss`: the choice for the adaptive loss function. See the
[adaptive loss page](@ref adaptive_loss) for more details. Defaults to no adaptivity.
* `additional_loss`: a function `additional_loss(phi, θ, p_)` where `phi` are the neural
network trial solutions, `θ` are the weights of the neural network(s), and `p_` are the
hyperparameters of the `OptimizationProblem`. If `param_estim = true`, then `θ` additionally
contains the parameters of the differential equation appended to the end of the vector.
* `param_estim`: whether the parameters of the differential equation should be included in
the values sent to the `additional_loss` function. Defaults to `false`.
* `logger`: ?? needs docs
* `log_options`: ?? why is this separate from the logger?
* `iteration`: used to control the iteration counter???
* `
"""
struct DeepRitz{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN
chain::Any
strategy::T
init_params::P
phi::PH
derivative::DER
param_estim::PE
additional_loss::AL
adaptive_loss::ADA
logger::LOG
log_options::LogOptions
iteration::Vector{Int64}
self_increment::Bool
multioutput::Bool
kwargs::K
end

function DeepRitz(chain, strategy; kwargs...)
pinn = NeuralPDE.PhysicsInformedNN(chain, strategy);

DeepRitz([
getfield(pinn, k) for k in propertynames(pinn)]...)
end

"""
prob = discretize(pde_system::PDESystem, discretization::DeepRitz)

For 2nd order PDEs, transforms a symbolic description of a ModelingToolkit-defined `PDESystem`
using Deep-Ritz me and generates an `OptimizationProblem` for [Optimization.jl](https://docs.sciml.ai/Optimization/stable/)
whose solution is the solution to the PDE.
"""
function SciMLBase.discretize(pde_system::PDESystem, discretization::DeepRitz)
modify_deep_ritz!(pde_system);
pinnrep = symbolic_discretize(pde_system, discretization)
f = OptimizationFunction(pinnrep.loss_functions.full_loss_function,
Optimization.AutoZygote())
Optimization.OptimizationProblem(f, pinnrep.flat_init_params)
end


"""
modify_deep_ritz!(pde_system::PDESystem)

Performs the checks for Deep-Ritz method and modifies the pde in the `pde_system`.
"""
function modify_deep_ritz!(pde_system::PDESystem)

if length(pde_system.eqs) > 1
error("Deep Ritz solves for only one dependent variable")
end

ind_vars = pde_system.ivs
dep_var = pde_system.dvs[1]

n_vars = length(ind_vars)

expr = first(pde_system.eqs).lhs - first(pde_system.eqs).rhs

Ds = [Differential(ind_var) for ind_var in ind_vars];
D²s = [Differential(ind_var)^2 for ind_var in ind_vars];
laplacian = (sum([d²s(dep_var) for d²s in D²s]) ~ 0).lhs;

expr_new = modify_laplacian(expr, laplacian, n_vars);

rhs = - expr_new * dep_var
lhs = (sum([(ds(dep_var))^2 for ds in Ds]) ~ 0).lhs;

pde_system.eqs[1] = lhs ~ rhs
return nothing
end


function modify_laplacian(expr, Δ, n_vars)
expr_new = expr - Δ;
if (operation(expr_new)!= +) || (length(expr_new.dict) + n_vars == length(expr.dict))
# positive coeff of laplacian
return expr_new
else
expr_new = expr + Δ
if length(expr_new.dict) == n_vars + length(expr.dict)
# negative coeff of laplacian
return expr_new
else
error("Incorrect form of PDE given")
end
end
end
2 changes: 1 addition & 1 deletion src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem,
pde_loss_functions,
bc_loss_functions)

function get_likelihood_estimate_function(discretization::PhysicsInformedNN)
function get_likelihood_estimate_function(discretization::Union{PhysicsInformedNN, DeepRitz})
function full_loss_function(θ, p)
# the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them
pde_losses = [pde_loss_function(θ) for pde_loss_function in pde_loss_functions]
Expand Down
54 changes: 54 additions & 0 deletions test/deep_ritz_test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using NeuralPDE, Test

using ModelingToolkit, Optimization, OptimizationOptimisers, Distributions, MethodOfLines,
OrdinaryDiffEq
import ModelingToolkit: Interval, infimum, supremum
using Lux #: tanh, identity

@testset "Poisson's equation" begin
@parameters x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2

# 2D PDE
eq = Dxx(u(x, y)) + Dyy(u(x, y)) ~ -sin(pi * x) * sin(pi * y)

# Initial and boundary conditions
bcs = [u(0, y) ~ 0.0, u(1, y) ~ -sin(pi * 1) * sin(pi * y),
u(x, 0) ~ 0.0, u(x, 1) ~ -sin(pi * x) * sin(pi * 1)]
# Space and time domains
domains = [x ∈ Interval(0.0, 1.0), y ∈ Interval(0.0, 1.0)]

strategy = QuasiRandomTraining(256, minibatch = 32)
hid = 40
chain_ = Lux.Chain(Lux.Dense(2, hid, Lux.σ), Lux.Dense(hid, hid, Lux.σ),
Lux.Dense(hid, 1))
discretization = DeepRitz(chain_, strategy);

@named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)])
prob = discretize(pde_system, discretization)

global iter = 0
callback = function (p, l)
global iter += 1
if iter % 50 == 0
println("$iter => $l")
end
return false
end

res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 500)
prob = remake(prob, u0 = res.u)
res = Optimization.solve(prob, Adam(0.001); callback = callback, maxiters = 200)
phi = discretization.phi

xs, ys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains]
analytic_sol_func(x, y) = (sin(pi * x) * sin(pi * y)) / (2pi^2)

u_predict = reshape([first(phi([x, y], res.u)) for x in xs for y in ys],
(length(xs), length(ys)))
u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys],
(length(xs), length(ys)))
@test u_predict≈u_real atol=0.1
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,10 @@ end
include("dgm_test.jl")
end
end

if GROUP == "All" || GROUP == "Deep-Ritz"
@time @safetestset "Deep Ritz method" begin
include("deep_ritz_test.jl")
end
end
end