-
-
Notifications
You must be signed in to change notification settings - Fork 77
Add repack/canonicalize in vec_pjac! to support SciMLStructs #1239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Can you add a test case that hits this? |
A typo that was causing the tests to fail was corrected. |
It seems on the new test case Enzyme works on v1.10 but segfaults on v1.11 https://github.com/SciML/SciMLSensitivity.jl/actions/runs/16355834035/job/46217780082?pr=1239#step:6:973. |
Interesting that EnzymeAD/Enzyme.jl#2450 doesn't catch this. |
using Random, Lux
using ComponentArrays
using Enzyme
mutable struct myparam{M,P,S}
model::M
ps ::P
st ::S
α :: Float64
β :: Float64
γ :: Float64
end
function initialize()
# Defining the neural network
U = Lux.Chain(Lux.Dense(3,30,tanh),Lux.Dense(30,30,tanh),Lux.Dense(30,1))
rng = Random.GLOBAL_RNG
_para,st = Lux.setup(rng,U)
_para = ComponentArray(_para)
# Setting the parameters
α = 0.5
β = 0.1
γ = 0.01
return myparam(U,_para,st,α,β,γ)
end
function UDE_model!(du, u, p, t)
o = p.model(u,p.ps, p.st)[1][1]
du[1] = o * p.α * u[1] + p.β * u[2] + p.γ * u[3]
du[2] = -p.α * u[1] + p.β * u[2] - p.γ * u[3]
du[3] = p.α * u[1] - p.β * u[2] + p.γ * u[3]
nothing
end
p = initialize()
u0 = zeros(3); du = zeros(3)
ddu = Enzyme.make_zero(du)
d_u0 = Enzyme.make_zero(u0)
dp = Enzyme.make_zero(p)
Enzyme.autodiff(Reverse, Enzyme.Const(UDE_model!), Enzyme.Duplicated(du, ddu), Enzyme.Duplicated(u0, d_u0), Enzyme.Duplicated(p, dp), Enzyme.Const(0.2)) Need to run on v1.11 to confirm it segfaults, battery running low. |
Interesting, I just ran it on v1.11 but didn't recreate the segfault 😅 Maybe @wsmoses has ideas just from the stack trace. |
what enzyme version, we fixed a segfault last patch |
For reference this is the package test which is segafulting: using OrdinaryDiffEq
using Random, Lux
using ComponentArrays
using SciMLSensitivity
import SciMLStructures as SS
using Zygote
using ADTypes
using Test
mutable struct myparam{M,P,S}
model::M
ps ::P
st ::S
α :: Float64
β :: Float64
γ :: Float64
end
SS.isscimlstructure(::myparam) = true
SS.ismutablescimlstructure(::myparam) = true
SS.hasportion(::SS.Tunable, ::myparam) = true
function SS.canonicalize(::SS.Tunable, p::myparam)
buffer = copy(p.ps)
repack = let p = p
function repack(newbuffer)
SS.replace(SS.Tunable(), p, newbuffer)
end
end
return buffer, repack, false
end
function SS.replace(::SS.Tunable, p::myparam, newbuffer)
return myparam(p.model, newbuffer, p.st, p.α, p.β, p.γ)
end
function SS.replace!(::SS.Tunable, p::myparam, newbuffer)
p.ps = newbuffer
return p
end
function initialize()
# Defining the neural network
U = Lux.Chain(Lux.Dense(3,30,tanh),Lux.Dense(30,30,tanh),Lux.Dense(30,1))
rng = Random.GLOBAL_RNG
_para,st = Lux.setup(rng,U)
_para = ComponentArray(_para)
# Setting the parameters
α = 0.5
β = 0.1
γ = 0.01
return myparam(U,_para,st,α,β,γ)
end
function UDE_model!(du, u, p, t)
o = p.model(u,p.ps, p.st)[1][1]
du[1] = o * p.α * u[1] + p.β * u[2] + p.γ * u[3]
du[2] = -p.α * u[1] + p.β * u[2] - p.γ * u[3]
du[3] = p.α * u[1] - p.β * u[2] + p.γ * u[3]
nothing
end
p = initialize()
function run_diff(ps)
u01 = [1.0, 0.0, 0.0]
tspan = (0.0, 10.0)
prob = ODEProblem(UDE_model!, u01, tspan, ps)
sol = solve(prob, Rosenbrock23(), saveat = 0.1)
return sol.u |> last |> sum
end
run_diff(initialize())
@test !iszero(Zygote.gradient(run_diff, initialize())[1].ps)
function run_diff(ps,sensealg)
u01 = [1.0, 0.0, 0.0]
tspan = (0.0, 10.0)
prob = ODEProblem(UDE_model!, u01, tspan, ps)
sol = solve(prob, Rosenbrock23(), saveat = 0.1, sensealg=sensealg)
return sol.u |> last |> sum
end
run_diff(initialize())
@test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint(autojacvec=EnzymeVJP()))[1].ps) and it's running the same AD of the |
[7da242da] Enzyme v0.13.61 |
For reference
|
I tried to reproduce this error but it turns out that it segfaults only in test mode. Running your example @ChrisRackauckas after having added a Note: the |
Test mode runs with error bounds checking always turned on, i.e. |
Lux mutates its input? That is unexpected, @avik-pal |
If you are calling via the public API (i.e. the models like Lux.Chain/Lux.Dense) it should not mutate the input |
This PR fixes #1238