unstructured-ingest 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (87) hide show
  1. test/integration/chunkers/test_chunkers.py +0 -11
  2. test/integration/connectors/conftest.py +11 -1
  3. test/integration/connectors/databricks_tests/test_volumes_native.py +4 -3
  4. test/integration/connectors/duckdb/conftest.py +14 -0
  5. test/integration/connectors/duckdb/test_duckdb.py +51 -44
  6. test/integration/connectors/duckdb/test_motherduck.py +37 -48
  7. test/integration/connectors/elasticsearch/test_elasticsearch.py +26 -4
  8. test/integration/connectors/elasticsearch/test_opensearch.py +26 -3
  9. test/integration/connectors/sql/test_postgres.py +103 -92
  10. test/integration/connectors/sql/test_singlestore.py +112 -100
  11. test/integration/connectors/sql/test_snowflake.py +142 -117
  12. test/integration/connectors/sql/test_sqlite.py +87 -76
  13. test/integration/connectors/test_astradb.py +62 -1
  14. test/integration/connectors/test_azure_ai_search.py +25 -3
  15. test/integration/connectors/test_chroma.py +120 -0
  16. test/integration/connectors/test_confluence.py +4 -4
  17. test/integration/connectors/test_delta_table.py +1 -0
  18. test/integration/connectors/test_kafka.py +6 -6
  19. test/integration/connectors/test_milvus.py +21 -0
  20. test/integration/connectors/test_mongodb.py +7 -4
  21. test/integration/connectors/test_neo4j.py +236 -0
  22. test/integration/connectors/test_pinecone.py +25 -1
  23. test/integration/connectors/test_qdrant.py +25 -2
  24. test/integration/connectors/test_s3.py +9 -6
  25. test/integration/connectors/utils/docker.py +6 -0
  26. test/integration/connectors/utils/validation/__init__.py +0 -0
  27. test/integration/connectors/utils/validation/destination.py +88 -0
  28. test/integration/connectors/utils/validation/equality.py +75 -0
  29. test/integration/connectors/utils/{validation.py → validation/source.py} +42 -98
  30. test/integration/connectors/utils/validation/utils.py +36 -0
  31. unstructured_ingest/__version__.py +1 -1
  32. unstructured_ingest/utils/chunking.py +11 -0
  33. unstructured_ingest/utils/data_prep.py +36 -0
  34. unstructured_ingest/v2/interfaces/__init__.py +3 -1
  35. unstructured_ingest/v2/interfaces/file_data.py +58 -14
  36. unstructured_ingest/v2/interfaces/upload_stager.py +70 -6
  37. unstructured_ingest/v2/interfaces/uploader.py +11 -2
  38. unstructured_ingest/v2/pipeline/steps/chunk.py +2 -1
  39. unstructured_ingest/v2/pipeline/steps/download.py +5 -4
  40. unstructured_ingest/v2/pipeline/steps/embed.py +2 -1
  41. unstructured_ingest/v2/pipeline/steps/filter.py +2 -2
  42. unstructured_ingest/v2/pipeline/steps/index.py +4 -4
  43. unstructured_ingest/v2/pipeline/steps/partition.py +3 -2
  44. unstructured_ingest/v2/pipeline/steps/stage.py +5 -3
  45. unstructured_ingest/v2/pipeline/steps/uncompress.py +2 -2
  46. unstructured_ingest/v2/pipeline/steps/upload.py +3 -3
  47. unstructured_ingest/v2/processes/connectors/__init__.py +3 -0
  48. unstructured_ingest/v2/processes/connectors/astradb.py +43 -63
  49. unstructured_ingest/v2/processes/connectors/azure_ai_search.py +16 -40
  50. unstructured_ingest/v2/processes/connectors/chroma.py +36 -59
  51. unstructured_ingest/v2/processes/connectors/couchbase.py +92 -93
  52. unstructured_ingest/v2/processes/connectors/delta_table.py +11 -33
  53. unstructured_ingest/v2/processes/connectors/duckdb/base.py +26 -26
  54. unstructured_ingest/v2/processes/connectors/duckdb/duckdb.py +29 -20
  55. unstructured_ingest/v2/processes/connectors/duckdb/motherduck.py +37 -44
  56. unstructured_ingest/v2/processes/connectors/elasticsearch/elasticsearch.py +46 -75
  57. unstructured_ingest/v2/processes/connectors/fsspec/azure.py +12 -35
  58. unstructured_ingest/v2/processes/connectors/fsspec/box.py +12 -35
  59. unstructured_ingest/v2/processes/connectors/fsspec/dropbox.py +15 -42
  60. unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +33 -29
  61. unstructured_ingest/v2/processes/connectors/fsspec/gcs.py +12 -34
  62. unstructured_ingest/v2/processes/connectors/fsspec/s3.py +13 -37
  63. unstructured_ingest/v2/processes/connectors/fsspec/sftp.py +19 -33
  64. unstructured_ingest/v2/processes/connectors/gitlab.py +32 -31
  65. unstructured_ingest/v2/processes/connectors/google_drive.py +32 -29
  66. unstructured_ingest/v2/processes/connectors/kafka/kafka.py +2 -4
  67. unstructured_ingest/v2/processes/connectors/kdbai.py +44 -70
  68. unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +8 -10
  69. unstructured_ingest/v2/processes/connectors/local.py +13 -2
  70. unstructured_ingest/v2/processes/connectors/milvus.py +16 -57
  71. unstructured_ingest/v2/processes/connectors/mongodb.py +99 -108
  72. unstructured_ingest/v2/processes/connectors/neo4j.py +383 -0
  73. unstructured_ingest/v2/processes/connectors/onedrive.py +1 -1
  74. unstructured_ingest/v2/processes/connectors/pinecone.py +3 -33
  75. unstructured_ingest/v2/processes/connectors/qdrant/qdrant.py +32 -41
  76. unstructured_ingest/v2/processes/connectors/sql/postgres.py +5 -5
  77. unstructured_ingest/v2/processes/connectors/sql/singlestore.py +5 -5
  78. unstructured_ingest/v2/processes/connectors/sql/snowflake.py +5 -5
  79. unstructured_ingest/v2/processes/connectors/sql/sql.py +72 -66
  80. unstructured_ingest/v2/processes/connectors/sql/sqlite.py +5 -5
  81. unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +9 -31
  82. {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/METADATA +20 -15
  83. {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/RECORD +87 -79
  84. {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/LICENSE.md +0 -0
  85. {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/WHEEL +0 -0
  86. {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/entry_points.txt +0 -0
  87. {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,10 @@
1
- import json
2
- import sys
3
1
  from contextlib import contextmanager
4
- from dataclasses import dataclass, replace
2
+ from dataclasses import dataclass
5
3
  from datetime import datetime
6
- from pathlib import Path
7
4
  from time import time
8
5
  from typing import TYPE_CHECKING, Any, Generator, Optional
9
6
 
10
- from pydantic import Field, Secret
7
+ from pydantic import BaseModel, Field, Secret
11
8
 
12
9
  from unstructured_ingest.__version__ import __version__ as unstructured_version
13
10
  from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
@@ -16,9 +13,12 @@ from unstructured_ingest.utils.dep_check import requires_dependencies
16
13
  from unstructured_ingest.v2.constants import RECORD_ID_LABEL
17
14
  from unstructured_ingest.v2.interfaces import (
18
15
  AccessConfig,
16
+ BatchFileData,
17
+ BatchItem,
19
18
  ConnectionConfig,
20
19
  Downloader,
21
20
  DownloaderConfig,
21
+ DownloadResponse,
22
22
  FileData,
23
23
  FileDataSourceMetadata,
24
24
  Indexer,
@@ -42,6 +42,15 @@ CONNECTOR_TYPE = "mongodb"
42
42
  SERVER_API_VERSION = "1"
43
43
 
44
44
 
45
+ class MongoDBAdditionalMetadata(BaseModel):
46
+ database: str
47
+ collection: str
48
+
49
+
50
+ class MongoDBBatchFileData(BatchFileData):
51
+ additional_metadata: MongoDBAdditionalMetadata
52
+
53
+
45
54
  class MongoDBAccessConfig(AccessConfig):
46
55
  uri: Optional[str] = Field(default=None, description="URI to user when connecting")
47
56
 
@@ -124,7 +133,7 @@ class MongoDBIndexer(Indexer):
124
133
  logger.error(f"Failed to validate connection: {e}", exc_info=True)
125
134
  raise SourceConnectionError(f"Failed to validate connection: {e}")
126
135
 
127
- def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
136
+ def run(self, **kwargs: Any) -> Generator[BatchFileData, None, None]:
128
137
  """Generates FileData objects for each document in the MongoDB collection."""
129
138
  with self.connection_config.get_client() as client:
130
139
  database = client[self.index_config.database]
@@ -132,12 +141,12 @@ class MongoDBIndexer(Indexer):
132
141
 
133
142
  # Get list of document IDs
134
143
  ids = collection.distinct("_id")
135
- batch_size = self.index_config.batch_size if self.index_config else 100
144
+
145
+ ids = sorted(ids)
146
+ batch_size = self.index_config.batch_size
136
147
 
137
148
  for id_batch in batch_generator(ids, batch_size=batch_size):
138
149
  # Make sure the hash is always a positive number to create identifier
139
- batch_id = str(hash(frozenset(id_batch)) + sys.maxsize + 1)
140
-
141
150
  metadata = FileDataSourceMetadata(
142
151
  date_processed=str(time()),
143
152
  record_locator={
@@ -146,14 +155,13 @@ class MongoDBIndexer(Indexer):
146
155
  },
147
156
  )
148
157
 
149
- file_data = FileData(
150
- identifier=batch_id,
151
- doc_type="batch",
158
+ file_data = MongoDBBatchFileData(
152
159
  connector_type=self.connector_type,
153
160
  metadata=metadata,
154
- additional_metadata={
155
- "ids": [str(doc_id) for doc_id in id_batch],
156
- },
161
+ batch_items=[BatchItem(identifier=str(doc_id)) for doc_id in id_batch],
162
+ additional_metadata=MongoDBAdditionalMetadata(
163
+ collection=self.index_config.collection, database=self.index_config.database
164
+ ),
157
165
  )
158
166
  yield file_data
159
167
 
@@ -164,26 +172,59 @@ class MongoDBDownloader(Downloader):
164
172
  connection_config: MongoDBConnectionConfig
165
173
  connector_type: str = CONNECTOR_TYPE
166
174
 
167
- @requires_dependencies(["pymongo"], extras="mongodb")
168
- def create_client(self) -> "MongoClient":
169
- from pymongo import MongoClient
170
- from pymongo.driver_info import DriverInfo
171
- from pymongo.server_api import ServerApi
175
+ def generate_download_response(
176
+ self, doc: dict, file_data: MongoDBBatchFileData
177
+ ) -> DownloadResponse:
178
+ from bson.objectid import ObjectId
172
179
 
173
- access_config = self.connection_config.access_config.get_secret_value()
180
+ doc_id = doc["_id"]
181
+ doc.pop("_id", None)
174
182
 
175
- if access_config.uri:
176
- return MongoClient(
177
- access_config.uri,
178
- server_api=ServerApi(version=SERVER_API_VERSION),
179
- driver=DriverInfo(name="unstructured", version=unstructured_version),
180
- )
181
- else:
182
- return MongoClient(
183
- host=self.connection_config.host,
184
- port=self.connection_config.port,
185
- server_api=ServerApi(version=SERVER_API_VERSION),
186
- )
183
+ # Extract date_created from the document or ObjectId
184
+ date_created = None
185
+ if "date_created" in doc:
186
+ # If the document has a 'date_created' field, use it
187
+ date_created = doc["date_created"]
188
+ if isinstance(date_created, datetime):
189
+ date_created = date_created.isoformat()
190
+ else:
191
+ # Convert to ISO format if it's a string
192
+ date_created = str(date_created)
193
+ elif isinstance(doc_id, ObjectId):
194
+ # Use the ObjectId's generation time
195
+ date_created = doc_id.generation_time.isoformat()
196
+
197
+ flattened_dict = flatten_dict(dictionary=doc)
198
+ concatenated_values = "\n".join(str(value) for value in flattened_dict.values())
199
+
200
+ # Create a FileData object for each document with source_identifiers
201
+ cast_file_data = FileData.cast(file_data=file_data)
202
+ cast_file_data.identifier = str(doc_id)
203
+ filename = f"{doc_id}.txt"
204
+ cast_file_data.source_identifiers = SourceIdentifiers(
205
+ filename=filename,
206
+ fullpath=filename,
207
+ rel_path=filename,
208
+ )
209
+
210
+ # Determine the download path
211
+ download_path = self.get_download_path(file_data=cast_file_data)
212
+ if download_path is None:
213
+ raise ValueError("Download path could not be determined")
214
+
215
+ download_path.parent.mkdir(parents=True, exist_ok=True)
216
+
217
+ # Write the concatenated values to the file
218
+ with open(download_path, "w", encoding="utf8") as f:
219
+ f.write(concatenated_values)
220
+
221
+ # Update metadata
222
+ cast_file_data.metadata.record_locator["document_id"] = str(doc_id)
223
+ cast_file_data.metadata.date_created = date_created
224
+
225
+ return super().generate_download_response(
226
+ file_data=cast_file_data, download_path=download_path
227
+ )
187
228
 
188
229
  @SourceConnectionError.wrap
189
230
  @requires_dependencies(["bson"], extras="mongodb")
@@ -192,82 +233,34 @@ class MongoDBDownloader(Downloader):
192
233
  from bson.errors import InvalidId
193
234
  from bson.objectid import ObjectId
194
235
 
195
- client = self.create_client()
196
- database = client[file_data.metadata.record_locator["database"]]
197
- collection = database[file_data.metadata.record_locator["collection"]]
236
+ mongo_file_data = MongoDBBatchFileData.cast(file_data=file_data)
198
237
 
199
- ids = file_data.additional_metadata.get("ids", [])
200
- if not ids:
201
- raise ValueError("No document IDs provided in additional_metadata")
238
+ with self.connection_config.get_client() as client:
239
+ database = client[mongo_file_data.additional_metadata.database]
240
+ collection = database[mongo_file_data.additional_metadata.collection]
202
241
 
203
- object_ids = []
204
- for doc_id in ids:
205
- try:
206
- object_ids.append(ObjectId(doc_id))
207
- except InvalidId as e:
208
- error_message = f"Invalid ObjectId for doc_id '{doc_id}': {str(e)}"
209
- logger.error(error_message)
210
- raise ValueError(error_message) from e
242
+ ids = [item.identifier for item in mongo_file_data.batch_items]
211
243
 
212
- try:
213
- docs = list(collection.find({"_id": {"$in": object_ids}}))
214
- except Exception as e:
215
- logger.error(f"Failed to fetch documents: {e}", exc_info=True)
216
- raise e
244
+ object_ids = []
245
+ for doc_id in ids:
246
+ try:
247
+ object_ids.append(ObjectId(doc_id))
248
+ except InvalidId as e:
249
+ error_message = f"Invalid ObjectId for doc_id '{doc_id}': {str(e)}"
250
+ logger.error(error_message)
251
+ raise ValueError(error_message) from e
252
+
253
+ try:
254
+ docs = list(collection.find({"_id": {"$in": object_ids}}))
255
+ except Exception as e:
256
+ logger.error(f"Failed to fetch documents: {e}", exc_info=True)
257
+ raise e
217
258
 
218
259
  download_responses = []
219
260
  for doc in docs:
220
- doc_id = doc["_id"]
221
- doc.pop("_id", None)
222
-
223
- # Extract date_created from the document or ObjectId
224
- date_created = None
225
- if "date_created" in doc:
226
- # If the document has a 'date_created' field, use it
227
- date_created = doc["date_created"]
228
- if isinstance(date_created, datetime):
229
- date_created = date_created.isoformat()
230
- else:
231
- # Convert to ISO format if it's a string
232
- date_created = str(date_created)
233
- elif isinstance(doc_id, ObjectId):
234
- # Use the ObjectId's generation time
235
- date_created = doc_id.generation_time.isoformat()
236
-
237
- flattened_dict = flatten_dict(dictionary=doc)
238
- concatenated_values = "\n".join(str(value) for value in flattened_dict.values())
239
-
240
- # Create a FileData object for each document with source_identifiers
241
- individual_file_data = replace(file_data)
242
- individual_file_data.identifier = str(doc_id)
243
- individual_file_data.source_identifiers = SourceIdentifiers(
244
- filename=str(doc_id),
245
- fullpath=str(doc_id),
246
- rel_path=str(doc_id),
247
- )
248
-
249
- # Determine the download path
250
- download_path = self.get_download_path(individual_file_data)
251
- if download_path is None:
252
- raise ValueError("Download path could not be determined")
253
-
254
- download_path.parent.mkdir(parents=True, exist_ok=True)
255
- download_path = download_path.with_suffix(".txt")
256
-
257
- # Write the concatenated values to the file
258
- with open(download_path, "w", encoding="utf8") as f:
259
- f.write(concatenated_values)
260
-
261
- individual_file_data.local_download_path = str(download_path)
262
-
263
- # Update metadata
264
- individual_file_data.metadata.record_locator["document_id"] = str(doc_id)
265
- individual_file_data.metadata.date_created = date_created
266
-
267
- download_response = self.generate_download_response(
268
- file_data=individual_file_data, download_path=download_path
261
+ download_responses.append(
262
+ self.generate_download_response(doc=doc, file_data=mongo_file_data)
269
263
  )
270
- download_responses.append(download_response)
271
264
 
272
265
  return download_responses
273
266
 
@@ -332,18 +325,16 @@ class MongoDBUploader(Uploader):
332
325
  f"deleted {delete_results.deleted_count} records from collection {collection.name}"
333
326
  )
334
327
 
335
- def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
336
- with path.open("r") as file:
337
- elements_dict = json.load(file)
328
+ def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
338
329
  logger.info(
339
- f"writing {len(elements_dict)} objects to destination "
330
+ f"writing {len(data)} objects to destination "
340
331
  f"db, {self.upload_config.database}, "
341
332
  f"collection {self.upload_config.collection} "
342
333
  f"at {self.connection_config.host}",
343
334
  )
344
335
  # This would typically live in the stager but since no other manipulation
345
336
  # is done, setting the record id field in the uploader
346
- for element in elements_dict:
337
+ for element in data:
347
338
  element[self.upload_config.record_id_key] = file_data.identifier
348
339
  with self.connection_config.get_client() as client:
349
340
  db = client[self.upload_config.database]
@@ -352,7 +343,7 @@ class MongoDBUploader(Uploader):
352
343
  self.delete_by_record_id(file_data=file_data, collection=collection)
353
344
  else:
354
345
  logger.warning("criteria for deleting previous content not met, skipping")
355
- for chunk in batch_generator(elements_dict, self.upload_config.batch_size):
346
+ for chunk in batch_generator(data, self.upload_config.batch_size):
356
347
  collection.insert_many(chunk)
357
348
 
358
349
 
@@ -0,0 +1,383 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import json
5
+ import uuid
6
+ from collections import defaultdict
7
+ from contextlib import asynccontextmanager
8
+ from dataclasses import dataclass
9
+ from enum import Enum
10
+ from pathlib import Path
11
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
12
+
13
+ from pydantic import BaseModel, ConfigDict, Field, Secret
14
+
15
+ from unstructured_ingest.error import DestinationConnectionError
16
+ from unstructured_ingest.logger import logger
17
+ from unstructured_ingest.utils.chunking import elements_from_base64_gzipped_json
18
+ from unstructured_ingest.utils.data_prep import batch_generator
19
+ from unstructured_ingest.utils.dep_check import requires_dependencies
20
+ from unstructured_ingest.v2.interfaces import (
21
+ AccessConfig,
22
+ ConnectionConfig,
23
+ FileData,
24
+ Uploader,
25
+ UploaderConfig,
26
+ UploadStager,
27
+ UploadStagerConfig,
28
+ )
29
+ from unstructured_ingest.v2.processes.connector_registry import (
30
+ DestinationRegistryEntry,
31
+ )
32
+
33
+ if TYPE_CHECKING:
34
+ from neo4j import AsyncDriver, Auth
35
+ from networkx import Graph, MultiDiGraph
36
+
37
+ CONNECTOR_TYPE = "neo4j"
38
+
39
+
40
+ class Neo4jAccessConfig(AccessConfig):
41
+ password: str
42
+
43
+
44
+ class Neo4jConnectionConfig(ConnectionConfig):
45
+ access_config: Secret[Neo4jAccessConfig]
46
+ connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
47
+ username: str
48
+ uri: str = Field(description="Neo4j Connection URI <scheme>://<host>:<port>")
49
+ database: str = Field(description="Name of the target database")
50
+
51
+ @requires_dependencies(["neo4j"], extras="neo4j")
52
+ @asynccontextmanager
53
+ async def get_client(self) -> AsyncGenerator["AsyncDriver", None]:
54
+ from neo4j import AsyncGraphDatabase
55
+
56
+ driver = AsyncGraphDatabase.driver(**self._get_driver_parameters())
57
+ logger.info(f"Created driver connecting to the database '{self.database}' at {self.uri}.")
58
+ try:
59
+ yield driver
60
+ finally:
61
+ await driver.close()
62
+ logger.info(
63
+ f"Closed driver connecting to the database '{self.database}' at {self.uri}."
64
+ )
65
+
66
+ def _get_driver_parameters(self) -> dict:
67
+ return {
68
+ "uri": self.uri,
69
+ "auth": self._get_auth(),
70
+ "database": self.database,
71
+ }
72
+
73
+ @requires_dependencies(["neo4j"], extras="neo4j")
74
+ def _get_auth(self) -> "Auth":
75
+ from neo4j import Auth
76
+
77
+ return Auth("basic", self.username, self.access_config.get_secret_value().password)
78
+
79
+
80
+ class Neo4jUploadStagerConfig(UploadStagerConfig):
81
+ pass
82
+
83
+
84
+ @dataclass
85
+ class Neo4jUploadStager(UploadStager):
86
+ upload_stager_config: Neo4jUploadStagerConfig = Field(
87
+ default_factory=Neo4jUploadStagerConfig, validate_default=True
88
+ )
89
+
90
+ def run( # type: ignore
91
+ self,
92
+ elements_filepath: Path,
93
+ file_data: FileData,
94
+ output_dir: Path,
95
+ output_filename: str,
96
+ **kwargs: Any,
97
+ ) -> Path:
98
+ with elements_filepath.open() as file:
99
+ elements = json.load(file)
100
+
101
+ nx_graph = self._create_lexical_graph(
102
+ elements, self._create_document_node(file_data=file_data)
103
+ )
104
+ output_filepath = Path(output_dir) / f"{output_filename}.json"
105
+ output_filepath.parent.mkdir(parents=True, exist_ok=True)
106
+
107
+ with open(output_filepath, "w") as file:
108
+ json.dump(_GraphData.from_nx(nx_graph).model_dump(), file, indent=4)
109
+
110
+ return output_filepath
111
+
112
+ def _create_lexical_graph(self, elements: list[dict], document_node: _Node) -> "Graph":
113
+ import networkx as nx
114
+
115
+ graph = nx.MultiDiGraph()
116
+ graph.add_node(document_node)
117
+
118
+ previous_node: Optional[_Node] = None
119
+ for element in elements:
120
+ element_node = self._create_element_node(element)
121
+ order_relationship = (
122
+ Relationship.NEXT_CHUNK if self._is_chunk(element) else Relationship.NEXT_ELEMENT
123
+ )
124
+ if previous_node:
125
+ graph.add_edge(element_node, previous_node, relationship=order_relationship)
126
+
127
+ previous_node = element_node
128
+ graph.add_edge(element_node, document_node, relationship=Relationship.PART_OF_DOCUMENT)
129
+
130
+ if self._is_chunk(element):
131
+ origin_element_nodes = [
132
+ self._create_element_node(origin_element)
133
+ for origin_element in self._get_origin_elements(element)
134
+ ]
135
+ graph.add_edges_from(
136
+ [
137
+ (origin_element_node, element_node)
138
+ for origin_element_node in origin_element_nodes
139
+ ],
140
+ relationship=Relationship.PART_OF_CHUNK,
141
+ )
142
+ graph.add_edges_from(
143
+ [
144
+ (origin_element_node, document_node)
145
+ for origin_element_node in origin_element_nodes
146
+ ],
147
+ relationship=Relationship.PART_OF_DOCUMENT,
148
+ )
149
+
150
+ return graph
151
+
152
+ # TODO(Filip Knefel): Ensure _is_chunk is as reliable as possible, consider different checks
153
+ def _is_chunk(self, element: dict) -> bool:
154
+ return "orig_elements" in element.get("metadata", {})
155
+
156
+ def _create_document_node(self, file_data: FileData) -> _Node:
157
+ properties = {}
158
+ if file_data.source_identifiers:
159
+ properties["name"] = file_data.source_identifiers.filename
160
+ if file_data.metadata.date_created:
161
+ properties["date_created"] = file_data.metadata.date_created
162
+ if file_data.metadata.date_modified:
163
+ properties["date_modified"] = file_data.metadata.date_modified
164
+ return _Node(id_=file_data.identifier, properties=properties, labels=[Label.DOCUMENT])
165
+
166
+ def _create_element_node(self, element: dict) -> _Node:
167
+ properties = {"id": element["element_id"], "text": element["text"]}
168
+
169
+ if embeddings := element.get("embeddings"):
170
+ properties["embeddings"] = embeddings
171
+
172
+ label = Label.CHUNK if self._is_chunk(element) else Label.UNSTRUCTURED_ELEMENT
173
+ return _Node(id_=element["element_id"], properties=properties, labels=[label])
174
+
175
+ def _get_origin_elements(self, chunk_element: dict) -> list[dict]:
176
+ orig_elements = chunk_element.get("metadata", {}).get("orig_elements")
177
+ return elements_from_base64_gzipped_json(raw_s=orig_elements)
178
+
179
+
180
+ class _GraphData(BaseModel):
181
+ nodes: list[_Node]
182
+ edges: list[_Edge]
183
+
184
+ @classmethod
185
+ def from_nx(cls, nx_graph: "MultiDiGraph") -> _GraphData:
186
+ nodes = list(nx_graph.nodes())
187
+ edges = [
188
+ _Edge(
189
+ source_id=u.id_,
190
+ destination_id=v.id_,
191
+ relationship=Relationship(data_dict["relationship"]),
192
+ )
193
+ for u, v, data_dict in nx_graph.edges(data=True)
194
+ ]
195
+ return _GraphData(nodes=nodes, edges=edges)
196
+
197
+
198
+ class _Node(BaseModel):
199
+ model_config = ConfigDict(use_enum_values=True)
200
+
201
+ id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
202
+ labels: list[Label] = Field(default_factory=list)
203
+ properties: dict = Field(default_factory=dict)
204
+
205
+ def __hash__(self):
206
+ return hash(self.id_)
207
+
208
+
209
+ class _Edge(BaseModel):
210
+ model_config = ConfigDict(use_enum_values=True)
211
+
212
+ source_id: str
213
+ destination_id: str
214
+ relationship: Relationship
215
+
216
+
217
+ class Label(str, Enum):
218
+ UNSTRUCTURED_ELEMENT = "UnstructuredElement"
219
+ CHUNK = "Chunk"
220
+ DOCUMENT = "Document"
221
+
222
+
223
+ class Relationship(str, Enum):
224
+ PART_OF_DOCUMENT = "PART_OF_DOCUMENT"
225
+ PART_OF_CHUNK = "PART_OF_CHUNK"
226
+ NEXT_CHUNK = "NEXT_CHUNK"
227
+ NEXT_ELEMENT = "NEXT_ELEMENT"
228
+
229
+
230
+ class Neo4jUploaderConfig(UploaderConfig):
231
+ batch_size: int = Field(
232
+ default=100, description="Maximal number of nodes/relationships created per transaction."
233
+ )
234
+
235
+
236
+ @dataclass
237
+ class Neo4jUploader(Uploader):
238
+ upload_config: Neo4jUploaderConfig
239
+ connection_config: Neo4jConnectionConfig
240
+ connector_type: str = CONNECTOR_TYPE
241
+
242
+ @DestinationConnectionError.wrap
243
+ def precheck(self) -> None:
244
+ async def verify_auth():
245
+ async with self.connection_config.get_client() as client:
246
+ await client.verify_connectivity()
247
+
248
+ asyncio.run(verify_auth())
249
+
250
+ def is_async(self):
251
+ return True
252
+
253
+ async def run_async(self, path: Path, file_data: FileData, **kwargs) -> None: # type: ignore
254
+ with path.open() as file:
255
+ staged_data = json.load(file)
256
+
257
+ graph_data = _GraphData.model_validate(staged_data)
258
+ async with self.connection_config.get_client() as client:
259
+ await self._create_uniqueness_constraints(client)
260
+ await self._delete_old_data_if_exists(file_data, client=client)
261
+ await self._merge_graph(graph_data=graph_data, client=client)
262
+
263
+ async def _create_uniqueness_constraints(self, client: AsyncDriver) -> None:
264
+ for label in Label:
265
+ logger.info(
266
+ f"Adding id uniqueness constraint for nodes labeled '{label}'"
267
+ " if it does not already exist."
268
+ )
269
+ constraint_name = f"{label.lower()}_id"
270
+ await client.execute_query(
271
+ f"""
272
+ CREATE CONSTRAINT {constraint_name} IF NOT EXISTS
273
+ FOR (n: {label}) REQUIRE n.id IS UNIQUE
274
+ """
275
+ )
276
+
277
+ async def _delete_old_data_if_exists(self, file_data: FileData, client: AsyncDriver) -> None:
278
+ logger.info(f"Deleting old data for the record '{file_data.identifier}' (if present).")
279
+ _, summary, _ = await client.execute_query(
280
+ f"""
281
+ MATCH (n: {Label.DOCUMENT} {{id: $identifier}})
282
+ MATCH (n)--(m: {Label.CHUNK}|{Label.UNSTRUCTURED_ELEMENT})
283
+ DETACH DELETE m""",
284
+ identifier=file_data.identifier,
285
+ )
286
+ logger.info(
287
+ f"Deleted {summary.counters.nodes_deleted} nodes"
288
+ f" and {summary.counters.relationships_deleted} relationships."
289
+ )
290
+
291
+ async def _merge_graph(self, graph_data: _GraphData, client: AsyncDriver) -> None:
292
+ nodes_by_labels: defaultdict[tuple[Label, ...], list[_Node]] = defaultdict(list)
293
+ for node in graph_data.nodes:
294
+ nodes_by_labels[tuple(node.labels)].append(node)
295
+
296
+ logger.info(f"Merging {len(graph_data.nodes)} graph nodes.")
297
+ # NOTE: Processed in parallel as there's no overlap between accessed nodes
298
+ await self._execute_queries(
299
+ [
300
+ self._create_nodes_query(nodes_batch, labels)
301
+ for labels, nodes in nodes_by_labels.items()
302
+ for nodes_batch in batch_generator(nodes, batch_size=self.upload_config.batch_size)
303
+ ],
304
+ client=client,
305
+ in_parallel=True,
306
+ )
307
+ logger.info(f"Finished merging {len(graph_data.nodes)} graph nodes.")
308
+
309
+ edges_by_relationship: defaultdict[Relationship, list[_Edge]] = defaultdict(list)
310
+ for edge in graph_data.edges:
311
+ edges_by_relationship[edge.relationship].append(edge)
312
+
313
+ logger.info(f"Merging {len(graph_data.edges)} graph relationships (edges).")
314
+ # NOTE: Processed sequentially to avoid queries locking node access to one another
315
+ await self._execute_queries(
316
+ [
317
+ self._create_edges_query(edges_batch, relationship)
318
+ for relationship, edges in edges_by_relationship.items()
319
+ for edges_batch in batch_generator(edges, batch_size=self.upload_config.batch_size)
320
+ ],
321
+ client=client,
322
+ )
323
+ logger.info(f"Finished merging {len(graph_data.edges)} graph relationships (edges).")
324
+
325
+ @staticmethod
326
+ async def _execute_queries(
327
+ queries_with_parameters: list[tuple[str, dict]],
328
+ client: AsyncDriver,
329
+ in_parallel: bool = False,
330
+ ) -> None:
331
+ if in_parallel:
332
+ logger.info(f"Executing {len(queries_with_parameters)} queries in parallel.")
333
+ await asyncio.gather(
334
+ *[
335
+ client.execute_query(query, parameters_=parameters)
336
+ for query, parameters in queries_with_parameters
337
+ ]
338
+ )
339
+ logger.info("Finished executing parallel queries.")
340
+ else:
341
+ logger.info(f"Executing {len(queries_with_parameters)} queries sequentially.")
342
+ for i, (query, parameters) in enumerate(queries_with_parameters):
343
+ logger.info(f"Query #{i} started.")
344
+ await client.execute_query(query, parameters_=parameters)
345
+ logger.info(f"Query #{i} finished.")
346
+ logger.info(
347
+ f"Finished executing all ({len(queries_with_parameters)}) sequential queries."
348
+ )
349
+
350
+ @staticmethod
351
+ def _create_nodes_query(nodes: list[_Node], labels: tuple[Label, ...]) -> tuple[str, dict]:
352
+ labels_string = ", ".join(labels)
353
+ logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{labels_string}'.")
354
+ query_string = f"""
355
+ UNWIND $nodes AS node
356
+ MERGE (n: {labels_string} {{id: node.id}})
357
+ SET n += node.properties
358
+ """
359
+ parameters = {"nodes": [{"id": node.id_, "properties": node.properties} for node in nodes]}
360
+ return query_string, parameters
361
+
362
+ @staticmethod
363
+ def _create_edges_query(edges: list[_Edge], relationship: Relationship) -> tuple[str, dict]:
364
+ logger.info(f"Preparing MERGE query for {len(edges)} {relationship} relationships.")
365
+ query_string = f"""
366
+ UNWIND $edges AS edge
367
+ MATCH (u {{id: edge.source}})
368
+ MATCH (v {{id: edge.destination}})
369
+ MERGE (u)-[:{relationship}]->(v)
370
+ """
371
+ parameters = {
372
+ "edges": [
373
+ {"source": edge.source_id, "destination": edge.destination_id} for edge in edges
374
+ ]
375
+ }
376
+ return query_string, parameters
377
+
378
+
379
+ neo4j_destination_entry = DestinationRegistryEntry(
380
+ connection_config=Neo4jConnectionConfig,
381
+ uploader=Neo4jUploader,
382
+ uploader_config=Neo4jUploaderConfig,
383
+ )
@@ -202,7 +202,7 @@ class OnedriveDownloader(Downloader):
202
202
  if file_data.source_identifiers is None or not file_data.source_identifiers.fullpath:
203
203
  raise ValueError(
204
204
  f"file data doesn't have enough information to get "
205
- f"file content: {file_data.to_dict()}"
205
+ f"file content: {file_data.model_dump()}"
206
206
  )
207
207
 
208
208
  server_relative_path = file_data.source_identifiers.fullpath