Skip to content

Commit 93cfd09

Browse files
committed
feat(debug): implement debug state saving and directory management
- Added functionality to save the current SQLMesh project state to a timestamped directory within the 'debug_runs' folder. - Enhanced `sample_project_root` fixture to create a debug directory after test runs for easier debugging. - Updated `DagsterTestContext` to include separate paths for Dagster and SQLMesh projects, improving clarity and organization. - Modified test cases to utilize the new debug state saving feature, facilitating better tracking of project states during tests. - Adjusted paths in `definitions.py` to ensure correct database file location relative to the SQLMesh project.
1 parent d3202d3 commit 93cfd09

File tree

4 files changed

+178
-38
lines changed

4 files changed

+178
-38
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,5 @@ sample/dagster_project/logs/
6868
sample/dagster_project/history/
6969
sample/dagster_project/schedules/
7070
tests/temp/
71+
72+
debug_runs/

dagster_sqlmesh/conftest.py

Lines changed: 119 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime as dt
12
import json
23
import logging
34
import os
@@ -6,7 +7,7 @@
67
import sys
78
import tempfile
89
import typing as t
9-
from dataclasses import dataclass
10+
from dataclasses import dataclass, field
1011

1112
import duckdb
1213
import polars
@@ -41,8 +42,32 @@ def sample_project_root() -> t.Iterator[str]:
4142
"""Creates a temporary project directory containing both SQLMesh and Dagster projects"""
4243
with tempfile.TemporaryDirectory() as tmp_dir:
4344
project_dir = shutil.copytree("sample", tmp_dir, dirs_exist_ok=True)
45+
4446
yield project_dir
4547

48+
# Create debug directory with timestamp AFTER test run
49+
debug_dir = os.path.join(
50+
os.path.dirname(os.path.dirname(__file__)), "debug_runs"
51+
)
52+
os.makedirs(debug_dir, exist_ok=True)
53+
timestamp = dt.datetime.now().strftime("%Y%m%d_%H%M%S")
54+
run_debug_dir = os.path.join(debug_dir, f"run_{timestamp}")
55+
56+
# Copy contents to debug directory
57+
try:
58+
shutil.copytree(tmp_dir, run_debug_dir, dirs_exist_ok=True)
59+
logger.info(
60+
f"Copied final test project contents to {run_debug_dir} for debugging"
61+
)
62+
except FileNotFoundError:
63+
logger.warning(
64+
f"Temporary directory {tmp_dir} not found during cleanup copy."
65+
)
66+
except Exception as e:
67+
logger.error(
68+
f"Error copying temporary directory {tmp_dir} to {run_debug_dir}: {e}"
69+
)
70+
4671

4772
@pytest.fixture
4873
def sample_sqlmesh_project(sample_project_root: str) -> t.Iterator[str]:
@@ -75,6 +100,9 @@ class SQLMeshTestContext:
75100
context_config: SQLMeshContextConfig
76101
project_path: str
77102

103+
# Internal state for backup/restore
104+
_backed_up_files: set[str] = field(default_factory=set, init=False)
105+
78106
def create_controller(
79107
self, enable_debug_console: bool = False
80108
) -> DagsterSQLMeshController:
@@ -160,6 +188,39 @@ def cleanup_modified_files(self) -> None:
160188
self.restore_model_file(model_name)
161189
self._backed_up_files.clear()
162190

191+
def save_sqlmesh_debug_state(self, name_suffix: str = "manual_save") -> str:
192+
"""Saves the current state of the SQLMesh project to the debug directory.
193+
194+
Copies the contents of the SQLMesh project directory (self.project_path)
195+
to a timestamped sub-directory within the 'debug_runs' folder.
196+
197+
Args:
198+
name_suffix: An optional suffix to append to the debug directory name
199+
to distinguish this save point (e.g., 'before_change',
200+
'after_plan'). Defaults to 'manual_save'.
201+
202+
Returns:
203+
The path to the created debug state directory.
204+
"""
205+
debug_dir_base = os.path.join(
206+
os.path.dirname(self.project_path), "..", "debug_runs"
207+
)
208+
os.makedirs(debug_dir_base, exist_ok=True)
209+
timestamp = dt.datetime.now().strftime("%Y%m%d_%H%M%S")
210+
run_debug_dir = os.path.join(
211+
debug_dir_base, f"sqlmesh_state_{timestamp}_{name_suffix}"
212+
)
213+
214+
try:
215+
shutil.copytree(self.project_path, run_debug_dir, dirs_exist_ok=True)
216+
logger.info(f"Saved SQLMesh project debug state to {run_debug_dir}")
217+
return run_debug_dir
218+
except Exception as e:
219+
logger.error(
220+
f"Error saving SQLMesh project debug state to {run_debug_dir}: {e}"
221+
)
222+
raise
223+
163224
def query(self, *args: t.Any, return_df: bool = False, **kwargs: t.Any) -> t.Any:
164225
"""Execute a query against the test database.
165226
@@ -477,7 +538,8 @@ def model_change_test_context(
477538
class DagsterTestContext:
478539
"""A test context for running Dagster"""
479540

480-
project_path: str
541+
dagster_project_path: str
542+
sqlmesh_project_path: str
481543

482544
def _run_command(self, cmd: list[str]) -> None:
483545
"""Execute a command and stream its output in real-time.
@@ -488,76 +550,93 @@ def _run_command(self, cmd: list[str]) -> None:
488550
Raises:
489551
subprocess.CalledProcessError: If the command returns non-zero exit code
490552
"""
553+
import io
491554
import queue
492555
import threading
493556
import typing as t
494557

495558
def stream_output(
496-
pipe: t.IO[str], output_queue: queue.Queue[str | None]
559+
pipe: t.IO[str], output_queue: queue.Queue[tuple[str, str | None]]
497560
) -> None:
498-
"""Stream output from a pipe to a queue."""
561+
"""Stream output from a pipe to a queue.
562+
563+
Args:
564+
pipe: The pipe to read from (stdout or stderr)
565+
output_queue: Queue to write output to, as (stream_type, line) tuples
566+
"""
567+
# Use a StringIO buffer to accumulate characters into lines
568+
buffer = io.StringIO()
569+
stream_type = "stdout" if pipe is process.stdout else "stderr"
570+
499571
try:
500572
while True:
501573
char = pipe.read(1)
502574
if not char:
575+
# Flush any remaining content in buffer
576+
remaining = buffer.getvalue()
577+
if remaining:
578+
output_queue.put((stream_type, remaining))
503579
break
504-
output_queue.put(char)
580+
581+
buffer.write(char)
582+
583+
# If we hit a newline, flush the buffer
584+
if char == "\n":
585+
output_queue.put((stream_type, buffer.getvalue()))
586+
buffer = io.StringIO()
505587
finally:
506-
output_queue.put(None) # Signal EOF
588+
buffer.close()
589+
output_queue.put((stream_type, None)) # Signal EOF
507590

508591
print(f"Running command: {' '.join(cmd)}")
592+
print(f"Current working directory: {os.getcwd()}")
593+
print(f"Changing to directory: {self.dagster_project_path}")
594+
595+
# Change to the dagster project directory before running the command
596+
os.chdir(self.dagster_project_path)
597+
509598
process = subprocess.Popen(
510599
cmd,
511600
stdout=subprocess.PIPE,
512601
stderr=subprocess.PIPE,
513602
text=True,
514603
universal_newlines=True,
604+
encoding="utf-8",
605+
errors="replace",
515606
)
516607

517608
if not process.stdout or not process.stderr:
518609
raise RuntimeError("Failed to open subprocess pipes")
519610

520-
# Create queues for stdout and stderr
521-
stdout_queue: queue.Queue[str | None] = queue.Queue()
522-
stderr_queue: queue.Queue[str | None] = queue.Queue()
611+
# Create a single queue for all output
612+
output_queue: queue.Queue[tuple[str, str | None]] = queue.Queue()
523613

524614
# Start threads to read from pipes
525615
stdout_thread = threading.Thread(
526-
target=stream_output, args=(process.stdout, stdout_queue)
616+
target=stream_output, args=(process.stdout, output_queue)
527617
)
528618
stderr_thread = threading.Thread(
529-
target=stream_output, args=(process.stderr, stderr_queue)
619+
target=stream_output, args=(process.stderr, output_queue)
530620
)
531621

532622
stdout_thread.daemon = True
533623
stderr_thread.daemon = True
534624
stdout_thread.start()
535625
stderr_thread.start()
536626

537-
# Read from queues and print output
538-
stdout_done = False
539-
stderr_done = False
627+
# Track which streams are still active
628+
active_streams = {"stdout", "stderr"}
540629

541-
while not (stdout_done and stderr_done):
542-
# Handle stdout
630+
# Read from queue and print output
631+
while active_streams:
543632
try:
544-
char = stdout_queue.get_nowait()
545-
if char is None:
546-
stdout_done = True
633+
stream_type, content = output_queue.get(timeout=0.1)
634+
if content is None:
635+
active_streams.remove(stream_type)
547636
else:
548-
print(char, end="", flush=True)
637+
print(content, end="", flush=True)
549638
except queue.Empty:
550-
pass
551-
552-
# Handle stderr
553-
try:
554-
char = stderr_queue.get_nowait()
555-
if char is None:
556-
stderr_done = True
557-
else:
558-
print(char, end="", flush=True)
559-
except queue.Empty:
560-
pass
639+
continue
561640

562641
stdout_thread.join()
563642
stderr_thread.join()
@@ -583,7 +662,10 @@ def asset_materialisation(
583662
"resources": {
584663
"sqlmesh": {
585664
"config": {
586-
"config": {"gateway": "local", "path": self.project_path}
665+
"config": {
666+
"gateway": "local",
667+
"path": self.sqlmesh_project_path,
668+
}
587669
}
588670
}
589671
}
@@ -610,7 +692,7 @@ def asset_materialisation(
610692
"asset",
611693
"materialize",
612694
"-f",
613-
os.path.join(self.project_path, "definitions.py"),
695+
os.path.join(self.dagster_project_path, "definitions.py"),
614696
"--select",
615697
",".join(assets),
616698
"--config-json",
@@ -633,7 +715,10 @@ def sample_dagster_test_context(
633715
sample_dagster_project: str,
634716
) -> t.Iterator[DagsterTestContext]:
635717
test_context = DagsterTestContext(
636-
project_path=os.path.join(sample_dagster_project),
718+
dagster_project_path=os.path.join(sample_dagster_project),
719+
sqlmesh_project_path=os.path.join(
720+
sample_dagster_project.replace("dagster_project", "sqlmesh_project")
721+
),
637722
)
638723
yield test_context
639724

dagster_sqlmesh/controller/tests_plan_and_run/test_model_code_change.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
from dagster_sqlmesh.conftest import DagsterTestContext, SQLMeshTestContext
6+
from dagster_sqlmesh.controller.base import PlanOptions
67

78
logger = logging.getLogger(__name__)
89

@@ -192,11 +193,63 @@ def test_given_model_chain_when_running_with_different_flags_then_behaves_as_exp
192193

193194
# sample_dagster_test_context.init_test_source()
194195

195-
sample_dagster_test_context.asset_materialisation(assets=["seed_model_1"])
196+
sample_dagster_test_context.asset_materialisation(
197+
assets=[
198+
"test_source",
199+
"seed_model_1",
200+
"seed_model_2",
201+
"staging_model_1",
202+
"staging_model_2",
203+
"intermediate_model_1",
204+
"full_model",
205+
],
206+
plan_options=PlanOptions(
207+
enable_preview=True,
208+
),
209+
)
210+
211+
# # # Modify intermediate_model_1 sql to cause breaking change
212+
# sample_sqlmesh_test_context.modify_model_file(
213+
# "intermediate_model_1.sql",
214+
# """
215+
# MODEL (
216+
# name sqlmesh_example.intermediate_model_1,
217+
# kind INCREMENTAL_BY_TIME_RANGE (
218+
# time_column event_date
219+
# ),
220+
# start '2020-01-01',
221+
# cron '@daily',
222+
# grain (id, event_date)
223+
# );
224+
225+
# SELECT
226+
# main.id,
227+
# main.item_id,
228+
# main.event_date,
229+
# CONCAT('item - ', main.item_id) as item_name
230+
# FROM sqlmesh_example.staging_model_1 AS main
231+
# INNER JOIN sqlmesh_example.staging_model_2 as sub
232+
# ON main.id = sub.id
233+
# WHERE
234+
# event_date BETWEEN @start_date AND @end_date
235+
# """,
236+
# )
237+
238+
# sample_dagster_test_context.asset_materialisation(assets=["intermediate_model_1"], plan_options=PlanOptions(skip_backfill=True, enable_preview=True, skip_tests=True))
196239

197-
# sample_dagster_test_context.asset_materialisation(assets=["test_source", "seed_model_1", "seed_model_2", "staging_model_1", "staging_model_2", "intermediate_model_1", "full_model"])
240+
intermediate_model_1_df = (
241+
sample_sqlmesh_test_context.query(
242+
"""
243+
SELECT *
244+
FROM sqlmesh_example.intermediate_model_1
245+
""",
246+
return_df=True,
247+
)
248+
.sort_values(by="item_id")
249+
.reset_index(drop=True)
250+
)
198251

199-
# sample_dagster_test_context.asset_materialisation(assets=["intermediate_model_1"])
252+
print(f"intermediate_model_1_df:\n{intermediate_model_1_df}")
200253

201254

202255
if __name__ == "__main__":

sample/dagster_project/definitions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
CURR_DIR = os.path.dirname(__file__)
1818
SQLMESH_PROJECT_PATH = os.path.abspath(os.path.join(CURR_DIR, "../sqlmesh_project"))
19-
DUCKDB_PATH = os.path.join(CURR_DIR, "../../db.db")
19+
DUCKDB_PATH = os.path.join(SQLMESH_PROJECT_PATH, "db.db")
2020

2121
sqlmesh_config = SQLMeshContextConfig(path=SQLMESH_PROJECT_PATH, gateway="local")
2222

0 commit comments

Comments
 (0)