mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
work on tests
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -32,3 +32,4 @@ coverage.xml
|
|||||||
.idea/
|
.idea/
|
||||||
|
|
||||||
.env
|
.env
|
||||||
|
./tests/*.env
|
||||||
|
|||||||
11
README.md
11
README.md
@@ -47,3 +47,14 @@ See [the guide to adding goals and tools](./adding-goals-and-tools.md) for more
|
|||||||
|
|
||||||
## For Temporal SAs
|
## For Temporal SAs
|
||||||
Check out the [slides](https://docs.google.com/presentation/d/1wUFY4v17vrtv8llreKEBDPLRtZte3FixxBUn0uWy5NU/edit#slide=id.g3333e5deaa9_0_0) here and the enablement guide here (TODO).
|
Check out the [slides](https://docs.google.com/presentation/d/1wUFY4v17vrtv8llreKEBDPLRtZte3FixxBUn0uWy5NU/edit#slide=id.g3333e5deaa9_0_0) here and the enablement guide here (TODO).
|
||||||
|
|
||||||
|
## Tests
|
||||||
|
|
||||||
|
Running the tests requires `poe` and `pytest_asyncio` to be installed.
|
||||||
|
|
||||||
|
python -m pip install poethepoet
|
||||||
|
python -m pip install pytest_asyncio
|
||||||
|
|
||||||
|
Once you have `poe` and `pytest_asyncio` installed you can run:
|
||||||
|
|
||||||
|
poe test
|
||||||
|
|||||||
@@ -15,6 +15,12 @@ packages = [
|
|||||||
[tool.poetry.urls]
|
[tool.poetry.urls]
|
||||||
"Bug Tracker" = "https://github.com/temporalio/samples-python/issues"
|
"Bug Tracker" = "https://github.com/temporalio/samples-python/issues"
|
||||||
|
|
||||||
|
[tool.poe.tasks]
|
||||||
|
format = [{cmd = "black ."}, {cmd = "isort ."}]
|
||||||
|
lint = [{cmd = "black --check ."}, {cmd = "isort --check-only ."}, {ref = "lint-types" }]
|
||||||
|
lint-types = "mypy --check-untyped-defs --namespace-packages ."
|
||||||
|
test = "pytest"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.10,<4.0"
|
python = ">=3.10,<4.0"
|
||||||
temporalio = "^1.8.0"
|
temporalio = "^1.8.0"
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from tools.search_events import find_events
|
from tools.search_flights import search_flights
|
||||||
import json
|
import json
|
||||||
|
|
||||||
# Example usage
|
# Example usage
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
search_args = {"city": "Sydney", "month": "July"}
|
search_args = {"city": "Sydney", "month": "July"}
|
||||||
results = find_events(search_args)
|
results = search_flights(search_args)
|
||||||
print(json.dumps(results, indent=2))
|
print(json.dumps(results, indent=2))
|
||||||
|
|||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
@@ -1,53 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
|
|
||||||
from temporalio.client import Client, WorkflowExecutionStatus
|
|
||||||
from temporalio.worker import Worker
|
|
||||||
from temporalio.testing import TestWorkflowEnvironment
|
|
||||||
from api.main import get_initial_agent_goal
|
|
||||||
from models.data_types import AgentGoalWorkflowParams, CombinedInput
|
|
||||||
from workflows.agent_goal_workflow import AgentGoalWorkflow
|
|
||||||
from activities.tool_activities import ToolActivities, dynamic_tool_activity
|
|
||||||
|
|
||||||
|
|
||||||
async def asyncSetUp(self):
|
|
||||||
# Set up the test environment
|
|
||||||
self.env = await TestWorkflowEnvironment.create_local()
|
|
||||||
|
|
||||||
async def asyncTearDown(self):
|
|
||||||
# Clean up after tests
|
|
||||||
await self.env.shutdown()
|
|
||||||
|
|
||||||
async def test_flight_booking(client: Client):
|
|
||||||
|
|
||||||
task_queue_name = "agent-ai-workflow"
|
|
||||||
workflow_id = "agent-workflow"
|
|
||||||
|
|
||||||
initial_agent_goal = get_initial_agent_goal()
|
|
||||||
|
|
||||||
# Create combined input
|
|
||||||
combined_input = CombinedInput(
|
|
||||||
tool_params=AgentGoalWorkflowParams(None, None),
|
|
||||||
agent_goal=initial_agent_goal,
|
|
||||||
)
|
|
||||||
|
|
||||||
workflow_id = "agent-workflow"
|
|
||||||
async with Worker(client, task_queue=task_queue_name, workflows=[AgentGoalWorkflow], activities=[ToolActivities.agent_validatePrompt, ToolActivities.agent_toolPlanner, dynamic_tool_activity]):
|
|
||||||
|
|
||||||
# todo set goal categories for scenarios
|
|
||||||
handle = await client.start_workflow(
|
|
||||||
AgentGoalWorkflow.run, id=workflow_id, task_queue=task_queue_name
|
|
||||||
)
|
|
||||||
# todo send signals based on
|
|
||||||
await handle.signal(AgentGoalWorkflow.user_prompt, "book flights")
|
|
||||||
await handle.signal(AgentGoalWorkflow.user_prompt, "sydney in september")
|
|
||||||
assert WorkflowExecutionStatus.RUNNING == (await handle.describe()).status
|
|
||||||
|
|
||||||
|
|
||||||
#assert ["Hello, user1", "Hello, user2"] == await handle.result()
|
|
||||||
await handle.signal(AgentGoalWorkflow.user_prompt, "I'm all set, end conversation")
|
|
||||||
assert WorkflowExecutionStatus.COMPLETED == (await handle.describe()).status
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
55
tests/conftest.py
Normal file
55
tests/conftest.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import asyncio
|
||||||
|
import multiprocessing
|
||||||
|
import sys
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from temporalio.client import Client
|
||||||
|
from temporalio.testing import WorkflowEnvironment
|
||||||
|
|
||||||
|
# Due to https://github.com/python/cpython/issues/77906, multiprocessing on
|
||||||
|
# macOS starting with Python 3.8 has changed from "fork" to "spawn". For
|
||||||
|
# pre-3.8, we are changing it for them.
|
||||||
|
if sys.version_info < (3, 8) and sys.platform.startswith("darwin"):
|
||||||
|
multiprocessing.set_start_method("spawn", True)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--workflow-environment",
|
||||||
|
default="local",
|
||||||
|
help="Which workflow environment to use ('local', 'time-skipping', or target to existing server)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
# See https://github.com/pytest-dev/pytest-asyncio/issues/68
|
||||||
|
# See https://github.com/pytest-dev/pytest-asyncio/issues/257
|
||||||
|
# Also need ProactorEventLoop on older versions of Python with Windows so
|
||||||
|
# that asyncio subprocess works properly
|
||||||
|
if sys.version_info < (3, 8) and sys.platform == "win32":
|
||||||
|
loop = asyncio.ProactorEventLoop()
|
||||||
|
else:
|
||||||
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def env(request) -> AsyncGenerator[WorkflowEnvironment, None]:
|
||||||
|
env_type = request.config.getoption("--workflow-environment")
|
||||||
|
if env_type == "local":
|
||||||
|
env = await WorkflowEnvironment.start_local()
|
||||||
|
elif env_type == "time-skipping":
|
||||||
|
env = await WorkflowEnvironment.start_time_skipping()
|
||||||
|
else:
|
||||||
|
env = WorkflowEnvironment.from_client(await Client.connect(env_type))
|
||||||
|
yield env
|
||||||
|
await env.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def client(env: WorkflowEnvironment) -> Client:
|
||||||
|
return env.client
|
||||||
80
tests/workflowtests/agent_goal_workflow_test.py
Normal file
80
tests/workflowtests/agent_goal_workflow_test.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
from temporalio.client import Client, WorkflowExecutionStatus
|
||||||
|
from temporalio.worker import Worker
|
||||||
|
import concurrent.futures
|
||||||
|
from temporalio.testing import WorkflowEnvironment
|
||||||
|
from api.main import get_initial_agent_goal
|
||||||
|
from models.data_types import AgentGoalWorkflowParams, CombinedInput
|
||||||
|
from workflows.agent_goal_workflow import AgentGoalWorkflow
|
||||||
|
from activities.tool_activities import ToolActivities, dynamic_tool_activity
|
||||||
|
from unittest.mock import patch
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def my_context():
|
||||||
|
print("Setup")
|
||||||
|
yield "some_value" # Value assigned to 'as' variable
|
||||||
|
print("Cleanup")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flight_booking(client: Client):
|
||||||
|
|
||||||
|
#load_dotenv("test_flights_single.env")
|
||||||
|
|
||||||
|
with my_context() as value:
|
||||||
|
print(f"Working with {value}")
|
||||||
|
|
||||||
|
|
||||||
|
# Create the test environment
|
||||||
|
#env = await WorkflowEnvironment.start_local()
|
||||||
|
#client = env.client
|
||||||
|
task_queue_name = "agent-ai-workflow"
|
||||||
|
workflow_id = "agent-workflow"
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor:
|
||||||
|
worker = Worker(
|
||||||
|
client,
|
||||||
|
task_queue=task_queue_name,
|
||||||
|
workflows=[AgentGoalWorkflow],
|
||||||
|
activities=[ToolActivities.agent_validatePrompt, ToolActivities.agent_toolPlanner, ToolActivities.get_wf_env_vars, dynamic_tool_activity],
|
||||||
|
activity_executor=activity_executor,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with worker:
|
||||||
|
initial_agent_goal = get_initial_agent_goal()
|
||||||
|
# Create combined input
|
||||||
|
combined_input = CombinedInput(
|
||||||
|
tool_params=AgentGoalWorkflowParams(None, None),
|
||||||
|
agent_goal=initial_agent_goal,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt="Hello!"
|
||||||
|
|
||||||
|
#async with Worker(client, task_queue=task_queue_name, workflows=[AgentGoalWorkflow], activities=[ToolActivities.agent_validatePrompt, ToolActivities.agent_toolPlanner, dynamic_tool_activity]):
|
||||||
|
|
||||||
|
# todo set goal categories for scenarios
|
||||||
|
handle = await client.start_workflow(
|
||||||
|
AgentGoalWorkflow.run,
|
||||||
|
combined_input,
|
||||||
|
id=workflow_id,
|
||||||
|
task_queue=task_queue_name,
|
||||||
|
start_signal="user_prompt",
|
||||||
|
start_signal_args=[prompt],
|
||||||
|
)
|
||||||
|
# todo send signals to simulate user input
|
||||||
|
# await handle.signal(AgentGoalWorkflow.user_prompt, "book flights") # for multi-goal
|
||||||
|
await handle.signal(AgentGoalWorkflow.user_prompt, "sydney in september")
|
||||||
|
assert WorkflowExecutionStatus.RUNNING == (await handle.describe()).status
|
||||||
|
|
||||||
|
|
||||||
|
#assert ["Hello, user1", "Hello, user2"] == await handle.result()
|
||||||
|
await handle.signal(AgentGoalWorkflow.user_prompt, "I'm all set, end conversation")
|
||||||
|
|
||||||
|
#assert WorkflowExecutionStatus.COMPLETED == (await handle.describe()).status
|
||||||
|
|
||||||
|
result = await handle.result()
|
||||||
|
#todo dump workflow history for analysis optional
|
||||||
|
#todo assert result is good
|
||||||
Reference in New Issue
Block a user