Skip to content

[WIP] Commutative operations and negative exponent match in rules #752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 25 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
a07bdbc
first version, really caothic, and doesn't work with defslot powers
Bumblebee00 Jun 14, 2025
12843da
second version, really caothic, but works with defslotpowers
Bumblebee00 Jun 14, 2025
d81145e
fix typo
Bumblebee00 Jun 18, 2025
79118fc
operation + and * are always commutative now
Bumblebee00 Jun 18, 2025
cd0cc33
added some tests of commutative operations
Bumblebee00 Jun 18, 2025
bd06d79
fixed bug on defslot functionality
Bumblebee00 Jun 19, 2025
a1da82d
added defslot on operations with multiple arguments
Bumblebee00 Jun 19, 2025
7849e7a
moved the commutativiry checks to only acrule macro
Bumblebee00 Jun 19, 2025
a7d57e9
negative exponent feature is done in a different way, more clean
Bumblebee00 Jun 20, 2025
50b5e50
fixed failing ci tests
Bumblebee00 Jun 20, 2025
3bd1282
added tests with deflost in operation call with more than two arguments
Bumblebee00 Jun 20, 2025
6825df3
now rationals can be used in rules
Bumblebee00 Jun 21, 2025
e6bce15
created smrule (sum multiplication rule) macro
Bumblebee00 Jun 22, 2025
f8c8841
enhance commutative term matcher to validate operation type
Bumblebee00 Jun 22, 2025
e742a84
fixed bug in defslot code and improved performance
Bumblebee00 Jun 22, 2025
9e4596d
improved negative exponent pattern matching. now it matches also for…
Bumblebee00 Jun 22, 2025
bdce8c4
changed order of checks in pow term matcher
Bumblebee00 Jun 24, 2025
8c8a207
added match for exp and sqrt calls
Bumblebee00 Jun 27, 2025
08e9993
removed smrule macro and added commutativity checks to the rule macro
Bumblebee00 Jun 30, 2025
80cabb1
added commutativity checks also for segment matcher
Bumblebee00 Jul 7, 2025
2dbff77
fixed predicates with defslots
Bumblebee00 Jul 7, 2025
734d1b9
now the pattern ~x^~m matches 1/x with m=-1
Bumblebee00 Aug 3, 2025
4a49b19
added tests for power match with sqrt and exp functions
Bumblebee00 Aug 6, 2025
05a5af2
refactor
Bumblebee00 Aug 6, 2025
36034e0
now ...^(1//2) matches in the rule with sqrt, and ℯ^... matches in th…
Bumblebee00 Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 174 additions & 79 deletions src/matchers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
# 3. Callback: takes arguments Dictionary × Number of elements matched
#

function matcher(val::Any)
function matcher(val::Any, acSets)
# if val is a call (like an operation) creates a term matcher or term matcher with defslot
if iscall(val)
# if has two arguments and one of them is a DefSlot, create a term matcher with defslot
if length(arguments(val)) == 2 && any(x -> isa(x, DefSlot), arguments(val))
return defslot_term_matcher_constructor(val)
# else return a normal term matcher
else
return term_matcher_constructor(val)
# just two arguments bc defslot is only supported with operations with two args: *, ^, +
if any(x -> isa(x, DefSlot), arguments(val))
return defslot_term_matcher_constructor(val, acSets)
end
# else return a normal term matcher
return term_matcher_constructor(val, acSets)
end

function literal_matcher(next, data, bindings)
Expand All @@ -24,7 +24,8 @@ function matcher(val::Any)
end
end

function matcher(slot::Slot)
# acSets is not used but needs to be there in case matcher(::Slot) is directly called from the macro
function matcher(slot::Slot, acSets)
function slot_matcher(next, data, bindings)
!islist(data) && return nothing
val = get(bindings, slot.name, nothing)
Expand All @@ -43,8 +44,8 @@ end
# this is called only when defslot_term_matcher finds the operation and tries
# to match it, so no default value used. So the same function as slot_matcher
# can be used
function matcher(defslot::DefSlot)
matcher(Slot(defslot.name, defslot.predicate))
function matcher(defslot::DefSlot, acSets)
matcher(Slot(defslot.name, defslot.predicate), nothing) # slot matcher doesnt use acsets
end

# returns n == offset, 0 if failed
Expand Down Expand Up @@ -75,7 +76,7 @@ function trymatchexpr(data, value, n)
end
end

function matcher(segment::Segment)
function matcher(segment::Segment, acSets)
function segment_matcher(success, data, bindings)
val = get(bindings, segment.name, nothing)

Expand All @@ -90,98 +91,192 @@ function matcher(segment::Segment)
for i=length(data):-1:0
subexpr = take_n(data, i)

if segment.predicate(subexpr)
res = success(assoc(bindings, segment.name, subexpr), i)
if res !== nothing
break
end
end
!segment.predicate(subexpr) && continue
res = success(assoc(bindings, segment.name, subexpr), i)
res !== nothing && break
end

return res
end
end
end

function term_matcher_constructor(term)
matchers = (matcher(operation(term)), map(matcher, arguments(term))...,)
function term_matcher_constructor(term, acSets)
matchers = (matcher(operation(term), acSets), map(x->matcher(x,acSets), arguments(term))...,)

function loop(term, bindings′, matchers′) # Get it to compile faster
if !islist(matchers′)
if !islist(term)
return bindings′
end
return nothing
end
car(matchers′)(term, bindings′) do b, n
loop(drop_n(term, n), b, cdr(matchers′))
end
# explanation of above 3 lines:
# car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′)
# <------ next(b,n) ---------------------------->
# car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list
# Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term
# Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses
# the length of the list, is considered empty
end

function term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
!iscall(car(data)) && return nothing # if first element is not a call, return nothing
# if the operation is a pow, we have to match also 1/(...)^(...) with negative exponent
if operation(term) === ^
function pow_term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
data = car(data) # from (..., ) to ...
!iscall(data) && return nothing # if first element is not a call, return nothing

result = loop(data, bindings, matchers)
result !== nothing && return success(result, 1)

frankestein = nothing
if (operation(data) === ^) && iscall(arguments(data)[1]) && (operation(arguments(data)[1]) === /) && isequal(arguments(arguments(data)[1])[1], 1)
# if data is of the alternative form (1/...)^(...)
one_over_smth = arguments(data)[1]
T = symtype(one_over_smth)
frankestein = Term{T}(^, [arguments(one_over_smth)[2], -arguments(data)[2]])
elseif (operation(data) === /) && isequal(arguments(data)[1], 1) && iscall(arguments(data)[2]) && (operation(arguments(data)[2]) === ^)
# if data is of the alternative form 1/(...)^(...)
denominator = arguments(data)[2]
T = symtype(denominator)
frankestein = Term{T}(^, [arguments(denominator)[1], -arguments(denominator)[2]])
elseif (operation(data) === /) && isequal(arguments(data)[1], 1)
# if data is of the alternative form 1/(...), it might match with exponent = -1
denominator = arguments(data)[2]
T = symtype(denominator)
frankestein = Term{T}(^, [denominator, -1])
elseif operation(data)===exp
# if data is a exp call, it might match with base e
T = symtype(arguments(data)[1])
frankestein = Term{T}(^,[ℯ, arguments(data)[1]])
elseif operation(data)===sqrt
# if data is a sqrt call, it might match with exponent 1//2
T = symtype(arguments(data)[1])
frankestein = Term{T}(^,[arguments(data)[1], 1//2])
end

if frankestein !==nothing
result = loop(frankestein, bindings, matchers)
result !== nothing && return success(result, 1)
end

return nothing
end
return pow_term_matcher
# if we want to do commutative checks, i.e. call matcher with different order of the arguments
elseif acSets!==nothing && operation(term) in [+, *]
function commutative_term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
!iscall(car(data)) && return nothing # if first element is not a call, return nothing
operation(term) !== operation(car(data)) && return nothing # if the operation of data is not the correct one, don't even try

T = symtype(car(data))
if T <: Number
f = operation(car(data))
data_args = arguments(car(data))

for inds in acSets(eachindex(data_args), length(data_args))
candidate = Term{T}(f, @views data_args[inds])

function loop(term, bindings′, matchers′) # Get it to compile faster
if !islist(matchers′)
if !islist(term)
return success(bindings′, 1)
result = loop(candidate, bindings, matchers)
result !== nothing && return success(result,1)
end
return nothing
# if car(data) does not subtype to number, it might not be commutative
else
# call the normal matcher
result = loop(car(data), bindings, matchers)
result !== nothing && return success(result, 1)
end
car(matchers′)(term, bindings′) do b, n
loop(drop_n(term, n), b, cdr(matchers′))
return nothing
end
return commutative_term_matcher
# if the operation is sqrt, we have to match also ^(1//2)
elseif operation(term)==sqrt
function sqrt_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
data = car(data)
!iscall(data) && return nothing # if first element is not a call, return nothing

# do the normal matcher
result = loop(data, bindings, matchers)
result !== nothing && return success(result, 1)

if (operation(data) === ^) && (arguments(data)[2] === 1//2)
T = symtype(arguments(data)[1])
frankestein = Term{T}(sqrt,[arguments(data)[1]])
result = loop(frankestein, bindings, matchers)
result !== nothing && return success(result, 1)
end
# explanation of above 3 lines:
# car(matchers′)(b,n -> loop(drop_n(term, n), b, cdr(matchers′)), term, bindings′)
# <------ next(b,n) ---------------------------->
# car = first element of list, cdr = rest of the list, drop_n = drop first n elements of list
# Calls the first matcher, with the "next" function being loop again but with n terms dropepd from term
# Term is a linked list (a list and a index). drop n advances the index. when the index sorpasses
# the length of the list, is considered empty
return nothing
end
return sqrt_matcher
# if the operation is exp, we have to match also ℯ^
elseif operation(term)==exp
function exp_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
data = car(data)
!iscall(data) && return nothing # if first element is not a call, return nothing

# do the normal matcher
result = loop(data, bindings, matchers)
result !== nothing && return success(result, 1)

loop(car(data), bindings, matchers) # Try to eat exactly one term
if (operation(data) === ^) && (arguments(data)[1] === ℯ)
T = symtype(arguments(data)[2])
frankestein = Term{T}(exp,[arguments(data)[2]])
result = loop(frankestein, bindings, matchers)
result !== nothing && return success(result, 1)
end
return nothing
end
return exp_matcher
else
function term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
!iscall(car(data)) && return nothing # if first element is not a call, return nothing

result = loop(car(data), bindings, matchers)
result !== nothing && return success(result, 1)
return nothing
end
return term_matcher
end
end

# creates a matcher for a term containing a defslot, such as:
# (~x + ...complicated pattern...) * ~!y
# normal part (can bee a tree) operation defslot part

# defslot_term_matcher works like this:
# checks whether data starts with the default operation.
# if yes (1): continues like term_matcher
# if no checks whether data matches the normal part
# if no returns nothing, rule is not applied
# if yes (2): adds the pair (default value name, default value) to the found bindings and
# calls the success function like term_matcher would do

function defslot_term_matcher_constructor(term)
a = arguments(term) # length two bc defslot term matcher is allowed only with +,* and ^, that accept two arguments
matchers = (matcher(operation(term)), map(matcher, a)...) # create matchers for the operation and the two arguments of the term

function defslot_term_matcher_constructor(term, acSets)
a = arguments(term)
defslot_index = findfirst(x -> isa(x, DefSlot), a) # find the defslot in the term
defslot = a[defslot_index]
if length(a) == 2
other_part_matcher = matcher(a[defslot_index == 1 ? 2 : 1], acSets)
else
others = [a[i] for i in eachindex(a) if i != defslot_index]
T = symtype(term)
f = operation(term)
other_part_matcher = term_matcher_constructor(Term{T}(f, others), acSets)
end

function defslot_term_matcher(success, data, bindings)
# if data is not a list, return nothing
!islist(data) && return nothing
# if data (is not a tree and is just a symbol) or (is a tree not starting with the default operation)
if !iscall(car(data)) || (iscall(car(data)) && nameof(operation(car(data))) != defslot.operation)
other_part_matcher = matchers[defslot_index==2 ? 2 : 3] # find the matcher of the normal part

# checks whether it matches the normal part
# <-----------------(2)------------------------------->
bindings = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings)

if bindings === nothing
return nothing
end
return success(bindings, 1)
end

# (1)
function loop(term, bindings′, matchers′) # Get it to compile faster
if !islist(matchers′)
if !islist(term)
return success(bindings′, 1)
end
return nothing
end
car(matchers′)(term, bindings′) do b, n
loop(drop_n(term, n), b, cdr(matchers′))
end
end
normal_matcher = term_matcher_constructor(term, acSets)

loop(car(data), bindings, matchers) # Try to eat exactly one term
function defslot_term_matcher(success, data, bindings)
!islist(data) && return nothing # if data is not a list, return nothing
# call the normal matcher, with success function foo1 that simply returns the bindings
# <--foo1-->
result = normal_matcher((b,n) -> b, data, bindings)
result !== nothing && return success(result, 1)
# if no match, try to match with a defslot.
# checks whether it matches the normal part if yes executes foo2
# foo2: adds the pair (default value name, default value) to the found bindings
# <-------------------foo2---------------------------->
result = other_part_matcher((b,n) -> assoc(b, defslot.name, defslot.defaultValue), data, bindings)
result !== nothing && return success(result, 1)
nothing
end
end
Loading
Loading