test: expand unit tests for ml and dagster resources
This commit is contained in:
390
tests/test_ml.py
390
tests/test_ml.py
@@ -1,18 +1,24 @@
|
||||
"""Tests for ML helper functions — preprocess, _best_run, _build_embed."""
|
||||
"""Tests for ML helper functions and asset functions."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from dagster import build_asset_context
|
||||
|
||||
from data_platform.assets.ml.elo_model import (
|
||||
ALL_FEATURES,
|
||||
ENERGY_LABEL_MAP,
|
||||
_cleanup_old_runs,
|
||||
preprocess,
|
||||
)
|
||||
from data_platform.assets.ml.discord_alerts import _build_embed
|
||||
from data_platform.assets.ml.elo_inference import _best_run
|
||||
from data_platform.assets.ml.discord_alerts import (
|
||||
DiscordNotificationConfig,
|
||||
_build_embed,
|
||||
)
|
||||
from data_platform.assets.ml.elo_inference import EloInferenceConfig, _best_run
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -201,3 +207,381 @@ class TestBuildEmbed:
|
||||
def test_color_is_green(self):
|
||||
embed = _build_embed(self._make_row())
|
||||
assert embed["color"] == 0x00B894
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cleanup_old_runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCleanupOldRuns:
|
||||
@patch("data_platform.assets.ml.elo_model.mlflow")
|
||||
def test_no_experiment_noop(self, mock_mlflow):
|
||||
client = mock_mlflow.tracking.MlflowClient.return_value
|
||||
client.get_experiment_by_name.return_value = None
|
||||
context = MagicMock()
|
||||
_cleanup_old_runs("nonexistent", context, keep=3)
|
||||
client.delete_run.assert_not_called()
|
||||
|
||||
@patch("data_platform.assets.ml.elo_model.mlflow")
|
||||
def test_fewer_than_keep_noop(self, mock_mlflow):
|
||||
client = mock_mlflow.tracking.MlflowClient.return_value
|
||||
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
|
||||
client.search_runs.return_value = [MagicMock(), MagicMock()]
|
||||
context = MagicMock()
|
||||
_cleanup_old_runs("exp", context, keep=3)
|
||||
client.delete_run.assert_not_called()
|
||||
|
||||
@patch("data_platform.assets.ml.elo_model.mlflow")
|
||||
def test_deletes_old_runs(self, mock_mlflow):
|
||||
client = mock_mlflow.tracking.MlflowClient.return_value
|
||||
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
|
||||
runs = [MagicMock() for _ in range(5)]
|
||||
for i, run in enumerate(runs):
|
||||
run.info.run_id = f"run_{i}"
|
||||
client.search_runs.return_value = runs
|
||||
context = MagicMock()
|
||||
_cleanup_old_runs("exp", context, keep=2)
|
||||
assert client.delete_run.call_count == 3
|
||||
client.delete_run.assert_any_call("run_2")
|
||||
client.delete_run.assert_any_call("run_3")
|
||||
client.delete_run.assert_any_call("run_4")
|
||||
|
||||
@patch("data_platform.assets.ml.elo_model.mlflow")
|
||||
def test_exactly_keep_noop(self, mock_mlflow):
|
||||
client = mock_mlflow.tracking.MlflowClient.return_value
|
||||
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
|
||||
client.search_runs.return_value = [MagicMock(), MagicMock(), MagicMock()]
|
||||
context = MagicMock()
|
||||
_cleanup_old_runs("exp", context, keep=3)
|
||||
client.delete_run.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# elo_prediction_model asset
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_training_df(n=20):
|
||||
"""Build a small training DataFrame with all required columns."""
|
||||
rng = np.random.default_rng(42)
|
||||
data = {
|
||||
"energy_label": rng.choice(["A", "B", "C", "D"], size=n),
|
||||
"current_price": rng.integers(200_000, 500_000, size=n).astype(float),
|
||||
"living_area": rng.integers(40, 150, size=n).astype(float),
|
||||
"plot_area": rng.integers(0, 300, size=n).astype(float),
|
||||
"bedrooms": rng.integers(1, 5, size=n).astype(float),
|
||||
"rooms": rng.integers(2, 8, size=n).astype(float),
|
||||
"construction_year": rng.integers(1900, 2025, size=n).astype(float),
|
||||
"latitude": rng.uniform(51, 53, size=n),
|
||||
"longitude": rng.uniform(3, 6, size=n),
|
||||
"photo_count": rng.integers(1, 30, size=n).astype(float),
|
||||
"views": rng.integers(10, 500, size=n).astype(float),
|
||||
"saves": rng.integers(0, 50, size=n).astype(float),
|
||||
"price_per_sqm": rng.integers(2000, 6000, size=n).astype(float),
|
||||
"has_garden": rng.choice([True, False, None], size=n),
|
||||
"has_balcony": rng.choice([True, False], size=n),
|
||||
"has_solar_panels": rng.choice([True, False, None], size=n),
|
||||
"has_heat_pump": rng.choice([True, False], size=n),
|
||||
"has_roof_terrace": rng.choice([True, False, None], size=n),
|
||||
"is_energy_efficient": rng.choice([True, False], size=n),
|
||||
"is_monument": rng.choice([True, False], size=n),
|
||||
"elo_rating": rng.uniform(1200, 1800, size=n),
|
||||
}
|
||||
return pd.DataFrame(data)
|
||||
|
||||
|
||||
class TestEloPredictionModelAsset:
|
||||
@patch("data_platform.assets.ml.elo_model._cleanup_old_runs")
|
||||
@patch("data_platform.assets.ml.elo_model.mlflow")
|
||||
def test_trains_and_returns_metadata(self, mock_mlflow, mock_cleanup):
|
||||
# Setup MLflow mock
|
||||
mock_run = MagicMock()
|
||||
mock_run.info.run_id = "test-run-123"
|
||||
mock_mlflow.start_run.return_value.__enter__ = MagicMock(return_value=mock_run)
|
||||
mock_mlflow.start_run.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# Setup Postgres mock
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
postgres.get_engine.return_value = engine
|
||||
mlflow_resource = MagicMock()
|
||||
mlflow_resource.get_tracking_uri.return_value = "http://mlflow:5000"
|
||||
|
||||
df = _make_training_df(20)
|
||||
context = build_asset_context()
|
||||
|
||||
from data_platform.assets.ml.elo_model import (
|
||||
EloModelConfig,
|
||||
elo_prediction_model,
|
||||
)
|
||||
|
||||
config = EloModelConfig(n_estimators=10)
|
||||
|
||||
with patch("data_platform.assets.ml.elo_model.pd.read_sql", return_value=df):
|
||||
result = elo_prediction_model(context, config, postgres, mlflow_resource)
|
||||
|
||||
assert result.metadata["mlflow_run_id"].value == "test-run-123"
|
||||
assert "rmse" in result.metadata
|
||||
assert "mae" in result.metadata
|
||||
assert "r2" in result.metadata
|
||||
assert result.metadata["train_rows"] > 0
|
||||
assert result.metadata["test_rows"] > 0
|
||||
mock_cleanup.assert_called_once()
|
||||
|
||||
@patch("data_platform.assets.ml.elo_model.mlflow")
|
||||
def test_raises_with_too_few_rows(self, mock_mlflow):
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
postgres.get_engine.return_value = engine
|
||||
mlflow_resource = MagicMock()
|
||||
mlflow_resource.get_tracking_uri.return_value = "http://mlflow:5000"
|
||||
|
||||
df = _make_training_df(5) # fewer than 10
|
||||
context = build_asset_context()
|
||||
|
||||
from data_platform.assets.ml.elo_model import (
|
||||
EloModelConfig,
|
||||
elo_prediction_model,
|
||||
)
|
||||
|
||||
config = EloModelConfig()
|
||||
|
||||
with (
|
||||
patch("data_platform.assets.ml.elo_model.pd.read_sql", return_value=df),
|
||||
pytest.raises(ValueError, match="Not enough rated listings"),
|
||||
):
|
||||
elo_prediction_model(context, config, postgres, mlflow_resource)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# elo_inference asset
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEloInferenceAsset:
|
||||
def _make_unscored_df(self, n=5):
|
||||
"""Build a DataFrame of unscored listings."""
|
||||
rng = np.random.default_rng(42)
|
||||
data = {
|
||||
"global_id": [f"listing_{i}" for i in range(n)],
|
||||
"energy_label": rng.choice(["A", "B", "C"], size=n),
|
||||
"current_price": rng.integers(200_000, 500_000, size=n).astype(float),
|
||||
"living_area": rng.integers(40, 150, size=n).astype(float),
|
||||
"plot_area": rng.integers(0, 300, size=n).astype(float),
|
||||
"bedrooms": rng.integers(1, 5, size=n).astype(float),
|
||||
"rooms": rng.integers(2, 8, size=n).astype(float),
|
||||
"construction_year": rng.integers(1900, 2025, size=n).astype(float),
|
||||
"latitude": rng.uniform(51, 53, size=n),
|
||||
"longitude": rng.uniform(3, 6, size=n),
|
||||
"photo_count": rng.integers(1, 30, size=n).astype(float),
|
||||
"views": rng.integers(10, 500, size=n).astype(float),
|
||||
"saves": rng.integers(0, 50, size=n).astype(float),
|
||||
"price_per_sqm": rng.integers(2000, 6000, size=n).astype(float),
|
||||
"has_garden": rng.choice([True, False], size=n),
|
||||
"has_balcony": rng.choice([True, False], size=n),
|
||||
"has_solar_panels": rng.choice([True, False], size=n),
|
||||
"has_heat_pump": rng.choice([True, False], size=n),
|
||||
"has_roof_terrace": rng.choice([True, False], size=n),
|
||||
"is_energy_efficient": rng.choice([True, False], size=n),
|
||||
"is_monument": rng.choice([True, False], size=n),
|
||||
}
|
||||
return pd.DataFrame(data)
|
||||
|
||||
@patch("data_platform.assets.ml.elo_inference.mlflow")
|
||||
@patch("data_platform.assets.ml.elo_inference.pd.read_sql")
|
||||
@patch("data_platform.assets.ml.elo_inference.render_sql", return_value="SQL")
|
||||
def test_scores_listings(self, mock_render, mock_read_sql, mock_mlflow):
|
||||
from data_platform.assets.ml.elo_inference import elo_inference
|
||||
|
||||
df = self._make_unscored_df(5)
|
||||
mock_read_sql.return_value = df
|
||||
|
||||
# Mock MLflow best run
|
||||
best_run = MagicMock()
|
||||
best_run.info.run_id = "run-abc"
|
||||
best_run.data.metrics = {"rmse": 0.5}
|
||||
client = mock_mlflow.tracking.MlflowClient.return_value
|
||||
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
|
||||
client.search_runs.return_value = [best_run]
|
||||
|
||||
# Mock model
|
||||
mock_model = MagicMock()
|
||||
mock_model.predict.return_value = np.random.default_rng(42).uniform(
|
||||
-3, 3, size=5
|
||||
)
|
||||
mock_mlflow.lightgbm.load_model.return_value = mock_model
|
||||
|
||||
# Mock resources
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
conn = MagicMock()
|
||||
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
|
||||
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
postgres.get_engine.return_value = engine
|
||||
|
||||
mlflow_resource = MagicMock()
|
||||
mlflow_resource.get_tracking_uri.return_value = "http://mlflow:5000"
|
||||
|
||||
context = build_asset_context()
|
||||
config = EloInferenceConfig()
|
||||
|
||||
result = elo_inference(context, config, postgres, mlflow_resource)
|
||||
|
||||
assert result.metadata["scored"] == 5
|
||||
assert result.metadata["mlflow_run_id"].value == "run-abc"
|
||||
assert "predicted_elo_mean" in result.metadata
|
||||
mock_model.predict.assert_called_once()
|
||||
|
||||
@patch("data_platform.assets.ml.elo_inference.pd.read_sql")
|
||||
@patch("data_platform.assets.ml.elo_inference.render_sql", return_value="SQL")
|
||||
def test_empty_returns_zero_scored(self, mock_render, mock_read_sql):
|
||||
from data_platform.assets.ml.elo_inference import elo_inference
|
||||
|
||||
mock_read_sql.return_value = pd.DataFrame()
|
||||
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
conn = MagicMock()
|
||||
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
|
||||
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
postgres.get_engine.return_value = engine
|
||||
|
||||
mlflow_resource = MagicMock()
|
||||
context = build_asset_context()
|
||||
config = EloInferenceConfig()
|
||||
|
||||
result = elo_inference(context, config, postgres, mlflow_resource)
|
||||
|
||||
assert result.metadata["scored"] == 0
|
||||
assert result.metadata["status"].value == "No new listings to score."
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# listing_alert asset
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListingAlertAsset:
|
||||
def _make_predictions_df(self, n=3):
|
||||
"""Build a DataFrame of high-ELO predictions."""
|
||||
data = {
|
||||
"global_id": [f"listing_{i}" for i in range(n)],
|
||||
"predicted_elo": [1650.0 + i * 10 for i in range(n)],
|
||||
"current_price": [350_000 + i * 10_000 for i in range(n)],
|
||||
"city": ["Amsterdam", "Rotterdam", "Utrecht"][:n],
|
||||
"living_area": [80 + i * 5 for i in range(n)],
|
||||
"rooms": [4, 5, 3][:n],
|
||||
"energy_label": ["A", "B", "C"][:n],
|
||||
"price_per_sqm": [4375, 4000, 3800][:n],
|
||||
"title": [f"Teststraat {i}" for i in range(n)],
|
||||
"url": [f"https://funda.nl/{i}" for i in range(n)],
|
||||
}
|
||||
return pd.DataFrame(data)
|
||||
|
||||
@patch("data_platform.assets.ml.discord_alerts.requests.post")
|
||||
@patch("data_platform.assets.ml.discord_alerts.pd.read_sql")
|
||||
@patch("data_platform.assets.ml.discord_alerts.render_sql", return_value="SQL")
|
||||
def test_sends_notifications(self, mock_render, mock_read_sql, mock_post):
|
||||
from data_platform.assets.ml.discord_alerts import listing_alert
|
||||
|
||||
df = self._make_predictions_df(3)
|
||||
mock_read_sql.return_value = df
|
||||
mock_post.return_value = MagicMock(status_code=200)
|
||||
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
conn = MagicMock()
|
||||
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
|
||||
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
postgres.get_engine.return_value = engine
|
||||
|
||||
discord = MagicMock()
|
||||
discord.get_webhook_url.return_value = "https://discord.com/webhook/test"
|
||||
|
||||
context = build_asset_context()
|
||||
config = DiscordNotificationConfig()
|
||||
|
||||
result = listing_alert(context, config, postgres, discord)
|
||||
|
||||
assert result.metadata["notified"] == 3
|
||||
mock_post.assert_called_once()
|
||||
|
||||
@patch("data_platform.assets.ml.discord_alerts.pd.read_sql")
|
||||
@patch("data_platform.assets.ml.discord_alerts.render_sql", return_value="SQL")
|
||||
def test_empty_returns_zero_notified(self, mock_render, mock_read_sql):
|
||||
from data_platform.assets.ml.discord_alerts import listing_alert
|
||||
|
||||
mock_read_sql.return_value = pd.DataFrame()
|
||||
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
conn = MagicMock()
|
||||
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
|
||||
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
postgres.get_engine.return_value = engine
|
||||
|
||||
discord = MagicMock()
|
||||
context = build_asset_context()
|
||||
config = DiscordNotificationConfig()
|
||||
|
||||
result = listing_alert(context, config, postgres, discord)
|
||||
|
||||
assert result.metadata["notified"] == 0
|
||||
|
||||
@patch("data_platform.assets.ml.discord_alerts.requests.post")
|
||||
@patch("data_platform.assets.ml.discord_alerts.pd.read_sql")
|
||||
@patch("data_platform.assets.ml.discord_alerts.render_sql", return_value="SQL")
|
||||
def test_batches_large_result_set(self, mock_render, mock_read_sql, mock_post):
|
||||
from data_platform.assets.ml.discord_alerts import listing_alert
|
||||
|
||||
# 15 listings → should produce 2 batches (10 + 5)
|
||||
data = {
|
||||
"global_id": [f"listing_{i}" for i in range(15)],
|
||||
"predicted_elo": [1700.0] * 15,
|
||||
"current_price": [400_000] * 15,
|
||||
"city": ["Amsterdam"] * 15,
|
||||
"living_area": [90] * 15,
|
||||
"rooms": [4] * 15,
|
||||
"energy_label": ["A"] * 15,
|
||||
"price_per_sqm": [4444] * 15,
|
||||
"title": [f"Street {i}" for i in range(15)],
|
||||
"url": [f"https://funda.nl/{i}" for i in range(15)],
|
||||
}
|
||||
mock_read_sql.return_value = pd.DataFrame(data)
|
||||
mock_post.return_value = MagicMock(status_code=200)
|
||||
|
||||
postgres = MagicMock()
|
||||
engine = MagicMock()
|
||||
conn = MagicMock()
|
||||
engine.begin.return_value.__enter__ = MagicMock(return_value=conn)
|
||||
engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
postgres.get_engine.return_value = engine
|
||||
|
||||
discord = MagicMock()
|
||||
discord.get_webhook_url.return_value = "https://discord.com/webhook/test"
|
||||
|
||||
context = build_asset_context()
|
||||
config = DiscordNotificationConfig()
|
||||
|
||||
result = listing_alert(context, config, postgres, discord)
|
||||
|
||||
assert result.metadata["notified"] == 15
|
||||
assert mock_post.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMLConfigs:
|
||||
def test_elo_inference_config_defaults(self):
|
||||
config = EloInferenceConfig()
|
||||
assert config.mlflow_experiment == "elo-rating-prediction"
|
||||
assert config.metric == "rmse"
|
||||
assert config.ascending is True
|
||||
|
||||
def test_discord_notification_config_defaults(self):
|
||||
config = DiscordNotificationConfig()
|
||||
assert config.min_elo == 1600
|
||||
|
||||
Reference in New Issue
Block a user