prompt validator to prevent going off beaten track

This commit is contained in:
Steve Androulakis
2025-02-03 13:07:27 -08:00
parent 8cf2e891e9
commit 2a2383bb71
5 changed files with 126 additions and 17 deletions

View File

@@ -68,6 +68,7 @@ Note: I found the other (hosted) LLMs to be MUCH more reliable for this use case
* Requires a Stripe key for the `create_invoice` tool. Set this in the `STRIPE_API_KEY` environment variable in .env
* It's free to sign up and get a key at [Stripe](https://stripe.com/)
* If you're lazy go to `tools/create_invoice.py` and replace the `create_invoice` function with the mock `create_invoice_example` that exists in the same file.
* Requires a key from [Football Data](https://www.football-data.org). Sign up for a free account, then see the 'My Account' page to get your API token. Set `FOOTBALL_DATA_API_KEY` to this value.
## Configuring Temporal Connection
@@ -137,3 +138,11 @@ Access the UI at `http://localhost:5173`
- Continue-as-new shouldn't be a big consideration for this use case (as it would take many conversational turns to trigger). Regardless, I should ensure that it's able to carry the agent state over to the new workflow execution.
- Perhaps the UI should show when the LLM response is being retried (i.e. activity retry attempt because the LLM provided bad output)
- Tests would be nice!
# TODO for this branch
## Agent
- We'll have to figure out which matches are where. No use going to Manchester for a match that isn't there.
## Validator function
- Probably keep data types, but move the activity and workflow code for the demo
- Probably don't need the validator function if its the result from a tool call or confirmation step

View File

@@ -1,4 +1,3 @@
from dataclasses import dataclass
from temporalio import activity
from ollama import chat, ChatResponse
from openai import OpenAI
@@ -11,6 +10,7 @@ import google.generativeai as genai
import anthropic
import deepseek
from dotenv import load_dotenv
from models.data_types import ValidationInput, ValidationResult, ToolPromptInput
load_dotenv(override=True)
print(
@@ -23,13 +23,66 @@ if os.environ.get("LLM_PROVIDER") == "ollama":
print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME"))
@dataclass
class ToolPromptInput:
prompt: str
context_instructions: str
class ToolActivities:
@activity.defn
async def validate_llm_prompt(
self, validation_input: ValidationInput
) -> ValidationResult:
"""
Validates the prompt in the context of the conversation history and agent goal.
Returns a ValidationResult indicating if the prompt makes sense given the context.
"""
# Create simple context string describing tools and goals
tools_description = []
for tool in validation_input.agent_goal.tools:
tool_str = f"Tool: {tool.name}\n"
tool_str += f"Description: {tool.description}\n"
tool_str += "Arguments: " + ", ".join(
[f"{arg.name} ({arg.type})" for arg in tool.arguments]
)
tools_description.append(tool_str)
tools_str = "\n".join(tools_description)
# Convert conversation history to string
history_str = json.dumps(validation_input.conversation_history, indent=2)
# Create context instructions
context_instructions = f"""The agent goal and tools are as follows:
Description: {validation_input.agent_goal.description}
Available Tools:
{tools_str}
The conversation history to date is:
{history_str}"""
# Create validation prompt
validation_prompt = f"""The user's prompt is: "{validation_input.prompt}"
Please validate if this prompt makes sense given the agent goal and conversation history.
If the prompt doesn't make sense toward the goal then validationResult should be true.
Only return false if the prompt is nonsensical given the goal, tools available, and conversation history.
Return ONLY a JSON object with the following structure:
"validationResult": true/false,
"validationFailedReason": "If validationResult is false, provide a clear explanation to the user
about why their request doesn't make sense in the context and what information they should provide instead.
validationFailedReason should contain JSON in the format
{{
"next": "question",
"response": "[your reason here and a response to get the user back on track with the agent goal]"
}}
If validationResult is true, return an empty dict {{}}"
"""
# Call the LLM with the validation prompt
prompt_input = ToolPromptInput(
prompt=validation_prompt, context_instructions=context_instructions
)
result = self.prompt_llm(prompt_input)
return ValidationResult(
validationResult=result.get("validationResult", False),
validationFailedReason=result.get("validationFailedReason", {}),
)
@activity.defn
def prompt_llm(self, input: ToolPromptInput) -> dict:
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional, Deque
from typing import Optional, Deque, Dict, Any, List, Union, Literal
from models.tool_definitions import AgentGoal
@@ -13,3 +13,32 @@ class ToolWorkflowParams:
class CombinedInput:
tool_params: ToolWorkflowParams
agent_goal: AgentGoal
Message = Dict[str, Union[str, Dict[str, Any]]]
ConversationHistory = Dict[str, List[Message]]
NextStep = Literal["confirm", "question", "done"]
@dataclass
class ToolPromptInput:
prompt: str
context_instructions: str
@dataclass
class ValidationInput:
prompt: str
conversation_history: ConversationHistory
agent_goal: AgentGoal
@dataclass
class ValidationResult:
validationResult: bool
validationFailedReason: dict = None
def __post_init__(self):
# Initialize empty dict if None
if self.validationFailedReason is None:
self.validationFailedReason = {}

View File

@@ -24,6 +24,7 @@ async def main():
workflows=[ToolWorkflow],
activities=[
activities.prompt_llm,
activities.validate_llm_prompt,
dynamic_tool_activity,
],
activity_executor=activity_executor,

View File

@@ -1,24 +1,21 @@
from collections import deque
from datetime import timedelta
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict, Literal
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict
from temporalio.common import RetryPolicy
from temporalio import workflow
from models.data_types import ConversationHistory, NextStep, ValidationInput
with workflow.unsafe.imports_passed_through():
from activities.tool_activities import ToolActivities, ToolPromptInput
from activities.tool_activities import ToolActivities
from prompts.agent_prompt_generators import generate_genai_prompt
from models.data_types import CombinedInput, ToolWorkflowParams
from models.data_types import CombinedInput, ToolWorkflowParams, ToolPromptInput
# Constants
MAX_TURNS_BEFORE_CONTINUE = 250
TOOL_ACTIVITY_TIMEOUT = timedelta(seconds=20)
LLM_ACTIVITY_TIMEOUT = timedelta(seconds=60)
# Type definitions
Message = Dict[str, Union[str, Dict[str, Any]]]
ConversationHistory = Dict[str, List[Message]]
NextStep = Literal["confirm", "question", "done"]
LLM_ACTIVITY_TIMEOUT = timedelta(minutes=30)
class ToolData(TypedDict, total=False):
@@ -154,6 +151,26 @@ class ToolWorkflow:
if not prompt.startswith("###"):
self.add_message("user", prompt)
# Validate the prompt before proceeding
validation_input = ValidationInput(
prompt=prompt,
conversation_history=self.conversation_history,
agent_goal=agent_goal,
)
validation_result = await workflow.execute_activity(
ToolActivities.validate_llm_prompt,
args=[validation_input],
schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT,
retry_policy=RetryPolicy(initial_interval=timedelta(seconds=5)),
)
if not validation_result.validationResult:
# Handle validation failure
self.add_message("agent", validation_result.validationFailedReason)
continue # Skip to the next iteration
# Proceed with generating the context and prompt
context_instructions = generate_genai_prompt(
agent_goal, self.conversation_history, self.tool_data
)