mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-16 06:28:08 +01:00
prompt validator to prevent going off beaten track
This commit is contained in:
@@ -68,6 +68,7 @@ Note: I found the other (hosted) LLMs to be MUCH more reliable for this use case
|
||||
* Requires a Stripe key for the `create_invoice` tool. Set this in the `STRIPE_API_KEY` environment variable in .env
|
||||
* It's free to sign up and get a key at [Stripe](https://stripe.com/)
|
||||
* If you're lazy go to `tools/create_invoice.py` and replace the `create_invoice` function with the mock `create_invoice_example` that exists in the same file.
|
||||
* Requires a key from [Football Data](https://www.football-data.org). Sign up for a free account, then see the 'My Account' page to get your API token. Set `FOOTBALL_DATA_API_KEY` to this value.
|
||||
|
||||
## Configuring Temporal Connection
|
||||
|
||||
@@ -137,3 +138,11 @@ Access the UI at `http://localhost:5173`
|
||||
- Continue-as-new shouldn't be a big consideration for this use case (as it would take many conversational turns to trigger). Regardless, I should ensure that it's able to carry the agent state over to the new workflow execution.
|
||||
- Perhaps the UI should show when the LLM response is being retried (i.e. activity retry attempt because the LLM provided bad output)
|
||||
- Tests would be nice!
|
||||
|
||||
# TODO for this branch
|
||||
## Agent
|
||||
- We'll have to figure out which matches are where. No use going to Manchester for a match that isn't there.
|
||||
|
||||
## Validator function
|
||||
- Probably keep data types, but move the activity and workflow code for the demo
|
||||
- Probably don't need the validator function if its the result from a tool call or confirmation step
|
||||
@@ -1,4 +1,3 @@
|
||||
from dataclasses import dataclass
|
||||
from temporalio import activity
|
||||
from ollama import chat, ChatResponse
|
||||
from openai import OpenAI
|
||||
@@ -11,6 +10,7 @@ import google.generativeai as genai
|
||||
import anthropic
|
||||
import deepseek
|
||||
from dotenv import load_dotenv
|
||||
from models.data_types import ValidationInput, ValidationResult, ToolPromptInput
|
||||
|
||||
load_dotenv(override=True)
|
||||
print(
|
||||
@@ -23,13 +23,66 @@ if os.environ.get("LLM_PROVIDER") == "ollama":
|
||||
print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolPromptInput:
|
||||
prompt: str
|
||||
context_instructions: str
|
||||
|
||||
|
||||
class ToolActivities:
|
||||
@activity.defn
|
||||
async def validate_llm_prompt(
|
||||
self, validation_input: ValidationInput
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validates the prompt in the context of the conversation history and agent goal.
|
||||
Returns a ValidationResult indicating if the prompt makes sense given the context.
|
||||
"""
|
||||
# Create simple context string describing tools and goals
|
||||
tools_description = []
|
||||
for tool in validation_input.agent_goal.tools:
|
||||
tool_str = f"Tool: {tool.name}\n"
|
||||
tool_str += f"Description: {tool.description}\n"
|
||||
tool_str += "Arguments: " + ", ".join(
|
||||
[f"{arg.name} ({arg.type})" for arg in tool.arguments]
|
||||
)
|
||||
tools_description.append(tool_str)
|
||||
tools_str = "\n".join(tools_description)
|
||||
|
||||
# Convert conversation history to string
|
||||
history_str = json.dumps(validation_input.conversation_history, indent=2)
|
||||
|
||||
# Create context instructions
|
||||
context_instructions = f"""The agent goal and tools are as follows:
|
||||
Description: {validation_input.agent_goal.description}
|
||||
Available Tools:
|
||||
{tools_str}
|
||||
The conversation history to date is:
|
||||
{history_str}"""
|
||||
|
||||
# Create validation prompt
|
||||
validation_prompt = f"""The user's prompt is: "{validation_input.prompt}"
|
||||
Please validate if this prompt makes sense given the agent goal and conversation history.
|
||||
If the prompt doesn't make sense toward the goal then validationResult should be true.
|
||||
Only return false if the prompt is nonsensical given the goal, tools available, and conversation history.
|
||||
Return ONLY a JSON object with the following structure:
|
||||
"validationResult": true/false,
|
||||
"validationFailedReason": "If validationResult is false, provide a clear explanation to the user
|
||||
about why their request doesn't make sense in the context and what information they should provide instead.
|
||||
validationFailedReason should contain JSON in the format
|
||||
{{
|
||||
"next": "question",
|
||||
"response": "[your reason here and a response to get the user back on track with the agent goal]"
|
||||
}}
|
||||
If validationResult is true, return an empty dict {{}}"
|
||||
"""
|
||||
|
||||
# Call the LLM with the validation prompt
|
||||
prompt_input = ToolPromptInput(
|
||||
prompt=validation_prompt, context_instructions=context_instructions
|
||||
)
|
||||
|
||||
result = self.prompt_llm(prompt_input)
|
||||
|
||||
return ValidationResult(
|
||||
validationResult=result.get("validationResult", False),
|
||||
validationFailedReason=result.get("validationFailedReason", {}),
|
||||
)
|
||||
|
||||
@activity.defn
|
||||
def prompt_llm(self, input: ToolPromptInput) -> dict:
|
||||
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Deque
|
||||
from typing import Optional, Deque, Dict, Any, List, Union, Literal
|
||||
from models.tool_definitions import AgentGoal
|
||||
|
||||
|
||||
@@ -13,3 +13,32 @@ class ToolWorkflowParams:
|
||||
class CombinedInput:
|
||||
tool_params: ToolWorkflowParams
|
||||
agent_goal: AgentGoal
|
||||
|
||||
|
||||
Message = Dict[str, Union[str, Dict[str, Any]]]
|
||||
ConversationHistory = Dict[str, List[Message]]
|
||||
NextStep = Literal["confirm", "question", "done"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolPromptInput:
|
||||
prompt: str
|
||||
context_instructions: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationInput:
|
||||
prompt: str
|
||||
conversation_history: ConversationHistory
|
||||
agent_goal: AgentGoal
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
validationResult: bool
|
||||
validationFailedReason: dict = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Initialize empty dict if None
|
||||
if self.validationFailedReason is None:
|
||||
self.validationFailedReason = {}
|
||||
|
||||
@@ -24,6 +24,7 @@ async def main():
|
||||
workflows=[ToolWorkflow],
|
||||
activities=[
|
||||
activities.prompt_llm,
|
||||
activities.validate_llm_prompt,
|
||||
dynamic_tool_activity,
|
||||
],
|
||||
activity_executor=activity_executor,
|
||||
|
||||
@@ -1,24 +1,21 @@
|
||||
from collections import deque
|
||||
from datetime import timedelta
|
||||
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict, Literal
|
||||
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict
|
||||
|
||||
from temporalio.common import RetryPolicy
|
||||
from temporalio import workflow
|
||||
|
||||
from models.data_types import ConversationHistory, NextStep, ValidationInput
|
||||
|
||||
with workflow.unsafe.imports_passed_through():
|
||||
from activities.tool_activities import ToolActivities, ToolPromptInput
|
||||
from activities.tool_activities import ToolActivities
|
||||
from prompts.agent_prompt_generators import generate_genai_prompt
|
||||
from models.data_types import CombinedInput, ToolWorkflowParams
|
||||
from models.data_types import CombinedInput, ToolWorkflowParams, ToolPromptInput
|
||||
|
||||
# Constants
|
||||
MAX_TURNS_BEFORE_CONTINUE = 250
|
||||
TOOL_ACTIVITY_TIMEOUT = timedelta(seconds=20)
|
||||
LLM_ACTIVITY_TIMEOUT = timedelta(seconds=60)
|
||||
|
||||
# Type definitions
|
||||
Message = Dict[str, Union[str, Dict[str, Any]]]
|
||||
ConversationHistory = Dict[str, List[Message]]
|
||||
NextStep = Literal["confirm", "question", "done"]
|
||||
LLM_ACTIVITY_TIMEOUT = timedelta(minutes=30)
|
||||
|
||||
|
||||
class ToolData(TypedDict, total=False):
|
||||
@@ -154,6 +151,26 @@ class ToolWorkflow:
|
||||
if not prompt.startswith("###"):
|
||||
self.add_message("user", prompt)
|
||||
|
||||
# Validate the prompt before proceeding
|
||||
validation_input = ValidationInput(
|
||||
prompt=prompt,
|
||||
conversation_history=self.conversation_history,
|
||||
agent_goal=agent_goal,
|
||||
)
|
||||
validation_result = await workflow.execute_activity(
|
||||
ToolActivities.validate_llm_prompt,
|
||||
args=[validation_input],
|
||||
schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT,
|
||||
retry_policy=RetryPolicy(initial_interval=timedelta(seconds=5)),
|
||||
)
|
||||
|
||||
if not validation_result.validationResult:
|
||||
# Handle validation failure
|
||||
self.add_message("agent", validation_result.validationFailedReason)
|
||||
continue # Skip to the next iteration
|
||||
|
||||
# Proceed with generating the context and prompt
|
||||
|
||||
context_instructions = generate_genai_prompt(
|
||||
agent_goal, self.conversation_history, self.tool_data
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user