mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
prompt engineering, train_api date parsing changes
This commit is contained in:
19
api/main.py
19
api/main.py
@@ -6,17 +6,20 @@ from temporalio.api.enums.v1 import WorkflowExecutionStatus
|
|||||||
|
|
||||||
from workflows.tool_workflow import ToolWorkflow
|
from workflows.tool_workflow import ToolWorkflow
|
||||||
from models.data_types import CombinedInput, ToolWorkflowParams
|
from models.data_types import CombinedInput, ToolWorkflowParams
|
||||||
from tools.goal_registry import goal_event_flight_invoice
|
from tools.goal_registry import goal_match_train_invoice
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE
|
from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
temporal_client: Optional[Client] = None
|
temporal_client: Optional[Client] = None
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
global temporal_client
|
global temporal_client
|
||||||
temporal_client = await get_temporal_client()
|
temporal_client = await get_temporal_client()
|
||||||
|
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["http://localhost:5173"],
|
allow_origins=["http://localhost:5173"],
|
||||||
@@ -62,13 +65,13 @@ async def get_conversation_history():
|
|||||||
status_names = {
|
status_names = {
|
||||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED: "WORKFLOW_EXECUTION_STATUS_TERMINATED",
|
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED: "WORKFLOW_EXECUTION_STATUS_TERMINATED",
|
||||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED: "WORKFLOW_EXECUTION_STATUS_CANCELED",
|
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED: "WORKFLOW_EXECUTION_STATUS_CANCELED",
|
||||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED"
|
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED",
|
||||||
}
|
}
|
||||||
|
|
||||||
failed_states = [
|
failed_states = [
|
||||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED,
|
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED,
|
||||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED,
|
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED,
|
||||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED
|
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Check workflow status first
|
# Check workflow status first
|
||||||
@@ -92,7 +95,7 @@ async def send_prompt(prompt: str):
|
|||||||
# Create combined input
|
# Create combined input
|
||||||
combined_input = CombinedInput(
|
combined_input = CombinedInput(
|
||||||
tool_params=ToolWorkflowParams(None, None),
|
tool_params=ToolWorkflowParams(None, None),
|
||||||
agent_goal=goal_event_flight_invoice,
|
agent_goal=goal_match_train_invoice,
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_id = "agent-workflow"
|
workflow_id = "agent-workflow"
|
||||||
@@ -139,7 +142,7 @@ async def start_workflow():
|
|||||||
# Create combined input
|
# Create combined input
|
||||||
combined_input = CombinedInput(
|
combined_input = CombinedInput(
|
||||||
tool_params=ToolWorkflowParams(None, None),
|
tool_params=ToolWorkflowParams(None, None),
|
||||||
agent_goal=goal_event_flight_invoice,
|
agent_goal=goal_match_train_invoice,
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow_id = "agent-workflow"
|
workflow_id = "agent-workflow"
|
||||||
@@ -151,7 +154,9 @@ async def start_workflow():
|
|||||||
id=workflow_id,
|
id=workflow_id,
|
||||||
task_queue=TEMPORAL_TASK_QUEUE,
|
task_queue=TEMPORAL_TASK_QUEUE,
|
||||||
start_signal="user_prompt",
|
start_signal="user_prompt",
|
||||||
start_signal_args=["### " + goal_event_flight_invoice.starter_prompt],
|
start_signal_args=["### " + goal_match_train_invoice.starter_prompt],
|
||||||
)
|
)
|
||||||
|
|
||||||
return {"message": f"Workflow started with goal's starter prompt: {goal_event_flight_invoice.starter_prompt}."}
|
return {
|
||||||
|
"message": f"Workflow started with goal's starter prompt: {goal_match_train_invoice.starter_prompt}."
|
||||||
|
}
|
||||||
|
|||||||
120
thirdparty/train_api.py
vendored
120
thirdparty/train_api.py
vendored
@@ -15,13 +15,28 @@ import string
|
|||||||
|
|
||||||
|
|
||||||
def parse_datetime(datetime_str):
|
def parse_datetime(datetime_str):
|
||||||
# Parse YYYY-MM-DDTHH:MM format
|
# Remove trailing 'Z' if present
|
||||||
|
if datetime_str.endswith("Z"):
|
||||||
|
datetime_str = datetime_str[:-1]
|
||||||
|
|
||||||
|
formats = [
|
||||||
|
"%Y-%m-%dT%H:%M", # e.g. "2025-04-18T09:00"
|
||||||
|
"%Y-%m-%dT%H:%M:%S", # e.g. "2025-04-18T09:00:00"
|
||||||
|
"%Y-%m-%d %H:%M:%S", # e.g. "2025-04-18 09:00:00"
|
||||||
|
]
|
||||||
|
|
||||||
|
for fmt in formats:
|
||||||
try:
|
try:
|
||||||
date_part, time_part = datetime_str.split('T')
|
parsed = time.strptime(datetime_str, fmt)
|
||||||
year, month, day = map(int, date_part.split('-'))
|
return (
|
||||||
hour, minute = map(int, time_part.split(':'))
|
parsed.tm_year,
|
||||||
return year, month, day, hour, minute
|
parsed.tm_mon,
|
||||||
except:
|
parsed.tm_mday,
|
||||||
|
parsed.tm_hour,
|
||||||
|
parsed.tm_min,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
return None, None, None, None, None
|
return None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
@@ -47,7 +62,7 @@ class TrainServer(BaseHTTPRequestHandler):
|
|||||||
def write_response(self, response):
|
def write_response(self, response):
|
||||||
try:
|
try:
|
||||||
# Python 3
|
# Python 3
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode("utf-8"))
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# Python 1.5.2
|
# Python 1.5.2
|
||||||
self.wfile.write(response)
|
self.wfile.write(response)
|
||||||
@@ -73,7 +88,7 @@ class TrainServer(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
# Journey takes 1-2 hours
|
# Journey takes 1-2 hours
|
||||||
duration = 60 + random.randint(0, 60)
|
duration = 60 + random.randint(0, 60)
|
||||||
arr_hour = (adj_hour + (duration // 60))
|
arr_hour = adj_hour + (duration // 60)
|
||||||
arr_minute = (adj_minute + (duration % 60)) % 60
|
arr_minute = (adj_minute + (duration % 60)) % 60
|
||||||
arr_day = adj_day + (arr_hour // 24)
|
arr_day = adj_day + (arr_hour // 24)
|
||||||
arr_hour = arr_hour % 24
|
arr_hour = arr_hour % 24
|
||||||
@@ -83,10 +98,14 @@ class TrainServer(BaseHTTPRequestHandler):
|
|||||||
"type": "outbound",
|
"type": "outbound",
|
||||||
"departure": origin,
|
"departure": origin,
|
||||||
"arrival": destination,
|
"arrival": destination,
|
||||||
"departure_time": format_datetime(year, month, adj_day, adj_hour, adj_minute),
|
"departure_time": format_datetime(
|
||||||
"arrival_time": format_datetime(year, month, arr_day, arr_hour, arr_minute),
|
year, month, adj_day, adj_hour, adj_minute
|
||||||
|
),
|
||||||
|
"arrival_time": format_datetime(
|
||||||
|
year, month, arr_day, arr_hour, arr_minute
|
||||||
|
),
|
||||||
"platform": str(random.randint(1, 8)),
|
"platform": str(random.randint(1, 8)),
|
||||||
"price": round(30 + random.random() * 50, 2)
|
"price": round(30 + random.random() * 50, 2),
|
||||||
}
|
}
|
||||||
journeys.append(journey)
|
journeys.append(journey)
|
||||||
|
|
||||||
@@ -102,7 +121,7 @@ class TrainServer(BaseHTTPRequestHandler):
|
|||||||
adj_hour = adj_hour % 24
|
adj_hour = adj_hour % 24
|
||||||
|
|
||||||
duration = 60 + random.randint(0, 60)
|
duration = 60 + random.randint(0, 60)
|
||||||
arr_hour = (adj_hour + (duration // 60))
|
arr_hour = adj_hour + (duration // 60)
|
||||||
arr_minute = (adj_minute + (duration % 60)) % 60
|
arr_minute = (adj_minute + (duration % 60)) % 60
|
||||||
arr_day = adj_day + (arr_hour // 24)
|
arr_day = adj_day + (arr_hour // 24)
|
||||||
arr_hour = arr_hour % 24
|
arr_hour = arr_hour % 24
|
||||||
@@ -112,10 +131,14 @@ class TrainServer(BaseHTTPRequestHandler):
|
|||||||
"type": "return",
|
"type": "return",
|
||||||
"departure": destination,
|
"departure": destination,
|
||||||
"arrival": origin,
|
"arrival": origin,
|
||||||
"departure_time": format_datetime(year, month, adj_day, adj_hour, adj_minute),
|
"departure_time": format_datetime(
|
||||||
"arrival_time": format_datetime(year, month, arr_day, arr_hour, arr_minute),
|
year, month, adj_day, adj_hour, adj_minute
|
||||||
|
),
|
||||||
|
"arrival_time": format_datetime(
|
||||||
|
year, month, arr_day, arr_hour, arr_minute
|
||||||
|
),
|
||||||
"platform": str(random.randint(1, 8)),
|
"platform": str(random.randint(1, 8)),
|
||||||
"price": round(30 + random.random() * 50, 2)
|
"price": round(30 + random.random() * 50, 2),
|
||||||
}
|
}
|
||||||
journeys.append(journey)
|
journeys.append(journey)
|
||||||
|
|
||||||
@@ -124,49 +147,57 @@ class TrainServer(BaseHTTPRequestHandler):
|
|||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
parsed_url = urlparse(self.path)
|
parsed_url = urlparse(self.path)
|
||||||
|
|
||||||
if parsed_url.path == '/api/journeys':
|
if parsed_url.path == "/api/journeys":
|
||||||
try:
|
try:
|
||||||
params = parse_qs(parsed_url.query)
|
params = parse_qs(parsed_url.query)
|
||||||
origin = params.get('from', [''])[0]
|
origin = params.get("from", [""])[0]
|
||||||
destination = params.get('to', [''])[0]
|
destination = params.get("to", [""])[0]
|
||||||
outbound_datetime = params.get('outbound_time', [''])[0]
|
outbound_datetime = params.get("outbound_time", [""])[0]
|
||||||
return_datetime = params.get('return_time', [''])[0]
|
return_datetime = params.get("return_time", [""])[0]
|
||||||
|
|
||||||
if not origin or not destination or not outbound_datetime:
|
if not origin or not destination or not outbound_datetime:
|
||||||
self.send_response(400)
|
self.send_response(400)
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header("Content-Type", "application/json")
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
self.write_response(self.format_json({
|
self.write_response(
|
||||||
|
self.format_json(
|
||||||
|
{
|
||||||
"error": "Required parameters: 'from', 'to', and 'outbound_time'"
|
"error": "Required parameters: 'from', 'to', and 'outbound_time'"
|
||||||
}))
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Parse datetimes
|
# Parse datetimes
|
||||||
out_dt = parse_datetime(outbound_datetime)
|
out_dt = parse_datetime(outbound_datetime)
|
||||||
ret_dt = parse_datetime(return_datetime) if return_datetime else (
|
ret_dt = (
|
||||||
None, None, None, None, None)
|
parse_datetime(return_datetime)
|
||||||
|
if return_datetime
|
||||||
|
else (None, None, None, None, None)
|
||||||
|
)
|
||||||
|
|
||||||
if out_dt[0] is None:
|
if out_dt[0] is None:
|
||||||
self.send_response(400)
|
self.send_response(400)
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header("Content-Type", "application/json")
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
self.write_response(self.format_json({
|
self.write_response(
|
||||||
"error": "Invalid datetime format. Use YYYY-MM-DDTHH:MM"
|
self.format_json(
|
||||||
}))
|
{"error": "Invalid datetime format. Use YYYY-MM-DDTHH:MM"}
|
||||||
|
)
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header("Content-Type", "application/json")
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
journeys = self.generate_journeys(
|
journeys = self.generate_journeys(origin, destination, out_dt, ret_dt)
|
||||||
origin, destination, out_dt, ret_dt)
|
|
||||||
response = self.format_json({"journeys": journeys})
|
response = self.format_json({"journeys": journeys})
|
||||||
self.write_response(response)
|
self.write_response(response)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.send_response(500)
|
self.send_response(500)
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header("Content-Type", "application/json")
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
self.write_response(self.format_json({"error": str(e)}))
|
self.write_response(self.format_json({"error": str(e)}))
|
||||||
else:
|
else:
|
||||||
@@ -176,20 +207,23 @@ class TrainServer(BaseHTTPRequestHandler):
|
|||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
parsed_url = urlparse(self.path)
|
parsed_url = urlparse(self.path)
|
||||||
|
|
||||||
if parsed_url.path.startswith('/api/book/'):
|
if parsed_url.path.startswith("/api/book/"):
|
||||||
journey_id = parsed_url.path.split('/')[-1]
|
journey_id = parsed_url.path.split("/")[-1]
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header("Content-Type", "application/json")
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
booking_ref = "BR" + \
|
booking_ref = "BR" + "".join(
|
||||||
"".join([random.choice(string.digits) for _ in range(5)])
|
[random.choice(string.digits) for _ in range(5)]
|
||||||
|
)
|
||||||
|
|
||||||
response = self.format_json({
|
response = self.format_json(
|
||||||
|
{
|
||||||
"booking_reference": booking_ref,
|
"booking_reference": booking_ref,
|
||||||
"journey_id": journey_id,
|
"journey_id": journey_id,
|
||||||
"status": "confirmed"
|
"status": "confirmed",
|
||||||
})
|
}
|
||||||
|
)
|
||||||
self.write_response(response)
|
self.write_response(response)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -198,10 +232,10 @@ class TrainServer(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
|
|
||||||
def run_server():
|
def run_server():
|
||||||
server = HTTPServer(('', 8080), TrainServer)
|
server = HTTPServer(("", 8080), TrainServer)
|
||||||
print("Train booking server starting on port 8080...")
|
print("Train booking server starting on port 8080...")
|
||||||
server.serve_forever()
|
server.serve_forever()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
run_server()
|
run_server()
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ def create_invoice(args: dict) -> dict:
|
|||||||
customer=customer_id,
|
customer=customer_id,
|
||||||
amount=amount_cents,
|
amount=amount_cents,
|
||||||
currency="usd",
|
currency="usd",
|
||||||
description=args.get("flightDetails", "Service Invoice"),
|
description=args.get("tripDetails", "Service Invoice"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create and finalize the invoice
|
# Create and finalize the invoice
|
||||||
|
|||||||
@@ -7,8 +7,49 @@ from tools.tool_registry import (
|
|||||||
create_invoice_tool,
|
create_invoice_tool,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
goal_match_train_invoice = AgentGoal(
|
||||||
|
tools=[
|
||||||
|
search_fixtures_tool,
|
||||||
|
search_trains_tool,
|
||||||
|
book_train_tool,
|
||||||
|
create_invoice_tool,
|
||||||
|
],
|
||||||
|
description="Help the user book trains to a premier league match. The user lives in London. Gather args for these tools in order: "
|
||||||
|
"1. SearchFixtures: Search for fixtures for a team in a given month"
|
||||||
|
"2. SearchTrains: Search for trains to visit somewhere before or after the match"
|
||||||
|
"3. BookTrain: Book the train tickets"
|
||||||
|
"4. CreateInvoice: Create a simple invoice for the cost of the flights and train tickets",
|
||||||
|
starter_prompt="Welcome me, give me a description of what you can do, then ask me for the details you need to do your job",
|
||||||
|
example_conversation_history="\n ".join(
|
||||||
|
[
|
||||||
|
"user: I'd like to travel to a premier league match",
|
||||||
|
"agent: Sure! Let's start by finding an match you'd like to attend. I know about Premier League fixtures in the UK. Could you tell me which team and month you're interested in?",
|
||||||
|
"user: Wolves in May please",
|
||||||
|
"agent: Great! Let's find a match for Wolverhampton Wanderers FC in May.",
|
||||||
|
"user_confirmed_tool_run: <user clicks confirm on SearchFixtures tool, passing the full team name as an input>",
|
||||||
|
'tool_result: results including {"homeTeam": "Wolverhampton Wanderers FC", "awayTeam": "Manchester United", "date": "2025-05-04"}',
|
||||||
|
"agent: Found a match! There's an away game against Manchester United on May 4 2025. Would you like to plan train travel from London for around this date?",
|
||||||
|
"user_confirmed_tool_run: <user clicks confirm on SearchTrains tool>",
|
||||||
|
"tool_result: results including train dates and times, origin and depature stations",
|
||||||
|
"agent: Found some trains! The best option is leaving <origin> on <date> <time> and arriving in <destination> at <date> <time>. The return trip is leaving <origin> on <date> <time> and arriving in <destination> at <date> <time>. Would you like to book this train?",
|
||||||
|
"user_confirmed_tool_run: <user clicks confirm on BookTrain tool>",
|
||||||
|
'tool_result: results including {"status": "success"}',
|
||||||
|
"agent: Train tickets booked! Now let's create an invoice for your train tickets",
|
||||||
|
"user_confirmed_tool_run: <user clicks confirm on CreateInvoice tool which includes details of the train journey, the match, and the total cost>",
|
||||||
|
"tool_result: contains an invoiceURL",
|
||||||
|
"agent: Great! I've generated your invoice for your trains to the <match>. You can view and pay your invoice at this link: https://invoice.stripe.com/i/acct_1NBOLuKVZbzw7QA5/test_YWNjdF8xTkJPTHVLVlpienc3UUE1LF9SaHlBTU9GYnFibEJ4VlpNaThkWkhrcUR6a1dwTmNULDEyOTE2MjkwNA0200CCUNvTox?s=ap",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# unused
|
||||||
goal_event_flight_invoice = AgentGoal(
|
goal_event_flight_invoice = AgentGoal(
|
||||||
tools=[search_fixtures_tool, search_flights_tool, search_trains_tool, create_invoice_tool],
|
tools=[
|
||||||
|
search_fixtures_tool,
|
||||||
|
search_flights_tool,
|
||||||
|
search_trains_tool,
|
||||||
|
create_invoice_tool,
|
||||||
|
],
|
||||||
description="Help the user gather args for these tools in order: "
|
description="Help the user gather args for these tools in order: "
|
||||||
"1. SearchFixtures: Search for fixtures for a team in a given month"
|
"1. SearchFixtures: Search for fixtures for a team in a given month"
|
||||||
"2. SearchFlights: Search for a flight around the match dates"
|
"2. SearchFlights: Search for a flight around the match dates"
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ create_invoice_tool = ToolDefinition(
|
|||||||
description="The total cost to be invoiced",
|
description="The total cost to be invoiced",
|
||||||
),
|
),
|
||||||
ToolArgument(
|
ToolArgument(
|
||||||
name="flightDetails",
|
name="tripDetails",
|
||||||
type="string",
|
type="string",
|
||||||
description="A description of the item details to be invoiced",
|
description="A description of the item details to be invoiced",
|
||||||
),
|
),
|
||||||
|
|||||||
Reference in New Issue
Block a user