neo4j-etl-lib 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
etl_lib/core/Task.py ADDED
@@ -0,0 +1,267 @@
1
+ import abc
2
+ import logging
3
+ import uuid
4
+ from concurrent.futures import ThreadPoolExecutor, as_completed
5
+ from datetime import datetime
6
+
7
+
8
+ class TaskReturn:
9
+ """
10
+ Return object for the :py:func:`~Task.execute` function, transporting result information.
11
+ """
12
+
13
+ success: bool
14
+ """Success or failure of the task."""
15
+ summery: dict
16
+ """dict holding statistics about the task performed, such as rows inserted, updated."""
17
+ error: str
18
+ """Error message."""
19
+
20
+ def __init__(self, success: bool = True, summery: dict = None, error: str = None):
21
+ self.success = success
22
+ self.summery = summery if summery else {}
23
+ self.error = error
24
+
25
+ def __repr__(self):
26
+ return f"TaskReturn({self.success=}, {self.summery=}, {self.error=})"
27
+
28
+ def __add__(self, other):
29
+ """
30
+ Adding 2 instances of TaskReturn.
31
+
32
+ Args:
33
+ other: Instance to add.
34
+
35
+ Returns:
36
+ New TaskReturn instance. `success` is the logical AND of the instances.
37
+ `summery` is the merged dict. For the values of the same key the values are added.
38
+ """
39
+ if not isinstance(other, TaskReturn):
40
+ return NotImplemented
41
+
42
+ # Merge the summery dictionaries by summing their values
43
+ merged_summery = self.summery.copy()
44
+ for key, value in other.summery.items():
45
+ merged_summery[key] = merged_summery.get(key, 0) + value
46
+
47
+ # Combine success values and errors
48
+ combined_success = self.success and other.success
49
+ combined_error = f"{self.error or ''} | {other.error or ''}".strip(" |")
50
+
51
+ return TaskReturn(
52
+ success=combined_success, summery=merged_summery, error=combined_error
53
+ )
54
+
55
+
56
+ class Task:
57
+ """
58
+ ETL job that can be executed.
59
+
60
+ Provides reporting, time tracking and error handling.
61
+ Implementations must provide the :py:func:`~run_internal` function.
62
+ """
63
+
64
+ def __init__(self, context):
65
+ """
66
+ Construct a Task object.
67
+
68
+ Args:
69
+ context: :py:class:`etl_lib.core.ETLContext.ETLContext` instance. Will be available to subclasses.
70
+ """
71
+ self.context = context
72
+ self.logger = logging.getLogger(self.__class__.__name__)
73
+ self.uuid = str(uuid.uuid4())
74
+ """Uniquely identifies a Task."""
75
+ self.start_time: datetime
76
+ """Time when the :py:func:`~execute` was called., `None` before."""
77
+ self.end_time: datetime
78
+ """Time when the :py:func:`~execute` has finished., `None` before."""
79
+ self.success: bool
80
+ """True if the task has finished successful. False otherwise, `None` before the task has finished."""
81
+ self.summery: dict # TODO: still in use?
82
+ """Summery statistics about the task performed, such as rows inserted, updated."""
83
+ self.error: str # TODO: still in use?
84
+ self.depth: int = 0
85
+ """Level or depth of the task in the hierarchy. The root task is depth 0. Updated by the Reporter"""
86
+
87
+ def execute(self, **kwargs) -> TaskReturn:
88
+ """
89
+ Executes the task.
90
+
91
+ Implementations of this Interface should not overwrite this method, but provide the
92
+ Task functionality inside :py:func:`~run_internal` which will be called from here.
93
+ Will use the :py:class:`ProgressReporter` from the :py:attr:`~context` to report status updates.
94
+
95
+ Args:
96
+ kwargs: will be passed to `run_internal`
97
+ """
98
+ self.context.reporter.started_task(self)
99
+
100
+ try:
101
+ result = self.run_internal(**kwargs)
102
+ except Exception as e:
103
+ result = TaskReturn(success=False, summery={}, error=str(e))
104
+
105
+ self.context.reporter.finished_task(
106
+ task=self,
107
+ success=result.success,
108
+ summery=result.summery,
109
+ error=result.error,
110
+ )
111
+
112
+ return result
113
+
114
+ @abc.abstractmethod
115
+ def run_internal(self, **kwargs) -> TaskReturn:
116
+ """
117
+ Place to provide the logic to be performed.
118
+
119
+ This base class provides all the housekeeping and reporting, so that implementation must/should not need to care
120
+ about them.
121
+ Exceptions should not be captured by implementations. They are handled by this base class.
122
+
123
+ Args:
124
+ kwargs: will be passed to `run_internal`
125
+ Returns:
126
+ An instance of :py:class:`~etl_lib.core.Task.TaskReturn`.
127
+ """
128
+ pass
129
+
130
+ def abort_on_fail(self) -> bool:
131
+ """
132
+ Should the pipeline abort when this job fails.
133
+
134
+ Returns:
135
+ `True` indicates that no other Tasks should be executed if :py:func:`~run_internal` fails.
136
+ """
137
+ return True
138
+
139
+ def task_name(self) -> str:
140
+ """
141
+ Option to overwrite the name of this Task.
142
+
143
+ Name is used in reporting only.
144
+
145
+ Returns:
146
+ Sting describing the task. Defaults to the class name..
147
+ """
148
+ return self.__class__.__name__
149
+
150
+ def __repr__(self):
151
+ return f"Task({self.task_name()})"
152
+
153
+
154
+ class TaskGroup(Task):
155
+ """
156
+ Base class to allow wrapping of Task or TaskGroups to form a hierarchy of jobs.
157
+
158
+ Implementations only need to provide the Tasks to execute as an array.
159
+ The summery statistic object returned from the group execute method will be a merged/aggregated one.
160
+ """
161
+
162
+ def __init__(self, context, tasks: list[Task], name: str):
163
+ """
164
+ Construct a TaskGroup object.
165
+
166
+ Args:
167
+ context: :py:class:`etl_lib.core.ETLContext.ETLContext` instance.
168
+ tasks: a list of `:py:class:`etl_lib.core.Task.Rask` instances.
169
+ These will be executed in the order provided when :py:func:`~run_internal` is called.
170
+ name: short name of the TaskGroup for reporting.
171
+ """
172
+ super().__init__(context)
173
+ self.tasks = tasks
174
+ self.name = name
175
+
176
+ def sub_tasks(self) -> [Task]:
177
+ return self.tasks
178
+
179
+ def run_internal(self, **kwargs) -> TaskReturn:
180
+ ret = TaskReturn()
181
+ for task in self.tasks:
182
+ task_ret = task.execute(**kwargs)
183
+ if task_ret == False and task.abort_on_fail():
184
+ self.logger.warning(
185
+ f"Task {self.task_name()} failed. Aborting execution."
186
+ )
187
+ return task_ret
188
+ ret = ret + task_ret
189
+ return ret
190
+
191
+ def abort_on_fail(self):
192
+ for task in self.tasks:
193
+ if task.abort_on_fail():
194
+ return True
195
+
196
+ def task_name(self) -> str:
197
+ return self.name
198
+
199
+ def __repr__(self):
200
+ return f"TaskGroup({self.task_name()})"
201
+
202
+
203
+ class ParallelTaskGroup(TaskGroup):
204
+ """
205
+ Task group for parallel execution of jobs.
206
+
207
+ This class uses a ThreadPoolExecutor to run the provided tasks :py:func:`~run_internal` functions in parallel.
208
+ Care should be taken that the Tasks can operate without blocking.locking each other.
209
+ """
210
+
211
+ def __init__(self, context, tasks: list[Task], name: str):
212
+ """
213
+ Construct a TaskGroup object.
214
+
215
+ Args:
216
+ context: :py:class:`etl_lib.core.ETLContext.ETLContext` instance.
217
+ tasks: an array of `Task` instances.
218
+ These will be executed in parallel when :py:func:`~run_internal` is called.
219
+ The Tasks in the array could itself be other TaskGroups.
220
+ name: short name of the TaskGroup.
221
+ """
222
+ super().__init__(context, tasks, name)
223
+
224
+ def run_internal(self, **kwargs) -> TaskReturn:
225
+ combined_result = TaskReturn()
226
+
227
+ with ThreadPoolExecutor() as executor:
228
+ future_to_task = {
229
+ executor.submit(task.execute, **kwargs): task for task in self.tasks
230
+ }
231
+
232
+ for future in as_completed(future_to_task):
233
+ task = future_to_task[future]
234
+ try:
235
+ result = future.result()
236
+ combined_result += result
237
+
238
+ # If a task fails and it has abort_on_fail set, stop further execution
239
+ if not result.success and task.abort_on_fail():
240
+ self.logger.warning(
241
+ f"Task {task.task_name()} failed. Aborting execution of TaskGroup {self.task_name()}."
242
+ )
243
+ # Cancel any pending tasks
244
+ for f in future_to_task:
245
+ if not f.done():
246
+ f.cancel()
247
+ return combined_result
248
+
249
+ except Exception as e:
250
+ self.logger.error(
251
+ f"Task {task.task_name()} encountered an error: {str(e)}"
252
+ )
253
+ error_result = TaskReturn(success=False, summery={}, error=str(e))
254
+ combined_result += error_result
255
+
256
+ # Handle abort logic for unexpected exceptions
257
+ if task.abort_on_fail():
258
+ self.logger.warning(
259
+ f"Unexpected failure in {task.task_name()}. Aborting execution of TaskGroup {self.task_name()}."
260
+ )
261
+ # Cancel any pending tasks
262
+ for f in future_to_task:
263
+ if not f.done():
264
+ f.cancel()
265
+ return combined_result
266
+
267
+ return combined_result
@@ -0,0 +1,74 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Type, Generator
4
+
5
+ from pydantic import BaseModel, ValidationError
6
+
7
+ from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
8
+ from etl_lib.core.ETLContext import ETLContext
9
+ from etl_lib.core.Task import Task
10
+ from etl_lib.core.utils import merge_summery
11
+
12
+
13
+ class ValidationBatchProcessor(BatchProcessor):
14
+ """
15
+ Batch processor for validation, using Pydantic.
16
+ """
17
+
18
+ def __init__(self, context: ETLContext, task: Task, predecessor, model: Type[BaseModel], error_file: Path):
19
+ """
20
+ Constructs a new ValidationBatchProcessor.
21
+
22
+ The :py:class:`etl_lib.core.BatchProcessor.BatchResults` returned from the :py:func:`~get_batch` of this
23
+ implementation will contain the following additional entries:
24
+
25
+ - `valid_rows`: Number of valid rows.
26
+ - `invalid_rows`: Number of invalid rows.
27
+
28
+ Args:
29
+ context: :py:class:`etl_lib.core.ETLContext.ETLContext` instance.
30
+ task: :py:class:`etl_lib.core.Task.Task` instance owning this batchProcessor.
31
+ predecessor: BatchProcessor which :py:func:`~get_batch` function will be called to receive batches to process.
32
+ model: Pydantic model class used to validate each row in the batch.
33
+ error_file: Path to the file that will receive each row that did not pass validation.
34
+ Each row in this file will contain the original data together with all validation errors for this row.
35
+ """
36
+ super().__init__(context, task, predecessor)
37
+ self.error_file = error_file
38
+ self.model = model
39
+
40
+ def get_batch(self, max_batch__size: int) -> Generator[BatchResults, None, None]:
41
+ assert self.predecessor is not None
42
+
43
+ for batch in self.predecessor.get_batch(max_batch__size):
44
+ valid_rows = []
45
+ invalid_rows = []
46
+
47
+ for row in batch.chunk:
48
+ try:
49
+ # Validate and transform the row
50
+ validated_row = self.model(**row).model_dump()
51
+ valid_rows.append(validated_row)
52
+ except ValidationError as e:
53
+ # Collect invalid rows with errors
54
+ invalid_rows.append({"row": row, "errors": e.errors()})
55
+
56
+ # Write invalid rows to the error file
57
+ if invalid_rows:
58
+ with open(self.error_file, "a") as f:
59
+ for invalid in invalid_rows:
60
+ # the following is needed as ValueError (contained in 'ctx') is not json serializable
61
+ serializable = {"row": invalid["row"],
62
+ "errors": [{k: v for k, v in e.items() if k != "ctx"} for e in
63
+ invalid["errors"]]}
64
+ f.write(f"{json.dumps(serializable)}\n")
65
+
66
+ # Yield BatchResults with statistics
67
+ yield BatchResults(
68
+ chunk=valid_rows,
69
+ statistics=merge_summery(batch.statistics, {
70
+ "valid_rows": len(valid_rows),
71
+ "invalid_rows": len(invalid_rows)
72
+ }),
73
+ batch_size=len(batch.chunk)
74
+ )
File without changes
etl_lib/core/utils.py ADDED
@@ -0,0 +1,7 @@
1
+ def merge_summery(summery_1: dict, summery_2: dict) -> dict:
2
+ """
3
+ Helper function to merge dicts. Assuming that values are numbers.
4
+ If a key exists in both dicts, then the result will contain a key with the added values.
5
+ """
6
+ return {i: summery_1.get(i, 0) + summery_2.get(i, 0)
7
+ for i in set(summery_1).union(summery_2)}
@@ -0,0 +1,35 @@
1
+ from typing import Generator
2
+
3
+ from etl_lib.core.ETLContext import ETLContext
4
+ from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults, append_result
5
+ from etl_lib.core.Task import Task
6
+
7
+
8
+ class CypherBatchProcessor(BatchProcessor):
9
+ """
10
+ BatchProcessor to write batches of data to a Neo4j database.
11
+ """
12
+
13
+ def __init__(self, context: ETLContext, task: Task, predecessor: BatchProcessor, query: str):
14
+ """
15
+ Constructs a new CypherBatchProcessor.
16
+
17
+ Args:
18
+ context: :py:class:`etl_lib.core.ETLContext.ETLContext` instance.
19
+ task: :py:class:`etl_lib.core.Task.Task` instance owning this batchProcessor.
20
+ predecessor: BatchProcessor which :py:func:`~get_batch` function will be called to receive batches to process.
21
+ query: Cypher to write the query to Neo4j.
22
+ Data will be passed as `batch` parameter.
23
+ Therefor, the query should start with a `UNWIND $batch AS row`.
24
+ """
25
+ super().__init__(context, task, predecessor)
26
+ self.query = query
27
+ self.neo4j = context.neo4j
28
+
29
+ def get_batch(self, batch_size: int) -> Generator[BatchResults, None, None]:
30
+ assert self.predecessor is not None
31
+
32
+ with self.neo4j.session() as session:
33
+ for batch_result in self.predecessor.get_batch(batch_size):
34
+ result = self.neo4j.query_database(session=session, query=self.query, batch=batch_result.chunk)
35
+ yield append_result(batch_result, result.summery)
File without changes
@@ -0,0 +1,90 @@
1
+ import csv
2
+ import gzip
3
+ from pathlib import Path
4
+ from typing import Generator
5
+
6
+ from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults
7
+ from etl_lib.core.ETLContext import ETLContext
8
+ from etl_lib.core.Task import Task
9
+
10
+
11
+ class CSVBatchProcessor(BatchProcessor):
12
+ """
13
+ BatchProcessor that reads a CSV file using the `csv` package.
14
+
15
+ File can optionally be gzipped.
16
+ The returned batch of rows will have an additional `_row` column, containing the source row of the data,
17
+ starting with 0.
18
+ """
19
+
20
+ def __init__(self, csv_file: Path, context: ETLContext, task: Task, **kwargs):
21
+ """
22
+ Constructs a new CSVBatchProcessor.
23
+
24
+ Args:
25
+ csv_file: Path to the CSV file.
26
+ context: :py:class:`etl_lib.core.ETLContext.ETLContext` instance.
27
+ kwargs: Will be passed on to the `csv.DictReader` providing a way to customise the reading to different
28
+ csv formats.
29
+ """
30
+ super().__init__(context, task)
31
+ self.csv_file = csv_file
32
+ self.kwargs = kwargs
33
+
34
+ def get_batch(self, max_batch__size: int) -> Generator[BatchResults]:
35
+ for batch_size, chunks_ in self.read_csv(self.csv_file, batch_size=max_batch__size, **self.kwargs):
36
+ yield BatchResults(chunk=chunks_, statistics={"csv_lines_read": batch_size}, batch_size=batch_size)
37
+
38
+ def read_csv(self, file: Path, batch_size: int, **kwargs):
39
+ if file.suffix == ".gz":
40
+ with gzip.open(file, "rt", encoding='utf-8-sig') as f:
41
+ yield from self.__parse_csv(batch_size, file=f, **kwargs)
42
+ else:
43
+ with open(file, "rt", encoding='utf-8-sig') as f:
44
+ yield from self.__parse_csv(batch_size, file=f, **kwargs)
45
+
46
+ def __parse_csv(self, batch_size, file, **kwargs):
47
+ csv_file = csv.DictReader(file, **kwargs)
48
+ yield from self.__split_to_batches(csv_file, batch_size)
49
+
50
+ def __split_to_batches(self, source: [dict], batch_size):
51
+ """
52
+ Splits the provided source into batches.
53
+
54
+ Args:
55
+ source: Anything that can be loop over, ideally, this should also be a generator
56
+ batch_size: desired batch size
57
+
58
+ Returns:
59
+ generator object to loop over the batches. Each batch is an Array.
60
+ """
61
+ cnt = 0
62
+ batch_ = []
63
+ for i in source:
64
+ i["_row"] = cnt
65
+ cnt += 1
66
+ batch_.append(self.__clean_dict(i))
67
+ if len(batch_) == batch_size:
68
+ yield len(batch_), batch_
69
+ batch_ = []
70
+ if len(batch_) > 0:
71
+ yield len(batch_), batch_
72
+
73
+ def __clean_dict(self, input_dict):
74
+ """
75
+ Needed in Python versions < 3.13
76
+ Removes entries from the dictionary where:
77
+ - The value is an empty string
78
+ - The key is NoneType
79
+
80
+ Args:
81
+ input_dict (dict): The dictionary to clean.
82
+
83
+ Returns:
84
+ dict: A cleaned dictionary.
85
+ """
86
+ return {
87
+ k: (None if isinstance(v, str) and v.strip() == "" else v)
88
+ for k, v in input_dict.items()
89
+ if k is not None
90
+ }
File without changes
@@ -0,0 +1,29 @@
1
+ import abc
2
+
3
+ from etl_lib.core.ETLContext import ETLContext
4
+ from etl_lib.core.Task import Task, TaskReturn
5
+ from etl_lib.core.utils import merge_summery
6
+
7
+
8
+ class ExecuteCypherTask(Task):
9
+
10
+ def __init__(self, context: ETLContext):
11
+ super().__init__(context)
12
+ self.context = context
13
+
14
+ def run_internal(self, **kwargs) -> TaskReturn:
15
+ with self.context.neo4j.session() as session:
16
+
17
+ if isinstance(self._query(), list):
18
+ stats = {}
19
+ for query in self._query():
20
+ result = self.context.neo4j.query_database(session=session, query=query, **kwargs)
21
+ stats = merge_summery(stats, result.summery)
22
+ return TaskReturn(True, stats)
23
+ else:
24
+ result = self.context.neo4j.query_database(session=session, query=self._query(), **kwargs)
25
+ return TaskReturn(True, result.summery)
26
+
27
+ @abc.abstractmethod
28
+ def _query(self) -> str | list[str]:
29
+ pass
@@ -0,0 +1,44 @@
1
+ from etl_lib.core.Task import Task, TaskReturn
2
+
3
+ def transform_dict(input_dict):
4
+ """
5
+ Recursively transforms the input dictionary by converting any dictionary or list values to string representations.
6
+
7
+ Helpful to transform a gds call return into a storable representation
8
+ param: input_dict (dict): The input dictionary with values that can be of any type.
9
+
10
+ Returns:
11
+ dict: A new dictionary with transformed values.
12
+ """
13
+ def transform_value(value):
14
+ if isinstance(value, dict):
15
+ return {k: transform_value(v) for k, v in value.items()}
16
+ elif isinstance(value, list):
17
+ return str(value)
18
+ else:
19
+ return value
20
+
21
+ return {key: transform_value(value) for key, value in input_dict.items()}
22
+
23
+
24
+ class GDSTask(Task):
25
+
26
+ def __init__(self, context, func):
27
+ """
28
+ Function that uses the gds client to perform tasks. See the following example:
29
+
30
+ def gds_fun(etl_context):
31
+ with etl_context.neo4j.gds() as gds:
32
+ gds.graph.drop("neo4j-offices", failIfMissing=False)
33
+ g_office, project_result = gds.graph.project("neo4j-offices", "City", "FLY_TO")
34
+ mutate_result = gds.pageRank.mutate(g_office, tolerance=0.5, mutateProperty="rank")
35
+ return TaskReturn(success=True, summery=transform_dict(mutate_result.to_dict()))
36
+
37
+ :param context: The ETLContext to use. Provides the gds client to the func via `etl_context.neo4j.gds()`
38
+ :param func: a function that expects a param `etl_context` and returns a `TaskReturn` object.
39
+ """
40
+ super().__init__(context)
41
+ self.func = func
42
+
43
+ def run_internal(self, **kwargs) -> TaskReturn:
44
+ return self.func(etl_context= self.context, **kwargs)
File without changes
@@ -0,0 +1,41 @@
1
+ import abc
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Type
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from etl_lib.core.ETLContext import ETLContext
9
+ from etl_lib.core.ClosedLoopBatchProcessor import ClosedLoopBatchProcessor
10
+ from etl_lib.core.Task import Task, TaskReturn
11
+ from etl_lib.core.ValidationBatchProcessor import ValidationBatchProcessor
12
+ from etl_lib.data_sink.CypherBatchProcessor import CypherBatchProcessor
13
+ from etl_lib.data_source.CSVBatchProcessor import CSVBatchProcessor
14
+
15
+
16
+ class CSVLoad2Neo4jTasks(Task):
17
+
18
+ def __init__(self, context: ETLContext, model: Type[BaseModel], file: Path, batch_size: int = 5000):
19
+ super().__init__(context)
20
+ self.batch_size = batch_size
21
+ self.model = model
22
+ self.logger = logging.getLogger(self.__class__.__name__)
23
+ self.file = file
24
+
25
+ def run_internal(self, **kwargs) -> TaskReturn:
26
+ error_file = self.file.with_suffix(".error.json")
27
+
28
+ csv = CSVBatchProcessor(self.file, self.context, self)
29
+ validator = ValidationBatchProcessor(self.context, self, csv, self.model, error_file)
30
+ cypher = CypherBatchProcessor(self.context, self, validator, self._query())
31
+ end = ClosedLoopBatchProcessor(self.context, self, cypher)
32
+ result = next(end.get_batch(self.batch_size))
33
+
34
+ return TaskReturn(True, result.statistics)
35
+
36
+ def __repr__(self):
37
+ return f"{self.__class__.__name__}({self.file})"
38
+
39
+ @abc.abstractmethod
40
+ def _query(self):
41
+ pass
File without changes