diff --git a/README.md b/README.md index 948e3c4..6b885cf 100644 --- a/README.md +++ b/README.md @@ -35,8 +35,13 @@ The chat session will end if it has collected enough information to complete the Run query get_tool_data to see the data the tool has collected so far. +## API +- `poetry run uvicorn api.main:app --reload` to start the API server. +- Access the API at `/docs` to see the available endpoints. + ## TODO - The LLM prompts move through 3 mock tools (FindEvents, SearchFlights, CreateInvoice) but I should make them contact real APIs. - Might need to abstract the json example in the prompt generator to be part of a ToolDefinition (prevent overfitting to the example). - I need to build a chat interface so it's not cli-controlled. Also want to show some 'behind the scenes' of the agents being used as they run. -- What happens if I don't want to confirm a step, but instead want to correct it? TODO figure out \ No newline at end of file +- What happens if I don't want to confirm a step, but instead want to correct it? TODO figure out +- What happens if I am at confirmation step and want to end the chat (do I need some sort of signal router?) \ No newline at end of file diff --git a/api/main.py b/api/main.py new file mode 100644 index 0000000..0aa2ea9 --- /dev/null +++ b/api/main.py @@ -0,0 +1,59 @@ +from fastapi import FastAPI +from temporalio.client import Client +from workflows.tool_workflow import ToolWorkflow +from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams +from tools.tool_registry import all_tools + +app = FastAPI() + + +@app.get("/") +def root(): + return {"message": "Temporal AI Agent!"} + + +@app.get("/tool-data") +async def get_tool_data(): + """Calls the workflow's 'get_tool_data' query.""" + client = await Client.connect("localhost:7233") + handle = client.get_workflow_handle("agent-workflow") + tool_data = await handle.query(ToolWorkflow.get_tool_data) + return tool_data + + +@app.post("/send-prompt") +async def send_prompt(prompt: str): + client = await Client.connect("localhost:7233") + + # Build the ToolsData + tools_data = ToolsData(tools=all_tools) + + # Create combined input + combined_input = CombinedInput( + tool_params=ToolWorkflowParams(None, None), + tools_data=tools_data, + ) + + workflow_id = "agent-workflow" + + # Start (or signal) the workflow + await client.start_workflow( + ToolWorkflow.run, + combined_input, + id=workflow_id, + task_queue="agent-task-queue", + start_signal="user_prompt", + start_signal_args=[prompt], + ) + + return {"message": f"Prompt '{prompt}' sent to workflow {workflow_id}."} + + +@app.post("/confirm") +async def send_confirm(): + """Sends a 'confirm' signal to the workflow.""" + client = await Client.connect("localhost:7233") + workflow_id = "agent-workflow" + handle = client.get_workflow_handle(workflow_id) + await handle.signal("confirm") + return {"message": "Confirm signal sent."} diff --git a/poetry.lock b/poetry.lock index 7fb7540..a95f94d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -11,9 +11,6 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] -[package.dependencies] -typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} - [[package]] name = "anyio" version = "4.5.2" @@ -132,6 +129,26 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "fastapi" +version = "0.115.6" +description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fastapi-0.115.6-py3-none-any.whl", hash = "sha256:e9240b29e36fa8f4bb7290316988e90c381e5092e0cbe84e7818cc3713bcf305"}, + {file = "fastapi-0.115.6.tar.gz", hash = "sha256:9ec46f7addc14ea472958a96aae5b5de65f39721a46aaf5705c480d9a8b76654"}, +] + +[package.dependencies] +pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" +starlette = ">=0.40.0,<0.42.0" +typing-extensions = ">=4.8.0" + +[package.extras] +all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"] + [[package]] name = "h11" version = "0.14.0" @@ -579,6 +596,24 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "starlette" +version = "0.41.3" +description = "The little ASGI library that shines." +optional = false +python-versions = ">=3.8" +files = [ + {file = "starlette-0.41.3-py3-none-any.whl", hash = "sha256:44cedb2b7c77a9de33a8b74b2b90e9f50d11fcf25d8270ea525ad71a25374ff7"}, + {file = "starlette-0.41.3.tar.gz", hash = "sha256:0e4ab3d16522a255be6b28260b938eae2482f98ce5cc934cb08dce8dc3ba5835"}, +] + +[package.dependencies] +anyio = ">=3.4.0,<5" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} + +[package.extras] +full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.7)", "pyyaml"] + [[package]] name = "temporalio" version = "1.9.0" @@ -667,7 +702,26 @@ files = [ {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +[[package]] +name = "uvicorn" +version = "0.34.0" +description = "The lightning-fast ASGI server." +optional = false +python-versions = ">=3.9" +files = [ + {file = "uvicorn-0.34.0-py3-none-any.whl", hash = "sha256:023dc038422502fa28a09c7a30bf2b6991512da7dcdb8fd35fe57cfc154126f4"}, + {file = "uvicorn-0.34.0.tar.gz", hash = "sha256:404051050cd7e905de2c9a7e61790943440b3416f49cb409f965d9dcd0fa73e9"}, +] + +[package.dependencies] +click = ">=7.0" +h11 = ">=0.8" +typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} + +[package.extras] +standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] + [metadata] lock-version = "2.0" -python-versions = "^3.8" -content-hash = "6e712358e27083eae9c9b8b64ee258e14f589dbd0007c7d1615adffdd99b7e2a" +python-versions = ">=3.9,<4.0" +content-hash = "20144fbc5773604251f9b61ac475eb9f292c6c8baf38d59170bbef81b482c71e" diff --git a/pyproject.toml b/pyproject.toml index 5ccd568..c19eff9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,13 +16,15 @@ packages = [ "Bug Tracker" = "https://github.com/temporalio/samples-python/issues" [tool.poetry.dependencies] -python = "^3.8" +python = ">=3.9,<4.0" temporalio = "^1.8.0" # Standard library modules (e.g. asyncio, collections) don't need to be added # since they're built-in for Python 3.8+. ollama = "^0.4.5" pyyaml = "^6.0.2" +fastapi = "^0.115.6" +uvicorn = "^0.34.0" [tool.poetry.group.dev.dependencies] pytest = "^7.3" diff --git a/scripts/get_tool_data.py b/scripts/get_tool_data.py new file mode 100644 index 0000000..e969f41 --- /dev/null +++ b/scripts/get_tool_data.py @@ -0,0 +1,23 @@ +import asyncio +import json + +from temporalio.client import Client +from workflows.tool_workflow import ToolWorkflow + + +async def main(): + # Create client connected to server at the given address + client = await Client.connect("localhost:7233") + workflow_id = "agent-workflow" + + handle = client.get_workflow_handle(workflow_id) + + # Queries the workflow for the conversation history + tool_data = await handle.query(ToolWorkflow.get_tool_data) + + # pretty print + print(json.dumps(tool_data, indent=4)) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/send_message.py b/scripts/send_message.py index 154b162..27d5b01 100644 --- a/scripts/send_message.py +++ b/scripts/send_message.py @@ -1,106 +1,32 @@ -# send_message.py import asyncio import sys -from typing import List from temporalio.client import Client from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams -from models.tool_definitions import ToolDefinition, ToolArgument +from tools.tool_registry import all_tools # <–– Import your pre-defined tools from workflows.tool_workflow import ToolWorkflow async def main(prompt: str): - # 1) Define the FindEvents tool - find_events_tool = ToolDefinition( - name="FindEvents", - description="Find upcoming events given a location or region (e.g., 'Oceania') and a date or month", - arguments=[ - ToolArgument( - name="continent", - type="string", - description="Which continent or region to search for events", - ), - ToolArgument( - name="month", - type="string", - description="The month or approximate date range to find events", - ), - ], - ) - - # 2) Define the SearchFlights tool - search_flights_tool = ToolDefinition( - name="SearchFlights", - description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)", - arguments=[ - ToolArgument( - name="origin", - type="string", - description="Airport or city (infer airport code from city)", - ), - ToolArgument( - name="destination", - type="string", - description="Airport or city code for arrival (infer airport code from city)", - ), - ToolArgument( - name="dateDepart", - type="ISO8601", - description="Start of date range in human readable format, when you want to depart", - ), - ToolArgument( - name="dateReturn", - type="ISO8601", - description="End of date range in human readable format, when you want to return", - ), - ], - ) - - # 3) Define the CreateInvoice tool - create_invoice_tool = ToolDefinition( - name="CreateInvoice", - description="Generate an invoice with flight information or other items to purchase", - arguments=[ - ToolArgument( - name="amount", - type="float", - description="The total cost to be invoiced", - ), - ToolArgument( - name="flightDetails", - type="string", - description="A summary of the flights, e.g., flight numbers, price breakdown", - ), - ], - ) - - # Collect all tools in a ToolsData structure - all_tools: List[ToolDefinition] = [ - find_events_tool, - search_flights_tool, - create_invoice_tool, - ] + # 1) Build the ToolsData from imported all_tools tools_data = ToolsData(tools=all_tools) - # Create the combined input (includes ToolsData + optional conversation summary or prompt queue) + # 2) Create combined input combined_input = CombinedInput( tool_params=ToolWorkflowParams(None, None), tools_data=tools_data, ) - # 4) Connect to Temporal and start or signal the workflow + # 3) Connect to Temporal and start or signal the workflow client = await Client.connect("localhost:7233") - workflow_id = "agent-workflow" - # Note that we start the ToolWorkflow.run with 'combined_input' - # Then we immediately signal with the initial prompt await client.start_workflow( ToolWorkflow.run, combined_input, id=workflow_id, task_queue="agent-task-queue", - start_signal="user_prompt", # This will send your first prompt to the workflow + start_signal="user_prompt", start_signal_args=[prompt], ) @@ -108,9 +34,6 @@ async def main(prompt: str): if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: python send_message.py ''") - print( - "Example: python send_message.py 'I want an event in Oceania this March'" - " or 'Search flights from Seattle to San Francisco'" - ) + print("Example: python send_message.py 'I want an event in Oceania this March'") else: asyncio.run(main(sys.argv[1])) diff --git a/tools/tool_registry.py b/tools/tool_registry.py new file mode 100644 index 0000000..d7bca8e --- /dev/null +++ b/tools/tool_registry.py @@ -0,0 +1,66 @@ +from models.tool_definitions import ToolDefinition, ToolArgument + +find_events_tool = ToolDefinition( + name="FindEvents", + description="Find upcoming events given a location or region (e.g., 'Oceania') and a date or month", + arguments=[ + ToolArgument( + name="continent", + type="string", + description="Which continent or region to search for events", + ), + ToolArgument( + name="month", + type="string", + description="The month or approximate date range to find events", + ), + ], +) + +# 2) Define the SearchFlights tool +search_flights_tool = ToolDefinition( + name="SearchFlights", + description="Search for return flights from an origin to a destination within a date range (dateDepart, dateReturn)", + arguments=[ + ToolArgument( + name="origin", + type="string", + description="Airport or city (infer airport code from city)", + ), + ToolArgument( + name="destination", + type="string", + description="Airport or city code for arrival (infer airport code from city)", + ), + ToolArgument( + name="dateDepart", + type="ISO8601", + description="Start of date range in human readable format, when you want to depart", + ), + ToolArgument( + name="dateReturn", + type="ISO8601", + description="End of date range in human readable format, when you want to return", + ), + ], +) + +# 3) Define the CreateInvoice tool +create_invoice_tool = ToolDefinition( + name="CreateInvoice", + description="Generate an invoice with flight information or other items to purchase", + arguments=[ + ToolArgument( + name="amount", + type="float", + description="The total cost to be invoiced", + ), + ToolArgument( + name="flightDetails", + type="string", + description="A summary of the flights, e.g., flight numbers, price breakdown", + ), + ], +) + +all_tools = [find_events_tool, search_flights_tool, create_invoice_tool]