Skip to content

Commit 63bd837

Browse files
authored
feat: allow passing in custom context (#42)
1 parent e95e718 commit 63bd837

File tree

6 files changed

+68
-27
lines changed

6 files changed

+68
-27
lines changed

dagster_sqlmesh/asset.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,13 @@
22
import typing as t
33

44
from dagster import AssetsDefinition, RetryPolicy, multi_asset
5+
from sqlmesh import Context
56

6-
from dagster_sqlmesh.controller import DagsterSQLMeshController
7+
from dagster_sqlmesh.controller import (
8+
ContextCls,
9+
ContextFactory,
10+
DagsterSQLMeshController,
11+
)
712
from dagster_sqlmesh.translator import SQLMeshDagsterTranslator
813

914
from .config import SQLMeshContextConfig
@@ -16,6 +21,7 @@ def sqlmesh_assets(
1621
*,
1722
environment: str,
1823
config: SQLMeshContextConfig,
24+
context_factory: ContextFactory[ContextCls] = lambda **kwargs: Context(**kwargs),
1925
name: str | None = None,
2026
dagster_sqlmesh_translator: SQLMeshDagsterTranslator | None = None,
2127
compute_kind: str = "sqlmesh",
@@ -25,7 +31,7 @@ def sqlmesh_assets(
2531
# For now we don't set this by default
2632
enabled_subsetting: bool = False,
2733
) -> t.Callable[[t.Callable[..., t.Any]], AssetsDefinition]:
28-
controller = DagsterSQLMeshController.setup_with_config(config)
34+
controller = DagsterSQLMeshController.setup_with_config(config=config, context_factory=context_factory)
2935
if not dagster_sqlmesh_translator:
3036
dagster_sqlmesh_translator = SQLMeshDagsterTranslator()
3137
conversion = controller.to_asset_outs(environment, dagster_sqlmesh_translator)
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
11
# ruff: noqa: F403 F401
2-
from .base import PlanOptions, RunOptions, SQLMeshController, SQLMeshInstance
2+
from .base import (
3+
DEFAULT_CONTEXT_FACTORY,
4+
ContextCls,
5+
ContextFactory,
6+
PlanOptions,
7+
RunOptions,
8+
SQLMeshController,
9+
SQLMeshInstance,
10+
)
311
from .dagster import DagsterSQLMeshController

dagster_sqlmesh/controller/base.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from contextlib import contextmanager
55
from dataclasses import dataclass
66
from types import MappingProxyType
7-
from typing import TypeVar
87

98
from sqlmesh.core.config import CategorizerConfig
109
from sqlmesh.core.console import set_console
@@ -27,8 +26,14 @@
2726

2827
logger = logging.getLogger(__name__)
2928

30-
T = TypeVar("T", bound="SQLMeshController")
29+
T = t.TypeVar("T", bound="SQLMeshController")
30+
ContextCls = t.TypeVar("ContextCls", bound=Context)
31+
ContextFactory = t.Callable[..., ContextCls]
3132

33+
def default_context_factory(**kwargs: t.Any) -> Context:
34+
return Context(**kwargs)
35+
36+
DEFAULT_CONTEXT_FACTORY: ContextFactory[Context] = default_context_factory
3237

3338
class PlanOptions(t.TypedDict):
3439
start: t.NotRequired[TimeLike]
@@ -88,7 +93,7 @@ def parse_fqn(self) -> SQLMeshParsedFQN:
8893
return parse_fqn(self.fqn)
8994

9095

91-
class SQLMeshInstance:
96+
class SQLMeshInstance(t.Generic[ContextCls]):
9297
"""
9398
A class that manages sqlmesh operations and context within a specific
9499
environment. This class will run sqlmesh in a separate thread.
@@ -110,15 +115,15 @@ class SQLMeshInstance:
110115
config: SQLMeshContextConfig
111116
console: EventConsole
112117
logger: logging.Logger
113-
context: Context
118+
context: ContextCls
114119
environment: str
115120

116121
def __init__(
117122
self,
118123
environment: str,
119124
console: EventConsole,
120125
config: SQLMeshContextConfig,
121-
context: Context,
126+
context: ContextCls,
122127
logger: logging.Logger,
123128
):
124129
self.environment = environment
@@ -167,7 +172,7 @@ def plan(
167172
def run_sqlmesh_thread(
168173
logger: logging.Logger,
169174
context: Context,
170-
controller: SQLMeshController,
175+
controller: SQLMeshController[ContextCls],
171176
environment: str,
172177
plan_options: PlanOptions,
173178
default_catalog: str,
@@ -251,7 +256,7 @@ def run(self, **run_options: t.Unpack[RunOptions]) -> t.Iterator[ConsoleEvent]:
251256
def run_sqlmesh_thread(
252257
logger: logging.Logger,
253258
context: Context,
254-
controller: SQLMeshController,
259+
controller: SQLMeshController[ContextCls],
255260
environment: str,
256261
run_options: RunOptions,
257262
) -> None:
@@ -364,8 +369,7 @@ def non_external_models_dag(self) -> t.Iterable[tuple[Model, set[str]]]:
364369
continue
365370
yield (model, deps)
366371

367-
368-
class SQLMeshController:
372+
class SQLMeshController(t.Generic[ContextCls]):
369373
"""Allows control of sqlmesh via a python interface. It is not suggested to
370374
use the constructor of this class directly, but instead use the provided
371375
`setup` or `setup_with_config` class methods.
@@ -405,37 +409,45 @@ class SQLMeshController:
405409
def setup(
406410
cls,
407411
path: str,
412+
*,
413+
context_factory: ContextFactory[ContextCls],
408414
gateway: str = "local",
409415
log_override: logging.Logger | None = None,
410-
) -> "SQLMeshController":
416+
) -> t.Self:
411417
return cls.setup_with_config(
412418
config=SQLMeshContextConfig(path=path, gateway=gateway),
413419
log_override=log_override,
420+
context_factory=context_factory,
414421
)
415422

416423
@classmethod
417424
def setup_with_config(
418-
cls: type[T],
425+
cls,
426+
*,
419427
config: SQLMeshContextConfig,
428+
context_factory: ContextFactory[ContextCls] = DEFAULT_CONTEXT_FACTORY,
420429
log_override: logging.Logger | None = None,
421-
) -> T:
430+
) -> t.Self:
422431
console = EventConsole(log_override=log_override) # type: ignore
423432
controller = cls(
424433
console=console,
425434
config=config,
426435
log_override=log_override,
436+
context_factory=context_factory,
427437
)
428438
return controller
429439

430440
def __init__(
431441
self,
432442
config: SQLMeshContextConfig,
433443
console: EventConsole,
444+
context_factory: ContextFactory[ContextCls],
434445
log_override: logging.Logger | None = None,
435446
) -> None:
436447
self.config = config
437448
self.console = console
438449
self.logger = log_override or logger
450+
self._context_factory = context_factory
439451
self._context_open = False
440452

441453
def set_logger(self, logger: logging.Logger) -> None:
@@ -448,20 +460,20 @@ def add_event_handler(self, handler: ConsoleEventHandler) -> str:
448460
def remove_event_handler(self, handler_id: str) -> None:
449461
self.console.remove_handler(handler_id)
450462

451-
def _create_context(self) -> Context:
463+
def _create_context(self) -> ContextCls:
452464
options: dict[str, t.Any] = dict(
453465
paths=self.config.path,
454466
gateway=self.config.gateway,
455467
)
456468
if self.config.sqlmesh_config:
457469
options["config"] = self.config.sqlmesh_config
458470
set_console(self.console)
459-
return Context(**options)
471+
return self._context_factory(**options)
460472

461473
@contextmanager
462474
def instance(
463475
self, environment: str, component: str = "unknown"
464-
) -> t.Iterator[SQLMeshInstance]:
476+
) -> t.Iterator[SQLMeshInstance[ContextCls]]:
465477
self.logger.info(
466478
f"Opening sqlmesh instance for env={environment} component={component}"
467479
)

dagster_sqlmesh/controller/dagster.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from ..translator import SQLMeshDagsterTranslator
88
from ..types import SQLMeshModelDep, SQLMeshMultiAssetOptions
99
from ..utils import sqlmesh_model_name_to_key
10-
from .base import SQLMeshController
10+
from .base import ContextCls, SQLMeshController
1111

1212
logger = logging.getLogger(__name__)
1313

1414

15-
class DagsterSQLMeshController(SQLMeshController):
15+
class DagsterSQLMeshController(SQLMeshController[ContextCls]):
1616
"""An extension of the sqlmesh controller specifically for dagster use"""
1717

1818
def to_asset_outs(

dagster_sqlmesh/resource.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
from sqlmesh.utils.dag import DAG
1313
from sqlmesh.utils.date import TimeLike
1414

15+
from dagster_sqlmesh.controller.base import (
16+
DEFAULT_CONTEXT_FACTORY,
17+
ContextCls,
18+
ContextFactory,
19+
)
20+
1521
from . import console
1622
from .config import SQLMeshContextConfig
1723
from .controller import PlanOptions, RunOptions
@@ -169,7 +175,9 @@ def report_event(self, event: console.ConsoleEvent) -> None:
169175
case console.StopPlanEvaluation:
170176
log_context.info("Plan evaluation completed")
171177
case console.StartEvaluationProgress(
172-
batched_intervals=batches, environment_naming_info=environment_naming_info, default_catalog=default_catalog
178+
batched_intervals=batches,
179+
environment_naming_info=environment_naming_info,
180+
default_catalog=default_catalog,
173181
):
174182
self.update_stage("run")
175183
log_context.info(
@@ -263,6 +271,7 @@ def run(
263271
self,
264272
context: AssetExecutionContext,
265273
*,
274+
context_factory: ContextFactory[ContextCls] = DEFAULT_CONTEXT_FACTORY,
266275
environment: str = "dev",
267276
start: TimeLike | None = None,
268277
end: TimeLike | None = None,
@@ -279,7 +288,7 @@ def run(
279288

280289
logger = context.log
281290

282-
controller = self.get_controller(logger)
291+
controller = self.get_controller(context_factory, logger)
283292

284293
with controller.instance(environment) as mesh:
285294
dag = mesh.models_dag()
@@ -325,13 +334,18 @@ def run(
325334
plan_options=plan_options,
326335
run_options=run_options,
327336
):
337+
logger.debug(f"sqlmesh event: {event}")
328338
event_handler.process_events(event)
329339

330340
yield from event_handler.notify_success(mesh.context)
331341

332342
def get_controller(
333-
self, log_override: logging.Logger | None = None
334-
) -> DagsterSQLMeshController:
343+
self,
344+
context_factory: ContextFactory[ContextCls],
345+
log_override: logging.Logger | None = None,
346+
) -> DagsterSQLMeshController[ContextCls]:
335347
return DagsterSQLMeshController.setup_with_config(
336-
self.config, log_override=log_override
348+
config=self.config,
349+
context_factory=context_factory,
350+
log_override=log_override,
337351
)

dagster_sqlmesh/testing/context.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import duckdb
66
import polars
7+
from sqlmesh import Context
78
from sqlmesh.utils.date import TimeLike
89

910
from dagster_sqlmesh.config import SQLMeshContextConfig
@@ -21,9 +22,9 @@ class SQLMeshTestContext:
2122
db_path: str
2223
context_config: SQLMeshContextConfig
2324

24-
def create_controller(self):
25+
def create_controller(self) -> DagsterSQLMeshController[Context]:
2526
return DagsterSQLMeshController.setup_with_config(
26-
self.context_config,
27+
config=self.context_config,
2728
)
2829

2930
def query(self, *args: t.Any, **kwargs: t.Any) -> list[t.Any]:

0 commit comments

Comments
 (0)