unstructured-ingest 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (125) hide show
  1. unstructured_ingest/__version__.py +1 -1
  2. unstructured_ingest/cli/cli.py +6 -1
  3. unstructured_ingest/cli/cmds/__init__.py +4 -4
  4. unstructured_ingest/cli/cmds/{astra.py → astradb.py} +9 -9
  5. unstructured_ingest/cli/interfaces.py +13 -6
  6. unstructured_ingest/connector/{astra.py → astradb.py} +29 -29
  7. unstructured_ingest/connector/biomed.py +12 -5
  8. unstructured_ingest/connector/confluence.py +3 -3
  9. unstructured_ingest/connector/github.py +3 -2
  10. unstructured_ingest/connector/google_drive.py +1 -2
  11. unstructured_ingest/connector/mongodb.py +1 -2
  12. unstructured_ingest/connector/notion/client.py +31 -16
  13. unstructured_ingest/connector/notion/connector.py +3 -2
  14. unstructured_ingest/connector/registry.py +2 -2
  15. unstructured_ingest/connector/vectara.py +7 -2
  16. unstructured_ingest/interfaces.py +13 -9
  17. unstructured_ingest/pipeline/interfaces.py +8 -3
  18. unstructured_ingest/pipeline/reformat/chunking.py +13 -9
  19. unstructured_ingest/pipeline/reformat/embedding.py +3 -3
  20. unstructured_ingest/runner/__init__.py +2 -2
  21. unstructured_ingest/runner/{astra.py → astradb.py} +7 -7
  22. unstructured_ingest/runner/writers/__init__.py +2 -2
  23. unstructured_ingest/runner/writers/{astra.py → astradb.py} +7 -7
  24. unstructured_ingest/utils/chunking.py +45 -0
  25. unstructured_ingest/utils/dep_check.py +1 -1
  26. unstructured_ingest/utils/google_filetype.py +9 -0
  27. unstructured_ingest/v2/cli/base/cmd.py +57 -13
  28. unstructured_ingest/v2/cli/base/dest.py +21 -12
  29. unstructured_ingest/v2/cli/base/src.py +35 -23
  30. unstructured_ingest/v2/cli/cmds.py +14 -0
  31. unstructured_ingest/v2/cli/{utils.py → utils/click.py} +36 -89
  32. unstructured_ingest/v2/cli/utils/model_conversion.py +199 -0
  33. unstructured_ingest/v2/interfaces/connector.py +5 -7
  34. unstructured_ingest/v2/interfaces/downloader.py +8 -5
  35. unstructured_ingest/v2/interfaces/file_data.py +8 -2
  36. unstructured_ingest/v2/interfaces/indexer.py +3 -4
  37. unstructured_ingest/v2/interfaces/processor.py +10 -10
  38. unstructured_ingest/v2/interfaces/upload_stager.py +3 -3
  39. unstructured_ingest/v2/interfaces/uploader.py +3 -3
  40. unstructured_ingest/v2/pipeline/pipeline.py +9 -6
  41. unstructured_ingest/v2/pipeline/steps/chunk.py +5 -11
  42. unstructured_ingest/v2/pipeline/steps/download.py +13 -11
  43. unstructured_ingest/v2/pipeline/steps/embed.py +5 -11
  44. unstructured_ingest/v2/pipeline/steps/filter.py +1 -6
  45. unstructured_ingest/v2/pipeline/steps/index.py +14 -10
  46. unstructured_ingest/v2/pipeline/steps/partition.py +5 -5
  47. unstructured_ingest/v2/pipeline/steps/stage.py +4 -7
  48. unstructured_ingest/v2/pipeline/steps/uncompress.py +1 -6
  49. unstructured_ingest/v2/pipeline/steps/upload.py +2 -9
  50. unstructured_ingest/v2/processes/__init__.py +18 -0
  51. unstructured_ingest/v2/processes/chunker.py +74 -28
  52. unstructured_ingest/v2/processes/connector_registry.py +8 -2
  53. unstructured_ingest/v2/processes/connectors/__init__.py +18 -3
  54. unstructured_ingest/v2/processes/connectors/{astra.py → astradb.py} +46 -39
  55. unstructured_ingest/v2/processes/connectors/azure_cognitive_search.py +30 -27
  56. unstructured_ingest/v2/processes/connectors/chroma.py +30 -21
  57. unstructured_ingest/v2/processes/connectors/couchbase.py +333 -0
  58. unstructured_ingest/v2/processes/connectors/databricks_volumes.py +87 -32
  59. unstructured_ingest/v2/processes/connectors/elasticsearch.py +70 -45
  60. unstructured_ingest/v2/processes/connectors/fsspec/azure.py +39 -16
  61. unstructured_ingest/v2/processes/connectors/fsspec/box.py +15 -13
  62. unstructured_ingest/v2/processes/connectors/fsspec/dropbox.py +10 -11
  63. unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +20 -34
  64. unstructured_ingest/v2/processes/connectors/fsspec/gcs.py +38 -13
  65. unstructured_ingest/v2/processes/connectors/fsspec/s3.py +31 -17
  66. unstructured_ingest/v2/processes/connectors/fsspec/sftp.py +19 -28
  67. unstructured_ingest/v2/processes/connectors/google_drive.py +40 -34
  68. unstructured_ingest/v2/processes/connectors/kdbai.py +170 -0
  69. unstructured_ingest/v2/processes/connectors/local.py +27 -16
  70. unstructured_ingest/v2/processes/connectors/milvus.py +22 -18
  71. unstructured_ingest/v2/processes/connectors/mongodb.py +22 -18
  72. unstructured_ingest/v2/processes/connectors/onedrive.py +17 -14
  73. unstructured_ingest/v2/processes/connectors/opensearch.py +66 -56
  74. unstructured_ingest/v2/processes/connectors/pinecone.py +22 -21
  75. unstructured_ingest/v2/processes/connectors/salesforce.py +26 -18
  76. unstructured_ingest/v2/processes/connectors/sharepoint.py +51 -26
  77. unstructured_ingest/v2/processes/connectors/singlestore.py +11 -15
  78. unstructured_ingest/v2/processes/connectors/sql.py +29 -31
  79. unstructured_ingest/v2/processes/connectors/weaviate.py +22 -13
  80. unstructured_ingest/v2/processes/embedder.py +106 -47
  81. unstructured_ingest/v2/processes/filter.py +11 -5
  82. unstructured_ingest/v2/processes/partitioner.py +79 -33
  83. unstructured_ingest/v2/processes/uncompress.py +3 -3
  84. unstructured_ingest/v2/utils.py +45 -0
  85. unstructured_ingest-0.0.5.dist-info/LICENSE.md +201 -0
  86. unstructured_ingest-0.0.5.dist-info/METADATA +574 -0
  87. {unstructured_ingest-0.0.3.dist-info → unstructured_ingest-0.0.5.dist-info}/RECORD +91 -116
  88. {unstructured_ingest-0.0.3.dist-info → unstructured_ingest-0.0.5.dist-info}/WHEEL +1 -1
  89. unstructured_ingest/v2/cli/cmds/__init__.py +0 -89
  90. unstructured_ingest/v2/cli/cmds/astra.py +0 -85
  91. unstructured_ingest/v2/cli/cmds/azure_cognitive_search.py +0 -72
  92. unstructured_ingest/v2/cli/cmds/chroma.py +0 -108
  93. unstructured_ingest/v2/cli/cmds/databricks_volumes.py +0 -161
  94. unstructured_ingest/v2/cli/cmds/elasticsearch.py +0 -159
  95. unstructured_ingest/v2/cli/cmds/fsspec/azure.py +0 -84
  96. unstructured_ingest/v2/cli/cmds/fsspec/box.py +0 -58
  97. unstructured_ingest/v2/cli/cmds/fsspec/dropbox.py +0 -58
  98. unstructured_ingest/v2/cli/cmds/fsspec/fsspec.py +0 -69
  99. unstructured_ingest/v2/cli/cmds/fsspec/gcs.py +0 -81
  100. unstructured_ingest/v2/cli/cmds/fsspec/s3.py +0 -84
  101. unstructured_ingest/v2/cli/cmds/fsspec/sftp.py +0 -80
  102. unstructured_ingest/v2/cli/cmds/google_drive.py +0 -74
  103. unstructured_ingest/v2/cli/cmds/local.py +0 -52
  104. unstructured_ingest/v2/cli/cmds/milvus.py +0 -72
  105. unstructured_ingest/v2/cli/cmds/mongodb.py +0 -62
  106. unstructured_ingest/v2/cli/cmds/onedrive.py +0 -91
  107. unstructured_ingest/v2/cli/cmds/opensearch.py +0 -93
  108. unstructured_ingest/v2/cli/cmds/pinecone.py +0 -62
  109. unstructured_ingest/v2/cli/cmds/salesforce.py +0 -79
  110. unstructured_ingest/v2/cli/cmds/sharepoint.py +0 -112
  111. unstructured_ingest/v2/cli/cmds/singlestore.py +0 -96
  112. unstructured_ingest/v2/cli/cmds/sql.py +0 -84
  113. unstructured_ingest/v2/cli/cmds/weaviate.py +0 -100
  114. unstructured_ingest/v2/cli/configs/__init__.py +0 -13
  115. unstructured_ingest/v2/cli/configs/chunk.py +0 -89
  116. unstructured_ingest/v2/cli/configs/embed.py +0 -74
  117. unstructured_ingest/v2/cli/configs/filter.py +0 -28
  118. unstructured_ingest/v2/cli/configs/partition.py +0 -99
  119. unstructured_ingest/v2/cli/configs/processor.py +0 -88
  120. unstructured_ingest/v2/cli/interfaces.py +0 -27
  121. unstructured_ingest/v2/pipeline/utils.py +0 -15
  122. unstructured_ingest-0.0.3.dist-info/METADATA +0 -175
  123. /unstructured_ingest/v2/cli/{cmds/fsspec → utils}/__init__.py +0 -0
  124. {unstructured_ingest-0.0.3.dist-info → unstructured_ingest-0.0.5.dist-info}/entry_points.txt +0 -0
  125. {unstructured_ingest-0.0.3.dist-info → unstructured_ingest-0.0.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,333 @@
1
+ import hashlib
2
+ import json
3
+ import sys
4
+ import time
5
+ from dataclasses import dataclass, field
6
+ from datetime import timedelta
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING, Any, Generator, List
9
+
10
+ from pydantic import Field, Secret
11
+
12
+ from unstructured_ingest.error import (
13
+ DestinationConnectionError,
14
+ SourceConnectionError,
15
+ SourceConnectionNetworkError,
16
+ )
17
+ from unstructured_ingest.utils.data_prep import batch_generator, flatten_dict
18
+ from unstructured_ingest.utils.dep_check import requires_dependencies
19
+ from unstructured_ingest.v2.interfaces import (
20
+ AccessConfig,
21
+ ConnectionConfig,
22
+ Downloader,
23
+ DownloaderConfig,
24
+ DownloadResponse,
25
+ FileData,
26
+ FileDataSourceMetadata,
27
+ Indexer,
28
+ IndexerConfig,
29
+ UploadContent,
30
+ Uploader,
31
+ UploaderConfig,
32
+ UploadStager,
33
+ UploadStagerConfig,
34
+ download_responses,
35
+ )
36
+ from unstructured_ingest.v2.logger import logger
37
+ from unstructured_ingest.v2.processes.connector_registry import (
38
+ DestinationRegistryEntry,
39
+ SourceRegistryEntry,
40
+ )
41
+
42
+ if TYPE_CHECKING:
43
+ from couchbase.cluster import Cluster
44
+
45
+ CONNECTOR_TYPE = "couchbase"
46
+ SERVER_API_VERSION = "1"
47
+
48
+
49
+ class CouchbaseAccessConfig(AccessConfig):
50
+ password: str = Field(description="The password for the Couchbase server")
51
+
52
+
53
+ class CouchbaseConnectionConfig(ConnectionConfig):
54
+ username: str = Field(description="The username for the Couchbase server")
55
+ bucket: str = Field(description="The bucket to connect to on the Couchbase server")
56
+ connection_string: str = Field(
57
+ default="couchbase://localhost", description="The connection string of the Couchbase server"
58
+ )
59
+ scope: str = Field(
60
+ default="_default", description="The scope to connect to on the Couchbase server"
61
+ )
62
+ collection: str = Field(
63
+ default="_default", description="The collection to connect to on the Couchbase server"
64
+ )
65
+ connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
66
+ access_config: Secret[CouchbaseAccessConfig]
67
+
68
+ @requires_dependencies(["couchbase"], extras="couchbase")
69
+ def connect_to_couchbase(self) -> "Cluster":
70
+ from couchbase.auth import PasswordAuthenticator
71
+ from couchbase.cluster import Cluster
72
+ from couchbase.options import ClusterOptions
73
+
74
+ auth = PasswordAuthenticator(self.username, self.access_config.get_secret_value().password)
75
+ options = ClusterOptions(auth)
76
+ options.apply_profile("wan_development")
77
+ cluster = Cluster(self.connection_string, options)
78
+ cluster.wait_until_ready(timedelta(seconds=5))
79
+ return cluster
80
+
81
+
82
+ class CouchbaseUploadStagerConfig(UploadStagerConfig):
83
+ pass
84
+
85
+
86
+ @dataclass
87
+ class CouchbaseUploadStager(UploadStager):
88
+ upload_stager_config: CouchbaseUploadStagerConfig = field(
89
+ default_factory=lambda: CouchbaseUploadStagerConfig()
90
+ )
91
+
92
+ def run(
93
+ self,
94
+ elements_filepath: Path,
95
+ output_dir: Path,
96
+ output_filename: str,
97
+ **kwargs: Any,
98
+ ) -> Path:
99
+ with open(elements_filepath) as elements_file:
100
+ elements_contents = json.load(elements_file)
101
+
102
+ output_elements = []
103
+ for element in elements_contents:
104
+ new_doc = {
105
+ element["element_id"]: {
106
+ "embedding": element.get("embeddings", None),
107
+ "text": element.get("text", None),
108
+ "metadata": element.get("metadata", None),
109
+ "type": element.get("type", None),
110
+ }
111
+ }
112
+ output_elements.append(new_doc)
113
+
114
+ output_path = Path(output_dir) / Path(f"{output_filename}.json")
115
+ with open(output_path, "w") as output_file:
116
+ json.dump(output_elements, output_file)
117
+ return output_path
118
+
119
+
120
+ class CouchbaseUploaderConfig(UploaderConfig):
121
+ batch_size: int = Field(default=50, description="Number of documents to upload per batch")
122
+
123
+
124
+ @dataclass
125
+ class CouchbaseUploader(Uploader):
126
+ connection_config: CouchbaseConnectionConfig
127
+ upload_config: CouchbaseUploaderConfig
128
+ connector_type: str = CONNECTOR_TYPE
129
+
130
+ def precheck(self) -> None:
131
+ try:
132
+ self.connection_config.connect_to_couchbase()
133
+ except Exception as e:
134
+ logger.error(f"Failed to validate connection {e}", exc_info=True)
135
+ raise DestinationConnectionError(f"failed to validate connection: {e}")
136
+
137
+ def run(self, contents: list[UploadContent], **kwargs: Any) -> None:
138
+ elements = []
139
+ for content in contents:
140
+ with open(content.path) as elements_file:
141
+ elements.extend(json.load(elements_file))
142
+
143
+ logger.info(
144
+ f"writing {len(elements)} objects to destination "
145
+ f"bucket, {self.connection_config.bucket} "
146
+ f"at {self.connection_config.connection_string}",
147
+ )
148
+ cluster = self.connection_config.connect_to_couchbase()
149
+ bucket = cluster.bucket(self.connection_config.bucket)
150
+ scope = bucket.scope(self.connection_config.scope)
151
+ collection = scope.collection(self.connection_config.collection)
152
+
153
+ for chunk in batch_generator(elements, self.upload_config.batch_size):
154
+ collection.upsert_multi({doc_id: doc for doc in chunk for doc_id, doc in doc.items()})
155
+
156
+
157
+ class CouchbaseIndexerConfig(IndexerConfig):
158
+ batch_size: int = Field(default=50, description="Number of documents to index per batch")
159
+
160
+
161
+ @dataclass
162
+ class CouchbaseIndexer(Indexer):
163
+ connection_config: CouchbaseConnectionConfig
164
+ index_config: CouchbaseIndexerConfig
165
+ connector_type: str = CONNECTOR_TYPE
166
+
167
+ def precheck(self) -> None:
168
+ try:
169
+ self.connection_config.connect_to_couchbase()
170
+ except Exception as e:
171
+ logger.error(f"Failed to validate connection {e}", exc_info=True)
172
+ raise DestinationConnectionError(f"failed to validate connection: {e}")
173
+
174
+ @requires_dependencies(["couchbase"], extras="couchbase")
175
+ def _get_doc_ids(self) -> List[str]:
176
+ query = (
177
+ f"SELECT META(d).id "
178
+ f"FROM `{self.connection_config.bucket}`."
179
+ f"`{self.connection_config.scope}`."
180
+ f"`{self.connection_config.collection}` as d"
181
+ )
182
+
183
+ max_attempts = 5
184
+ attempts = 0
185
+ while attempts < max_attempts:
186
+ try:
187
+ cluster = self.connection_config.connect_to_couchbase()
188
+ result = cluster.query(query)
189
+ document_ids = [row["id"] for row in result]
190
+ return document_ids
191
+ except Exception as e:
192
+ attempts += 1
193
+ time.sleep(3)
194
+ if attempts == max_attempts:
195
+ raise SourceConnectionError(f"failed to get document ids: {e}")
196
+
197
+ def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
198
+ ids = self._get_doc_ids()
199
+
200
+ id_batches = [
201
+ ids[i * self.index_config.batch_size : (i + 1) * self.index_config.batch_size]
202
+ for i in range(
203
+ (len(ids) + self.index_config.batch_size - 1) // self.index_config.batch_size
204
+ )
205
+ ]
206
+ for batch in id_batches:
207
+ # Make sure the hash is always a positive number to create identified
208
+ identified = str(hash(tuple(batch)) + sys.maxsize + 1)
209
+ yield FileData(
210
+ identifier=identified,
211
+ connector_type=CONNECTOR_TYPE,
212
+ metadata=FileDataSourceMetadata(
213
+ url=f"{self.connection_config.connection_string}/"
214
+ f"{self.connection_config.bucket}",
215
+ date_processed=str(time.time()),
216
+ ),
217
+ additional_metadata={
218
+ "ids": list(batch),
219
+ "bucket": self.connection_config.bucket,
220
+ },
221
+ )
222
+
223
+
224
+ class CouchbaseDownloaderConfig(DownloaderConfig):
225
+ fields: list[str] = field(default_factory=list)
226
+
227
+
228
+ @dataclass
229
+ class CouchbaseDownloader(Downloader):
230
+ connection_config: CouchbaseConnectionConfig
231
+ download_config: CouchbaseDownloaderConfig
232
+ connector_type: str = CONNECTOR_TYPE
233
+
234
+ def is_async(self) -> bool:
235
+ return False
236
+
237
+ def get_identifier(self, bucket: str, record_id: str) -> str:
238
+ f = f"{bucket}-{record_id}"
239
+ if self.download_config.fields:
240
+ f = "{}-{}".format(
241
+ f,
242
+ hashlib.sha256(",".join(self.download_config.fields).encode()).hexdigest()[:8],
243
+ )
244
+ return f
245
+
246
+ def map_cb_results(self, cb_results: dict) -> str:
247
+ doc_body = cb_results
248
+ flattened_dict = flatten_dict(dictionary=doc_body)
249
+ str_values = [str(value) for value in flattened_dict.values()]
250
+ concatenated_values = "\n".join(str_values)
251
+ return concatenated_values
252
+
253
+ def generate_download_response(
254
+ self, result: dict, bucket: str, file_data: FileData
255
+ ) -> DownloadResponse:
256
+ record_id = result["id"]
257
+ filename_id = self.get_identifier(bucket=bucket, record_id=record_id)
258
+ filename = f"{filename_id}.txt"
259
+ download_path = self.download_dir / Path(filename)
260
+ logger.debug(
261
+ f"Downloading results from bucket {bucket} and id {record_id} to {download_path}"
262
+ )
263
+ download_path.parent.mkdir(parents=True, exist_ok=True)
264
+ try:
265
+ with open(download_path, "w", encoding="utf8") as f:
266
+ f.write(self.map_cb_results(cb_results=result))
267
+ except Exception as e:
268
+ logger.error(
269
+ f"failed to download from bucket {bucket} "
270
+ f"and id {record_id} to {download_path}: {e}",
271
+ exc_info=True,
272
+ )
273
+ raise SourceConnectionNetworkError(f"failed to download file {file_data.identifier}")
274
+ return DownloadResponse(
275
+ file_data=FileData(
276
+ identifier=filename_id,
277
+ connector_type=CONNECTOR_TYPE,
278
+ metadata=FileDataSourceMetadata(
279
+ version=None,
280
+ date_processed=str(time.time()),
281
+ record_locator={
282
+ "connection_string": self.connection_config.connection_string,
283
+ "bucket": bucket,
284
+ "scope": self.connection_config.scope,
285
+ "collection": self.connection_config.collection,
286
+ "document_id": record_id,
287
+ },
288
+ ),
289
+ ),
290
+ path=download_path,
291
+ )
292
+
293
+ def run(self, file_data: FileData, **kwargs: Any) -> download_responses:
294
+ bucket_name: str = file_data.additional_metadata["bucket"]
295
+ ids: list[str] = file_data.additional_metadata["ids"]
296
+
297
+ cluster = self.connection_config.connect_to_couchbase()
298
+ bucket = cluster.bucket(bucket_name)
299
+ scope = bucket.scope(self.connection_config.scope)
300
+ collection = scope.collection(self.connection_config.collection)
301
+
302
+ download_resp = self.process_all_doc_ids(ids, collection, bucket_name, file_data)
303
+ return list(download_resp)
304
+
305
+ def process_doc_id(self, doc_id, collection, bucket_name, file_data):
306
+ result = collection.get(doc_id)
307
+ return self.generate_download_response(
308
+ result=result.content_as[dict], bucket=bucket_name, file_data=file_data
309
+ )
310
+
311
+ def process_all_doc_ids(self, ids, collection, bucket_name, file_data):
312
+ for doc_id in ids:
313
+ yield self.process_doc_id(doc_id, collection, bucket_name, file_data)
314
+
315
+ async def run_async(self, file_data: FileData, **kwargs: Any) -> download_responses:
316
+ raise NotImplementedError()
317
+
318
+
319
+ couchbase_destination_entry = DestinationRegistryEntry(
320
+ connection_config=CouchbaseConnectionConfig,
321
+ uploader=CouchbaseUploader,
322
+ uploader_config=CouchbaseUploaderConfig,
323
+ upload_stager=CouchbaseUploadStager,
324
+ upload_stager_config=CouchbaseUploadStagerConfig,
325
+ )
326
+
327
+ couchbase_source_entry = SourceRegistryEntry(
328
+ connection_config=CouchbaseConnectionConfig,
329
+ indexer=CouchbaseIndexer,
330
+ indexer_config=CouchbaseIndexerConfig,
331
+ downloader=CouchbaseDownloader,
332
+ downloader_config=CouchbaseDownloaderConfig,
333
+ )
@@ -1,8 +1,9 @@
1
1
  import os
2
- from dataclasses import dataclass, field
2
+ from dataclasses import dataclass
3
3
  from typing import TYPE_CHECKING, Any, Optional
4
4
 
5
- from unstructured_ingest.enhanced_dataclass import enhanced_field
5
+ from pydantic import Field, Secret
6
+
6
7
  from unstructured_ingest.error import DestinationConnectionError
7
8
  from unstructured_ingest.utils.dep_check import requires_dependencies
8
9
  from unstructured_ingest.v2.interfaces import (
@@ -21,45 +22,99 @@ if TYPE_CHECKING:
21
22
  CONNECTOR_TYPE = "databricks_volumes"
22
23
 
23
24
 
24
- @dataclass
25
25
  class DatabricksVolumesAccessConfig(AccessConfig):
26
- account_id: Optional[str] = None
27
- username: Optional[str] = None
28
- password: Optional[str] = None
29
- client_id: Optional[str] = None
30
- client_secret: Optional[str] = None
31
- token: Optional[str] = None
26
+ account_id: Optional[str] = Field(
27
+ default=None,
28
+ description="The Databricks account ID for the Databricks "
29
+ "accounts endpoint. Only has effect when Host is "
30
+ "either https://accounts.cloud.databricks.com/ (AWS), "
31
+ "https://accounts.azuredatabricks.net/ (Azure), "
32
+ "or https://accounts.gcp.databricks.com/ (GCP).",
33
+ )
34
+ username: Optional[str] = Field(
35
+ default=None,
36
+ description="The Databricks username part of basic authentication. "
37
+ "Only possible when Host is *.cloud.databricks.com (AWS).",
38
+ )
39
+ password: Optional[str] = Field(
40
+ default=None,
41
+ description="The Databricks password part of basic authentication. "
42
+ "Only possible when Host is *.cloud.databricks.com (AWS).",
43
+ )
44
+ client_id: Optional[str] = Field(default=None)
45
+ client_secret: Optional[str] = Field(default=None)
46
+ token: Optional[str] = Field(
47
+ default=None,
48
+ description="The Databricks personal access token (PAT) (AWS, Azure, and GCP) or "
49
+ "Azure Active Directory (Azure AD) token (Azure).",
50
+ )
32
51
  profile: Optional[str] = None
33
- azure_workspace_resource_id: Optional[str] = None
34
- azure_client_secret: Optional[str] = None
35
- azure_client_id: Optional[str] = None
36
- azure_tenant_id: Optional[str] = None
37
- azure_environment: Optional[str] = None
38
- auth_type: Optional[str] = None
52
+ azure_workspace_resource_id: Optional[str] = Field(
53
+ default=None,
54
+ description="The Azure Resource Manager ID for the Azure Databricks workspace, "
55
+ "which is exchanged for a Databricks host URL.",
56
+ )
57
+ azure_client_secret: Optional[str] = Field(
58
+ default=None, description="The Azure AD service principal’s client secret."
59
+ )
60
+ azure_client_id: Optional[str] = Field(
61
+ default=None, description="The Azure AD service principal’s application ID."
62
+ )
63
+ azure_tenant_id: Optional[str] = Field(
64
+ default=None, description="The Azure AD service principal’s tenant ID."
65
+ )
66
+ azure_environment: Optional[str] = Field(
67
+ default=None,
68
+ description="The Azure environment type for a " "specific set of API endpoints",
69
+ examples=["Public", "UsGov", "China", "Germany"],
70
+ )
71
+ auth_type: Optional[str] = Field(
72
+ default=None,
73
+ description="When multiple auth attributes are available in the "
74
+ "environment, use the auth type specified by this "
75
+ "argument. This argument also holds the currently "
76
+ "selected auth.",
77
+ )
39
78
  cluster_id: Optional[str] = None
40
79
  google_credentials: Optional[str] = None
41
80
  google_service_account: Optional[str] = None
42
81
 
43
82
 
44
- @dataclass
83
+ SecretDatabricksVolumesAccessConfig = Secret[DatabricksVolumesAccessConfig]
84
+
85
+
45
86
  class DatabricksVolumesConnectionConfig(ConnectionConfig):
46
- access_config: DatabricksVolumesAccessConfig = enhanced_field(
47
- default_factory=DatabricksVolumesAccessConfig, sensitive=True
87
+ access_config: SecretDatabricksVolumesAccessConfig = Field(
88
+ default_factory=lambda: SecretDatabricksVolumesAccessConfig(
89
+ secret_value=DatabricksVolumesAccessConfig()
90
+ )
91
+ )
92
+ host: Optional[str] = Field(
93
+ default=None,
94
+ description="The Databricks host URL for either the "
95
+ "Databricks workspace endpoint or the "
96
+ "Databricks accounts endpoint.",
48
97
  )
49
- host: Optional[str] = None
50
98
 
51
99
 
52
- @dataclass
53
100
  class DatabricksVolumesUploaderConfig(UploaderConfig):
54
- volume: str
55
- catalog: str
56
- volume_path: Optional[str] = None
57
- overwrite: bool = False
58
- schema: str = "default"
101
+ volume: str = Field(description="Name of volume in the Unity Catalog")
102
+ catalog: str = Field(description="Name of the catalog in the Databricks Unity Catalog service")
103
+ volume_path: Optional[str] = Field(
104
+ default=None, description="Optional path within the volume to write to"
105
+ )
106
+ overwrite: bool = Field(
107
+ default=False, description="If true, an existing file will be overwritten."
108
+ )
109
+ databricks_schema: str = Field(
110
+ default="default",
111
+ alias="schema",
112
+ description="Schema associated with the volume to write to in the Unity Catalog service",
113
+ )
59
114
 
60
115
  @property
61
116
  def path(self) -> str:
62
- path = f"/Volumes/{self.catalog}/{self.schema}/{self.volume}"
117
+ path = f"/Volumes/{self.catalog}/{self.databricks_schema}/{self.volume}"
63
118
  if self.volume_path:
64
119
  path = f"{path}/{self.volume_path}"
65
120
  return path
@@ -70,19 +125,19 @@ class DatabricksVolumesUploader(Uploader):
70
125
  connector_type: str = CONNECTOR_TYPE
71
126
  upload_config: DatabricksVolumesUploaderConfig
72
127
  connection_config: DatabricksVolumesConnectionConfig
73
- client: Optional["WorkspaceClient"] = field(init=False, default=None)
74
128
 
75
129
  @requires_dependencies(dependencies=["databricks.sdk"], extras="databricks-volumes")
76
- def __post_init__(self) -> "WorkspaceClient":
130
+ def get_client(self) -> "WorkspaceClient":
77
131
  from databricks.sdk import WorkspaceClient
78
132
 
79
- self.client = WorkspaceClient(
80
- host=self.connection_config.host, **self.connection_config.access_config.to_dict()
133
+ return WorkspaceClient(
134
+ host=self.connection_config.host,
135
+ **self.connection_config.access_config.get_secret_value().dict(),
81
136
  )
82
137
 
83
138
  def precheck(self) -> None:
84
139
  try:
85
- assert self.client.current_user.me().active
140
+ assert self.get_client().current_user.me().active
86
141
  except Exception as e:
87
142
  logger.error(f"failed to validate connection: {e}", exc_info=True)
88
143
  raise DestinationConnectionError(f"failed to validate connection: {e}")
@@ -91,7 +146,7 @@ class DatabricksVolumesUploader(Uploader):
91
146
  for content in contents:
92
147
  with open(content.path, "rb") as elements_file:
93
148
  output_path = os.path.join(self.upload_config.path, content.path.name)
94
- self.client.files.upload(
149
+ self.get_client().files.upload(
95
150
  file_path=output_path,
96
151
  contents=elements_file,
97
152
  overwrite=self.upload_config.overwrite,
@@ -5,9 +5,10 @@ import uuid
5
5
  from dataclasses import dataclass, field
6
6
  from pathlib import Path
7
7
  from time import time
8
- from typing import TYPE_CHECKING, Any, Generator, Optional
8
+ from typing import TYPE_CHECKING, Any, Generator, Optional, Union
9
+
10
+ from pydantic import BaseModel, Field, Secret, SecretStr
9
11
 
10
- from unstructured_ingest.enhanced_dataclass import EnhancedDataClassJsonMixin, enhanced_field
11
12
  from unstructured_ingest.error import (
12
13
  DestinationConnectionError,
13
14
  SourceConnectionError,
@@ -44,57 +45,74 @@ if TYPE_CHECKING:
44
45
  CONNECTOR_TYPE = "elasticsearch"
45
46
 
46
47
 
47
- @dataclass
48
48
  class ElasticsearchAccessConfig(AccessConfig):
49
- password: Optional[str] = None
50
- api_key: Optional[str] = enhanced_field(default=None, overload_name="es_api_key")
51
- bearer_auth: Optional[str] = None
52
- ssl_assert_fingerprint: Optional[str] = None
53
-
54
-
55
- @dataclass
56
- class ElasticsearchClientInput(EnhancedDataClassJsonMixin):
49
+ password: Optional[str] = Field(
50
+ default=None, description="password when using basic auth or connecting to a cloud instance"
51
+ )
52
+ es_api_key: Optional[str] = Field(default=None, description="api key used for authentication")
53
+ bearer_auth: Optional[str] = Field(
54
+ default=None, description="bearer token used for HTTP bearer authentication"
55
+ )
56
+ ssl_assert_fingerprint: Optional[str] = Field(
57
+ default=None, description="SHA256 fingerprint value"
58
+ )
59
+
60
+
61
+ class ElasticsearchClientInput(BaseModel):
57
62
  hosts: Optional[list[str]] = None
58
63
  cloud_id: Optional[str] = None
59
- ca_certs: Optional[str] = None
60
- basic_auth: Optional[tuple[str, str]] = enhanced_field(sensitive=True, default=None)
61
- api_key: Optional[str] = enhanced_field(sensitive=True, default=None)
64
+ ca_certs: Optional[Path] = None
65
+ basic_auth: Optional[Secret[tuple[str, str]]] = None
66
+ api_key: Optional[Union[Secret[tuple[str, str]], SecretStr]] = None
62
67
 
63
68
 
64
- @dataclass
65
69
  class ElasticsearchConnectionConfig(ConnectionConfig):
66
- hosts: Optional[list[str]] = None
67
- username: Optional[str] = None
68
- cloud_id: Optional[str] = None
69
- api_key_id: Optional[str] = None
70
- ca_certs: Optional[str] = None
71
- access_config: ElasticsearchAccessConfig = enhanced_field(sensitive=True)
70
+ hosts: Optional[list[str]] = Field(
71
+ default=None,
72
+ description="list of the Elasticsearch hosts to connect to",
73
+ examples=["http://localhost:9200"],
74
+ )
75
+ username: Optional[str] = Field(default=None, description="username when using basic auth")
76
+ cloud_id: Optional[str] = Field(default=None, description="id used to connect to Elastic Cloud")
77
+ api_key_id: Optional[str] = Field(
78
+ default=None,
79
+ description="id associated with api key used for authentication: "
80
+ "https://www.elastic.co/guide/en/elasticsearch/reference/current/security-api-create-api-key.html", # noqa: E501
81
+ )
82
+ ca_certs: Optional[Path] = None
83
+ access_config: Secret[ElasticsearchAccessConfig]
72
84
 
73
85
  def get_client_kwargs(self) -> dict:
74
86
  # Update auth related fields to conform to what the SDK expects based on the
75
87
  # supported methods:
76
88
  # https://www.elastic.co/guide/en/elasticsearch/client/python-api/current/connecting.html
77
- client_input = ElasticsearchClientInput()
89
+ client_input_kwargs: dict[str, Any] = {}
90
+ access_config = self.access_config.get_secret_value()
78
91
  if self.hosts:
79
- client_input.hosts = self.hosts
92
+ client_input_kwargs["hosts"] = self.hosts
80
93
  if self.cloud_id:
81
- client_input.cloud_id = self.cloud_id
94
+ client_input_kwargs["cloud_id"] = self.cloud_id
82
95
  if self.ca_certs:
83
- client_input.ca_certs = self.ca_certs
84
- if self.access_config.password and (
85
- self.cloud_id or self.ca_certs or self.access_config.ssl_assert_fingerprint
96
+ client_input_kwargs["ca_certs"] = self.ca_certs
97
+ if access_config.password and (
98
+ self.cloud_id or self.ca_certs or access_config.ssl_assert_fingerprint
86
99
  ):
87
- client_input.basic_auth = ("elastic", self.access_config.password)
88
- elif not self.cloud_id and self.username and self.access_config.password:
89
- client_input.basic_auth = (self.username, self.access_config.password)
90
- elif self.access_config.api_key and self.api_key_id:
91
- client_input.api_key = (self.api_key_id, self.access_config.api_key)
92
- elif self.access_config.api_key:
93
- client_input.api_key = self.access_config.api_key
94
- logger.debug(
95
- f"Elasticsearch client inputs mapped to: {client_input.to_dict(redact_sensitive=True)}"
100
+ client_input_kwargs["basic_auth"] = ("elastic", access_config.password)
101
+ elif not self.cloud_id and self.username and access_config.password:
102
+ client_input_kwargs["basic_auth"] = (self.username, access_config.password)
103
+ elif access_config.es_api_key and self.api_key_id:
104
+ client_input_kwargs["api_key"] = (self.api_key_id, access_config.es_api_key)
105
+ elif access_config.es_api_key:
106
+ client_input_kwargs["api_key"] = access_config.es_api_key
107
+ client_input = ElasticsearchClientInput(**client_input_kwargs)
108
+ logger.debug(f"Elasticsearch client inputs mapped to: {client_input.dict()}")
109
+ client_kwargs = client_input.dict()
110
+ client_kwargs["basic_auth"] = (
111
+ client_input.basic_auth.get_secret_value() if client_input.basic_auth else None
112
+ )
113
+ client_kwargs["api_key"] = (
114
+ client_input.api_key.get_secret_value() if client_input.api_key else None
96
115
  )
97
- client_kwargs = client_input.to_dict(redact_sensitive=False)
98
116
  client_kwargs = {k: v for k, v in client_kwargs.items() if v is not None}
99
117
  return client_kwargs
100
118
 
@@ -114,7 +132,6 @@ class ElasticsearchConnectionConfig(ConnectionConfig):
114
132
  raise SourceConnectionError(f"failed to validate connection: {e}")
115
133
 
116
134
 
117
- @dataclass
118
135
  class ElasticsearchIndexerConfig(IndexerConfig):
119
136
  index_name: str
120
137
  batch_size: int = 100
@@ -186,7 +203,6 @@ class ElasticsearchIndexer(Indexer):
186
203
  )
187
204
 
188
205
 
189
- @dataclass
190
206
  class ElasticsearchDownloaderConfig(DownloaderConfig):
191
207
  fields: list[str] = field(default_factory=list)
192
208
 
@@ -292,9 +308,10 @@ class ElasticsearchDownloader(Downloader):
292
308
  return download_responses
293
309
 
294
310
 
295
- @dataclass
296
311
  class ElasticsearchUploadStagerConfig(UploadStagerConfig):
297
- index_name: str
312
+ index_name: str = Field(
313
+ description="Name of the Elasticsearch index to pull data from, or upload data to."
314
+ )
298
315
 
299
316
 
300
317
  @dataclass
@@ -333,11 +350,19 @@ class ElasticsearchUploadStager(UploadStager):
333
350
  return output_path
334
351
 
335
352
 
336
- @dataclass
337
353
  class ElasticsearchUploaderConfig(UploaderConfig):
338
- index_name: str
339
- batch_size_bytes: int = 15_000_000
340
- num_threads: int = 4
354
+ index_name: str = Field(
355
+ description="Name of the Elasticsearch index to pull data from, or upload data to."
356
+ )
357
+ batch_size_bytes: int = Field(
358
+ default=15_000_000,
359
+ description="Size limit (in bytes) for each batch of items to be uploaded. Check"
360
+ " https://www.elastic.co/guide/en/elasticsearch/guide/current/bulk.html"
361
+ "#_how_big_is_too_big for more information.",
362
+ )
363
+ num_threads: int = Field(
364
+ default=4, description="Number of threads to be used while uploading content"
365
+ )
341
366
 
342
367
 
343
368
  @dataclass