diff --git a/.env.example b/.env.example index 48f1c26..8d411dc 100644 --- a/.env.example +++ b/.env.example @@ -1,11 +1,10 @@ -OPENAI_API_KEY=sk-proj-... - RAPIDAPI_KEY=9df2cb5... RAPIDAPI_HOST=sky-scrapper.p.rapidapi.com STRIPE_API_KEY=sk_test_51J... LLM_PROVIDER=openai # default +OPENAI_API_KEY=sk-proj-... # or # LLM_PROVIDER=ollama # OLLAMA_MODEL_NAME=qwen2.5:14b @@ -15,6 +14,10 @@ LLM_PROVIDER=openai # default # or # LLM_PROVIDER=anthropic # ANTHROPIC_API_KEY=your-anthropic-api-key +# or +# LLM_PROVIDER=deepseek +# DEEPSEEK_API_KEY=your-deepseek-api-key + # uncomment and unset these environment variables to connect to the local dev server # TEMPORAL_ADDRESS=namespace.acct.tmprl.cloud:7233 diff --git a/activities/tool_activities.py b/activities/tool_activities.py index 375bd76..445791d 100644 --- a/activities/tool_activities.py +++ b/activities/tool_activities.py @@ -9,6 +9,18 @@ import os from datetime import datetime import google.generativeai as genai import anthropic +import deepseek +from dotenv import load_dotenv + +load_dotenv(override=True) +print( + "Using LLM: " + + os.environ.get("LLM_PROVIDER") + + " (set LLM_PROVIDER in .env to change)" +) + +if os.environ.get("LLM_PROVIDER") == "ollama": + print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME")) @dataclass @@ -22,20 +34,33 @@ class ToolActivities: def prompt_llm(self, input: ToolPromptInput) -> dict: llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower() + print(f"LLM provider: {llm_provider}") + if llm_provider == "ollama": return self.prompt_llm_ollama(input) elif llm_provider == "google": return self.prompt_llm_google(input) elif llm_provider == "anthropic": return self.prompt_llm_anthropic(input) + elif llm_provider == "deepseek": + return self.prompt_llm_deepseek(input) else: return self.prompt_llm_openai(input) + def parse_json_response(self, response_content: str) -> dict: + """ + Parses the JSON response content and returns it as a dictionary. + """ + try: + data = json.loads(response_content) + return data + except json.JSONDecodeError as e: + print(f"Invalid JSON: {e}") + raise json.JSONDecodeError + def prompt_llm_openai(self, input: ToolPromptInput) -> dict: client = OpenAI( - api_key=os.environ.get( - "OPENAI_API_KEY" - ), + api_key=os.environ.get("OPENAI_API_KEY"), ) messages = [ @@ -61,15 +86,8 @@ class ToolActivities: # Use the new sanitize function response_content = self.sanitize_json_response(response_content) - try: - data = json.loads(response_content) - except json.JSONDecodeError as e: - print(f"Invalid JSON: {e}") - raise json.JSONDecodeError + return self.parse_json_response(response_content) - return data - - @activity.defn def prompt_llm_ollama(self, input: ToolPromptInput) -> dict: model_name = os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b") messages = [ @@ -92,14 +110,7 @@ class ToolActivities: # Use the new sanitize function response_content = self.sanitize_json_response(response.message.content) - try: - data = json.loads(response_content) - except json.JSONDecodeError as e: - print(f"Invalid JSON: {e}") - print(response.message.content) - raise json.JSONDecodeError - - return data + return self.parse_json_response(response_content) def prompt_llm_google(self, input: ToolPromptInput) -> dict: api_key = os.environ.get("GOOGLE_API_KEY") @@ -118,18 +129,14 @@ class ToolActivities: # Use the new sanitize function response_content = self.sanitize_json_response(response_content) - try: - data = json.loads(response_content) - except json.JSONDecodeError as e: - print(f"Invalid JSON: {e}") - raise json.JSONDecodeError - - return data + return self.parse_json_response(response_content) def prompt_llm_anthropic(self, input: ToolPromptInput) -> dict: api_key = os.environ.get("ANTHROPIC_API_KEY") if not api_key: - raise ValueError("ANTHROPIC_API_KEY is not set in the environment variables.") + raise ValueError( + "ANTHROPIC_API_KEY is not set in the environment variables." + ) client = anthropic.Anthropic(api_key=api_key) @@ -137,7 +144,12 @@ class ToolActivities: model="claude-3-5-sonnet-20241022", max_tokens=1024, system=input.context_instructions, - messages=input.prompt + messages=[ + { + "role": "user", + "content": input.prompt, + } + ], ) response_content = response.content[0].text @@ -146,13 +158,32 @@ class ToolActivities: # Use the new sanitize function response_content = self.sanitize_json_response(response_content) - try: - data = json.loads(response_content) - except json.JSONDecodeError as e: - print(f"Invalid JSON: {e}") - raise json.JSONDecodeError + return self.parse_json_response(response_content) - return data + def prompt_llm_deepseek(self, input: ToolPromptInput) -> dict: + api_client = deepseek.DeepSeekAPI(api_key=os.environ.get("DEEPSEEK_API_KEY")) + + messages = [ + { + "role": "system", + "content": input.context_instructions + + ". The current date is " + + datetime.now().strftime("%B %d, %Y"), + }, + { + "role": "user", + "content": input.prompt, + }, + ] + + response = api_client.chat_completion(prompt=messages) + response_content = response + print(f"DeepSeek response: {response_content}") + + # Use the new sanitize function + response_content = self.sanitize_json_response(response_content) + + return self.parse_json_response(response_content) def sanitize_json_response(self, response_content: str) -> str: """ diff --git a/poetry.lock b/poetry.lock index d8a9587..33b6777 100644 --- a/poetry.lock +++ b/poetry.lock @@ -251,6 +251,19 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "deepseek" +version = "1.0.0" +description = "Deepseek API Library" +optional = false +python-versions = "*" +files = [ + {file = "deepseek-1.0.0-py3-none-any.whl", hash = "sha256:ee4175bfcb7ac1154369dbd86a4d8bc1809f6fa20e3e7baa362544567197cb3f"}, +] + +[package.dependencies] +requests = "*" + [[package]] name = "distro" version = "1.9.0" @@ -1399,4 +1412,4 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "a6b005e6cf77139e266ddf6b3263890f31a7b8813b3b5db95831d5fdb4d20239" +content-hash = "07f540e2c348c45ac9f6f7f4c79f1c060ae269a380474795843b1a444094acb5" diff --git a/pyproject.toml b/pyproject.toml index 1714482..ad596b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ openai = "^1.59.2" stripe = "^11.4.1" google-generativeai = "^0.8.4" anthropic = "^0.45.0" +deepseek = "^1.0.0" [tool.poetry.group.dev.dependencies] pytest = "^7.3" diff --git a/shared/config.py b/shared/config.py index 9291863..c2005ec 100644 --- a/shared/config.py +++ b/shared/config.py @@ -3,7 +3,7 @@ from dotenv import load_dotenv from temporalio.client import Client from temporalio.service import TLSConfig -load_dotenv() +load_dotenv(override=True) # Temporal connection settings TEMPORAL_ADDRESS = os.getenv("TEMPORAL_ADDRESS", "localhost:7233") @@ -15,6 +15,7 @@ TEMPORAL_TLS_CERT = os.getenv("TEMPORAL_TLS_CERT", "") TEMPORAL_TLS_KEY = os.getenv("TEMPORAL_TLS_KEY", "") TEMPORAL_API_KEY = os.getenv("TEMPORAL_API_KEY", "") + async def get_temporal_client() -> Client: """ Creates a Temporal client based on environment configuration. @@ -47,10 +48,10 @@ async def get_temporal_client() -> Client: api_key=TEMPORAL_API_KEY, tls=True, # Always use TLS with API key ) - + # Use mTLS or local connection return await Client.connect( TEMPORAL_ADDRESS, namespace=TEMPORAL_NAMESPACE, tls=tls_config, - ) \ No newline at end of file + ) diff --git a/tools/create_invoice.py b/tools/create_invoice.py index e7e1ae7..1cf0283 100644 --- a/tools/create_invoice.py +++ b/tools/create_invoice.py @@ -2,7 +2,7 @@ import os import stripe from dotenv import load_dotenv -load_dotenv() # Load environment variables from a .env file +load_dotenv(override=True) # Load environment variables from a .env file stripe.api_key = os.getenv("STRIPE_API_KEY", "YOUR_DEFAULT_KEY") diff --git a/tools/search_flights.py b/tools/search_flights.py index 52009eb..f5cac85 100644 --- a/tools/search_flights.py +++ b/tools/search_flights.py @@ -9,7 +9,7 @@ def search_airport(query: str) -> list: """ Returns a list of matching airports/cities from sky-scrapper's searchAirport endpoint. """ - load_dotenv() + load_dotenv(override=True) api_key = os.getenv("RAPIDAPI_KEY", "YOUR_DEFAULT_KEY") api_host = os.getenv("RAPIDAPI_HOST", "sky-scrapper.p.rapidapi.com") @@ -67,7 +67,7 @@ def search_flights(args: dict) -> dict: # _realapi dest_entity_id = dest_params["entityId"] # e.g. "27537542" # Step 2: Call flight search with resolved codes - load_dotenv() + load_dotenv(override=True) api_key = os.getenv("RAPIDAPI_KEY", "YOUR_DEFAULT_KEY") api_host = os.getenv("RAPIDAPI_HOST", "sky-scrapper.p.rapidapi.com")