Files
temporal-ai-agent/activities/tool_activities.py
Steve Androulakis eb06cf5c8d Enhance Dev Experience and Code Quality (#41)
* Format codebase to satisfy linters

* fixing pylance and ruff-checked files

* contributing md, and type and formatting fixes

* setup file capitalization

* test fix
2025-06-01 08:54:59 -07:00

204 lines
7.7 KiB
Python

import inspect
import json
import os
from datetime import datetime
from typing import Sequence
from dotenv import load_dotenv
from litellm import completion
from temporalio import activity
from temporalio.common import RawValue
from models.data_types import (
EnvLookupInput,
EnvLookupOutput,
ToolPromptInput,
ValidationInput,
ValidationResult,
)
load_dotenv(override=True)
class ToolActivities:
def __init__(self):
"""Initialize LLM client using LiteLLM."""
self.llm_model = os.environ.get("LLM_MODEL", "openai/gpt-4")
self.llm_key = os.environ.get("LLM_KEY")
self.llm_base_url = os.environ.get("LLM_BASE_URL")
print(f"Initializing ToolActivities with LLM model: {self.llm_model}")
if self.llm_base_url:
print(f"Using custom base URL: {self.llm_base_url}")
@activity.defn
async def agent_validatePrompt(
self, validation_input: ValidationInput
) -> ValidationResult:
"""
Validates the prompt in the context of the conversation history and agent goal.
Returns a ValidationResult indicating if the prompt makes sense given the context.
"""
# Create simple context string describing tools and goals
tools_description = []
for tool in validation_input.agent_goal.tools:
tool_str = f"Tool: {tool.name}\n"
tool_str += f"Description: {tool.description}\n"
tool_str += "Arguments: " + ", ".join(
[f"{arg.name} ({arg.type})" for arg in tool.arguments]
)
tools_description.append(tool_str)
tools_str = "\n".join(tools_description)
# Convert conversation history to string
history_str = json.dumps(validation_input.conversation_history, indent=2)
# Create context instructions
context_instructions = f"""The agent goal and tools are as follows:
Description: {validation_input.agent_goal.description}
Available Tools:
{tools_str}
The conversation history to date is:
{history_str}"""
# Create validation prompt
validation_prompt = f"""The user's prompt is: "{validation_input.prompt}"
Please validate if this prompt makes sense given the agent goal and conversation history.
If the prompt makes sense toward the goal then validationResult should be true.
If the prompt is wildly nonsensical or makes no sense toward the goal and current conversation history then validationResult should be false.
If the response is low content such as "yes" or "that's right" then the user is probably responding to a previous prompt.
Therefore examine it in the context of the conversation history to determine if it makes sense and return true if it makes sense.
Return ONLY a JSON object with the following structure:
"validationResult": true/false,
"validationFailedReason": "If validationResult is false, provide a clear explanation to the user in the response field
about why their request doesn't make sense in the context and what information they should provide instead.
validationFailedReason should contain JSON in the format
{{
"next": "question",
"response": "[your reason here and a response to get the user back on track with the agent goal]"
}}
If validationResult is true (the prompt makes sense), return an empty dict as its value {{}}"
"""
# Call the LLM with the validation prompt
prompt_input = ToolPromptInput(
prompt=validation_prompt, context_instructions=context_instructions
)
result = await self.agent_toolPlanner(prompt_input)
return ValidationResult(
validationResult=result.get("validationResult", False),
validationFailedReason=result.get("validationFailedReason", {}),
)
@activity.defn
async def agent_toolPlanner(self, input: ToolPromptInput) -> dict:
messages = [
{
"role": "system",
"content": input.context_instructions
+ ". The current date is "
+ datetime.now().strftime("%B %d, %Y"),
},
{
"role": "user",
"content": input.prompt,
},
]
try:
completion_kwargs = {
"model": self.llm_model,
"messages": messages,
"api_key": self.llm_key,
}
# Add base_url if configured
if self.llm_base_url:
completion_kwargs["base_url"] = self.llm_base_url
response = completion(**completion_kwargs)
response_content = response.choices[0].message.content
activity.logger.info(f"LLM response: {response_content}")
# Use the new sanitize function
response_content = self.sanitize_json_response(response_content)
return self.parse_json_response(response_content)
except Exception as e:
print(f"Error in LLM completion: {str(e)}")
raise
def parse_json_response(self, response_content: str) -> dict:
"""
Parses the JSON response content and returns it as a dictionary.
"""
try:
data = json.loads(response_content)
return data
except json.JSONDecodeError as e:
print(f"Invalid JSON: {e}")
raise
def sanitize_json_response(self, response_content: str) -> str:
"""
Sanitizes the response content to ensure it's valid JSON.
"""
# Remove any markdown code block markers
response_content = response_content.replace("```json", "").replace("```", "")
# Remove any leading/trailing whitespace
response_content = response_content.strip()
return response_content
@activity.defn
async def get_wf_env_vars(self, input: EnvLookupInput) -> EnvLookupOutput:
"""gets env vars for workflow as an activity result so it's deterministic
handles default/None
"""
output: EnvLookupOutput = EnvLookupOutput(
show_confirm=input.show_confirm_default, multi_goal_mode=True
)
show_confirm_value = os.getenv(input.show_confirm_env_var_name)
if show_confirm_value is None:
output.show_confirm = input.show_confirm_default
elif show_confirm_value is not None and show_confirm_value.lower() == "false":
output.show_confirm = False
else:
output.show_confirm = True
first_goal_value = os.getenv("AGENT_GOAL")
if first_goal_value is None:
output.multi_goal_mode = True # default if unset
elif (
first_goal_value is not None
and first_goal_value.lower() != "goal_choose_agent_type"
):
output.multi_goal_mode = False
else:
output.multi_goal_mode = True
return output
@activity.defn(dynamic=True)
async def dynamic_tool_activity(args: Sequence[RawValue]) -> dict:
from tools import get_handler
tool_name = activity.info().activity_type # e.g. "FindEvents"
tool_args = activity.payload_converter().from_payload(args[0].payload, dict)
activity.logger.info(f"Running dynamic tool '{tool_name}' with args: {tool_args}")
# Delegate to the relevant function
handler = get_handler(tool_name)
if inspect.iscoroutinefunction(handler):
result = await handler(tool_args)
else:
result = handler(tool_args)
# Optionally log or augment the result
activity.logger.info(f"Tool '{tool_name}' result: {result}")
return result