chore: organise backend

This commit is contained in:
Stijnvandenbroek
2025-08-12 14:39:44 +02:00
parent c4cfbdae62
commit 9d48a1012f

View File

@@ -1,60 +1,84 @@
from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import pandas as pd import pandas as pd
import io import io
import random import random
import json import json
from pydantic import BaseModel from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
import uvicorn import uvicorn
# FastAPI app initialization
app = FastAPI() app = FastAPI()
@app.get("/")
async def root():
return {"status": "ok"}
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
# Allow specific origins - both container and host URLs
allow_origins=["*"], allow_origins=["*"],
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
# Pydantic models
class AnswerOption(BaseModel): class AnswerOption(BaseModel):
text: str text: str
is_correct: bool is_correct: bool
class AnswerSubmission(BaseModel): class AnswerSubmission(BaseModel):
session_id: str session_id: str
selected_answers: list[str] selected_answers: list[str]
class MoveQuestionRequest(BaseModel): class MoveQuestionRequest(BaseModel):
session_id: str session_id: str
question_index: int question_index: int
class SessionResetRequest(BaseModel): class SessionResetRequest(BaseModel):
session_id: str session_id: str
class QuizSettings(BaseModel): class QuizSettings(BaseModel):
repeat_on_mistake: bool repeat_on_mistake: bool
shuffle_answers: bool shuffle_answers: bool
randomise_order: bool randomise_order: bool
question_count_multiplier: int question_count_multiplier: int
# Global storage
# Dictionary to store session data
quiz_sessions = {} quiz_sessions = {}
# Utility functions
def check_answers(selected_answers: list[str], correct_answers: list[str]) -> bool:
return set(selected_answers) == set(correct_answers)
def get_correct_answers(session_id: str) -> list[str]:
session = quiz_sessions.get(session_id)
if not session:
return []
df = session["data"]
question_index = session["current_question_index"]
return [ans["text"] for ans in df.loc[question_index]["answer"] if ans["is_correct"]]
def update_stats(session_id: str, is_correct: bool) -> None:
session = quiz_sessions.get(session_id)
if not session:
return
session["current_question_index"] += 1
if is_correct:
session["correct_count"] += 1
else:
session["incorrect_count"] += 1
def validate_session(session_id: str):
session = quiz_sessions.get(session_id)
if not session:
raise HTTPException(status_code=400, detail="Invalid session ID.")
return session
# API endpoints
@app.get("/")
async def root():
return {"status": "ok"}
@app.post("/upload-csv-with-settings/") @app.post("/upload-csv-with-settings/")
async def upload_csv_with_settings(files: list[UploadFile] = File(...), settings: str = Form(...)): async def upload_csv_with_settings(files: list[UploadFile] = File(...), settings: str = Form(...)):
@@ -63,7 +87,6 @@ async def upload_csv_with_settings(files: list[UploadFile] = File(...), settings
print(f"File names: {[file.filename for file in files]}") print(f"File names: {[file.filename for file in files]}")
print(f"Settings: {settings}") print(f"Settings: {settings}")
# Parse JSON string to QuizSettings
try: try:
quiz_settings = QuizSettings.parse_raw(settings) quiz_settings = QuizSettings.parse_raw(settings)
print(f"Parsed settings: {quiz_settings}") print(f"Parsed settings: {quiz_settings}")
@@ -79,32 +102,25 @@ async def upload_csv_with_settings(files: list[UploadFile] = File(...), settings
df = pd.read_csv(io.BytesIO(contents)) df = pd.read_csv(io.BytesIO(contents))
print(f"File columns: {df.columns.tolist()}") print(f"File columns: {df.columns.tolist()}")
# Ensure required columns are present
if "question" not in df.columns or "answer" not in df.columns: if "question" not in df.columns or "answer" not in df.columns:
return JSONResponse({"error": f"CSV file {file.filename} must have 'question' and 'answer' columns."}, status_code=400) return JSONResponse({"error": f"CSV file {file.filename} must have 'question' and 'answer' columns."}, status_code=400)
except Exception as e: except Exception as e:
print(f"Error processing file {file.filename}: {str(e)}") print(f"Error processing file {file.filename}: {str(e)}")
return JSONResponse({"error": f"Error processing {file.filename}: {str(e)}"}, status_code=400) return JSONResponse({"error": f"Error processing {file.filename}: {str(e)}"}, status_code=400)
# Parse answer column if it's in JSON format
df["answer"] = df["answer"].apply( df["answer"] = df["answer"].apply(
lambda x: json.loads(x.replace("\\\\", "\\\\\\")) if isinstance(x, str) else x lambda x: json.loads(x.replace("\\\\", "\\\\\\")) if isinstance(x, str) else x
) )
# Combine data into a single dataframe
combined_df = pd.concat([combined_df, df], ignore_index=True) combined_df = pd.concat([combined_df, df], ignore_index=True)
# Apply question count multiplier
combined_df = pd.concat([combined_df] * quiz_settings.question_count_multiplier, ignore_index=True) combined_df = pd.concat([combined_df] * quiz_settings.question_count_multiplier, ignore_index=True)
# Randomize order if specified
if quiz_settings.randomise_order: if quiz_settings.randomise_order:
combined_df = combined_df.sample(frac=1).reset_index(drop=True) combined_df = combined_df.sample(frac=1).reset_index(drop=True)
# Generate a unique session ID
session_id = str(random.randint(1000, 9999)) session_id = str(random.randint(1000, 9999))
# Store the session data
quiz_sessions[session_id] = { quiz_sessions[session_id] = {
"data": combined_df, "data": combined_df,
"original_data": combined_df.copy(), "original_data": combined_df.copy(),
@@ -114,10 +130,8 @@ async def upload_csv_with_settings(files: list[UploadFile] = File(...), settings
"settings": quiz_settings.dict(), "settings": quiz_settings.dict(),
} }
# Return the session ID to the client
return {"session_id": session_id, "message": "Quiz session started!"} return {"session_id": session_id, "message": "Quiz session started!"}
@app.get("/quiz-settings/") @app.get("/quiz-settings/")
async def get_quiz_settings(session_id: str): async def get_quiz_settings(session_id: str):
session = quiz_sessions.get(session_id) session = quiz_sessions.get(session_id)
@@ -126,8 +140,6 @@ async def get_quiz_settings(session_id: str):
return session["settings"] return session["settings"]
# Endpoint to get the next question
@app.get("/next-question/") @app.get("/next-question/")
async def get_next_question(session_id: str): async def get_next_question(session_id: str):
session = quiz_sessions.get(session_id) session = quiz_sessions.get(session_id)
@@ -137,11 +149,9 @@ async def get_next_question(session_id: str):
df = session["data"] df = session["data"]
question_index = session["current_question_index"] question_index = session["current_question_index"]
# Check if there are more questions
if session["correct_count"] >= df.shape[0]: if session["correct_count"] >= df.shape[0]:
return {"message": "Quiz complete!", "total_questions": len(df)} return {"message": "Quiz complete!", "total_questions": len(df)}
# Get the current question and possible answers (for display purposes)
question = df.loc[question_index]["question"] question = df.loc[question_index]["question"]
possible_answers = [ans["text"] for ans in df.loc[question_index]["answer"]] possible_answers = [ans["text"] for ans in df.loc[question_index]["answer"]]
if session["settings"]["shuffle_answers"]: if session["settings"]["shuffle_answers"]:
@@ -155,8 +165,6 @@ async def get_next_question(session_id: str):
"multiple_choice": multiple_choice, "multiple_choice": multiple_choice,
} }
# Endpoint to submit an answer
@app.post("/submit-answer/") @app.post("/submit-answer/")
async def submit_answer(submission: AnswerSubmission): async def submit_answer(submission: AnswerSubmission):
if submission.session_id not in quiz_sessions: if submission.session_id not in quiz_sessions:
@@ -167,35 +175,6 @@ async def submit_answer(submission: AnswerSubmission):
update_stats(submission.session_id, is_correct) update_stats(submission.session_id, is_correct)
return {"result": "Correct" if is_correct else "Incorrect", "correct_answers": correct_answers} return {"result": "Correct" if is_correct else "Incorrect", "correct_answers": correct_answers}
def check_answers(selected_answers: list[str], correct_answers: list[str]) -> bool:
return set(selected_answers) == set(correct_answers)
def get_correct_answers(session_id: str) -> list[str]:
session = quiz_sessions.get(session_id)
if not session:
return JSONResponse({"error": "Invalid session ID."}, status_code=400)
df = session["data"]
question_index = session["current_question_index"]
return [ans["text"] for ans in df.loc[question_index]["answer"] if ans["is_correct"]]
def update_stats(session_id: str, is_correct: bool) -> int:
session = quiz_sessions.get(session_id)
if not session:
return JSONResponse({"error": "Invalid session ID."}, status_code=400)
session["current_question_index"] += 1
if is_correct:
session["correct_count"] += 1
else:
session["incorrect_count"] += 1
# Endpoint to get quiz statistics
@app.get("/quiz-stats/") @app.get("/quiz-stats/")
async def get_quiz_stats(session_id: str): async def get_quiz_stats(session_id: str):
session = quiz_sessions.get(session_id) session = quiz_sessions.get(session_id)
@@ -212,14 +191,9 @@ async def get_quiz_stats(session_id: str):
"incorrect_answers": incorrect_count, "incorrect_answers": incorrect_count,
} }
# Endpoint to move a given question to the bottom of the DataFrame
@app.post("/move-question-to-bottom/") @app.post("/move-question-to-bottom/")
async def move_question_to_bottom(request: MoveQuestionRequest): async def move_question_to_bottom(request: MoveQuestionRequest):
session = quiz_sessions.get(request.session_id) session = validate_session(request.session_id)
if not session:
raise HTTPException(status_code=400, detail="Invalid session ID.")
df = session["data"] df = session["data"]
if request.question_index < 0 or request.question_index > df.index.max(): if request.question_index < 0 or request.question_index > df.index.max():
@@ -236,8 +210,6 @@ async def move_question_to_bottom(request: MoveQuestionRequest):
session["data"] = df session["data"] = df
return {"message": "Question moved to the bottom successfully."} return {"message": "Question moved to the bottom successfully."}
# Endpoint to reset the session
@app.post("/reset-session/") @app.post("/reset-session/")
async def reset_session(session_reset_request: SessionResetRequest): async def reset_session(session_reset_request: SessionResetRequest):
session = quiz_sessions.get(session_reset_request.session_id) session = quiz_sessions.get(session_reset_request.session_id)
@@ -251,8 +223,7 @@ async def reset_session(session_reset_request: SessionResetRequest):
return {"message": "Session reset successfully."} return {"message": "Session reset successfully."}
# Application startup
# Add this to run directly with python main.py
if __name__ == "__main__": if __name__ == "__main__":
print("Starting server on 0.0.0.0:8000...") print("Starting server on 0.0.0.0:8000...")
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug") uvicorn.run(app, host="0.0.0.0", port=8000, log_level="debug")