-
Notifications
You must be signed in to change notification settings - Fork 228
Compatibility with DynamicPPL 0.38 + InitContext #2676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: breaking
Are you sure you want to change the base?
Conversation
9658a3e
to
ed43a02
Compare
ed43a02
to
bf18516
Compare
bf18516
to
3a04643
Compare
# Get the initial values for this component sampler. | ||
initial_params_local = if initial_params === nothing | ||
nothing | ||
else | ||
DynamicPPL.subset(vi, varnames)[:] | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was quite pleased with this discovery. Previously the initial params had to be subsetted to be the correct length for the conditioned model. That's not only a faff, but also I get a bit scared whenever there's direct VarInfo manipulation like this.
Now, if you use InitFromParams with a NamedTuple/Dict that has extra params, the extra params are just ignored. So no need to subset it at all, just pass it through directly!
# TODO(DPPL0.38/penelopeysm): This function should no longer be needed | ||
# once InitContext is merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unfortunately set_namedtuple!
is used elsewhere in this file (though it won't appear in this diff) so we can't delete it (yet)
function DynamicPPL.tilde_assume!!( | ||
context::MHContext, right::Distribution, vn::VarName, vi::AbstractVarInfo | ||
) | ||
# Just defer to `SampleFromPrior`. | ||
retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi) | ||
return retval | ||
# Allow MH to sample new variables from the prior if it's not already present in the | ||
# VarInfo. | ||
dispatch_ctx = if haskey(vi, vn) | ||
DynamicPPL.DefaultContext() | ||
else | ||
DynamicPPL.InitContext(context.rng, DynamicPPL.InitFromPrior()) | ||
end | ||
return DynamicPPL.tilde_assume!!(dispatch_ctx, right, vn, vi) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The behaviour of SampleFromPrior
used to be: if the key is present, don't actually sample, and if it was absent, sample. This if/else replicates the old behaviour.
sampler::S | ||
varinfo::V | ||
evaluator::E | ||
resample::Bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For pMCMC, this Boolean field essentially replaces the del flag. Instead of set_all_del
and unset_all_del
we construct new TracedModel
with this set to true and false respectively.
@test sample(StableRNG(23), xy(), spl_xy, num_samples).value ≈ | ||
sample(StableRNG(23), x12(), spl_x, num_samples).value | ||
chn1 = sample(StableRNG(23), xy(), spl_xy, num_samples) | ||
chn2 = sample(StableRNG(23), x12(), spl_x, num_samples) | ||
|
||
@test mean(chn1[:z]) ≈ mean(chn2[:z]) atol = 0.05 | ||
@test mean(chn1[:x]) ≈ mean(chn2["x[1]"]) atol = 0.05 | ||
@test mean(chn1[:y]) ≈ mean(chn2["x[2]"]) atol = 0.05 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The values are no longer exactly the same (it has something to do with initialisation behaviour which is different for the two models). But we can still check that the results are sensibly similar, which is probably also more meaningful anyway as it means that ESS not only works on both models but also consistently converges regardless of how the model is specified.
Turing.jl documentation for PR #2676 is available at: |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## breaking #2676 +/- ##
===========================================
Coverage ? 85.10%
===========================================
Files ? 22
Lines ? 1410
Branches ? 0
===========================================
Hits ? 1200
Misses ? 210
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
seed = if dist isa GeneralizedExtremeValue | ||
# GEV is prone to giving really wacky results that are quite | ||
# seed-dependent. | ||
StableRNG(469) | ||
else | ||
StableRNG(468) | ||
end | ||
chn = sample(seed, m(), HMC(0.05, 20), n_samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Case in point:
julia> using Turing, StableRNGs
julia> dist = GeneralizedExtremeValue(0, 1, 0.5); @model m() = x ~ dist
m (generic function with 2 methods)
julia> mean(dist)
1.5449077018110322
julia> mean(sample(StableRNG(468), m(), HMC(0.05, 20), 10000; progress=false))
Mean
parameters mean
Symbol Float64
x 3.9024
julia> mean(sample(StableRNG(469), m(), HMC(0.05, 20), 10000; progress=false))
Mean
parameters mean
Symbol Float64
x 1.5868
For the record, 11 failing CI jobs is the expected number:
There is also the failing job caused by base Julia segfault (#2655), but that's on 1.10 so overlaps with the first category. |
This PR is being tested against this DynamicPPL branch: TuringLang/DynamicPPL.jl#1057
It should be noted that due to the changes in DynamicPPL's
src/sampler.jl
, the results of running MCMC sampling on this branch will pretty much always differ from that on the main branch. Thus there is no (easy) way to test full reproducibility of MCMC results (we have to rely instead on statistics for converged chains).TODO:
Separate PRs:
use InitStrategy for optimisation as well
Note that the three pre-existing InitStrategies can be used directly with optimisation. However, to handle constraints properly, it seems necessary to introduce a new subtype of AbstractInitStrategy. I think this should be a separate PR because it's a fair bit of work.
fix docs for that argument, wherever it is (there's probably some in AbstractMCMC but it should probably be documented on the main site)