prompt engineering, train_api date parsing changes

This commit is contained in:
Steve Androulakis
2025-02-11 09:35:40 -08:00
parent 7f6ff2397f
commit aeffe75a0a
5 changed files with 141 additions and 61 deletions

View File

@@ -6,17 +6,20 @@ from temporalio.api.enums.v1 import WorkflowExecutionStatus
from workflows.tool_workflow import ToolWorkflow from workflows.tool_workflow import ToolWorkflow
from models.data_types import CombinedInput, ToolWorkflowParams from models.data_types import CombinedInput, ToolWorkflowParams
from tools.goal_registry import goal_event_flight_invoice from tools.goal_registry import goal_match_train_invoice
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE
app = FastAPI() app = FastAPI()
temporal_client: Optional[Client] = None temporal_client: Optional[Client] = None
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
global temporal_client global temporal_client
temporal_client = await get_temporal_client() temporal_client = await get_temporal_client()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["http://localhost:5173"], allow_origins=["http://localhost:5173"],
@@ -62,13 +65,13 @@ async def get_conversation_history():
status_names = { status_names = {
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED: "WORKFLOW_EXECUTION_STATUS_TERMINATED", WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED: "WORKFLOW_EXECUTION_STATUS_TERMINATED",
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED: "WORKFLOW_EXECUTION_STATUS_CANCELED", WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED: "WORKFLOW_EXECUTION_STATUS_CANCELED",
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED" WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED",
} }
failed_states = [ failed_states = [
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED, WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED,
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED, WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED,
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED,
] ]
# Check workflow status first # Check workflow status first
@@ -92,7 +95,7 @@ async def send_prompt(prompt: str):
# Create combined input # Create combined input
combined_input = CombinedInput( combined_input = CombinedInput(
tool_params=ToolWorkflowParams(None, None), tool_params=ToolWorkflowParams(None, None),
agent_goal=goal_event_flight_invoice, agent_goal=goal_match_train_invoice,
) )
workflow_id = "agent-workflow" workflow_id = "agent-workflow"
@@ -139,7 +142,7 @@ async def start_workflow():
# Create combined input # Create combined input
combined_input = CombinedInput( combined_input = CombinedInput(
tool_params=ToolWorkflowParams(None, None), tool_params=ToolWorkflowParams(None, None),
agent_goal=goal_event_flight_invoice, agent_goal=goal_match_train_invoice,
) )
workflow_id = "agent-workflow" workflow_id = "agent-workflow"
@@ -151,7 +154,9 @@ async def start_workflow():
id=workflow_id, id=workflow_id,
task_queue=TEMPORAL_TASK_QUEUE, task_queue=TEMPORAL_TASK_QUEUE,
start_signal="user_prompt", start_signal="user_prompt",
start_signal_args=["### " + goal_event_flight_invoice.starter_prompt], start_signal_args=["### " + goal_match_train_invoice.starter_prompt],
) )
return {"message": f"Workflow started with goal's starter prompt: {goal_event_flight_invoice.starter_prompt}."} return {
"message": f"Workflow started with goal's starter prompt: {goal_match_train_invoice.starter_prompt}."
}

View File

@@ -15,14 +15,29 @@ import string
def parse_datetime(datetime_str): def parse_datetime(datetime_str):
# Parse YYYY-MM-DDTHH:MM format # Remove trailing 'Z' if present
try: if datetime_str.endswith("Z"):
date_part, time_part = datetime_str.split('T') datetime_str = datetime_str[:-1]
year, month, day = map(int, date_part.split('-'))
hour, minute = map(int, time_part.split(':')) formats = [
return year, month, day, hour, minute "%Y-%m-%dT%H:%M", # e.g. "2025-04-18T09:00"
except: "%Y-%m-%dT%H:%M:%S", # e.g. "2025-04-18T09:00:00"
return None, None, None, None, None "%Y-%m-%d %H:%M:%S", # e.g. "2025-04-18 09:00:00"
]
for fmt in formats:
try:
parsed = time.strptime(datetime_str, fmt)
return (
parsed.tm_year,
parsed.tm_mon,
parsed.tm_mday,
parsed.tm_hour,
parsed.tm_min,
)
except ValueError:
continue
return None, None, None, None, None
class TrainServer(BaseHTTPRequestHandler): class TrainServer(BaseHTTPRequestHandler):
@@ -47,7 +62,7 @@ class TrainServer(BaseHTTPRequestHandler):
def write_response(self, response): def write_response(self, response):
try: try:
# Python 3 # Python 3
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode("utf-8"))
except AttributeError: except AttributeError:
# Python 1.5.2 # Python 1.5.2
self.wfile.write(response) self.wfile.write(response)
@@ -73,7 +88,7 @@ class TrainServer(BaseHTTPRequestHandler):
# Journey takes 1-2 hours # Journey takes 1-2 hours
duration = 60 + random.randint(0, 60) duration = 60 + random.randint(0, 60)
arr_hour = (adj_hour + (duration // 60)) arr_hour = adj_hour + (duration // 60)
arr_minute = (adj_minute + (duration % 60)) % 60 arr_minute = (adj_minute + (duration % 60)) % 60
arr_day = adj_day + (arr_hour // 24) arr_day = adj_day + (arr_hour // 24)
arr_hour = arr_hour % 24 arr_hour = arr_hour % 24
@@ -83,10 +98,14 @@ class TrainServer(BaseHTTPRequestHandler):
"type": "outbound", "type": "outbound",
"departure": origin, "departure": origin,
"arrival": destination, "arrival": destination,
"departure_time": format_datetime(year, month, adj_day, adj_hour, adj_minute), "departure_time": format_datetime(
"arrival_time": format_datetime(year, month, arr_day, arr_hour, arr_minute), year, month, adj_day, adj_hour, adj_minute
),
"arrival_time": format_datetime(
year, month, arr_day, arr_hour, arr_minute
),
"platform": str(random.randint(1, 8)), "platform": str(random.randint(1, 8)),
"price": round(30 + random.random() * 50, 2) "price": round(30 + random.random() * 50, 2),
} }
journeys.append(journey) journeys.append(journey)
@@ -102,7 +121,7 @@ class TrainServer(BaseHTTPRequestHandler):
adj_hour = adj_hour % 24 adj_hour = adj_hour % 24
duration = 60 + random.randint(0, 60) duration = 60 + random.randint(0, 60)
arr_hour = (adj_hour + (duration // 60)) arr_hour = adj_hour + (duration // 60)
arr_minute = (adj_minute + (duration % 60)) % 60 arr_minute = (adj_minute + (duration % 60)) % 60
arr_day = adj_day + (arr_hour // 24) arr_day = adj_day + (arr_hour // 24)
arr_hour = arr_hour % 24 arr_hour = arr_hour % 24
@@ -112,10 +131,14 @@ class TrainServer(BaseHTTPRequestHandler):
"type": "return", "type": "return",
"departure": destination, "departure": destination,
"arrival": origin, "arrival": origin,
"departure_time": format_datetime(year, month, adj_day, adj_hour, adj_minute), "departure_time": format_datetime(
"arrival_time": format_datetime(year, month, arr_day, arr_hour, arr_minute), year, month, adj_day, adj_hour, adj_minute
),
"arrival_time": format_datetime(
year, month, arr_day, arr_hour, arr_minute
),
"platform": str(random.randint(1, 8)), "platform": str(random.randint(1, 8)),
"price": round(30 + random.random() * 50, 2) "price": round(30 + random.random() * 50, 2),
} }
journeys.append(journey) journeys.append(journey)
@@ -124,49 +147,57 @@ class TrainServer(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self):
parsed_url = urlparse(self.path) parsed_url = urlparse(self.path)
if parsed_url.path == '/api/journeys': if parsed_url.path == "/api/journeys":
try: try:
params = parse_qs(parsed_url.query) params = parse_qs(parsed_url.query)
origin = params.get('from', [''])[0] origin = params.get("from", [""])[0]
destination = params.get('to', [''])[0] destination = params.get("to", [""])[0]
outbound_datetime = params.get('outbound_time', [''])[0] outbound_datetime = params.get("outbound_time", [""])[0]
return_datetime = params.get('return_time', [''])[0] return_datetime = params.get("return_time", [""])[0]
if not origin or not destination or not outbound_datetime: if not origin or not destination or not outbound_datetime:
self.send_response(400) self.send_response(400)
self.send_header('Content-Type', 'application/json') self.send_header("Content-Type", "application/json")
self.end_headers() self.end_headers()
self.write_response(self.format_json({ self.write_response(
"error": "Required parameters: 'from', 'to', and 'outbound_time'" self.format_json(
})) {
"error": "Required parameters: 'from', 'to', and 'outbound_time'"
}
)
)
return return
# Parse datetimes # Parse datetimes
out_dt = parse_datetime(outbound_datetime) out_dt = parse_datetime(outbound_datetime)
ret_dt = parse_datetime(return_datetime) if return_datetime else ( ret_dt = (
None, None, None, None, None) parse_datetime(return_datetime)
if return_datetime
else (None, None, None, None, None)
)
if out_dt[0] is None: if out_dt[0] is None:
self.send_response(400) self.send_response(400)
self.send_header('Content-Type', 'application/json') self.send_header("Content-Type", "application/json")
self.end_headers() self.end_headers()
self.write_response(self.format_json({ self.write_response(
"error": "Invalid datetime format. Use YYYY-MM-DDTHH:MM" self.format_json(
})) {"error": "Invalid datetime format. Use YYYY-MM-DDTHH:MM"}
)
)
return return
self.send_response(200) self.send_response(200)
self.send_header('Content-Type', 'application/json') self.send_header("Content-Type", "application/json")
self.end_headers() self.end_headers()
journeys = self.generate_journeys( journeys = self.generate_journeys(origin, destination, out_dt, ret_dt)
origin, destination, out_dt, ret_dt)
response = self.format_json({"journeys": journeys}) response = self.format_json({"journeys": journeys})
self.write_response(response) self.write_response(response)
except Exception as e: except Exception as e:
self.send_response(500) self.send_response(500)
self.send_header('Content-Type', 'application/json') self.send_header("Content-Type", "application/json")
self.end_headers() self.end_headers()
self.write_response(self.format_json({"error": str(e)})) self.write_response(self.format_json({"error": str(e)}))
else: else:
@@ -176,20 +207,23 @@ class TrainServer(BaseHTTPRequestHandler):
def do_POST(self): def do_POST(self):
parsed_url = urlparse(self.path) parsed_url = urlparse(self.path)
if parsed_url.path.startswith('/api/book/'): if parsed_url.path.startswith("/api/book/"):
journey_id = parsed_url.path.split('/')[-1] journey_id = parsed_url.path.split("/")[-1]
self.send_response(200) self.send_response(200)
self.send_header('Content-Type', 'application/json') self.send_header("Content-Type", "application/json")
self.end_headers() self.end_headers()
booking_ref = "BR" + \ booking_ref = "BR" + "".join(
"".join([random.choice(string.digits) for _ in range(5)]) [random.choice(string.digits) for _ in range(5)]
)
response = self.format_json({ response = self.format_json(
"booking_reference": booking_ref, {
"journey_id": journey_id, "booking_reference": booking_ref,
"status": "confirmed" "journey_id": journey_id,
}) "status": "confirmed",
}
)
self.write_response(response) self.write_response(response)
else: else:
@@ -198,10 +232,10 @@ class TrainServer(BaseHTTPRequestHandler):
def run_server(): def run_server():
server = HTTPServer(('', 8080), TrainServer) server = HTTPServer(("", 8080), TrainServer)
print("Train booking server starting on port 8080...") print("Train booking server starting on port 8080...")
server.serve_forever() server.serve_forever()
if __name__ == '__main__': if __name__ == "__main__":
run_server() run_server()

View File

@@ -43,7 +43,7 @@ def create_invoice(args: dict) -> dict:
customer=customer_id, customer=customer_id,
amount=amount_cents, amount=amount_cents,
currency="usd", currency="usd",
description=args.get("flightDetails", "Service Invoice"), description=args.get("tripDetails", "Service Invoice"),
) )
# Create and finalize the invoice # Create and finalize the invoice

View File

@@ -7,8 +7,49 @@ from tools.tool_registry import (
create_invoice_tool, create_invoice_tool,
) )
goal_match_train_invoice = AgentGoal(
tools=[
search_fixtures_tool,
search_trains_tool,
book_train_tool,
create_invoice_tool,
],
description="Help the user book trains to a premier league match. The user lives in London. Gather args for these tools in order: "
"1. SearchFixtures: Search for fixtures for a team in a given month"
"2. SearchTrains: Search for trains to visit somewhere before or after the match"
"3. BookTrain: Book the train tickets"
"4. CreateInvoice: Create a simple invoice for the cost of the flights and train tickets",
starter_prompt="Welcome me, give me a description of what you can do, then ask me for the details you need to do your job",
example_conversation_history="\n ".join(
[
"user: I'd like to travel to a premier league match",
"agent: Sure! Let's start by finding an match you'd like to attend. I know about Premier League fixtures in the UK. Could you tell me which team and month you're interested in?",
"user: Wolves in May please",
"agent: Great! Let's find a match for Wolverhampton Wanderers FC in May.",
"user_confirmed_tool_run: <user clicks confirm on SearchFixtures tool, passing the full team name as an input>",
'tool_result: results including {"homeTeam": "Wolverhampton Wanderers FC", "awayTeam": "Manchester United", "date": "2025-05-04"}',
"agent: Found a match! There's an away game against Manchester United on May 4 2025. Would you like to plan train travel from London for around this date?",
"user_confirmed_tool_run: <user clicks confirm on SearchTrains tool>",
"tool_result: results including train dates and times, origin and depature stations",
"agent: Found some trains! The best option is leaving <origin> on <date> <time> and arriving in <destination> at <date> <time>. The return trip is leaving <origin> on <date> <time> and arriving in <destination> at <date> <time>. Would you like to book this train?",
"user_confirmed_tool_run: <user clicks confirm on BookTrain tool>",
'tool_result: results including {"status": "success"}',
"agent: Train tickets booked! Now let's create an invoice for your train tickets",
"user_confirmed_tool_run: <user clicks confirm on CreateInvoice tool which includes details of the train journey, the match, and the total cost>",
"tool_result: contains an invoiceURL",
"agent: Great! I've generated your invoice for your trains to the <match>. You can view and pay your invoice at this link: https://invoice.stripe.com/i/acct_1NBOLuKVZbzw7QA5/test_YWNjdF8xTkJPTHVLVlpienc3UUE1LF9SaHlBTU9GYnFibEJ4VlpNaThkWkhrcUR6a1dwTmNULDEyOTE2MjkwNA0200CCUNvTox?s=ap",
]
),
)
# unused
goal_event_flight_invoice = AgentGoal( goal_event_flight_invoice = AgentGoal(
tools=[search_fixtures_tool, search_flights_tool, search_trains_tool, create_invoice_tool], tools=[
search_fixtures_tool,
search_flights_tool,
search_trains_tool,
create_invoice_tool,
],
description="Help the user gather args for these tools in order: " description="Help the user gather args for these tools in order: "
"1. SearchFixtures: Search for fixtures for a team in a given month" "1. SearchFixtures: Search for fixtures for a team in a given month"
"2. SearchFlights: Search for a flight around the match dates" "2. SearchFlights: Search for a flight around the match dates"

View File

@@ -76,7 +76,7 @@ create_invoice_tool = ToolDefinition(
description="The total cost to be invoiced", description="The total cost to be invoiced",
), ),
ToolArgument( ToolArgument(
name="flightDetails", name="tripDetails",
type="string", type="string",
description="A description of the item details to be invoiced", description="A description of the item details to be invoiced",
), ),