neo4j-etl-lib 0.3.0__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/PKG-INFO +8 -6
  2. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/pyproject.toml +10 -5
  3. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/__init__.py +1 -1
  4. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/ETLContext.py +2 -9
  5. neo4j_etl_lib-0.3.2/src/etl_lib/core/ParallelBatchProcessor.py +150 -0
  6. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/ProgressReporter.py +1 -1
  7. neo4j_etl_lib-0.3.2/src/etl_lib/core/SplittingBatchProcessor.py +391 -0
  8. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_source/CSVBatchSource.py +1 -1
  9. neo4j_etl_lib-0.3.2/src/etl_lib/data_source/SQLBatchSource.py +84 -0
  10. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/GDSTask.py +8 -5
  11. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/test_utils/utils.py +9 -5
  12. neo4j_etl_lib-0.3.0/src/etl_lib/core/ParallelBatchProcessor.py +0 -180
  13. neo4j_etl_lib-0.3.0/src/etl_lib/core/SplittingBatchProcessor.py +0 -268
  14. neo4j_etl_lib-0.3.0/src/etl_lib/data_source/SQLBatchSource.py +0 -114
  15. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/LICENSE +0 -0
  16. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/README.md +0 -0
  17. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/cli/__init__.py +0 -0
  18. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/cli/run_tools.py +0 -0
  19. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/BatchProcessor.py +0 -0
  20. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/ClosedLoopBatchProcessor.py +0 -0
  21. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/Task.py +0 -0
  22. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/ValidationBatchProcessor.py +0 -0
  23. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/__init__.py +0 -0
  24. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/utils.py +0 -0
  25. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_sink/CSVBatchSink.py +0 -0
  26. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_sink/CypherBatchSink.py +0 -0
  27. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_sink/SQLBatchSink.py +0 -0
  28. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_sink/__init__.py +0 -0
  29. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_source/CypherBatchSource.py +0 -0
  30. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_source/__init__.py +0 -0
  31. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/CreateReportingConstraintsTask.py +0 -0
  32. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/ExecuteCypherTask.py +0 -0
  33. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/__init__.py +0 -0
  34. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/data_loading/CSVLoad2Neo4jTask.py +0 -0
  35. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/data_loading/ParallelCSVLoad2Neo4jTask.py +0 -0
  36. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/data_loading/ParallelSQLLoad2Neo4jTask.py +0 -0
  37. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/data_loading/SQLLoad2Neo4jTask.py +0 -0
  38. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/data_loading/__init__.py +0 -0
  39. {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/test_utils/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: neo4j-etl-lib
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: Building blocks for ETL pipelines.
5
5
  Keywords: etl,graph,database
6
6
  Author-email: Bert Radke <bert.radke@pm.me>
@@ -14,11 +14,11 @@ Classifier: Programming Language :: Python :: 3
14
14
  Classifier: Topic :: Database
15
15
  Classifier: Development Status :: 4 - Beta
16
16
  License-File: LICENSE
17
- Requires-Dist: pydantic>=2.10.5; python_version >= '3.8'
18
- Requires-Dist: neo4j-rust-ext>=5.27.0; python_version >= '3.7'
19
- Requires-Dist: python-dotenv>=1.0.1; python_version >= '3.8'
20
- Requires-Dist: tabulate>=0.9.0; python_version >= '3.7'
21
- Requires-Dist: click>=8.1.8; python_version >= '3.7'
17
+ Requires-Dist: pydantic>=2.10.5; python_version >= '3.10'
18
+ Requires-Dist: neo4j-rust-ext>=5.27.0,<6; python_version >= '3.10'
19
+ Requires-Dist: python-dotenv>=1.0.1; python_version >= '3.10'
20
+ Requires-Dist: tabulate>=0.9.0; python_version >= '3.10'
21
+ Requires-Dist: click>=8.1.8; python_version >= '3.10'
22
22
  Requires-Dist: pydantic[email-validator]
23
23
  Requires-Dist: pytest>=8.3.0 ; extra == "dev" and ( python_version >= '3.8')
24
24
  Requires-Dist: testcontainers[neo4j]==4.9.0 ; extra == "dev" and ( python_version >= '3.9' and python_version < '4.0')
@@ -35,11 +35,13 @@ Requires-Dist: sphinx-autoapi ; extra == "dev"
35
35
  Requires-Dist: sqlalchemy ; extra == "dev"
36
36
  Requires-Dist: psycopg2-binary ; extra == "dev"
37
37
  Requires-Dist: graphdatascience>=1.13 ; extra == "gds" and ( python_version >= '3.9')
38
+ Requires-Dist: nox>=2024.0.0 ; extra == "nox"
38
39
  Requires-Dist: sqlalchemy ; extra == "sql"
39
40
  Project-URL: Documentation, https://neo-technology-field.github.io/python-etl-lib/index.html
40
41
  Project-URL: Home, https://github.com/neo-technology-field/python-etl-lib
41
42
  Provides-Extra: dev
42
43
  Provides-Extra: gds
44
+ Provides-Extra: nox
43
45
  Provides-Extra: sql
44
46
 
45
47
  # Neo4j ETL Toolbox
@@ -22,11 +22,11 @@ dynamic = ["version", "description"]
22
22
  keywords = ["etl", "graph", "database"]
23
23
 
24
24
  dependencies = [
25
- "pydantic>=2.10.5; python_version >= '3.8'",
26
- "neo4j-rust-ext>=5.27.0; python_version >= '3.7'",
27
- "python-dotenv>=1.0.1; python_version >= '3.8'",
28
- "tabulate>=0.9.0; python_version >= '3.7'",
29
- "click>=8.1.8; python_version >= '3.7'",
25
+ "pydantic>=2.10.5; python_version >= '3.10'",
26
+ "neo4j-rust-ext>=5.27.0,<6; python_version >= '3.10'",
27
+ "python-dotenv>=1.0.1; python_version >= '3.10'",
28
+ "tabulate>=0.9.0; python_version >= '3.10'",
29
+ "click>=8.1.8; python_version >= '3.10'",
30
30
  "pydantic[email_validator]"
31
31
  ]
32
32
 
@@ -41,6 +41,11 @@ dev = [
41
41
  gds = ["graphdatascience>=1.13; python_version >= '3.9'"]
42
42
  sql = ["sqlalchemy"]
43
43
 
44
+ # Local-only multy-version testing, install via `pip install ".[dev,nox]"`
45
+ nox = [
46
+ "nox>=2024.0.0"
47
+ ]
48
+
44
49
  [project.urls]
45
50
  Home = "https://github.com/neo-technology-field/python-etl-lib"
46
51
  Documentation = "https://neo-technology-field.github.io/python-etl-lib/index.html"
@@ -1,4 +1,4 @@
1
1
  """
2
2
  Building blocks for ETL pipelines.
3
3
  """
4
- __version__ = "0.3.0"
4
+ __version__ = "0.3.2"
@@ -172,15 +172,8 @@ if sqlalchemy_available:
172
172
  database_url,
173
173
  pool_pre_ping=True,
174
174
  pool_size=pool_size,
175
- max_overflow=max_overflow,
176
- pool_recycle=1800, # recycle connections older than 30m
177
- connect_args={
178
- # turn on TCP keepalives on the client socket:
179
- "keepalives": 1,
180
- "keepalives_idle": 60, # after 60s of idle
181
- "keepalives_interval": 10, # probe every 10s
182
- "keepalives_count": 5, # give up after 5 failed probes
183
- })
175
+ max_overflow=max_overflow
176
+ )
184
177
 
185
178
 
186
179
  class ETLContext:
@@ -0,0 +1,150 @@
1
+ import queue
2
+ import threading
3
+ from concurrent.futures import ThreadPoolExecutor, as_completed
4
+ from typing import Any, Callable, Generator, List
5
+
6
+ from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
7
+ from etl_lib.core.utils import merge_summery
8
+
9
+
10
+ class ParallelBatchResult(BatchResults):
11
+ """
12
+ Represents one *wave* produced by the splitter.
13
+
14
+ `chunk` is a list of bucket-batches. Each sub-list is processed by one worker instance.
15
+ """
16
+ pass
17
+
18
+
19
+ class ParallelBatchProcessor(BatchProcessor):
20
+ """
21
+ BatchProcessor that runs a worker over the bucket-batches of each ParallelBatchResult
22
+ in parallel threads, while prefetching the next ParallelBatchResult from its predecessor.
23
+
24
+ Note:
25
+ - The predecessor must emit `ParallelBatchResult` instances (waves).
26
+ - This processor collects the BatchResults from all workers for one wave and merges them
27
+ into one BatchResults.
28
+ - The returned BatchResults will not obey the max_batch_size from get_batch() because
29
+ it represents the full wave.
30
+
31
+ Args:
32
+ context: ETL context.
33
+ worker_factory: A zero-arg callable that returns a new BatchProcessor each time it's called.
34
+ task: optional Task for reporting.
35
+ predecessor: upstream BatchProcessor that must emit ParallelBatchResult See
36
+ :class:`~etl_lib.core.SplittingBatchProcessor.SplittingBatchProcessor`.
37
+ max_workers: number of parallel threads for bucket processing.
38
+ prefetch: number of waves to prefetch.
39
+
40
+ Behavior:
41
+ - For every wave, spins up `max_workers` threads.
42
+ - Each thread processes one bucket-batch using a fresh worker from `worker_factory()`.
43
+ - Collects and merges worker results in a fail-fast manner.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ context,
49
+ worker_factory: Callable[[], BatchProcessor],
50
+ task=None,
51
+ predecessor=None,
52
+ max_workers: int = 4,
53
+ prefetch: int = 4,
54
+ ):
55
+ super().__init__(context, task, predecessor)
56
+ self.worker_factory = worker_factory
57
+ self.max_workers = max_workers
58
+ self.prefetch = prefetch
59
+
60
+ def _process_wave(self, wave: ParallelBatchResult) -> BatchResults:
61
+ """
62
+ Process one wave: run one worker per bucket-batch and merge their BatchResults.
63
+
64
+ Statistics:
65
+ `wave.statistics` is used as the initial merged stats, then merged with each worker's stats.
66
+ """
67
+ merged_stats = dict(wave.statistics or {})
68
+ merged_chunk = []
69
+ total = 0
70
+
71
+ self.logger.debug(f"Processing wave with {len(wave.chunk)} buckets")
72
+
73
+ with ThreadPoolExecutor(max_workers=self.max_workers, thread_name_prefix="PBP_worker_") as pool:
74
+ futures = [pool.submit(self._process_bucket_batch, bucket_batch) for bucket_batch in wave.chunk]
75
+ try:
76
+ for f in as_completed(futures):
77
+ out = f.result()
78
+ merged_stats = merge_summery(merged_stats, out.statistics or {})
79
+ total += out.batch_size
80
+ merged_chunk.extend(out.chunk if isinstance(out.chunk, list) else [out.chunk])
81
+ except Exception as e:
82
+ for g in futures:
83
+ g.cancel()
84
+ pool.shutdown(cancel_futures=True)
85
+ raise RuntimeError("bucket processing failed") from e
86
+
87
+ self.logger.debug(f"Finished wave with stats={merged_stats}")
88
+ return BatchResults(chunk=merged_chunk, statistics=merged_stats, batch_size=total)
89
+
90
+ def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
91
+ """
92
+ Pull waves from the predecessor (prefetching up to `prefetch` ahead), process each wave's
93
+ buckets in parallel, and yield one flattened BatchResults per wave.
94
+ """
95
+ wave_queue: queue.Queue[ParallelBatchResult | object] = queue.Queue(self.prefetch)
96
+ SENTINEL = object()
97
+ exc: BaseException | None = None
98
+
99
+ def producer():
100
+ nonlocal exc
101
+ try:
102
+ for wave in self.predecessor.get_batch(max_batch_size):
103
+ self.logger.debug(
104
+ f"adding wave stats={wave.statistics} buckets={len(wave.chunk)} to queue size={wave_queue.qsize()}"
105
+ )
106
+ wave_queue.put(wave)
107
+ except BaseException as e:
108
+ exc = e
109
+ finally:
110
+ wave_queue.put(SENTINEL)
111
+
112
+ threading.Thread(target=producer, daemon=True, name="prefetcher").start()
113
+
114
+ while True:
115
+ wave = wave_queue.get()
116
+ if wave is SENTINEL:
117
+ if exc is not None:
118
+ self.logger.error("Upstream producer failed", exc_info=True)
119
+ raise exc
120
+ break
121
+ yield self._process_wave(wave)
122
+
123
+ class SingleBatchWrapper(BatchProcessor):
124
+ """
125
+ Simple BatchProcessor that returns exactly one batch (the bucket-batch passed in via init).
126
+ Used as predecessor for the per-bucket worker.
127
+ """
128
+
129
+ def __init__(self, context, batch: List[Any]):
130
+ super().__init__(context=context, predecessor=None)
131
+ self._batch = batch
132
+
133
+ def get_batch(self, max_size: int) -> Generator[BatchResults, None, None]:
134
+ yield BatchResults(
135
+ chunk=self._batch,
136
+ statistics={},
137
+ batch_size=len(self._batch),
138
+ )
139
+
140
+ def _process_bucket_batch(self, bucket_batch):
141
+ """
142
+ Process one bucket-batch by running a fresh worker over it.
143
+ """
144
+ self.logger.debug(f"Processing batch w/ size {len(bucket_batch)}")
145
+ wrapper = self.SingleBatchWrapper(self.context, bucket_batch)
146
+ worker = self.worker_factory()
147
+ worker.predecessor = wrapper
148
+ result = next(worker.get_batch(len(bucket_batch)))
149
+ self.logger.debug(f"Finished bucket batch stats={result.statistics}")
150
+ return result
@@ -45,7 +45,7 @@ class ProgressReporter:
45
45
  The task that was provided.
46
46
  """
47
47
  task.start_time = datetime.now()
48
- self.logger.info(f"{'\t' * task.depth}starting {task.task_name()}")
48
+ self.logger.info(f"{' ' * (4 * task.depth)}starting {task.task_name()}")
49
49
  return task
50
50
 
51
51
  def finished_task(self, task: Task, result: TaskReturn) -> Task:
@@ -0,0 +1,391 @@
1
+ import logging
2
+ from typing import Any, Callable, Dict, Generator, List, Tuple
3
+
4
+ from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
5
+ from etl_lib.core.utils import merge_summery
6
+ from tabulate import tabulate
7
+
8
+
9
+ def tuple_id_extractor(table_size: int = 10) -> Callable[[Tuple[str | int, str | int]], Tuple[int, int]]:
10
+ """
11
+ Create an ID extractor function for tuple items, using the last decimal digit of each element.
12
+ The output is a `(row, col)` tuple within a `table_size x table_size` grid (default 10x10).
13
+
14
+ Args:
15
+ table_size: The dimension of the grid (number of rows/cols). Defaults to 10.
16
+
17
+ Returns:
18
+ A callable that maps a tuple `(a, b)` to a tuple `(row, col)` using the last digit of `a` and `b`.
19
+
20
+ Notes:
21
+ This extractor does not validate the returned indices. Range validation is performed by
22
+ :class:`~etl_lib.core.SplittingBatchProcessor.SplittingBatchProcessor`.
23
+ """
24
+
25
+ def extractor(item: Tuple[Any, Any]) -> Tuple[int, int]:
26
+ a, b = item
27
+ try:
28
+ row = int(str(a)[-1])
29
+ col = int(str(b)[-1])
30
+ except Exception as e:
31
+ raise ValueError(f"Failed to extract ID from item {item}: {e}")
32
+ return row, col
33
+
34
+ extractor.table_size = table_size
35
+ return extractor
36
+
37
+
38
+ def dict_id_extractor(
39
+ table_size: int = 10,
40
+ start_key: str = "start",
41
+ end_key: str = "end",
42
+ ) -> Callable[[Dict[str, Any]], Tuple[int, int]]:
43
+ """
44
+ Build an ID extractor for dict rows. The extractor reads two fields (configurable via
45
+ `start_key` and `end_key`) and returns (row, col) based on the last decimal digit of each.
46
+
47
+ Args:
48
+ table_size: Informational hint carried on the extractor.
49
+ start_key: Field name for the start node identifier.
50
+ end_key: Field name for the end node identifier.
51
+
52
+ Returns:
53
+ Callable that maps {start_key, end_key} → (row, col).
54
+ """
55
+
56
+ def extractor(item: Dict[str, Any]) -> Tuple[int, int]:
57
+ missing = [k for k in (start_key, end_key) if k not in item]
58
+ if missing:
59
+ raise KeyError(f"Item missing required keys: {', '.join(missing)}")
60
+ try:
61
+ row = int(str(item[start_key])[-1])
62
+ col = int(str(item[end_key])[-1])
63
+ except Exception as e:
64
+ raise ValueError(f"Failed to extract ID from item {item}: {e}")
65
+ return row, col
66
+
67
+ extractor.table_size = table_size
68
+ return extractor
69
+
70
+
71
+ def canonical_integer_id_extractor(
72
+ table_size: int = 10,
73
+ start_key: str = "start",
74
+ end_key: str = "end",
75
+ ) -> Callable[[Dict[str, Any]], Tuple[int, int]]:
76
+ """
77
+ ID extractor for integer IDs with canonical folding.
78
+
79
+ - Uses Knuth's multiplicative hashing to scatter sequential integers across the grid.
80
+ - Canonical folding ensures (A,B) and (B,A) land in the same bucket by folding the lower
81
+ triangle into the upper triangle (row <= col).
82
+
83
+ The extractor marks itself as mono-partite by setting `extractor.monopartite = True`.
84
+ """
85
+ MAGIC = 2654435761
86
+
87
+ def extractor(item: Dict[str, Any]) -> Tuple[int, int]:
88
+ try:
89
+ s_val = item[start_key]
90
+ e_val = item[end_key]
91
+
92
+ s_hash = (s_val * MAGIC) & 0xffffffff
93
+ e_hash = (e_val * MAGIC) & 0xffffffff
94
+
95
+ row = s_hash % table_size
96
+ col = e_hash % table_size
97
+
98
+ if row > col:
99
+ row, col = col, row
100
+
101
+ return row, col
102
+ except KeyError:
103
+ raise KeyError(f"Item missing keys: {start_key} or {end_key}")
104
+ except Exception as e:
105
+ raise ValueError(f"Failed to extract ID: {e}")
106
+
107
+ extractor.table_size = table_size
108
+ extractor.monopartite = True
109
+ return extractor
110
+
111
+
112
+ class SplittingBatchProcessor(BatchProcessor):
113
+ """
114
+ Streaming wave scheduler for mix-and-batch style loading.
115
+
116
+ Incoming rows are assigned to buckets via an `id_extractor(item) -> (row, col)` inside a
117
+ `table_size x table_size` grid. The processor emits waves; each wave contains bucket-batches
118
+ that are safe to process concurrently under the configured non-overlap rule.
119
+
120
+ Non-overlap rules
121
+ -----------------
122
+ - Bi-partite (default): within a wave, no two buckets share a row index and no two buckets share a col index.
123
+ - Mono-partite: within a wave, no node index is used more than once (row/col indices are the same domain).
124
+ Enable by setting `id_extractor.monopartite = True` (as done by `canonical_integer_id_extractor`).
125
+
126
+ Emission strategy
127
+ -----------------
128
+ - During streaming: emit a wave when at least one bucket is full (>= max_batch_size).
129
+ The wave is then filled with additional non-overlapping buckets that are near-full to
130
+ keep parallelism high without producing tiny batches.
131
+ - If a bucket backlog grows beyond a burst threshold, emit a burst wave to bound memory.
132
+ - After source exhaustion: flush leftovers in capped waves (max_batch_size per bucket).
133
+
134
+ Statistics policy
135
+ -----------------
136
+ - Every emission except the last carries {}.
137
+ - The last emission carries the accumulated upstream statistics (unfiltered).
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ context,
143
+ table_size: int,
144
+ id_extractor: Callable[[Any], Tuple[int, int]],
145
+ task=None,
146
+ predecessor=None,
147
+ near_full_ratio: float = 0.85,
148
+ burst_multiplier: int = 25,
149
+ ):
150
+ super().__init__(context, task, predecessor)
151
+
152
+ if hasattr(id_extractor, "table_size"):
153
+ expected_size = id_extractor.table_size
154
+ if table_size is None:
155
+ table_size = expected_size
156
+ elif table_size != expected_size:
157
+ raise ValueError(
158
+ f"Mismatch between provided table_size ({table_size}) and id_extractor table_size ({expected_size})."
159
+ )
160
+ elif table_size is None:
161
+ raise ValueError("table_size must be specified if id_extractor has no defined table_size")
162
+
163
+ if not (0 < near_full_ratio <= 1.0):
164
+ raise ValueError(f"near_full_ratio must be in (0, 1], got {near_full_ratio}")
165
+ if burst_multiplier < 1:
166
+ raise ValueError(f"burst_multiplier must be >= 1, got {burst_multiplier}")
167
+
168
+ self.table_size = table_size
169
+ self._id_extractor = id_extractor
170
+ self._monopartite = bool(getattr(id_extractor, "monopartite", False))
171
+
172
+ self.near_full_ratio = float(near_full_ratio)
173
+ self.burst_multiplier = int(burst_multiplier)
174
+
175
+ self.buffer: Dict[int, Dict[int, List[Any]]] = {
176
+ r: {c: [] for c in range(self.table_size)}
177
+ for r in range(self.table_size)
178
+ }
179
+ self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
180
+
181
+ def _bucket_claims(self, row: int, col: int) -> Tuple[Any, ...]:
182
+ """
183
+ Return the resource claims a bucket consumes within a wave.
184
+
185
+ - Bi-partite: claims (row-slot, col-slot)
186
+ - Mono-partite: claims node indices touched by the bucket
187
+ """
188
+ if self._monopartite:
189
+ return (row,) if row == col else (row, col)
190
+ return ("row", row), ("col", col)
191
+
192
+ def _all_bucket_sizes(self) -> List[Tuple[int, int, int]]:
193
+ """
194
+ Return all non-empty buckets as (size, row, col).
195
+ """
196
+ out: List[Tuple[int, int, int]] = []
197
+ for r in range(self.table_size):
198
+ for c in range(self.table_size):
199
+ n = len(self.buffer[r][c])
200
+ if n > 0:
201
+ out.append((n, r, c))
202
+ return out
203
+
204
+ def _select_wave(self, *, min_bucket_len: int, seed: List[Tuple[int, int]] | None = None) -> List[Tuple[int, int]]:
205
+ """
206
+ Greedy wave scheduler: pick a non-overlapping set of buckets with len >= min_bucket_len.
207
+
208
+ If `seed` is provided, it is taken as fixed and the wave is extended greedily.
209
+ """
210
+ candidates: List[Tuple[int, int, int]] = []
211
+ for r in range(self.table_size):
212
+ for c in range(self.table_size):
213
+ n = len(self.buffer[r][c])
214
+ if n >= min_bucket_len:
215
+ candidates.append((n, r, c))
216
+
217
+ if not candidates and not seed:
218
+ return []
219
+
220
+ candidates.sort(key=lambda x: (-x[0], x[1], x[2]))
221
+
222
+ used: set[Any] = set()
223
+ wave: List[Tuple[int, int]] = []
224
+
225
+ if seed:
226
+ for r, c in seed:
227
+ claims = self._bucket_claims(r, c)
228
+ used.update(claims)
229
+ wave.append((r, c))
230
+
231
+ for _, r, c in candidates:
232
+ if (r, c) in wave:
233
+ continue
234
+ claims = self._bucket_claims(r, c)
235
+ if any(claim in used for claim in claims):
236
+ continue
237
+ wave.append((r, c))
238
+ used.update(claims)
239
+ if len(wave) >= self.table_size:
240
+ break
241
+
242
+ return wave
243
+
244
+ def _find_hottest_bucket(self, *, threshold: int) -> Tuple[int, int, int] | None:
245
+ """
246
+ Find the single hottest bucket (largest backlog) whose size is >= threshold.
247
+ Returns (row, col, size) or None.
248
+ """
249
+ best: Tuple[int, int, int] | None = None
250
+ for r in range(self.table_size):
251
+ for c in range(self.table_size):
252
+ n = len(self.buffer[r][c])
253
+ if n < threshold:
254
+ continue
255
+ if best is None or n > best[2]:
256
+ best = (r, c, n)
257
+ return best
258
+
259
+ def _flush_wave(
260
+ self,
261
+ wave: List[Tuple[int, int]],
262
+ max_batch_size: int,
263
+ statistics: Dict[str, Any] | None = None,
264
+ ) -> BatchResults:
265
+ """
266
+ Extract up to `max_batch_size` items from each bucket in `wave`, remove them from the buffer,
267
+ and return a BatchResults whose chunk is a list of per-bucket lists (aligned with `wave`).
268
+ """
269
+ self._log_buffer_matrix(wave=wave)
270
+
271
+ bucket_batches: List[List[Any]] = []
272
+ for r, c in wave:
273
+ q = self.buffer[r][c]
274
+ take = min(max_batch_size, len(q))
275
+ bucket_batches.append(q[:take])
276
+ self.buffer[r][c] = q[take:]
277
+
278
+ emitted = sum(len(b) for b in bucket_batches)
279
+
280
+ return BatchResults(
281
+ chunk=bucket_batches,
282
+ statistics=statistics or {},
283
+ batch_size=emitted,
284
+ )
285
+
286
+ def _log_buffer_matrix(self, *, wave: List[Tuple[int, int]]) -> None:
287
+ """
288
+ Dumps a compact 2D matrix of per-bucket sizes (len of each buffer) when DEBUG is enabled.
289
+ """
290
+ if not self.logger.isEnabledFor(logging.DEBUG):
291
+ return
292
+
293
+ counts = [
294
+ [len(self.buffer[r][c]) for c in range(self.table_size)]
295
+ for r in range(self.table_size)
296
+ ]
297
+ marks = set(wave)
298
+
299
+ pad = max(2, len(str(self.table_size - 1)))
300
+ col_headers = [f"c{c:0{pad}d}" for c in range(self.table_size)]
301
+
302
+ rows = []
303
+ for r in range(self.table_size):
304
+ row_label = f"r{r:0{pad}d}"
305
+ row_vals = [f"[{v}]" if (r, c) in marks else f"{v}" for c, v in enumerate(counts[r])]
306
+ rows.append([row_label, *row_vals])
307
+
308
+ table = tabulate(
309
+ rows,
310
+ headers=["", *col_headers],
311
+ tablefmt="psql",
312
+ stralign="right",
313
+ disable_numparse=True,
314
+ )
315
+ self.logger.debug("buffer matrix:\n%s", table)
316
+
317
+ def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
318
+ """
319
+ Consume upstream batches, bucket incoming rows, and emit waves of non-overlapping buckets.
320
+
321
+ Streaming behavior:
322
+ - If at least one bucket is full (>= max_batch_size), emit a wave seeded with full buckets
323
+ and extended with near-full buckets to keep parallelism high.
324
+ - If a bucket exceeds a burst threshold, emit a burst wave (seeded by the hottest bucket)
325
+ and extended with near-full buckets.
326
+
327
+ End-of-stream behavior:
328
+ - Flush leftovers in capped waves (max_batch_size per bucket).
329
+
330
+ Statistics policy:
331
+ * Every emission except the last carries {}.
332
+ * The last emission carries the accumulated upstream stats (unfiltered).
333
+ """
334
+ if self.predecessor is None:
335
+ return
336
+
337
+ accum_stats: Dict[str, Any] = {}
338
+ pending: BatchResults | None = None
339
+
340
+ near_full_threshold = max(1, int(max_batch_size * self.near_full_ratio))
341
+ burst_threshold = self.burst_multiplier * max_batch_size
342
+
343
+ for upstream in self.predecessor.get_batch(max_batch_size):
344
+ if upstream.statistics:
345
+ accum_stats = merge_summery(accum_stats, upstream.statistics)
346
+
347
+ for item in upstream.chunk:
348
+ r, c = self._id_extractor(item)
349
+ if self._monopartite and r > c:
350
+ r, c = c, r
351
+ if not (0 <= r < self.table_size and 0 <= c < self.table_size):
352
+ raise ValueError(f"bucket id out of range: {(r, c)} for table_size={self.table_size}")
353
+ self.buffer[r][c].append(item)
354
+
355
+ while True:
356
+ full_seed = self._select_wave(min_bucket_len=max_batch_size)
357
+ if not full_seed:
358
+ break
359
+ wave = self._select_wave(min_bucket_len=near_full_threshold, seed=full_seed)
360
+ br = self._flush_wave(wave, max_batch_size, statistics={})
361
+ if pending is not None:
362
+ yield pending
363
+ pending = br
364
+
365
+ while True:
366
+ hot = self._find_hottest_bucket(threshold=burst_threshold)
367
+ if hot is None:
368
+ break
369
+ hot_r, hot_c, hot_n = hot
370
+ wave = self._select_wave(min_bucket_len=near_full_threshold, seed=[(hot_r, hot_c)])
371
+ self.logger.debug(
372
+ "burst flush: hottest_bucket=(%d,%d len=%d) threshold=%d near_full_threshold=%d wave_size=%d",
373
+ hot_r, hot_c, hot_n, burst_threshold, near_full_threshold, len(wave)
374
+ )
375
+ br = self._flush_wave(wave, max_batch_size, statistics={})
376
+ if pending is not None:
377
+ yield pending
378
+ pending = br
379
+
380
+ self.logger.debug("start flushing leftovers")
381
+ while True:
382
+ wave = self._select_wave(min_bucket_len=1)
383
+ if not wave:
384
+ break
385
+ br = self._flush_wave(wave, max_batch_size, statistics={})
386
+ if pending is not None:
387
+ yield pending
388
+ pending = br
389
+
390
+ if pending is not None:
391
+ yield BatchResults(chunk=pending.chunk, statistics=accum_stats, batch_size=pending.batch_size)
@@ -30,7 +30,7 @@ class CSVBatchSource(BatchProcessor):
30
30
  self.csv_file = csv_file
31
31
  self.kwargs = kwargs
32
32
 
33
- def get_batch(self, max_batch__size: int) -> Generator[BatchResults]:
33
+ def get_batch(self, max_batch__size: int) -> Generator[BatchResults, None, None]:
34
34
  for batch_size, chunks_ in self.__read_csv(self.csv_file, batch_size=max_batch__size, **self.kwargs):
35
35
  yield BatchResults(chunk=chunks_, statistics={"csv_lines_read": batch_size}, batch_size=batch_size)
36
36