deepseek api support, dotenv dev fix, other improvements

This commit is contained in:
Steve Androulakis
2025-01-26 21:18:31 -08:00
parent 0d0011d696
commit 8fbbfef6f7
7 changed files with 92 additions and 43 deletions

View File

@@ -9,6 +9,18 @@ import os
from datetime import datetime
import google.generativeai as genai
import anthropic
import deepseek
from dotenv import load_dotenv
load_dotenv(override=True)
print(
"Using LLM: "
+ os.environ.get("LLM_PROVIDER")
+ " (set LLM_PROVIDER in .env to change)"
)
if os.environ.get("LLM_PROVIDER") == "ollama":
print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME"))
@dataclass
@@ -22,20 +34,33 @@ class ToolActivities:
def prompt_llm(self, input: ToolPromptInput) -> dict:
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
print(f"LLM provider: {llm_provider}")
if llm_provider == "ollama":
return self.prompt_llm_ollama(input)
elif llm_provider == "google":
return self.prompt_llm_google(input)
elif llm_provider == "anthropic":
return self.prompt_llm_anthropic(input)
elif llm_provider == "deepseek":
return self.prompt_llm_deepseek(input)
else:
return self.prompt_llm_openai(input)
def parse_json_response(self, response_content: str) -> dict:
"""
Parses the JSON response content and returns it as a dictionary.
"""
try:
data = json.loads(response_content)
return data
except json.JSONDecodeError as e:
print(f"Invalid JSON: {e}")
raise json.JSONDecodeError
def prompt_llm_openai(self, input: ToolPromptInput) -> dict:
client = OpenAI(
api_key=os.environ.get(
"OPENAI_API_KEY"
),
api_key=os.environ.get("OPENAI_API_KEY"),
)
messages = [
@@ -61,15 +86,8 @@ class ToolActivities:
# Use the new sanitize function
response_content = self.sanitize_json_response(response_content)
try:
data = json.loads(response_content)
except json.JSONDecodeError as e:
print(f"Invalid JSON: {e}")
raise json.JSONDecodeError
return self.parse_json_response(response_content)
return data
@activity.defn
def prompt_llm_ollama(self, input: ToolPromptInput) -> dict:
model_name = os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b")
messages = [
@@ -92,14 +110,7 @@ class ToolActivities:
# Use the new sanitize function
response_content = self.sanitize_json_response(response.message.content)
try:
data = json.loads(response_content)
except json.JSONDecodeError as e:
print(f"Invalid JSON: {e}")
print(response.message.content)
raise json.JSONDecodeError
return data
return self.parse_json_response(response_content)
def prompt_llm_google(self, input: ToolPromptInput) -> dict:
api_key = os.environ.get("GOOGLE_API_KEY")
@@ -118,18 +129,14 @@ class ToolActivities:
# Use the new sanitize function
response_content = self.sanitize_json_response(response_content)
try:
data = json.loads(response_content)
except json.JSONDecodeError as e:
print(f"Invalid JSON: {e}")
raise json.JSONDecodeError
return data
return self.parse_json_response(response_content)
def prompt_llm_anthropic(self, input: ToolPromptInput) -> dict:
api_key = os.environ.get("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("ANTHROPIC_API_KEY is not set in the environment variables.")
raise ValueError(
"ANTHROPIC_API_KEY is not set in the environment variables."
)
client = anthropic.Anthropic(api_key=api_key)
@@ -137,7 +144,12 @@ class ToolActivities:
model="claude-3-5-sonnet-20241022",
max_tokens=1024,
system=input.context_instructions,
messages=input.prompt
messages=[
{
"role": "user",
"content": input.prompt,
}
],
)
response_content = response.content[0].text
@@ -146,13 +158,32 @@ class ToolActivities:
# Use the new sanitize function
response_content = self.sanitize_json_response(response_content)
try:
data = json.loads(response_content)
except json.JSONDecodeError as e:
print(f"Invalid JSON: {e}")
raise json.JSONDecodeError
return self.parse_json_response(response_content)
return data
def prompt_llm_deepseek(self, input: ToolPromptInput) -> dict:
api_client = deepseek.DeepSeekAPI(api_key=os.environ.get("DEEPSEEK_API_KEY"))
messages = [
{
"role": "system",
"content": input.context_instructions
+ ". The current date is "
+ datetime.now().strftime("%B %d, %Y"),
},
{
"role": "user",
"content": input.prompt,
},
]
response = api_client.chat_completion(prompt=messages)
response_content = response
print(f"DeepSeek response: {response_content}")
# Use the new sanitize function
response_content = self.sanitize_json_response(response_content)
return self.parse_json_response(response_content)
def sanitize_json_response(self, response_content: str) -> str:
"""