cursor refactor of workflow code

This commit is contained in:
Steve Androulakis
2025-01-07 13:27:53 -08:00
parent 1b8b9c9906
commit 2f22af500f

View File

@@ -1,37 +1,107 @@
from collections import deque
from datetime import timedelta
from typing import Dict, Any, Union, List, Optional, Deque
from temporalio.common import RetryPolicy
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict, Literal
from temporalio.common import RetryPolicy
from temporalio import workflow
with workflow.unsafe.imports_passed_through():
from activities.tool_activities import ToolActivities, ToolPromptInput
from prompts.agent_prompt_generators import (
generate_genai_prompt,
)
from prompts.agent_prompt_generators import generate_genai_prompt
from models.data_types import CombinedInput, ToolWorkflowParams
# Constants
MAX_TURNS_BEFORE_CONTINUE = 250
TOOL_ACTIVITY_TIMEOUT = timedelta(seconds=20)
LLM_ACTIVITY_TIMEOUT = timedelta(seconds=60)
# Type definitions
Message = Dict[str, Union[str, Dict[str, Any]]]
ConversationHistory = Dict[str, List[Message]]
NextStep = Literal["confirm", "question", "done"]
class ToolData(TypedDict, total=False):
next: NextStep
tool: str
args: Dict[str, Any]
response: str
@workflow.defn
class ToolWorkflow:
"""Workflow that manages tool execution with user confirmation and conversation history."""
def __init__(self) -> None:
self.conversation_history: Dict[
str, List[Dict[str, Union[str, Dict[str, Any]]]]
] = {"messages": []}
self.conversation_history: ConversationHistory = {"messages": []}
self.prompt_queue: Deque[str] = deque()
self.conversation_summary: Optional[str] = None
self.chat_ended: bool = False
self.tool_data = None
self.max_turns_before_continue: int = 250
self.confirm = False
self.tool_data: Optional[ToolData] = None
self.confirm: bool = False
self.tool_results: List[Dict[str, Any]] = []
async def _handle_tool_execution(self, current_tool: str, tool_data: ToolData) -> None:
"""Execute a tool after confirmation and handle its result."""
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
dynamic_result = await workflow.execute_activity(
current_tool,
tool_data["args"],
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
)
dynamic_result["tool"] = current_tool
self.add_message("tool_result", {"tool": current_tool, "result": dynamic_result})
self.prompt_queue.append(
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
"INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. "
"DON'T ask any clarifying questions that are outside of the tools and args specified. "
)
async def _handle_missing_args(self, current_tool: str, args: Dict[str, Any], tool_data: ToolData) -> bool:
"""Check for missing arguments and handle them if found."""
missing_args = [key for key, value in args.items() if value is None]
if missing_args:
self.prompt_queue.append(
f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' "
f"and following missing arguments for tool {current_tool}: {missing_args}. "
"Only provide a valid JSON response without any comments or metadata."
)
workflow.logger.info(f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}")
return True
return False
async def _continue_as_new_if_needed(self, agent_goal: Any) -> None:
"""Handle workflow continuation if message limit is reached."""
if len(self.conversation_history["messages"]) >= MAX_TURNS_BEFORE_CONTINUE:
summary_context, summary_prompt = self.prompt_summary_with_history()
summary_input = ToolPromptInput(
prompt=summary_prompt,
context_instructions=summary_context
)
self.conversation_summary = await workflow.start_activity_method(
ToolActivities.prompt_llm,
summary_input,
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
)
workflow.logger.info(f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns.")
workflow.continue_as_new(
args=[
CombinedInput(
tool_params=ToolWorkflowParams(
conversation_summary=self.conversation_summary,
prompt_queue=self.prompt_queue,
),
agent_goal=agent_goal,
)
]
)
@workflow.run
async def run(self, combined_input: CombinedInput) -> str:
"""Main workflow execution method."""
params = combined_input.tool_params
agent_goal = combined_input.agent_goal
tool_data = None
if params and params.conversation_summary:
self.add_message("conversation_summary", params.conversation_summary)
@@ -44,111 +114,61 @@ class ToolWorkflow:
current_tool = None
while True:
# Wait until *any* signal or user prompt arrives:
await workflow.wait_condition(
lambda: bool(self.prompt_queue) or self.chat_ended or self.confirm
)
# 1) If chat_ended was signaled, handle end and return
if self.chat_ended:
workflow.logger.info("Chat ended.")
return f"{self.conversation_history}"
# 2) If we received a confirm signal:
if self.confirm and waiting_for_confirm and current_tool:
# Clear the confirm flag so we don't repeatedly confirm
if self.confirm and waiting_for_confirm and current_tool and self.tool_data:
self.confirm = False
waiting_for_confirm = False
confirmed_tool_data = self.tool_data.copy()
confirmed_tool_data["next"] = "user_confirmed_tool_run"
self.add_message("user_confirmed_tool_run", confirmed_tool_data)
# Run the tool
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
dynamic_result = await workflow.execute_activity(
current_tool,
self.tool_data["args"],
schedule_to_close_timeout=timedelta(seconds=20),
)
dynamic_result["tool"] = current_tool
self.add_message(
"tool_result", {"tool": current_tool, "result": dynamic_result}
)
# Enqueue a follow-up prompt for the LLM
self.prompt_queue.append(
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
"INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. "
"DON'T ask any clarifying questions that are outside of the tools and args specified. "
)
# Loop around again
await self._handle_tool_execution(current_tool, self.tool_data)
continue
# 3) If there's a user prompt waiting, process it (unless we're in some other skipping logic).
if self.prompt_queue:
prompt = self.prompt_queue.popleft()
if prompt.startswith("###"):
pass
else:
if not prompt.startswith("###"):
self.add_message("user", prompt)
# Pass entire conversation + Tools to LLM
context_instructions = generate_genai_prompt(
agent_goal, self.conversation_history, self.tool_data
)
# tools_list = ", ".join([t.name for t in agent_goal.tools])
prompt_input = ToolPromptInput(
prompt=prompt,
context_instructions=context_instructions,
)
tool_data = await workflow.execute_activity(
ToolActivities.prompt_llm,
prompt_input,
schedule_to_close_timeout=timedelta(seconds=60),
schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT,
retry_policy=RetryPolicy(
maximum_attempts=5, initial_interval=timedelta(seconds=12)
maximum_attempts=5,
initial_interval=timedelta(seconds=12)
),
)
self.tool_data = tool_data
# Check the next step from LLM
next_step = self.tool_data.get("next")
current_tool = self.tool_data.get("tool")
next_step = tool_data.get("next")
current_tool = tool_data.get("tool")
if next_step == "confirm" and current_tool:
# todo make this less awkward
args = self.tool_data.get("args")
# check each argument for null values
missing_args = []
for key, value in args.items():
if value is None:
next_step = "question"
missing_args.append(key)
if missing_args:
# Enqueue a follow-up prompt for the LLM
self.prompt_queue.append(
f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' and following missing arguments for tool {current_tool}: {missing_args}. "
"Only provide a valid JSON response without any comments or metadata."
)
workflow.logger.info(
f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}"
)
# Loop around again
args = tool_data.get("args", {})
if await self._handle_missing_args(current_tool, args, tool_data):
continue
waiting_for_confirm = True
self.confirm = False # Clear any stale confirm
self.confirm = False
workflow.logger.info("Waiting for user confirm signal...")
# We do NOT do an immediate wait_condition here;
# instead, let the loop continue so we can still handle prompts/end_chat signals.
elif next_step == "done":
workflow.logger.info("All steps completed. Exiting workflow.")
@@ -156,39 +176,11 @@ class ToolWorkflow:
return str(self.conversation_history)
self.add_message("agent", tool_data)
# Possibly continue-as-new after many turns
# todo ensure this doesn't lose critical context
if (
len(self.conversation_history["messages"])
>= self.max_turns_before_continue
):
summary_context, summary_prompt = self.prompt_summary_with_history()
summary_input = ToolPromptInput(
prompt=summary_prompt, context_instructions=summary_context
)
self.conversation_summary = await workflow.start_activity_method(
ToolActivities.prompt_llm,
summary_input,
schedule_to_close_timeout=timedelta(seconds=20),
)
workflow.logger.info(
f"Continuing as new after {self.max_turns_before_continue} turns."
)
workflow.continue_as_new(
args=[
CombinedInput(
tool_params=ToolWorkflowParams(
conversation_summary=self.conversation_summary,
prompt_queue=self.prompt_queue,
),
agent_goal=agent_goal,
)
]
)
await self._continue_as_new_if_needed(agent_goal)
@workflow.signal
async def user_prompt(self, prompt: str) -> None:
"""Signal handler for receiving user prompts."""
if self.chat_ended:
workflow.logger.warn(f"Message dropped due to chat closed: {prompt}")
return
@@ -196,36 +188,41 @@ class ToolWorkflow:
@workflow.signal
async def end_chat(self) -> None:
"""Signal handler for ending the chat session."""
self.chat_ended = True
@workflow.signal
async def confirm(self) -> None:
"""Signal handler for user confirmation of tool execution."""
self.confirm = True
@workflow.query
def get_conversation_history(
self,
) -> Dict[str, List[Dict[str, Union[str, Dict[str, Any]]]]]:
# Return the whole conversation as a dict
def get_conversation_history(self) -> ConversationHistory:
"""Query handler to retrieve the full conversation history."""
return self.conversation_history
@workflow.query
def get_summary_from_history(self) -> Optional[dict]:
def get_summary_from_history(self) -> Optional[str]:
"""Query handler to retrieve the conversation summary if available."""
return self.conversation_summary
@workflow.query
def get_tool_data(self) -> Optional[dict]:
def get_tool_data(self) -> Optional[ToolData]:
"""Query handler to retrieve the current tool data if available."""
return self.tool_data
# Helper: generate text of the entire conversation so far
def format_history(self) -> str:
"""Format the conversation history into a single string."""
return " ".join(
str(msg["response"]) for msg in self.conversation_history["messages"]
)
# Return (context_instructions, prompt)
def prompt_with_history(self, prompt: str) -> tuple[str, str]:
"""Generate a context-aware prompt with conversation history.
Returns:
tuple[str, str]: A tuple of (context_instructions, prompt)
"""
history_string = self.format_history()
context_instructions = (
f"Here is the conversation history: {history_string} "
@@ -235,8 +232,12 @@ class ToolWorkflow:
)
return (context_instructions, prompt)
# Return (context_instructions, prompt) for summarizing the conversation
def prompt_summary_with_history(self) -> tuple[str, str]:
"""Generate a prompt for summarizing the conversation.
Returns:
tuple[str, str]: A tuple of (context_instructions, prompt)
"""
history_string = self.format_history()
context_instructions = f"Here is the conversation history between a user and a chatbot: {history_string}"
actual_prompt = (
@@ -246,7 +247,12 @@ class ToolWorkflow:
return (context_instructions, actual_prompt)
def add_message(self, actor: str, response: Union[str, Dict[str, Any]]) -> None:
# Append a message object to the "messages" list
"""Add a message to the conversation history.
Args:
actor: The entity that generated the message (e.g., "user", "agent")
response: The message content, either as a string or structured data
"""
self.conversation_history["messages"].append(
{"actor": actor, "response": response}
)