mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-17 06:58:09 +01:00
prompt validator to prevent going off beaten track
This commit is contained in:
@@ -68,6 +68,7 @@ Note: I found the other (hosted) LLMs to be MUCH more reliable for this use case
|
|||||||
* Requires a Stripe key for the `create_invoice` tool. Set this in the `STRIPE_API_KEY` environment variable in .env
|
* Requires a Stripe key for the `create_invoice` tool. Set this in the `STRIPE_API_KEY` environment variable in .env
|
||||||
* It's free to sign up and get a key at [Stripe](https://stripe.com/)
|
* It's free to sign up and get a key at [Stripe](https://stripe.com/)
|
||||||
* If you're lazy go to `tools/create_invoice.py` and replace the `create_invoice` function with the mock `create_invoice_example` that exists in the same file.
|
* If you're lazy go to `tools/create_invoice.py` and replace the `create_invoice` function with the mock `create_invoice_example` that exists in the same file.
|
||||||
|
* Requires a key from [Football Data](https://www.football-data.org). Sign up for a free account, then see the 'My Account' page to get your API token. Set `FOOTBALL_DATA_API_KEY` to this value.
|
||||||
|
|
||||||
## Configuring Temporal Connection
|
## Configuring Temporal Connection
|
||||||
|
|
||||||
@@ -137,3 +138,11 @@ Access the UI at `http://localhost:5173`
|
|||||||
- Continue-as-new shouldn't be a big consideration for this use case (as it would take many conversational turns to trigger). Regardless, I should ensure that it's able to carry the agent state over to the new workflow execution.
|
- Continue-as-new shouldn't be a big consideration for this use case (as it would take many conversational turns to trigger). Regardless, I should ensure that it's able to carry the agent state over to the new workflow execution.
|
||||||
- Perhaps the UI should show when the LLM response is being retried (i.e. activity retry attempt because the LLM provided bad output)
|
- Perhaps the UI should show when the LLM response is being retried (i.e. activity retry attempt because the LLM provided bad output)
|
||||||
- Tests would be nice!
|
- Tests would be nice!
|
||||||
|
|
||||||
|
# TODO for this branch
|
||||||
|
## Agent
|
||||||
|
- We'll have to figure out which matches are where. No use going to Manchester for a match that isn't there.
|
||||||
|
|
||||||
|
## Validator function
|
||||||
|
- Probably keep data types, but move the activity and workflow code for the demo
|
||||||
|
- Probably don't need the validator function if its the result from a tool call or confirmation step
|
||||||
@@ -1,4 +1,3 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from temporalio import activity
|
from temporalio import activity
|
||||||
from ollama import chat, ChatResponse
|
from ollama import chat, ChatResponse
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
@@ -11,6 +10,7 @@ import google.generativeai as genai
|
|||||||
import anthropic
|
import anthropic
|
||||||
import deepseek
|
import deepseek
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from models.data_types import ValidationInput, ValidationResult, ToolPromptInput
|
||||||
|
|
||||||
load_dotenv(override=True)
|
load_dotenv(override=True)
|
||||||
print(
|
print(
|
||||||
@@ -23,13 +23,66 @@ if os.environ.get("LLM_PROVIDER") == "ollama":
|
|||||||
print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME"))
|
print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME"))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ToolPromptInput:
|
|
||||||
prompt: str
|
|
||||||
context_instructions: str
|
|
||||||
|
|
||||||
|
|
||||||
class ToolActivities:
|
class ToolActivities:
|
||||||
|
@activity.defn
|
||||||
|
async def validate_llm_prompt(
|
||||||
|
self, validation_input: ValidationInput
|
||||||
|
) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Validates the prompt in the context of the conversation history and agent goal.
|
||||||
|
Returns a ValidationResult indicating if the prompt makes sense given the context.
|
||||||
|
"""
|
||||||
|
# Create simple context string describing tools and goals
|
||||||
|
tools_description = []
|
||||||
|
for tool in validation_input.agent_goal.tools:
|
||||||
|
tool_str = f"Tool: {tool.name}\n"
|
||||||
|
tool_str += f"Description: {tool.description}\n"
|
||||||
|
tool_str += "Arguments: " + ", ".join(
|
||||||
|
[f"{arg.name} ({arg.type})" for arg in tool.arguments]
|
||||||
|
)
|
||||||
|
tools_description.append(tool_str)
|
||||||
|
tools_str = "\n".join(tools_description)
|
||||||
|
|
||||||
|
# Convert conversation history to string
|
||||||
|
history_str = json.dumps(validation_input.conversation_history, indent=2)
|
||||||
|
|
||||||
|
# Create context instructions
|
||||||
|
context_instructions = f"""The agent goal and tools are as follows:
|
||||||
|
Description: {validation_input.agent_goal.description}
|
||||||
|
Available Tools:
|
||||||
|
{tools_str}
|
||||||
|
The conversation history to date is:
|
||||||
|
{history_str}"""
|
||||||
|
|
||||||
|
# Create validation prompt
|
||||||
|
validation_prompt = f"""The user's prompt is: "{validation_input.prompt}"
|
||||||
|
Please validate if this prompt makes sense given the agent goal and conversation history.
|
||||||
|
If the prompt doesn't make sense toward the goal then validationResult should be true.
|
||||||
|
Only return false if the prompt is nonsensical given the goal, tools available, and conversation history.
|
||||||
|
Return ONLY a JSON object with the following structure:
|
||||||
|
"validationResult": true/false,
|
||||||
|
"validationFailedReason": "If validationResult is false, provide a clear explanation to the user
|
||||||
|
about why their request doesn't make sense in the context and what information they should provide instead.
|
||||||
|
validationFailedReason should contain JSON in the format
|
||||||
|
{{
|
||||||
|
"next": "question",
|
||||||
|
"response": "[your reason here and a response to get the user back on track with the agent goal]"
|
||||||
|
}}
|
||||||
|
If validationResult is true, return an empty dict {{}}"
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Call the LLM with the validation prompt
|
||||||
|
prompt_input = ToolPromptInput(
|
||||||
|
prompt=validation_prompt, context_instructions=context_instructions
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.prompt_llm(prompt_input)
|
||||||
|
|
||||||
|
return ValidationResult(
|
||||||
|
validationResult=result.get("validationResult", False),
|
||||||
|
validationFailedReason=result.get("validationFailedReason", {}),
|
||||||
|
)
|
||||||
|
|
||||||
@activity.defn
|
@activity.defn
|
||||||
def prompt_llm(self, input: ToolPromptInput) -> dict:
|
def prompt_llm(self, input: ToolPromptInput) -> dict:
|
||||||
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Deque
|
from typing import Optional, Deque, Dict, Any, List, Union, Literal
|
||||||
from models.tool_definitions import AgentGoal
|
from models.tool_definitions import AgentGoal
|
||||||
|
|
||||||
|
|
||||||
@@ -13,3 +13,32 @@ class ToolWorkflowParams:
|
|||||||
class CombinedInput:
|
class CombinedInput:
|
||||||
tool_params: ToolWorkflowParams
|
tool_params: ToolWorkflowParams
|
||||||
agent_goal: AgentGoal
|
agent_goal: AgentGoal
|
||||||
|
|
||||||
|
|
||||||
|
Message = Dict[str, Union[str, Dict[str, Any]]]
|
||||||
|
ConversationHistory = Dict[str, List[Message]]
|
||||||
|
NextStep = Literal["confirm", "question", "done"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolPromptInput:
|
||||||
|
prompt: str
|
||||||
|
context_instructions: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationInput:
|
||||||
|
prompt: str
|
||||||
|
conversation_history: ConversationHistory
|
||||||
|
agent_goal: AgentGoal
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationResult:
|
||||||
|
validationResult: bool
|
||||||
|
validationFailedReason: dict = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Initialize empty dict if None
|
||||||
|
if self.validationFailedReason is None:
|
||||||
|
self.validationFailedReason = {}
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ async def main():
|
|||||||
workflows=[ToolWorkflow],
|
workflows=[ToolWorkflow],
|
||||||
activities=[
|
activities=[
|
||||||
activities.prompt_llm,
|
activities.prompt_llm,
|
||||||
|
activities.validate_llm_prompt,
|
||||||
dynamic_tool_activity,
|
dynamic_tool_activity,
|
||||||
],
|
],
|
||||||
activity_executor=activity_executor,
|
activity_executor=activity_executor,
|
||||||
|
|||||||
@@ -1,24 +1,21 @@
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict, Literal
|
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict
|
||||||
|
|
||||||
from temporalio.common import RetryPolicy
|
from temporalio.common import RetryPolicy
|
||||||
from temporalio import workflow
|
from temporalio import workflow
|
||||||
|
|
||||||
|
from models.data_types import ConversationHistory, NextStep, ValidationInput
|
||||||
|
|
||||||
with workflow.unsafe.imports_passed_through():
|
with workflow.unsafe.imports_passed_through():
|
||||||
from activities.tool_activities import ToolActivities, ToolPromptInput
|
from activities.tool_activities import ToolActivities
|
||||||
from prompts.agent_prompt_generators import generate_genai_prompt
|
from prompts.agent_prompt_generators import generate_genai_prompt
|
||||||
from models.data_types import CombinedInput, ToolWorkflowParams
|
from models.data_types import CombinedInput, ToolWorkflowParams, ToolPromptInput
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
MAX_TURNS_BEFORE_CONTINUE = 250
|
MAX_TURNS_BEFORE_CONTINUE = 250
|
||||||
TOOL_ACTIVITY_TIMEOUT = timedelta(seconds=20)
|
TOOL_ACTIVITY_TIMEOUT = timedelta(seconds=20)
|
||||||
LLM_ACTIVITY_TIMEOUT = timedelta(seconds=60)
|
LLM_ACTIVITY_TIMEOUT = timedelta(minutes=30)
|
||||||
|
|
||||||
# Type definitions
|
|
||||||
Message = Dict[str, Union[str, Dict[str, Any]]]
|
|
||||||
ConversationHistory = Dict[str, List[Message]]
|
|
||||||
NextStep = Literal["confirm", "question", "done"]
|
|
||||||
|
|
||||||
|
|
||||||
class ToolData(TypedDict, total=False):
|
class ToolData(TypedDict, total=False):
|
||||||
@@ -154,6 +151,26 @@ class ToolWorkflow:
|
|||||||
if not prompt.startswith("###"):
|
if not prompt.startswith("###"):
|
||||||
self.add_message("user", prompt)
|
self.add_message("user", prompt)
|
||||||
|
|
||||||
|
# Validate the prompt before proceeding
|
||||||
|
validation_input = ValidationInput(
|
||||||
|
prompt=prompt,
|
||||||
|
conversation_history=self.conversation_history,
|
||||||
|
agent_goal=agent_goal,
|
||||||
|
)
|
||||||
|
validation_result = await workflow.execute_activity(
|
||||||
|
ToolActivities.validate_llm_prompt,
|
||||||
|
args=[validation_input],
|
||||||
|
schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT,
|
||||||
|
retry_policy=RetryPolicy(initial_interval=timedelta(seconds=5)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not validation_result.validationResult:
|
||||||
|
# Handle validation failure
|
||||||
|
self.add_message("agent", validation_result.validationFailedReason)
|
||||||
|
continue # Skip to the next iteration
|
||||||
|
|
||||||
|
# Proceed with generating the context and prompt
|
||||||
|
|
||||||
context_instructions = generate_genai_prompt(
|
context_instructions = generate_genai_prompt(
|
||||||
agent_goal, self.conversation_history, self.tool_data
|
agent_goal, self.conversation_history, self.tool_data
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user