unstructured-ingest 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (30) hide show
  1. test/integration/connectors/conftest.py +13 -0
  2. test/integration/connectors/databricks_tests/test_volumes_native.py +8 -4
  3. test/integration/connectors/sql/test_postgres.py +6 -10
  4. test/integration/connectors/sql/test_snowflake.py +205 -0
  5. test/integration/connectors/sql/test_sqlite.py +6 -10
  6. test/integration/connectors/test_delta_table.py +138 -0
  7. test/integration/connectors/utils/docker.py +78 -0
  8. test/integration/connectors/utils/validation.py +93 -2
  9. unstructured_ingest/__version__.py +1 -1
  10. unstructured_ingest/v2/cli/utils/click.py +32 -1
  11. unstructured_ingest/v2/cli/utils/model_conversion.py +10 -3
  12. unstructured_ingest/v2/interfaces/indexer.py +4 -1
  13. unstructured_ingest/v2/pipeline/pipeline.py +10 -2
  14. unstructured_ingest/v2/pipeline/steps/index.py +18 -1
  15. unstructured_ingest/v2/processes/connectors/__init__.py +10 -0
  16. unstructured_ingest/v2/processes/connectors/databricks/volumes.py +1 -1
  17. unstructured_ingest/v2/processes/connectors/delta_table.py +185 -0
  18. unstructured_ingest/v2/processes/connectors/slack.py +248 -0
  19. unstructured_ingest/v2/processes/connectors/sql/__init__.py +10 -2
  20. unstructured_ingest/v2/processes/connectors/sql/postgres.py +33 -37
  21. unstructured_ingest/v2/processes/connectors/sql/snowflake.py +164 -0
  22. unstructured_ingest/v2/processes/connectors/sql/sql.py +38 -10
  23. unstructured_ingest/v2/processes/connectors/sql/sqlite.py +31 -32
  24. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.0.dist-info}/METADATA +14 -12
  25. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.0.dist-info}/RECORD +29 -24
  26. unstructured_ingest/v2/processes/connectors/databricks_volumes.py +0 -250
  27. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.0.dist-info}/LICENSE.md +0 -0
  28. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.0.dist-info}/WHEEL +0 -0
  29. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.0.dist-info}/entry_points.txt +0 -0
  30. {unstructured_ingest-0.1.1.dist-info → unstructured_ingest-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,248 @@
1
+ import hashlib
2
+ import time
3
+ import xml.etree.ElementTree as ET
4
+ from dataclasses import dataclass, field
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any, Generator, Optional
8
+
9
+ from pydantic import Field, Secret
10
+
11
+ from unstructured_ingest.error import SourceConnectionError
12
+ from unstructured_ingest.logger import logger
13
+ from unstructured_ingest.utils.dep_check import requires_dependencies
14
+ from unstructured_ingest.v2.interfaces import (
15
+ AccessConfig,
16
+ ConnectionConfig,
17
+ Downloader,
18
+ DownloaderConfig,
19
+ Indexer,
20
+ IndexerConfig,
21
+ download_responses,
22
+ )
23
+ from unstructured_ingest.v2.interfaces.file_data import (
24
+ FileData,
25
+ FileDataSourceMetadata,
26
+ SourceIdentifiers,
27
+ )
28
+ from unstructured_ingest.v2.processes.connector_registry import SourceRegistryEntry
29
+
30
+ if TYPE_CHECKING:
31
+ from slack_sdk import WebClient
32
+ from slack_sdk.web.async_client import AsyncWebClient
33
+
34
+ # NOTE: Pagination limit set to the upper end of the recommended range
35
+ # https://api.slack.com/apis/pagination#facts
36
+ PAGINATION_LIMIT = 200
37
+
38
+ CONNECTOR_TYPE = "slack"
39
+
40
+
41
+ class SlackAccessConfig(AccessConfig):
42
+ token: str = Field(
43
+ description="Bot token used to access Slack API, must have channels:history scope for the"
44
+ " bot user."
45
+ )
46
+
47
+
48
+ class SlackConnectionConfig(ConnectionConfig):
49
+ access_config: Secret[SlackAccessConfig]
50
+
51
+ @requires_dependencies(["slack_sdk"], extras="slack")
52
+ @SourceConnectionError.wrap
53
+ def get_client(self) -> "WebClient":
54
+ from slack_sdk import WebClient
55
+
56
+ return WebClient(token=self.access_config.get_secret_value().token)
57
+
58
+ @requires_dependencies(["slack_sdk"], extras="slack")
59
+ @SourceConnectionError.wrap
60
+ def get_async_client(self) -> "AsyncWebClient":
61
+ from slack_sdk.web.async_client import AsyncWebClient
62
+
63
+ return AsyncWebClient(token=self.access_config.get_secret_value().token)
64
+
65
+
66
+ class SlackIndexerConfig(IndexerConfig):
67
+ channels: list[str] = Field(
68
+ description="Comma-delimited list of Slack channel IDs to pull messages from, can be"
69
+ " both public or private channels."
70
+ )
71
+ start_date: Optional[datetime] = Field(
72
+ default=None,
73
+ description="Start date/time in formats YYYY-MM-DD[T]HH:MM[:SS[.ffffff]][Z or [±]HH[:]MM]"
74
+ " or YYYY-MM-DD",
75
+ )
76
+ end_date: Optional[datetime] = Field(
77
+ default=None,
78
+ description="End date/time in formats YYYY-MM-DD[T]HH:MM[:SS[.ffffff]][Z or [±]HH[:]MM]"
79
+ " or YYYY-MM-DD",
80
+ )
81
+
82
+
83
+ @dataclass
84
+ class SlackIndexer(Indexer):
85
+ index_config: SlackIndexerConfig
86
+ connection_config: SlackConnectionConfig
87
+ connector_type: str = CONNECTOR_TYPE
88
+
89
+ def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
90
+ client = self.connection_config.get_client()
91
+ for channel in self.index_config.channels:
92
+ messages = []
93
+ oldest = (
94
+ str(self.index_config.start_date.timestamp())
95
+ if self.index_config.start_date is not None
96
+ else None
97
+ )
98
+ latest = (
99
+ str(self.index_config.end_date.timestamp())
100
+ if self.index_config.end_date is not None
101
+ else None
102
+ )
103
+ for conversation_history in client.conversations_history(
104
+ channel=channel,
105
+ oldest=oldest,
106
+ latest=latest,
107
+ limit=PAGINATION_LIMIT,
108
+ ):
109
+ messages = conversation_history.get("messages", [])
110
+ if messages:
111
+ yield self._messages_to_file_data(messages, channel)
112
+
113
+ def _messages_to_file_data(
114
+ self,
115
+ messages: list[dict],
116
+ channel: str,
117
+ ) -> FileData:
118
+ ts_oldest = min((message["ts"] for message in messages), key=lambda m: float(m))
119
+ ts_newest = max((message["ts"] for message in messages), key=lambda m: float(m))
120
+
121
+ identifier_base = f"{channel}-{ts_oldest}-{ts_newest}"
122
+ identifier = hashlib.sha256(identifier_base.encode("utf-8")).hexdigest()
123
+ filename = identifier[:16]
124
+
125
+ return FileData(
126
+ identifier=identifier,
127
+ connector_type=CONNECTOR_TYPE,
128
+ source_identifiers=SourceIdentifiers(
129
+ filename=f"{filename}.xml", fullpath=f"{filename}.xml"
130
+ ),
131
+ metadata=FileDataSourceMetadata(
132
+ date_created=ts_oldest,
133
+ date_modified=ts_newest,
134
+ date_processed=str(time.time()),
135
+ record_locator={
136
+ "channel": channel,
137
+ "oldest": ts_oldest,
138
+ "latest": ts_newest,
139
+ },
140
+ ),
141
+ )
142
+
143
+ @SourceConnectionError.wrap
144
+ def precheck(self) -> None:
145
+ client = self.connection_config.get_client()
146
+ for channel in self.index_config.channels:
147
+ # NOTE: Querying conversations history guarantees that the bot is in the channel
148
+ client.conversations_history(channel=channel, limit=1)
149
+
150
+
151
+ class SlackDownloaderConfig(DownloaderConfig):
152
+ pass
153
+
154
+
155
+ @dataclass
156
+ class SlackDownloader(Downloader):
157
+ connector_type: str = CONNECTOR_TYPE
158
+ connection_config: SlackConnectionConfig
159
+ download_config: SlackDownloaderConfig = field(default_factory=SlackDownloaderConfig)
160
+
161
+ def run(self, file_data, **kwargs):
162
+ raise NotImplementedError
163
+
164
+ async def run_async(self, file_data: FileData, **kwargs) -> download_responses:
165
+ # NOTE: Indexer should provide source identifiers required to generate the download path
166
+ download_path = self.get_download_path(file_data)
167
+ if download_path is None:
168
+ logger.error(
169
+ "Generated download path is None, source_identifiers might be missing"
170
+ "from FileData."
171
+ )
172
+ raise ValueError("Generated invalid download path.")
173
+
174
+ await self._download_conversation(file_data, download_path)
175
+ return self.generate_download_response(file_data, download_path)
176
+
177
+ def is_async(self):
178
+ return True
179
+
180
+ async def _download_conversation(self, file_data: FileData, download_path: Path) -> None:
181
+ # NOTE: Indexer should supply the record locator in metadata
182
+ if (
183
+ file_data.metadata.record_locator is None
184
+ or "channel" not in file_data.metadata.record_locator
185
+ or "oldest" not in file_data.metadata.record_locator
186
+ or "latest" not in file_data.metadata.record_locator
187
+ ):
188
+ logger.error(
189
+ f"Invalid record locator in metadata: {file_data.metadata.record_locator}."
190
+ "Keys 'channel', 'oldest' and 'latest' must be present."
191
+ )
192
+ raise ValueError("Invalid record locator.")
193
+
194
+ client = self.connection_config.get_async_client()
195
+ messages = []
196
+ async for conversation_history in await client.conversations_history(
197
+ channel=file_data.metadata.record_locator["channel"],
198
+ oldest=file_data.metadata.record_locator["oldest"],
199
+ latest=file_data.metadata.record_locator["latest"],
200
+ limit=PAGINATION_LIMIT,
201
+ # NOTE: In order to get the exact same range of messages as indexer, it provides
202
+ # timestamps of oldest and newest messages, inclusive=True is necessary to include them
203
+ inclusive=True,
204
+ ):
205
+ messages += conversation_history.get("messages", [])
206
+
207
+ conversation = []
208
+ for message in messages:
209
+ thread_messages = []
210
+ async for conversations_replies in await client.conversations_replies(
211
+ channel=file_data.metadata.record_locator["channel"],
212
+ ts=message["ts"],
213
+ limit=PAGINATION_LIMIT,
214
+ ):
215
+ thread_messages += conversations_replies.get("messages", [])
216
+
217
+ # NOTE: Replies contains the whole thread, including the message references by the `ts`
218
+ # parameter even if it's the only message (there were no replies).
219
+ # Reference: https://api.slack.com/methods/conversations.replies#markdown
220
+ conversation.append(thread_messages)
221
+
222
+ conversation_xml = self._conversation_to_xml(conversation)
223
+ download_path.parent.mkdir(exist_ok=True, parents=True)
224
+ conversation_xml.write(download_path, encoding="utf-8", xml_declaration=True)
225
+
226
+ def _conversation_to_xml(self, conversation: list[list[dict]]) -> ET.ElementTree:
227
+ root = ET.Element("messages")
228
+
229
+ for thread in conversation:
230
+ message, *replies = thread
231
+ message_elem = ET.SubElement(root, "message")
232
+ text_elem = ET.SubElement(message_elem, "text")
233
+ text_elem.text = message.get("text")
234
+
235
+ for reply in replies:
236
+ reply_msg = reply.get("text", "")
237
+ text_elem.text = "".join([str(text_elem.text), " <reply> ", reply_msg])
238
+
239
+ return ET.ElementTree(root)
240
+
241
+
242
+ slack_source_entry = SourceRegistryEntry(
243
+ indexer=SlackIndexer,
244
+ indexer_config=SlackIndexerConfig,
245
+ downloader=SlackDownloader,
246
+ downloader_config=DownloaderConfig,
247
+ connection_config=SlackConnectionConfig,
248
+ )
@@ -2,12 +2,20 @@ from __future__ import annotations
2
2
 
3
3
  from unstructured_ingest.v2.processes.connector_registry import (
4
4
  add_destination_entry,
5
+ add_source_entry,
5
6
  )
6
7
 
7
8
  from .postgres import CONNECTOR_TYPE as POSTGRES_CONNECTOR_TYPE
8
- from .postgres import postgres_destination_entry
9
+ from .postgres import postgres_destination_entry, postgres_source_entry
10
+ from .snowflake import CONNECTOR_TYPE as SNOWFLAKE_CONNECTOR_TYPE
11
+ from .snowflake import snowflake_destination_entry, snowflake_source_entry
9
12
  from .sqlite import CONNECTOR_TYPE as SQLITE_CONNECTOR_TYPE
10
- from .sqlite import sqlite_destination_entry
13
+ from .sqlite import sqlite_destination_entry, sqlite_source_entry
14
+
15
+ add_source_entry(source_type=SQLITE_CONNECTOR_TYPE, entry=sqlite_source_entry)
16
+ add_source_entry(source_type=POSTGRES_CONNECTOR_TYPE, entry=postgres_source_entry)
17
+ add_source_entry(source_type=SNOWFLAKE_CONNECTOR_TYPE, entry=snowflake_source_entry)
11
18
 
12
19
  add_destination_entry(destination_type=SQLITE_CONNECTOR_TYPE, entry=sqlite_destination_entry)
13
20
  add_destination_entry(destination_type=POSTGRES_CONNECTOR_TYPE, entry=postgres_destination_entry)
21
+ add_destination_entry(destination_type=SNOWFLAKE_CONNECTOR_TYPE, entry=snowflake_destination_entry)
@@ -1,15 +1,16 @@
1
+ from contextlib import contextmanager
1
2
  from dataclasses import dataclass, field
2
- from pathlib import Path
3
- from typing import TYPE_CHECKING, Any, Optional
3
+ from typing import TYPE_CHECKING, Any, Generator, Optional
4
4
 
5
- import numpy as np
6
- import pandas as pd
7
5
  from pydantic import Field, Secret
8
6
 
9
7
  from unstructured_ingest.utils.dep_check import requires_dependencies
10
8
  from unstructured_ingest.v2.interfaces import FileData
11
9
  from unstructured_ingest.v2.logger import logger
12
- from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
10
+ from unstructured_ingest.v2.processes.connector_registry import (
11
+ DestinationRegistryEntry,
12
+ SourceRegistryEntry,
13
+ )
13
14
  from unstructured_ingest.v2.processes.connectors.sql.sql import (
14
15
  _DATE_COLUMNS,
15
16
  SQLAccessConfig,
@@ -27,6 +28,7 @@ from unstructured_ingest.v2.processes.connectors.sql.sql import (
27
28
 
28
29
  if TYPE_CHECKING:
29
30
  from psycopg2.extensions import connection as PostgresConnection
31
+ from psycopg2.extensions import cursor as PostgresCursor
30
32
 
31
33
  CONNECTOR_TYPE = "postgres"
32
34
 
@@ -48,18 +50,33 @@ class PostgresConnectionConfig(SQLConnectionConfig):
48
50
  port: Optional[int] = Field(default=5432, description="DB host connection port")
49
51
  connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
50
52
 
53
+ @contextmanager
51
54
  @requires_dependencies(["psycopg2"], extras="postgres")
52
- def get_connection(self) -> "PostgresConnection":
55
+ def get_connection(self) -> Generator["PostgresConnection", None, None]:
53
56
  from psycopg2 import connect
54
57
 
55
58
  access_config = self.access_config.get_secret_value()
56
- return connect(
59
+ connection = connect(
57
60
  user=self.username,
58
61
  password=access_config.password,
59
62
  dbname=self.database,
60
63
  host=self.host,
61
64
  port=self.port,
62
65
  )
66
+ try:
67
+ yield connection
68
+ finally:
69
+ connection.commit()
70
+ connection.close()
71
+
72
+ @contextmanager
73
+ def get_cursor(self) -> Generator["PostgresCursor", None, None]:
74
+ with self.get_connection() as connection:
75
+ cursor = connection.cursor()
76
+ try:
77
+ yield cursor
78
+ finally:
79
+ cursor.close()
63
80
 
64
81
 
65
82
  class PostgresIndexerConfig(SQLIndexerConfig):
@@ -72,16 +89,6 @@ class PostgresIndexer(SQLIndexer):
72
89
  index_config: PostgresIndexerConfig
73
90
  connector_type: str = CONNECTOR_TYPE
74
91
 
75
- def _get_doc_ids(self) -> list[str]:
76
- connection = self.connection_config.get_connection()
77
- with connection.cursor() as cursor:
78
- cursor.execute(
79
- f"SELECT {self.index_config.id_column} FROM {self.index_config.table_name}"
80
- )
81
- results = cursor.fetchall()
82
- ids = [result[0] for result in results]
83
- return ids
84
-
85
92
 
86
93
  class PostgresDownloaderConfig(SQLDownloaderConfig):
87
94
  pass
@@ -97,8 +104,7 @@ class PostgresDownloader(SQLDownloader):
97
104
  table_name = file_data.additional_metadata["table_name"]
98
105
  id_column = file_data.additional_metadata["id_column"]
99
106
  ids = file_data.additional_metadata["ids"]
100
- connection = self.connection_config.get_connection()
101
- with connection.cursor() as cursor:
107
+ with self.connection_config.get_cursor() as cursor:
102
108
  fields = ",".join(self.download_config.fields) if self.download_config.fields else "*"
103
109
  query = "SELECT {fields} FROM {table_name} WHERE {id_column} in ({ids})".format(
104
110
  fields=fields,
@@ -130,6 +136,7 @@ class PostgresUploader(SQLUploader):
130
136
  upload_config: PostgresUploaderConfig = field(default_factory=PostgresUploaderConfig)
131
137
  connection_config: PostgresConnectionConfig
132
138
  connector_type: str = CONNECTOR_TYPE
139
+ values_delimiter: str = "%s"
133
140
 
134
141
  def prepare_data(
135
142
  self, columns: list[str], data: tuple[tuple[Any, ...], ...]
@@ -148,25 +155,14 @@ class PostgresUploader(SQLUploader):
148
155
  output.append(tuple(parsed))
149
156
  return output
150
157
 
151
- def upload_contents(self, path: Path) -> None:
152
- df = pd.read_json(path, orient="records", lines=True)
153
- logger.debug(f"uploading {len(df)} entries to {self.connection_config.database} ")
154
- df.replace({np.nan: None}, inplace=True)
155
-
156
- columns = tuple(df.columns)
157
- stmt = f"INSERT INTO {self.upload_config.table_name} ({','.join(columns)}) \
158
- VALUES({','.join(['%s' for x in columns])})" # noqa E501
159
-
160
- for rows in pd.read_json(
161
- path, orient="records", lines=True, chunksize=self.upload_config.batch_size
162
- ):
163
- with self.connection_config.get_connection() as conn:
164
- values = self.prepare_data(columns, tuple(rows.itertuples(index=False, name=None)))
165
- with conn.cursor() as cur:
166
- cur.executemany(stmt, values)
167
-
168
- conn.commit()
169
158
 
159
+ postgres_source_entry = SourceRegistryEntry(
160
+ connection_config=PostgresConnectionConfig,
161
+ indexer_config=PostgresIndexerConfig,
162
+ indexer=PostgresIndexer,
163
+ downloader_config=PostgresDownloaderConfig,
164
+ downloader=PostgresDownloader,
165
+ )
170
166
 
171
167
  postgres_destination_entry = DestinationRegistryEntry(
172
168
  connection_config=PostgresConnectionConfig,
@@ -0,0 +1,164 @@
1
+ from contextlib import contextmanager
2
+ from dataclasses import dataclass, field
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Generator, Optional
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from pydantic import Field, Secret
9
+
10
+ from unstructured_ingest.utils.dep_check import requires_dependencies
11
+ from unstructured_ingest.v2.processes.connector_registry import (
12
+ DestinationRegistryEntry,
13
+ SourceRegistryEntry,
14
+ )
15
+ from unstructured_ingest.v2.processes.connectors.sql.postgres import (
16
+ PostgresDownloader,
17
+ PostgresDownloaderConfig,
18
+ PostgresIndexer,
19
+ PostgresIndexerConfig,
20
+ PostgresUploader,
21
+ PostgresUploaderConfig,
22
+ PostgresUploadStager,
23
+ PostgresUploadStagerConfig,
24
+ )
25
+ from unstructured_ingest.v2.processes.connectors.sql.sql import SQLAccessConfig, SQLConnectionConfig
26
+
27
+ if TYPE_CHECKING:
28
+ from snowflake.connector import SnowflakeConnection
29
+ from snowflake.connector.cursor import SnowflakeCursor
30
+
31
+ CONNECTOR_TYPE = "snowflake"
32
+
33
+
34
+ class SnowflakeAccessConfig(SQLAccessConfig):
35
+ password: Optional[str] = Field(default=None, description="DB password")
36
+
37
+
38
+ class SnowflakeConnectionConfig(SQLConnectionConfig):
39
+ access_config: Secret[SnowflakeAccessConfig] = Field(
40
+ default=SnowflakeAccessConfig(), validate_default=True
41
+ )
42
+ account: str = Field(
43
+ default=None,
44
+ description="Your account identifier. The account identifier "
45
+ "does not include the snowflakecomputing.com suffix.",
46
+ )
47
+ user: Optional[str] = Field(default=None, description="DB username")
48
+ host: Optional[str] = Field(default=None, description="DB host")
49
+ port: Optional[int] = Field(default=443, description="DB host connection port")
50
+ database: str = Field(
51
+ default=None,
52
+ description="Database name.",
53
+ )
54
+ schema: str = Field(
55
+ default=None,
56
+ description="Database schema.",
57
+ )
58
+ role: str = Field(
59
+ default=None,
60
+ description="Database role.",
61
+ )
62
+ connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
63
+
64
+ @contextmanager
65
+ @requires_dependencies(["snowflake"], extras="snowflake")
66
+ def get_connection(self) -> Generator["SnowflakeConnection", None, None]:
67
+ # https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api#label-snowflake-connector-methods-connect
68
+ from snowflake.connector import connect
69
+
70
+ connect_kwargs = self.model_dump()
71
+ connect_kwargs.pop("access_configs", None)
72
+ connect_kwargs["password"] = self.access_config.get_secret_value().password
73
+ # https://peps.python.org/pep-0249/#paramstyle
74
+ connect_kwargs["paramstyle"] = "qmark"
75
+ connection = connect(**connect_kwargs)
76
+ try:
77
+ yield connection
78
+ finally:
79
+ connection.commit()
80
+ connection.close()
81
+
82
+ @contextmanager
83
+ def get_cursor(self) -> Generator["SnowflakeCursor", None, None]:
84
+ with self.get_connection() as connection:
85
+ cursor = connection.cursor()
86
+ try:
87
+ yield cursor
88
+ finally:
89
+ cursor.close()
90
+
91
+
92
+ class SnowflakeIndexerConfig(PostgresIndexerConfig):
93
+ pass
94
+
95
+
96
+ @dataclass
97
+ class SnowflakeIndexer(PostgresIndexer):
98
+ connection_config: SnowflakeConnectionConfig
99
+ index_config: SnowflakeIndexerConfig
100
+ connector_type: str = CONNECTOR_TYPE
101
+
102
+
103
+ class SnowflakeDownloaderConfig(PostgresDownloaderConfig):
104
+ pass
105
+
106
+
107
+ @dataclass
108
+ class SnowflakeDownloader(PostgresDownloader):
109
+ connection_config: SnowflakeConnectionConfig
110
+ download_config: SnowflakeDownloaderConfig
111
+ connector_type: str = CONNECTOR_TYPE
112
+
113
+
114
+ class SnowflakeUploadStagerConfig(PostgresUploadStagerConfig):
115
+ pass
116
+
117
+
118
+ class SnowflakeUploadStager(PostgresUploadStager):
119
+ upload_stager_config: SnowflakeUploadStagerConfig
120
+
121
+
122
+ class SnowflakeUploaderConfig(PostgresUploaderConfig):
123
+ pass
124
+
125
+
126
+ @dataclass
127
+ class SnowflakeUploader(PostgresUploader):
128
+ upload_config: SnowflakeUploaderConfig = field(default_factory=SnowflakeUploaderConfig)
129
+ connection_config: SnowflakeConnectionConfig
130
+ connector_type: str = CONNECTOR_TYPE
131
+ values_delimiter: str = "?"
132
+
133
+ def upload_contents(self, path: Path) -> None:
134
+ df = pd.read_json(path, orient="records", lines=True)
135
+ df.replace({np.nan: None}, inplace=True)
136
+
137
+ columns = list(df.columns)
138
+ stmt = f"INSERT INTO {self.upload_config.table_name} ({','.join(columns)}) VALUES({','.join([self.values_delimiter for x in columns])})" # noqa E501
139
+
140
+ for rows in pd.read_json(
141
+ path, orient="records", lines=True, chunksize=self.upload_config.batch_size
142
+ ):
143
+ with self.connection_config.get_cursor() as cursor:
144
+ values = self.prepare_data(columns, tuple(rows.itertuples(index=False, name=None)))
145
+ # TODO: executemany break on 'Binding data in type (list) is not supported'
146
+ for val in values:
147
+ cursor.execute(stmt, val)
148
+
149
+
150
+ snowflake_source_entry = SourceRegistryEntry(
151
+ connection_config=SnowflakeConnectionConfig,
152
+ indexer_config=SnowflakeIndexerConfig,
153
+ indexer=SnowflakeIndexer,
154
+ downloader_config=SnowflakeDownloaderConfig,
155
+ downloader=SnowflakeDownloader,
156
+ )
157
+
158
+ snowflake_destination_entry = DestinationRegistryEntry(
159
+ connection_config=SnowflakeConnectionConfig,
160
+ uploader=SnowflakeUploader,
161
+ uploader_config=SnowflakeUploaderConfig,
162
+ upload_stager=SnowflakeUploadStager,
163
+ upload_stager_config=SnowflakeUploadStagerConfig,
164
+ )
@@ -3,12 +3,14 @@ import json
3
3
  import sys
4
4
  import uuid
5
5
  from abc import ABC, abstractmethod
6
+ from contextlib import contextmanager
6
7
  from dataclasses import dataclass, field, replace
7
8
  from datetime import date, datetime
8
9
  from pathlib import Path
9
10
  from time import time
10
11
  from typing import Any, Generator, Union
11
12
 
13
+ import numpy as np
12
14
  import pandas as pd
13
15
  from dateutil import parser
14
16
  from pydantic import Field, Secret
@@ -94,7 +96,13 @@ class SQLConnectionConfig(ConnectionConfig, ABC):
94
96
  access_config: Secret[SQLAccessConfig] = Field(default=SQLAccessConfig(), validate_default=True)
95
97
 
96
98
  @abstractmethod
97
- def get_connection(self) -> Any:
99
+ @contextmanager
100
+ def get_connection(self) -> Generator[Any, None, None]:
101
+ pass
102
+
103
+ @abstractmethod
104
+ @contextmanager
105
+ def get_cursor(self) -> Generator[Any, None, None]:
98
106
  pass
99
107
 
100
108
 
@@ -108,16 +116,19 @@ class SQLIndexer(Indexer, ABC):
108
116
  connection_config: SQLConnectionConfig
109
117
  index_config: SQLIndexerConfig
110
118
 
111
- @abstractmethod
112
119
  def _get_doc_ids(self) -> list[str]:
113
- pass
120
+ with self.connection_config.get_cursor() as cursor:
121
+ cursor.execute(
122
+ f"SELECT {self.index_config.id_column} FROM {self.index_config.table_name}"
123
+ )
124
+ results = cursor.fetchall()
125
+ ids = [result[0] for result in results]
126
+ return ids
114
127
 
115
128
  def precheck(self) -> None:
116
129
  try:
117
- connection = self.connection_config.get_connection()
118
- cursor = connection.cursor()
119
- cursor.execute("SELECT 1;")
120
- cursor.close()
130
+ with self.connection_config.get_cursor() as cursor:
131
+ cursor.execute("SELECT 1;")
121
132
  except Exception as e:
122
133
  logger.error(f"failed to validate connection: {e}", exc_info=True)
123
134
  raise SourceConnectionError(f"failed to validate connection: {e}")
@@ -198,7 +209,7 @@ class SQLDownloader(Downloader, ABC):
198
209
  f"Downloading results from table {table_name} and id {record_id} to {download_path}"
199
210
  )
200
211
  download_path.parent.mkdir(parents=True, exist_ok=True)
201
- result.to_csv(download_path)
212
+ result.to_csv(download_path, index=False)
202
213
  copied_file_data = replace(file_data)
203
214
  copied_file_data.identifier = filename_id
204
215
  copied_file_data.doc_type = "file"
@@ -285,6 +296,7 @@ class SQLUploaderConfig(UploaderConfig):
285
296
  class SQLUploader(Uploader):
286
297
  upload_config: SQLUploaderConfig
287
298
  connection_config: SQLConnectionConfig
299
+ values_delimiter: str = "?"
288
300
 
289
301
  def precheck(self) -> None:
290
302
  try:
@@ -302,9 +314,25 @@ class SQLUploader(Uploader):
302
314
  ) -> list[tuple[Any, ...]]:
303
315
  pass
304
316
 
305
- @abstractmethod
306
317
  def upload_contents(self, path: Path) -> None:
307
- pass
318
+ df = pd.read_json(path, orient="records", lines=True)
319
+ df.replace({np.nan: None}, inplace=True)
320
+
321
+ columns = list(df.columns)
322
+ stmt = f"INSERT INTO {self.upload_config.table_name} ({','.join(columns)}) VALUES({','.join([self.values_delimiter for x in columns])})" # noqa E501
323
+
324
+ for rows in pd.read_json(
325
+ path, orient="records", lines=True, chunksize=self.upload_config.batch_size
326
+ ):
327
+ with self.connection_config.get_cursor() as cursor:
328
+ values = self.prepare_data(columns, tuple(rows.itertuples(index=False, name=None)))
329
+ # for val in values:
330
+ # try:
331
+ # cursor.execute(stmt, val)
332
+ # except Exception as e:
333
+ # print(f"Error: {e}")
334
+ # print(f"failed to write {len(columns)}, {len(val)}: {stmt} -> {val}")
335
+ cursor.executemany(stmt, values)
308
336
 
309
337
  def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
310
338
  self.upload_contents(path=path)