feat: add inference for elo on new listings
This commit is contained in:
@@ -14,3 +14,6 @@ DBT_TARGET=dev
|
||||
# pgAdmin
|
||||
PGADMIN_EMAIL=admin@example.com
|
||||
PGADMIN_PASSWORD=changeme
|
||||
|
||||
# Discord webhook for ELO alerts
|
||||
DISCORD_WEBHOOK_URL=https://discord.com/api/webhooks/your/webhook/url
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Machine-learning assets."""
|
||||
|
||||
from data_platform.assets.ml.discord_alerts import listing_alert
|
||||
from data_platform.assets.ml.elo_inference import elo_inference
|
||||
from data_platform.assets.ml.elo_model import elo_prediction_model
|
||||
|
||||
__all__ = ["elo_prediction_model"]
|
||||
__all__ = ["elo_inference", "elo_prediction_model", "listing_alert"]
|
||||
|
||||
121
data_platform/assets/ml/discord_alerts.py
Normal file
121
data_platform/assets/ml/discord_alerts.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Discord notification asset for high-ELO listings."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from dagster import (
|
||||
AssetExecutionContext,
|
||||
Config,
|
||||
MaterializeResult,
|
||||
MetadataValue,
|
||||
asset,
|
||||
)
|
||||
from sqlalchemy import text
|
||||
|
||||
from data_platform.helpers import format_euro, format_area, render_sql
|
||||
from data_platform.resources import DiscordResource, PostgresResource
|
||||
|
||||
_SQL_DIR = Path(__file__).parent / "sql"
|
||||
|
||||
|
||||
class DiscordNotificationConfig(Config):
|
||||
"""Configuration for Discord ELO notifications."""
|
||||
|
||||
min_elo: float = 1600
|
||||
|
||||
|
||||
def _build_embed(row) -> dict:
|
||||
"""Build a Discord embed for a single listing."""
|
||||
fields = [
|
||||
{"name": "Predicted ELO", "value": f"{row.predicted_elo:.0f}", "inline": True},
|
||||
{"name": "Price", "value": format_euro(row.current_price), "inline": True},
|
||||
{"name": "City", "value": row.city or "–", "inline": True},
|
||||
{"name": "Living area", "value": format_area(row.living_area), "inline": True},
|
||||
{"name": "Rooms", "value": str(row.rooms or "–"), "inline": True},
|
||||
{"name": "Energy label", "value": row.energy_label or "–", "inline": True},
|
||||
]
|
||||
if row.price_per_sqm:
|
||||
fields.append(
|
||||
{"name": "€/m²", "value": format_euro(row.price_per_sqm), "inline": True}
|
||||
)
|
||||
|
||||
return {
|
||||
"title": row.title or row.global_id,
|
||||
"url": row.url,
|
||||
"color": 0x00B894, # green
|
||||
"fields": fields,
|
||||
}
|
||||
|
||||
|
||||
@asset(
|
||||
deps=["elo_inference"],
|
||||
group_name="ml",
|
||||
kinds={"python", "discord"},
|
||||
description=(
|
||||
"Send a Discord notification for newly scored listings whose "
|
||||
"predicted ELO exceeds a configurable threshold."
|
||||
),
|
||||
)
|
||||
def listing_alert(
|
||||
context: AssetExecutionContext,
|
||||
config: DiscordNotificationConfig,
|
||||
postgres: PostgresResource,
|
||||
discord: DiscordResource,
|
||||
) -> MaterializeResult:
|
||||
engine = postgres.get_engine()
|
||||
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text(render_sql(_SQL_DIR, "ensure_elo_schema.sql")))
|
||||
conn.execute(text(render_sql(_SQL_DIR, "ensure_notified_table.sql")))
|
||||
|
||||
query = render_sql(_SQL_DIR, "select_top_predictions.sql")
|
||||
df = __import__("pandas").read_sql(
|
||||
text(query),
|
||||
engine,
|
||||
params={"min_elo": config.min_elo},
|
||||
)
|
||||
context.log.info(f"Found {len(df)} listings above ELO threshold {config.min_elo}.")
|
||||
|
||||
if df.empty:
|
||||
return MaterializeResult(
|
||||
metadata={
|
||||
"notified": 0,
|
||||
"status": MetadataValue.text("No listings above threshold."),
|
||||
}
|
||||
)
|
||||
|
||||
# Send in batches of up to 10 embeds per message (Discord limit)
|
||||
webhook_url = discord.get_webhook_url()
|
||||
batch_size = 10
|
||||
sent = 0
|
||||
|
||||
for i in range(0, len(df), batch_size):
|
||||
batch = df.iloc[i : i + batch_size]
|
||||
embeds = [_build_embed(row) for row in batch.itertuples()]
|
||||
payload = {
|
||||
"username": "ELO Scout",
|
||||
"content": (
|
||||
f"**{len(embeds)} listing(s) scored above ELO {config.min_elo:.0f}**"
|
||||
if i == 0
|
||||
else None
|
||||
),
|
||||
"embeds": embeds,
|
||||
}
|
||||
resp = requests.post(webhook_url, json=payload, timeout=15)
|
||||
resp.raise_for_status()
|
||||
sent += len(embeds)
|
||||
|
||||
# Mark as notified so we don't send duplicates
|
||||
insert_notified = render_sql(_SQL_DIR, "insert_notified.sql")
|
||||
notified_rows = [{"global_id": gid} for gid in df["global_id"]]
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text(insert_notified), notified_rows)
|
||||
|
||||
context.log.info(f"Sent {sent} notification(s) to Discord.")
|
||||
|
||||
return MaterializeResult(
|
||||
metadata={
|
||||
"notified": sent,
|
||||
"min_elo_threshold": MetadataValue.float(config.min_elo),
|
||||
}
|
||||
)
|
||||
138
data_platform/assets/ml/elo_inference.py
Normal file
138
data_platform/assets/ml/elo_inference.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Infer ELO scores for new listings using the best trained model."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import mlflow
|
||||
import pandas as pd
|
||||
from dagster import (
|
||||
AssetExecutionContext,
|
||||
AssetKey,
|
||||
Config,
|
||||
MaterializeResult,
|
||||
MetadataValue,
|
||||
asset,
|
||||
)
|
||||
from sqlalchemy import text
|
||||
|
||||
from data_platform.assets.ml.elo_model import (
|
||||
ALL_FEATURES,
|
||||
_preprocess,
|
||||
)
|
||||
from data_platform.helpers import render_sql
|
||||
from data_platform.resources import MLflowResource, PostgresResource
|
||||
|
||||
_SQL_DIR = Path(__file__).parent / "sql"
|
||||
|
||||
|
||||
class EloInferenceConfig(Config):
|
||||
"""Configuration for ELO inference."""
|
||||
|
||||
mlflow_experiment: str = "elo-rating-prediction"
|
||||
metric: str = "rmse"
|
||||
ascending: bool = True
|
||||
|
||||
|
||||
def _best_run(experiment_name: str, metric: str, ascending: bool):
|
||||
"""Return the MLflow run with the best metric value."""
|
||||
client = mlflow.tracking.MlflowClient()
|
||||
experiment = client.get_experiment_by_name(experiment_name)
|
||||
if experiment is None:
|
||||
raise ValueError(
|
||||
f"MLflow experiment '{experiment_name}' does not exist. "
|
||||
"Train the elo_prediction_model asset first."
|
||||
)
|
||||
|
||||
order = "ASC" if ascending else "DESC"
|
||||
runs = client.search_runs(
|
||||
experiment_ids=[experiment.experiment_id],
|
||||
order_by=[f"metrics.{metric} {order}"],
|
||||
max_results=1,
|
||||
)
|
||||
if not runs:
|
||||
raise ValueError(
|
||||
f"No runs found in experiment '{experiment_name}'. "
|
||||
"Train the elo_prediction_model asset first."
|
||||
)
|
||||
return runs[0]
|
||||
|
||||
|
||||
@asset(
|
||||
deps=["elo_prediction_model", AssetKey(["marts", "funda_listings"])],
|
||||
group_name="ml",
|
||||
kinds={"python", "mlflow"},
|
||||
tags={"manual": "true"},
|
||||
description=(
|
||||
"Load the best ELO prediction model from MLflow and infer scores "
|
||||
"for all listings that have not been scored yet."
|
||||
),
|
||||
)
|
||||
def elo_inference(
|
||||
context: AssetExecutionContext,
|
||||
config: EloInferenceConfig,
|
||||
postgres: PostgresResource,
|
||||
mlflow_resource: MLflowResource,
|
||||
) -> MaterializeResult:
|
||||
engine = postgres.get_engine()
|
||||
|
||||
# Ensure target table exists
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text(render_sql(_SQL_DIR, "ensure_elo_schema.sql")))
|
||||
conn.execute(text(render_sql(_SQL_DIR, "ensure_predictions_table.sql")))
|
||||
|
||||
# Fetch unscored listings
|
||||
query = render_sql(_SQL_DIR, "select_unscored_listings.sql")
|
||||
df = pd.read_sql(text(query), engine)
|
||||
context.log.info(f"Found {len(df)} unscored listings.")
|
||||
|
||||
if df.empty:
|
||||
return MaterializeResult(
|
||||
metadata={
|
||||
"scored": 0,
|
||||
"status": MetadataValue.text("No new listings to score."),
|
||||
}
|
||||
)
|
||||
|
||||
# Load best model
|
||||
mlflow.set_tracking_uri(mlflow_resource.get_tracking_uri())
|
||||
best_run = _best_run(config.mlflow_experiment, config.metric, config.ascending)
|
||||
run_id = best_run.info.run_id
|
||||
model_uri = f"runs:/{run_id}/elo_lgbm_model"
|
||||
context.log.info(
|
||||
f"Loading model from run {run_id} "
|
||||
f"({config.metric}={best_run.data.metrics.get(config.metric, '?')})."
|
||||
)
|
||||
model = mlflow.lightgbm.load_model(model_uri)
|
||||
|
||||
# Preprocess features identically to training
|
||||
df = _preprocess(df)
|
||||
X = df[ALL_FEATURES].copy()
|
||||
|
||||
# Predict normalised ELO and convert back to original scale
|
||||
elo_norm = model.predict(X)
|
||||
df["predicted_elo"] = elo_norm * 100 + 1500
|
||||
|
||||
# Write predictions
|
||||
rows = [
|
||||
{
|
||||
"global_id": row.global_id,
|
||||
"predicted_elo": float(row.predicted_elo),
|
||||
"mlflow_run_id": run_id,
|
||||
}
|
||||
for row in df.itertuples()
|
||||
]
|
||||
upsert = render_sql(_SQL_DIR, "upsert_prediction.sql")
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text(upsert), rows)
|
||||
|
||||
context.log.info(f"Wrote {len(rows)} predictions (run {run_id}).")
|
||||
|
||||
return MaterializeResult(
|
||||
metadata={
|
||||
"scored": len(rows),
|
||||
"mlflow_run_id": MetadataValue.text(run_id),
|
||||
"predicted_elo_mean": MetadataValue.float(
|
||||
float(df["predicted_elo"].mean())
|
||||
),
|
||||
"predicted_elo_std": MetadataValue.float(float(df["predicted_elo"].std())),
|
||||
}
|
||||
)
|
||||
@@ -6,8 +6,10 @@ import mlflow
|
||||
import mlflow.lightgbm
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sqlalchemy import text
|
||||
from dagster import (
|
||||
AssetExecutionContext,
|
||||
AssetKey,
|
||||
Config,
|
||||
MaterializeResult,
|
||||
MetadataValue,
|
||||
@@ -105,7 +107,7 @@ def _preprocess(df: pd.DataFrame) -> pd.DataFrame:
|
||||
|
||||
|
||||
@asset(
|
||||
deps=["elo_ratings", "funda_listings"],
|
||||
deps=["elo_ratings", AssetKey(["marts", "funda_listings"])],
|
||||
group_name="ml",
|
||||
kinds={"python", "mlflow", "lightgbm"},
|
||||
tags={"manual": "true"},
|
||||
@@ -124,7 +126,7 @@ def elo_prediction_model(
|
||||
engine = postgres.get_engine()
|
||||
query = render_sql(_SQL_DIR, "select_training_data.sql")
|
||||
df = pd.read_sql(
|
||||
query,
|
||||
text(query),
|
||||
engine,
|
||||
params={"min_comparisons": config.min_comparisons},
|
||||
)
|
||||
|
||||
1
data_platform/assets/ml/sql/ensure_elo_schema.sql
Normal file
1
data_platform/assets/ml/sql/ensure_elo_schema.sql
Normal file
@@ -0,0 +1 @@
|
||||
create schema if not exists elo
|
||||
4
data_platform/assets/ml/sql/ensure_notified_table.sql
Normal file
4
data_platform/assets/ml/sql/ensure_notified_table.sql
Normal file
@@ -0,0 +1,4 @@
|
||||
create table if not exists elo.notified (
|
||||
global_id text primary key,
|
||||
notified_at timestamp with time zone default now()
|
||||
)
|
||||
6
data_platform/assets/ml/sql/ensure_predictions_table.sql
Normal file
6
data_platform/assets/ml/sql/ensure_predictions_table.sql
Normal file
@@ -0,0 +1,6 @@
|
||||
create table if not exists elo.predictions (
|
||||
global_id text primary key,
|
||||
predicted_elo double precision not null,
|
||||
mlflow_run_id text not null,
|
||||
scored_at timestamp with time zone default now()
|
||||
)
|
||||
3
data_platform/assets/ml/sql/insert_notified.sql
Normal file
3
data_platform/assets/ml/sql/insert_notified.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
insert into elo.notified (global_id)
|
||||
values (: global_id)
|
||||
on conflict (global_id) do nothing
|
||||
20
data_platform/assets/ml/sql/select_top_predictions.sql
Normal file
20
data_platform/assets/ml/sql/select_top_predictions.sql
Normal file
@@ -0,0 +1,20 @@
|
||||
select
|
||||
ep.global_id,
|
||||
ep.predicted_elo,
|
||||
fl.title,
|
||||
fl.city,
|
||||
fl.url,
|
||||
fl.current_price,
|
||||
fl.living_area,
|
||||
fl.bedrooms,
|
||||
fl.rooms,
|
||||
fl.energy_label,
|
||||
fl.price_per_sqm,
|
||||
ep.scored_at
|
||||
from elo.predictions as ep
|
||||
inner join marts.funda_listings as fl on ep.global_id = fl.global_id
|
||||
left join elo.notified as en on ep.global_id = en.global_id
|
||||
where
|
||||
ep.predicted_elo >=: min_elo
|
||||
and en.global_id is null
|
||||
order by ep.predicted_elo desc
|
||||
28
data_platform/assets/ml/sql/select_unscored_listings.sql
Normal file
28
data_platform/assets/ml/sql/select_unscored_listings.sql
Normal file
@@ -0,0 +1,28 @@
|
||||
select
|
||||
fl.global_id,
|
||||
fl.url,
|
||||
fl.title,
|
||||
fl.city,
|
||||
fl.current_price,
|
||||
fl.living_area,
|
||||
fl.plot_area,
|
||||
fl.bedrooms,
|
||||
fl.rooms,
|
||||
fl.construction_year,
|
||||
fl.latitude,
|
||||
fl.longitude,
|
||||
fl.energy_label,
|
||||
fl.has_garden,
|
||||
fl.has_balcony,
|
||||
fl.has_solar_panels,
|
||||
fl.has_heat_pump,
|
||||
fl.has_roof_terrace,
|
||||
fl.is_energy_efficient,
|
||||
fl.is_monument,
|
||||
fl.photo_count,
|
||||
fl.views,
|
||||
fl.saves,
|
||||
fl.price_per_sqm
|
||||
from marts.funda_listings as fl
|
||||
left join elo.predictions as ep on fl.global_id = ep.global_id
|
||||
where ep.global_id is null
|
||||
7
data_platform/assets/ml/sql/upsert_prediction.sql
Normal file
7
data_platform/assets/ml/sql/upsert_prediction.sql
Normal file
@@ -0,0 +1,7 @@
|
||||
insert into elo.predictions (global_id, predicted_elo, mlflow_run_id)
|
||||
values (: global_id,: predicted_elo,: mlflow_run_id)
|
||||
on conflict (global_id) do update
|
||||
set
|
||||
predicted_elo = excluded.predicted_elo,
|
||||
mlflow_run_id = excluded.mlflow_run_id,
|
||||
scored_at = now()
|
||||
@@ -12,14 +12,19 @@ from data_platform.assets.ingestion.funda import (
|
||||
raw_funda_price_history,
|
||||
raw_funda_search_results,
|
||||
)
|
||||
from data_platform.assets.ml import elo_prediction_model
|
||||
from data_platform.assets.ml import elo_inference, elo_prediction_model, listing_alert
|
||||
from data_platform.helpers import apply_automation
|
||||
from data_platform.jobs import (
|
||||
elementary_refresh_job,
|
||||
funda_ingestion_job,
|
||||
funda_raw_quality_job,
|
||||
)
|
||||
from data_platform.resources import FundaResource, MLflowResource, PostgresResource
|
||||
from data_platform.resources import (
|
||||
DiscordResource,
|
||||
FundaResource,
|
||||
MLflowResource,
|
||||
PostgresResource,
|
||||
)
|
||||
from data_platform.schedules import (
|
||||
elementary_refresh_schedule,
|
||||
funda_ingestion_schedule,
|
||||
@@ -36,6 +41,8 @@ defs = Definitions(
|
||||
elo_ratings,
|
||||
elo_comparisons,
|
||||
elo_prediction_model,
|
||||
elo_inference,
|
||||
listing_alert,
|
||||
]
|
||||
),
|
||||
jobs=[funda_ingestion_job, funda_raw_quality_job, elementary_refresh_job],
|
||||
@@ -56,5 +63,6 @@ defs = Definitions(
|
||||
"funda": FundaResource(),
|
||||
"postgres": PostgresResource(),
|
||||
"mlflow_resource": MLflowResource(),
|
||||
"discord": DiscordResource(),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -46,3 +46,12 @@ class MLflowResource(ConfigurableResource):
|
||||
|
||||
def get_tracking_uri(self) -> str:
|
||||
return self.tracking_uri
|
||||
|
||||
|
||||
class DiscordResource(ConfigurableResource):
|
||||
"""Discord webhook resource for sending notifications."""
|
||||
|
||||
webhook_url: str = EnvVar("DISCORD_WEBHOOK_URL")
|
||||
|
||||
def get_webhook_url(self) -> str:
|
||||
return self.webhook_url
|
||||
|
||||
@@ -32,4 +32,6 @@ exec mlflow server \
|
||||
--host=0.0.0.0 \
|
||||
--port=5000 \
|
||||
--backend-store-uri="postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@postgres:5432/mlflow" \
|
||||
--default-artifact-root=/mlflow/artifacts
|
||||
--default-artifact-root=/mlflow/artifacts \
|
||||
--allowed-hosts="*" \
|
||||
--cors-allowed-origins="*"
|
||||
|
||||
Reference in New Issue
Block a user