Skip to content

fix: fix selecting all #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions dagster_sqlmesh/asset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import logging
import typing as t

from dagster import (
AssetsDefinition,
RetryPolicy,
multi_asset,
)
from dagster import AssetsDefinition, RetryPolicy, multi_asset

from dagster_sqlmesh.controller import DagsterSQLMeshController
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
Expand Down
114 changes: 1 addition & 113 deletions dagster_sqlmesh/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,17 @@
import sys
import tempfile
import typing as t
from dataclasses import dataclass

import duckdb
import polars
import pytest
from sqlmesh.core.config import (
Config as SQLMeshConfig,
DuckDBConnectionConfig,
GatewayConfig,
ModelDefaultsConfig,
)
from sqlmesh.core.console import get_console
from sqlmesh.utils.date import TimeLike

from dagster_sqlmesh.config import SQLMeshContextConfig
from dagster_sqlmesh.console import ConsoleEvent
from dagster_sqlmesh.controller.base import PlanOptions, RunOptions
from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController
from dagster_sqlmesh.events import ConsoleRecorder
from dagster_sqlmesh.testing import SQLMeshTestContext

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,110 +42,6 @@ def sample_sqlmesh_project() -> t.Iterator[str]:
yield str(project_dir)


@dataclass
class SQLMeshTestContext:
"""A test context for running SQLMesh"""

db_path: str
context_config: SQLMeshContextConfig

def create_controller(
self, enable_debug_console: bool = False
) -> DagsterSQLMeshController:
console = None
if enable_debug_console:
console = get_console()
return DagsterSQLMeshController.setup_with_config(
self.context_config, debug_console=console
)

def query(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
conn = duckdb.connect(self.db_path)
return conn.sql(*args, **kwargs).fetchall()

def initialize_test_source(self) -> None:
conn = duckdb.connect(self.db_path)
conn.sql(
"""
CREATE SCHEMA sources;
"""
)
conn.sql(
"""
CREATE TABLE sources.test_source (id INTEGER, name VARCHAR);
"""
)
conn.sql(
"""
INSERT INTO sources.test_source (id, name)
VALUES (1, 'abc'), (2, 'def');
"""
)
conn.close()

def append_to_test_source(self, df: polars.DataFrame):
logger.debug("appending data to the test source")
conn = duckdb.connect(self.db_path)
conn.sql(
"""
INSERT INTO sources.test_source
SELECT * FROM df
"""
)

def plan_and_run(
self,
*,
environment: str,
execution_time: TimeLike | None = None,
enable_debug_console: bool = False,
start: TimeLike | None = None,
end: TimeLike | None = None,
select_models: list[str] | None = None,
restate_selected: bool = False,
skip_run: bool = False,
) -> t.Iterator[ConsoleEvent] | None:
"""Runs plan and run on SQLMesh with the given configuration and record all of the generated events.

Args:
environment (str): The environment to run SQLMesh in.
execution_time (TimeLike, optional): The execution timestamp for the run. Defaults to None.
enable_debug_console (bool, optional): Flag to enable debug console. Defaults to False.
start (TimeLike, optional): Start time for the run interval. Defaults to None.
end (TimeLike, optional): End time for the run interval. Defaults to None.
restate_models (List[str], optional): List of models to restate. Defaults to None.

Returns:
None: The function records events to a debug console but doesn't return anything.

Note:
TimeLike can be any time-like object that SQLMesh accepts (datetime, str, etc.).
The function creates a controller and recorder to capture all SQLMesh events during execution.
"""
controller = self.create_controller(enable_debug_console=enable_debug_console)
recorder = ConsoleRecorder()
# controller.add_event_handler(ConsoleRecorder())
plan_options = PlanOptions(
enable_preview=True,
)
run_options = RunOptions()
if execution_time:
plan_options["execution_time"] = execution_time
run_options["execution_time"] = execution_time

for event in controller.plan_and_run(
environment,
start=start,
end=end,
select_models=select_models,
restate_selected=restate_selected,
plan_options=plan_options,
run_options=run_options,
skip_run=skip_run,
):
recorder(event)


@pytest.fixture
def sample_sqlmesh_test_context(
sample_sqlmesh_project: str,
Expand Down
6 changes: 1 addition & 5 deletions dagster_sqlmesh/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@
from sqlmesh.core.linter.rule import RuleViolation
from sqlmesh.core.model import Model
from sqlmesh.core.plan import EvaluatablePlan, PlanBuilder
from sqlmesh.core.snapshot import (
Snapshot,
SnapshotChangeCategory,
SnapshotInfoLike,
)
from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory, SnapshotInfoLike
from sqlmesh.core.table_diff import RowDiff, SchemaDiff, TableDiff
from sqlmesh.utils.concurrency import NodeExecutionFailedError

Expand Down
24 changes: 15 additions & 9 deletions dagster_sqlmesh/controller/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,7 @@ def __init__(
self.logger = logger

@contextmanager
def console_context(
self, handler: ConsoleEventHandler
) -> t.Iterator[None]:
def console_context(self, handler: ConsoleEventHandler) -> t.Iterator[None]:
id = self.console.add_handler(handler)
yield
self.console.remove_handler(id)
Expand Down Expand Up @@ -224,9 +222,7 @@ def run_sqlmesh_thread(

thread.join()

def run(
self, **run_options: t.Unpack[RunOptions]
) -> t.Iterator[ConsoleEvent]:
def run(self, **run_options: t.Unpack[RunOptions]) -> t.Iterator[ConsoleEvent]:
"""Executes sqlmesh run in a separate thread with console output.

This method executes SQLMesh operations in a dedicated thread while
Expand Down Expand Up @@ -295,7 +291,7 @@ def plan_and_run(
end: TimeLike | None = None,
categorizer: SnapshotCategorizer | None = None,
default_catalog: str | None = None,
plan_options: PlanOptions | None= None,
plan_options: PlanOptions | None = None,
run_options: RunOptions | None = None,
skip_run: bool = False,
) -> t.Iterator[ConsoleEvent]:
Expand All @@ -309,11 +305,11 @@ def plan_and_run(

if plan_options.get("select_models") or run_options.get("select_models"):
raise ValueError(
"select_models should not be set in plan_options or run_options use the `select_models` or `select_models_func` arguments instead"
"select_models should not be set in plan_options or run_options use the `select_models` option instead"
)
if plan_options.get("restate_models"):
raise ValueError(
"restate_models should not be set in plan_options use the `restate_selected` argument with `select_models` or `select_models_func` instead"
"restate_models should not be set in plan_options use the `restate_selected` argument with `select_models` instead"
)
select_models = select_models or []

Expand Down Expand Up @@ -352,6 +348,16 @@ def models(self) -> MappingProxyType[str, Model]:
def models_dag(self) -> DAG[str]:
return self.context.dag

def non_external_models_dag(self) -> t.Iterable[tuple[Model, set[str]]]:
dag = self.context.dag

for model_fqn, deps in dag.graph.items():
logger.debug(f"model found: {model_fqn}")
model = self.context.get_model(model_fqn)
if not model:
continue
yield (model, deps)


class SQLMeshController:
"""Allows control of sqlmesh via a python interface. It is not suggested to
Expand Down
14 changes: 2 additions & 12 deletions dagster_sqlmesh/controller/dagster.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import logging

from dagster import (
AssetDep,
AssetKey,
AssetOut,
)
from dagster import AssetDep, AssetKey, AssetOut
from dagster._core.definitions.asset_dep import CoercibleToAssetDep

from ..translator import SQLMeshDagsterTranslator
Expand All @@ -23,16 +19,10 @@ def to_asset_outs(
) -> SQLMeshMultiAssetOptions:
with self.instance(environment, "to_asset_outs") as instance:
context = instance.context
dag = context.dag
output = SQLMeshMultiAssetOptions()
depsMap: dict[str, CoercibleToAssetDep] = {}

for model_fqn, deps in dag.graph.items():
logger.debug(f"model found: {model_fqn}")
model = context.get_model(model_fqn)
if not model:
# If no model is returned this seems to be an asset dependency
continue
for model, deps in instance.non_external_models_dag():
asset_key = translator.get_asset_key_from_model(
context,
model,
Expand Down
25 changes: 17 additions & 8 deletions dagster_sqlmesh/resource.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import logging
import typing as t

from dagster import (
AssetExecutionContext,
ConfigurableResource,
MaterializeResult,
)
from dagster import AssetExecutionContext, ConfigurableResource, MaterializeResult
from sqlmesh import Model
from sqlmesh.core.context import Context as SQLMeshContext
from sqlmesh.core.snapshot import Snapshot
Expand All @@ -14,7 +10,8 @@

from . import console
from .config import SQLMeshContextConfig
from .controller import PlanOptions, RunOptions, SQLMeshController
from .controller import PlanOptions, RunOptions
from .controller.dagster import DagsterSQLMeshController
from .utils import sqlmesh_model_name_to_key


Expand Down Expand Up @@ -250,13 +247,17 @@ def run(
logger = context.log

controller = self.get_controller(logger)

with controller.instance(environment) as mesh:
dag = mesh.models_dag()

select_models = []

models = mesh.models()
models_map = models.copy()
all_available_models = set(
[model.name for model, _ in mesh.non_external_models_dag()]
)
if context.selected_output_names:
models_map = {}
for key, model in models.items():
Expand All @@ -268,6 +269,14 @@ def run(

models_map[key] = model
select_models.append(model.name)
selected_models_set = set(models_map.keys())

if all_available_models == selected_models_set:
logger.info("all models selected")

# Setting this to none to allow sqlmesh to select all models and
# also remove any models
select_models = None

event_handler = DagsterSQLMeshEventHandler(
context, models_map, dag, "sqlmesh: "
Expand All @@ -285,7 +294,7 @@ def run(

def get_controller(
self, log_override: logging.Logger | None = None
) -> SQLMeshController:
return SQLMeshController.setup_with_config(
) -> DagsterSQLMeshController:
return DagsterSQLMeshController.setup_with_config(
self.config, log_override=log_override
)
4 changes: 1 addition & 3 deletions dagster_sqlmesh/test_asset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from dagster_sqlmesh.asset import (
SQLMeshDagsterTranslator,
)
from dagster_sqlmesh.asset import SQLMeshDagsterTranslator
from dagster_sqlmesh.conftest import SQLMeshTestContext


Expand Down
20 changes: 10 additions & 10 deletions dagster_sqlmesh/test_sqlmesh_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import polars

from .conftest import SQLMeshTestContext
from .testing import SQLMeshTestContext

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -184,12 +184,12 @@ def test_restating_models(sample_sqlmesh_test_context: SQLMeshTestContext):
"""
)

assert feb_sum_query_restate[0][0] == feb_sum_query[0][0], (
"February sum should not change"
)
assert march_sum_query_restate[0][0] != march_sum_query[0][0], (
"March sum should change"
)
assert intermediate_2_query_restate[0][0] == intermediate_2_query[0][0], (
"Intermediate model should not change during restate"
)
assert (
feb_sum_query_restate[0][0] == feb_sum_query[0][0]
), "February sum should not change"
assert (
march_sum_query_restate[0][0] != march_sum_query[0][0]
), "March sum should change"
assert (
intermediate_2_query_restate[0][0] == intermediate_2_query[0][0]
), "Intermediate model should not change during restate"
2 changes: 2 additions & 0 deletions dagster_sqlmesh/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# ruff: noqa: F403 F401
from .context import *
Loading