unstructured-ingest 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/conftest.py +13 -0
- test/integration/connectors/databricks_tests/test_volumes_native.py +8 -4
- test/integration/connectors/sql/__init__.py +0 -0
- test/integration/connectors/{test_postgres.py → sql/test_postgres.py} +76 -2
- test/integration/connectors/sql/test_snowflake.py +205 -0
- test/integration/connectors/{test_sqlite.py → sql/test_sqlite.py} +68 -12
- test/integration/connectors/test_delta_table.py +138 -0
- test/integration/connectors/utils/constants.py +1 -1
- test/integration/connectors/utils/docker.py +78 -0
- test/integration/connectors/utils/validation.py +100 -4
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/v2/cli/utils/click.py +32 -1
- unstructured_ingest/v2/cli/utils/model_conversion.py +10 -3
- unstructured_ingest/v2/interfaces/indexer.py +4 -1
- unstructured_ingest/v2/pipeline/pipeline.py +10 -2
- unstructured_ingest/v2/pipeline/steps/index.py +18 -1
- unstructured_ingest/v2/processes/connectors/__init__.py +10 -0
- unstructured_ingest/v2/processes/connectors/databricks/volumes.py +1 -1
- unstructured_ingest/v2/processes/connectors/delta_table.py +185 -0
- unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +17 -0
- unstructured_ingest/v2/processes/connectors/kdbai.py +14 -6
- unstructured_ingest/v2/processes/connectors/slack.py +248 -0
- unstructured_ingest/v2/processes/connectors/sql/__init__.py +10 -2
- unstructured_ingest/v2/processes/connectors/sql/postgres.py +77 -25
- unstructured_ingest/v2/processes/connectors/sql/snowflake.py +164 -0
- unstructured_ingest/v2/processes/connectors/sql/sql.py +163 -6
- unstructured_ingest/v2/processes/connectors/sql/sqlite.py +86 -24
- {unstructured_ingest-0.1.0.dist-info → unstructured_ingest-0.2.0.dist-info}/METADATA +16 -14
- {unstructured_ingest-0.1.0.dist-info → unstructured_ingest-0.2.0.dist-info}/RECORD +33 -27
- unstructured_ingest/v2/processes/connectors/databricks_volumes.py +0 -250
- {unstructured_ingest-0.1.0.dist-info → unstructured_ingest-0.2.0.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.1.0.dist-info → unstructured_ingest-0.2.0.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.1.0.dist-info → unstructured_ingest-0.2.0.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.1.0.dist-info → unstructured_ingest-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import random
|
|
3
4
|
from dataclasses import dataclass, field
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar
|
|
@@ -63,6 +64,7 @@ class FileConfig(BaseModel):
|
|
|
63
64
|
|
|
64
65
|
class FsspecIndexerConfig(FileConfig, IndexerConfig):
|
|
65
66
|
recursive: bool = False
|
|
67
|
+
sample_n_files: Optional[int] = None
|
|
66
68
|
|
|
67
69
|
|
|
68
70
|
class FsspecAccessConfig(AccessConfig):
|
|
@@ -128,8 +130,23 @@ class FsspecIndexer(Indexer):
|
|
|
128
130
|
filtered_files = [
|
|
129
131
|
file for file in files if file.get("size") > 0 and file.get("type") == "file"
|
|
130
132
|
]
|
|
133
|
+
|
|
134
|
+
if self.index_config.sample_n_files:
|
|
135
|
+
filtered_files = self.sample_n_files(filtered_files, self.index_config.sample_n_files)
|
|
136
|
+
|
|
131
137
|
return filtered_files
|
|
132
138
|
|
|
139
|
+
def sample_n_files(self, files: list[dict[str, Any]], n) -> list[dict[str, Any]]:
|
|
140
|
+
if len(files) <= n:
|
|
141
|
+
logger.warning(
|
|
142
|
+
f"number of files to be sampled={n} is not smaller than the number"
|
|
143
|
+
f" of files found ({len(files)}). Returning all of the files as the"
|
|
144
|
+
" sample."
|
|
145
|
+
)
|
|
146
|
+
return files
|
|
147
|
+
|
|
148
|
+
return random.sample(files, n)
|
|
149
|
+
|
|
133
150
|
def get_metadata(self, file_data: dict) -> FileDataSourceMetadata:
|
|
134
151
|
raise NotImplementedError()
|
|
135
152
|
|
|
@@ -26,7 +26,7 @@ from unstructured_ingest.v2.processes.connector_registry import (
|
|
|
26
26
|
)
|
|
27
27
|
|
|
28
28
|
if TYPE_CHECKING:
|
|
29
|
-
from kdbai_client import Session, Table
|
|
29
|
+
from kdbai_client import Database, Session, Table
|
|
30
30
|
|
|
31
31
|
CONNECTOR_TYPE = "kdbai"
|
|
32
32
|
|
|
@@ -99,6 +99,9 @@ class KdbaiUploadStager(UploadStager):
|
|
|
99
99
|
|
|
100
100
|
|
|
101
101
|
class KdbaiUploaderConfig(UploaderConfig):
|
|
102
|
+
database_name: str = Field(
|
|
103
|
+
default="default", description="The name of the KDBAI database to write into."
|
|
104
|
+
)
|
|
102
105
|
table_name: str = Field(description="The name of the KDBAI table to write into.")
|
|
103
106
|
batch_size: int = Field(default=100, description="Number of records per batch")
|
|
104
107
|
|
|
@@ -111,24 +114,29 @@ class KdbaiUploader(Uploader):
|
|
|
111
114
|
|
|
112
115
|
def precheck(self) -> None:
|
|
113
116
|
try:
|
|
114
|
-
self.
|
|
117
|
+
self.get_database()
|
|
115
118
|
except Exception as e:
|
|
116
119
|
logger.error(f"Failed to validate connection {e}", exc_info=True)
|
|
117
120
|
raise DestinationConnectionError(f"failed to validate connection: {e}")
|
|
118
121
|
|
|
119
|
-
def
|
|
122
|
+
def get_database(self) -> "Database":
|
|
120
123
|
session: Session = self.connection_config.get_session()
|
|
121
|
-
|
|
124
|
+
db = session.database(self.upload_config.database_name)
|
|
125
|
+
return db
|
|
126
|
+
|
|
127
|
+
def get_table(self) -> "Table":
|
|
128
|
+
db = self.get_database()
|
|
129
|
+
table = db.table(self.upload_config.table_name)
|
|
122
130
|
return table
|
|
123
131
|
|
|
124
132
|
def upsert_batch(self, batch: pd.DataFrame):
|
|
125
133
|
table = self.get_table()
|
|
126
|
-
table.insert(
|
|
134
|
+
table.insert(batch)
|
|
127
135
|
|
|
128
136
|
def process_dataframe(self, df: pd.DataFrame):
|
|
129
137
|
logger.debug(
|
|
130
138
|
f"uploading {len(df)} entries to {self.connection_config.endpoint} "
|
|
131
|
-
f"db in table {self.upload_config.table_name}"
|
|
139
|
+
f"db {self.upload_config.database_name} in table {self.upload_config.table_name}"
|
|
132
140
|
)
|
|
133
141
|
for _, batch_df in df.groupby(np.arange(len(df)) // self.upload_config.batch_size):
|
|
134
142
|
self.upsert_batch(batch=batch_df)
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import time
|
|
3
|
+
import xml.etree.ElementTree as ET
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Generator, Optional
|
|
8
|
+
|
|
9
|
+
from pydantic import Field, Secret
|
|
10
|
+
|
|
11
|
+
from unstructured_ingest.error import SourceConnectionError
|
|
12
|
+
from unstructured_ingest.logger import logger
|
|
13
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
14
|
+
from unstructured_ingest.v2.interfaces import (
|
|
15
|
+
AccessConfig,
|
|
16
|
+
ConnectionConfig,
|
|
17
|
+
Downloader,
|
|
18
|
+
DownloaderConfig,
|
|
19
|
+
Indexer,
|
|
20
|
+
IndexerConfig,
|
|
21
|
+
download_responses,
|
|
22
|
+
)
|
|
23
|
+
from unstructured_ingest.v2.interfaces.file_data import (
|
|
24
|
+
FileData,
|
|
25
|
+
FileDataSourceMetadata,
|
|
26
|
+
SourceIdentifiers,
|
|
27
|
+
)
|
|
28
|
+
from unstructured_ingest.v2.processes.connector_registry import SourceRegistryEntry
|
|
29
|
+
|
|
30
|
+
if TYPE_CHECKING:
|
|
31
|
+
from slack_sdk import WebClient
|
|
32
|
+
from slack_sdk.web.async_client import AsyncWebClient
|
|
33
|
+
|
|
34
|
+
# NOTE: Pagination limit set to the upper end of the recommended range
|
|
35
|
+
# https://api.slack.com/apis/pagination#facts
|
|
36
|
+
PAGINATION_LIMIT = 200
|
|
37
|
+
|
|
38
|
+
CONNECTOR_TYPE = "slack"
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SlackAccessConfig(AccessConfig):
|
|
42
|
+
token: str = Field(
|
|
43
|
+
description="Bot token used to access Slack API, must have channels:history scope for the"
|
|
44
|
+
" bot user."
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SlackConnectionConfig(ConnectionConfig):
|
|
49
|
+
access_config: Secret[SlackAccessConfig]
|
|
50
|
+
|
|
51
|
+
@requires_dependencies(["slack_sdk"], extras="slack")
|
|
52
|
+
@SourceConnectionError.wrap
|
|
53
|
+
def get_client(self) -> "WebClient":
|
|
54
|
+
from slack_sdk import WebClient
|
|
55
|
+
|
|
56
|
+
return WebClient(token=self.access_config.get_secret_value().token)
|
|
57
|
+
|
|
58
|
+
@requires_dependencies(["slack_sdk"], extras="slack")
|
|
59
|
+
@SourceConnectionError.wrap
|
|
60
|
+
def get_async_client(self) -> "AsyncWebClient":
|
|
61
|
+
from slack_sdk.web.async_client import AsyncWebClient
|
|
62
|
+
|
|
63
|
+
return AsyncWebClient(token=self.access_config.get_secret_value().token)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class SlackIndexerConfig(IndexerConfig):
|
|
67
|
+
channels: list[str] = Field(
|
|
68
|
+
description="Comma-delimited list of Slack channel IDs to pull messages from, can be"
|
|
69
|
+
" both public or private channels."
|
|
70
|
+
)
|
|
71
|
+
start_date: Optional[datetime] = Field(
|
|
72
|
+
default=None,
|
|
73
|
+
description="Start date/time in formats YYYY-MM-DD[T]HH:MM[:SS[.ffffff]][Z or [±]HH[:]MM]"
|
|
74
|
+
" or YYYY-MM-DD",
|
|
75
|
+
)
|
|
76
|
+
end_date: Optional[datetime] = Field(
|
|
77
|
+
default=None,
|
|
78
|
+
description="End date/time in formats YYYY-MM-DD[T]HH:MM[:SS[.ffffff]][Z or [±]HH[:]MM]"
|
|
79
|
+
" or YYYY-MM-DD",
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class SlackIndexer(Indexer):
|
|
85
|
+
index_config: SlackIndexerConfig
|
|
86
|
+
connection_config: SlackConnectionConfig
|
|
87
|
+
connector_type: str = CONNECTOR_TYPE
|
|
88
|
+
|
|
89
|
+
def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
|
|
90
|
+
client = self.connection_config.get_client()
|
|
91
|
+
for channel in self.index_config.channels:
|
|
92
|
+
messages = []
|
|
93
|
+
oldest = (
|
|
94
|
+
str(self.index_config.start_date.timestamp())
|
|
95
|
+
if self.index_config.start_date is not None
|
|
96
|
+
else None
|
|
97
|
+
)
|
|
98
|
+
latest = (
|
|
99
|
+
str(self.index_config.end_date.timestamp())
|
|
100
|
+
if self.index_config.end_date is not None
|
|
101
|
+
else None
|
|
102
|
+
)
|
|
103
|
+
for conversation_history in client.conversations_history(
|
|
104
|
+
channel=channel,
|
|
105
|
+
oldest=oldest,
|
|
106
|
+
latest=latest,
|
|
107
|
+
limit=PAGINATION_LIMIT,
|
|
108
|
+
):
|
|
109
|
+
messages = conversation_history.get("messages", [])
|
|
110
|
+
if messages:
|
|
111
|
+
yield self._messages_to_file_data(messages, channel)
|
|
112
|
+
|
|
113
|
+
def _messages_to_file_data(
|
|
114
|
+
self,
|
|
115
|
+
messages: list[dict],
|
|
116
|
+
channel: str,
|
|
117
|
+
) -> FileData:
|
|
118
|
+
ts_oldest = min((message["ts"] for message in messages), key=lambda m: float(m))
|
|
119
|
+
ts_newest = max((message["ts"] for message in messages), key=lambda m: float(m))
|
|
120
|
+
|
|
121
|
+
identifier_base = f"{channel}-{ts_oldest}-{ts_newest}"
|
|
122
|
+
identifier = hashlib.sha256(identifier_base.encode("utf-8")).hexdigest()
|
|
123
|
+
filename = identifier[:16]
|
|
124
|
+
|
|
125
|
+
return FileData(
|
|
126
|
+
identifier=identifier,
|
|
127
|
+
connector_type=CONNECTOR_TYPE,
|
|
128
|
+
source_identifiers=SourceIdentifiers(
|
|
129
|
+
filename=f"{filename}.xml", fullpath=f"{filename}.xml"
|
|
130
|
+
),
|
|
131
|
+
metadata=FileDataSourceMetadata(
|
|
132
|
+
date_created=ts_oldest,
|
|
133
|
+
date_modified=ts_newest,
|
|
134
|
+
date_processed=str(time.time()),
|
|
135
|
+
record_locator={
|
|
136
|
+
"channel": channel,
|
|
137
|
+
"oldest": ts_oldest,
|
|
138
|
+
"latest": ts_newest,
|
|
139
|
+
},
|
|
140
|
+
),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
@SourceConnectionError.wrap
|
|
144
|
+
def precheck(self) -> None:
|
|
145
|
+
client = self.connection_config.get_client()
|
|
146
|
+
for channel in self.index_config.channels:
|
|
147
|
+
# NOTE: Querying conversations history guarantees that the bot is in the channel
|
|
148
|
+
client.conversations_history(channel=channel, limit=1)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class SlackDownloaderConfig(DownloaderConfig):
|
|
152
|
+
pass
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
@dataclass
|
|
156
|
+
class SlackDownloader(Downloader):
|
|
157
|
+
connector_type: str = CONNECTOR_TYPE
|
|
158
|
+
connection_config: SlackConnectionConfig
|
|
159
|
+
download_config: SlackDownloaderConfig = field(default_factory=SlackDownloaderConfig)
|
|
160
|
+
|
|
161
|
+
def run(self, file_data, **kwargs):
|
|
162
|
+
raise NotImplementedError
|
|
163
|
+
|
|
164
|
+
async def run_async(self, file_data: FileData, **kwargs) -> download_responses:
|
|
165
|
+
# NOTE: Indexer should provide source identifiers required to generate the download path
|
|
166
|
+
download_path = self.get_download_path(file_data)
|
|
167
|
+
if download_path is None:
|
|
168
|
+
logger.error(
|
|
169
|
+
"Generated download path is None, source_identifiers might be missing"
|
|
170
|
+
"from FileData."
|
|
171
|
+
)
|
|
172
|
+
raise ValueError("Generated invalid download path.")
|
|
173
|
+
|
|
174
|
+
await self._download_conversation(file_data, download_path)
|
|
175
|
+
return self.generate_download_response(file_data, download_path)
|
|
176
|
+
|
|
177
|
+
def is_async(self):
|
|
178
|
+
return True
|
|
179
|
+
|
|
180
|
+
async def _download_conversation(self, file_data: FileData, download_path: Path) -> None:
|
|
181
|
+
# NOTE: Indexer should supply the record locator in metadata
|
|
182
|
+
if (
|
|
183
|
+
file_data.metadata.record_locator is None
|
|
184
|
+
or "channel" not in file_data.metadata.record_locator
|
|
185
|
+
or "oldest" not in file_data.metadata.record_locator
|
|
186
|
+
or "latest" not in file_data.metadata.record_locator
|
|
187
|
+
):
|
|
188
|
+
logger.error(
|
|
189
|
+
f"Invalid record locator in metadata: {file_data.metadata.record_locator}."
|
|
190
|
+
"Keys 'channel', 'oldest' and 'latest' must be present."
|
|
191
|
+
)
|
|
192
|
+
raise ValueError("Invalid record locator.")
|
|
193
|
+
|
|
194
|
+
client = self.connection_config.get_async_client()
|
|
195
|
+
messages = []
|
|
196
|
+
async for conversation_history in await client.conversations_history(
|
|
197
|
+
channel=file_data.metadata.record_locator["channel"],
|
|
198
|
+
oldest=file_data.metadata.record_locator["oldest"],
|
|
199
|
+
latest=file_data.metadata.record_locator["latest"],
|
|
200
|
+
limit=PAGINATION_LIMIT,
|
|
201
|
+
# NOTE: In order to get the exact same range of messages as indexer, it provides
|
|
202
|
+
# timestamps of oldest and newest messages, inclusive=True is necessary to include them
|
|
203
|
+
inclusive=True,
|
|
204
|
+
):
|
|
205
|
+
messages += conversation_history.get("messages", [])
|
|
206
|
+
|
|
207
|
+
conversation = []
|
|
208
|
+
for message in messages:
|
|
209
|
+
thread_messages = []
|
|
210
|
+
async for conversations_replies in await client.conversations_replies(
|
|
211
|
+
channel=file_data.metadata.record_locator["channel"],
|
|
212
|
+
ts=message["ts"],
|
|
213
|
+
limit=PAGINATION_LIMIT,
|
|
214
|
+
):
|
|
215
|
+
thread_messages += conversations_replies.get("messages", [])
|
|
216
|
+
|
|
217
|
+
# NOTE: Replies contains the whole thread, including the message references by the `ts`
|
|
218
|
+
# parameter even if it's the only message (there were no replies).
|
|
219
|
+
# Reference: https://api.slack.com/methods/conversations.replies#markdown
|
|
220
|
+
conversation.append(thread_messages)
|
|
221
|
+
|
|
222
|
+
conversation_xml = self._conversation_to_xml(conversation)
|
|
223
|
+
download_path.parent.mkdir(exist_ok=True, parents=True)
|
|
224
|
+
conversation_xml.write(download_path, encoding="utf-8", xml_declaration=True)
|
|
225
|
+
|
|
226
|
+
def _conversation_to_xml(self, conversation: list[list[dict]]) -> ET.ElementTree:
|
|
227
|
+
root = ET.Element("messages")
|
|
228
|
+
|
|
229
|
+
for thread in conversation:
|
|
230
|
+
message, *replies = thread
|
|
231
|
+
message_elem = ET.SubElement(root, "message")
|
|
232
|
+
text_elem = ET.SubElement(message_elem, "text")
|
|
233
|
+
text_elem.text = message.get("text")
|
|
234
|
+
|
|
235
|
+
for reply in replies:
|
|
236
|
+
reply_msg = reply.get("text", "")
|
|
237
|
+
text_elem.text = "".join([str(text_elem.text), " <reply> ", reply_msg])
|
|
238
|
+
|
|
239
|
+
return ET.ElementTree(root)
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
slack_source_entry = SourceRegistryEntry(
|
|
243
|
+
indexer=SlackIndexer,
|
|
244
|
+
indexer_config=SlackIndexerConfig,
|
|
245
|
+
downloader=SlackDownloader,
|
|
246
|
+
downloader_config=DownloaderConfig,
|
|
247
|
+
connection_config=SlackConnectionConfig,
|
|
248
|
+
)
|
|
@@ -2,12 +2,20 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from unstructured_ingest.v2.processes.connector_registry import (
|
|
4
4
|
add_destination_entry,
|
|
5
|
+
add_source_entry,
|
|
5
6
|
)
|
|
6
7
|
|
|
7
8
|
from .postgres import CONNECTOR_TYPE as POSTGRES_CONNECTOR_TYPE
|
|
8
|
-
from .postgres import postgres_destination_entry
|
|
9
|
+
from .postgres import postgres_destination_entry, postgres_source_entry
|
|
10
|
+
from .snowflake import CONNECTOR_TYPE as SNOWFLAKE_CONNECTOR_TYPE
|
|
11
|
+
from .snowflake import snowflake_destination_entry, snowflake_source_entry
|
|
9
12
|
from .sqlite import CONNECTOR_TYPE as SQLITE_CONNECTOR_TYPE
|
|
10
|
-
from .sqlite import sqlite_destination_entry
|
|
13
|
+
from .sqlite import sqlite_destination_entry, sqlite_source_entry
|
|
14
|
+
|
|
15
|
+
add_source_entry(source_type=SQLITE_CONNECTOR_TYPE, entry=sqlite_source_entry)
|
|
16
|
+
add_source_entry(source_type=POSTGRES_CONNECTOR_TYPE, entry=postgres_source_entry)
|
|
17
|
+
add_source_entry(source_type=SNOWFLAKE_CONNECTOR_TYPE, entry=snowflake_source_entry)
|
|
11
18
|
|
|
12
19
|
add_destination_entry(destination_type=SQLITE_CONNECTOR_TYPE, entry=sqlite_destination_entry)
|
|
13
20
|
add_destination_entry(destination_type=POSTGRES_CONNECTOR_TYPE, entry=postgres_destination_entry)
|
|
21
|
+
add_destination_entry(destination_type=SNOWFLAKE_CONNECTOR_TYPE, entry=snowflake_destination_entry)
|
|
@@ -1,18 +1,24 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
1
2
|
from dataclasses import dataclass, field
|
|
2
|
-
from
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Optional
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Generator, Optional
|
|
4
4
|
|
|
5
|
-
import numpy as np
|
|
6
|
-
import pandas as pd
|
|
7
5
|
from pydantic import Field, Secret
|
|
8
6
|
|
|
9
7
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
8
|
+
from unstructured_ingest.v2.interfaces import FileData
|
|
10
9
|
from unstructured_ingest.v2.logger import logger
|
|
11
|
-
from unstructured_ingest.v2.processes.connector_registry import
|
|
10
|
+
from unstructured_ingest.v2.processes.connector_registry import (
|
|
11
|
+
DestinationRegistryEntry,
|
|
12
|
+
SourceRegistryEntry,
|
|
13
|
+
)
|
|
12
14
|
from unstructured_ingest.v2.processes.connectors.sql.sql import (
|
|
13
15
|
_DATE_COLUMNS,
|
|
14
16
|
SQLAccessConfig,
|
|
15
17
|
SQLConnectionConfig,
|
|
18
|
+
SQLDownloader,
|
|
19
|
+
SQLDownloaderConfig,
|
|
20
|
+
SQLIndexer,
|
|
21
|
+
SQLIndexerConfig,
|
|
16
22
|
SQLUploader,
|
|
17
23
|
SQLUploaderConfig,
|
|
18
24
|
SQLUploadStager,
|
|
@@ -22,6 +28,7 @@ from unstructured_ingest.v2.processes.connectors.sql.sql import (
|
|
|
22
28
|
|
|
23
29
|
if TYPE_CHECKING:
|
|
24
30
|
from psycopg2.extensions import connection as PostgresConnection
|
|
31
|
+
from psycopg2.extensions import cursor as PostgresCursor
|
|
25
32
|
|
|
26
33
|
CONNECTOR_TYPE = "postgres"
|
|
27
34
|
|
|
@@ -43,18 +50,73 @@ class PostgresConnectionConfig(SQLConnectionConfig):
|
|
|
43
50
|
port: Optional[int] = Field(default=5432, description="DB host connection port")
|
|
44
51
|
connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
|
|
45
52
|
|
|
53
|
+
@contextmanager
|
|
46
54
|
@requires_dependencies(["psycopg2"], extras="postgres")
|
|
47
|
-
def get_connection(self) -> "PostgresConnection":
|
|
55
|
+
def get_connection(self) -> Generator["PostgresConnection", None, None]:
|
|
48
56
|
from psycopg2 import connect
|
|
49
57
|
|
|
50
58
|
access_config = self.access_config.get_secret_value()
|
|
51
|
-
|
|
59
|
+
connection = connect(
|
|
52
60
|
user=self.username,
|
|
53
61
|
password=access_config.password,
|
|
54
62
|
dbname=self.database,
|
|
55
63
|
host=self.host,
|
|
56
64
|
port=self.port,
|
|
57
65
|
)
|
|
66
|
+
try:
|
|
67
|
+
yield connection
|
|
68
|
+
finally:
|
|
69
|
+
connection.commit()
|
|
70
|
+
connection.close()
|
|
71
|
+
|
|
72
|
+
@contextmanager
|
|
73
|
+
def get_cursor(self) -> Generator["PostgresCursor", None, None]:
|
|
74
|
+
with self.get_connection() as connection:
|
|
75
|
+
cursor = connection.cursor()
|
|
76
|
+
try:
|
|
77
|
+
yield cursor
|
|
78
|
+
finally:
|
|
79
|
+
cursor.close()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class PostgresIndexerConfig(SQLIndexerConfig):
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
@dataclass
|
|
87
|
+
class PostgresIndexer(SQLIndexer):
|
|
88
|
+
connection_config: PostgresConnectionConfig
|
|
89
|
+
index_config: PostgresIndexerConfig
|
|
90
|
+
connector_type: str = CONNECTOR_TYPE
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class PostgresDownloaderConfig(SQLDownloaderConfig):
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
@dataclass
|
|
98
|
+
class PostgresDownloader(SQLDownloader):
|
|
99
|
+
connection_config: PostgresConnectionConfig
|
|
100
|
+
download_config: PostgresDownloaderConfig
|
|
101
|
+
connector_type: str = CONNECTOR_TYPE
|
|
102
|
+
|
|
103
|
+
def query_db(self, file_data: FileData) -> tuple[list[tuple], list[str]]:
|
|
104
|
+
table_name = file_data.additional_metadata["table_name"]
|
|
105
|
+
id_column = file_data.additional_metadata["id_column"]
|
|
106
|
+
ids = file_data.additional_metadata["ids"]
|
|
107
|
+
with self.connection_config.get_cursor() as cursor:
|
|
108
|
+
fields = ",".join(self.download_config.fields) if self.download_config.fields else "*"
|
|
109
|
+
query = "SELECT {fields} FROM {table_name} WHERE {id_column} in ({ids})".format(
|
|
110
|
+
fields=fields,
|
|
111
|
+
table_name=table_name,
|
|
112
|
+
id_column=id_column,
|
|
113
|
+
ids=",".join([str(i) for i in ids]),
|
|
114
|
+
)
|
|
115
|
+
logger.debug(f"running query: {query}")
|
|
116
|
+
cursor.execute(query)
|
|
117
|
+
rows = cursor.fetchall()
|
|
118
|
+
columns = [col[0] for col in cursor.description]
|
|
119
|
+
return rows, columns
|
|
58
120
|
|
|
59
121
|
|
|
60
122
|
class PostgresUploadStagerConfig(SQLUploadStagerConfig):
|
|
@@ -74,6 +136,7 @@ class PostgresUploader(SQLUploader):
|
|
|
74
136
|
upload_config: PostgresUploaderConfig = field(default_factory=PostgresUploaderConfig)
|
|
75
137
|
connection_config: PostgresConnectionConfig
|
|
76
138
|
connector_type: str = CONNECTOR_TYPE
|
|
139
|
+
values_delimiter: str = "%s"
|
|
77
140
|
|
|
78
141
|
def prepare_data(
|
|
79
142
|
self, columns: list[str], data: tuple[tuple[Any, ...], ...]
|
|
@@ -92,25 +155,14 @@ class PostgresUploader(SQLUploader):
|
|
|
92
155
|
output.append(tuple(parsed))
|
|
93
156
|
return output
|
|
94
157
|
|
|
95
|
-
def upload_contents(self, path: Path) -> None:
|
|
96
|
-
df = pd.read_json(path, orient="records", lines=True)
|
|
97
|
-
logger.debug(f"uploading {len(df)} entries to {self.connection_config.database} ")
|
|
98
|
-
df.replace({np.nan: None}, inplace=True)
|
|
99
|
-
|
|
100
|
-
columns = tuple(df.columns)
|
|
101
|
-
stmt = f"INSERT INTO {self.upload_config.table_name} ({','.join(columns)}) \
|
|
102
|
-
VALUES({','.join(['%s' for x in columns])})" # noqa E501
|
|
103
|
-
|
|
104
|
-
for rows in pd.read_json(
|
|
105
|
-
path, orient="records", lines=True, chunksize=self.upload_config.batch_size
|
|
106
|
-
):
|
|
107
|
-
with self.connection_config.get_connection() as conn:
|
|
108
|
-
values = self.prepare_data(columns, tuple(rows.itertuples(index=False, name=None)))
|
|
109
|
-
with conn.cursor() as cur:
|
|
110
|
-
cur.executemany(stmt, values)
|
|
111
|
-
|
|
112
|
-
conn.commit()
|
|
113
158
|
|
|
159
|
+
postgres_source_entry = SourceRegistryEntry(
|
|
160
|
+
connection_config=PostgresConnectionConfig,
|
|
161
|
+
indexer_config=PostgresIndexerConfig,
|
|
162
|
+
indexer=PostgresIndexer,
|
|
163
|
+
downloader_config=PostgresDownloaderConfig,
|
|
164
|
+
downloader=PostgresDownloader,
|
|
165
|
+
)
|
|
114
166
|
|
|
115
167
|
postgres_destination_entry = DestinationRegistryEntry(
|
|
116
168
|
connection_config=PostgresConnectionConfig,
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING, Generator, Optional
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from pydantic import Field, Secret
|
|
9
|
+
|
|
10
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
11
|
+
from unstructured_ingest.v2.processes.connector_registry import (
|
|
12
|
+
DestinationRegistryEntry,
|
|
13
|
+
SourceRegistryEntry,
|
|
14
|
+
)
|
|
15
|
+
from unstructured_ingest.v2.processes.connectors.sql.postgres import (
|
|
16
|
+
PostgresDownloader,
|
|
17
|
+
PostgresDownloaderConfig,
|
|
18
|
+
PostgresIndexer,
|
|
19
|
+
PostgresIndexerConfig,
|
|
20
|
+
PostgresUploader,
|
|
21
|
+
PostgresUploaderConfig,
|
|
22
|
+
PostgresUploadStager,
|
|
23
|
+
PostgresUploadStagerConfig,
|
|
24
|
+
)
|
|
25
|
+
from unstructured_ingest.v2.processes.connectors.sql.sql import SQLAccessConfig, SQLConnectionConfig
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from snowflake.connector import SnowflakeConnection
|
|
29
|
+
from snowflake.connector.cursor import SnowflakeCursor
|
|
30
|
+
|
|
31
|
+
CONNECTOR_TYPE = "snowflake"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class SnowflakeAccessConfig(SQLAccessConfig):
|
|
35
|
+
password: Optional[str] = Field(default=None, description="DB password")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class SnowflakeConnectionConfig(SQLConnectionConfig):
|
|
39
|
+
access_config: Secret[SnowflakeAccessConfig] = Field(
|
|
40
|
+
default=SnowflakeAccessConfig(), validate_default=True
|
|
41
|
+
)
|
|
42
|
+
account: str = Field(
|
|
43
|
+
default=None,
|
|
44
|
+
description="Your account identifier. The account identifier "
|
|
45
|
+
"does not include the snowflakecomputing.com suffix.",
|
|
46
|
+
)
|
|
47
|
+
user: Optional[str] = Field(default=None, description="DB username")
|
|
48
|
+
host: Optional[str] = Field(default=None, description="DB host")
|
|
49
|
+
port: Optional[int] = Field(default=443, description="DB host connection port")
|
|
50
|
+
database: str = Field(
|
|
51
|
+
default=None,
|
|
52
|
+
description="Database name.",
|
|
53
|
+
)
|
|
54
|
+
schema: str = Field(
|
|
55
|
+
default=None,
|
|
56
|
+
description="Database schema.",
|
|
57
|
+
)
|
|
58
|
+
role: str = Field(
|
|
59
|
+
default=None,
|
|
60
|
+
description="Database role.",
|
|
61
|
+
)
|
|
62
|
+
connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
|
|
63
|
+
|
|
64
|
+
@contextmanager
|
|
65
|
+
@requires_dependencies(["snowflake"], extras="snowflake")
|
|
66
|
+
def get_connection(self) -> Generator["SnowflakeConnection", None, None]:
|
|
67
|
+
# https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api#label-snowflake-connector-methods-connect
|
|
68
|
+
from snowflake.connector import connect
|
|
69
|
+
|
|
70
|
+
connect_kwargs = self.model_dump()
|
|
71
|
+
connect_kwargs.pop("access_configs", None)
|
|
72
|
+
connect_kwargs["password"] = self.access_config.get_secret_value().password
|
|
73
|
+
# https://peps.python.org/pep-0249/#paramstyle
|
|
74
|
+
connect_kwargs["paramstyle"] = "qmark"
|
|
75
|
+
connection = connect(**connect_kwargs)
|
|
76
|
+
try:
|
|
77
|
+
yield connection
|
|
78
|
+
finally:
|
|
79
|
+
connection.commit()
|
|
80
|
+
connection.close()
|
|
81
|
+
|
|
82
|
+
@contextmanager
|
|
83
|
+
def get_cursor(self) -> Generator["SnowflakeCursor", None, None]:
|
|
84
|
+
with self.get_connection() as connection:
|
|
85
|
+
cursor = connection.cursor()
|
|
86
|
+
try:
|
|
87
|
+
yield cursor
|
|
88
|
+
finally:
|
|
89
|
+
cursor.close()
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class SnowflakeIndexerConfig(PostgresIndexerConfig):
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass
|
|
97
|
+
class SnowflakeIndexer(PostgresIndexer):
|
|
98
|
+
connection_config: SnowflakeConnectionConfig
|
|
99
|
+
index_config: SnowflakeIndexerConfig
|
|
100
|
+
connector_type: str = CONNECTOR_TYPE
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class SnowflakeDownloaderConfig(PostgresDownloaderConfig):
|
|
104
|
+
pass
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dataclass
|
|
108
|
+
class SnowflakeDownloader(PostgresDownloader):
|
|
109
|
+
connection_config: SnowflakeConnectionConfig
|
|
110
|
+
download_config: SnowflakeDownloaderConfig
|
|
111
|
+
connector_type: str = CONNECTOR_TYPE
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class SnowflakeUploadStagerConfig(PostgresUploadStagerConfig):
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class SnowflakeUploadStager(PostgresUploadStager):
|
|
119
|
+
upload_stager_config: SnowflakeUploadStagerConfig
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class SnowflakeUploaderConfig(PostgresUploaderConfig):
|
|
123
|
+
pass
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@dataclass
|
|
127
|
+
class SnowflakeUploader(PostgresUploader):
|
|
128
|
+
upload_config: SnowflakeUploaderConfig = field(default_factory=SnowflakeUploaderConfig)
|
|
129
|
+
connection_config: SnowflakeConnectionConfig
|
|
130
|
+
connector_type: str = CONNECTOR_TYPE
|
|
131
|
+
values_delimiter: str = "?"
|
|
132
|
+
|
|
133
|
+
def upload_contents(self, path: Path) -> None:
|
|
134
|
+
df = pd.read_json(path, orient="records", lines=True)
|
|
135
|
+
df.replace({np.nan: None}, inplace=True)
|
|
136
|
+
|
|
137
|
+
columns = list(df.columns)
|
|
138
|
+
stmt = f"INSERT INTO {self.upload_config.table_name} ({','.join(columns)}) VALUES({','.join([self.values_delimiter for x in columns])})" # noqa E501
|
|
139
|
+
|
|
140
|
+
for rows in pd.read_json(
|
|
141
|
+
path, orient="records", lines=True, chunksize=self.upload_config.batch_size
|
|
142
|
+
):
|
|
143
|
+
with self.connection_config.get_cursor() as cursor:
|
|
144
|
+
values = self.prepare_data(columns, tuple(rows.itertuples(index=False, name=None)))
|
|
145
|
+
# TODO: executemany break on 'Binding data in type (list) is not supported'
|
|
146
|
+
for val in values:
|
|
147
|
+
cursor.execute(stmt, val)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
snowflake_source_entry = SourceRegistryEntry(
|
|
151
|
+
connection_config=SnowflakeConnectionConfig,
|
|
152
|
+
indexer_config=SnowflakeIndexerConfig,
|
|
153
|
+
indexer=SnowflakeIndexer,
|
|
154
|
+
downloader_config=SnowflakeDownloaderConfig,
|
|
155
|
+
downloader=SnowflakeDownloader,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
snowflake_destination_entry = DestinationRegistryEntry(
|
|
159
|
+
connection_config=SnowflakeConnectionConfig,
|
|
160
|
+
uploader=SnowflakeUploader,
|
|
161
|
+
uploader_config=SnowflakeUploaderConfig,
|
|
162
|
+
upload_stager=SnowflakeUploadStager,
|
|
163
|
+
upload_stager_config=SnowflakeUploadStagerConfig,
|
|
164
|
+
)
|