prompt validator to prevent going off beaten track

This commit is contained in:
Steve Androulakis
2025-02-03 13:07:27 -08:00
parent 8cf2e891e9
commit 2a2383bb71
5 changed files with 126 additions and 17 deletions

View File

@@ -1,4 +1,3 @@
from dataclasses import dataclass
from temporalio import activity
from ollama import chat, ChatResponse
from openai import OpenAI
@@ -11,6 +10,7 @@ import google.generativeai as genai
import anthropic
import deepseek
from dotenv import load_dotenv
from models.data_types import ValidationInput, ValidationResult, ToolPromptInput
load_dotenv(override=True)
print(
@@ -23,13 +23,66 @@ if os.environ.get("LLM_PROVIDER") == "ollama":
print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME"))
@dataclass
class ToolPromptInput:
prompt: str
context_instructions: str
class ToolActivities:
@activity.defn
async def validate_llm_prompt(
self, validation_input: ValidationInput
) -> ValidationResult:
"""
Validates the prompt in the context of the conversation history and agent goal.
Returns a ValidationResult indicating if the prompt makes sense given the context.
"""
# Create simple context string describing tools and goals
tools_description = []
for tool in validation_input.agent_goal.tools:
tool_str = f"Tool: {tool.name}\n"
tool_str += f"Description: {tool.description}\n"
tool_str += "Arguments: " + ", ".join(
[f"{arg.name} ({arg.type})" for arg in tool.arguments]
)
tools_description.append(tool_str)
tools_str = "\n".join(tools_description)
# Convert conversation history to string
history_str = json.dumps(validation_input.conversation_history, indent=2)
# Create context instructions
context_instructions = f"""The agent goal and tools are as follows:
Description: {validation_input.agent_goal.description}
Available Tools:
{tools_str}
The conversation history to date is:
{history_str}"""
# Create validation prompt
validation_prompt = f"""The user's prompt is: "{validation_input.prompt}"
Please validate if this prompt makes sense given the agent goal and conversation history.
If the prompt doesn't make sense toward the goal then validationResult should be true.
Only return false if the prompt is nonsensical given the goal, tools available, and conversation history.
Return ONLY a JSON object with the following structure:
"validationResult": true/false,
"validationFailedReason": "If validationResult is false, provide a clear explanation to the user
about why their request doesn't make sense in the context and what information they should provide instead.
validationFailedReason should contain JSON in the format
{{
"next": "question",
"response": "[your reason here and a response to get the user back on track with the agent goal]"
}}
If validationResult is true, return an empty dict {{}}"
"""
# Call the LLM with the validation prompt
prompt_input = ToolPromptInput(
prompt=validation_prompt, context_instructions=context_instructions
)
result = self.prompt_llm(prompt_input)
return ValidationResult(
validationResult=result.get("validationResult", False),
validationFailedReason=result.get("validationFailedReason", {}),
)
@activity.defn
def prompt_llm(self, input: ToolPromptInput) -> dict:
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()