chore: organise backend

This commit is contained in:
Stijnvandenbroek
2025-08-12 14:39:44 +02:00
parent c4cfbdae62
commit 9d48a1012f

View File

@@ -1,60 +1,84 @@
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import pandas as pd
import io
import random
import json
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
# FastAPI app initialization
app = FastAPI()
@app.get("/")
async def root():
return {"status": "ok"}
app.add_middleware(
CORSMiddleware,
# Allow specific origins - both container and host URLs
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Pydantic models
class AnswerOption(BaseModel):
text: str
is_correct: bool
class AnswerSubmission(BaseModel):
session_id: str
selected_answers: list[str]
class MoveQuestionRequest(BaseModel):
session_id: str
question_index: int
class SessionResetRequest(BaseModel):
session_id: str
class QuizSettings(BaseModel):
repeat_on_mistake: bool
shuffle_answers: bool
randomise_order: bool
question_count_multiplier: int
# Dictionary to store session data
# Global storage
quiz_sessions = {}
# Utility functions
def check_answers(selected_answers: list[str], correct_answers: list[str]) -> bool:
return set(selected_answers) == set(correct_answers)
def get_correct_answers(session_id: str) -> list[str]:
session = quiz_sessions.get(session_id)
if not session:
return []
df = session["data"]
question_index = session["current_question_index"]
return [ans["text"] for ans in df.loc[question_index]["answer"] if ans["is_correct"]]
def update_stats(session_id: str, is_correct: bool) -> None:
session = quiz_sessions.get(session_id)
if not session:
return
session["current_question_index"] += 1
if is_correct:
session["correct_count"] += 1
else:
session["incorrect_count"] += 1
def validate_session(session_id: str):
session = quiz_sessions.get(session_id)
if not session:
raise HTTPException(status_code=400, detail="Invalid session ID.")
return session
# API endpoints
@app.get("/")
async def root():
return {"status": "ok"}
@app.post("/upload-csv-with-settings/")
async def upload_csv_with_settings(files: list[UploadFile] = File(...), settings: str = Form(...)):
@@ -63,7 +87,6 @@ async def upload_csv_with_settings(files: list[UploadFile] = File(...), settings
print(f"File names: {[file.filename for file in files]}")
print(f"Settings: {settings}")
# Parse JSON string to QuizSettings
try:
quiz_settings = QuizSettings.parse_raw(settings)
print(f"Parsed settings: {quiz_settings}")
@@ -79,32 +102,25 @@ async def upload_csv_with_settings(files: list[UploadFile] = File(...), settings
df = pd.read_csv(io.BytesIO(contents))
print(f"File columns: {df.columns.tolist()}")
# Ensure required columns are present
if "question" not in df.columns or "answer" not in df.columns:
return JSONResponse({"error": f"CSV file {file.filename} must have 'question' and 'answer' columns."}, status_code=400)
except Exception as e:
print(f"Error processing file {file.filename}: {str(e)}")
return JSONResponse({"error": f"Error processing {file.filename}: {str(e)}"}, status_code=400)
# Parse answer column if it's in JSON format
df["answer"] = df["answer"].apply(
lambda x: json.loads(x.replace("\\\\", "\\\\\\")) if isinstance(x, str) else x
)
# Combine data into a single dataframe
combined_df = pd.concat([combined_df, df], ignore_index=True)
# Apply question count multiplier
combined_df = pd.concat([combined_df] * quiz_settings.question_count_multiplier, ignore_index=True)
# Randomize order if specified
if quiz_settings.randomise_order:
combined_df = combined_df.sample(frac=1).reset_index(drop=True)
# Generate a unique session ID
session_id = str(random.randint(1000, 9999))
# Store the session data
quiz_sessions[session_id] = {
"data": combined_df,
"original_data": combined_df.copy(),
@@ -114,10 +130,8 @@ async def upload_csv_with_settings(files: list[UploadFile] = File(...), settings
"settings": quiz_settings.dict(),
}
# Return the session ID to the client
return {"session_id": session_id, "message": "Quiz session started!"}
@app.get("/quiz-settings/")
async def get_quiz_settings(session_id: str):
session = quiz_sessions.get(session_id)
@@ -126,8 +140,6 @@ async def get_quiz_settings(session_id: str):
return session["settings"]
# Endpoint to get the next question
@app.get("/next-question/")
async def get_next_question(session_id: str):
session = quiz_sessions.get(session_id)
@@ -137,11 +149,9 @@ async def get_next_question(session_id: str):
df = session["data"]
question_index = session["current_question_index"]
# Check if there are more questions
if session["correct_count"] >= df.shape[0]:
return {"message": "Quiz complete!", "total_questions": len(df)}
# Get the current question and possible answers (for display purposes)
question = df.loc[question_index]["question"]
possible_answers = [ans["text"] for ans in df.loc[question_index]["answer"]]
if session["settings"]["shuffle_answers"]:
@@ -155,8 +165,6 @@ async def get_next_question(session_id: str):
"multiple_choice": multiple_choice,
}
# Endpoint to submit an answer
@app.post("/submit-answer/")
async def submit_answer(submission: AnswerSubmission):
if submission.session_id not in quiz_sessions:
@@ -167,35 +175,6 @@ async def submit_answer(submission: AnswerSubmission):
update_stats(submission.session_id, is_correct)
return {"result": "Correct" if is_correct else "Incorrect", "correct_answers": correct_answers}
def check_answers(selected_answers: list[str], correct_answers: list[str]) -> bool:
return set(selected_answers) == set(correct_answers)
def get_correct_answers(session_id: str) -> list[str]:
session = quiz_sessions.get(session_id)
if not session:
return JSONResponse({"error": "Invalid session ID."}, status_code=400)
df = session["data"]
question_index = session["current_question_index"]
return [ans["text"] for ans in df.loc[question_index]["answer"] if ans["is_correct"]]
def update_stats(session_id: str, is_correct: bool) -> int:
session = quiz_sessions.get(session_id)
if not session:
return JSONResponse({"error": "Invalid session ID."}, status_code=400)
session["current_question_index"] += 1
if is_correct:
session["correct_count"] += 1
else:
session["incorrect_count"] += 1
# Endpoint to get quiz statistics
@app.get("/quiz-stats/")
async def get_quiz_stats(session_id: str):
session = quiz_sessions.get(session_id)
@@ -212,14 +191,9 @@ async def get_quiz_stats(session_id: str):
"incorrect_answers": incorrect_count,
}
# Endpoint to move a given question to the bottom of the DataFrame
@app.post("/move-question-to-bottom/")
async def move_question_to_bottom(request: MoveQuestionRequest):
session = quiz_sessions.get(request.session_id)
if not session:
raise HTTPException(status_code=400, detail="Invalid session ID.")
session = validate_session(request.session_id)
df = session["data"]
if request.question_index < 0 or request.question_index > df.index.max():
@@ -236,8 +210,6 @@ async def move_question_to_bottom(request: MoveQuestionRequest):
session["data"] = df
return {"message": "Question moved to the bottom successfully."}
# Endpoint to reset the session
@app.post("/reset-session/")
async def reset_session(session_reset_request: SessionResetRequest):
session = quiz_sessions.get(session_reset_request.session_id)
@@ -251,8 +223,7 @@ async def reset_session(session_reset_request: SessionResetRequest):
return {"message": "Session reset successfully."}
# Add this to run directly with python main.py
# Application startup
if __name__ == "__main__":
print("Starting server on 0.0.0.0:8000...")
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug")