unstructured-ingest 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/chunkers/test_chunkers.py +0 -11
- test/integration/connectors/conftest.py +11 -1
- test/integration/connectors/databricks_tests/test_volumes_native.py +4 -3
- test/integration/connectors/duckdb/conftest.py +14 -0
- test/integration/connectors/duckdb/test_duckdb.py +51 -44
- test/integration/connectors/duckdb/test_motherduck.py +37 -48
- test/integration/connectors/elasticsearch/test_elasticsearch.py +26 -4
- test/integration/connectors/elasticsearch/test_opensearch.py +26 -3
- test/integration/connectors/sql/test_postgres.py +103 -92
- test/integration/connectors/sql/test_singlestore.py +112 -100
- test/integration/connectors/sql/test_snowflake.py +142 -117
- test/integration/connectors/sql/test_sqlite.py +87 -76
- test/integration/connectors/test_astradb.py +62 -1
- test/integration/connectors/test_azure_ai_search.py +25 -3
- test/integration/connectors/test_chroma.py +120 -0
- test/integration/connectors/test_confluence.py +4 -4
- test/integration/connectors/test_delta_table.py +1 -0
- test/integration/connectors/test_kafka.py +6 -6
- test/integration/connectors/test_milvus.py +21 -0
- test/integration/connectors/test_mongodb.py +7 -4
- test/integration/connectors/test_neo4j.py +236 -0
- test/integration/connectors/test_pinecone.py +25 -1
- test/integration/connectors/test_qdrant.py +25 -2
- test/integration/connectors/test_s3.py +9 -6
- test/integration/connectors/utils/docker.py +6 -0
- test/integration/connectors/utils/validation/__init__.py +0 -0
- test/integration/connectors/utils/validation/destination.py +88 -0
- test/integration/connectors/utils/validation/equality.py +75 -0
- test/integration/connectors/utils/{validation.py → validation/source.py} +42 -98
- test/integration/connectors/utils/validation/utils.py +36 -0
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/utils/chunking.py +11 -0
- unstructured_ingest/utils/data_prep.py +36 -0
- unstructured_ingest/v2/interfaces/__init__.py +3 -1
- unstructured_ingest/v2/interfaces/file_data.py +58 -14
- unstructured_ingest/v2/interfaces/upload_stager.py +70 -6
- unstructured_ingest/v2/interfaces/uploader.py +11 -2
- unstructured_ingest/v2/pipeline/steps/chunk.py +2 -1
- unstructured_ingest/v2/pipeline/steps/download.py +5 -4
- unstructured_ingest/v2/pipeline/steps/embed.py +2 -1
- unstructured_ingest/v2/pipeline/steps/filter.py +2 -2
- unstructured_ingest/v2/pipeline/steps/index.py +4 -4
- unstructured_ingest/v2/pipeline/steps/partition.py +3 -2
- unstructured_ingest/v2/pipeline/steps/stage.py +5 -3
- unstructured_ingest/v2/pipeline/steps/uncompress.py +2 -2
- unstructured_ingest/v2/pipeline/steps/upload.py +3 -3
- unstructured_ingest/v2/processes/connectors/__init__.py +3 -0
- unstructured_ingest/v2/processes/connectors/astradb.py +43 -63
- unstructured_ingest/v2/processes/connectors/azure_ai_search.py +16 -40
- unstructured_ingest/v2/processes/connectors/chroma.py +36 -59
- unstructured_ingest/v2/processes/connectors/couchbase.py +92 -93
- unstructured_ingest/v2/processes/connectors/delta_table.py +11 -33
- unstructured_ingest/v2/processes/connectors/duckdb/base.py +26 -26
- unstructured_ingest/v2/processes/connectors/duckdb/duckdb.py +29 -20
- unstructured_ingest/v2/processes/connectors/duckdb/motherduck.py +37 -44
- unstructured_ingest/v2/processes/connectors/elasticsearch/elasticsearch.py +46 -75
- unstructured_ingest/v2/processes/connectors/fsspec/azure.py +12 -35
- unstructured_ingest/v2/processes/connectors/fsspec/box.py +12 -35
- unstructured_ingest/v2/processes/connectors/fsspec/dropbox.py +15 -42
- unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +33 -29
- unstructured_ingest/v2/processes/connectors/fsspec/gcs.py +12 -34
- unstructured_ingest/v2/processes/connectors/fsspec/s3.py +13 -37
- unstructured_ingest/v2/processes/connectors/fsspec/sftp.py +19 -33
- unstructured_ingest/v2/processes/connectors/gitlab.py +32 -31
- unstructured_ingest/v2/processes/connectors/google_drive.py +32 -29
- unstructured_ingest/v2/processes/connectors/kafka/kafka.py +2 -4
- unstructured_ingest/v2/processes/connectors/kdbai.py +44 -70
- unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +8 -10
- unstructured_ingest/v2/processes/connectors/local.py +13 -2
- unstructured_ingest/v2/processes/connectors/milvus.py +16 -57
- unstructured_ingest/v2/processes/connectors/mongodb.py +99 -108
- unstructured_ingest/v2/processes/connectors/neo4j.py +383 -0
- unstructured_ingest/v2/processes/connectors/onedrive.py +1 -1
- unstructured_ingest/v2/processes/connectors/pinecone.py +3 -33
- unstructured_ingest/v2/processes/connectors/qdrant/qdrant.py +32 -41
- unstructured_ingest/v2/processes/connectors/sql/postgres.py +5 -5
- unstructured_ingest/v2/processes/connectors/sql/singlestore.py +5 -5
- unstructured_ingest/v2/processes/connectors/sql/snowflake.py +5 -5
- unstructured_ingest/v2/processes/connectors/sql/sql.py +72 -66
- unstructured_ingest/v2/processes/connectors/sql/sqlite.py +5 -5
- unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +9 -31
- {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/METADATA +20 -15
- {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/RECORD +87 -79
- {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/top_level.txt +0 -0
|
@@ -1,13 +1,10 @@
|
|
|
1
|
-
import json
|
|
2
|
-
import sys
|
|
3
1
|
from contextlib import contextmanager
|
|
4
|
-
from dataclasses import dataclass
|
|
2
|
+
from dataclasses import dataclass
|
|
5
3
|
from datetime import datetime
|
|
6
|
-
from pathlib import Path
|
|
7
4
|
from time import time
|
|
8
5
|
from typing import TYPE_CHECKING, Any, Generator, Optional
|
|
9
6
|
|
|
10
|
-
from pydantic import Field, Secret
|
|
7
|
+
from pydantic import BaseModel, Field, Secret
|
|
11
8
|
|
|
12
9
|
from unstructured_ingest.__version__ import __version__ as unstructured_version
|
|
13
10
|
from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
|
|
@@ -16,9 +13,12 @@ from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
|
16
13
|
from unstructured_ingest.v2.constants import RECORD_ID_LABEL
|
|
17
14
|
from unstructured_ingest.v2.interfaces import (
|
|
18
15
|
AccessConfig,
|
|
16
|
+
BatchFileData,
|
|
17
|
+
BatchItem,
|
|
19
18
|
ConnectionConfig,
|
|
20
19
|
Downloader,
|
|
21
20
|
DownloaderConfig,
|
|
21
|
+
DownloadResponse,
|
|
22
22
|
FileData,
|
|
23
23
|
FileDataSourceMetadata,
|
|
24
24
|
Indexer,
|
|
@@ -42,6 +42,15 @@ CONNECTOR_TYPE = "mongodb"
|
|
|
42
42
|
SERVER_API_VERSION = "1"
|
|
43
43
|
|
|
44
44
|
|
|
45
|
+
class MongoDBAdditionalMetadata(BaseModel):
|
|
46
|
+
database: str
|
|
47
|
+
collection: str
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class MongoDBBatchFileData(BatchFileData):
|
|
51
|
+
additional_metadata: MongoDBAdditionalMetadata
|
|
52
|
+
|
|
53
|
+
|
|
45
54
|
class MongoDBAccessConfig(AccessConfig):
|
|
46
55
|
uri: Optional[str] = Field(default=None, description="URI to user when connecting")
|
|
47
56
|
|
|
@@ -124,7 +133,7 @@ class MongoDBIndexer(Indexer):
|
|
|
124
133
|
logger.error(f"Failed to validate connection: {e}", exc_info=True)
|
|
125
134
|
raise SourceConnectionError(f"Failed to validate connection: {e}")
|
|
126
135
|
|
|
127
|
-
def run(self, **kwargs: Any) -> Generator[
|
|
136
|
+
def run(self, **kwargs: Any) -> Generator[BatchFileData, None, None]:
|
|
128
137
|
"""Generates FileData objects for each document in the MongoDB collection."""
|
|
129
138
|
with self.connection_config.get_client() as client:
|
|
130
139
|
database = client[self.index_config.database]
|
|
@@ -132,12 +141,12 @@ class MongoDBIndexer(Indexer):
|
|
|
132
141
|
|
|
133
142
|
# Get list of document IDs
|
|
134
143
|
ids = collection.distinct("_id")
|
|
135
|
-
|
|
144
|
+
|
|
145
|
+
ids = sorted(ids)
|
|
146
|
+
batch_size = self.index_config.batch_size
|
|
136
147
|
|
|
137
148
|
for id_batch in batch_generator(ids, batch_size=batch_size):
|
|
138
149
|
# Make sure the hash is always a positive number to create identifier
|
|
139
|
-
batch_id = str(hash(frozenset(id_batch)) + sys.maxsize + 1)
|
|
140
|
-
|
|
141
150
|
metadata = FileDataSourceMetadata(
|
|
142
151
|
date_processed=str(time()),
|
|
143
152
|
record_locator={
|
|
@@ -146,14 +155,13 @@ class MongoDBIndexer(Indexer):
|
|
|
146
155
|
},
|
|
147
156
|
)
|
|
148
157
|
|
|
149
|
-
file_data =
|
|
150
|
-
identifier=batch_id,
|
|
151
|
-
doc_type="batch",
|
|
158
|
+
file_data = MongoDBBatchFileData(
|
|
152
159
|
connector_type=self.connector_type,
|
|
153
160
|
metadata=metadata,
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
161
|
+
batch_items=[BatchItem(identifier=str(doc_id)) for doc_id in id_batch],
|
|
162
|
+
additional_metadata=MongoDBAdditionalMetadata(
|
|
163
|
+
collection=self.index_config.collection, database=self.index_config.database
|
|
164
|
+
),
|
|
157
165
|
)
|
|
158
166
|
yield file_data
|
|
159
167
|
|
|
@@ -164,26 +172,59 @@ class MongoDBDownloader(Downloader):
|
|
|
164
172
|
connection_config: MongoDBConnectionConfig
|
|
165
173
|
connector_type: str = CONNECTOR_TYPE
|
|
166
174
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
from
|
|
171
|
-
from pymongo.server_api import ServerApi
|
|
175
|
+
def generate_download_response(
|
|
176
|
+
self, doc: dict, file_data: MongoDBBatchFileData
|
|
177
|
+
) -> DownloadResponse:
|
|
178
|
+
from bson.objectid import ObjectId
|
|
172
179
|
|
|
173
|
-
|
|
180
|
+
doc_id = doc["_id"]
|
|
181
|
+
doc.pop("_id", None)
|
|
174
182
|
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
)
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
183
|
+
# Extract date_created from the document or ObjectId
|
|
184
|
+
date_created = None
|
|
185
|
+
if "date_created" in doc:
|
|
186
|
+
# If the document has a 'date_created' field, use it
|
|
187
|
+
date_created = doc["date_created"]
|
|
188
|
+
if isinstance(date_created, datetime):
|
|
189
|
+
date_created = date_created.isoformat()
|
|
190
|
+
else:
|
|
191
|
+
# Convert to ISO format if it's a string
|
|
192
|
+
date_created = str(date_created)
|
|
193
|
+
elif isinstance(doc_id, ObjectId):
|
|
194
|
+
# Use the ObjectId's generation time
|
|
195
|
+
date_created = doc_id.generation_time.isoformat()
|
|
196
|
+
|
|
197
|
+
flattened_dict = flatten_dict(dictionary=doc)
|
|
198
|
+
concatenated_values = "\n".join(str(value) for value in flattened_dict.values())
|
|
199
|
+
|
|
200
|
+
# Create a FileData object for each document with source_identifiers
|
|
201
|
+
cast_file_data = FileData.cast(file_data=file_data)
|
|
202
|
+
cast_file_data.identifier = str(doc_id)
|
|
203
|
+
filename = f"{doc_id}.txt"
|
|
204
|
+
cast_file_data.source_identifiers = SourceIdentifiers(
|
|
205
|
+
filename=filename,
|
|
206
|
+
fullpath=filename,
|
|
207
|
+
rel_path=filename,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Determine the download path
|
|
211
|
+
download_path = self.get_download_path(file_data=cast_file_data)
|
|
212
|
+
if download_path is None:
|
|
213
|
+
raise ValueError("Download path could not be determined")
|
|
214
|
+
|
|
215
|
+
download_path.parent.mkdir(parents=True, exist_ok=True)
|
|
216
|
+
|
|
217
|
+
# Write the concatenated values to the file
|
|
218
|
+
with open(download_path, "w", encoding="utf8") as f:
|
|
219
|
+
f.write(concatenated_values)
|
|
220
|
+
|
|
221
|
+
# Update metadata
|
|
222
|
+
cast_file_data.metadata.record_locator["document_id"] = str(doc_id)
|
|
223
|
+
cast_file_data.metadata.date_created = date_created
|
|
224
|
+
|
|
225
|
+
return super().generate_download_response(
|
|
226
|
+
file_data=cast_file_data, download_path=download_path
|
|
227
|
+
)
|
|
187
228
|
|
|
188
229
|
@SourceConnectionError.wrap
|
|
189
230
|
@requires_dependencies(["bson"], extras="mongodb")
|
|
@@ -192,82 +233,34 @@ class MongoDBDownloader(Downloader):
|
|
|
192
233
|
from bson.errors import InvalidId
|
|
193
234
|
from bson.objectid import ObjectId
|
|
194
235
|
|
|
195
|
-
|
|
196
|
-
database = client[file_data.metadata.record_locator["database"]]
|
|
197
|
-
collection = database[file_data.metadata.record_locator["collection"]]
|
|
236
|
+
mongo_file_data = MongoDBBatchFileData.cast(file_data=file_data)
|
|
198
237
|
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
238
|
+
with self.connection_config.get_client() as client:
|
|
239
|
+
database = client[mongo_file_data.additional_metadata.database]
|
|
240
|
+
collection = database[mongo_file_data.additional_metadata.collection]
|
|
202
241
|
|
|
203
|
-
|
|
204
|
-
for doc_id in ids:
|
|
205
|
-
try:
|
|
206
|
-
object_ids.append(ObjectId(doc_id))
|
|
207
|
-
except InvalidId as e:
|
|
208
|
-
error_message = f"Invalid ObjectId for doc_id '{doc_id}': {str(e)}"
|
|
209
|
-
logger.error(error_message)
|
|
210
|
-
raise ValueError(error_message) from e
|
|
242
|
+
ids = [item.identifier for item in mongo_file_data.batch_items]
|
|
211
243
|
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
244
|
+
object_ids = []
|
|
245
|
+
for doc_id in ids:
|
|
246
|
+
try:
|
|
247
|
+
object_ids.append(ObjectId(doc_id))
|
|
248
|
+
except InvalidId as e:
|
|
249
|
+
error_message = f"Invalid ObjectId for doc_id '{doc_id}': {str(e)}"
|
|
250
|
+
logger.error(error_message)
|
|
251
|
+
raise ValueError(error_message) from e
|
|
252
|
+
|
|
253
|
+
try:
|
|
254
|
+
docs = list(collection.find({"_id": {"$in": object_ids}}))
|
|
255
|
+
except Exception as e:
|
|
256
|
+
logger.error(f"Failed to fetch documents: {e}", exc_info=True)
|
|
257
|
+
raise e
|
|
217
258
|
|
|
218
259
|
download_responses = []
|
|
219
260
|
for doc in docs:
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
# Extract date_created from the document or ObjectId
|
|
224
|
-
date_created = None
|
|
225
|
-
if "date_created" in doc:
|
|
226
|
-
# If the document has a 'date_created' field, use it
|
|
227
|
-
date_created = doc["date_created"]
|
|
228
|
-
if isinstance(date_created, datetime):
|
|
229
|
-
date_created = date_created.isoformat()
|
|
230
|
-
else:
|
|
231
|
-
# Convert to ISO format if it's a string
|
|
232
|
-
date_created = str(date_created)
|
|
233
|
-
elif isinstance(doc_id, ObjectId):
|
|
234
|
-
# Use the ObjectId's generation time
|
|
235
|
-
date_created = doc_id.generation_time.isoformat()
|
|
236
|
-
|
|
237
|
-
flattened_dict = flatten_dict(dictionary=doc)
|
|
238
|
-
concatenated_values = "\n".join(str(value) for value in flattened_dict.values())
|
|
239
|
-
|
|
240
|
-
# Create a FileData object for each document with source_identifiers
|
|
241
|
-
individual_file_data = replace(file_data)
|
|
242
|
-
individual_file_data.identifier = str(doc_id)
|
|
243
|
-
individual_file_data.source_identifiers = SourceIdentifiers(
|
|
244
|
-
filename=str(doc_id),
|
|
245
|
-
fullpath=str(doc_id),
|
|
246
|
-
rel_path=str(doc_id),
|
|
247
|
-
)
|
|
248
|
-
|
|
249
|
-
# Determine the download path
|
|
250
|
-
download_path = self.get_download_path(individual_file_data)
|
|
251
|
-
if download_path is None:
|
|
252
|
-
raise ValueError("Download path could not be determined")
|
|
253
|
-
|
|
254
|
-
download_path.parent.mkdir(parents=True, exist_ok=True)
|
|
255
|
-
download_path = download_path.with_suffix(".txt")
|
|
256
|
-
|
|
257
|
-
# Write the concatenated values to the file
|
|
258
|
-
with open(download_path, "w", encoding="utf8") as f:
|
|
259
|
-
f.write(concatenated_values)
|
|
260
|
-
|
|
261
|
-
individual_file_data.local_download_path = str(download_path)
|
|
262
|
-
|
|
263
|
-
# Update metadata
|
|
264
|
-
individual_file_data.metadata.record_locator["document_id"] = str(doc_id)
|
|
265
|
-
individual_file_data.metadata.date_created = date_created
|
|
266
|
-
|
|
267
|
-
download_response = self.generate_download_response(
|
|
268
|
-
file_data=individual_file_data, download_path=download_path
|
|
261
|
+
download_responses.append(
|
|
262
|
+
self.generate_download_response(doc=doc, file_data=mongo_file_data)
|
|
269
263
|
)
|
|
270
|
-
download_responses.append(download_response)
|
|
271
264
|
|
|
272
265
|
return download_responses
|
|
273
266
|
|
|
@@ -332,18 +325,16 @@ class MongoDBUploader(Uploader):
|
|
|
332
325
|
f"deleted {delete_results.deleted_count} records from collection {collection.name}"
|
|
333
326
|
)
|
|
334
327
|
|
|
335
|
-
def
|
|
336
|
-
with path.open("r") as file:
|
|
337
|
-
elements_dict = json.load(file)
|
|
328
|
+
def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
|
|
338
329
|
logger.info(
|
|
339
|
-
f"writing {len(
|
|
330
|
+
f"writing {len(data)} objects to destination "
|
|
340
331
|
f"db, {self.upload_config.database}, "
|
|
341
332
|
f"collection {self.upload_config.collection} "
|
|
342
333
|
f"at {self.connection_config.host}",
|
|
343
334
|
)
|
|
344
335
|
# This would typically live in the stager but since no other manipulation
|
|
345
336
|
# is done, setting the record id field in the uploader
|
|
346
|
-
for element in
|
|
337
|
+
for element in data:
|
|
347
338
|
element[self.upload_config.record_id_key] = file_data.identifier
|
|
348
339
|
with self.connection_config.get_client() as client:
|
|
349
340
|
db = client[self.upload_config.database]
|
|
@@ -352,7 +343,7 @@ class MongoDBUploader(Uploader):
|
|
|
352
343
|
self.delete_by_record_id(file_data=file_data, collection=collection)
|
|
353
344
|
else:
|
|
354
345
|
logger.warning("criteria for deleting previous content not met, skipping")
|
|
355
|
-
for chunk in batch_generator(
|
|
346
|
+
for chunk in batch_generator(data, self.upload_config.batch_size):
|
|
356
347
|
collection.insert_many(chunk)
|
|
357
348
|
|
|
358
349
|
|
|
@@ -0,0 +1,383 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import uuid
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
|
|
12
|
+
|
|
13
|
+
from pydantic import BaseModel, ConfigDict, Field, Secret
|
|
14
|
+
|
|
15
|
+
from unstructured_ingest.error import DestinationConnectionError
|
|
16
|
+
from unstructured_ingest.logger import logger
|
|
17
|
+
from unstructured_ingest.utils.chunking import elements_from_base64_gzipped_json
|
|
18
|
+
from unstructured_ingest.utils.data_prep import batch_generator
|
|
19
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
20
|
+
from unstructured_ingest.v2.interfaces import (
|
|
21
|
+
AccessConfig,
|
|
22
|
+
ConnectionConfig,
|
|
23
|
+
FileData,
|
|
24
|
+
Uploader,
|
|
25
|
+
UploaderConfig,
|
|
26
|
+
UploadStager,
|
|
27
|
+
UploadStagerConfig,
|
|
28
|
+
)
|
|
29
|
+
from unstructured_ingest.v2.processes.connector_registry import (
|
|
30
|
+
DestinationRegistryEntry,
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from neo4j import AsyncDriver, Auth
|
|
35
|
+
from networkx import Graph, MultiDiGraph
|
|
36
|
+
|
|
37
|
+
CONNECTOR_TYPE = "neo4j"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Neo4jAccessConfig(AccessConfig):
|
|
41
|
+
password: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Neo4jConnectionConfig(ConnectionConfig):
|
|
45
|
+
access_config: Secret[Neo4jAccessConfig]
|
|
46
|
+
connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
|
|
47
|
+
username: str
|
|
48
|
+
uri: str = Field(description="Neo4j Connection URI <scheme>://<host>:<port>")
|
|
49
|
+
database: str = Field(description="Name of the target database")
|
|
50
|
+
|
|
51
|
+
@requires_dependencies(["neo4j"], extras="neo4j")
|
|
52
|
+
@asynccontextmanager
|
|
53
|
+
async def get_client(self) -> AsyncGenerator["AsyncDriver", None]:
|
|
54
|
+
from neo4j import AsyncGraphDatabase
|
|
55
|
+
|
|
56
|
+
driver = AsyncGraphDatabase.driver(**self._get_driver_parameters())
|
|
57
|
+
logger.info(f"Created driver connecting to the database '{self.database}' at {self.uri}.")
|
|
58
|
+
try:
|
|
59
|
+
yield driver
|
|
60
|
+
finally:
|
|
61
|
+
await driver.close()
|
|
62
|
+
logger.info(
|
|
63
|
+
f"Closed driver connecting to the database '{self.database}' at {self.uri}."
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def _get_driver_parameters(self) -> dict:
|
|
67
|
+
return {
|
|
68
|
+
"uri": self.uri,
|
|
69
|
+
"auth": self._get_auth(),
|
|
70
|
+
"database": self.database,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
@requires_dependencies(["neo4j"], extras="neo4j")
|
|
74
|
+
def _get_auth(self) -> "Auth":
|
|
75
|
+
from neo4j import Auth
|
|
76
|
+
|
|
77
|
+
return Auth("basic", self.username, self.access_config.get_secret_value().password)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class Neo4jUploadStagerConfig(UploadStagerConfig):
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class Neo4jUploadStager(UploadStager):
|
|
86
|
+
upload_stager_config: Neo4jUploadStagerConfig = Field(
|
|
87
|
+
default_factory=Neo4jUploadStagerConfig, validate_default=True
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def run( # type: ignore
|
|
91
|
+
self,
|
|
92
|
+
elements_filepath: Path,
|
|
93
|
+
file_data: FileData,
|
|
94
|
+
output_dir: Path,
|
|
95
|
+
output_filename: str,
|
|
96
|
+
**kwargs: Any,
|
|
97
|
+
) -> Path:
|
|
98
|
+
with elements_filepath.open() as file:
|
|
99
|
+
elements = json.load(file)
|
|
100
|
+
|
|
101
|
+
nx_graph = self._create_lexical_graph(
|
|
102
|
+
elements, self._create_document_node(file_data=file_data)
|
|
103
|
+
)
|
|
104
|
+
output_filepath = Path(output_dir) / f"{output_filename}.json"
|
|
105
|
+
output_filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
106
|
+
|
|
107
|
+
with open(output_filepath, "w") as file:
|
|
108
|
+
json.dump(_GraphData.from_nx(nx_graph).model_dump(), file, indent=4)
|
|
109
|
+
|
|
110
|
+
return output_filepath
|
|
111
|
+
|
|
112
|
+
def _create_lexical_graph(self, elements: list[dict], document_node: _Node) -> "Graph":
|
|
113
|
+
import networkx as nx
|
|
114
|
+
|
|
115
|
+
graph = nx.MultiDiGraph()
|
|
116
|
+
graph.add_node(document_node)
|
|
117
|
+
|
|
118
|
+
previous_node: Optional[_Node] = None
|
|
119
|
+
for element in elements:
|
|
120
|
+
element_node = self._create_element_node(element)
|
|
121
|
+
order_relationship = (
|
|
122
|
+
Relationship.NEXT_CHUNK if self._is_chunk(element) else Relationship.NEXT_ELEMENT
|
|
123
|
+
)
|
|
124
|
+
if previous_node:
|
|
125
|
+
graph.add_edge(element_node, previous_node, relationship=order_relationship)
|
|
126
|
+
|
|
127
|
+
previous_node = element_node
|
|
128
|
+
graph.add_edge(element_node, document_node, relationship=Relationship.PART_OF_DOCUMENT)
|
|
129
|
+
|
|
130
|
+
if self._is_chunk(element):
|
|
131
|
+
origin_element_nodes = [
|
|
132
|
+
self._create_element_node(origin_element)
|
|
133
|
+
for origin_element in self._get_origin_elements(element)
|
|
134
|
+
]
|
|
135
|
+
graph.add_edges_from(
|
|
136
|
+
[
|
|
137
|
+
(origin_element_node, element_node)
|
|
138
|
+
for origin_element_node in origin_element_nodes
|
|
139
|
+
],
|
|
140
|
+
relationship=Relationship.PART_OF_CHUNK,
|
|
141
|
+
)
|
|
142
|
+
graph.add_edges_from(
|
|
143
|
+
[
|
|
144
|
+
(origin_element_node, document_node)
|
|
145
|
+
for origin_element_node in origin_element_nodes
|
|
146
|
+
],
|
|
147
|
+
relationship=Relationship.PART_OF_DOCUMENT,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
return graph
|
|
151
|
+
|
|
152
|
+
# TODO(Filip Knefel): Ensure _is_chunk is as reliable as possible, consider different checks
|
|
153
|
+
def _is_chunk(self, element: dict) -> bool:
|
|
154
|
+
return "orig_elements" in element.get("metadata", {})
|
|
155
|
+
|
|
156
|
+
def _create_document_node(self, file_data: FileData) -> _Node:
|
|
157
|
+
properties = {}
|
|
158
|
+
if file_data.source_identifiers:
|
|
159
|
+
properties["name"] = file_data.source_identifiers.filename
|
|
160
|
+
if file_data.metadata.date_created:
|
|
161
|
+
properties["date_created"] = file_data.metadata.date_created
|
|
162
|
+
if file_data.metadata.date_modified:
|
|
163
|
+
properties["date_modified"] = file_data.metadata.date_modified
|
|
164
|
+
return _Node(id_=file_data.identifier, properties=properties, labels=[Label.DOCUMENT])
|
|
165
|
+
|
|
166
|
+
def _create_element_node(self, element: dict) -> _Node:
|
|
167
|
+
properties = {"id": element["element_id"], "text": element["text"]}
|
|
168
|
+
|
|
169
|
+
if embeddings := element.get("embeddings"):
|
|
170
|
+
properties["embeddings"] = embeddings
|
|
171
|
+
|
|
172
|
+
label = Label.CHUNK if self._is_chunk(element) else Label.UNSTRUCTURED_ELEMENT
|
|
173
|
+
return _Node(id_=element["element_id"], properties=properties, labels=[label])
|
|
174
|
+
|
|
175
|
+
def _get_origin_elements(self, chunk_element: dict) -> list[dict]:
|
|
176
|
+
orig_elements = chunk_element.get("metadata", {}).get("orig_elements")
|
|
177
|
+
return elements_from_base64_gzipped_json(raw_s=orig_elements)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class _GraphData(BaseModel):
|
|
181
|
+
nodes: list[_Node]
|
|
182
|
+
edges: list[_Edge]
|
|
183
|
+
|
|
184
|
+
@classmethod
|
|
185
|
+
def from_nx(cls, nx_graph: "MultiDiGraph") -> _GraphData:
|
|
186
|
+
nodes = list(nx_graph.nodes())
|
|
187
|
+
edges = [
|
|
188
|
+
_Edge(
|
|
189
|
+
source_id=u.id_,
|
|
190
|
+
destination_id=v.id_,
|
|
191
|
+
relationship=Relationship(data_dict["relationship"]),
|
|
192
|
+
)
|
|
193
|
+
for u, v, data_dict in nx_graph.edges(data=True)
|
|
194
|
+
]
|
|
195
|
+
return _GraphData(nodes=nodes, edges=edges)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class _Node(BaseModel):
|
|
199
|
+
model_config = ConfigDict(use_enum_values=True)
|
|
200
|
+
|
|
201
|
+
id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
202
|
+
labels: list[Label] = Field(default_factory=list)
|
|
203
|
+
properties: dict = Field(default_factory=dict)
|
|
204
|
+
|
|
205
|
+
def __hash__(self):
|
|
206
|
+
return hash(self.id_)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class _Edge(BaseModel):
|
|
210
|
+
model_config = ConfigDict(use_enum_values=True)
|
|
211
|
+
|
|
212
|
+
source_id: str
|
|
213
|
+
destination_id: str
|
|
214
|
+
relationship: Relationship
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class Label(str, Enum):
|
|
218
|
+
UNSTRUCTURED_ELEMENT = "UnstructuredElement"
|
|
219
|
+
CHUNK = "Chunk"
|
|
220
|
+
DOCUMENT = "Document"
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class Relationship(str, Enum):
|
|
224
|
+
PART_OF_DOCUMENT = "PART_OF_DOCUMENT"
|
|
225
|
+
PART_OF_CHUNK = "PART_OF_CHUNK"
|
|
226
|
+
NEXT_CHUNK = "NEXT_CHUNK"
|
|
227
|
+
NEXT_ELEMENT = "NEXT_ELEMENT"
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
class Neo4jUploaderConfig(UploaderConfig):
|
|
231
|
+
batch_size: int = Field(
|
|
232
|
+
default=100, description="Maximal number of nodes/relationships created per transaction."
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
@dataclass
|
|
237
|
+
class Neo4jUploader(Uploader):
|
|
238
|
+
upload_config: Neo4jUploaderConfig
|
|
239
|
+
connection_config: Neo4jConnectionConfig
|
|
240
|
+
connector_type: str = CONNECTOR_TYPE
|
|
241
|
+
|
|
242
|
+
@DestinationConnectionError.wrap
|
|
243
|
+
def precheck(self) -> None:
|
|
244
|
+
async def verify_auth():
|
|
245
|
+
async with self.connection_config.get_client() as client:
|
|
246
|
+
await client.verify_connectivity()
|
|
247
|
+
|
|
248
|
+
asyncio.run(verify_auth())
|
|
249
|
+
|
|
250
|
+
def is_async(self):
|
|
251
|
+
return True
|
|
252
|
+
|
|
253
|
+
async def run_async(self, path: Path, file_data: FileData, **kwargs) -> None: # type: ignore
|
|
254
|
+
with path.open() as file:
|
|
255
|
+
staged_data = json.load(file)
|
|
256
|
+
|
|
257
|
+
graph_data = _GraphData.model_validate(staged_data)
|
|
258
|
+
async with self.connection_config.get_client() as client:
|
|
259
|
+
await self._create_uniqueness_constraints(client)
|
|
260
|
+
await self._delete_old_data_if_exists(file_data, client=client)
|
|
261
|
+
await self._merge_graph(graph_data=graph_data, client=client)
|
|
262
|
+
|
|
263
|
+
async def _create_uniqueness_constraints(self, client: AsyncDriver) -> None:
|
|
264
|
+
for label in Label:
|
|
265
|
+
logger.info(
|
|
266
|
+
f"Adding id uniqueness constraint for nodes labeled '{label}'"
|
|
267
|
+
" if it does not already exist."
|
|
268
|
+
)
|
|
269
|
+
constraint_name = f"{label.lower()}_id"
|
|
270
|
+
await client.execute_query(
|
|
271
|
+
f"""
|
|
272
|
+
CREATE CONSTRAINT {constraint_name} IF NOT EXISTS
|
|
273
|
+
FOR (n: {label}) REQUIRE n.id IS UNIQUE
|
|
274
|
+
"""
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
async def _delete_old_data_if_exists(self, file_data: FileData, client: AsyncDriver) -> None:
|
|
278
|
+
logger.info(f"Deleting old data for the record '{file_data.identifier}' (if present).")
|
|
279
|
+
_, summary, _ = await client.execute_query(
|
|
280
|
+
f"""
|
|
281
|
+
MATCH (n: {Label.DOCUMENT} {{id: $identifier}})
|
|
282
|
+
MATCH (n)--(m: {Label.CHUNK}|{Label.UNSTRUCTURED_ELEMENT})
|
|
283
|
+
DETACH DELETE m""",
|
|
284
|
+
identifier=file_data.identifier,
|
|
285
|
+
)
|
|
286
|
+
logger.info(
|
|
287
|
+
f"Deleted {summary.counters.nodes_deleted} nodes"
|
|
288
|
+
f" and {summary.counters.relationships_deleted} relationships."
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
async def _merge_graph(self, graph_data: _GraphData, client: AsyncDriver) -> None:
|
|
292
|
+
nodes_by_labels: defaultdict[tuple[Label, ...], list[_Node]] = defaultdict(list)
|
|
293
|
+
for node in graph_data.nodes:
|
|
294
|
+
nodes_by_labels[tuple(node.labels)].append(node)
|
|
295
|
+
|
|
296
|
+
logger.info(f"Merging {len(graph_data.nodes)} graph nodes.")
|
|
297
|
+
# NOTE: Processed in parallel as there's no overlap between accessed nodes
|
|
298
|
+
await self._execute_queries(
|
|
299
|
+
[
|
|
300
|
+
self._create_nodes_query(nodes_batch, labels)
|
|
301
|
+
for labels, nodes in nodes_by_labels.items()
|
|
302
|
+
for nodes_batch in batch_generator(nodes, batch_size=self.upload_config.batch_size)
|
|
303
|
+
],
|
|
304
|
+
client=client,
|
|
305
|
+
in_parallel=True,
|
|
306
|
+
)
|
|
307
|
+
logger.info(f"Finished merging {len(graph_data.nodes)} graph nodes.")
|
|
308
|
+
|
|
309
|
+
edges_by_relationship: defaultdict[Relationship, list[_Edge]] = defaultdict(list)
|
|
310
|
+
for edge in graph_data.edges:
|
|
311
|
+
edges_by_relationship[edge.relationship].append(edge)
|
|
312
|
+
|
|
313
|
+
logger.info(f"Merging {len(graph_data.edges)} graph relationships (edges).")
|
|
314
|
+
# NOTE: Processed sequentially to avoid queries locking node access to one another
|
|
315
|
+
await self._execute_queries(
|
|
316
|
+
[
|
|
317
|
+
self._create_edges_query(edges_batch, relationship)
|
|
318
|
+
for relationship, edges in edges_by_relationship.items()
|
|
319
|
+
for edges_batch in batch_generator(edges, batch_size=self.upload_config.batch_size)
|
|
320
|
+
],
|
|
321
|
+
client=client,
|
|
322
|
+
)
|
|
323
|
+
logger.info(f"Finished merging {len(graph_data.edges)} graph relationships (edges).")
|
|
324
|
+
|
|
325
|
+
@staticmethod
|
|
326
|
+
async def _execute_queries(
|
|
327
|
+
queries_with_parameters: list[tuple[str, dict]],
|
|
328
|
+
client: AsyncDriver,
|
|
329
|
+
in_parallel: bool = False,
|
|
330
|
+
) -> None:
|
|
331
|
+
if in_parallel:
|
|
332
|
+
logger.info(f"Executing {len(queries_with_parameters)} queries in parallel.")
|
|
333
|
+
await asyncio.gather(
|
|
334
|
+
*[
|
|
335
|
+
client.execute_query(query, parameters_=parameters)
|
|
336
|
+
for query, parameters in queries_with_parameters
|
|
337
|
+
]
|
|
338
|
+
)
|
|
339
|
+
logger.info("Finished executing parallel queries.")
|
|
340
|
+
else:
|
|
341
|
+
logger.info(f"Executing {len(queries_with_parameters)} queries sequentially.")
|
|
342
|
+
for i, (query, parameters) in enumerate(queries_with_parameters):
|
|
343
|
+
logger.info(f"Query #{i} started.")
|
|
344
|
+
await client.execute_query(query, parameters_=parameters)
|
|
345
|
+
logger.info(f"Query #{i} finished.")
|
|
346
|
+
logger.info(
|
|
347
|
+
f"Finished executing all ({len(queries_with_parameters)}) sequential queries."
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
@staticmethod
|
|
351
|
+
def _create_nodes_query(nodes: list[_Node], labels: tuple[Label, ...]) -> tuple[str, dict]:
|
|
352
|
+
labels_string = ", ".join(labels)
|
|
353
|
+
logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{labels_string}'.")
|
|
354
|
+
query_string = f"""
|
|
355
|
+
UNWIND $nodes AS node
|
|
356
|
+
MERGE (n: {labels_string} {{id: node.id}})
|
|
357
|
+
SET n += node.properties
|
|
358
|
+
"""
|
|
359
|
+
parameters = {"nodes": [{"id": node.id_, "properties": node.properties} for node in nodes]}
|
|
360
|
+
return query_string, parameters
|
|
361
|
+
|
|
362
|
+
@staticmethod
|
|
363
|
+
def _create_edges_query(edges: list[_Edge], relationship: Relationship) -> tuple[str, dict]:
|
|
364
|
+
logger.info(f"Preparing MERGE query for {len(edges)} {relationship} relationships.")
|
|
365
|
+
query_string = f"""
|
|
366
|
+
UNWIND $edges AS edge
|
|
367
|
+
MATCH (u {{id: edge.source}})
|
|
368
|
+
MATCH (v {{id: edge.destination}})
|
|
369
|
+
MERGE (u)-[:{relationship}]->(v)
|
|
370
|
+
"""
|
|
371
|
+
parameters = {
|
|
372
|
+
"edges": [
|
|
373
|
+
{"source": edge.source_id, "destination": edge.destination_id} for edge in edges
|
|
374
|
+
]
|
|
375
|
+
}
|
|
376
|
+
return query_string, parameters
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
neo4j_destination_entry = DestinationRegistryEntry(
|
|
380
|
+
connection_config=Neo4jConnectionConfig,
|
|
381
|
+
uploader=Neo4jUploader,
|
|
382
|
+
uploader_config=Neo4jUploaderConfig,
|
|
383
|
+
)
|
|
@@ -202,7 +202,7 @@ class OnedriveDownloader(Downloader):
|
|
|
202
202
|
if file_data.source_identifiers is None or not file_data.source_identifiers.fullpath:
|
|
203
203
|
raise ValueError(
|
|
204
204
|
f"file data doesn't have enough information to get "
|
|
205
|
-
f"file content: {file_data.
|
|
205
|
+
f"file content: {file_data.model_dump()}"
|
|
206
206
|
)
|
|
207
207
|
|
|
208
208
|
server_relative_path = file_data.source_identifiers.fullpath
|