mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
pre-warm ollama local model on initialization
This commit is contained in:
@@ -20,7 +20,10 @@ print(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if os.environ.get("LLM_PROVIDER") == "ollama":
|
if os.environ.get("LLM_PROVIDER") == "ollama":
|
||||||
print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b"))
|
print(
|
||||||
|
"Using Ollama (local) model: "
|
||||||
|
+ os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ToolActivities:
|
class ToolActivities:
|
||||||
@@ -28,13 +31,15 @@ class ToolActivities:
|
|||||||
"""Initialize LLM clients based on environment configuration."""
|
"""Initialize LLM clients based on environment configuration."""
|
||||||
self.llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
self.llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
||||||
print(f"Initializing ToolActivities with LLM provider: {self.llm_provider}")
|
print(f"Initializing ToolActivities with LLM provider: {self.llm_provider}")
|
||||||
|
|
||||||
# Initialize client variables (all set to None initially)
|
# Initialize client variables (all set to None initially)
|
||||||
self.openai_client: Optional[OpenAI] = None
|
self.openai_client: Optional[OpenAI] = None
|
||||||
self.anthropic_client: Optional[anthropic.Anthropic] = None
|
self.anthropic_client: Optional[anthropic.Anthropic] = None
|
||||||
self.genai_configured: bool = False
|
self.genai_configured: bool = False
|
||||||
self.deepseek_client: Optional[deepseek.DeepSeekAPI] = None
|
self.deepseek_client: Optional[deepseek.DeepSeekAPI] = None
|
||||||
|
self.ollama_model_name: Optional[str] = None
|
||||||
|
self.ollama_initialized: bool = False
|
||||||
|
|
||||||
# Only initialize the client specified by LLM_PROVIDER
|
# Only initialize the client specified by LLM_PROVIDER
|
||||||
if self.llm_provider == "openai":
|
if self.llm_provider == "openai":
|
||||||
if os.environ.get("OPENAI_API_KEY"):
|
if os.environ.get("OPENAI_API_KEY"):
|
||||||
@@ -42,14 +47,18 @@ class ToolActivities:
|
|||||||
print("Initialized OpenAI client")
|
print("Initialized OpenAI client")
|
||||||
else:
|
else:
|
||||||
print("Warning: OPENAI_API_KEY not set but LLM_PROVIDER is 'openai'")
|
print("Warning: OPENAI_API_KEY not set but LLM_PROVIDER is 'openai'")
|
||||||
|
|
||||||
elif self.llm_provider == "anthropic":
|
elif self.llm_provider == "anthropic":
|
||||||
if os.environ.get("ANTHROPIC_API_KEY"):
|
if os.environ.get("ANTHROPIC_API_KEY"):
|
||||||
self.anthropic_client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
|
self.anthropic_client = anthropic.Anthropic(
|
||||||
|
api_key=os.environ.get("ANTHROPIC_API_KEY")
|
||||||
|
)
|
||||||
print("Initialized Anthropic client")
|
print("Initialized Anthropic client")
|
||||||
else:
|
else:
|
||||||
print("Warning: ANTHROPIC_API_KEY not set but LLM_PROVIDER is 'anthropic'")
|
print(
|
||||||
|
"Warning: ANTHROPIC_API_KEY not set but LLM_PROVIDER is 'anthropic'"
|
||||||
|
)
|
||||||
|
|
||||||
elif self.llm_provider == "google":
|
elif self.llm_provider == "google":
|
||||||
api_key = os.environ.get("GOOGLE_API_KEY")
|
api_key = os.environ.get("GOOGLE_API_KEY")
|
||||||
if api_key:
|
if api_key:
|
||||||
@@ -58,22 +67,62 @@ class ToolActivities:
|
|||||||
print("Configured Google Generative AI")
|
print("Configured Google Generative AI")
|
||||||
else:
|
else:
|
||||||
print("Warning: GOOGLE_API_KEY not set but LLM_PROVIDER is 'google'")
|
print("Warning: GOOGLE_API_KEY not set but LLM_PROVIDER is 'google'")
|
||||||
|
|
||||||
elif self.llm_provider == "deepseek":
|
elif self.llm_provider == "deepseek":
|
||||||
if os.environ.get("DEEPSEEK_API_KEY"):
|
if os.environ.get("DEEPSEEK_API_KEY"):
|
||||||
self.deepseek_client = deepseek.DeepSeekAPI(api_key=os.environ.get("DEEPSEEK_API_KEY"))
|
self.deepseek_client = deepseek.DeepSeekAPI(
|
||||||
|
api_key=os.environ.get("DEEPSEEK_API_KEY")
|
||||||
|
)
|
||||||
print("Initialized DeepSeek client")
|
print("Initialized DeepSeek client")
|
||||||
else:
|
else:
|
||||||
print("Warning: DEEPSEEK_API_KEY not set but LLM_PROVIDER is 'deepseek'")
|
print(
|
||||||
|
"Warning: DEEPSEEK_API_KEY not set but LLM_PROVIDER is 'deepseek'"
|
||||||
# Ollama is initialized on-demand since it's a local API call
|
)
|
||||||
|
|
||||||
|
# For Ollama, we store the model name but actual initialization happens in warm_up_ollama
|
||||||
elif self.llm_provider == "ollama":
|
elif self.llm_provider == "ollama":
|
||||||
if not os.environ.get("OLLAMA_MODEL_NAME"):
|
self.ollama_model_name = os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b")
|
||||||
print("Warning: OLLAMA_MODEL_NAME not set, will use default 'qwen2.5:14b'")
|
print(
|
||||||
else:
|
f"Using Ollama model: {self.ollama_model_name} (will be loaded on worker startup)"
|
||||||
print(f"Using Ollama model: {os.environ.get('OLLAMA_MODEL_NAME')}")
|
)
|
||||||
else:
|
else:
|
||||||
print(f"Warning: Unknown LLM_PROVIDER '{self.llm_provider}', defaulting to OpenAI")
|
print(
|
||||||
|
f"Warning: Unknown LLM_PROVIDER '{self.llm_provider}', defaulting to OpenAI"
|
||||||
|
)
|
||||||
|
|
||||||
|
def warm_up_ollama(self):
|
||||||
|
"""Pre-load the Ollama model to avoid cold start latency on first request"""
|
||||||
|
if self.llm_provider != "ollama" or self.ollama_initialized:
|
||||||
|
return False # No need to warm up if not using Ollama or already warmed up
|
||||||
|
|
||||||
|
try:
|
||||||
|
print(
|
||||||
|
f"Pre-loading Ollama model '{self.ollama_model_name}' - this may take 30+ seconds..."
|
||||||
|
)
|
||||||
|
start_time = datetime.now()
|
||||||
|
|
||||||
|
# Make a simple request to load the model into memory
|
||||||
|
chat(
|
||||||
|
model=self.ollama_model_name,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are an AI assistant"},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello! This is a warm-up message to load the model.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
elapsed_time = (datetime.now() - start_time).total_seconds()
|
||||||
|
print(f"✅ Ollama model loaded successfully in {elapsed_time:.2f} seconds")
|
||||||
|
self.ollama_initialized = True
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error pre-loading Ollama model: {str(e)}")
|
||||||
|
print(
|
||||||
|
"The worker will continue, but the first actual request may experience a delay."
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
@activity.defn
|
@activity.defn
|
||||||
async def agent_validatePrompt(
|
async def agent_validatePrompt(
|
||||||
@@ -158,13 +207,15 @@ class ToolActivities:
|
|||||||
return data
|
return data
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
print(f"Invalid JSON: {e}")
|
print(f"Invalid JSON: {e}")
|
||||||
raise json.JSONDecodeError
|
raise
|
||||||
|
|
||||||
def prompt_llm_openai(self, input: ToolPromptInput) -> dict:
|
def prompt_llm_openai(self, input: ToolPromptInput) -> dict:
|
||||||
if not self.openai_client:
|
if not self.openai_client:
|
||||||
api_key = os.environ.get("OPENAI_API_KEY")
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError("OPENAI_API_KEY is not set in the environment variables but LLM_PROVIDER is 'openai'")
|
raise ValueError(
|
||||||
|
"OPENAI_API_KEY is not set in the environment variables but LLM_PROVIDER is 'openai'"
|
||||||
|
)
|
||||||
self.openai_client = OpenAI(api_key=api_key)
|
self.openai_client = OpenAI(api_key=api_key)
|
||||||
print("Initialized OpenAI client on demand")
|
print("Initialized OpenAI client on demand")
|
||||||
|
|
||||||
@@ -194,7 +245,20 @@ class ToolActivities:
|
|||||||
return self.parse_json_response(response_content)
|
return self.parse_json_response(response_content)
|
||||||
|
|
||||||
def prompt_llm_ollama(self, input: ToolPromptInput) -> dict:
|
def prompt_llm_ollama(self, input: ToolPromptInput) -> dict:
|
||||||
model_name = os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b")
|
# If not yet initialized, try to do so now (this is a backup if warm_up_ollama wasn't called or failed)
|
||||||
|
if not self.ollama_initialized:
|
||||||
|
print(
|
||||||
|
"Ollama model not pre-loaded. Loading now (this may take 30+ seconds)..."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
self.warm_up_ollama()
|
||||||
|
except Exception:
|
||||||
|
# We already logged the error in warm_up_ollama, continue with the actual request
|
||||||
|
pass
|
||||||
|
|
||||||
|
model_name = self.ollama_model_name or os.environ.get(
|
||||||
|
"OLLAMA_MODEL_NAME", "qwen2.5:14b"
|
||||||
|
)
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
@@ -208,20 +272,29 @@ class ToolActivities:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
response: ChatResponse = chat(model=model_name, messages=messages)
|
try:
|
||||||
|
response: ChatResponse = chat(model=model_name, messages=messages)
|
||||||
|
print(f"Chat response: {response.message.content}")
|
||||||
|
|
||||||
print(f"Chat response: {response.message.content}")
|
# Use the new sanitize function
|
||||||
|
response_content = self.sanitize_json_response(response.message.content)
|
||||||
# Use the new sanitize function
|
return self.parse_json_response(response_content)
|
||||||
response_content = self.sanitize_json_response(response.message.content)
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
|
# Re-raise JSON-related exceptions to let Temporal retry the activity
|
||||||
return self.parse_json_response(response_content)
|
print(f"JSON parsing error with Ollama response: {str(e)}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
# Log and raise other exceptions that may need retrying
|
||||||
|
print(f"Error in Ollama chat: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def prompt_llm_google(self, input: ToolPromptInput) -> dict:
|
def prompt_llm_google(self, input: ToolPromptInput) -> dict:
|
||||||
if not self.genai_configured:
|
if not self.genai_configured:
|
||||||
api_key = os.environ.get("GOOGLE_API_KEY")
|
api_key = os.environ.get("GOOGLE_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError("GOOGLE_API_KEY is not set in the environment variables but LLM_PROVIDER is 'google'")
|
raise ValueError(
|
||||||
|
"GOOGLE_API_KEY is not set in the environment variables but LLM_PROVIDER is 'google'"
|
||||||
|
)
|
||||||
genai.configure(api_key=api_key)
|
genai.configure(api_key=api_key)
|
||||||
self.genai_configured = True
|
self.genai_configured = True
|
||||||
print("Configured Google Generative AI on demand")
|
print("Configured Google Generative AI on demand")
|
||||||
@@ -245,7 +318,9 @@ class ToolActivities:
|
|||||||
if not self.anthropic_client:
|
if not self.anthropic_client:
|
||||||
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError("ANTHROPIC_API_KEY is not set in the environment variables but LLM_PROVIDER is 'anthropic'")
|
raise ValueError(
|
||||||
|
"ANTHROPIC_API_KEY is not set in the environment variables but LLM_PROVIDER is 'anthropic'"
|
||||||
|
)
|
||||||
self.anthropic_client = anthropic.Anthropic(api_key=api_key)
|
self.anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||||
print("Initialized Anthropic client on demand")
|
print("Initialized Anthropic client on demand")
|
||||||
|
|
||||||
@@ -275,7 +350,9 @@ class ToolActivities:
|
|||||||
if not self.deepseek_client:
|
if not self.deepseek_client:
|
||||||
api_key = os.environ.get("DEEPSEEK_API_KEY")
|
api_key = os.environ.get("DEEPSEEK_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError("DEEPSEEK_API_KEY is not set in the environment variables but LLM_PROVIDER is 'deepseek'")
|
raise ValueError(
|
||||||
|
"DEEPSEEK_API_KEY is not set in the environment variables but LLM_PROVIDER is 'deepseek'"
|
||||||
|
)
|
||||||
self.deepseek_client = deepseek.DeepSeekAPI(api_key=api_key)
|
self.deepseek_client = deepseek.DeepSeekAPI(api_key=api_key)
|
||||||
print("Initialized DeepSeek client on demand")
|
print("Initialized DeepSeek client on demand")
|
||||||
|
|
||||||
|
|||||||
@@ -14,11 +14,11 @@ from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE
|
|||||||
async def main():
|
async def main():
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
load_dotenv(override=True)
|
load_dotenv(override=True)
|
||||||
|
|
||||||
# Print LLM configuration info
|
# Print LLM configuration info
|
||||||
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
||||||
print(f"Worker will use LLM provider: {llm_provider}")
|
print(f"Worker will use LLM provider: {llm_provider}")
|
||||||
|
|
||||||
# Create the client
|
# Create the client
|
||||||
client = await get_temporal_client()
|
client = await get_temporal_client()
|
||||||
|
|
||||||
@@ -26,6 +26,29 @@ async def main():
|
|||||||
activities = ToolActivities()
|
activities = ToolActivities()
|
||||||
print(f"ToolActivities initialized with LLM provider: {llm_provider}")
|
print(f"ToolActivities initialized with LLM provider: {llm_provider}")
|
||||||
|
|
||||||
|
# If using Ollama, pre-load the model to avoid cold start latency
|
||||||
|
if llm_provider == "ollama":
|
||||||
|
print("\n======== OLLAMA MODEL INITIALIZATION ========")
|
||||||
|
print("Ollama models need to be loaded into memory on first use.")
|
||||||
|
print("This may take 30+ seconds depending on your hardware and model size.")
|
||||||
|
print("Please wait while the model is being loaded...")
|
||||||
|
|
||||||
|
# This call will load the model and measure initialization time
|
||||||
|
success = activities.warm_up_ollama()
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print("===========================================================")
|
||||||
|
print("✅ Ollama model successfully pre-loaded and ready for requests!")
|
||||||
|
print("===========================================================\n")
|
||||||
|
else:
|
||||||
|
print("===========================================================")
|
||||||
|
print("⚠️ Ollama model pre-loading failed. The worker will continue,")
|
||||||
|
print("but the first actual request may experience a delay while")
|
||||||
|
print("the model is loaded on-demand.")
|
||||||
|
print("===========================================================\n")
|
||||||
|
|
||||||
|
print("Worker ready to process tasks!")
|
||||||
|
|
||||||
# Run the worker
|
# Run the worker
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor:
|
||||||
worker = Worker(
|
worker = Worker(
|
||||||
|
|||||||
Reference in New Issue
Block a user