mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
tool registry refactor and fastAPI
This commit is contained in:
23
scripts/get_tool_data.py
Normal file
23
scripts/get_tool_data.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from temporalio.client import Client
|
||||
from workflows.tool_workflow import ToolWorkflow
|
||||
|
||||
|
||||
async def main():
|
||||
# Create client connected to server at the given address
|
||||
client = await Client.connect("localhost:7233")
|
||||
workflow_id = "agent-workflow"
|
||||
|
||||
handle = client.get_workflow_handle(workflow_id)
|
||||
|
||||
# Queries the workflow for the conversation history
|
||||
tool_data = await handle.query(ToolWorkflow.get_tool_data)
|
||||
|
||||
# pretty print
|
||||
print(json.dumps(tool_data, indent=4))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,106 +1,32 @@
|
||||
# send_message.py
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import List
|
||||
from temporalio.client import Client
|
||||
|
||||
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
|
||||
from models.tool_definitions import ToolDefinition, ToolArgument
|
||||
from tools.tool_registry import all_tools # <–– Import your pre-defined tools
|
||||
from workflows.tool_workflow import ToolWorkflow
|
||||
|
||||
|
||||
async def main(prompt: str):
|
||||
# 1) Define the FindEvents tool
|
||||
find_events_tool = ToolDefinition(
|
||||
name="FindEvents",
|
||||
description="Find upcoming events given a location or region (e.g., 'Oceania') and a date or month",
|
||||
arguments=[
|
||||
ToolArgument(
|
||||
name="continent",
|
||||
type="string",
|
||||
description="Which continent or region to search for events",
|
||||
),
|
||||
ToolArgument(
|
||||
name="month",
|
||||
type="string",
|
||||
description="The month or approximate date range to find events",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# 2) Define the SearchFlights tool
|
||||
search_flights_tool = ToolDefinition(
|
||||
name="SearchFlights",
|
||||
description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)",
|
||||
arguments=[
|
||||
ToolArgument(
|
||||
name="origin",
|
||||
type="string",
|
||||
description="Airport or city (infer airport code from city)",
|
||||
),
|
||||
ToolArgument(
|
||||
name="destination",
|
||||
type="string",
|
||||
description="Airport or city code for arrival (infer airport code from city)",
|
||||
),
|
||||
ToolArgument(
|
||||
name="dateDepart",
|
||||
type="ISO8601",
|
||||
description="Start of date range in human readable format, when you want to depart",
|
||||
),
|
||||
ToolArgument(
|
||||
name="dateReturn",
|
||||
type="ISO8601",
|
||||
description="End of date range in human readable format, when you want to return",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# 3) Define the CreateInvoice tool
|
||||
create_invoice_tool = ToolDefinition(
|
||||
name="CreateInvoice",
|
||||
description="Generate an invoice with flight information or other items to purchase",
|
||||
arguments=[
|
||||
ToolArgument(
|
||||
name="amount",
|
||||
type="float",
|
||||
description="The total cost to be invoiced",
|
||||
),
|
||||
ToolArgument(
|
||||
name="flightDetails",
|
||||
type="string",
|
||||
description="A summary of the flights, e.g., flight numbers, price breakdown",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Collect all tools in a ToolsData structure
|
||||
all_tools: List[ToolDefinition] = [
|
||||
find_events_tool,
|
||||
search_flights_tool,
|
||||
create_invoice_tool,
|
||||
]
|
||||
# 1) Build the ToolsData from imported all_tools
|
||||
tools_data = ToolsData(tools=all_tools)
|
||||
|
||||
# Create the combined input (includes ToolsData + optional conversation summary or prompt queue)
|
||||
# 2) Create combined input
|
||||
combined_input = CombinedInput(
|
||||
tool_params=ToolWorkflowParams(None, None),
|
||||
tools_data=tools_data,
|
||||
)
|
||||
|
||||
# 4) Connect to Temporal and start or signal the workflow
|
||||
# 3) Connect to Temporal and start or signal the workflow
|
||||
client = await Client.connect("localhost:7233")
|
||||
|
||||
workflow_id = "agent-workflow"
|
||||
|
||||
# Note that we start the ToolWorkflow.run with 'combined_input'
|
||||
# Then we immediately signal with the initial prompt
|
||||
await client.start_workflow(
|
||||
ToolWorkflow.run,
|
||||
combined_input,
|
||||
id=workflow_id,
|
||||
task_queue="agent-task-queue",
|
||||
start_signal="user_prompt", # This will send your first prompt to the workflow
|
||||
start_signal="user_prompt",
|
||||
start_signal_args=[prompt],
|
||||
)
|
||||
|
||||
@@ -108,9 +34,6 @@ async def main(prompt: str):
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python send_message.py '<prompt>'")
|
||||
print(
|
||||
"Example: python send_message.py 'I want an event in Oceania this March'"
|
||||
" or 'Search flights from Seattle to San Francisco'"
|
||||
)
|
||||
print("Example: python send_message.py 'I want an event in Oceania this March'")
|
||||
else:
|
||||
asyncio.run(main(sys.argv[1]))
|
||||
|
||||
Reference in New Issue
Block a user