diff --git a/data_platform/ops/elementary.py b/data_platform/ops/elementary.py index 83525b5..e1f024f 100644 --- a/data_platform/ops/elementary.py +++ b/data_platform/ops/elementary.py @@ -10,7 +10,17 @@ from sqlalchemy import create_engine, text _DBT_DIR = Path(__file__).parents[2] / "dbt" -def _elementary_schema_exists() -> bool: +_DAYS_BACK = 3 + +_CLEANUP_TABLES = [ + "elementary_test_results", + "dbt_run_results", + "dbt_invocations", + "dbt_source_freshness_results", +] + + +def _get_engine(): url = "postgresql://{user}:{password}@{host}:{port}/{dbname}".format( user=os.environ["POSTGRES_USER"], password=os.environ["POSTGRES_PASSWORD"], @@ -18,12 +28,16 @@ def _elementary_schema_exists() -> bool: port=os.environ.get("POSTGRES_PORT", "5432"), dbname=os.environ["POSTGRES_DB"], ) - engine = create_engine( + return create_engine( url, pool_pre_ping=True, connect_args={"connect_timeout": 10}, ) + +def _elementary_schema_exists() -> bool: + engine = _get_engine() + from data_platform.resources import _retry_on_operational_error def _query(): @@ -69,9 +83,34 @@ def elementary_run_models(context: OpExecutionContext) -> None: raise Exception(f"dbt run elementary failed with exit code {returncode}") +def _cleanup_old_elementary_data(context: OpExecutionContext) -> None: + """Delete elementary rows older than _DAYS_BACK to prevent OOM during report generation.""" + engine = _get_engine() + total = 0 + with engine.begin() as conn: + for table in _CLEANUP_TABLES: + result = conn.execute( + text( + f"DELETE FROM elementary.{table} " # noqa: S608 + f"WHERE created_at < now() - interval '{_DAYS_BACK} days'" + ) + ) + if result.rowcount: + context.log.info( + f"Cleaned up {result.rowcount} old rows from elementary.{table}" + ) + total += result.rowcount + if total: + context.log.info(f"Total rows cleaned: {total}") + else: + context.log.info("No old elementary data to clean up.") + + @op(ins={"after": In(Nothing)}) def elementary_generate_report(context: OpExecutionContext) -> None: """Run edr report to regenerate the Elementary HTML report.""" + _cleanup_old_elementary_data(context) + report_path = ( Path(__file__).parents[2] / "dbt" / "edr_target" / "elementary_report.html" ) diff --git a/tests/test_ops.py b/tests/test_ops.py index dbb23f2..f67f918 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -9,6 +9,7 @@ from data_platform.ops.check_source_freshness import ( SourceFreshnessConfig, ) from data_platform.ops.elementary import ( + _cleanup_old_elementary_data, _elementary_schema_exists, elementary_generate_report, elementary_run_models, @@ -108,21 +109,62 @@ class TestElementaryRunModels: # elementary_generate_report +class TestCleanupOldElementaryData: + @patch("data_platform.ops.elementary._get_engine") + def test_deletes_old_rows(self, mock_get_engine): + from unittest.mock import MagicMock + + mock_conn = MagicMock() + mock_result = MagicMock() + mock_result.rowcount = 5 + mock_conn.execute.return_value = mock_result + mock_engine = MagicMock() + mock_engine.begin.return_value.__enter__ = lambda _: mock_conn + mock_engine.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_get_engine.return_value = mock_engine + + context = build_op_context() + _cleanup_old_elementary_data(context) + assert mock_conn.execute.call_count == 4 + + @patch("data_platform.ops.elementary._get_engine") + def test_logs_when_no_rows_deleted(self, mock_get_engine): + from unittest.mock import MagicMock + + mock_conn = MagicMock() + mock_result = MagicMock() + mock_result.rowcount = 0 + mock_conn.execute.return_value = mock_result + mock_engine = MagicMock() + mock_engine.begin.return_value.__enter__ = lambda _: mock_conn + mock_engine.begin.return_value.__exit__ = MagicMock(return_value=False) + mock_get_engine.return_value = mock_engine + + context = build_op_context() + _cleanup_old_elementary_data(context) + assert mock_conn.execute.call_count == 4 + + +# elementary_generate_report + + +@patch("data_platform.ops.elementary._cleanup_old_elementary_data") class TestElementaryGenerateReport: @patch("data_platform.ops.elementary.subprocess.Popen") - def test_calls_edr_report(self, mock_popen): + def test_calls_edr_report(self, mock_popen, mock_cleanup): mock_popen.return_value = _mock_popen( returncode=0, stdout_lines=["report generated\n"] ) context = build_op_context() elementary_generate_report(context) + mock_cleanup.assert_called_once() mock_popen.assert_called_once() args = mock_popen.call_args[0][0] assert "edr" in args assert "report" in args @patch("data_platform.ops.elementary.subprocess.Popen") - def test_raises_on_failure(self, mock_popen): + def test_raises_on_failure(self, mock_popen, mock_cleanup): mock_popen.return_value = _mock_popen( returncode=1, stdout_lines=["fatal error\n"] ) @@ -131,7 +173,7 @@ class TestElementaryGenerateReport: elementary_generate_report(context) @patch("data_platform.ops.elementary.subprocess.Popen") - def test_success_returns_none(self, mock_popen): + def test_success_returns_none(self, mock_popen, mock_cleanup): mock_popen.return_value = _mock_popen(returncode=0, stdout_lines=["done\n"]) context = build_op_context() result = elementary_generate_report(context)