cursor refactor of workflow code

This commit is contained in:
Steve Androulakis
2025-01-07 13:27:53 -08:00
parent 1b8b9c9906
commit 2f22af500f

View File

@@ -1,37 +1,107 @@
from collections import deque from collections import deque
from datetime import timedelta from datetime import timedelta
from typing import Dict, Any, Union, List, Optional, Deque from typing import Dict, Any, Union, List, Optional, Deque, TypedDict, Literal
from temporalio.common import RetryPolicy
from temporalio.common import RetryPolicy
from temporalio import workflow from temporalio import workflow
with workflow.unsafe.imports_passed_through(): with workflow.unsafe.imports_passed_through():
from activities.tool_activities import ToolActivities, ToolPromptInput from activities.tool_activities import ToolActivities, ToolPromptInput
from prompts.agent_prompt_generators import ( from prompts.agent_prompt_generators import generate_genai_prompt
generate_genai_prompt,
)
from models.data_types import CombinedInput, ToolWorkflowParams from models.data_types import CombinedInput, ToolWorkflowParams
# Constants
MAX_TURNS_BEFORE_CONTINUE = 250
TOOL_ACTIVITY_TIMEOUT = timedelta(seconds=20)
LLM_ACTIVITY_TIMEOUT = timedelta(seconds=60)
# Type definitions
Message = Dict[str, Union[str, Dict[str, Any]]]
ConversationHistory = Dict[str, List[Message]]
NextStep = Literal["confirm", "question", "done"]
class ToolData(TypedDict, total=False):
next: NextStep
tool: str
args: Dict[str, Any]
response: str
@workflow.defn @workflow.defn
class ToolWorkflow: class ToolWorkflow:
"""Workflow that manages tool execution with user confirmation and conversation history."""
def __init__(self) -> None: def __init__(self) -> None:
self.conversation_history: Dict[ self.conversation_history: ConversationHistory = {"messages": []}
str, List[Dict[str, Union[str, Dict[str, Any]]]]
] = {"messages": []}
self.prompt_queue: Deque[str] = deque() self.prompt_queue: Deque[str] = deque()
self.conversation_summary: Optional[str] = None self.conversation_summary: Optional[str] = None
self.chat_ended: bool = False self.chat_ended: bool = False
self.tool_data = None self.tool_data: Optional[ToolData] = None
self.max_turns_before_continue: int = 250 self.confirm: bool = False
self.confirm = False
self.tool_results: List[Dict[str, Any]] = [] self.tool_results: List[Dict[str, Any]] = []
async def _handle_tool_execution(self, current_tool: str, tool_data: ToolData) -> None:
"""Execute a tool after confirmation and handle its result."""
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
dynamic_result = await workflow.execute_activity(
current_tool,
tool_data["args"],
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
)
dynamic_result["tool"] = current_tool
self.add_message("tool_result", {"tool": current_tool, "result": dynamic_result})
self.prompt_queue.append(
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
"INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. "
"DON'T ask any clarifying questions that are outside of the tools and args specified. "
)
async def _handle_missing_args(self, current_tool: str, args: Dict[str, Any], tool_data: ToolData) -> bool:
"""Check for missing arguments and handle them if found."""
missing_args = [key for key, value in args.items() if value is None]
if missing_args:
self.prompt_queue.append(
f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' "
f"and following missing arguments for tool {current_tool}: {missing_args}. "
"Only provide a valid JSON response without any comments or metadata."
)
workflow.logger.info(f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}")
return True
return False
async def _continue_as_new_if_needed(self, agent_goal: Any) -> None:
"""Handle workflow continuation if message limit is reached."""
if len(self.conversation_history["messages"]) >= MAX_TURNS_BEFORE_CONTINUE:
summary_context, summary_prompt = self.prompt_summary_with_history()
summary_input = ToolPromptInput(
prompt=summary_prompt,
context_instructions=summary_context
)
self.conversation_summary = await workflow.start_activity_method(
ToolActivities.prompt_llm,
summary_input,
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
)
workflow.logger.info(f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns.")
workflow.continue_as_new(
args=[
CombinedInput(
tool_params=ToolWorkflowParams(
conversation_summary=self.conversation_summary,
prompt_queue=self.prompt_queue,
),
agent_goal=agent_goal,
)
]
)
@workflow.run @workflow.run
async def run(self, combined_input: CombinedInput) -> str: async def run(self, combined_input: CombinedInput) -> str:
"""Main workflow execution method."""
params = combined_input.tool_params params = combined_input.tool_params
agent_goal = combined_input.agent_goal agent_goal = combined_input.agent_goal
tool_data = None
if params and params.conversation_summary: if params and params.conversation_summary:
self.add_message("conversation_summary", params.conversation_summary) self.add_message("conversation_summary", params.conversation_summary)
@@ -44,111 +114,61 @@ class ToolWorkflow:
current_tool = None current_tool = None
while True: while True:
# Wait until *any* signal or user prompt arrives:
await workflow.wait_condition( await workflow.wait_condition(
lambda: bool(self.prompt_queue) or self.chat_ended or self.confirm lambda: bool(self.prompt_queue) or self.chat_ended or self.confirm
) )
# 1) If chat_ended was signaled, handle end and return
if self.chat_ended: if self.chat_ended:
workflow.logger.info("Chat ended.") workflow.logger.info("Chat ended.")
return f"{self.conversation_history}" return f"{self.conversation_history}"
# 2) If we received a confirm signal: if self.confirm and waiting_for_confirm and current_tool and self.tool_data:
if self.confirm and waiting_for_confirm and current_tool:
# Clear the confirm flag so we don't repeatedly confirm
self.confirm = False self.confirm = False
waiting_for_confirm = False waiting_for_confirm = False
confirmed_tool_data = self.tool_data.copy() confirmed_tool_data = self.tool_data.copy()
confirmed_tool_data["next"] = "user_confirmed_tool_run" confirmed_tool_data["next"] = "user_confirmed_tool_run"
self.add_message("user_confirmed_tool_run", confirmed_tool_data) self.add_message("user_confirmed_tool_run", confirmed_tool_data)
# Run the tool await self._handle_tool_execution(current_tool, self.tool_data)
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
dynamic_result = await workflow.execute_activity(
current_tool,
self.tool_data["args"],
schedule_to_close_timeout=timedelta(seconds=20),
)
dynamic_result["tool"] = current_tool
self.add_message(
"tool_result", {"tool": current_tool, "result": dynamic_result}
)
# Enqueue a follow-up prompt for the LLM
self.prompt_queue.append(
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
"INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. "
"DON'T ask any clarifying questions that are outside of the tools and args specified. "
)
# Loop around again
continue continue
# 3) If there's a user prompt waiting, process it (unless we're in some other skipping logic).
if self.prompt_queue: if self.prompt_queue:
prompt = self.prompt_queue.popleft() prompt = self.prompt_queue.popleft()
if prompt.startswith("###"): if not prompt.startswith("###"):
pass
else:
self.add_message("user", prompt) self.add_message("user", prompt)
# Pass entire conversation + Tools to LLM
context_instructions = generate_genai_prompt( context_instructions = generate_genai_prompt(
agent_goal, self.conversation_history, self.tool_data agent_goal, self.conversation_history, self.tool_data
) )
# tools_list = ", ".join([t.name for t in agent_goal.tools])
prompt_input = ToolPromptInput( prompt_input = ToolPromptInput(
prompt=prompt, prompt=prompt,
context_instructions=context_instructions, context_instructions=context_instructions,
) )
tool_data = await workflow.execute_activity( tool_data = await workflow.execute_activity(
ToolActivities.prompt_llm, ToolActivities.prompt_llm,
prompt_input, prompt_input,
schedule_to_close_timeout=timedelta(seconds=60), schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT,
retry_policy=RetryPolicy( retry_policy=RetryPolicy(
maximum_attempts=5, initial_interval=timedelta(seconds=12) maximum_attempts=5,
initial_interval=timedelta(seconds=12)
), ),
) )
self.tool_data = tool_data self.tool_data = tool_data
# Check the next step from LLM next_step = tool_data.get("next")
next_step = self.tool_data.get("next") current_tool = tool_data.get("tool")
current_tool = self.tool_data.get("tool")
if next_step == "confirm" and current_tool: if next_step == "confirm" and current_tool:
# todo make this less awkward args = tool_data.get("args", {})
args = self.tool_data.get("args") if await self._handle_missing_args(current_tool, args, tool_data):
# check each argument for null values
missing_args = []
for key, value in args.items():
if value is None:
next_step = "question"
missing_args.append(key)
if missing_args:
# Enqueue a follow-up prompt for the LLM
self.prompt_queue.append(
f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' and following missing arguments for tool {current_tool}: {missing_args}. "
"Only provide a valid JSON response without any comments or metadata."
)
workflow.logger.info(
f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}"
)
# Loop around again
continue continue
waiting_for_confirm = True waiting_for_confirm = True
self.confirm = False # Clear any stale confirm self.confirm = False
workflow.logger.info("Waiting for user confirm signal...") workflow.logger.info("Waiting for user confirm signal...")
# We do NOT do an immediate wait_condition here;
# instead, let the loop continue so we can still handle prompts/end_chat signals.
elif next_step == "done": elif next_step == "done":
workflow.logger.info("All steps completed. Exiting workflow.") workflow.logger.info("All steps completed. Exiting workflow.")
@@ -156,39 +176,11 @@ class ToolWorkflow:
return str(self.conversation_history) return str(self.conversation_history)
self.add_message("agent", tool_data) self.add_message("agent", tool_data)
await self._continue_as_new_if_needed(agent_goal)
# Possibly continue-as-new after many turns
# todo ensure this doesn't lose critical context
if (
len(self.conversation_history["messages"])
>= self.max_turns_before_continue
):
summary_context, summary_prompt = self.prompt_summary_with_history()
summary_input = ToolPromptInput(
prompt=summary_prompt, context_instructions=summary_context
)
self.conversation_summary = await workflow.start_activity_method(
ToolActivities.prompt_llm,
summary_input,
schedule_to_close_timeout=timedelta(seconds=20),
)
workflow.logger.info(
f"Continuing as new after {self.max_turns_before_continue} turns."
)
workflow.continue_as_new(
args=[
CombinedInput(
tool_params=ToolWorkflowParams(
conversation_summary=self.conversation_summary,
prompt_queue=self.prompt_queue,
),
agent_goal=agent_goal,
)
]
)
@workflow.signal @workflow.signal
async def user_prompt(self, prompt: str) -> None: async def user_prompt(self, prompt: str) -> None:
"""Signal handler for receiving user prompts."""
if self.chat_ended: if self.chat_ended:
workflow.logger.warn(f"Message dropped due to chat closed: {prompt}") workflow.logger.warn(f"Message dropped due to chat closed: {prompt}")
return return
@@ -196,36 +188,41 @@ class ToolWorkflow:
@workflow.signal @workflow.signal
async def end_chat(self) -> None: async def end_chat(self) -> None:
"""Signal handler for ending the chat session."""
self.chat_ended = True self.chat_ended = True
@workflow.signal @workflow.signal
async def confirm(self) -> None: async def confirm(self) -> None:
"""Signal handler for user confirmation of tool execution."""
self.confirm = True self.confirm = True
@workflow.query @workflow.query
def get_conversation_history( def get_conversation_history(self) -> ConversationHistory:
self, """Query handler to retrieve the full conversation history."""
) -> Dict[str, List[Dict[str, Union[str, Dict[str, Any]]]]]:
# Return the whole conversation as a dict
return self.conversation_history return self.conversation_history
@workflow.query @workflow.query
def get_summary_from_history(self) -> Optional[dict]: def get_summary_from_history(self) -> Optional[str]:
"""Query handler to retrieve the conversation summary if available."""
return self.conversation_summary return self.conversation_summary
@workflow.query @workflow.query
def get_tool_data(self) -> Optional[dict]: def get_tool_data(self) -> Optional[ToolData]:
"""Query handler to retrieve the current tool data if available."""
return self.tool_data return self.tool_data
# Helper: generate text of the entire conversation so far
def format_history(self) -> str: def format_history(self) -> str:
"""Format the conversation history into a single string."""
return " ".join( return " ".join(
str(msg["response"]) for msg in self.conversation_history["messages"] str(msg["response"]) for msg in self.conversation_history["messages"]
) )
# Return (context_instructions, prompt)
def prompt_with_history(self, prompt: str) -> tuple[str, str]: def prompt_with_history(self, prompt: str) -> tuple[str, str]:
"""Generate a context-aware prompt with conversation history.
Returns:
tuple[str, str]: A tuple of (context_instructions, prompt)
"""
history_string = self.format_history() history_string = self.format_history()
context_instructions = ( context_instructions = (
f"Here is the conversation history: {history_string} " f"Here is the conversation history: {history_string} "
@@ -235,8 +232,12 @@ class ToolWorkflow:
) )
return (context_instructions, prompt) return (context_instructions, prompt)
# Return (context_instructions, prompt) for summarizing the conversation
def prompt_summary_with_history(self) -> tuple[str, str]: def prompt_summary_with_history(self) -> tuple[str, str]:
"""Generate a prompt for summarizing the conversation.
Returns:
tuple[str, str]: A tuple of (context_instructions, prompt)
"""
history_string = self.format_history() history_string = self.format_history()
context_instructions = f"Here is the conversation history between a user and a chatbot: {history_string}" context_instructions = f"Here is the conversation history between a user and a chatbot: {history_string}"
actual_prompt = ( actual_prompt = (
@@ -246,7 +247,12 @@ class ToolWorkflow:
return (context_instructions, actual_prompt) return (context_instructions, actual_prompt)
def add_message(self, actor: str, response: Union[str, Dict[str, Any]]) -> None: def add_message(self, actor: str, response: Union[str, Dict[str, Any]]) -> None:
# Append a message object to the "messages" list """Add a message to the conversation history.
Args:
actor: The entity that generated the message (e.g., "user", "agent")
response: The message content, either as a string or structured data
"""
self.conversation_history["messages"].append( self.conversation_history["messages"].append(
{"actor": actor, "response": response} {"actor": actor, "response": response}
) )