unstructured-ingest 0.3.8__py3-none-any.whl → 0.3.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/chunkers/test_chunkers.py +0 -11
- test/integration/connectors/conftest.py +11 -1
- test/integration/connectors/databricks_tests/test_volumes_native.py +4 -3
- test/integration/connectors/duckdb/conftest.py +14 -0
- test/integration/connectors/duckdb/test_duckdb.py +51 -44
- test/integration/connectors/duckdb/test_motherduck.py +37 -48
- test/integration/connectors/elasticsearch/test_elasticsearch.py +26 -4
- test/integration/connectors/elasticsearch/test_opensearch.py +26 -3
- test/integration/connectors/sql/test_postgres.py +102 -91
- test/integration/connectors/sql/test_singlestore.py +111 -99
- test/integration/connectors/sql/test_snowflake.py +142 -117
- test/integration/connectors/sql/test_sqlite.py +86 -75
- test/integration/connectors/test_astradb.py +22 -1
- test/integration/connectors/test_azure_ai_search.py +25 -3
- test/integration/connectors/test_chroma.py +120 -0
- test/integration/connectors/test_confluence.py +4 -4
- test/integration/connectors/test_delta_table.py +1 -0
- test/integration/connectors/test_kafka.py +4 -4
- test/integration/connectors/test_milvus.py +21 -0
- test/integration/connectors/test_mongodb.py +3 -3
- test/integration/connectors/test_neo4j.py +236 -0
- test/integration/connectors/test_pinecone.py +25 -1
- test/integration/connectors/test_qdrant.py +25 -2
- test/integration/connectors/test_s3.py +9 -6
- test/integration/connectors/utils/docker.py +6 -0
- test/integration/connectors/utils/validation/__init__.py +0 -0
- test/integration/connectors/utils/validation/destination.py +88 -0
- test/integration/connectors/utils/validation/equality.py +75 -0
- test/integration/connectors/utils/{validation.py → validation/source.py} +15 -91
- test/integration/connectors/utils/validation/utils.py +36 -0
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/utils/chunking.py +11 -0
- unstructured_ingest/utils/data_prep.py +36 -0
- unstructured_ingest/v2/interfaces/upload_stager.py +70 -6
- unstructured_ingest/v2/interfaces/uploader.py +11 -2
- unstructured_ingest/v2/pipeline/steps/stage.py +3 -1
- unstructured_ingest/v2/processes/connectors/astradb.py +8 -30
- unstructured_ingest/v2/processes/connectors/azure_ai_search.py +16 -40
- unstructured_ingest/v2/processes/connectors/chroma.py +36 -59
- unstructured_ingest/v2/processes/connectors/couchbase.py +42 -52
- unstructured_ingest/v2/processes/connectors/delta_table.py +11 -33
- unstructured_ingest/v2/processes/connectors/duckdb/base.py +26 -26
- unstructured_ingest/v2/processes/connectors/duckdb/duckdb.py +29 -20
- unstructured_ingest/v2/processes/connectors/duckdb/motherduck.py +37 -44
- unstructured_ingest/v2/processes/connectors/elasticsearch/elasticsearch.py +5 -30
- unstructured_ingest/v2/processes/connectors/gitlab.py +32 -31
- unstructured_ingest/v2/processes/connectors/google_drive.py +32 -29
- unstructured_ingest/v2/processes/connectors/kafka/kafka.py +2 -4
- unstructured_ingest/v2/processes/connectors/kdbai.py +44 -70
- unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +8 -10
- unstructured_ingest/v2/processes/connectors/local.py +13 -2
- unstructured_ingest/v2/processes/connectors/milvus.py +16 -57
- unstructured_ingest/v2/processes/connectors/mongodb.py +4 -8
- unstructured_ingest/v2/processes/connectors/neo4j.py +381 -0
- unstructured_ingest/v2/processes/connectors/pinecone.py +3 -33
- unstructured_ingest/v2/processes/connectors/qdrant/qdrant.py +32 -41
- unstructured_ingest/v2/processes/connectors/sql/sql.py +41 -40
- unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +9 -31
- {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.9.dist-info}/METADATA +18 -14
- {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.9.dist-info}/RECORD +64 -56
- {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.9.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.9.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.9.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.9.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,7 @@
|
|
|
1
|
-
import json
|
|
2
1
|
import sys
|
|
3
2
|
from contextlib import contextmanager
|
|
4
3
|
from dataclasses import dataclass, replace
|
|
5
4
|
from datetime import datetime
|
|
6
|
-
from pathlib import Path
|
|
7
5
|
from time import time
|
|
8
6
|
from typing import TYPE_CHECKING, Any, Generator, Optional
|
|
9
7
|
|
|
@@ -332,18 +330,16 @@ class MongoDBUploader(Uploader):
|
|
|
332
330
|
f"deleted {delete_results.deleted_count} records from collection {collection.name}"
|
|
333
331
|
)
|
|
334
332
|
|
|
335
|
-
def
|
|
336
|
-
with path.open("r") as file:
|
|
337
|
-
elements_dict = json.load(file)
|
|
333
|
+
def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
|
|
338
334
|
logger.info(
|
|
339
|
-
f"writing {len(
|
|
335
|
+
f"writing {len(data)} objects to destination "
|
|
340
336
|
f"db, {self.upload_config.database}, "
|
|
341
337
|
f"collection {self.upload_config.collection} "
|
|
342
338
|
f"at {self.connection_config.host}",
|
|
343
339
|
)
|
|
344
340
|
# This would typically live in the stager but since no other manipulation
|
|
345
341
|
# is done, setting the record id field in the uploader
|
|
346
|
-
for element in
|
|
342
|
+
for element in data:
|
|
347
343
|
element[self.upload_config.record_id_key] = file_data.identifier
|
|
348
344
|
with self.connection_config.get_client() as client:
|
|
349
345
|
db = client[self.upload_config.database]
|
|
@@ -352,7 +348,7 @@ class MongoDBUploader(Uploader):
|
|
|
352
348
|
self.delete_by_record_id(file_data=file_data, collection=collection)
|
|
353
349
|
else:
|
|
354
350
|
logger.warning("criteria for deleting previous content not met, skipping")
|
|
355
|
-
for chunk in batch_generator(
|
|
351
|
+
for chunk in batch_generator(data, self.upload_config.batch_size):
|
|
356
352
|
collection.insert_many(chunk)
|
|
357
353
|
|
|
358
354
|
|
|
@@ -0,0 +1,381 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
import uuid
|
|
6
|
+
from collections import defaultdict
|
|
7
|
+
from contextlib import asynccontextmanager
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from enum import Enum
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
|
|
12
|
+
|
|
13
|
+
import networkx as nx
|
|
14
|
+
from pydantic import BaseModel, ConfigDict, Field, Secret
|
|
15
|
+
|
|
16
|
+
from unstructured_ingest.error import DestinationConnectionError
|
|
17
|
+
from unstructured_ingest.logger import logger
|
|
18
|
+
from unstructured_ingest.utils.chunking import elements_from_base64_gzipped_json
|
|
19
|
+
from unstructured_ingest.utils.data_prep import batch_generator
|
|
20
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
21
|
+
from unstructured_ingest.v2.interfaces import (
|
|
22
|
+
AccessConfig,
|
|
23
|
+
ConnectionConfig,
|
|
24
|
+
FileData,
|
|
25
|
+
Uploader,
|
|
26
|
+
UploaderConfig,
|
|
27
|
+
UploadStager,
|
|
28
|
+
UploadStagerConfig,
|
|
29
|
+
)
|
|
30
|
+
from unstructured_ingest.v2.processes.connector_registry import (
|
|
31
|
+
DestinationRegistryEntry,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
if TYPE_CHECKING:
|
|
35
|
+
from neo4j import AsyncDriver, Auth
|
|
36
|
+
|
|
37
|
+
CONNECTOR_TYPE = "neo4j"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class Neo4jAccessConfig(AccessConfig):
|
|
41
|
+
password: str
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class Neo4jConnectionConfig(ConnectionConfig):
|
|
45
|
+
access_config: Secret[Neo4jAccessConfig]
|
|
46
|
+
connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
|
|
47
|
+
username: str
|
|
48
|
+
uri: str = Field(description="Neo4j Connection URI <scheme>://<host>:<port>")
|
|
49
|
+
database: str = Field(description="Name of the target database")
|
|
50
|
+
|
|
51
|
+
@requires_dependencies(["neo4j"], extras="neo4j")
|
|
52
|
+
@asynccontextmanager
|
|
53
|
+
async def get_client(self) -> AsyncGenerator["AsyncDriver", None]:
|
|
54
|
+
from neo4j import AsyncGraphDatabase
|
|
55
|
+
|
|
56
|
+
driver = AsyncGraphDatabase.driver(**self._get_driver_parameters())
|
|
57
|
+
logger.info(f"Created driver connecting to the database '{self.database}' at {self.uri}.")
|
|
58
|
+
try:
|
|
59
|
+
yield driver
|
|
60
|
+
finally:
|
|
61
|
+
await driver.close()
|
|
62
|
+
logger.info(
|
|
63
|
+
f"Closed driver connecting to the database '{self.database}' at {self.uri}."
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def _get_driver_parameters(self) -> dict:
|
|
67
|
+
return {
|
|
68
|
+
"uri": self.uri,
|
|
69
|
+
"auth": self._get_auth(),
|
|
70
|
+
"database": self.database,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
@requires_dependencies(["neo4j"], extras="neo4j")
|
|
74
|
+
def _get_auth(self) -> "Auth":
|
|
75
|
+
from neo4j import Auth
|
|
76
|
+
|
|
77
|
+
return Auth("basic", self.username, self.access_config.get_secret_value().password)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class Neo4jUploadStagerConfig(UploadStagerConfig):
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@dataclass
|
|
85
|
+
class Neo4jUploadStager(UploadStager):
|
|
86
|
+
upload_stager_config: Neo4jUploadStagerConfig = Field(
|
|
87
|
+
default_factory=Neo4jUploadStagerConfig, validate_default=True
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
def run( # type: ignore
|
|
91
|
+
self,
|
|
92
|
+
elements_filepath: Path,
|
|
93
|
+
file_data: FileData,
|
|
94
|
+
output_dir: Path,
|
|
95
|
+
output_filename: str,
|
|
96
|
+
**kwargs: Any,
|
|
97
|
+
) -> Path:
|
|
98
|
+
with elements_filepath.open() as file:
|
|
99
|
+
elements = json.load(file)
|
|
100
|
+
|
|
101
|
+
nx_graph = self._create_lexical_graph(
|
|
102
|
+
elements, self._create_document_node(file_data=file_data)
|
|
103
|
+
)
|
|
104
|
+
output_filepath = Path(output_dir) / f"{output_filename}.json"
|
|
105
|
+
output_filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
106
|
+
|
|
107
|
+
with open(output_filepath, "w") as file:
|
|
108
|
+
json.dump(_GraphData.from_nx(nx_graph).model_dump(), file, indent=4)
|
|
109
|
+
|
|
110
|
+
return output_filepath
|
|
111
|
+
|
|
112
|
+
def _create_lexical_graph(self, elements: list[dict], document_node: _Node) -> nx.Graph:
|
|
113
|
+
graph = nx.MultiDiGraph()
|
|
114
|
+
graph.add_node(document_node)
|
|
115
|
+
|
|
116
|
+
previous_node: Optional[_Node] = None
|
|
117
|
+
for element in elements:
|
|
118
|
+
element_node = self._create_element_node(element)
|
|
119
|
+
order_relationship = (
|
|
120
|
+
Relationship.NEXT_CHUNK if self._is_chunk(element) else Relationship.NEXT_ELEMENT
|
|
121
|
+
)
|
|
122
|
+
if previous_node:
|
|
123
|
+
graph.add_edge(element_node, previous_node, relationship=order_relationship)
|
|
124
|
+
|
|
125
|
+
previous_node = element_node
|
|
126
|
+
graph.add_edge(element_node, document_node, relationship=Relationship.PART_OF_DOCUMENT)
|
|
127
|
+
|
|
128
|
+
if self._is_chunk(element):
|
|
129
|
+
origin_element_nodes = [
|
|
130
|
+
self._create_element_node(origin_element)
|
|
131
|
+
for origin_element in self._get_origin_elements(element)
|
|
132
|
+
]
|
|
133
|
+
graph.add_edges_from(
|
|
134
|
+
[
|
|
135
|
+
(origin_element_node, element_node)
|
|
136
|
+
for origin_element_node in origin_element_nodes
|
|
137
|
+
],
|
|
138
|
+
relationship=Relationship.PART_OF_CHUNK,
|
|
139
|
+
)
|
|
140
|
+
graph.add_edges_from(
|
|
141
|
+
[
|
|
142
|
+
(origin_element_node, document_node)
|
|
143
|
+
for origin_element_node in origin_element_nodes
|
|
144
|
+
],
|
|
145
|
+
relationship=Relationship.PART_OF_DOCUMENT,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
return graph
|
|
149
|
+
|
|
150
|
+
# TODO(Filip Knefel): Ensure _is_chunk is as reliable as possible, consider different checks
|
|
151
|
+
def _is_chunk(self, element: dict) -> bool:
|
|
152
|
+
return "orig_elements" in element.get("metadata", {})
|
|
153
|
+
|
|
154
|
+
def _create_document_node(self, file_data: FileData) -> _Node:
|
|
155
|
+
properties = {}
|
|
156
|
+
if file_data.source_identifiers:
|
|
157
|
+
properties["name"] = file_data.source_identifiers.filename
|
|
158
|
+
if file_data.metadata.date_created:
|
|
159
|
+
properties["date_created"] = file_data.metadata.date_created
|
|
160
|
+
if file_data.metadata.date_modified:
|
|
161
|
+
properties["date_modified"] = file_data.metadata.date_modified
|
|
162
|
+
return _Node(id_=file_data.identifier, properties=properties, labels=[Label.DOCUMENT])
|
|
163
|
+
|
|
164
|
+
def _create_element_node(self, element: dict) -> _Node:
|
|
165
|
+
properties = {"id": element["element_id"], "text": element["text"]}
|
|
166
|
+
|
|
167
|
+
if embeddings := element.get("embeddings"):
|
|
168
|
+
properties["embeddings"] = embeddings
|
|
169
|
+
|
|
170
|
+
label = Label.CHUNK if self._is_chunk(element) else Label.UNSTRUCTURED_ELEMENT
|
|
171
|
+
return _Node(id_=element["element_id"], properties=properties, labels=[label])
|
|
172
|
+
|
|
173
|
+
def _get_origin_elements(self, chunk_element: dict) -> list[dict]:
|
|
174
|
+
orig_elements = chunk_element.get("metadata", {}).get("orig_elements")
|
|
175
|
+
return elements_from_base64_gzipped_json(raw_s=orig_elements)
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
class _GraphData(BaseModel):
|
|
179
|
+
nodes: list[_Node]
|
|
180
|
+
edges: list[_Edge]
|
|
181
|
+
|
|
182
|
+
@classmethod
|
|
183
|
+
def from_nx(cls, nx_graph: nx.MultiDiGraph) -> _GraphData:
|
|
184
|
+
nodes = list(nx_graph.nodes())
|
|
185
|
+
edges = [
|
|
186
|
+
_Edge(
|
|
187
|
+
source_id=u.id_,
|
|
188
|
+
destination_id=v.id_,
|
|
189
|
+
relationship=Relationship(data_dict["relationship"]),
|
|
190
|
+
)
|
|
191
|
+
for u, v, data_dict in nx_graph.edges(data=True)
|
|
192
|
+
]
|
|
193
|
+
return _GraphData(nodes=nodes, edges=edges)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class _Node(BaseModel):
|
|
197
|
+
model_config = ConfigDict(use_enum_values=True)
|
|
198
|
+
|
|
199
|
+
id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
200
|
+
labels: list[Label] = Field(default_factory=list)
|
|
201
|
+
properties: dict = Field(default_factory=dict)
|
|
202
|
+
|
|
203
|
+
def __hash__(self):
|
|
204
|
+
return hash(self.id_)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class _Edge(BaseModel):
|
|
208
|
+
model_config = ConfigDict(use_enum_values=True)
|
|
209
|
+
|
|
210
|
+
source_id: str
|
|
211
|
+
destination_id: str
|
|
212
|
+
relationship: Relationship
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class Label(str, Enum):
|
|
216
|
+
UNSTRUCTURED_ELEMENT = "UnstructuredElement"
|
|
217
|
+
CHUNK = "Chunk"
|
|
218
|
+
DOCUMENT = "Document"
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
class Relationship(str, Enum):
|
|
222
|
+
PART_OF_DOCUMENT = "PART_OF_DOCUMENT"
|
|
223
|
+
PART_OF_CHUNK = "PART_OF_CHUNK"
|
|
224
|
+
NEXT_CHUNK = "NEXT_CHUNK"
|
|
225
|
+
NEXT_ELEMENT = "NEXT_ELEMENT"
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class Neo4jUploaderConfig(UploaderConfig):
|
|
229
|
+
batch_size: int = Field(
|
|
230
|
+
default=100, description="Maximal number of nodes/relationships created per transaction."
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@dataclass
|
|
235
|
+
class Neo4jUploader(Uploader):
|
|
236
|
+
upload_config: Neo4jUploaderConfig
|
|
237
|
+
connection_config: Neo4jConnectionConfig
|
|
238
|
+
connector_type: str = CONNECTOR_TYPE
|
|
239
|
+
|
|
240
|
+
@DestinationConnectionError.wrap
|
|
241
|
+
def precheck(self) -> None:
|
|
242
|
+
async def verify_auth():
|
|
243
|
+
async with self.connection_config.get_client() as client:
|
|
244
|
+
await client.verify_connectivity()
|
|
245
|
+
|
|
246
|
+
asyncio.run(verify_auth())
|
|
247
|
+
|
|
248
|
+
def is_async(self):
|
|
249
|
+
return True
|
|
250
|
+
|
|
251
|
+
async def run_async(self, path: Path, file_data: FileData, **kwargs) -> None: # type: ignore
|
|
252
|
+
with path.open() as file:
|
|
253
|
+
staged_data = json.load(file)
|
|
254
|
+
|
|
255
|
+
graph_data = _GraphData.model_validate(staged_data)
|
|
256
|
+
async with self.connection_config.get_client() as client:
|
|
257
|
+
await self._create_uniqueness_constraints(client)
|
|
258
|
+
await self._delete_old_data_if_exists(file_data, client=client)
|
|
259
|
+
await self._merge_graph(graph_data=graph_data, client=client)
|
|
260
|
+
|
|
261
|
+
async def _create_uniqueness_constraints(self, client: AsyncDriver) -> None:
|
|
262
|
+
for label in Label:
|
|
263
|
+
logger.info(
|
|
264
|
+
f"Adding id uniqueness constraint for nodes labeled '{label}'"
|
|
265
|
+
" if it does not already exist."
|
|
266
|
+
)
|
|
267
|
+
constraint_name = f"{label.lower()}_id"
|
|
268
|
+
await client.execute_query(
|
|
269
|
+
f"""
|
|
270
|
+
CREATE CONSTRAINT {constraint_name} IF NOT EXISTS
|
|
271
|
+
FOR (n: {label}) REQUIRE n.id IS UNIQUE
|
|
272
|
+
"""
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
async def _delete_old_data_if_exists(self, file_data: FileData, client: AsyncDriver) -> None:
|
|
276
|
+
logger.info(f"Deleting old data for the record '{file_data.identifier}' (if present).")
|
|
277
|
+
_, summary, _ = await client.execute_query(
|
|
278
|
+
f"""
|
|
279
|
+
MATCH (n: {Label.DOCUMENT} {{id: $identifier}})
|
|
280
|
+
MATCH (n)--(m: {Label.CHUNK}|{Label.UNSTRUCTURED_ELEMENT})
|
|
281
|
+
DETACH DELETE m""",
|
|
282
|
+
identifier=file_data.identifier,
|
|
283
|
+
)
|
|
284
|
+
logger.info(
|
|
285
|
+
f"Deleted {summary.counters.nodes_deleted} nodes"
|
|
286
|
+
f" and {summary.counters.relationships_deleted} relationships."
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
async def _merge_graph(self, graph_data: _GraphData, client: AsyncDriver) -> None:
|
|
290
|
+
nodes_by_labels: defaultdict[tuple[Label, ...], list[_Node]] = defaultdict(list)
|
|
291
|
+
for node in graph_data.nodes:
|
|
292
|
+
nodes_by_labels[tuple(node.labels)].append(node)
|
|
293
|
+
|
|
294
|
+
logger.info(f"Merging {len(graph_data.nodes)} graph nodes.")
|
|
295
|
+
# NOTE: Processed in parallel as there's no overlap between accessed nodes
|
|
296
|
+
await self._execute_queries(
|
|
297
|
+
[
|
|
298
|
+
self._create_nodes_query(nodes_batch, labels)
|
|
299
|
+
for labels, nodes in nodes_by_labels.items()
|
|
300
|
+
for nodes_batch in batch_generator(nodes, batch_size=self.upload_config.batch_size)
|
|
301
|
+
],
|
|
302
|
+
client=client,
|
|
303
|
+
in_parallel=True,
|
|
304
|
+
)
|
|
305
|
+
logger.info(f"Finished merging {len(graph_data.nodes)} graph nodes.")
|
|
306
|
+
|
|
307
|
+
edges_by_relationship: defaultdict[Relationship, list[_Edge]] = defaultdict(list)
|
|
308
|
+
for edge in graph_data.edges:
|
|
309
|
+
edges_by_relationship[edge.relationship].append(edge)
|
|
310
|
+
|
|
311
|
+
logger.info(f"Merging {len(graph_data.edges)} graph relationships (edges).")
|
|
312
|
+
# NOTE: Processed sequentially to avoid queries locking node access to one another
|
|
313
|
+
await self._execute_queries(
|
|
314
|
+
[
|
|
315
|
+
self._create_edges_query(edges_batch, relationship)
|
|
316
|
+
for relationship, edges in edges_by_relationship.items()
|
|
317
|
+
for edges_batch in batch_generator(edges, batch_size=self.upload_config.batch_size)
|
|
318
|
+
],
|
|
319
|
+
client=client,
|
|
320
|
+
)
|
|
321
|
+
logger.info(f"Finished merging {len(graph_data.edges)} graph relationships (edges).")
|
|
322
|
+
|
|
323
|
+
@staticmethod
|
|
324
|
+
async def _execute_queries(
|
|
325
|
+
queries_with_parameters: list[tuple[str, dict]],
|
|
326
|
+
client: AsyncDriver,
|
|
327
|
+
in_parallel: bool = False,
|
|
328
|
+
) -> None:
|
|
329
|
+
if in_parallel:
|
|
330
|
+
logger.info(f"Executing {len(queries_with_parameters)} queries in parallel.")
|
|
331
|
+
await asyncio.gather(
|
|
332
|
+
*[
|
|
333
|
+
client.execute_query(query, parameters_=parameters)
|
|
334
|
+
for query, parameters in queries_with_parameters
|
|
335
|
+
]
|
|
336
|
+
)
|
|
337
|
+
logger.info("Finished executing parallel queries.")
|
|
338
|
+
else:
|
|
339
|
+
logger.info(f"Executing {len(queries_with_parameters)} queries sequentially.")
|
|
340
|
+
for i, (query, parameters) in enumerate(queries_with_parameters):
|
|
341
|
+
logger.info(f"Query #{i} started.")
|
|
342
|
+
await client.execute_query(query, parameters_=parameters)
|
|
343
|
+
logger.info(f"Query #{i} finished.")
|
|
344
|
+
logger.info(
|
|
345
|
+
f"Finished executing all ({len(queries_with_parameters)}) sequential queries."
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
@staticmethod
|
|
349
|
+
def _create_nodes_query(nodes: list[_Node], labels: tuple[Label, ...]) -> tuple[str, dict]:
|
|
350
|
+
labels_string = ", ".join(labels)
|
|
351
|
+
logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{labels_string}'.")
|
|
352
|
+
query_string = f"""
|
|
353
|
+
UNWIND $nodes AS node
|
|
354
|
+
MERGE (n: {labels_string} {{id: node.id}})
|
|
355
|
+
SET n += node.properties
|
|
356
|
+
"""
|
|
357
|
+
parameters = {"nodes": [{"id": node.id_, "properties": node.properties} for node in nodes]}
|
|
358
|
+
return query_string, parameters
|
|
359
|
+
|
|
360
|
+
@staticmethod
|
|
361
|
+
def _create_edges_query(edges: list[_Edge], relationship: Relationship) -> tuple[str, dict]:
|
|
362
|
+
logger.info(f"Preparing MERGE query for {len(edges)} {relationship} relationships.")
|
|
363
|
+
query_string = f"""
|
|
364
|
+
UNWIND $edges AS edge
|
|
365
|
+
MATCH (u {{id: edge.source}})
|
|
366
|
+
MATCH (v {{id: edge.destination}})
|
|
367
|
+
MERGE (u)-[:{relationship}]->(v)
|
|
368
|
+
"""
|
|
369
|
+
parameters = {
|
|
370
|
+
"edges": [
|
|
371
|
+
{"source": edge.source_id, "destination": edge.destination_id} for edge in edges
|
|
372
|
+
]
|
|
373
|
+
}
|
|
374
|
+
return query_string, parameters
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
neo4j_destination_entry = DestinationRegistryEntry(
|
|
378
|
+
connection_config=Neo4jConnectionConfig,
|
|
379
|
+
uploader=Neo4jUploader,
|
|
380
|
+
uploader_config=Neo4jUploaderConfig,
|
|
381
|
+
)
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
|
-
from pathlib import Path
|
|
4
3
|
from typing import TYPE_CHECKING, Any, Optional
|
|
5
4
|
|
|
6
5
|
from pydantic import Field, Secret
|
|
@@ -159,33 +158,6 @@ class PineconeUploadStager(UploadStager):
|
|
|
159
158
|
"metadata": metadata,
|
|
160
159
|
}
|
|
161
160
|
|
|
162
|
-
def run(
|
|
163
|
-
self,
|
|
164
|
-
file_data: FileData,
|
|
165
|
-
elements_filepath: Path,
|
|
166
|
-
output_dir: Path,
|
|
167
|
-
output_filename: str,
|
|
168
|
-
**kwargs: Any,
|
|
169
|
-
) -> Path:
|
|
170
|
-
with open(elements_filepath) as elements_file:
|
|
171
|
-
elements_contents = json.load(elements_file)
|
|
172
|
-
|
|
173
|
-
conformed_elements = [
|
|
174
|
-
self.conform_dict(element_dict=element, file_data=file_data)
|
|
175
|
-
for element in elements_contents
|
|
176
|
-
]
|
|
177
|
-
|
|
178
|
-
if Path(output_filename).suffix != ".json":
|
|
179
|
-
output_filename = f"{output_filename}.json"
|
|
180
|
-
else:
|
|
181
|
-
output_filename = f"{Path(output_filename).stem}.json"
|
|
182
|
-
output_path = Path(output_dir) / Path(f"{output_filename}")
|
|
183
|
-
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
184
|
-
|
|
185
|
-
with open(output_path, "w") as output_file:
|
|
186
|
-
json.dump(conformed_elements, output_file)
|
|
187
|
-
return output_path
|
|
188
|
-
|
|
189
161
|
|
|
190
162
|
@dataclass
|
|
191
163
|
class PineconeUploader(Uploader):
|
|
@@ -278,11 +250,9 @@ class PineconeUploader(Uploader):
|
|
|
278
250
|
raise DestinationConnectionError(f"http error: {api_error}") from api_error
|
|
279
251
|
logger.debug(f"results: {results}")
|
|
280
252
|
|
|
281
|
-
def
|
|
282
|
-
with path.open("r") as file:
|
|
283
|
-
elements_dict = json.load(file)
|
|
253
|
+
def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
|
|
284
254
|
logger.info(
|
|
285
|
-
f"writing a total of {len(
|
|
255
|
+
f"writing a total of {len(data)} elements via"
|
|
286
256
|
f" document batches to destination"
|
|
287
257
|
f" index named {self.connection_config.index_name}"
|
|
288
258
|
)
|
|
@@ -295,7 +265,7 @@ class PineconeUploader(Uploader):
|
|
|
295
265
|
self.pod_delete_by_record_id(file_data=file_data)
|
|
296
266
|
else:
|
|
297
267
|
raise ValueError(f"unexpected spec type in index description: {index_description}")
|
|
298
|
-
self.upsert_batches_async(elements_dict=
|
|
268
|
+
self.upsert_batches_async(elements_dict=data)
|
|
299
269
|
|
|
300
270
|
|
|
301
271
|
pinecone_destination_entry = DestinationRegistryEntry(
|
|
@@ -1,10 +1,9 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import json
|
|
3
3
|
from abc import ABC, abstractmethod
|
|
4
|
-
from contextlib import asynccontextmanager
|
|
4
|
+
from contextlib import asynccontextmanager, contextmanager
|
|
5
5
|
from dataclasses import dataclass, field
|
|
6
|
-
from
|
|
7
|
-
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
|
|
6
|
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, Optional
|
|
8
7
|
|
|
9
8
|
from pydantic import Field, Secret
|
|
10
9
|
|
|
@@ -24,7 +23,7 @@ from unstructured_ingest.v2.logger import logger
|
|
|
24
23
|
from unstructured_ingest.v2.utils import get_enhanced_element_id
|
|
25
24
|
|
|
26
25
|
if TYPE_CHECKING:
|
|
27
|
-
from qdrant_client import AsyncQdrantClient
|
|
26
|
+
from qdrant_client import AsyncQdrantClient, QdrantClient
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
class QdrantAccessConfig(AccessConfig, ABC):
|
|
@@ -42,8 +41,8 @@ class QdrantConnectionConfig(ConnectionConfig, ABC):
|
|
|
42
41
|
|
|
43
42
|
@requires_dependencies(["qdrant_client"], extras="qdrant")
|
|
44
43
|
@asynccontextmanager
|
|
45
|
-
async def
|
|
46
|
-
from qdrant_client
|
|
44
|
+
async def get_async_client(self) -> AsyncGenerator["AsyncQdrantClient", None]:
|
|
45
|
+
from qdrant_client import AsyncQdrantClient
|
|
47
46
|
|
|
48
47
|
client_kwargs = self.get_client_kwargs()
|
|
49
48
|
client = AsyncQdrantClient(**client_kwargs)
|
|
@@ -52,6 +51,18 @@ class QdrantConnectionConfig(ConnectionConfig, ABC):
|
|
|
52
51
|
finally:
|
|
53
52
|
await client.close()
|
|
54
53
|
|
|
54
|
+
@requires_dependencies(["qdrant_client"], extras="qdrant")
|
|
55
|
+
@contextmanager
|
|
56
|
+
def get_client(self) -> Generator["QdrantClient", None, None]:
|
|
57
|
+
from qdrant_client import QdrantClient
|
|
58
|
+
|
|
59
|
+
client_kwargs = self.get_client_kwargs()
|
|
60
|
+
client = QdrantClient(**client_kwargs)
|
|
61
|
+
try:
|
|
62
|
+
yield client
|
|
63
|
+
finally:
|
|
64
|
+
client.close()
|
|
65
|
+
|
|
55
66
|
|
|
56
67
|
class QdrantUploadStagerConfig(UploadStagerConfig):
|
|
57
68
|
pass
|
|
@@ -63,9 +74,9 @@ class QdrantUploadStager(UploadStager, ABC):
|
|
|
63
74
|
default_factory=lambda: QdrantUploadStagerConfig()
|
|
64
75
|
)
|
|
65
76
|
|
|
66
|
-
|
|
67
|
-
def conform_dict(data: dict, file_data: FileData) -> dict:
|
|
77
|
+
def conform_dict(self, element_dict: dict, file_data: FileData) -> dict:
|
|
68
78
|
"""Prepares dictionary in the format that Chroma requires"""
|
|
79
|
+
data = element_dict.copy()
|
|
69
80
|
return {
|
|
70
81
|
"id": get_enhanced_element_id(element_dict=data, file_data=file_data),
|
|
71
82
|
"vector": data.pop("embeddings", {}),
|
|
@@ -80,26 +91,6 @@ class QdrantUploadStager(UploadStager, ABC):
|
|
|
80
91
|
},
|
|
81
92
|
}
|
|
82
93
|
|
|
83
|
-
def run(
|
|
84
|
-
self,
|
|
85
|
-
elements_filepath: Path,
|
|
86
|
-
file_data: FileData,
|
|
87
|
-
output_dir: Path,
|
|
88
|
-
output_filename: str,
|
|
89
|
-
**kwargs: Any,
|
|
90
|
-
) -> Path:
|
|
91
|
-
with open(elements_filepath) as elements_file:
|
|
92
|
-
elements_contents = json.load(elements_file)
|
|
93
|
-
|
|
94
|
-
conformed_elements = [
|
|
95
|
-
self.conform_dict(data=element, file_data=file_data) for element in elements_contents
|
|
96
|
-
]
|
|
97
|
-
output_path = Path(output_dir) / Path(f"{output_filename}.json")
|
|
98
|
-
|
|
99
|
-
with open(output_path, "w") as output_file:
|
|
100
|
-
json.dump(conformed_elements, output_file)
|
|
101
|
-
return output_path
|
|
102
|
-
|
|
103
94
|
|
|
104
95
|
class QdrantUploaderConfig(UploaderConfig):
|
|
105
96
|
collection_name: str = Field(description="Name of the collection.")
|
|
@@ -118,27 +109,27 @@ class QdrantUploader(Uploader, ABC):
|
|
|
118
109
|
|
|
119
110
|
@DestinationConnectionError.wrap
|
|
120
111
|
def precheck(self) -> None:
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
112
|
+
with self.connection_config.get_client() as client:
|
|
113
|
+
collections_response = client.get_collections()
|
|
114
|
+
collection_names = [c.name for c in collections_response.collections]
|
|
115
|
+
if self.upload_config.collection_name not in collection_names:
|
|
116
|
+
raise DestinationConnectionError(
|
|
117
|
+
"collection '{}' not found: {}".format(
|
|
118
|
+
self.upload_config.collection_name, ", ".join(collection_names)
|
|
119
|
+
)
|
|
120
|
+
)
|
|
126
121
|
|
|
127
122
|
def is_async(self):
|
|
128
123
|
return True
|
|
129
124
|
|
|
130
|
-
async def
|
|
125
|
+
async def run_data_async(
|
|
131
126
|
self,
|
|
132
|
-
|
|
127
|
+
data: list[dict],
|
|
133
128
|
file_data: FileData,
|
|
134
129
|
**kwargs: Any,
|
|
135
130
|
) -> None:
|
|
136
|
-
with path.open("r") as file:
|
|
137
|
-
elements: list[dict] = json.load(file)
|
|
138
|
-
|
|
139
|
-
logger.debug("Loaded %i elements from %s", len(elements), path)
|
|
140
131
|
|
|
141
|
-
batches = list(batch_generator(
|
|
132
|
+
batches = list(batch_generator(data, batch_size=self.upload_config.batch_size))
|
|
142
133
|
logger.debug(
|
|
143
134
|
"Elements split into %i batches of size %i.",
|
|
144
135
|
len(batches),
|
|
@@ -156,7 +147,7 @@ class QdrantUploader(Uploader, ABC):
|
|
|
156
147
|
len(points),
|
|
157
148
|
self.upload_config.collection_name,
|
|
158
149
|
)
|
|
159
|
-
async with self.connection_config.
|
|
150
|
+
async with self.connection_config.get_async_client() as async_client:
|
|
160
151
|
await async_client.upsert(
|
|
161
152
|
self.upload_config.collection_name, points=points, wait=True
|
|
162
153
|
)
|