diff --git a/activities/tool_activities.py b/activities/tool_activities.py index d128424..e833b44 100644 --- a/activities/tool_activities.py +++ b/activities/tool_activities.py @@ -20,7 +20,10 @@ print( ) if os.environ.get("LLM_PROVIDER") == "ollama": - print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b")) + print( + "Using Ollama (local) model: " + + os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b") + ) class ToolActivities: @@ -28,13 +31,15 @@ class ToolActivities: """Initialize LLM clients based on environment configuration.""" self.llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower() print(f"Initializing ToolActivities with LLM provider: {self.llm_provider}") - + # Initialize client variables (all set to None initially) self.openai_client: Optional[OpenAI] = None self.anthropic_client: Optional[anthropic.Anthropic] = None self.genai_configured: bool = False self.deepseek_client: Optional[deepseek.DeepSeekAPI] = None - + self.ollama_model_name: Optional[str] = None + self.ollama_initialized: bool = False + # Only initialize the client specified by LLM_PROVIDER if self.llm_provider == "openai": if os.environ.get("OPENAI_API_KEY"): @@ -42,14 +47,18 @@ class ToolActivities: print("Initialized OpenAI client") else: print("Warning: OPENAI_API_KEY not set but LLM_PROVIDER is 'openai'") - + elif self.llm_provider == "anthropic": if os.environ.get("ANTHROPIC_API_KEY"): - self.anthropic_client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")) + self.anthropic_client = anthropic.Anthropic( + api_key=os.environ.get("ANTHROPIC_API_KEY") + ) print("Initialized Anthropic client") else: - print("Warning: ANTHROPIC_API_KEY not set but LLM_PROVIDER is 'anthropic'") - + print( + "Warning: ANTHROPIC_API_KEY not set but LLM_PROVIDER is 'anthropic'" + ) + elif self.llm_provider == "google": api_key = os.environ.get("GOOGLE_API_KEY") if api_key: @@ -58,22 +67,62 @@ class ToolActivities: print("Configured Google Generative AI") else: print("Warning: GOOGLE_API_KEY not set but LLM_PROVIDER is 'google'") - + elif self.llm_provider == "deepseek": if os.environ.get("DEEPSEEK_API_KEY"): - self.deepseek_client = deepseek.DeepSeekAPI(api_key=os.environ.get("DEEPSEEK_API_KEY")) + self.deepseek_client = deepseek.DeepSeekAPI( + api_key=os.environ.get("DEEPSEEK_API_KEY") + ) print("Initialized DeepSeek client") else: - print("Warning: DEEPSEEK_API_KEY not set but LLM_PROVIDER is 'deepseek'") - - # Ollama is initialized on-demand since it's a local API call + print( + "Warning: DEEPSEEK_API_KEY not set but LLM_PROVIDER is 'deepseek'" + ) + + # For Ollama, we store the model name but actual initialization happens in warm_up_ollama elif self.llm_provider == "ollama": - if not os.environ.get("OLLAMA_MODEL_NAME"): - print("Warning: OLLAMA_MODEL_NAME not set, will use default 'qwen2.5:14b'") - else: - print(f"Using Ollama model: {os.environ.get('OLLAMA_MODEL_NAME')}") + self.ollama_model_name = os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b") + print( + f"Using Ollama model: {self.ollama_model_name} (will be loaded on worker startup)" + ) else: - print(f"Warning: Unknown LLM_PROVIDER '{self.llm_provider}', defaulting to OpenAI") + print( + f"Warning: Unknown LLM_PROVIDER '{self.llm_provider}', defaulting to OpenAI" + ) + + def warm_up_ollama(self): + """Pre-load the Ollama model to avoid cold start latency on first request""" + if self.llm_provider != "ollama" or self.ollama_initialized: + return False # No need to warm up if not using Ollama or already warmed up + + try: + print( + f"Pre-loading Ollama model '{self.ollama_model_name}' - this may take 30+ seconds..." + ) + start_time = datetime.now() + + # Make a simple request to load the model into memory + chat( + model=self.ollama_model_name, + messages=[ + {"role": "system", "content": "You are an AI assistant"}, + { + "role": "user", + "content": "Hello! This is a warm-up message to load the model.", + }, + ], + ) + + elapsed_time = (datetime.now() - start_time).total_seconds() + print(f"✅ Ollama model loaded successfully in {elapsed_time:.2f} seconds") + self.ollama_initialized = True + return True + except Exception as e: + print(f"❌ Error pre-loading Ollama model: {str(e)}") + print( + "The worker will continue, but the first actual request may experience a delay." + ) + return False @activity.defn async def agent_validatePrompt( @@ -158,13 +207,15 @@ class ToolActivities: return data except json.JSONDecodeError as e: print(f"Invalid JSON: {e}") - raise json.JSONDecodeError + raise def prompt_llm_openai(self, input: ToolPromptInput) -> dict: if not self.openai_client: api_key = os.environ.get("OPENAI_API_KEY") if not api_key: - raise ValueError("OPENAI_API_KEY is not set in the environment variables but LLM_PROVIDER is 'openai'") + raise ValueError( + "OPENAI_API_KEY is not set in the environment variables but LLM_PROVIDER is 'openai'" + ) self.openai_client = OpenAI(api_key=api_key) print("Initialized OpenAI client on demand") @@ -194,7 +245,20 @@ class ToolActivities: return self.parse_json_response(response_content) def prompt_llm_ollama(self, input: ToolPromptInput) -> dict: - model_name = os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b") + # If not yet initialized, try to do so now (this is a backup if warm_up_ollama wasn't called or failed) + if not self.ollama_initialized: + print( + "Ollama model not pre-loaded. Loading now (this may take 30+ seconds)..." + ) + try: + self.warm_up_ollama() + except Exception: + # We already logged the error in warm_up_ollama, continue with the actual request + pass + + model_name = self.ollama_model_name or os.environ.get( + "OLLAMA_MODEL_NAME", "qwen2.5:14b" + ) messages = [ { "role": "system", @@ -208,20 +272,29 @@ class ToolActivities: }, ] - response: ChatResponse = chat(model=model_name, messages=messages) + try: + response: ChatResponse = chat(model=model_name, messages=messages) + print(f"Chat response: {response.message.content}") - print(f"Chat response: {response.message.content}") - - # Use the new sanitize function - response_content = self.sanitize_json_response(response.message.content) - - return self.parse_json_response(response_content) + # Use the new sanitize function + response_content = self.sanitize_json_response(response.message.content) + return self.parse_json_response(response_content) + except (json.JSONDecodeError, ValueError) as e: + # Re-raise JSON-related exceptions to let Temporal retry the activity + print(f"JSON parsing error with Ollama response: {str(e)}") + raise + except Exception as e: + # Log and raise other exceptions that may need retrying + print(f"Error in Ollama chat: {str(e)}") + raise def prompt_llm_google(self, input: ToolPromptInput) -> dict: if not self.genai_configured: api_key = os.environ.get("GOOGLE_API_KEY") if not api_key: - raise ValueError("GOOGLE_API_KEY is not set in the environment variables but LLM_PROVIDER is 'google'") + raise ValueError( + "GOOGLE_API_KEY is not set in the environment variables but LLM_PROVIDER is 'google'" + ) genai.configure(api_key=api_key) self.genai_configured = True print("Configured Google Generative AI on demand") @@ -245,7 +318,9 @@ class ToolActivities: if not self.anthropic_client: api_key = os.environ.get("ANTHROPIC_API_KEY") if not api_key: - raise ValueError("ANTHROPIC_API_KEY is not set in the environment variables but LLM_PROVIDER is 'anthropic'") + raise ValueError( + "ANTHROPIC_API_KEY is not set in the environment variables but LLM_PROVIDER is 'anthropic'" + ) self.anthropic_client = anthropic.Anthropic(api_key=api_key) print("Initialized Anthropic client on demand") @@ -275,7 +350,9 @@ class ToolActivities: if not self.deepseek_client: api_key = os.environ.get("DEEPSEEK_API_KEY") if not api_key: - raise ValueError("DEEPSEEK_API_KEY is not set in the environment variables but LLM_PROVIDER is 'deepseek'") + raise ValueError( + "DEEPSEEK_API_KEY is not set in the environment variables but LLM_PROVIDER is 'deepseek'" + ) self.deepseek_client = deepseek.DeepSeekAPI(api_key=api_key) print("Initialized DeepSeek client on demand") diff --git a/scripts/run_worker.py b/scripts/run_worker.py index 0f7ad62..7b9549e 100644 --- a/scripts/run_worker.py +++ b/scripts/run_worker.py @@ -14,11 +14,11 @@ from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE async def main(): # Load environment variables load_dotenv(override=True) - + # Print LLM configuration info llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower() print(f"Worker will use LLM provider: {llm_provider}") - + # Create the client client = await get_temporal_client() @@ -26,6 +26,29 @@ async def main(): activities = ToolActivities() print(f"ToolActivities initialized with LLM provider: {llm_provider}") + # If using Ollama, pre-load the model to avoid cold start latency + if llm_provider == "ollama": + print("\n======== OLLAMA MODEL INITIALIZATION ========") + print("Ollama models need to be loaded into memory on first use.") + print("This may take 30+ seconds depending on your hardware and model size.") + print("Please wait while the model is being loaded...") + + # This call will load the model and measure initialization time + success = activities.warm_up_ollama() + + if success: + print("===========================================================") + print("✅ Ollama model successfully pre-loaded and ready for requests!") + print("===========================================================\n") + else: + print("===========================================================") + print("⚠️ Ollama model pre-loading failed. The worker will continue,") + print("but the first actual request may experience a delay while") + print("the model is loaded on-demand.") + print("===========================================================\n") + + print("Worker ready to process tasks!") + # Run the worker with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor: worker = Worker(