mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
tool registry refactor and fastAPI
This commit is contained in:
@@ -35,8 +35,13 @@ The chat session will end if it has collected enough information to complete the
|
|||||||
|
|
||||||
Run query get_tool_data to see the data the tool has collected so far.
|
Run query get_tool_data to see the data the tool has collected so far.
|
||||||
|
|
||||||
|
## API
|
||||||
|
- `poetry run uvicorn api.main:app --reload` to start the API server.
|
||||||
|
- Access the API at `/docs` to see the available endpoints.
|
||||||
|
|
||||||
## TODO
|
## TODO
|
||||||
- The LLM prompts move through 3 mock tools (FindEvents, SearchFlights, CreateInvoice) but I should make them contact real APIs.
|
- The LLM prompts move through 3 mock tools (FindEvents, SearchFlights, CreateInvoice) but I should make them contact real APIs.
|
||||||
- Might need to abstract the json example in the prompt generator to be part of a ToolDefinition (prevent overfitting to the example).
|
- Might need to abstract the json example in the prompt generator to be part of a ToolDefinition (prevent overfitting to the example).
|
||||||
- I need to build a chat interface so it's not cli-controlled. Also want to show some 'behind the scenes' of the agents being used as they run.
|
- I need to build a chat interface so it's not cli-controlled. Also want to show some 'behind the scenes' of the agents being used as they run.
|
||||||
- What happens if I don't want to confirm a step, but instead want to correct it? TODO figure out
|
- What happens if I don't want to confirm a step, but instead want to correct it? TODO figure out
|
||||||
|
- What happens if I am at confirmation step and want to end the chat (do I need some sort of signal router?)
|
||||||
59
api/main.py
Normal file
59
api/main.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from temporalio.client import Client
|
||||||
|
from workflows.tool_workflow import ToolWorkflow
|
||||||
|
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
|
||||||
|
from tools.tool_registry import all_tools
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
def root():
|
||||||
|
return {"message": "Temporal AI Agent!"}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/tool-data")
|
||||||
|
async def get_tool_data():
|
||||||
|
"""Calls the workflow's 'get_tool_data' query."""
|
||||||
|
client = await Client.connect("localhost:7233")
|
||||||
|
handle = client.get_workflow_handle("agent-workflow")
|
||||||
|
tool_data = await handle.query(ToolWorkflow.get_tool_data)
|
||||||
|
return tool_data
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/send-prompt")
|
||||||
|
async def send_prompt(prompt: str):
|
||||||
|
client = await Client.connect("localhost:7233")
|
||||||
|
|
||||||
|
# Build the ToolsData
|
||||||
|
tools_data = ToolsData(tools=all_tools)
|
||||||
|
|
||||||
|
# Create combined input
|
||||||
|
combined_input = CombinedInput(
|
||||||
|
tool_params=ToolWorkflowParams(None, None),
|
||||||
|
tools_data=tools_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_id = "agent-workflow"
|
||||||
|
|
||||||
|
# Start (or signal) the workflow
|
||||||
|
await client.start_workflow(
|
||||||
|
ToolWorkflow.run,
|
||||||
|
combined_input,
|
||||||
|
id=workflow_id,
|
||||||
|
task_queue="agent-task-queue",
|
||||||
|
start_signal="user_prompt",
|
||||||
|
start_signal_args=[prompt],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"message": f"Prompt '{prompt}' sent to workflow {workflow_id}."}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/confirm")
|
||||||
|
async def send_confirm():
|
||||||
|
"""Sends a 'confirm' signal to the workflow."""
|
||||||
|
client = await Client.connect("localhost:7233")
|
||||||
|
workflow_id = "agent-workflow"
|
||||||
|
handle = client.get_workflow_handle(workflow_id)
|
||||||
|
await handle.signal("confirm")
|
||||||
|
return {"message": "Confirm signal sent."}
|
||||||
64
poetry.lock
generated
64
poetry.lock
generated
@@ -11,9 +11,6 @@ files = [
|
|||||||
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
|
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dependencies]
|
|
||||||
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "anyio"
|
name = "anyio"
|
||||||
version = "4.5.2"
|
version = "4.5.2"
|
||||||
@@ -132,6 +129,26 @@ files = [
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
test = ["pytest (>=6)"]
|
test = ["pytest (>=6)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fastapi"
|
||||||
|
version = "0.115.6"
|
||||||
|
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "fastapi-0.115.6-py3-none-any.whl", hash = "sha256:e9240b29e36fa8f4bb7290316988e90c381e5092e0cbe84e7818cc3713bcf305"},
|
||||||
|
{file = "fastapi-0.115.6.tar.gz", hash = "sha256:9ec46f7addc14ea472958a96aae5b5de65f39721a46aaf5705c480d9a8b76654"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0"
|
||||||
|
starlette = ">=0.40.0,<0.42.0"
|
||||||
|
typing-extensions = ">=4.8.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
|
||||||
|
standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "h11"
|
name = "h11"
|
||||||
version = "0.14.0"
|
version = "0.14.0"
|
||||||
@@ -579,6 +596,24 @@ files = [
|
|||||||
{file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"},
|
{file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "starlette"
|
||||||
|
version = "0.41.3"
|
||||||
|
description = "The little ASGI library that shines."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "starlette-0.41.3-py3-none-any.whl", hash = "sha256:44cedb2b7c77a9de33a8b74b2b90e9f50d11fcf25d8270ea525ad71a25374ff7"},
|
||||||
|
{file = "starlette-0.41.3.tar.gz", hash = "sha256:0e4ab3d16522a255be6b28260b938eae2482f98ce5cc934cb08dce8dc3ba5835"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
anyio = ">=3.4.0,<5"
|
||||||
|
typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""}
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "temporalio"
|
name = "temporalio"
|
||||||
version = "1.9.0"
|
version = "1.9.0"
|
||||||
@@ -667,7 +702,26 @@ files = [
|
|||||||
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
|
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "uvicorn"
|
||||||
|
version = "0.34.0"
|
||||||
|
description = "The lightning-fast ASGI server."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.9"
|
||||||
|
files = [
|
||||||
|
{file = "uvicorn-0.34.0-py3-none-any.whl", hash = "sha256:023dc038422502fa28a09c7a30bf2b6991512da7dcdb8fd35fe57cfc154126f4"},
|
||||||
|
{file = "uvicorn-0.34.0.tar.gz", hash = "sha256:404051050cd7e905de2c9a7e61790943440b3416f49cb409f965d9dcd0fa73e9"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
click = ">=7.0"
|
||||||
|
h11 = ">=0.8"
|
||||||
|
typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""}
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"]
|
||||||
|
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.8"
|
python-versions = ">=3.9,<4.0"
|
||||||
content-hash = "6e712358e27083eae9c9b8b64ee258e14f589dbd0007c7d1615adffdd99b7e2a"
|
content-hash = "20144fbc5773604251f9b61ac475eb9f292c6c8baf38d59170bbef81b482c71e"
|
||||||
|
|||||||
@@ -16,13 +16,15 @@ packages = [
|
|||||||
"Bug Tracker" = "https://github.com/temporalio/samples-python/issues"
|
"Bug Tracker" = "https://github.com/temporalio/samples-python/issues"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.8"
|
python = ">=3.9,<4.0"
|
||||||
temporalio = "^1.8.0"
|
temporalio = "^1.8.0"
|
||||||
|
|
||||||
# Standard library modules (e.g. asyncio, collections) don't need to be added
|
# Standard library modules (e.g. asyncio, collections) don't need to be added
|
||||||
# since they're built-in for Python 3.8+.
|
# since they're built-in for Python 3.8+.
|
||||||
ollama = "^0.4.5"
|
ollama = "^0.4.5"
|
||||||
pyyaml = "^6.0.2"
|
pyyaml = "^6.0.2"
|
||||||
|
fastapi = "^0.115.6"
|
||||||
|
uvicorn = "^0.34.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pytest = "^7.3"
|
pytest = "^7.3"
|
||||||
|
|||||||
23
scripts/get_tool_data.py
Normal file
23
scripts/get_tool_data.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
from temporalio.client import Client
|
||||||
|
from workflows.tool_workflow import ToolWorkflow
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Create client connected to server at the given address
|
||||||
|
client = await Client.connect("localhost:7233")
|
||||||
|
workflow_id = "agent-workflow"
|
||||||
|
|
||||||
|
handle = client.get_workflow_handle(workflow_id)
|
||||||
|
|
||||||
|
# Queries the workflow for the conversation history
|
||||||
|
tool_data = await handle.query(ToolWorkflow.get_tool_data)
|
||||||
|
|
||||||
|
# pretty print
|
||||||
|
print(json.dumps(tool_data, indent=4))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
@@ -1,106 +1,32 @@
|
|||||||
# send_message.py
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
from typing import List
|
|
||||||
from temporalio.client import Client
|
from temporalio.client import Client
|
||||||
|
|
||||||
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
|
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
|
||||||
from models.tool_definitions import ToolDefinition, ToolArgument
|
from tools.tool_registry import all_tools # <–– Import your pre-defined tools
|
||||||
from workflows.tool_workflow import ToolWorkflow
|
from workflows.tool_workflow import ToolWorkflow
|
||||||
|
|
||||||
|
|
||||||
async def main(prompt: str):
|
async def main(prompt: str):
|
||||||
# 1) Define the FindEvents tool
|
# 1) Build the ToolsData from imported all_tools
|
||||||
find_events_tool = ToolDefinition(
|
|
||||||
name="FindEvents",
|
|
||||||
description="Find upcoming events given a location or region (e.g., 'Oceania') and a date or month",
|
|
||||||
arguments=[
|
|
||||||
ToolArgument(
|
|
||||||
name="continent",
|
|
||||||
type="string",
|
|
||||||
description="Which continent or region to search for events",
|
|
||||||
),
|
|
||||||
ToolArgument(
|
|
||||||
name="month",
|
|
||||||
type="string",
|
|
||||||
description="The month or approximate date range to find events",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2) Define the SearchFlights tool
|
|
||||||
search_flights_tool = ToolDefinition(
|
|
||||||
name="SearchFlights",
|
|
||||||
description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)",
|
|
||||||
arguments=[
|
|
||||||
ToolArgument(
|
|
||||||
name="origin",
|
|
||||||
type="string",
|
|
||||||
description="Airport or city (infer airport code from city)",
|
|
||||||
),
|
|
||||||
ToolArgument(
|
|
||||||
name="destination",
|
|
||||||
type="string",
|
|
||||||
description="Airport or city code for arrival (infer airport code from city)",
|
|
||||||
),
|
|
||||||
ToolArgument(
|
|
||||||
name="dateDepart",
|
|
||||||
type="ISO8601",
|
|
||||||
description="Start of date range in human readable format, when you want to depart",
|
|
||||||
),
|
|
||||||
ToolArgument(
|
|
||||||
name="dateReturn",
|
|
||||||
type="ISO8601",
|
|
||||||
description="End of date range in human readable format, when you want to return",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3) Define the CreateInvoice tool
|
|
||||||
create_invoice_tool = ToolDefinition(
|
|
||||||
name="CreateInvoice",
|
|
||||||
description="Generate an invoice with flight information or other items to purchase",
|
|
||||||
arguments=[
|
|
||||||
ToolArgument(
|
|
||||||
name="amount",
|
|
||||||
type="float",
|
|
||||||
description="The total cost to be invoiced",
|
|
||||||
),
|
|
||||||
ToolArgument(
|
|
||||||
name="flightDetails",
|
|
||||||
type="string",
|
|
||||||
description="A summary of the flights, e.g., flight numbers, price breakdown",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Collect all tools in a ToolsData structure
|
|
||||||
all_tools: List[ToolDefinition] = [
|
|
||||||
find_events_tool,
|
|
||||||
search_flights_tool,
|
|
||||||
create_invoice_tool,
|
|
||||||
]
|
|
||||||
tools_data = ToolsData(tools=all_tools)
|
tools_data = ToolsData(tools=all_tools)
|
||||||
|
|
||||||
# Create the combined input (includes ToolsData + optional conversation summary or prompt queue)
|
# 2) Create combined input
|
||||||
combined_input = CombinedInput(
|
combined_input = CombinedInput(
|
||||||
tool_params=ToolWorkflowParams(None, None),
|
tool_params=ToolWorkflowParams(None, None),
|
||||||
tools_data=tools_data,
|
tools_data=tools_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4) Connect to Temporal and start or signal the workflow
|
# 3) Connect to Temporal and start or signal the workflow
|
||||||
client = await Client.connect("localhost:7233")
|
client = await Client.connect("localhost:7233")
|
||||||
|
|
||||||
workflow_id = "agent-workflow"
|
workflow_id = "agent-workflow"
|
||||||
|
|
||||||
# Note that we start the ToolWorkflow.run with 'combined_input'
|
|
||||||
# Then we immediately signal with the initial prompt
|
|
||||||
await client.start_workflow(
|
await client.start_workflow(
|
||||||
ToolWorkflow.run,
|
ToolWorkflow.run,
|
||||||
combined_input,
|
combined_input,
|
||||||
id=workflow_id,
|
id=workflow_id,
|
||||||
task_queue="agent-task-queue",
|
task_queue="agent-task-queue",
|
||||||
start_signal="user_prompt", # This will send your first prompt to the workflow
|
start_signal="user_prompt",
|
||||||
start_signal_args=[prompt],
|
start_signal_args=[prompt],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -108,9 +34,6 @@ async def main(prompt: str):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
if len(sys.argv) != 2:
|
if len(sys.argv) != 2:
|
||||||
print("Usage: python send_message.py '<prompt>'")
|
print("Usage: python send_message.py '<prompt>'")
|
||||||
print(
|
print("Example: python send_message.py 'I want an event in Oceania this March'")
|
||||||
"Example: python send_message.py 'I want an event in Oceania this March'"
|
|
||||||
" or 'Search flights from Seattle to San Francisco'"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
asyncio.run(main(sys.argv[1]))
|
asyncio.run(main(sys.argv[1]))
|
||||||
|
|||||||
66
tools/tool_registry.py
Normal file
66
tools/tool_registry.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
from models.tool_definitions import ToolDefinition, ToolArgument
|
||||||
|
|
||||||
|
find_events_tool = ToolDefinition(
|
||||||
|
name="FindEvents",
|
||||||
|
description="Find upcoming events given a location or region (e.g., 'Oceania') and a date or month",
|
||||||
|
arguments=[
|
||||||
|
ToolArgument(
|
||||||
|
name="continent",
|
||||||
|
type="string",
|
||||||
|
description="Which continent or region to search for events",
|
||||||
|
),
|
||||||
|
ToolArgument(
|
||||||
|
name="month",
|
||||||
|
type="string",
|
||||||
|
description="The month or approximate date range to find events",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Define the SearchFlights tool
|
||||||
|
search_flights_tool = ToolDefinition(
|
||||||
|
name="SearchFlights",
|
||||||
|
description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)",
|
||||||
|
arguments=[
|
||||||
|
ToolArgument(
|
||||||
|
name="origin",
|
||||||
|
type="string",
|
||||||
|
description="Airport or city (infer airport code from city)",
|
||||||
|
),
|
||||||
|
ToolArgument(
|
||||||
|
name="destination",
|
||||||
|
type="string",
|
||||||
|
description="Airport or city code for arrival (infer airport code from city)",
|
||||||
|
),
|
||||||
|
ToolArgument(
|
||||||
|
name="dateDepart",
|
||||||
|
type="ISO8601",
|
||||||
|
description="Start of date range in human readable format, when you want to depart",
|
||||||
|
),
|
||||||
|
ToolArgument(
|
||||||
|
name="dateReturn",
|
||||||
|
type="ISO8601",
|
||||||
|
description="End of date range in human readable format, when you want to return",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3) Define the CreateInvoice tool
|
||||||
|
create_invoice_tool = ToolDefinition(
|
||||||
|
name="CreateInvoice",
|
||||||
|
description="Generate an invoice with flight information or other items to purchase",
|
||||||
|
arguments=[
|
||||||
|
ToolArgument(
|
||||||
|
name="amount",
|
||||||
|
type="float",
|
||||||
|
description="The total cost to be invoiced",
|
||||||
|
),
|
||||||
|
ToolArgument(
|
||||||
|
name="flightDetails",
|
||||||
|
type="string",
|
||||||
|
description="A summary of the flights, e.g., flight numbers, price breakdown",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
all_tools = [find_events_tool, search_flights_tool, create_invoice_tool]
|
||||||
Reference in New Issue
Block a user