From d996927855d9fc5f3b10c8112545ee800ce8fca6 Mon Sep 17 00:00:00 2001 From: Steve Androulakis Date: Sun, 16 Feb 2025 07:45:44 -0800 Subject: [PATCH] refactor workflow file for clarity --- workflows/agent_goal_workflow.py | 164 ++++++------------------------- workflows/workflow_helpers.py | 130 ++++++++++++++++++++++++ 2 files changed, 158 insertions(+), 136 deletions(-) create mode 100644 workflows/workflow_helpers.py diff --git a/workflows/agent_goal_workflow.py b/workflows/agent_goal_workflow.py index c4d940d..fdc8fc5 100644 --- a/workflows/agent_goal_workflow.py +++ b/workflows/agent_goal_workflow.py @@ -4,32 +4,24 @@ from typing import Dict, Any, Union, List, Optional, Deque, TypedDict from temporalio.common import RetryPolicy from temporalio import workflow -from temporalio.exceptions import ActivityError from models.data_types import ConversationHistory, NextStep, ValidationInput +from workflows.workflow_helpers import LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT, \ + LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT +from workflows import workflow_helpers as helpers with workflow.unsafe.imports_passed_through(): from activities.tool_activities import ToolActivities from prompts.agent_prompt_generators import ( - generate_genai_prompt, - generate_tool_completion_prompt, - generate_missing_args_prompt, + generate_genai_prompt ) from models.data_types import ( CombinedInput, - AgentGoalWorkflowParams, ToolPromptInput, ) -from shared.config import TEMPORAL_LEGACY_TASK_QUEUE - # Constants MAX_TURNS_BEFORE_CONTINUE = 250 -TOOL_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=10) -TOOL_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30) -LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=10) -LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30) - class ToolData(TypedDict, total=False): next: NextStep @@ -37,7 +29,6 @@ class ToolData(TypedDict, total=False): args: Dict[str, Any] response: str - @workflow.defn class AgentGoalWorkflow: """Workflow that manages tool execution with user confirmation and conversation history.""" @@ -51,82 +42,6 @@ class AgentGoalWorkflow: self.confirm: bool = False self.tool_results: List[Dict[str, Any]] = [] - async def _handle_tool_execution( - self, current_tool: str, tool_data: ToolData - ) -> None: - """Execute a tool after confirmation and handle its result.""" - workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}") - - task_queue = ( - TEMPORAL_LEGACY_TASK_QUEUE - if current_tool in ["SearchTrains", "BookTrains"] - else None - ) - - try: - dynamic_result = await workflow.execute_activity( - current_tool, - tool_data["args"], - task_queue=task_queue, - schedule_to_close_timeout=TOOL_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT, - start_to_close_timeout=TOOL_ACTIVITY_START_TO_CLOSE_TIMEOUT, - retry_policy=RetryPolicy( - initial_interval=timedelta(seconds=5), backoff_coefficient=1 - ), - ) - dynamic_result["tool"] = current_tool - self.tool_results.append(dynamic_result) - except ActivityError as e: - workflow.logger.error(f"Tool execution failed: {str(e)}") - dynamic_result = {"error": str(e), "tool": current_tool} - - self.add_message("tool_result", dynamic_result) - - self.prompt_queue.append(generate_tool_completion_prompt(current_tool, dynamic_result)) - - async def _handle_missing_args( - self, current_tool: str, args: Dict[str, Any], tool_data: ToolData - ) -> bool: - """Check for missing arguments and handle them if found.""" - missing_args = [key for key, value in args.items() if value is None] - - if missing_args: - self.prompt_queue.append( - generate_missing_args_prompt(current_tool, tool_data, missing_args) - ) - workflow.logger.info( - f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}" - ) - return True - return False - - async def _continue_as_new_if_needed(self, agent_goal: Any) -> None: - """Handle workflow continuation if message limit is reached.""" - if len(self.conversation_history["messages"]) >= MAX_TURNS_BEFORE_CONTINUE: - summary_context, summary_prompt = self.prompt_summary_with_history() - summary_input = ToolPromptInput( - prompt=summary_prompt, context_instructions=summary_context - ) - self.conversation_summary = await workflow.start_activity_method( - ToolActivities.agent_toolPlanner, - summary_input, - schedule_to_close_timeout=LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT, - ) - workflow.logger.info( - f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns." - ) - workflow.continue_as_new( - args=[ - CombinedInput( - tool_params=AgentGoalWorkflowParams( - conversation_summary=self.conversation_summary, - prompt_queue=self.prompt_queue, - ), - agent_goal=agent_goal, - ) - ] - ) - @workflow.run async def run(self, combined_input: CombinedInput) -> str: """Main workflow execution method.""" @@ -160,7 +75,13 @@ class AgentGoalWorkflow: confirmed_tool_data["next"] = "user_confirmed_tool_run" self.add_message("user_confirmed_tool_run", confirmed_tool_data) - await self._handle_tool_execution(current_tool, self.tool_data) + await helpers.handle_tool_execution( + current_tool, + self.tool_data, + self.tool_results, + self.add_message, + self.prompt_queue + ) continue if self.prompt_queue: @@ -194,7 +115,6 @@ class AgentGoalWorkflow: continue # Proceed with generating the context and prompt - context_instructions = generate_genai_prompt( agent_goal, self.conversation_history, self.tool_data ) @@ -220,7 +140,7 @@ class AgentGoalWorkflow: if next_step == "confirm" and current_tool: args = tool_data.get("args", {}) - if await self._handle_missing_args(current_tool, args, tool_data): + if await helpers.handle_missing_args(current_tool, args, tool_data, self.prompt_queue): continue waiting_for_confirm = True @@ -233,7 +153,13 @@ class AgentGoalWorkflow: return str(self.conversation_history) self.add_message("agent", tool_data) - await self._continue_as_new_if_needed(agent_goal) + await helpers.continue_as_new_if_needed( + self.conversation_history, + self.prompt_queue, + agent_goal, + MAX_TURNS_BEFORE_CONTINUE, + self.add_message + ) @workflow.signal async def user_prompt(self, prompt: str) -> None: @@ -243,17 +169,17 @@ class AgentGoalWorkflow: return self.prompt_queue.append(prompt) - @workflow.signal - async def end_chat(self) -> None: - """Signal handler for ending the chat session.""" - self.chat_ended = True - @workflow.signal async def confirm(self) -> None: """Signal handler for user confirmation of tool execution.""" workflow.logger.info("Received user confirmation") self.confirm = True + @workflow.signal + async def end_chat(self) -> None: + """Signal handler for ending the chat session.""" + self.chat_ended = True + @workflow.query def get_conversation_history(self) -> ConversationHistory: """Query handler to retrieve the full conversation history.""" @@ -261,49 +187,15 @@ class AgentGoalWorkflow: @workflow.query def get_summary_from_history(self) -> Optional[str]: - """Query handler to retrieve the conversation summary if available.""" + """Query handler to retrieve the conversation summary if available. + Used only for continue as new of the workflow.""" return self.conversation_summary @workflow.query - def get_tool_data(self) -> Optional[ToolData]: - """Query handler to retrieve the current tool data if available.""" + def get_latest_tool_data(self) -> Optional[ToolData]: + """Query handler to retrieve the latest tool data response if available.""" return self.tool_data - def format_history(self) -> str: - """Format the conversation history into a single string.""" - return " ".join( - str(msg["response"]) for msg in self.conversation_history["messages"] - ) - - def prompt_with_history(self, prompt: str) -> tuple[str, str]: - """Generate a context-aware prompt with conversation history. - - Returns: - tuple[str, str]: A tuple of (context_instructions, prompt) - """ - history_string = self.format_history() - context_instructions = ( - f"Here is the conversation history: {history_string} " - "Please add a few sentence response in plain text sentences. " - "Don't editorialize or add metadata. " - "Keep the text a plain explanation based on the history." - ) - return (context_instructions, prompt) - - def prompt_summary_with_history(self) -> tuple[str, str]: - """Generate a prompt for summarizing the conversation. - - Returns: - tuple[str, str]: A tuple of (context_instructions, prompt) - """ - history_string = self.format_history() - context_instructions = f"Here is the conversation history between a user and a chatbot: {history_string}" - actual_prompt = ( - "Please produce a two sentence summary of this conversation. " - 'Put the summary in the format { "summary": "" }' - ) - return (context_instructions, actual_prompt) - def add_message(self, actor: str, response: Union[str, Dict[str, Any]]) -> None: """Add a message to the conversation history. diff --git a/workflows/workflow_helpers.py b/workflows/workflow_helpers.py new file mode 100644 index 0000000..5546a7f --- /dev/null +++ b/workflows/workflow_helpers.py @@ -0,0 +1,130 @@ +from datetime import timedelta +from typing import Dict, Any, Deque +from temporalio import workflow +from temporalio.exceptions import ActivityError +from temporalio.common import RetryPolicy + +from models.data_types import ConversationHistory, ToolPromptInput +from prompts.agent_prompt_generators import generate_missing_args_prompt, generate_tool_completion_prompt +from shared.config import TEMPORAL_LEGACY_TASK_QUEUE + +# Constants from original file +TOOL_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=10) +TOOL_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30) +LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=10) +LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30) + +async def handle_tool_execution( + current_tool: str, + tool_data: Dict[str, Any], + tool_results: list, + add_message_callback: callable, + prompt_queue: Deque[str] +) -> None: + """Execute a tool after confirmation and handle its result.""" + workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}") + + task_queue = ( + TEMPORAL_LEGACY_TASK_QUEUE + if current_tool in ["SearchTrains", "BookTrains"] + else None + ) + + try: + dynamic_result = await workflow.execute_activity( + current_tool, + tool_data["args"], + task_queue=task_queue, + schedule_to_close_timeout=TOOL_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT, + start_to_close_timeout=TOOL_ACTIVITY_START_TO_CLOSE_TIMEOUT, + retry_policy=RetryPolicy( + initial_interval=timedelta(seconds=5), backoff_coefficient=1 + ), + ) + dynamic_result["tool"] = current_tool + tool_results.append(dynamic_result) + except ActivityError as e: + workflow.logger.error(f"Tool execution failed: {str(e)}") + dynamic_result = {"error": str(e), "tool": current_tool} + + add_message_callback("tool_result", dynamic_result) + prompt_queue.append(generate_tool_completion_prompt(current_tool, dynamic_result)) + +async def handle_missing_args( + current_tool: str, + args: Dict[str, Any], + tool_data: Dict[str, Any], + prompt_queue: Deque[str] +) -> bool: + """Check for missing arguments and handle them if found.""" + missing_args = [key for key, value in args.items() if value is None] + + if missing_args: + prompt_queue.append( + generate_missing_args_prompt(current_tool, tool_data, missing_args) + ) + workflow.logger.info( + f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}" + ) + return True + return False + +def format_history(conversation_history: ConversationHistory) -> str: + """Format the conversation history into a single string.""" + return " ".join( + str(msg["response"]) for msg in conversation_history["messages"] + ) + +def prompt_with_history(conversation_history: ConversationHistory, prompt: str) -> tuple[str, str]: + """Generate a context-aware prompt with conversation history.""" + history_string = format_history(conversation_history) + context_instructions = ( + f"Here is the conversation history: {history_string} " + "Please add a few sentence response in plain text sentences. " + "Don't editorialize or add metadata. " + "Keep the text a plain explanation based on the history." + ) + return (context_instructions, prompt) + +async def continue_as_new_if_needed( + conversation_history: ConversationHistory, + prompt_queue: Deque[str], + agent_goal: Any, + max_turns: int, + add_message_callback: callable +) -> None: + """Handle workflow continuation if message limit is reached.""" + if len(conversation_history["messages"]) >= max_turns: + summary_context, summary_prompt = prompt_summary_with_history(conversation_history) + summary_input = ToolPromptInput( + prompt=summary_prompt, context_instructions=summary_context + ) + conversation_summary = await workflow.start_activity_method( + "ToolActivities.agent_toolPlanner", + summary_input, + schedule_to_close_timeout=LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT, + ) + workflow.logger.info( + f"Continuing as new after {max_turns} turns." + ) + add_message_callback("conversation_summary", conversation_summary) + workflow.continue_as_new( + args=[{ + "tool_params": { + "conversation_summary": conversation_summary, + "prompt_queue": prompt_queue, + }, + "agent_goal": agent_goal, + }] + ) + +def prompt_summary_with_history(conversation_history: ConversationHistory) -> tuple[str, str]: + """Generate a prompt for summarizing the conversation. + Used only for continue as new of the workflow.""" + history_string = format_history(conversation_history) + context_instructions = f"Here is the conversation history between a user and a chatbot: {history_string}" + actual_prompt = ( + "Please produce a two sentence summary of this conversation. " + 'Put the summary in the format { "summary": "" }' + ) + return (context_instructions, actual_prompt) \ No newline at end of file