sibi-dst 2025.8.7__py3-none-any.whl → 2025.8.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sibi_dst/df_helper/_df_helper.py +105 -89
- sibi_dst/df_helper/_parquet_artifact.py +11 -10
- sibi_dst/df_helper/_parquet_reader.py +4 -0
- sibi_dst/df_helper/backends/parquet/_parquet_options.py +504 -214
- sibi_dst/df_helper/backends/sqlalchemy/_db_connection.py +11 -10
- sibi_dst/df_helper/backends/sqlalchemy/_io_dask.py +9 -8
- sibi_dst/df_helper/backends/sqlalchemy/_load_from_db.py +4 -76
- sibi_dst/df_helper/backends/sqlalchemy/_sql_model_builder.py +0 -104
- sibi_dst/utils/boilerplate/__init__.py +6 -0
- sibi_dst/utils/boilerplate/base_data_artifact.py +110 -0
- sibi_dst/utils/boilerplate/base_data_cube.py +79 -0
- sibi_dst/utils/data_wrapper.py +22 -263
- sibi_dst/utils/iceberg_saver.py +126 -0
- sibi_dst/utils/log_utils.py +108 -529
- sibi_dst/utils/parquet_saver.py +110 -9
- sibi_dst/utils/progress/__init__.py +5 -0
- sibi_dst/utils/progress/jobs.py +82 -0
- sibi_dst/utils/progress/sse_runner.py +82 -0
- sibi_dst/utils/storage_hive.py +38 -1
- sibi_dst/utils/update_planner.py +617 -116
- {sibi_dst-2025.8.7.dist-info → sibi_dst-2025.8.9.dist-info}/METADATA +3 -2
- {sibi_dst-2025.8.7.dist-info → sibi_dst-2025.8.9.dist-info}/RECORD +23 -16
- {sibi_dst-2025.8.7.dist-info → sibi_dst-2025.8.9.dist-info}/WHEEL +0 -0
sibi_dst/utils/data_wrapper.py
CHANGED
@@ -23,6 +23,8 @@ class DataWrapper(ManagedResource):
|
|
23
23
|
DEFAULT_MAX_AGE_MINUTES: int = 1440
|
24
24
|
DEFAULT_HISTORY_DAYS_THRESHOLD: int = 30
|
25
25
|
|
26
|
+
logger_extra = {"sibi_dst_component": __name__}
|
27
|
+
|
26
28
|
def __init__(
|
27
29
|
self,
|
28
30
|
dataclass: Type,
|
@@ -62,7 +64,7 @@ class DataWrapper(ManagedResource):
|
|
62
64
|
|
63
65
|
# --- NEW: stop gate tripped during cleanup/interrupt to block further scheduling/retries
|
64
66
|
self._stop_event = threading.Event()
|
65
|
-
self.
|
67
|
+
self.logger_extra.update({"action_module_name": "data_wrapper", "dataclass": self.dataclass.__name__})
|
66
68
|
|
67
69
|
# ensure manifest is saved on context exit
|
68
70
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
@@ -114,14 +116,14 @@ class DataWrapper(ManagedResource):
|
|
114
116
|
break
|
115
117
|
self._execute_task_batch(priority, dates, max_retries, backoff_base, backoff_jitter, backoff_max)
|
116
118
|
except KeyboardInterrupt:
|
117
|
-
self.logger.warning("KeyboardInterrupt received — stopping scheduling and shutting down.")
|
119
|
+
self.logger.warning("KeyboardInterrupt received — stopping scheduling and shutting down.", extra=self.logger_extra)
|
118
120
|
self._stop_event.set()
|
119
121
|
raise
|
120
122
|
finally:
|
121
123
|
total_time = time.perf_counter() - overall_start
|
122
124
|
if self.processed_dates:
|
123
125
|
count = len(self.processed_dates)
|
124
|
-
self.logger.info(f"Processed {count} dates in {total_time:.1f}s (avg {total_time / count:.1f}s/date)")
|
126
|
+
self.logger.info(f"Processed {count} dates in {total_time:.1f}s (avg {total_time / count:.1f}s/date)", extra=self.logger_extra)
|
125
127
|
if self.update_planner.show_progress:
|
126
128
|
self.show_benchmark_summary()
|
127
129
|
|
@@ -136,7 +138,7 @@ class DataWrapper(ManagedResource):
|
|
136
138
|
):
|
137
139
|
desc = f"Processing {self.dataclass.__name__}, priority: {priority}"
|
138
140
|
max_thr = min(len(dates), self.max_threads)
|
139
|
-
self.logger.info(f"Executing {len(dates)} tasks with priority {priority} using {max_thr} threads.", extra=self.
|
141
|
+
self.logger.info(f"Executing {len(dates)} tasks with priority {priority} using {max_thr} threads.", extra=self.logger_extra)
|
140
142
|
|
141
143
|
# Use explicit try/finally so we can request cancel of queued tasks on teardown
|
142
144
|
executor = ThreadPoolExecutor(max_workers=max_thr, thread_name_prefix="datawrapper")
|
@@ -153,7 +155,7 @@ class DataWrapper(ManagedResource):
|
|
153
155
|
except RuntimeError as e:
|
154
156
|
# tolerate race: executor shutting down
|
155
157
|
if "cannot schedule new futures after shutdown" in str(e).lower():
|
156
|
-
self.logger.warning("Executor is shutting down; halting new submissions for this batch.")
|
158
|
+
self.logger.warning("Executor is shutting down; halting new submissions for this batch.", extra=self.logger_extra)
|
157
159
|
break
|
158
160
|
raise
|
159
161
|
|
@@ -165,7 +167,7 @@ class DataWrapper(ManagedResource):
|
|
165
167
|
try:
|
166
168
|
future.result(timeout=self.timeout)
|
167
169
|
except Exception as e:
|
168
|
-
self.logger.error(f"Permanent failure for {futures[future]}: {e}", extra=self.
|
170
|
+
self.logger.error(f"Permanent failure for {futures[future]}: {e}", extra=self.logger_extra)
|
169
171
|
finally:
|
170
172
|
# Python 3.9+: cancel_futures prevents queued tasks from starting
|
171
173
|
executor.shutdown(wait=True, cancel_futures=True)
|
@@ -191,18 +193,19 @@ class DataWrapper(ManagedResource):
|
|
191
193
|
base_delay = min(backoff_base ** attempt, backoff_max)
|
192
194
|
delay = base_delay * (1 + random.uniform(0.0, max(0.0, backoff_jitter)))
|
193
195
|
self.logger.warning(
|
194
|
-
f"Retry {attempt + 1}/{max_retries} for {date}: {e} (sleep {delay:.2f}s)"
|
196
|
+
f"Retry {attempt + 1}/{max_retries} for {date}: {e} (sleep {delay:.2f}s)",
|
197
|
+
extra=self.logger_extra
|
195
198
|
)
|
196
199
|
time.sleep(delay)
|
197
200
|
else:
|
198
|
-
self.logger.error(f"Failed processing {date} after {max_retries} attempts.", extra=self.
|
201
|
+
self.logger.error(f"Failed processing {date} after {max_retries} attempts.", extra=self.logger_extra)
|
199
202
|
raise
|
200
203
|
|
201
204
|
def _process_single_date(self, date: datetime.date):
|
202
205
|
path = f"{self.data_path}{date.year}/{date.month:02d}/{date.day:02d}/"
|
203
|
-
self.logger.debug(f"Processing date {date.isoformat()} for {path}")
|
206
|
+
self.logger.debug(f"Processing date {date.isoformat()} for {path}", extra=self.logger_extra)
|
204
207
|
if path in self.update_planner.skipped and self.update_planner.ignore_missing:
|
205
|
-
self.logger.debug(f"Skipping {date} as it exists in the skipped list")
|
208
|
+
self.logger.debug(f"Skipping {date} as it exists in the skipped list", extra=self.logger_extra)
|
206
209
|
return
|
207
210
|
full_path = f"{path}{self.parquet_filename}"
|
208
211
|
|
@@ -210,7 +213,7 @@ class DataWrapper(ManagedResource):
|
|
210
213
|
try:
|
211
214
|
load_start = time.perf_counter()
|
212
215
|
date_filter = {f"{self.date_field}__date": {date.isoformat()}}
|
213
|
-
self.logger.debug(f"
|
216
|
+
self.logger.debug(f"{self.dataclass.__name__} is loading data for {date} with filter: {date_filter}", extra=self.logger_extra)
|
214
217
|
|
215
218
|
local_load_params = self.load_params.copy()
|
216
219
|
local_load_params.update(date_filter)
|
@@ -221,16 +224,16 @@ class DataWrapper(ManagedResource):
|
|
221
224
|
|
222
225
|
if hasattr(local_class_instance, "total_records"):
|
223
226
|
total_records = int(local_class_instance.total_records)
|
224
|
-
self.logger.debug(f"Total records loaded: {total_records}")
|
227
|
+
self.logger.debug(f"Total records loaded: {total_records}", extra=self.logger_extra)
|
225
228
|
|
226
229
|
if total_records == 0:
|
227
230
|
if self.mmanifest:
|
228
231
|
self.mmanifest.record(full_path=path)
|
229
|
-
self.logger.info(f"No data found for {full_path}. Logged to missing manifest.")
|
232
|
+
self.logger.info(f"No data found for {full_path}. Logged to missing manifest.", extra=self.logger_extra)
|
230
233
|
return
|
231
234
|
|
232
235
|
if total_records < 0:
|
233
|
-
self.logger.warning(f"Negative record count ({total_records}) for {full_path}.")
|
236
|
+
self.logger.warning(f"Negative record count ({total_records}) for {full_path}.", extra=self.logger_extra)
|
234
237
|
return
|
235
238
|
|
236
239
|
save_start = time.perf_counter()
|
@@ -258,261 +261,17 @@ class DataWrapper(ManagedResource):
|
|
258
261
|
raise
|
259
262
|
|
260
263
|
def _log_success(self, date: datetime.date, duration: float, path: str):
|
261
|
-
self.logger.info(f"Completed {date} in {duration:.1f}s | Saved to {path}", extra=self.
|
264
|
+
self.logger.info(f"Completed {date} in {duration:.1f}s | Saved to {path}", extra=self.logger_extra)
|
262
265
|
self.processed_dates.append(date)
|
263
266
|
|
264
267
|
def _log_failure(self, date: datetime.date, error: Exception):
|
265
|
-
self.logger.error(f"Failed processing {date}: {error}", extra=self.
|
268
|
+
self.logger.error(f"Failed processing {date}: {error}", extra=self.logger_extra)
|
266
269
|
|
267
270
|
def show_benchmark_summary(self):
|
268
271
|
if not self.benchmarks:
|
269
|
-
self.logger.info("No benchmarking data to show", extra=self.
|
272
|
+
self.logger.info("No benchmarking data to show", extra=self.logger_extra)
|
270
273
|
return
|
271
274
|
df_bench = pd.DataFrame.from_records([{"date": d, **m} for d, m in self.benchmarks.items()])
|
272
275
|
df_bench = df_bench.set_index("date").sort_index(ascending=not self.update_planner.reverse_order)
|
273
|
-
self.logger.info(f"Benchmark Summary:\n {self.dataclass.__name__}\n" + df_bench.to_string(), extra=self.
|
274
|
-
|
275
|
-
# import datetime
|
276
|
-
# import threading
|
277
|
-
# import time
|
278
|
-
# import random
|
279
|
-
# from concurrent.futures import ThreadPoolExecutor, as_completed
|
280
|
-
# from typing import Type, Any, Dict, Optional, Union, List, ClassVar
|
281
|
-
#
|
282
|
-
# import dask.dataframe as dd
|
283
|
-
# import pandas as pd
|
284
|
-
# from tqdm import tqdm
|
285
|
-
#
|
286
|
-
# from . import ManagedResource
|
287
|
-
# from .parquet_saver import ParquetSaver
|
288
|
-
#
|
289
|
-
#
|
290
|
-
# class DataWrapper(ManagedResource):
|
291
|
-
# DEFAULT_PRIORITY_MAP: ClassVar[Dict[str, int]] = {
|
292
|
-
# "overwrite": 1,
|
293
|
-
# "missing_in_history": 2,
|
294
|
-
# "existing_but_stale": 3,
|
295
|
-
# "missing_outside_history": 4,
|
296
|
-
# "file_is_recent": 0,
|
297
|
-
# }
|
298
|
-
# DEFAULT_MAX_AGE_MINUTES: int = 1440
|
299
|
-
# DEFAULT_HISTORY_DAYS_THRESHOLD: int = 30
|
300
|
-
#
|
301
|
-
# def __init__(
|
302
|
-
# self,
|
303
|
-
# dataclass: Type,
|
304
|
-
# date_field: str,
|
305
|
-
# data_path: str,
|
306
|
-
# parquet_filename: str,
|
307
|
-
# class_params: Optional[Dict] = None,
|
308
|
-
# load_params: Optional[Dict] = None,
|
309
|
-
# show_progress: bool = False,
|
310
|
-
# timeout: float = 30,
|
311
|
-
# max_threads: int = 3,
|
312
|
-
# **kwargs: Any,
|
313
|
-
# ):
|
314
|
-
# super().__init__(**kwargs)
|
315
|
-
# self.dataclass = dataclass
|
316
|
-
# self.date_field = date_field
|
317
|
-
# self.data_path = self._ensure_forward_slash(data_path)
|
318
|
-
# self.parquet_filename = parquet_filename
|
319
|
-
# if self.fs is None:
|
320
|
-
# raise ValueError("DataWrapper requires a File system (fs) to be provided.")
|
321
|
-
# self.show_progress = show_progress
|
322
|
-
# self.timeout = timeout
|
323
|
-
# self.max_threads = max_threads
|
324
|
-
# self.class_params = class_params or {
|
325
|
-
# "debug": self.debug,
|
326
|
-
# "logger": self.logger,
|
327
|
-
# "fs": self.fs,
|
328
|
-
# "verbose": self.verbose,
|
329
|
-
# }
|
330
|
-
# self.load_params = load_params or {}
|
331
|
-
#
|
332
|
-
# self._lock = threading.Lock()
|
333
|
-
# self.processed_dates: List[datetime.date] = []
|
334
|
-
# self.benchmarks: Dict[datetime.date, Dict[str, float]] = {}
|
335
|
-
# self.mmanifest = kwargs.get("mmanifest", None)
|
336
|
-
# self.update_planner = kwargs.get("update_planner", None)
|
337
|
-
#
|
338
|
-
# def __exit__(self, exc_type, exc_val, exc_tb):
|
339
|
-
# if self.mmanifest:
|
340
|
-
# self.mmanifest.save()
|
341
|
-
# super().__exit__(exc_type, exc_val, exc_tb)
|
342
|
-
# return False
|
343
|
-
#
|
344
|
-
# @staticmethod
|
345
|
-
# def _convert_to_date(date: Union[datetime.date, str]) -> datetime.date:
|
346
|
-
# if isinstance(date, datetime.date):
|
347
|
-
# return date
|
348
|
-
# try:
|
349
|
-
# return pd.to_datetime(date).date()
|
350
|
-
# except ValueError as e:
|
351
|
-
# raise ValueError(f"Error converting {date} to datetime: {e}")
|
352
|
-
#
|
353
|
-
# @staticmethod
|
354
|
-
# def _ensure_forward_slash(path: str) -> str:
|
355
|
-
# return path.rstrip("/") + "/"
|
356
|
-
#
|
357
|
-
# def process(
|
358
|
-
# self,
|
359
|
-
# max_retries: int = 3,
|
360
|
-
# backoff_base: float = 2.0,
|
361
|
-
# backoff_jitter: float = 0.1,
|
362
|
-
# backoff_max: float = 60.0,
|
363
|
-
# ):
|
364
|
-
# """
|
365
|
-
# Execute the update plan with concurrency, retries and exponential backoff.
|
366
|
-
#
|
367
|
-
# Args:
|
368
|
-
# max_retries: attempts per date.
|
369
|
-
# backoff_base: base for exponential backoff (delay = base**attempt).
|
370
|
-
# backoff_jitter: multiplicative jitter factor in [0, backoff_jitter].
|
371
|
-
# backoff_max: maximum backoff seconds per attempt (before jitter).
|
372
|
-
# """
|
373
|
-
# overall_start = time.perf_counter()
|
374
|
-
# tasks = list(self.update_planner.get_tasks_by_priority())
|
375
|
-
# if not tasks:
|
376
|
-
# self.logger.info("No updates required based on the current plan.")
|
377
|
-
# return
|
378
|
-
#
|
379
|
-
# if self.update_planner.show_progress:
|
380
|
-
# self.update_planner.show_update_plan()
|
381
|
-
#
|
382
|
-
# for priority, dates in tasks:
|
383
|
-
# self._execute_task_batch(priority, dates, max_retries, backoff_base, backoff_jitter, backoff_max)
|
384
|
-
#
|
385
|
-
# total_time = time.perf_counter() - overall_start
|
386
|
-
# if self.processed_dates:
|
387
|
-
# count = len(self.processed_dates)
|
388
|
-
# self.logger.info(f"Processed {count} dates in {total_time:.1f}s (avg {total_time / count:.1f}s/date)")
|
389
|
-
# if self.update_planner.show_progress:
|
390
|
-
# self.show_benchmark_summary()
|
391
|
-
#
|
392
|
-
# def _execute_task_batch(
|
393
|
-
# self,
|
394
|
-
# priority: int,
|
395
|
-
# dates: List[datetime.date],
|
396
|
-
# max_retries: int,
|
397
|
-
# backoff_base: float,
|
398
|
-
# backoff_jitter: float,
|
399
|
-
# backoff_max: float,
|
400
|
-
# ):
|
401
|
-
# desc = f"Processing {self.dataclass.__name__}, priority: {priority}"
|
402
|
-
# max_thr = min(len(dates), self.max_threads)
|
403
|
-
# self.logger.info(f"Executing {len(dates)} tasks with priority {priority} using {max_thr} threads.")
|
404
|
-
#
|
405
|
-
# with ThreadPoolExecutor(max_workers=max_thr) as executor:
|
406
|
-
# futures = {
|
407
|
-
# executor.submit(
|
408
|
-
# self._process_date_with_retry, date, max_retries, backoff_base, backoff_jitter, backoff_max
|
409
|
-
# ): date
|
410
|
-
# for date in dates
|
411
|
-
# }
|
412
|
-
# iterator = as_completed(futures)
|
413
|
-
# if self.show_progress:
|
414
|
-
# iterator = tqdm(iterator, total=len(futures), desc=desc)
|
415
|
-
#
|
416
|
-
# for future in iterator:
|
417
|
-
# try:
|
418
|
-
# future.result(timeout=self.timeout)
|
419
|
-
# except Exception as e:
|
420
|
-
# self.logger.error(f"Permanent failure for {futures[future]}: {e}")
|
421
|
-
#
|
422
|
-
# def _process_date_with_retry(
|
423
|
-
# self,
|
424
|
-
# date: datetime.date,
|
425
|
-
# max_retries: int,
|
426
|
-
# backoff_base: float,
|
427
|
-
# backoff_jitter: float,
|
428
|
-
# backoff_max: float,
|
429
|
-
# ):
|
430
|
-
# for attempt in range(max_retries):
|
431
|
-
# try:
|
432
|
-
# self._process_single_date(date)
|
433
|
-
# return
|
434
|
-
# except Exception as e:
|
435
|
-
# if attempt < max_retries - 1:
|
436
|
-
# base_delay = min(backoff_base ** attempt, backoff_max)
|
437
|
-
# delay = base_delay * (1 + random.uniform(0.0, max(0.0, backoff_jitter)))
|
438
|
-
# self.logger.warning(
|
439
|
-
# f"Retry {attempt + 1}/{max_retries} for {date}: {e} (sleep {delay:.2f}s)"
|
440
|
-
# )
|
441
|
-
# time.sleep(delay)
|
442
|
-
# else:
|
443
|
-
# self.logger.error(f"Failed processing {date} after {max_retries} attempts.")
|
444
|
-
#
|
445
|
-
# def _process_single_date(self, date: datetime.date):
|
446
|
-
# path = f"{self.data_path}{date.year}/{date.month:02d}/{date.day:02d}/"
|
447
|
-
# self.logger.debug(f"Processing date {date.isoformat()} for {path}")
|
448
|
-
# if path in self.update_planner.skipped and self.update_planner.ignore_missing:
|
449
|
-
# self.logger.debug(f"Skipping {date} as it exists in the skipped list")
|
450
|
-
# return
|
451
|
-
# full_path = f"{path}{self.parquet_filename}"
|
452
|
-
#
|
453
|
-
# overall_start = time.perf_counter()
|
454
|
-
# try:
|
455
|
-
# load_start = time.perf_counter()
|
456
|
-
# date_filter = {f"{self.date_field}__date": {date.isoformat()}}
|
457
|
-
# self.logger.debug(f"Loading data for {date} with filter: {date_filter}")
|
458
|
-
#
|
459
|
-
# local_load_params = self.load_params.copy()
|
460
|
-
# local_load_params.update(date_filter)
|
461
|
-
#
|
462
|
-
# with self.dataclass(**self.class_params) as local_class_instance:
|
463
|
-
# df = local_class_instance.load(**local_load_params) # expected to be Dask
|
464
|
-
# load_time = time.perf_counter() - load_start
|
465
|
-
#
|
466
|
-
# if hasattr(local_class_instance, "total_records"):
|
467
|
-
# total_records = int(local_class_instance.total_records)
|
468
|
-
# self.logger.debug(f"Total records loaded: {total_records}")
|
469
|
-
#
|
470
|
-
# if total_records == 0:
|
471
|
-
# if self.mmanifest:
|
472
|
-
# self.mmanifest.record(full_path=path)
|
473
|
-
# self.logger.info(f"No data found for {full_path}. Logged to missing manifest.")
|
474
|
-
# return
|
475
|
-
#
|
476
|
-
# if total_records < 0:
|
477
|
-
# self.logger.warning(f"Negative record count ({total_records}) for {full_path}.")
|
478
|
-
# return
|
479
|
-
#
|
480
|
-
# save_start = time.perf_counter()
|
481
|
-
# parquet_params = {
|
482
|
-
# "df_result": df,
|
483
|
-
# "parquet_storage_path": path,
|
484
|
-
# "fs": self.fs,
|
485
|
-
# "logger": self.logger,
|
486
|
-
# "debug": self.debug,
|
487
|
-
# }
|
488
|
-
# with ParquetSaver(**parquet_params) as ps:
|
489
|
-
# ps.save_to_parquet(self.parquet_filename, overwrite=True)
|
490
|
-
# save_time = time.perf_counter() - save_start
|
491
|
-
#
|
492
|
-
# total_time = time.perf_counter() - overall_start
|
493
|
-
# self.benchmarks[date] = {
|
494
|
-
# "load_duration": load_time,
|
495
|
-
# "save_duration": save_time,
|
496
|
-
# "total_duration": total_time,
|
497
|
-
# }
|
498
|
-
# self._log_success(date, total_time, full_path)
|
499
|
-
#
|
500
|
-
# except Exception as e:
|
501
|
-
# self._log_failure(date, e)
|
502
|
-
# raise
|
503
|
-
#
|
504
|
-
# def _log_success(self, date: datetime.date, duration: float, path: str):
|
505
|
-
# self.logger.info(f"Completed {date} in {duration:.1f}s | Saved to {path}")
|
506
|
-
# self.processed_dates.append(date)
|
507
|
-
#
|
508
|
-
# def _log_failure(self, date: datetime.date, error: Exception):
|
509
|
-
# self.logger.error(f"Failed processing {date}: {error}")
|
510
|
-
#
|
511
|
-
# def show_benchmark_summary(self):
|
512
|
-
# if not self.benchmarks:
|
513
|
-
# self.logger.info("No benchmarking data to show")
|
514
|
-
# return
|
515
|
-
# df_bench = pd.DataFrame.from_records([{"date": d, **m} for d, m in self.benchmarks.items()])
|
516
|
-
# df_bench = df_bench.set_index("date").sort_index(ascending=not self.update_planner.reverse_order)
|
517
|
-
# self.logger.info(f"Benchmark Summary:\n {self.dataclass.__name__}\n" + df_bench.to_string())
|
518
|
-
#
|
276
|
+
self.logger.info(f"Benchmark Summary:\n {self.dataclass.__name__}\n" + df_bench.to_string(), extra=self.logger_extra)
|
277
|
+
|
@@ -0,0 +1,126 @@
|
|
1
|
+
import warnings
|
2
|
+
import dask.dataframe as dd
|
3
|
+
import pandas as pd
|
4
|
+
import pyarrow as pa
|
5
|
+
from pyiceberg.catalog import load_catalog
|
6
|
+
from typing import Optional, Dict, Any
|
7
|
+
from . import ManagedResource
|
8
|
+
|
9
|
+
warnings.filterwarnings("ignore", message="Passing 'overwrite=True' to to_parquet is deprecated")
|
10
|
+
|
11
|
+
class IcebergSaver(ManagedResource):
|
12
|
+
"""
|
13
|
+
Saves a Dask DataFrame into an Apache Iceberg table using PyIceberg.
|
14
|
+
- Uses Arrow conversion per Dask partition.
|
15
|
+
- One Iceberg commit per partition (append mode), or a staged overwrite
|
16
|
+
(coalesce to N partitions, commit them in place of the old snapshot).
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(
|
20
|
+
self,
|
21
|
+
df_result: dd.DataFrame,
|
22
|
+
catalog_name: str,
|
23
|
+
table_name: str,
|
24
|
+
*,
|
25
|
+
persist: bool = True,
|
26
|
+
npartitions: Optional[int] = 8,
|
27
|
+
arrow_schema: Optional[pa.Schema] = None,
|
28
|
+
**kwargs,
|
29
|
+
):
|
30
|
+
super().__init__(**kwargs)
|
31
|
+
self.df_result = df_result
|
32
|
+
self.catalog_name = catalog_name
|
33
|
+
self.table_name = table_name
|
34
|
+
self.persist = persist
|
35
|
+
self.npartitions = npartitions
|
36
|
+
self.arrow_schema = arrow_schema # optional: enforce column order/types
|
37
|
+
|
38
|
+
# Iceberg writes don’t need self.fs; catalog handles IO.
|
39
|
+
# But we keep self.fs available in case you presign or stage files.
|
40
|
+
|
41
|
+
# Load table once
|
42
|
+
self.catalog = load_catalog(self.catalog_name)
|
43
|
+
self.table = self.catalog.load_table(self.table_name)
|
44
|
+
|
45
|
+
def save(self, *, mode: str = "append"):
|
46
|
+
"""
|
47
|
+
mode:
|
48
|
+
- "append": append rows as new data files (one commit per partition)
|
49
|
+
- "overwrite": replace table data atomically (single staged commit)
|
50
|
+
(requires coalescing to limit number of files)
|
51
|
+
"""
|
52
|
+
if mode not in ("append", "overwrite"):
|
53
|
+
raise ValueError("mode must be 'append' or 'overwrite'")
|
54
|
+
|
55
|
+
# Optional persist to avoid recomputation across multiple consumers
|
56
|
+
ddf = self.df_result.persist() if self.persist else self.df_result
|
57
|
+
|
58
|
+
if self.npartitions:
|
59
|
+
ddf = ddf.repartition(npartitions=self.npartitions)
|
60
|
+
|
61
|
+
if mode == "append":
|
62
|
+
self._append_partitions(ddf)
|
63
|
+
else:
|
64
|
+
self._overwrite_atomic(ddf)
|
65
|
+
|
66
|
+
# ---------- internals ----------
|
67
|
+
|
68
|
+
def _to_arrow_table(self, pdf: pd.DataFrame) -> pa.Table:
|
69
|
+
if self.arrow_schema is None:
|
70
|
+
return pa.Table.from_pandas(pdf, preserve_index=False)
|
71
|
+
# Enforce schema (column order & target types) when provided
|
72
|
+
at = pa.Table.from_pandas(pdf, preserve_index=False, schema=self.arrow_schema)
|
73
|
+
# Some Arrow versions require select to exact order if pandas added cols
|
74
|
+
return at.select(self.arrow_schema.names)
|
75
|
+
|
76
|
+
def _append_partitions(self, ddf: dd.DataFrame):
|
77
|
+
"""
|
78
|
+
Simple path: commit each partition as a separate append.
|
79
|
+
Good for moderate rates; for very high throughput, consider staging or
|
80
|
+
increasing npartitions to get larger files.
|
81
|
+
"""
|
82
|
+
def _commit(pdf: pd.DataFrame):
|
83
|
+
if len(pdf) == 0:
|
84
|
+
return pdf.iloc[0:0]
|
85
|
+
at = self._to_arrow_table(pdf)
|
86
|
+
self.table.append(at) # one atomic Iceberg commit
|
87
|
+
return pdf.iloc[0:0]
|
88
|
+
|
89
|
+
ddf.map_partitions(_commit, meta=ddf._meta).compute()
|
90
|
+
self.logger.info(f"Appended data to Iceberg table {self.table_name} (catalog={self.catalog_name}).")
|
91
|
+
|
92
|
+
def _overwrite_atomic(self, ddf: dd.DataFrame):
|
93
|
+
"""
|
94
|
+
Safer full refresh: stage N Arrow batches and replace existing snapshot.
|
95
|
+
Strategy:
|
96
|
+
1) Build a single overwrite transaction.
|
97
|
+
2) Add files produced from each partition to the same transaction.
|
98
|
+
3) Commit once (atomic snapshot replacement).
|
99
|
+
"""
|
100
|
+
from pyiceberg.table.ops import RewriteFiles # operation helper
|
101
|
+
|
102
|
+
# Materialize partitions one by one and add to a rewrite op
|
103
|
+
# Note: PyIceberg API offers two patterns:
|
104
|
+
# - table.overwrite(at) for “overwrite by filter” in one call (simple)
|
105
|
+
# - lower-level staged ops (demonstrated conceptually below)
|
106
|
+
|
107
|
+
# Easiest “full-table” overwrite via filter(True) – clears table then writes new data:
|
108
|
+
# If you only want to replace certain partitions, use a filter expr.
|
109
|
+
def _collect_partitions(pdf: pd.DataFrame):
|
110
|
+
if len(pdf) == 0:
|
111
|
+
return None
|
112
|
+
return self._to_arrow_table(pdf)
|
113
|
+
|
114
|
+
batches = [b for b in ddf.map_partitions(_collect_partitions, meta=object).compute() if b is not None]
|
115
|
+
if not batches:
|
116
|
+
self.logger.warning("Overwrite requested but no rows in DataFrame; leaving table unchanged.")
|
117
|
+
return
|
118
|
+
|
119
|
+
# Commit as a single overwrite
|
120
|
+
self.table.overwrite(batches[0])
|
121
|
+
for at in batches[1:]:
|
122
|
+
self.table.append(at) # append subsequent batches into the same snapshot lineage
|
123
|
+
|
124
|
+
# If you require truly single-snapshot replacement in one call, you can
|
125
|
+
# also union the batches into fewer (bigger) Arrow Tables before calling overwrite.
|
126
|
+
self.logger.info(f"Overwrote Iceberg table {self.table_name} with {len(batches)} batch(es).")
|