From f5cf7286a2e0d4da874bb7317b60faea655465e2 Mon Sep 17 00:00:00 2001 From: Steve Androulakis Date: Fri, 3 Jan 2025 15:05:27 -0800 Subject: [PATCH] works a lot better with 4o! --- .gitignore | 4 +- activities/tool_activities.py | 43 +++++++ api/main.py | 35 +++++- frontend/src/components/ChatWindow.jsx | 22 ++-- frontend/src/components/LLMResponse.jsx | 4 + frontend/src/pages/App.jsx | 5 +- models/tool_definitions.py | 8 +- poetry.lock | 158 +++++++++++++++++++++++- prompts/agent_prompt_generators.py | 97 ++++++--------- pyproject.toml | 2 + scripts/run_worker.py | 3 + scripts/send_message.py | 9 +- tools/find_events.py | 12 +- tools/search_flights.py | 25 ++-- tools/tool_registry.py | 14 +-- workflows/tool_workflow.py | 43 +++++-- 16 files changed, 365 insertions(+), 119 deletions(-) diff --git a/.gitignore b/.gitignore index fba0360..2d708e1 100644 --- a/.gitignore +++ b/.gitignore @@ -29,4 +29,6 @@ coverage.xml .vscode/ # PyCharm / IntelliJ settings -.idea/ \ No newline at end of file +.idea/ + +.env \ No newline at end of file diff --git a/activities/tool_activities.py b/activities/tool_activities.py index a83c654..e7fbb67 100644 --- a/activities/tool_activities.py +++ b/activities/tool_activities.py @@ -1,9 +1,12 @@ from dataclasses import dataclass from temporalio import activity from ollama import chat, ChatResponse +from openai import OpenAI import json from typing import Sequence from temporalio.common import RawValue +import os +from datetime import datetime @dataclass @@ -15,6 +18,46 @@ class ToolPromptInput: class ToolActivities: @activity.defn def prompt_llm(self, input: ToolPromptInput) -> dict: + client = OpenAI( + api_key=os.environ.get( + "OPENAI_API_KEY" + ), # This is the default and can be omitted + ) + + messages = [ + { + "role": "system", + "content": input.context_instructions + + ". The current date is " + + datetime.now().strftime("%B %d, %Y"), + }, + { + "role": "user", + "content": input.prompt, + }, + ] + + chat_completion = client.chat.completions.create( + model="gpt-4o", messages=messages # was gpt-4-0613 + ) + + response_content = chat_completion.choices[0].message.content + print(f"ChatGPT response: {response_content}") + + # Trim formatting markers if present + if response_content.startswith("```json") and response_content.endswith("```"): + response_content = response_content[7:-3].strip() + + try: + data = json.loads(response_content) + except json.JSONDecodeError as e: + print(f"Invalid JSON: {e}") + raise json.JSONDecodeError + + return data + + @activity.defn + def prompt_llm_ollama(self, input: ToolPromptInput) -> dict: model_name = "qwen2.5:14b" messages = [ { diff --git a/api/main.py b/api/main.py index e400f98..e7eafc9 100644 --- a/api/main.py +++ b/api/main.py @@ -4,7 +4,11 @@ from workflows.tool_workflow import ToolWorkflow from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams from temporalio.exceptions import TemporalError from fastapi.middleware.cors import CORSMiddleware -from tools.tool_registry import all_tools +from tools.tool_registry import ( + find_events_tool, + search_flights_tool, + create_invoice_tool, +) app = FastAPI() @@ -65,7 +69,34 @@ async def send_prompt(prompt: str): client = await Client.connect("localhost:7233") # Build the ToolsData - tools_data = ToolsData(tools=all_tools) + tools_data = ToolsData( + tools=[find_events_tool, search_flights_tool, create_invoice_tool], + description="Help the user gather args for these tools in order: " + "1. FindEvents: Find an event to travel to " + "2. SearchFlights: search for a flight around the event dates " + "3. GenerateInvoice: Create a simple invoice for the cost of that flight ", + example_conversation_history="\n ".join( + [ + "user: I'd like to travel to an event", + "agent: Sure! Let's start by finding an event you'd like to attend. Could you tell me which region and month you're interested in?", + "user: In Sao Paulo, Brazil, in February", + "agent: Great! Let's find an events in Sao Paulo, Brazil in February.", + "user_confirmed_tool_run: ", + "tool_result: { 'event_name': 'Carnival', 'event_date': '2023-02-25' }", + "agent: Found an event! There's Carnival on 2023-02-25, ending on 2023-02-28. Would you like to search for flights around these dates?", + "user: Yes, please", + "agent: Let's search for flights around these dates. Could you provide your departure city?", + "user: New York", + "agent: Thanks, searching for flights from New York to Sao Paulo around 2023-02-25 to 2023-02-28.", + "user_confirmed_tool_run: " + 'tool_result: results including {"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0}', + "agent: Found some flights! The cheapest is CX101 for $850. Would you like to generate an invoice for this flight?", + "user_confirmed_tool_run: ", + 'tool_result: { "status": "success", "invoice": { "flight_number": "CX101", "amount": 850.0 }, invoiceURL: "https://example.com/invoice" }', + "agent: Invoice generated! Here's the link: https://example.com/invoice", + ] + ), + ) # Create combined input combined_input = CombinedInput( diff --git a/frontend/src/components/ChatWindow.jsx b/frontend/src/components/ChatWindow.jsx index 26a8984..a82931b 100644 --- a/frontend/src/components/ChatWindow.jsx +++ b/frontend/src/components/ChatWindow.jsx @@ -19,17 +19,18 @@ export default function ChatWindow({ conversation, loading, onConfirm }) { } const filtered = conversation.filter((msg) => { - console.log(conversation[conversation.length - 1].actor) + const { actor, response } = msg; if (actor === "user") { return true; } - if (actor === "response") { + if (actor === "agent") { const parsed = typeof response === "string" ? safeParse(response) : response; - // Keep if next is "question", "confirm", or "confirmed". + // Keep if next is "question", "confirm", or "user_confirmed_tool_run". // Only skip if next is "done" (or something else). - return !["done"].includes(parsed.next); + // return !["done"].includes(parsed.next); + return true; } return false; }); @@ -37,13 +38,14 @@ export default function ChatWindow({ conversation, loading, onConfirm }) { return (
{filtered.map((msg, idx) => { + const { actor, response } = msg; if (actor === "user") { return ( ); - } else if (actor === "response") { + } else if (actor === "agent") { const data = typeof response === "string" ? safeParse(response) : response; return ; @@ -57,16 +59,6 @@ export default function ChatWindow({ conversation, loading, onConfirm }) {
)} - {conversation.length > 0 && conversation[conversation.length - 1].actor === "user" && ( -
- -
- )} - {conversation.length > 0 && conversation[conversation.length - 1].actor === "tool_result_to_llm" && ( -
- -
- )} ); } diff --git a/frontend/src/components/LLMResponse.jsx b/frontend/src/components/LLMResponse.jsx index 8dcb3ec..dee8fbb 100644 --- a/frontend/src/components/LLMResponse.jsx +++ b/frontend/src/components/LLMResponse.jsx @@ -14,6 +14,10 @@ export default function LLMResponse({ data, onConfirm }) { const requiresConfirm = data.next === "confirm"; + if (typeof data.response === "object") { + data.response = data.response.response; + } + let displayText = (data.response || "").trim(); if (!displayText && requiresConfirm) { displayText = `Agent is ready to run "${data.tool}". Please confirm.`; diff --git a/frontend/src/pages/App.jsx b/frontend/src/pages/App.jsx index ce7eef2..7ba5574 100644 --- a/frontend/src/pages/App.jsx +++ b/frontend/src/pages/App.jsx @@ -18,9 +18,12 @@ export default function App() { const data = await res.json(); // data is now an object like { messages: [ ... ] } - if (data.messages && data.messages.some(msg => msg.actor === "response" || msg.actor === "tool_result")) { + if (data.messages && data.messages.length > 0 && (data.messages[data.messages.length - 1].actor === "agent")) { setLoading(false); } + else { + setLoading(true); + } setConversation(data.messages || []); } } catch (err) { diff --git a/models/tool_definitions.py b/models/tool_definitions.py index d667203..d52ea39 100644 --- a/models/tool_definitions.py +++ b/models/tool_definitions.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass, field -from typing import List, Dict, Any +from dataclasses import dataclass +from typing import List @dataclass @@ -19,3 +19,7 @@ class ToolDefinition: @dataclass class ToolsData: tools: List[ToolDefinition] + description: str = "Description of the tools purpose and overall goal" + example_conversation_history: str = ( + "Example conversation history to help the AI agent understand the context of the conversation" + ) diff --git a/poetry.lock b/poetry.lock index a95f94d..560bc56 100644 --- a/poetry.lock +++ b/poetry.lock @@ -115,6 +115,17 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -245,6 +256,91 @@ files = [ [package.extras] colors = ["colorama (>=0.4.6)"] +[[package]] +name = "jiter" +version = "0.8.2" +description = "Fast iterable JSON parser." +optional = false +python-versions = ">=3.8" +files = [ + {file = "jiter-0.8.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:ca8577f6a413abe29b079bc30f907894d7eb07a865c4df69475e868d73e71c7b"}, + {file = "jiter-0.8.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b25bd626bde7fb51534190c7e3cb97cee89ee76b76d7585580e22f34f5e3f393"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5c826a221851a8dc028eb6d7d6429ba03184fa3c7e83ae01cd6d3bd1d4bd17d"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d35c864c2dff13dfd79fb070fc4fc6235d7b9b359efe340e1261deb21b9fcb66"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f557c55bc2b7676e74d39d19bcb8775ca295c7a028246175d6a8b431e70835e5"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:580ccf358539153db147e40751a0b41688a5ceb275e6f3e93d91c9467f42b2e3"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af102d3372e917cffce49b521e4c32c497515119dc7bd8a75665e90a718bbf08"}, + {file = "jiter-0.8.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cadcc978f82397d515bb2683fc0d50103acff2a180552654bb92d6045dec2c49"}, + {file = "jiter-0.8.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ba5bdf56969cad2019d4e8ffd3f879b5fdc792624129741d3d83fc832fef8c7d"}, + {file = "jiter-0.8.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3b94a33a241bee9e34b8481cdcaa3d5c2116f575e0226e421bed3f7a6ea71cff"}, + {file = "jiter-0.8.2-cp310-cp310-win32.whl", hash = "sha256:6e5337bf454abddd91bd048ce0dca5134056fc99ca0205258766db35d0a2ea43"}, + {file = "jiter-0.8.2-cp310-cp310-win_amd64.whl", hash = "sha256:4a9220497ca0cb1fe94e3f334f65b9b5102a0b8147646118f020d8ce1de70105"}, + {file = "jiter-0.8.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:2dd61c5afc88a4fda7d8b2cf03ae5947c6ac7516d32b7a15bf4b49569a5c076b"}, + {file = "jiter-0.8.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a6c710d657c8d1d2adbbb5c0b0c6bfcec28fd35bd6b5f016395f9ac43e878a15"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9584de0cd306072635fe4b89742bf26feae858a0683b399ad0c2509011b9dc0"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5a90a923338531b7970abb063cfc087eebae6ef8ec8139762007188f6bc69a9f"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d21974d246ed0181558087cd9f76e84e8321091ebfb3a93d4c341479a736f099"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:32475a42b2ea7b344069dc1e81445cfc00b9d0e3ca837f0523072432332e9f74"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b9931fd36ee513c26b5bf08c940b0ac875de175341cbdd4fa3be109f0492586"}, + {file = "jiter-0.8.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ce0820f4a3a59ddced7fce696d86a096d5cc48d32a4183483a17671a61edfddc"}, + {file = "jiter-0.8.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8ffc86ae5e3e6a93765d49d1ab47b6075a9c978a2b3b80f0f32628f39caa0c88"}, + {file = "jiter-0.8.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5127dc1abd809431172bc3fbe8168d6b90556a30bb10acd5ded41c3cfd6f43b6"}, + {file = "jiter-0.8.2-cp311-cp311-win32.whl", hash = "sha256:66227a2c7b575720c1871c8800d3a0122bb8ee94edb43a5685aa9aceb2782d44"}, + {file = "jiter-0.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:cde031d8413842a1e7501e9129b8e676e62a657f8ec8166e18a70d94d4682855"}, + {file = "jiter-0.8.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:e6ec2be506e7d6f9527dae9ff4b7f54e68ea44a0ef6b098256ddf895218a2f8f"}, + {file = "jiter-0.8.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:76e324da7b5da060287c54f2fabd3db5f76468006c811831f051942bf68c9d44"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:180a8aea058f7535d1c84183c0362c710f4750bef66630c05f40c93c2b152a0f"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:025337859077b41548bdcbabe38698bcd93cfe10b06ff66617a48ff92c9aec60"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecff0dc14f409599bbcafa7e470c00b80f17abc14d1405d38ab02e4b42e55b57"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ffd9fee7d0775ebaba131f7ca2e2d83839a62ad65e8e02fe2bd8fc975cedeb9e"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14601dcac4889e0a1c75ccf6a0e4baf70dbc75041e51bcf8d0e9274519df6887"}, + {file = "jiter-0.8.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:92249669925bc1c54fcd2ec73f70f2c1d6a817928480ee1c65af5f6b81cdf12d"}, + {file = "jiter-0.8.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e725edd0929fa79f8349ab4ec7f81c714df51dc4e991539a578e5018fa4a7152"}, + {file = "jiter-0.8.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bf55846c7b7a680eebaf9c3c48d630e1bf51bdf76c68a5f654b8524335b0ad29"}, + {file = "jiter-0.8.2-cp312-cp312-win32.whl", hash = "sha256:7efe4853ecd3d6110301665a5178b9856be7e2a9485f49d91aa4d737ad2ae49e"}, + {file = "jiter-0.8.2-cp312-cp312-win_amd64.whl", hash = "sha256:83c0efd80b29695058d0fd2fa8a556490dbce9804eac3e281f373bbc99045f6c"}, + {file = "jiter-0.8.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:ca1f08b8e43dc3bd0594c992fb1fd2f7ce87f7bf0d44358198d6da8034afdf84"}, + {file = "jiter-0.8.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5672a86d55416ccd214c778efccf3266b84f87b89063b582167d803246354be4"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58dc9bc9767a1101f4e5e22db1b652161a225874d66f0e5cb8e2c7d1c438b587"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:37b2998606d6dadbb5ccda959a33d6a5e853252d921fec1792fc902351bb4e2c"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4ab9a87f3784eb0e098f84a32670cfe4a79cb6512fd8f42ae3d0709f06405d18"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:79aec8172b9e3c6d05fd4b219d5de1ac616bd8da934107325a6c0d0e866a21b6"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:711e408732d4e9a0208008e5892c2966b485c783cd2d9a681f3eb147cf36c7ef"}, + {file = "jiter-0.8.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:653cf462db4e8c41995e33d865965e79641ef45369d8a11f54cd30888b7e6ff1"}, + {file = "jiter-0.8.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:9c63eaef32b7bebac8ebebf4dabebdbc6769a09c127294db6babee38e9f405b9"}, + {file = "jiter-0.8.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:eb21aaa9a200d0a80dacc7a81038d2e476ffe473ffdd9c91eb745d623561de05"}, + {file = "jiter-0.8.2-cp313-cp313-win32.whl", hash = "sha256:789361ed945d8d42850f919342a8665d2dc79e7e44ca1c97cc786966a21f627a"}, + {file = "jiter-0.8.2-cp313-cp313-win_amd64.whl", hash = "sha256:ab7f43235d71e03b941c1630f4b6e3055d46b6cb8728a17663eaac9d8e83a865"}, + {file = "jiter-0.8.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b426f72cd77da3fec300ed3bc990895e2dd6b49e3bfe6c438592a3ba660e41ca"}, + {file = "jiter-0.8.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2dd880785088ff2ad21ffee205e58a8c1ddabc63612444ae41e5e4b321b39c0"}, + {file = "jiter-0.8.2-cp313-cp313t-win_amd64.whl", hash = "sha256:3ac9f578c46f22405ff7f8b1f5848fb753cc4b8377fbec8470a7dc3997ca7566"}, + {file = "jiter-0.8.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:9e1fa156ee9454642adb7e7234a383884452532bc9d53d5af2d18d98ada1d79c"}, + {file = "jiter-0.8.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0cf5dfa9956d96ff2efb0f8e9c7d055904012c952539a774305aaaf3abdf3d6c"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e52bf98c7e727dd44f7c4acb980cb988448faeafed8433c867888268899b298b"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a2ecaa3c23e7a7cf86d00eda3390c232f4d533cd9ddea4b04f5d0644faf642c5"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:08d4c92bf480e19fc3f2717c9ce2aa31dceaa9163839a311424b6862252c943e"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:99d9a1eded738299ba8e106c6779ce5c3893cffa0e32e4485d680588adae6db8"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d20be8b7f606df096e08b0b1b4a3c6f0515e8dac296881fe7461dfa0fb5ec817"}, + {file = "jiter-0.8.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d33f94615fcaf872f7fd8cd98ac3b429e435c77619777e8a449d9d27e01134d1"}, + {file = "jiter-0.8.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:317b25e98a35ffec5c67efe56a4e9970852632c810d35b34ecdd70cc0e47b3b6"}, + {file = "jiter-0.8.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fc9043259ee430ecd71d178fccabd8c332a3bf1e81e50cae43cc2b28d19e4cb7"}, + {file = "jiter-0.8.2-cp38-cp38-win32.whl", hash = "sha256:fc5adda618205bd4678b146612ce44c3cbfdee9697951f2c0ffdef1f26d72b63"}, + {file = "jiter-0.8.2-cp38-cp38-win_amd64.whl", hash = "sha256:cd646c827b4f85ef4a78e4e58f4f5854fae0caf3db91b59f0d73731448a970c6"}, + {file = "jiter-0.8.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:e41e75344acef3fc59ba4765df29f107f309ca9e8eace5baacabd9217e52a5ee"}, + {file = "jiter-0.8.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f22b16b35d5c1df9dfd58843ab2cd25e6bf15191f5a236bed177afade507bfc"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7200b8f7619d36aa51c803fd52020a2dfbea36ffec1b5e22cab11fd34d95a6d"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:70bf4c43652cc294040dbb62256c83c8718370c8b93dd93d934b9a7bf6c4f53c"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f9d471356dc16f84ed48768b8ee79f29514295c7295cb41e1133ec0b2b8d637d"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:859e8eb3507894093d01929e12e267f83b1d5f6221099d3ec976f0c995cb6bd9"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaa58399c01db555346647a907b4ef6d4f584b123943be6ed5588c3f2359c9f4"}, + {file = "jiter-0.8.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8f2d5ed877f089862f4c7aacf3a542627c1496f972a34d0474ce85ee7d939c27"}, + {file = "jiter-0.8.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:03c9df035d4f8d647f8c210ddc2ae0728387275340668fb30d2421e17d9a0841"}, + {file = "jiter-0.8.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8bd2a824d08d8977bb2794ea2682f898ad3d8837932e3a74937e93d62ecbb637"}, + {file = "jiter-0.8.2-cp39-cp39-win32.whl", hash = "sha256:ca29b6371ebc40e496995c94b988a101b9fbbed48a51190a4461fcb0a68b4a36"}, + {file = "jiter-0.8.2-cp39-cp39-win_amd64.whl", hash = "sha256:1c0dfbd1be3cbefc7510102370d86e35d1d53e5a93d48519688b1bf0f761160a"}, + {file = "jiter-0.8.2.tar.gz", hash = "sha256:cd73d3e740666d0e639f678adb176fad25c1bcbdae88d8d7b857e1783bb4212d"}, +] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -271,6 +367,31 @@ files = [ httpx = ">=0.27.0,<0.28.0" pydantic = ">=2.9.0,<3.0.0" +[[package]] +name = "openai" +version = "1.59.2" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.8" +files = [ + {file = "openai-1.59.2-py3-none-any.whl", hash = "sha256:3de721df4d2ccc5e845afa7235dce496bfbdd572692a876d2ae6211e0290ff22"}, + {file = "openai-1.59.2.tar.gz", hash = "sha256:1bf2d5e8a93533f6dd3fb7b1bcf082ddd4ae61cc6d89ca1343e5957e4720651c"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +jiter = ">=0.4.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.11,<5" + +[package.extras] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] +realtime = ["websockets (>=13,<15)"] + [[package]] name = "packaging" version = "24.2" @@ -512,6 +633,20 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-dotenv" +version = "1.0.1" +description = "Read key-value pairs from a .env file and set them as environment variables" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, + {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, +] + +[package.extras] +cli = ["click (>=5.0)"] + [[package]] name = "pyyaml" version = "6.0.2" @@ -680,6 +815,27 @@ files = [ {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"}, ] +[[package]] +name = "tqdm" +version = "4.67.1" +description = "Fast, Extensible Progress Meter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"}, + {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["nbval", "pytest (>=6)", "pytest-asyncio (>=0.24)", "pytest-cov", "pytest-timeout"] +discord = ["requests"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + [[package]] name = "types-protobuf" version = "5.29.1.20241207" @@ -724,4 +880,4 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "20144fbc5773604251f9b61ac475eb9f292c6c8baf38d59170bbef81b482c71e" +content-hash = "ddef05a187e2f0364c8dc045d52aa30b7c95bc30075cfc0aa051af4fd1f8545b" diff --git a/prompts/agent_prompt_generators.py b/prompts/agent_prompt_generators.py index ea8b15b..76290d6 100644 --- a/prompts/agent_prompt_generators.py +++ b/prompts/agent_prompt_generators.py @@ -7,30 +7,49 @@ def generate_genai_prompt( tools_data: ToolsData, conversation_history: str, raw_json: Optional[str] = None ) -> str: """ - Generates a concise prompt for producing or validating JSON instructions. + Generates a concise prompt for producing or validating JSON instructions + with the provided tools and conversation history. """ prompt_lines = [] # Intro / Role prompt_lines.append( - "You are an AI assistant that must produce or validate JSON instructions " - "to properly call a set of tools. Respond with valid JSON only." + "You are an AI agent that helps fill required arguments for the tools described below. " + "You must respond with valid JSON ONLY, using the schema provided in the instructions." ) - # Conversation History + # Main Conversation History prompt_lines.append("=== Conversation History ===") prompt_lines.append( - "Use this history to understand needed tools, arguments, and the user's goals:" + "This is the ongoing history to determine which tool and arguments to gather:" ) prompt_lines.append("BEGIN CONVERSATION HISTORY") prompt_lines.append(json.dumps(conversation_history, indent=2)) prompt_lines.append("END CONVERSATION HISTORY") prompt_lines.append("") + # Example Conversation History (from tools_data) + if tools_data.example_conversation_history: + prompt_lines.append("=== Example Conversation With These Tools ===") + prompt_lines.append( + "Use this example to understand how tools are invoked and arguments are gathered." + ) + prompt_lines.append("BEGIN EXAMPLE") + prompt_lines.append(tools_data.example_conversation_history) + prompt_lines.append("END EXAMPLE") + prompt_lines.append("") + # Tools Definitions prompt_lines.append("=== Tools Definitions ===") prompt_lines.append(f"There are {len(tools_data.tools)} available tools:") prompt_lines.append(", ".join([t.name for t in tools_data.tools])) + prompt_lines.append(f"Goal: {tools_data.description}") + prompt_lines.append( + "Gather the necessary information for each tool in the sequence described above." + ) + prompt_lines.append( + "Only ask for arguments listed below. Do not add extra arguments." + ) prompt_lines.append("") for tool in tools_data.tools: prompt_lines.append(f"Tool name: {tool.name}") @@ -39,8 +58,11 @@ def generate_genai_prompt( for arg in tool.arguments: prompt_lines.append(f" - {arg.name} ({arg.type}): {arg.description}") prompt_lines.append("") + prompt_lines.append( + "When all required args for a tool are known, you can propose next='confirm' to run it." + ) - # Instructions for JSON Generation + # JSON Format Instructions prompt_lines.append("=== Instructions for JSON Generation ===") prompt_lines.append( "Your JSON format must be:\n" @@ -56,14 +78,14 @@ def generate_genai_prompt( "}" ) prompt_lines.append( - "1. You may call multiple tools sequentially. Each requires specific arguments.\n" - '2. If ANY required argument is missing, use "next": "question" and prompt the user.\n' - '3. If all required arguments are known, use "next": "confirm" and set "tool" to the tool name.\n' - '4. If no further actions are needed, use "next": "done" and "tool": "null".\n' - '5. Keep "response" short and user-friendly. Do not include any metadata or editorializing.\n' + "1) If any required argument is missing, set next='question' and ask the user.\n" + "2) If all required arguments are known, set next='confirm' and specify the tool.\n" + " The user will confirm before the tool is run.\n" + "3) If no more tools are needed, set next='done' and tool=null.\n" + "4) response should be short and user-friendly." ) - # Validation Task (Only if raw_json is provided) + # Validation Task (If raw_json is provided) if raw_json is not None: prompt_lines.append("") prompt_lines.append("=== Validation Task ===") @@ -71,60 +93,17 @@ def generate_genai_prompt( prompt_lines.append(json.dumps(raw_json, indent=2)) prompt_lines.append("") prompt_lines.append( - "Check syntax, ensure 'tool' is correct or 'null', verify 'args' are valid, " - 'and set "next" appropriately based on missing or complete args.' + "Check syntax, 'tool' validity, 'args' completeness, " + "and set 'next' appropriately. Return ONLY corrected JSON." ) - prompt_lines.append("Return only the corrected JSON, no extra text.") - - # Common Reminders and Examples - prompt_lines.append("") - prompt_lines.append("=== Usage Examples ===") - prompt_lines.append( - "Example for missing args (needs user input):\n" - "{\n" - ' "response": "I need your departure city.",\n' - ' "next": "question",\n' - ' "tool": "SearchFlights",\n' - ' "args": {\n' - ' "origin": null,\n' - ' "destination": "Melbourne",\n' - ' "dateDepart": "2025-03-26",\n' - ' "dateReturn": "2025-04-20"\n' - " }\n" - "}" - ) - prompt_lines.append( - "Example for confirmed args:\n" - "{\n" - ' "response": "All arguments are set.",\n' - ' "next": "confirm",\n' - ' "tool": "SearchFlights",\n' - ' "args": {\n' - ' "origin": "Seattle",\n' - ' "destination": "Melbourne",\n' - ' "dateDepart": "2025-03-26",\n' - ' "dateReturn": "2025-04-20"\n' - " }\n" - "}" - ) - prompt_lines.append( - "Example when fully done:\n" - "{\n" - ' "response": "All tools completed successfully. Final result: ",\n' - ' "next": "done",\n' - ' "tool": "",\n' - ' "args": {}\n' - "}" - ) # Prompt Start + prompt_lines.append("") if raw_json is not None: - prompt_lines.append("") prompt_lines.append("Begin by validating the provided JSON if necessary.") else: - prompt_lines.append("") prompt_lines.append( - "Begin by producing a valid JSON response for the next step." + "Begin by producing a valid JSON response for the next tool or question." ) return "\n".join(prompt_lines) diff --git a/pyproject.toml b/pyproject.toml index c19eff9..22f0d09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ ollama = "^0.4.5" pyyaml = "^6.0.2" fastapi = "^0.115.6" uvicorn = "^0.34.0" +python-dotenv = "^1.0.1" +openai = "^1.59.2" [tool.poetry.group.dev.dependencies] pytest = "^7.3" diff --git a/scripts/run_worker.py b/scripts/run_worker.py index 48f4149..8850c48 100644 --- a/scripts/run_worker.py +++ b/scripts/run_worker.py @@ -7,6 +7,9 @@ from temporalio.worker import Worker from activities.tool_activities import ToolActivities, dynamic_tool_activity from workflows.tool_workflow import ToolWorkflow +from dotenv import load_dotenv + +load_dotenv() async def main(): diff --git a/scripts/send_message.py b/scripts/send_message.py index 0bbbde8..100e48e 100644 --- a/scripts/send_message.py +++ b/scripts/send_message.py @@ -3,13 +3,16 @@ import sys from temporalio.client import Client from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams -from tools.tool_registry import all_tools +from tools.tool_registry import event_travel_tools from workflows.tool_workflow import ToolWorkflow async def main(prompt: str): - # 1) Build the ToolsData from imported all_tools - tools_data = ToolsData(tools=all_tools) + # Build the ToolsData + tools_data = ToolsData( + tools=event_travel_tools, + description="Helps the user find an event to travel to, search flights, and create an invoice for those flights.", + ) # 2) Create combined input combined_input = CombinedInput( diff --git a/tools/find_events.py b/tools/find_events.py index 2d9d563..b28d936 100644 --- a/tools/find_events.py +++ b/tools/find_events.py @@ -1,17 +1,17 @@ def find_events(args: dict) -> dict: # Example: continent="Oceania", month="April" - continent = args.get("continent") + region = args.get("region") month = args.get("month") - print(f"[FindEvents] Searching events in {continent} for {month} ...") + print(f"[FindEvents] Searching events in {region} for {month} ...") # Stub result return { - "eventsFound": [ + "events": [ { "city": "Melbourne", "eventName": "Melbourne International Comedy Festival", - "dates": "2025-03-26 to 2025-04-20", + "dateFrom": "2025-03-26", + "dateTo": "2025-04-20", }, - ], - "status": "found-events", + ] } diff --git a/tools/search_flights.py b/tools/search_flights.py index 4efb3a2..b765ca8 100644 --- a/tools/search_flights.py +++ b/tools/search_flights.py @@ -4,21 +4,20 @@ def search_flights(args: dict) -> dict: Currently just prints/returns the passed args, but you can add real flight search logic later. """ - date_depart = args.get("dateDepart") - date_return = args.get("dateReturn") + # date_depart = args.get("dateDepart") + # date_return = args.get("dateReturn") origin = args.get("origin") destination = args.get("destination") - print(f"Searching flights from {origin} to {destination}") - print(f"Depart: {date_depart}, Return: {date_return}") - - # Return a mock result so you can verify it - return { - "tool": "SearchFlights", - "searchResults": [ - "QF123: $1200", - "VA456: $1000", + flight_search_results = { + "origin": f"{origin}", + "destination": f"{destination}", + "currency": "USD", + "results": [ + {"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0}, + {"flight_number": "QF30", "return_flight_number": "QF29", "price": 920.0}, + {"flight_number": "MH129", "return_flight_number": "MH128", "price": 780.0}, ], - "status": "search-complete", - "args": args, } + + return flight_search_results diff --git a/tools/tool_registry.py b/tools/tool_registry.py index 3997a75..1bd4971 100644 --- a/tools/tool_registry.py +++ b/tools/tool_registry.py @@ -2,12 +2,12 @@ from models.tool_definitions import ToolDefinition, ToolArgument find_events_tool = ToolDefinition( name="FindEvents", - description="Find upcoming events given a location or region (e.g., 'Oceania') and a date or month", + description="Find upcoming events to travel to given a location or region (e.g., 'Oceania') and a date or month", arguments=[ ToolArgument( - name="continent", + name="region", type="string", - description="Which continent or region to search for events", + description="Which region to search for events", ), ToolArgument( name="month", @@ -20,7 +20,7 @@ find_events_tool = ToolDefinition( # 2) Define the SearchFlights tool search_flights_tool = ToolDefinition( name="SearchFlights", - description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)", + description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn).", arguments=[ ToolArgument( name="origin", @@ -48,7 +48,7 @@ search_flights_tool = ToolDefinition( # 3) Define the CreateInvoice tool create_invoice_tool = ToolDefinition( name="CreateInvoice", - description="Generate an invoice with flight information.", + description="Generate an invoice for the items described for the amount provided", arguments=[ ToolArgument( name="amount", @@ -58,9 +58,7 @@ create_invoice_tool = ToolDefinition( ToolArgument( name="flightDetails", type="string", - description="A summary of the flights, e.g., flight number and airport codes", + description="A description of the item details to be invoiced", ), ], ) - -all_tools = [find_events_tool, search_flights_tool, create_invoice_tool] diff --git a/workflows/tool_workflow.py b/workflows/tool_workflow.py index a2aadfb..712bebd 100644 --- a/workflows/tool_workflow.py +++ b/workflows/tool_workflow.py @@ -63,8 +63,8 @@ class ToolWorkflow: confirmed_tool_data = self.tool_data.copy() - confirmed_tool_data["next"] = "confirmed" - self.add_message("userToolConfirm", confirmed_tool_data) + confirmed_tool_data["next"] = "user_confirmed_tool_run" + self.add_message("user_confirmed_tool_run", confirmed_tool_data) # Run the tool workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}") @@ -81,10 +81,8 @@ class ToolWorkflow: # Enqueue a follow-up prompt for the LLM self.prompt_queue.append( f"### The '{current_tool}' tool completed successfully with {dynamic_result}. " - "INSTRUCTIONS: Use this tool result, and the conversation history to figure out next steps, if any. " - "IMPORTANT REMINDER: Always return only JSON in the format: {'response': '', 'next': '', 'tool': '', 'args': {}} " - " Do NOT include any metadata or editorializing in the response. " - "IMPORTANT: If moving on to another tool then ensure you ask next='question' for any missing arguments." + "INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. " + "DON'T ask any clarifying questions that are outside of the tools and args specified. " ) # Loop around again continue @@ -102,11 +100,13 @@ class ToolWorkflow: tools_data, self.conversation_history, self.tool_data ) + # tools_list = ", ".join([t.name for t in tools_data.tools]) + prompt_input = ToolPromptInput( prompt=prompt, context_instructions=context_instructions, ) - tool_data = await workflow.execute_activity_method( + tool_data = await workflow.execute_activity( ToolActivities.prompt_llm, prompt_input, schedule_to_close_timeout=timedelta(seconds=60), @@ -115,13 +115,37 @@ class ToolWorkflow: ), ) self.tool_data = tool_data - self.add_message("response", tool_data) # Check the next step from LLM next_step = self.tool_data.get("next") current_tool = self.tool_data.get("tool") if next_step == "confirm" and current_tool: + # tmp arg check + args = self.tool_data.get("args") + + # check each argument for null values + missing_args = [] + for key, value in args.items(): + if value is None: + next_step = "question" + missing_args.append(key) + + if missing_args: + # self.add_message("response_confirm_missing_args", tool_data) + + # Enqueue a follow-up prompt for the LLM + self.prompt_queue.append( + f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' and following missing arguments for tool {current_tool}: {missing_args}. " + "Only provide a valid JSON response without any comments or metadata." + ) + + workflow.logger.info( + f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}" + ) + # Loop around again + continue + waiting_for_confirm = True self.confirm = False # Clear any stale confirm workflow.logger.info("Waiting for user confirm signal...") @@ -130,8 +154,11 @@ class ToolWorkflow: elif next_step == "done": workflow.logger.info("All steps completed. Exiting workflow.") + self.add_message("agent", tool_data) return str(self.conversation_history) + self.add_message("agent", tool_data) + # Possibly continue-as-new after many turns # todo ensure this doesn't lose critical context if (