unstructured-ingest 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/elasticsearch/__init__.py +0 -0
- test/integration/connectors/elasticsearch/conftest.py +34 -0
- test/integration/connectors/elasticsearch/test_elasticsearch.py +308 -0
- test/integration/connectors/elasticsearch/test_opensearch.py +302 -0
- test/integration/connectors/sql/test_postgres.py +10 -4
- test/integration/connectors/sql/test_singlestore.py +8 -4
- test/integration/connectors/sql/test_snowflake.py +10 -6
- test/integration/connectors/sql/test_sqlite.py +4 -4
- test/integration/connectors/test_astradb.py +50 -3
- test/integration/connectors/test_delta_table.py +46 -0
- test/integration/connectors/test_kafka.py +40 -6
- test/integration/connectors/test_lancedb.py +209 -0
- test/integration/connectors/test_milvus.py +141 -0
- test/integration/connectors/test_pinecone.py +53 -1
- test/integration/connectors/utils/docker.py +81 -15
- test/integration/connectors/utils/validation.py +10 -0
- test/integration/connectors/weaviate/__init__.py +0 -0
- test/integration/connectors/weaviate/conftest.py +15 -0
- test/integration/connectors/weaviate/test_local.py +131 -0
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/pipeline/reformat/embedding.py +1 -1
- unstructured_ingest/utils/data_prep.py +9 -1
- unstructured_ingest/v2/processes/connectors/__init__.py +3 -16
- unstructured_ingest/v2/processes/connectors/astradb.py +2 -2
- unstructured_ingest/v2/processes/connectors/azure_ai_search.py +4 -0
- unstructured_ingest/v2/processes/connectors/delta_table.py +20 -4
- unstructured_ingest/v2/processes/connectors/elasticsearch/__init__.py +19 -0
- unstructured_ingest/v2/processes/connectors/{elasticsearch.py → elasticsearch/elasticsearch.py} +92 -46
- unstructured_ingest/v2/processes/connectors/{opensearch.py → elasticsearch/opensearch.py} +1 -1
- unstructured_ingest/v2/processes/connectors/kafka/kafka.py +6 -0
- unstructured_ingest/v2/processes/connectors/lancedb/__init__.py +17 -0
- unstructured_ingest/v2/processes/connectors/lancedb/aws.py +43 -0
- unstructured_ingest/v2/processes/connectors/lancedb/azure.py +43 -0
- unstructured_ingest/v2/processes/connectors/lancedb/gcp.py +44 -0
- unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +161 -0
- unstructured_ingest/v2/processes/connectors/lancedb/local.py +44 -0
- unstructured_ingest/v2/processes/connectors/milvus.py +72 -27
- unstructured_ingest/v2/processes/connectors/pinecone.py +24 -7
- unstructured_ingest/v2/processes/connectors/sql/sql.py +97 -26
- unstructured_ingest/v2/processes/connectors/weaviate/__init__.py +22 -0
- unstructured_ingest/v2/processes/connectors/weaviate/cloud.py +164 -0
- unstructured_ingest/v2/processes/connectors/weaviate/embedded.py +90 -0
- unstructured_ingest/v2/processes/connectors/weaviate/local.py +73 -0
- unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +289 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/METADATA +15 -15
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/RECORD +50 -30
- unstructured_ingest/v2/processes/connectors/weaviate.py +0 -242
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/top_level.txt +0 -0
unstructured_ingest/v2/processes/connectors/{elasticsearch.py → elasticsearch/elasticsearch.py}
RENAMED
|
@@ -2,6 +2,7 @@ import hashlib
|
|
|
2
2
|
import json
|
|
3
3
|
import sys
|
|
4
4
|
import uuid
|
|
5
|
+
from contextlib import contextmanager
|
|
5
6
|
from dataclasses import dataclass, field
|
|
6
7
|
from pathlib import Path
|
|
7
8
|
from time import time
|
|
@@ -13,9 +14,11 @@ from unstructured_ingest.error import (
|
|
|
13
14
|
DestinationConnectionError,
|
|
14
15
|
SourceConnectionError,
|
|
15
16
|
SourceConnectionNetworkError,
|
|
17
|
+
WriteError,
|
|
16
18
|
)
|
|
17
19
|
from unstructured_ingest.utils.data_prep import flatten_dict, generator_batching_wbytes
|
|
18
20
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
21
|
+
from unstructured_ingest.v2.constants import RECORD_ID_LABEL
|
|
19
22
|
from unstructured_ingest.v2.interfaces import (
|
|
20
23
|
AccessConfig,
|
|
21
24
|
ConnectionConfig,
|
|
@@ -26,6 +29,7 @@ from unstructured_ingest.v2.interfaces import (
|
|
|
26
29
|
FileDataSourceMetadata,
|
|
27
30
|
Indexer,
|
|
28
31
|
IndexerConfig,
|
|
32
|
+
SourceIdentifiers,
|
|
29
33
|
Uploader,
|
|
30
34
|
UploaderConfig,
|
|
31
35
|
UploadStager,
|
|
@@ -116,19 +120,12 @@ class ElasticsearchConnectionConfig(ConnectionConfig):
|
|
|
116
120
|
return client_kwargs
|
|
117
121
|
|
|
118
122
|
@requires_dependencies(["elasticsearch"], extras="elasticsearch")
|
|
119
|
-
|
|
123
|
+
@contextmanager
|
|
124
|
+
def get_client(self) -> Generator["ElasticsearchClient", None, None]:
|
|
120
125
|
from elasticsearch import Elasticsearch as ElasticsearchClient
|
|
121
126
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
return client
|
|
125
|
-
|
|
126
|
-
def check_connection(self, client: "ElasticsearchClient"):
|
|
127
|
-
try:
|
|
128
|
-
client.perform_request("HEAD", "/", headers={"accept": "application/json"})
|
|
129
|
-
except Exception as e:
|
|
130
|
-
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
|
131
|
-
raise SourceConnectionError(f"failed to validate connection: {e}")
|
|
127
|
+
with ElasticsearchClient(**self.get_client_kwargs()) as client:
|
|
128
|
+
yield client
|
|
132
129
|
|
|
133
130
|
|
|
134
131
|
class ElasticsearchIndexerConfig(IndexerConfig):
|
|
@@ -144,7 +141,16 @@ class ElasticsearchIndexer(Indexer):
|
|
|
144
141
|
|
|
145
142
|
def precheck(self) -> None:
|
|
146
143
|
try:
|
|
147
|
-
self.connection_config.get_client()
|
|
144
|
+
with self.connection_config.get_client() as client:
|
|
145
|
+
if not client.ping():
|
|
146
|
+
raise SourceConnectionError("cluster not detected")
|
|
147
|
+
indices = client.indices.get_alias(index="*")
|
|
148
|
+
if self.index_config.index_name not in indices:
|
|
149
|
+
raise SourceConnectionError(
|
|
150
|
+
"index {} not found: {}".format(
|
|
151
|
+
self.index_config.index_name, ", ".join(indices.keys())
|
|
152
|
+
)
|
|
153
|
+
)
|
|
148
154
|
except Exception as e:
|
|
149
155
|
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
|
150
156
|
raise SourceConnectionError(f"failed to validate connection: {e}")
|
|
@@ -160,15 +166,15 @@ class ElasticsearchIndexer(Indexer):
|
|
|
160
166
|
scan = self.load_scan()
|
|
161
167
|
|
|
162
168
|
scan_query: dict = {"stored_fields": [], "query": {"match_all": {}}}
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
169
|
+
with self.connection_config.get_client() as client:
|
|
170
|
+
hits = scan(
|
|
171
|
+
client,
|
|
172
|
+
query=scan_query,
|
|
173
|
+
scroll="1m",
|
|
174
|
+
index=self.index_config.index_name,
|
|
175
|
+
)
|
|
170
176
|
|
|
171
|
-
|
|
177
|
+
return {hit["_id"] for hit in hits}
|
|
172
178
|
|
|
173
179
|
def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
|
|
174
180
|
all_ids = self._get_doc_ids()
|
|
@@ -257,6 +263,7 @@ class ElasticsearchDownloader(Downloader):
|
|
|
257
263
|
file_data=FileData(
|
|
258
264
|
identifier=filename_id,
|
|
259
265
|
connector_type=CONNECTOR_TYPE,
|
|
266
|
+
source_identifiers=SourceIdentifiers(filename=filename, fullpath=filename),
|
|
260
267
|
metadata=FileDataSourceMetadata(
|
|
261
268
|
version=str(result["_version"]) if "_version" in result else None,
|
|
262
269
|
date_processed=str(time()),
|
|
@@ -318,7 +325,7 @@ class ElasticsearchUploadStagerConfig(UploadStagerConfig):
|
|
|
318
325
|
class ElasticsearchUploadStager(UploadStager):
|
|
319
326
|
upload_stager_config: ElasticsearchUploadStagerConfig
|
|
320
327
|
|
|
321
|
-
def conform_dict(self, data: dict) -> dict:
|
|
328
|
+
def conform_dict(self, data: dict, file_data: FileData) -> dict:
|
|
322
329
|
resp = {
|
|
323
330
|
"_index": self.upload_stager_config.index_name,
|
|
324
331
|
"_id": str(uuid.uuid4()),
|
|
@@ -327,6 +334,7 @@ class ElasticsearchUploadStager(UploadStager):
|
|
|
327
334
|
"embeddings": data.pop("embeddings", None),
|
|
328
335
|
"text": data.pop("text", None),
|
|
329
336
|
"type": data.pop("type", None),
|
|
337
|
+
RECORD_ID_LABEL: file_data.identifier,
|
|
330
338
|
},
|
|
331
339
|
}
|
|
332
340
|
if "metadata" in data and isinstance(data["metadata"], dict):
|
|
@@ -343,10 +351,17 @@ class ElasticsearchUploadStager(UploadStager):
|
|
|
343
351
|
) -> Path:
|
|
344
352
|
with open(elements_filepath) as elements_file:
|
|
345
353
|
elements_contents = json.load(elements_file)
|
|
346
|
-
conformed_elements = [
|
|
347
|
-
|
|
354
|
+
conformed_elements = [
|
|
355
|
+
self.conform_dict(data=element, file_data=file_data) for element in elements_contents
|
|
356
|
+
]
|
|
357
|
+
if Path(output_filename).suffix != ".json":
|
|
358
|
+
output_filename = f"{output_filename}.json"
|
|
359
|
+
else:
|
|
360
|
+
output_filename = f"{Path(output_filename).stem}.json"
|
|
361
|
+
output_path = Path(output_dir) / output_filename
|
|
362
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
348
363
|
with open(output_path, "w") as output_file:
|
|
349
|
-
json.dump(conformed_elements, output_file)
|
|
364
|
+
json.dump(conformed_elements, output_file, indent=2)
|
|
350
365
|
return output_path
|
|
351
366
|
|
|
352
367
|
|
|
@@ -363,6 +378,10 @@ class ElasticsearchUploaderConfig(UploaderConfig):
|
|
|
363
378
|
num_threads: int = Field(
|
|
364
379
|
default=4, description="Number of threads to be used while uploading content"
|
|
365
380
|
)
|
|
381
|
+
record_id_key: str = Field(
|
|
382
|
+
default=RECORD_ID_LABEL,
|
|
383
|
+
description="searchable key to find entries for the same record on previous runs",
|
|
384
|
+
)
|
|
366
385
|
|
|
367
386
|
|
|
368
387
|
@dataclass
|
|
@@ -373,7 +392,16 @@ class ElasticsearchUploader(Uploader):
|
|
|
373
392
|
|
|
374
393
|
def precheck(self) -> None:
|
|
375
394
|
try:
|
|
376
|
-
self.connection_config.get_client()
|
|
395
|
+
with self.connection_config.get_client() as client:
|
|
396
|
+
if not client.ping():
|
|
397
|
+
raise DestinationConnectionError("cluster not detected")
|
|
398
|
+
indices = client.indices.get_alias(index="*")
|
|
399
|
+
if self.upload_config.index_name not in indices:
|
|
400
|
+
raise SourceConnectionError(
|
|
401
|
+
"index {} not found: {}".format(
|
|
402
|
+
self.upload_config.index_name, ", ".join(indices.keys())
|
|
403
|
+
)
|
|
404
|
+
)
|
|
377
405
|
except Exception as e:
|
|
378
406
|
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
|
379
407
|
raise DestinationConnectionError(f"failed to validate connection: {e}")
|
|
@@ -384,6 +412,23 @@ class ElasticsearchUploader(Uploader):
|
|
|
384
412
|
|
|
385
413
|
return parallel_bulk
|
|
386
414
|
|
|
415
|
+
def delete_by_record_id(self, client, file_data: FileData) -> None:
|
|
416
|
+
logger.debug(
|
|
417
|
+
f"deleting any content with metadata {RECORD_ID_LABEL}={file_data.identifier} "
|
|
418
|
+
f"from {self.upload_config.index_name} index"
|
|
419
|
+
)
|
|
420
|
+
delete_resp = client.delete_by_query(
|
|
421
|
+
index=self.upload_config.index_name,
|
|
422
|
+
body={"query": {"match": {self.upload_config.record_id_key: file_data.identifier}}},
|
|
423
|
+
)
|
|
424
|
+
logger.info(
|
|
425
|
+
"deleted {} records from index {}".format(
|
|
426
|
+
delete_resp["deleted"], self.upload_config.index_name
|
|
427
|
+
)
|
|
428
|
+
)
|
|
429
|
+
if failures := delete_resp.get("failures"):
|
|
430
|
+
raise WriteError(f"failed to delete records: {failures}")
|
|
431
|
+
|
|
387
432
|
def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
|
|
388
433
|
parallel_bulk = self.load_parallel_bulk()
|
|
389
434
|
with path.open("r") as file:
|
|
@@ -397,28 +442,29 @@ class ElasticsearchUploader(Uploader):
|
|
|
397
442
|
f"{self.upload_config.num_threads} (number of) threads"
|
|
398
443
|
)
|
|
399
444
|
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
for success, info in parallel_bulk(
|
|
411
|
-
client=client,
|
|
412
|
-
actions=batch,
|
|
413
|
-
thread_count=self.upload_config.num_threads,
|
|
445
|
+
with self.connection_config.get_client() as client:
|
|
446
|
+
self.delete_by_record_id(client=client, file_data=file_data)
|
|
447
|
+
if not client.indices.exists(index=self.upload_config.index_name):
|
|
448
|
+
logger.warning(
|
|
449
|
+
f"{(self.__class__.__name__).replace('Uploader', '')} index does not exist: "
|
|
450
|
+
f"{self.upload_config.index_name}. "
|
|
451
|
+
f"This may cause issues when uploading."
|
|
452
|
+
)
|
|
453
|
+
for batch in generator_batching_wbytes(
|
|
454
|
+
elements_dict, batch_size_limit_bytes=self.upload_config.batch_size_bytes
|
|
414
455
|
):
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
456
|
+
for success, info in parallel_bulk(
|
|
457
|
+
client=client,
|
|
458
|
+
actions=batch,
|
|
459
|
+
thread_count=self.upload_config.num_threads,
|
|
460
|
+
):
|
|
461
|
+
if not success:
|
|
462
|
+
logger.error(
|
|
463
|
+
"upload failed for a batch in "
|
|
464
|
+
f"{(self.__class__.__name__).replace('Uploader', '')} "
|
|
465
|
+
"destination connector:",
|
|
466
|
+
info,
|
|
467
|
+
)
|
|
422
468
|
|
|
423
469
|
|
|
424
470
|
elasticsearch_source_entry = SourceRegistryEntry(
|
|
@@ -17,7 +17,7 @@ from unstructured_ingest.v2.processes.connector_registry import (
|
|
|
17
17
|
DestinationRegistryEntry,
|
|
18
18
|
SourceRegistryEntry,
|
|
19
19
|
)
|
|
20
|
-
from unstructured_ingest.v2.processes.connectors.elasticsearch import (
|
|
20
|
+
from unstructured_ingest.v2.processes.connectors.elasticsearch.elasticsearch import (
|
|
21
21
|
ElasticsearchDownloader,
|
|
22
22
|
ElasticsearchDownloaderConfig,
|
|
23
23
|
ElasticsearchIndexer,
|
|
@@ -161,6 +161,12 @@ class KafkaIndexer(Indexer, ABC):
|
|
|
161
161
|
current_topics = [
|
|
162
162
|
topic for topic in cluster_meta.topics if topic != "__consumer_offsets"
|
|
163
163
|
]
|
|
164
|
+
if self.index_config.topic not in current_topics:
|
|
165
|
+
raise SourceConnectionError(
|
|
166
|
+
"expected topic {} not detected in cluster: {}".format(
|
|
167
|
+
self.index_config.topic, ", ".join(current_topics)
|
|
168
|
+
)
|
|
169
|
+
)
|
|
164
170
|
logger.info(f"successfully checked available topics: {current_topics}")
|
|
165
171
|
except Exception as e:
|
|
166
172
|
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from unstructured_ingest.v2.processes.connector_registry import add_destination_entry
|
|
4
|
+
|
|
5
|
+
from .aws import CONNECTOR_TYPE as LANCEDB_S3_CONNECTOR_TYPE
|
|
6
|
+
from .aws import lancedb_aws_destination_entry
|
|
7
|
+
from .azure import CONNECTOR_TYPE as LANCEDB_AZURE_CONNECTOR_TYPE
|
|
8
|
+
from .azure import lancedb_azure_destination_entry
|
|
9
|
+
from .gcp import CONNECTOR_TYPE as LANCEDB_GCS_CONNECTOR_TYPE
|
|
10
|
+
from .gcp import lancedb_gcp_destination_entry
|
|
11
|
+
from .local import CONNECTOR_TYPE as LANCEDB_LOCAL_CONNECTOR_TYPE
|
|
12
|
+
from .local import lancedb_local_destination_entry
|
|
13
|
+
|
|
14
|
+
add_destination_entry(LANCEDB_S3_CONNECTOR_TYPE, lancedb_aws_destination_entry)
|
|
15
|
+
add_destination_entry(LANCEDB_AZURE_CONNECTOR_TYPE, lancedb_azure_destination_entry)
|
|
16
|
+
add_destination_entry(LANCEDB_GCS_CONNECTOR_TYPE, lancedb_gcp_destination_entry)
|
|
17
|
+
add_destination_entry(LANCEDB_LOCAL_CONNECTOR_TYPE, lancedb_local_destination_entry)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from pydantic import Field, Secret
|
|
4
|
+
|
|
5
|
+
from unstructured_ingest.v2.interfaces.connector import AccessConfig
|
|
6
|
+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
|
|
7
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
|
|
8
|
+
LanceDBRemoteConnectionConfig,
|
|
9
|
+
LanceDBUploader,
|
|
10
|
+
LanceDBUploaderConfig,
|
|
11
|
+
LanceDBUploadStager,
|
|
12
|
+
LanceDBUploadStagerConfig,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
CONNECTOR_TYPE = "lancedb_aws"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LanceDBS3AccessConfig(AccessConfig):
|
|
19
|
+
aws_access_key_id: str = Field(description="The AWS access key ID to use.")
|
|
20
|
+
aws_secret_access_key: str = Field(description="The AWS secret access key to use.")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LanceDBS3ConnectionConfig(LanceDBRemoteConnectionConfig):
|
|
24
|
+
access_config: Secret[LanceDBS3AccessConfig]
|
|
25
|
+
|
|
26
|
+
def get_storage_options(self) -> dict:
|
|
27
|
+
return {**self.access_config.get_secret_value().model_dump(), "timeout": self.timeout}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class LanceDBS3Uploader(LanceDBUploader):
|
|
32
|
+
upload_config: LanceDBUploaderConfig
|
|
33
|
+
connection_config: LanceDBS3ConnectionConfig
|
|
34
|
+
connector_type: str = CONNECTOR_TYPE
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
lancedb_aws_destination_entry = DestinationRegistryEntry(
|
|
38
|
+
connection_config=LanceDBS3ConnectionConfig,
|
|
39
|
+
uploader=LanceDBS3Uploader,
|
|
40
|
+
uploader_config=LanceDBUploaderConfig,
|
|
41
|
+
upload_stager_config=LanceDBUploadStagerConfig,
|
|
42
|
+
upload_stager=LanceDBUploadStager,
|
|
43
|
+
)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from pydantic import Field, Secret
|
|
4
|
+
|
|
5
|
+
from unstructured_ingest.v2.interfaces.connector import AccessConfig
|
|
6
|
+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
|
|
7
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
|
|
8
|
+
LanceDBRemoteConnectionConfig,
|
|
9
|
+
LanceDBUploader,
|
|
10
|
+
LanceDBUploaderConfig,
|
|
11
|
+
LanceDBUploadStager,
|
|
12
|
+
LanceDBUploadStagerConfig,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
CONNECTOR_TYPE = "lancedb_azure"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LanceDBAzureAccessConfig(AccessConfig):
|
|
19
|
+
azure_storage_account_name: str = Field(description="The name of the azure storage account.")
|
|
20
|
+
azure_storage_account_key: str = Field(description="The serialized azure service account key.")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LanceDBAzureConnectionConfig(LanceDBRemoteConnectionConfig):
|
|
24
|
+
access_config: Secret[LanceDBAzureAccessConfig]
|
|
25
|
+
|
|
26
|
+
def get_storage_options(self) -> dict:
|
|
27
|
+
return {**self.access_config.get_secret_value().model_dump(), "timeout": self.timeout}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class LanceDBAzureUploader(LanceDBUploader):
|
|
32
|
+
upload_config: LanceDBUploaderConfig
|
|
33
|
+
connection_config: LanceDBAzureConnectionConfig
|
|
34
|
+
connector_type: str = CONNECTOR_TYPE
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
lancedb_azure_destination_entry = DestinationRegistryEntry(
|
|
38
|
+
connection_config=LanceDBAzureConnectionConfig,
|
|
39
|
+
uploader=LanceDBAzureUploader,
|
|
40
|
+
uploader_config=LanceDBUploaderConfig,
|
|
41
|
+
upload_stager_config=LanceDBUploadStagerConfig,
|
|
42
|
+
upload_stager=LanceDBUploadStager,
|
|
43
|
+
)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from pydantic import Field, Secret
|
|
4
|
+
|
|
5
|
+
from unstructured_ingest.v2.interfaces.connector import AccessConfig
|
|
6
|
+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
|
|
7
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
|
|
8
|
+
LanceDBRemoteConnectionConfig,
|
|
9
|
+
LanceDBUploader,
|
|
10
|
+
LanceDBUploaderConfig,
|
|
11
|
+
LanceDBUploadStager,
|
|
12
|
+
LanceDBUploadStagerConfig,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
CONNECTOR_TYPE = "lancedb_gcs"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LanceDBGCSAccessConfig(AccessConfig):
|
|
19
|
+
google_service_account_key: str = Field(
|
|
20
|
+
description="The serialized google service account key."
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LanceDBGCSConnectionConfig(LanceDBRemoteConnectionConfig):
|
|
25
|
+
access_config: Secret[LanceDBGCSAccessConfig]
|
|
26
|
+
|
|
27
|
+
def get_storage_options(self) -> dict:
|
|
28
|
+
return {**self.access_config.get_secret_value().model_dump(), "timeout": self.timeout}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class LanceDBGSPUploader(LanceDBUploader):
|
|
33
|
+
upload_config: LanceDBUploaderConfig
|
|
34
|
+
connection_config: LanceDBGCSConnectionConfig
|
|
35
|
+
connector_type: str = CONNECTOR_TYPE
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
lancedb_gcp_destination_entry = DestinationRegistryEntry(
|
|
39
|
+
connection_config=LanceDBGCSConnectionConfig,
|
|
40
|
+
uploader=LanceDBGSPUploader,
|
|
41
|
+
uploader_config=LanceDBUploaderConfig,
|
|
42
|
+
upload_stager_config=LanceDBUploadStagerConfig,
|
|
43
|
+
upload_stager=LanceDBUploadStager,
|
|
44
|
+
)
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from contextlib import asynccontextmanager
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
|
|
10
|
+
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from pydantic import Field
|
|
13
|
+
|
|
14
|
+
from unstructured_ingest.error import DestinationConnectionError
|
|
15
|
+
from unstructured_ingest.logger import logger
|
|
16
|
+
from unstructured_ingest.utils.data_prep import flatten_dict
|
|
17
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
18
|
+
from unstructured_ingest.v2.interfaces.connector import ConnectionConfig
|
|
19
|
+
from unstructured_ingest.v2.interfaces.file_data import FileData
|
|
20
|
+
from unstructured_ingest.v2.interfaces.upload_stager import UploadStager, UploadStagerConfig
|
|
21
|
+
from unstructured_ingest.v2.interfaces.uploader import Uploader, UploaderConfig
|
|
22
|
+
|
|
23
|
+
CONNECTOR_TYPE = "lancedb"
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from lancedb import AsyncConnection
|
|
27
|
+
from lancedb.table import AsyncTable
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class LanceDBConnectionConfig(ConnectionConfig, ABC):
|
|
31
|
+
uri: str = Field(description="The uri of the database.")
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def get_storage_options(self) -> Optional[dict[str, str]]:
|
|
35
|
+
raise NotImplementedError
|
|
36
|
+
|
|
37
|
+
@asynccontextmanager
|
|
38
|
+
@requires_dependencies(["lancedb"], extras="lancedb")
|
|
39
|
+
@DestinationConnectionError.wrap
|
|
40
|
+
async def get_async_connection(self) -> AsyncGenerator["AsyncConnection", None]:
|
|
41
|
+
import lancedb
|
|
42
|
+
|
|
43
|
+
connection = await lancedb.connect_async(
|
|
44
|
+
self.uri,
|
|
45
|
+
storage_options=self.get_storage_options(),
|
|
46
|
+
)
|
|
47
|
+
try:
|
|
48
|
+
yield connection
|
|
49
|
+
finally:
|
|
50
|
+
connection.close()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class LanceDBRemoteConnectionConfig(LanceDBConnectionConfig):
|
|
54
|
+
timeout: str = Field(
|
|
55
|
+
default="30s",
|
|
56
|
+
description=(
|
|
57
|
+
"Timeout for the entire request, from connection until the response body has finished"
|
|
58
|
+
"in a [0-9]+(ns|us|ms|[smhdwy]) format."
|
|
59
|
+
),
|
|
60
|
+
pattern=r"[0-9]+(ns|us|ms|[smhdwy])",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class LanceDBUploadStagerConfig(UploadStagerConfig):
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class LanceDBUploadStager(UploadStager):
|
|
70
|
+
upload_stager_config: LanceDBUploadStagerConfig = field(
|
|
71
|
+
default_factory=LanceDBUploadStagerConfig
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def run(
|
|
75
|
+
self,
|
|
76
|
+
elements_filepath: Path,
|
|
77
|
+
file_data: FileData,
|
|
78
|
+
output_dir: Path,
|
|
79
|
+
output_filename: str,
|
|
80
|
+
**kwargs: Any,
|
|
81
|
+
) -> Path:
|
|
82
|
+
with open(elements_filepath) as elements_file:
|
|
83
|
+
elements_contents: list[dict] = json.load(elements_file)
|
|
84
|
+
|
|
85
|
+
df = pd.DataFrame(
|
|
86
|
+
[
|
|
87
|
+
self._conform_element_contents(element_contents)
|
|
88
|
+
for element_contents in elements_contents
|
|
89
|
+
]
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
output_path = (output_dir / output_filename).with_suffix(".feather")
|
|
93
|
+
df.to_feather(output_path)
|
|
94
|
+
|
|
95
|
+
return output_path
|
|
96
|
+
|
|
97
|
+
def _conform_element_contents(self, element: dict) -> dict:
|
|
98
|
+
return {
|
|
99
|
+
"vector": element.pop("embeddings", None),
|
|
100
|
+
**flatten_dict(element, separator="-"),
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class LanceDBUploaderConfig(UploaderConfig):
|
|
105
|
+
table_name: str = Field(description="The name of the table.")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@dataclass
|
|
109
|
+
class LanceDBUploader(Uploader):
|
|
110
|
+
upload_config: LanceDBUploaderConfig
|
|
111
|
+
connection_config: LanceDBConnectionConfig
|
|
112
|
+
connector_type: str = CONNECTOR_TYPE
|
|
113
|
+
|
|
114
|
+
@DestinationConnectionError.wrap
|
|
115
|
+
def precheck(self):
|
|
116
|
+
async def _precheck() -> None:
|
|
117
|
+
async with self.connection_config.get_async_connection() as conn:
|
|
118
|
+
table = await conn.open_table(self.upload_config.table_name)
|
|
119
|
+
table.close()
|
|
120
|
+
|
|
121
|
+
asyncio.run(_precheck())
|
|
122
|
+
|
|
123
|
+
@asynccontextmanager
|
|
124
|
+
async def get_table(self) -> AsyncGenerator["AsyncTable", None]:
|
|
125
|
+
async with self.connection_config.get_async_connection() as conn:
|
|
126
|
+
table = await conn.open_table(self.upload_config.table_name)
|
|
127
|
+
try:
|
|
128
|
+
yield table
|
|
129
|
+
finally:
|
|
130
|
+
table.close()
|
|
131
|
+
|
|
132
|
+
async def run_async(self, path, file_data, **kwargs):
|
|
133
|
+
df = pd.read_feather(path)
|
|
134
|
+
async with self.get_table() as table:
|
|
135
|
+
schema = await table.schema()
|
|
136
|
+
df = self._fit_to_schema(df, schema)
|
|
137
|
+
await table.add(data=df)
|
|
138
|
+
|
|
139
|
+
def _fit_to_schema(self, df: pd.DataFrame, schema) -> pd.DataFrame:
|
|
140
|
+
columns = set(df.columns)
|
|
141
|
+
schema_fields = set(schema.names)
|
|
142
|
+
columns_to_drop = columns - schema_fields
|
|
143
|
+
missing_columns = schema_fields - columns
|
|
144
|
+
|
|
145
|
+
if columns_to_drop:
|
|
146
|
+
logger.info(
|
|
147
|
+
"Following columns will be dropped to match the table's schema: "
|
|
148
|
+
f"{', '.join(columns_to_drop)}"
|
|
149
|
+
)
|
|
150
|
+
if missing_columns:
|
|
151
|
+
logger.info(
|
|
152
|
+
"Following null filled columns will be added to match the table's schema:"
|
|
153
|
+
f" {', '.join(missing_columns)} "
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
df = df.drop(columns=columns_to_drop)
|
|
157
|
+
|
|
158
|
+
for column in missing_columns:
|
|
159
|
+
df[column] = pd.Series()
|
|
160
|
+
|
|
161
|
+
return df
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from pydantic import Field, Secret
|
|
4
|
+
|
|
5
|
+
from unstructured_ingest.v2.interfaces.connector import AccessConfig
|
|
6
|
+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
|
|
7
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
|
|
8
|
+
LanceDBConnectionConfig,
|
|
9
|
+
LanceDBUploader,
|
|
10
|
+
LanceDBUploaderConfig,
|
|
11
|
+
LanceDBUploadStager,
|
|
12
|
+
LanceDBUploadStagerConfig,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
CONNECTOR_TYPE = "lancedb_local"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LanceDBLocalAccessConfig(AccessConfig):
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class LanceDBLocalConnectionConfig(LanceDBConnectionConfig):
|
|
23
|
+
access_config: Secret[LanceDBLocalAccessConfig] = Field(
|
|
24
|
+
default_factory=LanceDBLocalAccessConfig, validate_default=True
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
def get_storage_options(self) -> None:
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class LanceDBLocalUploader(LanceDBUploader):
|
|
33
|
+
upload_config: LanceDBUploaderConfig
|
|
34
|
+
connection_config: LanceDBLocalConnectionConfig
|
|
35
|
+
connector_type: str = CONNECTOR_TYPE
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
lancedb_local_destination_entry = DestinationRegistryEntry(
|
|
39
|
+
connection_config=LanceDBLocalConnectionConfig,
|
|
40
|
+
uploader=LanceDBLocalUploader,
|
|
41
|
+
uploader_config=LanceDBUploaderConfig,
|
|
42
|
+
upload_stager_config=LanceDBUploadStagerConfig,
|
|
43
|
+
upload_stager=LanceDBUploadStager,
|
|
44
|
+
)
|