Files
temporal-ai-agent/activities/tool_activities.py
Steve Androulakis 861e55a8d0 Mcp enhancements (#43)
* reuses MCP connections in each worker for efficiency

* you can see your food

* you can see your food

* prompt eng around images
2025-06-16 08:37:32 -07:00

513 lines
20 KiB
Python

import inspect
import json
import os
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Any, Dict, List, Optional, Sequence
from dotenv import load_dotenv
from litellm import completion
from temporalio import activity
from temporalio.common import RawValue
from temporalio.exceptions import ApplicationError
from models.data_types import (
EnvLookupInput,
EnvLookupOutput,
ToolPromptInput,
ValidationInput,
ValidationResult,
)
from models.tool_definitions import MCPServerDefinition
from shared.mcp_client_manager import MCPClientManager
# Import MCP client libraries
try:
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
except ImportError:
# Fallback if MCP not installed
ClientSession = None
StdioServerParameters = None
stdio_client = None
load_dotenv(override=True)
class ToolActivities:
def __init__(self, mcp_client_manager: MCPClientManager = None):
"""Initialize LLM client using LiteLLM and optional MCP client manager"""
self.llm_model = os.environ.get("LLM_MODEL", "openai/gpt-4")
self.llm_key = os.environ.get("LLM_KEY")
self.llm_base_url = os.environ.get("LLM_BASE_URL")
self.mcp_client_manager = mcp_client_manager
print(f"Initializing ToolActivities with LLM model: {self.llm_model}")
if self.llm_base_url:
print(f"Using custom base URL: {self.llm_base_url}")
if self.mcp_client_manager:
print("MCP client manager enabled for connection pooling")
@activity.defn
async def agent_validatePrompt(
self, validation_input: ValidationInput
) -> ValidationResult:
"""
Validates the prompt in the context of the conversation history and agent goal.
Returns a ValidationResult indicating if the prompt makes sense given the context.
"""
# Create simple context string describing tools and goals
tools_description = []
for tool in validation_input.agent_goal.tools:
tool_str = f"Tool: {tool.name}\n"
tool_str += f"Description: {tool.description}\n"
tool_str += "Arguments: " + ", ".join(
[f"{arg.name} ({arg.type})" for arg in tool.arguments]
)
tools_description.append(tool_str)
tools_str = "\n".join(tools_description)
# Convert conversation history to string
history_str = json.dumps(validation_input.conversation_history, indent=2)
# Create context instructions
context_instructions = f"""The agent goal and tools are as follows:
Description: {validation_input.agent_goal.description}
Available Tools:
{tools_str}
The conversation history to date is:
{history_str}"""
# Create validation prompt
validation_prompt = f"""The user's prompt is: "{validation_input.prompt}"
Please validate if this prompt makes sense given the agent goal and conversation history.
If the prompt makes sense toward the goal then validationResult should be true.
If the prompt is wildly nonsensical or makes no sense toward the goal and current conversation history then validationResult should be false.
If the response is low content such as "yes" or "that's right" then the user is probably responding to a previous prompt.
Therefore examine it in the context of the conversation history to determine if it makes sense and return true if it makes sense.
Return ONLY a JSON object with the following structure:
"validationResult": true/false,
"validationFailedReason": "If validationResult is false, provide a clear explanation to the user in the response field
about why their request doesn't make sense in the context and what information they should provide instead.
validationFailedReason should contain JSON in the format
{{
"next": "question",
"response": "[your reason here and a response to get the user back on track with the agent goal]"
}}
If validationResult is true (the prompt makes sense), return an empty dict as its value {{}}"
"""
# Call the LLM with the validation prompt
prompt_input = ToolPromptInput(
prompt=validation_prompt, context_instructions=context_instructions
)
result = await self.agent_toolPlanner(prompt_input)
return ValidationResult(
validationResult=result.get("validationResult", False),
validationFailedReason=result.get("validationFailedReason", {}),
)
@activity.defn
async def agent_toolPlanner(self, input: ToolPromptInput) -> dict:
messages = [
{
"role": "system",
"content": input.context_instructions
+ ". The current date is "
+ datetime.now().strftime("%B %d, %Y"),
},
{
"role": "user",
"content": input.prompt,
},
]
try:
completion_kwargs = {
"model": self.llm_model,
"messages": messages,
"api_key": self.llm_key,
}
# Add base_url if configured
if self.llm_base_url:
completion_kwargs["base_url"] = self.llm_base_url
response = completion(**completion_kwargs)
response_content = response.choices[0].message.content
activity.logger.info(f"Raw LLM response: {repr(response_content)}")
activity.logger.info(f"LLM response content: {response_content}")
activity.logger.info(f"LLM response type: {type(response_content)}")
activity.logger.info(
f"LLM response length: {len(response_content) if response_content else 'None'}"
)
# Use the new sanitize function
response_content = self.sanitize_json_response(response_content)
activity.logger.info(f"Sanitized response: {repr(response_content)}")
return self.parse_json_response(response_content)
except Exception as e:
print(f"Error in LLM completion: {str(e)}")
raise
def parse_json_response(self, response_content: str) -> dict:
"""
Parses the JSON response content and returns it as a dictionary.
"""
try:
data = json.loads(response_content)
return data
except json.JSONDecodeError as e:
print(f"Invalid JSON: {e}")
raise
def sanitize_json_response(self, response_content: str) -> str:
"""
Sanitizes the response content to ensure it's valid JSON.
"""
# Remove any markdown code block markers
response_content = response_content.replace("```json", "").replace("```", "")
# Remove any leading/trailing whitespace
response_content = response_content.strip()
return response_content
@activity.defn
async def get_wf_env_vars(self, input: EnvLookupInput) -> EnvLookupOutput:
"""gets env vars for workflow as an activity result so it's deterministic
handles default/None
"""
output: EnvLookupOutput = EnvLookupOutput(
show_confirm=input.show_confirm_default, multi_goal_mode=False
)
show_confirm_value = os.getenv(input.show_confirm_env_var_name)
if show_confirm_value is None:
output.show_confirm = input.show_confirm_default
elif show_confirm_value is not None and show_confirm_value.lower() == "false":
output.show_confirm = False
else:
output.show_confirm = True
first_goal_value = os.getenv("AGENT_GOAL")
if first_goal_value is None:
output.multi_goal_mode = False # default to single agent mode if unset
elif (
first_goal_value is not None
and first_goal_value.lower() == "goal_choose_agent_type"
):
output.multi_goal_mode = True
else:
output.multi_goal_mode = False
return output
@activity.defn
async def mcp_tool_activity(
self, tool_name: str, tool_args: Dict[str, Any]
) -> Dict[str, Any]:
"""MCP Tool - now using pooled connections"""
activity.logger.info(f"Executing MCP tool: {tool_name} with args: {tool_args}")
# Extract server definition
server_definition = tool_args.pop("server_definition", None)
if self.mcp_client_manager:
# Use pooled connection
return await self._execute_mcp_tool_pooled(
tool_name, tool_args, server_definition
)
else:
# Fallback to original implementation
return await _execute_mcp_tool(tool_name, tool_args, server_definition)
async def _execute_mcp_tool_pooled(
self,
tool_name: str,
tool_args: Dict[str, Any],
server_definition: MCPServerDefinition | Dict[str, Any] | None,
) -> Dict[str, Any]:
"""Execute MCP tool using pooled client connection"""
activity.logger.info(f"Executing MCP tool with pooled connection: {tool_name}")
# Convert argument types for MCP tools
converted_args = _convert_args_types(tool_args)
try:
# Get pooled client
client = await self.mcp_client_manager.get_client(server_definition)
# Call the tool using existing client session
result = await client.call_tool(tool_name, arguments=converted_args)
normalized_result = _normalize_result(result)
return {
"tool": tool_name,
"success": True,
"content": normalized_result,
}
except Exception as e:
activity.logger.error(f"MCP tool {tool_name} failed: {str(e)}")
return {
"tool": tool_name,
"success": False,
"error": str(e),
"error_type": type(e).__name__,
}
@activity.defn(dynamic=True)
async def dynamic_tool_activity(args: Sequence[RawValue]) -> dict:
from tools import get_handler
tool_name = activity.info().activity_type # e.g. "FindEvents"
tool_args = activity.payload_converter().from_payload(args[0].payload, dict)
activity.logger.info(f"Running dynamic tool '{tool_name}' with args: {tool_args}")
# Check if this is an MCP tool call by looking for server_definition in args
server_definition = tool_args.pop("server_definition", None)
if server_definition:
# This is an MCP tool call - handle it directly
activity.logger.info(f"Executing MCP tool: {tool_name}")
return await _execute_mcp_tool(tool_name, tool_args, server_definition)
else:
# This is a regular tool - delegate to the relevant function
handler = get_handler(tool_name)
if inspect.iscoroutinefunction(handler):
result = await handler(tool_args)
else:
result = handler(tool_args)
# Optionally log or augment the result
activity.logger.info(f"Tool '{tool_name}' result: {result}")
return result
# MCP Client Activities
def _build_connection(
server_definition: MCPServerDefinition | Dict[str, Any] | None,
) -> Dict[str, Any]:
"""Build connection parameters from MCPServerDefinition or dict"""
if server_definition is None:
# Default to stdio connection with the main server
return {"type": "stdio", "command": "python", "args": ["server.py"], "env": {}}
# Handle both MCPServerDefinition objects and dicts (from Temporal serialization)
if isinstance(server_definition, dict):
return {
"type": server_definition.get("connection_type", "stdio"),
"command": server_definition.get("command", "python"),
"args": server_definition.get("args", ["server.py"]),
"env": server_definition.get("env", {}) or {},
}
return {
"type": server_definition.connection_type,
"command": server_definition.command,
"args": server_definition.args,
"env": server_definition.env or {},
}
def _normalize_result(result: Any) -> Any:
"""Normalize MCP tool result for serialization"""
if hasattr(result, "content"):
# Handle MCP result objects
if hasattr(result.content, "__iter__") and not isinstance(result.content, str):
return [
item.text if hasattr(item, "text") else str(item)
for item in result.content
]
return str(result.content)
return result
def _convert_args_types(tool_args: Dict[str, Any]) -> Dict[str, Any]:
"""Convert string arguments to appropriate types for MCP tools"""
converted_args = {}
for key, value in tool_args.items():
if key == "server_definition":
# Skip server_definition - it's metadata
continue
if isinstance(value, str):
# Try to convert string values to appropriate types
if value.isdigit():
# Convert numeric strings to integers
converted_args[key] = int(value)
elif value.replace(".", "").isdigit() and value.count(".") == 1:
# Convert decimal strings to floats
converted_args[key] = float(value)
elif value.lower() in ("true", "false"):
# Convert boolean strings
converted_args[key] = value.lower() == "true"
else:
# Keep as string
converted_args[key] = value
else:
# Keep non-string values as-is
converted_args[key] = value
return converted_args
async def _execute_mcp_tool(
tool_name: str,
tool_args: Dict[str, Any],
server_definition: MCPServerDefinition | Dict[str, Any] | None,
) -> Dict[str, Any]:
"""Execute an MCP tool with the given arguments and server definition"""
activity.logger.info(f"Executing MCP tool: {tool_name}")
# Convert argument types for MCP tools
converted_args = _convert_args_types(tool_args)
connection = _build_connection(server_definition)
try:
if connection["type"] == "stdio":
# Handle stdio connection
async with _stdio_connection(
command=connection.get("command", "python"),
args=connection.get("args", ["server.py"]),
env=connection.get("env", {}),
) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the session
activity.logger.info(f"Initializing MCP session for {tool_name}")
await session.initialize()
activity.logger.info(f"MCP session initialized for {tool_name}")
# Call the tool
activity.logger.info(
f"Calling MCP tool {tool_name} with args: {converted_args}"
)
try:
result = await session.call_tool(
tool_name, arguments=converted_args
)
activity.logger.info(
f"MCP tool {tool_name} returned result: {result}"
)
except Exception as tool_exc:
activity.logger.error(
f"MCP tool {tool_name} call failed: {type(tool_exc).__name__}: {tool_exc}"
)
raise
normalized_result = _normalize_result(result)
activity.logger.info(f"MCP tool {tool_name} completed successfully")
return {
"tool": tool_name,
"success": True,
"content": normalized_result,
}
elif connection["type"] == "tcp":
# Handle TCP connection (placeholder for future implementation)
raise ApplicationError("TCP connections not yet implemented")
else:
raise ApplicationError(f"Unsupported connection type: {connection['type']}")
except Exception as e:
activity.logger.error(f"MCP tool {tool_name} failed: {str(e)}")
# Return error information
return {
"tool": tool_name,
"success": False,
"error": str(e),
"error_type": type(e).__name__,
}
@asynccontextmanager
async def _stdio_connection(command: str, args: list, env: dict):
"""Create stdio connection to MCP server"""
if stdio_client is None:
raise ApplicationError("MCP client libraries not available")
# Create server parameters
server_params = StdioServerParameters(command=command, args=args, env=env)
async with stdio_client(server_params) as (read, write):
yield read, write
@activity.defn
async def mcp_list_tools(
server_definition: MCPServerDefinition, include_tools: Optional[List[str]] = None
) -> Dict[str, Any]:
"""List available MCP tools from the specified server"""
activity.logger.info(f"Listing MCP tools for server: {server_definition.name}")
connection = _build_connection(server_definition)
try:
if connection["type"] == "stdio":
async with _stdio_connection(
command=connection.get("command", "python"),
args=connection.get("args", ["server.py"]),
env=connection.get("env", {}),
) as (read, write):
async with ClientSession(read, write) as session:
# Initialize the session
await session.initialize()
# List available tools
tools_response = await session.list_tools()
# Process tools based on include_tools filter
tools_info = {}
for tool in tools_response.tools:
# If include_tools is specified, only include those tools
if include_tools is None or tool.name in include_tools:
tools_info[tool.name] = {
"name": tool.name,
"description": tool.description,
"inputSchema": (
tool.inputSchema.model_dump()
if hasattr(tool.inputSchema, "model_dump")
else str(tool.inputSchema)
),
}
activity.logger.info(
f"Found {len(tools_info)} tools for server {server_definition.name}"
)
return {
"server_name": server_definition.name,
"success": True,
"tools": tools_info,
"total_available": len(tools_response.tools),
"filtered_count": len(tools_info),
}
elif connection["type"] == "tcp":
raise ApplicationError("TCP connections not yet implemented")
else:
raise ApplicationError(f"Unsupported connection type: {connection['type']}")
except Exception as e:
activity.logger.error(
f"Failed to list tools for server {server_definition.name}: {str(e)}"
)
return {
"server_name": server_definition.name,
"success": False,
"error": str(e),
"error_type": type(e).__name__,
}