mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-15 14:08:08 +01:00
activity init function to load llm client only once
This commit is contained in:
@@ -2,7 +2,7 @@ from temporalio import activity
|
|||||||
from ollama import chat, ChatResponse
|
from ollama import chat, ChatResponse
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import json
|
import json
|
||||||
from typing import Sequence
|
from typing import Sequence, Optional
|
||||||
from temporalio.common import RawValue
|
from temporalio.common import RawValue
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -14,16 +14,67 @@ from models.data_types import ValidationInput, ValidationResult, ToolPromptInput
|
|||||||
|
|
||||||
load_dotenv(override=True)
|
load_dotenv(override=True)
|
||||||
print(
|
print(
|
||||||
"Using LLM: "
|
"Using LLM provider: "
|
||||||
+ os.environ.get("LLM_PROVIDER")
|
+ os.environ.get("LLM_PROVIDER", "openai")
|
||||||
+ " (set LLM_PROVIDER in .env to change)"
|
+ " (set LLM_PROVIDER in .env to change)"
|
||||||
)
|
)
|
||||||
|
|
||||||
if os.environ.get("LLM_PROVIDER") == "ollama":
|
if os.environ.get("LLM_PROVIDER") == "ollama":
|
||||||
print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME"))
|
print("Using Ollama (local) model: " + os.environ.get("OLLAMA_MODEL_NAME", "qwen2.5:14b"))
|
||||||
|
|
||||||
|
|
||||||
class ToolActivities:
|
class ToolActivities:
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize LLM clients based on environment configuration."""
|
||||||
|
self.llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
||||||
|
print(f"Initializing ToolActivities with LLM provider: {self.llm_provider}")
|
||||||
|
|
||||||
|
# Initialize client variables (all set to None initially)
|
||||||
|
self.openai_client: Optional[OpenAI] = None
|
||||||
|
self.anthropic_client: Optional[anthropic.Anthropic] = None
|
||||||
|
self.genai_configured: bool = False
|
||||||
|
self.deepseek_client: Optional[deepseek.DeepSeekAPI] = None
|
||||||
|
|
||||||
|
# Only initialize the client specified by LLM_PROVIDER
|
||||||
|
if self.llm_provider == "openai":
|
||||||
|
if os.environ.get("OPENAI_API_KEY"):
|
||||||
|
self.openai_client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
||||||
|
print("Initialized OpenAI client")
|
||||||
|
else:
|
||||||
|
print("Warning: OPENAI_API_KEY not set but LLM_PROVIDER is 'openai'")
|
||||||
|
|
||||||
|
elif self.llm_provider == "anthropic":
|
||||||
|
if os.environ.get("ANTHROPIC_API_KEY"):
|
||||||
|
self.anthropic_client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
|
||||||
|
print("Initialized Anthropic client")
|
||||||
|
else:
|
||||||
|
print("Warning: ANTHROPIC_API_KEY not set but LLM_PROVIDER is 'anthropic'")
|
||||||
|
|
||||||
|
elif self.llm_provider == "google":
|
||||||
|
api_key = os.environ.get("GOOGLE_API_KEY")
|
||||||
|
if api_key:
|
||||||
|
genai.configure(api_key=api_key)
|
||||||
|
self.genai_configured = True
|
||||||
|
print("Configured Google Generative AI")
|
||||||
|
else:
|
||||||
|
print("Warning: GOOGLE_API_KEY not set but LLM_PROVIDER is 'google'")
|
||||||
|
|
||||||
|
elif self.llm_provider == "deepseek":
|
||||||
|
if os.environ.get("DEEPSEEK_API_KEY"):
|
||||||
|
self.deepseek_client = deepseek.DeepSeekAPI(api_key=os.environ.get("DEEPSEEK_API_KEY"))
|
||||||
|
print("Initialized DeepSeek client")
|
||||||
|
else:
|
||||||
|
print("Warning: DEEPSEEK_API_KEY not set but LLM_PROVIDER is 'deepseek'")
|
||||||
|
|
||||||
|
# Ollama is initialized on-demand since it's a local API call
|
||||||
|
elif self.llm_provider == "ollama":
|
||||||
|
if not os.environ.get("OLLAMA_MODEL_NAME"):
|
||||||
|
print("Warning: OLLAMA_MODEL_NAME not set, will use default 'qwen2.5:14b'")
|
||||||
|
else:
|
||||||
|
print(f"Using Ollama model: {os.environ.get('OLLAMA_MODEL_NAME')}")
|
||||||
|
else:
|
||||||
|
print(f"Warning: Unknown LLM_PROVIDER '{self.llm_provider}', defaulting to OpenAI")
|
||||||
|
|
||||||
@activity.defn
|
@activity.defn
|
||||||
async def agent_validatePrompt(
|
async def agent_validatePrompt(
|
||||||
self, validation_input: ValidationInput
|
self, validation_input: ValidationInput
|
||||||
@@ -87,17 +138,13 @@ class ToolActivities:
|
|||||||
|
|
||||||
@activity.defn
|
@activity.defn
|
||||||
def agent_toolPlanner(self, input: ToolPromptInput) -> dict:
|
def agent_toolPlanner(self, input: ToolPromptInput) -> dict:
|
||||||
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
if self.llm_provider == "ollama":
|
||||||
|
|
||||||
print(f"LLM provider: {llm_provider}")
|
|
||||||
|
|
||||||
if llm_provider == "ollama":
|
|
||||||
return self.prompt_llm_ollama(input)
|
return self.prompt_llm_ollama(input)
|
||||||
elif llm_provider == "google":
|
elif self.llm_provider == "google":
|
||||||
return self.prompt_llm_google(input)
|
return self.prompt_llm_google(input)
|
||||||
elif llm_provider == "anthropic":
|
elif self.llm_provider == "anthropic":
|
||||||
return self.prompt_llm_anthropic(input)
|
return self.prompt_llm_anthropic(input)
|
||||||
elif llm_provider == "deepseek":
|
elif self.llm_provider == "deepseek":
|
||||||
return self.prompt_llm_deepseek(input)
|
return self.prompt_llm_deepseek(input)
|
||||||
else:
|
else:
|
||||||
return self.prompt_llm_openai(input)
|
return self.prompt_llm_openai(input)
|
||||||
@@ -114,9 +161,12 @@ class ToolActivities:
|
|||||||
raise json.JSONDecodeError
|
raise json.JSONDecodeError
|
||||||
|
|
||||||
def prompt_llm_openai(self, input: ToolPromptInput) -> dict:
|
def prompt_llm_openai(self, input: ToolPromptInput) -> dict:
|
||||||
client = OpenAI(
|
if not self.openai_client:
|
||||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
)
|
if not api_key:
|
||||||
|
raise ValueError("OPENAI_API_KEY is not set in the environment variables but LLM_PROVIDER is 'openai'")
|
||||||
|
self.openai_client = OpenAI(api_key=api_key)
|
||||||
|
print("Initialized OpenAI client on demand")
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
@@ -131,7 +181,7 @@ class ToolActivities:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
chat_completion = client.chat.completions.create(
|
chat_completion = self.openai_client.chat.completions.create(
|
||||||
model="gpt-4o", messages=messages # was gpt-4-0613
|
model="gpt-4o", messages=messages # was gpt-4-0613
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -168,11 +218,14 @@ class ToolActivities:
|
|||||||
return self.parse_json_response(response_content)
|
return self.parse_json_response(response_content)
|
||||||
|
|
||||||
def prompt_llm_google(self, input: ToolPromptInput) -> dict:
|
def prompt_llm_google(self, input: ToolPromptInput) -> dict:
|
||||||
|
if not self.genai_configured:
|
||||||
api_key = os.environ.get("GOOGLE_API_KEY")
|
api_key = os.environ.get("GOOGLE_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError("GOOGLE_API_KEY is not set in the environment variables.")
|
raise ValueError("GOOGLE_API_KEY is not set in the environment variables but LLM_PROVIDER is 'google'")
|
||||||
|
|
||||||
genai.configure(api_key=api_key)
|
genai.configure(api_key=api_key)
|
||||||
|
self.genai_configured = True
|
||||||
|
print("Configured Google Generative AI on demand")
|
||||||
|
|
||||||
model = genai.GenerativeModel(
|
model = genai.GenerativeModel(
|
||||||
"models/gemini-1.5-flash",
|
"models/gemini-1.5-flash",
|
||||||
system_instruction=input.context_instructions
|
system_instruction=input.context_instructions
|
||||||
@@ -189,15 +242,14 @@ class ToolActivities:
|
|||||||
return self.parse_json_response(response_content)
|
return self.parse_json_response(response_content)
|
||||||
|
|
||||||
def prompt_llm_anthropic(self, input: ToolPromptInput) -> dict:
|
def prompt_llm_anthropic(self, input: ToolPromptInput) -> dict:
|
||||||
|
if not self.anthropic_client:
|
||||||
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise ValueError(
|
raise ValueError("ANTHROPIC_API_KEY is not set in the environment variables but LLM_PROVIDER is 'anthropic'")
|
||||||
"ANTHROPIC_API_KEY is not set in the environment variables."
|
self.anthropic_client = anthropic.Anthropic(api_key=api_key)
|
||||||
)
|
print("Initialized Anthropic client on demand")
|
||||||
|
|
||||||
client = anthropic.Anthropic(api_key=api_key)
|
response = self.anthropic_client.messages.create(
|
||||||
|
|
||||||
response = client.messages.create(
|
|
||||||
model="claude-3-5-sonnet-20241022", # todo try claude-3-7-sonnet-20250219
|
model="claude-3-5-sonnet-20241022", # todo try claude-3-7-sonnet-20250219
|
||||||
max_tokens=1024,
|
max_tokens=1024,
|
||||||
system=input.context_instructions
|
system=input.context_instructions
|
||||||
@@ -220,7 +272,12 @@ class ToolActivities:
|
|||||||
return self.parse_json_response(response_content)
|
return self.parse_json_response(response_content)
|
||||||
|
|
||||||
def prompt_llm_deepseek(self, input: ToolPromptInput) -> dict:
|
def prompt_llm_deepseek(self, input: ToolPromptInput) -> dict:
|
||||||
api_client = deepseek.DeepSeekAPI(api_key=os.environ.get("DEEPSEEK_API_KEY"))
|
if not self.deepseek_client:
|
||||||
|
api_key = os.environ.get("DEEPSEEK_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("DEEPSEEK_API_KEY is not set in the environment variables but LLM_PROVIDER is 'deepseek'")
|
||||||
|
self.deepseek_client = deepseek.DeepSeekAPI(api_key=api_key)
|
||||||
|
print("Initialized DeepSeek client on demand")
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
@@ -235,7 +292,7 @@ class ToolActivities:
|
|||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
response = api_client.chat_completion(prompt=messages)
|
response = self.deepseek_client.chat_completion(prompt=messages)
|
||||||
response_content = response
|
response_content = response
|
||||||
print(f"DeepSeek response: {response_content}")
|
print(f"DeepSeek response: {response_content}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from temporalio.worker import Worker
|
from temporalio.worker import Worker
|
||||||
|
|
||||||
@@ -11,10 +12,19 @@ from shared.config import get_temporal_client, TEMPORAL_TASK_QUEUE
|
|||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
# Print LLM configuration info
|
||||||
|
llm_provider = os.environ.get("LLM_PROVIDER", "openai").lower()
|
||||||
|
print(f"Worker will use LLM provider: {llm_provider}")
|
||||||
|
|
||||||
# Create the client
|
# Create the client
|
||||||
client = await get_temporal_client()
|
client = await get_temporal_client()
|
||||||
|
|
||||||
|
# Initialize the activities class once with the specified LLM provider
|
||||||
activities = ToolActivities()
|
activities = ToolActivities()
|
||||||
|
print(f"ToolActivities initialized with LLM provider: {llm_provider}")
|
||||||
|
|
||||||
# Run the worker
|
# Run the worker
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=100) as activity_executor:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from shared.config import TEMPORAL_LEGACY_TASK_QUEUE
|
|||||||
# Constants from original file
|
# Constants from original file
|
||||||
TOOL_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=12)
|
TOOL_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=12)
|
||||||
TOOL_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30)
|
TOOL_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30)
|
||||||
LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=12)
|
LLM_ACTIVITY_START_TO_CLOSE_TIMEOUT = timedelta(seconds=20)
|
||||||
LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30)
|
LLM_ACTIVITY_SCHEDULE_TO_CLOSE_TIMEOUT = timedelta(minutes=30)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user