Skip to content

Commit 092e46e

Browse files
authored
feat: Translator for external deps, minor improvement to AssetOut (#41)
* feat: move translator to config, add support for kinds label and specify group through translator The SQLMeshDagsterTranslator has been simplified to return an AssetKey, the name of the AssetKey or a string equivalent, which deprecates utils.sqlmesh_model_name_to_key and utils.key_to_sqlmesh_model_name. The kinds label is now also added, to show technology labels on the UI, only when Dagster's version allows for it. Groups can now be specified through the translator, instead of having a fixed method. * chore: merge source for PR * fix: rename translator parameters and set defaults * fix: remove translator from controller and use helper for internal ref * chore: remove unused refs to translator * fix: make translator required param for controller * chore: formatting * chore: update internal module refs to be absolute * fix: undo singleton for translator
1 parent 7ad32cf commit 092e46e

File tree

10 files changed

+79
-63
lines changed

10 files changed

+79
-63
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ from dagster import (
2424
AssetExecutionContext,
2525
Definitions,
2626
)
27-
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource
27+
from dagster_sqlmesh import sqlmesh_assets, SQLMeshContextConfig, SQLMeshResource, SQLMeshDagsterTranslator
2828

2929
sqlmesh_config = SQLMeshContextConfig(path="/home/foo/sqlmesh_project", gateway="name-of-your-gateway")
3030

31-
@sqlmesh_assets(environment="dev", config=sqlmesh_config)
31+
@sqlmesh_assets(environment="dev", config=sqlmesh_config, translator=SQLMeshDagsterTranslator())
3232
def sqlmesh_project(context: AssetExecutionContext, sqlmesh: SQLMeshResource):
3333
yield from sqlmesh.run(context)
3434

dagster_sqlmesh/asset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44
from dagster import AssetsDefinition, RetryPolicy, multi_asset
55
from sqlmesh import Context
66

7+
from dagster_sqlmesh.config import SQLMeshContextConfig
78
from dagster_sqlmesh.controller import (
89
ContextCls,
910
ContextFactory,
1011
DagsterSQLMeshController,
1112
)
1213
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
1314

14-
from .config import SQLMeshContextConfig
15-
1615
logger = logging.getLogger(__name__)
1716

1817

@@ -34,7 +33,7 @@ def sqlmesh_assets(
3433
controller = DagsterSQLMeshController.setup_with_config(config=config, context_factory=context_factory)
3534
if not dagster_sqlmesh_translator:
3635
dagster_sqlmesh_translator = SQLMeshDagsterTranslator()
37-
conversion = controller.to_asset_outs(environment, dagster_sqlmesh_translator)
36+
conversion = controller.to_asset_outs(environment, translator=dagster_sqlmesh_translator)
3837

3938
return multi_asset(
4039
name=name,

dagster_sqlmesh/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ class SQLMeshContextConfig(Config):
3030
def sqlmesh_config(self) -> MeshConfig | None:
3131
if self.config_override:
3232
return MeshConfig.parse_obj(self.config_override)
33-
return None
33+
return None

dagster_sqlmesh/controller/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
from sqlmesh.utils.dag import DAG
1414
from sqlmesh.utils.date import TimeLike
1515

16-
from ..config import SQLMeshContextConfig
17-
from ..console import (
16+
from dagster_sqlmesh.config import SQLMeshContextConfig
17+
from dagster_sqlmesh.console import (
1818
ConsoleEvent,
1919
ConsoleEventHandler,
2020
ConsoleException,
2121
EventConsole,
2222
Plan,
2323
SnapshotCategorizer,
2424
)
25-
from ..events import ConsoleGenerator
25+
from dagster_sqlmesh.events import ConsoleGenerator
2626

2727
logger = logging.getLogger(__name__)
2828

dagster_sqlmesh/controller/dagster.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# pyright: reportPrivateImportUsage=false
22
import logging
3+
from inspect import signature
34

45
from dagster import AssetDep, AssetKey, AssetOut
56
from dagster._core.definitions.asset_dep import CoercibleToAssetDep
67

7-
from ..translator import SQLMeshDagsterTranslator
8-
from ..types import SQLMeshModelDep, SQLMeshMultiAssetOptions
9-
from ..utils import sqlmesh_model_name_to_key
10-
from .base import ContextCls, SQLMeshController
8+
from dagster_sqlmesh.controller.base import ContextCls, SQLMeshController
9+
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
10+
from dagster_sqlmesh.types import SQLMeshModelDep, SQLMeshMultiAssetOptions
11+
from dagster_sqlmesh.utils import get_asset_key_str
1112

1213
logger = logging.getLogger(__name__)
1314

@@ -16,18 +17,15 @@ class DagsterSQLMeshController(SQLMeshController[ContextCls]):
1617
"""An extension of the sqlmesh controller specifically for dagster use"""
1718

1819
def to_asset_outs(
19-
self, environment: str, translator: SQLMeshDagsterTranslator
20+
self, environment: str, translator: SQLMeshDagsterTranslator,
2021
) -> SQLMeshMultiAssetOptions:
2122
with self.instance(environment, "to_asset_outs") as instance:
2223
context = instance.context
2324
output = SQLMeshMultiAssetOptions()
2425
depsMap: dict[str, CoercibleToAssetDep] = {}
2526

2627
for model, deps in instance.non_external_models_dag():
27-
asset_key = translator.get_asset_key_from_model(
28-
context,
29-
model,
30-
)
28+
asset_key = translator.get_asset_key(context=context, fqn=model.fqn)
3129
model_deps = [
3230
SQLMeshModelDep(fqn=dep, model=context.get_model(dep))
3331
for dep in deps
@@ -38,18 +36,27 @@ def to_asset_outs(
3836
for dep in model_deps:
3937
if dep.model:
4038
internal_asset_deps.add(
41-
translator.get_asset_key_from_model(context, dep.model)
39+
translator.get_asset_key(context, dep.model.fqn)
4240
)
4341
else:
44-
table = translator.get_fqn_to_table(context, dep.fqn)
45-
key = translator.get_asset_key_fqn(context, dep.fqn)
42+
table = get_asset_key_str(dep.fqn)
43+
key = translator.get_asset_key(context, dep.fqn)
4644
internal_asset_deps.add(key)
4745
# create an external dep
48-
depsMap[table.name] = AssetDep(key)
49-
model_key = sqlmesh_model_name_to_key(model.name)
50-
output.outs[model_key] = AssetOut(
51-
key=asset_key, tags=asset_tags, is_required=False
52-
)
46+
depsMap[table] = AssetDep(key)
47+
model_key = get_asset_key_str(model.fqn)
48+
# If current Dagster supports "kinds", add labels for Dagster UI
49+
if "kinds" in signature(AssetOut).parameters:
50+
output.outs[model_key] = AssetOut(
51+
key=asset_key, tags=asset_tags, is_required=False,
52+
group_name=translator.get_group_name(context, model),
53+
kinds={"sqlmesh", translator._get_context_dialect(context).lower()}
54+
)
55+
else:
56+
output.outs[model_key] = AssetOut(
57+
key=asset_key, tags=asset_tags, is_required=False,
58+
group_name=translator.get_group_name(context, model)
59+
)
5360
output.internal_asset_deps[model_key] = internal_asset_deps
5461

5562
output.deps = list(depsMap.values())

dagster_sqlmesh/resource.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,16 @@
1515
from sqlmesh.utils.date import TimeLike
1616
from sqlmesh.utils.errors import SQLMeshError
1717

18+
from dagster_sqlmesh import console
19+
from dagster_sqlmesh.config import SQLMeshContextConfig
20+
from dagster_sqlmesh.controller import PlanOptions, RunOptions
1821
from dagster_sqlmesh.controller.base import (
1922
DEFAULT_CONTEXT_FACTORY,
2023
ContextCls,
2124
ContextFactory,
2225
)
23-
24-
from . import console
25-
from .config import SQLMeshContextConfig
26-
from .controller import PlanOptions, RunOptions
27-
from .controller.dagster import DagsterSQLMeshController
28-
from .utils import sqlmesh_model_name_to_key
26+
from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController
27+
from dagster_sqlmesh.utils import get_asset_key_str
2928

3029

3130
class MaterializationTracker:
@@ -147,7 +146,7 @@ def __init__(
147146
self._prefix = prefix
148147
self._context = context
149148
self._logger = context.log
150-
self._tracker = MaterializationTracker(dag.sorted[:], self._logger)
149+
self._tracker = MaterializationTracker(sorted_dag=dag.sorted[:], logger=self._logger)
151150
self._stage = "plan"
152151
self._errors: list[Exception] = []
153152
self._is_testing = is_testing
@@ -173,7 +172,8 @@ def notify_success(
173172
# We allow selecting models. That value is mapped to models_map.
174173
# If the model is not in models_map, we can skip any notification
175174
if model:
176-
output_key = sqlmesh_model_name_to_key(model.name)
175+
# Passing model.fqn to get internal unique asset key
176+
output_key = get_asset_key_str(model.fqn)
177177
if not self._is_testing:
178178
# Stupidly dagster when testing cannot use the following
179179
# method so we must specifically skip this when testing
@@ -227,7 +227,7 @@ def report_event(self, event: console.ConsoleEvent) -> None:
227227
log_context.info(
228228
"Snapshot progress update",
229229
{
230-
"asset_key": sqlmesh_model_name_to_key(snapshot.model.name),
230+
"asset_key": get_asset_key_str(snapshot.model.name),
231231
"progress": f"{done}/{expected}",
232232
"duration_ms": duration_ms,
233233
},
@@ -327,7 +327,10 @@ def run(
327327

328328
logger = context.log
329329

330-
controller = self.get_controller(context_factory, logger)
330+
controller = self.get_controller(
331+
context_factory=context_factory,
332+
log_override=logger
333+
)
331334

332335
with controller.instance(environment) as mesh:
333336
dag = mesh.models_dag()
@@ -338,7 +341,10 @@ def run(
338341
[model.fqn for model, _ in mesh.non_external_models_dag()]
339342
)
340343
selected_models_set, models_map, select_models = (
341-
self._get_selected_models_from_context(context, models)
344+
self._get_selected_models_from_context(
345+
context=context,
346+
models=models
347+
)
342348
)
343349

344350
if all_available_models == selected_models_set or select_models is None:
@@ -351,7 +357,8 @@ def run(
351357
logger.info(f"selected models: {select_models}")
352358

353359
event_handler = DagsterSQLMeshEventHandler(
354-
context, models_map, dag, "sqlmesh: ", is_testing=self.is_testing
360+
context=context, models_map=models_map, dag=dag,
361+
prefix="sqlmesh: ", is_testing=self.is_testing
355362
)
356363

357364
try:
@@ -397,7 +404,7 @@ def _get_selected_models_from_context(
397404
select_models: list[str] = []
398405
models_map = {}
399406
for key, model in models.items():
400-
if sqlmesh_model_name_to_key(model.name) in selected_output_names:
407+
if get_asset_key_str(model.fqn) in selected_output_names:
401408
models_map[key] = model
402409
select_models.append(model.name)
403410
return (
@@ -414,5 +421,5 @@ def get_controller(
414421
return DagsterSQLMeshController.setup_with_config(
415422
config=self.config,
416423
context_factory=context_factory,
417-
log_override=log_override,
424+
log_override=log_override
418425
)

dagster_sqlmesh/test_sqlmesh_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import polars
44

5-
from .testing import SQLMeshTestContext
5+
from dagster_sqlmesh.testing import SQLMeshTestContext
66

77
logger = logging.getLogger(__name__)
88

dagster_sqlmesh/translator.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1+
from collections.abc import Sequence
12

2-
import sqlglot
33
from dagster import AssetKey
44
from sqlglot import exp
55
from sqlmesh.core.context import Context
@@ -9,19 +9,20 @@
99
class SQLMeshDagsterTranslator:
1010
"""Translates sqlmesh objects for dagster"""
1111

12-
def get_asset_key_from_model(self, context: Context, model: Model) -> AssetKey:
12+
def get_asset_key(self, context: Context, fqn: str) -> AssetKey:
1313
"""Given the sqlmesh context and a model return the asset key"""
14-
return AssetKey(model.view_name)
15-
16-
def get_asset_key_fqn(self, context: Context, fqn: str) -> AssetKey:
17-
"""Given the sqlmesh context and a fqn of a model return an asset key"""
18-
table = self.get_fqn_to_table(context, fqn)
19-
return AssetKey(table.name)
20-
21-
def get_fqn_to_table(self, context: Context, fqn: str) -> exp.Table:
22-
"""Given the sqlmesh context and a fqn return the table"""
23-
dialect = self._get_context_dialect(context)
24-
return sqlglot.to_table(fqn, dialect=dialect)
14+
path = self.get_asset_key_name(fqn)
15+
return AssetKey(path)
16+
17+
def get_asset_key_name(self, fqn: str) -> Sequence[str]:
18+
table = exp.to_table(fqn)
19+
asset_key_name = [table.catalog, table.db, table.name]
20+
21+
return asset_key_name
22+
23+
def get_group_name(self, context: Context, model: Model) -> str:
24+
path = self.get_asset_key_name(model.fqn)
25+
return path[-2]
2526

2627
def _get_context_dialect(self, context: Context) -> str:
2728
return context.engine_adapter.dialect

dagster_sqlmesh/utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
from sqlglot import exp
12
from sqlmesh.core.snapshot import SnapshotId
23

34

4-
def sqlmesh_model_name_to_key(name: str) -> str:
5-
return name.replace(".", "_dot__")
6-
7-
8-
def key_to_sqlmesh_model_name(key: str) -> str:
9-
return key.replace("_dot__", ".")
10-
5+
def get_asset_key_str(fqn: str) -> str:
6+
# This is an internal identifier used to map outputs and dependencies
7+
# it will not affect the existing AssetKeys
8+
# Only alphanumeric characters and underscores
9+
table = exp.to_table(fqn)
10+
asset_key_name = [table.catalog, table.db, table.name]
11+
12+
return "sqlmesh__" + "_".join(asset_key_name)
1113

1214
def snapshot_id_to_model_name(snapshot_id: SnapshotId) -> str:
1315
"""Convert a SnapshotId to its model name.

sample/dagster_project/definitions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
sqlmesh_config = SQLMeshContextConfig(path=SQLMESH_PROJECT_PATH, gateway="local")
2222

2323

24-
@asset
24+
@asset(key=["db", "sources", "reset_asset"])
2525
def reset_asset() -> MaterializeResult:
2626
"""An asset used for testing this entire workflow. If the duckdb database is
2727
found, this will delete it. This allows us to continously test this dag if
@@ -34,7 +34,7 @@ def reset_asset() -> MaterializeResult:
3434
return MaterializeResult(metadata={"deleted": deleted})
3535

3636

37-
@asset(deps=[reset_asset])
37+
@asset(deps=[reset_asset], key=["db", "sources", "test_source"])
3838
def test_source() -> pl.DataFrame:
3939
"""Sets up the `test_source` table in duckdb that one of the sample sqlmesh
4040
models depends on"""

0 commit comments

Comments
 (0)