fix: postgres timeout issues during heavy load
This commit is contained in:
@@ -6,7 +6,7 @@ POSTGRES_PASSWORD=changeme
|
||||
POSTGRES_DB=dagster
|
||||
|
||||
# Dagster metadata storage (same postgres instance)
|
||||
DAGSTER_POSTGRES_URL=postgresql://dagster:changeme@postgres:5432/dagster
|
||||
DAGSTER_POSTGRES_URL=postgresql://dagster:changeme@postgres:5432/dagster?connect_timeout=10&keepalives=1&keepalives_idle=30&keepalives_interval=10&keepalives_count=5
|
||||
|
||||
# dbt profile target
|
||||
DBT_TARGET=dev
|
||||
|
||||
@@ -4,8 +4,6 @@ storage:
|
||||
postgres:
|
||||
postgres_url:
|
||||
env: DAGSTER_POSTGRES_URL
|
||||
pool_size: 5
|
||||
max_overflow: 5
|
||||
|
||||
# Limit concurrent runs to avoid overwhelming the VM and database.
|
||||
concurrency:
|
||||
|
||||
@@ -18,15 +18,25 @@ def _elementary_schema_exists() -> bool:
|
||||
port=os.environ.get("POSTGRES_PORT", "5432"),
|
||||
dbname=os.environ["POSTGRES_DB"],
|
||||
)
|
||||
engine = create_engine(url)
|
||||
with engine.connect() as conn:
|
||||
return bool(
|
||||
conn.execute(
|
||||
text(
|
||||
"SELECT 1 FROM information_schema.schemata WHERE schema_name = 'elementary'"
|
||||
)
|
||||
).scalar()
|
||||
)
|
||||
engine = create_engine(
|
||||
url,
|
||||
pool_pre_ping=True,
|
||||
connect_args={"connect_timeout": 10},
|
||||
)
|
||||
|
||||
from data_platform.resources import _retry_on_operational_error
|
||||
|
||||
def _query():
|
||||
with engine.connect() as conn:
|
||||
return bool(
|
||||
conn.execute(
|
||||
text(
|
||||
"SELECT 1 FROM information_schema.schemata WHERE schema_name = 'elementary'"
|
||||
)
|
||||
).scalar()
|
||||
)
|
||||
|
||||
return _retry_on_operational_error(_query)
|
||||
|
||||
|
||||
@op
|
||||
|
||||
@@ -1,9 +1,36 @@
|
||||
"""Dagster resources."""
|
||||
|
||||
from dagster import ConfigurableResource, EnvVar
|
||||
import time
|
||||
|
||||
from dagster import ConfigurableResource, EnvVar, get_dagster_logger
|
||||
from funda import Funda
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
logger = get_dagster_logger()
|
||||
|
||||
_RETRY_ATTEMPTS = 5
|
||||
_RETRY_BASE_DELAY = 1 # seconds; doubles each attempt
|
||||
|
||||
|
||||
def _retry_on_operational_error(
|
||||
fn, *, attempts=_RETRY_ATTEMPTS, base_delay=_RETRY_BASE_DELAY
|
||||
):
|
||||
"""Retry *fn* with exponential back-off on SQLAlchemy OperationalError."""
|
||||
for attempt in range(1, attempts + 1):
|
||||
try:
|
||||
return fn()
|
||||
except OperationalError:
|
||||
if attempt == attempts:
|
||||
raise
|
||||
delay = base_delay * 2 ** (attempt - 1)
|
||||
logger.warning(
|
||||
"DB connection attempt %d/%d failed, retrying in %ds …",
|
||||
attempt,
|
||||
attempts,
|
||||
delay,
|
||||
)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
class FundaResource(ConfigurableResource):
|
||||
@@ -26,17 +53,31 @@ class PostgresResource(ConfigurableResource):
|
||||
|
||||
def get_engine(self):
|
||||
url = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.dbname}"
|
||||
return create_engine(url, poolclass=NullPool)
|
||||
return create_engine(
|
||||
url,
|
||||
pool_pre_ping=True,
|
||||
pool_size=2,
|
||||
max_overflow=3,
|
||||
connect_args={"connect_timeout": 10},
|
||||
)
|
||||
|
||||
def execute(self, statement: str, params: dict | None = None):
|
||||
engine = self.get_engine()
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text(statement), params or {})
|
||||
|
||||
def _run():
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text(statement), params or {})
|
||||
|
||||
_retry_on_operational_error(_run)
|
||||
|
||||
def execute_many(self, statement: str, rows: list[dict]):
|
||||
engine = self.get_engine()
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text(statement), rows)
|
||||
|
||||
def _run():
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text(statement), rows)
|
||||
|
||||
_retry_on_operational_error(_run)
|
||||
|
||||
|
||||
class MLflowResource(ConfigurableResource):
|
||||
|
||||
@@ -2,11 +2,15 @@
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
from data_platform.resources import (
|
||||
DiscordResource,
|
||||
FundaResource,
|
||||
MLflowResource,
|
||||
PostgresResource,
|
||||
_retry_on_operational_error,
|
||||
)
|
||||
|
||||
|
||||
@@ -59,13 +63,45 @@ class TestPostgresResource:
|
||||
call_url = mock_create.call_args[0][0]
|
||||
assert call_url.startswith("postgresql://")
|
||||
|
||||
def test_engine_uses_pool_pre_ping(self):
|
||||
res = self._make_resource()
|
||||
with patch("data_platform.resources.create_engine") as mock_create:
|
||||
mock_create.return_value = MagicMock()
|
||||
res.get_engine()
|
||||
kwargs = mock_create.call_args[1]
|
||||
assert kwargs["pool_pre_ping"] is True
|
||||
|
||||
def test_engine_sets_connect_timeout(self):
|
||||
res = self._make_resource()
|
||||
with patch("data_platform.resources.create_engine") as mock_create:
|
||||
mock_create.return_value = MagicMock()
|
||||
res.get_engine()
|
||||
kwargs = mock_create.call_args[1]
|
||||
assert kwargs["connect_args"]["connect_timeout"] == 10
|
||||
|
||||
def test_execute_retries_on_operational_error(self):
|
||||
mock_engine = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.execute.side_effect = [
|
||||
OperationalError("conn", {}, Exception("DNS failure")),
|
||||
None,
|
||||
]
|
||||
mock_engine.begin.return_value.__enter__ = MagicMock(return_value=mock_conn)
|
||||
mock_engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("data_platform.resources.create_engine", return_value=mock_engine),
|
||||
patch("data_platform.resources.time.sleep"),
|
||||
):
|
||||
res = self._make_resource()
|
||||
res.execute("SELECT 1")
|
||||
|
||||
def test_execute_calls_engine_begin(self):
|
||||
mock_engine = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
mock_engine.begin.return_value.__enter__ = MagicMock(return_value=mock_conn)
|
||||
mock_engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
# Patch create_engine at module level so that get_engine() returns our mock
|
||||
with patch("data_platform.resources.create_engine", return_value=mock_engine):
|
||||
res = self._make_resource()
|
||||
res.execute("SELECT 1")
|
||||
@@ -86,6 +122,51 @@ class TestPostgresResource:
|
||||
mock_conn.execute.assert_called_once()
|
||||
|
||||
|
||||
class TestRetryOnOperationalError:
|
||||
def test_succeeds_on_first_attempt(self):
|
||||
fn = MagicMock(return_value="ok")
|
||||
result = _retry_on_operational_error(fn, attempts=3, base_delay=0)
|
||||
assert result == "ok"
|
||||
assert fn.call_count == 1
|
||||
|
||||
@patch("data_platform.resources.time.sleep")
|
||||
def test_retries_then_succeeds(self, mock_sleep):
|
||||
fn = MagicMock(
|
||||
side_effect=[
|
||||
OperationalError("conn", {}, Exception("DNS failure")),
|
||||
"ok",
|
||||
]
|
||||
)
|
||||
result = _retry_on_operational_error(fn, attempts=3, base_delay=1)
|
||||
assert result == "ok"
|
||||
assert fn.call_count == 2
|
||||
mock_sleep.assert_called_once_with(1)
|
||||
|
||||
@patch("data_platform.resources.time.sleep")
|
||||
def test_raises_after_all_attempts_exhausted(self, mock_sleep):
|
||||
fn = MagicMock(
|
||||
side_effect=OperationalError("conn", {}, Exception("DNS failure"))
|
||||
)
|
||||
with pytest.raises(OperationalError):
|
||||
_retry_on_operational_error(fn, attempts=3, base_delay=1)
|
||||
assert fn.call_count == 3
|
||||
|
||||
@patch("data_platform.resources.time.sleep")
|
||||
def test_exponential_backoff(self, mock_sleep):
|
||||
fn = MagicMock(
|
||||
side_effect=[
|
||||
OperationalError("conn", {}, Exception("DNS failure")),
|
||||
OperationalError("conn", {}, Exception("DNS failure")),
|
||||
"ok",
|
||||
]
|
||||
)
|
||||
_retry_on_operational_error(fn, attempts=5, base_delay=1)
|
||||
assert mock_sleep.call_args_list == [
|
||||
((1,),),
|
||||
((2,),),
|
||||
]
|
||||
|
||||
|
||||
class TestMLflowResource:
|
||||
def test_tracking_uri(self):
|
||||
resource = MLflowResource(tracking_uri="http://mlflow:5000")
|
||||
|
||||
Reference in New Issue
Block a user