mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 05:58:08 +01:00
refactors and ui enhancements
This commit is contained in:
@@ -7,6 +7,7 @@ This demo shows a multi-turn conversation with an AI agent running inside a Temp
|
||||
## Setup
|
||||
* Requires an OpenAI key for the gpt-4o model. Set this in the `OPENAI_API_KEY` environment variable in .env
|
||||
* Requires a rapidapi key for sky-scrapper (how we find flights). Set this in the `RAPIDAPI_KEY` environment variable in .env
|
||||
* It's free to sign up and get a key at [RapidAPI](https://rapidapi.com/apiheya/api/sky-scrapper)
|
||||
* If you're lazy go to `tools/search_flights.py` and replace the `get_flights` function with the mock `search_flights_example` that exists in the same file.
|
||||
* See .env_example for the required environment variables.
|
||||
* Install and run Temporal. Follow the instructions in the [Temporal documentation](https://learn.temporal.io/getting_started/python/dev_environment/#set-up-a-local-temporal-service-for-development-with-temporal-cli) to install and run the Temporal server.
|
||||
|
||||
41
api/main.py
41
api/main.py
@@ -1,15 +1,10 @@
|
||||
from fastapi import FastAPI
|
||||
from temporalio.client import Client
|
||||
from workflows.tool_workflow import ToolWorkflow
|
||||
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
|
||||
from models.data_types import CombinedInput, ToolWorkflowParams
|
||||
from tools.goal_registry import goal_event_flight_invoice
|
||||
from temporalio.exceptions import TemporalError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from tools.tool_registry import (
|
||||
find_events_tool,
|
||||
search_flights_tool,
|
||||
create_invoice_tool,
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@@ -68,40 +63,10 @@ async def get_conversation_history():
|
||||
async def send_prompt(prompt: str):
|
||||
client = await Client.connect("localhost:7233")
|
||||
|
||||
# Build the ToolsData
|
||||
tools_data = ToolsData(
|
||||
tools=[find_events_tool, search_flights_tool, create_invoice_tool],
|
||||
description="Help the user gather args for these tools in order: "
|
||||
"1. FindEvents: Find an event to travel to "
|
||||
"2. SearchFlights: search for a flight around the event dates "
|
||||
"3. GenerateInvoice: Create a simple invoice for the cost of that flight ",
|
||||
example_conversation_history="\n ".join(
|
||||
[
|
||||
"user: I'd like to travel to an event",
|
||||
"agent: Sure! Let's start by finding an event you'd like to attend. Could you tell me which city and month you're interested in?",
|
||||
"user: In Sao Paulo, Brazil, in February",
|
||||
"agent: Great! Let's find an events in Sao Paulo, Brazil in February.",
|
||||
"user_confirmed_tool_run: <user clicks confirm on FindEvents tool>",
|
||||
"tool_result: { 'event_name': 'Carnival', 'event_date': '2023-02-25' }",
|
||||
"agent: Found an event! There's Carnival on 2023-02-25, ending on 2023-02-28. Would you like to search for flights around these dates?",
|
||||
"user: Yes, please",
|
||||
"agent: Let's search for flights around these dates. Could you provide your departure city?",
|
||||
"user: New York",
|
||||
"agent: Thanks, searching for flights from New York to Sao Paulo around 2023-02-25 to 2023-02-28.",
|
||||
"user_confirmed_tool_run: <user clicks confirm on SearchFlights tool>"
|
||||
'tool_result: results including {"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0}',
|
||||
"agent: Found some flights! The cheapest is CX101 for $850. Would you like to generate an invoice for this flight?",
|
||||
"user_confirmed_tool_run: <user clicks confirm on CreateInvoice tool>",
|
||||
'tool_result: { "status": "success", "invoice": { "flight_number": "CX101", "amount": 850.0 }, invoiceURL: "https://example.com/invoice" }',
|
||||
"agent: Invoice generated! Here's the link: https://example.com/invoice",
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
# Create combined input
|
||||
combined_input = CombinedInput(
|
||||
tool_params=ToolWorkflowParams(None, None),
|
||||
tools_data=tools_data,
|
||||
agent_goal=goal_event_flight_invoice,
|
||||
)
|
||||
|
||||
workflow_id = "agent-workflow"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import React from "react";
|
||||
import LoadingIndicator from "./LoadingIndicator";
|
||||
|
||||
export default function ConfirmInline({ data, confirmed, onConfirm }) {
|
||||
const { args, tool } = data || {};
|
||||
|
||||
@@ -32,6 +32,16 @@ export default function LLMResponse({ data, onConfirm, isLastMessage }) {
|
||||
onConfirm={handleConfirm}
|
||||
/>
|
||||
)}
|
||||
{!requiresConfirm && data.tool && data.next === "confirm" && (
|
||||
<div className="text-sm text-center text-green-600 dark:text-green-400">
|
||||
<div>
|
||||
Agent ran tool: <strong>{data.tool ?? "Unknown"}</strong>
|
||||
</div>
|
||||
{/* <div>
|
||||
{JSON.stringify(data, null, 2)}
|
||||
</div> */}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Deque
|
||||
from models.tool_definitions import ToolsData
|
||||
from models.tool_definitions import AgentGoal
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -12,4 +12,4 @@ class ToolWorkflowParams:
|
||||
@dataclass
|
||||
class CombinedInput:
|
||||
tool_params: ToolWorkflowParams
|
||||
tools_data: ToolsData
|
||||
agent_goal: AgentGoal
|
||||
|
||||
@@ -17,7 +17,7 @@ class ToolDefinition:
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolsData:
|
||||
class AgentGoal:
|
||||
tools: List[ToolDefinition]
|
||||
description: str = "Description of the tools purpose and overall goal"
|
||||
example_conversation_history: str = (
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from models.tool_definitions import ToolsData
|
||||
from models.tool_definitions import AgentGoal
|
||||
from typing import Optional
|
||||
import json
|
||||
|
||||
|
||||
def generate_genai_prompt(
|
||||
tools_data: ToolsData, conversation_history: str, raw_json: Optional[str] = None
|
||||
agent_goal: AgentGoal, conversation_history: str, raw_json: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generates a concise prompt for producing or validating JSON instructions
|
||||
@@ -28,22 +28,22 @@ def generate_genai_prompt(
|
||||
prompt_lines.append("END CONVERSATION HISTORY")
|
||||
prompt_lines.append("")
|
||||
|
||||
# Example Conversation History (from tools_data)
|
||||
if tools_data.example_conversation_history:
|
||||
# Example Conversation History (from agent_goal)
|
||||
if agent_goal.example_conversation_history:
|
||||
prompt_lines.append("=== Example Conversation With These Tools ===")
|
||||
prompt_lines.append(
|
||||
"Use this example to understand how tools are invoked and arguments are gathered."
|
||||
)
|
||||
prompt_lines.append("BEGIN EXAMPLE")
|
||||
prompt_lines.append(tools_data.example_conversation_history)
|
||||
prompt_lines.append(agent_goal.example_conversation_history)
|
||||
prompt_lines.append("END EXAMPLE")
|
||||
prompt_lines.append("")
|
||||
|
||||
# Tools Definitions
|
||||
prompt_lines.append("=== Tools Definitions ===")
|
||||
prompt_lines.append(f"There are {len(tools_data.tools)} available tools:")
|
||||
prompt_lines.append(", ".join([t.name for t in tools_data.tools]))
|
||||
prompt_lines.append(f"Goal: {tools_data.description}")
|
||||
prompt_lines.append(f"There are {len(agent_goal.tools)} available tools:")
|
||||
prompt_lines.append(", ".join([t.name for t in agent_goal.tools]))
|
||||
prompt_lines.append(f"Goal: {agent_goal.description}")
|
||||
prompt_lines.append(
|
||||
"Gather the necessary information for each tool in the sequence described above."
|
||||
)
|
||||
@@ -51,7 +51,7 @@ def generate_genai_prompt(
|
||||
"Only ask for arguments listed below. Do not add extra arguments."
|
||||
)
|
||||
prompt_lines.append("")
|
||||
for tool in tools_data.tools:
|
||||
for tool in agent_goal.tools:
|
||||
prompt_lines.append(f"Tool name: {tool.name}")
|
||||
prompt_lines.append(f" Description: {tool.description}")
|
||||
prompt_lines.append(" Required args:")
|
||||
|
||||
@@ -2,14 +2,14 @@ import asyncio
|
||||
import sys
|
||||
from temporalio.client import Client
|
||||
|
||||
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
|
||||
from models.data_types import CombinedInput, AgentGoal, ToolWorkflowParams
|
||||
from tools.tool_registry import event_travel_tools
|
||||
from workflows.tool_workflow import ToolWorkflow
|
||||
|
||||
|
||||
async def main(prompt: str):
|
||||
# Build the ToolsData
|
||||
tools_data = ToolsData(
|
||||
# Build the AgentGoal
|
||||
agent_goal = AgentGoal(
|
||||
tools=event_travel_tools,
|
||||
description="Helps the user find an event to travel to, search flights, and create an invoice for those flights.",
|
||||
)
|
||||
@@ -17,7 +17,7 @@ async def main(prompt: str):
|
||||
# 2) Create combined input
|
||||
combined_input = CombinedInput(
|
||||
tool_params=ToolWorkflowParams(None, None),
|
||||
tools_data=tools_data,
|
||||
agent_goal=agent_goal,
|
||||
)
|
||||
|
||||
# 3) Connect to Temporal and start or signal the workflow
|
||||
|
||||
35
tools/goal_registry.py
Normal file
35
tools/goal_registry.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from models.tool_definitions import AgentGoal
|
||||
from tools.tool_registry import (
|
||||
find_events_tool,
|
||||
search_flights_tool,
|
||||
create_invoice_tool,
|
||||
)
|
||||
|
||||
goal_event_flight_invoice = AgentGoal(
|
||||
tools=[find_events_tool, search_flights_tool, create_invoice_tool],
|
||||
description="Help the user gather args for these tools in order: "
|
||||
"1. FindEvents: Find an event to travel to "
|
||||
"2. SearchFlights: search for a flight around the event dates "
|
||||
"3. GenerateInvoice: Create a simple invoice for the cost of that flight ",
|
||||
example_conversation_history="\n ".join(
|
||||
[
|
||||
"user: I'd like to travel to an event",
|
||||
"agent: Sure! Let's start by finding an event you'd like to attend. Could you tell me which city and month you're interested in?",
|
||||
"user: In Sao Paulo, Brazil, in February",
|
||||
"agent: Great! Let's find an events in Sao Paulo, Brazil in February.",
|
||||
"user_confirmed_tool_run: <user clicks confirm on FindEvents tool>",
|
||||
"tool_result: { 'event_name': 'Carnival', 'event_date': '2023-02-25' }",
|
||||
"agent: Found an event! There's Carnival on 2023-02-25, ending on 2023-02-28. Would you like to search for flights around these dates?",
|
||||
"user: Yes, please",
|
||||
"agent: Let's search for flights around these dates. Could you provide your departure city?",
|
||||
"user: New York",
|
||||
"agent: Thanks, searching for flights from New York to Sao Paulo around 2023-02-25 to 2023-02-28.",
|
||||
"user_confirmed_tool_run: <user clicks confirm on SearchFlights tool>"
|
||||
'tool_result: results including {"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0}',
|
||||
"agent: Found some flights! The cheapest is CX101 for $850. Would you like to generate an invoice for this flight?",
|
||||
"user_confirmed_tool_run: <user clicks confirm on CreateInvoice tool>",
|
||||
'tool_result: { "status": "success", "invoice": { "flight_number": "CX101", "amount": 850.0 }, invoiceURL: "https://example.com/invoice" }',
|
||||
"agent: Invoice generated! Here's the link: https://example.com/invoice",
|
||||
]
|
||||
),
|
||||
)
|
||||
@@ -40,7 +40,7 @@ def search_airport(query: str) -> list:
|
||||
return []
|
||||
|
||||
|
||||
def search_flights(args: dict) -> dict:
|
||||
def search_flights_realapi(args: dict) -> dict:
|
||||
"""
|
||||
1) Looks up airport/city codes via search_airport.
|
||||
2) Finds the first matching skyId/entityId for both origin & destination.
|
||||
@@ -169,26 +169,31 @@ def search_flights(args: dict) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def search_flights_example(args: dict) -> dict:
|
||||
def search_flights(args: dict) -> dict:
|
||||
"""
|
||||
Example function for searching flights.
|
||||
Currently just prints/returns the passed args,
|
||||
but you can add real flight search logic later.
|
||||
Returns example flight search results in the requested JSON format.
|
||||
"""
|
||||
# date_depart = args.get("dateDepart")
|
||||
# date_return = args.get("dateReturn")
|
||||
origin = args.get("origin")
|
||||
destination = args.get("destination")
|
||||
|
||||
flight_search_results = {
|
||||
"origin": f"{origin}",
|
||||
"destination": f"{destination}",
|
||||
return {
|
||||
"currency": "USD",
|
||||
"destination": f"{destination}",
|
||||
"origin": f"{origin}",
|
||||
"results": [
|
||||
{"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0},
|
||||
{"flight_number": "QF30", "return_flight_number": "QF29", "price": 920.0},
|
||||
{"flight_number": "MH129", "return_flight_number": "MH128", "price": 780.0},
|
||||
{
|
||||
"operating_carrier": "American Airlines",
|
||||
"outbound_flight_code": "AA203",
|
||||
"price": 1262.51,
|
||||
"return_flight_code": "AA202",
|
||||
"return_operating_carrier": "American Airlines",
|
||||
},
|
||||
{
|
||||
"operating_carrier": "Air New Zealand",
|
||||
"outbound_flight_code": "NZ488",
|
||||
"price": 1396.00,
|
||||
"return_flight_code": "NZ527",
|
||||
"return_operating_carrier": "Air New Zealand",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
return flight_search_results
|
||||
|
||||
@@ -30,7 +30,7 @@ class ToolWorkflow:
|
||||
@workflow.run
|
||||
async def run(self, combined_input: CombinedInput) -> str:
|
||||
params = combined_input.tool_params
|
||||
tools_data = combined_input.tools_data
|
||||
agent_goal = combined_input.agent_goal
|
||||
tool_data = None
|
||||
|
||||
if params and params.conversation_summary:
|
||||
@@ -97,10 +97,10 @@ class ToolWorkflow:
|
||||
|
||||
# Pass entire conversation + Tools to LLM
|
||||
context_instructions = generate_genai_prompt(
|
||||
tools_data, self.conversation_history, self.tool_data
|
||||
agent_goal, self.conversation_history, self.tool_data
|
||||
)
|
||||
|
||||
# tools_list = ", ".join([t.name for t in tools_data.tools])
|
||||
# tools_list = ", ".join([t.name for t in agent_goal.tools])
|
||||
|
||||
prompt_input = ToolPromptInput(
|
||||
prompt=prompt,
|
||||
@@ -121,7 +121,7 @@ class ToolWorkflow:
|
||||
current_tool = self.tool_data.get("tool")
|
||||
|
||||
if next_step == "confirm" and current_tool:
|
||||
# tmp arg check
|
||||
# todo make this less awkward
|
||||
args = self.tool_data.get("args")
|
||||
|
||||
# check each argument for null values
|
||||
@@ -132,8 +132,6 @@ class ToolWorkflow:
|
||||
missing_args.append(key)
|
||||
|
||||
if missing_args:
|
||||
# self.add_message("response_confirm_missing_args", tool_data)
|
||||
|
||||
# Enqueue a follow-up prompt for the LLM
|
||||
self.prompt_queue.append(
|
||||
f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' and following missing arguments for tool {current_tool}: {missing_args}. "
|
||||
@@ -184,7 +182,7 @@ class ToolWorkflow:
|
||||
conversation_summary=self.conversation_summary,
|
||||
prompt_queue=self.prompt_queue,
|
||||
),
|
||||
tools_data=tools_data,
|
||||
agent_goal=agent_goal,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user