neo4j-etl-lib 0.2.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,268 @@
1
+ import logging
2
+ from typing import Any, Callable, Dict, Generator, List, Tuple
3
+
4
+ from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
5
+ from etl_lib.core.utils import merge_summery
6
+ from tabulate import tabulate
7
+
8
+
9
+ def tuple_id_extractor(table_size: int = 10) -> Callable[[Tuple[str | int, str | int]], Tuple[int, int]]:
10
+ """
11
+ Create an ID extractor function for tuple items, using the last decimal digit of each element.
12
+ The output is a `(row, col)` tuple within a `table_size x table_size` grid (default 10x10).
13
+
14
+ Args:
15
+ table_size: The dimension of the grid (number of rows/cols). Defaults to 10.
16
+
17
+ Returns:
18
+ A callable that maps a tuple `(a, b)` to a tuple `(row, col)` using the last digit of `a` and `b`.
19
+ """
20
+
21
+ def extractor(item: Tuple[Any, Any]) -> Tuple[int, int]:
22
+ a, b = item
23
+ try:
24
+ row = int(str(a)[-1])
25
+ col = int(str(b)[-1])
26
+ except Exception as e:
27
+ raise ValueError(f"Failed to extract ID from item {item}: {e}")
28
+ return row, col
29
+
30
+ extractor.table_size = table_size
31
+ return extractor
32
+
33
+
34
+ def dict_id_extractor(
35
+ table_size: int = 10,
36
+ start_key: str = "start",
37
+ end_key: str = "end",
38
+ ) -> Callable[[Dict[str, Any]], Tuple[int, int]]:
39
+ """
40
+ Build an ID extractor for dict rows. The extractor reads two fields (configurable via
41
+ `start_key` and `end_key`) and returns (row, col) based on the last decimal digit of each.
42
+ Range validation remains the responsibility of the SplittingBatchProcessor.
43
+
44
+ Args:
45
+ table_size: Informational hint carried on the extractor; used by callers to sanity-check.
46
+ start_key: Field name for the start node identifier.
47
+ end_key: Field name for the end node identifier.
48
+
49
+ Returns:
50
+ Callable[[Mapping[str, Any]], tuple[int, int]]: Maps {start_key, end_key} → (row, col).
51
+ """
52
+
53
+ def extractor(item: Dict[str, Any]) -> Tuple[int, int]:
54
+ missing = [k for k in (start_key, end_key) if k not in item]
55
+ if missing:
56
+ raise KeyError(f"Item missing required keys: {', '.join(missing)}")
57
+ try:
58
+ row = int(str(item[start_key])[-1])
59
+ col = int(str(item[end_key])[-1])
60
+ except Exception as e:
61
+ raise ValueError(f"Failed to extract ID from item {item}: {e}")
62
+ return row, col
63
+
64
+ extractor.table_size = table_size
65
+ return extractor
66
+
67
+
68
+ class SplittingBatchProcessor(BatchProcessor):
69
+ """
70
+ BatchProcessor that splits incoming BatchResults into non-overlapping partitions based
71
+ on row/col indices extracted by the id_extractor, and emits full or remaining batches
72
+ using the mix-and-batch algorithm from https://neo4j.com/blog/developer/mix-and-batch-relationship-load/
73
+ Each emitted batch is a list of per-cell lists (array of arrays), so callers
74
+ can process each partition (other name for a cell) in parallel.
75
+
76
+ A batch for a schedule group is emitted when all cells in that group have at least `batch_size` items.
77
+ In addition, when a cell/partition reaches 3x the configured max_batch_size, the group is emitted to avoid
78
+ overflowing the buffer when the distribution per cell is uneven.
79
+ Leftovers are flushed after source exhaustion.
80
+ Emitted batches never exceed the configured max_batch_size.
81
+ """
82
+
83
+ def __init__(self, context, table_size: int, id_extractor: Callable[[Any], Tuple[int, int]],
84
+ task=None, predecessor=None):
85
+ super().__init__(context, task, predecessor)
86
+
87
+ # If the extractor carries an expected table size, use or validate it
88
+ if hasattr(id_extractor, "table_size"):
89
+ expected_size = id_extractor.table_size
90
+ if table_size is None:
91
+ table_size = expected_size # determine table size from extractor if not provided
92
+ elif table_size != expected_size:
93
+ raise ValueError(f"Mismatch between provided table_size ({table_size}) and "
94
+ f"id_extractor table_size ({expected_size}).")
95
+ elif table_size is None:
96
+ raise ValueError("table_size must be specified if id_extractor has no defined table_size")
97
+ self.table_size = table_size
98
+ self._id_extractor = id_extractor
99
+
100
+ # Initialize 2D buffer for partitions
101
+ self.buffer: Dict[int, Dict[int, List[Any]]] = {
102
+ i: {j: [] for j in range(self.table_size)} for i in range(self.table_size)
103
+ }
104
+ self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
105
+
106
+ def _generate_batch_schedule(self) -> List[List[Tuple[int, int]]]:
107
+ """
108
+ Create diagonal stripes (row, col) partitions to ensure no overlapping IDs
109
+ across emitted batches.
110
+ Example grid:
111
+ || 0 | 1 | 2
112
+ =====++=====+=====+=====
113
+ 0 || 0 | 1 | 2
114
+ -----++-----+-----+----
115
+ 1 || 2 | 0 | 1
116
+ -----++-----+-----+-----
117
+ 2 || 1 | 2 | 0
118
+
119
+ would return [[(0, 0), (1, 1), (2, 2)], [(0, 1), (1, 2), (2, 0)], [(0, 2), (1, 0), (2, 1)]]
120
+ """
121
+ schedule: List[List[Tuple[int, int]]] = []
122
+ for shift in range(self.table_size):
123
+ partition: List[Tuple[int, int]] = [
124
+ (i, (i + shift) % self.table_size)
125
+ for i in range(self.table_size)
126
+ ]
127
+ schedule.append(partition)
128
+ return schedule
129
+
130
+ def _flush_group(
131
+ self,
132
+ partitions: List[Tuple[int, int]],
133
+ batch_size: int,
134
+ statistics: Dict[str, Any] | None = None,
135
+ ) -> Generator[BatchResults, None, None]:
136
+ """
137
+ Extract up to `batch_size` items from each cell in `partitions`, remove them from the buffer,
138
+ and yield a BatchResults whose chunks is a list of per-cell lists from these partitions.
139
+
140
+ Args:
141
+ partitions: Cell coordinates forming a diagonal group from the schedule.
142
+ batch_size: Max number of items to take from each cell.
143
+ statistics: Stats dict to attach to this emission (use {} for interim waves).
144
+ The "final" emission will pass the accumulated stats here.
145
+
146
+ Notes:
147
+ - Debug-only: prints a 2D matrix of cell sizes when logger is in DEBUG.
148
+ - batch_size in the returned BatchResults equals the number of emitted items.
149
+ """
150
+ self._log_buffer_matrix(partition=partitions)
151
+
152
+ cell_chunks: List[List[Any]] = []
153
+ for row, col in partitions:
154
+ q = self.buffer[row][col]
155
+ take = min(batch_size, len(q))
156
+ part = q[:take]
157
+ cell_chunks.append(part)
158
+ # remove flushed items
159
+ self.buffer[row][col] = q[take:]
160
+
161
+ emitted = sum(len(c) for c in cell_chunks)
162
+
163
+ result = BatchResults(
164
+ chunk=cell_chunks,
165
+ statistics=statistics or {},
166
+ batch_size=emitted,
167
+ )
168
+ yield result
169
+
170
+ def get_batch(self, max_batch__size: int) -> Generator[BatchResults, None, None]:
171
+ """
172
+ Consume upstream batches, split data across cells, and emit diagonal partitions:
173
+ - During streaming: emit a full partition when all its cells have >= max_batch__size.
174
+ - Also during streaming: if any cell in a partition builds up beyond a 'burst' threshold
175
+ (3 * of max_batch__size), emit that partition early, taking up to max_batch__size
176
+ from each cell.
177
+ - After source exhaustion: flush leftovers in waves capped at max_batch__size per cell.
178
+
179
+ Statistics policy:
180
+ * Every emission except the last carries {}.
181
+ * The last emission carries the accumulated upstream stats (unfiltered).
182
+ """
183
+ schedule = self._generate_batch_schedule()
184
+
185
+ accum_stats: Dict[str, Any] = {}
186
+ pending: BatchResults | None = None # hold back the most recent emission so we know what's final
187
+
188
+ burst_threshold = 3 * max_batch__size
189
+
190
+ for upstream in self.predecessor.get_batch(max_batch__size):
191
+ # accumulate upstream stats (unfiltered)
192
+ if upstream.statistics:
193
+ accum_stats = merge_summery(accum_stats, upstream.statistics)
194
+
195
+ # add data to cells
196
+ for item in upstream.chunk:
197
+ r, c = self._id_extractor(item)
198
+ if not (0 <= r < self.table_size and 0 <= c < self.table_size):
199
+ raise ValueError(f"partition id out of range: {(r, c)} for table_size={self.table_size}")
200
+ self.buffer[r][c].append(item)
201
+
202
+ # process partitions
203
+ for partition in schedule:
204
+ # normal full flush when all cells are ready
205
+ if all(len(self.buffer[r][c]) >= max_batch__size for r, c in partition):
206
+ br = next(self._flush_group(partition, max_batch__size, statistics={}))
207
+ if pending is not None:
208
+ yield pending
209
+ pending = br
210
+ continue
211
+
212
+ # burst flush if any cell backlog explodes
213
+ hot_cells = [(r, c, len(self.buffer[r][c])) for r, c in partition if
214
+ len(self.buffer[r][c]) >= burst_threshold]
215
+ if hot_cells:
216
+ top_r, top_c, top_len = max(hot_cells, key=lambda x: x[2])
217
+ self.logger.debug(
218
+ "burst flush: partition=%s threshold=%d top_cell=(%d,%d len=%d)",
219
+ partition, burst_threshold, top_r, top_c, top_len
220
+ )
221
+ br = next(self._flush_group(partition, max_batch__size, statistics={}))
222
+ if pending is not None:
223
+ yield pending
224
+ pending = br
225
+
226
+ # source exhausted: drain leftovers in capped waves (respecting batch size)
227
+ self.logger.debug("start flushing leftovers")
228
+ for partition in (p for p in schedule if any(self.buffer[r][c] for r, c in p)):
229
+ while any(self.buffer[r][c] for r, c in partition):
230
+ br = next(self._flush_group(partition, max_batch__size, statistics={}))
231
+ if pending is not None:
232
+ yield pending
233
+ pending = br
234
+
235
+ # final emission carries accumulated stats once
236
+ if pending is not None:
237
+ yield BatchResults(chunk=pending.chunk, statistics=accum_stats, batch_size=pending.batch_size)
238
+
239
+ def _log_buffer_matrix(self, *, partition: List[Tuple[int, int]]) -> None:
240
+ """
241
+ Dumps a compact 2D matrix of per-cell sizes (len of each buffer) when DEBUG is enabled.
242
+ """
243
+ if not self.logger.isEnabledFor(logging.DEBUG):
244
+ return
245
+
246
+ counts = [
247
+ [len(self.buffer[r][c]) for c in range(self.table_size)]
248
+ for r in range(self.table_size)
249
+ ]
250
+ marks = set(partition)
251
+
252
+ pad = max(2, len(str(self.table_size - 1)))
253
+ col_headers = [f"c{c:0{pad}d}" for c in range(self.table_size)]
254
+
255
+ rows = []
256
+ for r in range(self.table_size):
257
+ row_label = f"r{r:0{pad}d}"
258
+ row_vals = [f"[{v}]" if (r, c) in marks else f"{v}" for c, v in enumerate(counts[r])]
259
+ rows.append([row_label, *row_vals])
260
+
261
+ table = tabulate(
262
+ rows,
263
+ headers=["", *col_headers],
264
+ tablefmt="psql",
265
+ stralign="right",
266
+ disable_numparse=True,
267
+ )
268
+ self.logger.debug("buffer matrix:\n%s", table)
etl_lib/core/Task.py CHANGED
@@ -7,7 +7,7 @@ from datetime import datetime
7
7
 
8
8
  class TaskReturn:
9
9
  """
10
- Return object for the :py:func:`~Task.execute` function, transporting result information.
10
+ Return object for the :func:`~Task.execute` function, transporting result information.
11
11
  """
12
12
 
13
13
  success: bool
@@ -59,7 +59,7 @@ class Task:
59
59
  ETL job that can be executed.
60
60
 
61
61
  Provides reporting, time tracking and error handling.
62
- Implementations must provide the :py:func:`~run_internal` function.
62
+ Implementations must provide the :func:`~run_internal` function.
63
63
  """
64
64
 
65
65
  def __init__(self, context):
@@ -67,16 +67,17 @@ class Task:
67
67
  Construct a Task object.
68
68
 
69
69
  Args:
70
- context: :py:class:`etl_lib.core.ETLContext.ETLContext` instance. Will be available to subclasses.
70
+ context: :class:`~etl_lib.core.ETLContext.ETLContext` instance. Will be available to subclasses.
71
71
  """
72
72
  self.context = context
73
- self.logger = logging.getLogger(self.__class__.__name__)
73
+ """:class:`~etl_lib.core.ETLContext.ETLContext` giving access to various resources."""
74
+ self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
74
75
  self.uuid = str(uuid.uuid4())
75
76
  """Uniquely identifies a Task."""
76
77
  self.start_time: datetime
77
- """Time when the :py:func:`~execute` was called., `None` before."""
78
+ """Time when the :func:`~execute` was called., `None` before."""
78
79
  self.end_time: datetime
79
- """Time when the :py:func:`~execute` has finished., `None` before."""
80
+ """Time when the :func:`~execute` has finished., `None` before."""
80
81
  self.success: bool
81
82
  """True if the task has finished successful. False otherwise, `None` before the task has finished."""
82
83
  self.depth: int = 0
@@ -87,8 +88,9 @@ class Task:
87
88
  Executes the task.
88
89
 
89
90
  Implementations of this Interface should not overwrite this method, but provide the
90
- Task functionality inside :py:func:`~run_internal` which will be called from here.
91
- Will use the :py:class:`ProgressReporter` from the :py:attr:`~context` to report status updates.
91
+ Task functionality inside :func:`~run_internal` which will be called from here.
92
+ Will use the :class:`~etl_lib.core.ProgressReporter.ProgressReporter` from
93
+ :attr:`~etl_lib.core.Task.Task.context` to report status updates.
92
94
 
93
95
  Args:
94
96
  kwargs: will be passed to `run_internal`
@@ -34,6 +34,8 @@ class ValidationBatchProcessor(BatchProcessor):
34
34
  Each row in this file will contain the original data together with all validation errors for this row.
35
35
  """
36
36
  super().__init__(context, task, predecessor)
37
+ if model is not None and error_file is None:
38
+ raise ValueError('you must provide error file if the model is specified')
37
39
  self.error_file = error_file
38
40
  self.model = model
39
41
 
etl_lib/core/utils.py CHANGED
@@ -1,4 +1,6 @@
1
1
  import logging
2
+ import os
3
+ import signal
2
4
 
3
5
 
4
6
  def merge_summery(summery_1: dict, summery_2: dict) -> dict:
@@ -12,17 +14,56 @@ def merge_summery(summery_1: dict, summery_2: dict) -> dict:
12
14
 
13
15
  def setup_logging(log_file=None):
14
16
  """
15
- Set up logging to console and optionally to a log file.
16
-
17
- :param log_file: Path to the log file
18
- :type log_file: str, optional
17
+ Set up the logging. INFO is used for the root logger.
18
+ Via ETL_LIB_LOG_LEVEL environment variable, the log level of the library itself can be set to another level.
19
+ It also defaults to INFO.
19
20
  """
20
- handlers = [logging.StreamHandler()]
21
+ fmt = '%(asctime)s - %(levelname)s - %(name)s - [%(threadName)s] - %(message)s'
22
+ formatter = logging.Formatter(fmt)
23
+
24
+ root_handlers = [logging.StreamHandler()]
25
+ if log_file:
26
+ root_handlers.append(logging.FileHandler(log_file))
27
+ for h in root_handlers:
28
+ h.setLevel(logging.INFO)
29
+ h.setFormatter(formatter)
30
+ logging.basicConfig(level=logging.INFO, handlers=root_handlers, force=True)
31
+
32
+ raw = os.getenv("ETL_LIB_LOG_LEVEL", "INFO")
33
+ try:
34
+ etl_level = int(raw) if str(raw).isdigit() else getattr(logging, str(raw).upper())
35
+ except Exception:
36
+ etl_level = logging.DEBUG
37
+
38
+ etl_logger = logging.getLogger('etl_lib')
39
+ etl_logger.setLevel(etl_level)
40
+ etl_logger.propagate = False
41
+ etl_logger.handlers.clear()
42
+
43
+ dbg_console = logging.StreamHandler()
44
+ dbg_console.setLevel(logging.NOTSET)
45
+ dbg_console.setFormatter(formatter)
46
+ etl_logger.addHandler(dbg_console)
47
+
21
48
  if log_file:
22
- handlers.append(logging.FileHandler(log_file))
49
+ dbg_file = logging.FileHandler(log_file)
50
+ dbg_file.setLevel(logging.NOTSET)
51
+ dbg_file.setFormatter(formatter)
52
+ etl_logger.addHandler(dbg_file)
53
+
54
+
55
+ def add_sigint_handler(handler_to_add):
56
+ """
57
+ Register handler_to_add(signum, frame) to run on Ctrl-C,
58
+ chaining any previously registered handler afterward.
59
+ """
60
+ old_handler = signal.getsignal(signal.SIGINT)
61
+
62
+ def chained_handler(signum, frame):
63
+ # first, run the new handler
64
+ handler_to_add(signum, frame)
65
+ # then, if there was an old handler, call it
66
+ if callable(old_handler):
67
+ old_handler(signum, frame)
23
68
 
24
- logging.basicConfig(
25
- level=logging.INFO,
26
- format='%(asctime)s - %(levelname)s - %(message)s',
27
- handlers=handlers
28
- )
69
+ signal.signal(signal.SIGINT, chained_handler)
@@ -30,7 +30,7 @@ class CSVBatchSource(BatchProcessor):
30
30
  self.csv_file = csv_file
31
31
  self.kwargs = kwargs
32
32
 
33
- def get_batch(self, max_batch__size: int) -> Generator[BatchResults]:
33
+ def get_batch(self, max_batch__size: int) -> Generator[BatchResults, None, None]:
34
34
  for batch_size, chunks_ in self.__read_csv(self.csv_file, batch_size=max_batch__size, **self.kwargs):
35
35
  yield BatchResults(chunk=chunks_, statistics={"csv_lines_read": batch_size}, batch_size=batch_size)
36
36
 
@@ -1,5 +1,10 @@
1
- from typing import Generator, Callable, Optional
1
+ import time
2
+ from typing import Generator, Callable, Optional, List, Dict
3
+
4
+ from psycopg2 import OperationalError as PsycopgOperationalError
2
5
  from sqlalchemy import text
6
+ from sqlalchemy.exc import OperationalError as SAOperationalError, DBAPIError
7
+
3
8
  from etl_lib.core.BatchProcessor import BatchResults, BatchProcessor
4
9
  from etl_lib.core.ETLContext import ETLContext
5
10
  from etl_lib.core.Task import Task
@@ -25,36 +30,85 @@ class SQLBatchSource(BatchProcessor):
25
30
  kwargs: Arguments passed as parameters with the query.
26
31
  """
27
32
  super().__init__(context, task)
28
- self.query = query
33
+ self.query = query.strip().rstrip(";")
29
34
  self.record_transformer = record_transformer
30
- self.kwargs = kwargs # Query parameters
35
+ self.kwargs = kwargs
36
+
37
+ def _fetch_page(self, limit: int, offset: int) -> Optional[List[Dict]]:
38
+ """
39
+ Fetch a single batch of rows using LIMIT/OFFSET, with retry/backoff.
40
+
41
+ Each page is executed in its own transaction. On transient
42
+ disconnects or DB errors, it retries up to 3 times with exponential backoff.
31
43
 
32
- def __read_records(self, conn, batch_size: int):
33
- batch_ = []
34
- result = conn.execute(text(self.query), self.kwargs) # Safe execution with bound parameters
44
+ Args:
45
+ limit: maximum number of rows to return.
46
+ offset: number of rows to skip before starting this page.
47
+
48
+ Returns:
49
+ A list of row dicts (after applying record_transformer, if any),
50
+ or None if no rows are returned.
35
51
 
36
- for row in result.mappings(): # Returns row as dict (like Neo4j's `record.data()`)
37
- data = dict(row) # Convert to dictionary
38
- if self.record_transformer:
39
- data = self.record_transformer(data)
40
- batch_.append(data)
52
+ Raises:
53
+ Exception: re-raises the last caught error on final failure.
54
+ """
55
+ paged_sql = f"{self.query} LIMIT :limit OFFSET :offset"
56
+ params = {**self.kwargs, "limit": limit, "offset": offset}
57
+ max_retries = 5
58
+ backoff = 2.0
41
59
 
42
- if len(batch_) == batch_size:
43
- yield batch_
44
- batch_ = [] # Reset batch
60
+ for attempt in range(1, max_retries + 1):
61
+ try:
62
+ with self.context.sql.engine.connect() as conn:
63
+ with conn.begin():
64
+ rows = conn.execute(text(paged_sql), params).mappings().all()
65
+ result = [
66
+ self.record_transformer(dict(r)) if self.record_transformer else dict(r)
67
+ for r in rows
68
+ ]
69
+ return result if result else None
45
70
 
46
- if batch_:
47
- yield batch_
71
+ except (PsycopgOperationalError, SAOperationalError, DBAPIError) as err:
72
+
73
+ if attempt == max_retries:
74
+ self.logger.error(
75
+ f"Page fetch failed after {max_retries} attempts "
76
+ f"(limit={limit}, offset={offset}): {err}"
77
+ )
78
+ raise
79
+
80
+ self.logger.warning(
81
+ f"Transient DB error on page fetch {attempt}/{max_retries}: {err!r}, "
82
+ f"retrying in {backoff:.1f}s"
83
+ )
84
+ time.sleep(backoff)
85
+ backoff *= 2
86
+
87
+ return None
48
88
 
49
89
  def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
50
90
  """
51
- Fetches data in batches using an open transaction, similar to Neo4j's approach.
91
+ Yield successive batches until the query is exhausted.
92
+
93
+ Calls _fetch_page() repeatedly, advancing the offset by the
94
+ number of rows returned. Stops when no more rows are returned.
95
+
96
+ Args:
97
+ max_batch_size: upper limit on rows per batch.
98
+
99
+ Yields:
100
+ BatchResults for each non-empty page.
52
101
  """
53
- with self.context.sql.engine.connect() as conn: # Keep transaction open
54
- with conn.begin(): # Ensures rollback on failure
55
- for chunk in self.__read_records(conn, max_batch_size):
56
- yield BatchResults(
57
- chunk=chunk,
58
- statistics={"sql_rows_read": len(chunk)},
59
- batch_size=len(chunk)
60
- )
102
+ offset = 0
103
+ while True:
104
+ chunk = self._fetch_page(max_batch_size, offset)
105
+ if not chunk:
106
+ break
107
+
108
+ yield BatchResults(
109
+ chunk=chunk,
110
+ statistics={"sql_rows_read": len(chunk)},
111
+ batch_size=len(chunk),
112
+ )
113
+
114
+ offset += len(chunk)
etl_lib/task/GDSTask.py CHANGED
@@ -28,11 +28,14 @@ class GDSTask(Task):
28
28
  Function that uses the gds client to perform tasks. See the following example:
29
29
 
30
30
  def gds_fun(etl_context):
31
- with etl_context.neo4j.gds() as gds:
32
- gds.graph.drop("neo4j-offices", failIfMissing=False)
33
- g_office, project_result = gds.graph.project("neo4j-offices", "City", "FLY_TO")
34
- mutate_result = gds.pageRank.mutate(g_office, tolerance=0.5, mutateProperty="rank")
35
- return TaskReturn(success=True, summery=transform_dict(mutate_result.to_dict()))
31
+ gds = etl_context.neo4j.gds
32
+ gds.graph.drop("neo4j-offices", failIfMissing=False)
33
+ g_office, project_result = gds.graph.project("neo4j-offices", "City", "FLY_TO")
34
+ mutate_result = gds.pageRank.write(g_office, tolerance=0.5, writeProperty="rank")
35
+ return TaskReturn(success=True, summery=transform_dict(mutate_result.to_dict()))
36
+
37
+ Notes: Do *NOT* use `etl_context.neo4j.gds` with a context manager. The GDS client closes the underlying
38
+ connection when exiting the context.
36
39
 
37
40
  :param context: The ETLContext to use. Provides the gds client to the func via `etl_context.neo4j.gds()`
38
41
  :param func: a function that expects a param `etl_context` and returns a `TaskReturn` object.
@@ -67,7 +67,7 @@ class CSVLoad2Neo4jTask(Task):
67
67
  super().__init__(context)
68
68
  self.batch_size = batch_size
69
69
  self.model = model
70
- self.logger = logging.getLogger(self.__class__.__name__)
70
+ self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
71
71
  self.file = file
72
72
 
73
73
  def run_internal(self, **kwargs) -> TaskReturn:
@@ -0,0 +1,98 @@
1
+ import abc
2
+ from pathlib import Path
3
+ from typing import Type
4
+
5
+
6
+ from etl_lib.core.ClosedLoopBatchProcessor import ClosedLoopBatchProcessor
7
+ from etl_lib.core.ETLContext import ETLContext
8
+ from etl_lib.core.ParallelBatchProcessor import ParallelBatchProcessor
9
+ from etl_lib.core.SplittingBatchProcessor import SplittingBatchProcessor, dict_id_extractor
10
+ from etl_lib.core.Task import Task, TaskReturn
11
+ from etl_lib.core.ValidationBatchProcessor import ValidationBatchProcessor
12
+ from etl_lib.data_sink.CypherBatchSink import CypherBatchSink
13
+ from etl_lib.data_source.CSVBatchSource import CSVBatchSource
14
+ from pydantic import BaseModel
15
+
16
+ class ParallelCSVLoad2Neo4jTask(Task):
17
+ """
18
+ Parallel CSV → Neo4j load using the mix-and-batch strategy.
19
+
20
+ Wires a CSV reader, optional Pydantic validation, a diagonal splitter
21
+ (to avoid overlapping node locks), and a Cypher sink. Rows are
22
+ distributed into (row, col) partitions and processed in non-overlapping groups.
23
+
24
+ Args:
25
+ context: Shared ETL context.
26
+ file: CSV file to load.
27
+ model: Optional Pydantic model for row validation; invalid rows go to `error_file`.
28
+ error_file: Output for invalid rows. Required when `model` is set.
29
+ table_size: Bucketing grid size for the splitter.
30
+ batch_size: Per-cell target batch size from the splitter.
31
+ max_workers: Worker threads per wave.
32
+ prefetch: Number of waves to prefetch from the splitter.
33
+ **csv_reader_kwargs: Forwarded to :py:class:`etl_lib.data_source.CSVBatchSource.CSVBatchSource`.
34
+
35
+ Returns:
36
+ :py:class:`~etl_lib.core.Task.TaskReturn` with merged validation and Neo4j counters.
37
+
38
+ Notes:
39
+ - `_query()` must return Cypher that starts with ``UNWIND $batch AS row``.
40
+ - Override `_id_extractor()` if your CSV schema doesn’t expose ``start``/``end``; the default uses
41
+ :py:func:`etl_lib.core.SplittingBatchProcessor.dict_id_extractor`.
42
+ - See the nyc-taxi example for a working subclass.
43
+ """
44
+ def __init__(self,
45
+ context: ETLContext,
46
+ file: Path,
47
+ model: Type[BaseModel] | None = None,
48
+ error_file: Path | None = None,
49
+ table_size: int = 10,
50
+ batch_size: int = 5000,
51
+ max_workers: int | None = None,
52
+ prefetch: int = 4,
53
+ **csv_reader_kwargs):
54
+ super().__init__(context)
55
+ self.file = file
56
+ self.model = model
57
+ if model is not None and error_file is None:
58
+ raise ValueError('you must provide error file if the model is specified')
59
+ self.error_file = error_file
60
+ self.table_size = table_size
61
+ self.batch_size = batch_size
62
+ self.max_workers = max_workers or table_size
63
+ self.prefetch = prefetch
64
+ self.csv_reader_kwargs = csv_reader_kwargs
65
+
66
+ def run_internal(self) -> TaskReturn:
67
+ csv = CSVBatchSource(self.file, self.context, self, **self.csv_reader_kwargs)
68
+ predecessor = csv
69
+ if self.model is not None:
70
+ predecessor = ValidationBatchProcessor(self.context, self, csv, self.model, self.error_file)
71
+
72
+ splitter = SplittingBatchProcessor(
73
+ context=self.context,
74
+ task=self,
75
+ predecessor=predecessor,
76
+ table_size=self.table_size,
77
+ id_extractor=self._id_extractor()
78
+ )
79
+
80
+ parallel = ParallelBatchProcessor(
81
+ context=self.context,
82
+ task=self,
83
+ predecessor=splitter,
84
+ worker_factory=lambda: CypherBatchSink(self.context, self, None, self._query()),
85
+ max_workers=self.max_workers,
86
+ prefetch=self.prefetch
87
+ )
88
+
89
+ closing = ClosedLoopBatchProcessor(self.context, self, parallel)
90
+ result = next(closing.get_batch(self.batch_size))
91
+ return TaskReturn(True, result.statistics)
92
+
93
+ def _id_extractor(self):
94
+ return dict_id_extractor()
95
+
96
+ @abc.abstractmethod
97
+ def _query(self):
98
+ pass