neo4j-etl-lib 0.3.0__tar.gz → 0.3.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/PKG-INFO +8 -6
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/pyproject.toml +10 -5
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/__init__.py +1 -1
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/ETLContext.py +2 -9
- neo4j_etl_lib-0.3.2/src/etl_lib/core/ParallelBatchProcessor.py +150 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/ProgressReporter.py +1 -1
- neo4j_etl_lib-0.3.2/src/etl_lib/core/SplittingBatchProcessor.py +391 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_source/CSVBatchSource.py +1 -1
- neo4j_etl_lib-0.3.2/src/etl_lib/data_source/SQLBatchSource.py +84 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/GDSTask.py +8 -5
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/test_utils/utils.py +9 -5
- neo4j_etl_lib-0.3.0/src/etl_lib/core/ParallelBatchProcessor.py +0 -180
- neo4j_etl_lib-0.3.0/src/etl_lib/core/SplittingBatchProcessor.py +0 -268
- neo4j_etl_lib-0.3.0/src/etl_lib/data_source/SQLBatchSource.py +0 -114
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/LICENSE +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/README.md +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/cli/__init__.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/cli/run_tools.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/BatchProcessor.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/ClosedLoopBatchProcessor.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/Task.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/ValidationBatchProcessor.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/__init__.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/core/utils.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_sink/CSVBatchSink.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_sink/CypherBatchSink.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_sink/SQLBatchSink.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_sink/__init__.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_source/CypherBatchSource.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/data_source/__init__.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/CreateReportingConstraintsTask.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/ExecuteCypherTask.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/__init__.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/data_loading/CSVLoad2Neo4jTask.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/data_loading/ParallelCSVLoad2Neo4jTask.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/data_loading/ParallelSQLLoad2Neo4jTask.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/data_loading/SQLLoad2Neo4jTask.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/task/data_loading/__init__.py +0 -0
- {neo4j_etl_lib-0.3.0 → neo4j_etl_lib-0.3.2}/src/etl_lib/test_utils/__init__.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: neo4j-etl-lib
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.2
|
|
4
4
|
Summary: Building blocks for ETL pipelines.
|
|
5
5
|
Keywords: etl,graph,database
|
|
6
6
|
Author-email: Bert Radke <bert.radke@pm.me>
|
|
@@ -14,11 +14,11 @@ Classifier: Programming Language :: Python :: 3
|
|
|
14
14
|
Classifier: Topic :: Database
|
|
15
15
|
Classifier: Development Status :: 4 - Beta
|
|
16
16
|
License-File: LICENSE
|
|
17
|
-
Requires-Dist: pydantic>=2.10.5; python_version >= '3.
|
|
18
|
-
Requires-Dist: neo4j-rust-ext>=5.27.0; python_version >= '3.
|
|
19
|
-
Requires-Dist: python-dotenv>=1.0.1; python_version >= '3.
|
|
20
|
-
Requires-Dist: tabulate>=0.9.0; python_version >= '3.
|
|
21
|
-
Requires-Dist: click>=8.1.8; python_version >= '3.
|
|
17
|
+
Requires-Dist: pydantic>=2.10.5; python_version >= '3.10'
|
|
18
|
+
Requires-Dist: neo4j-rust-ext>=5.27.0,<6; python_version >= '3.10'
|
|
19
|
+
Requires-Dist: python-dotenv>=1.0.1; python_version >= '3.10'
|
|
20
|
+
Requires-Dist: tabulate>=0.9.0; python_version >= '3.10'
|
|
21
|
+
Requires-Dist: click>=8.1.8; python_version >= '3.10'
|
|
22
22
|
Requires-Dist: pydantic[email-validator]
|
|
23
23
|
Requires-Dist: pytest>=8.3.0 ; extra == "dev" and ( python_version >= '3.8')
|
|
24
24
|
Requires-Dist: testcontainers[neo4j]==4.9.0 ; extra == "dev" and ( python_version >= '3.9' and python_version < '4.0')
|
|
@@ -35,11 +35,13 @@ Requires-Dist: sphinx-autoapi ; extra == "dev"
|
|
|
35
35
|
Requires-Dist: sqlalchemy ; extra == "dev"
|
|
36
36
|
Requires-Dist: psycopg2-binary ; extra == "dev"
|
|
37
37
|
Requires-Dist: graphdatascience>=1.13 ; extra == "gds" and ( python_version >= '3.9')
|
|
38
|
+
Requires-Dist: nox>=2024.0.0 ; extra == "nox"
|
|
38
39
|
Requires-Dist: sqlalchemy ; extra == "sql"
|
|
39
40
|
Project-URL: Documentation, https://neo-technology-field.github.io/python-etl-lib/index.html
|
|
40
41
|
Project-URL: Home, https://github.com/neo-technology-field/python-etl-lib
|
|
41
42
|
Provides-Extra: dev
|
|
42
43
|
Provides-Extra: gds
|
|
44
|
+
Provides-Extra: nox
|
|
43
45
|
Provides-Extra: sql
|
|
44
46
|
|
|
45
47
|
# Neo4j ETL Toolbox
|
|
@@ -22,11 +22,11 @@ dynamic = ["version", "description"]
|
|
|
22
22
|
keywords = ["etl", "graph", "database"]
|
|
23
23
|
|
|
24
24
|
dependencies = [
|
|
25
|
-
"pydantic>=2.10.5; python_version >= '3.
|
|
26
|
-
"neo4j-rust-ext>=5.27.0; python_version >= '3.
|
|
27
|
-
"python-dotenv>=1.0.1; python_version >= '3.
|
|
28
|
-
"tabulate>=0.9.0; python_version >= '3.
|
|
29
|
-
"click>=8.1.8; python_version >= '3.
|
|
25
|
+
"pydantic>=2.10.5; python_version >= '3.10'",
|
|
26
|
+
"neo4j-rust-ext>=5.27.0,<6; python_version >= '3.10'",
|
|
27
|
+
"python-dotenv>=1.0.1; python_version >= '3.10'",
|
|
28
|
+
"tabulate>=0.9.0; python_version >= '3.10'",
|
|
29
|
+
"click>=8.1.8; python_version >= '3.10'",
|
|
30
30
|
"pydantic[email_validator]"
|
|
31
31
|
]
|
|
32
32
|
|
|
@@ -41,6 +41,11 @@ dev = [
|
|
|
41
41
|
gds = ["graphdatascience>=1.13; python_version >= '3.9'"]
|
|
42
42
|
sql = ["sqlalchemy"]
|
|
43
43
|
|
|
44
|
+
# Local-only multy-version testing, install via `pip install ".[dev,nox]"`
|
|
45
|
+
nox = [
|
|
46
|
+
"nox>=2024.0.0"
|
|
47
|
+
]
|
|
48
|
+
|
|
44
49
|
[project.urls]
|
|
45
50
|
Home = "https://github.com/neo-technology-field/python-etl-lib"
|
|
46
51
|
Documentation = "https://neo-technology-field.github.io/python-etl-lib/index.html"
|
|
@@ -172,15 +172,8 @@ if sqlalchemy_available:
|
|
|
172
172
|
database_url,
|
|
173
173
|
pool_pre_ping=True,
|
|
174
174
|
pool_size=pool_size,
|
|
175
|
-
max_overflow=max_overflow
|
|
176
|
-
|
|
177
|
-
connect_args={
|
|
178
|
-
# turn on TCP keepalives on the client socket:
|
|
179
|
-
"keepalives": 1,
|
|
180
|
-
"keepalives_idle": 60, # after 60s of idle
|
|
181
|
-
"keepalives_interval": 10, # probe every 10s
|
|
182
|
-
"keepalives_count": 5, # give up after 5 failed probes
|
|
183
|
-
})
|
|
175
|
+
max_overflow=max_overflow
|
|
176
|
+
)
|
|
184
177
|
|
|
185
178
|
|
|
186
179
|
class ETLContext:
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
import queue
|
|
2
|
+
import threading
|
|
3
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
4
|
+
from typing import Any, Callable, Generator, List
|
|
5
|
+
|
|
6
|
+
from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
|
|
7
|
+
from etl_lib.core.utils import merge_summery
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ParallelBatchResult(BatchResults):
|
|
11
|
+
"""
|
|
12
|
+
Represents one *wave* produced by the splitter.
|
|
13
|
+
|
|
14
|
+
`chunk` is a list of bucket-batches. Each sub-list is processed by one worker instance.
|
|
15
|
+
"""
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ParallelBatchProcessor(BatchProcessor):
|
|
20
|
+
"""
|
|
21
|
+
BatchProcessor that runs a worker over the bucket-batches of each ParallelBatchResult
|
|
22
|
+
in parallel threads, while prefetching the next ParallelBatchResult from its predecessor.
|
|
23
|
+
|
|
24
|
+
Note:
|
|
25
|
+
- The predecessor must emit `ParallelBatchResult` instances (waves).
|
|
26
|
+
- This processor collects the BatchResults from all workers for one wave and merges them
|
|
27
|
+
into one BatchResults.
|
|
28
|
+
- The returned BatchResults will not obey the max_batch_size from get_batch() because
|
|
29
|
+
it represents the full wave.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
context: ETL context.
|
|
33
|
+
worker_factory: A zero-arg callable that returns a new BatchProcessor each time it's called.
|
|
34
|
+
task: optional Task for reporting.
|
|
35
|
+
predecessor: upstream BatchProcessor that must emit ParallelBatchResult See
|
|
36
|
+
:class:`~etl_lib.core.SplittingBatchProcessor.SplittingBatchProcessor`.
|
|
37
|
+
max_workers: number of parallel threads for bucket processing.
|
|
38
|
+
prefetch: number of waves to prefetch.
|
|
39
|
+
|
|
40
|
+
Behavior:
|
|
41
|
+
- For every wave, spins up `max_workers` threads.
|
|
42
|
+
- Each thread processes one bucket-batch using a fresh worker from `worker_factory()`.
|
|
43
|
+
- Collects and merges worker results in a fail-fast manner.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
context,
|
|
49
|
+
worker_factory: Callable[[], BatchProcessor],
|
|
50
|
+
task=None,
|
|
51
|
+
predecessor=None,
|
|
52
|
+
max_workers: int = 4,
|
|
53
|
+
prefetch: int = 4,
|
|
54
|
+
):
|
|
55
|
+
super().__init__(context, task, predecessor)
|
|
56
|
+
self.worker_factory = worker_factory
|
|
57
|
+
self.max_workers = max_workers
|
|
58
|
+
self.prefetch = prefetch
|
|
59
|
+
|
|
60
|
+
def _process_wave(self, wave: ParallelBatchResult) -> BatchResults:
|
|
61
|
+
"""
|
|
62
|
+
Process one wave: run one worker per bucket-batch and merge their BatchResults.
|
|
63
|
+
|
|
64
|
+
Statistics:
|
|
65
|
+
`wave.statistics` is used as the initial merged stats, then merged with each worker's stats.
|
|
66
|
+
"""
|
|
67
|
+
merged_stats = dict(wave.statistics or {})
|
|
68
|
+
merged_chunk = []
|
|
69
|
+
total = 0
|
|
70
|
+
|
|
71
|
+
self.logger.debug(f"Processing wave with {len(wave.chunk)} buckets")
|
|
72
|
+
|
|
73
|
+
with ThreadPoolExecutor(max_workers=self.max_workers, thread_name_prefix="PBP_worker_") as pool:
|
|
74
|
+
futures = [pool.submit(self._process_bucket_batch, bucket_batch) for bucket_batch in wave.chunk]
|
|
75
|
+
try:
|
|
76
|
+
for f in as_completed(futures):
|
|
77
|
+
out = f.result()
|
|
78
|
+
merged_stats = merge_summery(merged_stats, out.statistics or {})
|
|
79
|
+
total += out.batch_size
|
|
80
|
+
merged_chunk.extend(out.chunk if isinstance(out.chunk, list) else [out.chunk])
|
|
81
|
+
except Exception as e:
|
|
82
|
+
for g in futures:
|
|
83
|
+
g.cancel()
|
|
84
|
+
pool.shutdown(cancel_futures=True)
|
|
85
|
+
raise RuntimeError("bucket processing failed") from e
|
|
86
|
+
|
|
87
|
+
self.logger.debug(f"Finished wave with stats={merged_stats}")
|
|
88
|
+
return BatchResults(chunk=merged_chunk, statistics=merged_stats, batch_size=total)
|
|
89
|
+
|
|
90
|
+
def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
|
|
91
|
+
"""
|
|
92
|
+
Pull waves from the predecessor (prefetching up to `prefetch` ahead), process each wave's
|
|
93
|
+
buckets in parallel, and yield one flattened BatchResults per wave.
|
|
94
|
+
"""
|
|
95
|
+
wave_queue: queue.Queue[ParallelBatchResult | object] = queue.Queue(self.prefetch)
|
|
96
|
+
SENTINEL = object()
|
|
97
|
+
exc: BaseException | None = None
|
|
98
|
+
|
|
99
|
+
def producer():
|
|
100
|
+
nonlocal exc
|
|
101
|
+
try:
|
|
102
|
+
for wave in self.predecessor.get_batch(max_batch_size):
|
|
103
|
+
self.logger.debug(
|
|
104
|
+
f"adding wave stats={wave.statistics} buckets={len(wave.chunk)} to queue size={wave_queue.qsize()}"
|
|
105
|
+
)
|
|
106
|
+
wave_queue.put(wave)
|
|
107
|
+
except BaseException as e:
|
|
108
|
+
exc = e
|
|
109
|
+
finally:
|
|
110
|
+
wave_queue.put(SENTINEL)
|
|
111
|
+
|
|
112
|
+
threading.Thread(target=producer, daemon=True, name="prefetcher").start()
|
|
113
|
+
|
|
114
|
+
while True:
|
|
115
|
+
wave = wave_queue.get()
|
|
116
|
+
if wave is SENTINEL:
|
|
117
|
+
if exc is not None:
|
|
118
|
+
self.logger.error("Upstream producer failed", exc_info=True)
|
|
119
|
+
raise exc
|
|
120
|
+
break
|
|
121
|
+
yield self._process_wave(wave)
|
|
122
|
+
|
|
123
|
+
class SingleBatchWrapper(BatchProcessor):
|
|
124
|
+
"""
|
|
125
|
+
Simple BatchProcessor that returns exactly one batch (the bucket-batch passed in via init).
|
|
126
|
+
Used as predecessor for the per-bucket worker.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(self, context, batch: List[Any]):
|
|
130
|
+
super().__init__(context=context, predecessor=None)
|
|
131
|
+
self._batch = batch
|
|
132
|
+
|
|
133
|
+
def get_batch(self, max_size: int) -> Generator[BatchResults, None, None]:
|
|
134
|
+
yield BatchResults(
|
|
135
|
+
chunk=self._batch,
|
|
136
|
+
statistics={},
|
|
137
|
+
batch_size=len(self._batch),
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def _process_bucket_batch(self, bucket_batch):
|
|
141
|
+
"""
|
|
142
|
+
Process one bucket-batch by running a fresh worker over it.
|
|
143
|
+
"""
|
|
144
|
+
self.logger.debug(f"Processing batch w/ size {len(bucket_batch)}")
|
|
145
|
+
wrapper = self.SingleBatchWrapper(self.context, bucket_batch)
|
|
146
|
+
worker = self.worker_factory()
|
|
147
|
+
worker.predecessor = wrapper
|
|
148
|
+
result = next(worker.get_batch(len(bucket_batch)))
|
|
149
|
+
self.logger.debug(f"Finished bucket batch stats={result.statistics}")
|
|
150
|
+
return result
|
|
@@ -45,7 +45,7 @@ class ProgressReporter:
|
|
|
45
45
|
The task that was provided.
|
|
46
46
|
"""
|
|
47
47
|
task.start_time = datetime.now()
|
|
48
|
-
self.logger.info(f"{'
|
|
48
|
+
self.logger.info(f"{' ' * (4 * task.depth)}starting {task.task_name()}")
|
|
49
49
|
return task
|
|
50
50
|
|
|
51
51
|
def finished_task(self, task: Task, result: TaskReturn) -> Task:
|
|
@@ -0,0 +1,391 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Callable, Dict, Generator, List, Tuple
|
|
3
|
+
|
|
4
|
+
from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
|
|
5
|
+
from etl_lib.core.utils import merge_summery
|
|
6
|
+
from tabulate import tabulate
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def tuple_id_extractor(table_size: int = 10) -> Callable[[Tuple[str | int, str | int]], Tuple[int, int]]:
|
|
10
|
+
"""
|
|
11
|
+
Create an ID extractor function for tuple items, using the last decimal digit of each element.
|
|
12
|
+
The output is a `(row, col)` tuple within a `table_size x table_size` grid (default 10x10).
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
table_size: The dimension of the grid (number of rows/cols). Defaults to 10.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
A callable that maps a tuple `(a, b)` to a tuple `(row, col)` using the last digit of `a` and `b`.
|
|
19
|
+
|
|
20
|
+
Notes:
|
|
21
|
+
This extractor does not validate the returned indices. Range validation is performed by
|
|
22
|
+
:class:`~etl_lib.core.SplittingBatchProcessor.SplittingBatchProcessor`.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
def extractor(item: Tuple[Any, Any]) -> Tuple[int, int]:
|
|
26
|
+
a, b = item
|
|
27
|
+
try:
|
|
28
|
+
row = int(str(a)[-1])
|
|
29
|
+
col = int(str(b)[-1])
|
|
30
|
+
except Exception as e:
|
|
31
|
+
raise ValueError(f"Failed to extract ID from item {item}: {e}")
|
|
32
|
+
return row, col
|
|
33
|
+
|
|
34
|
+
extractor.table_size = table_size
|
|
35
|
+
return extractor
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def dict_id_extractor(
|
|
39
|
+
table_size: int = 10,
|
|
40
|
+
start_key: str = "start",
|
|
41
|
+
end_key: str = "end",
|
|
42
|
+
) -> Callable[[Dict[str, Any]], Tuple[int, int]]:
|
|
43
|
+
"""
|
|
44
|
+
Build an ID extractor for dict rows. The extractor reads two fields (configurable via
|
|
45
|
+
`start_key` and `end_key`) and returns (row, col) based on the last decimal digit of each.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
table_size: Informational hint carried on the extractor.
|
|
49
|
+
start_key: Field name for the start node identifier.
|
|
50
|
+
end_key: Field name for the end node identifier.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
Callable that maps {start_key, end_key} → (row, col).
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def extractor(item: Dict[str, Any]) -> Tuple[int, int]:
|
|
57
|
+
missing = [k for k in (start_key, end_key) if k not in item]
|
|
58
|
+
if missing:
|
|
59
|
+
raise KeyError(f"Item missing required keys: {', '.join(missing)}")
|
|
60
|
+
try:
|
|
61
|
+
row = int(str(item[start_key])[-1])
|
|
62
|
+
col = int(str(item[end_key])[-1])
|
|
63
|
+
except Exception as e:
|
|
64
|
+
raise ValueError(f"Failed to extract ID from item {item}: {e}")
|
|
65
|
+
return row, col
|
|
66
|
+
|
|
67
|
+
extractor.table_size = table_size
|
|
68
|
+
return extractor
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def canonical_integer_id_extractor(
|
|
72
|
+
table_size: int = 10,
|
|
73
|
+
start_key: str = "start",
|
|
74
|
+
end_key: str = "end",
|
|
75
|
+
) -> Callable[[Dict[str, Any]], Tuple[int, int]]:
|
|
76
|
+
"""
|
|
77
|
+
ID extractor for integer IDs with canonical folding.
|
|
78
|
+
|
|
79
|
+
- Uses Knuth's multiplicative hashing to scatter sequential integers across the grid.
|
|
80
|
+
- Canonical folding ensures (A,B) and (B,A) land in the same bucket by folding the lower
|
|
81
|
+
triangle into the upper triangle (row <= col).
|
|
82
|
+
|
|
83
|
+
The extractor marks itself as mono-partite by setting `extractor.monopartite = True`.
|
|
84
|
+
"""
|
|
85
|
+
MAGIC = 2654435761
|
|
86
|
+
|
|
87
|
+
def extractor(item: Dict[str, Any]) -> Tuple[int, int]:
|
|
88
|
+
try:
|
|
89
|
+
s_val = item[start_key]
|
|
90
|
+
e_val = item[end_key]
|
|
91
|
+
|
|
92
|
+
s_hash = (s_val * MAGIC) & 0xffffffff
|
|
93
|
+
e_hash = (e_val * MAGIC) & 0xffffffff
|
|
94
|
+
|
|
95
|
+
row = s_hash % table_size
|
|
96
|
+
col = e_hash % table_size
|
|
97
|
+
|
|
98
|
+
if row > col:
|
|
99
|
+
row, col = col, row
|
|
100
|
+
|
|
101
|
+
return row, col
|
|
102
|
+
except KeyError:
|
|
103
|
+
raise KeyError(f"Item missing keys: {start_key} or {end_key}")
|
|
104
|
+
except Exception as e:
|
|
105
|
+
raise ValueError(f"Failed to extract ID: {e}")
|
|
106
|
+
|
|
107
|
+
extractor.table_size = table_size
|
|
108
|
+
extractor.monopartite = True
|
|
109
|
+
return extractor
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class SplittingBatchProcessor(BatchProcessor):
|
|
113
|
+
"""
|
|
114
|
+
Streaming wave scheduler for mix-and-batch style loading.
|
|
115
|
+
|
|
116
|
+
Incoming rows are assigned to buckets via an `id_extractor(item) -> (row, col)` inside a
|
|
117
|
+
`table_size x table_size` grid. The processor emits waves; each wave contains bucket-batches
|
|
118
|
+
that are safe to process concurrently under the configured non-overlap rule.
|
|
119
|
+
|
|
120
|
+
Non-overlap rules
|
|
121
|
+
-----------------
|
|
122
|
+
- Bi-partite (default): within a wave, no two buckets share a row index and no two buckets share a col index.
|
|
123
|
+
- Mono-partite: within a wave, no node index is used more than once (row/col indices are the same domain).
|
|
124
|
+
Enable by setting `id_extractor.monopartite = True` (as done by `canonical_integer_id_extractor`).
|
|
125
|
+
|
|
126
|
+
Emission strategy
|
|
127
|
+
-----------------
|
|
128
|
+
- During streaming: emit a wave when at least one bucket is full (>= max_batch_size).
|
|
129
|
+
The wave is then filled with additional non-overlapping buckets that are near-full to
|
|
130
|
+
keep parallelism high without producing tiny batches.
|
|
131
|
+
- If a bucket backlog grows beyond a burst threshold, emit a burst wave to bound memory.
|
|
132
|
+
- After source exhaustion: flush leftovers in capped waves (max_batch_size per bucket).
|
|
133
|
+
|
|
134
|
+
Statistics policy
|
|
135
|
+
-----------------
|
|
136
|
+
- Every emission except the last carries {}.
|
|
137
|
+
- The last emission carries the accumulated upstream statistics (unfiltered).
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
context,
|
|
143
|
+
table_size: int,
|
|
144
|
+
id_extractor: Callable[[Any], Tuple[int, int]],
|
|
145
|
+
task=None,
|
|
146
|
+
predecessor=None,
|
|
147
|
+
near_full_ratio: float = 0.85,
|
|
148
|
+
burst_multiplier: int = 25,
|
|
149
|
+
):
|
|
150
|
+
super().__init__(context, task, predecessor)
|
|
151
|
+
|
|
152
|
+
if hasattr(id_extractor, "table_size"):
|
|
153
|
+
expected_size = id_extractor.table_size
|
|
154
|
+
if table_size is None:
|
|
155
|
+
table_size = expected_size
|
|
156
|
+
elif table_size != expected_size:
|
|
157
|
+
raise ValueError(
|
|
158
|
+
f"Mismatch between provided table_size ({table_size}) and id_extractor table_size ({expected_size})."
|
|
159
|
+
)
|
|
160
|
+
elif table_size is None:
|
|
161
|
+
raise ValueError("table_size must be specified if id_extractor has no defined table_size")
|
|
162
|
+
|
|
163
|
+
if not (0 < near_full_ratio <= 1.0):
|
|
164
|
+
raise ValueError(f"near_full_ratio must be in (0, 1], got {near_full_ratio}")
|
|
165
|
+
if burst_multiplier < 1:
|
|
166
|
+
raise ValueError(f"burst_multiplier must be >= 1, got {burst_multiplier}")
|
|
167
|
+
|
|
168
|
+
self.table_size = table_size
|
|
169
|
+
self._id_extractor = id_extractor
|
|
170
|
+
self._monopartite = bool(getattr(id_extractor, "monopartite", False))
|
|
171
|
+
|
|
172
|
+
self.near_full_ratio = float(near_full_ratio)
|
|
173
|
+
self.burst_multiplier = int(burst_multiplier)
|
|
174
|
+
|
|
175
|
+
self.buffer: Dict[int, Dict[int, List[Any]]] = {
|
|
176
|
+
r: {c: [] for c in range(self.table_size)}
|
|
177
|
+
for r in range(self.table_size)
|
|
178
|
+
}
|
|
179
|
+
self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
|
|
180
|
+
|
|
181
|
+
def _bucket_claims(self, row: int, col: int) -> Tuple[Any, ...]:
|
|
182
|
+
"""
|
|
183
|
+
Return the resource claims a bucket consumes within a wave.
|
|
184
|
+
|
|
185
|
+
- Bi-partite: claims (row-slot, col-slot)
|
|
186
|
+
- Mono-partite: claims node indices touched by the bucket
|
|
187
|
+
"""
|
|
188
|
+
if self._monopartite:
|
|
189
|
+
return (row,) if row == col else (row, col)
|
|
190
|
+
return ("row", row), ("col", col)
|
|
191
|
+
|
|
192
|
+
def _all_bucket_sizes(self) -> List[Tuple[int, int, int]]:
|
|
193
|
+
"""
|
|
194
|
+
Return all non-empty buckets as (size, row, col).
|
|
195
|
+
"""
|
|
196
|
+
out: List[Tuple[int, int, int]] = []
|
|
197
|
+
for r in range(self.table_size):
|
|
198
|
+
for c in range(self.table_size):
|
|
199
|
+
n = len(self.buffer[r][c])
|
|
200
|
+
if n > 0:
|
|
201
|
+
out.append((n, r, c))
|
|
202
|
+
return out
|
|
203
|
+
|
|
204
|
+
def _select_wave(self, *, min_bucket_len: int, seed: List[Tuple[int, int]] | None = None) -> List[Tuple[int, int]]:
|
|
205
|
+
"""
|
|
206
|
+
Greedy wave scheduler: pick a non-overlapping set of buckets with len >= min_bucket_len.
|
|
207
|
+
|
|
208
|
+
If `seed` is provided, it is taken as fixed and the wave is extended greedily.
|
|
209
|
+
"""
|
|
210
|
+
candidates: List[Tuple[int, int, int]] = []
|
|
211
|
+
for r in range(self.table_size):
|
|
212
|
+
for c in range(self.table_size):
|
|
213
|
+
n = len(self.buffer[r][c])
|
|
214
|
+
if n >= min_bucket_len:
|
|
215
|
+
candidates.append((n, r, c))
|
|
216
|
+
|
|
217
|
+
if not candidates and not seed:
|
|
218
|
+
return []
|
|
219
|
+
|
|
220
|
+
candidates.sort(key=lambda x: (-x[0], x[1], x[2]))
|
|
221
|
+
|
|
222
|
+
used: set[Any] = set()
|
|
223
|
+
wave: List[Tuple[int, int]] = []
|
|
224
|
+
|
|
225
|
+
if seed:
|
|
226
|
+
for r, c in seed:
|
|
227
|
+
claims = self._bucket_claims(r, c)
|
|
228
|
+
used.update(claims)
|
|
229
|
+
wave.append((r, c))
|
|
230
|
+
|
|
231
|
+
for _, r, c in candidates:
|
|
232
|
+
if (r, c) in wave:
|
|
233
|
+
continue
|
|
234
|
+
claims = self._bucket_claims(r, c)
|
|
235
|
+
if any(claim in used for claim in claims):
|
|
236
|
+
continue
|
|
237
|
+
wave.append((r, c))
|
|
238
|
+
used.update(claims)
|
|
239
|
+
if len(wave) >= self.table_size:
|
|
240
|
+
break
|
|
241
|
+
|
|
242
|
+
return wave
|
|
243
|
+
|
|
244
|
+
def _find_hottest_bucket(self, *, threshold: int) -> Tuple[int, int, int] | None:
|
|
245
|
+
"""
|
|
246
|
+
Find the single hottest bucket (largest backlog) whose size is >= threshold.
|
|
247
|
+
Returns (row, col, size) or None.
|
|
248
|
+
"""
|
|
249
|
+
best: Tuple[int, int, int] | None = None
|
|
250
|
+
for r in range(self.table_size):
|
|
251
|
+
for c in range(self.table_size):
|
|
252
|
+
n = len(self.buffer[r][c])
|
|
253
|
+
if n < threshold:
|
|
254
|
+
continue
|
|
255
|
+
if best is None or n > best[2]:
|
|
256
|
+
best = (r, c, n)
|
|
257
|
+
return best
|
|
258
|
+
|
|
259
|
+
def _flush_wave(
|
|
260
|
+
self,
|
|
261
|
+
wave: List[Tuple[int, int]],
|
|
262
|
+
max_batch_size: int,
|
|
263
|
+
statistics: Dict[str, Any] | None = None,
|
|
264
|
+
) -> BatchResults:
|
|
265
|
+
"""
|
|
266
|
+
Extract up to `max_batch_size` items from each bucket in `wave`, remove them from the buffer,
|
|
267
|
+
and return a BatchResults whose chunk is a list of per-bucket lists (aligned with `wave`).
|
|
268
|
+
"""
|
|
269
|
+
self._log_buffer_matrix(wave=wave)
|
|
270
|
+
|
|
271
|
+
bucket_batches: List[List[Any]] = []
|
|
272
|
+
for r, c in wave:
|
|
273
|
+
q = self.buffer[r][c]
|
|
274
|
+
take = min(max_batch_size, len(q))
|
|
275
|
+
bucket_batches.append(q[:take])
|
|
276
|
+
self.buffer[r][c] = q[take:]
|
|
277
|
+
|
|
278
|
+
emitted = sum(len(b) for b in bucket_batches)
|
|
279
|
+
|
|
280
|
+
return BatchResults(
|
|
281
|
+
chunk=bucket_batches,
|
|
282
|
+
statistics=statistics or {},
|
|
283
|
+
batch_size=emitted,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
def _log_buffer_matrix(self, *, wave: List[Tuple[int, int]]) -> None:
|
|
287
|
+
"""
|
|
288
|
+
Dumps a compact 2D matrix of per-bucket sizes (len of each buffer) when DEBUG is enabled.
|
|
289
|
+
"""
|
|
290
|
+
if not self.logger.isEnabledFor(logging.DEBUG):
|
|
291
|
+
return
|
|
292
|
+
|
|
293
|
+
counts = [
|
|
294
|
+
[len(self.buffer[r][c]) for c in range(self.table_size)]
|
|
295
|
+
for r in range(self.table_size)
|
|
296
|
+
]
|
|
297
|
+
marks = set(wave)
|
|
298
|
+
|
|
299
|
+
pad = max(2, len(str(self.table_size - 1)))
|
|
300
|
+
col_headers = [f"c{c:0{pad}d}" for c in range(self.table_size)]
|
|
301
|
+
|
|
302
|
+
rows = []
|
|
303
|
+
for r in range(self.table_size):
|
|
304
|
+
row_label = f"r{r:0{pad}d}"
|
|
305
|
+
row_vals = [f"[{v}]" if (r, c) in marks else f"{v}" for c, v in enumerate(counts[r])]
|
|
306
|
+
rows.append([row_label, *row_vals])
|
|
307
|
+
|
|
308
|
+
table = tabulate(
|
|
309
|
+
rows,
|
|
310
|
+
headers=["", *col_headers],
|
|
311
|
+
tablefmt="psql",
|
|
312
|
+
stralign="right",
|
|
313
|
+
disable_numparse=True,
|
|
314
|
+
)
|
|
315
|
+
self.logger.debug("buffer matrix:\n%s", table)
|
|
316
|
+
|
|
317
|
+
def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
|
|
318
|
+
"""
|
|
319
|
+
Consume upstream batches, bucket incoming rows, and emit waves of non-overlapping buckets.
|
|
320
|
+
|
|
321
|
+
Streaming behavior:
|
|
322
|
+
- If at least one bucket is full (>= max_batch_size), emit a wave seeded with full buckets
|
|
323
|
+
and extended with near-full buckets to keep parallelism high.
|
|
324
|
+
- If a bucket exceeds a burst threshold, emit a burst wave (seeded by the hottest bucket)
|
|
325
|
+
and extended with near-full buckets.
|
|
326
|
+
|
|
327
|
+
End-of-stream behavior:
|
|
328
|
+
- Flush leftovers in capped waves (max_batch_size per bucket).
|
|
329
|
+
|
|
330
|
+
Statistics policy:
|
|
331
|
+
* Every emission except the last carries {}.
|
|
332
|
+
* The last emission carries the accumulated upstream stats (unfiltered).
|
|
333
|
+
"""
|
|
334
|
+
if self.predecessor is None:
|
|
335
|
+
return
|
|
336
|
+
|
|
337
|
+
accum_stats: Dict[str, Any] = {}
|
|
338
|
+
pending: BatchResults | None = None
|
|
339
|
+
|
|
340
|
+
near_full_threshold = max(1, int(max_batch_size * self.near_full_ratio))
|
|
341
|
+
burst_threshold = self.burst_multiplier * max_batch_size
|
|
342
|
+
|
|
343
|
+
for upstream in self.predecessor.get_batch(max_batch_size):
|
|
344
|
+
if upstream.statistics:
|
|
345
|
+
accum_stats = merge_summery(accum_stats, upstream.statistics)
|
|
346
|
+
|
|
347
|
+
for item in upstream.chunk:
|
|
348
|
+
r, c = self._id_extractor(item)
|
|
349
|
+
if self._monopartite and r > c:
|
|
350
|
+
r, c = c, r
|
|
351
|
+
if not (0 <= r < self.table_size and 0 <= c < self.table_size):
|
|
352
|
+
raise ValueError(f"bucket id out of range: {(r, c)} for table_size={self.table_size}")
|
|
353
|
+
self.buffer[r][c].append(item)
|
|
354
|
+
|
|
355
|
+
while True:
|
|
356
|
+
full_seed = self._select_wave(min_bucket_len=max_batch_size)
|
|
357
|
+
if not full_seed:
|
|
358
|
+
break
|
|
359
|
+
wave = self._select_wave(min_bucket_len=near_full_threshold, seed=full_seed)
|
|
360
|
+
br = self._flush_wave(wave, max_batch_size, statistics={})
|
|
361
|
+
if pending is not None:
|
|
362
|
+
yield pending
|
|
363
|
+
pending = br
|
|
364
|
+
|
|
365
|
+
while True:
|
|
366
|
+
hot = self._find_hottest_bucket(threshold=burst_threshold)
|
|
367
|
+
if hot is None:
|
|
368
|
+
break
|
|
369
|
+
hot_r, hot_c, hot_n = hot
|
|
370
|
+
wave = self._select_wave(min_bucket_len=near_full_threshold, seed=[(hot_r, hot_c)])
|
|
371
|
+
self.logger.debug(
|
|
372
|
+
"burst flush: hottest_bucket=(%d,%d len=%d) threshold=%d near_full_threshold=%d wave_size=%d",
|
|
373
|
+
hot_r, hot_c, hot_n, burst_threshold, near_full_threshold, len(wave)
|
|
374
|
+
)
|
|
375
|
+
br = self._flush_wave(wave, max_batch_size, statistics={})
|
|
376
|
+
if pending is not None:
|
|
377
|
+
yield pending
|
|
378
|
+
pending = br
|
|
379
|
+
|
|
380
|
+
self.logger.debug("start flushing leftovers")
|
|
381
|
+
while True:
|
|
382
|
+
wave = self._select_wave(min_bucket_len=1)
|
|
383
|
+
if not wave:
|
|
384
|
+
break
|
|
385
|
+
br = self._flush_wave(wave, max_batch_size, statistics={})
|
|
386
|
+
if pending is not None:
|
|
387
|
+
yield pending
|
|
388
|
+
pending = br
|
|
389
|
+
|
|
390
|
+
if pending is not None:
|
|
391
|
+
yield BatchResults(chunk=pending.chunk, statistics=accum_stats, batch_size=pending.batch_size)
|
|
@@ -30,7 +30,7 @@ class CSVBatchSource(BatchProcessor):
|
|
|
30
30
|
self.csv_file = csv_file
|
|
31
31
|
self.kwargs = kwargs
|
|
32
32
|
|
|
33
|
-
def get_batch(self, max_batch__size: int) -> Generator[BatchResults]:
|
|
33
|
+
def get_batch(self, max_batch__size: int) -> Generator[BatchResults, None, None]:
|
|
34
34
|
for batch_size, chunks_ in self.__read_csv(self.csv_file, batch_size=max_batch__size, **self.kwargs):
|
|
35
35
|
yield BatchResults(chunk=chunks_, statistics={"csv_lines_read": batch_size}, batch_size=batch_size)
|
|
36
36
|
|