Skip to content

Commit fa2d86a

Browse files
committed
TEST: Test propagate_map_queries pass
1 parent 931597c commit fa2d86a

File tree

8 files changed

+46
-13
lines changed

8 files changed

+46
-13
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ readme = "README.md"
77
packages = [{include = "finch", from = "src"}]
88

99
[tool.poetry.dependencies]
10-
python = "^3.10"
10+
python = "^3.11"
1111
numpy = ">=1.19"
1212

1313
[tool.poetry.group.test.dependencies]

src/finch/autoschedule/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .finch_logic import (
1+
from ..finch_logic import (
22
Aggregate,
33
Alias,
44
Deferred,
@@ -15,7 +15,7 @@
1515
Table,
1616
)
1717
from .optimize import optimize, propagate_map_queries
18-
from .rewrite_tools import PostOrderDFS, PostWalk, PreWalk
18+
from ..symbolic import PostOrderDFS, PostWalk, PreWalk
1919

2020
__all__ = [
2121
"Aggregate",

src/finch/autoschedule/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from textwrap import dedent
33
from typing import Any
44

5-
from .finch_logic import (
5+
from ..finch_logic import (
66
Alias,
77
Deferred,
88
Field,

src/finch/autoschedule/executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .compiler import LogicCompiler
2-
from .rewrite_tools import gensym
2+
from ..symbolic import gensym
33

44

55
class LogicExecutor:

src/finch/autoschedule/optimize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .compiler import LogicCompiler
2-
from .finch_logic import Aggregate, Alias, LogicNode, MapJoin, Plan, Produces, Query
3-
from .rewrite_tools import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite
2+
from ..finch_logic import Aggregate, Alias, LogicNode, MapJoin, Plan, Produces, Query
3+
from ..symbolic import Chain, PostOrderDFS, PostWalk, PreWalk, Rewrite
44

55

66
def optimize(prgm: LogicNode) -> LogicNode:
@@ -52,4 +52,4 @@ def __init__(self, ctx: LogicCompiler):
5252

5353
def __call__(self, prgm: LogicNode):
5454
prgm = optimize(prgm)
55-
return self.ctx(prgm)
55+
return self.ctx(prgm)

src/finch/finch_logic/nodes.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def head(cls):
2727
return cls
2828

2929
@classmethod
30-
def make_term(cls, head, *args):
31-
"""Creates a term with the given head and arguments."""
32-
return head(*args)
30+
def make_term(cls, *args):
31+
"""Creates a term of the given cls type and with arguments."""
32+
return cls(*args)
3333

3434

3535
@dataclass(eq=True, frozen=True)
@@ -191,6 +191,10 @@ def children(self):
191191
"""Returns the children of the node."""
192192
return [self.op, *self.args]
193193

194+
@classmethod
195+
def make_term(cls, op, *args):
196+
return cls(op, args)
197+
194198

195199
@dataclass(eq=True, frozen=True)
196200
class Aggregate(LogicNode):
@@ -361,6 +365,10 @@ def children(self):
361365
"""Returns the children of the node."""
362366
return [self.lhs, self.rhs]
363367

368+
@classmethod
369+
def make_term(cls, lhs, rhs):
370+
return cls(lhs, rhs)
371+
364372

365373
@dataclass(eq=True, frozen=True)
366374
class Produces(LogicNode):
@@ -412,3 +420,7 @@ def is_stateful():
412420
def children(self):
413421
"""Returns the children of the node."""
414422
return [*self.bodies]
423+
424+
@classmethod
425+
def make_term(cls, *val):
426+
return cls(val)

src/finch/symbolic/rewriters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,13 @@ def __call__(self, x: Term) -> Term | None:
5858
if y is not None:
5959
if y.is_expr():
6060
args = y.children()
61-
return y.make_term(y.head(), *[default_rewrite(self(arg), arg) for arg in args])
61+
return y.make_term(*[default_rewrite(self(arg), arg) for arg in args])
6262
return y
6363
if x.is_expr():
6464
args = x.children()
6565
new_args = list(map(self, args))
6666
if not all(arg is None for arg in new_args):
67-
return x.make_term(x.head(), *map(lambda x1, x2: default_rewrite(x1, x2), new_args, args))
67+
return x.make_term(*map(lambda x1, x2: default_rewrite(x1, x2), new_args, args))
6868
return None
6969
return None
7070

tests/test_scheduler.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from finch.autoschedule import propagate_map_queries
2+
from finch.finch_logic import *
3+
4+
5+
def test_propagate_map_queries_simple():
6+
plan = Plan(
7+
(
8+
Query(Alias("A10"), Aggregate(Immediate("+"), Immediate(0), Immediate("[1,2,3]"), ())),
9+
Query(Alias("A11"), Alias("A10")),
10+
Produces((Alias("11"),)),
11+
)
12+
)
13+
expected = Plan(
14+
(
15+
Query(Alias("A11"), MapJoin(Immediate("+"), (Immediate(0), Immediate("[1,2,3]")))),
16+
Produces((Alias("11"),)),
17+
)
18+
)
19+
20+
result = propagate_map_queries(plan)
21+
assert result == expected

0 commit comments

Comments
 (0)