diff --git a/activities/tool_activities.py b/activities/tool_activities.py index 7952066..a015946 100644 --- a/activities/tool_activities.py +++ b/activities/tool_activities.py @@ -3,6 +3,9 @@ from temporalio import activity from temporalio.exceptions import ApplicationError from ollama import chat, ChatResponse import json +from models.tool_definitions import ToolsData +from typing import Sequence +from temporalio.common import RawValue @dataclass @@ -44,6 +47,47 @@ class ToolActivities: return data + @activity.defn + def validate_and_parse_json( + self, + response_prechecked: str, + tools_data: ToolsData, + conversation_history: str, + ) -> dict: + """ + 1) Build JSON validation instructions + 2) Call LLM with those instructions + 3) Parse the result + 4) If parsing fails, raise exception -> triggers retry + """ + + # 1) Build validation instructions + # (Generate the validation prompt exactly as you do in your workflow.) + from prompts.agent_prompt_generators import ( + generate_json_validation_prompt_from_tools_data, + ) + + validation_prompt = generate_json_validation_prompt_from_tools_data( + tools_data, conversation_history, response_prechecked + ) + + # 2) Call LLM + prompt_input = ToolPromptInput( + prompt=response_prechecked, + context_instructions=validation_prompt, + ) + validated_response = self.prompt_llm(prompt_input) + + # 3) Parse + # If parse fails, we raise ApplicationError -> triggers retry + try: + parsed = self.parse_tool_data(validated_response) + except Exception as e: + raise ApplicationError(f"Failed to parse validated JSON: {e}") + + # 4) If we get here, parse succeeded + return parsed + def get_current_date_human_readable(): """ @@ -54,3 +98,30 @@ def get_current_date_human_readable(): from datetime import datetime return datetime.now().strftime("%A, %B %d, %Y") + + +@activity.defn(dynamic=True) +def dynamic_tool_activity(args: Sequence[RawValue]) -> dict: + """Dynamic activity that is invoked via an unknown activity type.""" + tool_name = activity.info().activity_type # e.g. "SearchFlights" + + # The first payload is the dictionary of arguments + tool_args = activity.payload_converter().from_payload(args[0].payload, dict) + + # Extract fields from the arguments + date_depart = tool_args.get("dateDepart") + date_return = tool_args.get("dateReturn") + origin = tool_args.get("origin") + destination = tool_args.get("destination") + + # Print (or log) them + activity.logger.info(f"Tool: {tool_name}") + activity.logger.info(f"Depart: {date_depart}, Return: {date_return}") + activity.logger.info(f"Origin: {origin}, Destination: {destination}") + + # For now, just return them + return { + "tool": tool_name, + "args": tool_args, + "status": "OK - dynamic activity stub", + } diff --git a/scripts/run_worker.py b/scripts/run_worker.py index 888aa7b..b8732a6 100644 --- a/scripts/run_worker.py +++ b/scripts/run_worker.py @@ -5,7 +5,7 @@ import logging from temporalio.client import Client from temporalio.worker import Worker -from activities.tool_activities import ToolActivities +from activities.tool_activities import ToolActivities, dynamic_tool_activity from workflows.tool_workflow import ToolWorkflow from workflows.parent_workflow import ParentWorkflow @@ -21,7 +21,12 @@ async def main(): client, task_queue="ollama-task-queue", workflows=[ToolWorkflow, ParentWorkflow], - activities=[activities.prompt_llm, activities.parse_tool_data], + activities=[ + activities.prompt_llm, + activities.parse_tool_data, + activities.validate_and_parse_json, + dynamic_tool_activity, + ], activity_executor=activity_executor, ) await worker.run() diff --git a/scripts/send_message.py b/scripts/send_message.py index 10f7905..41252c7 100644 --- a/scripts/send_message.py +++ b/scripts/send_message.py @@ -12,7 +12,7 @@ async def main(prompt): # Construct your tool definitions in code search_flights_tool = ToolDefinition( name="SearchFlights", - description="Search for return flights from an origin to a destination within a date range", + description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)", arguments=[ ToolArgument( name="origin", @@ -27,12 +27,12 @@ async def main(prompt): ToolArgument( name="dateDepart", type="ISO8601", - description="Start of date range in human readable format", + description="Start of date range in human readable format, when you want to depart", ), ToolArgument( name="dateReturn", type="ISO8601", - description="End of date range in human readable format", + description="End of date range in human readable format, when you want to return", ), ], ) diff --git a/workflows/tool_workflow.py b/workflows/tool_workflow.py index 94cd9f8..1266ee2 100644 --- a/workflows/tool_workflow.py +++ b/workflows/tool_workflow.py @@ -1,18 +1,14 @@ from collections import deque from datetime import timedelta from typing import Deque, List, Optional, Tuple +from temporalio.common import RetryPolicy from temporalio import workflow -from prompts.agent_prompt_generators import ( - generate_genai_prompt_from_tools_data, - generate_json_validation_prompt_from_tools_data, -) with workflow.unsafe.imports_passed_through(): from activities.tool_activities import ToolActivities, ToolPromptInput from prompts.agent_prompt_generators import ( generate_genai_prompt_from_tools_data, - generate_json_validation_prompt_from_tools_data, ) from models.data_types import CombinedInput, ToolWorkflowParams @@ -43,71 +39,47 @@ class ToolWorkflow: self.prompt_queue.extend(params.prompt_queue) while True: - workflow.logger.info("Waiting for prompts...") - await workflow.wait_condition( lambda: bool(self.prompt_queue) or self.chat_ended ) if self.prompt_queue: - # Get user's prompt + # 1) Get the user prompt -> call initial LLM prompt = self.prompt_queue.popleft() self.conversation_history.append(("user", prompt)) - - # Build prompt + context context_instructions = generate_genai_prompt_from_tools_data( tools_data, self.format_history() ) - workflow.logger.info("Prompt: " + prompt) - - # Pass a single input object prompt_input = ToolPromptInput( prompt=prompt, context_instructions=context_instructions, ) - - # Call activity with one argument responsePrechecked = await workflow.execute_activity_method( ToolActivities.prompt_llm, prompt_input, schedule_to_close_timeout=timedelta(seconds=20), ) - # Check if the response is valid JSON - json_validation_instructions = ( - generate_json_validation_prompt_from_tools_data( - tools_data, self.format_history(), responsePrechecked - ) - ) - workflow.logger.info("Prompt: " + prompt) - - # Pass a single input object - prompt_input = ToolPromptInput( - prompt=responsePrechecked, - context_instructions=json_validation_instructions, - ) - - # Call activity with one argument - response = await workflow.execute_activity_method( - ToolActivities.prompt_llm, - prompt_input, - schedule_to_close_timeout=timedelta(seconds=20), - ) - - workflow.logger.info(f"Ollama response: {response}") - self.conversation_history.append(("response", response)) - - # Call activity with one argument + # 2) Validate + parse in one shot tool_data = await workflow.execute_activity_method( - ToolActivities.parse_tool_data, - response, - schedule_to_close_timeout=timedelta(seconds=1), + ToolActivities.validate_and_parse_json, + args=[responsePrechecked, tools_data, self.format_history()], + schedule_to_close_timeout=timedelta(seconds=40), + retry_policy=RetryPolicy(initial_interval=timedelta(seconds=10)), ) + # store it self.tool_data = tool_data + self.conversation_history.append(("response", str(tool_data))) if self.tool_data.get("next") == "confirm": - return self.tool_data + dynamic_result = await workflow.execute_activity( + self.tool_data["tool"], # dynamic activity name + self.tool_data["args"], # single argument to pass + schedule_to_close_timeout=timedelta(seconds=20), + ) + + return dynamic_result # Continue as new after X turns if len(self.conversation_history) >= self.max_turns_before_continue: