unstructured-ingest 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (93) hide show
  1. test/integration/connectors/elasticsearch/__init__.py +0 -0
  2. test/integration/connectors/elasticsearch/conftest.py +34 -0
  3. test/integration/connectors/elasticsearch/test_elasticsearch.py +308 -0
  4. test/integration/connectors/elasticsearch/test_opensearch.py +302 -0
  5. test/integration/connectors/sql/test_postgres.py +10 -4
  6. test/integration/connectors/sql/test_singlestore.py +8 -4
  7. test/integration/connectors/sql/test_snowflake.py +10 -6
  8. test/integration/connectors/sql/test_sqlite.py +4 -4
  9. test/integration/connectors/test_astradb.py +156 -0
  10. test/integration/connectors/test_azure_cog_search.py +233 -0
  11. test/integration/connectors/test_delta_table.py +46 -0
  12. test/integration/connectors/test_kafka.py +150 -16
  13. test/integration/connectors/test_lancedb.py +209 -0
  14. test/integration/connectors/test_milvus.py +141 -0
  15. test/integration/connectors/test_pinecone.py +213 -0
  16. test/integration/connectors/test_s3.py +23 -0
  17. test/integration/connectors/utils/docker.py +81 -15
  18. test/integration/connectors/utils/validation.py +10 -0
  19. test/integration/connectors/weaviate/__init__.py +0 -0
  20. test/integration/connectors/weaviate/conftest.py +15 -0
  21. test/integration/connectors/weaviate/test_local.py +131 -0
  22. test/unit/v2/__init__.py +0 -0
  23. test/unit/v2/chunkers/__init__.py +0 -0
  24. test/unit/v2/chunkers/test_chunkers.py +49 -0
  25. test/unit/v2/connectors/__init__.py +0 -0
  26. test/unit/v2/embedders/__init__.py +0 -0
  27. test/unit/v2/embedders/test_bedrock.py +36 -0
  28. test/unit/v2/embedders/test_huggingface.py +48 -0
  29. test/unit/v2/embedders/test_mixedbread.py +37 -0
  30. test/unit/v2/embedders/test_octoai.py +35 -0
  31. test/unit/v2/embedders/test_openai.py +35 -0
  32. test/unit/v2/embedders/test_togetherai.py +37 -0
  33. test/unit/v2/embedders/test_vertexai.py +37 -0
  34. test/unit/v2/embedders/test_voyageai.py +38 -0
  35. test/unit/v2/partitioners/__init__.py +0 -0
  36. test/unit/v2/partitioners/test_partitioner.py +63 -0
  37. test/unit/v2/utils/__init__.py +0 -0
  38. test/unit/v2/utils/data_generator.py +32 -0
  39. unstructured_ingest/__version__.py +1 -1
  40. unstructured_ingest/cli/cmds/__init__.py +2 -2
  41. unstructured_ingest/cli/cmds/{azure_cognitive_search.py → azure_ai_search.py} +9 -9
  42. unstructured_ingest/connector/{azure_cognitive_search.py → azure_ai_search.py} +9 -9
  43. unstructured_ingest/pipeline/reformat/embedding.py +1 -1
  44. unstructured_ingest/runner/writers/__init__.py +2 -2
  45. unstructured_ingest/runner/writers/azure_ai_search.py +24 -0
  46. unstructured_ingest/utils/data_prep.py +9 -1
  47. unstructured_ingest/v2/constants.py +2 -0
  48. unstructured_ingest/v2/processes/connectors/__init__.py +7 -20
  49. unstructured_ingest/v2/processes/connectors/airtable.py +2 -2
  50. unstructured_ingest/v2/processes/connectors/astradb.py +35 -23
  51. unstructured_ingest/v2/processes/connectors/{azure_cognitive_search.py → azure_ai_search.py} +116 -35
  52. unstructured_ingest/v2/processes/connectors/confluence.py +2 -2
  53. unstructured_ingest/v2/processes/connectors/couchbase.py +1 -0
  54. unstructured_ingest/v2/processes/connectors/delta_table.py +37 -9
  55. unstructured_ingest/v2/processes/connectors/elasticsearch/__init__.py +19 -0
  56. unstructured_ingest/v2/processes/connectors/{elasticsearch.py → elasticsearch/elasticsearch.py} +93 -46
  57. unstructured_ingest/v2/processes/connectors/{opensearch.py → elasticsearch/opensearch.py} +1 -1
  58. unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +27 -0
  59. unstructured_ingest/v2/processes/connectors/google_drive.py +3 -3
  60. unstructured_ingest/v2/processes/connectors/kafka/__init__.py +6 -2
  61. unstructured_ingest/v2/processes/connectors/kafka/cloud.py +38 -2
  62. unstructured_ingest/v2/processes/connectors/kafka/kafka.py +84 -23
  63. unstructured_ingest/v2/processes/connectors/kafka/local.py +32 -4
  64. unstructured_ingest/v2/processes/connectors/lancedb/__init__.py +17 -0
  65. unstructured_ingest/v2/processes/connectors/lancedb/aws.py +43 -0
  66. unstructured_ingest/v2/processes/connectors/lancedb/azure.py +43 -0
  67. unstructured_ingest/v2/processes/connectors/lancedb/gcp.py +44 -0
  68. unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +161 -0
  69. unstructured_ingest/v2/processes/connectors/lancedb/local.py +44 -0
  70. unstructured_ingest/v2/processes/connectors/milvus.py +72 -27
  71. unstructured_ingest/v2/processes/connectors/onedrive.py +2 -3
  72. unstructured_ingest/v2/processes/connectors/outlook.py +2 -2
  73. unstructured_ingest/v2/processes/connectors/pinecone.py +101 -13
  74. unstructured_ingest/v2/processes/connectors/sharepoint.py +3 -2
  75. unstructured_ingest/v2/processes/connectors/slack.py +2 -2
  76. unstructured_ingest/v2/processes/connectors/sql/postgres.py +16 -8
  77. unstructured_ingest/v2/processes/connectors/sql/sql.py +97 -26
  78. unstructured_ingest/v2/processes/connectors/weaviate/__init__.py +22 -0
  79. unstructured_ingest/v2/processes/connectors/weaviate/cloud.py +164 -0
  80. unstructured_ingest/v2/processes/connectors/weaviate/embedded.py +90 -0
  81. unstructured_ingest/v2/processes/connectors/weaviate/local.py +73 -0
  82. unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +289 -0
  83. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/METADATA +20 -19
  84. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/RECORD +91 -50
  85. unstructured_ingest/runner/writers/azure_cognitive_search.py +0 -24
  86. unstructured_ingest/v2/processes/connectors/weaviate.py +0 -242
  87. /test/integration/embedders/{togetherai.py → test_togetherai.py} +0 -0
  88. /test/unit/{test_interfaces_v2.py → v2/test_interfaces.py} +0 -0
  89. /test/unit/{test_utils_v2.py → v2/test_utils.py} +0 -0
  90. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/LICENSE.md +0 -0
  91. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/WHEEL +0 -0
  92. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/entry_points.txt +0 -0
  93. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ import hashlib
2
2
  import json
3
3
  import sys
4
4
  import uuid
5
+ from contextlib import contextmanager
5
6
  from dataclasses import dataclass, field
6
7
  from pathlib import Path
7
8
  from time import time
@@ -13,9 +14,11 @@ from unstructured_ingest.error import (
13
14
  DestinationConnectionError,
14
15
  SourceConnectionError,
15
16
  SourceConnectionNetworkError,
17
+ WriteError,
16
18
  )
17
19
  from unstructured_ingest.utils.data_prep import flatten_dict, generator_batching_wbytes
18
20
  from unstructured_ingest.utils.dep_check import requires_dependencies
21
+ from unstructured_ingest.v2.constants import RECORD_ID_LABEL
19
22
  from unstructured_ingest.v2.interfaces import (
20
23
  AccessConfig,
21
24
  ConnectionConfig,
@@ -26,6 +29,7 @@ from unstructured_ingest.v2.interfaces import (
26
29
  FileDataSourceMetadata,
27
30
  Indexer,
28
31
  IndexerConfig,
32
+ SourceIdentifiers,
29
33
  Uploader,
30
34
  UploaderConfig,
31
35
  UploadStager,
@@ -116,19 +120,12 @@ class ElasticsearchConnectionConfig(ConnectionConfig):
116
120
  return client_kwargs
117
121
 
118
122
  @requires_dependencies(["elasticsearch"], extras="elasticsearch")
119
- def get_client(self) -> "ElasticsearchClient":
123
+ @contextmanager
124
+ def get_client(self) -> Generator["ElasticsearchClient", None, None]:
120
125
  from elasticsearch import Elasticsearch as ElasticsearchClient
121
126
 
122
- client = ElasticsearchClient(**self.get_client_kwargs())
123
- self.check_connection(client=client)
124
- return client
125
-
126
- def check_connection(self, client: "ElasticsearchClient"):
127
- try:
128
- client.perform_request("HEAD", "/", headers={"accept": "application/json"})
129
- except Exception as e:
130
- logger.error(f"failed to validate connection: {e}", exc_info=True)
131
- raise SourceConnectionError(f"failed to validate connection: {e}")
127
+ with ElasticsearchClient(**self.get_client_kwargs()) as client:
128
+ yield client
132
129
 
133
130
 
134
131
  class ElasticsearchIndexerConfig(IndexerConfig):
@@ -144,7 +141,16 @@ class ElasticsearchIndexer(Indexer):
144
141
 
145
142
  def precheck(self) -> None:
146
143
  try:
147
- self.connection_config.get_client()
144
+ with self.connection_config.get_client() as client:
145
+ if not client.ping():
146
+ raise SourceConnectionError("cluster not detected")
147
+ indices = client.indices.get_alias(index="*")
148
+ if self.index_config.index_name not in indices:
149
+ raise SourceConnectionError(
150
+ "index {} not found: {}".format(
151
+ self.index_config.index_name, ", ".join(indices.keys())
152
+ )
153
+ )
148
154
  except Exception as e:
149
155
  logger.error(f"failed to validate connection: {e}", exc_info=True)
150
156
  raise SourceConnectionError(f"failed to validate connection: {e}")
@@ -160,15 +166,15 @@ class ElasticsearchIndexer(Indexer):
160
166
  scan = self.load_scan()
161
167
 
162
168
  scan_query: dict = {"stored_fields": [], "query": {"match_all": {}}}
163
- client = self.connection_config.get_client()
164
- hits = scan(
165
- client,
166
- query=scan_query,
167
- scroll="1m",
168
- index=self.index_config.index_name,
169
- )
169
+ with self.connection_config.get_client() as client:
170
+ hits = scan(
171
+ client,
172
+ query=scan_query,
173
+ scroll="1m",
174
+ index=self.index_config.index_name,
175
+ )
170
176
 
171
- return {hit["_id"] for hit in hits}
177
+ return {hit["_id"] for hit in hits}
172
178
 
173
179
  def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
174
180
  all_ids = self._get_doc_ids()
@@ -191,6 +197,7 @@ class ElasticsearchIndexer(Indexer):
191
197
  yield FileData(
192
198
  identifier=identified,
193
199
  connector_type=CONNECTOR_TYPE,
200
+ doc_type="batch",
194
201
  metadata=FileDataSourceMetadata(
195
202
  url=f"{self.connection_config.hosts[0]}/{self.index_config.index_name}",
196
203
  date_processed=str(time()),
@@ -256,6 +263,7 @@ class ElasticsearchDownloader(Downloader):
256
263
  file_data=FileData(
257
264
  identifier=filename_id,
258
265
  connector_type=CONNECTOR_TYPE,
266
+ source_identifiers=SourceIdentifiers(filename=filename, fullpath=filename),
259
267
  metadata=FileDataSourceMetadata(
260
268
  version=str(result["_version"]) if "_version" in result else None,
261
269
  date_processed=str(time()),
@@ -317,7 +325,7 @@ class ElasticsearchUploadStagerConfig(UploadStagerConfig):
317
325
  class ElasticsearchUploadStager(UploadStager):
318
326
  upload_stager_config: ElasticsearchUploadStagerConfig
319
327
 
320
- def conform_dict(self, data: dict) -> dict:
328
+ def conform_dict(self, data: dict, file_data: FileData) -> dict:
321
329
  resp = {
322
330
  "_index": self.upload_stager_config.index_name,
323
331
  "_id": str(uuid.uuid4()),
@@ -326,6 +334,7 @@ class ElasticsearchUploadStager(UploadStager):
326
334
  "embeddings": data.pop("embeddings", None),
327
335
  "text": data.pop("text", None),
328
336
  "type": data.pop("type", None),
337
+ RECORD_ID_LABEL: file_data.identifier,
329
338
  },
330
339
  }
331
340
  if "metadata" in data and isinstance(data["metadata"], dict):
@@ -342,10 +351,17 @@ class ElasticsearchUploadStager(UploadStager):
342
351
  ) -> Path:
343
352
  with open(elements_filepath) as elements_file:
344
353
  elements_contents = json.load(elements_file)
345
- conformed_elements = [self.conform_dict(data=element) for element in elements_contents]
346
- output_path = Path(output_dir) / Path(f"{output_filename}.json")
354
+ conformed_elements = [
355
+ self.conform_dict(data=element, file_data=file_data) for element in elements_contents
356
+ ]
357
+ if Path(output_filename).suffix != ".json":
358
+ output_filename = f"{output_filename}.json"
359
+ else:
360
+ output_filename = f"{Path(output_filename).stem}.json"
361
+ output_path = Path(output_dir) / output_filename
362
+ output_path.parent.mkdir(parents=True, exist_ok=True)
347
363
  with open(output_path, "w") as output_file:
348
- json.dump(conformed_elements, output_file)
364
+ json.dump(conformed_elements, output_file, indent=2)
349
365
  return output_path
350
366
 
351
367
 
@@ -362,6 +378,10 @@ class ElasticsearchUploaderConfig(UploaderConfig):
362
378
  num_threads: int = Field(
363
379
  default=4, description="Number of threads to be used while uploading content"
364
380
  )
381
+ record_id_key: str = Field(
382
+ default=RECORD_ID_LABEL,
383
+ description="searchable key to find entries for the same record on previous runs",
384
+ )
365
385
 
366
386
 
367
387
  @dataclass
@@ -372,7 +392,16 @@ class ElasticsearchUploader(Uploader):
372
392
 
373
393
  def precheck(self) -> None:
374
394
  try:
375
- self.connection_config.get_client()
395
+ with self.connection_config.get_client() as client:
396
+ if not client.ping():
397
+ raise DestinationConnectionError("cluster not detected")
398
+ indices = client.indices.get_alias(index="*")
399
+ if self.upload_config.index_name not in indices:
400
+ raise SourceConnectionError(
401
+ "index {} not found: {}".format(
402
+ self.upload_config.index_name, ", ".join(indices.keys())
403
+ )
404
+ )
376
405
  except Exception as e:
377
406
  logger.error(f"failed to validate connection: {e}", exc_info=True)
378
407
  raise DestinationConnectionError(f"failed to validate connection: {e}")
@@ -383,6 +412,23 @@ class ElasticsearchUploader(Uploader):
383
412
 
384
413
  return parallel_bulk
385
414
 
415
+ def delete_by_record_id(self, client, file_data: FileData) -> None:
416
+ logger.debug(
417
+ f"deleting any content with metadata {RECORD_ID_LABEL}={file_data.identifier} "
418
+ f"from {self.upload_config.index_name} index"
419
+ )
420
+ delete_resp = client.delete_by_query(
421
+ index=self.upload_config.index_name,
422
+ body={"query": {"match": {self.upload_config.record_id_key: file_data.identifier}}},
423
+ )
424
+ logger.info(
425
+ "deleted {} records from index {}".format(
426
+ delete_resp["deleted"], self.upload_config.index_name
427
+ )
428
+ )
429
+ if failures := delete_resp.get("failures"):
430
+ raise WriteError(f"failed to delete records: {failures}")
431
+
386
432
  def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
387
433
  parallel_bulk = self.load_parallel_bulk()
388
434
  with path.open("r") as file:
@@ -396,28 +442,29 @@ class ElasticsearchUploader(Uploader):
396
442
  f"{self.upload_config.num_threads} (number of) threads"
397
443
  )
398
444
 
399
- client = self.connection_config.get_client()
400
- if not client.indices.exists(index=self.upload_config.index_name):
401
- logger.warning(
402
- f"{(self.__class__.__name__).replace('Uploader', '')} index does not exist: "
403
- f"{self.upload_config.index_name}. "
404
- f"This may cause issues when uploading."
405
- )
406
- for batch in generator_batching_wbytes(
407
- elements_dict, batch_size_limit_bytes=self.upload_config.batch_size_bytes
408
- ):
409
- for success, info in parallel_bulk(
410
- client=client,
411
- actions=batch,
412
- thread_count=self.upload_config.num_threads,
445
+ with self.connection_config.get_client() as client:
446
+ self.delete_by_record_id(client=client, file_data=file_data)
447
+ if not client.indices.exists(index=self.upload_config.index_name):
448
+ logger.warning(
449
+ f"{(self.__class__.__name__).replace('Uploader', '')} index does not exist: "
450
+ f"{self.upload_config.index_name}. "
451
+ f"This may cause issues when uploading."
452
+ )
453
+ for batch in generator_batching_wbytes(
454
+ elements_dict, batch_size_limit_bytes=self.upload_config.batch_size_bytes
413
455
  ):
414
- if not success:
415
- logger.error(
416
- "upload failed for a batch in "
417
- f"{(self.__class__.__name__).replace('Uploader', '')} "
418
- "destination connector:",
419
- info,
420
- )
456
+ for success, info in parallel_bulk(
457
+ client=client,
458
+ actions=batch,
459
+ thread_count=self.upload_config.num_threads,
460
+ ):
461
+ if not success:
462
+ logger.error(
463
+ "upload failed for a batch in "
464
+ f"{(self.__class__.__name__).replace('Uploader', '')} "
465
+ "destination connector:",
466
+ info,
467
+ )
421
468
 
422
469
 
423
470
  elasticsearch_source_entry = SourceRegistryEntry(
@@ -17,7 +17,7 @@ from unstructured_ingest.v2.processes.connector_registry import (
17
17
  DestinationRegistryEntry,
18
18
  SourceRegistryEntry,
19
19
  )
20
- from unstructured_ingest.v2.processes.connectors.elasticsearch import (
20
+ from unstructured_ingest.v2.processes.connectors.elasticsearch.elasticsearch import (
21
21
  ElasticsearchDownloader,
22
22
  ElasticsearchDownloaderConfig,
23
23
  ElasticsearchIndexer,
@@ -1,6 +1,9 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import os
3
4
  import random
5
+ import shutil
6
+ import tempfile
4
7
  from dataclasses import dataclass, field
5
8
  from pathlib import Path
6
9
  from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar
@@ -207,12 +210,35 @@ class FsspecDownloader(Downloader):
207
210
  **self.connection_config.get_access_config(),
208
211
  )
209
212
 
213
+ def handle_directory_download(self, lpath: Path) -> None:
214
+ # If the object's name contains certain characters (i.e. '?'), it
215
+ # gets downloaded into a new directory of the same name. This
216
+ # reconciles that with what is expected, which is to download it
217
+ # as a file that is not within a directory.
218
+ if not lpath.is_dir():
219
+ return
220
+ desired_name = lpath.name
221
+ files_in_dir = [file for file in lpath.iterdir() if file.is_file()]
222
+ if not files_in_dir:
223
+ raise ValueError(f"no files in {lpath}")
224
+ if len(files_in_dir) > 1:
225
+ raise ValueError(
226
+ "Multiple files in {}: {}".format(lpath, ", ".join([str(f) for f in files_in_dir]))
227
+ )
228
+ file = files_in_dir[0]
229
+ with tempfile.TemporaryDirectory() as temp_dir:
230
+ temp_location = os.path.join(temp_dir, desired_name)
231
+ shutil.copyfile(src=file, dst=temp_location)
232
+ shutil.rmtree(lpath)
233
+ shutil.move(src=temp_location, dst=lpath)
234
+
210
235
  def run(self, file_data: FileData, **kwargs: Any) -> DownloadResponse:
211
236
  download_path = self.get_download_path(file_data=file_data)
212
237
  download_path.parent.mkdir(parents=True, exist_ok=True)
213
238
  try:
214
239
  rpath = file_data.additional_metadata["original_file_path"]
215
240
  self.fs.get(rpath=rpath, lpath=download_path.as_posix())
241
+ self.handle_directory_download(lpath=download_path)
216
242
  except Exception as e:
217
243
  logger.error(f"failed to download file {file_data.identifier}: {e}", exc_info=True)
218
244
  raise SourceConnectionNetworkError(f"failed to download file {file_data.identifier}")
@@ -224,6 +250,7 @@ class FsspecDownloader(Downloader):
224
250
  try:
225
251
  rpath = file_data.additional_metadata["original_file_path"]
226
252
  await self.fs.get(rpath=rpath, lpath=download_path.as_posix())
253
+ self.handle_directory_download(lpath=download_path)
227
254
  except Exception as e:
228
255
  logger.error(f"failed to download file {file_data.identifier}: {e}", exc_info=True)
229
256
  raise SourceConnectionNetworkError(f"failed to download file {file_data.identifier}")
@@ -19,12 +19,12 @@ from unstructured_ingest.v2.interfaces import (
19
19
  ConnectionConfig,
20
20
  Downloader,
21
21
  DownloaderConfig,
22
+ DownloadResponse,
22
23
  FileData,
23
24
  FileDataSourceMetadata,
24
25
  Indexer,
25
26
  IndexerConfig,
26
27
  SourceIdentifiers,
27
- download_responses,
28
28
  )
29
29
  from unstructured_ingest.v2.logger import logger
30
30
  from unstructured_ingest.v2.processes.connector_registry import SourceRegistryEntry
@@ -294,7 +294,7 @@ class GoogleDriveDownloader(Downloader):
294
294
  _, downloaded = downloader.next_chunk()
295
295
  return downloaded
296
296
 
297
- def _write_file(self, file_data: FileData, file_contents: io.BytesIO):
297
+ def _write_file(self, file_data: FileData, file_contents: io.BytesIO) -> DownloadResponse:
298
298
  download_path = self.get_download_path(file_data=file_data)
299
299
  download_path.parent.mkdir(parents=True, exist_ok=True)
300
300
  logger.debug(f"writing {file_data.source_identifiers.fullpath} to {download_path}")
@@ -303,7 +303,7 @@ class GoogleDriveDownloader(Downloader):
303
303
  return self.generate_download_response(file_data=file_data, download_path=download_path)
304
304
 
305
305
  @requires_dependencies(["googleapiclient"], extras="google-drive")
306
- def run(self, file_data: FileData, **kwargs: Any) -> download_responses:
306
+ def run(self, file_data: FileData, **kwargs: Any) -> DownloadResponse:
307
307
  from googleapiclient.http import MediaIoBaseDownload
308
308
 
309
309
  logger.debug(f"fetching file: {file_data.source_identifiers.fullpath}")
@@ -1,13 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from unstructured_ingest.v2.processes.connector_registry import (
4
+ add_destination_entry,
4
5
  add_source_entry,
5
6
  )
6
7
 
7
8
  from .cloud import CONNECTOR_TYPE as CLOUD_CONNECTOR
8
- from .cloud import kafka_cloud_source_entry
9
+ from .cloud import kafka_cloud_destination_entry, kafka_cloud_source_entry
9
10
  from .local import CONNECTOR_TYPE as LOCAL_CONNECTOR
10
- from .local import kafka_local_source_entry
11
+ from .local import kafka_local_destination_entry, kafka_local_source_entry
11
12
 
12
13
  add_source_entry(source_type=LOCAL_CONNECTOR, entry=kafka_local_source_entry)
14
+ add_destination_entry(destination_type=LOCAL_CONNECTOR, entry=kafka_local_destination_entry)
15
+
13
16
  add_source_entry(source_type=CLOUD_CONNECTOR, entry=kafka_cloud_source_entry)
17
+ add_destination_entry(destination_type=CLOUD_CONNECTOR, entry=kafka_cloud_destination_entry)
@@ -4,7 +4,10 @@ from typing import TYPE_CHECKING, Optional
4
4
 
5
5
  from pydantic import Field, Secret, SecretStr
6
6
 
7
- from unstructured_ingest.v2.processes.connector_registry import SourceRegistryEntry
7
+ from unstructured_ingest.v2.processes.connector_registry import (
8
+ DestinationRegistryEntry,
9
+ SourceRegistryEntry,
10
+ )
8
11
  from unstructured_ingest.v2.processes.connectors.kafka.kafka import (
9
12
  KafkaAccessConfig,
10
13
  KafkaConnectionConfig,
@@ -12,6 +15,8 @@ from unstructured_ingest.v2.processes.connectors.kafka.kafka import (
12
15
  KafkaDownloaderConfig,
13
16
  KafkaIndexer,
14
17
  KafkaIndexerConfig,
18
+ KafkaUploader,
19
+ KafkaUploaderConfig,
15
20
  )
16
21
 
17
22
  if TYPE_CHECKING:
@@ -41,7 +46,21 @@ class CloudKafkaConnectionConfig(KafkaConnectionConfig):
41
46
  "group.id": "default_group_id",
42
47
  "enable.auto.commit": "false",
43
48
  "auto.offset.reset": "earliest",
44
- "message.max.bytes": 10485760,
49
+ "sasl.username": access_config.api_key,
50
+ "sasl.password": access_config.secret,
51
+ "sasl.mechanism": "PLAIN",
52
+ "security.protocol": "SASL_SSL",
53
+ }
54
+
55
+ return conf
56
+
57
+ def get_producer_configuration(self) -> dict:
58
+ bootstrap = self.bootstrap_server
59
+ port = self.port
60
+ access_config = self.access_config.get_secret_value()
61
+
62
+ conf = {
63
+ "bootstrap.servers": f"{bootstrap}:{port}",
45
64
  "sasl.username": access_config.api_key,
46
65
  "sasl.password": access_config.secret,
47
66
  "sasl.mechanism": "PLAIN",
@@ -73,6 +92,17 @@ class CloudKafkaDownloader(KafkaDownloader):
73
92
  connector_type: str = CONNECTOR_TYPE
74
93
 
75
94
 
95
+ class CloudKafkaUploaderConfig(KafkaUploaderConfig):
96
+ pass
97
+
98
+
99
+ @dataclass
100
+ class CloudKafkaUploader(KafkaUploader):
101
+ connection_config: CloudKafkaConnectionConfig
102
+ upload_config: CloudKafkaUploaderConfig
103
+ connector_type: str = CONNECTOR_TYPE
104
+
105
+
76
106
  kafka_cloud_source_entry = SourceRegistryEntry(
77
107
  connection_config=CloudKafkaConnectionConfig,
78
108
  indexer=CloudKafkaIndexer,
@@ -80,3 +110,9 @@ kafka_cloud_source_entry = SourceRegistryEntry(
80
110
  downloader=CloudKafkaDownloader,
81
111
  downloader_config=CloudKafkaDownloaderConfig,
82
112
  )
113
+
114
+ kafka_cloud_destination_entry = DestinationRegistryEntry(
115
+ connection_config=CloudKafkaConnectionConfig,
116
+ uploader=CloudKafkaUploader,
117
+ uploader_config=CloudKafkaUploaderConfig,
118
+ )
@@ -1,3 +1,4 @@
1
+ import json
1
2
  from abc import ABC, abstractmethod
2
3
  from contextlib import contextmanager
3
4
  from dataclasses import dataclass, field
@@ -5,32 +6,33 @@ from pathlib import Path
5
6
  from time import time
6
7
  from typing import TYPE_CHECKING, Any, ContextManager, Generator, Optional
7
8
 
8
- from pydantic import Secret
9
+ from pydantic import Field, Secret
9
10
 
10
11
  from unstructured_ingest.error import (
12
+ DestinationConnectionError,
11
13
  SourceConnectionError,
12
14
  SourceConnectionNetworkError,
13
15
  )
16
+ from unstructured_ingest.utils.data_prep import batch_generator
14
17
  from unstructured_ingest.utils.dep_check import requires_dependencies
15
18
  from unstructured_ingest.v2.interfaces import (
16
19
  AccessConfig,
17
20
  ConnectionConfig,
18
21
  Downloader,
19
22
  DownloaderConfig,
23
+ DownloadResponse,
20
24
  FileData,
21
25
  FileDataSourceMetadata,
22
26
  Indexer,
23
27
  IndexerConfig,
24
28
  SourceIdentifiers,
25
- download_responses,
29
+ Uploader,
30
+ UploaderConfig,
26
31
  )
27
32
  from unstructured_ingest.v2.logger import logger
28
- from unstructured_ingest.v2.processes.connector_registry import SourceRegistryEntry
29
33
 
30
34
  if TYPE_CHECKING:
31
- from confluent_kafka import Consumer
32
-
33
- CONNECTOR_TYPE = "kafka"
35
+ from confluent_kafka import Consumer, Producer
34
36
 
35
37
 
36
38
  class KafkaAccessConfig(AccessConfig, ABC):
@@ -39,7 +41,6 @@ class KafkaAccessConfig(AccessConfig, ABC):
39
41
 
40
42
  class KafkaConnectionConfig(ConnectionConfig, ABC):
41
43
  access_config: Secret[KafkaAccessConfig]
42
- timeout: Optional[float] = 1.0
43
44
  bootstrap_server: str
44
45
  port: int
45
46
 
@@ -47,6 +48,10 @@ class KafkaConnectionConfig(ConnectionConfig, ABC):
47
48
  def get_consumer_configuration(self) -> dict:
48
49
  pass
49
50
 
51
+ @abstractmethod
52
+ def get_producer_configuration(self) -> dict:
53
+ pass
54
+
50
55
  @contextmanager
51
56
  @requires_dependencies(["confluent_kafka"], extras="kafka")
52
57
  def get_consumer(self) -> ContextManager["Consumer"]:
@@ -59,20 +64,27 @@ class KafkaConnectionConfig(ConnectionConfig, ABC):
59
64
  finally:
60
65
  consumer.close()
61
66
 
67
+ @requires_dependencies(["confluent_kafka"], extras="kafka")
68
+ def get_producer(self) -> "Producer":
69
+ from confluent_kafka import Producer
70
+
71
+ producer = Producer(self.get_producer_configuration())
72
+ return producer
73
+
62
74
 
63
75
  class KafkaIndexerConfig(IndexerConfig):
64
- topic: str
76
+ topic: str = Field(description="which topic to consume from")
65
77
  num_messages_to_consume: Optional[int] = 100
78
+ timeout: Optional[float] = Field(default=1.0, description="polling timeout")
66
79
 
67
80
  def update_consumer(self, consumer: "Consumer") -> None:
68
81
  consumer.subscribe([self.topic])
69
82
 
70
83
 
71
84
  @dataclass
72
- class KafkaIndexer(Indexer):
85
+ class KafkaIndexer(Indexer, ABC):
73
86
  connection_config: KafkaConnectionConfig
74
87
  index_config: KafkaIndexerConfig
75
- connector_type: str = CONNECTOR_TYPE
76
88
 
77
89
  @contextmanager
78
90
  def get_consumer(self) -> ContextManager["Consumer"]:
@@ -90,7 +102,7 @@ class KafkaIndexer(Indexer):
90
102
  num_messages_to_consume = self.index_config.num_messages_to_consume
91
103
  with self.get_consumer() as consumer:
92
104
  while messages_consumed < num_messages_to_consume and empty_polls < max_empty_polls:
93
- msg = consumer.poll(timeout=self.connection_config.timeout)
105
+ msg = consumer.poll(timeout=self.index_config.timeout)
94
106
  if msg is None:
95
107
  logger.debug("No Kafka messages found")
96
108
  empty_polls += 1
@@ -139,16 +151,22 @@ class KafkaIndexer(Indexer):
139
151
  for message in self.generate_messages():
140
152
  yield self.generate_file_data(message)
141
153
 
142
- async def run_async(self, file_data: FileData, **kwargs: Any) -> download_responses:
154
+ async def run_async(self, file_data: FileData, **kwargs: Any) -> DownloadResponse:
143
155
  raise NotImplementedError()
144
156
 
145
157
  def precheck(self):
146
158
  try:
147
159
  with self.get_consumer() as consumer:
148
- cluster_meta = consumer.list_topics(timeout=self.connection_config.timeout)
160
+ cluster_meta = consumer.list_topics(timeout=self.index_config.timeout)
149
161
  current_topics = [
150
162
  topic for topic in cluster_meta.topics if topic != "__consumer_offsets"
151
163
  ]
164
+ if self.index_config.topic not in current_topics:
165
+ raise SourceConnectionError(
166
+ "expected topic {} not detected in cluster: {}".format(
167
+ self.index_config.topic, ", ".join(current_topics)
168
+ )
169
+ )
152
170
  logger.info(f"successfully checked available topics: {current_topics}")
153
171
  except Exception as e:
154
172
  logger.error(f"failed to validate connection: {e}", exc_info=True)
@@ -160,14 +178,13 @@ class KafkaDownloaderConfig(DownloaderConfig):
160
178
 
161
179
 
162
180
  @dataclass
163
- class KafkaDownloader(Downloader):
181
+ class KafkaDownloader(Downloader, ABC):
164
182
  connection_config: KafkaConnectionConfig
165
183
  download_config: KafkaDownloaderConfig = field(default_factory=KafkaDownloaderConfig)
166
- connector_type: str = CONNECTOR_TYPE
167
184
  version: Optional[str] = None
168
185
  source_url: Optional[str] = None
169
186
 
170
- def run(self, file_data: FileData, **kwargs: Any) -> download_responses:
187
+ def run(self, file_data: FileData, **kwargs: Any) -> DownloadResponse:
171
188
  source_identifiers = file_data.source_identifiers
172
189
  if source_identifiers is None:
173
190
  raise ValueError("FileData is missing source_identifiers")
@@ -187,10 +204,54 @@ class KafkaDownloader(Downloader):
187
204
  return self.generate_download_response(file_data=file_data, download_path=download_path)
188
205
 
189
206
 
190
- kafka_source_entry = SourceRegistryEntry(
191
- connection_config=KafkaConnectionConfig,
192
- indexer=KafkaIndexer,
193
- indexer_config=KafkaIndexerConfig,
194
- downloader=KafkaDownloader,
195
- downloader_config=KafkaDownloaderConfig,
196
- )
207
+ class KafkaUploaderConfig(UploaderConfig):
208
+ batch_size: int = Field(default=100, description="Batch size")
209
+ topic: str = Field(description="which topic to write to")
210
+ timeout: Optional[float] = Field(
211
+ default=10.0, description="Timeout in seconds to flush batch of messages"
212
+ )
213
+
214
+
215
+ @dataclass
216
+ class KafkaUploader(Uploader, ABC):
217
+ connection_config: KafkaConnectionConfig
218
+ upload_config: KafkaUploaderConfig
219
+
220
+ def precheck(self):
221
+ try:
222
+ with self.connection_config.get_consumer() as consumer:
223
+ cluster_meta = consumer.list_topics(timeout=self.upload_config.timeout)
224
+ current_topics = [
225
+ topic for topic in cluster_meta.topics if topic != "__consumer_offsets"
226
+ ]
227
+ logger.info(f"successfully checked available topics: {current_topics}")
228
+ except Exception as e:
229
+ logger.error(f"failed to validate connection: {e}", exc_info=True)
230
+ raise DestinationConnectionError(f"failed to validate connection: {e}")
231
+
232
+ def produce_batch(self, elements: list[dict]) -> None:
233
+ from confluent_kafka.error import KafkaException
234
+
235
+ producer = self.connection_config.get_producer()
236
+ failed_producer = False
237
+
238
+ def acked(err, msg):
239
+ if err is not None:
240
+ logger.error("Failed to deliver message: %s: %s" % (str(msg), str(err)))
241
+
242
+ for element in elements:
243
+ producer.produce(
244
+ topic=self.upload_config.topic,
245
+ value=json.dumps(element),
246
+ callback=acked,
247
+ )
248
+
249
+ producer.flush(timeout=self.upload_config.timeout)
250
+ if failed_producer:
251
+ raise KafkaException("failed to produce all messages in batch")
252
+
253
+ def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
254
+ with path.open("r") as elements_file:
255
+ elements = json.load(elements_file)
256
+ for element_batch in batch_generator(elements, batch_size=self.upload_config.batch_size):
257
+ self.produce_batch(elements=element_batch)