Skip to content

Commit ff2801c

Browse files
committed
feat: move translator to config, add support for kinds label and specify group through translator
The SQLMeshDagsterTranslator has been simplified to return an AssetKey, the name of the AssetKey or a string equivalent, which deprecates utils.sqlmesh_model_name_to_key and utils.key_to_sqlmesh_model_name. The kinds label is now also added, to show technology labels on the UI, only when Dagster's version allows for it. Groups can now be specified through the translator, instead of having a fixed method.
1 parent e95e718 commit ff2801c

File tree

9 files changed

+75
-56
lines changed

9 files changed

+75
-56
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ from dagster import (
2424
AssetExecutionContext,
2525
Definitions,
2626
)
27-
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource
27+
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource, SQLMeshDagsterTranslator
2828

2929
sqlmesh_config = SQLMeshContextConfig(path="/home/foo/sqlmesh_project", gateway="name-of-your-gateway")
3030

3131
@sqlmesh_assets(environment="dev", config=sqlmesh_config)
3232
def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource):
33-
yield from sqlmesh.run(context)
33+
yield from sqlmesh.run(context, translator=SQLMeshDagsterTranslator())
3434

3535
defs = Definitions(
3636
assets=[sqlmesh_project],

dagster_sqlmesh/asset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ def sqlmesh_assets(
2525
# For now we don't set this by default
2626
enabled_subsetting: bool = False,
2727
) -> t.Callable[[t.Callable[..., t.Any]], AssetsDefinition]:
28-
controller = DagsterSQLMeshController.setup_with_config(config)
2928
if not dagster_sqlmesh_translator:
30-
dagster_sqlmesh_translator = SQLMeshDagsterTranslator()
31-
conversion = controller.to_asset_outs(environment, dagster_sqlmesh_translator)
29+
dagster_sqlmesh_translator = dagster_sqlmesh_translator
30+
31+
controller = DagsterSQLMeshController.setup_with_config(config)
32+
conversion = controller.to_asset_outs(environment, translator=dagster_sqlmesh_translator)
3233

3334
return multi_asset(
3435
name=name,

dagster_sqlmesh/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ class SQLMeshContextConfig(Config):
2626
gateway: str
2727
config_override: dict[str, Any] | None = Field(default_factory=lambda: None)
2828

29+
2930
@property
3031
def sqlmesh_config(self) -> MeshConfig | None:
3132
if self.config_override:
3233
return MeshConfig.parse_obj(self.config_override)
3334
return None
35+

dagster_sqlmesh/controller/base.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
SnapshotCategorizer,
2525
)
2626
from ..events import ConsoleGenerator
27+
from ..translator import SQLMeshDagsterTranslator
2728

2829
logger = logging.getLogger(__name__)
2930

@@ -400,30 +401,34 @@ class SQLMeshController:
400401
config: SQLMeshContextConfig
401402
console: EventConsole
402403
logger: logging.Logger
404+
translator: SQLMeshDagsterTranslator
403405

404406
@classmethod
405407
def setup(
406408
cls,
407409
path: str,
408410
gateway: str = "local",
409411
log_override: logging.Logger | None = None,
412+
translator_override: SQLMeshDagsterTranslator | None = None,
410413
) -> "SQLMeshController":
411414
return cls.setup_with_config(
412415
config=SQLMeshContextConfig(path=path, gateway=gateway),
413-
log_override=log_override,
416+
log_override=log_override, translator_override=translator_override
414417
)
415418

416419
@classmethod
417420
def setup_with_config(
418421
cls: type[T],
419422
config: SQLMeshContextConfig,
420423
log_override: logging.Logger | None = None,
424+
translator_override: SQLMeshDagsterTranslator | None = None,
421425
) -> T:
422426
console = EventConsole(log_override=log_override) # type: ignore
423427
controller = cls(
424428
console=console,
425429
config=config,
426430
log_override=log_override,
431+
translator_override=translator_override
427432
)
428433
return controller
429434

@@ -432,10 +437,12 @@ def __init__(
432437
config: SQLMeshContextConfig,
433438
console: EventConsole,
434439
log_override: logging.Logger | None = None,
440+
translator_override: SQLMeshDagsterTranslator | None = None,
435441
) -> None:
436442
self.config = config
437443
self.console = console
438444
self.logger = log_override or logger
445+
self.translator = translator_override or SQLMeshDagsterTranslator()
439446
self._context_open = False
440447

441448
def set_logger(self, logger: logging.Logger) -> None:
@@ -457,7 +464,7 @@ def _create_context(self) -> Context:
457464
options["config"] = self.config.sqlmesh_config
458465
set_console(self.console)
459466
return Context(**options)
460-
467+
461468
@contextmanager
462469
def instance(
463470
self, environment: str, component: str = "unknown"

dagster_sqlmesh/controller/dagster.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# pyright: reportPrivateImportUsage=false
22
import logging
3+
from inspect import signature
34

45
from dagster import AssetDep, AssetKey, AssetOut
56
from dagster._core.definitions.asset_dep import CoercibleToAssetDep
67

78
from ..translator import SQLMeshDagsterTranslator
89
from ..types import SQLMeshModelDep, SQLMeshMultiAssetOptions
9-
from ..utils import sqlmesh_model_name_to_key
1010
from .base import SQLMeshController
1111

1212
logger = logging.getLogger(__name__)
@@ -16,18 +16,16 @@ class DagsterSQLMeshController(SQLMeshController):
1616
"""An extension of the sqlmesh controller specifically for dagster use"""
1717

1818
def to_asset_outs(
19-
self, environment: str, translator: SQLMeshDagsterTranslator
19+
self, environment: str, translator: SQLMeshDagsterTranslator | None = None,
2020
) -> SQLMeshMultiAssetOptions:
2121
with self.instance(environment, "to_asset_outs") as instance:
22+
translator = translator or SQLMeshDagsterTranslator()
2223
context = instance.context
2324
output = SQLMeshMultiAssetOptions()
2425
depsMap: dict[str, CoercibleToAssetDep] = {}
2526

2627
for model, deps in instance.non_external_models_dag():
27-
asset_key = translator.get_asset_key_from_model(
28-
context,
29-
model,
30-
)
28+
asset_key = translator.get_asset_key(context=context, fqn=model.fqn)
3129
model_deps = [
3230
SQLMeshModelDep(fqn=dep, model=context.get_model(dep))
3331
for dep in deps
@@ -38,18 +36,27 @@ def to_asset_outs(
3836
for dep in model_deps:
3937
if dep.model:
4038
internal_asset_deps.add(
41-
translator.get_asset_key_from_model(context, dep.model)
39+
translator.get_asset_key(context, dep.model.fqn)
4240
)
4341
else:
44-
table = translator.get_fqn_to_table(context, dep.fqn)
45-
key = translator.get_asset_key_fqn(context, dep.fqn)
42+
table = translator.get_asset_key_str(dep.fqn)
43+
key = translator.get_asset_key(context, dep.fqn)
4644
internal_asset_deps.add(key)
4745
# create an external dep
48-
depsMap[table.name] = AssetDep(key)
49-
model_key = sqlmesh_model_name_to_key(model.name)
50-
output.outs[model_key] = AssetOut(
51-
key=asset_key, tags=asset_tags, is_required=False
52-
)
46+
depsMap[table] = AssetDep(key)
47+
model_key = translator.get_asset_key_str(model.fqn)
48+
# If current Dagster supports "kinds", add labels for Dagster UI
49+
if "kinds" in signature(AssetOut).parameters:
50+
output.outs[model_key] = AssetOut(
51+
key=asset_key, tags=asset_tags, is_required=False,
52+
group_name=translator.get_group_name(context, model),
53+
kinds={"sqlmesh", translator._get_context_dialect(context).lower()}
54+
)
55+
else:
56+
output.outs[model_key] = AssetOut(
57+
key=asset_key, tags=asset_tags, is_required=False,
58+
group_name=translator.get_group_name(context, model)
59+
)
5360
output.internal_asset_deps[model_key] = internal_asset_deps
5461

5562
output.deps = list(depsMap.values())

dagster_sqlmesh/resource.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,17 @@
1616
from .config import SQLMeshContextConfig
1717
from .controller import PlanOptions, RunOptions
1818
from .controller.dagster import DagsterSQLMeshController
19-
from .utils import sqlmesh_model_name_to_key
19+
from .translator import SQLMeshDagsterTranslator
2020

2121

2222
class MaterializationTracker:
2323
"""Tracks sqlmesh materializations and notifies dagster in the correct
2424
order. This is necessary because sqlmesh may skip some materializations that
2525
have no changes and those will be reported as completed out of order."""
2626

27-
def __init__(self, sorted_dag: list[str], logger: logging.Logger) -> None:
27+
def __init__(self, sorted_dag: list[str], logger: logging.Logger, translator: SQLMeshDagsterTranslator) -> None:
2828
self.logger = logger
29+
self.translator = translator
2930
self._batches: dict[Snapshot, int] = {}
3031
self._count: dict[Snapshot, int] = {}
3132
self._complete_update_status: dict[str, bool] = {}
@@ -114,12 +115,14 @@ def __init__(
114115
models_map: dict[str, Model],
115116
dag: DAG[t.Any],
116117
prefix: str,
118+
translator: SQLMeshDagsterTranslator
117119
) -> None:
118120
self._models_map = models_map
119121
self._prefix = prefix
120122
self._context = context
121123
self._logger = context.log
122-
self._tracker = MaterializationTracker(dag.sorted[:], self._logger)
124+
self.translator = translator
125+
self._tracker = MaterializationTracker(sorted_dag=dag.sorted[:], logger=self._logger, translator=self.translator)
123126
self._stage = "plan"
124127

125128
def process_events(self, event: console.ConsoleEvent) -> None:
@@ -143,7 +146,7 @@ def notify_success(
143146
# We allow selecting models. That value is mapped to models_map.
144147
# If the model is not in models_map, we can skip any notification
145148
if model:
146-
output_key = sqlmesh_model_name_to_key(model.name)
149+
output_key = self.translator.get_asset_key_str(model.name)
147150
asset_key = self._context.asset_key_for_output(output_key)
148151
yield MaterializeResult(
149152
asset_key=asset_key,
@@ -192,7 +195,7 @@ def report_event(self, event: console.ConsoleEvent) -> None:
192195
log_context.info(
193196
"Snapshot progress update",
194197
{
195-
"asset_key": sqlmesh_model_name_to_key(snapshot.model.name),
198+
"asset_key": self.translator.get_asset_key_str(snapshot.model.name),
196199
"progress": f"{done}/{expected}",
197200
"duration_ms": duration_ms,
198201
},
@@ -263,6 +266,7 @@ def run(
263266
self,
264267
context: AssetExecutionContext,
265268
*,
269+
translator: SQLMeshDagsterTranslator | None = None,
266270
environment: str = "dev",
267271
start: TimeLike | None = None,
268272
end: TimeLike | None = None,
@@ -279,7 +283,7 @@ def run(
279283

280284
logger = context.log
281285

282-
controller = self.get_controller(logger)
286+
controller = self.get_controller(logger, translator)
283287

284288
with controller.instance(environment) as mesh:
285289
dag = mesh.models_dag()
@@ -295,7 +299,7 @@ def run(
295299
models_map = {}
296300
for key, model in models.items():
297301
if (
298-
sqlmesh_model_name_to_key(model.name)
302+
self.translator.get_asset_key_str(model.name)
299303
in context.selected_output_names
300304
):
301305
models_map[key] = model
@@ -312,7 +316,8 @@ def run(
312316
logger.info(f"selected models: {select_models}")
313317

314318
event_handler = DagsterSQLMeshEventHandler(
315-
context, models_map, dag, "sqlmesh: "
319+
context=context, models_map=models_map, dag=dag,
320+
prefix="sqlmesh: ", translator=self.translator
316321
)
317322

318323
for event in mesh.plan_and_run(
@@ -330,8 +335,8 @@ def run(
330335
yield from event_handler.notify_success(mesh.context)
331336

332337
def get_controller(
333-
self, log_override: logging.Logger | None = None
338+
self, log_override: logging.Logger | None = None, translator: SQLMeshDagsterTranslator | None = None
334339
) -> DagsterSQLMeshController:
335340
return DagsterSQLMeshController.setup_with_config(
336-
self.config, log_override=log_override
341+
self.config, log_override=log_override, translator_override=translator
337342
)

dagster_sqlmesh/test_asset.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
from dagster_sqlmesh.asset import SQLMeshDagsterTranslator
21
from dagster_sqlmesh.conftest import SQLMeshTestContext
32

43

54
def test_sqlmesh_context_to_asset_outs(sample_sqlmesh_test_context: SQLMeshTestContext):
65
controller = sample_sqlmesh_test_context.create_controller()
7-
translator = SQLMeshDagsterTranslator()
8-
outs = controller.to_asset_outs("dev", translator)
6+
outs = controller.to_asset_outs("dev")
97
assert len(list(outs.deps)) == 1
108
assert len(outs.outs) == 9

dagster_sqlmesh/translator.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,34 @@
1+
import re
2+
from collections.abc import Sequence
13

2-
import sqlglot
34
from dagster import AssetKey
4-
from sqlglot import exp
55
from sqlmesh.core.context import Context
66
from sqlmesh.core.model import Model
77

88

99
class SQLMeshDagsterTranslator:
1010
"""Translates sqlmesh objects for dagster"""
1111

12-
def get_asset_key_from_model(self, context: Context, model: Model) -> AssetKey:
12+
def get_asset_key(self, context: Context, fqn: str) -> AssetKey:
1313
"""Given the sqlmesh context and a model return the asset key"""
14-
return AssetKey(model.view_name)
15-
16-
def get_asset_key_fqn(self, context: Context, fqn: str) -> AssetKey:
17-
"""Given the sqlmesh context and a fqn of a model return an asset key"""
18-
table = self.get_fqn_to_table(context, fqn)
19-
return AssetKey(table.name)
20-
21-
def get_fqn_to_table(self, context: Context, fqn: str) -> exp.Table:
22-
"""Given the sqlmesh context and a fqn return the table"""
23-
dialect = self._get_context_dialect(context)
24-
return sqlglot.to_table(fqn, dialect=dialect)
14+
path = self.get_asset_key_name(fqn)
15+
return AssetKey(path)
16+
17+
def get_asset_key_name(self, fqn: str) -> Sequence[str]:
18+
asset_path = re.findall(r"[A-Za-z0-9_\-]+", fqn)
19+
return asset_path
20+
21+
def get_asset_key_str(self, fqn: str) -> str:
22+
# This is an internal identifier used to map outputs and dependencies
23+
# it will not affect the existing AssetKeys
24+
# Only alphanumeric characters and underscores
25+
path = self.get_asset_key_name(fqn)
26+
27+
return "__dot__".join(path).replace("-", "__dash__")
28+
29+
def get_group_name(self, context: Context, model: Model) -> str:
30+
path = self.get_asset_key_name(model.fqn)
31+
return path[-2]
2532

2633
def _get_context_dialect(self, context: Context) -> str:
2734
return context.engine_adapter.dialect

dagster_sqlmesh/utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
from sqlmesh.core.snapshot import SnapshotId
22

33

4-
def sqlmesh_model_name_to_key(name: str) -> str:
5-
return name.replace(".", "_dot__")
6-
7-
8-
def key_to_sqlmesh_model_name(key: str) -> str:
9-
return key.replace("_dot__", ".")
10-
11-
124
def snapshot_id_to_model_name(snapshot_id: SnapshotId) -> str:
135
"""Convert a SnapshotId to its model name.
146

0 commit comments

Comments
 (0)