mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 05:58:08 +01:00
Merge pull request #5 from steveandroulakis/validator-and-improvements
Validator and improvements
This commit is contained in:
10
README.md
10
README.md
@@ -68,6 +68,7 @@ Note: I found the other (hosted) LLMs to be MUCH more reliable for this use case
|
||||
* Requires a Stripe key for the `create_invoice` tool. Set this in the `STRIPE_API_KEY` environment variable in .env
|
||||
* It's free to sign up and get a key at [Stripe](https://stripe.com/)
|
||||
* If you're lazy go to `tools/create_invoice.py` and replace the `create_invoice` function with the mock `create_invoice_example` that exists in the same file.
|
||||
* Requires a key from [Football Data](https://www.football-data.org). Sign up for a free account, then see the 'My Account' page to get your API token. Set `FOOTBALL_DATA_API_KEY` to this value.
|
||||
|
||||
## Configuring Temporal Connection
|
||||
|
||||
@@ -137,3 +138,12 @@ Access the UI at `http://localhost:5173`
|
||||
- Continue-as-new shouldn't be a big consideration for this use case (as it would take many conversational turns to trigger). Regardless, I should ensure that it's able to carry the agent state over to the new workflow execution.
|
||||
- Perhaps the UI should show when the LLM response is being retried (i.e. activity retry attempt because the LLM provided bad output)
|
||||
- Tests would be nice!
|
||||
|
||||
# TODO for this branch
|
||||
## Agent
|
||||
- We'll have to figure out which matches are where. No use going to Manchester for a match that isn't there.
|
||||
- The use of `###` in prompts I want excluded from the conversation history is a bit of a hack.
|
||||
|
||||
## Validator function
|
||||
- Probably keep data types, but move the activity and workflow code for the demo
|
||||
- Probably don't need the validator function if its the result from a tool call or confirmation step
|
||||
@@ -1,4 +1,3 @@
|
||||
from dataclasses import dataclass
|
||||
from temporalio import activity
|
||||
from ollama import chat, ChatResponse
|
||||
from openai import OpenAI
|
||||
@@ -11,6 +10,7 @@ import google.generativeai as genai
|
||||
import anthropic
|
||||
import deepseek
|
||||
from dotenv import load_dotenv
|
||||
from models.data_types import ValidationInput, ValidationResult, ToolPromptInput
|
||||
|
||||
load_dotenv(override=True)
|
||||
print(
|
||||
@@ -23,13 +23,66 @@ if os.environ.get("LLM_PROVIDER") == "ollama":
|
||||
print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolPromptInput:
|
||||
prompt: str
|
||||
context_instructions: str
|
||||
|
||||
|
||||
class ToolActivities:
|
||||
@activity.defn
|
||||
async def validate_llm_prompt(
|
||||
self, validation_input: ValidationInput
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validates the prompt in the context of the conversation history and agent goal.
|
||||
Returns a ValidationResult indicating if the prompt makes sense given the context.
|
||||
"""
|
||||
# Create simple context string describing tools and goals
|
||||
tools_description = []
|
||||
for tool in validation_input.agent_goal.tools:
|
||||
tool_str = f"Tool: {tool.name}\n"
|
||||
tool_str += f"Description: {tool.description}\n"
|
||||
tool_str += "Arguments: " + ", ".join(
|
||||
[f"{arg.name} ({arg.type})" for arg in tool.arguments]
|
||||
)
|
||||
tools_description.append(tool_str)
|
||||
tools_str = "\n".join(tools_description)
|
||||
|
||||
# Convert conversation history to string
|
||||
history_str = json.dumps(validation_input.conversation_history, indent=2)
|
||||
|
||||
# Create context instructions
|
||||
context_instructions = f"""The agent goal and tools are as follows:
|
||||
Description: {validation_input.agent_goal.description}
|
||||
Available Tools:
|
||||
{tools_str}
|
||||
The conversation history to date is:
|
||||
{history_str}"""
|
||||
|
||||
# Create validation prompt
|
||||
validation_prompt = f"""The user's prompt is: "{validation_input.prompt}"
|
||||
Please validate if this prompt makes sense given the agent goal and conversation history.
|
||||
If the prompt doesn't make sense toward the goal then validationResult should be true.
|
||||
Only return false if the prompt is nonsensical given the goal, tools available, and conversation history.
|
||||
Return ONLY a JSON object with the following structure:
|
||||
"validationResult": true/false,
|
||||
"validationFailedReason": "If validationResult is false, provide a clear explanation to the user
|
||||
about why their request doesn't make sense in the context and what information they should provide instead.
|
||||
validationFailedReason should contain JSON in the format
|
||||
{{
|
||||
"next": "question",
|
||||
"response": "[your reason here and a response to get the user back on track with the agent goal]"
|
||||
}}
|
||||
If validationResult is true, return an empty dict {{}}"
|
||||
"""
|
||||
|
||||
# Call the LLM with the validation prompt
|
||||
prompt_input = ToolPromptInput(
|
||||
prompt=validation_prompt, context_instructions=context_instructions
|
||||
)
|
||||
|
||||
result = self.prompt_llm(prompt_input)
|
||||
|
||||
return ValidationResult(
|
||||
validationResult=result.get("validationResult", False),
|
||||
validationFailedReason=result.get("validationFailedReason", {}),
|
||||
)
|
||||
|
||||
@activity.defn
|
||||
def prompt_llm(self, input: ToolPromptInput) -> dict:
|
||||
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
||||
|
||||
23
api/main.py
23
api/main.py
@@ -132,3 +132,26 @@ async def end_chat():
|
||||
print(e)
|
||||
# Workflow not found; return an empty response
|
||||
return {}
|
||||
|
||||
|
||||
@app.post("/start-workflow")
|
||||
async def start_workflow():
|
||||
# Create combined input
|
||||
combined_input = CombinedInput(
|
||||
tool_params=ToolWorkflowParams(None, None),
|
||||
agent_goal=goal_event_flight_invoice,
|
||||
)
|
||||
|
||||
workflow_id = "agent-workflow"
|
||||
|
||||
# Start the workflow with the starter prompt from the goal
|
||||
await temporal_client.start_workflow(
|
||||
ToolWorkflow.run,
|
||||
combined_input,
|
||||
id=workflow_id,
|
||||
task_queue=TEMPORAL_TASK_QUEUE,
|
||||
start_signal="user_prompt",
|
||||
start_signal_args=["### " + goal_event_flight_invoice.starter_prompt],
|
||||
)
|
||||
|
||||
return {"message": f"Workflow started with goal's starter prompt: {goal_event_flight_invoice.starter_prompt}."}
|
||||
|
||||
@@ -167,7 +167,7 @@ export default function App() {
|
||||
try {
|
||||
setError(INITIAL_ERROR_STATE);
|
||||
setLoading(true);
|
||||
await apiService.sendMessage("I'd like to travel for an event.");
|
||||
await apiService.startWorkflow();
|
||||
setConversation([]);
|
||||
setLastMessage(null);
|
||||
} catch (err) {
|
||||
|
||||
@@ -56,6 +56,26 @@ export const apiService = {
|
||||
}
|
||||
},
|
||||
|
||||
async startWorkflow() {
|
||||
try {
|
||||
const res = await fetch(
|
||||
`${API_BASE_URL}/start-workflow`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
}
|
||||
);
|
||||
return handleResponse(res);
|
||||
} catch (error) {
|
||||
throw new ApiError(
|
||||
'Failed to start workflow',
|
||||
error.status || 500
|
||||
);
|
||||
}
|
||||
},
|
||||
|
||||
async confirm() {
|
||||
try {
|
||||
const res = await fetch(`${API_BASE_URL}/confirm`, {
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Deque
|
||||
from typing import Optional, Deque, Dict, Any, List, Union, Literal
|
||||
from models.tool_definitions import AgentGoal
|
||||
|
||||
|
||||
@@ -13,3 +13,32 @@ class ToolWorkflowParams:
|
||||
class CombinedInput:
|
||||
tool_params: ToolWorkflowParams
|
||||
agent_goal: AgentGoal
|
||||
|
||||
|
||||
Message = Dict[str, Union[str, Dict[str, Any]]]
|
||||
ConversationHistory = Dict[str, List[Message]]
|
||||
NextStep = Literal["confirm", "question", "done"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolPromptInput:
|
||||
prompt: str
|
||||
context_instructions: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationInput:
|
||||
prompt: str
|
||||
conversation_history: ConversationHistory
|
||||
agent_goal: AgentGoal
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
validationResult: bool
|
||||
validationFailedReason: dict = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Initialize empty dict if None
|
||||
if self.validationFailedReason is None:
|
||||
self.validationFailedReason = {}
|
||||
|
||||
@@ -20,6 +20,7 @@ class ToolDefinition:
|
||||
class AgentGoal:
|
||||
tools: List[ToolDefinition]
|
||||
description: str = "Description of the tools purpose and overall goal"
|
||||
starter_prompt: str = "Initial prompt to start the conversation"
|
||||
example_conversation_history: str = (
|
||||
"Example conversation history to help the AI agent understand the context of the conversation"
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ async def main():
|
||||
workflows=[ToolWorkflow],
|
||||
activities=[
|
||||
activities.prompt_llm,
|
||||
activities.validate_llm_prompt,
|
||||
dynamic_tool_activity,
|
||||
],
|
||||
activity_executor=activity_executor,
|
||||
|
||||
@@ -11,6 +11,7 @@ goal_event_flight_invoice = AgentGoal(
|
||||
"1. FindFixtures: Find fixtures for a team in a given month "
|
||||
"2. SearchFlights: search for a flight around the event dates "
|
||||
"3. CreateInvoice: Create a simple invoice for the cost of that flight ",
|
||||
starter_prompt="Welcome me, give me a description of what you can do, then ask me for the details you need to do your job",
|
||||
example_conversation_history="\n ".join(
|
||||
[
|
||||
"user: I'd like to travel to a football match",
|
||||
|
||||
@@ -1,24 +1,21 @@
|
||||
from collections import deque
|
||||
from datetime import timedelta
|
||||
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict, Literal
|
||||
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict
|
||||
|
||||
from temporalio.common import RetryPolicy
|
||||
from temporalio import workflow
|
||||
|
||||
from models.data_types import ConversationHistory, NextStep, ValidationInput
|
||||
|
||||
with workflow.unsafe.imports_passed_through():
|
||||
from activities.tool_activities import ToolActivities, ToolPromptInput
|
||||
from activities.tool_activities import ToolActivities
|
||||
from prompts.agent_prompt_generators import generate_genai_prompt
|
||||
from models.data_types import CombinedInput, ToolWorkflowParams
|
||||
from models.data_types import CombinedInput, ToolWorkflowParams, ToolPromptInput
|
||||
|
||||
# Constants
|
||||
MAX_TURNS_BEFORE_CONTINUE = 250
|
||||
TOOL_ACTIVITY_TIMEOUT = timedelta(seconds=20)
|
||||
LLM_ACTIVITY_TIMEOUT = timedelta(seconds=60)
|
||||
|
||||
# Type definitions
|
||||
Message = Dict[str, Union[str, Dict[str, Any]]]
|
||||
ConversationHistory = Dict[str, List[Message]]
|
||||
NextStep = Literal["confirm", "question", "done"]
|
||||
LLM_ACTIVITY_TIMEOUT = timedelta(minutes=30)
|
||||
|
||||
|
||||
class ToolData(TypedDict, total=False):
|
||||
@@ -153,6 +150,26 @@ class ToolWorkflow:
|
||||
prompt = self.prompt_queue.popleft()
|
||||
if not prompt.startswith("###"):
|
||||
self.add_message("user", prompt)
|
||||
|
||||
# Validate the prompt before proceeding
|
||||
validation_input = ValidationInput(
|
||||
prompt=prompt,
|
||||
conversation_history=self.conversation_history,
|
||||
agent_goal=agent_goal,
|
||||
)
|
||||
validation_result = await workflow.execute_activity(
|
||||
ToolActivities.validate_llm_prompt,
|
||||
args=[validation_input],
|
||||
schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT,
|
||||
retry_policy=RetryPolicy(initial_interval=timedelta(seconds=5)),
|
||||
)
|
||||
|
||||
if not validation_result.validationResult:
|
||||
# Handle validation failure
|
||||
self.add_message("agent", validation_result.validationFailedReason)
|
||||
continue # Skip to the next iteration
|
||||
|
||||
# Proceed with generating the context and prompt
|
||||
|
||||
context_instructions = generate_genai_prompt(
|
||||
agent_goal, self.conversation_history, self.tool_data
|
||||
|
||||
Reference in New Issue
Block a user