mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
dynamic activity to call tool
This commit is contained in:
@@ -1,18 +1,14 @@
|
||||
from collections import deque
|
||||
from datetime import timedelta
|
||||
from typing import Deque, List, Optional, Tuple
|
||||
from temporalio.common import RetryPolicy
|
||||
|
||||
from temporalio import workflow
|
||||
from prompts.agent_prompt_generators import (
|
||||
generate_genai_prompt_from_tools_data,
|
||||
generate_json_validation_prompt_from_tools_data,
|
||||
)
|
||||
|
||||
with workflow.unsafe.imports_passed_through():
|
||||
from activities.tool_activities import ToolActivities, ToolPromptInput
|
||||
from prompts.agent_prompt_generators import (
|
||||
generate_genai_prompt_from_tools_data,
|
||||
generate_json_validation_prompt_from_tools_data,
|
||||
)
|
||||
from models.data_types import CombinedInput, ToolWorkflowParams
|
||||
|
||||
@@ -43,71 +39,47 @@ class ToolWorkflow:
|
||||
self.prompt_queue.extend(params.prompt_queue)
|
||||
|
||||
while True:
|
||||
workflow.logger.info("Waiting for prompts...")
|
||||
|
||||
await workflow.wait_condition(
|
||||
lambda: bool(self.prompt_queue) or self.chat_ended
|
||||
)
|
||||
|
||||
if self.prompt_queue:
|
||||
# Get user's prompt
|
||||
# 1) Get the user prompt -> call initial LLM
|
||||
prompt = self.prompt_queue.popleft()
|
||||
self.conversation_history.append(("user", prompt))
|
||||
|
||||
# Build prompt + context
|
||||
context_instructions = generate_genai_prompt_from_tools_data(
|
||||
tools_data, self.format_history()
|
||||
)
|
||||
workflow.logger.info("Prompt: " + prompt)
|
||||
|
||||
# Pass a single input object
|
||||
prompt_input = ToolPromptInput(
|
||||
prompt=prompt,
|
||||
context_instructions=context_instructions,
|
||||
)
|
||||
|
||||
# Call activity with one argument
|
||||
responsePrechecked = await workflow.execute_activity_method(
|
||||
ToolActivities.prompt_llm,
|
||||
prompt_input,
|
||||
schedule_to_close_timeout=timedelta(seconds=20),
|
||||
)
|
||||
|
||||
# Check if the response is valid JSON
|
||||
json_validation_instructions = (
|
||||
generate_json_validation_prompt_from_tools_data(
|
||||
tools_data, self.format_history(), responsePrechecked
|
||||
)
|
||||
)
|
||||
workflow.logger.info("Prompt: " + prompt)
|
||||
|
||||
# Pass a single input object
|
||||
prompt_input = ToolPromptInput(
|
||||
prompt=responsePrechecked,
|
||||
context_instructions=json_validation_instructions,
|
||||
)
|
||||
|
||||
# Call activity with one argument
|
||||
response = await workflow.execute_activity_method(
|
||||
ToolActivities.prompt_llm,
|
||||
prompt_input,
|
||||
schedule_to_close_timeout=timedelta(seconds=20),
|
||||
)
|
||||
|
||||
workflow.logger.info(f"Ollama response: {response}")
|
||||
self.conversation_history.append(("response", response))
|
||||
|
||||
# Call activity with one argument
|
||||
# 2) Validate + parse in one shot
|
||||
tool_data = await workflow.execute_activity_method(
|
||||
ToolActivities.parse_tool_data,
|
||||
response,
|
||||
schedule_to_close_timeout=timedelta(seconds=1),
|
||||
ToolActivities.validate_and_parse_json,
|
||||
args=[responsePrechecked, tools_data, self.format_history()],
|
||||
schedule_to_close_timeout=timedelta(seconds=40),
|
||||
retry_policy=RetryPolicy(initial_interval=timedelta(seconds=10)),
|
||||
)
|
||||
|
||||
# store it
|
||||
self.tool_data = tool_data
|
||||
self.conversation_history.append(("response", str(tool_data)))
|
||||
|
||||
if self.tool_data.get("next") == "confirm":
|
||||
return self.tool_data
|
||||
dynamic_result = await workflow.execute_activity(
|
||||
self.tool_data["tool"], # dynamic activity name
|
||||
self.tool_data["args"], # single argument to pass
|
||||
schedule_to_close_timeout=timedelta(seconds=20),
|
||||
)
|
||||
|
||||
return dynamic_result
|
||||
|
||||
# Continue as new after X turns
|
||||
if len(self.conversation_history) >= self.max_turns_before_continue:
|
||||
|
||||
Reference in New Issue
Block a user