unstructured-ingest 0.5.2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/unit/v2/embedders/test_bedrock.py +1 -1
- test/unit/v2/embedders/test_huggingface.py +1 -1
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/embed/azure_openai.py +6 -0
- unstructured_ingest/embed/bedrock.py +16 -6
- unstructured_ingest/embed/huggingface.py +3 -1
- unstructured_ingest/embed/interfaces.py +61 -23
- unstructured_ingest/embed/mixedbreadai.py +28 -114
- unstructured_ingest/embed/octoai.py +19 -51
- unstructured_ingest/embed/openai.py +17 -55
- unstructured_ingest/embed/togetherai.py +16 -58
- unstructured_ingest/embed/vertexai.py +15 -46
- unstructured_ingest/embed/voyageai.py +17 -52
- unstructured_ingest/v2/errors.py +7 -0
- unstructured_ingest/v2/processes/connectors/neo4j.py +129 -43
- unstructured_ingest/v2/processes/embedder.py +9 -7
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.3.dist-info}/METADATA +96 -84
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.3.dist-info}/RECORD +22 -22
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.3.dist-info}/WHEEL +1 -1
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.3.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.3.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.3.dist-info}/top_level.txt +0 -0
|
@@ -1,24 +1,19 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
from typing import TYPE_CHECKING
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
3
|
|
|
4
4
|
from pydantic import Field, SecretStr
|
|
5
5
|
|
|
6
6
|
from unstructured_ingest.embed.interfaces import (
|
|
7
|
-
EMBEDDINGS_KEY,
|
|
8
7
|
AsyncBaseEmbeddingEncoder,
|
|
9
8
|
BaseEmbeddingEncoder,
|
|
10
9
|
EmbeddingConfig,
|
|
11
10
|
)
|
|
12
11
|
from unstructured_ingest.logger import logger
|
|
13
|
-
from unstructured_ingest.utils.data_prep import batch_generator
|
|
14
12
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
15
13
|
from unstructured_ingest.v2.errors import (
|
|
16
14
|
RateLimitError as CustomRateLimitError,
|
|
17
15
|
)
|
|
18
|
-
from unstructured_ingest.v2.errors import
|
|
19
|
-
UserAuthError,
|
|
20
|
-
UserError,
|
|
21
|
-
)
|
|
16
|
+
from unstructured_ingest.v2.errors import UserAuthError, UserError, is_internal_error
|
|
22
17
|
|
|
23
18
|
if TYPE_CHECKING:
|
|
24
19
|
from together import AsyncTogether, Together
|
|
@@ -31,6 +26,8 @@ class TogetherAIEmbeddingConfig(EmbeddingConfig):
|
|
|
31
26
|
)
|
|
32
27
|
|
|
33
28
|
def wrap_error(self, e: Exception) -> Exception:
|
|
29
|
+
if is_internal_error(e=e):
|
|
30
|
+
return e
|
|
34
31
|
# https://docs.together.ai/docs/error-codes
|
|
35
32
|
from together.error import AuthenticationError, RateLimitError, TogetherException
|
|
36
33
|
|
|
@@ -64,31 +61,12 @@ class TogetherAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
64
61
|
def wrap_error(self, e: Exception) -> Exception:
|
|
65
62
|
return self.config.wrap_error(e=e)
|
|
66
63
|
|
|
67
|
-
def
|
|
68
|
-
return self.
|
|
69
|
-
|
|
70
|
-
def
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
embeddings = self._embed_documents([e["text"] for e in elements_with_text])
|
|
74
|
-
for element, embedding in zip(elements_with_text, embeddings):
|
|
75
|
-
element[EMBEDDINGS_KEY] = embedding
|
|
76
|
-
return elements
|
|
77
|
-
|
|
78
|
-
def _embed_documents(self, elements: list[str]) -> list[list[float]]:
|
|
79
|
-
client = self.config.get_client()
|
|
80
|
-
embeddings = []
|
|
81
|
-
try:
|
|
82
|
-
for batch in batch_generator(
|
|
83
|
-
elements, batch_size=self.config.batch_size or len(elements)
|
|
84
|
-
):
|
|
85
|
-
outputs = client.embeddings.create(
|
|
86
|
-
model=self.config.embedder_model_name, input=batch
|
|
87
|
-
)
|
|
88
|
-
embeddings.extend([outputs.data[i].embedding for i in range(len(batch))])
|
|
89
|
-
except Exception as e:
|
|
90
|
-
raise self.wrap_error(e=e)
|
|
91
|
-
return embeddings
|
|
64
|
+
def get_client(self) -> "Together":
|
|
65
|
+
return self.config.get_client()
|
|
66
|
+
|
|
67
|
+
def embed_batch(self, client: "Together", batch: list[str]) -> list[list[float]]:
|
|
68
|
+
outputs = client.embeddings.create(model=self.config.embedder_model_name, input=batch)
|
|
69
|
+
return [outputs.data[i].embedding for i in range(len(batch))]
|
|
92
70
|
|
|
93
71
|
|
|
94
72
|
@dataclass
|
|
@@ -98,29 +76,9 @@ class AsyncTogetherAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
98
76
|
def wrap_error(self, e: Exception) -> Exception:
|
|
99
77
|
return self.config.wrap_error(e=e)
|
|
100
78
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
elements_with_text = [e for e in elements if e.get("text")]
|
|
108
|
-
embeddings = await self._embed_documents([e["text"] for e in elements_with_text])
|
|
109
|
-
for element, embedding in zip(elements_with_text, embeddings):
|
|
110
|
-
element[EMBEDDINGS_KEY] = embedding
|
|
111
|
-
return elements
|
|
112
|
-
|
|
113
|
-
async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
|
|
114
|
-
client = self.config.get_async_client()
|
|
115
|
-
embeddings = []
|
|
116
|
-
try:
|
|
117
|
-
for batch in batch_generator(
|
|
118
|
-
elements, batch_size=self.config.batch_size or len(elements)
|
|
119
|
-
):
|
|
120
|
-
outputs = await client.embeddings.create(
|
|
121
|
-
model=self.config.embedder_model_name, input=batch
|
|
122
|
-
)
|
|
123
|
-
embeddings.extend([outputs.data[i].embedding for i in range(len(batch))])
|
|
124
|
-
except Exception as e:
|
|
125
|
-
raise self.wrap_error(e=e)
|
|
126
|
-
return embeddings
|
|
79
|
+
def get_client(self) -> "AsyncTogether":
|
|
80
|
+
return self.config.get_async_client()
|
|
81
|
+
|
|
82
|
+
async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
|
|
83
|
+
outputs = await client.embeddings.create(model=self.config.embedder_model_name, input=batch)
|
|
84
|
+
return [outputs.data[i].embedding for i in range(len(batch))]
|
|
@@ -9,14 +9,12 @@ from pydantic import Field, Secret, ValidationError
|
|
|
9
9
|
from pydantic.functional_validators import BeforeValidator
|
|
10
10
|
|
|
11
11
|
from unstructured_ingest.embed.interfaces import (
|
|
12
|
-
EMBEDDINGS_KEY,
|
|
13
12
|
AsyncBaseEmbeddingEncoder,
|
|
14
13
|
BaseEmbeddingEncoder,
|
|
15
14
|
EmbeddingConfig,
|
|
16
15
|
)
|
|
17
|
-
from unstructured_ingest.utils.data_prep import batch_generator
|
|
18
16
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
19
|
-
from unstructured_ingest.v2.errors import UserAuthError
|
|
17
|
+
from unstructured_ingest.v2.errors import UserAuthError, is_internal_error
|
|
20
18
|
|
|
21
19
|
if TYPE_CHECKING:
|
|
22
20
|
from vertexai.language_models import TextEmbeddingModel
|
|
@@ -40,6 +38,8 @@ class VertexAIEmbeddingConfig(EmbeddingConfig):
|
|
|
40
38
|
)
|
|
41
39
|
|
|
42
40
|
def wrap_error(self, e: Exception) -> Exception:
|
|
41
|
+
if is_internal_error(e=e):
|
|
42
|
+
return e
|
|
43
43
|
from google.auth.exceptions import GoogleAuthError
|
|
44
44
|
|
|
45
45
|
if isinstance(e, GoogleAuthError):
|
|
@@ -72,34 +72,19 @@ class VertexAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
72
72
|
def wrap_error(self, e: Exception) -> Exception:
|
|
73
73
|
return self.config.wrap_error(e=e)
|
|
74
74
|
|
|
75
|
-
def
|
|
76
|
-
return self.
|
|
77
|
-
|
|
78
|
-
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
79
|
-
elements = elements.copy()
|
|
80
|
-
elements_with_text = [e for e in elements if e.get("text")]
|
|
81
|
-
embeddings = self._embed_documents([e["text"] for e in elements_with_text])
|
|
82
|
-
for element, embedding in zip(elements_with_text, embeddings):
|
|
83
|
-
element[EMBEDDINGS_KEY] = embedding
|
|
84
|
-
return elements
|
|
75
|
+
def get_client(self) -> "TextEmbeddingModel":
|
|
76
|
+
return self.config.get_client()
|
|
85
77
|
|
|
86
78
|
@requires_dependencies(
|
|
87
79
|
["vertexai"],
|
|
88
80
|
extras="embed-vertexai",
|
|
89
81
|
)
|
|
90
|
-
def
|
|
82
|
+
def embed_batch(self, client: "TextEmbeddingModel", batch: list[str]) -> list[list[float]]:
|
|
91
83
|
from vertexai.language_models import TextEmbeddingInput
|
|
92
84
|
|
|
93
|
-
inputs = [TextEmbeddingInput(text=
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
try:
|
|
97
|
-
for batch in batch_generator(inputs, batch_size=self.config.batch_size or len(inputs)):
|
|
98
|
-
response = client.get_embeddings(batch)
|
|
99
|
-
embeddings.extend([e.values for e in response])
|
|
100
|
-
except Exception as e:
|
|
101
|
-
raise self.wrap_error(e=e)
|
|
102
|
-
return embeddings
|
|
85
|
+
inputs = [TextEmbeddingInput(text=text) for text in batch]
|
|
86
|
+
response = client.get_embeddings(inputs)
|
|
87
|
+
return [e.values for e in response]
|
|
103
88
|
|
|
104
89
|
|
|
105
90
|
@dataclass
|
|
@@ -109,32 +94,16 @@ class AsyncVertexAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
109
94
|
def wrap_error(self, e: Exception) -> Exception:
|
|
110
95
|
return self.config.wrap_error(e=e)
|
|
111
96
|
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
return embedding[0]
|
|
115
|
-
|
|
116
|
-
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
117
|
-
elements = elements.copy()
|
|
118
|
-
elements_with_text = [e for e in elements if e.get("text")]
|
|
119
|
-
embeddings = await self._embed_documents([e["text"] for e in elements_with_text])
|
|
120
|
-
for element, embedding in zip(elements_with_text, embeddings):
|
|
121
|
-
element[EMBEDDINGS_KEY] = embedding
|
|
122
|
-
return elements
|
|
97
|
+
def get_client(self) -> "TextEmbeddingModel":
|
|
98
|
+
return self.config.get_client()
|
|
123
99
|
|
|
124
100
|
@requires_dependencies(
|
|
125
101
|
["vertexai"],
|
|
126
102
|
extras="embed-vertexai",
|
|
127
103
|
)
|
|
128
|
-
async def
|
|
104
|
+
async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
|
|
129
105
|
from vertexai.language_models import TextEmbeddingInput
|
|
130
106
|
|
|
131
|
-
inputs = [TextEmbeddingInput(text=
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
try:
|
|
135
|
-
for batch in batch_generator(inputs, batch_size=self.config.batch_size or len(inputs)):
|
|
136
|
-
response = await client.get_embeddings_async(batch)
|
|
137
|
-
embeddings.extend([e.values for e in response])
|
|
138
|
-
except Exception as e:
|
|
139
|
-
raise self.wrap_error(e=e)
|
|
140
|
-
return embeddings
|
|
107
|
+
inputs = [TextEmbeddingInput(text=text) for text in batch]
|
|
108
|
+
response = await client.get_embeddings_async(inputs)
|
|
109
|
+
return [e.values for e in response]
|
|
@@ -4,19 +4,13 @@ from typing import TYPE_CHECKING, Optional
|
|
|
4
4
|
from pydantic import Field, SecretStr
|
|
5
5
|
|
|
6
6
|
from unstructured_ingest.embed.interfaces import (
|
|
7
|
-
EMBEDDINGS_KEY,
|
|
8
7
|
AsyncBaseEmbeddingEncoder,
|
|
9
8
|
BaseEmbeddingEncoder,
|
|
10
9
|
EmbeddingConfig,
|
|
11
10
|
)
|
|
12
11
|
from unstructured_ingest.logger import logger
|
|
13
|
-
from unstructured_ingest.utils.data_prep import batch_generator
|
|
14
12
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
15
|
-
from unstructured_ingest.v2.errors import
|
|
16
|
-
ProviderError,
|
|
17
|
-
UserAuthError,
|
|
18
|
-
UserError,
|
|
19
|
-
)
|
|
13
|
+
from unstructured_ingest.v2.errors import ProviderError, UserAuthError, UserError, is_internal_error
|
|
20
14
|
from unstructured_ingest.v2.errors import (
|
|
21
15
|
RateLimitError as CustomRateLimitError,
|
|
22
16
|
)
|
|
@@ -39,6 +33,8 @@ class VoyageAIEmbeddingConfig(EmbeddingConfig):
|
|
|
39
33
|
timeout_in_seconds: Optional[int] = None
|
|
40
34
|
|
|
41
35
|
def wrap_error(self, e: Exception) -> Exception:
|
|
36
|
+
if is_internal_error(e=e):
|
|
37
|
+
return e
|
|
42
38
|
# https://docs.voyageai.com/docs/error-codes
|
|
43
39
|
from voyageai.error import AuthenticationError, RateLimitError, VoyageError
|
|
44
40
|
|
|
@@ -96,27 +92,12 @@ class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
96
92
|
def wrap_error(self, e: Exception) -> Exception:
|
|
97
93
|
return self.config.wrap_error(e=e)
|
|
98
94
|
|
|
99
|
-
def
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
embeddings.extend(response.embeddings)
|
|
106
|
-
except Exception as e:
|
|
107
|
-
raise self.wrap_error(e=e)
|
|
108
|
-
return embeddings
|
|
109
|
-
|
|
110
|
-
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
111
|
-
elements = elements.copy()
|
|
112
|
-
elements_with_text = [e for e in elements if e.get("text")]
|
|
113
|
-
embeddings = self._embed_documents([e["text"] for e in elements_with_text])
|
|
114
|
-
for element, embedding in zip(elements_with_text, embeddings):
|
|
115
|
-
element[EMBEDDINGS_KEY] = embedding
|
|
116
|
-
return elements
|
|
117
|
-
|
|
118
|
-
def embed_query(self, query: str) -> list[float]:
|
|
119
|
-
return self._embed_documents(elements=[query])[0]
|
|
95
|
+
def get_client(self) -> "VoyageAIClient":
|
|
96
|
+
return self.config.get_client()
|
|
97
|
+
|
|
98
|
+
def embed_batch(self, client: "VoyageAIClient", batch: list[str]) -> list[list[float]]:
|
|
99
|
+
response = client.embed(texts=batch, model=self.config.embedder_model_name)
|
|
100
|
+
return response.embeddings
|
|
120
101
|
|
|
121
102
|
|
|
122
103
|
@dataclass
|
|
@@ -126,27 +107,11 @@ class AsyncVoyageAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
126
107
|
def wrap_error(self, e: Exception) -> Exception:
|
|
127
108
|
return self.config.wrap_error(e=e)
|
|
128
109
|
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
embeddings.extend(response.embeddings)
|
|
138
|
-
except Exception as e:
|
|
139
|
-
raise self.wrap_error(e=e)
|
|
140
|
-
return embeddings
|
|
141
|
-
|
|
142
|
-
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
143
|
-
elements = elements.copy()
|
|
144
|
-
elements_with_text = [e for e in elements if e.get("text")]
|
|
145
|
-
embeddings = await self._embed_documents([e["text"] for e in elements_with_text])
|
|
146
|
-
for element, embedding in zip(elements_with_text, embeddings):
|
|
147
|
-
element[EMBEDDINGS_KEY] = embedding
|
|
148
|
-
return elements
|
|
149
|
-
|
|
150
|
-
async def embed_query(self, query: str) -> list[float]:
|
|
151
|
-
embedding = await self._embed_documents(elements=[query])
|
|
152
|
-
return embedding[0]
|
|
110
|
+
def get_client(self) -> "AsyncVoyageAIClient":
|
|
111
|
+
return self.config.get_async_client()
|
|
112
|
+
|
|
113
|
+
async def embed_batch(
|
|
114
|
+
self, client: "AsyncVoyageAIClient", batch: list[str]
|
|
115
|
+
) -> list[list[float]]:
|
|
116
|
+
response = await client.embed(texts=batch, model=self.config.embedder_model_name)
|
|
117
|
+
return response.embeddings
|
unstructured_ingest/v2/errors.py
CHANGED
|
@@ -16,3 +16,10 @@ class QuotaError(UserError):
|
|
|
16
16
|
|
|
17
17
|
class ProviderError(Exception):
|
|
18
18
|
pass
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
recognized_errors = [UserError, UserAuthError, RateLimitError, QuotaError, ProviderError]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def is_internal_error(e: Exception) -> bool:
|
|
25
|
+
return any(isinstance(e, recognized_error) for recognized_error in recognized_errors)
|
|
@@ -8,9 +8,9 @@ from contextlib import asynccontextmanager
|
|
|
8
8
|
from dataclasses import dataclass
|
|
9
9
|
from enum import Enum
|
|
10
10
|
from pathlib import Path
|
|
11
|
-
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
|
|
11
|
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Literal, Optional
|
|
12
12
|
|
|
13
|
-
from pydantic import BaseModel, ConfigDict, Field, Secret
|
|
13
|
+
from pydantic import BaseModel, ConfigDict, Field, Secret, field_validator
|
|
14
14
|
|
|
15
15
|
from unstructured_ingest.error import DestinationConnectionError
|
|
16
16
|
from unstructured_ingest.logger import logger
|
|
@@ -30,6 +30,8 @@ from unstructured_ingest.v2.processes.connector_registry import (
|
|
|
30
30
|
DestinationRegistryEntry,
|
|
31
31
|
)
|
|
32
32
|
|
|
33
|
+
SimilarityFunction = Literal["cosine"]
|
|
34
|
+
|
|
33
35
|
if TYPE_CHECKING:
|
|
34
36
|
from neo4j import AsyncDriver, Auth
|
|
35
37
|
from networkx import Graph, MultiDiGraph
|
|
@@ -44,9 +46,9 @@ class Neo4jAccessConfig(AccessConfig):
|
|
|
44
46
|
class Neo4jConnectionConfig(ConnectionConfig):
|
|
45
47
|
access_config: Secret[Neo4jAccessConfig]
|
|
46
48
|
connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
|
|
47
|
-
username: str
|
|
49
|
+
username: str = Field(default="neo4j")
|
|
48
50
|
uri: str = Field(description="Neo4j Connection URI <scheme>://<host>:<port>")
|
|
49
|
-
database: str = Field(description="Name of the target database")
|
|
51
|
+
database: str = Field(default="neo4j", description="Name of the target database")
|
|
50
52
|
|
|
51
53
|
@requires_dependencies(["neo4j"], extras="neo4j")
|
|
52
54
|
@asynccontextmanager
|
|
@@ -186,8 +188,8 @@ class _GraphData(BaseModel):
|
|
|
186
188
|
nodes = list(nx_graph.nodes())
|
|
187
189
|
edges = [
|
|
188
190
|
_Edge(
|
|
189
|
-
|
|
190
|
-
|
|
191
|
+
source=u,
|
|
192
|
+
destination=v,
|
|
191
193
|
relationship=Relationship(data_dict["relationship"]),
|
|
192
194
|
)
|
|
193
195
|
for u, v, data_dict in nx_graph.edges(data=True)
|
|
@@ -198,19 +200,30 @@ class _GraphData(BaseModel):
|
|
|
198
200
|
class _Node(BaseModel):
|
|
199
201
|
model_config = ConfigDict()
|
|
200
202
|
|
|
201
|
-
|
|
202
|
-
labels: list[Label] = Field(default_factory=list)
|
|
203
|
+
labels: list[Label]
|
|
203
204
|
properties: dict = Field(default_factory=dict)
|
|
205
|
+
id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
204
206
|
|
|
205
207
|
def __hash__(self):
|
|
206
208
|
return hash(self.id_)
|
|
207
209
|
|
|
210
|
+
@property
|
|
211
|
+
def main_label(self) -> Label:
|
|
212
|
+
return self.labels[0]
|
|
213
|
+
|
|
214
|
+
@classmethod
|
|
215
|
+
@field_validator("labels", mode="after")
|
|
216
|
+
def require_at_least_one_label(cls, value: list[Label]) -> list[Label]:
|
|
217
|
+
if not value:
|
|
218
|
+
raise ValueError("Node must have at least one label.")
|
|
219
|
+
return value
|
|
220
|
+
|
|
208
221
|
|
|
209
222
|
class _Edge(BaseModel):
|
|
210
223
|
model_config = ConfigDict()
|
|
211
224
|
|
|
212
|
-
|
|
213
|
-
|
|
225
|
+
source: _Node
|
|
226
|
+
destination: _Node
|
|
214
227
|
relationship: Relationship
|
|
215
228
|
|
|
216
229
|
|
|
@@ -229,7 +242,14 @@ class Relationship(Enum):
|
|
|
229
242
|
|
|
230
243
|
class Neo4jUploaderConfig(UploaderConfig):
|
|
231
244
|
batch_size: int = Field(
|
|
232
|
-
default=
|
|
245
|
+
default=1000, description="Maximal number of nodes/relationships created per transaction."
|
|
246
|
+
)
|
|
247
|
+
similarity_function: SimilarityFunction = Field(
|
|
248
|
+
default="cosine",
|
|
249
|
+
description="Vector similarity function used to create index on Chunk nodes",
|
|
250
|
+
)
|
|
251
|
+
create_destination: bool = Field(
|
|
252
|
+
default=True, description="Create destination if it does not exist"
|
|
233
253
|
)
|
|
234
254
|
|
|
235
255
|
|
|
@@ -257,6 +277,13 @@ class Neo4jUploader(Uploader):
|
|
|
257
277
|
graph_data = _GraphData.model_validate(staged_data)
|
|
258
278
|
async with self.connection_config.get_client() as client:
|
|
259
279
|
await self._create_uniqueness_constraints(client)
|
|
280
|
+
embedding_dimensions = self._get_embedding_dimensions(graph_data)
|
|
281
|
+
if embedding_dimensions and self.upload_config.create_destination:
|
|
282
|
+
await self._create_vector_index(
|
|
283
|
+
client,
|
|
284
|
+
dimensions=embedding_dimensions,
|
|
285
|
+
similarity_function=self.upload_config.similarity_function,
|
|
286
|
+
)
|
|
260
287
|
await self._delete_old_data_if_exists(file_data, client=client)
|
|
261
288
|
await self._merge_graph(graph_data=graph_data, client=client)
|
|
262
289
|
|
|
@@ -274,13 +301,33 @@ class Neo4jUploader(Uploader):
|
|
|
274
301
|
"""
|
|
275
302
|
)
|
|
276
303
|
|
|
304
|
+
async def _create_vector_index(
|
|
305
|
+
self, client: AsyncDriver, dimensions: int, similarity_function: SimilarityFunction
|
|
306
|
+
) -> None:
|
|
307
|
+
label = Label.CHUNK
|
|
308
|
+
logger.info(
|
|
309
|
+
f"Creating index on nodes labeled '{label.value}' if it does not already exist."
|
|
310
|
+
)
|
|
311
|
+
index_name = f"{label.value.lower()}_vector"
|
|
312
|
+
await client.execute_query(
|
|
313
|
+
f"""
|
|
314
|
+
CREATE VECTOR INDEX {index_name} IF NOT EXISTS
|
|
315
|
+
FOR (n:{label.value}) ON n.embedding
|
|
316
|
+
OPTIONS {{indexConfig: {{
|
|
317
|
+
`vector.similarity_function`: '{similarity_function}',
|
|
318
|
+
`vector.dimensions`: {dimensions}}}
|
|
319
|
+
}}
|
|
320
|
+
"""
|
|
321
|
+
)
|
|
322
|
+
|
|
277
323
|
async def _delete_old_data_if_exists(self, file_data: FileData, client: AsyncDriver) -> None:
|
|
278
324
|
logger.info(f"Deleting old data for the record '{file_data.identifier}' (if present).")
|
|
279
325
|
_, summary, _ = await client.execute_query(
|
|
280
326
|
f"""
|
|
281
|
-
MATCH (n: {Label.DOCUMENT.value} {{id: $identifier}})
|
|
282
|
-
MATCH (n)--(m: {Label.CHUNK.value}
|
|
283
|
-
DETACH DELETE m
|
|
327
|
+
MATCH (n: `{Label.DOCUMENT.value}` {{id: $identifier}})
|
|
328
|
+
MATCH (n)--(m: `{Label.CHUNK.value}`|`{Label.UNSTRUCTURED_ELEMENT.value}`)
|
|
329
|
+
DETACH DELETE m
|
|
330
|
+
DETACH DELETE n""",
|
|
284
331
|
identifier=file_data.identifier,
|
|
285
332
|
)
|
|
286
333
|
logger.info(
|
|
@@ -289,16 +336,15 @@ class Neo4jUploader(Uploader):
|
|
|
289
336
|
)
|
|
290
337
|
|
|
291
338
|
async def _merge_graph(self, graph_data: _GraphData, client: AsyncDriver) -> None:
|
|
292
|
-
nodes_by_labels: defaultdict[
|
|
339
|
+
nodes_by_labels: defaultdict[Label, list[_Node]] = defaultdict(list)
|
|
293
340
|
for node in graph_data.nodes:
|
|
294
|
-
nodes_by_labels[
|
|
295
|
-
|
|
341
|
+
nodes_by_labels[node.main_label].append(node)
|
|
296
342
|
logger.info(f"Merging {len(graph_data.nodes)} graph nodes.")
|
|
297
343
|
# NOTE: Processed in parallel as there's no overlap between accessed nodes
|
|
298
344
|
await self._execute_queries(
|
|
299
345
|
[
|
|
300
|
-
self._create_nodes_query(nodes_batch,
|
|
301
|
-
for
|
|
346
|
+
self._create_nodes_query(nodes_batch, label)
|
|
347
|
+
for label, nodes in nodes_by_labels.items()
|
|
302
348
|
for nodes_batch in batch_generator(nodes, batch_size=self.upload_config.batch_size)
|
|
303
349
|
],
|
|
304
350
|
client=client,
|
|
@@ -306,16 +352,23 @@ class Neo4jUploader(Uploader):
|
|
|
306
352
|
)
|
|
307
353
|
logger.info(f"Finished merging {len(graph_data.nodes)} graph nodes.")
|
|
308
354
|
|
|
309
|
-
edges_by_relationship: defaultdict[Relationship, list[_Edge]] =
|
|
355
|
+
edges_by_relationship: defaultdict[tuple[Relationship, Label, Label], list[_Edge]] = (
|
|
356
|
+
defaultdict(list)
|
|
357
|
+
)
|
|
310
358
|
for edge in graph_data.edges:
|
|
311
|
-
|
|
359
|
+
key = (edge.relationship, edge.source.main_label, edge.destination.main_label)
|
|
360
|
+
edges_by_relationship[key].append(edge)
|
|
312
361
|
|
|
313
362
|
logger.info(f"Merging {len(graph_data.edges)} graph relationships (edges).")
|
|
314
363
|
# NOTE: Processed sequentially to avoid queries locking node access to one another
|
|
315
364
|
await self._execute_queries(
|
|
316
365
|
[
|
|
317
|
-
self._create_edges_query(edges_batch, relationship)
|
|
318
|
-
for
|
|
366
|
+
self._create_edges_query(edges_batch, relationship, source_label, destination_label)
|
|
367
|
+
for (
|
|
368
|
+
relationship,
|
|
369
|
+
source_label,
|
|
370
|
+
destination_label,
|
|
371
|
+
), edges in edges_by_relationship.items()
|
|
319
372
|
for edges_batch in batch_generator(edges, batch_size=self.upload_config.batch_size)
|
|
320
373
|
],
|
|
321
374
|
client=client,
|
|
@@ -328,53 +381,86 @@ class Neo4jUploader(Uploader):
|
|
|
328
381
|
client: AsyncDriver,
|
|
329
382
|
in_parallel: bool = False,
|
|
330
383
|
) -> None:
|
|
384
|
+
from neo4j import EagerResult
|
|
385
|
+
|
|
386
|
+
results: list[EagerResult] = []
|
|
387
|
+
logger.info(
|
|
388
|
+
f"Executing {len(queries_with_parameters)} "
|
|
389
|
+
+ f"{'parallel' if in_parallel else 'sequential'} Cypher statements."
|
|
390
|
+
)
|
|
331
391
|
if in_parallel:
|
|
332
|
-
|
|
333
|
-
await asyncio.gather(
|
|
392
|
+
results = await asyncio.gather(
|
|
334
393
|
*[
|
|
335
394
|
client.execute_query(query, parameters_=parameters)
|
|
336
395
|
for query, parameters in queries_with_parameters
|
|
337
396
|
]
|
|
338
397
|
)
|
|
339
|
-
logger.info("Finished executing parallel queries.")
|
|
340
398
|
else:
|
|
341
|
-
logger.info(f"Executing {len(queries_with_parameters)} queries sequentially.")
|
|
342
399
|
for i, (query, parameters) in enumerate(queries_with_parameters):
|
|
343
|
-
logger.info(f"
|
|
344
|
-
await client.execute_query(query, parameters_=parameters)
|
|
345
|
-
logger.info(f"
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
400
|
+
logger.info(f"Statement #{i} started.")
|
|
401
|
+
results.append(await client.execute_query(query, parameters_=parameters))
|
|
402
|
+
logger.info(f"Statement #{i} finished.")
|
|
403
|
+
nodeCount = sum([res.summary.counters.nodes_created for res in results])
|
|
404
|
+
relCount = sum([res.summary.counters.relationships_created for res in results])
|
|
405
|
+
logger.info(
|
|
406
|
+
f"Finished executing all ({len(queries_with_parameters)}) "
|
|
407
|
+
+ f"{'parallel' if in_parallel else 'sequential'} Cypher statements. "
|
|
408
|
+
+ f"Created {nodeCount} nodes, {relCount} relationships."
|
|
409
|
+
)
|
|
349
410
|
|
|
350
411
|
@staticmethod
|
|
351
|
-
def _create_nodes_query(nodes: list[_Node],
|
|
352
|
-
|
|
353
|
-
logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{labels_string}'.")
|
|
412
|
+
def _create_nodes_query(nodes: list[_Node], label: Label) -> tuple[str, dict]:
|
|
413
|
+
logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{label}'.")
|
|
354
414
|
query_string = f"""
|
|
355
415
|
UNWIND $nodes AS node
|
|
356
|
-
MERGE (n: {
|
|
416
|
+
MERGE (n: `{label.value}` {{id: node.id}})
|
|
357
417
|
SET n += node.properties
|
|
418
|
+
SET n:$(node.labels)
|
|
419
|
+
WITH * WHERE node.vector IS NOT NULL
|
|
420
|
+
CALL db.create.setNodeVectorProperty(n, 'embedding', node.vector)
|
|
358
421
|
"""
|
|
359
|
-
parameters = {
|
|
422
|
+
parameters = {
|
|
423
|
+
"nodes": [
|
|
424
|
+
{
|
|
425
|
+
"id": node.id_,
|
|
426
|
+
"labels": [l.value for l in node.labels if l != label], # noqa: E741
|
|
427
|
+
"vector": node.properties.pop("embedding", None),
|
|
428
|
+
"properties": node.properties,
|
|
429
|
+
}
|
|
430
|
+
for node in nodes
|
|
431
|
+
]
|
|
432
|
+
}
|
|
360
433
|
return query_string, parameters
|
|
361
434
|
|
|
362
435
|
@staticmethod
|
|
363
|
-
def _create_edges_query(
|
|
436
|
+
def _create_edges_query(
|
|
437
|
+
edges: list[_Edge],
|
|
438
|
+
relationship: Relationship,
|
|
439
|
+
source_label: Label,
|
|
440
|
+
destination_label: Label,
|
|
441
|
+
) -> tuple[str, dict]:
|
|
364
442
|
logger.info(f"Preparing MERGE query for {len(edges)} {relationship} relationships.")
|
|
365
443
|
query_string = f"""
|
|
366
444
|
UNWIND $edges AS edge
|
|
367
|
-
MATCH (u {{id: edge.source}})
|
|
368
|
-
MATCH (v {{id: edge.destination}})
|
|
369
|
-
MERGE (u)-[
|
|
445
|
+
MATCH (u: `{source_label.value}` {{id: edge.source}})
|
|
446
|
+
MATCH (v: `{destination_label.value}` {{id: edge.destination}})
|
|
447
|
+
MERGE (u)-[:`{relationship.value}`]->(v)
|
|
370
448
|
"""
|
|
371
449
|
parameters = {
|
|
372
450
|
"edges": [
|
|
373
|
-
{"source": edge.
|
|
451
|
+
{"source": edge.source.id_, "destination": edge.destination.id_} for edge in edges
|
|
374
452
|
]
|
|
375
453
|
}
|
|
376
454
|
return query_string, parameters
|
|
377
455
|
|
|
456
|
+
def _get_embedding_dimensions(self, graph_data: _GraphData) -> int | None:
|
|
457
|
+
"""Embedding dimensions inferred from chunk nodes or None if it can't be determined."""
|
|
458
|
+
for node in graph_data.nodes:
|
|
459
|
+
if Label.CHUNK in node.labels and "embeddings" in node.properties:
|
|
460
|
+
return len(node.properties["embeddings"])
|
|
461
|
+
|
|
462
|
+
return None
|
|
463
|
+
|
|
378
464
|
|
|
379
465
|
neo4j_destination_entry = DestinationRegistryEntry(
|
|
380
466
|
connection_config=Neo4jConnectionConfig,
|
|
@@ -92,18 +92,20 @@ class EmbedderConfig(BaseModel):
|
|
|
92
92
|
|
|
93
93
|
return OctoAIEmbeddingEncoder(config=OctoAiEmbeddingConfig.model_validate(embedding_kwargs))
|
|
94
94
|
|
|
95
|
-
def get_bedrock_embedder(self) -> "BaseEmbeddingEncoder":
|
|
95
|
+
def get_bedrock_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
|
|
96
96
|
from unstructured_ingest.embed.bedrock import (
|
|
97
97
|
BedrockEmbeddingConfig,
|
|
98
98
|
BedrockEmbeddingEncoder,
|
|
99
99
|
)
|
|
100
100
|
|
|
101
|
+
embedding_kwargs = embedding_kwargs | {
|
|
102
|
+
"aws_access_key_id": self.embedding_aws_access_key_id,
|
|
103
|
+
"aws_secret_access_key": self.embedding_aws_secret_access_key.get_secret_value(),
|
|
104
|
+
"region_name": self.embedding_aws_region,
|
|
105
|
+
}
|
|
106
|
+
|
|
101
107
|
return BedrockEmbeddingEncoder(
|
|
102
|
-
config=BedrockEmbeddingConfig(
|
|
103
|
-
aws_access_key_id=self.embedding_aws_access_key_id,
|
|
104
|
-
aws_secret_access_key=self.embedding_aws_secret_access_key.get_secret_value(),
|
|
105
|
-
region_name=self.embedding_aws_region,
|
|
106
|
-
)
|
|
108
|
+
config=BedrockEmbeddingConfig.model_validate(embedding_kwargs)
|
|
107
109
|
)
|
|
108
110
|
|
|
109
111
|
def get_vertexai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
|
|
@@ -163,7 +165,7 @@ class EmbedderConfig(BaseModel):
|
|
|
163
165
|
return self.get_octoai_embedder(embedding_kwargs=kwargs)
|
|
164
166
|
|
|
165
167
|
if self.embedding_provider == "bedrock":
|
|
166
|
-
return self.get_bedrock_embedder()
|
|
168
|
+
return self.get_bedrock_embedder(embedding_kwargs=kwargs)
|
|
167
169
|
|
|
168
170
|
if self.embedding_provider == "vertexai":
|
|
169
171
|
return self.get_vertexai_embedder(embedding_kwargs=kwargs)
|