unstructured-ingest 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (51) hide show
  1. test/integration/connectors/elasticsearch/__init__.py +0 -0
  2. test/integration/connectors/elasticsearch/conftest.py +34 -0
  3. test/integration/connectors/elasticsearch/test_elasticsearch.py +308 -0
  4. test/integration/connectors/elasticsearch/test_opensearch.py +302 -0
  5. test/integration/connectors/sql/test_postgres.py +10 -4
  6. test/integration/connectors/sql/test_singlestore.py +8 -4
  7. test/integration/connectors/sql/test_snowflake.py +10 -6
  8. test/integration/connectors/sql/test_sqlite.py +4 -4
  9. test/integration/connectors/test_astradb.py +50 -3
  10. test/integration/connectors/test_delta_table.py +46 -0
  11. test/integration/connectors/test_kafka.py +40 -6
  12. test/integration/connectors/test_lancedb.py +209 -0
  13. test/integration/connectors/test_milvus.py +141 -0
  14. test/integration/connectors/test_pinecone.py +53 -1
  15. test/integration/connectors/utils/docker.py +81 -15
  16. test/integration/connectors/utils/validation.py +10 -0
  17. test/integration/connectors/weaviate/__init__.py +0 -0
  18. test/integration/connectors/weaviate/conftest.py +15 -0
  19. test/integration/connectors/weaviate/test_local.py +131 -0
  20. unstructured_ingest/__version__.py +1 -1
  21. unstructured_ingest/pipeline/reformat/embedding.py +1 -1
  22. unstructured_ingest/utils/data_prep.py +9 -1
  23. unstructured_ingest/v2/processes/connectors/__init__.py +3 -16
  24. unstructured_ingest/v2/processes/connectors/astradb.py +2 -2
  25. unstructured_ingest/v2/processes/connectors/azure_ai_search.py +4 -0
  26. unstructured_ingest/v2/processes/connectors/delta_table.py +20 -4
  27. unstructured_ingest/v2/processes/connectors/elasticsearch/__init__.py +19 -0
  28. unstructured_ingest/v2/processes/connectors/{elasticsearch.py → elasticsearch/elasticsearch.py} +92 -46
  29. unstructured_ingest/v2/processes/connectors/{opensearch.py → elasticsearch/opensearch.py} +1 -1
  30. unstructured_ingest/v2/processes/connectors/kafka/kafka.py +6 -0
  31. unstructured_ingest/v2/processes/connectors/lancedb/__init__.py +17 -0
  32. unstructured_ingest/v2/processes/connectors/lancedb/aws.py +43 -0
  33. unstructured_ingest/v2/processes/connectors/lancedb/azure.py +43 -0
  34. unstructured_ingest/v2/processes/connectors/lancedb/gcp.py +44 -0
  35. unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +161 -0
  36. unstructured_ingest/v2/processes/connectors/lancedb/local.py +44 -0
  37. unstructured_ingest/v2/processes/connectors/milvus.py +72 -27
  38. unstructured_ingest/v2/processes/connectors/pinecone.py +24 -7
  39. unstructured_ingest/v2/processes/connectors/sql/sql.py +97 -26
  40. unstructured_ingest/v2/processes/connectors/weaviate/__init__.py +22 -0
  41. unstructured_ingest/v2/processes/connectors/weaviate/cloud.py +164 -0
  42. unstructured_ingest/v2/processes/connectors/weaviate/embedded.py +90 -0
  43. unstructured_ingest/v2/processes/connectors/weaviate/local.py +73 -0
  44. unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +289 -0
  45. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/METADATA +15 -15
  46. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/RECORD +50 -30
  47. unstructured_ingest/v2/processes/connectors/weaviate.py +0 -242
  48. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/LICENSE.md +0 -0
  49. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/WHEEL +0 -0
  50. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/entry_points.txt +0 -0
  51. {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,8 @@
1
1
  import json
2
+ from contextlib import contextmanager
2
3
  from dataclasses import dataclass, field
3
4
  from pathlib import Path
4
- from typing import TYPE_CHECKING, Any, Optional, Union
5
+ from typing import TYPE_CHECKING, Any, Generator, Optional, Union
5
6
 
6
7
  import pandas as pd
7
8
  from dateutil import parser
@@ -10,6 +11,7 @@ from pydantic import Field, Secret
10
11
  from unstructured_ingest.error import WriteError
11
12
  from unstructured_ingest.utils.data_prep import flatten_dict
12
13
  from unstructured_ingest.utils.dep_check import requires_dependencies
14
+ from unstructured_ingest.v2.constants import RECORD_ID_LABEL
13
15
  from unstructured_ingest.v2.interfaces import (
14
16
  AccessConfig,
15
17
  ConnectionConfig,
@@ -90,24 +92,27 @@ class MilvusUploadStager(UploadStager):
90
92
  pass
91
93
  return parser.parse(date_string).timestamp()
92
94
 
93
- def conform_dict(self, data: dict) -> None:
94
- if self.upload_stager_config.flatten_metadata and (metadata := data.pop("metadata", None)):
95
- data.update(flatten_dict(metadata, keys_to_omit=["data_source_record_locator"]))
95
+ def conform_dict(self, data: dict, file_data: FileData) -> dict:
96
+ working_data = data.copy()
97
+ if self.upload_stager_config.flatten_metadata and (
98
+ metadata := working_data.pop("metadata", None)
99
+ ):
100
+ working_data.update(flatten_dict(metadata, keys_to_omit=["data_source_record_locator"]))
96
101
 
97
102
  # TODO: milvus sdk doesn't seem to support defaults via the schema yet,
98
103
  # remove once that gets updated
99
104
  defaults = {"is_continuation": False}
100
105
  for default in defaults:
101
- if default not in data:
102
- data[default] = defaults[default]
106
+ if default not in working_data:
107
+ working_data[default] = defaults[default]
103
108
 
104
109
  if self.upload_stager_config.fields_to_include:
105
- data_keys = set(data.keys())
110
+ data_keys = set(working_data.keys())
106
111
  for data_key in data_keys:
107
112
  if data_key not in self.upload_stager_config.fields_to_include:
108
- data.pop(data_key)
113
+ working_data.pop(data_key)
109
114
  for field_include_key in self.upload_stager_config.fields_to_include:
110
- if field_include_key not in data:
115
+ if field_include_key not in working_data:
111
116
  raise KeyError(f"Field '{field_include_key}' is missing in data!")
112
117
 
113
118
  datetime_columns = [
@@ -120,11 +125,15 @@ class MilvusUploadStager(UploadStager):
120
125
  json_dumps_fields = ["languages", "data_source_permissions_data"]
121
126
 
122
127
  for datetime_column in datetime_columns:
123
- if datetime_column in data:
124
- data[datetime_column] = self.parse_date_string(data[datetime_column])
128
+ if datetime_column in working_data:
129
+ working_data[datetime_column] = self.parse_date_string(
130
+ working_data[datetime_column]
131
+ )
125
132
  for json_dumps_field in json_dumps_fields:
126
- if json_dumps_field in data:
127
- data[json_dumps_field] = json.dumps(data[json_dumps_field])
133
+ if json_dumps_field in working_data:
134
+ working_data[json_dumps_field] = json.dumps(working_data[json_dumps_field])
135
+ working_data[RECORD_ID_LABEL] = file_data.identifier
136
+ return working_data
128
137
 
129
138
  def run(
130
139
  self,
@@ -136,18 +145,27 @@ class MilvusUploadStager(UploadStager):
136
145
  ) -> Path:
137
146
  with open(elements_filepath) as elements_file:
138
147
  elements_contents: list[dict[str, Any]] = json.load(elements_file)
139
- for element in elements_contents:
140
- self.conform_dict(data=element)
141
-
142
- output_path = Path(output_dir) / Path(f"{output_filename}.json")
148
+ new_content = [
149
+ self.conform_dict(data=element, file_data=file_data) for element in elements_contents
150
+ ]
151
+ output_filename_path = Path(output_filename)
152
+ if output_filename_path.suffix == ".json":
153
+ output_path = Path(output_dir) / output_filename_path
154
+ else:
155
+ output_path = Path(output_dir) / output_filename_path.with_suffix(".json")
143
156
  output_path.parent.mkdir(parents=True, exist_ok=True)
144
157
  with output_path.open("w") as output_file:
145
- json.dump(elements_contents, output_file, indent=2)
158
+ json.dump(new_content, output_file, indent=2)
146
159
  return output_path
147
160
 
148
161
 
149
162
  class MilvusUploaderConfig(UploaderConfig):
163
+ db_name: Optional[str] = Field(default=None, description="Milvus database name")
150
164
  collection_name: str = Field(description="Milvus collections to write to")
165
+ record_id_key: str = Field(
166
+ default=RECORD_ID_LABEL,
167
+ description="searchable key to find entries for the same record on previous runs",
168
+ )
151
169
 
152
170
 
153
171
  @dataclass
@@ -156,6 +174,16 @@ class MilvusUploader(Uploader):
156
174
  upload_config: MilvusUploaderConfig
157
175
  connector_type: str = CONNECTOR_TYPE
158
176
 
177
+ @contextmanager
178
+ def get_client(self) -> Generator["MilvusClient", None, None]:
179
+ client = self.connection_config.get_client()
180
+ if db_name := self.upload_config.db_name:
181
+ client.using_database(db_name=db_name)
182
+ try:
183
+ yield client
184
+ finally:
185
+ client.close()
186
+
159
187
  def upload(self, content: UploadContent) -> None:
160
188
  file_extension = content.path.suffix
161
189
  if file_extension == ".json":
@@ -165,23 +193,39 @@ class MilvusUploader(Uploader):
165
193
  else:
166
194
  raise ValueError(f"Unsupported file extension: {file_extension}")
167
195
 
196
+ def delete_by_record_id(self, file_data: FileData) -> None:
197
+ logger.info(
198
+ f"deleting any content with metadata {RECORD_ID_LABEL}={file_data.identifier} "
199
+ f"from milvus collection {self.upload_config.collection_name}"
200
+ )
201
+ with self.get_client() as client:
202
+ delete_filter = f'{self.upload_config.record_id_key} == "{file_data.identifier}"'
203
+ resp = client.delete(
204
+ collection_name=self.upload_config.collection_name, filter=delete_filter
205
+ )
206
+ logger.info(
207
+ "deleted {} records from milvus collection {}".format(
208
+ resp["delete_count"], self.upload_config.collection_name
209
+ )
210
+ )
211
+
168
212
  @requires_dependencies(["pymilvus"], extras="milvus")
169
213
  def insert_results(self, data: Union[dict, list[dict]]):
170
214
  from pymilvus import MilvusException
171
215
 
172
- logger.debug(
216
+ logger.info(
173
217
  f"uploading {len(data)} entries to {self.connection_config.db_name} "
174
218
  f"db in collection {self.upload_config.collection_name}"
175
219
  )
176
- client = self.connection_config.get_client()
220
+ with self.get_client() as client:
177
221
 
178
- try:
179
- res = client.insert(collection_name=self.upload_config.collection_name, data=data)
180
- except MilvusException as milvus_exception:
181
- raise WriteError("failed to upload records to milvus") from milvus_exception
182
- if "err_count" in res and isinstance(res["err_count"], int) and res["err_count"] > 0:
183
- err_count = res["err_count"]
184
- raise WriteError(f"failed to upload {err_count} docs")
222
+ try:
223
+ res = client.insert(collection_name=self.upload_config.collection_name, data=data)
224
+ except MilvusException as milvus_exception:
225
+ raise WriteError("failed to upload records to milvus") from milvus_exception
226
+ if "err_count" in res and isinstance(res["err_count"], int) and res["err_count"] > 0:
227
+ err_count = res["err_count"]
228
+ raise WriteError(f"failed to upload {err_count} docs")
185
229
 
186
230
  def upload_csv(self, content: UploadContent) -> None:
187
231
  df = pd.read_csv(content.path)
@@ -194,6 +238,7 @@ class MilvusUploader(Uploader):
194
238
  self.insert_results(data=data)
195
239
 
196
240
  def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
241
+ self.delete_by_record_id(file_data=file_data)
197
242
  self.upload(content=UploadContent(path=path, file_data=file_data))
198
243
 
199
244
 
@@ -30,6 +30,7 @@ if TYPE_CHECKING:
30
30
  CONNECTOR_TYPE = "pinecone"
31
31
  MAX_PAYLOAD_SIZE = 2 * 1024 * 1024 # 2MB
32
32
  MAX_POOL_THREADS = 100
33
+ MAX_METADATA_BYTES = 40960 # 40KB https://docs.pinecone.io/reference/quotas-and-limits#hard-limits
33
34
 
34
35
 
35
36
  class PineconeAccessConfig(AccessConfig):
@@ -103,6 +104,10 @@ class PineconeUploaderConfig(UploaderConfig):
103
104
  default=None,
104
105
  description="The namespace to write to. If not specified, the default namespace is used",
105
106
  )
107
+ record_id_key: str = Field(
108
+ default=RECORD_ID_LABEL,
109
+ description="searchable key to find entries for the same record on previous runs",
110
+ )
106
111
 
107
112
 
108
113
  @dataclass
@@ -133,6 +138,13 @@ class PineconeUploadStager(UploadStager):
133
138
  remove_none=True,
134
139
  )
135
140
  metadata[RECORD_ID_LABEL] = file_data.identifier
141
+ metadata_size_bytes = len(json.dumps(metadata).encode())
142
+ if metadata_size_bytes > MAX_METADATA_BYTES:
143
+ logger.info(
144
+ f"Metadata size is {metadata_size_bytes} bytes, which exceeds the limit of"
145
+ f" {MAX_METADATA_BYTES} bytes per vector. Dropping the metadata."
146
+ )
147
+ metadata = {}
136
148
 
137
149
  return {
138
150
  "id": str(uuid.uuid4()),
@@ -183,23 +195,28 @@ class PineconeUploader(Uploader):
183
195
 
184
196
  def pod_delete_by_record_id(self, file_data: FileData) -> None:
185
197
  logger.debug(
186
- f"deleting any content with metadata {RECORD_ID_LABEL}={file_data.identifier} "
198
+ f"deleting any content with metadata "
199
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
187
200
  f"from pinecone pod index"
188
201
  )
189
202
  index = self.connection_config.get_index(pool_threads=MAX_POOL_THREADS)
190
- delete_kwargs = {"filter": {RECORD_ID_LABEL: {"$eq": file_data.identifier}}}
203
+ delete_kwargs = {
204
+ "filter": {self.upload_config.record_id_key: {"$eq": file_data.identifier}}
205
+ }
191
206
  if namespace := self.upload_config.namespace:
192
207
  delete_kwargs["namespace"] = namespace
193
208
 
194
209
  resp = index.delete(**delete_kwargs)
195
210
  logger.debug(
196
- f"deleted any content with metadata {RECORD_ID_LABEL}={file_data.identifier} "
211
+ f"deleted any content with metadata "
212
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
197
213
  f"from pinecone index: {resp}"
198
214
  )
199
215
 
200
216
  def serverless_delete_by_record_id(self, file_data: FileData) -> None:
201
217
  logger.debug(
202
- f"deleting any content with metadata {RECORD_ID_LABEL}={file_data.identifier} "
218
+ f"deleting any content with metadata "
219
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
203
220
  f"from pinecone serverless index"
204
221
  )
205
222
  index = self.connection_config.get_index(pool_threads=MAX_POOL_THREADS)
@@ -209,7 +226,7 @@ class PineconeUploader(Uploader):
209
226
  return
210
227
  dimension = index_stats["dimension"]
211
228
  query_params = {
212
- "filter": {RECORD_ID_LABEL: {"$eq": file_data.identifier}},
229
+ "filter": {self.upload_config.record_id_key: {"$eq": file_data.identifier}},
213
230
  "vector": [0] * dimension,
214
231
  "top_k": total_vectors,
215
232
  }
@@ -226,7 +243,8 @@ class PineconeUploader(Uploader):
226
243
  delete_params["namespace"] = namespace
227
244
  index.delete(**delete_params)
228
245
  logger.debug(
229
- f"deleted any content with metadata {RECORD_ID_LABEL}={file_data.identifier} "
246
+ f"deleted any content with metadata "
247
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
230
248
  f"from pinecone index"
231
249
  )
232
250
 
@@ -269,7 +287,6 @@ class PineconeUploader(Uploader):
269
287
  f"writing a total of {len(elements_dict)} elements via"
270
288
  f" document batches to destination"
271
289
  f" index named {self.connection_config.index_name}"
272
- f" with batch size {self.upload_config.batch_size}"
273
290
  )
274
291
  # Determine if serverless or pod based index
275
292
  pinecone_client = self.connection_config.get_client()
@@ -16,6 +16,8 @@ from dateutil import parser
16
16
  from pydantic import Field, Secret
17
17
 
18
18
  from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
19
+ from unstructured_ingest.utils.data_prep import split_dataframe
20
+ from unstructured_ingest.v2.constants import RECORD_ID_LABEL
19
21
  from unstructured_ingest.v2.interfaces import (
20
22
  AccessConfig,
21
23
  ConnectionConfig,
@@ -236,35 +238,25 @@ class SQLUploadStagerConfig(UploadStagerConfig):
236
238
  class SQLUploadStager(UploadStager):
237
239
  upload_stager_config: SQLUploadStagerConfig = field(default_factory=SQLUploadStagerConfig)
238
240
 
239
- def run(
240
- self,
241
- elements_filepath: Path,
242
- file_data: FileData,
243
- output_dir: Path,
244
- output_filename: str,
245
- **kwargs: Any,
246
- ) -> Path:
247
- with open(elements_filepath) as elements_file:
248
- elements_contents: list[dict] = json.load(elements_file)
249
- output_path = Path(output_dir) / Path(f"{output_filename}.json")
250
- output_path.parent.mkdir(parents=True, exist_ok=True)
251
-
241
+ @staticmethod
242
+ def conform_dict(data: dict, file_data: FileData) -> pd.DataFrame:
243
+ working_data = data.copy()
252
244
  output = []
253
- for data in elements_contents:
254
- metadata: dict[str, Any] = data.pop("metadata", {})
245
+ for element in working_data:
246
+ metadata: dict[str, Any] = element.pop("metadata", {})
255
247
  data_source = metadata.pop("data_source", {})
256
248
  coordinates = metadata.pop("coordinates", {})
257
249
 
258
- data.update(metadata)
259
- data.update(data_source)
260
- data.update(coordinates)
250
+ element.update(metadata)
251
+ element.update(data_source)
252
+ element.update(coordinates)
261
253
 
262
- data["id"] = str(uuid.uuid4())
254
+ element["id"] = str(uuid.uuid4())
263
255
 
264
256
  # remove extraneous, not supported columns
265
- data = {k: v for k, v in data.items() if k in _COLUMNS}
266
-
267
- output.append(data)
257
+ element = {k: v for k, v in element.items() if k in _COLUMNS}
258
+ element[RECORD_ID_LABEL] = file_data.identifier
259
+ output.append(element)
268
260
 
269
261
  df = pd.DataFrame.from_dict(output)
270
262
  for column in filter(lambda x: x in df.columns, _DATE_COLUMNS):
@@ -281,6 +273,26 @@ class SQLUploadStager(UploadStager):
281
273
  ("version", "page_number", "regex_metadata"),
282
274
  ):
283
275
  df[column] = df[column].apply(str)
276
+ return df
277
+
278
+ def run(
279
+ self,
280
+ elements_filepath: Path,
281
+ file_data: FileData,
282
+ output_dir: Path,
283
+ output_filename: str,
284
+ **kwargs: Any,
285
+ ) -> Path:
286
+ with open(elements_filepath) as elements_file:
287
+ elements_contents: list[dict] = json.load(elements_file)
288
+
289
+ df = self.conform_dict(data=elements_contents, file_data=file_data)
290
+ if Path(output_filename).suffix != ".json":
291
+ output_filename = f"{output_filename}.json"
292
+ else:
293
+ output_filename = f"{Path(output_filename).stem}.json"
294
+ output_path = Path(output_dir) / Path(f"{output_filename}")
295
+ output_path.parent.mkdir(parents=True, exist_ok=True)
284
296
 
285
297
  with output_path.open("w") as output_file:
286
298
  df.to_json(output_file, orient="records", lines=True)
@@ -290,6 +302,10 @@ class SQLUploadStager(UploadStager):
290
302
  class SQLUploaderConfig(UploaderConfig):
291
303
  batch_size: int = Field(default=50, description="Number of records per batch")
292
304
  table_name: str = Field(default="elements", description="which table to upload contents to")
305
+ record_id_key: str = Field(
306
+ default=RECORD_ID_LABEL,
307
+ description="searchable key to find entries for the same record on previous runs",
308
+ )
293
309
 
294
310
 
295
311
  @dataclass
@@ -323,18 +339,45 @@ class SQLUploader(Uploader):
323
339
  output.append(tuple(parsed))
324
340
  return output
325
341
 
342
+ def _fit_to_schema(self, df: pd.DataFrame, columns: list[str]) -> pd.DataFrame:
343
+ columns = set(df.columns)
344
+ schema_fields = set(columns)
345
+ columns_to_drop = columns - schema_fields
346
+ missing_columns = schema_fields - columns
347
+
348
+ if columns_to_drop:
349
+ logger.warning(
350
+ "Following columns will be dropped to match the table's schema: "
351
+ f"{', '.join(columns_to_drop)}"
352
+ )
353
+ if missing_columns:
354
+ logger.info(
355
+ "Following null filled columns will be added to match the table's schema:"
356
+ f" {', '.join(missing_columns)} "
357
+ )
358
+
359
+ df = df.drop(columns=columns_to_drop)
360
+
361
+ for column in missing_columns:
362
+ df[column] = pd.Series()
363
+
326
364
  def upload_contents(self, path: Path) -> None:
327
365
  df = pd.read_json(path, orient="records", lines=True)
328
366
  df.replace({np.nan: None}, inplace=True)
367
+ self._fit_to_schema(df=df, columns=self.get_table_columns())
329
368
 
330
369
  columns = list(df.columns)
331
370
  stmt = f"INSERT INTO {self.upload_config.table_name} ({','.join(columns)}) VALUES({','.join([self.values_delimiter for x in columns])})" # noqa E501
332
-
333
- for rows in pd.read_json(
334
- path, orient="records", lines=True, chunksize=self.upload_config.batch_size
335
- ):
371
+ logger.info(
372
+ f"writing a total of {len(df)} elements via"
373
+ f" document batches to destination"
374
+ f" table named {self.upload_config.table_name}"
375
+ f" with batch size {self.upload_config.batch_size}"
376
+ )
377
+ for rows in split_dataframe(df=df, chunk_size=self.upload_config.batch_size):
336
378
  with self.connection_config.get_cursor() as cursor:
337
379
  values = self.prepare_data(columns, tuple(rows.itertuples(index=False, name=None)))
380
+ # For debugging purposes:
338
381
  # for val in values:
339
382
  # try:
340
383
  # cursor.execute(stmt, val)
@@ -343,5 +386,33 @@ class SQLUploader(Uploader):
343
386
  # print(f"failed to write {len(columns)}, {len(val)}: {stmt} -> {val}")
344
387
  cursor.executemany(stmt, values)
345
388
 
389
+ def get_table_columns(self) -> list[str]:
390
+ with self.connection_config.get_cursor() as cursor:
391
+ cursor.execute(f"SELECT * from {self.upload_config.table_name}")
392
+ return [desc[0] for desc in cursor.description]
393
+
394
+ def can_delete(self) -> bool:
395
+ return self.upload_config.record_id_key in self.get_table_columns()
396
+
397
+ def delete_by_record_id(self, file_data: FileData) -> None:
398
+ logger.debug(
399
+ f"deleting any content with data "
400
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
401
+ f"from table {self.upload_config.table_name}"
402
+ )
403
+ stmt = f"DELETE FROM {self.upload_config.table_name} WHERE {self.upload_config.record_id_key} = {self.values_delimiter}" # noqa: E501
404
+ with self.connection_config.get_cursor() as cursor:
405
+ cursor.execute(stmt, [file_data.identifier])
406
+ rowcount = cursor.rowcount
407
+ logger.info(f"deleted {rowcount} rows from table {self.upload_config.table_name}")
408
+
346
409
  def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
410
+ if self.can_delete():
411
+ self.delete_by_record_id(file_data=file_data)
412
+ else:
413
+ logger.warning(
414
+ f"table doesn't contain expected "
415
+ f"record id column "
416
+ f"{self.upload_config.record_id_key}, skipping delete"
417
+ )
347
418
  self.upload_contents(path=path)
@@ -0,0 +1,22 @@
1
+ from __future__ import annotations
2
+
3
+ from unstructured_ingest.v2.processes.connector_registry import (
4
+ add_destination_entry,
5
+ )
6
+
7
+ from .cloud import CONNECTOR_TYPE as CLOUD_WEAVIATE_CONNECTOR_TYPE
8
+ from .cloud import weaviate_cloud_destination_entry
9
+ from .embedded import CONNECTOR_TYPE as EMBEDDED_WEAVIATE_CONNECTOR_TYPE
10
+ from .embedded import weaviate_embedded_destination_entry
11
+ from .local import CONNECTOR_TYPE as LOCAL_WEAVIATE_CONNECTOR_TYPE
12
+ from .local import weaviate_local_destination_entry
13
+
14
+ add_destination_entry(
15
+ destination_type=LOCAL_WEAVIATE_CONNECTOR_TYPE, entry=weaviate_local_destination_entry
16
+ )
17
+ add_destination_entry(
18
+ destination_type=CLOUD_WEAVIATE_CONNECTOR_TYPE, entry=weaviate_cloud_destination_entry
19
+ )
20
+ add_destination_entry(
21
+ destination_type=EMBEDDED_WEAVIATE_CONNECTOR_TYPE, entry=weaviate_embedded_destination_entry
22
+ )
@@ -0,0 +1,164 @@
1
+ from contextlib import contextmanager
2
+ from dataclasses import dataclass, field
3
+ from typing import TYPE_CHECKING, Any, Generator, Optional
4
+
5
+ from pydantic import Field, Secret
6
+
7
+ from unstructured_ingest.utils.dep_check import requires_dependencies
8
+ from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
9
+ from unstructured_ingest.v2.processes.connectors.weaviate.weaviate import (
10
+ WeaviateAccessConfig,
11
+ WeaviateConnectionConfig,
12
+ WeaviateUploader,
13
+ WeaviateUploaderConfig,
14
+ WeaviateUploadStager,
15
+ WeaviateUploadStagerConfig,
16
+ )
17
+
18
+ if TYPE_CHECKING:
19
+ from weaviate.auth import AuthCredentials
20
+ from weaviate.client import WeaviateClient
21
+
22
+ CONNECTOR_TYPE = "weaviate-cloud"
23
+
24
+
25
+ class CloudWeaviateAccessConfig(WeaviateAccessConfig):
26
+ access_token: Optional[str] = Field(
27
+ default=None, description="Used to create the bearer token."
28
+ )
29
+ api_key: Optional[str] = None
30
+ client_secret: Optional[str] = None
31
+ password: Optional[str] = None
32
+
33
+
34
+ class CloudWeaviateConnectionConfig(WeaviateConnectionConfig):
35
+ cluster_url: str = Field(
36
+ description="The WCD cluster URL or hostname to connect to. "
37
+ "Usually in the form: rAnD0mD1g1t5.something.weaviate.cloud"
38
+ )
39
+ username: Optional[str] = None
40
+ anonymous: bool = Field(default=False, description="if set, all auth values will be ignored")
41
+ refresh_token: Optional[str] = Field(
42
+ default=None,
43
+ description="Will tie this value to the bearer token. If not provided, "
44
+ "the authentication will expire once the lifetime of the access token is up.",
45
+ )
46
+ access_config: Secret[CloudWeaviateAccessConfig]
47
+
48
+ def model_post_init(self, __context: Any) -> None:
49
+ if self.anonymous:
50
+ return
51
+ access_config = self.access_config.get_secret_value()
52
+ auths = {
53
+ "api_key": access_config.api_key is not None,
54
+ "bearer_token": access_config.access_token is not None,
55
+ "client_secret": access_config.client_secret is not None,
56
+ "client_password": access_config.password is not None and self.username is not None,
57
+ }
58
+ if len(auths) == 0:
59
+ raise ValueError("No auth values provided and anonymous is False")
60
+ if len(auths) > 1:
61
+ existing_auths = [auth_method for auth_method, flag in auths.items() if flag]
62
+ raise ValueError(
63
+ "Multiple auth values provided, only one approach can be used: {}".format(
64
+ ", ".join(existing_auths)
65
+ )
66
+ )
67
+
68
+ @requires_dependencies(["weaviate"], extras="weaviate")
69
+ def get_api_key_auth(self) -> Optional["AuthCredentials"]:
70
+ from weaviate.classes.init import Auth
71
+
72
+ if api_key := self.access_config.get_secret_value().api_key:
73
+ return Auth.api_key(api_key=api_key)
74
+ return None
75
+
76
+ @requires_dependencies(["weaviate"], extras="weaviate")
77
+ def get_bearer_token_auth(self) -> Optional["AuthCredentials"]:
78
+ from weaviate.classes.init import Auth
79
+
80
+ if access_token := self.access_config.get_secret_value().access_token:
81
+ return Auth.bearer_token(access_token=access_token, refresh_token=self.refresh_token)
82
+ return None
83
+
84
+ @requires_dependencies(["weaviate"], extras="weaviate")
85
+ def get_client_secret_auth(self) -> Optional["AuthCredentials"]:
86
+ from weaviate.classes.init import Auth
87
+
88
+ if client_secret := self.access_config.get_secret_value().client_secret:
89
+ return Auth.client_credentials(client_secret=client_secret)
90
+ return None
91
+
92
+ @requires_dependencies(["weaviate"], extras="weaviate")
93
+ def get_client_password_auth(self) -> Optional["AuthCredentials"]:
94
+ from weaviate.classes.init import Auth
95
+
96
+ if (username := self.username) and (
97
+ password := self.access_config.get_secret_value().password
98
+ ):
99
+ return Auth.client_password(username=username, password=password)
100
+ return None
101
+
102
+ @requires_dependencies(["weaviate"], extras="weaviate")
103
+ def get_auth(self) -> "AuthCredentials":
104
+ auths = [
105
+ self.get_api_key_auth(),
106
+ self.get_client_secret_auth(),
107
+ self.get_bearer_token_auth(),
108
+ self.get_client_password_auth(),
109
+ ]
110
+ auths = [auth for auth in auths if auth]
111
+ if len(auths) == 0:
112
+ raise ValueError("No auth values provided and anonymous is False")
113
+ if len(auths) > 1:
114
+ raise ValueError("Multiple auth values provided, only one approach can be used")
115
+ return auths[0]
116
+
117
+ @contextmanager
118
+ @requires_dependencies(["weaviate"], extras="weaviate")
119
+ def get_client(self) -> Generator["WeaviateClient", None, None]:
120
+ from weaviate import connect_to_weaviate_cloud
121
+ from weaviate.classes.init import AdditionalConfig
122
+
123
+ auth_credentials = None if self.anonymous else self.get_auth()
124
+ with connect_to_weaviate_cloud(
125
+ cluster_url=self.cluster_url,
126
+ auth_credentials=auth_credentials,
127
+ additional_config=AdditionalConfig(timeout=self.get_timeout()),
128
+ ) as weaviate_client:
129
+ yield weaviate_client
130
+
131
+
132
+ class CloudWeaviateUploadStagerConfig(WeaviateUploadStagerConfig):
133
+ pass
134
+
135
+
136
+ @dataclass
137
+ class CloudWeaviateUploadStager(WeaviateUploadStager):
138
+ upload_stager_config: CloudWeaviateUploadStagerConfig = field(
139
+ default_factory=lambda: WeaviateUploadStagerConfig()
140
+ )
141
+
142
+
143
+ class CloudWeaviateUploaderConfig(WeaviateUploaderConfig):
144
+ pass
145
+
146
+
147
+ @dataclass
148
+ class CloudWeaviateUploader(WeaviateUploader):
149
+ connection_config: CloudWeaviateConnectionConfig = field(
150
+ default_factory=lambda: CloudWeaviateConnectionConfig()
151
+ )
152
+ upload_config: CloudWeaviateUploaderConfig = field(
153
+ default_factory=lambda: CloudWeaviateUploaderConfig()
154
+ )
155
+ connector_type: str = CONNECTOR_TYPE
156
+
157
+
158
+ weaviate_cloud_destination_entry = DestinationRegistryEntry(
159
+ connection_config=CloudWeaviateConnectionConfig,
160
+ uploader=CloudWeaviateUploader,
161
+ uploader_config=CloudWeaviateUploaderConfig,
162
+ upload_stager=CloudWeaviateUploadStager,
163
+ upload_stager_config=CloudWeaviateUploadStagerConfig,
164
+ )