feat: small refactor
This commit is contained in:
26
README.md
26
README.md
@@ -11,7 +11,9 @@ deployed via Docker Compose.
|
||||
| -------------- | ------------------------------------------------- |
|
||||
| Orchestration | Dagster (webserver + daemon) |
|
||||
| Transformation | dbt-core + dbt-postgres |
|
||||
| ML | LightGBM, MLflow, scikit-learn |
|
||||
| Storage | PostgreSQL 16 |
|
||||
| Notifications | Discord webhooks |
|
||||
| Observability | Elementary (report served via nginx) |
|
||||
| CI | GitHub Actions (Ruff, SQLFluff, Prettier, pytest) |
|
||||
| Package / venv | uv |
|
||||
@@ -49,6 +51,25 @@ types to prevent silent schema drift.
|
||||
- **Source freshness**: a scheduled job verifies raw tables haven't gone stale.
|
||||
- **Elementary**: collects test results and generates an HTML observability report served via nginx.
|
||||
|
||||
### Machine learning
|
||||
|
||||
An **ELO rating system** lets you rank listings via pairwise comparisons. An ML pipeline then learns
|
||||
to predict ELO scores for unseen listings:
|
||||
|
||||
| Asset | Description |
|
||||
| ---------------------- | ----------------------------------------------------------------------------------- |
|
||||
| `elo_prediction_model` | Trains a LightGBM regressor on listing features → ELO rating. Logs to MLflow. |
|
||||
| `elo_inference` | Loads the best model from MLflow, scores all unscored listings, writes to Postgres. |
|
||||
| `listing_alert` | Sends a Discord notification for listings with a predicted ELO above a threshold. |
|
||||
|
||||
All three are tagged `"manual"` — they run only when triggered explicitly.
|
||||
|
||||
### Notifications
|
||||
|
||||
The `listing_alert` asset posts rich embeds to a Discord channel via webhook when newly scored
|
||||
listings exceed a configurable ELO threshold. Notifications are deduplicated using the
|
||||
`elo.notified` table.
|
||||
|
||||
## Scheduling & automation
|
||||
|
||||
Ingestion assets run on cron schedules managed by the Dagster daemon. Downstream dbt models use
|
||||
@@ -65,11 +86,13 @@ assets are still materialising.
|
||||
data_platform/ # Dagster Python package
|
||||
assets/
|
||||
dbt.py # @dbt_assets definition
|
||||
elo/ # ELO schema/table management assets
|
||||
ingestion/ # Raw ingestion assets + SQL templates
|
||||
ml/ # ML assets (training, inference, alerts)
|
||||
helpers/ # Shared utilities (SQL rendering, formatting, automation)
|
||||
jobs/ # Job definitions
|
||||
schedules/ # Schedule definitions
|
||||
resources/ # Dagster resources (API clients, Postgres)
|
||||
resources/ # Dagster resources (Postgres, MLflow, Discord, Funda)
|
||||
definitions.py # Main Definitions entry point
|
||||
dbt/ # dbt project
|
||||
models/
|
||||
@@ -134,5 +157,6 @@ make reload-code # Rebuild + restart user-code container
|
||||
| Service | URL |
|
||||
| ----------------- | --------------------- |
|
||||
| Dagster UI | http://localhost:3000 |
|
||||
| MLflow UI | http://localhost:5000 |
|
||||
| pgAdmin | http://localhost:5050 |
|
||||
| Elementary report | http://localhost:8080 |
|
||||
|
||||
@@ -26,6 +26,7 @@ def _ensure_schema(conn: object) -> None:
|
||||
@asset(
|
||||
deps=["elo_sample_listings"],
|
||||
group_name="elo",
|
||||
kinds={"python", "postgres"},
|
||||
description="Creates the ELO ratings table that stores per-listing ELO scores.",
|
||||
)
|
||||
def elo_ratings(
|
||||
@@ -53,6 +54,7 @@ def elo_ratings(
|
||||
@asset(
|
||||
deps=["elo_sample_listings"],
|
||||
group_name="elo",
|
||||
kinds={"python", "postgres"},
|
||||
description="Creates the ELO comparisons table that records pairwise match results.",
|
||||
)
|
||||
def elo_comparisons(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import requests
|
||||
from dagster import (
|
||||
AssetExecutionContext,
|
||||
@@ -51,6 +52,7 @@ def _build_embed(row) -> dict:
|
||||
deps=["elo_inference"],
|
||||
group_name="ml",
|
||||
kinds={"python", "discord"},
|
||||
tags={"manual": "true"},
|
||||
description=(
|
||||
"Send a Discord notification for newly scored listings whose "
|
||||
"predicted ELO exceeds a configurable threshold."
|
||||
@@ -69,7 +71,7 @@ def listing_alert(
|
||||
conn.execute(text(render_sql(_SQL_DIR, "ensure_notified_table.sql")))
|
||||
|
||||
query = render_sql(_SQL_DIR, "select_top_predictions.sql")
|
||||
df = __import__("pandas").read_sql(
|
||||
df = pd.read_sql(
|
||||
text(query),
|
||||
engine,
|
||||
params={"min_elo": config.min_elo},
|
||||
|
||||
@@ -16,7 +16,7 @@ from sqlalchemy import text
|
||||
|
||||
from data_platform.assets.ml.elo_model import (
|
||||
ALL_FEATURES,
|
||||
_preprocess,
|
||||
preprocess,
|
||||
)
|
||||
from data_platform.helpers import render_sql
|
||||
from data_platform.resources import MLflowResource, PostgresResource
|
||||
@@ -104,7 +104,7 @@ def elo_inference(
|
||||
model = mlflow.lightgbm.load_model(model_uri)
|
||||
|
||||
# Preprocess features identically to training
|
||||
df = _preprocess(df)
|
||||
df = preprocess(df)
|
||||
X = df[ALL_FEATURES].copy()
|
||||
|
||||
# Predict normalised ELO and convert back to original scale
|
||||
|
||||
@@ -84,7 +84,7 @@ class EloModelConfig(Config):
|
||||
mlflow_experiment: str = "elo-rating-prediction"
|
||||
|
||||
|
||||
def _preprocess(df: pd.DataFrame) -> pd.DataFrame:
|
||||
def preprocess(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Convert raw columns to model-ready numeric features."""
|
||||
df["energy_label_num"] = (
|
||||
df["energy_label"]
|
||||
@@ -139,7 +139,7 @@ def elo_prediction_model(
|
||||
)
|
||||
|
||||
# Preprocess and normalise ELO target
|
||||
df = _preprocess(df)
|
||||
df = preprocess(df)
|
||||
df["elo_norm"] = (df["elo_rating"] - 1500) / 100
|
||||
|
||||
X = df[ALL_FEATURES].copy()
|
||||
|
||||
64
tests/test_helpers_automation.py
Normal file
64
tests/test_helpers_automation.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Tests for data_platform.helpers.automation — apply_automation."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from dagster import AssetSpec
|
||||
|
||||
from data_platform.helpers.automation import (
|
||||
AUTOMATION_CONDITION,
|
||||
_is_manual,
|
||||
apply_automation,
|
||||
)
|
||||
|
||||
|
||||
def _make_asset_def(tags: dict | None = None):
|
||||
"""Create a minimal mock AssetsDefinition with the given tags."""
|
||||
spec = MagicMock(spec=AssetSpec)
|
||||
spec.tags = tags or {}
|
||||
|
||||
asset_def = MagicMock()
|
||||
asset_def.specs = [spec]
|
||||
|
||||
updated = MagicMock()
|
||||
asset_def.with_attributes.return_value = updated
|
||||
|
||||
return asset_def, updated
|
||||
|
||||
|
||||
class TestIsManual:
|
||||
def test_manual_when_tagged(self):
|
||||
asset_def, _ = _make_asset_def(tags={"manual": "true"})
|
||||
assert _is_manual(asset_def) is True
|
||||
|
||||
def test_not_manual_when_no_tags(self):
|
||||
asset_def, _ = _make_asset_def(tags={})
|
||||
assert _is_manual(asset_def) is False
|
||||
|
||||
def test_not_manual_when_other_tags(self):
|
||||
asset_def, _ = _make_asset_def(tags={"owner": "team-data"})
|
||||
assert _is_manual(asset_def) is False
|
||||
|
||||
|
||||
class TestApplyAutomation:
|
||||
def test_manual_asset_unchanged(self):
|
||||
asset_def, _ = _make_asset_def(tags={"manual": "true"})
|
||||
result = apply_automation([asset_def])
|
||||
assert result == [asset_def]
|
||||
asset_def.with_attributes.assert_not_called()
|
||||
|
||||
def test_non_manual_asset_gets_condition(self):
|
||||
asset_def, updated = _make_asset_def(tags={})
|
||||
result = apply_automation([asset_def])
|
||||
assert result == [updated]
|
||||
asset_def.with_attributes.assert_called_once_with(
|
||||
automation_condition=AUTOMATION_CONDITION
|
||||
)
|
||||
|
||||
def test_empty_list(self):
|
||||
assert apply_automation([]) == []
|
||||
|
||||
def test_mixed_assets(self):
|
||||
manual, _ = _make_asset_def(tags={"manual": "true"})
|
||||
auto, auto_updated = _make_asset_def(tags={})
|
||||
result = apply_automation([manual, auto])
|
||||
assert result == [manual, auto_updated]
|
||||
35
tests/test_helpers_sql.py
Normal file
35
tests/test_helpers_sql.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Tests for data_platform.helpers.sql — render_sql."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from jinja2 import TemplateNotFound
|
||||
|
||||
from data_platform.helpers import render_sql
|
||||
|
||||
|
||||
_FIXTURES = Path(__file__).parent / "fixtures" / "sql"
|
||||
|
||||
|
||||
class TestRenderSql:
|
||||
def test_renders_plain_sql(self, tmp_path):
|
||||
sql_file = tmp_path / "plain.sql"
|
||||
sql_file.write_text("select 1")
|
||||
result = render_sql(tmp_path, "plain.sql")
|
||||
assert result == "select 1"
|
||||
|
||||
def test_renders_with_variables(self, tmp_path):
|
||||
sql_file = tmp_path / "schema.sql"
|
||||
sql_file.write_text("create schema if not exists {{ schema }}")
|
||||
result = render_sql(tmp_path, "schema.sql", schema="elo")
|
||||
assert result == "create schema if not exists elo"
|
||||
|
||||
def test_missing_template_raises(self, tmp_path):
|
||||
with pytest.raises(TemplateNotFound):
|
||||
render_sql(tmp_path, "nonexistent.sql")
|
||||
|
||||
def test_multiple_variables(self, tmp_path):
|
||||
sql_file = tmp_path / "multi.sql"
|
||||
sql_file.write_text("select * from {{ schema }}.{{ table }}")
|
||||
result = render_sql(tmp_path, "multi.sql", schema="raw", table="events")
|
||||
assert result == "select * from raw.events"
|
||||
203
tests/test_ml.py
Normal file
203
tests/test_ml.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Tests for ML helper functions — preprocess, _best_run, _build_embed."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from data_platform.assets.ml.elo_model import (
|
||||
ALL_FEATURES,
|
||||
ENERGY_LABEL_MAP,
|
||||
preprocess,
|
||||
)
|
||||
from data_platform.assets.ml.discord_alerts import _build_embed
|
||||
from data_platform.assets.ml.elo_inference import _best_run
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# preprocess
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPreprocess:
|
||||
def _make_df(self, overrides: dict | None = None) -> pd.DataFrame:
|
||||
"""Build a single-row DataFrame with sensible defaults."""
|
||||
data = {
|
||||
"energy_label": "A",
|
||||
"current_price": 350_000,
|
||||
"living_area": 80,
|
||||
"plot_area": 120,
|
||||
"bedrooms": 3,
|
||||
"rooms": 5,
|
||||
"construction_year": 2000,
|
||||
"latitude": 52.0,
|
||||
"longitude": 4.5,
|
||||
"photo_count": 10,
|
||||
"views": 100,
|
||||
"saves": 20,
|
||||
"price_per_sqm": 4375,
|
||||
"has_garden": True,
|
||||
"has_balcony": False,
|
||||
"has_solar_panels": None,
|
||||
"has_heat_pump": False,
|
||||
"has_roof_terrace": None,
|
||||
"is_energy_efficient": True,
|
||||
"is_monument": False,
|
||||
}
|
||||
if overrides:
|
||||
data.update(overrides)
|
||||
return pd.DataFrame([data])
|
||||
|
||||
def test_energy_label_mapped(self):
|
||||
df = preprocess(self._make_df({"energy_label": "A"}))
|
||||
assert df["energy_label_num"].iloc[0] == ENERGY_LABEL_MAP["A"]
|
||||
|
||||
def test_energy_label_g(self):
|
||||
df = preprocess(self._make_df({"energy_label": "G"}))
|
||||
assert df["energy_label_num"].iloc[0] == -1
|
||||
|
||||
def test_energy_label_unknown(self):
|
||||
df = preprocess(self._make_df({"energy_label": "Z"}))
|
||||
assert df["energy_label_num"].iloc[0] == -2
|
||||
|
||||
def test_energy_label_case_insensitive(self):
|
||||
df = preprocess(self._make_df({"energy_label": " a "}))
|
||||
assert df["energy_label_num"].iloc[0] == ENERGY_LABEL_MAP["A"]
|
||||
|
||||
def test_bool_none_becomes_zero(self):
|
||||
df = preprocess(self._make_df({"has_garden": None}))
|
||||
assert df["has_garden"].iloc[0] == 0
|
||||
|
||||
def test_bool_true_becomes_one(self):
|
||||
df = preprocess(self._make_df({"has_garden": True}))
|
||||
assert df["has_garden"].iloc[0] == 1
|
||||
|
||||
def test_bool_false_becomes_zero(self):
|
||||
df = preprocess(self._make_df({"has_garden": False}))
|
||||
assert df["has_garden"].iloc[0] == 0
|
||||
|
||||
def test_numeric_string_coerced(self):
|
||||
df = preprocess(self._make_df({"current_price": "350000"}))
|
||||
assert df["current_price"].iloc[0] == 350_000.0
|
||||
|
||||
def test_numeric_nan_filled_with_median(self):
|
||||
data = self._make_df()
|
||||
row2 = self._make_df({"current_price": None})
|
||||
df = pd.concat([data, row2], ignore_index=True)
|
||||
df = preprocess(df)
|
||||
# median of [350000] is 350000 (single non-null value)
|
||||
assert df["current_price"].iloc[1] == 350_000.0
|
||||
|
||||
def test_all_features_present(self):
|
||||
df = preprocess(self._make_df())
|
||||
for feat in ALL_FEATURES:
|
||||
assert feat in df.columns, f"Missing feature: {feat}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _best_run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBestRun:
|
||||
@patch("data_platform.assets.ml.elo_inference.mlflow")
|
||||
def test_no_experiment_raises(self, mock_mlflow):
|
||||
mock_mlflow.tracking.MlflowClient.return_value.get_experiment_by_name.return_value = None
|
||||
with pytest.raises(ValueError, match="does not exist"):
|
||||
_best_run("nonexistent", "rmse", ascending=True)
|
||||
|
||||
@patch("data_platform.assets.ml.elo_inference.mlflow")
|
||||
def test_no_runs_raises(self, mock_mlflow):
|
||||
client = mock_mlflow.tracking.MlflowClient.return_value
|
||||
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
|
||||
client.search_runs.return_value = []
|
||||
with pytest.raises(ValueError, match="No runs found"):
|
||||
_best_run("experiment", "rmse", ascending=True)
|
||||
|
||||
@patch("data_platform.assets.ml.elo_inference.mlflow")
|
||||
def test_ascending_order(self, mock_mlflow):
|
||||
client = mock_mlflow.tracking.MlflowClient.return_value
|
||||
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
|
||||
run = MagicMock()
|
||||
client.search_runs.return_value = [run]
|
||||
|
||||
result = _best_run("experiment", "rmse", ascending=True)
|
||||
assert result is run
|
||||
client.search_runs.assert_called_once_with(
|
||||
experiment_ids=["1"],
|
||||
order_by=["metrics.rmse ASC"],
|
||||
max_results=1,
|
||||
)
|
||||
|
||||
@patch("data_platform.assets.ml.elo_inference.mlflow")
|
||||
def test_descending_order(self, mock_mlflow):
|
||||
client = mock_mlflow.tracking.MlflowClient.return_value
|
||||
client.get_experiment_by_name.return_value = MagicMock(experiment_id="1")
|
||||
run = MagicMock()
|
||||
client.search_runs.return_value = [run]
|
||||
|
||||
result = _best_run("experiment", "r2", ascending=False)
|
||||
assert result is run
|
||||
client.search_runs.assert_called_once_with(
|
||||
experiment_ids=["1"],
|
||||
order_by=["metrics.r2 DESC"],
|
||||
max_results=1,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_embed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildEmbed:
|
||||
def _make_row(self, **overrides):
|
||||
data = {
|
||||
"predicted_elo": 1650.0,
|
||||
"current_price": 350_000,
|
||||
"city": "Amsterdam",
|
||||
"living_area": 80,
|
||||
"rooms": 4,
|
||||
"energy_label": "A",
|
||||
"price_per_sqm": 4375,
|
||||
"title": "Teststraat 1",
|
||||
"global_id": "abc123",
|
||||
"url": "https://funda.nl/abc123",
|
||||
}
|
||||
data.update(overrides)
|
||||
return SimpleNamespace(**data)
|
||||
|
||||
def test_all_fields_present(self):
|
||||
embed = _build_embed(self._make_row())
|
||||
field_names = [f["name"] for f in embed["fields"]]
|
||||
assert "Predicted ELO" in field_names
|
||||
assert "Price" in field_names
|
||||
assert "City" in field_names
|
||||
assert "Living area" in field_names
|
||||
assert "Rooms" in field_names
|
||||
assert "Energy label" in field_names
|
||||
assert "€/m²" in field_names
|
||||
|
||||
def test_no_price_per_sqm_omits_field(self):
|
||||
embed = _build_embed(self._make_row(price_per_sqm=None))
|
||||
field_names = [f["name"] for f in embed["fields"]]
|
||||
assert "€/m²" not in field_names
|
||||
assert len(embed["fields"]) == 6
|
||||
|
||||
def test_missing_city_shows_dash(self):
|
||||
embed = _build_embed(self._make_row(city=None))
|
||||
city_field = next(f for f in embed["fields"] if f["name"] == "City")
|
||||
assert city_field["value"] == "–"
|
||||
|
||||
def test_title_fallback_to_global_id(self):
|
||||
embed = _build_embed(self._make_row(title=None))
|
||||
assert embed["title"] == "abc123"
|
||||
|
||||
def test_url_set(self):
|
||||
embed = _build_embed(self._make_row())
|
||||
assert embed["url"] == "https://funda.nl/abc123"
|
||||
|
||||
def test_color_is_green(self):
|
||||
embed = _build_embed(self._make_row())
|
||||
assert embed["color"] == 0x00B894
|
||||
@@ -2,7 +2,12 @@
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from data_platform.resources import FundaResource, PostgresResource
|
||||
from data_platform.resources import (
|
||||
DiscordResource,
|
||||
FundaResource,
|
||||
MLflowResource,
|
||||
PostgresResource,
|
||||
)
|
||||
|
||||
|
||||
class TestFundaResource:
|
||||
@@ -79,3 +84,15 @@ class TestPostgresResource:
|
||||
res.execute_many("INSERT INTO t VALUES (:id)", rows)
|
||||
|
||||
mock_conn.execute.assert_called_once()
|
||||
|
||||
|
||||
class TestMLflowResource:
|
||||
def test_tracking_uri(self):
|
||||
resource = MLflowResource(tracking_uri="http://mlflow:5000")
|
||||
assert resource.get_tracking_uri() == "http://mlflow:5000"
|
||||
|
||||
|
||||
class TestDiscordResource:
|
||||
def test_webhook_url(self):
|
||||
resource = DiscordResource(webhook_url="https://discord.com/api/webhooks/test")
|
||||
assert resource.get_webhook_url() == "https://discord.com/api/webhooks/test"
|
||||
|
||||
Reference in New Issue
Block a user