fix: postgres timeout issues during heavy load
This commit is contained in:
@@ -6,7 +6,7 @@ POSTGRES_PASSWORD=changeme
|
|||||||
POSTGRES_DB=dagster
|
POSTGRES_DB=dagster
|
||||||
|
|
||||||
# Dagster metadata storage (same postgres instance)
|
# Dagster metadata storage (same postgres instance)
|
||||||
DAGSTER_POSTGRES_URL=postgresql://dagster:changeme@postgres:5432/dagster
|
DAGSTER_POSTGRES_URL=postgresql://dagster:changeme@postgres:5432/dagster?connect_timeout=10&keepalives=1&keepalives_idle=30&keepalives_interval=10&keepalives_count=5
|
||||||
|
|
||||||
# dbt profile target
|
# dbt profile target
|
||||||
DBT_TARGET=dev
|
DBT_TARGET=dev
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ storage:
|
|||||||
postgres:
|
postgres:
|
||||||
postgres_url:
|
postgres_url:
|
||||||
env: DAGSTER_POSTGRES_URL
|
env: DAGSTER_POSTGRES_URL
|
||||||
pool_size: 5
|
|
||||||
max_overflow: 5
|
|
||||||
|
|
||||||
# Limit concurrent runs to avoid overwhelming the VM and database.
|
# Limit concurrent runs to avoid overwhelming the VM and database.
|
||||||
concurrency:
|
concurrency:
|
||||||
|
|||||||
@@ -18,7 +18,15 @@ def _elementary_schema_exists() -> bool:
|
|||||||
port=os.environ.get("POSTGRES_PORT", "5432"),
|
port=os.environ.get("POSTGRES_PORT", "5432"),
|
||||||
dbname=os.environ["POSTGRES_DB"],
|
dbname=os.environ["POSTGRES_DB"],
|
||||||
)
|
)
|
||||||
engine = create_engine(url)
|
engine = create_engine(
|
||||||
|
url,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
connect_args={"connect_timeout": 10},
|
||||||
|
)
|
||||||
|
|
||||||
|
from data_platform.resources import _retry_on_operational_error
|
||||||
|
|
||||||
|
def _query():
|
||||||
with engine.connect() as conn:
|
with engine.connect() as conn:
|
||||||
return bool(
|
return bool(
|
||||||
conn.execute(
|
conn.execute(
|
||||||
@@ -28,6 +36,8 @@ def _elementary_schema_exists() -> bool:
|
|||||||
).scalar()
|
).scalar()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return _retry_on_operational_error(_query)
|
||||||
|
|
||||||
|
|
||||||
@op
|
@op
|
||||||
def elementary_run_models(context: OpExecutionContext) -> None:
|
def elementary_run_models(context: OpExecutionContext) -> None:
|
||||||
|
|||||||
@@ -1,9 +1,36 @@
|
|||||||
"""Dagster resources."""
|
"""Dagster resources."""
|
||||||
|
|
||||||
from dagster import ConfigurableResource, EnvVar
|
import time
|
||||||
|
|
||||||
|
from dagster import ConfigurableResource, EnvVar, get_dagster_logger
|
||||||
from funda import Funda
|
from funda import Funda
|
||||||
from sqlalchemy import create_engine, text
|
from sqlalchemy import create_engine, text
|
||||||
from sqlalchemy.pool import NullPool
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
|
logger = get_dagster_logger()
|
||||||
|
|
||||||
|
_RETRY_ATTEMPTS = 5
|
||||||
|
_RETRY_BASE_DELAY = 1 # seconds; doubles each attempt
|
||||||
|
|
||||||
|
|
||||||
|
def _retry_on_operational_error(
|
||||||
|
fn, *, attempts=_RETRY_ATTEMPTS, base_delay=_RETRY_BASE_DELAY
|
||||||
|
):
|
||||||
|
"""Retry *fn* with exponential back-off on SQLAlchemy OperationalError."""
|
||||||
|
for attempt in range(1, attempts + 1):
|
||||||
|
try:
|
||||||
|
return fn()
|
||||||
|
except OperationalError:
|
||||||
|
if attempt == attempts:
|
||||||
|
raise
|
||||||
|
delay = base_delay * 2 ** (attempt - 1)
|
||||||
|
logger.warning(
|
||||||
|
"DB connection attempt %d/%d failed, retrying in %ds …",
|
||||||
|
attempt,
|
||||||
|
attempts,
|
||||||
|
delay,
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
|
||||||
class FundaResource(ConfigurableResource):
|
class FundaResource(ConfigurableResource):
|
||||||
@@ -26,18 +53,32 @@ class PostgresResource(ConfigurableResource):
|
|||||||
|
|
||||||
def get_engine(self):
|
def get_engine(self):
|
||||||
url = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.dbname}"
|
url = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.dbname}"
|
||||||
return create_engine(url, poolclass=NullPool)
|
return create_engine(
|
||||||
|
url,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
pool_size=2,
|
||||||
|
max_overflow=3,
|
||||||
|
connect_args={"connect_timeout": 10},
|
||||||
|
)
|
||||||
|
|
||||||
def execute(self, statement: str, params: dict | None = None):
|
def execute(self, statement: str, params: dict | None = None):
|
||||||
engine = self.get_engine()
|
engine = self.get_engine()
|
||||||
|
|
||||||
|
def _run():
|
||||||
with engine.begin() as conn:
|
with engine.begin() as conn:
|
||||||
conn.execute(text(statement), params or {})
|
conn.execute(text(statement), params or {})
|
||||||
|
|
||||||
|
_retry_on_operational_error(_run)
|
||||||
|
|
||||||
def execute_many(self, statement: str, rows: list[dict]):
|
def execute_many(self, statement: str, rows: list[dict]):
|
||||||
engine = self.get_engine()
|
engine = self.get_engine()
|
||||||
|
|
||||||
|
def _run():
|
||||||
with engine.begin() as conn:
|
with engine.begin() as conn:
|
||||||
conn.execute(text(statement), rows)
|
conn.execute(text(statement), rows)
|
||||||
|
|
||||||
|
_retry_on_operational_error(_run)
|
||||||
|
|
||||||
|
|
||||||
class MLflowResource(ConfigurableResource):
|
class MLflowResource(ConfigurableResource):
|
||||||
"""MLflow experiment tracking resource."""
|
"""MLflow experiment tracking resource."""
|
||||||
|
|||||||
@@ -2,11 +2,15 @@
|
|||||||
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
from data_platform.resources import (
|
from data_platform.resources import (
|
||||||
DiscordResource,
|
DiscordResource,
|
||||||
FundaResource,
|
FundaResource,
|
||||||
MLflowResource,
|
MLflowResource,
|
||||||
PostgresResource,
|
PostgresResource,
|
||||||
|
_retry_on_operational_error,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -59,13 +63,45 @@ class TestPostgresResource:
|
|||||||
call_url = mock_create.call_args[0][0]
|
call_url = mock_create.call_args[0][0]
|
||||||
assert call_url.startswith("postgresql://")
|
assert call_url.startswith("postgresql://")
|
||||||
|
|
||||||
|
def test_engine_uses_pool_pre_ping(self):
|
||||||
|
res = self._make_resource()
|
||||||
|
with patch("data_platform.resources.create_engine") as mock_create:
|
||||||
|
mock_create.return_value = MagicMock()
|
||||||
|
res.get_engine()
|
||||||
|
kwargs = mock_create.call_args[1]
|
||||||
|
assert kwargs["pool_pre_ping"] is True
|
||||||
|
|
||||||
|
def test_engine_sets_connect_timeout(self):
|
||||||
|
res = self._make_resource()
|
||||||
|
with patch("data_platform.resources.create_engine") as mock_create:
|
||||||
|
mock_create.return_value = MagicMock()
|
||||||
|
res.get_engine()
|
||||||
|
kwargs = mock_create.call_args[1]
|
||||||
|
assert kwargs["connect_args"]["connect_timeout"] == 10
|
||||||
|
|
||||||
|
def test_execute_retries_on_operational_error(self):
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.execute.side_effect = [
|
||||||
|
OperationalError("conn", {}, Exception("DNS failure")),
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
mock_engine.begin.return_value.__enter__ = MagicMock(return_value=mock_conn)
|
||||||
|
mock_engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("data_platform.resources.create_engine", return_value=mock_engine),
|
||||||
|
patch("data_platform.resources.time.sleep"),
|
||||||
|
):
|
||||||
|
res = self._make_resource()
|
||||||
|
res.execute("SELECT 1")
|
||||||
|
|
||||||
def test_execute_calls_engine_begin(self):
|
def test_execute_calls_engine_begin(self):
|
||||||
mock_engine = MagicMock()
|
mock_engine = MagicMock()
|
||||||
mock_conn = MagicMock()
|
mock_conn = MagicMock()
|
||||||
mock_engine.begin.return_value.__enter__ = MagicMock(return_value=mock_conn)
|
mock_engine.begin.return_value.__enter__ = MagicMock(return_value=mock_conn)
|
||||||
mock_engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
mock_engine.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||||
|
|
||||||
# Patch create_engine at module level so that get_engine() returns our mock
|
|
||||||
with patch("data_platform.resources.create_engine", return_value=mock_engine):
|
with patch("data_platform.resources.create_engine", return_value=mock_engine):
|
||||||
res = self._make_resource()
|
res = self._make_resource()
|
||||||
res.execute("SELECT 1")
|
res.execute("SELECT 1")
|
||||||
@@ -86,6 +122,51 @@ class TestPostgresResource:
|
|||||||
mock_conn.execute.assert_called_once()
|
mock_conn.execute.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetryOnOperationalError:
|
||||||
|
def test_succeeds_on_first_attempt(self):
|
||||||
|
fn = MagicMock(return_value="ok")
|
||||||
|
result = _retry_on_operational_error(fn, attempts=3, base_delay=0)
|
||||||
|
assert result == "ok"
|
||||||
|
assert fn.call_count == 1
|
||||||
|
|
||||||
|
@patch("data_platform.resources.time.sleep")
|
||||||
|
def test_retries_then_succeeds(self, mock_sleep):
|
||||||
|
fn = MagicMock(
|
||||||
|
side_effect=[
|
||||||
|
OperationalError("conn", {}, Exception("DNS failure")),
|
||||||
|
"ok",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
result = _retry_on_operational_error(fn, attempts=3, base_delay=1)
|
||||||
|
assert result == "ok"
|
||||||
|
assert fn.call_count == 2
|
||||||
|
mock_sleep.assert_called_once_with(1)
|
||||||
|
|
||||||
|
@patch("data_platform.resources.time.sleep")
|
||||||
|
def test_raises_after_all_attempts_exhausted(self, mock_sleep):
|
||||||
|
fn = MagicMock(
|
||||||
|
side_effect=OperationalError("conn", {}, Exception("DNS failure"))
|
||||||
|
)
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
_retry_on_operational_error(fn, attempts=3, base_delay=1)
|
||||||
|
assert fn.call_count == 3
|
||||||
|
|
||||||
|
@patch("data_platform.resources.time.sleep")
|
||||||
|
def test_exponential_backoff(self, mock_sleep):
|
||||||
|
fn = MagicMock(
|
||||||
|
side_effect=[
|
||||||
|
OperationalError("conn", {}, Exception("DNS failure")),
|
||||||
|
OperationalError("conn", {}, Exception("DNS failure")),
|
||||||
|
"ok",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_retry_on_operational_error(fn, attempts=5, base_delay=1)
|
||||||
|
assert mock_sleep.call_args_list == [
|
||||||
|
((1,),),
|
||||||
|
((2,),),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TestMLflowResource:
|
class TestMLflowResource:
|
||||||
def test_tracking_uri(self):
|
def test_tracking_uri(self):
|
||||||
resource = MLflowResource(tracking_uri="http://mlflow:5000")
|
resource = MLflowResource(tracking_uri="http://mlflow:5000")
|
||||||
|
|||||||
Reference in New Issue
Block a user