mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
prompt validator to prevent going off beaten track
This commit is contained in:
@@ -1,24 +1,21 @@
|
||||
from collections import deque
|
||||
from datetime import timedelta
|
||||
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict, Literal
|
||||
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict
|
||||
|
||||
from temporalio.common import RetryPolicy
|
||||
from temporalio import workflow
|
||||
|
||||
from models.data_types import ConversationHistory, NextStep, ValidationInput
|
||||
|
||||
with workflow.unsafe.imports_passed_through():
|
||||
from activities.tool_activities import ToolActivities, ToolPromptInput
|
||||
from activities.tool_activities import ToolActivities
|
||||
from prompts.agent_prompt_generators import generate_genai_prompt
|
||||
from models.data_types import CombinedInput, ToolWorkflowParams
|
||||
from models.data_types import CombinedInput, ToolWorkflowParams, ToolPromptInput
|
||||
|
||||
# Constants
|
||||
MAX_TURNS_BEFORE_CONTINUE = 250
|
||||
TOOL_ACTIVITY_TIMEOUT = timedelta(seconds=20)
|
||||
LLM_ACTIVITY_TIMEOUT = timedelta(seconds=60)
|
||||
|
||||
# Type definitions
|
||||
Message = Dict[str, Union[str, Dict[str, Any]]]
|
||||
ConversationHistory = Dict[str, List[Message]]
|
||||
NextStep = Literal["confirm", "question", "done"]
|
||||
LLM_ACTIVITY_TIMEOUT = timedelta(minutes=30)
|
||||
|
||||
|
||||
class ToolData(TypedDict, total=False):
|
||||
@@ -154,6 +151,26 @@ class ToolWorkflow:
|
||||
if not prompt.startswith("###"):
|
||||
self.add_message("user", prompt)
|
||||
|
||||
# Validate the prompt before proceeding
|
||||
validation_input = ValidationInput(
|
||||
prompt=prompt,
|
||||
conversation_history=self.conversation_history,
|
||||
agent_goal=agent_goal,
|
||||
)
|
||||
validation_result = await workflow.execute_activity(
|
||||
ToolActivities.validate_llm_prompt,
|
||||
args=[validation_input],
|
||||
schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT,
|
||||
retry_policy=RetryPolicy(initial_interval=timedelta(seconds=5)),
|
||||
)
|
||||
|
||||
if not validation_result.validationResult:
|
||||
# Handle validation failure
|
||||
self.add_message("agent", validation_result.validationFailedReason)
|
||||
continue # Skip to the next iteration
|
||||
|
||||
# Proceed with generating the context and prompt
|
||||
|
||||
context_instructions = generate_genai_prompt(
|
||||
agent_goal, self.conversation_history, self.tool_data
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user