neo4j-etl-lib 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
etl_lib/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ """
2
+ Building blocks for ETL pipelines.
3
+ """
4
+ __version__ = "0.0.2"
File without changes
@@ -0,0 +1,189 @@
1
+ from datetime import datetime
2
+
3
+ import click
4
+ import neo4j
5
+ from neo4j import GraphDatabase
6
+ from neo4j.time import DateTime
7
+ from tabulate import tabulate
8
+
9
+
10
+ def __convert_date_time(input_date_time) -> datetime | None:
11
+ if input_date_time is None:
12
+ return None
13
+ return input_date_time.to_native().strftime("%Y-%m-%d %H:%M")
14
+
15
+
16
+ def __duration_from_start_end(start_time: DateTime | None, end_time: DateTime | None) -> str | None:
17
+ if start_time is None or end_time is None:
18
+ return None
19
+
20
+ # Convert neo4j.time.DateTime to native Python datetime
21
+ start_time = start_time.to_native()
22
+ end_time = end_time.to_native()
23
+
24
+ # Calculate the duration as a timedelta
25
+ duration = end_time - start_time
26
+
27
+ # Extract hours, minutes, and seconds
28
+ total_seconds = int(duration.total_seconds())
29
+ hours = total_seconds // 3600
30
+ minutes = (total_seconds % 3600) // 60
31
+ seconds = total_seconds % 60
32
+
33
+ # Format as HH:MM:SS
34
+ return f"{hours}:{minutes:02}:{seconds:02}"
35
+
36
+
37
+ def __print_details(driver, run_id):
38
+ records, _, _ = driver.execute_query("""
39
+ MATCH (:ETLRun {uuid : $id})-[:HAS_SUB_TASK*]->(task)-[:HAS_STATS]->(stats)
40
+ WITH task, stats ORDER BY task.order ASC
41
+ RETURN task.task AS task, task.status AS status, properties(stats) AS stats
42
+ """, id=run_id, routing_=neo4j.RoutingControl.READ)
43
+
44
+ print("Showing detailed stats for each task. Task without non-zero stats are omitted.")
45
+ for record in records:
46
+ rows = [(key, value) for key, value in record["stats"].items() if value != 0]
47
+ if rows:
48
+ print(f"Showing statistics for Task '{record['task']}' with status '{record['status']}'")
49
+ print(tabulate(rows, headers=["Name", "Value"], tablefmt='psql'))
50
+
51
+
52
+ def __driver(ctx):
53
+ neo4j_uri = ctx.obj["neo4j_uri"]
54
+ neo4j_user = ctx.obj["neo4j_user"]
55
+ database_name = ctx.obj["database_name"]
56
+ neo4j_password = ctx.obj["neo4j_password"]
57
+ return GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password), database=database_name,
58
+ notifications_min_severity="OFF", user_agent="ETL CLI 0.1")
59
+
60
+
61
+ @click.group()
62
+ @click.option('--neo4j-uri', envvar='NEO4J_URI', help='Neo4j database URI')
63
+ @click.option('--neo4j-user', envvar='NEO4J_USERNAME', help='Neo4j username')
64
+ @click.option('--neo4j-password', envvar='NEO4J_PASSWORD', help='Neo4j password')
65
+ @click.option('--log-file', envvar='LOG_FILE', help='Path to the log file', default=None)
66
+ @click.option('--database-name', envvar='DATABASE_NAME', default='neo4j', help='Neo4j database name (default: neo4j)')
67
+ @click.pass_context
68
+ def cli(ctx, neo4j_uri, neo4j_user, neo4j_password, log_file, database_name):
69
+ """
70
+ Command-line tool to process files in INPUT_DIRECTORY.
71
+
72
+ Environment variables can be configured via a .env file or overridden via CLI options:
73
+
74
+ \b
75
+ - NEO4J_URI: Neo4j database URI
76
+ - NEO4J_USERNAME: Neo4j username
77
+ - NEO4J_PASSWORD: Neo4j password
78
+ - LOG_FILE: Path to the log file
79
+ - DATABASE_NAME: Neo4j database name (default: neo4j)
80
+ """
81
+
82
+ # Validate Neo4j connection details
83
+ if not neo4j_uri or not neo4j_user or not neo4j_password:
84
+ print(
85
+ "Neo4j connection details are incomplete. Please provide NEO4J_URL, NEO4J_USER, and NEO4J_PASSWORD.")
86
+ return
87
+
88
+ ctx.ensure_object(dict)
89
+ ctx.obj['neo4j_uri'] = neo4j_uri
90
+ ctx.obj['neo4j_user'] = neo4j_user
91
+ ctx.obj['neo4j_password'] = neo4j_password
92
+ ctx.obj['database_name'] = database_name
93
+ ctx.obj['log_file'] = log_file
94
+
95
+
96
+ @cli.command()
97
+ @click.option("--number-runs", default=10, help="Number of rows to process, defaults to 10", type=int)
98
+ @click.pass_context
99
+ def query(ctx, number_runs):
100
+ """
101
+ Retrieve the list of the last x etl runs from the database and display them.
102
+ """
103
+ print(f"Listing runs in database '{ctx.obj['database_name']}'")
104
+ with __driver(ctx) as driver:
105
+ records, _, _ = driver.execute_query("""
106
+ MATCH (r:ETLRun:ETLTask)
107
+ WITH r, r.name AS name, r.uuid AS id, r.startTime AS startTime, r.endTime AS endTime
108
+ CALL (r) {
109
+ MATCH (r)-[:HAS_STATS]->(stats)
110
+ WITH [k IN keys(stats) | stats[k] ] AS stats
111
+ UNWIND stats AS stat
112
+ RETURN sum(stat) AS changes
113
+ }
114
+ ORDER BY startTime DESC LIMIT $number_runs
115
+ RETURN name, id, startTime, endTime, changes
116
+ """, number_runs=number_runs, routing_=neo4j.RoutingControl.READ)
117
+ data = [
118
+ {
119
+ "name": record["name"], "ID": record["id"],
120
+ "startTime": __convert_date_time(record["startTime"]),
121
+ "endTime": __convert_date_time(record["endTime"]),
122
+ "changes": record["changes"]
123
+ } for record in records]
124
+
125
+ print(tabulate(data, headers='keys', tablefmt='psql'))
126
+
127
+
128
+ @cli.command()
129
+ @click.argument('run-id', required=True)
130
+ @click.option("--details", default=False, is_flag=True, help="Show stats for each task", type=bool)
131
+ @click.pass_context
132
+ def detail(ctx, run_id, details):
133
+ """
134
+ Show a breakdown of the task for the specified run, including statistics.
135
+ """
136
+ print(f"Showing details for run ID: {run_id}")
137
+ with __driver(ctx) as driver:
138
+ records, _, _ = driver.execute_query("""
139
+ MATCH (r:ETLRun {uuid : $id})-[:HAS_SUB_TASK*]->(task)
140
+ WITH task ORDER BY task.order ASC
141
+ CALL (task) {
142
+ MATCH (task)-[:HAS_STATS]->(stats)
143
+ WITH [k IN keys(stats) | stats[k] ] AS stats
144
+ UNWIND stats AS stat
145
+ RETURN sum(stat) AS changes
146
+ }
147
+ RETURN
148
+ task.task AS task, task.status AS status,
149
+ task.batches + ' / ' + coalesce(task.expected_batches, '-') AS batches,
150
+ task.startTime AS startTime, task.endTime AS endTime, changes
151
+ """, id=run_id, routing_=neo4j.RoutingControl.READ)
152
+ data = [
153
+ {
154
+ "task": record["task"],
155
+ "status": record["status"],
156
+ "batches": record["batches"],
157
+ "duration": __duration_from_start_end(record["startTime"], record["endTime"]),
158
+ "changes": sum(record.get("stats", {}).values())
159
+ }
160
+ for record in records
161
+ ]
162
+
163
+ print(tabulate(data, headers='keys', tablefmt='psql'))
164
+ if details:
165
+ __print_details(driver, run_id)
166
+
167
+
168
+ @cli.command()
169
+ @click.option('--run-id', required=False, help='Run ID to delete')
170
+ @click.option('--since', help='Delete runs since a specific date')
171
+ @click.option('--older', help='Delete runs older than a specific date')
172
+ @click.pass_context
173
+ def delete(ctx, run_id, since, older):
174
+ """
175
+ Delete runs based on run ID, date, or age. One and only one of --run-id, --since, or --older must be provided.
176
+ """
177
+ # Ensure mutual exclusivity
178
+ options = [run_id, since, older]
179
+ if sum(bool(opt) for opt in options) != 1:
180
+ print("You must specify exactly one of --run-id, --since, or --older.")
181
+ return
182
+
183
+ if run_id:
184
+ print(f"Deleting run ID: {run_id}")
185
+ elif since:
186
+ print(f"Deleting runs since: {since}")
187
+ elif older:
188
+ print(f"Deleting runs older than: {older}")
189
+ # Implement delete logic here
@@ -0,0 +1,88 @@
1
+ import abc
2
+ import logging
3
+ import sys
4
+ from dataclasses import dataclass, field
5
+ from typing import Generator
6
+
7
+ from etl_lib.core.ETLContext import ETLContext
8
+ from etl_lib.core.Task import Task
9
+ from etl_lib.core.utils import merge_summery
10
+
11
+
12
+ @dataclass
13
+ class BatchResults:
14
+ """
15
+ Return object of the :py:func:`~BatchProcessor.get_batch` method, wrapping a batched data together with meta information.
16
+ """
17
+ chunk: []
18
+ """The batch of data."""
19
+ statistics: dict = field(default_factory=dict)
20
+ """`dict` of statistic information, such as row processed, nodes writen, .."""
21
+ batch_size: int = field(default=sys.maxsize)
22
+ """size of the batch."""
23
+
24
+
25
+ def append_result(org: BatchResults, stats: dict) -> BatchResults:
26
+ """
27
+ Appends the stats dict to the provided `org`.
28
+
29
+ Args:
30
+ org: The original `BatchResults` object.
31
+ stats: dict containing statistics to be added to the org object.
32
+
33
+ Returns:
34
+ New `BatchResults` object, where the :py:attr:`~BatchResults.statistics` attribute is the merged result of the
35
+ provided parameters. Values in the dicts with the same key are added.
36
+
37
+ """
38
+ return BatchResults(chunk=org.chunk, statistics=merge_summery(org.statistics, stats),
39
+ batch_size=org.batch_size)
40
+
41
+
42
+ class BatchProcessor:
43
+ """
44
+ Allows assembly of :py:class:`etl_lib.core.Task.Task` out of smaller building blocks.
45
+
46
+ This way, functionally such as reading from a CSV file, writing to a database or validation
47
+ can be implemented and tested independently and re-used.
48
+
49
+ BatchProcessors form, a linked list, where each processor only knows about its predecessor.
50
+
51
+ BatchProcessors process data in batches. A batch of data is requested from the provided predecessors
52
+ :py:func:`~get_batch`
53
+ and returned in batches to the caller. Usage of `Generators` ensure that not all data must be loaded at once.
54
+ """
55
+
56
+ def __init__(self, context: ETLContext, task: Task, predecessor=None):
57
+ """
58
+ Constructs a new :py:class:`etl_lib.core.BatchProcessor` instance.
59
+
60
+ Args:
61
+ context: :py:class:`etl_lib.core.ETLContext.ETLContext` instance. Will be available to subclasses.
62
+ task: :py:class:`etl_lib.core.Task.Task` this processor is part of.
63
+ Needed for status reporting only.
64
+ predecessor: Source of batches for this processor.
65
+ Can be `None` of no predecessor is needed (such as when this processor is the start of the queue.
66
+ """
67
+ self.context = context
68
+ """:py:class:`etl_lib.core.ETLContext.ETLContext` instance. Providing access to general facilities."""
69
+ self.predecessor = predecessor
70
+ """Predecessor, used as a source of batches."""
71
+ self.logger = logging.getLogger(self.__class__.__name__)
72
+ self.task = task
73
+ """The :py:class:`etl_lib.core.Task.Task` owning instance."""
74
+
75
+ @abc.abstractmethod
76
+ def get_batch(self, max_batch__size: int) -> Generator[BatchResults, None, None]:
77
+ """
78
+ Provides a batch of data to the caller.
79
+
80
+ The batch itself could be called and processed from the provided predecessor or generated from other sources.
81
+
82
+ Args:
83
+ max_batch__size: The max size of the batch the caller expects to receive.
84
+
85
+ Returns
86
+ A generator that yields batches.
87
+ """
88
+ pass
@@ -0,0 +1,30 @@
1
+ from typing import Generator
2
+
3
+ from etl_lib.core.ETLContext import ETLContext
4
+ from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults, append_result
5
+ from etl_lib.core.Task import Task
6
+
7
+
8
+ class ClosedLoopBatchProcessor(BatchProcessor):
9
+ """
10
+ Reporting implementation of a BatchProcessor.
11
+
12
+ Meant to be the last entry in the list of :py:class:`etl_lib.core.BatchProcessor` driving the processing and
13
+ reporting updates of the processed batches using the :py:class:`etl_lib.core.ProgressReporter` from the context.
14
+ """
15
+
16
+ def __init__(self, context: ETLContext, task: Task, predecessor: BatchProcessor, expected_rows: int = None):
17
+ super().__init__(context, task, predecessor)
18
+ self.expected_rows = expected_rows
19
+
20
+ def get_batch(self, max_batch__size: int) -> Generator[BatchResults, None, None]:
21
+ assert self.predecessor is not None
22
+ batch_cnt = 0
23
+ result = BatchResults(chunk=[], statistics={}, batch_size=max_batch__size)
24
+ for batch in self.predecessor.get_batch(max_batch__size):
25
+ result = append_result(result, batch.statistics)
26
+ batch_cnt += 1
27
+ self.context.reporter.report_progress(self.task, batch_cnt, self.expected_rows, result.statistics)
28
+
29
+ self.logger.debug(result.statistics)
30
+ yield result
@@ -0,0 +1,136 @@
1
+ import logging
2
+ from typing import NamedTuple, Any
3
+
4
+ from graphdatascience import GraphDataScience
5
+ from neo4j import Driver, GraphDatabase, WRITE_ACCESS, SummaryCounters
6
+
7
+ from etl_lib.core.ProgressReporter import get_reporter
8
+
9
+
10
+ class QueryResult(NamedTuple):
11
+ """Result of a query against the neo4j database."""
12
+ data: []
13
+ """Data as returned from the query."""
14
+ summery: {}
15
+ """Counters as reported by neo4j. Contains entries such as `nodes_created`, `nodes_deleted`, etc."""
16
+
17
+
18
+ def append_results(r1: QueryResult, r2: QueryResult) -> QueryResult:
19
+ return QueryResult(r1.data + r2.data, r1.summery + r2.summery)
20
+
21
+
22
+ class Neo4jContext:
23
+ uri: str
24
+ auth: (str, str)
25
+ driver: Driver
26
+ database: str
27
+
28
+ def __init__(self, env_vars: dict):
29
+ """
30
+ Create a new Neo4j context.
31
+ Reads the following env_vars keys:
32
+ - `NEO4J_URI`,
33
+ - `NEO4J_USERNAME`,
34
+ - `NEO4J_PASSWORD`.
35
+ """
36
+ self.logger = logging.getLogger(self.__class__.__name__)
37
+ self.uri = env_vars["NEO4J_URI"]
38
+ self.auth = (env_vars["NEO4J_USERNAME"],
39
+ env_vars["NEO4J_PASSWORD"])
40
+ self.database = env_vars["NEO4J_DATABASE"]
41
+ self.__neo4j_connect()
42
+
43
+ def query_database(self, session, query, **kwargs) -> QueryResult:
44
+ """
45
+ Executes a Cypher query on the Neo4j database.
46
+ """
47
+ if isinstance(query, list):
48
+ results = []
49
+ for single_query in query:
50
+ result = self.query_database(session, single_query, **kwargs)
51
+ results = append_results(results, result)
52
+ return results
53
+ else:
54
+ try:
55
+ res = session.run(query, **kwargs)
56
+ counters = res.consume().counters
57
+
58
+ return QueryResult(res, self.__counters_2_dict(counters))
59
+
60
+ except Exception as e:
61
+ self.logger.error(e)
62
+ raise e
63
+
64
+ @staticmethod
65
+ def __counters_2_dict(counters: SummaryCounters):
66
+ return {
67
+ "constraints_added": counters.constraints_added,
68
+ "constraints_removed": counters.constraints_removed,
69
+ "indexes_added": counters.indexes_added,
70
+ "indexes_removed": counters.indexes_removed,
71
+ "labels_added": counters.labels_added,
72
+ "labels_removed": counters.labels_removed,
73
+ "nodes_created": counters.nodes_created,
74
+ "nodes_deleted": counters.nodes_deleted,
75
+ "properties_set": counters.properties_set,
76
+ "relationships_created": counters.relationships_created,
77
+ "relationships_deleted": counters.relationships_deleted,
78
+ }
79
+
80
+ def session(self, database=None):
81
+ if database is None:
82
+ return self.driver.session(database=self.database, default_access_mode=WRITE_ACCESS)
83
+ else:
84
+ return self.driver.session(database=database, default_access_mode=WRITE_ACCESS)
85
+
86
+ def gds(self, database=None) -> GraphDataScience:
87
+ if database is None:
88
+ return GraphDataScience.from_neo4j_driver(driver=self.driver, database=self.database)
89
+ else:
90
+ return GraphDataScience.from_neo4j_driver(driver=self.driver, database=database)
91
+
92
+ def __neo4j_connect(self):
93
+ self.driver = GraphDatabase.driver(uri=self.uri, auth=self.auth,
94
+ notifications_min_severity="OFF")
95
+ self.driver.verify_connectivity()
96
+ self.logger.info(
97
+ f"driver connected to instance at {self.uri} with username {self.auth[0]} and database {self.database}")
98
+
99
+
100
+ class ETLContext:
101
+ """
102
+ General context information.
103
+
104
+ Will be passed to all :py:class:`etl_lib.core.Task` to provide access to environment variables and functionally
105
+ deemed general enough that all parts of the ETL pipeline would need it.
106
+ """
107
+ neo4j: Neo4jContext
108
+ __env_vars: dict
109
+
110
+ def __init__(self, env_vars: dict):
111
+ """
112
+ Create a new ETLContext.
113
+
114
+ Args:
115
+ env_vars: Environment variables. Stored internally and can be accessed via :py:func:`~env` .
116
+
117
+ The context created will contain an :py:class:`~Neo4jContext` and a :py:class:`ProgressReporter`.
118
+ See there for keys used from the provided `env_vars` dict.
119
+ """
120
+ self.logger = logging.getLogger(self.__class__.__name__)
121
+ self.neo4j = Neo4jContext(env_vars)
122
+ self.__env_vars = env_vars
123
+ self.reporter = get_reporter(self)
124
+
125
+ def env(self, key: str) -> Any:
126
+ """
127
+ Returns the value of an entry in the `env_vars` dict.
128
+
129
+ Args:
130
+ key: name of the entry to read.
131
+
132
+ Returns:
133
+ va lue of the entry, or None if the key is not in the dict.
134
+ """
135
+ if key in self.__env_vars:
136
+ return self.__env_vars[key]
@@ -0,0 +1,210 @@
1
+ import logging
2
+ from datetime import datetime
3
+
4
+ from tabulate import tabulate
5
+
6
+ from etl_lib.core.Task import Task, TaskGroup
7
+
8
+
9
+ class ProgressReporter:
10
+ """
11
+ Responsible for reporting progress of :py:class:`etl_lib.core.Task` .
12
+
13
+ This specific implementation uses the python logging module to log progress.
14
+ Non-error logging is using the INFO level.
15
+ """
16
+ start_time: datetime
17
+ end_time: datetime
18
+
19
+ def __init__(self, context):
20
+ self.context = context
21
+ self.logger = logging.getLogger(self.__class__.__name__)
22
+
23
+ def register_tasks(self, main: Task):
24
+ """
25
+ Registers a :py:class:`etl_lib.core.Task` with this reporter.
26
+
27
+ Needs to be called once with the root task. The function will walk the tree of tasks and register them in turn.
28
+
29
+ Args:
30
+ main: Root of the task tree.
31
+ """
32
+ self.logger.info("\n" + self.__print_tree(main))
33
+
34
+ def started_task(self, task: Task) -> Task:
35
+ """
36
+ Marks the task as started.
37
+
38
+ Start the time keeping for this task and performs logging.
39
+
40
+ Args:
41
+ task: Task to be marked as started.
42
+
43
+ Returns:
44
+ The task that was provided.
45
+ """
46
+ task.start_time = datetime.now()
47
+ self.logger.info(f"{'\t' * task.depth}starting {task.task_name()}")
48
+ return task
49
+
50
+ def finished_task(self, task: Task, success: bool, summery: dict, error: str = None) -> Task:
51
+ """
52
+ Marks the task as finished.
53
+
54
+ Stops the time recording for the tasks and performs logging. Logging will include details from the provided summery.
55
+
56
+ Args:
57
+ task: Task to be marked as finished.
58
+ success: True if the task has successfully finished.
59
+ summery: statistics for this task (such as `nodes_created`)
60
+ error: If an exception occurred, the exception text should be provided here.
61
+
62
+ Returns:
63
+ Task to be marked as started.
64
+ """
65
+ task.end_time = datetime.now()
66
+ task.success = success
67
+ task.summery = summery
68
+
69
+ report = f"{'\t' * task.depth}finished {task.task_name()} with success: {success}"
70
+ if error is not None:
71
+ report += f", error: \n{error}"
72
+ else:
73
+ # for the logger, remove entries with 0, but keep them in the original for reporting
74
+ cleaned_summery = {key: value for key, value in summery.items() if value != 0}
75
+ if len(cleaned_summery) > 0:
76
+ report += f"\n{tabulate([cleaned_summery], headers='keys', tablefmt='psql')}"
77
+ self.logger.info(report)
78
+ return task
79
+
80
+ def report_progress(self, task: Task, batches: int, expected_batches: int, stats: dict) -> None:
81
+ """
82
+ Optionally provide updates during execution of a task, such as batches processed so far.
83
+
84
+ This is an optional call, as not all :py:class:`etl_lib.core.Task` need batching.
85
+
86
+ Args:
87
+ task: Task reporting updates.
88
+ batches: Number of batches processed so far.
89
+ expected_batches: Number of expected batches. Can be `None` if the overall number of
90
+ batches is not know before execution.
91
+ stats: dict of statistics so far (such as `nodes_created`).
92
+ """
93
+ pass
94
+
95
+ def __print_tree(self, task: Task, last=True, header='') -> str:
96
+ """Generates a tree view of the task tree."""
97
+ elbow = "└──"
98
+ pipe = "│ "
99
+ tee = "├──"
100
+ blank = " "
101
+ tree_string = header + (elbow if last else tee) + task.task_name() + "\n"
102
+ if isinstance(task, TaskGroup):
103
+ children = list(task.sub_tasks())
104
+ for i, c in enumerate(children):
105
+ tree_string += self.__print_tree(c, header=header + (blank if last else pipe),
106
+ last=i == len(children) - 1)
107
+ return tree_string
108
+
109
+
110
+ class Neo4jProgressReporter(ProgressReporter):
111
+ """
112
+ Extends the ProgressReporter to additionally write the status updates from the tasks to a Neo4j database.
113
+ """
114
+
115
+ def __init__(self, context, database: str):
116
+ """
117
+ Creates a new Neo4j progress reporter.
118
+
119
+ Args:
120
+ context: :py:class:`etl_lib.core.ETLContext` containing a Neo4jConnection instance.
121
+ database: Name of the database to write the status updates to.
122
+ """
123
+ super().__init__(context)
124
+ self.database = database
125
+ self.logger.info(f"progress reporting to database: {self.database}")
126
+ self.__create_constraints()
127
+
128
+ def register_tasks(self, root: Task, **kwargs):
129
+ super().register_tasks(root)
130
+
131
+ with self.context.neo4j.session(self.database) as session:
132
+ order = 0
133
+ session.run(
134
+ "CREATE (t:ETLTask:ETLRun {uuid:$id, task:$task, order:$order, name:$name, status: 'open'}) SET t +=$other",
135
+ id=root.uuid, order=order, task=root.__repr__(), name=root.task_name(), other=kwargs)
136
+ self.__persist_task(session, root, order)
137
+
138
+ def __persist_task(self, session, task: Task | TaskGroup, order: int) -> int:
139
+ """Writes task information to the database."""
140
+
141
+ if type(task) is Task:
142
+ order += 1
143
+ session.run(
144
+ """
145
+ MERGE (t:ETLTask { uuid: $id })
146
+ SET t.task=$task, t.order=$order, t.name=$name, t.status='open'
147
+ """,
148
+ id=task.uuid, task=task.__repr__(), order=order, name=task.task_name())
149
+ else:
150
+ for child in task.sub_tasks():
151
+ order += 1
152
+ session.run(
153
+ """
154
+ MATCH (p:ETLTask { uuid: $parent_id }) SET p.type='TaskGroup'
155
+ CREATE (t:ETLTask { uuid:$id, task:$task, order:$order, name:$name, status: 'open' })
156
+ CREATE (p)-[:HAS_SUB_TASK]->(t)
157
+ """,
158
+ parent_id=task.uuid, id=child.uuid, task=child.__repr__(), order=order, name=child.task_name())
159
+ if isinstance(child, TaskGroup):
160
+ order = self.__persist_task(session, child, order)
161
+ return order
162
+
163
+ def started_task(self, task: Task) -> Task:
164
+ super().started_task(task=task)
165
+ with self.context.neo4j.session(self.database) as session:
166
+ session.run("MATCH (t:ETLTask { uuid: $id }) SET t.startTime = $start_time, t.status= 'running'",
167
+ id=task.uuid,
168
+ start_time=task.start_time)
169
+ return task
170
+
171
+ def finished_task(self, task: Task, success: bool, summery: dict, error: str = None) -> Task:
172
+ super().finished_task(task=task, success=success, summery=summery, error=error)
173
+ if success:
174
+ status = "success"
175
+ else:
176
+ status = "failure"
177
+ with self.context.neo4j.session(self.database) as session:
178
+ session.run("""
179
+ MATCH (t:ETLTask {uuid:$id}) SET t.endTime = $end_time, t.status = $status, t.error = $error
180
+ CREATE (s:ETLStats) SET s=$summery
181
+ CREATE (t)-[:HAS_STATS]->(s)
182
+ """, id=task.uuid, end_time=task.end_time, summery=summery, status=status, error=error)
183
+ return task
184
+
185
+ def __create_constraints(self):
186
+ with self.context.neo4j.session(self.database) as session:
187
+ session.run("CREATE CONSTRAINT etl_task_unique IF NOT EXISTS FOR (n:ETLTask) REQUIRE n.uuid IS UNIQUE;")
188
+
189
+ def report_progress(self, task: Task, batches: int, expected_batches: int, stats: dict) -> None:
190
+ self.logger.debug(f"{batches=}, {expected_batches=}, {stats=}")
191
+ with self.context.neo4j.session(self.database) as session:
192
+ session.run("MATCH (t:ETLTask {uuid:$id}) SET t.batches =$batches, t.expected_batches =$expected_batches",
193
+ id=task.uuid, batches=batches, expected_batches=expected_batches)
194
+
195
+
196
+ def get_reporter(context) -> ProgressReporter:
197
+ """
198
+ Returns a ProgressReporter instance.
199
+
200
+ If the :py:class:`ETLContext <etl_lib.core.ETLContext>` env holds the key `REPORTER_DATABASE` then
201
+ a :py:class:`Neo4jProgressReporter` instance is created with the given database name.
202
+
203
+ Otherwise, a :py:class:`ProgressReporter` (no logging to database) instance will be created.
204
+ """
205
+
206
+ db = context.env("REPORTER_DATABASE")
207
+ if db is None:
208
+ return ProgressReporter(context)
209
+ else:
210
+ return Neo4jProgressReporter(context, db)