mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 22:18:09 +01:00
Compare commits
9 Commits
docker-onl
...
food-order
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1b7c273e55 | ||
|
|
e35181b5ad | ||
|
|
f7ef2b1c7e | ||
|
|
71e54b9ecd | ||
|
|
a7a2002217 | ||
|
|
5a3bfbd848 | ||
|
|
7bb6688797 | ||
|
|
847f4bbaef | ||
|
|
f8e0dd3b2a |
22
.env.example
22
.env.example
@@ -1,27 +1,13 @@
|
||||
RAPIDAPI_KEY=9df2cb5...
|
||||
RAPIDAPI_HOST_FLIGHTS=sky-scrapper.p.rapidapi.com #For travel flight information tool
|
||||
RAPIDAPI_HOST_PACKAGE=trackingpackage.p.rapidapi.com #For eCommerce order status package tracking tool
|
||||
FOOTBALL_DATA_API_KEY=....
|
||||
FOOTBALL_DATA_API_KEY=
|
||||
# Leave blank to use the built-in mock fixtures generator
|
||||
|
||||
STRIPE_API_KEY=sk_test_51J...
|
||||
|
||||
LLM_PROVIDER=openai # default
|
||||
OPENAI_API_KEY=sk-proj-...
|
||||
# or
|
||||
#LLM_PROVIDER=grok
|
||||
#GROK_API_KEY=xai-your-grok-api-key
|
||||
# or
|
||||
# LLM_PROVIDER=ollama
|
||||
# OLLAMA_MODEL_NAME=qwen2.5:14b
|
||||
# or
|
||||
# LLM_PROVIDER=google
|
||||
# GOOGLE_API_KEY=your-google-api-key
|
||||
# or
|
||||
# LLM_PROVIDER=anthropic
|
||||
# ANTHROPIC_API_KEY=your-anthropic-api-key
|
||||
# or
|
||||
# LLM_PROVIDER=deepseek
|
||||
# DEEPSEEK_API_KEY=your-deepseek-api-key
|
||||
LLM_MODEL=openai/gpt-4o # default
|
||||
LLM_KEY=sk-proj-...
|
||||
|
||||
|
||||
# uncomment and unset these environment variables to connect to the local dev server
|
||||
|
||||
175
AGENTS.md
Normal file
175
AGENTS.md
Normal file
@@ -0,0 +1,175 @@
|
||||
# Temporal AI Agent Contribution Guide
|
||||
|
||||
## Repository Layout
|
||||
- `workflows/` - Temporal workflows including the main AgentGoalWorkflow for multi-turn AI conversations
|
||||
- `activities/` - Temporal activities for tool execution and LLM interactions
|
||||
- `tools/` - AI agent tools organized by category (finance, HR, ecommerce, travel, etc.)
|
||||
- `models/` - Data types and tool definitions used throughout the system
|
||||
- `prompts/` - Agent prompt generators and templates
|
||||
- `api/` - FastAPI server that exposes REST endpoints to interact with workflows
|
||||
- `frontend/` - React-based web UI for chatting with the AI agent
|
||||
- `tests/` - Comprehensive test suite for workflows and activities using Temporal's testing framework
|
||||
- `enterprise/` - .NET worker implementation for enterprise activities (train booking)
|
||||
- `scripts/` - Utility scripts for running workers and testing tools
|
||||
|
||||
## Running the Application
|
||||
|
||||
### Quick Start with Docker
|
||||
```bash
|
||||
# Start all services with development hot-reload
|
||||
docker compose up -d
|
||||
|
||||
# Quick rebuild without infrastructure
|
||||
docker compose up -d --no-deps --build api worker frontend
|
||||
```
|
||||
|
||||
Default URLs:
|
||||
- Temporal UI: http://localhost:8080
|
||||
- API: http://localhost:8000
|
||||
- Frontend: http://localhost:5173
|
||||
|
||||
### Local Development Setup
|
||||
|
||||
1. **Prerequisites:**
|
||||
```bash
|
||||
# Install Poetry for Python dependency management
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# Start Temporal server (Mac)
|
||||
brew install temporal
|
||||
temporal server start-dev
|
||||
```
|
||||
|
||||
2. **Backend (Python):**
|
||||
```bash
|
||||
# Quick setup using Makefile
|
||||
make setup # Creates venv and installs dependencies
|
||||
make run-worker # Starts the Temporal worker
|
||||
make run-api # Starts the API server
|
||||
|
||||
# Or manually:
|
||||
poetry install
|
||||
poetry run python scripts/run_worker.py # In one terminal
|
||||
poetry run uvicorn api.main:app --reload # In another terminal
|
||||
```
|
||||
|
||||
3. **Frontend (React):**
|
||||
```bash
|
||||
make run-frontend # Using Makefile
|
||||
|
||||
# Or manually:
|
||||
cd frontend
|
||||
npm install
|
||||
npx vite
|
||||
```
|
||||
|
||||
4. **Enterprise .NET Worker (optional):**
|
||||
```bash
|
||||
make run-enterprise # Using Makefile
|
||||
|
||||
# Or manually:
|
||||
cd enterprise
|
||||
dotnet build
|
||||
dotnet run
|
||||
```
|
||||
|
||||
### Environment Configuration
|
||||
Copy `.env.example` to `.env` and configure:
|
||||
```bash
|
||||
# Required: LLM Configuration
|
||||
LLM_MODEL=openai/gpt-4o # or anthropic/claude-3-sonnet, etc.
|
||||
LLM_KEY=your-api-key-here
|
||||
|
||||
# Optional: Agent Goals and Categories
|
||||
AGENT_GOAL=goal_choose_agent_type
|
||||
GOAL_CATEGORIES=hr,travel-flights,travel-trains,fin
|
||||
|
||||
# Optional: Tool-specific APIs
|
||||
STRIPE_API_KEY=sk_test_... # For invoice creation
|
||||
FOOTBALL_DATA_API_KEY=... # For real football fixtures
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
The project includes comprehensive tests using Temporal's testing framework:
|
||||
|
||||
```bash
|
||||
# Install test dependencies
|
||||
poetry install --with dev
|
||||
|
||||
# Run all tests
|
||||
poetry run pytest
|
||||
|
||||
# Run with time-skipping for faster execution
|
||||
poetry run pytest --workflow-environment=time-skipping
|
||||
|
||||
# Run specific test categories
|
||||
poetry run pytest tests/test_tool_activities.py -v # Activity tests
|
||||
poetry run pytest tests/test_agent_goal_workflow.py -v # Workflow tests
|
||||
|
||||
# Run with coverage
|
||||
poetry run pytest --cov=workflows --cov=activities
|
||||
```
|
||||
|
||||
**Test Coverage:**
|
||||
- ✅ **Workflow Tests**: AgentGoalWorkflow signals, queries, state management
|
||||
- ✅ **Activity Tests**: ToolActivities, LLM integration (mocked), environment configuration
|
||||
- ✅ **Integration Tests**: End-to-end workflow and activity execution
|
||||
|
||||
**Documentation:**
|
||||
- **Quick Start**: [TESTING.md](TESTING.md) - Simple commands to run tests
|
||||
- **Comprehensive Guide**: [tests/README.md](tests/README.md) - Detailed testing patterns and best practices
|
||||
|
||||
## Linting and Code Quality
|
||||
|
||||
```bash
|
||||
# Using Poetry tasks
|
||||
poetry run poe format # Format code with black and isort
|
||||
poetry run poe lint # Check code style and types
|
||||
poetry run poe test # Run test suite
|
||||
|
||||
# Manual commands
|
||||
poetry run black .
|
||||
poetry run isort .
|
||||
poetry run mypy --check-untyped-defs --namespace-packages .
|
||||
```
|
||||
|
||||
## Agent Customization
|
||||
|
||||
### Adding New Tools
|
||||
1. Create tool implementation in `tools/` directory
|
||||
2. Add tool function mapping in `tools/__init__.py`
|
||||
3. Register tool definition in `tools/tool_registry.py`
|
||||
4. Associate with goals in `tools/goal_registry.py`
|
||||
|
||||
### Configuring Goals
|
||||
The agent supports multiple goal categories:
|
||||
- **Financial**: Money transfers, loan applications (`fin/`)
|
||||
- **HR**: PTO booking, payroll status (`hr/`)
|
||||
- **Travel**: Flight/train booking, event finding
|
||||
- **Ecommerce**: Order tracking, package management (`ecommerce/`)
|
||||
|
||||
See [adding-goals-and-tools.md](adding-goals-and-tools.md) for detailed customization guide.
|
||||
|
||||
## Architecture
|
||||
|
||||
This system implements "Agentic AI" with these key components:
|
||||
1. **Goals** - High-level objectives accomplished through tool sequences
|
||||
2. **Agent Loops** - LLM execution → tool calls → human input → repeat until goal completion
|
||||
3. **Tool Approval** - Human confirmation for sensitive operations
|
||||
4. **Conversation Management** - LLM-powered input validation and history summarization
|
||||
5. **Durability** - Temporal workflows ensure reliable execution across failures
|
||||
|
||||
For detailed architecture information, see [architecture.md](architecture.md).
|
||||
|
||||
## Commit Messages and Pull Requests
|
||||
- Use clear commit messages describing the change purpose
|
||||
- Reference specific files and line numbers when relevant (e.g., `workflows/agent_goal_workflow.py:125`)
|
||||
- Open PRs describing **what changed** and **why**
|
||||
- Ensure tests pass before submitting: `poetry run pytest --workflow-environment=time-skipping`
|
||||
|
||||
## Additional Resources
|
||||
- **Setup Guide**: [setup.md](setup.md) - Detailed configuration instructions
|
||||
- **Architecture Decisions**: [architecture-decisions.md](architecture-decisions.md) - Why Temporal for AI agents
|
||||
- **Demo Video**: [5-minute YouTube overview](https://www.youtube.com/watch?v=GEXllEH2XiQ)
|
||||
- **Multi-Agent Demo**: [Advanced multi-agent execution](https://www.youtube.com/watch?v=8Dc_0dC14yY)
|
||||
63
Makefile
Normal file
63
Makefile
Normal file
@@ -0,0 +1,63 @@
|
||||
.PHONY: setup install run-worker run-api run-frontend run-train-api run-legacy-worker run-enterprise setup-venv check-python run-dev
|
||||
|
||||
# Setup commands
|
||||
setup: check-python setup-venv install
|
||||
|
||||
check-python:
|
||||
@which python3 >/dev/null 2>&1 || (echo "Python 3 is required. Please install it first." && exit 1)
|
||||
@which poetry >/dev/null 2>&1 || (echo "Poetry is required. Please install it first." && exit 1)
|
||||
|
||||
setup-venv:
|
||||
python3 -m venv venv
|
||||
@echo "Virtual environment created. Don't forget to activate it with 'source venv/bin/activate'"
|
||||
|
||||
install:
|
||||
poetry install
|
||||
cd frontend && npm install
|
||||
|
||||
# Run commands
|
||||
run-worker:
|
||||
poetry run python scripts/run_worker.py
|
||||
|
||||
run-api:
|
||||
poetry run uvicorn api.main:app --reload
|
||||
|
||||
run-frontend:
|
||||
cd frontend && npx vite
|
||||
|
||||
run-train-api:
|
||||
poetry run python thirdparty/train_api.py
|
||||
|
||||
run-legacy-worker:
|
||||
poetry run python scripts/run_legacy_worker.py
|
||||
|
||||
run-enterprise:
|
||||
cd enterprise && dotnet build && dotnet run
|
||||
|
||||
# Development environment setup
|
||||
setup-temporal-mac:
|
||||
brew install temporal
|
||||
temporal server start-dev
|
||||
|
||||
# Run all development services
|
||||
run-dev:
|
||||
@echo "Starting all development services..."
|
||||
@make run-worker & \
|
||||
make run-api & \
|
||||
make run-frontend & \
|
||||
wait
|
||||
|
||||
# Help command
|
||||
help:
|
||||
@echo "Available commands:"
|
||||
@echo " make setup - Create virtual environment and install dependencies"
|
||||
@echo " make setup-venv - Create virtual environment only"
|
||||
@echo " make install - Install all dependencies"
|
||||
@echo " make run-worker - Start the Temporal worker"
|
||||
@echo " make run-api - Start the API server"
|
||||
@echo " make run-frontend - Start the frontend development server"
|
||||
@echo " make run-train-api - Start the train API server"
|
||||
@echo " make run-legacy-worker - Start the legacy worker"
|
||||
@echo " make run-enterprise - Build and run the enterprise .NET worker"
|
||||
@echo " make setup-temporal-mac - Install and start Temporal server on Mac"
|
||||
@echo " make run-dev - Start all development services (worker, API, frontend) in parallel"
|
||||
49
README.md
49
README.md
@@ -2,7 +2,13 @@
|
||||
|
||||
This demo shows a multi-turn conversation with an AI agent running inside a Temporal workflow. The purpose of the agent is to collect information towards a goal, running tools along the way. There's a simple DSL input for collecting information (currently set up to use mock functions to search for public events, search for flights around those events, then create a test Stripe invoice for the trip).
|
||||
|
||||
The AI will respond with clarifications and ask for any missing information to that goal. You can configure it to use [ChatGPT 4o](https://openai.com/index/hello-gpt-4o/), [Anthropic Claude](https://www.anthropic.com/claude), [Google Gemini](https://gemini.google.com), [Deepseek-V3](https://www.deepseek.com/), [Grok](https://docs.x.ai/docs/overview) or a local LLM of your choice using [Ollama](https://ollama.com).
|
||||
The AI will respond with clarifications and ask for any missing information to that goal. You can configure it to use any LLM supported by [LiteLLM](https://docs.litellm.ai/docs/providers), including:
|
||||
- OpenAI models (GPT-4, GPT-3.5)
|
||||
- Anthropic Claude models
|
||||
- Google Gemini models
|
||||
- Deepseek models
|
||||
- Ollama models (local)
|
||||
- And many more!
|
||||
|
||||
It's really helpful to [watch the demo (5 minute YouTube video)](https://www.youtube.com/watch?v=GEXllEH2XiQ) to understand how interaction works.
|
||||
|
||||
@@ -28,7 +34,11 @@ These are the key elements of an agentic framework:
|
||||
For a deeper dive into this, check out the [architecture guide](./architecture.md).
|
||||
|
||||
## Setup and Configuration
|
||||
See [the Setup guide](./setup.md).
|
||||
See [the Setup guide](./setup.md) for detailed instructions. The basic configuration requires just two environment variables:
|
||||
```bash
|
||||
LLM_MODEL=openai/gpt-4o # or any other model supported by LiteLLM
|
||||
LLM_KEY=your-api-key-here
|
||||
```
|
||||
|
||||
## Customizing Interaction & Tools
|
||||
See [the guide to adding goals and tools](./adding-goals-and-tools.md).
|
||||
@@ -36,11 +46,44 @@ See [the guide to adding goals and tools](./adding-goals-and-tools.md).
|
||||
## Architecture
|
||||
See [the architecture guide](./architecture.md).
|
||||
|
||||
## Testing
|
||||
|
||||
The project includes comprehensive tests for workflows and activities using Temporal's testing framework:
|
||||
|
||||
```bash
|
||||
# Install dependencies including test dependencies
|
||||
poetry install --with dev
|
||||
|
||||
# Run all tests
|
||||
poetry run pytest
|
||||
|
||||
# Run with time-skipping for faster execution
|
||||
poetry run pytest --workflow-environment=time-skipping
|
||||
```
|
||||
|
||||
**Test Coverage:**
|
||||
- ✅ **Workflow Tests**: AgentGoalWorkflow signals, queries, state management
|
||||
- ✅ **Activity Tests**: ToolActivities, LLM integration (mocked), environment configuration
|
||||
- ✅ **Integration Tests**: End-to-end workflow and activity execution
|
||||
|
||||
**Documentation:**
|
||||
- **Quick Start**: [TESTING.md](TESTING.md) - Simple commands to run tests
|
||||
- **Comprehensive Guide**: [tests/README.md](tests/README.md) - Detailed testing documentation, patterns, and best practices
|
||||
|
||||
## Development
|
||||
|
||||
Install dependencies:
|
||||
```bash
|
||||
poetry install
|
||||
```
|
||||
|
||||
Start the Temporal Server and API server, see [setup](setup.md)
|
||||
|
||||
## Productionalization & Adding Features
|
||||
- In a prod setting, I would need to ensure that payload data is stored separately (e.g. in S3 or a noSQL db - the claim-check pattern), or otherwise 'garbage collected'. Without these techniques, long conversations will fill up the workflow's conversation history, and start to breach Temporal event history payload limits.
|
||||
- A single worker can easily support many agent workflows (chats) running at the same time. Currently the workflow ID is the same each time, so it will only run one agent at a time. To run multiple agents, you can use a different workflow ID each time (e.g. by using a UUID or timestamp).
|
||||
- Perhaps the UI should show when the LLM response is being retried (i.e. activity retry attempt because the LLM provided bad output)
|
||||
- Tests would be nice! [See tests](./tests/).
|
||||
- The project now includes comprehensive tests for workflows and activities! [See testing guide](TESTING.md).
|
||||
|
||||
|
||||
See [the todo](./todo.md) for more details.
|
||||
|
||||
163
TESTING.md
Normal file
163
TESTING.md
Normal file
@@ -0,0 +1,163 @@
|
||||
# Testing the Temporal AI Agent
|
||||
|
||||
This guide provides instructions for running the comprehensive test suite for the Temporal AI Agent project.
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. **Install dependencies**:
|
||||
```bash
|
||||
poetry install --with dev
|
||||
```
|
||||
|
||||
2. **Run all tests**:
|
||||
```bash
|
||||
poetry run pytest
|
||||
```
|
||||
|
||||
3. **Run with time-skipping for faster execution**:
|
||||
```bash
|
||||
poetry run pytest --workflow-environment=time-skipping
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
### Unit Tests
|
||||
- **Activity Tests**: `tests/test_tool_activities.py`
|
||||
- LLM integration (mocked)
|
||||
- Environment configuration
|
||||
- JSON processing
|
||||
- Dynamic tool execution
|
||||
|
||||
### Integration Tests
|
||||
- **Workflow Tests**: `tests/test_agent_goal_workflow.py`
|
||||
- Full workflow execution
|
||||
- Signal and query handling
|
||||
- State management
|
||||
- Error scenarios
|
||||
|
||||
## Running Specific Tests
|
||||
|
||||
```bash
|
||||
# Run only activity tests
|
||||
poetry run pytest tests/test_tool_activities.py -v
|
||||
|
||||
# Run only workflow tests
|
||||
poetry run pytest tests/test_agent_goal_workflow.py -v
|
||||
|
||||
# Run a specific test
|
||||
poetry run pytest tests/test_tool_activities.py::TestToolActivities::test_sanitize_json_response -v
|
||||
|
||||
# Run tests matching a pattern
|
||||
poetry run pytest -k "validation" -v
|
||||
```
|
||||
|
||||
## Test Environment Options
|
||||
|
||||
### Local Environment (Default)
|
||||
```bash
|
||||
poetry run pytest --workflow-environment=local
|
||||
```
|
||||
|
||||
### Time-Skipping Environment (Recommended for CI)
|
||||
```bash
|
||||
poetry run pytest --workflow-environment=time-skipping
|
||||
```
|
||||
|
||||
### External Temporal Server
|
||||
```bash
|
||||
poetry run pytest --workflow-environment=localhost:7233
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
Tests can be configured with these environment variables:
|
||||
|
||||
- `LLM_MODEL`: Model for LLM testing (default: "openai/gpt-4")
|
||||
- `LLM_KEY`: API key for LLM service (mocked in tests)
|
||||
- `LLM_BASE_URL`: Custom LLM endpoint (optional)
|
||||
|
||||
## Test Coverage
|
||||
|
||||
The test suite covers:
|
||||
|
||||
✅ **Workflows**
|
||||
- AgentGoalWorkflow initialization and execution
|
||||
- Signal handling (user_prompt, confirm, end_chat)
|
||||
- Query methods (conversation history, agent goal, tool data)
|
||||
- State management and conversation flow
|
||||
- Validation and error handling
|
||||
|
||||
✅ **Activities**
|
||||
- ToolActivities class methods
|
||||
- LLM integration (mocked)
|
||||
- Environment variable handling
|
||||
- JSON response processing
|
||||
- Dynamic tool activity execution
|
||||
|
||||
✅ **Integration**
|
||||
- End-to-end workflow execution
|
||||
- Activity registration in workers
|
||||
- Temporal client interactions
|
||||
|
||||
## Test Output
|
||||
|
||||
Successful test run example:
|
||||
```
|
||||
============================== test session starts ==============================
|
||||
platform darwin -- Python 3.11.3, pytest-8.3.5, pluggy-1.5.0
|
||||
rootdir: /Users/steveandroulakis/Documents/Code/agentic/temporal-demo/temporal-ai-agent
|
||||
configfile: pyproject.toml
|
||||
plugins: anyio-4.5.2, asyncio-0.26.0
|
||||
collected 21 items
|
||||
|
||||
tests/test_tool_activities.py::TestToolActivities::test_sanitize_json_response PASSED
|
||||
tests/test_tool_activities.py::TestToolActivities::test_parse_json_response_success PASSED
|
||||
tests/test_tool_activities.py::TestToolActivities::test_get_wf_env_vars_default_values PASSED
|
||||
...
|
||||
|
||||
============================== 21 passed in 12.5s ==============================
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Module not found errors**: Run `poetry install --with dev`
|
||||
2. **Async warnings**: These are expected with pytest-asyncio and can be ignored
|
||||
3. **Test timeouts**: Use `--workflow-environment=time-skipping` for faster execution
|
||||
4. **Import errors**: Check that you're running tests from the project root directory
|
||||
|
||||
### Debugging Tests
|
||||
|
||||
Enable verbose logging:
|
||||
```bash
|
||||
poetry run pytest --log-cli-level=DEBUG -s
|
||||
```
|
||||
|
||||
Run with coverage:
|
||||
```bash
|
||||
poetry run pytest --cov=workflows --cov=activities
|
||||
```
|
||||
|
||||
## Continuous Integration
|
||||
|
||||
For CI environments, use:
|
||||
```bash
|
||||
poetry run pytest --workflow-environment=time-skipping --tb=short
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- See `tests/README.md` for detailed testing documentation
|
||||
- Review `tests/conftest.py` for available test fixtures
|
||||
- Check individual test files for specific test scenarios
|
||||
|
||||
## Test Architecture
|
||||
|
||||
The tests use:
|
||||
- **Temporal Testing Framework**: For workflow and activity testing
|
||||
- **pytest-asyncio**: For async test support
|
||||
- **unittest.mock**: For mocking external dependencies
|
||||
- **Test Fixtures**: For consistent test data and setup
|
||||
|
||||
All external dependencies (LLM calls, file I/O) are mocked to ensure fast, reliable tests.
|
||||
@@ -1,142 +1,28 @@
|
||||
import inspect
|
||||
from temporalio import activity
|
||||
from ollama import chat, ChatResponse
|
||||
from openai import OpenAI
|
||||
import json
|
||||
from typing import Sequence, Optional
|
||||
from typing import Optional, Sequence
|
||||
from temporalio.common import RawValue
|
||||
import os
|
||||
from datetime import datetime
|
||||
import google.generativeai as genai
|
||||
import anthropic
|
||||
import deepseek
|
||||
from dotenv import load_dotenv
|
||||
from models.data_types import EnvLookupOutput, ValidationInput, ValidationResult, ToolPromptInput, EnvLookupInput
|
||||
from litellm import completion
|
||||
|
||||
load_dotenv(override=True)
|
||||
print(
|
||||
"Using LLM provider: "
|
||||
+ os.environ.get("LLM_PROVIDER", "openai")
|
||||
+ " (set LLM_PROVIDER in .env to change)"
|
||||
)
|
||||
|
||||
if os.environ.get("LLM_PROVIDER") == "ollama":
|
||||
print(
|
||||
"Using Ollama (local) model: "
|
||||
+ os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b")
|
||||
)
|
||||
|
||||
|
||||
class ToolActivities:
|
||||
def __init__(self):
|
||||
"""Initialize LLM clients based on environment configuration."""
|
||||
self.llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
||||
print(f"Initializing ToolActivities with LLM provider: {self.llm_provider}")
|
||||
|
||||
# Initialize client variables (all set to None initially)
|
||||
self.openai_client: Optional[OpenAI] = None
|
||||
self.grok_client: Optional[OpenAI] = None
|
||||
self.anthropic_client: Optional[anthropic.Anthropic] = None
|
||||
self.genai_configured: bool = False
|
||||
self.deepseek_client: Optional[deepseek.DeepSeekAPI] = None
|
||||
self.ollama_model_name: Optional[str] = None
|
||||
self.ollama_initialized: bool = False
|
||||
|
||||
# Only initialize the client specified by LLM_PROVIDER
|
||||
if self.llm_provider == "openai":
|
||||
if os.environ.get("OPENAI_API_KEY"):
|
||||
self.openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
print("Initialized OpenAI client")
|
||||
else:
|
||||
print("Warning: OPENAI_API_KEY not set but LLM_PROVIDER is 'openai'")
|
||||
|
||||
elif self.llm_provider == "grok":
|
||||
if os.environ.get("GROK_API_KEY"):
|
||||
self.grok_client = OpenAI(api_key=os.environ.get("GROK_API_KEY"), base_url="https://api.x.ai/v1")
|
||||
print("Initialized grok client")
|
||||
else:
|
||||
print("Warning: GROK_API_KEY not set but LLM_PROVIDER is 'grok'")
|
||||
|
||||
elif self.llm_provider == "anthropic":
|
||||
if os.environ.get("ANTHROPIC_API_KEY"):
|
||||
self.anthropic_client = anthropic.Anthropic(
|
||||
api_key=os.environ.get("ANTHROPIC_API_KEY")
|
||||
)
|
||||
print("Initialized Anthropic client")
|
||||
else:
|
||||
print(
|
||||
"Warning: ANTHROPIC_API_KEY not set but LLM_PROVIDER is 'anthropic'"
|
||||
)
|
||||
|
||||
elif self.llm_provider == "google":
|
||||
api_key = os.environ.get("GOOGLE_API_KEY")
|
||||
if api_key:
|
||||
genai.configure(api_key=api_key)
|
||||
self.genai_configured = True
|
||||
print("Configured Google Generative AI")
|
||||
else:
|
||||
print("Warning: GOOGLE_API_KEY not set but LLM_PROVIDER is 'google'")
|
||||
|
||||
elif self.llm_provider == "deepseek":
|
||||
if os.environ.get("DEEPSEEK_API_KEY"):
|
||||
self.deepseek_client = deepseek.DeepSeekAPI(
|
||||
api_key=os.environ.get("DEEPSEEK_API_KEY")
|
||||
)
|
||||
print("Initialized DeepSeek client")
|
||||
else:
|
||||
print(
|
||||
"Warning: DEEPSEEK_API_KEY not set but LLM_PROVIDER is 'deepseek'"
|
||||
)
|
||||
|
||||
# For Ollama, we store the model name but actual initialization happens in warm_up_ollama
|
||||
elif self.llm_provider == "ollama":
|
||||
self.ollama_model_name = os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b")
|
||||
print(
|
||||
f"Using Ollama model: {self.ollama_model_name} (will be loaded on worker startup)"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Warning: Unknown LLM_PROVIDER '{self.llm_provider}', defaulting to OpenAI"
|
||||
)
|
||||
|
||||
def warm_up_ollama(self):
|
||||
"""Pre-load the Ollama model to avoid cold start latency on first request"""
|
||||
if self.llm_provider != "ollama" or self.ollama_initialized:
|
||||
return False # No need to warm up if not using Ollama or already warmed up
|
||||
|
||||
try:
|
||||
print(
|
||||
f"Pre-loading Ollama model '{self.ollama_model_name}' - this may take 30+ seconds..."
|
||||
)
|
||||
start_time = datetime.now()
|
||||
|
||||
# Make a simple request to load the model into memory
|
||||
chat(
|
||||
model=self.ollama_model_name,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are an AI assistant"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello! This is a warm-up message to load the model.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||
print(f"✅ Ollama model loaded successfully in {elapsed_time:.2f} seconds")
|
||||
self.ollama_initialized = True
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ Error pre-loading Ollama model: {str(e)}")
|
||||
print(
|
||||
"The worker will continue, but the first actual request may experience a delay."
|
||||
)
|
||||
return False
|
||||
"""Initialize LLM client using LiteLLM."""
|
||||
self.llm_model = os.environ.get("LLM_MODEL", "openai/gpt-4")
|
||||
self.llm_key = os.environ.get("LLM_KEY")
|
||||
self.llm_base_url = os.environ.get("LLM_BASE_URL")
|
||||
print(f"Initializing ToolActivities with LLM model: {self.llm_model}")
|
||||
if self.llm_base_url:
|
||||
print(f"Using custom base URL: {self.llm_base_url}")
|
||||
|
||||
@activity.defn
|
||||
async def agent_validatePrompt(
|
||||
self, validation_input: ValidationInput
|
||||
) -> ValidationResult:
|
||||
async def agent_validatePrompt(self, validation_input: ValidationInput) -> ValidationResult:
|
||||
"""
|
||||
Validates the prompt in the context of the conversation history and agent goal.
|
||||
Returns a ValidationResult indicating if the prompt makes sense given the context.
|
||||
@@ -187,7 +73,7 @@ class ToolActivities:
|
||||
prompt=validation_prompt, context_instructions=context_instructions
|
||||
)
|
||||
|
||||
result = self.agent_toolPlanner(prompt_input)
|
||||
result = await self.agent_toolPlanner(prompt_input)
|
||||
|
||||
return ValidationResult(
|
||||
validationResult=result.get("validationResult", False),
|
||||
@@ -195,19 +81,43 @@ class ToolActivities:
|
||||
)
|
||||
|
||||
@activity.defn
|
||||
def agent_toolPlanner(self, input: ToolPromptInput) -> dict:
|
||||
if self.llm_provider == "ollama":
|
||||
return self.prompt_llm_ollama(input)
|
||||
elif self.llm_provider == "google":
|
||||
return self.prompt_llm_google(input)
|
||||
elif self.llm_provider == "anthropic":
|
||||
return self.prompt_llm_anthropic(input)
|
||||
elif self.llm_provider == "deepseek":
|
||||
return self.prompt_llm_deepseek(input)
|
||||
elif self.llm_provider == "grok":
|
||||
return self.prompt_llm_grok(input)
|
||||
else:
|
||||
return self.prompt_llm_openai(input)
|
||||
async def agent_toolPlanner(self, input: ToolPromptInput) -> dict:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": input.context_instructions
|
||||
+ ". The current date is "
|
||||
+ datetime.now().strftime("%B %d, %Y"),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": input.prompt,
|
||||
},
|
||||
]
|
||||
|
||||
try:
|
||||
completion_kwargs = {
|
||||
"model": self.llm_model,
|
||||
"messages": messages,
|
||||
"api_key": self.llm_key
|
||||
}
|
||||
|
||||
# Add base_url if configured
|
||||
if self.llm_base_url:
|
||||
completion_kwargs["base_url"] = self.llm_base_url
|
||||
|
||||
response = completion(**completion_kwargs)
|
||||
|
||||
response_content = response.choices[0].message.content
|
||||
activity.logger.info(f"LLM response: {response_content}")
|
||||
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response_content)
|
||||
|
||||
return self.parse_json_response(response_content)
|
||||
except Exception as e:
|
||||
print(f"Error in LLM completion: {str(e)}")
|
||||
raise
|
||||
|
||||
def parse_json_response(self, response_content: str) -> dict:
|
||||
"""
|
||||
@@ -220,259 +130,18 @@ class ToolActivities:
|
||||
print(f"Invalid JSON: {e}")
|
||||
raise
|
||||
|
||||
def prompt_llm_openai(self, input: ToolPromptInput) -> dict:
|
||||
if not self.openai_client:
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY is not set in the environment variables but LLM_PROVIDER is 'openai'"
|
||||
)
|
||||
self.openai_client = OpenAI(api_key=api_key)
|
||||
print("Initialized OpenAI client on demand")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": input.context_instructions
|
||||
+ ". The current date is "
|
||||
+ datetime.now().strftime("%B %d, %Y"),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": input.prompt,
|
||||
},
|
||||
]
|
||||
|
||||
chat_completion = self.openai_client.chat.completions.create(
|
||||
model="gpt-4o", messages=messages # was gpt-4-0613
|
||||
)
|
||||
|
||||
response_content = chat_completion.choices[0].message.content
|
||||
activity.logger.info(f"ChatGPT response: {response_content}")
|
||||
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response_content)
|
||||
|
||||
return self.parse_json_response(response_content)
|
||||
|
||||
def prompt_llm_grok(self, input: ToolPromptInput) -> dict:
|
||||
if not self.grok_client:
|
||||
api_key = os.environ.get("GROK_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"GROK_API_KEY is not set in the environment variables but LLM_PROVIDER is 'grok'"
|
||||
)
|
||||
self.grok_client = OpenAI(api_key=api_key, base_url="https://api.x.ai/v1")
|
||||
print("Initialized grok client on demand")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": input.context_instructions
|
||||
+ ". The current date is "
|
||||
+ datetime.now().strftime("%B %d, %Y"),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": input.prompt,
|
||||
},
|
||||
]
|
||||
|
||||
chat_completion = self.grok_client.chat.completions.create(
|
||||
model="grok-2-1212", messages=messages
|
||||
)
|
||||
|
||||
response_content = chat_completion.choices[0].message.content
|
||||
activity.logger.info(f"Grok response: {response_content}")
|
||||
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response_content)
|
||||
|
||||
return self.parse_json_response(response_content)
|
||||
def prompt_llm_ollama(self, input: ToolPromptInput) -> dict:
|
||||
# If not yet initialized, try to do so now (this is a backup if warm_up_ollama wasn't called or failed)
|
||||
if not self.ollama_initialized:
|
||||
print(
|
||||
"Ollama model not pre-loaded. Loading now (this may take 30+ seconds)..."
|
||||
)
|
||||
try:
|
||||
self.warm_up_ollama()
|
||||
except Exception:
|
||||
# We already logged the error in warm_up_ollama, continue with the actual request
|
||||
pass
|
||||
|
||||
model_name = self.ollama_model_name or os.environ.get(
|
||||
"OLLAMA_MODEL_NAME", "qwen2.5:14b"
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": input.context_instructions
|
||||
+ ". The current date is "
|
||||
+ get_current_date_human_readable(),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": input.prompt,
|
||||
},
|
||||
]
|
||||
|
||||
try:
|
||||
response: ChatResponse = chat(model=model_name, messages=messages)
|
||||
print(f"Chat response: {response.message.content}")
|
||||
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response.message.content)
|
||||
return self.parse_json_response(response_content)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
# Re-raise JSON-related exceptions to let Temporal retry the activity
|
||||
print(f"JSON parsing error with Ollama response: {str(e)}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# Log and raise other exceptions that may need retrying
|
||||
print(f"Error in Ollama chat: {str(e)}")
|
||||
raise
|
||||
|
||||
def prompt_llm_google(self, input: ToolPromptInput) -> dict:
|
||||
if not self.genai_configured:
|
||||
api_key = os.environ.get("GOOGLE_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"GOOGLE_API_KEY is not set in the environment variables but LLM_PROVIDER is 'google'"
|
||||
)
|
||||
genai.configure(api_key=api_key)
|
||||
self.genai_configured = True
|
||||
print("Configured Google Generative AI on demand")
|
||||
|
||||
model = genai.GenerativeModel(
|
||||
"models/gemini-1.5-flash",
|
||||
system_instruction=input.context_instructions
|
||||
+ ". The current date is "
|
||||
+ datetime.now().strftime("%B %d, %Y"),
|
||||
)
|
||||
response = model.generate_content(input.prompt)
|
||||
response_content = response.text
|
||||
print(f"Google Gemini response: {response_content}")
|
||||
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response_content)
|
||||
|
||||
return self.parse_json_response(response_content)
|
||||
|
||||
def prompt_llm_anthropic(self, input: ToolPromptInput) -> dict:
|
||||
if not self.anthropic_client:
|
||||
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"ANTHROPIC_API_KEY is not set in the environment variables but LLM_PROVIDER is 'anthropic'"
|
||||
)
|
||||
self.anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||
print("Initialized Anthropic client on demand")
|
||||
|
||||
response = self.anthropic_client.messages.create(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
#model="claude-3-7-sonnet-20250219", # doesn't do as well
|
||||
max_tokens=1024,
|
||||
system=input.context_instructions
|
||||
+ ". The current date is "
|
||||
+ get_current_date_human_readable(),
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input.prompt,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
response_content = response.content[0].text
|
||||
print(f"Anthropic response: {response_content}")
|
||||
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response_content)
|
||||
|
||||
return self.parse_json_response(response_content)
|
||||
|
||||
def prompt_llm_deepseek(self, input: ToolPromptInput) -> dict:
|
||||
if not self.deepseek_client:
|
||||
api_key = os.environ.get("DEEPSEEK_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"DEEPSEEK_API_KEY is not set in the environment variables but LLM_PROVIDER is 'deepseek'"
|
||||
)
|
||||
self.deepseek_client = deepseek.DeepSeekAPI(api_key=api_key)
|
||||
print("Initialized DeepSeek client on demand")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": input.context_instructions
|
||||
+ ". The current date is "
|
||||
+ datetime.now().strftime("%B %d, %Y"),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": input.prompt,
|
||||
},
|
||||
]
|
||||
|
||||
response = self.deepseek_client.chat_completion(prompt=messages)
|
||||
response_content = response
|
||||
print(f"DeepSeek response: {response_content}")
|
||||
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response_content)
|
||||
|
||||
return self.parse_json_response(response_content)
|
||||
|
||||
def sanitize_json_response(self, response_content: str) -> str:
|
||||
"""
|
||||
Extracts the JSON block from the response content as a string.
|
||||
Supports:
|
||||
- JSON surrounded by ```json and ```
|
||||
- Raw JSON input
|
||||
- JSON preceded or followed by extra text
|
||||
Rejects invalid input that doesn't contain JSON.
|
||||
Sanitizes the response content to ensure it's valid JSON.
|
||||
"""
|
||||
try:
|
||||
start_marker = "```json"
|
||||
end_marker = "```"
|
||||
# Remove any markdown code block markers
|
||||
response_content = response_content.replace("```json", "").replace("```", "")
|
||||
|
||||
# Remove any leading/trailing whitespace
|
||||
response_content = response_content.strip()
|
||||
|
||||
return response_content
|
||||
|
||||
json_str = None
|
||||
|
||||
# Case 1: JSON surrounded by markers
|
||||
if start_marker in response_content and end_marker in response_content:
|
||||
json_start = response_content.index(start_marker) + len(start_marker)
|
||||
json_end = response_content.index(end_marker, json_start)
|
||||
json_str = response_content[json_start:json_end].strip()
|
||||
|
||||
# Case 2: Text with valid JSON
|
||||
else:
|
||||
# Try to locate the JSON block by scanning for the first `{` and last `}`
|
||||
json_start = response_content.find("{")
|
||||
json_end = response_content.rfind("}")
|
||||
|
||||
if json_start != -1 and json_end != -1 and json_start < json_end:
|
||||
json_str = response_content[json_start : json_end + 1].strip()
|
||||
|
||||
# Validate and ensure the extracted JSON is valid
|
||||
if json_str:
|
||||
json.loads(json_str) # This will raise an error if the JSON is invalid
|
||||
return json_str
|
||||
|
||||
# If no valid JSON found, raise an error
|
||||
raise ValueError("Response does not contain valid JSON.")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# Invalid JSON
|
||||
print(f"Invalid JSON detected in response: {response_content}")
|
||||
raise ValueError("Response does not contain valid JSON.")
|
||||
except Exception as e:
|
||||
# Other errors
|
||||
print(f"Error processing response: {str(e)}")
|
||||
print(f"Full response: {response_content}")
|
||||
raise
|
||||
|
||||
# get env vars for workflow
|
||||
@activity.defn
|
||||
async def get_wf_env_vars(self, input: EnvLookupInput) -> EnvLookupOutput:
|
||||
""" gets env vars for workflow as an activity result so it's deterministic
|
||||
@@ -498,18 +167,6 @@ class ToolActivities:
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def get_current_date_human_readable():
|
||||
"""
|
||||
Returns the current date in a human-readable format.
|
||||
|
||||
Example: Wednesday, January 1, 2025
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
return datetime.now().strftime("%A, %B %d, %Y")
|
||||
|
||||
|
||||
@activity.defn(dynamic=True)
|
||||
async def dynamic_tool_activity(args: Sequence[RawValue]) -> dict:
|
||||
from tools import get_handler
|
||||
|
||||
@@ -3,7 +3,7 @@ import NavBar from "../components/NavBar";
|
||||
import ChatWindow from "../components/ChatWindow";
|
||||
import { apiService } from "../services/api";
|
||||
|
||||
const POLL_INTERVAL = 500; // 0.5 seconds
|
||||
const POLL_INTERVAL = 600; // 0.6 seconds
|
||||
const INITIAL_ERROR_STATE = { visible: false, message: '' };
|
||||
const DEBOUNCE_DELAY = 300; // 300ms debounce for user input
|
||||
|
||||
|
||||
1562
poetry.lock
generated
1562
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,7 @@ authors = [
|
||||
]
|
||||
readme = "README.md"
|
||||
|
||||
# By default, Poetry will find packages automatically,
|
||||
# By default, Poetry will find packages automatically,
|
||||
# but explicitly including them is fine:
|
||||
packages = [
|
||||
{ include = "**/*.py", from = "." }
|
||||
@@ -31,18 +31,14 @@ temporalio = "^1.8.0"
|
||||
|
||||
# Standard library modules (e.g. asyncio, collections) don't need to be added
|
||||
# since they're built-in for Python 3.8+.
|
||||
ollama = "^0.4.5"
|
||||
litellm = "^1.70.0"
|
||||
pyyaml = "^6.0.2"
|
||||
fastapi = "^0.115.6"
|
||||
uvicorn = "^0.34.0"
|
||||
python-dotenv = "^1.0.1"
|
||||
openai = "^1.59.2"
|
||||
stripe = "^11.4.1"
|
||||
google-generativeai = "^0.8.4"
|
||||
anthropic = "0.47.0"
|
||||
deepseek = "^1.0.0"
|
||||
requests = "^2.32.3"
|
||||
pandas = "^2.2.3"
|
||||
stripe = "^11.4.1"
|
||||
gtfs-kit = "^10.1.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
@@ -60,4 +56,5 @@ asyncio_mode = "auto"
|
||||
log_cli = true
|
||||
log_cli_level = "INFO"
|
||||
log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
norecursedirs = ["vibe"]
|
||||
@@ -1,23 +0,0 @@
|
||||
from ollama import chat, ChatResponse
|
||||
|
||||
|
||||
def main():
|
||||
model_name = "mistral"
|
||||
|
||||
# The messages to pass to the model
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Why is the sky blue?",
|
||||
}
|
||||
]
|
||||
|
||||
# Call ollama's chat function
|
||||
response: ChatResponse = chat(model=model_name, messages=messages)
|
||||
|
||||
# Print the full message content
|
||||
print(response.message.content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -17,18 +17,18 @@ async def main():
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Print LLM configuration info
|
||||
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
||||
print(f"Worker will use LLM provider: {llm_provider}")
|
||||
llm_model = os.environ.get("LLM_MODEL", "openai/gpt-4")
|
||||
print(f"Worker will use LLM model: {llm_model}")
|
||||
|
||||
# Create the client
|
||||
client = await get_temporal_client()
|
||||
|
||||
# Initialize the activities class once with the specified LLM provider
|
||||
# Initialize the activities class
|
||||
activities = ToolActivities()
|
||||
print(f"ToolActivities initialized with LLM provider: {llm_provider}")
|
||||
print(f"ToolActivities initialized with LLM model: {llm_model}")
|
||||
|
||||
# If using Ollama, pre-load the model to avoid cold start latency
|
||||
if llm_provider == "ollama":
|
||||
if llm_model.startswith("ollama"):
|
||||
print("\n======== OLLAMA MODEL INITIALIZATION ========")
|
||||
print("Ollama models need to be loaded into memory on first use.")
|
||||
print("This may take 30+ seconds depending on your hardware and model size.")
|
||||
@@ -51,8 +51,6 @@ async def main():
|
||||
print("Worker ready to process tasks!")
|
||||
logging.basicConfig(level=logging.WARN)
|
||||
|
||||
|
||||
|
||||
# Run the worker
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor:
|
||||
worker = Worker(
|
||||
|
||||
127
setup.md
127
setup.md
@@ -14,9 +14,40 @@ If you want to show confirmations/enable the debugging UI that shows tool args,
|
||||
SHOW_CONFIRM=True
|
||||
```
|
||||
|
||||
### Quick Start with Makefile
|
||||
|
||||
We've provided a Makefile to simplify the setup and running of the application. Here are the main commands:
|
||||
|
||||
```bash
|
||||
# Initial setup
|
||||
make setup # Creates virtual environment and installs dependencies
|
||||
make setup-venv # Creates virtual environment only
|
||||
make install # Installs all dependencies
|
||||
|
||||
# Running the application
|
||||
make run-worker # Starts the Temporal worker
|
||||
make run-api # Starts the API server
|
||||
make run-frontend # Starts the frontend development server
|
||||
|
||||
# Additional services
|
||||
make run-train-api # Starts the train API server
|
||||
make run-legacy-worker # Starts the legacy worker
|
||||
make run-enterprise # Builds and runs the enterprise .NET worker
|
||||
|
||||
# Development environment setup
|
||||
make setup-temporal-mac # Installs and starts Temporal server on Mac
|
||||
|
||||
# View all available commands
|
||||
make help
|
||||
```
|
||||
|
||||
### Manual Setup (Alternative to Makefile)
|
||||
|
||||
If you prefer to run commands manually, follow these steps:
|
||||
|
||||
### Agent Goal Configuration
|
||||
|
||||
The agent can be configured to pursue different goals using the `AGENT_GOAL` environment variable in your `.env` file. If unset, default is `goal_choose_agent_type`.
|
||||
The agent can be configured to pursue different goals using the `AGENT_GOAL` environment variable in your `.env` file. If unset, default is `goal_choose_agent_type`.
|
||||
|
||||
If the first goal is `goal_choose_agent_type` the agent will support multiple goals using goal categories defined by `GOAL_CATEGORIES` in your .env file. If unset, default is all. We recommend starting with `fin`.
|
||||
```bash
|
||||
@@ -25,54 +56,41 @@ GOAL_CATEGORIES=hr,travel-flights,travel-trains,fin
|
||||
|
||||
See the section Goal-Specific Tool Configuration below for tool configuration for specific goals.
|
||||
|
||||
### LLM Provider Configuration
|
||||
### LLM Configuration
|
||||
|
||||
The agent can use OpenAI's GPT-4o, Google Gemini, Anthropic Claude, or a local LLM via Ollama. Set the `LLM_PROVIDER` environment variable in your `.env` file to choose the desired provider:
|
||||
Note: We recommend using OpenAI's GPT-4o or Claude 3.5 Sonnet for the best results. There can be significant differences in performance and capabilities between models, especially for complex tasks.
|
||||
|
||||
- `LLM_PROVIDER=openai` for OpenAI's GPT-4o
|
||||
- `LLM_PROVIDER=google` for Google Gemini
|
||||
- `LLM_PROVIDER=anthropic` for Anthropic Claude
|
||||
- `LLM_PROVIDER=deepseek` for DeepSeek-V3
|
||||
- `LLM_PROVIDER=ollama` for running LLMs via [Ollama](https://ollama.ai) (not recommended for this use case)
|
||||
The agent uses LiteLLM to interact with various LLM providers. Configure the following environment variables in your `.env` file:
|
||||
|
||||
### Option 1: OpenAI
|
||||
- `LLM_MODEL`: The model to use (e.g., "openai/gpt-4o", "anthropic/claude-3-sonnet", "google/gemini-pro", etc.)
|
||||
- `LLM_KEY`: Your API key for the selected provider
|
||||
- `LLM_BASE_URL`: (Optional) Custom base URL for the LLM provider. Useful for:
|
||||
- Using Ollama with a custom endpoint
|
||||
- Using a proxy or custom API gateway
|
||||
- Testing with different API versions
|
||||
|
||||
If using OpenAI, ensure you have an OpenAI key for the GPT-4o model. Set this in the `OPENAI_API_KEY` environment variable in `.env`.
|
||||
LiteLLM will automatically detect the provider based on the model name. For example:
|
||||
- For OpenAI models: `openai/gpt-4o` or `openai/gpt-3.5-turbo`
|
||||
- For Anthropic models: `anthropic/claude-3-sonnet`
|
||||
- For Google models: `google/gemini-pro`
|
||||
- For Ollama models: `ollama/mistral` (requires `LLM_BASE_URL` set to your Ollama server)
|
||||
|
||||
### Option 2: Google Gemini
|
||||
Example configurations:
|
||||
```bash
|
||||
# For OpenAI
|
||||
LLM_MODEL=openai/gpt-4o
|
||||
LLM_KEY=your-api-key-here
|
||||
|
||||
To use Google Gemini:
|
||||
# For Anthropic
|
||||
LLM_MODEL=anthropic/claude-3-sonnet
|
||||
LLM_KEY=your-api-key-here
|
||||
|
||||
1. Obtain a Google API key and set it in the `GOOGLE_API_KEY` environment variable in `.env`.
|
||||
2. Set `LLM_PROVIDER=google` in your `.env` file.
|
||||
# For Ollama with custom URL
|
||||
LLM_MODEL=ollama/mistral
|
||||
LLM_BASE_URL=http://localhost:11434
|
||||
```
|
||||
|
||||
### Option 3: Anthropic Claude (recommended)
|
||||
|
||||
I find that Claude Sonnet 3.5 performs better than the other hosted LLMs for this use case.
|
||||
|
||||
To use Anthropic:
|
||||
|
||||
1. Obtain an Anthropic API key and set it in the `ANTHROPIC_API_KEY` environment variable in `.env`.
|
||||
2. Set `LLM_PROVIDER=anthropic` in your `.env` file.
|
||||
|
||||
### Option 4: Deepseek-V3
|
||||
|
||||
To use Deepseek-V3:
|
||||
|
||||
1. Obtain a Deepseek API key and set it in the `DEEPSEEK_API_KEY` environment variable in `.env`.
|
||||
2. Set `LLM_PROVIDER=deepseek` in your `.env` file.
|
||||
|
||||
### Option 5: Local LLM via Ollama (not recommended)
|
||||
|
||||
To use a local LLM with Ollama:
|
||||
|
||||
1. Install [Ollama](https://ollama.com) and the [Qwen2.5 14B](https://ollama.com/library/qwen2.5) model.
|
||||
- Run `ollama run <OLLAMA_MODEL_NAME>` to start the model. Note that this model is about 9GB to download.
|
||||
- Example: `ollama run qwen2.5:14b`
|
||||
|
||||
2. Set `LLM_PROVIDER=ollama` in your `.env` file and `OLLAMA_MODEL_NAME` to the name of the model you installed.
|
||||
|
||||
Note: I found the other (hosted) LLMs to be MUCH more reliable for this use case. However, you can switch to Ollama if desired, and choose a suitably large model if your computer has the resources.
|
||||
For a complete list of supported models and providers, visit the [LiteLLM documentation](https://docs.litellm.ai/docs/providers).
|
||||
|
||||
## Configuring Temporal Connection
|
||||
|
||||
@@ -149,7 +167,7 @@ npm install
|
||||
npx vite
|
||||
```
|
||||
Access the UI at `http://localhost:5173`
|
||||
|
||||
|
||||
|
||||
## Goal-Specific Tool Configuration
|
||||
Here is configuration guidance for specific goals. Travel and financial goals have configuration & setup as below.
|
||||
@@ -157,7 +175,7 @@ Here is configuration guidance for specific goals. Travel and financial goals ha
|
||||
- `AGENT_GOAL=goal_event_flight_invoice` - Helps users find events, book flights, and arrange train travel with invoice generation
|
||||
- This is the scenario in the [original video](https://www.youtube.com/watch?v=GEXllEH2XiQ)
|
||||
|
||||
#### Configuring Agent Goal: goal_event_flight_invoice
|
||||
#### Configuring Agent Goal: goal_event_flight_invoice
|
||||
* The agent uses a mock function to search for events. This has zero configuration.
|
||||
* By default the agent uses a mock function to search for flights.
|
||||
* If you want to use the real flights API, go to `tools/search_flights.py` and replace the `search_flights` function with `search_flights_real_api` that exists in the same file.
|
||||
@@ -166,16 +184,15 @@ Here is configuration guidance for specific goals. Travel and financial goals ha
|
||||
* Requires a Stripe key for the `create_invoice` tool. Set this in the `STRIPE_API_KEY` environment variable in .env
|
||||
* It's free to sign up and get a key at [Stripe](https://stripe.com/)
|
||||
* Set permissions for read-write on: `Credit Notes, Invoices, Customers and Customer Sessions`
|
||||
* If you're lazy go to `tools/create_invoice.py` and replace the `create_invoice` function with the mock `create_invoice_example` that exists in the same file.
|
||||
* If you don't have a Stripe key, comment out the STRIPE_API_KEY in the .env file, and a dummy invoice will be created rather than a Stripe invoice. The function can be found in `tools/create_invoice.py`
|
||||
|
||||
### Goal: Find a Premier League match, book train tickets to it and invoice the user for the cost (Replay 2025 Keynote)
|
||||
- `AGENT_GOAL=goal_match_train_invoice` - Focuses on Premier League match attendance with train booking and invoice generation
|
||||
- This goal was part of [Temporal's Replay 2025 conference keynote demo](https://www.youtube.com/watch?v=YDxAWrIBQNE)
|
||||
- Note, there is failure built in to this demo (the train booking step) to show how the agent can handle failures and retry. See Tool Configuration below for details.
|
||||
#### Configuring Agent Goal: goal_match_train_invoice
|
||||
#### Configuring Agent Goal: goal_match_train_invoice
|
||||
NOTE: This goal was developed for an on-stage demo and has failure (and its resolution) built in to show how the agent can handle failures and retry.
|
||||
* Finding a match requires a key from [Football Data](https://www.football-data.org). Sign up for a free account, then see the 'My Account' page to get your API token. Set `FOOTBALL_DATA_API_KEY` to this value.
|
||||
* If you're lazy go to `tools/search_fixtures.py` and replace the `search_fixtures` function with the mock `search_fixtures_example` that exists in the same file.
|
||||
* Omit `FOOTBALL_DATA_API_KEY` from .env for the `SearchFixtures` tool to automatically return mock Premier League fixtures. Finding a real match requires a key from [Football Data](https://www.football-data.org). Sign up for a free account, then see the 'My Account' page to get your API token.
|
||||
* We use a mock function to search for trains. Start the train API server to use the real API: `python thirdparty/train_api.py`
|
||||
* * The train activity is 'enterprise' so it's written in C# and requires a .NET runtime. See the [.NET backend](#net-(enterprise)-backend) section for details on running it.
|
||||
* Requires a Stripe key for the `create_invoice` tool. Set this in the `STRIPE_API_KEY` environment variable in .env
|
||||
@@ -195,15 +212,15 @@ poetry run python thirdparty/train_api.py
|
||||
|
||||
##### Python Train Legacy Worker
|
||||
> Agent Goal: goal_match_train_invoice only
|
||||
|
||||
|
||||
These are Python activities that fail (raise NotImplemented) to show how Temporal handles a failure. You can run these activities with.
|
||||
|
||||
|
||||
```bash
|
||||
poetry run python scripts/run_legacy_worker.py
|
||||
poetry run python scripts/run_legacy_worker.py
|
||||
```
|
||||
|
||||
|
||||
The activity will fail and be retried infinitely. To rescue the activity (and its corresponding workflows), kill the worker and run the .NET one in the section below.
|
||||
|
||||
|
||||
##### .NET (enterprise) Worker ;)
|
||||
We have activities written in C# to call the train APIs.
|
||||
```bash
|
||||
@@ -216,12 +233,12 @@ If you're running your train API above on a different host/port then change the
|
||||
#### Goals: FIN - Money Movement and Loan Application
|
||||
Make sure you have the mock users you want (such as yourself) in [the account mock data file](./tools/data/customer_account_data.json).
|
||||
|
||||
- `AGENT_GOAL=goal_fin_move_money` - This scenario _can_ initiate a secondary workflow to move money. Check out [this repo](https://github.com/temporal-sa/temporal-money-transfer-java) - you'll need to get the worker running and connected to the same account as the agentic worker.
|
||||
- `AGENT_GOAL=goal_fin_move_money` - This scenario _can_ initiate a secondary workflow to move money. Check out [this repo](https://github.com/temporal-sa/temporal-money-transfer-java) - you'll need to get the worker running and connected to the same account as the agentic worker.
|
||||
By default it will _not_ make a real workflow, it'll just fake it. If you get the worker running and want to start a workflow, in your [.env](./.env):
|
||||
```bash
|
||||
FIN_START_REAL_WORKFLOW=FALSE #set this to true to start a real workflow
|
||||
```
|
||||
- `AGENT_GOAL=goal_fin_loan_application` - This scenario _can_ initiate a secondary workflow to apply for a loan. Check out [this repo](https://github.com/temporal-sa/temporal-latency-optimization-scenarios) - you'll need to get the worker running and connected to the same account as the agentic worker.
|
||||
- `AGENT_GOAL=goal_fin_loan_application` - This scenario _can_ initiate a secondary workflow to apply for a loan. Check out [this repo](https://github.com/temporal-sa/temporal-latency-optimization-scenarios) - you'll need to get the worker running and connected to the same account as the agentic worker.
|
||||
By default it will _not_ make a real workflow, it'll just fake it. If you get the worker running and want to start a workflow, in your [.env](./.env):
|
||||
```bash
|
||||
FIN_START_REAL_WORKFLOW=FALSE #set this to true to start a real workflow
|
||||
@@ -252,4 +269,4 @@ For more details, check out [adding goals and tools guide](./adding-goals-and-to
|
||||
[ ] `cd frontend`, `npm install`, `npx vite` <br />
|
||||
[ ] Access the UI at `http://localhost:5173` <br />
|
||||
|
||||
And that's it! Happy AI Agent Exploring!
|
||||
And that's it! Happy AI Agent Exploring!
|
||||
|
||||
350
tests/README.md
Normal file
350
tests/README.md
Normal file
@@ -0,0 +1,350 @@
|
||||
# Temporal AI Agent - Testing Guide
|
||||
|
||||
This directory contains comprehensive tests for the Temporal AI Agent project. The tests cover workflows, activities, and integration scenarios using Temporal's testing framework.
|
||||
|
||||
## Test Structure
|
||||
|
||||
```
|
||||
tests/
|
||||
├── README.md # This file - testing documentation
|
||||
├── conftest.py # Test configuration and fixtures
|
||||
├── test_agent_goal_workflow.py # Workflow tests
|
||||
├── test_tool_activities.py # Activity tests
|
||||
└── workflowtests/ # Legacy workflow tests
|
||||
└── agent_goal_workflow_test.py
|
||||
```
|
||||
|
||||
## Test Types
|
||||
|
||||
### 1. Workflow Tests (`test_agent_goal_workflow.py`)
|
||||
|
||||
Tests the main `AgentGoalWorkflow` class covering:
|
||||
|
||||
- **Workflow Initialization**: Basic workflow startup and state management
|
||||
- **Signal Handling**: Testing user_prompt, confirm, end_chat signals
|
||||
- **Query Methods**: Testing all workflow query endpoints
|
||||
- **State Management**: Conversation history, goal changes, tool data
|
||||
- **Validation Flow**: Prompt validation and error handling
|
||||
- **Tool Execution Flow**: Confirmation and tool execution cycles
|
||||
|
||||
### 2. Activity Tests (`test_tool_activities.py`)
|
||||
|
||||
Tests the `ToolActivities` class and `dynamic_tool_activity` function:
|
||||
|
||||
- **LLM Integration**: Testing agent_toolPlanner with mocked LLM responses
|
||||
- **Validation Logic**: Testing agent_validatePrompt with various scenarios
|
||||
- **Environment Configuration**: Testing get_wf_env_vars with different env setups
|
||||
- **JSON Processing**: Testing response parsing and sanitization
|
||||
- **Dynamic Tool Execution**: Testing the dynamic activity dispatcher
|
||||
- **Integration**: End-to-end activity execution in Temporal workers
|
||||
|
||||
### 3. Configuration Tests (`conftest.py`)
|
||||
|
||||
Provides shared test fixtures and configuration:
|
||||
|
||||
- **Temporal Environment**: Local and time-skipping test environments
|
||||
- **Sample Data**: Pre-configured agent goals, conversation history, inputs
|
||||
- **Test Client**: Configured Temporal client for testing
|
||||
|
||||
## Running Tests
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Ensure you have the required dependencies installed:
|
||||
|
||||
```bash
|
||||
poetry install --with dev
|
||||
```
|
||||
|
||||
### Basic Test Execution
|
||||
|
||||
Run all tests:
|
||||
```bash
|
||||
poetry run pytest
|
||||
```
|
||||
|
||||
Run specific test files:
|
||||
```bash
|
||||
# Workflow tests only
|
||||
poetry run pytest tests/test_agent_goal_workflow.py
|
||||
|
||||
# Activity tests only
|
||||
poetry run pytest tests/test_tool_activities.py
|
||||
|
||||
# Legacy tests
|
||||
poetry run pytest tests/workflowtests/
|
||||
```
|
||||
|
||||
Run with verbose output:
|
||||
```bash
|
||||
poetry run pytest -v
|
||||
```
|
||||
|
||||
### Test Environment Options
|
||||
|
||||
The tests support different Temporal environments via the `--workflow-environment` flag:
|
||||
|
||||
#### Local Environment (Default)
|
||||
Uses a local Temporal test server:
|
||||
```bash
|
||||
poetry run pytest --workflow-environment=local
|
||||
```
|
||||
|
||||
#### Time-Skipping Environment
|
||||
Uses Temporal's time-skipping test environment for faster execution:
|
||||
```bash
|
||||
poetry run pytest --workflow-environment=time-skipping
|
||||
```
|
||||
|
||||
#### External Server
|
||||
Connect to an existing Temporal server:
|
||||
```bash
|
||||
poetry run pytest --workflow-environment=localhost:7233
|
||||
```
|
||||
|
||||
#### Setup Script for AI Agent environments such as OpenAI Codex
|
||||
```bash
|
||||
export SHELL=/bin/bash
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
ls
|
||||
poetry install --with dev
|
||||
cd frontend
|
||||
npm install
|
||||
cd ..
|
||||
|
||||
# Pre-download the temporal test server binary
|
||||
poetry run python3 -c "
|
||||
import asyncio
|
||||
import sys
|
||||
from temporalio.testing import WorkflowEnvironment
|
||||
|
||||
async def predownload():
|
||||
try:
|
||||
print('Starting test server download...')
|
||||
env = await WorkflowEnvironment.start_time_skipping()
|
||||
print('Test server downloaded and started successfully')
|
||||
await env.shutdown()
|
||||
print('Test server shut down successfully')
|
||||
except Exception as e:
|
||||
print(f'Error during download: {e}')
|
||||
sys.exit(1)
|
||||
|
||||
asyncio.run(predownload())
|
||||
"
|
||||
```
|
||||
|
||||
### Filtering Tests
|
||||
|
||||
Run tests by pattern:
|
||||
```bash
|
||||
# Run only validation tests
|
||||
poetry run pytest -k "validation"
|
||||
|
||||
# Run only workflow tests
|
||||
poetry run pytest -k "workflow"
|
||||
|
||||
# Run only activity tests
|
||||
poetry run pytest -k "activity"
|
||||
```
|
||||
|
||||
Run tests by marker (if you add custom markers):
|
||||
```bash
|
||||
# Run only integration tests
|
||||
poetry run pytest -m integration
|
||||
|
||||
# Skip slow tests
|
||||
poetry run pytest -m "not slow"
|
||||
```
|
||||
|
||||
## Test Configuration
|
||||
|
||||
### Test Discovery
|
||||
|
||||
The `vibe/` directory is excluded from test collection to avoid conflicts with sample tests. This is configured in `pyproject.toml`:
|
||||
|
||||
```toml
|
||||
[tool.pytest.ini_options]
|
||||
norecursedirs = ["vibe"]
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
Tests respect the following environment variables:
|
||||
|
||||
- `LLM_MODEL`: Model to use for LLM testing (defaults to "openai/gpt-4")
|
||||
- `LLM_KEY`: API key for LLM service
|
||||
- `LLM_BASE_URL`: Custom base URL for LLM service
|
||||
- `SHOW_CONFIRM`: Whether to show confirmation dialogs
|
||||
- `AGENT_GOAL`: Default agent goal setting
|
||||
|
||||
### Mocking Strategy
|
||||
|
||||
The tests use extensive mocking to avoid external dependencies:
|
||||
|
||||
- **LLM Calls**: Mocked using `unittest.mock` to avoid actual API calls
|
||||
- **Tool Handlers**: Mocked to test workflow logic without tool execution
|
||||
- **Environment Variables**: Patched for consistent test environments
|
||||
|
||||
## Writing New Tests
|
||||
|
||||
### Test Naming Convention
|
||||
|
||||
- Test files: `test_<module_name>.py`
|
||||
- Test classes: `Test<ClassName>`
|
||||
- Test methods: `test_<functionality>_<scenario>`
|
||||
|
||||
Example:
|
||||
```python
|
||||
class TestAgentGoalWorkflow:
|
||||
async def test_user_prompt_signal_valid_input(self, client, sample_combined_input):
|
||||
# Test implementation
|
||||
pass
|
||||
```
|
||||
|
||||
### Using Fixtures
|
||||
|
||||
Leverage the provided fixtures for consistent test data:
|
||||
|
||||
```python
|
||||
async def test_my_workflow(self, client, sample_agent_goal, sample_conversation_history):
|
||||
# client: Temporal test client
|
||||
# sample_agent_goal: Pre-configured AgentGoal
|
||||
# sample_conversation_history: Sample conversation data
|
||||
pass
|
||||
```
|
||||
|
||||
### Mocking External Dependencies
|
||||
|
||||
Always mock external services:
|
||||
|
||||
```python
|
||||
@patch('activities.tool_activities.completion')
|
||||
async def test_llm_integration(self, mock_completion):
|
||||
mock_completion.return_value.choices[0].message.content = '{"test": "response"}'
|
||||
# Test implementation
|
||||
```
|
||||
|
||||
### Testing Workflow Signals and Queries
|
||||
|
||||
```python
|
||||
async def test_workflow_signal(self, client, sample_combined_input):
|
||||
# Start workflow
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Send signal
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "test message")
|
||||
|
||||
# Query state
|
||||
conversation = await handle.query(AgentGoalWorkflow.get_conversation_history)
|
||||
|
||||
# End workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
result = await handle.result()
|
||||
```
|
||||
|
||||
## Test Data and Fixtures
|
||||
|
||||
### Sample Agent Goal
|
||||
|
||||
The `sample_agent_goal` fixture provides a basic agent goal with:
|
||||
- Goal ID: "test_goal"
|
||||
- One test tool with a required string argument
|
||||
- Suitable for most workflow testing scenarios
|
||||
|
||||
### Sample Conversation History
|
||||
|
||||
The `sample_conversation_history` fixture provides:
|
||||
- Basic user and agent message exchange
|
||||
- Proper message format for testing
|
||||
|
||||
### Sample Combined Input
|
||||
|
||||
The `sample_combined_input` fixture provides:
|
||||
- Complete workflow input with agent goal and tool params
|
||||
- Conversation summary and prompt queue
|
||||
- Ready for workflow execution
|
||||
|
||||
## Debugging Tests
|
||||
|
||||
### Verbose Logging
|
||||
|
||||
Enable detailed logging:
|
||||
```bash
|
||||
poetry run pytest --log-cli-level=DEBUG -s
|
||||
```
|
||||
|
||||
### Temporal Web UI
|
||||
|
||||
When using local environment, access Temporal Web UI at http://localhost:8233 to inspect workflow executions during tests.
|
||||
|
||||
### Test Isolation
|
||||
|
||||
Each test uses unique task queue names to prevent interference:
|
||||
```python
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
```
|
||||
|
||||
## Continuous Integration
|
||||
|
||||
### GitHub Actions Example
|
||||
|
||||
```yaml
|
||||
name: Test
|
||||
on: [push, pull_request]
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- run: pip install poetry
|
||||
- run: poetry install --with dev
|
||||
- run: poetry run pytest --workflow-environment=time-skipping
|
||||
```
|
||||
|
||||
### Test Coverage
|
||||
|
||||
Generate coverage reports:
|
||||
```bash
|
||||
poetry add --group dev pytest-cov
|
||||
poetry run pytest --cov=workflows --cov=activities --cov-report=html
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Mock External Dependencies**: Always mock LLM calls, file I/O, and network requests
|
||||
2. **Use Time-Skipping**: For CI/CD, prefer time-skipping environment for speed
|
||||
3. **Unique Identifiers**: Use UUIDs for workflow IDs and task queues
|
||||
4. **Clean Shutdown**: Always end workflows properly in tests
|
||||
5. **Descriptive Names**: Use clear, descriptive test names
|
||||
6. **Test Edge Cases**: Include error scenarios and validation failures
|
||||
7. **Keep Tests Fast**: Use mocks to avoid slow external calls
|
||||
8. **Isolate Tests**: Ensure tests don't depend on each other
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Workflow Timeout**: Increase timeouts or use time-skipping environment
|
||||
2. **Mock Not Working**: Check patch decorators and import paths
|
||||
3. **Test Hanging**: Ensure workflows are properly ended with signals
|
||||
4. **Environment Issues**: Check environment variable settings
|
||||
|
||||
### Getting Help
|
||||
|
||||
- Check Temporal Python SDK documentation
|
||||
- Review existing test patterns in the codebase
|
||||
- Use `poetry run pytest --collect-only` to verify test discovery
|
||||
- Run with `-v` flag for detailed output
|
||||
|
||||
## Legacy Tests
|
||||
|
||||
The `workflowtests/` directory contains legacy tests. New tests should be added to the main `tests/` directory following the patterns established in this guide.
|
||||
@@ -41,7 +41,12 @@ def event_loop():
|
||||
async def env(request) -> AsyncGenerator[WorkflowEnvironment, None]:
|
||||
env_type = request.config.getoption("--workflow-environment")
|
||||
if env_type == "local":
|
||||
env = await WorkflowEnvironment.start_local()
|
||||
env = await WorkflowEnvironment.start_local(
|
||||
dev_server_extra_args=[
|
||||
"--dynamic-config-value",
|
||||
"frontend.enableExecuteMultiOperation=true",
|
||||
]
|
||||
)
|
||||
elif env_type == "time-skipping":
|
||||
env = await WorkflowEnvironment.start_time_skipping()
|
||||
else:
|
||||
@@ -53,3 +58,59 @@ async def env(request) -> AsyncGenerator[WorkflowEnvironment, None]:
|
||||
@pytest_asyncio.fixture
|
||||
async def client(env: WorkflowEnvironment) -> Client:
|
||||
return env.client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent_goal():
|
||||
"""Sample agent goal for testing."""
|
||||
from models.tool_definitions import AgentGoal, ToolDefinition, ToolArgument
|
||||
|
||||
return AgentGoal(
|
||||
id="test_goal",
|
||||
category_tag="test",
|
||||
agent_name="TestAgent",
|
||||
agent_friendly_description="A test agent for testing purposes",
|
||||
description="Test goal for agent testing",
|
||||
tools=[
|
||||
ToolDefinition(
|
||||
name="TestTool",
|
||||
description="A test tool for testing purposes",
|
||||
arguments=[
|
||||
ToolArgument(
|
||||
name="test_arg",
|
||||
type="string",
|
||||
description="A test argument"
|
||||
)
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_conversation_history():
|
||||
"""Sample conversation history for testing."""
|
||||
return {
|
||||
"messages": [
|
||||
{"actor": "user", "response": "Hello, I need help with testing"},
|
||||
{"actor": "agent", "response": "I can help you with that"}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_combined_input(sample_agent_goal):
|
||||
"""Sample combined input for workflow testing."""
|
||||
from models.data_types import CombinedInput, AgentGoalWorkflowParams
|
||||
|
||||
from collections import deque
|
||||
|
||||
tool_params = AgentGoalWorkflowParams(
|
||||
conversation_summary="Test conversation summary",
|
||||
prompt_queue=deque() # Start with empty queue for most tests
|
||||
)
|
||||
|
||||
return CombinedInput(
|
||||
agent_goal=sample_agent_goal,
|
||||
tool_params=tool_params
|
||||
)
|
||||
|
||||
540
tests/test_agent_goal_workflow.py
Normal file
540
tests/test_agent_goal_workflow.py
Normal file
@@ -0,0 +1,540 @@
|
||||
import uuid
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
from temporalio import activity
|
||||
from temporalio.client import Client
|
||||
from temporalio.worker import Worker
|
||||
from temporalio.testing import WorkflowEnvironment
|
||||
|
||||
from workflows.agent_goal_workflow import AgentGoalWorkflow
|
||||
from activities.tool_activities import ToolActivities
|
||||
from models.data_types import (
|
||||
CombinedInput,
|
||||
AgentGoalWorkflowParams,
|
||||
ConversationHistory,
|
||||
ValidationResult,
|
||||
ValidationInput,
|
||||
EnvLookupOutput,
|
||||
EnvLookupInput,
|
||||
ToolPromptInput
|
||||
)
|
||||
|
||||
|
||||
class TestAgentGoalWorkflow:
|
||||
"""Test cases for AgentGoalWorkflow."""
|
||||
|
||||
async def test_workflow_initialization(self, client: Client, sample_combined_input: CombinedInput):
|
||||
"""Test workflow can be initialized and started."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[mock_get_wf_env_vars],
|
||||
):
|
||||
# Start workflow but don't wait for completion since it runs indefinitely
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Verify workflow is running
|
||||
assert handle is not None
|
||||
|
||||
# Query the workflow to check initial state
|
||||
conversation_history = await handle.query(AgentGoalWorkflow.get_conversation_history)
|
||||
assert isinstance(conversation_history, dict)
|
||||
assert "messages" in conversation_history
|
||||
|
||||
# Test goal query
|
||||
agent_goal = await handle.query(AgentGoalWorkflow.get_agent_goal)
|
||||
assert agent_goal == sample_combined_input.agent_goal
|
||||
|
||||
# End the workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
result = await handle.result()
|
||||
assert isinstance(result, str)
|
||||
|
||||
async def test_user_prompt_signal(self, client: Client, sample_combined_input: CombinedInput):
|
||||
"""Test user_prompt signal handling."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
@activity.defn(name="agent_validatePrompt")
|
||||
async def mock_agent_validatePrompt(validation_input: ValidationInput) -> ValidationResult:
|
||||
return ValidationResult(
|
||||
validationResult=True,
|
||||
validationFailedReason={}
|
||||
)
|
||||
|
||||
@activity.defn(name="agent_toolPlanner")
|
||||
async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict:
|
||||
return {
|
||||
"next": "done",
|
||||
"response": "Test response from LLM"
|
||||
}
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[
|
||||
mock_get_wf_env_vars,
|
||||
mock_agent_validatePrompt,
|
||||
mock_agent_toolPlanner
|
||||
],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Send user prompt
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "Hello, this is a test message")
|
||||
|
||||
# Wait for workflow to complete (it should end due to "done" next step)
|
||||
result = await handle.result()
|
||||
assert isinstance(result, str)
|
||||
|
||||
# Verify the conversation includes our message
|
||||
import json
|
||||
try:
|
||||
conversation_history = json.loads(result.replace("'", '"'))
|
||||
except:
|
||||
# Fallback to eval if json fails
|
||||
conversation_history = eval(result)
|
||||
messages = conversation_history["messages"]
|
||||
|
||||
# Should have our user message and agent response
|
||||
user_messages = [msg for msg in messages if msg["actor"] == "user"]
|
||||
assert len(user_messages) > 0
|
||||
assert any("Hello, this is a test message" in str(msg["response"]) for msg in user_messages)
|
||||
|
||||
async def test_confirm_signal(self, client: Client, sample_combined_input: CombinedInput):
|
||||
"""Test confirm signal handling for tool execution."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
@activity.defn(name="agent_validatePrompt")
|
||||
async def mock_agent_validatePrompt(validation_input: ValidationInput) -> ValidationResult:
|
||||
return ValidationResult(
|
||||
validationResult=True,
|
||||
validationFailedReason={}
|
||||
)
|
||||
|
||||
@activity.defn(name="agent_toolPlanner")
|
||||
async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict:
|
||||
return {
|
||||
"next": "confirm",
|
||||
"tool": "TestTool",
|
||||
"args": {"test_arg": "test_value"},
|
||||
"response": "Ready to execute tool"
|
||||
}
|
||||
|
||||
@activity.defn(name="TestTool")
|
||||
async def mock_test_tool(args: dict) -> dict:
|
||||
return {"result": "Test tool executed successfully"}
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[
|
||||
mock_get_wf_env_vars,
|
||||
mock_agent_validatePrompt,
|
||||
mock_agent_toolPlanner,
|
||||
mock_test_tool
|
||||
],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Send user prompt that will require confirmation
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "Execute the test tool")
|
||||
|
||||
# Query to check tool data is set
|
||||
import asyncio
|
||||
await asyncio.sleep(0.1) # Give workflow time to process
|
||||
|
||||
tool_data = await handle.query(AgentGoalWorkflow.get_latest_tool_data)
|
||||
if tool_data:
|
||||
assert tool_data.get("tool") == "TestTool"
|
||||
assert tool_data.get("next") == "confirm"
|
||||
|
||||
# Send confirmation and end chat
|
||||
await handle.signal(AgentGoalWorkflow.confirm)
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
|
||||
result = await handle.result()
|
||||
assert isinstance(result, str)
|
||||
|
||||
async def test_validation_failure(self, client: Client, sample_combined_input: CombinedInput):
|
||||
"""Test workflow handles validation failures correctly."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
@activity.defn(name="agent_validatePrompt")
|
||||
async def mock_agent_validatePrompt(validation_input: ValidationInput) -> ValidationResult:
|
||||
return ValidationResult(
|
||||
validationResult=False,
|
||||
validationFailedReason={
|
||||
"next": "question",
|
||||
"response": "Your request doesn't make sense in this context"
|
||||
}
|
||||
)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[
|
||||
mock_get_wf_env_vars,
|
||||
mock_agent_validatePrompt
|
||||
],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Send invalid prompt
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "Invalid nonsensical prompt")
|
||||
|
||||
# Give workflow time to process the prompt
|
||||
import asyncio
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# End workflow to check conversation
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
result = await handle.result()
|
||||
|
||||
# Verify validation failure message was added
|
||||
import json
|
||||
try:
|
||||
conversation_history = json.loads(result.replace("'", '"'))
|
||||
except:
|
||||
# Fallback to eval if json fails
|
||||
conversation_history = eval(result)
|
||||
messages = conversation_history["messages"]
|
||||
|
||||
# Should have validation failure response
|
||||
agent_messages = [msg for msg in messages if msg["actor"] == "agent"]
|
||||
assert len(agent_messages) > 0
|
||||
assert any("doesn't make sense" in str(msg["response"]) for msg in agent_messages)
|
||||
|
||||
async def test_conversation_summary_initialization(self, client: Client, sample_agent_goal):
|
||||
"""Test workflow initializes with conversation summary."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create input with conversation summary
|
||||
from collections import deque
|
||||
tool_params = AgentGoalWorkflowParams(
|
||||
conversation_summary="Previous conversation summary",
|
||||
prompt_queue=deque()
|
||||
)
|
||||
combined_input = CombinedInput(
|
||||
agent_goal=sample_agent_goal,
|
||||
tool_params=tool_params
|
||||
)
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[mock_get_wf_env_vars],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Give workflow time to initialize
|
||||
import asyncio
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Query conversation summary
|
||||
summary = await handle.query(AgentGoalWorkflow.get_summary_from_history)
|
||||
assert summary == "Previous conversation summary"
|
||||
|
||||
# Query conversation history - should include summary message
|
||||
conversation_history = await handle.query(AgentGoalWorkflow.get_conversation_history)
|
||||
messages = conversation_history["messages"]
|
||||
|
||||
# Should have conversation_summary message
|
||||
summary_messages = [msg for msg in messages if msg["actor"] == "conversation_summary"]
|
||||
assert len(summary_messages) == 1
|
||||
assert summary_messages[0]["response"] == "Previous conversation summary"
|
||||
|
||||
# End workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
await handle.result()
|
||||
|
||||
async def test_workflow_queries(self, client: Client, sample_combined_input: CombinedInput):
|
||||
"""Test all workflow query methods."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[mock_get_wf_env_vars],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Give workflow time to initialize
|
||||
import asyncio
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Test get_conversation_history query
|
||||
conversation_history = await handle.query(AgentGoalWorkflow.get_conversation_history)
|
||||
assert isinstance(conversation_history, dict)
|
||||
assert "messages" in conversation_history
|
||||
|
||||
# Test get_agent_goal query
|
||||
agent_goal = await handle.query(AgentGoalWorkflow.get_agent_goal)
|
||||
assert agent_goal.id == sample_combined_input.agent_goal.id
|
||||
|
||||
# Test get_summary_from_history query
|
||||
summary = await handle.query(AgentGoalWorkflow.get_summary_from_history)
|
||||
# Summary might be None if not set, so check for that
|
||||
if sample_combined_input.tool_params.conversation_summary:
|
||||
assert summary == sample_combined_input.tool_params.conversation_summary
|
||||
else:
|
||||
assert summary is None
|
||||
|
||||
# Test get_latest_tool_data query (should be None initially)
|
||||
tool_data = await handle.query(AgentGoalWorkflow.get_latest_tool_data)
|
||||
assert tool_data is None
|
||||
|
||||
# End workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
await handle.result()
|
||||
|
||||
async def test_enable_disable_debugging_confirm_signals(self, client: Client, sample_combined_input: CombinedInput):
|
||||
"""Test debugging confirm enable/disable signals."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[mock_get_wf_env_vars],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Test enable debugging confirm signal
|
||||
await handle.signal(AgentGoalWorkflow.enable_debugging_confirm)
|
||||
|
||||
# Test disable debugging confirm signal
|
||||
await handle.signal(AgentGoalWorkflow.disable_debugging_confirm)
|
||||
|
||||
# End workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
result = await handle.result()
|
||||
assert isinstance(result, str)
|
||||
|
||||
async def test_workflow_with_empty_prompt_queue(self, client: Client, sample_agent_goal):
|
||||
"""Test workflow behavior with empty prompt queue."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create input with empty prompt queue
|
||||
from collections import deque
|
||||
tool_params = AgentGoalWorkflowParams(
|
||||
conversation_summary=None,
|
||||
prompt_queue=deque()
|
||||
)
|
||||
combined_input = CombinedInput(
|
||||
agent_goal=sample_agent_goal,
|
||||
tool_params=tool_params
|
||||
)
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[mock_get_wf_env_vars],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Give workflow time to initialize
|
||||
import asyncio
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Query initial state
|
||||
conversation_history = await handle.query(AgentGoalWorkflow.get_conversation_history)
|
||||
assert isinstance(conversation_history, dict)
|
||||
assert "messages" in conversation_history
|
||||
|
||||
# Should have no messages initially (empty prompt queue, no summary)
|
||||
messages = conversation_history["messages"]
|
||||
assert len(messages) == 0
|
||||
|
||||
# End workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
result = await handle.result()
|
||||
assert isinstance(result, str)
|
||||
|
||||
async def test_multiple_user_prompts(self, client: Client, sample_combined_input: CombinedInput):
|
||||
"""Test workflow handling multiple user prompts in sequence."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
@activity.defn(name="agent_validatePrompt")
|
||||
async def mock_agent_validatePrompt(validation_input: ValidationInput) -> ValidationResult:
|
||||
return ValidationResult(
|
||||
validationResult=True,
|
||||
validationFailedReason={}
|
||||
)
|
||||
|
||||
@activity.defn(name="agent_toolPlanner")
|
||||
async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict:
|
||||
# Keep workflow running for multiple prompts
|
||||
return {
|
||||
"next": "question",
|
||||
"response": f"Processed: {input.prompt}"
|
||||
}
|
||||
|
||||
async with Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[
|
||||
mock_get_wf_env_vars,
|
||||
mock_agent_validatePrompt,
|
||||
mock_agent_toolPlanner
|
||||
],
|
||||
):
|
||||
handle = await client.start_workflow(
|
||||
AgentGoalWorkflow.run,
|
||||
sample_combined_input,
|
||||
id=str(uuid.uuid4()),
|
||||
task_queue=task_queue_name,
|
||||
)
|
||||
|
||||
# Send multiple prompts
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "First message")
|
||||
import asyncio
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "Second message")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await handle.signal(AgentGoalWorkflow.user_prompt, "Third message")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# End workflow
|
||||
await handle.signal(AgentGoalWorkflow.end_chat)
|
||||
result = await handle.result()
|
||||
assert isinstance(result, str)
|
||||
|
||||
# Parse result and verify multiple messages
|
||||
import json
|
||||
try:
|
||||
conversation_history = json.loads(result.replace("'", '"'))
|
||||
except:
|
||||
conversation_history = eval(result)
|
||||
messages = conversation_history["messages"]
|
||||
|
||||
# Should have at least one user message (timing dependent)
|
||||
user_messages = [msg for msg in messages if msg["actor"] == "user"]
|
||||
assert len(user_messages) >= 1
|
||||
|
||||
# Verify at least the first message was processed
|
||||
message_texts = [str(msg["response"]) for msg in user_messages]
|
||||
assert any("First message" in text for text in message_texts)
|
||||
466
tests/test_tool_activities.py
Normal file
466
tests/test_tool_activities.py
Normal file
@@ -0,0 +1,466 @@
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
import pytest
|
||||
from temporalio.client import Client
|
||||
from temporalio.worker import Worker
|
||||
from temporalio.testing import ActivityEnvironment
|
||||
|
||||
from activities.tool_activities import ToolActivities, dynamic_tool_activity
|
||||
from models.data_types import (
|
||||
ValidationInput,
|
||||
ValidationResult,
|
||||
ToolPromptInput,
|
||||
EnvLookupInput,
|
||||
EnvLookupOutput
|
||||
)
|
||||
|
||||
|
||||
class TestToolActivities:
|
||||
"""Test cases for ToolActivities."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test environment for each test."""
|
||||
self.tool_activities = ToolActivities()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_validatePrompt_valid_prompt(self, sample_agent_goal, sample_conversation_history):
|
||||
"""Test agent_validatePrompt with a valid prompt."""
|
||||
validation_input = ValidationInput(
|
||||
prompt="I need help with the test tool",
|
||||
conversation_history=sample_conversation_history,
|
||||
agent_goal=sample_agent_goal
|
||||
)
|
||||
|
||||
# Mock the agent_toolPlanner to return a valid response
|
||||
mock_response = {
|
||||
"validationResult": True,
|
||||
"validationFailedReason": {}
|
||||
}
|
||||
|
||||
with patch.object(self.tool_activities, 'agent_toolPlanner', new_callable=AsyncMock) as mock_planner:
|
||||
mock_planner.return_value = mock_response
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.agent_validatePrompt,
|
||||
validation_input
|
||||
)
|
||||
|
||||
assert isinstance(result, ValidationResult)
|
||||
assert result.validationResult is True
|
||||
assert result.validationFailedReason == {}
|
||||
|
||||
# Verify the mock was called with correct parameters
|
||||
mock_planner.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_validatePrompt_invalid_prompt(self, sample_agent_goal, sample_conversation_history):
|
||||
"""Test agent_validatePrompt with an invalid prompt."""
|
||||
validation_input = ValidationInput(
|
||||
prompt="asdfghjkl nonsense",
|
||||
conversation_history=sample_conversation_history,
|
||||
agent_goal=sample_agent_goal
|
||||
)
|
||||
|
||||
# Mock the agent_toolPlanner to return an invalid response
|
||||
mock_response = {
|
||||
"validationResult": False,
|
||||
"validationFailedReason": {
|
||||
"next": "question",
|
||||
"response": "Your request doesn't make sense in this context"
|
||||
}
|
||||
}
|
||||
|
||||
with patch.object(self.tool_activities, 'agent_toolPlanner', new_callable=AsyncMock) as mock_planner:
|
||||
mock_planner.return_value = mock_response
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.agent_validatePrompt,
|
||||
validation_input
|
||||
)
|
||||
|
||||
assert isinstance(result, ValidationResult)
|
||||
assert result.validationResult is False
|
||||
assert "doesn't make sense" in str(result.validationFailedReason)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_toolPlanner_success(self):
|
||||
"""Test agent_toolPlanner with successful LLM response."""
|
||||
prompt_input = ToolPromptInput(
|
||||
prompt="Test prompt",
|
||||
context_instructions="Test context instructions"
|
||||
)
|
||||
|
||||
# Mock the completion function
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = '{"next": "confirm", "tool": "TestTool", "response": "Test response"}'
|
||||
|
||||
with patch('activities.tool_activities.completion') as mock_completion:
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.agent_toolPlanner,
|
||||
prompt_input
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["next"] == "confirm"
|
||||
assert result["tool"] == "TestTool"
|
||||
assert result["response"] == "Test response"
|
||||
|
||||
# Verify completion was called with correct parameters
|
||||
mock_completion.assert_called_once()
|
||||
call_args = mock_completion.call_args[1]
|
||||
assert call_args["model"] == self.tool_activities.llm_model
|
||||
assert len(call_args["messages"]) == 2
|
||||
assert call_args["messages"][0]["role"] == "system"
|
||||
assert call_args["messages"][1]["role"] == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_toolPlanner_with_custom_base_url(self):
|
||||
"""Test agent_toolPlanner with custom base URL configuration."""
|
||||
# Set up tool activities with custom base URL
|
||||
with patch.dict(os.environ, {'LLM_BASE_URL': 'https://custom.endpoint.com'}):
|
||||
tool_activities = ToolActivities()
|
||||
|
||||
prompt_input = ToolPromptInput(
|
||||
prompt="Test prompt",
|
||||
context_instructions="Test context instructions"
|
||||
)
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = '{"next": "done", "response": "Test"}'
|
||||
|
||||
with patch('activities.tool_activities.completion') as mock_completion:
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
await activity_env.run(
|
||||
tool_activities.agent_toolPlanner,
|
||||
prompt_input
|
||||
)
|
||||
|
||||
# Verify base_url was included in the call
|
||||
call_args = mock_completion.call_args[1]
|
||||
assert "base_url" in call_args
|
||||
assert call_args["base_url"] == "https://custom.endpoint.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_toolPlanner_json_parsing_error(self):
|
||||
"""Test agent_toolPlanner handles JSON parsing errors."""
|
||||
prompt_input = ToolPromptInput(
|
||||
prompt="Test prompt",
|
||||
context_instructions="Test context instructions"
|
||||
)
|
||||
|
||||
# Mock the completion function to return invalid JSON
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = 'Invalid JSON response'
|
||||
|
||||
with patch('activities.tool_activities.completion') as mock_completion:
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
with pytest.raises(Exception): # Should raise JSON parsing error
|
||||
await activity_env.run(
|
||||
self.tool_activities.agent_toolPlanner,
|
||||
prompt_input
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_wf_env_vars_default_values(self):
|
||||
"""Test get_wf_env_vars with default values."""
|
||||
env_input = EnvLookupInput(
|
||||
show_confirm_env_var_name="SHOW_CONFIRM",
|
||||
show_confirm_default=True
|
||||
)
|
||||
|
||||
# Clear environment variables
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.get_wf_env_vars,
|
||||
env_input
|
||||
)
|
||||
|
||||
assert isinstance(result, EnvLookupOutput)
|
||||
assert result.show_confirm is True # default value
|
||||
assert result.multi_goal_mode is True # default value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_wf_env_vars_custom_values(self):
|
||||
"""Test get_wf_env_vars with custom environment values."""
|
||||
env_input = EnvLookupInput(
|
||||
show_confirm_env_var_name="SHOW_CONFIRM",
|
||||
show_confirm_default=True
|
||||
)
|
||||
|
||||
# Set environment variables
|
||||
with patch.dict(os.environ, {
|
||||
'SHOW_CONFIRM': 'false',
|
||||
'AGENT_GOAL': 'specific_goal'
|
||||
}):
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.get_wf_env_vars,
|
||||
env_input
|
||||
)
|
||||
|
||||
assert isinstance(result, EnvLookupOutput)
|
||||
assert result.show_confirm is False # from env var
|
||||
assert result.multi_goal_mode is False # from env var
|
||||
|
||||
def test_sanitize_json_response(self):
|
||||
"""Test JSON response sanitization."""
|
||||
# Test with markdown code blocks
|
||||
response_with_markdown = "```json\n{\"test\": \"value\"}\n```"
|
||||
sanitized = self.tool_activities.sanitize_json_response(response_with_markdown)
|
||||
assert sanitized == '{"test": "value"}'
|
||||
|
||||
# Test with extra whitespace
|
||||
response_with_whitespace = " \n{\"test\": \"value\"} \n"
|
||||
sanitized = self.tool_activities.sanitize_json_response(response_with_whitespace)
|
||||
assert sanitized == '{"test": "value"}'
|
||||
|
||||
def test_parse_json_response_success(self):
|
||||
"""Test successful JSON parsing."""
|
||||
json_string = '{"next": "confirm", "tool": "TestTool"}'
|
||||
result = self.tool_activities.parse_json_response(json_string)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["next"] == "confirm"
|
||||
assert result["tool"] == "TestTool"
|
||||
|
||||
def test_parse_json_response_failure(self):
|
||||
"""Test JSON parsing with invalid JSON."""
|
||||
invalid_json = "Not valid JSON"
|
||||
|
||||
with pytest.raises(Exception): # Should raise JSON parsing error
|
||||
self.tool_activities.parse_json_response(invalid_json)
|
||||
|
||||
|
||||
class TestDynamicToolActivity:
|
||||
"""Test cases for dynamic_tool_activity."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_tool_activity_sync_handler(self):
|
||||
"""Test dynamic tool activity with synchronous handler."""
|
||||
# Mock the activity info and payload converter
|
||||
mock_info = MagicMock()
|
||||
mock_info.activity_type = "TestTool"
|
||||
|
||||
mock_payload_converter = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
mock_payload.payload = b'{"test_arg": "test_value"}'
|
||||
mock_payload_converter.from_payload.return_value = {"test_arg": "test_value"}
|
||||
|
||||
# Mock the handler function
|
||||
def mock_handler(args):
|
||||
return {"result": f"Handled {args['test_arg']}"}
|
||||
|
||||
with patch('temporalio.activity.info', return_value=mock_info), \
|
||||
patch('temporalio.activity.payload_converter', return_value=mock_payload_converter), \
|
||||
patch('tools.get_handler', return_value=mock_handler):
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
dynamic_tool_activity,
|
||||
[mock_payload]
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["result"] == "Handled test_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_tool_activity_async_handler(self):
|
||||
"""Test dynamic tool activity with asynchronous handler."""
|
||||
# Mock the activity info and payload converter
|
||||
mock_info = MagicMock()
|
||||
mock_info.activity_type = "AsyncTestTool"
|
||||
|
||||
mock_payload_converter = MagicMock()
|
||||
mock_payload = MagicMock()
|
||||
mock_payload.payload = b'{"test_arg": "async_test"}'
|
||||
mock_payload_converter.from_payload.return_value = {"test_arg": "async_test"}
|
||||
|
||||
# Mock the async handler function
|
||||
async def mock_async_handler(args):
|
||||
return {"async_result": f"Async handled {args['test_arg']}"}
|
||||
|
||||
with patch('temporalio.activity.info', return_value=mock_info), \
|
||||
patch('temporalio.activity.payload_converter', return_value=mock_payload_converter), \
|
||||
patch('tools.get_handler', return_value=mock_async_handler):
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
dynamic_tool_activity,
|
||||
[mock_payload]
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["async_result"] == "Async handled async_test"
|
||||
|
||||
|
||||
class TestToolActivitiesIntegration:
|
||||
"""Integration tests for ToolActivities in a real Temporal environment."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_activities_in_worker(self, client: Client):
|
||||
"""Test activities can be registered and executed in a worker."""
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
tool_activities = ToolActivities()
|
||||
|
||||
# Test get_wf_env_vars activity using ActivityEnvironment
|
||||
env_input = EnvLookupInput(
|
||||
show_confirm_env_var_name="TEST_CONFIRM",
|
||||
show_confirm_default=False
|
||||
)
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
tool_activities.get_wf_env_vars,
|
||||
env_input
|
||||
)
|
||||
|
||||
assert isinstance(result, EnvLookupOutput)
|
||||
assert isinstance(result.show_confirm, bool)
|
||||
assert isinstance(result.multi_goal_mode, bool)
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error handling."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test environment for each test."""
|
||||
self.tool_activities = ToolActivities()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_validatePrompt_with_empty_conversation_history(self, sample_agent_goal):
|
||||
"""Test validation with empty conversation history."""
|
||||
validation_input = ValidationInput(
|
||||
prompt="Test prompt",
|
||||
conversation_history={"messages": []},
|
||||
agent_goal=sample_agent_goal
|
||||
)
|
||||
|
||||
mock_response = {
|
||||
"validationResult": True,
|
||||
"validationFailedReason": {}
|
||||
}
|
||||
|
||||
with patch.object(self.tool_activities, 'agent_toolPlanner', new_callable=AsyncMock) as mock_planner:
|
||||
mock_planner.return_value = mock_response
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.agent_validatePrompt,
|
||||
validation_input
|
||||
)
|
||||
|
||||
assert isinstance(result, ValidationResult)
|
||||
assert result.validationResult == True
|
||||
assert result.validationFailedReason == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_toolPlanner_with_long_prompt(self):
|
||||
"""Test toolPlanner with very long prompt."""
|
||||
long_prompt = "This is a very long prompt " * 100
|
||||
tool_prompt_input = ToolPromptInput(
|
||||
prompt=long_prompt,
|
||||
context_instructions="Test context instructions"
|
||||
)
|
||||
|
||||
# Mock the completion response
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = '{"next": "done", "response": "Processed long prompt"}'
|
||||
|
||||
with patch('activities.tool_activities.completion', return_value=mock_response):
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.agent_toolPlanner,
|
||||
tool_prompt_input
|
||||
)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["next"] == "done"
|
||||
assert "Processed long prompt" in result["response"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sanitize_json_with_various_formats(self):
|
||||
"""Test JSON sanitization with various input formats."""
|
||||
# Test markdown code blocks
|
||||
markdown_json = "```json\n{\"test\": \"value\"}\n```"
|
||||
result = self.tool_activities.sanitize_json_response(markdown_json)
|
||||
assert result == '{"test": "value"}'
|
||||
|
||||
# Test with extra whitespace
|
||||
whitespace_json = " \n {\"test\": \"value\"} \n "
|
||||
result = self.tool_activities.sanitize_json_response(whitespace_json)
|
||||
assert result == '{"test": "value"}'
|
||||
|
||||
# Test already clean JSON
|
||||
clean_json = '{"test": "value"}'
|
||||
result = self.tool_activities.sanitize_json_response(clean_json)
|
||||
assert result == '{"test": "value"}'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_json_response_with_invalid_json(self):
|
||||
"""Test JSON parsing with invalid JSON."""
|
||||
with pytest.raises(json.JSONDecodeError):
|
||||
self.tool_activities.parse_json_response("Invalid JSON {test: value")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_wf_env_vars_with_various_env_values(self):
|
||||
"""Test environment variable parsing with different values."""
|
||||
# Test with "true" string
|
||||
with patch.dict(os.environ, {"TEST_CONFIRM": "true"}):
|
||||
env_input = EnvLookupInput(
|
||||
show_confirm_env_var_name="TEST_CONFIRM",
|
||||
show_confirm_default=False
|
||||
)
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.get_wf_env_vars,
|
||||
env_input
|
||||
)
|
||||
|
||||
assert result.show_confirm == True
|
||||
|
||||
# Test with "false" string
|
||||
with patch.dict(os.environ, {"TEST_CONFIRM": "false"}):
|
||||
env_input = EnvLookupInput(
|
||||
show_confirm_env_var_name="TEST_CONFIRM",
|
||||
show_confirm_default=True
|
||||
)
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.get_wf_env_vars,
|
||||
env_input
|
||||
)
|
||||
|
||||
assert result.show_confirm == False
|
||||
|
||||
# Test with missing env var (should use default)
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
env_input = EnvLookupInput(
|
||||
show_confirm_env_var_name="MISSING_VAR",
|
||||
show_confirm_default=True
|
||||
)
|
||||
|
||||
activity_env = ActivityEnvironment()
|
||||
result = await activity_env.run(
|
||||
self.tool_activities.get_wf_env_vars,
|
||||
env_input
|
||||
)
|
||||
|
||||
assert result.show_confirm == True
|
||||
@@ -1,9 +1,19 @@
|
||||
import uuid
|
||||
from temporalio.client import Client, WorkflowExecutionStatus
|
||||
from temporalio.worker import Worker
|
||||
from temporalio import activity
|
||||
import concurrent.futures
|
||||
from temporalio.testing import WorkflowEnvironment
|
||||
from api.main import get_initial_agent_goal
|
||||
from models.data_types import AgentGoalWorkflowParams, CombinedInput
|
||||
from models.data_types import (
|
||||
AgentGoalWorkflowParams,
|
||||
CombinedInput,
|
||||
ValidationResult,
|
||||
ValidationInput,
|
||||
EnvLookupOutput,
|
||||
EnvLookupInput,
|
||||
ToolPromptInput
|
||||
)
|
||||
from workflows.agent_goal_workflow import AgentGoalWorkflow
|
||||
from activities.tool_activities import ToolActivities, dynamic_tool_activity
|
||||
from unittest.mock import patch
|
||||
@@ -31,15 +41,41 @@ async def test_flight_booking(client: Client):
|
||||
# Create the test environment
|
||||
#env = await WorkflowEnvironment.start_local()
|
||||
#client = env.client
|
||||
task_queue_name = "agent-ai-workflow"
|
||||
workflow_id = "agent-workflow"
|
||||
task_queue_name = str(uuid.uuid4())
|
||||
workflow_id = str(uuid.uuid4())
|
||||
|
||||
# Create mock activity functions with proper signatures
|
||||
@activity.defn(name="get_wf_env_vars")
|
||||
async def mock_get_wf_env_vars(input: EnvLookupInput) -> EnvLookupOutput:
|
||||
return EnvLookupOutput(
|
||||
show_confirm=True,
|
||||
multi_goal_mode=True
|
||||
)
|
||||
|
||||
@activity.defn(name="agent_validatePrompt")
|
||||
async def mock_agent_validatePrompt(validation_input: ValidationInput) -> ValidationResult:
|
||||
return ValidationResult(
|
||||
validationResult=True,
|
||||
validationFailedReason={}
|
||||
)
|
||||
|
||||
@activity.defn(name="agent_toolPlanner")
|
||||
async def mock_agent_toolPlanner(input: ToolPromptInput) -> dict:
|
||||
return {
|
||||
"next": "done",
|
||||
"response": "Test response from LLM"
|
||||
}
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor:
|
||||
worker = Worker(
|
||||
client,
|
||||
task_queue=task_queue_name,
|
||||
workflows=[AgentGoalWorkflow],
|
||||
activities=[ToolActivities.agent_validatePrompt, ToolActivities.agent_toolPlanner, ToolActivities.get_wf_env_vars, dynamic_tool_activity],
|
||||
activities=[
|
||||
mock_get_wf_env_vars,
|
||||
mock_agent_validatePrompt,
|
||||
mock_agent_toolPlanner
|
||||
],
|
||||
activity_executor=activity_executor,
|
||||
)
|
||||
|
||||
|
||||
49
todo.md
49
todo.md
@@ -1,8 +1,33 @@
|
||||
# todo list
|
||||
[x] take steve's confirm box changes https://temporaltechnologies.slack.com/archives/D062SV8KEEM/p1745251279164319 <br />
|
||||
[ ] consider adding goal categories to goal picker
|
||||
|
||||
[ ] adding fintech goals <br />
|
||||
## General Agent Enhancements
|
||||
|
||||
[ ] MCP: There is a plan to add MCP (Model Context Protocol) to the agent. This really really really needs to be done and is scheduled to be done by @steveandroulakis some time in June 2025.
|
||||
|
||||
[ ] Google's A2A is emerging as the standard way to hand off agents to other agents. We should examine implementing this soon.
|
||||
|
||||
[ ] Custom metrics/tracing is important for AI specific aspects such as number of LLM calls, number of bad LLM responses that require retrying, number of bad chat outcomes. We should add this.
|
||||
|
||||
[ ] Evals are very important in agents. We want to be able to 'judge' the agent's performance both in dev and production (AIOps). This will help us improve our agent's performance over time in a targeted fashion.
|
||||
|
||||
[ ] Dynamically switch LLMs on persistent failures: <br />
|
||||
- detect failure in the activity using failurecount <br />
|
||||
- activity switches to secondary LLM defined in .env
|
||||
- activity reports switch to workflow
|
||||
|
||||
[ ] Collapse history/summarize chat after goal finished <br />
|
||||
|
||||
[ ] Write tests<br />
|
||||
|
||||
[ ] non-retry the api key error - "Invalid API Key provided: sk_test_**J..." and "AuthenticationError" <br />
|
||||
|
||||
[ ] add visual feedback when workflow starting <br />
|
||||
|
||||
[ ] enable user to list agents at any time - like end conversation - probably with a next step<br />
|
||||
|
||||
## Ideas for more goals and tools
|
||||
|
||||
[ ] Add fintech goals <br />
|
||||
- Fraud Detection and Prevention - The AI monitors transactions across accounts, flagging suspicious activities (e.g., unusual spending patterns or login attempts) and autonomously freezing accounts or notifying customers and compliance teams.<br />
|
||||
- Personalized Financial Advice - An AI agent analyzes a customer’s financial data (e.g., income, spending habits, savings, investments) and provides tailored advice, such as budgeting tips, investment options, or debt repayment strategies.<br />
|
||||
- Portfolio Management and Rebalancing - The AI monitors a customer’s investment portfolio, rebalancing it automatically based on market trends, risk tolerance, and financial goals (e.g., shifting assets between stocks, bonds, or crypto).<br />
|
||||
@@ -12,21 +37,3 @@
|
||||
[ ] tool is maybe a new tool asking the LLM to advise
|
||||
|
||||
[ ] for demo simulate failure - add utilities/simulated failures from pipeline demo <br />
|
||||
|
||||
[ ] LLM failure->autoswitch: <br />
|
||||
- detect failure in the activity using failurecount <br />
|
||||
- activity switches to secondary LLM defined in .env
|
||||
- activity reports switch to workflow
|
||||
|
||||
[ ] for demo simulate failure - add utilities/simulated failures from pipeline demo <br />
|
||||
|
||||
[ ] expand [tests](./tests/agent_goal_workflow_test.py)<br />
|
||||
[ ] collapse history/summarize after goal finished <br />
|
||||
[ ] add aws bedrock <br />
|
||||
|
||||
[ ] ask the ai agent how it did at the end of the conversation, was it efficient? successful? insert a search attribute to document that before return <br />
|
||||
- Insight into the agent’s performance <br />
|
||||
[ ] non-retry the api key error - "Invalid API Key provided: sk_test_**J..." and "AuthenticationError" <br />
|
||||
[ ] add visual feedback when workflow starting <br />
|
||||
[ ] enable user to list agents at any time - like end conversation - probably with a next step<br />
|
||||
- with changing "'Next should only be "pick-new-goal" if all tools have been run (use the system prompt to figure that out).'" in [prompt_generators](./prompts/agent_prompt_generators.py).
|
||||
@@ -22,6 +22,12 @@ from .ecommerce.get_order import get_order
|
||||
from .ecommerce.track_package import track_package
|
||||
from .ecommerce.list_orders import list_orders
|
||||
|
||||
from .food.get_menu import get_menu
|
||||
from .food.get_menu_item_details import get_menu_item_details
|
||||
from .food.add_to_cart import add_to_cart
|
||||
from .food.place_order import place_order
|
||||
from .food.check_order_status import check_order_status
|
||||
|
||||
from .give_hint import give_hint
|
||||
from .guess_location import guess_location
|
||||
|
||||
@@ -67,6 +73,16 @@ def get_handler(tool_name: str):
|
||||
return track_package
|
||||
if tool_name == "ListOrders":
|
||||
return list_orders
|
||||
if tool_name == "GetMenu":
|
||||
return get_menu
|
||||
if tool_name == "GetMenuItemDetails":
|
||||
return get_menu_item_details
|
||||
if tool_name == "AddToCart":
|
||||
return add_to_cart
|
||||
if tool_name == "PlaceOrder":
|
||||
return place_order
|
||||
if tool_name == "CheckOrderStatus":
|
||||
return check_order_status
|
||||
if tool_name == "GiveHint":
|
||||
return give_hint
|
||||
if tool_name == "GuessLocation":
|
||||
|
||||
@@ -27,7 +27,7 @@ def ensure_customer_exists(
|
||||
def create_invoice(args: dict) -> dict:
|
||||
"""Create and finalize a Stripe invoice."""
|
||||
# If an API key exists in the env file, find or create customer
|
||||
if stripe.api_key is not None:
|
||||
if stripe.api_key is not None and stripe.api_key != "":
|
||||
customer_id = ensure_customer_exists(
|
||||
args.get("customer_id"), args.get("email", "default@example.com")
|
||||
)
|
||||
@@ -69,15 +69,3 @@ def create_invoice(args: dict) -> dict:
|
||||
"invoiceURL": "https://pay.example.com/invoice/12345",
|
||||
"reference": "INV-12345",
|
||||
}
|
||||
|
||||
def create_invoice_example(args: dict) -> dict:
|
||||
"""
|
||||
This is an example implementation of the CreateInvoice tool
|
||||
Doesn't call any external services, just returns a dummy response
|
||||
"""
|
||||
print("[CreateInvoice] Creating invoice with:", args)
|
||||
return {
|
||||
"invoiceStatus": "generated",
|
||||
"invoiceURL": "https://pay.example.com/invoice/12345",
|
||||
"reference": "INV-12345",
|
||||
}
|
||||
|
||||
122
tools/data/food_ordering_data.json
Normal file
122
tools/data/food_ordering_data.json
Normal file
@@ -0,0 +1,122 @@
|
||||
{
|
||||
"restaurants": [
|
||||
{
|
||||
"id": "rest_001",
|
||||
"name": "Tony's Pizza Palace",
|
||||
"menu": [
|
||||
{
|
||||
"id": "item_001",
|
||||
"name": "Margherita Pizza",
|
||||
"category": "Pizza",
|
||||
"price": 14.99,
|
||||
"description": "Fresh mozzarella, tomato sauce, basil",
|
||||
"available": true
|
||||
},
|
||||
{
|
||||
"id": "item_002",
|
||||
"name": "Pepperoni Pizza",
|
||||
"category": "Pizza",
|
||||
"price": 16.99,
|
||||
"description": "Classic pepperoni with mozzarella and tomato sauce",
|
||||
"available": true
|
||||
},
|
||||
{
|
||||
"id": "item_003",
|
||||
"name": "Caesar Salad",
|
||||
"category": "Salad",
|
||||
"price": 9.99,
|
||||
"description": "Romaine lettuce, parmesan, croutons, caesar dressing",
|
||||
"available": true
|
||||
},
|
||||
{
|
||||
"id": "item_004",
|
||||
"name": "Garlic Bread",
|
||||
"category": "Sides",
|
||||
"price": 5.99,
|
||||
"description": "Fresh baked bread with garlic butter",
|
||||
"available": true
|
||||
},
|
||||
{
|
||||
"id": "item_005",
|
||||
"name": "Tiramisu",
|
||||
"category": "Dessert",
|
||||
"price": 7.99,
|
||||
"description": "Classic Italian dessert with coffee and mascarpone",
|
||||
"available": true
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"carts": {
|
||||
"steve@example.com": {
|
||||
"restaurant_id": "rest_001",
|
||||
"items": []
|
||||
}
|
||||
},
|
||||
"orders": [
|
||||
{
|
||||
"id": "order_001",
|
||||
"customer_email": "john.doe@example.com",
|
||||
"restaurant_id": "rest_001",
|
||||
"items": [
|
||||
{
|
||||
"item_id": "item_001",
|
||||
"quantity": 1,
|
||||
"price": 14.99
|
||||
},
|
||||
{
|
||||
"item_id": "item_004",
|
||||
"quantity": 2,
|
||||
"price": 5.99
|
||||
}
|
||||
],
|
||||
"total": 26.97,
|
||||
"status": "delivered",
|
||||
"order_date": "2025-05-29T18:30:00Z",
|
||||
"estimated_delivery": "2025-05-29T19:15:00Z",
|
||||
"actual_delivery": "2025-05-29T19:12:00Z"
|
||||
},
|
||||
{
|
||||
"id": "order_002",
|
||||
"customer_email": "jane.smith@example.com",
|
||||
"restaurant_id": "rest_001",
|
||||
"items": [
|
||||
{
|
||||
"item_id": "item_002",
|
||||
"quantity": 1,
|
||||
"price": 16.99
|
||||
}
|
||||
],
|
||||
"total": 16.99,
|
||||
"status": "preparing",
|
||||
"order_date": "2025-05-30T12:00:00Z",
|
||||
"estimated_delivery": "2025-05-30T12:45:00Z"
|
||||
},
|
||||
{
|
||||
"id": "order_58539a70",
|
||||
"customer_email": "steve@example.com",
|
||||
"restaurant_id": "rest_001",
|
||||
"items": [
|
||||
{
|
||||
"item_id": "item_001",
|
||||
"quantity": 1,
|
||||
"price": 14.99
|
||||
},
|
||||
{
|
||||
"item_id": "item_002",
|
||||
"quantity": 1,
|
||||
"price": 16.99
|
||||
},
|
||||
{
|
||||
"item_id": "item_004",
|
||||
"quantity": 1,
|
||||
"price": 5.99
|
||||
}
|
||||
],
|
||||
"total": 37.97,
|
||||
"status": "preparing",
|
||||
"order_date": "2025-05-30T20:28:18.444162Z",
|
||||
"estimated_delivery": "2025-05-30T20:58:18.444169Z"
|
||||
}
|
||||
]
|
||||
}
|
||||
63
tools/food/add_to_cart.py
Normal file
63
tools/food/add_to_cart.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
def add_to_cart(args: dict) -> dict:
|
||||
customer_email = args.get("customer_email")
|
||||
item_id = args.get("item_id")
|
||||
quantity = int(args.get("quantity", 1))
|
||||
restaurant_id = args.get("restaurant_id", "rest_001")
|
||||
|
||||
file_path = Path(__file__).resolve().parent.parent / "data" / "food_ordering_data.json"
|
||||
if not file_path.exists():
|
||||
return {"error": "Data file not found."}
|
||||
|
||||
with open(file_path, "r") as file:
|
||||
data = json.load(file)
|
||||
|
||||
# Find the item to get its price
|
||||
item_price = None
|
||||
item_name = None
|
||||
for restaurant in data["restaurants"]:
|
||||
if restaurant["id"] == restaurant_id:
|
||||
for item in restaurant["menu"]:
|
||||
if item["id"] == item_id:
|
||||
item_price = item["price"]
|
||||
item_name = item["name"]
|
||||
break
|
||||
|
||||
if item_price is None:
|
||||
return {"error": f"Item {item_id} not found."}
|
||||
|
||||
# Initialize cart if it doesn't exist
|
||||
if customer_email not in data["carts"]:
|
||||
data["carts"][customer_email] = {
|
||||
"restaurant_id": restaurant_id,
|
||||
"items": []
|
||||
}
|
||||
|
||||
# Check if item already in cart
|
||||
cart = data["carts"][customer_email]
|
||||
existing_item = None
|
||||
for cart_item in cart["items"]:
|
||||
if cart_item["item_id"] == item_id:
|
||||
existing_item = cart_item
|
||||
break
|
||||
|
||||
if existing_item:
|
||||
existing_item["quantity"] += quantity
|
||||
else:
|
||||
cart["items"].append({
|
||||
"item_id": item_id,
|
||||
"quantity": quantity,
|
||||
"price": item_price
|
||||
})
|
||||
|
||||
# Save back to file
|
||||
with open(file_path, "w") as file:
|
||||
json.dump(data, file, indent=2)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Added {quantity} x {item_name} to cart",
|
||||
"cart": cart
|
||||
}
|
||||
28
tools/food/check_order_status.py
Normal file
28
tools/food/check_order_status.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
def check_order_status(args: dict) -> dict:
|
||||
order_id = args.get("order_id")
|
||||
|
||||
file_path = Path(__file__).resolve().parent.parent / "data" / "food_ordering_data.json"
|
||||
if not file_path.exists():
|
||||
return {"error": "Data file not found."}
|
||||
|
||||
with open(file_path, "r") as file:
|
||||
data = json.load(file)
|
||||
|
||||
orders = data["orders"]
|
||||
|
||||
for order in orders:
|
||||
if order["id"] == order_id:
|
||||
return {
|
||||
"order_id": order["id"],
|
||||
"status": order["status"],
|
||||
"order_date": order["order_date"],
|
||||
"estimated_delivery": order["estimated_delivery"],
|
||||
"actual_delivery": order.get("actual_delivery"),
|
||||
"total": order["total"],
|
||||
"items": order["items"]
|
||||
}
|
||||
|
||||
return {"error": f"Order {order_id} not found."}
|
||||
23
tools/food/get_menu.py
Normal file
23
tools/food/get_menu.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
def get_menu(args: dict) -> dict:
|
||||
restaurant_id = args.get("restaurant_id", "rest_001")
|
||||
|
||||
file_path = Path(__file__).resolve().parent.parent / "data" / "food_ordering_data.json"
|
||||
if not file_path.exists():
|
||||
return {"error": "Data file not found."}
|
||||
|
||||
with open(file_path, "r") as file:
|
||||
data = json.load(file)
|
||||
|
||||
restaurants = data["restaurants"]
|
||||
|
||||
for restaurant in restaurants:
|
||||
if restaurant["id"] == restaurant_id:
|
||||
return {
|
||||
"restaurant_name": restaurant["name"],
|
||||
"menu": restaurant["menu"]
|
||||
}
|
||||
|
||||
return {"error": f"Restaurant {restaurant_id} not found."}
|
||||
23
tools/food/get_menu_item_details.py
Normal file
23
tools/food/get_menu_item_details.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
def get_menu_item_details(args: dict) -> dict:
|
||||
item_id = args.get("item_id")
|
||||
restaurant_id = args.get("restaurant_id", "rest_001")
|
||||
|
||||
file_path = Path(__file__).resolve().parent.parent / "data" / "food_ordering_data.json"
|
||||
if not file_path.exists():
|
||||
return {"error": "Data file not found."}
|
||||
|
||||
with open(file_path, "r") as file:
|
||||
data = json.load(file)
|
||||
|
||||
restaurants = data["restaurants"]
|
||||
|
||||
for restaurant in restaurants:
|
||||
if restaurant["id"] == restaurant_id:
|
||||
for item in restaurant["menu"]:
|
||||
if item["id"] == item_id:
|
||||
return item
|
||||
|
||||
return {"error": f"Menu item {item_id} not found."}
|
||||
57
tools/food/place_order.py
Normal file
57
tools/food/place_order.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from pathlib import Path
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
def place_order(args: dict) -> dict:
|
||||
customer_email = args.get("customer_email")
|
||||
|
||||
file_path = Path(__file__).resolve().parent.parent / "data" / "food_ordering_data.json"
|
||||
if not file_path.exists():
|
||||
return {"error": "Data file not found."}
|
||||
|
||||
with open(file_path, "r") as file:
|
||||
data = json.load(file)
|
||||
|
||||
# Check if cart exists
|
||||
if customer_email not in data["carts"] or not data["carts"][customer_email]["items"]:
|
||||
return {"error": "Cart is empty. Please add items to cart first."}
|
||||
|
||||
cart = data["carts"][customer_email]
|
||||
|
||||
# Calculate total
|
||||
total = sum(item["price"] * item["quantity"] for item in cart["items"])
|
||||
|
||||
# Create order
|
||||
order_id = f"order_{str(uuid.uuid4())[:8]}"
|
||||
order_date = datetime.now().isoformat() + "Z"
|
||||
estimated_delivery = (datetime.now() + timedelta(minutes=30)).isoformat() + "Z"
|
||||
|
||||
new_order = {
|
||||
"id": order_id,
|
||||
"customer_email": customer_email,
|
||||
"restaurant_id": cart["restaurant_id"],
|
||||
"items": cart["items"],
|
||||
"total": round(total, 2),
|
||||
"status": "preparing",
|
||||
"order_date": order_date,
|
||||
"estimated_delivery": estimated_delivery
|
||||
}
|
||||
|
||||
# Add order to data
|
||||
data["orders"].append(new_order)
|
||||
|
||||
# Clear cart
|
||||
data["carts"][customer_email] = {"restaurant_id": cart["restaurant_id"], "items": []}
|
||||
|
||||
# Save back to file
|
||||
with open(file_path, "w") as file:
|
||||
json.dump(data, file, indent=2)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"order_id": order_id,
|
||||
"total": round(total, 2),
|
||||
"estimated_delivery": estimated_delivery,
|
||||
"message": "Order placed successfully!"
|
||||
}
|
||||
@@ -114,10 +114,10 @@ goal_match_train_invoice = AgentGoal(
|
||||
],
|
||||
description="The user wants to book a trip to a city in the UK around the dates of a premier league match. "
|
||||
"Help the user find a premier league match to attend, search and book trains for that match and offers to invoice them for the cost of train tickets. "
|
||||
"The user lives in London. "
|
||||
"The user lives in London. Premier league fixtures may be mocked data, so don't worry about valid season dates and teams. "
|
||||
"Gather args for these tools in order, ensuring you move the user from one tool to the next: "
|
||||
"1. SearchFixtures: Search for fixtures for a team within a specified date range. The user might ask questions about the matches dates and locations to decide on where to go. "
|
||||
"2. SearchTrains: Search for trains to the city of the match and list them for the customer to choose from "
|
||||
"2. SearchTrains: Search for trains to the city of the match. Ensure you list them for the customer to choose from "
|
||||
"3. BookTrains: Book the train tickets, used to invoice the user for the cost of the train tickets "
|
||||
"4. CreateInvoice: Invoices the user for the cost of train tickets, with total and details inferred from the conversation history ",
|
||||
starter_prompt=starter_prompt_generic,
|
||||
@@ -454,6 +454,50 @@ goal_ecomm_list_orders = AgentGoal(
|
||||
),
|
||||
)
|
||||
|
||||
# ----- Food Ordering Goal -----
|
||||
goal_food_ordering = AgentGoal(
|
||||
id="goal_food_ordering",
|
||||
category_tag="food",
|
||||
agent_name="Food Ordering Assistant",
|
||||
agent_friendly_description="Order food from Tony's Pizza Palace. Browse menu, add items to cart, and place orders.",
|
||||
tools=[
|
||||
tool_registry.food_get_menu_tool,
|
||||
tool_registry.food_get_menu_item_details_tool,
|
||||
tool_registry.food_add_to_cart_tool,
|
||||
tool_registry.food_place_order_tool,
|
||||
tool_registry.food_check_order_status_tool,
|
||||
],
|
||||
description="The user wants to order food from Tony's Pizza Palace. Help them browse the menu, learn about menu items, add items to their cart, and place an order. To assist with that goal, help the user gather args for these tools in order: "
|
||||
"1. GetMenu: Show the restaurant menu. This tool is optional if the user already knows what they want. "
|
||||
"2. GetMenuItemDetails: Get details about specific menu items. This tool is optional and can be used multiple times. "
|
||||
"3. AddToCart: Add menu items to the customer's cart. This tool can be used multiple times to add different items. "
|
||||
"4. PlaceOrder: Place the order for items in the cart. "
|
||||
"5. CheckOrderStatus: Check the status of a placed order. This tool is optional and used after placing an order.",
|
||||
starter_prompt=starter_prompt_generic,
|
||||
example_conversation_history="\n ".join(
|
||||
[
|
||||
"user: I'd like to order some food",
|
||||
"agent: Great! I can help you order food from Tony's Pizza Palace. Would you like to see the menu first?",
|
||||
"user: Yes, please show me the menu",
|
||||
"user_confirmed_tool_run: <user clicks confirm on GetMenu tool>",
|
||||
"tool_result: {'restaurant_name': 'Tony\\'s Pizza Palace', 'menu': [{'id': 'item_001', 'name': 'Margherita Pizza', 'category': 'Pizza', 'price': 14.99, 'description': 'Fresh mozzarella, tomato sauce, basil', 'available': True}, {'id': 'item_002', 'name': 'Pepperoni Pizza', 'category': 'Pizza', 'price': 16.99, 'description': 'Classic pepperoni with mozzarella and tomato sauce', 'available': True}]}",
|
||||
"agent: Here's the menu from Tony's Pizza Palace: \n"
|
||||
"Pizza: \n"
|
||||
"- Margherita Pizza ($14.99): Fresh mozzarella, tomato sauce, basil \n"
|
||||
"- Pepperoni Pizza ($16.99): Classic pepperoni with mozzarella and tomato sauce \n"
|
||||
"What would you like to add to your cart? I'll need your email address to create your order.",
|
||||
"user: I'd like a Margherita Pizza. My email is john.doe@example.com",
|
||||
"user_confirmed_tool_run: <user clicks confirm on AddToCart tool>",
|
||||
"tool_result: {'status': 'success', 'message': 'Added 1 x Margherita Pizza to cart', 'cart': {'restaurant_id': 'rest_001', 'items': [{'item_id': 'item_001', 'quantity': 1, 'price': 14.99}]}}",
|
||||
"agent: Perfect! I've added 1 Margherita Pizza to your cart. Would you like to add anything else or are you ready to place your order?",
|
||||
"user: I'm ready to place the order",
|
||||
"user_confirmed_tool_run: <user clicks confirm on PlaceOrder tool>",
|
||||
"tool_result: {'status': 'success', 'order_id': 'order_12345678', 'total': 14.99, 'estimated_delivery': '2025-05-30T13:30:00Z', 'message': 'Order placed successfully!'}",
|
||||
"agent: Order placed successfully! Your order ID is order_12345678 and the total is $14.99. Your food should be delivered by 1:30 PM today. You can check your order status anytime using the order ID.",
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
# Add the goals to a list for more generic processing, like listing available agents
|
||||
goal_list: List[AgentGoal] = []
|
||||
goal_list.append(goal_choose_agent_type)
|
||||
@@ -468,6 +512,7 @@ goal_list.append(goal_fin_move_money)
|
||||
goal_list.append(goal_fin_loan_application)
|
||||
goal_list.append(goal_ecomm_list_orders)
|
||||
goal_list.append(goal_ecomm_order_status)
|
||||
goal_list.append(goal_food_ordering)
|
||||
|
||||
|
||||
# for multi-goal, just set list agents as the last tool
|
||||
@@ -489,6 +534,6 @@ if multi_goal_mode:
|
||||
if tool.name == "ListAgents":
|
||||
list_agents_found = True
|
||||
continue
|
||||
if list_agents_found == False:
|
||||
if list_agents_found is False:
|
||||
goal.tools.append(tool_registry.list_agents_tool)
|
||||
continue
|
||||
|
||||
@@ -1,64 +1,263 @@
|
||||
import os
|
||||
import requests
|
||||
from datetime import datetime, timedelta
|
||||
import random
|
||||
from datetime import datetime, timedelta, date
|
||||
from dotenv import load_dotenv
|
||||
|
||||
PREMIER_LEAGUE_CLUBS_DATA = [
|
||||
{"name": "Arsenal FC", "stadium": "Emirates Stadium"},
|
||||
{"name": "Aston Villa FC", "stadium": "Villa Park"},
|
||||
{"name": "AFC Bournemouth", "stadium": "Vitality Stadium"},
|
||||
{"name": "Brentford FC", "stadium": "Gtech Community Stadium"},
|
||||
{"name": "Brighton & Hove Albion FC", "stadium": "American Express Stadium"},
|
||||
{"name": "Chelsea FC", "stadium": "Stamford Bridge"},
|
||||
{"name": "Crystal Palace FC", "stadium": "Selhurst Park"},
|
||||
{"name": "Everton FC", "stadium": "Goodison Park"},
|
||||
{"name": "Fulham FC", "stadium": "Craven Cottage"},
|
||||
{"name": "Ipswich Town FC", "stadium": "Portman Road"},
|
||||
{"name": "Leicester City FC", "stadium": "King Power Stadium"},
|
||||
{"name": "Liverpool FC", "stadium": "Anfield"},
|
||||
{"name": "Manchester City FC", "stadium": "Etihad Stadium"},
|
||||
{"name": "Manchester United FC", "stadium": "Old Trafford"},
|
||||
{"name": "Newcastle United FC", "stadium": "St James' Park"},
|
||||
{"name": "Nottingham Forest FC", "stadium": "City Ground"},
|
||||
{"name": "Southampton FC", "stadium": "St Mary's Stadium"},
|
||||
{"name": "Tottenham Hotspur FC", "stadium": "Tottenham Hotspur Stadium"},
|
||||
{"name": "West Ham United FC", "stadium": "London Stadium"},
|
||||
{"name": "Wolverhampton Wanderers FC", "stadium": "Molineux Stadium"},
|
||||
]
|
||||
|
||||
|
||||
def get_future_matches(
|
||||
team_name: str,
|
||||
all_clubs_data: list,
|
||||
num_matches: int = 12,
|
||||
date_from: date = None,
|
||||
date_to: date = None,
|
||||
) -> list:
|
||||
"""Generate a set of future Premier League matches for ``team_name``.
|
||||
|
||||
This is a purely mocked schedule. It returns up to ``num_matches``
|
||||
fixtures, respecting the ``date_from`` and ``date_to`` constraints.
|
||||
Matches are typically on Saturdays or Sundays.
|
||||
"""
|
||||
matches = []
|
||||
|
||||
team_details = next((c for c in all_clubs_data if c["name"] == team_name), None)
|
||||
if not team_details:
|
||||
return []
|
||||
|
||||
opponents_pool = [c for c in all_clubs_data if c["name"] != team_name]
|
||||
if not opponents_pool:
|
||||
return []
|
||||
|
||||
# Determine the maximum number of matches we can generate based on opponents
|
||||
# and the requested num_matches
|
||||
num_actual_matches_to_generate = min(num_matches, len(opponents_pool))
|
||||
if num_actual_matches_to_generate == 0:
|
||||
return []
|
||||
|
||||
# Shuffle opponents once and pick them sequentially
|
||||
random.shuffle(opponents_pool) # Shuffle in place
|
||||
|
||||
# Determine the initial Saturday for match week consideration
|
||||
today_date = date.today()
|
||||
# Default to next Saturday
|
||||
current_match_week_saturday = today_date + timedelta(
|
||||
days=(5 - today_date.weekday() + 7) % 7
|
||||
)
|
||||
|
||||
# If today is Saturday and it's late evening, or if today is Sunday,
|
||||
# advance to the following Saturday.
|
||||
now_time = datetime.now().time()
|
||||
if (
|
||||
today_date.weekday() == 5
|
||||
and now_time > datetime.strptime("20:00", "%H:%M").time()
|
||||
) or (today_date.weekday() == 6):
|
||||
current_match_week_saturday += timedelta(days=7)
|
||||
|
||||
# If date_from is specified, ensure our starting Saturday is not before it.
|
||||
if date_from:
|
||||
if current_match_week_saturday < date_from:
|
||||
current_match_week_saturday = date_from
|
||||
# Align current_match_week_saturday to be a Saturday on or after the potentially adjusted date
|
||||
current_match_week_saturday += timedelta(
|
||||
days=(5 - current_match_week_saturday.weekday() + 7) % 7
|
||||
)
|
||||
|
||||
opponent_idx = 0
|
||||
while len(matches) < num_actual_matches_to_generate and opponent_idx < len(
|
||||
opponents_pool
|
||||
):
|
||||
# If the current week's Saturday is already past date_to, stop.
|
||||
if date_to and current_match_week_saturday > date_to:
|
||||
break
|
||||
|
||||
opponent_details = opponents_pool[opponent_idx]
|
||||
is_saturday_game = random.choice([True, True, False])
|
||||
actual_match_date = None
|
||||
kick_off_time = ""
|
||||
|
||||
if is_saturday_game:
|
||||
actual_match_date = current_match_week_saturday
|
||||
kick_off_time = random.choice(["12:30", "15:00", "17:30"])
|
||||
else: # Sunday game
|
||||
actual_match_date = current_match_week_saturday + timedelta(days=1)
|
||||
kick_off_time = random.choice(["14:00", "16:30"])
|
||||
|
||||
# Check if this specific match date is within the date_to constraint
|
||||
if date_to and actual_match_date > date_to:
|
||||
# If this game is too late, try the next week if possible.
|
||||
# (This mainly affects Sunday games if Saturday was the last valid day)
|
||||
current_match_week_saturday += timedelta(days=7)
|
||||
continue # Skip adding this match, try next week.
|
||||
|
||||
match_datetime_gmt = (
|
||||
f"{actual_match_date.strftime('%Y-%m-%d')} {kick_off_time} GMT"
|
||||
)
|
||||
is_home_match = random.choice([True, False])
|
||||
|
||||
if is_home_match:
|
||||
team1_name = team_details["name"]
|
||||
team2_name = opponent_details["name"]
|
||||
stadium_name = team_details["stadium"]
|
||||
else:
|
||||
team1_name = opponent_details["name"]
|
||||
team2_name = team_details["name"]
|
||||
stadium_name = opponent_details["stadium"]
|
||||
|
||||
matches.append(
|
||||
{
|
||||
"team1": team1_name,
|
||||
"team2": team2_name,
|
||||
"stadium": stadium_name,
|
||||
"datetime_gmt": match_datetime_gmt,
|
||||
}
|
||||
)
|
||||
opponent_idx += 1
|
||||
current_match_week_saturday += timedelta(
|
||||
days=7
|
||||
) # Advance to next week's Saturday
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
BASE_URL = "https://api.football-data.org/v4"
|
||||
|
||||
|
||||
def search_fixtures(args: dict) -> dict:
|
||||
load_dotenv(override=True)
|
||||
api_key = os.getenv("FOOTBALL_DATA_API_KEY", "YOUR_DEFAULT_KEY")
|
||||
api_key = os.getenv("FOOTBALL_DATA_API_KEY")
|
||||
|
||||
team_name = args.get("team")
|
||||
date_from_str = args.get("date_from")
|
||||
date_to_str = args.get("date_to")
|
||||
headers = {"X-Auth-Token": api_key}
|
||||
team_name = team_name.lower()
|
||||
|
||||
try:
|
||||
date_from = datetime.strptime(date_from_str, "%Y-%m-%d")
|
||||
date_to = datetime.strptime(date_to_str, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
if not team_name:
|
||||
return {"error": "Team name is required."}
|
||||
|
||||
parsed_date_from = None
|
||||
if date_from_str:
|
||||
try:
|
||||
parsed_date_from = datetime.strptime(date_from_str, "%Y-%m-%d").date()
|
||||
except ValueError:
|
||||
return {
|
||||
"error": f"Invalid date_from: '{date_from_str}'. Expected format YYYY-MM-DD."
|
||||
}
|
||||
|
||||
parsed_date_to = None
|
||||
if date_to_str:
|
||||
try:
|
||||
parsed_date_to = datetime.strptime(date_to_str, "%Y-%m-%d").date()
|
||||
except ValueError:
|
||||
return {
|
||||
"error": f"Invalid date_to: '{date_to_str}'. Expected format YYYY-MM-DD."
|
||||
}
|
||||
|
||||
if parsed_date_from and parsed_date_to and parsed_date_from > parsed_date_to:
|
||||
return {"error": "date_from cannot be after date_to."}
|
||||
|
||||
# If no API key, fall back to mocked data
|
||||
if not api_key:
|
||||
# Use the parsed date objects (which can be None)
|
||||
fixtures = get_future_matches(
|
||||
team_name,
|
||||
PREMIER_LEAGUE_CLUBS_DATA,
|
||||
date_from=parsed_date_from,
|
||||
date_to=parsed_date_to,
|
||||
# num_matches can be passed explicitly if needed, otherwise defaults to 12
|
||||
)
|
||||
if not fixtures:
|
||||
# Check if the team name itself was invalid, as get_future_matches returns [] for that too
|
||||
team_details_check = next(
|
||||
(c for c in PREMIER_LEAGUE_CLUBS_DATA if c["name"] == team_name), None
|
||||
)
|
||||
if not team_details_check:
|
||||
return {"error": f"Team '{team_name}' not found in mocked data."}
|
||||
# If team is valid, an empty fixtures list means no matches fit the criteria (e.g., date range)
|
||||
return {"fixtures": fixtures}
|
||||
|
||||
# API Key is present, proceed with API logic
|
||||
# The API requires both date_from and date_to
|
||||
if not parsed_date_from or not parsed_date_to:
|
||||
return {
|
||||
"error": "Invalid date provided. Expected format YYYY-MM-DD for both date_from and date_to."
|
||||
"error": "Both date_from and date_to (YYYY-MM-DD) are required for API search."
|
||||
}
|
||||
|
||||
headers = {"X-Auth-Token": api_key}
|
||||
# For API calls, team name matching might be case-insensitive or require specific handling
|
||||
# The existing logic uses team_name.lower() for the API search path later.
|
||||
|
||||
# Fetch team ID
|
||||
teams_response = requests.get(f"{BASE_URL}/competitions/PL/teams", headers=headers)
|
||||
if teams_response.status_code != 200:
|
||||
return {"error": "Failed to fetch teams data."}
|
||||
return {
|
||||
"error": f"Failed to fetch teams data from API (status {teams_response.status_code})."
|
||||
}
|
||||
|
||||
teams_data = teams_response.json()
|
||||
team_id = None
|
||||
for team in teams_data["teams"]:
|
||||
if team_name in team["name"].lower():
|
||||
team_id = team["id"]
|
||||
# Using lower() for comparison, assuming API team names might have varied casing
|
||||
# or the input team_name might not be exact.
|
||||
# The `ToolDefinition` lists exact names, so direct match might also be an option.
|
||||
for team_api_data in teams_data.get("teams", []):
|
||||
if team_name.lower() in team_api_data.get("name", "").lower():
|
||||
team_id = team_api_data["id"]
|
||||
break
|
||||
|
||||
if not team_id:
|
||||
return {"error": "Team not found."}
|
||||
return {"error": f"Team '{team_name}' not found via API."}
|
||||
|
||||
date_from_formatted = date_from.strftime("%Y-%m-%d")
|
||||
date_to_formatted = date_to.strftime("%Y-%m-%d")
|
||||
date_from_formatted = parsed_date_from.strftime("%Y-%m-%d")
|
||||
date_to_formatted = parsed_date_to.strftime("%Y-%m-%d")
|
||||
fixtures_url = f"{BASE_URL}/teams/{team_id}/matches?dateFrom={date_from_formatted}&dateTo={date_to_formatted}"
|
||||
print(fixtures_url)
|
||||
# print(fixtures_url) # Keep for debugging if necessary
|
||||
|
||||
fixtures_response = requests.get(fixtures_url, headers=headers)
|
||||
if fixtures_response.status_code != 200:
|
||||
return {"error": "Failed to fetch fixtures data."}
|
||||
return {
|
||||
"error": f"Failed to fetch fixtures data from API (status {fixtures_response.status_code})."
|
||||
}
|
||||
|
||||
fixtures_data = fixtures_response.json()
|
||||
matching_fixtures = []
|
||||
|
||||
for match in fixtures_data.get("matches", []):
|
||||
match_datetime = datetime.strptime(match["utcDate"], "%Y-%m-%dT%H:%M:%SZ")
|
||||
if match["competition"]["code"] == "PL":
|
||||
# Ensure match datetime parsing is robust
|
||||
try:
|
||||
match_datetime_utc = datetime.strptime(
|
||||
match["utcDate"], "%Y-%m-%dT%H:%M:%SZ"
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
# Skip malformed match entries or log an error
|
||||
continue
|
||||
|
||||
if match.get("competition", {}).get("code") == "PL":
|
||||
matching_fixtures.append(
|
||||
{
|
||||
"date": match_datetime.strftime("%Y-%m-%d"),
|
||||
"homeTeam": match["homeTeam"]["name"],
|
||||
"awayTeam": match["awayTeam"]["name"],
|
||||
"date": match_datetime_utc.strftime("%Y-%m-%d"),
|
||||
"homeTeam": match.get("homeTeam", {}).get("name", "N/A"),
|
||||
"awayTeam": match.get("awayTeam", {}).get("name", "N/A"),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -82,34 +281,69 @@ def search_fixtures_example(args: dict) -> dict:
|
||||
|
||||
# Validate dates
|
||||
try:
|
||||
date_from = datetime.strptime(date_from_str, "%Y-%m-%d")
|
||||
date_to = datetime.strptime(date_to_str, "%Y-%m-%d")
|
||||
# Ensure date strings are not None before parsing
|
||||
if date_from_str is None or date_to_str is None:
|
||||
raise ValueError("Date strings cannot be None")
|
||||
date_from_obj = datetime.strptime(date_from_str, "%Y-%m-%d")
|
||||
date_to_obj = datetime.strptime(date_to_str, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
return {
|
||||
"error": "Invalid date provided. Expected format YYYY-MM-DD for both date_from and date_to."
|
||||
}
|
||||
|
||||
# Calculate 3 reasonable fixture dates within the given range
|
||||
date_range = (date_to - date_from).days
|
||||
date_range = (date_to_obj - date_from_obj).days
|
||||
if date_range < 0: # date_from is after date_to
|
||||
return {"fixtures": []} # No fixtures possible
|
||||
|
||||
fixture_dates_timestamps = []
|
||||
if date_range < 21:
|
||||
# If range is less than 3 weeks, use evenly spaced fixtures
|
||||
fixture_dates = [
|
||||
date_from + timedelta(days=max(1, date_range // 3)),
|
||||
date_from + timedelta(days=max(2, date_range * 2 // 3)),
|
||||
date_to - timedelta(days=min(2, date_range // 4)),
|
||||
]
|
||||
# If range is less than 3 weeks, use evenly spaced fixtures if possible
|
||||
if date_range >= 2: # Need at least some gap for 3 fixtures
|
||||
fixture_dates_timestamps = [
|
||||
date_from_obj
|
||||
+ timedelta(days=max(0, date_range // 4)), # Closer to start
|
||||
date_from_obj + timedelta(days=max(1, date_range // 2)), # Middle
|
||||
date_to_obj - timedelta(days=max(0, date_range // 4)), # Closer to end
|
||||
]
|
||||
elif date_range == 1: # Only two days
|
||||
fixture_dates_timestamps = [date_from_obj, date_to_obj]
|
||||
elif date_range == 0: # Only one day
|
||||
fixture_dates_timestamps = [date_from_obj]
|
||||
else: # date_range is negative, handled above, or 0 (single day)
|
||||
fixture_dates_timestamps = [date_from_obj] if date_range == 0 else []
|
||||
|
||||
else:
|
||||
# Otherwise space them out by weeks
|
||||
fixture_dates = [
|
||||
date_from + timedelta(days=7),
|
||||
date_from + timedelta(days=14),
|
||||
date_to - timedelta(days=7),
|
||||
]
|
||||
# Otherwise space them out by weeks, ensuring they are within the bounds
|
||||
d1 = date_from_obj + timedelta(days=7)
|
||||
d2 = date_from_obj + timedelta(days=14)
|
||||
d3 = date_to_obj - timedelta(days=7) # Potential third game from the end
|
||||
|
||||
# Ensure we only have 3 dates
|
||||
fixture_dates = fixture_dates[:3]
|
||||
fixture_dates_timestamps.append(d1)
|
||||
if d2 <= date_to_obj and d2 > d1: # ensure d2 is valid and distinct
|
||||
fixture_dates_timestamps.append(d2)
|
||||
if (
|
||||
d3 >= date_from_obj and d3 > d2 and d3 <= date_to_obj
|
||||
): # ensure d3 is valid and distinct
|
||||
fixture_dates_timestamps.append(d3)
|
||||
elif (
|
||||
d3 < date_from_obj and len(fixture_dates_timestamps) < 3
|
||||
): # if d3 is too early, try using date_to_obj itself if distinct
|
||||
if date_to_obj not in fixture_dates_timestamps:
|
||||
fixture_dates_timestamps.append(date_to_obj)
|
||||
|
||||
# Ensure unique dates and sort, then take up to 3.
|
||||
fixture_dates_timestamps = sorted(
|
||||
list(
|
||||
set(
|
||||
f_date
|
||||
for f_date in fixture_dates_timestamps
|
||||
if date_from_obj <= f_date <= date_to_obj
|
||||
)
|
||||
)
|
||||
)
|
||||
fixture_dates_final = fixture_dates_timestamps[:3]
|
||||
|
||||
# Expanded pool of opponent teams to avoid team playing against itself
|
||||
all_opponents = [
|
||||
"Manchester United FC",
|
||||
"Leicester City FC",
|
||||
@@ -120,35 +354,35 @@ def search_fixtures_example(args: dict) -> dict:
|
||||
"Tottenham Hotspur FC",
|
||||
"West Ham United FC",
|
||||
"Everton FC",
|
||||
"Generic Opponent A",
|
||||
"Generic Opponent B",
|
||||
"Generic Opponent C", # Fallbacks
|
||||
]
|
||||
|
||||
# Select opponents that aren't the same as the requested team
|
||||
available_opponents = [
|
||||
team for team in all_opponents if team.lower() != team_name.lower()
|
||||
]
|
||||
|
||||
# Ensure we have at least 3 opponents
|
||||
if len(available_opponents) < 3:
|
||||
# Add generic opponents if needed
|
||||
additional_teams = [f"Opponent {i} FC" for i in range(1, 4)]
|
||||
available_opponents.extend(additional_teams)
|
||||
# Ensure we have enough opponents for the number of fixtures we'll generate
|
||||
if len(available_opponents) < len(fixture_dates_final):
|
||||
needed = len(fixture_dates_final) - len(available_opponents)
|
||||
for i in range(needed):
|
||||
available_opponents.append(f"Placeholder Opponent {i+1}")
|
||||
|
||||
# Take only the first 3 opponents
|
||||
opponents = available_opponents[:3]
|
||||
opponents = available_opponents[: len(fixture_dates_final)]
|
||||
|
||||
# Generate fixtures - always exactly 3
|
||||
fixtures = []
|
||||
for i, fixture_date in enumerate(fixture_dates):
|
||||
date_str = fixture_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Alternate between home and away games
|
||||
if i % 2 == 0:
|
||||
fixtures.append(
|
||||
{"date": date_str, "homeTeam": opponents[i], "awayTeam": team_name}
|
||||
)
|
||||
else:
|
||||
for i, fixture_date_obj in enumerate(fixture_dates_final):
|
||||
if i >= len(opponents): # Should not happen with the logic above
|
||||
break
|
||||
date_str = fixture_date_obj.strftime("%Y-%m-%d")
|
||||
if i % 2 == 0: # Home game
|
||||
fixtures.append(
|
||||
{"date": date_str, "homeTeam": team_name, "awayTeam": opponents[i]}
|
||||
)
|
||||
else: # Away game
|
||||
fixtures.append(
|
||||
{"date": date_str, "homeTeam": opponents[i], "awayTeam": team_name}
|
||||
)
|
||||
|
||||
return {"fixtures": fixtures}
|
||||
|
||||
@@ -90,7 +90,7 @@ search_flights_tool = ToolDefinition(
|
||||
|
||||
search_trains_tool = ToolDefinition(
|
||||
name="SearchTrains",
|
||||
description="Search for trains between two English cities. Returns a list of train information for the user to choose from.",
|
||||
description="Search for trains between two English cities. Returns a list of train information for the user to choose from. Present the list to the user.",
|
||||
arguments=[
|
||||
ToolArgument(
|
||||
name="origin",
|
||||
@@ -156,7 +156,7 @@ create_invoice_tool = ToolDefinition(
|
||||
|
||||
search_fixtures_tool = ToolDefinition(
|
||||
name="SearchFixtures",
|
||||
description="Search for upcoming fixtures for a given team within a date range inferred from the user's description. Valid teams this 24/25 season are Arsenal FC, Aston Villa FC, AFC Bournemouth, Brentford FC, Brighton & Hove Albion FC, Chelsea FC, Crystal Palace FC, Everton FC, Fulham FC, Ipswich Town FC, Leicester City FC, Liverpool FC, Manchester City FC, Manchester United FC, Newcastle United FC, Nottingham Forest FC, Southampton FC, Tottenham Hotspur FC, West Ham United FC, Wolverhampton Wanderers FC",
|
||||
description="Search for upcoming fixtures for a given team within a date range inferred from the user's description. Ignore valid premier league dates. Valid teams this season are Arsenal FC, Aston Villa FC, AFC Bournemouth, Brentford FC, Brighton & Hove Albion FC, Chelsea FC, Crystal Palace FC, Everton FC, Fulham FC, Ipswich Town FC, Leicester City FC, Liverpool FC, Manchester City FC, Manchester United FC, Newcastle United FC, Nottingham Forest FC, Southampton FC, Tottenham Hotspur FC, West Ham United FC, Wolverhampton Wanderers FC",
|
||||
arguments=[
|
||||
ToolArgument(
|
||||
name="team",
|
||||
@@ -397,3 +397,89 @@ ecomm_track_package = ToolDefinition(
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# ----- Food Ordering Use Case Tools -----
|
||||
food_get_menu_tool = ToolDefinition(
|
||||
name="GetMenu",
|
||||
description="Get the menu for a restaurant. Defaults to Tony's Pizza Palace if no restaurant specified.",
|
||||
arguments=[
|
||||
ToolArgument(
|
||||
name="restaurant_id",
|
||||
type="string",
|
||||
description="ID of the restaurant (defaults to rest_001 for Tony's Pizza Palace)",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
food_get_menu_item_details_tool = ToolDefinition(
|
||||
name="GetMenuItemDetails",
|
||||
description="Get detailed information about a specific menu item.",
|
||||
arguments=[
|
||||
ToolArgument(
|
||||
name="item_id",
|
||||
type="string",
|
||||
description="ID of the menu item to get details for",
|
||||
),
|
||||
ToolArgument(
|
||||
name="restaurant_id",
|
||||
type="string",
|
||||
description="ID of the restaurant (defaults to rest_001 for Tony's Pizza Palace)",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
food_add_to_cart_tool = ToolDefinition(
|
||||
name="AddToCart",
|
||||
description="Add a menu item to the customer's cart.",
|
||||
arguments=[
|
||||
ToolArgument(
|
||||
name="customer_email",
|
||||
type="string",
|
||||
description="Email address of the customer",
|
||||
),
|
||||
ToolArgument(
|
||||
name="item_id",
|
||||
type="string",
|
||||
description="ID of the menu item to add to cart",
|
||||
),
|
||||
ToolArgument(
|
||||
name="quantity",
|
||||
type="number",
|
||||
description="Quantity of the item to add (defaults to 1)",
|
||||
),
|
||||
ToolArgument(
|
||||
name="restaurant_id",
|
||||
type="string",
|
||||
description="ID of the restaurant (defaults to rest_001 for Tony's Pizza Palace)",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
food_place_order_tool = ToolDefinition(
|
||||
name="PlaceOrder",
|
||||
description="Place an order for the items in the customer's cart.",
|
||||
arguments=[
|
||||
ToolArgument(
|
||||
name="customer_email",
|
||||
type="string",
|
||||
description="Email address of the customer",
|
||||
),
|
||||
ToolArgument(
|
||||
name="userConfirmation",
|
||||
type="string",
|
||||
description="Indication of user's desire to place the order",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
food_check_order_status_tool = ToolDefinition(
|
||||
name="CheckOrderStatus",
|
||||
description="Check the status of a food order.",
|
||||
arguments=[
|
||||
ToolArgument(
|
||||
name="order_id",
|
||||
type="string",
|
||||
description="ID of the order to check status for",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user