mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 05:58:08 +01:00
basic react API
This commit is contained in:
11
README.md
11
README.md
@@ -39,10 +39,13 @@ Run query get_tool_data to see the data the tool has collected so far.
|
||||
- `poetry run uvicorn api.main:app --reload` to start the API server.
|
||||
- Access the API at `/docs` to see the available endpoints.
|
||||
|
||||
## UI
|
||||
TODO: Document /frontend react app running instructions.
|
||||
|
||||
## TODO
|
||||
- The LLM prompts move through 3 mock tools (FindEvents, SearchFlights, CreateInvoice) but I should make them contact real APIs.
|
||||
- Might need to abstract the json example in the prompt generator to be part of a ToolDefinition (prevent overfitting to the example).
|
||||
- I need to build a chat interface so it's not cli-controlled. Also want to show some 'behind the scenes' of the agents being used as they run.
|
||||
- What happens if I don't want to confirm a step, but instead want to correct it? TODO figure out
|
||||
- What happens if I am at confirmation step and want to end the chat (do I need some sort of signal router?)
|
||||
- Currently hardcoded to the Temporal dev server at localhost:7233. Need to support options incl Temporal Cloud.
|
||||
- Currently hardcoded to the Temporal dev server at localhost:7233. Need to support options incl Temporal Cloud.
|
||||
- UI: A bit ugly
|
||||
- UI: Tool Confirmed state could be better represented
|
||||
- UI: 'Start new chat' button needs to handle better
|
||||
61
api/main.py
61
api/main.py
@@ -2,10 +2,21 @@ from fastapi import FastAPI
|
||||
from temporalio.client import Client
|
||||
from workflows.tool_workflow import ToolWorkflow
|
||||
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
|
||||
from temporalio.exceptions import TemporalError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from tools.tool_registry import all_tools
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost:5173"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def root():
|
||||
@@ -16,9 +27,37 @@ def root():
|
||||
async def get_tool_data():
|
||||
"""Calls the workflow's 'get_tool_data' query."""
|
||||
client = await Client.connect("localhost:7233")
|
||||
handle = client.get_workflow_handle("agent-workflow")
|
||||
tool_data = await handle.query(ToolWorkflow.get_tool_data)
|
||||
return tool_data
|
||||
try:
|
||||
# Get workflow handle
|
||||
handle = client.get_workflow_handle("agent-workflow")
|
||||
|
||||
# Check if the workflow is completed
|
||||
workflow_status = await handle.describe()
|
||||
if workflow_status.status == 2:
|
||||
# Workflow is completed; return an empty response
|
||||
return {}
|
||||
|
||||
# Query the workflow
|
||||
tool_data = await handle.query("get_tool_data")
|
||||
return tool_data
|
||||
except TemporalError as e:
|
||||
# Workflow not found; return an empty response
|
||||
print(e)
|
||||
return {}
|
||||
|
||||
|
||||
@app.get("/get-conversation-history")
|
||||
async def get_conversation_history():
|
||||
"""Calls the workflow's 'get_conversation_history' query."""
|
||||
client = await Client.connect("localhost:7233")
|
||||
try:
|
||||
handle = client.get_workflow_handle("agent-workflow")
|
||||
conversation_history = await handle.query("get_conversation_history")
|
||||
|
||||
return conversation_history
|
||||
except TemporalError as e:
|
||||
print(e)
|
||||
return []
|
||||
|
||||
|
||||
@app.post("/send-prompt")
|
||||
@@ -57,3 +96,19 @@ async def send_confirm():
|
||||
handle = client.get_workflow_handle(workflow_id)
|
||||
await handle.signal("confirm")
|
||||
return {"message": "Confirm signal sent."}
|
||||
|
||||
|
||||
@app.post("/end-chat")
|
||||
async def end_chat():
|
||||
"""Sends a 'end_chat' signal to the workflow."""
|
||||
client = await Client.connect("localhost:7233")
|
||||
workflow_id = "agent-workflow"
|
||||
|
||||
try:
|
||||
handle = client.get_workflow_handle(workflow_id)
|
||||
await handle.signal("end_chat")
|
||||
return {"message": "End chat signal sent."}
|
||||
except TemporalError as e:
|
||||
print(e)
|
||||
# Workflow not found; return an empty response
|
||||
return {}
|
||||
|
||||
53
frontend/.gitignore
vendored
Normal file
53
frontend/.gitignore
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
# Node.js
|
||||
node_modules/
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
logs/*.log
|
||||
|
||||
# OS-specific files
|
||||
.DS_Store
|
||||
|
||||
# Build output
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Dependency directories
|
||||
jspm_packages/
|
||||
|
||||
# Testing
|
||||
coverage/
|
||||
|
||||
# Next.js
|
||||
.next/
|
||||
|
||||
# Vite
|
||||
.vite/
|
||||
|
||||
# Parcel
|
||||
.cache/
|
||||
|
||||
# Environment files
|
||||
.env
|
||||
.env.*.local
|
||||
|
||||
# Editor files
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.temp
|
||||
*.bak
|
||||
*.orig
|
||||
|
||||
# Lock files
|
||||
*.lock
|
||||
|
||||
# Others
|
||||
public/**/*.cache
|
||||
12
frontend/index.html
Normal file
12
frontend/index.html
Normal file
@@ -0,0 +1,12 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<title>Temporal AI Agent</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<script type="module" src="/src/main.jsx"></script>
|
||||
</head>
|
||||
<body class="bg-gray-100 text-gray-900 dark:bg-gray-800 dark:text-gray-100">
|
||||
<div id="root"></div>
|
||||
</body>
|
||||
</html>
|
||||
2926
frontend/package-lock.json
generated
Normal file
2926
frontend/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
24
frontend/package.json
Normal file
24
frontend/package.json
Normal file
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"name": "temporal-ai-agent-frontend",
|
||||
"version": "1.0.0",
|
||||
"description": "React and Tailwind",
|
||||
"license": "ISC",
|
||||
"author": "",
|
||||
"type": "commonjs",
|
||||
"main": "index.js",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build"
|
||||
},
|
||||
"dependencies": {
|
||||
"@vitejs/plugin-react": "^4.3.4",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
"vite": "^6.0.7"
|
||||
},
|
||||
"devDependencies": {
|
||||
"autoprefixer": "^10.4.20",
|
||||
"postcss": "^8.4.49",
|
||||
"tailwindcss": "^3.4.17"
|
||||
}
|
||||
}
|
||||
6
frontend/postcss.config.js
Normal file
6
frontend/postcss.config.js
Normal file
@@ -0,0 +1,6 @@
|
||||
module.exports = {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
},
|
||||
};
|
||||
66
frontend/src/components/ChatWindow.jsx
Normal file
66
frontend/src/components/ChatWindow.jsx
Normal file
@@ -0,0 +1,66 @@
|
||||
import React from "react";
|
||||
import LLMResponse from "./LLMResponse";
|
||||
import MessageBubble from "./MessageBubble";
|
||||
import LoadingIndicator from "./LoadingIndicator";
|
||||
|
||||
function safeParse(str) {
|
||||
try {
|
||||
return JSON.parse(str);
|
||||
} catch (err) {
|
||||
console.error("safeParse error:", err, "Original string:", str);
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
export default function ChatWindow({ conversation, loading, onConfirm }) {
|
||||
if (!Array.isArray(conversation)) {
|
||||
console.error("ChatWindow expected conversation to be an array, got:", conversation);
|
||||
return null;
|
||||
}
|
||||
|
||||
const filtered = conversation.filter((msg) => {
|
||||
const { actor, response } = msg;
|
||||
|
||||
if (actor === "user") {
|
||||
return true;
|
||||
}
|
||||
if (actor === "response") {
|
||||
const parsed = typeof response === "string" ? safeParse(response) : response;
|
||||
// Keep if next is "question", "confirm", or "confirmed".
|
||||
// Only skip if next is "done" (or something else).
|
||||
return !["done"].includes(parsed.next);
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
return (
|
||||
<div className="flex-grow overflow-y-auto space-y-4">
|
||||
{filtered.map((msg, idx) => {
|
||||
const { actor, response } = msg;
|
||||
|
||||
if (actor === "user") {
|
||||
return (
|
||||
<MessageBubble key={idx} message={{ response }} isUser />
|
||||
);
|
||||
} else if (actor === "response") {
|
||||
const data =
|
||||
typeof response === "string" ? safeParse(response) : response;
|
||||
return <LLMResponse key={idx} data={data} onConfirm={onConfirm} />;
|
||||
}
|
||||
return null;
|
||||
})}
|
||||
|
||||
{/* If loading = true, show the spinner at the bottom */}
|
||||
{loading && (
|
||||
<div className="flex justify-center">
|
||||
<LoadingIndicator />
|
||||
</div>
|
||||
)}
|
||||
{conversation.length > 0 && conversation[conversation.length - 1].actor === "user" && (
|
||||
<div className="flex justify-center">
|
||||
<LoadingIndicator />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
59
frontend/src/components/ConfirmInline.jsx
Normal file
59
frontend/src/components/ConfirmInline.jsx
Normal file
@@ -0,0 +1,59 @@
|
||||
import React from "react";
|
||||
import LoadingIndicator from "./LoadingIndicator";
|
||||
|
||||
export default function ConfirmInline({ data, confirmed, onConfirm }) {
|
||||
const { args, tool } = data || {};
|
||||
|
||||
console.log("ConfirmInline rendered with confirmed:", confirmed);
|
||||
|
||||
if (confirmed) {
|
||||
// Once confirmed, show "Running..." state in the same container
|
||||
return (
|
||||
<div className="mt-2 p-2 border border-gray-400 dark:border-gray-600 rounded bg-gray-50 dark:bg-gray-800">
|
||||
<div className="text-sm text-gray-600 dark:text-gray-300">
|
||||
<div>
|
||||
<strong>Tool:</strong> {tool ?? "Unknown"}
|
||||
</div>
|
||||
{args && (
|
||||
<div className="mt-1">
|
||||
<strong>Args:</strong>
|
||||
<pre className="bg-gray-100 dark:bg-gray-700 p-1 rounded text-xs whitespace-pre-wrap">
|
||||
{JSON.stringify(args, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-2 text-green-600 dark:text-green-400 font-medium">
|
||||
Running {tool}...
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Not confirmed yet → show confirmation UI
|
||||
return (
|
||||
<div className="mt-2 p-2 border border-gray-400 dark:border-gray-600 rounded bg-gray-50 dark:bg-gray-800">
|
||||
<div className="text-sm text-gray-600 dark:text-gray-300">
|
||||
<div>
|
||||
<strong>Tool:</strong> {tool ?? "Unknown"}
|
||||
</div>
|
||||
{args && (
|
||||
<div className="mt-1">
|
||||
<strong>Args:</strong>
|
||||
<pre className="bg-gray-100 dark:bg-gray-700 p-1 rounded text-xs whitespace-pre-wrap">
|
||||
{JSON.stringify(args, null, 2)}
|
||||
</pre>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="text-right mt-2">
|
||||
<button
|
||||
onClick={onConfirm}
|
||||
className="bg-green-600 hover:bg-green-700 text-white px-3 py-1 rounded"
|
||||
>
|
||||
Confirm
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
34
frontend/src/components/LLMResponse.jsx
Normal file
34
frontend/src/components/LLMResponse.jsx
Normal file
@@ -0,0 +1,34 @@
|
||||
import React, { useState } from "react";
|
||||
import MessageBubble from "./MessageBubble";
|
||||
import ConfirmInline from "./ConfirmInline";
|
||||
|
||||
export default function LLMResponse({ data, onConfirm }) {
|
||||
const [isConfirmed, setIsConfirmed] = useState(false);
|
||||
|
||||
const handleConfirm = async () => {
|
||||
if (onConfirm) {
|
||||
await onConfirm();
|
||||
}
|
||||
setIsConfirmed(true); // Update state after confirmation
|
||||
};
|
||||
|
||||
const requiresConfirm = data.next === "confirm";
|
||||
|
||||
let displayText = (data.response || "").trim();
|
||||
if (!displayText && requiresConfirm) {
|
||||
displayText = `Agent is ready to run "${data.tool}". Please confirm.`;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<MessageBubble message={{ response: displayText }} />
|
||||
{requiresConfirm && (
|
||||
<ConfirmInline
|
||||
data={data}
|
||||
confirmed={isConfirmed}
|
||||
onConfirm={handleConfirm}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
11
frontend/src/components/LoadingIndicator.jsx
Normal file
11
frontend/src/components/LoadingIndicator.jsx
Normal file
@@ -0,0 +1,11 @@
|
||||
import React from "react";
|
||||
|
||||
export default function LoadingIndicator() {
|
||||
return (
|
||||
<div className="flex items-center justify-center space-x-2">
|
||||
<div className="w-2 h-2 rounded-full bg-blue-600 animate-ping"></div>
|
||||
<div className="w-2 h-2 rounded-full bg-blue-600 animate-ping delay-100"></div>
|
||||
<div className="w-2 h-2 rounded-full bg-blue-600 animate-ping delay-200"></div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
27
frontend/src/components/MessageBubble.jsx
Normal file
27
frontend/src/components/MessageBubble.jsx
Normal file
@@ -0,0 +1,27 @@
|
||||
import React from "react";
|
||||
|
||||
export default function MessageBubble({ message, fallback = "", isUser = false }) {
|
||||
// Use isUser directly instead of message.user
|
||||
const bubbleStyle = isUser
|
||||
? "bg-blue-600 text-white self-end"
|
||||
: "bg-gray-300 text-gray-900 dark:bg-gray-600 dark:text-gray-100";
|
||||
|
||||
// If message.response is empty or whitespace, use fallback text
|
||||
const displayText = message.response?.trim() ? message.response : fallback;
|
||||
|
||||
// Skip display entirely if text starts with ###
|
||||
if (displayText.startsWith("###")) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`max-w-xs md:max-w-sm px-4 py-2 mb-1 rounded-lg ${
|
||||
isUser ? "ml-auto" : "mr-auto"
|
||||
} ${bubbleStyle}`}
|
||||
style={{ whiteSpace: "pre-wrap" }}
|
||||
>
|
||||
{displayText}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
9
frontend/src/components/NavBar.jsx
Normal file
9
frontend/src/components/NavBar.jsx
Normal file
@@ -0,0 +1,9 @@
|
||||
import React from "react";
|
||||
|
||||
export default function NavBar({ title }) {
|
||||
return (
|
||||
<div className="bg-gray-200 dark:bg-gray-700 p-4 shadow-sm">
|
||||
<h1 className="text-xl font-bold">{title}</h1>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
19
frontend/src/hooks/useLocalChatHistory.js
Normal file
19
frontend/src/hooks/useLocalChatHistory.js
Normal file
@@ -0,0 +1,19 @@
|
||||
import { useState, useEffect } from "react";
|
||||
|
||||
export default function useLocalChatHistory(key, initialValue) {
|
||||
const [state, setState] = useState(() => {
|
||||
try {
|
||||
const stored = window.localStorage.getItem(key);
|
||||
return stored ? JSON.parse(stored) : initialValue;
|
||||
} catch (err) {
|
||||
console.error("Error parsing localStorage:", err);
|
||||
return initialValue;
|
||||
}
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
window.localStorage.setItem(key, JSON.stringify(state));
|
||||
}, [key, state]);
|
||||
|
||||
return [state, setState];
|
||||
}
|
||||
3
frontend/src/index.css
Normal file
3
frontend/src/index.css
Normal file
@@ -0,0 +1,3 @@
|
||||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
||||
9
frontend/src/main.jsx
Normal file
9
frontend/src/main.jsx
Normal file
@@ -0,0 +1,9 @@
|
||||
import React from "react";
|
||||
import { createRoot } from "react-dom/client";
|
||||
import App from "./pages/App";
|
||||
import "./index.css"; // Tailwind imports
|
||||
|
||||
const container = document.getElementById("root");
|
||||
const root = createRoot(container);
|
||||
|
||||
root.render(<App />);
|
||||
114
frontend/src/pages/App.jsx
Normal file
114
frontend/src/pages/App.jsx
Normal file
@@ -0,0 +1,114 @@
|
||||
import React, { useEffect, useState } from "react";
|
||||
import NavBar from "../components/NavBar";
|
||||
import ChatWindow from "../components/ChatWindow";
|
||||
|
||||
const POLL_INTERVAL = 500; // 0.5 seconds
|
||||
|
||||
export default function App() {
|
||||
const [conversation, setConversation] = useState([]);
|
||||
const [userInput, setUserInput] = useState("");
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
// Poll /get-conversation-history once per second
|
||||
const intervalId = setInterval(async () => {
|
||||
try {
|
||||
const res = await fetch("http://127.0.0.1:8000/get-conversation-history");
|
||||
if (res.ok) {
|
||||
const data = await res.json();
|
||||
// data is now an object like { messages: [ ... ] }
|
||||
|
||||
if (data.messages && data.messages.some(msg => msg.actor === "response" || msg.actor === "tool_result")) {
|
||||
setLoading(false);
|
||||
}
|
||||
setConversation(data.messages || []);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Error fetching conversation history:", err);
|
||||
}
|
||||
}, POLL_INTERVAL);
|
||||
|
||||
return () => clearInterval(intervalId);
|
||||
}, []);
|
||||
|
||||
const handleSendMessage = async () => {
|
||||
if (!userInput.trim()) return;
|
||||
try {
|
||||
setLoading(true); // <--- Mark as loading
|
||||
await fetch(
|
||||
`http://127.0.0.1:8000/send-prompt?prompt=${encodeURIComponent(userInput)}`,
|
||||
{ method: "POST" }
|
||||
);
|
||||
setUserInput("");
|
||||
} catch (err) {
|
||||
console.error("Error sending prompt:", err);
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleConfirm = async () => {
|
||||
try {
|
||||
setLoading(true);
|
||||
await fetch("http://127.0.0.1:8000/confirm", { method: "POST" });
|
||||
} catch (err) {
|
||||
console.error("Confirm error:", err);
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleStartNewChat = async () => {
|
||||
try {
|
||||
await fetch("http://127.0.0.1:8000/end-chat", { method: "POST" });
|
||||
// sleep for a bit to allow the server to process the end-chat request
|
||||
await new Promise((resolve) => setTimeout(resolve, 4000)); // todo make less dodgy
|
||||
await fetch(
|
||||
`http://127.0.0.1:8000/send-prompt?prompt=${encodeURIComponent("I'd like to travel to an event.")}`,
|
||||
{ method: "POST" }
|
||||
);
|
||||
setConversation([]); // clear local state
|
||||
} catch (err) {
|
||||
console.error("Error ending chat:", err);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col min-h-screen">
|
||||
<NavBar title="Temporal AI Agent" />
|
||||
<div className="flex-grow flex justify-center px-4 py-6">
|
||||
<div className="w-full max-w-lg bg-white dark:bg-gray-900 p-4 rounded shadow-md flex flex-col">
|
||||
{/* Pass down the array of messages to ChatWindow */}
|
||||
<ChatWindow
|
||||
conversation={conversation}
|
||||
loading={loading}
|
||||
onConfirm={handleConfirm}
|
||||
/>
|
||||
|
||||
<div className="flex items-center mt-4">
|
||||
<input
|
||||
type="text"
|
||||
className="flex-grow rounded-l px-3 py-2 border border-gray-300
|
||||
dark:bg-gray-700 dark:border-gray-600 focus:outline-none"
|
||||
placeholder="Type your message..."
|
||||
value={userInput}
|
||||
onChange={(e) => setUserInput(e.target.value)}
|
||||
/>
|
||||
<button
|
||||
onClick={handleSendMessage}
|
||||
className="bg-blue-600 hover:bg-blue-700 text-white px-4 py-2 rounded-r"
|
||||
>
|
||||
Send
|
||||
</button>
|
||||
</div>
|
||||
<div className="text-right mt-3">
|
||||
<button
|
||||
onClick={handleStartNewChat}
|
||||
className="text-sm underline text-gray-600 dark:text-gray-400 hover:text-gray-800 dark:hover:text-gray-200"
|
||||
>
|
||||
Start New Chat
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
12
frontend/tailwind.config.js
Normal file
12
frontend/tailwind.config.js
Normal file
@@ -0,0 +1,12 @@
|
||||
/** @type {import('tailwindcss').Config} */
|
||||
module.exports = {
|
||||
content: [
|
||||
"./public/index.html",
|
||||
"./src/**/*.{js,jsx,ts,tsx}",
|
||||
],
|
||||
darkMode: "class", // enable dark mode by toggling a .dark class
|
||||
theme: {
|
||||
extend: {},
|
||||
},
|
||||
plugins: [],
|
||||
};
|
||||
9
frontend/vite.config.js
Normal file
9
frontend/vite.config.js
Normal file
@@ -0,0 +1,9 @@
|
||||
import { defineConfig } from 'vite';
|
||||
import react from '@vitejs/plugin-react';
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
server: {
|
||||
open: true,
|
||||
},
|
||||
});
|
||||
@@ -3,7 +3,7 @@ import sys
|
||||
from temporalio.client import Client
|
||||
|
||||
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
|
||||
from tools.tool_registry import all_tools # <–– Import your pre-defined tools
|
||||
from tools.tool_registry import all_tools
|
||||
from workflows.tool_workflow import ToolWorkflow
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections import deque
|
||||
from datetime import timedelta
|
||||
from typing import Deque, List, Optional, Tuple
|
||||
from typing import Dict, Any, Union, List, Optional, Tuple, Deque
|
||||
from temporalio.common import RetryPolicy
|
||||
|
||||
from temporalio import workflow
|
||||
@@ -16,7 +16,9 @@ with workflow.unsafe.imports_passed_through():
|
||||
@workflow.defn
|
||||
class ToolWorkflow:
|
||||
def __init__(self) -> None:
|
||||
self.conversation_history: List[Tuple[str, str]] = []
|
||||
self.conversation_history: Dict[
|
||||
str, List[Dict[str, Union[str, Dict[str, Any]]]]
|
||||
] = {"messages": []}
|
||||
self.prompt_queue: Deque[str] = deque()
|
||||
self.conversation_summary: Optional[str] = None
|
||||
self.chat_ended: bool = False
|
||||
@@ -31,9 +33,7 @@ class ToolWorkflow:
|
||||
tool_data = None
|
||||
|
||||
if params and params.conversation_summary:
|
||||
self.conversation_history.append(
|
||||
("conversation_summary", params.conversation_summary)
|
||||
)
|
||||
self.add_message("conversation_summary", params.conversation_summary)
|
||||
self.conversation_summary = params.conversation_summary
|
||||
|
||||
if params and params.prompt_queue:
|
||||
@@ -51,7 +51,7 @@ class ToolWorkflow:
|
||||
# 1) If chat_ended was signaled, handle end and return
|
||||
if self.chat_ended:
|
||||
# possibly do a summary if multiple turns
|
||||
if len(self.conversation_history) > 1:
|
||||
if len(self.conversation_history["messages"]) > 1:
|
||||
summary_context, summary_prompt = self.prompt_summary_with_history()
|
||||
summary_input = ToolPromptInput(
|
||||
prompt=summary_prompt, context_instructions=summary_context
|
||||
@@ -73,6 +73,11 @@ class ToolWorkflow:
|
||||
self.confirm = False
|
||||
waiting_for_confirm = False
|
||||
|
||||
confirmed_tool_data = self.tool_data.copy()
|
||||
|
||||
confirmed_tool_data["next"] = "confirmed"
|
||||
self.add_message("userToolConfirm", confirmed_tool_data)
|
||||
|
||||
# Run the tool
|
||||
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
|
||||
dynamic_result = await workflow.execute_activity(
|
||||
@@ -80,15 +85,14 @@ class ToolWorkflow:
|
||||
self.tool_data["args"],
|
||||
schedule_to_close_timeout=timedelta(seconds=20),
|
||||
)
|
||||
self.conversation_history.append(
|
||||
(f"{current_tool}_result", str(dynamic_result))
|
||||
)
|
||||
dynamic_result["tool"] = current_tool
|
||||
self.add_message(f"tool_result", dynamic_result)
|
||||
|
||||
# Enqueue a follow-up prompt for the LLM
|
||||
self.prompt_queue.append(
|
||||
f"The '{current_tool}' tool completed successfully with {dynamic_result}. "
|
||||
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
|
||||
"INSTRUCTIONS: Use this tool result, and the conversation history to figure out next steps. "
|
||||
"If all listed tools have run, then produce a done response."
|
||||
"IMPORTANT: If all listed tools have run, you are up to the final step. Mark 'next':'done' and respond with your final confirmation."
|
||||
)
|
||||
# Loop around again
|
||||
continue
|
||||
@@ -96,7 +100,11 @@ class ToolWorkflow:
|
||||
# 3) If there's a user prompt waiting, process it (unless we're in some other skipping logic).
|
||||
if self.prompt_queue:
|
||||
prompt = self.prompt_queue.popleft()
|
||||
self.conversation_history.append(("user", prompt))
|
||||
if prompt.startswith("###"):
|
||||
# this is a custom prompt where the tool result is sent to the LLM
|
||||
self.add_message("tool_result_to_llm", prompt)
|
||||
else:
|
||||
self.add_message("user", prompt)
|
||||
|
||||
# Pass entire conversation + Tools to LLM
|
||||
context_instructions = generate_genai_prompt(
|
||||
@@ -115,7 +123,7 @@ class ToolWorkflow:
|
||||
),
|
||||
)
|
||||
self.tool_data = tool_data
|
||||
self.conversation_history.append(("response", str(tool_data)))
|
||||
self.add_message("response", tool_data)
|
||||
|
||||
# Check the next step from LLM
|
||||
next_step = self.tool_data.get("next")
|
||||
@@ -134,7 +142,10 @@ class ToolWorkflow:
|
||||
|
||||
# Possibly continue-as-new after many turns
|
||||
# todo ensure this doesn't lose critical context
|
||||
if len(self.conversation_history) >= self.max_turns_before_continue:
|
||||
if (
|
||||
len(self.conversation_history["messages"])
|
||||
>= self.max_turns_before_continue
|
||||
):
|
||||
summary_context, summary_prompt = self.prompt_summary_with_history()
|
||||
summary_input = ToolPromptInput(
|
||||
prompt=summary_prompt, context_instructions=summary_context
|
||||
@@ -175,7 +186,10 @@ class ToolWorkflow:
|
||||
self.confirm = True
|
||||
|
||||
@workflow.query
|
||||
def get_conversation_history(self) -> List[Tuple[str, str]]:
|
||||
def get_conversation_history(
|
||||
self,
|
||||
) -> Dict[str, List[Dict[str, Union[str, Dict[str, Any]]]]]:
|
||||
# Return the whole conversation as a dict
|
||||
return self.conversation_history
|
||||
|
||||
@workflow.query
|
||||
@@ -187,8 +201,11 @@ class ToolWorkflow:
|
||||
return self.tool_data
|
||||
|
||||
# Helper: generate text of the entire conversation so far
|
||||
|
||||
def format_history(self) -> str:
|
||||
return " ".join(f"{text}" for _, text in self.conversation_history)
|
||||
return " ".join(
|
||||
str(msg["response"]) for msg in self.conversation_history["messages"]
|
||||
)
|
||||
|
||||
# Return (context_instructions, prompt)
|
||||
def prompt_with_history(self, prompt: str) -> tuple[str, str]:
|
||||
@@ -210,3 +227,9 @@ class ToolWorkflow:
|
||||
'Put the summary in the format { "summary": "<plain text>" }'
|
||||
)
|
||||
return (context_instructions, actual_prompt)
|
||||
|
||||
def add_message(self, actor: str, response: Union[str, Dict[str, Any]]) -> None:
|
||||
# Append a message object to the "messages" list
|
||||
self.conversation_history["messages"].append(
|
||||
{"actor": actor, "response": response}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user