Skip to content

Commit 46b8d49

Browse files
committed
Further Improvements
1 parent 111ab2b commit 46b8d49

File tree

1 file changed

+163
-133
lines changed

1 file changed

+163
-133
lines changed

dagster_sqlmesh/controller/base.py

Lines changed: 163 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -80,89 +80,41 @@ def parse_fqn(self):
8080
return parse_fqn(self.fqn)
8181

8282

83-
class SQLMeshController:
84-
"""Allows control of sqlmesh via a python interface. It is not suggested to
85-
use the constructor of this class directly, but instead use the provided
86-
`setup` class method"""
87-
83+
class SQLMeshInstance:
8884
config: SQLMeshContextConfig
8985
console: EventConsole
9086
logger: logging.Logger
87+
context: Context
88+
environment: str
9189

92-
@classmethod
93-
def setup(
94-
cls,
90+
def __init__(
91+
self,
92+
environment: str,
93+
console: EventConsole,
9594
config: SQLMeshContextConfig,
96-
debug_console: t.Optional[Console] = None,
97-
log_override: t.Optional[logging.Logger] = None,
95+
context: Context,
96+
logger: logging.Logger,
9897
):
99-
console = EventConsole(log_override=log_override)
100-
if debug_console:
101-
console = DebugEventConsole(debug_console)
102-
controller = cls(
103-
console=console,
104-
config=config,
105-
)
106-
return controller
107-
108-
def __init__(self, config: SQLMeshContextConfig, console: EventConsole):
109-
self.config = config
98+
self.environment = environment
11099
self.console = console
100+
self.config = config
101+
self.context = context
111102
self.logger = logger
112103

113-
def set_logger(self, logger: logging.Logger):
114-
self.logger = logger
115-
116-
def add_event_handler(self, handler: ConsoleEventHandler):
117-
return self.console.add_handler(handler)
118-
119-
def remove_event_handler(self, handler_id: str):
120-
return self.console.remove_handler(handler_id)
121-
122-
def models_dag(self):
123-
with self.context() as context:
124-
return context.dag
125-
126-
def models(self):
127-
with self.context() as context:
128-
return context.models
129-
130-
def _create_context(self):
131-
options: t.Dict[str, t.Any] = dict(
132-
paths=self.config.path,
133-
gateway=self.config.gateway,
134-
console=self.console,
135-
)
136-
if self.config.sqlmesh_config:
137-
options["config"] = self.config.sqlmesh_config
138-
return Context(**options)
139-
140104
@contextmanager
141-
def context(self):
142-
context = self._create_context()
143-
yield context
144-
context.close()
105+
def console_context(self, handler: ConsoleEventHandler):
106+
id = self.console.add_handler(handler)
107+
yield
108+
self.console.remove_handler(id)
145109

146110
def plan(
147111
self,
148-
environment: str,
149112
categorizer: t.Optional[SnapshotCategorizer] = None,
150113
default_catalog: t.Optional[str] = None,
151114
**plan_options: t.Unpack[PlanOptions],
152115
):
153-
with self.context() as context:
154-
return self._plan(
155-
context, environment, categorizer, default_catalog, plan_options
156-
)
116+
context = self.context
157117

158-
def _plan(
159-
self,
160-
context: Context,
161-
environment: str,
162-
categorizer: t.Optional[SnapshotCategorizer],
163-
default_catalog: t.Optional[str],
164-
plan_options: PlanOptions,
165-
):
166118
# Runs things in thread
167119
def run_sqlmesh_thread(
168120
logger: logging.Logger,
@@ -190,46 +142,30 @@ def run_sqlmesh_thread(
190142
controller.console.exception(e)
191143

192144
generator = ConsoleGenerator(self.logger)
193-
event_id = self.add_event_handler(generator)
194-
195-
thread = threading.Thread(
196-
target=run_sqlmesh_thread,
197-
args=(
198-
self.logger,
199-
context,
200-
self,
201-
environment,
202-
plan_options,
203-
default_catalog,
204-
),
205-
)
206-
thread.start()
207-
208-
for event in generator.events(thread):
209-
match event:
210-
case ConsoleException(e):
211-
raise e
212-
case _:
213-
yield (context, event)
214-
215-
thread.join()
145+
with self.console_context(generator):
146+
thread = threading.Thread(
147+
target=run_sqlmesh_thread,
148+
args=(
149+
self.logger,
150+
context,
151+
self,
152+
self.environment,
153+
plan_options,
154+
default_catalog,
155+
),
156+
)
157+
thread.start()
216158

217-
self.remove_event_handler(event_id)
159+
for event in generator.events(thread):
160+
match event:
161+
case ConsoleException(e):
162+
raise e
163+
case _:
164+
yield event
218165

219-
def run(
220-
self,
221-
environment: str,
222-
**run_options: t.Unpack[RunOptions],
223-
):
224-
with self.context() as context:
225-
return self._run(context, environment, run_options=run_options)
166+
thread.join()
226167

227-
def _run(
228-
self,
229-
context: Context,
230-
environment: str,
231-
run_options: RunOptions,
232-
):
168+
def run(self, **run_options: t.Unpack[RunOptions]):
233169
# Runs things in thread
234170
def run_sqlmesh_thread(
235171
logger: logging.Logger,
@@ -245,29 +181,130 @@ def run_sqlmesh_thread(
245181
controller.console.exception(e)
246182

247183
generator = ConsoleGenerator(self.logger)
248-
event_id = self.add_event_handler(generator)
249-
250-
thread = threading.Thread(
251-
target=run_sqlmesh_thread,
252-
args=(
253-
self.logger,
254-
context,
255-
self,
256-
environment,
257-
run_options or {},
258-
),
184+
with self.console_context(generator):
185+
thread = threading.Thread(
186+
target=run_sqlmesh_thread,
187+
args=(
188+
self.logger,
189+
self.context,
190+
self,
191+
self.environment,
192+
run_options or {},
193+
),
194+
)
195+
thread.start()
196+
197+
for event in generator.events(thread):
198+
match event:
199+
case ConsoleException(e):
200+
raise e
201+
case _:
202+
yield event
203+
204+
thread.join()
205+
206+
def plan_and_run(
207+
self,
208+
categorizer: t.Optional[SnapshotCategorizer] = None,
209+
default_catalog: t.Optional[str] = None,
210+
plan_options: t.Optional[PlanOptions] = None,
211+
run_options: t.Optional[RunOptions] = None,
212+
):
213+
run_options = run_options or {}
214+
plan_options = plan_options or {}
215+
216+
yield from self.plan(categorizer, default_catalog, **plan_options)
217+
yield from self.run(**run_options)
218+
219+
def models(self):
220+
return self.context.models
221+
222+
def models_dag(self):
223+
return self.context.dag
224+
225+
226+
class SQLMeshController:
227+
"""Allows control of sqlmesh via a python interface. It is not suggested to
228+
use the constructor of this class directly, but instead use the provided
229+
`setup` class method"""
230+
231+
config: SQLMeshContextConfig
232+
console: EventConsole
233+
logger: logging.Logger
234+
235+
@classmethod
236+
def setup(
237+
cls,
238+
config: SQLMeshContextConfig,
239+
debug_console: t.Optional[Console] = None,
240+
log_override: t.Optional[logging.Logger] = None,
241+
):
242+
console = EventConsole(log_override=log_override)
243+
if debug_console:
244+
console = DebugEventConsole(debug_console)
245+
controller = cls(
246+
console=console,
247+
config=config,
259248
)
260-
thread.start()
249+
return controller
261250

262-
for event in generator.events(thread):
263-
match event:
264-
case ConsoleException(e):
265-
raise e
266-
case _:
267-
yield (context, event)
251+
def __init__(self, config: SQLMeshContextConfig, console: EventConsole):
252+
self.config = config
253+
self.console = console
254+
self.logger = logger
255+
self._context_open = False
268256

269-
thread.join()
270-
self.remove_event_handler(event_id)
257+
def set_logger(self, logger: logging.Logger):
258+
self.logger = logger
259+
260+
def add_event_handler(self, handler: ConsoleEventHandler):
261+
return self.console.add_handler(handler)
262+
263+
def remove_event_handler(self, handler_id: str):
264+
return self.console.remove_handler(handler_id)
265+
266+
def _create_context(self):
267+
options: t.Dict[str, t.Any] = dict(
268+
paths=self.config.path,
269+
gateway=self.config.gateway,
270+
console=self.console,
271+
)
272+
if self.config.sqlmesh_config:
273+
options["config"] = self.config.sqlmesh_config
274+
return Context(**options)
275+
276+
@contextmanager
277+
def instance(self, environment: str):
278+
if self._context_open:
279+
raise Exception("Only one sqlmesh instance at a time")
280+
281+
context = self._create_context()
282+
self._context_open = True
283+
try:
284+
yield SQLMeshInstance(
285+
environment, self.console, self.config, context, self.logger
286+
)
287+
finally:
288+
self._context_open = False
289+
context.close()
290+
291+
def run(
292+
self,
293+
environment: str,
294+
**run_options: t.Unpack[RunOptions],
295+
):
296+
with self.instance(environment) as mesh:
297+
yield from mesh.run(**run_options)
298+
299+
def plan(
300+
self,
301+
environment: str,
302+
categorizer: t.Optional[SnapshotCategorizer],
303+
default_catalog: t.Optional[str],
304+
plan_options: PlanOptions,
305+
):
306+
with self.instance(environment) as mesh:
307+
yield from mesh.plan(categorizer, default_catalog, **plan_options)
271308

272309
def plan_and_run(
273310
self,
@@ -277,17 +314,10 @@ def plan_and_run(
277314
plan_options: t.Optional[PlanOptions] = None,
278315
run_options: t.Optional[RunOptions] = None,
279316
):
280-
with self.context() as context:
281-
yield from self._plan(
282-
context,
283-
environment=environment,
317+
with self.instance(environment) as mesh:
318+
yield from mesh.plan_and_run(
284319
categorizer=categorizer,
285320
default_catalog=default_catalog,
286-
plan_options=plan_options or {},
287-
)
288-
289-
yield from self._run(
290-
context,
291-
environment=environment,
292-
run_options=run_options or {},
321+
plan_options=plan_options,
322+
run_options=run_options,
293323
)

0 commit comments

Comments
 (0)