initial progress

This commit is contained in:
Steve Androulakis
2024-12-31 15:40:46 -08:00
parent 396b748b7d
commit 6355f976ad
8 changed files with 363 additions and 19 deletions

View File

@@ -1,3 +1,4 @@
import yaml
from collections import deque
from dataclasses import dataclass
from datetime import timedelta
@@ -10,23 +11,59 @@ with workflow.unsafe.imports_passed_through():
from activities import OllamaActivities, OllamaPromptInput
@dataclass
class ToolArgument:
name: str
type: str
description: str
@dataclass
class ToolDefinition:
name: str
description: str
arguments: List[ToolArgument]
@dataclass
class ToolsData:
tools: List[ToolDefinition]
@dataclass
class OllamaParams:
conversation_summary: Optional[str] = None
prompt_queue: Optional[Deque[str]] = None
@dataclass
class CombinedInput:
ollama_params: OllamaParams
tools_data: ToolsData
from agent_prompt_generators import (
generate_genai_prompt_from_tools_data,
generate_json_validation_prompt_from_tools_data,
)
@workflow.defn
class EntityOllamaWorkflow:
def __init__(self) -> None:
self.conversation_history: List[Tuple[str, str]] = []
self.prompt_queue: Deque[str] = deque()
self.conversation_summary: Optional[str] = None
self.continue_as_new_per_turns: int = 6
self.continue_as_new_per_turns: int = 250
self.chat_ended: bool = False
self.tool_data = None
@workflow.run
async def run(self, params: OllamaParams) -> str:
async def run(self, combined_input: CombinedInput) -> str:
params = combined_input.ollama_params
tools_data = combined_input.tools_data
if params and params.conversation_summary:
self.conversation_history.append(
("conversation_summary", params.conversation_summary)
@@ -49,15 +86,38 @@ class EntityOllamaWorkflow:
self.conversation_history.append(("user", prompt))
# Build prompt + context
context_instructions, actual_prompt = self.prompt_with_history(prompt)
context_instructions = generate_genai_prompt_from_tools_data(
tools_data, self.format_history()
)
workflow.logger.info("Prompt: " + prompt)
# Pass a single input object
prompt_input = OllamaPromptInput(
prompt=actual_prompt,
prompt=prompt,
context_instructions=context_instructions,
)
# Call activity with one argument
responsePrechecked = await workflow.execute_activity_method(
OllamaActivities.prompt_ollama,
prompt_input,
schedule_to_close_timeout=timedelta(seconds=20),
)
# Check if the response is valid JSON
json_validation_instructions = (
generate_json_validation_prompt_from_tools_data(
tools_data, self.format_history(), responsePrechecked
)
)
workflow.logger.info("Prompt: " + prompt)
# Pass a single input object
prompt_input = OllamaPromptInput(
prompt=responsePrechecked,
context_instructions=json_validation_instructions,
)
# Call activity with one argument
response = await workflow.execute_activity_method(
OllamaActivities.prompt_ollama,
@@ -68,6 +128,18 @@ class EntityOllamaWorkflow:
workflow.logger.info(f"Ollama response: {response}")
self.conversation_history.append(("response", response))
# Call activity with one argument
tool_data = await workflow.execute_activity_method(
OllamaActivities.parse_tool_data,
response,
schedule_to_close_timeout=timedelta(seconds=1),
)
self.tool_data = tool_data
if self.tool_data.get("next") == "confirm":
return self.tool_data
# Continue as new after X turns
if len(self.conversation_history) >= self.continue_as_new_per_turns:
# Summarize conversation
@@ -90,9 +162,12 @@ class EntityOllamaWorkflow:
workflow.continue_as_new(
args=[
OllamaParams(
conversation_summary=self.conversation_summary,
prompt_queue=self.prompt_queue,
CombinedInput(
ollama_params=OllamaParams(
conversation_summary=self.conversation_summary,
prompt_queue=self.prompt_queue,
),
tools_data=tools_data,
)
]
)
@@ -140,6 +215,10 @@ class EntityOllamaWorkflow:
def get_summary_from_history(self) -> Optional[str]:
return self.conversation_summary
@workflow.query
def get_tool_data(self) -> Optional[str]:
return self.tool_data
# Helper: generate text of the entire conversation so far
def format_history(self) -> str:
return " ".join(f"{text}" for _, text in self.conversation_history)