feat: add inference for elo on new listings
This commit is contained in:
@@ -14,3 +14,6 @@ DBT_TARGET=dev
|
|||||||
# pgAdmin
|
# pgAdmin
|
||||||
PGADMIN_EMAIL=admin@example.com
|
PGADMIN_EMAIL=admin@example.com
|
||||||
PGADMIN_PASSWORD=changeme
|
PGADMIN_PASSWORD=changeme
|
||||||
|
|
||||||
|
# Discord webhook for ELO alerts
|
||||||
|
DISCORD_WEBHOOK_URL=https://discord.com/api/webhooks/your/webhook/url
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Machine-learning assets."""
|
"""Machine-learning assets."""
|
||||||
|
|
||||||
|
from data_platform.assets.ml.discord_alerts import listing_alert
|
||||||
|
from data_platform.assets.ml.elo_inference import elo_inference
|
||||||
from data_platform.assets.ml.elo_model import elo_prediction_model
|
from data_platform.assets.ml.elo_model import elo_prediction_model
|
||||||
|
|
||||||
__all__ = ["elo_prediction_model"]
|
__all__ = ["elo_inference", "elo_prediction_model", "listing_alert"]
|
||||||
|
|||||||
121
data_platform/assets/ml/discord_alerts.py
Normal file
121
data_platform/assets/ml/discord_alerts.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""Discord notification asset for high-ELO listings."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from dagster import (
|
||||||
|
AssetExecutionContext,
|
||||||
|
Config,
|
||||||
|
MaterializeResult,
|
||||||
|
MetadataValue,
|
||||||
|
asset,
|
||||||
|
)
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from data_platform.helpers import format_euro, format_area, render_sql
|
||||||
|
from data_platform.resources import DiscordResource, PostgresResource
|
||||||
|
|
||||||
|
_SQL_DIR = Path(__file__).parent / "sql"
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordNotificationConfig(Config):
|
||||||
|
"""Configuration for Discord ELO notifications."""
|
||||||
|
|
||||||
|
min_elo: float = 1600
|
||||||
|
|
||||||
|
|
||||||
|
def _build_embed(row) -> dict:
|
||||||
|
"""Build a Discord embed for a single listing."""
|
||||||
|
fields = [
|
||||||
|
{"name": "Predicted ELO", "value": f"{row.predicted_elo:.0f}", "inline": True},
|
||||||
|
{"name": "Price", "value": format_euro(row.current_price), "inline": True},
|
||||||
|
{"name": "City", "value": row.city or "–", "inline": True},
|
||||||
|
{"name": "Living area", "value": format_area(row.living_area), "inline": True},
|
||||||
|
{"name": "Rooms", "value": str(row.rooms or "–"), "inline": True},
|
||||||
|
{"name": "Energy label", "value": row.energy_label or "–", "inline": True},
|
||||||
|
]
|
||||||
|
if row.price_per_sqm:
|
||||||
|
fields.append(
|
||||||
|
{"name": "€/m²", "value": format_euro(row.price_per_sqm), "inline": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"title": row.title or row.global_id,
|
||||||
|
"url": row.url,
|
||||||
|
"color": 0x00B894, # green
|
||||||
|
"fields": fields,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@asset(
|
||||||
|
deps=["elo_inference"],
|
||||||
|
group_name="ml",
|
||||||
|
kinds={"python", "discord"},
|
||||||
|
description=(
|
||||||
|
"Send a Discord notification for newly scored listings whose "
|
||||||
|
"predicted ELO exceeds a configurable threshold."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def listing_alert(
|
||||||
|
context: AssetExecutionContext,
|
||||||
|
config: DiscordNotificationConfig,
|
||||||
|
postgres: PostgresResource,
|
||||||
|
discord: DiscordResource,
|
||||||
|
) -> MaterializeResult:
|
||||||
|
engine = postgres.get_engine()
|
||||||
|
|
||||||
|
with engine.begin() as conn:
|
||||||
|
conn.execute(text(render_sql(_SQL_DIR, "ensure_elo_schema.sql")))
|
||||||
|
conn.execute(text(render_sql(_SQL_DIR, "ensure_notified_table.sql")))
|
||||||
|
|
||||||
|
query = render_sql(_SQL_DIR, "select_top_predictions.sql")
|
||||||
|
df = __import__("pandas").read_sql(
|
||||||
|
text(query),
|
||||||
|
engine,
|
||||||
|
params={"min_elo": config.min_elo},
|
||||||
|
)
|
||||||
|
context.log.info(f"Found {len(df)} listings above ELO threshold {config.min_elo}.")
|
||||||
|
|
||||||
|
if df.empty:
|
||||||
|
return MaterializeResult(
|
||||||
|
metadata={
|
||||||
|
"notified": 0,
|
||||||
|
"status": MetadataValue.text("No listings above threshold."),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send in batches of up to 10 embeds per message (Discord limit)
|
||||||
|
webhook_url = discord.get_webhook_url()
|
||||||
|
batch_size = 10
|
||||||
|
sent = 0
|
||||||
|
|
||||||
|
for i in range(0, len(df), batch_size):
|
||||||
|
batch = df.iloc[i : i + batch_size]
|
||||||
|
embeds = [_build_embed(row) for row in batch.itertuples()]
|
||||||
|
payload = {
|
||||||
|
"username": "ELO Scout",
|
||||||
|
"content": (
|
||||||
|
f"**{len(embeds)} listing(s) scored above ELO {config.min_elo:.0f}**"
|
||||||
|
if i == 0
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
"embeds": embeds,
|
||||||
|
}
|
||||||
|
resp = requests.post(webhook_url, json=payload, timeout=15)
|
||||||
|
resp.raise_for_status()
|
||||||
|
sent += len(embeds)
|
||||||
|
|
||||||
|
# Mark as notified so we don't send duplicates
|
||||||
|
insert_notified = render_sql(_SQL_DIR, "insert_notified.sql")
|
||||||
|
notified_rows = [{"global_id": gid} for gid in df["global_id"]]
|
||||||
|
with engine.begin() as conn:
|
||||||
|
conn.execute(text(insert_notified), notified_rows)
|
||||||
|
|
||||||
|
context.log.info(f"Sent {sent} notification(s) to Discord.")
|
||||||
|
|
||||||
|
return MaterializeResult(
|
||||||
|
metadata={
|
||||||
|
"notified": sent,
|
||||||
|
"min_elo_threshold": MetadataValue.float(config.min_elo),
|
||||||
|
}
|
||||||
|
)
|
||||||
138
data_platform/assets/ml/elo_inference.py
Normal file
138
data_platform/assets/ml/elo_inference.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""Infer ELO scores for new listings using the best trained model."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import mlflow
|
||||||
|
import pandas as pd
|
||||||
|
from dagster import (
|
||||||
|
AssetExecutionContext,
|
||||||
|
AssetKey,
|
||||||
|
Config,
|
||||||
|
MaterializeResult,
|
||||||
|
MetadataValue,
|
||||||
|
asset,
|
||||||
|
)
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from data_platform.assets.ml.elo_model import (
|
||||||
|
ALL_FEATURES,
|
||||||
|
_preprocess,
|
||||||
|
)
|
||||||
|
from data_platform.helpers import render_sql
|
||||||
|
from data_platform.resources import MLflowResource, PostgresResource
|
||||||
|
|
||||||
|
_SQL_DIR = Path(__file__).parent / "sql"
|
||||||
|
|
||||||
|
|
||||||
|
class EloInferenceConfig(Config):
|
||||||
|
"""Configuration for ELO inference."""
|
||||||
|
|
||||||
|
mlflow_experiment: str = "elo-rating-prediction"
|
||||||
|
metric: str = "rmse"
|
||||||
|
ascending: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
def _best_run(experiment_name: str, metric: str, ascending: bool):
|
||||||
|
"""Return the MLflow run with the best metric value."""
|
||||||
|
client = mlflow.tracking.MlflowClient()
|
||||||
|
experiment = client.get_experiment_by_name(experiment_name)
|
||||||
|
if experiment is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"MLflow experiment '{experiment_name}' does not exist. "
|
||||||
|
"Train the elo_prediction_model asset first."
|
||||||
|
)
|
||||||
|
|
||||||
|
order = "ASC" if ascending else "DESC"
|
||||||
|
runs = client.search_runs(
|
||||||
|
experiment_ids=[experiment.experiment_id],
|
||||||
|
order_by=[f"metrics.{metric} {order}"],
|
||||||
|
max_results=1,
|
||||||
|
)
|
||||||
|
if not runs:
|
||||||
|
raise ValueError(
|
||||||
|
f"No runs found in experiment '{experiment_name}'. "
|
||||||
|
"Train the elo_prediction_model asset first."
|
||||||
|
)
|
||||||
|
return runs[0]
|
||||||
|
|
||||||
|
|
||||||
|
@asset(
|
||||||
|
deps=["elo_prediction_model", AssetKey(["marts", "funda_listings"])],
|
||||||
|
group_name="ml",
|
||||||
|
kinds={"python", "mlflow"},
|
||||||
|
tags={"manual": "true"},
|
||||||
|
description=(
|
||||||
|
"Load the best ELO prediction model from MLflow and infer scores "
|
||||||
|
"for all listings that have not been scored yet."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def elo_inference(
|
||||||
|
context: AssetExecutionContext,
|
||||||
|
config: EloInferenceConfig,
|
||||||
|
postgres: PostgresResource,
|
||||||
|
mlflow_resource: MLflowResource,
|
||||||
|
) -> MaterializeResult:
|
||||||
|
engine = postgres.get_engine()
|
||||||
|
|
||||||
|
# Ensure target table exists
|
||||||
|
with engine.begin() as conn:
|
||||||
|
conn.execute(text(render_sql(_SQL_DIR, "ensure_elo_schema.sql")))
|
||||||
|
conn.execute(text(render_sql(_SQL_DIR, "ensure_predictions_table.sql")))
|
||||||
|
|
||||||
|
# Fetch unscored listings
|
||||||
|
query = render_sql(_SQL_DIR, "select_unscored_listings.sql")
|
||||||
|
df = pd.read_sql(text(query), engine)
|
||||||
|
context.log.info(f"Found {len(df)} unscored listings.")
|
||||||
|
|
||||||
|
if df.empty:
|
||||||
|
return MaterializeResult(
|
||||||
|
metadata={
|
||||||
|
"scored": 0,
|
||||||
|
"status": MetadataValue.text("No new listings to score."),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load best model
|
||||||
|
mlflow.set_tracking_uri(mlflow_resource.get_tracking_uri())
|
||||||
|
best_run = _best_run(config.mlflow_experiment, config.metric, config.ascending)
|
||||||
|
run_id = best_run.info.run_id
|
||||||
|
model_uri = f"runs:/{run_id}/elo_lgbm_model"
|
||||||
|
context.log.info(
|
||||||
|
f"Loading model from run {run_id} "
|
||||||
|
f"({config.metric}={best_run.data.metrics.get(config.metric, '?')})."
|
||||||
|
)
|
||||||
|
model = mlflow.lightgbm.load_model(model_uri)
|
||||||
|
|
||||||
|
# Preprocess features identically to training
|
||||||
|
df = _preprocess(df)
|
||||||
|
X = df[ALL_FEATURES].copy()
|
||||||
|
|
||||||
|
# Predict normalised ELO and convert back to original scale
|
||||||
|
elo_norm = model.predict(X)
|
||||||
|
df["predicted_elo"] = elo_norm * 100 + 1500
|
||||||
|
|
||||||
|
# Write predictions
|
||||||
|
rows = [
|
||||||
|
{
|
||||||
|
"global_id": row.global_id,
|
||||||
|
"predicted_elo": float(row.predicted_elo),
|
||||||
|
"mlflow_run_id": run_id,
|
||||||
|
}
|
||||||
|
for row in df.itertuples()
|
||||||
|
]
|
||||||
|
upsert = render_sql(_SQL_DIR, "upsert_prediction.sql")
|
||||||
|
with engine.begin() as conn:
|
||||||
|
conn.execute(text(upsert), rows)
|
||||||
|
|
||||||
|
context.log.info(f"Wrote {len(rows)} predictions (run {run_id}).")
|
||||||
|
|
||||||
|
return MaterializeResult(
|
||||||
|
metadata={
|
||||||
|
"scored": len(rows),
|
||||||
|
"mlflow_run_id": MetadataValue.text(run_id),
|
||||||
|
"predicted_elo_mean": MetadataValue.float(
|
||||||
|
float(df["predicted_elo"].mean())
|
||||||
|
),
|
||||||
|
"predicted_elo_std": MetadataValue.float(float(df["predicted_elo"].std())),
|
||||||
|
}
|
||||||
|
)
|
||||||
@@ -6,8 +6,10 @@ import mlflow
|
|||||||
import mlflow.lightgbm
|
import mlflow.lightgbm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from sqlalchemy import text
|
||||||
from dagster import (
|
from dagster import (
|
||||||
AssetExecutionContext,
|
AssetExecutionContext,
|
||||||
|
AssetKey,
|
||||||
Config,
|
Config,
|
||||||
MaterializeResult,
|
MaterializeResult,
|
||||||
MetadataValue,
|
MetadataValue,
|
||||||
@@ -105,7 +107,7 @@ def _preprocess(df: pd.DataFrame) -> pd.DataFrame:
|
|||||||
|
|
||||||
|
|
||||||
@asset(
|
@asset(
|
||||||
deps=["elo_ratings", "funda_listings"],
|
deps=["elo_ratings", AssetKey(["marts", "funda_listings"])],
|
||||||
group_name="ml",
|
group_name="ml",
|
||||||
kinds={"python", "mlflow", "lightgbm"},
|
kinds={"python", "mlflow", "lightgbm"},
|
||||||
tags={"manual": "true"},
|
tags={"manual": "true"},
|
||||||
@@ -124,7 +126,7 @@ def elo_prediction_model(
|
|||||||
engine = postgres.get_engine()
|
engine = postgres.get_engine()
|
||||||
query = render_sql(_SQL_DIR, "select_training_data.sql")
|
query = render_sql(_SQL_DIR, "select_training_data.sql")
|
||||||
df = pd.read_sql(
|
df = pd.read_sql(
|
||||||
query,
|
text(query),
|
||||||
engine,
|
engine,
|
||||||
params={"min_comparisons": config.min_comparisons},
|
params={"min_comparisons": config.min_comparisons},
|
||||||
)
|
)
|
||||||
|
|||||||
1
data_platform/assets/ml/sql/ensure_elo_schema.sql
Normal file
1
data_platform/assets/ml/sql/ensure_elo_schema.sql
Normal file
@@ -0,0 +1 @@
|
|||||||
|
create schema if not exists elo
|
||||||
4
data_platform/assets/ml/sql/ensure_notified_table.sql
Normal file
4
data_platform/assets/ml/sql/ensure_notified_table.sql
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
create table if not exists elo.notified (
|
||||||
|
global_id text primary key,
|
||||||
|
notified_at timestamp with time zone default now()
|
||||||
|
)
|
||||||
6
data_platform/assets/ml/sql/ensure_predictions_table.sql
Normal file
6
data_platform/assets/ml/sql/ensure_predictions_table.sql
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
create table if not exists elo.predictions (
|
||||||
|
global_id text primary key,
|
||||||
|
predicted_elo double precision not null,
|
||||||
|
mlflow_run_id text not null,
|
||||||
|
scored_at timestamp with time zone default now()
|
||||||
|
)
|
||||||
3
data_platform/assets/ml/sql/insert_notified.sql
Normal file
3
data_platform/assets/ml/sql/insert_notified.sql
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
insert into elo.notified (global_id)
|
||||||
|
values (: global_id)
|
||||||
|
on conflict (global_id) do nothing
|
||||||
20
data_platform/assets/ml/sql/select_top_predictions.sql
Normal file
20
data_platform/assets/ml/sql/select_top_predictions.sql
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
select
|
||||||
|
ep.global_id,
|
||||||
|
ep.predicted_elo,
|
||||||
|
fl.title,
|
||||||
|
fl.city,
|
||||||
|
fl.url,
|
||||||
|
fl.current_price,
|
||||||
|
fl.living_area,
|
||||||
|
fl.bedrooms,
|
||||||
|
fl.rooms,
|
||||||
|
fl.energy_label,
|
||||||
|
fl.price_per_sqm,
|
||||||
|
ep.scored_at
|
||||||
|
from elo.predictions as ep
|
||||||
|
inner join marts.funda_listings as fl on ep.global_id = fl.global_id
|
||||||
|
left join elo.notified as en on ep.global_id = en.global_id
|
||||||
|
where
|
||||||
|
ep.predicted_elo >=: min_elo
|
||||||
|
and en.global_id is null
|
||||||
|
order by ep.predicted_elo desc
|
||||||
28
data_platform/assets/ml/sql/select_unscored_listings.sql
Normal file
28
data_platform/assets/ml/sql/select_unscored_listings.sql
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
select
|
||||||
|
fl.global_id,
|
||||||
|
fl.url,
|
||||||
|
fl.title,
|
||||||
|
fl.city,
|
||||||
|
fl.current_price,
|
||||||
|
fl.living_area,
|
||||||
|
fl.plot_area,
|
||||||
|
fl.bedrooms,
|
||||||
|
fl.rooms,
|
||||||
|
fl.construction_year,
|
||||||
|
fl.latitude,
|
||||||
|
fl.longitude,
|
||||||
|
fl.energy_label,
|
||||||
|
fl.has_garden,
|
||||||
|
fl.has_balcony,
|
||||||
|
fl.has_solar_panels,
|
||||||
|
fl.has_heat_pump,
|
||||||
|
fl.has_roof_terrace,
|
||||||
|
fl.is_energy_efficient,
|
||||||
|
fl.is_monument,
|
||||||
|
fl.photo_count,
|
||||||
|
fl.views,
|
||||||
|
fl.saves,
|
||||||
|
fl.price_per_sqm
|
||||||
|
from marts.funda_listings as fl
|
||||||
|
left join elo.predictions as ep on fl.global_id = ep.global_id
|
||||||
|
where ep.global_id is null
|
||||||
7
data_platform/assets/ml/sql/upsert_prediction.sql
Normal file
7
data_platform/assets/ml/sql/upsert_prediction.sql
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
insert into elo.predictions (global_id, predicted_elo, mlflow_run_id)
|
||||||
|
values (: global_id,: predicted_elo,: mlflow_run_id)
|
||||||
|
on conflict (global_id) do update
|
||||||
|
set
|
||||||
|
predicted_elo = excluded.predicted_elo,
|
||||||
|
mlflow_run_id = excluded.mlflow_run_id,
|
||||||
|
scored_at = now()
|
||||||
@@ -12,14 +12,19 @@ from data_platform.assets.ingestion.funda import (
|
|||||||
raw_funda_price_history,
|
raw_funda_price_history,
|
||||||
raw_funda_search_results,
|
raw_funda_search_results,
|
||||||
)
|
)
|
||||||
from data_platform.assets.ml import elo_prediction_model
|
from data_platform.assets.ml import elo_inference, elo_prediction_model, listing_alert
|
||||||
from data_platform.helpers import apply_automation
|
from data_platform.helpers import apply_automation
|
||||||
from data_platform.jobs import (
|
from data_platform.jobs import (
|
||||||
elementary_refresh_job,
|
elementary_refresh_job,
|
||||||
funda_ingestion_job,
|
funda_ingestion_job,
|
||||||
funda_raw_quality_job,
|
funda_raw_quality_job,
|
||||||
)
|
)
|
||||||
from data_platform.resources import FundaResource, MLflowResource, PostgresResource
|
from data_platform.resources import (
|
||||||
|
DiscordResource,
|
||||||
|
FundaResource,
|
||||||
|
MLflowResource,
|
||||||
|
PostgresResource,
|
||||||
|
)
|
||||||
from data_platform.schedules import (
|
from data_platform.schedules import (
|
||||||
elementary_refresh_schedule,
|
elementary_refresh_schedule,
|
||||||
funda_ingestion_schedule,
|
funda_ingestion_schedule,
|
||||||
@@ -36,6 +41,8 @@ defs = Definitions(
|
|||||||
elo_ratings,
|
elo_ratings,
|
||||||
elo_comparisons,
|
elo_comparisons,
|
||||||
elo_prediction_model,
|
elo_prediction_model,
|
||||||
|
elo_inference,
|
||||||
|
listing_alert,
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
jobs=[funda_ingestion_job, funda_raw_quality_job, elementary_refresh_job],
|
jobs=[funda_ingestion_job, funda_raw_quality_job, elementary_refresh_job],
|
||||||
@@ -56,5 +63,6 @@ defs = Definitions(
|
|||||||
"funda": FundaResource(),
|
"funda": FundaResource(),
|
||||||
"postgres": PostgresResource(),
|
"postgres": PostgresResource(),
|
||||||
"mlflow_resource": MLflowResource(),
|
"mlflow_resource": MLflowResource(),
|
||||||
|
"discord": DiscordResource(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -46,3 +46,12 @@ class MLflowResource(ConfigurableResource):
|
|||||||
|
|
||||||
def get_tracking_uri(self) -> str:
|
def get_tracking_uri(self) -> str:
|
||||||
return self.tracking_uri
|
return self.tracking_uri
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordResource(ConfigurableResource):
|
||||||
|
"""Discord webhook resource for sending notifications."""
|
||||||
|
|
||||||
|
webhook_url: str = EnvVar("DISCORD_WEBHOOK_URL")
|
||||||
|
|
||||||
|
def get_webhook_url(self) -> str:
|
||||||
|
return self.webhook_url
|
||||||
|
|||||||
@@ -32,4 +32,6 @@ exec mlflow server \
|
|||||||
--host=0.0.0.0 \
|
--host=0.0.0.0 \
|
||||||
--port=5000 \
|
--port=5000 \
|
||||||
--backend-store-uri="postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@postgres:5432/mlflow" \
|
--backend-store-uri="postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@postgres:5432/mlflow" \
|
||||||
--default-artifact-root=/mlflow/artifacts
|
--default-artifact-root=/mlflow/artifacts \
|
||||||
|
--allowed-hosts="*" \
|
||||||
|
--cors-allowed-origins="*"
|
||||||
|
|||||||
Reference in New Issue
Block a user