179 lines
6.0 KiB
Python
179 lines
6.0 KiB
Python
"""Comparison endpoints – matchmaking and ELO updates."""
|
||
|
||
import random
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException
|
||
from sqlalchemy import text
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.config import load_sql, settings
|
||
from app.database import get_db
|
||
from app.elo import calculate_elo
|
||
from app.models import Comparison, EloRating
|
||
from app.queries import LISTING_SELECT, row_to_listing
|
||
from app.schemas import (
|
||
CompareRequest,
|
||
CompareResponse,
|
||
ComparisonHistoryItem,
|
||
MatchupResponse,
|
||
StatsResponse,
|
||
)
|
||
|
||
router = APIRouter()
|
||
|
||
SAMPLE_JOIN = (
|
||
f" inner join {settings.ELO_SCHEMA}.sample_listings as s"
|
||
f" on l.global_id = s.global_id"
|
||
)
|
||
|
||
|
||
@router.get("/matchup", response_model=MatchupResponse)
|
||
def get_matchup(
|
||
status: str | None = None,
|
||
db: Session = Depends(get_db),
|
||
):
|
||
"""Return a weighted-random pair of listings for comparison."""
|
||
query = LISTING_SELECT + SAMPLE_JOIN
|
||
params: dict = {"default_elo": settings.DEFAULT_ELO}
|
||
if status and status != "all":
|
||
query += " where l.status = :status"
|
||
params["status"] = status
|
||
result = db.execute(text(query), params)
|
||
listings = [row_to_listing(row) for row in result]
|
||
|
||
if len(listings) < 2:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail="Not enough listings for comparison (need at least 2).",
|
||
)
|
||
|
||
recent = db.execute(text(load_sql("recent_pairs.sql")))
|
||
recent_pairs = {
|
||
frozenset([r.listing_a_id, r.listing_b_id]) for r in recent
|
||
}
|
||
|
||
weights = [1.0 / (l.comparison_count + 1) ** 1.5 for l in listings]
|
||
first = random.choices(listings, weights=weights, k=1)[0]
|
||
|
||
remaining = [l for l in listings if l.global_id != first.global_id]
|
||
remaining_weights = [1.0 / (l.comparison_count + 1) ** 1.5 for l in remaining]
|
||
|
||
second = remaining[0]
|
||
for _ in range(20):
|
||
candidate = random.choices(remaining, weights=remaining_weights, k=1)[0]
|
||
if frozenset([first.global_id, candidate.global_id]) not in recent_pairs:
|
||
second = candidate
|
||
break
|
||
else:
|
||
second = random.choices(remaining, weights=remaining_weights, k=1)[0]
|
||
|
||
return MatchupResponse(listing_a=first, listing_b=second)
|
||
|
||
|
||
@router.post("/compare", response_model=CompareResponse)
|
||
def submit_comparison(body: CompareRequest, db: Session = Depends(get_db)):
|
||
"""Record a comparison result and update ELO ratings."""
|
||
if body.winner_id == body.loser_id:
|
||
raise HTTPException(status_code=400, detail="Winner and loser must differ.")
|
||
|
||
def get_or_create_rating(global_id: str) -> EloRating:
|
||
rating = db.query(EloRating).filter_by(global_id=global_id).first()
|
||
if not rating:
|
||
rating = EloRating(global_id=global_id, elo_rating=settings.DEFAULT_ELO)
|
||
db.add(rating)
|
||
db.flush()
|
||
return rating
|
||
|
||
winner_rating = get_or_create_rating(body.winner_id)
|
||
loser_rating = get_or_create_rating(body.loser_id)
|
||
|
||
elo_w_before = winner_rating.elo_rating
|
||
elo_l_before = loser_rating.elo_rating
|
||
new_elo_w, new_elo_l = calculate_elo(elo_w_before, elo_l_before, settings.K_FACTOR)
|
||
|
||
winner_rating.elo_rating = new_elo_w
|
||
winner_rating.comparison_count += 1
|
||
winner_rating.wins += 1
|
||
|
||
loser_rating.elo_rating = new_elo_l
|
||
loser_rating.comparison_count += 1
|
||
loser_rating.losses += 1
|
||
|
||
db.add(
|
||
Comparison(
|
||
listing_a_id=body.winner_id,
|
||
listing_b_id=body.loser_id,
|
||
winner_id=body.winner_id,
|
||
elo_a_before=elo_w_before,
|
||
elo_b_before=elo_l_before,
|
||
elo_a_after=new_elo_w,
|
||
elo_b_after=new_elo_l,
|
||
)
|
||
)
|
||
db.commit()
|
||
|
||
return CompareResponse(
|
||
winner_id=body.winner_id,
|
||
loser_id=body.loser_id,
|
||
elo_change=round(new_elo_w - elo_w_before, 1),
|
||
new_winner_elo=round(new_elo_w, 1),
|
||
new_loser_elo=round(new_elo_l, 1),
|
||
)
|
||
|
||
|
||
def _row_to_history(r) -> ComparisonHistoryItem:
|
||
return ComparisonHistoryItem(
|
||
id=r.id,
|
||
listing_a_title=r.listing_a_title,
|
||
listing_b_title=r.listing_b_title,
|
||
winner_title=r.winner_title,
|
||
listing_a_id=r.listing_a_id,
|
||
listing_b_id=r.listing_b_id,
|
||
winner_id=r.winner_id,
|
||
elo_a_before=round(r.elo_a_before, 1),
|
||
elo_b_before=round(r.elo_b_before, 1),
|
||
elo_a_after=round(r.elo_a_after, 1),
|
||
elo_b_after=round(r.elo_b_after, 1),
|
||
created_at=r.created_at,
|
||
)
|
||
|
||
|
||
@router.get("/history", response_model=list[ComparisonHistoryItem])
|
||
def get_history(limit: int = 50, db: Session = Depends(get_db)):
|
||
"""Return recent comparisons."""
|
||
rows = db.execute(text(load_sql("history.sql")), {"limit": limit})
|
||
return [_row_to_history(r) for r in rows]
|
||
|
||
|
||
@router.get("/stats", response_model=StatsResponse)
|
||
def get_stats(db: Session = Depends(get_db)):
|
||
"""Return aggregate statistics about comparisons and ratings."""
|
||
total_comparisons = db.execute(text(load_sql("count_comparisons.sql"))).scalar() or 0
|
||
total_rated = db.execute(text(load_sql("count_rated.sql"))).scalar() or 0
|
||
total_listings = db.execute(text(load_sql("count_listings.sql"))).scalar() or 0
|
||
|
||
elo_agg = db.execute(text(load_sql("elo_aggregates.sql"))).first()
|
||
avg_elo = round(float(elo_agg[0]), 1) if elo_agg and elo_agg[0] else None
|
||
max_elo = round(float(elo_agg[1]), 1) if elo_agg and elo_agg[1] else None
|
||
min_elo = round(float(elo_agg[2]), 1) if elo_agg and elo_agg[2] else None
|
||
|
||
dist_rows = db.execute(text(load_sql("elo_distribution.sql")))
|
||
elo_distribution = [
|
||
{"bucket": f"{int(r.bucket)}-{int(r.bucket) + 49}", "count": r.count}
|
||
for r in dist_rows
|
||
]
|
||
|
||
recent_rows = db.execute(text(load_sql("history.sql")), {"limit": 10})
|
||
recent_comparisons = [_row_to_history(r) for r in recent_rows]
|
||
|
||
return StatsResponse(
|
||
total_comparisons=total_comparisons,
|
||
total_rated_listings=total_rated,
|
||
total_listings=total_listings,
|
||
avg_elo=avg_elo,
|
||
max_elo=max_elo,
|
||
min_elo=min_elo,
|
||
elo_distribution=elo_distribution,
|
||
recent_comparisons=recent_comparisons,
|
||
)
|