test: expand unit tests for ml and dagster resources
This commit is contained in:
@@ -118,6 +118,6 @@ def listing_alert(
|
||||
return MaterializeResult(
|
||||
metadata={
|
||||
"notified": sent,
|
||||
"min_elo_threshold": MetadataValue.float(config.min_elo),
|
||||
"min_elo_threshold": MetadataValue.float(float(config.min_elo)),
|
||||
}
|
||||
)
|
||||
|
||||
87
tests/test_assets_elo.py
Normal file
87
tests/test_assets_elo.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Tests for ELO rating and comparison assets."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from dagster import build_asset_context
|
||||
|
||||
from data_platform.assets.elo.elo import _ensure_schema, elo_comparisons, elo_ratings
|
||||
|
||||
|
||||
class TestEnsureSchema:
|
||||
def test_executes_create_schema_sql(self):
|
||||
conn = MagicMock()
|
||||
with patch(
|
||||
"data_platform.assets.elo.elo.render_sql",
|
||||
return_value="CREATE SCHEMA IF NOT EXISTS elo",
|
||||
):
|
||||
_ensure_schema(conn)
|
||||
conn.execute.assert_called_once()
|
||||
|
||||
|
||||
class TestEloRatings:
|
||||
@patch(
|
||||
"data_platform.assets.elo.elo.render_sql",
|
||||
return_value="CREATE TABLE IF NOT EXISTS elo.ratings ()",
|
||||
)
|
||||
def test_creates_table_and_returns_metadata(self, mock_render):
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
conn = MagicMock()
|
||||
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
|
||||
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
postgres.get_engine.return_value = engine
|
||||
|
||||
context = build_asset_context()
|
||||
result = elo_ratings(context, postgres)
|
||||
|
||||
assert result.metadata["schema"].value == "elo"
|
||||
assert result.metadata["table"].value == "ratings"
|
||||
# Two calls: _ensure_schema + create table
|
||||
assert conn.execute.call_count == 2
|
||||
|
||||
@patch("data_platform.assets.elo.elo.render_sql", return_value="SQL")
|
||||
def test_calls_get_engine(self, mock_render):
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
conn = MagicMock()
|
||||
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
|
||||
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
postgres.get_engine.return_value = engine
|
||||
|
||||
context = build_asset_context()
|
||||
elo_ratings(context, postgres)
|
||||
postgres.get_engine.assert_called_once()
|
||||
|
||||
|
||||
class TestEloComparisons:
|
||||
@patch(
|
||||
"data_platform.assets.elo.elo.render_sql",
|
||||
return_value="CREATE TABLE IF NOT EXISTS elo.comparisons ()",
|
||||
)
|
||||
def test_creates_table_and_returns_metadata(self, mock_render):
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
conn = MagicMock()
|
||||
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
|
||||
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
postgres.get_engine.return_value = engine
|
||||
|
||||
context = build_asset_context()
|
||||
result = elo_comparisons(context, postgres)
|
||||
|
||||
assert result.metadata["schema"].value == "elo"
|
||||
assert result.metadata["table"].value == "comparisons"
|
||||
assert conn.execute.call_count == 2
|
||||
|
||||
@patch("data_platform.assets.elo.elo.render_sql", return_value="SQL")
|
||||
def test_calls_get_engine(self, mock_render):
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
conn = MagicMock()
|
||||
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
|
||||
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
postgres.get_engine.return_value = engine
|
||||
|
||||
context = build_asset_context()
|
||||
elo_comparisons(context, postgres)
|
||||
postgres.get_engine.assert_called_once()
|
||||
121
tests/test_dagster.py
Normal file
121
tests/test_dagster.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Generic tests for Dagster definitions — schedules, jobs, and the Definitions object."""
|
||||
|
||||
import pytest
|
||||
from dagster import DefaultScheduleStatus
|
||||
|
||||
from data_platform.jobs import (
|
||||
elementary_refresh_job,
|
||||
funda_ingestion_job,
|
||||
funda_raw_quality_job,
|
||||
)
|
||||
from data_platform.schedules import (
|
||||
elementary_refresh_schedule,
|
||||
funda_ingestion_schedule,
|
||||
funda_raw_quality_schedule,
|
||||
)
|
||||
|
||||
ALL_SCHEDULES = [
|
||||
elementary_refresh_schedule,
|
||||
funda_ingestion_schedule,
|
||||
funda_raw_quality_schedule,
|
||||
]
|
||||
|
||||
ALL_JOBS = [
|
||||
elementary_refresh_job,
|
||||
funda_ingestion_job,
|
||||
funda_raw_quality_job,
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Generic schedule tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSchedulesGeneric:
|
||||
"""Property tests that apply to every schedule."""
|
||||
|
||||
@pytest.mark.parametrize("schedule", ALL_SCHEDULES, ids=lambda s: s.name)
|
||||
def test_has_name(self, schedule):
|
||||
assert schedule.name
|
||||
|
||||
@pytest.mark.parametrize("schedule", ALL_SCHEDULES, ids=lambda s: s.name)
|
||||
def test_has_valid_cron(self, schedule):
|
||||
parts = schedule.cron_schedule.split()
|
||||
assert len(parts) == 5, f"Expected 5-part cron, got {schedule.cron_schedule}"
|
||||
|
||||
@pytest.mark.parametrize("schedule", ALL_SCHEDULES, ids=lambda s: s.name)
|
||||
def test_has_job(self, schedule):
|
||||
assert schedule.job is not None or schedule.job_name is not None
|
||||
|
||||
@pytest.mark.parametrize("schedule", ALL_SCHEDULES, ids=lambda s: s.name)
|
||||
def test_default_status_running(self, schedule):
|
||||
assert schedule.default_status == DefaultScheduleStatus.RUNNING
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Generic job tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJobsGeneric:
|
||||
"""Property tests that apply to every job."""
|
||||
|
||||
@pytest.mark.parametrize("job", ALL_JOBS, ids=lambda j: j.name)
|
||||
def test_has_name(self, job):
|
||||
assert job.name
|
||||
|
||||
@pytest.mark.parametrize("job", ALL_JOBS, ids=lambda j: j.name)
|
||||
def test_has_description(self, job):
|
||||
assert job.description
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schedule-specific tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestScheduleSpecific:
|
||||
def test_elementary_schedule_daily(self):
|
||||
assert elementary_refresh_schedule.cron_schedule == "0 9 * * *"
|
||||
|
||||
def test_funda_ingestion_every_4_hours(self):
|
||||
assert funda_ingestion_schedule.cron_schedule == "0 */4 * * *"
|
||||
|
||||
def test_funda_quality_daily(self):
|
||||
assert funda_raw_quality_schedule.cron_schedule == "0 8 * * *"
|
||||
|
||||
def test_funda_ingestion_schedule_has_run_config_fn(self):
|
||||
assert funda_ingestion_schedule._run_config_fn is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Definitions integration test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDefinitions:
|
||||
def test_definitions_loads(self):
|
||||
from data_platform.definitions import defs
|
||||
|
||||
assert defs is not None
|
||||
|
||||
def test_definitions_has_assets(self):
|
||||
from data_platform.definitions import defs
|
||||
|
||||
repo = defs.get_repository_def()
|
||||
asset_keys = repo.asset_graph.get_all_asset_keys()
|
||||
assert len(asset_keys) > 0
|
||||
|
||||
def test_definitions_has_jobs(self):
|
||||
from data_platform.definitions import defs
|
||||
|
||||
job = defs.resolve_job_def("funda_ingestion")
|
||||
assert job is not None
|
||||
|
||||
def test_definitions_has_resources(self):
|
||||
from data_platform.definitions import defs
|
||||
|
||||
repo = defs.get_repository_def()
|
||||
# Resources are configured but we can verify they're present
|
||||
assert repo is not None
|
||||
390
tests/test_ml.py
390
tests/test_ml.py
@@ -1,18 +1,24 @@
|
||||
"""Tests for ML helper functions — preprocess, _best_run, _build_embed."""
|
||||
"""Tests for ML helper functions and asset functions."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from dagster import build_asset_context
|
||||
|
||||
from data_platform.assets.ml.elo_model import (
|
||||
ALL_FEATURES,
|
||||
ENERGY_LABEL_MAP,
|
||||
_cleanup_old_runs,
|
||||
preprocess,
|
||||
)
|
||||
from data_platform.assets.ml.discord_alerts import _build_embed
|
||||
from data_platform.assets.ml.elo_inference import _best_run
|
||||
from data_platform.assets.ml.discord_alerts import (
|
||||
DiscordNotificationConfig,
|
||||
_build_embed,
|
||||
)
|
||||
from data_platform.assets.ml.elo_inference import EloInferenceConfig, _best_run
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -201,3 +207,381 @@ class TestBuildEmbed:
|
||||
def test_color_is_green(self):
|
||||
embed = _build_embed(self._make_row())
|
||||
assert embed["color"] == 0x00B894
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cleanup_old_runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupOldRuns:
|
||||
@patch("data_platform.assets.ml.elo_model.mlflow")
|
||||
def test_no_experiment_noop(self, mock_mlflow):
|
||||
client = mock_mlflow.tracking.MlflowClient.return_value
|
||||
client.get_experiment_by_name.return_value = None
|
||||
context = MagicMock()
|
||||
_cleanup_old_runs("nonexistent", context, keep=3)
|
||||
client.delete_run.assert_not_called()
|
||||
|
||||
@patch("data_platform.assets.ml.elo_model.mlflow")
|
||||
def test_fewer_than_keep_noop(self, mock_mlflow):
|
||||
client = mock_mlflow.tracking.MlflowClient.return_value
|
||||
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
|
||||
client.search_runs.return_value = [MagicMock(), MagicMock()]
|
||||
context = MagicMock()
|
||||
_cleanup_old_runs("exp", context, keep=3)
|
||||
client.delete_run.assert_not_called()
|
||||
|
||||
@patch("data_platform.assets.ml.elo_model.mlflow")
|
||||
def test_deletes_old_runs(self, mock_mlflow):
|
||||
client = mock_mlflow.tracking.MlflowClient.return_value
|
||||
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
|
||||
runs = [MagicMock() for _ in range(5)]
|
||||
for i, run in enumerate(runs):
|
||||
run.info.run_id = f"run_{i}"
|
||||
client.search_runs.return_value = runs
|
||||
context = MagicMock()
|
||||
_cleanup_old_runs("exp", context, keep=2)
|
||||
assert client.delete_run.call_count == 3
|
||||
client.delete_run.assert_any_call("run_2")
|
||||
client.delete_run.assert_any_call("run_3")
|
||||
client.delete_run.assert_any_call("run_4")
|
||||
|
||||
@patch("data_platform.assets.ml.elo_model.mlflow")
|
||||
def test_exactly_keep_noop(self, mock_mlflow):
|
||||
client = mock_mlflow.tracking.MlflowClient.return_value
|
||||
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
|
||||
client.search_runs.return_value = [MagicMock(), MagicMock(), MagicMock()]
|
||||
context = MagicMock()
|
||||
_cleanup_old_runs("exp", context, keep=3)
|
||||
client.delete_run.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# elo_prediction_model asset
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_training_df(n=20):
|
||||
"""Build a small training DataFrame with all required columns."""
|
||||
rng = np.random.default_rng(42)
|
||||
data = {
|
||||
"energy_label": rng.choice(["A", "B", "C", "D"], size=n),
|
||||
"current_price": rng.integers(200_000, 500_000, size=n).astype(float),
|
||||
"living_area": rng.integers(40, 150, size=n).astype(float),
|
||||
"plot_area": rng.integers(0, 300, size=n).astype(float),
|
||||
"bedrooms": rng.integers(1, 5, size=n).astype(float),
|
||||
"rooms": rng.integers(2, 8, size=n).astype(float),
|
||||
"construction_year": rng.integers(1900, 2025, size=n).astype(float),
|
||||
"latitude": rng.uniform(51, 53, size=n),
|
||||
"longitude": rng.uniform(3, 6, size=n),
|
||||
"photo_count": rng.integers(1, 30, size=n).astype(float),
|
||||
"views": rng.integers(10, 500, size=n).astype(float),
|
||||
"saves": rng.integers(0, 50, size=n).astype(float),
|
||||
"price_per_sqm": rng.integers(2000, 6000, size=n).astype(float),
|
||||
"has_garden": rng.choice([True, False, None], size=n),
|
||||
"has_balcony": rng.choice([True, False], size=n),
|
||||
"has_solar_panels": rng.choice([True, False, None], size=n),
|
||||
"has_heat_pump": rng.choice([True, False], size=n),
|
||||
"has_roof_terrace": rng.choice([True, False, None], size=n),
|
||||
"is_energy_efficient": rng.choice([True, False], size=n),
|
||||
"is_monument": rng.choice([True, False], size=n),
|
||||
"elo_rating": rng.uniform(1200, 1800, size=n),
|
||||
}
|
||||
return pd.DataFrame(data)
|
||||
|
||||
|
||||
class TestEloPredictionModelAsset:
|
||||
@patch("data_platform.assets.ml.elo_model._cleanup_old_runs")
|
||||
@patch("data_platform.assets.ml.elo_model.mlflow")
|
||||
def test_trains_and_returns_metadata(self, mock_mlflow, mock_cleanup):
|
||||
# Setup MLflow mock
|
||||
mock_run = MagicMock()
|
||||
mock_run.info.run_id = "test-run-123"
|
||||
mock_mlflow.start_run.return_value.__enter__ = MagicMock(return_value=mock_run)
|
||||
mock_mlflow.start_run.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# Setup Postgres mock
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
postgres.get_engine.return_value = engine
|
||||
mlflow_resource = MagicMock()
|
||||
mlflow_resource.get_tracking_uri.return_value = "http://mlflow:5000"
|
||||
|
||||
df = _make_training_df(20)
|
||||
context = build_asset_context()
|
||||
|
||||
from data_platform.assets.ml.elo_model import (
|
||||
EloModelConfig,
|
||||
elo_prediction_model,
|
||||
)
|
||||
|
||||
config = EloModelConfig(n_estimators=10)
|
||||
|
||||
with patch("data_platform.assets.ml.elo_model.pd.read_sql", return_value=df):
|
||||
result = elo_prediction_model(context, config, postgres, mlflow_resource)
|
||||
|
||||
assert result.metadata["mlflow_run_id"].value == "test-run-123"
|
||||
assert "rmse" in result.metadata
|
||||
assert "mae" in result.metadata
|
||||
assert "r2" in result.metadata
|
||||
assert result.metadata["train_rows"] > 0
|
||||
assert result.metadata["test_rows"] > 0
|
||||
mock_cleanup.assert_called_once()
|
||||
|
||||
@patch("data_platform.assets.ml.elo_model.mlflow")
|
||||
def test_raises_with_too_few_rows(self, mock_mlflow):
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
postgres.get_engine.return_value = engine
|
||||
mlflow_resource = MagicMock()
|
||||
mlflow_resource.get_tracking_uri.return_value = "http://mlflow:5000"
|
||||
|
||||
df = _make_training_df(5) # fewer than 10
|
||||
context = build_asset_context()
|
||||
|
||||
from data_platform.assets.ml.elo_model import (
|
||||
EloModelConfig,
|
||||
elo_prediction_model,
|
||||
)
|
||||
|
||||
config = EloModelConfig()
|
||||
|
||||
with (
|
||||
patch("data_platform.assets.ml.elo_model.pd.read_sql", return_value=df),
|
||||
pytest.raises(ValueError, match="Not enough rated listings"),
|
||||
):
|
||||
elo_prediction_model(context, config, postgres, mlflow_resource)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# elo_inference asset
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEloInferenceAsset:
|
||||
def _make_unscored_df(self, n=5):
|
||||
"""Build a DataFrame of unscored listings."""
|
||||
rng = np.random.default_rng(42)
|
||||
data = {
|
||||
"global_id": [f"listing_{i}" for i in range(n)],
|
||||
"energy_label": rng.choice(["A", "B", "C"], size=n),
|
||||
"current_price": rng.integers(200_000, 500_000, size=n).astype(float),
|
||||
"living_area": rng.integers(40, 150, size=n).astype(float),
|
||||
"plot_area": rng.integers(0, 300, size=n).astype(float),
|
||||
"bedrooms": rng.integers(1, 5, size=n).astype(float),
|
||||
"rooms": rng.integers(2, 8, size=n).astype(float),
|
||||
"construction_year": rng.integers(1900, 2025, size=n).astype(float),
|
||||
"latitude": rng.uniform(51, 53, size=n),
|
||||
"longitude": rng.uniform(3, 6, size=n),
|
||||
"photo_count": rng.integers(1, 30, size=n).astype(float),
|
||||
"views": rng.integers(10, 500, size=n).astype(float),
|
||||
"saves": rng.integers(0, 50, size=n).astype(float),
|
||||
"price_per_sqm": rng.integers(2000, 6000, size=n).astype(float),
|
||||
"has_garden": rng.choice([True, False], size=n),
|
||||
"has_balcony": rng.choice([True, False], size=n),
|
||||
"has_solar_panels": rng.choice([True, False], size=n),
|
||||
"has_heat_pump": rng.choice([True, False], size=n),
|
||||
"has_roof_terrace": rng.choice([True, False], size=n),
|
||||
"is_energy_efficient": rng.choice([True, False], size=n),
|
||||
"is_monument": rng.choice([True, False], size=n),
|
||||
}
|
||||
return pd.DataFrame(data)
|
||||
|
||||
@patch("data_platform.assets.ml.elo_inference.mlflow")
|
||||
@patch("data_platform.assets.ml.elo_inference.pd.read_sql")
|
||||
@patch("data_platform.assets.ml.elo_inference.render_sql", return_value="SQL")
|
||||
def test_scores_listings(self, mock_render, mock_read_sql, mock_mlflow):
|
||||
from data_platform.assets.ml.elo_inference import elo_inference
|
||||
|
||||
df = self._make_unscored_df(5)
|
||||
mock_read_sql.return_value = df
|
||||
|
||||
# Mock MLflow best run
|
||||
best_run = MagicMock()
|
||||
best_run.info.run_id = "run-abc"
|
||||
best_run.data.metrics = {"rmse": 0.5}
|
||||
client = mock_mlflow.tracking.MlflowClient.return_value
|
||||
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
|
||||
client.search_runs.return_value = [best_run]
|
||||
|
||||
# Mock model
|
||||
mock_model = MagicMock()
|
||||
mock_model.predict.return_value = np.random.default_rng(42).uniform(
|
||||
-3, 3, size=5
|
||||
)
|
||||
mock_mlflow.lightgbm.load_model.return_value = mock_model
|
||||
|
||||
# Mock resources
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
conn = MagicMock()
|
||||
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
|
||||
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
postgres.get_engine.return_value = engine
|
||||
|
||||
mlflow_resource = MagicMock()
|
||||
mlflow_resource.get_tracking_uri.return_value = "http://mlflow:5000"
|
||||
|
||||
context = build_asset_context()
|
||||
config = EloInferenceConfig()
|
||||
|
||||
result = elo_inference(context, config, postgres, mlflow_resource)
|
||||
|
||||
assert result.metadata["scored"] == 5
|
||||
assert result.metadata["mlflow_run_id"].value == "run-abc"
|
||||
assert "predicted_elo_mean" in result.metadata
|
||||
mock_model.predict.assert_called_once()
|
||||
|
||||
@patch("data_platform.assets.ml.elo_inference.pd.read_sql")
|
||||
@patch("data_platform.assets.ml.elo_inference.render_sql", return_value="SQL")
|
||||
def test_empty_returns_zero_scored(self, mock_render, mock_read_sql):
|
||||
from data_platform.assets.ml.elo_inference import elo_inference
|
||||
|
||||
mock_read_sql.return_value = pd.DataFrame()
|
||||
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
conn = MagicMock()
|
||||
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
|
||||
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
postgres.get_engine.return_value = engine
|
||||
|
||||
mlflow_resource = MagicMock()
|
||||
context = build_asset_context()
|
||||
config = EloInferenceConfig()
|
||||
|
||||
result = elo_inference(context, config, postgres, mlflow_resource)
|
||||
|
||||
assert result.metadata["scored"] == 0
|
||||
assert result.metadata["status"].value == "No new listings to score."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# listing_alert asset
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListingAlertAsset:
|
||||
def _make_predictions_df(self, n=3):
|
||||
"""Build a DataFrame of high-ELO predictions."""
|
||||
data = {
|
||||
"global_id": [f"listing_{i}" for i in range(n)],
|
||||
"predicted_elo": [1650.0 + i * 10 for i in range(n)],
|
||||
"current_price": [350_000 + i * 10_000 for i in range(n)],
|
||||
"city": ["Amsterdam", "Rotterdam", "Utrecht"][:n],
|
||||
"living_area": [80 + i * 5 for i in range(n)],
|
||||
"rooms": [4, 5, 3][:n],
|
||||
"energy_label": ["A", "B", "C"][:n],
|
||||
"price_per_sqm": [4375, 4000, 3800][:n],
|
||||
"title": [f"Teststraat {i}" for i in range(n)],
|
||||
"url": [f"https://funda.nl/{i}" for i in range(n)],
|
||||
}
|
||||
return pd.DataFrame(data)
|
||||
|
||||
@patch("data_platform.assets.ml.discord_alerts.requests.post")
|
||||
@patch("data_platform.assets.ml.discord_alerts.pd.read_sql")
|
||||
@patch("data_platform.assets.ml.discord_alerts.render_sql", return_value="SQL")
|
||||
def test_sends_notifications(self, mock_render, mock_read_sql, mock_post):
|
||||
from data_platform.assets.ml.discord_alerts import listing_alert
|
||||
|
||||
df = self._make_predictions_df(3)
|
||||
mock_read_sql.return_value = df
|
||||
mock_post.return_value = MagicMock(status_code=200)
|
||||
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
conn = MagicMock()
|
||||
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
|
||||
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
postgres.get_engine.return_value = engine
|
||||
|
||||
discord = MagicMock()
|
||||
discord.get_webhook_url.return_value = "https://discord.com/webhook/test"
|
||||
|
||||
context = build_asset_context()
|
||||
config = DiscordNotificationConfig()
|
||||
|
||||
result = listing_alert(context, config, postgres, discord)
|
||||
|
||||
assert result.metadata["notified"] == 3
|
||||
mock_post.assert_called_once()
|
||||
|
||||
@patch("data_platform.assets.ml.discord_alerts.pd.read_sql")
|
||||
@patch("data_platform.assets.ml.discord_alerts.render_sql", return_value="SQL")
|
||||
def test_empty_returns_zero_notified(self, mock_render, mock_read_sql):
|
||||
from data_platform.assets.ml.discord_alerts import listing_alert
|
||||
|
||||
mock_read_sql.return_value = pd.DataFrame()
|
||||
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
conn = MagicMock()
|
||||
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
|
||||
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
postgres.get_engine.return_value = engine
|
||||
|
||||
discord = MagicMock()
|
||||
context = build_asset_context()
|
||||
config = DiscordNotificationConfig()
|
||||
|
||||
result = listing_alert(context, config, postgres, discord)
|
||||
|
||||
assert result.metadata["notified"] == 0
|
||||
|
||||
@patch("data_platform.assets.ml.discord_alerts.requests.post")
|
||||
@patch("data_platform.assets.ml.discord_alerts.pd.read_sql")
|
||||
@patch("data_platform.assets.ml.discord_alerts.render_sql", return_value="SQL")
|
||||
def test_batches_large_result_set(self, mock_render, mock_read_sql, mock_post):
|
||||
from data_platform.assets.ml.discord_alerts import listing_alert
|
||||
|
||||
# 15 listings → should produce 2 batches (10 + 5)
|
||||
data = {
|
||||
"global_id": [f"listing_{i}" for i in range(15)],
|
||||
"predicted_elo": [1700.0] * 15,
|
||||
"current_price": [400_000] * 15,
|
||||
"city": ["Amsterdam"] * 15,
|
||||
"living_area": [90] * 15,
|
||||
"rooms": [4] * 15,
|
||||
"energy_label": ["A"] * 15,
|
||||
"price_per_sqm": [4444] * 15,
|
||||
"title": [f"Street {i}" for i in range(15)],
|
||||
"url": [f"https://funda.nl/{i}" for i in range(15)],
|
||||
}
|
||||
mock_read_sql.return_value = pd.DataFrame(data)
|
||||
mock_post.return_value = MagicMock(status_code=200)
|
||||
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
conn = MagicMock()
|
||||
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
|
||||
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
postgres.get_engine.return_value = engine
|
||||
|
||||
discord = MagicMock()
|
||||
discord.get_webhook_url.return_value = "https://discord.com/webhook/test"
|
||||
|
||||
context = build_asset_context()
|
||||
config = DiscordNotificationConfig()
|
||||
|
||||
result = listing_alert(context, config, postgres, discord)
|
||||
|
||||
assert result.metadata["notified"] == 15
|
||||
assert mock_post.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMLConfigs:
|
||||
def test_elo_inference_config_defaults(self):
|
||||
config = EloInferenceConfig()
|
||||
assert config.mlflow_experiment == "elo-rating-prediction"
|
||||
assert config.metric == "rmse"
|
||||
assert config.ascending is True
|
||||
|
||||
def test_discord_notification_config_defaults(self):
|
||||
config = DiscordNotificationConfig()
|
||||
assert config.min_elo == 1600
|
||||
|
||||
153
tests/test_ops.py
Normal file
153
tests/test_ops.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Tests for Dagster ops — elementary and source freshness."""
|
||||
|
||||
import subprocess
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from dagster import build_op_context
|
||||
|
||||
from data_platform.ops.check_source_freshness import (
|
||||
SourceFreshnessConfig,
|
||||
)
|
||||
from data_platform.ops.elementary import (
|
||||
_elementary_schema_exists,
|
||||
elementary_generate_report,
|
||||
elementary_run_models,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SourceFreshnessConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSourceFreshnessConfig:
|
||||
def test_accepts_source_name(self):
|
||||
cfg = SourceFreshnessConfig(source_name="raw_funda")
|
||||
assert cfg.source_name == "raw_funda"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _elementary_schema_exists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestElementarySchemaExists:
|
||||
@patch("data_platform.resources._retry_on_operational_error")
|
||||
@patch("data_platform.ops.elementary.create_engine")
|
||||
def test_returns_true_when_schema_present(self, mock_create_engine, mock_retry):
|
||||
mock_retry.return_value = True
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"POSTGRES_USER": "u",
|
||||
"POSTGRES_PASSWORD": "p",
|
||||
"POSTGRES_HOST": "localhost",
|
||||
"POSTGRES_PORT": "5432",
|
||||
"POSTGRES_DB": "db",
|
||||
},
|
||||
):
|
||||
result = _elementary_schema_exists()
|
||||
assert result is True
|
||||
mock_create_engine.assert_called_once()
|
||||
|
||||
@patch("data_platform.resources._retry_on_operational_error")
|
||||
@patch("data_platform.ops.elementary.create_engine")
|
||||
def test_returns_false_when_schema_absent(self, mock_create_engine, mock_retry):
|
||||
mock_retry.return_value = False
|
||||
with patch.dict(
|
||||
"os.environ",
|
||||
{
|
||||
"POSTGRES_USER": "u",
|
||||
"POSTGRES_PASSWORD": "p",
|
||||
"POSTGRES_HOST": "localhost",
|
||||
"POSTGRES_PORT": "5432",
|
||||
"POSTGRES_DB": "db",
|
||||
},
|
||||
):
|
||||
result = _elementary_schema_exists()
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# elementary_run_models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestElementaryRunModels:
|
||||
@patch("data_platform.ops.elementary._elementary_schema_exists", return_value=True)
|
||||
def test_skips_when_schema_exists(self, mock_exists):
|
||||
context = build_op_context()
|
||||
elementary_run_models(context)
|
||||
mock_exists.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"data_platform.ops.elementary.subprocess.run",
|
||||
return_value=subprocess.CompletedProcess(
|
||||
args=[], returncode=0, stdout="ok", stderr=""
|
||||
),
|
||||
)
|
||||
@patch("data_platform.ops.elementary._elementary_schema_exists", return_value=False)
|
||||
def test_runs_dbt_when_schema_missing(self, mock_exists, mock_run):
|
||||
context = build_op_context()
|
||||
elementary_run_models(context)
|
||||
mock_run.assert_called_once()
|
||||
args = mock_run.call_args[0][0]
|
||||
assert "dbt" in args
|
||||
assert "run" in args
|
||||
assert "elementary" in args
|
||||
|
||||
@patch(
|
||||
"data_platform.ops.elementary.subprocess.run",
|
||||
return_value=subprocess.CompletedProcess(
|
||||
args=[], returncode=1, stdout="", stderr="error"
|
||||
),
|
||||
)
|
||||
@patch("data_platform.ops.elementary._elementary_schema_exists", return_value=False)
|
||||
def test_raises_on_dbt_failure(self, mock_exists, mock_run):
|
||||
context = build_op_context()
|
||||
with pytest.raises(Exception, match="dbt run elementary failed"):
|
||||
elementary_run_models(context)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# elementary_generate_report
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestElementaryGenerateReport:
|
||||
@patch(
|
||||
"data_platform.ops.elementary.subprocess.run",
|
||||
return_value=subprocess.CompletedProcess(
|
||||
args=[], returncode=0, stdout="report generated", stderr=""
|
||||
),
|
||||
)
|
||||
def test_calls_edr_report(self, mock_run):
|
||||
context = build_op_context()
|
||||
elementary_generate_report(context)
|
||||
mock_run.assert_called_once()
|
||||
args = mock_run.call_args[0][0]
|
||||
assert "edr" in args
|
||||
assert "report" in args
|
||||
|
||||
@patch(
|
||||
"data_platform.ops.elementary.subprocess.run",
|
||||
return_value=subprocess.CompletedProcess(
|
||||
args=[], returncode=1, stdout="", stderr="fatal error"
|
||||
),
|
||||
)
|
||||
def test_raises_on_failure(self, mock_run):
|
||||
context = build_op_context()
|
||||
with pytest.raises(Exception, match="edr report failed"):
|
||||
elementary_generate_report(context)
|
||||
|
||||
@patch(
|
||||
"data_platform.ops.elementary.subprocess.run",
|
||||
return_value=subprocess.CompletedProcess(
|
||||
args=[], returncode=0, stdout="done", stderr=""
|
||||
),
|
||||
)
|
||||
def test_success_returns_none(self, mock_run):
|
||||
context = build_op_context()
|
||||
result = elementary_generate_report(context)
|
||||
assert result is None
|
||||
Reference in New Issue
Block a user