From 2f22af500ff2842105b323fbb3126bd04f53280e Mon Sep 17 00:00:00 2001 From: Steve Androulakis Date: Tue, 7 Jan 2025 13:27:53 -0800 Subject: [PATCH] cursor refactor of workflow code --- workflows/tool_workflow.py | 236 +++++++++++++++++++------------------ 1 file changed, 121 insertions(+), 115 deletions(-) diff --git a/workflows/tool_workflow.py b/workflows/tool_workflow.py index 6efea91..b5463c4 100644 --- a/workflows/tool_workflow.py +++ b/workflows/tool_workflow.py @@ -1,37 +1,107 @@ from collections import deque from datetime import timedelta -from typing import Dict, Any, Union, List, Optional, Deque -from temporalio.common import RetryPolicy +from typing import Dict, Any, Union, List, Optional, Deque, TypedDict, Literal +from temporalio.common import RetryPolicy from temporalio import workflow with workflow.unsafe.imports_passed_through(): from activities.tool_activities import ToolActivities, ToolPromptInput - from prompts.agent_prompt_generators import ( - generate_genai_prompt, - ) + from prompts.agent_prompt_generators import generate_genai_prompt from models.data_types import CombinedInput, ToolWorkflowParams +# Constants +MAX_TURNS_BEFORE_CONTINUE = 250 +TOOL_ACTIVITY_TIMEOUT = timedelta(seconds=20) +LLM_ACTIVITY_TIMEOUT = timedelta(seconds=60) + +# Type definitions +Message = Dict[str, Union[str, Dict[str, Any]]] +ConversationHistory = Dict[str, List[Message]] +NextStep = Literal["confirm", "question", "done"] + +class ToolData(TypedDict, total=False): + next: NextStep + tool: str + args: Dict[str, Any] + response: str @workflow.defn class ToolWorkflow: + """Workflow that manages tool execution with user confirmation and conversation history.""" + def __init__(self) -> None: - self.conversation_history: Dict[ - str, List[Dict[str, Union[str, Dict[str, Any]]]] - ] = {"messages": []} + self.conversation_history: ConversationHistory = {"messages": []} self.prompt_queue: Deque[str] = deque() self.conversation_summary: Optional[str] = None self.chat_ended: bool = False - self.tool_data = None - self.max_turns_before_continue: int = 250 - self.confirm = False + self.tool_data: Optional[ToolData] = None + self.confirm: bool = False self.tool_results: List[Dict[str, Any]] = [] + async def _handle_tool_execution(self, current_tool: str, tool_data: ToolData) -> None: + """Execute a tool after confirmation and handle its result.""" + workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}") + + dynamic_result = await workflow.execute_activity( + current_tool, + tool_data["args"], + schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT, + ) + dynamic_result["tool"] = current_tool + self.add_message("tool_result", {"tool": current_tool, "result": dynamic_result}) + + self.prompt_queue.append( + f"### The '{current_tool}' tool completed successfully with {dynamic_result}. " + "INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. " + "DON'T ask any clarifying questions that are outside of the tools and args specified. " + ) + + async def _handle_missing_args(self, current_tool: str, args: Dict[str, Any], tool_data: ToolData) -> bool: + """Check for missing arguments and handle them if found.""" + missing_args = [key for key, value in args.items() if value is None] + + if missing_args: + self.prompt_queue.append( + f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' " + f"and following missing arguments for tool {current_tool}: {missing_args}. " + "Only provide a valid JSON response without any comments or metadata." + ) + workflow.logger.info(f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}") + return True + return False + + async def _continue_as_new_if_needed(self, agent_goal: Any) -> None: + """Handle workflow continuation if message limit is reached.""" + if len(self.conversation_history["messages"]) >= MAX_TURNS_BEFORE_CONTINUE: + summary_context, summary_prompt = self.prompt_summary_with_history() + summary_input = ToolPromptInput( + prompt=summary_prompt, + context_instructions=summary_context + ) + self.conversation_summary = await workflow.start_activity_method( + ToolActivities.prompt_llm, + summary_input, + schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT, + ) + workflow.logger.info(f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns.") + workflow.continue_as_new( + args=[ + CombinedInput( + tool_params=ToolWorkflowParams( + conversation_summary=self.conversation_summary, + prompt_queue=self.prompt_queue, + ), + agent_goal=agent_goal, + ) + ] + ) + @workflow.run async def run(self, combined_input: CombinedInput) -> str: + """Main workflow execution method.""" params = combined_input.tool_params agent_goal = combined_input.agent_goal - tool_data = None if params and params.conversation_summary: self.add_message("conversation_summary", params.conversation_summary) @@ -44,111 +114,61 @@ class ToolWorkflow: current_tool = None while True: - # Wait until *any* signal or user prompt arrives: await workflow.wait_condition( lambda: bool(self.prompt_queue) or self.chat_ended or self.confirm ) - # 1) If chat_ended was signaled, handle end and return if self.chat_ended: - workflow.logger.info("Chat ended.") return f"{self.conversation_history}" - # 2) If we received a confirm signal: - if self.confirm and waiting_for_confirm and current_tool: - # Clear the confirm flag so we don't repeatedly confirm + if self.confirm and waiting_for_confirm and current_tool and self.tool_data: self.confirm = False waiting_for_confirm = False confirmed_tool_data = self.tool_data.copy() - confirmed_tool_data["next"] = "user_confirmed_tool_run" self.add_message("user_confirmed_tool_run", confirmed_tool_data) - # Run the tool - workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}") - dynamic_result = await workflow.execute_activity( - current_tool, - self.tool_data["args"], - schedule_to_close_timeout=timedelta(seconds=20), - ) - dynamic_result["tool"] = current_tool - self.add_message( - "tool_result", {"tool": current_tool, "result": dynamic_result} - ) - - # Enqueue a follow-up prompt for the LLM - self.prompt_queue.append( - f"### The '{current_tool}' tool completed successfully with {dynamic_result}. " - "INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. " - "DON'T ask any clarifying questions that are outside of the tools and args specified. " - ) - # Loop around again + await self._handle_tool_execution(current_tool, self.tool_data) continue - # 3) If there's a user prompt waiting, process it (unless we're in some other skipping logic). if self.prompt_queue: prompt = self.prompt_queue.popleft() - if prompt.startswith("###"): - pass - else: + if not prompt.startswith("###"): self.add_message("user", prompt) - # Pass entire conversation + Tools to LLM context_instructions = generate_genai_prompt( agent_goal, self.conversation_history, self.tool_data ) - # tools_list = ", ".join([t.name for t in agent_goal.tools]) - prompt_input = ToolPromptInput( prompt=prompt, context_instructions=context_instructions, ) + tool_data = await workflow.execute_activity( ToolActivities.prompt_llm, prompt_input, - schedule_to_close_timeout=timedelta(seconds=60), + schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT, retry_policy=RetryPolicy( - maximum_attempts=5, initial_interval=timedelta(seconds=12) + maximum_attempts=5, + initial_interval=timedelta(seconds=12) ), ) self.tool_data = tool_data - # Check the next step from LLM - next_step = self.tool_data.get("next") - current_tool = self.tool_data.get("tool") + next_step = tool_data.get("next") + current_tool = tool_data.get("tool") if next_step == "confirm" and current_tool: - # todo make this less awkward - args = self.tool_data.get("args") - - # check each argument for null values - missing_args = [] - for key, value in args.items(): - if value is None: - next_step = "question" - missing_args.append(key) - - if missing_args: - # Enqueue a follow-up prompt for the LLM - self.prompt_queue.append( - f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' and following missing arguments for tool {current_tool}: {missing_args}. " - "Only provide a valid JSON response without any comments or metadata." - ) - - workflow.logger.info( - f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}" - ) - # Loop around again + args = tool_data.get("args", {}) + if await self._handle_missing_args(current_tool, args, tool_data): continue waiting_for_confirm = True - self.confirm = False # Clear any stale confirm + self.confirm = False workflow.logger.info("Waiting for user confirm signal...") - # We do NOT do an immediate wait_condition here; - # instead, let the loop continue so we can still handle prompts/end_chat signals. elif next_step == "done": workflow.logger.info("All steps completed. Exiting workflow.") @@ -156,39 +176,11 @@ class ToolWorkflow: return str(self.conversation_history) self.add_message("agent", tool_data) - - # Possibly continue-as-new after many turns - # todo ensure this doesn't lose critical context - if ( - len(self.conversation_history["messages"]) - >= self.max_turns_before_continue - ): - summary_context, summary_prompt = self.prompt_summary_with_history() - summary_input = ToolPromptInput( - prompt=summary_prompt, context_instructions=summary_context - ) - self.conversation_summary = await workflow.start_activity_method( - ToolActivities.prompt_llm, - summary_input, - schedule_to_close_timeout=timedelta(seconds=20), - ) - workflow.logger.info( - f"Continuing as new after {self.max_turns_before_continue} turns." - ) - workflow.continue_as_new( - args=[ - CombinedInput( - tool_params=ToolWorkflowParams( - conversation_summary=self.conversation_summary, - prompt_queue=self.prompt_queue, - ), - agent_goal=agent_goal, - ) - ] - ) + await self._continue_as_new_if_needed(agent_goal) @workflow.signal async def user_prompt(self, prompt: str) -> None: + """Signal handler for receiving user prompts.""" if self.chat_ended: workflow.logger.warn(f"Message dropped due to chat closed: {prompt}") return @@ -196,36 +188,41 @@ class ToolWorkflow: @workflow.signal async def end_chat(self) -> None: + """Signal handler for ending the chat session.""" self.chat_ended = True @workflow.signal async def confirm(self) -> None: + """Signal handler for user confirmation of tool execution.""" self.confirm = True @workflow.query - def get_conversation_history( - self, - ) -> Dict[str, List[Dict[str, Union[str, Dict[str, Any]]]]]: - # Return the whole conversation as a dict + def get_conversation_history(self) -> ConversationHistory: + """Query handler to retrieve the full conversation history.""" return self.conversation_history @workflow.query - def get_summary_from_history(self) -> Optional[dict]: + def get_summary_from_history(self) -> Optional[str]: + """Query handler to retrieve the conversation summary if available.""" return self.conversation_summary @workflow.query - def get_tool_data(self) -> Optional[dict]: + def get_tool_data(self) -> Optional[ToolData]: + """Query handler to retrieve the current tool data if available.""" return self.tool_data - # Helper: generate text of the entire conversation so far - def format_history(self) -> str: + """Format the conversation history into a single string.""" return " ".join( str(msg["response"]) for msg in self.conversation_history["messages"] ) - # Return (context_instructions, prompt) def prompt_with_history(self, prompt: str) -> tuple[str, str]: + """Generate a context-aware prompt with conversation history. + + Returns: + tuple[str, str]: A tuple of (context_instructions, prompt) + """ history_string = self.format_history() context_instructions = ( f"Here is the conversation history: {history_string} " @@ -235,8 +232,12 @@ class ToolWorkflow: ) return (context_instructions, prompt) - # Return (context_instructions, prompt) for summarizing the conversation def prompt_summary_with_history(self) -> tuple[str, str]: + """Generate a prompt for summarizing the conversation. + + Returns: + tuple[str, str]: A tuple of (context_instructions, prompt) + """ history_string = self.format_history() context_instructions = f"Here is the conversation history between a user and a chatbot: {history_string}" actual_prompt = ( @@ -246,7 +247,12 @@ class ToolWorkflow: return (context_instructions, actual_prompt) def add_message(self, actor: str, response: Union[str, Dict[str, Any]]) -> None: - # Append a message object to the "messages" list + """Add a message to the conversation history. + + Args: + actor: The entity that generated the message (e.g., "user", "agent") + response: The message content, either as a string or structured data + """ self.conversation_history["messages"].append( {"actor": actor, "response": response} )