unstructured-ingest 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/elasticsearch/__init__.py +0 -0
- test/integration/connectors/elasticsearch/conftest.py +34 -0
- test/integration/connectors/elasticsearch/test_elasticsearch.py +308 -0
- test/integration/connectors/elasticsearch/test_opensearch.py +302 -0
- test/integration/connectors/sql/test_postgres.py +10 -4
- test/integration/connectors/sql/test_singlestore.py +8 -4
- test/integration/connectors/sql/test_snowflake.py +10 -6
- test/integration/connectors/sql/test_sqlite.py +4 -4
- test/integration/connectors/test_astradb.py +156 -0
- test/integration/connectors/test_azure_cog_search.py +233 -0
- test/integration/connectors/test_delta_table.py +46 -0
- test/integration/connectors/test_kafka.py +150 -16
- test/integration/connectors/test_lancedb.py +209 -0
- test/integration/connectors/test_milvus.py +141 -0
- test/integration/connectors/test_pinecone.py +213 -0
- test/integration/connectors/test_s3.py +23 -0
- test/integration/connectors/utils/docker.py +81 -15
- test/integration/connectors/utils/validation.py +10 -0
- test/integration/connectors/weaviate/__init__.py +0 -0
- test/integration/connectors/weaviate/conftest.py +15 -0
- test/integration/connectors/weaviate/test_local.py +131 -0
- test/unit/v2/__init__.py +0 -0
- test/unit/v2/chunkers/__init__.py +0 -0
- test/unit/v2/chunkers/test_chunkers.py +49 -0
- test/unit/v2/connectors/__init__.py +0 -0
- test/unit/v2/embedders/__init__.py +0 -0
- test/unit/v2/embedders/test_bedrock.py +36 -0
- test/unit/v2/embedders/test_huggingface.py +48 -0
- test/unit/v2/embedders/test_mixedbread.py +37 -0
- test/unit/v2/embedders/test_octoai.py +35 -0
- test/unit/v2/embedders/test_openai.py +35 -0
- test/unit/v2/embedders/test_togetherai.py +37 -0
- test/unit/v2/embedders/test_vertexai.py +37 -0
- test/unit/v2/embedders/test_voyageai.py +38 -0
- test/unit/v2/partitioners/__init__.py +0 -0
- test/unit/v2/partitioners/test_partitioner.py +63 -0
- test/unit/v2/utils/__init__.py +0 -0
- test/unit/v2/utils/data_generator.py +32 -0
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/cli/cmds/__init__.py +2 -2
- unstructured_ingest/cli/cmds/{azure_cognitive_search.py → azure_ai_search.py} +9 -9
- unstructured_ingest/connector/{azure_cognitive_search.py → azure_ai_search.py} +9 -9
- unstructured_ingest/pipeline/reformat/embedding.py +1 -1
- unstructured_ingest/runner/writers/__init__.py +2 -2
- unstructured_ingest/runner/writers/azure_ai_search.py +24 -0
- unstructured_ingest/utils/data_prep.py +9 -1
- unstructured_ingest/v2/constants.py +2 -0
- unstructured_ingest/v2/processes/connectors/__init__.py +7 -20
- unstructured_ingest/v2/processes/connectors/airtable.py +2 -2
- unstructured_ingest/v2/processes/connectors/astradb.py +35 -23
- unstructured_ingest/v2/processes/connectors/{azure_cognitive_search.py → azure_ai_search.py} +116 -35
- unstructured_ingest/v2/processes/connectors/confluence.py +2 -2
- unstructured_ingest/v2/processes/connectors/couchbase.py +1 -0
- unstructured_ingest/v2/processes/connectors/delta_table.py +37 -9
- unstructured_ingest/v2/processes/connectors/elasticsearch/__init__.py +19 -0
- unstructured_ingest/v2/processes/connectors/{elasticsearch.py → elasticsearch/elasticsearch.py} +93 -46
- unstructured_ingest/v2/processes/connectors/{opensearch.py → elasticsearch/opensearch.py} +1 -1
- unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +27 -0
- unstructured_ingest/v2/processes/connectors/google_drive.py +3 -3
- unstructured_ingest/v2/processes/connectors/kafka/__init__.py +6 -2
- unstructured_ingest/v2/processes/connectors/kafka/cloud.py +38 -2
- unstructured_ingest/v2/processes/connectors/kafka/kafka.py +84 -23
- unstructured_ingest/v2/processes/connectors/kafka/local.py +32 -4
- unstructured_ingest/v2/processes/connectors/lancedb/__init__.py +17 -0
- unstructured_ingest/v2/processes/connectors/lancedb/aws.py +43 -0
- unstructured_ingest/v2/processes/connectors/lancedb/azure.py +43 -0
- unstructured_ingest/v2/processes/connectors/lancedb/gcp.py +44 -0
- unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +161 -0
- unstructured_ingest/v2/processes/connectors/lancedb/local.py +44 -0
- unstructured_ingest/v2/processes/connectors/milvus.py +72 -27
- unstructured_ingest/v2/processes/connectors/onedrive.py +2 -3
- unstructured_ingest/v2/processes/connectors/outlook.py +2 -2
- unstructured_ingest/v2/processes/connectors/pinecone.py +101 -13
- unstructured_ingest/v2/processes/connectors/sharepoint.py +3 -2
- unstructured_ingest/v2/processes/connectors/slack.py +2 -2
- unstructured_ingest/v2/processes/connectors/sql/postgres.py +16 -8
- unstructured_ingest/v2/processes/connectors/sql/sql.py +97 -26
- unstructured_ingest/v2/processes/connectors/weaviate/__init__.py +22 -0
- unstructured_ingest/v2/processes/connectors/weaviate/cloud.py +164 -0
- unstructured_ingest/v2/processes/connectors/weaviate/embedded.py +90 -0
- unstructured_ingest/v2/processes/connectors/weaviate/local.py +73 -0
- unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +289 -0
- {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/METADATA +20 -19
- {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/RECORD +91 -50
- unstructured_ingest/runner/writers/azure_cognitive_search.py +0 -24
- unstructured_ingest/v2/processes/connectors/weaviate.py +0 -242
- /test/integration/embedders/{togetherai.py → test_togetherai.py} +0 -0
- /test/unit/{test_interfaces_v2.py → v2/test_interfaces.py} +0 -0
- /test/unit/{test_utils_v2.py → v2/test_utils.py} +0 -0
- {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -1,11 +1,14 @@
|
|
|
1
|
-
import
|
|
1
|
+
import json
|
|
2
2
|
import tempfile
|
|
3
|
+
import time
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
|
|
5
6
|
import pytest
|
|
6
|
-
from confluent_kafka import Producer
|
|
7
|
+
from confluent_kafka import Consumer, KafkaError, KafkaException, Producer
|
|
8
|
+
from confluent_kafka.admin import AdminClient, NewTopic
|
|
7
9
|
|
|
8
10
|
from test.integration.connectors.utils.constants import (
|
|
11
|
+
DESTINATION_TAG,
|
|
9
12
|
SOURCE_TAG,
|
|
10
13
|
env_setup_path,
|
|
11
14
|
)
|
|
@@ -14,6 +17,8 @@ from test.integration.connectors.utils.validation import (
|
|
|
14
17
|
ValidationConfigs,
|
|
15
18
|
source_connector_validation,
|
|
16
19
|
)
|
|
20
|
+
from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
|
|
21
|
+
from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
|
|
17
22
|
from unstructured_ingest.v2.processes.connectors.kafka.local import (
|
|
18
23
|
CONNECTOR_TYPE,
|
|
19
24
|
LocalKafkaConnectionConfig,
|
|
@@ -21,27 +26,64 @@ from unstructured_ingest.v2.processes.connectors.kafka.local import (
|
|
|
21
26
|
LocalKafkaDownloaderConfig,
|
|
22
27
|
LocalKafkaIndexer,
|
|
23
28
|
LocalKafkaIndexerConfig,
|
|
29
|
+
LocalKafkaUploader,
|
|
30
|
+
LocalKafkaUploaderConfig,
|
|
24
31
|
)
|
|
25
32
|
|
|
26
33
|
SEED_MESSAGES = 10
|
|
27
34
|
TOPIC = "fake-topic"
|
|
28
35
|
|
|
29
36
|
|
|
37
|
+
def get_admin_client() -> AdminClient:
|
|
38
|
+
conf = {
|
|
39
|
+
"bootstrap.servers": "localhost:29092",
|
|
40
|
+
}
|
|
41
|
+
return AdminClient(conf)
|
|
42
|
+
|
|
43
|
+
|
|
30
44
|
@pytest.fixture
|
|
31
|
-
def
|
|
32
|
-
with docker_compose_context(docker_compose_path=env_setup_path / "kafka"):
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
+
def docker_compose_ctx():
|
|
46
|
+
with docker_compose_context(docker_compose_path=env_setup_path / "kafka") as ctx:
|
|
47
|
+
yield ctx
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def wait_for_topic(topic: str, retries: int = 10, interval: int = 1):
|
|
51
|
+
admin_client = get_admin_client()
|
|
52
|
+
current_topics = admin_client.list_topics().topics
|
|
53
|
+
attempts = 0
|
|
54
|
+
while topic not in current_topics and attempts < retries:
|
|
55
|
+
attempts += 1
|
|
56
|
+
print(
|
|
57
|
+
"Attempt {}: Waiting for topic {} to exist in {}".format(
|
|
58
|
+
attempts, topic, ", ".join(current_topics)
|
|
59
|
+
)
|
|
60
|
+
)
|
|
61
|
+
time.sleep(interval)
|
|
62
|
+
current_topics = admin_client.list_topics().topics
|
|
63
|
+
if topic not in current_topics:
|
|
64
|
+
raise TimeoutError(f"Timeout out waiting for topic {topic} to exist")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@pytest.fixture
|
|
68
|
+
def kafka_seed_topic(docker_compose_ctx) -> str:
|
|
69
|
+
conf = {
|
|
70
|
+
"bootstrap.servers": "localhost:29092",
|
|
71
|
+
}
|
|
72
|
+
producer = Producer(conf)
|
|
73
|
+
for i in range(SEED_MESSAGES):
|
|
74
|
+
message = f"This is some text for message {i}"
|
|
75
|
+
producer.produce(topic=TOPIC, value=message)
|
|
76
|
+
producer.flush(timeout=10)
|
|
77
|
+
print(f"kafka topic {TOPIC} seeded with {SEED_MESSAGES} messages")
|
|
78
|
+
wait_for_topic(topic=TOPIC)
|
|
79
|
+
return TOPIC
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@pytest.fixture
|
|
83
|
+
def kafka_upload_topic(docker_compose_ctx) -> str:
|
|
84
|
+
admin_client = get_admin_client()
|
|
85
|
+
admin_client.create_topics([NewTopic(TOPIC, 1, 1)])
|
|
86
|
+
return TOPIC
|
|
45
87
|
|
|
46
88
|
|
|
47
89
|
@pytest.mark.asyncio
|
|
@@ -58,6 +100,7 @@ async def test_kafka_source_local(kafka_seed_topic: str):
|
|
|
58
100
|
downloader = LocalKafkaDownloader(
|
|
59
101
|
connection_config=connection_config, download_config=download_config
|
|
60
102
|
)
|
|
103
|
+
indexer.precheck()
|
|
61
104
|
await source_connector_validation(
|
|
62
105
|
indexer=indexer,
|
|
63
106
|
downloader=downloader,
|
|
@@ -65,3 +108,94 @@ async def test_kafka_source_local(kafka_seed_topic: str):
|
|
|
65
108
|
test_id="kafka", expected_num_files=5, validate_downloaded_files=True
|
|
66
109
|
),
|
|
67
110
|
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
|
|
114
|
+
def test_kafka_source_local_precheck_fail_no_cluster():
|
|
115
|
+
connection_config = LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092)
|
|
116
|
+
indexer = LocalKafkaIndexer(
|
|
117
|
+
connection_config=connection_config,
|
|
118
|
+
index_config=LocalKafkaIndexerConfig(topic=TOPIC, num_messages_to_consume=5),
|
|
119
|
+
)
|
|
120
|
+
with pytest.raises(SourceConnectionError):
|
|
121
|
+
indexer.precheck()
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
|
|
125
|
+
def test_kafka_source_local_precheck_fail_no_topic(kafka_seed_topic: str):
|
|
126
|
+
connection_config = LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092)
|
|
127
|
+
indexer = LocalKafkaIndexer(
|
|
128
|
+
connection_config=connection_config,
|
|
129
|
+
index_config=LocalKafkaIndexerConfig(topic="topic", num_messages_to_consume=5),
|
|
130
|
+
)
|
|
131
|
+
with pytest.raises(SourceConnectionError):
|
|
132
|
+
indexer.precheck()
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def get_all_messages(topic: str, max_empty_messages: int = 5) -> list[dict]:
|
|
136
|
+
conf = {
|
|
137
|
+
"bootstrap.servers": "localhost:29092",
|
|
138
|
+
"group.id": "default_group_id",
|
|
139
|
+
"enable.auto.commit": "false",
|
|
140
|
+
"auto.offset.reset": "earliest",
|
|
141
|
+
}
|
|
142
|
+
consumer = Consumer(conf)
|
|
143
|
+
consumer.subscribe([topic])
|
|
144
|
+
messages = []
|
|
145
|
+
try:
|
|
146
|
+
empty_count = 0
|
|
147
|
+
while empty_count < max_empty_messages:
|
|
148
|
+
msg = consumer.poll(timeout=1)
|
|
149
|
+
if msg is None:
|
|
150
|
+
empty_count += 1
|
|
151
|
+
continue
|
|
152
|
+
if msg.error():
|
|
153
|
+
if msg.error().code() == KafkaError._PARTITION_EOF:
|
|
154
|
+
break
|
|
155
|
+
else:
|
|
156
|
+
raise KafkaException(msg.error())
|
|
157
|
+
try:
|
|
158
|
+
message = json.loads(msg.value().decode("utf8"))
|
|
159
|
+
messages.append(message)
|
|
160
|
+
finally:
|
|
161
|
+
consumer.commit(asynchronous=False)
|
|
162
|
+
finally:
|
|
163
|
+
print("closing consumer")
|
|
164
|
+
consumer.close()
|
|
165
|
+
return messages
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
@pytest.mark.asyncio
|
|
169
|
+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
|
|
170
|
+
async def test_kafka_destination_local(upload_file: Path, kafka_upload_topic: str):
|
|
171
|
+
uploader = LocalKafkaUploader(
|
|
172
|
+
connection_config=LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092),
|
|
173
|
+
upload_config=LocalKafkaUploaderConfig(topic=TOPIC, batch_size=10),
|
|
174
|
+
)
|
|
175
|
+
file_data = FileData(
|
|
176
|
+
source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
|
|
177
|
+
connector_type=CONNECTOR_TYPE,
|
|
178
|
+
identifier="mock file data",
|
|
179
|
+
)
|
|
180
|
+
uploader.precheck()
|
|
181
|
+
if uploader.is_async():
|
|
182
|
+
await uploader.run_async(path=upload_file, file_data=file_data)
|
|
183
|
+
else:
|
|
184
|
+
uploader.run(path=upload_file, file_data=file_data)
|
|
185
|
+
all_messages = get_all_messages(topic=kafka_upload_topic)
|
|
186
|
+
with upload_file.open("r") as upload_fs:
|
|
187
|
+
content_to_upload = json.load(upload_fs)
|
|
188
|
+
assert len(all_messages) == len(content_to_upload), (
|
|
189
|
+
f"expected number of messages ({len(content_to_upload)}) doesn't "
|
|
190
|
+
f"match how many messages read off of kakfa topic {kafka_upload_topic}: {len(all_messages)}"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
|
|
195
|
+
def test_kafka_destination_local_precheck_fail_no_cluster():
|
|
196
|
+
uploader = LocalKafkaUploader(
|
|
197
|
+
connection_config=LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092),
|
|
198
|
+
upload_config=LocalKafkaUploaderConfig(topic=TOPIC, batch_size=10),
|
|
199
|
+
)
|
|
200
|
+
with pytest.raises(DestinationConnectionError):
|
|
201
|
+
uploader.precheck()
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Literal, Union
|
|
4
|
+
|
|
5
|
+
import lancedb
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
import pytest
|
|
9
|
+
import pytest_asyncio
|
|
10
|
+
from lancedb import AsyncConnection
|
|
11
|
+
from upath import UPath
|
|
12
|
+
|
|
13
|
+
from test.integration.connectors.utils.constants import DESTINATION_TAG
|
|
14
|
+
from unstructured_ingest.v2.interfaces.file_data import FileData, SourceIdentifiers
|
|
15
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.aws import (
|
|
16
|
+
LanceDBS3AccessConfig,
|
|
17
|
+
LanceDBS3ConnectionConfig,
|
|
18
|
+
LanceDBS3Uploader,
|
|
19
|
+
)
|
|
20
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.azure import (
|
|
21
|
+
LanceDBAzureAccessConfig,
|
|
22
|
+
LanceDBAzureConnectionConfig,
|
|
23
|
+
LanceDBAzureUploader,
|
|
24
|
+
)
|
|
25
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.gcp import (
|
|
26
|
+
LanceDBGCSAccessConfig,
|
|
27
|
+
LanceDBGCSConnectionConfig,
|
|
28
|
+
LanceDBGSPUploader,
|
|
29
|
+
)
|
|
30
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
|
|
31
|
+
CONNECTOR_TYPE,
|
|
32
|
+
LanceDBUploaderConfig,
|
|
33
|
+
LanceDBUploadStager,
|
|
34
|
+
)
|
|
35
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.local import (
|
|
36
|
+
LanceDBLocalAccessConfig,
|
|
37
|
+
LanceDBLocalConnectionConfig,
|
|
38
|
+
LanceDBLocalUploader,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
DATABASE_NAME = "database"
|
|
42
|
+
TABLE_NAME = "elements"
|
|
43
|
+
DIMENSION = 384
|
|
44
|
+
NUMBER_EXPECTED_ROWS = 22
|
|
45
|
+
NUMBER_EXPECTED_COLUMNS = 10
|
|
46
|
+
S3_BUCKET = "s3://utic-ingest-test-fixtures/"
|
|
47
|
+
GS_BUCKET = "gs://utic-test-ingest-fixtures-output/"
|
|
48
|
+
AZURE_BUCKET = "az://utic-ingest-test-fixtures-output/"
|
|
49
|
+
REQUIRED_ENV_VARS = {
|
|
50
|
+
"s3": ("S3_INGEST_TEST_ACCESS_KEY", "S3_INGEST_TEST_SECRET_KEY"),
|
|
51
|
+
"gcs": ("GCP_INGEST_SERVICE_KEY",),
|
|
52
|
+
"az": ("AZURE_DEST_CONNECTION_STR",),
|
|
53
|
+
"local": (),
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
SCHEMA = pa.schema(
|
|
58
|
+
[
|
|
59
|
+
pa.field("vector", pa.list_(pa.float16(), DIMENSION)),
|
|
60
|
+
pa.field("text", pa.string(), nullable=True),
|
|
61
|
+
pa.field("type", pa.string(), nullable=True),
|
|
62
|
+
pa.field("element_id", pa.string(), nullable=True),
|
|
63
|
+
pa.field("metadata-text_as_html", pa.string(), nullable=True),
|
|
64
|
+
pa.field("metadata-filetype", pa.string(), nullable=True),
|
|
65
|
+
pa.field("metadata-filename", pa.string(), nullable=True),
|
|
66
|
+
pa.field("metadata-languages", pa.list_(pa.string()), nullable=True),
|
|
67
|
+
pa.field("metadata-is_continuation", pa.bool_(), nullable=True),
|
|
68
|
+
pa.field("metadata-page_number", pa.int32(), nullable=True),
|
|
69
|
+
]
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@pytest_asyncio.fixture
|
|
74
|
+
async def connection_with_uri(request, tmp_path: Path):
|
|
75
|
+
target = request.param
|
|
76
|
+
uri = _get_uri(target, local_base_path=tmp_path)
|
|
77
|
+
|
|
78
|
+
unset_variables = [env for env in REQUIRED_ENV_VARS[target] if env not in os.environ]
|
|
79
|
+
if unset_variables:
|
|
80
|
+
pytest.skip(
|
|
81
|
+
reason="Following required environment variables were not set: "
|
|
82
|
+
+ f"{', '.join(unset_variables)}"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
storage_options = {
|
|
86
|
+
"aws_access_key_id": os.getenv("S3_INGEST_TEST_ACCESS_KEY"),
|
|
87
|
+
"aws_secret_access_key": os.getenv("S3_INGEST_TEST_SECRET_KEY"),
|
|
88
|
+
"google_service_account_key": os.getenv("GCP_INGEST_SERVICE_KEY"),
|
|
89
|
+
}
|
|
90
|
+
azure_connection_string = os.getenv("AZURE_DEST_CONNECTION_STR")
|
|
91
|
+
if azure_connection_string:
|
|
92
|
+
storage_options.update(_parse_azure_connection_string(azure_connection_string))
|
|
93
|
+
|
|
94
|
+
storage_options = {key: value for key, value in storage_options.items() if value is not None}
|
|
95
|
+
connection = await lancedb.connect_async(
|
|
96
|
+
uri=uri,
|
|
97
|
+
storage_options=storage_options,
|
|
98
|
+
)
|
|
99
|
+
await connection.create_table(name=TABLE_NAME, schema=SCHEMA)
|
|
100
|
+
|
|
101
|
+
yield connection, uri
|
|
102
|
+
|
|
103
|
+
await connection.drop_database()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@pytest.mark.asyncio
|
|
107
|
+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
|
|
108
|
+
@pytest.mark.parametrize("connection_with_uri", ["local", "s3", "gcs", "az"], indirect=True)
|
|
109
|
+
async def test_lancedb_destination(
|
|
110
|
+
upload_file: Path,
|
|
111
|
+
connection_with_uri: tuple[AsyncConnection, str],
|
|
112
|
+
tmp_path: Path,
|
|
113
|
+
) -> None:
|
|
114
|
+
connection, uri = connection_with_uri
|
|
115
|
+
file_data = FileData(
|
|
116
|
+
source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
|
|
117
|
+
connector_type=CONNECTOR_TYPE,
|
|
118
|
+
identifier="mock file data",
|
|
119
|
+
)
|
|
120
|
+
stager = LanceDBUploadStager()
|
|
121
|
+
uploader = _get_uploader(uri)
|
|
122
|
+
staged_file_path = stager.run(
|
|
123
|
+
elements_filepath=upload_file,
|
|
124
|
+
file_data=file_data,
|
|
125
|
+
output_dir=tmp_path,
|
|
126
|
+
output_filename=upload_file.name,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
await uploader.run_async(path=staged_file_path, file_data=file_data)
|
|
130
|
+
|
|
131
|
+
table = await connection.open_table(TABLE_NAME)
|
|
132
|
+
table_df: pd.DataFrame = await table.to_pandas()
|
|
133
|
+
|
|
134
|
+
assert len(table_df) == NUMBER_EXPECTED_ROWS
|
|
135
|
+
assert len(table_df.columns) == NUMBER_EXPECTED_COLUMNS
|
|
136
|
+
|
|
137
|
+
assert table_df["element_id"][0] == "2470d8dc42215b3d68413b55bf00fed2"
|
|
138
|
+
assert table_df["type"][0] == "CompositeElement"
|
|
139
|
+
assert table_df["metadata-filename"][0] == "DA-1p-with-duplicate-pages.pdf.json"
|
|
140
|
+
assert table_df["metadata-text_as_html"][0] is None
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _get_uri(target: Literal["local", "s3", "gcs", "az"], local_base_path: Path) -> str:
|
|
144
|
+
if target == "local":
|
|
145
|
+
return str(local_base_path / DATABASE_NAME)
|
|
146
|
+
if target == "s3":
|
|
147
|
+
base_uri = UPath(S3_BUCKET)
|
|
148
|
+
elif target == "gcs":
|
|
149
|
+
base_uri = UPath(GS_BUCKET)
|
|
150
|
+
elif target == "az":
|
|
151
|
+
base_uri = UPath(AZURE_BUCKET)
|
|
152
|
+
|
|
153
|
+
return str(base_uri / "destination" / "lancedb" / DATABASE_NAME)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def _get_uploader(
|
|
157
|
+
uri: str,
|
|
158
|
+
) -> Union[LanceDBAzureUploader, LanceDBAzureUploader, LanceDBS3Uploader, LanceDBGSPUploader]:
|
|
159
|
+
target = uri.split("://", maxsplit=1)[0] if uri.startswith(("s3", "az", "gs")) else "local"
|
|
160
|
+
if target == "az":
|
|
161
|
+
azure_connection_string = os.getenv("AZURE_DEST_CONNECTION_STR")
|
|
162
|
+
access_config_kwargs = _parse_azure_connection_string(azure_connection_string)
|
|
163
|
+
return LanceDBAzureUploader(
|
|
164
|
+
upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
|
|
165
|
+
connection_config=LanceDBAzureConnectionConfig(
|
|
166
|
+
access_config=LanceDBAzureAccessConfig(**access_config_kwargs),
|
|
167
|
+
uri=uri,
|
|
168
|
+
),
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
elif target == "s3":
|
|
172
|
+
return LanceDBS3Uploader(
|
|
173
|
+
upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
|
|
174
|
+
connection_config=LanceDBS3ConnectionConfig(
|
|
175
|
+
access_config=LanceDBS3AccessConfig(
|
|
176
|
+
aws_access_key_id=os.getenv("S3_INGEST_TEST_ACCESS_KEY"),
|
|
177
|
+
aws_secret_access_key=os.getenv("S3_INGEST_TEST_SECRET_KEY"),
|
|
178
|
+
),
|
|
179
|
+
uri=uri,
|
|
180
|
+
),
|
|
181
|
+
)
|
|
182
|
+
elif target == "gs":
|
|
183
|
+
return LanceDBGSPUploader(
|
|
184
|
+
upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
|
|
185
|
+
connection_config=LanceDBGCSConnectionConfig(
|
|
186
|
+
access_config=LanceDBGCSAccessConfig(
|
|
187
|
+
google_service_account_key=os.getenv("GCP_INGEST_SERVICE_KEY")
|
|
188
|
+
),
|
|
189
|
+
uri=uri,
|
|
190
|
+
),
|
|
191
|
+
)
|
|
192
|
+
else:
|
|
193
|
+
return LanceDBLocalUploader(
|
|
194
|
+
upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
|
|
195
|
+
connection_config=LanceDBLocalConnectionConfig(
|
|
196
|
+
access_config=LanceDBLocalAccessConfig(),
|
|
197
|
+
uri=uri,
|
|
198
|
+
),
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def _parse_azure_connection_string(
|
|
203
|
+
connection_str: str,
|
|
204
|
+
) -> dict[Literal["azure_storage_account_name", "azure_storage_account_key"], str]:
|
|
205
|
+
parameters = dict(keyvalue.split("=", maxsplit=1) for keyvalue in connection_str.split(";"))
|
|
206
|
+
return {
|
|
207
|
+
"azure_storage_account_name": parameters.get("AccountName"),
|
|
208
|
+
"azure_storage_account_key": parameters.get("AccountKey"),
|
|
209
|
+
}
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import time
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import docker
|
|
6
|
+
import pytest
|
|
7
|
+
from pymilvus import (
|
|
8
|
+
CollectionSchema,
|
|
9
|
+
DataType,
|
|
10
|
+
FieldSchema,
|
|
11
|
+
MilvusClient,
|
|
12
|
+
)
|
|
13
|
+
from pymilvus.milvus_client import IndexParams
|
|
14
|
+
|
|
15
|
+
from test.integration.connectors.utils.constants import DESTINATION_TAG, env_setup_path
|
|
16
|
+
from test.integration.connectors.utils.docker import healthcheck_wait
|
|
17
|
+
from test.integration.connectors.utils.docker_compose import docker_compose_context
|
|
18
|
+
from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
|
|
19
|
+
from unstructured_ingest.v2.processes.connectors.milvus import (
|
|
20
|
+
CONNECTOR_TYPE,
|
|
21
|
+
MilvusConnectionConfig,
|
|
22
|
+
MilvusUploader,
|
|
23
|
+
MilvusUploaderConfig,
|
|
24
|
+
MilvusUploadStager,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
DB_URI = "http://localhost:19530"
|
|
28
|
+
DB_NAME = "test_database"
|
|
29
|
+
COLLECTION_NAME = "test_collection"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_schema() -> CollectionSchema:
|
|
33
|
+
id_field = FieldSchema(
|
|
34
|
+
name="id", dtype=DataType.INT64, description="primary field", is_primary=True, auto_id=True
|
|
35
|
+
)
|
|
36
|
+
embeddings_field = FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=384)
|
|
37
|
+
record_id_field = FieldSchema(name="record_id", dtype=DataType.VARCHAR, max_length=64)
|
|
38
|
+
|
|
39
|
+
schema = CollectionSchema(
|
|
40
|
+
enable_dynamic_field=True,
|
|
41
|
+
fields=[
|
|
42
|
+
id_field,
|
|
43
|
+
record_id_field,
|
|
44
|
+
embeddings_field,
|
|
45
|
+
],
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
return schema
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_index_params() -> IndexParams:
|
|
52
|
+
index_params = IndexParams()
|
|
53
|
+
index_params.add_index(field_name="embeddings", index_type="AUTOINDEX", metric_type="COSINE")
|
|
54
|
+
index_params.add_index(field_name="record_id", index_type="Trie")
|
|
55
|
+
return index_params
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@pytest.fixture
|
|
59
|
+
def collection():
|
|
60
|
+
docker_client = docker.from_env()
|
|
61
|
+
with docker_compose_context(docker_compose_path=env_setup_path / "milvus"):
|
|
62
|
+
milvus_container = docker_client.containers.get("milvus-standalone")
|
|
63
|
+
healthcheck_wait(container=milvus_container)
|
|
64
|
+
milvus_client = MilvusClient(uri=DB_URI)
|
|
65
|
+
try:
|
|
66
|
+
# Create the database
|
|
67
|
+
database_resp = milvus_client._get_connection().create_database(db_name=DB_NAME)
|
|
68
|
+
milvus_client.using_database(db_name=DB_NAME)
|
|
69
|
+
|
|
70
|
+
print(f"Created database {DB_NAME}: {database_resp}")
|
|
71
|
+
|
|
72
|
+
# Create the collection
|
|
73
|
+
schema = get_schema()
|
|
74
|
+
index_params = get_index_params()
|
|
75
|
+
collection_resp = milvus_client.create_collection(
|
|
76
|
+
collection_name=COLLECTION_NAME, schema=schema, index_params=index_params
|
|
77
|
+
)
|
|
78
|
+
print(f"Created collection {COLLECTION_NAME}: {collection_resp}")
|
|
79
|
+
yield COLLECTION_NAME
|
|
80
|
+
finally:
|
|
81
|
+
milvus_client.close()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def get_count(client: MilvusClient) -> int:
|
|
85
|
+
count_field = "count(*)"
|
|
86
|
+
resp = client.query(collection_name="test_collection", output_fields=[count_field])
|
|
87
|
+
return resp[0][count_field]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def validate_count(
|
|
91
|
+
client: MilvusClient, expected_count: int, retries: int = 10, interval: int = 1
|
|
92
|
+
) -> None:
|
|
93
|
+
current_count = get_count(client=client)
|
|
94
|
+
retry_count = 0
|
|
95
|
+
while current_count != expected_count and retry_count < retries:
|
|
96
|
+
time.sleep(interval)
|
|
97
|
+
current_count = get_count(client=client)
|
|
98
|
+
retry_count += 1
|
|
99
|
+
assert current_count == expected_count, (
|
|
100
|
+
f"Expected count ({expected_count}) doesn't match how "
|
|
101
|
+
f"much came back from collection: {current_count}"
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@pytest.mark.asyncio
|
|
106
|
+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
|
|
107
|
+
async def test_milvus_destination(
|
|
108
|
+
upload_file: Path,
|
|
109
|
+
collection: str,
|
|
110
|
+
tmp_path: Path,
|
|
111
|
+
):
|
|
112
|
+
file_data = FileData(
|
|
113
|
+
source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
|
|
114
|
+
connector_type=CONNECTOR_TYPE,
|
|
115
|
+
identifier="mock file data",
|
|
116
|
+
)
|
|
117
|
+
stager = MilvusUploadStager()
|
|
118
|
+
uploader = MilvusUploader(
|
|
119
|
+
connection_config=MilvusConnectionConfig(uri=DB_URI),
|
|
120
|
+
upload_config=MilvusUploaderConfig(collection_name=collection, db_name=DB_NAME),
|
|
121
|
+
)
|
|
122
|
+
staged_filepath = stager.run(
|
|
123
|
+
elements_filepath=upload_file,
|
|
124
|
+
file_data=file_data,
|
|
125
|
+
output_dir=tmp_path,
|
|
126
|
+
output_filename=upload_file.name,
|
|
127
|
+
)
|
|
128
|
+
uploader.precheck()
|
|
129
|
+
uploader.run(path=staged_filepath, file_data=file_data)
|
|
130
|
+
|
|
131
|
+
# Run validation
|
|
132
|
+
with staged_filepath.open() as f:
|
|
133
|
+
staged_elements = json.load(f)
|
|
134
|
+
expected_count = len(staged_elements)
|
|
135
|
+
with uploader.get_client() as client:
|
|
136
|
+
validate_count(client=client, expected_count=expected_count)
|
|
137
|
+
|
|
138
|
+
# Rerun and make sure the same documents get updated
|
|
139
|
+
uploader.run(path=staged_filepath, file_data=file_data)
|
|
140
|
+
with uploader.get_client() as client:
|
|
141
|
+
validate_count(client=client, expected_count=expected_count)
|