diff --git a/.env.example b/.env.example index 3f29a28..f3ce08d 100644 --- a/.env.example +++ b/.env.example @@ -9,6 +9,9 @@ LLM_PROVIDER=openai # default # or # LLM_PROVIDER=ollama # OLLAMA_MODEL_NAME=qwen2.5:14b +# or +# LLM_PROVIDER=google +# GOOGLE_API_KEY=your-google-api-key # uncomment and unset these environment variables to connect to the local dev server # TEMPORAL_ADDRESS=namespace.acct.tmprl.cloud:7233 diff --git a/README.md b/README.md index 72f36d2..6d05fb4 100644 --- a/README.md +++ b/README.md @@ -16,10 +16,11 @@ cp .env.example .env ### LLM Provider Configuration -The agent can use either OpenAI's GPT-4o or a local LLM via Ollama. Set the `LLM_PROVIDER` environment variable in your `.env` file to choose the desired provider: +The agent can use either OpenAI's GPT-4o, a local LLM via Ollama, or Google Gemini. Set the `LLM_PROVIDER` environment variable in your `.env` file to choose the desired provider: - `LLM_PROVIDER=openai` for OpenAI's GPT-4o - `LLM_PROVIDER=ollama` for the local LLM via Ollama (not recommended for this use case) +- `LLM_PROVIDER=google` for Google Gemini ### OpenAI Configuration @@ -37,6 +38,13 @@ To use a local LLM with Ollama: Note: The local LLM is disabled by default as ChatGPT 4o was found to be MUCH more reliable for this use case. However, you can switch to Ollama if desired. +### Google Gemini Configuration + +To use Google Gemini: + +1. Obtain a Google API key and set it in the `GOOGLE_API_KEY` environment variable in `.env`. +2. Set `LLM_PROVIDER=google` in your `.env` file. + ## Agent Tools * Requires a Rapidapi key for sky-scrapper (how we find flights). Set this in the `RAPIDAPI_KEY` environment variable in .env * It's free to sign up and get a key at [RapidAPI](https://rapidapi.com/apiheya/api/sky-scrapper) diff --git a/activities/tool_activities.py b/activities/tool_activities.py index 215c289..a78cc7d 100644 --- a/activities/tool_activities.py +++ b/activities/tool_activities.py @@ -7,6 +7,7 @@ from typing import Sequence from temporalio.common import RawValue import os from datetime import datetime +import google.generativeai as genai @dataclass @@ -22,6 +23,8 @@ class ToolActivities: if llm_provider == "ollama": return self.prompt_llm_ollama(input) + elif llm_provider == "google": + return self.prompt_llm_google(input) else: return self.prompt_llm_openai(input) @@ -95,6 +98,31 @@ class ToolActivities: return data + def prompt_llm_google(self, input: ToolPromptInput) -> dict: + api_key = os.environ.get("GOOGLE_API_KEY") + if not api_key: + raise ValueError("GOOGLE_API_KEY is not set in the environment variables.") + + genai.configure(api_key=api_key) + model = genai.GenerativeModel( + "models/gemini-1.5-flash", + system_instruction=input.context_instructions, + ) + response = model.generate_content(input.prompt) + response_content = response.text + print(f"Google Gemini response: {response_content}") + + # Use the new sanitize function + response_content = self.sanitize_json_response(response_content) + + try: + data = json.loads(response_content) + except json.JSONDecodeError as e: + print(f"Invalid JSON: {e}") + raise json.JSONDecodeError + + return data + def sanitize_json_response(self, response_content: str) -> str: """ Extracts the JSON block from the response content as a string. diff --git a/poetry.lock b/poetry.lock index 2448408..11faa82 100644 --- a/poetry.lock +++ b/poetry.lock @@ -79,6 +79,17 @@ d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] +[[package]] +name = "cachetools" +version = "5.5.1" +description = "Extensible memoizing collections and decorators" +optional = false +python-versions = ">=3.7" +files = [ + {file = "cachetools-5.5.1-py3-none-any.whl", hash = "sha256:b76651fdc3b24ead3c648bbdeeb940c1b04d365b38b4af66788f9ec4a81d42bb"}, + {file = "cachetools-5.5.1.tar.gz", hash = "sha256:70f238fbba50383ef62e55c6aff6d9673175fe59f7c6782c7a0b9e38f4a9df95"}, +] + [[package]] name = "certifi" version = "2024.12.14" @@ -261,6 +272,241 @@ typing-extensions = ">=4.8.0" all = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.7)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] standard = ["email-validator (>=2.0.0)", "fastapi-cli[standard] (>=0.0.5)", "httpx (>=0.23.0)", "jinja2 (>=2.11.2)", "python-multipart (>=0.0.7)", "uvicorn[standard] (>=0.12.0)"] +[[package]] +name = "google-ai-generativelanguage" +version = "0.6.15" +description = "Google Ai Generativelanguage API client library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google_ai_generativelanguage-0.6.15-py3-none-any.whl", hash = "sha256:5a03ef86377aa184ffef3662ca28f19eeee158733e45d7947982eb953c6ebb6c"}, + {file = "google_ai_generativelanguage-0.6.15.tar.gz", hash = "sha256:8f6d9dc4c12b065fe2d0289026171acea5183ebf2d0b11cefe12f3821e159ec3"}, +] + +[package.dependencies] +google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" +proto-plus = [ + {version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""}, + {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, +] +protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" + +[[package]] +name = "google-api-core" +version = "2.24.0" +description = "Google API client core library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google_api_core-2.24.0-py3-none-any.whl", hash = "sha256:10d82ac0fca69c82a25b3efdeefccf6f28e02ebb97925a8cce8edbfe379929d9"}, + {file = "google_api_core-2.24.0.tar.gz", hash = "sha256:e255640547a597a4da010876d333208ddac417d60add22b6851a0c66a831fcaf"}, +] + +[package.dependencies] +google-auth = ">=2.14.1,<3.0.dev0" +googleapis-common-protos = ">=1.56.2,<2.0.dev0" +grpcio = [ + {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}, + {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, +] +grpcio-status = [ + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""}, + {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, +] +proto-plus = [ + {version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""}, + {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, +] +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" +requests = ">=2.18.0,<3.0.0.dev0" + +[package.extras] +async-rest = ["google-auth[aiohttp] (>=2.35.0,<3.0.dev0)"] +grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"] +grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] +grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] + +[[package]] +name = "google-api-python-client" +version = "2.159.0" +description = "Google API Client Library for Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google_api_python_client-2.159.0-py2.py3-none-any.whl", hash = "sha256:baef0bb631a60a0bd7c0bf12a5499e3a40cd4388484de7ee55c1950bf820a0cf"}, + {file = "google_api_python_client-2.159.0.tar.gz", hash = "sha256:55197f430f25c907394b44fa078545ffef89d33fd4dca501b7db9f0d8e224bd6"}, +] + +[package.dependencies] +google-api-core = ">=1.31.5,<2.0.dev0 || >2.3.0,<3.0.0.dev0" +google-auth = ">=1.32.0,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0.dev0" +google-auth-httplib2 = ">=0.2.0,<1.0.0" +httplib2 = ">=0.19.0,<1.dev0" +uritemplate = ">=3.0.1,<5" + +[[package]] +name = "google-auth" +version = "2.38.0" +description = "Google Authentication Library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google_auth-2.38.0-py2.py3-none-any.whl", hash = "sha256:e7dae6694313f434a2727bf2906f27ad259bae090d7aa896590d86feec3d9d4a"}, + {file = "google_auth-2.38.0.tar.gz", hash = "sha256:8285113607d3b80a3f1543b75962447ba8a09fe85783432a784fdeef6ac094c4"}, +] + +[package.dependencies] +cachetools = ">=2.0.0,<6.0" +pyasn1-modules = ">=0.2.1" +rsa = ">=3.1.4,<5" + +[package.extras] +aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"] +enterprise-cert = ["cryptography", "pyopenssl"] +pyjwt = ["cryptography (>=38.0.3)", "pyjwt (>=2.0)"] +pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] +reauth = ["pyu2f (>=0.1.5)"] +requests = ["requests (>=2.20.0,<3.0.0.dev0)"] + +[[package]] +name = "google-auth-httplib2" +version = "0.2.0" +description = "Google Authentication Library: httplib2 transport" +optional = false +python-versions = "*" +files = [ + {file = "google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05"}, + {file = "google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d"}, +] + +[package.dependencies] +google-auth = "*" +httplib2 = ">=0.19.0" + +[[package]] +name = "google-generativeai" +version = "0.8.4" +description = "Google Generative AI High level API client library and tools." +optional = false +python-versions = ">=3.9" +files = [ + {file = "google_generativeai-0.8.4-py3-none-any.whl", hash = "sha256:e987b33ea6decde1e69191ddcaec6ef974458864d243de7191db50c21a7c5b82"}, +] + +[package.dependencies] +google-ai-generativelanguage = "0.6.15" +google-api-core = "*" +google-api-python-client = "*" +google-auth = ">=2.15.0" +protobuf = "*" +pydantic = "*" +tqdm = "*" +typing-extensions = "*" + +[package.extras] +dev = ["Pillow", "absl-py", "black", "ipython", "nose2", "pandas", "pytype", "pyyaml"] + +[[package]] +name = "googleapis-common-protos" +version = "1.66.0" +description = "Common protobufs used in Google APIs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "googleapis_common_protos-1.66.0-py2.py3-none-any.whl", hash = "sha256:d7abcd75fabb2e0ec9f74466401f6c119a0b498e27370e9be4c94cb7e382b8ed"}, + {file = "googleapis_common_protos-1.66.0.tar.gz", hash = "sha256:c3e7b33d15fdca5374cc0a7346dd92ffa847425cc4ea941d970f13680052ec8c"}, +] + +[package.dependencies] +protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" + +[package.extras] +grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] + +[[package]] +name = "grpcio" +version = "1.70.0" +description = "HTTP/2-based RPC framework" +optional = false +python-versions = ">=3.8" +files = [ + {file = "grpcio-1.70.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:95469d1977429f45fe7df441f586521361e235982a0b39e33841549143ae2851"}, + {file = "grpcio-1.70.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:ed9718f17fbdb472e33b869c77a16d0b55e166b100ec57b016dc7de9c8d236bf"}, + {file = "grpcio-1.70.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:374d014f29f9dfdb40510b041792e0e2828a1389281eb590df066e1cc2b404e5"}, + {file = "grpcio-1.70.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2af68a6f5c8f78d56c145161544ad0febbd7479524a59c16b3e25053f39c87f"}, + {file = "grpcio-1.70.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce7df14b2dcd1102a2ec32f621cc9fab6695effef516efbc6b063ad749867295"}, + {file = "grpcio-1.70.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c78b339869f4dbf89881e0b6fbf376313e4f845a42840a7bdf42ee6caed4b11f"}, + {file = "grpcio-1.70.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:58ad9ba575b39edef71f4798fdb5c7b6d02ad36d47949cd381d4392a5c9cbcd3"}, + {file = "grpcio-1.70.0-cp310-cp310-win32.whl", hash = "sha256:2b0d02e4b25a5c1f9b6c7745d4fa06efc9fd6a611af0fb38d3ba956786b95199"}, + {file = "grpcio-1.70.0-cp310-cp310-win_amd64.whl", hash = "sha256:0de706c0a5bb9d841e353f6343a9defc9fc35ec61d6eb6111802f3aa9fef29e1"}, + {file = "grpcio-1.70.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:17325b0be0c068f35770f944124e8839ea3185d6d54862800fc28cc2ffad205a"}, + {file = "grpcio-1.70.0-cp311-cp311-macosx_10_14_universal2.whl", hash = "sha256:dbe41ad140df911e796d4463168e33ef80a24f5d21ef4d1e310553fcd2c4a386"}, + {file = "grpcio-1.70.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:5ea67c72101d687d44d9c56068328da39c9ccba634cabb336075fae2eab0d04b"}, + {file = "grpcio-1.70.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb5277db254ab7586769e490b7b22f4ddab3876c490da0a1a9d7c695ccf0bf77"}, + {file = "grpcio-1.70.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7831a0fc1beeeb7759f737f5acd9fdcda520e955049512d68fda03d91186eea"}, + {file = "grpcio-1.70.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:27cc75e22c5dba1fbaf5a66c778e36ca9b8ce850bf58a9db887754593080d839"}, + {file = "grpcio-1.70.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d63764963412e22f0491d0d32833d71087288f4e24cbcddbae82476bfa1d81fd"}, + {file = "grpcio-1.70.0-cp311-cp311-win32.whl", hash = "sha256:bb491125103c800ec209d84c9b51f1c60ea456038e4734688004f377cfacc113"}, + {file = "grpcio-1.70.0-cp311-cp311-win_amd64.whl", hash = "sha256:d24035d49e026353eb042bf7b058fb831db3e06d52bee75c5f2f3ab453e71aca"}, + {file = "grpcio-1.70.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:ef4c14508299b1406c32bdbb9fb7b47612ab979b04cf2b27686ea31882387cff"}, + {file = "grpcio-1.70.0-cp312-cp312-macosx_10_14_universal2.whl", hash = "sha256:aa47688a65643afd8b166928a1da6247d3f46a2784d301e48ca1cc394d2ffb40"}, + {file = "grpcio-1.70.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:880bfb43b1bb8905701b926274eafce5c70a105bc6b99e25f62e98ad59cb278e"}, + {file = "grpcio-1.70.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e654c4b17d07eab259d392e12b149c3a134ec52b11ecdc6a515b39aceeec898"}, + {file = "grpcio-1.70.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2394e3381071045a706ee2eeb6e08962dd87e8999b90ac15c55f56fa5a8c9597"}, + {file = "grpcio-1.70.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:b3c76701428d2df01964bc6479422f20e62fcbc0a37d82ebd58050b86926ef8c"}, + {file = "grpcio-1.70.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ac073fe1c4cd856ebcf49e9ed6240f4f84d7a4e6ee95baa5d66ea05d3dd0df7f"}, + {file = "grpcio-1.70.0-cp312-cp312-win32.whl", hash = "sha256:cd24d2d9d380fbbee7a5ac86afe9787813f285e684b0271599f95a51bce33528"}, + {file = "grpcio-1.70.0-cp312-cp312-win_amd64.whl", hash = "sha256:0495c86a55a04a874c7627fd33e5beaee771917d92c0e6d9d797628ac40e7655"}, + {file = "grpcio-1.70.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:aa573896aeb7d7ce10b1fa425ba263e8dddd83d71530d1322fd3a16f31257b4a"}, + {file = "grpcio-1.70.0-cp313-cp313-macosx_10_14_universal2.whl", hash = "sha256:d405b005018fd516c9ac529f4b4122342f60ec1cee181788249372524e6db429"}, + {file = "grpcio-1.70.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:f32090238b720eb585248654db8e3afc87b48d26ac423c8dde8334a232ff53c9"}, + {file = "grpcio-1.70.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:dfa089a734f24ee5f6880c83d043e4f46bf812fcea5181dcb3a572db1e79e01c"}, + {file = "grpcio-1.70.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f19375f0300b96c0117aca118d400e76fede6db6e91f3c34b7b035822e06c35f"}, + {file = "grpcio-1.70.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:7c73c42102e4a5ec76608d9b60227d917cea46dff4d11d372f64cbeb56d259d0"}, + {file = "grpcio-1.70.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:0a5c78d5198a1f0aa60006cd6eb1c912b4a1520b6a3968e677dbcba215fabb40"}, + {file = "grpcio-1.70.0-cp313-cp313-win32.whl", hash = "sha256:fe9dbd916df3b60e865258a8c72ac98f3ac9e2a9542dcb72b7a34d236242a5ce"}, + {file = "grpcio-1.70.0-cp313-cp313-win_amd64.whl", hash = "sha256:4119fed8abb7ff6c32e3d2255301e59c316c22d31ab812b3fbcbaf3d0d87cc68"}, + {file = "grpcio-1.70.0-cp38-cp38-linux_armv7l.whl", hash = "sha256:8058667a755f97407fca257c844018b80004ae8035565ebc2812cc550110718d"}, + {file = "grpcio-1.70.0-cp38-cp38-macosx_10_14_universal2.whl", hash = "sha256:879a61bf52ff8ccacbedf534665bb5478ec8e86ad483e76fe4f729aaef867cab"}, + {file = "grpcio-1.70.0-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:0ba0a173f4feacf90ee618fbc1a27956bfd21260cd31ced9bc707ef551ff7dc7"}, + {file = "grpcio-1.70.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:558c386ecb0148f4f99b1a65160f9d4b790ed3163e8610d11db47838d452512d"}, + {file = "grpcio-1.70.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:412faabcc787bbc826f51be261ae5fa996b21263de5368a55dc2cf824dc5090e"}, + {file = "grpcio-1.70.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3b0f01f6ed9994d7a0b27eeddea43ceac1b7e6f3f9d86aeec0f0064b8cf50fdb"}, + {file = "grpcio-1.70.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7385b1cb064734005204bc8994eed7dcb801ed6c2eda283f613ad8c6c75cf873"}, + {file = "grpcio-1.70.0-cp38-cp38-win32.whl", hash = "sha256:07269ff4940f6fb6710951116a04cd70284da86d0a4368fd5a3b552744511f5a"}, + {file = "grpcio-1.70.0-cp38-cp38-win_amd64.whl", hash = "sha256:aba19419aef9b254e15011b230a180e26e0f6864c90406fdbc255f01d83bc83c"}, + {file = "grpcio-1.70.0-cp39-cp39-linux_armv7l.whl", hash = "sha256:4f1937f47c77392ccd555728f564a49128b6a197a05a5cd527b796d36f3387d0"}, + {file = "grpcio-1.70.0-cp39-cp39-macosx_10_14_universal2.whl", hash = "sha256:0cd430b9215a15c10b0e7d78f51e8a39d6cf2ea819fd635a7214fae600b1da27"}, + {file = "grpcio-1.70.0-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:e27585831aa6b57b9250abaf147003e126cd3a6c6ca0c531a01996f31709bed1"}, + {file = "grpcio-1.70.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c1af8e15b0f0fe0eac75195992a63df17579553b0c4af9f8362cc7cc99ccddf4"}, + {file = "grpcio-1.70.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cbce24409beaee911c574a3d75d12ffb8c3e3dd1b813321b1d7a96bbcac46bf4"}, + {file = "grpcio-1.70.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ff4a8112a79464919bb21c18e956c54add43ec9a4850e3949da54f61c241a4a6"}, + {file = "grpcio-1.70.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5413549fdf0b14046c545e19cfc4eb1e37e9e1ebba0ca390a8d4e9963cab44d2"}, + {file = "grpcio-1.70.0-cp39-cp39-win32.whl", hash = "sha256:b745d2c41b27650095e81dea7091668c040457483c9bdb5d0d9de8f8eb25e59f"}, + {file = "grpcio-1.70.0-cp39-cp39-win_amd64.whl", hash = "sha256:a31d7e3b529c94e930a117b2175b2efd179d96eb3c7a21ccb0289a8ab05b645c"}, + {file = "grpcio-1.70.0.tar.gz", hash = "sha256:8d1584a68d5922330025881e63a6c1b54cc8117291d382e4fa69339b6d914c56"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.70.0)"] + +[[package]] +name = "grpcio-status" +version = "1.70.0" +description = "Status proto mapping for gRPC" +optional = false +python-versions = ">=3.8" +files = [ + {file = "grpcio_status-1.70.0-py3-none-any.whl", hash = "sha256:fc5a2ae2b9b1c1969cc49f3262676e6854aa2398ec69cb5bd6c47cd501904a85"}, + {file = "grpcio_status-1.70.0.tar.gz", hash = "sha256:0e7b42816512433b18b9d764285ff029bde059e9d41f8fe10a60631bd8348101"}, +] + +[package.dependencies] +googleapis-common-protos = ">=1.5.5" +grpcio = ">=1.70.0" +protobuf = ">=5.26.1,<6.0dev" + [[package]] name = "h11" version = "0.14.0" @@ -293,6 +539,20 @@ http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] trio = ["trio (>=0.22.0,<1.0)"] +[[package]] +name = "httplib2" +version = "0.22.0" +description = "A comprehensive HTTP client library." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc"}, + {file = "httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81"}, +] + +[package.dependencies] +pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""} + [[package]] name = "httpx" version = "0.27.2" @@ -546,6 +806,23 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "proto-plus" +version = "1.25.0" +description = "Beautiful, Pythonic protocol buffers." +optional = false +python-versions = ">=3.7" +files = [ + {file = "proto_plus-1.25.0-py3-none-any.whl", hash = "sha256:c91fc4a65074ade8e458e95ef8bac34d4008daa7cce4a12d6707066fca648961"}, + {file = "proto_plus-1.25.0.tar.gz", hash = "sha256:fbb17f57f7bd05a68b7707e745e26528b0b3c34e378db91eef93912c54982d91"}, +] + +[package.dependencies] +protobuf = ">=3.19.0,<6.0.0dev" + +[package.extras] +testing = ["google-api-core (>=1.31.5)"] + [[package]] name = "protobuf" version = "5.29.2" @@ -566,6 +843,31 @@ files = [ {file = "protobuf-5.29.2.tar.gz", hash = "sha256:b2cc8e8bb7c9326996f0e160137b0861f1a82162502658df2951209d0cb0309e"}, ] +[[package]] +name = "pyasn1" +version = "0.6.1" +description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, + {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.1" +description = "A collection of ASN.1-based protocols modules" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"}, + {file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"}, +] + +[package.dependencies] +pyasn1 = ">=0.4.6,<0.7.0" + [[package]] name = "pydantic" version = "2.10.4" @@ -698,6 +1000,20 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pyparsing" +version = "3.2.1" +description = "pyparsing module - Classes and methods to define and execute parsing grammars" +optional = false +python-versions = ">=3.9" +files = [ + {file = "pyparsing-3.2.1-py3-none-any.whl", hash = "sha256:506ff4f4386c4cec0590ec19e6302d3aedb992fdc02c761e90416f158dacf8e1"}, + {file = "pyparsing-3.2.1.tar.gz", hash = "sha256:61980854fd66de3a90028d679a954d5f2623e83144b5afe5ee86f43d762e5f0a"}, +] + +[package.extras] +diagrams = ["jinja2", "railroad-diagrams"] + [[package]] name = "pytest" version = "7.4.4" @@ -831,6 +1147,20 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "rsa" +version = "4.9" +description = "Pure-Python RSA implementation" +optional = false +python-versions = ">=3.6,<4" +files = [ + {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, + {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, +] + +[package.dependencies] +pyasn1 = ">=0.1.3" + [[package]] name = "six" version = "1.17.0" @@ -995,6 +1325,17 @@ files = [ {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +[[package]] +name = "uritemplate" +version = "4.1.1" +description = "Implementation of RFC 6570 URI Templates" +optional = false +python-versions = ">=3.6" +files = [ + {file = "uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e"}, + {file = "uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0"}, +] + [[package]] name = "urllib3" version = "2.3.0" @@ -1034,4 +1375,4 @@ standard = ["colorama (>=0.4)", "httptools (>=0.6.3)", "python-dotenv (>=0.13)", [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "7eca0acee246ab22c53515536bb18adf35ef6264e667681f9f6f1029e303389d" +content-hash = "6607357cf7851b278f44b13b3eccbe4f5309d8c038bba3bdc2114b5989832ba2" diff --git a/pyproject.toml b/pyproject.toml index 65c827e..908e4ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ uvicorn = "^0.34.0" python-dotenv = "^1.0.1" openai = "^1.59.2" stripe = "^11.4.1" +google-generativeai = "^0.8.4" [tool.poetry.group.dev.dependencies] pytest = "^7.3" diff --git a/workflows/tool_workflow.py b/workflows/tool_workflow.py index c325aeb..f80ee7a 100644 --- a/workflows/tool_workflow.py +++ b/workflows/tool_workflow.py @@ -59,7 +59,7 @@ class ToolWorkflow: self.prompt_queue.append( f"### The '{current_tool}' tool completed successfully with {dynamic_result}. " - "INSTRUCTIONS: Parse this tool result as plain text, and use the system prompt containing the list of tools in sequence and the conversation history to figure out next steps, if any. " + "INSTRUCTIONS: Parse this tool result as plain text, and use the system prompt containing the list of tools in sequence and the conversation history (and previous tool_results) to figure out next steps, if any. " '{"next": "", "tool": "", "args": {"": "", "": "}, "response": ""}' "ONLY return those json keys (next, tool, args, response), nothing else." 'Next should only be "done" if all tools have been run (use the system prompt to figure that out).'