unstructured-ingest 0.3.10__py3-none-any.whl → 0.3.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (53) hide show
  1. test/integration/connectors/{databricks_tests → databricks}/test_volumes_native.py +75 -19
  2. test/integration/connectors/sql/test_postgres.py +6 -2
  3. test/integration/connectors/sql/test_singlestore.py +6 -2
  4. test/integration/connectors/sql/test_snowflake.py +6 -2
  5. test/integration/connectors/sql/test_sqlite.py +6 -2
  6. test/integration/connectors/test_milvus.py +13 -0
  7. test/integration/connectors/test_onedrive.py +6 -0
  8. test/integration/connectors/test_redis.py +119 -0
  9. test/integration/connectors/test_vectara.py +270 -0
  10. test/integration/embedders/test_bedrock.py +28 -0
  11. test/integration/embedders/test_octoai.py +14 -0
  12. test/integration/embedders/test_openai.py +13 -0
  13. test/integration/embedders/test_togetherai.py +10 -0
  14. test/integration/partitioners/test_partitioner.py +2 -2
  15. test/unit/embed/test_octoai.py +8 -1
  16. unstructured_ingest/__version__.py +1 -1
  17. unstructured_ingest/embed/bedrock.py +39 -11
  18. unstructured_ingest/embed/interfaces.py +5 -0
  19. unstructured_ingest/embed/octoai.py +44 -3
  20. unstructured_ingest/embed/openai.py +37 -1
  21. unstructured_ingest/embed/togetherai.py +28 -1
  22. unstructured_ingest/embed/voyageai.py +33 -1
  23. unstructured_ingest/v2/errors.py +18 -0
  24. unstructured_ingest/v2/interfaces/file_data.py +11 -1
  25. unstructured_ingest/v2/processes/connectors/__init__.py +7 -0
  26. unstructured_ingest/v2/processes/connectors/astradb.py +2 -0
  27. unstructured_ingest/v2/processes/connectors/chroma.py +0 -1
  28. unstructured_ingest/v2/processes/connectors/couchbase.py +2 -0
  29. unstructured_ingest/v2/processes/connectors/databricks/volumes.py +5 -0
  30. unstructured_ingest/v2/processes/connectors/databricks/volumes_aws.py +2 -2
  31. unstructured_ingest/v2/processes/connectors/databricks/volumes_azure.py +2 -2
  32. unstructured_ingest/v2/processes/connectors/databricks/volumes_gcp.py +2 -2
  33. unstructured_ingest/v2/processes/connectors/databricks/volumes_native.py +2 -2
  34. unstructured_ingest/v2/processes/connectors/elasticsearch/elasticsearch.py +1 -1
  35. unstructured_ingest/v2/processes/connectors/kafka/cloud.py +5 -2
  36. unstructured_ingest/v2/processes/connectors/kafka/kafka.py +14 -3
  37. unstructured_ingest/v2/processes/connectors/milvus.py +15 -6
  38. unstructured_ingest/v2/processes/connectors/mongodb.py +3 -4
  39. unstructured_ingest/v2/processes/connectors/neo4j.py +2 -0
  40. unstructured_ingest/v2/processes/connectors/onedrive.py +79 -25
  41. unstructured_ingest/v2/processes/connectors/qdrant/qdrant.py +0 -1
  42. unstructured_ingest/v2/processes/connectors/redisdb.py +182 -0
  43. unstructured_ingest/v2/processes/connectors/sql/sql.py +5 -0
  44. unstructured_ingest/v2/processes/connectors/vectara.py +350 -0
  45. unstructured_ingest/v2/unstructured_api.py +25 -2
  46. {unstructured_ingest-0.3.10.dist-info → unstructured_ingest-0.3.12.dist-info}/METADATA +20 -16
  47. {unstructured_ingest-0.3.10.dist-info → unstructured_ingest-0.3.12.dist-info}/RECORD +52 -48
  48. test/integration/connectors/test_kafka.py +0 -304
  49. /test/integration/connectors/{databricks_tests → databricks}/__init__.py +0 -0
  50. {unstructured_ingest-0.3.10.dist-info → unstructured_ingest-0.3.12.dist-info}/LICENSE.md +0 -0
  51. {unstructured_ingest-0.3.10.dist-info → unstructured_ingest-0.3.12.dist-info}/WHEEL +0 -0
  52. {unstructured_ingest-0.3.10.dist-info → unstructured_ingest-0.3.12.dist-info}/entry_points.txt +0 -0
  53. {unstructured_ingest-0.3.10.dist-info → unstructured_ingest-0.3.12.dist-info}/top_level.txt +0 -0
@@ -3,12 +3,12 @@ from typing import Optional
3
3
 
4
4
  from pydantic import Field, Secret
5
5
 
6
- from unstructured_ingest.v2.interfaces import AccessConfig
7
6
  from unstructured_ingest.v2.processes.connector_registry import (
8
7
  DestinationRegistryEntry,
9
8
  SourceRegistryEntry,
10
9
  )
11
10
  from unstructured_ingest.v2.processes.connectors.databricks.volumes import (
11
+ DatabricksVolumesAccessConfig,
12
12
  DatabricksVolumesConnectionConfig,
13
13
  DatabricksVolumesDownloader,
14
14
  DatabricksVolumesDownloaderConfig,
@@ -21,7 +21,7 @@ from unstructured_ingest.v2.processes.connectors.databricks.volumes import (
21
21
  CONNECTOR_TYPE = "databricks_volumes_gcp"
22
22
 
23
23
 
24
- class DatabricksGoogleVolumesAccessConfig(AccessConfig):
24
+ class DatabricksGoogleVolumesAccessConfig(DatabricksVolumesAccessConfig):
25
25
  account_id: Optional[str] = Field(
26
26
  default=None,
27
27
  description="The Databricks account ID for the Databricks " "accounts endpoint.",
@@ -3,12 +3,12 @@ from typing import Optional
3
3
 
4
4
  from pydantic import Field, Secret
5
5
 
6
- from unstructured_ingest.v2.interfaces import AccessConfig
7
6
  from unstructured_ingest.v2.processes.connector_registry import (
8
7
  DestinationRegistryEntry,
9
8
  SourceRegistryEntry,
10
9
  )
11
10
  from unstructured_ingest.v2.processes.connectors.databricks.volumes import (
11
+ DatabricksVolumesAccessConfig,
12
12
  DatabricksVolumesConnectionConfig,
13
13
  DatabricksVolumesDownloader,
14
14
  DatabricksVolumesDownloaderConfig,
@@ -21,7 +21,7 @@ from unstructured_ingest.v2.processes.connectors.databricks.volumes import (
21
21
  CONNECTOR_TYPE = "databricks_volumes"
22
22
 
23
23
 
24
- class DatabricksNativeVolumesAccessConfig(AccessConfig):
24
+ class DatabricksNativeVolumesAccessConfig(DatabricksVolumesAccessConfig):
25
25
  client_id: Optional[str] = Field(default=None, description="Client ID of the OAuth app.")
26
26
  client_secret: Optional[str] = Field(
27
27
  default=None, description="Client Secret of the OAuth app."
@@ -255,6 +255,7 @@ class ElasticsearchDownloader(Downloader):
255
255
  exc_info=True,
256
256
  )
257
257
  raise SourceConnectionNetworkError(f"failed to download file {file_data.identifier}")
258
+ file_data.source_identifiers = SourceIdentifiers(filename=filename, fullpath=filename)
258
259
  cast_file_data = FileData.cast(file_data=file_data)
259
260
  cast_file_data.identifier = filename_id
260
261
  cast_file_data.metadata.date_processed = str(time())
@@ -264,7 +265,6 @@ class ElasticsearchDownloader(Downloader):
264
265
  "index_name": index_name,
265
266
  "document_id": record_id,
266
267
  }
267
- cast_file_data.source_identifiers = SourceIdentifiers(filename=filename, fullpath=filename)
268
268
  return super().generate_download_response(
269
269
  file_data=cast_file_data,
270
270
  download_path=download_path,
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
4
4
 
5
5
  from pydantic import Field, Secret, SecretStr
6
6
 
7
+ from unstructured_ingest.v2.logger import logger
7
8
  from unstructured_ingest.v2.processes.connector_registry import (
8
9
  DestinationRegistryEntry,
9
10
  SourceRegistryEntry,
@@ -50,6 +51,7 @@ class CloudKafkaConnectionConfig(KafkaConnectionConfig):
50
51
  "sasl.password": access_config.secret.get_secret_value(),
51
52
  "sasl.mechanism": "PLAIN",
52
53
  "security.protocol": "SASL_SSL",
54
+ "logger": logger,
53
55
  }
54
56
 
55
57
  return conf
@@ -61,10 +63,11 @@ class CloudKafkaConnectionConfig(KafkaConnectionConfig):
61
63
 
62
64
  conf = {
63
65
  "bootstrap.servers": f"{bootstrap}:{port}",
64
- "sasl.username": access_config.kafka_api_key,
65
- "sasl.password": access_config.secret,
66
+ "sasl.username": access_config.kafka_api_key.get_secret_value(),
67
+ "sasl.password": access_config.secret.get_secret_value(),
66
68
  "sasl.mechanism": "PLAIN",
67
69
  "security.protocol": "SASL_SSL",
70
+ "logger": logger,
68
71
  }
69
72
 
70
73
  return conf
@@ -170,7 +170,7 @@ class KafkaIndexer(Indexer, ABC):
170
170
  ]
171
171
  if self.index_config.topic not in current_topics:
172
172
  raise SourceConnectionError(
173
- "expected topic {} not detected in cluster: {}".format(
173
+ "expected topic '{}' not detected in cluster: '{}'".format(
174
174
  self.index_config.topic, ", ".join(current_topics)
175
175
  )
176
176
  )
@@ -232,6 +232,13 @@ class KafkaUploader(Uploader, ABC):
232
232
  topic for topic in cluster_meta.topics if topic != "__consumer_offsets"
233
233
  ]
234
234
  logger.info(f"successfully checked available topics: {current_topics}")
235
+ if self.upload_config.topic not in current_topics:
236
+ raise DestinationConnectionError(
237
+ "expected topic '{}' not detected in cluster: '{}'".format(
238
+ self.upload_config.topic, ", ".join(current_topics)
239
+ )
240
+ )
241
+
235
242
  except Exception as e:
236
243
  logger.error(f"failed to validate connection: {e}", exc_info=True)
237
244
  raise DestinationConnectionError(f"failed to validate connection: {e}")
@@ -243,8 +250,10 @@ class KafkaUploader(Uploader, ABC):
243
250
  failed_producer = False
244
251
 
245
252
  def acked(err, msg):
253
+ nonlocal failed_producer
246
254
  if err is not None:
247
- logger.error("Failed to deliver message: %s: %s" % (str(msg), str(err)))
255
+ failed_producer = True
256
+ logger.error("Failed to deliver kafka message: %s: %s" % (str(msg), str(err)))
248
257
 
249
258
  for element in elements:
250
259
  producer.produce(
@@ -253,7 +262,9 @@ class KafkaUploader(Uploader, ABC):
253
262
  callback=acked,
254
263
  )
255
264
 
256
- producer.flush(timeout=self.upload_config.timeout)
265
+ while producer_len := len(producer):
266
+ logger.debug(f"another iteration of kafka producer flush. Queue length: {producer_len}")
267
+ producer.flush(timeout=self.upload_config.timeout)
257
268
  if failed_producer:
258
269
  raise KafkaException("failed to produce all messages in batch")
259
270
 
@@ -156,11 +156,18 @@ class MilvusUploader(Uploader):
156
156
 
157
157
  @DestinationConnectionError.wrap
158
158
  def precheck(self):
159
- with self.get_client() as client:
160
- if not client.has_collection(self.upload_config.collection_name):
161
- raise DestinationConnectionError(
162
- f"Collection '{self.upload_config.collection_name}' does not exist"
163
- )
159
+ from pymilvus import MilvusException
160
+
161
+ try:
162
+ with self.get_client() as client:
163
+ if not client.has_collection(self.upload_config.collection_name):
164
+ raise DestinationConnectionError(
165
+ f"Collection '{self.upload_config.collection_name}' does not exist"
166
+ )
167
+ except MilvusException as milvus_exception:
168
+ raise DestinationConnectionError(
169
+ f"failed to precheck Milvus: {str(milvus_exception.message)}"
170
+ ) from milvus_exception
164
171
 
165
172
  @contextmanager
166
173
  def get_client(self) -> Generator["MilvusClient", None, None]:
@@ -197,7 +204,9 @@ class MilvusUploader(Uploader):
197
204
  try:
198
205
  res = client.insert(collection_name=self.upload_config.collection_name, data=data)
199
206
  except MilvusException as milvus_exception:
200
- raise WriteError("failed to upload records to milvus") from milvus_exception
207
+ raise WriteError(
208
+ f"failed to upload records to Milvus: {str(milvus_exception.message)}"
209
+ ) from milvus_exception
201
210
  if "err_count" in res and isinstance(res["err_count"], int) and res["err_count"] > 0:
202
211
  err_count = res["err_count"]
203
212
  raise WriteError(f"failed to upload {err_count} docs")
@@ -198,14 +198,13 @@ class MongoDBDownloader(Downloader):
198
198
  concatenated_values = "\n".join(str(value) for value in flattened_dict.values())
199
199
 
200
200
  # Create a FileData object for each document with source_identifiers
201
- cast_file_data = FileData.cast(file_data=file_data)
202
- cast_file_data.identifier = str(doc_id)
203
201
  filename = f"{doc_id}.txt"
204
- cast_file_data.source_identifiers = SourceIdentifiers(
202
+ file_data.source_identifiers = SourceIdentifiers(
205
203
  filename=filename,
206
204
  fullpath=filename,
207
- rel_path=filename,
208
205
  )
206
+ cast_file_data = FileData.cast(file_data=file_data)
207
+ cast_file_data.identifier = str(doc_id)
209
208
 
210
209
  # Determine the download path
211
210
  download_path = self.get_download_path(file_data=cast_file_data)
@@ -378,6 +378,8 @@ class Neo4jUploader(Uploader):
378
378
 
379
379
  neo4j_destination_entry = DestinationRegistryEntry(
380
380
  connection_config=Neo4jConnectionConfig,
381
+ upload_stager=Neo4jUploadStager,
382
+ upload_stager_config=Neo4jUploadStagerConfig,
381
383
  uploader=Neo4jUploader,
382
384
  uploader_config=Neo4jUploaderConfig,
383
385
  )
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
3
4
  import json
4
5
  from dataclasses import dataclass
5
6
  from pathlib import Path
6
7
  from time import time
7
- from typing import TYPE_CHECKING, Any, Generator, Optional
8
+ from typing import TYPE_CHECKING, Any, AsyncIterator, Generator, Iterator, Optional, TypeVar
8
9
 
9
10
  from dateutil import parser
10
11
  from pydantic import Field, Secret
@@ -100,6 +101,27 @@ class OnedriveIndexerConfig(IndexerConfig):
100
101
  recursive: bool = False
101
102
 
102
103
 
104
+ T = TypeVar("T")
105
+
106
+
107
+ def async_iterable_to_sync_iterable(iterator: AsyncIterator[T]) -> Iterator[T]:
108
+ # This version works on Python 3.9 by manually handling the async iteration.
109
+ loop = asyncio.new_event_loop()
110
+ asyncio.set_event_loop(loop)
111
+ try:
112
+ while True:
113
+ try:
114
+ # Instead of anext(iterator), we directly call __anext__().
115
+ # __anext__ returns a coroutine that we must run until complete.
116
+ future = iterator.__anext__()
117
+ result = loop.run_until_complete(future)
118
+ yield result
119
+ except StopAsyncIteration:
120
+ break
121
+ finally:
122
+ loop.close()
123
+
124
+
103
125
  @dataclass
104
126
  class OnedriveIndexer(Indexer):
105
127
  connection_config: OnedriveConnectionConfig
@@ -116,17 +138,21 @@ class OnedriveIndexer(Indexer):
116
138
  logger.error(f"failed to validate connection: {e}", exc_info=True)
117
139
  raise SourceConnectionError(f"failed to validate connection: {e}")
118
140
 
119
- def list_objects(self, folder: DriveItem, recursive: bool) -> list["DriveItem"]:
141
+ def list_objects_sync(self, folder: DriveItem, recursive: bool) -> list["DriveItem"]:
120
142
  drive_items = folder.children.get().execute_query()
121
143
  files = [d for d in drive_items if d.is_file]
122
144
  if not recursive:
123
145
  return files
146
+
124
147
  folders = [d for d in drive_items if d.is_folder]
125
148
  for f in folders:
126
- files.extend(self.list_objects(f, recursive))
149
+ files.extend(self.list_objects_sync(f, recursive))
127
150
  return files
128
151
 
129
- def get_root(self, client: "GraphClient") -> "DriveItem":
152
+ async def list_objects(self, folder: "DriveItem", recursive: bool) -> list["DriveItem"]:
153
+ return await asyncio.to_thread(self.list_objects_sync, folder, recursive)
154
+
155
+ def get_root_sync(self, client: "GraphClient") -> "DriveItem":
130
156
  root = client.users[self.connection_config.user_pname].drive.get().execute_query().root
131
157
  if fpath := self.index_config.path:
132
158
  root = root.get_by_path(fpath).get().execute_query()
@@ -134,7 +160,10 @@ class OnedriveIndexer(Indexer):
134
160
  raise ValueError(f"Unable to find directory, given: {fpath}")
135
161
  return root
136
162
 
137
- def get_properties(self, drive_item: "DriveItem") -> dict:
163
+ async def get_root(self, client: "GraphClient") -> "DriveItem":
164
+ return await asyncio.to_thread(self.get_root_sync, client)
165
+
166
+ def get_properties_sync(self, drive_item: "DriveItem") -> dict:
138
167
  properties = drive_item.properties
139
168
  filtered_properties = {}
140
169
  for k, v in properties.items():
@@ -145,7 +174,10 @@ class OnedriveIndexer(Indexer):
145
174
  pass
146
175
  return filtered_properties
147
176
 
148
- def drive_item_to_file_data(self, drive_item: "DriveItem") -> FileData:
177
+ async def get_properties(self, drive_item: "DriveItem") -> dict:
178
+ return await asyncio.to_thread(self.get_properties_sync, drive_item)
179
+
180
+ def drive_item_to_file_data_sync(self, drive_item: "DriveItem") -> FileData:
149
181
  file_path = drive_item.parent_reference.path.split(":")[-1]
150
182
  file_path = file_path[1:] if file_path and file_path[0] == "/" else file_path
151
183
  filename = drive_item.name
@@ -176,17 +208,34 @@ class OnedriveIndexer(Indexer):
176
208
  "server_relative_path": server_path,
177
209
  },
178
210
  ),
179
- additional_metadata=self.get_properties(drive_item=drive_item),
211
+ additional_metadata=self.get_properties_sync(drive_item=drive_item),
180
212
  )
181
213
 
182
- def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
183
- client = self.connection_config.get_client()
184
- root = self.get_root(client=client)
185
- drive_items = self.list_objects(folder=root, recursive=self.index_config.recursive)
214
+ async def drive_item_to_file_data(self, drive_item: "DriveItem") -> FileData:
215
+ # Offload the file data creation if it's not guaranteed async
216
+ return await asyncio.to_thread(self.drive_item_to_file_data_sync, drive_item)
217
+
218
+ async def _run_async(self, **kwargs: Any) -> AsyncIterator[FileData]:
219
+ token_resp = await asyncio.to_thread(self.connection_config.get_token)
220
+ if "error" in token_resp:
221
+ raise SourceConnectionError(
222
+ f"[{CONNECTOR_TYPE}]: {token_resp['error']} ({token_resp.get('error_description')})"
223
+ )
224
+
225
+ client = await asyncio.to_thread(self.connection_config.get_client)
226
+ root = await self.get_root(client=client)
227
+ drive_items = await self.list_objects(folder=root, recursive=self.index_config.recursive)
228
+
186
229
  for drive_item in drive_items:
187
- file_data = self.drive_item_to_file_data(drive_item=drive_item)
230
+ file_data = await self.drive_item_to_file_data(drive_item=drive_item)
188
231
  yield file_data
189
232
 
233
+ def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
234
+ # Convert the async generator to a sync generator without loading all data into memory
235
+ async_gen = self._run_async(**kwargs)
236
+ for item in async_iterable_to_sync_iterable(async_gen):
237
+ yield item
238
+
190
239
 
191
240
  class OnedriveDownloaderConfig(DownloaderConfig):
192
241
  pass
@@ -220,19 +269,24 @@ class OnedriveDownloader(Downloader):
220
269
 
221
270
  @SourceConnectionError.wrap
222
271
  def run(self, file_data: FileData, **kwargs: Any) -> DownloadResponse:
223
- file = self._fetch_file(file_data=file_data)
224
- fsize = file.get_property("size", 0)
225
- download_path = self.get_download_path(file_data=file_data)
226
- download_path.parent.mkdir(parents=True, exist_ok=True)
227
- logger.info(f"downloading {file_data.source_identifiers.fullpath} to {download_path}")
228
- if fsize > MAX_MB_SIZE:
229
- logger.info(f"downloading file with size: {fsize} bytes in chunks")
230
- with download_path.open(mode="wb") as f:
231
- file.download_session(f, chunk_size=1024 * 1024 * 100).execute_query()
232
- else:
233
- with download_path.open(mode="wb") as f:
234
- file.download(f).execute_query()
235
- return self.generate_download_response(file_data=file_data, download_path=download_path)
272
+ try:
273
+ file = self._fetch_file(file_data=file_data)
274
+ fsize = file.get_property("size", 0)
275
+ download_path = self.get_download_path(file_data=file_data)
276
+ download_path.parent.mkdir(parents=True, exist_ok=True)
277
+ logger.info(f"downloading {file_data.source_identifiers.fullpath} to {download_path}")
278
+ if fsize > MAX_MB_SIZE:
279
+ logger.info(f"downloading file with size: {fsize} bytes in chunks")
280
+ with download_path.open(mode="wb") as f:
281
+ file.download_session(f, chunk_size=1024 * 1024 * 100).execute_query()
282
+ else:
283
+ with download_path.open(mode="wb") as f:
284
+ file.download(f).execute_query()
285
+ return self.generate_download_response(file_data=file_data, download_path=download_path)
286
+ except Exception as e:
287
+ logger.error(f"[{CONNECTOR_TYPE}] Exception during downloading: {e}", exc_info=True)
288
+ # Re-raise to see full stack trace locally
289
+ raise
236
290
 
237
291
 
238
292
  class OnedriveUploaderConfig(UploaderConfig):
@@ -128,7 +128,6 @@ class QdrantUploader(Uploader, ABC):
128
128
  file_data: FileData,
129
129
  **kwargs: Any,
130
130
  ) -> None:
131
-
132
131
  batches = list(batch_generator(data, batch_size=self.upload_config.batch_size))
133
132
  logger.debug(
134
133
  "Elements split into %i batches of size %i.",
@@ -0,0 +1,182 @@
1
+ import json
2
+ from contextlib import asynccontextmanager, contextmanager
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, Optional
5
+
6
+ from pydantic import Field, Secret, model_validator
7
+
8
+ from unstructured_ingest.error import DestinationConnectionError
9
+ from unstructured_ingest.utils.data_prep import batch_generator
10
+ from unstructured_ingest.utils.dep_check import requires_dependencies
11
+ from unstructured_ingest.v2.interfaces import (
12
+ AccessConfig,
13
+ ConnectionConfig,
14
+ FileData,
15
+ Uploader,
16
+ UploaderConfig,
17
+ )
18
+ from unstructured_ingest.v2.logger import logger
19
+ from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
20
+
21
+ if TYPE_CHECKING:
22
+ from redis.asyncio import Redis
23
+
24
+ import asyncio
25
+
26
+ CONNECTOR_TYPE = "redis"
27
+ SERVER_API_VERSION = "1"
28
+
29
+
30
+ class RedisAccessConfig(AccessConfig):
31
+ uri: Optional[str] = Field(
32
+ default=None, description="If not anonymous, use this uri, if specified."
33
+ )
34
+ password: Optional[str] = Field(
35
+ default=None, description="If not anonymous, use this password, if specified."
36
+ )
37
+
38
+
39
+ class RedisConnectionConfig(ConnectionConfig):
40
+ access_config: Secret[RedisAccessConfig] = Field(
41
+ default=RedisAccessConfig(), validate_default=True
42
+ )
43
+ host: Optional[str] = Field(
44
+ default=None, description="Hostname or IP address of a Redis instance to connect to."
45
+ )
46
+ database: int = Field(default=0, description="Database index to connect to.")
47
+ port: int = Field(default=6379, description="port used to connect to database.")
48
+ username: Optional[str] = Field(
49
+ default=None, description="Username used to connect to database."
50
+ )
51
+ ssl: bool = Field(default=True, description="Whether the connection should use SSL encryption.")
52
+ connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
53
+
54
+ @model_validator(mode="after")
55
+ def validate_host_or_url(self) -> "RedisConnectionConfig":
56
+ if not self.access_config.get_secret_value().uri and not self.host:
57
+ raise ValueError("Please pass a hostname either directly or through uri")
58
+ return self
59
+
60
+ @requires_dependencies(["redis"], extras="redis")
61
+ @asynccontextmanager
62
+ async def create_async_client(self) -> AsyncGenerator["Redis", None]:
63
+ from redis.asyncio import Redis, from_url
64
+
65
+ access_config = self.access_config.get_secret_value()
66
+
67
+ options = {
68
+ "host": self.host,
69
+ "port": self.port,
70
+ "db": self.database,
71
+ "ssl": self.ssl,
72
+ "username": self.username,
73
+ }
74
+
75
+ if access_config.password:
76
+ options["password"] = access_config.password
77
+
78
+ if access_config.uri:
79
+ async with from_url(access_config.uri) as client:
80
+ yield client
81
+ else:
82
+ async with Redis(**options) as client:
83
+ yield client
84
+
85
+ @requires_dependencies(["redis"], extras="redis")
86
+ @contextmanager
87
+ def create_client(self) -> Generator["Redis", None, None]:
88
+ from redis import Redis, from_url
89
+
90
+ access_config = self.access_config.get_secret_value()
91
+
92
+ options = {
93
+ "host": self.host,
94
+ "port": self.port,
95
+ "db": self.database,
96
+ "ssl": self.ssl,
97
+ "username": self.username,
98
+ }
99
+
100
+ if access_config.password:
101
+ options["password"] = access_config.password
102
+
103
+ if access_config.uri:
104
+ with from_url(access_config.uri) as client:
105
+ yield client
106
+ else:
107
+ with Redis(**options) as client:
108
+ yield client
109
+
110
+
111
+ class RedisUploaderConfig(UploaderConfig):
112
+ batch_size: int = Field(default=100, description="Number of records per batch")
113
+
114
+
115
+ @dataclass
116
+ class RedisUploader(Uploader):
117
+ upload_config: RedisUploaderConfig
118
+ connection_config: RedisConnectionConfig
119
+ connector_type: str = CONNECTOR_TYPE
120
+
121
+ def is_async(self) -> bool:
122
+ return True
123
+
124
+ def precheck(self) -> None:
125
+ try:
126
+ with self.connection_config.create_client() as client:
127
+ client.ping()
128
+ except Exception as e:
129
+ logger.error(f"failed to validate connection: {e}", exc_info=True)
130
+ raise DestinationConnectionError(f"failed to validate connection: {e}")
131
+
132
+ async def run_data_async(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
133
+ first_element = data[0]
134
+ redis_stack = await self._check_redis_stack(first_element)
135
+ logger.info(
136
+ f"writing {len(data)} objects to destination asynchronously, "
137
+ f"db, {self.connection_config.database}, "
138
+ f"at {self.connection_config.host}",
139
+ )
140
+
141
+ batches = list(batch_generator(data, batch_size=self.upload_config.batch_size))
142
+ await asyncio.gather(*[self._write_batch(batch, redis_stack) for batch in batches])
143
+
144
+ async def _write_batch(self, batch: list[dict], redis_stack: bool) -> None:
145
+ async with self.connection_config.create_async_client() as async_client:
146
+ async with async_client.pipeline(transaction=True) as pipe:
147
+ for element in batch:
148
+ element_id = element["element_id"]
149
+ if redis_stack:
150
+ pipe.json().set(element_id, "$", element)
151
+ else:
152
+ pipe.set(element_id, json.dumps(element))
153
+ await pipe.execute()
154
+
155
+ @requires_dependencies(["redis"], extras="redis")
156
+ async def _check_redis_stack(self, element: dict) -> bool:
157
+ from redis import exceptions as redis_exceptions
158
+
159
+ redis_stack = True
160
+ async with self.connection_config.create_async_client() as async_client:
161
+ async with async_client.pipeline(transaction=True) as pipe:
162
+ element_id = element["element_id"]
163
+ try:
164
+ # Redis with stack extension supports JSON type
165
+ await pipe.json().set(element_id, "$", element).execute()
166
+ except redis_exceptions.ResponseError as e:
167
+ message = str(e)
168
+ if "unknown command `JSON.SET`" in message:
169
+ # if this error occurs, Redis server doesn't support JSON type,
170
+ # so save as string type instead
171
+ await pipe.set(element_id, json.dumps(element)).execute()
172
+ redis_stack = False
173
+ else:
174
+ raise e
175
+ return redis_stack
176
+
177
+
178
+ redis_destination_entry = DestinationRegistryEntry(
179
+ connection_config=RedisConnectionConfig,
180
+ uploader=RedisUploader,
181
+ uploader_config=RedisUploaderConfig,
182
+ )
@@ -28,6 +28,7 @@ from unstructured_ingest.v2.interfaces import (
28
28
  FileDataSourceMetadata,
29
29
  Indexer,
30
30
  IndexerConfig,
31
+ SourceIdentifiers,
31
32
  Uploader,
32
33
  UploaderConfig,
33
34
  UploadStager,
@@ -218,6 +219,10 @@ class SQLDownloader(Downloader, ABC):
218
219
  )
219
220
  download_path.parent.mkdir(parents=True, exist_ok=True)
220
221
  result.to_csv(download_path, index=False)
222
+ file_data.source_identifiers = SourceIdentifiers(
223
+ filename=filename,
224
+ fullpath=filename,
225
+ )
221
226
  cast_file_data = FileData.cast(file_data=file_data)
222
227
  cast_file_data.identifier = filename_id
223
228
  return super().generate_download_response(