mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
basic react API
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from collections import deque
|
||||
from datetime import timedelta
|
||||
from typing import Deque, List, Optional, Tuple
|
||||
from typing import Dict, Any, Union, List, Optional, Tuple, Deque
|
||||
from temporalio.common import RetryPolicy
|
||||
|
||||
from temporalio import workflow
|
||||
@@ -16,7 +16,9 @@ with workflow.unsafe.imports_passed_through():
|
||||
@workflow.defn
|
||||
class ToolWorkflow:
|
||||
def __init__(self) -> None:
|
||||
self.conversation_history: List[Tuple[str, str]] = []
|
||||
self.conversation_history: Dict[
|
||||
str, List[Dict[str, Union[str, Dict[str, Any]]]]
|
||||
] = {"messages": []}
|
||||
self.prompt_queue: Deque[str] = deque()
|
||||
self.conversation_summary: Optional[str] = None
|
||||
self.chat_ended: bool = False
|
||||
@@ -31,9 +33,7 @@ class ToolWorkflow:
|
||||
tool_data = None
|
||||
|
||||
if params and params.conversation_summary:
|
||||
self.conversation_history.append(
|
||||
("conversation_summary", params.conversation_summary)
|
||||
)
|
||||
self.add_message("conversation_summary", params.conversation_summary)
|
||||
self.conversation_summary = params.conversation_summary
|
||||
|
||||
if params and params.prompt_queue:
|
||||
@@ -51,7 +51,7 @@ class ToolWorkflow:
|
||||
# 1) If chat_ended was signaled, handle end and return
|
||||
if self.chat_ended:
|
||||
# possibly do a summary if multiple turns
|
||||
if len(self.conversation_history) > 1:
|
||||
if len(self.conversation_history["messages"]) > 1:
|
||||
summary_context, summary_prompt = self.prompt_summary_with_history()
|
||||
summary_input = ToolPromptInput(
|
||||
prompt=summary_prompt, context_instructions=summary_context
|
||||
@@ -73,6 +73,11 @@ class ToolWorkflow:
|
||||
self.confirm = False
|
||||
waiting_for_confirm = False
|
||||
|
||||
confirmed_tool_data = self.tool_data.copy()
|
||||
|
||||
confirmed_tool_data["next"] = "confirmed"
|
||||
self.add_message("userToolConfirm", confirmed_tool_data)
|
||||
|
||||
# Run the tool
|
||||
workflow.logger.info(f"Confirmed. Proceeding with tool: {current_tool}")
|
||||
dynamic_result = await workflow.execute_activity(
|
||||
@@ -80,15 +85,14 @@ class ToolWorkflow:
|
||||
self.tool_data["args"],
|
||||
schedule_to_close_timeout=timedelta(seconds=20),
|
||||
)
|
||||
self.conversation_history.append(
|
||||
(f"{current_tool}_result", str(dynamic_result))
|
||||
)
|
||||
dynamic_result["tool"] = current_tool
|
||||
self.add_message(f"tool_result", dynamic_result)
|
||||
|
||||
# Enqueue a follow-up prompt for the LLM
|
||||
self.prompt_queue.append(
|
||||
f"The '{current_tool}' tool completed successfully with {dynamic_result}. "
|
||||
f"### The '{current_tool}' tool completed successfully with {dynamic_result}. "
|
||||
"INSTRUCTIONS: Use this tool result, and the conversation history to figure out next steps. "
|
||||
"If all listed tools have run, then produce a done response."
|
||||
"IMPORTANT: If all listed tools have run, you are up to the final step. Mark 'next':'done' and respond with your final confirmation."
|
||||
)
|
||||
# Loop around again
|
||||
continue
|
||||
@@ -96,7 +100,11 @@ class ToolWorkflow:
|
||||
# 3) If there's a user prompt waiting, process it (unless we're in some other skipping logic).
|
||||
if self.prompt_queue:
|
||||
prompt = self.prompt_queue.popleft()
|
||||
self.conversation_history.append(("user", prompt))
|
||||
if prompt.startswith("###"):
|
||||
# this is a custom prompt where the tool result is sent to the LLM
|
||||
self.add_message("tool_result_to_llm", prompt)
|
||||
else:
|
||||
self.add_message("user", prompt)
|
||||
|
||||
# Pass entire conversation + Tools to LLM
|
||||
context_instructions = generate_genai_prompt(
|
||||
@@ -115,7 +123,7 @@ class ToolWorkflow:
|
||||
),
|
||||
)
|
||||
self.tool_data = tool_data
|
||||
self.conversation_history.append(("response", str(tool_data)))
|
||||
self.add_message("response", tool_data)
|
||||
|
||||
# Check the next step from LLM
|
||||
next_step = self.tool_data.get("next")
|
||||
@@ -134,7 +142,10 @@ class ToolWorkflow:
|
||||
|
||||
# Possibly continue-as-new after many turns
|
||||
# todo ensure this doesn't lose critical context
|
||||
if len(self.conversation_history) >= self.max_turns_before_continue:
|
||||
if (
|
||||
len(self.conversation_history["messages"])
|
||||
>= self.max_turns_before_continue
|
||||
):
|
||||
summary_context, summary_prompt = self.prompt_summary_with_history()
|
||||
summary_input = ToolPromptInput(
|
||||
prompt=summary_prompt, context_instructions=summary_context
|
||||
@@ -175,7 +186,10 @@ class ToolWorkflow:
|
||||
self.confirm = True
|
||||
|
||||
@workflow.query
|
||||
def get_conversation_history(self) -> List[Tuple[str, str]]:
|
||||
def get_conversation_history(
|
||||
self,
|
||||
) -> Dict[str, List[Dict[str, Union[str, Dict[str, Any]]]]]:
|
||||
# Return the whole conversation as a dict
|
||||
return self.conversation_history
|
||||
|
||||
@workflow.query
|
||||
@@ -187,8 +201,11 @@ class ToolWorkflow:
|
||||
return self.tool_data
|
||||
|
||||
# Helper: generate text of the entire conversation so far
|
||||
|
||||
def format_history(self) -> str:
|
||||
return " ".join(f"{text}" for _, text in self.conversation_history)
|
||||
return " ".join(
|
||||
str(msg["response"]) for msg in self.conversation_history["messages"]
|
||||
)
|
||||
|
||||
# Return (context_instructions, prompt)
|
||||
def prompt_with_history(self, prompt: str) -> tuple[str, str]:
|
||||
@@ -210,3 +227,9 @@ class ToolWorkflow:
|
||||
'Put the summary in the format { "summary": "<plain text>" }'
|
||||
)
|
||||
return (context_instructions, actual_prompt)
|
||||
|
||||
def add_message(self, actor: str, response: Union[str, Dict[str, Any]]) -> None:
|
||||
# Append a message object to the "messages" list
|
||||
self.conversation_history["messages"].append(
|
||||
{"actor": actor, "response": response}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user