neo4j-etl-lib 0.1.0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/PKG-INFO +13 -2
  2. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/README.md +7 -1
  3. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/pyproject.toml +4 -2
  4. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/__init__.py +1 -1
  5. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/cli/run_tools.py +45 -13
  6. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/core/ClosedLoopBatchProcessor.py +8 -2
  7. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/core/ETLContext.py +50 -18
  8. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/core/ProgressReporter.py +14 -16
  9. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/core/Task.py +3 -7
  10. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/data_sink/CypherBatchSink.py +4 -3
  11. neo4j_etl_lib-0.2.0/src/etl_lib/data_sink/SQLBatchSink.py +36 -0
  12. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/data_source/CypherBatchSource.py +18 -3
  13. neo4j_etl_lib-0.2.0/src/etl_lib/data_source/SQLBatchSource.py +60 -0
  14. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/task/CreateReportingConstraintsTask.py +2 -2
  15. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/task/ExecuteCypherTask.py +2 -2
  16. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/task/data_loading/CSVLoad2Neo4jTask.py +41 -5
  17. neo4j_etl_lib-0.2.0/src/etl_lib/task/data_loading/SQLLoad2Neo4jTask.py +90 -0
  18. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/test_utils/utils.py +20 -2
  19. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/LICENSE +0 -0
  20. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/cli/__init__.py +0 -0
  21. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/core/BatchProcessor.py +0 -0
  22. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/core/ValidationBatchProcessor.py +0 -0
  23. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/core/__init__.py +0 -0
  24. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/core/utils.py +0 -0
  25. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/data_sink/CSVBatchSink.py +0 -0
  26. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/data_sink/__init__.py +0 -0
  27. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/data_source/CSVBatchSource.py +0 -0
  28. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/data_source/__init__.py +0 -0
  29. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/task/GDSTask.py +0 -0
  30. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/task/__init__.py +0 -0
  31. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/task/data_loading/__init__.py +0 -0
  32. {neo4j_etl_lib-0.1.0 → neo4j_etl_lib-0.2.0}/src/etl_lib/test_utils/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: neo4j-etl-lib
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: Building blocks for ETL pipelines.
5
5
  Keywords: etl,graph,database
6
6
  Author-email: Bert Radke <bert.radke@pm.me>
@@ -19,6 +19,7 @@ Requires-Dist: neo4j>=5.27.0; python_version >= '3.7'
19
19
  Requires-Dist: python-dotenv>=1.0.1; python_version >= '3.8'
20
20
  Requires-Dist: tabulate>=0.9.0; python_version >= '3.7'
21
21
  Requires-Dist: click>=8.1.8; python_version >= '3.7'
22
+ Requires-Dist: pydantic[email-validator]
22
23
  Requires-Dist: pytest>=8.3.0 ; extra == "dev" and ( python_version >= '3.8')
23
24
  Requires-Dist: testcontainers[neo4j]==4.9.0 ; extra == "dev" and ( python_version >= '3.9' and python_version < '4.0')
24
25
  Requires-Dist: pytest-cov ; extra == "dev"
@@ -31,11 +32,15 @@ Requires-Dist: pydata-sphinx-theme ; extra == "dev"
31
32
  Requires-Dist: sphinx-autodoc-typehints ; extra == "dev"
32
33
  Requires-Dist: sphinxcontrib-napoleon ; extra == "dev"
33
34
  Requires-Dist: sphinx-autoapi ; extra == "dev"
35
+ Requires-Dist: sqlalchemy ; extra == "dev"
36
+ Requires-Dist: psycopg2-binary ; extra == "dev"
34
37
  Requires-Dist: graphdatascience>=1.13 ; extra == "gds" and ( python_version >= '3.9')
38
+ Requires-Dist: sqlalchemy ; extra == "sql"
35
39
  Project-URL: Documentation, https://neo-technology-field.github.io/python-etl-lib/index.html
36
40
  Project-URL: Home, https://github.com/neo-technology-field/python-etl-lib
37
41
  Provides-Extra: dev
38
42
  Provides-Extra: gds
43
+ Provides-Extra: sql
39
44
 
40
45
  # Neo4j ETL Toolbox
41
46
 
@@ -43,7 +48,13 @@ A Python library of building blocks to assemble etl pipelines.
43
48
 
44
49
  Complete documentation can be found on https://neo-technology-field.github.io/python-etl-lib/index.html
45
50
 
46
- See https://github.com/neo-technology-field/python-etl-lib/tree/main/examples/gtfs for an example project.
51
+ See https://github.com/neo-technology-field/python-etl-lib/tree/main/examples/gtfs
52
+
53
+ or
54
+
55
+ https://github.com/neo-technology-field/python-etl-lib/tree/main/examples/musicbrainz
56
+
57
+ for example projects.
47
58
 
48
59
 
49
60
  The library can be installed via
@@ -4,7 +4,13 @@ A Python library of building blocks to assemble etl pipelines.
4
4
 
5
5
  Complete documentation can be found on https://neo-technology-field.github.io/python-etl-lib/index.html
6
6
 
7
- See https://github.com/neo-technology-field/python-etl-lib/tree/main/examples/gtfs for an example project.
7
+ See https://github.com/neo-technology-field/python-etl-lib/tree/main/examples/gtfs
8
+
9
+ or
10
+
11
+ https://github.com/neo-technology-field/python-etl-lib/tree/main/examples/musicbrainz
12
+
13
+ for example projects.
8
14
 
9
15
 
10
16
  The library can be installed via
@@ -26,7 +26,8 @@ dependencies = [
26
26
  "neo4j>=5.27.0; python_version >= '3.7'",
27
27
  "python-dotenv>=1.0.1; python_version >= '3.8'",
28
28
  "tabulate>=0.9.0; python_version >= '3.7'",
29
- "click>=8.1.8; python_version >= '3.7'"
29
+ "click>=8.1.8; python_version >= '3.7'",
30
+ "pydantic[email_validator]"
30
31
  ]
31
32
 
32
33
  [project.optional-dependencies]
@@ -35,9 +36,10 @@ dev = [
35
36
  "testcontainers[neo4j]==4.9.0; python_version >= '3.9' and python_version < '4.0'",
36
37
  "pytest-cov", "bumpver", "isort", "pip-tools",
37
38
  "sphinx", "sphinx-rtd-theme", "pydata-sphinx-theme", "sphinx-autodoc-typehints",
38
- "sphinxcontrib-napoleon", "sphinx-autoapi"
39
+ "sphinxcontrib-napoleon", "sphinx-autoapi", "sqlalchemy", "psycopg2-binary"
39
40
  ]
40
41
  gds = ["graphdatascience>=1.13; python_version >= '3.9'"]
42
+ sql = ["sqlalchemy"]
41
43
 
42
44
  [project.urls]
43
45
  Home = "https://github.com/neo-technology-field/python-etl-lib"
@@ -1,4 +1,4 @@
1
1
  """
2
2
  Building blocks for ETL pipelines.
3
3
  """
4
- __version__ = "0.1.0"
4
+ __version__ = "0.2.0"
@@ -55,7 +55,7 @@ def __driver(ctx):
55
55
  database_name = ctx.obj["database_name"]
56
56
  neo4j_password = ctx.obj["neo4j_password"]
57
57
  return GraphDatabase.driver(neo4j_uri, auth=(neo4j_user, neo4j_password), database=database_name,
58
- notifications_min_severity="OFF", user_agent="ETL CLI 0.1")
58
+ notifications_min_severity="OFF", user_agent="ETL CLI")
59
59
 
60
60
 
61
61
  @click.group()
@@ -67,7 +67,7 @@ def __driver(ctx):
67
67
  @click.pass_context
68
68
  def cli(ctx, neo4j_uri, neo4j_user, neo4j_password, log_file, database_name):
69
69
  """
70
- Command-line tool to process files in INPUT_DIRECTORY.
70
+ Command-line tool for ETL pipelines.
71
71
 
72
72
  Environment variables can be configured via a .env file or overridden via CLI options:
73
73
 
@@ -165,25 +165,57 @@ def detail(ctx, run_id, details):
165
165
  __print_details(driver, run_id)
166
166
 
167
167
 
168
+ # noinspection PyTypeChecker
168
169
  @cli.command()
169
- @click.option('--run-id', required=False, help='Run ID to delete')
170
- @click.option('--since', help='Delete runs since a specific date')
171
- @click.option('--older', help='Delete runs older than a specific date')
170
+ @click.option('--run-id', required=False, type=str, help='Run IDs to delete, works with comma separated list')
171
+ @click.option('--before', type=click.DateTime(formats=["%Y-%m-%d"]), help='Delete runs before a specific date in format YYYY-MM-DD')
172
+ @click.option('--older', help='Delete runs older than x days', type=int)
172
173
  @click.pass_context
173
- def delete(ctx, run_id, since, older):
174
+ def delete(ctx, run_id, before, older):
174
175
  """
175
- Delete runs based on run ID, date, or age. One and only one of --run-id, --since, or --older must be provided.
176
+ Delete runs based on run ID, date, or age. One and only one of --run-id, --before, or --older must be provided.
176
177
  """
177
178
  # Ensure mutual exclusivity
178
- options = [run_id, since, older]
179
+ options = [run_id, before, older]
179
180
  if sum(bool(opt) for opt in options) != 1:
180
- print("You must specify exactly one of --run-id, --since, or --older.")
181
+ print("You must specify exactly one of --run-id, --before, or --older.")
181
182
  return
182
183
 
183
184
  if run_id:
184
- print(f"Deleting run ID: {run_id}")
185
- elif since:
186
- print(f"Deleting runs since: {since}")
185
+ ids = run_id.split(',')
186
+ delete_runs(ctx, ids)
187
+ elif before:
188
+ print(f"Deleting runs before: {before}")
189
+ with __driver(ctx) as driver:
190
+ record= driver.execute_query(
191
+ """MATCH (r:ETLRun) WHERE date(r.startTime) < date($before)
192
+ RETURN collect(r.uuid) AS ids
193
+ """,
194
+ result_transformer_=neo4j.Result.single,
195
+ before=before)
196
+ ids = record[0]
197
+ delete_runs(ctx, ids)
198
+
187
199
  elif older:
188
200
  print(f"Deleting runs older than: {older}")
189
- # Implement delete logic here
201
+ with __driver(ctx) as driver:
202
+ record = driver.execute_query(
203
+ """MATCH (r:ETLRun) WHERE date(r.startTime) < (date() - duration({days: $days}))
204
+ RETURN collect(r.uuid) AS ids
205
+ """,
206
+ result_transformer_=neo4j.Result.single,
207
+ days=older)
208
+ ids = record[0]
209
+ delete_runs(ctx, ids)
210
+
211
+
212
+ def delete_runs(ctx, ids):
213
+ print(f"Deleting run IDs: {ids}")
214
+ with __driver(ctx) as driver:
215
+ records, _, _ = driver.execute_query(
216
+ """
217
+ MATCH (r:ETLRun)-[*]->(n) WHERE r.uuid IN $ids
218
+ DETACH DELETE n
219
+ DETACH DELETE r
220
+ """, ids=ids, routing_=neo4j.RoutingControl.WRITE)
221
+ print(f"Deleted run IDs: {ids} successfully")
@@ -1,7 +1,7 @@
1
1
  from typing import Generator
2
2
 
3
- from etl_lib.core.ETLContext import ETLContext
4
3
  from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults, append_result
4
+ from etl_lib.core.ETLContext import ETLContext
5
5
  from etl_lib.core.Task import Task
6
6
 
7
7
 
@@ -24,7 +24,13 @@ class ClosedLoopBatchProcessor(BatchProcessor):
24
24
  for batch in self.predecessor.get_batch(max_batch__size):
25
25
  result = append_result(result, batch.statistics)
26
26
  batch_cnt += 1
27
- self.context.reporter.report_progress(self.task, batch_cnt, self.expected_rows, result.statistics)
27
+ self.context.reporter.report_progress(self.task, batch_cnt, self._safe_calculate_count(max_batch__size),
28
+ result.statistics)
28
29
 
29
30
  self.logger.debug(result.statistics)
30
31
  yield result
32
+
33
+ def _safe_calculate_count(self, batch_size: int | None) -> int:
34
+ if not self.expected_rows or not batch_size:
35
+ return 0
36
+ return (self.expected_rows + batch_size - 1) // batch_size # ceiling division
@@ -1,9 +1,26 @@
1
1
  import logging
2
2
  from typing import NamedTuple, Any
3
3
 
4
- from graphdatascience import GraphDataScience
4
+ try:
5
+ from graphdatascience import GraphDataScience
6
+ gds_available = False
7
+ except ImportError:
8
+ gds_available = False
9
+ logging.info("Graph Data Science not installed, skipping")
10
+ GraphDataScience = None
11
+
5
12
  from neo4j import GraphDatabase, WRITE_ACCESS, SummaryCounters
6
13
 
14
+ try:
15
+ from sqlalchemy import create_engine
16
+ from sqlalchemy.engine import Engine
17
+ sqlalchemy_available = True
18
+ except ImportError:
19
+ sqlalchemy_available = False
20
+ logging.info("SQL Alchemy not installed, skipping")
21
+ create_engine = None # this and next line needed to prevent PyCharm warning
22
+ Engine = None
23
+
7
24
  from etl_lib.core.ProgressReporter import get_reporter
8
25
 
9
26
 
@@ -99,22 +116,6 @@ class Neo4jContext:
99
116
  else:
100
117
  return self.driver.session(database=database, default_access_mode=WRITE_ACCESS)
101
118
 
102
- def gds(self, database=None) -> GraphDataScience:
103
- """
104
- Creates a new GraphDataScience client.
105
-
106
- Args:
107
- database: Name of the database to use for this dgs client.
108
- If not provided, the database name provided during construction will be used.
109
-
110
- Returns:
111
- gds client.
112
- """
113
- if database is None:
114
- return GraphDataScience.from_neo4j_driver(driver=self.driver, database=self.database)
115
- else:
116
- return GraphDataScience.from_neo4j_driver(driver=self.driver, database=database)
117
-
118
119
  def __neo4j_connect(self):
119
120
  self.driver = GraphDatabase.driver(uri=self.uri, auth=self.auth,
120
121
  notifications_min_severity="OFF")
@@ -122,6 +123,32 @@ class Neo4jContext:
122
123
  self.logger.info(
123
124
  f"driver connected to instance at {self.uri} with username {self.auth[0]} and database {self.database}")
124
125
 
126
+ def gds(neo4j_context) -> GraphDataScience:
127
+ """
128
+ Creates a new GraphDataScience client.
129
+
130
+ Args:
131
+ neo4j_context: Neo4j context containing driver and database name.
132
+
133
+ Returns:
134
+ gds client.
135
+ """
136
+ return GraphDataScience.from_neo4j_driver(driver=neo4j_context.driver, database=neo4j_context.database)
137
+
138
+
139
+ if sqlalchemy_available:
140
+ class SQLContext:
141
+ def __init__(self, database_url: str, pool_size: int = 10, max_overflow: int = 20):
142
+ """
143
+ Initializes the SQL context with an SQLAlchemy engine.
144
+
145
+ Args:
146
+ database_url (str): SQLAlchemy connection URL.
147
+ pool_size (int): Number of connections to maintain in the pool.
148
+ max_overflow (int): Additional connections allowed beyond pool_size.
149
+ """
150
+ self.engine: Engine = create_engine(database_url, pool_size=pool_size, max_overflow=max_overflow)
151
+
125
152
 
126
153
  class ETLContext:
127
154
  """
@@ -145,6 +172,11 @@ class ETLContext:
145
172
  self.neo4j = Neo4jContext(env_vars)
146
173
  self.__env_vars = env_vars
147
174
  self.reporter = get_reporter(self)
175
+ sql_uri = self.env("SQLALCHEMY_URI")
176
+ if sql_uri is not None and sqlalchemy_available:
177
+ self.sql = SQLContext(sql_uri)
178
+ if gds_available:
179
+ self.gds =gds(self.neo4j)
148
180
 
149
181
  def env(self, key: str) -> Any:
150
182
  """
@@ -154,7 +186,7 @@ class ETLContext:
154
186
  key: name of the entry to read.
155
187
 
156
188
  Returns:
157
- va lue of the entry, or None if the key is not in the dict.
189
+ value of the entry, or None if the key is not in the dict.
158
190
  """
159
191
  if key in self.__env_vars:
160
192
  return self.__env_vars[key]
@@ -3,7 +3,7 @@ from datetime import datetime
3
3
 
4
4
  from tabulate import tabulate
5
5
 
6
- from etl_lib.core.Task import Task, TaskGroup
6
+ from etl_lib.core.Task import Task, TaskGroup, TaskReturn
7
7
 
8
8
 
9
9
  class ProgressReporter:
@@ -47,7 +47,7 @@ class ProgressReporter:
47
47
  self.logger.info(f"{'\t' * task.depth}starting {task.task_name()}")
48
48
  return task
49
49
 
50
- def finished_task(self, task: Task, success: bool, summery: dict, error: str = None) -> Task:
50
+ def finished_task(self, task: Task, result: TaskReturn) -> Task:
51
51
  """
52
52
  Marks the task as finished.
53
53
 
@@ -55,23 +55,21 @@ class ProgressReporter:
55
55
 
56
56
  Args:
57
57
  task: Task to be marked as finished.
58
- success: True if the task has successfully finished.
59
- summery: statistics for this task (such as `nodes_created`)
60
- error: If an exception occurred, the exception text should be provided here.
58
+ result: result of the task execution, such as status and summery information.
61
59
 
62
60
  Returns:
63
61
  Task to be marked as started.
64
62
  """
65
63
  task.end_time = datetime.now()
66
- task.success = success
67
- task.summery = summery
64
+ task.success = result.success
65
+ task.summery = result.summery
68
66
 
69
- report = f"{'\t' * task.depth} finished {task.task_name()} in {task.end_time - task.start_time} with success: {success}"
70
- if error is not None:
71
- report += f", error: \n{error}"
67
+ report = f"finished {task.task_name()} in {task.end_time - task.start_time} with status: {'success' if result.success else 'failed'}"
68
+ if result.error is not None:
69
+ report += f", error: \n{result.error}"
72
70
  else:
73
71
  # for the logger, remove entries with 0, but keep them in the original for reporting
74
- cleaned_summery = {key: value for key, value in summery.items() if value != 0}
72
+ cleaned_summery = {key: value for key, value in result.summery.items() if value != 0}
75
73
  if len(cleaned_summery) > 0:
76
74
  report += f"\n{tabulate([cleaned_summery], headers='keys', tablefmt='psql')}"
77
75
  self.logger.info(report)
@@ -87,7 +85,7 @@ class ProgressReporter:
87
85
  task: Task reporting updates.
88
86
  batches: Number of batches processed so far.
89
87
  expected_batches: Number of expected batches. Can be `None` if the overall number of
90
- batches is not know before execution.
88
+ batches is not known before execution.
91
89
  stats: dict of statistics so far (such as `nodes_created`).
92
90
  """
93
91
  pass
@@ -168,9 +166,9 @@ class Neo4jProgressReporter(ProgressReporter):
168
166
  start_time=task.start_time)
169
167
  return task
170
168
 
171
- def finished_task(self, task: Task, success: bool, summery: dict, error: str = None) -> Task:
172
- super().finished_task(task=task, success=success, summery=summery, error=error)
173
- if success:
169
+ def finished_task(self, task: Task, result: TaskReturn) -> Task:
170
+ super().finished_task(task=task, result=result)
171
+ if result.success:
174
172
  status = "success"
175
173
  else:
176
174
  status = "failure"
@@ -179,7 +177,7 @@ class Neo4jProgressReporter(ProgressReporter):
179
177
  MATCH (t:ETLTask {uuid:$id}) SET t.endTime = $end_time, t.status = $status, t.error = $error
180
178
  CREATE (s:ETLStats) SET s=$summery
181
179
  CREATE (t)-[:HAS_STATS]->(s)
182
- """, id=task.uuid, end_time=task.end_time, summery=summery, status=status, error=error)
180
+ """, id=task.uuid, end_time=task.end_time, summery=result.summery, status=status, error=result.error)
183
181
  return task
184
182
 
185
183
  def __create_constraints(self):
@@ -46,7 +46,8 @@ class TaskReturn:
46
46
 
47
47
  # Combine success values and errors
48
48
  combined_success = self.success and other.success
49
- combined_error = f"{self.error or ''} | {other.error or ''}".strip(" |")
49
+ combined_error = None if not (self.error or other.error) \
50
+ else f"{self.error or ''} | {other.error or ''}".strip(" |")
50
51
 
51
52
  return TaskReturn(
52
53
  success=combined_success, summery=merged_summery, error=combined_error
@@ -99,12 +100,7 @@ class Task:
99
100
  except Exception as e:
100
101
  result = TaskReturn(success=False, summery={}, error=str(e))
101
102
 
102
- self.context.reporter.finished_task(
103
- task=self,
104
- success=result.success,
105
- summery=result.summery,
106
- error=result.error,
107
- )
103
+ self.context.reporter.finished_task(task=self,result=result)
108
104
 
109
105
  return result
110
106
 
@@ -10,7 +10,7 @@ class CypherBatchSink(BatchProcessor):
10
10
  BatchProcessor to write batches of data to a Neo4j database.
11
11
  """
12
12
 
13
- def __init__(self, context: ETLContext, task: Task, predecessor: BatchProcessor, query: str):
13
+ def __init__(self, context: ETLContext, task: Task, predecessor: BatchProcessor, query: str, **kwargs):
14
14
  """
15
15
  Constructs a new CypherBatchSink.
16
16
 
@@ -20,16 +20,17 @@ class CypherBatchSink(BatchProcessor):
20
20
  predecessor: BatchProcessor which :func:`~get_batch` function will be called to receive batches to process.
21
21
  query: Cypher to write the query to Neo4j.
22
22
  Data will be passed as `batch` parameter.
23
- Therefor, the query should start with a `UNWIND $batch AS row`.
23
+ Therefore, the query should start with a `UNWIND $batch AS row`.
24
24
  """
25
25
  super().__init__(context, task, predecessor)
26
26
  self.query = query
27
27
  self.neo4j = context.neo4j
28
+ self.kwargs = kwargs
28
29
 
29
30
  def get_batch(self, batch_size: int) -> Generator[BatchResults, None, None]:
30
31
  assert self.predecessor is not None
31
32
 
32
33
  with self.neo4j.session() as session:
33
34
  for batch_result in self.predecessor.get_batch(batch_size):
34
- result = self.neo4j.query_database(session=session, query=self.query, batch=batch_result.chunk)
35
+ result = self.neo4j.query_database(session=session, query=self.query, batch=batch_result.chunk, **self.kwargs)
35
36
  yield append_result(batch_result, result.summery)
@@ -0,0 +1,36 @@
1
+ from typing import Generator
2
+ from sqlalchemy import text
3
+ from etl_lib.core.ETLContext import ETLContext
4
+ from etl_lib.core.BatchProcessor import BatchProcessor, BatchResults, append_result
5
+ from etl_lib.core.Task import Task
6
+
7
+
8
+ class SQLBatchSink(BatchProcessor):
9
+ """
10
+ BatchProcessor to write batches of data to an SQL database.
11
+ """
12
+
13
+ def __init__(self, context: ETLContext, task: Task, predecessor: BatchProcessor, query: str):
14
+ """
15
+ Constructs a new SQLBatchSink.
16
+
17
+ Args:
18
+ context: ETLContext instance.
19
+ task: Task instance owning this batchProcessor.
20
+ predecessor: BatchProcessor which `get_batch` function will be called to receive batches to process.
21
+ query: SQL query to write data.
22
+ Data will be passed as a batch using parameterized statements (`:param_name` syntax).
23
+ """
24
+ super().__init__(context, task, predecessor)
25
+ self.query = query
26
+ self.engine = context.sql.engine
27
+
28
+ def get_batch(self, batch_size: int) -> Generator[BatchResults, None, None]:
29
+ assert self.predecessor is not None
30
+
31
+ with self.engine.connect() as conn:
32
+ with conn.begin():
33
+ for batch_result in self.predecessor.get_batch(batch_size):
34
+ conn.execute(text(self.query), batch_result.chunk)
35
+ yield append_result(batch_result, {"sql_rows_written": len(batch_result.chunk)})
36
+
@@ -1,4 +1,6 @@
1
- from typing import Generator
1
+ from typing import Generator, Callable, Optional
2
+
3
+ from neo4j import Record
2
4
 
3
5
  from etl_lib.core.BatchProcessor import BatchResults, BatchProcessor
4
6
  from etl_lib.core.ETLContext import ETLContext
@@ -7,7 +9,14 @@ from etl_lib.core.Task import Task
7
9
 
8
10
  class CypherBatchSource(BatchProcessor):
9
11
 
10
- def __init__(self, context: ETLContext, task: Task, query: str, **kwargs):
12
+ def __init__(
13
+ self,
14
+ context: ETLContext,
15
+ task: Task,
16
+ query: str,
17
+ record_transformer: Optional[Callable[[Record], dict]] = None,
18
+ **kwargs
19
+ ):
11
20
  """
12
21
  Constructs a new CypherBatchSource.
13
22
 
@@ -15,10 +24,12 @@ class CypherBatchSource(BatchProcessor):
15
24
  context: :class:`etl_lib.core.ETLContext.ETLContext` instance.
16
25
  task: :class:`etl_lib.core.Task.Task` instance owning this batchProcessor.
17
26
  query: Cypher query to execute.
27
+ record_transformer: Optional function to transform each record. See Neo4j API documentation on `result_transformer_`
18
28
  kwargs: Arguments passed as parameters with the query.
19
29
  """
20
30
  super().__init__(context, task)
21
31
  self.query = query
32
+ self.record_transformer = record_transformer
22
33
  self.kwargs = kwargs
23
34
 
24
35
  def __read_records(self, tx, batch_size):
@@ -26,7 +37,11 @@ class CypherBatchSource(BatchProcessor):
26
37
  result = tx.run(self.query, **self.kwargs)
27
38
 
28
39
  for record in result:
29
- batch_.append(record.data())
40
+ data = record.data()
41
+ if self.record_transformer:
42
+ data = self.record_transformer(data)
43
+ batch_.append(data)
44
+
30
45
  if len(batch_) == batch_size:
31
46
  yield batch_
32
47
  batch_ = []
@@ -0,0 +1,60 @@
1
+ from typing import Generator, Callable, Optional
2
+ from sqlalchemy import text
3
+ from etl_lib.core.BatchProcessor import BatchResults, BatchProcessor
4
+ from etl_lib.core.ETLContext import ETLContext
5
+ from etl_lib.core.Task import Task
6
+
7
+
8
+ class SQLBatchSource(BatchProcessor):
9
+ def __init__(
10
+ self,
11
+ context: ETLContext,
12
+ task: Task,
13
+ query: str,
14
+ record_transformer: Optional[Callable[[dict], dict]] = None,
15
+ **kwargs
16
+ ):
17
+ """
18
+ Constructs a new SQLBatchSource.
19
+
20
+ Args:
21
+ context: :class:`etl_lib.core.ETLContext.ETLContext` instance.
22
+ task: :class:`etl_lib.core.Task.Task` instance owning this batchProcessor.
23
+ query: SQL query to execute.
24
+ record_transformer: Optional function to transform each row (dict format).
25
+ kwargs: Arguments passed as parameters with the query.
26
+ """
27
+ super().__init__(context, task)
28
+ self.query = query
29
+ self.record_transformer = record_transformer
30
+ self.kwargs = kwargs # Query parameters
31
+
32
+ def __read_records(self, conn, batch_size: int):
33
+ batch_ = []
34
+ result = conn.execute(text(self.query), self.kwargs) # Safe execution with bound parameters
35
+
36
+ for row in result.mappings(): # Returns row as dict (like Neo4j's `record.data()`)
37
+ data = dict(row) # Convert to dictionary
38
+ if self.record_transformer:
39
+ data = self.record_transformer(data)
40
+ batch_.append(data)
41
+
42
+ if len(batch_) == batch_size:
43
+ yield batch_
44
+ batch_ = [] # Reset batch
45
+
46
+ if batch_:
47
+ yield batch_
48
+
49
+ def get_batch(self, max_batch_size: int) -> Generator[BatchResults, None, None]:
50
+ """
51
+ Fetches data in batches using an open transaction, similar to Neo4j's approach.
52
+ """
53
+ with self.context.sql.engine.connect() as conn: # Keep transaction open
54
+ with conn.begin(): # Ensures rollback on failure
55
+ for chunk in self.__read_records(conn, max_batch_size):
56
+ yield BatchResults(
57
+ chunk=chunk,
58
+ statistics={"sql_rows_read": len(chunk)},
59
+ batch_size=len(chunk)
60
+ )
@@ -4,8 +4,8 @@ from etl_lib.core.Task import Task, TaskReturn
4
4
  class CreateReportingConstraintsTask(Task):
5
5
  """Creates the constraint in the REPORTER_DATABASE database."""
6
6
 
7
- def __init__(self, config):
8
- super().__init__(config)
7
+ def __init__(self, context):
8
+ super().__init__(context)
9
9
 
10
10
  def run_internal(self, **kwargs) -> TaskReturn:
11
11
  database = self.context.env("REPORTER_DATABASE")
@@ -24,10 +24,10 @@ class ExecuteCypherTask(Task):
24
24
  for query in self._query():
25
25
  result = self.context.neo4j.query_database(session=session, query=query, **kwargs)
26
26
  stats = merge_summery(stats, result.summery)
27
- return TaskReturn(True, stats)
27
+ return TaskReturn(success=True, summery=stats)
28
28
  else:
29
29
  result = self.context.neo4j.query_database(session=session, query=self._query(), **kwargs)
30
- return TaskReturn(True, result.summery)
30
+ return TaskReturn(success=True, summery=result.summery)
31
31
 
32
32
  @abc.abstractmethod
33
33
  def _query(self) -> str | list[str]:
@@ -14,19 +14,55 @@ from etl_lib.data_source.CSVBatchSource import CSVBatchSource
14
14
 
15
15
 
16
16
  class CSVLoad2Neo4jTask(Task):
17
- """
17
+ '''
18
18
  Loads the specified CSV file to Neo4j.
19
19
 
20
20
  Uses BatchProcessors to read, validate and write to Neo4j.
21
21
  The validation step is using pydantic, hence a Pydantic model needs to be provided.
22
- Rows that fail the validation, will be written to en error file. The location of the error file is determined as
22
+ Rows with fail validation will be written to en error file. The location of the error file is determined as
23
23
  follows:
24
24
 
25
- If the context env vars hold an entry `ETL_ERROR_PATH` the file will be place there, with the name set to name
25
+ If the context env vars hold an entry `ETL_ERROR_PATH` the file will be placed there, with the name set to name
26
26
  of the provided filename appended with `.error.json`
27
27
 
28
- If `ETL_ERROR_PATH` is not set, the file will be placed in the same directory as the CSV file.
29
- """
28
+ If `ETL_ERROR_PATH` is not set, the file will be placed in the same directory as the CSV file.
29
+
30
+ Example usage: (from the gtfs demo)
31
+
32
+ .. code-block:: python
33
+
34
+ class LoadStopsTask(CSVLoad2Neo4jTask):
35
+ class Stop(BaseModel):
36
+ id: str = Field(alias="stop_id")
37
+ name: str = Field(alias="stop_name")
38
+ latitude: float = Field(alias="stop_lat")
39
+ longitude: float = Field(alias="stop_lon")
40
+ platform_code: Optional[str] = None
41
+ parent_station: Optional[str] = None
42
+ type: Optional[str] = Field(alias="location_type", default=None)
43
+ timezone: Optional[str] = Field(alias="stop_timezone", default=None)
44
+ code: Optional[str] = Field(alias="stop_code", default=None)
45
+
46
+ def __init__(self, context: ETLContext, file: Path):
47
+ super().__init__(context, LoadStopsTask.Stop, file)
48
+
49
+ def task_name(self) -> str:
50
+ return f"{self.__class__.__name__}('{self.file}')"
51
+
52
+ def _query(self):
53
+ return """
54
+ UNWIND $batch AS row
55
+ MERGE (s:Stop {id: row.id})
56
+ SET s.name = row.name,
57
+ s.location= point({latitude: row.latitude, longitude: row.longitude}),
58
+ s.platformCode= row.platform_code,
59
+ s.parentStation= row.parent_station,
60
+ s.type= row.type,
61
+ s.timezone= row.timezone,
62
+ s.code= row.code
63
+ """
64
+
65
+ '''
30
66
  def __init__(self, context: ETLContext, model: Type[BaseModel], file: Path, batch_size: int = 5000):
31
67
  super().__init__(context)
32
68
  self.batch_size = batch_size
@@ -0,0 +1,90 @@
1
+ from abc import abstractmethod
2
+
3
+ from sqlalchemy import text
4
+
5
+ from etl_lib.core import ETLContext
6
+ from etl_lib.core.ClosedLoopBatchProcessor import ClosedLoopBatchProcessor
7
+ from etl_lib.core.Task import Task, TaskReturn
8
+ from etl_lib.data_sink.CypherBatchSink import CypherBatchSink
9
+ from etl_lib.data_source.SQLBatchSource import SQLBatchSource
10
+
11
+
12
+ class SQLLoad2Neo4jTask(Task):
13
+ '''
14
+ Load the output of the specified SQL query to Neo4j.
15
+
16
+ Uses BatchProcessors to read and write data.
17
+ Subclasses must implement the methods returning the SQL and Cypher queries.
18
+
19
+ Example usage: (from the MusicBrainz example)
20
+
21
+ .. code-block:: python
22
+
23
+
24
+ class LoadArtistCreditTask(SQLLoad2Neo4jTask):
25
+ def _sql_query(self) -> str:
26
+ return """
27
+ SELECT ac.id AS artist_credit_id, ac.name AS credit_name
28
+ FROM artist_credit ac;
29
+ """
30
+
31
+ def _cypher_query(self) -> str:
32
+ return """
33
+ UNWIND $batch AS row
34
+ MERGE (ac:ArtistCredit {id: row.artist_credit_id})
35
+ SET ac.name = row.credit_name
36
+ """
37
+
38
+ def _count_query(self) -> str | None:
39
+ return "SELECT COUNT(*) FROM artist_credit;"
40
+
41
+ '''
42
+
43
+ def __init__(self, context: ETLContext, batch_size: int = 5000):
44
+ super().__init__(context)
45
+ self.context = context
46
+ self.batch_size = batch_size
47
+
48
+ @abstractmethod
49
+ def _sql_query(self) -> str:
50
+ """
51
+ Return the SQL query to load the source data.
52
+ """
53
+ pass
54
+
55
+ @abstractmethod
56
+ def _cypher_query(self) -> str:
57
+ """
58
+ Return the Cypher query to write the data in batches to Neo4j.
59
+ """
60
+ pass
61
+
62
+ def _count_query(self) -> str | None:
63
+ """
64
+ Return the SQL query to count the number of rows returned from :func:`_sql_query`.
65
+
66
+ Optional. If provided, it will run once at the beginning of the task and
67
+ provide the :class:`etl_lib.core.ClosedLoopBatchProcessor` with the total number of rows.
68
+ """
69
+ return None
70
+
71
+ def run_internal(self) -> TaskReturn:
72
+ total_count = self.__get_source_count()
73
+ source = SQLBatchSource(self.context, self, self._sql_query())
74
+ sink = CypherBatchSink(self.context, self, source, self._cypher_query())
75
+
76
+ end = ClosedLoopBatchProcessor(self.context, self, sink, total_count)
77
+
78
+ result = next(end.get_batch(self.batch_size))
79
+ return TaskReturn(True, result.statistics)
80
+
81
+ def __get_source_count(self):
82
+ count_query = self._count_query()
83
+ if count_query is None:
84
+ return None
85
+
86
+ with self.context.sql.engine.connect() as conn:
87
+ with conn.begin():
88
+ result = conn.execute(text(count_query))
89
+ row = result.fetchone()
90
+ return row[0] if row else None
@@ -7,7 +7,7 @@ from _pytest.tmpdir import tmp_path
7
7
  from neo4j import Driver
8
8
  from neo4j.time import Date
9
9
 
10
- from etl_lib.core.ETLContext import QueryResult, Neo4jContext, ETLContext
10
+ from etl_lib.core.ETLContext import QueryResult, Neo4jContext, ETLContext, SQLContext, gds
11
11
  from etl_lib.core.Task import Task
12
12
 
13
13
 
@@ -102,6 +102,7 @@ class TestNeo4jContext(Neo4jContext):
102
102
  self.logger = logging.getLogger(self.__class__.__name__)
103
103
  self.driver = driver
104
104
  self.database = get_database_name()
105
+ self.gds = gds(self)
105
106
 
106
107
 
107
108
  class TestETLContext(ETLContext):
@@ -116,6 +117,16 @@ class TestETLContext(ETLContext):
116
117
  if key in self.__env_vars:
117
118
  return self.__env_vars[key]
118
119
 
120
+ class TestSQLETLContext(ETLContext):
121
+
122
+ def __init__(self, sql_uri):
123
+ self.logger = logging.getLogger(self.__class__.__name__)
124
+ self.reporter = DummyReporter()
125
+ self.sql = SQLContext(sql_uri)
126
+
127
+ def env(self, key: str) -> Any:
128
+ if key in self.__env_vars:
129
+ return self.__env_vars[key]
119
130
 
120
131
  class DummyReporter:
121
132
 
@@ -125,7 +136,7 @@ class DummyReporter:
125
136
  def started_task(self, task: Task) -> Task:
126
137
  pass
127
138
 
128
- def finished_task(self, task, success: bool, summery: dict, error: str = None) -> Task:
139
+ def finished_task(self, task, result) -> Task:
129
140
  pass
130
141
 
131
142
  def report_progress(self, task, batches: int, expected_batches: int, stats: dict) -> None:
@@ -151,3 +162,10 @@ class DummyContext:
151
162
 
152
163
  def env(self, key: str) -> Any:
153
164
  pass
165
+
166
+ class DummyPredecessor:
167
+ def __init__(self, batches):
168
+ self.batches = batches
169
+
170
+ def get_batch(self, batch_size):
171
+ yield from self.batches
File without changes