mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
dynamic activity to call tool
This commit is contained in:
@@ -3,6 +3,9 @@ from temporalio import activity
|
|||||||
from temporalio.exceptions import ApplicationError
|
from temporalio.exceptions import ApplicationError
|
||||||
from ollama import chat, ChatResponse
|
from ollama import chat, ChatResponse
|
||||||
import json
|
import json
|
||||||
|
from models.tool_definitions import ToolsData
|
||||||
|
from typing import Sequence
|
||||||
|
from temporalio.common import RawValue
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -44,6 +47,47 @@ class ToolActivities:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@activity.defn
|
||||||
|
def validate_and_parse_json(
|
||||||
|
self,
|
||||||
|
response_prechecked: str,
|
||||||
|
tools_data: ToolsData,
|
||||||
|
conversation_history: str,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
1) Build JSON validation instructions
|
||||||
|
2) Call LLM with those instructions
|
||||||
|
3) Parse the result
|
||||||
|
4) If parsing fails, raise exception -> triggers retry
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 1) Build validation instructions
|
||||||
|
# (Generate the validation prompt exactly as you do in your workflow.)
|
||||||
|
from prompts.agent_prompt_generators import (
|
||||||
|
generate_json_validation_prompt_from_tools_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
validation_prompt = generate_json_validation_prompt_from_tools_data(
|
||||||
|
tools_data, conversation_history, response_prechecked
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Call LLM
|
||||||
|
prompt_input = ToolPromptInput(
|
||||||
|
prompt=response_prechecked,
|
||||||
|
context_instructions=validation_prompt,
|
||||||
|
)
|
||||||
|
validated_response = self.prompt_llm(prompt_input)
|
||||||
|
|
||||||
|
# 3) Parse
|
||||||
|
# If parse fails, we raise ApplicationError -> triggers retry
|
||||||
|
try:
|
||||||
|
parsed = self.parse_tool_data(validated_response)
|
||||||
|
except Exception as e:
|
||||||
|
raise ApplicationError(f"Failed to parse validated JSON: {e}")
|
||||||
|
|
||||||
|
# 4) If we get here, parse succeeded
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
def get_current_date_human_readable():
|
def get_current_date_human_readable():
|
||||||
"""
|
"""
|
||||||
@@ -54,3 +98,30 @@ def get_current_date_human_readable():
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
return datetime.now().strftime("%A, %B %d, %Y")
|
return datetime.now().strftime("%A, %B %d, %Y")
|
||||||
|
|
||||||
|
|
||||||
|
@activity.defn(dynamic=True)
|
||||||
|
def dynamic_tool_activity(args: Sequence[RawValue]) -> dict:
|
||||||
|
"""Dynamic activity that is invoked via an unknown activity type."""
|
||||||
|
tool_name = activity.info().activity_type # e.g. "SearchFlights"
|
||||||
|
|
||||||
|
# The first payload is the dictionary of arguments
|
||||||
|
tool_args = activity.payload_converter().from_payload(args[0].payload, dict)
|
||||||
|
|
||||||
|
# Extract fields from the arguments
|
||||||
|
date_depart = tool_args.get("dateDepart")
|
||||||
|
date_return = tool_args.get("dateReturn")
|
||||||
|
origin = tool_args.get("origin")
|
||||||
|
destination = tool_args.get("destination")
|
||||||
|
|
||||||
|
# Print (or log) them
|
||||||
|
activity.logger.info(f"Tool: {tool_name}")
|
||||||
|
activity.logger.info(f"Depart: {date_depart}, Return: {date_return}")
|
||||||
|
activity.logger.info(f"Origin: {origin}, Destination: {destination}")
|
||||||
|
|
||||||
|
# For now, just return them
|
||||||
|
return {
|
||||||
|
"tool": tool_name,
|
||||||
|
"args": tool_args,
|
||||||
|
"status": "OK - dynamic activity stub",
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import logging
|
|||||||
from temporalio.client import Client
|
from temporalio.client import Client
|
||||||
from temporalio.worker import Worker
|
from temporalio.worker import Worker
|
||||||
|
|
||||||
from activities.tool_activities import ToolActivities
|
from activities.tool_activities import ToolActivities, dynamic_tool_activity
|
||||||
from workflows.tool_workflow import ToolWorkflow
|
from workflows.tool_workflow import ToolWorkflow
|
||||||
from workflows.parent_workflow import ParentWorkflow
|
from workflows.parent_workflow import ParentWorkflow
|
||||||
|
|
||||||
@@ -21,7 +21,12 @@ async def main():
|
|||||||
client,
|
client,
|
||||||
task_queue="ollama-task-queue",
|
task_queue="ollama-task-queue",
|
||||||
workflows=[ToolWorkflow, ParentWorkflow],
|
workflows=[ToolWorkflow, ParentWorkflow],
|
||||||
activities=[activities.prompt_llm, activities.parse_tool_data],
|
activities=[
|
||||||
|
activities.prompt_llm,
|
||||||
|
activities.parse_tool_data,
|
||||||
|
activities.validate_and_parse_json,
|
||||||
|
dynamic_tool_activity,
|
||||||
|
],
|
||||||
activity_executor=activity_executor,
|
activity_executor=activity_executor,
|
||||||
)
|
)
|
||||||
await worker.run()
|
await worker.run()
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ async def main(prompt):
|
|||||||
# Construct your tool definitions in code
|
# Construct your tool definitions in code
|
||||||
search_flights_tool = ToolDefinition(
|
search_flights_tool = ToolDefinition(
|
||||||
name="SearchFlights",
|
name="SearchFlights",
|
||||||
description="Search for return flights from an origin to a destination within a date range",
|
description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)",
|
||||||
arguments=[
|
arguments=[
|
||||||
ToolArgument(
|
ToolArgument(
|
||||||
name="origin",
|
name="origin",
|
||||||
@@ -27,12 +27,12 @@ async def main(prompt):
|
|||||||
ToolArgument(
|
ToolArgument(
|
||||||
name="dateDepart",
|
name="dateDepart",
|
||||||
type="ISO8601",
|
type="ISO8601",
|
||||||
description="Start of date range in human readable format",
|
description="Start of date range in human readable format, when you want to depart",
|
||||||
),
|
),
|
||||||
ToolArgument(
|
ToolArgument(
|
||||||
name="dateReturn",
|
name="dateReturn",
|
||||||
type="ISO8601",
|
type="ISO8601",
|
||||||
description="End of date range in human readable format",
|
description="End of date range in human readable format, when you want to return",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,18 +1,14 @@
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Deque, List, Optional, Tuple
|
from typing import Deque, List, Optional, Tuple
|
||||||
|
from temporalio.common import RetryPolicy
|
||||||
|
|
||||||
from temporalio import workflow
|
from temporalio import workflow
|
||||||
from prompts.agent_prompt_generators import (
|
|
||||||
generate_genai_prompt_from_tools_data,
|
|
||||||
generate_json_validation_prompt_from_tools_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
with workflow.unsafe.imports_passed_through():
|
with workflow.unsafe.imports_passed_through():
|
||||||
from activities.tool_activities import ToolActivities, ToolPromptInput
|
from activities.tool_activities import ToolActivities, ToolPromptInput
|
||||||
from prompts.agent_prompt_generators import (
|
from prompts.agent_prompt_generators import (
|
||||||
generate_genai_prompt_from_tools_data,
|
generate_genai_prompt_from_tools_data,
|
||||||
generate_json_validation_prompt_from_tools_data,
|
|
||||||
)
|
)
|
||||||
from models.data_types import CombinedInput, ToolWorkflowParams
|
from models.data_types import CombinedInput, ToolWorkflowParams
|
||||||
|
|
||||||
@@ -43,71 +39,47 @@ class ToolWorkflow:
|
|||||||
self.prompt_queue.extend(params.prompt_queue)
|
self.prompt_queue.extend(params.prompt_queue)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
workflow.logger.info("Waiting for prompts...")
|
|
||||||
|
|
||||||
await workflow.wait_condition(
|
await workflow.wait_condition(
|
||||||
lambda: bool(self.prompt_queue) or self.chat_ended
|
lambda: bool(self.prompt_queue) or self.chat_ended
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.prompt_queue:
|
if self.prompt_queue:
|
||||||
# Get user's prompt
|
# 1) Get the user prompt -> call initial LLM
|
||||||
prompt = self.prompt_queue.popleft()
|
prompt = self.prompt_queue.popleft()
|
||||||
self.conversation_history.append(("user", prompt))
|
self.conversation_history.append(("user", prompt))
|
||||||
|
|
||||||
# Build prompt + context
|
|
||||||
context_instructions = generate_genai_prompt_from_tools_data(
|
context_instructions = generate_genai_prompt_from_tools_data(
|
||||||
tools_data, self.format_history()
|
tools_data, self.format_history()
|
||||||
)
|
)
|
||||||
workflow.logger.info("Prompt: " + prompt)
|
|
||||||
|
|
||||||
# Pass a single input object
|
|
||||||
prompt_input = ToolPromptInput(
|
prompt_input = ToolPromptInput(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
context_instructions=context_instructions,
|
context_instructions=context_instructions,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call activity with one argument
|
|
||||||
responsePrechecked = await workflow.execute_activity_method(
|
responsePrechecked = await workflow.execute_activity_method(
|
||||||
ToolActivities.prompt_llm,
|
ToolActivities.prompt_llm,
|
||||||
prompt_input,
|
prompt_input,
|
||||||
schedule_to_close_timeout=timedelta(seconds=20),
|
schedule_to_close_timeout=timedelta(seconds=20),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the response is valid JSON
|
# 2) Validate + parse in one shot
|
||||||
json_validation_instructions = (
|
tool_data = await workflow.execute_activity_method(
|
||||||
generate_json_validation_prompt_from_tools_data(
|
ToolActivities.validate_and_parse_json,
|
||||||
tools_data, self.format_history(), responsePrechecked
|
args=[responsePrechecked, tools_data, self.format_history()],
|
||||||
)
|
schedule_to_close_timeout=timedelta(seconds=40),
|
||||||
)
|
retry_policy=RetryPolicy(initial_interval=timedelta(seconds=10)),
|
||||||
workflow.logger.info("Prompt: " + prompt)
|
|
||||||
|
|
||||||
# Pass a single input object
|
|
||||||
prompt_input = ToolPromptInput(
|
|
||||||
prompt=responsePrechecked,
|
|
||||||
context_instructions=json_validation_instructions,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Call activity with one argument
|
# store it
|
||||||
response = await workflow.execute_activity_method(
|
self.tool_data = tool_data
|
||||||
ToolActivities.prompt_llm,
|
self.conversation_history.append(("response", str(tool_data)))
|
||||||
prompt_input,
|
|
||||||
|
if self.tool_data.get("next") == "confirm":
|
||||||
|
dynamic_result = await workflow.execute_activity(
|
||||||
|
self.tool_data["tool"], # dynamic activity name
|
||||||
|
self.tool_data["args"], # single argument to pass
|
||||||
schedule_to_close_timeout=timedelta(seconds=20),
|
schedule_to_close_timeout=timedelta(seconds=20),
|
||||||
)
|
)
|
||||||
|
|
||||||
workflow.logger.info(f"Ollama response: {response}")
|
return dynamic_result
|
||||||
self.conversation_history.append(("response", response))
|
|
||||||
|
|
||||||
# Call activity with one argument
|
|
||||||
tool_data = await workflow.execute_activity_method(
|
|
||||||
ToolActivities.parse_tool_data,
|
|
||||||
response,
|
|
||||||
schedule_to_close_timeout=timedelta(seconds=1),
|
|
||||||
)
|
|
||||||
|
|
||||||
self.tool_data = tool_data
|
|
||||||
|
|
||||||
if self.tool_data.get("next") == "confirm":
|
|
||||||
return self.tool_data
|
|
||||||
|
|
||||||
# Continue as new after X turns
|
# Continue as new after X turns
|
||||||
if len(self.conversation_history) >= self.max_turns_before_continue:
|
if len(self.conversation_history) >= self.max_turns_before_continue:
|
||||||
|
|||||||
Reference in New Issue
Block a user