unstructured-ingest 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (93) hide show
  1. test/integration/connectors/elasticsearch/__init__.py +0 -0
  2. test/integration/connectors/elasticsearch/conftest.py +34 -0
  3. test/integration/connectors/elasticsearch/test_elasticsearch.py +308 -0
  4. test/integration/connectors/elasticsearch/test_opensearch.py +302 -0
  5. test/integration/connectors/sql/test_postgres.py +10 -4
  6. test/integration/connectors/sql/test_singlestore.py +8 -4
  7. test/integration/connectors/sql/test_snowflake.py +10 -6
  8. test/integration/connectors/sql/test_sqlite.py +4 -4
  9. test/integration/connectors/test_astradb.py +156 -0
  10. test/integration/connectors/test_azure_cog_search.py +233 -0
  11. test/integration/connectors/test_delta_table.py +46 -0
  12. test/integration/connectors/test_kafka.py +150 -16
  13. test/integration/connectors/test_lancedb.py +209 -0
  14. test/integration/connectors/test_milvus.py +141 -0
  15. test/integration/connectors/test_pinecone.py +213 -0
  16. test/integration/connectors/test_s3.py +23 -0
  17. test/integration/connectors/utils/docker.py +81 -15
  18. test/integration/connectors/utils/validation.py +10 -0
  19. test/integration/connectors/weaviate/__init__.py +0 -0
  20. test/integration/connectors/weaviate/conftest.py +15 -0
  21. test/integration/connectors/weaviate/test_local.py +131 -0
  22. test/unit/v2/__init__.py +0 -0
  23. test/unit/v2/chunkers/__init__.py +0 -0
  24. test/unit/v2/chunkers/test_chunkers.py +49 -0
  25. test/unit/v2/connectors/__init__.py +0 -0
  26. test/unit/v2/embedders/__init__.py +0 -0
  27. test/unit/v2/embedders/test_bedrock.py +36 -0
  28. test/unit/v2/embedders/test_huggingface.py +48 -0
  29. test/unit/v2/embedders/test_mixedbread.py +37 -0
  30. test/unit/v2/embedders/test_octoai.py +35 -0
  31. test/unit/v2/embedders/test_openai.py +35 -0
  32. test/unit/v2/embedders/test_togetherai.py +37 -0
  33. test/unit/v2/embedders/test_vertexai.py +37 -0
  34. test/unit/v2/embedders/test_voyageai.py +38 -0
  35. test/unit/v2/partitioners/__init__.py +0 -0
  36. test/unit/v2/partitioners/test_partitioner.py +63 -0
  37. test/unit/v2/utils/__init__.py +0 -0
  38. test/unit/v2/utils/data_generator.py +32 -0
  39. unstructured_ingest/__version__.py +1 -1
  40. unstructured_ingest/cli/cmds/__init__.py +2 -2
  41. unstructured_ingest/cli/cmds/{azure_cognitive_search.py → azure_ai_search.py} +9 -9
  42. unstructured_ingest/connector/{azure_cognitive_search.py → azure_ai_search.py} +9 -9
  43. unstructured_ingest/pipeline/reformat/embedding.py +1 -1
  44. unstructured_ingest/runner/writers/__init__.py +2 -2
  45. unstructured_ingest/runner/writers/azure_ai_search.py +24 -0
  46. unstructured_ingest/utils/data_prep.py +9 -1
  47. unstructured_ingest/v2/constants.py +2 -0
  48. unstructured_ingest/v2/processes/connectors/__init__.py +7 -20
  49. unstructured_ingest/v2/processes/connectors/airtable.py +2 -2
  50. unstructured_ingest/v2/processes/connectors/astradb.py +35 -23
  51. unstructured_ingest/v2/processes/connectors/{azure_cognitive_search.py → azure_ai_search.py} +116 -35
  52. unstructured_ingest/v2/processes/connectors/confluence.py +2 -2
  53. unstructured_ingest/v2/processes/connectors/couchbase.py +1 -0
  54. unstructured_ingest/v2/processes/connectors/delta_table.py +37 -9
  55. unstructured_ingest/v2/processes/connectors/elasticsearch/__init__.py +19 -0
  56. unstructured_ingest/v2/processes/connectors/{elasticsearch.py → elasticsearch/elasticsearch.py} +93 -46
  57. unstructured_ingest/v2/processes/connectors/{opensearch.py → elasticsearch/opensearch.py} +1 -1
  58. unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +27 -0
  59. unstructured_ingest/v2/processes/connectors/google_drive.py +3 -3
  60. unstructured_ingest/v2/processes/connectors/kafka/__init__.py +6 -2
  61. unstructured_ingest/v2/processes/connectors/kafka/cloud.py +38 -2
  62. unstructured_ingest/v2/processes/connectors/kafka/kafka.py +84 -23
  63. unstructured_ingest/v2/processes/connectors/kafka/local.py +32 -4
  64. unstructured_ingest/v2/processes/connectors/lancedb/__init__.py +17 -0
  65. unstructured_ingest/v2/processes/connectors/lancedb/aws.py +43 -0
  66. unstructured_ingest/v2/processes/connectors/lancedb/azure.py +43 -0
  67. unstructured_ingest/v2/processes/connectors/lancedb/gcp.py +44 -0
  68. unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +161 -0
  69. unstructured_ingest/v2/processes/connectors/lancedb/local.py +44 -0
  70. unstructured_ingest/v2/processes/connectors/milvus.py +72 -27
  71. unstructured_ingest/v2/processes/connectors/onedrive.py +2 -3
  72. unstructured_ingest/v2/processes/connectors/outlook.py +2 -2
  73. unstructured_ingest/v2/processes/connectors/pinecone.py +101 -13
  74. unstructured_ingest/v2/processes/connectors/sharepoint.py +3 -2
  75. unstructured_ingest/v2/processes/connectors/slack.py +2 -2
  76. unstructured_ingest/v2/processes/connectors/sql/postgres.py +16 -8
  77. unstructured_ingest/v2/processes/connectors/sql/sql.py +97 -26
  78. unstructured_ingest/v2/processes/connectors/weaviate/__init__.py +22 -0
  79. unstructured_ingest/v2/processes/connectors/weaviate/cloud.py +164 -0
  80. unstructured_ingest/v2/processes/connectors/weaviate/embedded.py +90 -0
  81. unstructured_ingest/v2/processes/connectors/weaviate/local.py +73 -0
  82. unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +289 -0
  83. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/METADATA +20 -19
  84. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/RECORD +91 -50
  85. unstructured_ingest/runner/writers/azure_cognitive_search.py +0 -24
  86. unstructured_ingest/v2/processes/connectors/weaviate.py +0 -242
  87. /test/integration/embedders/{togetherai.py → test_togetherai.py} +0 -0
  88. /test/unit/{test_interfaces_v2.py → v2/test_interfaces.py} +0 -0
  89. /test/unit/{test_utils_v2.py → v2/test_utils.py} +0 -0
  90. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/LICENSE.md +0 -0
  91. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/WHEEL +0 -0
  92. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/entry_points.txt +0 -0
  93. {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ from pydantic import Field, Secret
9
9
  from unstructured_ingest.error import DestinationConnectionError
10
10
  from unstructured_ingest.utils.data_prep import flatten_dict, generator_batching_wbytes
11
11
  from unstructured_ingest.utils.dep_check import requires_dependencies
12
+ from unstructured_ingest.v2.constants import RECORD_ID_LABEL
12
13
  from unstructured_ingest.v2.interfaces import (
13
14
  AccessConfig,
14
15
  ConnectionConfig,
@@ -23,11 +24,13 @@ from unstructured_ingest.v2.processes.connector_registry import DestinationRegis
23
24
 
24
25
  if TYPE_CHECKING:
25
26
  from pinecone import Index as PineconeIndex
27
+ from pinecone import Pinecone
26
28
 
27
29
 
28
30
  CONNECTOR_TYPE = "pinecone"
29
31
  MAX_PAYLOAD_SIZE = 2 * 1024 * 1024 # 2MB
30
32
  MAX_POOL_THREADS = 100
33
+ MAX_METADATA_BYTES = 40960 # 40KB https://docs.pinecone.io/reference/quotas-and-limits#hard-limits
31
34
 
32
35
 
33
36
  class PineconeAccessConfig(AccessConfig):
@@ -43,16 +46,19 @@ class PineconeConnectionConfig(ConnectionConfig):
43
46
  )
44
47
 
45
48
  @requires_dependencies(["pinecone"], extras="pinecone")
46
- def get_index(self, **index_kwargs) -> "PineconeIndex":
49
+ def get_client(self, **index_kwargs) -> "Pinecone":
47
50
  from pinecone import Pinecone
48
51
 
49
52
  from unstructured_ingest import __version__ as unstructured_version
50
53
 
51
- pc = Pinecone(
54
+ return Pinecone(
52
55
  api_key=self.access_config.get_secret_value().pinecone_api_key,
53
56
  source_tag=f"unstructured_ingest=={unstructured_version}",
54
57
  )
55
58
 
59
+ def get_index(self, **index_kwargs) -> "PineconeIndex":
60
+ pc = self.get_client()
61
+
56
62
  index = pc.Index(name=self.index_name, **index_kwargs)
57
63
  logger.debug(f"connected to index: {pc.describe_index(self.index_name)}")
58
64
  return index
@@ -98,6 +104,10 @@ class PineconeUploaderConfig(UploaderConfig):
98
104
  default=None,
99
105
  description="The namespace to write to. If not specified, the default namespace is used",
100
106
  )
107
+ record_id_key: str = Field(
108
+ default=RECORD_ID_LABEL,
109
+ description="searchable key to find entries for the same record on previous runs",
110
+ )
101
111
 
102
112
 
103
113
  @dataclass
@@ -106,7 +116,7 @@ class PineconeUploadStager(UploadStager):
106
116
  default_factory=lambda: PineconeUploadStagerConfig()
107
117
  )
108
118
 
109
- def conform_dict(self, element_dict: dict) -> dict:
119
+ def conform_dict(self, element_dict: dict, file_data: FileData) -> dict:
110
120
  embeddings = element_dict.pop("embeddings", None)
111
121
  metadata: dict[str, Any] = element_dict.pop("metadata", {})
112
122
  data_source = metadata.pop("data_source", {})
@@ -121,19 +131,30 @@ class PineconeUploadStager(UploadStager):
121
131
  }
122
132
  )
123
133
 
134
+ metadata = flatten_dict(
135
+ pinecone_metadata,
136
+ separator="-",
137
+ flatten_lists=True,
138
+ remove_none=True,
139
+ )
140
+ metadata[RECORD_ID_LABEL] = file_data.identifier
141
+ metadata_size_bytes = len(json.dumps(metadata).encode())
142
+ if metadata_size_bytes > MAX_METADATA_BYTES:
143
+ logger.info(
144
+ f"Metadata size is {metadata_size_bytes} bytes, which exceeds the limit of"
145
+ f" {MAX_METADATA_BYTES} bytes per vector. Dropping the metadata."
146
+ )
147
+ metadata = {}
148
+
124
149
  return {
125
150
  "id": str(uuid.uuid4()),
126
151
  "values": embeddings,
127
- "metadata": flatten_dict(
128
- pinecone_metadata,
129
- separator="-",
130
- flatten_lists=True,
131
- remove_none=True,
132
- ),
152
+ "metadata": metadata,
133
153
  }
134
154
 
135
155
  def run(
136
156
  self,
157
+ file_data: FileData,
137
158
  elements_filepath: Path,
138
159
  output_dir: Path,
139
160
  output_filename: str,
@@ -143,10 +164,15 @@ class PineconeUploadStager(UploadStager):
143
164
  elements_contents = json.load(elements_file)
144
165
 
145
166
  conformed_elements = [
146
- self.conform_dict(element_dict=element) for element in elements_contents
167
+ self.conform_dict(element_dict=element, file_data=file_data)
168
+ for element in elements_contents
147
169
  ]
148
170
 
149
- output_path = Path(output_dir) / Path(f"{output_filename}.json")
171
+ if Path(output_filename).suffix != ".json":
172
+ output_filename = f"{output_filename}.json"
173
+ else:
174
+ output_filename = f"{Path(output_filename).stem}.json"
175
+ output_path = Path(output_dir) / Path(f"{output_filename}")
150
176
  output_path.parent.mkdir(parents=True, exist_ok=True)
151
177
 
152
178
  with open(output_path, "w") as output_file:
@@ -167,6 +193,61 @@ class PineconeUploader(Uploader):
167
193
  logger.error(f"failed to validate connection: {e}", exc_info=True)
168
194
  raise DestinationConnectionError(f"failed to validate connection: {e}")
169
195
 
196
+ def pod_delete_by_record_id(self, file_data: FileData) -> None:
197
+ logger.debug(
198
+ f"deleting any content with metadata "
199
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
200
+ f"from pinecone pod index"
201
+ )
202
+ index = self.connection_config.get_index(pool_threads=MAX_POOL_THREADS)
203
+ delete_kwargs = {
204
+ "filter": {self.upload_config.record_id_key: {"$eq": file_data.identifier}}
205
+ }
206
+ if namespace := self.upload_config.namespace:
207
+ delete_kwargs["namespace"] = namespace
208
+
209
+ resp = index.delete(**delete_kwargs)
210
+ logger.debug(
211
+ f"deleted any content with metadata "
212
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
213
+ f"from pinecone index: {resp}"
214
+ )
215
+
216
+ def serverless_delete_by_record_id(self, file_data: FileData) -> None:
217
+ logger.debug(
218
+ f"deleting any content with metadata "
219
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
220
+ f"from pinecone serverless index"
221
+ )
222
+ index = self.connection_config.get_index(pool_threads=MAX_POOL_THREADS)
223
+ index_stats = index.describe_index_stats()
224
+ total_vectors = index_stats["total_vector_count"]
225
+ if total_vectors == 0:
226
+ return
227
+ dimension = index_stats["dimension"]
228
+ query_params = {
229
+ "filter": {self.upload_config.record_id_key: {"$eq": file_data.identifier}},
230
+ "vector": [0] * dimension,
231
+ "top_k": total_vectors,
232
+ }
233
+ if namespace := self.upload_config.namespace:
234
+ query_params["namespace"] = namespace
235
+ while True:
236
+ query_results = index.query(**query_params)
237
+ matches = query_results.get("matches", [])
238
+ if not matches:
239
+ break
240
+ ids = [match["id"] for match in matches]
241
+ delete_params = {"ids": ids}
242
+ if namespace := self.upload_config.namespace:
243
+ delete_params["namespace"] = namespace
244
+ index.delete(**delete_params)
245
+ logger.debug(
246
+ f"deleted any content with metadata "
247
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
248
+ f"from pinecone index"
249
+ )
250
+
170
251
  @requires_dependencies(["pinecone"], extras="pinecone")
171
252
  def upsert_batches_async(self, elements_dict: list[dict]):
172
253
  from pinecone.exceptions import PineconeApiException
@@ -206,9 +287,16 @@ class PineconeUploader(Uploader):
206
287
  f"writing a total of {len(elements_dict)} elements via"
207
288
  f" document batches to destination"
208
289
  f" index named {self.connection_config.index_name}"
209
- f" with batch size {self.upload_config.batch_size}"
210
290
  )
211
-
291
+ # Determine if serverless or pod based index
292
+ pinecone_client = self.connection_config.get_client()
293
+ index_description = pinecone_client.describe_index(name=self.connection_config.index_name)
294
+ if "serverless" in index_description.get("spec"):
295
+ self.serverless_delete_by_record_id(file_data=file_data)
296
+ elif "pod" in index_description.get("spec"):
297
+ self.pod_delete_by_record_id(file_data=file_data)
298
+ else:
299
+ raise ValueError(f"unexpected spec type in index description: {index_description}")
212
300
  self.upsert_batches_async(elements_dict=elements_dict)
213
301
 
214
302
 
@@ -21,7 +21,6 @@ from unstructured_ingest.v2.interfaces import (
21
21
  Indexer,
22
22
  IndexerConfig,
23
23
  SourceIdentifiers,
24
- download_responses,
25
24
  )
26
25
  from unstructured_ingest.v2.logger import logger
27
26
  from unstructured_ingest.v2.processes.connector_registry import (
@@ -426,7 +425,7 @@ class SharepointDownloader(Downloader):
426
425
  f.write(etree.tostring(document, encoding="unicode", pretty_print=True))
427
426
  return self.generate_download_response(file_data=file_data, download_path=download_path)
428
427
 
429
- def run(self, file_data: FileData, **kwargs: Any) -> download_responses:
428
+ def run(self, file_data: FileData, **kwargs: Any) -> DownloadResponse:
430
429
  content_type = file_data.additional_metadata.get("sharepoint_content_type")
431
430
  if not content_type:
432
431
  raise ValueError(
@@ -436,6 +435,8 @@ class SharepointDownloader(Downloader):
436
435
  return self.get_document(file_data=file_data)
437
436
  elif content_type == SharepointContentType.SITEPAGE.value:
438
437
  return self.get_site_page(file_data=file_data)
438
+ else:
439
+ raise ValueError(f"content type not recognized: {content_type}")
439
440
 
440
441
 
441
442
  sharepoint_source_entry = SourceRegistryEntry(
@@ -16,9 +16,9 @@ from unstructured_ingest.v2.interfaces import (
16
16
  ConnectionConfig,
17
17
  Downloader,
18
18
  DownloaderConfig,
19
+ DownloadResponse,
19
20
  Indexer,
20
21
  IndexerConfig,
21
- download_responses,
22
22
  )
23
23
  from unstructured_ingest.v2.interfaces.file_data import (
24
24
  FileData,
@@ -161,7 +161,7 @@ class SlackDownloader(Downloader):
161
161
  def run(self, file_data, **kwargs):
162
162
  raise NotImplementedError
163
163
 
164
- async def run_async(self, file_data: FileData, **kwargs) -> download_responses:
164
+ async def run_async(self, file_data: FileData, **kwargs) -> DownloadResponse:
165
165
  # NOTE: Indexer should provide source identifiers required to generate the download path
166
166
  download_path = self.get_download_path(file_data)
167
167
  if download_path is None:
@@ -98,20 +98,28 @@ class PostgresDownloader(SQLDownloader):
98
98
  download_config: PostgresDownloaderConfig
99
99
  connector_type: str = CONNECTOR_TYPE
100
100
 
101
+ @requires_dependencies(["psycopg2"], extras="postgres")
101
102
  def query_db(self, file_data: FileData) -> tuple[list[tuple], list[str]]:
103
+ from psycopg2 import sql
104
+
102
105
  table_name = file_data.additional_metadata["table_name"]
103
106
  id_column = file_data.additional_metadata["id_column"]
104
- ids = file_data.additional_metadata["ids"]
107
+ ids = tuple(file_data.additional_metadata["ids"])
108
+
105
109
  with self.connection_config.get_cursor() as cursor:
106
- fields = ",".join(self.download_config.fields) if self.download_config.fields else "*"
107
- query = "SELECT {fields} FROM {table_name} WHERE {id_column} in ({ids})".format(
110
+ fields = (
111
+ sql.SQL(",").join(sql.Identifier(field) for field in self.download_config.fields)
112
+ if self.download_config.fields
113
+ else sql.SQL("*")
114
+ )
115
+
116
+ query = sql.SQL("SELECT {fields} FROM {table_name} WHERE {id_column} IN %s").format(
108
117
  fields=fields,
109
- table_name=table_name,
110
- id_column=id_column,
111
- ids=",".join([str(i) for i in ids]),
118
+ table_name=sql.Identifier(table_name),
119
+ id_column=sql.Identifier(id_column),
112
120
  )
113
- logger.debug(f"running query: {query}")
114
- cursor.execute(query)
121
+ logger.debug(f"running query: {cursor.mogrify(query, (ids,))}")
122
+ cursor.execute(query, (ids,))
115
123
  rows = cursor.fetchall()
116
124
  columns = [col[0] for col in cursor.description]
117
125
  return rows, columns
@@ -16,6 +16,8 @@ from dateutil import parser
16
16
  from pydantic import Field, Secret
17
17
 
18
18
  from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
19
+ from unstructured_ingest.utils.data_prep import split_dataframe
20
+ from unstructured_ingest.v2.constants import RECORD_ID_LABEL
19
21
  from unstructured_ingest.v2.interfaces import (
20
22
  AccessConfig,
21
23
  ConnectionConfig,
@@ -236,35 +238,25 @@ class SQLUploadStagerConfig(UploadStagerConfig):
236
238
  class SQLUploadStager(UploadStager):
237
239
  upload_stager_config: SQLUploadStagerConfig = field(default_factory=SQLUploadStagerConfig)
238
240
 
239
- def run(
240
- self,
241
- elements_filepath: Path,
242
- file_data: FileData,
243
- output_dir: Path,
244
- output_filename: str,
245
- **kwargs: Any,
246
- ) -> Path:
247
- with open(elements_filepath) as elements_file:
248
- elements_contents: list[dict] = json.load(elements_file)
249
- output_path = Path(output_dir) / Path(f"{output_filename}.json")
250
- output_path.parent.mkdir(parents=True, exist_ok=True)
251
-
241
+ @staticmethod
242
+ def conform_dict(data: dict, file_data: FileData) -> pd.DataFrame:
243
+ working_data = data.copy()
252
244
  output = []
253
- for data in elements_contents:
254
- metadata: dict[str, Any] = data.pop("metadata", {})
245
+ for element in working_data:
246
+ metadata: dict[str, Any] = element.pop("metadata", {})
255
247
  data_source = metadata.pop("data_source", {})
256
248
  coordinates = metadata.pop("coordinates", {})
257
249
 
258
- data.update(metadata)
259
- data.update(data_source)
260
- data.update(coordinates)
250
+ element.update(metadata)
251
+ element.update(data_source)
252
+ element.update(coordinates)
261
253
 
262
- data["id"] = str(uuid.uuid4())
254
+ element["id"] = str(uuid.uuid4())
263
255
 
264
256
  # remove extraneous, not supported columns
265
- data = {k: v for k, v in data.items() if k in _COLUMNS}
266
-
267
- output.append(data)
257
+ element = {k: v for k, v in element.items() if k in _COLUMNS}
258
+ element[RECORD_ID_LABEL] = file_data.identifier
259
+ output.append(element)
268
260
 
269
261
  df = pd.DataFrame.from_dict(output)
270
262
  for column in filter(lambda x: x in df.columns, _DATE_COLUMNS):
@@ -281,6 +273,26 @@ class SQLUploadStager(UploadStager):
281
273
  ("version", "page_number", "regex_metadata"),
282
274
  ):
283
275
  df[column] = df[column].apply(str)
276
+ return df
277
+
278
+ def run(
279
+ self,
280
+ elements_filepath: Path,
281
+ file_data: FileData,
282
+ output_dir: Path,
283
+ output_filename: str,
284
+ **kwargs: Any,
285
+ ) -> Path:
286
+ with open(elements_filepath) as elements_file:
287
+ elements_contents: list[dict] = json.load(elements_file)
288
+
289
+ df = self.conform_dict(data=elements_contents, file_data=file_data)
290
+ if Path(output_filename).suffix != ".json":
291
+ output_filename = f"{output_filename}.json"
292
+ else:
293
+ output_filename = f"{Path(output_filename).stem}.json"
294
+ output_path = Path(output_dir) / Path(f"{output_filename}")
295
+ output_path.parent.mkdir(parents=True, exist_ok=True)
284
296
 
285
297
  with output_path.open("w") as output_file:
286
298
  df.to_json(output_file, orient="records", lines=True)
@@ -290,6 +302,10 @@ class SQLUploadStager(UploadStager):
290
302
  class SQLUploaderConfig(UploaderConfig):
291
303
  batch_size: int = Field(default=50, description="Number of records per batch")
292
304
  table_name: str = Field(default="elements", description="which table to upload contents to")
305
+ record_id_key: str = Field(
306
+ default=RECORD_ID_LABEL,
307
+ description="searchable key to find entries for the same record on previous runs",
308
+ )
293
309
 
294
310
 
295
311
  @dataclass
@@ -323,18 +339,45 @@ class SQLUploader(Uploader):
323
339
  output.append(tuple(parsed))
324
340
  return output
325
341
 
342
+ def _fit_to_schema(self, df: pd.DataFrame, columns: list[str]) -> pd.DataFrame:
343
+ columns = set(df.columns)
344
+ schema_fields = set(columns)
345
+ columns_to_drop = columns - schema_fields
346
+ missing_columns = schema_fields - columns
347
+
348
+ if columns_to_drop:
349
+ logger.warning(
350
+ "Following columns will be dropped to match the table's schema: "
351
+ f"{', '.join(columns_to_drop)}"
352
+ )
353
+ if missing_columns:
354
+ logger.info(
355
+ "Following null filled columns will be added to match the table's schema:"
356
+ f" {', '.join(missing_columns)} "
357
+ )
358
+
359
+ df = df.drop(columns=columns_to_drop)
360
+
361
+ for column in missing_columns:
362
+ df[column] = pd.Series()
363
+
326
364
  def upload_contents(self, path: Path) -> None:
327
365
  df = pd.read_json(path, orient="records", lines=True)
328
366
  df.replace({np.nan: None}, inplace=True)
367
+ self._fit_to_schema(df=df, columns=self.get_table_columns())
329
368
 
330
369
  columns = list(df.columns)
331
370
  stmt = f"INSERT INTO {self.upload_config.table_name} ({','.join(columns)}) VALUES({','.join([self.values_delimiter for x in columns])})" # noqa E501
332
-
333
- for rows in pd.read_json(
334
- path, orient="records", lines=True, chunksize=self.upload_config.batch_size
335
- ):
371
+ logger.info(
372
+ f"writing a total of {len(df)} elements via"
373
+ f" document batches to destination"
374
+ f" table named {self.upload_config.table_name}"
375
+ f" with batch size {self.upload_config.batch_size}"
376
+ )
377
+ for rows in split_dataframe(df=df, chunk_size=self.upload_config.batch_size):
336
378
  with self.connection_config.get_cursor() as cursor:
337
379
  values = self.prepare_data(columns, tuple(rows.itertuples(index=False, name=None)))
380
+ # For debugging purposes:
338
381
  # for val in values:
339
382
  # try:
340
383
  # cursor.execute(stmt, val)
@@ -343,5 +386,33 @@ class SQLUploader(Uploader):
343
386
  # print(f"failed to write {len(columns)}, {len(val)}: {stmt} -> {val}")
344
387
  cursor.executemany(stmt, values)
345
388
 
389
+ def get_table_columns(self) -> list[str]:
390
+ with self.connection_config.get_cursor() as cursor:
391
+ cursor.execute(f"SELECT * from {self.upload_config.table_name}")
392
+ return [desc[0] for desc in cursor.description]
393
+
394
+ def can_delete(self) -> bool:
395
+ return self.upload_config.record_id_key in self.get_table_columns()
396
+
397
+ def delete_by_record_id(self, file_data: FileData) -> None:
398
+ logger.debug(
399
+ f"deleting any content with data "
400
+ f"{self.upload_config.record_id_key}={file_data.identifier} "
401
+ f"from table {self.upload_config.table_name}"
402
+ )
403
+ stmt = f"DELETE FROM {self.upload_config.table_name} WHERE {self.upload_config.record_id_key} = {self.values_delimiter}" # noqa: E501
404
+ with self.connection_config.get_cursor() as cursor:
405
+ cursor.execute(stmt, [file_data.identifier])
406
+ rowcount = cursor.rowcount
407
+ logger.info(f"deleted {rowcount} rows from table {self.upload_config.table_name}")
408
+
346
409
  def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
410
+ if self.can_delete():
411
+ self.delete_by_record_id(file_data=file_data)
412
+ else:
413
+ logger.warning(
414
+ f"table doesn't contain expected "
415
+ f"record id column "
416
+ f"{self.upload_config.record_id_key}, skipping delete"
417
+ )
347
418
  self.upload_contents(path=path)
@@ -0,0 +1,22 @@
1
+ from __future__ import annotations
2
+
3
+ from unstructured_ingest.v2.processes.connector_registry import (
4
+ add_destination_entry,
5
+ )
6
+
7
+ from .cloud import CONNECTOR_TYPE as CLOUD_WEAVIATE_CONNECTOR_TYPE
8
+ from .cloud import weaviate_cloud_destination_entry
9
+ from .embedded import CONNECTOR_TYPE as EMBEDDED_WEAVIATE_CONNECTOR_TYPE
10
+ from .embedded import weaviate_embedded_destination_entry
11
+ from .local import CONNECTOR_TYPE as LOCAL_WEAVIATE_CONNECTOR_TYPE
12
+ from .local import weaviate_local_destination_entry
13
+
14
+ add_destination_entry(
15
+ destination_type=LOCAL_WEAVIATE_CONNECTOR_TYPE, entry=weaviate_local_destination_entry
16
+ )
17
+ add_destination_entry(
18
+ destination_type=CLOUD_WEAVIATE_CONNECTOR_TYPE, entry=weaviate_cloud_destination_entry
19
+ )
20
+ add_destination_entry(
21
+ destination_type=EMBEDDED_WEAVIATE_CONNECTOR_TYPE, entry=weaviate_embedded_destination_entry
22
+ )
@@ -0,0 +1,164 @@
1
+ from contextlib import contextmanager
2
+ from dataclasses import dataclass, field
3
+ from typing import TYPE_CHECKING, Any, Generator, Optional
4
+
5
+ from pydantic import Field, Secret
6
+
7
+ from unstructured_ingest.utils.dep_check import requires_dependencies
8
+ from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
9
+ from unstructured_ingest.v2.processes.connectors.weaviate.weaviate import (
10
+ WeaviateAccessConfig,
11
+ WeaviateConnectionConfig,
12
+ WeaviateUploader,
13
+ WeaviateUploaderConfig,
14
+ WeaviateUploadStager,
15
+ WeaviateUploadStagerConfig,
16
+ )
17
+
18
+ if TYPE_CHECKING:
19
+ from weaviate.auth import AuthCredentials
20
+ from weaviate.client import WeaviateClient
21
+
22
+ CONNECTOR_TYPE = "weaviate-cloud"
23
+
24
+
25
+ class CloudWeaviateAccessConfig(WeaviateAccessConfig):
26
+ access_token: Optional[str] = Field(
27
+ default=None, description="Used to create the bearer token."
28
+ )
29
+ api_key: Optional[str] = None
30
+ client_secret: Optional[str] = None
31
+ password: Optional[str] = None
32
+
33
+
34
+ class CloudWeaviateConnectionConfig(WeaviateConnectionConfig):
35
+ cluster_url: str = Field(
36
+ description="The WCD cluster URL or hostname to connect to. "
37
+ "Usually in the form: rAnD0mD1g1t5.something.weaviate.cloud"
38
+ )
39
+ username: Optional[str] = None
40
+ anonymous: bool = Field(default=False, description="if set, all auth values will be ignored")
41
+ refresh_token: Optional[str] = Field(
42
+ default=None,
43
+ description="Will tie this value to the bearer token. If not provided, "
44
+ "the authentication will expire once the lifetime of the access token is up.",
45
+ )
46
+ access_config: Secret[CloudWeaviateAccessConfig]
47
+
48
+ def model_post_init(self, __context: Any) -> None:
49
+ if self.anonymous:
50
+ return
51
+ access_config = self.access_config.get_secret_value()
52
+ auths = {
53
+ "api_key": access_config.api_key is not None,
54
+ "bearer_token": access_config.access_token is not None,
55
+ "client_secret": access_config.client_secret is not None,
56
+ "client_password": access_config.password is not None and self.username is not None,
57
+ }
58
+ if len(auths) == 0:
59
+ raise ValueError("No auth values provided and anonymous is False")
60
+ if len(auths) > 1:
61
+ existing_auths = [auth_method for auth_method, flag in auths.items() if flag]
62
+ raise ValueError(
63
+ "Multiple auth values provided, only one approach can be used: {}".format(
64
+ ", ".join(existing_auths)
65
+ )
66
+ )
67
+
68
+ @requires_dependencies(["weaviate"], extras="weaviate")
69
+ def get_api_key_auth(self) -> Optional["AuthCredentials"]:
70
+ from weaviate.classes.init import Auth
71
+
72
+ if api_key := self.access_config.get_secret_value().api_key:
73
+ return Auth.api_key(api_key=api_key)
74
+ return None
75
+
76
+ @requires_dependencies(["weaviate"], extras="weaviate")
77
+ def get_bearer_token_auth(self) -> Optional["AuthCredentials"]:
78
+ from weaviate.classes.init import Auth
79
+
80
+ if access_token := self.access_config.get_secret_value().access_token:
81
+ return Auth.bearer_token(access_token=access_token, refresh_token=self.refresh_token)
82
+ return None
83
+
84
+ @requires_dependencies(["weaviate"], extras="weaviate")
85
+ def get_client_secret_auth(self) -> Optional["AuthCredentials"]:
86
+ from weaviate.classes.init import Auth
87
+
88
+ if client_secret := self.access_config.get_secret_value().client_secret:
89
+ return Auth.client_credentials(client_secret=client_secret)
90
+ return None
91
+
92
+ @requires_dependencies(["weaviate"], extras="weaviate")
93
+ def get_client_password_auth(self) -> Optional["AuthCredentials"]:
94
+ from weaviate.classes.init import Auth
95
+
96
+ if (username := self.username) and (
97
+ password := self.access_config.get_secret_value().password
98
+ ):
99
+ return Auth.client_password(username=username, password=password)
100
+ return None
101
+
102
+ @requires_dependencies(["weaviate"], extras="weaviate")
103
+ def get_auth(self) -> "AuthCredentials":
104
+ auths = [
105
+ self.get_api_key_auth(),
106
+ self.get_client_secret_auth(),
107
+ self.get_bearer_token_auth(),
108
+ self.get_client_password_auth(),
109
+ ]
110
+ auths = [auth for auth in auths if auth]
111
+ if len(auths) == 0:
112
+ raise ValueError("No auth values provided and anonymous is False")
113
+ if len(auths) > 1:
114
+ raise ValueError("Multiple auth values provided, only one approach can be used")
115
+ return auths[0]
116
+
117
+ @contextmanager
118
+ @requires_dependencies(["weaviate"], extras="weaviate")
119
+ def get_client(self) -> Generator["WeaviateClient", None, None]:
120
+ from weaviate import connect_to_weaviate_cloud
121
+ from weaviate.classes.init import AdditionalConfig
122
+
123
+ auth_credentials = None if self.anonymous else self.get_auth()
124
+ with connect_to_weaviate_cloud(
125
+ cluster_url=self.cluster_url,
126
+ auth_credentials=auth_credentials,
127
+ additional_config=AdditionalConfig(timeout=self.get_timeout()),
128
+ ) as weaviate_client:
129
+ yield weaviate_client
130
+
131
+
132
+ class CloudWeaviateUploadStagerConfig(WeaviateUploadStagerConfig):
133
+ pass
134
+
135
+
136
+ @dataclass
137
+ class CloudWeaviateUploadStager(WeaviateUploadStager):
138
+ upload_stager_config: CloudWeaviateUploadStagerConfig = field(
139
+ default_factory=lambda: WeaviateUploadStagerConfig()
140
+ )
141
+
142
+
143
+ class CloudWeaviateUploaderConfig(WeaviateUploaderConfig):
144
+ pass
145
+
146
+
147
+ @dataclass
148
+ class CloudWeaviateUploader(WeaviateUploader):
149
+ connection_config: CloudWeaviateConnectionConfig = field(
150
+ default_factory=lambda: CloudWeaviateConnectionConfig()
151
+ )
152
+ upload_config: CloudWeaviateUploaderConfig = field(
153
+ default_factory=lambda: CloudWeaviateUploaderConfig()
154
+ )
155
+ connector_type: str = CONNECTOR_TYPE
156
+
157
+
158
+ weaviate_cloud_destination_entry = DestinationRegistryEntry(
159
+ connection_config=CloudWeaviateConnectionConfig,
160
+ uploader=CloudWeaviateUploader,
161
+ uploader_config=CloudWeaviateUploaderConfig,
162
+ upload_stager=CloudWeaviateUploadStager,
163
+ upload_stager_config=CloudWeaviateUploadStagerConfig,
164
+ )