unstructured-ingest 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/elasticsearch/__init__.py +0 -0
- test/integration/connectors/elasticsearch/conftest.py +34 -0
- test/integration/connectors/elasticsearch/test_elasticsearch.py +308 -0
- test/integration/connectors/elasticsearch/test_opensearch.py +302 -0
- test/integration/connectors/sql/test_postgres.py +10 -4
- test/integration/connectors/sql/test_singlestore.py +8 -4
- test/integration/connectors/sql/test_snowflake.py +10 -6
- test/integration/connectors/sql/test_sqlite.py +4 -4
- test/integration/connectors/test_astradb.py +50 -3
- test/integration/connectors/test_delta_table.py +46 -0
- test/integration/connectors/test_kafka.py +40 -6
- test/integration/connectors/test_lancedb.py +210 -0
- test/integration/connectors/test_milvus.py +141 -0
- test/integration/connectors/test_mongodb.py +332 -0
- test/integration/connectors/test_pinecone.py +53 -1
- test/integration/connectors/utils/docker.py +81 -15
- test/integration/connectors/utils/validation.py +10 -0
- test/integration/connectors/weaviate/__init__.py +0 -0
- test/integration/connectors/weaviate/conftest.py +15 -0
- test/integration/connectors/weaviate/test_local.py +131 -0
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/pipeline/reformat/embedding.py +1 -1
- unstructured_ingest/utils/data_prep.py +9 -1
- unstructured_ingest/v2/processes/connectors/__init__.py +3 -16
- unstructured_ingest/v2/processes/connectors/astradb.py +2 -2
- unstructured_ingest/v2/processes/connectors/azure_ai_search.py +4 -0
- unstructured_ingest/v2/processes/connectors/delta_table.py +20 -4
- unstructured_ingest/v2/processes/connectors/elasticsearch/__init__.py +19 -0
- unstructured_ingest/v2/processes/connectors/{elasticsearch.py → elasticsearch/elasticsearch.py} +92 -46
- unstructured_ingest/v2/processes/connectors/{opensearch.py → elasticsearch/opensearch.py} +1 -1
- unstructured_ingest/v2/processes/connectors/google_drive.py +1 -1
- unstructured_ingest/v2/processes/connectors/kafka/kafka.py +6 -0
- unstructured_ingest/v2/processes/connectors/lancedb/__init__.py +17 -0
- unstructured_ingest/v2/processes/connectors/lancedb/aws.py +43 -0
- unstructured_ingest/v2/processes/connectors/lancedb/azure.py +43 -0
- unstructured_ingest/v2/processes/connectors/lancedb/gcp.py +44 -0
- unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +161 -0
- unstructured_ingest/v2/processes/connectors/lancedb/local.py +44 -0
- unstructured_ingest/v2/processes/connectors/milvus.py +72 -27
- unstructured_ingest/v2/processes/connectors/mongodb.py +122 -111
- unstructured_ingest/v2/processes/connectors/pinecone.py +24 -7
- unstructured_ingest/v2/processes/connectors/sql/sql.py +97 -26
- unstructured_ingest/v2/processes/connectors/weaviate/__init__.py +25 -0
- unstructured_ingest/v2/processes/connectors/weaviate/cloud.py +164 -0
- unstructured_ingest/v2/processes/connectors/weaviate/embedded.py +90 -0
- unstructured_ingest/v2/processes/connectors/weaviate/local.py +73 -0
- unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +299 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/METADATA +19 -19
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/RECORD +54 -33
- unstructured_ingest/v2/processes/connectors/weaviate.py +0 -242
- /test/integration/connectors/{test_azure_cog_search.py → test_azure_ai_search.py} +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/top_level.txt +0 -0
|
@@ -86,7 +86,7 @@ async def test_snowflake_source():
|
|
|
86
86
|
image="localstack/snowflake",
|
|
87
87
|
environment={"LOCALSTACK_AUTH_TOKEN": token, "EXTRA_CORS_ALLOWED_ORIGINS": "*"},
|
|
88
88
|
ports={4566: 4566, 443: 443},
|
|
89
|
-
|
|
89
|
+
healthcheck_retries=30,
|
|
90
90
|
):
|
|
91
91
|
seed_data()
|
|
92
92
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
@@ -156,7 +156,7 @@ async def test_snowflake_destination(upload_file: Path):
|
|
|
156
156
|
image="localstack/snowflake",
|
|
157
157
|
environment={"LOCALSTACK_AUTH_TOKEN": token, "EXTRA_CORS_ALLOWED_ORIGINS": "*"},
|
|
158
158
|
ports={4566: 4566, 443: 443},
|
|
159
|
-
|
|
159
|
+
healthcheck_retries=30,
|
|
160
160
|
):
|
|
161
161
|
init_db_destination()
|
|
162
162
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
@@ -192,10 +192,8 @@ async def test_snowflake_destination(upload_file: Path):
|
|
|
192
192
|
host=connect_params["host"],
|
|
193
193
|
)
|
|
194
194
|
)
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
else:
|
|
198
|
-
uploader.run(path=staged_path, file_data=mock_file_data)
|
|
195
|
+
|
|
196
|
+
uploader.run(path=staged_path, file_data=mock_file_data)
|
|
199
197
|
|
|
200
198
|
staged_df = pd.read_json(staged_path, orient="records", lines=True)
|
|
201
199
|
expected_num_elements = len(staged_df)
|
|
@@ -203,3 +201,9 @@ async def test_snowflake_destination(upload_file: Path):
|
|
|
203
201
|
connect_params=connect_params,
|
|
204
202
|
expected_num_elements=expected_num_elements,
|
|
205
203
|
)
|
|
204
|
+
|
|
205
|
+
uploader.run(path=staged_path, file_data=mock_file_data)
|
|
206
|
+
validate_destination(
|
|
207
|
+
connect_params=connect_params,
|
|
208
|
+
expected_num_elements=expected_num_elements,
|
|
209
|
+
)
|
|
@@ -138,10 +138,10 @@ async def test_sqlite_destination(upload_file: Path):
|
|
|
138
138
|
uploader = SQLiteUploader(
|
|
139
139
|
connection_config=SQLiteConnectionConfig(database_path=db_path)
|
|
140
140
|
)
|
|
141
|
-
|
|
142
|
-
await uploader.run_async(path=staged_path, file_data=mock_file_data)
|
|
143
|
-
else:
|
|
144
|
-
uploader.run(path=staged_path, file_data=mock_file_data)
|
|
141
|
+
uploader.run(path=staged_path, file_data=mock_file_data)
|
|
145
142
|
|
|
146
143
|
staged_df = pd.read_json(staged_path, orient="records", lines=True)
|
|
147
144
|
validate_destination(db_path=db_path, expected_num_elements=len(staged_df))
|
|
145
|
+
|
|
146
|
+
uploader.run(path=staged_path, file_data=mock_file_data)
|
|
147
|
+
validate_destination(db_path=db_path, expected_num_elements=len(staged_df))
|
|
@@ -8,20 +8,67 @@ import pytest
|
|
|
8
8
|
from astrapy import Collection
|
|
9
9
|
from astrapy import DataAPIClient as AstraDBClient
|
|
10
10
|
|
|
11
|
-
from test.integration.connectors.utils.constants import
|
|
12
|
-
DESTINATION_TAG,
|
|
13
|
-
)
|
|
11
|
+
from test.integration.connectors.utils.constants import DESTINATION_TAG, SOURCE_TAG
|
|
14
12
|
from test.integration.utils import requires_env
|
|
15
13
|
from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
|
|
16
14
|
from unstructured_ingest.v2.processes.connectors.astradb import (
|
|
17
15
|
CONNECTOR_TYPE,
|
|
18
16
|
AstraDBAccessConfig,
|
|
19
17
|
AstraDBConnectionConfig,
|
|
18
|
+
AstraDBIndexer,
|
|
19
|
+
AstraDBIndexerConfig,
|
|
20
20
|
AstraDBUploader,
|
|
21
21
|
AstraDBUploaderConfig,
|
|
22
22
|
AstraDBUploadStager,
|
|
23
|
+
DestinationConnectionError,
|
|
24
|
+
SourceConnectionError,
|
|
23
25
|
)
|
|
24
26
|
|
|
27
|
+
EXISTENT_COLLECTION_NAME = "ingest_test_src"
|
|
28
|
+
NONEXISTENT_COLLECTION_NAME = "nonexistant"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@pytest.fixture
|
|
32
|
+
def connection_config() -> AstraDBConnectionConfig:
|
|
33
|
+
return AstraDBConnectionConfig(
|
|
34
|
+
access_config=AstraDBAccessConfig(
|
|
35
|
+
token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
|
|
36
|
+
api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
|
|
37
|
+
)
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, DESTINATION_TAG)
|
|
42
|
+
@requires_env("ASTRA_DB_APPLICATION_TOKEN", "ASTRA_DB_API_ENDPOINT")
|
|
43
|
+
def test_precheck_succeeds(connection_config: AstraDBConnectionConfig):
|
|
44
|
+
indexer = AstraDBIndexer(
|
|
45
|
+
connection_config=connection_config,
|
|
46
|
+
index_config=AstraDBIndexerConfig(collection_name=EXISTENT_COLLECTION_NAME),
|
|
47
|
+
)
|
|
48
|
+
uploader = AstraDBUploader(
|
|
49
|
+
connection_config=connection_config,
|
|
50
|
+
upload_config=AstraDBUploaderConfig(collection_name=EXISTENT_COLLECTION_NAME),
|
|
51
|
+
)
|
|
52
|
+
indexer.precheck()
|
|
53
|
+
uploader.precheck()
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, DESTINATION_TAG)
|
|
57
|
+
@requires_env("ASTRA_DB_APPLICATION_TOKEN", "ASTRA_DB_API_ENDPOINT")
|
|
58
|
+
def test_precheck_fails(connection_config: AstraDBConnectionConfig):
|
|
59
|
+
indexer = AstraDBIndexer(
|
|
60
|
+
connection_config=connection_config,
|
|
61
|
+
index_config=AstraDBIndexerConfig(collection_name=NONEXISTENT_COLLECTION_NAME),
|
|
62
|
+
)
|
|
63
|
+
uploader = AstraDBUploader(
|
|
64
|
+
connection_config=connection_config,
|
|
65
|
+
upload_config=AstraDBUploaderConfig(collection_name=NONEXISTENT_COLLECTION_NAME),
|
|
66
|
+
)
|
|
67
|
+
with pytest.raises(expected_exception=SourceConnectionError):
|
|
68
|
+
indexer.precheck()
|
|
69
|
+
with pytest.raises(expected_exception=DestinationConnectionError):
|
|
70
|
+
uploader.precheck()
|
|
71
|
+
|
|
25
72
|
|
|
26
73
|
@dataclass(frozen=True)
|
|
27
74
|
class EnvData:
|
|
@@ -136,3 +136,49 @@ async def test_delta_table_destination_s3(upload_file: Path, temp_dir: Path):
|
|
|
136
136
|
secret=aws_credentials["AWS_SECRET_ACCESS_KEY"],
|
|
137
137
|
)
|
|
138
138
|
s3fs.rm(path=destination_path, recursive=True)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@pytest.mark.asyncio
|
|
142
|
+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
|
|
143
|
+
@requires_env("S3_INGEST_TEST_ACCESS_KEY", "S3_INGEST_TEST_SECRET_KEY")
|
|
144
|
+
async def test_delta_table_destination_s3_bad_creds(upload_file: Path, temp_dir: Path):
|
|
145
|
+
aws_credentials = {
|
|
146
|
+
"AWS_ACCESS_KEY_ID": "bad key",
|
|
147
|
+
"AWS_SECRET_ACCESS_KEY": "bad secret",
|
|
148
|
+
"AWS_REGION": "us-east-2",
|
|
149
|
+
}
|
|
150
|
+
s3_bucket = "s3://utic-platform-test-destination"
|
|
151
|
+
destination_path = f"{s3_bucket}/destination/test"
|
|
152
|
+
connection_config = DeltaTableConnectionConfig(
|
|
153
|
+
access_config=DeltaTableAccessConfig(
|
|
154
|
+
aws_access_key_id=aws_credentials["AWS_ACCESS_KEY_ID"],
|
|
155
|
+
aws_secret_access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"],
|
|
156
|
+
),
|
|
157
|
+
aws_region=aws_credentials["AWS_REGION"],
|
|
158
|
+
table_uri=destination_path,
|
|
159
|
+
)
|
|
160
|
+
stager_config = DeltaTableUploadStagerConfig()
|
|
161
|
+
stager = DeltaTableUploadStager(upload_stager_config=stager_config)
|
|
162
|
+
new_upload_file = stager.run(
|
|
163
|
+
elements_filepath=upload_file,
|
|
164
|
+
output_dir=temp_dir,
|
|
165
|
+
output_filename=upload_file.name,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
upload_config = DeltaTableUploaderConfig()
|
|
169
|
+
uploader = DeltaTableUploader(connection_config=connection_config, upload_config=upload_config)
|
|
170
|
+
file_data = FileData(
|
|
171
|
+
source_identifiers=SourceIdentifiers(
|
|
172
|
+
fullpath=upload_file.name, filename=new_upload_file.name
|
|
173
|
+
),
|
|
174
|
+
connector_type=CONNECTOR_TYPE,
|
|
175
|
+
identifier="mock file data",
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
with pytest.raises(Exception) as excinfo:
|
|
179
|
+
if uploader.is_async():
|
|
180
|
+
await uploader.run_async(path=new_upload_file, file_data=file_data)
|
|
181
|
+
else:
|
|
182
|
+
uploader.run(path=new_upload_file, file_data=file_data)
|
|
183
|
+
|
|
184
|
+
assert "403 Forbidden" in str(excinfo.value), f"Exception message did not match: {str(excinfo)}"
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import tempfile
|
|
3
|
+
import time
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
|
|
5
6
|
import pytest
|
|
@@ -33,12 +34,36 @@ SEED_MESSAGES = 10
|
|
|
33
34
|
TOPIC = "fake-topic"
|
|
34
35
|
|
|
35
36
|
|
|
37
|
+
def get_admin_client() -> AdminClient:
|
|
38
|
+
conf = {
|
|
39
|
+
"bootstrap.servers": "localhost:29092",
|
|
40
|
+
}
|
|
41
|
+
return AdminClient(conf)
|
|
42
|
+
|
|
43
|
+
|
|
36
44
|
@pytest.fixture
|
|
37
45
|
def docker_compose_ctx():
|
|
38
46
|
with docker_compose_context(docker_compose_path=env_setup_path / "kafka") as ctx:
|
|
39
47
|
yield ctx
|
|
40
48
|
|
|
41
49
|
|
|
50
|
+
def wait_for_topic(topic: str, retries: int = 10, interval: int = 1):
|
|
51
|
+
admin_client = get_admin_client()
|
|
52
|
+
current_topics = admin_client.list_topics().topics
|
|
53
|
+
attempts = 0
|
|
54
|
+
while topic not in current_topics and attempts < retries:
|
|
55
|
+
attempts += 1
|
|
56
|
+
print(
|
|
57
|
+
"Attempt {}: Waiting for topic {} to exist in {}".format(
|
|
58
|
+
attempts, topic, ", ".join(current_topics)
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
time.sleep(interval)
|
|
62
|
+
current_topics = admin_client.list_topics().topics
|
|
63
|
+
if topic not in current_topics:
|
|
64
|
+
raise TimeoutError(f"Timeout out waiting for topic {topic} to exist")
|
|
65
|
+
|
|
66
|
+
|
|
42
67
|
@pytest.fixture
|
|
43
68
|
def kafka_seed_topic(docker_compose_ctx) -> str:
|
|
44
69
|
conf = {
|
|
@@ -50,15 +75,13 @@ def kafka_seed_topic(docker_compose_ctx) -> str:
|
|
|
50
75
|
producer.produce(topic=TOPIC, value=message)
|
|
51
76
|
producer.flush(timeout=10)
|
|
52
77
|
print(f"kafka topic {TOPIC} seeded with {SEED_MESSAGES} messages")
|
|
78
|
+
wait_for_topic(topic=TOPIC)
|
|
53
79
|
return TOPIC
|
|
54
80
|
|
|
55
81
|
|
|
56
82
|
@pytest.fixture
|
|
57
83
|
def kafka_upload_topic(docker_compose_ctx) -> str:
|
|
58
|
-
|
|
59
|
-
"bootstrap.servers": "localhost:29092",
|
|
60
|
-
}
|
|
61
|
-
admin_client = AdminClient(conf)
|
|
84
|
+
admin_client = get_admin_client()
|
|
62
85
|
admin_client.create_topics([NewTopic(TOPIC, 1, 1)])
|
|
63
86
|
return TOPIC
|
|
64
87
|
|
|
@@ -88,7 +111,7 @@ async def test_kafka_source_local(kafka_seed_topic: str):
|
|
|
88
111
|
|
|
89
112
|
|
|
90
113
|
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
|
|
91
|
-
def
|
|
114
|
+
def test_kafka_source_local_precheck_fail_no_cluster():
|
|
92
115
|
connection_config = LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092)
|
|
93
116
|
indexer = LocalKafkaIndexer(
|
|
94
117
|
connection_config=connection_config,
|
|
@@ -98,6 +121,17 @@ def test_kafak_source_local_precheck_fail():
|
|
|
98
121
|
indexer.precheck()
|
|
99
122
|
|
|
100
123
|
|
|
124
|
+
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
|
|
125
|
+
def test_kafka_source_local_precheck_fail_no_topic(kafka_seed_topic: str):
|
|
126
|
+
connection_config = LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092)
|
|
127
|
+
indexer = LocalKafkaIndexer(
|
|
128
|
+
connection_config=connection_config,
|
|
129
|
+
index_config=LocalKafkaIndexerConfig(topic="topic", num_messages_to_consume=5),
|
|
130
|
+
)
|
|
131
|
+
with pytest.raises(SourceConnectionError):
|
|
132
|
+
indexer.precheck()
|
|
133
|
+
|
|
134
|
+
|
|
101
135
|
def get_all_messages(topic: str, max_empty_messages: int = 5) -> list[dict]:
|
|
102
136
|
conf = {
|
|
103
137
|
"bootstrap.servers": "localhost:29092",
|
|
@@ -158,7 +192,7 @@ async def test_kafka_destination_local(upload_file: Path, kafka_upload_topic: st
|
|
|
158
192
|
|
|
159
193
|
|
|
160
194
|
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
|
|
161
|
-
def
|
|
195
|
+
def test_kafka_destination_local_precheck_fail_no_cluster():
|
|
162
196
|
uploader = LocalKafkaUploader(
|
|
163
197
|
connection_config=LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092),
|
|
164
198
|
upload_config=LocalKafkaUploaderConfig(topic=TOPIC, batch_size=10),
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Literal, Union
|
|
4
|
+
from uuid import uuid4
|
|
5
|
+
|
|
6
|
+
import lancedb
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import pyarrow as pa
|
|
9
|
+
import pytest
|
|
10
|
+
import pytest_asyncio
|
|
11
|
+
from lancedb import AsyncConnection
|
|
12
|
+
from upath import UPath
|
|
13
|
+
|
|
14
|
+
from test.integration.connectors.utils.constants import DESTINATION_TAG
|
|
15
|
+
from unstructured_ingest.v2.interfaces.file_data import FileData, SourceIdentifiers
|
|
16
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.aws import (
|
|
17
|
+
LanceDBS3AccessConfig,
|
|
18
|
+
LanceDBS3ConnectionConfig,
|
|
19
|
+
LanceDBS3Uploader,
|
|
20
|
+
)
|
|
21
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.azure import (
|
|
22
|
+
LanceDBAzureAccessConfig,
|
|
23
|
+
LanceDBAzureConnectionConfig,
|
|
24
|
+
LanceDBAzureUploader,
|
|
25
|
+
)
|
|
26
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.gcp import (
|
|
27
|
+
LanceDBGCSAccessConfig,
|
|
28
|
+
LanceDBGCSConnectionConfig,
|
|
29
|
+
LanceDBGSPUploader,
|
|
30
|
+
)
|
|
31
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
|
|
32
|
+
CONNECTOR_TYPE,
|
|
33
|
+
LanceDBUploaderConfig,
|
|
34
|
+
LanceDBUploadStager,
|
|
35
|
+
)
|
|
36
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.local import (
|
|
37
|
+
LanceDBLocalAccessConfig,
|
|
38
|
+
LanceDBLocalConnectionConfig,
|
|
39
|
+
LanceDBLocalUploader,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
DATABASE_NAME = "database"
|
|
43
|
+
TABLE_NAME = "elements"
|
|
44
|
+
DIMENSION = 384
|
|
45
|
+
NUMBER_EXPECTED_ROWS = 22
|
|
46
|
+
NUMBER_EXPECTED_COLUMNS = 10
|
|
47
|
+
S3_BUCKET = "s3://utic-ingest-test-fixtures/"
|
|
48
|
+
GS_BUCKET = "gs://utic-test-ingest-fixtures-output/"
|
|
49
|
+
AZURE_BUCKET = "az://utic-ingest-test-fixtures-output/"
|
|
50
|
+
REQUIRED_ENV_VARS = {
|
|
51
|
+
"s3": ("S3_INGEST_TEST_ACCESS_KEY", "S3_INGEST_TEST_SECRET_KEY"),
|
|
52
|
+
"gcs": ("GCP_INGEST_SERVICE_KEY",),
|
|
53
|
+
"az": ("AZURE_DEST_CONNECTION_STR",),
|
|
54
|
+
"local": (),
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
SCHEMA = pa.schema(
|
|
59
|
+
[
|
|
60
|
+
pa.field("vector", pa.list_(pa.float16(), DIMENSION)),
|
|
61
|
+
pa.field("text", pa.string(), nullable=True),
|
|
62
|
+
pa.field("type", pa.string(), nullable=True),
|
|
63
|
+
pa.field("element_id", pa.string(), nullable=True),
|
|
64
|
+
pa.field("metadata-text_as_html", pa.string(), nullable=True),
|
|
65
|
+
pa.field("metadata-filetype", pa.string(), nullable=True),
|
|
66
|
+
pa.field("metadata-filename", pa.string(), nullable=True),
|
|
67
|
+
pa.field("metadata-languages", pa.list_(pa.string()), nullable=True),
|
|
68
|
+
pa.field("metadata-is_continuation", pa.bool_(), nullable=True),
|
|
69
|
+
pa.field("metadata-page_number", pa.int32(), nullable=True),
|
|
70
|
+
]
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@pytest_asyncio.fixture
|
|
75
|
+
async def connection_with_uri(request, tmp_path: Path):
|
|
76
|
+
target = request.param
|
|
77
|
+
uri = _get_uri(target, local_base_path=tmp_path)
|
|
78
|
+
|
|
79
|
+
unset_variables = [env for env in REQUIRED_ENV_VARS[target] if env not in os.environ]
|
|
80
|
+
if unset_variables:
|
|
81
|
+
pytest.skip(
|
|
82
|
+
reason="Following required environment variables were not set: "
|
|
83
|
+
+ f"{', '.join(unset_variables)}"
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
storage_options = {
|
|
87
|
+
"aws_access_key_id": os.getenv("S3_INGEST_TEST_ACCESS_KEY"),
|
|
88
|
+
"aws_secret_access_key": os.getenv("S3_INGEST_TEST_SECRET_KEY"),
|
|
89
|
+
"google_service_account_key": os.getenv("GCP_INGEST_SERVICE_KEY"),
|
|
90
|
+
}
|
|
91
|
+
azure_connection_string = os.getenv("AZURE_DEST_CONNECTION_STR")
|
|
92
|
+
if azure_connection_string:
|
|
93
|
+
storage_options.update(_parse_azure_connection_string(azure_connection_string))
|
|
94
|
+
|
|
95
|
+
storage_options = {key: value for key, value in storage_options.items() if value is not None}
|
|
96
|
+
connection = await lancedb.connect_async(
|
|
97
|
+
uri=uri,
|
|
98
|
+
storage_options=storage_options,
|
|
99
|
+
)
|
|
100
|
+
await connection.create_table(name=TABLE_NAME, schema=SCHEMA)
|
|
101
|
+
|
|
102
|
+
yield connection, uri
|
|
103
|
+
|
|
104
|
+
await connection.drop_database()
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@pytest.mark.asyncio
|
|
108
|
+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
|
|
109
|
+
@pytest.mark.parametrize("connection_with_uri", ["local", "s3", "gcs", "az"], indirect=True)
|
|
110
|
+
async def test_lancedb_destination(
|
|
111
|
+
upload_file: Path,
|
|
112
|
+
connection_with_uri: tuple[AsyncConnection, str],
|
|
113
|
+
tmp_path: Path,
|
|
114
|
+
) -> None:
|
|
115
|
+
connection, uri = connection_with_uri
|
|
116
|
+
file_data = FileData(
|
|
117
|
+
source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
|
|
118
|
+
connector_type=CONNECTOR_TYPE,
|
|
119
|
+
identifier="mock file data",
|
|
120
|
+
)
|
|
121
|
+
stager = LanceDBUploadStager()
|
|
122
|
+
uploader = _get_uploader(uri)
|
|
123
|
+
staged_file_path = stager.run(
|
|
124
|
+
elements_filepath=upload_file,
|
|
125
|
+
file_data=file_data,
|
|
126
|
+
output_dir=tmp_path,
|
|
127
|
+
output_filename=upload_file.name,
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
await uploader.run_async(path=staged_file_path, file_data=file_data)
|
|
131
|
+
|
|
132
|
+
table = await connection.open_table(TABLE_NAME)
|
|
133
|
+
table_df: pd.DataFrame = await table.to_pandas()
|
|
134
|
+
|
|
135
|
+
assert len(table_df) == NUMBER_EXPECTED_ROWS
|
|
136
|
+
assert len(table_df.columns) == NUMBER_EXPECTED_COLUMNS
|
|
137
|
+
|
|
138
|
+
assert table_df["element_id"][0] == "2470d8dc42215b3d68413b55bf00fed2"
|
|
139
|
+
assert table_df["type"][0] == "CompositeElement"
|
|
140
|
+
assert table_df["metadata-filename"][0] == "DA-1p-with-duplicate-pages.pdf.json"
|
|
141
|
+
assert table_df["metadata-text_as_html"][0] is None
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _get_uri(target: Literal["local", "s3", "gcs", "az"], local_base_path: Path) -> str:
|
|
145
|
+
if target == "local":
|
|
146
|
+
return str(local_base_path / DATABASE_NAME)
|
|
147
|
+
if target == "s3":
|
|
148
|
+
base_uri = UPath(S3_BUCKET)
|
|
149
|
+
elif target == "gcs":
|
|
150
|
+
base_uri = UPath(GS_BUCKET)
|
|
151
|
+
elif target == "az":
|
|
152
|
+
base_uri = UPath(AZURE_BUCKET)
|
|
153
|
+
|
|
154
|
+
return str(base_uri / "destination" / "lancedb" / str(uuid4()) / DATABASE_NAME)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _get_uploader(
|
|
158
|
+
uri: str,
|
|
159
|
+
) -> Union[LanceDBAzureUploader, LanceDBAzureUploader, LanceDBS3Uploader, LanceDBGSPUploader]:
|
|
160
|
+
target = uri.split("://", maxsplit=1)[0] if uri.startswith(("s3", "az", "gs")) else "local"
|
|
161
|
+
if target == "az":
|
|
162
|
+
azure_connection_string = os.getenv("AZURE_DEST_CONNECTION_STR")
|
|
163
|
+
access_config_kwargs = _parse_azure_connection_string(azure_connection_string)
|
|
164
|
+
return LanceDBAzureUploader(
|
|
165
|
+
upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
|
|
166
|
+
connection_config=LanceDBAzureConnectionConfig(
|
|
167
|
+
access_config=LanceDBAzureAccessConfig(**access_config_kwargs),
|
|
168
|
+
uri=uri,
|
|
169
|
+
),
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
elif target == "s3":
|
|
173
|
+
return LanceDBS3Uploader(
|
|
174
|
+
upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
|
|
175
|
+
connection_config=LanceDBS3ConnectionConfig(
|
|
176
|
+
access_config=LanceDBS3AccessConfig(
|
|
177
|
+
aws_access_key_id=os.getenv("S3_INGEST_TEST_ACCESS_KEY"),
|
|
178
|
+
aws_secret_access_key=os.getenv("S3_INGEST_TEST_SECRET_KEY"),
|
|
179
|
+
),
|
|
180
|
+
uri=uri,
|
|
181
|
+
),
|
|
182
|
+
)
|
|
183
|
+
elif target == "gs":
|
|
184
|
+
return LanceDBGSPUploader(
|
|
185
|
+
upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
|
|
186
|
+
connection_config=LanceDBGCSConnectionConfig(
|
|
187
|
+
access_config=LanceDBGCSAccessConfig(
|
|
188
|
+
google_service_account_key=os.getenv("GCP_INGEST_SERVICE_KEY")
|
|
189
|
+
),
|
|
190
|
+
uri=uri,
|
|
191
|
+
),
|
|
192
|
+
)
|
|
193
|
+
else:
|
|
194
|
+
return LanceDBLocalUploader(
|
|
195
|
+
upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
|
|
196
|
+
connection_config=LanceDBLocalConnectionConfig(
|
|
197
|
+
access_config=LanceDBLocalAccessConfig(),
|
|
198
|
+
uri=uri,
|
|
199
|
+
),
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _parse_azure_connection_string(
|
|
204
|
+
connection_str: str,
|
|
205
|
+
) -> dict[Literal["azure_storage_account_name", "azure_storage_account_key"], str]:
|
|
206
|
+
parameters = dict(keyvalue.split("=", maxsplit=1) for keyvalue in connection_str.split(";"))
|
|
207
|
+
return {
|
|
208
|
+
"azure_storage_account_name": parameters.get("AccountName"),
|
|
209
|
+
"azure_storage_account_key": parameters.get("AccountKey"),
|
|
210
|
+
}
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import time
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import docker
|
|
6
|
+
import pytest
|
|
7
|
+
from pymilvus import (
|
|
8
|
+
CollectionSchema,
|
|
9
|
+
DataType,
|
|
10
|
+
FieldSchema,
|
|
11
|
+
MilvusClient,
|
|
12
|
+
)
|
|
13
|
+
from pymilvus.milvus_client import IndexParams
|
|
14
|
+
|
|
15
|
+
from test.integration.connectors.utils.constants import DESTINATION_TAG, env_setup_path
|
|
16
|
+
from test.integration.connectors.utils.docker import healthcheck_wait
|
|
17
|
+
from test.integration.connectors.utils.docker_compose import docker_compose_context
|
|
18
|
+
from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
|
|
19
|
+
from unstructured_ingest.v2.processes.connectors.milvus import (
|
|
20
|
+
CONNECTOR_TYPE,
|
|
21
|
+
MilvusConnectionConfig,
|
|
22
|
+
MilvusUploader,
|
|
23
|
+
MilvusUploaderConfig,
|
|
24
|
+
MilvusUploadStager,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
DB_URI = "http://localhost:19530"
|
|
28
|
+
DB_NAME = "test_database"
|
|
29
|
+
COLLECTION_NAME = "test_collection"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_schema() -> CollectionSchema:
|
|
33
|
+
id_field = FieldSchema(
|
|
34
|
+
name="id", dtype=DataType.INT64, description="primary field", is_primary=True, auto_id=True
|
|
35
|
+
)
|
|
36
|
+
embeddings_field = FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=384)
|
|
37
|
+
record_id_field = FieldSchema(name="record_id", dtype=DataType.VARCHAR, max_length=64)
|
|
38
|
+
|
|
39
|
+
schema = CollectionSchema(
|
|
40
|
+
enable_dynamic_field=True,
|
|
41
|
+
fields=[
|
|
42
|
+
id_field,
|
|
43
|
+
record_id_field,
|
|
44
|
+
embeddings_field,
|
|
45
|
+
],
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
return schema
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_index_params() -> IndexParams:
|
|
52
|
+
index_params = IndexParams()
|
|
53
|
+
index_params.add_index(field_name="embeddings", index_type="AUTOINDEX", metric_type="COSINE")
|
|
54
|
+
index_params.add_index(field_name="record_id", index_type="Trie")
|
|
55
|
+
return index_params
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@pytest.fixture
|
|
59
|
+
def collection():
|
|
60
|
+
docker_client = docker.from_env()
|
|
61
|
+
with docker_compose_context(docker_compose_path=env_setup_path / "milvus"):
|
|
62
|
+
milvus_container = docker_client.containers.get("milvus-standalone")
|
|
63
|
+
healthcheck_wait(container=milvus_container)
|
|
64
|
+
milvus_client = MilvusClient(uri=DB_URI)
|
|
65
|
+
try:
|
|
66
|
+
# Create the database
|
|
67
|
+
database_resp = milvus_client._get_connection().create_database(db_name=DB_NAME)
|
|
68
|
+
milvus_client.using_database(db_name=DB_NAME)
|
|
69
|
+
|
|
70
|
+
print(f"Created database {DB_NAME}: {database_resp}")
|
|
71
|
+
|
|
72
|
+
# Create the collection
|
|
73
|
+
schema = get_schema()
|
|
74
|
+
index_params = get_index_params()
|
|
75
|
+
collection_resp = milvus_client.create_collection(
|
|
76
|
+
collection_name=COLLECTION_NAME, schema=schema, index_params=index_params
|
|
77
|
+
)
|
|
78
|
+
print(f"Created collection {COLLECTION_NAME}: {collection_resp}")
|
|
79
|
+
yield COLLECTION_NAME
|
|
80
|
+
finally:
|
|
81
|
+
milvus_client.close()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def get_count(client: MilvusClient) -> int:
|
|
85
|
+
count_field = "count(*)"
|
|
86
|
+
resp = client.query(collection_name="test_collection", output_fields=[count_field])
|
|
87
|
+
return resp[0][count_field]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def validate_count(
|
|
91
|
+
client: MilvusClient, expected_count: int, retries: int = 10, interval: int = 1
|
|
92
|
+
) -> None:
|
|
93
|
+
current_count = get_count(client=client)
|
|
94
|
+
retry_count = 0
|
|
95
|
+
while current_count != expected_count and retry_count < retries:
|
|
96
|
+
time.sleep(interval)
|
|
97
|
+
current_count = get_count(client=client)
|
|
98
|
+
retry_count += 1
|
|
99
|
+
assert current_count == expected_count, (
|
|
100
|
+
f"Expected count ({expected_count}) doesn't match how "
|
|
101
|
+
f"much came back from collection: {current_count}"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@pytest.mark.asyncio
|
|
106
|
+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
|
|
107
|
+
async def test_milvus_destination(
|
|
108
|
+
upload_file: Path,
|
|
109
|
+
collection: str,
|
|
110
|
+
tmp_path: Path,
|
|
111
|
+
):
|
|
112
|
+
file_data = FileData(
|
|
113
|
+
source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
|
|
114
|
+
connector_type=CONNECTOR_TYPE,
|
|
115
|
+
identifier="mock file data",
|
|
116
|
+
)
|
|
117
|
+
stager = MilvusUploadStager()
|
|
118
|
+
uploader = MilvusUploader(
|
|
119
|
+
connection_config=MilvusConnectionConfig(uri=DB_URI),
|
|
120
|
+
upload_config=MilvusUploaderConfig(collection_name=collection, db_name=DB_NAME),
|
|
121
|
+
)
|
|
122
|
+
staged_filepath = stager.run(
|
|
123
|
+
elements_filepath=upload_file,
|
|
124
|
+
file_data=file_data,
|
|
125
|
+
output_dir=tmp_path,
|
|
126
|
+
output_filename=upload_file.name,
|
|
127
|
+
)
|
|
128
|
+
uploader.precheck()
|
|
129
|
+
uploader.run(path=staged_filepath, file_data=file_data)
|
|
130
|
+
|
|
131
|
+
# Run validation
|
|
132
|
+
with staged_filepath.open() as f:
|
|
133
|
+
staged_elements = json.load(f)
|
|
134
|
+
expected_count = len(staged_elements)
|
|
135
|
+
with uploader.get_client() as client:
|
|
136
|
+
validate_count(client=client, expected_count=expected_count)
|
|
137
|
+
|
|
138
|
+
# Rerun and make sure the same documents get updated
|
|
139
|
+
uploader.run(path=staged_filepath, file_data=file_data)
|
|
140
|
+
with uploader.get_client() as client:
|
|
141
|
+
validate_count(client=client, expected_count=expected_count)
|