From 99c12505f7c270d2425a879019b36844edd5a93b Mon Sep 17 00:00:00 2001 From: "Reuven V. Gonzales" Date: Thu, 13 Mar 2025 00:39:34 +0000 Subject: [PATCH] feat: support tags in sqlmesh models --- dagster_sqlmesh/controller/dagster.py | 8 +++++-- dagster_sqlmesh/translator.py | 22 ++++++++----------- .../models/marts/full_model.sql | 4 ++++ .../models/staging/staging_model_1.sql | 6 ++++- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/dagster_sqlmesh/controller/dagster.py b/dagster_sqlmesh/controller/dagster.py index 9c221a4..b6b1a42 100644 --- a/dagster_sqlmesh/controller/dagster.py +++ b/dagster_sqlmesh/controller/dagster.py @@ -34,7 +34,7 @@ def to_asset_outs( if not model: # If no model is returned this seems to be an asset dependency continue - asset_out = translator.get_asset_key_from_model( + asset_key = translator.get_asset_key_from_model( context, model, ) @@ -43,6 +43,8 @@ def to_asset_outs( for dep in deps ] internal_asset_deps: t.Set[AssetKey] = set() + asset_tags = translator.get_tags(context, model) + for dep in model_deps: if dep.model: internal_asset_deps.add( @@ -55,7 +57,9 @@ def to_asset_outs( # create an external dep depsMap[table.name] = AssetDep(key) model_key = sqlmesh_model_name_to_key(model.name) - output.outs[model_key] = AssetOut(key=asset_out, is_required=False) + output.outs[model_key] = AssetOut( + key=asset_key, tags=asset_tags, is_required=False + ) output.internal_asset_deps[model_key] = internal_asset_deps output.deps = list(depsMap.values()) diff --git a/dagster_sqlmesh/translator.py b/dagster_sqlmesh/translator.py index 2375b7f..b631d04 100644 --- a/dagster_sqlmesh/translator.py +++ b/dagster_sqlmesh/translator.py @@ -1,3 +1,4 @@ +import typing as t import sqlglot from sqlglot import exp from sqlmesh.core.context import Context @@ -9,27 +10,22 @@ class SQLMeshDagsterTranslator: """Translates sqlmesh objects for dagster""" def get_asset_key_from_model(self, context: Context, model: Model) -> AssetKey: + """Given the sqlmesh context and a model return the asset key""" return AssetKey(model.view_name) def get_asset_key_fqn(self, context: Context, fqn: str) -> AssetKey: + """Given the sqlmesh context and a fqn of a model return an asset key""" table = self.get_fqn_to_table(context, fqn) return AssetKey(table.name) def get_fqn_to_table(self, context: Context, fqn: str) -> exp.Table: - dialect = self.get_context_dialect(context) + """Given the sqlmesh context and a fqn return the table""" + dialect = self._get_context_dialect(context) return sqlglot.to_table(fqn, dialect=dialect) - def get_context_dialect(self, context: Context) -> str: + def _get_context_dialect(self, context: Context) -> str: return context.engine_adapter.dialect - # def get_asset_deps( - # self, context: Context, model: Model, deps: List[SQLMeshModelDep] - # ) -> List[AssetKey]: - # asset_keys: List[AssetKey] = [] - # for dep in deps: - # if dep.model: - # asset_keys.append(AssetKey(dep.model.view_name)) - # else: - # parsed_fqn = dep.parse_fqn() - # asset_keys.append(AssetKey([parsed_fqn.view_name])) - # return asset_keys + def get_tags(self, context: Context, model: Model) -> t.Dict[str, str]: + """Given the sqlmesh context and a model return the tags for that model""" + return {k: "true" for k in model.tags} diff --git a/sample/sqlmesh_project/models/marts/full_model.sql b/sample/sqlmesh_project/models/marts/full_model.sql index 02a82a7..4a70afd 100644 --- a/sample/sqlmesh_project/models/marts/full_model.sql +++ b/sample/sqlmesh_project/models/marts/full_model.sql @@ -4,6 +4,10 @@ MODEL ( cron '@daily', grain item_id, audits (assert_positive_order_ids), + tags ( + "mart", + "full", + ) ); SELECT diff --git a/sample/sqlmesh_project/models/staging/staging_model_1.sql b/sample/sqlmesh_project/models/staging/staging_model_1.sql index 373be31..3232526 100644 --- a/sample/sqlmesh_project/models/staging/staging_model_1.sql +++ b/sample/sqlmesh_project/models/staging/staging_model_1.sql @@ -5,7 +5,11 @@ MODEL ( ), start '2020-01-01', cron '@daily', - grain (id, event_date) + grain (id, event_date), + tags ( + "staging", + "incremental" + ) ); SELECT