Skip to content

Add repack/canonicalize in vec_pjac! to support SciMLStructs #1239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

albangossard
Copy link

This PR fixes #1238

@ChrisRackauckas
Copy link
Member

Can you add a test case that hits this?

@albangossard
Copy link
Author

A typo that was causing the tests to fail was corrected.
This use case is already tested as isscimlstructure(p) is true for vectors. I added a test to check GaussAdjoint with EnzymeVJP on a more complex struct. I made sure that this new test was failing without the patch I'm proposing. This makes sure that no regression is added in the future.

@ChrisRackauckas
Copy link
Member

It seems on the new test case Enzyme works on v1.10 but segfaults on v1.11 https://github.com/SciML/SciMLSensitivity.jl/actions/runs/16355834035/job/46217780082?pr=1239#step:6:973.

@ChrisRackauckas
Copy link
Member

Interesting that EnzymeAD/Enzyme.jl#2450 doesn't catch this.

@ChrisRackauckas
Copy link
Member

using Random, Lux
using ComponentArrays
using Enzyme

mutable struct myparam{M,P,S}
    model::M
    ps ::P 
    st ::S
    α :: Float64
    β :: Float64
    γ :: Float64
end
function initialize()
    # Defining the neural network
    U = Lux.Chain(Lux.Dense(3,30,tanh),Lux.Dense(30,30,tanh),Lux.Dense(30,1))
    rng = Random.GLOBAL_RNG
    _para,st = Lux.setup(rng,U)
    _para = ComponentArray(_para)
    # Setting the parameters
    α = 0.5
    β = 0.1
    γ = 0.01
    return myparam(U,_para,st,α,β,γ)
end
function UDE_model!(du, u, p, t)
    o = p.model(u,p.ps, p.st)[1][1]
    du[1] = o * p.α * u[1] + p.β * u[2] + p.γ * u[3]
    du[2] = -p.α * u[1] + p.β * u[2] - p.γ * u[3]
    du[3] = p.α * u[1] - p.β * u[2] + p.γ * u[3]
    nothing
end
   
p = initialize()
u0 = zeros(3); du = zeros(3)
ddu = Enzyme.make_zero(du)
d_u0 = Enzyme.make_zero(u0)
dp = Enzyme.make_zero(p)
Enzyme.autodiff(Reverse, Enzyme.Const(UDE_model!), Enzyme.Duplicated(du, ddu), Enzyme.Duplicated(u0, d_u0), Enzyme.Duplicated(p, dp), Enzyme.Const(0.2))

Need to run on v1.11 to confirm it segfaults, battery running low.

@ChrisRackauckas
Copy link
Member

Interesting, I just ran it on v1.11 but didn't recreate the segfault 😅 Maybe @wsmoses has ideas just from the stack trace.

@wsmoses
Copy link

wsmoses commented Jul 18, 2025

what enzyme version, we fixed a segfault last patch

@ChrisRackauckas
Copy link
Member

For reference this is the package test which is segafulting:

using OrdinaryDiffEq
using Random, Lux
using ComponentArrays
using SciMLSensitivity
import SciMLStructures as SS
using Zygote
using ADTypes
using Test
mutable struct myparam{M,P,S}
    model::M
    ps ::P 
    st ::S
    α :: Float64
    β :: Float64
    γ :: Float64
end
SS.isscimlstructure(::myparam) = true
SS.ismutablescimlstructure(::myparam) = true
SS.hasportion(::SS.Tunable, ::myparam) = true
function SS.canonicalize(::SS.Tunable, p::myparam)
    buffer = copy(p.ps)
    repack = let p = p
        function repack(newbuffer)
            SS.replace(SS.Tunable(), p, newbuffer)
        end
    end
    return buffer, repack, false
end
function SS.replace(::SS.Tunable, p::myparam, newbuffer)
    return myparam(p.model, newbuffer, p.st, p.α, p.β, p.γ)
end
function SS.replace!(::SS.Tunable, p::myparam, newbuffer)
    p.ps = newbuffer
    return p
end
function initialize()
    # Defining the neural network
    U = Lux.Chain(Lux.Dense(3,30,tanh),Lux.Dense(30,30,tanh),Lux.Dense(30,1))
    rng = Random.GLOBAL_RNG
    _para,st = Lux.setup(rng,U)
    _para = ComponentArray(_para)
    # Setting the parameters
    α = 0.5
    β = 0.1
    γ = 0.01
    return myparam(U,_para,st,α,β,γ)
end
function UDE_model!(du, u, p, t)
    o = p.model(u,p.ps, p.st)[1][1]
    du[1] = o * p.α * u[1] + p.β * u[2] + p.γ * u[3]
    du[2] = -p.α * u[1] + p.β * u[2] - p.γ * u[3]
    du[3] = p.α * u[1] - p.β * u[2] + p.γ * u[3]
    nothing
end
   
p = initialize()
function run_diff(ps)
    u01 = [1.0, 0.0, 0.0]
    tspan = (0.0, 10.0)
    prob = ODEProblem(UDE_model!, u01, tspan, ps)
    sol = solve(prob, Rosenbrock23(), saveat = 0.1)
    return sol.u |> last |> sum
end
run_diff(initialize())
@test !iszero(Zygote.gradient(run_diff, initialize())[1].ps)
function run_diff(ps,sensealg)
    u01 = [1.0, 0.0, 0.0]
    tspan = (0.0, 10.0)
    prob = ODEProblem(UDE_model!, u01, tspan, ps)
    sol = solve(prob, Rosenbrock23(), saveat = 0.1, sensealg=sensealg)
    return sol.u |> last |> sum
end

run_diff(initialize())
@test !iszero(Zygote.gradient(run_diff, initialize(), GaussAdjoint(autojacvec=EnzymeVJP()))[1].ps)

and it's running the same AD of the f function above. Pinging @avik-pal as well since it seems like it's a Lux matmul thing.

@ChrisRackauckas
Copy link
Member

[7da242da] Enzyme v0.13.61

@ChrisRackauckas
Copy link
Member

For reference

[7701] signal 11 (1): Segmentation fault
in expression starting at /home/runner/work/SciMLSensitivity.jl/SciMLSensitivity.jl/test/scimlstructures_interface.jl:162
getindex at ./essentials.jl:916 [inlined]
getindex at ./subarray.jl:343 [inlined]
_broadcast_getindex at ./broadcast.jl:644 [inlined]
_getindex at ./broadcast.jl:675 [inlined]
_broadcast_getindex at ./broadcast.jl:650 [inlined]
macro expansion at /home/runner/.julia/packages/Enzyme/9dy6G/src/compiler/interpreter.jl:574 [inlined]
lindex_v1 at /home/runner/.julia/packages/Enzyme/9dy6G/src/compiler/interpreter.jl:551 [inlined]
macro expansion at /home/runner/.julia/packages/Enzyme/9dy6G/src/compiler/interpreter.jl:785 [inlined]
lindex_v3 at /home/runner/.julia/packages/Enzyme/9dy6G/src/compiler/interpreter.jl:722 [inlined]
override_bc_copyto! at /home/runner/.julia/packages/Enzyme/9dy6G/src/compiler/interpreter.jl:824 [inlined]
copyto! at ./broadcast.jl:925 [inlined]
materialize! at ./broadcast.jl:883 [inlined]
materialize! at ./broadcast.jl:880 [inlined]
muladd at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:221
matmuladd at /home/runner/.julia/packages/LuxLib/3ewJc/src/impl/matmul.jl:13 [inlined]
matmuladd at /home/runner/.julia/packages/LuxLib/3ewJc/src/impl/matmul.jl:7 [inlined]
fused_dense at /home/runner/.julia/packages/LuxLib/3ewJc/src/impl/dense.jl:10 [inlined]
fused_dense_bias_activation at /home/runner/.julia/packages/LuxLib/3ewJc/src/api/dense.jl:36 [inlined]
Dense at /home/runner/.julia/packages/Lux/FMMvw/src/layers/basic.jl:363
apply at /home/runner/.julia/packages/LuxCore/q0Mrq/src/LuxCore.jl:155 [inlined]
macro expansion at /home/runner/.julia/packages/Lux/FMMvw/src/layers/containers.jl:0 [inlined]
applychain at /home/runner/.julia/packages/Lux/FMMvw/src/layers/containers.jl:511 [inlined]
Chain at /home/runner/.julia/packages/Lux/FMMvw/src/layers/containers.jl:509 [inlined]
UDE_model! at /home/runner/work/SciMLSensitivity.jl/SciMLSensitivity.jl/test/scimlstructures_interface.jl:132
ODEFunction at /home/runner/.julia/packages/SciMLBase/zk34N/src/scimlfunctions.jl:2591 [inlined]
ODEFunction at /home/runner/.julia/packages/SciMLBase/zk34N/src/scimlfunctions.jl:2591 [inlined]
Void at /home/runner/.julia/packages/SciMLBase/zk34N/src/utils.jl:486 [inlined]
Void at /home/runner/.julia/packages/SciMLBase/zk34N/src/utils.jl:0 [inlined]
diffejulia_Void_651042_inner_29wrap at /home/runner/.julia/packages/SciMLBase/zk34N/src/utils.jl:0
macro expansion at /home/runner/.julia/packages/Enzyme/9dy6G/src/compiler.jl:5610 [inlined]
enzyme_call at /home/runner/.julia/packages/Enzyme/9dy6G/src/compiler.jl:5144 [inlined]
CombinedAdjointThunk at /home/runner/.julia/packages/Enzyme/9dy6G/src/compiler.jl:5019 [inlined]
autodiff at /home/runner/.julia/packages/Enzyme/9dy6G/src/Enzyme.jl:517 [inlined]
_vecjacobian! at /home/runner/work/SciMLSensitivity.jl/SciMLSensitivity.jl/src/derivative_wrappers.jl:733

@albangossard
Copy link
Author

I tried to reproduce this error but it turns out that it segfaults only in test mode. Running your example @ChrisRackauckas after having added a copy(u) in p.model(...) is working when doing an include but it segfaults in test mode. I made sure that the Manifest.toml were exactly the same in both cases.
Maybe something is handled differently in test mode? I am not familiar enough with the internals of Enzyme and Lux to debug this.

Note: the copy is needed as it seems that Lux mutates u and because autodiff uses Enzyme.Const(y) it causes this error.

@ChrisRackauckas
Copy link
Member

Maybe something is handled differently in test mode? I am not familiar enough with the internals of Enzyme and Lux to debug this.

Test mode runs with error bounds checking always turned on, i.e. @inbounds is ignored. I wonder if this effects Enzyme.

@ChrisRackauckas
Copy link
Member

Note: the copy is needed as it seems that Lux mutates u and because autodiff uses Enzyme.Const(y) it causes this error.

Lux mutates its input? That is unexpected, @avik-pal

@avik-pal
Copy link
Member

avik-pal commented Aug 1, 2025

If you are calling via the public API (i.e. the models like Lux.Chain/Lux.Dense) it should not mutate the input

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Differentiation fails when using SciMLStructures with GaussAdjoint and EnzymeVJP
4 participants