test: expand unit tests for ml and dagster resources

This commit is contained in:
Stijnvandenbroek
2026-03-11 13:52:53 +00:00
parent 6d0114aafc
commit adcc112f90
5 changed files with 749 additions and 4 deletions

87
tests/test_assets_elo.py Normal file
View File

@@ -0,0 +1,87 @@
"""Tests for ELO rating and comparison assets."""
from unittest.mock import MagicMock, patch
from dagster import build_asset_context
from data_platform.assets.elo.elo import _ensure_schema, elo_comparisons, elo_ratings
class TestEnsureSchema:
def test_executes_create_schema_sql(self):
conn = MagicMock()
with patch(
"data_platform.assets.elo.elo.render_sql",
return_value="CREATE SCHEMA IF NOT EXISTS elo",
):
_ensure_schema(conn)
conn.execute.assert_called_once()
class TestEloRatings:
@patch(
"data_platform.assets.elo.elo.render_sql",
return_value="CREATE TABLE IF NOT EXISTS elo.ratings ()",
)
def test_creates_table_and_returns_metadata(self, mock_render):
postgres = MagicMock()
engine = MagicMock()
conn = MagicMock()
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
postgres.get_engine.return_value = engine
context = build_asset_context()
result = elo_ratings(context, postgres)
assert result.metadata["schema"].value == "elo"
assert result.metadata["table"].value == "ratings"
# Two calls: _ensure_schema + create table
assert conn.execute.call_count == 2
@patch("data_platform.assets.elo.elo.render_sql", return_value="SQL")
def test_calls_get_engine(self, mock_render):
postgres = MagicMock()
engine = MagicMock()
conn = MagicMock()
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
postgres.get_engine.return_value = engine
context = build_asset_context()
elo_ratings(context, postgres)
postgres.get_engine.assert_called_once()
class TestEloComparisons:
@patch(
"data_platform.assets.elo.elo.render_sql",
return_value="CREATE TABLE IF NOT EXISTS elo.comparisons ()",
)
def test_creates_table_and_returns_metadata(self, mock_render):
postgres = MagicMock()
engine = MagicMock()
conn = MagicMock()
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
postgres.get_engine.return_value = engine
context = build_asset_context()
result = elo_comparisons(context, postgres)
assert result.metadata["schema"].value == "elo"
assert result.metadata["table"].value == "comparisons"
assert conn.execute.call_count == 2
@patch("data_platform.assets.elo.elo.render_sql", return_value="SQL")
def test_calls_get_engine(self, mock_render):
postgres = MagicMock()
engine = MagicMock()
conn = MagicMock()
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
postgres.get_engine.return_value = engine
context = build_asset_context()
elo_comparisons(context, postgres)
postgres.get_engine.assert_called_once()

121
tests/test_dagster.py Normal file
View File

@@ -0,0 +1,121 @@
"""Generic tests for Dagster definitions — schedules, jobs, and the Definitions object."""
import pytest
from dagster import DefaultScheduleStatus
from data_platform.jobs import (
elementary_refresh_job,
funda_ingestion_job,
funda_raw_quality_job,
)
from data_platform.schedules import (
elementary_refresh_schedule,
funda_ingestion_schedule,
funda_raw_quality_schedule,
)
ALL_SCHEDULES = [
elementary_refresh_schedule,
funda_ingestion_schedule,
funda_raw_quality_schedule,
]
ALL_JOBS = [
elementary_refresh_job,
funda_ingestion_job,
funda_raw_quality_job,
]
# ---------------------------------------------------------------------------
# Generic schedule tests
# ---------------------------------------------------------------------------
class TestSchedulesGeneric:
"""Property tests that apply to every schedule."""
@pytest.mark.parametrize("schedule", ALL_SCHEDULES, ids=lambda s: s.name)
def test_has_name(self, schedule):
assert schedule.name
@pytest.mark.parametrize("schedule", ALL_SCHEDULES, ids=lambda s: s.name)
def test_has_valid_cron(self, schedule):
parts = schedule.cron_schedule.split()
assert len(parts) == 5, f"Expected 5-part cron, got {schedule.cron_schedule}"
@pytest.mark.parametrize("schedule", ALL_SCHEDULES, ids=lambda s: s.name)
def test_has_job(self, schedule):
assert schedule.job is not None or schedule.job_name is not None
@pytest.mark.parametrize("schedule", ALL_SCHEDULES, ids=lambda s: s.name)
def test_default_status_running(self, schedule):
assert schedule.default_status == DefaultScheduleStatus.RUNNING
# ---------------------------------------------------------------------------
# Generic job tests
# ---------------------------------------------------------------------------
class TestJobsGeneric:
"""Property tests that apply to every job."""
@pytest.mark.parametrize("job", ALL_JOBS, ids=lambda j: j.name)
def test_has_name(self, job):
assert job.name
@pytest.mark.parametrize("job", ALL_JOBS, ids=lambda j: j.name)
def test_has_description(self, job):
assert job.description
# ---------------------------------------------------------------------------
# Schedule-specific tests
# ---------------------------------------------------------------------------
class TestScheduleSpecific:
def test_elementary_schedule_daily(self):
assert elementary_refresh_schedule.cron_schedule == "0 9 * * *"
def test_funda_ingestion_every_4_hours(self):
assert funda_ingestion_schedule.cron_schedule == "0 */4 * * *"
def test_funda_quality_daily(self):
assert funda_raw_quality_schedule.cron_schedule == "0 8 * * *"
def test_funda_ingestion_schedule_has_run_config_fn(self):
assert funda_ingestion_schedule._run_config_fn is not None
# ---------------------------------------------------------------------------
# Definitions integration test
# ---------------------------------------------------------------------------
class TestDefinitions:
def test_definitions_loads(self):
from data_platform.definitions import defs
assert defs is not None
def test_definitions_has_assets(self):
from data_platform.definitions import defs
repo = defs.get_repository_def()
asset_keys = repo.asset_graph.get_all_asset_keys()
assert len(asset_keys) > 0
def test_definitions_has_jobs(self):
from data_platform.definitions import defs
job = defs.resolve_job_def("funda_ingestion")
assert job is not None
def test_definitions_has_resources(self):
from data_platform.definitions import defs
repo = defs.get_repository_def()
# Resources are configured but we can verify they're present
assert repo is not None

View File

@@ -1,18 +1,24 @@
"""Tests for ML helper functions — preprocess, _best_run, _build_embed."""
"""Tests for ML helper functions and asset functions."""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import numpy as np
import pandas as pd
import pytest
from dagster import build_asset_context
from data_platform.assets.ml.elo_model import (
ALL_FEATURES,
ENERGY_LABEL_MAP,
_cleanup_old_runs,
preprocess,
)
from data_platform.assets.ml.discord_alerts import _build_embed
from data_platform.assets.ml.elo_inference import _best_run
from data_platform.assets.ml.discord_alerts import (
DiscordNotificationConfig,
_build_embed,
)
from data_platform.assets.ml.elo_inference import EloInferenceConfig, _best_run
# ---------------------------------------------------------------------------
@@ -201,3 +207,381 @@ class TestBuildEmbed:
def test_color_is_green(self):
embed = _build_embed(self._make_row())
assert embed["color"] == 0x00B894
# ---------------------------------------------------------------------------
# _cleanup_old_runs
# ---------------------------------------------------------------------------
class TestCleanupOldRuns:
@patch("data_platform.assets.ml.elo_model.mlflow")
def test_no_experiment_noop(self, mock_mlflow):
client = mock_mlflow.tracking.MlflowClient.return_value
client.get_experiment_by_name.return_value = None
context = MagicMock()
_cleanup_old_runs("nonexistent", context, keep=3)
client.delete_run.assert_not_called()
@patch("data_platform.assets.ml.elo_model.mlflow")
def test_fewer_than_keep_noop(self, mock_mlflow):
client = mock_mlflow.tracking.MlflowClient.return_value
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
client.search_runs.return_value = [MagicMock(), MagicMock()]
context = MagicMock()
_cleanup_old_runs("exp", context, keep=3)
client.delete_run.assert_not_called()
@patch("data_platform.assets.ml.elo_model.mlflow")
def test_deletes_old_runs(self, mock_mlflow):
client = mock_mlflow.tracking.MlflowClient.return_value
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
runs = [MagicMock() for _ in range(5)]
for i, run in enumerate(runs):
run.info.run_id = f"run_{i}"
client.search_runs.return_value = runs
context = MagicMock()
_cleanup_old_runs("exp", context, keep=2)
assert client.delete_run.call_count == 3
client.delete_run.assert_any_call("run_2")
client.delete_run.assert_any_call("run_3")
client.delete_run.assert_any_call("run_4")
@patch("data_platform.assets.ml.elo_model.mlflow")
def test_exactly_keep_noop(self, mock_mlflow):
client = mock_mlflow.tracking.MlflowClient.return_value
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
client.search_runs.return_value = [MagicMock(), MagicMock(), MagicMock()]
context = MagicMock()
_cleanup_old_runs("exp", context, keep=3)
client.delete_run.assert_not_called()
# ---------------------------------------------------------------------------
# elo_prediction_model asset
# ---------------------------------------------------------------------------
def _make_training_df(n=20):
"""Build a small training DataFrame with all required columns."""
rng = np.random.default_rng(42)
data = {
"energy_label": rng.choice(["A", "B", "C", "D"], size=n),
"current_price": rng.integers(200_000, 500_000, size=n).astype(float),
"living_area": rng.integers(40, 150, size=n).astype(float),
"plot_area": rng.integers(0, 300, size=n).astype(float),
"bedrooms": rng.integers(1, 5, size=n).astype(float),
"rooms": rng.integers(2, 8, size=n).astype(float),
"construction_year": rng.integers(1900, 2025, size=n).astype(float),
"latitude": rng.uniform(51, 53, size=n),
"longitude": rng.uniform(3, 6, size=n),
"photo_count": rng.integers(1, 30, size=n).astype(float),
"views": rng.integers(10, 500, size=n).astype(float),
"saves": rng.integers(0, 50, size=n).astype(float),
"price_per_sqm": rng.integers(2000, 6000, size=n).astype(float),
"has_garden": rng.choice([True, False, None], size=n),
"has_balcony": rng.choice([True, False], size=n),
"has_solar_panels": rng.choice([True, False, None], size=n),
"has_heat_pump": rng.choice([True, False], size=n),
"has_roof_terrace": rng.choice([True, False, None], size=n),
"is_energy_efficient": rng.choice([True, False], size=n),
"is_monument": rng.choice([True, False], size=n),
"elo_rating": rng.uniform(1200, 1800, size=n),
}
return pd.DataFrame(data)
class TestEloPredictionModelAsset:
@patch("data_platform.assets.ml.elo_model._cleanup_old_runs")
@patch("data_platform.assets.ml.elo_model.mlflow")
def test_trains_and_returns_metadata(self, mock_mlflow, mock_cleanup):
# Setup MLflow mock
mock_run = MagicMock()
mock_run.info.run_id = "test-run-123"
mock_mlflow.start_run.return_value.__enter__ = MagicMock(return_value=mock_run)
mock_mlflow.start_run.return_value.__exit__ = MagicMock(return_value=False)
# Setup Postgres mock
postgres = MagicMock()
engine = MagicMock()
postgres.get_engine.return_value = engine
mlflow_resource = MagicMock()
mlflow_resource.get_tracking_uri.return_value = "http://mlflow:5000"
df = _make_training_df(20)
context = build_asset_context()
from data_platform.assets.ml.elo_model import (
EloModelConfig,
elo_prediction_model,
)
config = EloModelConfig(n_estimators=10)
with patch("data_platform.assets.ml.elo_model.pd.read_sql", return_value=df):
result = elo_prediction_model(context, config, postgres, mlflow_resource)
assert result.metadata["mlflow_run_id"].value == "test-run-123"
assert "rmse" in result.metadata
assert "mae" in result.metadata
assert "r2" in result.metadata
assert result.metadata["train_rows"] > 0
assert result.metadata["test_rows"] > 0
mock_cleanup.assert_called_once()
@patch("data_platform.assets.ml.elo_model.mlflow")
def test_raises_with_too_few_rows(self, mock_mlflow):
postgres = MagicMock()
engine = MagicMock()
postgres.get_engine.return_value = engine
mlflow_resource = MagicMock()
mlflow_resource.get_tracking_uri.return_value = "http://mlflow:5000"
df = _make_training_df(5) # fewer than 10
context = build_asset_context()
from data_platform.assets.ml.elo_model import (
EloModelConfig,
elo_prediction_model,
)
config = EloModelConfig()
with (
patch("data_platform.assets.ml.elo_model.pd.read_sql", return_value=df),
pytest.raises(ValueError, match="Not enough rated listings"),
):
elo_prediction_model(context, config, postgres, mlflow_resource)
# ---------------------------------------------------------------------------
# elo_inference asset
# ---------------------------------------------------------------------------
class TestEloInferenceAsset:
def _make_unscored_df(self, n=5):
"""Build a DataFrame of unscored listings."""
rng = np.random.default_rng(42)
data = {
"global_id": [f"listing_{i}" for i in range(n)],
"energy_label": rng.choice(["A", "B", "C"], size=n),
"current_price": rng.integers(200_000, 500_000, size=n).astype(float),
"living_area": rng.integers(40, 150, size=n).astype(float),
"plot_area": rng.integers(0, 300, size=n).astype(float),
"bedrooms": rng.integers(1, 5, size=n).astype(float),
"rooms": rng.integers(2, 8, size=n).astype(float),
"construction_year": rng.integers(1900, 2025, size=n).astype(float),
"latitude": rng.uniform(51, 53, size=n),
"longitude": rng.uniform(3, 6, size=n),
"photo_count": rng.integers(1, 30, size=n).astype(float),
"views": rng.integers(10, 500, size=n).astype(float),
"saves": rng.integers(0, 50, size=n).astype(float),
"price_per_sqm": rng.integers(2000, 6000, size=n).astype(float),
"has_garden": rng.choice([True, False], size=n),
"has_balcony": rng.choice([True, False], size=n),
"has_solar_panels": rng.choice([True, False], size=n),
"has_heat_pump": rng.choice([True, False], size=n),
"has_roof_terrace": rng.choice([True, False], size=n),
"is_energy_efficient": rng.choice([True, False], size=n),
"is_monument": rng.choice([True, False], size=n),
}
return pd.DataFrame(data)
@patch("data_platform.assets.ml.elo_inference.mlflow")
@patch("data_platform.assets.ml.elo_inference.pd.read_sql")
@patch("data_platform.assets.ml.elo_inference.render_sql", return_value="SQL")
def test_scores_listings(self, mock_render, mock_read_sql, mock_mlflow):
from data_platform.assets.ml.elo_inference import elo_inference
df = self._make_unscored_df(5)
mock_read_sql.return_value = df
# Mock MLflow best run
best_run = MagicMock()
best_run.info.run_id = "run-abc"
best_run.data.metrics = {"rmse": 0.5}
client = mock_mlflow.tracking.MlflowClient.return_value
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
client.search_runs.return_value = [best_run]
# Mock model
mock_model = MagicMock()
mock_model.predict.return_value = np.random.default_rng(42).uniform(
-3, 3, size=5
)
mock_mlflow.lightgbm.load_model.return_value = mock_model
# Mock resources
postgres = MagicMock()
engine = MagicMock()
conn = MagicMock()
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
postgres.get_engine.return_value = engine
mlflow_resource = MagicMock()
mlflow_resource.get_tracking_uri.return_value = "http://mlflow:5000"
context = build_asset_context()
config = EloInferenceConfig()
result = elo_inference(context, config, postgres, mlflow_resource)
assert result.metadata["scored"] == 5
assert result.metadata["mlflow_run_id"].value == "run-abc"
assert "predicted_elo_mean" in result.metadata
mock_model.predict.assert_called_once()
@patch("data_platform.assets.ml.elo_inference.pd.read_sql")
@patch("data_platform.assets.ml.elo_inference.render_sql", return_value="SQL")
def test_empty_returns_zero_scored(self, mock_render, mock_read_sql):
from data_platform.assets.ml.elo_inference import elo_inference
mock_read_sql.return_value = pd.DataFrame()
postgres = MagicMock()
engine = MagicMock()
conn = MagicMock()
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
postgres.get_engine.return_value = engine
mlflow_resource = MagicMock()
context = build_asset_context()
config = EloInferenceConfig()
result = elo_inference(context, config, postgres, mlflow_resource)
assert result.metadata["scored"] == 0
assert result.metadata["status"].value == "No new listings to score."
# ---------------------------------------------------------------------------
# listing_alert asset
# ---------------------------------------------------------------------------
class TestListingAlertAsset:
def _make_predictions_df(self, n=3):
"""Build a DataFrame of high-ELO predictions."""
data = {
"global_id": [f"listing_{i}" for i in range(n)],
"predicted_elo": [1650.0 + i * 10 for i in range(n)],
"current_price": [350_000 + i * 10_000 for i in range(n)],
"city": ["Amsterdam", "Rotterdam", "Utrecht"][:n],
"living_area": [80 + i * 5 for i in range(n)],
"rooms": [4, 5, 3][:n],
"energy_label": ["A", "B", "C"][:n],
"price_per_sqm": [4375, 4000, 3800][:n],
"title": [f"Teststraat {i}" for i in range(n)],
"url": [f"https://funda.nl/{i}" for i in range(n)],
}
return pd.DataFrame(data)
@patch("data_platform.assets.ml.discord_alerts.requests.post")
@patch("data_platform.assets.ml.discord_alerts.pd.read_sql")
@patch("data_platform.assets.ml.discord_alerts.render_sql", return_value="SQL")
def test_sends_notifications(self, mock_render, mock_read_sql, mock_post):
from data_platform.assets.ml.discord_alerts import listing_alert
df = self._make_predictions_df(3)
mock_read_sql.return_value = df
mock_post.return_value = MagicMock(status_code=200)
postgres = MagicMock()
engine = MagicMock()
conn = MagicMock()
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
postgres.get_engine.return_value = engine
discord = MagicMock()
discord.get_webhook_url.return_value = "https://discord.com/webhook/test"
context = build_asset_context()
config = DiscordNotificationConfig()
result = listing_alert(context, config, postgres, discord)
assert result.metadata["notified"] == 3
mock_post.assert_called_once()
@patch("data_platform.assets.ml.discord_alerts.pd.read_sql")
@patch("data_platform.assets.ml.discord_alerts.render_sql", return_value="SQL")
def test_empty_returns_zero_notified(self, mock_render, mock_read_sql):
from data_platform.assets.ml.discord_alerts import listing_alert
mock_read_sql.return_value = pd.DataFrame()
postgres = MagicMock()
engine = MagicMock()
conn = MagicMock()
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
postgres.get_engine.return_value = engine
discord = MagicMock()
context = build_asset_context()
config = DiscordNotificationConfig()
result = listing_alert(context, config, postgres, discord)
assert result.metadata["notified"] == 0
@patch("data_platform.assets.ml.discord_alerts.requests.post")
@patch("data_platform.assets.ml.discord_alerts.pd.read_sql")
@patch("data_platform.assets.ml.discord_alerts.render_sql", return_value="SQL")
def test_batches_large_result_set(self, mock_render, mock_read_sql, mock_post):
from data_platform.assets.ml.discord_alerts import listing_alert
# 15 listings → should produce 2 batches (10 + 5)
data = {
"global_id": [f"listing_{i}" for i in range(15)],
"predicted_elo": [1700.0] * 15,
"current_price": [400_000] * 15,
"city": ["Amsterdam"] * 15,
"living_area": [90] * 15,
"rooms": [4] * 15,
"energy_label": ["A"] * 15,
"price_per_sqm": [4444] * 15,
"title": [f"Street {i}" for i in range(15)],
"url": [f"https://funda.nl/{i}" for i in range(15)],
}
mock_read_sql.return_value = pd.DataFrame(data)
mock_post.return_value = MagicMock(status_code=200)
postgres = MagicMock()
engine = MagicMock()
conn = MagicMock()
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
postgres.get_engine.return_value = engine
discord = MagicMock()
discord.get_webhook_url.return_value = "https://discord.com/webhook/test"
context = build_asset_context()
config = DiscordNotificationConfig()
result = listing_alert(context, config, postgres, discord)
assert result.metadata["notified"] == 15
assert mock_post.call_count == 2
# ---------------------------------------------------------------------------
# Config defaults
# ---------------------------------------------------------------------------
class TestMLConfigs:
def test_elo_inference_config_defaults(self):
config = EloInferenceConfig()
assert config.mlflow_experiment == "elo-rating-prediction"
assert config.metric == "rmse"
assert config.ascending is True
def test_discord_notification_config_defaults(self):
config = DiscordNotificationConfig()
assert config.min_elo == 1600

153
tests/test_ops.py Normal file
View File

@@ -0,0 +1,153 @@
"""Tests for Dagster ops — elementary and source freshness."""
import subprocess
from unittest.mock import patch
import pytest
from dagster import build_op_context
from data_platform.ops.check_source_freshness import (
SourceFreshnessConfig,
)
from data_platform.ops.elementary import (
_elementary_schema_exists,
elementary_generate_report,
elementary_run_models,
)
# ---------------------------------------------------------------------------
# SourceFreshnessConfig
# ---------------------------------------------------------------------------
class TestSourceFreshnessConfig:
def test_accepts_source_name(self):
cfg = SourceFreshnessConfig(source_name="raw_funda")
assert cfg.source_name == "raw_funda"
# ---------------------------------------------------------------------------
# _elementary_schema_exists
# ---------------------------------------------------------------------------
class TestElementarySchemaExists:
@patch("data_platform.resources._retry_on_operational_error")
@patch("data_platform.ops.elementary.create_engine")
def test_returns_true_when_schema_present(self, mock_create_engine, mock_retry):
mock_retry.return_value = True
with patch.dict(
"os.environ",
{
"POSTGRES_USER": "u",
"POSTGRES_PASSWORD": "p",
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": "5432",
"POSTGRES_DB": "db",
},
):
result = _elementary_schema_exists()
assert result is True
mock_create_engine.assert_called_once()
@patch("data_platform.resources._retry_on_operational_error")
@patch("data_platform.ops.elementary.create_engine")
def test_returns_false_when_schema_absent(self, mock_create_engine, mock_retry):
mock_retry.return_value = False
with patch.dict(
"os.environ",
{
"POSTGRES_USER": "u",
"POSTGRES_PASSWORD": "p",
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": "5432",
"POSTGRES_DB": "db",
},
):
result = _elementary_schema_exists()
assert result is False
# ---------------------------------------------------------------------------
# elementary_run_models
# ---------------------------------------------------------------------------
class TestElementaryRunModels:
@patch("data_platform.ops.elementary._elementary_schema_exists", return_value=True)
def test_skips_when_schema_exists(self, mock_exists):
context = build_op_context()
elementary_run_models(context)
mock_exists.assert_called_once()
@patch(
"data_platform.ops.elementary.subprocess.run",
return_value=subprocess.CompletedProcess(
args=[], returncode=0, stdout="ok", stderr=""
),
)
@patch("data_platform.ops.elementary._elementary_schema_exists", return_value=False)
def test_runs_dbt_when_schema_missing(self, mock_exists, mock_run):
context = build_op_context()
elementary_run_models(context)
mock_run.assert_called_once()
args = mock_run.call_args[0][0]
assert "dbt" in args
assert "run" in args
assert "elementary" in args
@patch(
"data_platform.ops.elementary.subprocess.run",
return_value=subprocess.CompletedProcess(
args=[], returncode=1, stdout="", stderr="error"
),
)
@patch("data_platform.ops.elementary._elementary_schema_exists", return_value=False)
def test_raises_on_dbt_failure(self, mock_exists, mock_run):
context = build_op_context()
with pytest.raises(Exception, match="dbt run elementary failed"):
elementary_run_models(context)
# ---------------------------------------------------------------------------
# elementary_generate_report
# ---------------------------------------------------------------------------
class TestElementaryGenerateReport:
@patch(
"data_platform.ops.elementary.subprocess.run",
return_value=subprocess.CompletedProcess(
args=[], returncode=0, stdout="report generated", stderr=""
),
)
def test_calls_edr_report(self, mock_run):
context = build_op_context()
elementary_generate_report(context)
mock_run.assert_called_once()
args = mock_run.call_args[0][0]
assert "edr" in args
assert "report" in args
@patch(
"data_platform.ops.elementary.subprocess.run",
return_value=subprocess.CompletedProcess(
args=[], returncode=1, stdout="", stderr="fatal error"
),
)
def test_raises_on_failure(self, mock_run):
context = build_op_context()
with pytest.raises(Exception, match="edr report failed"):
elementary_generate_report(context)
@patch(
"data_platform.ops.elementary.subprocess.run",
return_value=subprocess.CompletedProcess(
args=[], returncode=0, stdout="done", stderr=""
),
)
def test_success_returns_none(self, mock_run):
context = build_op_context()
result = elementary_generate_report(context)
assert result is None