mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-16 22:48:09 +01:00
cursor refactor of workflow code
This commit is contained in:
@@ -1,37 +1,107 @@
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Dict, Any, Union, List, Optional, Deque
|
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict, Literal
|
||||||
from temporalio.common import RetryPolicy
|
|
||||||
|
|
||||||
|
from temporalio.common import RetryPolicy
|
||||||
from temporalio import workflow
|
from temporalio import workflow
|
||||||
|
|
||||||
with workflow.unsafe.imports_passed_through():
|
with workflow.unsafe.imports_passed_through():
|
||||||
from activities.tool_activities import ToolActivities, ToolPromptInput
|
from activities.tool_activities import ToolActivities, ToolPromptInput
|
||||||
from prompts.agent_prompt_generators import (
|
from prompts.agent_prompt_generators import generate_genai_prompt
|
||||||
generate_genai_prompt,
|
|
||||||
)
|
|
||||||
from models.data_types import CombinedInput, ToolWorkflowParams
|
from models.data_types import CombinedInput, ToolWorkflowParams
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
MAX_TURNS_BEFORE_CONTINUE = 250
|
||||||
|
TOOL_ACTIVITY_TIMEOUT = timedelta(seconds=20)
|
||||||
|
LLM_ACTIVITY_TIMEOUT = timedelta(seconds=60)
|
||||||
|
|
||||||
|
# Type definitions
|
||||||
|
Message = Dict[str, Union[str, Dict[str, Any]]]
|
||||||
|
ConversationHistory = Dict[str, List[Message]]
|
||||||
|
NextStep = Literal["confirm", "question", "done"]
|
||||||
|
|
||||||
|
class ToolData(TypedDict, total=False):
|
||||||
|
next: NextStep
|
||||||
|
tool: str
|
||||||
|
args: Dict[str, Any]
|
||||||
|
response: str
|
||||||
|
|
||||||
@workflow.defn
|
@workflow.defn
|
||||||
class ToolWorkflow:
|
class ToolWorkflow:
|
||||||
|
"""Workflow that manages tool execution with user confirmation and conversation history."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.conversation_history: Dict[
|
self.conversation_history: ConversationHistory = {"messages": []}
|
||||||
str, List[Dict[str, Union[str, Dict[str, Any]]]]
|
|
||||||
] = {"messages": []}
|
|
||||||
self.prompt_queue: Deque[str] = deque()
|
self.prompt_queue: Deque[str] = deque()
|
||||||
self.conversation_summary: Optional[str] = None
|
self.conversation_summary: Optional[str] = None
|
||||||
self.chat_ended: bool = False
|
self.chat_ended: bool = False
|
||||||
self.tool_data = None
|
self.tool_data: Optional[ToolData] = None
|
||||||
self.max_turns_before_continue: int = 250
|
self.confirm: bool = False
|
||||||
self.confirm = False
|
|
||||||
self.tool_results: List[Dict[str, Any]] = []
|
self.tool_results: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def _handle_tool_execution(self, current_tool: str, tool_data: ToolData) -> None:
|
||||||
|
"""Execute a tool after confirmation and handle its result."""
|
||||||
|
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
|
||||||
|
|
||||||
|
dynamic_result = await workflow.execute_activity(
|
||||||
|
current_tool,
|
||||||
|
tool_data["args"],
|
||||||
|
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
|
||||||
|
)
|
||||||
|
dynamic_result["tool"] = current_tool
|
||||||
|
self.add_message("tool_result", {"tool": current_tool, "result": dynamic_result})
|
||||||
|
|
||||||
|
self.prompt_queue.append(
|
||||||
|
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
|
||||||
|
"INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. "
|
||||||
|
"DON'T ask any clarifying questions that are outside of the tools and args specified. "
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_missing_args(self, current_tool: str, args: Dict[str, Any], tool_data: ToolData) -> bool:
|
||||||
|
"""Check for missing arguments and handle them if found."""
|
||||||
|
missing_args = [key for key, value in args.items() if value is None]
|
||||||
|
|
||||||
|
if missing_args:
|
||||||
|
self.prompt_queue.append(
|
||||||
|
f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' "
|
||||||
|
f"and following missing arguments for tool {current_tool}: {missing_args}. "
|
||||||
|
"Only provide a valid JSON response without any comments or metadata."
|
||||||
|
)
|
||||||
|
workflow.logger.info(f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _continue_as_new_if_needed(self, agent_goal: Any) -> None:
|
||||||
|
"""Handle workflow continuation if message limit is reached."""
|
||||||
|
if len(self.conversation_history["messages"]) >= MAX_TURNS_BEFORE_CONTINUE:
|
||||||
|
summary_context, summary_prompt = self.prompt_summary_with_history()
|
||||||
|
summary_input = ToolPromptInput(
|
||||||
|
prompt=summary_prompt,
|
||||||
|
context_instructions=summary_context
|
||||||
|
)
|
||||||
|
self.conversation_summary = await workflow.start_activity_method(
|
||||||
|
ToolActivities.prompt_llm,
|
||||||
|
summary_input,
|
||||||
|
schedule_to_close_timeout=TOOL_ACTIVITY_TIMEOUT,
|
||||||
|
)
|
||||||
|
workflow.logger.info(f"Continuing as new after {MAX_TURNS_BEFORE_CONTINUE} turns.")
|
||||||
|
workflow.continue_as_new(
|
||||||
|
args=[
|
||||||
|
CombinedInput(
|
||||||
|
tool_params=ToolWorkflowParams(
|
||||||
|
conversation_summary=self.conversation_summary,
|
||||||
|
prompt_queue=self.prompt_queue,
|
||||||
|
),
|
||||||
|
agent_goal=agent_goal,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
@workflow.run
|
@workflow.run
|
||||||
async def run(self, combined_input: CombinedInput) -> str:
|
async def run(self, combined_input: CombinedInput) -> str:
|
||||||
|
"""Main workflow execution method."""
|
||||||
params = combined_input.tool_params
|
params = combined_input.tool_params
|
||||||
agent_goal = combined_input.agent_goal
|
agent_goal = combined_input.agent_goal
|
||||||
tool_data = None
|
|
||||||
|
|
||||||
if params and params.conversation_summary:
|
if params and params.conversation_summary:
|
||||||
self.add_message("conversation_summary", params.conversation_summary)
|
self.add_message("conversation_summary", params.conversation_summary)
|
||||||
@@ -44,111 +114,61 @@ class ToolWorkflow:
|
|||||||
current_tool = None
|
current_tool = None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Wait until *any* signal or user prompt arrives:
|
|
||||||
await workflow.wait_condition(
|
await workflow.wait_condition(
|
||||||
lambda: bool(self.prompt_queue) or self.chat_ended or self.confirm
|
lambda: bool(self.prompt_queue) or self.chat_ended or self.confirm
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1) If chat_ended was signaled, handle end and return
|
|
||||||
if self.chat_ended:
|
if self.chat_ended:
|
||||||
|
|
||||||
workflow.logger.info("Chat ended.")
|
workflow.logger.info("Chat ended.")
|
||||||
return f"{self.conversation_history}"
|
return f"{self.conversation_history}"
|
||||||
|
|
||||||
# 2) If we received a confirm signal:
|
if self.confirm and waiting_for_confirm and current_tool and self.tool_data:
|
||||||
if self.confirm and waiting_for_confirm and current_tool:
|
|
||||||
# Clear the confirm flag so we don't repeatedly confirm
|
|
||||||
self.confirm = False
|
self.confirm = False
|
||||||
waiting_for_confirm = False
|
waiting_for_confirm = False
|
||||||
|
|
||||||
confirmed_tool_data = self.tool_data.copy()
|
confirmed_tool_data = self.tool_data.copy()
|
||||||
|
|
||||||
confirmed_tool_data["next"] = "user_confirmed_tool_run"
|
confirmed_tool_data["next"] = "user_confirmed_tool_run"
|
||||||
self.add_message("user_confirmed_tool_run", confirmed_tool_data)
|
self.add_message("user_confirmed_tool_run", confirmed_tool_data)
|
||||||
|
|
||||||
# Run the tool
|
await self._handle_tool_execution(current_tool, self.tool_data)
|
||||||
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
|
|
||||||
dynamic_result = await workflow.execute_activity(
|
|
||||||
current_tool,
|
|
||||||
self.tool_data["args"],
|
|
||||||
schedule_to_close_timeout=timedelta(seconds=20),
|
|
||||||
)
|
|
||||||
dynamic_result["tool"] = current_tool
|
|
||||||
self.add_message(
|
|
||||||
"tool_result", {"tool": current_tool, "result": dynamic_result}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Enqueue a follow-up prompt for the LLM
|
|
||||||
self.prompt_queue.append(
|
|
||||||
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
|
|
||||||
"INSTRUCTIONS: Use this tool result, the list of tools in sequence and the conversation history to figure out next steps, if any. "
|
|
||||||
"DON'T ask any clarifying questions that are outside of the tools and args specified. "
|
|
||||||
)
|
|
||||||
# Loop around again
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 3) If there's a user prompt waiting, process it (unless we're in some other skipping logic).
|
|
||||||
if self.prompt_queue:
|
if self.prompt_queue:
|
||||||
prompt = self.prompt_queue.popleft()
|
prompt = self.prompt_queue.popleft()
|
||||||
if prompt.startswith("###"):
|
if not prompt.startswith("###"):
|
||||||
pass
|
|
||||||
else:
|
|
||||||
self.add_message("user", prompt)
|
self.add_message("user", prompt)
|
||||||
|
|
||||||
# Pass entire conversation + Tools to LLM
|
|
||||||
context_instructions = generate_genai_prompt(
|
context_instructions = generate_genai_prompt(
|
||||||
agent_goal, self.conversation_history, self.tool_data
|
agent_goal, self.conversation_history, self.tool_data
|
||||||
)
|
)
|
||||||
|
|
||||||
# tools_list = ", ".join([t.name for t in agent_goal.tools])
|
|
||||||
|
|
||||||
prompt_input = ToolPromptInput(
|
prompt_input = ToolPromptInput(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
context_instructions=context_instructions,
|
context_instructions=context_instructions,
|
||||||
)
|
)
|
||||||
|
|
||||||
tool_data = await workflow.execute_activity(
|
tool_data = await workflow.execute_activity(
|
||||||
ToolActivities.prompt_llm,
|
ToolActivities.prompt_llm,
|
||||||
prompt_input,
|
prompt_input,
|
||||||
schedule_to_close_timeout=timedelta(seconds=60),
|
schedule_to_close_timeout=LLM_ACTIVITY_TIMEOUT,
|
||||||
retry_policy=RetryPolicy(
|
retry_policy=RetryPolicy(
|
||||||
maximum_attempts=5, initial_interval=timedelta(seconds=12)
|
maximum_attempts=5,
|
||||||
|
initial_interval=timedelta(seconds=12)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.tool_data = tool_data
|
self.tool_data = tool_data
|
||||||
|
|
||||||
# Check the next step from LLM
|
next_step = tool_data.get("next")
|
||||||
next_step = self.tool_data.get("next")
|
current_tool = tool_data.get("tool")
|
||||||
current_tool = self.tool_data.get("tool")
|
|
||||||
|
|
||||||
if next_step == "confirm" and current_tool:
|
if next_step == "confirm" and current_tool:
|
||||||
# todo make this less awkward
|
args = tool_data.get("args", {})
|
||||||
args = self.tool_data.get("args")
|
if await self._handle_missing_args(current_tool, args, tool_data):
|
||||||
|
|
||||||
# check each argument for null values
|
|
||||||
missing_args = []
|
|
||||||
for key, value in args.items():
|
|
||||||
if value is None:
|
|
||||||
next_step = "question"
|
|
||||||
missing_args.append(key)
|
|
||||||
|
|
||||||
if missing_args:
|
|
||||||
# Enqueue a follow-up prompt for the LLM
|
|
||||||
self.prompt_queue.append(
|
|
||||||
f"### INSTRUCTIONS set next='question', combine this response response='{tool_data.get('response')}' and following missing arguments for tool {current_tool}: {missing_args}. "
|
|
||||||
"Only provide a valid JSON response without any comments or metadata."
|
|
||||||
)
|
|
||||||
|
|
||||||
workflow.logger.info(
|
|
||||||
f"Missing arguments for tool: {current_tool}: {' '.join(missing_args)}"
|
|
||||||
)
|
|
||||||
# Loop around again
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
waiting_for_confirm = True
|
waiting_for_confirm = True
|
||||||
self.confirm = False # Clear any stale confirm
|
self.confirm = False
|
||||||
workflow.logger.info("Waiting for user confirm signal...")
|
workflow.logger.info("Waiting for user confirm signal...")
|
||||||
# We do NOT do an immediate wait_condition here;
|
|
||||||
# instead, let the loop continue so we can still handle prompts/end_chat signals.
|
|
||||||
|
|
||||||
elif next_step == "done":
|
elif next_step == "done":
|
||||||
workflow.logger.info("All steps completed. Exiting workflow.")
|
workflow.logger.info("All steps completed. Exiting workflow.")
|
||||||
@@ -156,39 +176,11 @@ class ToolWorkflow:
|
|||||||
return str(self.conversation_history)
|
return str(self.conversation_history)
|
||||||
|
|
||||||
self.add_message("agent", tool_data)
|
self.add_message("agent", tool_data)
|
||||||
|
await self._continue_as_new_if_needed(agent_goal)
|
||||||
# Possibly continue-as-new after many turns
|
|
||||||
# todo ensure this doesn't lose critical context
|
|
||||||
if (
|
|
||||||
len(self.conversation_history["messages"])
|
|
||||||
>= self.max_turns_before_continue
|
|
||||||
):
|
|
||||||
summary_context, summary_prompt = self.prompt_summary_with_history()
|
|
||||||
summary_input = ToolPromptInput(
|
|
||||||
prompt=summary_prompt, context_instructions=summary_context
|
|
||||||
)
|
|
||||||
self.conversation_summary = await workflow.start_activity_method(
|
|
||||||
ToolActivities.prompt_llm,
|
|
||||||
summary_input,
|
|
||||||
schedule_to_close_timeout=timedelta(seconds=20),
|
|
||||||
)
|
|
||||||
workflow.logger.info(
|
|
||||||
f"Continuing as new after {self.max_turns_before_continue} turns."
|
|
||||||
)
|
|
||||||
workflow.continue_as_new(
|
|
||||||
args=[
|
|
||||||
CombinedInput(
|
|
||||||
tool_params=ToolWorkflowParams(
|
|
||||||
conversation_summary=self.conversation_summary,
|
|
||||||
prompt_queue=self.prompt_queue,
|
|
||||||
),
|
|
||||||
agent_goal=agent_goal,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@workflow.signal
|
@workflow.signal
|
||||||
async def user_prompt(self, prompt: str) -> None:
|
async def user_prompt(self, prompt: str) -> None:
|
||||||
|
"""Signal handler for receiving user prompts."""
|
||||||
if self.chat_ended:
|
if self.chat_ended:
|
||||||
workflow.logger.warn(f"Message dropped due to chat closed: {prompt}")
|
workflow.logger.warn(f"Message dropped due to chat closed: {prompt}")
|
||||||
return
|
return
|
||||||
@@ -196,36 +188,41 @@ class ToolWorkflow:
|
|||||||
|
|
||||||
@workflow.signal
|
@workflow.signal
|
||||||
async def end_chat(self) -> None:
|
async def end_chat(self) -> None:
|
||||||
|
"""Signal handler for ending the chat session."""
|
||||||
self.chat_ended = True
|
self.chat_ended = True
|
||||||
|
|
||||||
@workflow.signal
|
@workflow.signal
|
||||||
async def confirm(self) -> None:
|
async def confirm(self) -> None:
|
||||||
|
"""Signal handler for user confirmation of tool execution."""
|
||||||
self.confirm = True
|
self.confirm = True
|
||||||
|
|
||||||
@workflow.query
|
@workflow.query
|
||||||
def get_conversation_history(
|
def get_conversation_history(self) -> ConversationHistory:
|
||||||
self,
|
"""Query handler to retrieve the full conversation history."""
|
||||||
) -> Dict[str, List[Dict[str, Union[str, Dict[str, Any]]]]]:
|
|
||||||
# Return the whole conversation as a dict
|
|
||||||
return self.conversation_history
|
return self.conversation_history
|
||||||
|
|
||||||
@workflow.query
|
@workflow.query
|
||||||
def get_summary_from_history(self) -> Optional[dict]:
|
def get_summary_from_history(self) -> Optional[str]:
|
||||||
|
"""Query handler to retrieve the conversation summary if available."""
|
||||||
return self.conversation_summary
|
return self.conversation_summary
|
||||||
|
|
||||||
@workflow.query
|
@workflow.query
|
||||||
def get_tool_data(self) -> Optional[dict]:
|
def get_tool_data(self) -> Optional[ToolData]:
|
||||||
|
"""Query handler to retrieve the current tool data if available."""
|
||||||
return self.tool_data
|
return self.tool_data
|
||||||
|
|
||||||
# Helper: generate text of the entire conversation so far
|
|
||||||
|
|
||||||
def format_history(self) -> str:
|
def format_history(self) -> str:
|
||||||
|
"""Format the conversation history into a single string."""
|
||||||
return " ".join(
|
return " ".join(
|
||||||
str(msg["response"]) for msg in self.conversation_history["messages"]
|
str(msg["response"]) for msg in self.conversation_history["messages"]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return (context_instructions, prompt)
|
|
||||||
def prompt_with_history(self, prompt: str) -> tuple[str, str]:
|
def prompt_with_history(self, prompt: str) -> tuple[str, str]:
|
||||||
|
"""Generate a context-aware prompt with conversation history.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str]: A tuple of (context_instructions, prompt)
|
||||||
|
"""
|
||||||
history_string = self.format_history()
|
history_string = self.format_history()
|
||||||
context_instructions = (
|
context_instructions = (
|
||||||
f"Here is the conversation history: {history_string} "
|
f"Here is the conversation history: {history_string} "
|
||||||
@@ -235,8 +232,12 @@ class ToolWorkflow:
|
|||||||
)
|
)
|
||||||
return (context_instructions, prompt)
|
return (context_instructions, prompt)
|
||||||
|
|
||||||
# Return (context_instructions, prompt) for summarizing the conversation
|
|
||||||
def prompt_summary_with_history(self) -> tuple[str, str]:
|
def prompt_summary_with_history(self) -> tuple[str, str]:
|
||||||
|
"""Generate a prompt for summarizing the conversation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str]: A tuple of (context_instructions, prompt)
|
||||||
|
"""
|
||||||
history_string = self.format_history()
|
history_string = self.format_history()
|
||||||
context_instructions = f"Here is the conversation history between a user and a chatbot: {history_string}"
|
context_instructions = f"Here is the conversation history between a user and a chatbot: {history_string}"
|
||||||
actual_prompt = (
|
actual_prompt = (
|
||||||
@@ -246,7 +247,12 @@ class ToolWorkflow:
|
|||||||
return (context_instructions, actual_prompt)
|
return (context_instructions, actual_prompt)
|
||||||
|
|
||||||
def add_message(self, actor: str, response: Union[str, Dict[str, Any]]) -> None:
|
def add_message(self, actor: str, response: Union[str, Dict[str, Any]]) -> None:
|
||||||
# Append a message object to the "messages" list
|
"""Add a message to the conversation history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actor: The entity that generated the message (e.g., "user", "agent")
|
||||||
|
response: The message content, either as a string or structured data
|
||||||
|
"""
|
||||||
self.conversation_history["messages"].append(
|
self.conversation_history["messages"].append(
|
||||||
{"actor": actor, "response": response}
|
{"actor": actor, "response": response}
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user