unstructured-ingest 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (93) hide show
  1. test/integration/connectors/elasticsearch/__init__.py +0 -0
  2. test/integration/connectors/elasticsearch/conftest.py +34 -0
  3. test/integration/connectors/elasticsearch/test_elasticsearch.py +308 -0
  4. test/integration/connectors/elasticsearch/test_opensearch.py +302 -0
  5. test/integration/connectors/sql/test_postgres.py +10 -4
  6. test/integration/connectors/sql/test_singlestore.py +8 -4
  7. test/integration/connectors/sql/test_snowflake.py +10 -6
  8. test/integration/connectors/sql/test_sqlite.py +4 -4
  9. test/integration/connectors/test_astradb.py +156 -0
  10. test/integration/connectors/test_azure_cog_search.py +233 -0
  11. test/integration/connectors/test_delta_table.py +46 -0
  12. test/integration/connectors/test_kafka.py +150 -16
  13. test/integration/connectors/test_lancedb.py +209 -0
  14. test/integration/connectors/test_milvus.py +141 -0
  15. test/integration/connectors/test_pinecone.py +213 -0
  16. test/integration/connectors/test_s3.py +23 -0
  17. test/integration/connectors/utils/docker.py +81 -15
  18. test/integration/connectors/utils/validation.py +10 -0
  19. test/integration/connectors/weaviate/__init__.py +0 -0
  20. test/integration/connectors/weaviate/conftest.py +15 -0
  21. test/integration/connectors/weaviate/test_local.py +131 -0
  22. test/unit/v2/__init__.py +0 -0
  23. test/unit/v2/chunkers/__init__.py +0 -0
  24. test/unit/v2/chunkers/test_chunkers.py +49 -0
  25. test/unit/v2/connectors/__init__.py +0 -0
  26. test/unit/v2/embedders/__init__.py +0 -0
  27. test/unit/v2/embedders/test_bedrock.py +36 -0
  28. test/unit/v2/embedders/test_huggingface.py +48 -0
  29. test/unit/v2/embedders/test_mixedbread.py +37 -0
  30. test/unit/v2/embedders/test_octoai.py +35 -0
  31. test/unit/v2/embedders/test_openai.py +35 -0
  32. test/unit/v2/embedders/test_togetherai.py +37 -0
  33. test/unit/v2/embedders/test_vertexai.py +37 -0
  34. test/unit/v2/embedders/test_voyageai.py +38 -0
  35. test/unit/v2/partitioners/__init__.py +0 -0
  36. test/unit/v2/partitioners/test_partitioner.py +63 -0
  37. test/unit/v2/utils/__init__.py +0 -0
  38. test/unit/v2/utils/data_generator.py +32 -0
  39. unstructured_ingest/__version__.py +1 -1
  40. unstructured_ingest/cli/cmds/__init__.py +2 -2
  41. unstructured_ingest/cli/cmds/{azure_cognitive_search.py → azure_ai_search.py} +9 -9
  42. unstructured_ingest/connector/{azure_cognitive_search.py → azure_ai_search.py} +9 -9
  43. unstructured_ingest/pipeline/reformat/embedding.py +1 -1
  44. unstructured_ingest/runner/writers/__init__.py +2 -2
  45. unstructured_ingest/runner/writers/azure_ai_search.py +24 -0
  46. unstructured_ingest/utils/data_prep.py +9 -1
  47. unstructured_ingest/v2/constants.py +2 -0
  48. unstructured_ingest/v2/processes/connectors/__init__.py +7 -20
  49. unstructured_ingest/v2/processes/connectors/airtable.py +2 -2
  50. unstructured_ingest/v2/processes/connectors/astradb.py +35 -23
  51. unstructured_ingest/v2/processes/connectors/{azure_cognitive_search.py → azure_ai_search.py} +116 -35
  52. unstructured_ingest/v2/processes/connectors/confluence.py +2 -2
  53. unstructured_ingest/v2/processes/connectors/couchbase.py +1 -0
  54. unstructured_ingest/v2/processes/connectors/delta_table.py +37 -9
  55. unstructured_ingest/v2/processes/connectors/elasticsearch/__init__.py +19 -0
  56. unstructured_ingest/v2/processes/connectors/{elasticsearch.py → elasticsearch/elasticsearch.py} +93 -46
  57. unstructured_ingest/v2/processes/connectors/{opensearch.py → elasticsearch/opensearch.py} +1 -1
  58. unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +27 -0
  59. unstructured_ingest/v2/processes/connectors/google_drive.py +3 -3
  60. unstructured_ingest/v2/processes/connectors/kafka/__init__.py +6 -2
  61. unstructured_ingest/v2/processes/connectors/kafka/cloud.py +38 -2
  62. unstructured_ingest/v2/processes/connectors/kafka/kafka.py +84 -23
  63. unstructured_ingest/v2/processes/connectors/kafka/local.py +32 -4
  64. unstructured_ingest/v2/processes/connectors/lancedb/__init__.py +17 -0
  65. unstructured_ingest/v2/processes/connectors/lancedb/aws.py +43 -0
  66. unstructured_ingest/v2/processes/connectors/lancedb/azure.py +43 -0
  67. unstructured_ingest/v2/processes/connectors/lancedb/gcp.py +44 -0
  68. unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +161 -0
  69. unstructured_ingest/v2/processes/connectors/lancedb/local.py +44 -0
  70. unstructured_ingest/v2/processes/connectors/milvus.py +72 -27
  71. unstructured_ingest/v2/processes/connectors/onedrive.py +2 -3
  72. unstructured_ingest/v2/processes/connectors/outlook.py +2 -2
  73. unstructured_ingest/v2/processes/connectors/pinecone.py +101 -13
  74. unstructured_ingest/v2/processes/connectors/sharepoint.py +3 -2
  75. unstructured_ingest/v2/processes/connectors/slack.py +2 -2
  76. unstructured_ingest/v2/processes/connectors/sql/postgres.py +16 -8
  77. unstructured_ingest/v2/processes/connectors/sql/sql.py +97 -26
  78. unstructured_ingest/v2/processes/connectors/weaviate/__init__.py +22 -0
  79. unstructured_ingest/v2/processes/connectors/weaviate/cloud.py +164 -0
  80. unstructured_ingest/v2/processes/connectors/weaviate/embedded.py +90 -0
  81. unstructured_ingest/v2/processes/connectors/weaviate/local.py +73 -0
  82. unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +289 -0
  83. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/METADATA +20 -19
  84. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/RECORD +91 -50
  85. unstructured_ingest/runner/writers/azure_cognitive_search.py +0 -24
  86. unstructured_ingest/v2/processes/connectors/weaviate.py +0 -242
  87. /test/integration/embedders/{togetherai.py → test_togetherai.py} +0 -0
  88. /test/unit/{test_interfaces_v2.py → v2/test_interfaces.py} +0 -0
  89. /test/unit/{test_utils_v2.py → v2/test_utils.py} +0 -0
  90. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/LICENSE.md +0 -0
  91. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/WHEEL +0 -0
  92. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/entry_points.txt +0 -0
  93. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,14 @@
1
- import socket
1
+ import json
2
2
  import tempfile
3
+ import time
3
4
  from pathlib import Path
4
5
 
5
6
  import pytest
6
- from confluent_kafka import Producer
7
+ from confluent_kafka import Consumer, KafkaError, KafkaException, Producer
8
+ from confluent_kafka.admin import AdminClient, NewTopic
7
9
 
8
10
  from test.integration.connectors.utils.constants import (
11
+ DESTINATION_TAG,
9
12
  SOURCE_TAG,
10
13
  env_setup_path,
11
14
  )
@@ -14,6 +17,8 @@ from test.integration.connectors.utils.validation import (
14
17
  ValidationConfigs,
15
18
  source_connector_validation,
16
19
  )
20
+ from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
21
+ from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
17
22
  from unstructured_ingest.v2.processes.connectors.kafka.local import (
18
23
  CONNECTOR_TYPE,
19
24
  LocalKafkaConnectionConfig,
@@ -21,27 +26,64 @@ from unstructured_ingest.v2.processes.connectors.kafka.local import (
21
26
  LocalKafkaDownloaderConfig,
22
27
  LocalKafkaIndexer,
23
28
  LocalKafkaIndexerConfig,
29
+ LocalKafkaUploader,
30
+ LocalKafkaUploaderConfig,
24
31
  )
25
32
 
26
33
  SEED_MESSAGES = 10
27
34
  TOPIC = "fake-topic"
28
35
 
29
36
 
37
+ def get_admin_client() -> AdminClient:
38
+ conf = {
39
+ "bootstrap.servers": "localhost:29092",
40
+ }
41
+ return AdminClient(conf)
42
+
43
+
30
44
  @pytest.fixture
31
- def kafka_seed_topic() -> str:
32
- with docker_compose_context(docker_compose_path=env_setup_path / "kafka"):
33
- conf = {
34
- "bootstrap.servers": "localhost:29092",
35
- "client.id": socket.gethostname(),
36
- "message.max.bytes": 10485760,
37
- }
38
- producer = Producer(conf)
39
- for i in range(SEED_MESSAGES):
40
- message = f"This is some text for message {i}"
41
- producer.produce(topic=TOPIC, value=message)
42
- producer.flush(timeout=10)
43
- print(f"kafka topic {TOPIC} seeded with {SEED_MESSAGES} messages")
44
- yield TOPIC
45
+ def docker_compose_ctx():
46
+ with docker_compose_context(docker_compose_path=env_setup_path / "kafka") as ctx:
47
+ yield ctx
48
+
49
+
50
+ def wait_for_topic(topic: str, retries: int = 10, interval: int = 1):
51
+ admin_client = get_admin_client()
52
+ current_topics = admin_client.list_topics().topics
53
+ attempts = 0
54
+ while topic not in current_topics and attempts < retries:
55
+ attempts += 1
56
+ print(
57
+ "Attempt {}: Waiting for topic {} to exist in {}".format(
58
+ attempts, topic, ", ".join(current_topics)
59
+ )
60
+ )
61
+ time.sleep(interval)
62
+ current_topics = admin_client.list_topics().topics
63
+ if topic not in current_topics:
64
+ raise TimeoutError(f"Timeout out waiting for topic {topic} to exist")
65
+
66
+
67
+ @pytest.fixture
68
+ def kafka_seed_topic(docker_compose_ctx) -> str:
69
+ conf = {
70
+ "bootstrap.servers": "localhost:29092",
71
+ }
72
+ producer = Producer(conf)
73
+ for i in range(SEED_MESSAGES):
74
+ message = f"This is some text for message {i}"
75
+ producer.produce(topic=TOPIC, value=message)
76
+ producer.flush(timeout=10)
77
+ print(f"kafka topic {TOPIC} seeded with {SEED_MESSAGES} messages")
78
+ wait_for_topic(topic=TOPIC)
79
+ return TOPIC
80
+
81
+
82
+ @pytest.fixture
83
+ def kafka_upload_topic(docker_compose_ctx) -> str:
84
+ admin_client = get_admin_client()
85
+ admin_client.create_topics([NewTopic(TOPIC, 1, 1)])
86
+ return TOPIC
45
87
 
46
88
 
47
89
  @pytest.mark.asyncio
@@ -58,6 +100,7 @@ async def test_kafka_source_local(kafka_seed_topic: str):
58
100
  downloader = LocalKafkaDownloader(
59
101
  connection_config=connection_config, download_config=download_config
60
102
  )
103
+ indexer.precheck()
61
104
  await source_connector_validation(
62
105
  indexer=indexer,
63
106
  downloader=downloader,
@@ -65,3 +108,94 @@ async def test_kafka_source_local(kafka_seed_topic: str):
65
108
  test_id="kafka", expected_num_files=5, validate_downloaded_files=True
66
109
  ),
67
110
  )
111
+
112
+
113
+ @pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
114
+ def test_kafka_source_local_precheck_fail_no_cluster():
115
+ connection_config = LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092)
116
+ indexer = LocalKafkaIndexer(
117
+ connection_config=connection_config,
118
+ index_config=LocalKafkaIndexerConfig(topic=TOPIC, num_messages_to_consume=5),
119
+ )
120
+ with pytest.raises(SourceConnectionError):
121
+ indexer.precheck()
122
+
123
+
124
+ @pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
125
+ def test_kafka_source_local_precheck_fail_no_topic(kafka_seed_topic: str):
126
+ connection_config = LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092)
127
+ indexer = LocalKafkaIndexer(
128
+ connection_config=connection_config,
129
+ index_config=LocalKafkaIndexerConfig(topic="topic", num_messages_to_consume=5),
130
+ )
131
+ with pytest.raises(SourceConnectionError):
132
+ indexer.precheck()
133
+
134
+
135
+ def get_all_messages(topic: str, max_empty_messages: int = 5) -> list[dict]:
136
+ conf = {
137
+ "bootstrap.servers": "localhost:29092",
138
+ "group.id": "default_group_id",
139
+ "enable.auto.commit": "false",
140
+ "auto.offset.reset": "earliest",
141
+ }
142
+ consumer = Consumer(conf)
143
+ consumer.subscribe([topic])
144
+ messages = []
145
+ try:
146
+ empty_count = 0
147
+ while empty_count < max_empty_messages:
148
+ msg = consumer.poll(timeout=1)
149
+ if msg is None:
150
+ empty_count += 1
151
+ continue
152
+ if msg.error():
153
+ if msg.error().code() == KafkaError._PARTITION_EOF:
154
+ break
155
+ else:
156
+ raise KafkaException(msg.error())
157
+ try:
158
+ message = json.loads(msg.value().decode("utf8"))
159
+ messages.append(message)
160
+ finally:
161
+ consumer.commit(asynchronous=False)
162
+ finally:
163
+ print("closing consumer")
164
+ consumer.close()
165
+ return messages
166
+
167
+
168
+ @pytest.mark.asyncio
169
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
170
+ async def test_kafka_destination_local(upload_file: Path, kafka_upload_topic: str):
171
+ uploader = LocalKafkaUploader(
172
+ connection_config=LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092),
173
+ upload_config=LocalKafkaUploaderConfig(topic=TOPIC, batch_size=10),
174
+ )
175
+ file_data = FileData(
176
+ source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
177
+ connector_type=CONNECTOR_TYPE,
178
+ identifier="mock file data",
179
+ )
180
+ uploader.precheck()
181
+ if uploader.is_async():
182
+ await uploader.run_async(path=upload_file, file_data=file_data)
183
+ else:
184
+ uploader.run(path=upload_file, file_data=file_data)
185
+ all_messages = get_all_messages(topic=kafka_upload_topic)
186
+ with upload_file.open("r") as upload_fs:
187
+ content_to_upload = json.load(upload_fs)
188
+ assert len(all_messages) == len(content_to_upload), (
189
+ f"expected number of messages ({len(content_to_upload)}) doesn't "
190
+ f"match how many messages read off of kakfa topic {kafka_upload_topic}: {len(all_messages)}"
191
+ )
192
+
193
+
194
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
195
+ def test_kafka_destination_local_precheck_fail_no_cluster():
196
+ uploader = LocalKafkaUploader(
197
+ connection_config=LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092),
198
+ upload_config=LocalKafkaUploaderConfig(topic=TOPIC, batch_size=10),
199
+ )
200
+ with pytest.raises(DestinationConnectionError):
201
+ uploader.precheck()
@@ -0,0 +1,209 @@
1
+ import os
2
+ from pathlib import Path
3
+ from typing import Literal, Union
4
+
5
+ import lancedb
6
+ import pandas as pd
7
+ import pyarrow as pa
8
+ import pytest
9
+ import pytest_asyncio
10
+ from lancedb import AsyncConnection
11
+ from upath import UPath
12
+
13
+ from test.integration.connectors.utils.constants import DESTINATION_TAG
14
+ from unstructured_ingest.v2.interfaces.file_data import FileData, SourceIdentifiers
15
+ from unstructured_ingest.v2.processes.connectors.lancedb.aws import (
16
+ LanceDBS3AccessConfig,
17
+ LanceDBS3ConnectionConfig,
18
+ LanceDBS3Uploader,
19
+ )
20
+ from unstructured_ingest.v2.processes.connectors.lancedb.azure import (
21
+ LanceDBAzureAccessConfig,
22
+ LanceDBAzureConnectionConfig,
23
+ LanceDBAzureUploader,
24
+ )
25
+ from unstructured_ingest.v2.processes.connectors.lancedb.gcp import (
26
+ LanceDBGCSAccessConfig,
27
+ LanceDBGCSConnectionConfig,
28
+ LanceDBGSPUploader,
29
+ )
30
+ from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
31
+ CONNECTOR_TYPE,
32
+ LanceDBUploaderConfig,
33
+ LanceDBUploadStager,
34
+ )
35
+ from unstructured_ingest.v2.processes.connectors.lancedb.local import (
36
+ LanceDBLocalAccessConfig,
37
+ LanceDBLocalConnectionConfig,
38
+ LanceDBLocalUploader,
39
+ )
40
+
41
+ DATABASE_NAME = "database"
42
+ TABLE_NAME = "elements"
43
+ DIMENSION = 384
44
+ NUMBER_EXPECTED_ROWS = 22
45
+ NUMBER_EXPECTED_COLUMNS = 10
46
+ S3_BUCKET = "s3://utic-ingest-test-fixtures/"
47
+ GS_BUCKET = "gs://utic-test-ingest-fixtures-output/"
48
+ AZURE_BUCKET = "az://utic-ingest-test-fixtures-output/"
49
+ REQUIRED_ENV_VARS = {
50
+ "s3": ("S3_INGEST_TEST_ACCESS_KEY", "S3_INGEST_TEST_SECRET_KEY"),
51
+ "gcs": ("GCP_INGEST_SERVICE_KEY",),
52
+ "az": ("AZURE_DEST_CONNECTION_STR",),
53
+ "local": (),
54
+ }
55
+
56
+
57
+ SCHEMA = pa.schema(
58
+ [
59
+ pa.field("vector", pa.list_(pa.float16(), DIMENSION)),
60
+ pa.field("text", pa.string(), nullable=True),
61
+ pa.field("type", pa.string(), nullable=True),
62
+ pa.field("element_id", pa.string(), nullable=True),
63
+ pa.field("metadata-text_as_html", pa.string(), nullable=True),
64
+ pa.field("metadata-filetype", pa.string(), nullable=True),
65
+ pa.field("metadata-filename", pa.string(), nullable=True),
66
+ pa.field("metadata-languages", pa.list_(pa.string()), nullable=True),
67
+ pa.field("metadata-is_continuation", pa.bool_(), nullable=True),
68
+ pa.field("metadata-page_number", pa.int32(), nullable=True),
69
+ ]
70
+ )
71
+
72
+
73
+ @pytest_asyncio.fixture
74
+ async def connection_with_uri(request, tmp_path: Path):
75
+ target = request.param
76
+ uri = _get_uri(target, local_base_path=tmp_path)
77
+
78
+ unset_variables = [env for env in REQUIRED_ENV_VARS[target] if env not in os.environ]
79
+ if unset_variables:
80
+ pytest.skip(
81
+ reason="Following required environment variables were not set: "
82
+ + f"{', '.join(unset_variables)}"
83
+ )
84
+
85
+ storage_options = {
86
+ "aws_access_key_id": os.getenv("S3_INGEST_TEST_ACCESS_KEY"),
87
+ "aws_secret_access_key": os.getenv("S3_INGEST_TEST_SECRET_KEY"),
88
+ "google_service_account_key": os.getenv("GCP_INGEST_SERVICE_KEY"),
89
+ }
90
+ azure_connection_string = os.getenv("AZURE_DEST_CONNECTION_STR")
91
+ if azure_connection_string:
92
+ storage_options.update(_parse_azure_connection_string(azure_connection_string))
93
+
94
+ storage_options = {key: value for key, value in storage_options.items() if value is not None}
95
+ connection = await lancedb.connect_async(
96
+ uri=uri,
97
+ storage_options=storage_options,
98
+ )
99
+ await connection.create_table(name=TABLE_NAME, schema=SCHEMA)
100
+
101
+ yield connection, uri
102
+
103
+ await connection.drop_database()
104
+
105
+
106
+ @pytest.mark.asyncio
107
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
108
+ @pytest.mark.parametrize("connection_with_uri", ["local", "s3", "gcs", "az"], indirect=True)
109
+ async def test_lancedb_destination(
110
+ upload_file: Path,
111
+ connection_with_uri: tuple[AsyncConnection, str],
112
+ tmp_path: Path,
113
+ ) -> None:
114
+ connection, uri = connection_with_uri
115
+ file_data = FileData(
116
+ source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
117
+ connector_type=CONNECTOR_TYPE,
118
+ identifier="mock file data",
119
+ )
120
+ stager = LanceDBUploadStager()
121
+ uploader = _get_uploader(uri)
122
+ staged_file_path = stager.run(
123
+ elements_filepath=upload_file,
124
+ file_data=file_data,
125
+ output_dir=tmp_path,
126
+ output_filename=upload_file.name,
127
+ )
128
+
129
+ await uploader.run_async(path=staged_file_path, file_data=file_data)
130
+
131
+ table = await connection.open_table(TABLE_NAME)
132
+ table_df: pd.DataFrame = await table.to_pandas()
133
+
134
+ assert len(table_df) == NUMBER_EXPECTED_ROWS
135
+ assert len(table_df.columns) == NUMBER_EXPECTED_COLUMNS
136
+
137
+ assert table_df["element_id"][0] == "2470d8dc42215b3d68413b55bf00fed2"
138
+ assert table_df["type"][0] == "CompositeElement"
139
+ assert table_df["metadata-filename"][0] == "DA-1p-with-duplicate-pages.pdf.json"
140
+ assert table_df["metadata-text_as_html"][0] is None
141
+
142
+
143
+ def _get_uri(target: Literal["local", "s3", "gcs", "az"], local_base_path: Path) -> str:
144
+ if target == "local":
145
+ return str(local_base_path / DATABASE_NAME)
146
+ if target == "s3":
147
+ base_uri = UPath(S3_BUCKET)
148
+ elif target == "gcs":
149
+ base_uri = UPath(GS_BUCKET)
150
+ elif target == "az":
151
+ base_uri = UPath(AZURE_BUCKET)
152
+
153
+ return str(base_uri / "destination" / "lancedb" / DATABASE_NAME)
154
+
155
+
156
+ def _get_uploader(
157
+ uri: str,
158
+ ) -> Union[LanceDBAzureUploader, LanceDBAzureUploader, LanceDBS3Uploader, LanceDBGSPUploader]:
159
+ target = uri.split("://", maxsplit=1)[0] if uri.startswith(("s3", "az", "gs")) else "local"
160
+ if target == "az":
161
+ azure_connection_string = os.getenv("AZURE_DEST_CONNECTION_STR")
162
+ access_config_kwargs = _parse_azure_connection_string(azure_connection_string)
163
+ return LanceDBAzureUploader(
164
+ upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
165
+ connection_config=LanceDBAzureConnectionConfig(
166
+ access_config=LanceDBAzureAccessConfig(**access_config_kwargs),
167
+ uri=uri,
168
+ ),
169
+ )
170
+
171
+ elif target == "s3":
172
+ return LanceDBS3Uploader(
173
+ upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
174
+ connection_config=LanceDBS3ConnectionConfig(
175
+ access_config=LanceDBS3AccessConfig(
176
+ aws_access_key_id=os.getenv("S3_INGEST_TEST_ACCESS_KEY"),
177
+ aws_secret_access_key=os.getenv("S3_INGEST_TEST_SECRET_KEY"),
178
+ ),
179
+ uri=uri,
180
+ ),
181
+ )
182
+ elif target == "gs":
183
+ return LanceDBGSPUploader(
184
+ upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
185
+ connection_config=LanceDBGCSConnectionConfig(
186
+ access_config=LanceDBGCSAccessConfig(
187
+ google_service_account_key=os.getenv("GCP_INGEST_SERVICE_KEY")
188
+ ),
189
+ uri=uri,
190
+ ),
191
+ )
192
+ else:
193
+ return LanceDBLocalUploader(
194
+ upload_config=LanceDBUploaderConfig(table_name=TABLE_NAME),
195
+ connection_config=LanceDBLocalConnectionConfig(
196
+ access_config=LanceDBLocalAccessConfig(),
197
+ uri=uri,
198
+ ),
199
+ )
200
+
201
+
202
+ def _parse_azure_connection_string(
203
+ connection_str: str,
204
+ ) -> dict[Literal["azure_storage_account_name", "azure_storage_account_key"], str]:
205
+ parameters = dict(keyvalue.split("=", maxsplit=1) for keyvalue in connection_str.split(";"))
206
+ return {
207
+ "azure_storage_account_name": parameters.get("AccountName"),
208
+ "azure_storage_account_key": parameters.get("AccountKey"),
209
+ }
@@ -0,0 +1,141 @@
1
+ import json
2
+ import time
3
+ from pathlib import Path
4
+
5
+ import docker
6
+ import pytest
7
+ from pymilvus import (
8
+ CollectionSchema,
9
+ DataType,
10
+ FieldSchema,
11
+ MilvusClient,
12
+ )
13
+ from pymilvus.milvus_client import IndexParams
14
+
15
+ from test.integration.connectors.utils.constants import DESTINATION_TAG, env_setup_path
16
+ from test.integration.connectors.utils.docker import healthcheck_wait
17
+ from test.integration.connectors.utils.docker_compose import docker_compose_context
18
+ from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
19
+ from unstructured_ingest.v2.processes.connectors.milvus import (
20
+ CONNECTOR_TYPE,
21
+ MilvusConnectionConfig,
22
+ MilvusUploader,
23
+ MilvusUploaderConfig,
24
+ MilvusUploadStager,
25
+ )
26
+
27
+ DB_URI = "http://localhost:19530"
28
+ DB_NAME = "test_database"
29
+ COLLECTION_NAME = "test_collection"
30
+
31
+
32
+ def get_schema() -> CollectionSchema:
33
+ id_field = FieldSchema(
34
+ name="id", dtype=DataType.INT64, description="primary field", is_primary=True, auto_id=True
35
+ )
36
+ embeddings_field = FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=384)
37
+ record_id_field = FieldSchema(name="record_id", dtype=DataType.VARCHAR, max_length=64)
38
+
39
+ schema = CollectionSchema(
40
+ enable_dynamic_field=True,
41
+ fields=[
42
+ id_field,
43
+ record_id_field,
44
+ embeddings_field,
45
+ ],
46
+ )
47
+
48
+ return schema
49
+
50
+
51
+ def get_index_params() -> IndexParams:
52
+ index_params = IndexParams()
53
+ index_params.add_index(field_name="embeddings", index_type="AUTOINDEX", metric_type="COSINE")
54
+ index_params.add_index(field_name="record_id", index_type="Trie")
55
+ return index_params
56
+
57
+
58
+ @pytest.fixture
59
+ def collection():
60
+ docker_client = docker.from_env()
61
+ with docker_compose_context(docker_compose_path=env_setup_path / "milvus"):
62
+ milvus_container = docker_client.containers.get("milvus-standalone")
63
+ healthcheck_wait(container=milvus_container)
64
+ milvus_client = MilvusClient(uri=DB_URI)
65
+ try:
66
+ # Create the database
67
+ database_resp = milvus_client._get_connection().create_database(db_name=DB_NAME)
68
+ milvus_client.using_database(db_name=DB_NAME)
69
+
70
+ print(f"Created database {DB_NAME}: {database_resp}")
71
+
72
+ # Create the collection
73
+ schema = get_schema()
74
+ index_params = get_index_params()
75
+ collection_resp = milvus_client.create_collection(
76
+ collection_name=COLLECTION_NAME, schema=schema, index_params=index_params
77
+ )
78
+ print(f"Created collection {COLLECTION_NAME}: {collection_resp}")
79
+ yield COLLECTION_NAME
80
+ finally:
81
+ milvus_client.close()
82
+
83
+
84
+ def get_count(client: MilvusClient) -> int:
85
+ count_field = "count(*)"
86
+ resp = client.query(collection_name="test_collection", output_fields=[count_field])
87
+ return resp[0][count_field]
88
+
89
+
90
+ def validate_count(
91
+ client: MilvusClient, expected_count: int, retries: int = 10, interval: int = 1
92
+ ) -> None:
93
+ current_count = get_count(client=client)
94
+ retry_count = 0
95
+ while current_count != expected_count and retry_count < retries:
96
+ time.sleep(interval)
97
+ current_count = get_count(client=client)
98
+ retry_count += 1
99
+ assert current_count == expected_count, (
100
+ f"Expected count ({expected_count}) doesn't match how "
101
+ f"much came back from collection: {current_count}"
102
+ )
103
+
104
+
105
+ @pytest.mark.asyncio
106
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
107
+ async def test_milvus_destination(
108
+ upload_file: Path,
109
+ collection: str,
110
+ tmp_path: Path,
111
+ ):
112
+ file_data = FileData(
113
+ source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
114
+ connector_type=CONNECTOR_TYPE,
115
+ identifier="mock file data",
116
+ )
117
+ stager = MilvusUploadStager()
118
+ uploader = MilvusUploader(
119
+ connection_config=MilvusConnectionConfig(uri=DB_URI),
120
+ upload_config=MilvusUploaderConfig(collection_name=collection, db_name=DB_NAME),
121
+ )
122
+ staged_filepath = stager.run(
123
+ elements_filepath=upload_file,
124
+ file_data=file_data,
125
+ output_dir=tmp_path,
126
+ output_filename=upload_file.name,
127
+ )
128
+ uploader.precheck()
129
+ uploader.run(path=staged_filepath, file_data=file_data)
130
+
131
+ # Run validation
132
+ with staged_filepath.open() as f:
133
+ staged_elements = json.load(f)
134
+ expected_count = len(staged_elements)
135
+ with uploader.get_client() as client:
136
+ validate_count(client=client, expected_count=expected_count)
137
+
138
+ # Rerun and make sure the same documents get updated
139
+ uploader.run(path=staged_filepath, file_data=file_data)
140
+ with uploader.get_client() as client:
141
+ validate_count(client=client, expected_count=expected_count)