unstructured-ingest 0.5.2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

@@ -1,24 +1,19 @@
1
1
  from dataclasses import dataclass
2
- from typing import TYPE_CHECKING
2
+ from typing import TYPE_CHECKING, Any
3
3
 
4
4
  from pydantic import Field, SecretStr
5
5
 
6
6
  from unstructured_ingest.embed.interfaces import (
7
- EMBEDDINGS_KEY,
8
7
  AsyncBaseEmbeddingEncoder,
9
8
  BaseEmbeddingEncoder,
10
9
  EmbeddingConfig,
11
10
  )
12
11
  from unstructured_ingest.logger import logger
13
- from unstructured_ingest.utils.data_prep import batch_generator
14
12
  from unstructured_ingest.utils.dep_check import requires_dependencies
15
13
  from unstructured_ingest.v2.errors import (
16
14
  RateLimitError as CustomRateLimitError,
17
15
  )
18
- from unstructured_ingest.v2.errors import (
19
- UserAuthError,
20
- UserError,
21
- )
16
+ from unstructured_ingest.v2.errors import UserAuthError, UserError, is_internal_error
22
17
 
23
18
  if TYPE_CHECKING:
24
19
  from together import AsyncTogether, Together
@@ -31,6 +26,8 @@ class TogetherAIEmbeddingConfig(EmbeddingConfig):
31
26
  )
32
27
 
33
28
  def wrap_error(self, e: Exception) -> Exception:
29
+ if is_internal_error(e=e):
30
+ return e
34
31
  # https://docs.together.ai/docs/error-codes
35
32
  from together.error import AuthenticationError, RateLimitError, TogetherException
36
33
 
@@ -64,31 +61,12 @@ class TogetherAIEmbeddingEncoder(BaseEmbeddingEncoder):
64
61
  def wrap_error(self, e: Exception) -> Exception:
65
62
  return self.config.wrap_error(e=e)
66
63
 
67
- def embed_query(self, query: str) -> list[float]:
68
- return self._embed_documents(elements=[query])[0]
69
-
70
- def embed_documents(self, elements: list[dict]) -> list[dict]:
71
- elements = elements.copy()
72
- elements_with_text = [e for e in elements if e.get("text")]
73
- embeddings = self._embed_documents([e["text"] for e in elements_with_text])
74
- for element, embedding in zip(elements_with_text, embeddings):
75
- element[EMBEDDINGS_KEY] = embedding
76
- return elements
77
-
78
- def _embed_documents(self, elements: list[str]) -> list[list[float]]:
79
- client = self.config.get_client()
80
- embeddings = []
81
- try:
82
- for batch in batch_generator(
83
- elements, batch_size=self.config.batch_size or len(elements)
84
- ):
85
- outputs = client.embeddings.create(
86
- model=self.config.embedder_model_name, input=batch
87
- )
88
- embeddings.extend([outputs.data[i].embedding for i in range(len(batch))])
89
- except Exception as e:
90
- raise self.wrap_error(e=e)
91
- return embeddings
64
+ def get_client(self) -> "Together":
65
+ return self.config.get_client()
66
+
67
+ def embed_batch(self, client: "Together", batch: list[str]) -> list[list[float]]:
68
+ outputs = client.embeddings.create(model=self.config.embedder_model_name, input=batch)
69
+ return [outputs.data[i].embedding for i in range(len(batch))]
92
70
 
93
71
 
94
72
  @dataclass
@@ -98,29 +76,9 @@ class AsyncTogetherAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
98
76
  def wrap_error(self, e: Exception) -> Exception:
99
77
  return self.config.wrap_error(e=e)
100
78
 
101
- async def embed_query(self, query: str) -> list[float]:
102
- embedding = await self._embed_documents(elements=[query])
103
- return embedding[0]
104
-
105
- async def embed_documents(self, elements: list[dict]) -> list[dict]:
106
- elements = elements.copy()
107
- elements_with_text = [e for e in elements if e.get("text")]
108
- embeddings = await self._embed_documents([e["text"] for e in elements_with_text])
109
- for element, embedding in zip(elements_with_text, embeddings):
110
- element[EMBEDDINGS_KEY] = embedding
111
- return elements
112
-
113
- async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
114
- client = self.config.get_async_client()
115
- embeddings = []
116
- try:
117
- for batch in batch_generator(
118
- elements, batch_size=self.config.batch_size or len(elements)
119
- ):
120
- outputs = await client.embeddings.create(
121
- model=self.config.embedder_model_name, input=batch
122
- )
123
- embeddings.extend([outputs.data[i].embedding for i in range(len(batch))])
124
- except Exception as e:
125
- raise self.wrap_error(e=e)
126
- return embeddings
79
+ def get_client(self) -> "AsyncTogether":
80
+ return self.config.get_async_client()
81
+
82
+ async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
83
+ outputs = await client.embeddings.create(model=self.config.embedder_model_name, input=batch)
84
+ return [outputs.data[i].embedding for i in range(len(batch))]
@@ -9,14 +9,12 @@ from pydantic import Field, Secret, ValidationError
9
9
  from pydantic.functional_validators import BeforeValidator
10
10
 
11
11
  from unstructured_ingest.embed.interfaces import (
12
- EMBEDDINGS_KEY,
13
12
  AsyncBaseEmbeddingEncoder,
14
13
  BaseEmbeddingEncoder,
15
14
  EmbeddingConfig,
16
15
  )
17
- from unstructured_ingest.utils.data_prep import batch_generator
18
16
  from unstructured_ingest.utils.dep_check import requires_dependencies
19
- from unstructured_ingest.v2.errors import UserAuthError
17
+ from unstructured_ingest.v2.errors import UserAuthError, is_internal_error
20
18
 
21
19
  if TYPE_CHECKING:
22
20
  from vertexai.language_models import TextEmbeddingModel
@@ -40,6 +38,8 @@ class VertexAIEmbeddingConfig(EmbeddingConfig):
40
38
  )
41
39
 
42
40
  def wrap_error(self, e: Exception) -> Exception:
41
+ if is_internal_error(e=e):
42
+ return e
43
43
  from google.auth.exceptions import GoogleAuthError
44
44
 
45
45
  if isinstance(e, GoogleAuthError):
@@ -72,34 +72,19 @@ class VertexAIEmbeddingEncoder(BaseEmbeddingEncoder):
72
72
  def wrap_error(self, e: Exception) -> Exception:
73
73
  return self.config.wrap_error(e=e)
74
74
 
75
- def embed_query(self, query):
76
- return self._embed_documents(elements=[query])[0]
77
-
78
- def embed_documents(self, elements: list[dict]) -> list[dict]:
79
- elements = elements.copy()
80
- elements_with_text = [e for e in elements if e.get("text")]
81
- embeddings = self._embed_documents([e["text"] for e in elements_with_text])
82
- for element, embedding in zip(elements_with_text, embeddings):
83
- element[EMBEDDINGS_KEY] = embedding
84
- return elements
75
+ def get_client(self) -> "TextEmbeddingModel":
76
+ return self.config.get_client()
85
77
 
86
78
  @requires_dependencies(
87
79
  ["vertexai"],
88
80
  extras="embed-vertexai",
89
81
  )
90
- def _embed_documents(self, elements: list[str]) -> list[list[float]]:
82
+ def embed_batch(self, client: "TextEmbeddingModel", batch: list[str]) -> list[list[float]]:
91
83
  from vertexai.language_models import TextEmbeddingInput
92
84
 
93
- inputs = [TextEmbeddingInput(text=element) for element in elements]
94
- client = self.config.get_client()
95
- embeddings = []
96
- try:
97
- for batch in batch_generator(inputs, batch_size=self.config.batch_size or len(inputs)):
98
- response = client.get_embeddings(batch)
99
- embeddings.extend([e.values for e in response])
100
- except Exception as e:
101
- raise self.wrap_error(e=e)
102
- return embeddings
85
+ inputs = [TextEmbeddingInput(text=text) for text in batch]
86
+ response = client.get_embeddings(inputs)
87
+ return [e.values for e in response]
103
88
 
104
89
 
105
90
  @dataclass
@@ -109,32 +94,16 @@ class AsyncVertexAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
109
94
  def wrap_error(self, e: Exception) -> Exception:
110
95
  return self.config.wrap_error(e=e)
111
96
 
112
- async def embed_query(self, query):
113
- embedding = await self._embed_documents(elements=[query])
114
- return embedding[0]
115
-
116
- async def embed_documents(self, elements: list[dict]) -> list[dict]:
117
- elements = elements.copy()
118
- elements_with_text = [e for e in elements if e.get("text")]
119
- embeddings = await self._embed_documents([e["text"] for e in elements_with_text])
120
- for element, embedding in zip(elements_with_text, embeddings):
121
- element[EMBEDDINGS_KEY] = embedding
122
- return elements
97
+ def get_client(self) -> "TextEmbeddingModel":
98
+ return self.config.get_client()
123
99
 
124
100
  @requires_dependencies(
125
101
  ["vertexai"],
126
102
  extras="embed-vertexai",
127
103
  )
128
- async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
104
+ async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
129
105
  from vertexai.language_models import TextEmbeddingInput
130
106
 
131
- inputs = [TextEmbeddingInput(text=element) for element in elements]
132
- client = self.config.get_client()
133
- embeddings = []
134
- try:
135
- for batch in batch_generator(inputs, batch_size=self.config.batch_size or len(inputs)):
136
- response = await client.get_embeddings_async(batch)
137
- embeddings.extend([e.values for e in response])
138
- except Exception as e:
139
- raise self.wrap_error(e=e)
140
- return embeddings
107
+ inputs = [TextEmbeddingInput(text=text) for text in batch]
108
+ response = await client.get_embeddings_async(inputs)
109
+ return [e.values for e in response]
@@ -4,19 +4,13 @@ from typing import TYPE_CHECKING, Optional
4
4
  from pydantic import Field, SecretStr
5
5
 
6
6
  from unstructured_ingest.embed.interfaces import (
7
- EMBEDDINGS_KEY,
8
7
  AsyncBaseEmbeddingEncoder,
9
8
  BaseEmbeddingEncoder,
10
9
  EmbeddingConfig,
11
10
  )
12
11
  from unstructured_ingest.logger import logger
13
- from unstructured_ingest.utils.data_prep import batch_generator
14
12
  from unstructured_ingest.utils.dep_check import requires_dependencies
15
- from unstructured_ingest.v2.errors import (
16
- ProviderError,
17
- UserAuthError,
18
- UserError,
19
- )
13
+ from unstructured_ingest.v2.errors import ProviderError, UserAuthError, UserError, is_internal_error
20
14
  from unstructured_ingest.v2.errors import (
21
15
  RateLimitError as CustomRateLimitError,
22
16
  )
@@ -39,6 +33,8 @@ class VoyageAIEmbeddingConfig(EmbeddingConfig):
39
33
  timeout_in_seconds: Optional[int] = None
40
34
 
41
35
  def wrap_error(self, e: Exception) -> Exception:
36
+ if is_internal_error(e=e):
37
+ return e
42
38
  # https://docs.voyageai.com/docs/error-codes
43
39
  from voyageai.error import AuthenticationError, RateLimitError, VoyageError
44
40
 
@@ -96,27 +92,12 @@ class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder):
96
92
  def wrap_error(self, e: Exception) -> Exception:
97
93
  return self.config.wrap_error(e=e)
98
94
 
99
- def _embed_documents(self, elements: list[str]) -> list[list[float]]:
100
- client = self.config.get_client()
101
- embeddings = []
102
- try:
103
- for batch in batch_generator(elements, batch_size=self.config.batch_size):
104
- response = client.embed(texts=batch, model=self.config.embedder_model_name)
105
- embeddings.extend(response.embeddings)
106
- except Exception as e:
107
- raise self.wrap_error(e=e)
108
- return embeddings
109
-
110
- def embed_documents(self, elements: list[dict]) -> list[dict]:
111
- elements = elements.copy()
112
- elements_with_text = [e for e in elements if e.get("text")]
113
- embeddings = self._embed_documents([e["text"] for e in elements_with_text])
114
- for element, embedding in zip(elements_with_text, embeddings):
115
- element[EMBEDDINGS_KEY] = embedding
116
- return elements
117
-
118
- def embed_query(self, query: str) -> list[float]:
119
- return self._embed_documents(elements=[query])[0]
95
+ def get_client(self) -> "VoyageAIClient":
96
+ return self.config.get_client()
97
+
98
+ def embed_batch(self, client: "VoyageAIClient", batch: list[str]) -> list[list[float]]:
99
+ response = client.embed(texts=batch, model=self.config.embedder_model_name)
100
+ return response.embeddings
120
101
 
121
102
 
122
103
  @dataclass
@@ -126,27 +107,11 @@ class AsyncVoyageAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
126
107
  def wrap_error(self, e: Exception) -> Exception:
127
108
  return self.config.wrap_error(e=e)
128
109
 
129
- async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
130
- client = self.config.get_async_client()
131
- embeddings = []
132
- try:
133
- for batch in batch_generator(
134
- elements, batch_size=self.config.batch_size or len(elements)
135
- ):
136
- response = await client.embed(texts=batch, model=self.config.embedder_model_name)
137
- embeddings.extend(response.embeddings)
138
- except Exception as e:
139
- raise self.wrap_error(e=e)
140
- return embeddings
141
-
142
- async def embed_documents(self, elements: list[dict]) -> list[dict]:
143
- elements = elements.copy()
144
- elements_with_text = [e for e in elements if e.get("text")]
145
- embeddings = await self._embed_documents([e["text"] for e in elements_with_text])
146
- for element, embedding in zip(elements_with_text, embeddings):
147
- element[EMBEDDINGS_KEY] = embedding
148
- return elements
149
-
150
- async def embed_query(self, query: str) -> list[float]:
151
- embedding = await self._embed_documents(elements=[query])
152
- return embedding[0]
110
+ def get_client(self) -> "AsyncVoyageAIClient":
111
+ return self.config.get_async_client()
112
+
113
+ async def embed_batch(
114
+ self, client: "AsyncVoyageAIClient", batch: list[str]
115
+ ) -> list[list[float]]:
116
+ response = await client.embed(texts=batch, model=self.config.embedder_model_name)
117
+ return response.embeddings
@@ -16,3 +16,10 @@ class QuotaError(UserError):
16
16
 
17
17
  class ProviderError(Exception):
18
18
  pass
19
+
20
+
21
+ recognized_errors = [UserError, UserAuthError, RateLimitError, QuotaError, ProviderError]
22
+
23
+
24
+ def is_internal_error(e: Exception) -> bool:
25
+ return any(isinstance(e, recognized_error) for recognized_error in recognized_errors)
@@ -8,9 +8,9 @@ from contextlib import asynccontextmanager
8
8
  from dataclasses import dataclass
9
9
  from enum import Enum
10
10
  from pathlib import Path
11
- from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
11
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Literal, Optional
12
12
 
13
- from pydantic import BaseModel, ConfigDict, Field, Secret
13
+ from pydantic import BaseModel, ConfigDict, Field, Secret, field_validator
14
14
 
15
15
  from unstructured_ingest.error import DestinationConnectionError
16
16
  from unstructured_ingest.logger import logger
@@ -30,6 +30,8 @@ from unstructured_ingest.v2.processes.connector_registry import (
30
30
  DestinationRegistryEntry,
31
31
  )
32
32
 
33
+ SimilarityFunction = Literal["cosine"]
34
+
33
35
  if TYPE_CHECKING:
34
36
  from neo4j import AsyncDriver, Auth
35
37
  from networkx import Graph, MultiDiGraph
@@ -44,9 +46,9 @@ class Neo4jAccessConfig(AccessConfig):
44
46
  class Neo4jConnectionConfig(ConnectionConfig):
45
47
  access_config: Secret[Neo4jAccessConfig]
46
48
  connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
47
- username: str
49
+ username: str = Field(default="neo4j")
48
50
  uri: str = Field(description="Neo4j Connection URI <scheme>://<host>:<port>")
49
- database: str = Field(description="Name of the target database")
51
+ database: str = Field(default="neo4j", description="Name of the target database")
50
52
 
51
53
  @requires_dependencies(["neo4j"], extras="neo4j")
52
54
  @asynccontextmanager
@@ -186,8 +188,8 @@ class _GraphData(BaseModel):
186
188
  nodes = list(nx_graph.nodes())
187
189
  edges = [
188
190
  _Edge(
189
- source_id=u.id_,
190
- destination_id=v.id_,
191
+ source=u,
192
+ destination=v,
191
193
  relationship=Relationship(data_dict["relationship"]),
192
194
  )
193
195
  for u, v, data_dict in nx_graph.edges(data=True)
@@ -198,19 +200,30 @@ class _GraphData(BaseModel):
198
200
  class _Node(BaseModel):
199
201
  model_config = ConfigDict()
200
202
 
201
- id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
202
- labels: list[Label] = Field(default_factory=list)
203
+ labels: list[Label]
203
204
  properties: dict = Field(default_factory=dict)
205
+ id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
204
206
 
205
207
  def __hash__(self):
206
208
  return hash(self.id_)
207
209
 
210
+ @property
211
+ def main_label(self) -> Label:
212
+ return self.labels[0]
213
+
214
+ @classmethod
215
+ @field_validator("labels", mode="after")
216
+ def require_at_least_one_label(cls, value: list[Label]) -> list[Label]:
217
+ if not value:
218
+ raise ValueError("Node must have at least one label.")
219
+ return value
220
+
208
221
 
209
222
  class _Edge(BaseModel):
210
223
  model_config = ConfigDict()
211
224
 
212
- source_id: str
213
- destination_id: str
225
+ source: _Node
226
+ destination: _Node
214
227
  relationship: Relationship
215
228
 
216
229
 
@@ -229,7 +242,14 @@ class Relationship(Enum):
229
242
 
230
243
  class Neo4jUploaderConfig(UploaderConfig):
231
244
  batch_size: int = Field(
232
- default=100, description="Maximal number of nodes/relationships created per transaction."
245
+ default=1000, description="Maximal number of nodes/relationships created per transaction."
246
+ )
247
+ similarity_function: SimilarityFunction = Field(
248
+ default="cosine",
249
+ description="Vector similarity function used to create index on Chunk nodes",
250
+ )
251
+ create_destination: bool = Field(
252
+ default=True, description="Create destination if it does not exist"
233
253
  )
234
254
 
235
255
 
@@ -257,6 +277,13 @@ class Neo4jUploader(Uploader):
257
277
  graph_data = _GraphData.model_validate(staged_data)
258
278
  async with self.connection_config.get_client() as client:
259
279
  await self._create_uniqueness_constraints(client)
280
+ embedding_dimensions = self._get_embedding_dimensions(graph_data)
281
+ if embedding_dimensions and self.upload_config.create_destination:
282
+ await self._create_vector_index(
283
+ client,
284
+ dimensions=embedding_dimensions,
285
+ similarity_function=self.upload_config.similarity_function,
286
+ )
260
287
  await self._delete_old_data_if_exists(file_data, client=client)
261
288
  await self._merge_graph(graph_data=graph_data, client=client)
262
289
 
@@ -274,13 +301,33 @@ class Neo4jUploader(Uploader):
274
301
  """
275
302
  )
276
303
 
304
+ async def _create_vector_index(
305
+ self, client: AsyncDriver, dimensions: int, similarity_function: SimilarityFunction
306
+ ) -> None:
307
+ label = Label.CHUNK
308
+ logger.info(
309
+ f"Creating index on nodes labeled '{label.value}' if it does not already exist."
310
+ )
311
+ index_name = f"{label.value.lower()}_vector"
312
+ await client.execute_query(
313
+ f"""
314
+ CREATE VECTOR INDEX {index_name} IF NOT EXISTS
315
+ FOR (n:{label.value}) ON n.embedding
316
+ OPTIONS {{indexConfig: {{
317
+ `vector.similarity_function`: '{similarity_function}',
318
+ `vector.dimensions`: {dimensions}}}
319
+ }}
320
+ """
321
+ )
322
+
277
323
  async def _delete_old_data_if_exists(self, file_data: FileData, client: AsyncDriver) -> None:
278
324
  logger.info(f"Deleting old data for the record '{file_data.identifier}' (if present).")
279
325
  _, summary, _ = await client.execute_query(
280
326
  f"""
281
- MATCH (n: {Label.DOCUMENT.value} {{id: $identifier}})
282
- MATCH (n)--(m: {Label.CHUNK.value}|{Label.UNSTRUCTURED_ELEMENT.value})
283
- DETACH DELETE m""",
327
+ MATCH (n: `{Label.DOCUMENT.value}` {{id: $identifier}})
328
+ MATCH (n)--(m: `{Label.CHUNK.value}`|`{Label.UNSTRUCTURED_ELEMENT.value}`)
329
+ DETACH DELETE m
330
+ DETACH DELETE n""",
284
331
  identifier=file_data.identifier,
285
332
  )
286
333
  logger.info(
@@ -289,16 +336,15 @@ class Neo4jUploader(Uploader):
289
336
  )
290
337
 
291
338
  async def _merge_graph(self, graph_data: _GraphData, client: AsyncDriver) -> None:
292
- nodes_by_labels: defaultdict[tuple[Label, ...], list[_Node]] = defaultdict(list)
339
+ nodes_by_labels: defaultdict[Label, list[_Node]] = defaultdict(list)
293
340
  for node in graph_data.nodes:
294
- nodes_by_labels[tuple(node.labels)].append(node)
295
-
341
+ nodes_by_labels[node.main_label].append(node)
296
342
  logger.info(f"Merging {len(graph_data.nodes)} graph nodes.")
297
343
  # NOTE: Processed in parallel as there's no overlap between accessed nodes
298
344
  await self._execute_queries(
299
345
  [
300
- self._create_nodes_query(nodes_batch, labels)
301
- for labels, nodes in nodes_by_labels.items()
346
+ self._create_nodes_query(nodes_batch, label)
347
+ for label, nodes in nodes_by_labels.items()
302
348
  for nodes_batch in batch_generator(nodes, batch_size=self.upload_config.batch_size)
303
349
  ],
304
350
  client=client,
@@ -306,16 +352,23 @@ class Neo4jUploader(Uploader):
306
352
  )
307
353
  logger.info(f"Finished merging {len(graph_data.nodes)} graph nodes.")
308
354
 
309
- edges_by_relationship: defaultdict[Relationship, list[_Edge]] = defaultdict(list)
355
+ edges_by_relationship: defaultdict[tuple[Relationship, Label, Label], list[_Edge]] = (
356
+ defaultdict(list)
357
+ )
310
358
  for edge in graph_data.edges:
311
- edges_by_relationship[edge.relationship].append(edge)
359
+ key = (edge.relationship, edge.source.main_label, edge.destination.main_label)
360
+ edges_by_relationship[key].append(edge)
312
361
 
313
362
  logger.info(f"Merging {len(graph_data.edges)} graph relationships (edges).")
314
363
  # NOTE: Processed sequentially to avoid queries locking node access to one another
315
364
  await self._execute_queries(
316
365
  [
317
- self._create_edges_query(edges_batch, relationship)
318
- for relationship, edges in edges_by_relationship.items()
366
+ self._create_edges_query(edges_batch, relationship, source_label, destination_label)
367
+ for (
368
+ relationship,
369
+ source_label,
370
+ destination_label,
371
+ ), edges in edges_by_relationship.items()
319
372
  for edges_batch in batch_generator(edges, batch_size=self.upload_config.batch_size)
320
373
  ],
321
374
  client=client,
@@ -328,53 +381,86 @@ class Neo4jUploader(Uploader):
328
381
  client: AsyncDriver,
329
382
  in_parallel: bool = False,
330
383
  ) -> None:
384
+ from neo4j import EagerResult
385
+
386
+ results: list[EagerResult] = []
387
+ logger.info(
388
+ f"Executing {len(queries_with_parameters)} "
389
+ + f"{'parallel' if in_parallel else 'sequential'} Cypher statements."
390
+ )
331
391
  if in_parallel:
332
- logger.info(f"Executing {len(queries_with_parameters)} queries in parallel.")
333
- await asyncio.gather(
392
+ results = await asyncio.gather(
334
393
  *[
335
394
  client.execute_query(query, parameters_=parameters)
336
395
  for query, parameters in queries_with_parameters
337
396
  ]
338
397
  )
339
- logger.info("Finished executing parallel queries.")
340
398
  else:
341
- logger.info(f"Executing {len(queries_with_parameters)} queries sequentially.")
342
399
  for i, (query, parameters) in enumerate(queries_with_parameters):
343
- logger.info(f"Query #{i} started.")
344
- await client.execute_query(query, parameters_=parameters)
345
- logger.info(f"Query #{i} finished.")
346
- logger.info(
347
- f"Finished executing all ({len(queries_with_parameters)}) sequential queries."
348
- )
400
+ logger.info(f"Statement #{i} started.")
401
+ results.append(await client.execute_query(query, parameters_=parameters))
402
+ logger.info(f"Statement #{i} finished.")
403
+ nodeCount = sum([res.summary.counters.nodes_created for res in results])
404
+ relCount = sum([res.summary.counters.relationships_created for res in results])
405
+ logger.info(
406
+ f"Finished executing all ({len(queries_with_parameters)}) "
407
+ + f"{'parallel' if in_parallel else 'sequential'} Cypher statements. "
408
+ + f"Created {nodeCount} nodes, {relCount} relationships."
409
+ )
349
410
 
350
411
  @staticmethod
351
- def _create_nodes_query(nodes: list[_Node], labels: tuple[Label, ...]) -> tuple[str, dict]:
352
- labels_string = ", ".join([label.value for label in labels])
353
- logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{labels_string}'.")
412
+ def _create_nodes_query(nodes: list[_Node], label: Label) -> tuple[str, dict]:
413
+ logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{label}'.")
354
414
  query_string = f"""
355
415
  UNWIND $nodes AS node
356
- MERGE (n: {labels_string} {{id: node.id}})
416
+ MERGE (n: `{label.value}` {{id: node.id}})
357
417
  SET n += node.properties
418
+ SET n:$(node.labels)
419
+ WITH * WHERE node.vector IS NOT NULL
420
+ CALL db.create.setNodeVectorProperty(n, 'embedding', node.vector)
358
421
  """
359
- parameters = {"nodes": [{"id": node.id_, "properties": node.properties} for node in nodes]}
422
+ parameters = {
423
+ "nodes": [
424
+ {
425
+ "id": node.id_,
426
+ "labels": [l.value for l in node.labels if l != label], # noqa: E741
427
+ "vector": node.properties.pop("embedding", None),
428
+ "properties": node.properties,
429
+ }
430
+ for node in nodes
431
+ ]
432
+ }
360
433
  return query_string, parameters
361
434
 
362
435
  @staticmethod
363
- def _create_edges_query(edges: list[_Edge], relationship: Relationship) -> tuple[str, dict]:
436
+ def _create_edges_query(
437
+ edges: list[_Edge],
438
+ relationship: Relationship,
439
+ source_label: Label,
440
+ destination_label: Label,
441
+ ) -> tuple[str, dict]:
364
442
  logger.info(f"Preparing MERGE query for {len(edges)} {relationship} relationships.")
365
443
  query_string = f"""
366
444
  UNWIND $edges AS edge
367
- MATCH (u {{id: edge.source}})
368
- MATCH (v {{id: edge.destination}})
369
- MERGE (u)-[:{relationship.value}]->(v)
445
+ MATCH (u: `{source_label.value}` {{id: edge.source}})
446
+ MATCH (v: `{destination_label.value}` {{id: edge.destination}})
447
+ MERGE (u)-[:`{relationship.value}`]->(v)
370
448
  """
371
449
  parameters = {
372
450
  "edges": [
373
- {"source": edge.source_id, "destination": edge.destination_id} for edge in edges
451
+ {"source": edge.source.id_, "destination": edge.destination.id_} for edge in edges
374
452
  ]
375
453
  }
376
454
  return query_string, parameters
377
455
 
456
+ def _get_embedding_dimensions(self, graph_data: _GraphData) -> int | None:
457
+ """Embedding dimensions inferred from chunk nodes or None if it can't be determined."""
458
+ for node in graph_data.nodes:
459
+ if Label.CHUNK in node.labels and "embeddings" in node.properties:
460
+ return len(node.properties["embeddings"])
461
+
462
+ return None
463
+
378
464
 
379
465
  neo4j_destination_entry = DestinationRegistryEntry(
380
466
  connection_config=Neo4jConnectionConfig,
@@ -92,18 +92,20 @@ class EmbedderConfig(BaseModel):
92
92
 
93
93
  return OctoAIEmbeddingEncoder(config=OctoAiEmbeddingConfig.model_validate(embedding_kwargs))
94
94
 
95
- def get_bedrock_embedder(self) -> "BaseEmbeddingEncoder":
95
+ def get_bedrock_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
96
96
  from unstructured_ingest.embed.bedrock import (
97
97
  BedrockEmbeddingConfig,
98
98
  BedrockEmbeddingEncoder,
99
99
  )
100
100
 
101
+ embedding_kwargs = embedding_kwargs | {
102
+ "aws_access_key_id": self.embedding_aws_access_key_id,
103
+ "aws_secret_access_key": self.embedding_aws_secret_access_key.get_secret_value(),
104
+ "region_name": self.embedding_aws_region,
105
+ }
106
+
101
107
  return BedrockEmbeddingEncoder(
102
- config=BedrockEmbeddingConfig(
103
- aws_access_key_id=self.embedding_aws_access_key_id,
104
- aws_secret_access_key=self.embedding_aws_secret_access_key.get_secret_value(),
105
- region_name=self.embedding_aws_region,
106
- )
108
+ config=BedrockEmbeddingConfig.model_validate(embedding_kwargs)
107
109
  )
108
110
 
109
111
  def get_vertexai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
@@ -163,7 +165,7 @@ class EmbedderConfig(BaseModel):
163
165
  return self.get_octoai_embedder(embedding_kwargs=kwargs)
164
166
 
165
167
  if self.embedding_provider == "bedrock":
166
- return self.get_bedrock_embedder()
168
+ return self.get_bedrock_embedder(embedding_kwargs=kwargs)
167
169
 
168
170
  if self.embedding_provider == "vertexai":
169
171
  return self.get_vertexai_embedder(embedding_kwargs=kwargs)