fix: postgres timeout issues during heavy load

This commit is contained in:
Stijnvandenbroek
2026-03-10 14:47:36 +00:00
parent 508da573fa
commit 1b29efd649
5 changed files with 150 additions and 20 deletions

View File

@@ -6,7 +6,7 @@ POSTGRES_PASSWORD=changeme
POSTGRES_DB=dagster POSTGRES_DB=dagster
# Dagster metadata storage (same postgres instance) # Dagster metadata storage (same postgres instance)
DAGSTER_POSTGRES_URL=postgresql://dagster:changeme@postgres:5432/dagster DAGSTER_POSTGRES_URL=postgresql://dagster:changeme@postgres:5432/dagster?connect_timeout=10&keepalives=1&keepalives_idle=30&keepalives_interval=10&keepalives_count=5
# dbt profile target # dbt profile target
DBT_TARGET=dev DBT_TARGET=dev

View File

@@ -4,8 +4,6 @@ storage:
postgres: postgres:
postgres_url: postgres_url:
env: DAGSTER_POSTGRES_URL env: DAGSTER_POSTGRES_URL
pool_size: 5
max_overflow: 5
# Limit concurrent runs to avoid overwhelming the VM and database. # Limit concurrent runs to avoid overwhelming the VM and database.
concurrency: concurrency:

View File

@@ -18,7 +18,15 @@ def _elementary_schema_exists() -> bool:
port=os.environ.get("POSTGRES_PORT", "5432"), port=os.environ.get("POSTGRES_PORT", "5432"),
dbname=os.environ["POSTGRES_DB"], dbname=os.environ["POSTGRES_DB"],
) )
engine = create_engine(url) engine = create_engine(
url,
pool_pre_ping=True,
connect_args={"connect_timeout": 10},
)
from data_platform.resources import _retry_on_operational_error
def _query():
with engine.connect() as conn: with engine.connect() as conn:
return bool( return bool(
conn.execute( conn.execute(
@@ -28,6 +36,8 @@ def _elementary_schema_exists() -> bool:
).scalar() ).scalar()
) )
return _retry_on_operational_error(_query)
@op @op
def elementary_run_models(context: OpExecutionContext) -> None: def elementary_run_models(context: OpExecutionContext) -> None:

View File

@@ -1,9 +1,36 @@
"""Dagster resources.""" """Dagster resources."""
from dagster import ConfigurableResource, EnvVar import time
from dagster import ConfigurableResource, EnvVar, get_dagster_logger
from funda import Funda from funda import Funda
from sqlalchemy import create_engine, text from sqlalchemy import create_engine, text
from sqlalchemy.pool import NullPool from sqlalchemy.exc import OperationalError
logger = get_dagster_logger()
_RETRY_ATTEMPTS = 5
_RETRY_BASE_DELAY = 1 # seconds; doubles each attempt
def _retry_on_operational_error(
fn, *, attempts=_RETRY_ATTEMPTS, base_delay=_RETRY_BASE_DELAY
):
"""Retry *fn* with exponential back-off on SQLAlchemy OperationalError."""
for attempt in range(1, attempts + 1):
try:
return fn()
except OperationalError:
if attempt == attempts:
raise
delay = base_delay * 2 ** (attempt - 1)
logger.warning(
"DB connection attempt %d/%d failed, retrying in %ds …",
attempt,
attempts,
delay,
)
time.sleep(delay)
class FundaResource(ConfigurableResource): class FundaResource(ConfigurableResource):
@@ -26,18 +53,32 @@ class PostgresResource(ConfigurableResource):
def get_engine(self): def get_engine(self):
url = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.dbname}" url = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.dbname}"
return create_engine(url, poolclass=NullPool) return create_engine(
url,
pool_pre_ping=True,
pool_size=2,
max_overflow=3,
connect_args={"connect_timeout": 10},
)
def execute(self, statement: str, params: dict | None = None): def execute(self, statement: str, params: dict | None = None):
engine = self.get_engine() engine = self.get_engine()
def _run():
with engine.begin() as conn: with engine.begin() as conn:
conn.execute(text(statement), params or {}) conn.execute(text(statement), params or {})
_retry_on_operational_error(_run)
def execute_many(self, statement: str, rows: list[dict]): def execute_many(self, statement: str, rows: list[dict]):
engine = self.get_engine() engine = self.get_engine()
def _run():
with engine.begin() as conn: with engine.begin() as conn:
conn.execute(text(statement), rows) conn.execute(text(statement), rows)
_retry_on_operational_error(_run)
class MLflowResource(ConfigurableResource): class MLflowResource(ConfigurableResource):
"""MLflow experiment tracking resource.""" """MLflow experiment tracking resource."""

View File

@@ -2,11 +2,15 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy.exc import OperationalError
from data_platform.resources import ( from data_platform.resources import (
DiscordResource, DiscordResource,
FundaResource, FundaResource,
MLflowResource, MLflowResource,
PostgresResource, PostgresResource,
_retry_on_operational_error,
) )
@@ -59,13 +63,45 @@ class TestPostgresResource:
call_url = mock_create.call_args[0][0] call_url = mock_create.call_args[0][0]
assert call_url.startswith("postgresql://") assert call_url.startswith("postgresql://")
def test_engine_uses_pool_pre_ping(self):
res = self._make_resource()
with patch("data_platform.resources.create_engine") as mock_create:
mock_create.return_value = MagicMock()
res.get_engine()
kwargs = mock_create.call_args[1]
assert kwargs["pool_pre_ping"] is True
def test_engine_sets_connect_timeout(self):
res = self._make_resource()
with patch("data_platform.resources.create_engine") as mock_create:
mock_create.return_value = MagicMock()
res.get_engine()
kwargs = mock_create.call_args[1]
assert kwargs["connect_args"]["connect_timeout"] == 10
def test_execute_retries_on_operational_error(self):
mock_engine = MagicMock()
mock_conn = MagicMock()
mock_conn.execute.side_effect = [
OperationalError("conn", {}, Exception("DNS failure")),
None,
]
mock_engine.begin.return_value.__enter__ = MagicMock(return_value=mock_conn)
mock_engine.begin.return_value.__exit__ = MagicMock(return_value=False)
with (
patch("data_platform.resources.create_engine", return_value=mock_engine),
patch("data_platform.resources.time.sleep"),
):
res = self._make_resource()
res.execute("SELECT 1")
def test_execute_calls_engine_begin(self): def test_execute_calls_engine_begin(self):
mock_engine = MagicMock() mock_engine = MagicMock()
mock_conn = MagicMock() mock_conn = MagicMock()
mock_engine.begin.return_value.__enter__ = MagicMock(return_value=mock_conn) mock_engine.begin.return_value.__enter__ = MagicMock(return_value=mock_conn)
mock_engine.begin.return_value.__exit__ = MagicMock(return_value=False) mock_engine.begin.return_value.__exit__ = MagicMock(return_value=False)
# Patch create_engine at module level so that get_engine() returns our mock
with patch("data_platform.resources.create_engine", return_value=mock_engine): with patch("data_platform.resources.create_engine", return_value=mock_engine):
res = self._make_resource() res = self._make_resource()
res.execute("SELECT 1") res.execute("SELECT 1")
@@ -86,6 +122,51 @@ class TestPostgresResource:
mock_conn.execute.assert_called_once() mock_conn.execute.assert_called_once()
class TestRetryOnOperationalError:
def test_succeeds_on_first_attempt(self):
fn = MagicMock(return_value="ok")
result = _retry_on_operational_error(fn, attempts=3, base_delay=0)
assert result == "ok"
assert fn.call_count == 1
@patch("data_platform.resources.time.sleep")
def test_retries_then_succeeds(self, mock_sleep):
fn = MagicMock(
side_effect=[
OperationalError("conn", {}, Exception("DNS failure")),
"ok",
]
)
result = _retry_on_operational_error(fn, attempts=3, base_delay=1)
assert result == "ok"
assert fn.call_count == 2
mock_sleep.assert_called_once_with(1)
@patch("data_platform.resources.time.sleep")
def test_raises_after_all_attempts_exhausted(self, mock_sleep):
fn = MagicMock(
side_effect=OperationalError("conn", {}, Exception("DNS failure"))
)
with pytest.raises(OperationalError):
_retry_on_operational_error(fn, attempts=3, base_delay=1)
assert fn.call_count == 3
@patch("data_platform.resources.time.sleep")
def test_exponential_backoff(self, mock_sleep):
fn = MagicMock(
side_effect=[
OperationalError("conn", {}, Exception("DNS failure")),
OperationalError("conn", {}, Exception("DNS failure")),
"ok",
]
)
_retry_on_operational_error(fn, attempts=5, base_delay=1)
assert mock_sleep.call_args_list == [
((1,),),
((2,),),
]
class TestMLflowResource: class TestMLflowResource:
def test_tracking_uri(self): def test_tracking_uri(self):
resource = MLflowResource(tracking_uri="http://mlflow:5000") resource = MLflowResource(tracking_uri="http://mlflow:5000")