diff --git a/dagster_sqlmesh/resource.py b/dagster_sqlmesh/resource.py index 99609e7..fdd3f35 100644 --- a/dagster_sqlmesh/resource.py +++ b/dagster_sqlmesh/resource.py @@ -146,7 +146,9 @@ def __init__( self._prefix = prefix self._context = context self._logger = context.log - self._tracker = MaterializationTracker(sorted_dag=dag.sorted[:], logger=self._logger) + self._tracker = MaterializationTracker( + sorted_dag=dag.sorted[:], logger=self._logger + ) self._stage = "plan" self._errors: list[Exception] = [] self._is_testing = is_testing @@ -328,8 +330,7 @@ def run( logger = context.log controller = self.get_controller( - context_factory=context_factory, - log_override=logger + context_factory=context_factory, log_override=logger ) with controller.instance(environment) as mesh: @@ -341,10 +342,7 @@ def run( [model.fqn for model, _ in mesh.non_external_models_dag()] ) selected_models_set, models_map, select_models = ( - self._get_selected_models_from_context( - context=context, - models=models - ) + self._get_selected_models_from_context(context=context, models=models) ) if all_available_models == selected_models_set or select_models is None: @@ -356,11 +354,32 @@ def run( else: logger.info(f"selected models: {select_models}") - event_handler = DagsterSQLMeshEventHandler( - context=context, models_map=models_map, dag=dag, - prefix="sqlmesh: ", is_testing=self.is_testing + event_handler = self.create_event_handler( + context=context, + models_map=models_map, + dag=dag, + prefix="sqlmesh: ", + is_testing=self.is_testing, ) + def raise_for_sqlmesh_errors( + event_handler: DagsterSQLMeshEventHandler, + additional_errors: list[Exception] | None = None, + ) -> None: + additional_errors = additional_errors or [] + errors = event_handler.errors + if len(errors) + len(additional_errors) == 0: + return + for error in errors: + logger.error( + f"sqlmesh encountered the following error during sqlmesh {event_handler.stage}: {error}" + ) + raise PlanOrRunFailedError( + event_handler.stage, + f"sqlmesh failed during {event_handler.stage} with {len(event_handler.errors) + 1} errors", + [*errors, *additional_errors], + ) + try: for event in mesh.plan_and_run( start=start, @@ -376,16 +395,30 @@ def run( event_handler.process_events(event) except SQLMeshError as e: logger.error(f"sqlmesh error: {e}") - errors = event_handler.errors - for error in errors: - logger.error(f"sqlmesh encountered the following error during sqlmesh {event_handler.stage}: {error}") - raise PlanOrRunFailedError( - event_handler.stage, - f"sqlmesh failed during {event_handler.stage} with {len(event_handler.errors) + 1} errors", - [e, *event_handler.errors], - ) + raise_for_sqlmesh_errors(event_handler, [GenericSQLMeshError(str(e))]) + # Some errors do not raise exceptions immediately, so we need to check + # the event handler for any errors that may have been collected. + raise_for_sqlmesh_errors(event_handler) + yield from event_handler.notify_success(mesh.context) + def create_event_handler( + self, + *, + context: AssetExecutionContext, + dag: DAG[str], + models_map: dict[str, Model], + prefix: str, + is_testing: bool, + ) -> DagsterSQLMeshEventHandler: + return DagsterSQLMeshEventHandler( + context=context, + dag=dag, + models_map=models_map, + prefix=prefix, + is_testing=is_testing, + ) + def _get_selected_models_from_context( self, context: AssetExecutionContext, models: MappingProxyType[str, Model] ) -> tuple[set[str], dict[str, Model], list[str] | None]: @@ -421,5 +454,5 @@ def get_controller( return DagsterSQLMeshController.setup_with_config( config=self.config, context_factory=context_factory, - log_override=log_override + log_override=log_override, ) diff --git a/dagster_sqlmesh/test_resource.py b/dagster_sqlmesh/test_resource.py index bea2ee1..38a3924 100644 --- a/dagster_sqlmesh/test_resource.py +++ b/dagster_sqlmesh/test_resource.py @@ -1,6 +1,8 @@ +import typing as t + import dagster as dg -from dagster_sqlmesh.resource import PlanOrRunFailedError +from dagster_sqlmesh.resource import DagsterSQLMeshEventHandler, PlanOrRunFailedError from dagster_sqlmesh.testing import setup_testing_sqlmesh_test_context @@ -14,10 +16,19 @@ def test_sqlmesh_resource_should_report_no_errors( variables={"enable_model_failure": False} ) test_context.initialize_test_source() - resource = test_context.create_resource() + resource = test_context.create_resource() - for result in resource.run(dg_context): - pass + success = True + try: + for result in resource.run(dg_context): + pass + except PlanOrRunFailedError as e: + success = False + print(f"Plan or run failed with errors: {e.errors}") + except Exception as e: + success = False + print(f"An unexpected error occurred: {e}") + assert success, "Expected no errors, but an error was raised during the run." def test_sqlmesh_resource_properly_reports_errors( @@ -48,3 +59,40 @@ def test_sqlmesh_resource_properly_reports_errors( assert caught_failure, "Expected an error to be raised, but it was not." + +def test_sqlmesh_resource_properly_reports_errors_not_thrown( + sample_sqlmesh_project: str, sample_sqlmesh_db_path: str +): + dg_context = dg.build_asset_context() + test_context = setup_testing_sqlmesh_test_context( + db_path=sample_sqlmesh_db_path, + project_path=sample_sqlmesh_project, + variables={"enable_model_failure": False} + ) + test_context.initialize_test_source() + resource = test_context.create_resource() + def event_handler_factory(*args: t.Any, **kwargs: t.Any) -> DagsterSQLMeshEventHandler: + """Custom event handler factory for the SQLMesh resource.""" + handler = DagsterSQLMeshEventHandler(*args, **kwargs) + # Load it with an error + handler._errors = [Exception("testerror")] + return handler + resource.set_event_handler_factory(event_handler_factory) + + caught_failure = False + try: + for result in resource.run(dg_context): + pass + except PlanOrRunFailedError as e: + caught_failure = True + + expected_error_found = False + for err in e.errors: + print(f"Found error: {err}") + if "testerror" in str(err): + expected_error_found = True + break + assert expected_error_found, "Expected error 'testerror' not found in the error list." + + assert caught_failure, "Expected an error to be raised, but it was not." + diff --git a/dagster_sqlmesh/testing/context.py b/dagster_sqlmesh/testing/context.py index 4cb1166..dce47fa 100644 --- a/dagster_sqlmesh/testing/context.py +++ b/dagster_sqlmesh/testing/context.py @@ -17,7 +17,7 @@ from dagster_sqlmesh.controller.base import PlanOptions, RunOptions from dagster_sqlmesh.controller.dagster import DagsterSQLMeshController from dagster_sqlmesh.events import ConsoleRecorder -from dagster_sqlmesh.resource import SQLMeshResource +from dagster_sqlmesh.resource import DagsterSQLMeshEventHandler, SQLMeshResource logger = logging.getLogger(__name__) @@ -49,6 +49,41 @@ def setup_testing_sqlmesh_test_context( return SQLMeshTestContext(db_path=db_path, context_config=context_config) +class TestSQLMeshResource(SQLMeshResource): + """A test SQLMesh resource that can be used in tests. + + This resource is a subclass of SQLMeshResource and is used to run SQLMesh in tests. + It allows for easy setup and teardown of the SQLMesh context. + """ + + def __init__(self, config: SQLMeshContextConfig, is_testing: bool = False): + super().__init__(config=config, is_testing=is_testing) + def default_event_handler_factory(*args: t.Any, **kwargs: t.Any) -> DagsterSQLMeshEventHandler: + """Default event handler factory for the SQLMesh resource.""" + return DagsterSQLMeshEventHandler(*args, **kwargs) + self._event_handler_factory = default_event_handler_factory + + def set_event_handler_factory(self, event_handler_factory: t.Callable[..., DagsterSQLMeshEventHandler]) -> None: + """Set the event handler for the SQLMesh resource. + + Args: + event_handler (DagsterSQLMeshEventHandler): The event handler to set. + """ + self._event_handler_factory = event_handler_factory + + def create_event_handler(self, *args: t.Any, **kwargs: t.Any) -> DagsterSQLMeshEventHandler: + """Create a new event handler for the SQLMesh resource. + + Args: + *args: Positional arguments to pass to the event handler. + **kwargs: Keyword arguments to pass to the event handler. + + Returns: + DagsterSQLMeshEventHandler: The created event handler. + """ + return self._event_handler_factory(*args, **kwargs) + + @dataclass class SQLMeshTestContext: """A test context for running SQLMesh""" @@ -61,8 +96,8 @@ def create_controller(self) -> DagsterSQLMeshController[Context]: config=self.context_config, ) - def create_resource(self) -> SQLMeshResource: - return SQLMeshResource( + def create_resource(self) -> TestSQLMeshResource: + return TestSQLMeshResource( config=self.context_config, is_testing=True, ) diff --git a/uv.lock b/uv.lock index d65bb98..5874dc4 100644 --- a/uv.lock +++ b/uv.lock @@ -287,7 +287,7 @@ wheels = [ [[package]] name = "dagster-sqlmesh" -version = "0.15.0" +version = "0.16.0" source = { editable = "." } dependencies = [ { name = "dagster" },