diff --git a/activities/tool_activities.py b/activities/tool_activities.py index 49aaedc..1380666 100644 --- a/activities/tool_activities.py +++ b/activities/tool_activities.py @@ -19,6 +19,7 @@ from models.data_types import ( ValidationResult, ) from models.tool_definitions import MCPServerDefinition +from shared.mcp_client_manager import MCPClientManager # Import MCP client libraries try: @@ -34,14 +35,17 @@ load_dotenv(override=True) class ToolActivities: - def __init__(self): - """Initialize LLM client using LiteLLM.""" + def __init__(self, mcp_client_manager: MCPClientManager = None): + """Initialize LLM client using LiteLLM and optional MCP client manager""" self.llm_model = os.environ.get("LLM_MODEL", "openai/gpt-4") self.llm_key = os.environ.get("LLM_KEY") self.llm_base_url = os.environ.get("LLM_BASE_URL") + self.mcp_client_manager = mcp_client_manager print(f"Initializing ToolActivities with LLM model: {self.llm_model}") if self.llm_base_url: print(f"Using custom base URL: {self.llm_base_url}") + if self.mcp_client_manager: + print("MCP client manager enabled for connection pooling") @activity.defn async def agent_validatePrompt( @@ -205,13 +209,54 @@ class ToolActivities: async def mcp_tool_activity( self, tool_name: str, tool_args: Dict[str, Any] ) -> Dict[str, Any]: - """MCP Tool""" + """MCP Tool - now using pooled connections""" activity.logger.info(f"Executing MCP tool: {tool_name} with args: {tool_args}") # Extract server definition server_definition = tool_args.pop("server_definition", None) - return await _execute_mcp_tool(tool_name, tool_args, server_definition) + if self.mcp_client_manager: + # Use pooled connection + return await self._execute_mcp_tool_pooled( + tool_name, tool_args, server_definition + ) + else: + # Fallback to original implementation + return await _execute_mcp_tool(tool_name, tool_args, server_definition) + + async def _execute_mcp_tool_pooled( + self, + tool_name: str, + tool_args: Dict[str, Any], + server_definition: MCPServerDefinition | Dict[str, Any] | None, + ) -> Dict[str, Any]: + """Execute MCP tool using pooled client connection""" + activity.logger.info(f"Executing MCP tool with pooled connection: {tool_name}") + + # Convert argument types for MCP tools + converted_args = _convert_args_types(tool_args) + + try: + # Get pooled client + client = await self.mcp_client_manager.get_client(server_definition) + + # Call the tool using existing client session + result = await client.call_tool(tool_name, arguments=converted_args) + normalized_result = _normalize_result(result) + + return { + "tool": tool_name, + "success": True, + "content": normalized_result, + } + except Exception as e: + activity.logger.error(f"MCP tool {tool_name} failed: {str(e)}") + return { + "tool": tool_name, + "success": False, + "error": str(e), + "error_type": type(e).__name__, + } @activity.defn(dynamic=True) diff --git a/frontend/src/components/MessageBubble.jsx b/frontend/src/components/MessageBubble.jsx index f3245a1..c29eeb5 100644 --- a/frontend/src/components/MessageBubble.jsx +++ b/frontend/src/components/MessageBubble.jsx @@ -8,26 +8,54 @@ const MessageBubble = memo(({ message, fallback = "", isUser = false }) => { } const renderTextWithLinks = (text) => { + // First handle image markdown: ![alt text](url) + const imageRegex = /!\[([^\]]*)\]\(([^)]+)\)/g; const urlRegex = /(https?:\/\/[^\s]+)/g; - const parts = text.split(urlRegex); - - return parts.map((part, index) => { - if (urlRegex.test(part)) { + + // Split by image markdown first + const imageParts = text.split(imageRegex); + + return imageParts.map((part, index) => { + // Every third element (starting from index 2) is an image URL + if (index > 0 && (index - 2) % 3 === 0) { + const altText = imageParts[index - 1]; + const imageUrl = part; return ( - - {part} - + src={imageUrl} + alt={altText} + className="max-w-full h-auto rounded mt-2 mb-2 mx-auto block border border-gray-300 dark:border-gray-600" + style={{ maxHeight: '200px' }} + loading="lazy" + /> ); } - return part; - }); + // Skip alt text parts (every second element after first) + if (index > 0 && (index - 1) % 3 === 0) { + return null; + } + + // Handle regular text and links + const linkParts = part.split(urlRegex); + return linkParts.map((linkPart, linkIndex) => { + if (urlRegex.test(linkPart)) { + return ( + + {linkPart} + + ); + } + return linkPart; + }); + }).filter(Boolean); }; return ( diff --git a/goals/food.py b/goals/food.py index f77e750..ce24ed9 100644 --- a/goals/food.py +++ b/goals/food.py @@ -23,7 +23,7 @@ goal_food_ordering = AgentGoal( ] ), description="The user wants to order food from Tony's Pizza Palace. " - "First, help the user browse the menu by calling list_products. " + "First, help the user browse the menu by calling list_products. You can refer to this data throughout the conversation to answer questions (without re-calling this tool). If menu items contain links to images, you can use markdown to display them e.g. ![Pepperoni Pizza](https://...)" "When they express interest in items, get pricing using list_prices. " "Add items to their cart using AddToCart as they decide - the order doesn't matter, multiple items can be added. " "After they're done selecting items, get their customer details and create a Stripe customer. " @@ -35,8 +35,8 @@ goal_food_ordering = AgentGoal( "user: I want to order food.", "agent: Welcome to Tony's Pizza Palace! I can help you browse our menu, add items to your cart, and process payment. Let me show you what we have available.", "user_confirmed_tool_run: ", - 'tool_result: {"products": [{"id": "prod_pepperoni", "name": "Pepperoni Pizza", "description": "A timeless classic featuring pepperoni"}, {"id": "prod_margherita", "name": "Margherita Pizza", "description": "Fresh mozzarella and basil"}, {"id": "prod_garlic_bread", "name": "Garlic Bread", "description": "Warm crusty bread with garlic butter"}, {"id": "prod_soda", "name": "Soda", "description": "A large bottle of cola"}]}', - "agent: We have Pepperoni Pizza, Margherita Pizza, Garlic Bread, and Soda available. What would you like to order?", + 'tool_result: {"products": [{"id": "prod_pepperoni", "name": "Pepperoni Pizza", "description": "A timeless classic featuring pepperoni", "images": ["https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfcGRHc0c4cEZYWmR2bm0zOHBOa0FWMk5t008QmCJoWr"]}, {"id": "prod_margherita", "name": "Margherita Pizza", "description": "Fresh mozzarella and basil", "images": ["https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfamdmTXBFbzY0TW9rS2N0c2g0Tml2SERL00Evl60Ttq"]}, {"id": "prod_garlic_bread", "name": "Garlic Bread", "description": "Warm crusty bread with garlic butter", "images": ["https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfWTdIZTBkUjNZNFQ1ZEhSVG9nRnduY1pS00XVgLRRZD"]}, {"id": "prod_soda", "name": "Soda", "description": "A large bottle of cola", "images": ["https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfbDJxckJKMDRnT1dDc243OHlZNWNkZkY5006Xg07kHT"]}]}', + "agent: Welcome to our menu! Here's what we have available:\\n\\n**Pepperoni Pizza** - A timeless classic featuring pepperoni\\n![Pepperoni Pizza](https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfcGRHc0c4cEZYWmR2bm0zOHBOa0FWMk5t008QmCJoWr)\\n\\n**Margherita Pizza** - Fresh mozzarella and basil\\n![Margherita Pizza](https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfamdmTXBFbzY0TW9rS2N0c2g0Tml2SERL00Evl60Ttq)\\n\\n**Garlic Bread** - Warm crusty bread with garlic butter\\n![Garlic Bread](https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfWTdIZTBkUjNZNFQ1ZEhSVG9nRnduY1pS00XVgLRRZD)\\n\\n**Soda** - A large bottle of cola\\n![Soda](https://files.stripe.com/links/MDB8YWNjdF8xTkJPTHVLVlpienc3UUE1fGZsX3Rlc3RfbDJxckJKMDRnT1dDc243OHlZNWNkZkY5006Xg07kHT)\\n\\nWhat would you like to order?", "user: I'd like 2 pepperoni pizzas and a soda.", "agent: Excellent! Let me get the pricing for the pepperoni pizzas first so you know the cost.", "user_confirmed_tool_run: ", diff --git a/scripts/run_worker.py b/scripts/run_worker.py index 864a7b8..5ba1c89 100644 --- a/scripts/run_worker.py +++ b/scripts/run_worker.py @@ -12,6 +12,7 @@ from activities.tool_activities import ( mcp_list_tools, ) from shared.config import TEMPORAL_TASK_QUEUE, get_temporal_client +from shared.mcp_client_manager import MCPClientManager from workflows.agent_goal_workflow import AgentGoalWorkflow @@ -23,11 +24,14 @@ async def main(): llm_model = os.environ.get("LLM_MODEL", "openai/gpt-4") print(f"Worker will use LLM model: {llm_model}") + # Create shared MCP client manager + mcp_client_manager = MCPClientManager() + # Create the client client = await get_temporal_client() - # Initialize the activities class - activities = ToolActivities() + # Initialize the activities class with injected manager + activities = ToolActivities(mcp_client_manager) print(f"ToolActivities initialized with LLM model: {llm_model}") # If using Ollama, pre-load the model to avoid cold start latency @@ -54,25 +58,31 @@ async def main(): print("Worker ready to process tasks!") logging.basicConfig(level=logging.INFO) - # Run the worker - with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor: - worker = Worker( - client, - task_queue=TEMPORAL_TASK_QUEUE, - workflows=[AgentGoalWorkflow], - activities=[ - activities.agent_validatePrompt, - activities.agent_toolPlanner, - activities.get_wf_env_vars, - activities.mcp_tool_activity, - dynamic_tool_activity, - mcp_list_tools, - ], - activity_executor=activity_executor, - ) + # Run the worker with proper cleanup + try: + with concurrent.futures.ThreadPoolExecutor( + max_workers=100 + ) as activity_executor: + worker = Worker( + client, + task_queue=TEMPORAL_TASK_QUEUE, + workflows=[AgentGoalWorkflow], + activities=[ + activities.agent_validatePrompt, + activities.agent_toolPlanner, + activities.get_wf_env_vars, + activities.mcp_tool_activity, + dynamic_tool_activity, + mcp_list_tools, + ], + activity_executor=activity_executor, + ) - print(f"Starting worker, connecting to task queue: {TEMPORAL_TASK_QUEUE}") - await worker.run() + print(f"Starting worker, connecting to task queue: {TEMPORAL_TASK_QUEUE}") + await worker.run() + finally: + # Cleanup MCP connections when worker shuts down + await mcp_client_manager.cleanup() if __name__ == "__main__": diff --git a/shared/mcp_client_manager.py b/shared/mcp_client_manager.py new file mode 100644 index 0000000..1305c49 --- /dev/null +++ b/shared/mcp_client_manager.py @@ -0,0 +1,167 @@ +import asyncio +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any, Dict, Tuple + +from temporalio import activity + +from models.tool_definitions import MCPServerDefinition + +# Import MCP client libraries +if TYPE_CHECKING: + from mcp import ClientSession, StdioServerParameters + from mcp.client.stdio import stdio_client +else: + try: + from mcp import ClientSession, StdioServerParameters + from mcp.client.stdio import stdio_client + except ImportError: + # Fallback if MCP not installed + ClientSession = None + StdioServerParameters = None + stdio_client = None + + +class MCPClientManager: + """Manages pooled MCP client connections for reuse across tool calls""" + + def __init__(self): + self._clients: Dict[str, Any] = {} + self._connections: Dict[str, Tuple[Any, Any]] = {} + self._lock = asyncio.Lock() + + async def get_client( + self, server_def: MCPServerDefinition | Dict[str, Any] | None + ) -> Any: + """Return existing client or create new one, keyed by server definition hash""" + async with self._lock: + key = self._get_server_key(server_def) + if key not in self._clients: + await self._create_client(server_def, key) + activity.logger.info( + f"Created new MCP client for {self._get_server_name(server_def)}" + ) + else: + activity.logger.info( + f"Reusing existing MCP client for {self._get_server_name(server_def)}" + ) + return self._clients[key] + + def _get_server_key( + self, server_def: MCPServerDefinition | Dict[str, Any] | None + ) -> str: + """Generate unique key for server definition""" + if server_def is None: + return "default:python:server.py" + + # Handle both MCPServerDefinition objects and dicts (from Temporal serialization) + if isinstance(server_def, dict): + name = server_def.get("name", "default") + command = server_def.get("command", "python") + args = server_def.get("args", ["server.py"]) + else: + name = server_def.name + command = server_def.command + args = server_def.args + + return f"{name}:{command}:{':'.join(args)}" + + def _get_server_name( + self, server_def: MCPServerDefinition | Dict[str, Any] | None + ) -> str: + """Get server name for logging""" + if server_def is None: + return "default" + + if isinstance(server_def, dict): + return server_def.get("name", "default") + else: + return server_def.name + + def _build_connection( + self, server_def: MCPServerDefinition | Dict[str, Any] | None + ) -> Dict[str, Any]: + """Build connection parameters from MCPServerDefinition or dict""" + if server_def is None: + # Default to stdio connection with the main server + return { + "type": "stdio", + "command": "python", + "args": ["server.py"], + "env": {}, + } + + # Handle both MCPServerDefinition objects and dicts (from Temporal serialization) + if isinstance(server_def, dict): + return { + "type": server_def.get("connection_type", "stdio"), + "command": server_def.get("command", "python"), + "args": server_def.get("args", ["server.py"]), + "env": server_def.get("env", {}) or {}, + } + + return { + "type": server_def.connection_type, + "command": server_def.command, + "args": server_def.args, + "env": server_def.env or {}, + } + + @asynccontextmanager + async def _stdio_connection(self, command: str, args: list, env: dict): + """Create stdio connection to MCP server""" + if stdio_client is None: + raise Exception("MCP client libraries not available") + + # Create server parameters + server_params = StdioServerParameters(command=command, args=args, env=env) + + async with stdio_client(server_params) as (read, write): + yield read, write + + async def _create_client( + self, server_def: MCPServerDefinition | Dict[str, Any] | None, key: str + ): + """Create and store new client connection""" + connection = self._build_connection(server_def) + + if connection["type"] == "stdio": + # Create stdio connection + connection_manager = self._stdio_connection( + command=connection.get("command", "python"), + args=connection.get("args", ["server.py"]), + env=connection.get("env", {}), + ) + + # Enter the connection context + read, write = await connection_manager.__aenter__() + + # Create and initialize client session + session = ClientSession(read, write) + await session.initialize() + + # Store both the session and connection manager for cleanup + self._clients[key] = session + self._connections[key] = (connection_manager, read, write) + else: + raise Exception(f"Unsupported connection type: {connection['type']}") + + async def cleanup(self): + """Close all connections gracefully""" + async with self._lock: + # Close all client sessions + for session in self._clients.values(): + try: + await session.close() + except Exception as e: + activity.logger.warning(f"Error closing MCP session: {e}") + + # Exit all connection contexts + for connection_manager, read, write in self._connections.values(): + try: + await connection_manager.__aexit__(None, None, None) + except Exception as e: + activity.logger.warning(f"Error closing MCP connection: {e}") + + self._clients.clear() + self._connections.clear() + activity.logger.info("All MCP connections closed") diff --git a/tools/food/setup/archive_food_products.py b/tools/food/setup/archive_food_products.py index ef1ff46..2ee0287 100644 --- a/tools/food/setup/archive_food_products.py +++ b/tools/food/setup/archive_food_products.py @@ -1,4 +1,5 @@ import os + from dotenv import load_dotenv diff --git a/tools/food/setup/create_stripe_products.py b/tools/food/setup/create_stripe_products.py index b462bbc..91dc8f9 100644 --- a/tools/food/setup/create_stripe_products.py +++ b/tools/food/setup/create_stripe_products.py @@ -1,5 +1,6 @@ import json import os + from dotenv import load_dotenv