mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 05:58:08 +01:00
* reuses MCP connections in each worker for efficiency * you can see your food * you can see your food * prompt eng around images
513 lines
20 KiB
Python
513 lines
20 KiB
Python
import inspect
|
|
import json
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Sequence
|
|
|
|
from dotenv import load_dotenv
|
|
from litellm import completion
|
|
from temporalio import activity
|
|
from temporalio.common import RawValue
|
|
from temporalio.exceptions import ApplicationError
|
|
|
|
from models.data_types import (
|
|
EnvLookupInput,
|
|
EnvLookupOutput,
|
|
ToolPromptInput,
|
|
ValidationInput,
|
|
ValidationResult,
|
|
)
|
|
from models.tool_definitions import MCPServerDefinition
|
|
from shared.mcp_client_manager import MCPClientManager
|
|
|
|
# Import MCP client libraries
|
|
try:
|
|
from mcp import ClientSession, StdioServerParameters
|
|
from mcp.client.stdio import stdio_client
|
|
except ImportError:
|
|
# Fallback if MCP not installed
|
|
ClientSession = None
|
|
StdioServerParameters = None
|
|
stdio_client = None
|
|
|
|
load_dotenv(override=True)
|
|
|
|
|
|
class ToolActivities:
|
|
def __init__(self, mcp_client_manager: MCPClientManager = None):
|
|
"""Initialize LLM client using LiteLLM and optional MCP client manager"""
|
|
self.llm_model = os.environ.get("LLM_MODEL", "openai/gpt-4")
|
|
self.llm_key = os.environ.get("LLM_KEY")
|
|
self.llm_base_url = os.environ.get("LLM_BASE_URL")
|
|
self.mcp_client_manager = mcp_client_manager
|
|
print(f"Initializing ToolActivities with LLM model: {self.llm_model}")
|
|
if self.llm_base_url:
|
|
print(f"Using custom base URL: {self.llm_base_url}")
|
|
if self.mcp_client_manager:
|
|
print("MCP client manager enabled for connection pooling")
|
|
|
|
@activity.defn
|
|
async def agent_validatePrompt(
|
|
self, validation_input: ValidationInput
|
|
) -> ValidationResult:
|
|
"""
|
|
Validates the prompt in the context of the conversation history and agent goal.
|
|
Returns a ValidationResult indicating if the prompt makes sense given the context.
|
|
"""
|
|
# Create simple context string describing tools and goals
|
|
tools_description = []
|
|
for tool in validation_input.agent_goal.tools:
|
|
tool_str = f"Tool: {tool.name}\n"
|
|
tool_str += f"Description: {tool.description}\n"
|
|
tool_str += "Arguments: " + ", ".join(
|
|
[f"{arg.name} ({arg.type})" for arg in tool.arguments]
|
|
)
|
|
tools_description.append(tool_str)
|
|
tools_str = "\n".join(tools_description)
|
|
|
|
# Convert conversation history to string
|
|
history_str = json.dumps(validation_input.conversation_history, indent=2)
|
|
|
|
# Create context instructions
|
|
context_instructions = f"""The agent goal and tools are as follows:
|
|
Description: {validation_input.agent_goal.description}
|
|
Available Tools:
|
|
{tools_str}
|
|
The conversation history to date is:
|
|
{history_str}"""
|
|
|
|
# Create validation prompt
|
|
validation_prompt = f"""The user's prompt is: "{validation_input.prompt}"
|
|
Please validate if this prompt makes sense given the agent goal and conversation history.
|
|
If the prompt makes sense toward the goal then validationResult should be true.
|
|
If the prompt is wildly nonsensical or makes no sense toward the goal and current conversation history then validationResult should be false.
|
|
If the response is low content such as "yes" or "that's right" then the user is probably responding to a previous prompt.
|
|
Therefore examine it in the context of the conversation history to determine if it makes sense and return true if it makes sense.
|
|
Return ONLY a JSON object with the following structure:
|
|
"validationResult": true/false,
|
|
"validationFailedReason": "If validationResult is false, provide a clear explanation to the user in the response field
|
|
about why their request doesn't make sense in the context and what information they should provide instead.
|
|
validationFailedReason should contain JSON in the format
|
|
{{
|
|
"next": "question",
|
|
"response": "[your reason here and a response to get the user back on track with the agent goal]"
|
|
}}
|
|
If validationResult is true (the prompt makes sense), return an empty dict as its value {{}}"
|
|
"""
|
|
|
|
# Call the LLM with the validation prompt
|
|
prompt_input = ToolPromptInput(
|
|
prompt=validation_prompt, context_instructions=context_instructions
|
|
)
|
|
|
|
result = await self.agent_toolPlanner(prompt_input)
|
|
|
|
return ValidationResult(
|
|
validationResult=result.get("validationResult", False),
|
|
validationFailedReason=result.get("validationFailedReason", {}),
|
|
)
|
|
|
|
@activity.defn
|
|
async def agent_toolPlanner(self, input: ToolPromptInput) -> dict:
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": input.context_instructions
|
|
+ ". The current date is "
|
|
+ datetime.now().strftime("%B %d, %Y"),
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": input.prompt,
|
|
},
|
|
]
|
|
|
|
try:
|
|
completion_kwargs = {
|
|
"model": self.llm_model,
|
|
"messages": messages,
|
|
"api_key": self.llm_key,
|
|
}
|
|
|
|
# Add base_url if configured
|
|
if self.llm_base_url:
|
|
completion_kwargs["base_url"] = self.llm_base_url
|
|
|
|
response = completion(**completion_kwargs)
|
|
|
|
response_content = response.choices[0].message.content
|
|
activity.logger.info(f"Raw LLM response: {repr(response_content)}")
|
|
activity.logger.info(f"LLM response content: {response_content}")
|
|
activity.logger.info(f"LLM response type: {type(response_content)}")
|
|
activity.logger.info(
|
|
f"LLM response length: {len(response_content) if response_content else 'None'}"
|
|
)
|
|
|
|
# Use the new sanitize function
|
|
response_content = self.sanitize_json_response(response_content)
|
|
activity.logger.info(f"Sanitized response: {repr(response_content)}")
|
|
|
|
return self.parse_json_response(response_content)
|
|
except Exception as e:
|
|
print(f"Error in LLM completion: {str(e)}")
|
|
raise
|
|
|
|
def parse_json_response(self, response_content: str) -> dict:
|
|
"""
|
|
Parses the JSON response content and returns it as a dictionary.
|
|
"""
|
|
try:
|
|
data = json.loads(response_content)
|
|
return data
|
|
except json.JSONDecodeError as e:
|
|
print(f"Invalid JSON: {e}")
|
|
raise
|
|
|
|
def sanitize_json_response(self, response_content: str) -> str:
|
|
"""
|
|
Sanitizes the response content to ensure it's valid JSON.
|
|
"""
|
|
# Remove any markdown code block markers
|
|
response_content = response_content.replace("```json", "").replace("```", "")
|
|
|
|
# Remove any leading/trailing whitespace
|
|
response_content = response_content.strip()
|
|
|
|
return response_content
|
|
|
|
@activity.defn
|
|
async def get_wf_env_vars(self, input: EnvLookupInput) -> EnvLookupOutput:
|
|
"""gets env vars for workflow as an activity result so it's deterministic
|
|
handles default/None
|
|
"""
|
|
output: EnvLookupOutput = EnvLookupOutput(
|
|
show_confirm=input.show_confirm_default, multi_goal_mode=False
|
|
)
|
|
show_confirm_value = os.getenv(input.show_confirm_env_var_name)
|
|
if show_confirm_value is None:
|
|
output.show_confirm = input.show_confirm_default
|
|
elif show_confirm_value is not None and show_confirm_value.lower() == "false":
|
|
output.show_confirm = False
|
|
else:
|
|
output.show_confirm = True
|
|
|
|
first_goal_value = os.getenv("AGENT_GOAL")
|
|
if first_goal_value is None:
|
|
output.multi_goal_mode = False # default to single agent mode if unset
|
|
elif (
|
|
first_goal_value is not None
|
|
and first_goal_value.lower() == "goal_choose_agent_type"
|
|
):
|
|
output.multi_goal_mode = True
|
|
else:
|
|
output.multi_goal_mode = False
|
|
|
|
return output
|
|
|
|
@activity.defn
|
|
async def mcp_tool_activity(
|
|
self, tool_name: str, tool_args: Dict[str, Any]
|
|
) -> Dict[str, Any]:
|
|
"""MCP Tool - now using pooled connections"""
|
|
activity.logger.info(f"Executing MCP tool: {tool_name} with args: {tool_args}")
|
|
|
|
# Extract server definition
|
|
server_definition = tool_args.pop("server_definition", None)
|
|
|
|
if self.mcp_client_manager:
|
|
# Use pooled connection
|
|
return await self._execute_mcp_tool_pooled(
|
|
tool_name, tool_args, server_definition
|
|
)
|
|
else:
|
|
# Fallback to original implementation
|
|
return await _execute_mcp_tool(tool_name, tool_args, server_definition)
|
|
|
|
async def _execute_mcp_tool_pooled(
|
|
self,
|
|
tool_name: str,
|
|
tool_args: Dict[str, Any],
|
|
server_definition: MCPServerDefinition | Dict[str, Any] | None,
|
|
) -> Dict[str, Any]:
|
|
"""Execute MCP tool using pooled client connection"""
|
|
activity.logger.info(f"Executing MCP tool with pooled connection: {tool_name}")
|
|
|
|
# Convert argument types for MCP tools
|
|
converted_args = _convert_args_types(tool_args)
|
|
|
|
try:
|
|
# Get pooled client
|
|
client = await self.mcp_client_manager.get_client(server_definition)
|
|
|
|
# Call the tool using existing client session
|
|
result = await client.call_tool(tool_name, arguments=converted_args)
|
|
normalized_result = _normalize_result(result)
|
|
|
|
return {
|
|
"tool": tool_name,
|
|
"success": True,
|
|
"content": normalized_result,
|
|
}
|
|
except Exception as e:
|
|
activity.logger.error(f"MCP tool {tool_name} failed: {str(e)}")
|
|
return {
|
|
"tool": tool_name,
|
|
"success": False,
|
|
"error": str(e),
|
|
"error_type": type(e).__name__,
|
|
}
|
|
|
|
|
|
@activity.defn(dynamic=True)
|
|
async def dynamic_tool_activity(args: Sequence[RawValue]) -> dict:
|
|
from tools import get_handler
|
|
|
|
tool_name = activity.info().activity_type # e.g. "FindEvents"
|
|
tool_args = activity.payload_converter().from_payload(args[0].payload, dict)
|
|
activity.logger.info(f"Running dynamic tool '{tool_name}' with args: {tool_args}")
|
|
|
|
# Check if this is an MCP tool call by looking for server_definition in args
|
|
server_definition = tool_args.pop("server_definition", None)
|
|
|
|
if server_definition:
|
|
# This is an MCP tool call - handle it directly
|
|
activity.logger.info(f"Executing MCP tool: {tool_name}")
|
|
return await _execute_mcp_tool(tool_name, tool_args, server_definition)
|
|
else:
|
|
# This is a regular tool - delegate to the relevant function
|
|
handler = get_handler(tool_name)
|
|
if inspect.iscoroutinefunction(handler):
|
|
result = await handler(tool_args)
|
|
else:
|
|
result = handler(tool_args)
|
|
|
|
# Optionally log or augment the result
|
|
activity.logger.info(f"Tool '{tool_name}' result: {result}")
|
|
return result
|
|
|
|
|
|
# MCP Client Activities
|
|
|
|
|
|
def _build_connection(
|
|
server_definition: MCPServerDefinition | Dict[str, Any] | None,
|
|
) -> Dict[str, Any]:
|
|
"""Build connection parameters from MCPServerDefinition or dict"""
|
|
if server_definition is None:
|
|
# Default to stdio connection with the main server
|
|
return {"type": "stdio", "command": "python", "args": ["server.py"], "env": {}}
|
|
|
|
# Handle both MCPServerDefinition objects and dicts (from Temporal serialization)
|
|
if isinstance(server_definition, dict):
|
|
return {
|
|
"type": server_definition.get("connection_type", "stdio"),
|
|
"command": server_definition.get("command", "python"),
|
|
"args": server_definition.get("args", ["server.py"]),
|
|
"env": server_definition.get("env", {}) or {},
|
|
}
|
|
|
|
return {
|
|
"type": server_definition.connection_type,
|
|
"command": server_definition.command,
|
|
"args": server_definition.args,
|
|
"env": server_definition.env or {},
|
|
}
|
|
|
|
|
|
def _normalize_result(result: Any) -> Any:
|
|
"""Normalize MCP tool result for serialization"""
|
|
if hasattr(result, "content"):
|
|
# Handle MCP result objects
|
|
if hasattr(result.content, "__iter__") and not isinstance(result.content, str):
|
|
return [
|
|
item.text if hasattr(item, "text") else str(item)
|
|
for item in result.content
|
|
]
|
|
return str(result.content)
|
|
return result
|
|
|
|
|
|
def _convert_args_types(tool_args: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Convert string arguments to appropriate types for MCP tools"""
|
|
converted_args = {}
|
|
|
|
for key, value in tool_args.items():
|
|
if key == "server_definition":
|
|
# Skip server_definition - it's metadata
|
|
continue
|
|
|
|
if isinstance(value, str):
|
|
# Try to convert string values to appropriate types
|
|
if value.isdigit():
|
|
# Convert numeric strings to integers
|
|
converted_args[key] = int(value)
|
|
elif value.replace(".", "").isdigit() and value.count(".") == 1:
|
|
# Convert decimal strings to floats
|
|
converted_args[key] = float(value)
|
|
elif value.lower() in ("true", "false"):
|
|
# Convert boolean strings
|
|
converted_args[key] = value.lower() == "true"
|
|
else:
|
|
# Keep as string
|
|
converted_args[key] = value
|
|
else:
|
|
# Keep non-string values as-is
|
|
converted_args[key] = value
|
|
|
|
return converted_args
|
|
|
|
|
|
async def _execute_mcp_tool(
|
|
tool_name: str,
|
|
tool_args: Dict[str, Any],
|
|
server_definition: MCPServerDefinition | Dict[str, Any] | None,
|
|
) -> Dict[str, Any]:
|
|
"""Execute an MCP tool with the given arguments and server definition"""
|
|
activity.logger.info(f"Executing MCP tool: {tool_name}")
|
|
|
|
# Convert argument types for MCP tools
|
|
converted_args = _convert_args_types(tool_args)
|
|
connection = _build_connection(server_definition)
|
|
|
|
try:
|
|
if connection["type"] == "stdio":
|
|
# Handle stdio connection
|
|
async with _stdio_connection(
|
|
command=connection.get("command", "python"),
|
|
args=connection.get("args", ["server.py"]),
|
|
env=connection.get("env", {}),
|
|
) as (read, write):
|
|
async with ClientSession(read, write) as session:
|
|
# Initialize the session
|
|
activity.logger.info(f"Initializing MCP session for {tool_name}")
|
|
await session.initialize()
|
|
activity.logger.info(f"MCP session initialized for {tool_name}")
|
|
|
|
# Call the tool
|
|
activity.logger.info(
|
|
f"Calling MCP tool {tool_name} with args: {converted_args}"
|
|
)
|
|
try:
|
|
result = await session.call_tool(
|
|
tool_name, arguments=converted_args
|
|
)
|
|
activity.logger.info(
|
|
f"MCP tool {tool_name} returned result: {result}"
|
|
)
|
|
except Exception as tool_exc:
|
|
activity.logger.error(
|
|
f"MCP tool {tool_name} call failed: {type(tool_exc).__name__}: {tool_exc}"
|
|
)
|
|
raise
|
|
|
|
normalized_result = _normalize_result(result)
|
|
activity.logger.info(f"MCP tool {tool_name} completed successfully")
|
|
|
|
return {
|
|
"tool": tool_name,
|
|
"success": True,
|
|
"content": normalized_result,
|
|
}
|
|
|
|
elif connection["type"] == "tcp":
|
|
# Handle TCP connection (placeholder for future implementation)
|
|
raise ApplicationError("TCP connections not yet implemented")
|
|
|
|
else:
|
|
raise ApplicationError(f"Unsupported connection type: {connection['type']}")
|
|
|
|
except Exception as e:
|
|
activity.logger.error(f"MCP tool {tool_name} failed: {str(e)}")
|
|
|
|
# Return error information
|
|
return {
|
|
"tool": tool_name,
|
|
"success": False,
|
|
"error": str(e),
|
|
"error_type": type(e).__name__,
|
|
}
|
|
|
|
|
|
@asynccontextmanager
|
|
async def _stdio_connection(command: str, args: list, env: dict):
|
|
"""Create stdio connection to MCP server"""
|
|
if stdio_client is None:
|
|
raise ApplicationError("MCP client libraries not available")
|
|
|
|
# Create server parameters
|
|
server_params = StdioServerParameters(command=command, args=args, env=env)
|
|
|
|
async with stdio_client(server_params) as (read, write):
|
|
yield read, write
|
|
|
|
|
|
@activity.defn
|
|
async def mcp_list_tools(
|
|
server_definition: MCPServerDefinition, include_tools: Optional[List[str]] = None
|
|
) -> Dict[str, Any]:
|
|
"""List available MCP tools from the specified server"""
|
|
|
|
activity.logger.info(f"Listing MCP tools for server: {server_definition.name}")
|
|
|
|
connection = _build_connection(server_definition)
|
|
|
|
try:
|
|
if connection["type"] == "stdio":
|
|
async with _stdio_connection(
|
|
command=connection.get("command", "python"),
|
|
args=connection.get("args", ["server.py"]),
|
|
env=connection.get("env", {}),
|
|
) as (read, write):
|
|
async with ClientSession(read, write) as session:
|
|
# Initialize the session
|
|
await session.initialize()
|
|
|
|
# List available tools
|
|
tools_response = await session.list_tools()
|
|
|
|
# Process tools based on include_tools filter
|
|
tools_info = {}
|
|
for tool in tools_response.tools:
|
|
# If include_tools is specified, only include those tools
|
|
if include_tools is None or tool.name in include_tools:
|
|
tools_info[tool.name] = {
|
|
"name": tool.name,
|
|
"description": tool.description,
|
|
"inputSchema": (
|
|
tool.inputSchema.model_dump()
|
|
if hasattr(tool.inputSchema, "model_dump")
|
|
else str(tool.inputSchema)
|
|
),
|
|
}
|
|
|
|
activity.logger.info(
|
|
f"Found {len(tools_info)} tools for server {server_definition.name}"
|
|
)
|
|
|
|
return {
|
|
"server_name": server_definition.name,
|
|
"success": True,
|
|
"tools": tools_info,
|
|
"total_available": len(tools_response.tools),
|
|
"filtered_count": len(tools_info),
|
|
}
|
|
|
|
elif connection["type"] == "tcp":
|
|
raise ApplicationError("TCP connections not yet implemented")
|
|
|
|
else:
|
|
raise ApplicationError(f"Unsupported connection type: {connection['type']}")
|
|
|
|
except Exception as e:
|
|
activity.logger.error(
|
|
f"Failed to list tools for server {server_definition.name}: {str(e)}"
|
|
)
|
|
|
|
return {
|
|
"server_name": server_definition.name,
|
|
"success": False,
|
|
"error": str(e),
|
|
"error_type": type(e).__name__,
|
|
}
|