mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 05:58:08 +01:00
prompt engineering, train_api date parsing changes
This commit is contained in:
23
api/main.py
23
api/main.py
@@ -6,17 +6,20 @@ from temporalio.api.enums.v1 import WorkflowExecutionStatus
|
||||
|
||||
from workflows.tool_workflow import ToolWorkflow
|
||||
from models.data_types import CombinedInput, ToolWorkflowParams
|
||||
from tools.goal_registry import goal_event_flight_invoice
|
||||
from tools.goal_registry import goal_match_train_invoice
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE
|
||||
|
||||
app = FastAPI()
|
||||
temporal_client: Optional[Client] = None
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global temporal_client
|
||||
temporal_client = await get_temporal_client()
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:5173"],
|
||||
@@ -62,13 +65,13 @@ async def get_conversation_history():
|
||||
status_names = {
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED: "WORKFLOW_EXECUTION_STATUS_TERMINATED",
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED: "WORKFLOW_EXECUTION_STATUS_CANCELED",
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED"
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED: "WORKFLOW_EXECUTION_STATUS_FAILED",
|
||||
}
|
||||
|
||||
failed_states = [
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_TERMINATED,
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_CANCELED,
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED
|
||||
WorkflowExecutionStatus.WORKFLOW_EXECUTION_STATUS_FAILED,
|
||||
]
|
||||
|
||||
# Check workflow status first
|
||||
@@ -77,11 +80,11 @@ async def get_conversation_history():
|
||||
status_name = status_names.get(description.status, "UNKNOWN_STATUS")
|
||||
print(f"Workflow is in {status_name} state. Returning empty history.")
|
||||
return []
|
||||
|
||||
|
||||
# Only query if workflow is running
|
||||
conversation_history = await handle.query("get_conversation_history")
|
||||
return conversation_history
|
||||
|
||||
|
||||
except TemporalError as e:
|
||||
print(f"Temporal error: {e}")
|
||||
return []
|
||||
@@ -92,7 +95,7 @@ async def send_prompt(prompt: str):
|
||||
# Create combined input
|
||||
combined_input = CombinedInput(
|
||||
tool_params=ToolWorkflowParams(None, None),
|
||||
agent_goal=goal_event_flight_invoice,
|
||||
agent_goal=goal_match_train_invoice,
|
||||
)
|
||||
|
||||
workflow_id = "agent-workflow"
|
||||
@@ -139,7 +142,7 @@ async def start_workflow():
|
||||
# Create combined input
|
||||
combined_input = CombinedInput(
|
||||
tool_params=ToolWorkflowParams(None, None),
|
||||
agent_goal=goal_event_flight_invoice,
|
||||
agent_goal=goal_match_train_invoice,
|
||||
)
|
||||
|
||||
workflow_id = "agent-workflow"
|
||||
@@ -151,7 +154,9 @@ async def start_workflow():
|
||||
id=workflow_id,
|
||||
task_queue=TEMPORAL_TASK_QUEUE,
|
||||
start_signal="user_prompt",
|
||||
start_signal_args=["### " + goal_event_flight_invoice.starter_prompt],
|
||||
start_signal_args=["### " + goal_match_train_invoice.starter_prompt],
|
||||
)
|
||||
|
||||
return {"message": f"Workflow started with goal's starter prompt: {goal_event_flight_invoice.starter_prompt}."}
|
||||
return {
|
||||
"message": f"Workflow started with goal's starter prompt: {goal_match_train_invoice.starter_prompt}."
|
||||
}
|
||||
|
||||
130
thirdparty/train_api.py
vendored
130
thirdparty/train_api.py
vendored
@@ -15,14 +15,29 @@ import string
|
||||
|
||||
|
||||
def parse_datetime(datetime_str):
|
||||
# Parse YYYY-MM-DDTHH:MM format
|
||||
try:
|
||||
date_part, time_part = datetime_str.split('T')
|
||||
year, month, day = map(int, date_part.split('-'))
|
||||
hour, minute = map(int, time_part.split(':'))
|
||||
return year, month, day, hour, minute
|
||||
except:
|
||||
return None, None, None, None, None
|
||||
# Remove trailing 'Z' if present
|
||||
if datetime_str.endswith("Z"):
|
||||
datetime_str = datetime_str[:-1]
|
||||
|
||||
formats = [
|
||||
"%Y-%m-%dT%H:%M", # e.g. "2025-04-18T09:00"
|
||||
"%Y-%m-%dT%H:%M:%S", # e.g. "2025-04-18T09:00:00"
|
||||
"%Y-%m-%d %H:%M:%S", # e.g. "2025-04-18 09:00:00"
|
||||
]
|
||||
|
||||
for fmt in formats:
|
||||
try:
|
||||
parsed = time.strptime(datetime_str, fmt)
|
||||
return (
|
||||
parsed.tm_year,
|
||||
parsed.tm_mon,
|
||||
parsed.tm_mday,
|
||||
parsed.tm_hour,
|
||||
parsed.tm_min,
|
||||
)
|
||||
except ValueError:
|
||||
continue
|
||||
return None, None, None, None, None
|
||||
|
||||
|
||||
class TrainServer(BaseHTTPRequestHandler):
|
||||
@@ -47,7 +62,7 @@ class TrainServer(BaseHTTPRequestHandler):
|
||||
def write_response(self, response):
|
||||
try:
|
||||
# Python 3
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
self.wfile.write(response.encode("utf-8"))
|
||||
except AttributeError:
|
||||
# Python 1.5.2
|
||||
self.wfile.write(response)
|
||||
@@ -73,7 +88,7 @@ class TrainServer(BaseHTTPRequestHandler):
|
||||
|
||||
# Journey takes 1-2 hours
|
||||
duration = 60 + random.randint(0, 60)
|
||||
arr_hour = (adj_hour + (duration // 60))
|
||||
arr_hour = adj_hour + (duration // 60)
|
||||
arr_minute = (adj_minute + (duration % 60)) % 60
|
||||
arr_day = adj_day + (arr_hour // 24)
|
||||
arr_hour = arr_hour % 24
|
||||
@@ -83,10 +98,14 @@ class TrainServer(BaseHTTPRequestHandler):
|
||||
"type": "outbound",
|
||||
"departure": origin,
|
||||
"arrival": destination,
|
||||
"departure_time": format_datetime(year, month, adj_day, adj_hour, adj_minute),
|
||||
"arrival_time": format_datetime(year, month, arr_day, arr_hour, arr_minute),
|
||||
"departure_time": format_datetime(
|
||||
year, month, adj_day, adj_hour, adj_minute
|
||||
),
|
||||
"arrival_time": format_datetime(
|
||||
year, month, arr_day, arr_hour, arr_minute
|
||||
),
|
||||
"platform": str(random.randint(1, 8)),
|
||||
"price": round(30 + random.random() * 50, 2)
|
||||
"price": round(30 + random.random() * 50, 2),
|
||||
}
|
||||
journeys.append(journey)
|
||||
|
||||
@@ -102,7 +121,7 @@ class TrainServer(BaseHTTPRequestHandler):
|
||||
adj_hour = adj_hour % 24
|
||||
|
||||
duration = 60 + random.randint(0, 60)
|
||||
arr_hour = (adj_hour + (duration // 60))
|
||||
arr_hour = adj_hour + (duration // 60)
|
||||
arr_minute = (adj_minute + (duration % 60)) % 60
|
||||
arr_day = adj_day + (arr_hour // 24)
|
||||
arr_hour = arr_hour % 24
|
||||
@@ -112,10 +131,14 @@ class TrainServer(BaseHTTPRequestHandler):
|
||||
"type": "return",
|
||||
"departure": destination,
|
||||
"arrival": origin,
|
||||
"departure_time": format_datetime(year, month, adj_day, adj_hour, adj_minute),
|
||||
"arrival_time": format_datetime(year, month, arr_day, arr_hour, arr_minute),
|
||||
"departure_time": format_datetime(
|
||||
year, month, adj_day, adj_hour, adj_minute
|
||||
),
|
||||
"arrival_time": format_datetime(
|
||||
year, month, arr_day, arr_hour, arr_minute
|
||||
),
|
||||
"platform": str(random.randint(1, 8)),
|
||||
"price": round(30 + random.random() * 50, 2)
|
||||
"price": round(30 + random.random() * 50, 2),
|
||||
}
|
||||
journeys.append(journey)
|
||||
|
||||
@@ -124,49 +147,57 @@ class TrainServer(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
parsed_url = urlparse(self.path)
|
||||
|
||||
if parsed_url.path == '/api/journeys':
|
||||
if parsed_url.path == "/api/journeys":
|
||||
try:
|
||||
params = parse_qs(parsed_url.query)
|
||||
origin = params.get('from', [''])[0]
|
||||
destination = params.get('to', [''])[0]
|
||||
outbound_datetime = params.get('outbound_time', [''])[0]
|
||||
return_datetime = params.get('return_time', [''])[0]
|
||||
origin = params.get("from", [""])[0]
|
||||
destination = params.get("to", [""])[0]
|
||||
outbound_datetime = params.get("outbound_time", [""])[0]
|
||||
return_datetime = params.get("return_time", [""])[0]
|
||||
|
||||
if not origin or not destination or not outbound_datetime:
|
||||
self.send_response(400)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.end_headers()
|
||||
self.write_response(self.format_json({
|
||||
"error": "Required parameters: 'from', 'to', and 'outbound_time'"
|
||||
}))
|
||||
self.write_response(
|
||||
self.format_json(
|
||||
{
|
||||
"error": "Required parameters: 'from', 'to', and 'outbound_time'"
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Parse datetimes
|
||||
out_dt = parse_datetime(outbound_datetime)
|
||||
ret_dt = parse_datetime(return_datetime) if return_datetime else (
|
||||
None, None, None, None, None)
|
||||
ret_dt = (
|
||||
parse_datetime(return_datetime)
|
||||
if return_datetime
|
||||
else (None, None, None, None, None)
|
||||
)
|
||||
|
||||
if out_dt[0] is None:
|
||||
self.send_response(400)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.end_headers()
|
||||
self.write_response(self.format_json({
|
||||
"error": "Invalid datetime format. Use YYYY-MM-DDTHH:MM"
|
||||
}))
|
||||
self.write_response(
|
||||
self.format_json(
|
||||
{"error": "Invalid datetime format. Use YYYY-MM-DDTHH:MM"}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.end_headers()
|
||||
|
||||
journeys = self.generate_journeys(
|
||||
origin, destination, out_dt, ret_dt)
|
||||
journeys = self.generate_journeys(origin, destination, out_dt, ret_dt)
|
||||
response = self.format_json({"journeys": journeys})
|
||||
self.write_response(response)
|
||||
|
||||
except Exception as e:
|
||||
self.send_response(500)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.end_headers()
|
||||
self.write_response(self.format_json({"error": str(e)}))
|
||||
else:
|
||||
@@ -176,20 +207,23 @@ class TrainServer(BaseHTTPRequestHandler):
|
||||
def do_POST(self):
|
||||
parsed_url = urlparse(self.path)
|
||||
|
||||
if parsed_url.path.startswith('/api/book/'):
|
||||
journey_id = parsed_url.path.split('/')[-1]
|
||||
if parsed_url.path.startswith("/api/book/"):
|
||||
journey_id = parsed_url.path.split("/")[-1]
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.end_headers()
|
||||
|
||||
booking_ref = "BR" + \
|
||||
"".join([random.choice(string.digits) for _ in range(5)])
|
||||
booking_ref = "BR" + "".join(
|
||||
[random.choice(string.digits) for _ in range(5)]
|
||||
)
|
||||
|
||||
response = self.format_json({
|
||||
"booking_reference": booking_ref,
|
||||
"journey_id": journey_id,
|
||||
"status": "confirmed"
|
||||
})
|
||||
response = self.format_json(
|
||||
{
|
||||
"booking_reference": booking_ref,
|
||||
"journey_id": journey_id,
|
||||
"status": "confirmed",
|
||||
}
|
||||
)
|
||||
self.write_response(response)
|
||||
|
||||
else:
|
||||
@@ -198,10 +232,10 @@ class TrainServer(BaseHTTPRequestHandler):
|
||||
|
||||
|
||||
def run_server():
|
||||
server = HTTPServer(('', 8080), TrainServer)
|
||||
server = HTTPServer(("", 8080), TrainServer)
|
||||
print("Train booking server starting on port 8080...")
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
run_server()
|
||||
|
||||
@@ -43,7 +43,7 @@ def create_invoice(args: dict) -> dict:
|
||||
customer=customer_id,
|
||||
amount=amount_cents,
|
||||
currency="usd",
|
||||
description=args.get("flightDetails", "Service Invoice"),
|
||||
description=args.get("tripDetails", "Service Invoice"),
|
||||
)
|
||||
|
||||
# Create and finalize the invoice
|
||||
|
||||
@@ -7,8 +7,49 @@ from tools.tool_registry import (
|
||||
create_invoice_tool,
|
||||
)
|
||||
|
||||
goal_match_train_invoice = AgentGoal(
|
||||
tools=[
|
||||
search_fixtures_tool,
|
||||
search_trains_tool,
|
||||
book_train_tool,
|
||||
create_invoice_tool,
|
||||
],
|
||||
description="Help the user book trains to a premier league match. The user lives in London. Gather args for these tools in order: "
|
||||
"1. SearchFixtures: Search for fixtures for a team in a given month"
|
||||
"2. SearchTrains: Search for trains to visit somewhere before or after the match"
|
||||
"3. BookTrain: Book the train tickets"
|
||||
"4. CreateInvoice: Create a simple invoice for the cost of the flights and train tickets",
|
||||
starter_prompt="Welcome me, give me a description of what you can do, then ask me for the details you need to do your job",
|
||||
example_conversation_history="\n ".join(
|
||||
[
|
||||
"user: I'd like to travel to a premier league match",
|
||||
"agent: Sure! Let's start by finding an match you'd like to attend. I know about Premier League fixtures in the UK. Could you tell me which team and month you're interested in?",
|
||||
"user: Wolves in May please",
|
||||
"agent: Great! Let's find a match for Wolverhampton Wanderers FC in May.",
|
||||
"user_confirmed_tool_run: <user clicks confirm on SearchFixtures tool, passing the full team name as an input>",
|
||||
'tool_result: results including {"homeTeam": "Wolverhampton Wanderers FC", "awayTeam": "Manchester United", "date": "2025-05-04"}',
|
||||
"agent: Found a match! There's an away game against Manchester United on May 4 2025. Would you like to plan train travel from London for around this date?",
|
||||
"user_confirmed_tool_run: <user clicks confirm on SearchTrains tool>",
|
||||
"tool_result: results including train dates and times, origin and depature stations",
|
||||
"agent: Found some trains! The best option is leaving <origin> on <date> <time> and arriving in <destination> at <date> <time>. The return trip is leaving <origin> on <date> <time> and arriving in <destination> at <date> <time>. Would you like to book this train?",
|
||||
"user_confirmed_tool_run: <user clicks confirm on BookTrain tool>",
|
||||
'tool_result: results including {"status": "success"}',
|
||||
"agent: Train tickets booked! Now let's create an invoice for your train tickets",
|
||||
"user_confirmed_tool_run: <user clicks confirm on CreateInvoice tool which includes details of the train journey, the match, and the total cost>",
|
||||
"tool_result: contains an invoiceURL",
|
||||
"agent: Great! I've generated your invoice for your trains to the <match>. You can view and pay your invoice at this link: https://invoice.stripe.com/i/acct_1NBOLuKVZbzw7QA5/test_YWNjdF8xTkJPTHVLVlpienc3UUE1LF9SaHlBTU9GYnFibEJ4VlpNaThkWkhrcUR6a1dwTmNULDEyOTE2MjkwNA0200CCUNvTox?s=ap",
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
# unused
|
||||
goal_event_flight_invoice = AgentGoal(
|
||||
tools=[search_fixtures_tool, search_flights_tool, search_trains_tool, create_invoice_tool],
|
||||
tools=[
|
||||
search_fixtures_tool,
|
||||
search_flights_tool,
|
||||
search_trains_tool,
|
||||
create_invoice_tool,
|
||||
],
|
||||
description="Help the user gather args for these tools in order: "
|
||||
"1. SearchFixtures: Search for fixtures for a team in a given month"
|
||||
"2. SearchFlights: Search for a flight around the match dates"
|
||||
|
||||
@@ -76,7 +76,7 @@ create_invoice_tool = ToolDefinition(
|
||||
description="The total cost to be invoiced",
|
||||
),
|
||||
ToolArgument(
|
||||
name="flightDetails",
|
||||
name="tripDetails",
|
||||
type="string",
|
||||
description="A description of the item details to be invoiced",
|
||||
),
|
||||
@@ -98,4 +98,4 @@ search_fixtures_tool = ToolDefinition(
|
||||
description="The month to search for fixtures.",
|
||||
),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user