dynamic activity to call tool

This commit is contained in:
Steve Androulakis
2025-01-01 13:54:54 -08:00
parent e7e8e7e658
commit eb92c71dfe
4 changed files with 97 additions and 49 deletions

View File

@@ -3,6 +3,9 @@ from temporalio import activity
from temporalio.exceptions import ApplicationError
from ollama import chat, ChatResponse
import json
from models.tool_definitions import ToolsData
from typing import Sequence
from temporalio.common import RawValue
@dataclass
@@ -44,6 +47,47 @@ class ToolActivities:
return data
@activity.defn
def validate_and_parse_json(
self,
response_prechecked: str,
tools_data: ToolsData,
conversation_history: str,
) -> dict:
"""
1) Build JSON validation instructions
2) Call LLM with those instructions
3) Parse the result
4) If parsing fails, raise exception -> triggers retry
"""
# 1) Build validation instructions
# (Generate the validation prompt exactly as you do in your workflow.)
from prompts.agent_prompt_generators import (
generate_json_validation_prompt_from_tools_data,
)
validation_prompt = generate_json_validation_prompt_from_tools_data(
tools_data, conversation_history, response_prechecked
)
# 2) Call LLM
prompt_input = ToolPromptInput(
prompt=response_prechecked,
context_instructions=validation_prompt,
)
validated_response = self.prompt_llm(prompt_input)
# 3) Parse
# If parse fails, we raise ApplicationError -> triggers retry
try:
parsed = self.parse_tool_data(validated_response)
except Exception as e:
raise ApplicationError(f"Failed to parse validated JSON: {e}")
# 4) If we get here, parse succeeded
return parsed
def get_current_date_human_readable():
"""
@@ -54,3 +98,30 @@ def get_current_date_human_readable():
from datetime import datetime
return datetime.now().strftime("%A, %B %d, %Y")
@activity.defn(dynamic=True)
def dynamic_tool_activity(args: Sequence[RawValue]) -> dict:
"""Dynamic activity that is invoked via an unknown activity type."""
tool_name = activity.info().activity_type # e.g. "SearchFlights"
# The first payload is the dictionary of arguments
tool_args = activity.payload_converter().from_payload(args[0].payload, dict)
# Extract fields from the arguments
date_depart = tool_args.get("dateDepart")
date_return = tool_args.get("dateReturn")
origin = tool_args.get("origin")
destination = tool_args.get("destination")
# Print (or log) them
activity.logger.info(f"Tool: {tool_name}")
activity.logger.info(f"Depart: {date_depart}, Return: {date_return}")
activity.logger.info(f"Origin: {origin}, Destination: {destination}")
# For now, just return them
return {
"tool": tool_name,
"args": tool_args,
"status": "OK - dynamic activity stub",
}

View File

@@ -5,7 +5,7 @@ import logging
from temporalio.client import Client
from temporalio.worker import Worker
from activities.tool_activities import ToolActivities
from activities.tool_activities import ToolActivities, dynamic_tool_activity
from workflows.tool_workflow import ToolWorkflow
from workflows.parent_workflow import ParentWorkflow
@@ -21,7 +21,12 @@ async def main():
client,
task_queue="ollama-task-queue",
workflows=[ToolWorkflow, ParentWorkflow],
activities=[activities.prompt_llm, activities.parse_tool_data],
activities=[
activities.prompt_llm,
activities.parse_tool_data,
activities.validate_and_parse_json,
dynamic_tool_activity,
],
activity_executor=activity_executor,
)
await worker.run()

View File

@@ -12,7 +12,7 @@ async def main(prompt):
# Construct your tool definitions in code
search_flights_tool = ToolDefinition(
name="SearchFlights",
description="Search for return flights from an origin to a destination within a date range",
description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)",
arguments=[
ToolArgument(
name="origin",
@@ -27,12 +27,12 @@ async def main(prompt):
ToolArgument(
name="dateDepart",
type="ISO8601",
description="Start of date range in human readable format",
description="Start of date range in human readable format, when you want to depart",
),
ToolArgument(
name="dateReturn",
type="ISO8601",
description="End of date range in human readable format",
description="End of date range in human readable format, when you want to return",
),
],
)

View File

@@ -1,18 +1,14 @@
from collections import deque
from datetime import timedelta
from typing import Deque, List, Optional, Tuple
from temporalio.common import RetryPolicy
from temporalio import workflow
from prompts.agent_prompt_generators import (
generate_genai_prompt_from_tools_data,
generate_json_validation_prompt_from_tools_data,
)
with workflow.unsafe.imports_passed_through():
from activities.tool_activities import ToolActivities, ToolPromptInput
from prompts.agent_prompt_generators import (
generate_genai_prompt_from_tools_data,
generate_json_validation_prompt_from_tools_data,
)
from models.data_types import CombinedInput, ToolWorkflowParams
@@ -43,71 +39,47 @@ class ToolWorkflow:
self.prompt_queue.extend(params.prompt_queue)
while True:
workflow.logger.info("Waiting for prompts...")
await workflow.wait_condition(
lambda: bool(self.prompt_queue) or self.chat_ended
)
if self.prompt_queue:
# Get user's prompt
# 1) Get the user prompt -> call initial LLM
prompt = self.prompt_queue.popleft()
self.conversation_history.append(("user", prompt))
# Build prompt + context
context_instructions = generate_genai_prompt_from_tools_data(
tools_data, self.format_history()
)
workflow.logger.info("Prompt: " + prompt)
# Pass a single input object
prompt_input = ToolPromptInput(
prompt=prompt,
context_instructions=context_instructions,
)
# Call activity with one argument
responsePrechecked = await workflow.execute_activity_method(
ToolActivities.prompt_llm,
prompt_input,
schedule_to_close_timeout=timedelta(seconds=20),
)
# Check if the response is valid JSON
json_validation_instructions = (
generate_json_validation_prompt_from_tools_data(
tools_data, self.format_history(), responsePrechecked
)
)
workflow.logger.info("Prompt: " + prompt)
# Pass a single input object
prompt_input = ToolPromptInput(
prompt=responsePrechecked,
context_instructions=json_validation_instructions,
# 2) Validate + parse in one shot
tool_data = await workflow.execute_activity_method(
ToolActivities.validate_and_parse_json,
args=[responsePrechecked, tools_data, self.format_history()],
schedule_to_close_timeout=timedelta(seconds=40),
retry_policy=RetryPolicy(initial_interval=timedelta(seconds=10)),
)
# Call activity with one argument
response = await workflow.execute_activity_method(
ToolActivities.prompt_llm,
prompt_input,
# store it
self.tool_data = tool_data
self.conversation_history.append(("response", str(tool_data)))
if self.tool_data.get("next") == "confirm":
dynamic_result = await workflow.execute_activity(
self.tool_data["tool"], # dynamic activity name
self.tool_data["args"], # single argument to pass
schedule_to_close_timeout=timedelta(seconds=20),
)
workflow.logger.info(f"Ollama response: {response}")
self.conversation_history.append(("response", response))
# Call activity with one argument
tool_data = await workflow.execute_activity_method(
ToolActivities.parse_tool_data,
response,
schedule_to_close_timeout=timedelta(seconds=1),
)
self.tool_data = tool_data
if self.tool_data.get("next") == "confirm":
return self.tool_data
return dynamic_result
# Continue as new after X turns
if len(self.conversation_history) >= self.max_turns_before_continue: