unstructured-ingest 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (51) hide show
  1. test/integration/connectors/elasticsearch/__init__.py +0 -0
  2. test/integration/connectors/elasticsearch/conftest.py +34 -0
  3. test/integration/connectors/elasticsearch/test_elasticsearch.py +308 -0
  4. test/integration/connectors/elasticsearch/test_opensearch.py +302 -0
  5. test/integration/connectors/sql/test_postgres.py +10 -4
  6. test/integration/connectors/sql/test_singlestore.py +8 -4
  7. test/integration/connectors/sql/test_snowflake.py +10 -6
  8. test/integration/connectors/sql/test_sqlite.py +4 -4
  9. test/integration/connectors/test_astradb.py +50 -3
  10. test/integration/connectors/test_delta_table.py +46 -0
  11. test/integration/connectors/test_kafka.py +40 -6
  12. test/integration/connectors/test_lancedb.py +209 -0
  13. test/integration/connectors/test_milvus.py +141 -0
  14. test/integration/connectors/test_pinecone.py +53 -1
  15. test/integration/connectors/utils/docker.py +81 -15
  16. test/integration/connectors/utils/validation.py +10 -0
  17. test/integration/connectors/weaviate/__init__.py +0 -0
  18. test/integration/connectors/weaviate/conftest.py +15 -0
  19. test/integration/connectors/weaviate/test_local.py +131 -0
  20. unstructured_ingest/__version__.py +1 -1
  21. unstructured_ingest/pipeline/reformat/embedding.py +1 -1
  22. unstructured_ingest/utils/data_prep.py +9 -1
  23. unstructured_ingest/v2/processes/connectors/__init__.py +3 -16
  24. unstructured_ingest/v2/processes/connectors/astradb.py +2 -2
  25. unstructured_ingest/v2/processes/connectors/azure_ai_search.py +4 -0
  26. unstructured_ingest/v2/processes/connectors/delta_table.py +20 -4
  27. unstructured_ingest/v2/processes/connectors/elasticsearch/__init__.py +19 -0
  28. unstructured_ingest/v2/processes/connectors/{elasticsearch.py → elasticsearch/elasticsearch.py} +92 -46
  29. unstructured_ingest/v2/processes/connectors/{opensearch.py → elasticsearch/opensearch.py} +1 -1
  30. unstructured_ingest/v2/processes/connectors/kafka/kafka.py +6 -0
  31. unstructured_ingest/v2/processes/connectors/lancedb/__init__.py +17 -0
  32. unstructured_ingest/v2/processes/connectors/lancedb/aws.py +43 -0
  33. unstructured_ingest/v2/processes/connectors/lancedb/azure.py +43 -0
  34. unstructured_ingest/v2/processes/connectors/lancedb/gcp.py +44 -0
  35. unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +161 -0
  36. unstructured_ingest/v2/processes/connectors/lancedb/local.py +44 -0
  37. unstructured_ingest/v2/processes/connectors/milvus.py +72 -27
  38. unstructured_ingest/v2/processes/connectors/pinecone.py +24 -7
  39. unstructured_ingest/v2/processes/connectors/sql/sql.py +97 -26
  40. unstructured_ingest/v2/processes/connectors/weaviate/__init__.py +22 -0
  41. unstructured_ingest/v2/processes/connectors/weaviate/cloud.py +164 -0
  42. unstructured_ingest/v2/processes/connectors/weaviate/embedded.py +90 -0
  43. unstructured_ingest/v2/processes/connectors/weaviate/local.py +73 -0
  44. unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +289 -0
  45. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/METADATA +15 -15
  46. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/RECORD +50 -30
  47. unstructured_ingest/v2/processes/connectors/weaviate.py +0 -242
  48. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/LICENSE.md +0 -0
  49. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/WHEEL +0 -0
  50. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/entry_points.txt +0 -0
  51. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ import hashlib
2
2
  import json
3
3
  import sys
4
4
  import uuid
5
+ from contextlib import contextmanager
5
6
  from dataclasses import dataclass, field
6
7
  from pathlib import Path
7
8
  from time import time
@@ -13,9 +14,11 @@ from unstructured_ingest.error import (
13
14
  DestinationConnectionError,
14
15
  SourceConnectionError,
15
16
  SourceConnectionNetworkError,
17
+ WriteError,
16
18
  )
17
19
  from unstructured_ingest.utils.data_prep import flatten_dict, generator_batching_wbytes
18
20
  from unstructured_ingest.utils.dep_check import requires_dependencies
21
+ from unstructured_ingest.v2.constants import RECORD_ID_LABEL
19
22
  from unstructured_ingest.v2.interfaces import (
20
23
  AccessConfig,
21
24
  ConnectionConfig,
@@ -26,6 +29,7 @@ from unstructured_ingest.v2.interfaces import (
26
29
  FileDataSourceMetadata,
27
30
  Indexer,
28
31
  IndexerConfig,
32
+ SourceIdentifiers,
29
33
  Uploader,
30
34
  UploaderConfig,
31
35
  UploadStager,
@@ -116,19 +120,12 @@ class ElasticsearchConnectionConfig(ConnectionConfig):
116
120
  return client_kwargs
117
121
 
118
122
  @requires_dependencies(["elasticsearch"], extras="elasticsearch")
119
- def get_client(self) -> "ElasticsearchClient":
123
+ @contextmanager
124
+ def get_client(self) -> Generator["ElasticsearchClient", None, None]:
120
125
  from elasticsearch import Elasticsearch as ElasticsearchClient
121
126
 
122
- client = ElasticsearchClient(**self.get_client_kwargs())
123
- self.check_connection(client=client)
124
- return client
125
-
126
- def check_connection(self, client: "ElasticsearchClient"):
127
- try:
128
- client.perform_request("HEAD", "/", headers={"accept": "application/json"})
129
- except Exception as e:
130
- logger.error(f"failed to validate connection: {e}", exc_info=True)
131
- raise SourceConnectionError(f"failed to validate connection: {e}")
127
+ with ElasticsearchClient(**self.get_client_kwargs()) as client:
128
+ yield client
132
129
 
133
130
 
134
131
  class ElasticsearchIndexerConfig(IndexerConfig):
@@ -144,7 +141,16 @@ class ElasticsearchIndexer(Indexer):
144
141
 
145
142
  def precheck(self) -> None:
146
143
  try:
147
- self.connection_config.get_client()
144
+ with self.connection_config.get_client() as client:
145
+ if not client.ping():
146
+ raise SourceConnectionError("cluster not detected")
147
+ indices = client.indices.get_alias(index="*")
148
+ if self.index_config.index_name not in indices:
149
+ raise SourceConnectionError(
150
+ "index {} not found: {}".format(
151
+ self.index_config.index_name, ", ".join(indices.keys())
152
+ )
153
+ )
148
154
  except Exception as e:
149
155
  logger.error(f"failed to validate connection: {e}", exc_info=True)
150
156
  raise SourceConnectionError(f"failed to validate connection: {e}")
@@ -160,15 +166,15 @@ class ElasticsearchIndexer(Indexer):
160
166
  scan = self.load_scan()
161
167
 
162
168
  scan_query: dict = {"stored_fields": [], "query": {"match_all": {}}}
163
- client = self.connection_config.get_client()
164
- hits = scan(
165
- client,
166
- query=scan_query,
167
- scroll="1m",
168
- index=self.index_config.index_name,
169
- )
169
+ with self.connection_config.get_client() as client:
170
+ hits = scan(
171
+ client,
172
+ query=scan_query,
173
+ scroll="1m",
174
+ index=self.index_config.index_name,
175
+ )
170
176
 
171
- return {hit["_id"] for hit in hits}
177
+ return {hit["_id"] for hit in hits}
172
178
 
173
179
  def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
174
180
  all_ids = self._get_doc_ids()
@@ -257,6 +263,7 @@ class ElasticsearchDownloader(Downloader):
257
263
  file_data=FileData(
258
264
  identifier=filename_id,
259
265
  connector_type=CONNECTOR_TYPE,
266
+ source_identifiers=SourceIdentifiers(filename=filename, fullpath=filename),
260
267
  metadata=FileDataSourceMetadata(
261
268
  version=str(result["_version"]) if "_version" in result else None,
262
269
  date_processed=str(time()),
@@ -318,7 +325,7 @@ class ElasticsearchUploadStagerConfig(UploadStagerConfig):
318
325
  class ElasticsearchUploadStager(UploadStager):
319
326
  upload_stager_config: ElasticsearchUploadStagerConfig
320
327
 
321
- def conform_dict(self, data: dict) -> dict:
328
+ def conform_dict(self, data: dict, file_data: FileData) -> dict:
322
329
  resp = {
323
330
  "_index": self.upload_stager_config.index_name,
324
331
  "_id": str(uuid.uuid4()),
@@ -327,6 +334,7 @@ class ElasticsearchUploadStager(UploadStager):
327
334
  "embeddings": data.pop("embeddings", None),
328
335
  "text": data.pop("text", None),
329
336
  "type": data.pop("type", None),
337
+ RECORD_ID_LABEL: file_data.identifier,
330
338
  },
331
339
  }
332
340
  if "metadata" in data and isinstance(data["metadata"], dict):
@@ -343,10 +351,17 @@ class ElasticsearchUploadStager(UploadStager):
343
351
  ) -> Path:
344
352
  with open(elements_filepath) as elements_file:
345
353
  elements_contents = json.load(elements_file)
346
- conformed_elements = [self.conform_dict(data=element) for element in elements_contents]
347
- output_path = Path(output_dir) / Path(f"{output_filename}.json")
354
+ conformed_elements = [
355
+ self.conform_dict(data=element, file_data=file_data) for element in elements_contents
356
+ ]
357
+ if Path(output_filename).suffix != ".json":
358
+ output_filename = f"{output_filename}.json"
359
+ else:
360
+ output_filename = f"{Path(output_filename).stem}.json"
361
+ output_path = Path(output_dir) / output_filename
362
+ output_path.parent.mkdir(parents=True, exist_ok=True)
348
363
  with open(output_path, "w") as output_file:
349
- json.dump(conformed_elements, output_file)
364
+ json.dump(conformed_elements, output_file, indent=2)
350
365
  return output_path
351
366
 
352
367
 
@@ -363,6 +378,10 @@ class ElasticsearchUploaderConfig(UploaderConfig):
363
378
  num_threads: int = Field(
364
379
  default=4, description="Number of threads to be used while uploading content"
365
380
  )
381
+ record_id_key: str = Field(
382
+ default=RECORD_ID_LABEL,
383
+ description="searchable key to find entries for the same record on previous runs",
384
+ )
366
385
 
367
386
 
368
387
  @dataclass
@@ -373,7 +392,16 @@ class ElasticsearchUploader(Uploader):
373
392
 
374
393
  def precheck(self) -> None:
375
394
  try:
376
- self.connection_config.get_client()
395
+ with self.connection_config.get_client() as client:
396
+ if not client.ping():
397
+ raise DestinationConnectionError("cluster not detected")
398
+ indices = client.indices.get_alias(index="*")
399
+ if self.upload_config.index_name not in indices:
400
+ raise SourceConnectionError(
401
+ "index {} not found: {}".format(
402
+ self.upload_config.index_name, ", ".join(indices.keys())
403
+ )
404
+ )
377
405
  except Exception as e:
378
406
  logger.error(f"failed to validate connection: {e}", exc_info=True)
379
407
  raise DestinationConnectionError(f"failed to validate connection: {e}")
@@ -384,6 +412,23 @@ class ElasticsearchUploader(Uploader):
384
412
 
385
413
  return parallel_bulk
386
414
 
415
+ def delete_by_record_id(self, client, file_data: FileData) -> None:
416
+ logger.debug(
417
+ f"deleting any content with metadata {RECORD_ID_LABEL}={file_data.identifier} "
418
+ f"from {self.upload_config.index_name} index"
419
+ )
420
+ delete_resp = client.delete_by_query(
421
+ index=self.upload_config.index_name,
422
+ body={"query": {"match": {self.upload_config.record_id_key: file_data.identifier}}},
423
+ )
424
+ logger.info(
425
+ "deleted {} records from index {}".format(
426
+ delete_resp["deleted"], self.upload_config.index_name
427
+ )
428
+ )
429
+ if failures := delete_resp.get("failures"):
430
+ raise WriteError(f"failed to delete records: {failures}")
431
+
387
432
  def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
388
433
  parallel_bulk = self.load_parallel_bulk()
389
434
  with path.open("r") as file:
@@ -397,28 +442,29 @@ class ElasticsearchUploader(Uploader):
397
442
  f"{self.upload_config.num_threads} (number of) threads"
398
443
  )
399
444
 
400
- client = self.connection_config.get_client()
401
- if not client.indices.exists(index=self.upload_config.index_name):
402
- logger.warning(
403
- f"{(self.__class__.__name__).replace('Uploader', '')} index does not exist: "
404
- f"{self.upload_config.index_name}. "
405
- f"This may cause issues when uploading."
406
- )
407
- for batch in generator_batching_wbytes(
408
- elements_dict, batch_size_limit_bytes=self.upload_config.batch_size_bytes
409
- ):
410
- for success, info in parallel_bulk(
411
- client=client,
412
- actions=batch,
413
- thread_count=self.upload_config.num_threads,
445
+ with self.connection_config.get_client() as client:
446
+ self.delete_by_record_id(client=client, file_data=file_data)
447
+ if not client.indices.exists(index=self.upload_config.index_name):
448
+ logger.warning(
449
+ f"{(self.__class__.__name__).replace('Uploader', '')} index does not exist: "
450
+ f"{self.upload_config.index_name}. "
451
+ f"This may cause issues when uploading."
452
+ )
453
+ for batch in generator_batching_wbytes(
454
+ elements_dict, batch_size_limit_bytes=self.upload_config.batch_size_bytes
414
455
  ):
415
- if not success:
416
- logger.error(
417
- "upload failed for a batch in "
418
- f"{(self.__class__.__name__).replace('Uploader', '')} "
419
- "destination connector:",
420
- info,
421
- )
456
+ for success, info in parallel_bulk(
457
+ client=client,
458
+ actions=batch,
459
+ thread_count=self.upload_config.num_threads,
460
+ ):
461
+ if not success:
462
+ logger.error(
463
+ "upload failed for a batch in "
464
+ f"{(self.__class__.__name__).replace('Uploader', '')} "
465
+ "destination connector:",
466
+ info,
467
+ )
422
468
 
423
469
 
424
470
  elasticsearch_source_entry = SourceRegistryEntry(
@@ -17,7 +17,7 @@ from unstructured_ingest.v2.processes.connector_registry import (
17
17
  DestinationRegistryEntry,
18
18
  SourceRegistryEntry,
19
19
  )
20
- from unstructured_ingest.v2.processes.connectors.elasticsearch import (
20
+ from unstructured_ingest.v2.processes.connectors.elasticsearch.elasticsearch import (
21
21
  ElasticsearchDownloader,
22
22
  ElasticsearchDownloaderConfig,
23
23
  ElasticsearchIndexer,
@@ -161,6 +161,12 @@ class KafkaIndexer(Indexer, ABC):
161
161
  current_topics = [
162
162
  topic for topic in cluster_meta.topics if topic != "__consumer_offsets"
163
163
  ]
164
+ if self.index_config.topic not in current_topics:
165
+ raise SourceConnectionError(
166
+ "expected topic {} not detected in cluster: {}".format(
167
+ self.index_config.topic, ", ".join(current_topics)
168
+ )
169
+ )
164
170
  logger.info(f"successfully checked available topics: {current_topics}")
165
171
  except Exception as e:
166
172
  logger.error(f"failed to validate connection: {e}", exc_info=True)
@@ -0,0 +1,17 @@
1
+ from __future__ import annotations
2
+
3
+ from unstructured_ingest.v2.processes.connector_registry import add_destination_entry
4
+
5
+ from .aws import CONNECTOR_TYPE as LANCEDB_S3_CONNECTOR_TYPE
6
+ from .aws import lancedb_aws_destination_entry
7
+ from .azure import CONNECTOR_TYPE as LANCEDB_AZURE_CONNECTOR_TYPE
8
+ from .azure import lancedb_azure_destination_entry
9
+ from .gcp import CONNECTOR_TYPE as LANCEDB_GCS_CONNECTOR_TYPE
10
+ from .gcp import lancedb_gcp_destination_entry
11
+ from .local import CONNECTOR_TYPE as LANCEDB_LOCAL_CONNECTOR_TYPE
12
+ from .local import lancedb_local_destination_entry
13
+
14
+ add_destination_entry(LANCEDB_S3_CONNECTOR_TYPE, lancedb_aws_destination_entry)
15
+ add_destination_entry(LANCEDB_AZURE_CONNECTOR_TYPE, lancedb_azure_destination_entry)
16
+ add_destination_entry(LANCEDB_GCS_CONNECTOR_TYPE, lancedb_gcp_destination_entry)
17
+ add_destination_entry(LANCEDB_LOCAL_CONNECTOR_TYPE, lancedb_local_destination_entry)
@@ -0,0 +1,43 @@
1
+ from dataclasses import dataclass
2
+
3
+ from pydantic import Field, Secret
4
+
5
+ from unstructured_ingest.v2.interfaces.connector import AccessConfig
6
+ from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
7
+ from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
8
+ LanceDBRemoteConnectionConfig,
9
+ LanceDBUploader,
10
+ LanceDBUploaderConfig,
11
+ LanceDBUploadStager,
12
+ LanceDBUploadStagerConfig,
13
+ )
14
+
15
+ CONNECTOR_TYPE = "lancedb_aws"
16
+
17
+
18
+ class LanceDBS3AccessConfig(AccessConfig):
19
+ aws_access_key_id: str = Field(description="The AWS access key ID to use.")
20
+ aws_secret_access_key: str = Field(description="The AWS secret access key to use.")
21
+
22
+
23
+ class LanceDBS3ConnectionConfig(LanceDBRemoteConnectionConfig):
24
+ access_config: Secret[LanceDBS3AccessConfig]
25
+
26
+ def get_storage_options(self) -> dict:
27
+ return {**self.access_config.get_secret_value().model_dump(), "timeout": self.timeout}
28
+
29
+
30
+ @dataclass
31
+ class LanceDBS3Uploader(LanceDBUploader):
32
+ upload_config: LanceDBUploaderConfig
33
+ connection_config: LanceDBS3ConnectionConfig
34
+ connector_type: str = CONNECTOR_TYPE
35
+
36
+
37
+ lancedb_aws_destination_entry = DestinationRegistryEntry(
38
+ connection_config=LanceDBS3ConnectionConfig,
39
+ uploader=LanceDBS3Uploader,
40
+ uploader_config=LanceDBUploaderConfig,
41
+ upload_stager_config=LanceDBUploadStagerConfig,
42
+ upload_stager=LanceDBUploadStager,
43
+ )
@@ -0,0 +1,43 @@
1
+ from dataclasses import dataclass
2
+
3
+ from pydantic import Field, Secret
4
+
5
+ from unstructured_ingest.v2.interfaces.connector import AccessConfig
6
+ from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
7
+ from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
8
+ LanceDBRemoteConnectionConfig,
9
+ LanceDBUploader,
10
+ LanceDBUploaderConfig,
11
+ LanceDBUploadStager,
12
+ LanceDBUploadStagerConfig,
13
+ )
14
+
15
+ CONNECTOR_TYPE = "lancedb_azure"
16
+
17
+
18
+ class LanceDBAzureAccessConfig(AccessConfig):
19
+ azure_storage_account_name: str = Field(description="The name of the azure storage account.")
20
+ azure_storage_account_key: str = Field(description="The serialized azure service account key.")
21
+
22
+
23
+ class LanceDBAzureConnectionConfig(LanceDBRemoteConnectionConfig):
24
+ access_config: Secret[LanceDBAzureAccessConfig]
25
+
26
+ def get_storage_options(self) -> dict:
27
+ return {**self.access_config.get_secret_value().model_dump(), "timeout": self.timeout}
28
+
29
+
30
+ @dataclass
31
+ class LanceDBAzureUploader(LanceDBUploader):
32
+ upload_config: LanceDBUploaderConfig
33
+ connection_config: LanceDBAzureConnectionConfig
34
+ connector_type: str = CONNECTOR_TYPE
35
+
36
+
37
+ lancedb_azure_destination_entry = DestinationRegistryEntry(
38
+ connection_config=LanceDBAzureConnectionConfig,
39
+ uploader=LanceDBAzureUploader,
40
+ uploader_config=LanceDBUploaderConfig,
41
+ upload_stager_config=LanceDBUploadStagerConfig,
42
+ upload_stager=LanceDBUploadStager,
43
+ )
@@ -0,0 +1,44 @@
1
+ from dataclasses import dataclass
2
+
3
+ from pydantic import Field, Secret
4
+
5
+ from unstructured_ingest.v2.interfaces.connector import AccessConfig
6
+ from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
7
+ from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
8
+ LanceDBRemoteConnectionConfig,
9
+ LanceDBUploader,
10
+ LanceDBUploaderConfig,
11
+ LanceDBUploadStager,
12
+ LanceDBUploadStagerConfig,
13
+ )
14
+
15
+ CONNECTOR_TYPE = "lancedb_gcs"
16
+
17
+
18
+ class LanceDBGCSAccessConfig(AccessConfig):
19
+ google_service_account_key: str = Field(
20
+ description="The serialized google service account key."
21
+ )
22
+
23
+
24
+ class LanceDBGCSConnectionConfig(LanceDBRemoteConnectionConfig):
25
+ access_config: Secret[LanceDBGCSAccessConfig]
26
+
27
+ def get_storage_options(self) -> dict:
28
+ return {**self.access_config.get_secret_value().model_dump(), "timeout": self.timeout}
29
+
30
+
31
+ @dataclass
32
+ class LanceDBGSPUploader(LanceDBUploader):
33
+ upload_config: LanceDBUploaderConfig
34
+ connection_config: LanceDBGCSConnectionConfig
35
+ connector_type: str = CONNECTOR_TYPE
36
+
37
+
38
+ lancedb_gcp_destination_entry = DestinationRegistryEntry(
39
+ connection_config=LanceDBGCSConnectionConfig,
40
+ uploader=LanceDBGSPUploader,
41
+ uploader_config=LanceDBUploaderConfig,
42
+ upload_stager_config=LanceDBUploadStagerConfig,
43
+ upload_stager=LanceDBUploadStager,
44
+ )
@@ -0,0 +1,161 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ from abc import ABC, abstractmethod
6
+ from contextlib import asynccontextmanager
7
+ from dataclasses import dataclass, field
8
+ from pathlib import Path
9
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
10
+
11
+ import pandas as pd
12
+ from pydantic import Field
13
+
14
+ from unstructured_ingest.error import DestinationConnectionError
15
+ from unstructured_ingest.logger import logger
16
+ from unstructured_ingest.utils.data_prep import flatten_dict
17
+ from unstructured_ingest.utils.dep_check import requires_dependencies
18
+ from unstructured_ingest.v2.interfaces.connector import ConnectionConfig
19
+ from unstructured_ingest.v2.interfaces.file_data import FileData
20
+ from unstructured_ingest.v2.interfaces.upload_stager import UploadStager, UploadStagerConfig
21
+ from unstructured_ingest.v2.interfaces.uploader import Uploader, UploaderConfig
22
+
23
+ CONNECTOR_TYPE = "lancedb"
24
+
25
+ if TYPE_CHECKING:
26
+ from lancedb import AsyncConnection
27
+ from lancedb.table import AsyncTable
28
+
29
+
30
+ class LanceDBConnectionConfig(ConnectionConfig, ABC):
31
+ uri: str = Field(description="The uri of the database.")
32
+
33
+ @abstractmethod
34
+ def get_storage_options(self) -> Optional[dict[str, str]]:
35
+ raise NotImplementedError
36
+
37
+ @asynccontextmanager
38
+ @requires_dependencies(["lancedb"], extras="lancedb")
39
+ @DestinationConnectionError.wrap
40
+ async def get_async_connection(self) -> AsyncGenerator["AsyncConnection", None]:
41
+ import lancedb
42
+
43
+ connection = await lancedb.connect_async(
44
+ self.uri,
45
+ storage_options=self.get_storage_options(),
46
+ )
47
+ try:
48
+ yield connection
49
+ finally:
50
+ connection.close()
51
+
52
+
53
+ class LanceDBRemoteConnectionConfig(LanceDBConnectionConfig):
54
+ timeout: str = Field(
55
+ default="30s",
56
+ description=(
57
+ "Timeout for the entire request, from connection until the response body has finished"
58
+ "in a [0-9]+(ns|us|ms|[smhdwy]) format."
59
+ ),
60
+ pattern=r"[0-9]+(ns|us|ms|[smhdwy])",
61
+ )
62
+
63
+
64
+ class LanceDBUploadStagerConfig(UploadStagerConfig):
65
+ pass
66
+
67
+
68
+ @dataclass
69
+ class LanceDBUploadStager(UploadStager):
70
+ upload_stager_config: LanceDBUploadStagerConfig = field(
71
+ default_factory=LanceDBUploadStagerConfig
72
+ )
73
+
74
+ def run(
75
+ self,
76
+ elements_filepath: Path,
77
+ file_data: FileData,
78
+ output_dir: Path,
79
+ output_filename: str,
80
+ **kwargs: Any,
81
+ ) -> Path:
82
+ with open(elements_filepath) as elements_file:
83
+ elements_contents: list[dict] = json.load(elements_file)
84
+
85
+ df = pd.DataFrame(
86
+ [
87
+ self._conform_element_contents(element_contents)
88
+ for element_contents in elements_contents
89
+ ]
90
+ )
91
+
92
+ output_path = (output_dir / output_filename).with_suffix(".feather")
93
+ df.to_feather(output_path)
94
+
95
+ return output_path
96
+
97
+ def _conform_element_contents(self, element: dict) -> dict:
98
+ return {
99
+ "vector": element.pop("embeddings", None),
100
+ **flatten_dict(element, separator="-"),
101
+ }
102
+
103
+
104
+ class LanceDBUploaderConfig(UploaderConfig):
105
+ table_name: str = Field(description="The name of the table.")
106
+
107
+
108
+ @dataclass
109
+ class LanceDBUploader(Uploader):
110
+ upload_config: LanceDBUploaderConfig
111
+ connection_config: LanceDBConnectionConfig
112
+ connector_type: str = CONNECTOR_TYPE
113
+
114
+ @DestinationConnectionError.wrap
115
+ def precheck(self):
116
+ async def _precheck() -> None:
117
+ async with self.connection_config.get_async_connection() as conn:
118
+ table = await conn.open_table(self.upload_config.table_name)
119
+ table.close()
120
+
121
+ asyncio.run(_precheck())
122
+
123
+ @asynccontextmanager
124
+ async def get_table(self) -> AsyncGenerator["AsyncTable", None]:
125
+ async with self.connection_config.get_async_connection() as conn:
126
+ table = await conn.open_table(self.upload_config.table_name)
127
+ try:
128
+ yield table
129
+ finally:
130
+ table.close()
131
+
132
+ async def run_async(self, path, file_data, **kwargs):
133
+ df = pd.read_feather(path)
134
+ async with self.get_table() as table:
135
+ schema = await table.schema()
136
+ df = self._fit_to_schema(df, schema)
137
+ await table.add(data=df)
138
+
139
+ def _fit_to_schema(self, df: pd.DataFrame, schema) -> pd.DataFrame:
140
+ columns = set(df.columns)
141
+ schema_fields = set(schema.names)
142
+ columns_to_drop = columns - schema_fields
143
+ missing_columns = schema_fields - columns
144
+
145
+ if columns_to_drop:
146
+ logger.info(
147
+ "Following columns will be dropped to match the table's schema: "
148
+ f"{', '.join(columns_to_drop)}"
149
+ )
150
+ if missing_columns:
151
+ logger.info(
152
+ "Following null filled columns will be added to match the table's schema:"
153
+ f" {', '.join(missing_columns)} "
154
+ )
155
+
156
+ df = df.drop(columns=columns_to_drop)
157
+
158
+ for column in missing_columns:
159
+ df[column] = pd.Series()
160
+
161
+ return df
@@ -0,0 +1,44 @@
1
+ from dataclasses import dataclass
2
+
3
+ from pydantic import Field, Secret
4
+
5
+ from unstructured_ingest.v2.interfaces.connector import AccessConfig
6
+ from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
7
+ from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
8
+ LanceDBConnectionConfig,
9
+ LanceDBUploader,
10
+ LanceDBUploaderConfig,
11
+ LanceDBUploadStager,
12
+ LanceDBUploadStagerConfig,
13
+ )
14
+
15
+ CONNECTOR_TYPE = "lancedb_local"
16
+
17
+
18
+ class LanceDBLocalAccessConfig(AccessConfig):
19
+ pass
20
+
21
+
22
+ class LanceDBLocalConnectionConfig(LanceDBConnectionConfig):
23
+ access_config: Secret[LanceDBLocalAccessConfig] = Field(
24
+ default_factory=LanceDBLocalAccessConfig, validate_default=True
25
+ )
26
+
27
+ def get_storage_options(self) -> None:
28
+ return None
29
+
30
+
31
+ @dataclass
32
+ class LanceDBLocalUploader(LanceDBUploader):
33
+ upload_config: LanceDBUploaderConfig
34
+ connection_config: LanceDBLocalConnectionConfig
35
+ connector_type: str = CONNECTOR_TYPE
36
+
37
+
38
+ lancedb_local_destination_entry = DestinationRegistryEntry(
39
+ connection_config=LanceDBLocalConnectionConfig,
40
+ uploader=LanceDBLocalUploader,
41
+ uploader_config=LanceDBUploaderConfig,
42
+ upload_stager_config=LanceDBUploadStagerConfig,
43
+ upload_stager=LanceDBUploadStager,
44
+ )