Skip to content

feat: Translator for external deps, minor improvement to AssetOut #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 30, 2025
Merged
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ from dagster import (
AssetExecutionContext,
Definitions,
)
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource, SQLMeshDagsterTranslator

sqlmesh_config = SQLMeshContextConfig(path="/home/foo/sqlmesh_project", gateway="name-of-your-gateway")

@sqlmesh_assets(environment="dev", config=sqlmesh_config)
def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource):
yield from sqlmesh.run(context)
yield from sqlmesh.run(context, translator=SQLMeshDagsterTranslator())

defs = Definitions(
assets=[sqlmesh_project],
Expand Down
5 changes: 3 additions & 2 deletions dagster_sqlmesh/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ def sqlmesh_assets(
) -> t.Callable[[t.Callable[..., t.Any]], AssetsDefinition]:
controller = DagsterSQLMeshController.setup_with_config(config=config, context_factory=context_factory)
if not dagster_sqlmesh_translator:
dagster_sqlmesh_translator = SQLMeshDagsterTranslator()
conversion = controller.to_asset_outs(environment, dagster_sqlmesh_translator)
dagster_sqlmesh_translator = dagster_sqlmesh_translator

conversion = controller.to_asset_outs(environment, translator=dagster_sqlmesh_translator)

return multi_asset(
name=name,
Expand Down
2 changes: 1 addition & 1 deletion dagster_sqlmesh/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ class SQLMeshContextConfig(Config):
def sqlmesh_config(self) -> MeshConfig | None:
if self.config_override:
return MeshConfig.parse_obj(self.config_override)
return None
return None
10 changes: 9 additions & 1 deletion dagster_sqlmesh/controller/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
SnapshotCategorizer,
)
from ..events import ConsoleGenerator
from ..translator import SQLMeshDagsterTranslator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -416,6 +417,7 @@ class SQLMeshController(t.Generic[ContextCls]):
config: SQLMeshContextConfig
console: EventConsole
logger: logging.Logger
translator: SQLMeshDagsterTranslator

@classmethod
def setup(
Expand All @@ -425,10 +427,12 @@ def setup(
context_factory: ContextFactory[ContextCls],
gateway: str = "local",
log_override: logging.Logger | None = None,
translator_override: SQLMeshDagsterTranslator | None = None,
) -> t.Self:
return cls.setup_with_config(
config=SQLMeshContextConfig(path=path, gateway=gateway),
log_override=log_override,
translator_override=translator_override,
context_factory=context_factory,
)

Expand All @@ -439,13 +443,15 @@ def setup_with_config(
config: SQLMeshContextConfig,
context_factory: ContextFactory[ContextCls] = DEFAULT_CONTEXT_FACTORY,
log_override: logging.Logger | None = None,
translator_override: SQLMeshDagsterTranslator | None = None,
) -> t.Self:
console = EventConsole(log_override=log_override) # type: ignore
controller = cls(
console=console,
config=config,
log_override=log_override,
context_factory=context_factory,
translator_override=translator_override
)
return controller

Expand All @@ -455,11 +461,13 @@ def __init__(
console: EventConsole,
context_factory: ContextFactory[ContextCls],
log_override: logging.Logger | None = None,
translator_override: SQLMeshDagsterTranslator | None = None,
) -> None:
self.config = config
self.console = console
self.logger = log_override or logger
self._context_factory = context_factory
self.translator = translator_override or SQLMeshDagsterTranslator()
self._context_open = False

def set_logger(self, logger: logging.Logger) -> None:
Expand All @@ -481,7 +489,7 @@ def _create_context(self) -> ContextCls:
options["config"] = self.config.sqlmesh_config
set_console(self.console)
return self._context_factory(**options)

@contextmanager
def instance(
self, environment: str, component: str = "unknown"
Expand Down
35 changes: 21 additions & 14 deletions dagster_sqlmesh/controller/dagster.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# pyright: reportPrivateImportUsage=false
import logging
from inspect import signature

from dagster import AssetDep, AssetKey, AssetOut
from dagster._core.definitions.asset_dep import CoercibleToAssetDep

from ..translator import SQLMeshDagsterTranslator
from ..types import SQLMeshModelDep, SQLMeshMultiAssetOptions
from ..utils import sqlmesh_model_name_to_key
from .base import ContextCls, SQLMeshController

logger = logging.getLogger(__name__)
Expand All @@ -16,18 +16,16 @@ class DagsterSQLMeshController(SQLMeshController[ContextCls]):
"""An extension of the sqlmesh controller specifically for dagster use"""

def to_asset_outs(
self, environment: str, translator: SQLMeshDagsterTranslator
self, environment: str, translator: SQLMeshDagsterTranslator | None = None,
) -> SQLMeshMultiAssetOptions:
with self.instance(environment, "to_asset_outs") as instance:
translator = translator or SQLMeshDagsterTranslator()
context = instance.context
output = SQLMeshMultiAssetOptions()
depsMap: dict[str, CoercibleToAssetDep] = {}

for model, deps in instance.non_external_models_dag():
asset_key = translator.get_asset_key_from_model(
context,
model,
)
asset_key = translator.get_asset_key(context=context, fqn=model.fqn)
model_deps = [
SQLMeshModelDep(fqn=dep, model=context.get_model(dep))
for dep in deps
Expand All @@ -38,18 +36,27 @@ def to_asset_outs(
for dep in model_deps:
if dep.model:
internal_asset_deps.add(
translator.get_asset_key_from_model(context, dep.model)
translator.get_asset_key(context, dep.model.fqn)
)
else:
table = translator.get_fqn_to_table(context, dep.fqn)
key = translator.get_asset_key_fqn(context, dep.fqn)
table = translator.get_asset_key_str(dep.fqn)
key = translator.get_asset_key(context, dep.fqn)
internal_asset_deps.add(key)
# create an external dep
depsMap[table.name] = AssetDep(key)
model_key = sqlmesh_model_name_to_key(model.name)
output.outs[model_key] = AssetOut(
key=asset_key, tags=asset_tags, is_required=False
)
depsMap[table] = AssetDep(key)
model_key = translator.get_asset_key_str(model.fqn)
# If current Dagster supports "kinds", add labels for Dagster UI
if "kinds" in signature(AssetOut).parameters:
output.outs[model_key] = AssetOut(
key=asset_key, tags=asset_tags, is_required=False,
group_name=translator.get_group_name(context, model),
kinds={"sqlmesh", translator._get_context_dialect(context).lower()}
)
else:
output.outs[model_key] = AssetOut(
key=asset_key, tags=asset_tags, is_required=False,
group_name=translator.get_group_name(context, model)
)
output.internal_asset_deps[model_key] = internal_asset_deps

output.deps = list(depsMap.values())
Expand Down
29 changes: 19 additions & 10 deletions dagster_sqlmesh/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,17 @@
from .config import SQLMeshContextConfig
from .controller import PlanOptions, RunOptions
from .controller.dagster import DagsterSQLMeshController
from .utils import sqlmesh_model_name_to_key
from .translator import SQLMeshDagsterTranslator


class MaterializationTracker:
"""Tracks sqlmesh materializations and notifies dagster in the correct
order. This is necessary because sqlmesh may skip some materializations that
have no changes and those will be reported as completed out of order."""

def __init__(self, sorted_dag: list[str], logger: logging.Logger) -> None:
def __init__(self, sorted_dag: list[str], logger: logging.Logger, translator: SQLMeshDagsterTranslator) -> None:
self.logger = logger
self.translator = translator
self._batches: dict[Snapshot, int] = {}
self._count: dict[Snapshot, int] = {}
self._complete_update_status: dict[str, bool] = {}
Expand Down Expand Up @@ -141,13 +142,15 @@ def __init__(
models_map: dict[str, Model],
dag: DAG[t.Any],
prefix: str,
translator: SQLMeshDagsterTranslator,
is_testing: bool = False,
) -> None:
self._models_map = models_map
self._prefix = prefix
self._context = context
self._logger = context.log
self._tracker = MaterializationTracker(dag.sorted[:], self._logger)
self.translator = translator
self._tracker = MaterializationTracker(sorted_dag=dag.sorted[:], logger=self._logger, translator=self.translator)
self._stage = "plan"
self._errors: list[Exception] = []
self._is_testing = is_testing
Expand All @@ -173,7 +176,8 @@ def notify_success(
# We allow selecting models. That value is mapped to models_map.
# If the model is not in models_map, we can skip any notification
if model:
output_key = sqlmesh_model_name_to_key(model.name)
# Passing model.fqn to translator
output_key = self.translator.get_asset_key_str(model.fqn)
if not self._is_testing:
# Stupidly dagster when testing cannot use the following
# method so we must specifically skip this when testing
Expand Down Expand Up @@ -227,7 +231,7 @@ def report_event(self, event: console.ConsoleEvent) -> None:
log_context.info(
"Snapshot progress update",
{
"asset_key": sqlmesh_model_name_to_key(snapshot.model.name),
"asset_key": self.translator.get_asset_key_str(snapshot.model.name),
"progress": f"{done}/{expected}",
"duration_ms": duration_ms,
},
Expand Down Expand Up @@ -311,6 +315,7 @@ def run(
context: AssetExecutionContext,
*,
context_factory: ContextFactory[ContextCls] = DEFAULT_CONTEXT_FACTORY,
translator: SQLMeshDagsterTranslator = SQLMeshDagsterTranslator(),
environment: str = "dev",
start: TimeLike | None = None,
end: TimeLike | None = None,
Expand All @@ -327,7 +332,7 @@ def run(

logger = context.log

controller = self.get_controller(context_factory, logger)
controller = self.get_controller(context_factory, logger, translator)

with controller.instance(environment) as mesh:
dag = mesh.models_dag()
Expand All @@ -338,7 +343,7 @@ def run(
[model.fqn for model, _ in mesh.non_external_models_dag()]
)
selected_models_set, models_map, select_models = (
self._get_selected_models_from_context(context, models)
self._get_selected_models_from_context(context, models, translator)
)

if all_available_models == selected_models_set or select_models is None:
Expand All @@ -351,7 +356,8 @@ def run(
logger.info(f"selected models: {select_models}")

event_handler = DagsterSQLMeshEventHandler(
context, models_map, dag, "sqlmesh: ", is_testing=self.is_testing
context=context, models_map=models_map, dag=dag,
prefix="sqlmesh: ", translator=translator, is_testing=self.is_testing
)

try:
Expand Down Expand Up @@ -380,7 +386,8 @@ def run(
yield from event_handler.notify_success(mesh.context)

def _get_selected_models_from_context(
self, context: AssetExecutionContext, models: MappingProxyType[str, Model]
self, context: AssetExecutionContext, models: MappingProxyType[str, Model],
translator: SQLMeshDagsterTranslator
) -> tuple[set[str], dict[str, Model], list[str] | None]:
models_map = models.copy()
try:
Expand All @@ -397,7 +404,7 @@ def _get_selected_models_from_context(
select_models: list[str] = []
models_map = {}
for key, model in models.items():
if sqlmesh_model_name_to_key(model.name) in selected_output_names:
if translator.get_asset_key_str(model.fqn) in selected_output_names:
models_map[key] = model
select_models.append(model.name)
return (
Expand All @@ -410,9 +417,11 @@ def get_controller(
self,
context_factory: ContextFactory[ContextCls],
log_override: logging.Logger | None = None,
translator: SQLMeshDagsterTranslator | None = None
) -> DagsterSQLMeshController[ContextCls]:
return DagsterSQLMeshController.setup_with_config(
config=self.config,
context_factory=context_factory,
log_override=log_override,
translator_override=translator
)
4 changes: 1 addition & 3 deletions dagster_sqlmesh/test_asset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from dagster_sqlmesh.asset import SQLMeshDagsterTranslator
from dagster_sqlmesh.conftest import SQLMeshTestContext


def test_sqlmesh_context_to_asset_outs(sample_sqlmesh_test_context: SQLMeshTestContext):
controller = sample_sqlmesh_test_context.create_controller()
translator = SQLMeshDagsterTranslator()
outs = controller.to_asset_outs("dev", translator)
outs = controller.to_asset_outs("dev")
assert len(list(outs.deps)) == 1
assert len(outs.outs) == 10
35 changes: 21 additions & 14 deletions dagster_sqlmesh/translator.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,34 @@
import re
from collections.abc import Sequence

import sqlglot
from dagster import AssetKey
from sqlglot import exp
from sqlmesh.core.context import Context
from sqlmesh.core.model import Model


class SQLMeshDagsterTranslator:
"""Translates sqlmesh objects for dagster"""

def get_asset_key_from_model(self, context: Context, model: Model) -> AssetKey:
def get_asset_key(self, context: Context, fqn: str) -> AssetKey:
"""Given the sqlmesh context and a model return the asset key"""
return AssetKey(model.view_name)

def get_asset_key_fqn(self, context: Context, fqn: str) -> AssetKey:
"""Given the sqlmesh context and a fqn of a model return an asset key"""
table = self.get_fqn_to_table(context, fqn)
return AssetKey(table.name)

def get_fqn_to_table(self, context: Context, fqn: str) -> exp.Table:
"""Given the sqlmesh context and a fqn return the table"""
dialect = self._get_context_dialect(context)
return sqlglot.to_table(fqn, dialect=dialect)
path = self.get_asset_key_name(fqn)
return AssetKey(path)

def get_asset_key_name(self, fqn: str) -> Sequence[str]:
asset_path = re.findall(r"[A-Za-z0-9_\-]+", fqn)
return asset_path

def get_asset_key_str(self, fqn: str) -> str:
# This is an internal identifier used to map outputs and dependencies
# it will not affect the existing AssetKeys
# Only alphanumeric characters and underscores
path = self.get_asset_key_name(fqn)

return "__dot__".join(path).replace("-", "__dash__")

def get_group_name(self, context: Context, model: Model) -> str:
path = self.get_asset_key_name(model.fqn)
return path[-2]

def _get_context_dialect(self, context: Context) -> str:
return context.engine_adapter.dialect
Expand Down
8 changes: 0 additions & 8 deletions dagster_sqlmesh/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
from sqlmesh.core.snapshot import SnapshotId


def sqlmesh_model_name_to_key(name: str) -> str:
return name.replace(".", "_dot__")


def key_to_sqlmesh_model_name(key: str) -> str:
return key.replace("_dot__", ".")


def snapshot_id_to_model_name(snapshot_id: SnapshotId) -> str:
"""Convert a SnapshotId to its model name.

Expand Down
Loading