refactor, date context

This commit is contained in:
Steve Androulakis
2025-01-01 13:16:18 -08:00
parent 8115f0d2df
commit e7e8e7e658
17 changed files with 118 additions and 90 deletions

View File

@@ -11,6 +11,8 @@ This demo shows a multi-turn conversation with an AI agent running inside a Temp
## Running the example
From the /scripts directory:
1. Run the worker: `poetry run python run_worker.py`
2. In another terminal run the client with a prompt.

0
activities/__init__.py Normal file
View File

View File

@@ -1,24 +1,26 @@
from dataclasses import dataclass
from temporalio import activity
from temporalio.exceptions import ApplicationError
from ollama import chat, ChatResponse
import json
from temporalio.exceptions import ApplicationError
@dataclass
class OllamaPromptInput:
class ToolPromptInput:
prompt: str
context_instructions: str
class OllamaActivities:
class ToolActivities:
@activity.defn
def prompt_ollama(self, input: OllamaPromptInput) -> str:
def prompt_llm(self, input: ToolPromptInput) -> str:
model_name = "qwen2.5:14b"
messages = [
{
"role": "system",
"content": input.context_instructions,
"content": input.context_instructions
+ ". The current date is "
+ get_current_date_human_readable(),
},
{
"role": "user",
@@ -41,3 +43,14 @@ class OllamaActivities:
raise ApplicationError(f"Invalid JSON: {e}")
return data
def get_current_date_human_readable():
"""
Returns the current date in a human-readable format.
Example: Wednesday, January 1, 2025
"""
from datetime import datetime
return datetime.now().strftime("%A, %B %d, %Y")

0
models/__init__.py Normal file
View File

15
models/data_types.py Normal file
View File

@@ -0,0 +1,15 @@
from dataclasses import dataclass
from typing import Optional, Deque
from models.tool_definitions import ToolsData
@dataclass
class ToolWorkflowParams:
conversation_summary: Optional[str] = None
prompt_queue: Optional[Deque[str]] = None
@dataclass
class CombinedInput:
tool_params: ToolWorkflowParams
tools_data: ToolsData

View File

@@ -0,0 +1,21 @@
from dataclasses import dataclass
from typing import List
@dataclass
class ToolArgument:
name: str
type: str
description: str
@dataclass
class ToolDefinition:
name: str
description: str
arguments: List[ToolArgument]
@dataclass
class ToolsData:
tools: List[ToolDefinition]

0
prompts/__init__.py Normal file
View File

View File

@@ -1,4 +1,4 @@
from workflows import ToolsData
from models.tool_definitions import ToolsData
def generate_genai_prompt_from_tools_data(
@@ -51,7 +51,7 @@ def generate_genai_prompt_from_tools_data(
'- If you have enough information and are confirming, use "next": "confirm". This is the final step once you have filled all args.'
)
prompt_lines.append(
'- Example of a good answer: {"response": "It seems we have all the information needed to search for flights. You will be flying from <city> to <city> from <date> to <date>. Is this correct?", "args":{"origin": "Seattle", "destination": "San Francisco", "dateFrom": "2025-01-04", "dateTo": "2025-01-08"}, "next": "confirm", "tool": "<toolName>" }'
'- Example of a good answer: {"response": "It seems we have all the information needed to search for flights. You will be flying from <city> to <city> from <date> to <date>. Is this correct?", "args":{"origin": "Seattle", "destination": "San Francisco", "dateDepart": "2025-01-04", "dateReturn": "2025-01-08"}, "next": "confirm", "tool": "<toolName>" }'
)
prompt_lines.append("- Return valid JSON without special characters.")
prompt_lines.append("")

View File

@@ -1,7 +1,7 @@
[tool.poetry]
name = "temporal-ollama-agent"
name = "temporal-AI-agent"
version = "0.1.0"
description = "Temporal Ollama Agent"
description = "Temporal AI Agent"
license = "MIT"
authors = ["Steve Androulakis <steve.androulakis@temporal.io>"]
readme = "README.md"

View File

@@ -1,8 +1,7 @@
import asyncio
import sys
from temporalio.client import Client
from workflows import EntityOllamaWorkflow
from workflows.tool_workflow import ToolWorkflow
async def main():
@@ -11,10 +10,10 @@ async def main():
workflow_id = "ollama-agent"
handle = client.get_workflow_handle_for(EntityOllamaWorkflow.run, workflow_id)
handle = client.get_workflow_handle_for(ToolWorkflow.run, workflow_id)
# Sends a signal to the workflow
await handle.signal(EntityOllamaWorkflow.end_chat)
await handle.signal(ToolWorkflow.end_chat)
if __name__ == "__main__":

View File

@@ -1,7 +1,7 @@
import asyncio
from temporalio.client import Client
from workflows import EntityOllamaWorkflow
from workflows import ToolWorkflow
async def main():
@@ -12,7 +12,7 @@ async def main():
handle = client.get_workflow_handle(workflow_id)
# Queries the workflow for the conversation history
history = await handle.query(EntityOllamaWorkflow.get_conversation_history)
history = await handle.query(ToolWorkflow.get_conversation_history)
print("Conversation History")
print(
@@ -20,7 +20,7 @@ async def main():
)
# Queries the workflow for the conversation summary
summary = await handle.query(EntityOllamaWorkflow.get_summary_from_history)
summary = await handle.query(ToolWorkflow.get_summary_from_history)
if summary is not None:
print("Conversation Summary:")

View File

@@ -4,23 +4,24 @@ import logging
from temporalio.client import Client
from temporalio.worker import Worker
from workflows import EntityOllamaWorkflow
from activities import OllamaActivities
from activities.tool_activities import ToolActivities
from workflows.tool_workflow import ToolWorkflow
from workflows.parent_workflow import ParentWorkflow
async def main():
# Create client connected to server at the given address
client = await Client.connect("localhost:7233")
activities = OllamaActivities()
activities = ToolActivities()
# Run the worker
with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor:
worker = Worker(
client,
task_queue="ollama-task-queue",
workflows=[EntityOllamaWorkflow],
activities=[activities.prompt_ollama, activities.parse_tool_data],
workflows=[ToolWorkflow, ParentWorkflow],
activities=[activities.prompt_llm, activities.parse_tool_data],
activity_executor=activity_executor,
)
await worker.run()

View File

@@ -3,22 +3,16 @@ import sys
from temporalio.client import Client
# Import your dataclasses/types
from workflows import (
OllamaParams,
EntityOllamaWorkflow,
ToolsData,
ToolDefinition,
ToolArgument,
CombinedInput,
)
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
from models.tool_definitions import ToolDefinition, ToolArgument
from workflows.tool_workflow import ToolWorkflow
async def main(prompt):
# Construct your tool definitions in code
search_flights_tool = ToolDefinition(
name="SearchFlights",
description="Search for flights from an origin to a destination within a date range",
description="Search for return flights from an origin to a destination within a date range",
arguments=[
ToolArgument(
name="origin",
@@ -31,12 +25,12 @@ async def main(prompt):
description="Airport or city code for arrival (infer airport code from city)",
),
ToolArgument(
name="dateFrom",
name="dateDepart",
type="ISO8601",
description="Start of date range in human readable format",
),
ToolArgument(
name="dateTo",
name="dateReturn",
type="ISO8601",
description="End of date range in human readable format",
),
@@ -47,7 +41,7 @@ async def main(prompt):
tools_data = ToolsData(tools=[search_flights_tool])
combined_input = CombinedInput(
ollama_params=OllamaParams(None, None), tools_data=tools_data
tool_params=ToolWorkflowParams(None, None), tools_data=tools_data
)
# Create client connected to Temporal server
@@ -57,7 +51,7 @@ async def main(prompt):
# Start or signal the workflow, passing OllamaParams and tools_data
await client.start_workflow(
EntityOllamaWorkflow.run,
ToolWorkflow.run,
combined_input, # or pass custom summary/prompt_queue
id=workflow_id,
task_queue="ollama-task-queue",

0
workflows/__init__.py Normal file
View File

View File

@@ -0,0 +1,14 @@
from temporalio import workflow
from .tool_workflow import ToolWorkflow, CombinedInput, ToolWorkflowParams
@workflow.defn
class ParentWorkflow:
@workflow.run
async def run(self, some_input: dict) -> dict:
combined_input = CombinedInput(
tool_params=ToolWorkflowParams(None, None), tools_data=some_input
)
child = workflow.start_child_workflow(ToolWorkflow.run, combined_input)
result = await child
return result

View File

@@ -1,67 +1,36 @@
import yaml
from collections import deque
from dataclasses import dataclass
from datetime import timedelta
from typing import Deque, List, Optional, Tuple
from temporalio import workflow
with workflow.unsafe.imports_passed_through():
# Import the updated OllamaActivities and the new dataclass
from activities import OllamaActivities, OllamaPromptInput
@dataclass
class ToolArgument:
name: str
type: str
description: str
@dataclass
class ToolDefinition:
name: str
description: str
arguments: List[ToolArgument]
@dataclass
class ToolsData:
tools: List[ToolDefinition]
@dataclass
class OllamaParams:
conversation_summary: Optional[str] = None
prompt_queue: Optional[Deque[str]] = None
@dataclass
class CombinedInput:
ollama_params: OllamaParams
tools_data: ToolsData
from agent_prompt_generators import (
from prompts.agent_prompt_generators import (
generate_genai_prompt_from_tools_data,
generate_json_validation_prompt_from_tools_data,
)
with workflow.unsafe.imports_passed_through():
from activities.tool_activities import ToolActivities, ToolPromptInput
from prompts.agent_prompt_generators import (
generate_genai_prompt_from_tools_data,
generate_json_validation_prompt_from_tools_data,
)
from models.data_types import CombinedInput, ToolWorkflowParams
@workflow.defn
class EntityOllamaWorkflow:
class ToolWorkflow:
def __init__(self) -> None:
self.conversation_history: List[Tuple[str, str]] = []
self.prompt_queue: Deque[str] = deque()
self.conversation_summary: Optional[str] = None
self.continue_as_new_per_turns: int = 250
self.chat_ended: bool = False
self.tool_data = None
self.max_turns_before_continue: int = 250
@workflow.run
async def run(self, combined_input: CombinedInput) -> str:
params = combined_input.ollama_params
params = combined_input.tool_params
tools_data = combined_input.tools_data
if params and params.conversation_summary:
@@ -92,14 +61,14 @@ class EntityOllamaWorkflow:
workflow.logger.info("Prompt: " + prompt)
# Pass a single input object
prompt_input = OllamaPromptInput(
prompt_input = ToolPromptInput(
prompt=prompt,
context_instructions=context_instructions,
)
# Call activity with one argument
responsePrechecked = await workflow.execute_activity_method(
OllamaActivities.prompt_ollama,
ToolActivities.prompt_llm,
prompt_input,
schedule_to_close_timeout=timedelta(seconds=20),
)
@@ -113,14 +82,14 @@ class EntityOllamaWorkflow:
workflow.logger.info("Prompt: " + prompt)
# Pass a single input object
prompt_input = OllamaPromptInput(
prompt_input = ToolPromptInput(
prompt=responsePrechecked,
context_instructions=json_validation_instructions,
)
# Call activity with one argument
response = await workflow.execute_activity_method(
OllamaActivities.prompt_ollama,
ToolActivities.prompt_llm,
prompt_input,
schedule_to_close_timeout=timedelta(seconds=20),
)
@@ -130,7 +99,7 @@ class EntityOllamaWorkflow:
# Call activity with one argument
tool_data = await workflow.execute_activity_method(
OllamaActivities.parse_tool_data,
ToolActivities.parse_tool_data,
response,
schedule_to_close_timeout=timedelta(seconds=1),
)
@@ -141,29 +110,29 @@ class EntityOllamaWorkflow:
return self.tool_data
# Continue as new after X turns
if len(self.conversation_history) >= self.continue_as_new_per_turns:
if len(self.conversation_history) >= self.max_turns_before_continue:
# Summarize conversation
summary_context, summary_prompt = self.prompt_summary_with_history()
summary_input = OllamaPromptInput(
summary_input = ToolPromptInput(
prompt=summary_prompt,
context_instructions=summary_context,
)
self.conversation_summary = await workflow.start_activity_method(
OllamaActivities.prompt_ollama,
ToolActivities.prompt_llm,
summary_input,
schedule_to_close_timeout=timedelta(seconds=20),
)
workflow.logger.info(
"Continuing as new after %i turns."
% self.continue_as_new_per_turns,
% self.max_turns_before_continue,
)
workflow.continue_as_new(
args=[
CombinedInput(
ollama_params=OllamaParams(
tool_params=ToolWorkflowParams(
conversation_summary=self.conversation_summary,
prompt_queue=self.prompt_queue,
),
@@ -179,13 +148,13 @@ class EntityOllamaWorkflow:
if len(self.conversation_history) > 1:
# Summarize conversation
summary_context, summary_prompt = self.prompt_summary_with_history()
summary_input = OllamaPromptInput(
summary_input = ToolPromptInput(
prompt=summary_prompt,
context_instructions=summary_context,
)
self.conversation_summary = await workflow.start_activity_method(
OllamaActivities.prompt_ollama,
ToolActivities.prompt_llm,
summary_input,
schedule_to_close_timeout=timedelta(seconds=20),
)