From 602a2ff74fb510fd7b68bb076e6427b507e94b5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=8Dcaro=20Guerra?= Date: Thu, 10 Apr 2025 12:13:49 -0300 Subject: [PATCH 1/2] feat(resource): allow users to skip the run step --- dagster_sqlmesh/resource.py | 2 ++ pyproject.toml | 2 +- uv.lock | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/dagster_sqlmesh/resource.py b/dagster_sqlmesh/resource.py index 02665a9..10e0c83 100644 --- a/dagster_sqlmesh/resource.py +++ b/dagster_sqlmesh/resource.py @@ -237,6 +237,7 @@ def run( start: TimeLike | None = None, end: TimeLike | None = None, restate_selected: bool = False, + skip_run: bool = False, plan_options: PlanOptions | None = None, run_options: RunOptions | None = None, ) -> t.Iterable[MaterializeResult]: @@ -287,6 +288,7 @@ def run( end=end, select_models=select_models, restate_selected=restate_selected, + skip_run=skip_run, plan_options=plan_options, run_options=run_options, ): diff --git a/pyproject.toml b/pyproject.toml index b78da07..2c7dfa1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dagster-sqlmesh" -version = "0.10.0" +version = "0.11.0" description = "" authors = [ {name = "Reuven Gonzales", email = "reuven@karibalabs.co"} diff --git a/uv.lock b/uv.lock index 9b1daee..077d885 100644 --- a/uv.lock +++ b/uv.lock @@ -287,7 +287,7 @@ wheels = [ [[package]] name = "dagster-sqlmesh" -version = "0.10.0" +version = "0.11.0" source = { editable = "." } dependencies = [ { name = "dagster" }, From 82c7ba9f0a795e5b2850abf7d50767955fbbb277 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=8Dcaro=20Guerra?= Date: Thu, 10 Apr 2025 16:46:22 -0300 Subject: [PATCH 2/2] feat(resource): make success depend on promoting models instead of planning --- dagster_sqlmesh/controller/base.py | 2 +- dagster_sqlmesh/resource.py | 69 ++++++++++++++++++++++-------- 2 files changed, 51 insertions(+), 20 deletions(-) diff --git a/dagster_sqlmesh/controller/base.py b/dagster_sqlmesh/controller/base.py index 6f4cbf1..5d66e1d 100644 --- a/dagster_sqlmesh/controller/base.py +++ b/dagster_sqlmesh/controller/base.py @@ -332,8 +332,8 @@ def plan_and_run( self.logger.debug("starting sqlmesh plan") self.logger.debug(f"selected models: {select_models}") yield from self.plan(categorizer, default_catalog, **plan_options) - self.logger.debug("starting sqlmesh run") if not skip_run: + self.logger.debug("starting sqlmesh run") yield from self.run(**run_options) except Exception as e: self.logger.error(f"Error during sqlmesh plan and run: {e}") diff --git a/dagster_sqlmesh/resource.py b/dagster_sqlmesh/resource.py index 10e0c83..54ce1c8 100644 --- a/dagster_sqlmesh/resource.py +++ b/dagster_sqlmesh/resource.py @@ -1,10 +1,14 @@ import logging import typing as t -from dagster import AssetExecutionContext, ConfigurableResource, MaterializeResult +from dagster import ( + AssetExecutionContext, + ConfigurableResource, + MaterializeResult, +) from sqlmesh import Model from sqlmesh.core.context import Context as SQLMeshContext -from sqlmesh.core.snapshot import Snapshot +from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import TimeLike @@ -27,28 +31,36 @@ def __init__(self, sorted_dag: list[str], logger: logging.Logger) -> None: self._complete_update_status: dict[str, bool] = {} self._sorted_dag = sorted_dag self._current_index = 0 + self.finished_promotion = False - def plan(self, batches: dict[Snapshot, int]) -> None: - self._batches = batches - self._count: dict[Snapshot, int] = {} - - incomplete_names = set() - for snapshot, count in self._batches.items(): - incomplete_names.add(snapshot.name) - self._count[snapshot] = 0 + def init_complete_update_status(self, snapshots: list[SnapshotTableInfo]) -> None: + planned_model_names = set() + for snapshot in snapshots: + planned_model_names.add(snapshot.name) # Anything not in the plan should be listed as completed and queued for # notification self._complete_update_status = { - name: False for name in (set(self._sorted_dag) - incomplete_names) + name: False for name in (set(self._sorted_dag) - planned_model_names) } - def update(self, snapshot: Snapshot, _batch_idx: int) -> tuple[int, int]: + def update_promotion(self, snapshot: SnapshotInfoLike, promoted: bool) -> None: + self._complete_update_status[snapshot.name] = promoted + + def stop_promotion(self) -> None: + self.finished_promotion = True + + def plan(self, batches: dict[Snapshot, int]) -> None: + self._batches = batches + self._count: dict[Snapshot, int] = {} + + for snapshot, _ in self._batches.items(): + self._count[snapshot] = 0 + + def update_plan(self, snapshot: Snapshot, _batch_idx: int) -> tuple[int, int]: self._count[snapshot] += 1 current_count = self._count[snapshot] expected_count = self._batches[snapshot] - if self._batches[snapshot] == self._count[snapshot]: - self._complete_update_status[snapshot.name] = True return (current_count, expected_count) def notify_queue_next(self) -> tuple[str, bool] | None: @@ -110,11 +122,12 @@ def __init__( self._tracker = MaterializationTracker(dag.sorted[:], self._logger) self._stage = "plan" - def process_events( - self, sqlmesh_context: SQLMeshContext, event: console.ConsoleEvent - ) -> t.Iterator[MaterializeResult]: + def process_events(self, event: console.ConsoleEvent) -> None: self.report_event(event) + def notify_success( + self, sqlmesh_context: SQLMeshContext + ) -> t.Iterator[MaterializeResult]: notify = self._tracker.notify_queue_next() while notify is not None: completed_name, update_status = notify @@ -146,6 +159,7 @@ def report_event(self, event: console.ConsoleEvent) -> None: match event: case console.StartPlanEvaluation(plan): + self._tracker.init_complete_update_status(plan.environment.snapshots) log_context.info( "Starting Plan Evaluation", { @@ -173,7 +187,7 @@ def report_event(self, event: console.ConsoleEvent) -> None: case console.UpdateSnapshotEvaluationProgress( snapshot, batch_idx, duration_ms ): - done, expected = self._tracker.update(snapshot, batch_idx) + done, expected = self._tracker.update_plan(snapshot, batch_idx) log_context.info( "Snapshot progress update", @@ -200,6 +214,21 @@ def report_event(self, event: console.ConsoleEvent) -> None: [f"{model!s}\n{model.__cause__!s}" for model in models] ) log_context.error(f"sqlmesh failed models: {failed_models}") + case console.UpdatePromotionProgress(snapshot, promoted): + log_context.info( + "Promotion progress update", + { + "snapshot": snapshot.name, + "promoted": promoted, + }, + ) + self._tracker.update_promotion(snapshot, promoted) + case console.StopPromotionProgress(success): + self._tracker.stop_promotion() + if success: + log_context.info("Promotion completed successfully") + else: + log_context.error("Promotion failed") case _: log_context.debug("Received event") @@ -292,7 +321,9 @@ def run( plan_options=plan_options, run_options=run_options, ): - yield from event_handler.process_events(mesh.context, event) + event_handler.process_events(event) + + yield from event_handler.notify_success(mesh.context) def get_controller( self, log_override: logging.Logger | None = None