From 010518c16e439d841581d742398f5473a789b927 Mon Sep 17 00:00:00 2001 From: Steve Androulakis Date: Sat, 4 Jan 2025 11:27:59 -0800 Subject: [PATCH] refactors and ui enhancements --- README.md | 1 + api/main.py | 41 ++--------------------- frontend/src/components/ConfirmInline.jsx | 1 - frontend/src/components/LLMResponse.jsx | 10 ++++++ models/data_types.py | 4 +-- models/tool_definitions.py | 2 +- prompts/agent_prompt_generators.py | 18 +++++----- scripts/send_message.py | 8 ++--- tools/goal_registry.py | 35 +++++++++++++++++++ tools/search_flights.py | 35 ++++++++++--------- workflows/tool_workflow.py | 12 +++---- 11 files changed, 90 insertions(+), 77 deletions(-) create mode 100644 tools/goal_registry.py diff --git a/README.md b/README.md index 5eee7b2..8b36144 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ This demo shows a multi-turn conversation with an AI agent running inside a Temp ## Setup * Requires an OpenAI key for the gpt-4o model. Set this in the `OPENAI_API_KEY` environment variable in .env * Requires a rapidapi key for sky-scrapper (how we find flights). Set this in the `RAPIDAPI_KEY` environment variable in .env + * It's free to sign up and get a key at [RapidAPI](https://rapidapi.com/apiheya/api/sky-scrapper) * If you're lazy go to `tools/search_flights.py` and replace the `get_flights` function with the mock `search_flights_example` that exists in the same file. * See .env_example for the required environment variables. * Install and run Temporal. Follow the instructions in the [Temporal documentation](https://learn.temporal.io/getting_started/python/dev_environment/#set-up-a-local-temporal-service-for-development-with-temporal-cli) to install and run the Temporal server. diff --git a/api/main.py b/api/main.py index 671c23a..da51617 100644 --- a/api/main.py +++ b/api/main.py @@ -1,15 +1,10 @@ from fastapi import FastAPI from temporalio.client import Client from workflows.tool_workflow import ToolWorkflow -from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams +from models.data_types import CombinedInput, ToolWorkflowParams +from tools.goal_registry import goal_event_flight_invoice from temporalio.exceptions import TemporalError from fastapi.middleware.cors import CORSMiddleware -from tools.tool_registry import ( - find_events_tool, - search_flights_tool, - create_invoice_tool, -) - app = FastAPI() @@ -68,40 +63,10 @@ async def get_conversation_history(): async def send_prompt(prompt: str): client = await Client.connect("localhost:7233") - # Build the ToolsData - tools_data = ToolsData( - tools=[find_events_tool, search_flights_tool, create_invoice_tool], - description="Help the user gather args for these tools in order: " - "1. FindEvents: Find an event to travel to " - "2. SearchFlights: search for a flight around the event dates " - "3. GenerateInvoice: Create a simple invoice for the cost of that flight ", - example_conversation_history="\n ".join( - [ - "user: I'd like to travel to an event", - "agent: Sure! Let's start by finding an event you'd like to attend. Could you tell me which city and month you're interested in?", - "user: In Sao Paulo, Brazil, in February", - "agent: Great! Let's find an events in Sao Paulo, Brazil in February.", - "user_confirmed_tool_run: ", - "tool_result: { 'event_name': 'Carnival', 'event_date': '2023-02-25' }", - "agent: Found an event! There's Carnival on 2023-02-25, ending on 2023-02-28. Would you like to search for flights around these dates?", - "user: Yes, please", - "agent: Let's search for flights around these dates. Could you provide your departure city?", - "user: New York", - "agent: Thanks, searching for flights from New York to Sao Paulo around 2023-02-25 to 2023-02-28.", - "user_confirmed_tool_run: " - 'tool_result: results including {"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0}', - "agent: Found some flights! The cheapest is CX101 for $850. Would you like to generate an invoice for this flight?", - "user_confirmed_tool_run: ", - 'tool_result: { "status": "success", "invoice": { "flight_number": "CX101", "amount": 850.0 }, invoiceURL: "https://example.com/invoice" }', - "agent: Invoice generated! Here's the link: https://example.com/invoice", - ] - ), - ) - # Create combined input combined_input = CombinedInput( tool_params=ToolWorkflowParams(None, None), - tools_data=tools_data, + agent_goal=goal_event_flight_invoice, ) workflow_id = "agent-workflow" diff --git a/frontend/src/components/ConfirmInline.jsx b/frontend/src/components/ConfirmInline.jsx index 878d817..5119e9c 100644 --- a/frontend/src/components/ConfirmInline.jsx +++ b/frontend/src/components/ConfirmInline.jsx @@ -1,5 +1,4 @@ import React from "react"; -import LoadingIndicator from "./LoadingIndicator"; export default function ConfirmInline({ data, confirmed, onConfirm }) { const { args, tool } = data || {}; diff --git a/frontend/src/components/LLMResponse.jsx b/frontend/src/components/LLMResponse.jsx index 654dc21..c62b05b 100644 --- a/frontend/src/components/LLMResponse.jsx +++ b/frontend/src/components/LLMResponse.jsx @@ -32,6 +32,16 @@ export default function LLMResponse({ data, onConfirm, isLastMessage }) { onConfirm={handleConfirm} /> )} + {!requiresConfirm && data.tool && data.next === "confirm" && ( +
+
+ Agent ran tool: {data.tool ?? "Unknown"} +
+ {/*
+ {JSON.stringify(data, null, 2)} +
*/} +
+ )} ); } diff --git a/models/data_types.py b/models/data_types.py index 4b81577..a2de494 100644 --- a/models/data_types.py +++ b/models/data_types.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from typing import Optional, Deque -from models.tool_definitions import ToolsData +from models.tool_definitions import AgentGoal @dataclass @@ -12,4 +12,4 @@ class ToolWorkflowParams: @dataclass class CombinedInput: tool_params: ToolWorkflowParams - tools_data: ToolsData + agent_goal: AgentGoal diff --git a/models/tool_definitions.py b/models/tool_definitions.py index d52ea39..6e711b5 100644 --- a/models/tool_definitions.py +++ b/models/tool_definitions.py @@ -17,7 +17,7 @@ class ToolDefinition: @dataclass -class ToolsData: +class AgentGoal: tools: List[ToolDefinition] description: str = "Description of the tools purpose and overall goal" example_conversation_history: str = ( diff --git a/prompts/agent_prompt_generators.py b/prompts/agent_prompt_generators.py index 00cc634..3cbf536 100644 --- a/prompts/agent_prompt_generators.py +++ b/prompts/agent_prompt_generators.py @@ -1,10 +1,10 @@ -from models.tool_definitions import ToolsData +from models.tool_definitions import AgentGoal from typing import Optional import json def generate_genai_prompt( - tools_data: ToolsData, conversation_history: str, raw_json: Optional[str] = None + agent_goal: AgentGoal, conversation_history: str, raw_json: Optional[str] = None ) -> str: """ Generates a concise prompt for producing or validating JSON instructions @@ -28,22 +28,22 @@ def generate_genai_prompt( prompt_lines.append("END CONVERSATION HISTORY") prompt_lines.append("") - # Example Conversation History (from tools_data) - if tools_data.example_conversation_history: + # Example Conversation History (from agent_goal) + if agent_goal.example_conversation_history: prompt_lines.append("=== Example Conversation With These Tools ===") prompt_lines.append( "Use this example to understand how tools are invoked and arguments are gathered." ) prompt_lines.append("BEGIN EXAMPLE") - prompt_lines.append(tools_data.example_conversation_history) + prompt_lines.append(agent_goal.example_conversation_history) prompt_lines.append("END EXAMPLE") prompt_lines.append("") # Tools Definitions prompt_lines.append("=== Tools Definitions ===") - prompt_lines.append(f"There are {len(tools_data.tools)} available tools:") - prompt_lines.append(", ".join([t.name for t in tools_data.tools])) - prompt_lines.append(f"Goal: {tools_data.description}") + prompt_lines.append(f"There are {len(agent_goal.tools)} available tools:") + prompt_lines.append(", ".join([t.name for t in agent_goal.tools])) + prompt_lines.append(f"Goal: {agent_goal.description}") prompt_lines.append( "Gather the necessary information for each tool in the sequence described above." ) @@ -51,7 +51,7 @@ def generate_genai_prompt( "Only ask for arguments listed below. Do not add extra arguments." ) prompt_lines.append("") - for tool in tools_data.tools: + for tool in agent_goal.tools: prompt_lines.append(f"Tool name: {tool.name}") prompt_lines.append(f" Description: {tool.description}") prompt_lines.append(" Required args:") diff --git a/scripts/send_message.py b/scripts/send_message.py index 100e48e..13acd37 100644 --- a/scripts/send_message.py +++ b/scripts/send_message.py @@ -2,14 +2,14 @@ import asyncio import sys from temporalio.client import Client -from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams +from models.data_types import CombinedInput, AgentGoal, ToolWorkflowParams from tools.tool_registry import event_travel_tools from workflows.tool_workflow import ToolWorkflow async def main(prompt: str): - # Build the ToolsData - tools_data = ToolsData( + # Build the AgentGoal + agent_goal = AgentGoal( tools=event_travel_tools, description="Helps the user find an event to travel to, search flights, and create an invoice for those flights.", ) @@ -17,7 +17,7 @@ async def main(prompt: str): # 2) Create combined input combined_input = CombinedInput( tool_params=ToolWorkflowParams(None, None), - tools_data=tools_data, + agent_goal=agent_goal, ) # 3) Connect to Temporal and start or signal the workflow diff --git a/tools/goal_registry.py b/tools/goal_registry.py new file mode 100644 index 0000000..3052f09 --- /dev/null +++ b/tools/goal_registry.py @@ -0,0 +1,35 @@ +from models.tool_definitions import AgentGoal +from tools.tool_registry import ( + find_events_tool, + search_flights_tool, + create_invoice_tool, +) + +goal_event_flight_invoice = AgentGoal( + tools=[find_events_tool, search_flights_tool, create_invoice_tool], + description="Help the user gather args for these tools in order: " + "1. FindEvents: Find an event to travel to " + "2. SearchFlights: search for a flight around the event dates " + "3. GenerateInvoice: Create a simple invoice for the cost of that flight ", + example_conversation_history="\n ".join( + [ + "user: I'd like to travel to an event", + "agent: Sure! Let's start by finding an event you'd like to attend. Could you tell me which city and month you're interested in?", + "user: In Sao Paulo, Brazil, in February", + "agent: Great! Let's find an events in Sao Paulo, Brazil in February.", + "user_confirmed_tool_run: ", + "tool_result: { 'event_name': 'Carnival', 'event_date': '2023-02-25' }", + "agent: Found an event! There's Carnival on 2023-02-25, ending on 2023-02-28. Would you like to search for flights around these dates?", + "user: Yes, please", + "agent: Let's search for flights around these dates. Could you provide your departure city?", + "user: New York", + "agent: Thanks, searching for flights from New York to Sao Paulo around 2023-02-25 to 2023-02-28.", + "user_confirmed_tool_run: " + 'tool_result: results including {"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0}', + "agent: Found some flights! The cheapest is CX101 for $850. Would you like to generate an invoice for this flight?", + "user_confirmed_tool_run: ", + 'tool_result: { "status": "success", "invoice": { "flight_number": "CX101", "amount": 850.0 }, invoiceURL: "https://example.com/invoice" }', + "agent: Invoice generated! Here's the link: https://example.com/invoice", + ] + ), +) diff --git a/tools/search_flights.py b/tools/search_flights.py index 46ae148..40f4297 100644 --- a/tools/search_flights.py +++ b/tools/search_flights.py @@ -40,7 +40,7 @@ def search_airport(query: str) -> list: return [] -def search_flights(args: dict) -> dict: +def search_flights_realapi(args: dict) -> dict: """ 1) Looks up airport/city codes via search_airport. 2) Finds the first matching skyId/entityId for both origin & destination. @@ -169,26 +169,31 @@ def search_flights(args: dict) -> dict: } -def search_flights_example(args: dict) -> dict: +def search_flights(args: dict) -> dict: """ - Example function for searching flights. - Currently just prints/returns the passed args, - but you can add real flight search logic later. + Returns example flight search results in the requested JSON format. """ - # date_depart = args.get("dateDepart") - # date_return = args.get("dateReturn") origin = args.get("origin") destination = args.get("destination") - flight_search_results = { - "origin": f"{origin}", - "destination": f"{destination}", + return { "currency": "USD", + "destination": f"{destination}", + "origin": f"{origin}", "results": [ - {"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0}, - {"flight_number": "QF30", "return_flight_number": "QF29", "price": 920.0}, - {"flight_number": "MH129", "return_flight_number": "MH128", "price": 780.0}, + { + "operating_carrier": "American Airlines", + "outbound_flight_code": "AA203", + "price": 1262.51, + "return_flight_code": "AA202", + "return_operating_carrier": "American Airlines", + }, + { + "operating_carrier": "Air New Zealand", + "outbound_flight_code": "NZ488", + "price": 1396.00, + "return_flight_code": "NZ527", + "return_operating_carrier": "Air New Zealand", + }, ], } - - return flight_search_results diff --git a/workflows/tool_workflow.py b/workflows/tool_workflow.py index 712bebd..6efea91 100644 --- a/workflows/tool_workflow.py +++ b/workflows/tool_workflow.py @@ -30,7 +30,7 @@ class ToolWorkflow: @workflow.run async def run(self, combined_input: CombinedInput) -> str: params = combined_input.tool_params - tools_data = combined_input.tools_data + agent_goal = combined_input.agent_goal tool_data = None if params and params.conversation_summary: @@ -97,10 +97,10 @@ class ToolWorkflow: # Pass entire conversation + Tools to LLM context_instructions = generate_genai_prompt( - tools_data, self.conversation_history, self.tool_data + agent_goal, self.conversation_history, self.tool_data ) - # tools_list = ", ".join([t.name for t in tools_data.tools]) + # tools_list = ", ".join([t.name for t in agent_goal.tools]) prompt_input = ToolPromptInput( prompt=prompt, @@ -121,7 +121,7 @@ class ToolWorkflow: current_tool = self.tool_data.get("tool") if next_step == "confirm" and current_tool: - # tmp arg check + # todo make this less awkward args = self.tool_data.get("args") # check each argument for null values @@ -132,8 +132,6 @@ class ToolWorkflow: missing_args.append(key) if missing_args: - # self.add_message("response_confirm_missing_args", tool_data) - # Enqueue a follow-up prompt for the LLM self.prompt_queue.append( f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' and following missing arguments for tool {current_tool}: {missing_args}. " @@ -184,7 +182,7 @@ class ToolWorkflow: conversation_summary=self.conversation_summary, prompt_queue=self.prompt_queue, ), - tools_data=tools_data, + agent_goal=agent_goal, ) ] )