unstructured-ingest 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (34) hide show
  1. test/integration/connectors/conftest.py +13 -0
  2. test/integration/connectors/databricks_tests/test_volumes_native.py +8 -4
  3. test/integration/connectors/sql/__init__.py +0 -0
  4. test/integration/connectors/{test_postgres.py → sql/test_postgres.py} +76 -2
  5. test/integration/connectors/sql/test_snowflake.py +205 -0
  6. test/integration/connectors/{test_sqlite.py → sql/test_sqlite.py} +68 -12
  7. test/integration/connectors/test_delta_table.py +138 -0
  8. test/integration/connectors/utils/constants.py +1 -1
  9. test/integration/connectors/utils/docker.py +78 -0
  10. test/integration/connectors/utils/validation.py +100 -4
  11. unstructured_ingest/__version__.py +1 -1
  12. unstructured_ingest/v2/cli/utils/click.py +32 -1
  13. unstructured_ingest/v2/cli/utils/model_conversion.py +10 -3
  14. unstructured_ingest/v2/interfaces/indexer.py +4 -1
  15. unstructured_ingest/v2/pipeline/pipeline.py +10 -2
  16. unstructured_ingest/v2/pipeline/steps/index.py +18 -1
  17. unstructured_ingest/v2/processes/connectors/__init__.py +10 -0
  18. unstructured_ingest/v2/processes/connectors/databricks/volumes.py +1 -1
  19. unstructured_ingest/v2/processes/connectors/delta_table.py +185 -0
  20. unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +17 -0
  21. unstructured_ingest/v2/processes/connectors/kdbai.py +14 -6
  22. unstructured_ingest/v2/processes/connectors/slack.py +248 -0
  23. unstructured_ingest/v2/processes/connectors/sql/__init__.py +10 -2
  24. unstructured_ingest/v2/processes/connectors/sql/postgres.py +77 -25
  25. unstructured_ingest/v2/processes/connectors/sql/snowflake.py +164 -0
  26. unstructured_ingest/v2/processes/connectors/sql/sql.py +163 -6
  27. unstructured_ingest/v2/processes/connectors/sql/sqlite.py +86 -24
  28. {unstructured_ingest-0.1.0.dist-info → unstructured_ingest-0.2.0.dist-info}/METADATA +16 -14
  29. {unstructured_ingest-0.1.0.dist-info → unstructured_ingest-0.2.0.dist-info}/RECORD +33 -27
  30. unstructured_ingest/v2/processes/connectors/databricks_volumes.py +0 -250
  31. {unstructured_ingest-0.1.0.dist-info → unstructured_ingest-0.2.0.dist-info}/LICENSE.md +0 -0
  32. {unstructured_ingest-0.1.0.dist-info → unstructured_ingest-0.2.0.dist-info}/WHEEL +0 -0
  33. {unstructured_ingest-0.1.0.dist-info → unstructured_ingest-0.2.0.dist-info}/entry_points.txt +0 -0
  34. {unstructured_ingest-0.1.0.dist-info → unstructured_ingest-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import random
3
4
  from dataclasses import dataclass, field
4
5
  from pathlib import Path
5
6
  from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar
@@ -63,6 +64,7 @@ class FileConfig(BaseModel):
63
64
 
64
65
  class FsspecIndexerConfig(FileConfig, IndexerConfig):
65
66
  recursive: bool = False
67
+ sample_n_files: Optional[int] = None
66
68
 
67
69
 
68
70
  class FsspecAccessConfig(AccessConfig):
@@ -128,8 +130,23 @@ class FsspecIndexer(Indexer):
128
130
  filtered_files = [
129
131
  file for file in files if file.get("size") > 0 and file.get("type") == "file"
130
132
  ]
133
+
134
+ if self.index_config.sample_n_files:
135
+ filtered_files = self.sample_n_files(filtered_files, self.index_config.sample_n_files)
136
+
131
137
  return filtered_files
132
138
 
139
+ def sample_n_files(self, files: list[dict[str, Any]], n) -> list[dict[str, Any]]:
140
+ if len(files) <= n:
141
+ logger.warning(
142
+ f"number of files to be sampled={n} is not smaller than the number"
143
+ f" of files found ({len(files)}). Returning all of the files as the"
144
+ " sample."
145
+ )
146
+ return files
147
+
148
+ return random.sample(files, n)
149
+
133
150
  def get_metadata(self, file_data: dict) -> FileDataSourceMetadata:
134
151
  raise NotImplementedError()
135
152
 
@@ -26,7 +26,7 @@ from unstructured_ingest.v2.processes.connector_registry import (
26
26
  )
27
27
 
28
28
  if TYPE_CHECKING:
29
- from kdbai_client import Session, Table
29
+ from kdbai_client import Database, Session, Table
30
30
 
31
31
  CONNECTOR_TYPE = "kdbai"
32
32
 
@@ -99,6 +99,9 @@ class KdbaiUploadStager(UploadStager):
99
99
 
100
100
 
101
101
  class KdbaiUploaderConfig(UploaderConfig):
102
+ database_name: str = Field(
103
+ default="default", description="The name of the KDBAI database to write into."
104
+ )
102
105
  table_name: str = Field(description="The name of the KDBAI table to write into.")
103
106
  batch_size: int = Field(default=100, description="Number of records per batch")
104
107
 
@@ -111,24 +114,29 @@ class KdbaiUploader(Uploader):
111
114
 
112
115
  def precheck(self) -> None:
113
116
  try:
114
- self.get_table()
117
+ self.get_database()
115
118
  except Exception as e:
116
119
  logger.error(f"Failed to validate connection {e}", exc_info=True)
117
120
  raise DestinationConnectionError(f"failed to validate connection: {e}")
118
121
 
119
- def get_table(self) -> "Table":
122
+ def get_database(self) -> "Database":
120
123
  session: Session = self.connection_config.get_session()
121
- table = session.table(self.upload_config.table_name)
124
+ db = session.database(self.upload_config.database_name)
125
+ return db
126
+
127
+ def get_table(self) -> "Table":
128
+ db = self.get_database()
129
+ table = db.table(self.upload_config.table_name)
122
130
  return table
123
131
 
124
132
  def upsert_batch(self, batch: pd.DataFrame):
125
133
  table = self.get_table()
126
- table.insert(data=batch)
134
+ table.insert(batch)
127
135
 
128
136
  def process_dataframe(self, df: pd.DataFrame):
129
137
  logger.debug(
130
138
  f"uploading {len(df)} entries to {self.connection_config.endpoint} "
131
- f"db in table {self.upload_config.table_name}"
139
+ f"db {self.upload_config.database_name} in table {self.upload_config.table_name}"
132
140
  )
133
141
  for _, batch_df in df.groupby(np.arange(len(df)) // self.upload_config.batch_size):
134
142
  self.upsert_batch(batch=batch_df)
@@ -0,0 +1,248 @@
1
+ import hashlib
2
+ import time
3
+ import xml.etree.ElementTree as ET
4
+ from dataclasses import dataclass, field
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING, Any, Generator, Optional
8
+
9
+ from pydantic import Field, Secret
10
+
11
+ from unstructured_ingest.error import SourceConnectionError
12
+ from unstructured_ingest.logger import logger
13
+ from unstructured_ingest.utils.dep_check import requires_dependencies
14
+ from unstructured_ingest.v2.interfaces import (
15
+ AccessConfig,
16
+ ConnectionConfig,
17
+ Downloader,
18
+ DownloaderConfig,
19
+ Indexer,
20
+ IndexerConfig,
21
+ download_responses,
22
+ )
23
+ from unstructured_ingest.v2.interfaces.file_data import (
24
+ FileData,
25
+ FileDataSourceMetadata,
26
+ SourceIdentifiers,
27
+ )
28
+ from unstructured_ingest.v2.processes.connector_registry import SourceRegistryEntry
29
+
30
+ if TYPE_CHECKING:
31
+ from slack_sdk import WebClient
32
+ from slack_sdk.web.async_client import AsyncWebClient
33
+
34
+ # NOTE: Pagination limit set to the upper end of the recommended range
35
+ # https://api.slack.com/apis/pagination#facts
36
+ PAGINATION_LIMIT = 200
37
+
38
+ CONNECTOR_TYPE = "slack"
39
+
40
+
41
+ class SlackAccessConfig(AccessConfig):
42
+ token: str = Field(
43
+ description="Bot token used to access Slack API, must have channels:history scope for the"
44
+ " bot user."
45
+ )
46
+
47
+
48
+ class SlackConnectionConfig(ConnectionConfig):
49
+ access_config: Secret[SlackAccessConfig]
50
+
51
+ @requires_dependencies(["slack_sdk"], extras="slack")
52
+ @SourceConnectionError.wrap
53
+ def get_client(self) -> "WebClient":
54
+ from slack_sdk import WebClient
55
+
56
+ return WebClient(token=self.access_config.get_secret_value().token)
57
+
58
+ @requires_dependencies(["slack_sdk"], extras="slack")
59
+ @SourceConnectionError.wrap
60
+ def get_async_client(self) -> "AsyncWebClient":
61
+ from slack_sdk.web.async_client import AsyncWebClient
62
+
63
+ return AsyncWebClient(token=self.access_config.get_secret_value().token)
64
+
65
+
66
+ class SlackIndexerConfig(IndexerConfig):
67
+ channels: list[str] = Field(
68
+ description="Comma-delimited list of Slack channel IDs to pull messages from, can be"
69
+ " both public or private channels."
70
+ )
71
+ start_date: Optional[datetime] = Field(
72
+ default=None,
73
+ description="Start date/time in formats YYYY-MM-DD[T]HH:MM[:SS[.ffffff]][Z or [±]HH[:]MM]"
74
+ " or YYYY-MM-DD",
75
+ )
76
+ end_date: Optional[datetime] = Field(
77
+ default=None,
78
+ description="End date/time in formats YYYY-MM-DD[T]HH:MM[:SS[.ffffff]][Z or [±]HH[:]MM]"
79
+ " or YYYY-MM-DD",
80
+ )
81
+
82
+
83
+ @dataclass
84
+ class SlackIndexer(Indexer):
85
+ index_config: SlackIndexerConfig
86
+ connection_config: SlackConnectionConfig
87
+ connector_type: str = CONNECTOR_TYPE
88
+
89
+ def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
90
+ client = self.connection_config.get_client()
91
+ for channel in self.index_config.channels:
92
+ messages = []
93
+ oldest = (
94
+ str(self.index_config.start_date.timestamp())
95
+ if self.index_config.start_date is not None
96
+ else None
97
+ )
98
+ latest = (
99
+ str(self.index_config.end_date.timestamp())
100
+ if self.index_config.end_date is not None
101
+ else None
102
+ )
103
+ for conversation_history in client.conversations_history(
104
+ channel=channel,
105
+ oldest=oldest,
106
+ latest=latest,
107
+ limit=PAGINATION_LIMIT,
108
+ ):
109
+ messages = conversation_history.get("messages", [])
110
+ if messages:
111
+ yield self._messages_to_file_data(messages, channel)
112
+
113
+ def _messages_to_file_data(
114
+ self,
115
+ messages: list[dict],
116
+ channel: str,
117
+ ) -> FileData:
118
+ ts_oldest = min((message["ts"] for message in messages), key=lambda m: float(m))
119
+ ts_newest = max((message["ts"] for message in messages), key=lambda m: float(m))
120
+
121
+ identifier_base = f"{channel}-{ts_oldest}-{ts_newest}"
122
+ identifier = hashlib.sha256(identifier_base.encode("utf-8")).hexdigest()
123
+ filename = identifier[:16]
124
+
125
+ return FileData(
126
+ identifier=identifier,
127
+ connector_type=CONNECTOR_TYPE,
128
+ source_identifiers=SourceIdentifiers(
129
+ filename=f"{filename}.xml", fullpath=f"{filename}.xml"
130
+ ),
131
+ metadata=FileDataSourceMetadata(
132
+ date_created=ts_oldest,
133
+ date_modified=ts_newest,
134
+ date_processed=str(time.time()),
135
+ record_locator={
136
+ "channel": channel,
137
+ "oldest": ts_oldest,
138
+ "latest": ts_newest,
139
+ },
140
+ ),
141
+ )
142
+
143
+ @SourceConnectionError.wrap
144
+ def precheck(self) -> None:
145
+ client = self.connection_config.get_client()
146
+ for channel in self.index_config.channels:
147
+ # NOTE: Querying conversations history guarantees that the bot is in the channel
148
+ client.conversations_history(channel=channel, limit=1)
149
+
150
+
151
+ class SlackDownloaderConfig(DownloaderConfig):
152
+ pass
153
+
154
+
155
+ @dataclass
156
+ class SlackDownloader(Downloader):
157
+ connector_type: str = CONNECTOR_TYPE
158
+ connection_config: SlackConnectionConfig
159
+ download_config: SlackDownloaderConfig = field(default_factory=SlackDownloaderConfig)
160
+
161
+ def run(self, file_data, **kwargs):
162
+ raise NotImplementedError
163
+
164
+ async def run_async(self, file_data: FileData, **kwargs) -> download_responses:
165
+ # NOTE: Indexer should provide source identifiers required to generate the download path
166
+ download_path = self.get_download_path(file_data)
167
+ if download_path is None:
168
+ logger.error(
169
+ "Generated download path is None, source_identifiers might be missing"
170
+ "from FileData."
171
+ )
172
+ raise ValueError("Generated invalid download path.")
173
+
174
+ await self._download_conversation(file_data, download_path)
175
+ return self.generate_download_response(file_data, download_path)
176
+
177
+ def is_async(self):
178
+ return True
179
+
180
+ async def _download_conversation(self, file_data: FileData, download_path: Path) -> None:
181
+ # NOTE: Indexer should supply the record locator in metadata
182
+ if (
183
+ file_data.metadata.record_locator is None
184
+ or "channel" not in file_data.metadata.record_locator
185
+ or "oldest" not in file_data.metadata.record_locator
186
+ or "latest" not in file_data.metadata.record_locator
187
+ ):
188
+ logger.error(
189
+ f"Invalid record locator in metadata: {file_data.metadata.record_locator}."
190
+ "Keys 'channel', 'oldest' and 'latest' must be present."
191
+ )
192
+ raise ValueError("Invalid record locator.")
193
+
194
+ client = self.connection_config.get_async_client()
195
+ messages = []
196
+ async for conversation_history in await client.conversations_history(
197
+ channel=file_data.metadata.record_locator["channel"],
198
+ oldest=file_data.metadata.record_locator["oldest"],
199
+ latest=file_data.metadata.record_locator["latest"],
200
+ limit=PAGINATION_LIMIT,
201
+ # NOTE: In order to get the exact same range of messages as indexer, it provides
202
+ # timestamps of oldest and newest messages, inclusive=True is necessary to include them
203
+ inclusive=True,
204
+ ):
205
+ messages += conversation_history.get("messages", [])
206
+
207
+ conversation = []
208
+ for message in messages:
209
+ thread_messages = []
210
+ async for conversations_replies in await client.conversations_replies(
211
+ channel=file_data.metadata.record_locator["channel"],
212
+ ts=message["ts"],
213
+ limit=PAGINATION_LIMIT,
214
+ ):
215
+ thread_messages += conversations_replies.get("messages", [])
216
+
217
+ # NOTE: Replies contains the whole thread, including the message references by the `ts`
218
+ # parameter even if it's the only message (there were no replies).
219
+ # Reference: https://api.slack.com/methods/conversations.replies#markdown
220
+ conversation.append(thread_messages)
221
+
222
+ conversation_xml = self._conversation_to_xml(conversation)
223
+ download_path.parent.mkdir(exist_ok=True, parents=True)
224
+ conversation_xml.write(download_path, encoding="utf-8", xml_declaration=True)
225
+
226
+ def _conversation_to_xml(self, conversation: list[list[dict]]) -> ET.ElementTree:
227
+ root = ET.Element("messages")
228
+
229
+ for thread in conversation:
230
+ message, *replies = thread
231
+ message_elem = ET.SubElement(root, "message")
232
+ text_elem = ET.SubElement(message_elem, "text")
233
+ text_elem.text = message.get("text")
234
+
235
+ for reply in replies:
236
+ reply_msg = reply.get("text", "")
237
+ text_elem.text = "".join([str(text_elem.text), " <reply> ", reply_msg])
238
+
239
+ return ET.ElementTree(root)
240
+
241
+
242
+ slack_source_entry = SourceRegistryEntry(
243
+ indexer=SlackIndexer,
244
+ indexer_config=SlackIndexerConfig,
245
+ downloader=SlackDownloader,
246
+ downloader_config=DownloaderConfig,
247
+ connection_config=SlackConnectionConfig,
248
+ )
@@ -2,12 +2,20 @@ from __future__ import annotations
2
2
 
3
3
  from unstructured_ingest.v2.processes.connector_registry import (
4
4
  add_destination_entry,
5
+ add_source_entry,
5
6
  )
6
7
 
7
8
  from .postgres import CONNECTOR_TYPE as POSTGRES_CONNECTOR_TYPE
8
- from .postgres import postgres_destination_entry
9
+ from .postgres import postgres_destination_entry, postgres_source_entry
10
+ from .snowflake import CONNECTOR_TYPE as SNOWFLAKE_CONNECTOR_TYPE
11
+ from .snowflake import snowflake_destination_entry, snowflake_source_entry
9
12
  from .sqlite import CONNECTOR_TYPE as SQLITE_CONNECTOR_TYPE
10
- from .sqlite import sqlite_destination_entry
13
+ from .sqlite import sqlite_destination_entry, sqlite_source_entry
14
+
15
+ add_source_entry(source_type=SQLITE_CONNECTOR_TYPE, entry=sqlite_source_entry)
16
+ add_source_entry(source_type=POSTGRES_CONNECTOR_TYPE, entry=postgres_source_entry)
17
+ add_source_entry(source_type=SNOWFLAKE_CONNECTOR_TYPE, entry=snowflake_source_entry)
11
18
 
12
19
  add_destination_entry(destination_type=SQLITE_CONNECTOR_TYPE, entry=sqlite_destination_entry)
13
20
  add_destination_entry(destination_type=POSTGRES_CONNECTOR_TYPE, entry=postgres_destination_entry)
21
+ add_destination_entry(destination_type=SNOWFLAKE_CONNECTOR_TYPE, entry=snowflake_destination_entry)
@@ -1,18 +1,24 @@
1
+ from contextlib import contextmanager
1
2
  from dataclasses import dataclass, field
2
- from pathlib import Path
3
- from typing import TYPE_CHECKING, Any, Optional
3
+ from typing import TYPE_CHECKING, Any, Generator, Optional
4
4
 
5
- import numpy as np
6
- import pandas as pd
7
5
  from pydantic import Field, Secret
8
6
 
9
7
  from unstructured_ingest.utils.dep_check import requires_dependencies
8
+ from unstructured_ingest.v2.interfaces import FileData
10
9
  from unstructured_ingest.v2.logger import logger
11
- from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
10
+ from unstructured_ingest.v2.processes.connector_registry import (
11
+ DestinationRegistryEntry,
12
+ SourceRegistryEntry,
13
+ )
12
14
  from unstructured_ingest.v2.processes.connectors.sql.sql import (
13
15
  _DATE_COLUMNS,
14
16
  SQLAccessConfig,
15
17
  SQLConnectionConfig,
18
+ SQLDownloader,
19
+ SQLDownloaderConfig,
20
+ SQLIndexer,
21
+ SQLIndexerConfig,
16
22
  SQLUploader,
17
23
  SQLUploaderConfig,
18
24
  SQLUploadStager,
@@ -22,6 +28,7 @@ from unstructured_ingest.v2.processes.connectors.sql.sql import (
22
28
 
23
29
  if TYPE_CHECKING:
24
30
  from psycopg2.extensions import connection as PostgresConnection
31
+ from psycopg2.extensions import cursor as PostgresCursor
25
32
 
26
33
  CONNECTOR_TYPE = "postgres"
27
34
 
@@ -43,18 +50,73 @@ class PostgresConnectionConfig(SQLConnectionConfig):
43
50
  port: Optional[int] = Field(default=5432, description="DB host connection port")
44
51
  connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
45
52
 
53
+ @contextmanager
46
54
  @requires_dependencies(["psycopg2"], extras="postgres")
47
- def get_connection(self) -> "PostgresConnection":
55
+ def get_connection(self) -> Generator["PostgresConnection", None, None]:
48
56
  from psycopg2 import connect
49
57
 
50
58
  access_config = self.access_config.get_secret_value()
51
- return connect(
59
+ connection = connect(
52
60
  user=self.username,
53
61
  password=access_config.password,
54
62
  dbname=self.database,
55
63
  host=self.host,
56
64
  port=self.port,
57
65
  )
66
+ try:
67
+ yield connection
68
+ finally:
69
+ connection.commit()
70
+ connection.close()
71
+
72
+ @contextmanager
73
+ def get_cursor(self) -> Generator["PostgresCursor", None, None]:
74
+ with self.get_connection() as connection:
75
+ cursor = connection.cursor()
76
+ try:
77
+ yield cursor
78
+ finally:
79
+ cursor.close()
80
+
81
+
82
+ class PostgresIndexerConfig(SQLIndexerConfig):
83
+ pass
84
+
85
+
86
+ @dataclass
87
+ class PostgresIndexer(SQLIndexer):
88
+ connection_config: PostgresConnectionConfig
89
+ index_config: PostgresIndexerConfig
90
+ connector_type: str = CONNECTOR_TYPE
91
+
92
+
93
+ class PostgresDownloaderConfig(SQLDownloaderConfig):
94
+ pass
95
+
96
+
97
+ @dataclass
98
+ class PostgresDownloader(SQLDownloader):
99
+ connection_config: PostgresConnectionConfig
100
+ download_config: PostgresDownloaderConfig
101
+ connector_type: str = CONNECTOR_TYPE
102
+
103
+ def query_db(self, file_data: FileData) -> tuple[list[tuple], list[str]]:
104
+ table_name = file_data.additional_metadata["table_name"]
105
+ id_column = file_data.additional_metadata["id_column"]
106
+ ids = file_data.additional_metadata["ids"]
107
+ with self.connection_config.get_cursor() as cursor:
108
+ fields = ",".join(self.download_config.fields) if self.download_config.fields else "*"
109
+ query = "SELECT {fields} FROM {table_name} WHERE {id_column} in ({ids})".format(
110
+ fields=fields,
111
+ table_name=table_name,
112
+ id_column=id_column,
113
+ ids=",".join([str(i) for i in ids]),
114
+ )
115
+ logger.debug(f"running query: {query}")
116
+ cursor.execute(query)
117
+ rows = cursor.fetchall()
118
+ columns = [col[0] for col in cursor.description]
119
+ return rows, columns
58
120
 
59
121
 
60
122
  class PostgresUploadStagerConfig(SQLUploadStagerConfig):
@@ -74,6 +136,7 @@ class PostgresUploader(SQLUploader):
74
136
  upload_config: PostgresUploaderConfig = field(default_factory=PostgresUploaderConfig)
75
137
  connection_config: PostgresConnectionConfig
76
138
  connector_type: str = CONNECTOR_TYPE
139
+ values_delimiter: str = "%s"
77
140
 
78
141
  def prepare_data(
79
142
  self, columns: list[str], data: tuple[tuple[Any, ...], ...]
@@ -92,25 +155,14 @@ class PostgresUploader(SQLUploader):
92
155
  output.append(tuple(parsed))
93
156
  return output
94
157
 
95
- def upload_contents(self, path: Path) -> None:
96
- df = pd.read_json(path, orient="records", lines=True)
97
- logger.debug(f"uploading {len(df)} entries to {self.connection_config.database} ")
98
- df.replace({np.nan: None}, inplace=True)
99
-
100
- columns = tuple(df.columns)
101
- stmt = f"INSERT INTO {self.upload_config.table_name} ({','.join(columns)}) \
102
- VALUES({','.join(['%s' for x in columns])})" # noqa E501
103
-
104
- for rows in pd.read_json(
105
- path, orient="records", lines=True, chunksize=self.upload_config.batch_size
106
- ):
107
- with self.connection_config.get_connection() as conn:
108
- values = self.prepare_data(columns, tuple(rows.itertuples(index=False, name=None)))
109
- with conn.cursor() as cur:
110
- cur.executemany(stmt, values)
111
-
112
- conn.commit()
113
158
 
159
+ postgres_source_entry = SourceRegistryEntry(
160
+ connection_config=PostgresConnectionConfig,
161
+ indexer_config=PostgresIndexerConfig,
162
+ indexer=PostgresIndexer,
163
+ downloader_config=PostgresDownloaderConfig,
164
+ downloader=PostgresDownloader,
165
+ )
114
166
 
115
167
  postgres_destination_entry = DestinationRegistryEntry(
116
168
  connection_config=PostgresConnectionConfig,
@@ -0,0 +1,164 @@
1
+ from contextlib import contextmanager
2
+ from dataclasses import dataclass, field
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Generator, Optional
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from pydantic import Field, Secret
9
+
10
+ from unstructured_ingest.utils.dep_check import requires_dependencies
11
+ from unstructured_ingest.v2.processes.connector_registry import (
12
+ DestinationRegistryEntry,
13
+ SourceRegistryEntry,
14
+ )
15
+ from unstructured_ingest.v2.processes.connectors.sql.postgres import (
16
+ PostgresDownloader,
17
+ PostgresDownloaderConfig,
18
+ PostgresIndexer,
19
+ PostgresIndexerConfig,
20
+ PostgresUploader,
21
+ PostgresUploaderConfig,
22
+ PostgresUploadStager,
23
+ PostgresUploadStagerConfig,
24
+ )
25
+ from unstructured_ingest.v2.processes.connectors.sql.sql import SQLAccessConfig, SQLConnectionConfig
26
+
27
+ if TYPE_CHECKING:
28
+ from snowflake.connector import SnowflakeConnection
29
+ from snowflake.connector.cursor import SnowflakeCursor
30
+
31
+ CONNECTOR_TYPE = "snowflake"
32
+
33
+
34
+ class SnowflakeAccessConfig(SQLAccessConfig):
35
+ password: Optional[str] = Field(default=None, description="DB password")
36
+
37
+
38
+ class SnowflakeConnectionConfig(SQLConnectionConfig):
39
+ access_config: Secret[SnowflakeAccessConfig] = Field(
40
+ default=SnowflakeAccessConfig(), validate_default=True
41
+ )
42
+ account: str = Field(
43
+ default=None,
44
+ description="Your account identifier. The account identifier "
45
+ "does not include the snowflakecomputing.com suffix.",
46
+ )
47
+ user: Optional[str] = Field(default=None, description="DB username")
48
+ host: Optional[str] = Field(default=None, description="DB host")
49
+ port: Optional[int] = Field(default=443, description="DB host connection port")
50
+ database: str = Field(
51
+ default=None,
52
+ description="Database name.",
53
+ )
54
+ schema: str = Field(
55
+ default=None,
56
+ description="Database schema.",
57
+ )
58
+ role: str = Field(
59
+ default=None,
60
+ description="Database role.",
61
+ )
62
+ connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
63
+
64
+ @contextmanager
65
+ @requires_dependencies(["snowflake"], extras="snowflake")
66
+ def get_connection(self) -> Generator["SnowflakeConnection", None, None]:
67
+ # https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api#label-snowflake-connector-methods-connect
68
+ from snowflake.connector import connect
69
+
70
+ connect_kwargs = self.model_dump()
71
+ connect_kwargs.pop("access_configs", None)
72
+ connect_kwargs["password"] = self.access_config.get_secret_value().password
73
+ # https://peps.python.org/pep-0249/#paramstyle
74
+ connect_kwargs["paramstyle"] = "qmark"
75
+ connection = connect(**connect_kwargs)
76
+ try:
77
+ yield connection
78
+ finally:
79
+ connection.commit()
80
+ connection.close()
81
+
82
+ @contextmanager
83
+ def get_cursor(self) -> Generator["SnowflakeCursor", None, None]:
84
+ with self.get_connection() as connection:
85
+ cursor = connection.cursor()
86
+ try:
87
+ yield cursor
88
+ finally:
89
+ cursor.close()
90
+
91
+
92
+ class SnowflakeIndexerConfig(PostgresIndexerConfig):
93
+ pass
94
+
95
+
96
+ @dataclass
97
+ class SnowflakeIndexer(PostgresIndexer):
98
+ connection_config: SnowflakeConnectionConfig
99
+ index_config: SnowflakeIndexerConfig
100
+ connector_type: str = CONNECTOR_TYPE
101
+
102
+
103
+ class SnowflakeDownloaderConfig(PostgresDownloaderConfig):
104
+ pass
105
+
106
+
107
+ @dataclass
108
+ class SnowflakeDownloader(PostgresDownloader):
109
+ connection_config: SnowflakeConnectionConfig
110
+ download_config: SnowflakeDownloaderConfig
111
+ connector_type: str = CONNECTOR_TYPE
112
+
113
+
114
+ class SnowflakeUploadStagerConfig(PostgresUploadStagerConfig):
115
+ pass
116
+
117
+
118
+ class SnowflakeUploadStager(PostgresUploadStager):
119
+ upload_stager_config: SnowflakeUploadStagerConfig
120
+
121
+
122
+ class SnowflakeUploaderConfig(PostgresUploaderConfig):
123
+ pass
124
+
125
+
126
+ @dataclass
127
+ class SnowflakeUploader(PostgresUploader):
128
+ upload_config: SnowflakeUploaderConfig = field(default_factory=SnowflakeUploaderConfig)
129
+ connection_config: SnowflakeConnectionConfig
130
+ connector_type: str = CONNECTOR_TYPE
131
+ values_delimiter: str = "?"
132
+
133
+ def upload_contents(self, path: Path) -> None:
134
+ df = pd.read_json(path, orient="records", lines=True)
135
+ df.replace({np.nan: None}, inplace=True)
136
+
137
+ columns = list(df.columns)
138
+ stmt = f"INSERT INTO {self.upload_config.table_name} ({','.join(columns)}) VALUES({','.join([self.values_delimiter for x in columns])})" # noqa E501
139
+
140
+ for rows in pd.read_json(
141
+ path, orient="records", lines=True, chunksize=self.upload_config.batch_size
142
+ ):
143
+ with self.connection_config.get_cursor() as cursor:
144
+ values = self.prepare_data(columns, tuple(rows.itertuples(index=False, name=None)))
145
+ # TODO: executemany break on 'Binding data in type (list) is not supported'
146
+ for val in values:
147
+ cursor.execute(stmt, val)
148
+
149
+
150
+ snowflake_source_entry = SourceRegistryEntry(
151
+ connection_config=SnowflakeConnectionConfig,
152
+ indexer_config=SnowflakeIndexerConfig,
153
+ indexer=SnowflakeIndexer,
154
+ downloader_config=SnowflakeDownloaderConfig,
155
+ downloader=SnowflakeDownloader,
156
+ )
157
+
158
+ snowflake_destination_entry = DestinationRegistryEntry(
159
+ connection_config=SnowflakeConnectionConfig,
160
+ uploader=SnowflakeUploader,
161
+ uploader_config=SnowflakeUploaderConfig,
162
+ upload_stager=SnowflakeUploadStager,
163
+ upload_stager_config=SnowflakeUploadStagerConfig,
164
+ )