Skip to content

Commit c16985d

Browse files
authored
Merge pull request #726 from AayushSabharwal/as/wvd
refactor: return to using `WeakValueDict`
2 parents 4fa6e63 + 7244b4b commit c16985d

File tree

3 files changed

+19
-25
lines changed

3 files changed

+19
-25
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34"
2727
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
2828
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
2929
Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
30+
WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1"
3031

3132
[weakdeps]
3233
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
@@ -61,6 +62,7 @@ TaskLocalValues = "0.1.2"
6162
TermInterface = "2.0"
6263
TimerOutputs = "0.5"
6364
Unityper = "0.1.2"
65+
WeakValueDicts = "0.1.0"
6466
julia = "1.10"
6567

6668
[extras]

src/SymbolicUtils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import TermInterface: iscall, isexpr, head, children,
2222
import ArrayInterface
2323
import ExproniconLite as EL
2424
import TaskLocalValues: TaskLocalValue
25+
import WeakValueDicts: WeakValueDict
2526

2627
include("cache.jl")
2728
Base.@deprecate istree iscall

src/types.jl

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,39 +24,39 @@ const EMPTY_DICT_T = typeof(EMPTY_DICT)
2424
const ENABLE_HASHCONSING = Ref(true)
2525

2626
@compactify show_methods=false begin
27-
@abstract struct BasicSymbolic{T} <: Symbolic{T}
27+
@abstract mutable struct BasicSymbolic{T} <: Symbolic{T}
2828
metadata::Metadata = NO_METADATA
2929
end
30-
struct Sym{T} <: BasicSymbolic{T}
30+
mutable struct Sym{T} <: BasicSymbolic{T}
3131
name::Symbol = :OOF
3232
end
33-
struct Term{T} <: BasicSymbolic{T}
33+
mutable struct Term{T} <: BasicSymbolic{T}
3434
f::Any = identity # base/num if Pow; issorted if Add/Dict
3535
arguments::SmallV{Any} = EMPTY_ARGS
3636
hash::RefValue{UInt} = EMPTY_HASH
3737
hash2::RefValue{UInt} = EMPTY_HASH
3838
end
39-
struct Mul{T} <: BasicSymbolic{T}
39+
mutable struct Mul{T} <: BasicSymbolic{T}
4040
coeff::Any = 0 # exp/den if Pow
4141
dict::EMPTY_DICT_T = EMPTY_DICT
4242
hash::RefValue{UInt} = EMPTY_HASH
4343
hash2::RefValue{UInt} = EMPTY_HASH
4444
arguments::SmallV{Any} = EMPTY_ARGS
4545
end
46-
struct Add{T} <: BasicSymbolic{T}
46+
mutable struct Add{T} <: BasicSymbolic{T}
4747
coeff::Any = 0 # exp/den if Pow
4848
dict::EMPTY_DICT_T = EMPTY_DICT
4949
hash::RefValue{UInt} = EMPTY_HASH
5050
hash2::RefValue{UInt} = EMPTY_HASH
5151
arguments::SmallV{Any} = EMPTY_ARGS
5252
end
53-
struct Div{T} <: BasicSymbolic{T}
53+
mutable struct Div{T} <: BasicSymbolic{T}
5454
num::Any = 1
5555
den::Any = 1
5656
simplified::Bool = false
5757
arguments::SmallV{Any} = EMPTY_ARGS
5858
end
59-
struct Pow{T} <: BasicSymbolic{T}
59+
mutable struct Pow{T} <: BasicSymbolic{T}
6060
base::Any = 1
6161
exp::Any = 1
6262
arguments::SmallV{Any} = EMPTY_ARGS
@@ -87,15 +87,7 @@ function exprtype(x::BasicSymbolic)
8787
end
8888
end
8989

90-
mutable struct HashConsingWrapper
91-
bs::BasicSymbolic
92-
end
93-
94-
Base.hash(x::HashConsingWrapper, h::UInt) = hash2(x.bs, h)
95-
96-
Base.isequal(x::HashConsingWrapper, y::HashConsingWrapper) = isequal_with_metadata(x.bs, y.bs)
97-
98-
const wkd = TaskLocalValue{WeakKeyDict{HashConsingWrapper, Nothing}}(WeakKeyDict{HashConsingWrapper, Nothing})
90+
const wvd = TaskLocalValue{WeakValueDict{UInt, BasicSymbolic}}(WeakValueDict{UInt, BasicSymbolic})
9991

10092
# Same but different error messages
10193
@noinline error_on_type() = error("Internal error: unreachable reached!")
@@ -531,15 +523,15 @@ Implements hash consing (flyweight design pattern) for `BasicSymbolic` objects.
531523
532524
This function checks if an equivalent `BasicSymbolic` object already exists. It uses a
533525
custom hash function (`hash2`) incorporating metadata and symtypes to search for existing
534-
objects in a `WeakKeyDict` (`wkd`). Due to the possibility of hash collisions (where
526+
objects in a `WeakValueDict` (`wvd`). Due to the possibility of hash collisions (where
535527
different objects produce the same hash), a custom equality check (`isequal_with_metadata`)
536528
which includes metadata comparison, is used to confirm the equivalence of objects with
537529
matching hashes. If an equivalent object is found, the existing object is returned;
538530
otherwise, the input `s` is returned. This reduces memory usage, improves compilation time
539531
for runtime code generation, and supports built-in common subexpression elimination,
540532
particularly when working with symbolic objects with metadata.
541533
542-
Using a `WeakKeyDict` ensures that only weak references to `BasicSymbolic` objects are
534+
Using a `WeakValueDict` ensures that only weak references to `BasicSymbolic` objects are
543535
stored, allowing objects that are no longer strongly referenced to be garbage collected.
544536
Custom functions `hash2` and `isequal_with_metadata` are used instead of `Base.hash` and
545537
`Base.isequal` to accommodate metadata without disrupting existing tests reliant on the
@@ -549,14 +541,13 @@ function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic
549541
if !ENABLE_HASHCONSING[]
550542
return s
551543
end
552-
cache = wkd[]
553-
hcw = HashConsingWrapper(s)
554-
k = getkey(cache, hcw, nothing)
555-
if isnothing(k)
556-
cache[hcw] = nothing
557-
return s
544+
cache = wvd[]
545+
h = hash2(s)
546+
k = get!(cache, h, s)
547+
if isequal_with_metadata(k, s)
548+
return k
558549
else
559-
return k.bs
550+
return s
560551
end
561552
end
562553

0 commit comments

Comments
 (0)