mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-17 06:58:09 +01:00
Model Context Protocol (MCP) support with new use case (#42)
* initial mcp * food ordering with mcp * prompt eng * splitting out goals and updating docs * a diff so I can get tests from codex * a diff so I can get tests from codex * oops, missing files * tests, file formatting * readme and setup updates * setup.md link fixes * readme change * readme change * readme change * stripe food setup script * single agent mode default * prompt engineering for better multi agent performance * performance should be greatly improved * Update goals/finance.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update activities/tool_activities.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * co-pilot PR suggested this change, and now fixed it * stronger wording around json format response * formatting * moved docs to dir * moved image assets under docs * cleanup env example, stripe guidance * cleanup --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
1811e4cf59
commit
5d55a9fe80
@@ -1,13 +1,15 @@
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from litellm import completion
|
||||
from temporalio import activity
|
||||
from temporalio.common import RawValue
|
||||
from temporalio.exceptions import ApplicationError
|
||||
|
||||
from models.data_types import (
|
||||
EnvLookupInput,
|
||||
@@ -16,6 +18,17 @@ from models.data_types import (
|
||||
ValidationInput,
|
||||
ValidationResult,
|
||||
)
|
||||
from models.tool_definitions import MCPServerDefinition
|
||||
|
||||
# Import MCP client libraries
|
||||
try:
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
except ImportError:
|
||||
# Fallback if MCP not installed
|
||||
ClientSession = None
|
||||
StdioServerParameters = None
|
||||
stdio_client = None
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -120,10 +133,16 @@ class ToolActivities:
|
||||
response = completion(**completion_kwargs)
|
||||
|
||||
response_content = response.choices[0].message.content
|
||||
activity.logger.info(f"LLM response: {response_content}")
|
||||
activity.logger.info(f"Raw LLM response: {repr(response_content)}")
|
||||
activity.logger.info(f"LLM response content: {response_content}")
|
||||
activity.logger.info(f"LLM response type: {type(response_content)}")
|
||||
activity.logger.info(
|
||||
f"LLM response length: {len(response_content) if response_content else 'None'}"
|
||||
)
|
||||
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response_content)
|
||||
activity.logger.info(f"Sanitized response: {repr(response_content)}")
|
||||
|
||||
return self.parse_json_response(response_content)
|
||||
except Exception as e:
|
||||
@@ -159,7 +178,7 @@ class ToolActivities:
|
||||
handles default/None
|
||||
"""
|
||||
output: EnvLookupOutput = EnvLookupOutput(
|
||||
show_confirm=input.show_confirm_default, multi_goal_mode=True
|
||||
show_confirm=input.show_confirm_default, multi_goal_mode=False
|
||||
)
|
||||
show_confirm_value = os.getenv(input.show_confirm_env_var_name)
|
||||
if show_confirm_value is None:
|
||||
@@ -171,17 +190,29 @@ class ToolActivities:
|
||||
|
||||
first_goal_value = os.getenv("AGENT_GOAL")
|
||||
if first_goal_value is None:
|
||||
output.multi_goal_mode = True # default if unset
|
||||
output.multi_goal_mode = False # default to single agent mode if unset
|
||||
elif (
|
||||
first_goal_value is not None
|
||||
and first_goal_value.lower() != "goal_choose_agent_type"
|
||||
and first_goal_value.lower() == "goal_choose_agent_type"
|
||||
):
|
||||
output.multi_goal_mode = False
|
||||
else:
|
||||
output.multi_goal_mode = True
|
||||
else:
|
||||
output.multi_goal_mode = False
|
||||
|
||||
return output
|
||||
|
||||
@activity.defn
|
||||
async def mcp_tool_activity(
|
||||
self, tool_name: str, tool_args: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""MCP Tool"""
|
||||
activity.logger.info(f"Executing MCP tool: {tool_name} with args: {tool_args}")
|
||||
|
||||
# Extract server definition
|
||||
server_definition = tool_args.pop("server_definition", None)
|
||||
|
||||
return await _execute_mcp_tool(tool_name, tool_args, server_definition)
|
||||
|
||||
|
||||
@activity.defn(dynamic=True)
|
||||
async def dynamic_tool_activity(args: Sequence[RawValue]) -> dict:
|
||||
@@ -191,13 +222,246 @@ async def dynamic_tool_activity(args: Sequence[RawValue]) -> dict:
|
||||
tool_args = activity.payload_converter().from_payload(args[0].payload, dict)
|
||||
activity.logger.info(f"Running dynamic tool '{tool_name}' with args: {tool_args}")
|
||||
|
||||
# Delegate to the relevant function
|
||||
handler = get_handler(tool_name)
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
result = await handler(tool_args)
|
||||
else:
|
||||
result = handler(tool_args)
|
||||
# Check if this is an MCP tool call by looking for server_definition in args
|
||||
server_definition = tool_args.pop("server_definition", None)
|
||||
|
||||
# Optionally log or augment the result
|
||||
activity.logger.info(f"Tool '{tool_name}' result: {result}")
|
||||
if server_definition:
|
||||
# This is an MCP tool call - handle it directly
|
||||
activity.logger.info(f"Executing MCP tool: {tool_name}")
|
||||
return await _execute_mcp_tool(tool_name, tool_args, server_definition)
|
||||
else:
|
||||
# This is a regular tool - delegate to the relevant function
|
||||
handler = get_handler(tool_name)
|
||||
if inspect.iscoroutinefunction(handler):
|
||||
result = await handler(tool_args)
|
||||
else:
|
||||
result = handler(tool_args)
|
||||
|
||||
# Optionally log or augment the result
|
||||
activity.logger.info(f"Tool '{tool_name}' result: {result}")
|
||||
return result
|
||||
|
||||
|
||||
# MCP Client Activities
|
||||
|
||||
|
||||
def _build_connection(
|
||||
server_definition: MCPServerDefinition | Dict[str, Any] | None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build connection parameters from MCPServerDefinition or dict"""
|
||||
if server_definition is None:
|
||||
# Default to stdio connection with the main server
|
||||
return {"type": "stdio", "command": "python", "args": ["server.py"], "env": {}}
|
||||
|
||||
# Handle both MCPServerDefinition objects and dicts (from Temporal serialization)
|
||||
if isinstance(server_definition, dict):
|
||||
return {
|
||||
"type": server_definition.get("connection_type", "stdio"),
|
||||
"command": server_definition.get("command", "python"),
|
||||
"args": server_definition.get("args", ["server.py"]),
|
||||
"env": server_definition.get("env", {}) or {},
|
||||
}
|
||||
|
||||
return {
|
||||
"type": server_definition.connection_type,
|
||||
"command": server_definition.command,
|
||||
"args": server_definition.args,
|
||||
"env": server_definition.env or {},
|
||||
}
|
||||
|
||||
|
||||
def _normalize_result(result: Any) -> Any:
|
||||
"""Normalize MCP tool result for serialization"""
|
||||
if hasattr(result, "content"):
|
||||
# Handle MCP result objects
|
||||
if hasattr(result.content, "__iter__") and not isinstance(result.content, str):
|
||||
return [
|
||||
item.text if hasattr(item, "text") else str(item)
|
||||
for item in result.content
|
||||
]
|
||||
return str(result.content)
|
||||
return result
|
||||
|
||||
|
||||
def _convert_args_types(tool_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert string arguments to appropriate types for MCP tools"""
|
||||
converted_args = {}
|
||||
|
||||
for key, value in tool_args.items():
|
||||
if key == "server_definition":
|
||||
# Skip server_definition - it's metadata
|
||||
continue
|
||||
|
||||
if isinstance(value, str):
|
||||
# Try to convert string values to appropriate types
|
||||
if value.isdigit():
|
||||
# Convert numeric strings to integers
|
||||
converted_args[key] = int(value)
|
||||
elif value.replace(".", "").isdigit() and value.count(".") == 1:
|
||||
# Convert decimal strings to floats
|
||||
converted_args[key] = float(value)
|
||||
elif value.lower() in ("true", "false"):
|
||||
# Convert boolean strings
|
||||
converted_args[key] = value.lower() == "true"
|
||||
else:
|
||||
# Keep as string
|
||||
converted_args[key] = value
|
||||
else:
|
||||
# Keep non-string values as-is
|
||||
converted_args[key] = value
|
||||
|
||||
return converted_args
|
||||
|
||||
|
||||
async def _execute_mcp_tool(
|
||||
tool_name: str,
|
||||
tool_args: Dict[str, Any],
|
||||
server_definition: MCPServerDefinition | Dict[str, Any] | None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute an MCP tool with the given arguments and server definition"""
|
||||
activity.logger.info(f"Executing MCP tool: {tool_name}")
|
||||
|
||||
# Convert argument types for MCP tools
|
||||
converted_args = _convert_args_types(tool_args)
|
||||
connection = _build_connection(server_definition)
|
||||
|
||||
try:
|
||||
if connection["type"] == "stdio":
|
||||
# Handle stdio connection
|
||||
async with _stdio_connection(
|
||||
command=connection.get("command", "python"),
|
||||
args=connection.get("args", ["server.py"]),
|
||||
env=connection.get("env", {}),
|
||||
) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the session
|
||||
activity.logger.info(f"Initializing MCP session for {tool_name}")
|
||||
await session.initialize()
|
||||
activity.logger.info(f"MCP session initialized for {tool_name}")
|
||||
|
||||
# Call the tool
|
||||
activity.logger.info(
|
||||
f"Calling MCP tool {tool_name} with args: {converted_args}"
|
||||
)
|
||||
try:
|
||||
result = await session.call_tool(
|
||||
tool_name, arguments=converted_args
|
||||
)
|
||||
activity.logger.info(
|
||||
f"MCP tool {tool_name} returned result: {result}"
|
||||
)
|
||||
except Exception as tool_exc:
|
||||
activity.logger.error(
|
||||
f"MCP tool {tool_name} call failed: {type(tool_exc).__name__}: {tool_exc}"
|
||||
)
|
||||
raise
|
||||
|
||||
normalized_result = _normalize_result(result)
|
||||
activity.logger.info(f"MCP tool {tool_name} completed successfully")
|
||||
|
||||
return {
|
||||
"tool": tool_name,
|
||||
"success": True,
|
||||
"content": normalized_result,
|
||||
}
|
||||
|
||||
elif connection["type"] == "tcp":
|
||||
# Handle TCP connection (placeholder for future implementation)
|
||||
raise ApplicationError("TCP connections not yet implemented")
|
||||
|
||||
else:
|
||||
raise ApplicationError(f"Unsupported connection type: {connection['type']}")
|
||||
|
||||
except Exception as e:
|
||||
activity.logger.error(f"MCP tool {tool_name} failed: {str(e)}")
|
||||
|
||||
# Return error information
|
||||
return {
|
||||
"tool": tool_name,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _stdio_connection(command: str, args: list, env: dict):
|
||||
"""Create stdio connection to MCP server"""
|
||||
if stdio_client is None:
|
||||
raise ApplicationError("MCP client libraries not available")
|
||||
|
||||
# Create server parameters
|
||||
server_params = StdioServerParameters(command=command, args=args, env=env)
|
||||
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
yield read, write
|
||||
|
||||
|
||||
@activity.defn
|
||||
async def mcp_list_tools(
|
||||
server_definition: MCPServerDefinition, include_tools: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""List available MCP tools from the specified server"""
|
||||
|
||||
activity.logger.info(f"Listing MCP tools for server: {server_definition.name}")
|
||||
|
||||
connection = _build_connection(server_definition)
|
||||
|
||||
try:
|
||||
if connection["type"] == "stdio":
|
||||
async with _stdio_connection(
|
||||
command=connection.get("command", "python"),
|
||||
args=connection.get("args", ["server.py"]),
|
||||
env=connection.get("env", {}),
|
||||
) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
# Initialize the session
|
||||
await session.initialize()
|
||||
|
||||
# List available tools
|
||||
tools_response = await session.list_tools()
|
||||
|
||||
# Process tools based on include_tools filter
|
||||
tools_info = {}
|
||||
for tool in tools_response.tools:
|
||||
# If include_tools is specified, only include those tools
|
||||
if include_tools is None or tool.name in include_tools:
|
||||
tools_info[tool.name] = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"inputSchema": (
|
||||
tool.inputSchema.model_dump()
|
||||
if hasattr(tool.inputSchema, "model_dump")
|
||||
else str(tool.inputSchema)
|
||||
),
|
||||
}
|
||||
|
||||
activity.logger.info(
|
||||
f"Found {len(tools_info)} tools for server {server_definition.name}"
|
||||
)
|
||||
|
||||
return {
|
||||
"server_name": server_definition.name,
|
||||
"success": True,
|
||||
"tools": tools_info,
|
||||
"total_available": len(tools_response.tools),
|
||||
"filtered_count": len(tools_info),
|
||||
}
|
||||
|
||||
elif connection["type"] == "tcp":
|
||||
raise ApplicationError("TCP connections not yet implemented")
|
||||
|
||||
else:
|
||||
raise ApplicationError(f"Unsupported connection type: {connection['type']}")
|
||||
|
||||
except Exception as e:
|
||||
activity.logger.error(
|
||||
f"Failed to list tools for server {server_definition.name}: {str(e)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"server_name": server_definition.name,
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user