From 8fafe4b09036a18155f5b73ae5fe1cba80b96f4f Mon Sep 17 00:00:00 2001 From: Laine Date: Tue, 11 Mar 2025 09:07:25 -0400 Subject: [PATCH] Change agent goal to be an element of the workflow, including query --- api/main.py | 36 +++++++++++++++++++++++++------- shared/config.py | 1 + tools/change_goal.py | 9 +++++++- workflows/agent_goal_workflow.py | 35 ++++++++++++++++++++----------- 4 files changed, 61 insertions(+), 20 deletions(-) diff --git a/api/main.py b/api/main.py index e5c42da..d82826d 100644 --- a/api/main.py +++ b/api/main.py @@ -20,7 +20,7 @@ temporal_client: Optional[Client] = None load_dotenv() -def get_agent_goal(): +def get_initial_agent_goal(): """Get the agent goal from environment variables.""" goals = { "goal_match_train_invoice": goal_match_train_invoice, @@ -121,6 +121,27 @@ async def get_conversation_history(): raise HTTPException( status_code=500, detail="Internal server error while querying workflow." ) + +@app.get("/agent-goal") +async def get_agent_goal(): + """Calls the workflow's 'get_agent_goal' query.""" + try: + # Get workflow handle + handle = temporal_client.get_workflow_handle("agent-workflow") + + # Check if the workflow is completed + workflow_status = await handle.describe() + if workflow_status.status == 2: + # Workflow is completed; return an empty response + return {} + + # Query the workflow + agent_goal = await handle.query("get_agent_goal") + return agent_goal + except TemporalError as e: + # Workflow not found; return an empty response + print(e) + return {} @app.post("/send-prompt") @@ -128,7 +149,8 @@ async def send_prompt(prompt: str): # Create combined input with goal from environment combined_input = CombinedInput( tool_params=AgentGoalWorkflowParams(None, None), - agent_goal=get_agent_goal(), + agent_goal=get_initial_agent_goal(), + #change to get from workflow query ) workflow_id = "agent-workflow" @@ -172,13 +194,13 @@ async def end_chat(): @app.post("/start-workflow") async def start_workflow(): - # Get the configured goal - agent_goal = get_agent_goal() + # Get the initial goal as set in shared/config or env or just...always should be "pick a goal?" + initial_agent_goal = get_initial_agent_goal() # Create combined input combined_input = CombinedInput( tool_params=AgentGoalWorkflowParams(None, None), - agent_goal=agent_goal, + agent_goal=initial_agent_goal, ) workflow_id = "agent-workflow" @@ -190,9 +212,9 @@ async def start_workflow(): id=workflow_id, task_queue=TEMPORAL_TASK_QUEUE, start_signal="user_prompt", - start_signal_args=["### " + agent_goal.starter_prompt], + start_signal_args=["### " + initial_agent_goal.starter_prompt], ) return { - "message": f"Workflow started with goal's starter prompt: {agent_goal.starter_prompt}." + "message": f"Workflow started with goal's starter prompt: {initial_agent_goal.starter_prompt}." } diff --git a/shared/config.py b/shared/config.py index 0775a39..cb4b9da 100644 --- a/shared/config.py +++ b/shared/config.py @@ -18,6 +18,7 @@ TEMPORAL_API_KEY = os.getenv("TEMPORAL_API_KEY", "") #Starting agent goal - 1st goal is always to help user pick a next goal AGENT_GOAL = "goal_choose_agent_type" +#AGENT_GOAL = "goal_event_flight_invoice" async def get_temporal_client() -> Client: diff --git a/tools/change_goal.py b/tools/change_goal.py index 2458796..983e9c7 100644 --- a/tools/change_goal.py +++ b/tools/change_goal.py @@ -1,6 +1,13 @@ -# can this just call the API endpoint to set the goal, if that changes to allow a param? +# can this just call the API endpoint to set the goal, if that changes to allow a param? +# if this functions, it could work to both send a signal and also circumvent the UI -> API thing. Maybe? + # --- OR --- + # end this workflow and start a new one with the new goal + +# --- OR --- + +# send a signal to the workflow from here? import shared.config def change_goal(args: dict) -> dict: diff --git a/workflows/agent_goal_workflow.py b/workflows/agent_goal_workflow.py index 0e47b8a..61479a3 100644 --- a/workflows/agent_goal_workflow.py +++ b/workflows/agent_goal_workflow.py @@ -7,6 +7,7 @@ from temporalio.common import RetryPolicy from temporalio import workflow from models.data_types import ConversationHistory, NextStep, ValidationInput +from models.tool_definitions import AgentGoal from workflows.workflow_helpers import LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT, \ LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT from workflows import workflow_helpers as helpers @@ -49,6 +50,8 @@ class AgentGoalWorkflow: self.tool_data: Optional[ToolData] = None self.confirm: bool = False self.tool_results: List[Dict[str, Any]] = [] + #set initial goal of "pick an agent" here?? + self.goal: AgentGoal = {"tools": []} # see ../api/main.py#temporal_client.start_workflow() for how these parameters are set @workflow.run @@ -56,7 +59,7 @@ class AgentGoalWorkflow: """Main workflow execution method.""" # setup phase, starts with blank tool_params and agent_goal prompt as defined in tools/goal_registry.py params = combined_input.tool_params - agent_goal = combined_input.agent_goal + self.goal = combined_input.agent_goal # add message from sample conversation provided in tools/goal_registry.py, if it exists if params and params.conversation_summary: @@ -77,15 +80,15 @@ class AgentGoalWorkflow: ) #update the goal, in case it's changed - doesn't help - goals = { - "goal_match_train_invoice": goal_match_train_invoice, - "goal_event_flight_invoice": goal_event_flight_invoice, - "goal_choose_agent_type": goal_choose_agent_type, - } + #goals = { + # "goal_match_train_invoice": goal_match_train_invoice, + # "goal_event_flight_invoice": goal_event_flight_invoice, + # "goal_choose_agent_type": goal_choose_agent_type, + #} - if shared.config.AGENT_GOAL is not None: - agent_goal = goals.get(shared.config.AGENT_GOAL) - workflow.logger.warning("AGENT_GOAL: " + shared.config.AGENT_GOAL) + #if shared.config.AGENT_GOAL is not None: + # agent_goal = goals.get(shared.config.AGENT_GOAL) + #workflow.logger.warning("AGENT_GOAL: " + shared.config.AGENT_GOAL) # workflow.logger.warning("agent_goal", agent_goal) #process signals of various kinds @@ -112,6 +115,9 @@ class AgentGoalWorkflow: self.add_message, self.prompt_queue ) + # workflow.logger.warning("last tool_data tool: ", self.tool_data[-1].tool) + #workflow.logger.warning("last tool_data args: ", self.tool_data[-1].args) + # workflow.logger.warning("last tool_results [args]: ", self.tool_results[-1]["args"]) continue if self.prompt_queue: @@ -123,7 +129,7 @@ class AgentGoalWorkflow: validation_input = ValidationInput( prompt=prompt, conversation_history=self.conversation_history, - agent_goal=agent_goal, + agent_goal=self.goal, ) validation_result = await workflow.execute_activity( ToolActivities.agent_validatePrompt, @@ -147,7 +153,7 @@ class AgentGoalWorkflow: # Proceed with generating the context and prompt context_instructions = generate_genai_prompt( - agent_goal, self.conversation_history, self.tool_data + self.goal, self.conversation_history, self.tool_data ) prompt_input = ToolPromptInput( @@ -189,7 +195,7 @@ class AgentGoalWorkflow: await helpers.continue_as_new_if_needed( self.conversation_history, self.prompt_queue, - agent_goal, + self.goal, MAX_TURNS_BEFORE_CONTINUE, self.add_message ) @@ -220,6 +226,11 @@ class AgentGoalWorkflow: def get_conversation_history(self) -> ConversationHistory: """Query handler to retrieve the full conversation history.""" return self.conversation_history + + @workflow.query + def get_agent_goal(self) -> AgentGoal: + """Query handler to retrieve the current goal of the agent.""" + return self.goal @workflow.query def get_summary_from_history(self) -> Optional[str]: