diff --git a/activities/tool_activities.py b/activities/tool_activities.py
index 49aaedc..1380666 100644
--- a/activities/tool_activities.py
+++ b/activities/tool_activities.py
@@ -19,6 +19,7 @@ from models.data_types import (
ValidationResult,
)
from models.tool_definitions import MCPServerDefinition
+from shared.mcp_client_manager import MCPClientManager
# Import MCP client libraries
try:
@@ -34,14 +35,17 @@ load_dotenv(override=True)
class ToolActivities:
- def __init__(self):
- """Initialize LLM client using LiteLLM."""
+ def __init__(self, mcp_client_manager: MCPClientManager = None):
+ """Initialize LLM client using LiteLLM and optional MCP client manager"""
self.llm_model = os.environ.get("LLM_MODEL", "openai/gpt-4")
self.llm_key = os.environ.get("LLM_KEY")
self.llm_base_url = os.environ.get("LLM_BASE_URL")
+ self.mcp_client_manager = mcp_client_manager
print(f"Initializing ToolActivities with LLM model: {self.llm_model}")
if self.llm_base_url:
print(f"Using custom base URL: {self.llm_base_url}")
+ if self.mcp_client_manager:
+ print("MCP client manager enabled for connection pooling")
@activity.defn
async def agent_validatePrompt(
@@ -205,13 +209,54 @@ class ToolActivities:
async def mcp_tool_activity(
self, tool_name: str, tool_args: Dict[str, Any]
) -> Dict[str, Any]:
- """MCP Tool"""
+ """MCP Tool - now using pooled connections"""
activity.logger.info(f"Executing MCP tool: {tool_name} with args: {tool_args}")
# Extract server definition
server_definition = tool_args.pop("server_definition", None)
- return await _execute_mcp_tool(tool_name, tool_args, server_definition)
+ if self.mcp_client_manager:
+ # Use pooled connection
+ return await self._execute_mcp_tool_pooled(
+ tool_name, tool_args, server_definition
+ )
+ else:
+ # Fallback to original implementation
+ return await _execute_mcp_tool(tool_name, tool_args, server_definition)
+
+ async def _execute_mcp_tool_pooled(
+ self,
+ tool_name: str,
+ tool_args: Dict[str, Any],
+ server_definition: MCPServerDefinition | Dict[str, Any] | None,
+ ) -> Dict[str, Any]:
+ """Execute MCP tool using pooled client connection"""
+ activity.logger.info(f"Executing MCP tool with pooled connection: {tool_name}")
+
+ # Convert argument types for MCP tools
+ converted_args = _convert_args_types(tool_args)
+
+ try:
+ # Get pooled client
+ client = await self.mcp_client_manager.get_client(server_definition)
+
+ # Call the tool using existing client session
+ result = await client.call_tool(tool_name, arguments=converted_args)
+ normalized_result = _normalize_result(result)
+
+ return {
+ "tool": tool_name,
+ "success": True,
+ "content": normalized_result,
+ }
+ except Exception as e:
+ activity.logger.error(f"MCP tool {tool_name} failed: {str(e)}")
+ return {
+ "tool": tool_name,
+ "success": False,
+ "error": str(e),
+ "error_type": type(e).__name__,
+ }
@activity.defn(dynamic=True)
diff --git a/frontend/src/components/MessageBubble.jsx b/frontend/src/components/MessageBubble.jsx
index f3245a1..c29eeb5 100644
--- a/frontend/src/components/MessageBubble.jsx
+++ b/frontend/src/components/MessageBubble.jsx
@@ -8,26 +8,54 @@ const MessageBubble = memo(({ message, fallback = "", isUser = false }) => {
}
const renderTextWithLinks = (text) => {
+ // First handle image markdown: 
+ const imageRegex = /!\[([^\]]*)\]\(([^)]+)\)/g;
const urlRegex = /(https?:\/\/[^\s]+)/g;
- const parts = text.split(urlRegex);
-
- return parts.map((part, index) => {
- if (urlRegex.test(part)) {
+
+ // Split by image markdown first
+ const imageParts = text.split(imageRegex);
+
+ return imageParts.map((part, index) => {
+ // Every third element (starting from index 2) is an image URL
+ if (index > 0 && (index - 2) % 3 === 0) {
+ const altText = imageParts[index - 1];
+ const imageUrl = part;
return (
-
- {part}
-
+ src={imageUrl}
+ alt={altText}
+ className="max-w-full h-auto rounded mt-2 mb-2 mx-auto block border border-gray-300 dark:border-gray-600"
+ style={{ maxHeight: '200px' }}
+ loading="lazy"
+ />
);
}
- return part;
- });
+ // Skip alt text parts (every second element after first)
+ if (index > 0 && (index - 1) % 3 === 0) {
+ return null;
+ }
+
+ // Handle regular text and links
+ const linkParts = part.split(urlRegex);
+ return linkParts.map((linkPart, linkIndex) => {
+ if (urlRegex.test(linkPart)) {
+ return (
+
+ {linkPart}
+
+ );
+ }
+ return linkPart;
+ });
+ }).filter(Boolean);
};
return (
diff --git a/goals/food.py b/goals/food.py
index f77e750..ce24ed9 100644
--- a/goals/food.py
+++ b/goals/food.py
@@ -23,7 +23,7 @@ goal_food_ordering = AgentGoal(
]
),
description="The user wants to order food from Tony's Pizza Palace. "
- "First, help the user browse the menu by calling list_products. "
+ "First, help the user browse the menu by calling list_products. You can refer to this data throughout the conversation to answer questions (without re-calling this tool). If menu items contain links to images, you can use markdown to display them e.g. "
"When they express interest in items, get pricing using list_prices. "
"Add items to their cart using AddToCart as they decide - the order doesn't matter, multiple items can be added. "
"After they're done selecting items, get their customer details and create a Stripe customer. "
@@ -35,8 +35,8 @@ goal_food_ordering = AgentGoal(
"user: I want to order food.",
"agent: Welcome to Tony's Pizza Palace! I can help you browse our menu, add items to your cart, and process payment. Let me show you what we have available.",
"user_confirmed_tool_run: ",
- 'tool_result: {"products": [{"id": "prod_pepperoni", "name": "Pepperoni Pizza", "description": "A timeless classic featuring pepperoni"}, {"id": "prod_margherita", "name": "Margherita Pizza", "description": "Fresh mozzarella and basil"}, {"id": "prod_garlic_bread", "name": "Garlic Bread", "description": "Warm crusty bread with garlic butter"}, {"id": "prod_soda", "name": "Soda", "description": "A large bottle of cola"}]}',
- "agent: We have Pepperoni Pizza, Margherita Pizza, Garlic Bread, and Soda available. What would you like to order?",
+ 'tool_result: {"products": [{"id": "prod_pepperoni", "name": "Pepperoni Pizza", "description": "A timeless classic featuring pepperoni", "images": ["https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfcGRHc0c4cEZYWmR2bm0zOHBOa0FWMk5t008QmCJoWr"]}, {"id": "prod_margherita", "name": "Margherita Pizza", "description": "Fresh mozzarella and basil", "images": ["https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfamdmTXBFbzY0TW9rS2N0c2g0Tml2SERL00Evl60Ttq"]}, {"id": "prod_garlic_bread", "name": "Garlic Bread", "description": "Warm crusty bread with garlic butter", "images": ["https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfWTdIZTBkUjNZNFQ1ZEhSVG9nRnduY1pS00XVgLRRZD"]}, {"id": "prod_soda", "name": "Soda", "description": "A large bottle of cola", "images": ["https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfbDJxckJKMDRnT1dDc243OHlZNWNkZkY5006Xg07kHT"]}]}',
+ "agent: Welcome to our menu! Here's what we have available:\\n\\n**Pepperoni Pizza** - A timeless classic featuring pepperoni\\n\\n\\n**Margherita Pizza** - Fresh mozzarella and basil\\n\\n\\n**Garlic Bread** - Warm crusty bread with garlic butter\\n\\n\\n**Soda** - A large bottle of cola\\n\\n\\nWhat would you like to order?",
"user: I'd like 2 pepperoni pizzas and a soda.",
"agent: Excellent! Let me get the pricing for the pepperoni pizzas first so you know the cost.",
"user_confirmed_tool_run: ",
diff --git a/scripts/run_worker.py b/scripts/run_worker.py
index 864a7b8..5ba1c89 100644
--- a/scripts/run_worker.py
+++ b/scripts/run_worker.py
@@ -12,6 +12,7 @@ from activities.tool_activities import (
mcp_list_tools,
)
from shared.config import TEMPORAL_TASK_QUEUE, get_temporal_client
+from shared.mcp_client_manager import MCPClientManager
from workflows.agent_goal_workflow import AgentGoalWorkflow
@@ -23,11 +24,14 @@ async def main():
llm_model = os.environ.get("LLM_MODEL", "openai/gpt-4")
print(f"Worker will use LLM model: {llm_model}")
+ # Create shared MCP client manager
+ mcp_client_manager = MCPClientManager()
+
# Create the client
client = await get_temporal_client()
- # Initialize the activities class
- activities = ToolActivities()
+ # Initialize the activities class with injected manager
+ activities = ToolActivities(mcp_client_manager)
print(f"ToolActivities initialized with LLM model: {llm_model}")
# If using Ollama, pre-load the model to avoid cold start latency
@@ -54,25 +58,31 @@ async def main():
print("Worker ready to process tasks!")
logging.basicConfig(level=logging.INFO)
- # Run the worker
- with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor:
- worker = Worker(
- client,
- task_queue=TEMPORAL_TASK_QUEUE,
- workflows=[AgentGoalWorkflow],
- activities=[
- activities.agent_validatePrompt,
- activities.agent_toolPlanner,
- activities.get_wf_env_vars,
- activities.mcp_tool_activity,
- dynamic_tool_activity,
- mcp_list_tools,
- ],
- activity_executor=activity_executor,
- )
+ # Run the worker with proper cleanup
+ try:
+ with concurrent.futures.ThreadPoolExecutor(
+ max_workers=100
+ ) as activity_executor:
+ worker = Worker(
+ client,
+ task_queue=TEMPORAL_TASK_QUEUE,
+ workflows=[AgentGoalWorkflow],
+ activities=[
+ activities.agent_validatePrompt,
+ activities.agent_toolPlanner,
+ activities.get_wf_env_vars,
+ activities.mcp_tool_activity,
+ dynamic_tool_activity,
+ mcp_list_tools,
+ ],
+ activity_executor=activity_executor,
+ )
- print(f"Starting worker, connecting to task queue: {TEMPORAL_TASK_QUEUE}")
- await worker.run()
+ print(f"Starting worker, connecting to task queue: {TEMPORAL_TASK_QUEUE}")
+ await worker.run()
+ finally:
+ # Cleanup MCP connections when worker shuts down
+ await mcp_client_manager.cleanup()
if __name__ == "__main__":
diff --git a/shared/mcp_client_manager.py b/shared/mcp_client_manager.py
new file mode 100644
index 0000000..1305c49
--- /dev/null
+++ b/shared/mcp_client_manager.py
@@ -0,0 +1,167 @@
+import asyncio
+from contextlib import asynccontextmanager
+from typing import TYPE_CHECKING, Any, Dict, Tuple
+
+from temporalio import activity
+
+from models.tool_definitions import MCPServerDefinition
+
+# Import MCP client libraries
+if TYPE_CHECKING:
+ from mcp import ClientSession, StdioServerParameters
+ from mcp.client.stdio import stdio_client
+else:
+ try:
+ from mcp import ClientSession, StdioServerParameters
+ from mcp.client.stdio import stdio_client
+ except ImportError:
+ # Fallback if MCP not installed
+ ClientSession = None
+ StdioServerParameters = None
+ stdio_client = None
+
+
+class MCPClientManager:
+ """Manages pooled MCP client connections for reuse across tool calls"""
+
+ def __init__(self):
+ self._clients: Dict[str, Any] = {}
+ self._connections: Dict[str, Tuple[Any, Any]] = {}
+ self._lock = asyncio.Lock()
+
+ async def get_client(
+ self, server_def: MCPServerDefinition | Dict[str, Any] | None
+ ) -> Any:
+ """Return existing client or create new one, keyed by server definition hash"""
+ async with self._lock:
+ key = self._get_server_key(server_def)
+ if key not in self._clients:
+ await self._create_client(server_def, key)
+ activity.logger.info(
+ f"Created new MCP client for {self._get_server_name(server_def)}"
+ )
+ else:
+ activity.logger.info(
+ f"Reusing existing MCP client for {self._get_server_name(server_def)}"
+ )
+ return self._clients[key]
+
+ def _get_server_key(
+ self, server_def: MCPServerDefinition | Dict[str, Any] | None
+ ) -> str:
+ """Generate unique key for server definition"""
+ if server_def is None:
+ return "default:python:server.py"
+
+ # Handle both MCPServerDefinition objects and dicts (from Temporal serialization)
+ if isinstance(server_def, dict):
+ name = server_def.get("name", "default")
+ command = server_def.get("command", "python")
+ args = server_def.get("args", ["server.py"])
+ else:
+ name = server_def.name
+ command = server_def.command
+ args = server_def.args
+
+ return f"{name}:{command}:{':'.join(args)}"
+
+ def _get_server_name(
+ self, server_def: MCPServerDefinition | Dict[str, Any] | None
+ ) -> str:
+ """Get server name for logging"""
+ if server_def is None:
+ return "default"
+
+ if isinstance(server_def, dict):
+ return server_def.get("name", "default")
+ else:
+ return server_def.name
+
+ def _build_connection(
+ self, server_def: MCPServerDefinition | Dict[str, Any] | None
+ ) -> Dict[str, Any]:
+ """Build connection parameters from MCPServerDefinition or dict"""
+ if server_def is None:
+ # Default to stdio connection with the main server
+ return {
+ "type": "stdio",
+ "command": "python",
+ "args": ["server.py"],
+ "env": {},
+ }
+
+ # Handle both MCPServerDefinition objects and dicts (from Temporal serialization)
+ if isinstance(server_def, dict):
+ return {
+ "type": server_def.get("connection_type", "stdio"),
+ "command": server_def.get("command", "python"),
+ "args": server_def.get("args", ["server.py"]),
+ "env": server_def.get("env", {}) or {},
+ }
+
+ return {
+ "type": server_def.connection_type,
+ "command": server_def.command,
+ "args": server_def.args,
+ "env": server_def.env or {},
+ }
+
+ @asynccontextmanager
+ async def _stdio_connection(self, command: str, args: list, env: dict):
+ """Create stdio connection to MCP server"""
+ if stdio_client is None:
+ raise Exception("MCP client libraries not available")
+
+ # Create server parameters
+ server_params = StdioServerParameters(command=command, args=args, env=env)
+
+ async with stdio_client(server_params) as (read, write):
+ yield read, write
+
+ async def _create_client(
+ self, server_def: MCPServerDefinition | Dict[str, Any] | None, key: str
+ ):
+ """Create and store new client connection"""
+ connection = self._build_connection(server_def)
+
+ if connection["type"] == "stdio":
+ # Create stdio connection
+ connection_manager = self._stdio_connection(
+ command=connection.get("command", "python"),
+ args=connection.get("args", ["server.py"]),
+ env=connection.get("env", {}),
+ )
+
+ # Enter the connection context
+ read, write = await connection_manager.__aenter__()
+
+ # Create and initialize client session
+ session = ClientSession(read, write)
+ await session.initialize()
+
+ # Store both the session and connection manager for cleanup
+ self._clients[key] = session
+ self._connections[key] = (connection_manager, read, write)
+ else:
+ raise Exception(f"Unsupported connection type: {connection['type']}")
+
+ async def cleanup(self):
+ """Close all connections gracefully"""
+ async with self._lock:
+ # Close all client sessions
+ for session in self._clients.values():
+ try:
+ await session.close()
+ except Exception as e:
+ activity.logger.warning(f"Error closing MCP session: {e}")
+
+ # Exit all connection contexts
+ for connection_manager, read, write in self._connections.values():
+ try:
+ await connection_manager.__aexit__(None, None, None)
+ except Exception as e:
+ activity.logger.warning(f"Error closing MCP connection: {e}")
+
+ self._clients.clear()
+ self._connections.clear()
+ activity.logger.info("All MCP connections closed")
diff --git a/tools/food/setup/archive_food_products.py b/tools/food/setup/archive_food_products.py
index ef1ff46..2ee0287 100644
--- a/tools/food/setup/archive_food_products.py
+++ b/tools/food/setup/archive_food_products.py
@@ -1,4 +1,5 @@
import os
+
from dotenv import load_dotenv
diff --git a/tools/food/setup/create_stripe_products.py b/tools/food/setup/create_stripe_products.py
index b462bbc..91dc8f9 100644
--- a/tools/food/setup/create_stripe_products.py
+++ b/tools/food/setup/create_stripe_products.py
@@ -1,5 +1,6 @@
import json
import os
+
from dotenv import load_dotenv