diff --git a/activities/tool_activities.py b/activities/tool_activities.py index df5759a..c9bfa94 100644 --- a/activities/tool_activities.py +++ b/activities/tool_activities.py @@ -1,6 +1,5 @@ from dataclasses import dataclass from temporalio import activity -from temporalio.exceptions import ApplicationError from ollama import chat, ChatResponse import json from typing import Sequence diff --git a/workflows/tool_workflow.py b/workflows/tool_workflow.py index 1bf4d6a..721c6fa 100644 --- a/workflows/tool_workflow.py +++ b/workflows/tool_workflow.py @@ -22,6 +22,7 @@ class ToolWorkflow: self.chat_ended: bool = False self.tool_data = None self.max_turns_before_continue: int = 250 + self.confirm = False @workflow.run async def run(self, combined_input: CombinedInput) -> str: @@ -96,6 +97,14 @@ class ToolWorkflow: ) # e.g. "FindEvents", "SearchFlights", "CreateInvoice" if next_step == "confirm" and current_tool: + self.confirm = False + + # Wait for a 'confirm' signal + await workflow.wait_condition(lambda: self.confirm) + workflow.logger.info( + "Confirmed. Proceeding with tool execution: " + current_tool + ) + # We have enough info to call the tool dynamic_result = await workflow.execute_activity( current_tool, @@ -188,6 +197,10 @@ class ToolWorkflow: async def end_chat(self) -> None: self.chat_ended = True + @workflow.signal + async def confirm(self) -> None: + self.confirm = True + @workflow.query def get_conversation_history(self) -> List[Tuple[str, str]]: return self.conversation_history