neo4j-etl-lib 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,268 @@
1
+ import logging
2
+ from typing import Any, Callable, Dict, Generator, List, Tuple
3
+
4
+ from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
5
+ from etl_lib.core.utils import merge_summery
6
+ from tabulate import tabulate
7
+
8
+
9
+ def tuple_id_extractor(table_size: int = 10) -> Callable[[Tuple[str | int, str | int]], Tuple[int, int]]:
10
+ """
11
+ Create an ID extractor function for tuple items, using the last decimal digit of each element.
12
+ The output is a `(row, col)` tuple within a `table_size x table_size` grid (default 10x10).
13
+
14
+ Args:
15
+ table_size: The dimension of the grid (number of rows/cols). Defaults to 10.
16
+
17
+ Returns:
18
+ A callable that maps a tuple `(a, b)` to a tuple `(row, col)` using the last digit of `a` and `b`.
19
+ """
20
+
21
+ def extractor(item: Tuple[Any, Any]) -> Tuple[int, int]:
22
+ a, b = item
23
+ try:
24
+ row = int(str(a)[-1])
25
+ col = int(str(b)[-1])
26
+ except Exception as e:
27
+ raise ValueError(f"Failed to extract ID from item {item}: {e}")
28
+ return row, col
29
+
30
+ extractor.table_size = table_size
31
+ return extractor
32
+
33
+
34
+ def dict_id_extractor(
35
+ table_size: int = 10,
36
+ start_key: str = "start",
37
+ end_key: str = "end",
38
+ ) -> Callable[[Dict[str, Any]], Tuple[int, int]]:
39
+ """
40
+ Build an ID extractor for dict rows. The extractor reads two fields (configurable via
41
+ `start_key` and `end_key`) and returns (row, col) based on the last decimal digit of each.
42
+ Range validation remains the responsibility of the SplittingBatchProcessor.
43
+
44
+ Args:
45
+ table_size: Informational hint carried on the extractor; used by callers to sanity-check.
46
+ start_key: Field name for the start node identifier.
47
+ end_key: Field name for the end node identifier.
48
+
49
+ Returns:
50
+ Callable[[Mapping[str, Any]], tuple[int, int]]: Maps {start_key, end_key} → (row, col).
51
+ """
52
+
53
+ def extractor(item: Dict[str, Any]) -> Tuple[int, int]:
54
+ missing = [k for k in (start_key, end_key) if k not in item]
55
+ if missing:
56
+ raise KeyError(f"Item missing required keys: {', '.join(missing)}")
57
+ try:
58
+ row = int(str(item[start_key])[-1])
59
+ col = int(str(item[end_key])[-1])
60
+ except Exception as e:
61
+ raise ValueError(f"Failed to extract ID from item {item}: {e}")
62
+ return row, col
63
+
64
+ extractor.table_size = table_size
65
+ return extractor
66
+
67
+
68
+ class SplittingBatchProcessor(BatchProcessor):
69
+ """
70
+ BatchProcessor that splits incoming BatchResults into non-overlapping partitions based
71
+ on row/col indices extracted by the id_extractor, and emits full or remaining batches
72
+ using the mix-and-batch algorithm from https://neo4j.com/blog/developer/mix-and-batch-relationship-load/
73
+ Each emitted batch is a list of per-cell lists (array of arrays), so callers
74
+ can process each partition (other name for a cell) in parallel.
75
+
76
+ A batch for a schedule group is emitted when all cells in that group have at least `batch_size` items.
77
+ In addition, when a cell/partition reaches 3x the configured max_batch_size, the group is emitted to avoid
78
+ overflowing the buffer when the distribution per cell is uneven.
79
+ Leftovers are flushed after source exhaustion.
80
+ Emitted batches never exceed the configured max_batch_size.
81
+ """
82
+
83
+ def __init__(self, context, table_size: int, id_extractor: Callable[[Any], Tuple[int, int]],
84
+ task=None, predecessor=None):
85
+ super().__init__(context, task, predecessor)
86
+
87
+ # If the extractor carries an expected table size, use or validate it
88
+ if hasattr(id_extractor, "table_size"):
89
+ expected_size = id_extractor.table_size
90
+ if table_size is None:
91
+ table_size = expected_size # determine table size from extractor if not provided
92
+ elif table_size != expected_size:
93
+ raise ValueError(f"Mismatch between provided table_size ({table_size}) and "
94
+ f"id_extractor table_size ({expected_size}).")
95
+ elif table_size is None:
96
+ raise ValueError("table_size must be specified if id_extractor has no defined table_size")
97
+ self.table_size = table_size
98
+ self._id_extractor = id_extractor
99
+
100
+ # Initialize 2D buffer for partitions
101
+ self.buffer: Dict[int, Dict[int, List[Any]]] = {
102
+ i: {j: [] for j in range(self.table_size)} for i in range(self.table_size)
103
+ }
104
+ self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
105
+
106
+ def _generate_batch_schedule(self) -> List[List[Tuple[int, int]]]:
107
+ """
108
+ Create diagonal stripes (row, col) partitions to ensure no overlapping IDs
109
+ across emitted batches.
110
+ Example grid:
111
+ || 0 | 1 | 2
112
+ =====++=====+=====+=====
113
+ 0 || 0 | 1 | 2
114
+ -----++-----+-----+----
115
+ 1 || 2 | 0 | 1
116
+ -----++-----+-----+-----
117
+ 2 || 1 | 2 | 0
118
+
119
+ would return [[(0, 0), (1, 1), (2, 2)], [(0, 1), (1, 2), (2, 0)], [(0, 2), (1, 0), (2, 1)]]
120
+ """
121
+ schedule: List[List[Tuple[int, int]]] = []
122
+ for shift in range(self.table_size):
123
+ partition: List[Tuple[int, int]] = [
124
+ (i, (i + shift) % self.table_size)
125
+ for i in range(self.table_size)
126
+ ]
127
+ schedule.append(partition)
128
+ return schedule
129
+
130
+ def _flush_group(
131
+ self,
132
+ partitions: List[Tuple[int, int]],
133
+ batch_size: int,
134
+ statistics: Dict[str, Any] | None = None,
135
+ ) -> Generator[BatchResults, None, None]:
136
+ """
137
+ Extract up to `batch_size` items from each cell in `partitions`, remove them from the buffer,
138
+ and yield a BatchResults whose chunks is a list of per-cell lists from these partitions.
139
+
140
+ Args:
141
+ partitions: Cell coordinates forming a diagonal group from the schedule.
142
+ batch_size: Max number of items to take from each cell.
143
+ statistics: Stats dict to attach to this emission (use {} for interim waves).
144
+ The "final" emission will pass the accumulated stats here.
145
+
146
+ Notes:
147
+ - Debug-only: prints a 2D matrix of cell sizes when logger is in DEBUG.
148
+ - batch_size in the returned BatchResults equals the number of emitted items.
149
+ """
150
+ self._log_buffer_matrix(partition=partitions)
151
+
152
+ cell_chunks: List[List[Any]] = []
153
+ for row, col in partitions:
154
+ q = self.buffer[row][col]
155
+ take = min(batch_size, len(q))
156
+ part = q[:take]
157
+ cell_chunks.append(part)
158
+ # remove flushed items
159
+ self.buffer[row][col] = q[take:]
160
+
161
+ emitted = sum(len(c) for c in cell_chunks)
162
+
163
+ result = BatchResults(
164
+ chunk=cell_chunks,
165
+ statistics=statistics or {},
166
+ batch_size=emitted,
167
+ )
168
+ yield result
169
+
170
+ def get_batch(self, max_batch__size: int) -> Generator[BatchResults, None, None]:
171
+ """
172
+ Consume upstream batches, split data across cells, and emit diagonal partitions:
173
+ - During streaming: emit a full partition when all its cells have >= max_batch__size.
174
+ - Also during streaming: if any cell in a partition builds up beyond a 'burst' threshold
175
+ (3 * of max_batch__size), emit that partition early, taking up to max_batch__size
176
+ from each cell.
177
+ - After source exhaustion: flush leftovers in waves capped at max_batch__size per cell.
178
+
179
+ Statistics policy:
180
+ * Every emission except the last carries {}.
181
+ * The last emission carries the accumulated upstream stats (unfiltered).
182
+ """
183
+ schedule = self._generate_batch_schedule()
184
+
185
+ accum_stats: Dict[str, Any] = {}
186
+ pending: BatchResults | None = None # hold back the most recent emission so we know what's final
187
+
188
+ burst_threshold = 3 * max_batch__size
189
+
190
+ for upstream in self.predecessor.get_batch(max_batch__size):
191
+ # accumulate upstream stats (unfiltered)
192
+ if upstream.statistics:
193
+ accum_stats = merge_summery(accum_stats, upstream.statistics)
194
+
195
+ # add data to cells
196
+ for item in upstream.chunk:
197
+ r, c = self._id_extractor(item)
198
+ if not (0 <= r < self.table_size and 0 <= c < self.table_size):
199
+ raise ValueError(f"partition id out of range: {(r, c)} for table_size={self.table_size}")
200
+ self.buffer[r][c].append(item)
201
+
202
+ # process partitions
203
+ for partition in schedule:
204
+ # normal full flush when all cells are ready
205
+ if all(len(self.buffer[r][c]) >= max_batch__size for r, c in partition):
206
+ br = next(self._flush_group(partition, max_batch__size, statistics={}))
207
+ if pending is not None:
208
+ yield pending
209
+ pending = br
210
+ continue
211
+
212
+ # burst flush if any cell backlog explodes
213
+ hot_cells = [(r, c, len(self.buffer[r][c])) for r, c in partition if
214
+ len(self.buffer[r][c]) >= burst_threshold]
215
+ if hot_cells:
216
+ top_r, top_c, top_len = max(hot_cells, key=lambda x: x[2])
217
+ self.logger.debug(
218
+ "burst flush: partition=%s threshold=%d top_cell=(%d,%d len=%d)",
219
+ partition, burst_threshold, top_r, top_c, top_len
220
+ )
221
+ br = next(self._flush_group(partition, max_batch__size, statistics={}))
222
+ if pending is not None:
223
+ yield pending
224
+ pending = br
225
+
226
+ # source exhausted: drain leftovers in capped waves (respecting batch size)
227
+ self.logger.debug("start flushing leftovers")
228
+ for partition in (p for p in schedule if any(self.buffer[r][c] for r, c in p)):
229
+ while any(self.buffer[r][c] for r, c in partition):
230
+ br = next(self._flush_group(partition, max_batch__size, statistics={}))
231
+ if pending is not None:
232
+ yield pending
233
+ pending = br
234
+
235
+ # final emission carries accumulated stats once
236
+ if pending is not None:
237
+ yield BatchResults(chunk=pending.chunk, statistics=accum_stats, batch_size=pending.batch_size)
238
+
239
+ def _log_buffer_matrix(self, *, partition: List[Tuple[int, int]]) -> None:
240
+ """
241
+ Dumps a compact 2D matrix of per-cell sizes (len of each buffer) when DEBUG is enabled.
242
+ """
243
+ if not self.logger.isEnabledFor(logging.DEBUG):
244
+ return
245
+
246
+ counts = [
247
+ [len(self.buffer[r][c]) for c in range(self.table_size)]
248
+ for r in range(self.table_size)
249
+ ]
250
+ marks = set(partition)
251
+
252
+ pad = max(2, len(str(self.table_size - 1)))
253
+ col_headers = [f"c{c:0{pad}d}" for c in range(self.table_size)]
254
+
255
+ rows = []
256
+ for r in range(self.table_size):
257
+ row_label = f"r{r:0{pad}d}"
258
+ row_vals = [f"[{v}]" if (r, c) in marks else f"{v}" for c, v in enumerate(counts[r])]
259
+ rows.append([row_label, *row_vals])
260
+
261
+ table = tabulate(
262
+ rows,
263
+ headers=["", *col_headers],
264
+ tablefmt="psql",
265
+ stralign="right",
266
+ disable_numparse=True,
267
+ )
268
+ self.logger.debug("buffer matrix:\n%s", table)
etl_lib/core/Task.py CHANGED
@@ -7,7 +7,7 @@ from datetime import datetime
7
7
 
8
8
  class TaskReturn:
9
9
  """
10
- Return object for the :py:func:`~Task.execute` function, transporting result information.
10
+ Return object for the :func:`~Task.execute` function, transporting result information.
11
11
  """
12
12
 
13
13
  success: bool
@@ -59,7 +59,7 @@ class Task:
59
59
  ETL job that can be executed.
60
60
 
61
61
  Provides reporting, time tracking and error handling.
62
- Implementations must provide the :py:func:`~run_internal` function.
62
+ Implementations must provide the :func:`~run_internal` function.
63
63
  """
64
64
 
65
65
  def __init__(self, context):
@@ -67,16 +67,17 @@ class Task:
67
67
  Construct a Task object.
68
68
 
69
69
  Args:
70
- context: :py:class:`etl_lib.core.ETLContext.ETLContext` instance. Will be available to subclasses.
70
+ context: :class:`~etl_lib.core.ETLContext.ETLContext` instance. Will be available to subclasses.
71
71
  """
72
72
  self.context = context
73
- self.logger = logging.getLogger(self.__class__.__name__)
73
+ """:class:`~etl_lib.core.ETLContext.ETLContext` giving access to various resources."""
74
+ self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
74
75
  self.uuid = str(uuid.uuid4())
75
76
  """Uniquely identifies a Task."""
76
77
  self.start_time: datetime
77
- """Time when the :py:func:`~execute` was called., `None` before."""
78
+ """Time when the :func:`~execute` was called., `None` before."""
78
79
  self.end_time: datetime
79
- """Time when the :py:func:`~execute` has finished., `None` before."""
80
+ """Time when the :func:`~execute` has finished., `None` before."""
80
81
  self.success: bool
81
82
  """True if the task has finished successful. False otherwise, `None` before the task has finished."""
82
83
  self.depth: int = 0
@@ -87,8 +88,9 @@ class Task:
87
88
  Executes the task.
88
89
 
89
90
  Implementations of this Interface should not overwrite this method, but provide the
90
- Task functionality inside :py:func:`~run_internal` which will be called from here.
91
- Will use the :py:class:`ProgressReporter` from the :py:attr:`~context` to report status updates.
91
+ Task functionality inside :func:`~run_internal` which will be called from here.
92
+ Will use the :class:`~etl_lib.core.ProgressReporter.ProgressReporter` from
93
+ :attr:`~etl_lib.core.Task.Task.context` to report status updates.
92
94
 
93
95
  Args:
94
96
  kwargs: will be passed to `run_internal`
@@ -34,6 +34,8 @@ class ValidationBatchProcessor(BatchProcessor):
34
34
  Each row in this file will contain the original data together with all validation errors for this row.
35
35
  """
36
36
  super().__init__(context, task, predecessor)
37
+ if model is not None and error_file is None:
38
+ raise ValueError('you must provide error file if the model is specified')
37
39
  self.error_file = error_file
38
40
  self.model = model
39
41
 
etl_lib/core/utils.py CHANGED
@@ -1,4 +1,6 @@
1
1
  import logging
2
+ import os
3
+ import signal
2
4
 
3
5
 
4
6
  def merge_summery(summery_1: dict, summery_2: dict) -> dict:
@@ -12,17 +14,56 @@ def merge_summery(summery_1: dict, summery_2: dict) -> dict:
12
14
 
13
15
  def setup_logging(log_file=None):
14
16
  """
15
- Set up logging to console and optionally to a log file.
16
-
17
- :param log_file: Path to the log file
18
- :type log_file: str, optional
17
+ Set up the logging. INFO is used for the root logger.
18
+ Via ETL_LIB_LOG_LEVEL environment variable, the log level of the library itself can be set to another level.
19
+ It also defaults to INFO.
19
20
  """
20
- handlers = [logging.StreamHandler()]
21
+ fmt = '%(asctime)s - %(levelname)s - %(name)s - [%(threadName)s] - %(message)s'
22
+ formatter = logging.Formatter(fmt)
23
+
24
+ root_handlers = [logging.StreamHandler()]
25
+ if log_file:
26
+ root_handlers.append(logging.FileHandler(log_file))
27
+ for h in root_handlers:
28
+ h.setLevel(logging.INFO)
29
+ h.setFormatter(formatter)
30
+ logging.basicConfig(level=logging.INFO, handlers=root_handlers, force=True)
31
+
32
+ raw = os.getenv("ETL_LIB_LOG_LEVEL", "INFO")
33
+ try:
34
+ etl_level = int(raw) if str(raw).isdigit() else getattr(logging, str(raw).upper())
35
+ except Exception:
36
+ etl_level = logging.DEBUG
37
+
38
+ etl_logger = logging.getLogger('etl_lib')
39
+ etl_logger.setLevel(etl_level)
40
+ etl_logger.propagate = False
41
+ etl_logger.handlers.clear()
42
+
43
+ dbg_console = logging.StreamHandler()
44
+ dbg_console.setLevel(logging.NOTSET)
45
+ dbg_console.setFormatter(formatter)
46
+ etl_logger.addHandler(dbg_console)
47
+
21
48
  if log_file:
22
- handlers.append(logging.FileHandler(log_file))
49
+ dbg_file = logging.FileHandler(log_file)
50
+ dbg_file.setLevel(logging.NOTSET)
51
+ dbg_file.setFormatter(formatter)
52
+ etl_logger.addHandler(dbg_file)
53
+
54
+
55
+ def add_sigint_handler(handler_to_add):
56
+ """
57
+ Register handler_to_add(signum, frame) to run on Ctrl-C,
58
+ chaining any previously registered handler afterward.
59
+ """
60
+ old_handler = signal.getsignal(signal.SIGINT)
61
+
62
+ def chained_handler(signum, frame):
63
+ # first, run the new handler
64
+ handler_to_add(signum, frame)
65
+ # then, if there was an old handler, call it
66
+ if callable(old_handler):
67
+ old_handler(signum, frame)
23
68
 
24
- logging.basicConfig(
25
- level=logging.INFO,
26
- format='%(asctime)s - %(levelname)s - %(message)s',
27
- handlers=handlers
28
- )
69
+ signal.signal(signal.SIGINT, chained_handler)
@@ -10,7 +10,7 @@ class CypherBatchSink(BatchProcessor):
10
10
  BatchProcessor to write batches of data to a Neo4j database.
11
11
  """
12
12
 
13
- def __init__(self, context: ETLContext, task: Task, predecessor: BatchProcessor, query: str):
13
+ def __init__(self, context: ETLContext, task: Task, predecessor: BatchProcessor, query: str, **kwargs):
14
14
  """
15
15
  Constructs a new CypherBatchSink.
16
16
 
@@ -20,16 +20,17 @@ class CypherBatchSink(BatchProcessor):
20
20
  predecessor: BatchProcessor which :func:`~get_batch` function will be called to receive batches to process.
21
21
  query: Cypher to write the query to Neo4j.
22
22
  Data will be passed as `batch` parameter.
23
- Therefor, the query should start with a `UNWIND $batch AS row`.
23
+ Therefore, the query should start with a `UNWIND $batch AS row`.
24
24
  """
25
25
  super().__init__(context, task, predecessor)
26
26
  self.query = query
27
27
  self.neo4j = context.neo4j
28
+ self.kwargs = kwargs
28
29
 
29
30
  def get_batch(self, batch_size: int) -> Generator[BatchResults, None, None]:
30
31
  assert self.predecessor is not None
31
32
 
32
33
  with self.neo4j.session() as session:
33
34
  for batch_result in self.predecessor.get_batch(batch_size):
34
- result = self.neo4j.query_database(session=session, query=self.query, batch=batch_result.chunk)
35
+ result = self.neo4j.query_database(session=session, query=self.query, batch=batch_result.chunk, **self.kwargs)
35
36
  yield append_result(batch_result, result.summery)
@@ -0,0 +1,36 @@
1
+ from typing import Generator
2
+ from sqlalchemy import text
3
+ from etl_lib.core.ETLContext import ETLContext
4
+ from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults, append_result
5
+ from etl_lib.core.Task import Task
6
+
7
+
8
+ class SQLBatchSink(BatchProcessor):
9
+ """
10
+ BatchProcessor to write batches of data to an SQL database.
11
+ """
12
+
13
+ def __init__(self, context: ETLContext, task: Task, predecessor: BatchProcessor, query: str):
14
+ """
15
+ Constructs a new SQLBatchSink.
16
+
17
+ Args:
18
+ context: ETLContext instance.
19
+ task: Task instance owning this batchProcessor.
20
+ predecessor: BatchProcessor which `get_batch` function will be called to receive batches to process.
21
+ query: SQL query to write data.
22
+ Data will be passed as a batch using parameterized statements (`:param_name` syntax).
23
+ """
24
+ super().__init__(context, task, predecessor)
25
+ self.query = query
26
+ self.engine = context.sql.engine
27
+
28
+ def get_batch(self, batch_size: int) -> Generator[BatchResults, None, None]:
29
+ assert self.predecessor is not None
30
+
31
+ with self.engine.connect() as conn:
32
+ with conn.begin():
33
+ for batch_result in self.predecessor.get_batch(batch_size):
34
+ conn.execute(text(self.query), batch_result.chunk)
35
+ yield append_result(batch_result, {"sql_rows_written": len(batch_result.chunk)})
36
+
@@ -0,0 +1,114 @@
1
+ import time
2
+ from typing import Generator, Callable, Optional, List, Dict
3
+
4
+ from psycopg2 import OperationalError as PsycopgOperationalError
5
+ from sqlalchemy import text
6
+ from sqlalchemy.exc import OperationalError as SAOperationalError, DBAPIError
7
+
8
+ from etl_lib.core.BatchProcessor import BatchResults, BatchProcessor
9
+ from etl_lib.core.ETLContext import ETLContext
10
+ from etl_lib.core.Task import Task
11
+
12
+
13
+ class SQLBatchSource(BatchProcessor):
14
+ def __init__(
15
+ self,
16
+ context: ETLContext,
17
+ task: Task,
18
+ query: str,
19
+ record_transformer: Optional[Callable[[dict], dict]] = None,
20
+ **kwargs
21
+ ):
22
+ """
23
+ Constructs a new SQLBatchSource.
24
+
25
+ Args:
26
+ context: :class:`etl_lib.core.ETLContext.ETLContext` instance.
27
+ task: :class:`etl_lib.core.Task.Task` instance owning this batchProcessor.
28
+ query: SQL query to execute.
29
+ record_transformer: Optional function to transform each row (dict format).
30
+ kwargs: Arguments passed as parameters with the query.
31
+ """
32
+ super().__init__(context, task)
33
+ self.query = query.strip().rstrip(";")
34
+ self.record_transformer = record_transformer
35
+ self.kwargs = kwargs
36
+
37
+ def _fetch_page(self, limit: int, offset: int) -> Optional[List[Dict]]:
38
+ """
39
+ Fetch a single batch of rows using LIMIT/OFFSET, with retry/backoff.
40
+
41
+ Each page is executed in its own transaction. On transient
42
+ disconnects or DB errors, it retries up to 3 times with exponential backoff.
43
+
44
+ Args:
45
+ limit: maximum number of rows to return.
46
+ offset: number of rows to skip before starting this page.
47
+
48
+ Returns:
49
+ A list of row dicts (after applying record_transformer, if any),
50
+ or None if no rows are returned.
51
+
52
+ Raises:
53
+ Exception: re-raises the last caught error on final failure.
54
+ """
55
+ paged_sql = f"{self.query} LIMIT :limit OFFSET :offset"
56
+ params = {**self.kwargs, "limit": limit, "offset": offset}
57
+ max_retries = 5
58
+ backoff = 2.0
59
+
60
+ for attempt in range(1, max_retries + 1):
61
+ try:
62
+ with self.context.sql.engine.connect() as conn:
63
+ with conn.begin():
64
+ rows = conn.execute(text(paged_sql), params).mappings().all()
65
+ result = [
66
+ self.record_transformer(dict(r)) if self.record_transformer else dict(r)
67
+ for r in rows
68
+ ]
69
+ return result if result else None
70
+
71
+ except (PsycopgOperationalError, SAOperationalError, DBAPIError) as err:
72
+
73
+ if attempt == max_retries:
74
+ self.logger.error(
75
+ f"Page fetch failed after {max_retries} attempts "
76
+ f"(limit={limit}, offset={offset}): {err}"
77
+ )
78
+ raise
79
+
80
+ self.logger.warning(
81
+ f"Transient DB error on page fetch {attempt}/{max_retries}: {err!r}, "
82
+ f"retrying in {backoff:.1f}s"
83
+ )
84
+ time.sleep(backoff)
85
+ backoff *= 2
86
+
87
+ return None
88
+
89
+ def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
90
+ """
91
+ Yield successive batches until the query is exhausted.
92
+
93
+ Calls _fetch_page() repeatedly, advancing the offset by the
94
+ number of rows returned. Stops when no more rows are returned.
95
+
96
+ Args:
97
+ max_batch_size: upper limit on rows per batch.
98
+
99
+ Yields:
100
+ BatchResults for each non-empty page.
101
+ """
102
+ offset = 0
103
+ while True:
104
+ chunk = self._fetch_page(max_batch_size, offset)
105
+ if not chunk:
106
+ break
107
+
108
+ yield BatchResults(
109
+ chunk=chunk,
110
+ statistics={"sql_rows_read": len(chunk)},
111
+ batch_size=len(chunk),
112
+ )
113
+
114
+ offset += len(chunk)
@@ -4,8 +4,8 @@ from etl_lib.core.Task import Task, TaskReturn
4
4
  class CreateReportingConstraintsTask(Task):
5
5
  """Creates the constraint in the REPORTER_DATABASE database."""
6
6
 
7
- def __init__(self, config):
8
- super().__init__(config)
7
+ def __init__(self, context):
8
+ super().__init__(context)
9
9
 
10
10
  def run_internal(self, **kwargs) -> TaskReturn:
11
11
  database = self.context.env("REPORTER_DATABASE")
@@ -14,24 +14,60 @@ from etl_lib.data_source.CSVBatchSource import CSVBatchSource
14
14
 
15
15
 
16
16
  class CSVLoad2Neo4jTask(Task):
17
- """
17
+ '''
18
18
  Loads the specified CSV file to Neo4j.
19
19
 
20
20
  Uses BatchProcessors to read, validate and write to Neo4j.
21
21
  The validation step is using pydantic, hence a Pydantic model needs to be provided.
22
- Rows that fail the validation, will be written to en error file. The location of the error file is determined as
22
+ Rows with fail validation will be written to en error file. The location of the error file is determined as
23
23
  follows:
24
24
 
25
- If the context env vars hold an entry `ETL_ERROR_PATH` the file will be place there, with the name set to name
25
+ If the context env vars hold an entry `ETL_ERROR_PATH` the file will be placed there, with the name set to name
26
26
  of the provided filename appended with `.error.json`
27
27
 
28
- If `ETL_ERROR_PATH` is not set, the file will be placed in the same directory as the CSV file.
29
- """
28
+ If `ETL_ERROR_PATH` is not set, the file will be placed in the same directory as the CSV file.
29
+
30
+ Example usage: (from the gtfs demo)
31
+
32
+ .. code-block:: python
33
+
34
+ class LoadStopsTask(CSVLoad2Neo4jTask):
35
+ class Stop(BaseModel):
36
+ id: str = Field(alias="stop_id")
37
+ name: str = Field(alias="stop_name")
38
+ latitude: float = Field(alias="stop_lat")
39
+ longitude: float = Field(alias="stop_lon")
40
+ platform_code: Optional[str] = None
41
+ parent_station: Optional[str] = None
42
+ type: Optional[str] = Field(alias="location_type", default=None)
43
+ timezone: Optional[str] = Field(alias="stop_timezone", default=None)
44
+ code: Optional[str] = Field(alias="stop_code", default=None)
45
+
46
+ def __init__(self, context: ETLContext, file: Path):
47
+ super().__init__(context, LoadStopsTask.Stop, file)
48
+
49
+ def task_name(self) -> str:
50
+ return f"{self.__class__.__name__}('{self.file}')"
51
+
52
+ def _query(self):
53
+ return """
54
+ UNWIND $batch AS row
55
+ MERGE (s:Stop {id: row.id})
56
+ SET s.name = row.name,
57
+ s.location= point({latitude: row.latitude, longitude: row.longitude}),
58
+ s.platformCode= row.platform_code,
59
+ s.parentStation= row.parent_station,
60
+ s.type= row.type,
61
+ s.timezone= row.timezone,
62
+ s.code= row.code
63
+ """
64
+
65
+ '''
30
66
  def __init__(self, context: ETLContext, model: Type[BaseModel], file: Path, batch_size: int = 5000):
31
67
  super().__init__(context)
32
68
  self.batch_size = batch_size
33
69
  self.model = model
34
- self.logger = logging.getLogger(self.__class__.__name__)
70
+ self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
35
71
  self.file = file
36
72
 
37
73
  def run_internal(self, **kwargs) -> TaskReturn: