ingestr 0.2.6__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ingestr might be problematic. Click here for more details.

ingestr/main.py CHANGED
@@ -5,7 +5,7 @@ from typing import Optional
5
5
  import dlt
6
6
  import humanize
7
7
  import typer
8
- from dlt.common.runtime.collector import Collector
8
+ from dlt.common.runtime.collector import Collector, LogCollector
9
9
  from rich.console import Console
10
10
  from rich.status import Status
11
11
  from typing_extensions import Annotated
@@ -32,10 +32,17 @@ DATE_FORMATS = [
32
32
  "%Y-%m-%dT%H:%M:%S.%f%z",
33
33
  ]
34
34
 
35
+ # https://dlthub.com/docs/dlt-ecosystem/file-formats/parquet#supported-destinations
36
+ PARQUET_SUPPORTED_DESTINATIONS = [
37
+ "bigquery",
38
+ "duckdb",
39
+ "snowflake",
40
+ "databricks",
41
+ "synapse",
42
+ ]
35
43
 
36
- class SpinnerCollector(Collector):
37
- """A Collector that shows progress with `tqdm` progress bars"""
38
44
 
45
+ class SpinnerCollector(Collector):
39
46
  status: Status
40
47
  current_step: str
41
48
  started: bool
@@ -150,6 +157,27 @@ def ingest(
150
157
  envvar="FULL_REFRESH",
151
158
  ),
152
159
  ] = False, # type: ignore
160
+ progress: Annotated[
161
+ Optional[str],
162
+ typer.Option(
163
+ help="The progress display type, must be one of 'interactive', 'log'",
164
+ envvar="PROGRESS",
165
+ ),
166
+ ] = "interactive", # type: ignore
167
+ sql_backend: Annotated[
168
+ Optional[str],
169
+ typer.Option(
170
+ help="The SQL backend to use, must be one of 'sqlalchemy', 'pyarrow'",
171
+ envvar="SQL_BACKEND",
172
+ ),
173
+ ] = "pyarrow", # type: ignore
174
+ loader_file_format: Annotated[
175
+ Optional[str],
176
+ typer.Option(
177
+ help="The file format to use when loading data, must be one of 'jsonl', 'parquet', 'default'",
178
+ envvar="LOADER_FILE_FORMAT",
179
+ ),
180
+ ] = "default", # type: ignore
153
181
  ):
154
182
  track(
155
183
  "command_triggered",
@@ -186,12 +214,16 @@ def ingest(
186
214
  m = hashlib.sha256()
187
215
  m.update(dest_table.encode("utf-8"))
188
216
 
217
+ progressInstance: Collector = SpinnerCollector()
218
+ if progress == "log":
219
+ progressInstance = LogCollector()
220
+
189
221
  pipeline = dlt.pipeline(
190
222
  pipeline_name=m.hexdigest(),
191
223
  destination=destination.dlt_dest(
192
224
  uri=dest_uri,
193
225
  ),
194
- progress=SpinnerCollector(),
226
+ progress=progressInstance,
195
227
  pipelines_dir="pipeline_data",
196
228
  full_refresh=full_refresh,
197
229
  )
@@ -231,11 +263,20 @@ def ingest(
231
263
  merge_key=merge_key,
232
264
  interval_start=interval_start,
233
265
  interval_end=interval_end,
266
+ sql_backend=sql_backend,
234
267
  )
235
268
 
236
269
  if original_incremental_strategy == "delete+insert":
237
270
  dlt_source.incremental.primary_key = ()
238
271
 
272
+ if (
273
+ factory.destination_scheme in PARQUET_SUPPORTED_DESTINATIONS
274
+ and loader_file_format == "default"
275
+ ):
276
+ loader_file_format = "parquet"
277
+ elif loader_file_format == "default":
278
+ loader_file_format = "jsonl"
279
+
239
280
  run_info = pipeline.run(
240
281
  dlt_source,
241
282
  **destination.dlt_run_params(
@@ -244,6 +285,7 @@ def ingest(
244
285
  ),
245
286
  write_disposition=incremental_strategy, # type: ignore
246
287
  primary_key=(primary_key if primary_key and len(primary_key) > 0 else None), # type: ignore
288
+ loader_file_format=loader_file_format, # type: ignore
247
289
  )
248
290
 
249
291
  destination.post_load()
ingestr/main_test.py CHANGED
@@ -24,6 +24,8 @@ def invoke_ingest_command(
24
24
  merge_key=None,
25
25
  interval_start=None,
26
26
  interval_end=None,
27
+ sql_backend=None,
28
+ loader_file_format=None,
27
29
  ):
28
30
  args = [
29
31
  "ingest",
@@ -61,6 +63,14 @@ def invoke_ingest_command(
61
63
  args.append("--interval-end")
62
64
  args.append(interval_end)
63
65
 
66
+ if sql_backend:
67
+ args.append("--sql-backend")
68
+ args.append(sql_backend)
69
+
70
+ if loader_file_format:
71
+ args.append("--loader-file-format")
72
+ args.append(loader_file_format)
73
+
64
74
  result = runner.invoke(
65
75
  app,
66
76
  args,
@@ -137,6 +147,7 @@ def test_append():
137
147
  "testschema_append.output",
138
148
  "append",
139
149
  "updated_at",
150
+ sql_backend="sqlalchemy",
140
151
  )
141
152
  assert res.exit_code == 0
142
153
 
@@ -193,6 +204,7 @@ def test_merge_with_primary_key():
193
204
  "merge",
194
205
  "updated_at",
195
206
  "id",
207
+ sql_backend="sqlalchemy",
196
208
  )
197
209
  assert res.exit_code == 0
198
210
  return res
@@ -333,10 +345,10 @@ def test_delete_insert_without_primary_key():
333
345
  "CREATE TABLE testschema_delete_insert.input (id INTEGER, val VARCHAR, updated_at TIMESTAMP WITH TIME ZONE)"
334
346
  )
335
347
  conn.execute(
336
- "INSERT INTO testschema_delete_insert.input VALUES (1, 'val1', '2022-01-01')"
348
+ "INSERT INTO testschema_delete_insert.input VALUES (1, 'val1', '2022-01-01 00:00:00+00:00')"
337
349
  )
338
350
  conn.execute(
339
- "INSERT INTO testschema_delete_insert.input VALUES (2, 'val2', '2022-02-01')"
351
+ "INSERT INTO testschema_delete_insert.input VALUES (2, 'val2', '2022-02-01 00:00:00+00:00')"
340
352
  )
341
353
 
342
354
  res = conn.sql("select count(*) from testschema_delete_insert.input").fetchall()
@@ -350,6 +362,8 @@ def test_delete_insert_without_primary_key():
350
362
  "testschema_delete_insert.output",
351
363
  inc_strategy="delete+insert",
352
364
  inc_key="updated_at",
365
+ sql_backend="sqlalchemy",
366
+ loader_file_format="jsonl",
353
367
  )
354
368
  assert res.exit_code == 0
355
369
  return res
@@ -357,7 +371,7 @@ def test_delete_insert_without_primary_key():
357
371
  def get_output_rows():
358
372
  conn.execute("CHECKPOINT")
359
373
  return conn.sql(
360
- "select id, val, strftime(updated_at, '%Y-%m-%d') as updated_at from testschema_delete_insert.output order by id asc"
374
+ "select id, val, strftime(CAST(updated_at AT TIME ZONE 'UTC' AS TIMESTAMP), '%Y-%m-%d %H:%M:%S') from testschema_delete_insert.output order by id asc"
361
375
  ).fetchall()
362
376
 
363
377
  def assert_output_equals(expected):
@@ -367,7 +381,9 @@ def test_delete_insert_without_primary_key():
367
381
  assert res[i] == row
368
382
 
369
383
  run()
370
- assert_output_equals([(1, "val1", "2022-01-01"), (2, "val2", "2022-02-01")])
384
+ assert_output_equals(
385
+ [(1, "val1", "2022-01-01 00:00:00"), (2, "val2", "2022-02-01 00:00:00")]
386
+ )
371
387
 
372
388
  first_run_id = conn.sql(
373
389
  "select _dlt_load_id from testschema_delete_insert.output limit 1"
@@ -375,8 +391,10 @@ def test_delete_insert_without_primary_key():
375
391
 
376
392
  ##############################
377
393
  # we'll run again, since this is a delete+insert, we expect the run ID to change for the last one
378
- run()
379
- assert_output_equals([(1, "val1", "2022-01-01"), (2, "val2", "2022-02-01")])
394
+ res = run()
395
+ assert_output_equals(
396
+ [(1, "val1", "2022-01-01 00:00:00"), (2, "val2", "2022-02-01 00:00:00")]
397
+ )
380
398
 
381
399
  # we ensure that one of the rows is updated with a new run
382
400
  count_by_run_id = conn.sql(
@@ -392,17 +410,17 @@ def test_delete_insert_without_primary_key():
392
410
  ##############################
393
411
  # now we'll insert a few more lines for the same day, the new rows should show up
394
412
  conn.execute(
395
- "INSERT INTO testschema_delete_insert.input VALUES (3, 'val3', '2022-02-01'), (4, 'val4', '2022-02-01')"
413
+ "INSERT INTO testschema_delete_insert.input VALUES (3, 'val3', '2022-02-01 00:00:00+00:00'), (4, 'val4', '2022-02-01 00:00:00+00:00')"
396
414
  )
397
415
  conn.execute("CHECKPOINT")
398
416
 
399
417
  run()
400
418
  assert_output_equals(
401
419
  [
402
- (1, "val1", "2022-01-01"),
403
- (2, "val2", "2022-02-01"),
404
- (3, "val3", "2022-02-01"),
405
- (4, "val4", "2022-02-01"),
420
+ (1, "val1", "2022-01-01 00:00:00"),
421
+ (2, "val2", "2022-02-01 00:00:00"),
422
+ (3, "val3", "2022-02-01 00:00:00"),
423
+ (4, "val4", "2022-02-01 00:00:00"),
406
424
  ]
407
425
  )
408
426
 
@@ -460,6 +478,8 @@ def test_delete_insert_with_timerange():
460
478
  inc_key="updated_at",
461
479
  interval_start=start_date,
462
480
  interval_end=end_date,
481
+ sql_backend="sqlalchemy",
482
+ loader_file_format="jsonl",
463
483
  )
464
484
  assert res.exit_code == 0
465
485
  return res
@@ -177,7 +177,8 @@ class CsvDestination(GenericSqlDestination):
177
177
  )
178
178
 
179
179
  output_path = self.uri.split("://")[1]
180
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
180
+ if output_path.count("/") > 1:
181
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
181
182
 
182
183
  with gzip.open(first_file_path, "rt", encoding="utf-8") as jsonl_file: # type: ignore
183
184
  with open(output_path, "w", newline="") as csv_file:
ingestr/src/factory.py CHANGED
@@ -14,7 +14,7 @@ from ingestr.src.destinations import (
14
14
  SnowflakeDestination,
15
15
  SynapseDestination,
16
16
  )
17
- from ingestr.src.sources import LocalCsvSource, MongoDbSource, SqlSource
17
+ from ingestr.src.sources import LocalCsvSource, MongoDbSource, NotionSource, SqlSource
18
18
 
19
19
  SQL_SOURCE_SCHEMES = [
20
20
  "bigquery",
@@ -80,6 +80,8 @@ class SourceDestinationFactory:
80
80
  return LocalCsvSource()
81
81
  elif self.source_scheme == "mongodb":
82
82
  return MongoDbSource()
83
+ elif self.source_scheme == "notion":
84
+ return NotionSource()
83
85
  else:
84
86
  raise ValueError(f"Unsupported source scheme: {self.source_scheme}")
85
87
 
@@ -70,7 +70,7 @@ def mongodb_collection(
70
70
  collection: str = dlt.config.value,
71
71
  incremental: Optional[dlt.sources.incremental] = None, # type: ignore[type-arg]
72
72
  write_disposition: Optional[str] = dlt.config.value,
73
- parallel: Optional[bool] = dlt.config.value,
73
+ parallel: Optional[bool] = False,
74
74
  ) -> Any:
75
75
  """
76
76
  A DLT source which loads a collection from a mongo database using PyMongo.
@@ -83,7 +83,7 @@ class CollectionLoaderParallell(CollectionLoader):
83
83
  def _get_cursor(self) -> TCursor:
84
84
  cursor = self.collection.find(filter=self._filter_op)
85
85
  if self._sort_op:
86
- cursor = cursor.sort(self._sort_op) # type: ignore
86
+ cursor = cursor.sort(self._sort_op)
87
87
  return cursor
88
88
 
89
89
  @dlt.defer
@@ -155,11 +155,11 @@ class MongoDbCollectionConfiguration(BaseConfiguration):
155
155
 
156
156
  @configspec
157
157
  class MongoDbCollectionResourceConfiguration(BaseConfiguration):
158
- connection_url: str
159
- database: Optional[str]
160
- collection: str
158
+ connection_url: str = dlt.secrets.value
159
+ database: Optional[str] = dlt.config.value
160
+ collection: str = dlt.config.value
161
161
  incremental: Optional[dlt.sources.incremental] = None # type: ignore[type-arg]
162
- write_disposition: Optional[str] = None
162
+ write_disposition: Optional[str] = dlt.config.value
163
163
  parallel: Optional[bool] = False
164
164
 
165
165
 
@@ -0,0 +1,55 @@
1
+ """A source that extracts data from Notion API"""
2
+
3
+ from typing import Dict, Iterator, List, Optional
4
+
5
+ import dlt
6
+ from dlt.sources import DltResource
7
+
8
+ from .helpers.client import NotionClient
9
+ from .helpers.database import NotionDatabase
10
+
11
+
12
+ @dlt.source
13
+ def notion_databases(
14
+ database_ids: Optional[List[Dict[str, str]]] = None,
15
+ api_key: str = dlt.secrets.value,
16
+ ) -> Iterator[DltResource]:
17
+ """
18
+ Retrieves data from Notion databases.
19
+
20
+ Args:
21
+ database_ids (List[Dict[str, str]], optional): A list of dictionaries
22
+ each containing a database id and a name.
23
+ Defaults to None. If None, the function will generate all databases
24
+ in the workspace that are accessible to the integration.
25
+ api_key (str): The Notion API secret key.
26
+
27
+ Yields:
28
+ DltResource: Data resources from Notion databases.
29
+ """
30
+ notion_client = NotionClient(api_key)
31
+
32
+ if database_ids is None:
33
+ search_results = notion_client.search(
34
+ filter_criteria={"value": "database", "property": "object"}
35
+ )
36
+ database_ids = [
37
+ {"id": result["id"], "use_name": result["title"][0]["plain_text"]}
38
+ for result in search_results
39
+ ]
40
+
41
+ for database in database_ids:
42
+ if "use_name" not in database:
43
+ # Fetch the database details from Notion
44
+ details = notion_client.get_database(database["id"])
45
+
46
+ # Extract the name/title from the details
47
+ database["use_name"] = details["title"][0]["plain_text"]
48
+
49
+ notion_database = NotionDatabase(database["id"], notion_client)
50
+ yield dlt.resource( # type: ignore
51
+ notion_database.query(),
52
+ primary_key="id",
53
+ name=database["use_name"],
54
+ write_disposition="replace",
55
+ )
File without changes
@@ -0,0 +1,164 @@
1
+ from typing import Any, Dict, Iterator, Optional
2
+
3
+ from dlt.sources.helpers import requests
4
+
5
+ from ..settings import API_URL
6
+
7
+
8
+ class NotionClient:
9
+ """A client to interact with the Notion API.
10
+
11
+ Attributes:
12
+ api_key (str): The Notion API secret key.
13
+ """
14
+
15
+ def __init__(self, api_key: Optional[str] = None):
16
+ self.api_key = api_key
17
+
18
+ def _create_headers(self) -> Dict[str, str]:
19
+ headers = {
20
+ "accept": "application/json",
21
+ "Notion-Version": "2022-06-28",
22
+ "Authorization": f"Bearer {self.api_key}",
23
+ }
24
+ return headers
25
+
26
+ def _filter_out_none_values(self, dict_in: Dict[str, Any]) -> Dict[str, Any]:
27
+ return {k: v for k, v in dict_in.items() if v is not None}
28
+
29
+ def get_endpoint(
30
+ self, resource: str, resource_id: str, subresource: Optional[str] = None
31
+ ) -> str:
32
+ """Returns the endpoint for a given resource.
33
+
34
+ Args:
35
+ resource (str): The resource to get the endpoint for.
36
+ resource_id (str): The id of the resource.
37
+ subresource (str, optional): The subresource to get the endpoint for.
38
+
39
+ Returns:
40
+ str: The endpoint for the resource.
41
+ """
42
+ url = f"{API_URL}/{resource}/{resource_id}"
43
+ if subresource:
44
+ url += f"/{subresource}"
45
+ return url
46
+
47
+ def fetch_resource(
48
+ self, resource: str, resource_id: str, subresource: Optional[str] = None
49
+ ) -> Any:
50
+ """Fetches a resource from the Notion API.
51
+
52
+ Args:
53
+ resource (str): The resource to fetch.
54
+ resource_id (str): The id of the resource.
55
+ subresource (str, optional): The subresource to fetch. Defaults to None.
56
+
57
+ Returns:
58
+ Any: The resource from the Notion API.
59
+ """
60
+ url = self.get_endpoint(resource, resource_id, subresource)
61
+ headers = self._create_headers()
62
+ response = requests.get(url, headers=headers)
63
+ response.raise_for_status()
64
+ return response.json()
65
+
66
+ def send_payload(
67
+ self,
68
+ resource: str,
69
+ resource_id: str,
70
+ subresource: Optional[str] = None,
71
+ query_params: Optional[Dict[str, Any]] = None,
72
+ payload: Optional[Dict[str, Any]] = None,
73
+ ) -> Any:
74
+ """Sends a payload to the Notion API using the POST method.
75
+
76
+ Args:
77
+ resource (str): The resource to send the payload to.
78
+ resource_id (str): The id of the resource.
79
+ subresource (str, optional): The subresource to send the payload to.
80
+ Defaults to None.
81
+ query_params (Dict[str, Any], optional): The query parameters to send
82
+ with the payload. Defaults to None.
83
+ payload (Dict[str, Any], optional): The payload to send. Defaults to None.
84
+
85
+ Returns:
86
+ Any: The response from the Notion API.
87
+
88
+ Raises:
89
+ requests.HTTPError: If the response from the Notion API is not 200.
90
+ """
91
+
92
+ url = self.get_endpoint(resource, resource_id, subresource)
93
+ headers = self._create_headers()
94
+
95
+ if payload is None:
96
+ payload = {}
97
+
98
+ filtered_payload = self._filter_out_none_values(payload)
99
+
100
+ response = requests.post(
101
+ url, headers=headers, params=query_params, json=filtered_payload
102
+ )
103
+ response.raise_for_status()
104
+ return response.json()
105
+
106
+ def search(
107
+ self,
108
+ query: Optional[str] = None,
109
+ filter_criteria: Optional[Dict[str, Any]] = None,
110
+ sort: Optional[Dict[str, Any]] = None,
111
+ start_cursor: Optional[str] = None,
112
+ page_size: Optional[int] = None,
113
+ ) -> Iterator[Dict[str, Any]]:
114
+ """Searches all parent or child pages and databases that have been
115
+ shared with an integration.
116
+
117
+ Notion API Reference. Search:
118
+ https://developers.notion.com/reference/post-search
119
+
120
+ Args:
121
+ query (str, optional): The string to search for. Defaults to None.
122
+ filter_criteria (Dict[str, Any], optional): The filter to apply to
123
+ the results.
124
+ sort (Dict[str, Any], optional): The sort to apply to the results.
125
+ start_cursor (str, optional): The cursor to start the query at.
126
+ Defaults to None.
127
+ page_size (int, optional): The number of results to return.
128
+ Defaults to None.
129
+
130
+ Yields:
131
+ Dict[str, Any]: A result from the search.
132
+ """
133
+ has_more = True
134
+
135
+ while has_more:
136
+ payload = {
137
+ "query": query,
138
+ "sort": sort,
139
+ "filter": filter_criteria,
140
+ "start_cursor": start_cursor,
141
+ "page_size": page_size,
142
+ }
143
+
144
+ filtered_payload = self._filter_out_none_values(payload)
145
+
146
+ response = self.send_payload("search", "", payload=filtered_payload)
147
+
148
+ for result in response.get("results", []):
149
+ yield result
150
+
151
+ next_cursor = response.get("next_cursor")
152
+ has_more = next_cursor is not None
153
+ start_cursor = next_cursor
154
+
155
+ def get_database(self, database_id: str) -> Any:
156
+ """Fetches the details of a specific database by its ID.
157
+
158
+ Args:
159
+ database_id (str): The ID of the database to fetch.
160
+
161
+ Returns:
162
+ Any: The details of the database.
163
+ """
164
+ return self.fetch_resource("databases", database_id)
@@ -0,0 +1,78 @@
1
+ from typing import Any, Dict, Iterable, Optional
2
+
3
+ from dlt.common.typing import TDataItem
4
+
5
+ from .client import NotionClient
6
+
7
+
8
+ class NotionDatabase:
9
+ """
10
+ A class to represent a Notion database.
11
+
12
+ Attributes:
13
+ database_id (str): The ID of the Notion database.
14
+ notion_client (NotionClient): A client to interact with the Notion API.
15
+ """
16
+
17
+ def __init__(self, database_id: str, notion_client: NotionClient):
18
+ self.database_id = database_id
19
+ self.notion_client = notion_client
20
+
21
+ def get_structure(self) -> Any:
22
+ """Retrieves the structure of the database.
23
+
24
+ Notion API Reference. Retrieve a database:
25
+ https://developers.notion.com/reference/retrieve-a-database
26
+
27
+ Returns:
28
+ Any: The structure of the database.
29
+ """
30
+ return self.notion_client.fetch_resource("databases", self.database_id)
31
+
32
+ def query(
33
+ self,
34
+ filter_properties: Optional[Dict[str, Any]] = None,
35
+ filter_criteria: Optional[Dict[str, Any]] = None,
36
+ sorts: Optional[Dict[str, Any]] = None,
37
+ start_cursor: Optional[str] = None,
38
+ page_size: Optional[int] = None,
39
+ ) -> Iterable[TDataItem]:
40
+ """Queries the database for records.
41
+
42
+ Notion API Reference. Query a database:
43
+ https://developers.notion.com/reference/post-database-query
44
+
45
+ Args:
46
+ filter_properties (Dict[str, Any], optional): A dictionary of
47
+ properties to filter the records by. Defaults to None.
48
+ filter_criteria (Dict[str, Any], optional): A dictionary of filters
49
+ to apply to the records. Defaults to None.
50
+ sorts (Dict[str, Any], optional): A dictionary of sorts to apply
51
+ to the records. Defaults to None.
52
+ start_cursor (str, optional): The cursor to start the query at.
53
+ Defaults to None.
54
+ page_size (int, optional): The number of records to return.
55
+ Defaults to None.
56
+
57
+ Yields:
58
+ List[Dict[str, Any]]: A record from the database.
59
+ """
60
+ while True:
61
+ payload = {
62
+ "filter": filter_criteria,
63
+ "sorts": sorts,
64
+ "start_cursor": start_cursor,
65
+ "page_size": page_size,
66
+ }
67
+ response = self.notion_client.send_payload(
68
+ "databases",
69
+ self.database_id,
70
+ subresource="query",
71
+ query_params=filter_properties,
72
+ payload=payload,
73
+ )
74
+
75
+ yield response.get("results", [])
76
+ if not response.get("has_more"):
77
+ break
78
+ start_cursor = response.get("next_cursor")
@@ -0,0 +1,3 @@
1
+ """Notion source settings and constants"""
2
+
3
+ API_URL = "https://api.notion.com/v1"
ingestr/src/sources.py CHANGED
@@ -1,9 +1,11 @@
1
1
  import csv
2
2
  from typing import Callable
3
+ from urllib.parse import parse_qs, urlparse
3
4
 
4
5
  import dlt
5
6
 
6
7
  from ingestr.src.mongodb import mongodb_collection
8
+ from ingestr.src.notion import notion_databases
7
9
  from ingestr.src.sql_database import sql_table
8
10
 
9
11
 
@@ -39,6 +41,7 @@ class SqlSource:
39
41
  table=table_fields[-1],
40
42
  incremental=incremental,
41
43
  merge_key=kwargs.get("merge_key"),
44
+ backend=kwargs.get("sql_backend", "sqlalchemy"),
42
45
  )
43
46
 
44
47
  return table_instance
@@ -104,3 +107,25 @@ class LocalCsvSource:
104
107
  csv_file,
105
108
  merge_key=kwargs.get("merge_key"), # type: ignore
106
109
  )
110
+
111
+
112
+ class NotionSource:
113
+ table_builder: Callable
114
+
115
+ def __init__(self, table_builder=notion_databases) -> None:
116
+ self.table_builder = table_builder
117
+
118
+ def dlt_source(self, uri: str, table: str, **kwargs):
119
+ if kwargs.get("incremental_key"):
120
+ raise ValueError("Incremental loads are not supported for Notion")
121
+
122
+ source_fields = urlparse(uri)
123
+ source_params = parse_qs(source_fields.query)
124
+ api_key = source_params.get("api_key")
125
+ if not api_key:
126
+ raise ValueError("api_key in the URI is required to connect to Notion")
127
+
128
+ return self.table_builder(
129
+ database_ids=[{"id": table}],
130
+ api_key=api_key[0],
131
+ )
@@ -22,10 +22,11 @@ class SqlSourceTest(unittest.TestCase):
22
22
  table = "schema.table"
23
23
 
24
24
  # monkey patch the sql_table function
25
- def sql_table(credentials, schema, table, incremental, merge_key):
25
+ def sql_table(credentials, schema, table, incremental, merge_key, backend):
26
26
  self.assertEqual(credentials, uri)
27
27
  self.assertEqual(schema, "schema")
28
28
  self.assertEqual(table, "table")
29
+ self.assertEqual(backend, "sqlalchemy")
29
30
  self.assertIsNone(incremental)
30
31
  self.assertIsNone(merge_key)
31
32
  return dlt.resource()
@@ -40,10 +41,11 @@ class SqlSourceTest(unittest.TestCase):
40
41
  incremental_key = "id"
41
42
 
42
43
  # monkey patch the sql_table function
43
- def sql_table(credentials, schema, table, incremental, merge_key):
44
+ def sql_table(credentials, schema, table, incremental, merge_key, backend):
44
45
  self.assertEqual(credentials, uri)
45
46
  self.assertEqual(schema, "schema")
46
47
  self.assertEqual(table, "table")
48
+ self.assertEqual(backend, "sqlalchemy")
47
49
  self.assertIsInstance(incremental, dlt.sources.incremental)
48
50
  self.assertEqual(incremental.cursor_path, incremental_key)
49
51
  self.assertIsNone(merge_key)