mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
Model Context Protocol (MCP) support with new use case (#42)
* initial mcp * food ordering with mcp * prompt eng * splitting out goals and updating docs * a diff so I can get tests from codex * a diff so I can get tests from codex * oops, missing files * tests, file formatting * readme and setup updates * setup.md link fixes * readme change * readme change * readme change * stripe food setup script * single agent mode default * prompt engineering for better multi agent performance * performance should be greatly improved * Update goals/finance.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update activities/tool_activities.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * co-pilot PR suggested this change, and now fixed it * stronger wording around json format response * formatting * moved docs to dir * moved image assets under docs * cleanup env example, stripe guidance * cleanup --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
1811e4cf59
commit
5d55a9fe80
418
tests/test_mcp_integration.py
Normal file
418
tests/test_mcp_integration.py
Normal file
@@ -0,0 +1,418 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from collections import deque
|
||||
from typing import Sequence
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from temporalio import activity
|
||||
from temporalio.client import Client
|
||||
from temporalio.common import RawValue
|
||||
from temporalio.testing import ActivityEnvironment
|
||||
from temporalio.worker import Worker
|
||||
|
||||
from activities.tool_activities import _convert_args_types, mcp_list_tools
|
||||
from models.data_types import (
|
||||
AgentGoalWorkflowParams,
|
||||
CombinedInput,
|
||||
EnvLookupInput,
|
||||
EnvLookupOutput,
|
||||
ToolPromptInput,
|
||||
ValidationInput,
|
||||
ValidationResult,
|
||||
)
|
||||
from models.tool_definitions import AgentGoal, MCPServerDefinition, ToolDefinition
|
||||
from workflows.agent_goal_workflow import AgentGoalWorkflow
|
||||
from workflows.workflow_helpers import is_mcp_tool
|
||||
|
||||
|
||||
class DummySession:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def list_tools(self):
|
||||
class Tool:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.description = f"desc {name}"
|
||||
self.inputSchema = {}
|
||||
|
||||
return type(
|
||||
"Resp", (), {"tools": [Tool("list_products"), Tool("create_customer")]}
|
||||
)()
|
||||
|
||||
|
||||
def test_convert_args_types_basic():
|
||||
args = {
|
||||
"count": "5",
|
||||
"price": "12.5",
|
||||
"flag_true": "true",
|
||||
"flag_false": "false",
|
||||
"name": "pizza",
|
||||
"already_int": 2,
|
||||
}
|
||||
result = _convert_args_types(args)
|
||||
assert result["count"] == 5 and isinstance(result["count"], int)
|
||||
assert result["price"] == 12.5 and isinstance(result["price"], float)
|
||||
assert result["flag_true"] is True
|
||||
assert result["flag_false"] is False
|
||||
assert result["name"] == "pizza"
|
||||
assert result["already_int"] == 2
|
||||
|
||||
|
||||
def test_is_mcp_tool_identification():
|
||||
server_def = MCPServerDefinition(name="test", command="python", args=["server.py"])
|
||||
goal = AgentGoal(
|
||||
id="g",
|
||||
category_tag="food",
|
||||
agent_name="agent",
|
||||
agent_friendly_description="",
|
||||
description="",
|
||||
tools=[ToolDefinition(name="AddToCart", description="", arguments=[])],
|
||||
starter_prompt="",
|
||||
example_conversation_history="",
|
||||
mcp_server_definition=server_def,
|
||||
)
|
||||
|
||||
assert is_mcp_tool("list_products", goal) is True
|
||||
assert is_mcp_tool("AddToCart", goal) is False
|
||||
no_mcp_goal = AgentGoal(
|
||||
id="g2",
|
||||
category_tag="food",
|
||||
agent_name="agent",
|
||||
agent_friendly_description="",
|
||||
description="",
|
||||
tools=[],
|
||||
starter_prompt="",
|
||||
example_conversation_history="",
|
||||
mcp_server_definition=None,
|
||||
)
|
||||
assert is_mcp_tool("list_products", no_mcp_goal) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_list_tools_success():
|
||||
server_def = MCPServerDefinition(name="test", command="python", args=["server.py"])
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def dummy_connection(command, args, env):
|
||||
yield None, None
|
||||
|
||||
with patch(
|
||||
"activities.tool_activities._build_connection", return_value={"type": "stdio"}
|
||||
), patch("activities.tool_activities._stdio_connection", dummy_connection), patch(
|
||||
"activities.tool_activities.ClientSession", lambda r, w: DummySession()
|
||||
):
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(mcp_list_tools, server_def, ["list_products"])
|
||||
assert result["success"] is True
|
||||
assert result["filtered_count"] == 1
|
||||
assert "list_products" in result["tools"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_list_tools_failure():
|
||||
server_def = MCPServerDefinition(name="test", command="python", args=["server.py"])
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def failing_connection(*args, **kwargs):
|
||||
raise RuntimeError("conn fail")
|
||||
yield None, None
|
||||
|
||||
with patch(
|
||||
"activities.tool_activities._build_connection", return_value={"type": "stdio"}
|
||||
), patch("activities.tool_activities._stdio_connection", failing_connection):
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(mcp_list_tools, server_def)
|
||||
assert result["success"] is False
|
||||
assert "conn fail" in result["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_loads_mcp_tools_dynamically(client: Client):
|
||||
"""Workflow should load MCP tools and add them to the goal."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
server_def = MCPServerDefinition(name="test", command="python", args=["srv.py"])
|
||||
goal = AgentGoal(
|
||||
id="g_mcp",
|
||||
category_tag="food",
|
||||
agent_name="agent",
|
||||
agent_friendly_description="",
|
||||
description="",
|
||||
tools=[],
|
||||
starter_prompt="",
|
||||
example_conversation_history="",
|
||||
mcp_server_definition=server_def,
|
||||
)
|
||||
combined_input = CombinedInput(
|
||||
agent_goal=goal,
|
||||
tool_params=AgentGoalWorkflowParams(
|
||||
conversation_summary=None, prompt_queue=deque()
|
||||
),
|
||||
)
|
||||
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(show_confirm=True, multi_goal_mode=True)
|
||||
|
||||
@activity.defn(name="mcp_list_tools")
|
||||
async def mock_mcp_list_tools(
|
||||
server_definition: MCPServerDefinition, include_tools=None
|
||||
):
|
||||
return {
|
||||
"server_name": server_definition.name,
|
||||
"success": True,
|
||||
"tools": {
|
||||
"list_products": {
|
||||
"name": "list_products",
|
||||
"description": "",
|
||||
"inputSchema": {},
|
||||
},
|
||||
},
|
||||
"total_available": 1,
|
||||
"filtered_count": 1,
|
||||
}
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[mock_get_wf_env_vars, mock_mcp_list_tools],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Wait until the MCP tools have been added
|
||||
for _ in range(10):
|
||||
updated_goal = await handle.query(AgentGoalWorkflow.get_agent_goal)
|
||||
if any(t.name == "list_products" for t in updated_goal.tools):
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
else:
|
||||
updated_goal = await handle.query(AgentGoalWorkflow.get_agent_goal)
|
||||
|
||||
assert any(t.name == "list_products" for t in updated_goal.tools)
|
||||
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
await handle.result()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_tool_execution_flow(client: Client):
|
||||
"""MCP tool execution should pass server_definition to activity."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
server_def = MCPServerDefinition(name="test", command="python", args=["srv.py"])
|
||||
goal = AgentGoal(
|
||||
id="g_mcp_exec",
|
||||
category_tag="food",
|
||||
agent_name="agent",
|
||||
agent_friendly_description="",
|
||||
description="",
|
||||
tools=[],
|
||||
starter_prompt="",
|
||||
example_conversation_history="",
|
||||
mcp_server_definition=server_def,
|
||||
)
|
||||
combined_input = CombinedInput(
|
||||
agent_goal=goal,
|
||||
tool_params=AgentGoalWorkflowParams(
|
||||
conversation_summary=None, prompt_queue=deque()
|
||||
),
|
||||
)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(show_confirm=True, multi_goal_mode=True)
|
||||
|
||||
@activity.defn(name="agent_validatePrompt")
|
||||
async def mock_validate(prompt: ValidationInput) -> ValidationResult:
|
||||
return ValidationResult(validationResult=True, validationFailedReason={})
|
||||
|
||||
@activity.defn(name="agent_toolPlanner")
|
||||
async def mock_planner(input: ToolPromptInput) -> dict:
|
||||
if "planner_called" not in captured:
|
||||
captured["planner_called"] = True
|
||||
return {
|
||||
"next": "confirm",
|
||||
"tool": "list_products",
|
||||
"args": {"limit": "5"},
|
||||
"response": "Listing products",
|
||||
}
|
||||
return {"next": "done", "response": "done"}
|
||||
|
||||
@activity.defn(name="mcp_list_tools")
|
||||
async def mock_mcp_list_tools(
|
||||
server_definition: MCPServerDefinition, include_tools=None
|
||||
):
|
||||
return {
|
||||
"server_name": server_definition.name,
|
||||
"success": True,
|
||||
"tools": {
|
||||
"list_products": {
|
||||
"name": "list_products",
|
||||
"description": "",
|
||||
"inputSchema": {},
|
||||
},
|
||||
},
|
||||
"total_available": 1,
|
||||
"filtered_count": 1,
|
||||
}
|
||||
|
||||
@activity.defn(name="dynamic_tool_activity", dynamic=True)
|
||||
async def mock_dynamic_tool_activity(args: Sequence[RawValue]) -> dict:
|
||||
payload = activity.payload_converter().from_payload(args[0].payload, dict)
|
||||
captured["dynamic_args"] = payload
|
||||
return {"tool": "list_products", "success": True, "content": {"ok": True}}
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[
|
||||
mock_get_wf_env_vars,
|
||||
mock_validate,
|
||||
mock_planner,
|
||||
mock_mcp_list_tools,
|
||||
mock_dynamic_tool_activity,
|
||||
],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "show menu")
|
||||
await asyncio.sleep(0.5)
|
||||
await handle.signal(AgentGoalWorkflow.confirm)
|
||||
# Give workflow time to execute the MCP tool and finish
|
||||
await asyncio.sleep(0.5)
|
||||
result = await handle.result()
|
||||
print(result)
|
||||
|
||||
assert "dynamic_args" in captured
|
||||
assert "server_definition" in captured["dynamic_args"]
|
||||
assert captured["dynamic_args"]["server_definition"]["name"] == server_def.name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_tool_failure_recorded(client: Client):
|
||||
"""Failure of an MCP tool should be recorded in conversation history."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
server_def = MCPServerDefinition(name="test", command="python", args=["srv.py"])
|
||||
goal = AgentGoal(
|
||||
id="g_mcp_fail",
|
||||
category_tag="food",
|
||||
agent_name="agent",
|
||||
agent_friendly_description="",
|
||||
description="",
|
||||
tools=[],
|
||||
starter_prompt="",
|
||||
example_conversation_history="",
|
||||
mcp_server_definition=server_def,
|
||||
)
|
||||
combined_input = CombinedInput(
|
||||
agent_goal=goal,
|
||||
tool_params=AgentGoalWorkflowParams(
|
||||
conversation_summary=None, prompt_queue=deque()
|
||||
),
|
||||
)
|
||||
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(show_confirm=True, multi_goal_mode=True)
|
||||
|
||||
@activity.defn(name="agent_validatePrompt")
|
||||
async def mock_validate(prompt: ValidationInput) -> ValidationResult:
|
||||
return ValidationResult(validationResult=True, validationFailedReason={})
|
||||
|
||||
@activity.defn(name="agent_toolPlanner")
|
||||
async def mock_planner(input: ToolPromptInput) -> dict:
|
||||
return {
|
||||
"next": "confirm",
|
||||
"tool": "list_products",
|
||||
"args": {},
|
||||
"response": "Listing products",
|
||||
}
|
||||
|
||||
@activity.defn(name="mcp_list_tools")
|
||||
async def mock_mcp_list_tools(
|
||||
server_definition: MCPServerDefinition, include_tools=None
|
||||
):
|
||||
return {
|
||||
"server_name": server_definition.name,
|
||||
"success": True,
|
||||
"tools": {
|
||||
"list_products": {
|
||||
"name": "list_products",
|
||||
"description": "",
|
||||
"inputSchema": {},
|
||||
},
|
||||
},
|
||||
"total_available": 1,
|
||||
"filtered_count": 1,
|
||||
}
|
||||
|
||||
@activity.defn(name="dynamic_tool_activity", dynamic=True)
|
||||
async def failing_dynamic_tool(args: Sequence[RawValue]) -> dict:
|
||||
return {
|
||||
"tool": "list_products",
|
||||
"success": False,
|
||||
"error": "Connection timed out",
|
||||
}
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[
|
||||
mock_get_wf_env_vars,
|
||||
mock_validate,
|
||||
mock_planner,
|
||||
mock_mcp_list_tools,
|
||||
failing_dynamic_tool,
|
||||
],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "show menu")
|
||||
await asyncio.sleep(0.5)
|
||||
await handle.signal(AgentGoalWorkflow.confirm)
|
||||
# Give workflow time to record the failure result
|
||||
await asyncio.sleep(0.5)
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
result = await handle.result()
|
||||
|
||||
import json
|
||||
|
||||
try:
|
||||
history = json.loads(result.replace("'", '"'))
|
||||
except Exception:
|
||||
history = eval(result)
|
||||
|
||||
assert any(
|
||||
msg["actor"] == "tool_result" and not msg["response"].get("success", True)
|
||||
for msg in history["messages"]
|
||||
)
|
||||
@@ -6,7 +6,11 @@ import pytest
|
||||
from temporalio.client import Client
|
||||
from temporalio.testing import ActivityEnvironment
|
||||
|
||||
from activities.tool_activities import ToolActivities, dynamic_tool_activity
|
||||
from activities.tool_activities import (
|
||||
MCPServerDefinition,
|
||||
ToolActivities,
|
||||
dynamic_tool_activity,
|
||||
)
|
||||
from models.data_types import (
|
||||
EnvLookupInput,
|
||||
EnvLookupOutput,
|
||||
@@ -190,7 +194,7 @@ class TestToolActivities:
|
||||
|
||||
assert isinstance(result, EnvLookupOutput)
|
||||
assert result.show_confirm is True # default value
|
||||
assert result.multi_goal_mode is True # default value
|
||||
assert result.multi_goal_mode is False # default value (single agent mode)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_wf_env_vars_custom_values(self):
|
||||
@@ -443,3 +447,132 @@ class TestEdgeCases:
|
||||
)
|
||||
|
||||
assert result.show_confirm
|
||||
|
||||
|
||||
class TestMCPIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_args_types(self):
|
||||
from activities.tool_activities import _convert_args_types
|
||||
|
||||
args = {
|
||||
"int_val": "123",
|
||||
"float_val": "123.45",
|
||||
"bool_true": "true",
|
||||
"bool_false": "False",
|
||||
"string": "text",
|
||||
"other": 5,
|
||||
}
|
||||
converted = _convert_args_types(args)
|
||||
assert converted["int_val"] == 123
|
||||
assert converted["float_val"] == 123.45
|
||||
assert converted["bool_true"] is True
|
||||
assert converted["bool_false"] is False
|
||||
assert converted["string"] == "text"
|
||||
assert converted["other"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_tool_activity_mcp_call(self):
|
||||
mcp_def = MCPServerDefinition(
|
||||
name="stripe", command="python", args=["server.py"]
|
||||
)
|
||||
payload = MagicMock()
|
||||
payload.payload = b'{"server_definition": null, "amount": "10", "flag": "true"}'
|
||||
mock_info = MagicMock()
|
||||
mock_info.activity_type = "list_products"
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def dummy_conn(*args, **kwargs):
|
||||
yield (None, None)
|
||||
|
||||
class DummySession:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def call_tool(self, tool_name, arguments=None):
|
||||
self.called_tool = tool_name
|
||||
self.called_args = arguments
|
||||
return MagicMock(content="ok")
|
||||
|
||||
mock_payload_converter = MagicMock()
|
||||
mock_payload_converter.from_payload.return_value = {
|
||||
"server_definition": mcp_def,
|
||||
"amount": "10",
|
||||
"flag": "true",
|
||||
}
|
||||
|
||||
with patch("activities.tool_activities._stdio_connection", dummy_conn), patch(
|
||||
"activities.tool_activities.ClientSession", return_value=DummySession()
|
||||
), patch(
|
||||
"activities.tool_activities._build_connection",
|
||||
return_value={
|
||||
"type": "stdio",
|
||||
"command": "python",
|
||||
"args": ["server.py"],
|
||||
"env": {},
|
||||
},
|
||||
), patch(
|
||||
"temporalio.activity.info", return_value=mock_info
|
||||
), patch(
|
||||
"temporalio.activity.payload_converter", return_value=mock_payload_converter
|
||||
):
|
||||
result = await ActivityEnvironment().run(dynamic_tool_activity, [payload])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["tool"] == "list_products"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_tool_activity_failure(self):
|
||||
tool_activities = ToolActivities()
|
||||
mcp_def = MCPServerDefinition(
|
||||
name="stripe", command="python", args=["server.py"]
|
||||
)
|
||||
|
||||
async def dummy_conn(*args, **kwargs):
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def cm():
|
||||
yield (None, None)
|
||||
|
||||
return cm()
|
||||
|
||||
class DummySession:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def call_tool(self, tool_name, arguments=None):
|
||||
raise TypeError("boom")
|
||||
|
||||
with patch("activities.tool_activities._stdio_connection", dummy_conn), patch(
|
||||
"activities.tool_activities.ClientSession", return_value=DummySession()
|
||||
), patch(
|
||||
"activities.tool_activities._build_connection",
|
||||
return_value={
|
||||
"type": "stdio",
|
||||
"command": "python",
|
||||
"args": ["server.py"],
|
||||
"env": {},
|
||||
},
|
||||
):
|
||||
result = await ActivityEnvironment().run(
|
||||
tool_activities.mcp_tool_activity,
|
||||
"list_products",
|
||||
{"server_definition": mcp_def, "amount": "10"},
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert result["error_type"] == "TypeError"
|
||||
|
||||
36
tests/test_workflow_helpers.py
Normal file
36
tests/test_workflow_helpers.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import pytest
|
||||
|
||||
from models.tool_definitions import (
|
||||
AgentGoal,
|
||||
MCPServerDefinition,
|
||||
ToolArgument,
|
||||
ToolDefinition,
|
||||
)
|
||||
from workflows.workflow_helpers import is_mcp_tool
|
||||
|
||||
|
||||
def make_goal(with_mcp: bool) -> AgentGoal:
|
||||
tools = [ToolDefinition(name="AddToCart", description="", arguments=[])]
|
||||
mcp_def = None
|
||||
if with_mcp:
|
||||
mcp_def = MCPServerDefinition(
|
||||
name="stripe", command="python", args=["server.py"]
|
||||
)
|
||||
return AgentGoal(
|
||||
id="g",
|
||||
category_tag="test",
|
||||
agent_name="Test",
|
||||
agent_friendly_description="",
|
||||
tools=tools,
|
||||
mcp_server_definition=mcp_def,
|
||||
)
|
||||
|
||||
|
||||
def test_is_mcp_tool_recognizes_native():
|
||||
goal = make_goal(True)
|
||||
assert not is_mcp_tool("AddToCart", goal)
|
||||
|
||||
|
||||
def test_is_mcp_tool_recognizes_mcp():
|
||||
goal = make_goal(True)
|
||||
assert is_mcp_tool("list_products", goal)
|
||||
@@ -1,9 +1,11 @@
|
||||
import concurrent.futures
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from temporalio import activity
|
||||
from temporalio.client import Client, WorkflowExecutionStatus
|
||||
from temporalio.common import RawValue
|
||||
from temporalio.worker import Worker
|
||||
|
||||
from api.main import get_initial_agent_goal
|
||||
@@ -16,6 +18,7 @@ from models.data_types import (
|
||||
ValidationInput,
|
||||
ValidationResult,
|
||||
)
|
||||
from models.tool_definitions import MCPServerDefinition
|
||||
from workflows.agent_goal_workflow import AgentGoalWorkflow
|
||||
|
||||
|
||||
@@ -53,6 +56,23 @@ async def test_flight_booking(client: Client):
|
||||
async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict:
|
||||
return {"next": "done", "response": "Test response from LLM"}
|
||||
|
||||
@activity.defn(name="mcp_list_tools")
|
||||
async def mock_mcp_list_tools(
|
||||
server_definition: MCPServerDefinition,
|
||||
include_tools: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return {"success": True, "tools": {}, "server_name": "test"}
|
||||
|
||||
@activity.defn(name="mcp_tool_activity")
|
||||
async def mock_mcp_tool_activity(
|
||||
tool_name: str, tool_args: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
return {"success": True, "result": "Mock MCP tool result"}
|
||||
|
||||
@activity.defn(name="dynamic_tool_activity", dynamic=True)
|
||||
async def mock_dynamic_tool_activity(args: Sequence[RawValue]) -> dict:
|
||||
return {"success": True, "result": "Mock dynamic tool result"}
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=100
|
||||
) as activity_executor:
|
||||
@@ -64,6 +84,9 @@ async def test_flight_booking(client: Client):
|
||||
mock_get_wf_env_vars,
|
||||
mock_agent_validatePrompt,
|
||||
mock_agent_toolPlanner,
|
||||
mock_mcp_list_tools,
|
||||
mock_mcp_tool_activity,
|
||||
mock_dynamic_tool_activity,
|
||||
],
|
||||
activity_executor=activity_executor,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user