Merge branch 'development' of https://github.com/joshmsmith/temporal-ai-agent into development

This commit is contained in:
Joshua Smith
2025-03-12 13:38:25 -04:00
6 changed files with 63 additions and 54 deletions

View File

@@ -37,3 +37,6 @@ OPENAI_API_KEY=sk-proj-...
# Agent Goal Configuration # Agent Goal Configuration
# AGENT_GOAL=goal_event_flight_invoice # (default) or goal_match_train_invoice # AGENT_GOAL=goal_event_flight_invoice # (default) or goal_match_train_invoice
# Set if the UI should force a user confirmation step or not
SHOW_CONFIRM=True

View File

@@ -1,3 +1,4 @@
import os
from fastapi import FastAPI from fastapi import FastAPI
from typing import Optional from typing import Optional
from temporalio.client import Client from temporalio.client import Client
@@ -11,7 +12,7 @@ from workflows.agent_goal_workflow import AgentGoalWorkflow
from models.data_types import CombinedInput, AgentGoalWorkflowParams from models.data_types import CombinedInput, AgentGoalWorkflowParams
from tools.goal_registry import goal_list from tools.goal_registry import goal_list
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE, AGENT_GOAL from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE
app = FastAPI() app = FastAPI()
temporal_client: Optional[Client] = None temporal_client: Optional[Client] = None
@@ -22,13 +23,10 @@ load_dotenv()
def get_initial_agent_goal(): def get_initial_agent_goal():
"""Get the agent goal from environment variables.""" """Get the agent goal from environment variables."""
if AGENT_GOAL is not None: env_goal = os.getenv("AGENT_GOAL", "goal_choose_agent_type") #if no goal is set in the env file, default to choosing an agent
for listed_goal in goal_list: for listed_goal in goal_list:
if listed_goal.id == AGENT_GOAL: if listed_goal.id == env_goal:
return listed_goal return listed_goal
else:
#if no goal is set in the config file, default to choosing an agent
return goal_list.get("goal_choose_agent_type")
@app.on_event("startup") @app.on_event("startup")
@@ -113,6 +111,10 @@ async def get_conversation_history():
status_code=404, detail="Workflow worker unavailable or not found." status_code=404, detail="Workflow worker unavailable or not found."
) )
if "workflow not found" in error_message:
await start_workflow()
return []
else:
# For other Temporal errors, return a 500 # For other Temporal errors, return a 500
raise HTTPException( raise HTTPException(
status_code=500, detail="Internal server error while querying workflow." status_code=500, detail="Internal server error while querying workflow."

View File

@@ -27,7 +27,7 @@ const LLMResponse = memo(({ data, onConfirm, isLastMessage, onHeightChange }) =>
: data?.response; : data?.response;
const displayText = (response || '').trim(); const displayText = (response || '').trim();
const requiresConfirm = data.next === "confirm" && isLastMessage; const requiresConfirm = data.force_confirm && data.next === "confirm" && isLastMessage;
const defaultText = requiresConfirm const defaultText = requiresConfirm
? `Agent is ready to run "${data.tool}". Please confirm.` ? `Agent is ready to run "${data.tool}". Please confirm.`
: ''; : '';

View File

@@ -16,11 +16,6 @@ TEMPORAL_TLS_CERT = os.getenv("TEMPORAL_TLS_CERT", "")
TEMPORAL_TLS_KEY = os.getenv("TEMPORAL_TLS_KEY", "") TEMPORAL_TLS_KEY = os.getenv("TEMPORAL_TLS_KEY", "")
TEMPORAL_API_KEY = os.getenv("TEMPORAL_API_KEY", "") TEMPORAL_API_KEY = os.getenv("TEMPORAL_API_KEY", "")
#Starting agent goal - 1st goal is always to help user pick a next goal
AGENT_GOAL = "goal_choose_agent_type"
#AGENT_GOAL = "goal_event_flight_invoice"
async def get_temporal_client() -> Client: async def get_temporal_client() -> Client:
""" """
Creates a Temporal client based on environment configuration. Creates a Temporal client based on environment configuration.

View File

@@ -41,3 +41,5 @@
[ ] non-retry the api key error - "Invalid API Key provided: sk_test_**J..." and "AuthenticationError" <br /> [ ] non-retry the api key error - "Invalid API Key provided: sk_test_**J..." and "AuthenticationError" <br />
[ ] make it so you can yeet yourself out of a goal and pick a new one <br /> [ ] make it so you can yeet yourself out of a goal and pick a new one <br />
[ ] add visual feedback when workflow starting

View File

@@ -1,5 +1,6 @@
from collections import deque from collections import deque
from datetime import timedelta from datetime import timedelta
import os
from typing import Dict, Any, Union, List, Optional, Deque, TypedDict from typing import Dict, Any, Union, List, Optional, Deque, TypedDict
from temporalio.common import RetryPolicy from temporalio.common import RetryPolicy
@@ -25,11 +26,19 @@ with workflow.unsafe.imports_passed_through():
# Constants # Constants
MAX_TURNS_BEFORE_CONTINUE = 250 MAX_TURNS_BEFORE_CONTINUE = 250
SHOW_CONFIRM = True
show_confirm_env = os.getenv("SHOW_CONFIRM")
if show_confirm_env is not None:
if show_confirm_env == "False":
SHOW_CONFIRM = False
#ToolData as part of the workflow is what's accessible to the UI - see LLMResponse.jsx for example
class ToolData(TypedDict, total=False): class ToolData(TypedDict, total=False):
next: NextStep next: NextStep
tool: str tool: str
args: Dict[str, Any] args: Dict[str, Any]
response: str response: str
force_confirm: bool = True
@workflow.defn @workflow.defn
class AgentGoalWorkflow: class AgentGoalWorkflow:
@@ -48,6 +57,7 @@ class AgentGoalWorkflow:
# see ../api/main.py#temporal_client.start_workflow() for how the input parameters are set # see ../api/main.py#temporal_client.start_workflow() for how the input parameters are set
@workflow.run @workflow.run
async def run(self, combined_input: CombinedInput) -> str: async def run(self, combined_input: CombinedInput) -> str:
"""Main workflow execution method.""" """Main workflow execution method."""
# setup phase, starts with blank tool_params and agent_goal prompt as defined in tools/goal_registry.py # setup phase, starts with blank tool_params and agent_goal prompt as defined in tools/goal_registry.py
params = combined_input.tool_params params = combined_input.tool_params
@@ -77,12 +87,13 @@ class AgentGoalWorkflow:
# handle chat-end signal # handle chat-end signal
if self.chat_ended: if self.chat_ended:
workflow.logger.warning(f"workflow step: chat-end signal received, ending")
workflow.logger.info("Chat ended.") workflow.logger.info("Chat ended.")
return f"{self.conversation_history}" return f"{self.conversation_history}"
# execute tool # Execute the tool
if self.confirm and waiting_for_confirm and current_tool and self.tool_data: if self.confirm and waiting_for_confirm and current_tool and self.tool_data:
workflow.logger.warning(f"workflow step: user has confirmed, executing the tool {current_tool}")
self.confirm = False self.confirm = False
waiting_for_confirm = False waiting_for_confirm = False
@@ -99,8 +110,6 @@ class AgentGoalWorkflow:
self.prompt_queue self.prompt_queue
) )
workflow.logger.warning(f"tool_results keys: {self.tool_results[-1].keys()}")
workflow.logger.warning(f"tool_results values: {self.tool_results[-1].values()}")
#set new goal if we should #set new goal if we should
if len(self.tool_results) > 0: if len(self.tool_results) > 0:
if "ChangeGoal" in self.tool_results[-1].values() and "new_goal" in self.tool_results[-1].keys(): if "ChangeGoal" in self.tool_results[-1].values() and "new_goal" in self.tool_results[-1].keys():
@@ -112,10 +121,11 @@ class AgentGoalWorkflow:
self.change_goal("goal_choose_agent_type") self.change_goal("goal_choose_agent_type")
continue continue
# push messages to UI if there are any # if we've received messages to be processed on the prompt queue...
if self.prompt_queue: if self.prompt_queue:
prompt = self.prompt_queue.popleft() prompt = self.prompt_queue.popleft()
if not prompt.startswith("###"): workflow.logger.warning(f"workflow step: processing message on the prompt queue, message is {prompt}")
if not prompt.startswith("###"): #if the message isn't from the LLM but is instead from the user
self.add_message("user", prompt) self.add_message("user", prompt)
# Validate the prompt before proceeding # Validate the prompt before proceeding
@@ -134,27 +144,17 @@ class AgentGoalWorkflow:
), ),
) )
#If validation fails, provide that feedback to the user - i.e., "your words make no sense, human" #If validation fails, provide that feedback to the user - i.e., "your words make no sense, puny human" end this iteration of processing
if not validation_result.validationResult: if not validation_result.validationResult:
workflow.logger.warning( workflow.logger.warning(f"Prompt validation failed: {validation_result.validationFailedReason}")
f"Prompt validation failed: {validation_result.validationFailedReason}" self.add_message("agent", validation_result.validationFailedReason)
)
self.add_message(
"agent", validation_result.validationFailedReason
)
continue continue
# Proceed with generating the context and prompt # If valid, proceed with generating the context and prompt
context_instructions = generate_genai_prompt( context_instructions = generate_genai_prompt(self.goal, self.conversation_history, self.tool_data)
self.goal, self.conversation_history, self.tool_data prompt_input = ToolPromptInput(prompt=prompt, context_instructions=context_instructions)
)
prompt_input = ToolPromptInput( # connect to LLM and execute to get next steps
prompt=prompt,
context_instructions=context_instructions,
)
# connect to LLM and get it to create a prompt for the user about the tool
tool_data = await workflow.execute_activity( tool_data = await workflow.execute_activity(
ToolActivities.agent_toolPlanner, ToolActivities.agent_toolPlanner,
prompt_input, prompt_input,
@@ -164,35 +164,39 @@ class AgentGoalWorkflow:
initial_interval=timedelta(seconds=5), backoff_coefficient=1 initial_interval=timedelta(seconds=5), backoff_coefficient=1
), ),
) )
tool_data["force_confirm"] = SHOW_CONFIRM
self.tool_data = tool_data self.tool_data = tool_data
# move forward in the tool chain # process the tool as dictated by the prompt response - what to do next, and with which tool
next_step = tool_data.get("next") next_step = tool_data.get("next")
current_tool = tool_data.get("tool") current_tool = tool_data.get("tool")
if "next" in self.tool_data.keys():
workflow.logger.warning(f"ran the toolplanner, next step: {next_step}")
else:
workflow.logger.warning("ran the toolplanner, next step not set!")
workflow.logger.warning(f"next_step: {next_step}, current tool is {current_tool}")
#if the next step is to confirm...
if next_step == "confirm" and current_tool: if next_step == "confirm" and current_tool:
workflow.logger.warning("next_step: confirm, ran the toolplanner, trying to confirm")
args = tool_data.get("args", {}) args = tool_data.get("args", {})
#if we're missing arguments, go back to the top of the loop
if await helpers.handle_missing_args(current_tool, args, tool_data, self.prompt_queue): if await helpers.handle_missing_args(current_tool, args, tool_data, self.prompt_queue):
continue continue
#...otherwise, if we want to force the user to confirm, set that up
waiting_for_confirm = True waiting_for_confirm = True
if SHOW_CONFIRM:
self.confirm = False self.confirm = False
workflow.logger.info("Waiting for user confirm signal...") workflow.logger.info("Waiting for user confirm signal...")
else:
#theory - set self.confirm to true bc that's the signal, so we can get around the signal??
self.confirm = True
# todo probably here we can set the next step to be change-goal # else if the next step is to pick a new goal...
elif next_step == "pick-new-goal": elif next_step == "pick-new-goal":
workflow.logger.info("All steps completed. Resetting goal.") workflow.logger.info("All steps completed. Resetting goal.")
workflow.logger.warning("next_step = pick-new-goal, setting goal to goal_choose_agent_type")
self.change_goal("goal_choose_agent_type") self.change_goal("goal_choose_agent_type")
# else if the next step is to be done - this should only happen if the user requests it via "end conversation"
elif next_step == "done": elif next_step == "done":
workflow.logger.warning("next_step = done")
self.add_message("agent", tool_data) self.add_message("agent", tool_data)
# end the workflow
return str(self.conversation_history) return str(self.conversation_history)
self.add_message("agent", tool_data) self.add_message("agent", tool_data)
@@ -208,8 +212,9 @@ class AgentGoalWorkflow:
@workflow.signal @workflow.signal
async def user_prompt(self, prompt: str) -> None: async def user_prompt(self, prompt: str) -> None:
"""Signal handler for receiving user prompts.""" """Signal handler for receiving user prompts."""
workflow.logger.warning(f"signal received: user_prompt, prompt is {prompt}")
if self.chat_ended: if self.chat_ended:
workflow.logger.warn(f"Message dropped due to chat closed: {prompt}") workflow.logger.warning(f"Message dropped due to chat closed: {prompt}")
return return
self.prompt_queue.append(prompt) self.prompt_queue.append(prompt)
@@ -218,12 +223,14 @@ class AgentGoalWorkflow:
async def confirm(self) -> None: async def confirm(self) -> None:
"""Signal handler for user confirmation of tool execution.""" """Signal handler for user confirmation of tool execution."""
workflow.logger.info("Received user confirmation") workflow.logger.info("Received user confirmation")
workflow.logger.warning(f"signal recieved: confirm")
self.confirm = True self.confirm = True
#Signal that comes from api/main.py via a post to /end-chat #Signal that comes from api/main.py via a post to /end-chat
@workflow.signal @workflow.signal
async def end_chat(self) -> None: async def end_chat(self) -> None:
"""Signal handler for ending the chat session.""" """Signal handler for ending the chat session."""
workflow.logger.warning("signal received: end_chat")
self.chat_ended = True self.chat_ended = True
@workflow.query @workflow.query