unstructured-ingest 0.2.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (59) hide show
  1. test/integration/connectors/test_astradb.py +109 -0
  2. test/integration/connectors/test_azure_cog_search.py +233 -0
  3. test/integration/connectors/test_kafka.py +116 -16
  4. test/integration/connectors/test_pinecone.py +161 -0
  5. test/integration/connectors/test_s3.py +23 -0
  6. test/unit/v2/__init__.py +0 -0
  7. test/unit/v2/chunkers/__init__.py +0 -0
  8. test/unit/v2/chunkers/test_chunkers.py +49 -0
  9. test/unit/v2/connectors/__init__.py +0 -0
  10. test/unit/v2/embedders/__init__.py +0 -0
  11. test/unit/v2/embedders/test_bedrock.py +36 -0
  12. test/unit/v2/embedders/test_huggingface.py +48 -0
  13. test/unit/v2/embedders/test_mixedbread.py +37 -0
  14. test/unit/v2/embedders/test_octoai.py +35 -0
  15. test/unit/v2/embedders/test_openai.py +35 -0
  16. test/unit/v2/embedders/test_togetherai.py +37 -0
  17. test/unit/v2/embedders/test_vertexai.py +37 -0
  18. test/unit/v2/embedders/test_voyageai.py +38 -0
  19. test/unit/v2/partitioners/__init__.py +0 -0
  20. test/unit/v2/partitioners/test_partitioner.py +63 -0
  21. test/unit/v2/utils/__init__.py +0 -0
  22. test/unit/v2/utils/data_generator.py +32 -0
  23. unstructured_ingest/__version__.py +1 -1
  24. unstructured_ingest/cli/cmds/__init__.py +2 -2
  25. unstructured_ingest/cli/cmds/{azure_cognitive_search.py → azure_ai_search.py} +9 -9
  26. unstructured_ingest/connector/{azure_cognitive_search.py → azure_ai_search.py} +9 -9
  27. unstructured_ingest/runner/writers/__init__.py +2 -2
  28. unstructured_ingest/runner/writers/azure_ai_search.py +24 -0
  29. unstructured_ingest/v2/constants.py +2 -0
  30. unstructured_ingest/v2/processes/connectors/__init__.py +4 -4
  31. unstructured_ingest/v2/processes/connectors/airtable.py +2 -2
  32. unstructured_ingest/v2/processes/connectors/astradb.py +33 -21
  33. unstructured_ingest/v2/processes/connectors/{azure_cognitive_search.py → azure_ai_search.py} +112 -35
  34. unstructured_ingest/v2/processes/connectors/confluence.py +2 -2
  35. unstructured_ingest/v2/processes/connectors/couchbase.py +1 -0
  36. unstructured_ingest/v2/processes/connectors/delta_table.py +17 -5
  37. unstructured_ingest/v2/processes/connectors/elasticsearch.py +1 -0
  38. unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +27 -0
  39. unstructured_ingest/v2/processes/connectors/google_drive.py +3 -3
  40. unstructured_ingest/v2/processes/connectors/kafka/__init__.py +6 -2
  41. unstructured_ingest/v2/processes/connectors/kafka/cloud.py +38 -2
  42. unstructured_ingest/v2/processes/connectors/kafka/kafka.py +78 -23
  43. unstructured_ingest/v2/processes/connectors/kafka/local.py +32 -4
  44. unstructured_ingest/v2/processes/connectors/onedrive.py +2 -3
  45. unstructured_ingest/v2/processes/connectors/outlook.py +2 -2
  46. unstructured_ingest/v2/processes/connectors/pinecone.py +83 -12
  47. unstructured_ingest/v2/processes/connectors/sharepoint.py +3 -2
  48. unstructured_ingest/v2/processes/connectors/slack.py +2 -2
  49. unstructured_ingest/v2/processes/connectors/sql/postgres.py +16 -8
  50. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.0.dist-info}/METADATA +20 -19
  51. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.0.dist-info}/RECORD +58 -37
  52. unstructured_ingest/runner/writers/azure_cognitive_search.py +0 -24
  53. /test/integration/embedders/{togetherai.py → test_togetherai.py} +0 -0
  54. /test/unit/{test_interfaces_v2.py → v2/test_interfaces.py} +0 -0
  55. /test/unit/{test_utils_v2.py → v2/test_utils.py} +0 -0
  56. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.0.dist-info}/LICENSE.md +0 -0
  57. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.0.dist-info}/WHEEL +0 -0
  58. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.0.dist-info}/entry_points.txt +0 -0
  59. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,109 @@
1
+ import json
2
+ import os
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from uuid import uuid4
6
+
7
+ import pytest
8
+ from astrapy import Collection
9
+ from astrapy import DataAPIClient as AstraDBClient
10
+
11
+ from test.integration.connectors.utils.constants import (
12
+ DESTINATION_TAG,
13
+ )
14
+ from test.integration.utils import requires_env
15
+ from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
16
+ from unstructured_ingest.v2.processes.connectors.astradb import (
17
+ CONNECTOR_TYPE,
18
+ AstraDBAccessConfig,
19
+ AstraDBConnectionConfig,
20
+ AstraDBUploader,
21
+ AstraDBUploaderConfig,
22
+ AstraDBUploadStager,
23
+ )
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class EnvData:
28
+ api_endpoint: str
29
+ token: str
30
+
31
+
32
+ def get_env_data() -> EnvData:
33
+ return EnvData(
34
+ api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
35
+ token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
36
+ )
37
+
38
+
39
+ @pytest.fixture
40
+ def collection(upload_file: Path) -> Collection:
41
+ random_id = str(uuid4())[:8]
42
+ collection_name = f"utic_test_{random_id}"
43
+ with upload_file.open("r") as upload_fp:
44
+ upload_data = json.load(upload_fp)
45
+ first_content = upload_data[0]
46
+ embeddings = first_content["embeddings"]
47
+ embedding_dimension = len(embeddings)
48
+ my_client = AstraDBClient()
49
+ env_data = get_env_data()
50
+ astra_db = my_client.get_database(
51
+ api_endpoint=env_data.api_endpoint,
52
+ token=env_data.token,
53
+ )
54
+ collection = astra_db.create_collection(collection_name, dimension=embedding_dimension)
55
+ try:
56
+ yield collection
57
+ finally:
58
+ astra_db.drop_collection(collection)
59
+
60
+
61
+ @pytest.mark.asyncio
62
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
63
+ @requires_env("ASTRA_DB_API_ENDPOINT", "ASTRA_DB_APPLICATION_TOKEN")
64
+ async def test_azure_ai_search_destination(
65
+ upload_file: Path,
66
+ collection: Collection,
67
+ tmp_path: Path,
68
+ ):
69
+ file_data = FileData(
70
+ source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
71
+ connector_type=CONNECTOR_TYPE,
72
+ identifier="mock file data",
73
+ )
74
+ stager = AstraDBUploadStager()
75
+ env_data = get_env_data()
76
+ uploader = AstraDBUploader(
77
+ connection_config=AstraDBConnectionConfig(
78
+ access_config=AstraDBAccessConfig(
79
+ api_endpoint=env_data.api_endpoint, token=env_data.token
80
+ ),
81
+ ),
82
+ upload_config=AstraDBUploaderConfig(collection_name=collection.name),
83
+ )
84
+ staged_filepath = stager.run(
85
+ elements_filepath=upload_file,
86
+ file_data=file_data,
87
+ output_dir=tmp_path,
88
+ output_filename=upload_file.name,
89
+ )
90
+ uploader.precheck()
91
+ uploader.run(path=staged_filepath, file_data=file_data)
92
+
93
+ # Run validation
94
+ with staged_filepath.open() as f:
95
+ staged_elements = json.load(f)
96
+ expected_count = len(staged_elements)
97
+ current_count = collection.count_documents(filter={}, upper_bound=expected_count * 2)
98
+ assert current_count == expected_count, (
99
+ f"Expected count ({expected_count}) doesn't match how "
100
+ f"much came back from collection: {current_count}"
101
+ )
102
+
103
+ # Rerun and make sure the same documents get updated
104
+ uploader.run(path=staged_filepath, file_data=file_data)
105
+ current_count = collection.count_documents(filter={}, upper_bound=expected_count * 2)
106
+ assert current_count == expected_count, (
107
+ f"Expected count ({expected_count}) doesn't match how "
108
+ f"much came back from collection: {current_count}"
109
+ )
@@ -0,0 +1,233 @@
1
+ import json
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+ from uuid import uuid4
6
+
7
+ import pytest
8
+ from azure.core.credentials import AzureKeyCredential
9
+ from azure.search.documents import SearchClient
10
+ from azure.search.documents.indexes import SearchIndexClient
11
+ from azure.search.documents.indexes.models import (
12
+ ComplexField,
13
+ CorsOptions,
14
+ HnswAlgorithmConfiguration,
15
+ HnswParameters,
16
+ SearchField,
17
+ SearchFieldDataType,
18
+ SearchIndex,
19
+ SimpleField,
20
+ VectorSearch,
21
+ VectorSearchAlgorithmMetric,
22
+ VectorSearchProfile,
23
+ )
24
+
25
+ from test.integration.connectors.utils.constants import (
26
+ DESTINATION_TAG,
27
+ )
28
+ from test.integration.utils import requires_env
29
+ from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
30
+ from unstructured_ingest.v2.processes.connectors.azure_ai_search import (
31
+ CONNECTOR_TYPE,
32
+ RECORD_ID_LABEL,
33
+ AzureAISearchAccessConfig,
34
+ AzureAISearchConnectionConfig,
35
+ AzureAISearchUploader,
36
+ AzureAISearchUploaderConfig,
37
+ AzureAISearchUploadStager,
38
+ AzureAISearchUploadStagerConfig,
39
+ )
40
+
41
+ repo_path = Path(__file__).parent.resolve()
42
+
43
+ API_KEY = "AZURE_SEARCH_API_KEY"
44
+ ENDPOINT = "https://ingest-test-azure-cognitive-search.search.windows.net"
45
+
46
+
47
+ def get_api_key() -> str:
48
+ key = os.environ[API_KEY]
49
+ return key
50
+
51
+
52
+ def get_fields() -> list:
53
+ data_source_fields = [
54
+ SimpleField(name="url", type=SearchFieldDataType.String),
55
+ SimpleField(name="version", type=SearchFieldDataType.String),
56
+ SimpleField(name="date_created", type=SearchFieldDataType.DateTimeOffset),
57
+ SimpleField(name="date_modified", type=SearchFieldDataType.DateTimeOffset),
58
+ SimpleField(name="date_processed", type=SearchFieldDataType.DateTimeOffset),
59
+ SimpleField(name="permissions_data", type=SearchFieldDataType.String),
60
+ SimpleField(name="record_locator", type=SearchFieldDataType.String),
61
+ ]
62
+ coordinates_fields = [
63
+ SimpleField(name="system", type=SearchFieldDataType.String),
64
+ SimpleField(name="layout_width", type=SearchFieldDataType.Double),
65
+ SimpleField(name="layout_height", type=SearchFieldDataType.Double),
66
+ SimpleField(name="points", type=SearchFieldDataType.String),
67
+ ]
68
+ metadata_fields = [
69
+ SimpleField(name="orig_elements", type=SearchFieldDataType.String),
70
+ SimpleField(name="category_depth", type=SearchFieldDataType.Int32),
71
+ SimpleField(name="parent_id", type=SearchFieldDataType.String),
72
+ SimpleField(name="attached_to_filename", type=SearchFieldDataType.String),
73
+ SimpleField(name="filetype", type=SearchFieldDataType.String),
74
+ SimpleField(name="last_modified", type=SearchFieldDataType.DateTimeOffset),
75
+ SimpleField(name="is_continuation", type=SearchFieldDataType.Boolean),
76
+ SimpleField(name="file_directory", type=SearchFieldDataType.String),
77
+ SimpleField(name="filename", type=SearchFieldDataType.String),
78
+ ComplexField(name="data_source", fields=data_source_fields),
79
+ ComplexField(name="coordinates", fields=coordinates_fields),
80
+ SimpleField(
81
+ name="languages", type=SearchFieldDataType.Collection(SearchFieldDataType.String)
82
+ ),
83
+ SimpleField(name="page_number", type=SearchFieldDataType.String),
84
+ SimpleField(name="links", type=SearchFieldDataType.Collection(SearchFieldDataType.String)),
85
+ SimpleField(name="page_name", type=SearchFieldDataType.String),
86
+ SimpleField(name="url", type=SearchFieldDataType.String),
87
+ SimpleField(
88
+ name="link_urls", type=SearchFieldDataType.Collection(SearchFieldDataType.String)
89
+ ),
90
+ SimpleField(
91
+ name="link_texts", type=SearchFieldDataType.Collection(SearchFieldDataType.String)
92
+ ),
93
+ SimpleField(
94
+ name="sent_from", type=SearchFieldDataType.Collection(SearchFieldDataType.String)
95
+ ),
96
+ SimpleField(
97
+ name="sent_to", type=SearchFieldDataType.Collection(SearchFieldDataType.String)
98
+ ),
99
+ SimpleField(name="subject", type=SearchFieldDataType.String),
100
+ SimpleField(name="section", type=SearchFieldDataType.String),
101
+ SimpleField(name="header_footer_type", type=SearchFieldDataType.String),
102
+ SimpleField(
103
+ name="emphasized_text_contents",
104
+ type=SearchFieldDataType.Collection(SearchFieldDataType.String),
105
+ ),
106
+ SimpleField(
107
+ name="emphasized_text_tags",
108
+ type=SearchFieldDataType.Collection(SearchFieldDataType.String),
109
+ ),
110
+ SimpleField(name="text_as_html", type=SearchFieldDataType.String),
111
+ SimpleField(name="regex_metadata", type=SearchFieldDataType.String),
112
+ SimpleField(name="detection_class_prob", type=SearchFieldDataType.Double),
113
+ ]
114
+ fields = [
115
+ SimpleField(name="id", type=SearchFieldDataType.String, key=True),
116
+ SimpleField(name=RECORD_ID_LABEL, type=SearchFieldDataType.String, filterable=True),
117
+ SimpleField(name="element_id", type=SearchFieldDataType.String),
118
+ SimpleField(name="text", type=SearchFieldDataType.String),
119
+ SimpleField(name="type", type=SearchFieldDataType.String),
120
+ ComplexField(name="metadata", fields=metadata_fields),
121
+ SearchField(
122
+ name="embeddings",
123
+ type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
124
+ vector_search_dimensions=384,
125
+ vector_search_profile_name="embeddings-config-profile",
126
+ ),
127
+ ]
128
+ return fields
129
+
130
+
131
+ def get_vector_search() -> VectorSearch:
132
+ return VectorSearch(
133
+ algorithms=[
134
+ HnswAlgorithmConfiguration(
135
+ name="hnsw-config",
136
+ parameters=HnswParameters(
137
+ metric=VectorSearchAlgorithmMetric.COSINE,
138
+ ),
139
+ )
140
+ ],
141
+ profiles=[
142
+ VectorSearchProfile(
143
+ name="embeddings-config-profile", algorithm_configuration_name="hnsw-config"
144
+ )
145
+ ],
146
+ )
147
+
148
+
149
+ def get_search_index_client() -> SearchIndexClient:
150
+ api_key = get_api_key()
151
+ return SearchIndexClient(ENDPOINT, AzureKeyCredential(api_key))
152
+
153
+
154
+ @pytest.fixture
155
+ def index() -> str:
156
+ random_id = str(uuid4())[:8]
157
+ index_name = f"utic-test-{random_id}"
158
+ client = get_search_index_client()
159
+ index = SearchIndex(
160
+ name=index_name,
161
+ fields=get_fields(),
162
+ vector_search=get_vector_search(),
163
+ cors_options=CorsOptions(allowed_origins=["*"], max_age_in_seconds=60),
164
+ )
165
+ print(f"creating index: {index_name}")
166
+ client.create_index(index=index)
167
+ try:
168
+ yield index_name
169
+ finally:
170
+ print(f"deleting index: {index_name}")
171
+ client.delete_index(index)
172
+
173
+
174
+ def validate_count(
175
+ search_client: SearchClient, expected_count: int, retries: int = 10, interval: int = 1
176
+ ) -> None:
177
+ index_count = search_client.get_document_count()
178
+ if index_count == expected_count:
179
+ return
180
+ tries = 0
181
+ while tries < retries:
182
+ time.sleep(interval)
183
+ index_count = search_client.get_document_count()
184
+ if index_count == expected_count:
185
+ break
186
+ assert index_count == expected_count, (
187
+ f"Expected count ({expected_count}) doesn't match how "
188
+ f"much came back from index: {index_count}"
189
+ )
190
+
191
+
192
+ @pytest.mark.asyncio
193
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
194
+ @requires_env("AZURE_SEARCH_API_KEY")
195
+ async def test_azure_ai_search_destination(
196
+ upload_file: Path,
197
+ index: str,
198
+ tmp_path: Path,
199
+ ):
200
+ file_data = FileData(
201
+ source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
202
+ connector_type=CONNECTOR_TYPE,
203
+ identifier="mock file data",
204
+ )
205
+ stager = AzureAISearchUploadStager(upload_stager_config=AzureAISearchUploadStagerConfig())
206
+
207
+ uploader = AzureAISearchUploader(
208
+ connection_config=AzureAISearchConnectionConfig(
209
+ access_config=AzureAISearchAccessConfig(key=get_api_key()),
210
+ endpoint=ENDPOINT,
211
+ index=index,
212
+ ),
213
+ upload_config=AzureAISearchUploaderConfig(),
214
+ )
215
+ staged_filepath = stager.run(
216
+ elements_filepath=upload_file,
217
+ file_data=file_data,
218
+ output_dir=tmp_path,
219
+ output_filename=upload_file.name,
220
+ )
221
+ uploader.precheck()
222
+ uploader.run(path=staged_filepath, file_data=file_data)
223
+
224
+ # Run validation
225
+ with staged_filepath.open() as f:
226
+ staged_elements = json.load(f)
227
+ expected_count = len(staged_elements)
228
+ search_client: SearchClient = uploader.connection_config.get_search_client()
229
+ validate_count(search_client=search_client, expected_count=expected_count)
230
+
231
+ # Rerun and make sure the same documents get updated
232
+ uploader.run(path=staged_filepath, file_data=file_data)
233
+ validate_count(search_client=search_client, expected_count=expected_count)
@@ -1,11 +1,13 @@
1
- import socket
1
+ import json
2
2
  import tempfile
3
3
  from pathlib import Path
4
4
 
5
5
  import pytest
6
- from confluent_kafka import Producer
6
+ from confluent_kafka import Consumer, KafkaError, KafkaException, Producer
7
+ from confluent_kafka.admin import AdminClient, NewTopic
7
8
 
8
9
  from test.integration.connectors.utils.constants import (
10
+ DESTINATION_TAG,
9
11
  SOURCE_TAG,
10
12
  env_setup_path,
11
13
  )
@@ -14,6 +16,8 @@ from test.integration.connectors.utils.validation import (
14
16
  ValidationConfigs,
15
17
  source_connector_validation,
16
18
  )
19
+ from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
20
+ from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
17
21
  from unstructured_ingest.v2.processes.connectors.kafka.local import (
18
22
  CONNECTOR_TYPE,
19
23
  LocalKafkaConnectionConfig,
@@ -21,6 +25,8 @@ from unstructured_ingest.v2.processes.connectors.kafka.local import (
21
25
  LocalKafkaDownloaderConfig,
22
26
  LocalKafkaIndexer,
23
27
  LocalKafkaIndexerConfig,
28
+ LocalKafkaUploader,
29
+ LocalKafkaUploaderConfig,
24
30
  )
25
31
 
26
32
  SEED_MESSAGES = 10
@@ -28,20 +34,33 @@ TOPIC = "fake-topic"
28
34
 
29
35
 
30
36
  @pytest.fixture
31
- def kafka_seed_topic() -> str:
32
- with docker_compose_context(docker_compose_path=env_setup_path / "kafka"):
33
- conf = {
34
- "bootstrap.servers": "localhost:29092",
35
- "client.id": socket.gethostname(),
36
- "message.max.bytes": 10485760,
37
- }
38
- producer = Producer(conf)
39
- for i in range(SEED_MESSAGES):
40
- message = f"This is some text for message {i}"
41
- producer.produce(topic=TOPIC, value=message)
42
- producer.flush(timeout=10)
43
- print(f"kafka topic {TOPIC} seeded with {SEED_MESSAGES} messages")
44
- yield TOPIC
37
+ def docker_compose_ctx():
38
+ with docker_compose_context(docker_compose_path=env_setup_path / "kafka") as ctx:
39
+ yield ctx
40
+
41
+
42
+ @pytest.fixture
43
+ def kafka_seed_topic(docker_compose_ctx) -> str:
44
+ conf = {
45
+ "bootstrap.servers": "localhost:29092",
46
+ }
47
+ producer = Producer(conf)
48
+ for i in range(SEED_MESSAGES):
49
+ message = f"This is some text for message {i}"
50
+ producer.produce(topic=TOPIC, value=message)
51
+ producer.flush(timeout=10)
52
+ print(f"kafka topic {TOPIC} seeded with {SEED_MESSAGES} messages")
53
+ return TOPIC
54
+
55
+
56
+ @pytest.fixture
57
+ def kafka_upload_topic(docker_compose_ctx) -> str:
58
+ conf = {
59
+ "bootstrap.servers": "localhost:29092",
60
+ }
61
+ admin_client = AdminClient(conf)
62
+ admin_client.create_topics([NewTopic(TOPIC, 1, 1)])
63
+ return TOPIC
45
64
 
46
65
 
47
66
  @pytest.mark.asyncio
@@ -58,6 +77,7 @@ async def test_kafka_source_local(kafka_seed_topic: str):
58
77
  downloader = LocalKafkaDownloader(
59
78
  connection_config=connection_config, download_config=download_config
60
79
  )
80
+ indexer.precheck()
61
81
  await source_connector_validation(
62
82
  indexer=indexer,
63
83
  downloader=downloader,
@@ -65,3 +85,83 @@ async def test_kafka_source_local(kafka_seed_topic: str):
65
85
  test_id="kafka", expected_num_files=5, validate_downloaded_files=True
66
86
  ),
67
87
  )
88
+
89
+
90
+ @pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
91
+ def test_kafak_source_local_precheck_fail():
92
+ connection_config = LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092)
93
+ indexer = LocalKafkaIndexer(
94
+ connection_config=connection_config,
95
+ index_config=LocalKafkaIndexerConfig(topic=TOPIC, num_messages_to_consume=5),
96
+ )
97
+ with pytest.raises(SourceConnectionError):
98
+ indexer.precheck()
99
+
100
+
101
+ def get_all_messages(topic: str, max_empty_messages: int = 5) -> list[dict]:
102
+ conf = {
103
+ "bootstrap.servers": "localhost:29092",
104
+ "group.id": "default_group_id",
105
+ "enable.auto.commit": "false",
106
+ "auto.offset.reset": "earliest",
107
+ }
108
+ consumer = Consumer(conf)
109
+ consumer.subscribe([topic])
110
+ messages = []
111
+ try:
112
+ empty_count = 0
113
+ while empty_count < max_empty_messages:
114
+ msg = consumer.poll(timeout=1)
115
+ if msg is None:
116
+ empty_count += 1
117
+ continue
118
+ if msg.error():
119
+ if msg.error().code() == KafkaError._PARTITION_EOF:
120
+ break
121
+ else:
122
+ raise KafkaException(msg.error())
123
+ try:
124
+ message = json.loads(msg.value().decode("utf8"))
125
+ messages.append(message)
126
+ finally:
127
+ consumer.commit(asynchronous=False)
128
+ finally:
129
+ print("closing consumer")
130
+ consumer.close()
131
+ return messages
132
+
133
+
134
+ @pytest.mark.asyncio
135
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
136
+ async def test_kafka_destination_local(upload_file: Path, kafka_upload_topic: str):
137
+ uploader = LocalKafkaUploader(
138
+ connection_config=LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092),
139
+ upload_config=LocalKafkaUploaderConfig(topic=TOPIC, batch_size=10),
140
+ )
141
+ file_data = FileData(
142
+ source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
143
+ connector_type=CONNECTOR_TYPE,
144
+ identifier="mock file data",
145
+ )
146
+ uploader.precheck()
147
+ if uploader.is_async():
148
+ await uploader.run_async(path=upload_file, file_data=file_data)
149
+ else:
150
+ uploader.run(path=upload_file, file_data=file_data)
151
+ all_messages = get_all_messages(topic=kafka_upload_topic)
152
+ with upload_file.open("r") as upload_fs:
153
+ content_to_upload = json.load(upload_fs)
154
+ assert len(all_messages) == len(content_to_upload), (
155
+ f"expected number of messages ({len(content_to_upload)}) doesn't "
156
+ f"match how many messages read off of kakfa topic {kafka_upload_topic}: {len(all_messages)}"
157
+ )
158
+
159
+
160
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
161
+ def test_kafak_destination_local_precheck_fail():
162
+ uploader = LocalKafkaUploader(
163
+ connection_config=LocalKafkaConnectionConfig(bootstrap_server="localhost", port=29092),
164
+ upload_config=LocalKafkaUploaderConfig(topic=TOPIC, batch_size=10),
165
+ )
166
+ with pytest.raises(DestinationConnectionError):
167
+ uploader.precheck()
@@ -0,0 +1,161 @@
1
+ import json
2
+ import os
3
+ import time
4
+ from pathlib import Path
5
+ from uuid import uuid4
6
+
7
+ import pytest
8
+ from pinecone import Pinecone, ServerlessSpec
9
+ from pinecone.core.openapi.shared.exceptions import NotFoundException
10
+
11
+ from test.integration.connectors.utils.constants import (
12
+ DESTINATION_TAG,
13
+ )
14
+ from test.integration.utils import requires_env
15
+ from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
16
+ from unstructured_ingest.v2.logger import logger
17
+ from unstructured_ingest.v2.processes.connectors.pinecone import (
18
+ CONNECTOR_TYPE,
19
+ PineconeAccessConfig,
20
+ PineconeConnectionConfig,
21
+ PineconeUploader,
22
+ PineconeUploaderConfig,
23
+ PineconeUploadStager,
24
+ PineconeUploadStagerConfig,
25
+ )
26
+
27
+ API_KEY = "PINECONE_API_KEY"
28
+
29
+
30
+ def get_api_key() -> str:
31
+ api_key = os.getenv(API_KEY, None)
32
+ assert api_key
33
+ return api_key
34
+
35
+
36
+ def wait_for_delete(client: Pinecone, index_name: str, timeout=60, interval=1) -> None:
37
+ start = time.time()
38
+ while True and time.time() - start < timeout:
39
+ try:
40
+ description = client.describe_index(name=index_name)
41
+ logger.info(f"current index status: {description}")
42
+ except NotFoundException:
43
+ return
44
+ time.sleep(interval)
45
+
46
+ raise TimeoutError("time out waiting for index to delete")
47
+
48
+
49
+ def wait_for_ready(client: Pinecone, index_name: str, timeout=60, interval=1) -> None:
50
+ def is_ready_status():
51
+ description = client.describe_index(name=index_name)
52
+ status = description["status"]
53
+ return status["ready"]
54
+
55
+ start = time.time()
56
+ is_ready = is_ready_status()
57
+ while not is_ready and time.time() - start < timeout:
58
+ time.sleep(interval)
59
+ is_ready = is_ready_status()
60
+ if not is_ready:
61
+ raise TimeoutError("time out waiting for index to be ready")
62
+
63
+
64
+ @pytest.fixture
65
+ def pinecone_index() -> str:
66
+ pinecone = Pinecone(api_key=get_api_key())
67
+ random_id = str(uuid4()).split("-")[0]
68
+ index_name = f"ingest-test-{random_id}"
69
+ assert len(index_name) < 45
70
+ logger.info(f"Creating index: {index_name}")
71
+ try:
72
+ pinecone.create_index(
73
+ name=index_name,
74
+ dimension=384,
75
+ metric="cosine",
76
+ spec=ServerlessSpec(
77
+ cloud="aws",
78
+ region="us-east-1",
79
+ ),
80
+ deletion_protection="disabled",
81
+ )
82
+ wait_for_ready(client=pinecone, index_name=index_name)
83
+ yield index_name
84
+ except Exception as e:
85
+ logger.error(f"failed to create index {index_name}: {e}")
86
+ finally:
87
+ try:
88
+ logger.info(f"deleting index: {index_name}")
89
+ pinecone.delete_index(name=index_name)
90
+ wait_for_delete(client=pinecone, index_name=index_name)
91
+ except NotFoundException:
92
+ return
93
+
94
+
95
+ def validate_pinecone_index(
96
+ index_name: str, expected_num_of_vectors: int, retries=30, interval=1
97
+ ) -> None:
98
+ # Because there's a delay for the index to catch up to the recent writes, add in a retry
99
+ pinecone = Pinecone(api_key=get_api_key())
100
+ index = pinecone.Index(name=index_name)
101
+ vector_count = -1
102
+ for i in range(retries):
103
+ index_stats = index.describe_index_stats()
104
+ vector_count = index_stats["total_vector_count"]
105
+ if vector_count == expected_num_of_vectors:
106
+ logger.info(f"expected {expected_num_of_vectors} == vector count {vector_count}")
107
+ break
108
+ logger.info(
109
+ f"retry attempt {i}: expected {expected_num_of_vectors} != vector count {vector_count}"
110
+ )
111
+ time.sleep(interval)
112
+ assert vector_count == expected_num_of_vectors
113
+
114
+
115
+ @requires_env(API_KEY)
116
+ @pytest.mark.asyncio
117
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG)
118
+ async def test_pinecone_destination(pinecone_index: str, upload_file: Path, temp_dir: Path):
119
+ file_data = FileData(
120
+ source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
121
+ connector_type=CONNECTOR_TYPE,
122
+ identifier="pinecone_mock_id",
123
+ )
124
+ connection_config = PineconeConnectionConfig(
125
+ index_name=pinecone_index,
126
+ access_config=PineconeAccessConfig(api_key=get_api_key()),
127
+ )
128
+ stager_config = PineconeUploadStagerConfig()
129
+ stager = PineconeUploadStager(upload_stager_config=stager_config)
130
+ new_upload_file = stager.run(
131
+ elements_filepath=upload_file,
132
+ output_dir=temp_dir,
133
+ output_filename=upload_file.name,
134
+ file_data=file_data,
135
+ )
136
+
137
+ upload_config = PineconeUploaderConfig()
138
+ uploader = PineconeUploader(connection_config=connection_config, upload_config=upload_config)
139
+ uploader.precheck()
140
+
141
+ if uploader.is_async():
142
+ await uploader.run_async(path=new_upload_file, file_data=file_data)
143
+ else:
144
+ uploader.run(path=new_upload_file, file_data=file_data)
145
+ with new_upload_file.open() as f:
146
+ staged_content = json.load(f)
147
+ expected_num_of_vectors = len(staged_content)
148
+ logger.info("validating first upload")
149
+ validate_pinecone_index(
150
+ index_name=pinecone_index, expected_num_of_vectors=expected_num_of_vectors
151
+ )
152
+
153
+ # Rerun uploader and make sure no duplicates exist
154
+ if uploader.is_async():
155
+ await uploader.run_async(path=new_upload_file, file_data=file_data)
156
+ else:
157
+ uploader.run(path=new_upload_file, file_data=file_data)
158
+ logger.info("validating second upload")
159
+ validate_pinecone_index(
160
+ index_name=pinecone_index, expected_num_of_vectors=expected_num_of_vectors
161
+ )