diff --git a/README.md b/README.md index 6ad86a0..706b95e 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,11 @@ -# Multi-turn chat with Ollama Entity Workflow +# AI Agent execution using Temporal -Multi-Turn Chat using an Entity Workflow. The workflow runs forever unless explicitly ended. The workflow continues as new after a configurable number of chat turns to keep the prompt size small and the Temporal event history small. Each continued-as-new workflow receives a summary of the conversation history so far for context. +Work in progress (very early!). + +This demo shows a multi-turn conversation with an AI agent running inside a Temporal Entity Workflow. The goal is to collect information towards a goal. There's a simple DSL input for collecting information (currently set up to search for flights). The AI will respond with clarifications and ask for any missing information (e.g., origin city, destination, travel dates). It uses a local LLM via Ollama. ## Setup -* Install [Ollama](https://ollama.com) and the Mistral model (`ollama run mistral`). +* Install [Ollama](https://ollama.com) and the Mistral model (`ollama run qwen2.5:14b`). (note this model is more than 10GB to download). * Install and run Temporal. Follow the instructions in the [Temporal documentation](https://learn.temporal.io/getting_started/python/dev_environment/#set-up-a-local-temporal-service-for-development-with-temporal-cli) to install and run the Temporal server. * Install the dependencies: `poetry install` @@ -12,13 +14,24 @@ Multi-Turn Chat using an Entity Workflow. The workflow runs forever unless expli 1. Run the worker: `poetry run python run_worker.py` 2. In another terminal run the client with a prompt. - Example: `poetry run python send_message.py 'What animals are marsupials?'` + Example: `poetry run python send_message.py 'I want to book a flight'` 3. View the worker's output for the response. 4. Give followup prompts by signaling the workflow. - Example: `poetry run python send_message.py 'Do they lay eggs?'` + Example: `poetry run python send_message.py 'From San Francisco'` 5. Get the conversation history summary by querying the workflow. Example: `poetry run python get_history.py` 6. To end the chat session, run `poetry run python end_chat.py` + +The chat session will end if it has collected enough information to complete the task or if the user explicitly ends the chat session. + +Run query get_tool_data to see the data the tool has collected so far. + +## TODO +- This is currently a good single tool workflow. It could be a child as part of a planning workflow (multiple tools). +- I should integrate another tool. Perhaps something that consumes web sites hunting for destinations to go to in the first place. +- I should make this workflow execute a Search for flights as right now it will finish without doing anything. +- I need to add a human in the loop confirmation step before it executes tools. +- I need to build a chat interface so it's not cli-controlled. Also want to show some 'behind the scenes' of the agents being used as they run. \ No newline at end of file diff --git a/activities.py b/activities.py index 46a82a5..8d1bd87 100644 --- a/activities.py +++ b/activities.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from temporalio import activity from ollama import chat, ChatResponse +import json +from temporalio.exceptions import ApplicationError @dataclass @@ -12,7 +14,7 @@ class OllamaPromptInput: class OllamaActivities: @activity.defn def prompt_ollama(self, input: OllamaPromptInput) -> str: - model_name = "mistral" + model_name = "qwen2.5:14b" messages = [ { "role": "system", @@ -26,3 +28,16 @@ class OllamaActivities: response: ChatResponse = chat(model=model_name, messages=messages) return response.message.content + + @activity.defn + def parse_tool_data(self, json_str: str) -> dict: + """ + Parses a JSON string into a dictionary. + Raises a ValueError if the JSON is invalid. + """ + try: + data = json.loads(json_str) + except json.JSONDecodeError as e: + raise ApplicationError(f"Invalid JSON: {e}") + + return data diff --git a/agent_prompt_generators.py b/agent_prompt_generators.py new file mode 100644 index 0000000..c2e8c6d --- /dev/null +++ b/agent_prompt_generators.py @@ -0,0 +1,130 @@ +from workflows import ToolsData + + +def generate_genai_prompt_from_tools_data( + tools_data: ToolsData, conversation_history: str +) -> str: + """ + Generates a prompt describing the tools and the instructions for the AI + assistant, using the conversation history provided. + """ + prompt_lines = [] + + prompt_lines.append( + "You are an AI assistant that must determine all required arguments" + ) + prompt_lines.append("for the tools to achieve the user's goal.\n") + prompt_lines.append("Conversation history so far:") + prompt_lines.append(conversation_history) + prompt_lines.append("") + + # List all tools and their arguments + for tool in tools_data.tools: + prompt_lines.append(f"Tool to run: {tool.name}") + prompt_lines.append(f"Description: {tool.description}") + prompt_lines.append("Arguments needed:") + for arg in tool.arguments: + prompt_lines.append(f" - {arg.name} ({arg.type}): {arg.description}") + prompt_lines.append("") + + prompt_lines.append("Instructions:") + prompt_lines.append( + "1. You need to ask the user (or confirm with them) for each argument required by the tools above." + ) + prompt_lines.append( + "2. If you do not yet have a specific argument value, ask the user for it." + ) + prompt_lines.append( + "3. Once you have all arguments, read them back to confirm with the user before yielding to the tool to take action.\n" + ) + prompt_lines.append( + 'Your response must be valid JSON in the format: {"response": "", "next": "", ' + + '"tool": "", "arg1": "value1", "arg2": "value2"}" where args are the arguments for the tool (or null if unknown so far)."' + ) + prompt_lines.append( + '- Your goal is to convert the AI responses into filled args in the JSON and once all args are filled, confirm with the user.".' + ) + prompt_lines.append( + '- If you still need information from the user, use "next": "question".' + ) + prompt_lines.append( + '- If you have enough information and are confirming, use "next": "confirm". This is the final step once you have filled all args.' + ) + prompt_lines.append( + '- Example of a good answer: {"response": "It seems we have all the information needed to search for flights. You will be flying from to from to . Is this correct?", "args":{"origin": "Seattle", "destination": "San Francisco", "dateFrom": "2025-01-04", "dateTo": "2025-01-08"}, "next": "confirm", "tool": "" }' + ) + prompt_lines.append("- Return valid JSON without special characters.") + prompt_lines.append("") + prompt_lines.append("Begin by prompting or confirming the necessary details.") + + return "\n".join(prompt_lines) + + +def generate_json_validation_prompt_from_tools_data( + tools_data: ToolsData, conversation_history: str, raw_json: str +) -> str: + """ + Generates a prompt instructing the AI to: + 1. Check that the given raw JSON is syntactically valid. + 2. Ensure the 'tool' matches one of the defined tools in tools_data. + 3. Confirm or correct that all required arguments are present and make sense. + 4. Return a corrected JSON if possible. + """ + prompt_lines = [] + + prompt_lines.append( + "You are an AI assistant that must validate the following JSON." + ) + prompt_lines.append("It may be malformed or incomplete.") + prompt_lines.append("You also have a list of tools and their required arguments.") + prompt_lines.append( + "You must ensure the JSON is valid and matches these definitions.\n" + ) + + prompt_lines.append("== Tools Definitions ==") + for tool in tools_data.tools: + prompt_lines.append(f"Tool name: {tool.name}") + prompt_lines.append(f" Description: {tool.description}") + prompt_lines.append(" Arguments required:") + for arg in tool.arguments: + prompt_lines.append(f" - {arg.name} ({arg.type}): {arg.description}") + prompt_lines.append("") + + prompt_lines.append("== JSON to Validate ==") + prompt_lines.append(raw_json) + prompt_lines.append("") + + prompt_lines.append("Validation checks:") + prompt_lines.append("1. Is the JSON syntactically valid? If not, fix it.") + prompt_lines.append( + "2. Does the 'tool' field match one of the tools in Tools Definitions? If not, correct or note the mismatch." + ) + prompt_lines.append( + "3. Do the arguments under 'args' correspond exactly to the required arguments for that tool? Are they present and valid? If not, set them to null or correct them." + ) + prompt_lines.append( + "4. Confirm the 'response' and 'next' fields are present, if applicable, per the desired JSON structure." + ) + prompt_lines.append( + "5. If something is missing or incorrect, fix it in the final JSON output or explain what is missing." + ) + prompt_lines.append( + "6. You can and should take values from the response, parse them and insert them into JSON args where possible. Carefully parse the history and the latest response to fill in the args." + ) + prompt_lines.append("") + prompt_lines.append( + "Return your response in valid JSON. DO NOT RETURN ANYTHING EXCEPT VALID JSON IN THE CORRECT FORMAT. No editorializing or comments on the JSON." + ) + prompt_lines.append("The final output must:") + prompt_lines.append( + '- Provide the corrected JSON if you can fix it, using the format {"response": "...", "next": "...", "tool": "...", "args": {...}}.' + ) + prompt_lines.append( + '- If you cannot correct it then provide a skeleton JSON structure with the original "response" value inside.\n' + ) + prompt_lines.append("Conversation history so far:") + prompt_lines.append(conversation_history) + + prompt_lines.append("Begin validating now.") + + return "\n".join(prompt_lines) diff --git a/poetry.lock b/poetry.lock index d26ba0c..7fb7540 100644 --- a/poetry.lock +++ b/poetry.lock @@ -495,6 +495,68 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "pyyaml" +version = "6.0.2" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"}, + {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"}, + {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"}, + {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"}, + {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"}, + {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"}, + {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"}, + {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"}, + {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"}, + {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"}, + {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"}, + {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"}, + {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"}, + {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"}, + {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"}, + {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, +] + [[package]] name = "six" version = "1.17.0" @@ -608,4 +670,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "85301bd486e833b45f80659b9e767579e78a4a379766e8b8e1b68d4d93d0be6a" +content-hash = "6e712358e27083eae9c9b8b64ee258e14f589dbd0007c7d1615adffdd99b7e2a" diff --git a/pyproject.toml b/pyproject.toml index 4a2ec65..2fd2208 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ temporalio = "^1.8.0" # Standard library modules (e.g. asyncio, collections) don't need to be added # since they're built-in for Python 3.8+. ollama = "^0.4.5" +pyyaml = "^6.0.2" [tool.poetry.group.dev.dependencies] pytest = "^7.3" diff --git a/run_worker.py b/run_worker.py index 1d4397f..aab7ed8 100644 --- a/run_worker.py +++ b/run_worker.py @@ -20,7 +20,7 @@ async def main(): client, task_queue="ollama-task-queue", workflows=[EntityOllamaWorkflow], - activities=[activities.prompt_ollama], + activities=[activities.prompt_ollama, activities.parse_tool_data], activity_executor=activity_executor, ) await worker.run() diff --git a/send_message.py b/send_message.py index d20d67d..96d5b19 100644 --- a/send_message.py +++ b/send_message.py @@ -2,19 +2,63 @@ import asyncio import sys from temporalio.client import Client -from workflows import OllamaParams, EntityOllamaWorkflow + +# Import your dataclasses/types +from workflows import ( + OllamaParams, + EntityOllamaWorkflow, + ToolsData, + ToolDefinition, + ToolArgument, + CombinedInput, +) async def main(prompt): - # Create client connected to server at the given address + # Construct your tool definitions in code + search_flights_tool = ToolDefinition( + name="SearchFlights", + description="Search for flights from an origin to a destination within a date range", + arguments=[ + ToolArgument( + name="origin", + type="string", + description="Airport or city (infer airport code from city)", + ), + ToolArgument( + name="destination", + type="string", + description="Airport or city code for arrival (infer airport code from city)", + ), + ToolArgument( + name="dateFrom", + type="ISO8601", + description="Start of date range in human readable format", + ), + ToolArgument( + name="dateTo", + type="ISO8601", + description="End of date range in human readable format", + ), + ], + ) + + # Wrap it in ToolsData + tools_data = ToolsData(tools=[search_flights_tool]) + + combined_input = CombinedInput( + ollama_params=OllamaParams(None, None), tools_data=tools_data + ) + + # Create client connected to Temporal server client = await Client.connect("localhost:7233") workflow_id = "ollama-agent" - # Sends a signal to the workflow (and starts it if needed) + # Start or signal the workflow, passing OllamaParams and tools_data await client.start_workflow( EntityOllamaWorkflow.run, - OllamaParams(None, None), # or pass in custom summary/prompt_queue if desired + combined_input, # or pass custom summary/prompt_queue id=workflow_id, task_queue="ollama-task-queue", start_signal="user_prompt", diff --git a/workflows.py b/workflows.py index 0dc3427..3e984fb 100644 --- a/workflows.py +++ b/workflows.py @@ -1,3 +1,4 @@ +import yaml from collections import deque from dataclasses import dataclass from datetime import timedelta @@ -10,23 +11,59 @@ with workflow.unsafe.imports_passed_through(): from activities import OllamaActivities, OllamaPromptInput +@dataclass +class ToolArgument: + name: str + type: str + description: str + + +@dataclass +class ToolDefinition: + name: str + description: str + arguments: List[ToolArgument] + + +@dataclass +class ToolsData: + tools: List[ToolDefinition] + + @dataclass class OllamaParams: conversation_summary: Optional[str] = None prompt_queue: Optional[Deque[str]] = None +@dataclass +class CombinedInput: + ollama_params: OllamaParams + tools_data: ToolsData + + +from agent_prompt_generators import ( + generate_genai_prompt_from_tools_data, + generate_json_validation_prompt_from_tools_data, +) + + @workflow.defn class EntityOllamaWorkflow: def __init__(self) -> None: self.conversation_history: List[Tuple[str, str]] = [] self.prompt_queue: Deque[str] = deque() self.conversation_summary: Optional[str] = None - self.continue_as_new_per_turns: int = 6 + self.continue_as_new_per_turns: int = 250 self.chat_ended: bool = False + self.tool_data = None @workflow.run - async def run(self, params: OllamaParams) -> str: + async def run(self, combined_input: CombinedInput) -> str: + + params = combined_input.ollama_params + tools_data = combined_input.tools_data + if params and params.conversation_summary: self.conversation_history.append( ("conversation_summary", params.conversation_summary) @@ -49,15 +86,38 @@ class EntityOllamaWorkflow: self.conversation_history.append(("user", prompt)) # Build prompt + context - context_instructions, actual_prompt = self.prompt_with_history(prompt) + context_instructions = generate_genai_prompt_from_tools_data( + tools_data, self.format_history() + ) workflow.logger.info("Prompt: " + prompt) # Pass a single input object prompt_input = OllamaPromptInput( - prompt=actual_prompt, + prompt=prompt, context_instructions=context_instructions, ) + # Call activity with one argument + responsePrechecked = await workflow.execute_activity_method( + OllamaActivities.prompt_ollama, + prompt_input, + schedule_to_close_timeout=timedelta(seconds=20), + ) + + # Check if the response is valid JSON + json_validation_instructions = ( + generate_json_validation_prompt_from_tools_data( + tools_data, self.format_history(), responsePrechecked + ) + ) + workflow.logger.info("Prompt: " + prompt) + + # Pass a single input object + prompt_input = OllamaPromptInput( + prompt=responsePrechecked, + context_instructions=json_validation_instructions, + ) + # Call activity with one argument response = await workflow.execute_activity_method( OllamaActivities.prompt_ollama, @@ -68,6 +128,18 @@ class EntityOllamaWorkflow: workflow.logger.info(f"Ollama response: {response}") self.conversation_history.append(("response", response)) + # Call activity with one argument + tool_data = await workflow.execute_activity_method( + OllamaActivities.parse_tool_data, + response, + schedule_to_close_timeout=timedelta(seconds=1), + ) + + self.tool_data = tool_data + + if self.tool_data.get("next") == "confirm": + return self.tool_data + # Continue as new after X turns if len(self.conversation_history) >= self.continue_as_new_per_turns: # Summarize conversation @@ -90,9 +162,12 @@ class EntityOllamaWorkflow: workflow.continue_as_new( args=[ - OllamaParams( - conversation_summary=self.conversation_summary, - prompt_queue=self.prompt_queue, + CombinedInput( + ollama_params=OllamaParams( + conversation_summary=self.conversation_summary, + prompt_queue=self.prompt_queue, + ), + tools_data=tools_data, ) ] ) @@ -140,6 +215,10 @@ class EntityOllamaWorkflow: def get_summary_from_history(self) -> Optional[str]: return self.conversation_summary + @workflow.query + def get_tool_data(self) -> Optional[str]: + return self.tool_data + # Helper: generate text of the entire conversation so far def format_history(self) -> str: return " ".join(f"{text}" for _, text in self.conversation_history)