From 1b29efd649e27cad3c1a093739fb47f64e0dfda5 Mon Sep 17 00:00:00 2001 From: Stijnvandenbroek Date: Tue, 10 Mar 2026 14:47:36 +0000 Subject: [PATCH] fix: postgres timeout issues during heavy load --- .env.example | 2 +- dagster_home/dagster.yaml | 2 - data_platform/ops/elementary.py | 28 ++++++---- data_platform/resources/__init__.py | 55 ++++++++++++++++--- tests/test_resources.py | 83 ++++++++++++++++++++++++++++- 5 files changed, 150 insertions(+), 20 deletions(-) diff --git a/.env.example b/.env.example index 3d7ef3a..178d9b9 100644 --- a/.env.example +++ b/.env.example @@ -6,7 +6,7 @@ POSTGRES_PASSWORD=changeme POSTGRES_DB=dagster # Dagster metadata storage (same postgres instance) -DAGSTER_POSTGRES_URL=postgresql://dagster:changeme@postgres:5432/dagster +DAGSTER_POSTGRES_URL=postgresql://dagster:changeme@postgres:5432/dagster?connect_timeout=10&keepalives=1&keepalives_idle=30&keepalives_interval=10&keepalives_count=5 # dbt profile target DBT_TARGET=dev diff --git a/dagster_home/dagster.yaml b/dagster_home/dagster.yaml index 4c8535b..802ffbd 100644 --- a/dagster_home/dagster.yaml +++ b/dagster_home/dagster.yaml @@ -4,8 +4,6 @@ storage: postgres: postgres_url: env: DAGSTER_POSTGRES_URL - pool_size: 5 - max_overflow: 5 # Limit concurrent runs to avoid overwhelming the VM and database. concurrency: diff --git a/data_platform/ops/elementary.py b/data_platform/ops/elementary.py index 48c494d..5af9194 100644 --- a/data_platform/ops/elementary.py +++ b/data_platform/ops/elementary.py @@ -18,15 +18,25 @@ def _elementary_schema_exists() -> bool: port=os.environ.get("POSTGRES_PORT", "5432"), dbname=os.environ["POSTGRES_DB"], ) - engine = create_engine(url) - with engine.connect() as conn: - return bool( - conn.execute( - text( - "SELECT 1 FROM information_schema.schemata WHERE schema_name = 'elementary'" - ) - ).scalar() - ) + engine = create_engine( + url, + pool_pre_ping=True, + connect_args={"connect_timeout": 10}, + ) + + from data_platform.resources import _retry_on_operational_error + + def _query(): + with engine.connect() as conn: + return bool( + conn.execute( + text( + "SELECT 1 FROM information_schema.schemata WHERE schema_name = 'elementary'" + ) + ).scalar() + ) + + return _retry_on_operational_error(_query) @op diff --git a/data_platform/resources/__init__.py b/data_platform/resources/__init__.py index efd73e8..1b7a4e6 100644 --- a/data_platform/resources/__init__.py +++ b/data_platform/resources/__init__.py @@ -1,9 +1,36 @@ """Dagster resources.""" -from dagster import ConfigurableResource, EnvVar +import time + +from dagster import ConfigurableResource, EnvVar, get_dagster_logger from funda import Funda from sqlalchemy import create_engine, text -from sqlalchemy.pool import NullPool +from sqlalchemy.exc import OperationalError + +logger = get_dagster_logger() + +_RETRY_ATTEMPTS = 5 +_RETRY_BASE_DELAY = 1 # seconds; doubles each attempt + + +def _retry_on_operational_error( + fn, *, attempts=_RETRY_ATTEMPTS, base_delay=_RETRY_BASE_DELAY +): + """Retry *fn* with exponential back-off on SQLAlchemy OperationalError.""" + for attempt in range(1, attempts + 1): + try: + return fn() + except OperationalError: + if attempt == attempts: + raise + delay = base_delay * 2 ** (attempt - 1) + logger.warning( + "DB connection attempt %d/%d failed, retrying in %ds …", + attempt, + attempts, + delay, + ) + time.sleep(delay) class FundaResource(ConfigurableResource): @@ -26,17 +53,31 @@ class PostgresResource(ConfigurableResource): def get_engine(self): url = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.dbname}" - return create_engine(url, poolclass=NullPool) + return create_engine( + url, + pool_pre_ping=True, + pool_size=2, + max_overflow=3, + connect_args={"connect_timeout": 10}, + ) def execute(self, statement: str, params: dict | None = None): engine = self.get_engine() - with engine.begin() as conn: - conn.execute(text(statement), params or {}) + + def _run(): + with engine.begin() as conn: + conn.execute(text(statement), params or {}) + + _retry_on_operational_error(_run) def execute_many(self, statement: str, rows: list[dict]): engine = self.get_engine() - with engine.begin() as conn: - conn.execute(text(statement), rows) + + def _run(): + with engine.begin() as conn: + conn.execute(text(statement), rows) + + _retry_on_operational_error(_run) class MLflowResource(ConfigurableResource): diff --git a/tests/test_resources.py b/tests/test_resources.py index dfb3817..fe0f309 100644 --- a/tests/test_resources.py +++ b/tests/test_resources.py @@ -2,11 +2,15 @@ from unittest.mock import MagicMock, patch +import pytest +from sqlalchemy.exc import OperationalError + from data_platform.resources import ( DiscordResource, FundaResource, MLflowResource, PostgresResource, + _retry_on_operational_error, ) @@ -59,13 +63,45 @@ class TestPostgresResource: call_url = mock_create.call_args[0][0] assert call_url.startswith("postgresql://") + def test_engine_uses_pool_pre_ping(self): + res = self._make_resource() + with patch("data_platform.resources.create_engine") as mock_create: + mock_create.return_value = MagicMock() + res.get_engine() + kwargs = mock_create.call_args[1] + assert kwargs["pool_pre_ping"] is True + + def test_engine_sets_connect_timeout(self): + res = self._make_resource() + with patch("data_platform.resources.create_engine") as mock_create: + mock_create.return_value = MagicMock() + res.get_engine() + kwargs = mock_create.call_args[1] + assert kwargs["connect_args"]["connect_timeout"] == 10 + + def test_execute_retries_on_operational_error(self): + mock_engine = MagicMock() + mock_conn = MagicMock() + mock_conn.execute.side_effect = [ + OperationalError("conn", {}, Exception("DNS failure")), + None, + ] + mock_engine.begin.return_value.__enter__ = MagicMock(return_value=mock_conn) + mock_engine.begin.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch("data_platform.resources.create_engine", return_value=mock_engine), + patch("data_platform.resources.time.sleep"), + ): + res = self._make_resource() + res.execute("SELECT 1") + def test_execute_calls_engine_begin(self): mock_engine = MagicMock() mock_conn = MagicMock() mock_engine.begin.return_value.__enter__ = MagicMock(return_value=mock_conn) mock_engine.begin.return_value.__exit__ = MagicMock(return_value=False) - # Patch create_engine at module level so that get_engine() returns our mock with patch("data_platform.resources.create_engine", return_value=mock_engine): res = self._make_resource() res.execute("SELECT 1") @@ -86,6 +122,51 @@ class TestPostgresResource: mock_conn.execute.assert_called_once() +class TestRetryOnOperationalError: + def test_succeeds_on_first_attempt(self): + fn = MagicMock(return_value="ok") + result = _retry_on_operational_error(fn, attempts=3, base_delay=0) + assert result == "ok" + assert fn.call_count == 1 + + @patch("data_platform.resources.time.sleep") + def test_retries_then_succeeds(self, mock_sleep): + fn = MagicMock( + side_effect=[ + OperationalError("conn", {}, Exception("DNS failure")), + "ok", + ] + ) + result = _retry_on_operational_error(fn, attempts=3, base_delay=1) + assert result == "ok" + assert fn.call_count == 2 + mock_sleep.assert_called_once_with(1) + + @patch("data_platform.resources.time.sleep") + def test_raises_after_all_attempts_exhausted(self, mock_sleep): + fn = MagicMock( + side_effect=OperationalError("conn", {}, Exception("DNS failure")) + ) + with pytest.raises(OperationalError): + _retry_on_operational_error(fn, attempts=3, base_delay=1) + assert fn.call_count == 3 + + @patch("data_platform.resources.time.sleep") + def test_exponential_backoff(self, mock_sleep): + fn = MagicMock( + side_effect=[ + OperationalError("conn", {}, Exception("DNS failure")), + OperationalError("conn", {}, Exception("DNS failure")), + "ok", + ] + ) + _retry_on_operational_error(fn, attempts=5, base_delay=1) + assert mock_sleep.call_args_list == [ + ((1,),), + ((2,),), + ] + + class TestMLflowResource: def test_tracking_uri(self): resource = MLflowResource(tracking_uri="http://mlflow:5000")