Skip to content

feat: Translator for external deps, minor improvement to AssetOut #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 30, 2025
Merged
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ from dagster import (
AssetExecutionContext,
Definitions,
)
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource, SQLMeshDagsterTranslator

sqlmesh_config = SQLMeshContextConfig(path="/home/foo/sqlmesh_project", gateway="name-of-your-gateway")

@sqlmesh_assets(environment="dev", config=sqlmesh_config)
@sqlmesh_assets(environment="dev", config=sqlmesh_config, translator=SQLMeshDagsterTranslator())
def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource):
yield from sqlmesh.run(context)

Expand Down
5 changes: 2 additions & 3 deletions dagster_sqlmesh/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
from dagster import AssetsDefinition, RetryPolicy, multi_asset
from sqlmesh import Context

from dagster_sqlmesh.config import SQLMeshContextConfig
from dagster_sqlmesh.controller import (
ContextCls,
ContextFactory,
DagsterSQLMeshController,
)
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator

from .config import SQLMeshContextConfig

logger = logging.getLogger(__name__)


Expand All @@ -34,7 +33,7 @@ def sqlmesh_assets(
controller = DagsterSQLMeshController.setup_with_config(config=config, context_factory=context_factory)
if not dagster_sqlmesh_translator:
dagster_sqlmesh_translator = SQLMeshDagsterTranslator()
conversion = controller.to_asset_outs(environment, dagster_sqlmesh_translator)
conversion = controller.to_asset_outs(environment, translator=dagster_sqlmesh_translator)

return multi_asset(
name=name,
Expand Down
2 changes: 1 addition & 1 deletion dagster_sqlmesh/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ class SQLMeshContextConfig(Config):
def sqlmesh_config(self) -> MeshConfig | None:
if self.config_override:
return MeshConfig.parse_obj(self.config_override)
return None
return None
6 changes: 3 additions & 3 deletions dagster_sqlmesh/controller/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@
from sqlmesh.utils.dag import DAG
from sqlmesh.utils.date import TimeLike

from ..config import SQLMeshContextConfig
from ..console import (
from dagster_sqlmesh.config import SQLMeshContextConfig
from dagster_sqlmesh.console import (
ConsoleEvent,
ConsoleEventHandler,
ConsoleException,
EventConsole,
Plan,
SnapshotCategorizer,
)
from ..events import ConsoleGenerator
from dagster_sqlmesh.events import ConsoleGenerator

logger = logging.getLogger(__name__)

Expand Down
41 changes: 24 additions & 17 deletions dagster_sqlmesh/controller/dagster.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# pyright: reportPrivateImportUsage=false
import logging
from inspect import signature

from dagster import AssetDep, AssetKey, AssetOut
from dagster._core.definitions.asset_dep import CoercibleToAssetDep

from ..translator import SQLMeshDagsterTranslator
from ..types import SQLMeshModelDep, SQLMeshMultiAssetOptions
from ..utils import sqlmesh_model_name_to_key
from .base import ContextCls, SQLMeshController
from dagster_sqlmesh.controller.base import ContextCls, SQLMeshController
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
from dagster_sqlmesh.types import SQLMeshModelDep, SQLMeshMultiAssetOptions
from dagster_sqlmesh.utils import get_asset_key_str

logger = logging.getLogger(__name__)

Expand All @@ -16,18 +17,15 @@ class DagsterSQLMeshController(SQLMeshController[ContextCls]):
"""An extension of the sqlmesh controller specifically for dagster use"""

def to_asset_outs(
self, environment: str, translator: SQLMeshDagsterTranslator
self, environment: str, translator: SQLMeshDagsterTranslator,
) -> SQLMeshMultiAssetOptions:
with self.instance(environment, "to_asset_outs") as instance:
context = instance.context
output = SQLMeshMultiAssetOptions()
depsMap: dict[str, CoercibleToAssetDep] = {}

for model, deps in instance.non_external_models_dag():
asset_key = translator.get_asset_key_from_model(
context,
model,
)
asset_key = translator.get_asset_key(context=context, fqn=model.fqn)
model_deps = [
SQLMeshModelDep(fqn=dep, model=context.get_model(dep))
for dep in deps
Expand All @@ -38,18 +36,27 @@ def to_asset_outs(
for dep in model_deps:
if dep.model:
internal_asset_deps.add(
translator.get_asset_key_from_model(context, dep.model)
translator.get_asset_key(context, dep.model.fqn)
)
else:
table = translator.get_fqn_to_table(context, dep.fqn)
key = translator.get_asset_key_fqn(context, dep.fqn)
table = get_asset_key_str(dep.fqn)
key = translator.get_asset_key(context, dep.fqn)
internal_asset_deps.add(key)
# create an external dep
depsMap[table.name] = AssetDep(key)
model_key = sqlmesh_model_name_to_key(model.name)
output.outs[model_key] = AssetOut(
key=asset_key, tags=asset_tags, is_required=False
)
depsMap[table] = AssetDep(key)
model_key = get_asset_key_str(model.fqn)
# If current Dagster supports "kinds", add labels for Dagster UI
if "kinds" in signature(AssetOut).parameters:
output.outs[model_key] = AssetOut(
key=asset_key, tags=asset_tags, is_required=False,
group_name=translator.get_group_name(context, model),
kinds={"sqlmesh", translator._get_context_dialect(context).lower()}
)
else:
output.outs[model_key] = AssetOut(
key=asset_key, tags=asset_tags, is_required=False,
group_name=translator.get_group_name(context, model)
)
output.internal_asset_deps[model_key] = internal_asset_deps

output.deps = list(depsMap.values())
Expand Down
35 changes: 21 additions & 14 deletions dagster_sqlmesh/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@
from sqlmesh.utils.date import TimeLike
from sqlmesh.utils.errors import SQLMeshError

from dagster_sqlmesh import console
from dagster_sqlmesh.config import SQLMeshContextConfig
from dagster_sqlmesh.controller import PlanOptions, RunOptions
from dagster_sqlmesh.controller.base import (
DEFAULT_CONTEXT_FACTORY,
ContextCls,
ContextFactory,
)

from . import console
from .config import SQLMeshContextConfig
from .controller import PlanOptions, RunOptions
from .controller.dagster import DagsterSQLMeshController
from .utils import sqlmesh_model_name_to_key
from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController
from dagster_sqlmesh.utils import get_asset_key_str


class MaterializationTracker:
Expand Down Expand Up @@ -147,7 +146,7 @@ def __init__(
self._prefix = prefix
self._context = context
self._logger = context.log
self._tracker = MaterializationTracker(dag.sorted[:], self._logger)
self._tracker = MaterializationTracker(sorted_dag=dag.sorted[:], logger=self._logger)
self._stage = "plan"
self._errors: list[Exception] = []
self._is_testing = is_testing
Expand All @@ -173,7 +172,8 @@ def notify_success(
# We allow selecting models. That value is mapped to models_map.
# If the model is not in models_map, we can skip any notification
if model:
output_key = sqlmesh_model_name_to_key(model.name)
# Passing model.fqn to get internal unique asset key
output_key = get_asset_key_str(model.fqn)
if not self._is_testing:
# Stupidly dagster when testing cannot use the following
# method so we must specifically skip this when testing
Expand Down Expand Up @@ -227,7 +227,7 @@ def report_event(self, event: console.ConsoleEvent) -> None:
log_context.info(
"Snapshot progress update",
{
"asset_key": sqlmesh_model_name_to_key(snapshot.model.name),
"asset_key": get_asset_key_str(snapshot.model.name),
"progress": f"{done}/{expected}",
"duration_ms": duration_ms,
},
Expand Down Expand Up @@ -327,7 +327,10 @@ def run(

logger = context.log

controller = self.get_controller(context_factory, logger)
controller = self.get_controller(
context_factory=context_factory,
log_override=logger
)

with controller.instance(environment) as mesh:
dag = mesh.models_dag()
Expand All @@ -338,7 +341,10 @@ def run(
[model.fqn for model, _ in mesh.non_external_models_dag()]
)
selected_models_set, models_map, select_models = (
self._get_selected_models_from_context(context, models)
self._get_selected_models_from_context(
context=context,
models=models
)
)

if all_available_models == selected_models_set or select_models is None:
Expand All @@ -351,7 +357,8 @@ def run(
logger.info(f"selected models: {select_models}")

event_handler = DagsterSQLMeshEventHandler(
context, models_map, dag, "sqlmesh: ", is_testing=self.is_testing
context=context, models_map=models_map, dag=dag,
prefix="sqlmesh: ", is_testing=self.is_testing
)

try:
Expand Down Expand Up @@ -397,7 +404,7 @@ def _get_selected_models_from_context(
select_models: list[str] = []
models_map = {}
for key, model in models.items():
if sqlmesh_model_name_to_key(model.name) in selected_output_names:
if get_asset_key_str(model.fqn) in selected_output_names:
models_map[key] = model
select_models.append(model.name)
return (
Expand All @@ -414,5 +421,5 @@ def get_controller(
return DagsterSQLMeshController.setup_with_config(
config=self.config,
context_factory=context_factory,
log_override=log_override,
log_override=log_override
)
2 changes: 1 addition & 1 deletion dagster_sqlmesh/test_sqlmesh_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import polars

from .testing import SQLMeshTestContext
from dagster_sqlmesh.testing import SQLMeshTestContext

logger = logging.getLogger(__name__)

Expand Down
27 changes: 14 additions & 13 deletions dagster_sqlmesh/translator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Sequence

import sqlglot
from dagster import AssetKey
from sqlglot import exp
from sqlmesh.core.context import Context
Expand All @@ -9,19 +9,20 @@
class SQLMeshDagsterTranslator:
"""Translates sqlmesh objects for dagster"""

def get_asset_key_from_model(self, context: Context, model: Model) -> AssetKey:
def get_asset_key(self, context: Context, fqn: str) -> AssetKey:
"""Given the sqlmesh context and a model return the asset key"""
return AssetKey(model.view_name)

def get_asset_key_fqn(self, context: Context, fqn: str) -> AssetKey:
"""Given the sqlmesh context and a fqn of a model return an asset key"""
table = self.get_fqn_to_table(context, fqn)
return AssetKey(table.name)

def get_fqn_to_table(self, context: Context, fqn: str) -> exp.Table:
"""Given the sqlmesh context and a fqn return the table"""
dialect = self._get_context_dialect(context)
return sqlglot.to_table(fqn, dialect=dialect)
path = self.get_asset_key_name(fqn)
return AssetKey(path)

def get_asset_key_name(self, fqn: str) -> Sequence[str]:
table = exp.to_table(fqn)
asset_key_name = [table.catalog, table.db, table.name]

return asset_key_name

def get_group_name(self, context: Context, model: Model) -> str:
path = self.get_asset_key_name(model.fqn)
return path[-2]

def _get_context_dialect(self, context: Context) -> str:
return context.engine_adapter.dialect
Expand Down
16 changes: 9 additions & 7 deletions dagster_sqlmesh/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from sqlglot import exp
from sqlmesh.core.snapshot import SnapshotId


def sqlmesh_model_name_to_key(name: str) -> str:
return name.replace(".", "_dot__")


def key_to_sqlmesh_model_name(key: str) -> str:
return key.replace("_dot__", ".")

def get_asset_key_str(fqn: str) -> str:
# This is an internal identifier used to map outputs and dependencies
# it will not affect the existing AssetKeys
# Only alphanumeric characters and underscores
table = exp.to_table(fqn)
asset_key_name = [table.catalog, table.db, table.name]

return "sqlmesh__" + "_".join(asset_key_name)

def snapshot_id_to_model_name(snapshot_id: SnapshotId) -> str:
"""Convert a SnapshotId to its model name.
Expand Down
4 changes: 2 additions & 2 deletions sample/dagster_project/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
sqlmesh_config = SQLMeshContextConfig(path=SQLMESH_PROJECT_PATH, gateway="local")


@asset
@asset(key=["db", "sources", "reset_asset"])
def reset_asset() -> MaterializeResult:
"""An asset used for testing this entire workflow. If the duckdb database is
found, this will delete it. This allows us to continously test this dag if
Expand All @@ -34,7 +34,7 @@ def reset_asset() -> MaterializeResult:
return MaterializeResult(metadata={"deleted": deleted})


@asset(deps=[reset_asset])
@asset(deps=[reset_asset], key=["db", "sources", "test_source"])
def test_source() -> pl.DataFrame:
"""Sets up the `test_source` table in duckdb that one of the sample sqlmesh
models depends on"""
Expand Down