diff --git a/README.md b/README.md index 8e12f17..62942f1 100644 --- a/README.md +++ b/README.md @@ -24,11 +24,11 @@ from dagster import ( AssetExecutionContext, Definitions, ) -from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource +from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource, SQLMeshDagsterTranslator sqlmesh_config = SQLMeshContextConfig(path="/home/foo/sqlmesh_project", gateway="name-of-your-gateway") -@sqlmesh_assets(environment="dev", config=sqlmesh_config) +@sqlmesh_assets(environment="dev", config=sqlmesh_config, translator=SQLMeshDagsterTranslator()) def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource): yield from sqlmesh.run(context) diff --git a/dagster_sqlmesh/asset.py b/dagster_sqlmesh/asset.py index 992d99f..c697b1b 100644 --- a/dagster_sqlmesh/asset.py +++ b/dagster_sqlmesh/asset.py @@ -4,6 +4,7 @@ from dagster import AssetsDefinition, RetryPolicy, multi_asset from sqlmesh import Context +from dagster_sqlmesh.config import SQLMeshContextConfig from dagster_sqlmesh.controller import ( ContextCls, ContextFactory, @@ -11,8 +12,6 @@ ) from dagster_sqlmesh.translator import SQLMeshDagsterTranslator -from .config import SQLMeshContextConfig - logger = logging.getLogger(__name__) @@ -34,7 +33,7 @@ def sqlmesh_assets( controller = DagsterSQLMeshController.setup_with_config(config=config, context_factory=context_factory) if not dagster_sqlmesh_translator: dagster_sqlmesh_translator = SQLMeshDagsterTranslator() - conversion = controller.to_asset_outs(environment, dagster_sqlmesh_translator) + conversion = controller.to_asset_outs(environment, translator=dagster_sqlmesh_translator) return multi_asset( name=name, diff --git a/dagster_sqlmesh/config.py b/dagster_sqlmesh/config.py index d088ee9..61431b3 100644 --- a/dagster_sqlmesh/config.py +++ b/dagster_sqlmesh/config.py @@ -30,4 +30,4 @@ class SQLMeshContextConfig(Config): def sqlmesh_config(self) -> MeshConfig | None: if self.config_override: return MeshConfig.parse_obj(self.config_override) - return None + return None \ No newline at end of file diff --git a/dagster_sqlmesh/controller/base.py b/dagster_sqlmesh/controller/base.py index 78e7905..51e671b 100644 --- a/dagster_sqlmesh/controller/base.py +++ b/dagster_sqlmesh/controller/base.py @@ -13,8 +13,8 @@ from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import TimeLike -from ..config import SQLMeshContextConfig -from ..console import ( +from dagster_sqlmesh.config import SQLMeshContextConfig +from dagster_sqlmesh.console import ( ConsoleEvent, ConsoleEventHandler, ConsoleException, @@ -22,7 +22,7 @@ Plan, SnapshotCategorizer, ) -from ..events import ConsoleGenerator +from dagster_sqlmesh.events import ConsoleGenerator logger = logging.getLogger(__name__) diff --git a/dagster_sqlmesh/controller/dagster.py b/dagster_sqlmesh/controller/dagster.py index 6978be8..bc858fc 100644 --- a/dagster_sqlmesh/controller/dagster.py +++ b/dagster_sqlmesh/controller/dagster.py @@ -1,13 +1,14 @@ # pyright: reportPrivateImportUsage=false import logging +from inspect import signature from dagster import AssetDep, AssetKey, AssetOut from dagster._core.definitions.asset_dep import CoercibleToAssetDep -from ..translator import SQLMeshDagsterTranslator -from ..types import SQLMeshModelDep, SQLMeshMultiAssetOptions -from ..utils import sqlmesh_model_name_to_key -from .base import ContextCls, SQLMeshController +from dagster_sqlmesh.controller.base import ContextCls, SQLMeshController +from dagster_sqlmesh.translator import SQLMeshDagsterTranslator +from dagster_sqlmesh.types import SQLMeshModelDep, SQLMeshMultiAssetOptions +from dagster_sqlmesh.utils import get_asset_key_str logger = logging.getLogger(__name__) @@ -16,7 +17,7 @@ class DagsterSQLMeshController(SQLMeshController[ContextCls]): """An extension of the sqlmesh controller specifically for dagster use""" def to_asset_outs( - self, environment: str, translator: SQLMeshDagsterTranslator + self, environment: str, translator: SQLMeshDagsterTranslator, ) -> SQLMeshMultiAssetOptions: with self.instance(environment, "to_asset_outs") as instance: context = instance.context @@ -24,10 +25,7 @@ def to_asset_outs( depsMap: dict[str, CoercibleToAssetDep] = {} for model, deps in instance.non_external_models_dag(): - asset_key = translator.get_asset_key_from_model( - context, - model, - ) + asset_key = translator.get_asset_key(context=context, fqn=model.fqn) model_deps = [ SQLMeshModelDep(fqn=dep, model=context.get_model(dep)) for dep in deps @@ -38,18 +36,27 @@ def to_asset_outs( for dep in model_deps: if dep.model: internal_asset_deps.add( - translator.get_asset_key_from_model(context, dep.model) + translator.get_asset_key(context, dep.model.fqn) ) else: - table = translator.get_fqn_to_table(context, dep.fqn) - key = translator.get_asset_key_fqn(context, dep.fqn) + table = get_asset_key_str(dep.fqn) + key = translator.get_asset_key(context, dep.fqn) internal_asset_deps.add(key) # create an external dep - depsMap[table.name] = AssetDep(key) - model_key = sqlmesh_model_name_to_key(model.name) - output.outs[model_key] = AssetOut( - key=asset_key, tags=asset_tags, is_required=False - ) + depsMap[table] = AssetDep(key) + model_key = get_asset_key_str(model.fqn) + # If current Dagster supports "kinds", add labels for Dagster UI + if "kinds" in signature(AssetOut).parameters: + output.outs[model_key] = AssetOut( + key=asset_key, tags=asset_tags, is_required=False, + group_name=translator.get_group_name(context, model), + kinds={"sqlmesh", translator._get_context_dialect(context).lower()} + ) + else: + output.outs[model_key] = AssetOut( + key=asset_key, tags=asset_tags, is_required=False, + group_name=translator.get_group_name(context, model) + ) output.internal_asset_deps[model_key] = internal_asset_deps output.deps = list(depsMap.values()) diff --git a/dagster_sqlmesh/resource.py b/dagster_sqlmesh/resource.py index de456ff..99609e7 100644 --- a/dagster_sqlmesh/resource.py +++ b/dagster_sqlmesh/resource.py @@ -15,17 +15,16 @@ from sqlmesh.utils.date import TimeLike from sqlmesh.utils.errors import SQLMeshError +from dagster_sqlmesh import console +from dagster_sqlmesh.config import SQLMeshContextConfig +from dagster_sqlmesh.controller import PlanOptions, RunOptions from dagster_sqlmesh.controller.base import ( DEFAULT_CONTEXT_FACTORY, ContextCls, ContextFactory, ) - -from . import console -from .config import SQLMeshContextConfig -from .controller import PlanOptions, RunOptions -from .controller.dagster import DagsterSQLMeshController -from .utils import sqlmesh_model_name_to_key +from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController +from dagster_sqlmesh.utils import get_asset_key_str class MaterializationTracker: @@ -147,7 +146,7 @@ def __init__( self._prefix = prefix self._context = context self._logger = context.log - self._tracker = MaterializationTracker(dag.sorted[:], self._logger) + self._tracker = MaterializationTracker(sorted_dag=dag.sorted[:], logger=self._logger) self._stage = "plan" self._errors: list[Exception] = [] self._is_testing = is_testing @@ -173,7 +172,8 @@ def notify_success( # We allow selecting models. That value is mapped to models_map. # If the model is not in models_map, we can skip any notification if model: - output_key = sqlmesh_model_name_to_key(model.name) + # Passing model.fqn to get internal unique asset key + output_key = get_asset_key_str(model.fqn) if not self._is_testing: # Stupidly dagster when testing cannot use the following # method so we must specifically skip this when testing @@ -227,7 +227,7 @@ def report_event(self, event: console.ConsoleEvent) -> None: log_context.info( "Snapshot progress update", { - "asset_key": sqlmesh_model_name_to_key(snapshot.model.name), + "asset_key": get_asset_key_str(snapshot.model.name), "progress": f"{done}/{expected}", "duration_ms": duration_ms, }, @@ -327,7 +327,10 @@ def run( logger = context.log - controller = self.get_controller(context_factory, logger) + controller = self.get_controller( + context_factory=context_factory, + log_override=logger + ) with controller.instance(environment) as mesh: dag = mesh.models_dag() @@ -338,7 +341,10 @@ def run( [model.fqn for model, _ in mesh.non_external_models_dag()] ) selected_models_set, models_map, select_models = ( - self._get_selected_models_from_context(context, models) + self._get_selected_models_from_context( + context=context, + models=models + ) ) if all_available_models == selected_models_set or select_models is None: @@ -351,7 +357,8 @@ def run( logger.info(f"selected models: {select_models}") event_handler = DagsterSQLMeshEventHandler( - context, models_map, dag, "sqlmesh: ", is_testing=self.is_testing + context=context, models_map=models_map, dag=dag, + prefix="sqlmesh: ", is_testing=self.is_testing ) try: @@ -397,7 +404,7 @@ def _get_selected_models_from_context( select_models: list[str] = [] models_map = {} for key, model in models.items(): - if sqlmesh_model_name_to_key(model.name) in selected_output_names: + if get_asset_key_str(model.fqn) in selected_output_names: models_map[key] = model select_models.append(model.name) return ( @@ -414,5 +421,5 @@ def get_controller( return DagsterSQLMeshController.setup_with_config( config=self.config, context_factory=context_factory, - log_override=log_override, + log_override=log_override ) diff --git a/dagster_sqlmesh/test_sqlmesh_context.py b/dagster_sqlmesh/test_sqlmesh_context.py index 075a544..75858a9 100644 --- a/dagster_sqlmesh/test_sqlmesh_context.py +++ b/dagster_sqlmesh/test_sqlmesh_context.py @@ -2,7 +2,7 @@ import polars -from .testing import SQLMeshTestContext +from dagster_sqlmesh.testing import SQLMeshTestContext logger = logging.getLogger(__name__) diff --git a/dagster_sqlmesh/translator.py b/dagster_sqlmesh/translator.py index e3791ba..2da090d 100644 --- a/dagster_sqlmesh/translator.py +++ b/dagster_sqlmesh/translator.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence -import sqlglot from dagster import AssetKey from sqlglot import exp from sqlmesh.core.context import Context @@ -9,19 +9,20 @@ class SQLMeshDagsterTranslator: """Translates sqlmesh objects for dagster""" - def get_asset_key_from_model(self, context: Context, model: Model) -> AssetKey: + def get_asset_key(self, context: Context, fqn: str) -> AssetKey: """Given the sqlmesh context and a model return the asset key""" - return AssetKey(model.view_name) - - def get_asset_key_fqn(self, context: Context, fqn: str) -> AssetKey: - """Given the sqlmesh context and a fqn of a model return an asset key""" - table = self.get_fqn_to_table(context, fqn) - return AssetKey(table.name) - - def get_fqn_to_table(self, context: Context, fqn: str) -> exp.Table: - """Given the sqlmesh context and a fqn return the table""" - dialect = self._get_context_dialect(context) - return sqlglot.to_table(fqn, dialect=dialect) + path = self.get_asset_key_name(fqn) + return AssetKey(path) + + def get_asset_key_name(self, fqn: str) -> Sequence[str]: + table = exp.to_table(fqn) + asset_key_name = [table.catalog, table.db, table.name] + + return asset_key_name + + def get_group_name(self, context: Context, model: Model) -> str: + path = self.get_asset_key_name(model.fqn) + return path[-2] def _get_context_dialect(self, context: Context) -> str: return context.engine_adapter.dialect diff --git a/dagster_sqlmesh/utils.py b/dagster_sqlmesh/utils.py index 8a08b2a..96b2dba 100644 --- a/dagster_sqlmesh/utils.py +++ b/dagster_sqlmesh/utils.py @@ -1,13 +1,15 @@ +from sqlglot import exp from sqlmesh.core.snapshot import SnapshotId -def sqlmesh_model_name_to_key(name: str) -> str: - return name.replace(".", "_dot__") - - -def key_to_sqlmesh_model_name(key: str) -> str: - return key.replace("_dot__", ".") - +def get_asset_key_str(fqn: str) -> str: + # This is an internal identifier used to map outputs and dependencies + # it will not affect the existing AssetKeys + # Only alphanumeric characters and underscores + table = exp.to_table(fqn) + asset_key_name = [table.catalog, table.db, table.name] + + return "sqlmesh__" + "_".join(asset_key_name) def snapshot_id_to_model_name(snapshot_id: SnapshotId) -> str: """Convert a SnapshotId to its model name. diff --git a/sample/dagster_project/definitions.py b/sample/dagster_project/definitions.py index 345de54..8719a9e 100644 --- a/sample/dagster_project/definitions.py +++ b/sample/dagster_project/definitions.py @@ -21,7 +21,7 @@ sqlmesh_config = SQLMeshContextConfig(path=SQLMESH_PROJECT_PATH, gateway="local") -@asset +@asset(key=["db", "sources", "reset_asset"]) def reset_asset() -> MaterializeResult: """An asset used for testing this entire workflow. If the duckdb database is found, this will delete it. This allows us to continously test this dag if @@ -34,7 +34,7 @@ def reset_asset() -> MaterializeResult: return MaterializeResult(metadata={"deleted": deleted}) -@asset(deps=[reset_asset]) +@asset(deps=[reset_asset], key=["db", "sources", "test_source"]) def test_source() -> pl.DataFrame: """Sets up the `test_source` table in duckdb that one of the sample sqlmesh models depends on"""