mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-16 14:38:08 +01:00
initial progress
This commit is contained in:
93
workflows.py
93
workflows.py
@@ -1,3 +1,4 @@
|
||||
import yaml
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
@@ -10,23 +11,59 @@ with workflow.unsafe.imports_passed_through():
|
||||
from activities import OllamaActivities, OllamaPromptInput
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolArgument:
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolDefinition:
|
||||
name: str
|
||||
description: str
|
||||
arguments: List[ToolArgument]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolsData:
|
||||
tools: List[ToolDefinition]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OllamaParams:
|
||||
conversation_summary: Optional[str] = None
|
||||
prompt_queue: Optional[Deque[str]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CombinedInput:
|
||||
ollama_params: OllamaParams
|
||||
tools_data: ToolsData
|
||||
|
||||
|
||||
from agent_prompt_generators import (
|
||||
generate_genai_prompt_from_tools_data,
|
||||
generate_json_validation_prompt_from_tools_data,
|
||||
)
|
||||
|
||||
|
||||
@workflow.defn
|
||||
class EntityOllamaWorkflow:
|
||||
def __init__(self) -> None:
|
||||
self.conversation_history: List[Tuple[str, str]] = []
|
||||
self.prompt_queue: Deque[str] = deque()
|
||||
self.conversation_summary: Optional[str] = None
|
||||
self.continue_as_new_per_turns: int = 6
|
||||
self.continue_as_new_per_turns: int = 250
|
||||
self.chat_ended: bool = False
|
||||
self.tool_data = None
|
||||
|
||||
@workflow.run
|
||||
async def run(self, params: OllamaParams) -> str:
|
||||
async def run(self, combined_input: CombinedInput) -> str:
|
||||
|
||||
params = combined_input.ollama_params
|
||||
tools_data = combined_input.tools_data
|
||||
|
||||
if params and params.conversation_summary:
|
||||
self.conversation_history.append(
|
||||
("conversation_summary", params.conversation_summary)
|
||||
@@ -49,15 +86,38 @@ class EntityOllamaWorkflow:
|
||||
self.conversation_history.append(("user", prompt))
|
||||
|
||||
# Build prompt + context
|
||||
context_instructions, actual_prompt = self.prompt_with_history(prompt)
|
||||
context_instructions = generate_genai_prompt_from_tools_data(
|
||||
tools_data, self.format_history()
|
||||
)
|
||||
workflow.logger.info("Prompt: " + prompt)
|
||||
|
||||
# Pass a single input object
|
||||
prompt_input = OllamaPromptInput(
|
||||
prompt=actual_prompt,
|
||||
prompt=prompt,
|
||||
context_instructions=context_instructions,
|
||||
)
|
||||
|
||||
# Call activity with one argument
|
||||
responsePrechecked = await workflow.execute_activity_method(
|
||||
OllamaActivities.prompt_ollama,
|
||||
prompt_input,
|
||||
schedule_to_close_timeout=timedelta(seconds=20),
|
||||
)
|
||||
|
||||
# Check if the response is valid JSON
|
||||
json_validation_instructions = (
|
||||
generate_json_validation_prompt_from_tools_data(
|
||||
tools_data, self.format_history(), responsePrechecked
|
||||
)
|
||||
)
|
||||
workflow.logger.info("Prompt: " + prompt)
|
||||
|
||||
# Pass a single input object
|
||||
prompt_input = OllamaPromptInput(
|
||||
prompt=responsePrechecked,
|
||||
context_instructions=json_validation_instructions,
|
||||
)
|
||||
|
||||
# Call activity with one argument
|
||||
response = await workflow.execute_activity_method(
|
||||
OllamaActivities.prompt_ollama,
|
||||
@@ -68,6 +128,18 @@ class EntityOllamaWorkflow:
|
||||
workflow.logger.info(f"Ollama response: {response}")
|
||||
self.conversation_history.append(("response", response))
|
||||
|
||||
# Call activity with one argument
|
||||
tool_data = await workflow.execute_activity_method(
|
||||
OllamaActivities.parse_tool_data,
|
||||
response,
|
||||
schedule_to_close_timeout=timedelta(seconds=1),
|
||||
)
|
||||
|
||||
self.tool_data = tool_data
|
||||
|
||||
if self.tool_data.get("next") == "confirm":
|
||||
return self.tool_data
|
||||
|
||||
# Continue as new after X turns
|
||||
if len(self.conversation_history) >= self.continue_as_new_per_turns:
|
||||
# Summarize conversation
|
||||
@@ -90,9 +162,12 @@ class EntityOllamaWorkflow:
|
||||
|
||||
workflow.continue_as_new(
|
||||
args=[
|
||||
OllamaParams(
|
||||
conversation_summary=self.conversation_summary,
|
||||
prompt_queue=self.prompt_queue,
|
||||
CombinedInput(
|
||||
ollama_params=OllamaParams(
|
||||
conversation_summary=self.conversation_summary,
|
||||
prompt_queue=self.prompt_queue,
|
||||
),
|
||||
tools_data=tools_data,
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -140,6 +215,10 @@ class EntityOllamaWorkflow:
|
||||
def get_summary_from_history(self) -> Optional[str]:
|
||||
return self.conversation_summary
|
||||
|
||||
@workflow.query
|
||||
def get_tool_data(self) -> Optional[str]:
|
||||
return self.tool_data
|
||||
|
||||
# Helper: generate text of the entire conversation so far
|
||||
def format_history(self) -> str:
|
||||
return " ".join(f"{text}" for _, text in self.conversation_history)
|
||||
|
||||
Reference in New Issue
Block a user