unstructured-ingest 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/test_lancedb.py +9 -8
- test/integration/connectors/test_milvus.py +34 -6
- test/integration/connectors/test_mongodb.py +332 -0
- test/integration/connectors/weaviate/test_cloud.py +34 -0
- test/unit/test_utils.py +21 -1
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/utils/string_and_date_utils.py +10 -0
- unstructured_ingest/v2/processes/connectors/astradb.py +16 -0
- unstructured_ingest/v2/processes/connectors/google_drive.py +1 -1
- unstructured_ingest/v2/processes/connectors/lancedb/__init__.py +17 -4
- unstructured_ingest/v2/processes/connectors/lancedb/aws.py +7 -7
- unstructured_ingest/v2/processes/connectors/lancedb/cloud.py +42 -0
- unstructured_ingest/v2/processes/connectors/milvus.py +9 -3
- unstructured_ingest/v2/processes/connectors/mongodb.py +122 -111
- unstructured_ingest/v2/processes/connectors/weaviate/__init__.py +3 -0
- unstructured_ingest/v2/processes/connectors/weaviate/cloud.py +4 -3
- unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +10 -0
- {unstructured_ingest-0.3.1.dist-info → unstructured_ingest-0.3.3.dist-info}/METADATA +14 -12
- {unstructured_ingest-0.3.1.dist-info → unstructured_ingest-0.3.3.dist-info}/RECORD +24 -21
- /test/integration/connectors/{test_azure_cog_search.py → test_azure_ai_search.py} +0 -0
- {unstructured_ingest-0.3.1.dist-info → unstructured_ingest-0.3.3.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.3.1.dist-info → unstructured_ingest-0.3.3.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.3.1.dist-info → unstructured_ingest-0.3.3.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.3.1.dist-info → unstructured_ingest-0.3.3.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from pathlib import Path
|
|
3
3
|
from typing import Literal, Union
|
|
4
|
+
from uuid import uuid4
|
|
4
5
|
|
|
5
6
|
import lancedb
|
|
6
7
|
import pandas as pd
|
|
@@ -13,9 +14,9 @@ from upath import UPath
|
|
|
13
14
|
from test.integration.connectors.utils.constants import DESTINATION_TAG
|
|
14
15
|
from unstructured_ingest.v2.interfaces.file_data import FileData, SourceIdentifiers
|
|
15
16
|
from unstructured_ingest.v2.processes.connectors.lancedb.aws import (
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
LanceDBAwsAccessConfig,
|
|
18
|
+
LanceDBAwsConnectionConfig,
|
|
19
|
+
LanceDBAwsUploader,
|
|
19
20
|
)
|
|
20
21
|
from unstructured_ingest.v2.processes.connectors.lancedb.azure import (
|
|
21
22
|
LanceDBAzureAccessConfig,
|
|
@@ -150,12 +151,12 @@ def _get_uri(target: Literal["local", "s3", "gcs", "az"], local_base_path: Path)
|
|
|
150
151
|
elif target == "az":
|
|
151
152
|
base_uri = UPath(AZURE_BUCKET)
|
|
152
153
|
|
|
153
|
-
return str(base_uri / "destination" / "lancedb" / DATABASE_NAME)
|
|
154
|
+
return str(base_uri / "destination" / "lancedb" / str(uuid4()) / DATABASE_NAME)
|
|
154
155
|
|
|
155
156
|
|
|
156
157
|
def _get_uploader(
|
|
157
158
|
uri: str,
|
|
158
|
-
) -> Union[LanceDBAzureUploader, LanceDBAzureUploader,
|
|
159
|
+
) -> Union[LanceDBAzureUploader, LanceDBAzureUploader, LanceDBAwsUploader, LanceDBGSPUploader]:
|
|
159
160
|
target = uri.split("://", maxsplit=1)[0] if uri.startswith(("s3", "az", "gs")) else "local"
|
|
160
161
|
if target == "az":
|
|
161
162
|
azure_connection_string = os.getenv("AZURE_DEST_CONNECTION_STR")
|
|
@@ -169,10 +170,10 @@ def _get_uploader(
|
|
|
169
170
|
)
|
|
170
171
|
|
|
171
172
|
elif target == "s3":
|
|
172
|
-
return
|
|
173
|
+
return LanceDBAwsUploader(
|
|
173
174
|
upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
|
|
174
|
-
connection_config=
|
|
175
|
-
access_config=
|
|
175
|
+
connection_config=LanceDBAwsConnectionConfig(
|
|
176
|
+
access_config=LanceDBAwsAccessConfig(
|
|
176
177
|
aws_access_key_id=os.getenv("S3_INGEST_TEST_ACCESS_KEY"),
|
|
177
178
|
aws_secret_access_key=os.getenv("S3_INGEST_TEST_SECRET_KEY"),
|
|
178
179
|
),
|
|
@@ -15,6 +15,7 @@ from pymilvus.milvus_client import IndexParams
|
|
|
15
15
|
from test.integration.connectors.utils.constants import DESTINATION_TAG, env_setup_path
|
|
16
16
|
from test.integration.connectors.utils.docker import healthcheck_wait
|
|
17
17
|
from test.integration.connectors.utils.docker_compose import docker_compose_context
|
|
18
|
+
from unstructured_ingest.error import DestinationConnectionError
|
|
18
19
|
from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
|
|
19
20
|
from unstructured_ingest.v2.processes.connectors.milvus import (
|
|
20
21
|
CONNECTOR_TYPE,
|
|
@@ -24,9 +25,10 @@ from unstructured_ingest.v2.processes.connectors.milvus import (
|
|
|
24
25
|
MilvusUploadStager,
|
|
25
26
|
)
|
|
26
27
|
|
|
27
|
-
DB_URI = "http://localhost:19530"
|
|
28
28
|
DB_NAME = "test_database"
|
|
29
|
-
|
|
29
|
+
EXISTENT_COLLECTION_NAME = "test_collection"
|
|
30
|
+
NONEXISTENT_COLLECTION_NAME = "nonexistent_collection"
|
|
31
|
+
DB_URI = "http://localhost:19530"
|
|
30
32
|
|
|
31
33
|
|
|
32
34
|
def get_schema() -> CollectionSchema:
|
|
@@ -55,7 +57,9 @@ def get_index_params() -> IndexParams:
|
|
|
55
57
|
return index_params
|
|
56
58
|
|
|
57
59
|
|
|
58
|
-
|
|
60
|
+
# NOTE: Precheck tests are read-only so they don't interfere with destination test,
|
|
61
|
+
# using scope="module" we can limit number of times the docker-compose has to be run
|
|
62
|
+
@pytest.fixture(scope="module")
|
|
59
63
|
def collection():
|
|
60
64
|
docker_client = docker.from_env()
|
|
61
65
|
with docker_compose_context(docker_compose_path=env_setup_path / "milvus"):
|
|
@@ -73,10 +77,10 @@ def collection():
|
|
|
73
77
|
schema = get_schema()
|
|
74
78
|
index_params = get_index_params()
|
|
75
79
|
collection_resp = milvus_client.create_collection(
|
|
76
|
-
collection_name=
|
|
80
|
+
collection_name=EXISTENT_COLLECTION_NAME, schema=schema, index_params=index_params
|
|
77
81
|
)
|
|
78
|
-
print(f"Created collection {
|
|
79
|
-
yield
|
|
82
|
+
print(f"Created collection {EXISTENT_COLLECTION_NAME}: {collection_resp}")
|
|
83
|
+
yield EXISTENT_COLLECTION_NAME
|
|
80
84
|
finally:
|
|
81
85
|
milvus_client.close()
|
|
82
86
|
|
|
@@ -139,3 +143,27 @@ async def test_milvus_destination(
|
|
|
139
143
|
uploader.run(path=staged_filepath, file_data=file_data)
|
|
140
144
|
with uploader.get_client() as client:
|
|
141
145
|
validate_count(client=client, expected_count=expected_count)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
|
|
149
|
+
def test_precheck_succeeds(collection: str):
|
|
150
|
+
uploader = MilvusUploader(
|
|
151
|
+
connection_config=MilvusConnectionConfig(uri=DB_URI),
|
|
152
|
+
upload_config=MilvusUploaderConfig(db_name=DB_NAME, collection_name=collection),
|
|
153
|
+
)
|
|
154
|
+
uploader.precheck()
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
|
|
158
|
+
def test_precheck_fails_on_nonexistent_collection(collection: str):
|
|
159
|
+
uploader = MilvusUploader(
|
|
160
|
+
connection_config=MilvusConnectionConfig(uri=DB_URI),
|
|
161
|
+
upload_config=MilvusUploaderConfig(
|
|
162
|
+
db_name=DB_NAME, collection_name=NONEXISTENT_COLLECTION_NAME
|
|
163
|
+
),
|
|
164
|
+
)
|
|
165
|
+
with pytest.raises(
|
|
166
|
+
DestinationConnectionError,
|
|
167
|
+
match=f"Collection '{NONEXISTENT_COLLECTION_NAME}' does not exist",
|
|
168
|
+
):
|
|
169
|
+
uploader.precheck()
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import time
|
|
4
|
+
import uuid
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Generator
|
|
8
|
+
|
|
9
|
+
import pytest
|
|
10
|
+
from pydantic import BaseModel, SecretStr
|
|
11
|
+
from pymongo.collection import Collection
|
|
12
|
+
from pymongo.database import Database
|
|
13
|
+
from pymongo.mongo_client import MongoClient
|
|
14
|
+
from pymongo.operations import SearchIndexModel
|
|
15
|
+
|
|
16
|
+
from test.integration.connectors.utils.constants import DESTINATION_TAG, SOURCE_TAG
|
|
17
|
+
from test.integration.connectors.utils.validation import (
|
|
18
|
+
ValidationConfigs,
|
|
19
|
+
source_connector_validation,
|
|
20
|
+
)
|
|
21
|
+
from test.integration.utils import requires_env
|
|
22
|
+
from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
|
|
23
|
+
from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
|
|
24
|
+
from unstructured_ingest.v2.processes.connectors.mongodb import (
|
|
25
|
+
CONNECTOR_TYPE,
|
|
26
|
+
MongoDBAccessConfig,
|
|
27
|
+
MongoDBConnectionConfig,
|
|
28
|
+
MongoDBDownloader,
|
|
29
|
+
MongoDBDownloaderConfig,
|
|
30
|
+
MongoDBIndexer,
|
|
31
|
+
MongoDBIndexerConfig,
|
|
32
|
+
MongoDBUploader,
|
|
33
|
+
MongoDBUploaderConfig,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
SOURCE_COLLECTION = "sample-mongodb-data"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class EnvData(BaseModel):
|
|
40
|
+
uri: SecretStr
|
|
41
|
+
database: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_env_data() -> EnvData:
|
|
45
|
+
uri = os.getenv("MONGODB_URI")
|
|
46
|
+
assert uri
|
|
47
|
+
database = os.getenv("MONGODB_DATABASE")
|
|
48
|
+
assert database
|
|
49
|
+
return EnvData(uri=uri, database=database)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@contextmanager
|
|
53
|
+
def get_client() -> Generator[MongoClient, None, None]:
|
|
54
|
+
uri = get_env_data().uri.get_secret_value()
|
|
55
|
+
with MongoClient(uri) as client:
|
|
56
|
+
assert client.admin.command("ping")
|
|
57
|
+
yield client
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def wait_for_collection(
|
|
61
|
+
database: Database, collection_name: str, retries: int = 10, interval: int = 1
|
|
62
|
+
):
|
|
63
|
+
collections = database.list_collection_names()
|
|
64
|
+
attempts = 0
|
|
65
|
+
while collection_name not in collections and attempts < retries:
|
|
66
|
+
attempts += 1
|
|
67
|
+
print(
|
|
68
|
+
"Waiting for collection {} to be recognized: {}".format(
|
|
69
|
+
collection_name, ", ".join(collections)
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
time.sleep(interval)
|
|
73
|
+
collections = database.list_collection_names()
|
|
74
|
+
if collection_name not in collection_name:
|
|
75
|
+
raise TimeoutError(f"Collection {collection_name} was not recognized")
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def get_search_index_status(collection: Collection, index_name: str) -> str:
|
|
79
|
+
search_indexes = collection.list_search_indexes(name=index_name)
|
|
80
|
+
search_index = list(search_indexes)[0]
|
|
81
|
+
return search_index["status"]
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def wait_for_search_index(
|
|
85
|
+
collection: Collection, index_name: str, retries: int = 60, interval: int = 1
|
|
86
|
+
):
|
|
87
|
+
current_status = get_search_index_status(collection, index_name)
|
|
88
|
+
attempts = 0
|
|
89
|
+
while current_status != "READY" and attempts < retries:
|
|
90
|
+
attempts += 1
|
|
91
|
+
print(f"attempt {attempts}: waiting for search index to be READY: {current_status}")
|
|
92
|
+
time.sleep(interval)
|
|
93
|
+
current_status = get_search_index_status(collection, index_name)
|
|
94
|
+
|
|
95
|
+
if current_status != "READY":
|
|
96
|
+
raise TimeoutError("search index never detected as READY")
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@pytest.fixture
|
|
100
|
+
def destination_collection() -> Collection:
|
|
101
|
+
env_data = get_env_data()
|
|
102
|
+
collection_name = f"utic-test-output-{uuid.uuid4()}"
|
|
103
|
+
with get_client() as client:
|
|
104
|
+
database = client[env_data.database]
|
|
105
|
+
print(f"creating collection in database {database}: {collection_name}")
|
|
106
|
+
collection = database.create_collection(name=collection_name)
|
|
107
|
+
search_index_name = "embeddings"
|
|
108
|
+
collection.create_search_index(
|
|
109
|
+
model=SearchIndexModel(
|
|
110
|
+
name=search_index_name,
|
|
111
|
+
definition={
|
|
112
|
+
"mappings": {
|
|
113
|
+
"dynamic": True,
|
|
114
|
+
"fields": {
|
|
115
|
+
"embeddings": [
|
|
116
|
+
{"type": "knnVector", "dimensions": 384, "similarity": "euclidean"}
|
|
117
|
+
]
|
|
118
|
+
},
|
|
119
|
+
}
|
|
120
|
+
},
|
|
121
|
+
)
|
|
122
|
+
)
|
|
123
|
+
collection.create_index("record_id")
|
|
124
|
+
wait_for_collection(database=database, collection_name=collection_name)
|
|
125
|
+
wait_for_search_index(collection=collection, index_name=search_index_name)
|
|
126
|
+
try:
|
|
127
|
+
yield collection
|
|
128
|
+
finally:
|
|
129
|
+
print(f"deleting collection: {collection_name}")
|
|
130
|
+
collection.drop()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def validate_collection_count(
|
|
134
|
+
collection: Collection, expected_records: int, retries: int = 10, interval: int = 1
|
|
135
|
+
) -> None:
|
|
136
|
+
count = collection.count_documents(filter={})
|
|
137
|
+
attempt = 0
|
|
138
|
+
while count != expected_records and attempt < retries:
|
|
139
|
+
attempt += 1
|
|
140
|
+
print(f"attempt {attempt} to get count of collection {count} to match {expected_records}")
|
|
141
|
+
time.sleep(interval)
|
|
142
|
+
count = collection.count_documents(filter={})
|
|
143
|
+
assert (
|
|
144
|
+
count == expected_records
|
|
145
|
+
), f"expected count ({expected_records}) does not match how many records were found: {count}"
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def validate_collection_vector(
|
|
149
|
+
collection: Collection, embedding: list[float], text: str, retries: int = 30, interval: int = 1
|
|
150
|
+
) -> None:
|
|
151
|
+
pipeline = [
|
|
152
|
+
{
|
|
153
|
+
"$vectorSearch": {
|
|
154
|
+
"index": "embeddings",
|
|
155
|
+
"path": "embeddings",
|
|
156
|
+
"queryVector": embedding,
|
|
157
|
+
"numCandidates": 150,
|
|
158
|
+
"limit": 10,
|
|
159
|
+
},
|
|
160
|
+
},
|
|
161
|
+
{"$project": {"_id": 0, "text": 1, "score": {"$meta": "vectorSearchScore"}}},
|
|
162
|
+
]
|
|
163
|
+
attempts = 0
|
|
164
|
+
results = list(collection.aggregate(pipeline=pipeline))
|
|
165
|
+
while not results and attempts < retries:
|
|
166
|
+
attempts += 1
|
|
167
|
+
print(f"attempt {attempts}, waiting for valid results: {results}")
|
|
168
|
+
time.sleep(interval)
|
|
169
|
+
results = list(collection.aggregate(pipeline=pipeline))
|
|
170
|
+
if not results:
|
|
171
|
+
raise TimeoutError("Timed out waiting for valid results")
|
|
172
|
+
print(f"found results on attempt {attempts}")
|
|
173
|
+
top_result = results[0]
|
|
174
|
+
assert top_result["score"] == 1.0, "score detected should be 1: {}".format(top_result["score"])
|
|
175
|
+
assert top_result["text"] == text, "text detected should be {}, found: {}".format(
|
|
176
|
+
text, top_result["text"]
|
|
177
|
+
)
|
|
178
|
+
for r in results[1:]:
|
|
179
|
+
assert r["score"] < 1.0, "score detected should be less than 1: {}".format(r["score"])
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
@pytest.mark.asyncio
|
|
183
|
+
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
|
|
184
|
+
@requires_env("MONGODB_URI", "MONGODB_DATABASE")
|
|
185
|
+
async def test_mongodb_source(temp_dir: Path):
|
|
186
|
+
env_data = get_env_data()
|
|
187
|
+
indexer_config = MongoDBIndexerConfig(database=env_data.database, collection=SOURCE_COLLECTION)
|
|
188
|
+
download_config = MongoDBDownloaderConfig(download_dir=temp_dir)
|
|
189
|
+
connection_config = MongoDBConnectionConfig(
|
|
190
|
+
access_config=MongoDBAccessConfig(uri=env_data.uri.get_secret_value()),
|
|
191
|
+
)
|
|
192
|
+
indexer = MongoDBIndexer(connection_config=connection_config, index_config=indexer_config)
|
|
193
|
+
downloader = MongoDBDownloader(
|
|
194
|
+
connection_config=connection_config, download_config=download_config
|
|
195
|
+
)
|
|
196
|
+
await source_connector_validation(
|
|
197
|
+
indexer=indexer,
|
|
198
|
+
downloader=downloader,
|
|
199
|
+
configs=ValidationConfigs(
|
|
200
|
+
test_id=CONNECTOR_TYPE, expected_num_files=4, validate_downloaded_files=True
|
|
201
|
+
),
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
|
|
206
|
+
def test_mongodb_indexer_precheck_fail_no_host():
|
|
207
|
+
indexer_config = MongoDBIndexerConfig(
|
|
208
|
+
database="non-existent-database", collection="non-existent-database"
|
|
209
|
+
)
|
|
210
|
+
connection_config = MongoDBConnectionConfig(
|
|
211
|
+
access_config=MongoDBAccessConfig(uri="mongodb+srv://ingest-test.hgaig.mongodb"),
|
|
212
|
+
)
|
|
213
|
+
indexer = MongoDBIndexer(connection_config=connection_config, index_config=indexer_config)
|
|
214
|
+
with pytest.raises(SourceConnectionError):
|
|
215
|
+
indexer.precheck()
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
|
|
219
|
+
@requires_env("MONGODB_URI", "MONGODB_DATABASE")
|
|
220
|
+
def test_mongodb_indexer_precheck_fail_no_database():
|
|
221
|
+
env_data = get_env_data()
|
|
222
|
+
indexer_config = MongoDBIndexerConfig(
|
|
223
|
+
database="non-existent-database", collection=SOURCE_COLLECTION
|
|
224
|
+
)
|
|
225
|
+
connection_config = MongoDBConnectionConfig(
|
|
226
|
+
access_config=MongoDBAccessConfig(uri=env_data.uri.get_secret_value()),
|
|
227
|
+
)
|
|
228
|
+
indexer = MongoDBIndexer(connection_config=connection_config, index_config=indexer_config)
|
|
229
|
+
with pytest.raises(SourceConnectionError):
|
|
230
|
+
indexer.precheck()
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
|
|
234
|
+
@requires_env("MONGODB_URI", "MONGODB_DATABASE")
|
|
235
|
+
def test_mongodb_indexer_precheck_fail_no_collection():
|
|
236
|
+
env_data = get_env_data()
|
|
237
|
+
indexer_config = MongoDBIndexerConfig(
|
|
238
|
+
database=env_data.database, collection="non-existent-collection"
|
|
239
|
+
)
|
|
240
|
+
connection_config = MongoDBConnectionConfig(
|
|
241
|
+
access_config=MongoDBAccessConfig(uri=env_data.uri.get_secret_value()),
|
|
242
|
+
)
|
|
243
|
+
indexer = MongoDBIndexer(connection_config=connection_config, index_config=indexer_config)
|
|
244
|
+
with pytest.raises(SourceConnectionError):
|
|
245
|
+
indexer.precheck()
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
@pytest.mark.asyncio
|
|
249
|
+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
|
|
250
|
+
@requires_env("MONGODB_URI", "MONGODB_DATABASE")
|
|
251
|
+
async def test_mongodb_destination(
|
|
252
|
+
upload_file: Path,
|
|
253
|
+
destination_collection: Collection,
|
|
254
|
+
tmp_path: Path,
|
|
255
|
+
):
|
|
256
|
+
env_data = get_env_data()
|
|
257
|
+
file_data = FileData(
|
|
258
|
+
source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
|
|
259
|
+
connector_type=CONNECTOR_TYPE,
|
|
260
|
+
identifier="mongodb_mock_id",
|
|
261
|
+
)
|
|
262
|
+
connection_config = MongoDBConnectionConfig(
|
|
263
|
+
access_config=MongoDBAccessConfig(uri=env_data.uri.get_secret_value()),
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
upload_config = MongoDBUploaderConfig(
|
|
267
|
+
database=env_data.database,
|
|
268
|
+
collection=destination_collection.name,
|
|
269
|
+
)
|
|
270
|
+
uploader = MongoDBUploader(connection_config=connection_config, upload_config=upload_config)
|
|
271
|
+
uploader.precheck()
|
|
272
|
+
uploader.run(path=upload_file, file_data=file_data)
|
|
273
|
+
|
|
274
|
+
with upload_file.open() as f:
|
|
275
|
+
staged_elements = json.load(f)
|
|
276
|
+
expected_records = len(staged_elements)
|
|
277
|
+
validate_collection_count(collection=destination_collection, expected_records=expected_records)
|
|
278
|
+
first_element = staged_elements[0]
|
|
279
|
+
validate_collection_vector(
|
|
280
|
+
collection=destination_collection,
|
|
281
|
+
embedding=first_element["embeddings"],
|
|
282
|
+
text=first_element["text"],
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
uploader.run(path=upload_file, file_data=file_data)
|
|
286
|
+
validate_collection_count(collection=destination_collection, expected_records=expected_records)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
|
|
290
|
+
def test_mongodb_uploader_precheck_fail_no_host():
|
|
291
|
+
upload_config = MongoDBUploaderConfig(
|
|
292
|
+
database="database",
|
|
293
|
+
collection="collection",
|
|
294
|
+
)
|
|
295
|
+
connection_config = MongoDBConnectionConfig(
|
|
296
|
+
access_config=MongoDBAccessConfig(uri="mongodb+srv://ingest-test.hgaig.mongodb"),
|
|
297
|
+
)
|
|
298
|
+
uploader = MongoDBUploader(connection_config=connection_config, upload_config=upload_config)
|
|
299
|
+
with pytest.raises(DestinationConnectionError):
|
|
300
|
+
uploader.precheck()
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
|
|
304
|
+
@requires_env("MONGODB_URI", "MONGODB_DATABASE")
|
|
305
|
+
def test_mongodb_uploader_precheck_fail_no_database():
|
|
306
|
+
env_data = get_env_data()
|
|
307
|
+
upload_config = MongoDBUploaderConfig(
|
|
308
|
+
database="database",
|
|
309
|
+
collection="collection",
|
|
310
|
+
)
|
|
311
|
+
connection_config = MongoDBConnectionConfig(
|
|
312
|
+
access_config=MongoDBAccessConfig(uri=env_data.uri.get_secret_value()),
|
|
313
|
+
)
|
|
314
|
+
uploader = MongoDBUploader(connection_config=connection_config, upload_config=upload_config)
|
|
315
|
+
with pytest.raises(DestinationConnectionError):
|
|
316
|
+
uploader.precheck()
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
|
|
320
|
+
@requires_env("MONGODB_URI", "MONGODB_DATABASE")
|
|
321
|
+
def test_mongodb_uploader_precheck_fail_no_collection():
|
|
322
|
+
env_data = get_env_data()
|
|
323
|
+
upload_config = MongoDBUploaderConfig(
|
|
324
|
+
database=env_data.database,
|
|
325
|
+
collection="collection",
|
|
326
|
+
)
|
|
327
|
+
connection_config = MongoDBConnectionConfig(
|
|
328
|
+
access_config=MongoDBAccessConfig(uri=env_data.uri.get_secret_value()),
|
|
329
|
+
)
|
|
330
|
+
uploader = MongoDBUploader(connection_config=connection_config, upload_config=upload_config)
|
|
331
|
+
with pytest.raises(DestinationConnectionError):
|
|
332
|
+
uploader.precheck()
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from pydantic import ValidationError
|
|
3
|
+
|
|
4
|
+
from unstructured_ingest.v2.processes.connectors.weaviate.cloud import (
|
|
5
|
+
CloudWeaviateAccessConfig,
|
|
6
|
+
CloudWeaviateConnectionConfig,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def test_weaviate_failing_connection_config():
|
|
11
|
+
with pytest.raises(ValidationError):
|
|
12
|
+
CloudWeaviateConnectionConfig(
|
|
13
|
+
access_config=CloudWeaviateAccessConfig(api_key="my key", password="password"),
|
|
14
|
+
username="username",
|
|
15
|
+
cluster_url="clusterurl",
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def test_weaviate_connection_config_happy_path():
|
|
20
|
+
CloudWeaviateConnectionConfig(
|
|
21
|
+
access_config=CloudWeaviateAccessConfig(
|
|
22
|
+
api_key="my key",
|
|
23
|
+
),
|
|
24
|
+
cluster_url="clusterurl",
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_weaviate_connection_config_anonymous():
|
|
29
|
+
CloudWeaviateConnectionConfig(
|
|
30
|
+
access_config=CloudWeaviateAccessConfig(api_key="my key", password="password"),
|
|
31
|
+
username="username",
|
|
32
|
+
anonymous=True,
|
|
33
|
+
cluster_url="clusterurl",
|
|
34
|
+
)
|
test/unit/test_utils.py
CHANGED
|
@@ -8,7 +8,11 @@ import pytz
|
|
|
8
8
|
|
|
9
9
|
from unstructured_ingest.cli.utils import extract_config
|
|
10
10
|
from unstructured_ingest.interfaces import BaseConfig
|
|
11
|
-
from unstructured_ingest.utils.string_and_date_utils import
|
|
11
|
+
from unstructured_ingest.utils.string_and_date_utils import (
|
|
12
|
+
ensure_isoformat_datetime,
|
|
13
|
+
json_to_dict,
|
|
14
|
+
truncate_string_bytes,
|
|
15
|
+
)
|
|
12
16
|
|
|
13
17
|
|
|
14
18
|
@dataclass
|
|
@@ -162,3 +166,19 @@ def test_ensure_isoformat_datetime_fails_on_string():
|
|
|
162
166
|
def test_ensure_isoformat_datetime_fails_on_int():
|
|
163
167
|
with pytest.raises(TypeError):
|
|
164
168
|
ensure_isoformat_datetime(1111)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def test_truncate_string_bytes_return_truncated_string():
|
|
172
|
+
test_string = "abcdef안녕하세요ghijklmn방갑습니opqrstu 더 길어지면 안되는 문자열vwxyz"
|
|
173
|
+
max_bytes = 11
|
|
174
|
+
result = truncate_string_bytes(test_string, max_bytes)
|
|
175
|
+
assert result == "abcdef안"
|
|
176
|
+
assert len(result.encode("utf-8")) <= max_bytes
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def test_truncate_string_bytes_return_untouched_string():
|
|
180
|
+
test_string = "abcdef"
|
|
181
|
+
max_bytes = 11
|
|
182
|
+
result = truncate_string_bytes(test_string, max_bytes)
|
|
183
|
+
assert result == "abcdef"
|
|
184
|
+
assert len(result.encode("utf-8")) <= max_bytes
|
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.3.
|
|
1
|
+
__version__ = "0.3.3" # pragma: no cover
|
|
@@ -37,3 +37,13 @@ def ensure_isoformat_datetime(timestamp: t.Union[datetime, str]) -> str:
|
|
|
37
37
|
raise ValueError(f"String '{timestamp}' could not be parsed as a datetime.") from e
|
|
38
38
|
else:
|
|
39
39
|
raise TypeError(f"Expected input type datetime or str, but got {type(timestamp)}.")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def truncate_string_bytes(string: str, max_bytes: int, encoding: str = "utf-8") -> str:
|
|
43
|
+
"""
|
|
44
|
+
Truncates a string to a specified maximum number of bytes.
|
|
45
|
+
"""
|
|
46
|
+
encoded_string = str(string).encode(encoding)
|
|
47
|
+
if len(encoded_string) <= max_bytes:
|
|
48
|
+
return string
|
|
49
|
+
return encoded_string[:max_bytes].decode(encoding, errors="ignore")
|
|
@@ -19,6 +19,7 @@ from unstructured_ingest.error import (
|
|
|
19
19
|
)
|
|
20
20
|
from unstructured_ingest.utils.data_prep import batch_generator
|
|
21
21
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
22
|
+
from unstructured_ingest.utils.string_and_date_utils import truncate_string_bytes
|
|
22
23
|
from unstructured_ingest.v2.constants import RECORD_ID_LABEL
|
|
23
24
|
from unstructured_ingest.v2.interfaces import (
|
|
24
25
|
AccessConfig,
|
|
@@ -50,6 +51,8 @@ if TYPE_CHECKING:
|
|
|
50
51
|
|
|
51
52
|
CONNECTOR_TYPE = "astradb"
|
|
52
53
|
|
|
54
|
+
MAX_CONTENT_PARAM_BYTE_SIZE = 8000
|
|
55
|
+
|
|
53
56
|
|
|
54
57
|
class AstraDBAccessConfig(AccessConfig):
|
|
55
58
|
token: str = Field(description="Astra DB Token with access to the database.")
|
|
@@ -301,7 +304,20 @@ class AstraDBUploadStager(UploadStager):
|
|
|
301
304
|
default_factory=lambda: AstraDBUploadStagerConfig()
|
|
302
305
|
)
|
|
303
306
|
|
|
307
|
+
def truncate_dict_elements(self, element_dict: dict) -> None:
|
|
308
|
+
text = element_dict.pop("text", None)
|
|
309
|
+
if text is not None:
|
|
310
|
+
element_dict["text"] = truncate_string_bytes(text, MAX_CONTENT_PARAM_BYTE_SIZE)
|
|
311
|
+
metadata = element_dict.get("metadata")
|
|
312
|
+
if metadata is not None and isinstance(metadata, dict):
|
|
313
|
+
text_as_html = element_dict["metadata"].pop("text_as_html", None)
|
|
314
|
+
if text_as_html is not None:
|
|
315
|
+
element_dict["metadata"]["text_as_html"] = truncate_string_bytes(
|
|
316
|
+
text_as_html, MAX_CONTENT_PARAM_BYTE_SIZE
|
|
317
|
+
)
|
|
318
|
+
|
|
304
319
|
def conform_dict(self, element_dict: dict, file_data: FileData) -> dict:
|
|
320
|
+
self.truncate_dict_elements(element_dict)
|
|
305
321
|
return {
|
|
306
322
|
"$vector": element_dict.pop("embeddings", None),
|
|
307
323
|
"content": element_dict.pop("text", None),
|
|
@@ -161,7 +161,7 @@ class GoogleDriveIndexer(Indexer):
|
|
|
161
161
|
and isinstance(parent_root_path, str)
|
|
162
162
|
):
|
|
163
163
|
fullpath = f"{parent_path}/{filename}"
|
|
164
|
-
rel_path = fullpath.
|
|
164
|
+
rel_path = Path(fullpath).relative_to(parent_root_path).as_posix()
|
|
165
165
|
source_identifiers = SourceIdentifiers(
|
|
166
166
|
filename=filename, fullpath=fullpath, rel_path=rel_path
|
|
167
167
|
)
|
|
@@ -6,12 +6,25 @@ from .aws import CONNECTOR_TYPE as LANCEDB_S3_CONNECTOR_TYPE
|
|
|
6
6
|
from .aws import lancedb_aws_destination_entry
|
|
7
7
|
from .azure import CONNECTOR_TYPE as LANCEDB_AZURE_CONNECTOR_TYPE
|
|
8
8
|
from .azure import lancedb_azure_destination_entry
|
|
9
|
+
from .cloud import CONNECTOR_TYPE as LANCEDB_CLOUD_CONNECTOR_TYPE
|
|
10
|
+
from .cloud import lancedb_cloud_destination_entry
|
|
9
11
|
from .gcp import CONNECTOR_TYPE as LANCEDB_GCS_CONNECTOR_TYPE
|
|
10
12
|
from .gcp import lancedb_gcp_destination_entry
|
|
11
13
|
from .local import CONNECTOR_TYPE as LANCEDB_LOCAL_CONNECTOR_TYPE
|
|
12
14
|
from .local import lancedb_local_destination_entry
|
|
13
15
|
|
|
14
|
-
add_destination_entry(
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
add_destination_entry(
|
|
16
|
+
add_destination_entry(
|
|
17
|
+
destination_type=LANCEDB_S3_CONNECTOR_TYPE, entry=lancedb_aws_destination_entry
|
|
18
|
+
)
|
|
19
|
+
add_destination_entry(
|
|
20
|
+
destination_type=LANCEDB_AZURE_CONNECTOR_TYPE, entry=lancedb_azure_destination_entry
|
|
21
|
+
)
|
|
22
|
+
add_destination_entry(
|
|
23
|
+
destination_type=LANCEDB_GCS_CONNECTOR_TYPE, entry=lancedb_gcp_destination_entry
|
|
24
|
+
)
|
|
25
|
+
add_destination_entry(
|
|
26
|
+
destination_type=LANCEDB_LOCAL_CONNECTOR_TYPE, entry=lancedb_local_destination_entry
|
|
27
|
+
)
|
|
28
|
+
add_destination_entry(
|
|
29
|
+
destination_type=LANCEDB_CLOUD_CONNECTOR_TYPE, entry=lancedb_cloud_destination_entry
|
|
30
|
+
)
|
|
@@ -15,28 +15,28 @@ from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
|
|
|
15
15
|
CONNECTOR_TYPE = "lancedb_aws"
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class
|
|
18
|
+
class LanceDBAwsAccessConfig(AccessConfig):
|
|
19
19
|
aws_access_key_id: str = Field(description="The AWS access key ID to use.")
|
|
20
20
|
aws_secret_access_key: str = Field(description="The AWS secret access key to use.")
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
class
|
|
24
|
-
access_config: Secret[
|
|
23
|
+
class LanceDBAwsConnectionConfig(LanceDBRemoteConnectionConfig):
|
|
24
|
+
access_config: Secret[LanceDBAwsAccessConfig]
|
|
25
25
|
|
|
26
26
|
def get_storage_options(self) -> dict:
|
|
27
27
|
return {**self.access_config.get_secret_value().model_dump(), "timeout": self.timeout}
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
@dataclass
|
|
31
|
-
class
|
|
31
|
+
class LanceDBAwsUploader(LanceDBUploader):
|
|
32
32
|
upload_config: LanceDBUploaderConfig
|
|
33
|
-
connection_config:
|
|
33
|
+
connection_config: LanceDBAwsConnectionConfig
|
|
34
34
|
connector_type: str = CONNECTOR_TYPE
|
|
35
35
|
|
|
36
36
|
|
|
37
37
|
lancedb_aws_destination_entry = DestinationRegistryEntry(
|
|
38
|
-
connection_config=
|
|
39
|
-
uploader=
|
|
38
|
+
connection_config=LanceDBAwsConnectionConfig,
|
|
39
|
+
uploader=LanceDBAwsUploader,
|
|
40
40
|
uploader_config=LanceDBUploaderConfig,
|
|
41
41
|
upload_stager_config=LanceDBUploadStagerConfig,
|
|
42
42
|
upload_stager=LanceDBUploadStager,
|