From e7e8e7e65859770a16dab65cc1b33a7ae9115b12 Mon Sep 17 00:00:00 2001 From: Steve Androulakis Date: Wed, 1 Jan 2025 13:16:18 -0800 Subject: [PATCH] refactor, date context --- README.md | 2 + activities/__init__.py | 0 .../tool_activities.py | 23 ++++-- models/__init__.py | 0 models/data_types.py | 15 ++++ models/tool_definitions.py | 21 +++++ prompts/__init__.py | 0 .../agent_prompt_generators.py | 4 +- pyproject.toml | 4 +- end_chat.py => scripts/end_chat.py | 7 +- get_history.py => scripts/get_history.py | 6 +- run_ollama.py => scripts/run_ollama.py | 0 run_worker.py => scripts/run_worker.py | 11 +-- send_message.py => scripts/send_message.py | 22 ++---- workflows/__init__.py | 0 workflows/parent_workflow.py | 14 ++++ workflows.py => workflows/tool_workflow.py | 79 ++++++------------- 17 files changed, 118 insertions(+), 90 deletions(-) create mode 100644 activities/__init__.py rename activities.py => activities/tool_activities.py (66%) create mode 100644 models/__init__.py create mode 100644 models/data_types.py create mode 100644 models/tool_definitions.py create mode 100644 prompts/__init__.py rename agent_prompt_generators.py => prompts/agent_prompt_generators.py (97%) rename end_chat.py => scripts/end_chat.py (64%) rename get_history.py => scripts/get_history.py (76%) rename run_ollama.py => scripts/run_ollama.py (100%) rename run_worker.py => scripts/run_worker.py (69%) rename send_message.py => scripts/send_message.py (79%) create mode 100644 workflows/__init__.py create mode 100644 workflows/parent_workflow.py rename workflows.py => workflows/tool_workflow.py (82%) diff --git a/README.md b/README.md index 34b2700..c101028 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,8 @@ This demo shows a multi-turn conversation with an AI agent running inside a Temp ## Running the example +From the /scripts directory: + 1. Run the worker: `poetry run python run_worker.py` 2. In another terminal run the client with a prompt. diff --git a/activities/__init__.py b/activities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/activities.py b/activities/tool_activities.py similarity index 66% rename from activities.py rename to activities/tool_activities.py index 8d1bd87..7952066 100644 --- a/activities.py +++ b/activities/tool_activities.py @@ -1,24 +1,26 @@ from dataclasses import dataclass from temporalio import activity +from temporalio.exceptions import ApplicationError from ollama import chat, ChatResponse import json -from temporalio.exceptions import ApplicationError @dataclass -class OllamaPromptInput: +class ToolPromptInput: prompt: str context_instructions: str -class OllamaActivities: +class ToolActivities: @activity.defn - def prompt_ollama(self, input: OllamaPromptInput) -> str: + def prompt_llm(self, input: ToolPromptInput) -> str: model_name = "qwen2.5:14b" messages = [ { "role": "system", - "content": input.context_instructions, + "content": input.context_instructions + + ". The current date is " + + get_current_date_human_readable(), }, { "role": "user", @@ -41,3 +43,14 @@ class OllamaActivities: raise ApplicationError(f"Invalid JSON: {e}") return data + + +def get_current_date_human_readable(): + """ + Returns the current date in a human-readable format. + + Example: Wednesday, January 1, 2025 + """ + from datetime import datetime + + return datetime.now().strftime("%A, %B %d, %Y") diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/data_types.py b/models/data_types.py new file mode 100644 index 0000000..4b81577 --- /dev/null +++ b/models/data_types.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from typing import Optional, Deque +from models.tool_definitions import ToolsData + + +@dataclass +class ToolWorkflowParams: + conversation_summary: Optional[str] = None + prompt_queue: Optional[Deque[str]] = None + + +@dataclass +class CombinedInput: + tool_params: ToolWorkflowParams + tools_data: ToolsData diff --git a/models/tool_definitions.py b/models/tool_definitions.py new file mode 100644 index 0000000..79fbfa6 --- /dev/null +++ b/models/tool_definitions.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List + + +@dataclass +class ToolArgument: + name: str + type: str + description: str + + +@dataclass +class ToolDefinition: + name: str + description: str + arguments: List[ToolArgument] + + +@dataclass +class ToolsData: + tools: List[ToolDefinition] diff --git a/prompts/__init__.py b/prompts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/agent_prompt_generators.py b/prompts/agent_prompt_generators.py similarity index 97% rename from agent_prompt_generators.py rename to prompts/agent_prompt_generators.py index c2e8c6d..013cc85 100644 --- a/agent_prompt_generators.py +++ b/prompts/agent_prompt_generators.py @@ -1,4 +1,4 @@ -from workflows import ToolsData +from models.tool_definitions import ToolsData def generate_genai_prompt_from_tools_data( @@ -51,7 +51,7 @@ def generate_genai_prompt_from_tools_data( '- If you have enough information and are confirming, use "next": "confirm". This is the final step once you have filled all args.' ) prompt_lines.append( - '- Example of a good answer: {"response": "It seems we have all the information needed to search for flights. You will be flying from to from to . Is this correct?", "args":{"origin": "Seattle", "destination": "San Francisco", "dateFrom": "2025-01-04", "dateTo": "2025-01-08"}, "next": "confirm", "tool": "" }' + '- Example of a good answer: {"response": "It seems we have all the information needed to search for flights. You will be flying from to from to . Is this correct?", "args":{"origin": "Seattle", "destination": "San Francisco", "dateDepart": "2025-01-04", "dateReturn": "2025-01-08"}, "next": "confirm", "tool": "" }' ) prompt_lines.append("- Return valid JSON without special characters.") prompt_lines.append("") diff --git a/pyproject.toml b/pyproject.toml index 2fd2208..5ccd568 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] -name = "temporal-ollama-agent" +name = "temporal-AI-agent" version = "0.1.0" -description = "Temporal Ollama Agent" +description = "Temporal AI Agent" license = "MIT" authors = ["Steve Androulakis "] readme = "README.md" diff --git a/end_chat.py b/scripts/end_chat.py similarity index 64% rename from end_chat.py rename to scripts/end_chat.py index b9554a0..8913f16 100644 --- a/end_chat.py +++ b/scripts/end_chat.py @@ -1,8 +1,7 @@ import asyncio -import sys from temporalio.client import Client -from workflows import EntityOllamaWorkflow +from workflows.tool_workflow import ToolWorkflow async def main(): @@ -11,10 +10,10 @@ async def main(): workflow_id = "ollama-agent" - handle = client.get_workflow_handle_for(EntityOllamaWorkflow.run, workflow_id) + handle = client.get_workflow_handle_for(ToolWorkflow.run, workflow_id) # Sends a signal to the workflow - await handle.signal(EntityOllamaWorkflow.end_chat) + await handle.signal(ToolWorkflow.end_chat) if __name__ == "__main__": diff --git a/get_history.py b/scripts/get_history.py similarity index 76% rename from get_history.py rename to scripts/get_history.py index 0fb07c6..a09a868 100644 --- a/get_history.py +++ b/scripts/get_history.py @@ -1,7 +1,7 @@ import asyncio from temporalio.client import Client -from workflows import EntityOllamaWorkflow +from workflows import ToolWorkflow async def main(): @@ -12,7 +12,7 @@ async def main(): handle = client.get_workflow_handle(workflow_id) # Queries the workflow for the conversation history - history = await handle.query(EntityOllamaWorkflow.get_conversation_history) + history = await handle.query(ToolWorkflow.get_conversation_history) print("Conversation History") print( @@ -20,7 +20,7 @@ async def main(): ) # Queries the workflow for the conversation summary - summary = await handle.query(EntityOllamaWorkflow.get_summary_from_history) + summary = await handle.query(ToolWorkflow.get_summary_from_history) if summary is not None: print("Conversation Summary:") diff --git a/run_ollama.py b/scripts/run_ollama.py similarity index 100% rename from run_ollama.py rename to scripts/run_ollama.py diff --git a/run_worker.py b/scripts/run_worker.py similarity index 69% rename from run_worker.py rename to scripts/run_worker.py index aab7ed8..888aa7b 100644 --- a/run_worker.py +++ b/scripts/run_worker.py @@ -4,23 +4,24 @@ import logging from temporalio.client import Client from temporalio.worker import Worker -from workflows import EntityOllamaWorkflow -from activities import OllamaActivities +from activities.tool_activities import ToolActivities +from workflows.tool_workflow import ToolWorkflow +from workflows.parent_workflow import ParentWorkflow async def main(): # Create client connected to server at the given address client = await Client.connect("localhost:7233") - activities = OllamaActivities() + activities = ToolActivities() # Run the worker with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor: worker = Worker( client, task_queue="ollama-task-queue", - workflows=[EntityOllamaWorkflow], - activities=[activities.prompt_ollama, activities.parse_tool_data], + workflows=[ToolWorkflow, ParentWorkflow], + activities=[activities.prompt_llm, activities.parse_tool_data], activity_executor=activity_executor, ) await worker.run() diff --git a/send_message.py b/scripts/send_message.py similarity index 79% rename from send_message.py rename to scripts/send_message.py index 96d5b19..10f7905 100644 --- a/send_message.py +++ b/scripts/send_message.py @@ -3,22 +3,16 @@ import sys from temporalio.client import Client -# Import your dataclasses/types -from workflows import ( - OllamaParams, - EntityOllamaWorkflow, - ToolsData, - ToolDefinition, - ToolArgument, - CombinedInput, -) +from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams +from models.tool_definitions import ToolDefinition, ToolArgument +from workflows.tool_workflow import ToolWorkflow async def main(prompt): # Construct your tool definitions in code search_flights_tool = ToolDefinition( name="SearchFlights", - description="Search for flights from an origin to a destination within a date range", + description="Search for return flights from an origin to a destination within a date range", arguments=[ ToolArgument( name="origin", @@ -31,12 +25,12 @@ async def main(prompt): description="Airport or city code for arrival (infer airport code from city)", ), ToolArgument( - name="dateFrom", + name="dateDepart", type="ISO8601", description="Start of date range in human readable format", ), ToolArgument( - name="dateTo", + name="dateReturn", type="ISO8601", description="End of date range in human readable format", ), @@ -47,7 +41,7 @@ async def main(prompt): tools_data = ToolsData(tools=[search_flights_tool]) combined_input = CombinedInput( - ollama_params=OllamaParams(None, None), tools_data=tools_data + tool_params=ToolWorkflowParams(None, None), tools_data=tools_data ) # Create client connected to Temporal server @@ -57,7 +51,7 @@ async def main(prompt): # Start or signal the workflow, passing OllamaParams and tools_data await client.start_workflow( - EntityOllamaWorkflow.run, + ToolWorkflow.run, combined_input, # or pass custom summary/prompt_queue id=workflow_id, task_queue="ollama-task-queue", diff --git a/workflows/__init__.py b/workflows/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/workflows/parent_workflow.py b/workflows/parent_workflow.py new file mode 100644 index 0000000..fdddcf3 --- /dev/null +++ b/workflows/parent_workflow.py @@ -0,0 +1,14 @@ +from temporalio import workflow +from .tool_workflow import ToolWorkflow, CombinedInput, ToolWorkflowParams + + +@workflow.defn +class ParentWorkflow: + @workflow.run + async def run(self, some_input: dict) -> dict: + combined_input = CombinedInput( + tool_params=ToolWorkflowParams(None, None), tools_data=some_input + ) + child = workflow.start_child_workflow(ToolWorkflow.run, combined_input) + result = await child + return result diff --git a/workflows.py b/workflows/tool_workflow.py similarity index 82% rename from workflows.py rename to workflows/tool_workflow.py index 3e984fb..94cd9f8 100644 --- a/workflows.py +++ b/workflows/tool_workflow.py @@ -1,67 +1,36 @@ -import yaml from collections import deque -from dataclasses import dataclass from datetime import timedelta from typing import Deque, List, Optional, Tuple from temporalio import workflow - -with workflow.unsafe.imports_passed_through(): - # Import the updated OllamaActivities and the new dataclass - from activities import OllamaActivities, OllamaPromptInput - - -@dataclass -class ToolArgument: - name: str - type: str - description: str - - -@dataclass -class ToolDefinition: - name: str - description: str - arguments: List[ToolArgument] - - -@dataclass -class ToolsData: - tools: List[ToolDefinition] - - -@dataclass -class OllamaParams: - conversation_summary: Optional[str] = None - prompt_queue: Optional[Deque[str]] = None - - -@dataclass -class CombinedInput: - ollama_params: OllamaParams - tools_data: ToolsData - - -from agent_prompt_generators import ( +from prompts.agent_prompt_generators import ( generate_genai_prompt_from_tools_data, generate_json_validation_prompt_from_tools_data, ) +with workflow.unsafe.imports_passed_through(): + from activities.tool_activities import ToolActivities, ToolPromptInput + from prompts.agent_prompt_generators import ( + generate_genai_prompt_from_tools_data, + generate_json_validation_prompt_from_tools_data, + ) + from models.data_types import CombinedInput, ToolWorkflowParams + @workflow.defn -class EntityOllamaWorkflow: +class ToolWorkflow: def __init__(self) -> None: self.conversation_history: List[Tuple[str, str]] = [] self.prompt_queue: Deque[str] = deque() self.conversation_summary: Optional[str] = None - self.continue_as_new_per_turns: int = 250 self.chat_ended: bool = False self.tool_data = None + self.max_turns_before_continue: int = 250 @workflow.run async def run(self, combined_input: CombinedInput) -> str: - params = combined_input.ollama_params + params = combined_input.tool_params tools_data = combined_input.tools_data if params and params.conversation_summary: @@ -92,14 +61,14 @@ class EntityOllamaWorkflow: workflow.logger.info("Prompt: " + prompt) # Pass a single input object - prompt_input = OllamaPromptInput( + prompt_input = ToolPromptInput( prompt=prompt, context_instructions=context_instructions, ) # Call activity with one argument responsePrechecked = await workflow.execute_activity_method( - OllamaActivities.prompt_ollama, + ToolActivities.prompt_llm, prompt_input, schedule_to_close_timeout=timedelta(seconds=20), ) @@ -113,14 +82,14 @@ class EntityOllamaWorkflow: workflow.logger.info("Prompt: " + prompt) # Pass a single input object - prompt_input = OllamaPromptInput( + prompt_input = ToolPromptInput( prompt=responsePrechecked, context_instructions=json_validation_instructions, ) # Call activity with one argument response = await workflow.execute_activity_method( - OllamaActivities.prompt_ollama, + ToolActivities.prompt_llm, prompt_input, schedule_to_close_timeout=timedelta(seconds=20), ) @@ -130,7 +99,7 @@ class EntityOllamaWorkflow: # Call activity with one argument tool_data = await workflow.execute_activity_method( - OllamaActivities.parse_tool_data, + ToolActivities.parse_tool_data, response, schedule_to_close_timeout=timedelta(seconds=1), ) @@ -141,29 +110,29 @@ class EntityOllamaWorkflow: return self.tool_data # Continue as new after X turns - if len(self.conversation_history) >= self.continue_as_new_per_turns: + if len(self.conversation_history) >= self.max_turns_before_continue: # Summarize conversation summary_context, summary_prompt = self.prompt_summary_with_history() - summary_input = OllamaPromptInput( + summary_input = ToolPromptInput( prompt=summary_prompt, context_instructions=summary_context, ) self.conversation_summary = await workflow.start_activity_method( - OllamaActivities.prompt_ollama, + ToolActivities.prompt_llm, summary_input, schedule_to_close_timeout=timedelta(seconds=20), ) workflow.logger.info( "Continuing as new after %i turns." - % self.continue_as_new_per_turns, + % self.max_turns_before_continue, ) workflow.continue_as_new( args=[ CombinedInput( - ollama_params=OllamaParams( + tool_params=ToolWorkflowParams( conversation_summary=self.conversation_summary, prompt_queue=self.prompt_queue, ), @@ -179,13 +148,13 @@ class EntityOllamaWorkflow: if len(self.conversation_history) > 1: # Summarize conversation summary_context, summary_prompt = self.prompt_summary_with_history() - summary_input = OllamaPromptInput( + summary_input = ToolPromptInput( prompt=summary_prompt, context_instructions=summary_context, ) self.conversation_summary = await workflow.start_activity_method( - OllamaActivities.prompt_ollama, + ToolActivities.prompt_llm, summary_input, schedule_to_close_timeout=timedelta(seconds=20), )