diff --git a/.gitignore b/.gitignore index ce64de2..11cabc8 100644 --- a/.gitignore +++ b/.gitignore @@ -61,4 +61,9 @@ dbt_packages/ # Python *.pyc -*.db \ No newline at end of file +*.db + +sample/dagster_project/storage/ +sample/dagster_project/logs/ +sample/dagster_project/history/ +sample/dagster_project/schedules/ diff --git a/Makefile b/Makefile index e264e0e..2987092 100644 --- a/Makefile +++ b/Makefile @@ -45,13 +45,26 @@ pyright: pnpm pyright # Sample project commands +clean-dagster: + rm -rf sample/dagster_project/storage sample/dagster_project/logs sample/dagster_project/history + +clean-db: + $(PYTHON_CMD) -c "import duckdb; conn = duckdb.connect('db.db'); [conn.execute(cmd[0]) for cmd in conn.execute(\"\"\"SELECT 'DROP TABLE ' || table_schema || '.' || table_name || ' CASCADE;' as drop_cmd FROM information_schema.tables WHERE table_schema != 'sources' AND table_schema != 'information_schema' AND table_type = 'BASE TABLE'\"\"\").fetchall()]; [conn.execute(cmd[0]) for cmd in conn.execute(\"\"\"SELECT 'DROP VIEW ' || table_schema || '.' || table_name || ' CASCADE;' as drop_cmd FROM information_schema.tables WHERE table_schema != 'sources' AND table_schema != 'information_schema' AND table_type = 'VIEW'\"\"\").fetchall()]; conn.close()" dagster-dev: clean-dagster - DAGSTER_HOME=$(CURDIR)/sample/dagster_project $(PYTHON_CMD) -m dagster dev -h 0.0.0.0 -w + @DAGSTER_HOME="$(subst \,/,$(CURDIR))/sample/dagster_project" "$(PYTHON_CMD)" -m dagster dev -f sample/dagster_project/definitions.py -h 0.0.0.0 dev: dagster-dev # Alias for dagster-dev dagster-materialize: $(PYTHON_CMD) -m dagster asset materialize -f sample/dagster_project/definitions.py --select '*' -.PHONY: init init-python install-python-deps upgrade-python-deps clean test mypy check-pnpm install-node-deps upgrade-node-deps sample-dev dagster-dev dagster-materialize clean-dagster \ No newline at end of file +sqlmesh-plan: + cd sample/sqlmesh_project && $(SQLMESH_CMD) plan + +sqlmesh-run: + cd sample/sqlmesh_project && $(SQLMESH_CMD) run + +clean-dev: clean-db clean-dagster dev + +.PHONY: init init-python install-python-deps upgrade-python-deps clean test mypy check-pnpm install-node-deps upgrade-node-deps sample-dev dagster-dev dagster-materialize clean-dagster clean-db clean-dev \ No newline at end of file diff --git a/dagster_sqlmesh/asset.py b/dagster_sqlmesh/asset.py index 586fe27..aada67e 100644 --- a/dagster_sqlmesh/asset.py +++ b/dagster_sqlmesh/asset.py @@ -1,9 +1,10 @@ -import typing as t import logging +import typing as t from dagster import ( - multi_asset, + AssetsDefinition, RetryPolicy, + multi_asset, ) from dagster_sqlmesh.controller import DagsterSQLMeshController @@ -19,15 +20,15 @@ def sqlmesh_assets( *, environment: str, config: SQLMeshContextConfig, - name: t.Optional[str] = None, - dagster_sqlmesh_translator: t.Optional[SQLMeshDagsterTranslator] = None, + name: str | None = None, + dagster_sqlmesh_translator: SQLMeshDagsterTranslator | None = None, compute_kind: str = "sqlmesh", - op_tags: t.Optional[t.Mapping[str, t.Any]] = None, - required_resource_keys: t.Optional[t.Set[str]] = None, - retry_policy: t.Optional[RetryPolicy] = None, + op_tags: t.Mapping[str, t.Any] | None = None, + required_resource_keys: set[str] | None = None, + retry_policy: RetryPolicy | None = None, # For now we don't set this by default enabled_subsetting: bool = False, -): +) -> t.Callable[[t.Callable[..., t.Any]], AssetsDefinition]: controller = DagsterSQLMeshController.setup_with_config(config) if not dagster_sqlmesh_translator: dagster_sqlmesh_translator = SQLMeshDagsterTranslator() diff --git a/dagster_sqlmesh/config.py b/dagster_sqlmesh/config.py index 76b479e..d088ee9 100644 --- a/dagster_sqlmesh/config.py +++ b/dagster_sqlmesh/config.py @@ -1,16 +1,16 @@ -from typing import Optional, Dict, Any from dataclasses import dataclass +from typing import Any from dagster import Config -from sqlmesh.core.config import Config as MeshConfig from pydantic import Field +from sqlmesh.core.config import Config as MeshConfig @dataclass class ConfigOverride: - config_as_dict: Dict + config_as_dict: dict[str, Any] - def dict(self): + def dict(self) -> dict[str, Any]: return self.config_as_dict @@ -24,10 +24,10 @@ class SQLMeshContextConfig(Config): path: str gateway: str - config_override: Optional[Dict[str, Any]] = Field(default_factory=lambda: None) + config_override: dict[str, Any] | None = Field(default_factory=lambda: None) @property - def sqlmesh_config(self): + def sqlmesh_config(self) -> MeshConfig | None: if self.config_override: return MeshConfig.parse_obj(self.config_override) return None diff --git a/dagster_sqlmesh/conftest.py b/dagster_sqlmesh/conftest.py index 621ba10..3a371a2 100644 --- a/dagster_sqlmesh/conftest.py +++ b/dagster_sqlmesh/conftest.py @@ -1,33 +1,34 @@ import logging +import os +import shutil import sys import tempfile -import shutil -import os -from dataclasses import dataclass import typing as t +from dataclasses import dataclass -import pytest import duckdb import polars -from sqlmesh.utils.date import TimeLike -from sqlmesh.core.console import get_console +import pytest from sqlmesh.core.config import ( Config as SQLMeshConfig, - GatewayConfig, DuckDBConnectionConfig, + GatewayConfig, ModelDefaultsConfig, ) +from sqlmesh.core.console import get_console +from sqlmesh.utils.date import TimeLike from dagster_sqlmesh.config import SQLMeshContextConfig -from dagster_sqlmesh.events import ConsoleRecorder +from dagster_sqlmesh.console import ConsoleEvent from dagster_sqlmesh.controller.base import PlanOptions, RunOptions from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController +from dagster_sqlmesh.events import ConsoleRecorder logger = logging.getLogger(__name__) @pytest.fixture(scope="session", autouse=True) -def setup_debug_logging_for_tests(): +def setup_debug_logging_for_tests() -> None: root_logger = logging.getLogger(__name__.split(".")[0]) root_logger.setLevel(logging.DEBUG) @@ -35,7 +36,7 @@ def setup_debug_logging_for_tests(): @pytest.fixture -def sample_sqlmesh_project(): +def sample_sqlmesh_project() -> t.Iterator[str]: """Creates a temporary sqlmesh project by copying the sample project""" with tempfile.TemporaryDirectory() as tmp_dir: project_dir = shutil.copytree( @@ -56,7 +57,9 @@ class SQLMeshTestContext: db_path: str context_config: SQLMeshContextConfig - def create_controller(self, enable_debug_console: bool = False): + def create_controller( + self, enable_debug_console: bool = False + ) -> DagsterSQLMeshController: console = None if enable_debug_console: console = get_console() @@ -64,11 +67,11 @@ def create_controller(self, enable_debug_console: bool = False): self.context_config, debug_console=console ) - def query(self, *args, **kwargs): + def query(self, *args: t.Any, **kwargs: t.Any) -> t.Any: conn = duckdb.connect(self.db_path) return conn.sql(*args, **kwargs).fetchall() - def initialize_test_source(self): + def initialize_test_source(self) -> None: conn = duckdb.connect(self.db_path) conn.sql( """ @@ -102,14 +105,14 @@ def plan_and_run( self, *, environment: str, - execution_time: t.Optional[TimeLike] = None, + execution_time: TimeLike | None = None, enable_debug_console: bool = False, - start: t.Optional[TimeLike] = None, - end: t.Optional[TimeLike] = None, - select_models: t.Optional[t.List[str]] = None, + start: TimeLike | None = None, + end: TimeLike | None = None, + select_models: list[str] | None = None, restate_selected: bool = False, skip_run: bool = False, - ): + ) -> t.Iterator[ConsoleEvent] | None: """Runs plan and run on SQLMesh with the given configuration and record all of the generated events. Args: @@ -152,7 +155,9 @@ def plan_and_run( @pytest.fixture -def sample_sqlmesh_test_context(sample_sqlmesh_project: str): +def sample_sqlmesh_test_context( + sample_sqlmesh_project: str, +) -> t.Iterator[SQLMeshTestContext]: db_path = os.path.join(sample_sqlmesh_project, "db.db") config = SQLMeshConfig( gateways={ diff --git a/dagster_sqlmesh/console.py b/dagster_sqlmesh/console.py index 6749c7b..d8cdce5 100644 --- a/dagster_sqlmesh/console.py +++ b/dagster_sqlmesh/console.py @@ -2,15 +2,15 @@ import typing as t import unittest import uuid -from dataclasses import dataclass -from typing import Callable, Dict, Union +from collections.abc import Callable +from dataclasses import dataclass, field from sqlglot.expressions import Alter from sqlmesh.core.console import Console -from sqlmesh.core.model import Model from sqlmesh.core.context_diff import ContextDiff from sqlmesh.core.environment import EnvironmentNamingInfo from sqlmesh.core.linter.rule import RuleViolation +from sqlmesh.core.model import Model from sqlmesh.core.plan import EvaluatablePlan, PlanBuilder from sqlmesh.core.snapshot import ( Snapshot, @@ -50,9 +50,9 @@ class StopPlanEvaluation: @dataclass class StartEvaluationProgress: - batches: Dict[Snapshot, int] + batches: dict[Snapshot, int] environment_naming_info: EnvironmentNamingInfo - default_catalog: t.Optional[str] + default_catalog: str | None @dataclass @@ -64,7 +64,7 @@ class StartSnapshotEvaluationProgress: class UpdateSnapshotEvaluationProgress: snapshot: Snapshot batch_idx: int - duration_ms: t.Optional[int] + duration_ms: int | None @dataclass @@ -76,7 +76,7 @@ class StopEvaluationProgress: class StartCreationProgress: total_tasks: int environment_naming_info: EnvironmentNamingInfo - default_catalog: t.Optional[str] + default_catalog: str | None @dataclass @@ -108,7 +108,7 @@ class StopCleanup: class StartPromotionProgress: total_tasks: int environment_naming_info: EnvironmentNamingInfo - default_catalog: t.Optional[str] + default_catalog: str | None @dataclass @@ -161,7 +161,7 @@ class StopEnvMigrationProgress: class ShowModelDifferenceSummary: context_diff: ContextDiff environment_naming_info: EnvironmentNamingInfo - default_catalog: t.Optional[str] + default_catalog: str | None no_diff: bool = True @@ -169,7 +169,7 @@ class ShowModelDifferenceSummary: class PlanEvent: plan_builder: PlanBuilder auto_apply: bool - default_catalog: t.Optional[str] + default_catalog: str | None no_diff: bool = False no_prompts: bool = False @@ -177,7 +177,7 @@ class PlanEvent: @dataclass class LogTestResults: result: unittest.result.TestResult - output: t.Optional[str] + output: str | None target_dialect: str @@ -199,7 +199,7 @@ class LogError: @dataclass class LogWarning: short_message: str - long_message: t.Optional[str] = None + long_message: str | None = None @dataclass @@ -209,27 +209,27 @@ class LogSuccess: @dataclass class LogFailedModels: - errors: t.List[NodeExecutionFailedError] + errors: list[NodeExecutionFailedError[str]] @dataclass class LogSkippedModels: - snapshot_names: t.Set[str] + snapshot_names: set[str] @dataclass class LogDestructiveChange: snapshot_name: str - dropped_column_names: t.List[str] - alter_expressions: t.List[Alter] + dropped_column_names: list[str] + alter_expressions: list[Alter] dialect: str error: bool = True @dataclass class LoadingStart: - message: t.Optional[str] = None - id: uuid.UUID = uuid.uuid4() + message: str | None = None + id: uuid.UUID = field(default_factory=uuid.uuid4) @dataclass @@ -256,7 +256,7 @@ class ConsoleException: @dataclass class PrintEnvironments: - environments_summary: t.Dict[str, int] + environments_summary: dict[str, int] @dataclass @@ -264,51 +264,51 @@ class ShowTableDiffSummary: table_diff: TableDiff -ConsoleEvent = Union[ - StartPlanEvaluation, - StopPlanEvaluation, - StartEvaluationProgress, - StartSnapshotEvaluationProgress, - UpdateSnapshotEvaluationProgress, - StopEvaluationProgress, - StartCreationProgress, - UpdateCreationProgress, - StopCreationProgress, - StartCleanup, - UpdateCleanupProgress, - StopCleanup, - StartPromotionProgress, - UpdatePromotionProgress, - StopPromotionProgress, - UpdateSnapshotMigrationProgress, - LogMigrationStatus, - StopSnapshotMigrationProgress, - StartEnvMigrationProgress, - UpdateEnvMigrationProgress, - StopEnvMigrationProgress, - ShowModelDifferenceSummary, - PlanEvent, - LogTestResults, - ShowSQL, - LogStatusUpdate, - LogError, - LogWarning, - LogSuccess, - LogFailedModels, - LogSkippedModels, - LogDestructiveChange, - LoadingStart, - LoadingStop, - ShowSchemaDiff, - ShowRowDiff, - StartMigrationProgress, - UpdateMigrationProgress, - StopMigrationProgress, - StartSnapshotMigrationProgress, - ConsoleException, - PrintEnvironments, - ShowTableDiffSummary, -] +ConsoleEvent = ( + StartPlanEvaluation + | StopPlanEvaluation + | StartEvaluationProgress + | StartSnapshotEvaluationProgress + | UpdateSnapshotEvaluationProgress + | StopEvaluationProgress + | StartCreationProgress + | UpdateCreationProgress + | StopCreationProgress + | StartCleanup + | UpdateCleanupProgress + | StopCleanup + | StartPromotionProgress + | UpdatePromotionProgress + | StopPromotionProgress + | UpdateSnapshotMigrationProgress + | LogMigrationStatus + | StopSnapshotMigrationProgress + | StartEnvMigrationProgress + | UpdateEnvMigrationProgress + | StopEnvMigrationProgress + | ShowModelDifferenceSummary + | PlanEvent + | LogTestResults + | ShowSQL + | LogStatusUpdate + | LogError + | LogWarning + | LogSuccess + | LogFailedModels + | LogSkippedModels + | LogDestructiveChange + | LoadingStart + | LoadingStop + | ShowSchemaDiff + | ShowRowDiff + | StartMigrationProgress + | UpdateMigrationProgress + | StopMigrationProgress + | StartSnapshotMigrationProgress + | ConsoleException + | PrintEnvironments + | ShowTableDiffSummary +) ConsoleEventHandler = Callable[[ConsoleEvent], None] @@ -330,16 +330,16 @@ class EventConsole(Console): promotion, migration, and testing. """ - categorizer: t.Optional[SnapshotCategorizer] = None + categorizer: SnapshotCategorizer | None = None - def __init__(self, log_override: t.Optional[logging.Logger] = None): - self._handlers: Dict[str, ConsoleEventHandler] = {} + def __init__(self, log_override: logging.Logger | None = None) -> None: + self._handlers: dict[str, ConsoleEventHandler] = {} self.logger = log_override or logger self.id = str(uuid.uuid4()) self.logger.debug(f"EventConsole[{self.id}]: created") self.categorizer = None - def add_snapshot_categorizer(self, categorizer: SnapshotCategorizer): + def add_snapshot_categorizer(self, categorizer: SnapshotCategorizer) -> None: self.categorizer = categorizer def start_plan_evaluation(self, plan: EvaluatablePlan) -> None: @@ -350,9 +350,9 @@ def stop_plan_evaluation(self) -> None: def start_evaluation_progress( self, - batches: Dict[Snapshot, int], + batches: dict[Snapshot, int], environment_naming_info: EnvironmentNamingInfo, - default_catalog: t.Optional[str], + default_catalog: str | None, ) -> None: self.publish( StartEvaluationProgress(batches, environment_naming_info, default_catalog) @@ -362,7 +362,7 @@ def start_snapshot_evaluation_progress(self, snapshot: Snapshot) -> None: self.publish(StartSnapshotEvaluationProgress(snapshot)) def update_snapshot_evaluation_progress( - self, snapshot: Snapshot, batch_idx: int, duration_ms: t.Optional[int] + self, snapshot: Snapshot, batch_idx: int, duration_ms: int | None ) -> None: self.publish(UpdateSnapshotEvaluationProgress(snapshot, batch_idx, duration_ms)) @@ -373,7 +373,7 @@ def start_creation_progress( self, total_tasks: int, environment_naming_info: EnvironmentNamingInfo, - default_catalog: t.Optional[str], + default_catalog: str | None, ) -> None: self.publish( StartCreationProgress(total_tasks, environment_naming_info, default_catalog) @@ -400,7 +400,7 @@ def start_promotion_progress( self, total_tasks: int, environment_naming_info: EnvironmentNamingInfo, - default_catalog: t.Optional[str], + default_catalog: str | None, ) -> None: self.publish( StartPromotionProgress( @@ -441,7 +441,7 @@ def show_model_difference_summary( self, context_diff: ContextDiff, environment_naming_info: EnvironmentNamingInfo, - default_catalog: t.Optional[str], + default_catalog: str | None, no_diff: bool = True, ) -> None: self.publish( @@ -457,7 +457,7 @@ def plan( self, plan_builder: PlanBuilder, auto_apply: bool, - default_catalog: t.Optional[str], + default_catalog: str | None, no_diff: bool = False, no_prompts: bool = False, ) -> None: @@ -477,7 +477,7 @@ def plan( def log_test_results( self, result: unittest.result.TestResult, - output: t.Optional[str], + output: str | None, target_dialect: str, ) -> None: self.publish(LogTestResults(result, output, target_dialect)) @@ -491,35 +491,33 @@ def log_status_update(self, message: str) -> None: def log_error(self, message: str) -> None: self.publish(LogError(message)) - def log_warning( - self, short_message: str, long_message: t.Optional[str] = None - ) -> None: + def log_warning(self, short_message: str, long_message: str | None = None) -> None: self.publish(LogWarning(short_message, long_message)) def log_success(self, message: str) -> None: self.publish(LogSuccess(message)) - def log_failed_models(self, errors): + def log_failed_models(self, errors: list[NodeExecutionFailedError[str]]) -> None: self.publish(LogFailedModels(errors)) - def log_skipped_models(self, snapshot_names): + def log_skipped_models(self, snapshot_names: set[str]) -> None: self.publish(LogSkippedModels(snapshot_names)) def log_destructive_change( self, - snapshot_name, - dropped_column_names, - alter_expressions, - dialect, - error=True, - ): + snapshot_name: str, + dropped_column_names: list[str], + alter_expressions: list[Alter], + dialect: str, + error: bool = True, + ) -> None: self.publish( LogDestructiveChange( snapshot_name, dropped_column_names, alter_expressions, dialect, error ) ) - def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID: + def loading_start(self, message: str | None = None) -> uuid.UUID: event_id = uuid.uuid4() self.publish(LoadingStart(message, event_id)) return event_id @@ -545,27 +543,37 @@ def publish(self, event: ConsoleEvent) -> None: for handler in self._handlers.values(): handler(event) - def add_handler(self, handler: ConsoleEventHandler): + def add_handler(self, handler: ConsoleEventHandler) -> str: handler_id = str(uuid.uuid4()) self.logger.debug(f"EventConsole[{self.id}]: Adding handler {handler_id}") self._handlers[handler_id] = handler return handler_id - def remove_handler(self, handler_id: str): + def remove_handler(self, handler_id: str) -> None: del self._handlers[handler_id] - def exception(self, exc: Exception): + def exception(self, exc: Exception) -> None: self.publish(ConsoleException(exc)) - def print_environments(self, environments_summary: t.Dict[str, int]) -> None: + def print_environments(self, environments_summary: dict[str, int]) -> None: self.publish(PrintEnvironments(environments_summary)) def show_table_diff_summary(self, table_diff: TableDiff) -> None: self.publish(ShowTableDiffSummary(table_diff)) def show_linter_violations( - self, violations: list[RuleViolation], model: Model, is_error: bool = False + self, + violations: list[RuleViolation], + model: Model, + is_error: bool = False, ) -> None: + """Show linting violations from SQLMesh. + + Args: + violations: List of linting violations to display + model: The model being linted + is_error: Whether the violations are errors + """ self.publish(LogWarning("Linting violations found", str(violations))) @@ -586,9 +594,9 @@ def stop_plan_evaluation(self) -> None: def start_evaluation_progress( self, - batches: Dict[Snapshot, int], + batches: dict[Snapshot, int], environment_naming_info: EnvironmentNamingInfo, - default_catalog: t.Optional[str], + default_catalog: str | None, ) -> None: super().start_evaluation_progress( batches, environment_naming_info, default_catalog @@ -602,7 +610,7 @@ def start_snapshot_evaluation_progress(self, snapshot: Snapshot) -> None: self._console.start_snapshot_evaluation_progress(snapshot) def update_snapshot_evaluation_progress( - self, snapshot: Snapshot, batch_idx: int, duration_ms: t.Optional[int] + self, snapshot: Snapshot, batch_idx: int, duration_ms: int | None ) -> None: super().update_snapshot_evaluation_progress(snapshot, batch_idx, duration_ms) self._console.update_snapshot_evaluation_progress( @@ -617,7 +625,7 @@ def start_creation_progress( self, total_tasks: int, environment_naming_info: EnvironmentNamingInfo, - default_catalog: t.Optional[str], + default_catalog: str | None, ) -> None: super().start_creation_progress( total_tasks, environment_naming_info, default_catalog @@ -642,7 +650,7 @@ def start_promotion_progress( self, total_tasks: int, environment_naming_info: EnvironmentNamingInfo, - default_catalog: t.Optional[str], + default_catalog: str | None, ) -> None: super().start_promotion_progress( total_tasks, environment_naming_info, default_catalog @@ -665,7 +673,7 @@ def show_model_difference_summary( self, context_diff: ContextDiff, environment_naming_info: EnvironmentNamingInfo, - default_catalog: t.Optional[str], + default_catalog: str | None, no_diff: bool = True, ) -> None: super().show_model_difference_summary( @@ -686,7 +694,7 @@ def plan( self, plan_builder: PlanBuilder, auto_apply: bool, - default_catalog: t.Optional[str], + default_catalog: str | None, no_diff: bool = False, no_prompts: bool = False, ) -> None: @@ -698,7 +706,7 @@ def plan( def log_test_results( self, result: unittest.result.TestResult, - output: t.Optional[str], + output: str | None, target_dialect: str, ) -> None: super().log_test_results(result, output, target_dialect) @@ -720,7 +728,7 @@ def log_success(self, message: str) -> None: super().log_success(message) self._console.log_success(message) - def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID: + def loading_start(self, message: str | None = None) -> uuid.UUID: event_id = super().loading_start(message) self._console.loading_start(message) return event_id @@ -741,3 +749,12 @@ def show_row_diff( ) -> None: super().show_row_diff(row_diff, show_sample) self._console.show_row_diff(row_diff, show_sample, skip_grain_check) + + def show_linter_violations( + self, + violations: list[RuleViolation], + model: Model, + is_error: bool = False, + ) -> None: + super().show_linter_violations(violations, model, is_error) + self._console.show_linter_violations(violations, model, is_error) diff --git a/dagster_sqlmesh/controller/__init__.py b/dagster_sqlmesh/controller/__init__.py index 9e90dc3..150e8f0 100644 --- a/dagster_sqlmesh/controller/__init__.py +++ b/dagster_sqlmesh/controller/__init__.py @@ -1,3 +1,3 @@ # ruff: noqa: F403 F401 -from .base import SQLMeshController, SQLMeshInstance, PlanOptions, RunOptions +from .base import PlanOptions, RunOptions, SQLMeshController, SQLMeshInstance from .dagster import DagsterSQLMeshController diff --git a/dagster_sqlmesh/controller/base.py b/dagster_sqlmesh/controller/base.py index 1acf3a7..00137b4 100644 --- a/dagster_sqlmesh/controller/base.py +++ b/dagster_sqlmesh/controller/base.py @@ -1,28 +1,34 @@ -from dataclasses import dataclass -import typing as t import logging import threading +import typing as t from contextlib import contextmanager +from dataclasses import dataclass +from types import MappingProxyType +from typing import TypeVar -from sqlmesh.utils.date import TimeLike -from sqlmesh.core.context import Context -from sqlmesh.core.plan import PlanBuilder from sqlmesh.core.config import CategorizerConfig from sqlmesh.core.console import Console, set_console +from sqlmesh.core.context import Context from sqlmesh.core.model import Model +from sqlmesh.core.plan import PlanBuilder +from sqlmesh.utils.dag import DAG +from sqlmesh.utils.date import TimeLike -from ..events import ConsoleGenerator from ..config import SQLMeshContextConfig from ..console import ( - ConsoleException, - EventConsole, + ConsoleEvent, ConsoleEventHandler, + ConsoleException, DebugEventConsole, + EventConsole, SnapshotCategorizer, ) +from ..events import ConsoleGenerator logger = logging.getLogger(__name__) +T = TypeVar("T", bound="SQLMeshController") + class PlanOptions(t.TypedDict): start: t.NotRequired[TimeLike] @@ -52,6 +58,8 @@ class RunOptions(t.TypedDict): skip_janitor: t.NotRequired[bool] ignore_cron: t.NotRequired[bool] select_models: t.NotRequired[t.Collection[str]] + exit_on_env_update: t.NotRequired[int] + no_auto_upstream: t.NotRequired[bool] @dataclass(kw_only=True) @@ -61,7 +69,7 @@ class SQLMeshParsedFQN: view_name: str -def parse_fqn(fqn: str): +def parse_fqn(fqn: str) -> SQLMeshParsedFQN: split_fqn = fqn.split(".") # Remove any quotes @@ -74,9 +82,9 @@ def parse_fqn(fqn: str): @dataclass(kw_only=True) class SQLMeshModelDep: fqn: str - model: t.Optional[Model] = None + model: Model | None = None - def parse_fqn(self): + def parse_fqn(self) -> SQLMeshParsedFQN: return parse_fqn(self.fqn) @@ -120,17 +128,19 @@ def __init__( self.logger = logger @contextmanager - def console_context(self, handler: ConsoleEventHandler): + def console_context( + self, handler: ConsoleEventHandler + ) -> t.Iterator[None]: id = self.console.add_handler(handler) yield self.console.remove_handler(id) def plan( self, - categorizer: t.Optional[SnapshotCategorizer] = None, - default_catalog: t.Optional[str] = None, + categorizer: SnapshotCategorizer | None = None, + default_catalog: str | None = None, **plan_options: t.Unpack[PlanOptions], - ): + ) -> t.Iterator[ConsoleEvent]: """ Executes a sqlmesh plan operation in a separate thread and yields console events. @@ -139,9 +149,9 @@ def plan( thread, and provides real-time console output through a generator. Args: - categorizer (Optional[SnapshotCategorizer]): Categorizer for + categorizer (SnapshotCategorizer | None): Categorizer for snapshots. Defaults to None. - default_catalog (Optional[str]): Default catalog to use for the + default_catalog (str | None): Default catalog to use for the plan. Defaults to None. **plan_options (**PlanOptions): Additional options for plan execution. @@ -163,7 +173,7 @@ def run_sqlmesh_thread( environment: str, plan_options: PlanOptions, default_catalog: str, - ): + ) -> None: logger.debug("dagster-sqlmesh: thread started") try: builder = t.cast( @@ -214,7 +224,9 @@ def run_sqlmesh_thread( thread.join() - def run(self, **run_options: t.Unpack[RunOptions]): + def run( + self, **run_options: t.Unpack[RunOptions] + ) -> t.Iterator[ConsoleEvent]: """Executes sqlmesh run in a separate thread with console output. This method executes SQLMesh operations in a dedicated thread while @@ -242,7 +254,7 @@ def run_sqlmesh_thread( controller: SQLMeshController, environment: str, run_options: RunOptions, - ): + ) -> None: logger.debug("dagster-sqlmesh: run") try: context.run(environment=environment, **run_options) @@ -281,12 +293,12 @@ def plan_and_run( restate_selected: bool = False, start: TimeLike | None = None, end: TimeLike | None = None, - categorizer: t.Optional[SnapshotCategorizer] = None, - default_catalog: t.Optional[str] = None, - plan_options: t.Optional[PlanOptions] = None, - run_options: t.Optional[RunOptions] = None, + categorizer: SnapshotCategorizer | None = None, + default_catalog: str | None = None, + plan_options: PlanOptions | None= None, + run_options: RunOptions | None = None, skip_run: bool = False, - ): + ) -> t.Iterator[ConsoleEvent]: """Executes a plan and run operation This is an opinionated interface for running a plan and run operation in @@ -334,10 +346,10 @@ def plan_and_run( self.logger.error("Error during sqlmesh plan and run") raise - def models(self): + def models(self) -> MappingProxyType[str, Model]: return self.context.models - def models_dag(self): + def models_dag(self) -> DAG[str]: return self.context.dag @@ -382,9 +394,9 @@ def setup( cls, path: str, gateway: str = "local", - debug_console: t.Optional[Console] = None, - log_override: t.Optional[logging.Logger] = None, - ): + debug_console: Console | None = None, + log_override: logging.Logger | None = None, + ) -> "SQLMeshController": return cls.setup_with_config( config=SQLMeshContextConfig(path=path, gateway=gateway), debug_console=debug_console, @@ -393,11 +405,11 @@ def setup( @classmethod def setup_with_config( - cls, + cls: type[T], config: SQLMeshContextConfig, - debug_console: t.Optional[Console] = None, - log_override: t.Optional[logging.Logger] = None, - ): + debug_console: Console | None = None, + log_override: logging.Logger | None = None, + ) -> T: console = EventConsole(log_override=log_override) if debug_console: console = DebugEventConsole(debug_console) @@ -412,24 +424,25 @@ def __init__( self, config: SQLMeshContextConfig, console: EventConsole, - log_override: t.Optional[logging.Logger] = None, - ): + log_override: logging.Logger | None = None, + ) -> None: self.config = config self.console = console self.logger = log_override or logger self._context_open = False - def set_logger(self, logger: logging.Logger): + def set_logger(self, logger: logging.Logger) -> None: self.logger = logger - def add_event_handler(self, handler: ConsoleEventHandler): - return self.console.add_handler(handler) + def add_event_handler(self, handler: ConsoleEventHandler) -> str: + handler_id: str = self.console.add_handler(handler) + return handler_id - def remove_event_handler(self, handler_id: str): - return self.console.remove_handler(handler_id) + def remove_event_handler(self, handler_id: str) -> None: + self.console.remove_handler(handler_id) - def _create_context(self): - options: t.Dict[str, t.Any] = dict( + def _create_context(self) -> Context: + options: dict[str, t.Any] = dict( paths=self.config.path, gateway=self.config.gateway, ) @@ -439,7 +452,9 @@ def _create_context(self): return Context(**options) @contextmanager - def instance(self, environment: str, component: str = "unknown"): + def instance( + self, environment: str, component: str = "unknown" + ) -> t.Iterator[SQLMeshInstance]: self.logger.info( f"Opening sqlmesh instance for env={environment} component={component}" ) @@ -463,17 +478,17 @@ def run( self, environment: str, **run_options: t.Unpack[RunOptions], - ): + ) -> t.Iterator[ConsoleEvent]: with self.instance(environment, "run") as mesh: yield from mesh.run(**run_options) def plan( self, environment: str, - categorizer: t.Optional[SnapshotCategorizer], - default_catalog: t.Optional[str], + categorizer: SnapshotCategorizer | None, + default_catalog: str | None, plan_options: PlanOptions, - ): + ) -> t.Iterator[ConsoleEvent]: with self.instance(environment, "plan") as mesh: yield from mesh.plan(categorizer, default_catalog, **plan_options) @@ -481,16 +496,16 @@ def plan_and_run( self, environment: str, *, - categorizer: t.Optional[SnapshotCategorizer] = None, + categorizer: SnapshotCategorizer | None = None, select_models: list[str] | None = None, restate_selected: bool = False, start: TimeLike | None = None, end: TimeLike | None = None, - default_catalog: t.Optional[str] = None, - plan_options: t.Optional[PlanOptions] = None, - run_options: t.Optional[RunOptions] = None, + default_catalog: str | None = None, + plan_options: PlanOptions | None = None, + run_options: RunOptions | None = None, skip_run: bool = False, - ): + ) -> t.Iterator[ConsoleEvent]: with self.instance(environment, "plan_and_run") as mesh: yield from mesh.plan_and_run( start=start, diff --git a/dagster_sqlmesh/controller/dagster.py b/dagster_sqlmesh/controller/dagster.py index b6b1a42..bf1adb4 100644 --- a/dagster_sqlmesh/controller/dagster.py +++ b/dagster_sqlmesh/controller/dagster.py @@ -1,17 +1,16 @@ import logging -import typing as t -from dagster._core.definitions.asset_dep import CoercibleToAssetDep from dagster import ( AssetDep, - AssetOut, AssetKey, + AssetOut, ) +from dagster._core.definitions.asset_dep import CoercibleToAssetDep +from ..translator import SQLMeshDagsterTranslator +from ..types import SQLMeshModelDep, SQLMeshMultiAssetOptions from ..utils import sqlmesh_model_name_to_key from .base import SQLMeshController -from ..translator import SQLMeshDagsterTranslator -from ..types import SQLMeshMultiAssetOptions, SQLMeshModelDep logger = logging.getLogger(__name__) @@ -26,7 +25,7 @@ def to_asset_outs( context = instance.context dag = context.dag output = SQLMeshMultiAssetOptions() - depsMap: t.Dict[str, CoercibleToAssetDep] = {} + depsMap: dict[str, CoercibleToAssetDep] = {} for model_fqn, deps in dag.graph.items(): logger.debug(f"model found: {model_fqn}") @@ -42,7 +41,7 @@ def to_asset_outs( SQLMeshModelDep(fqn=dep, model=context.get_model(dep)) for dep in deps ] - internal_asset_deps: t.Set[AssetKey] = set() + internal_asset_deps: set[AssetKey] = set() asset_tags = translator.get_tags(context, model) for dep in model_deps: diff --git a/dagster_sqlmesh/events.py b/dagster_sqlmesh/events.py index 48430e5..a9b36b6 100644 --- a/dagster_sqlmesh/events.py +++ b/dagster_sqlmesh/events.py @@ -1,11 +1,11 @@ -from typing import List, Optional, Set, Callable, Iterator import logging import queue import threading +from collections.abc import Callable, Iterator from sqlmesh.core.model import Model -from sqlmesh.core.snapshot import SnapshotInfoLike, SnapshotId, Snapshot from sqlmesh.core.plan import Plan +from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotInfoLike from dagster_sqlmesh import console @@ -17,8 +17,8 @@ def show_plan_summary( logger: logging.Logger, plan: Plan, snapshot_selector: Callable[[SnapshotInfoLike], bool], - ignored_snapshot_ids: Optional[Set[SnapshotId]] = None, -): + ignored_snapshot_ids: set[SnapshotId] | None = None, +) -> None: context_diff = plan.context_diff ignored_snapshot_ids = ignored_snapshot_ids or set() selected_snapshots = { @@ -44,7 +44,7 @@ def show_plan_summary( for _, (current_snapshot, _) in context_diff.modified_snapshots.items() if snapshot_selector(current_snapshot) } - selected_ignored_snapshot_ids - restated_snapshots: List[SnapshotInfoLike] = [ + restated_snapshots: list[SnapshotInfoLike] = [ context_diff.snapshots[snap_id] for snap_id in plan.restatements.keys() ] @@ -62,11 +62,11 @@ def show_plan_summary( class ConsoleGenerator: - def __init__(self, log_override: Optional[logging.Logger] = None): - self._queue = queue.Queue() + def __init__(self, log_override: logging.Logger | None = None): + self._queue: queue.Queue[console.ConsoleEvent] = queue.Queue() self.logger = log_override or logger - def __call__(self, event: console.ConsoleEvent): + def __call__(self, event: console.ConsoleEvent) -> None: self._queue.put(event) def events(self, thread: threading.Thread) -> Iterator[console.ConsoleEvent]: @@ -82,16 +82,16 @@ def events(self, thread: threading.Thread) -> Iterator[console.ConsoleEvent]: class ConsoleRecorder: def __init__( self, - log_override: Optional[logging.Logger] = None, + log_override: logging.Logger | None = None, enable_unknown_event_logging: bool = True, ): self.logger = log_override or logger - self._planned_models: List[Model] = [] - self._updated: List[Snapshot] = [] + self._planned_models: list[Model] = [] + self._updated: list[Snapshot] = [] self._successful = False self._enable_unknown_event_logging = enable_unknown_event_logging - def __call__(self, event: console.ConsoleEvent): + def __call__(self, event: console.ConsoleEvent) -> None: match event: case console.StartPlanEvaluation(evaluatable_plan): self.logger.debug("Starting plan evaluation") @@ -131,8 +131,8 @@ def _show_summary_for( self, plan: Plan, snapshot_selector: Callable[[SnapshotInfoLike], bool], - ignored_snapshot_ids: Optional[Set[SnapshotId]] = None, - ): + ignored_snapshot_ids: set[SnapshotId] | None = None, + ) -> None: return show_plan_summary( self.logger, plan, snapshot_selector, ignored_snapshot_ids ) diff --git a/dagster_sqlmesh/resource.py b/dagster_sqlmesh/resource.py index 5a1b797..8e351c0 100644 --- a/dagster_sqlmesh/resource.py +++ b/dagster_sqlmesh/resource.py @@ -1,21 +1,21 @@ -import typing as t import logging +import typing as t from dagster import ( - ConfigurableResource, AssetExecutionContext, + ConfigurableResource, MaterializeResult, ) from sqlmesh import Model +from sqlmesh.core.context import Context as SQLMeshContext +from sqlmesh.core.snapshot import Snapshot from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import TimeLike -from sqlmesh.core.snapshot import Snapshot -from sqlmesh.core.context import Context as SQLMeshContext -from .utils import sqlmesh_model_name_to_key -from .config import SQLMeshContextConfig -from .controller import SQLMeshController, PlanOptions, RunOptions from . import console +from .config import SQLMeshContextConfig +from .controller import PlanOptions, RunOptions, SQLMeshController +from .utils import sqlmesh_model_name_to_key class MaterializationTracker: @@ -23,17 +23,17 @@ class MaterializationTracker: order. This is necessary because sqlmesh may skip some materializations that have no changes and those will be reported as completed out of order.""" - def __init__(self, sorted_dag: t.List[str], logger: logging.Logger): + def __init__(self, sorted_dag: list[str], logger: logging.Logger) -> None: self.logger = logger - self._batches: t.Dict[Snapshot, int] = {} - self._count: t.Dict[Snapshot, int] = {} - self._complete_update_status: t.Dict[str, bool] = {} + self._batches: dict[Snapshot, int] = {} + self._count: dict[Snapshot, int] = {} + self._complete_update_status: dict[str, bool] = {} self._sorted_dag = sorted_dag self._current_index = 0 - def plan(self, batches: t.Dict[Snapshot, int]): + def plan(self, batches: dict[Snapshot, int]) -> None: self._batches = batches - self._count: t.Dict[Snapshot, int] = {} + self._count: dict[Snapshot, int] = {} incomplete_names = set() for snapshot, count in self._batches.items(): @@ -46,7 +46,7 @@ def plan(self, batches: t.Dict[Snapshot, int]): name: False for name in (set(self._sorted_dag) - incomplete_names) } - def update(self, snapshot: Snapshot, _batch_idx: int): + def update(self, snapshot: Snapshot, _batch_idx: int) -> tuple[int, int]: self._count[snapshot] += 1 current_count = self._count[snapshot] expected_count = self._batches[snapshot] @@ -54,7 +54,7 @@ def update(self, snapshot: Snapshot, _batch_idx: int): self._complete_update_status[snapshot.name] = True return (current_count, expected_count) - def notify_queue_next(self) -> t.Tuple[str, bool] | None: + def notify_queue_next(self) -> tuple[str, bool] | None: if self._current_index >= len(self._sorted_dag): return None check_name = self._sorted_dag[self._current_index] @@ -73,24 +73,24 @@ def __init__( self._handler = handler self._event = event - def ensure_standard_obj(self, obj: t.Optional[t.Dict[str, t.Any]]): + def ensure_standard_obj(self, obj: dict[str, t.Any] | None) -> dict[str, t.Any]: obj = obj or {} obj["_event_type"] = self.event_name return obj - def info(self, message: str, obj: t.Optional[t.Dict[str, t.Any]] = None): + def info(self, message: str, obj: dict[str, t.Any] | None = None) -> None: self.log("info", message, obj) - def debug(self, message: str, obj: t.Optional[t.Dict[str, t.Any]] = None): + def debug(self, message: str, obj: dict[str, t.Any] | None = None) -> None: self.log("debug", message, obj) - def warning(self, message: str, obj: t.Optional[t.Dict[str, t.Any]] = None): + def warning(self, message: str, obj: dict[str, t.Any] | None = None) -> None: self.log("warning", message, obj) - def error(self, message: str, obj: t.Optional[t.Dict[str, t.Any]] = None): + def error(self, message: str, obj: dict[str, t.Any] | None = None) -> None: self.log("error", message, obj) - def log(self, level: str | int, message: str, obj: t.Optional[t.Dict[str, t.Any]]): + def log(self, level: str | int, message: str, obj: dict[str, t.Any] | None) -> None: self._handler.log(level, message, self.ensure_standard_obj(obj)) @property @@ -102,10 +102,10 @@ class DagsterSQLMeshEventHandler: def __init__( self, context: AssetExecutionContext, - models_map: t.Dict[str, Model], - dag: DAG, + models_map: dict[str, Model], + dag: DAG[t.Any], prefix: str, - ): + ) -> None: self._models_map = models_map self._prefix = prefix self._context = context @@ -115,7 +115,7 @@ def __init__( def process_events( self, sqlmesh_context: SQLMeshContext, event: console.ConsoleEvent - ): + ) -> t.Iterator[MaterializeResult]: self.report_event(event) notify = self._tracker.notify_queue_next() @@ -144,7 +144,7 @@ def process_events( ) notify = self._tracker.notify_queue_next() - def report_event(self, event: console.ConsoleEvent): + def report_event(self, event: console.ConsoleEvent) -> None: log_context = self.log_context(event) match event: @@ -200,21 +200,21 @@ def report_event(self, event: console.ConsoleEvent): case console.LogFailedModels(models): if len(models) != 0: failed_models = "\n".join( - [f"{str(model)}\n{str(model.__cause__)}" for model in models] + [f"{model!s}\n{model.__cause__!s}" for model in models] ) log_context.error(f"sqlmesh failed models: {failed_models}") case _: log_context.debug("Received event") - def log_context(self, event: console.ConsoleEvent): + def log_context(self, event: console.ConsoleEvent) -> SQLMeshEventLogContext: return SQLMeshEventLogContext(self, event) def log( self, level: str | int, message: str, - obj: t.Optional[t.Dict[str, t.Any]] = None, - ): + obj: dict[str, t.Any] | None = None, + ) -> None: if level == "error": self._logger.error(message) return @@ -240,8 +240,8 @@ def run( start: TimeLike | None = None, end: TimeLike | None = None, restate_selected: bool = False, - plan_options: t.Optional[PlanOptions] = None, - run_options: t.Optional[RunOptions] = None, + plan_options: PlanOptions | None = None, + run_options: RunOptions | None = None, ) -> t.Iterable[MaterializeResult]: """Execute SQLMesh based on the configuration given""" plan_options = plan_options or {} @@ -283,7 +283,9 @@ def run( ): yield from event_handler.process_events(mesh.context, event) - def get_controller(self, log_override: t.Optional[logging.Logger] = None): + def get_controller( + self, log_override: logging.Logger | None = None + ) -> SQLMeshController: return SQLMeshController.setup_with_config( self.config, log_override=log_override ) diff --git a/dagster_sqlmesh/scheduler.py b/dagster_sqlmesh/scheduler.py index 39f3773..89e88f1 100644 --- a/dagster_sqlmesh/scheduler.py +++ b/dagster_sqlmesh/scheduler.py @@ -1,17 +1,19 @@ import typing as t -from sqlmesh.core.scheduler import Scheduler, CompletionStatus +from sqlmesh.core.scheduler import CompletionStatus, Scheduler class DagsterSQLMeshScheduler(Scheduler): """Custom Scheduler so that we can choose a set of snapshots to use with sqlmesh runs""" - def __init__(self, selected_snapshots: t.Optional[t.Set[str]], *args, **kwargs): + def __init__( + self, selected_snapshots: set[str] | None = None, *args: t.Any, **kwargs: t.Any + ) -> None: super().__init__(*args, **kwargs) - self._selected_snapshots: t.Set[str] = selected_snapshots or set() + self._selected_snapshots: set[str] = selected_snapshots or set() - def run(self, *args, **kwargs) -> CompletionStatus: + def run(self, *args: t.Any, **kwargs: t.Any) -> CompletionStatus: if len(self._selected_snapshots) > 0: kwargs["selected_snapshots"] = self._selected_snapshots return super().run(*args, **kwargs) diff --git a/dagster_sqlmesh/test_sqlmesh_context.py b/dagster_sqlmesh/test_sqlmesh_context.py index 46fcc88..efc02ad 100644 --- a/dagster_sqlmesh/test_sqlmesh_context.py +++ b/dagster_sqlmesh/test_sqlmesh_context.py @@ -1,6 +1,7 @@ import logging import polars + from .conftest import SQLMeshTestContext logger = logging.getLogger(__name__) @@ -183,12 +184,12 @@ def test_restating_models(sample_sqlmesh_test_context: SQLMeshTestContext): """ ) - assert ( - feb_sum_query_restate[0][0] == feb_sum_query[0][0] - ), "February sum should not change" - assert ( - march_sum_query_restate[0][0] != march_sum_query[0][0] - ), "March sum should change" - assert ( - intermediate_2_query_restate[0][0] == intermediate_2_query[0][0] - ), "Intermediate model should not change during restate" + assert feb_sum_query_restate[0][0] == feb_sum_query[0][0], ( + "February sum should not change" + ) + assert march_sum_query_restate[0][0] != march_sum_query[0][0], ( + "March sum should change" + ) + assert intermediate_2_query_restate[0][0] == intermediate_2_query[0][0], ( + "Intermediate model should not change during restate" + ) diff --git a/dagster_sqlmesh/translator.py b/dagster_sqlmesh/translator.py index b631d04..e3791ba 100644 --- a/dagster_sqlmesh/translator.py +++ b/dagster_sqlmesh/translator.py @@ -1,9 +1,9 @@ -import typing as t + import sqlglot +from dagster import AssetKey from sqlglot import exp from sqlmesh.core.context import Context from sqlmesh.core.model import Model -from dagster import AssetKey class SQLMeshDagsterTranslator: @@ -26,6 +26,6 @@ def get_fqn_to_table(self, context: Context, fqn: str) -> exp.Table: def _get_context_dialect(self, context: Context) -> str: return context.engine_adapter.dialect - def get_tags(self, context: Context, model: Model) -> t.Dict[str, str]: + def get_tags(self, context: Context, model: Model) -> dict[str, str]: """Given the sqlmesh context and a model return the tags for that model""" return {k: "true" for k in model.tags} diff --git a/dagster_sqlmesh/types.py b/dagster_sqlmesh/types.py index d8511d6..9a5a965 100644 --- a/dagster_sqlmesh/types.py +++ b/dagster_sqlmesh/types.py @@ -1,15 +1,16 @@ import typing as t from dataclasses import dataclass, field + from dagster import ( AssetCheckResult, + AssetKey, AssetMaterialization, AssetOut, - AssetKey, ) from dagster._core.definitions.asset_dep import CoercibleToAssetDep from sqlmesh.core.model import Model -MultiAssetResponse = t.Iterable[t.Union[AssetCheckResult, AssetMaterialization]] +MultiAssetResponse = t.Iterable[AssetCheckResult | AssetMaterialization] @dataclass(kw_only=True) @@ -19,7 +20,7 @@ class SQLMeshParsedFQN: view_name: str @classmethod - def parse(cls, fqn: str): + def parse(cls, fqn: str) -> "SQLMeshParsedFQN": split_fqn = fqn.split(".") # Remove any quotes @@ -30,16 +31,14 @@ def parse(cls, fqn: str): @dataclass(kw_only=True) class SQLMeshModelDep: fqn: str - model: t.Optional[Model] = None + model: Model | None = None - def parse_fqn(self): + def parse_fqn(self) -> SQLMeshParsedFQN: return SQLMeshParsedFQN.parse(self.fqn) @dataclass(kw_only=True) class SQLMeshMultiAssetOptions: - outs: t.Dict[str, AssetOut] = field(default_factory=lambda: {}) + outs: dict[str, AssetOut] = field(default_factory=lambda: {}) deps: t.Iterable[CoercibleToAssetDep] = field(default_factory=lambda: {}) - internal_asset_deps: t.Dict[str, t.Set[AssetKey]] = field( - default_factory=lambda: {} - ) + internal_asset_deps: dict[str, set[AssetKey]] = field(default_factory=lambda: {}) diff --git a/dagster_sqlmesh/utils.py b/dagster_sqlmesh/utils.py index 2e9b57d..8a08b2a 100644 --- a/dagster_sqlmesh/utils.py +++ b/dagster_sqlmesh/utils.py @@ -1,6 +1,21 @@ -def sqlmesh_model_name_to_key(name: str): +from sqlmesh.core.snapshot import SnapshotId + + +def sqlmesh_model_name_to_key(name: str) -> str: return name.replace(".", "_dot__") -def key_to_sqlmesh_model_name(key: str): +def key_to_sqlmesh_model_name(key: str) -> str: return key.replace("_dot__", ".") + + +def snapshot_id_to_model_name(snapshot_id: SnapshotId) -> str: + """Convert a SnapshotId to its model name. + + Args: + snapshot_id: The SnapshotId object to extract the model name from + + Returns: + str: The model name in the format "db"."schema"."name" + """ + return str(snapshot_id).split("<")[1].split(":")[0] diff --git a/pyproject.toml b/pyproject.toml index aaabd9a..ef02669 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,11 @@ [project] name = "dagster-sqlmesh" -version = "0.9.0" +version = "0.8.0" description = "" -authors = [{ name = "Reuven Gonzales", email = "reuven@karibalabs.co" }] -license = { text = "Apache-2.0" } +authors = [ + {name = "Reuven Gonzales", email = "reuven@karibalabs.co"} +] +license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.11,<3.13" dependencies = [ @@ -21,6 +23,8 @@ dev = [ "ruff>=0.6.2", "polars>=1.5.0", "dagster-duckdb-polars>=0.24.2", + "fastapi", # this is for sqlmesh ui + "sse-starlette", # this is for sqlmesh ui ] [build-system] @@ -36,5 +40,42 @@ exclude = [ "**/.github", "**/.vscode", "**/.idea", - "**/.pytest_cache", + "**/.pytest_cache", +] +pythonVersion = "3.11" +reportUnknownParameterType = true +reportMissingTypeStubs = false +reportUnusedImports = true +reportUnnecessaryTypeIgnoreComment = true +useLibraryCodeForTypes = true +reportMissingReturnType = true +reportIncompleteStub = true +reportUntypedFunctionDecorator = false + +[tool.ruff.lint] +# Ignore E402: Module level import not at top of file +ignore = ["E402", "E712"] +select = [ + 'I001', # isort + "E4", # pycodestyle errors + "E7", # pycodestyle errors + "E9", # pycodestyle errors + "F", # pyflakes + "F401", # unused imports + "F403", # star imports usage + "F405", # star imports usage + "F821", # undefined names + "UP", # pyupgrade (modernize Python code) + "RUF" # ruff-specific rules ] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +line-ending = "auto" + +[tool.ruff.lint.isort] +known-first-party = ["dagster_sqlmesh"] +section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"] +combine-as-imports = true +split-on-trailing-comma = true \ No newline at end of file diff --git a/sample/dagster_project/dagster.yaml b/sample/dagster_project/dagster.yaml new file mode 100644 index 0000000..a4b5df7 --- /dev/null +++ b/sample/dagster_project/dagster.yaml @@ -0,0 +1,12 @@ +run_launcher: + module: dagster.core.launcher + class: DefaultRunLauncher + +run_coordinator: + module: dagster.core.run_coordinator + class: QueuedRunCoordinator + config: + max_concurrent_runs: 4 + +telemetry: + enabled: false \ No newline at end of file diff --git a/sample/dagster_project/definitions.py b/sample/dagster_project/definitions.py index c6bc2b9..345de54 100644 --- a/sample/dagster_project/definitions.py +++ b/sample/dagster_project/definitions.py @@ -1,17 +1,18 @@ import os - import time +import typing as t + +import polars as pl from dagster import ( - MaterializeResult, - asset, AssetExecutionContext, Definitions, + MaterializeResult, + asset, define_asset_job, ) from dagster_duckdb_polars import DuckDBPolarsIOManager -import polars as pl -from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource +from dagster_sqlmesh import SQLMeshContextConfig, SQLMeshResource, sqlmesh_assets CURR_DIR = os.path.dirname(__file__) SQLMESH_PROJECT_PATH = os.path.abspath(os.path.join(CURR_DIR, "../sqlmesh_project")) @@ -52,7 +53,7 @@ def test_source() -> pl.DataFrame: @sqlmesh_assets(environment="dev", config=sqlmesh_config, enabled_subsetting=True) -def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource): +def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource) -> t.Iterator[MaterializeResult]: yield from sqlmesh.run(context) diff --git a/sample/dagster_project/workspace.yaml b/sample/dagster_project/workspace.yaml new file mode 100644 index 0000000..6eed5ff --- /dev/null +++ b/sample/dagster_project/workspace.yaml @@ -0,0 +1,5 @@ +load_from: + - python_file: + relative_path: definitions.py + location_name: "dagster_sqlmesh_project" + working_directory: "." \ No newline at end of file diff --git a/sample/sqlmesh_project/models/staging/staging_model_4.py b/sample/sqlmesh_project/models/staging/staging_model_4.py index e56fc3b..14053da 100644 --- a/sample/sqlmesh_project/models/staging/staging_model_4.py +++ b/sample/sqlmesh_project/models/staging/staging_model_4.py @@ -1,10 +1,10 @@ import typing as t from datetime import datetime +import numpy as np import pandas as pd from sqlmesh import ExecutionContext, model from sqlmesh.core.model import ModelKindName -import numpy as np @model( @@ -21,8 +21,8 @@ def staging_model_4( context: ExecutionContext, start: datetime, end: datetime, - **kwargs, -) -> t.Generator[pd.DataFrame, None, None]: + **kwargs: t.Any, +) -> t.Iterator[pd.DataFrame]: # Generates a set of random rows for the model based on the start and end dates date_range = pd.date_range(start=start, end=end, freq="D") num_days = len(date_range) diff --git a/uv.lock b/uv.lock index 8c0e508..e0a8ff8 100644 --- a/uv.lock +++ b/uv.lock @@ -300,10 +300,12 @@ dependencies = [ dev = [ { name = "dagster-duckdb-polars" }, { name = "dagster-webserver" }, + { name = "fastapi" }, { name = "ipython" }, { name = "pdbpp" }, { name = "polars" }, { name = "ruff" }, + { name = "sse-starlette" }, ] [package.metadata] @@ -318,10 +320,12 @@ requires-dist = [ dev = [ { name = "dagster-duckdb-polars", specifier = ">=0.24.2" }, { name = "dagster-webserver", specifier = ">=1.8.1" }, + { name = "fastapi" }, { name = "ipython", specifier = ">=8.26.0" }, { name = "pdbpp", specifier = ">=0.10.3" }, { name = "polars", specifier = ">=1.5.0" }, { name = "ruff", specifier = ">=0.6.2" }, + { name = "sse-starlette" }, ] [[package]] @@ -419,6 +423,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/ef/c08926112034d017633f693d3afc8343393a035134a29dfc12dcd71b0375/fancycompleter-0.9.1-py3-none-any.whl", hash = "sha256:dd076bca7d9d524cc7f25ec8f35ef95388ffef9ef46def4d3d25e9b044ad7080", size = 9681 }, ] +[[package]] +name = "fastapi" +version = "0.115.11" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b5/28/c5d26e5860df807241909a961a37d45e10533acef95fc368066c7dd186cd/fastapi-0.115.11.tar.gz", hash = "sha256:cc81f03f688678b92600a65a5e618b93592c65005db37157147204d8924bf94f", size = 294441 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/5d/4d8bbb94f0dbc22732350c06965e40740f4a92ca560e90bb566f4f73af41/fastapi-0.115.11-py3-none-any.whl", hash = "sha256:32e1541b7b74602e4ef4a0260ecaf3aadf9d4f19590bba3e1bf2ac4666aa2c64", size = 94926 }, +] + [[package]] name = "filelock" version = "3.17.0" @@ -1523,6 +1541,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f2/cd/cbdbcf1d03b9483c7fb47d1775962a0b4abcbc85aed4807d8b512467eaf6/sqlmesh-0.164.0-py3-none-any.whl", hash = "sha256:36ce7546f929b86884673fa1be68955f895797f07e0629e312f5cdb1da26416b", size = 5378661 }, ] +[[package]] +name = "sse-starlette" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "starlette" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/a4/80d2a11af59fe75b48230846989e93979c892d3a20016b42bb44edb9e398/sse_starlette-2.2.1.tar.gz", hash = "sha256:54470d5f19274aeed6b2d473430b08b4b379ea851d953b11d7f1c4a2c118b419", size = 17376 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/e0/5b8bd393f27f4a62461c5cf2479c75a2cc2ffa330976f9f00f5f6e4f50eb/sse_starlette-2.2.1-py3-none-any.whl", hash = "sha256:6410a3d3ba0c89e7675d4c273a301d64649c03a5ef1ca101f10b47f895fd0e99", size = 10120 }, +] + [[package]] name = "stack-data" version = "0.6.3"