diff --git a/activities/tool_activities.py b/activities/tool_activities.py index 405b33a..d128424 100644 --- a/activities/tool_activities.py +++ b/activities/tool_activities.py @@ -2,7 +2,7 @@ from temporalio import activity from ollama import chat, ChatResponse from openai import OpenAI import json -from typing import Sequence +from typing import Sequence, Optional from temporalio.common import RawValue import os from datetime import datetime @@ -14,16 +14,67 @@ from models.data_types import ValidationInput, ValidationResult, ToolPromptInput load_dotenv(override=True) print( - "Using LLM: " - + os.environ.get("LLM_PROVIDER") + "Using LLM provider: " + + os.environ.get("LLM_PROVIDER", "openai") + " (set LLM_PROVIDER in .env to change)" ) if os.environ.get("LLM_PROVIDER") == "ollama": - print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME")) + print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b")) class ToolActivities: + def __init__(self): + """Initialize LLM clients based on environment configuration.""" + self.llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower() + print(f"Initializing ToolActivities with LLM provider: {self.llm_provider}") + + # Initialize client variables (all set to None initially) + self.openai_client: Optional[OpenAI] = None + self.anthropic_client: Optional[anthropic.Anthropic] = None + self.genai_configured: bool = False + self.deepseek_client: Optional[deepseek.DeepSeekAPI] = None + + # Only initialize the client specified by LLM_PROVIDER + if self.llm_provider == "openai": + if os.environ.get("OPENAI_API_KEY"): + self.openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) + print("Initialized OpenAI client") + else: + print("Warning: OPENAI_API_KEY not set but LLM_PROVIDER is 'openai'") + + elif self.llm_provider == "anthropic": + if os.environ.get("ANTHROPIC_API_KEY"): + self.anthropic_client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) + print("Initialized Anthropic client") + else: + print("Warning: ANTHROPIC_API_KEY not set but LLM_PROVIDER is 'anthropic'") + + elif self.llm_provider == "google": + api_key = os.environ.get("GOOGLE_API_KEY") + if api_key: + genai.configure(api_key=api_key) + self.genai_configured = True + print("Configured Google Generative AI") + else: + print("Warning: GOOGLE_API_KEY not set but LLM_PROVIDER is 'google'") + + elif self.llm_provider == "deepseek": + if os.environ.get("DEEPSEEK_API_KEY"): + self.deepseek_client = deepseek.DeepSeekAPI(api_key=os.environ.get("DEEPSEEK_API_KEY")) + print("Initialized DeepSeek client") + else: + print("Warning: DEEPSEEK_API_KEY not set but LLM_PROVIDER is 'deepseek'") + + # Ollama is initialized on-demand since it's a local API call + elif self.llm_provider == "ollama": + if not os.environ.get("OLLAMA_MODEL_NAME"): + print("Warning: OLLAMA_MODEL_NAME not set, will use default 'qwen2.5:14b'") + else: + print(f"Using Ollama model: {os.environ.get('OLLAMA_MODEL_NAME')}") + else: + print(f"Warning: Unknown LLM_PROVIDER '{self.llm_provider}', defaulting to OpenAI") + @activity.defn async def agent_validatePrompt( self, validation_input: ValidationInput @@ -87,17 +138,13 @@ class ToolActivities: @activity.defn def agent_toolPlanner(self, input: ToolPromptInput) -> dict: - llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower() - - print(f"LLM provider: {llm_provider}") - - if llm_provider == "ollama": + if self.llm_provider == "ollama": return self.prompt_llm_ollama(input) - elif llm_provider == "google": + elif self.llm_provider == "google": return self.prompt_llm_google(input) - elif llm_provider == "anthropic": + elif self.llm_provider == "anthropic": return self.prompt_llm_anthropic(input) - elif llm_provider == "deepseek": + elif self.llm_provider == "deepseek": return self.prompt_llm_deepseek(input) else: return self.prompt_llm_openai(input) @@ -114,9 +161,12 @@ class ToolActivities: raise json.JSONDecodeError def prompt_llm_openai(self, input: ToolPromptInput) -> dict: - client = OpenAI( - api_key=os.environ.get("OPENAI_API_KEY"), - ) + if not self.openai_client: + api_key = os.environ.get("OPENAI_API_KEY") + if not api_key: + raise ValueError("OPENAI_API_KEY is not set in the environment variables but LLM_PROVIDER is 'openai'") + self.openai_client = OpenAI(api_key=api_key) + print("Initialized OpenAI client on demand") messages = [ { @@ -131,7 +181,7 @@ class ToolActivities: }, ] - chat_completion = client.chat.completions.create( + chat_completion = self.openai_client.chat.completions.create( model="gpt-4o", messages=messages # was gpt-4-0613 ) @@ -168,11 +218,14 @@ class ToolActivities: return self.parse_json_response(response_content) def prompt_llm_google(self, input: ToolPromptInput) -> dict: - api_key = os.environ.get("GOOGLE_API_KEY") - if not api_key: - raise ValueError("GOOGLE_API_KEY is not set in the environment variables.") + if not self.genai_configured: + api_key = os.environ.get("GOOGLE_API_KEY") + if not api_key: + raise ValueError("GOOGLE_API_KEY is not set in the environment variables but LLM_PROVIDER is 'google'") + genai.configure(api_key=api_key) + self.genai_configured = True + print("Configured Google Generative AI on demand") - genai.configure(api_key=api_key) model = genai.GenerativeModel( "models/gemini-1.5-flash", system_instruction=input.context_instructions @@ -189,15 +242,14 @@ class ToolActivities: return self.parse_json_response(response_content) def prompt_llm_anthropic(self, input: ToolPromptInput) -> dict: - api_key = os.environ.get("ANTHROPIC_API_KEY") - if not api_key: - raise ValueError( - "ANTHROPIC_API_KEY is not set in the environment variables." - ) + if not self.anthropic_client: + api_key = os.environ.get("ANTHROPIC_API_KEY") + if not api_key: + raise ValueError("ANTHROPIC_API_KEY is not set in the environment variables but LLM_PROVIDER is 'anthropic'") + self.anthropic_client = anthropic.Anthropic(api_key=api_key) + print("Initialized Anthropic client on demand") - client = anthropic.Anthropic(api_key=api_key) - - response = client.messages.create( + response = self.anthropic_client.messages.create( model="claude-3-5-sonnet-20241022", # todo try claude-3-7-sonnet-20250219 max_tokens=1024, system=input.context_instructions @@ -220,7 +272,12 @@ class ToolActivities: return self.parse_json_response(response_content) def prompt_llm_deepseek(self, input: ToolPromptInput) -> dict: - api_client = deepseek.DeepSeekAPI(api_key=os.environ.get("DEEPSEEK_API_KEY")) + if not self.deepseek_client: + api_key = os.environ.get("DEEPSEEK_API_KEY") + if not api_key: + raise ValueError("DEEPSEEK_API_KEY is not set in the environment variables but LLM_PROVIDER is 'deepseek'") + self.deepseek_client = deepseek.DeepSeekAPI(api_key=api_key) + print("Initialized DeepSeek client on demand") messages = [ { @@ -235,7 +292,7 @@ class ToolActivities: }, ] - response = api_client.chat_completion(prompt=messages) + response = self.deepseek_client.chat_completion(prompt=messages) response_content = response print(f"DeepSeek response: {response_content}") diff --git a/scripts/run_worker.py b/scripts/run_worker.py index fb88ff1..0f7ad62 100644 --- a/scripts/run_worker.py +++ b/scripts/run_worker.py @@ -1,6 +1,7 @@ import asyncio - import concurrent.futures +import os +from dotenv import load_dotenv from temporalio.worker import Worker @@ -11,10 +12,19 @@ from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE async def main(): + # Load environment variables + load_dotenv(override=True) + + # Print LLM configuration info + llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower() + print(f"Worker will use LLM provider: {llm_provider}") + # Create the client client = await get_temporal_client() + # Initialize the activities class once with the specified LLM provider activities = ToolActivities() + print(f"ToolActivities initialized with LLM provider: {llm_provider}") # Run the worker with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor: diff --git a/workflows/workflow_helpers.py b/workflows/workflow_helpers.py index 868b301..fba105c 100644 --- a/workflows/workflow_helpers.py +++ b/workflows/workflow_helpers.py @@ -14,7 +14,7 @@ from shared.config import TEMPORAL_LEGACY_TASK_QUEUE # Constants from original file TOOL_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=12) TOOL_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30) -LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=12) +LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=20) LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30)