mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 05:58:08 +01:00
deepseek api support, dotenv dev fix, other improvements
This commit is contained in:
@@ -1,11 +1,10 @@
|
||||
OPENAI_API_KEY=sk-proj-...
|
||||
|
||||
RAPIDAPI_KEY=9df2cb5...
|
||||
RAPIDAPI_HOST=sky-scrapper.p.rapidapi.com
|
||||
|
||||
STRIPE_API_KEY=sk_test_51J...
|
||||
|
||||
LLM_PROVIDER=openai # default
|
||||
OPENAI_API_KEY=sk-proj-...
|
||||
# or
|
||||
# LLM_PROVIDER=ollama
|
||||
# OLLAMA_MODEL_NAME=qwen2.5:14b
|
||||
@@ -15,6 +14,10 @@ LLM_PROVIDER=openai # default
|
||||
# or
|
||||
# LLM_PROVIDER=anthropic
|
||||
# ANTHROPIC_API_KEY=your-anthropic-api-key
|
||||
# or
|
||||
# LLM_PROVIDER=deepseek
|
||||
# DEEPSEEK_API_KEY=your-deepseek-api-key
|
||||
|
||||
|
||||
# uncomment and unset these environment variables to connect to the local dev server
|
||||
# TEMPORAL_ADDRESS=namespace.acct.tmprl.cloud:7233
|
||||
|
||||
@@ -9,6 +9,18 @@ import os
|
||||
from datetime import datetime
|
||||
import google.generativeai as genai
|
||||
import anthropic
|
||||
import deepseek
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
print(
|
||||
"Using LLM: "
|
||||
+ os.environ.get("LLM_PROVIDER")
|
||||
+ " (set LLM_PROVIDER in .env to change)"
|
||||
)
|
||||
|
||||
if os.environ.get("LLM_PROVIDER") == "ollama":
|
||||
print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME"))
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -22,20 +34,33 @@ class ToolActivities:
|
||||
def prompt_llm(self, input: ToolPromptInput) -> dict:
|
||||
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
||||
|
||||
print(f"LLM provider: {llm_provider}")
|
||||
|
||||
if llm_provider == "ollama":
|
||||
return self.prompt_llm_ollama(input)
|
||||
elif llm_provider == "google":
|
||||
return self.prompt_llm_google(input)
|
||||
elif llm_provider == "anthropic":
|
||||
return self.prompt_llm_anthropic(input)
|
||||
elif llm_provider == "deepseek":
|
||||
return self.prompt_llm_deepseek(input)
|
||||
else:
|
||||
return self.prompt_llm_openai(input)
|
||||
|
||||
def parse_json_response(self, response_content: str) -> dict:
|
||||
"""
|
||||
Parses the JSON response content and returns it as a dictionary.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(response_content)
|
||||
return data
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Invalid JSON: {e}")
|
||||
raise json.JSONDecodeError
|
||||
|
||||
def prompt_llm_openai(self, input: ToolPromptInput) -> dict:
|
||||
client = OpenAI(
|
||||
api_key=os.environ.get(
|
||||
"OPENAI_API_KEY"
|
||||
),
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
)
|
||||
|
||||
messages = [
|
||||
@@ -61,15 +86,8 @@ class ToolActivities:
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response_content)
|
||||
|
||||
try:
|
||||
data = json.loads(response_content)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Invalid JSON: {e}")
|
||||
raise json.JSONDecodeError
|
||||
return self.parse_json_response(response_content)
|
||||
|
||||
return data
|
||||
|
||||
@activity.defn
|
||||
def prompt_llm_ollama(self, input: ToolPromptInput) -> dict:
|
||||
model_name = os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b")
|
||||
messages = [
|
||||
@@ -92,14 +110,7 @@ class ToolActivities:
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response.message.content)
|
||||
|
||||
try:
|
||||
data = json.loads(response_content)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Invalid JSON: {e}")
|
||||
print(response.message.content)
|
||||
raise json.JSONDecodeError
|
||||
|
||||
return data
|
||||
return self.parse_json_response(response_content)
|
||||
|
||||
def prompt_llm_google(self, input: ToolPromptInput) -> dict:
|
||||
api_key = os.environ.get("GOOGLE_API_KEY")
|
||||
@@ -118,18 +129,14 @@ class ToolActivities:
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response_content)
|
||||
|
||||
try:
|
||||
data = json.loads(response_content)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Invalid JSON: {e}")
|
||||
raise json.JSONDecodeError
|
||||
|
||||
return data
|
||||
return self.parse_json_response(response_content)
|
||||
|
||||
def prompt_llm_anthropic(self, input: ToolPromptInput) -> dict:
|
||||
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("ANTHROPIC_API_KEY is not set in the environment variables.")
|
||||
raise ValueError(
|
||||
"ANTHROPIC_API_KEY is not set in the environment variables."
|
||||
)
|
||||
|
||||
client = anthropic.Anthropic(api_key=api_key)
|
||||
|
||||
@@ -137,7 +144,12 @@ class ToolActivities:
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
max_tokens=1024,
|
||||
system=input.context_instructions,
|
||||
messages=input.prompt
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input.prompt,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
response_content = response.content[0].text
|
||||
@@ -146,13 +158,32 @@ class ToolActivities:
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response_content)
|
||||
|
||||
try:
|
||||
data = json.loads(response_content)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Invalid JSON: {e}")
|
||||
raise json.JSONDecodeError
|
||||
return self.parse_json_response(response_content)
|
||||
|
||||
return data
|
||||
def prompt_llm_deepseek(self, input: ToolPromptInput) -> dict:
|
||||
api_client = deepseek.DeepSeekAPI(api_key=os.environ.get("DEEPSEEK_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": input.context_instructions
|
||||
+ ". The current date is "
|
||||
+ datetime.now().strftime("%B %d, %Y"),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": input.prompt,
|
||||
},
|
||||
]
|
||||
|
||||
response = api_client.chat_completion(prompt=messages)
|
||||
response_content = response
|
||||
print(f"DeepSeek response: {response_content}")
|
||||
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response_content)
|
||||
|
||||
return self.parse_json_response(response_content)
|
||||
|
||||
def sanitize_json_response(self, response_content: str) -> str:
|
||||
"""
|
||||
|
||||
15
poetry.lock
generated
15
poetry.lock
generated
@@ -251,6 +251,19 @@ files = [
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deepseek"
|
||||
version = "1.0.0"
|
||||
description = "Deepseek API Library"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "deepseek-1.0.0-py3-none-any.whl", hash = "sha256:ee4175bfcb7ac1154369dbd86a4d8bc1809f6fa20e3e7baa362544567197cb3f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
requests = "*"
|
||||
|
||||
[[package]]
|
||||
name = "distro"
|
||||
version = "1.9.0"
|
||||
@@ -1399,4 +1412,4 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)",
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<4.0"
|
||||
content-hash = "a6b005e6cf77139e266ddf6b3263890f31a7b8813b3b5db95831d5fdb4d20239"
|
||||
content-hash = "07f540e2c348c45ac9f6f7f4c79f1c060ae269a380474795843b1a444094acb5"
|
||||
|
||||
@@ -30,6 +30,7 @@ openai = "^1.59.2"
|
||||
stripe = "^11.4.1"
|
||||
google-generativeai = "^0.8.4"
|
||||
anthropic = "^0.45.0"
|
||||
deepseek = "^1.0.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^7.3"
|
||||
|
||||
@@ -3,7 +3,7 @@ from dotenv import load_dotenv
|
||||
from temporalio.client import Client
|
||||
from temporalio.service import TLSConfig
|
||||
|
||||
load_dotenv()
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Temporal connection settings
|
||||
TEMPORAL_ADDRESS = os.getenv("TEMPORAL_ADDRESS", "localhost:7233")
|
||||
@@ -15,6 +15,7 @@ TEMPORAL_TLS_CERT = os.getenv("TEMPORAL_TLS_CERT", "")
|
||||
TEMPORAL_TLS_KEY = os.getenv("TEMPORAL_TLS_KEY", "")
|
||||
TEMPORAL_API_KEY = os.getenv("TEMPORAL_API_KEY", "")
|
||||
|
||||
|
||||
async def get_temporal_client() -> Client:
|
||||
"""
|
||||
Creates a Temporal client based on environment configuration.
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import stripe
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv() # Load environment variables from a .env file
|
||||
load_dotenv(override=True) # Load environment variables from a .env file
|
||||
|
||||
stripe.api_key = os.getenv("STRIPE_API_KEY", "YOUR_DEFAULT_KEY")
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ def search_airport(query: str) -> list:
|
||||
"""
|
||||
Returns a list of matching airports/cities from sky-scrapper's searchAirport endpoint.
|
||||
"""
|
||||
load_dotenv()
|
||||
load_dotenv(override=True)
|
||||
api_key = os.getenv("RAPIDAPI_KEY", "YOUR_DEFAULT_KEY")
|
||||
api_host = os.getenv("RAPIDAPI_HOST", "sky-scrapper.p.rapidapi.com")
|
||||
|
||||
@@ -67,7 +67,7 @@ def search_flights(args: dict) -> dict: # _realapi
|
||||
dest_entity_id = dest_params["entityId"] # e.g. "27537542"
|
||||
|
||||
# Step 2: Call flight search with resolved codes
|
||||
load_dotenv()
|
||||
load_dotenv(override=True)
|
||||
api_key = os.getenv("RAPIDAPI_KEY", "YOUR_DEFAULT_KEY")
|
||||
api_host = os.getenv("RAPIDAPI_HOST", "sky-scrapper.p.rapidapi.com")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user