feat: small refactor

This commit is contained in:
Stijnvandenbroek
2026-03-08 16:41:30 +00:00
parent 16a7a470ea
commit 05aadaec29
9 changed files with 354 additions and 7 deletions

View File

@@ -11,7 +11,9 @@ deployed via Docker Compose.
| -------------- | ------------------------------------------------- | | -------------- | ------------------------------------------------- |
| Orchestration | Dagster (webserver + daemon) | | Orchestration | Dagster (webserver + daemon) |
| Transformation | dbt-core + dbt-postgres | | Transformation | dbt-core + dbt-postgres |
| ML | LightGBM, MLflow, scikit-learn |
| Storage | PostgreSQL 16 | | Storage | PostgreSQL 16 |
| Notifications | Discord webhooks |
| Observability | Elementary (report served via nginx) | | Observability | Elementary (report served via nginx) |
| CI | GitHub Actions (Ruff, SQLFluff, Prettier, pytest) | | CI | GitHub Actions (Ruff, SQLFluff, Prettier, pytest) |
| Package / venv | uv | | Package / venv | uv |
@@ -49,6 +51,25 @@ types to prevent silent schema drift.
- **Source freshness**: a scheduled job verifies raw tables haven't gone stale. - **Source freshness**: a scheduled job verifies raw tables haven't gone stale.
- **Elementary**: collects test results and generates an HTML observability report served via nginx. - **Elementary**: collects test results and generates an HTML observability report served via nginx.
### Machine learning
An **ELO rating system** lets you rank listings via pairwise comparisons. An ML pipeline then learns
to predict ELO scores for unseen listings:
| Asset | Description |
| ---------------------- | ----------------------------------------------------------------------------------- |
| `elo_prediction_model` | Trains a LightGBM regressor on listing features → ELO rating. Logs to MLflow. |
| `elo_inference` | Loads the best model from MLflow, scores all unscored listings, writes to Postgres. |
| `listing_alert` | Sends a Discord notification for listings with a predicted ELO above a threshold. |
All three are tagged `"manual"` — they run only when triggered explicitly.
### Notifications
The `listing_alert` asset posts rich embeds to a Discord channel via webhook when newly scored
listings exceed a configurable ELO threshold. Notifications are deduplicated using the
`elo.notified` table.
## Scheduling & automation ## Scheduling & automation
Ingestion assets run on cron schedules managed by the Dagster daemon. Downstream dbt models use Ingestion assets run on cron schedules managed by the Dagster daemon. Downstream dbt models use
@@ -65,11 +86,13 @@ assets are still materialising.
data_platform/ # Dagster Python package data_platform/ # Dagster Python package
assets/ assets/
dbt.py # @dbt_assets definition dbt.py # @dbt_assets definition
elo/ # ELO schema/table management assets
ingestion/ # Raw ingestion assets + SQL templates ingestion/ # Raw ingestion assets + SQL templates
ml/ # ML assets (training, inference, alerts)
helpers/ # Shared utilities (SQL rendering, formatting, automation) helpers/ # Shared utilities (SQL rendering, formatting, automation)
jobs/ # Job definitions jobs/ # Job definitions
schedules/ # Schedule definitions schedules/ # Schedule definitions
resources/ # Dagster resources (API clients, Postgres) resources/ # Dagster resources (Postgres, MLflow, Discord, Funda)
definitions.py # Main Definitions entry point definitions.py # Main Definitions entry point
dbt/ # dbt project dbt/ # dbt project
models/ models/
@@ -134,5 +157,6 @@ make reload-code # Rebuild + restart user-code container
| Service | URL | | Service | URL |
| ----------------- | --------------------- | | ----------------- | --------------------- |
| Dagster UI | http://localhost:3000 | | Dagster UI | http://localhost:3000 |
| MLflow UI | http://localhost:5000 |
| pgAdmin | http://localhost:5050 | | pgAdmin | http://localhost:5050 |
| Elementary report | http://localhost:8080 | | Elementary report | http://localhost:8080 |

View File

@@ -26,6 +26,7 @@ def _ensure_schema(conn: object) -> None:
@asset( @asset(
deps=["elo_sample_listings"], deps=["elo_sample_listings"],
group_name="elo", group_name="elo",
kinds={"python", "postgres"},
description="Creates the ELO ratings table that stores per-listing ELO scores.", description="Creates the ELO ratings table that stores per-listing ELO scores.",
) )
def elo_ratings( def elo_ratings(
@@ -53,6 +54,7 @@ def elo_ratings(
@asset( @asset(
deps=["elo_sample_listings"], deps=["elo_sample_listings"],
group_name="elo", group_name="elo",
kinds={"python", "postgres"},
description="Creates the ELO comparisons table that records pairwise match results.", description="Creates the ELO comparisons table that records pairwise match results.",
) )
def elo_comparisons( def elo_comparisons(

View File

@@ -2,6 +2,7 @@
from pathlib import Path from pathlib import Path
import pandas as pd
import requests import requests
from dagster import ( from dagster import (
AssetExecutionContext, AssetExecutionContext,
@@ -51,6 +52,7 @@ def _build_embed(row) -> dict:
deps=["elo_inference"], deps=["elo_inference"],
group_name="ml", group_name="ml",
kinds={"python", "discord"}, kinds={"python", "discord"},
tags={"manual": "true"},
description=( description=(
"Send a Discord notification for newly scored listings whose " "Send a Discord notification for newly scored listings whose "
"predicted ELO exceeds a configurable threshold." "predicted ELO exceeds a configurable threshold."
@@ -69,7 +71,7 @@ def listing_alert(
conn.execute(text(render_sql(_SQL_DIR, "ensure_notified_table.sql"))) conn.execute(text(render_sql(_SQL_DIR, "ensure_notified_table.sql")))
query = render_sql(_SQL_DIR, "select_top_predictions.sql") query = render_sql(_SQL_DIR, "select_top_predictions.sql")
df = __import__("pandas").read_sql( df = pd.read_sql(
text(query), text(query),
engine, engine,
params={"min_elo": config.min_elo}, params={"min_elo": config.min_elo},

View File

@@ -16,7 +16,7 @@ from sqlalchemy import text
from data_platform.assets.ml.elo_model import ( from data_platform.assets.ml.elo_model import (
ALL_FEATURES, ALL_FEATURES,
_preprocess, preprocess,
) )
from data_platform.helpers import render_sql from data_platform.helpers import render_sql
from data_platform.resources import MLflowResource, PostgresResource from data_platform.resources import MLflowResource, PostgresResource
@@ -104,7 +104,7 @@ def elo_inference(
model = mlflow.lightgbm.load_model(model_uri) model = mlflow.lightgbm.load_model(model_uri)
# Preprocess features identically to training # Preprocess features identically to training
df = _preprocess(df) df = preprocess(df)
X = df[ALL_FEATURES].copy() X = df[ALL_FEATURES].copy()
# Predict normalised ELO and convert back to original scale # Predict normalised ELO and convert back to original scale

View File

@@ -84,7 +84,7 @@ class EloModelConfig(Config):
mlflow_experiment: str = "elo-rating-prediction" mlflow_experiment: str = "elo-rating-prediction"
def _preprocess(df: pd.DataFrame) -> pd.DataFrame: def preprocess(df: pd.DataFrame) -> pd.DataFrame:
"""Convert raw columns to model-ready numeric features.""" """Convert raw columns to model-ready numeric features."""
df["energy_label_num"] = ( df["energy_label_num"] = (
df["energy_label"] df["energy_label"]
@@ -139,7 +139,7 @@ def elo_prediction_model(
) )
# Preprocess and normalise ELO target # Preprocess and normalise ELO target
df = _preprocess(df) df = preprocess(df)
df["elo_norm"] = (df["elo_rating"] - 1500) / 100 df["elo_norm"] = (df["elo_rating"] - 1500) / 100
X = df[ALL_FEATURES].copy() X = df[ALL_FEATURES].copy()

View File

@@ -0,0 +1,64 @@
"""Tests for data_platform.helpers.automation — apply_automation."""
from unittest.mock import MagicMock
from dagster import AssetSpec
from data_platform.helpers.automation import (
AUTOMATION_CONDITION,
_is_manual,
apply_automation,
)
def _make_asset_def(tags: dict | None = None):
"""Create a minimal mock AssetsDefinition with the given tags."""
spec = MagicMock(spec=AssetSpec)
spec.tags = tags or {}
asset_def = MagicMock()
asset_def.specs = [spec]
updated = MagicMock()
asset_def.with_attributes.return_value = updated
return asset_def, updated
class TestIsManual:
def test_manual_when_tagged(self):
asset_def, _ = _make_asset_def(tags={"manual": "true"})
assert _is_manual(asset_def) is True
def test_not_manual_when_no_tags(self):
asset_def, _ = _make_asset_def(tags={})
assert _is_manual(asset_def) is False
def test_not_manual_when_other_tags(self):
asset_def, _ = _make_asset_def(tags={"owner": "team-data"})
assert _is_manual(asset_def) is False
class TestApplyAutomation:
def test_manual_asset_unchanged(self):
asset_def, _ = _make_asset_def(tags={"manual": "true"})
result = apply_automation([asset_def])
assert result == [asset_def]
asset_def.with_attributes.assert_not_called()
def test_non_manual_asset_gets_condition(self):
asset_def, updated = _make_asset_def(tags={})
result = apply_automation([asset_def])
assert result == [updated]
asset_def.with_attributes.assert_called_once_with(
automation_condition=AUTOMATION_CONDITION
)
def test_empty_list(self):
assert apply_automation([]) == []
def test_mixed_assets(self):
manual, _ = _make_asset_def(tags={"manual": "true"})
auto, auto_updated = _make_asset_def(tags={})
result = apply_automation([manual, auto])
assert result == [manual, auto_updated]

35
tests/test_helpers_sql.py Normal file
View File

@@ -0,0 +1,35 @@
"""Tests for data_platform.helpers.sql — render_sql."""
from pathlib import Path
import pytest
from jinja2 import TemplateNotFound
from data_platform.helpers import render_sql
_FIXTURES = Path(__file__).parent / "fixtures" / "sql"
class TestRenderSql:
def test_renders_plain_sql(self, tmp_path):
sql_file = tmp_path / "plain.sql"
sql_file.write_text("select 1")
result = render_sql(tmp_path, "plain.sql")
assert result == "select 1"
def test_renders_with_variables(self, tmp_path):
sql_file = tmp_path / "schema.sql"
sql_file.write_text("create schema if not exists {{ schema }}")
result = render_sql(tmp_path, "schema.sql", schema="elo")
assert result == "create schema if not exists elo"
def test_missing_template_raises(self, tmp_path):
with pytest.raises(TemplateNotFound):
render_sql(tmp_path, "nonexistent.sql")
def test_multiple_variables(self, tmp_path):
sql_file = tmp_path / "multi.sql"
sql_file.write_text("select * from {{ schema }}.{{ table }}")
result = render_sql(tmp_path, "multi.sql", schema="raw", table="events")
assert result == "select * from raw.events"

203
tests/test_ml.py Normal file
View File

@@ -0,0 +1,203 @@
"""Tests for ML helper functions — preprocess, _best_run, _build_embed."""
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pandas as pd
import pytest
from data_platform.assets.ml.elo_model import (
ALL_FEATURES,
ENERGY_LABEL_MAP,
preprocess,
)
from data_platform.assets.ml.discord_alerts import _build_embed
from data_platform.assets.ml.elo_inference import _best_run
# ---------------------------------------------------------------------------
# preprocess
# ---------------------------------------------------------------------------
class TestPreprocess:
def _make_df(self, overrides: dict | None = None) -> pd.DataFrame:
"""Build a single-row DataFrame with sensible defaults."""
data = {
"energy_label": "A",
"current_price": 350_000,
"living_area": 80,
"plot_area": 120,
"bedrooms": 3,
"rooms": 5,
"construction_year": 2000,
"latitude": 52.0,
"longitude": 4.5,
"photo_count": 10,
"views": 100,
"saves": 20,
"price_per_sqm": 4375,
"has_garden": True,
"has_balcony": False,
"has_solar_panels": None,
"has_heat_pump": False,
"has_roof_terrace": None,
"is_energy_efficient": True,
"is_monument": False,
}
if overrides:
data.update(overrides)
return pd.DataFrame([data])
def test_energy_label_mapped(self):
df = preprocess(self._make_df({"energy_label": "A"}))
assert df["energy_label_num"].iloc[0] == ENERGY_LABEL_MAP["A"]
def test_energy_label_g(self):
df = preprocess(self._make_df({"energy_label": "G"}))
assert df["energy_label_num"].iloc[0] == -1
def test_energy_label_unknown(self):
df = preprocess(self._make_df({"energy_label": "Z"}))
assert df["energy_label_num"].iloc[0] == -2
def test_energy_label_case_insensitive(self):
df = preprocess(self._make_df({"energy_label": " a "}))
assert df["energy_label_num"].iloc[0] == ENERGY_LABEL_MAP["A"]
def test_bool_none_becomes_zero(self):
df = preprocess(self._make_df({"has_garden": None}))
assert df["has_garden"].iloc[0] == 0
def test_bool_true_becomes_one(self):
df = preprocess(self._make_df({"has_garden": True}))
assert df["has_garden"].iloc[0] == 1
def test_bool_false_becomes_zero(self):
df = preprocess(self._make_df({"has_garden": False}))
assert df["has_garden"].iloc[0] == 0
def test_numeric_string_coerced(self):
df = preprocess(self._make_df({"current_price": "350000"}))
assert df["current_price"].iloc[0] == 350_000.0
def test_numeric_nan_filled_with_median(self):
data = self._make_df()
row2 = self._make_df({"current_price": None})
df = pd.concat([data, row2], ignore_index=True)
df = preprocess(df)
# median of [350000] is 350000 (single non-null value)
assert df["current_price"].iloc[1] == 350_000.0
def test_all_features_present(self):
df = preprocess(self._make_df())
for feat in ALL_FEATURES:
assert feat in df.columns, f"Missing feature: {feat}"
# ---------------------------------------------------------------------------
# _best_run
# ---------------------------------------------------------------------------
class TestBestRun:
@patch("data_platform.assets.ml.elo_inference.mlflow")
def test_no_experiment_raises(self, mock_mlflow):
mock_mlflow.tracking.MlflowClient.return_value.get_experiment_by_name.return_value = None
with pytest.raises(ValueError, match="does not exist"):
_best_run("nonexistent", "rmse", ascending=True)
@patch("data_platform.assets.ml.elo_inference.mlflow")
def test_no_runs_raises(self, mock_mlflow):
client = mock_mlflow.tracking.MlflowClient.return_value
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
client.search_runs.return_value = []
with pytest.raises(ValueError, match="No runs found"):
_best_run("experiment", "rmse", ascending=True)
@patch("data_platform.assets.ml.elo_inference.mlflow")
def test_ascending_order(self, mock_mlflow):
client = mock_mlflow.tracking.MlflowClient.return_value
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
run = MagicMock()
client.search_runs.return_value = [run]
result = _best_run("experiment", "rmse", ascending=True)
assert result is run
client.search_runs.assert_called_once_with(
experiment_ids=["1"],
order_by=["metrics.rmse ASC"],
max_results=1,
)
@patch("data_platform.assets.ml.elo_inference.mlflow")
def test_descending_order(self, mock_mlflow):
client = mock_mlflow.tracking.MlflowClient.return_value
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
run = MagicMock()
client.search_runs.return_value = [run]
result = _best_run("experiment", "r2", ascending=False)
assert result is run
client.search_runs.assert_called_once_with(
experiment_ids=["1"],
order_by=["metrics.r2 DESC"],
max_results=1,
)
# ---------------------------------------------------------------------------
# _build_embed
# ---------------------------------------------------------------------------
class TestBuildEmbed:
def _make_row(self, **overrides):
data = {
"predicted_elo": 1650.0,
"current_price": 350_000,
"city": "Amsterdam",
"living_area": 80,
"rooms": 4,
"energy_label": "A",
"price_per_sqm": 4375,
"title": "Teststraat 1",
"global_id": "abc123",
"url": "https://funda.nl/abc123",
}
data.update(overrides)
return SimpleNamespace(**data)
def test_all_fields_present(self):
embed = _build_embed(self._make_row())
field_names = [f["name"] for f in embed["fields"]]
assert "Predicted ELO" in field_names
assert "Price" in field_names
assert "City" in field_names
assert "Living area" in field_names
assert "Rooms" in field_names
assert "Energy label" in field_names
assert "€/m²" in field_names
def test_no_price_per_sqm_omits_field(self):
embed = _build_embed(self._make_row(price_per_sqm=None))
field_names = [f["name"] for f in embed["fields"]]
assert "€/m²" not in field_names
assert len(embed["fields"]) == 6
def test_missing_city_shows_dash(self):
embed = _build_embed(self._make_row(city=None))
city_field = next(f for f in embed["fields"] if f["name"] == "City")
assert city_field["value"] == ""
def test_title_fallback_to_global_id(self):
embed = _build_embed(self._make_row(title=None))
assert embed["title"] == "abc123"
def test_url_set(self):
embed = _build_embed(self._make_row())
assert embed["url"] == "https://funda.nl/abc123"
def test_color_is_green(self):
embed = _build_embed(self._make_row())
assert embed["color"] == 0x00B894

View File

@@ -2,7 +2,12 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from data_platform.resources import FundaResource, PostgresResource from data_platform.resources import (
DiscordResource,
FundaResource,
MLflowResource,
PostgresResource,
)
class TestFundaResource: class TestFundaResource:
@@ -79,3 +84,15 @@ class TestPostgresResource:
res.execute_many("INSERT INTO t VALUES (:id)", rows) res.execute_many("INSERT INTO t VALUES (:id)", rows)
mock_conn.execute.assert_called_once() mock_conn.execute.assert_called_once()
class TestMLflowResource:
def test_tracking_uri(self):
resource = MLflowResource(tracking_uri="http://mlflow:5000")
assert resource.get_tracking_uri() == "http://mlflow:5000"
class TestDiscordResource:
def test_webhook_url(self):
resource = DiscordResource(webhook_url="https://discord.com/api/webhooks/test")
assert resource.get_webhook_url() == "https://discord.com/api/webhooks/test"