unstructured-ingest 0.2.2__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/elasticsearch/__init__.py +0 -0
- test/integration/connectors/elasticsearch/conftest.py +34 -0
- test/integration/connectors/elasticsearch/test_elasticsearch.py +308 -0
- test/integration/connectors/elasticsearch/test_opensearch.py +302 -0
- test/integration/connectors/sql/test_postgres.py +10 -4
- test/integration/connectors/sql/test_singlestore.py +8 -4
- test/integration/connectors/sql/test_snowflake.py +10 -6
- test/integration/connectors/sql/test_sqlite.py +4 -4
- test/integration/connectors/test_astradb.py +156 -0
- test/integration/connectors/test_azure_cog_search.py +233 -0
- test/integration/connectors/test_delta_table.py +46 -0
- test/integration/connectors/test_kafka.py +150 -16
- test/integration/connectors/test_lancedb.py +209 -0
- test/integration/connectors/test_milvus.py +141 -0
- test/integration/connectors/test_pinecone.py +213 -0
- test/integration/connectors/test_s3.py +23 -0
- test/integration/connectors/utils/docker.py +81 -15
- test/integration/connectors/utils/validation.py +10 -0
- test/integration/connectors/weaviate/__init__.py +0 -0
- test/integration/connectors/weaviate/conftest.py +15 -0
- test/integration/connectors/weaviate/test_local.py +131 -0
- test/unit/v2/__init__.py +0 -0
- test/unit/v2/chunkers/__init__.py +0 -0
- test/unit/v2/chunkers/test_chunkers.py +49 -0
- test/unit/v2/connectors/__init__.py +0 -0
- test/unit/v2/embedders/__init__.py +0 -0
- test/unit/v2/embedders/test_bedrock.py +36 -0
- test/unit/v2/embedders/test_huggingface.py +48 -0
- test/unit/v2/embedders/test_mixedbread.py +37 -0
- test/unit/v2/embedders/test_octoai.py +35 -0
- test/unit/v2/embedders/test_openai.py +35 -0
- test/unit/v2/embedders/test_togetherai.py +37 -0
- test/unit/v2/embedders/test_vertexai.py +37 -0
- test/unit/v2/embedders/test_voyageai.py +38 -0
- test/unit/v2/partitioners/__init__.py +0 -0
- test/unit/v2/partitioners/test_partitioner.py +63 -0
- test/unit/v2/utils/__init__.py +0 -0
- test/unit/v2/utils/data_generator.py +32 -0
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/cli/cmds/__init__.py +2 -2
- unstructured_ingest/cli/cmds/{azure_cognitive_search.py → azure_ai_search.py} +9 -9
- unstructured_ingest/connector/{azure_cognitive_search.py → azure_ai_search.py} +9 -9
- unstructured_ingest/pipeline/reformat/embedding.py +1 -1
- unstructured_ingest/runner/writers/__init__.py +2 -2
- unstructured_ingest/runner/writers/azure_ai_search.py +24 -0
- unstructured_ingest/utils/data_prep.py +9 -1
- unstructured_ingest/v2/constants.py +2 -0
- unstructured_ingest/v2/processes/connectors/__init__.py +7 -20
- unstructured_ingest/v2/processes/connectors/airtable.py +2 -2
- unstructured_ingest/v2/processes/connectors/astradb.py +35 -23
- unstructured_ingest/v2/processes/connectors/{azure_cognitive_search.py → azure_ai_search.py} +116 -35
- unstructured_ingest/v2/processes/connectors/confluence.py +2 -2
- unstructured_ingest/v2/processes/connectors/couchbase.py +1 -0
- unstructured_ingest/v2/processes/connectors/delta_table.py +37 -9
- unstructured_ingest/v2/processes/connectors/elasticsearch/__init__.py +19 -0
- unstructured_ingest/v2/processes/connectors/{elasticsearch.py → elasticsearch/elasticsearch.py} +93 -46
- unstructured_ingest/v2/processes/connectors/{opensearch.py → elasticsearch/opensearch.py} +1 -1
- unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +27 -0
- unstructured_ingest/v2/processes/connectors/google_drive.py +3 -3
- unstructured_ingest/v2/processes/connectors/kafka/__init__.py +6 -2
- unstructured_ingest/v2/processes/connectors/kafka/cloud.py +38 -2
- unstructured_ingest/v2/processes/connectors/kafka/kafka.py +84 -23
- unstructured_ingest/v2/processes/connectors/kafka/local.py +32 -4
- unstructured_ingest/v2/processes/connectors/lancedb/__init__.py +17 -0
- unstructured_ingest/v2/processes/connectors/lancedb/aws.py +43 -0
- unstructured_ingest/v2/processes/connectors/lancedb/azure.py +43 -0
- unstructured_ingest/v2/processes/connectors/lancedb/gcp.py +44 -0
- unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +161 -0
- unstructured_ingest/v2/processes/connectors/lancedb/local.py +44 -0
- unstructured_ingest/v2/processes/connectors/milvus.py +72 -27
- unstructured_ingest/v2/processes/connectors/onedrive.py +2 -3
- unstructured_ingest/v2/processes/connectors/outlook.py +2 -2
- unstructured_ingest/v2/processes/connectors/pinecone.py +101 -13
- unstructured_ingest/v2/processes/connectors/sharepoint.py +3 -2
- unstructured_ingest/v2/processes/connectors/slack.py +2 -2
- unstructured_ingest/v2/processes/connectors/sql/postgres.py +16 -8
- unstructured_ingest/v2/processes/connectors/sql/sql.py +97 -26
- unstructured_ingest/v2/processes/connectors/weaviate/__init__.py +22 -0
- unstructured_ingest/v2/processes/connectors/weaviate/cloud.py +164 -0
- unstructured_ingest/v2/processes/connectors/weaviate/embedded.py +90 -0
- unstructured_ingest/v2/processes/connectors/weaviate/local.py +73 -0
- unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +289 -0
- {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/METADATA +20 -19
- {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/RECORD +91 -50
- unstructured_ingest/runner/writers/azure_cognitive_search.py +0 -24
- unstructured_ingest/v2/processes/connectors/weaviate.py +0 -242
- /test/integration/embedders/{togetherai.py → test_togetherai.py} +0 -0
- /test/unit/{test_interfaces_v2.py → v2/test_interfaces.py} +0 -0
- /test/unit/{test_utils_v2.py → v2/test_utils.py} +0 -0
- {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.2.2.dist-info → unstructured_ingest-0.3.1.dist-info}/top_level.txt +0 -0
|
@@ -9,6 +9,7 @@ from pydantic import Field, Secret
|
|
|
9
9
|
from unstructured_ingest.error import DestinationConnectionError
|
|
10
10
|
from unstructured_ingest.utils.data_prep import flatten_dict, generator_batching_wbytes
|
|
11
11
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
12
|
+
from unstructured_ingest.v2.constants import RECORD_ID_LABEL
|
|
12
13
|
from unstructured_ingest.v2.interfaces import (
|
|
13
14
|
AccessConfig,
|
|
14
15
|
ConnectionConfig,
|
|
@@ -23,11 +24,13 @@ from unstructured_ingest.v2.processes.connector_registry import DestinationRegis
|
|
|
23
24
|
|
|
24
25
|
if TYPE_CHECKING:
|
|
25
26
|
from pinecone import Index as PineconeIndex
|
|
27
|
+
from pinecone import Pinecone
|
|
26
28
|
|
|
27
29
|
|
|
28
30
|
CONNECTOR_TYPE = "pinecone"
|
|
29
31
|
MAX_PAYLOAD_SIZE = 2 * 1024 * 1024 # 2MB
|
|
30
32
|
MAX_POOL_THREADS = 100
|
|
33
|
+
MAX_METADATA_BYTES = 40960 # 40KB https://docs.pinecone.io/reference/quotas-and-limits#hard-limits
|
|
31
34
|
|
|
32
35
|
|
|
33
36
|
class PineconeAccessConfig(AccessConfig):
|
|
@@ -43,16 +46,19 @@ class PineconeConnectionConfig(ConnectionConfig):
|
|
|
43
46
|
)
|
|
44
47
|
|
|
45
48
|
@requires_dependencies(["pinecone"], extras="pinecone")
|
|
46
|
-
def
|
|
49
|
+
def get_client(self, **index_kwargs) -> "Pinecone":
|
|
47
50
|
from pinecone import Pinecone
|
|
48
51
|
|
|
49
52
|
from unstructured_ingest import __version__ as unstructured_version
|
|
50
53
|
|
|
51
|
-
|
|
54
|
+
return Pinecone(
|
|
52
55
|
api_key=self.access_config.get_secret_value().pinecone_api_key,
|
|
53
56
|
source_tag=f"unstructured_ingest=={unstructured_version}",
|
|
54
57
|
)
|
|
55
58
|
|
|
59
|
+
def get_index(self, **index_kwargs) -> "PineconeIndex":
|
|
60
|
+
pc = self.get_client()
|
|
61
|
+
|
|
56
62
|
index = pc.Index(name=self.index_name, **index_kwargs)
|
|
57
63
|
logger.debug(f"connected to index: {pc.describe_index(self.index_name)}")
|
|
58
64
|
return index
|
|
@@ -98,6 +104,10 @@ class PineconeUploaderConfig(UploaderConfig):
|
|
|
98
104
|
default=None,
|
|
99
105
|
description="The namespace to write to. If not specified, the default namespace is used",
|
|
100
106
|
)
|
|
107
|
+
record_id_key: str = Field(
|
|
108
|
+
default=RECORD_ID_LABEL,
|
|
109
|
+
description="searchable key to find entries for the same record on previous runs",
|
|
110
|
+
)
|
|
101
111
|
|
|
102
112
|
|
|
103
113
|
@dataclass
|
|
@@ -106,7 +116,7 @@ class PineconeUploadStager(UploadStager):
|
|
|
106
116
|
default_factory=lambda: PineconeUploadStagerConfig()
|
|
107
117
|
)
|
|
108
118
|
|
|
109
|
-
def conform_dict(self, element_dict: dict) -> dict:
|
|
119
|
+
def conform_dict(self, element_dict: dict, file_data: FileData) -> dict:
|
|
110
120
|
embeddings = element_dict.pop("embeddings", None)
|
|
111
121
|
metadata: dict[str, Any] = element_dict.pop("metadata", {})
|
|
112
122
|
data_source = metadata.pop("data_source", {})
|
|
@@ -121,19 +131,30 @@ class PineconeUploadStager(UploadStager):
|
|
|
121
131
|
}
|
|
122
132
|
)
|
|
123
133
|
|
|
134
|
+
metadata = flatten_dict(
|
|
135
|
+
pinecone_metadata,
|
|
136
|
+
separator="-",
|
|
137
|
+
flatten_lists=True,
|
|
138
|
+
remove_none=True,
|
|
139
|
+
)
|
|
140
|
+
metadata[RECORD_ID_LABEL] = file_data.identifier
|
|
141
|
+
metadata_size_bytes = len(json.dumps(metadata).encode())
|
|
142
|
+
if metadata_size_bytes > MAX_METADATA_BYTES:
|
|
143
|
+
logger.info(
|
|
144
|
+
f"Metadata size is {metadata_size_bytes} bytes, which exceeds the limit of"
|
|
145
|
+
f" {MAX_METADATA_BYTES} bytes per vector. Dropping the metadata."
|
|
146
|
+
)
|
|
147
|
+
metadata = {}
|
|
148
|
+
|
|
124
149
|
return {
|
|
125
150
|
"id": str(uuid.uuid4()),
|
|
126
151
|
"values": embeddings,
|
|
127
|
-
"metadata":
|
|
128
|
-
pinecone_metadata,
|
|
129
|
-
separator="-",
|
|
130
|
-
flatten_lists=True,
|
|
131
|
-
remove_none=True,
|
|
132
|
-
),
|
|
152
|
+
"metadata": metadata,
|
|
133
153
|
}
|
|
134
154
|
|
|
135
155
|
def run(
|
|
136
156
|
self,
|
|
157
|
+
file_data: FileData,
|
|
137
158
|
elements_filepath: Path,
|
|
138
159
|
output_dir: Path,
|
|
139
160
|
output_filename: str,
|
|
@@ -143,10 +164,15 @@ class PineconeUploadStager(UploadStager):
|
|
|
143
164
|
elements_contents = json.load(elements_file)
|
|
144
165
|
|
|
145
166
|
conformed_elements = [
|
|
146
|
-
self.conform_dict(element_dict=element)
|
|
167
|
+
self.conform_dict(element_dict=element, file_data=file_data)
|
|
168
|
+
for element in elements_contents
|
|
147
169
|
]
|
|
148
170
|
|
|
149
|
-
|
|
171
|
+
if Path(output_filename).suffix != ".json":
|
|
172
|
+
output_filename = f"{output_filename}.json"
|
|
173
|
+
else:
|
|
174
|
+
output_filename = f"{Path(output_filename).stem}.json"
|
|
175
|
+
output_path = Path(output_dir) / Path(f"{output_filename}")
|
|
150
176
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
151
177
|
|
|
152
178
|
with open(output_path, "w") as output_file:
|
|
@@ -167,6 +193,61 @@ class PineconeUploader(Uploader):
|
|
|
167
193
|
logger.error(f"failed to validate connection: {e}", exc_info=True)
|
|
168
194
|
raise DestinationConnectionError(f"failed to validate connection: {e}")
|
|
169
195
|
|
|
196
|
+
def pod_delete_by_record_id(self, file_data: FileData) -> None:
|
|
197
|
+
logger.debug(
|
|
198
|
+
f"deleting any content with metadata "
|
|
199
|
+
f"{self.upload_config.record_id_key}={file_data.identifier} "
|
|
200
|
+
f"from pinecone pod index"
|
|
201
|
+
)
|
|
202
|
+
index = self.connection_config.get_index(pool_threads=MAX_POOL_THREADS)
|
|
203
|
+
delete_kwargs = {
|
|
204
|
+
"filter": {self.upload_config.record_id_key: {"$eq": file_data.identifier}}
|
|
205
|
+
}
|
|
206
|
+
if namespace := self.upload_config.namespace:
|
|
207
|
+
delete_kwargs["namespace"] = namespace
|
|
208
|
+
|
|
209
|
+
resp = index.delete(**delete_kwargs)
|
|
210
|
+
logger.debug(
|
|
211
|
+
f"deleted any content with metadata "
|
|
212
|
+
f"{self.upload_config.record_id_key}={file_data.identifier} "
|
|
213
|
+
f"from pinecone index: {resp}"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
def serverless_delete_by_record_id(self, file_data: FileData) -> None:
|
|
217
|
+
logger.debug(
|
|
218
|
+
f"deleting any content with metadata "
|
|
219
|
+
f"{self.upload_config.record_id_key}={file_data.identifier} "
|
|
220
|
+
f"from pinecone serverless index"
|
|
221
|
+
)
|
|
222
|
+
index = self.connection_config.get_index(pool_threads=MAX_POOL_THREADS)
|
|
223
|
+
index_stats = index.describe_index_stats()
|
|
224
|
+
total_vectors = index_stats["total_vector_count"]
|
|
225
|
+
if total_vectors == 0:
|
|
226
|
+
return
|
|
227
|
+
dimension = index_stats["dimension"]
|
|
228
|
+
query_params = {
|
|
229
|
+
"filter": {self.upload_config.record_id_key: {"$eq": file_data.identifier}},
|
|
230
|
+
"vector": [0] * dimension,
|
|
231
|
+
"top_k": total_vectors,
|
|
232
|
+
}
|
|
233
|
+
if namespace := self.upload_config.namespace:
|
|
234
|
+
query_params["namespace"] = namespace
|
|
235
|
+
while True:
|
|
236
|
+
query_results = index.query(**query_params)
|
|
237
|
+
matches = query_results.get("matches", [])
|
|
238
|
+
if not matches:
|
|
239
|
+
break
|
|
240
|
+
ids = [match["id"] for match in matches]
|
|
241
|
+
delete_params = {"ids": ids}
|
|
242
|
+
if namespace := self.upload_config.namespace:
|
|
243
|
+
delete_params["namespace"] = namespace
|
|
244
|
+
index.delete(**delete_params)
|
|
245
|
+
logger.debug(
|
|
246
|
+
f"deleted any content with metadata "
|
|
247
|
+
f"{self.upload_config.record_id_key}={file_data.identifier} "
|
|
248
|
+
f"from pinecone index"
|
|
249
|
+
)
|
|
250
|
+
|
|
170
251
|
@requires_dependencies(["pinecone"], extras="pinecone")
|
|
171
252
|
def upsert_batches_async(self, elements_dict: list[dict]):
|
|
172
253
|
from pinecone.exceptions import PineconeApiException
|
|
@@ -206,9 +287,16 @@ class PineconeUploader(Uploader):
|
|
|
206
287
|
f"writing a total of {len(elements_dict)} elements via"
|
|
207
288
|
f" document batches to destination"
|
|
208
289
|
f" index named {self.connection_config.index_name}"
|
|
209
|
-
f" with batch size {self.upload_config.batch_size}"
|
|
210
290
|
)
|
|
211
|
-
|
|
291
|
+
# Determine if serverless or pod based index
|
|
292
|
+
pinecone_client = self.connection_config.get_client()
|
|
293
|
+
index_description = pinecone_client.describe_index(name=self.connection_config.index_name)
|
|
294
|
+
if "serverless" in index_description.get("spec"):
|
|
295
|
+
self.serverless_delete_by_record_id(file_data=file_data)
|
|
296
|
+
elif "pod" in index_description.get("spec"):
|
|
297
|
+
self.pod_delete_by_record_id(file_data=file_data)
|
|
298
|
+
else:
|
|
299
|
+
raise ValueError(f"unexpected spec type in index description: {index_description}")
|
|
212
300
|
self.upsert_batches_async(elements_dict=elements_dict)
|
|
213
301
|
|
|
214
302
|
|
|
@@ -21,7 +21,6 @@ from unstructured_ingest.v2.interfaces import (
|
|
|
21
21
|
Indexer,
|
|
22
22
|
IndexerConfig,
|
|
23
23
|
SourceIdentifiers,
|
|
24
|
-
download_responses,
|
|
25
24
|
)
|
|
26
25
|
from unstructured_ingest.v2.logger import logger
|
|
27
26
|
from unstructured_ingest.v2.processes.connector_registry import (
|
|
@@ -426,7 +425,7 @@ class SharepointDownloader(Downloader):
|
|
|
426
425
|
f.write(etree.tostring(document, encoding="unicode", pretty_print=True))
|
|
427
426
|
return self.generate_download_response(file_data=file_data, download_path=download_path)
|
|
428
427
|
|
|
429
|
-
def run(self, file_data: FileData, **kwargs: Any) ->
|
|
428
|
+
def run(self, file_data: FileData, **kwargs: Any) -> DownloadResponse:
|
|
430
429
|
content_type = file_data.additional_metadata.get("sharepoint_content_type")
|
|
431
430
|
if not content_type:
|
|
432
431
|
raise ValueError(
|
|
@@ -436,6 +435,8 @@ class SharepointDownloader(Downloader):
|
|
|
436
435
|
return self.get_document(file_data=file_data)
|
|
437
436
|
elif content_type == SharepointContentType.SITEPAGE.value:
|
|
438
437
|
return self.get_site_page(file_data=file_data)
|
|
438
|
+
else:
|
|
439
|
+
raise ValueError(f"content type not recognized: {content_type}")
|
|
439
440
|
|
|
440
441
|
|
|
441
442
|
sharepoint_source_entry = SourceRegistryEntry(
|
|
@@ -16,9 +16,9 @@ from unstructured_ingest.v2.interfaces import (
|
|
|
16
16
|
ConnectionConfig,
|
|
17
17
|
Downloader,
|
|
18
18
|
DownloaderConfig,
|
|
19
|
+
DownloadResponse,
|
|
19
20
|
Indexer,
|
|
20
21
|
IndexerConfig,
|
|
21
|
-
download_responses,
|
|
22
22
|
)
|
|
23
23
|
from unstructured_ingest.v2.interfaces.file_data import (
|
|
24
24
|
FileData,
|
|
@@ -161,7 +161,7 @@ class SlackDownloader(Downloader):
|
|
|
161
161
|
def run(self, file_data, **kwargs):
|
|
162
162
|
raise NotImplementedError
|
|
163
163
|
|
|
164
|
-
async def run_async(self, file_data: FileData, **kwargs) ->
|
|
164
|
+
async def run_async(self, file_data: FileData, **kwargs) -> DownloadResponse:
|
|
165
165
|
# NOTE: Indexer should provide source identifiers required to generate the download path
|
|
166
166
|
download_path = self.get_download_path(file_data)
|
|
167
167
|
if download_path is None:
|
|
@@ -98,20 +98,28 @@ class PostgresDownloader(SQLDownloader):
|
|
|
98
98
|
download_config: PostgresDownloaderConfig
|
|
99
99
|
connector_type: str = CONNECTOR_TYPE
|
|
100
100
|
|
|
101
|
+
@requires_dependencies(["psycopg2"], extras="postgres")
|
|
101
102
|
def query_db(self, file_data: FileData) -> tuple[list[tuple], list[str]]:
|
|
103
|
+
from psycopg2 import sql
|
|
104
|
+
|
|
102
105
|
table_name = file_data.additional_metadata["table_name"]
|
|
103
106
|
id_column = file_data.additional_metadata["id_column"]
|
|
104
|
-
ids = file_data.additional_metadata["ids"]
|
|
107
|
+
ids = tuple(file_data.additional_metadata["ids"])
|
|
108
|
+
|
|
105
109
|
with self.connection_config.get_cursor() as cursor:
|
|
106
|
-
fields =
|
|
107
|
-
|
|
110
|
+
fields = (
|
|
111
|
+
sql.SQL(",").join(sql.Identifier(field) for field in self.download_config.fields)
|
|
112
|
+
if self.download_config.fields
|
|
113
|
+
else sql.SQL("*")
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
query = sql.SQL("SELECT {fields} FROM {table_name} WHERE {id_column} IN %s").format(
|
|
108
117
|
fields=fields,
|
|
109
|
-
table_name=table_name,
|
|
110
|
-
id_column=id_column,
|
|
111
|
-
ids=",".join([str(i) for i in ids]),
|
|
118
|
+
table_name=sql.Identifier(table_name),
|
|
119
|
+
id_column=sql.Identifier(id_column),
|
|
112
120
|
)
|
|
113
|
-
logger.debug(f"running query: {query}")
|
|
114
|
-
cursor.execute(query)
|
|
121
|
+
logger.debug(f"running query: {cursor.mogrify(query, (ids,))}")
|
|
122
|
+
cursor.execute(query, (ids,))
|
|
115
123
|
rows = cursor.fetchall()
|
|
116
124
|
columns = [col[0] for col in cursor.description]
|
|
117
125
|
return rows, columns
|
|
@@ -16,6 +16,8 @@ from dateutil import parser
|
|
|
16
16
|
from pydantic import Field, Secret
|
|
17
17
|
|
|
18
18
|
from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
|
|
19
|
+
from unstructured_ingest.utils.data_prep import split_dataframe
|
|
20
|
+
from unstructured_ingest.v2.constants import RECORD_ID_LABEL
|
|
19
21
|
from unstructured_ingest.v2.interfaces import (
|
|
20
22
|
AccessConfig,
|
|
21
23
|
ConnectionConfig,
|
|
@@ -236,35 +238,25 @@ class SQLUploadStagerConfig(UploadStagerConfig):
|
|
|
236
238
|
class SQLUploadStager(UploadStager):
|
|
237
239
|
upload_stager_config: SQLUploadStagerConfig = field(default_factory=SQLUploadStagerConfig)
|
|
238
240
|
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
file_data: FileData,
|
|
243
|
-
output_dir: Path,
|
|
244
|
-
output_filename: str,
|
|
245
|
-
**kwargs: Any,
|
|
246
|
-
) -> Path:
|
|
247
|
-
with open(elements_filepath) as elements_file:
|
|
248
|
-
elements_contents: list[dict] = json.load(elements_file)
|
|
249
|
-
output_path = Path(output_dir) / Path(f"{output_filename}.json")
|
|
250
|
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
251
|
-
|
|
241
|
+
@staticmethod
|
|
242
|
+
def conform_dict(data: dict, file_data: FileData) -> pd.DataFrame:
|
|
243
|
+
working_data = data.copy()
|
|
252
244
|
output = []
|
|
253
|
-
for
|
|
254
|
-
metadata: dict[str, Any] =
|
|
245
|
+
for element in working_data:
|
|
246
|
+
metadata: dict[str, Any] = element.pop("metadata", {})
|
|
255
247
|
data_source = metadata.pop("data_source", {})
|
|
256
248
|
coordinates = metadata.pop("coordinates", {})
|
|
257
249
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
250
|
+
element.update(metadata)
|
|
251
|
+
element.update(data_source)
|
|
252
|
+
element.update(coordinates)
|
|
261
253
|
|
|
262
|
-
|
|
254
|
+
element["id"] = str(uuid.uuid4())
|
|
263
255
|
|
|
264
256
|
# remove extraneous, not supported columns
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
output.append(
|
|
257
|
+
element = {k: v for k, v in element.items() if k in _COLUMNS}
|
|
258
|
+
element[RECORD_ID_LABEL] = file_data.identifier
|
|
259
|
+
output.append(element)
|
|
268
260
|
|
|
269
261
|
df = pd.DataFrame.from_dict(output)
|
|
270
262
|
for column in filter(lambda x: x in df.columns, _DATE_COLUMNS):
|
|
@@ -281,6 +273,26 @@ class SQLUploadStager(UploadStager):
|
|
|
281
273
|
("version", "page_number", "regex_metadata"),
|
|
282
274
|
):
|
|
283
275
|
df[column] = df[column].apply(str)
|
|
276
|
+
return df
|
|
277
|
+
|
|
278
|
+
def run(
|
|
279
|
+
self,
|
|
280
|
+
elements_filepath: Path,
|
|
281
|
+
file_data: FileData,
|
|
282
|
+
output_dir: Path,
|
|
283
|
+
output_filename: str,
|
|
284
|
+
**kwargs: Any,
|
|
285
|
+
) -> Path:
|
|
286
|
+
with open(elements_filepath) as elements_file:
|
|
287
|
+
elements_contents: list[dict] = json.load(elements_file)
|
|
288
|
+
|
|
289
|
+
df = self.conform_dict(data=elements_contents, file_data=file_data)
|
|
290
|
+
if Path(output_filename).suffix != ".json":
|
|
291
|
+
output_filename = f"{output_filename}.json"
|
|
292
|
+
else:
|
|
293
|
+
output_filename = f"{Path(output_filename).stem}.json"
|
|
294
|
+
output_path = Path(output_dir) / Path(f"{output_filename}")
|
|
295
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
284
296
|
|
|
285
297
|
with output_path.open("w") as output_file:
|
|
286
298
|
df.to_json(output_file, orient="records", lines=True)
|
|
@@ -290,6 +302,10 @@ class SQLUploadStager(UploadStager):
|
|
|
290
302
|
class SQLUploaderConfig(UploaderConfig):
|
|
291
303
|
batch_size: int = Field(default=50, description="Number of records per batch")
|
|
292
304
|
table_name: str = Field(default="elements", description="which table to upload contents to")
|
|
305
|
+
record_id_key: str = Field(
|
|
306
|
+
default=RECORD_ID_LABEL,
|
|
307
|
+
description="searchable key to find entries for the same record on previous runs",
|
|
308
|
+
)
|
|
293
309
|
|
|
294
310
|
|
|
295
311
|
@dataclass
|
|
@@ -323,18 +339,45 @@ class SQLUploader(Uploader):
|
|
|
323
339
|
output.append(tuple(parsed))
|
|
324
340
|
return output
|
|
325
341
|
|
|
342
|
+
def _fit_to_schema(self, df: pd.DataFrame, columns: list[str]) -> pd.DataFrame:
|
|
343
|
+
columns = set(df.columns)
|
|
344
|
+
schema_fields = set(columns)
|
|
345
|
+
columns_to_drop = columns - schema_fields
|
|
346
|
+
missing_columns = schema_fields - columns
|
|
347
|
+
|
|
348
|
+
if columns_to_drop:
|
|
349
|
+
logger.warning(
|
|
350
|
+
"Following columns will be dropped to match the table's schema: "
|
|
351
|
+
f"{', '.join(columns_to_drop)}"
|
|
352
|
+
)
|
|
353
|
+
if missing_columns:
|
|
354
|
+
logger.info(
|
|
355
|
+
"Following null filled columns will be added to match the table's schema:"
|
|
356
|
+
f" {', '.join(missing_columns)} "
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
df = df.drop(columns=columns_to_drop)
|
|
360
|
+
|
|
361
|
+
for column in missing_columns:
|
|
362
|
+
df[column] = pd.Series()
|
|
363
|
+
|
|
326
364
|
def upload_contents(self, path: Path) -> None:
|
|
327
365
|
df = pd.read_json(path, orient="records", lines=True)
|
|
328
366
|
df.replace({np.nan: None}, inplace=True)
|
|
367
|
+
self._fit_to_schema(df=df, columns=self.get_table_columns())
|
|
329
368
|
|
|
330
369
|
columns = list(df.columns)
|
|
331
370
|
stmt = f"INSERT INTO {self.upload_config.table_name} ({','.join(columns)}) VALUES({','.join([self.values_delimiter for x in columns])})" # noqa E501
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
371
|
+
logger.info(
|
|
372
|
+
f"writing a total of {len(df)} elements via"
|
|
373
|
+
f" document batches to destination"
|
|
374
|
+
f" table named {self.upload_config.table_name}"
|
|
375
|
+
f" with batch size {self.upload_config.batch_size}"
|
|
376
|
+
)
|
|
377
|
+
for rows in split_dataframe(df=df, chunk_size=self.upload_config.batch_size):
|
|
336
378
|
with self.connection_config.get_cursor() as cursor:
|
|
337
379
|
values = self.prepare_data(columns, tuple(rows.itertuples(index=False, name=None)))
|
|
380
|
+
# For debugging purposes:
|
|
338
381
|
# for val in values:
|
|
339
382
|
# try:
|
|
340
383
|
# cursor.execute(stmt, val)
|
|
@@ -343,5 +386,33 @@ class SQLUploader(Uploader):
|
|
|
343
386
|
# print(f"failed to write {len(columns)}, {len(val)}: {stmt} -> {val}")
|
|
344
387
|
cursor.executemany(stmt, values)
|
|
345
388
|
|
|
389
|
+
def get_table_columns(self) -> list[str]:
|
|
390
|
+
with self.connection_config.get_cursor() as cursor:
|
|
391
|
+
cursor.execute(f"SELECT * from {self.upload_config.table_name}")
|
|
392
|
+
return [desc[0] for desc in cursor.description]
|
|
393
|
+
|
|
394
|
+
def can_delete(self) -> bool:
|
|
395
|
+
return self.upload_config.record_id_key in self.get_table_columns()
|
|
396
|
+
|
|
397
|
+
def delete_by_record_id(self, file_data: FileData) -> None:
|
|
398
|
+
logger.debug(
|
|
399
|
+
f"deleting any content with data "
|
|
400
|
+
f"{self.upload_config.record_id_key}={file_data.identifier} "
|
|
401
|
+
f"from table {self.upload_config.table_name}"
|
|
402
|
+
)
|
|
403
|
+
stmt = f"DELETE FROM {self.upload_config.table_name} WHERE {self.upload_config.record_id_key} = {self.values_delimiter}" # noqa: E501
|
|
404
|
+
with self.connection_config.get_cursor() as cursor:
|
|
405
|
+
cursor.execute(stmt, [file_data.identifier])
|
|
406
|
+
rowcount = cursor.rowcount
|
|
407
|
+
logger.info(f"deleted {rowcount} rows from table {self.upload_config.table_name}")
|
|
408
|
+
|
|
346
409
|
def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
|
|
410
|
+
if self.can_delete():
|
|
411
|
+
self.delete_by_record_id(file_data=file_data)
|
|
412
|
+
else:
|
|
413
|
+
logger.warning(
|
|
414
|
+
f"table doesn't contain expected "
|
|
415
|
+
f"record id column "
|
|
416
|
+
f"{self.upload_config.record_id_key}, skipping delete"
|
|
417
|
+
)
|
|
347
418
|
self.upload_contents(path=path)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from unstructured_ingest.v2.processes.connector_registry import (
|
|
4
|
+
add_destination_entry,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
from .cloud import CONNECTOR_TYPE as CLOUD_WEAVIATE_CONNECTOR_TYPE
|
|
8
|
+
from .cloud import weaviate_cloud_destination_entry
|
|
9
|
+
from .embedded import CONNECTOR_TYPE as EMBEDDED_WEAVIATE_CONNECTOR_TYPE
|
|
10
|
+
from .embedded import weaviate_embedded_destination_entry
|
|
11
|
+
from .local import CONNECTOR_TYPE as LOCAL_WEAVIATE_CONNECTOR_TYPE
|
|
12
|
+
from .local import weaviate_local_destination_entry
|
|
13
|
+
|
|
14
|
+
add_destination_entry(
|
|
15
|
+
destination_type=LOCAL_WEAVIATE_CONNECTOR_TYPE, entry=weaviate_local_destination_entry
|
|
16
|
+
)
|
|
17
|
+
add_destination_entry(
|
|
18
|
+
destination_type=CLOUD_WEAVIATE_CONNECTOR_TYPE, entry=weaviate_cloud_destination_entry
|
|
19
|
+
)
|
|
20
|
+
add_destination_entry(
|
|
21
|
+
destination_type=EMBEDDED_WEAVIATE_CONNECTOR_TYPE, entry=weaviate_embedded_destination_entry
|
|
22
|
+
)
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Generator, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import Field, Secret
|
|
6
|
+
|
|
7
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
8
|
+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
|
|
9
|
+
from unstructured_ingest.v2.processes.connectors.weaviate.weaviate import (
|
|
10
|
+
WeaviateAccessConfig,
|
|
11
|
+
WeaviateConnectionConfig,
|
|
12
|
+
WeaviateUploader,
|
|
13
|
+
WeaviateUploaderConfig,
|
|
14
|
+
WeaviateUploadStager,
|
|
15
|
+
WeaviateUploadStagerConfig,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from weaviate.auth import AuthCredentials
|
|
20
|
+
from weaviate.client import WeaviateClient
|
|
21
|
+
|
|
22
|
+
CONNECTOR_TYPE = "weaviate-cloud"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CloudWeaviateAccessConfig(WeaviateAccessConfig):
|
|
26
|
+
access_token: Optional[str] = Field(
|
|
27
|
+
default=None, description="Used to create the bearer token."
|
|
28
|
+
)
|
|
29
|
+
api_key: Optional[str] = None
|
|
30
|
+
client_secret: Optional[str] = None
|
|
31
|
+
password: Optional[str] = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CloudWeaviateConnectionConfig(WeaviateConnectionConfig):
|
|
35
|
+
cluster_url: str = Field(
|
|
36
|
+
description="The WCD cluster URL or hostname to connect to. "
|
|
37
|
+
"Usually in the form: rAnD0mD1g1t5.something.weaviate.cloud"
|
|
38
|
+
)
|
|
39
|
+
username: Optional[str] = None
|
|
40
|
+
anonymous: bool = Field(default=False, description="if set, all auth values will be ignored")
|
|
41
|
+
refresh_token: Optional[str] = Field(
|
|
42
|
+
default=None,
|
|
43
|
+
description="Will tie this value to the bearer token. If not provided, "
|
|
44
|
+
"the authentication will expire once the lifetime of the access token is up.",
|
|
45
|
+
)
|
|
46
|
+
access_config: Secret[CloudWeaviateAccessConfig]
|
|
47
|
+
|
|
48
|
+
def model_post_init(self, __context: Any) -> None:
|
|
49
|
+
if self.anonymous:
|
|
50
|
+
return
|
|
51
|
+
access_config = self.access_config.get_secret_value()
|
|
52
|
+
auths = {
|
|
53
|
+
"api_key": access_config.api_key is not None,
|
|
54
|
+
"bearer_token": access_config.access_token is not None,
|
|
55
|
+
"client_secret": access_config.client_secret is not None,
|
|
56
|
+
"client_password": access_config.password is not None and self.username is not None,
|
|
57
|
+
}
|
|
58
|
+
if len(auths) == 0:
|
|
59
|
+
raise ValueError("No auth values provided and anonymous is False")
|
|
60
|
+
if len(auths) > 1:
|
|
61
|
+
existing_auths = [auth_method for auth_method, flag in auths.items() if flag]
|
|
62
|
+
raise ValueError(
|
|
63
|
+
"Multiple auth values provided, only one approach can be used: {}".format(
|
|
64
|
+
", ".join(existing_auths)
|
|
65
|
+
)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
69
|
+
def get_api_key_auth(self) -> Optional["AuthCredentials"]:
|
|
70
|
+
from weaviate.classes.init import Auth
|
|
71
|
+
|
|
72
|
+
if api_key := self.access_config.get_secret_value().api_key:
|
|
73
|
+
return Auth.api_key(api_key=api_key)
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
77
|
+
def get_bearer_token_auth(self) -> Optional["AuthCredentials"]:
|
|
78
|
+
from weaviate.classes.init import Auth
|
|
79
|
+
|
|
80
|
+
if access_token := self.access_config.get_secret_value().access_token:
|
|
81
|
+
return Auth.bearer_token(access_token=access_token, refresh_token=self.refresh_token)
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
85
|
+
def get_client_secret_auth(self) -> Optional["AuthCredentials"]:
|
|
86
|
+
from weaviate.classes.init import Auth
|
|
87
|
+
|
|
88
|
+
if client_secret := self.access_config.get_secret_value().client_secret:
|
|
89
|
+
return Auth.client_credentials(client_secret=client_secret)
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
93
|
+
def get_client_password_auth(self) -> Optional["AuthCredentials"]:
|
|
94
|
+
from weaviate.classes.init import Auth
|
|
95
|
+
|
|
96
|
+
if (username := self.username) and (
|
|
97
|
+
password := self.access_config.get_secret_value().password
|
|
98
|
+
):
|
|
99
|
+
return Auth.client_password(username=username, password=password)
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
103
|
+
def get_auth(self) -> "AuthCredentials":
|
|
104
|
+
auths = [
|
|
105
|
+
self.get_api_key_auth(),
|
|
106
|
+
self.get_client_secret_auth(),
|
|
107
|
+
self.get_bearer_token_auth(),
|
|
108
|
+
self.get_client_password_auth(),
|
|
109
|
+
]
|
|
110
|
+
auths = [auth for auth in auths if auth]
|
|
111
|
+
if len(auths) == 0:
|
|
112
|
+
raise ValueError("No auth values provided and anonymous is False")
|
|
113
|
+
if len(auths) > 1:
|
|
114
|
+
raise ValueError("Multiple auth values provided, only one approach can be used")
|
|
115
|
+
return auths[0]
|
|
116
|
+
|
|
117
|
+
@contextmanager
|
|
118
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
119
|
+
def get_client(self) -> Generator["WeaviateClient", None, None]:
|
|
120
|
+
from weaviate import connect_to_weaviate_cloud
|
|
121
|
+
from weaviate.classes.init import AdditionalConfig
|
|
122
|
+
|
|
123
|
+
auth_credentials = None if self.anonymous else self.get_auth()
|
|
124
|
+
with connect_to_weaviate_cloud(
|
|
125
|
+
cluster_url=self.cluster_url,
|
|
126
|
+
auth_credentials=auth_credentials,
|
|
127
|
+
additional_config=AdditionalConfig(timeout=self.get_timeout()),
|
|
128
|
+
) as weaviate_client:
|
|
129
|
+
yield weaviate_client
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class CloudWeaviateUploadStagerConfig(WeaviateUploadStagerConfig):
|
|
133
|
+
pass
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@dataclass
|
|
137
|
+
class CloudWeaviateUploadStager(WeaviateUploadStager):
|
|
138
|
+
upload_stager_config: CloudWeaviateUploadStagerConfig = field(
|
|
139
|
+
default_factory=lambda: WeaviateUploadStagerConfig()
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class CloudWeaviateUploaderConfig(WeaviateUploaderConfig):
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@dataclass
|
|
148
|
+
class CloudWeaviateUploader(WeaviateUploader):
|
|
149
|
+
connection_config: CloudWeaviateConnectionConfig = field(
|
|
150
|
+
default_factory=lambda: CloudWeaviateConnectionConfig()
|
|
151
|
+
)
|
|
152
|
+
upload_config: CloudWeaviateUploaderConfig = field(
|
|
153
|
+
default_factory=lambda: CloudWeaviateUploaderConfig()
|
|
154
|
+
)
|
|
155
|
+
connector_type: str = CONNECTOR_TYPE
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
weaviate_cloud_destination_entry = DestinationRegistryEntry(
|
|
159
|
+
connection_config=CloudWeaviateConnectionConfig,
|
|
160
|
+
uploader=CloudWeaviateUploader,
|
|
161
|
+
uploader_config=CloudWeaviateUploaderConfig,
|
|
162
|
+
upload_stager=CloudWeaviateUploadStager,
|
|
163
|
+
upload_stager_config=CloudWeaviateUploadStagerConfig,
|
|
164
|
+
)
|