mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
Mcp enhancements (#43)
* reuses MCP connections in each worker for efficiency * you can see your food * you can see your food * prompt eng around images
This commit is contained in:
committed by
GitHub
parent
49dd00ec3b
commit
861e55a8d0
@@ -19,6 +19,7 @@ from models.data_types import (
|
||||
ValidationResult,
|
||||
)
|
||||
from models.tool_definitions import MCPServerDefinition
|
||||
from shared.mcp_client_manager import MCPClientManager
|
||||
|
||||
# Import MCP client libraries
|
||||
try:
|
||||
@@ -34,14 +35,17 @@ load_dotenv(override=True)
|
||||
|
||||
|
||||
class ToolActivities:
|
||||
def __init__(self):
|
||||
"""Initialize LLM client using LiteLLM."""
|
||||
def __init__(self, mcp_client_manager: MCPClientManager = None):
|
||||
"""Initialize LLM client using LiteLLM and optional MCP client manager"""
|
||||
self.llm_model = os.environ.get("LLM_MODEL", "openai/gpt-4")
|
||||
self.llm_key = os.environ.get("LLM_KEY")
|
||||
self.llm_base_url = os.environ.get("LLM_BASE_URL")
|
||||
self.mcp_client_manager = mcp_client_manager
|
||||
print(f"Initializing ToolActivities with LLM model: {self.llm_model}")
|
||||
if self.llm_base_url:
|
||||
print(f"Using custom base URL: {self.llm_base_url}")
|
||||
if self.mcp_client_manager:
|
||||
print("MCP client manager enabled for connection pooling")
|
||||
|
||||
@activity.defn
|
||||
async def agent_validatePrompt(
|
||||
@@ -205,14 +209,55 @@ class ToolActivities:
|
||||
async def mcp_tool_activity(
|
||||
self, tool_name: str, tool_args: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""MCP Tool"""
|
||||
"""MCP Tool - now using pooled connections"""
|
||||
activity.logger.info(f"Executing MCP tool: {tool_name} with args: {tool_args}")
|
||||
|
||||
# Extract server definition
|
||||
server_definition = tool_args.pop("server_definition", None)
|
||||
|
||||
if self.mcp_client_manager:
|
||||
# Use pooled connection
|
||||
return await self._execute_mcp_tool_pooled(
|
||||
tool_name, tool_args, server_definition
|
||||
)
|
||||
else:
|
||||
# Fallback to original implementation
|
||||
return await _execute_mcp_tool(tool_name, tool_args, server_definition)
|
||||
|
||||
async def _execute_mcp_tool_pooled(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: Dict[str, Any],
|
||||
server_definition: MCPServerDefinition | Dict[str, Any] | None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute MCP tool using pooled client connection"""
|
||||
activity.logger.info(f"Executing MCP tool with pooled connection: {tool_name}")
|
||||
|
||||
# Convert argument types for MCP tools
|
||||
converted_args = _convert_args_types(tool_args)
|
||||
|
||||
try:
|
||||
# Get pooled client
|
||||
client = await self.mcp_client_manager.get_client(server_definition)
|
||||
|
||||
# Call the tool using existing client session
|
||||
result = await client.call_tool(tool_name, arguments=converted_args)
|
||||
normalized_result = _normalize_result(result)
|
||||
|
||||
return {
|
||||
"tool": tool_name,
|
||||
"success": True,
|
||||
"content": normalized_result,
|
||||
}
|
||||
except Exception as e:
|
||||
activity.logger.error(f"MCP tool {tool_name} failed: {str(e)}")
|
||||
return {
|
||||
"tool": tool_name,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
|
||||
|
||||
@activity.defn(dynamic=True)
|
||||
async def dynamic_tool_activity(args: Sequence[RawValue]) -> dict:
|
||||
|
||||
@@ -8,26 +8,54 @@ const MessageBubble = memo(({ message, fallback = "", isUser = false }) => {
|
||||
}
|
||||
|
||||
const renderTextWithLinks = (text) => {
|
||||
// First handle image markdown: 
|
||||
const imageRegex = /!\[([^\]]*)\]\(([^)]+)\)/g;
|
||||
const urlRegex = /(https?:\/\/[^\s]+)/g;
|
||||
const parts = text.split(urlRegex);
|
||||
|
||||
return parts.map((part, index) => {
|
||||
if (urlRegex.test(part)) {
|
||||
// Split by image markdown first
|
||||
const imageParts = text.split(imageRegex);
|
||||
|
||||
return imageParts.map((part, index) => {
|
||||
// Every third element (starting from index 2) is an image URL
|
||||
if (index > 0 && (index - 2) % 3 === 0) {
|
||||
const altText = imageParts[index - 1];
|
||||
const imageUrl = part;
|
||||
return (
|
||||
<img
|
||||
key={index}
|
||||
src={imageUrl}
|
||||
alt={altText}
|
||||
className="max-w-full h-auto rounded mt-2 mb-2 mx-auto block border border-gray-300 dark:border-gray-600"
|
||||
style={{ maxHeight: '200px' }}
|
||||
loading="lazy"
|
||||
/>
|
||||
);
|
||||
}
|
||||
// Skip alt text parts (every second element after first)
|
||||
if (index > 0 && (index - 1) % 3 === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Handle regular text and links
|
||||
const linkParts = part.split(urlRegex);
|
||||
return linkParts.map((linkPart, linkIndex) => {
|
||||
if (urlRegex.test(linkPart)) {
|
||||
return (
|
||||
<a
|
||||
key={index}
|
||||
href={part}
|
||||
key={`${index}-${linkIndex}`}
|
||||
href={linkPart}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-blue-500 hover:text-blue-600 underline"
|
||||
aria-label={`External link to ${part}`}
|
||||
aria-label={`External link to ${linkPart}`}
|
||||
>
|
||||
{part}
|
||||
{linkPart}
|
||||
</a>
|
||||
);
|
||||
}
|
||||
return part;
|
||||
return linkPart;
|
||||
});
|
||||
}).filter(Boolean);
|
||||
};
|
||||
|
||||
return (
|
||||
|
||||
@@ -23,7 +23,7 @@ goal_food_ordering = AgentGoal(
|
||||
]
|
||||
),
|
||||
description="The user wants to order food from Tony's Pizza Palace. "
|
||||
"First, help the user browse the menu by calling list_products. "
|
||||
"First, help the user browse the menu by calling list_products. You can refer to this data throughout the conversation to answer questions (without re-calling this tool). If menu items contain links to images, you can use markdown to display them e.g. "
|
||||
"When they express interest in items, get pricing using list_prices. "
|
||||
"Add items to their cart using AddToCart as they decide - the order doesn't matter, multiple items can be added. "
|
||||
"After they're done selecting items, get their customer details and create a Stripe customer. "
|
||||
@@ -35,8 +35,8 @@ goal_food_ordering = AgentGoal(
|
||||
"user: I want to order food.",
|
||||
"agent: Welcome to Tony's Pizza Palace! I can help you browse our menu, add items to your cart, and process payment. Let me show you what we have available.",
|
||||
"user_confirmed_tool_run: <user clicks confirm on list_products tool with limit=100 and product.metadata.use_case == 'food_ordering_demo'>",
|
||||
'tool_result: {"products": [{"id": "prod_pepperoni", "name": "Pepperoni Pizza", "description": "A timeless classic featuring pepperoni"}, {"id": "prod_margherita", "name": "Margherita Pizza", "description": "Fresh mozzarella and basil"}, {"id": "prod_garlic_bread", "name": "Garlic Bread", "description": "Warm crusty bread with garlic butter"}, {"id": "prod_soda", "name": "Soda", "description": "A large bottle of cola"}]}',
|
||||
"agent: We have Pepperoni Pizza, Margherita Pizza, Garlic Bread, and Soda available. What would you like to order?",
|
||||
'tool_result: {"products": [{"id": "prod_pepperoni", "name": "Pepperoni Pizza", "description": "A timeless classic featuring pepperoni", "images": ["https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfcGRHc0c4cEZYWmR2bm0zOHBOa0FWMk5t008QmCJoWr"]}, {"id": "prod_margherita", "name": "Margherita Pizza", "description": "Fresh mozzarella and basil", "images": ["https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfamdmTXBFbzY0TW9rS2N0c2g0Tml2SERL00Evl60Ttq"]}, {"id": "prod_garlic_bread", "name": "Garlic Bread", "description": "Warm crusty bread with garlic butter", "images": ["https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfWTdIZTBkUjNZNFQ1ZEhSVG9nRnduY1pS00XVgLRRZD"]}, {"id": "prod_soda", "name": "Soda", "description": "A large bottle of cola", "images": ["https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfbDJxckJKMDRnT1dDc243OHlZNWNkZkY5006Xg07kHT"]}]}',
|
||||
"agent: Welcome to our menu! Here's what we have available:\\n\\n**Pepperoni Pizza** - A timeless classic featuring pepperoni\\n\\n\\n**Margherita Pizza** - Fresh mozzarella and basil\\n\\n\\n**Garlic Bread** - Warm crusty bread with garlic butter\\n\\n\\n**Soda** - A large bottle of cola\\n\\n\\nWhat would you like to order?",
|
||||
"user: I'd like 2 pepperoni pizzas and a soda.",
|
||||
"agent: Excellent! Let me get the pricing for the pepperoni pizzas first so you know the cost.",
|
||||
"user_confirmed_tool_run: <user clicks confirm on list_prices tool with product='prod_pepperoni'>",
|
||||
|
||||
@@ -12,6 +12,7 @@ from activities.tool_activities import (
|
||||
mcp_list_tools,
|
||||
)
|
||||
from shared.config import TEMPORAL_TASK_QUEUE, get_temporal_client
|
||||
from shared.mcp_client_manager import MCPClientManager
|
||||
from workflows.agent_goal_workflow import AgentGoalWorkflow
|
||||
|
||||
|
||||
@@ -23,11 +24,14 @@ async def main():
|
||||
llm_model = os.environ.get("LLM_MODEL", "openai/gpt-4")
|
||||
print(f"Worker will use LLM model: {llm_model}")
|
||||
|
||||
# Create shared MCP client manager
|
||||
mcp_client_manager = MCPClientManager()
|
||||
|
||||
# Create the client
|
||||
client = await get_temporal_client()
|
||||
|
||||
# Initialize the activities class
|
||||
activities = ToolActivities()
|
||||
# Initialize the activities class with injected manager
|
||||
activities = ToolActivities(mcp_client_manager)
|
||||
print(f"ToolActivities initialized with LLM model: {llm_model}")
|
||||
|
||||
# If using Ollama, pre-load the model to avoid cold start latency
|
||||
@@ -54,8 +58,11 @@ async def main():
|
||||
print("Worker ready to process tasks!")
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Run the worker
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor:
|
||||
# Run the worker with proper cleanup
|
||||
try:
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=100
|
||||
) as activity_executor:
|
||||
worker = Worker(
|
||||
client,
|
||||
task_queue=TEMPORAL_TASK_QUEUE,
|
||||
@@ -73,6 +80,9 @@ async def main():
|
||||
|
||||
print(f"Starting worker, connecting to task queue: {TEMPORAL_TASK_QUEUE}")
|
||||
await worker.run()
|
||||
finally:
|
||||
# Cleanup MCP connections when worker shuts down
|
||||
await mcp_client_manager.cleanup()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
167
shared/mcp_client_manager.py
Normal file
167
shared/mcp_client_manager.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
||||
|
||||
from temporalio import activity
|
||||
|
||||
from models.tool_definitions import MCPServerDefinition
|
||||
|
||||
# Import MCP client libraries
|
||||
if TYPE_CHECKING:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
else:
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
except ImportError:
|
||||
# Fallback if MCP not installed
|
||||
ClientSession = None
|
||||
StdioServerParameters = None
|
||||
stdio_client = None
|
||||
|
||||
|
||||
class MCPClientManager:
|
||||
"""Manages pooled MCP client connections for reuse across tool calls"""
|
||||
|
||||
def __init__(self):
|
||||
self._clients: Dict[str, Any] = {}
|
||||
self._connections: Dict[str, Tuple[Any, Any]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_client(
|
||||
self, server_def: MCPServerDefinition | Dict[str, Any] | None
|
||||
) -> Any:
|
||||
"""Return existing client or create new one, keyed by server definition hash"""
|
||||
async with self._lock:
|
||||
key = self._get_server_key(server_def)
|
||||
if key not in self._clients:
|
||||
await self._create_client(server_def, key)
|
||||
activity.logger.info(
|
||||
f"Created new MCP client for {self._get_server_name(server_def)}"
|
||||
)
|
||||
else:
|
||||
activity.logger.info(
|
||||
f"Reusing existing MCP client for {self._get_server_name(server_def)}"
|
||||
)
|
||||
return self._clients[key]
|
||||
|
||||
def _get_server_key(
|
||||
self, server_def: MCPServerDefinition | Dict[str, Any] | None
|
||||
) -> str:
|
||||
"""Generate unique key for server definition"""
|
||||
if server_def is None:
|
||||
return "default:python:server.py"
|
||||
|
||||
# Handle both MCPServerDefinition objects and dicts (from Temporal serialization)
|
||||
if isinstance(server_def, dict):
|
||||
name = server_def.get("name", "default")
|
||||
command = server_def.get("command", "python")
|
||||
args = server_def.get("args", ["server.py"])
|
||||
else:
|
||||
name = server_def.name
|
||||
command = server_def.command
|
||||
args = server_def.args
|
||||
|
||||
return f"{name}:{command}:{':'.join(args)}"
|
||||
|
||||
def _get_server_name(
|
||||
self, server_def: MCPServerDefinition | Dict[str, Any] | None
|
||||
) -> str:
|
||||
"""Get server name for logging"""
|
||||
if server_def is None:
|
||||
return "default"
|
||||
|
||||
if isinstance(server_def, dict):
|
||||
return server_def.get("name", "default")
|
||||
else:
|
||||
return server_def.name
|
||||
|
||||
def _build_connection(
|
||||
self, server_def: MCPServerDefinition | Dict[str, Any] | None
|
||||
) -> Dict[str, Any]:
|
||||
"""Build connection parameters from MCPServerDefinition or dict"""
|
||||
if server_def is None:
|
||||
# Default to stdio connection with the main server
|
||||
return {
|
||||
"type": "stdio",
|
||||
"command": "python",
|
||||
"args": ["server.py"],
|
||||
"env": {},
|
||||
}
|
||||
|
||||
# Handle both MCPServerDefinition objects and dicts (from Temporal serialization)
|
||||
if isinstance(server_def, dict):
|
||||
return {
|
||||
"type": server_def.get("connection_type", "stdio"),
|
||||
"command": server_def.get("command", "python"),
|
||||
"args": server_def.get("args", ["server.py"]),
|
||||
"env": server_def.get("env", {}) or {},
|
||||
}
|
||||
|
||||
return {
|
||||
"type": server_def.connection_type,
|
||||
"command": server_def.command,
|
||||
"args": server_def.args,
|
||||
"env": server_def.env or {},
|
||||
}
|
||||
|
||||
@asynccontextmanager
|
||||
async def _stdio_connection(self, command: str, args: list, env: dict):
|
||||
"""Create stdio connection to MCP server"""
|
||||
if stdio_client is None:
|
||||
raise Exception("MCP client libraries not available")
|
||||
|
||||
# Create server parameters
|
||||
server_params = StdioServerParameters(command=command, args=args, env=env)
|
||||
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
yield read, write
|
||||
|
||||
async def _create_client(
|
||||
self, server_def: MCPServerDefinition | Dict[str, Any] | None, key: str
|
||||
):
|
||||
"""Create and store new client connection"""
|
||||
connection = self._build_connection(server_def)
|
||||
|
||||
if connection["type"] == "stdio":
|
||||
# Create stdio connection
|
||||
connection_manager = self._stdio_connection(
|
||||
command=connection.get("command", "python"),
|
||||
args=connection.get("args", ["server.py"]),
|
||||
env=connection.get("env", {}),
|
||||
)
|
||||
|
||||
# Enter the connection context
|
||||
read, write = await connection_manager.__aenter__()
|
||||
|
||||
# Create and initialize client session
|
||||
session = ClientSession(read, write)
|
||||
await session.initialize()
|
||||
|
||||
# Store both the session and connection manager for cleanup
|
||||
self._clients[key] = session
|
||||
self._connections[key] = (connection_manager, read, write)
|
||||
else:
|
||||
raise Exception(f"Unsupported connection type: {connection['type']}")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Close all connections gracefully"""
|
||||
async with self._lock:
|
||||
# Close all client sessions
|
||||
for session in self._clients.values():
|
||||
try:
|
||||
await session.close()
|
||||
except Exception as e:
|
||||
activity.logger.warning(f"Error closing MCP session: {e}")
|
||||
|
||||
# Exit all connection contexts
|
||||
for connection_manager, read, write in self._connections.values():
|
||||
try:
|
||||
await connection_manager.__aexit__(None, None, None)
|
||||
except Exception as e:
|
||||
activity.logger.warning(f"Error closing MCP connection: {e}")
|
||||
|
||||
self._clients.clear()
|
||||
self._connections.clear()
|
||||
activity.logger.info("All MCP connections closed")
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user