diff --git a/api/main.py b/api/main.py index 9dfff06..f18e03d 100644 --- a/api/main.py +++ b/api/main.py @@ -6,17 +6,20 @@ from temporalio.api.enums.v1 import WorkflowExecutionStatus from workflows.tool_workflow import ToolWorkflow from models.data_types import CombinedInput, ToolWorkflowParams -from tools.goal_registry import goal_event_flight_invoice +from tools.goal_registry import goal_match_train_invoice from fastapi.middleware.cors import CORSMiddleware from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE + app = FastAPI() temporal_client: Optional[Client] = None + @app.on_event("startup") async def startup_event(): global temporal_client temporal_client = await get_temporal_client() + app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:5173"], @@ -62,13 +65,13 @@ async def get_conversation_history(): status_names = { WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED: "WORKFLOW_EXECUTION_STATUS_TERMINATED", WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED: "WORKFLOW_EXECUTION_STATUS_CANCELED", - WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED" + WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED", } failed_states = [ WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED, WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED, - WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED + WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED, ] # Check workflow status first @@ -77,11 +80,11 @@ async def get_conversation_history(): status_name = status_names.get(description.status, "UNKNOWN_STATUS") print(f"Workflow is in {status_name} state. Returning empty history.") return [] - + # Only query if workflow is running conversation_history = await handle.query("get_conversation_history") return conversation_history - + except TemporalError as e: print(f"Temporal error: {e}") return [] @@ -92,7 +95,7 @@ async def send_prompt(prompt: str): # Create combined input combined_input = CombinedInput( tool_params=ToolWorkflowParams(None, None), - agent_goal=goal_event_flight_invoice, + agent_goal=goal_match_train_invoice, ) workflow_id = "agent-workflow" @@ -139,7 +142,7 @@ async def start_workflow(): # Create combined input combined_input = CombinedInput( tool_params=ToolWorkflowParams(None, None), - agent_goal=goal_event_flight_invoice, + agent_goal=goal_match_train_invoice, ) workflow_id = "agent-workflow" @@ -151,7 +154,9 @@ async def start_workflow(): id=workflow_id, task_queue=TEMPORAL_TASK_QUEUE, start_signal="user_prompt", - start_signal_args=["### " + goal_event_flight_invoice.starter_prompt], + start_signal_args=["### " + goal_match_train_invoice.starter_prompt], ) - return {"message": f"Workflow started with goal's starter prompt: {goal_event_flight_invoice.starter_prompt}."} + return { + "message": f"Workflow started with goal's starter prompt: {goal_match_train_invoice.starter_prompt}." + } diff --git a/thirdparty/train_api.py b/thirdparty/train_api.py index 153465e..891be9f 100644 --- a/thirdparty/train_api.py +++ b/thirdparty/train_api.py @@ -15,14 +15,29 @@ import string def parse_datetime(datetime_str): - # Parse YYYY-MM-DDTHH:MM format - try: - date_part, time_part = datetime_str.split('T') - year, month, day = map(int, date_part.split('-')) - hour, minute = map(int, time_part.split(':')) - return year, month, day, hour, minute - except: - return None, None, None, None, None + # Remove trailing 'Z' if present + if datetime_str.endswith("Z"): + datetime_str = datetime_str[:-1] + + formats = [ + "%Y-%m-%dT%H:%M", # e.g. "2025-04-18T09:00" + "%Y-%m-%dT%H:%M:%S", # e.g. "2025-04-18T09:00:00" + "%Y-%m-%d %H:%M:%S", # e.g. "2025-04-18 09:00:00" + ] + + for fmt in formats: + try: + parsed = time.strptime(datetime_str, fmt) + return ( + parsed.tm_year, + parsed.tm_mon, + parsed.tm_mday, + parsed.tm_hour, + parsed.tm_min, + ) + except ValueError: + continue + return None, None, None, None, None class TrainServer(BaseHTTPRequestHandler): @@ -47,7 +62,7 @@ class TrainServer(BaseHTTPRequestHandler): def write_response(self, response): try: # Python 3 - self.wfile.write(response.encode('utf-8')) + self.wfile.write(response.encode("utf-8")) except AttributeError: # Python 1.5.2 self.wfile.write(response) @@ -73,7 +88,7 @@ class TrainServer(BaseHTTPRequestHandler): # Journey takes 1-2 hours duration = 60 + random.randint(0, 60) - arr_hour = (adj_hour + (duration // 60)) + arr_hour = adj_hour + (duration // 60) arr_minute = (adj_minute + (duration % 60)) % 60 arr_day = adj_day + (arr_hour // 24) arr_hour = arr_hour % 24 @@ -83,10 +98,14 @@ class TrainServer(BaseHTTPRequestHandler): "type": "outbound", "departure": origin, "arrival": destination, - "departure_time": format_datetime(year, month, adj_day, adj_hour, adj_minute), - "arrival_time": format_datetime(year, month, arr_day, arr_hour, arr_minute), + "departure_time": format_datetime( + year, month, adj_day, adj_hour, adj_minute + ), + "arrival_time": format_datetime( + year, month, arr_day, arr_hour, arr_minute + ), "platform": str(random.randint(1, 8)), - "price": round(30 + random.random() * 50, 2) + "price": round(30 + random.random() * 50, 2), } journeys.append(journey) @@ -102,7 +121,7 @@ class TrainServer(BaseHTTPRequestHandler): adj_hour = adj_hour % 24 duration = 60 + random.randint(0, 60) - arr_hour = (adj_hour + (duration // 60)) + arr_hour = adj_hour + (duration // 60) arr_minute = (adj_minute + (duration % 60)) % 60 arr_day = adj_day + (arr_hour // 24) arr_hour = arr_hour % 24 @@ -112,10 +131,14 @@ class TrainServer(BaseHTTPRequestHandler): "type": "return", "departure": destination, "arrival": origin, - "departure_time": format_datetime(year, month, adj_day, adj_hour, adj_minute), - "arrival_time": format_datetime(year, month, arr_day, arr_hour, arr_minute), + "departure_time": format_datetime( + year, month, adj_day, adj_hour, adj_minute + ), + "arrival_time": format_datetime( + year, month, arr_day, arr_hour, arr_minute + ), "platform": str(random.randint(1, 8)), - "price": round(30 + random.random() * 50, 2) + "price": round(30 + random.random() * 50, 2), } journeys.append(journey) @@ -124,49 +147,57 @@ class TrainServer(BaseHTTPRequestHandler): def do_GET(self): parsed_url = urlparse(self.path) - if parsed_url.path == '/api/journeys': + if parsed_url.path == "/api/journeys": try: params = parse_qs(parsed_url.query) - origin = params.get('from', [''])[0] - destination = params.get('to', [''])[0] - outbound_datetime = params.get('outbound_time', [''])[0] - return_datetime = params.get('return_time', [''])[0] + origin = params.get("from", [""])[0] + destination = params.get("to", [""])[0] + outbound_datetime = params.get("outbound_time", [""])[0] + return_datetime = params.get("return_time", [""])[0] if not origin or not destination or not outbound_datetime: self.send_response(400) - self.send_header('Content-Type', 'application/json') + self.send_header("Content-Type", "application/json") self.end_headers() - self.write_response(self.format_json({ - "error": "Required parameters: 'from', 'to', and 'outbound_time'" - })) + self.write_response( + self.format_json( + { + "error": "Required parameters: 'from', 'to', and 'outbound_time'" + } + ) + ) return # Parse datetimes out_dt = parse_datetime(outbound_datetime) - ret_dt = parse_datetime(return_datetime) if return_datetime else ( - None, None, None, None, None) + ret_dt = ( + parse_datetime(return_datetime) + if return_datetime + else (None, None, None, None, None) + ) if out_dt[0] is None: self.send_response(400) - self.send_header('Content-Type', 'application/json') + self.send_header("Content-Type", "application/json") self.end_headers() - self.write_response(self.format_json({ - "error": "Invalid datetime format. Use YYYY-MM-DDTHH:MM" - })) + self.write_response( + self.format_json( + {"error": "Invalid datetime format. Use YYYY-MM-DDTHH:MM"} + ) + ) return self.send_response(200) - self.send_header('Content-Type', 'application/json') + self.send_header("Content-Type", "application/json") self.end_headers() - journeys = self.generate_journeys( - origin, destination, out_dt, ret_dt) + journeys = self.generate_journeys(origin, destination, out_dt, ret_dt) response = self.format_json({"journeys": journeys}) self.write_response(response) except Exception as e: self.send_response(500) - self.send_header('Content-Type', 'application/json') + self.send_header("Content-Type", "application/json") self.end_headers() self.write_response(self.format_json({"error": str(e)})) else: @@ -176,20 +207,23 @@ class TrainServer(BaseHTTPRequestHandler): def do_POST(self): parsed_url = urlparse(self.path) - if parsed_url.path.startswith('/api/book/'): - journey_id = parsed_url.path.split('/')[-1] + if parsed_url.path.startswith("/api/book/"): + journey_id = parsed_url.path.split("/")[-1] self.send_response(200) - self.send_header('Content-Type', 'application/json') + self.send_header("Content-Type", "application/json") self.end_headers() - booking_ref = "BR" + \ - "".join([random.choice(string.digits) for _ in range(5)]) + booking_ref = "BR" + "".join( + [random.choice(string.digits) for _ in range(5)] + ) - response = self.format_json({ - "booking_reference": booking_ref, - "journey_id": journey_id, - "status": "confirmed" - }) + response = self.format_json( + { + "booking_reference": booking_ref, + "journey_id": journey_id, + "status": "confirmed", + } + ) self.write_response(response) else: @@ -198,10 +232,10 @@ class TrainServer(BaseHTTPRequestHandler): def run_server(): - server = HTTPServer(('', 8080), TrainServer) + server = HTTPServer(("", 8080), TrainServer) print("Train booking server starting on port 8080...") server.serve_forever() -if __name__ == '__main__': +if __name__ == "__main__": run_server() diff --git a/tools/create_invoice.py b/tools/create_invoice.py index 1cf0283..19e5b92 100644 --- a/tools/create_invoice.py +++ b/tools/create_invoice.py @@ -43,7 +43,7 @@ def create_invoice(args: dict) -> dict: customer=customer_id, amount=amount_cents, currency="usd", - description=args.get("flightDetails", "Service Invoice"), + description=args.get("tripDetails", "Service Invoice"), ) # Create and finalize the invoice diff --git a/tools/goal_registry.py b/tools/goal_registry.py index 60e4e4a..d7f8060 100644 --- a/tools/goal_registry.py +++ b/tools/goal_registry.py @@ -7,8 +7,49 @@ from tools.tool_registry import ( create_invoice_tool, ) +goal_match_train_invoice = AgentGoal( + tools=[ + search_fixtures_tool, + search_trains_tool, + book_train_tool, + create_invoice_tool, + ], + description="Help the user book trains to a premier league match. The user lives in London. Gather args for these tools in order: " + "1. SearchFixtures: Search for fixtures for a team in a given month" + "2. SearchTrains: Search for trains to visit somewhere before or after the match" + "3. BookTrain: Book the train tickets" + "4. CreateInvoice: Create a simple invoice for the cost of the flights and train tickets", + starter_prompt="Welcome me, give me a description of what you can do, then ask me for the details you need to do your job", + example_conversation_history="\n ".join( + [ + "user: I'd like to travel to a premier league match", + "agent: Sure! Let's start by finding an match you'd like to attend. I know about Premier League fixtures in the UK. Could you tell me which team and month you're interested in?", + "user: Wolves in May please", + "agent: Great! Let's find a match for Wolverhampton Wanderers FC in May.", + "user_confirmed_tool_run: ", + 'tool_result: results including {"homeTeam": "Wolverhampton Wanderers FC", "awayTeam": "Manchester United", "date": "2025-05-04"}', + "agent: Found a match! There's an away game against Manchester United on May 4 2025. Would you like to plan train travel from London for around this date?", + "user_confirmed_tool_run: ", + "tool_result: results including train dates and times, origin and depature stations", + "agent: Found some trains! The best option is leaving on