unstructured-ingest 0.0.25__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/__init__.py +0 -0
- test/integration/__init__.py +0 -0
- test/integration/chunkers/__init__.py +0 -0
- test/integration/chunkers/test_chunkers.py +42 -0
- test/integration/connectors/__init__.py +0 -0
- test/integration/connectors/conftest.py +15 -0
- test/integration/connectors/databricks_tests/__init__.py +0 -0
- test/integration/connectors/databricks_tests/test_volumes_native.py +165 -0
- test/integration/connectors/sql/__init__.py +0 -0
- test/integration/connectors/sql/test_postgres.py +178 -0
- test/integration/connectors/sql/test_sqlite.py +151 -0
- test/integration/connectors/test_s3.py +152 -0
- test/integration/connectors/utils/__init__.py +0 -0
- test/integration/connectors/utils/constants.py +7 -0
- test/integration/connectors/utils/docker_compose.py +44 -0
- test/integration/connectors/utils/validation.py +203 -0
- test/integration/embedders/__init__.py +0 -0
- test/integration/embedders/conftest.py +13 -0
- test/integration/embedders/test_bedrock.py +49 -0
- test/integration/embedders/test_huggingface.py +26 -0
- test/integration/embedders/test_mixedbread.py +47 -0
- test/integration/embedders/test_octoai.py +41 -0
- test/integration/embedders/test_openai.py +41 -0
- test/integration/embedders/test_vertexai.py +41 -0
- test/integration/embedders/test_voyageai.py +41 -0
- test/integration/embedders/togetherai.py +43 -0
- test/integration/embedders/utils.py +44 -0
- test/integration/partitioners/__init__.py +0 -0
- test/integration/partitioners/test_partitioner.py +75 -0
- test/integration/utils.py +15 -0
- test/unit/__init__.py +0 -0
- test/unit/embed/__init__.py +0 -0
- test/unit/embed/test_mixedbreadai.py +41 -0
- test/unit/embed/test_octoai.py +20 -0
- test/unit/embed/test_openai.py +20 -0
- test/unit/embed/test_vertexai.py +25 -0
- test/unit/embed/test_voyageai.py +24 -0
- test/unit/test_chunking_utils.py +36 -0
- test/unit/test_error.py +27 -0
- test/unit/test_interfaces.py +280 -0
- test/unit/test_interfaces_v2.py +26 -0
- test/unit/test_logger.py +78 -0
- test/unit/test_utils.py +164 -0
- test/unit/test_utils_v2.py +82 -0
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/cli/interfaces.py +2 -2
- unstructured_ingest/connector/notion/types/block.py +1 -0
- unstructured_ingest/connector/notion/types/database.py +1 -0
- unstructured_ingest/connector/notion/types/page.py +1 -0
- unstructured_ingest/embed/bedrock.py +0 -20
- unstructured_ingest/embed/huggingface.py +0 -21
- unstructured_ingest/embed/interfaces.py +29 -3
- unstructured_ingest/embed/mixedbreadai.py +0 -36
- unstructured_ingest/embed/octoai.py +2 -24
- unstructured_ingest/embed/openai.py +0 -20
- unstructured_ingest/embed/togetherai.py +40 -0
- unstructured_ingest/embed/vertexai.py +0 -20
- unstructured_ingest/embed/voyageai.py +1 -24
- unstructured_ingest/interfaces.py +1 -1
- unstructured_ingest/v2/cli/utils/click.py +21 -2
- unstructured_ingest/v2/interfaces/connector.py +22 -2
- unstructured_ingest/v2/interfaces/downloader.py +1 -0
- unstructured_ingest/v2/processes/chunker.py +1 -1
- unstructured_ingest/v2/processes/connectors/__init__.py +5 -18
- unstructured_ingest/v2/processes/connectors/databricks/__init__.py +52 -0
- unstructured_ingest/v2/processes/connectors/databricks/volumes.py +175 -0
- unstructured_ingest/v2/processes/connectors/databricks/volumes_aws.py +87 -0
- unstructured_ingest/v2/processes/connectors/databricks/volumes_azure.py +102 -0
- unstructured_ingest/v2/processes/connectors/databricks/volumes_gcp.py +85 -0
- unstructured_ingest/v2/processes/connectors/databricks/volumes_native.py +86 -0
- unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +17 -0
- unstructured_ingest/v2/processes/connectors/kdbai.py +14 -6
- unstructured_ingest/v2/processes/connectors/mongodb.py +223 -3
- unstructured_ingest/v2/processes/connectors/sql/__init__.py +13 -0
- unstructured_ingest/v2/processes/connectors/sql/postgres.py +177 -0
- unstructured_ingest/v2/processes/connectors/sql/sql.py +310 -0
- unstructured_ingest/v2/processes/connectors/sql/sqlite.py +172 -0
- unstructured_ingest/v2/processes/embedder.py +13 -0
- unstructured_ingest/v2/processes/partitioner.py +2 -1
- {unstructured_ingest-0.0.25.dist-info → unstructured_ingest-0.1.1.dist-info}/METADATA +16 -14
- {unstructured_ingest-0.0.25.dist-info → unstructured_ingest-0.1.1.dist-info}/RECORD +85 -31
- {unstructured_ingest-0.0.25.dist-info → unstructured_ingest-0.1.1.dist-info}/top_level.txt +1 -0
- unstructured_ingest/v2/processes/connectors/sql.py +0 -275
- {unstructured_ingest-0.0.25.dist-info → unstructured_ingest-0.1.1.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.0.25.dist-info → unstructured_ingest-0.1.1.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.0.25.dist-info → unstructured_ingest-0.1.1.dist-info}/entry_points.txt +0 -0
|
@@ -1,26 +1,37 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import sys
|
|
2
3
|
from dataclasses import dataclass, field
|
|
4
|
+
from datetime import datetime
|
|
3
5
|
from pathlib import Path
|
|
4
|
-
from
|
|
6
|
+
from time import time
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Generator, Optional
|
|
5
8
|
|
|
6
9
|
from pydantic import Field, Secret
|
|
7
10
|
|
|
8
11
|
from unstructured_ingest.__version__ import __version__ as unstructured_version
|
|
9
|
-
from unstructured_ingest.error import DestinationConnectionError
|
|
10
|
-
from unstructured_ingest.utils.data_prep import batch_generator
|
|
12
|
+
from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
|
|
13
|
+
from unstructured_ingest.utils.data_prep import batch_generator, flatten_dict
|
|
11
14
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
12
15
|
from unstructured_ingest.v2.interfaces import (
|
|
13
16
|
AccessConfig,
|
|
14
17
|
ConnectionConfig,
|
|
18
|
+
Downloader,
|
|
19
|
+
DownloaderConfig,
|
|
15
20
|
FileData,
|
|
21
|
+
FileDataSourceMetadata,
|
|
22
|
+
Indexer,
|
|
23
|
+
IndexerConfig,
|
|
24
|
+
SourceIdentifiers,
|
|
16
25
|
Uploader,
|
|
17
26
|
UploaderConfig,
|
|
18
27
|
UploadStager,
|
|
19
28
|
UploadStagerConfig,
|
|
29
|
+
download_responses,
|
|
20
30
|
)
|
|
21
31
|
from unstructured_ingest.v2.logger import logger
|
|
22
32
|
from unstructured_ingest.v2.processes.connector_registry import (
|
|
23
33
|
DestinationRegistryEntry,
|
|
34
|
+
SourceRegistryEntry,
|
|
24
35
|
)
|
|
25
36
|
|
|
26
37
|
if TYPE_CHECKING:
|
|
@@ -53,6 +64,207 @@ class MongoDBUploadStagerConfig(UploadStagerConfig):
|
|
|
53
64
|
pass
|
|
54
65
|
|
|
55
66
|
|
|
67
|
+
class MongoDBIndexerConfig(IndexerConfig):
|
|
68
|
+
batch_size: int = Field(default=100, description="Number of records per batch")
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class MongoDBDownloaderConfig(DownloaderConfig):
|
|
72
|
+
pass
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class MongoDBIndexer(Indexer):
|
|
77
|
+
connection_config: MongoDBConnectionConfig
|
|
78
|
+
index_config: MongoDBIndexerConfig
|
|
79
|
+
connector_type: str = CONNECTOR_TYPE
|
|
80
|
+
|
|
81
|
+
def precheck(self) -> None:
|
|
82
|
+
"""Validates the connection to the MongoDB server."""
|
|
83
|
+
try:
|
|
84
|
+
client = self.create_client()
|
|
85
|
+
client.admin.command("ping")
|
|
86
|
+
except Exception as e:
|
|
87
|
+
logger.error(f"Failed to validate connection: {e}", exc_info=True)
|
|
88
|
+
raise SourceConnectionError(f"Failed to validate connection: {e}")
|
|
89
|
+
|
|
90
|
+
@requires_dependencies(["pymongo"], extras="mongodb")
|
|
91
|
+
def create_client(self) -> "MongoClient":
|
|
92
|
+
from pymongo import MongoClient
|
|
93
|
+
from pymongo.driver_info import DriverInfo
|
|
94
|
+
from pymongo.server_api import ServerApi
|
|
95
|
+
|
|
96
|
+
access_config = self.connection_config.access_config.get_secret_value()
|
|
97
|
+
|
|
98
|
+
if access_config.uri:
|
|
99
|
+
return MongoClient(
|
|
100
|
+
access_config.uri,
|
|
101
|
+
server_api=ServerApi(version=SERVER_API_VERSION),
|
|
102
|
+
driver=DriverInfo(name="unstructured", version=unstructured_version),
|
|
103
|
+
)
|
|
104
|
+
else:
|
|
105
|
+
return MongoClient(
|
|
106
|
+
host=self.connection_config.host,
|
|
107
|
+
port=self.connection_config.port,
|
|
108
|
+
server_api=ServerApi(version=SERVER_API_VERSION),
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
|
|
112
|
+
"""Generates FileData objects for each document in the MongoDB collection."""
|
|
113
|
+
client = self.create_client()
|
|
114
|
+
database = client[self.connection_config.database]
|
|
115
|
+
collection = database[self.connection_config.collection]
|
|
116
|
+
|
|
117
|
+
# Get list of document IDs
|
|
118
|
+
ids = collection.distinct("_id")
|
|
119
|
+
batch_size = self.index_config.batch_size if self.index_config else 100
|
|
120
|
+
|
|
121
|
+
for id_batch in batch_generator(ids, batch_size=batch_size):
|
|
122
|
+
# Make sure the hash is always a positive number to create identifier
|
|
123
|
+
batch_id = str(hash(frozenset(id_batch)) + sys.maxsize + 1)
|
|
124
|
+
|
|
125
|
+
metadata = FileDataSourceMetadata(
|
|
126
|
+
date_processed=str(time()),
|
|
127
|
+
record_locator={
|
|
128
|
+
"database": self.connection_config.database,
|
|
129
|
+
"collection": self.connection_config.collection,
|
|
130
|
+
},
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
file_data = FileData(
|
|
134
|
+
identifier=batch_id,
|
|
135
|
+
doc_type="batch",
|
|
136
|
+
connector_type=self.connector_type,
|
|
137
|
+
metadata=metadata,
|
|
138
|
+
additional_metadata={
|
|
139
|
+
"ids": [str(doc_id) for doc_id in id_batch],
|
|
140
|
+
},
|
|
141
|
+
)
|
|
142
|
+
yield file_data
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
@dataclass
|
|
146
|
+
class MongoDBDownloader(Downloader):
|
|
147
|
+
download_config: MongoDBDownloaderConfig
|
|
148
|
+
connection_config: MongoDBConnectionConfig
|
|
149
|
+
connector_type: str = CONNECTOR_TYPE
|
|
150
|
+
|
|
151
|
+
@requires_dependencies(["pymongo"], extras="mongodb")
|
|
152
|
+
def create_client(self) -> "MongoClient":
|
|
153
|
+
from pymongo import MongoClient
|
|
154
|
+
from pymongo.driver_info import DriverInfo
|
|
155
|
+
from pymongo.server_api import ServerApi
|
|
156
|
+
|
|
157
|
+
access_config = self.connection_config.access_config.get_secret_value()
|
|
158
|
+
|
|
159
|
+
if access_config.uri:
|
|
160
|
+
return MongoClient(
|
|
161
|
+
access_config.uri,
|
|
162
|
+
server_api=ServerApi(version=SERVER_API_VERSION),
|
|
163
|
+
driver=DriverInfo(name="unstructured", version=unstructured_version),
|
|
164
|
+
)
|
|
165
|
+
else:
|
|
166
|
+
return MongoClient(
|
|
167
|
+
host=self.connection_config.host,
|
|
168
|
+
port=self.connection_config.port,
|
|
169
|
+
server_api=ServerApi(version=SERVER_API_VERSION),
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
@SourceConnectionError.wrap
|
|
173
|
+
@requires_dependencies(["bson"], extras="mongodb")
|
|
174
|
+
def run(self, file_data: FileData, **kwargs: Any) -> download_responses:
|
|
175
|
+
"""Fetches the document from MongoDB and writes it to a file."""
|
|
176
|
+
from bson.errors import InvalidId
|
|
177
|
+
from bson.objectid import ObjectId
|
|
178
|
+
|
|
179
|
+
client = self.create_client()
|
|
180
|
+
database = client[self.connection_config.database]
|
|
181
|
+
collection = database[self.connection_config.collection]
|
|
182
|
+
|
|
183
|
+
ids = file_data.additional_metadata.get("ids", [])
|
|
184
|
+
if not ids:
|
|
185
|
+
raise ValueError("No document IDs provided in additional_metadata")
|
|
186
|
+
|
|
187
|
+
object_ids = []
|
|
188
|
+
for doc_id in ids:
|
|
189
|
+
try:
|
|
190
|
+
object_ids.append(ObjectId(doc_id))
|
|
191
|
+
except InvalidId as e:
|
|
192
|
+
error_message = f"Invalid ObjectId for doc_id '{doc_id}': {str(e)}"
|
|
193
|
+
logger.error(error_message)
|
|
194
|
+
raise ValueError(error_message) from e
|
|
195
|
+
|
|
196
|
+
try:
|
|
197
|
+
docs = list(collection.find({"_id": {"$in": object_ids}}))
|
|
198
|
+
except Exception as e:
|
|
199
|
+
logger.error(f"Failed to fetch documents: {e}", exc_info=True)
|
|
200
|
+
raise e
|
|
201
|
+
|
|
202
|
+
download_responses = []
|
|
203
|
+
for doc in docs:
|
|
204
|
+
doc_id = doc["_id"]
|
|
205
|
+
doc.pop("_id", None)
|
|
206
|
+
|
|
207
|
+
# Extract date_created from the document or ObjectId
|
|
208
|
+
date_created = None
|
|
209
|
+
if "date_created" in doc:
|
|
210
|
+
# If the document has a 'date_created' field, use it
|
|
211
|
+
date_created = doc["date_created"]
|
|
212
|
+
if isinstance(date_created, datetime):
|
|
213
|
+
date_created = date_created.isoformat()
|
|
214
|
+
else:
|
|
215
|
+
# Convert to ISO format if it's a string
|
|
216
|
+
date_created = str(date_created)
|
|
217
|
+
elif isinstance(doc_id, ObjectId):
|
|
218
|
+
# Use the ObjectId's generation time
|
|
219
|
+
date_created = doc_id.generation_time.isoformat()
|
|
220
|
+
|
|
221
|
+
flattened_dict = flatten_dict(dictionary=doc)
|
|
222
|
+
concatenated_values = "\n".join(str(value) for value in flattened_dict.values())
|
|
223
|
+
|
|
224
|
+
# Create a FileData object for each document with source_identifiers
|
|
225
|
+
individual_file_data = FileData(
|
|
226
|
+
identifier=str(doc_id),
|
|
227
|
+
connector_type=self.connector_type,
|
|
228
|
+
source_identifiers=SourceIdentifiers(
|
|
229
|
+
filename=str(doc_id),
|
|
230
|
+
fullpath=str(doc_id),
|
|
231
|
+
rel_path=str(doc_id),
|
|
232
|
+
),
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Determine the download path
|
|
236
|
+
download_path = self.get_download_path(individual_file_data)
|
|
237
|
+
if download_path is None:
|
|
238
|
+
raise ValueError("Download path could not be determined")
|
|
239
|
+
|
|
240
|
+
download_path.parent.mkdir(parents=True, exist_ok=True)
|
|
241
|
+
download_path = download_path.with_suffix(".txt")
|
|
242
|
+
|
|
243
|
+
# Write the concatenated values to the file
|
|
244
|
+
with open(download_path, "w", encoding="utf8") as f:
|
|
245
|
+
f.write(concatenated_values)
|
|
246
|
+
|
|
247
|
+
individual_file_data.local_download_path = str(download_path)
|
|
248
|
+
|
|
249
|
+
# Update metadata
|
|
250
|
+
individual_file_data.metadata = FileDataSourceMetadata(
|
|
251
|
+
date_created=date_created, # Include date_created here
|
|
252
|
+
date_processed=str(time()),
|
|
253
|
+
record_locator={
|
|
254
|
+
"database": self.connection_config.database,
|
|
255
|
+
"collection": self.connection_config.collection,
|
|
256
|
+
"document_id": str(doc_id),
|
|
257
|
+
},
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
download_response = self.generate_download_response(
|
|
261
|
+
file_data=individual_file_data, download_path=download_path
|
|
262
|
+
)
|
|
263
|
+
download_responses.append(download_response)
|
|
264
|
+
|
|
265
|
+
return download_responses
|
|
266
|
+
|
|
267
|
+
|
|
56
268
|
@dataclass
|
|
57
269
|
class MongoDBUploadStager(UploadStager):
|
|
58
270
|
upload_stager_config: MongoDBUploadStagerConfig = field(
|
|
@@ -138,3 +350,11 @@ mongodb_destination_entry = DestinationRegistryEntry(
|
|
|
138
350
|
upload_stager=MongoDBUploadStager,
|
|
139
351
|
upload_stager_config=MongoDBUploadStagerConfig,
|
|
140
352
|
)
|
|
353
|
+
|
|
354
|
+
mongodb_source_entry = SourceRegistryEntry(
|
|
355
|
+
connection_config=MongoDBConnectionConfig,
|
|
356
|
+
indexer_config=MongoDBIndexerConfig,
|
|
357
|
+
indexer=MongoDBIndexer,
|
|
358
|
+
downloader_config=MongoDBDownloaderConfig,
|
|
359
|
+
downloader=MongoDBDownloader,
|
|
360
|
+
)
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from unstructured_ingest.v2.processes.connector_registry import (
|
|
4
|
+
add_destination_entry,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
from .postgres import CONNECTOR_TYPE as POSTGRES_CONNECTOR_TYPE
|
|
8
|
+
from .postgres import postgres_destination_entry
|
|
9
|
+
from .sqlite import CONNECTOR_TYPE as SQLITE_CONNECTOR_TYPE
|
|
10
|
+
from .sqlite import sqlite_destination_entry
|
|
11
|
+
|
|
12
|
+
add_destination_entry(destination_type=SQLITE_CONNECTOR_TYPE, entry=sqlite_destination_entry)
|
|
13
|
+
add_destination_entry(destination_type=POSTGRES_CONNECTOR_TYPE, entry=postgres_destination_entry)
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
from pydantic import Field, Secret
|
|
8
|
+
|
|
9
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
10
|
+
from unstructured_ingest.v2.interfaces import FileData
|
|
11
|
+
from unstructured_ingest.v2.logger import logger
|
|
12
|
+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
|
|
13
|
+
from unstructured_ingest.v2.processes.connectors.sql.sql import (
|
|
14
|
+
_DATE_COLUMNS,
|
|
15
|
+
SQLAccessConfig,
|
|
16
|
+
SQLConnectionConfig,
|
|
17
|
+
SQLDownloader,
|
|
18
|
+
SQLDownloaderConfig,
|
|
19
|
+
SQLIndexer,
|
|
20
|
+
SQLIndexerConfig,
|
|
21
|
+
SQLUploader,
|
|
22
|
+
SQLUploaderConfig,
|
|
23
|
+
SQLUploadStager,
|
|
24
|
+
SQLUploadStagerConfig,
|
|
25
|
+
parse_date_string,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from psycopg2.extensions import connection as PostgresConnection
|
|
30
|
+
|
|
31
|
+
CONNECTOR_TYPE = "postgres"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PostgresAccessConfig(SQLAccessConfig):
|
|
35
|
+
password: Optional[str] = Field(default=None, description="DB password")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class PostgresConnectionConfig(SQLConnectionConfig):
|
|
39
|
+
access_config: Secret[PostgresAccessConfig] = Field(
|
|
40
|
+
default=PostgresAccessConfig(), validate_default=True
|
|
41
|
+
)
|
|
42
|
+
database: Optional[str] = Field(
|
|
43
|
+
default=None,
|
|
44
|
+
description="Database name.",
|
|
45
|
+
)
|
|
46
|
+
username: Optional[str] = Field(default=None, description="DB username")
|
|
47
|
+
host: Optional[str] = Field(default=None, description="DB host")
|
|
48
|
+
port: Optional[int] = Field(default=5432, description="DB host connection port")
|
|
49
|
+
connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
|
|
50
|
+
|
|
51
|
+
@requires_dependencies(["psycopg2"], extras="postgres")
|
|
52
|
+
def get_connection(self) -> "PostgresConnection":
|
|
53
|
+
from psycopg2 import connect
|
|
54
|
+
|
|
55
|
+
access_config = self.access_config.get_secret_value()
|
|
56
|
+
return connect(
|
|
57
|
+
user=self.username,
|
|
58
|
+
password=access_config.password,
|
|
59
|
+
dbname=self.database,
|
|
60
|
+
host=self.host,
|
|
61
|
+
port=self.port,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class PostgresIndexerConfig(SQLIndexerConfig):
|
|
66
|
+
pass
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@dataclass
|
|
70
|
+
class PostgresIndexer(SQLIndexer):
|
|
71
|
+
connection_config: PostgresConnectionConfig
|
|
72
|
+
index_config: PostgresIndexerConfig
|
|
73
|
+
connector_type: str = CONNECTOR_TYPE
|
|
74
|
+
|
|
75
|
+
def _get_doc_ids(self) -> list[str]:
|
|
76
|
+
connection = self.connection_config.get_connection()
|
|
77
|
+
with connection.cursor() as cursor:
|
|
78
|
+
cursor.execute(
|
|
79
|
+
f"SELECT {self.index_config.id_column} FROM {self.index_config.table_name}"
|
|
80
|
+
)
|
|
81
|
+
results = cursor.fetchall()
|
|
82
|
+
ids = [result[0] for result in results]
|
|
83
|
+
return ids
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class PostgresDownloaderConfig(SQLDownloaderConfig):
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class PostgresDownloader(SQLDownloader):
|
|
92
|
+
connection_config: PostgresConnectionConfig
|
|
93
|
+
download_config: PostgresDownloaderConfig
|
|
94
|
+
connector_type: str = CONNECTOR_TYPE
|
|
95
|
+
|
|
96
|
+
def query_db(self, file_data: FileData) -> tuple[list[tuple], list[str]]:
|
|
97
|
+
table_name = file_data.additional_metadata["table_name"]
|
|
98
|
+
id_column = file_data.additional_metadata["id_column"]
|
|
99
|
+
ids = file_data.additional_metadata["ids"]
|
|
100
|
+
connection = self.connection_config.get_connection()
|
|
101
|
+
with connection.cursor() as cursor:
|
|
102
|
+
fields = ",".join(self.download_config.fields) if self.download_config.fields else "*"
|
|
103
|
+
query = "SELECT {fields} FROM {table_name} WHERE {id_column} in ({ids})".format(
|
|
104
|
+
fields=fields,
|
|
105
|
+
table_name=table_name,
|
|
106
|
+
id_column=id_column,
|
|
107
|
+
ids=",".join([str(i) for i in ids]),
|
|
108
|
+
)
|
|
109
|
+
logger.debug(f"running query: {query}")
|
|
110
|
+
cursor.execute(query)
|
|
111
|
+
rows = cursor.fetchall()
|
|
112
|
+
columns = [col[0] for col in cursor.description]
|
|
113
|
+
return rows, columns
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class PostgresUploadStagerConfig(SQLUploadStagerConfig):
|
|
117
|
+
pass
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class PostgresUploadStager(SQLUploadStager):
|
|
121
|
+
upload_stager_config: PostgresUploadStagerConfig
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class PostgresUploaderConfig(SQLUploaderConfig):
|
|
125
|
+
pass
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@dataclass
|
|
129
|
+
class PostgresUploader(SQLUploader):
|
|
130
|
+
upload_config: PostgresUploaderConfig = field(default_factory=PostgresUploaderConfig)
|
|
131
|
+
connection_config: PostgresConnectionConfig
|
|
132
|
+
connector_type: str = CONNECTOR_TYPE
|
|
133
|
+
|
|
134
|
+
def prepare_data(
|
|
135
|
+
self, columns: list[str], data: tuple[tuple[Any, ...], ...]
|
|
136
|
+
) -> list[tuple[Any, ...]]:
|
|
137
|
+
output = []
|
|
138
|
+
for row in data:
|
|
139
|
+
parsed = []
|
|
140
|
+
for column_name, value in zip(columns, row):
|
|
141
|
+
if column_name in _DATE_COLUMNS:
|
|
142
|
+
if value is None:
|
|
143
|
+
parsed.append(None)
|
|
144
|
+
else:
|
|
145
|
+
parsed.append(parse_date_string(value))
|
|
146
|
+
else:
|
|
147
|
+
parsed.append(value)
|
|
148
|
+
output.append(tuple(parsed))
|
|
149
|
+
return output
|
|
150
|
+
|
|
151
|
+
def upload_contents(self, path: Path) -> None:
|
|
152
|
+
df = pd.read_json(path, orient="records", lines=True)
|
|
153
|
+
logger.debug(f"uploading {len(df)} entries to {self.connection_config.database} ")
|
|
154
|
+
df.replace({np.nan: None}, inplace=True)
|
|
155
|
+
|
|
156
|
+
columns = tuple(df.columns)
|
|
157
|
+
stmt = f"INSERT INTO {self.upload_config.table_name} ({','.join(columns)}) \
|
|
158
|
+
VALUES({','.join(['%s' for x in columns])})" # noqa E501
|
|
159
|
+
|
|
160
|
+
for rows in pd.read_json(
|
|
161
|
+
path, orient="records", lines=True, chunksize=self.upload_config.batch_size
|
|
162
|
+
):
|
|
163
|
+
with self.connection_config.get_connection() as conn:
|
|
164
|
+
values = self.prepare_data(columns, tuple(rows.itertuples(index=False, name=None)))
|
|
165
|
+
with conn.cursor() as cur:
|
|
166
|
+
cur.executemany(stmt, values)
|
|
167
|
+
|
|
168
|
+
conn.commit()
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
postgres_destination_entry = DestinationRegistryEntry(
|
|
172
|
+
connection_config=PostgresConnectionConfig,
|
|
173
|
+
uploader=PostgresUploader,
|
|
174
|
+
uploader_config=PostgresUploaderConfig,
|
|
175
|
+
upload_stager=PostgresUploadStager,
|
|
176
|
+
upload_stager_config=PostgresUploadStagerConfig,
|
|
177
|
+
)
|