unstructured-ingest 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (55) hide show
  1. test/integration/connectors/elasticsearch/__init__.py +0 -0
  2. test/integration/connectors/elasticsearch/conftest.py +34 -0
  3. test/integration/connectors/elasticsearch/test_elasticsearch.py +308 -0
  4. test/integration/connectors/elasticsearch/test_opensearch.py +302 -0
  5. test/integration/connectors/sql/test_postgres.py +10 -4
  6. test/integration/connectors/sql/test_singlestore.py +8 -4
  7. test/integration/connectors/sql/test_snowflake.py +10 -6
  8. test/integration/connectors/sql/test_sqlite.py +4 -4
  9. test/integration/connectors/test_astradb.py +50 -3
  10. test/integration/connectors/test_delta_table.py +46 -0
  11. test/integration/connectors/test_kafka.py +40 -6
  12. test/integration/connectors/test_lancedb.py +210 -0
  13. test/integration/connectors/test_milvus.py +141 -0
  14. test/integration/connectors/test_mongodb.py +332 -0
  15. test/integration/connectors/test_pinecone.py +53 -1
  16. test/integration/connectors/utils/docker.py +81 -15
  17. test/integration/connectors/utils/validation.py +10 -0
  18. test/integration/connectors/weaviate/__init__.py +0 -0
  19. test/integration/connectors/weaviate/conftest.py +15 -0
  20. test/integration/connectors/weaviate/test_local.py +131 -0
  21. unstructured_ingest/__version__.py +1 -1
  22. unstructured_ingest/pipeline/reformat/embedding.py +1 -1
  23. unstructured_ingest/utils/data_prep.py +9 -1
  24. unstructured_ingest/v2/processes/connectors/__init__.py +3 -16
  25. unstructured_ingest/v2/processes/connectors/astradb.py +2 -2
  26. unstructured_ingest/v2/processes/connectors/azure_ai_search.py +4 -0
  27. unstructured_ingest/v2/processes/connectors/delta_table.py +20 -4
  28. unstructured_ingest/v2/processes/connectors/elasticsearch/__init__.py +19 -0
  29. unstructured_ingest/v2/processes/connectors/{elasticsearch.py → elasticsearch/elasticsearch.py} +92 -46
  30. unstructured_ingest/v2/processes/connectors/{opensearch.py → elasticsearch/opensearch.py} +1 -1
  31. unstructured_ingest/v2/processes/connectors/google_drive.py +1 -1
  32. unstructured_ingest/v2/processes/connectors/kafka/kafka.py +6 -0
  33. unstructured_ingest/v2/processes/connectors/lancedb/__init__.py +17 -0
  34. unstructured_ingest/v2/processes/connectors/lancedb/aws.py +43 -0
  35. unstructured_ingest/v2/processes/connectors/lancedb/azure.py +43 -0
  36. unstructured_ingest/v2/processes/connectors/lancedb/gcp.py +44 -0
  37. unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +161 -0
  38. unstructured_ingest/v2/processes/connectors/lancedb/local.py +44 -0
  39. unstructured_ingest/v2/processes/connectors/milvus.py +72 -27
  40. unstructured_ingest/v2/processes/connectors/mongodb.py +122 -111
  41. unstructured_ingest/v2/processes/connectors/pinecone.py +24 -7
  42. unstructured_ingest/v2/processes/connectors/sql/sql.py +97 -26
  43. unstructured_ingest/v2/processes/connectors/weaviate/__init__.py +25 -0
  44. unstructured_ingest/v2/processes/connectors/weaviate/cloud.py +164 -0
  45. unstructured_ingest/v2/processes/connectors/weaviate/embedded.py +90 -0
  46. unstructured_ingest/v2/processes/connectors/weaviate/local.py +73 -0
  47. unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +299 -0
  48. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/METADATA +19 -19
  49. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/RECORD +54 -33
  50. unstructured_ingest/v2/processes/connectors/weaviate.py +0 -242
  51. /test/integration/connectors/{test_azure_cog_search.py → test_azure_ai_search.py} +0 -0
  52. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/LICENSE.md +0 -0
  53. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/WHEEL +0 -0
  54. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/entry_points.txt +0 -0
  55. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/top_level.txt +0 -0
@@ -86,7 +86,7 @@ async def test_snowflake_source():
86
86
  image="localstack/snowflake",
87
87
  environment={"LOCALSTACK_AUTH_TOKEN": token, "EXTRA_CORS_ALLOWED_ORIGINS": "*"},
88
88
  ports={4566: 4566, 443: 443},
89
- healthcheck_timeout=30,
89
+ healthcheck_retries=30,
90
90
  ):
91
91
  seed_data()
92
92
  with tempfile.TemporaryDirectory() as tmpdir:
@@ -156,7 +156,7 @@ async def test_snowflake_destination(upload_file: Path):
156
156
  image="localstack/snowflake",
157
157
  environment={"LOCALSTACK_AUTH_TOKEN": token, "EXTRA_CORS_ALLOWED_ORIGINS": "*"},
158
158
  ports={4566: 4566, 443: 443},
159
- healthcheck_timeout=30,
159
+ healthcheck_retries=30,
160
160
  ):
161
161
  init_db_destination()
162
162
  with tempfile.TemporaryDirectory() as tmpdir:
@@ -192,10 +192,8 @@ async def test_snowflake_destination(upload_file: Path):
192
192
  host=connect_params["host"],
193
193
  )
194
194
  )
195
- if uploader.is_async():
196
- await uploader.run_async(path=staged_path, file_data=mock_file_data)
197
- else:
198
- uploader.run(path=staged_path, file_data=mock_file_data)
195
+
196
+ uploader.run(path=staged_path, file_data=mock_file_data)
199
197
 
200
198
  staged_df = pd.read_json(staged_path, orient="records", lines=True)
201
199
  expected_num_elements = len(staged_df)
@@ -203,3 +201,9 @@ async def test_snowflake_destination(upload_file: Path):
203
201
  connect_params=connect_params,
204
202
  expected_num_elements=expected_num_elements,
205
203
  )
204
+
205
+ uploader.run(path=staged_path, file_data=mock_file_data)
206
+ validate_destination(
207
+ connect_params=connect_params,
208
+ expected_num_elements=expected_num_elements,
209
+ )
@@ -138,10 +138,10 @@ async def test_sqlite_destination(upload_file: Path):
138
138
  uploader = SQLiteUploader(
139
139
  connection_config=SQLiteConnectionConfig(database_path=db_path)
140
140
  )
141
- if uploader.is_async():
142
- await uploader.run_async(path=staged_path, file_data=mock_file_data)
143
- else:
144
- uploader.run(path=staged_path, file_data=mock_file_data)
141
+ uploader.run(path=staged_path, file_data=mock_file_data)
145
142
 
146
143
  staged_df = pd.read_json(staged_path, orient="records", lines=True)
147
144
  validate_destination(db_path=db_path, expected_num_elements=len(staged_df))
145
+
146
+ uploader.run(path=staged_path, file_data=mock_file_data)
147
+ validate_destination(db_path=db_path, expected_num_elements=len(staged_df))
@@ -8,20 +8,67 @@ import pytest
8
8
  from astrapy import Collection
9
9
  from astrapy import DataAPIClient as AstraDBClient
10
10
 
11
- from test.integration.connectors.utils.constants import (
12
- DESTINATION_TAG,
13
- )
11
+ from test.integration.connectors.utils.constants import DESTINATION_TAG, SOURCE_TAG
14
12
  from test.integration.utils import requires_env
15
13
  from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
16
14
  from unstructured_ingest.v2.processes.connectors.astradb import (
17
15
  CONNECTOR_TYPE,
18
16
  AstraDBAccessConfig,
19
17
  AstraDBConnectionConfig,
18
+ AstraDBIndexer,
19
+ AstraDBIndexerConfig,
20
20
  AstraDBUploader,
21
21
  AstraDBUploaderConfig,
22
22
  AstraDBUploadStager,
23
+ DestinationConnectionError,
24
+ SourceConnectionError,
23
25
  )
24
26
 
27
+ EXISTENT_COLLECTION_NAME = "ingest_test_src"
28
+ NONEXISTENT_COLLECTION_NAME = "nonexistant"
29
+
30
+
31
+ @pytest.fixture
32
+ def connection_config() -> AstraDBConnectionConfig:
33
+ return AstraDBConnectionConfig(
34
+ access_config=AstraDBAccessConfig(
35
+ token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
36
+ api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
37
+ )
38
+ )
39
+
40
+
41
+ @pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, DESTINATION_TAG)
42
+ @requires_env("ASTRA_DB_APPLICATION_TOKEN", "ASTRA_DB_API_ENDPOINT")
43
+ def test_precheck_succeeds(connection_config: AstraDBConnectionConfig):
44
+ indexer = AstraDBIndexer(
45
+ connection_config=connection_config,
46
+ index_config=AstraDBIndexerConfig(collection_name=EXISTENT_COLLECTION_NAME),
47
+ )
48
+ uploader = AstraDBUploader(
49
+ connection_config=connection_config,
50
+ upload_config=AstraDBUploaderConfig(collection_name=EXISTENT_COLLECTION_NAME),
51
+ )
52
+ indexer.precheck()
53
+ uploader.precheck()
54
+
55
+
56
+ @pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, DESTINATION_TAG)
57
+ @requires_env("ASTRA_DB_APPLICATION_TOKEN", "ASTRA_DB_API_ENDPOINT")
58
+ def test_precheck_fails(connection_config: AstraDBConnectionConfig):
59
+ indexer = AstraDBIndexer(
60
+ connection_config=connection_config,
61
+ index_config=AstraDBIndexerConfig(collection_name=NONEXISTENT_COLLECTION_NAME),
62
+ )
63
+ uploader = AstraDBUploader(
64
+ connection_config=connection_config,
65
+ upload_config=AstraDBUploaderConfig(collection_name=NONEXISTENT_COLLECTION_NAME),
66
+ )
67
+ with pytest.raises(expected_exception=SourceConnectionError):
68
+ indexer.precheck()
69
+ with pytest.raises(expected_exception=DestinationConnectionError):
70
+ uploader.precheck()
71
+
25
72
 
26
73
  @dataclass(frozen=True)
27
74
  class EnvData:
@@ -136,3 +136,49 @@ async def test_delta_table_destination_s3(upload_file: Path, temp_dir: Path):
136
136
  secret=aws_credentials["AWS_SECRET_ACCESS_KEY"],
137
137
  )
138
138
  s3fs.rm(path=destination_path, recursive=True)
139
+
140
+
141
+ @pytest.mark.asyncio
142
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
143
+ @requires_env("S3_INGEST_TEST_ACCESS_KEY", "S3_INGEST_TEST_SECRET_KEY")
144
+ async def test_delta_table_destination_s3_bad_creds(upload_file: Path, temp_dir: Path):
145
+ aws_credentials = {
146
+ "AWS_ACCESS_KEY_ID": "bad key",
147
+ "AWS_SECRET_ACCESS_KEY": "bad secret",
148
+ "AWS_REGION": "us-east-2",
149
+ }
150
+ s3_bucket = "s3://utic-platform-test-destination"
151
+ destination_path = f"{s3_bucket}/destination/test"
152
+ connection_config = DeltaTableConnectionConfig(
153
+ access_config=DeltaTableAccessConfig(
154
+ aws_access_key_id=aws_credentials["AWS_ACCESS_KEY_ID"],
155
+ aws_secret_access_key=aws_credentials["AWS_SECRET_ACCESS_KEY"],
156
+ ),
157
+ aws_region=aws_credentials["AWS_REGION"],
158
+ table_uri=destination_path,
159
+ )
160
+ stager_config = DeltaTableUploadStagerConfig()
161
+ stager = DeltaTableUploadStager(upload_stager_config=stager_config)
162
+ new_upload_file = stager.run(
163
+ elements_filepath=upload_file,
164
+ output_dir=temp_dir,
165
+ output_filename=upload_file.name,
166
+ )
167
+
168
+ upload_config = DeltaTableUploaderConfig()
169
+ uploader = DeltaTableUploader(connection_config=connection_config, upload_config=upload_config)
170
+ file_data = FileData(
171
+ source_identifiers=SourceIdentifiers(
172
+ fullpath=upload_file.name, filename=new_upload_file.name
173
+ ),
174
+ connector_type=CONNECTOR_TYPE,
175
+ identifier="mock file data",
176
+ )
177
+
178
+ with pytest.raises(Exception) as excinfo:
179
+ if uploader.is_async():
180
+ await uploader.run_async(path=new_upload_file, file_data=file_data)
181
+ else:
182
+ uploader.run(path=new_upload_file, file_data=file_data)
183
+
184
+ assert "403 Forbidden" in str(excinfo.value), f"Exception message did not match: {str(excinfo)}"
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import tempfile
3
+ import time
3
4
  from pathlib import Path
4
5
 
5
6
  import pytest
@@ -33,12 +34,36 @@ SEED_MESSAGES = 10
33
34
  TOPIC = "fake-topic"
34
35
 
35
36
 
37
+ def get_admin_client() -> AdminClient:
38
+ conf = {
39
+ "bootstrap.servers": "localhost:29092",
40
+ }
41
+ return AdminClient(conf)
42
+
43
+
36
44
  @pytest.fixture
37
45
  def docker_compose_ctx():
38
46
  with docker_compose_context(docker_compose_path=env_setup_path / "kafka") as ctx:
39
47
  yield ctx
40
48
 
41
49
 
50
+ def wait_for_topic(topic: str, retries: int = 10, interval: int = 1):
51
+ admin_client = get_admin_client()
52
+ current_topics = admin_client.list_topics().topics
53
+ attempts = 0
54
+ while topic not in current_topics and attempts < retries:
55
+ attempts += 1
56
+ print(
57
+ "Attempt {}: Waiting for topic {} to exist in {}".format(
58
+ attempts, topic, ", ".join(current_topics)
59
+ )
60
+ )
61
+ time.sleep(interval)
62
+ current_topics = admin_client.list_topics().topics
63
+ if topic not in current_topics:
64
+ raise TimeoutError(f"Timeout out waiting for topic {topic} to exist")
65
+
66
+
42
67
  @pytest.fixture
43
68
  def kafka_seed_topic(docker_compose_ctx) -> str:
44
69
  conf = {
@@ -50,15 +75,13 @@ def kafka_seed_topic(docker_compose_ctx) -> str:
50
75
  producer.produce(topic=TOPIC, value=message)
51
76
  producer.flush(timeout=10)
52
77
  print(f"kafka topic {TOPIC} seeded with {SEED_MESSAGES} messages")
78
+ wait_for_topic(topic=TOPIC)
53
79
  return TOPIC
54
80
 
55
81
 
56
82
  @pytest.fixture
57
83
  def kafka_upload_topic(docker_compose_ctx) -> str:
58
- conf = {
59
- "bootstrap.servers": "localhost:29092",
60
- }
61
- admin_client = AdminClient(conf)
84
+ admin_client = get_admin_client()
62
85
  admin_client.create_topics([NewTopic(TOPIC, 1, 1)])
63
86
  return TOPIC
64
87
 
@@ -88,7 +111,7 @@ async def test_kafka_source_local(kafka_seed_topic: str):
88
111
 
89
112
 
90
113
  @pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
91
- def test_kafak_source_local_precheck_fail():
114
+ def test_kafka_source_local_precheck_fail_no_cluster():
92
115
  connection_config = LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092)
93
116
  indexer = LocalKafkaIndexer(
94
117
  connection_config=connection_config,
@@ -98,6 +121,17 @@ def test_kafak_source_local_precheck_fail():
98
121
  indexer.precheck()
99
122
 
100
123
 
124
+ @pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
125
+ def test_kafka_source_local_precheck_fail_no_topic(kafka_seed_topic: str):
126
+ connection_config = LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092)
127
+ indexer = LocalKafkaIndexer(
128
+ connection_config=connection_config,
129
+ index_config=LocalKafkaIndexerConfig(topic="topic", num_messages_to_consume=5),
130
+ )
131
+ with pytest.raises(SourceConnectionError):
132
+ indexer.precheck()
133
+
134
+
101
135
  def get_all_messages(topic: str, max_empty_messages: int = 5) -> list[dict]:
102
136
  conf = {
103
137
  "bootstrap.servers": "localhost:29092",
@@ -158,7 +192,7 @@ async def test_kafka_destination_local(upload_file: Path, kafka_upload_topic: st
158
192
 
159
193
 
160
194
  @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
161
- def test_kafak_destination_local_precheck_fail():
195
+ def test_kafka_destination_local_precheck_fail_no_cluster():
162
196
  uploader = LocalKafkaUploader(
163
197
  connection_config=LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092),
164
198
  upload_config=LocalKafkaUploaderConfig(topic=TOPIC, batch_size=10),
@@ -0,0 +1,210 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Literal, Union
4
+ from uuid import uuid4
5
+
6
+ import lancedb
7
+ import pandas as pd
8
+ import pyarrow as pa
9
+ import pytest
10
+ import pytest_asyncio
11
+ from lancedb import AsyncConnection
12
+ from upath import UPath
13
+
14
+ from test.integration.connectors.utils.constants import DESTINATION_TAG
15
+ from unstructured_ingest.v2.interfaces.file_data import FileData, SourceIdentifiers
16
+ from unstructured_ingest.v2.processes.connectors.lancedb.aws import (
17
+ LanceDBS3AccessConfig,
18
+ LanceDBS3ConnectionConfig,
19
+ LanceDBS3Uploader,
20
+ )
21
+ from unstructured_ingest.v2.processes.connectors.lancedb.azure import (
22
+ LanceDBAzureAccessConfig,
23
+ LanceDBAzureConnectionConfig,
24
+ LanceDBAzureUploader,
25
+ )
26
+ from unstructured_ingest.v2.processes.connectors.lancedb.gcp import (
27
+ LanceDBGCSAccessConfig,
28
+ LanceDBGCSConnectionConfig,
29
+ LanceDBGSPUploader,
30
+ )
31
+ from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
32
+ CONNECTOR_TYPE,
33
+ LanceDBUploaderConfig,
34
+ LanceDBUploadStager,
35
+ )
36
+ from unstructured_ingest.v2.processes.connectors.lancedb.local import (
37
+ LanceDBLocalAccessConfig,
38
+ LanceDBLocalConnectionConfig,
39
+ LanceDBLocalUploader,
40
+ )
41
+
42
+ DATABASE_NAME = "database"
43
+ TABLE_NAME = "elements"
44
+ DIMENSION = 384
45
+ NUMBER_EXPECTED_ROWS = 22
46
+ NUMBER_EXPECTED_COLUMNS = 10
47
+ S3_BUCKET = "s3://utic-ingest-test-fixtures/"
48
+ GS_BUCKET = "gs://utic-test-ingest-fixtures-output/"
49
+ AZURE_BUCKET = "az://utic-ingest-test-fixtures-output/"
50
+ REQUIRED_ENV_VARS = {
51
+ "s3": ("S3_INGEST_TEST_ACCESS_KEY", "S3_INGEST_TEST_SECRET_KEY"),
52
+ "gcs": ("GCP_INGEST_SERVICE_KEY",),
53
+ "az": ("AZURE_DEST_CONNECTION_STR",),
54
+ "local": (),
55
+ }
56
+
57
+
58
+ SCHEMA = pa.schema(
59
+ [
60
+ pa.field("vector", pa.list_(pa.float16(), DIMENSION)),
61
+ pa.field("text", pa.string(), nullable=True),
62
+ pa.field("type", pa.string(), nullable=True),
63
+ pa.field("element_id", pa.string(), nullable=True),
64
+ pa.field("metadata-text_as_html", pa.string(), nullable=True),
65
+ pa.field("metadata-filetype", pa.string(), nullable=True),
66
+ pa.field("metadata-filename", pa.string(), nullable=True),
67
+ pa.field("metadata-languages", pa.list_(pa.string()), nullable=True),
68
+ pa.field("metadata-is_continuation", pa.bool_(), nullable=True),
69
+ pa.field("metadata-page_number", pa.int32(), nullable=True),
70
+ ]
71
+ )
72
+
73
+
74
+ @pytest_asyncio.fixture
75
+ async def connection_with_uri(request, tmp_path: Path):
76
+ target = request.param
77
+ uri = _get_uri(target, local_base_path=tmp_path)
78
+
79
+ unset_variables = [env for env in REQUIRED_ENV_VARS[target] if env not in os.environ]
80
+ if unset_variables:
81
+ pytest.skip(
82
+ reason="Following required environment variables were not set: "
83
+ + f"{', '.join(unset_variables)}"
84
+ )
85
+
86
+ storage_options = {
87
+ "aws_access_key_id": os.getenv("S3_INGEST_TEST_ACCESS_KEY"),
88
+ "aws_secret_access_key": os.getenv("S3_INGEST_TEST_SECRET_KEY"),
89
+ "google_service_account_key": os.getenv("GCP_INGEST_SERVICE_KEY"),
90
+ }
91
+ azure_connection_string = os.getenv("AZURE_DEST_CONNECTION_STR")
92
+ if azure_connection_string:
93
+ storage_options.update(_parse_azure_connection_string(azure_connection_string))
94
+
95
+ storage_options = {key: value for key, value in storage_options.items() if value is not None}
96
+ connection = await lancedb.connect_async(
97
+ uri=uri,
98
+ storage_options=storage_options,
99
+ )
100
+ await connection.create_table(name=TABLE_NAME, schema=SCHEMA)
101
+
102
+ yield connection, uri
103
+
104
+ await connection.drop_database()
105
+
106
+
107
+ @pytest.mark.asyncio
108
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
109
+ @pytest.mark.parametrize("connection_with_uri", ["local", "s3", "gcs", "az"], indirect=True)
110
+ async def test_lancedb_destination(
111
+ upload_file: Path,
112
+ connection_with_uri: tuple[AsyncConnection, str],
113
+ tmp_path: Path,
114
+ ) -> None:
115
+ connection, uri = connection_with_uri
116
+ file_data = FileData(
117
+ source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
118
+ connector_type=CONNECTOR_TYPE,
119
+ identifier="mock file data",
120
+ )
121
+ stager = LanceDBUploadStager()
122
+ uploader = _get_uploader(uri)
123
+ staged_file_path = stager.run(
124
+ elements_filepath=upload_file,
125
+ file_data=file_data,
126
+ output_dir=tmp_path,
127
+ output_filename=upload_file.name,
128
+ )
129
+
130
+ await uploader.run_async(path=staged_file_path, file_data=file_data)
131
+
132
+ table = await connection.open_table(TABLE_NAME)
133
+ table_df: pd.DataFrame = await table.to_pandas()
134
+
135
+ assert len(table_df) == NUMBER_EXPECTED_ROWS
136
+ assert len(table_df.columns) == NUMBER_EXPECTED_COLUMNS
137
+
138
+ assert table_df["element_id"][0] == "2470d8dc42215b3d68413b55bf00fed2"
139
+ assert table_df["type"][0] == "CompositeElement"
140
+ assert table_df["metadata-filename"][0] == "DA-1p-with-duplicate-pages.pdf.json"
141
+ assert table_df["metadata-text_as_html"][0] is None
142
+
143
+
144
+ def _get_uri(target: Literal["local", "s3", "gcs", "az"], local_base_path: Path) -> str:
145
+ if target == "local":
146
+ return str(local_base_path / DATABASE_NAME)
147
+ if target == "s3":
148
+ base_uri = UPath(S3_BUCKET)
149
+ elif target == "gcs":
150
+ base_uri = UPath(GS_BUCKET)
151
+ elif target == "az":
152
+ base_uri = UPath(AZURE_BUCKET)
153
+
154
+ return str(base_uri / "destination" / "lancedb" / str(uuid4()) / DATABASE_NAME)
155
+
156
+
157
+ def _get_uploader(
158
+ uri: str,
159
+ ) -> Union[LanceDBAzureUploader, LanceDBAzureUploader, LanceDBS3Uploader, LanceDBGSPUploader]:
160
+ target = uri.split("://", maxsplit=1)[0] if uri.startswith(("s3", "az", "gs")) else "local"
161
+ if target == "az":
162
+ azure_connection_string = os.getenv("AZURE_DEST_CONNECTION_STR")
163
+ access_config_kwargs = _parse_azure_connection_string(azure_connection_string)
164
+ return LanceDBAzureUploader(
165
+ upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
166
+ connection_config=LanceDBAzureConnectionConfig(
167
+ access_config=LanceDBAzureAccessConfig(**access_config_kwargs),
168
+ uri=uri,
169
+ ),
170
+ )
171
+
172
+ elif target == "s3":
173
+ return LanceDBS3Uploader(
174
+ upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
175
+ connection_config=LanceDBS3ConnectionConfig(
176
+ access_config=LanceDBS3AccessConfig(
177
+ aws_access_key_id=os.getenv("S3_INGEST_TEST_ACCESS_KEY"),
178
+ aws_secret_access_key=os.getenv("S3_INGEST_TEST_SECRET_KEY"),
179
+ ),
180
+ uri=uri,
181
+ ),
182
+ )
183
+ elif target == "gs":
184
+ return LanceDBGSPUploader(
185
+ upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
186
+ connection_config=LanceDBGCSConnectionConfig(
187
+ access_config=LanceDBGCSAccessConfig(
188
+ google_service_account_key=os.getenv("GCP_INGEST_SERVICE_KEY")
189
+ ),
190
+ uri=uri,
191
+ ),
192
+ )
193
+ else:
194
+ return LanceDBLocalUploader(
195
+ upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
196
+ connection_config=LanceDBLocalConnectionConfig(
197
+ access_config=LanceDBLocalAccessConfig(),
198
+ uri=uri,
199
+ ),
200
+ )
201
+
202
+
203
+ def _parse_azure_connection_string(
204
+ connection_str: str,
205
+ ) -> dict[Literal["azure_storage_account_name", "azure_storage_account_key"], str]:
206
+ parameters = dict(keyvalue.split("=", maxsplit=1) for keyvalue in connection_str.split(";"))
207
+ return {
208
+ "azure_storage_account_name": parameters.get("AccountName"),
209
+ "azure_storage_account_key": parameters.get("AccountKey"),
210
+ }
@@ -0,0 +1,141 @@
1
+ import json
2
+ import time
3
+ from pathlib import Path
4
+
5
+ import docker
6
+ import pytest
7
+ from pymilvus import (
8
+ CollectionSchema,
9
+ DataType,
10
+ FieldSchema,
11
+ MilvusClient,
12
+ )
13
+ from pymilvus.milvus_client import IndexParams
14
+
15
+ from test.integration.connectors.utils.constants import DESTINATION_TAG, env_setup_path
16
+ from test.integration.connectors.utils.docker import healthcheck_wait
17
+ from test.integration.connectors.utils.docker_compose import docker_compose_context
18
+ from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
19
+ from unstructured_ingest.v2.processes.connectors.milvus import (
20
+ CONNECTOR_TYPE,
21
+ MilvusConnectionConfig,
22
+ MilvusUploader,
23
+ MilvusUploaderConfig,
24
+ MilvusUploadStager,
25
+ )
26
+
27
+ DB_URI = "http://localhost:19530"
28
+ DB_NAME = "test_database"
29
+ COLLECTION_NAME = "test_collection"
30
+
31
+
32
+ def get_schema() -> CollectionSchema:
33
+ id_field = FieldSchema(
34
+ name="id", dtype=DataType.INT64, description="primary field", is_primary=True, auto_id=True
35
+ )
36
+ embeddings_field = FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=384)
37
+ record_id_field = FieldSchema(name="record_id", dtype=DataType.VARCHAR, max_length=64)
38
+
39
+ schema = CollectionSchema(
40
+ enable_dynamic_field=True,
41
+ fields=[
42
+ id_field,
43
+ record_id_field,
44
+ embeddings_field,
45
+ ],
46
+ )
47
+
48
+ return schema
49
+
50
+
51
+ def get_index_params() -> IndexParams:
52
+ index_params = IndexParams()
53
+ index_params.add_index(field_name="embeddings", index_type="AUTOINDEX", metric_type="COSINE")
54
+ index_params.add_index(field_name="record_id", index_type="Trie")
55
+ return index_params
56
+
57
+
58
+ @pytest.fixture
59
+ def collection():
60
+ docker_client = docker.from_env()
61
+ with docker_compose_context(docker_compose_path=env_setup_path / "milvus"):
62
+ milvus_container = docker_client.containers.get("milvus-standalone")
63
+ healthcheck_wait(container=milvus_container)
64
+ milvus_client = MilvusClient(uri=DB_URI)
65
+ try:
66
+ # Create the database
67
+ database_resp = milvus_client._get_connection().create_database(db_name=DB_NAME)
68
+ milvus_client.using_database(db_name=DB_NAME)
69
+
70
+ print(f"Created database {DB_NAME}: {database_resp}")
71
+
72
+ # Create the collection
73
+ schema = get_schema()
74
+ index_params = get_index_params()
75
+ collection_resp = milvus_client.create_collection(
76
+ collection_name=COLLECTION_NAME, schema=schema, index_params=index_params
77
+ )
78
+ print(f"Created collection {COLLECTION_NAME}: {collection_resp}")
79
+ yield COLLECTION_NAME
80
+ finally:
81
+ milvus_client.close()
82
+
83
+
84
+ def get_count(client: MilvusClient) -> int:
85
+ count_field = "count(*)"
86
+ resp = client.query(collection_name="test_collection", output_fields=[count_field])
87
+ return resp[0][count_field]
88
+
89
+
90
+ def validate_count(
91
+ client: MilvusClient, expected_count: int, retries: int = 10, interval: int = 1
92
+ ) -> None:
93
+ current_count = get_count(client=client)
94
+ retry_count = 0
95
+ while current_count != expected_count and retry_count < retries:
96
+ time.sleep(interval)
97
+ current_count = get_count(client=client)
98
+ retry_count += 1
99
+ assert current_count == expected_count, (
100
+ f"Expected count ({expected_count}) doesn't match how "
101
+ f"much came back from collection: {current_count}"
102
+ )
103
+
104
+
105
+ @pytest.mark.asyncio
106
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
107
+ async def test_milvus_destination(
108
+ upload_file: Path,
109
+ collection: str,
110
+ tmp_path: Path,
111
+ ):
112
+ file_data = FileData(
113
+ source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
114
+ connector_type=CONNECTOR_TYPE,
115
+ identifier="mock file data",
116
+ )
117
+ stager = MilvusUploadStager()
118
+ uploader = MilvusUploader(
119
+ connection_config=MilvusConnectionConfig(uri=DB_URI),
120
+ upload_config=MilvusUploaderConfig(collection_name=collection, db_name=DB_NAME),
121
+ )
122
+ staged_filepath = stager.run(
123
+ elements_filepath=upload_file,
124
+ file_data=file_data,
125
+ output_dir=tmp_path,
126
+ output_filename=upload_file.name,
127
+ )
128
+ uploader.precheck()
129
+ uploader.run(path=staged_filepath, file_data=file_data)
130
+
131
+ # Run validation
132
+ with staged_filepath.open() as f:
133
+ staged_elements = json.load(f)
134
+ expected_count = len(staged_elements)
135
+ with uploader.get_client() as client:
136
+ validate_count(client=client, expected_count=expected_count)
137
+
138
+ # Rerun and make sure the same documents get updated
139
+ uploader.run(path=staged_filepath, file_data=file_data)
140
+ with uploader.get_client() as client:
141
+ validate_count(client=client, expected_count=expected_count)