feat: refactor project
This commit is contained in:
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
30
backend/tests/conftest.py
Normal file
30
backend/tests/conftest.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Shared test fixtures."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import get_db
|
||||
from app.main import app
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db_session():
|
||||
"""Provide a mocked database session."""
|
||||
session = MagicMock(spec=Session)
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def client(db_session):
|
||||
"""Provide a FastAPI test client with mocked DB."""
|
||||
|
||||
def _override_get_db():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
app.dependency_overrides.clear()
|
||||
184
backend/tests/test_api.py
Normal file
184
backend/tests/test_api.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Tests for API endpoints."""
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.models import EloRating
|
||||
|
||||
|
||||
def _fake_listing(**overrides) -> dict:
|
||||
"""Return a mock DB row that looks like a joined listing+rating row."""
|
||||
defaults = dict(
|
||||
global_id="abc-123",
|
||||
tiny_id="123",
|
||||
url="https://example.com/listing",
|
||||
title="Nice House",
|
||||
city="Amsterdam",
|
||||
postcode="1012AB",
|
||||
province="Noord-Holland",
|
||||
neighbourhood="Centrum",
|
||||
municipality="Amsterdam",
|
||||
latitude=52.37,
|
||||
longitude=4.89,
|
||||
object_type="apartment",
|
||||
house_type="upstairs",
|
||||
offering_type="buy",
|
||||
construction_type="existing",
|
||||
construction_year="2000",
|
||||
energy_label="A",
|
||||
living_area=80,
|
||||
plot_area=0,
|
||||
bedrooms=2,
|
||||
rooms=4,
|
||||
has_garden=False,
|
||||
has_balcony=True,
|
||||
has_solar_panels=False,
|
||||
has_heat_pump=False,
|
||||
has_roof_terrace=False,
|
||||
is_energy_efficient=True,
|
||||
is_monument=False,
|
||||
current_price=350000,
|
||||
status="available",
|
||||
price_per_sqm=4375.0,
|
||||
publication_date="2024-01-15",
|
||||
elo_rating=1500.0,
|
||||
comparison_count=0,
|
||||
wins=0,
|
||||
losses=0,
|
||||
)
|
||||
defaults.update(overrides)
|
||||
row = MagicMock()
|
||||
for k, v in defaults.items():
|
||||
setattr(row, k, v)
|
||||
return row
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
def test_health(self, client):
|
||||
resp = client.get("/api/health")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"status": "ok"}
|
||||
|
||||
|
||||
class TestListingsEndpoints:
|
||||
def test_get_listings(self, client, db_session):
|
||||
row = _fake_listing()
|
||||
db_session.execute.return_value = [row]
|
||||
resp = client.get("/api/listings")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert isinstance(data, list)
|
||||
assert len(data) == 1
|
||||
assert data[0]["global_id"] == "abc-123"
|
||||
|
||||
def test_get_listing_not_found(self, client, db_session):
|
||||
result_mock = MagicMock()
|
||||
result_mock.first.return_value = None
|
||||
db_session.execute.return_value = result_mock
|
||||
resp = client.get("/api/listings/nonexistent")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_get_listing_found(self, client, db_session):
|
||||
row = _fake_listing(global_id="xyz-789")
|
||||
result_mock = MagicMock()
|
||||
result_mock.first.return_value = row
|
||||
db_session.execute.return_value = result_mock
|
||||
resp = client.get("/api/listings/xyz-789")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["global_id"] == "xyz-789"
|
||||
|
||||
|
||||
class TestRankingsEndpoints:
|
||||
def test_get_rankings(self, client, db_session):
|
||||
rows = [
|
||||
_fake_listing(global_id="a", elo_rating=1600.0),
|
||||
_fake_listing(global_id="b", elo_rating=1400.0),
|
||||
]
|
||||
db_session.execute.return_value = rows
|
||||
resp = client.get("/api/rankings")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
assert data[0]["rank"] == 1
|
||||
assert data[1]["rank"] == 2
|
||||
|
||||
|
||||
class TestImagesEndpoint:
|
||||
def test_images_found(self, client, db_session):
|
||||
row = MagicMock()
|
||||
row.photo_urls = ["https://img.example.com/1.jpg", "https://img.example.com/2.jpg"]
|
||||
result_mock = MagicMock()
|
||||
result_mock.first.return_value = row
|
||||
db_session.execute.return_value = result_mock
|
||||
resp = client.get("/api/listings/abc-123/images")
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["images"]) == 2
|
||||
|
||||
def test_images_not_found(self, client, db_session):
|
||||
result_mock = MagicMock()
|
||||
result_mock.first.return_value = None
|
||||
db_session.execute.return_value = result_mock
|
||||
resp = client.get("/api/listings/abc-123/images")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"images": []}
|
||||
|
||||
|
||||
class TestCompareEndpoints:
|
||||
def test_compare_same_ids_rejected(self, client):
|
||||
resp = client.post(
|
||||
"/api/compare",
|
||||
json={"winner_id": "abc", "loser_id": "abc"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_compare_success(self, client, db_session):
|
||||
winner = EloRating(global_id="w1", elo_rating=1500.0, comparison_count=0, wins=0, losses=0)
|
||||
loser = EloRating(global_id="l1", elo_rating=1500.0, comparison_count=0, wins=0, losses=0)
|
||||
|
||||
def fake_filter_by(global_id):
|
||||
mock_result = MagicMock()
|
||||
if global_id == "w1":
|
||||
mock_result.first.return_value = winner
|
||||
elif global_id == "l1":
|
||||
mock_result.first.return_value = loser
|
||||
else:
|
||||
mock_result.first.return_value = None
|
||||
return mock_result
|
||||
|
||||
db_session.query.return_value.filter_by = fake_filter_by
|
||||
|
||||
resp = client.post(
|
||||
"/api/compare",
|
||||
json={"winner_id": "w1", "loser_id": "l1"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["winner_id"] == "w1"
|
||||
assert data["loser_id"] == "l1"
|
||||
assert data["elo_change"] > 0
|
||||
|
||||
def test_matchup_insufficient_listings(self, client, db_session):
|
||||
db_session.execute.return_value = [_fake_listing()] # only 1
|
||||
resp = client.get("/api/matchup")
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_history(self, client, db_session):
|
||||
row = MagicMock()
|
||||
row.id = 1
|
||||
row.listing_a_title = "House A"
|
||||
row.listing_b_title = "House B"
|
||||
row.winner_title = "House A"
|
||||
row.listing_a_id = "a"
|
||||
row.listing_b_id = "b"
|
||||
row.winner_id = "a"
|
||||
row.elo_a_before = 1500.0
|
||||
row.elo_b_before = 1500.0
|
||||
row.elo_a_after = 1516.0
|
||||
row.elo_b_after = 1484.0
|
||||
row.created_at = datetime(2024, 1, 1, 12, 0, 0)
|
||||
db_session.execute.return_value = [row]
|
||||
resp = client.get("/api/history")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["winner_id"] == "a"
|
||||
53
backend/tests/test_config.py
Normal file
53
backend/tests/test_config.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Tests for configuration and SQL loading."""
|
||||
|
||||
from app.config import SQL_DIR, Settings, load_sql
|
||||
|
||||
|
||||
class TestSettings:
|
||||
"""Tests for the Settings class."""
|
||||
|
||||
def test_default_values(self):
|
||||
s = Settings()
|
||||
assert s.POSTGRES_HOST in ("localhost", s.POSTGRES_HOST)
|
||||
assert s.POSTGRES_PORT == int(s.POSTGRES_PORT)
|
||||
assert s.K_FACTOR > 0
|
||||
assert s.DEFAULT_ELO > 0
|
||||
|
||||
def test_database_url_format(self):
|
||||
s = Settings()
|
||||
url = s.database_url
|
||||
assert url.startswith("postgresql+psycopg2://")
|
||||
assert str(s.POSTGRES_HOST) in url
|
||||
|
||||
|
||||
class TestLoadSql:
|
||||
"""Tests for the SQL file loader."""
|
||||
|
||||
def test_sql_dir_exists(self):
|
||||
assert SQL_DIR.is_dir()
|
||||
|
||||
def test_all_sql_files_exist(self):
|
||||
expected = [
|
||||
"listing_select.sql",
|
||||
"recent_pairs.sql",
|
||||
"history.sql",
|
||||
"count_comparisons.sql",
|
||||
"count_rated.sql",
|
||||
"count_listings.sql",
|
||||
"elo_aggregates.sql",
|
||||
"elo_distribution.sql",
|
||||
"listing_images.sql",
|
||||
]
|
||||
for name in expected:
|
||||
assert (SQL_DIR / name).is_file(), f"Missing SQL file: {name}"
|
||||
|
||||
def test_load_sql_substitutes_schemas(self):
|
||||
sql = load_sql("listing_select.sql")
|
||||
# Should not contain any unresolved placeholders
|
||||
assert "{" not in sql
|
||||
assert "}" not in sql
|
||||
|
||||
def test_load_sql_returns_string(self):
|
||||
sql = load_sql("count_comparisons.sql")
|
||||
assert isinstance(sql, str)
|
||||
assert "count" in sql.lower()
|
||||
46
backend/tests/test_elo.py
Normal file
46
backend/tests/test_elo.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Tests for the ELO calculation module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.elo import calculate_elo
|
||||
|
||||
|
||||
class TestCalculateElo:
|
||||
"""Tests for calculate_elo."""
|
||||
|
||||
def test_equal_ratings(self):
|
||||
"""Equal ratings should produce symmetric changes."""
|
||||
new_w, new_l = calculate_elo(1500.0, 1500.0, k_factor=32.0)
|
||||
assert new_w > 1500.0
|
||||
assert new_l < 1500.0
|
||||
assert new_w - 1500.0 == pytest.approx(1500.0 - new_l, abs=0.01)
|
||||
|
||||
def test_higher_rated_wins(self):
|
||||
"""Higher-rated winner should get a small gain."""
|
||||
new_w, new_l = calculate_elo(1800.0, 1200.0, k_factor=32.0)
|
||||
change = new_w - 1800.0
|
||||
assert 0 < change < 16.0 # expected win → small gain
|
||||
|
||||
def test_lower_rated_wins(self):
|
||||
"""Lower-rated winner (upset) should get a large gain."""
|
||||
new_w, new_l = calculate_elo(1200.0, 1800.0, k_factor=32.0)
|
||||
change = new_w - 1200.0
|
||||
assert change > 16.0 # upset → large gain
|
||||
|
||||
def test_k_factor_scales_change(self):
|
||||
"""Higher K-factor should produce larger rating changes."""
|
||||
_, _ = calculate_elo(1500.0, 1500.0, k_factor=16.0)
|
||||
new_w_16, _ = calculate_elo(1500.0, 1500.0, k_factor=16.0)
|
||||
new_w_64, _ = calculate_elo(1500.0, 1500.0, k_factor=64.0)
|
||||
assert (new_w_64 - 1500.0) > (new_w_16 - 1500.0)
|
||||
|
||||
def test_total_elo_preserved(self):
|
||||
"""Total ELO should be preserved (zero-sum)."""
|
||||
new_w, new_l = calculate_elo(1600.0, 1400.0, k_factor=32.0)
|
||||
assert new_w + new_l == pytest.approx(1600.0 + 1400.0, abs=0.01)
|
||||
|
||||
def test_result_is_rounded(self):
|
||||
"""Results should be rounded to 2 decimal places."""
|
||||
new_w, new_l = calculate_elo(1500.0, 1500.0, k_factor=32.0)
|
||||
assert new_w == round(new_w, 2)
|
||||
assert new_l == round(new_l, 2)
|
||||
Reference in New Issue
Block a user