Files
house-elo-ranking/backend/app/routers/comparisons.py
2026-03-06 14:51:26 +00:00

173 lines
6.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Comparison endpoints matchmaking and ELO updates."""
import random
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import text
from sqlalchemy.orm import Session
from app.config import load_sql, settings
from app.database import get_db
from app.elo import calculate_elo
from app.models import Comparison, EloRating
from app.queries import LISTING_SELECT, row_to_listing
from app.schemas import (
CompareRequest,
CompareResponse,
ComparisonHistoryItem,
MatchupResponse,
StatsResponse,
)
router = APIRouter()
SAMPLE_JOIN = f" inner join {settings.ELO_SCHEMA}.sample_listings as s on l.global_id = s.global_id"
@router.get("/matchup", response_model=MatchupResponse)
def get_matchup(
status: str | None = None,
db: Session = Depends(get_db),
):
"""Return a weighted-random pair of listings for comparison."""
query = LISTING_SELECT + SAMPLE_JOIN
params: dict = {"default_elo": settings.DEFAULT_ELO}
if status and status != "all":
query += " where l.status = :status"
params["status"] = status
result = db.execute(text(query), params)
listings = [row_to_listing(row) for row in result]
if len(listings) < 2:
raise HTTPException(
status_code=400,
detail="Not enough listings for comparison (need at least 2).",
)
recent = db.execute(text(load_sql("recent_pairs.sql")))
recent_pairs = {frozenset([r.listing_a_id, r.listing_b_id]) for r in recent}
weights = [1.0 / (x.comparison_count + 1) ** 1.5 for x in listings]
first = random.choices(listings, weights=weights, k=1)[0]
remaining = [x for x in listings if x.global_id != first.global_id]
remaining_weights = [1.0 / (x.comparison_count + 1) ** 1.5 for x in remaining]
second = remaining[0]
for _ in range(20):
candidate = random.choices(remaining, weights=remaining_weights, k=1)[0]
if frozenset([first.global_id, candidate.global_id]) not in recent_pairs:
second = candidate
break
else:
second = random.choices(remaining, weights=remaining_weights, k=1)[0]
return MatchupResponse(listing_a=first, listing_b=second)
@router.post("/compare", response_model=CompareResponse)
def submit_comparison(body: CompareRequest, db: Session = Depends(get_db)):
"""Record a comparison result and update ELO ratings."""
if body.winner_id == body.loser_id:
raise HTTPException(status_code=400, detail="Winner and loser must differ.")
def get_or_create_rating(global_id: str) -> EloRating:
rating = db.query(EloRating).filter_by(global_id=global_id).first()
if not rating:
rating = EloRating(global_id=global_id, elo_rating=settings.DEFAULT_ELO)
db.add(rating)
db.flush()
return rating
winner_rating = get_or_create_rating(body.winner_id)
loser_rating = get_or_create_rating(body.loser_id)
elo_w_before = winner_rating.elo_rating
elo_l_before = loser_rating.elo_rating
new_elo_w, new_elo_l = calculate_elo(elo_w_before, elo_l_before, settings.K_FACTOR)
winner_rating.elo_rating = new_elo_w
winner_rating.comparison_count += 1
winner_rating.wins += 1
loser_rating.elo_rating = new_elo_l
loser_rating.comparison_count += 1
loser_rating.losses += 1
db.add(
Comparison(
listing_a_id=body.winner_id,
listing_b_id=body.loser_id,
winner_id=body.winner_id,
elo_a_before=elo_w_before,
elo_b_before=elo_l_before,
elo_a_after=new_elo_w,
elo_b_after=new_elo_l,
)
)
db.commit()
return CompareResponse(
winner_id=body.winner_id,
loser_id=body.loser_id,
elo_change=round(new_elo_w - elo_w_before, 1),
new_winner_elo=round(new_elo_w, 1),
new_loser_elo=round(new_elo_l, 1),
)
def _row_to_history(r) -> ComparisonHistoryItem:
return ComparisonHistoryItem(
id=r.id,
listing_a_title=r.listing_a_title,
listing_b_title=r.listing_b_title,
winner_title=r.winner_title,
listing_a_id=r.listing_a_id,
listing_b_id=r.listing_b_id,
winner_id=r.winner_id,
elo_a_before=round(r.elo_a_before, 1),
elo_b_before=round(r.elo_b_before, 1),
elo_a_after=round(r.elo_a_after, 1),
elo_b_after=round(r.elo_b_after, 1),
created_at=r.created_at,
)
@router.get("/history", response_model=list[ComparisonHistoryItem])
def get_history(limit: int = 50, db: Session = Depends(get_db)):
"""Return recent comparisons."""
rows = db.execute(text(load_sql("history.sql")), {"limit": limit})
return [_row_to_history(r) for r in rows]
@router.get("/stats", response_model=StatsResponse)
def get_stats(db: Session = Depends(get_db)):
"""Return aggregate statistics about comparisons and ratings."""
total_comparisons = db.execute(text(load_sql("count_comparisons.sql"))).scalar() or 0
total_rated = db.execute(text(load_sql("count_rated.sql"))).scalar() or 0
total_listings = db.execute(text(load_sql("count_listings.sql"))).scalar() or 0
elo_agg = db.execute(text(load_sql("elo_aggregates.sql"))).first()
avg_elo = round(float(elo_agg[0]), 1) if elo_agg and elo_agg[0] else None
max_elo = round(float(elo_agg[1]), 1) if elo_agg and elo_agg[1] else None
min_elo = round(float(elo_agg[2]), 1) if elo_agg and elo_agg[2] else None
dist_rows = db.execute(text(load_sql("elo_distribution.sql")))
elo_distribution = [
{"bucket": f"{int(r.bucket)}-{int(r.bucket) + 49}", "count": r.count} for r in dist_rows
]
recent_rows = db.execute(text(load_sql("history.sql")), {"limit": 10})
recent_comparisons = [_row_to_history(r) for r in recent_rows]
return StatsResponse(
total_comparisons=total_comparisons,
total_rated_listings=total_rated,
total_listings=total_listings,
avg_elo=avg_elo,
max_elo=max_elo,
min_elo=min_elo,
elo_distribution=elo_distribution,
recent_comparisons=recent_comparisons,
)