tool registry refactor and fastAPI

This commit is contained in:
Steve Androulakis
2025-01-02 11:30:50 -08:00
parent 745877db69
commit a98ae439ac
7 changed files with 222 additions and 90 deletions

View File

@@ -35,8 +35,13 @@ The chat session will end if it has collected enough information to complete the
Run query get_tool_data to see the data the tool has collected so far. Run query get_tool_data to see the data the tool has collected so far.
## API
- `poetry run uvicorn api.main:app --reload` to start the API server.
- Access the API at `/docs` to see the available endpoints.
## TODO ## TODO
- The LLM prompts move through 3 mock tools (FindEvents, SearchFlights, CreateInvoice) but I should make them contact real APIs. - The LLM prompts move through 3 mock tools (FindEvents, SearchFlights, CreateInvoice) but I should make them contact real APIs.
- Might need to abstract the json example in the prompt generator to be part of a ToolDefinition (prevent overfitting to the example). - Might need to abstract the json example in the prompt generator to be part of a ToolDefinition (prevent overfitting to the example).
- I need to build a chat interface so it's not cli-controlled. Also want to show some 'behind the scenes' of the agents being used as they run. - I need to build a chat interface so it's not cli-controlled. Also want to show some 'behind the scenes' of the agents being used as they run.
- What happens if I don't want to confirm a step, but instead want to correct it? TODO figure out - What happens if I don't want to confirm a step, but instead want to correct it? TODO figure out
- What happens if I am at confirmation step and want to end the chat (do I need some sort of signal router?)

59
api/main.py Normal file
View File

@@ -0,0 +1,59 @@
from fastapi import FastAPI
from temporalio.client import Client
from workflows.tool_workflow import ToolWorkflow
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
from tools.tool_registry import all_tools
app = FastAPI()
@app.get("/")
def root():
return {"message": "Temporal AI Agent!"}
@app.get("/tool-data")
async def get_tool_data():
"""Calls the workflow's 'get_tool_data' query."""
client = await Client.connect("localhost:7233")
handle = client.get_workflow_handle("agent-workflow")
tool_data = await handle.query(ToolWorkflow.get_tool_data)
return tool_data
@app.post("/send-prompt")
async def send_prompt(prompt: str):
client = await Client.connect("localhost:7233")
# Build the ToolsData
tools_data = ToolsData(tools=all_tools)
# Create combined input
combined_input = CombinedInput(
tool_params=ToolWorkflowParams(None, None),
tools_data=tools_data,
)
workflow_id = "agent-workflow"
# Start (or signal) the workflow
await client.start_workflow(
ToolWorkflow.run,
combined_input,
id=workflow_id,
task_queue="agent-task-queue",
start_signal="user_prompt",
start_signal_args=[prompt],
)
return {"message": f"Prompt '{prompt}' sent to workflow {workflow_id}."}
@app.post("/confirm")
async def send_confirm():
"""Sends a 'confirm' signal to the workflow."""
client = await Client.connect("localhost:7233")
workflow_id = "agent-workflow"
handle = client.get_workflow_handle(workflow_id)
await handle.signal("confirm")
return {"message": "Confirm signal sent."}

64
poetry.lock generated
View File

@@ -11,9 +11,6 @@ files = [
{file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"},
] ]
[package.dependencies]
typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""}
[[package]] [[package]]
name = "anyio" name = "anyio"
version = "4.5.2" version = "4.5.2"
@@ -132,6 +129,26 @@ files = [
[package.extras] [package.extras]
test = ["pytest (>=6)"] test = ["pytest (>=6)"]
[[package]]
name = "fastapi"
version = "0.115.6"
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
optional = false
python-versions = ">=3.8"
files = [
{file = "fastapi-0.115.6-py3-none-any.whl", hash = "sha256:e9240b29e36fa8f4bb7290316988e90c381e5092e0cbe84e7818cc3713bcf305"},
{file = "fastapi-0.115.6.tar.gz", hash = "sha256:9ec46f7addc14ea472958a96aae5b5de65f39721a46aaf5705c480d9a8b76654"},
]
[package.dependencies]
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0"
starlette = ">=0.40.0,<0.42.0"
typing-extensions = ">=4.8.0"
[package.extras]
all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"]
[[package]] [[package]]
name = "h11" name = "h11"
version = "0.14.0" version = "0.14.0"
@@ -579,6 +596,24 @@ files = [
{file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"},
] ]
[[package]]
name = "starlette"
version = "0.41.3"
description = "The little ASGI library that shines."
optional = false
python-versions = ">=3.8"
files = [
{file = "starlette-0.41.3-py3-none-any.whl", hash = "sha256:44cedb2b7c77a9de33a8b74b2b90e9f50d11fcf25d8270ea525ad71a25374ff7"},
{file = "starlette-0.41.3.tar.gz", hash = "sha256:0e4ab3d16522a255be6b28260b938eae2482f98ce5cc934cb08dce8dc3ba5835"},
]
[package.dependencies]
anyio = ">=3.4.0,<5"
typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""}
[package.extras]
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"]
[[package]] [[package]]
name = "temporalio" name = "temporalio"
version = "1.9.0" version = "1.9.0"
@@ -667,7 +702,26 @@ files = [
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
] ]
[[package]]
name = "uvicorn"
version = "0.34.0"
description = "The lightning-fast ASGI server."
optional = false
python-versions = ">=3.9"
files = [
{file = "uvicorn-0.34.0-py3-none-any.whl", hash = "sha256:023dc038422502fa28a09c7a30bf2b6991512da7dcdb8fd35fe57cfc154126f4"},
{file = "uvicorn-0.34.0.tar.gz", hash = "sha256:404051050cd7e905de2c9a7e61790943440b3416f49cb409f965d9dcd0fa73e9"},
]
[package.dependencies]
click = ">=7.0"
h11 = ">=0.8"
typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""}
[package.extras]
standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.8" python-versions = ">=3.9,<4.0"
content-hash = "6e712358e27083eae9c9b8b64ee258e14f589dbd0007c7d1615adffdd99b7e2a" content-hash = "20144fbc5773604251f9b61ac475eb9f292c6c8baf38d59170bbef81b482c71e"

View File

@@ -16,13 +16,15 @@ packages = [
"Bug Tracker" = "https://github.com/temporalio/samples-python/issues" "Bug Tracker" = "https://github.com/temporalio/samples-python/issues"
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.8" python = ">=3.9,<4.0"
temporalio = "^1.8.0" temporalio = "^1.8.0"
# Standard library modules (e.g. asyncio, collections) don't need to be added # Standard library modules (e.g. asyncio, collections) don't need to be added
# since they're built-in for Python 3.8+. # since they're built-in for Python 3.8+.
ollama = "^0.4.5" ollama = "^0.4.5"
pyyaml = "^6.0.2" pyyaml = "^6.0.2"
fastapi = "^0.115.6"
uvicorn = "^0.34.0"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
pytest = "^7.3" pytest = "^7.3"

23
scripts/get_tool_data.py Normal file
View File

@@ -0,0 +1,23 @@
import asyncio
import json
from temporalio.client import Client
from workflows.tool_workflow import ToolWorkflow
async def main():
# Create client connected to server at the given address
client = await Client.connect("localhost:7233")
workflow_id = "agent-workflow"
handle = client.get_workflow_handle(workflow_id)
# Queries the workflow for the conversation history
tool_data = await handle.query(ToolWorkflow.get_tool_data)
# pretty print
print(json.dumps(tool_data, indent=4))
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,106 +1,32 @@
# send_message.py
import asyncio import asyncio
import sys import sys
from typing import List
from temporalio.client import Client from temporalio.client import Client
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
from models.tool_definitions import ToolDefinition, ToolArgument from tools.tool_registry import all_tools # < Import your pre-defined tools
from workflows.tool_workflow import ToolWorkflow from workflows.tool_workflow import ToolWorkflow
async def main(prompt: str): async def main(prompt: str):
# 1) Define the FindEvents tool # 1) Build the ToolsData from imported all_tools
find_events_tool = ToolDefinition(
name="FindEvents",
description="Find upcoming events given a location or region (e.g., 'Oceania') and a date or month",
arguments=[
ToolArgument(
name="continent",
type="string",
description="Which continent or region to search for events",
),
ToolArgument(
name="month",
type="string",
description="The month or approximate date range to find events",
),
],
)
# 2) Define the SearchFlights tool
search_flights_tool = ToolDefinition(
name="SearchFlights",
description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)",
arguments=[
ToolArgument(
name="origin",
type="string",
description="Airport or city (infer airport code from city)",
),
ToolArgument(
name="destination",
type="string",
description="Airport or city code for arrival (infer airport code from city)",
),
ToolArgument(
name="dateDepart",
type="ISO8601",
description="Start of date range in human readable format, when you want to depart",
),
ToolArgument(
name="dateReturn",
type="ISO8601",
description="End of date range in human readable format, when you want to return",
),
],
)
# 3) Define the CreateInvoice tool
create_invoice_tool = ToolDefinition(
name="CreateInvoice",
description="Generate an invoice with flight information or other items to purchase",
arguments=[
ToolArgument(
name="amount",
type="float",
description="The total cost to be invoiced",
),
ToolArgument(
name="flightDetails",
type="string",
description="A summary of the flights, e.g., flight numbers, price breakdown",
),
],
)
# Collect all tools in a ToolsData structure
all_tools: List[ToolDefinition] = [
find_events_tool,
search_flights_tool,
create_invoice_tool,
]
tools_data = ToolsData(tools=all_tools) tools_data = ToolsData(tools=all_tools)
# Create the combined input (includes ToolsData + optional conversation summary or prompt queue) # 2) Create combined input
combined_input = CombinedInput( combined_input = CombinedInput(
tool_params=ToolWorkflowParams(None, None), tool_params=ToolWorkflowParams(None, None),
tools_data=tools_data, tools_data=tools_data,
) )
# 4) Connect to Temporal and start or signal the workflow # 3) Connect to Temporal and start or signal the workflow
client = await Client.connect("localhost:7233") client = await Client.connect("localhost:7233")
workflow_id = "agent-workflow" workflow_id = "agent-workflow"
# Note that we start the ToolWorkflow.run with 'combined_input'
# Then we immediately signal with the initial prompt
await client.start_workflow( await client.start_workflow(
ToolWorkflow.run, ToolWorkflow.run,
combined_input, combined_input,
id=workflow_id, id=workflow_id,
task_queue="agent-task-queue", task_queue="agent-task-queue",
start_signal="user_prompt", # This will send your first prompt to the workflow start_signal="user_prompt",
start_signal_args=[prompt], start_signal_args=[prompt],
) )
@@ -108,9 +34,6 @@ async def main(prompt: str):
if __name__ == "__main__": if __name__ == "__main__":
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("Usage: python send_message.py '<prompt>'") print("Usage: python send_message.py '<prompt>'")
print( print("Example: python send_message.py 'I want an event in Oceania this March'")
"Example: python send_message.py 'I want an event in Oceania this March'"
" or 'Search flights from Seattle to San Francisco'"
)
else: else:
asyncio.run(main(sys.argv[1])) asyncio.run(main(sys.argv[1]))

66
tools/tool_registry.py Normal file
View File

@@ -0,0 +1,66 @@
from models.tool_definitions import ToolDefinition, ToolArgument
find_events_tool = ToolDefinition(
name="FindEvents",
description="Find upcoming events given a location or region (e.g., 'Oceania') and a date or month",
arguments=[
ToolArgument(
name="continent",
type="string",
description="Which continent or region to search for events",
),
ToolArgument(
name="month",
type="string",
description="The month or approximate date range to find events",
),
],
)
# 2) Define the SearchFlights tool
search_flights_tool = ToolDefinition(
name="SearchFlights",
description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)",
arguments=[
ToolArgument(
name="origin",
type="string",
description="Airport or city (infer airport code from city)",
),
ToolArgument(
name="destination",
type="string",
description="Airport or city code for arrival (infer airport code from city)",
),
ToolArgument(
name="dateDepart",
type="ISO8601",
description="Start of date range in human readable format, when you want to depart",
),
ToolArgument(
name="dateReturn",
type="ISO8601",
description="End of date range in human readable format, when you want to return",
),
],
)
# 3) Define the CreateInvoice tool
create_invoice_tool = ToolDefinition(
name="CreateInvoice",
description="Generate an invoice with flight information or other items to purchase",
arguments=[
ToolArgument(
name="amount",
type="float",
description="The total cost to be invoiced",
),
ToolArgument(
name="flightDetails",
type="string",
description="A summary of the flights, e.g., flight numbers, price breakdown",
),
],
)
all_tools = [find_events_tool, search_flights_tool, create_invoice_tool]