mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-17 06:58:09 +01:00
fixes to issues 1 2 and 3. Plus tuning
This commit is contained in:
@@ -18,6 +18,14 @@ class ToolPromptInput:
|
||||
class ToolActivities:
|
||||
@activity.defn
|
||||
def prompt_llm(self, input: ToolPromptInput) -> dict:
|
||||
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
||||
|
||||
if llm_provider == "ollama":
|
||||
return self.prompt_llm_ollama(input)
|
||||
else:
|
||||
return self.prompt_llm_openai(input)
|
||||
|
||||
def prompt_llm_openai(self, input: ToolPromptInput) -> dict:
|
||||
client = OpenAI(
|
||||
api_key=os.environ.get(
|
||||
"OPENAI_API_KEY"
|
||||
@@ -44,9 +52,8 @@ class ToolActivities:
|
||||
response_content = chat_completion.choices[0].message.content
|
||||
print(f"ChatGPT response: {response_content}")
|
||||
|
||||
# Trim formatting markers if present
|
||||
if response_content.startswith("```json") and response_content.endswith("```"):
|
||||
response_content = response_content[7:-3].strip()
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response_content)
|
||||
|
||||
try:
|
||||
data = json.loads(response_content)
|
||||
@@ -58,7 +65,7 @@ class ToolActivities:
|
||||
|
||||
@activity.defn
|
||||
def prompt_llm_ollama(self, input: ToolPromptInput) -> dict:
|
||||
model_name = "qwen2.5:14b"
|
||||
model_name = os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b")
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
@@ -76,8 +83,11 @@ class ToolActivities:
|
||||
|
||||
print(f"Chat response: {response.message.content}")
|
||||
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response.message.content)
|
||||
|
||||
try:
|
||||
data = json.loads(response.message.content)
|
||||
data = json.loads(response_content)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Invalid JSON: {e}")
|
||||
print(response.message.content)
|
||||
@@ -85,6 +95,54 @@ class ToolActivities:
|
||||
|
||||
return data
|
||||
|
||||
def sanitize_json_response(self, response_content: str) -> str:
|
||||
"""
|
||||
Extracts the JSON block from the response content as a string.
|
||||
Supports:
|
||||
- JSON surrounded by ```json and ```
|
||||
- Raw JSON input
|
||||
- JSON preceded or followed by extra text
|
||||
Rejects invalid input that doesn't contain JSON.
|
||||
"""
|
||||
try:
|
||||
start_marker = "```json"
|
||||
end_marker = "```"
|
||||
|
||||
json_str = None
|
||||
|
||||
# Case 1: JSON surrounded by markers
|
||||
if start_marker in response_content and end_marker in response_content:
|
||||
json_start = response_content.index(start_marker) + len(start_marker)
|
||||
json_end = response_content.index(end_marker, json_start)
|
||||
json_str = response_content[json_start:json_end].strip()
|
||||
|
||||
# Case 2: Text with valid JSON
|
||||
else:
|
||||
# Try to locate the JSON block by scanning for the first `{` and last `}`
|
||||
json_start = response_content.find("{")
|
||||
json_end = response_content.rfind("}")
|
||||
|
||||
if json_start != -1 and json_end != -1 and json_start < json_end:
|
||||
json_str = response_content[json_start : json_end + 1].strip()
|
||||
|
||||
# Validate and ensure the extracted JSON is valid
|
||||
if json_str:
|
||||
json.loads(json_str) # This will raise an error if the JSON is invalid
|
||||
return json_str
|
||||
|
||||
# If no valid JSON found, raise an error
|
||||
raise ValueError("Response does not contain valid JSON.")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# Invalid JSON
|
||||
print(f"Invalid JSON detected in response: {response_content}")
|
||||
raise ValueError("Response does not contain valid JSON.")
|
||||
except Exception as e:
|
||||
# Other errors
|
||||
print(f"Error processing response: {str(e)}")
|
||||
print(f"Full response: {response_content}")
|
||||
raise
|
||||
|
||||
|
||||
def get_current_date_human_readable():
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user