mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
refactor, date context
This commit is contained in:
@@ -11,6 +11,8 @@ This demo shows a multi-turn conversation with an AI agent running inside a Temp
|
||||
|
||||
## Running the example
|
||||
|
||||
From the /scripts directory:
|
||||
|
||||
1. Run the worker: `poetry run python run_worker.py`
|
||||
2. In another terminal run the client with a prompt.
|
||||
|
||||
|
||||
0
activities/__init__.py
Normal file
0
activities/__init__.py
Normal file
@@ -1,24 +1,26 @@
|
||||
from dataclasses import dataclass
|
||||
from temporalio import activity
|
||||
from temporalio.exceptions import ApplicationError
|
||||
from ollama import chat, ChatResponse
|
||||
import json
|
||||
from temporalio.exceptions import ApplicationError
|
||||
|
||||
|
||||
@dataclass
|
||||
class OllamaPromptInput:
|
||||
class ToolPromptInput:
|
||||
prompt: str
|
||||
context_instructions: str
|
||||
|
||||
|
||||
class OllamaActivities:
|
||||
class ToolActivities:
|
||||
@activity.defn
|
||||
def prompt_ollama(self, input: OllamaPromptInput) -> str:
|
||||
def prompt_llm(self, input: ToolPromptInput) -> str:
|
||||
model_name = "qwen2.5:14b"
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": input.context_instructions,
|
||||
"content": input.context_instructions
|
||||
+ ". The current date is "
|
||||
+ get_current_date_human_readable(),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
@@ -41,3 +43,14 @@ class OllamaActivities:
|
||||
raise ApplicationError(f"Invalid JSON: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def get_current_date_human_readable():
|
||||
"""
|
||||
Returns the current date in a human-readable format.
|
||||
|
||||
Example: Wednesday, January 1, 2025
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
return datetime.now().strftime("%A, %B %d, %Y")
|
||||
0
models/__init__.py
Normal file
0
models/__init__.py
Normal file
15
models/data_types.py
Normal file
15
models/data_types.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Deque
|
||||
from models.tool_definitions import ToolsData
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolWorkflowParams:
|
||||
conversation_summary: Optional[str] = None
|
||||
prompt_queue: Optional[Deque[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CombinedInput:
|
||||
tool_params: ToolWorkflowParams
|
||||
tools_data: ToolsData
|
||||
21
models/tool_definitions.py
Normal file
21
models/tool_definitions.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolArgument:
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
name: str
|
||||
description: str
|
||||
arguments: List[ToolArgument]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolsData:
|
||||
tools: List[ToolDefinition]
|
||||
0
prompts/__init__.py
Normal file
0
prompts/__init__.py
Normal file
@@ -1,4 +1,4 @@
|
||||
from workflows import ToolsData
|
||||
from models.tool_definitions import ToolsData
|
||||
|
||||
|
||||
def generate_genai_prompt_from_tools_data(
|
||||
@@ -51,7 +51,7 @@ def generate_genai_prompt_from_tools_data(
|
||||
'- If you have enough information and are confirming, use "next": "confirm". This is the final step once you have filled all args.'
|
||||
)
|
||||
prompt_lines.append(
|
||||
'- Example of a good answer: {"response": "It seems we have all the information needed to search for flights. You will be flying from <city> to <city> from <date> to <date>. Is this correct?", "args":{"origin": "Seattle", "destination": "San Francisco", "dateFrom": "2025-01-04", "dateTo": "2025-01-08"}, "next": "confirm", "tool": "<toolName>" }'
|
||||
'- Example of a good answer: {"response": "It seems we have all the information needed to search for flights. You will be flying from <city> to <city> from <date> to <date>. Is this correct?", "args":{"origin": "Seattle", "destination": "San Francisco", "dateDepart": "2025-01-04", "dateReturn": "2025-01-08"}, "next": "confirm", "tool": "<toolName>" }'
|
||||
)
|
||||
prompt_lines.append("- Return valid JSON without special characters.")
|
||||
prompt_lines.append("")
|
||||
@@ -1,7 +1,7 @@
|
||||
[tool.poetry]
|
||||
name = "temporal-ollama-agent"
|
||||
name = "temporal-AI-agent"
|
||||
version = "0.1.0"
|
||||
description = "Temporal Ollama Agent"
|
||||
description = "Temporal AI Agent"
|
||||
license = "MIT"
|
||||
authors = ["Steve Androulakis <steve.androulakis@temporal.io>"]
|
||||
readme = "README.md"
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from temporalio.client import Client
|
||||
from workflows import EntityOllamaWorkflow
|
||||
from workflows.tool_workflow import ToolWorkflow
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -11,10 +10,10 @@ async def main():
|
||||
|
||||
workflow_id = "ollama-agent"
|
||||
|
||||
handle = client.get_workflow_handle_for(EntityOllamaWorkflow.run, workflow_id)
|
||||
handle = client.get_workflow_handle_for(ToolWorkflow.run, workflow_id)
|
||||
|
||||
# Sends a signal to the workflow
|
||||
await handle.signal(EntityOllamaWorkflow.end_chat)
|
||||
await handle.signal(ToolWorkflow.end_chat)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
from temporalio.client import Client
|
||||
from workflows import EntityOllamaWorkflow
|
||||
from workflows import ToolWorkflow
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -12,7 +12,7 @@ async def main():
|
||||
handle = client.get_workflow_handle(workflow_id)
|
||||
|
||||
# Queries the workflow for the conversation history
|
||||
history = await handle.query(EntityOllamaWorkflow.get_conversation_history)
|
||||
history = await handle.query(ToolWorkflow.get_conversation_history)
|
||||
|
||||
print("Conversation History")
|
||||
print(
|
||||
@@ -20,7 +20,7 @@ async def main():
|
||||
)
|
||||
|
||||
# Queries the workflow for the conversation summary
|
||||
summary = await handle.query(EntityOllamaWorkflow.get_summary_from_history)
|
||||
summary = await handle.query(ToolWorkflow.get_summary_from_history)
|
||||
|
||||
if summary is not None:
|
||||
print("Conversation Summary:")
|
||||
@@ -4,23 +4,24 @@ import logging
|
||||
|
||||
from temporalio.client import Client
|
||||
from temporalio.worker import Worker
|
||||
from workflows import EntityOllamaWorkflow
|
||||
|
||||
from activities import OllamaActivities
|
||||
from activities.tool_activities import ToolActivities
|
||||
from workflows.tool_workflow import ToolWorkflow
|
||||
from workflows.parent_workflow import ParentWorkflow
|
||||
|
||||
|
||||
async def main():
|
||||
# Create client connected to server at the given address
|
||||
client = await Client.connect("localhost:7233")
|
||||
activities = OllamaActivities()
|
||||
activities = ToolActivities()
|
||||
|
||||
# Run the worker
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor:
|
||||
worker = Worker(
|
||||
client,
|
||||
task_queue="ollama-task-queue",
|
||||
workflows=[EntityOllamaWorkflow],
|
||||
activities=[activities.prompt_ollama, activities.parse_tool_data],
|
||||
workflows=[ToolWorkflow, ParentWorkflow],
|
||||
activities=[activities.prompt_llm, activities.parse_tool_data],
|
||||
activity_executor=activity_executor,
|
||||
)
|
||||
await worker.run()
|
||||
@@ -3,22 +3,16 @@ import sys
|
||||
|
||||
from temporalio.client import Client
|
||||
|
||||
# Import your dataclasses/types
|
||||
from workflows import (
|
||||
OllamaParams,
|
||||
EntityOllamaWorkflow,
|
||||
ToolsData,
|
||||
ToolDefinition,
|
||||
ToolArgument,
|
||||
CombinedInput,
|
||||
)
|
||||
from models.data_types import CombinedInput, ToolsData, ToolWorkflowParams
|
||||
from models.tool_definitions import ToolDefinition, ToolArgument
|
||||
from workflows.tool_workflow import ToolWorkflow
|
||||
|
||||
|
||||
async def main(prompt):
|
||||
# Construct your tool definitions in code
|
||||
search_flights_tool = ToolDefinition(
|
||||
name="SearchFlights",
|
||||
description="Search for flights from an origin to a destination within a date range",
|
||||
description="Search for return flights from an origin to a destination within a date range",
|
||||
arguments=[
|
||||
ToolArgument(
|
||||
name="origin",
|
||||
@@ -31,12 +25,12 @@ async def main(prompt):
|
||||
description="Airport or city code for arrival (infer airport code from city)",
|
||||
),
|
||||
ToolArgument(
|
||||
name="dateFrom",
|
||||
name="dateDepart",
|
||||
type="ISO8601",
|
||||
description="Start of date range in human readable format",
|
||||
),
|
||||
ToolArgument(
|
||||
name="dateTo",
|
||||
name="dateReturn",
|
||||
type="ISO8601",
|
||||
description="End of date range in human readable format",
|
||||
),
|
||||
@@ -47,7 +41,7 @@ async def main(prompt):
|
||||
tools_data = ToolsData(tools=[search_flights_tool])
|
||||
|
||||
combined_input = CombinedInput(
|
||||
ollama_params=OllamaParams(None, None), tools_data=tools_data
|
||||
tool_params=ToolWorkflowParams(None, None), tools_data=tools_data
|
||||
)
|
||||
|
||||
# Create client connected to Temporal server
|
||||
@@ -57,7 +51,7 @@ async def main(prompt):
|
||||
|
||||
# Start or signal the workflow, passing OllamaParams and tools_data
|
||||
await client.start_workflow(
|
||||
EntityOllamaWorkflow.run,
|
||||
ToolWorkflow.run,
|
||||
combined_input, # or pass custom summary/prompt_queue
|
||||
id=workflow_id,
|
||||
task_queue="ollama-task-queue",
|
||||
0
workflows/__init__.py
Normal file
0
workflows/__init__.py
Normal file
14
workflows/parent_workflow.py
Normal file
14
workflows/parent_workflow.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from temporalio import workflow
|
||||
from .tool_workflow import ToolWorkflow, CombinedInput, ToolWorkflowParams
|
||||
|
||||
|
||||
@workflow.defn
|
||||
class ParentWorkflow:
|
||||
@workflow.run
|
||||
async def run(self, some_input: dict) -> dict:
|
||||
combined_input = CombinedInput(
|
||||
tool_params=ToolWorkflowParams(None, None), tools_data=some_input
|
||||
)
|
||||
child = workflow.start_child_workflow(ToolWorkflow.run, combined_input)
|
||||
result = await child
|
||||
return result
|
||||
@@ -1,67 +1,36 @@
|
||||
import yaml
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from typing import Deque, List, Optional, Tuple
|
||||
|
||||
from temporalio import workflow
|
||||
|
||||
with workflow.unsafe.imports_passed_through():
|
||||
# Import the updated OllamaActivities and the new dataclass
|
||||
from activities import OllamaActivities, OllamaPromptInput
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolArgument:
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
name: str
|
||||
description: str
|
||||
arguments: List[ToolArgument]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolsData:
|
||||
tools: List[ToolDefinition]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OllamaParams:
|
||||
conversation_summary: Optional[str] = None
|
||||
prompt_queue: Optional[Deque[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CombinedInput:
|
||||
ollama_params: OllamaParams
|
||||
tools_data: ToolsData
|
||||
|
||||
|
||||
from agent_prompt_generators import (
|
||||
from prompts.agent_prompt_generators import (
|
||||
generate_genai_prompt_from_tools_data,
|
||||
generate_json_validation_prompt_from_tools_data,
|
||||
)
|
||||
|
||||
with workflow.unsafe.imports_passed_through():
|
||||
from activities.tool_activities import ToolActivities, ToolPromptInput
|
||||
from prompts.agent_prompt_generators import (
|
||||
generate_genai_prompt_from_tools_data,
|
||||
generate_json_validation_prompt_from_tools_data,
|
||||
)
|
||||
from models.data_types import CombinedInput, ToolWorkflowParams
|
||||
|
||||
|
||||
@workflow.defn
|
||||
class EntityOllamaWorkflow:
|
||||
class ToolWorkflow:
|
||||
def __init__(self) -> None:
|
||||
self.conversation_history: List[Tuple[str, str]] = []
|
||||
self.prompt_queue: Deque[str] = deque()
|
||||
self.conversation_summary: Optional[str] = None
|
||||
self.continue_as_new_per_turns: int = 250
|
||||
self.chat_ended: bool = False
|
||||
self.tool_data = None
|
||||
self.max_turns_before_continue: int = 250
|
||||
|
||||
@workflow.run
|
||||
async def run(self, combined_input: CombinedInput) -> str:
|
||||
|
||||
params = combined_input.ollama_params
|
||||
params = combined_input.tool_params
|
||||
tools_data = combined_input.tools_data
|
||||
|
||||
if params and params.conversation_summary:
|
||||
@@ -92,14 +61,14 @@ class EntityOllamaWorkflow:
|
||||
workflow.logger.info("Prompt: " + prompt)
|
||||
|
||||
# Pass a single input object
|
||||
prompt_input = OllamaPromptInput(
|
||||
prompt_input = ToolPromptInput(
|
||||
prompt=prompt,
|
||||
context_instructions=context_instructions,
|
||||
)
|
||||
|
||||
# Call activity with one argument
|
||||
responsePrechecked = await workflow.execute_activity_method(
|
||||
OllamaActivities.prompt_ollama,
|
||||
ToolActivities.prompt_llm,
|
||||
prompt_input,
|
||||
schedule_to_close_timeout=timedelta(seconds=20),
|
||||
)
|
||||
@@ -113,14 +82,14 @@ class EntityOllamaWorkflow:
|
||||
workflow.logger.info("Prompt: " + prompt)
|
||||
|
||||
# Pass a single input object
|
||||
prompt_input = OllamaPromptInput(
|
||||
prompt_input = ToolPromptInput(
|
||||
prompt=responsePrechecked,
|
||||
context_instructions=json_validation_instructions,
|
||||
)
|
||||
|
||||
# Call activity with one argument
|
||||
response = await workflow.execute_activity_method(
|
||||
OllamaActivities.prompt_ollama,
|
||||
ToolActivities.prompt_llm,
|
||||
prompt_input,
|
||||
schedule_to_close_timeout=timedelta(seconds=20),
|
||||
)
|
||||
@@ -130,7 +99,7 @@ class EntityOllamaWorkflow:
|
||||
|
||||
# Call activity with one argument
|
||||
tool_data = await workflow.execute_activity_method(
|
||||
OllamaActivities.parse_tool_data,
|
||||
ToolActivities.parse_tool_data,
|
||||
response,
|
||||
schedule_to_close_timeout=timedelta(seconds=1),
|
||||
)
|
||||
@@ -141,29 +110,29 @@ class EntityOllamaWorkflow:
|
||||
return self.tool_data
|
||||
|
||||
# Continue as new after X turns
|
||||
if len(self.conversation_history) >= self.continue_as_new_per_turns:
|
||||
if len(self.conversation_history) >= self.max_turns_before_continue:
|
||||
# Summarize conversation
|
||||
summary_context, summary_prompt = self.prompt_summary_with_history()
|
||||
summary_input = OllamaPromptInput(
|
||||
summary_input = ToolPromptInput(
|
||||
prompt=summary_prompt,
|
||||
context_instructions=summary_context,
|
||||
)
|
||||
|
||||
self.conversation_summary = await workflow.start_activity_method(
|
||||
OllamaActivities.prompt_ollama,
|
||||
ToolActivities.prompt_llm,
|
||||
summary_input,
|
||||
schedule_to_close_timeout=timedelta(seconds=20),
|
||||
)
|
||||
|
||||
workflow.logger.info(
|
||||
"Continuing as new after %i turns."
|
||||
% self.continue_as_new_per_turns,
|
||||
% self.max_turns_before_continue,
|
||||
)
|
||||
|
||||
workflow.continue_as_new(
|
||||
args=[
|
||||
CombinedInput(
|
||||
ollama_params=OllamaParams(
|
||||
tool_params=ToolWorkflowParams(
|
||||
conversation_summary=self.conversation_summary,
|
||||
prompt_queue=self.prompt_queue,
|
||||
),
|
||||
@@ -179,13 +148,13 @@ class EntityOllamaWorkflow:
|
||||
if len(self.conversation_history) > 1:
|
||||
# Summarize conversation
|
||||
summary_context, summary_prompt = self.prompt_summary_with_history()
|
||||
summary_input = OllamaPromptInput(
|
||||
summary_input = ToolPromptInput(
|
||||
prompt=summary_prompt,
|
||||
context_instructions=summary_context,
|
||||
)
|
||||
|
||||
self.conversation_summary = await workflow.start_activity_method(
|
||||
OllamaActivities.prompt_ollama,
|
||||
ToolActivities.prompt_llm,
|
||||
summary_input,
|
||||
schedule_to_close_timeout=timedelta(seconds=20),
|
||||
)
|
||||
Reference in New Issue
Block a user