mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
works a lot better with 4o!
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -30,3 +30,5 @@ coverage.xml
|
|||||||
|
|
||||||
# PyCharm / IntelliJ settings
|
# PyCharm / IntelliJ settings
|
||||||
.idea/
|
.idea/
|
||||||
|
|
||||||
|
.env
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from temporalio import activity
|
from temporalio import activity
|
||||||
from ollama import chat, ChatResponse
|
from ollama import chat, ChatResponse
|
||||||
|
from openai import OpenAI
|
||||||
import json
|
import json
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
from temporalio.common import RawValue
|
from temporalio.common import RawValue
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -15,6 +18,46 @@ class ToolPromptInput:
|
|||||||
class ToolActivities:
|
class ToolActivities:
|
||||||
@activity.defn
|
@activity.defn
|
||||||
def prompt_llm(self, input: ToolPromptInput) -> dict:
|
def prompt_llm(self, input: ToolPromptInput) -> dict:
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=os.environ.get(
|
||||||
|
"OPENAI_API_KEY"
|
||||||
|
), # This is the default and can be omitted
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": input.context_instructions
|
||||||
|
+ ". The current date is "
|
||||||
|
+ datetime.now().strftime("%B %d, %Y"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": input.prompt,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
chat_completion = client.chat.completions.create(
|
||||||
|
model="gpt-4o", messages=messages # was gpt-4-0613
|
||||||
|
)
|
||||||
|
|
||||||
|
response_content = chat_completion.choices[0].message.content
|
||||||
|
print(f"ChatGPT response: {response_content}")
|
||||||
|
|
||||||
|
# Trim formatting markers if present
|
||||||
|
if response_content.startswith("```json") and response_content.endswith("```"):
|
||||||
|
response_content = response_content[7:-3].strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(response_content)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f"Invalid JSON: {e}")
|
||||||
|
raise json.JSONDecodeError
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
@activity.defn
|
||||||
|
def prompt_llm_ollama(self, input: ToolPromptInput) -> dict:
|
||||||
model_name = "qwen2.5:14b"
|
model_name = "qwen2.5:14b"
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
|
|||||||
35
api/main.py
35
api/main.py
@@ -4,7 +4,11 @@ from workflows.tool_workflow import ToolWorkflow
|
|||||||
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
|
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
|
||||||
from temporalio.exceptions import TemporalError
|
from temporalio.exceptions import TemporalError
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from tools.tool_registry import all_tools
|
from tools.tool_registry import (
|
||||||
|
find_events_tool,
|
||||||
|
search_flights_tool,
|
||||||
|
create_invoice_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
@@ -65,7 +69,34 @@ async def send_prompt(prompt: str):
|
|||||||
client = await Client.connect("localhost:7233")
|
client = await Client.connect("localhost:7233")
|
||||||
|
|
||||||
# Build the ToolsData
|
# Build the ToolsData
|
||||||
tools_data = ToolsData(tools=all_tools)
|
tools_data = ToolsData(
|
||||||
|
tools=[find_events_tool, search_flights_tool, create_invoice_tool],
|
||||||
|
description="Help the user gather args for these tools in order: "
|
||||||
|
"1. FindEvents: Find an event to travel to "
|
||||||
|
"2. SearchFlights: search for a flight around the event dates "
|
||||||
|
"3. GenerateInvoice: Create a simple invoice for the cost of that flight ",
|
||||||
|
example_conversation_history="\n ".join(
|
||||||
|
[
|
||||||
|
"user: I'd like to travel to an event",
|
||||||
|
"agent: Sure! Let's start by finding an event you'd like to attend. Could you tell me which region and month you're interested in?",
|
||||||
|
"user: In Sao Paulo, Brazil, in February",
|
||||||
|
"agent: Great! Let's find an events in Sao Paulo, Brazil in February.",
|
||||||
|
"user_confirmed_tool_run: <user clicks confirm on FindEvents tool>",
|
||||||
|
"tool_result: { 'event_name': 'Carnival', 'event_date': '2023-02-25' }",
|
||||||
|
"agent: Found an event! There's Carnival on 2023-02-25, ending on 2023-02-28. Would you like to search for flights around these dates?",
|
||||||
|
"user: Yes, please",
|
||||||
|
"agent: Let's search for flights around these dates. Could you provide your departure city?",
|
||||||
|
"user: New York",
|
||||||
|
"agent: Thanks, searching for flights from New York to Sao Paulo around 2023-02-25 to 2023-02-28.",
|
||||||
|
"user_confirmed_tool_run: <user clicks confirm on SearchFlights tool>"
|
||||||
|
'tool_result: results including {"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0}',
|
||||||
|
"agent: Found some flights! The cheapest is CX101 for $850. Would you like to generate an invoice for this flight?",
|
||||||
|
"user_confirmed_tool_run: <user clicks confirm on CreateInvoice tool>",
|
||||||
|
'tool_result: { "status": "success", "invoice": { "flight_number": "CX101", "amount": 850.0 }, invoiceURL: "https://example.com/invoice" }',
|
||||||
|
"agent: Invoice generated! Here's the link: https://example.com/invoice",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# Create combined input
|
# Create combined input
|
||||||
combined_input = CombinedInput(
|
combined_input = CombinedInput(
|
||||||
|
|||||||
@@ -19,17 +19,18 @@ export default function ChatWindow({ conversation, loading, onConfirm }) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const filtered = conversation.filter((msg) => {
|
const filtered = conversation.filter((msg) => {
|
||||||
console.log(conversation[conversation.length - 1].actor)
|
|
||||||
const { actor, response } = msg;
|
const { actor, response } = msg;
|
||||||
|
|
||||||
if (actor === "user") {
|
if (actor === "user") {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (actor === "response") {
|
if (actor === "agent") {
|
||||||
const parsed = typeof response === "string" ? safeParse(response) : response;
|
const parsed = typeof response === "string" ? safeParse(response) : response;
|
||||||
// Keep if next is "question", "confirm", or "confirmed".
|
// Keep if next is "question", "confirm", or "user_confirmed_tool_run".
|
||||||
// Only skip if next is "done" (or something else).
|
// Only skip if next is "done" (or something else).
|
||||||
return !["done"].includes(parsed.next);
|
// return !["done"].includes(parsed.next);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
});
|
});
|
||||||
@@ -37,13 +38,14 @@ export default function ChatWindow({ conversation, loading, onConfirm }) {
|
|||||||
return (
|
return (
|
||||||
<div className="flex-grow overflow-y-auto space-y-4">
|
<div className="flex-grow overflow-y-auto space-y-4">
|
||||||
{filtered.map((msg, idx) => {
|
{filtered.map((msg, idx) => {
|
||||||
|
|
||||||
const { actor, response } = msg;
|
const { actor, response } = msg;
|
||||||
|
|
||||||
if (actor === "user") {
|
if (actor === "user") {
|
||||||
return (
|
return (
|
||||||
<MessageBubble key={idx} message={{ response }} isUser />
|
<MessageBubble key={idx} message={{ response }} isUser />
|
||||||
);
|
);
|
||||||
} else if (actor === "response") {
|
} else if (actor === "agent") {
|
||||||
const data =
|
const data =
|
||||||
typeof response === "string" ? safeParse(response) : response;
|
typeof response === "string" ? safeParse(response) : response;
|
||||||
return <LLMResponse key={idx} data={data} onConfirm={onConfirm} />;
|
return <LLMResponse key={idx} data={data} onConfirm={onConfirm} />;
|
||||||
@@ -57,16 +59,6 @@ export default function ChatWindow({ conversation, loading, onConfirm }) {
|
|||||||
<LoadingIndicator />
|
<LoadingIndicator />
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
{conversation.length > 0 && conversation[conversation.length - 1].actor === "user" && (
|
|
||||||
<div className="flex justify-center">
|
|
||||||
<LoadingIndicator />
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
{conversation.length > 0 && conversation[conversation.length - 1].actor === "tool_result_to_llm" && (
|
|
||||||
<div className="flex justify-center">
|
|
||||||
<LoadingIndicator />
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,10 @@ export default function LLMResponse({ data, onConfirm }) {
|
|||||||
|
|
||||||
const requiresConfirm = data.next === "confirm";
|
const requiresConfirm = data.next === "confirm";
|
||||||
|
|
||||||
|
if (typeof data.response === "object") {
|
||||||
|
data.response = data.response.response;
|
||||||
|
}
|
||||||
|
|
||||||
let displayText = (data.response || "").trim();
|
let displayText = (data.response || "").trim();
|
||||||
if (!displayText && requiresConfirm) {
|
if (!displayText && requiresConfirm) {
|
||||||
displayText = `Agent is ready to run "${data.tool}". Please confirm.`;
|
displayText = `Agent is ready to run "${data.tool}". Please confirm.`;
|
||||||
|
|||||||
@@ -18,9 +18,12 @@ export default function App() {
|
|||||||
const data = await res.json();
|
const data = await res.json();
|
||||||
// data is now an object like { messages: [ ... ] }
|
// data is now an object like { messages: [ ... ] }
|
||||||
|
|
||||||
if (data.messages && data.messages.some(msg => msg.actor === "response" || msg.actor === "tool_result")) {
|
if (data.messages && data.messages.length > 0 && (data.messages[data.messages.length - 1].actor === "agent")) {
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
}
|
}
|
||||||
|
else {
|
||||||
|
setLoading(true);
|
||||||
|
}
|
||||||
setConversation(data.messages || []);
|
setConversation(data.messages || []);
|
||||||
}
|
}
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import List, Dict, Any
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -19,3 +19,7 @@ class ToolDefinition:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ToolsData:
|
class ToolsData:
|
||||||
tools: List[ToolDefinition]
|
tools: List[ToolDefinition]
|
||||||
|
description: str = "Description of the tools purpose and overall goal"
|
||||||
|
example_conversation_history: str = (
|
||||||
|
"Example conversation history to help the AI agent understand the context of the conversation"
|
||||||
|
)
|
||||||
|
|||||||
158
poetry.lock
generated
158
poetry.lock
generated
@@ -115,6 +115,17 @@ files = [
|
|||||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "distro"
|
||||||
|
version = "1.9.0"
|
||||||
|
description = "Distro - an OS platform information API"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.6"
|
||||||
|
files = [
|
||||||
|
{file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"},
|
||||||
|
{file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "exceptiongroup"
|
name = "exceptiongroup"
|
||||||
version = "1.2.2"
|
version = "1.2.2"
|
||||||
@@ -245,6 +256,91 @@ files = [
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
colors = ["colorama (>=0.4.6)"]
|
colors = ["colorama (>=0.4.6)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jiter"
|
||||||
|
version = "0.8.2"
|
||||||
|
description = "Fast iterable JSON parser."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "jiter-0.8.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:ca8577f6a413abe29b079bc30f907894d7eb07a865c4df69475e868d73e71c7b"},
|
||||||
|
{file = "jiter-0.8.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b25bd626bde7fb51534190c7e3cb97cee89ee76b76d7585580e22f34f5e3f393"},
|
||||||
|
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d5c826a221851a8dc028eb6d7d6429ba03184fa3c7e83ae01cd6d3bd1d4bd17d"},
|
||||||
|
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d35c864c2dff13dfd79fb070fc4fc6235d7b9b359efe340e1261deb21b9fcb66"},
|
||||||
|
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f557c55bc2b7676e74d39d19bcb8775ca295c7a028246175d6a8b431e70835e5"},
|
||||||
|
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:580ccf358539153db147e40751a0b41688a5ceb275e6f3e93d91c9467f42b2e3"},
|
||||||
|
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af102d3372e917cffce49b521e4c32c497515119dc7bd8a75665e90a718bbf08"},
|
||||||
|
{file = "jiter-0.8.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:cadcc978f82397d515bb2683fc0d50103acff2a180552654bb92d6045dec2c49"},
|
||||||
|
{file = "jiter-0.8.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ba5bdf56969cad2019d4e8ffd3f879b5fdc792624129741d3d83fc832fef8c7d"},
|
||||||
|
{file = "jiter-0.8.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:3b94a33a241bee9e34b8481cdcaa3d5c2116f575e0226e421bed3f7a6ea71cff"},
|
||||||
|
{file = "jiter-0.8.2-cp310-cp310-win32.whl", hash = "sha256:6e5337bf454abddd91bd048ce0dca5134056fc99ca0205258766db35d0a2ea43"},
|
||||||
|
{file = "jiter-0.8.2-cp310-cp310-win_amd64.whl", hash = "sha256:4a9220497ca0cb1fe94e3f334f65b9b5102a0b8147646118f020d8ce1de70105"},
|
||||||
|
{file = "jiter-0.8.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:2dd61c5afc88a4fda7d8b2cf03ae5947c6ac7516d32b7a15bf4b49569a5c076b"},
|
||||||
|
{file = "jiter-0.8.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a6c710d657c8d1d2adbbb5c0b0c6bfcec28fd35bd6b5f016395f9ac43e878a15"},
|
||||||
|
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a9584de0cd306072635fe4b89742bf26feae858a0683b399ad0c2509011b9dc0"},
|
||||||
|
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5a90a923338531b7970abb063cfc087eebae6ef8ec8139762007188f6bc69a9f"},
|
||||||
|
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d21974d246ed0181558087cd9f76e84e8321091ebfb3a93d4c341479a736f099"},
|
||||||
|
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:32475a42b2ea7b344069dc1e81445cfc00b9d0e3ca837f0523072432332e9f74"},
|
||||||
|
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b9931fd36ee513c26b5bf08c940b0ac875de175341cbdd4fa3be109f0492586"},
|
||||||
|
{file = "jiter-0.8.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ce0820f4a3a59ddced7fce696d86a096d5cc48d32a4183483a17671a61edfddc"},
|
||||||
|
{file = "jiter-0.8.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8ffc86ae5e3e6a93765d49d1ab47b6075a9c978a2b3b80f0f32628f39caa0c88"},
|
||||||
|
{file = "jiter-0.8.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5127dc1abd809431172bc3fbe8168d6b90556a30bb10acd5ded41c3cfd6f43b6"},
|
||||||
|
{file = "jiter-0.8.2-cp311-cp311-win32.whl", hash = "sha256:66227a2c7b575720c1871c8800d3a0122bb8ee94edb43a5685aa9aceb2782d44"},
|
||||||
|
{file = "jiter-0.8.2-cp311-cp311-win_amd64.whl", hash = "sha256:cde031d8413842a1e7501e9129b8e676e62a657f8ec8166e18a70d94d4682855"},
|
||||||
|
{file = "jiter-0.8.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:e6ec2be506e7d6f9527dae9ff4b7f54e68ea44a0ef6b098256ddf895218a2f8f"},
|
||||||
|
{file = "jiter-0.8.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:76e324da7b5da060287c54f2fabd3db5f76468006c811831f051942bf68c9d44"},
|
||||||
|
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:180a8aea058f7535d1c84183c0362c710f4750bef66630c05f40c93c2b152a0f"},
|
||||||
|
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:025337859077b41548bdcbabe38698bcd93cfe10b06ff66617a48ff92c9aec60"},
|
||||||
|
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecff0dc14f409599bbcafa7e470c00b80f17abc14d1405d38ab02e4b42e55b57"},
|
||||||
|
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ffd9fee7d0775ebaba131f7ca2e2d83839a62ad65e8e02fe2bd8fc975cedeb9e"},
|
||||||
|
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14601dcac4889e0a1c75ccf6a0e4baf70dbc75041e51bcf8d0e9274519df6887"},
|
||||||
|
{file = "jiter-0.8.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:92249669925bc1c54fcd2ec73f70f2c1d6a817928480ee1c65af5f6b81cdf12d"},
|
||||||
|
{file = "jiter-0.8.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e725edd0929fa79f8349ab4ec7f81c714df51dc4e991539a578e5018fa4a7152"},
|
||||||
|
{file = "jiter-0.8.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bf55846c7b7a680eebaf9c3c48d630e1bf51bdf76c68a5f654b8524335b0ad29"},
|
||||||
|
{file = "jiter-0.8.2-cp312-cp312-win32.whl", hash = "sha256:7efe4853ecd3d6110301665a5178b9856be7e2a9485f49d91aa4d737ad2ae49e"},
|
||||||
|
{file = "jiter-0.8.2-cp312-cp312-win_amd64.whl", hash = "sha256:83c0efd80b29695058d0fd2fa8a556490dbce9804eac3e281f373bbc99045f6c"},
|
||||||
|
{file = "jiter-0.8.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:ca1f08b8e43dc3bd0594c992fb1fd2f7ce87f7bf0d44358198d6da8034afdf84"},
|
||||||
|
{file = "jiter-0.8.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5672a86d55416ccd214c778efccf3266b84f87b89063b582167d803246354be4"},
|
||||||
|
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58dc9bc9767a1101f4e5e22db1b652161a225874d66f0e5cb8e2c7d1c438b587"},
|
||||||
|
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:37b2998606d6dadbb5ccda959a33d6a5e853252d921fec1792fc902351bb4e2c"},
|
||||||
|
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4ab9a87f3784eb0e098f84a32670cfe4a79cb6512fd8f42ae3d0709f06405d18"},
|
||||||
|
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:79aec8172b9e3c6d05fd4b219d5de1ac616bd8da934107325a6c0d0e866a21b6"},
|
||||||
|
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:711e408732d4e9a0208008e5892c2966b485c783cd2d9a681f3eb147cf36c7ef"},
|
||||||
|
{file = "jiter-0.8.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:653cf462db4e8c41995e33d865965e79641ef45369d8a11f54cd30888b7e6ff1"},
|
||||||
|
{file = "jiter-0.8.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:9c63eaef32b7bebac8ebebf4dabebdbc6769a09c127294db6babee38e9f405b9"},
|
||||||
|
{file = "jiter-0.8.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:eb21aaa9a200d0a80dacc7a81038d2e476ffe473ffdd9c91eb745d623561de05"},
|
||||||
|
{file = "jiter-0.8.2-cp313-cp313-win32.whl", hash = "sha256:789361ed945d8d42850f919342a8665d2dc79e7e44ca1c97cc786966a21f627a"},
|
||||||
|
{file = "jiter-0.8.2-cp313-cp313-win_amd64.whl", hash = "sha256:ab7f43235d71e03b941c1630f4b6e3055d46b6cb8728a17663eaac9d8e83a865"},
|
||||||
|
{file = "jiter-0.8.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b426f72cd77da3fec300ed3bc990895e2dd6b49e3bfe6c438592a3ba660e41ca"},
|
||||||
|
{file = "jiter-0.8.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2dd880785088ff2ad21ffee205e58a8c1ddabc63612444ae41e5e4b321b39c0"},
|
||||||
|
{file = "jiter-0.8.2-cp313-cp313t-win_amd64.whl", hash = "sha256:3ac9f578c46f22405ff7f8b1f5848fb753cc4b8377fbec8470a7dc3997ca7566"},
|
||||||
|
{file = "jiter-0.8.2-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:9e1fa156ee9454642adb7e7234a383884452532bc9d53d5af2d18d98ada1d79c"},
|
||||||
|
{file = "jiter-0.8.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:0cf5dfa9956d96ff2efb0f8e9c7d055904012c952539a774305aaaf3abdf3d6c"},
|
||||||
|
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e52bf98c7e727dd44f7c4acb980cb988448faeafed8433c867888268899b298b"},
|
||||||
|
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a2ecaa3c23e7a7cf86d00eda3390c232f4d533cd9ddea4b04f5d0644faf642c5"},
|
||||||
|
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:08d4c92bf480e19fc3f2717c9ce2aa31dceaa9163839a311424b6862252c943e"},
|
||||||
|
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:99d9a1eded738299ba8e106c6779ce5c3893cffa0e32e4485d680588adae6db8"},
|
||||||
|
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d20be8b7f606df096e08b0b1b4a3c6f0515e8dac296881fe7461dfa0fb5ec817"},
|
||||||
|
{file = "jiter-0.8.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d33f94615fcaf872f7fd8cd98ac3b429e435c77619777e8a449d9d27e01134d1"},
|
||||||
|
{file = "jiter-0.8.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:317b25e98a35ffec5c67efe56a4e9970852632c810d35b34ecdd70cc0e47b3b6"},
|
||||||
|
{file = "jiter-0.8.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fc9043259ee430ecd71d178fccabd8c332a3bf1e81e50cae43cc2b28d19e4cb7"},
|
||||||
|
{file = "jiter-0.8.2-cp38-cp38-win32.whl", hash = "sha256:fc5adda618205bd4678b146612ce44c3cbfdee9697951f2c0ffdef1f26d72b63"},
|
||||||
|
{file = "jiter-0.8.2-cp38-cp38-win_amd64.whl", hash = "sha256:cd646c827b4f85ef4a78e4e58f4f5854fae0caf3db91b59f0d73731448a970c6"},
|
||||||
|
{file = "jiter-0.8.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:e41e75344acef3fc59ba4765df29f107f309ca9e8eace5baacabd9217e52a5ee"},
|
||||||
|
{file = "jiter-0.8.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f22b16b35d5c1df9dfd58843ab2cd25e6bf15191f5a236bed177afade507bfc"},
|
||||||
|
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7200b8f7619d36aa51c803fd52020a2dfbea36ffec1b5e22cab11fd34d95a6d"},
|
||||||
|
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:70bf4c43652cc294040dbb62256c83c8718370c8b93dd93d934b9a7bf6c4f53c"},
|
||||||
|
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f9d471356dc16f84ed48768b8ee79f29514295c7295cb41e1133ec0b2b8d637d"},
|
||||||
|
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:859e8eb3507894093d01929e12e267f83b1d5f6221099d3ec976f0c995cb6bd9"},
|
||||||
|
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaa58399c01db555346647a907b4ef6d4f584b123943be6ed5588c3f2359c9f4"},
|
||||||
|
{file = "jiter-0.8.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8f2d5ed877f089862f4c7aacf3a542627c1496f972a34d0474ce85ee7d939c27"},
|
||||||
|
{file = "jiter-0.8.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:03c9df035d4f8d647f8c210ddc2ae0728387275340668fb30d2421e17d9a0841"},
|
||||||
|
{file = "jiter-0.8.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8bd2a824d08d8977bb2794ea2682f898ad3d8837932e3a74937e93d62ecbb637"},
|
||||||
|
{file = "jiter-0.8.2-cp39-cp39-win32.whl", hash = "sha256:ca29b6371ebc40e496995c94b988a101b9fbbed48a51190a4461fcb0a68b4a36"},
|
||||||
|
{file = "jiter-0.8.2-cp39-cp39-win_amd64.whl", hash = "sha256:1c0dfbd1be3cbefc7510102370d86e35d1d53e5a93d48519688b1bf0f761160a"},
|
||||||
|
{file = "jiter-0.8.2.tar.gz", hash = "sha256:cd73d3e740666d0e639f678adb176fad25c1bcbdae88d8d7b857e1783bb4212d"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mypy-extensions"
|
name = "mypy-extensions"
|
||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
@@ -271,6 +367,31 @@ files = [
|
|||||||
httpx = ">=0.27.0,<0.28.0"
|
httpx = ">=0.27.0,<0.28.0"
|
||||||
pydantic = ">=2.9.0,<3.0.0"
|
pydantic = ">=2.9.0,<3.0.0"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "openai"
|
||||||
|
version = "1.59.2"
|
||||||
|
description = "The official Python library for the openai API"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "openai-1.59.2-py3-none-any.whl", hash = "sha256:3de721df4d2ccc5e845afa7235dce496bfbdd572692a876d2ae6211e0290ff22"},
|
||||||
|
{file = "openai-1.59.2.tar.gz", hash = "sha256:1bf2d5e8a93533f6dd3fb7b1bcf082ddd4ae61cc6d89ca1343e5957e4720651c"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
anyio = ">=3.5.0,<5"
|
||||||
|
distro = ">=1.7.0,<2"
|
||||||
|
httpx = ">=0.23.0,<1"
|
||||||
|
jiter = ">=0.4.0,<1"
|
||||||
|
pydantic = ">=1.9.0,<3"
|
||||||
|
sniffio = "*"
|
||||||
|
tqdm = ">4"
|
||||||
|
typing-extensions = ">=4.11,<5"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"]
|
||||||
|
realtime = ["websockets (>=13,<15)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "packaging"
|
name = "packaging"
|
||||||
version = "24.2"
|
version = "24.2"
|
||||||
@@ -512,6 +633,20 @@ files = [
|
|||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
six = ">=1.5"
|
six = ">=1.5"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "python-dotenv"
|
||||||
|
version = "1.0.1"
|
||||||
|
description = "Read key-value pairs from a .env file and set them as environment variables"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"},
|
||||||
|
{file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
cli = ["click (>=5.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyyaml"
|
name = "pyyaml"
|
||||||
version = "6.0.2"
|
version = "6.0.2"
|
||||||
@@ -680,6 +815,27 @@ files = [
|
|||||||
{file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"},
|
{file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tqdm"
|
||||||
|
version = "4.67.1"
|
||||||
|
description = "Fast, Extensible Progress Meter"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"},
|
||||||
|
{file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["nbval", "pytest (>=6)", "pytest-asyncio (>=0.24)", "pytest-cov", "pytest-timeout"]
|
||||||
|
discord = ["requests"]
|
||||||
|
notebook = ["ipywidgets (>=6)"]
|
||||||
|
slack = ["slack-sdk"]
|
||||||
|
telegram = ["requests"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "types-protobuf"
|
name = "types-protobuf"
|
||||||
version = "5.29.1.20241207"
|
version = "5.29.1.20241207"
|
||||||
@@ -724,4 +880,4 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)",
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.9,<4.0"
|
python-versions = ">=3.9,<4.0"
|
||||||
content-hash = "20144fbc5773604251f9b61ac475eb9f292c6c8baf38d59170bbef81b482c71e"
|
content-hash = "ddef05a187e2f0364c8dc045d52aa30b7c95bc30075cfc0aa051af4fd1f8545b"
|
||||||
|
|||||||
@@ -7,30 +7,49 @@ def generate_genai_prompt(
|
|||||||
tools_data: ToolsData, conversation_history: str, raw_json: Optional[str] = None
|
tools_data: ToolsData, conversation_history: str, raw_json: Optional[str] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generates a concise prompt for producing or validating JSON instructions.
|
Generates a concise prompt for producing or validating JSON instructions
|
||||||
|
with the provided tools and conversation history.
|
||||||
"""
|
"""
|
||||||
prompt_lines = []
|
prompt_lines = []
|
||||||
|
|
||||||
# Intro / Role
|
# Intro / Role
|
||||||
prompt_lines.append(
|
prompt_lines.append(
|
||||||
"You are an AI assistant that must produce or validate JSON instructions "
|
"You are an AI agent that helps fill required arguments for the tools described below. "
|
||||||
"to properly call a set of tools. Respond with valid JSON only."
|
"You must respond with valid JSON ONLY, using the schema provided in the instructions."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Conversation History
|
# Main Conversation History
|
||||||
prompt_lines.append("=== Conversation History ===")
|
prompt_lines.append("=== Conversation History ===")
|
||||||
prompt_lines.append(
|
prompt_lines.append(
|
||||||
"Use this history to understand needed tools, arguments, and the user's goals:"
|
"This is the ongoing history to determine which tool and arguments to gather:"
|
||||||
)
|
)
|
||||||
prompt_lines.append("BEGIN CONVERSATION HISTORY")
|
prompt_lines.append("BEGIN CONVERSATION HISTORY")
|
||||||
prompt_lines.append(json.dumps(conversation_history, indent=2))
|
prompt_lines.append(json.dumps(conversation_history, indent=2))
|
||||||
prompt_lines.append("END CONVERSATION HISTORY")
|
prompt_lines.append("END CONVERSATION HISTORY")
|
||||||
prompt_lines.append("")
|
prompt_lines.append("")
|
||||||
|
|
||||||
|
# Example Conversation History (from tools_data)
|
||||||
|
if tools_data.example_conversation_history:
|
||||||
|
prompt_lines.append("=== Example Conversation With These Tools ===")
|
||||||
|
prompt_lines.append(
|
||||||
|
"Use this example to understand how tools are invoked and arguments are gathered."
|
||||||
|
)
|
||||||
|
prompt_lines.append("BEGIN EXAMPLE")
|
||||||
|
prompt_lines.append(tools_data.example_conversation_history)
|
||||||
|
prompt_lines.append("END EXAMPLE")
|
||||||
|
prompt_lines.append("")
|
||||||
|
|
||||||
# Tools Definitions
|
# Tools Definitions
|
||||||
prompt_lines.append("=== Tools Definitions ===")
|
prompt_lines.append("=== Tools Definitions ===")
|
||||||
prompt_lines.append(f"There are {len(tools_data.tools)} available tools:")
|
prompt_lines.append(f"There are {len(tools_data.tools)} available tools:")
|
||||||
prompt_lines.append(", ".join([t.name for t in tools_data.tools]))
|
prompt_lines.append(", ".join([t.name for t in tools_data.tools]))
|
||||||
|
prompt_lines.append(f"Goal: {tools_data.description}")
|
||||||
|
prompt_lines.append(
|
||||||
|
"Gather the necessary information for each tool in the sequence described above."
|
||||||
|
)
|
||||||
|
prompt_lines.append(
|
||||||
|
"Only ask for arguments listed below. Do not add extra arguments."
|
||||||
|
)
|
||||||
prompt_lines.append("")
|
prompt_lines.append("")
|
||||||
for tool in tools_data.tools:
|
for tool in tools_data.tools:
|
||||||
prompt_lines.append(f"Tool name: {tool.name}")
|
prompt_lines.append(f"Tool name: {tool.name}")
|
||||||
@@ -39,8 +58,11 @@ def generate_genai_prompt(
|
|||||||
for arg in tool.arguments:
|
for arg in tool.arguments:
|
||||||
prompt_lines.append(f" - {arg.name} ({arg.type}): {arg.description}")
|
prompt_lines.append(f" - {arg.name} ({arg.type}): {arg.description}")
|
||||||
prompt_lines.append("")
|
prompt_lines.append("")
|
||||||
|
prompt_lines.append(
|
||||||
|
"When all required args for a tool are known, you can propose next='confirm' to run it."
|
||||||
|
)
|
||||||
|
|
||||||
# Instructions for JSON Generation
|
# JSON Format Instructions
|
||||||
prompt_lines.append("=== Instructions for JSON Generation ===")
|
prompt_lines.append("=== Instructions for JSON Generation ===")
|
||||||
prompt_lines.append(
|
prompt_lines.append(
|
||||||
"Your JSON format must be:\n"
|
"Your JSON format must be:\n"
|
||||||
@@ -56,14 +78,14 @@ def generate_genai_prompt(
|
|||||||
"}"
|
"}"
|
||||||
)
|
)
|
||||||
prompt_lines.append(
|
prompt_lines.append(
|
||||||
"1. You may call multiple tools sequentially. Each requires specific arguments.\n"
|
"1) If any required argument is missing, set next='question' and ask the user.\n"
|
||||||
'2. If ANY required argument is missing, use "next": "question" and prompt the user.\n'
|
"2) If all required arguments are known, set next='confirm' and specify the tool.\n"
|
||||||
'3. If all required arguments are known, use "next": "confirm" and set "tool" to the tool name.\n'
|
" The user will confirm before the tool is run.\n"
|
||||||
'4. If no further actions are needed, use "next": "done" and "tool": "null".\n'
|
"3) If no more tools are needed, set next='done' and tool=null.\n"
|
||||||
'5. Keep "response" short and user-friendly. Do not include any metadata or editorializing.\n'
|
"4) response should be short and user-friendly."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validation Task (Only if raw_json is provided)
|
# Validation Task (If raw_json is provided)
|
||||||
if raw_json is not None:
|
if raw_json is not None:
|
||||||
prompt_lines.append("")
|
prompt_lines.append("")
|
||||||
prompt_lines.append("=== Validation Task ===")
|
prompt_lines.append("=== Validation Task ===")
|
||||||
@@ -71,60 +93,17 @@ def generate_genai_prompt(
|
|||||||
prompt_lines.append(json.dumps(raw_json, indent=2))
|
prompt_lines.append(json.dumps(raw_json, indent=2))
|
||||||
prompt_lines.append("")
|
prompt_lines.append("")
|
||||||
prompt_lines.append(
|
prompt_lines.append(
|
||||||
"Check syntax, ensure 'tool' is correct or 'null', verify 'args' are valid, "
|
"Check syntax, 'tool' validity, 'args' completeness, "
|
||||||
'and set "next" appropriately based on missing or complete args.'
|
"and set 'next' appropriately. Return ONLY corrected JSON."
|
||||||
)
|
|
||||||
prompt_lines.append("Return only the corrected JSON, no extra text.")
|
|
||||||
|
|
||||||
# Common Reminders and Examples
|
|
||||||
prompt_lines.append("")
|
|
||||||
prompt_lines.append("=== Usage Examples ===")
|
|
||||||
prompt_lines.append(
|
|
||||||
"Example for missing args (needs user input):\n"
|
|
||||||
"{\n"
|
|
||||||
' "response": "I need your departure city.",\n'
|
|
||||||
' "next": "question",\n'
|
|
||||||
' "tool": "SearchFlights",\n'
|
|
||||||
' "args": {\n'
|
|
||||||
' "origin": null,\n'
|
|
||||||
' "destination": "Melbourne",\n'
|
|
||||||
' "dateDepart": "2025-03-26",\n'
|
|
||||||
' "dateReturn": "2025-04-20"\n'
|
|
||||||
" }\n"
|
|
||||||
"}"
|
|
||||||
)
|
|
||||||
prompt_lines.append(
|
|
||||||
"Example for confirmed args:\n"
|
|
||||||
"{\n"
|
|
||||||
' "response": "All arguments are set.",\n'
|
|
||||||
' "next": "confirm",\n'
|
|
||||||
' "tool": "SearchFlights",\n'
|
|
||||||
' "args": {\n'
|
|
||||||
' "origin": "Seattle",\n'
|
|
||||||
' "destination": "Melbourne",\n'
|
|
||||||
' "dateDepart": "2025-03-26",\n'
|
|
||||||
' "dateReturn": "2025-04-20"\n'
|
|
||||||
" }\n"
|
|
||||||
"}"
|
|
||||||
)
|
|
||||||
prompt_lines.append(
|
|
||||||
"Example when fully done:\n"
|
|
||||||
"{\n"
|
|
||||||
' "response": "All tools completed successfully. Final result: <insert result here>",\n'
|
|
||||||
' "next": "done",\n'
|
|
||||||
' "tool": "",\n'
|
|
||||||
' "args": {}\n'
|
|
||||||
"}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prompt Start
|
# Prompt Start
|
||||||
if raw_json is not None:
|
|
||||||
prompt_lines.append("")
|
prompt_lines.append("")
|
||||||
|
if raw_json is not None:
|
||||||
prompt_lines.append("Begin by validating the provided JSON if necessary.")
|
prompt_lines.append("Begin by validating the provided JSON if necessary.")
|
||||||
else:
|
else:
|
||||||
prompt_lines.append("")
|
|
||||||
prompt_lines.append(
|
prompt_lines.append(
|
||||||
"Begin by producing a valid JSON response for the next step."
|
"Begin by producing a valid JSON response for the next tool or question."
|
||||||
)
|
)
|
||||||
|
|
||||||
return "\n".join(prompt_lines)
|
return "\n".join(prompt_lines)
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ ollama = "^0.4.5"
|
|||||||
pyyaml = "^6.0.2"
|
pyyaml = "^6.0.2"
|
||||||
fastapi = "^0.115.6"
|
fastapi = "^0.115.6"
|
||||||
uvicorn = "^0.34.0"
|
uvicorn = "^0.34.0"
|
||||||
|
python-dotenv = "^1.0.1"
|
||||||
|
openai = "^1.59.2"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pytest = "^7.3"
|
pytest = "^7.3"
|
||||||
|
|||||||
@@ -7,6 +7,9 @@ from temporalio.worker import Worker
|
|||||||
|
|
||||||
from activities.tool_activities import ToolActivities, dynamic_tool_activity
|
from activities.tool_activities import ToolActivities, dynamic_tool_activity
|
||||||
from workflows.tool_workflow import ToolWorkflow
|
from workflows.tool_workflow import ToolWorkflow
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|||||||
@@ -3,13 +3,16 @@ import sys
|
|||||||
from temporalio.client import Client
|
from temporalio.client import Client
|
||||||
|
|
||||||
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
|
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
|
||||||
from tools.tool_registry import all_tools
|
from tools.tool_registry import event_travel_tools
|
||||||
from workflows.tool_workflow import ToolWorkflow
|
from workflows.tool_workflow import ToolWorkflow
|
||||||
|
|
||||||
|
|
||||||
async def main(prompt: str):
|
async def main(prompt: str):
|
||||||
# 1) Build the ToolsData from imported all_tools
|
# Build the ToolsData
|
||||||
tools_data = ToolsData(tools=all_tools)
|
tools_data = ToolsData(
|
||||||
|
tools=event_travel_tools,
|
||||||
|
description="Helps the user find an event to travel to, search flights, and create an invoice for those flights.",
|
||||||
|
)
|
||||||
|
|
||||||
# 2) Create combined input
|
# 2) Create combined input
|
||||||
combined_input = CombinedInput(
|
combined_input = CombinedInput(
|
||||||
|
|||||||
@@ -1,17 +1,17 @@
|
|||||||
def find_events(args: dict) -> dict:
|
def find_events(args: dict) -> dict:
|
||||||
# Example: continent="Oceania", month="April"
|
# Example: continent="Oceania", month="April"
|
||||||
continent = args.get("continent")
|
region = args.get("region")
|
||||||
month = args.get("month")
|
month = args.get("month")
|
||||||
print(f"[FindEvents] Searching events in {continent} for {month} ...")
|
print(f"[FindEvents] Searching events in {region} for {month} ...")
|
||||||
|
|
||||||
# Stub result
|
# Stub result
|
||||||
return {
|
return {
|
||||||
"eventsFound": [
|
"events": [
|
||||||
{
|
{
|
||||||
"city": "Melbourne",
|
"city": "Melbourne",
|
||||||
"eventName": "Melbourne International Comedy Festival",
|
"eventName": "Melbourne International Comedy Festival",
|
||||||
"dates": "2025-03-26 to 2025-04-20",
|
"dateFrom": "2025-03-26",
|
||||||
|
"dateTo": "2025-04-20",
|
||||||
},
|
},
|
||||||
],
|
]
|
||||||
"status": "found-events",
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,21 +4,20 @@ def search_flights(args: dict) -> dict:
|
|||||||
Currently just prints/returns the passed args,
|
Currently just prints/returns the passed args,
|
||||||
but you can add real flight search logic later.
|
but you can add real flight search logic later.
|
||||||
"""
|
"""
|
||||||
date_depart = args.get("dateDepart")
|
# date_depart = args.get("dateDepart")
|
||||||
date_return = args.get("dateReturn")
|
# date_return = args.get("dateReturn")
|
||||||
origin = args.get("origin")
|
origin = args.get("origin")
|
||||||
destination = args.get("destination")
|
destination = args.get("destination")
|
||||||
|
|
||||||
print(f"Searching flights from {origin} to {destination}")
|
flight_search_results = {
|
||||||
print(f"Depart: {date_depart}, Return: {date_return}")
|
"origin": f"{origin}",
|
||||||
|
"destination": f"{destination}",
|
||||||
# Return a mock result so you can verify it
|
"currency": "USD",
|
||||||
return {
|
"results": [
|
||||||
"tool": "SearchFlights",
|
{"flight_number": "CX101", "return_flight_number": "CX102", "price": 850.0},
|
||||||
"searchResults": [
|
{"flight_number": "QF30", "return_flight_number": "QF29", "price": 920.0},
|
||||||
"QF123: $1200",
|
{"flight_number": "MH129", "return_flight_number": "MH128", "price": 780.0},
|
||||||
"VA456: $1000",
|
|
||||||
],
|
],
|
||||||
"status": "search-complete",
|
|
||||||
"args": args,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return flight_search_results
|
||||||
|
|||||||
@@ -2,12 +2,12 @@ from models.tool_definitions import ToolDefinition, ToolArgument
|
|||||||
|
|
||||||
find_events_tool = ToolDefinition(
|
find_events_tool = ToolDefinition(
|
||||||
name="FindEvents",
|
name="FindEvents",
|
||||||
description="Find upcoming events given a location or region (e.g., 'Oceania') and a date or month",
|
description="Find upcoming events to travel to given a location or region (e.g., 'Oceania') and a date or month",
|
||||||
arguments=[
|
arguments=[
|
||||||
ToolArgument(
|
ToolArgument(
|
||||||
name="continent",
|
name="region",
|
||||||
type="string",
|
type="string",
|
||||||
description="Which continent or region to search for events",
|
description="Which region to search for events",
|
||||||
),
|
),
|
||||||
ToolArgument(
|
ToolArgument(
|
||||||
name="month",
|
name="month",
|
||||||
@@ -20,7 +20,7 @@ find_events_tool = ToolDefinition(
|
|||||||
# 2) Define the SearchFlights tool
|
# 2) Define the SearchFlights tool
|
||||||
search_flights_tool = ToolDefinition(
|
search_flights_tool = ToolDefinition(
|
||||||
name="SearchFlights",
|
name="SearchFlights",
|
||||||
description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)",
|
description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn).",
|
||||||
arguments=[
|
arguments=[
|
||||||
ToolArgument(
|
ToolArgument(
|
||||||
name="origin",
|
name="origin",
|
||||||
@@ -48,7 +48,7 @@ search_flights_tool = ToolDefinition(
|
|||||||
# 3) Define the CreateInvoice tool
|
# 3) Define the CreateInvoice tool
|
||||||
create_invoice_tool = ToolDefinition(
|
create_invoice_tool = ToolDefinition(
|
||||||
name="CreateInvoice",
|
name="CreateInvoice",
|
||||||
description="Generate an invoice with flight information.",
|
description="Generate an invoice for the items described for the amount provided",
|
||||||
arguments=[
|
arguments=[
|
||||||
ToolArgument(
|
ToolArgument(
|
||||||
name="amount",
|
name="amount",
|
||||||
@@ -58,9 +58,7 @@ create_invoice_tool = ToolDefinition(
|
|||||||
ToolArgument(
|
ToolArgument(
|
||||||
name="flightDetails",
|
name="flightDetails",
|
||||||
type="string",
|
type="string",
|
||||||
description="A summary of the flights, e.g., flight number and airport codes",
|
description="A description of the item details to be invoiced",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
all_tools = [find_events_tool, search_flights_tool, create_invoice_tool]
|
|
||||||
|
|||||||
@@ -63,8 +63,8 @@ class ToolWorkflow:
|
|||||||
|
|
||||||
confirmed_tool_data = self.tool_data.copy()
|
confirmed_tool_data = self.tool_data.copy()
|
||||||
|
|
||||||
confirmed_tool_data["next"] = "confirmed"
|
confirmed_tool_data["next"] = "user_confirmed_tool_run"
|
||||||
self.add_message("userToolConfirm", confirmed_tool_data)
|
self.add_message("user_confirmed_tool_run", confirmed_tool_data)
|
||||||
|
|
||||||
# Run the tool
|
# Run the tool
|
||||||
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
|
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
|
||||||
@@ -81,10 +81,8 @@ class ToolWorkflow:
|
|||||||
# Enqueue a follow-up prompt for the LLM
|
# Enqueue a follow-up prompt for the LLM
|
||||||
self.prompt_queue.append(
|
self.prompt_queue.append(
|
||||||
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
|
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
|
||||||
"INSTRUCTIONS: Use this tool result, and the conversation history to figure out next steps, if any. "
|
"INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. "
|
||||||
"IMPORTANT REMINDER: Always return only JSON in the format: {'response': '', 'next': '', 'tool': '', 'args': {}} "
|
"DON'T ask any clarifying questions that are outside of the tools and args specified. "
|
||||||
" Do NOT include any metadata or editorializing in the response. "
|
|
||||||
"IMPORTANT: If moving on to another tool then ensure you ask next='question' for any missing arguments."
|
|
||||||
)
|
)
|
||||||
# Loop around again
|
# Loop around again
|
||||||
continue
|
continue
|
||||||
@@ -102,11 +100,13 @@ class ToolWorkflow:
|
|||||||
tools_data, self.conversation_history, self.tool_data
|
tools_data, self.conversation_history, self.tool_data
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# tools_list = ", ".join([t.name for t in tools_data.tools])
|
||||||
|
|
||||||
prompt_input = ToolPromptInput(
|
prompt_input = ToolPromptInput(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
context_instructions=context_instructions,
|
context_instructions=context_instructions,
|
||||||
)
|
)
|
||||||
tool_data = await workflow.execute_activity_method(
|
tool_data = await workflow.execute_activity(
|
||||||
ToolActivities.prompt_llm,
|
ToolActivities.prompt_llm,
|
||||||
prompt_input,
|
prompt_input,
|
||||||
schedule_to_close_timeout=timedelta(seconds=60),
|
schedule_to_close_timeout=timedelta(seconds=60),
|
||||||
@@ -115,13 +115,37 @@ class ToolWorkflow:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.tool_data = tool_data
|
self.tool_data = tool_data
|
||||||
self.add_message("response", tool_data)
|
|
||||||
|
|
||||||
# Check the next step from LLM
|
# Check the next step from LLM
|
||||||
next_step = self.tool_data.get("next")
|
next_step = self.tool_data.get("next")
|
||||||
current_tool = self.tool_data.get("tool")
|
current_tool = self.tool_data.get("tool")
|
||||||
|
|
||||||
if next_step == "confirm" and current_tool:
|
if next_step == "confirm" and current_tool:
|
||||||
|
# tmp arg check
|
||||||
|
args = self.tool_data.get("args")
|
||||||
|
|
||||||
|
# check each argument for null values
|
||||||
|
missing_args = []
|
||||||
|
for key, value in args.items():
|
||||||
|
if value is None:
|
||||||
|
next_step = "question"
|
||||||
|
missing_args.append(key)
|
||||||
|
|
||||||
|
if missing_args:
|
||||||
|
# self.add_message("response_confirm_missing_args", tool_data)
|
||||||
|
|
||||||
|
# Enqueue a follow-up prompt for the LLM
|
||||||
|
self.prompt_queue.append(
|
||||||
|
f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' and following missing arguments for tool {current_tool}: {missing_args}. "
|
||||||
|
"Only provide a valid JSON response without any comments or metadata."
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow.logger.info(
|
||||||
|
f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}"
|
||||||
|
)
|
||||||
|
# Loop around again
|
||||||
|
continue
|
||||||
|
|
||||||
waiting_for_confirm = True
|
waiting_for_confirm = True
|
||||||
self.confirm = False # Clear any stale confirm
|
self.confirm = False # Clear any stale confirm
|
||||||
workflow.logger.info("Waiting for user confirm signal...")
|
workflow.logger.info("Waiting for user confirm signal...")
|
||||||
@@ -130,8 +154,11 @@ class ToolWorkflow:
|
|||||||
|
|
||||||
elif next_step == "done":
|
elif next_step == "done":
|
||||||
workflow.logger.info("All steps completed. Exiting workflow.")
|
workflow.logger.info("All steps completed. Exiting workflow.")
|
||||||
|
self.add_message("agent", tool_data)
|
||||||
return str(self.conversation_history)
|
return str(self.conversation_history)
|
||||||
|
|
||||||
|
self.add_message("agent", tool_data)
|
||||||
|
|
||||||
# Possibly continue-as-new after many turns
|
# Possibly continue-as-new after many turns
|
||||||
# todo ensure this doesn't lose critical context
|
# todo ensure this doesn't lose critical context
|
||||||
if (
|
if (
|
||||||
|
|||||||
Reference in New Issue
Block a user