feat: refactor project

This commit is contained in:
Stijnvandenbroek
2026-03-06 14:51:26 +00:00
parent c908d96921
commit 535a09fd75
28 changed files with 1136 additions and 51 deletions

View File

30
backend/tests/conftest.py Normal file
View File

@@ -0,0 +1,30 @@
"""Shared test fixtures."""
from unittest.mock import MagicMock
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session
from app.database import get_db
from app.main import app
@pytest.fixture()
def db_session():
"""Provide a mocked database session."""
session = MagicMock(spec=Session)
return session
@pytest.fixture()
def client(db_session):
"""Provide a FastAPI test client with mocked DB."""
def _override_get_db():
yield db_session
app.dependency_overrides[get_db] = _override_get_db
with TestClient(app) as c:
yield c
app.dependency_overrides.clear()

184
backend/tests/test_api.py Normal file
View File

@@ -0,0 +1,184 @@
"""Tests for API endpoints."""
from datetime import datetime
from unittest.mock import MagicMock
from app.models import EloRating
def _fake_listing(**overrides) -> dict:
"""Return a mock DB row that looks like a joined listing+rating row."""
defaults = dict(
global_id="abc-123",
tiny_id="123",
url="https://example.com/listing",
title="Nice House",
city="Amsterdam",
postcode="1012AB",
province="Noord-Holland",
neighbourhood="Centrum",
municipality="Amsterdam",
latitude=52.37,
longitude=4.89,
object_type="apartment",
house_type="upstairs",
offering_type="buy",
construction_type="existing",
construction_year="2000",
energy_label="A",
living_area=80,
plot_area=0,
bedrooms=2,
rooms=4,
has_garden=False,
has_balcony=True,
has_solar_panels=False,
has_heat_pump=False,
has_roof_terrace=False,
is_energy_efficient=True,
is_monument=False,
current_price=350000,
status="available",
price_per_sqm=4375.0,
publication_date="2024-01-15",
elo_rating=1500.0,
comparison_count=0,
wins=0,
losses=0,
)
defaults.update(overrides)
row = MagicMock()
for k, v in defaults.items():
setattr(row, k, v)
return row
class TestHealthEndpoint:
def test_health(self, client):
resp = client.get("/api/health")
assert resp.status_code == 200
assert resp.json() == {"status": "ok"}
class TestListingsEndpoints:
def test_get_listings(self, client, db_session):
row = _fake_listing()
db_session.execute.return_value = [row]
resp = client.get("/api/listings")
assert resp.status_code == 200
data = resp.json()
assert isinstance(data, list)
assert len(data) == 1
assert data[0]["global_id"] == "abc-123"
def test_get_listing_not_found(self, client, db_session):
result_mock = MagicMock()
result_mock.first.return_value = None
db_session.execute.return_value = result_mock
resp = client.get("/api/listings/nonexistent")
assert resp.status_code == 404
def test_get_listing_found(self, client, db_session):
row = _fake_listing(global_id="xyz-789")
result_mock = MagicMock()
result_mock.first.return_value = row
db_session.execute.return_value = result_mock
resp = client.get("/api/listings/xyz-789")
assert resp.status_code == 200
assert resp.json()["global_id"] == "xyz-789"
class TestRankingsEndpoints:
def test_get_rankings(self, client, db_session):
rows = [
_fake_listing(global_id="a", elo_rating=1600.0),
_fake_listing(global_id="b", elo_rating=1400.0),
]
db_session.execute.return_value = rows
resp = client.get("/api/rankings")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
assert data[0]["rank"] == 1
assert data[1]["rank"] == 2
class TestImagesEndpoint:
def test_images_found(self, client, db_session):
row = MagicMock()
row.photo_urls = ["https://img.example.com/1.jpg", "https://img.example.com/2.jpg"]
result_mock = MagicMock()
result_mock.first.return_value = row
db_session.execute.return_value = result_mock
resp = client.get("/api/listings/abc-123/images")
assert resp.status_code == 200
assert len(resp.json()["images"]) == 2
def test_images_not_found(self, client, db_session):
result_mock = MagicMock()
result_mock.first.return_value = None
db_session.execute.return_value = result_mock
resp = client.get("/api/listings/abc-123/images")
assert resp.status_code == 200
assert resp.json() == {"images": []}
class TestCompareEndpoints:
def test_compare_same_ids_rejected(self, client):
resp = client.post(
"/api/compare",
json={"winner_id": "abc", "loser_id": "abc"},
)
assert resp.status_code == 400
def test_compare_success(self, client, db_session):
winner = EloRating(global_id="w1", elo_rating=1500.0, comparison_count=0, wins=0, losses=0)
loser = EloRating(global_id="l1", elo_rating=1500.0, comparison_count=0, wins=0, losses=0)
def fake_filter_by(global_id):
mock_result = MagicMock()
if global_id == "w1":
mock_result.first.return_value = winner
elif global_id == "l1":
mock_result.first.return_value = loser
else:
mock_result.first.return_value = None
return mock_result
db_session.query.return_value.filter_by = fake_filter_by
resp = client.post(
"/api/compare",
json={"winner_id": "w1", "loser_id": "l1"},
)
assert resp.status_code == 200
data = resp.json()
assert data["winner_id"] == "w1"
assert data["loser_id"] == "l1"
assert data["elo_change"] > 0
def test_matchup_insufficient_listings(self, client, db_session):
db_session.execute.return_value = [_fake_listing()] # only 1
resp = client.get("/api/matchup")
assert resp.status_code == 400
def test_history(self, client, db_session):
row = MagicMock()
row.id = 1
row.listing_a_title = "House A"
row.listing_b_title = "House B"
row.winner_title = "House A"
row.listing_a_id = "a"
row.listing_b_id = "b"
row.winner_id = "a"
row.elo_a_before = 1500.0
row.elo_b_before = 1500.0
row.elo_a_after = 1516.0
row.elo_b_after = 1484.0
row.created_at = datetime(2024, 1, 1, 12, 0, 0)
db_session.execute.return_value = [row]
resp = client.get("/api/history")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["winner_id"] == "a"

View File

@@ -0,0 +1,53 @@
"""Tests for configuration and SQL loading."""
from app.config import SQL_DIR, Settings, load_sql
class TestSettings:
"""Tests for the Settings class."""
def test_default_values(self):
s = Settings()
assert s.POSTGRES_HOST in ("localhost", s.POSTGRES_HOST)
assert s.POSTGRES_PORT == int(s.POSTGRES_PORT)
assert s.K_FACTOR > 0
assert s.DEFAULT_ELO > 0
def test_database_url_format(self):
s = Settings()
url = s.database_url
assert url.startswith("postgresql+psycopg2://")
assert str(s.POSTGRES_HOST) in url
class TestLoadSql:
"""Tests for the SQL file loader."""
def test_sql_dir_exists(self):
assert SQL_DIR.is_dir()
def test_all_sql_files_exist(self):
expected = [
"listing_select.sql",
"recent_pairs.sql",
"history.sql",
"count_comparisons.sql",
"count_rated.sql",
"count_listings.sql",
"elo_aggregates.sql",
"elo_distribution.sql",
"listing_images.sql",
]
for name in expected:
assert (SQL_DIR / name).is_file(), f"Missing SQL file: {name}"
def test_load_sql_substitutes_schemas(self):
sql = load_sql("listing_select.sql")
# Should not contain any unresolved placeholders
assert "{" not in sql
assert "}" not in sql
def test_load_sql_returns_string(self):
sql = load_sql("count_comparisons.sql")
assert isinstance(sql, str)
assert "count" in sql.lower()

46
backend/tests/test_elo.py Normal file
View File

@@ -0,0 +1,46 @@
"""Tests for the ELO calculation module."""
import pytest
from app.elo import calculate_elo
class TestCalculateElo:
"""Tests for calculate_elo."""
def test_equal_ratings(self):
"""Equal ratings should produce symmetric changes."""
new_w, new_l = calculate_elo(1500.0, 1500.0, k_factor=32.0)
assert new_w > 1500.0
assert new_l < 1500.0
assert new_w - 1500.0 == pytest.approx(1500.0 - new_l, abs=0.01)
def test_higher_rated_wins(self):
"""Higher-rated winner should get a small gain."""
new_w, new_l = calculate_elo(1800.0, 1200.0, k_factor=32.0)
change = new_w - 1800.0
assert 0 < change < 16.0 # expected win → small gain
def test_lower_rated_wins(self):
"""Lower-rated winner (upset) should get a large gain."""
new_w, new_l = calculate_elo(1200.0, 1800.0, k_factor=32.0)
change = new_w - 1200.0
assert change > 16.0 # upset → large gain
def test_k_factor_scales_change(self):
"""Higher K-factor should produce larger rating changes."""
_, _ = calculate_elo(1500.0, 1500.0, k_factor=16.0)
new_w_16, _ = calculate_elo(1500.0, 1500.0, k_factor=16.0)
new_w_64, _ = calculate_elo(1500.0, 1500.0, k_factor=64.0)
assert (new_w_64 - 1500.0) > (new_w_16 - 1500.0)
def test_total_elo_preserved(self):
"""Total ELO should be preserved (zero-sum)."""
new_w, new_l = calculate_elo(1600.0, 1400.0, k_factor=32.0)
assert new_w + new_l == pytest.approx(1600.0 + 1400.0, abs=0.01)
def test_result_is_rounded(self):
"""Results should be rounded to 2 decimal places."""
new_w, new_l = calculate_elo(1500.0, 1500.0, k_factor=32.0)
assert new_w == round(new_w, 2)
assert new_l == round(new_l, 2)