neo4j-etl-lib 0.3.1__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/PKG-INFO +1 -1
  2. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/__init__.py +1 -1
  3. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/ETLContext.py +2 -9
  4. neo4j_etl_lib-0.3.2/src/etl_lib/core/ParallelBatchProcessor.py +150 -0
  5. neo4j_etl_lib-0.3.2/src/etl_lib/core/SplittingBatchProcessor.py +391 -0
  6. neo4j_etl_lib-0.3.2/src/etl_lib/data_source/SQLBatchSource.py +84 -0
  7. neo4j_etl_lib-0.3.1/src/etl_lib/core/ParallelBatchProcessor.py +0 -180
  8. neo4j_etl_lib-0.3.1/src/etl_lib/core/SplittingBatchProcessor.py +0 -268
  9. neo4j_etl_lib-0.3.1/src/etl_lib/data_source/SQLBatchSource.py +0 -114
  10. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/LICENSE +0 -0
  11. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/README.md +0 -0
  12. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/pyproject.toml +0 -0
  13. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/cli/__init__.py +0 -0
  14. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/cli/run_tools.py +0 -0
  15. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/BatchProcessor.py +0 -0
  16. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/ClosedLoopBatchProcessor.py +0 -0
  17. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/ProgressReporter.py +0 -0
  18. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/Task.py +0 -0
  19. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/ValidationBatchProcessor.py +0 -0
  20. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/__init__.py +0 -0
  21. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/utils.py +0 -0
  22. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_sink/CSVBatchSink.py +0 -0
  23. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_sink/CypherBatchSink.py +0 -0
  24. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_sink/SQLBatchSink.py +0 -0
  25. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_sink/__init__.py +0 -0
  26. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_source/CSVBatchSource.py +0 -0
  27. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_source/CypherBatchSource.py +0 -0
  28. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_source/__init__.py +0 -0
  29. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/CreateReportingConstraintsTask.py +0 -0
  30. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/ExecuteCypherTask.py +0 -0
  31. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/GDSTask.py +0 -0
  32. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/__init__.py +0 -0
  33. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/data_loading/CSVLoad2Neo4jTask.py +0 -0
  34. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/data_loading/ParallelCSVLoad2Neo4jTask.py +0 -0
  35. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/data_loading/ParallelSQLLoad2Neo4jTask.py +0 -0
  36. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/data_loading/SQLLoad2Neo4jTask.py +0 -0
  37. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/data_loading/__init__.py +0 -0
  38. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/test_utils/__init__.py +0 -0
  39. {neo4j_etl_lib-0.3.1 → neo4j_etl_lib-0.3.2}/src/etl_lib/test_utils/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: neo4j-etl-lib
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: Building blocks for ETL pipelines.
5
5
  Keywords: etl,graph,database
6
6
  Author-email: Bert Radke <bert.radke@pm.me>
@@ -1,4 +1,4 @@
1
1
  """
2
2
  Building blocks for ETL pipelines.
3
3
  """
4
- __version__ = "0.3.1"
4
+ __version__ = "0.3.2"
@@ -172,15 +172,8 @@ if sqlalchemy_available:
172
172
  database_url,
173
173
  pool_pre_ping=True,
174
174
  pool_size=pool_size,
175
- max_overflow=max_overflow,
176
- pool_recycle=1800, # recycle connections older than 30m
177
- connect_args={
178
- # turn on TCP keepalives on the client socket:
179
- "keepalives": 1,
180
- "keepalives_idle": 60, # after 60s of idle
181
- "keepalives_interval": 10, # probe every 10s
182
- "keepalives_count": 5, # give up after 5 failed probes
183
- })
175
+ max_overflow=max_overflow
176
+ )
184
177
 
185
178
 
186
179
  class ETLContext:
@@ -0,0 +1,150 @@
1
+ import queue
2
+ import threading
3
+ from concurrent.futures import ThreadPoolExecutor, as_completed
4
+ from typing import Any, Callable, Generator, List
5
+
6
+ from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
7
+ from etl_lib.core.utils import merge_summery
8
+
9
+
10
+ class ParallelBatchResult(BatchResults):
11
+ """
12
+ Represents one *wave* produced by the splitter.
13
+
14
+ `chunk` is a list of bucket-batches. Each sub-list is processed by one worker instance.
15
+ """
16
+ pass
17
+
18
+
19
+ class ParallelBatchProcessor(BatchProcessor):
20
+ """
21
+ BatchProcessor that runs a worker over the bucket-batches of each ParallelBatchResult
22
+ in parallel threads, while prefetching the next ParallelBatchResult from its predecessor.
23
+
24
+ Note:
25
+ - The predecessor must emit `ParallelBatchResult` instances (waves).
26
+ - This processor collects the BatchResults from all workers for one wave and merges them
27
+ into one BatchResults.
28
+ - The returned BatchResults will not obey the max_batch_size from get_batch() because
29
+ it represents the full wave.
30
+
31
+ Args:
32
+ context: ETL context.
33
+ worker_factory: A zero-arg callable that returns a new BatchProcessor each time it's called.
34
+ task: optional Task for reporting.
35
+ predecessor: upstream BatchProcessor that must emit ParallelBatchResult See
36
+ :class:`~etl_lib.core.SplittingBatchProcessor.SplittingBatchProcessor`.
37
+ max_workers: number of parallel threads for bucket processing.
38
+ prefetch: number of waves to prefetch.
39
+
40
+ Behavior:
41
+ - For every wave, spins up `max_workers` threads.
42
+ - Each thread processes one bucket-batch using a fresh worker from `worker_factory()`.
43
+ - Collects and merges worker results in a fail-fast manner.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ context,
49
+ worker_factory: Callable[[], BatchProcessor],
50
+ task=None,
51
+ predecessor=None,
52
+ max_workers: int = 4,
53
+ prefetch: int = 4,
54
+ ):
55
+ super().__init__(context, task, predecessor)
56
+ self.worker_factory = worker_factory
57
+ self.max_workers = max_workers
58
+ self.prefetch = prefetch
59
+
60
+ def _process_wave(self, wave: ParallelBatchResult) -> BatchResults:
61
+ """
62
+ Process one wave: run one worker per bucket-batch and merge their BatchResults.
63
+
64
+ Statistics:
65
+ `wave.statistics` is used as the initial merged stats, then merged with each worker's stats.
66
+ """
67
+ merged_stats = dict(wave.statistics or {})
68
+ merged_chunk = []
69
+ total = 0
70
+
71
+ self.logger.debug(f"Processing wave with {len(wave.chunk)} buckets")
72
+
73
+ with ThreadPoolExecutor(max_workers=self.max_workers, thread_name_prefix="PBP_worker_") as pool:
74
+ futures = [pool.submit(self._process_bucket_batch, bucket_batch) for bucket_batch in wave.chunk]
75
+ try:
76
+ for f in as_completed(futures):
77
+ out = f.result()
78
+ merged_stats = merge_summery(merged_stats, out.statistics or {})
79
+ total += out.batch_size
80
+ merged_chunk.extend(out.chunk if isinstance(out.chunk, list) else [out.chunk])
81
+ except Exception as e:
82
+ for g in futures:
83
+ g.cancel()
84
+ pool.shutdown(cancel_futures=True)
85
+ raise RuntimeError("bucket processing failed") from e
86
+
87
+ self.logger.debug(f"Finished wave with stats={merged_stats}")
88
+ return BatchResults(chunk=merged_chunk, statistics=merged_stats, batch_size=total)
89
+
90
+ def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
91
+ """
92
+ Pull waves from the predecessor (prefetching up to `prefetch` ahead), process each wave's
93
+ buckets in parallel, and yield one flattened BatchResults per wave.
94
+ """
95
+ wave_queue: queue.Queue[ParallelBatchResult | object] = queue.Queue(self.prefetch)
96
+ SENTINEL = object()
97
+ exc: BaseException | None = None
98
+
99
+ def producer():
100
+ nonlocal exc
101
+ try:
102
+ for wave in self.predecessor.get_batch(max_batch_size):
103
+ self.logger.debug(
104
+ f"adding wave stats={wave.statistics} buckets={len(wave.chunk)} to queue size={wave_queue.qsize()}"
105
+ )
106
+ wave_queue.put(wave)
107
+ except BaseException as e:
108
+ exc = e
109
+ finally:
110
+ wave_queue.put(SENTINEL)
111
+
112
+ threading.Thread(target=producer, daemon=True, name="prefetcher").start()
113
+
114
+ while True:
115
+ wave = wave_queue.get()
116
+ if wave is SENTINEL:
117
+ if exc is not None:
118
+ self.logger.error("Upstream producer failed", exc_info=True)
119
+ raise exc
120
+ break
121
+ yield self._process_wave(wave)
122
+
123
+ class SingleBatchWrapper(BatchProcessor):
124
+ """
125
+ Simple BatchProcessor that returns exactly one batch (the bucket-batch passed in via init).
126
+ Used as predecessor for the per-bucket worker.
127
+ """
128
+
129
+ def __init__(self, context, batch: List[Any]):
130
+ super().__init__(context=context, predecessor=None)
131
+ self._batch = batch
132
+
133
+ def get_batch(self, max_size: int) -> Generator[BatchResults, None, None]:
134
+ yield BatchResults(
135
+ chunk=self._batch,
136
+ statistics={},
137
+ batch_size=len(self._batch),
138
+ )
139
+
140
+ def _process_bucket_batch(self, bucket_batch):
141
+ """
142
+ Process one bucket-batch by running a fresh worker over it.
143
+ """
144
+ self.logger.debug(f"Processing batch w/ size {len(bucket_batch)}")
145
+ wrapper = self.SingleBatchWrapper(self.context, bucket_batch)
146
+ worker = self.worker_factory()
147
+ worker.predecessor = wrapper
148
+ result = next(worker.get_batch(len(bucket_batch)))
149
+ self.logger.debug(f"Finished bucket batch stats={result.statistics}")
150
+ return result
@@ -0,0 +1,391 @@
1
+ import logging
2
+ from typing import Any, Callable, Dict, Generator, List, Tuple
3
+
4
+ from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
5
+ from etl_lib.core.utils import merge_summery
6
+ from tabulate import tabulate
7
+
8
+
9
+ def tuple_id_extractor(table_size: int = 10) -> Callable[[Tuple[str | int, str | int]], Tuple[int, int]]:
10
+ """
11
+ Create an ID extractor function for tuple items, using the last decimal digit of each element.
12
+ The output is a `(row, col)` tuple within a `table_size x table_size` grid (default 10x10).
13
+
14
+ Args:
15
+ table_size: The dimension of the grid (number of rows/cols). Defaults to 10.
16
+
17
+ Returns:
18
+ A callable that maps a tuple `(a, b)` to a tuple `(row, col)` using the last digit of `a` and `b`.
19
+
20
+ Notes:
21
+ This extractor does not validate the returned indices. Range validation is performed by
22
+ :class:`~etl_lib.core.SplittingBatchProcessor.SplittingBatchProcessor`.
23
+ """
24
+
25
+ def extractor(item: Tuple[Any, Any]) -> Tuple[int, int]:
26
+ a, b = item
27
+ try:
28
+ row = int(str(a)[-1])
29
+ col = int(str(b)[-1])
30
+ except Exception as e:
31
+ raise ValueError(f"Failed to extract ID from item {item}: {e}")
32
+ return row, col
33
+
34
+ extractor.table_size = table_size
35
+ return extractor
36
+
37
+
38
+ def dict_id_extractor(
39
+ table_size: int = 10,
40
+ start_key: str = "start",
41
+ end_key: str = "end",
42
+ ) -> Callable[[Dict[str, Any]], Tuple[int, int]]:
43
+ """
44
+ Build an ID extractor for dict rows. The extractor reads two fields (configurable via
45
+ `start_key` and `end_key`) and returns (row, col) based on the last decimal digit of each.
46
+
47
+ Args:
48
+ table_size: Informational hint carried on the extractor.
49
+ start_key: Field name for the start node identifier.
50
+ end_key: Field name for the end node identifier.
51
+
52
+ Returns:
53
+ Callable that maps {start_key, end_key} → (row, col).
54
+ """
55
+
56
+ def extractor(item: Dict[str, Any]) -> Tuple[int, int]:
57
+ missing = [k for k in (start_key, end_key) if k not in item]
58
+ if missing:
59
+ raise KeyError(f"Item missing required keys: {', '.join(missing)}")
60
+ try:
61
+ row = int(str(item[start_key])[-1])
62
+ col = int(str(item[end_key])[-1])
63
+ except Exception as e:
64
+ raise ValueError(f"Failed to extract ID from item {item}: {e}")
65
+ return row, col
66
+
67
+ extractor.table_size = table_size
68
+ return extractor
69
+
70
+
71
+ def canonical_integer_id_extractor(
72
+ table_size: int = 10,
73
+ start_key: str = "start",
74
+ end_key: str = "end",
75
+ ) -> Callable[[Dict[str, Any]], Tuple[int, int]]:
76
+ """
77
+ ID extractor for integer IDs with canonical folding.
78
+
79
+ - Uses Knuth's multiplicative hashing to scatter sequential integers across the grid.
80
+ - Canonical folding ensures (A,B) and (B,A) land in the same bucket by folding the lower
81
+ triangle into the upper triangle (row <= col).
82
+
83
+ The extractor marks itself as mono-partite by setting `extractor.monopartite = True`.
84
+ """
85
+ MAGIC = 2654435761
86
+
87
+ def extractor(item: Dict[str, Any]) -> Tuple[int, int]:
88
+ try:
89
+ s_val = item[start_key]
90
+ e_val = item[end_key]
91
+
92
+ s_hash = (s_val * MAGIC) & 0xffffffff
93
+ e_hash = (e_val * MAGIC) & 0xffffffff
94
+
95
+ row = s_hash % table_size
96
+ col = e_hash % table_size
97
+
98
+ if row > col:
99
+ row, col = col, row
100
+
101
+ return row, col
102
+ except KeyError:
103
+ raise KeyError(f"Item missing keys: {start_key} or {end_key}")
104
+ except Exception as e:
105
+ raise ValueError(f"Failed to extract ID: {e}")
106
+
107
+ extractor.table_size = table_size
108
+ extractor.monopartite = True
109
+ return extractor
110
+
111
+
112
+ class SplittingBatchProcessor(BatchProcessor):
113
+ """
114
+ Streaming wave scheduler for mix-and-batch style loading.
115
+
116
+ Incoming rows are assigned to buckets via an `id_extractor(item) -> (row, col)` inside a
117
+ `table_size x table_size` grid. The processor emits waves; each wave contains bucket-batches
118
+ that are safe to process concurrently under the configured non-overlap rule.
119
+
120
+ Non-overlap rules
121
+ -----------------
122
+ - Bi-partite (default): within a wave, no two buckets share a row index and no two buckets share a col index.
123
+ - Mono-partite: within a wave, no node index is used more than once (row/col indices are the same domain).
124
+ Enable by setting `id_extractor.monopartite = True` (as done by `canonical_integer_id_extractor`).
125
+
126
+ Emission strategy
127
+ -----------------
128
+ - During streaming: emit a wave when at least one bucket is full (>= max_batch_size).
129
+ The wave is then filled with additional non-overlapping buckets that are near-full to
130
+ keep parallelism high without producing tiny batches.
131
+ - If a bucket backlog grows beyond a burst threshold, emit a burst wave to bound memory.
132
+ - After source exhaustion: flush leftovers in capped waves (max_batch_size per bucket).
133
+
134
+ Statistics policy
135
+ -----------------
136
+ - Every emission except the last carries {}.
137
+ - The last emission carries the accumulated upstream statistics (unfiltered).
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ context,
143
+ table_size: int,
144
+ id_extractor: Callable[[Any], Tuple[int, int]],
145
+ task=None,
146
+ predecessor=None,
147
+ near_full_ratio: float = 0.85,
148
+ burst_multiplier: int = 25,
149
+ ):
150
+ super().__init__(context, task, predecessor)
151
+
152
+ if hasattr(id_extractor, "table_size"):
153
+ expected_size = id_extractor.table_size
154
+ if table_size is None:
155
+ table_size = expected_size
156
+ elif table_size != expected_size:
157
+ raise ValueError(
158
+ f"Mismatch between provided table_size ({table_size}) and id_extractor table_size ({expected_size})."
159
+ )
160
+ elif table_size is None:
161
+ raise ValueError("table_size must be specified if id_extractor has no defined table_size")
162
+
163
+ if not (0 < near_full_ratio <= 1.0):
164
+ raise ValueError(f"near_full_ratio must be in (0, 1], got {near_full_ratio}")
165
+ if burst_multiplier < 1:
166
+ raise ValueError(f"burst_multiplier must be >= 1, got {burst_multiplier}")
167
+
168
+ self.table_size = table_size
169
+ self._id_extractor = id_extractor
170
+ self._monopartite = bool(getattr(id_extractor, "monopartite", False))
171
+
172
+ self.near_full_ratio = float(near_full_ratio)
173
+ self.burst_multiplier = int(burst_multiplier)
174
+
175
+ self.buffer: Dict[int, Dict[int, List[Any]]] = {
176
+ r: {c: [] for c in range(self.table_size)}
177
+ for r in range(self.table_size)
178
+ }
179
+ self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
180
+
181
+ def _bucket_claims(self, row: int, col: int) -> Tuple[Any, ...]:
182
+ """
183
+ Return the resource claims a bucket consumes within a wave.
184
+
185
+ - Bi-partite: claims (row-slot, col-slot)
186
+ - Mono-partite: claims node indices touched by the bucket
187
+ """
188
+ if self._monopartite:
189
+ return (row,) if row == col else (row, col)
190
+ return ("row", row), ("col", col)
191
+
192
+ def _all_bucket_sizes(self) -> List[Tuple[int, int, int]]:
193
+ """
194
+ Return all non-empty buckets as (size, row, col).
195
+ """
196
+ out: List[Tuple[int, int, int]] = []
197
+ for r in range(self.table_size):
198
+ for c in range(self.table_size):
199
+ n = len(self.buffer[r][c])
200
+ if n > 0:
201
+ out.append((n, r, c))
202
+ return out
203
+
204
+ def _select_wave(self, *, min_bucket_len: int, seed: List[Tuple[int, int]] | None = None) -> List[Tuple[int, int]]:
205
+ """
206
+ Greedy wave scheduler: pick a non-overlapping set of buckets with len >= min_bucket_len.
207
+
208
+ If `seed` is provided, it is taken as fixed and the wave is extended greedily.
209
+ """
210
+ candidates: List[Tuple[int, int, int]] = []
211
+ for r in range(self.table_size):
212
+ for c in range(self.table_size):
213
+ n = len(self.buffer[r][c])
214
+ if n >= min_bucket_len:
215
+ candidates.append((n, r, c))
216
+
217
+ if not candidates and not seed:
218
+ return []
219
+
220
+ candidates.sort(key=lambda x: (-x[0], x[1], x[2]))
221
+
222
+ used: set[Any] = set()
223
+ wave: List[Tuple[int, int]] = []
224
+
225
+ if seed:
226
+ for r, c in seed:
227
+ claims = self._bucket_claims(r, c)
228
+ used.update(claims)
229
+ wave.append((r, c))
230
+
231
+ for _, r, c in candidates:
232
+ if (r, c) in wave:
233
+ continue
234
+ claims = self._bucket_claims(r, c)
235
+ if any(claim in used for claim in claims):
236
+ continue
237
+ wave.append((r, c))
238
+ used.update(claims)
239
+ if len(wave) >= self.table_size:
240
+ break
241
+
242
+ return wave
243
+
244
+ def _find_hottest_bucket(self, *, threshold: int) -> Tuple[int, int, int] | None:
245
+ """
246
+ Find the single hottest bucket (largest backlog) whose size is >= threshold.
247
+ Returns (row, col, size) or None.
248
+ """
249
+ best: Tuple[int, int, int] | None = None
250
+ for r in range(self.table_size):
251
+ for c in range(self.table_size):
252
+ n = len(self.buffer[r][c])
253
+ if n < threshold:
254
+ continue
255
+ if best is None or n > best[2]:
256
+ best = (r, c, n)
257
+ return best
258
+
259
+ def _flush_wave(
260
+ self,
261
+ wave: List[Tuple[int, int]],
262
+ max_batch_size: int,
263
+ statistics: Dict[str, Any] | None = None,
264
+ ) -> BatchResults:
265
+ """
266
+ Extract up to `max_batch_size` items from each bucket in `wave`, remove them from the buffer,
267
+ and return a BatchResults whose chunk is a list of per-bucket lists (aligned with `wave`).
268
+ """
269
+ self._log_buffer_matrix(wave=wave)
270
+
271
+ bucket_batches: List[List[Any]] = []
272
+ for r, c in wave:
273
+ q = self.buffer[r][c]
274
+ take = min(max_batch_size, len(q))
275
+ bucket_batches.append(q[:take])
276
+ self.buffer[r][c] = q[take:]
277
+
278
+ emitted = sum(len(b) for b in bucket_batches)
279
+
280
+ return BatchResults(
281
+ chunk=bucket_batches,
282
+ statistics=statistics or {},
283
+ batch_size=emitted,
284
+ )
285
+
286
+ def _log_buffer_matrix(self, *, wave: List[Tuple[int, int]]) -> None:
287
+ """
288
+ Dumps a compact 2D matrix of per-bucket sizes (len of each buffer) when DEBUG is enabled.
289
+ """
290
+ if not self.logger.isEnabledFor(logging.DEBUG):
291
+ return
292
+
293
+ counts = [
294
+ [len(self.buffer[r][c]) for c in range(self.table_size)]
295
+ for r in range(self.table_size)
296
+ ]
297
+ marks = set(wave)
298
+
299
+ pad = max(2, len(str(self.table_size - 1)))
300
+ col_headers = [f"c{c:0{pad}d}" for c in range(self.table_size)]
301
+
302
+ rows = []
303
+ for r in range(self.table_size):
304
+ row_label = f"r{r:0{pad}d}"
305
+ row_vals = [f"[{v}]" if (r, c) in marks else f"{v}" for c, v in enumerate(counts[r])]
306
+ rows.append([row_label, *row_vals])
307
+
308
+ table = tabulate(
309
+ rows,
310
+ headers=["", *col_headers],
311
+ tablefmt="psql",
312
+ stralign="right",
313
+ disable_numparse=True,
314
+ )
315
+ self.logger.debug("buffer matrix:\n%s", table)
316
+
317
+ def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
318
+ """
319
+ Consume upstream batches, bucket incoming rows, and emit waves of non-overlapping buckets.
320
+
321
+ Streaming behavior:
322
+ - If at least one bucket is full (>= max_batch_size), emit a wave seeded with full buckets
323
+ and extended with near-full buckets to keep parallelism high.
324
+ - If a bucket exceeds a burst threshold, emit a burst wave (seeded by the hottest bucket)
325
+ and extended with near-full buckets.
326
+
327
+ End-of-stream behavior:
328
+ - Flush leftovers in capped waves (max_batch_size per bucket).
329
+
330
+ Statistics policy:
331
+ * Every emission except the last carries {}.
332
+ * The last emission carries the accumulated upstream stats (unfiltered).
333
+ """
334
+ if self.predecessor is None:
335
+ return
336
+
337
+ accum_stats: Dict[str, Any] = {}
338
+ pending: BatchResults | None = None
339
+
340
+ near_full_threshold = max(1, int(max_batch_size * self.near_full_ratio))
341
+ burst_threshold = self.burst_multiplier * max_batch_size
342
+
343
+ for upstream in self.predecessor.get_batch(max_batch_size):
344
+ if upstream.statistics:
345
+ accum_stats = merge_summery(accum_stats, upstream.statistics)
346
+
347
+ for item in upstream.chunk:
348
+ r, c = self._id_extractor(item)
349
+ if self._monopartite and r > c:
350
+ r, c = c, r
351
+ if not (0 <= r < self.table_size and 0 <= c < self.table_size):
352
+ raise ValueError(f"bucket id out of range: {(r, c)} for table_size={self.table_size}")
353
+ self.buffer[r][c].append(item)
354
+
355
+ while True:
356
+ full_seed = self._select_wave(min_bucket_len=max_batch_size)
357
+ if not full_seed:
358
+ break
359
+ wave = self._select_wave(min_bucket_len=near_full_threshold, seed=full_seed)
360
+ br = self._flush_wave(wave, max_batch_size, statistics={})
361
+ if pending is not None:
362
+ yield pending
363
+ pending = br
364
+
365
+ while True:
366
+ hot = self._find_hottest_bucket(threshold=burst_threshold)
367
+ if hot is None:
368
+ break
369
+ hot_r, hot_c, hot_n = hot
370
+ wave = self._select_wave(min_bucket_len=near_full_threshold, seed=[(hot_r, hot_c)])
371
+ self.logger.debug(
372
+ "burst flush: hottest_bucket=(%d,%d len=%d) threshold=%d near_full_threshold=%d wave_size=%d",
373
+ hot_r, hot_c, hot_n, burst_threshold, near_full_threshold, len(wave)
374
+ )
375
+ br = self._flush_wave(wave, max_batch_size, statistics={})
376
+ if pending is not None:
377
+ yield pending
378
+ pending = br
379
+
380
+ self.logger.debug("start flushing leftovers")
381
+ while True:
382
+ wave = self._select_wave(min_bucket_len=1)
383
+ if not wave:
384
+ break
385
+ br = self._flush_wave(wave, max_batch_size, statistics={})
386
+ if pending is not None:
387
+ yield pending
388
+ pending = br
389
+
390
+ if pending is not None:
391
+ yield BatchResults(chunk=pending.chunk, statistics=accum_stats, batch_size=pending.batch_size)
@@ -0,0 +1,84 @@
1
+ import logging
2
+ from typing import Generator, Callable, Optional
3
+
4
+ from sqlalchemy import text
5
+ from sqlalchemy.exc import OperationalError as SAOperationalError, DBAPIError
6
+
7
+ # Conditional import for psycopg2 to avoid crashing if not installed
8
+ try:
9
+ from psycopg2 import OperationalError as PsycopgOperationalError
10
+ except ImportError:
11
+ class PsycopgOperationalError(Exception):
12
+ pass
13
+
14
+ from etl_lib.core.BatchProcessor import BatchResults, BatchProcessor
15
+ from etl_lib.core.ETLContext import ETLContext
16
+ from etl_lib.core.Task import Task
17
+
18
+
19
+ class SQLBatchSource(BatchProcessor):
20
+ def __init__(
21
+ self,
22
+ context: ETLContext,
23
+ task: Task,
24
+ query: str,
25
+ record_transformer: Optional[Callable[[dict], dict]] = None,
26
+ **kwargs
27
+ ):
28
+ """
29
+ Constructs a new SQLBatchSource that streams results instead of paging them.
30
+ """
31
+ super().__init__(context, task)
32
+ # Remove any trailing semicolons to prevent SQL syntax errors
33
+ self.query = query.strip().rstrip(";")
34
+ self.record_transformer = record_transformer
35
+ self.kwargs = kwargs
36
+ self.logger = logging.getLogger(__name__)
37
+
38
+ def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
39
+ """
40
+ Yield successive batches using a Server-Side Cursor (Streaming).
41
+
42
+ This avoids 'LIMIT/OFFSET' pagination, which causes performance degradation
43
+ on large tables. Instead, it holds a cursor open and fetches rows incrementally.
44
+ """
45
+
46
+ with self.context.sql.engine.connect() as conn:
47
+
48
+ conn = conn.execution_options(stream_results=True)
49
+
50
+ try:
51
+ self.logger.info("Starting SQL Result Stream...")
52
+
53
+ result_proxy = conn.execute(text(self.query), self.kwargs)
54
+
55
+ chunk = []
56
+ count = 0
57
+
58
+ for row in result_proxy.mappings():
59
+ item = self.record_transformer(dict(row)) if self.record_transformer else dict(row)
60
+ chunk.append(item)
61
+ count += 1
62
+
63
+ # Yield when batch is full
64
+ if len(chunk) >= max_batch_size:
65
+ yield BatchResults(
66
+ chunk=chunk,
67
+ statistics={"sql_rows_read": len(chunk)},
68
+ batch_size=len(chunk),
69
+ )
70
+ chunk = [] # Clear memory
71
+
72
+ # Yield any remaining rows
73
+ if chunk:
74
+ yield BatchResults(
75
+ chunk=chunk,
76
+ statistics={"sql_rows_read": len(chunk)},
77
+ batch_size=len(chunk),
78
+ )
79
+
80
+ self.logger.info(f"SQL Stream finished. Total rows read: {count}")
81
+
82
+ except (PsycopgOperationalError, SAOperationalError, DBAPIError) as err:
83
+ self.logger.error(f"Stream failed: {err}")
84
+ raise
@@ -1,180 +0,0 @@
1
- import queue
2
- import threading
3
- from concurrent.futures import ThreadPoolExecutor, as_completed
4
- from typing import Any, Callable, Generator, List
5
-
6
- from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
7
- from etl_lib.core.utils import merge_summery
8
-
9
-
10
- class ParallelBatchResult(BatchResults):
11
- """
12
- Represents a batch split into parallelizable partitions.
13
-
14
- `chunk` is a list of lists, each sub-list is a partition.
15
- """
16
- pass
17
-
18
-
19
- class ParallelBatchProcessor(BatchProcessor):
20
- """
21
- BatchProcessor that runs worker threads over partitions of batches.
22
-
23
- Receives a special BatchResult (:py:class:`ParallelBatchResult`) from the predecessor.
24
- All chunks in a ParallelBatchResult it receives can be processed in parallel.
25
- See :py:class:`etl_lib.core.SplittingBatchProcessor` on how to produce them.
26
- Prefetches the next ParallelBatchResults from its predecessor.
27
- The actual processing of the batches is deferred to the configured worker.
28
-
29
- Note:
30
- - The predecessor must emit `ParallelBatchResult` instances.
31
-
32
- Args:
33
- context: ETL context.
34
- worker_factory: A zero-arg callable that returns a new BatchProcessor
35
- each time it's called. This processor is responsible for the processing pf the batches.
36
- task: optional Task for reporting.
37
- predecessor: upstream BatchProcessor that must emit ParallelBatchResult.
38
- max_workers: number of parallel threads for partitions.
39
- prefetch: number of ParallelBatchResults to prefetch from the predecessor.
40
-
41
- Behavior:
42
- - For every ParallelBatchResult, spins up `max_workers` threads.
43
- - Each thread calls its own worker from `worker_factory()`, with its
44
- partition wrapped by `SingleBatchWrapper`.
45
- - Collects and merges their BatchResults in a fail-fast manner: on first
46
- exception, logs the error, cancels remaining threads, and raises an exception.
47
- """
48
-
49
- def __init__(
50
- self,
51
- context,
52
- worker_factory: Callable[[], BatchProcessor],
53
- task=None,
54
- predecessor=None,
55
- max_workers: int = 4,
56
- prefetch: int = 4,
57
- ):
58
- super().__init__(context, task, predecessor)
59
- self.worker_factory = worker_factory
60
- self.max_workers = max_workers
61
- self.prefetch = prefetch
62
- self._batches_done = 0
63
-
64
- def _process_parallel(self, pbr: ParallelBatchResult) -> BatchResults:
65
- """
66
- Run one worker per partition in `pbr.chunk`, merge their outputs, and include upstream
67
- statistics from `pbr.statistics` so counters (e.g., valid/invalid rows from validation)
68
- are preserved through the parallel stage.
69
-
70
- Progress reporting:
71
- - After each partition completes, report batch count only
72
- """
73
- merged_stats = dict(pbr.statistics or {})
74
- merged_chunk = []
75
- total = 0
76
-
77
- parts_total = len(pbr.chunk)
78
- partitions_done = 0
79
-
80
- self.logger.debug(f"Processing pbr of len {parts_total}")
81
- with ThreadPoolExecutor(max_workers=self.max_workers, thread_name_prefix='PBP_worker_') as pool:
82
- futures = [pool.submit(self._process_partition, part) for part in pbr.chunk]
83
- try:
84
- for f in as_completed(futures):
85
- out = f.result()
86
-
87
- # Merge into this PBR's cumulative result (returned downstream)
88
- merged_stats = merge_summery(merged_stats, out.statistics or {})
89
- total += out.batch_size
90
- merged_chunk.extend(out.chunk if isinstance(out.chunk, list) else [out.chunk])
91
-
92
- partitions_done += 1
93
- self.context.reporter.report_progress(
94
- task=self.task,
95
- batches=self._batches_done,
96
- expected_batches=None,
97
- stats={},
98
- )
99
-
100
- except Exception as e:
101
- for g in futures:
102
- g.cancel()
103
- pool.shutdown(cancel_futures=True)
104
- raise RuntimeError("partition processing failed") from e
105
-
106
- self.logger.debug(f"Finished processing pbr with {merged_stats}")
107
- return BatchResults(chunk=merged_chunk, statistics=merged_stats, batch_size=total)
108
-
109
- def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
110
- """
111
- Pulls ParallelBatchResult batches from the predecessor, prefetching
112
- up to `prefetch` ahead, processes each batch's partitions in
113
- parallel threads, and yields a flattened BatchResults. The predecessor
114
- can run ahead while the current batch is processed.
115
- """
116
- pbr_queue: queue.Queue[ParallelBatchResult | object] = queue.Queue(self.prefetch)
117
- SENTINEL = object()
118
- exc: BaseException | None = None
119
-
120
- def producer():
121
- nonlocal exc
122
- try:
123
- for pbr in self.predecessor.get_batch(max_batch_size):
124
- self.logger.debug(
125
- f"adding pgr {pbr.statistics} / {len(pbr.chunk)} to queue of size {pbr_queue.qsize()}"
126
- )
127
- pbr_queue.put(pbr)
128
- except BaseException as e:
129
- exc = e
130
- finally:
131
- pbr_queue.put(SENTINEL)
132
-
133
- threading.Thread(target=producer, daemon=True, name='prefetcher').start()
134
-
135
- while True:
136
- pbr = pbr_queue.get()
137
- if pbr is SENTINEL:
138
- if exc is not None:
139
- self.logger.error("Upstream producer failed", exc_info=True)
140
- raise exc
141
- break
142
- result = self._process_parallel(pbr)
143
- yield result
144
-
145
- class SingleBatchWrapper(BatchProcessor):
146
- """
147
- Simple BatchProcessor that returns the batch it receives via init.
148
- Will be used as predecessor for the worker
149
- """
150
-
151
- def __init__(self, context, batch: List[Any]):
152
- super().__init__(context=context, predecessor=None)
153
- self._batch = batch
154
-
155
- def get_batch(self, max_batch__size: int) -> Generator[BatchResults, None, None]:
156
- # Ignores max_size; yields exactly one BatchResults containing the whole batch
157
- yield BatchResults(
158
- chunk=self._batch,
159
- statistics={},
160
- batch_size=len(self._batch)
161
- )
162
-
163
- def _process_partition(self, partition):
164
- """
165
- Processes one partition of items by:
166
- 1. Wrapping it in SingleBatchWrapper
167
- 2. Instantiating a fresh worker via worker_factory()
168
- 3. Setting the worker's predecessor to the wrapper
169
- 4. Running exactly one batch and returning its BatchResults
170
-
171
- Raises whatever exception the worker raises, allowing _process_parallel
172
- to handle fail-fast behavior.
173
- """
174
- self.logger.debug("Processing partition")
175
- wrapper = self.SingleBatchWrapper(self.context, partition)
176
- worker = self.worker_factory()
177
- worker.predecessor = wrapper
178
- result = next(worker.get_batch(len(partition)))
179
- self.logger.debug(f"finished processing partition with {result.statistics}")
180
- return result
@@ -1,268 +0,0 @@
1
- import logging
2
- from typing import Any, Callable, Dict, Generator, List, Tuple
3
-
4
- from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
5
- from etl_lib.core.utils import merge_summery
6
- from tabulate import tabulate
7
-
8
-
9
- def tuple_id_extractor(table_size: int = 10) -> Callable[[Tuple[str | int, str | int]], Tuple[int, int]]:
10
- """
11
- Create an ID extractor function for tuple items, using the last decimal digit of each element.
12
- The output is a `(row, col)` tuple within a `table_size x table_size` grid (default 10x10).
13
-
14
- Args:
15
- table_size: The dimension of the grid (number of rows/cols). Defaults to 10.
16
-
17
- Returns:
18
- A callable that maps a tuple `(a, b)` to a tuple `(row, col)` using the last digit of `a` and `b`.
19
- """
20
-
21
- def extractor(item: Tuple[Any, Any]) -> Tuple[int, int]:
22
- a, b = item
23
- try:
24
- row = int(str(a)[-1])
25
- col = int(str(b)[-1])
26
- except Exception as e:
27
- raise ValueError(f"Failed to extract ID from item {item}: {e}")
28
- return row, col
29
-
30
- extractor.table_size = table_size
31
- return extractor
32
-
33
-
34
- def dict_id_extractor(
35
- table_size: int = 10,
36
- start_key: str = "start",
37
- end_key: str = "end",
38
- ) -> Callable[[Dict[str, Any]], Tuple[int, int]]:
39
- """
40
- Build an ID extractor for dict rows. The extractor reads two fields (configurable via
41
- `start_key` and `end_key`) and returns (row, col) based on the last decimal digit of each.
42
- Range validation remains the responsibility of the SplittingBatchProcessor.
43
-
44
- Args:
45
- table_size: Informational hint carried on the extractor; used by callers to sanity-check.
46
- start_key: Field name for the start node identifier.
47
- end_key: Field name for the end node identifier.
48
-
49
- Returns:
50
- Callable[[Mapping[str, Any]], tuple[int, int]]: Maps {start_key, end_key} → (row, col).
51
- """
52
-
53
- def extractor(item: Dict[str, Any]) -> Tuple[int, int]:
54
- missing = [k for k in (start_key, end_key) if k not in item]
55
- if missing:
56
- raise KeyError(f"Item missing required keys: {', '.join(missing)}")
57
- try:
58
- row = int(str(item[start_key])[-1])
59
- col = int(str(item[end_key])[-1])
60
- except Exception as e:
61
- raise ValueError(f"Failed to extract ID from item {item}: {e}")
62
- return row, col
63
-
64
- extractor.table_size = table_size
65
- return extractor
66
-
67
-
68
- class SplittingBatchProcessor(BatchProcessor):
69
- """
70
- BatchProcessor that splits incoming BatchResults into non-overlapping partitions based
71
- on row/col indices extracted by the id_extractor, and emits full or remaining batches
72
- using the mix-and-batch algorithm from https://neo4j.com/blog/developer/mix-and-batch-relationship-load/
73
- Each emitted batch is a list of per-cell lists (array of arrays), so callers
74
- can process each partition (other name for a cell) in parallel.
75
-
76
- A batch for a schedule group is emitted when all cells in that group have at least `batch_size` items.
77
- In addition, when a cell/partition reaches 3x the configured max_batch_size, the group is emitted to avoid
78
- overflowing the buffer when the distribution per cell is uneven.
79
- Leftovers are flushed after source exhaustion.
80
- Emitted batches never exceed the configured max_batch_size.
81
- """
82
-
83
- def __init__(self, context, table_size: int, id_extractor: Callable[[Any], Tuple[int, int]],
84
- task=None, predecessor=None):
85
- super().__init__(context, task, predecessor)
86
-
87
- # If the extractor carries an expected table size, use or validate it
88
- if hasattr(id_extractor, "table_size"):
89
- expected_size = id_extractor.table_size
90
- if table_size is None:
91
- table_size = expected_size # determine table size from extractor if not provided
92
- elif table_size != expected_size:
93
- raise ValueError(f"Mismatch between provided table_size ({table_size}) and "
94
- f"id_extractor table_size ({expected_size}).")
95
- elif table_size is None:
96
- raise ValueError("table_size must be specified if id_extractor has no defined table_size")
97
- self.table_size = table_size
98
- self._id_extractor = id_extractor
99
-
100
- # Initialize 2D buffer for partitions
101
- self.buffer: Dict[int, Dict[int, List[Any]]] = {
102
- i: {j: [] for j in range(self.table_size)} for i in range(self.table_size)
103
- }
104
- self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
105
-
106
- def _generate_batch_schedule(self) -> List[List[Tuple[int, int]]]:
107
- """
108
- Create diagonal stripes (row, col) partitions to ensure no overlapping IDs
109
- across emitted batches.
110
- Example grid:
111
- || 0 | 1 | 2
112
- =====++=====+=====+=====
113
- 0 || 0 | 1 | 2
114
- -----++-----+-----+----
115
- 1 || 2 | 0 | 1
116
- -----++-----+-----+-----
117
- 2 || 1 | 2 | 0
118
-
119
- would return [[(0, 0), (1, 1), (2, 2)], [(0, 1), (1, 2), (2, 0)], [(0, 2), (1, 0), (2, 1)]]
120
- """
121
- schedule: List[List[Tuple[int, int]]] = []
122
- for shift in range(self.table_size):
123
- partition: List[Tuple[int, int]] = [
124
- (i, (i + shift) % self.table_size)
125
- for i in range(self.table_size)
126
- ]
127
- schedule.append(partition)
128
- return schedule
129
-
130
- def _flush_group(
131
- self,
132
- partitions: List[Tuple[int, int]],
133
- batch_size: int,
134
- statistics: Dict[str, Any] | None = None,
135
- ) -> Generator[BatchResults, None, None]:
136
- """
137
- Extract up to `batch_size` items from each cell in `partitions`, remove them from the buffer,
138
- and yield a BatchResults whose chunks is a list of per-cell lists from these partitions.
139
-
140
- Args:
141
- partitions: Cell coordinates forming a diagonal group from the schedule.
142
- batch_size: Max number of items to take from each cell.
143
- statistics: Stats dict to attach to this emission (use {} for interim waves).
144
- The "final" emission will pass the accumulated stats here.
145
-
146
- Notes:
147
- - Debug-only: prints a 2D matrix of cell sizes when logger is in DEBUG.
148
- - batch_size in the returned BatchResults equals the number of emitted items.
149
- """
150
- self._log_buffer_matrix(partition=partitions)
151
-
152
- cell_chunks: List[List[Any]] = []
153
- for row, col in partitions:
154
- q = self.buffer[row][col]
155
- take = min(batch_size, len(q))
156
- part = q[:take]
157
- cell_chunks.append(part)
158
- # remove flushed items
159
- self.buffer[row][col] = q[take:]
160
-
161
- emitted = sum(len(c) for c in cell_chunks)
162
-
163
- result = BatchResults(
164
- chunk=cell_chunks,
165
- statistics=statistics or {},
166
- batch_size=emitted,
167
- )
168
- yield result
169
-
170
- def get_batch(self, max_batch__size: int) -> Generator[BatchResults, None, None]:
171
- """
172
- Consume upstream batches, split data across cells, and emit diagonal partitions:
173
- - During streaming: emit a full partition when all its cells have >= max_batch__size.
174
- - Also during streaming: if any cell in a partition builds up beyond a 'burst' threshold
175
- (3 * of max_batch__size), emit that partition early, taking up to max_batch__size
176
- from each cell.
177
- - After source exhaustion: flush leftovers in waves capped at max_batch__size per cell.
178
-
179
- Statistics policy:
180
- * Every emission except the last carries {}.
181
- * The last emission carries the accumulated upstream stats (unfiltered).
182
- """
183
- schedule = self._generate_batch_schedule()
184
-
185
- accum_stats: Dict[str, Any] = {}
186
- pending: BatchResults | None = None # hold back the most recent emission so we know what's final
187
-
188
- burst_threshold = 3 * max_batch__size
189
-
190
- for upstream in self.predecessor.get_batch(max_batch__size):
191
- # accumulate upstream stats (unfiltered)
192
- if upstream.statistics:
193
- accum_stats = merge_summery(accum_stats, upstream.statistics)
194
-
195
- # add data to cells
196
- for item in upstream.chunk:
197
- r, c = self._id_extractor(item)
198
- if not (0 <= r < self.table_size and 0 <= c < self.table_size):
199
- raise ValueError(f"partition id out of range: {(r, c)} for table_size={self.table_size}")
200
- self.buffer[r][c].append(item)
201
-
202
- # process partitions
203
- for partition in schedule:
204
- # normal full flush when all cells are ready
205
- if all(len(self.buffer[r][c]) >= max_batch__size for r, c in partition):
206
- br = next(self._flush_group(partition, max_batch__size, statistics={}))
207
- if pending is not None:
208
- yield pending
209
- pending = br
210
- continue
211
-
212
- # burst flush if any cell backlog explodes
213
- hot_cells = [(r, c, len(self.buffer[r][c])) for r, c in partition if
214
- len(self.buffer[r][c]) >= burst_threshold]
215
- if hot_cells:
216
- top_r, top_c, top_len = max(hot_cells, key=lambda x: x[2])
217
- self.logger.debug(
218
- "burst flush: partition=%s threshold=%d top_cell=(%d,%d len=%d)",
219
- partition, burst_threshold, top_r, top_c, top_len
220
- )
221
- br = next(self._flush_group(partition, max_batch__size, statistics={}))
222
- if pending is not None:
223
- yield pending
224
- pending = br
225
-
226
- # source exhausted: drain leftovers in capped waves (respecting batch size)
227
- self.logger.debug("start flushing leftovers")
228
- for partition in (p for p in schedule if any(self.buffer[r][c] for r, c in p)):
229
- while any(self.buffer[r][c] for r, c in partition):
230
- br = next(self._flush_group(partition, max_batch__size, statistics={}))
231
- if pending is not None:
232
- yield pending
233
- pending = br
234
-
235
- # final emission carries accumulated stats once
236
- if pending is not None:
237
- yield BatchResults(chunk=pending.chunk, statistics=accum_stats, batch_size=pending.batch_size)
238
-
239
- def _log_buffer_matrix(self, *, partition: List[Tuple[int, int]]) -> None:
240
- """
241
- Dumps a compact 2D matrix of per-cell sizes (len of each buffer) when DEBUG is enabled.
242
- """
243
- if not self.logger.isEnabledFor(logging.DEBUG):
244
- return
245
-
246
- counts = [
247
- [len(self.buffer[r][c]) for c in range(self.table_size)]
248
- for r in range(self.table_size)
249
- ]
250
- marks = set(partition)
251
-
252
- pad = max(2, len(str(self.table_size - 1)))
253
- col_headers = [f"c{c:0{pad}d}" for c in range(self.table_size)]
254
-
255
- rows = []
256
- for r in range(self.table_size):
257
- row_label = f"r{r:0{pad}d}"
258
- row_vals = [f"[{v}]" if (r, c) in marks else f"{v}" for c, v in enumerate(counts[r])]
259
- rows.append([row_label, *row_vals])
260
-
261
- table = tabulate(
262
- rows,
263
- headers=["", *col_headers],
264
- tablefmt="psql",
265
- stralign="right",
266
- disable_numparse=True,
267
- )
268
- self.logger.debug("buffer matrix:\n%s", table)
@@ -1,114 +0,0 @@
1
- import time
2
- from typing import Generator, Callable, Optional, List, Dict
3
-
4
- from psycopg2 import OperationalError as PsycopgOperationalError
5
- from sqlalchemy import text
6
- from sqlalchemy.exc import OperationalError as SAOperationalError, DBAPIError
7
-
8
- from etl_lib.core.BatchProcessor import BatchResults, BatchProcessor
9
- from etl_lib.core.ETLContext import ETLContext
10
- from etl_lib.core.Task import Task
11
-
12
-
13
- class SQLBatchSource(BatchProcessor):
14
- def __init__(
15
- self,
16
- context: ETLContext,
17
- task: Task,
18
- query: str,
19
- record_transformer: Optional[Callable[[dict], dict]] = None,
20
- **kwargs
21
- ):
22
- """
23
- Constructs a new SQLBatchSource.
24
-
25
- Args:
26
- context: :class:`etl_lib.core.ETLContext.ETLContext` instance.
27
- task: :class:`etl_lib.core.Task.Task` instance owning this batchProcessor.
28
- query: SQL query to execute.
29
- record_transformer: Optional function to transform each row (dict format).
30
- kwargs: Arguments passed as parameters with the query.
31
- """
32
- super().__init__(context, task)
33
- self.query = query.strip().rstrip(";")
34
- self.record_transformer = record_transformer
35
- self.kwargs = kwargs
36
-
37
- def _fetch_page(self, limit: int, offset: int) -> Optional[List[Dict]]:
38
- """
39
- Fetch a single batch of rows using LIMIT/OFFSET, with retry/backoff.
40
-
41
- Each page is executed in its own transaction. On transient
42
- disconnects or DB errors, it retries up to 3 times with exponential backoff.
43
-
44
- Args:
45
- limit: maximum number of rows to return.
46
- offset: number of rows to skip before starting this page.
47
-
48
- Returns:
49
- A list of row dicts (after applying record_transformer, if any),
50
- or None if no rows are returned.
51
-
52
- Raises:
53
- Exception: re-raises the last caught error on final failure.
54
- """
55
- paged_sql = f"{self.query} LIMIT :limit OFFSET :offset"
56
- params = {**self.kwargs, "limit": limit, "offset": offset}
57
- max_retries = 5
58
- backoff = 2.0
59
-
60
- for attempt in range(1, max_retries + 1):
61
- try:
62
- with self.context.sql.engine.connect() as conn:
63
- with conn.begin():
64
- rows = conn.execute(text(paged_sql), params).mappings().all()
65
- result = [
66
- self.record_transformer(dict(r)) if self.record_transformer else dict(r)
67
- for r in rows
68
- ]
69
- return result if result else None
70
-
71
- except (PsycopgOperationalError, SAOperationalError, DBAPIError) as err:
72
-
73
- if attempt == max_retries:
74
- self.logger.error(
75
- f"Page fetch failed after {max_retries} attempts "
76
- f"(limit={limit}, offset={offset}): {err}"
77
- )
78
- raise
79
-
80
- self.logger.warning(
81
- f"Transient DB error on page fetch {attempt}/{max_retries}: {err!r}, "
82
- f"retrying in {backoff:.1f}s"
83
- )
84
- time.sleep(backoff)
85
- backoff *= 2
86
-
87
- return None
88
-
89
- def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
90
- """
91
- Yield successive batches until the query is exhausted.
92
-
93
- Calls _fetch_page() repeatedly, advancing the offset by the
94
- number of rows returned. Stops when no more rows are returned.
95
-
96
- Args:
97
- max_batch_size: upper limit on rows per batch.
98
-
99
- Yields:
100
- BatchResults for each non-empty page.
101
- """
102
- offset = 0
103
- while True:
104
- chunk = self._fetch_page(max_batch_size, offset)
105
- if not chunk:
106
- break
107
-
108
- yield BatchResults(
109
- chunk=chunk,
110
- statistics={"sql_rows_read": len(chunk)},
111
- batch_size=len(chunk),
112
- )
113
-
114
- offset += len(chunk)
File without changes
File without changes