mirror of
https://github.com/temporal-community/temporal-ai-agent.git
synced 2026-03-16 22:48:09 +01:00
google gemini support
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import Sequence
|
||||
from temporalio.common import RawValue
|
||||
import os
|
||||
from datetime import datetime
|
||||
import google.generativeai as genai
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -22,6 +23,8 @@ class ToolActivities:
|
||||
|
||||
if llm_provider == "ollama":
|
||||
return self.prompt_llm_ollama(input)
|
||||
elif llm_provider == "google":
|
||||
return self.prompt_llm_google(input)
|
||||
else:
|
||||
return self.prompt_llm_openai(input)
|
||||
|
||||
@@ -95,6 +98,31 @@ class ToolActivities:
|
||||
|
||||
return data
|
||||
|
||||
def prompt_llm_google(self, input: ToolPromptInput) -> dict:
|
||||
api_key = os.environ.get("GOOGLE_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("GOOGLE_API_KEY is not set in the environment variables.")
|
||||
|
||||
genai.configure(api_key=api_key)
|
||||
model = genai.GenerativeModel(
|
||||
"models/gemini-1.5-flash",
|
||||
system_instruction=input.context_instructions,
|
||||
)
|
||||
response = model.generate_content(input.prompt)
|
||||
response_content = response.text
|
||||
print(f"Google Gemini response: {response_content}")
|
||||
|
||||
# Use the new sanitize function
|
||||
response_content = self.sanitize_json_response(response_content)
|
||||
|
||||
try:
|
||||
data = json.loads(response_content)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Invalid JSON: {e}")
|
||||
raise json.JSONDecodeError
|
||||
|
||||
return data
|
||||
|
||||
def sanitize_json_response(self, response_content: str) -> str:
|
||||
"""
|
||||
Extracts the JSON block from the response content as a string.
|
||||
|
||||
Reference in New Issue
Block a user