refactor workflow file for clarity

This commit is contained in:
Steve Androulakis
2025-02-16 07:45:44 -08:00
parent 355203c8fd
commit d996927855
2 changed files with 158 additions and 136 deletions

View File

@@ -4,32 +4,24 @@ from typing import Dict, Any, Union, List, Optional, Deque, TypedDict
from temporalio.common import RetryPolicy from temporalio.common import RetryPolicy
from temporalio import workflow from temporalio import workflow
from temporalio.exceptions import ActivityError
from models.data_types import ConversationHistory, NextStep, ValidationInput from models.data_types import ConversationHistory, NextStep, ValidationInput
from workflows.workflow_helpers import LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT, \
LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT
from workflows import workflow_helpers as helpers
with workflow.unsafe.imports_passed_through(): with workflow.unsafe.imports_passed_through():
from activities.tool_activities import ToolActivities from activities.tool_activities import ToolActivities
from prompts.agent_prompt_generators import ( from prompts.agent_prompt_generators import (
generate_genai_prompt, generate_genai_prompt
generate_tool_completion_prompt,
generate_missing_args_prompt,
) )
from models.data_types import ( from models.data_types import (
CombinedInput, CombinedInput,
AgentGoalWorkflowParams,
ToolPromptInput, ToolPromptInput,
) )
from shared.config import TEMPORAL_LEGACY_TASK_QUEUE
# Constants # Constants
MAX_TURNS_BEFORE_CONTINUE = 250 MAX_TURNS_BEFORE_CONTINUE = 250
TOOL_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=10)
TOOL_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30)
LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=10)
LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30)
class ToolData(TypedDict, total=False): class ToolData(TypedDict, total=False):
next: NextStep next: NextStep
@@ -37,7 +29,6 @@ class ToolData(TypedDict, total=False):
args: Dict[str, Any] args: Dict[str, Any]
response: str response: str
@workflow.defn @workflow.defn
class AgentGoalWorkflow: class AgentGoalWorkflow:
"""Workflow that manages tool execution with user confirmation and conversation history.""" """Workflow that manages tool execution with user confirmation and conversation history."""
@@ -51,82 +42,6 @@ class AgentGoalWorkflow:
self.confirm: bool = False self.confirm: bool = False
self.tool_results: List[Dict[str, Any]] = [] self.tool_results: List[Dict[str, Any]] = []
async def _handle_tool_execution(
self, current_tool: str, tool_data: ToolData
) -> None:
"""Execute a tool after confirmation and handle its result."""
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
task_queue = (
TEMPORAL_LEGACY_TASK_QUEUE
if current_tool in ["SearchTrains", "BookTrains"]
else None
)
try:
dynamic_result = await workflow.execute_activity(
current_tool,
tool_data["args"],
task_queue=task_queue,
schedule_to_close_timeout=TOOL_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT,
start_to_close_timeout=TOOL_ACTIVITY_START_TO_CLOSE_TIMEOUT,
retry_policy=RetryPolicy(
initial_interval=timedelta(seconds=5), backoff_coefficient=1
),
)
dynamic_result["tool"] = current_tool
self.tool_results.append(dynamic_result)
except ActivityError as e:
workflow.logger.error(f"Tool execution failed: {str(e)}")
dynamic_result = {"error": str(e), "tool": current_tool}
self.add_message("tool_result", dynamic_result)
self.prompt_queue.append(generate_tool_completion_prompt(current_tool, dynamic_result))
async def _handle_missing_args(
self, current_tool: str, args: Dict[str, Any], tool_data: ToolData
) -> bool:
"""Check for missing arguments and handle them if found."""
missing_args = [key for key, value in args.items() if value is None]
if missing_args:
self.prompt_queue.append(
generate_missing_args_prompt(current_tool, tool_data, missing_args)
)
workflow.logger.info(
f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}"
)
return True
return False
async def _continue_as_new_if_needed(self, agent_goal: Any) -> None:
"""Handle workflow continuation if message limit is reached."""
if len(self.conversation_history["messages"]) >= MAX_TURNS_BEFORE_CONTINUE:
summary_context, summary_prompt = self.prompt_summary_with_history()
summary_input = ToolPromptInput(
prompt=summary_prompt, context_instructions=summary_context
)
self.conversation_summary = await workflow.start_activity_method(
ToolActivities.agent_toolPlanner,
summary_input,
schedule_to_close_timeout=LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT,
)
workflow.logger.info(
f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns."
)
workflow.continue_as_new(
args=[
CombinedInput(
tool_params=AgentGoalWorkflowParams(
conversation_summary=self.conversation_summary,
prompt_queue=self.prompt_queue,
),
agent_goal=agent_goal,
)
]
)
@workflow.run @workflow.run
async def run(self, combined_input: CombinedInput) -> str: async def run(self, combined_input: CombinedInput) -> str:
"""Main workflow execution method.""" """Main workflow execution method."""
@@ -160,7 +75,13 @@ class AgentGoalWorkflow:
confirmed_tool_data["next"] = "user_confirmed_tool_run" confirmed_tool_data["next"] = "user_confirmed_tool_run"
self.add_message("user_confirmed_tool_run", confirmed_tool_data) self.add_message("user_confirmed_tool_run", confirmed_tool_data)
await self._handle_tool_execution(current_tool, self.tool_data) await helpers.handle_tool_execution(
current_tool,
self.tool_data,
self.tool_results,
self.add_message,
self.prompt_queue
)
continue continue
if self.prompt_queue: if self.prompt_queue:
@@ -194,7 +115,6 @@ class AgentGoalWorkflow:
continue continue
# Proceed with generating the context and prompt # Proceed with generating the context and prompt
context_instructions = generate_genai_prompt( context_instructions = generate_genai_prompt(
agent_goal, self.conversation_history, self.tool_data agent_goal, self.conversation_history, self.tool_data
) )
@@ -220,7 +140,7 @@ class AgentGoalWorkflow:
if next_step == "confirm" and current_tool: if next_step == "confirm" and current_tool:
args = tool_data.get("args", {}) args = tool_data.get("args", {})
if await self._handle_missing_args(current_tool, args, tool_data): if await helpers.handle_missing_args(current_tool, args, tool_data, self.prompt_queue):
continue continue
waiting_for_confirm = True waiting_for_confirm = True
@@ -233,7 +153,13 @@ class AgentGoalWorkflow:
return str(self.conversation_history) return str(self.conversation_history)
self.add_message("agent", tool_data) self.add_message("agent", tool_data)
await self._continue_as_new_if_needed(agent_goal) await helpers.continue_as_new_if_needed(
self.conversation_history,
self.prompt_queue,
agent_goal,
MAX_TURNS_BEFORE_CONTINUE,
self.add_message
)
@workflow.signal @workflow.signal
async def user_prompt(self, prompt: str) -> None: async def user_prompt(self, prompt: str) -> None:
@@ -243,17 +169,17 @@ class AgentGoalWorkflow:
return return
self.prompt_queue.append(prompt) self.prompt_queue.append(prompt)
@workflow.signal
async def end_chat(self) -> None:
"""Signal handler for ending the chat session."""
self.chat_ended = True
@workflow.signal @workflow.signal
async def confirm(self) -> None: async def confirm(self) -> None:
"""Signal handler for user confirmation of tool execution.""" """Signal handler for user confirmation of tool execution."""
workflow.logger.info("Received user confirmation") workflow.logger.info("Received user confirmation")
self.confirm = True self.confirm = True
@workflow.signal
async def end_chat(self) -> None:
"""Signal handler for ending the chat session."""
self.chat_ended = True
@workflow.query @workflow.query
def get_conversation_history(self) -> ConversationHistory: def get_conversation_history(self) -> ConversationHistory:
"""Query handler to retrieve the full conversation history.""" """Query handler to retrieve the full conversation history."""
@@ -261,49 +187,15 @@ class AgentGoalWorkflow:
@workflow.query @workflow.query
def get_summary_from_history(self) -> Optional[str]: def get_summary_from_history(self) -> Optional[str]:
"""Query handler to retrieve the conversation summary if available.""" """Query handler to retrieve the conversation summary if available.
Used only for continue as new of the workflow."""
return self.conversation_summary return self.conversation_summary
@workflow.query @workflow.query
def get_tool_data(self) -> Optional[ToolData]: def get_latest_tool_data(self) -> Optional[ToolData]:
"""Query handler to retrieve the current tool data if available.""" """Query handler to retrieve the latest tool data response if available."""
return self.tool_data return self.tool_data
def format_history(self) -> str:
"""Format the conversation history into a single string."""
return " ".join(
str(msg["response"]) for msg in self.conversation_history["messages"]
)
def prompt_with_history(self, prompt: str) -> tuple[str, str]:
"""Generate a context-aware prompt with conversation history.
Returns:
tuple[str, str]: A tuple of (context_instructions, prompt)
"""
history_string = self.format_history()
context_instructions = (
f"Here is the conversation history: {history_string} "
"Please add a few sentence response in plain text sentences. "
"Don't editorialize or add metadata. "
"Keep the text a plain explanation based on the history."
)
return (context_instructions, prompt)
def prompt_summary_with_history(self) -> tuple[str, str]:
"""Generate a prompt for summarizing the conversation.
Returns:
tuple[str, str]: A tuple of (context_instructions, prompt)
"""
history_string = self.format_history()
context_instructions = f"Here is the conversation history between a user and a chatbot: {history_string}"
actual_prompt = (
"Please produce a two sentence summary of this conversation. "
'Put the summary in the format { "summary": "<plain text>" }'
)
return (context_instructions, actual_prompt)
def add_message(self, actor: str, response: Union[str, Dict[str, Any]]) -> None: def add_message(self, actor: str, response: Union[str, Dict[str, Any]]) -> None:
"""Add a message to the conversation history. """Add a message to the conversation history.

View File

@@ -0,0 +1,130 @@
from datetime import timedelta
from typing import Dict, Any, Deque
from temporalio import workflow
from temporalio.exceptions import ActivityError
from temporalio.common import RetryPolicy
from models.data_types import ConversationHistory, ToolPromptInput
from prompts.agent_prompt_generators import generate_missing_args_prompt, generate_tool_completion_prompt
from shared.config import TEMPORAL_LEGACY_TASK_QUEUE
# Constants from original file
TOOL_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=10)
TOOL_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30)
LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=10)
LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30)
async def handle_tool_execution(
current_tool: str,
tool_data: Dict[str, Any],
tool_results: list,
add_message_callback: callable,
prompt_queue: Deque[str]
) -> None:
"""Execute a tool after confirmation and handle its result."""
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
task_queue = (
TEMPORAL_LEGACY_TASK_QUEUE
if current_tool in ["SearchTrains", "BookTrains"]
else None
)
try:
dynamic_result = await workflow.execute_activity(
current_tool,
tool_data["args"],
task_queue=task_queue,
schedule_to_close_timeout=TOOL_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT,
start_to_close_timeout=TOOL_ACTIVITY_START_TO_CLOSE_TIMEOUT,
retry_policy=RetryPolicy(
initial_interval=timedelta(seconds=5), backoff_coefficient=1
),
)
dynamic_result["tool"] = current_tool
tool_results.append(dynamic_result)
except ActivityError as e:
workflow.logger.error(f"Tool execution failed: {str(e)}")
dynamic_result = {"error": str(e), "tool": current_tool}
add_message_callback("tool_result", dynamic_result)
prompt_queue.append(generate_tool_completion_prompt(current_tool, dynamic_result))
async def handle_missing_args(
current_tool: str,
args: Dict[str, Any],
tool_data: Dict[str, Any],
prompt_queue: Deque[str]
) -> bool:
"""Check for missing arguments and handle them if found."""
missing_args = [key for key, value in args.items() if value is None]
if missing_args:
prompt_queue.append(
generate_missing_args_prompt(current_tool, tool_data, missing_args)
)
workflow.logger.info(
f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}"
)
return True
return False
def format_history(conversation_history: ConversationHistory) -> str:
"""Format the conversation history into a single string."""
return " ".join(
str(msg["response"]) for msg in conversation_history["messages"]
)
def prompt_with_history(conversation_history: ConversationHistory, prompt: str) -> tuple[str, str]:
"""Generate a context-aware prompt with conversation history."""
history_string = format_history(conversation_history)
context_instructions = (
f"Here is the conversation history: {history_string} "
"Please add a few sentence response in plain text sentences. "
"Don't editorialize or add metadata. "
"Keep the text a plain explanation based on the history."
)
return (context_instructions, prompt)
async def continue_as_new_if_needed(
conversation_history: ConversationHistory,
prompt_queue: Deque[str],
agent_goal: Any,
max_turns: int,
add_message_callback: callable
) -> None:
"""Handle workflow continuation if message limit is reached."""
if len(conversation_history["messages"]) >= max_turns:
summary_context, summary_prompt = prompt_summary_with_history(conversation_history)
summary_input = ToolPromptInput(
prompt=summary_prompt, context_instructions=summary_context
)
conversation_summary = await workflow.start_activity_method(
"ToolActivities.agent_toolPlanner",
summary_input,
schedule_to_close_timeout=LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT,
)
workflow.logger.info(
f"Continuing as new after {max_turns} turns."
)
add_message_callback("conversation_summary", conversation_summary)
workflow.continue_as_new(
args=[{
"tool_params": {
"conversation_summary": conversation_summary,
"prompt_queue": prompt_queue,
},
"agent_goal": agent_goal,
}]
)
def prompt_summary_with_history(conversation_history: ConversationHistory) -> tuple[str, str]:
"""Generate a prompt for summarizing the conversation.
Used only for continue as new of the workflow."""
history_string = format_history(conversation_history)
context_instructions = f"Here is the conversation history between a user and a chatbot: {history_string}"
actual_prompt = (
"Please produce a two sentence summary of this conversation. "
'Put the summary in the format { "summary": "<plain text>" }'
)
return (context_instructions, actual_prompt)