mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
refactor workflow file for clarity
This commit is contained in:
@@ -4,32 +4,24 @@ from typing import Dict, Any, Union, List, Optional, Deque, TypedDict
|
|||||||
|
|
||||||
from temporalio.common import RetryPolicy
|
from temporalio.common import RetryPolicy
|
||||||
from temporalio import workflow
|
from temporalio import workflow
|
||||||
from temporalio.exceptions import ActivityError
|
|
||||||
|
|
||||||
from models.data_types import ConversationHistory, NextStep, ValidationInput
|
from models.data_types import ConversationHistory, NextStep, ValidationInput
|
||||||
|
from workflows.workflow_helpers import LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT, \
|
||||||
|
LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT
|
||||||
|
from workflows import workflow_helpers as helpers
|
||||||
|
|
||||||
with workflow.unsafe.imports_passed_through():
|
with workflow.unsafe.imports_passed_through():
|
||||||
from activities.tool_activities import ToolActivities
|
from activities.tool_activities import ToolActivities
|
||||||
from prompts.agent_prompt_generators import (
|
from prompts.agent_prompt_generators import (
|
||||||
generate_genai_prompt,
|
generate_genai_prompt
|
||||||
generate_tool_completion_prompt,
|
|
||||||
generate_missing_args_prompt,
|
|
||||||
)
|
)
|
||||||
from models.data_types import (
|
from models.data_types import (
|
||||||
CombinedInput,
|
CombinedInput,
|
||||||
AgentGoalWorkflowParams,
|
|
||||||
ToolPromptInput,
|
ToolPromptInput,
|
||||||
)
|
)
|
||||||
|
|
||||||
from shared.config import TEMPORAL_LEGACY_TASK_QUEUE
|
|
||||||
|
|
||||||
# Constants
|
# Constants
|
||||||
MAX_TURNS_BEFORE_CONTINUE = 250
|
MAX_TURNS_BEFORE_CONTINUE = 250
|
||||||
TOOL_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=10)
|
|
||||||
TOOL_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30)
|
|
||||||
LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=10)
|
|
||||||
LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolData(TypedDict, total=False):
|
class ToolData(TypedDict, total=False):
|
||||||
next: NextStep
|
next: NextStep
|
||||||
@@ -37,7 +29,6 @@ class ToolData(TypedDict, total=False):
|
|||||||
args: Dict[str, Any]
|
args: Dict[str, Any]
|
||||||
response: str
|
response: str
|
||||||
|
|
||||||
|
|
||||||
@workflow.defn
|
@workflow.defn
|
||||||
class AgentGoalWorkflow:
|
class AgentGoalWorkflow:
|
||||||
"""Workflow that manages tool execution with user confirmation and conversation history."""
|
"""Workflow that manages tool execution with user confirmation and conversation history."""
|
||||||
@@ -51,82 +42,6 @@ class AgentGoalWorkflow:
|
|||||||
self.confirm: bool = False
|
self.confirm: bool = False
|
||||||
self.tool_results: List[Dict[str, Any]] = []
|
self.tool_results: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
async def _handle_tool_execution(
|
|
||||||
self, current_tool: str, tool_data: ToolData
|
|
||||||
) -> None:
|
|
||||||
"""Execute a tool after confirmation and handle its result."""
|
|
||||||
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
|
|
||||||
|
|
||||||
task_queue = (
|
|
||||||
TEMPORAL_LEGACY_TASK_QUEUE
|
|
||||||
if current_tool in ["SearchTrains", "BookTrains"]
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
dynamic_result = await workflow.execute_activity(
|
|
||||||
current_tool,
|
|
||||||
tool_data["args"],
|
|
||||||
task_queue=task_queue,
|
|
||||||
schedule_to_close_timeout=TOOL_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT,
|
|
||||||
start_to_close_timeout=TOOL_ACTIVITY_START_TO_CLOSE_TIMEOUT,
|
|
||||||
retry_policy=RetryPolicy(
|
|
||||||
initial_interval=timedelta(seconds=5), backoff_coefficient=1
|
|
||||||
),
|
|
||||||
)
|
|
||||||
dynamic_result["tool"] = current_tool
|
|
||||||
self.tool_results.append(dynamic_result)
|
|
||||||
except ActivityError as e:
|
|
||||||
workflow.logger.error(f"Tool execution failed: {str(e)}")
|
|
||||||
dynamic_result = {"error": str(e), "tool": current_tool}
|
|
||||||
|
|
||||||
self.add_message("tool_result", dynamic_result)
|
|
||||||
|
|
||||||
self.prompt_queue.append(generate_tool_completion_prompt(current_tool, dynamic_result))
|
|
||||||
|
|
||||||
async def _handle_missing_args(
|
|
||||||
self, current_tool: str, args: Dict[str, Any], tool_data: ToolData
|
|
||||||
) -> bool:
|
|
||||||
"""Check for missing arguments and handle them if found."""
|
|
||||||
missing_args = [key for key, value in args.items() if value is None]
|
|
||||||
|
|
||||||
if missing_args:
|
|
||||||
self.prompt_queue.append(
|
|
||||||
generate_missing_args_prompt(current_tool, tool_data, missing_args)
|
|
||||||
)
|
|
||||||
workflow.logger.info(
|
|
||||||
f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}"
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _continue_as_new_if_needed(self, agent_goal: Any) -> None:
|
|
||||||
"""Handle workflow continuation if message limit is reached."""
|
|
||||||
if len(self.conversation_history["messages"]) >= MAX_TURNS_BEFORE_CONTINUE:
|
|
||||||
summary_context, summary_prompt = self.prompt_summary_with_history()
|
|
||||||
summary_input = ToolPromptInput(
|
|
||||||
prompt=summary_prompt, context_instructions=summary_context
|
|
||||||
)
|
|
||||||
self.conversation_summary = await workflow.start_activity_method(
|
|
||||||
ToolActivities.agent_toolPlanner,
|
|
||||||
summary_input,
|
|
||||||
schedule_to_close_timeout=LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT,
|
|
||||||
)
|
|
||||||
workflow.logger.info(
|
|
||||||
f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns."
|
|
||||||
)
|
|
||||||
workflow.continue_as_new(
|
|
||||||
args=[
|
|
||||||
CombinedInput(
|
|
||||||
tool_params=AgentGoalWorkflowParams(
|
|
||||||
conversation_summary=self.conversation_summary,
|
|
||||||
prompt_queue=self.prompt_queue,
|
|
||||||
),
|
|
||||||
agent_goal=agent_goal,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@workflow.run
|
@workflow.run
|
||||||
async def run(self, combined_input: CombinedInput) -> str:
|
async def run(self, combined_input: CombinedInput) -> str:
|
||||||
"""Main workflow execution method."""
|
"""Main workflow execution method."""
|
||||||
@@ -160,7 +75,13 @@ class AgentGoalWorkflow:
|
|||||||
confirmed_tool_data["next"] = "user_confirmed_tool_run"
|
confirmed_tool_data["next"] = "user_confirmed_tool_run"
|
||||||
self.add_message("user_confirmed_tool_run", confirmed_tool_data)
|
self.add_message("user_confirmed_tool_run", confirmed_tool_data)
|
||||||
|
|
||||||
await self._handle_tool_execution(current_tool, self.tool_data)
|
await helpers.handle_tool_execution(
|
||||||
|
current_tool,
|
||||||
|
self.tool_data,
|
||||||
|
self.tool_results,
|
||||||
|
self.add_message,
|
||||||
|
self.prompt_queue
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self.prompt_queue:
|
if self.prompt_queue:
|
||||||
@@ -194,7 +115,6 @@ class AgentGoalWorkflow:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Proceed with generating the context and prompt
|
# Proceed with generating the context and prompt
|
||||||
|
|
||||||
context_instructions = generate_genai_prompt(
|
context_instructions = generate_genai_prompt(
|
||||||
agent_goal, self.conversation_history, self.tool_data
|
agent_goal, self.conversation_history, self.tool_data
|
||||||
)
|
)
|
||||||
@@ -220,7 +140,7 @@ class AgentGoalWorkflow:
|
|||||||
|
|
||||||
if next_step == "confirm" and current_tool:
|
if next_step == "confirm" and current_tool:
|
||||||
args = tool_data.get("args", {})
|
args = tool_data.get("args", {})
|
||||||
if await self._handle_missing_args(current_tool, args, tool_data):
|
if await helpers.handle_missing_args(current_tool, args, tool_data, self.prompt_queue):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
waiting_for_confirm = True
|
waiting_for_confirm = True
|
||||||
@@ -233,7 +153,13 @@ class AgentGoalWorkflow:
|
|||||||
return str(self.conversation_history)
|
return str(self.conversation_history)
|
||||||
|
|
||||||
self.add_message("agent", tool_data)
|
self.add_message("agent", tool_data)
|
||||||
await self._continue_as_new_if_needed(agent_goal)
|
await helpers.continue_as_new_if_needed(
|
||||||
|
self.conversation_history,
|
||||||
|
self.prompt_queue,
|
||||||
|
agent_goal,
|
||||||
|
MAX_TURNS_BEFORE_CONTINUE,
|
||||||
|
self.add_message
|
||||||
|
)
|
||||||
|
|
||||||
@workflow.signal
|
@workflow.signal
|
||||||
async def user_prompt(self, prompt: str) -> None:
|
async def user_prompt(self, prompt: str) -> None:
|
||||||
@@ -243,17 +169,17 @@ class AgentGoalWorkflow:
|
|||||||
return
|
return
|
||||||
self.prompt_queue.append(prompt)
|
self.prompt_queue.append(prompt)
|
||||||
|
|
||||||
@workflow.signal
|
|
||||||
async def end_chat(self) -> None:
|
|
||||||
"""Signal handler for ending the chat session."""
|
|
||||||
self.chat_ended = True
|
|
||||||
|
|
||||||
@workflow.signal
|
@workflow.signal
|
||||||
async def confirm(self) -> None:
|
async def confirm(self) -> None:
|
||||||
"""Signal handler for user confirmation of tool execution."""
|
"""Signal handler for user confirmation of tool execution."""
|
||||||
workflow.logger.info("Received user confirmation")
|
workflow.logger.info("Received user confirmation")
|
||||||
self.confirm = True
|
self.confirm = True
|
||||||
|
|
||||||
|
@workflow.signal
|
||||||
|
async def end_chat(self) -> None:
|
||||||
|
"""Signal handler for ending the chat session."""
|
||||||
|
self.chat_ended = True
|
||||||
|
|
||||||
@workflow.query
|
@workflow.query
|
||||||
def get_conversation_history(self) -> ConversationHistory:
|
def get_conversation_history(self) -> ConversationHistory:
|
||||||
"""Query handler to retrieve the full conversation history."""
|
"""Query handler to retrieve the full conversation history."""
|
||||||
@@ -261,49 +187,15 @@ class AgentGoalWorkflow:
|
|||||||
|
|
||||||
@workflow.query
|
@workflow.query
|
||||||
def get_summary_from_history(self) -> Optional[str]:
|
def get_summary_from_history(self) -> Optional[str]:
|
||||||
"""Query handler to retrieve the conversation summary if available."""
|
"""Query handler to retrieve the conversation summary if available.
|
||||||
|
Used only for continue as new of the workflow."""
|
||||||
return self.conversation_summary
|
return self.conversation_summary
|
||||||
|
|
||||||
@workflow.query
|
@workflow.query
|
||||||
def get_tool_data(self) -> Optional[ToolData]:
|
def get_latest_tool_data(self) -> Optional[ToolData]:
|
||||||
"""Query handler to retrieve the current tool data if available."""
|
"""Query handler to retrieve the latest tool data response if available."""
|
||||||
return self.tool_data
|
return self.tool_data
|
||||||
|
|
||||||
def format_history(self) -> str:
|
|
||||||
"""Format the conversation history into a single string."""
|
|
||||||
return " ".join(
|
|
||||||
str(msg["response"]) for msg in self.conversation_history["messages"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def prompt_with_history(self, prompt: str) -> tuple[str, str]:
|
|
||||||
"""Generate a context-aware prompt with conversation history.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[str, str]: A tuple of (context_instructions, prompt)
|
|
||||||
"""
|
|
||||||
history_string = self.format_history()
|
|
||||||
context_instructions = (
|
|
||||||
f"Here is the conversation history: {history_string} "
|
|
||||||
"Please add a few sentence response in plain text sentences. "
|
|
||||||
"Don't editorialize or add metadata. "
|
|
||||||
"Keep the text a plain explanation based on the history."
|
|
||||||
)
|
|
||||||
return (context_instructions, prompt)
|
|
||||||
|
|
||||||
def prompt_summary_with_history(self) -> tuple[str, str]:
|
|
||||||
"""Generate a prompt for summarizing the conversation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[str, str]: A tuple of (context_instructions, prompt)
|
|
||||||
"""
|
|
||||||
history_string = self.format_history()
|
|
||||||
context_instructions = f"Here is the conversation history between a user and a chatbot: {history_string}"
|
|
||||||
actual_prompt = (
|
|
||||||
"Please produce a two sentence summary of this conversation. "
|
|
||||||
'Put the summary in the format { "summary": "<plain text>" }'
|
|
||||||
)
|
|
||||||
return (context_instructions, actual_prompt)
|
|
||||||
|
|
||||||
def add_message(self, actor: str, response: Union[str, Dict[str, Any]]) -> None:
|
def add_message(self, actor: str, response: Union[str, Dict[str, Any]]) -> None:
|
||||||
"""Add a message to the conversation history.
|
"""Add a message to the conversation history.
|
||||||
|
|
||||||
|
|||||||
130
workflows/workflow_helpers.py
Normal file
130
workflows/workflow_helpers.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
from datetime import timedelta
|
||||||
|
from typing import Dict, Any, Deque
|
||||||
|
from temporalio import workflow
|
||||||
|
from temporalio.exceptions import ActivityError
|
||||||
|
from temporalio.common import RetryPolicy
|
||||||
|
|
||||||
|
from models.data_types import ConversationHistory, ToolPromptInput
|
||||||
|
from prompts.agent_prompt_generators import generate_missing_args_prompt, generate_tool_completion_prompt
|
||||||
|
from shared.config import TEMPORAL_LEGACY_TASK_QUEUE
|
||||||
|
|
||||||
|
# Constants from original file
|
||||||
|
TOOL_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=10)
|
||||||
|
TOOL_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30)
|
||||||
|
LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=10)
|
||||||
|
LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30)
|
||||||
|
|
||||||
|
async def handle_tool_execution(
|
||||||
|
current_tool: str,
|
||||||
|
tool_data: Dict[str, Any],
|
||||||
|
tool_results: list,
|
||||||
|
add_message_callback: callable,
|
||||||
|
prompt_queue: Deque[str]
|
||||||
|
) -> None:
|
||||||
|
"""Execute a tool after confirmation and handle its result."""
|
||||||
|
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
|
||||||
|
|
||||||
|
task_queue = (
|
||||||
|
TEMPORAL_LEGACY_TASK_QUEUE
|
||||||
|
if current_tool in ["SearchTrains", "BookTrains"]
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
dynamic_result = await workflow.execute_activity(
|
||||||
|
current_tool,
|
||||||
|
tool_data["args"],
|
||||||
|
task_queue=task_queue,
|
||||||
|
schedule_to_close_timeout=TOOL_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT,
|
||||||
|
start_to_close_timeout=TOOL_ACTIVITY_START_TO_CLOSE_TIMEOUT,
|
||||||
|
retry_policy=RetryPolicy(
|
||||||
|
initial_interval=timedelta(seconds=5), backoff_coefficient=1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
dynamic_result["tool"] = current_tool
|
||||||
|
tool_results.append(dynamic_result)
|
||||||
|
except ActivityError as e:
|
||||||
|
workflow.logger.error(f"Tool execution failed: {str(e)}")
|
||||||
|
dynamic_result = {"error": str(e), "tool": current_tool}
|
||||||
|
|
||||||
|
add_message_callback("tool_result", dynamic_result)
|
||||||
|
prompt_queue.append(generate_tool_completion_prompt(current_tool, dynamic_result))
|
||||||
|
|
||||||
|
async def handle_missing_args(
|
||||||
|
current_tool: str,
|
||||||
|
args: Dict[str, Any],
|
||||||
|
tool_data: Dict[str, Any],
|
||||||
|
prompt_queue: Deque[str]
|
||||||
|
) -> bool:
|
||||||
|
"""Check for missing arguments and handle them if found."""
|
||||||
|
missing_args = [key for key, value in args.items() if value is None]
|
||||||
|
|
||||||
|
if missing_args:
|
||||||
|
prompt_queue.append(
|
||||||
|
generate_missing_args_prompt(current_tool, tool_data, missing_args)
|
||||||
|
)
|
||||||
|
workflow.logger.info(
|
||||||
|
f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def format_history(conversation_history: ConversationHistory) -> str:
|
||||||
|
"""Format the conversation history into a single string."""
|
||||||
|
return " ".join(
|
||||||
|
str(msg["response"]) for msg in conversation_history["messages"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def prompt_with_history(conversation_history: ConversationHistory, prompt: str) -> tuple[str, str]:
|
||||||
|
"""Generate a context-aware prompt with conversation history."""
|
||||||
|
history_string = format_history(conversation_history)
|
||||||
|
context_instructions = (
|
||||||
|
f"Here is the conversation history: {history_string} "
|
||||||
|
"Please add a few sentence response in plain text sentences. "
|
||||||
|
"Don't editorialize or add metadata. "
|
||||||
|
"Keep the text a plain explanation based on the history."
|
||||||
|
)
|
||||||
|
return (context_instructions, prompt)
|
||||||
|
|
||||||
|
async def continue_as_new_if_needed(
|
||||||
|
conversation_history: ConversationHistory,
|
||||||
|
prompt_queue: Deque[str],
|
||||||
|
agent_goal: Any,
|
||||||
|
max_turns: int,
|
||||||
|
add_message_callback: callable
|
||||||
|
) -> None:
|
||||||
|
"""Handle workflow continuation if message limit is reached."""
|
||||||
|
if len(conversation_history["messages"]) >= max_turns:
|
||||||
|
summary_context, summary_prompt = prompt_summary_with_history(conversation_history)
|
||||||
|
summary_input = ToolPromptInput(
|
||||||
|
prompt=summary_prompt, context_instructions=summary_context
|
||||||
|
)
|
||||||
|
conversation_summary = await workflow.start_activity_method(
|
||||||
|
"ToolActivities.agent_toolPlanner",
|
||||||
|
summary_input,
|
||||||
|
schedule_to_close_timeout=LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT,
|
||||||
|
)
|
||||||
|
workflow.logger.info(
|
||||||
|
f"Continuing as new after {max_turns} turns."
|
||||||
|
)
|
||||||
|
add_message_callback("conversation_summary", conversation_summary)
|
||||||
|
workflow.continue_as_new(
|
||||||
|
args=[{
|
||||||
|
"tool_params": {
|
||||||
|
"conversation_summary": conversation_summary,
|
||||||
|
"prompt_queue": prompt_queue,
|
||||||
|
},
|
||||||
|
"agent_goal": agent_goal,
|
||||||
|
}]
|
||||||
|
)
|
||||||
|
|
||||||
|
def prompt_summary_with_history(conversation_history: ConversationHistory) -> tuple[str, str]:
|
||||||
|
"""Generate a prompt for summarizing the conversation.
|
||||||
|
Used only for continue as new of the workflow."""
|
||||||
|
history_string = format_history(conversation_history)
|
||||||
|
context_instructions = f"Here is the conversation history between a user and a chatbot: {history_string}"
|
||||||
|
actual_prompt = (
|
||||||
|
"Please produce a two sentence summary of this conversation. "
|
||||||
|
'Put the summary in the format { "summary": "<plain text>" }'
|
||||||
|
)
|
||||||
|
return (context_instructions, actual_prompt)
|
||||||
Reference in New Issue
Block a user