sibi-dst 2025.9.3__py3-none-any.whl → 2025.9.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,187 +1,394 @@
1
1
  from __future__ import annotations
2
2
 
3
- import time
4
- from typing import Any, Dict, Tuple, Type
3
+ from contextlib import contextmanager
4
+ from typing import Any, Dict, Optional, Tuple, Type
5
5
 
6
6
  import dask
7
7
  import dask.dataframe as dd
8
8
  import pandas as pd
9
9
  import sqlalchemy as sa
10
- from sqlalchemy import select, inspect
10
+ from sqlalchemy import func, inspect, select
11
11
  from sqlalchemy.engine import Engine
12
- from sqlalchemy.exc import TimeoutError as SASQLTimeoutError, OperationalError
12
+ from sqlalchemy.exc import OperationalError
13
+ from sqlalchemy.exc import TimeoutError as SASQLTimeoutError
13
14
  from sqlalchemy.orm import declarative_base
14
15
 
15
- from sibi_dst.utils import ManagedResource
16
16
  from sibi_dst.df_helper.core import FilterHandler
17
+ from sibi_dst.utils import ManagedResource
17
18
  from ._db_gatekeeper import DBGatekeeper
18
19
 
19
20
 
20
21
  class SQLAlchemyDask(ManagedResource):
21
22
  """
22
- Loads data from a database into a Dask DataFrame using a memory-safe,
23
- non-parallel, paginated approach (LIMIT/OFFSET).
23
+ Production-grade DB -> Dask loader with robust, error-coercing type alignment
24
+ and concurrent-safe database access via DBGatekeeper.
24
25
  """
25
26
 
26
27
  _SQLALCHEMY_TO_DASK_DTYPE: Dict[str, str] = {
27
- "INTEGER": "Int64",
28
- "SMALLINT": "Int64",
29
- "BIGINT": "Int64",
30
- "FLOAT": "float64",
31
- "NUMERIC": "float64",
32
- "BOOLEAN": "bool",
33
- "VARCHAR": "object",
34
- "TEXT": "object",
35
- "DATE": "datetime64[ns]",
36
- "DATETIME": "datetime64[ns]",
37
- "TIMESTAMP": "datetime64[ns]",
38
- "TIME": "object",
39
- "UUID": "object",
28
+ "INTEGER": "Int64", "SMALLINT": "Int64", "BIGINT": "Int64",
29
+ "FLOAT": "float64", "DOUBLE": "float64",
30
+ "NUMERIC": "string", "DECIMAL": "string",
31
+ "BOOLEAN": "boolean",
32
+ "VARCHAR": "string", "CHAR": "string", "TEXT": "string", "UUID": "string",
33
+ "DATE": "datetime64[ns]", "DATETIME": "datetime64[ns]", "TIMESTAMP": "datetime64[ns]",
34
+ "TIME": "string",
40
35
  }
36
+
41
37
  logger_extra: Dict[str, Any] = {"sibi_dst_component": __name__}
42
38
 
43
39
  def __init__(
44
- self,
45
- model: Type[declarative_base()],
46
- filters: Dict[str, Any],
47
- engine: Engine,
48
- chunk_size: int = 1000,
49
- **kwargs: Any,
40
+ self,
41
+ model: Type[declarative_base()],
42
+ *,
43
+ engine: Engine,
44
+ filters: Optional[Dict[str, Any]] = None,
45
+ chunk_size: int = 50_000,
46
+ pagination: str = "offset",
47
+ index_col: Optional[str] = None,
48
+ num_workers: int = 1,
49
+ **kwargs: Any,
50
50
  ):
51
51
  super().__init__(**kwargs)
52
+ if pagination not in {"offset", "range"}:
53
+ raise ValueError("pagination must be 'offset' or 'range'.")
54
+ if pagination == "range" and not index_col:
55
+ raise ValueError("pagination='range' requires index_col.")
56
+
52
57
  self.model = model
53
- self.filters = filters or {}
54
58
  self.engine = engine
59
+ self.filters = filters or {}
55
60
  self.chunk_size = int(chunk_size)
61
+ self.pagination = pagination
62
+ self.index_col = index_col
63
+ self.num_workers = int(num_workers)
56
64
  self.filter_handler_cls = FilterHandler
57
- self.total_records: int = -1 # -1 indicates failure/unknown
58
- self._sem = DBGatekeeper.get(str(engine.url), max_concurrency=self._safe_cap())
65
+ self.total_records: int = -1
59
66
 
60
- def _safe_cap(self) -> int:
61
- """
62
- Calculate a safe concurrency cap for DB work based on the engine's pool.
67
+ # --- DBGatekeeper Initialization (Re-integrated) ---
68
+ pool_size, max_overflow = self._engine_pool_limits()
69
+ pool_capacity = max(1, pool_size + max_overflow)
70
+ per_proc_cap = max(1, pool_capacity // max(1, self.num_workers))
71
+ cap = per_proc_cap # Can be overridden by an explicit db_gatekeeper_cap attribute
72
+ gate_key = self._normalized_engine_key(self.engine)
73
+ self._sem = DBGatekeeper.get(gate_key, max_concurrency=cap)
74
+ self.logger.debug(f"DBGatekeeper initialized with max_concurrency={cap}")
63
75
 
64
- Returns: max(1, pool_size + max_overflow - 1)
65
- - Works across SQLAlchemy 1.4/2.x
66
- - Tolerates pools that expose size/max_overflow as methods or attrs
67
- - Allows explicit override via self.db_gatekeeper_cap (if you pass it)
68
- """
69
- # optional explicit override
70
- explicit = getattr(self, "db_gatekeeper_cap", None)
71
- if isinstance(explicit, int) and explicit > 0:
72
- return explicit
76
+ self._ordered_columns = [c.name for c in self.model.__table__.columns]
77
+ self._meta_dtypes = self.infer_meta_from_model(self.model)
78
+ self._meta_df = self._build_meta()
73
79
 
74
- pool = getattr(self.engine, "pool", None)
75
-
76
- def _to_int(val, default):
77
- if val is None:
78
- return default
79
- if callable(val):
80
- try:
81
- return int(val()) # e.g., pool.size()
82
- except Exception:
83
- return default
84
- try:
85
- return int(val)
86
- except Exception:
87
- return default
88
-
89
- # size: QueuePool.size() -> int
90
- size_candidate = getattr(pool, "size", None) # method on QueuePool
91
- pool_size = _to_int(size_candidate, 5)
92
-
93
- # max_overflow: prefer attribute; fall back to private _max_overflow; avoid 'overflow()' (method)
94
- max_overflow_attr = (
95
- getattr(pool, "max_overflow", None) or # SQLAlchemy 2.x QueuePool
96
- getattr(pool, "_max_overflow", None) # private fallback
97
- )
98
- max_overflow = _to_int(max_overflow_attr, 10)
99
-
100
- cap = max(1, pool_size + max_overflow - 1)
101
- self.logger.debug(f"Using a Cap of {cap} from pool size of {pool_size} and max overflow of {max_overflow}.", extra=self.logger_extra)
102
- return max(1, cap)
103
-
104
- # ---------- meta ----------
105
80
  @classmethod
106
81
  def infer_meta_from_model(cls, model: Type[declarative_base()]) -> Dict[str, str]:
82
+ # (This method is unchanged)
107
83
  mapper = inspect(model)
108
84
  dtypes: Dict[str, str] = {}
109
85
  for column in mapper.columns:
110
86
  dtype_str = str(column.type).upper().split("(")[0]
111
- dtype = cls._SQLALCHEMY_TO_DASK_DTYPE.get(dtype_str, "object")
112
- dtypes[column.name] = dtype
87
+ dtypes[column.name] = cls._SQLALCHEMY_TO_DASK_DTYPE.get(dtype_str, "string")
113
88
  return dtypes
114
89
 
115
- def read_frame(self, fillna_value=None) -> Tuple[int, dd.DataFrame]:
116
- # Base selectable
117
- query = select(self.model)
118
- if self.filters:
119
- query = self.filter_handler_cls(
120
- backend="sqlalchemy", logger=self.logger, debug=self.debug
121
- ).apply_filters(query, model=self.model, filters=self.filters)
122
- else:
123
- query = query.limit(self.chunk_size)
124
-
125
- # Meta dataframe (stable column order & dtypes)
126
- ordered_columns = [c.name for c in self.model.__table__.columns]
127
- meta_dtypes = self.infer_meta_from_model(self.model)
128
- meta_df = pd.DataFrame(columns=ordered_columns).astype(meta_dtypes)
129
-
130
- # Count with retry/backoff
131
- retry_attempts = 3
132
- backoff = 0.5
133
- total = 0
134
-
135
- for attempt in range(retry_attempts):
90
+ def _build_meta(self) -> pd.DataFrame:
91
+ # (This method is unchanged)
92
+ return pd.DataFrame({col: pd.Series(dtype=dtype) for col, dtype in self._meta_dtypes.items()})
93
+
94
+ @contextmanager
95
+ def _conn(self):
96
+ """Provides a managed, concurrent-safe database connection using the semaphore."""
97
+ with self._sem:
98
+ with self.engine.connect() as c:
99
+ yield c
100
+
101
+ def _fetch_with_retry(self, sql: sa.sql.Select) -> pd.DataFrame:
102
+ """Fetches a data chunk using the concurrent-safe connection."""
103
+ try:
104
+ with self._conn() as conn:
105
+ df = pd.read_sql_query(sql, conn, dtype_backend="pyarrow")
106
+ return self._align_and_coerce_partition(df)
107
+ except (SASQLTimeoutError, OperationalError) as e:
108
+ self.logger.error(f"Chunk fetch failed due to {e.__class__.__name__}", exc_info=True,
109
+ extra=self.logger_extra)
110
+ # Return empty but correctly typed DataFrame on failure
111
+ return self._meta_df.copy()
112
+
113
+ def _align_and_coerce_partition(self, df: pd.DataFrame) -> pd.DataFrame:
114
+ """
115
+ Aligns DataFrame partition to expected dtypes, coercing errors to nulls.
116
+ Explicitly handles PyArrow timestamps by converting to numpy arrays.
117
+ """
118
+ output_df = pd.DataFrame(index=df.index)
119
+
120
+ for col, target_dtype in self._meta_dtypes.items():
121
+ if col not in df.columns:
122
+ # Add missing column as nulls of the target type
123
+ output_df[col] = pd.Series(pd.NA, index=df.index, dtype=target_dtype)
124
+ continue
125
+
126
+ source_series = df[col]
136
127
  try:
137
- with self._sem:
138
- with self.engine.connect() as connection:
139
- count_q = sa.select(sa.func.count()).select_from(query.alias())
140
- total = connection.execute(count_q).scalar_one()
141
- break
142
- except SASQLTimeoutError:
143
- if attempt < retry_attempts - 1:
144
- self.logger.warning(f"Connection pool limit reached. Retrying in {backoff} seconds...", extra=self.logger_extra)
145
- time.sleep(backoff)
146
- backoff *= 2
128
+ if target_dtype == "datetime64[ns]":
129
+ # Convert to datetime, coercing errors to NaT
130
+ coerced_series = pd.to_datetime(source_series, errors='coerce')
131
+ # Ensure numpy backend by creating a new Series from values
132
+ output_df[col] = pd.Series(coerced_series.to_numpy(), index=coerced_series.index)
133
+ elif target_dtype == "Int64":
134
+ output_df[col] = pd.to_numeric(source_series, errors='coerce').astype("Int64")
135
+ elif target_dtype == "boolean":
136
+ # Handle boolean conversion with explicit mapping
137
+ if pd.api.types.is_bool_dtype(source_series.dtype):
138
+ output_df[col] = source_series.astype('boolean')
139
+ else:
140
+ output_df[col] = (
141
+ source_series.astype(str)
142
+ .str.lower()
143
+ .map({'true': True, '1': True, 'false': False, '0': False})
144
+ .astype('boolean')
145
+ )
147
146
  else:
148
- self.total_records = -1
149
- self.logger.error("Failed to get a connection from the pool after retries.", exc_info=True, extra=self.logger_extra)
150
- return self.total_records, dd.from_pandas(meta_df, npartitions=1)
151
- except OperationalError as oe:
152
- if "timeout" in str(oe).lower() and attempt < retry_attempts - 1:
153
- self.logger.warning("Operational timeout, retrying…", exc_info=self.debug, extra=self.logger_extra)
154
- time.sleep(backoff)
155
- backoff *= 2
156
- continue
157
- self.total_records = -1
158
- self.logger.error("OperationalError during count.", exc_info=True, extra=self.logger_extra)
159
- return self.total_records, dd.from_pandas(meta_df, npartitions=1)
160
- except Exception as e:
161
- self.total_records = -1
162
- self.logger.error(f"Unexpected error during count: {e}", exc_info=True, extra=self.logger_extra)
163
- return self.total_records, dd.from_pandas(meta_df, npartitions=1)
164
-
165
- self.total_records = int(total)
166
- if total == 0:
167
- self.logger.warning("Query returned 0 records.")
168
- super().close()
169
- return self.total_records, dd.from_pandas(meta_df, npartitions=1)
170
-
171
- self.logger.debug(f"Total records to fetch: {total}. Chunk size: {self.chunk_size}.", extra=self.logger_extra)
172
-
173
- @dask.delayed
174
- def get_chunk(sql_query, chunk_offset):
175
- with self._sem: # <<< cap concurrent DB fetches
176
- paginated = sql_query.limit(self.chunk_size).offset(chunk_offset)
177
- df = pd.read_sql(paginated, self.engine)
178
- if fillna_value is not None:
179
- df = df.fillna(fillna_value)
180
- return df[ordered_columns].astype(meta_dtypes)
147
+ output_df[col] = source_series.astype(target_dtype)
148
+ except Exception:
149
+ # Fallback to string type on any error
150
+ output_df[col] = source_series.astype("string")
151
+
152
+ return output_df
153
+
154
+ def _count_total(self, subquery: sa.sql.Select) -> int:
155
+ """Executes a COUNT(*) query safely."""
156
+ try:
157
+ with self._conn() as conn:
158
+ count_q = sa.select(func.count()).select_from(subquery.alias())
159
+ return conn.execute(count_q).scalar_one()
160
+ except Exception:
161
+ self.logger.error("Failed to count total records.", exc_info=True, extra=self.logger_extra)
162
+ return -1
163
+
164
+ def read_frame(self) -> Tuple[int, dd.DataFrame]:
165
+ base_select = select(self.model)
166
+ if self.filters:
167
+ base_select = self.filter_handler_cls(backend="sqlalchemy").apply_filters(base_select, self.model, self.filters)
168
+
169
+ total = self._count_total(base_select)
170
+ self.total_records = total
181
171
 
172
+ if total <= 0:
173
+ self.logger.info(f"Query returned {total} or failed to count records.", extra=self.logger_extra)
174
+ return total, dd.from_pandas(self._meta_df, npartitions=1)
175
+
176
+ # Simplified to offset pagination as it's the most robust
182
177
  offsets = range(0, total, self.chunk_size)
183
- delayed_chunks = [get_chunk(query, off) for off in offsets]
184
- ddf = dd.from_delayed(delayed_chunks, meta=meta_df)
185
- self.logger.debug(f"{self.model.__name__} created Dask DataFrame with {ddf.npartitions} partitions.", extra=self.logger_extra)
186
- return self.total_records, ddf
178
+ delayed_parts = [
179
+ dask.delayed(self._fetch_with_retry)(
180
+ base_select.limit(self.chunk_size).offset(off)
181
+ ) for off in offsets
182
+ ]
183
+
184
+ ddf = dd.from_delayed(delayed_parts, meta=self._meta_df, verify_meta=True)
185
+ return total, ddf
186
+
187
+ # --- Other helper methods (unchanged) ---
188
+ def _engine_pool_limits(self) -> Tuple[int, int]:
189
+ pool = getattr(self.engine, "pool", None)
190
+
191
+ def to_int(val, default):
192
+ try:
193
+ return int(val() if callable(val) else val)
194
+ except Exception:
195
+ return default
196
+
197
+ size = to_int(getattr(pool, "size", None), 5)
198
+ overflow = to_int(getattr(pool, "max_overflow", None) or getattr(pool, "_max_overflow", None), 10)
199
+ return size, overflow
200
+
201
+ @staticmethod
202
+ def _normalized_engine_key(engine: Engine) -> str:
203
+ try:
204
+ return str(engine.url.set(query=None).set(password=None))
205
+ except Exception:
206
+ return str(engine.url)
187
207
 
208
+ # from __future__ import annotations
209
+ #
210
+ # import time
211
+ # from typing import Any, Dict, Tuple, Type
212
+ #
213
+ # import dask
214
+ # import dask.dataframe as dd
215
+ # import pandas as pd
216
+ # import sqlalchemy as sa
217
+ # from sqlalchemy import select, inspect
218
+ # from sqlalchemy.engine import Engine
219
+ # from sqlalchemy.exc import TimeoutError as SASQLTimeoutError, OperationalError
220
+ # from sqlalchemy.orm import declarative_base
221
+ #
222
+ # from sibi_dst.utils import ManagedResource
223
+ # from sibi_dst.df_helper.core import FilterHandler
224
+ # from ._db_gatekeeper import DBGatekeeper
225
+ #
226
+ #
227
+ # class SQLAlchemyDask(ManagedResource):
228
+ # """
229
+ # Loads data from a database into a Dask DataFrame using a memory-safe,
230
+ # non-parallel, paginated approach (LIMIT/OFFSET).
231
+ # """
232
+ #
233
+ # _SQLALCHEMY_TO_DASK_DTYPE: Dict[str, str] = {
234
+ # "INTEGER": "Int64",
235
+ # "SMALLINT": "Int64",
236
+ # "BIGINT": "Int64",
237
+ # "FLOAT": "float64",
238
+ # "NUMERIC": "float64",
239
+ # "BOOLEAN": "bool",
240
+ # "VARCHAR": "object",
241
+ # "TEXT": "object",
242
+ # "DATE": "datetime64[ns]",
243
+ # "DATETIME": "datetime64[ns]",
244
+ # "TIMESTAMP": "datetime64[ns]",
245
+ # "TIME": "object",
246
+ # "UUID": "object",
247
+ # }
248
+ # logger_extra: Dict[str, Any] = {"sibi_dst_component": __name__}
249
+ #
250
+ # def __init__(
251
+ # self,
252
+ # model: Type[declarative_base()],
253
+ # filters: Dict[str, Any],
254
+ # engine: Engine,
255
+ # chunk_size: int = 1000,
256
+ # **kwargs: Any,
257
+ # ):
258
+ # super().__init__(**kwargs)
259
+ # self.model = model
260
+ # self.filters = filters or {}
261
+ # self.engine = engine
262
+ # self.chunk_size = int(chunk_size)
263
+ # self.filter_handler_cls = FilterHandler
264
+ # self.total_records: int = -1 # -1 indicates failure/unknown
265
+ # self._sem = DBGatekeeper.get(str(engine.url), max_concurrency=self._safe_cap())
266
+ #
267
+ # def _safe_cap(self) -> int:
268
+ # """
269
+ # Calculate a safe concurrency cap for DB work based on the engine's pool.
270
+ #
271
+ # Returns: max(1, pool_size + max_overflow - 1)
272
+ # - Works across SQLAlchemy 1.4/2.x
273
+ # - Tolerates pools that expose size/max_overflow as methods or attrs
274
+ # - Allows explicit override via self.db_gatekeeper_cap (if you pass it)
275
+ # """
276
+ # # optional explicit override
277
+ # explicit = getattr(self, "db_gatekeeper_cap", None)
278
+ # if isinstance(explicit, int) and explicit > 0:
279
+ # return explicit
280
+ #
281
+ # pool = getattr(self.engine, "pool", None)
282
+ #
283
+ # def _to_int(val, default):
284
+ # if val is None:
285
+ # return default
286
+ # if callable(val):
287
+ # try:
288
+ # return int(val()) # e.g., pool.size()
289
+ # except Exception:
290
+ # return default
291
+ # try:
292
+ # return int(val)
293
+ # except Exception:
294
+ # return default
295
+ #
296
+ # # size: QueuePool.size() -> int
297
+ # size_candidate = getattr(pool, "size", None) # method on QueuePool
298
+ # pool_size = _to_int(size_candidate, 5)
299
+ #
300
+ # # max_overflow: prefer attribute; fall back to private _max_overflow; avoid 'overflow()' (method)
301
+ # max_overflow_attr = (
302
+ # getattr(pool, "max_overflow", None) or # SQLAlchemy 2.x QueuePool
303
+ # getattr(pool, "_max_overflow", None) # private fallback
304
+ # )
305
+ # max_overflow = _to_int(max_overflow_attr, 10)
306
+ #
307
+ # cap = max(1, pool_size + max_overflow - 1)
308
+ # self.logger.debug(f"Using a Cap of {cap} from pool size of {pool_size} and max overflow of {max_overflow}.", extra=self.logger_extra)
309
+ # return max(1, cap)
310
+ #
311
+ # # ---------- meta ----------
312
+ # @classmethod
313
+ # def infer_meta_from_model(cls, model: Type[declarative_base()]) -> Dict[str, str]:
314
+ # mapper = inspect(model)
315
+ # dtypes: Dict[str, str] = {}
316
+ # for column in mapper.columns:
317
+ # dtype_str = str(column.type).upper().split("(")[0]
318
+ # dtype = cls._SQLALCHEMY_TO_DASK_DTYPE.get(dtype_str, "object")
319
+ # dtypes[column.name] = dtype
320
+ # return dtypes
321
+ #
322
+ # def read_frame(self, fillna_value=None) -> Tuple[int, dd.DataFrame]:
323
+ # # Base selectable
324
+ # query = select(self.model)
325
+ # if self.filters:
326
+ # query = self.filter_handler_cls(
327
+ # backend="sqlalchemy", logger=self.logger, debug=self.debug
328
+ # ).apply_filters(query, model=self.model, filters=self.filters)
329
+ # else:
330
+ # query = query.limit(self.chunk_size)
331
+ #
332
+ # # Meta dataframe (stable column order & dtypes)
333
+ # ordered_columns = [c.name for c in self.model.__table__.columns]
334
+ # meta_dtypes = self.infer_meta_from_model(self.model)
335
+ # meta_df = pd.DataFrame(columns=ordered_columns).astype(meta_dtypes)
336
+ #
337
+ # # Count with retry/backoff
338
+ # retry_attempts = 3
339
+ # backoff = 0.5
340
+ # total = 0
341
+ #
342
+ # for attempt in range(retry_attempts):
343
+ # try:
344
+ # with self._sem:
345
+ # with self.engine.connect() as connection:
346
+ # count_q = sa.select(sa.func.count()).select_from(query.alias())
347
+ # total = connection.execute(count_q).scalar_one()
348
+ # break
349
+ # except SASQLTimeoutError:
350
+ # if attempt < retry_attempts - 1:
351
+ # self.logger.warning(f"Connection pool limit reached. Retrying in {backoff} seconds...", extra=self.logger_extra)
352
+ # time.sleep(backoff)
353
+ # backoff *= 2
354
+ # else:
355
+ # self.total_records = -1
356
+ # self.logger.error("Failed to get a connection from the pool after retries.", exc_info=True, extra=self.logger_extra)
357
+ # return self.total_records, dd.from_pandas(meta_df, npartitions=1)
358
+ # except OperationalError as oe:
359
+ # if "timeout" in str(oe).lower() and attempt < retry_attempts - 1:
360
+ # self.logger.warning("Operational timeout, retrying…", exc_info=self.debug, extra=self.logger_extra)
361
+ # time.sleep(backoff)
362
+ # backoff *= 2
363
+ # continue
364
+ # self.total_records = -1
365
+ # self.logger.error("OperationalError during count.", exc_info=True, extra=self.logger_extra)
366
+ # return self.total_records, dd.from_pandas(meta_df, npartitions=1)
367
+ # except Exception as e:
368
+ # self.total_records = -1
369
+ # self.logger.error(f"Unexpected error during count: {e}", exc_info=True, extra=self.logger_extra)
370
+ # return self.total_records, dd.from_pandas(meta_df, npartitions=1)
371
+ #
372
+ # self.total_records = int(total)
373
+ # if total == 0:
374
+ # self.logger.warning("Query returned 0 records.")
375
+ # super().close()
376
+ # return self.total_records, dd.from_pandas(meta_df, npartitions=1)
377
+ #
378
+ # self.logger.debug(f"Total records to fetch: {total}. Chunk size: {self.chunk_size}.", extra=self.logger_extra)
379
+ #
380
+ # @dask.delayed
381
+ # def get_chunk(sql_query, chunk_offset):
382
+ # with self._sem: # <<< cap concurrent DB fetches
383
+ # paginated = sql_query.limit(self.chunk_size).offset(chunk_offset)
384
+ # df = pd.read_sql(paginated, self.engine)
385
+ # if fillna_value is not None:
386
+ # df = df.fillna(fillna_value)
387
+ # return df[ordered_columns].astype(meta_dtypes)
388
+ #
389
+ # offsets = range(0, total, self.chunk_size)
390
+ # delayed_chunks = [get_chunk(query, off) for off in offsets]
391
+ # ddf = dd.from_delayed(delayed_chunks, meta=meta_df)
392
+ # self.logger.debug(f"{self.model.__name__} created Dask DataFrame with {ddf.npartitions} partitions.", extra=self.logger_extra)
393
+ # return self.total_records, ddf
394
+ #
@@ -46,6 +46,7 @@ class SqlAlchemyLoadFromDb(ManagedResource):
46
46
  ) as loader:
47
47
  self.logger.debug(f"SQLAlchemyDask loader initialized for model: {self.model.__name__}", extra=self.logger_extra)
48
48
  self.total_records, dask_df = loader.read_frame()
49
+ dask_df = dask_df.persist(scheduler='threads')
49
50
  return self.total_records, dask_df
50
51
  except Exception as e:
51
52
  self.total_records = -1
@@ -54,3 +55,19 @@ class SqlAlchemyLoadFromDb(ManagedResource):
54
55
  columns = [c.name for c in self.model.__table__.columns]
55
56
  return self.total_records, dd.from_pandas(pd.DataFrame(columns=columns), npartitions=1)
56
57
 
58
+ def _cleanup(self) -> None:
59
+ """
60
+ DO NOT close the shared connection here.
61
+ but clean up instance references to prevent memory leaks.
62
+ """
63
+ try:
64
+ # Remove references but don't close shared connection
65
+ self.logger.debug(f"Cleaning up {self.__class__.__name__} instance references")
66
+ attrs_to_clean = ['db_connection', 'engine', 'model']
67
+ for attr in attrs_to_clean:
68
+ if hasattr(self, attr):
69
+ delattr(self, attr)
70
+
71
+ except Exception as e:
72
+ if self._log_cleanup_errors:
73
+ self.logger.warning(f"Error during cleanup: {e}", extra=self.logger_extra)