dynamic activity to call tool

This commit is contained in:
Steve Androulakis
2025-01-01 13:54:54 -08:00
parent e7e8e7e658
commit eb92c71dfe
4 changed files with 97 additions and 49 deletions

View File

@@ -3,6 +3,9 @@ from temporalio import activity
from temporalio.exceptions import ApplicationError from temporalio.exceptions import ApplicationError
from ollama import chat, ChatResponse from ollama import chat, ChatResponse
import json import json
from models.tool_definitions import ToolsData
from typing import Sequence
from temporalio.common import RawValue
@dataclass @dataclass
@@ -44,6 +47,47 @@ class ToolActivities:
return data return data
@activity.defn
def validate_and_parse_json(
self,
response_prechecked: str,
tools_data: ToolsData,
conversation_history: str,
) -> dict:
"""
1) Build JSON validation instructions
2) Call LLM with those instructions
3) Parse the result
4) If parsing fails, raise exception -> triggers retry
"""
# 1) Build validation instructions
# (Generate the validation prompt exactly as you do in your workflow.)
from prompts.agent_prompt_generators import (
generate_json_validation_prompt_from_tools_data,
)
validation_prompt = generate_json_validation_prompt_from_tools_data(
tools_data, conversation_history, response_prechecked
)
# 2) Call LLM
prompt_input = ToolPromptInput(
prompt=response_prechecked,
context_instructions=validation_prompt,
)
validated_response = self.prompt_llm(prompt_input)
# 3) Parse
# If parse fails, we raise ApplicationError -> triggers retry
try:
parsed = self.parse_tool_data(validated_response)
except Exception as e:
raise ApplicationError(f"Failed to parse validated JSON: {e}")
# 4) If we get here, parse succeeded
return parsed
def get_current_date_human_readable(): def get_current_date_human_readable():
""" """
@@ -54,3 +98,30 @@ def get_current_date_human_readable():
from datetime import datetime from datetime import datetime
return datetime.now().strftime("%A, %B %d, %Y") return datetime.now().strftime("%A, %B %d, %Y")
@activity.defn(dynamic=True)
def dynamic_tool_activity(args: Sequence[RawValue]) -> dict:
"""Dynamic activity that is invoked via an unknown activity type."""
tool_name = activity.info().activity_type # e.g. "SearchFlights"
# The first payload is the dictionary of arguments
tool_args = activity.payload_converter().from_payload(args[0].payload, dict)
# Extract fields from the arguments
date_depart = tool_args.get("dateDepart")
date_return = tool_args.get("dateReturn")
origin = tool_args.get("origin")
destination = tool_args.get("destination")
# Print (or log) them
activity.logger.info(f"Tool: {tool_name}")
activity.logger.info(f"Depart: {date_depart}, Return: {date_return}")
activity.logger.info(f"Origin: {origin}, Destination: {destination}")
# For now, just return them
return {
"tool": tool_name,
"args": tool_args,
"status": "OK - dynamic activity stub",
}

View File

@@ -5,7 +5,7 @@ import logging
from temporalio.client import Client from temporalio.client import Client
from temporalio.worker import Worker from temporalio.worker import Worker
from activities.tool_activities import ToolActivities from activities.tool_activities import ToolActivities, dynamic_tool_activity
from workflows.tool_workflow import ToolWorkflow from workflows.tool_workflow import ToolWorkflow
from workflows.parent_workflow import ParentWorkflow from workflows.parent_workflow import ParentWorkflow
@@ -21,7 +21,12 @@ async def main():
client, client,
task_queue="ollama-task-queue", task_queue="ollama-task-queue",
workflows=[ToolWorkflow, ParentWorkflow], workflows=[ToolWorkflow, ParentWorkflow],
activities=[activities.prompt_llm, activities.parse_tool_data], activities=[
activities.prompt_llm,
activities.parse_tool_data,
activities.validate_and_parse_json,
dynamic_tool_activity,
],
activity_executor=activity_executor, activity_executor=activity_executor,
) )
await worker.run() await worker.run()

View File

@@ -12,7 +12,7 @@ async def main(prompt):
# Construct your tool definitions in code # Construct your tool definitions in code
search_flights_tool = ToolDefinition( search_flights_tool = ToolDefinition(
name="SearchFlights", name="SearchFlights",
description="Search for return flights from an origin to a destination within a date range", description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)",
arguments=[ arguments=[
ToolArgument( ToolArgument(
name="origin", name="origin",
@@ -27,12 +27,12 @@ async def main(prompt):
ToolArgument( ToolArgument(
name="dateDepart", name="dateDepart",
type="ISO8601", type="ISO8601",
description="Start of date range in human readable format", description="Start of date range in human readable format, when you want to depart",
), ),
ToolArgument( ToolArgument(
name="dateReturn", name="dateReturn",
type="ISO8601", type="ISO8601",
description="End of date range in human readable format", description="End of date range in human readable format, when you want to return",
), ),
], ],
) )

View File

@@ -1,18 +1,14 @@
from collections import deque from collections import deque
from datetime import timedelta from datetime import timedelta
from typing import Deque, List, Optional, Tuple from typing import Deque, List, Optional, Tuple
from temporalio.common import RetryPolicy
from temporalio import workflow from temporalio import workflow
from prompts.agent_prompt_generators import (
generate_genai_prompt_from_tools_data,
generate_json_validation_prompt_from_tools_data,
)
with workflow.unsafe.imports_passed_through(): with workflow.unsafe.imports_passed_through():
from activities.tool_activities import ToolActivities, ToolPromptInput from activities.tool_activities import ToolActivities, ToolPromptInput
from prompts.agent_prompt_generators import ( from prompts.agent_prompt_generators import (
generate_genai_prompt_from_tools_data, generate_genai_prompt_from_tools_data,
generate_json_validation_prompt_from_tools_data,
) )
from models.data_types import CombinedInput, ToolWorkflowParams from models.data_types import CombinedInput, ToolWorkflowParams
@@ -43,71 +39,47 @@ class ToolWorkflow:
self.prompt_queue.extend(params.prompt_queue) self.prompt_queue.extend(params.prompt_queue)
while True: while True:
workflow.logger.info("Waiting for prompts...")
await workflow.wait_condition( await workflow.wait_condition(
lambda: bool(self.prompt_queue) or self.chat_ended lambda: bool(self.prompt_queue) or self.chat_ended
) )
if self.prompt_queue: if self.prompt_queue:
# Get user's prompt # 1) Get the user prompt -> call initial LLM
prompt = self.prompt_queue.popleft() prompt = self.prompt_queue.popleft()
self.conversation_history.append(("user", prompt)) self.conversation_history.append(("user", prompt))
# Build prompt + context
context_instructions = generate_genai_prompt_from_tools_data( context_instructions = generate_genai_prompt_from_tools_data(
tools_data, self.format_history() tools_data, self.format_history()
) )
workflow.logger.info("Prompt: " + prompt)
# Pass a single input object
prompt_input = ToolPromptInput( prompt_input = ToolPromptInput(
prompt=prompt, prompt=prompt,
context_instructions=context_instructions, context_instructions=context_instructions,
) )
# Call activity with one argument
responsePrechecked = await workflow.execute_activity_method( responsePrechecked = await workflow.execute_activity_method(
ToolActivities.prompt_llm, ToolActivities.prompt_llm,
prompt_input, prompt_input,
schedule_to_close_timeout=timedelta(seconds=20), schedule_to_close_timeout=timedelta(seconds=20),
) )
# Check if the response is valid JSON # 2) Validate + parse in one shot
json_validation_instructions = (
generate_json_validation_prompt_from_tools_data(
tools_data, self.format_history(), responsePrechecked
)
)
workflow.logger.info("Prompt: " + prompt)
# Pass a single input object
prompt_input = ToolPromptInput(
prompt=responsePrechecked,
context_instructions=json_validation_instructions,
)
# Call activity with one argument
response = await workflow.execute_activity_method(
ToolActivities.prompt_llm,
prompt_input,
schedule_to_close_timeout=timedelta(seconds=20),
)
workflow.logger.info(f"Ollama response: {response}")
self.conversation_history.append(("response", response))
# Call activity with one argument
tool_data = await workflow.execute_activity_method( tool_data = await workflow.execute_activity_method(
ToolActivities.parse_tool_data, ToolActivities.validate_and_parse_json,
response, args=[responsePrechecked, tools_data, self.format_history()],
schedule_to_close_timeout=timedelta(seconds=1), schedule_to_close_timeout=timedelta(seconds=40),
retry_policy=RetryPolicy(initial_interval=timedelta(seconds=10)),
) )
# store it
self.tool_data = tool_data self.tool_data = tool_data
self.conversation_history.append(("response", str(tool_data)))
if self.tool_data.get("next") == "confirm": if self.tool_data.get("next") == "confirm":
return self.tool_data dynamic_result = await workflow.execute_activity(
self.tool_data["tool"], # dynamic activity name
self.tool_data["args"], # single argument to pass
schedule_to_close_timeout=timedelta(seconds=20),
)
return dynamic_result
# Continue as new after X turns # Continue as new after X turns
if len(self.conversation_history) >= self.max_turns_before_continue: if len(self.conversation_history) >= self.max_turns_before_continue: