diff --git a/data_platform/assets/ml/discord_alerts.py b/data_platform/assets/ml/discord_alerts.py index 79829ad..bf4cb62 100644 --- a/data_platform/assets/ml/discord_alerts.py +++ b/data_platform/assets/ml/discord_alerts.py @@ -118,6 +118,6 @@ def listing_alert( return MaterializeResult( metadata={ "notified": sent, - "min_elo_threshold": MetadataValue.float(config.min_elo), + "min_elo_threshold": MetadataValue.float(float(config.min_elo)), } ) diff --git a/tests/test_assets_elo.py b/tests/test_assets_elo.py new file mode 100644 index 0000000..2080869 --- /dev/null +++ b/tests/test_assets_elo.py @@ -0,0 +1,87 @@ +"""Tests for ELO rating and comparison assets.""" + +from unittest.mock import MagicMock, patch + +from dagster import build_asset_context + +from data_platform.assets.elo.elo import _ensure_schema, elo_comparisons, elo_ratings + + +class TestEnsureSchema: + def test_executes_create_schema_sql(self): + conn = MagicMock() + with patch( + "data_platform.assets.elo.elo.render_sql", + return_value="CREATE SCHEMA IF NOT EXISTS elo", + ): + _ensure_schema(conn) + conn.execute.assert_called_once() + + +class TestEloRatings: + @patch( + "data_platform.assets.elo.elo.render_sql", + return_value="CREATE TABLE IF NOT EXISTS elo.ratings ()", + ) + def test_creates_table_and_returns_metadata(self, mock_render): + postgres = MagicMock() + engine = MagicMock() + conn = MagicMock() + engine.begin.return_value.__enter__ = MagicMock(return_value=conn) + engine.begin.return_value.__exit__ = MagicMock(return_value=False) + postgres.get_engine.return_value = engine + + context = build_asset_context() + result = elo_ratings(context, postgres) + + assert result.metadata["schema"].value == "elo" + assert result.metadata["table"].value == "ratings" + # Two calls: _ensure_schema + create table + assert conn.execute.call_count == 2 + + @patch("data_platform.assets.elo.elo.render_sql", return_value="SQL") + def test_calls_get_engine(self, mock_render): + postgres = MagicMock() + engine = MagicMock() + conn = MagicMock() + engine.begin.return_value.__enter__ = MagicMock(return_value=conn) + engine.begin.return_value.__exit__ = MagicMock(return_value=False) + postgres.get_engine.return_value = engine + + context = build_asset_context() + elo_ratings(context, postgres) + postgres.get_engine.assert_called_once() + + +class TestEloComparisons: + @patch( + "data_platform.assets.elo.elo.render_sql", + return_value="CREATE TABLE IF NOT EXISTS elo.comparisons ()", + ) + def test_creates_table_and_returns_metadata(self, mock_render): + postgres = MagicMock() + engine = MagicMock() + conn = MagicMock() + engine.begin.return_value.__enter__ = MagicMock(return_value=conn) + engine.begin.return_value.__exit__ = MagicMock(return_value=False) + postgres.get_engine.return_value = engine + + context = build_asset_context() + result = elo_comparisons(context, postgres) + + assert result.metadata["schema"].value == "elo" + assert result.metadata["table"].value == "comparisons" + assert conn.execute.call_count == 2 + + @patch("data_platform.assets.elo.elo.render_sql", return_value="SQL") + def test_calls_get_engine(self, mock_render): + postgres = MagicMock() + engine = MagicMock() + conn = MagicMock() + engine.begin.return_value.__enter__ = MagicMock(return_value=conn) + engine.begin.return_value.__exit__ = MagicMock(return_value=False) + postgres.get_engine.return_value = engine + + context = build_asset_context() + elo_comparisons(context, postgres) + postgres.get_engine.assert_called_once() diff --git a/tests/test_dagster.py b/tests/test_dagster.py new file mode 100644 index 0000000..2b805f2 --- /dev/null +++ b/tests/test_dagster.py @@ -0,0 +1,121 @@ +"""Generic tests for Dagster definitions — schedules, jobs, and the Definitions object.""" + +import pytest +from dagster import DefaultScheduleStatus + +from data_platform.jobs import ( + elementary_refresh_job, + funda_ingestion_job, + funda_raw_quality_job, +) +from data_platform.schedules import ( + elementary_refresh_schedule, + funda_ingestion_schedule, + funda_raw_quality_schedule, +) + +ALL_SCHEDULES = [ + elementary_refresh_schedule, + funda_ingestion_schedule, + funda_raw_quality_schedule, +] + +ALL_JOBS = [ + elementary_refresh_job, + funda_ingestion_job, + funda_raw_quality_job, +] + + +# --------------------------------------------------------------------------- +# Generic schedule tests +# --------------------------------------------------------------------------- + + +class TestSchedulesGeneric: + """Property tests that apply to every schedule.""" + + @pytest.mark.parametrize("schedule", ALL_SCHEDULES, ids=lambda s: s.name) + def test_has_name(self, schedule): + assert schedule.name + + @pytest.mark.parametrize("schedule", ALL_SCHEDULES, ids=lambda s: s.name) + def test_has_valid_cron(self, schedule): + parts = schedule.cron_schedule.split() + assert len(parts) == 5, f"Expected 5-part cron, got {schedule.cron_schedule}" + + @pytest.mark.parametrize("schedule", ALL_SCHEDULES, ids=lambda s: s.name) + def test_has_job(self, schedule): + assert schedule.job is not None or schedule.job_name is not None + + @pytest.mark.parametrize("schedule", ALL_SCHEDULES, ids=lambda s: s.name) + def test_default_status_running(self, schedule): + assert schedule.default_status == DefaultScheduleStatus.RUNNING + + +# --------------------------------------------------------------------------- +# Generic job tests +# --------------------------------------------------------------------------- + + +class TestJobsGeneric: + """Property tests that apply to every job.""" + + @pytest.mark.parametrize("job", ALL_JOBS, ids=lambda j: j.name) + def test_has_name(self, job): + assert job.name + + @pytest.mark.parametrize("job", ALL_JOBS, ids=lambda j: j.name) + def test_has_description(self, job): + assert job.description + + +# --------------------------------------------------------------------------- +# Schedule-specific tests +# --------------------------------------------------------------------------- + + +class TestScheduleSpecific: + def test_elementary_schedule_daily(self): + assert elementary_refresh_schedule.cron_schedule == "0 9 * * *" + + def test_funda_ingestion_every_4_hours(self): + assert funda_ingestion_schedule.cron_schedule == "0 */4 * * *" + + def test_funda_quality_daily(self): + assert funda_raw_quality_schedule.cron_schedule == "0 8 * * *" + + def test_funda_ingestion_schedule_has_run_config_fn(self): + assert funda_ingestion_schedule._run_config_fn is not None + + +# --------------------------------------------------------------------------- +# Definitions integration test +# --------------------------------------------------------------------------- + + +class TestDefinitions: + def test_definitions_loads(self): + from data_platform.definitions import defs + + assert defs is not None + + def test_definitions_has_assets(self): + from data_platform.definitions import defs + + repo = defs.get_repository_def() + asset_keys = repo.asset_graph.get_all_asset_keys() + assert len(asset_keys) > 0 + + def test_definitions_has_jobs(self): + from data_platform.definitions import defs + + job = defs.resolve_job_def("funda_ingestion") + assert job is not None + + def test_definitions_has_resources(self): + from data_platform.definitions import defs + + repo = defs.get_repository_def() + # Resources are configured but we can verify they're present + assert repo is not None diff --git a/tests/test_ml.py b/tests/test_ml.py index 7f85cc3..955b980 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -1,18 +1,24 @@ -"""Tests for ML helper functions — preprocess, _best_run, _build_embed.""" +"""Tests for ML helper functions and asset functions.""" from types import SimpleNamespace from unittest.mock import MagicMock, patch +import numpy as np import pandas as pd import pytest +from dagster import build_asset_context from data_platform.assets.ml.elo_model import ( ALL_FEATURES, ENERGY_LABEL_MAP, + _cleanup_old_runs, preprocess, ) -from data_platform.assets.ml.discord_alerts import _build_embed -from data_platform.assets.ml.elo_inference import _best_run +from data_platform.assets.ml.discord_alerts import ( + DiscordNotificationConfig, + _build_embed, +) +from data_platform.assets.ml.elo_inference import EloInferenceConfig, _best_run # --------------------------------------------------------------------------- @@ -201,3 +207,381 @@ class TestBuildEmbed: def test_color_is_green(self): embed = _build_embed(self._make_row()) assert embed["color"] == 0x00B894 + + +# --------------------------------------------------------------------------- +# _cleanup_old_runs +# --------------------------------------------------------------------------- + + +class TestCleanupOldRuns: + @patch("data_platform.assets.ml.elo_model.mlflow") + def test_no_experiment_noop(self, mock_mlflow): + client = mock_mlflow.tracking.MlflowClient.return_value + client.get_experiment_by_name.return_value = None + context = MagicMock() + _cleanup_old_runs("nonexistent", context, keep=3) + client.delete_run.assert_not_called() + + @patch("data_platform.assets.ml.elo_model.mlflow") + def test_fewer_than_keep_noop(self, mock_mlflow): + client = mock_mlflow.tracking.MlflowClient.return_value + client.get_experiment_by_name.return_value = MagicMock(experiment_id="1") + client.search_runs.return_value = [MagicMock(), MagicMock()] + context = MagicMock() + _cleanup_old_runs("exp", context, keep=3) + client.delete_run.assert_not_called() + + @patch("data_platform.assets.ml.elo_model.mlflow") + def test_deletes_old_runs(self, mock_mlflow): + client = mock_mlflow.tracking.MlflowClient.return_value + client.get_experiment_by_name.return_value = MagicMock(experiment_id="1") + runs = [MagicMock() for _ in range(5)] + for i, run in enumerate(runs): + run.info.run_id = f"run_{i}" + client.search_runs.return_value = runs + context = MagicMock() + _cleanup_old_runs("exp", context, keep=2) + assert client.delete_run.call_count == 3 + client.delete_run.assert_any_call("run_2") + client.delete_run.assert_any_call("run_3") + client.delete_run.assert_any_call("run_4") + + @patch("data_platform.assets.ml.elo_model.mlflow") + def test_exactly_keep_noop(self, mock_mlflow): + client = mock_mlflow.tracking.MlflowClient.return_value + client.get_experiment_by_name.return_value = MagicMock(experiment_id="1") + client.search_runs.return_value = [MagicMock(), MagicMock(), MagicMock()] + context = MagicMock() + _cleanup_old_runs("exp", context, keep=3) + client.delete_run.assert_not_called() + + +# --------------------------------------------------------------------------- +# elo_prediction_model asset +# --------------------------------------------------------------------------- + + +def _make_training_df(n=20): + """Build a small training DataFrame with all required columns.""" + rng = np.random.default_rng(42) + data = { + "energy_label": rng.choice(["A", "B", "C", "D"], size=n), + "current_price": rng.integers(200_000, 500_000, size=n).astype(float), + "living_area": rng.integers(40, 150, size=n).astype(float), + "plot_area": rng.integers(0, 300, size=n).astype(float), + "bedrooms": rng.integers(1, 5, size=n).astype(float), + "rooms": rng.integers(2, 8, size=n).astype(float), + "construction_year": rng.integers(1900, 2025, size=n).astype(float), + "latitude": rng.uniform(51, 53, size=n), + "longitude": rng.uniform(3, 6, size=n), + "photo_count": rng.integers(1, 30, size=n).astype(float), + "views": rng.integers(10, 500, size=n).astype(float), + "saves": rng.integers(0, 50, size=n).astype(float), + "price_per_sqm": rng.integers(2000, 6000, size=n).astype(float), + "has_garden": rng.choice([True, False, None], size=n), + "has_balcony": rng.choice([True, False], size=n), + "has_solar_panels": rng.choice([True, False, None], size=n), + "has_heat_pump": rng.choice([True, False], size=n), + "has_roof_terrace": rng.choice([True, False, None], size=n), + "is_energy_efficient": rng.choice([True, False], size=n), + "is_monument": rng.choice([True, False], size=n), + "elo_rating": rng.uniform(1200, 1800, size=n), + } + return pd.DataFrame(data) + + +class TestEloPredictionModelAsset: + @patch("data_platform.assets.ml.elo_model._cleanup_old_runs") + @patch("data_platform.assets.ml.elo_model.mlflow") + def test_trains_and_returns_metadata(self, mock_mlflow, mock_cleanup): + # Setup MLflow mock + mock_run = MagicMock() + mock_run.info.run_id = "test-run-123" + mock_mlflow.start_run.return_value.__enter__ = MagicMock(return_value=mock_run) + mock_mlflow.start_run.return_value.__exit__ = MagicMock(return_value=False) + + # Setup Postgres mock + postgres = MagicMock() + engine = MagicMock() + postgres.get_engine.return_value = engine + mlflow_resource = MagicMock() + mlflow_resource.get_tracking_uri.return_value = "http://mlflow:5000" + + df = _make_training_df(20) + context = build_asset_context() + + from data_platform.assets.ml.elo_model import ( + EloModelConfig, + elo_prediction_model, + ) + + config = EloModelConfig(n_estimators=10) + + with patch("data_platform.assets.ml.elo_model.pd.read_sql", return_value=df): + result = elo_prediction_model(context, config, postgres, mlflow_resource) + + assert result.metadata["mlflow_run_id"].value == "test-run-123" + assert "rmse" in result.metadata + assert "mae" in result.metadata + assert "r2" in result.metadata + assert result.metadata["train_rows"] > 0 + assert result.metadata["test_rows"] > 0 + mock_cleanup.assert_called_once() + + @patch("data_platform.assets.ml.elo_model.mlflow") + def test_raises_with_too_few_rows(self, mock_mlflow): + postgres = MagicMock() + engine = MagicMock() + postgres.get_engine.return_value = engine + mlflow_resource = MagicMock() + mlflow_resource.get_tracking_uri.return_value = "http://mlflow:5000" + + df = _make_training_df(5) # fewer than 10 + context = build_asset_context() + + from data_platform.assets.ml.elo_model import ( + EloModelConfig, + elo_prediction_model, + ) + + config = EloModelConfig() + + with ( + patch("data_platform.assets.ml.elo_model.pd.read_sql", return_value=df), + pytest.raises(ValueError, match="Not enough rated listings"), + ): + elo_prediction_model(context, config, postgres, mlflow_resource) + + +# --------------------------------------------------------------------------- +# elo_inference asset +# --------------------------------------------------------------------------- + + +class TestEloInferenceAsset: + def _make_unscored_df(self, n=5): + """Build a DataFrame of unscored listings.""" + rng = np.random.default_rng(42) + data = { + "global_id": [f"listing_{i}" for i in range(n)], + "energy_label": rng.choice(["A", "B", "C"], size=n), + "current_price": rng.integers(200_000, 500_000, size=n).astype(float), + "living_area": rng.integers(40, 150, size=n).astype(float), + "plot_area": rng.integers(0, 300, size=n).astype(float), + "bedrooms": rng.integers(1, 5, size=n).astype(float), + "rooms": rng.integers(2, 8, size=n).astype(float), + "construction_year": rng.integers(1900, 2025, size=n).astype(float), + "latitude": rng.uniform(51, 53, size=n), + "longitude": rng.uniform(3, 6, size=n), + "photo_count": rng.integers(1, 30, size=n).astype(float), + "views": rng.integers(10, 500, size=n).astype(float), + "saves": rng.integers(0, 50, size=n).astype(float), + "price_per_sqm": rng.integers(2000, 6000, size=n).astype(float), + "has_garden": rng.choice([True, False], size=n), + "has_balcony": rng.choice([True, False], size=n), + "has_solar_panels": rng.choice([True, False], size=n), + "has_heat_pump": rng.choice([True, False], size=n), + "has_roof_terrace": rng.choice([True, False], size=n), + "is_energy_efficient": rng.choice([True, False], size=n), + "is_monument": rng.choice([True, False], size=n), + } + return pd.DataFrame(data) + + @patch("data_platform.assets.ml.elo_inference.mlflow") + @patch("data_platform.assets.ml.elo_inference.pd.read_sql") + @patch("data_platform.assets.ml.elo_inference.render_sql", return_value="SQL") + def test_scores_listings(self, mock_render, mock_read_sql, mock_mlflow): + from data_platform.assets.ml.elo_inference import elo_inference + + df = self._make_unscored_df(5) + mock_read_sql.return_value = df + + # Mock MLflow best run + best_run = MagicMock() + best_run.info.run_id = "run-abc" + best_run.data.metrics = {"rmse": 0.5} + client = mock_mlflow.tracking.MlflowClient.return_value + client.get_experiment_by_name.return_value = MagicMock(experiment_id="1") + client.search_runs.return_value = [best_run] + + # Mock model + mock_model = MagicMock() + mock_model.predict.return_value = np.random.default_rng(42).uniform( + -3, 3, size=5 + ) + mock_mlflow.lightgbm.load_model.return_value = mock_model + + # Mock resources + postgres = MagicMock() + engine = MagicMock() + conn = MagicMock() + engine.begin.return_value.__enter__ = MagicMock(return_value=conn) + engine.begin.return_value.__exit__ = MagicMock(return_value=False) + postgres.get_engine.return_value = engine + + mlflow_resource = MagicMock() + mlflow_resource.get_tracking_uri.return_value = "http://mlflow:5000" + + context = build_asset_context() + config = EloInferenceConfig() + + result = elo_inference(context, config, postgres, mlflow_resource) + + assert result.metadata["scored"] == 5 + assert result.metadata["mlflow_run_id"].value == "run-abc" + assert "predicted_elo_mean" in result.metadata + mock_model.predict.assert_called_once() + + @patch("data_platform.assets.ml.elo_inference.pd.read_sql") + @patch("data_platform.assets.ml.elo_inference.render_sql", return_value="SQL") + def test_empty_returns_zero_scored(self, mock_render, mock_read_sql): + from data_platform.assets.ml.elo_inference import elo_inference + + mock_read_sql.return_value = pd.DataFrame() + + postgres = MagicMock() + engine = MagicMock() + conn = MagicMock() + engine.begin.return_value.__enter__ = MagicMock(return_value=conn) + engine.begin.return_value.__exit__ = MagicMock(return_value=False) + postgres.get_engine.return_value = engine + + mlflow_resource = MagicMock() + context = build_asset_context() + config = EloInferenceConfig() + + result = elo_inference(context, config, postgres, mlflow_resource) + + assert result.metadata["scored"] == 0 + assert result.metadata["status"].value == "No new listings to score." + + +# --------------------------------------------------------------------------- +# listing_alert asset +# --------------------------------------------------------------------------- + + +class TestListingAlertAsset: + def _make_predictions_df(self, n=3): + """Build a DataFrame of high-ELO predictions.""" + data = { + "global_id": [f"listing_{i}" for i in range(n)], + "predicted_elo": [1650.0 + i * 10 for i in range(n)], + "current_price": [350_000 + i * 10_000 for i in range(n)], + "city": ["Amsterdam", "Rotterdam", "Utrecht"][:n], + "living_area": [80 + i * 5 for i in range(n)], + "rooms": [4, 5, 3][:n], + "energy_label": ["A", "B", "C"][:n], + "price_per_sqm": [4375, 4000, 3800][:n], + "title": [f"Teststraat {i}" for i in range(n)], + "url": [f"https://funda.nl/{i}" for i in range(n)], + } + return pd.DataFrame(data) + + @patch("data_platform.assets.ml.discord_alerts.requests.post") + @patch("data_platform.assets.ml.discord_alerts.pd.read_sql") + @patch("data_platform.assets.ml.discord_alerts.render_sql", return_value="SQL") + def test_sends_notifications(self, mock_render, mock_read_sql, mock_post): + from data_platform.assets.ml.discord_alerts import listing_alert + + df = self._make_predictions_df(3) + mock_read_sql.return_value = df + mock_post.return_value = MagicMock(status_code=200) + + postgres = MagicMock() + engine = MagicMock() + conn = MagicMock() + engine.begin.return_value.__enter__ = MagicMock(return_value=conn) + engine.begin.return_value.__exit__ = MagicMock(return_value=False) + postgres.get_engine.return_value = engine + + discord = MagicMock() + discord.get_webhook_url.return_value = "https://discord.com/webhook/test" + + context = build_asset_context() + config = DiscordNotificationConfig() + + result = listing_alert(context, config, postgres, discord) + + assert result.metadata["notified"] == 3 + mock_post.assert_called_once() + + @patch("data_platform.assets.ml.discord_alerts.pd.read_sql") + @patch("data_platform.assets.ml.discord_alerts.render_sql", return_value="SQL") + def test_empty_returns_zero_notified(self, mock_render, mock_read_sql): + from data_platform.assets.ml.discord_alerts import listing_alert + + mock_read_sql.return_value = pd.DataFrame() + + postgres = MagicMock() + engine = MagicMock() + conn = MagicMock() + engine.begin.return_value.__enter__ = MagicMock(return_value=conn) + engine.begin.return_value.__exit__ = MagicMock(return_value=False) + postgres.get_engine.return_value = engine + + discord = MagicMock() + context = build_asset_context() + config = DiscordNotificationConfig() + + result = listing_alert(context, config, postgres, discord) + + assert result.metadata["notified"] == 0 + + @patch("data_platform.assets.ml.discord_alerts.requests.post") + @patch("data_platform.assets.ml.discord_alerts.pd.read_sql") + @patch("data_platform.assets.ml.discord_alerts.render_sql", return_value="SQL") + def test_batches_large_result_set(self, mock_render, mock_read_sql, mock_post): + from data_platform.assets.ml.discord_alerts import listing_alert + + # 15 listings → should produce 2 batches (10 + 5) + data = { + "global_id": [f"listing_{i}" for i in range(15)], + "predicted_elo": [1700.0] * 15, + "current_price": [400_000] * 15, + "city": ["Amsterdam"] * 15, + "living_area": [90] * 15, + "rooms": [4] * 15, + "energy_label": ["A"] * 15, + "price_per_sqm": [4444] * 15, + "title": [f"Street {i}" for i in range(15)], + "url": [f"https://funda.nl/{i}" for i in range(15)], + } + mock_read_sql.return_value = pd.DataFrame(data) + mock_post.return_value = MagicMock(status_code=200) + + postgres = MagicMock() + engine = MagicMock() + conn = MagicMock() + engine.begin.return_value.__enter__ = MagicMock(return_value=conn) + engine.begin.return_value.__exit__ = MagicMock(return_value=False) + postgres.get_engine.return_value = engine + + discord = MagicMock() + discord.get_webhook_url.return_value = "https://discord.com/webhook/test" + + context = build_asset_context() + config = DiscordNotificationConfig() + + result = listing_alert(context, config, postgres, discord) + + assert result.metadata["notified"] == 15 + assert mock_post.call_count == 2 + + +# --------------------------------------------------------------------------- +# Config defaults +# --------------------------------------------------------------------------- + + +class TestMLConfigs: + def test_elo_inference_config_defaults(self): + config = EloInferenceConfig() + assert config.mlflow_experiment == "elo-rating-prediction" + assert config.metric == "rmse" + assert config.ascending is True + + def test_discord_notification_config_defaults(self): + config = DiscordNotificationConfig() + assert config.min_elo == 1600 diff --git a/tests/test_ops.py b/tests/test_ops.py new file mode 100644 index 0000000..138bf4a --- /dev/null +++ b/tests/test_ops.py @@ -0,0 +1,153 @@ +"""Tests for Dagster ops — elementary and source freshness.""" + +import subprocess +from unittest.mock import patch + +import pytest +from dagster import build_op_context + +from data_platform.ops.check_source_freshness import ( + SourceFreshnessConfig, +) +from data_platform.ops.elementary import ( + _elementary_schema_exists, + elementary_generate_report, + elementary_run_models, +) + + +# --------------------------------------------------------------------------- +# SourceFreshnessConfig +# --------------------------------------------------------------------------- + + +class TestSourceFreshnessConfig: + def test_accepts_source_name(self): + cfg = SourceFreshnessConfig(source_name="raw_funda") + assert cfg.source_name == "raw_funda" + + +# --------------------------------------------------------------------------- +# _elementary_schema_exists +# --------------------------------------------------------------------------- + + +class TestElementarySchemaExists: + @patch("data_platform.resources._retry_on_operational_error") + @patch("data_platform.ops.elementary.create_engine") + def test_returns_true_when_schema_present(self, mock_create_engine, mock_retry): + mock_retry.return_value = True + with patch.dict( + "os.environ", + { + "POSTGRES_USER": "u", + "POSTGRES_PASSWORD": "p", + "POSTGRES_HOST": "localhost", + "POSTGRES_PORT": "5432", + "POSTGRES_DB": "db", + }, + ): + result = _elementary_schema_exists() + assert result is True + mock_create_engine.assert_called_once() + + @patch("data_platform.resources._retry_on_operational_error") + @patch("data_platform.ops.elementary.create_engine") + def test_returns_false_when_schema_absent(self, mock_create_engine, mock_retry): + mock_retry.return_value = False + with patch.dict( + "os.environ", + { + "POSTGRES_USER": "u", + "POSTGRES_PASSWORD": "p", + "POSTGRES_HOST": "localhost", + "POSTGRES_PORT": "5432", + "POSTGRES_DB": "db", + }, + ): + result = _elementary_schema_exists() + assert result is False + + +# --------------------------------------------------------------------------- +# elementary_run_models +# --------------------------------------------------------------------------- + + +class TestElementaryRunModels: + @patch("data_platform.ops.elementary._elementary_schema_exists", return_value=True) + def test_skips_when_schema_exists(self, mock_exists): + context = build_op_context() + elementary_run_models(context) + mock_exists.assert_called_once() + + @patch( + "data_platform.ops.elementary.subprocess.run", + return_value=subprocess.CompletedProcess( + args=[], returncode=0, stdout="ok", stderr="" + ), + ) + @patch("data_platform.ops.elementary._elementary_schema_exists", return_value=False) + def test_runs_dbt_when_schema_missing(self, mock_exists, mock_run): + context = build_op_context() + elementary_run_models(context) + mock_run.assert_called_once() + args = mock_run.call_args[0][0] + assert "dbt" in args + assert "run" in args + assert "elementary" in args + + @patch( + "data_platform.ops.elementary.subprocess.run", + return_value=subprocess.CompletedProcess( + args=[], returncode=1, stdout="", stderr="error" + ), + ) + @patch("data_platform.ops.elementary._elementary_schema_exists", return_value=False) + def test_raises_on_dbt_failure(self, mock_exists, mock_run): + context = build_op_context() + with pytest.raises(Exception, match="dbt run elementary failed"): + elementary_run_models(context) + + +# --------------------------------------------------------------------------- +# elementary_generate_report +# --------------------------------------------------------------------------- + + +class TestElementaryGenerateReport: + @patch( + "data_platform.ops.elementary.subprocess.run", + return_value=subprocess.CompletedProcess( + args=[], returncode=0, stdout="report generated", stderr="" + ), + ) + def test_calls_edr_report(self, mock_run): + context = build_op_context() + elementary_generate_report(context) + mock_run.assert_called_once() + args = mock_run.call_args[0][0] + assert "edr" in args + assert "report" in args + + @patch( + "data_platform.ops.elementary.subprocess.run", + return_value=subprocess.CompletedProcess( + args=[], returncode=1, stdout="", stderr="fatal error" + ), + ) + def test_raises_on_failure(self, mock_run): + context = build_op_context() + with pytest.raises(Exception, match="edr report failed"): + elementary_generate_report(context) + + @patch( + "data_platform.ops.elementary.subprocess.run", + return_value=subprocess.CompletedProcess( + args=[], returncode=0, stdout="done", stderr="" + ), + ) + def test_success_returns_none(self, mock_run): + context = build_op_context() + result = elementary_generate_report(context) + assert result is None