google gemini support

This commit is contained in:
Steve Androulakis
2025-01-24 15:53:45 -08:00
parent 7977894f64
commit 20fdf935ab
6 changed files with 384 additions and 3 deletions

View File

@@ -7,6 +7,7 @@ from typing import Sequence
from temporalio.common import RawValue
import os
from datetime import datetime
import google.generativeai as genai
@dataclass
@@ -22,6 +23,8 @@ class ToolActivities:
if llm_provider == "ollama":
return self.prompt_llm_ollama(input)
elif llm_provider == "google":
return self.prompt_llm_google(input)
else:
return self.prompt_llm_openai(input)
@@ -95,6 +98,31 @@ class ToolActivities:
return data
def prompt_llm_google(self, input: ToolPromptInput) -> dict:
api_key = os.environ.get("GOOGLE_API_KEY")
if not api_key:
raise ValueError("GOOGLE_API_KEY is not set in the environment variables.")
genai.configure(api_key=api_key)
model = genai.GenerativeModel(
"models/gemini-1.5-flash",
system_instruction=input.context_instructions,
)
response = model.generate_content(input.prompt)
response_content = response.text
print(f"Google Gemini response: {response_content}")
# Use the new sanitize function
response_content = self.sanitize_json_response(response_content)
try:
data = json.loads(response_content)
except json.JSONDecodeError as e:
print(f"Invalid JSON: {e}")
raise json.JSONDecodeError
return data
def sanitize_json_response(self, response_content: str) -> str:
"""
Extracts the JSON block from the response content as a string.