mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
prompt generator overhaul
This commit is contained in:
@@ -3,7 +3,6 @@ from temporalio import activity
|
||||
from temporalio.exceptions import ApplicationError
|
||||
from ollama import chat, ChatResponse
|
||||
import json
|
||||
from models.tool_definitions import ToolsData
|
||||
from typing import Sequence
|
||||
from temporalio.common import RawValue
|
||||
|
||||
@@ -16,7 +15,7 @@ class ToolPromptInput:
|
||||
|
||||
class ToolActivities:
|
||||
@activity.defn
|
||||
def prompt_llm(self, input: ToolPromptInput) -> str:
|
||||
def prompt_llm(self, input: ToolPromptInput) -> dict:
|
||||
model_name = "qwen2.5:14b"
|
||||
messages = [
|
||||
{
|
||||
@@ -32,62 +31,14 @@ class ToolActivities:
|
||||
]
|
||||
|
||||
response: ChatResponse = chat(model=model_name, messages=messages)
|
||||
return response.message.content
|
||||
|
||||
@activity.defn
|
||||
def parse_tool_data(self, json_str: str) -> dict:
|
||||
"""
|
||||
Parses a JSON string into a dictionary.
|
||||
Raises a ValueError if the JSON is invalid.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
data = json.loads(response.message.content)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ApplicationError(f"Invalid JSON: {e}")
|
||||
|
||||
return data
|
||||
|
||||
@activity.defn
|
||||
def validate_and_parse_json(
|
||||
self,
|
||||
response_prechecked: str,
|
||||
tools_data: ToolsData,
|
||||
conversation_history: str,
|
||||
) -> dict:
|
||||
"""
|
||||
1) Build JSON validation instructions
|
||||
2) Call LLM with those instructions
|
||||
3) Parse the result
|
||||
4) If parsing fails, raise exception -> triggers retry
|
||||
"""
|
||||
|
||||
# 1) Build validation instructions
|
||||
# (Generate the validation prompt exactly as you do in your workflow.)
|
||||
from prompts.agent_prompt_generators import (
|
||||
generate_json_validation_prompt_from_tools_data,
|
||||
)
|
||||
|
||||
validation_prompt = generate_json_validation_prompt_from_tools_data(
|
||||
tools_data, conversation_history, response_prechecked
|
||||
)
|
||||
|
||||
# 2) Call LLM
|
||||
prompt_input = ToolPromptInput(
|
||||
prompt=response_prechecked,
|
||||
context_instructions=validation_prompt,
|
||||
)
|
||||
validated_response = self.prompt_llm(prompt_input)
|
||||
|
||||
# 3) Parse
|
||||
# If parse fails, we raise ApplicationError -> triggers retry
|
||||
try:
|
||||
parsed = self.parse_tool_data(validated_response)
|
||||
except Exception as e:
|
||||
raise ApplicationError(f"Failed to parse validated JSON: {e}")
|
||||
|
||||
# 4) If we get here, parse succeeded
|
||||
return parsed
|
||||
|
||||
|
||||
def get_current_date_human_readable():
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user