works a lot better with 4o!

This commit is contained in:
Steve Androulakis
2025-01-03 15:05:27 -08:00
parent 20d375b4ea
commit f5cf7286a2
16 changed files with 365 additions and 119 deletions

4
.gitignore vendored
View File

@@ -29,4 +29,6 @@ coverage.xml
.vscode/
# PyCharm / IntelliJ settings
.idea/
.idea/
.env

View File

@@ -1,9 +1,12 @@
from dataclasses import dataclass
from temporalio import activity
from ollama import chat, ChatResponse
from openai import OpenAI
import json
from typing import Sequence
from temporalio.common import RawValue
import os
from datetime import datetime
@dataclass
@@ -15,6 +18,46 @@ class ToolPromptInput:
class ToolActivities:
@activity.defn
def prompt_llm(self, input: ToolPromptInput) -> dict:
client = OpenAI(
api_key=os.environ.get(
"OPENAI_API_KEY"
), # This is the default and can be omitted
)
messages = [
{
"role": "system",
"content": input.context_instructions
+ ". The current date is "
+ datetime.now().strftime("%B %d, %Y"),
},
{
"role": "user",
"content": input.prompt,
},
]
chat_completion = client.chat.completions.create(
model="gpt-4o", messages=messages # was gpt-4-0613
)
response_content = chat_completion.choices[0].message.content
print(f"ChatGPT response: {response_content}")
# Trim formatting markers if present
if response_content.startswith("```json") and response_content.endswith("```"):
response_content = response_content[7:-3].strip()
try:
data = json.loads(response_content)
except json.JSONDecodeError as e:
print(f"Invalid JSON: {e}")
raise json.JSONDecodeError
return data
@activity.defn
def prompt_llm_ollama(self, input: ToolPromptInput) -> dict:
model_name = "qwen2.5:14b"
messages = [
{

View File

@@ -4,7 +4,11 @@ from workflows.tool_workflow import ToolWorkflow
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
from temporalio.exceptions import TemporalError
from fastapi.middleware.cors import CORSMiddleware
from tools.tool_registry import all_tools
from tools.tool_registry import (
find_events_tool,
search_flights_tool,
create_invoice_tool,
)
app = FastAPI()
@@ -65,7 +69,34 @@ async def send_prompt(prompt: str):
client = await Client.connect("localhost:7233")
# Build the ToolsData
tools_data = ToolsData(tools=all_tools)
tools_data = ToolsData(
tools=[find_events_tool, search_flights_tool, create_invoice_tool],
description="Help the user gather args for these tools in order: "
"1. FindEvents: Find an event to travel to "
"2. SearchFlights: search for a flight around the event dates "
"3. GenerateInvoice: Create a simple invoice for the cost of that flight ",
example_conversation_history="\n ".join(
[
"user: I'd like to travel to an event",
"agent: Sure! Let's start by finding an event you'd like to attend. Could you tell me which region and month you're interested in?",
"user: In Sao Paulo, Brazil, in February",
"agent: Great! Let's find an events in Sao Paulo, Brazil in February.",
"user_confirmed_tool_run: <user clicks confirm on FindEvents tool>",
"tool_result: { 'event_name': 'Carnival', 'event_date': '2023-02-25' }",
"agent: Found an event! There's Carnival on 2023-02-25, ending on 2023-02-28. Would you like to search for flights around these dates?",
"user: Yes, please",
"agent: Let's search for flights around these dates. Could you provide your departure city?",
"user: New York",
"agent: Thanks, searching for flights from New York to Sao Paulo around 2023-02-25 to 2023-02-28.",
"user_confirmed_tool_run: <user clicks confirm on SearchFlights tool>"
'tool_result: results including {"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0}',
"agent: Found some flights! The cheapest is CX101 for $850. Would you like to generate an invoice for this flight?",
"user_confirmed_tool_run: <user clicks confirm on CreateInvoice tool>",
'tool_result: { "status": "success", "invoice": { "flight_number": "CX101", "amount": 850.0 }, invoiceURL: "https://example.com/invoice" }',
"agent: Invoice generated! Here's the link: https://example.com/invoice",
]
),
)
# Create combined input
combined_input = CombinedInput(

View File

@@ -19,17 +19,18 @@ export default function ChatWindow({ conversation, loading, onConfirm }) {
}
const filtered = conversation.filter((msg) => {
console.log(conversation[conversation.length - 1].actor)
const { actor, response } = msg;
if (actor === "user") {
return true;
}
if (actor === "response") {
if (actor === "agent") {
const parsed = typeof response === "string" ? safeParse(response) : response;
// Keep if next is "question", "confirm", or "confirmed".
// Keep if next is "question", "confirm", or "user_confirmed_tool_run".
// Only skip if next is "done" (or something else).
return !["done"].includes(parsed.next);
// return !["done"].includes(parsed.next);
return true;
}
return false;
});
@@ -37,13 +38,14 @@ export default function ChatWindow({ conversation, loading, onConfirm }) {
return (
<div className="flex-grow overflow-y-auto space-y-4">
{filtered.map((msg, idx) => {
const { actor, response } = msg;
if (actor === "user") {
return (
<MessageBubble key={idx} message={{ response }} isUser />
);
} else if (actor === "response") {
} else if (actor === "agent") {
const data =
typeof response === "string" ? safeParse(response) : response;
return <LLMResponse key={idx} data={data} onConfirm={onConfirm} />;
@@ -57,16 +59,6 @@ export default function ChatWindow({ conversation, loading, onConfirm }) {
<LoadingIndicator />
</div>
)}
{conversation.length > 0 && conversation[conversation.length - 1].actor === "user" && (
<div className="flex justify-center">
<LoadingIndicator />
</div>
)}
{conversation.length > 0 && conversation[conversation.length - 1].actor === "tool_result_to_llm" && (
<div className="flex justify-center">
<LoadingIndicator />
</div>
)}
</div>
);
}

View File

@@ -14,6 +14,10 @@ export default function LLMResponse({ data, onConfirm }) {
const requiresConfirm = data.next === "confirm";
if (typeof data.response === "object") {
data.response = data.response.response;
}
let displayText = (data.response || "").trim();
if (!displayText && requiresConfirm) {
displayText = `Agent is ready to run "${data.tool}". Please confirm.`;

View File

@@ -18,9 +18,12 @@ export default function App() {
const data = await res.json();
// data is now an object like { messages: [ ... ] }
if (data.messages && data.messages.some(msg => msg.actor === "response" || msg.actor === "tool_result")) {
if (data.messages && data.messages.length > 0 && (data.messages[data.messages.length - 1].actor === "agent")) {
setLoading(false);
}
else {
setLoading(true);
}
setConversation(data.messages || []);
}
} catch (err) {

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import List, Dict, Any
from dataclasses import dataclass
from typing import List
@dataclass
@@ -19,3 +19,7 @@ class ToolDefinition:
@dataclass
class ToolsData:
tools: List[ToolDefinition]
description: str = "Description of the tools purpose and overall goal"
example_conversation_history: str = (
"Example conversation history to help the AI agent understand the context of the conversation"
)

158
poetry.lock generated
View File

@@ -115,6 +115,17 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
[[package]]
name = "distro"
version = "1.9.0"
description = "Distro - an OS platform information API"
optional = false
python-versions = ">=3.6"
files = [
{file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"},
{file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"},
]
[[package]]
name = "exceptiongroup"
version = "1.2.2"
@@ -245,6 +256,91 @@ files = [
[package.extras]
colors = ["colorama (>=0.4.6)"]
[[package]]
name = "jiter"
version = "0.8.2"
description = "Fast iterable JSON parser."
optional = false
python-versions = ">=3.8"
files = [
{file = "jiter-0.8.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:ca8577f6a413abe29b079bc30f907894d7eb07a865c4df69475e868d73e71c7b"},
{file = "jiter-0.8.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b25bd626bde7fb51534190c7e3cb97cee89ee76b76d7585580e22f34f5e3f393"},
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5c826a221851a8dc028eb6d7d6429ba03184fa3c7e83ae01cd6d3bd1d4bd17d"},
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d35c864c2dff13dfd79fb070fc4fc6235d7b9b359efe340e1261deb21b9fcb66"},
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f557c55bc2b7676e74d39d19bcb8775ca295c7a028246175d6a8b431e70835e5"},
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:580ccf358539153db147e40751a0b41688a5ceb275e6f3e93d91c9467f42b2e3"},
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af102d3372e917cffce49b521e4c32c497515119dc7bd8a75665e90a718bbf08"},
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cadcc978f82397d515bb2683fc0d50103acff2a180552654bb92d6045dec2c49"},
{file = "jiter-0.8.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ba5bdf56969cad2019d4e8ffd3f879b5fdc792624129741d3d83fc832fef8c7d"},
{file = "jiter-0.8.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3b94a33a241bee9e34b8481cdcaa3d5c2116f575e0226e421bed3f7a6ea71cff"},
{file = "jiter-0.8.2-cp310-cp310-win32.whl", hash = "sha256:6e5337bf454abddd91bd048ce0dca5134056fc99ca0205258766db35d0a2ea43"},
{file = "jiter-0.8.2-cp310-cp310-win_amd64.whl", hash = "sha256:4a9220497ca0cb1fe94e3f334f65b9b5102a0b8147646118f020d8ce1de70105"},
{file = "jiter-0.8.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:2dd61c5afc88a4fda7d8b2cf03ae5947c6ac7516d32b7a15bf4b49569a5c076b"},
{file = "jiter-0.8.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a6c710d657c8d1d2adbbb5c0b0c6bfcec28fd35bd6b5f016395f9ac43e878a15"},
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9584de0cd306072635fe4b89742bf26feae858a0683b399ad0c2509011b9dc0"},
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5a90a923338531b7970abb063cfc087eebae6ef8ec8139762007188f6bc69a9f"},
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d21974d246ed0181558087cd9f76e84e8321091ebfb3a93d4c341479a736f099"},
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:32475a42b2ea7b344069dc1e81445cfc00b9d0e3ca837f0523072432332e9f74"},
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b9931fd36ee513c26b5bf08c940b0ac875de175341cbdd4fa3be109f0492586"},
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ce0820f4a3a59ddced7fce696d86a096d5cc48d32a4183483a17671a61edfddc"},
{file = "jiter-0.8.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8ffc86ae5e3e6a93765d49d1ab47b6075a9c978a2b3b80f0f32628f39caa0c88"},
{file = "jiter-0.8.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5127dc1abd809431172bc3fbe8168d6b90556a30bb10acd5ded41c3cfd6f43b6"},
{file = "jiter-0.8.2-cp311-cp311-win32.whl", hash = "sha256:66227a2c7b575720c1871c8800d3a0122bb8ee94edb43a5685aa9aceb2782d44"},
{file = "jiter-0.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:cde031d8413842a1e7501e9129b8e676e62a657f8ec8166e18a70d94d4682855"},
{file = "jiter-0.8.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:e6ec2be506e7d6f9527dae9ff4b7f54e68ea44a0ef6b098256ddf895218a2f8f"},
{file = "jiter-0.8.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:76e324da7b5da060287c54f2fabd3db5f76468006c811831f051942bf68c9d44"},
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:180a8aea058f7535d1c84183c0362c710f4750bef66630c05f40c93c2b152a0f"},
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:025337859077b41548bdcbabe38698bcd93cfe10b06ff66617a48ff92c9aec60"},
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecff0dc14f409599bbcafa7e470c00b80f17abc14d1405d38ab02e4b42e55b57"},
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ffd9fee7d0775ebaba131f7ca2e2d83839a62ad65e8e02fe2bd8fc975cedeb9e"},
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14601dcac4889e0a1c75ccf6a0e4baf70dbc75041e51bcf8d0e9274519df6887"},
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:92249669925bc1c54fcd2ec73f70f2c1d6a817928480ee1c65af5f6b81cdf12d"},
{file = "jiter-0.8.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e725edd0929fa79f8349ab4ec7f81c714df51dc4e991539a578e5018fa4a7152"},
{file = "jiter-0.8.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bf55846c7b7a680eebaf9c3c48d630e1bf51bdf76c68a5f654b8524335b0ad29"},
{file = "jiter-0.8.2-cp312-cp312-win32.whl", hash = "sha256:7efe4853ecd3d6110301665a5178b9856be7e2a9485f49d91aa4d737ad2ae49e"},
{file = "jiter-0.8.2-cp312-cp312-win_amd64.whl", hash = "sha256:83c0efd80b29695058d0fd2fa8a556490dbce9804eac3e281f373bbc99045f6c"},
{file = "jiter-0.8.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:ca1f08b8e43dc3bd0594c992fb1fd2f7ce87f7bf0d44358198d6da8034afdf84"},
{file = "jiter-0.8.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5672a86d55416ccd214c778efccf3266b84f87b89063b582167d803246354be4"},
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58dc9bc9767a1101f4e5e22db1b652161a225874d66f0e5cb8e2c7d1c438b587"},
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:37b2998606d6dadbb5ccda959a33d6a5e853252d921fec1792fc902351bb4e2c"},
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4ab9a87f3784eb0e098f84a32670cfe4a79cb6512fd8f42ae3d0709f06405d18"},
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:79aec8172b9e3c6d05fd4b219d5de1ac616bd8da934107325a6c0d0e866a21b6"},
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:711e408732d4e9a0208008e5892c2966b485c783cd2d9a681f3eb147cf36c7ef"},
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:653cf462db4e8c41995e33d865965e79641ef45369d8a11f54cd30888b7e6ff1"},
{file = "jiter-0.8.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:9c63eaef32b7bebac8ebebf4dabebdbc6769a09c127294db6babee38e9f405b9"},
{file = "jiter-0.8.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:eb21aaa9a200d0a80dacc7a81038d2e476ffe473ffdd9c91eb745d623561de05"},
{file = "jiter-0.8.2-cp313-cp313-win32.whl", hash = "sha256:789361ed945d8d42850f919342a8665d2dc79e7e44ca1c97cc786966a21f627a"},
{file = "jiter-0.8.2-cp313-cp313-win_amd64.whl", hash = "sha256:ab7f43235d71e03b941c1630f4b6e3055d46b6cb8728a17663eaac9d8e83a865"},
{file = "jiter-0.8.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b426f72cd77da3fec300ed3bc990895e2dd6b49e3bfe6c438592a3ba660e41ca"},
{file = "jiter-0.8.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2dd880785088ff2ad21ffee205e58a8c1ddabc63612444ae41e5e4b321b39c0"},
{file = "jiter-0.8.2-cp313-cp313t-win_amd64.whl", hash = "sha256:3ac9f578c46f22405ff7f8b1f5848fb753cc4b8377fbec8470a7dc3997ca7566"},
{file = "jiter-0.8.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:9e1fa156ee9454642adb7e7234a383884452532bc9d53d5af2d18d98ada1d79c"},
{file = "jiter-0.8.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0cf5dfa9956d96ff2efb0f8e9c7d055904012c952539a774305aaaf3abdf3d6c"},
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e52bf98c7e727dd44f7c4acb980cb988448faeafed8433c867888268899b298b"},
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a2ecaa3c23e7a7cf86d00eda3390c232f4d533cd9ddea4b04f5d0644faf642c5"},
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:08d4c92bf480e19fc3f2717c9ce2aa31dceaa9163839a311424b6862252c943e"},
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:99d9a1eded738299ba8e106c6779ce5c3893cffa0e32e4485d680588adae6db8"},
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d20be8b7f606df096e08b0b1b4a3c6f0515e8dac296881fe7461dfa0fb5ec817"},
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d33f94615fcaf872f7fd8cd98ac3b429e435c77619777e8a449d9d27e01134d1"},
{file = "jiter-0.8.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:317b25e98a35ffec5c67efe56a4e9970852632c810d35b34ecdd70cc0e47b3b6"},
{file = "jiter-0.8.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fc9043259ee430ecd71d178fccabd8c332a3bf1e81e50cae43cc2b28d19e4cb7"},
{file = "jiter-0.8.2-cp38-cp38-win32.whl", hash = "sha256:fc5adda618205bd4678b146612ce44c3cbfdee9697951f2c0ffdef1f26d72b63"},
{file = "jiter-0.8.2-cp38-cp38-win_amd64.whl", hash = "sha256:cd646c827b4f85ef4a78e4e58f4f5854fae0caf3db91b59f0d73731448a970c6"},
{file = "jiter-0.8.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:e41e75344acef3fc59ba4765df29f107f309ca9e8eace5baacabd9217e52a5ee"},
{file = "jiter-0.8.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f22b16b35d5c1df9dfd58843ab2cd25e6bf15191f5a236bed177afade507bfc"},
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7200b8f7619d36aa51c803fd52020a2dfbea36ffec1b5e22cab11fd34d95a6d"},
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:70bf4c43652cc294040dbb62256c83c8718370c8b93dd93d934b9a7bf6c4f53c"},
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f9d471356dc16f84ed48768b8ee79f29514295c7295cb41e1133ec0b2b8d637d"},
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:859e8eb3507894093d01929e12e267f83b1d5f6221099d3ec976f0c995cb6bd9"},
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaa58399c01db555346647a907b4ef6d4f584b123943be6ed5588c3f2359c9f4"},
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8f2d5ed877f089862f4c7aacf3a542627c1496f972a34d0474ce85ee7d939c27"},
{file = "jiter-0.8.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:03c9df035d4f8d647f8c210ddc2ae0728387275340668fb30d2421e17d9a0841"},
{file = "jiter-0.8.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8bd2a824d08d8977bb2794ea2682f898ad3d8837932e3a74937e93d62ecbb637"},
{file = "jiter-0.8.2-cp39-cp39-win32.whl", hash = "sha256:ca29b6371ebc40e496995c94b988a101b9fbbed48a51190a4461fcb0a68b4a36"},
{file = "jiter-0.8.2-cp39-cp39-win_amd64.whl", hash = "sha256:1c0dfbd1be3cbefc7510102370d86e35d1d53e5a93d48519688b1bf0f761160a"},
{file = "jiter-0.8.2.tar.gz", hash = "sha256:cd73d3e740666d0e639f678adb176fad25c1bcbdae88d8d7b857e1783bb4212d"},
]
[[package]]
name = "mypy-extensions"
version = "1.0.0"
@@ -271,6 +367,31 @@ files = [
httpx = ">=0.27.0,<0.28.0"
pydantic = ">=2.9.0,<3.0.0"
[[package]]
name = "openai"
version = "1.59.2"
description = "The official Python library for the openai API"
optional = false
python-versions = ">=3.8"
files = [
{file = "openai-1.59.2-py3-none-any.whl", hash = "sha256:3de721df4d2ccc5e845afa7235dce496bfbdd572692a876d2ae6211e0290ff22"},
{file = "openai-1.59.2.tar.gz", hash = "sha256:1bf2d5e8a93533f6dd3fb7b1bcf082ddd4ae61cc6d89ca1343e5957e4720651c"},
]
[package.dependencies]
anyio = ">=3.5.0,<5"
distro = ">=1.7.0,<2"
httpx = ">=0.23.0,<1"
jiter = ">=0.4.0,<1"
pydantic = ">=1.9.0,<3"
sniffio = "*"
tqdm = ">4"
typing-extensions = ">=4.11,<5"
[package.extras]
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
realtime = ["websockets (>=13,<15)"]
[[package]]
name = "packaging"
version = "24.2"
@@ -512,6 +633,20 @@ files = [
[package.dependencies]
six = ">=1.5"
[[package]]
name = "python-dotenv"
version = "1.0.1"
description = "Read key-value pairs from a .env file and set them as environment variables"
optional = false
python-versions = ">=3.8"
files = [
{file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"},
{file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"},
]
[package.extras]
cli = ["click (>=5.0)"]
[[package]]
name = "pyyaml"
version = "6.0.2"
@@ -680,6 +815,27 @@ files = [
{file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"},
]
[[package]]
name = "tqdm"
version = "4.67.1"
description = "Fast, Extensible Progress Meter"
optional = false
python-versions = ">=3.7"
files = [
{file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"},
{file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"},
]
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[package.extras]
dev = ["nbval", "pytest (>=6)", "pytest-asyncio (>=0.24)", "pytest-cov", "pytest-timeout"]
discord = ["requests"]
notebook = ["ipywidgets (>=6)"]
slack = ["slack-sdk"]
telegram = ["requests"]
[[package]]
name = "types-protobuf"
version = "5.29.1.20241207"
@@ -724,4 +880,4 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)",
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<4.0"
content-hash = "20144fbc5773604251f9b61ac475eb9f292c6c8baf38d59170bbef81b482c71e"
content-hash = "ddef05a187e2f0364c8dc045d52aa30b7c95bc30075cfc0aa051af4fd1f8545b"

View File

@@ -7,30 +7,49 @@ def generate_genai_prompt(
tools_data: ToolsData, conversation_history: str, raw_json: Optional[str] = None
) -> str:
"""
Generates a concise prompt for producing or validating JSON instructions.
Generates a concise prompt for producing or validating JSON instructions
with the provided tools and conversation history.
"""
prompt_lines = []
# Intro / Role
prompt_lines.append(
"You are an AI assistant that must produce or validate JSON instructions "
"to properly call a set of tools. Respond with valid JSON only."
"You are an AI agent that helps fill required arguments for the tools described below. "
"You must respond with valid JSON ONLY, using the schema provided in the instructions."
)
# Conversation History
# Main Conversation History
prompt_lines.append("=== Conversation History ===")
prompt_lines.append(
"Use this history to understand needed tools, arguments, and the user's goals:"
"This is the ongoing history to determine which tool and arguments to gather:"
)
prompt_lines.append("BEGIN CONVERSATION HISTORY")
prompt_lines.append(json.dumps(conversation_history, indent=2))
prompt_lines.append("END CONVERSATION HISTORY")
prompt_lines.append("")
# Example Conversation History (from tools_data)
if tools_data.example_conversation_history:
prompt_lines.append("=== Example Conversation With These Tools ===")
prompt_lines.append(
"Use this example to understand how tools are invoked and arguments are gathered."
)
prompt_lines.append("BEGIN EXAMPLE")
prompt_lines.append(tools_data.example_conversation_history)
prompt_lines.append("END EXAMPLE")
prompt_lines.append("")
# Tools Definitions
prompt_lines.append("=== Tools Definitions ===")
prompt_lines.append(f"There are {len(tools_data.tools)} available tools:")
prompt_lines.append(", ".join([t.name for t in tools_data.tools]))
prompt_lines.append(f"Goal: {tools_data.description}")
prompt_lines.append(
"Gather the necessary information for each tool in the sequence described above."
)
prompt_lines.append(
"Only ask for arguments listed below. Do not add extra arguments."
)
prompt_lines.append("")
for tool in tools_data.tools:
prompt_lines.append(f"Tool name: {tool.name}")
@@ -39,8 +58,11 @@ def generate_genai_prompt(
for arg in tool.arguments:
prompt_lines.append(f" - {arg.name} ({arg.type}): {arg.description}")
prompt_lines.append("")
prompt_lines.append(
"When all required args for a tool are known, you can propose next='confirm' to run it."
)
# Instructions for JSON Generation
# JSON Format Instructions
prompt_lines.append("=== Instructions for JSON Generation ===")
prompt_lines.append(
"Your JSON format must be:\n"
@@ -56,14 +78,14 @@ def generate_genai_prompt(
"}"
)
prompt_lines.append(
"1. You may call multiple tools sequentially. Each requires specific arguments.\n"
'2. If ANY required argument is missing, use "next": "question" and prompt the user.\n'
'3. If all required arguments are known, use "next": "confirm" and set "tool" to the tool name.\n'
'4. If no further actions are needed, use "next": "done" and "tool": "null".\n'
'5. Keep "response" short and user-friendly. Do not include any metadata or editorializing.\n'
"1) If any required argument is missing, set next='question' and ask the user.\n"
"2) If all required arguments are known, set next='confirm' and specify the tool.\n"
" The user will confirm before the tool is run.\n"
"3) If no more tools are needed, set next='done' and tool=null.\n"
"4) response should be short and user-friendly."
)
# Validation Task (Only if raw_json is provided)
# Validation Task (If raw_json is provided)
if raw_json is not None:
prompt_lines.append("")
prompt_lines.append("=== Validation Task ===")
@@ -71,60 +93,17 @@ def generate_genai_prompt(
prompt_lines.append(json.dumps(raw_json, indent=2))
prompt_lines.append("")
prompt_lines.append(
"Check syntax, ensure 'tool' is correct or 'null', verify 'args' are valid, "
'and set "next" appropriately based on missing or complete args.'
"Check syntax, 'tool' validity, 'args' completeness, "
"and set 'next' appropriately. Return ONLY corrected JSON."
)
prompt_lines.append("Return only the corrected JSON, no extra text.")
# Common Reminders and Examples
prompt_lines.append("")
prompt_lines.append("=== Usage Examples ===")
prompt_lines.append(
"Example for missing args (needs user input):\n"
"{\n"
' "response": "I need your departure city.",\n'
' "next": "question",\n'
' "tool": "SearchFlights",\n'
' "args": {\n'
' "origin": null,\n'
' "destination": "Melbourne",\n'
' "dateDepart": "2025-03-26",\n'
' "dateReturn": "2025-04-20"\n'
" }\n"
"}"
)
prompt_lines.append(
"Example for confirmed args:\n"
"{\n"
' "response": "All arguments are set.",\n'
' "next": "confirm",\n'
' "tool": "SearchFlights",\n'
' "args": {\n'
' "origin": "Seattle",\n'
' "destination": "Melbourne",\n'
' "dateDepart": "2025-03-26",\n'
' "dateReturn": "2025-04-20"\n'
" }\n"
"}"
)
prompt_lines.append(
"Example when fully done:\n"
"{\n"
' "response": "All tools completed successfully. Final result: <insert result here>",\n'
' "next": "done",\n'
' "tool": "",\n'
' "args": {}\n'
"}"
)
# Prompt Start
prompt_lines.append("")
if raw_json is not None:
prompt_lines.append("")
prompt_lines.append("Begin by validating the provided JSON if necessary.")
else:
prompt_lines.append("")
prompt_lines.append(
"Begin by producing a valid JSON response for the next step."
"Begin by producing a valid JSON response for the next tool or question."
)
return "\n".join(prompt_lines)

View File

@@ -25,6 +25,8 @@ ollama = "^0.4.5"
pyyaml = "^6.0.2"
fastapi = "^0.115.6"
uvicorn = "^0.34.0"
python-dotenv = "^1.0.1"
openai = "^1.59.2"
[tool.poetry.group.dev.dependencies]
pytest = "^7.3"

View File

@@ -7,6 +7,9 @@ from temporalio.worker import Worker
from activities.tool_activities import ToolActivities, dynamic_tool_activity
from workflows.tool_workflow import ToolWorkflow
from dotenv import load_dotenv
load_dotenv()
async def main():

View File

@@ -3,13 +3,16 @@ import sys
from temporalio.client import Client
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
from tools.tool_registry import all_tools
from tools.tool_registry import event_travel_tools
from workflows.tool_workflow import ToolWorkflow
async def main(prompt: str):
# 1) Build the ToolsData from imported all_tools
tools_data = ToolsData(tools=all_tools)
# Build the ToolsData
tools_data = ToolsData(
tools=event_travel_tools,
description="Helps the user find an event to travel to, search flights, and create an invoice for those flights.",
)
# 2) Create combined input
combined_input = CombinedInput(

View File

@@ -1,17 +1,17 @@
def find_events(args: dict) -> dict:
# Example: continent="Oceania", month="April"
continent = args.get("continent")
region = args.get("region")
month = args.get("month")
print(f"[FindEvents] Searching events in {continent} for {month} ...")
print(f"[FindEvents] Searching events in {region} for {month} ...")
# Stub result
return {
"eventsFound": [
"events": [
{
"city": "Melbourne",
"eventName": "Melbourne International Comedy Festival",
"dates": "2025-03-26 to 2025-04-20",
"dateFrom": "2025-03-26",
"dateTo": "2025-04-20",
},
],
"status": "found-events",
]
}

View File

@@ -4,21 +4,20 @@ def search_flights(args: dict) -> dict:
Currently just prints/returns the passed args,
but you can add real flight search logic later.
"""
date_depart = args.get("dateDepart")
date_return = args.get("dateReturn")
# date_depart = args.get("dateDepart")
# date_return = args.get("dateReturn")
origin = args.get("origin")
destination = args.get("destination")
print(f"Searching flights from {origin} to {destination}")
print(f"Depart: {date_depart}, Return: {date_return}")
# Return a mock result so you can verify it
return {
"tool": "SearchFlights",
"searchResults": [
"QF123: $1200",
"VA456: $1000",
flight_search_results = {
"origin": f"{origin}",
"destination": f"{destination}",
"currency": "USD",
"results": [
{"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0},
{"flight_number": "QF30", "return_flight_number": "QF29", "price": 920.0},
{"flight_number": "MH129", "return_flight_number": "MH128", "price": 780.0},
],
"status": "search-complete",
"args": args,
}
return flight_search_results

View File

@@ -2,12 +2,12 @@ from models.tool_definitions import ToolDefinition, ToolArgument
find_events_tool = ToolDefinition(
name="FindEvents",
description="Find upcoming events given a location or region (e.g., 'Oceania') and a date or month",
description="Find upcoming events to travel to given a location or region (e.g., 'Oceania') and a date or month",
arguments=[
ToolArgument(
name="continent",
name="region",
type="string",
description="Which continent or region to search for events",
description="Which region to search for events",
),
ToolArgument(
name="month",
@@ -20,7 +20,7 @@ find_events_tool = ToolDefinition(
# 2) Define the SearchFlights tool
search_flights_tool = ToolDefinition(
name="SearchFlights",
description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)",
description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn).",
arguments=[
ToolArgument(
name="origin",
@@ -48,7 +48,7 @@ search_flights_tool = ToolDefinition(
# 3) Define the CreateInvoice tool
create_invoice_tool = ToolDefinition(
name="CreateInvoice",
description="Generate an invoice with flight information.",
description="Generate an invoice for the items described for the amount provided",
arguments=[
ToolArgument(
name="amount",
@@ -58,9 +58,7 @@ create_invoice_tool = ToolDefinition(
ToolArgument(
name="flightDetails",
type="string",
description="A summary of the flights, e.g., flight number and airport codes",
description="A description of the item details to be invoiced",
),
],
)
all_tools = [find_events_tool, search_flights_tool, create_invoice_tool]

View File

@@ -63,8 +63,8 @@ class ToolWorkflow:
confirmed_tool_data = self.tool_data.copy()
confirmed_tool_data["next"] = "confirmed"
self.add_message("userToolConfirm", confirmed_tool_data)
confirmed_tool_data["next"] = "user_confirmed_tool_run"
self.add_message("user_confirmed_tool_run", confirmed_tool_data)
# Run the tool
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
@@ -81,10 +81,8 @@ class ToolWorkflow:
# Enqueue a follow-up prompt for the LLM
self.prompt_queue.append(
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
"INSTRUCTIONS: Use this tool result, and the conversation history to figure out next steps, if any. "
"IMPORTANT REMINDER: Always return only JSON in the format: {'response': '', 'next': '', 'tool': '', 'args': {}} "
" Do NOT include any metadata or editorializing in the response. "
"IMPORTANT: If moving on to another tool then ensure you ask next='question' for any missing arguments."
"INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. "
"DON'T ask any clarifying questions that are outside of the tools and args specified. "
)
# Loop around again
continue
@@ -102,11 +100,13 @@ class ToolWorkflow:
tools_data, self.conversation_history, self.tool_data
)
# tools_list = ", ".join([t.name for t in tools_data.tools])
prompt_input = ToolPromptInput(
prompt=prompt,
context_instructions=context_instructions,
)
tool_data = await workflow.execute_activity_method(
tool_data = await workflow.execute_activity(
ToolActivities.prompt_llm,
prompt_input,
schedule_to_close_timeout=timedelta(seconds=60),
@@ -115,13 +115,37 @@ class ToolWorkflow:
),
)
self.tool_data = tool_data
self.add_message("response", tool_data)
# Check the next step from LLM
next_step = self.tool_data.get("next")
current_tool = self.tool_data.get("tool")
if next_step == "confirm" and current_tool:
# tmp arg check
args = self.tool_data.get("args")
# check each argument for null values
missing_args = []
for key, value in args.items():
if value is None:
next_step = "question"
missing_args.append(key)
if missing_args:
# self.add_message("response_confirm_missing_args", tool_data)
# Enqueue a follow-up prompt for the LLM
self.prompt_queue.append(
f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' and following missing arguments for tool {current_tool}: {missing_args}. "
"Only provide a valid JSON response without any comments or metadata."
)
workflow.logger.info(
f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}"
)
# Loop around again
continue
waiting_for_confirm = True
self.confirm = False # Clear any stale confirm
workflow.logger.info("Waiting for user confirm signal...")
@@ -130,8 +154,11 @@ class ToolWorkflow:
elif next_step == "done":
workflow.logger.info("All steps completed. Exiting workflow.")
self.add_message("agent", tool_data)
return str(self.conversation_history)
self.add_message("agent", tool_data)
# Possibly continue-as-new after many turns
# todo ensure this doesn't lose critical context
if (