neo4j-etl-lib 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- etl_lib/__init__.py +1 -1
- etl_lib/cli/run_tools.py +1 -1
- etl_lib/core/BatchProcessor.py +7 -7
- etl_lib/core/ClosedLoopBatchProcessor.py +8 -2
- etl_lib/core/ETLContext.py +112 -46
- etl_lib/core/ParallelBatchProcessor.py +180 -0
- etl_lib/core/ProgressReporter.py +23 -4
- etl_lib/core/SplittingBatchProcessor.py +268 -0
- etl_lib/core/Task.py +10 -8
- etl_lib/core/ValidationBatchProcessor.py +2 -0
- etl_lib/core/utils.py +52 -11
- etl_lib/data_sink/CypherBatchSink.py +4 -3
- etl_lib/data_sink/SQLBatchSink.py +36 -0
- etl_lib/data_source/SQLBatchSource.py +114 -0
- etl_lib/task/CreateReportingConstraintsTask.py +2 -2
- etl_lib/task/data_loading/CSVLoad2Neo4jTask.py +42 -6
- etl_lib/task/data_loading/ParallelCSVLoad2Neo4jTask.py +98 -0
- etl_lib/task/data_loading/ParallelSQLLoad2Neo4jTask.py +122 -0
- etl_lib/task/data_loading/SQLLoad2Neo4jTask.py +90 -0
- etl_lib/test_utils/utils.py +19 -1
- {neo4j_etl_lib-0.1.1.dist-info → neo4j_etl_lib-0.3.0.dist-info}/METADATA +14 -3
- neo4j_etl_lib-0.3.0.dist-info/RECORD +36 -0
- {neo4j_etl_lib-0.1.1.dist-info → neo4j_etl_lib-0.3.0.dist-info}/WHEEL +1 -1
- neo4j_etl_lib-0.1.1.dist-info/RECORD +0 -29
- {neo4j_etl_lib-0.1.1.dist-info → neo4j_etl_lib-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Callable, Dict, Generator, List, Tuple
|
|
3
|
+
|
|
4
|
+
from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
|
|
5
|
+
from etl_lib.core.utils import merge_summery
|
|
6
|
+
from tabulate import tabulate
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def tuple_id_extractor(table_size: int = 10) -> Callable[[Tuple[str | int, str | int]], Tuple[int, int]]:
|
|
10
|
+
"""
|
|
11
|
+
Create an ID extractor function for tuple items, using the last decimal digit of each element.
|
|
12
|
+
The output is a `(row, col)` tuple within a `table_size x table_size` grid (default 10x10).
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
table_size: The dimension of the grid (number of rows/cols). Defaults to 10.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
A callable that maps a tuple `(a, b)` to a tuple `(row, col)` using the last digit of `a` and `b`.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def extractor(item: Tuple[Any, Any]) -> Tuple[int, int]:
|
|
22
|
+
a, b = item
|
|
23
|
+
try:
|
|
24
|
+
row = int(str(a)[-1])
|
|
25
|
+
col = int(str(b)[-1])
|
|
26
|
+
except Exception as e:
|
|
27
|
+
raise ValueError(f"Failed to extract ID from item {item}: {e}")
|
|
28
|
+
return row, col
|
|
29
|
+
|
|
30
|
+
extractor.table_size = table_size
|
|
31
|
+
return extractor
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def dict_id_extractor(
|
|
35
|
+
table_size: int = 10,
|
|
36
|
+
start_key: str = "start",
|
|
37
|
+
end_key: str = "end",
|
|
38
|
+
) -> Callable[[Dict[str, Any]], Tuple[int, int]]:
|
|
39
|
+
"""
|
|
40
|
+
Build an ID extractor for dict rows. The extractor reads two fields (configurable via
|
|
41
|
+
`start_key` and `end_key`) and returns (row, col) based on the last decimal digit of each.
|
|
42
|
+
Range validation remains the responsibility of the SplittingBatchProcessor.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
table_size: Informational hint carried on the extractor; used by callers to sanity-check.
|
|
46
|
+
start_key: Field name for the start node identifier.
|
|
47
|
+
end_key: Field name for the end node identifier.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
Callable[[Mapping[str, Any]], tuple[int, int]]: Maps {start_key, end_key} → (row, col).
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
def extractor(item: Dict[str, Any]) -> Tuple[int, int]:
|
|
54
|
+
missing = [k for k in (start_key, end_key) if k not in item]
|
|
55
|
+
if missing:
|
|
56
|
+
raise KeyError(f"Item missing required keys: {', '.join(missing)}")
|
|
57
|
+
try:
|
|
58
|
+
row = int(str(item[start_key])[-1])
|
|
59
|
+
col = int(str(item[end_key])[-1])
|
|
60
|
+
except Exception as e:
|
|
61
|
+
raise ValueError(f"Failed to extract ID from item {item}: {e}")
|
|
62
|
+
return row, col
|
|
63
|
+
|
|
64
|
+
extractor.table_size = table_size
|
|
65
|
+
return extractor
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class SplittingBatchProcessor(BatchProcessor):
|
|
69
|
+
"""
|
|
70
|
+
BatchProcessor that splits incoming BatchResults into non-overlapping partitions based
|
|
71
|
+
on row/col indices extracted by the id_extractor, and emits full or remaining batches
|
|
72
|
+
using the mix-and-batch algorithm from https://neo4j.com/blog/developer/mix-and-batch-relationship-load/
|
|
73
|
+
Each emitted batch is a list of per-cell lists (array of arrays), so callers
|
|
74
|
+
can process each partition (other name for a cell) in parallel.
|
|
75
|
+
|
|
76
|
+
A batch for a schedule group is emitted when all cells in that group have at least `batch_size` items.
|
|
77
|
+
In addition, when a cell/partition reaches 3x the configured max_batch_size, the group is emitted to avoid
|
|
78
|
+
overflowing the buffer when the distribution per cell is uneven.
|
|
79
|
+
Leftovers are flushed after source exhaustion.
|
|
80
|
+
Emitted batches never exceed the configured max_batch_size.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self, context, table_size: int, id_extractor: Callable[[Any], Tuple[int, int]],
|
|
84
|
+
task=None, predecessor=None):
|
|
85
|
+
super().__init__(context, task, predecessor)
|
|
86
|
+
|
|
87
|
+
# If the extractor carries an expected table size, use or validate it
|
|
88
|
+
if hasattr(id_extractor, "table_size"):
|
|
89
|
+
expected_size = id_extractor.table_size
|
|
90
|
+
if table_size is None:
|
|
91
|
+
table_size = expected_size # determine table size from extractor if not provided
|
|
92
|
+
elif table_size != expected_size:
|
|
93
|
+
raise ValueError(f"Mismatch between provided table_size ({table_size}) and "
|
|
94
|
+
f"id_extractor table_size ({expected_size}).")
|
|
95
|
+
elif table_size is None:
|
|
96
|
+
raise ValueError("table_size must be specified if id_extractor has no defined table_size")
|
|
97
|
+
self.table_size = table_size
|
|
98
|
+
self._id_extractor = id_extractor
|
|
99
|
+
|
|
100
|
+
# Initialize 2D buffer for partitions
|
|
101
|
+
self.buffer: Dict[int, Dict[int, List[Any]]] = {
|
|
102
|
+
i: {j: [] for j in range(self.table_size)} for i in range(self.table_size)
|
|
103
|
+
}
|
|
104
|
+
self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
|
|
105
|
+
|
|
106
|
+
def _generate_batch_schedule(self) -> List[List[Tuple[int, int]]]:
|
|
107
|
+
"""
|
|
108
|
+
Create diagonal stripes (row, col) partitions to ensure no overlapping IDs
|
|
109
|
+
across emitted batches.
|
|
110
|
+
Example grid:
|
|
111
|
+
|| 0 | 1 | 2
|
|
112
|
+
=====++=====+=====+=====
|
|
113
|
+
0 || 0 | 1 | 2
|
|
114
|
+
-----++-----+-----+----
|
|
115
|
+
1 || 2 | 0 | 1
|
|
116
|
+
-----++-----+-----+-----
|
|
117
|
+
2 || 1 | 2 | 0
|
|
118
|
+
|
|
119
|
+
would return [[(0, 0), (1, 1), (2, 2)], [(0, 1), (1, 2), (2, 0)], [(0, 2), (1, 0), (2, 1)]]
|
|
120
|
+
"""
|
|
121
|
+
schedule: List[List[Tuple[int, int]]] = []
|
|
122
|
+
for shift in range(self.table_size):
|
|
123
|
+
partition: List[Tuple[int, int]] = [
|
|
124
|
+
(i, (i + shift) % self.table_size)
|
|
125
|
+
for i in range(self.table_size)
|
|
126
|
+
]
|
|
127
|
+
schedule.append(partition)
|
|
128
|
+
return schedule
|
|
129
|
+
|
|
130
|
+
def _flush_group(
|
|
131
|
+
self,
|
|
132
|
+
partitions: List[Tuple[int, int]],
|
|
133
|
+
batch_size: int,
|
|
134
|
+
statistics: Dict[str, Any] | None = None,
|
|
135
|
+
) -> Generator[BatchResults, None, None]:
|
|
136
|
+
"""
|
|
137
|
+
Extract up to `batch_size` items from each cell in `partitions`, remove them from the buffer,
|
|
138
|
+
and yield a BatchResults whose chunks is a list of per-cell lists from these partitions.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
partitions: Cell coordinates forming a diagonal group from the schedule.
|
|
142
|
+
batch_size: Max number of items to take from each cell.
|
|
143
|
+
statistics: Stats dict to attach to this emission (use {} for interim waves).
|
|
144
|
+
The "final" emission will pass the accumulated stats here.
|
|
145
|
+
|
|
146
|
+
Notes:
|
|
147
|
+
- Debug-only: prints a 2D matrix of cell sizes when logger is in DEBUG.
|
|
148
|
+
- batch_size in the returned BatchResults equals the number of emitted items.
|
|
149
|
+
"""
|
|
150
|
+
self._log_buffer_matrix(partition=partitions)
|
|
151
|
+
|
|
152
|
+
cell_chunks: List[List[Any]] = []
|
|
153
|
+
for row, col in partitions:
|
|
154
|
+
q = self.buffer[row][col]
|
|
155
|
+
take = min(batch_size, len(q))
|
|
156
|
+
part = q[:take]
|
|
157
|
+
cell_chunks.append(part)
|
|
158
|
+
# remove flushed items
|
|
159
|
+
self.buffer[row][col] = q[take:]
|
|
160
|
+
|
|
161
|
+
emitted = sum(len(c) for c in cell_chunks)
|
|
162
|
+
|
|
163
|
+
result = BatchResults(
|
|
164
|
+
chunk=cell_chunks,
|
|
165
|
+
statistics=statistics or {},
|
|
166
|
+
batch_size=emitted,
|
|
167
|
+
)
|
|
168
|
+
yield result
|
|
169
|
+
|
|
170
|
+
def get_batch(self, max_batch__size: int) -> Generator[BatchResults, None, None]:
|
|
171
|
+
"""
|
|
172
|
+
Consume upstream batches, split data across cells, and emit diagonal partitions:
|
|
173
|
+
- During streaming: emit a full partition when all its cells have >= max_batch__size.
|
|
174
|
+
- Also during streaming: if any cell in a partition builds up beyond a 'burst' threshold
|
|
175
|
+
(3 * of max_batch__size), emit that partition early, taking up to max_batch__size
|
|
176
|
+
from each cell.
|
|
177
|
+
- After source exhaustion: flush leftovers in waves capped at max_batch__size per cell.
|
|
178
|
+
|
|
179
|
+
Statistics policy:
|
|
180
|
+
* Every emission except the last carries {}.
|
|
181
|
+
* The last emission carries the accumulated upstream stats (unfiltered).
|
|
182
|
+
"""
|
|
183
|
+
schedule = self._generate_batch_schedule()
|
|
184
|
+
|
|
185
|
+
accum_stats: Dict[str, Any] = {}
|
|
186
|
+
pending: BatchResults | None = None # hold back the most recent emission so we know what's final
|
|
187
|
+
|
|
188
|
+
burst_threshold = 3 * max_batch__size
|
|
189
|
+
|
|
190
|
+
for upstream in self.predecessor.get_batch(max_batch__size):
|
|
191
|
+
# accumulate upstream stats (unfiltered)
|
|
192
|
+
if upstream.statistics:
|
|
193
|
+
accum_stats = merge_summery(accum_stats, upstream.statistics)
|
|
194
|
+
|
|
195
|
+
# add data to cells
|
|
196
|
+
for item in upstream.chunk:
|
|
197
|
+
r, c = self._id_extractor(item)
|
|
198
|
+
if not (0 <= r < self.table_size and 0 <= c < self.table_size):
|
|
199
|
+
raise ValueError(f"partition id out of range: {(r, c)} for table_size={self.table_size}")
|
|
200
|
+
self.buffer[r][c].append(item)
|
|
201
|
+
|
|
202
|
+
# process partitions
|
|
203
|
+
for partition in schedule:
|
|
204
|
+
# normal full flush when all cells are ready
|
|
205
|
+
if all(len(self.buffer[r][c]) >= max_batch__size for r, c in partition):
|
|
206
|
+
br = next(self._flush_group(partition, max_batch__size, statistics={}))
|
|
207
|
+
if pending is not None:
|
|
208
|
+
yield pending
|
|
209
|
+
pending = br
|
|
210
|
+
continue
|
|
211
|
+
|
|
212
|
+
# burst flush if any cell backlog explodes
|
|
213
|
+
hot_cells = [(r, c, len(self.buffer[r][c])) for r, c in partition if
|
|
214
|
+
len(self.buffer[r][c]) >= burst_threshold]
|
|
215
|
+
if hot_cells:
|
|
216
|
+
top_r, top_c, top_len = max(hot_cells, key=lambda x: x[2])
|
|
217
|
+
self.logger.debug(
|
|
218
|
+
"burst flush: partition=%s threshold=%d top_cell=(%d,%d len=%d)",
|
|
219
|
+
partition, burst_threshold, top_r, top_c, top_len
|
|
220
|
+
)
|
|
221
|
+
br = next(self._flush_group(partition, max_batch__size, statistics={}))
|
|
222
|
+
if pending is not None:
|
|
223
|
+
yield pending
|
|
224
|
+
pending = br
|
|
225
|
+
|
|
226
|
+
# source exhausted: drain leftovers in capped waves (respecting batch size)
|
|
227
|
+
self.logger.debug("start flushing leftovers")
|
|
228
|
+
for partition in (p for p in schedule if any(self.buffer[r][c] for r, c in p)):
|
|
229
|
+
while any(self.buffer[r][c] for r, c in partition):
|
|
230
|
+
br = next(self._flush_group(partition, max_batch__size, statistics={}))
|
|
231
|
+
if pending is not None:
|
|
232
|
+
yield pending
|
|
233
|
+
pending = br
|
|
234
|
+
|
|
235
|
+
# final emission carries accumulated stats once
|
|
236
|
+
if pending is not None:
|
|
237
|
+
yield BatchResults(chunk=pending.chunk, statistics=accum_stats, batch_size=pending.batch_size)
|
|
238
|
+
|
|
239
|
+
def _log_buffer_matrix(self, *, partition: List[Tuple[int, int]]) -> None:
|
|
240
|
+
"""
|
|
241
|
+
Dumps a compact 2D matrix of per-cell sizes (len of each buffer) when DEBUG is enabled.
|
|
242
|
+
"""
|
|
243
|
+
if not self.logger.isEnabledFor(logging.DEBUG):
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
counts = [
|
|
247
|
+
[len(self.buffer[r][c]) for c in range(self.table_size)]
|
|
248
|
+
for r in range(self.table_size)
|
|
249
|
+
]
|
|
250
|
+
marks = set(partition)
|
|
251
|
+
|
|
252
|
+
pad = max(2, len(str(self.table_size - 1)))
|
|
253
|
+
col_headers = [f"c{c:0{pad}d}" for c in range(self.table_size)]
|
|
254
|
+
|
|
255
|
+
rows = []
|
|
256
|
+
for r in range(self.table_size):
|
|
257
|
+
row_label = f"r{r:0{pad}d}"
|
|
258
|
+
row_vals = [f"[{v}]" if (r, c) in marks else f"{v}" for c, v in enumerate(counts[r])]
|
|
259
|
+
rows.append([row_label, *row_vals])
|
|
260
|
+
|
|
261
|
+
table = tabulate(
|
|
262
|
+
rows,
|
|
263
|
+
headers=["", *col_headers],
|
|
264
|
+
tablefmt="psql",
|
|
265
|
+
stralign="right",
|
|
266
|
+
disable_numparse=True,
|
|
267
|
+
)
|
|
268
|
+
self.logger.debug("buffer matrix:\n%s", table)
|
etl_lib/core/Task.py
CHANGED
|
@@ -7,7 +7,7 @@ from datetime import datetime
|
|
|
7
7
|
|
|
8
8
|
class TaskReturn:
|
|
9
9
|
"""
|
|
10
|
-
Return object for the :
|
|
10
|
+
Return object for the :func:`~Task.execute` function, transporting result information.
|
|
11
11
|
"""
|
|
12
12
|
|
|
13
13
|
success: bool
|
|
@@ -59,7 +59,7 @@ class Task:
|
|
|
59
59
|
ETL job that can be executed.
|
|
60
60
|
|
|
61
61
|
Provides reporting, time tracking and error handling.
|
|
62
|
-
Implementations must provide the :
|
|
62
|
+
Implementations must provide the :func:`~run_internal` function.
|
|
63
63
|
"""
|
|
64
64
|
|
|
65
65
|
def __init__(self, context):
|
|
@@ -67,16 +67,17 @@ class Task:
|
|
|
67
67
|
Construct a Task object.
|
|
68
68
|
|
|
69
69
|
Args:
|
|
70
|
-
context: :
|
|
70
|
+
context: :class:`~etl_lib.core.ETLContext.ETLContext` instance. Will be available to subclasses.
|
|
71
71
|
"""
|
|
72
72
|
self.context = context
|
|
73
|
-
|
|
73
|
+
""":class:`~etl_lib.core.ETLContext.ETLContext` giving access to various resources."""
|
|
74
|
+
self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
|
|
74
75
|
self.uuid = str(uuid.uuid4())
|
|
75
76
|
"""Uniquely identifies a Task."""
|
|
76
77
|
self.start_time: datetime
|
|
77
|
-
"""Time when the :
|
|
78
|
+
"""Time when the :func:`~execute` was called., `None` before."""
|
|
78
79
|
self.end_time: datetime
|
|
79
|
-
"""Time when the :
|
|
80
|
+
"""Time when the :func:`~execute` has finished., `None` before."""
|
|
80
81
|
self.success: bool
|
|
81
82
|
"""True if the task has finished successful. False otherwise, `None` before the task has finished."""
|
|
82
83
|
self.depth: int = 0
|
|
@@ -87,8 +88,9 @@ class Task:
|
|
|
87
88
|
Executes the task.
|
|
88
89
|
|
|
89
90
|
Implementations of this Interface should not overwrite this method, but provide the
|
|
90
|
-
Task functionality inside :
|
|
91
|
-
Will use the :
|
|
91
|
+
Task functionality inside :func:`~run_internal` which will be called from here.
|
|
92
|
+
Will use the :class:`~etl_lib.core.ProgressReporter.ProgressReporter` from
|
|
93
|
+
:attr:`~etl_lib.core.Task.Task.context` to report status updates.
|
|
92
94
|
|
|
93
95
|
Args:
|
|
94
96
|
kwargs: will be passed to `run_internal`
|
|
@@ -34,6 +34,8 @@ class ValidationBatchProcessor(BatchProcessor):
|
|
|
34
34
|
Each row in this file will contain the original data together with all validation errors for this row.
|
|
35
35
|
"""
|
|
36
36
|
super().__init__(context, task, predecessor)
|
|
37
|
+
if model is not None and error_file is None:
|
|
38
|
+
raise ValueError('you must provide error file if the model is specified')
|
|
37
39
|
self.error_file = error_file
|
|
38
40
|
self.model = model
|
|
39
41
|
|
etl_lib/core/utils.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import os
|
|
3
|
+
import signal
|
|
2
4
|
|
|
3
5
|
|
|
4
6
|
def merge_summery(summery_1: dict, summery_2: dict) -> dict:
|
|
@@ -12,17 +14,56 @@ def merge_summery(summery_1: dict, summery_2: dict) -> dict:
|
|
|
12
14
|
|
|
13
15
|
def setup_logging(log_file=None):
|
|
14
16
|
"""
|
|
15
|
-
Set up logging
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
:type log_file: str, optional
|
|
17
|
+
Set up the logging. INFO is used for the root logger.
|
|
18
|
+
Via ETL_LIB_LOG_LEVEL environment variable, the log level of the library itself can be set to another level.
|
|
19
|
+
It also defaults to INFO.
|
|
19
20
|
"""
|
|
20
|
-
|
|
21
|
+
fmt = '%(asctime)s - %(levelname)s - %(name)s - [%(threadName)s] - %(message)s'
|
|
22
|
+
formatter = logging.Formatter(fmt)
|
|
23
|
+
|
|
24
|
+
root_handlers = [logging.StreamHandler()]
|
|
25
|
+
if log_file:
|
|
26
|
+
root_handlers.append(logging.FileHandler(log_file))
|
|
27
|
+
for h in root_handlers:
|
|
28
|
+
h.setLevel(logging.INFO)
|
|
29
|
+
h.setFormatter(formatter)
|
|
30
|
+
logging.basicConfig(level=logging.INFO, handlers=root_handlers, force=True)
|
|
31
|
+
|
|
32
|
+
raw = os.getenv("ETL_LIB_LOG_LEVEL", "INFO")
|
|
33
|
+
try:
|
|
34
|
+
etl_level = int(raw) if str(raw).isdigit() else getattr(logging, str(raw).upper())
|
|
35
|
+
except Exception:
|
|
36
|
+
etl_level = logging.DEBUG
|
|
37
|
+
|
|
38
|
+
etl_logger = logging.getLogger('etl_lib')
|
|
39
|
+
etl_logger.setLevel(etl_level)
|
|
40
|
+
etl_logger.propagate = False
|
|
41
|
+
etl_logger.handlers.clear()
|
|
42
|
+
|
|
43
|
+
dbg_console = logging.StreamHandler()
|
|
44
|
+
dbg_console.setLevel(logging.NOTSET)
|
|
45
|
+
dbg_console.setFormatter(formatter)
|
|
46
|
+
etl_logger.addHandler(dbg_console)
|
|
47
|
+
|
|
21
48
|
if log_file:
|
|
22
|
-
|
|
49
|
+
dbg_file = logging.FileHandler(log_file)
|
|
50
|
+
dbg_file.setLevel(logging.NOTSET)
|
|
51
|
+
dbg_file.setFormatter(formatter)
|
|
52
|
+
etl_logger.addHandler(dbg_file)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def add_sigint_handler(handler_to_add):
|
|
56
|
+
"""
|
|
57
|
+
Register handler_to_add(signum, frame) to run on Ctrl-C,
|
|
58
|
+
chaining any previously registered handler afterward.
|
|
59
|
+
"""
|
|
60
|
+
old_handler = signal.getsignal(signal.SIGINT)
|
|
61
|
+
|
|
62
|
+
def chained_handler(signum, frame):
|
|
63
|
+
# first, run the new handler
|
|
64
|
+
handler_to_add(signum, frame)
|
|
65
|
+
# then, if there was an old handler, call it
|
|
66
|
+
if callable(old_handler):
|
|
67
|
+
old_handler(signum, frame)
|
|
23
68
|
|
|
24
|
-
|
|
25
|
-
level=logging.INFO,
|
|
26
|
-
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
27
|
-
handlers=handlers
|
|
28
|
-
)
|
|
69
|
+
signal.signal(signal.SIGINT, chained_handler)
|
|
@@ -10,7 +10,7 @@ class CypherBatchSink(BatchProcessor):
|
|
|
10
10
|
BatchProcessor to write batches of data to a Neo4j database.
|
|
11
11
|
"""
|
|
12
12
|
|
|
13
|
-
def __init__(self, context: ETLContext, task: Task, predecessor: BatchProcessor, query: str):
|
|
13
|
+
def __init__(self, context: ETLContext, task: Task, predecessor: BatchProcessor, query: str, **kwargs):
|
|
14
14
|
"""
|
|
15
15
|
Constructs a new CypherBatchSink.
|
|
16
16
|
|
|
@@ -20,16 +20,17 @@ class CypherBatchSink(BatchProcessor):
|
|
|
20
20
|
predecessor: BatchProcessor which :func:`~get_batch` function will be called to receive batches to process.
|
|
21
21
|
query: Cypher to write the query to Neo4j.
|
|
22
22
|
Data will be passed as `batch` parameter.
|
|
23
|
-
|
|
23
|
+
Therefore, the query should start with a `UNWIND $batch AS row`.
|
|
24
24
|
"""
|
|
25
25
|
super().__init__(context, task, predecessor)
|
|
26
26
|
self.query = query
|
|
27
27
|
self.neo4j = context.neo4j
|
|
28
|
+
self.kwargs = kwargs
|
|
28
29
|
|
|
29
30
|
def get_batch(self, batch_size: int) -> Generator[BatchResults, None, None]:
|
|
30
31
|
assert self.predecessor is not None
|
|
31
32
|
|
|
32
33
|
with self.neo4j.session() as session:
|
|
33
34
|
for batch_result in self.predecessor.get_batch(batch_size):
|
|
34
|
-
result = self.neo4j.query_database(session=session, query=self.query, batch=batch_result.chunk)
|
|
35
|
+
result = self.neo4j.query_database(session=session, query=self.query, batch=batch_result.chunk, **self.kwargs)
|
|
35
36
|
yield append_result(batch_result, result.summery)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from typing import Generator
|
|
2
|
+
from sqlalchemy import text
|
|
3
|
+
from etl_lib.core.ETLContext import ETLContext
|
|
4
|
+
from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults, append_result
|
|
5
|
+
from etl_lib.core.Task import Task
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SQLBatchSink(BatchProcessor):
|
|
9
|
+
"""
|
|
10
|
+
BatchProcessor to write batches of data to an SQL database.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
def __init__(self, context: ETLContext, task: Task, predecessor: BatchProcessor, query: str):
|
|
14
|
+
"""
|
|
15
|
+
Constructs a new SQLBatchSink.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
context: ETLContext instance.
|
|
19
|
+
task: Task instance owning this batchProcessor.
|
|
20
|
+
predecessor: BatchProcessor which `get_batch` function will be called to receive batches to process.
|
|
21
|
+
query: SQL query to write data.
|
|
22
|
+
Data will be passed as a batch using parameterized statements (`:param_name` syntax).
|
|
23
|
+
"""
|
|
24
|
+
super().__init__(context, task, predecessor)
|
|
25
|
+
self.query = query
|
|
26
|
+
self.engine = context.sql.engine
|
|
27
|
+
|
|
28
|
+
def get_batch(self, batch_size: int) -> Generator[BatchResults, None, None]:
|
|
29
|
+
assert self.predecessor is not None
|
|
30
|
+
|
|
31
|
+
with self.engine.connect() as conn:
|
|
32
|
+
with conn.begin():
|
|
33
|
+
for batch_result in self.predecessor.get_batch(batch_size):
|
|
34
|
+
conn.execute(text(self.query), batch_result.chunk)
|
|
35
|
+
yield append_result(batch_result, {"sql_rows_written": len(batch_result.chunk)})
|
|
36
|
+
|
|
@@ -0,0 +1,114 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from typing import Generator, Callable, Optional, List, Dict
|
|
3
|
+
|
|
4
|
+
from psycopg2 import OperationalError as PsycopgOperationalError
|
|
5
|
+
from sqlalchemy import text
|
|
6
|
+
from sqlalchemy.exc import OperationalError as SAOperationalError, DBAPIError
|
|
7
|
+
|
|
8
|
+
from etl_lib.core.BatchProcessor import BatchResults, BatchProcessor
|
|
9
|
+
from etl_lib.core.ETLContext import ETLContext
|
|
10
|
+
from etl_lib.core.Task import Task
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SQLBatchSource(BatchProcessor):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
context: ETLContext,
|
|
17
|
+
task: Task,
|
|
18
|
+
query: str,
|
|
19
|
+
record_transformer: Optional[Callable[[dict], dict]] = None,
|
|
20
|
+
**kwargs
|
|
21
|
+
):
|
|
22
|
+
"""
|
|
23
|
+
Constructs a new SQLBatchSource.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
context: :class:`etl_lib.core.ETLContext.ETLContext` instance.
|
|
27
|
+
task: :class:`etl_lib.core.Task.Task` instance owning this batchProcessor.
|
|
28
|
+
query: SQL query to execute.
|
|
29
|
+
record_transformer: Optional function to transform each row (dict format).
|
|
30
|
+
kwargs: Arguments passed as parameters with the query.
|
|
31
|
+
"""
|
|
32
|
+
super().__init__(context, task)
|
|
33
|
+
self.query = query.strip().rstrip(";")
|
|
34
|
+
self.record_transformer = record_transformer
|
|
35
|
+
self.kwargs = kwargs
|
|
36
|
+
|
|
37
|
+
def _fetch_page(self, limit: int, offset: int) -> Optional[List[Dict]]:
|
|
38
|
+
"""
|
|
39
|
+
Fetch a single batch of rows using LIMIT/OFFSET, with retry/backoff.
|
|
40
|
+
|
|
41
|
+
Each page is executed in its own transaction. On transient
|
|
42
|
+
disconnects or DB errors, it retries up to 3 times with exponential backoff.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
limit: maximum number of rows to return.
|
|
46
|
+
offset: number of rows to skip before starting this page.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
A list of row dicts (after applying record_transformer, if any),
|
|
50
|
+
or None if no rows are returned.
|
|
51
|
+
|
|
52
|
+
Raises:
|
|
53
|
+
Exception: re-raises the last caught error on final failure.
|
|
54
|
+
"""
|
|
55
|
+
paged_sql = f"{self.query} LIMIT :limit OFFSET :offset"
|
|
56
|
+
params = {**self.kwargs, "limit": limit, "offset": offset}
|
|
57
|
+
max_retries = 5
|
|
58
|
+
backoff = 2.0
|
|
59
|
+
|
|
60
|
+
for attempt in range(1, max_retries + 1):
|
|
61
|
+
try:
|
|
62
|
+
with self.context.sql.engine.connect() as conn:
|
|
63
|
+
with conn.begin():
|
|
64
|
+
rows = conn.execute(text(paged_sql), params).mappings().all()
|
|
65
|
+
result = [
|
|
66
|
+
self.record_transformer(dict(r)) if self.record_transformer else dict(r)
|
|
67
|
+
for r in rows
|
|
68
|
+
]
|
|
69
|
+
return result if result else None
|
|
70
|
+
|
|
71
|
+
except (PsycopgOperationalError, SAOperationalError, DBAPIError) as err:
|
|
72
|
+
|
|
73
|
+
if attempt == max_retries:
|
|
74
|
+
self.logger.error(
|
|
75
|
+
f"Page fetch failed after {max_retries} attempts "
|
|
76
|
+
f"(limit={limit}, offset={offset}): {err}"
|
|
77
|
+
)
|
|
78
|
+
raise
|
|
79
|
+
|
|
80
|
+
self.logger.warning(
|
|
81
|
+
f"Transient DB error on page fetch {attempt}/{max_retries}: {err!r}, "
|
|
82
|
+
f"retrying in {backoff:.1f}s"
|
|
83
|
+
)
|
|
84
|
+
time.sleep(backoff)
|
|
85
|
+
backoff *= 2
|
|
86
|
+
|
|
87
|
+
return None
|
|
88
|
+
|
|
89
|
+
def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
|
|
90
|
+
"""
|
|
91
|
+
Yield successive batches until the query is exhausted.
|
|
92
|
+
|
|
93
|
+
Calls _fetch_page() repeatedly, advancing the offset by the
|
|
94
|
+
number of rows returned. Stops when no more rows are returned.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
max_batch_size: upper limit on rows per batch.
|
|
98
|
+
|
|
99
|
+
Yields:
|
|
100
|
+
BatchResults for each non-empty page.
|
|
101
|
+
"""
|
|
102
|
+
offset = 0
|
|
103
|
+
while True:
|
|
104
|
+
chunk = self._fetch_page(max_batch_size, offset)
|
|
105
|
+
if not chunk:
|
|
106
|
+
break
|
|
107
|
+
|
|
108
|
+
yield BatchResults(
|
|
109
|
+
chunk=chunk,
|
|
110
|
+
statistics={"sql_rows_read": len(chunk)},
|
|
111
|
+
batch_size=len(chunk),
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
offset += len(chunk)
|
|
@@ -4,8 +4,8 @@ from etl_lib.core.Task import Task, TaskReturn
|
|
|
4
4
|
class CreateReportingConstraintsTask(Task):
|
|
5
5
|
"""Creates the constraint in the REPORTER_DATABASE database."""
|
|
6
6
|
|
|
7
|
-
def __init__(self,
|
|
8
|
-
super().__init__(
|
|
7
|
+
def __init__(self, context):
|
|
8
|
+
super().__init__(context)
|
|
9
9
|
|
|
10
10
|
def run_internal(self, **kwargs) -> TaskReturn:
|
|
11
11
|
database = self.context.env("REPORTER_DATABASE")
|
|
@@ -14,24 +14,60 @@ from etl_lib.data_source.CSVBatchSource import CSVBatchSource
|
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
class CSVLoad2Neo4jTask(Task):
|
|
17
|
-
|
|
17
|
+
'''
|
|
18
18
|
Loads the specified CSV file to Neo4j.
|
|
19
19
|
|
|
20
20
|
Uses BatchProcessors to read, validate and write to Neo4j.
|
|
21
21
|
The validation step is using pydantic, hence a Pydantic model needs to be provided.
|
|
22
|
-
Rows
|
|
22
|
+
Rows with fail validation will be written to en error file. The location of the error file is determined as
|
|
23
23
|
follows:
|
|
24
24
|
|
|
25
|
-
If the context env vars hold an entry `ETL_ERROR_PATH` the file will be
|
|
25
|
+
If the context env vars hold an entry `ETL_ERROR_PATH` the file will be placed there, with the name set to name
|
|
26
26
|
of the provided filename appended with `.error.json`
|
|
27
27
|
|
|
28
|
-
If
|
|
29
|
-
|
|
28
|
+
If `ETL_ERROR_PATH` is not set, the file will be placed in the same directory as the CSV file.
|
|
29
|
+
|
|
30
|
+
Example usage: (from the gtfs demo)
|
|
31
|
+
|
|
32
|
+
.. code-block:: python
|
|
33
|
+
|
|
34
|
+
class LoadStopsTask(CSVLoad2Neo4jTask):
|
|
35
|
+
class Stop(BaseModel):
|
|
36
|
+
id: str = Field(alias="stop_id")
|
|
37
|
+
name: str = Field(alias="stop_name")
|
|
38
|
+
latitude: float = Field(alias="stop_lat")
|
|
39
|
+
longitude: float = Field(alias="stop_lon")
|
|
40
|
+
platform_code: Optional[str] = None
|
|
41
|
+
parent_station: Optional[str] = None
|
|
42
|
+
type: Optional[str] = Field(alias="location_type", default=None)
|
|
43
|
+
timezone: Optional[str] = Field(alias="stop_timezone", default=None)
|
|
44
|
+
code: Optional[str] = Field(alias="stop_code", default=None)
|
|
45
|
+
|
|
46
|
+
def __init__(self, context: ETLContext, file: Path):
|
|
47
|
+
super().__init__(context, LoadStopsTask.Stop, file)
|
|
48
|
+
|
|
49
|
+
def task_name(self) -> str:
|
|
50
|
+
return f"{self.__class__.__name__}('{self.file}')"
|
|
51
|
+
|
|
52
|
+
def _query(self):
|
|
53
|
+
return """
|
|
54
|
+
UNWIND $batch AS row
|
|
55
|
+
MERGE (s:Stop {id: row.id})
|
|
56
|
+
SET s.name = row.name,
|
|
57
|
+
s.location= point({latitude: row.latitude, longitude: row.longitude}),
|
|
58
|
+
s.platformCode= row.platform_code,
|
|
59
|
+
s.parentStation= row.parent_station,
|
|
60
|
+
s.type= row.type,
|
|
61
|
+
s.timezone= row.timezone,
|
|
62
|
+
s.code= row.code
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
'''
|
|
30
66
|
def __init__(self, context: ETLContext, model: Type[BaseModel], file: Path, batch_size: int = 5000):
|
|
31
67
|
super().__init__(context)
|
|
32
68
|
self.batch_size = batch_size
|
|
33
69
|
self.model = model
|
|
34
|
-
self.logger = logging.getLogger(self.__class__.__name__)
|
|
70
|
+
self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
|
|
35
71
|
self.file = file
|
|
36
72
|
|
|
37
73
|
def run_internal(self, **kwargs) -> TaskReturn:
|