1
1
import logging
2
2
import typing as t
3
3
4
- from dagster import AssetExecutionContext , ConfigurableResource , MaterializeResult
4
+ from dagster import (
5
+ AssetExecutionContext ,
6
+ ConfigurableResource ,
7
+ MaterializeResult ,
8
+ )
5
9
from sqlmesh import Model
6
10
from sqlmesh .core .context import Context as SQLMeshContext
7
- from sqlmesh .core .snapshot import Snapshot
11
+ from sqlmesh .core .snapshot import Snapshot , SnapshotInfoLike , SnapshotTableInfo
8
12
from sqlmesh .utils .dag import DAG
9
13
from sqlmesh .utils .date import TimeLike
10
14
@@ -27,28 +31,36 @@ def __init__(self, sorted_dag: list[str], logger: logging.Logger) -> None:
27
31
self ._complete_update_status : dict [str , bool ] = {}
28
32
self ._sorted_dag = sorted_dag
29
33
self ._current_index = 0
34
+ self .finished_promotion = False
30
35
31
- def plan (self , batches : dict [Snapshot , int ]) -> None :
32
- self ._batches = batches
33
- self ._count : dict [Snapshot , int ] = {}
34
-
35
- incomplete_names = set ()
36
- for snapshot , count in self ._batches .items ():
37
- incomplete_names .add (snapshot .name )
38
- self ._count [snapshot ] = 0
36
+ def init_complete_update_status (self , snapshots : list [SnapshotTableInfo ]) -> None :
37
+ planned_model_names = set ()
38
+ for snapshot in snapshots :
39
+ planned_model_names .add (snapshot .name )
39
40
40
41
# Anything not in the plan should be listed as completed and queued for
41
42
# notification
42
43
self ._complete_update_status = {
43
- name : False for name in (set (self ._sorted_dag ) - incomplete_names )
44
+ name : False for name in (set (self ._sorted_dag ) - planned_model_names )
44
45
}
45
46
46
- def update (self , snapshot : Snapshot , _batch_idx : int ) -> tuple [int , int ]:
47
+ def update_promotion (self , snapshot : SnapshotInfoLike , promoted : bool ) -> None :
48
+ self ._complete_update_status [snapshot .name ] = promoted
49
+
50
+ def stop_promotion (self ) -> None :
51
+ self .finished_promotion = True
52
+
53
+ def plan (self , batches : dict [Snapshot , int ]) -> None :
54
+ self ._batches = batches
55
+ self ._count : dict [Snapshot , int ] = {}
56
+
57
+ for snapshot , _ in self ._batches .items ():
58
+ self ._count [snapshot ] = 0
59
+
60
+ def update_plan (self , snapshot : Snapshot , _batch_idx : int ) -> tuple [int , int ]:
47
61
self ._count [snapshot ] += 1
48
62
current_count = self ._count [snapshot ]
49
63
expected_count = self ._batches [snapshot ]
50
- if self ._batches [snapshot ] == self ._count [snapshot ]:
51
- self ._complete_update_status [snapshot .name ] = True
52
64
return (current_count , expected_count )
53
65
54
66
def notify_queue_next (self ) -> tuple [str , bool ] | None :
@@ -110,11 +122,12 @@ def __init__(
110
122
self ._tracker = MaterializationTracker (dag .sorted [:], self ._logger )
111
123
self ._stage = "plan"
112
124
113
- def process_events (
114
- self , sqlmesh_context : SQLMeshContext , event : console .ConsoleEvent
115
- ) -> t .Iterator [MaterializeResult ]:
125
+ def process_events (self , event : console .ConsoleEvent ) -> None :
116
126
self .report_event (event )
117
127
128
+ def notify_success (
129
+ self , sqlmesh_context : SQLMeshContext
130
+ ) -> t .Iterator [MaterializeResult ]:
118
131
notify = self ._tracker .notify_queue_next ()
119
132
while notify is not None :
120
133
completed_name , update_status = notify
@@ -146,6 +159,7 @@ def report_event(self, event: console.ConsoleEvent) -> None:
146
159
147
160
match event :
148
161
case console .StartPlanEvaluation (plan ):
162
+ self ._tracker .init_complete_update_status (plan .environment .snapshots )
149
163
log_context .info (
150
164
"Starting Plan Evaluation" ,
151
165
{
@@ -173,7 +187,7 @@ def report_event(self, event: console.ConsoleEvent) -> None:
173
187
case console .UpdateSnapshotEvaluationProgress (
174
188
snapshot , batch_idx , duration_ms
175
189
):
176
- done , expected = self ._tracker .update (snapshot , batch_idx )
190
+ done , expected = self ._tracker .update_plan (snapshot , batch_idx )
177
191
178
192
log_context .info (
179
193
"Snapshot progress update" ,
@@ -200,6 +214,21 @@ def report_event(self, event: console.ConsoleEvent) -> None:
200
214
[f"{ model !s} \n { model .__cause__ !s} " for model in models ]
201
215
)
202
216
log_context .error (f"sqlmesh failed models: { failed_models } " )
217
+ case console .UpdatePromotionProgress (snapshot , promoted ):
218
+ log_context .info (
219
+ "Promotion progress update" ,
220
+ {
221
+ "snapshot" : snapshot .name ,
222
+ "promoted" : promoted ,
223
+ },
224
+ )
225
+ self ._tracker .update_promotion (snapshot , promoted )
226
+ case console .StopPromotionProgress (success ):
227
+ self ._tracker .stop_promotion ()
228
+ if success :
229
+ log_context .info ("Promotion completed successfully" )
230
+ else :
231
+ log_context .error ("Promotion failed" )
203
232
case _:
204
233
log_context .debug ("Received event" )
205
234
@@ -237,6 +266,7 @@ def run(
237
266
start : TimeLike | None = None ,
238
267
end : TimeLike | None = None ,
239
268
restate_selected : bool = False ,
269
+ skip_run : bool = False ,
240
270
plan_options : PlanOptions | None = None ,
241
271
run_options : RunOptions | None = None ,
242
272
) -> t .Iterable [MaterializeResult ]:
@@ -287,10 +317,13 @@ def run(
287
317
end = end ,
288
318
select_models = select_models ,
289
319
restate_selected = restate_selected ,
320
+ skip_run = skip_run ,
290
321
plan_options = plan_options ,
291
322
run_options = run_options ,
292
323
):
293
- yield from event_handler .process_events (mesh .context , event )
324
+ event_handler .process_events (event )
325
+
326
+ yield from event_handler .notify_success (mesh .context )
294
327
295
328
def get_controller (
296
329
self , log_override : logging .Logger | None = None
0 commit comments