Skip to content

Allow users to skip "sqlmesh run" and make completion status depend on Table Promotion #33

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dagster_sqlmesh/controller/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,8 @@ def plan_and_run(
self.logger.debug("starting sqlmesh plan")
self.logger.debug(f"selected models: {select_models}")
yield from self.plan(categorizer, default_catalog, **plan_options)
self.logger.debug("starting sqlmesh run")
if not skip_run:
self.logger.debug("starting sqlmesh run")
yield from self.run(**run_options)
except Exception as e:
self.logger.error(f"Error during sqlmesh plan and run: {e}")
Expand Down
71 changes: 52 additions & 19 deletions dagster_sqlmesh/resource.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import logging
import typing as t

from dagster import AssetExecutionContext, ConfigurableResource, MaterializeResult
from dagster import (
AssetExecutionContext,
ConfigurableResource,
MaterializeResult,
)
from sqlmesh import Model
from sqlmesh.core.context import Context as SQLMeshContext
from sqlmesh.core.snapshot import Snapshot
from sqlmesh.core.snapshot import Snapshot, SnapshotInfoLike, SnapshotTableInfo
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import TimeLike

Expand All @@ -27,28 +31,36 @@ def __init__(self, sorted_dag: list[str], logger: logging.Logger) -> None:
self._complete_update_status: dict[str, bool] = {}
self._sorted_dag = sorted_dag
self._current_index = 0
self.finished_promotion = False

def plan(self, batches: dict[Snapshot, int]) -> None:
self._batches = batches
self._count: dict[Snapshot, int] = {}

incomplete_names = set()
for snapshot, count in self._batches.items():
incomplete_names.add(snapshot.name)
self._count[snapshot] = 0
def init_complete_update_status(self, snapshots: list[SnapshotTableInfo]) -> None:
planned_model_names = set()
for snapshot in snapshots:
planned_model_names.add(snapshot.name)

# Anything not in the plan should be listed as completed and queued for
# notification
self._complete_update_status = {
name: False for name in (set(self._sorted_dag) - incomplete_names)
name: False for name in (set(self._sorted_dag) - planned_model_names)
}

def update(self, snapshot: Snapshot, _batch_idx: int) -> tuple[int, int]:
def update_promotion(self, snapshot: SnapshotInfoLike, promoted: bool) -> None:
self._complete_update_status[snapshot.name] = promoted

def stop_promotion(self) -> None:
self.finished_promotion = True

def plan(self, batches: dict[Snapshot, int]) -> None:
self._batches = batches
self._count: dict[Snapshot, int] = {}

for snapshot, _ in self._batches.items():
self._count[snapshot] = 0

def update_plan(self, snapshot: Snapshot, _batch_idx: int) -> tuple[int, int]:
self._count[snapshot] += 1
current_count = self._count[snapshot]
expected_count = self._batches[snapshot]
if self._batches[snapshot] == self._count[snapshot]:
self._complete_update_status[snapshot.name] = True
return (current_count, expected_count)

def notify_queue_next(self) -> tuple[str, bool] | None:
Expand Down Expand Up @@ -110,11 +122,12 @@ def __init__(
self._tracker = MaterializationTracker(dag.sorted[:], self._logger)
self._stage = "plan"

def process_events(
self, sqlmesh_context: SQLMeshContext, event: console.ConsoleEvent
) -> t.Iterator[MaterializeResult]:
def process_events(self, event: console.ConsoleEvent) -> None:
self.report_event(event)

def notify_success(
self, sqlmesh_context: SQLMeshContext
) -> t.Iterator[MaterializeResult]:
notify = self._tracker.notify_queue_next()
while notify is not None:
completed_name, update_status = notify
Expand Down Expand Up @@ -146,6 +159,7 @@ def report_event(self, event: console.ConsoleEvent) -> None:

match event:
case console.StartPlanEvaluation(plan):
self._tracker.init_complete_update_status(plan.environment.snapshots)
log_context.info(
"Starting Plan Evaluation",
{
Expand Down Expand Up @@ -173,7 +187,7 @@ def report_event(self, event: console.ConsoleEvent) -> None:
case console.UpdateSnapshotEvaluationProgress(
snapshot, batch_idx, duration_ms
):
done, expected = self._tracker.update(snapshot, batch_idx)
done, expected = self._tracker.update_plan(snapshot, batch_idx)

log_context.info(
"Snapshot progress update",
Expand All @@ -200,6 +214,21 @@ def report_event(self, event: console.ConsoleEvent) -> None:
[f"{model!s}\n{model.__cause__!s}" for model in models]
)
log_context.error(f"sqlmesh failed models: {failed_models}")
case console.UpdatePromotionProgress(snapshot, promoted):
log_context.info(
"Promotion progress update",
{
"snapshot": snapshot.name,
"promoted": promoted,
},
)
self._tracker.update_promotion(snapshot, promoted)
case console.StopPromotionProgress(success):
self._tracker.stop_promotion()
if success:
log_context.info("Promotion completed successfully")
else:
log_context.error("Promotion failed")
case _:
log_context.debug("Received event")

Expand Down Expand Up @@ -237,6 +266,7 @@ def run(
start: TimeLike | None = None,
end: TimeLike | None = None,
restate_selected: bool = False,
skip_run: bool = False,
plan_options: PlanOptions | None = None,
run_options: RunOptions | None = None,
) -> t.Iterable[MaterializeResult]:
Expand Down Expand Up @@ -287,10 +317,13 @@ def run(
end=end,
select_models=select_models,
restate_selected=restate_selected,
skip_run=skip_run,
plan_options=plan_options,
run_options=run_options,
):
yield from event_handler.process_events(mesh.context, event)
event_handler.process_events(event)

yield from event_handler.notify_success(mesh.context)

def get_controller(
self, log_override: logging.Logger | None = None
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "dagster-sqlmesh"
version = "0.10.0"
version = "0.11.0"
description = ""
authors = [
{name = "Reuven Gonzales", email = "reuven@karibalabs.co"}
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.