prompt generator overhaul

This commit is contained in:
Steve Androulakis
2025-01-02 10:21:34 -08:00
parent 5ddf3c7705
commit 7398d8dec6
4 changed files with 115 additions and 187 deletions

View File

@@ -3,7 +3,6 @@ from temporalio import activity
from temporalio.exceptions import ApplicationError
from ollama import chat, ChatResponse
import json
from models.tool_definitions import ToolsData
from typing import Sequence
from temporalio.common import RawValue
@@ -16,7 +15,7 @@ class ToolPromptInput:
class ToolActivities:
@activity.defn
def prompt_llm(self, input: ToolPromptInput) -> str:
def prompt_llm(self, input: ToolPromptInput) -> dict:
model_name = "qwen2.5:14b"
messages = [
{
@@ -32,62 +31,14 @@ class ToolActivities:
]
response: ChatResponse = chat(model=model_name, messages=messages)
return response.message.content
@activity.defn
def parse_tool_data(self, json_str: str) -> dict:
"""
Parses a JSON string into a dictionary.
Raises a ValueError if the JSON is invalid.
"""
try:
data = json.loads(json_str)
data = json.loads(response.message.content)
except json.JSONDecodeError as e:
raise ApplicationError(f"Invalid JSON: {e}")
return data
@activity.defn
def validate_and_parse_json(
self,
response_prechecked: str,
tools_data: ToolsData,
conversation_history: str,
) -> dict:
"""
1) Build JSON validation instructions
2) Call LLM with those instructions
3) Parse the result
4) If parsing fails, raise exception -> triggers retry
"""
# 1) Build validation instructions
# (Generate the validation prompt exactly as you do in your workflow.)
from prompts.agent_prompt_generators import (
generate_json_validation_prompt_from_tools_data,
)
validation_prompt = generate_json_validation_prompt_from_tools_data(
tools_data, conversation_history, response_prechecked
)
# 2) Call LLM
prompt_input = ToolPromptInput(
prompt=response_prechecked,
context_instructions=validation_prompt,
)
validated_response = self.prompt_llm(prompt_input)
# 3) Parse
# If parse fails, we raise ApplicationError -> triggers retry
try:
parsed = self.parse_tool_data(validated_response)
except Exception as e:
raise ApplicationError(f"Failed to parse validated JSON: {e}")
# 4) If we get here, parse succeeded
return parsed
def get_current_date_human_readable():
"""