unstructured-ingest 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/elasticsearch/__init__.py +0 -0
- test/integration/connectors/elasticsearch/conftest.py +34 -0
- test/integration/connectors/elasticsearch/test_elasticsearch.py +308 -0
- test/integration/connectors/elasticsearch/test_opensearch.py +302 -0
- test/integration/connectors/sql/test_postgres.py +10 -4
- test/integration/connectors/sql/test_singlestore.py +8 -4
- test/integration/connectors/sql/test_snowflake.py +10 -6
- test/integration/connectors/sql/test_sqlite.py +4 -4
- test/integration/connectors/test_astradb.py +50 -3
- test/integration/connectors/test_delta_table.py +46 -0
- test/integration/connectors/test_kafka.py +40 -6
- test/integration/connectors/test_lancedb.py +209 -0
- test/integration/connectors/test_milvus.py +141 -0
- test/integration/connectors/test_pinecone.py +53 -1
- test/integration/connectors/utils/docker.py +81 -15
- test/integration/connectors/utils/validation.py +10 -0
- test/integration/connectors/weaviate/__init__.py +0 -0
- test/integration/connectors/weaviate/conftest.py +15 -0
- test/integration/connectors/weaviate/test_local.py +131 -0
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/pipeline/reformat/embedding.py +1 -1
- unstructured_ingest/utils/data_prep.py +9 -1
- unstructured_ingest/v2/processes/connectors/__init__.py +3 -16
- unstructured_ingest/v2/processes/connectors/astradb.py +2 -2
- unstructured_ingest/v2/processes/connectors/azure_ai_search.py +4 -0
- unstructured_ingest/v2/processes/connectors/delta_table.py +20 -4
- unstructured_ingest/v2/processes/connectors/elasticsearch/__init__.py +19 -0
- unstructured_ingest/v2/processes/connectors/{elasticsearch.py → elasticsearch/elasticsearch.py} +92 -46
- unstructured_ingest/v2/processes/connectors/{opensearch.py → elasticsearch/opensearch.py} +1 -1
- unstructured_ingest/v2/processes/connectors/kafka/kafka.py +6 -0
- unstructured_ingest/v2/processes/connectors/lancedb/__init__.py +17 -0
- unstructured_ingest/v2/processes/connectors/lancedb/aws.py +43 -0
- unstructured_ingest/v2/processes/connectors/lancedb/azure.py +43 -0
- unstructured_ingest/v2/processes/connectors/lancedb/gcp.py +44 -0
- unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +161 -0
- unstructured_ingest/v2/processes/connectors/lancedb/local.py +44 -0
- unstructured_ingest/v2/processes/connectors/milvus.py +72 -27
- unstructured_ingest/v2/processes/connectors/pinecone.py +24 -7
- unstructured_ingest/v2/processes/connectors/sql/sql.py +97 -26
- unstructured_ingest/v2/processes/connectors/weaviate/__init__.py +22 -0
- unstructured_ingest/v2/processes/connectors/weaviate/cloud.py +164 -0
- unstructured_ingest/v2/processes/connectors/weaviate/embedded.py +90 -0
- unstructured_ingest/v2/processes/connectors/weaviate/local.py +73 -0
- unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +289 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/METADATA +15 -15
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/RECORD +50 -30
- unstructured_ingest/v2/processes/connectors/weaviate.py +0 -242
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
|
+
from contextlib import contextmanager
|
|
2
3
|
from dataclasses import dataclass, field
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Generator, Optional, Union
|
|
5
6
|
|
|
6
7
|
import pandas as pd
|
|
7
8
|
from dateutil import parser
|
|
@@ -10,6 +11,7 @@ from pydantic import Field, Secret
|
|
|
10
11
|
from unstructured_ingest.error import WriteError
|
|
11
12
|
from unstructured_ingest.utils.data_prep import flatten_dict
|
|
12
13
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
14
|
+
from unstructured_ingest.v2.constants import RECORD_ID_LABEL
|
|
13
15
|
from unstructured_ingest.v2.interfaces import (
|
|
14
16
|
AccessConfig,
|
|
15
17
|
ConnectionConfig,
|
|
@@ -90,24 +92,27 @@ class MilvusUploadStager(UploadStager):
|
|
|
90
92
|
pass
|
|
91
93
|
return parser.parse(date_string).timestamp()
|
|
92
94
|
|
|
93
|
-
def conform_dict(self, data: dict) ->
|
|
94
|
-
|
|
95
|
-
|
|
95
|
+
def conform_dict(self, data: dict, file_data: FileData) -> dict:
|
|
96
|
+
working_data = data.copy()
|
|
97
|
+
if self.upload_stager_config.flatten_metadata and (
|
|
98
|
+
metadata := working_data.pop("metadata", None)
|
|
99
|
+
):
|
|
100
|
+
working_data.update(flatten_dict(metadata, keys_to_omit=["data_source_record_locator"]))
|
|
96
101
|
|
|
97
102
|
# TODO: milvus sdk doesn't seem to support defaults via the schema yet,
|
|
98
103
|
# remove once that gets updated
|
|
99
104
|
defaults = {"is_continuation": False}
|
|
100
105
|
for default in defaults:
|
|
101
|
-
if default not in
|
|
102
|
-
|
|
106
|
+
if default not in working_data:
|
|
107
|
+
working_data[default] = defaults[default]
|
|
103
108
|
|
|
104
109
|
if self.upload_stager_config.fields_to_include:
|
|
105
|
-
data_keys = set(
|
|
110
|
+
data_keys = set(working_data.keys())
|
|
106
111
|
for data_key in data_keys:
|
|
107
112
|
if data_key not in self.upload_stager_config.fields_to_include:
|
|
108
|
-
|
|
113
|
+
working_data.pop(data_key)
|
|
109
114
|
for field_include_key in self.upload_stager_config.fields_to_include:
|
|
110
|
-
if field_include_key not in
|
|
115
|
+
if field_include_key not in working_data:
|
|
111
116
|
raise KeyError(f"Field '{field_include_key}' is missing in data!")
|
|
112
117
|
|
|
113
118
|
datetime_columns = [
|
|
@@ -120,11 +125,15 @@ class MilvusUploadStager(UploadStager):
|
|
|
120
125
|
json_dumps_fields = ["languages", "data_source_permissions_data"]
|
|
121
126
|
|
|
122
127
|
for datetime_column in datetime_columns:
|
|
123
|
-
if datetime_column in
|
|
124
|
-
|
|
128
|
+
if datetime_column in working_data:
|
|
129
|
+
working_data[datetime_column] = self.parse_date_string(
|
|
130
|
+
working_data[datetime_column]
|
|
131
|
+
)
|
|
125
132
|
for json_dumps_field in json_dumps_fields:
|
|
126
|
-
if json_dumps_field in
|
|
127
|
-
|
|
133
|
+
if json_dumps_field in working_data:
|
|
134
|
+
working_data[json_dumps_field] = json.dumps(working_data[json_dumps_field])
|
|
135
|
+
working_data[RECORD_ID_LABEL] = file_data.identifier
|
|
136
|
+
return working_data
|
|
128
137
|
|
|
129
138
|
def run(
|
|
130
139
|
self,
|
|
@@ -136,18 +145,27 @@ class MilvusUploadStager(UploadStager):
|
|
|
136
145
|
) -> Path:
|
|
137
146
|
with open(elements_filepath) as elements_file:
|
|
138
147
|
elements_contents: list[dict[str, Any]] = json.load(elements_file)
|
|
139
|
-
|
|
140
|
-
self.conform_dict(data=element)
|
|
141
|
-
|
|
142
|
-
|
|
148
|
+
new_content = [
|
|
149
|
+
self.conform_dict(data=element, file_data=file_data) for element in elements_contents
|
|
150
|
+
]
|
|
151
|
+
output_filename_path = Path(output_filename)
|
|
152
|
+
if output_filename_path.suffix == ".json":
|
|
153
|
+
output_path = Path(output_dir) / output_filename_path
|
|
154
|
+
else:
|
|
155
|
+
output_path = Path(output_dir) / output_filename_path.with_suffix(".json")
|
|
143
156
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
144
157
|
with output_path.open("w") as output_file:
|
|
145
|
-
json.dump(
|
|
158
|
+
json.dump(new_content, output_file, indent=2)
|
|
146
159
|
return output_path
|
|
147
160
|
|
|
148
161
|
|
|
149
162
|
class MilvusUploaderConfig(UploaderConfig):
|
|
163
|
+
db_name: Optional[str] = Field(default=None, description="Milvus database name")
|
|
150
164
|
collection_name: str = Field(description="Milvus collections to write to")
|
|
165
|
+
record_id_key: str = Field(
|
|
166
|
+
default=RECORD_ID_LABEL,
|
|
167
|
+
description="searchable key to find entries for the same record on previous runs",
|
|
168
|
+
)
|
|
151
169
|
|
|
152
170
|
|
|
153
171
|
@dataclass
|
|
@@ -156,6 +174,16 @@ class MilvusUploader(Uploader):
|
|
|
156
174
|
upload_config: MilvusUploaderConfig
|
|
157
175
|
connector_type: str = CONNECTOR_TYPE
|
|
158
176
|
|
|
177
|
+
@contextmanager
|
|
178
|
+
def get_client(self) -> Generator["MilvusClient", None, None]:
|
|
179
|
+
client = self.connection_config.get_client()
|
|
180
|
+
if db_name := self.upload_config.db_name:
|
|
181
|
+
client.using_database(db_name=db_name)
|
|
182
|
+
try:
|
|
183
|
+
yield client
|
|
184
|
+
finally:
|
|
185
|
+
client.close()
|
|
186
|
+
|
|
159
187
|
def upload(self, content: UploadContent) -> None:
|
|
160
188
|
file_extension = content.path.suffix
|
|
161
189
|
if file_extension == ".json":
|
|
@@ -165,23 +193,39 @@ class MilvusUploader(Uploader):
|
|
|
165
193
|
else:
|
|
166
194
|
raise ValueError(f"Unsupported file extension: {file_extension}")
|
|
167
195
|
|
|
196
|
+
def delete_by_record_id(self, file_data: FileData) -> None:
|
|
197
|
+
logger.info(
|
|
198
|
+
f"deleting any content with metadata {RECORD_ID_LABEL}={file_data.identifier} "
|
|
199
|
+
f"from milvus collection {self.upload_config.collection_name}"
|
|
200
|
+
)
|
|
201
|
+
with self.get_client() as client:
|
|
202
|
+
delete_filter = f'{self.upload_config.record_id_key} == "{file_data.identifier}"'
|
|
203
|
+
resp = client.delete(
|
|
204
|
+
collection_name=self.upload_config.collection_name, filter=delete_filter
|
|
205
|
+
)
|
|
206
|
+
logger.info(
|
|
207
|
+
"deleted {} records from milvus collection {}".format(
|
|
208
|
+
resp["delete_count"], self.upload_config.collection_name
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
|
|
168
212
|
@requires_dependencies(["pymilvus"], extras="milvus")
|
|
169
213
|
def insert_results(self, data: Union[dict, list[dict]]):
|
|
170
214
|
from pymilvus import MilvusException
|
|
171
215
|
|
|
172
|
-
logger.
|
|
216
|
+
logger.info(
|
|
173
217
|
f"uploading {len(data)} entries to {self.connection_config.db_name} "
|
|
174
218
|
f"db in collection {self.upload_config.collection_name}"
|
|
175
219
|
)
|
|
176
|
-
|
|
220
|
+
with self.get_client() as client:
|
|
177
221
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
222
|
+
try:
|
|
223
|
+
res = client.insert(collection_name=self.upload_config.collection_name, data=data)
|
|
224
|
+
except MilvusException as milvus_exception:
|
|
225
|
+
raise WriteError("failed to upload records to milvus") from milvus_exception
|
|
226
|
+
if "err_count" in res and isinstance(res["err_count"], int) and res["err_count"] > 0:
|
|
227
|
+
err_count = res["err_count"]
|
|
228
|
+
raise WriteError(f"failed to upload {err_count} docs")
|
|
185
229
|
|
|
186
230
|
def upload_csv(self, content: UploadContent) -> None:
|
|
187
231
|
df = pd.read_csv(content.path)
|
|
@@ -194,6 +238,7 @@ class MilvusUploader(Uploader):
|
|
|
194
238
|
self.insert_results(data=data)
|
|
195
239
|
|
|
196
240
|
def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
|
|
241
|
+
self.delete_by_record_id(file_data=file_data)
|
|
197
242
|
self.upload(content=UploadContent(path=path, file_data=file_data))
|
|
198
243
|
|
|
199
244
|
|
|
@@ -30,6 +30,7 @@ if TYPE_CHECKING:
|
|
|
30
30
|
CONNECTOR_TYPE = "pinecone"
|
|
31
31
|
MAX_PAYLOAD_SIZE = 2 * 1024 * 1024 # 2MB
|
|
32
32
|
MAX_POOL_THREADS = 100
|
|
33
|
+
MAX_METADATA_BYTES = 40960 # 40KB https://docs.pinecone.io/reference/quotas-and-limits#hard-limits
|
|
33
34
|
|
|
34
35
|
|
|
35
36
|
class PineconeAccessConfig(AccessConfig):
|
|
@@ -103,6 +104,10 @@ class PineconeUploaderConfig(UploaderConfig):
|
|
|
103
104
|
default=None,
|
|
104
105
|
description="The namespace to write to. If not specified, the default namespace is used",
|
|
105
106
|
)
|
|
107
|
+
record_id_key: str = Field(
|
|
108
|
+
default=RECORD_ID_LABEL,
|
|
109
|
+
description="searchable key to find entries for the same record on previous runs",
|
|
110
|
+
)
|
|
106
111
|
|
|
107
112
|
|
|
108
113
|
@dataclass
|
|
@@ -133,6 +138,13 @@ class PineconeUploadStager(UploadStager):
|
|
|
133
138
|
remove_none=True,
|
|
134
139
|
)
|
|
135
140
|
metadata[RECORD_ID_LABEL] = file_data.identifier
|
|
141
|
+
metadata_size_bytes = len(json.dumps(metadata).encode())
|
|
142
|
+
if metadata_size_bytes > MAX_METADATA_BYTES:
|
|
143
|
+
logger.info(
|
|
144
|
+
f"Metadata size is {metadata_size_bytes} bytes, which exceeds the limit of"
|
|
145
|
+
f" {MAX_METADATA_BYTES} bytes per vector. Dropping the metadata."
|
|
146
|
+
)
|
|
147
|
+
metadata = {}
|
|
136
148
|
|
|
137
149
|
return {
|
|
138
150
|
"id": str(uuid.uuid4()),
|
|
@@ -183,23 +195,28 @@ class PineconeUploader(Uploader):
|
|
|
183
195
|
|
|
184
196
|
def pod_delete_by_record_id(self, file_data: FileData) -> None:
|
|
185
197
|
logger.debug(
|
|
186
|
-
f"deleting any content with metadata
|
|
198
|
+
f"deleting any content with metadata "
|
|
199
|
+
f"{self.upload_config.record_id_key}={file_data.identifier} "
|
|
187
200
|
f"from pinecone pod index"
|
|
188
201
|
)
|
|
189
202
|
index = self.connection_config.get_index(pool_threads=MAX_POOL_THREADS)
|
|
190
|
-
delete_kwargs = {
|
|
203
|
+
delete_kwargs = {
|
|
204
|
+
"filter": {self.upload_config.record_id_key: {"$eq": file_data.identifier}}
|
|
205
|
+
}
|
|
191
206
|
if namespace := self.upload_config.namespace:
|
|
192
207
|
delete_kwargs["namespace"] = namespace
|
|
193
208
|
|
|
194
209
|
resp = index.delete(**delete_kwargs)
|
|
195
210
|
logger.debug(
|
|
196
|
-
f"deleted any content with metadata
|
|
211
|
+
f"deleted any content with metadata "
|
|
212
|
+
f"{self.upload_config.record_id_key}={file_data.identifier} "
|
|
197
213
|
f"from pinecone index: {resp}"
|
|
198
214
|
)
|
|
199
215
|
|
|
200
216
|
def serverless_delete_by_record_id(self, file_data: FileData) -> None:
|
|
201
217
|
logger.debug(
|
|
202
|
-
f"deleting any content with metadata
|
|
218
|
+
f"deleting any content with metadata "
|
|
219
|
+
f"{self.upload_config.record_id_key}={file_data.identifier} "
|
|
203
220
|
f"from pinecone serverless index"
|
|
204
221
|
)
|
|
205
222
|
index = self.connection_config.get_index(pool_threads=MAX_POOL_THREADS)
|
|
@@ -209,7 +226,7 @@ class PineconeUploader(Uploader):
|
|
|
209
226
|
return
|
|
210
227
|
dimension = index_stats["dimension"]
|
|
211
228
|
query_params = {
|
|
212
|
-
"filter": {
|
|
229
|
+
"filter": {self.upload_config.record_id_key: {"$eq": file_data.identifier}},
|
|
213
230
|
"vector": [0] * dimension,
|
|
214
231
|
"top_k": total_vectors,
|
|
215
232
|
}
|
|
@@ -226,7 +243,8 @@ class PineconeUploader(Uploader):
|
|
|
226
243
|
delete_params["namespace"] = namespace
|
|
227
244
|
index.delete(**delete_params)
|
|
228
245
|
logger.debug(
|
|
229
|
-
f"deleted any content with metadata
|
|
246
|
+
f"deleted any content with metadata "
|
|
247
|
+
f"{self.upload_config.record_id_key}={file_data.identifier} "
|
|
230
248
|
f"from pinecone index"
|
|
231
249
|
)
|
|
232
250
|
|
|
@@ -269,7 +287,6 @@ class PineconeUploader(Uploader):
|
|
|
269
287
|
f"writing a total of {len(elements_dict)} elements via"
|
|
270
288
|
f" document batches to destination"
|
|
271
289
|
f" index named {self.connection_config.index_name}"
|
|
272
|
-
f" with batch size {self.upload_config.batch_size}"
|
|
273
290
|
)
|
|
274
291
|
# Determine if serverless or pod based index
|
|
275
292
|
pinecone_client = self.connection_config.get_client()
|
|
@@ -16,6 +16,8 @@ from dateutil import parser
|
|
|
16
16
|
from pydantic import Field, Secret
|
|
17
17
|
|
|
18
18
|
from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
|
|
19
|
+
from unstructured_ingest.utils.data_prep import split_dataframe
|
|
20
|
+
from unstructured_ingest.v2.constants import RECORD_ID_LABEL
|
|
19
21
|
from unstructured_ingest.v2.interfaces import (
|
|
20
22
|
AccessConfig,
|
|
21
23
|
ConnectionConfig,
|
|
@@ -236,35 +238,25 @@ class SQLUploadStagerConfig(UploadStagerConfig):
|
|
|
236
238
|
class SQLUploadStager(UploadStager):
|
|
237
239
|
upload_stager_config: SQLUploadStagerConfig = field(default_factory=SQLUploadStagerConfig)
|
|
238
240
|
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
file_data: FileData,
|
|
243
|
-
output_dir: Path,
|
|
244
|
-
output_filename: str,
|
|
245
|
-
**kwargs: Any,
|
|
246
|
-
) -> Path:
|
|
247
|
-
with open(elements_filepath) as elements_file:
|
|
248
|
-
elements_contents: list[dict] = json.load(elements_file)
|
|
249
|
-
output_path = Path(output_dir) / Path(f"{output_filename}.json")
|
|
250
|
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
251
|
-
|
|
241
|
+
@staticmethod
|
|
242
|
+
def conform_dict(data: dict, file_data: FileData) -> pd.DataFrame:
|
|
243
|
+
working_data = data.copy()
|
|
252
244
|
output = []
|
|
253
|
-
for
|
|
254
|
-
metadata: dict[str, Any] =
|
|
245
|
+
for element in working_data:
|
|
246
|
+
metadata: dict[str, Any] = element.pop("metadata", {})
|
|
255
247
|
data_source = metadata.pop("data_source", {})
|
|
256
248
|
coordinates = metadata.pop("coordinates", {})
|
|
257
249
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
250
|
+
element.update(metadata)
|
|
251
|
+
element.update(data_source)
|
|
252
|
+
element.update(coordinates)
|
|
261
253
|
|
|
262
|
-
|
|
254
|
+
element["id"] = str(uuid.uuid4())
|
|
263
255
|
|
|
264
256
|
# remove extraneous, not supported columns
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
output.append(
|
|
257
|
+
element = {k: v for k, v in element.items() if k in _COLUMNS}
|
|
258
|
+
element[RECORD_ID_LABEL] = file_data.identifier
|
|
259
|
+
output.append(element)
|
|
268
260
|
|
|
269
261
|
df = pd.DataFrame.from_dict(output)
|
|
270
262
|
for column in filter(lambda x: x in df.columns, _DATE_COLUMNS):
|
|
@@ -281,6 +273,26 @@ class SQLUploadStager(UploadStager):
|
|
|
281
273
|
("version", "page_number", "regex_metadata"),
|
|
282
274
|
):
|
|
283
275
|
df[column] = df[column].apply(str)
|
|
276
|
+
return df
|
|
277
|
+
|
|
278
|
+
def run(
|
|
279
|
+
self,
|
|
280
|
+
elements_filepath: Path,
|
|
281
|
+
file_data: FileData,
|
|
282
|
+
output_dir: Path,
|
|
283
|
+
output_filename: str,
|
|
284
|
+
**kwargs: Any,
|
|
285
|
+
) -> Path:
|
|
286
|
+
with open(elements_filepath) as elements_file:
|
|
287
|
+
elements_contents: list[dict] = json.load(elements_file)
|
|
288
|
+
|
|
289
|
+
df = self.conform_dict(data=elements_contents, file_data=file_data)
|
|
290
|
+
if Path(output_filename).suffix != ".json":
|
|
291
|
+
output_filename = f"{output_filename}.json"
|
|
292
|
+
else:
|
|
293
|
+
output_filename = f"{Path(output_filename).stem}.json"
|
|
294
|
+
output_path = Path(output_dir) / Path(f"{output_filename}")
|
|
295
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
284
296
|
|
|
285
297
|
with output_path.open("w") as output_file:
|
|
286
298
|
df.to_json(output_file, orient="records", lines=True)
|
|
@@ -290,6 +302,10 @@ class SQLUploadStager(UploadStager):
|
|
|
290
302
|
class SQLUploaderConfig(UploaderConfig):
|
|
291
303
|
batch_size: int = Field(default=50, description="Number of records per batch")
|
|
292
304
|
table_name: str = Field(default="elements", description="which table to upload contents to")
|
|
305
|
+
record_id_key: str = Field(
|
|
306
|
+
default=RECORD_ID_LABEL,
|
|
307
|
+
description="searchable key to find entries for the same record on previous runs",
|
|
308
|
+
)
|
|
293
309
|
|
|
294
310
|
|
|
295
311
|
@dataclass
|
|
@@ -323,18 +339,45 @@ class SQLUploader(Uploader):
|
|
|
323
339
|
output.append(tuple(parsed))
|
|
324
340
|
return output
|
|
325
341
|
|
|
342
|
+
def _fit_to_schema(self, df: pd.DataFrame, columns: list[str]) -> pd.DataFrame:
|
|
343
|
+
columns = set(df.columns)
|
|
344
|
+
schema_fields = set(columns)
|
|
345
|
+
columns_to_drop = columns - schema_fields
|
|
346
|
+
missing_columns = schema_fields - columns
|
|
347
|
+
|
|
348
|
+
if columns_to_drop:
|
|
349
|
+
logger.warning(
|
|
350
|
+
"Following columns will be dropped to match the table's schema: "
|
|
351
|
+
f"{', '.join(columns_to_drop)}"
|
|
352
|
+
)
|
|
353
|
+
if missing_columns:
|
|
354
|
+
logger.info(
|
|
355
|
+
"Following null filled columns will be added to match the table's schema:"
|
|
356
|
+
f" {', '.join(missing_columns)} "
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
df = df.drop(columns=columns_to_drop)
|
|
360
|
+
|
|
361
|
+
for column in missing_columns:
|
|
362
|
+
df[column] = pd.Series()
|
|
363
|
+
|
|
326
364
|
def upload_contents(self, path: Path) -> None:
|
|
327
365
|
df = pd.read_json(path, orient="records", lines=True)
|
|
328
366
|
df.replace({np.nan: None}, inplace=True)
|
|
367
|
+
self._fit_to_schema(df=df, columns=self.get_table_columns())
|
|
329
368
|
|
|
330
369
|
columns = list(df.columns)
|
|
331
370
|
stmt = f"INSERT INTO {self.upload_config.table_name} ({','.join(columns)}) VALUES({','.join([self.values_delimiter for x in columns])})" # noqa E501
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
371
|
+
logger.info(
|
|
372
|
+
f"writing a total of {len(df)} elements via"
|
|
373
|
+
f" document batches to destination"
|
|
374
|
+
f" table named {self.upload_config.table_name}"
|
|
375
|
+
f" with batch size {self.upload_config.batch_size}"
|
|
376
|
+
)
|
|
377
|
+
for rows in split_dataframe(df=df, chunk_size=self.upload_config.batch_size):
|
|
336
378
|
with self.connection_config.get_cursor() as cursor:
|
|
337
379
|
values = self.prepare_data(columns, tuple(rows.itertuples(index=False, name=None)))
|
|
380
|
+
# For debugging purposes:
|
|
338
381
|
# for val in values:
|
|
339
382
|
# try:
|
|
340
383
|
# cursor.execute(stmt, val)
|
|
@@ -343,5 +386,33 @@ class SQLUploader(Uploader):
|
|
|
343
386
|
# print(f"failed to write {len(columns)}, {len(val)}: {stmt} -> {val}")
|
|
344
387
|
cursor.executemany(stmt, values)
|
|
345
388
|
|
|
389
|
+
def get_table_columns(self) -> list[str]:
|
|
390
|
+
with self.connection_config.get_cursor() as cursor:
|
|
391
|
+
cursor.execute(f"SELECT * from {self.upload_config.table_name}")
|
|
392
|
+
return [desc[0] for desc in cursor.description]
|
|
393
|
+
|
|
394
|
+
def can_delete(self) -> bool:
|
|
395
|
+
return self.upload_config.record_id_key in self.get_table_columns()
|
|
396
|
+
|
|
397
|
+
def delete_by_record_id(self, file_data: FileData) -> None:
|
|
398
|
+
logger.debug(
|
|
399
|
+
f"deleting any content with data "
|
|
400
|
+
f"{self.upload_config.record_id_key}={file_data.identifier} "
|
|
401
|
+
f"from table {self.upload_config.table_name}"
|
|
402
|
+
)
|
|
403
|
+
stmt = f"DELETE FROM {self.upload_config.table_name} WHERE {self.upload_config.record_id_key} = {self.values_delimiter}" # noqa: E501
|
|
404
|
+
with self.connection_config.get_cursor() as cursor:
|
|
405
|
+
cursor.execute(stmt, [file_data.identifier])
|
|
406
|
+
rowcount = cursor.rowcount
|
|
407
|
+
logger.info(f"deleted {rowcount} rows from table {self.upload_config.table_name}")
|
|
408
|
+
|
|
346
409
|
def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
|
|
410
|
+
if self.can_delete():
|
|
411
|
+
self.delete_by_record_id(file_data=file_data)
|
|
412
|
+
else:
|
|
413
|
+
logger.warning(
|
|
414
|
+
f"table doesn't contain expected "
|
|
415
|
+
f"record id column "
|
|
416
|
+
f"{self.upload_config.record_id_key}, skipping delete"
|
|
417
|
+
)
|
|
347
418
|
self.upload_contents(path=path)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from unstructured_ingest.v2.processes.connector_registry import (
|
|
4
|
+
add_destination_entry,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
from .cloud import CONNECTOR_TYPE as CLOUD_WEAVIATE_CONNECTOR_TYPE
|
|
8
|
+
from .cloud import weaviate_cloud_destination_entry
|
|
9
|
+
from .embedded import CONNECTOR_TYPE as EMBEDDED_WEAVIATE_CONNECTOR_TYPE
|
|
10
|
+
from .embedded import weaviate_embedded_destination_entry
|
|
11
|
+
from .local import CONNECTOR_TYPE as LOCAL_WEAVIATE_CONNECTOR_TYPE
|
|
12
|
+
from .local import weaviate_local_destination_entry
|
|
13
|
+
|
|
14
|
+
add_destination_entry(
|
|
15
|
+
destination_type=LOCAL_WEAVIATE_CONNECTOR_TYPE, entry=weaviate_local_destination_entry
|
|
16
|
+
)
|
|
17
|
+
add_destination_entry(
|
|
18
|
+
destination_type=CLOUD_WEAVIATE_CONNECTOR_TYPE, entry=weaviate_cloud_destination_entry
|
|
19
|
+
)
|
|
20
|
+
add_destination_entry(
|
|
21
|
+
destination_type=EMBEDDED_WEAVIATE_CONNECTOR_TYPE, entry=weaviate_embedded_destination_entry
|
|
22
|
+
)
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Generator, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import Field, Secret
|
|
6
|
+
|
|
7
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
8
|
+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
|
|
9
|
+
from unstructured_ingest.v2.processes.connectors.weaviate.weaviate import (
|
|
10
|
+
WeaviateAccessConfig,
|
|
11
|
+
WeaviateConnectionConfig,
|
|
12
|
+
WeaviateUploader,
|
|
13
|
+
WeaviateUploaderConfig,
|
|
14
|
+
WeaviateUploadStager,
|
|
15
|
+
WeaviateUploadStagerConfig,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from weaviate.auth import AuthCredentials
|
|
20
|
+
from weaviate.client import WeaviateClient
|
|
21
|
+
|
|
22
|
+
CONNECTOR_TYPE = "weaviate-cloud"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CloudWeaviateAccessConfig(WeaviateAccessConfig):
|
|
26
|
+
access_token: Optional[str] = Field(
|
|
27
|
+
default=None, description="Used to create the bearer token."
|
|
28
|
+
)
|
|
29
|
+
api_key: Optional[str] = None
|
|
30
|
+
client_secret: Optional[str] = None
|
|
31
|
+
password: Optional[str] = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CloudWeaviateConnectionConfig(WeaviateConnectionConfig):
|
|
35
|
+
cluster_url: str = Field(
|
|
36
|
+
description="The WCD cluster URL or hostname to connect to. "
|
|
37
|
+
"Usually in the form: rAnD0mD1g1t5.something.weaviate.cloud"
|
|
38
|
+
)
|
|
39
|
+
username: Optional[str] = None
|
|
40
|
+
anonymous: bool = Field(default=False, description="if set, all auth values will be ignored")
|
|
41
|
+
refresh_token: Optional[str] = Field(
|
|
42
|
+
default=None,
|
|
43
|
+
description="Will tie this value to the bearer token. If not provided, "
|
|
44
|
+
"the authentication will expire once the lifetime of the access token is up.",
|
|
45
|
+
)
|
|
46
|
+
access_config: Secret[CloudWeaviateAccessConfig]
|
|
47
|
+
|
|
48
|
+
def model_post_init(self, __context: Any) -> None:
|
|
49
|
+
if self.anonymous:
|
|
50
|
+
return
|
|
51
|
+
access_config = self.access_config.get_secret_value()
|
|
52
|
+
auths = {
|
|
53
|
+
"api_key": access_config.api_key is not None,
|
|
54
|
+
"bearer_token": access_config.access_token is not None,
|
|
55
|
+
"client_secret": access_config.client_secret is not None,
|
|
56
|
+
"client_password": access_config.password is not None and self.username is not None,
|
|
57
|
+
}
|
|
58
|
+
if len(auths) == 0:
|
|
59
|
+
raise ValueError("No auth values provided and anonymous is False")
|
|
60
|
+
if len(auths) > 1:
|
|
61
|
+
existing_auths = [auth_method for auth_method, flag in auths.items() if flag]
|
|
62
|
+
raise ValueError(
|
|
63
|
+
"Multiple auth values provided, only one approach can be used: {}".format(
|
|
64
|
+
", ".join(existing_auths)
|
|
65
|
+
)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
69
|
+
def get_api_key_auth(self) -> Optional["AuthCredentials"]:
|
|
70
|
+
from weaviate.classes.init import Auth
|
|
71
|
+
|
|
72
|
+
if api_key := self.access_config.get_secret_value().api_key:
|
|
73
|
+
return Auth.api_key(api_key=api_key)
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
77
|
+
def get_bearer_token_auth(self) -> Optional["AuthCredentials"]:
|
|
78
|
+
from weaviate.classes.init import Auth
|
|
79
|
+
|
|
80
|
+
if access_token := self.access_config.get_secret_value().access_token:
|
|
81
|
+
return Auth.bearer_token(access_token=access_token, refresh_token=self.refresh_token)
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
85
|
+
def get_client_secret_auth(self) -> Optional["AuthCredentials"]:
|
|
86
|
+
from weaviate.classes.init import Auth
|
|
87
|
+
|
|
88
|
+
if client_secret := self.access_config.get_secret_value().client_secret:
|
|
89
|
+
return Auth.client_credentials(client_secret=client_secret)
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
93
|
+
def get_client_password_auth(self) -> Optional["AuthCredentials"]:
|
|
94
|
+
from weaviate.classes.init import Auth
|
|
95
|
+
|
|
96
|
+
if (username := self.username) and (
|
|
97
|
+
password := self.access_config.get_secret_value().password
|
|
98
|
+
):
|
|
99
|
+
return Auth.client_password(username=username, password=password)
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
103
|
+
def get_auth(self) -> "AuthCredentials":
|
|
104
|
+
auths = [
|
|
105
|
+
self.get_api_key_auth(),
|
|
106
|
+
self.get_client_secret_auth(),
|
|
107
|
+
self.get_bearer_token_auth(),
|
|
108
|
+
self.get_client_password_auth(),
|
|
109
|
+
]
|
|
110
|
+
auths = [auth for auth in auths if auth]
|
|
111
|
+
if len(auths) == 0:
|
|
112
|
+
raise ValueError("No auth values provided and anonymous is False")
|
|
113
|
+
if len(auths) > 1:
|
|
114
|
+
raise ValueError("Multiple auth values provided, only one approach can be used")
|
|
115
|
+
return auths[0]
|
|
116
|
+
|
|
117
|
+
@contextmanager
|
|
118
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
119
|
+
def get_client(self) -> Generator["WeaviateClient", None, None]:
|
|
120
|
+
from weaviate import connect_to_weaviate_cloud
|
|
121
|
+
from weaviate.classes.init import AdditionalConfig
|
|
122
|
+
|
|
123
|
+
auth_credentials = None if self.anonymous else self.get_auth()
|
|
124
|
+
with connect_to_weaviate_cloud(
|
|
125
|
+
cluster_url=self.cluster_url,
|
|
126
|
+
auth_credentials=auth_credentials,
|
|
127
|
+
additional_config=AdditionalConfig(timeout=self.get_timeout()),
|
|
128
|
+
) as weaviate_client:
|
|
129
|
+
yield weaviate_client
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class CloudWeaviateUploadStagerConfig(WeaviateUploadStagerConfig):
|
|
133
|
+
pass
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@dataclass
|
|
137
|
+
class CloudWeaviateUploadStager(WeaviateUploadStager):
|
|
138
|
+
upload_stager_config: CloudWeaviateUploadStagerConfig = field(
|
|
139
|
+
default_factory=lambda: WeaviateUploadStagerConfig()
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class CloudWeaviateUploaderConfig(WeaviateUploaderConfig):
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@dataclass
|
|
148
|
+
class CloudWeaviateUploader(WeaviateUploader):
|
|
149
|
+
connection_config: CloudWeaviateConnectionConfig = field(
|
|
150
|
+
default_factory=lambda: CloudWeaviateConnectionConfig()
|
|
151
|
+
)
|
|
152
|
+
upload_config: CloudWeaviateUploaderConfig = field(
|
|
153
|
+
default_factory=lambda: CloudWeaviateUploaderConfig()
|
|
154
|
+
)
|
|
155
|
+
connector_type: str = CONNECTOR_TYPE
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
weaviate_cloud_destination_entry = DestinationRegistryEntry(
|
|
159
|
+
connection_config=CloudWeaviateConnectionConfig,
|
|
160
|
+
uploader=CloudWeaviateUploader,
|
|
161
|
+
uploader_config=CloudWeaviateUploaderConfig,
|
|
162
|
+
upload_stager=CloudWeaviateUploadStager,
|
|
163
|
+
upload_stager_config=CloudWeaviateUploadStagerConfig,
|
|
164
|
+
)
|