mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-16 22:48:09 +01:00
Enhance Dev Experience and Code Quality (#41)
* Format codebase to satisfy linters * fixing pylance and ruff-checked files * contributing md, and type and formatting fixes * setup file capitalization * test fix
This commit is contained in:
committed by
GitHub
parent
e35181b5ad
commit
eb06cf5c8d
@@ -63,8 +63,8 @@ async def client(env: WorkflowEnvironment) -> Client:
|
||||
@pytest.fixture
|
||||
def sample_agent_goal():
|
||||
"""Sample agent goal for testing."""
|
||||
from models.tool_definitions import AgentGoal, ToolDefinition, ToolArgument
|
||||
|
||||
from models.tool_definitions import AgentGoal, ToolArgument, ToolDefinition
|
||||
|
||||
return AgentGoal(
|
||||
id="test_goal",
|
||||
category_tag="test",
|
||||
@@ -77,13 +77,11 @@ def sample_agent_goal():
|
||||
description="A test tool for testing purposes",
|
||||
arguments=[
|
||||
ToolArgument(
|
||||
name="test_arg",
|
||||
type="string",
|
||||
description="A test argument"
|
||||
name="test_arg", type="string", description="A test argument"
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -93,7 +91,7 @@ def sample_conversation_history():
|
||||
return {
|
||||
"messages": [
|
||||
{"actor": "user", "response": "Hello, I need help with testing"},
|
||||
{"actor": "agent", "response": "I can help you with that"}
|
||||
{"actor": "agent", "response": "I can help you with that"},
|
||||
]
|
||||
}
|
||||
|
||||
@@ -101,16 +99,13 @@ def sample_conversation_history():
|
||||
@pytest.fixture
|
||||
def sample_combined_input(sample_agent_goal):
|
||||
"""Sample combined input for workflow testing."""
|
||||
from models.data_types import CombinedInput, AgentGoalWorkflowParams
|
||||
|
||||
from collections import deque
|
||||
|
||||
|
||||
from models.data_types import AgentGoalWorkflowParams, CombinedInput
|
||||
|
||||
tool_params = AgentGoalWorkflowParams(
|
||||
conversation_summary="Test conversation summary",
|
||||
prompt_queue=deque() # Start with empty queue for most tests
|
||||
)
|
||||
|
||||
return CombinedInput(
|
||||
agent_goal=sample_agent_goal,
|
||||
tool_params=tool_params
|
||||
prompt_queue=deque(), # Start with empty queue for most tests
|
||||
)
|
||||
|
||||
return CombinedInput(agent_goal=sample_agent_goal, tool_params=tool_params)
|
||||
|
||||
@@ -1,40 +1,35 @@
|
||||
import uuid
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
|
||||
from temporalio import activity
|
||||
from temporalio.client import Client
|
||||
from temporalio.worker import Worker
|
||||
from temporalio.testing import WorkflowEnvironment
|
||||
|
||||
from workflows.agent_goal_workflow import AgentGoalWorkflow
|
||||
from activities.tool_activities import ToolActivities
|
||||
from models.data_types import (
|
||||
CombinedInput,
|
||||
AgentGoalWorkflowParams,
|
||||
ConversationHistory,
|
||||
ValidationResult,
|
||||
ValidationInput,
|
||||
EnvLookupOutput,
|
||||
CombinedInput,
|
||||
EnvLookupInput,
|
||||
ToolPromptInput
|
||||
EnvLookupOutput,
|
||||
ToolPromptInput,
|
||||
ValidationInput,
|
||||
ValidationResult,
|
||||
)
|
||||
from workflows.agent_goal_workflow import AgentGoalWorkflow
|
||||
|
||||
|
||||
class TestAgentGoalWorkflow:
|
||||
"""Test cases for AgentGoalWorkflow."""
|
||||
|
||||
async def test_workflow_initialization(self, client: Client, sample_combined_input: CombinedInput):
|
||||
async def test_workflow_initialization(
|
||||
self, client: Client, sample_combined_input: CombinedInput
|
||||
):
|
||||
"""Test workflow can be initialized and started."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
return EnvLookupOutput(show_confirm=True, multi_goal_mode=True)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
@@ -48,120 +43,47 @@ class TestAgentGoalWorkflow:
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
|
||||
# Verify workflow is running
|
||||
assert handle is not None
|
||||
|
||||
|
||||
# Query the workflow to check initial state
|
||||
conversation_history = await handle.query(AgentGoalWorkflow.get_conversation_history)
|
||||
conversation_history = await handle.query(
|
||||
AgentGoalWorkflow.get_conversation_history
|
||||
)
|
||||
assert isinstance(conversation_history, dict)
|
||||
assert "messages" in conversation_history
|
||||
|
||||
|
||||
# Test goal query
|
||||
agent_goal = await handle.query(AgentGoalWorkflow.get_agent_goal)
|
||||
assert agent_goal == sample_combined_input.agent_goal
|
||||
|
||||
|
||||
# End the workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
result = await handle.result()
|
||||
assert isinstance(result, str)
|
||||
|
||||
async def test_user_prompt_signal(self, client: Client, sample_combined_input: CombinedInput):
|
||||
async def test_user_prompt_signal(
|
||||
self, client: Client, sample_combined_input: CombinedInput
|
||||
):
|
||||
"""Test user_prompt signal handling."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
@activity.defn(name="agent_validatePrompt")
|
||||
async def mock_agent_validatePrompt(validation_input: ValidationInput) -> ValidationResult:
|
||||
return ValidationResult(
|
||||
validationResult=True,
|
||||
validationFailedReason={}
|
||||
)
|
||||
|
||||
@activity.defn(name="agent_toolPlanner")
|
||||
async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict:
|
||||
return {
|
||||
"next": "done",
|
||||
"response": "Test response from LLM"
|
||||
}
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[
|
||||
mock_get_wf_env_vars,
|
||||
mock_agent_validatePrompt,
|
||||
mock_agent_toolPlanner
|
||||
],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Send user prompt
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "Hello, this is a test message")
|
||||
|
||||
# Wait for workflow to complete (it should end due to "done" next step)
|
||||
result = await handle.result()
|
||||
assert isinstance(result, str)
|
||||
|
||||
# Verify the conversation includes our message
|
||||
import json
|
||||
try:
|
||||
conversation_history = json.loads(result.replace("'", '"'))
|
||||
except:
|
||||
# Fallback to eval if json fails
|
||||
conversation_history = eval(result)
|
||||
messages = conversation_history["messages"]
|
||||
|
||||
# Should have our user message and agent response
|
||||
user_messages = [msg for msg in messages if msg["actor"] == "user"]
|
||||
assert len(user_messages) > 0
|
||||
assert any("Hello, this is a test message" in str(msg["response"]) for msg in user_messages)
|
||||
|
||||
async def test_confirm_signal(self, client: Client, sample_combined_input: CombinedInput):
|
||||
"""Test confirm signal handling for tool execution."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
return EnvLookupOutput(show_confirm=True, multi_goal_mode=True)
|
||||
|
||||
@activity.defn(name="agent_validatePrompt")
|
||||
async def mock_agent_validatePrompt(validation_input: ValidationInput) -> ValidationResult:
|
||||
return ValidationResult(
|
||||
validationResult=True,
|
||||
validationFailedReason={}
|
||||
)
|
||||
|
||||
async def mock_agent_validatePrompt(
|
||||
validation_input: ValidationInput,
|
||||
) -> ValidationResult:
|
||||
return ValidationResult(validationResult=True, validationFailedReason={})
|
||||
|
||||
@activity.defn(name="agent_toolPlanner")
|
||||
async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict:
|
||||
return {
|
||||
"next": "confirm",
|
||||
"tool": "TestTool",
|
||||
"args": {"test_arg": "test_value"},
|
||||
"response": "Ready to execute tool"
|
||||
}
|
||||
|
||||
@activity.defn(name="TestTool")
|
||||
async def mock_test_tool(args: dict) -> dict:
|
||||
return {"result": "Test tool executed successfully"}
|
||||
|
||||
return {"next": "done", "response": "Test response from LLM"}
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
@@ -170,7 +92,6 @@ class TestAgentGoalWorkflow:
|
||||
mock_get_wf_env_vars,
|
||||
mock_agent_validatePrompt,
|
||||
mock_agent_toolPlanner,
|
||||
mock_test_tool
|
||||
],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
@@ -179,317 +100,64 @@ class TestAgentGoalWorkflow:
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Send user prompt that will require confirmation
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "Execute the test tool")
|
||||
|
||||
# Query to check tool data is set
|
||||
import asyncio
|
||||
await asyncio.sleep(0.1) # Give workflow time to process
|
||||
|
||||
tool_data = await handle.query(AgentGoalWorkflow.get_latest_tool_data)
|
||||
if tool_data:
|
||||
assert tool_data.get("tool") == "TestTool"
|
||||
assert tool_data.get("next") == "confirm"
|
||||
|
||||
# Send confirmation and end chat
|
||||
await handle.signal(AgentGoalWorkflow.confirm)
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
|
||||
|
||||
# Send user prompt
|
||||
await handle.signal(
|
||||
AgentGoalWorkflow.user_prompt, "Hello, this is a test message"
|
||||
)
|
||||
|
||||
# Wait for workflow to complete (it should end due to "done" next step)
|
||||
result = await handle.result()
|
||||
assert isinstance(result, str)
|
||||
|
||||
async def test_validation_failure(self, client: Client, sample_combined_input: CombinedInput):
|
||||
"""Test workflow handles validation failures correctly."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
@activity.defn(name="agent_validatePrompt")
|
||||
async def mock_agent_validatePrompt(validation_input: ValidationInput) -> ValidationResult:
|
||||
return ValidationResult(
|
||||
validationResult=False,
|
||||
validationFailedReason={
|
||||
"next": "question",
|
||||
"response": "Your request doesn't make sense in this context"
|
||||
}
|
||||
)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[
|
||||
mock_get_wf_env_vars,
|
||||
mock_agent_validatePrompt
|
||||
],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Send invalid prompt
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "Invalid nonsensical prompt")
|
||||
|
||||
# Give workflow time to process the prompt
|
||||
import asyncio
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# End workflow to check conversation
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
result = await handle.result()
|
||||
|
||||
# Verify validation failure message was added
|
||||
# Verify the conversation includes our message
|
||||
import json
|
||||
|
||||
try:
|
||||
conversation_history = json.loads(result.replace("'", '"'))
|
||||
except:
|
||||
except Exception:
|
||||
# Fallback to eval if json fails
|
||||
conversation_history = eval(result)
|
||||
messages = conversation_history["messages"]
|
||||
|
||||
# Should have validation failure response
|
||||
agent_messages = [msg for msg in messages if msg["actor"] == "agent"]
|
||||
assert len(agent_messages) > 0
|
||||
assert any("doesn't make sense" in str(msg["response"]) for msg in agent_messages)
|
||||
|
||||
async def test_conversation_summary_initialization(self, client: Client, sample_agent_goal):
|
||||
"""Test workflow initializes with conversation summary."""
|
||||
# Should have our user message and agent response
|
||||
user_messages = [msg for msg in messages if msg["actor"] == "user"]
|
||||
assert len(user_messages) > 0
|
||||
assert any(
|
||||
"Hello, this is a test message" in str(msg["response"])
|
||||
for msg in user_messages
|
||||
)
|
||||
|
||||
async def test_confirm_signal(
|
||||
self, client: Client, sample_combined_input: CombinedInput
|
||||
):
|
||||
"""Test confirm signal handling for tool execution."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create input with conversation summary
|
||||
from collections import deque
|
||||
tool_params = AgentGoalWorkflowParams(
|
||||
conversation_summary="Previous conversation summary",
|
||||
prompt_queue=deque()
|
||||
)
|
||||
combined_input = CombinedInput(
|
||||
agent_goal=sample_agent_goal,
|
||||
tool_params=tool_params
|
||||
)
|
||||
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[mock_get_wf_env_vars],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Give workflow time to initialize
|
||||
import asyncio
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Query conversation summary
|
||||
summary = await handle.query(AgentGoalWorkflow.get_summary_from_history)
|
||||
assert summary == "Previous conversation summary"
|
||||
|
||||
# Query conversation history - should include summary message
|
||||
conversation_history = await handle.query(AgentGoalWorkflow.get_conversation_history)
|
||||
messages = conversation_history["messages"]
|
||||
|
||||
# Should have conversation_summary message
|
||||
summary_messages = [msg for msg in messages if msg["actor"] == "conversation_summary"]
|
||||
assert len(summary_messages) == 1
|
||||
assert summary_messages[0]["response"] == "Previous conversation summary"
|
||||
|
||||
# End workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
await handle.result()
|
||||
return EnvLookupOutput(show_confirm=True, multi_goal_mode=True)
|
||||
|
||||
async def test_workflow_queries(self, client: Client, sample_combined_input: CombinedInput):
|
||||
"""Test all workflow query methods."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[mock_get_wf_env_vars],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Give workflow time to initialize
|
||||
import asyncio
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Test get_conversation_history query
|
||||
conversation_history = await handle.query(AgentGoalWorkflow.get_conversation_history)
|
||||
assert isinstance(conversation_history, dict)
|
||||
assert "messages" in conversation_history
|
||||
|
||||
# Test get_agent_goal query
|
||||
agent_goal = await handle.query(AgentGoalWorkflow.get_agent_goal)
|
||||
assert agent_goal.id == sample_combined_input.agent_goal.id
|
||||
|
||||
# Test get_summary_from_history query
|
||||
summary = await handle.query(AgentGoalWorkflow.get_summary_from_history)
|
||||
# Summary might be None if not set, so check for that
|
||||
if sample_combined_input.tool_params.conversation_summary:
|
||||
assert summary == sample_combined_input.tool_params.conversation_summary
|
||||
else:
|
||||
assert summary is None
|
||||
|
||||
# Test get_latest_tool_data query (should be None initially)
|
||||
tool_data = await handle.query(AgentGoalWorkflow.get_latest_tool_data)
|
||||
assert tool_data is None
|
||||
|
||||
# End workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
await handle.result()
|
||||
|
||||
async def test_enable_disable_debugging_confirm_signals(self, client: Client, sample_combined_input: CombinedInput):
|
||||
"""Test debugging confirm enable/disable signals."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[mock_get_wf_env_vars],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Test enable debugging confirm signal
|
||||
await handle.signal(AgentGoalWorkflow.enable_debugging_confirm)
|
||||
|
||||
# Test disable debugging confirm signal
|
||||
await handle.signal(AgentGoalWorkflow.disable_debugging_confirm)
|
||||
|
||||
# End workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
result = await handle.result()
|
||||
assert isinstance(result, str)
|
||||
|
||||
async def test_workflow_with_empty_prompt_queue(self, client: Client, sample_agent_goal):
|
||||
"""Test workflow behavior with empty prompt queue."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create input with empty prompt queue
|
||||
from collections import deque
|
||||
tool_params = AgentGoalWorkflowParams(
|
||||
conversation_summary=None,
|
||||
prompt_queue=deque()
|
||||
)
|
||||
combined_input = CombinedInput(
|
||||
agent_goal=sample_agent_goal,
|
||||
tool_params=tool_params
|
||||
)
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[mock_get_wf_env_vars],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Give workflow time to initialize
|
||||
import asyncio
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Query initial state
|
||||
conversation_history = await handle.query(AgentGoalWorkflow.get_conversation_history)
|
||||
assert isinstance(conversation_history, dict)
|
||||
assert "messages" in conversation_history
|
||||
|
||||
# Should have no messages initially (empty prompt queue, no summary)
|
||||
messages = conversation_history["messages"]
|
||||
assert len(messages) == 0
|
||||
|
||||
# End workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
result = await handle.result()
|
||||
assert isinstance(result, str)
|
||||
|
||||
async def test_multiple_user_prompts(self, client: Client, sample_combined_input: CombinedInput):
|
||||
"""Test workflow handling multiple user prompts in sequence."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
@activity.defn(name="agent_validatePrompt")
|
||||
async def mock_agent_validatePrompt(validation_input: ValidationInput) -> ValidationResult:
|
||||
return ValidationResult(
|
||||
validationResult=True,
|
||||
validationFailedReason={}
|
||||
)
|
||||
|
||||
async def mock_agent_validatePrompt(
|
||||
validation_input: ValidationInput,
|
||||
) -> ValidationResult:
|
||||
return ValidationResult(validationResult=True, validationFailedReason={})
|
||||
|
||||
@activity.defn(name="agent_toolPlanner")
|
||||
async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict:
|
||||
# Keep workflow running for multiple prompts
|
||||
return {
|
||||
"next": "question",
|
||||
"response": f"Processed: {input.prompt}"
|
||||
"next": "confirm",
|
||||
"tool": "TestTool",
|
||||
"args": {"test_arg": "test_value"},
|
||||
"response": "Ready to execute tool",
|
||||
}
|
||||
|
||||
|
||||
@activity.defn(name="TestTool")
|
||||
async def mock_test_tool(args: dict) -> dict:
|
||||
return {"result": "Test tool executed successfully"}
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
@@ -497,7 +165,8 @@ class TestAgentGoalWorkflow:
|
||||
activities=[
|
||||
mock_get_wf_env_vars,
|
||||
mock_agent_validatePrompt,
|
||||
mock_agent_toolPlanner
|
||||
mock_agent_toolPlanner,
|
||||
mock_test_tool,
|
||||
],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
@@ -506,35 +175,369 @@ class TestAgentGoalWorkflow:
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Send multiple prompts
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "First message")
|
||||
|
||||
# Send user prompt that will require confirmation
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "Execute the test tool")
|
||||
|
||||
# Query to check tool data is set
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0.1) # Give workflow time to process
|
||||
|
||||
tool_data = await handle.query(AgentGoalWorkflow.get_latest_tool_data)
|
||||
if tool_data:
|
||||
assert tool_data.get("tool") == "TestTool"
|
||||
assert tool_data.get("next") == "confirm"
|
||||
|
||||
# Send confirmation and end chat
|
||||
await handle.signal(AgentGoalWorkflow.confirm)
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
|
||||
result = await handle.result()
|
||||
assert isinstance(result, str)
|
||||
|
||||
async def test_validation_failure(
|
||||
self, client: Client, sample_combined_input: CombinedInput
|
||||
):
|
||||
"""Test workflow handles validation failures correctly."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(show_confirm=True, multi_goal_mode=True)
|
||||
|
||||
@activity.defn(name="agent_validatePrompt")
|
||||
async def mock_agent_validatePrompt(
|
||||
validation_input: ValidationInput,
|
||||
) -> ValidationResult:
|
||||
return ValidationResult(
|
||||
validationResult=False,
|
||||
validationFailedReason={
|
||||
"next": "question",
|
||||
"response": "Your request doesn't make sense in this context",
|
||||
},
|
||||
)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[mock_get_wf_env_vars, mock_agent_validatePrompt],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Send invalid prompt
|
||||
await handle.signal(
|
||||
AgentGoalWorkflow.user_prompt, "Invalid nonsensical prompt"
|
||||
)
|
||||
|
||||
# Give workflow time to process the prompt
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# End workflow to check conversation
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
result = await handle.result()
|
||||
|
||||
# Verify validation failure message was added
|
||||
import json
|
||||
|
||||
try:
|
||||
conversation_history = json.loads(result.replace("'", '"'))
|
||||
except Exception:
|
||||
# Fallback to eval if json fails
|
||||
conversation_history = eval(result)
|
||||
messages = conversation_history["messages"]
|
||||
|
||||
# Should have validation failure response
|
||||
agent_messages = [msg for msg in messages if msg["actor"] == "agent"]
|
||||
assert len(agent_messages) > 0
|
||||
assert any(
|
||||
"doesn't make sense" in str(msg["response"]) for msg in agent_messages
|
||||
)
|
||||
|
||||
async def test_conversation_summary_initialization(
|
||||
self, client: Client, sample_agent_goal
|
||||
):
|
||||
"""Test workflow initializes with conversation summary."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create input with conversation summary
|
||||
from collections import deque
|
||||
|
||||
tool_params = AgentGoalWorkflowParams(
|
||||
conversation_summary="Previous conversation summary", prompt_queue=deque()
|
||||
)
|
||||
combined_input = CombinedInput(
|
||||
agent_goal=sample_agent_goal, tool_params=tool_params
|
||||
)
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(show_confirm=True, multi_goal_mode=True)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[mock_get_wf_env_vars],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Give workflow time to initialize
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "Second message")
|
||||
|
||||
# Query conversation summary
|
||||
summary = await handle.query(AgentGoalWorkflow.get_summary_from_history)
|
||||
assert summary == "Previous conversation summary"
|
||||
|
||||
# Query conversation history - should include summary message
|
||||
conversation_history = await handle.query(
|
||||
AgentGoalWorkflow.get_conversation_history
|
||||
)
|
||||
messages = conversation_history["messages"]
|
||||
|
||||
# Should have conversation_summary message
|
||||
summary_messages = [
|
||||
msg for msg in messages if msg["actor"] == "conversation_summary"
|
||||
]
|
||||
assert len(summary_messages) == 1
|
||||
assert summary_messages[0]["response"] == "Previous conversation summary"
|
||||
|
||||
# End workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
await handle.result()
|
||||
|
||||
async def test_workflow_queries(
|
||||
self, client: Client, sample_combined_input: CombinedInput
|
||||
):
|
||||
"""Test all workflow query methods."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(show_confirm=True, multi_goal_mode=True)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[mock_get_wf_env_vars],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Give workflow time to initialize
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "Third message")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
# Test get_conversation_history query
|
||||
conversation_history = await handle.query(
|
||||
AgentGoalWorkflow.get_conversation_history
|
||||
)
|
||||
assert isinstance(conversation_history, dict)
|
||||
assert "messages" in conversation_history
|
||||
|
||||
# Test get_agent_goal query
|
||||
agent_goal = await handle.query(AgentGoalWorkflow.get_agent_goal)
|
||||
assert agent_goal.id == sample_combined_input.agent_goal.id
|
||||
|
||||
# Test get_summary_from_history query
|
||||
summary = await handle.query(AgentGoalWorkflow.get_summary_from_history)
|
||||
# Summary might be None if not set, so check for that
|
||||
if sample_combined_input.tool_params.conversation_summary:
|
||||
assert summary == sample_combined_input.tool_params.conversation_summary
|
||||
else:
|
||||
assert summary is None
|
||||
|
||||
# Test get_latest_tool_data query (should be None initially)
|
||||
tool_data = await handle.query(AgentGoalWorkflow.get_latest_tool_data)
|
||||
assert tool_data is None
|
||||
|
||||
# End workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
await handle.result()
|
||||
|
||||
async def test_enable_disable_debugging_confirm_signals(
|
||||
self, client: Client, sample_combined_input: CombinedInput
|
||||
):
|
||||
"""Test debugging confirm enable/disable signals."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(show_confirm=True, multi_goal_mode=True)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[mock_get_wf_env_vars],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Test enable debugging confirm signal
|
||||
await handle.signal(AgentGoalWorkflow.enable_debugging_confirm)
|
||||
|
||||
# Test disable debugging confirm signal
|
||||
await handle.signal(AgentGoalWorkflow.disable_debugging_confirm)
|
||||
|
||||
# End workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
result = await handle.result()
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
async def test_workflow_with_empty_prompt_queue(
|
||||
self, client: Client, sample_agent_goal
|
||||
):
|
||||
"""Test workflow behavior with empty prompt queue."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create input with empty prompt queue
|
||||
from collections import deque
|
||||
|
||||
tool_params = AgentGoalWorkflowParams(
|
||||
conversation_summary=None, prompt_queue=deque()
|
||||
)
|
||||
combined_input = CombinedInput(
|
||||
agent_goal=sample_agent_goal, tool_params=tool_params
|
||||
)
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(show_confirm=True, multi_goal_mode=True)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[mock_get_wf_env_vars],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Give workflow time to initialize
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Query initial state
|
||||
conversation_history = await handle.query(
|
||||
AgentGoalWorkflow.get_conversation_history
|
||||
)
|
||||
assert isinstance(conversation_history, dict)
|
||||
assert "messages" in conversation_history
|
||||
|
||||
# Should have no messages initially (empty prompt queue, no summary)
|
||||
messages = conversation_history["messages"]
|
||||
assert len(messages) == 0
|
||||
|
||||
# End workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
result = await handle.result()
|
||||
assert isinstance(result, str)
|
||||
|
||||
async def test_multiple_user_prompts(
|
||||
self, client: Client, sample_combined_input: CombinedInput
|
||||
):
|
||||
"""Test workflow handling multiple user prompts in sequence."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(show_confirm=True, multi_goal_mode=True)
|
||||
|
||||
@activity.defn(name="agent_validatePrompt")
|
||||
async def mock_agent_validatePrompt(
|
||||
validation_input: ValidationInput,
|
||||
) -> ValidationResult:
|
||||
return ValidationResult(validationResult=True, validationFailedReason={})
|
||||
|
||||
@activity.defn(name="agent_toolPlanner")
|
||||
async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict:
|
||||
# Keep workflow running for multiple prompts
|
||||
return {"next": "question", "response": f"Processed: {input.prompt}"}
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[
|
||||
mock_get_wf_env_vars,
|
||||
mock_agent_validatePrompt,
|
||||
mock_agent_toolPlanner,
|
||||
],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Send multiple prompts
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "First message")
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "Second message")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "Third message")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# End workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
result = await handle.result()
|
||||
assert isinstance(result, str)
|
||||
|
||||
# Parse result and verify multiple messages
|
||||
import json
|
||||
|
||||
try:
|
||||
conversation_history = json.loads(result.replace("'", '"'))
|
||||
except:
|
||||
except Exception:
|
||||
conversation_history = eval(result)
|
||||
messages = conversation_history["messages"]
|
||||
|
||||
|
||||
# Should have at least one user message (timing dependent)
|
||||
user_messages = [msg for msg in messages if msg["actor"] == "user"]
|
||||
assert len(user_messages) >= 1
|
||||
|
||||
|
||||
# Verify at least the first message was processed
|
||||
message_texts = [str(msg["response"]) for msg in user_messages]
|
||||
assert any("First message" in text for text in message_texts)
|
||||
assert any("First message" in text for text in message_texts)
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from temporalio.client import Client
|
||||
from temporalio.worker import Worker
|
||||
from temporalio.testing import ActivityEnvironment
|
||||
|
||||
from activities.tool_activities import ToolActivities, dynamic_tool_activity
|
||||
from models.data_types import (
|
||||
EnvLookupInput,
|
||||
EnvLookupOutput,
|
||||
ToolPromptInput,
|
||||
ValidationInput,
|
||||
ValidationResult,
|
||||
ToolPromptInput,
|
||||
EnvLookupInput,
|
||||
EnvLookupOutput
|
||||
)
|
||||
|
||||
|
||||
@@ -25,63 +24,66 @@ class TestToolActivities:
|
||||
self.tool_activities = ToolActivities()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_validatePrompt_valid_prompt(self, sample_agent_goal, sample_conversation_history):
|
||||
async def test_agent_validatePrompt_valid_prompt(
|
||||
self, sample_agent_goal, sample_conversation_history
|
||||
):
|
||||
"""Test agent_validatePrompt with a valid prompt."""
|
||||
validation_input = ValidationInput(
|
||||
prompt="I need help with the test tool",
|
||||
conversation_history=sample_conversation_history,
|
||||
agent_goal=sample_agent_goal
|
||||
agent_goal=sample_agent_goal,
|
||||
)
|
||||
|
||||
|
||||
# Mock the agent_toolPlanner to return a valid response
|
||||
mock_response = {
|
||||
"validationResult": True,
|
||||
"validationFailedReason": {}
|
||||
}
|
||||
|
||||
with patch.object(self.tool_activities, 'agent_toolPlanner', new_callable=AsyncMock) as mock_planner:
|
||||
mock_response = {"validationResult": True, "validationFailedReason": {}}
|
||||
|
||||
with patch.object(
|
||||
self.tool_activities, "agent_toolPlanner", new_callable=AsyncMock
|
||||
) as mock_planner:
|
||||
mock_planner.return_value = mock_response
|
||||
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.agent_validatePrompt,
|
||||
validation_input
|
||||
self.tool_activities.agent_validatePrompt, validation_input
|
||||
)
|
||||
|
||||
|
||||
assert isinstance(result, ValidationResult)
|
||||
assert result.validationResult is True
|
||||
assert result.validationFailedReason == {}
|
||||
|
||||
|
||||
# Verify the mock was called with correct parameters
|
||||
mock_planner.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_validatePrompt_invalid_prompt(self, sample_agent_goal, sample_conversation_history):
|
||||
async def test_agent_validatePrompt_invalid_prompt(
|
||||
self, sample_agent_goal, sample_conversation_history
|
||||
):
|
||||
"""Test agent_validatePrompt with an invalid prompt."""
|
||||
validation_input = ValidationInput(
|
||||
prompt="asdfghjkl nonsense",
|
||||
conversation_history=sample_conversation_history,
|
||||
agent_goal=sample_agent_goal
|
||||
agent_goal=sample_agent_goal,
|
||||
)
|
||||
|
||||
|
||||
# Mock the agent_toolPlanner to return an invalid response
|
||||
mock_response = {
|
||||
"validationResult": False,
|
||||
"validationFailedReason": {
|
||||
"next": "question",
|
||||
"response": "Your request doesn't make sense in this context"
|
||||
}
|
||||
"response": "Your request doesn't make sense in this context",
|
||||
},
|
||||
}
|
||||
|
||||
with patch.object(self.tool_activities, 'agent_toolPlanner', new_callable=AsyncMock) as mock_planner:
|
||||
|
||||
with patch.object(
|
||||
self.tool_activities, "agent_toolPlanner", new_callable=AsyncMock
|
||||
) as mock_planner:
|
||||
mock_planner.return_value = mock_response
|
||||
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.agent_validatePrompt,
|
||||
validation_input
|
||||
self.tool_activities.agent_validatePrompt, validation_input
|
||||
)
|
||||
|
||||
|
||||
assert isinstance(result, ValidationResult)
|
||||
assert result.validationResult is False
|
||||
assert "doesn't make sense" in str(result.validationFailedReason)
|
||||
@@ -90,29 +92,29 @@ class TestToolActivities:
|
||||
async def test_agent_toolPlanner_success(self):
|
||||
"""Test agent_toolPlanner with successful LLM response."""
|
||||
prompt_input = ToolPromptInput(
|
||||
prompt="Test prompt",
|
||||
context_instructions="Test context instructions"
|
||||
prompt="Test prompt", context_instructions="Test context instructions"
|
||||
)
|
||||
|
||||
|
||||
# Mock the completion function
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = '{"next": "confirm", "tool": "TestTool", "response": "Test response"}'
|
||||
|
||||
with patch('activities.tool_activities.completion') as mock_completion:
|
||||
mock_response.choices[0].message.content = (
|
||||
'{"next": "confirm", "tool": "TestTool", "response": "Test response"}'
|
||||
)
|
||||
|
||||
with patch("activities.tool_activities.completion") as mock_completion:
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.agent_toolPlanner,
|
||||
prompt_input
|
||||
self.tool_activities.agent_toolPlanner, prompt_input
|
||||
)
|
||||
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["next"] == "confirm"
|
||||
assert result["tool"] == "TestTool"
|
||||
assert result["response"] == "Test response"
|
||||
|
||||
|
||||
# Verify completion was called with correct parameters
|
||||
mock_completion.assert_called_once()
|
||||
call_args = mock_completion.call_args[1]
|
||||
@@ -125,27 +127,25 @@ class TestToolActivities:
|
||||
async def test_agent_toolPlanner_with_custom_base_url(self):
|
||||
"""Test agent_toolPlanner with custom base URL configuration."""
|
||||
# Set up tool activities with custom base URL
|
||||
with patch.dict(os.environ, {'LLM_BASE_URL': 'https://custom.endpoint.com'}):
|
||||
with patch.dict(os.environ, {"LLM_BASE_URL": "https://custom.endpoint.com"}):
|
||||
tool_activities = ToolActivities()
|
||||
|
||||
|
||||
prompt_input = ToolPromptInput(
|
||||
prompt="Test prompt",
|
||||
context_instructions="Test context instructions"
|
||||
prompt="Test prompt", context_instructions="Test context instructions"
|
||||
)
|
||||
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = '{"next": "done", "response": "Test"}'
|
||||
|
||||
with patch('activities.tool_activities.completion') as mock_completion:
|
||||
mock_response.choices[0].message.content = (
|
||||
'{"next": "done", "response": "Test"}'
|
||||
)
|
||||
|
||||
with patch("activities.tool_activities.completion") as mock_completion:
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
await activity_env.run(
|
||||
tool_activities.agent_toolPlanner,
|
||||
prompt_input
|
||||
)
|
||||
|
||||
await activity_env.run(tool_activities.agent_toolPlanner, prompt_input)
|
||||
|
||||
# Verify base_url was included in the call
|
||||
call_args = mock_completion.call_args[1]
|
||||
assert "base_url" in call_args
|
||||
@@ -155,41 +155,37 @@ class TestToolActivities:
|
||||
async def test_agent_toolPlanner_json_parsing_error(self):
|
||||
"""Test agent_toolPlanner handles JSON parsing errors."""
|
||||
prompt_input = ToolPromptInput(
|
||||
prompt="Test prompt",
|
||||
context_instructions="Test context instructions"
|
||||
prompt="Test prompt", context_instructions="Test context instructions"
|
||||
)
|
||||
|
||||
|
||||
# Mock the completion function to return invalid JSON
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Invalid JSON response'
|
||||
|
||||
with patch('activities.tool_activities.completion') as mock_completion:
|
||||
mock_response.choices[0].message.content = "Invalid JSON response"
|
||||
|
||||
with patch("activities.tool_activities.completion") as mock_completion:
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
with pytest.raises(Exception): # Should raise JSON parsing error
|
||||
await activity_env.run(
|
||||
self.tool_activities.agent_toolPlanner,
|
||||
prompt_input
|
||||
self.tool_activities.agent_toolPlanner, prompt_input
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_wf_env_vars_default_values(self):
|
||||
"""Test get_wf_env_vars with default values."""
|
||||
env_input = EnvLookupInput(
|
||||
show_confirm_env_var_name="SHOW_CONFIRM",
|
||||
show_confirm_default=True
|
||||
show_confirm_env_var_name="SHOW_CONFIRM", show_confirm_default=True
|
||||
)
|
||||
|
||||
|
||||
# Clear environment variables
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.get_wf_env_vars,
|
||||
env_input
|
||||
self.tool_activities.get_wf_env_vars, env_input
|
||||
)
|
||||
|
||||
|
||||
assert isinstance(result, EnvLookupOutput)
|
||||
assert result.show_confirm is True # default value
|
||||
assert result.multi_goal_mode is True # default value
|
||||
@@ -198,21 +194,18 @@ class TestToolActivities:
|
||||
async def test_get_wf_env_vars_custom_values(self):
|
||||
"""Test get_wf_env_vars with custom environment values."""
|
||||
env_input = EnvLookupInput(
|
||||
show_confirm_env_var_name="SHOW_CONFIRM",
|
||||
show_confirm_default=True
|
||||
show_confirm_env_var_name="SHOW_CONFIRM", show_confirm_default=True
|
||||
)
|
||||
|
||||
|
||||
# Set environment variables
|
||||
with patch.dict(os.environ, {
|
||||
'SHOW_CONFIRM': 'false',
|
||||
'AGENT_GOAL': 'specific_goal'
|
||||
}):
|
||||
with patch.dict(
|
||||
os.environ, {"SHOW_CONFIRM": "false", "AGENT_GOAL": "specific_goal"}
|
||||
):
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.get_wf_env_vars,
|
||||
env_input
|
||||
self.tool_activities.get_wf_env_vars, env_input
|
||||
)
|
||||
|
||||
|
||||
assert isinstance(result, EnvLookupOutput)
|
||||
assert result.show_confirm is False # from env var
|
||||
assert result.multi_goal_mode is False # from env var
|
||||
@@ -220,20 +213,22 @@ class TestToolActivities:
|
||||
def test_sanitize_json_response(self):
|
||||
"""Test JSON response sanitization."""
|
||||
# Test with markdown code blocks
|
||||
response_with_markdown = "```json\n{\"test\": \"value\"}\n```"
|
||||
response_with_markdown = '```json\n{"test": "value"}\n```'
|
||||
sanitized = self.tool_activities.sanitize_json_response(response_with_markdown)
|
||||
assert sanitized == '{"test": "value"}'
|
||||
|
||||
|
||||
# Test with extra whitespace
|
||||
response_with_whitespace = " \n{\"test\": \"value\"} \n"
|
||||
sanitized = self.tool_activities.sanitize_json_response(response_with_whitespace)
|
||||
response_with_whitespace = ' \n{"test": "value"} \n'
|
||||
sanitized = self.tool_activities.sanitize_json_response(
|
||||
response_with_whitespace
|
||||
)
|
||||
assert sanitized == '{"test": "value"}'
|
||||
|
||||
def test_parse_json_response_success(self):
|
||||
"""Test successful JSON parsing."""
|
||||
json_string = '{"next": "confirm", "tool": "TestTool"}'
|
||||
result = self.tool_activities.parse_json_response(json_string)
|
||||
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["next"] == "confirm"
|
||||
assert result["tool"] == "TestTool"
|
||||
@@ -241,7 +236,7 @@ class TestToolActivities:
|
||||
def test_parse_json_response_failure(self):
|
||||
"""Test JSON parsing with invalid JSON."""
|
||||
invalid_json = "Not valid JSON"
|
||||
|
||||
|
||||
with pytest.raises(Exception): # Should raise JSON parsing error
|
||||
self.tool_activities.parse_json_response(invalid_json)
|
||||
|
||||
@@ -255,26 +250,22 @@ class TestDynamicToolActivity:
|
||||
# Mock the activity info and payload converter
|
||||
mock_info = MagicMock()
|
||||
mock_info.activity_type = "TestTool"
|
||||
|
||||
|
||||
mock_payload_converter = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
mock_payload.payload = b'{"test_arg": "test_value"}'
|
||||
mock_payload_converter.from_payload.return_value = {"test_arg": "test_value"}
|
||||
|
||||
|
||||
# Mock the handler function
|
||||
def mock_handler(args):
|
||||
return {"result": f"Handled {args['test_arg']}"}
|
||||
|
||||
with patch('temporalio.activity.info', return_value=mock_info), \
|
||||
patch('temporalio.activity.payload_converter', return_value=mock_payload_converter), \
|
||||
patch('tools.get_handler', return_value=mock_handler):
|
||||
|
||||
|
||||
with patch("temporalio.activity.info", return_value=mock_info), patch(
|
||||
"temporalio.activity.payload_converter", return_value=mock_payload_converter
|
||||
), patch("tools.get_handler", return_value=mock_handler):
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
dynamic_tool_activity,
|
||||
[mock_payload]
|
||||
)
|
||||
|
||||
result = await activity_env.run(dynamic_tool_activity, [mock_payload])
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["result"] == "Handled test_value"
|
||||
|
||||
@@ -284,26 +275,22 @@ class TestDynamicToolActivity:
|
||||
# Mock the activity info and payload converter
|
||||
mock_info = MagicMock()
|
||||
mock_info.activity_type = "AsyncTestTool"
|
||||
|
||||
|
||||
mock_payload_converter = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
mock_payload.payload = b'{"test_arg": "async_test"}'
|
||||
mock_payload_converter.from_payload.return_value = {"test_arg": "async_test"}
|
||||
|
||||
|
||||
# Mock the async handler function
|
||||
async def mock_async_handler(args):
|
||||
return {"async_result": f"Async handled {args['test_arg']}"}
|
||||
|
||||
with patch('temporalio.activity.info', return_value=mock_info), \
|
||||
patch('temporalio.activity.payload_converter', return_value=mock_payload_converter), \
|
||||
patch('tools.get_handler', return_value=mock_async_handler):
|
||||
|
||||
|
||||
with patch("temporalio.activity.info", return_value=mock_info), patch(
|
||||
"temporalio.activity.payload_converter", return_value=mock_payload_converter
|
||||
), patch("tools.get_handler", return_value=mock_async_handler):
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
dynamic_tool_activity,
|
||||
[mock_payload]
|
||||
)
|
||||
|
||||
result = await activity_env.run(dynamic_tool_activity, [mock_payload])
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["async_result"] == "Async handled async_test"
|
||||
|
||||
@@ -314,21 +301,17 @@ class TestToolActivitiesIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_activities_in_worker(self, client: Client):
|
||||
"""Test activities can be registered and executed in a worker."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
# task_queue_name = str(uuid.uuid4())
|
||||
tool_activities = ToolActivities()
|
||||
|
||||
|
||||
# Test get_wf_env_vars activity using ActivityEnvironment
|
||||
env_input = EnvLookupInput(
|
||||
show_confirm_env_var_name="TEST_CONFIRM",
|
||||
show_confirm_default=False
|
||||
show_confirm_env_var_name="TEST_CONFIRM", show_confirm_default=False
|
||||
)
|
||||
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
tool_activities.get_wf_env_vars,
|
||||
env_input
|
||||
)
|
||||
|
||||
result = await activity_env.run(tool_activities.get_wf_env_vars, env_input)
|
||||
|
||||
assert isinstance(result, EnvLookupOutput)
|
||||
assert isinstance(result.show_confirm, bool)
|
||||
assert isinstance(result.multi_goal_mode, bool)
|
||||
@@ -336,36 +319,36 @@ class TestToolActivitiesIntegration:
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error handling."""
|
||||
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test environment for each test."""
|
||||
self.tool_activities = ToolActivities()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_validatePrompt_with_empty_conversation_history(self, sample_agent_goal):
|
||||
async def test_agent_validatePrompt_with_empty_conversation_history(
|
||||
self, sample_agent_goal
|
||||
):
|
||||
"""Test validation with empty conversation history."""
|
||||
validation_input = ValidationInput(
|
||||
prompt="Test prompt",
|
||||
conversation_history={"messages": []},
|
||||
agent_goal=sample_agent_goal
|
||||
agent_goal=sample_agent_goal,
|
||||
)
|
||||
|
||||
mock_response = {
|
||||
"validationResult": True,
|
||||
"validationFailedReason": {}
|
||||
}
|
||||
|
||||
with patch.object(self.tool_activities, 'agent_toolPlanner', new_callable=AsyncMock) as mock_planner:
|
||||
|
||||
mock_response = {"validationResult": True, "validationFailedReason": {}}
|
||||
|
||||
with patch.object(
|
||||
self.tool_activities, "agent_toolPlanner", new_callable=AsyncMock
|
||||
) as mock_planner:
|
||||
mock_planner.return_value = mock_response
|
||||
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.agent_validatePrompt,
|
||||
validation_input
|
||||
self.tool_activities.agent_validatePrompt, validation_input
|
||||
)
|
||||
|
||||
|
||||
assert isinstance(result, ValidationResult)
|
||||
assert result.validationResult == True
|
||||
assert result.validationResult
|
||||
assert result.validationFailedReason == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -373,22 +356,22 @@ class TestEdgeCases:
|
||||
"""Test toolPlanner with very long prompt."""
|
||||
long_prompt = "This is a very long prompt " * 100
|
||||
tool_prompt_input = ToolPromptInput(
|
||||
prompt=long_prompt,
|
||||
context_instructions="Test context instructions"
|
||||
prompt=long_prompt, context_instructions="Test context instructions"
|
||||
)
|
||||
|
||||
|
||||
# Mock the completion response
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = '{"next": "done", "response": "Processed long prompt"}'
|
||||
|
||||
with patch('activities.tool_activities.completion', return_value=mock_response):
|
||||
mock_response.choices[0].message.content = (
|
||||
'{"next": "done", "response": "Processed long prompt"}'
|
||||
)
|
||||
|
||||
with patch("activities.tool_activities.completion", return_value=mock_response):
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.agent_toolPlanner,
|
||||
tool_prompt_input
|
||||
self.tool_activities.agent_toolPlanner, tool_prompt_input
|
||||
)
|
||||
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["next"] == "done"
|
||||
assert "Processed long prompt" in result["response"]
|
||||
@@ -397,15 +380,15 @@ class TestEdgeCases:
|
||||
async def test_sanitize_json_with_various_formats(self):
|
||||
"""Test JSON sanitization with various input formats."""
|
||||
# Test markdown code blocks
|
||||
markdown_json = "```json\n{\"test\": \"value\"}\n```"
|
||||
markdown_json = '```json\n{"test": "value"}\n```'
|
||||
result = self.tool_activities.sanitize_json_response(markdown_json)
|
||||
assert result == '{"test": "value"}'
|
||||
|
||||
|
||||
# Test with extra whitespace
|
||||
whitespace_json = " \n {\"test\": \"value\"} \n "
|
||||
whitespace_json = ' \n {"test": "value"} \n '
|
||||
result = self.tool_activities.sanitize_json_response(whitespace_json)
|
||||
assert result == '{"test": "value"}'
|
||||
|
||||
|
||||
# Test already clean JSON
|
||||
clean_json = '{"test": "value"}'
|
||||
result = self.tool_activities.sanitize_json_response(clean_json)
|
||||
@@ -423,44 +406,38 @@ class TestEdgeCases:
|
||||
# Test with "true" string
|
||||
with patch.dict(os.environ, {"TEST_CONFIRM": "true"}):
|
||||
env_input = EnvLookupInput(
|
||||
show_confirm_env_var_name="TEST_CONFIRM",
|
||||
show_confirm_default=False
|
||||
show_confirm_env_var_name="TEST_CONFIRM", show_confirm_default=False
|
||||
)
|
||||
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.get_wf_env_vars,
|
||||
env_input
|
||||
self.tool_activities.get_wf_env_vars, env_input
|
||||
)
|
||||
|
||||
assert result.show_confirm == True
|
||||
|
||||
|
||||
assert result.show_confirm
|
||||
|
||||
# Test with "false" string
|
||||
with patch.dict(os.environ, {"TEST_CONFIRM": "false"}):
|
||||
env_input = EnvLookupInput(
|
||||
show_confirm_env_var_name="TEST_CONFIRM",
|
||||
show_confirm_default=True
|
||||
show_confirm_env_var_name="TEST_CONFIRM", show_confirm_default=True
|
||||
)
|
||||
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.get_wf_env_vars,
|
||||
env_input
|
||||
self.tool_activities.get_wf_env_vars, env_input
|
||||
)
|
||||
|
||||
assert result.show_confirm == False
|
||||
|
||||
|
||||
assert not result.show_confirm
|
||||
|
||||
# Test with missing env var (should use default)
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
env_input = EnvLookupInput(
|
||||
show_confirm_env_var_name="MISSING_VAR",
|
||||
show_confirm_default=True
|
||||
show_confirm_env_var_name="MISSING_VAR", show_confirm_default=True
|
||||
)
|
||||
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.get_wf_env_vars,
|
||||
env_input
|
||||
self.tool_activities.get_wf_env_vars, env_input
|
||||
)
|
||||
|
||||
assert result.show_confirm == True
|
||||
|
||||
assert result.show_confirm
|
||||
|
||||
@@ -1,25 +1,22 @@
|
||||
import concurrent.futures
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
|
||||
from temporalio import activity
|
||||
from temporalio.client import Client, WorkflowExecutionStatus
|
||||
from temporalio.worker import Worker
|
||||
from temporalio import activity
|
||||
import concurrent.futures
|
||||
from temporalio.testing import WorkflowEnvironment
|
||||
|
||||
from api.main import get_initial_agent_goal
|
||||
from models.data_types import (
|
||||
AgentGoalWorkflowParams,
|
||||
AgentGoalWorkflowParams,
|
||||
CombinedInput,
|
||||
ValidationResult,
|
||||
ValidationInput,
|
||||
EnvLookupOutput,
|
||||
EnvLookupInput,
|
||||
ToolPromptInput
|
||||
EnvLookupOutput,
|
||||
ToolPromptInput,
|
||||
ValidationInput,
|
||||
ValidationResult,
|
||||
)
|
||||
from workflows.agent_goal_workflow import AgentGoalWorkflow
|
||||
from activities.tool_activities import ToolActivities, dynamic_tool_activity
|
||||
from unittest.mock import patch
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -29,57 +26,49 @@ def my_context():
|
||||
print("Cleanup")
|
||||
|
||||
|
||||
|
||||
async def test_flight_booking(client: Client):
|
||||
# load_dotenv("test_flights_single.env")
|
||||
|
||||
#load_dotenv("test_flights_single.env")
|
||||
|
||||
with my_context() as value:
|
||||
print(f"Working with {value}")
|
||||
|
||||
|
||||
|
||||
# Create the test environment
|
||||
#env = await WorkflowEnvironment.start_local()
|
||||
#client = env.client
|
||||
# env = await WorkflowEnvironment.start_local()
|
||||
# client = env.client
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
workflow_id = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
return EnvLookupOutput(show_confirm=True, multi_goal_mode=True)
|
||||
|
||||
@activity.defn(name="agent_validatePrompt")
|
||||
async def mock_agent_validatePrompt(validation_input: ValidationInput) -> ValidationResult:
|
||||
return ValidationResult(
|
||||
validationResult=True,
|
||||
validationFailedReason={}
|
||||
)
|
||||
|
||||
async def mock_agent_validatePrompt(
|
||||
validation_input: ValidationInput,
|
||||
) -> ValidationResult:
|
||||
return ValidationResult(validationResult=True, validationFailedReason={})
|
||||
|
||||
@activity.defn(name="agent_toolPlanner")
|
||||
async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict:
|
||||
return {
|
||||
"next": "done",
|
||||
"response": "Test response from LLM"
|
||||
}
|
||||
return {"next": "done", "response": "Test response from LLM"}
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor:
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=100
|
||||
) as activity_executor:
|
||||
worker = Worker(
|
||||
client,
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[
|
||||
mock_get_wf_env_vars,
|
||||
mock_agent_validatePrompt,
|
||||
mock_agent_toolPlanner
|
||||
mock_agent_toolPlanner,
|
||||
],
|
||||
activity_executor=activity_executor,
|
||||
)
|
||||
|
||||
async with worker:
|
||||
async with worker:
|
||||
initial_agent_goal = get_initial_agent_goal()
|
||||
# Create combined input
|
||||
combined_input = CombinedInput(
|
||||
@@ -87,30 +76,36 @@ async def test_flight_booking(client: Client):
|
||||
agent_goal=initial_agent_goal,
|
||||
)
|
||||
|
||||
prompt="Hello!"
|
||||
prompt = "Hello!"
|
||||
|
||||
#async with Worker(client, task_queue=task_queue_name, workflows=[AgentGoalWorkflow], activities=[ToolActivities.agent_validatePrompt, ToolActivities.agent_toolPlanner, dynamic_tool_activity]):
|
||||
# async with Worker(client, task_queue=task_queue_name, workflows=[AgentGoalWorkflow], activities=[ToolActivities.agent_validatePrompt, ToolActivities.agent_toolPlanner, dynamic_tool_activity]):
|
||||
|
||||
# todo set goal categories for scenarios
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
combined_input,
|
||||
id=workflow_id,
|
||||
id=workflow_id,
|
||||
task_queue=task_queue_name,
|
||||
start_signal="user_prompt",
|
||||
start_signal_args=[prompt],
|
||||
)
|
||||
# todo send signals to simulate user input
|
||||
# await handle.signal(AgentGoalWorkflow.user_prompt, "book flights") # for multi-goal
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "sydney in september")
|
||||
assert WorkflowExecutionStatus.RUNNING == (await handle.describe()).status
|
||||
await handle.signal(
|
||||
AgentGoalWorkflow.user_prompt, "sydney in september"
|
||||
)
|
||||
assert (
|
||||
WorkflowExecutionStatus.RUNNING == (await handle.describe()).status
|
||||
)
|
||||
|
||||
|
||||
#assert ["Hello, user1", "Hello, user2"] == await handle.result()
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "I'm all set, end conversation")
|
||||
|
||||
#assert WorkflowExecutionStatus.COMPLETED == (await handle.describe()).status
|
||||
# assert ["Hello, user1", "Hello, user2"] == await handle.result()
|
||||
await handle.signal(
|
||||
AgentGoalWorkflow.user_prompt, "I'm all set, end conversation"
|
||||
)
|
||||
|
||||
# assert WorkflowExecutionStatus.COMPLETED == (await handle.describe()).status
|
||||
|
||||
result = await handle.result()
|
||||
#todo dump workflow history for analysis optional
|
||||
#todo assert result is good
|
||||
print(f"Workflow result: {result}")
|
||||
# todo dump workflow history for analysis optional
|
||||
# todo assert result is good
|
||||
|
||||
Reference in New Issue
Block a user