mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
prompt validator to prevent going off beaten track
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
from dataclasses import dataclass
|
||||
from temporalio import activity
|
||||
from ollama import chat, ChatResponse
|
||||
from openai import OpenAI
|
||||
@@ -11,6 +10,7 @@ import google.generativeai as genai
|
||||
import anthropic
|
||||
import deepseek
|
||||
from dotenv import load_dotenv
|
||||
from models.data_types import ValidationInput, ValidationResult, ToolPromptInput
|
||||
|
||||
load_dotenv(override=True)
|
||||
print(
|
||||
@@ -23,13 +23,66 @@ if os.environ.get("LLM_PROVIDER") == "ollama":
|
||||
print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolPromptInput:
|
||||
prompt: str
|
||||
context_instructions: str
|
||||
|
||||
|
||||
class ToolActivities:
|
||||
@activity.defn
|
||||
async def validate_llm_prompt(
|
||||
self, validation_input: ValidationInput
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validates the prompt in the context of the conversation history and agent goal.
|
||||
Returns a ValidationResult indicating if the prompt makes sense given the context.
|
||||
"""
|
||||
# Create simple context string describing tools and goals
|
||||
tools_description = []
|
||||
for tool in validation_input.agent_goal.tools:
|
||||
tool_str = f"Tool: {tool.name}\n"
|
||||
tool_str += f"Description: {tool.description}\n"
|
||||
tool_str += "Arguments: " + ", ".join(
|
||||
[f"{arg.name} ({arg.type})" for arg in tool.arguments]
|
||||
)
|
||||
tools_description.append(tool_str)
|
||||
tools_str = "\n".join(tools_description)
|
||||
|
||||
# Convert conversation history to string
|
||||
history_str = json.dumps(validation_input.conversation_history, indent=2)
|
||||
|
||||
# Create context instructions
|
||||
context_instructions = f"""The agent goal and tools are as follows:
|
||||
Description: {validation_input.agent_goal.description}
|
||||
Available Tools:
|
||||
{tools_str}
|
||||
The conversation history to date is:
|
||||
{history_str}"""
|
||||
|
||||
# Create validation prompt
|
||||
validation_prompt = f"""The user's prompt is: "{validation_input.prompt}"
|
||||
Please validate if this prompt makes sense given the agent goal and conversation history.
|
||||
If the prompt doesn't make sense toward the goal then validationResult should be true.
|
||||
Only return false if the prompt is nonsensical given the goal, tools available, and conversation history.
|
||||
Return ONLY a JSON object with the following structure:
|
||||
"validationResult": true/false,
|
||||
"validationFailedReason": "If validationResult is false, provide a clear explanation to the user
|
||||
about why their request doesn't make sense in the context and what information they should provide instead.
|
||||
validationFailedReason should contain JSON in the format
|
||||
{{
|
||||
"next": "question",
|
||||
"response": "[your reason here and a response to get the user back on track with the agent goal]"
|
||||
}}
|
||||
If validationResult is true, return an empty dict {{}}"
|
||||
"""
|
||||
|
||||
# Call the LLM with the validation prompt
|
||||
prompt_input = ToolPromptInput(
|
||||
prompt=validation_prompt, context_instructions=context_instructions
|
||||
)
|
||||
|
||||
result = self.prompt_llm(prompt_input)
|
||||
|
||||
return ValidationResult(
|
||||
validationResult=result.get("validationResult", False),
|
||||
validationFailedReason=result.get("validationFailedReason", {}),
|
||||
)
|
||||
|
||||
@activity.defn
|
||||
def prompt_llm(self, input: ToolPromptInput) -> dict:
|
||||
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
||||
|
||||
Reference in New Issue
Block a user