refactors and ui enhancements

This commit is contained in:
Steve Androulakis
2025-01-04 11:27:59 -08:00
parent 43904650dd
commit 010518c16e
11 changed files with 90 additions and 77 deletions

View File

@@ -7,6 +7,7 @@ This demo shows a multi-turn conversation with an AI agent running inside a Temp
## Setup
* Requires an OpenAI key for the gpt-4o model. Set this in the `OPENAI_API_KEY` environment variable in .env
* Requires a rapidapi key for sky-scrapper (how we find flights). Set this in the `RAPIDAPI_KEY` environment variable in .env
* It's free to sign up and get a key at [RapidAPI](https://rapidapi.com/apiheya/api/sky-scrapper)
* If you're lazy go to `tools/search_flights.py` and replace the `get_flights` function with the mock `search_flights_example` that exists in the same file.
* See .env_example for the required environment variables.
* Install and run Temporal. Follow the instructions in the [Temporal documentation](https://learn.temporal.io/getting_started/python/dev_environment/#set-up-a-local-temporal-service-for-development-with-temporal-cli) to install and run the Temporal server.

View File

@@ -1,15 +1,10 @@
from fastapi import FastAPI
from temporalio.client import Client
from workflows.tool_workflow import ToolWorkflow
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
from models.data_types import CombinedInput, ToolWorkflowParams
from tools.goal_registry import goal_event_flight_invoice
from temporalio.exceptions import TemporalError
from fastapi.middleware.cors import CORSMiddleware
from tools.tool_registry import (
find_events_tool,
search_flights_tool,
create_invoice_tool,
)
app = FastAPI()
@@ -68,40 +63,10 @@ async def get_conversation_history():
async def send_prompt(prompt: str):
client = await Client.connect("localhost:7233")
# Build the ToolsData
tools_data = ToolsData(
tools=[find_events_tool, search_flights_tool, create_invoice_tool],
description="Help the user gather args for these tools in order: "
"1. FindEvents: Find an event to travel to "
"2. SearchFlights: search for a flight around the event dates "
"3. GenerateInvoice: Create a simple invoice for the cost of that flight ",
example_conversation_history="\n ".join(
[
"user: I'd like to travel to an event",
"agent: Sure! Let's start by finding an event you'd like to attend. Could you tell me which city and month you're interested in?",
"user: In Sao Paulo, Brazil, in February",
"agent: Great! Let's find an events in Sao Paulo, Brazil in February.",
"user_confirmed_tool_run: <user clicks confirm on FindEvents tool>",
"tool_result: { 'event_name': 'Carnival', 'event_date': '2023-02-25' }",
"agent: Found an event! There's Carnival on 2023-02-25, ending on 2023-02-28. Would you like to search for flights around these dates?",
"user: Yes, please",
"agent: Let's search for flights around these dates. Could you provide your departure city?",
"user: New York",
"agent: Thanks, searching for flights from New York to Sao Paulo around 2023-02-25 to 2023-02-28.",
"user_confirmed_tool_run: <user clicks confirm on SearchFlights tool>"
'tool_result: results including {"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0}',
"agent: Found some flights! The cheapest is CX101 for $850. Would you like to generate an invoice for this flight?",
"user_confirmed_tool_run: <user clicks confirm on CreateInvoice tool>",
'tool_result: { "status": "success", "invoice": { "flight_number": "CX101", "amount": 850.0 }, invoiceURL: "https://example.com/invoice" }',
"agent: Invoice generated! Here's the link: https://example.com/invoice",
]
),
)
# Create combined input
combined_input = CombinedInput(
tool_params=ToolWorkflowParams(None, None),
tools_data=tools_data,
agent_goal=goal_event_flight_invoice,
)
workflow_id = "agent-workflow"

View File

@@ -1,5 +1,4 @@
import React from "react";
import LoadingIndicator from "./LoadingIndicator";
export default function ConfirmInline({ data, confirmed, onConfirm }) {
const { args, tool } = data || {};

View File

@@ -32,6 +32,16 @@ export default function LLMResponse({ data, onConfirm, isLastMessage }) {
onConfirm={handleConfirm}
/>
)}
{!requiresConfirm && data.tool && data.next === "confirm" && (
<div className="text-sm text-center text-green-600 dark:text-green-400">
<div>
Agent ran tool: <strong>{data.tool ?? "Unknown"}</strong>
</div>
{/* <div>
{JSON.stringify(data, null, 2)}
</div> */}
</div>
)}
</div>
);
}

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from typing import Optional, Deque
from models.tool_definitions import ToolsData
from models.tool_definitions import AgentGoal
@dataclass
@@ -12,4 +12,4 @@ class ToolWorkflowParams:
@dataclass
class CombinedInput:
tool_params: ToolWorkflowParams
tools_data: ToolsData
agent_goal: AgentGoal

View File

@@ -17,7 +17,7 @@ class ToolDefinition:
@dataclass
class ToolsData:
class AgentGoal:
tools: List[ToolDefinition]
description: str = "Description of the tools purpose and overall goal"
example_conversation_history: str = (

View File

@@ -1,10 +1,10 @@
from models.tool_definitions import ToolsData
from models.tool_definitions import AgentGoal
from typing import Optional
import json
def generate_genai_prompt(
tools_data: ToolsData, conversation_history: str, raw_json: Optional[str] = None
agent_goal: AgentGoal, conversation_history: str, raw_json: Optional[str] = None
) -> str:
"""
Generates a concise prompt for producing or validating JSON instructions
@@ -28,22 +28,22 @@ def generate_genai_prompt(
prompt_lines.append("END CONVERSATION HISTORY")
prompt_lines.append("")
# Example Conversation History (from tools_data)
if tools_data.example_conversation_history:
# Example Conversation History (from agent_goal)
if agent_goal.example_conversation_history:
prompt_lines.append("=== Example Conversation With These Tools ===")
prompt_lines.append(
"Use this example to understand how tools are invoked and arguments are gathered."
)
prompt_lines.append("BEGIN EXAMPLE")
prompt_lines.append(tools_data.example_conversation_history)
prompt_lines.append(agent_goal.example_conversation_history)
prompt_lines.append("END EXAMPLE")
prompt_lines.append("")
# Tools Definitions
prompt_lines.append("=== Tools Definitions ===")
prompt_lines.append(f"There are {len(tools_data.tools)} available tools:")
prompt_lines.append(", ".join([t.name for t in tools_data.tools]))
prompt_lines.append(f"Goal: {tools_data.description}")
prompt_lines.append(f"There are {len(agent_goal.tools)} available tools:")
prompt_lines.append(", ".join([t.name for t in agent_goal.tools]))
prompt_lines.append(f"Goal: {agent_goal.description}")
prompt_lines.append(
"Gather the necessary information for each tool in the sequence described above."
)
@@ -51,7 +51,7 @@ def generate_genai_prompt(
"Only ask for arguments listed below. Do not add extra arguments."
)
prompt_lines.append("")
for tool in tools_data.tools:
for tool in agent_goal.tools:
prompt_lines.append(f"Tool name: {tool.name}")
prompt_lines.append(f" Description: {tool.description}")
prompt_lines.append(" Required args:")

View File

@@ -2,14 +2,14 @@ import asyncio
import sys
from temporalio.client import Client
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
from models.data_types import CombinedInput, AgentGoal, ToolWorkflowParams
from tools.tool_registry import event_travel_tools
from workflows.tool_workflow import ToolWorkflow
async def main(prompt: str):
# Build the ToolsData
tools_data = ToolsData(
# Build the AgentGoal
agent_goal = AgentGoal(
tools=event_travel_tools,
description="Helps the user find an event to travel to, search flights, and create an invoice for those flights.",
)
@@ -17,7 +17,7 @@ async def main(prompt: str):
# 2) Create combined input
combined_input = CombinedInput(
tool_params=ToolWorkflowParams(None, None),
tools_data=tools_data,
agent_goal=agent_goal,
)
# 3) Connect to Temporal and start or signal the workflow

35
tools/goal_registry.py Normal file
View File

@@ -0,0 +1,35 @@
from models.tool_definitions import AgentGoal
from tools.tool_registry import (
find_events_tool,
search_flights_tool,
create_invoice_tool,
)
goal_event_flight_invoice = AgentGoal(
tools=[find_events_tool, search_flights_tool, create_invoice_tool],
description="Help the user gather args for these tools in order: "
"1. FindEvents: Find an event to travel to "
"2. SearchFlights: search for a flight around the event dates "
"3. GenerateInvoice: Create a simple invoice for the cost of that flight ",
example_conversation_history="\n ".join(
[
"user: I'd like to travel to an event",
"agent: Sure! Let's start by finding an event you'd like to attend. Could you tell me which city and month you're interested in?",
"user: In Sao Paulo, Brazil, in February",
"agent: Great! Let's find an events in Sao Paulo, Brazil in February.",
"user_confirmed_tool_run: <user clicks confirm on FindEvents tool>",
"tool_result: { 'event_name': 'Carnival', 'event_date': '2023-02-25' }",
"agent: Found an event! There's Carnival on 2023-02-25, ending on 2023-02-28. Would you like to search for flights around these dates?",
"user: Yes, please",
"agent: Let's search for flights around these dates. Could you provide your departure city?",
"user: New York",
"agent: Thanks, searching for flights from New York to Sao Paulo around 2023-02-25 to 2023-02-28.",
"user_confirmed_tool_run: <user clicks confirm on SearchFlights tool>"
'tool_result: results including {"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0}',
"agent: Found some flights! The cheapest is CX101 for $850. Would you like to generate an invoice for this flight?",
"user_confirmed_tool_run: <user clicks confirm on CreateInvoice tool>",
'tool_result: { "status": "success", "invoice": { "flight_number": "CX101", "amount": 850.0 }, invoiceURL: "https://example.com/invoice" }',
"agent: Invoice generated! Here's the link: https://example.com/invoice",
]
),
)

View File

@@ -40,7 +40,7 @@ def search_airport(query: str) -> list:
return []
def search_flights(args: dict) -> dict:
def search_flights_realapi(args: dict) -> dict:
"""
1) Looks up airport/city codes via search_airport.
2) Finds the first matching skyId/entityId for both origin & destination.
@@ -169,26 +169,31 @@ def search_flights(args: dict) -> dict:
}
def search_flights_example(args: dict) -> dict:
def search_flights(args: dict) -> dict:
"""
Example function for searching flights.
Currently just prints/returns the passed args,
but you can add real flight search logic later.
Returns example flight search results in the requested JSON format.
"""
# date_depart = args.get("dateDepart")
# date_return = args.get("dateReturn")
origin = args.get("origin")
destination = args.get("destination")
flight_search_results = {
"origin": f"{origin}",
"destination": f"{destination}",
return {
"currency": "USD",
"destination": f"{destination}",
"origin": f"{origin}",
"results": [
{"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0},
{"flight_number": "QF30", "return_flight_number": "QF29", "price": 920.0},
{"flight_number": "MH129", "return_flight_number": "MH128", "price": 780.0},
{
"operating_carrier": "American Airlines",
"outbound_flight_code": "AA203",
"price": 1262.51,
"return_flight_code": "AA202",
"return_operating_carrier": "American Airlines",
},
{
"operating_carrier": "Air New Zealand",
"outbound_flight_code": "NZ488",
"price": 1396.00,
"return_flight_code": "NZ527",
"return_operating_carrier": "Air New Zealand",
},
],
}
return flight_search_results

View File

@@ -30,7 +30,7 @@ class ToolWorkflow:
@workflow.run
async def run(self, combined_input: CombinedInput) -> str:
params = combined_input.tool_params
tools_data = combined_input.tools_data
agent_goal = combined_input.agent_goal
tool_data = None
if params and params.conversation_summary:
@@ -97,10 +97,10 @@ class ToolWorkflow:
# Pass entire conversation + Tools to LLM
context_instructions = generate_genai_prompt(
tools_data, self.conversation_history, self.tool_data
agent_goal, self.conversation_history, self.tool_data
)
# tools_list = ", ".join([t.name for t in tools_data.tools])
# tools_list = ", ".join([t.name for t in agent_goal.tools])
prompt_input = ToolPromptInput(
prompt=prompt,
@@ -121,7 +121,7 @@ class ToolWorkflow:
current_tool = self.tool_data.get("tool")
if next_step == "confirm" and current_tool:
# tmp arg check
# todo make this less awkward
args = self.tool_data.get("args")
# check each argument for null values
@@ -132,8 +132,6 @@ class ToolWorkflow:
missing_args.append(key)
if missing_args:
# self.add_message("response_confirm_missing_args", tool_data)
# Enqueue a follow-up prompt for the LLM
self.prompt_queue.append(
f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' and following missing arguments for tool {current_tool}: {missing_args}. "
@@ -184,7 +182,7 @@ class ToolWorkflow:
conversation_summary=self.conversation_summary,
prompt_queue=self.prompt_queue,
),
tools_data=tools_data,
agent_goal=agent_goal,
)
]
)