works a lot better with 4o!

This commit is contained in:
Steve Androulakis
2025-01-03 15:05:27 -08:00
parent 20d375b4ea
commit f5cf7286a2
16 changed files with 365 additions and 119 deletions

2
.gitignore vendored
View File

@@ -30,3 +30,5 @@ coverage.xml
# PyCharm / IntelliJ settings # PyCharm / IntelliJ settings
.idea/ .idea/
.env

View File

@@ -1,9 +1,12 @@
from dataclasses import dataclass from dataclasses import dataclass
from temporalio import activity from temporalio import activity
from ollama import chat, ChatResponse from ollama import chat, ChatResponse
from openai import OpenAI
import json import json
from typing import Sequence from typing import Sequence
from temporalio.common import RawValue from temporalio.common import RawValue
import os
from datetime import datetime
@dataclass @dataclass
@@ -15,6 +18,46 @@ class ToolPromptInput:
class ToolActivities: class ToolActivities:
@activity.defn @activity.defn
def prompt_llm(self, input: ToolPromptInput) -> dict: def prompt_llm(self, input: ToolPromptInput) -> dict:
client = OpenAI(
api_key=os.environ.get(
"OPENAI_API_KEY"
), # This is the default and can be omitted
)
messages = [
{
"role": "system",
"content": input.context_instructions
+ ". The current date is "
+ datetime.now().strftime("%B %d, %Y"),
},
{
"role": "user",
"content": input.prompt,
},
]
chat_completion = client.chat.completions.create(
model="gpt-4o", messages=messages # was gpt-4-0613
)
response_content = chat_completion.choices[0].message.content
print(f"ChatGPT response: {response_content}")
# Trim formatting markers if present
if response_content.startswith("```json") and response_content.endswith("```"):
response_content = response_content[7:-3].strip()
try:
data = json.loads(response_content)
except json.JSONDecodeError as e:
print(f"Invalid JSON: {e}")
raise json.JSONDecodeError
return data
@activity.defn
def prompt_llm_ollama(self, input: ToolPromptInput) -> dict:
model_name = "qwen2.5:14b" model_name = "qwen2.5:14b"
messages = [ messages = [
{ {

View File

@@ -4,7 +4,11 @@ from workflows.tool_workflow import ToolWorkflow
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
from temporalio.exceptions import TemporalError from temporalio.exceptions import TemporalError
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from tools.tool_registry import all_tools from tools.tool_registry import (
find_events_tool,
search_flights_tool,
create_invoice_tool,
)
app = FastAPI() app = FastAPI()
@@ -65,7 +69,34 @@ async def send_prompt(prompt: str):
client = await Client.connect("localhost:7233") client = await Client.connect("localhost:7233")
# Build the ToolsData # Build the ToolsData
tools_data = ToolsData(tools=all_tools) tools_data = ToolsData(
tools=[find_events_tool, search_flights_tool, create_invoice_tool],
description="Help the user gather args for these tools in order: "
"1. FindEvents: Find an event to travel to "
"2. SearchFlights: search for a flight around the event dates "
"3. GenerateInvoice: Create a simple invoice for the cost of that flight ",
example_conversation_history="\n ".join(
[
"user: I'd like to travel to an event",
"agent: Sure! Let's start by finding an event you'd like to attend. Could you tell me which region and month you're interested in?",
"user: In Sao Paulo, Brazil, in February",
"agent: Great! Let's find an events in Sao Paulo, Brazil in February.",
"user_confirmed_tool_run: <user clicks confirm on FindEvents tool>",
"tool_result: { 'event_name': 'Carnival', 'event_date': '2023-02-25' }",
"agent: Found an event! There's Carnival on 2023-02-25, ending on 2023-02-28. Would you like to search for flights around these dates?",
"user: Yes, please",
"agent: Let's search for flights around these dates. Could you provide your departure city?",
"user: New York",
"agent: Thanks, searching for flights from New York to Sao Paulo around 2023-02-25 to 2023-02-28.",
"user_confirmed_tool_run: <user clicks confirm on SearchFlights tool>"
'tool_result: results including {"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0}',
"agent: Found some flights! The cheapest is CX101 for $850. Would you like to generate an invoice for this flight?",
"user_confirmed_tool_run: <user clicks confirm on CreateInvoice tool>",
'tool_result: { "status": "success", "invoice": { "flight_number": "CX101", "amount": 850.0 }, invoiceURL: "https://example.com/invoice" }',
"agent: Invoice generated! Here's the link: https://example.com/invoice",
]
),
)
# Create combined input # Create combined input
combined_input = CombinedInput( combined_input = CombinedInput(

View File

@@ -19,17 +19,18 @@ export default function ChatWindow({ conversation, loading, onConfirm }) {
} }
const filtered = conversation.filter((msg) => { const filtered = conversation.filter((msg) => {
console.log(conversation[conversation.length - 1].actor)
const { actor, response } = msg; const { actor, response } = msg;
if (actor === "user") { if (actor === "user") {
return true; return true;
} }
if (actor === "response") { if (actor === "agent") {
const parsed = typeof response === "string" ? safeParse(response) : response; const parsed = typeof response === "string" ? safeParse(response) : response;
// Keep if next is "question", "confirm", or "confirmed". // Keep if next is "question", "confirm", or "user_confirmed_tool_run".
// Only skip if next is "done" (or something else). // Only skip if next is "done" (or something else).
return !["done"].includes(parsed.next); // return !["done"].includes(parsed.next);
return true;
} }
return false; return false;
}); });
@@ -37,13 +38,14 @@ export default function ChatWindow({ conversation, loading, onConfirm }) {
return ( return (
<div className="flex-grow overflow-y-auto space-y-4"> <div className="flex-grow overflow-y-auto space-y-4">
{filtered.map((msg, idx) => { {filtered.map((msg, idx) => {
const { actor, response } = msg; const { actor, response } = msg;
if (actor === "user") { if (actor === "user") {
return ( return (
<MessageBubble key={idx} message={{ response }} isUser /> <MessageBubble key={idx} message={{ response }} isUser />
); );
} else if (actor === "response") { } else if (actor === "agent") {
const data = const data =
typeof response === "string" ? safeParse(response) : response; typeof response === "string" ? safeParse(response) : response;
return <LLMResponse key={idx} data={data} onConfirm={onConfirm} />; return <LLMResponse key={idx} data={data} onConfirm={onConfirm} />;
@@ -57,16 +59,6 @@ export default function ChatWindow({ conversation, loading, onConfirm }) {
<LoadingIndicator /> <LoadingIndicator />
</div> </div>
)} )}
{conversation.length > 0 && conversation[conversation.length - 1].actor === "user" && (
<div className="flex justify-center">
<LoadingIndicator />
</div>
)}
{conversation.length > 0 && conversation[conversation.length - 1].actor === "tool_result_to_llm" && (
<div className="flex justify-center">
<LoadingIndicator />
</div>
)}
</div> </div>
); );
} }

View File

@@ -14,6 +14,10 @@ export default function LLMResponse({ data, onConfirm }) {
const requiresConfirm = data.next === "confirm"; const requiresConfirm = data.next === "confirm";
if (typeof data.response === "object") {
data.response = data.response.response;
}
let displayText = (data.response || "").trim(); let displayText = (data.response || "").trim();
if (!displayText && requiresConfirm) { if (!displayText && requiresConfirm) {
displayText = `Agent is ready to run "${data.tool}". Please confirm.`; displayText = `Agent is ready to run "${data.tool}". Please confirm.`;

View File

@@ -18,9 +18,12 @@ export default function App() {
const data = await res.json(); const data = await res.json();
// data is now an object like { messages: [ ... ] } // data is now an object like { messages: [ ... ] }
if (data.messages && data.messages.some(msg => msg.actor === "response" || msg.actor === "tool_result")) { if (data.messages && data.messages.length > 0 && (data.messages[data.messages.length - 1].actor === "agent")) {
setLoading(false); setLoading(false);
} }
else {
setLoading(true);
}
setConversation(data.messages || []); setConversation(data.messages || []);
} }
} catch (err) { } catch (err) {

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import List, Dict, Any from typing import List
@dataclass @dataclass
@@ -19,3 +19,7 @@ class ToolDefinition:
@dataclass @dataclass
class ToolsData: class ToolsData:
tools: List[ToolDefinition] tools: List[ToolDefinition]
description: str = "Description of the tools purpose and overall goal"
example_conversation_history: str = (
"Example conversation history to help the AI agent understand the context of the conversation"
)

158
poetry.lock generated
View File

@@ -115,6 +115,17 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
] ]
[[package]]
name = "distro"
version = "1.9.0"
description = "Distro - an OS platform information API"
optional = false
python-versions = ">=3.6"
files = [
{file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"},
{file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"},
]
[[package]] [[package]]
name = "exceptiongroup" name = "exceptiongroup"
version = "1.2.2" version = "1.2.2"
@@ -245,6 +256,91 @@ files = [
[package.extras] [package.extras]
colors = ["colorama (>=0.4.6)"] colors = ["colorama (>=0.4.6)"]
[[package]]
name = "jiter"
version = "0.8.2"
description = "Fast iterable JSON parser."
optional = false
python-versions = ">=3.8"
files = [
{file = "jiter-0.8.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:ca8577f6a413abe29b079bc30f907894d7eb07a865c4df69475e868d73e71c7b"},
{file = "jiter-0.8.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b25bd626bde7fb51534190c7e3cb97cee89ee76b76d7585580e22f34f5e3f393"},
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5c826a221851a8dc028eb6d7d6429ba03184fa3c7e83ae01cd6d3bd1d4bd17d"},
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d35c864c2dff13dfd79fb070fc4fc6235d7b9b359efe340e1261deb21b9fcb66"},
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f557c55bc2b7676e74d39d19bcb8775ca295c7a028246175d6a8b431e70835e5"},
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:580ccf358539153db147e40751a0b41688a5ceb275e6f3e93d91c9467f42b2e3"},
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af102d3372e917cffce49b521e4c32c497515119dc7bd8a75665e90a718bbf08"},
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cadcc978f82397d515bb2683fc0d50103acff2a180552654bb92d6045dec2c49"},
{file = "jiter-0.8.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ba5bdf56969cad2019d4e8ffd3f879b5fdc792624129741d3d83fc832fef8c7d"},
{file = "jiter-0.8.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3b94a33a241bee9e34b8481cdcaa3d5c2116f575e0226e421bed3f7a6ea71cff"},
{file = "jiter-0.8.2-cp310-cp310-win32.whl", hash = "sha256:6e5337bf454abddd91bd048ce0dca5134056fc99ca0205258766db35d0a2ea43"},
{file = "jiter-0.8.2-cp310-cp310-win_amd64.whl", hash = "sha256:4a9220497ca0cb1fe94e3f334f65b9b5102a0b8147646118f020d8ce1de70105"},
{file = "jiter-0.8.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:2dd61c5afc88a4fda7d8b2cf03ae5947c6ac7516d32b7a15bf4b49569a5c076b"},
{file = "jiter-0.8.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a6c710d657c8d1d2adbbb5c0b0c6bfcec28fd35bd6b5f016395f9ac43e878a15"},
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9584de0cd306072635fe4b89742bf26feae858a0683b399ad0c2509011b9dc0"},
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5a90a923338531b7970abb063cfc087eebae6ef8ec8139762007188f6bc69a9f"},
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d21974d246ed0181558087cd9f76e84e8321091ebfb3a93d4c341479a736f099"},
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:32475a42b2ea7b344069dc1e81445cfc00b9d0e3ca837f0523072432332e9f74"},
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b9931fd36ee513c26b5bf08c940b0ac875de175341cbdd4fa3be109f0492586"},
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ce0820f4a3a59ddced7fce696d86a096d5cc48d32a4183483a17671a61edfddc"},
{file = "jiter-0.8.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8ffc86ae5e3e6a93765d49d1ab47b6075a9c978a2b3b80f0f32628f39caa0c88"},
{file = "jiter-0.8.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5127dc1abd809431172bc3fbe8168d6b90556a30bb10acd5ded41c3cfd6f43b6"},
{file = "jiter-0.8.2-cp311-cp311-win32.whl", hash = "sha256:66227a2c7b575720c1871c8800d3a0122bb8ee94edb43a5685aa9aceb2782d44"},
{file = "jiter-0.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:cde031d8413842a1e7501e9129b8e676e62a657f8ec8166e18a70d94d4682855"},
{file = "jiter-0.8.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:e6ec2be506e7d6f9527dae9ff4b7f54e68ea44a0ef6b098256ddf895218a2f8f"},
{file = "jiter-0.8.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:76e324da7b5da060287c54f2fabd3db5f76468006c811831f051942bf68c9d44"},
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:180a8aea058f7535d1c84183c0362c710f4750bef66630c05f40c93c2b152a0f"},
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:025337859077b41548bdcbabe38698bcd93cfe10b06ff66617a48ff92c9aec60"},
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecff0dc14f409599bbcafa7e470c00b80f17abc14d1405d38ab02e4b42e55b57"},
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ffd9fee7d0775ebaba131f7ca2e2d83839a62ad65e8e02fe2bd8fc975cedeb9e"},
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14601dcac4889e0a1c75ccf6a0e4baf70dbc75041e51bcf8d0e9274519df6887"},
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:92249669925bc1c54fcd2ec73f70f2c1d6a817928480ee1c65af5f6b81cdf12d"},
{file = "jiter-0.8.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e725edd0929fa79f8349ab4ec7f81c714df51dc4e991539a578e5018fa4a7152"},
{file = "jiter-0.8.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bf55846c7b7a680eebaf9c3c48d630e1bf51bdf76c68a5f654b8524335b0ad29"},
{file = "jiter-0.8.2-cp312-cp312-win32.whl", hash = "sha256:7efe4853ecd3d6110301665a5178b9856be7e2a9485f49d91aa4d737ad2ae49e"},
{file = "jiter-0.8.2-cp312-cp312-win_amd64.whl", hash = "sha256:83c0efd80b29695058d0fd2fa8a556490dbce9804eac3e281f373bbc99045f6c"},
{file = "jiter-0.8.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:ca1f08b8e43dc3bd0594c992fb1fd2f7ce87f7bf0d44358198d6da8034afdf84"},
{file = "jiter-0.8.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5672a86d55416ccd214c778efccf3266b84f87b89063b582167d803246354be4"},
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58dc9bc9767a1101f4e5e22db1b652161a225874d66f0e5cb8e2c7d1c438b587"},
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:37b2998606d6dadbb5ccda959a33d6a5e853252d921fec1792fc902351bb4e2c"},
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4ab9a87f3784eb0e098f84a32670cfe4a79cb6512fd8f42ae3d0709f06405d18"},
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:79aec8172b9e3c6d05fd4b219d5de1ac616bd8da934107325a6c0d0e866a21b6"},
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:711e408732d4e9a0208008e5892c2966b485c783cd2d9a681f3eb147cf36c7ef"},
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:653cf462db4e8c41995e33d865965e79641ef45369d8a11f54cd30888b7e6ff1"},
{file = "jiter-0.8.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:9c63eaef32b7bebac8ebebf4dabebdbc6769a09c127294db6babee38e9f405b9"},
{file = "jiter-0.8.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:eb21aaa9a200d0a80dacc7a81038d2e476ffe473ffdd9c91eb745d623561de05"},
{file = "jiter-0.8.2-cp313-cp313-win32.whl", hash = "sha256:789361ed945d8d42850f919342a8665d2dc79e7e44ca1c97cc786966a21f627a"},
{file = "jiter-0.8.2-cp313-cp313-win_amd64.whl", hash = "sha256:ab7f43235d71e03b941c1630f4b6e3055d46b6cb8728a17663eaac9d8e83a865"},
{file = "jiter-0.8.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b426f72cd77da3fec300ed3bc990895e2dd6b49e3bfe6c438592a3ba660e41ca"},
{file = "jiter-0.8.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2dd880785088ff2ad21ffee205e58a8c1ddabc63612444ae41e5e4b321b39c0"},
{file = "jiter-0.8.2-cp313-cp313t-win_amd64.whl", hash = "sha256:3ac9f578c46f22405ff7f8b1f5848fb753cc4b8377fbec8470a7dc3997ca7566"},
{file = "jiter-0.8.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:9e1fa156ee9454642adb7e7234a383884452532bc9d53d5af2d18d98ada1d79c"},
{file = "jiter-0.8.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0cf5dfa9956d96ff2efb0f8e9c7d055904012c952539a774305aaaf3abdf3d6c"},
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e52bf98c7e727dd44f7c4acb980cb988448faeafed8433c867888268899b298b"},
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a2ecaa3c23e7a7cf86d00eda3390c232f4d533cd9ddea4b04f5d0644faf642c5"},
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:08d4c92bf480e19fc3f2717c9ce2aa31dceaa9163839a311424b6862252c943e"},
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:99d9a1eded738299ba8e106c6779ce5c3893cffa0e32e4485d680588adae6db8"},
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d20be8b7f606df096e08b0b1b4a3c6f0515e8dac296881fe7461dfa0fb5ec817"},
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d33f94615fcaf872f7fd8cd98ac3b429e435c77619777e8a449d9d27e01134d1"},
{file = "jiter-0.8.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:317b25e98a35ffec5c67efe56a4e9970852632c810d35b34ecdd70cc0e47b3b6"},
{file = "jiter-0.8.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fc9043259ee430ecd71d178fccabd8c332a3bf1e81e50cae43cc2b28d19e4cb7"},
{file = "jiter-0.8.2-cp38-cp38-win32.whl", hash = "sha256:fc5adda618205bd4678b146612ce44c3cbfdee9697951f2c0ffdef1f26d72b63"},
{file = "jiter-0.8.2-cp38-cp38-win_amd64.whl", hash = "sha256:cd646c827b4f85ef4a78e4e58f4f5854fae0caf3db91b59f0d73731448a970c6"},
{file = "jiter-0.8.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:e41e75344acef3fc59ba4765df29f107f309ca9e8eace5baacabd9217e52a5ee"},
{file = "jiter-0.8.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f22b16b35d5c1df9dfd58843ab2cd25e6bf15191f5a236bed177afade507bfc"},
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7200b8f7619d36aa51c803fd52020a2dfbea36ffec1b5e22cab11fd34d95a6d"},
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:70bf4c43652cc294040dbb62256c83c8718370c8b93dd93d934b9a7bf6c4f53c"},
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f9d471356dc16f84ed48768b8ee79f29514295c7295cb41e1133ec0b2b8d637d"},
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:859e8eb3507894093d01929e12e267f83b1d5f6221099d3ec976f0c995cb6bd9"},
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaa58399c01db555346647a907b4ef6d4f584b123943be6ed5588c3f2359c9f4"},
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8f2d5ed877f089862f4c7aacf3a542627c1496f972a34d0474ce85ee7d939c27"},
{file = "jiter-0.8.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:03c9df035d4f8d647f8c210ddc2ae0728387275340668fb30d2421e17d9a0841"},
{file = "jiter-0.8.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8bd2a824d08d8977bb2794ea2682f898ad3d8837932e3a74937e93d62ecbb637"},
{file = "jiter-0.8.2-cp39-cp39-win32.whl", hash = "sha256:ca29b6371ebc40e496995c94b988a101b9fbbed48a51190a4461fcb0a68b4a36"},
{file = "jiter-0.8.2-cp39-cp39-win_amd64.whl", hash = "sha256:1c0dfbd1be3cbefc7510102370d86e35d1d53e5a93d48519688b1bf0f761160a"},
{file = "jiter-0.8.2.tar.gz", hash = "sha256:cd73d3e740666d0e639f678adb176fad25c1bcbdae88d8d7b857e1783bb4212d"},
]
[[package]] [[package]]
name = "mypy-extensions" name = "mypy-extensions"
version = "1.0.0" version = "1.0.0"
@@ -271,6 +367,31 @@ files = [
httpx = ">=0.27.0,<0.28.0" httpx = ">=0.27.0,<0.28.0"
pydantic = ">=2.9.0,<3.0.0" pydantic = ">=2.9.0,<3.0.0"
[[package]]
name = "openai"
version = "1.59.2"
description = "The official Python library for the openai API"
optional = false
python-versions = ">=3.8"
files = [
{file = "openai-1.59.2-py3-none-any.whl", hash = "sha256:3de721df4d2ccc5e845afa7235dce496bfbdd572692a876d2ae6211e0290ff22"},
{file = "openai-1.59.2.tar.gz", hash = "sha256:1bf2d5e8a93533f6dd3fb7b1bcf082ddd4ae61cc6d89ca1343e5957e4720651c"},
]
[package.dependencies]
anyio = ">=3.5.0,<5"
distro = ">=1.7.0,<2"
httpx = ">=0.23.0,<1"
jiter = ">=0.4.0,<1"
pydantic = ">=1.9.0,<3"
sniffio = "*"
tqdm = ">4"
typing-extensions = ">=4.11,<5"
[package.extras]
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
realtime = ["websockets (>=13,<15)"]
[[package]] [[package]]
name = "packaging" name = "packaging"
version = "24.2" version = "24.2"
@@ -512,6 +633,20 @@ files = [
[package.dependencies] [package.dependencies]
six = ">=1.5" six = ">=1.5"
[[package]]
name = "python-dotenv"
version = "1.0.1"
description = "Read key-value pairs from a .env file and set them as environment variables"
optional = false
python-versions = ">=3.8"
files = [
{file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"},
{file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"},
]
[package.extras]
cli = ["click (>=5.0)"]
[[package]] [[package]]
name = "pyyaml" name = "pyyaml"
version = "6.0.2" version = "6.0.2"
@@ -680,6 +815,27 @@ files = [
{file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"}, {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"},
] ]
[[package]]
name = "tqdm"
version = "4.67.1"
description = "Fast, Extensible Progress Meter"
optional = false
python-versions = ">=3.7"
files = [
{file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"},
{file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"},
]
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[package.extras]
dev = ["nbval", "pytest (>=6)", "pytest-asyncio (>=0.24)", "pytest-cov", "pytest-timeout"]
discord = ["requests"]
notebook = ["ipywidgets (>=6)"]
slack = ["slack-sdk"]
telegram = ["requests"]
[[package]] [[package]]
name = "types-protobuf" name = "types-protobuf"
version = "5.29.1.20241207" version = "5.29.1.20241207"
@@ -724,4 +880,4 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)",
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.9,<4.0" python-versions = ">=3.9,<4.0"
content-hash = "20144fbc5773604251f9b61ac475eb9f292c6c8baf38d59170bbef81b482c71e" content-hash = "ddef05a187e2f0364c8dc045d52aa30b7c95bc30075cfc0aa051af4fd1f8545b"

View File

@@ -7,30 +7,49 @@ def generate_genai_prompt(
tools_data: ToolsData, conversation_history: str, raw_json: Optional[str] = None tools_data: ToolsData, conversation_history: str, raw_json: Optional[str] = None
) -> str: ) -> str:
""" """
Generates a concise prompt for producing or validating JSON instructions. Generates a concise prompt for producing or validating JSON instructions
with the provided tools and conversation history.
""" """
prompt_lines = [] prompt_lines = []
# Intro / Role # Intro / Role
prompt_lines.append( prompt_lines.append(
"You are an AI assistant that must produce or validate JSON instructions " "You are an AI agent that helps fill required arguments for the tools described below. "
"to properly call a set of tools. Respond with valid JSON only." "You must respond with valid JSON ONLY, using the schema provided in the instructions."
) )
# Conversation History # Main Conversation History
prompt_lines.append("=== Conversation History ===") prompt_lines.append("=== Conversation History ===")
prompt_lines.append( prompt_lines.append(
"Use this history to understand needed tools, arguments, and the user's goals:" "This is the ongoing history to determine which tool and arguments to gather:"
) )
prompt_lines.append("BEGIN CONVERSATION HISTORY") prompt_lines.append("BEGIN CONVERSATION HISTORY")
prompt_lines.append(json.dumps(conversation_history, indent=2)) prompt_lines.append(json.dumps(conversation_history, indent=2))
prompt_lines.append("END CONVERSATION HISTORY") prompt_lines.append("END CONVERSATION HISTORY")
prompt_lines.append("") prompt_lines.append("")
# Example Conversation History (from tools_data)
if tools_data.example_conversation_history:
prompt_lines.append("=== Example Conversation With These Tools ===")
prompt_lines.append(
"Use this example to understand how tools are invoked and arguments are gathered."
)
prompt_lines.append("BEGIN EXAMPLE")
prompt_lines.append(tools_data.example_conversation_history)
prompt_lines.append("END EXAMPLE")
prompt_lines.append("")
# Tools Definitions # Tools Definitions
prompt_lines.append("=== Tools Definitions ===") prompt_lines.append("=== Tools Definitions ===")
prompt_lines.append(f"There are {len(tools_data.tools)} available tools:") prompt_lines.append(f"There are {len(tools_data.tools)} available tools:")
prompt_lines.append(", ".join([t.name for t in tools_data.tools])) prompt_lines.append(", ".join([t.name for t in tools_data.tools]))
prompt_lines.append(f"Goal: {tools_data.description}")
prompt_lines.append(
"Gather the necessary information for each tool in the sequence described above."
)
prompt_lines.append(
"Only ask for arguments listed below. Do not add extra arguments."
)
prompt_lines.append("") prompt_lines.append("")
for tool in tools_data.tools: for tool in tools_data.tools:
prompt_lines.append(f"Tool name: {tool.name}") prompt_lines.append(f"Tool name: {tool.name}")
@@ -39,8 +58,11 @@ def generate_genai_prompt(
for arg in tool.arguments: for arg in tool.arguments:
prompt_lines.append(f" - {arg.name} ({arg.type}): {arg.description}") prompt_lines.append(f" - {arg.name} ({arg.type}): {arg.description}")
prompt_lines.append("") prompt_lines.append("")
prompt_lines.append(
"When all required args for a tool are known, you can propose next='confirm' to run it."
)
# Instructions for JSON Generation # JSON Format Instructions
prompt_lines.append("=== Instructions for JSON Generation ===") prompt_lines.append("=== Instructions for JSON Generation ===")
prompt_lines.append( prompt_lines.append(
"Your JSON format must be:\n" "Your JSON format must be:\n"
@@ -56,14 +78,14 @@ def generate_genai_prompt(
"}" "}"
) )
prompt_lines.append( prompt_lines.append(
"1. You may call multiple tools sequentially. Each requires specific arguments.\n" "1) If any required argument is missing, set next='question' and ask the user.\n"
'2. If ANY required argument is missing, use "next": "question" and prompt the user.\n' "2) If all required arguments are known, set next='confirm' and specify the tool.\n"
'3. If all required arguments are known, use "next": "confirm" and set "tool" to the tool name.\n' " The user will confirm before the tool is run.\n"
'4. If no further actions are needed, use "next": "done" and "tool": "null".\n' "3) If no more tools are needed, set next='done' and tool=null.\n"
'5. Keep "response" short and user-friendly. Do not include any metadata or editorializing.\n' "4) response should be short and user-friendly."
) )
# Validation Task (Only if raw_json is provided) # Validation Task (If raw_json is provided)
if raw_json is not None: if raw_json is not None:
prompt_lines.append("") prompt_lines.append("")
prompt_lines.append("=== Validation Task ===") prompt_lines.append("=== Validation Task ===")
@@ -71,60 +93,17 @@ def generate_genai_prompt(
prompt_lines.append(json.dumps(raw_json, indent=2)) prompt_lines.append(json.dumps(raw_json, indent=2))
prompt_lines.append("") prompt_lines.append("")
prompt_lines.append( prompt_lines.append(
"Check syntax, ensure 'tool' is correct or 'null', verify 'args' are valid, " "Check syntax, 'tool' validity, 'args' completeness, "
'and set "next" appropriately based on missing or complete args.' "and set 'next' appropriately. Return ONLY corrected JSON."
)
prompt_lines.append("Return only the corrected JSON, no extra text.")
# Common Reminders and Examples
prompt_lines.append("")
prompt_lines.append("=== Usage Examples ===")
prompt_lines.append(
"Example for missing args (needs user input):\n"
"{\n"
' "response": "I need your departure city.",\n'
' "next": "question",\n'
' "tool": "SearchFlights",\n'
' "args": {\n'
' "origin": null,\n'
' "destination": "Melbourne",\n'
' "dateDepart": "2025-03-26",\n'
' "dateReturn": "2025-04-20"\n'
" }\n"
"}"
)
prompt_lines.append(
"Example for confirmed args:\n"
"{\n"
' "response": "All arguments are set.",\n'
' "next": "confirm",\n'
' "tool": "SearchFlights",\n'
' "args": {\n'
' "origin": "Seattle",\n'
' "destination": "Melbourne",\n'
' "dateDepart": "2025-03-26",\n'
' "dateReturn": "2025-04-20"\n'
" }\n"
"}"
)
prompt_lines.append(
"Example when fully done:\n"
"{\n"
' "response": "All tools completed successfully. Final result: <insert result here>",\n'
' "next": "done",\n'
' "tool": "",\n'
' "args": {}\n'
"}"
) )
# Prompt Start # Prompt Start
if raw_json is not None:
prompt_lines.append("") prompt_lines.append("")
if raw_json is not None:
prompt_lines.append("Begin by validating the provided JSON if necessary.") prompt_lines.append("Begin by validating the provided JSON if necessary.")
else: else:
prompt_lines.append("")
prompt_lines.append( prompt_lines.append(
"Begin by producing a valid JSON response for the next step." "Begin by producing a valid JSON response for the next tool or question."
) )
return "\n".join(prompt_lines) return "\n".join(prompt_lines)

View File

@@ -25,6 +25,8 @@ ollama = "^0.4.5"
pyyaml = "^6.0.2" pyyaml = "^6.0.2"
fastapi = "^0.115.6" fastapi = "^0.115.6"
uvicorn = "^0.34.0" uvicorn = "^0.34.0"
python-dotenv = "^1.0.1"
openai = "^1.59.2"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pytest = "^7.3" pytest = "^7.3"

View File

@@ -7,6 +7,9 @@ from temporalio.worker import Worker
from activities.tool_activities import ToolActivities, dynamic_tool_activity from activities.tool_activities import ToolActivities, dynamic_tool_activity
from workflows.tool_workflow import ToolWorkflow from workflows.tool_workflow import ToolWorkflow
from dotenv import load_dotenv
load_dotenv()
async def main(): async def main():

View File

@@ -3,13 +3,16 @@ import sys
from temporalio.client import Client from temporalio.client import Client
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
from tools.tool_registry import all_tools from tools.tool_registry import event_travel_tools
from workflows.tool_workflow import ToolWorkflow from workflows.tool_workflow import ToolWorkflow
async def main(prompt: str): async def main(prompt: str):
# 1) Build the ToolsData from imported all_tools # Build the ToolsData
tools_data = ToolsData(tools=all_tools) tools_data = ToolsData(
tools=event_travel_tools,
description="Helps the user find an event to travel to, search flights, and create an invoice for those flights.",
)
# 2) Create combined input # 2) Create combined input
combined_input = CombinedInput( combined_input = CombinedInput(

View File

@@ -1,17 +1,17 @@
def find_events(args: dict) -> dict: def find_events(args: dict) -> dict:
# Example: continent="Oceania", month="April" # Example: continent="Oceania", month="April"
continent = args.get("continent") region = args.get("region")
month = args.get("month") month = args.get("month")
print(f"[FindEvents] Searching events in {continent} for {month} ...") print(f"[FindEvents] Searching events in {region} for {month} ...")
# Stub result # Stub result
return { return {
"eventsFound": [ "events": [
{ {
"city": "Melbourne", "city": "Melbourne",
"eventName": "Melbourne International Comedy Festival", "eventName": "Melbourne International Comedy Festival",
"dates": "2025-03-26 to 2025-04-20", "dateFrom": "2025-03-26",
"dateTo": "2025-04-20",
}, },
], ]
"status": "found-events",
} }

View File

@@ -4,21 +4,20 @@ def search_flights(args: dict) -> dict:
Currently just prints/returns the passed args, Currently just prints/returns the passed args,
but you can add real flight search logic later. but you can add real flight search logic later.
""" """
date_depart = args.get("dateDepart") # date_depart = args.get("dateDepart")
date_return = args.get("dateReturn") # date_return = args.get("dateReturn")
origin = args.get("origin") origin = args.get("origin")
destination = args.get("destination") destination = args.get("destination")
print(f"Searching flights from {origin} to {destination}") flight_search_results = {
print(f"Depart: {date_depart}, Return: {date_return}") "origin": f"{origin}",
"destination": f"{destination}",
# Return a mock result so you can verify it "currency": "USD",
return { "results": [
"tool": "SearchFlights", {"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0},
"searchResults": [ {"flight_number": "QF30", "return_flight_number": "QF29", "price": 920.0},
"QF123: $1200", {"flight_number": "MH129", "return_flight_number": "MH128", "price": 780.0},
"VA456: $1000",
], ],
"status": "search-complete",
"args": args,
} }
return flight_search_results

View File

@@ -2,12 +2,12 @@ from models.tool_definitions import ToolDefinition, ToolArgument
find_events_tool = ToolDefinition( find_events_tool = ToolDefinition(
name="FindEvents", name="FindEvents",
description="Find upcoming events given a location or region (e.g., 'Oceania') and a date or month", description="Find upcoming events to travel to given a location or region (e.g., 'Oceania') and a date or month",
arguments=[ arguments=[
ToolArgument( ToolArgument(
name="continent", name="region",
type="string", type="string",
description="Which continent or region to search for events", description="Which region to search for events",
), ),
ToolArgument( ToolArgument(
name="month", name="month",
@@ -20,7 +20,7 @@ find_events_tool = ToolDefinition(
# 2) Define the SearchFlights tool # 2) Define the SearchFlights tool
search_flights_tool = ToolDefinition( search_flights_tool = ToolDefinition(
name="SearchFlights", name="SearchFlights",
description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)", description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn).",
arguments=[ arguments=[
ToolArgument( ToolArgument(
name="origin", name="origin",
@@ -48,7 +48,7 @@ search_flights_tool = ToolDefinition(
# 3) Define the CreateInvoice tool # 3) Define the CreateInvoice tool
create_invoice_tool = ToolDefinition( create_invoice_tool = ToolDefinition(
name="CreateInvoice", name="CreateInvoice",
description="Generate an invoice with flight information.", description="Generate an invoice for the items described for the amount provided",
arguments=[ arguments=[
ToolArgument( ToolArgument(
name="amount", name="amount",
@@ -58,9 +58,7 @@ create_invoice_tool = ToolDefinition(
ToolArgument( ToolArgument(
name="flightDetails", name="flightDetails",
type="string", type="string",
description="A summary of the flights, e.g., flight number and airport codes", description="A description of the item details to be invoiced",
), ),
], ],
) )
all_tools = [find_events_tool, search_flights_tool, create_invoice_tool]

View File

@@ -63,8 +63,8 @@ class ToolWorkflow:
confirmed_tool_data = self.tool_data.copy() confirmed_tool_data = self.tool_data.copy()
confirmed_tool_data["next"] = "confirmed" confirmed_tool_data["next"] = "user_confirmed_tool_run"
self.add_message("userToolConfirm", confirmed_tool_data) self.add_message("user_confirmed_tool_run", confirmed_tool_data)
# Run the tool # Run the tool
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}") workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
@@ -81,10 +81,8 @@ class ToolWorkflow:
# Enqueue a follow-up prompt for the LLM # Enqueue a follow-up prompt for the LLM
self.prompt_queue.append( self.prompt_queue.append(
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. " f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
"INSTRUCTIONS: Use this tool result, and the conversation history to figure out next steps, if any. " "INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. "
"IMPORTANT REMINDER: Always return only JSON in the format: {'response': '', 'next': '', 'tool': '', 'args': {}} " "DON'T ask any clarifying questions that are outside of the tools and args specified. "
" Do NOT include any metadata or editorializing in the response. "
"IMPORTANT: If moving on to another tool then ensure you ask next='question' for any missing arguments."
) )
# Loop around again # Loop around again
continue continue
@@ -102,11 +100,13 @@ class ToolWorkflow:
tools_data, self.conversation_history, self.tool_data tools_data, self.conversation_history, self.tool_data
) )
# tools_list = ", ".join([t.name for t in tools_data.tools])
prompt_input = ToolPromptInput( prompt_input = ToolPromptInput(
prompt=prompt, prompt=prompt,
context_instructions=context_instructions, context_instructions=context_instructions,
) )
tool_data = await workflow.execute_activity_method( tool_data = await workflow.execute_activity(
ToolActivities.prompt_llm, ToolActivities.prompt_llm,
prompt_input, prompt_input,
schedule_to_close_timeout=timedelta(seconds=60), schedule_to_close_timeout=timedelta(seconds=60),
@@ -115,13 +115,37 @@ class ToolWorkflow:
), ),
) )
self.tool_data = tool_data self.tool_data = tool_data
self.add_message("response", tool_data)
# Check the next step from LLM # Check the next step from LLM
next_step = self.tool_data.get("next") next_step = self.tool_data.get("next")
current_tool = self.tool_data.get("tool") current_tool = self.tool_data.get("tool")
if next_step == "confirm" and current_tool: if next_step == "confirm" and current_tool:
# tmp arg check
args = self.tool_data.get("args")
# check each argument for null values
missing_args = []
for key, value in args.items():
if value is None:
next_step = "question"
missing_args.append(key)
if missing_args:
# self.add_message("response_confirm_missing_args", tool_data)
# Enqueue a follow-up prompt for the LLM
self.prompt_queue.append(
f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' and following missing arguments for tool {current_tool}: {missing_args}. "
"Only provide a valid JSON response without any comments or metadata."
)
workflow.logger.info(
f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}"
)
# Loop around again
continue
waiting_for_confirm = True waiting_for_confirm = True
self.confirm = False # Clear any stale confirm self.confirm = False # Clear any stale confirm
workflow.logger.info("Waiting for user confirm signal...") workflow.logger.info("Waiting for user confirm signal...")
@@ -130,8 +154,11 @@ class ToolWorkflow:
elif next_step == "done": elif next_step == "done":
workflow.logger.info("All steps completed. Exiting workflow.") workflow.logger.info("All steps completed. Exiting workflow.")
self.add_message("agent", tool_data)
return str(self.conversation_history) return str(self.conversation_history)
self.add_message("agent", tool_data)
# Possibly continue-as-new after many turns # Possibly continue-as-new after many turns
# todo ensure this doesn't lose critical context # todo ensure this doesn't lose critical context
if ( if (