neo4j-etl-lib 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
etl_lib/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
1
  """
2
2
  Building blocks for ETL pipelines.
3
3
  """
4
- __version__ = "0.3.1"
4
+ __version__ = "0.3.2"
@@ -172,15 +172,8 @@ if sqlalchemy_available:
172
172
  database_url,
173
173
  pool_pre_ping=True,
174
174
  pool_size=pool_size,
175
- max_overflow=max_overflow,
176
- pool_recycle=1800, # recycle connections older than 30m
177
- connect_args={
178
- # turn on TCP keepalives on the client socket:
179
- "keepalives": 1,
180
- "keepalives_idle": 60, # after 60s of idle
181
- "keepalives_interval": 10, # probe every 10s
182
- "keepalives_count": 5, # give up after 5 failed probes
183
- })
175
+ max_overflow=max_overflow
176
+ )
184
177
 
185
178
 
186
179
  class ETLContext:
@@ -9,41 +9,38 @@ from etl_lib.core.utils import merge_summery
9
9
 
10
10
  class ParallelBatchResult(BatchResults):
11
11
  """
12
- Represents a batch split into parallelizable partitions.
12
+ Represents one *wave* produced by the splitter.
13
13
 
14
- `chunk` is a list of lists, each sub-list is a partition.
14
+ `chunk` is a list of bucket-batches. Each sub-list is processed by one worker instance.
15
15
  """
16
16
  pass
17
17
 
18
18
 
19
19
  class ParallelBatchProcessor(BatchProcessor):
20
20
  """
21
- BatchProcessor that runs worker threads over partitions of batches.
22
-
23
- Receives a special BatchResult (:py:class:`ParallelBatchResult`) from the predecessor.
24
- All chunks in a ParallelBatchResult it receives can be processed in parallel.
25
- See :py:class:`etl_lib.core.SplittingBatchProcessor` on how to produce them.
26
- Prefetches the next ParallelBatchResults from its predecessor.
27
- The actual processing of the batches is deferred to the configured worker.
21
+ BatchProcessor that runs a worker over the bucket-batches of each ParallelBatchResult
22
+ in parallel threads, while prefetching the next ParallelBatchResult from its predecessor.
28
23
 
29
24
  Note:
30
- - The predecessor must emit `ParallelBatchResult` instances.
25
+ - The predecessor must emit `ParallelBatchResult` instances (waves).
26
+ - This processor collects the BatchResults from all workers for one wave and merges them
27
+ into one BatchResults.
28
+ - The returned BatchResults will not obey the max_batch_size from get_batch() because
29
+ it represents the full wave.
31
30
 
32
31
  Args:
33
32
  context: ETL context.
34
- worker_factory: A zero-arg callable that returns a new BatchProcessor
35
- each time it's called. This processor is responsible for the processing pf the batches.
33
+ worker_factory: A zero-arg callable that returns a new BatchProcessor each time it's called.
36
34
  task: optional Task for reporting.
37
- predecessor: upstream BatchProcessor that must emit ParallelBatchResult.
38
- max_workers: number of parallel threads for partitions.
39
- prefetch: number of ParallelBatchResults to prefetch from the predecessor.
35
+ predecessor: upstream BatchProcessor that must emit ParallelBatchResult See
36
+ :class:`~etl_lib.core.SplittingBatchProcessor.SplittingBatchProcessor`.
37
+ max_workers: number of parallel threads for bucket processing.
38
+ prefetch: number of waves to prefetch.
40
39
 
41
40
  Behavior:
42
- - For every ParallelBatchResult, spins up `max_workers` threads.
43
- - Each thread calls its own worker from `worker_factory()`, with its
44
- partition wrapped by `SingleBatchWrapper`.
45
- - Collects and merges their BatchResults in a fail-fast manner: on first
46
- exception, logs the error, cancels remaining threads, and raises an exception.
41
+ - For every wave, spins up `max_workers` threads.
42
+ - Each thread processes one bucket-batch using a fresh worker from `worker_factory()`.
43
+ - Collects and merges worker results in a fail-fast manner.
47
44
  """
48
45
 
49
46
  def __init__(
@@ -59,122 +56,95 @@ class ParallelBatchProcessor(BatchProcessor):
59
56
  self.worker_factory = worker_factory
60
57
  self.max_workers = max_workers
61
58
  self.prefetch = prefetch
62
- self._batches_done = 0
63
59
 
64
- def _process_parallel(self, pbr: ParallelBatchResult) -> BatchResults:
60
+ def _process_wave(self, wave: ParallelBatchResult) -> BatchResults:
65
61
  """
66
- Run one worker per partition in `pbr.chunk`, merge their outputs, and include upstream
67
- statistics from `pbr.statistics` so counters (e.g., valid/invalid rows from validation)
68
- are preserved through the parallel stage.
62
+ Process one wave: run one worker per bucket-batch and merge their BatchResults.
69
63
 
70
- Progress reporting:
71
- - After each partition completes, report batch count only
64
+ Statistics:
65
+ `wave.statistics` is used as the initial merged stats, then merged with each worker's stats.
72
66
  """
73
- merged_stats = dict(pbr.statistics or {})
67
+ merged_stats = dict(wave.statistics or {})
74
68
  merged_chunk = []
75
69
  total = 0
76
70
 
77
- parts_total = len(pbr.chunk)
78
- partitions_done = 0
71
+ self.logger.debug(f"Processing wave with {len(wave.chunk)} buckets")
79
72
 
80
- self.logger.debug(f"Processing pbr of len {parts_total}")
81
- with ThreadPoolExecutor(max_workers=self.max_workers, thread_name_prefix='PBP_worker_') as pool:
82
- futures = [pool.submit(self._process_partition, part) for part in pbr.chunk]
73
+ with ThreadPoolExecutor(max_workers=self.max_workers, thread_name_prefix="PBP_worker_") as pool:
74
+ futures = [pool.submit(self._process_bucket_batch, bucket_batch) for bucket_batch in wave.chunk]
83
75
  try:
84
76
  for f in as_completed(futures):
85
77
  out = f.result()
86
-
87
- # Merge into this PBR's cumulative result (returned downstream)
88
78
  merged_stats = merge_summery(merged_stats, out.statistics or {})
89
79
  total += out.batch_size
90
80
  merged_chunk.extend(out.chunk if isinstance(out.chunk, list) else [out.chunk])
91
-
92
- partitions_done += 1
93
- self.context.reporter.report_progress(
94
- task=self.task,
95
- batches=self._batches_done,
96
- expected_batches=None,
97
- stats={},
98
- )
99
-
100
81
  except Exception as e:
101
82
  for g in futures:
102
83
  g.cancel()
103
84
  pool.shutdown(cancel_futures=True)
104
- raise RuntimeError("partition processing failed") from e
85
+ raise RuntimeError("bucket processing failed") from e
105
86
 
106
- self.logger.debug(f"Finished processing pbr with {merged_stats}")
87
+ self.logger.debug(f"Finished wave with stats={merged_stats}")
107
88
  return BatchResults(chunk=merged_chunk, statistics=merged_stats, batch_size=total)
108
89
 
109
90
  def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
110
91
  """
111
- Pulls ParallelBatchResult batches from the predecessor, prefetching
112
- up to `prefetch` ahead, processes each batch's partitions in
113
- parallel threads, and yields a flattened BatchResults. The predecessor
114
- can run ahead while the current batch is processed.
92
+ Pull waves from the predecessor (prefetching up to `prefetch` ahead), process each wave's
93
+ buckets in parallel, and yield one flattened BatchResults per wave.
115
94
  """
116
- pbr_queue: queue.Queue[ParallelBatchResult | object] = queue.Queue(self.prefetch)
95
+ wave_queue: queue.Queue[ParallelBatchResult | object] = queue.Queue(self.prefetch)
117
96
  SENTINEL = object()
118
97
  exc: BaseException | None = None
119
98
 
120
99
  def producer():
121
100
  nonlocal exc
122
101
  try:
123
- for pbr in self.predecessor.get_batch(max_batch_size):
102
+ for wave in self.predecessor.get_batch(max_batch_size):
124
103
  self.logger.debug(
125
- f"adding pgr {pbr.statistics} / {len(pbr.chunk)} to queue of size {pbr_queue.qsize()}"
104
+ f"adding wave stats={wave.statistics} buckets={len(wave.chunk)} to queue size={wave_queue.qsize()}"
126
105
  )
127
- pbr_queue.put(pbr)
106
+ wave_queue.put(wave)
128
107
  except BaseException as e:
129
108
  exc = e
130
109
  finally:
131
- pbr_queue.put(SENTINEL)
110
+ wave_queue.put(SENTINEL)
132
111
 
133
- threading.Thread(target=producer, daemon=True, name='prefetcher').start()
112
+ threading.Thread(target=producer, daemon=True, name="prefetcher").start()
134
113
 
135
114
  while True:
136
- pbr = pbr_queue.get()
137
- if pbr is SENTINEL:
115
+ wave = wave_queue.get()
116
+ if wave is SENTINEL:
138
117
  if exc is not None:
139
118
  self.logger.error("Upstream producer failed", exc_info=True)
140
119
  raise exc
141
120
  break
142
- result = self._process_parallel(pbr)
143
- yield result
121
+ yield self._process_wave(wave)
144
122
 
145
123
  class SingleBatchWrapper(BatchProcessor):
146
124
  """
147
- Simple BatchProcessor that returns the batch it receives via init.
148
- Will be used as predecessor for the worker
125
+ Simple BatchProcessor that returns exactly one batch (the bucket-batch passed in via init).
126
+ Used as predecessor for the per-bucket worker.
149
127
  """
150
128
 
151
129
  def __init__(self, context, batch: List[Any]):
152
130
  super().__init__(context=context, predecessor=None)
153
131
  self._batch = batch
154
132
 
155
- def get_batch(self, max_batch__size: int) -> Generator[BatchResults, None, None]:
156
- # Ignores max_size; yields exactly one BatchResults containing the whole batch
133
+ def get_batch(self, max_size: int) -> Generator[BatchResults, None, None]:
157
134
  yield BatchResults(
158
135
  chunk=self._batch,
159
136
  statistics={},
160
- batch_size=len(self._batch)
137
+ batch_size=len(self._batch),
161
138
  )
162
139
 
163
- def _process_partition(self, partition):
140
+ def _process_bucket_batch(self, bucket_batch):
164
141
  """
165
- Processes one partition of items by:
166
- 1. Wrapping it in SingleBatchWrapper
167
- 2. Instantiating a fresh worker via worker_factory()
168
- 3. Setting the worker's predecessor to the wrapper
169
- 4. Running exactly one batch and returning its BatchResults
170
-
171
- Raises whatever exception the worker raises, allowing _process_parallel
172
- to handle fail-fast behavior.
142
+ Process one bucket-batch by running a fresh worker over it.
173
143
  """
174
- self.logger.debug("Processing partition")
175
- wrapper = self.SingleBatchWrapper(self.context, partition)
144
+ self.logger.debug(f"Processing batch w/ size {len(bucket_batch)}")
145
+ wrapper = self.SingleBatchWrapper(self.context, bucket_batch)
176
146
  worker = self.worker_factory()
177
147
  worker.predecessor = wrapper
178
- result = next(worker.get_batch(len(partition)))
179
- self.logger.debug(f"finished processing partition with {result.statistics}")
148
+ result = next(worker.get_batch(len(bucket_batch)))
149
+ self.logger.debug(f"Finished bucket batch stats={result.statistics}")
180
150
  return result
@@ -16,6 +16,10 @@ def tuple_id_extractor(table_size: int = 10) -> Callable[[Tuple[str | int, str |
16
16
 
17
17
  Returns:
18
18
  A callable that maps a tuple `(a, b)` to a tuple `(row, col)` using the last digit of `a` and `b`.
19
+
20
+ Notes:
21
+ This extractor does not validate the returned indices. Range validation is performed by
22
+ :class:`~etl_lib.core.SplittingBatchProcessor.SplittingBatchProcessor`.
19
23
  """
20
24
 
21
25
  def extractor(item: Tuple[Any, Any]) -> Tuple[int, int]:
@@ -39,15 +43,14 @@ def dict_id_extractor(
39
43
  """
40
44
  Build an ID extractor for dict rows. The extractor reads two fields (configurable via
41
45
  `start_key` and `end_key`) and returns (row, col) based on the last decimal digit of each.
42
- Range validation remains the responsibility of the SplittingBatchProcessor.
43
46
 
44
47
  Args:
45
- table_size: Informational hint carried on the extractor; used by callers to sanity-check.
48
+ table_size: Informational hint carried on the extractor.
46
49
  start_key: Field name for the start node identifier.
47
50
  end_key: Field name for the end node identifier.
48
51
 
49
52
  Returns:
50
- Callable[[Mapping[str, Any]], tuple[int, int]]: Maps {start_key, end_key} → (row, col).
53
+ Callable that maps {start_key, end_key} → (row, col).
51
54
  """
52
55
 
53
56
  def extractor(item: Dict[str, Any]) -> Tuple[int, int]:
@@ -65,180 +68,224 @@ def dict_id_extractor(
65
68
  return extractor
66
69
 
67
70
 
71
+ def canonical_integer_id_extractor(
72
+ table_size: int = 10,
73
+ start_key: str = "start",
74
+ end_key: str = "end",
75
+ ) -> Callable[[Dict[str, Any]], Tuple[int, int]]:
76
+ """
77
+ ID extractor for integer IDs with canonical folding.
78
+
79
+ - Uses Knuth's multiplicative hashing to scatter sequential integers across the grid.
80
+ - Canonical folding ensures (A,B) and (B,A) land in the same bucket by folding the lower
81
+ triangle into the upper triangle (row <= col).
82
+
83
+ The extractor marks itself as mono-partite by setting `extractor.monopartite = True`.
84
+ """
85
+ MAGIC = 2654435761
86
+
87
+ def extractor(item: Dict[str, Any]) -> Tuple[int, int]:
88
+ try:
89
+ s_val = item[start_key]
90
+ e_val = item[end_key]
91
+
92
+ s_hash = (s_val * MAGIC) & 0xffffffff
93
+ e_hash = (e_val * MAGIC) & 0xffffffff
94
+
95
+ row = s_hash % table_size
96
+ col = e_hash % table_size
97
+
98
+ if row > col:
99
+ row, col = col, row
100
+
101
+ return row, col
102
+ except KeyError:
103
+ raise KeyError(f"Item missing keys: {start_key} or {end_key}")
104
+ except Exception as e:
105
+ raise ValueError(f"Failed to extract ID: {e}")
106
+
107
+ extractor.table_size = table_size
108
+ extractor.monopartite = True
109
+ return extractor
110
+
111
+
68
112
  class SplittingBatchProcessor(BatchProcessor):
69
113
  """
70
- BatchProcessor that splits incoming BatchResults into non-overlapping partitions based
71
- on row/col indices extracted by the id_extractor, and emits full or remaining batches
72
- using the mix-and-batch algorithm from https://neo4j.com/blog/developer/mix-and-batch-relationship-load/
73
- Each emitted batch is a list of per-cell lists (array of arrays), so callers
74
- can process each partition (other name for a cell) in parallel.
75
-
76
- A batch for a schedule group is emitted when all cells in that group have at least `batch_size` items.
77
- In addition, when a cell/partition reaches 3x the configured max_batch_size, the group is emitted to avoid
78
- overflowing the buffer when the distribution per cell is uneven.
79
- Leftovers are flushed after source exhaustion.
80
- Emitted batches never exceed the configured max_batch_size.
114
+ Streaming wave scheduler for mix-and-batch style loading.
115
+
116
+ Incoming rows are assigned to buckets via an `id_extractor(item) -> (row, col)` inside a
117
+ `table_size x table_size` grid. The processor emits waves; each wave contains bucket-batches
118
+ that are safe to process concurrently under the configured non-overlap rule.
119
+
120
+ Non-overlap rules
121
+ -----------------
122
+ - Bi-partite (default): within a wave, no two buckets share a row index and no two buckets share a col index.
123
+ - Mono-partite: within a wave, no node index is used more than once (row/col indices are the same domain).
124
+ Enable by setting `id_extractor.monopartite = True` (as done by `canonical_integer_id_extractor`).
125
+
126
+ Emission strategy
127
+ -----------------
128
+ - During streaming: emit a wave when at least one bucket is full (>= max_batch_size).
129
+ The wave is then filled with additional non-overlapping buckets that are near-full to
130
+ keep parallelism high without producing tiny batches.
131
+ - If a bucket backlog grows beyond a burst threshold, emit a burst wave to bound memory.
132
+ - After source exhaustion: flush leftovers in capped waves (max_batch_size per bucket).
133
+
134
+ Statistics policy
135
+ -----------------
136
+ - Every emission except the last carries {}.
137
+ - The last emission carries the accumulated upstream statistics (unfiltered).
81
138
  """
82
139
 
83
- def __init__(self, context, table_size: int, id_extractor: Callable[[Any], Tuple[int, int]],
84
- task=None, predecessor=None):
140
+ def __init__(
141
+ self,
142
+ context,
143
+ table_size: int,
144
+ id_extractor: Callable[[Any], Tuple[int, int]],
145
+ task=None,
146
+ predecessor=None,
147
+ near_full_ratio: float = 0.85,
148
+ burst_multiplier: int = 25,
149
+ ):
85
150
  super().__init__(context, task, predecessor)
86
151
 
87
- # If the extractor carries an expected table size, use or validate it
88
152
  if hasattr(id_extractor, "table_size"):
89
153
  expected_size = id_extractor.table_size
90
154
  if table_size is None:
91
- table_size = expected_size # determine table size from extractor if not provided
155
+ table_size = expected_size
92
156
  elif table_size != expected_size:
93
- raise ValueError(f"Mismatch between provided table_size ({table_size}) and "
94
- f"id_extractor table_size ({expected_size}).")
157
+ raise ValueError(
158
+ f"Mismatch between provided table_size ({table_size}) and id_extractor table_size ({expected_size})."
159
+ )
95
160
  elif table_size is None:
96
161
  raise ValueError("table_size must be specified if id_extractor has no defined table_size")
162
+
163
+ if not (0 < near_full_ratio <= 1.0):
164
+ raise ValueError(f"near_full_ratio must be in (0, 1], got {near_full_ratio}")
165
+ if burst_multiplier < 1:
166
+ raise ValueError(f"burst_multiplier must be >= 1, got {burst_multiplier}")
167
+
97
168
  self.table_size = table_size
98
169
  self._id_extractor = id_extractor
170
+ self._monopartite = bool(getattr(id_extractor, "monopartite", False))
171
+
172
+ self.near_full_ratio = float(near_full_ratio)
173
+ self.burst_multiplier = int(burst_multiplier)
99
174
 
100
- # Initialize 2D buffer for partitions
101
175
  self.buffer: Dict[int, Dict[int, List[Any]]] = {
102
- i: {j: [] for j in range(self.table_size)} for i in range(self.table_size)
176
+ r: {c: [] for c in range(self.table_size)}
177
+ for r in range(self.table_size)
103
178
  }
104
179
  self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
105
180
 
106
- def _generate_batch_schedule(self) -> List[List[Tuple[int, int]]]:
181
+ def _bucket_claims(self, row: int, col: int) -> Tuple[Any, ...]:
107
182
  """
108
- Create diagonal stripes (row, col) partitions to ensure no overlapping IDs
109
- across emitted batches.
110
- Example grid:
111
- || 0 | 1 | 2
112
- =====++=====+=====+=====
113
- 0 || 0 | 1 | 2
114
- -----++-----+-----+----
115
- 1 || 2 | 0 | 1
116
- -----++-----+-----+-----
117
- 2 || 1 | 2 | 0
118
-
119
- would return [[(0, 0), (1, 1), (2, 2)], [(0, 1), (1, 2), (2, 0)], [(0, 2), (1, 0), (2, 1)]]
183
+ Return the resource claims a bucket consumes within a wave.
184
+
185
+ - Bi-partite: claims (row-slot, col-slot)
186
+ - Mono-partite: claims node indices touched by the bucket
120
187
  """
121
- schedule: List[List[Tuple[int, int]]] = []
122
- for shift in range(self.table_size):
123
- partition: List[Tuple[int, int]] = [
124
- (i, (i + shift) % self.table_size)
125
- for i in range(self.table_size)
126
- ]
127
- schedule.append(partition)
128
- return schedule
129
-
130
- def _flush_group(
131
- self,
132
- partitions: List[Tuple[int, int]],
133
- batch_size: int,
134
- statistics: Dict[str, Any] | None = None,
135
- ) -> Generator[BatchResults, None, None]:
188
+ if self._monopartite:
189
+ return (row,) if row == col else (row, col)
190
+ return ("row", row), ("col", col)
191
+
192
+ def _all_bucket_sizes(self) -> List[Tuple[int, int, int]]:
136
193
  """
137
- Extract up to `batch_size` items from each cell in `partitions`, remove them from the buffer,
138
- and yield a BatchResults whose chunks is a list of per-cell lists from these partitions.
139
-
140
- Args:
141
- partitions: Cell coordinates forming a diagonal group from the schedule.
142
- batch_size: Max number of items to take from each cell.
143
- statistics: Stats dict to attach to this emission (use {} for interim waves).
144
- The "final" emission will pass the accumulated stats here.
145
-
146
- Notes:
147
- - Debug-only: prints a 2D matrix of cell sizes when logger is in DEBUG.
148
- - batch_size in the returned BatchResults equals the number of emitted items.
194
+ Return all non-empty buckets as (size, row, col).
149
195
  """
150
- self._log_buffer_matrix(partition=partitions)
151
-
152
- cell_chunks: List[List[Any]] = []
153
- for row, col in partitions:
154
- q = self.buffer[row][col]
155
- take = min(batch_size, len(q))
156
- part = q[:take]
157
- cell_chunks.append(part)
158
- # remove flushed items
159
- self.buffer[row][col] = q[take:]
160
-
161
- emitted = sum(len(c) for c in cell_chunks)
162
-
163
- result = BatchResults(
164
- chunk=cell_chunks,
165
- statistics=statistics or {},
166
- batch_size=emitted,
167
- )
168
- yield result
196
+ out: List[Tuple[int, int, int]] = []
197
+ for r in range(self.table_size):
198
+ for c in range(self.table_size):
199
+ n = len(self.buffer[r][c])
200
+ if n > 0:
201
+ out.append((n, r, c))
202
+ return out
169
203
 
170
- def get_batch(self, max_batch__size: int) -> Generator[BatchResults, None, None]:
204
+ def _select_wave(self, *, min_bucket_len: int, seed: List[Tuple[int, int]] | None = None) -> List[Tuple[int, int]]:
171
205
  """
172
- Consume upstream batches, split data across cells, and emit diagonal partitions:
173
- - During streaming: emit a full partition when all its cells have >= max_batch__size.
174
- - Also during streaming: if any cell in a partition builds up beyond a 'burst' threshold
175
- (3 * of max_batch__size), emit that partition early, taking up to max_batch__size
176
- from each cell.
177
- - After source exhaustion: flush leftovers in waves capped at max_batch__size per cell.
206
+ Greedy wave scheduler: pick a non-overlapping set of buckets with len >= min_bucket_len.
178
207
 
179
- Statistics policy:
180
- * Every emission except the last carries {}.
181
- * The last emission carries the accumulated upstream stats (unfiltered).
208
+ If `seed` is provided, it is taken as fixed and the wave is extended greedily.
182
209
  """
183
- schedule = self._generate_batch_schedule()
184
-
185
- accum_stats: Dict[str, Any] = {}
186
- pending: BatchResults | None = None # hold back the most recent emission so we know what's final
187
-
188
- burst_threshold = 3 * max_batch__size
189
-
190
- for upstream in self.predecessor.get_batch(max_batch__size):
191
- # accumulate upstream stats (unfiltered)
192
- if upstream.statistics:
193
- accum_stats = merge_summery(accum_stats, upstream.statistics)
210
+ candidates: List[Tuple[int, int, int]] = []
211
+ for r in range(self.table_size):
212
+ for c in range(self.table_size):
213
+ n = len(self.buffer[r][c])
214
+ if n >= min_bucket_len:
215
+ candidates.append((n, r, c))
216
+
217
+ if not candidates and not seed:
218
+ return []
219
+
220
+ candidates.sort(key=lambda x: (-x[0], x[1], x[2]))
221
+
222
+ used: set[Any] = set()
223
+ wave: List[Tuple[int, int]] = []
224
+
225
+ if seed:
226
+ for r, c in seed:
227
+ claims = self._bucket_claims(r, c)
228
+ used.update(claims)
229
+ wave.append((r, c))
230
+
231
+ for _, r, c in candidates:
232
+ if (r, c) in wave:
233
+ continue
234
+ claims = self._bucket_claims(r, c)
235
+ if any(claim in used for claim in claims):
236
+ continue
237
+ wave.append((r, c))
238
+ used.update(claims)
239
+ if len(wave) >= self.table_size:
240
+ break
241
+
242
+ return wave
243
+
244
+ def _find_hottest_bucket(self, *, threshold: int) -> Tuple[int, int, int] | None:
245
+ """
246
+ Find the single hottest bucket (largest backlog) whose size is >= threshold.
247
+ Returns (row, col, size) or None.
248
+ """
249
+ best: Tuple[int, int, int] | None = None
250
+ for r in range(self.table_size):
251
+ for c in range(self.table_size):
252
+ n = len(self.buffer[r][c])
253
+ if n < threshold:
254
+ continue
255
+ if best is None or n > best[2]:
256
+ best = (r, c, n)
257
+ return best
194
258
 
195
- # add data to cells
196
- for item in upstream.chunk:
197
- r, c = self._id_extractor(item)
198
- if not (0 <= r < self.table_size and 0 <= c < self.table_size):
199
- raise ValueError(f"partition id out of range: {(r, c)} for table_size={self.table_size}")
200
- self.buffer[r][c].append(item)
259
+ def _flush_wave(
260
+ self,
261
+ wave: List[Tuple[int, int]],
262
+ max_batch_size: int,
263
+ statistics: Dict[str, Any] | None = None,
264
+ ) -> BatchResults:
265
+ """
266
+ Extract up to `max_batch_size` items from each bucket in `wave`, remove them from the buffer,
267
+ and return a BatchResults whose chunk is a list of per-bucket lists (aligned with `wave`).
268
+ """
269
+ self._log_buffer_matrix(wave=wave)
201
270
 
202
- # process partitions
203
- for partition in schedule:
204
- # normal full flush when all cells are ready
205
- if all(len(self.buffer[r][c]) >= max_batch__size for r, c in partition):
206
- br = next(self._flush_group(partition, max_batch__size, statistics={}))
207
- if pending is not None:
208
- yield pending
209
- pending = br
210
- continue
271
+ bucket_batches: List[List[Any]] = []
272
+ for r, c in wave:
273
+ q = self.buffer[r][c]
274
+ take = min(max_batch_size, len(q))
275
+ bucket_batches.append(q[:take])
276
+ self.buffer[r][c] = q[take:]
211
277
 
212
- # burst flush if any cell backlog explodes
213
- hot_cells = [(r, c, len(self.buffer[r][c])) for r, c in partition if
214
- len(self.buffer[r][c]) >= burst_threshold]
215
- if hot_cells:
216
- top_r, top_c, top_len = max(hot_cells, key=lambda x: x[2])
217
- self.logger.debug(
218
- "burst flush: partition=%s threshold=%d top_cell=(%d,%d len=%d)",
219
- partition, burst_threshold, top_r, top_c, top_len
220
- )
221
- br = next(self._flush_group(partition, max_batch__size, statistics={}))
222
- if pending is not None:
223
- yield pending
224
- pending = br
225
-
226
- # source exhausted: drain leftovers in capped waves (respecting batch size)
227
- self.logger.debug("start flushing leftovers")
228
- for partition in (p for p in schedule if any(self.buffer[r][c] for r, c in p)):
229
- while any(self.buffer[r][c] for r, c in partition):
230
- br = next(self._flush_group(partition, max_batch__size, statistics={}))
231
- if pending is not None:
232
- yield pending
233
- pending = br
278
+ emitted = sum(len(b) for b in bucket_batches)
234
279
 
235
- # final emission carries accumulated stats once
236
- if pending is not None:
237
- yield BatchResults(chunk=pending.chunk, statistics=accum_stats, batch_size=pending.batch_size)
280
+ return BatchResults(
281
+ chunk=bucket_batches,
282
+ statistics=statistics or {},
283
+ batch_size=emitted,
284
+ )
238
285
 
239
- def _log_buffer_matrix(self, *, partition: List[Tuple[int, int]]) -> None:
286
+ def _log_buffer_matrix(self, *, wave: List[Tuple[int, int]]) -> None:
240
287
  """
241
- Dumps a compact 2D matrix of per-cell sizes (len of each buffer) when DEBUG is enabled.
288
+ Dumps a compact 2D matrix of per-bucket sizes (len of each buffer) when DEBUG is enabled.
242
289
  """
243
290
  if not self.logger.isEnabledFor(logging.DEBUG):
244
291
  return
@@ -247,7 +294,7 @@ class SplittingBatchProcessor(BatchProcessor):
247
294
  [len(self.buffer[r][c]) for c in range(self.table_size)]
248
295
  for r in range(self.table_size)
249
296
  ]
250
- marks = set(partition)
297
+ marks = set(wave)
251
298
 
252
299
  pad = max(2, len(str(self.table_size - 1)))
253
300
  col_headers = [f"c{c:0{pad}d}" for c in range(self.table_size)]
@@ -266,3 +313,79 @@ class SplittingBatchProcessor(BatchProcessor):
266
313
  disable_numparse=True,
267
314
  )
268
315
  self.logger.debug("buffer matrix:\n%s", table)
316
+
317
+ def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
318
+ """
319
+ Consume upstream batches, bucket incoming rows, and emit waves of non-overlapping buckets.
320
+
321
+ Streaming behavior:
322
+ - If at least one bucket is full (>= max_batch_size), emit a wave seeded with full buckets
323
+ and extended with near-full buckets to keep parallelism high.
324
+ - If a bucket exceeds a burst threshold, emit a burst wave (seeded by the hottest bucket)
325
+ and extended with near-full buckets.
326
+
327
+ End-of-stream behavior:
328
+ - Flush leftovers in capped waves (max_batch_size per bucket).
329
+
330
+ Statistics policy:
331
+ * Every emission except the last carries {}.
332
+ * The last emission carries the accumulated upstream stats (unfiltered).
333
+ """
334
+ if self.predecessor is None:
335
+ return
336
+
337
+ accum_stats: Dict[str, Any] = {}
338
+ pending: BatchResults | None = None
339
+
340
+ near_full_threshold = max(1, int(max_batch_size * self.near_full_ratio))
341
+ burst_threshold = self.burst_multiplier * max_batch_size
342
+
343
+ for upstream in self.predecessor.get_batch(max_batch_size):
344
+ if upstream.statistics:
345
+ accum_stats = merge_summery(accum_stats, upstream.statistics)
346
+
347
+ for item in upstream.chunk:
348
+ r, c = self._id_extractor(item)
349
+ if self._monopartite and r > c:
350
+ r, c = c, r
351
+ if not (0 <= r < self.table_size and 0 <= c < self.table_size):
352
+ raise ValueError(f"bucket id out of range: {(r, c)} for table_size={self.table_size}")
353
+ self.buffer[r][c].append(item)
354
+
355
+ while True:
356
+ full_seed = self._select_wave(min_bucket_len=max_batch_size)
357
+ if not full_seed:
358
+ break
359
+ wave = self._select_wave(min_bucket_len=near_full_threshold, seed=full_seed)
360
+ br = self._flush_wave(wave, max_batch_size, statistics={})
361
+ if pending is not None:
362
+ yield pending
363
+ pending = br
364
+
365
+ while True:
366
+ hot = self._find_hottest_bucket(threshold=burst_threshold)
367
+ if hot is None:
368
+ break
369
+ hot_r, hot_c, hot_n = hot
370
+ wave = self._select_wave(min_bucket_len=near_full_threshold, seed=[(hot_r, hot_c)])
371
+ self.logger.debug(
372
+ "burst flush: hottest_bucket=(%d,%d len=%d) threshold=%d near_full_threshold=%d wave_size=%d",
373
+ hot_r, hot_c, hot_n, burst_threshold, near_full_threshold, len(wave)
374
+ )
375
+ br = self._flush_wave(wave, max_batch_size, statistics={})
376
+ if pending is not None:
377
+ yield pending
378
+ pending = br
379
+
380
+ self.logger.debug("start flushing leftovers")
381
+ while True:
382
+ wave = self._select_wave(min_bucket_len=1)
383
+ if not wave:
384
+ break
385
+ br = self._flush_wave(wave, max_batch_size, statistics={})
386
+ if pending is not None:
387
+ yield pending
388
+ pending = br
389
+
390
+ if pending is not None:
391
+ yield BatchResults(chunk=pending.chunk, statistics=accum_stats, batch_size=pending.batch_size)
@@ -1,10 +1,16 @@
1
- import time
2
- from typing import Generator, Callable, Optional, List, Dict
1
+ import logging
2
+ from typing import Generator, Callable, Optional
3
3
 
4
- from psycopg2 import OperationalError as PsycopgOperationalError
5
4
  from sqlalchemy import text
6
5
  from sqlalchemy.exc import OperationalError as SAOperationalError, DBAPIError
7
6
 
7
+ # Conditional import for psycopg2 to avoid crashing if not installed
8
+ try:
9
+ from psycopg2 import OperationalError as PsycopgOperationalError
10
+ except ImportError:
11
+ class PsycopgOperationalError(Exception):
12
+ pass
13
+
8
14
  from etl_lib.core.BatchProcessor import BatchResults, BatchProcessor
9
15
  from etl_lib.core.ETLContext import ETLContext
10
16
  from etl_lib.core.Task import Task
@@ -20,95 +26,59 @@ class SQLBatchSource(BatchProcessor):
20
26
  **kwargs
21
27
  ):
22
28
  """
23
- Constructs a new SQLBatchSource.
24
-
25
- Args:
26
- context: :class:`etl_lib.core.ETLContext.ETLContext` instance.
27
- task: :class:`etl_lib.core.Task.Task` instance owning this batchProcessor.
28
- query: SQL query to execute.
29
- record_transformer: Optional function to transform each row (dict format).
30
- kwargs: Arguments passed as parameters with the query.
29
+ Constructs a new SQLBatchSource that streams results instead of paging them.
31
30
  """
32
31
  super().__init__(context, task)
32
+ # Remove any trailing semicolons to prevent SQL syntax errors
33
33
  self.query = query.strip().rstrip(";")
34
34
  self.record_transformer = record_transformer
35
35
  self.kwargs = kwargs
36
+ self.logger = logging.getLogger(__name__)
36
37
 
37
- def _fetch_page(self, limit: int, offset: int) -> Optional[List[Dict]]:
38
+ def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
38
39
  """
39
- Fetch a single batch of rows using LIMIT/OFFSET, with retry/backoff.
40
-
41
- Each page is executed in its own transaction. On transient
42
- disconnects or DB errors, it retries up to 3 times with exponential backoff.
43
-
44
- Args:
45
- limit: maximum number of rows to return.
46
- offset: number of rows to skip before starting this page.
40
+ Yield successive batches using a Server-Side Cursor (Streaming).
47
41
 
48
- Returns:
49
- A list of row dicts (after applying record_transformer, if any),
50
- or None if no rows are returned.
51
-
52
- Raises:
53
- Exception: re-raises the last caught error on final failure.
42
+ This avoids 'LIMIT/OFFSET' pagination, which causes performance degradation
43
+ on large tables. Instead, it holds a cursor open and fetches rows incrementally.
54
44
  """
55
- paged_sql = f"{self.query} LIMIT :limit OFFSET :offset"
56
- params = {**self.kwargs, "limit": limit, "offset": offset}
57
- max_retries = 5
58
- backoff = 2.0
59
45
 
60
- for attempt in range(1, max_retries + 1):
61
- try:
62
- with self.context.sql.engine.connect() as conn:
63
- with conn.begin():
64
- rows = conn.execute(text(paged_sql), params).mappings().all()
65
- result = [
66
- self.record_transformer(dict(r)) if self.record_transformer else dict(r)
67
- for r in rows
68
- ]
69
- return result if result else None
46
+ with self.context.sql.engine.connect() as conn:
70
47
 
71
- except (PsycopgOperationalError, SAOperationalError, DBAPIError) as err:
48
+ conn = conn.execution_options(stream_results=True)
72
49
 
73
- if attempt == max_retries:
74
- self.logger.error(
75
- f"Page fetch failed after {max_retries} attempts "
76
- f"(limit={limit}, offset={offset}): {err}"
50
+ try:
51
+ self.logger.info("Starting SQL Result Stream...")
52
+
53
+ result_proxy = conn.execute(text(self.query), self.kwargs)
54
+
55
+ chunk = []
56
+ count = 0
57
+
58
+ for row in result_proxy.mappings():
59
+ item = self.record_transformer(dict(row)) if self.record_transformer else dict(row)
60
+ chunk.append(item)
61
+ count += 1
62
+
63
+ # Yield when batch is full
64
+ if len(chunk) >= max_batch_size:
65
+ yield BatchResults(
66
+ chunk=chunk,
67
+ statistics={"sql_rows_read": len(chunk)},
68
+ batch_size=len(chunk),
69
+ )
70
+ chunk = [] # Clear memory
71
+
72
+ # Yield any remaining rows
73
+ if chunk:
74
+ yield BatchResults(
75
+ chunk=chunk,
76
+ statistics={"sql_rows_read": len(chunk)},
77
+ batch_size=len(chunk),
77
78
  )
78
- raise
79
79
 
80
- self.logger.warning(
81
- f"Transient DB error on page fetch {attempt}/{max_retries}: {err!r}, "
82
- f"retrying in {backoff:.1f}s"
83
- )
84
- time.sleep(backoff)
85
- backoff *= 2
80
+ self.logger.info(f"SQL Stream finished. Total rows read: {count}")
86
81
 
87
- return None
88
-
89
- def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
90
- """
91
- Yield successive batches until the query is exhausted.
92
-
93
- Calls _fetch_page() repeatedly, advancing the offset by the
94
- number of rows returned. Stops when no more rows are returned.
95
-
96
- Args:
97
- max_batch_size: upper limit on rows per batch.
98
-
99
- Yields:
100
- BatchResults for each non-empty page.
101
- """
102
- offset = 0
103
- while True:
104
- chunk = self._fetch_page(max_batch_size, offset)
105
- if not chunk:
106
- break
107
-
108
- yield BatchResults(
109
- chunk=chunk,
110
- statistics={"sql_rows_read": len(chunk)},
111
- batch_size=len(chunk),
112
- )
113
-
114
- offset += len(chunk)
82
+ except (PsycopgOperationalError, SAOperationalError, DBAPIError) as err:
83
+ self.logger.error(f"Stream failed: {err}")
84
+ raise
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: neo4j-etl-lib
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: Building blocks for ETL pipelines.
5
5
  Keywords: etl,graph,database
6
6
  Author-email: Bert Radke <bert.radke@pm.me>
@@ -1,12 +1,12 @@
1
- etl_lib/__init__.py,sha256=x6coFV38ytJ_wPhR3c0UEzX65oTz2ouKwygkC_tyRLM,65
1
+ etl_lib/__init__.py,sha256=q7f7YqfTzmaIBlhdhwum8vxg-YfqXBBdyqS-LS5Bq9U,65
2
2
  etl_lib/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  etl_lib/cli/run_tools.py,sha256=KIar-y22P4kKm-yoJjecYsPwqC7U76M71dEgFO5-ZBo,8561
4
4
  etl_lib/core/BatchProcessor.py,sha256=mRpdxZ6ZMKI8XsY3TPuy4dVcvRqLKCO-p63KeOhFyKE,3417
5
5
  etl_lib/core/ClosedLoopBatchProcessor.py,sha256=WzML1nldhZRbP8fhlD6utuK5SBYRl1cJgEobVDIdBP4,1626
6
- etl_lib/core/ETLContext.py,sha256=wmEnbs3n_80B6La9Py_-MHG8BN0FajE9MjGPej0A3To,8045
7
- etl_lib/core/ParallelBatchProcessor.py,sha256=jNo1Xv1Ts34UZIseoQLDZOhHOVeEr8dUibKUt0FJ4Hw,7318
6
+ etl_lib/core/ETLContext.py,sha256=7dVU3qc8tK-8vbrF-pofKVih8yey7tJUiUjKKjOS28o,7625
7
+ etl_lib/core/ParallelBatchProcessor.py,sha256=mR6N3C75NgYOVQvmioGmEkQo7RXQ0tDj13-IqH4TkeY,6067
8
8
  etl_lib/core/ProgressReporter.py,sha256=tkE-W6qlR25nU8nUoECcxZDnjnG8AtQH9s9s5WBh_-Q,9377
9
- etl_lib/core/SplittingBatchProcessor.py,sha256=OIRMUVFpUoZc0w__JJjUr7B9QC3sBlqQp41xghrQzC0,11616
9
+ etl_lib/core/SplittingBatchProcessor.py,sha256=ZxLYoY41D9f1wEmiGgqSO0q1iqdnMyU98e0B9iycWh8,14909
10
10
  etl_lib/core/Task.py,sha256=muQFY5qj2n-ZVV8F6vlHqo2lVSvB3wtGdIgkSXVpOFM,9365
11
11
  etl_lib/core/ValidationBatchProcessor.py,sha256=U1M2Qp9Ledt8qFiHAg8zMxE9lLRkBrr51NKs_Y8skK8,3400
12
12
  etl_lib/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -17,7 +17,7 @@ etl_lib/data_sink/SQLBatchSink.py,sha256=vyGrrxpdmCLUZMI2_W2ORej3FLGbwN9-b2GMYHd
17
17
  etl_lib/data_sink/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  etl_lib/data_source/CSVBatchSource.py,sha256=0q1XdPhAIKw1HcTpnp_F4WxRUzk-24Q8Qd-WeIo5OZ0,2779
19
19
  etl_lib/data_source/CypherBatchSource.py,sha256=06WuW11BqYjAXBZqL96Qr9MR8JrcjujDpxXe8cI-SYY,2238
20
- etl_lib/data_source/SQLBatchSource.py,sha256=O3ZA2GXvo5j_KGwOILzguYZMPY_FJkV5j8FIa3-d9oM,4067
20
+ etl_lib/data_source/SQLBatchSource.py,sha256=99NF0-H1tRqYY8yQBlfuF0ORMToysXpOQXrS62_KIeQ,3038
21
21
  etl_lib/data_source/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
22
  etl_lib/task/CreateReportingConstraintsTask.py,sha256=nTcHLBIgXz_h2OQg-SHjQr68bhH974u0MwrtWPnVwng,762
23
23
  etl_lib/task/ExecuteCypherTask.py,sha256=thE8YTZzv1abxNhhDcb4p4ke6qmI6kWR4XQ-GrCBBBU,1284
@@ -30,7 +30,7 @@ etl_lib/task/data_loading/SQLLoad2Neo4jTask.py,sha256=HR3DcjOUkQN4SbCkgQYzljQCYh
30
30
  etl_lib/task/data_loading/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
31
  etl_lib/test_utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
32
  etl_lib/test_utils/utils.py,sha256=CgYOCXcUyndOdRAmGyPLoCIuEik0yzy6FLV2k16cpDM,5673
33
- neo4j_etl_lib-0.3.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
34
- neo4j_etl_lib-0.3.1.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
35
- neo4j_etl_lib-0.3.1.dist-info/METADATA,sha256=Pm921qyxL36Ed_Ppp2cW3OFPxUGMv7IyRTmtba3n96o,2580
36
- neo4j_etl_lib-0.3.1.dist-info/RECORD,,
33
+ neo4j_etl_lib-0.3.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
34
+ neo4j_etl_lib-0.3.2.dist-info/WHEEL,sha256=G2gURzTEtmeR8nrdXUJfNiB3VYVxigPQ-bEQujpNiNs,82
35
+ neo4j_etl_lib-0.3.2.dist-info/METADATA,sha256=nLCy9Iu1V3_WTKoo551tN__hZEbcM8Aeg1YxVv2eU1M,2580
36
+ neo4j_etl_lib-0.3.2.dist-info/RECORD,,