Skip to content

Commit a0cc64b

Browse files
authored
fix: fix selecting all (#28)
* fix: selecting all should set `select_models` to `None` * chore: bump package
1 parent 2be6d82 commit a0cc64b

File tree

14 files changed

+196
-199
lines changed

14 files changed

+196
-199
lines changed

dagster_sqlmesh/asset.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import logging
22
import typing as t
33

4-
from dagster import (
5-
AssetsDefinition,
6-
RetryPolicy,
7-
multi_asset,
8-
)
4+
from dagster import AssetsDefinition, RetryPolicy, multi_asset
95

106
from dagster_sqlmesh.controller import DagsterSQLMeshController
117
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator

dagster_sqlmesh/conftest.py

Lines changed: 1 addition & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,17 @@
44
import sys
55
import tempfile
66
import typing as t
7-
from dataclasses import dataclass
87

9-
import duckdb
10-
import polars
118
import pytest
129
from sqlmesh.core.config import (
1310
Config as SQLMeshConfig,
1411
DuckDBConnectionConfig,
1512
GatewayConfig,
1613
ModelDefaultsConfig,
1714
)
18-
from sqlmesh.core.console import get_console
19-
from sqlmesh.utils.date import TimeLike
2015

2116
from dagster_sqlmesh.config import SQLMeshContextConfig
22-
from dagster_sqlmesh.console import ConsoleEvent
23-
from dagster_sqlmesh.controller.base import PlanOptions, RunOptions
24-
from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController
25-
from dagster_sqlmesh.events import ConsoleRecorder
17+
from dagster_sqlmesh.testing import SQLMeshTestContext
2618

2719
logger = logging.getLogger(__name__)
2820

@@ -50,110 +42,6 @@ def sample_sqlmesh_project() -> t.Iterator[str]:
5042
yield str(project_dir)
5143

5244

53-
@dataclass
54-
class SQLMeshTestContext:
55-
"""A test context for running SQLMesh"""
56-
57-
db_path: str
58-
context_config: SQLMeshContextConfig
59-
60-
def create_controller(
61-
self, enable_debug_console: bool = False
62-
) -> DagsterSQLMeshController:
63-
console = None
64-
if enable_debug_console:
65-
console = get_console()
66-
return DagsterSQLMeshController.setup_with_config(
67-
self.context_config, debug_console=console
68-
)
69-
70-
def query(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
71-
conn = duckdb.connect(self.db_path)
72-
return conn.sql(*args, **kwargs).fetchall()
73-
74-
def initialize_test_source(self) -> None:
75-
conn = duckdb.connect(self.db_path)
76-
conn.sql(
77-
"""
78-
CREATE SCHEMA sources;
79-
"""
80-
)
81-
conn.sql(
82-
"""
83-
CREATE TABLE sources.test_source (id INTEGER, name VARCHAR);
84-
"""
85-
)
86-
conn.sql(
87-
"""
88-
INSERT INTO sources.test_source (id, name)
89-
VALUES (1, 'abc'), (2, 'def');
90-
"""
91-
)
92-
conn.close()
93-
94-
def append_to_test_source(self, df: polars.DataFrame):
95-
logger.debug("appending data to the test source")
96-
conn = duckdb.connect(self.db_path)
97-
conn.sql(
98-
"""
99-
INSERT INTO sources.test_source
100-
SELECT * FROM df
101-
"""
102-
)
103-
104-
def plan_and_run(
105-
self,
106-
*,
107-
environment: str,
108-
execution_time: TimeLike | None = None,
109-
enable_debug_console: bool = False,
110-
start: TimeLike | None = None,
111-
end: TimeLike | None = None,
112-
select_models: list[str] | None = None,
113-
restate_selected: bool = False,
114-
skip_run: bool = False,
115-
) -> t.Iterator[ConsoleEvent] | None:
116-
"""Runs plan and run on SQLMesh with the given configuration and record all of the generated events.
117-
118-
Args:
119-
environment (str): The environment to run SQLMesh in.
120-
execution_time (TimeLike, optional): The execution timestamp for the run. Defaults to None.
121-
enable_debug_console (bool, optional): Flag to enable debug console. Defaults to False.
122-
start (TimeLike, optional): Start time for the run interval. Defaults to None.
123-
end (TimeLike, optional): End time for the run interval. Defaults to None.
124-
restate_models (List[str], optional): List of models to restate. Defaults to None.
125-
126-
Returns:
127-
None: The function records events to a debug console but doesn't return anything.
128-
129-
Note:
130-
TimeLike can be any time-like object that SQLMesh accepts (datetime, str, etc.).
131-
The function creates a controller and recorder to capture all SQLMesh events during execution.
132-
"""
133-
controller = self.create_controller(enable_debug_console=enable_debug_console)
134-
recorder = ConsoleRecorder()
135-
# controller.add_event_handler(ConsoleRecorder())
136-
plan_options = PlanOptions(
137-
enable_preview=True,
138-
)
139-
run_options = RunOptions()
140-
if execution_time:
141-
plan_options["execution_time"] = execution_time
142-
run_options["execution_time"] = execution_time
143-
144-
for event in controller.plan_and_run(
145-
environment,
146-
start=start,
147-
end=end,
148-
select_models=select_models,
149-
restate_selected=restate_selected,
150-
plan_options=plan_options,
151-
run_options=run_options,
152-
skip_run=skip_run,
153-
):
154-
recorder(event)
155-
156-
15745
@pytest.fixture
15846
def sample_sqlmesh_test_context(
15947
sample_sqlmesh_project: str,

dagster_sqlmesh/console.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,7 @@
1212
from sqlmesh.core.linter.rule import RuleViolation
1313
from sqlmesh.core.model import Model
1414
from sqlmesh.core.plan import EvaluatablePlan, PlanBuilder
15-
from sqlmesh.core.snapshot import (
16-
Snapshot,
17-
SnapshotChangeCategory,
18-
SnapshotInfoLike,
19-
)
15+
from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory, SnapshotInfoLike
2016
from sqlmesh.core.table_diff import RowDiff, SchemaDiff, TableDiff
2117
from sqlmesh.utils.concurrency import NodeExecutionFailedError
2218

dagster_sqlmesh/controller/base.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,7 @@ def __init__(
128128
self.logger = logger
129129

130130
@contextmanager
131-
def console_context(
132-
self, handler: ConsoleEventHandler
133-
) -> t.Iterator[None]:
131+
def console_context(self, handler: ConsoleEventHandler) -> t.Iterator[None]:
134132
id = self.console.add_handler(handler)
135133
yield
136134
self.console.remove_handler(id)
@@ -224,9 +222,7 @@ def run_sqlmesh_thread(
224222

225223
thread.join()
226224

227-
def run(
228-
self, **run_options: t.Unpack[RunOptions]
229-
) -> t.Iterator[ConsoleEvent]:
225+
def run(self, **run_options: t.Unpack[RunOptions]) -> t.Iterator[ConsoleEvent]:
230226
"""Executes sqlmesh run in a separate thread with console output.
231227
232228
This method executes SQLMesh operations in a dedicated thread while
@@ -295,7 +291,7 @@ def plan_and_run(
295291
end: TimeLike | None = None,
296292
categorizer: SnapshotCategorizer | None = None,
297293
default_catalog: str | None = None,
298-
plan_options: PlanOptions | None= None,
294+
plan_options: PlanOptions | None = None,
299295
run_options: RunOptions | None = None,
300296
skip_run: bool = False,
301297
) -> t.Iterator[ConsoleEvent]:
@@ -309,11 +305,11 @@ def plan_and_run(
309305

310306
if plan_options.get("select_models") or run_options.get("select_models"):
311307
raise ValueError(
312-
"select_models should not be set in plan_options or run_options use the `select_models` or `select_models_func` arguments instead"
308+
"select_models should not be set in plan_options or run_options use the `select_models` option instead"
313309
)
314310
if plan_options.get("restate_models"):
315311
raise ValueError(
316-
"restate_models should not be set in plan_options use the `restate_selected` argument with `select_models` or `select_models_func` instead"
312+
"restate_models should not be set in plan_options use the `restate_selected` argument with `select_models` instead"
317313
)
318314
select_models = select_models or []
319315

@@ -352,6 +348,16 @@ def models(self) -> MappingProxyType[str, Model]:
352348
def models_dag(self) -> DAG[str]:
353349
return self.context.dag
354350

351+
def non_external_models_dag(self) -> t.Iterable[tuple[Model, set[str]]]:
352+
dag = self.context.dag
353+
354+
for model_fqn, deps in dag.graph.items():
355+
logger.debug(f"model found: {model_fqn}")
356+
model = self.context.get_model(model_fqn)
357+
if not model:
358+
continue
359+
yield (model, deps)
360+
355361

356362
class SQLMeshController:
357363
"""Allows control of sqlmesh via a python interface. It is not suggested to

dagster_sqlmesh/controller/dagster.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
import logging
22

3-
from dagster import (
4-
AssetDep,
5-
AssetKey,
6-
AssetOut,
7-
)
3+
from dagster import AssetDep, AssetKey, AssetOut
84
from dagster._core.definitions.asset_dep import CoercibleToAssetDep
95

106
from ..translator import SQLMeshDagsterTranslator
@@ -23,16 +19,10 @@ def to_asset_outs(
2319
) -> SQLMeshMultiAssetOptions:
2420
with self.instance(environment, "to_asset_outs") as instance:
2521
context = instance.context
26-
dag = context.dag
2722
output = SQLMeshMultiAssetOptions()
2823
depsMap: dict[str, CoercibleToAssetDep] = {}
2924

30-
for model_fqn, deps in dag.graph.items():
31-
logger.debug(f"model found: {model_fqn}")
32-
model = context.get_model(model_fqn)
33-
if not model:
34-
# If no model is returned this seems to be an asset dependency
35-
continue
25+
for model, deps in instance.non_external_models_dag():
3626
asset_key = translator.get_asset_key_from_model(
3727
context,
3828
model,

dagster_sqlmesh/resource.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
import logging
22
import typing as t
33

4-
from dagster import (
5-
AssetExecutionContext,
6-
ConfigurableResource,
7-
MaterializeResult,
8-
)
4+
from dagster import AssetExecutionContext, ConfigurableResource, MaterializeResult
95
from sqlmesh import Model
106
from sqlmesh.core.context import Context as SQLMeshContext
117
from sqlmesh.core.snapshot import Snapshot
@@ -14,7 +10,8 @@
1410

1511
from . import console
1612
from .config import SQLMeshContextConfig
17-
from .controller import PlanOptions, RunOptions, SQLMeshController
13+
from .controller import PlanOptions, RunOptions
14+
from .controller.dagster import DagsterSQLMeshController
1815
from .utils import sqlmesh_model_name_to_key
1916

2017

@@ -250,13 +247,17 @@ def run(
250247
logger = context.log
251248

252249
controller = self.get_controller(logger)
250+
253251
with controller.instance(environment) as mesh:
254252
dag = mesh.models_dag()
255253

256254
select_models = []
257255

258256
models = mesh.models()
259257
models_map = models.copy()
258+
all_available_models = set(
259+
[model.name for model, _ in mesh.non_external_models_dag()]
260+
)
260261
if context.selected_output_names:
261262
models_map = {}
262263
for key, model in models.items():
@@ -268,6 +269,14 @@ def run(
268269

269270
models_map[key] = model
270271
select_models.append(model.name)
272+
selected_models_set = set(models_map.keys())
273+
274+
if all_available_models == selected_models_set:
275+
logger.info("all models selected")
276+
277+
# Setting this to none to allow sqlmesh to select all models and
278+
# also remove any models
279+
select_models = None
271280

272281
event_handler = DagsterSQLMeshEventHandler(
273282
context, models_map, dag, "sqlmesh: "
@@ -285,7 +294,7 @@ def run(
285294

286295
def get_controller(
287296
self, log_override: logging.Logger | None = None
288-
) -> SQLMeshController:
289-
return SQLMeshController.setup_with_config(
297+
) -> DagsterSQLMeshController:
298+
return DagsterSQLMeshController.setup_with_config(
290299
self.config, log_override=log_override
291300
)

dagster_sqlmesh/test_asset.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from dagster_sqlmesh.asset import (
2-
SQLMeshDagsterTranslator,
3-
)
1+
from dagster_sqlmesh.asset import SQLMeshDagsterTranslator
42
from dagster_sqlmesh.conftest import SQLMeshTestContext
53

64

dagster_sqlmesh/test_sqlmesh_context.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import polars
44

5-
from .conftest import SQLMeshTestContext
5+
from .testing import SQLMeshTestContext
66

77
logger = logging.getLogger(__name__)
88

@@ -184,12 +184,12 @@ def test_restating_models(sample_sqlmesh_test_context: SQLMeshTestContext):
184184
"""
185185
)
186186

187-
assert feb_sum_query_restate[0][0] == feb_sum_query[0][0], (
188-
"February sum should not change"
189-
)
190-
assert march_sum_query_restate[0][0] != march_sum_query[0][0], (
191-
"March sum should change"
192-
)
193-
assert intermediate_2_query_restate[0][0] == intermediate_2_query[0][0], (
194-
"Intermediate model should not change during restate"
195-
)
187+
assert (
188+
feb_sum_query_restate[0][0] == feb_sum_query[0][0]
189+
), "February sum should not change"
190+
assert (
191+
march_sum_query_restate[0][0] != march_sum_query[0][0]
192+
), "March sum should change"
193+
assert (
194+
intermediate_2_query_restate[0][0] == intermediate_2_query[0][0]
195+
), "Intermediate model should not change during restate"

dagster_sqlmesh/testing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# ruff: noqa: F403 F401
2+
from .context import *

0 commit comments

Comments
 (0)