unstructured-ingest 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (55) hide show
  1. test/integration/connectors/elasticsearch/__init__.py +0 -0
  2. test/integration/connectors/elasticsearch/conftest.py +34 -0
  3. test/integration/connectors/elasticsearch/test_elasticsearch.py +308 -0
  4. test/integration/connectors/elasticsearch/test_opensearch.py +302 -0
  5. test/integration/connectors/sql/test_postgres.py +10 -4
  6. test/integration/connectors/sql/test_singlestore.py +8 -4
  7. test/integration/connectors/sql/test_snowflake.py +10 -6
  8. test/integration/connectors/sql/test_sqlite.py +4 -4
  9. test/integration/connectors/test_astradb.py +50 -3
  10. test/integration/connectors/test_delta_table.py +46 -0
  11. test/integration/connectors/test_kafka.py +40 -6
  12. test/integration/connectors/test_lancedb.py +210 -0
  13. test/integration/connectors/test_milvus.py +141 -0
  14. test/integration/connectors/test_mongodb.py +332 -0
  15. test/integration/connectors/test_pinecone.py +53 -1
  16. test/integration/connectors/utils/docker.py +81 -15
  17. test/integration/connectors/utils/validation.py +10 -0
  18. test/integration/connectors/weaviate/__init__.py +0 -0
  19. test/integration/connectors/weaviate/conftest.py +15 -0
  20. test/integration/connectors/weaviate/test_local.py +131 -0
  21. unstructured_ingest/__version__.py +1 -1
  22. unstructured_ingest/pipeline/reformat/embedding.py +1 -1
  23. unstructured_ingest/utils/data_prep.py +9 -1
  24. unstructured_ingest/v2/processes/connectors/__init__.py +3 -16
  25. unstructured_ingest/v2/processes/connectors/astradb.py +2 -2
  26. unstructured_ingest/v2/processes/connectors/azure_ai_search.py +4 -0
  27. unstructured_ingest/v2/processes/connectors/delta_table.py +20 -4
  28. unstructured_ingest/v2/processes/connectors/elasticsearch/__init__.py +19 -0
  29. unstructured_ingest/v2/processes/connectors/{elasticsearch.py → elasticsearch/elasticsearch.py} +92 -46
  30. unstructured_ingest/v2/processes/connectors/{opensearch.py → elasticsearch/opensearch.py} +1 -1
  31. unstructured_ingest/v2/processes/connectors/google_drive.py +1 -1
  32. unstructured_ingest/v2/processes/connectors/kafka/kafka.py +6 -0
  33. unstructured_ingest/v2/processes/connectors/lancedb/__init__.py +17 -0
  34. unstructured_ingest/v2/processes/connectors/lancedb/aws.py +43 -0
  35. unstructured_ingest/v2/processes/connectors/lancedb/azure.py +43 -0
  36. unstructured_ingest/v2/processes/connectors/lancedb/gcp.py +44 -0
  37. unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +161 -0
  38. unstructured_ingest/v2/processes/connectors/lancedb/local.py +44 -0
  39. unstructured_ingest/v2/processes/connectors/milvus.py +72 -27
  40. unstructured_ingest/v2/processes/connectors/mongodb.py +122 -111
  41. unstructured_ingest/v2/processes/connectors/pinecone.py +24 -7
  42. unstructured_ingest/v2/processes/connectors/sql/sql.py +97 -26
  43. unstructured_ingest/v2/processes/connectors/weaviate/__init__.py +25 -0
  44. unstructured_ingest/v2/processes/connectors/weaviate/cloud.py +164 -0
  45. unstructured_ingest/v2/processes/connectors/weaviate/embedded.py +90 -0
  46. unstructured_ingest/v2/processes/connectors/weaviate/local.py +73 -0
  47. unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +299 -0
  48. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/METADATA +19 -19
  49. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/RECORD +54 -33
  50. unstructured_ingest/v2/processes/connectors/weaviate.py +0 -242
  51. /test/integration/connectors/{test_azure_cog_search.py → test_azure_ai_search.py} +0 -0
  52. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/LICENSE.md +0 -0
  53. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/WHEEL +0 -0
  54. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/entry_points.txt +0 -0
  55. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import sys
3
- from dataclasses import dataclass, field
3
+ from contextlib import contextmanager
4
+ from dataclasses import dataclass, replace
4
5
  from datetime import datetime
5
6
  from pathlib import Path
6
7
  from time import time
@@ -12,6 +13,7 @@ from unstructured_ingest.__version__ import __version__ as unstructured_version
12
13
  from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
13
14
  from unstructured_ingest.utils.data_prep import batch_generator, flatten_dict
14
15
  from unstructured_ingest.utils.dep_check import requires_dependencies
16
+ from unstructured_ingest.v2.constants import RECORD_ID_LABEL
15
17
  from unstructured_ingest.v2.interfaces import (
16
18
  AccessConfig,
17
19
  ConnectionConfig,
@@ -24,8 +26,6 @@ from unstructured_ingest.v2.interfaces import (
24
26
  SourceIdentifiers,
25
27
  Uploader,
26
28
  UploaderConfig,
27
- UploadStager,
28
- UploadStagerConfig,
29
29
  download_responses,
30
30
  )
31
31
  from unstructured_ingest.v2.logger import logger
@@ -36,6 +36,7 @@ from unstructured_ingest.v2.processes.connector_registry import (
36
36
 
37
37
  if TYPE_CHECKING:
38
38
  from pymongo import MongoClient
39
+ from pymongo.collection import Collection
39
40
 
40
41
  CONNECTOR_TYPE = "mongodb"
41
42
  SERVER_API_VERSION = "1"
@@ -54,18 +55,37 @@ class MongoDBConnectionConfig(ConnectionConfig):
54
55
  description="hostname or IP address or Unix domain socket path of a single mongod or "
55
56
  "mongos instance to connect to, or a list of hostnames",
56
57
  )
57
- database: Optional[str] = Field(default=None, description="database name to connect to")
58
- collection: Optional[str] = Field(default=None, description="collection name to connect to")
59
58
  port: int = Field(default=27017)
60
59
  connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
61
60
 
61
+ @contextmanager
62
+ @requires_dependencies(["pymongo"], extras="mongodb")
63
+ def get_client(self) -> Generator["MongoClient", None, None]:
64
+ from pymongo import MongoClient
65
+ from pymongo.driver_info import DriverInfo
66
+ from pymongo.server_api import ServerApi
62
67
 
63
- class MongoDBUploadStagerConfig(UploadStagerConfig):
64
- pass
68
+ access_config = self.access_config.get_secret_value()
69
+ if uri := access_config.uri:
70
+ client_kwargs = {
71
+ "host": uri,
72
+ "server_api": ServerApi(version=SERVER_API_VERSION),
73
+ "driver": DriverInfo(name="unstructured", version=unstructured_version),
74
+ }
75
+ else:
76
+ client_kwargs = {
77
+ "host": self.host,
78
+ "port": self.port,
79
+ "server_api": ServerApi(version=SERVER_API_VERSION),
80
+ }
81
+ with MongoClient(**client_kwargs) as client:
82
+ yield client
65
83
 
66
84
 
67
85
  class MongoDBIndexerConfig(IndexerConfig):
68
86
  batch_size: int = Field(default=100, description="Number of records per batch")
87
+ database: Optional[str] = Field(default=None, description="database name to connect to")
88
+ collection: Optional[str] = Field(default=None, description="collection name to connect to")
69
89
 
70
90
 
71
91
  class MongoDBDownloaderConfig(DownloaderConfig):
@@ -81,42 +101,38 @@ class MongoDBIndexer(Indexer):
81
101
  def precheck(self) -> None:
82
102
  """Validates the connection to the MongoDB server."""
83
103
  try:
84
- client = self.create_client()
85
- client.admin.command("ping")
104
+ with self.connection_config.get_client() as client:
105
+ client.admin.command("ping")
106
+ database_names = client.list_database_names()
107
+ database_name = self.index_config.database
108
+ if database_name not in database_names:
109
+ raise DestinationConnectionError(
110
+ "database {} does not exist: {}".format(
111
+ database_name, ", ".join(database_names)
112
+ )
113
+ )
114
+ database = client[database_name]
115
+ collection_names = database.list_collection_names()
116
+ collection_name = self.index_config.collection
117
+ if collection_name not in collection_names:
118
+ raise SourceConnectionError(
119
+ "collection {} does not exist: {}".format(
120
+ collection_name, ", ".join(collection_names)
121
+ )
122
+ )
86
123
  except Exception as e:
87
124
  logger.error(f"Failed to validate connection: {e}", exc_info=True)
88
125
  raise SourceConnectionError(f"Failed to validate connection: {e}")
89
126
 
90
- @requires_dependencies(["pymongo"], extras="mongodb")
91
- def create_client(self) -> "MongoClient":
92
- from pymongo import MongoClient
93
- from pymongo.driver_info import DriverInfo
94
- from pymongo.server_api import ServerApi
95
-
96
- access_config = self.connection_config.access_config.get_secret_value()
97
-
98
- if access_config.uri:
99
- return MongoClient(
100
- access_config.uri,
101
- server_api=ServerApi(version=SERVER_API_VERSION),
102
- driver=DriverInfo(name="unstructured", version=unstructured_version),
103
- )
104
- else:
105
- return MongoClient(
106
- host=self.connection_config.host,
107
- port=self.connection_config.port,
108
- server_api=ServerApi(version=SERVER_API_VERSION),
109
- )
110
-
111
127
  def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
112
128
  """Generates FileData objects for each document in the MongoDB collection."""
113
- client = self.create_client()
114
- database = client[self.connection_config.database]
115
- collection = database[self.connection_config.collection]
129
+ with self.connection_config.get_client() as client:
130
+ database = client[self.index_config.database]
131
+ collection = database[self.index_config.collection]
116
132
 
117
- # Get list of document IDs
118
- ids = collection.distinct("_id")
119
- batch_size = self.index_config.batch_size if self.index_config else 100
133
+ # Get list of document IDs
134
+ ids = collection.distinct("_id")
135
+ batch_size = self.index_config.batch_size if self.index_config else 100
120
136
 
121
137
  for id_batch in batch_generator(ids, batch_size=batch_size):
122
138
  # Make sure the hash is always a positive number to create identifier
@@ -125,8 +141,8 @@ class MongoDBIndexer(Indexer):
125
141
  metadata = FileDataSourceMetadata(
126
142
  date_processed=str(time()),
127
143
  record_locator={
128
- "database": self.connection_config.database,
129
- "collection": self.connection_config.collection,
144
+ "database": self.index_config.database,
145
+ "collection": self.index_config.collection,
130
146
  },
131
147
  )
132
148
 
@@ -177,8 +193,8 @@ class MongoDBDownloader(Downloader):
177
193
  from bson.objectid import ObjectId
178
194
 
179
195
  client = self.create_client()
180
- database = client[self.connection_config.database]
181
- collection = database[self.connection_config.collection]
196
+ database = client[file_data.metadata.record_locator["database"]]
197
+ collection = database[file_data.metadata.record_locator["collection"]]
182
198
 
183
199
  ids = file_data.additional_metadata.get("ids", [])
184
200
  if not ids:
@@ -222,14 +238,12 @@ class MongoDBDownloader(Downloader):
222
238
  concatenated_values = "\n".join(str(value) for value in flattened_dict.values())
223
239
 
224
240
  # Create a FileData object for each document with source_identifiers
225
- individual_file_data = FileData(
226
- identifier=str(doc_id),
227
- connector_type=self.connector_type,
228
- source_identifiers=SourceIdentifiers(
229
- filename=str(doc_id),
230
- fullpath=str(doc_id),
231
- rel_path=str(doc_id),
232
- ),
241
+ individual_file_data = replace(file_data)
242
+ individual_file_data.identifier = str(doc_id)
243
+ individual_file_data.source_identifiers = SourceIdentifiers(
244
+ filename=str(doc_id),
245
+ fullpath=str(doc_id),
246
+ rel_path=str(doc_id),
233
247
  )
234
248
 
235
249
  # Determine the download path
@@ -247,15 +261,8 @@ class MongoDBDownloader(Downloader):
247
261
  individual_file_data.local_download_path = str(download_path)
248
262
 
249
263
  # Update metadata
250
- individual_file_data.metadata = FileDataSourceMetadata(
251
- date_created=date_created, # Include date_created here
252
- date_processed=str(time()),
253
- record_locator={
254
- "database": self.connection_config.database,
255
- "collection": self.connection_config.collection,
256
- "document_id": str(doc_id),
257
- },
258
- )
264
+ individual_file_data.metadata.record_locator["document_id"] = str(doc_id)
265
+ individual_file_data.metadata.date_created = date_created
259
266
 
260
267
  download_response = self.generate_download_response(
261
268
  file_data=individual_file_data, download_path=download_path
@@ -265,31 +272,14 @@ class MongoDBDownloader(Downloader):
265
272
  return download_responses
266
273
 
267
274
 
268
- @dataclass
269
- class MongoDBUploadStager(UploadStager):
270
- upload_stager_config: MongoDBUploadStagerConfig = field(
271
- default_factory=lambda: MongoDBUploadStagerConfig()
272
- )
273
-
274
- def run(
275
- self,
276
- elements_filepath: Path,
277
- file_data: FileData,
278
- output_dir: Path,
279
- output_filename: str,
280
- **kwargs: Any,
281
- ) -> Path:
282
- with open(elements_filepath) as elements_file:
283
- elements_contents = json.load(elements_file)
284
-
285
- output_path = Path(output_dir) / Path(f"{output_filename}.json")
286
- with open(output_path, "w") as output_file:
287
- json.dump(elements_contents, output_file)
288
- return output_path
289
-
290
-
291
275
  class MongoDBUploaderConfig(UploaderConfig):
292
276
  batch_size: int = Field(default=100, description="Number of records per batch")
277
+ database: Optional[str] = Field(default=None, description="database name to connect to")
278
+ collection: Optional[str] = Field(default=None, description="collection name to connect to")
279
+ record_id_key: str = Field(
280
+ default=RECORD_ID_LABEL,
281
+ description="searchable key to find entries for the same record on previous runs",
282
+ )
293
283
 
294
284
 
295
285
  @dataclass
@@ -300,55 +290,76 @@ class MongoDBUploader(Uploader):
300
290
 
301
291
  def precheck(self) -> None:
302
292
  try:
303
- client = self.create_client()
304
- client.admin.command("ping")
293
+ with self.connection_config.get_client() as client:
294
+ client.admin.command("ping")
295
+ database_names = client.list_database_names()
296
+ database_name = self.upload_config.database
297
+ if database_name not in database_names:
298
+ raise DestinationConnectionError(
299
+ "database {} does not exist: {}".format(
300
+ database_name, ", ".join(database_names)
301
+ )
302
+ )
303
+ database = client[database_name]
304
+ collection_names = database.list_collection_names()
305
+ collection_name = self.upload_config.collection
306
+ if collection_name not in collection_names:
307
+ raise SourceConnectionError(
308
+ "collection {} does not exist: {}".format(
309
+ collection_name, ", ".join(collection_names)
310
+ )
311
+ )
305
312
  except Exception as e:
306
313
  logger.error(f"failed to validate connection: {e}", exc_info=True)
307
314
  raise DestinationConnectionError(f"failed to validate connection: {e}")
308
315
 
309
- @requires_dependencies(["pymongo"], extras="mongodb")
310
- def create_client(self) -> "MongoClient":
311
- from pymongo import MongoClient
312
- from pymongo.driver_info import DriverInfo
313
- from pymongo.server_api import ServerApi
314
-
315
- access_config = self.connection_config.access_config.get_secret_value()
316
-
317
- if access_config.uri:
318
- return MongoClient(
319
- access_config.uri,
320
- server_api=ServerApi(version=SERVER_API_VERSION),
321
- driver=DriverInfo(name="unstructured", version=unstructured_version),
322
- )
323
- else:
324
- return MongoClient(
325
- host=self.connection_config.host,
326
- port=self.connection_config.port,
327
- server_api=ServerApi(version=SERVER_API_VERSION),
328
- )
316
+ def can_delete(self, collection: "Collection") -> bool:
317
+ indexed_keys = []
318
+ for index in collection.list_indexes():
319
+ key_bson = index["key"]
320
+ indexed_keys.extend(key_bson.keys())
321
+ return self.upload_config.record_id_key in indexed_keys
322
+
323
+ def delete_by_record_id(self, collection: "Collection", file_data: FileData) -> None:
324
+ logger.debug(
325
+ f"deleting any content with metadata "
326
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
327
+ f"from collection: {collection.name}"
328
+ )
329
+ query = {self.upload_config.record_id_key: file_data.identifier}
330
+ delete_results = collection.delete_many(filter=query)
331
+ logger.info(
332
+ f"deleted {delete_results.deleted_count} records from collection {collection.name}"
333
+ )
329
334
 
330
335
  def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
331
336
  with path.open("r") as file:
332
337
  elements_dict = json.load(file)
333
338
  logger.info(
334
339
  f"writing {len(elements_dict)} objects to destination "
335
- f"db, {self.connection_config.database}, "
336
- f"collection {self.connection_config.collection} "
340
+ f"db, {self.upload_config.database}, "
341
+ f"collection {self.upload_config.collection} "
337
342
  f"at {self.connection_config.host}",
338
343
  )
339
- client = self.create_client()
340
- db = client[self.connection_config.database]
341
- collection = db[self.connection_config.collection]
342
- for chunk in batch_generator(elements_dict, self.upload_config.batch_size):
343
- collection.insert_many(chunk)
344
+ # This would typically live in the stager but since no other manipulation
345
+ # is done, setting the record id field in the uploader
346
+ for element in elements_dict:
347
+ element[self.upload_config.record_id_key] = file_data.identifier
348
+ with self.connection_config.get_client() as client:
349
+ db = client[self.upload_config.database]
350
+ collection = db[self.upload_config.collection]
351
+ if self.can_delete(collection=collection):
352
+ self.delete_by_record_id(file_data=file_data, collection=collection)
353
+ else:
354
+ logger.warning("criteria for deleting previous content not met, skipping")
355
+ for chunk in batch_generator(elements_dict, self.upload_config.batch_size):
356
+ collection.insert_many(chunk)
344
357
 
345
358
 
346
359
  mongodb_destination_entry = DestinationRegistryEntry(
347
360
  connection_config=MongoDBConnectionConfig,
348
361
  uploader=MongoDBUploader,
349
362
  uploader_config=MongoDBUploaderConfig,
350
- upload_stager=MongoDBUploadStager,
351
- upload_stager_config=MongoDBUploadStagerConfig,
352
363
  )
353
364
 
354
365
  mongodb_source_entry = SourceRegistryEntry(
@@ -30,6 +30,7 @@ if TYPE_CHECKING:
30
30
  CONNECTOR_TYPE = "pinecone"
31
31
  MAX_PAYLOAD_SIZE = 2 * 1024 * 1024 # 2MB
32
32
  MAX_POOL_THREADS = 100
33
+ MAX_METADATA_BYTES = 40960 # 40KB https://docs.pinecone.io/reference/quotas-and-limits#hard-limits
33
34
 
34
35
 
35
36
  class PineconeAccessConfig(AccessConfig):
@@ -103,6 +104,10 @@ class PineconeUploaderConfig(UploaderConfig):
103
104
  default=None,
104
105
  description="The namespace to write to. If not specified, the default namespace is used",
105
106
  )
107
+ record_id_key: str = Field(
108
+ default=RECORD_ID_LABEL,
109
+ description="searchable key to find entries for the same record on previous runs",
110
+ )
106
111
 
107
112
 
108
113
  @dataclass
@@ -133,6 +138,13 @@ class PineconeUploadStager(UploadStager):
133
138
  remove_none=True,
134
139
  )
135
140
  metadata[RECORD_ID_LABEL] = file_data.identifier
141
+ metadata_size_bytes = len(json.dumps(metadata).encode())
142
+ if metadata_size_bytes > MAX_METADATA_BYTES:
143
+ logger.info(
144
+ f"Metadata size is {metadata_size_bytes} bytes, which exceeds the limit of"
145
+ f" {MAX_METADATA_BYTES} bytes per vector. Dropping the metadata."
146
+ )
147
+ metadata = {}
136
148
 
137
149
  return {
138
150
  "id": str(uuid.uuid4()),
@@ -183,23 +195,28 @@ class PineconeUploader(Uploader):
183
195
 
184
196
  def pod_delete_by_record_id(self, file_data: FileData) -> None:
185
197
  logger.debug(
186
- f"deleting any content with metadata {RECORD_ID_LABEL}={file_data.identifier} "
198
+ f"deleting any content with metadata "
199
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
187
200
  f"from pinecone pod index"
188
201
  )
189
202
  index = self.connection_config.get_index(pool_threads=MAX_POOL_THREADS)
190
- delete_kwargs = {"filter": {RECORD_ID_LABEL: {"$eq": file_data.identifier}}}
203
+ delete_kwargs = {
204
+ "filter": {self.upload_config.record_id_key: {"$eq": file_data.identifier}}
205
+ }
191
206
  if namespace := self.upload_config.namespace:
192
207
  delete_kwargs["namespace"] = namespace
193
208
 
194
209
  resp = index.delete(**delete_kwargs)
195
210
  logger.debug(
196
- f"deleted any content with metadata {RECORD_ID_LABEL}={file_data.identifier} "
211
+ f"deleted any content with metadata "
212
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
197
213
  f"from pinecone index: {resp}"
198
214
  )
199
215
 
200
216
  def serverless_delete_by_record_id(self, file_data: FileData) -> None:
201
217
  logger.debug(
202
- f"deleting any content with metadata {RECORD_ID_LABEL}={file_data.identifier} "
218
+ f"deleting any content with metadata "
219
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
203
220
  f"from pinecone serverless index"
204
221
  )
205
222
  index = self.connection_config.get_index(pool_threads=MAX_POOL_THREADS)
@@ -209,7 +226,7 @@ class PineconeUploader(Uploader):
209
226
  return
210
227
  dimension = index_stats["dimension"]
211
228
  query_params = {
212
- "filter": {RECORD_ID_LABEL: {"$eq": file_data.identifier}},
229
+ "filter": {self.upload_config.record_id_key: {"$eq": file_data.identifier}},
213
230
  "vector": [0] * dimension,
214
231
  "top_k": total_vectors,
215
232
  }
@@ -226,7 +243,8 @@ class PineconeUploader(Uploader):
226
243
  delete_params["namespace"] = namespace
227
244
  index.delete(**delete_params)
228
245
  logger.debug(
229
- f"deleted any content with metadata {RECORD_ID_LABEL}={file_data.identifier} "
246
+ f"deleted any content with metadata "
247
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
230
248
  f"from pinecone index"
231
249
  )
232
250
 
@@ -269,7 +287,6 @@ class PineconeUploader(Uploader):
269
287
  f"writing a total of {len(elements_dict)} elements via"
270
288
  f" document batches to destination"
271
289
  f" index named {self.connection_config.index_name}"
272
- f" with batch size {self.upload_config.batch_size}"
273
290
  )
274
291
  # Determine if serverless or pod based index
275
292
  pinecone_client = self.connection_config.get_client()
@@ -16,6 +16,8 @@ from dateutil import parser
16
16
  from pydantic import Field, Secret
17
17
 
18
18
  from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
19
+ from unstructured_ingest.utils.data_prep import split_dataframe
20
+ from unstructured_ingest.v2.constants import RECORD_ID_LABEL
19
21
  from unstructured_ingest.v2.interfaces import (
20
22
  AccessConfig,
21
23
  ConnectionConfig,
@@ -236,35 +238,25 @@ class SQLUploadStagerConfig(UploadStagerConfig):
236
238
  class SQLUploadStager(UploadStager):
237
239
  upload_stager_config: SQLUploadStagerConfig = field(default_factory=SQLUploadStagerConfig)
238
240
 
239
- def run(
240
- self,
241
- elements_filepath: Path,
242
- file_data: FileData,
243
- output_dir: Path,
244
- output_filename: str,
245
- **kwargs: Any,
246
- ) -> Path:
247
- with open(elements_filepath) as elements_file:
248
- elements_contents: list[dict] = json.load(elements_file)
249
- output_path = Path(output_dir) / Path(f"{output_filename}.json")
250
- output_path.parent.mkdir(parents=True, exist_ok=True)
251
-
241
+ @staticmethod
242
+ def conform_dict(data: dict, file_data: FileData) -> pd.DataFrame:
243
+ working_data = data.copy()
252
244
  output = []
253
- for data in elements_contents:
254
- metadata: dict[str, Any] = data.pop("metadata", {})
245
+ for element in working_data:
246
+ metadata: dict[str, Any] = element.pop("metadata", {})
255
247
  data_source = metadata.pop("data_source", {})
256
248
  coordinates = metadata.pop("coordinates", {})
257
249
 
258
- data.update(metadata)
259
- data.update(data_source)
260
- data.update(coordinates)
250
+ element.update(metadata)
251
+ element.update(data_source)
252
+ element.update(coordinates)
261
253
 
262
- data["id"] = str(uuid.uuid4())
254
+ element["id"] = str(uuid.uuid4())
263
255
 
264
256
  # remove extraneous, not supported columns
265
- data = {k: v for k, v in data.items() if k in _COLUMNS}
266
-
267
- output.append(data)
257
+ element = {k: v for k, v in element.items() if k in _COLUMNS}
258
+ element[RECORD_ID_LABEL] = file_data.identifier
259
+ output.append(element)
268
260
 
269
261
  df = pd.DataFrame.from_dict(output)
270
262
  for column in filter(lambda x: x in df.columns, _DATE_COLUMNS):
@@ -281,6 +273,26 @@ class SQLUploadStager(UploadStager):
281
273
  ("version", "page_number", "regex_metadata"),
282
274
  ):
283
275
  df[column] = df[column].apply(str)
276
+ return df
277
+
278
+ def run(
279
+ self,
280
+ elements_filepath: Path,
281
+ file_data: FileData,
282
+ output_dir: Path,
283
+ output_filename: str,
284
+ **kwargs: Any,
285
+ ) -> Path:
286
+ with open(elements_filepath) as elements_file:
287
+ elements_contents: list[dict] = json.load(elements_file)
288
+
289
+ df = self.conform_dict(data=elements_contents, file_data=file_data)
290
+ if Path(output_filename).suffix != ".json":
291
+ output_filename = f"{output_filename}.json"
292
+ else:
293
+ output_filename = f"{Path(output_filename).stem}.json"
294
+ output_path = Path(output_dir) / Path(f"{output_filename}")
295
+ output_path.parent.mkdir(parents=True, exist_ok=True)
284
296
 
285
297
  with output_path.open("w") as output_file:
286
298
  df.to_json(output_file, orient="records", lines=True)
@@ -290,6 +302,10 @@ class SQLUploadStager(UploadStager):
290
302
  class SQLUploaderConfig(UploaderConfig):
291
303
  batch_size: int = Field(default=50, description="Number of records per batch")
292
304
  table_name: str = Field(default="elements", description="which table to upload contents to")
305
+ record_id_key: str = Field(
306
+ default=RECORD_ID_LABEL,
307
+ description="searchable key to find entries for the same record on previous runs",
308
+ )
293
309
 
294
310
 
295
311
  @dataclass
@@ -323,18 +339,45 @@ class SQLUploader(Uploader):
323
339
  output.append(tuple(parsed))
324
340
  return output
325
341
 
342
+ def _fit_to_schema(self, df: pd.DataFrame, columns: list[str]) -> pd.DataFrame:
343
+ columns = set(df.columns)
344
+ schema_fields = set(columns)
345
+ columns_to_drop = columns - schema_fields
346
+ missing_columns = schema_fields - columns
347
+
348
+ if columns_to_drop:
349
+ logger.warning(
350
+ "Following columns will be dropped to match the table's schema: "
351
+ f"{', '.join(columns_to_drop)}"
352
+ )
353
+ if missing_columns:
354
+ logger.info(
355
+ "Following null filled columns will be added to match the table's schema:"
356
+ f" {', '.join(missing_columns)} "
357
+ )
358
+
359
+ df = df.drop(columns=columns_to_drop)
360
+
361
+ for column in missing_columns:
362
+ df[column] = pd.Series()
363
+
326
364
  def upload_contents(self, path: Path) -> None:
327
365
  df = pd.read_json(path, orient="records", lines=True)
328
366
  df.replace({np.nan: None}, inplace=True)
367
+ self._fit_to_schema(df=df, columns=self.get_table_columns())
329
368
 
330
369
  columns = list(df.columns)
331
370
  stmt = f"INSERT INTO {self.upload_config.table_name} ({','.join(columns)}) VALUES({','.join([self.values_delimiter for x in columns])})" # noqa E501
332
-
333
- for rows in pd.read_json(
334
- path, orient="records", lines=True, chunksize=self.upload_config.batch_size
335
- ):
371
+ logger.info(
372
+ f"writing a total of {len(df)} elements via"
373
+ f" document batches to destination"
374
+ f" table named {self.upload_config.table_name}"
375
+ f" with batch size {self.upload_config.batch_size}"
376
+ )
377
+ for rows in split_dataframe(df=df, chunk_size=self.upload_config.batch_size):
336
378
  with self.connection_config.get_cursor() as cursor:
337
379
  values = self.prepare_data(columns, tuple(rows.itertuples(index=False, name=None)))
380
+ # For debugging purposes:
338
381
  # for val in values:
339
382
  # try:
340
383
  # cursor.execute(stmt, val)
@@ -343,5 +386,33 @@ class SQLUploader(Uploader):
343
386
  # print(f"failed to write {len(columns)}, {len(val)}: {stmt} -> {val}")
344
387
  cursor.executemany(stmt, values)
345
388
 
389
+ def get_table_columns(self) -> list[str]:
390
+ with self.connection_config.get_cursor() as cursor:
391
+ cursor.execute(f"SELECT * from {self.upload_config.table_name}")
392
+ return [desc[0] for desc in cursor.description]
393
+
394
+ def can_delete(self) -> bool:
395
+ return self.upload_config.record_id_key in self.get_table_columns()
396
+
397
+ def delete_by_record_id(self, file_data: FileData) -> None:
398
+ logger.debug(
399
+ f"deleting any content with data "
400
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
401
+ f"from table {self.upload_config.table_name}"
402
+ )
403
+ stmt = f"DELETE FROM {self.upload_config.table_name} WHERE {self.upload_config.record_id_key} = {self.values_delimiter}" # noqa: E501
404
+ with self.connection_config.get_cursor() as cursor:
405
+ cursor.execute(stmt, [file_data.identifier])
406
+ rowcount = cursor.rowcount
407
+ logger.info(f"deleted {rowcount} rows from table {self.upload_config.table_name}")
408
+
346
409
  def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
410
+ if self.can_delete():
411
+ self.delete_by_record_id(file_data=file_data)
412
+ else:
413
+ logger.warning(
414
+ f"table doesn't contain expected "
415
+ f"record id column "
416
+ f"{self.upload_config.record_id_key}, skipping delete"
417
+ )
347
418
  self.upload_contents(path=path)
@@ -0,0 +1,25 @@
1
+ from __future__ import annotations
2
+
3
+ from unstructured_ingest.v2.processes.connector_registry import (
4
+ add_destination_entry,
5
+ )
6
+
7
+ from .cloud import CONNECTOR_TYPE as CLOUD_WEAVIATE_CONNECTOR_TYPE
8
+ from .cloud import weaviate_cloud_destination_entry
9
+ from .embedded import CONNECTOR_TYPE as EMBEDDED_WEAVIATE_CONNECTOR_TYPE
10
+ from .embedded import weaviate_embedded_destination_entry
11
+ from .local import CONNECTOR_TYPE as LOCAL_WEAVIATE_CONNECTOR_TYPE
12
+ from .local import weaviate_local_destination_entry
13
+ from .weaviate import CONNECTOR_TYPE as WEAVIATE_CONNECTOR_TYPE
14
+ from .weaviate import weaviate_destination_entry
15
+
16
+ add_destination_entry(
17
+ destination_type=LOCAL_WEAVIATE_CONNECTOR_TYPE, entry=weaviate_local_destination_entry
18
+ )
19
+ add_destination_entry(
20
+ destination_type=CLOUD_WEAVIATE_CONNECTOR_TYPE, entry=weaviate_cloud_destination_entry
21
+ )
22
+ add_destination_entry(
23
+ destination_type=EMBEDDED_WEAVIATE_CONNECTOR_TYPE, entry=weaviate_embedded_destination_entry
24
+ )
25
+ add_destination_entry(destination_type=WEAVIATE_CONNECTOR_TYPE, entry=weaviate_destination_entry)