refactor, date context

This commit is contained in:
Steve Androulakis
2025-01-01 13:16:18 -08:00
parent 8115f0d2df
commit e7e8e7e658
17 changed files with 118 additions and 90 deletions

View File

@@ -11,6 +11,8 @@ This demo shows a multi-turn conversation with an AI agent running inside a Temp
## Running the example ## Running the example
From the /scripts directory:
1. Run the worker: `poetry run python run_worker.py` 1. Run the worker: `poetry run python run_worker.py`
2. In another terminal run the client with a prompt. 2. In another terminal run the client with a prompt.

0
activities/__init__.py Normal file
View File

View File

@@ -1,24 +1,26 @@
from dataclasses import dataclass from dataclasses import dataclass
from temporalio import activity from temporalio import activity
from temporalio.exceptions import ApplicationError
from ollama import chat, ChatResponse from ollama import chat, ChatResponse
import json import json
from temporalio.exceptions import ApplicationError
@dataclass @dataclass
class OllamaPromptInput: class ToolPromptInput:
prompt: str prompt: str
context_instructions: str context_instructions: str
class OllamaActivities: class ToolActivities:
@activity.defn @activity.defn
def prompt_ollama(self, input: OllamaPromptInput) -> str: def prompt_llm(self, input: ToolPromptInput) -> str:
model_name = "qwen2.5:14b" model_name = "qwen2.5:14b"
messages = [ messages = [
{ {
"role": "system", "role": "system",
"content": input.context_instructions, "content": input.context_instructions
+ ". The current date is "
+ get_current_date_human_readable(),
}, },
{ {
"role": "user", "role": "user",
@@ -41,3 +43,14 @@ class OllamaActivities:
raise ApplicationError(f"Invalid JSON: {e}") raise ApplicationError(f"Invalid JSON: {e}")
return data return data
def get_current_date_human_readable():
"""
Returns the current date in a human-readable format.
Example: Wednesday, January 1, 2025
"""
from datetime import datetime
return datetime.now().strftime("%A, %B %d, %Y")

0
models/__init__.py Normal file
View File

15
models/data_types.py Normal file
View File

@@ -0,0 +1,15 @@
from dataclasses import dataclass
from typing import Optional, Deque
from models.tool_definitions import ToolsData
@dataclass
class ToolWorkflowParams:
conversation_summary: Optional[str] = None
prompt_queue: Optional[Deque[str]] = None
@dataclass
class CombinedInput:
tool_params: ToolWorkflowParams
tools_data: ToolsData

View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List
@dataclass
class ToolArgument:
name: str
type: str
description: str
@dataclass
class ToolDefinition:
name: str
description: str
arguments: List[ToolArgument]
@dataclass
class ToolsData:
tools: List[ToolDefinition]

0
prompts/__init__.py Normal file
View File

View File

@@ -1,4 +1,4 @@
from workflows import ToolsData from models.tool_definitions import ToolsData
def generate_genai_prompt_from_tools_data( def generate_genai_prompt_from_tools_data(
@@ -51,7 +51,7 @@ def generate_genai_prompt_from_tools_data(
'- If you have enough information and are confirming, use "next": "confirm". This is the final step once you have filled all args.' '- If you have enough information and are confirming, use "next": "confirm". This is the final step once you have filled all args.'
) )
prompt_lines.append( prompt_lines.append(
'- Example of a good answer: {"response": "It seems we have all the information needed to search for flights. You will be flying from <city> to <city> from <date> to <date>. Is this correct?", "args":{"origin": "Seattle", "destination": "San Francisco", "dateFrom": "2025-01-04", "dateTo": "2025-01-08"}, "next": "confirm", "tool": "<toolName>" }' '- Example of a good answer: {"response": "It seems we have all the information needed to search for flights. You will be flying from <city> to <city> from <date> to <date>. Is this correct?", "args":{"origin": "Seattle", "destination": "San Francisco", "dateDepart": "2025-01-04", "dateReturn": "2025-01-08"}, "next": "confirm", "tool": "<toolName>" }'
) )
prompt_lines.append("- Return valid JSON without special characters.") prompt_lines.append("- Return valid JSON without special characters.")
prompt_lines.append("") prompt_lines.append("")

View File

@@ -1,7 +1,7 @@
[tool.poetry] [tool.poetry]
name = "temporal-ollama-agent" name = "temporal-AI-agent"
version = "0.1.0" version = "0.1.0"
description = "Temporal Ollama Agent" description = "Temporal AI Agent"
license = "MIT" license = "MIT"
authors = ["Steve Androulakis <steve.androulakis@temporal.io>"] authors = ["Steve Androulakis <steve.androulakis@temporal.io>"]
readme = "README.md" readme = "README.md"

View File

@@ -1,8 +1,7 @@
import asyncio import asyncio
import sys
from temporalio.client import Client from temporalio.client import Client
from workflows import EntityOllamaWorkflow from workflows.tool_workflow import ToolWorkflow
async def main(): async def main():
@@ -11,10 +10,10 @@ async def main():
workflow_id = "ollama-agent" workflow_id = "ollama-agent"
handle = client.get_workflow_handle_for(EntityOllamaWorkflow.run, workflow_id) handle = client.get_workflow_handle_for(ToolWorkflow.run, workflow_id)
# Sends a signal to the workflow # Sends a signal to the workflow
await handle.signal(EntityOllamaWorkflow.end_chat) await handle.signal(ToolWorkflow.end_chat)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,7 +1,7 @@
import asyncio import asyncio
from temporalio.client import Client from temporalio.client import Client
from workflows import EntityOllamaWorkflow from workflows import ToolWorkflow
async def main(): async def main():
@@ -12,7 +12,7 @@ async def main():
handle = client.get_workflow_handle(workflow_id) handle = client.get_workflow_handle(workflow_id)
# Queries the workflow for the conversation history # Queries the workflow for the conversation history
history = await handle.query(EntityOllamaWorkflow.get_conversation_history) history = await handle.query(ToolWorkflow.get_conversation_history)
print("Conversation History") print("Conversation History")
print( print(
@@ -20,7 +20,7 @@ async def main():
) )
# Queries the workflow for the conversation summary # Queries the workflow for the conversation summary
summary = await handle.query(EntityOllamaWorkflow.get_summary_from_history) summary = await handle.query(ToolWorkflow.get_summary_from_history)
if summary is not None: if summary is not None:
print("Conversation Summary:") print("Conversation Summary:")

View File

@@ -4,23 +4,24 @@ import logging
from temporalio.client import Client from temporalio.client import Client
from temporalio.worker import Worker from temporalio.worker import Worker
from workflows import EntityOllamaWorkflow
from activities import OllamaActivities from activities.tool_activities import ToolActivities
from workflows.tool_workflow import ToolWorkflow
from workflows.parent_workflow import ParentWorkflow
async def main(): async def main():
# Create client connected to server at the given address # Create client connected to server at the given address
client = await Client.connect("localhost:7233") client = await Client.connect("localhost:7233")
activities = OllamaActivities() activities = ToolActivities()
# Run the worker # Run the worker
with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor: with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor:
worker = Worker( worker = Worker(
client, client,
task_queue="ollama-task-queue", task_queue="ollama-task-queue",
workflows=[EntityOllamaWorkflow], workflows=[ToolWorkflow, ParentWorkflow],
activities=[activities.prompt_ollama, activities.parse_tool_data], activities=[activities.prompt_llm, activities.parse_tool_data],
activity_executor=activity_executor, activity_executor=activity_executor,
) )
await worker.run() await worker.run()

View File

@@ -3,22 +3,16 @@ import sys
from temporalio.client import Client from temporalio.client import Client
# Import your dataclasses/types from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
from workflows import ( from models.tool_definitions import ToolDefinition, ToolArgument
OllamaParams, from workflows.tool_workflow import ToolWorkflow
EntityOllamaWorkflow,
ToolsData,
ToolDefinition,
ToolArgument,
CombinedInput,
)
async def main(prompt): async def main(prompt):
# Construct your tool definitions in code # Construct your tool definitions in code
search_flights_tool = ToolDefinition( search_flights_tool = ToolDefinition(
name="SearchFlights", name="SearchFlights",
description="Search for flights from an origin to a destination within a date range", description="Search for return flights from an origin to a destination within a date range",
arguments=[ arguments=[
ToolArgument( ToolArgument(
name="origin", name="origin",
@@ -31,12 +25,12 @@ async def main(prompt):
description="Airport or city code for arrival (infer airport code from city)", description="Airport or city code for arrival (infer airport code from city)",
), ),
ToolArgument( ToolArgument(
name="dateFrom", name="dateDepart",
type="ISO8601", type="ISO8601",
description="Start of date range in human readable format", description="Start of date range in human readable format",
), ),
ToolArgument( ToolArgument(
name="dateTo", name="dateReturn",
type="ISO8601", type="ISO8601",
description="End of date range in human readable format", description="End of date range in human readable format",
), ),
@@ -47,7 +41,7 @@ async def main(prompt):
tools_data = ToolsData(tools=[search_flights_tool]) tools_data = ToolsData(tools=[search_flights_tool])
combined_input = CombinedInput( combined_input = CombinedInput(
ollama_params=OllamaParams(None, None), tools_data=tools_data tool_params=ToolWorkflowParams(None, None), tools_data=tools_data
) )
# Create client connected to Temporal server # Create client connected to Temporal server
@@ -57,7 +51,7 @@ async def main(prompt):
# Start or signal the workflow, passing OllamaParams and tools_data # Start or signal the workflow, passing OllamaParams and tools_data
await client.start_workflow( await client.start_workflow(
EntityOllamaWorkflow.run, ToolWorkflow.run,
combined_input, # or pass custom summary/prompt_queue combined_input, # or pass custom summary/prompt_queue
id=workflow_id, id=workflow_id,
task_queue="ollama-task-queue", task_queue="ollama-task-queue",

0
workflows/__init__.py Normal file
View File

View File

@@ -0,0 +1,14 @@
from temporalio import workflow
from .tool_workflow import ToolWorkflow, CombinedInput, ToolWorkflowParams
@workflow.defn
class ParentWorkflow:
@workflow.run
async def run(self, some_input: dict) -> dict:
combined_input = CombinedInput(
tool_params=ToolWorkflowParams(None, None), tools_data=some_input
)
child = workflow.start_child_workflow(ToolWorkflow.run, combined_input)
result = await child
return result

View File

@@ -1,67 +1,36 @@
import yaml
from collections import deque from collections import deque
from dataclasses import dataclass
from datetime import timedelta from datetime import timedelta
from typing import Deque, List, Optional, Tuple from typing import Deque, List, Optional, Tuple
from temporalio import workflow from temporalio import workflow
from prompts.agent_prompt_generators import (
with workflow.unsafe.imports_passed_through():
# Import the updated OllamaActivities and the new dataclass
from activities import OllamaActivities, OllamaPromptInput
@dataclass
class ToolArgument:
name: str
type: str
description: str
@dataclass
class ToolDefinition:
name: str
description: str
arguments: List[ToolArgument]
@dataclass
class ToolsData:
tools: List[ToolDefinition]
@dataclass
class OllamaParams:
conversation_summary: Optional[str] = None
prompt_queue: Optional[Deque[str]] = None
@dataclass
class CombinedInput:
ollama_params: OllamaParams
tools_data: ToolsData
from agent_prompt_generators import (
generate_genai_prompt_from_tools_data, generate_genai_prompt_from_tools_data,
generate_json_validation_prompt_from_tools_data, generate_json_validation_prompt_from_tools_data,
) )
with workflow.unsafe.imports_passed_through():
from activities.tool_activities import ToolActivities, ToolPromptInput
from prompts.agent_prompt_generators import (
generate_genai_prompt_from_tools_data,
generate_json_validation_prompt_from_tools_data,
)
from models.data_types import CombinedInput, ToolWorkflowParams
@workflow.defn @workflow.defn
class EntityOllamaWorkflow: class ToolWorkflow:
def __init__(self) -> None: def __init__(self) -> None:
self.conversation_history: List[Tuple[str, str]] = [] self.conversation_history: List[Tuple[str, str]] = []
self.prompt_queue: Deque[str] = deque() self.prompt_queue: Deque[str] = deque()
self.conversation_summary: Optional[str] = None self.conversation_summary: Optional[str] = None
self.continue_as_new_per_turns: int = 250
self.chat_ended: bool = False self.chat_ended: bool = False
self.tool_data = None self.tool_data = None
self.max_turns_before_continue: int = 250
@workflow.run @workflow.run
async def run(self, combined_input: CombinedInput) -> str: async def run(self, combined_input: CombinedInput) -> str:
params = combined_input.ollama_params params = combined_input.tool_params
tools_data = combined_input.tools_data tools_data = combined_input.tools_data
if params and params.conversation_summary: if params and params.conversation_summary:
@@ -92,14 +61,14 @@ class EntityOllamaWorkflow:
workflow.logger.info("Prompt: " + prompt) workflow.logger.info("Prompt: " + prompt)
# Pass a single input object # Pass a single input object
prompt_input = OllamaPromptInput( prompt_input = ToolPromptInput(
prompt=prompt, prompt=prompt,
context_instructions=context_instructions, context_instructions=context_instructions,
) )
# Call activity with one argument # Call activity with one argument
responsePrechecked = await workflow.execute_activity_method( responsePrechecked = await workflow.execute_activity_method(
OllamaActivities.prompt_ollama, ToolActivities.prompt_llm,
prompt_input, prompt_input,
schedule_to_close_timeout=timedelta(seconds=20), schedule_to_close_timeout=timedelta(seconds=20),
) )
@@ -113,14 +82,14 @@ class EntityOllamaWorkflow:
workflow.logger.info("Prompt: " + prompt) workflow.logger.info("Prompt: " + prompt)
# Pass a single input object # Pass a single input object
prompt_input = OllamaPromptInput( prompt_input = ToolPromptInput(
prompt=responsePrechecked, prompt=responsePrechecked,
context_instructions=json_validation_instructions, context_instructions=json_validation_instructions,
) )
# Call activity with one argument # Call activity with one argument
response = await workflow.execute_activity_method( response = await workflow.execute_activity_method(
OllamaActivities.prompt_ollama, ToolActivities.prompt_llm,
prompt_input, prompt_input,
schedule_to_close_timeout=timedelta(seconds=20), schedule_to_close_timeout=timedelta(seconds=20),
) )
@@ -130,7 +99,7 @@ class EntityOllamaWorkflow:
# Call activity with one argument # Call activity with one argument
tool_data = await workflow.execute_activity_method( tool_data = await workflow.execute_activity_method(
OllamaActivities.parse_tool_data, ToolActivities.parse_tool_data,
response, response,
schedule_to_close_timeout=timedelta(seconds=1), schedule_to_close_timeout=timedelta(seconds=1),
) )
@@ -141,29 +110,29 @@ class EntityOllamaWorkflow:
return self.tool_data return self.tool_data
# Continue as new after X turns # Continue as new after X turns
if len(self.conversation_history) >= self.continue_as_new_per_turns: if len(self.conversation_history) >= self.max_turns_before_continue:
# Summarize conversation # Summarize conversation
summary_context, summary_prompt = self.prompt_summary_with_history() summary_context, summary_prompt = self.prompt_summary_with_history()
summary_input = OllamaPromptInput( summary_input = ToolPromptInput(
prompt=summary_prompt, prompt=summary_prompt,
context_instructions=summary_context, context_instructions=summary_context,
) )
self.conversation_summary = await workflow.start_activity_method( self.conversation_summary = await workflow.start_activity_method(
OllamaActivities.prompt_ollama, ToolActivities.prompt_llm,
summary_input, summary_input,
schedule_to_close_timeout=timedelta(seconds=20), schedule_to_close_timeout=timedelta(seconds=20),
) )
workflow.logger.info( workflow.logger.info(
"Continuing as new after %i turns." "Continuing as new after %i turns."
% self.continue_as_new_per_turns, % self.max_turns_before_continue,
) )
workflow.continue_as_new( workflow.continue_as_new(
args=[ args=[
CombinedInput( CombinedInput(
ollama_params=OllamaParams( tool_params=ToolWorkflowParams(
conversation_summary=self.conversation_summary, conversation_summary=self.conversation_summary,
prompt_queue=self.prompt_queue, prompt_queue=self.prompt_queue,
), ),
@@ -179,13 +148,13 @@ class EntityOllamaWorkflow:
if len(self.conversation_history) > 1: if len(self.conversation_history) > 1:
# Summarize conversation # Summarize conversation
summary_context, summary_prompt = self.prompt_summary_with_history() summary_context, summary_prompt = self.prompt_summary_with_history()
summary_input = OllamaPromptInput( summary_input = ToolPromptInput(
prompt=summary_prompt, prompt=summary_prompt,
context_instructions=summary_context, context_instructions=summary_context,
) )
self.conversation_summary = await workflow.start_activity_method( self.conversation_summary = await workflow.start_activity_method(
OllamaActivities.prompt_ollama, ToolActivities.prompt_llm,
summary_input, summary_input,
schedule_to_close_timeout=timedelta(seconds=20), schedule_to_close_timeout=timedelta(seconds=20),
) )