graphiti-core 0.22.0rc5__py3-none-any.whl → 0.22.1rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +5 -7
- graphiti_core/driver/falkordb_driver.py +54 -3
- graphiti_core/driver/graph_operations/__init__.py +0 -0
- graphiti_core/driver/graph_operations/graph_operations.py +195 -0
- graphiti_core/driver/neo4j_driver.py +9 -0
- graphiti_core/driver/search_interface/__init__.py +0 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +11 -34
- graphiti_core/models/edges/edge_db_queries.py +1 -0
- graphiti_core/models/nodes/node_db_queries.py +1 -0
- graphiti_core/nodes.py +26 -99
- graphiti_core/search/search_filters.py +0 -38
- graphiti_core/search/search_utils.py +84 -220
- graphiti_core/utils/bulk_utils.py +14 -28
- graphiti_core/utils/maintenance/edge_operations.py +20 -15
- graphiti_core/utils/maintenance/graph_data_operations.py +6 -25
- {graphiti_core-0.22.0rc5.dist-info → graphiti_core-0.22.1rc1.dist-info}/METADATA +36 -3
- {graphiti_core-0.22.0rc5.dist-info → graphiti_core-0.22.1rc1.dist-info}/RECORD +20 -16
- {graphiti_core-0.22.0rc5.dist-info → graphiti_core-0.22.1rc1.dist-info}/WHEEL +0 -0
- {graphiti_core-0.22.0rc5.dist-info → graphiti_core-0.22.1rc1.dist-info}/licenses/LICENSE +0 -0
graphiti_core/driver/driver.py
CHANGED
|
@@ -24,6 +24,9 @@ from typing import Any
|
|
|
24
24
|
|
|
25
25
|
from dotenv import load_dotenv
|
|
26
26
|
|
|
27
|
+
from graphiti_core.driver.graph_operations.graph_operations import GraphOperationsInterface
|
|
28
|
+
from graphiti_core.driver.search_interface.search_interface import SearchInterface
|
|
29
|
+
|
|
27
30
|
logger = logging.getLogger(__name__)
|
|
28
31
|
|
|
29
32
|
DEFAULT_SIZE = 10
|
|
@@ -73,7 +76,8 @@ class GraphDriver(ABC):
|
|
|
73
76
|
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
|
|
74
77
|
)
|
|
75
78
|
_database: str
|
|
76
|
-
|
|
79
|
+
search_interface: SearchInterface | None = None
|
|
80
|
+
graph_operations_interface: GraphOperationsInterface | None = None
|
|
77
81
|
|
|
78
82
|
@abstractmethod
|
|
79
83
|
def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
|
|
@@ -109,9 +113,3 @@ class GraphDriver(ABC):
|
|
|
109
113
|
Only implemented by providers that need custom fulltext query building.
|
|
110
114
|
"""
|
|
111
115
|
raise NotImplementedError(f'build_fulltext_query not implemented for {self.provider}')
|
|
112
|
-
|
|
113
|
-
async def save_to_aoss(self, name: str, data: list[dict]) -> int:
|
|
114
|
-
return 0
|
|
115
|
-
|
|
116
|
-
async def clear_aoss_indices(self):
|
|
117
|
-
return 1
|
|
@@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import asyncio
|
|
18
|
+
import datetime
|
|
17
19
|
import logging
|
|
18
20
|
from typing import TYPE_CHECKING, Any
|
|
19
21
|
|
|
@@ -191,9 +193,36 @@ class FalkorDriver(GraphDriver):
|
|
|
191
193
|
await self.client.connection.close()
|
|
192
194
|
|
|
193
195
|
async def delete_all_indexes(self) -> None:
|
|
194
|
-
await self.execute_query(
|
|
195
|
-
|
|
196
|
-
|
|
196
|
+
result = await self.execute_query('CALL db.indexes()')
|
|
197
|
+
if not result:
|
|
198
|
+
return
|
|
199
|
+
|
|
200
|
+
records, _, _ = result
|
|
201
|
+
drop_tasks = []
|
|
202
|
+
|
|
203
|
+
for record in records:
|
|
204
|
+
label = record['label']
|
|
205
|
+
entity_type = record['entitytype']
|
|
206
|
+
|
|
207
|
+
for field_name, index_type in record['types'].items():
|
|
208
|
+
if 'RANGE' in index_type:
|
|
209
|
+
drop_tasks.append(self.execute_query(f'DROP INDEX ON :{label}({field_name})'))
|
|
210
|
+
elif 'FULLTEXT' in index_type:
|
|
211
|
+
if entity_type == 'NODE':
|
|
212
|
+
drop_tasks.append(
|
|
213
|
+
self.execute_query(
|
|
214
|
+
f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field_name})'
|
|
215
|
+
)
|
|
216
|
+
)
|
|
217
|
+
elif entity_type == 'RELATIONSHIP':
|
|
218
|
+
drop_tasks.append(
|
|
219
|
+
self.execute_query(
|
|
220
|
+
f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field_name})'
|
|
221
|
+
)
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
if drop_tasks:
|
|
225
|
+
await asyncio.gather(*drop_tasks)
|
|
197
226
|
|
|
198
227
|
def clone(self, database: str) -> 'GraphDriver':
|
|
199
228
|
"""
|
|
@@ -204,6 +233,28 @@ class FalkorDriver(GraphDriver):
|
|
|
204
233
|
|
|
205
234
|
return cloned
|
|
206
235
|
|
|
236
|
+
async def health_check(self) -> None:
|
|
237
|
+
"""Check FalkorDB connectivity by running a simple query."""
|
|
238
|
+
try:
|
|
239
|
+
await self.execute_query('MATCH (n) RETURN 1 LIMIT 1')
|
|
240
|
+
return None
|
|
241
|
+
except Exception as e:
|
|
242
|
+
print(f'FalkorDB health check failed: {e}')
|
|
243
|
+
raise
|
|
244
|
+
|
|
245
|
+
@staticmethod
|
|
246
|
+
def convert_datetimes_to_strings(obj):
|
|
247
|
+
if isinstance(obj, dict):
|
|
248
|
+
return {k: FalkorDriver.convert_datetimes_to_strings(v) for k, v in obj.items()}
|
|
249
|
+
elif isinstance(obj, list):
|
|
250
|
+
return [FalkorDriver.convert_datetimes_to_strings(item) for item in obj]
|
|
251
|
+
elif isinstance(obj, tuple):
|
|
252
|
+
return tuple(FalkorDriver.convert_datetimes_to_strings(item) for item in obj)
|
|
253
|
+
elif isinstance(obj, datetime):
|
|
254
|
+
return obj.isoformat()
|
|
255
|
+
else:
|
|
256
|
+
return obj
|
|
257
|
+
|
|
207
258
|
def sanitize(self, query: str) -> str:
|
|
208
259
|
"""
|
|
209
260
|
Replace FalkorDB special characters with whitespace.
|
|
File without changes
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class GraphOperationsInterface(BaseModel):
|
|
23
|
+
"""
|
|
24
|
+
Interface for updating graph mutation behavior.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
# -----------------
|
|
28
|
+
# Node: Save/Delete
|
|
29
|
+
# -----------------
|
|
30
|
+
|
|
31
|
+
async def node_save(self, node: Any, driver: Any) -> None:
|
|
32
|
+
"""Persist (create or update) a single node."""
|
|
33
|
+
raise NotImplementedError
|
|
34
|
+
|
|
35
|
+
async def node_delete(self, node: Any, driver: Any) -> None:
|
|
36
|
+
raise NotImplementedError
|
|
37
|
+
|
|
38
|
+
async def node_save_bulk(
|
|
39
|
+
self,
|
|
40
|
+
_cls: Any, # kept for parity; callers won't pass it
|
|
41
|
+
driver: Any,
|
|
42
|
+
transaction: Any,
|
|
43
|
+
nodes: list[Any],
|
|
44
|
+
batch_size: int = 100,
|
|
45
|
+
) -> None:
|
|
46
|
+
"""Persist (create or update) many nodes in batches."""
|
|
47
|
+
raise NotImplementedError
|
|
48
|
+
|
|
49
|
+
async def node_delete_by_group_id(
|
|
50
|
+
self,
|
|
51
|
+
_cls: Any,
|
|
52
|
+
driver: Any,
|
|
53
|
+
group_id: str,
|
|
54
|
+
batch_size: int = 100,
|
|
55
|
+
) -> None:
|
|
56
|
+
raise NotImplementedError
|
|
57
|
+
|
|
58
|
+
async def node_delete_by_uuids(
|
|
59
|
+
self,
|
|
60
|
+
_cls: Any,
|
|
61
|
+
driver: Any,
|
|
62
|
+
uuids: list[str],
|
|
63
|
+
group_id: str | None = None,
|
|
64
|
+
batch_size: int = 100,
|
|
65
|
+
) -> None:
|
|
66
|
+
raise NotImplementedError
|
|
67
|
+
|
|
68
|
+
# --------------------------
|
|
69
|
+
# Node: Embeddings (load)
|
|
70
|
+
# --------------------------
|
|
71
|
+
|
|
72
|
+
async def node_load_embeddings(self, node: Any, driver: Any) -> None:
|
|
73
|
+
"""
|
|
74
|
+
Load embedding vectors for a single node into the instance (e.g., set node.embedding or similar).
|
|
75
|
+
"""
|
|
76
|
+
raise NotImplementedError
|
|
77
|
+
|
|
78
|
+
async def node_load_embeddings_bulk(
|
|
79
|
+
self,
|
|
80
|
+
_cls: Any,
|
|
81
|
+
driver: Any,
|
|
82
|
+
transaction: Any,
|
|
83
|
+
nodes: list[Any],
|
|
84
|
+
batch_size: int = 100,
|
|
85
|
+
) -> None:
|
|
86
|
+
"""
|
|
87
|
+
Load embedding vectors for many nodes in batches. Mutates the provided node instances.
|
|
88
|
+
"""
|
|
89
|
+
raise NotImplementedError
|
|
90
|
+
|
|
91
|
+
# --------------------------
|
|
92
|
+
# EpisodicNode: Save/Delete
|
|
93
|
+
# --------------------------
|
|
94
|
+
|
|
95
|
+
async def episodic_node_save(self, node: Any, driver: Any) -> None:
|
|
96
|
+
"""Persist (create or update) a single episodic node."""
|
|
97
|
+
raise NotImplementedError
|
|
98
|
+
|
|
99
|
+
async def episodic_node_delete(self, node: Any, driver: Any) -> None:
|
|
100
|
+
raise NotImplementedError
|
|
101
|
+
|
|
102
|
+
async def episodic_node_save_bulk(
|
|
103
|
+
self,
|
|
104
|
+
_cls: Any,
|
|
105
|
+
driver: Any,
|
|
106
|
+
transaction: Any,
|
|
107
|
+
nodes: list[Any],
|
|
108
|
+
batch_size: int = 100,
|
|
109
|
+
) -> None:
|
|
110
|
+
"""Persist (create or update) many episodic nodes in batches."""
|
|
111
|
+
raise NotImplementedError
|
|
112
|
+
|
|
113
|
+
async def episodic_edge_save_bulk(
|
|
114
|
+
self,
|
|
115
|
+
_cls: Any,
|
|
116
|
+
driver: Any,
|
|
117
|
+
transaction: Any,
|
|
118
|
+
episodic_edges: list[Any],
|
|
119
|
+
batch_size: int = 100,
|
|
120
|
+
) -> None:
|
|
121
|
+
"""Persist (create or update) many episodic edges in batches."""
|
|
122
|
+
raise NotImplementedError
|
|
123
|
+
|
|
124
|
+
async def episodic_node_delete_by_group_id(
|
|
125
|
+
self,
|
|
126
|
+
_cls: Any,
|
|
127
|
+
driver: Any,
|
|
128
|
+
group_id: str,
|
|
129
|
+
batch_size: int = 100,
|
|
130
|
+
) -> None:
|
|
131
|
+
raise NotImplementedError
|
|
132
|
+
|
|
133
|
+
async def episodic_node_delete_by_uuids(
|
|
134
|
+
self,
|
|
135
|
+
_cls: Any,
|
|
136
|
+
driver: Any,
|
|
137
|
+
uuids: list[str],
|
|
138
|
+
group_id: str | None = None,
|
|
139
|
+
batch_size: int = 100,
|
|
140
|
+
) -> None:
|
|
141
|
+
raise NotImplementedError
|
|
142
|
+
|
|
143
|
+
# -----------------
|
|
144
|
+
# Edge: Save/Delete
|
|
145
|
+
# -----------------
|
|
146
|
+
|
|
147
|
+
async def edge_save(self, edge: Any, driver: Any) -> None:
|
|
148
|
+
"""Persist (create or update) a single edge."""
|
|
149
|
+
raise NotImplementedError
|
|
150
|
+
|
|
151
|
+
async def edge_delete(self, edge: Any, driver: Any) -> None:
|
|
152
|
+
raise NotImplementedError
|
|
153
|
+
|
|
154
|
+
async def edge_save_bulk(
|
|
155
|
+
self,
|
|
156
|
+
_cls: Any,
|
|
157
|
+
driver: Any,
|
|
158
|
+
transaction: Any,
|
|
159
|
+
edges: list[Any],
|
|
160
|
+
batch_size: int = 100,
|
|
161
|
+
) -> None:
|
|
162
|
+
"""Persist (create or update) many edges in batches."""
|
|
163
|
+
raise NotImplementedError
|
|
164
|
+
|
|
165
|
+
async def edge_delete_by_uuids(
|
|
166
|
+
self,
|
|
167
|
+
_cls: Any,
|
|
168
|
+
driver: Any,
|
|
169
|
+
uuids: list[str],
|
|
170
|
+
group_id: str | None = None,
|
|
171
|
+
) -> None:
|
|
172
|
+
raise NotImplementedError
|
|
173
|
+
|
|
174
|
+
# -----------------
|
|
175
|
+
# Edge: Embeddings (load)
|
|
176
|
+
# -----------------
|
|
177
|
+
|
|
178
|
+
async def edge_load_embeddings(self, edge: Any, driver: Any) -> None:
|
|
179
|
+
"""
|
|
180
|
+
Load embedding vectors for a single edge into the instance (e.g., set edge.embedding or similar).
|
|
181
|
+
"""
|
|
182
|
+
raise NotImplementedError
|
|
183
|
+
|
|
184
|
+
async def edge_load_embeddings_bulk(
|
|
185
|
+
self,
|
|
186
|
+
_cls: Any,
|
|
187
|
+
driver: Any,
|
|
188
|
+
transaction: Any,
|
|
189
|
+
edges: list[Any],
|
|
190
|
+
batch_size: int = 100,
|
|
191
|
+
) -> None:
|
|
192
|
+
"""
|
|
193
|
+
Load embedding vectors for many edges in batches. Mutates the provided edge instances.
|
|
194
|
+
"""
|
|
195
|
+
raise NotImplementedError
|
|
@@ -72,3 +72,12 @@ class Neo4jDriver(GraphDriver):
|
|
|
72
72
|
return self.client.execute_query(
|
|
73
73
|
'CALL db.indexes() YIELD name DROP INDEX name',
|
|
74
74
|
)
|
|
75
|
+
|
|
76
|
+
async def health_check(self) -> None:
|
|
77
|
+
"""Check Neo4j connectivity by running the driver's verify_connectivity method."""
|
|
78
|
+
try:
|
|
79
|
+
await self.client.verify_connectivity()
|
|
80
|
+
return None
|
|
81
|
+
except Exception as e:
|
|
82
|
+
print(f'Neo4j health check failed: {e}')
|
|
83
|
+
raise
|
|
File without changes
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright 2024, Zep Software, Inc.
|
|
3
|
+
|
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
you may not use this file except in compliance with the License.
|
|
6
|
+
You may obtain a copy of the License at
|
|
7
|
+
|
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
|
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
See the License for the specific language governing permissions and
|
|
14
|
+
limitations under the License.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from typing import Any
|
|
18
|
+
|
|
19
|
+
from pydantic import BaseModel
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SearchInterface(BaseModel):
|
|
23
|
+
"""
|
|
24
|
+
This is an interface for implementing custom search logic
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
async def edge_fulltext_search(
|
|
28
|
+
self,
|
|
29
|
+
driver: Any,
|
|
30
|
+
query: str,
|
|
31
|
+
search_filter: Any,
|
|
32
|
+
group_ids: list[str] | None = None,
|
|
33
|
+
limit: int = 100,
|
|
34
|
+
) -> list[Any]:
|
|
35
|
+
raise NotImplementedError
|
|
36
|
+
|
|
37
|
+
async def edge_similarity_search(
|
|
38
|
+
self,
|
|
39
|
+
driver: Any,
|
|
40
|
+
search_vector: list[float],
|
|
41
|
+
source_node_uuid: str | None,
|
|
42
|
+
target_node_uuid: str | None,
|
|
43
|
+
search_filter: Any,
|
|
44
|
+
group_ids: list[str] | None = None,
|
|
45
|
+
limit: int = 100,
|
|
46
|
+
min_score: float = 0.7,
|
|
47
|
+
) -> list[Any]:
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
async def node_fulltext_search(
|
|
51
|
+
self,
|
|
52
|
+
driver: Any,
|
|
53
|
+
query: str,
|
|
54
|
+
search_filter: Any,
|
|
55
|
+
group_ids: list[str] | None = None,
|
|
56
|
+
limit: int = 100,
|
|
57
|
+
) -> list[Any]:
|
|
58
|
+
raise NotImplementedError
|
|
59
|
+
|
|
60
|
+
async def node_similarity_search(
|
|
61
|
+
self,
|
|
62
|
+
driver: Any,
|
|
63
|
+
search_vector: list[float],
|
|
64
|
+
search_filter: Any,
|
|
65
|
+
group_ids: list[str] | None = None,
|
|
66
|
+
limit: int = 100,
|
|
67
|
+
min_score: float = 0.7,
|
|
68
|
+
) -> list[Any]:
|
|
69
|
+
raise NotImplementedError
|
|
70
|
+
|
|
71
|
+
async def episode_fulltext_search(
|
|
72
|
+
self,
|
|
73
|
+
driver: Any,
|
|
74
|
+
query: str,
|
|
75
|
+
search_filter: Any, # kept for parity even if unused in your impl
|
|
76
|
+
group_ids: list[str] | None = None,
|
|
77
|
+
limit: int = 100,
|
|
78
|
+
) -> list[Any]:
|
|
79
|
+
raise NotImplementedError
|
|
80
|
+
|
|
81
|
+
# ---------- SEARCH FILTERS (sync) ----------
|
|
82
|
+
def build_node_search_filters(self, search_filters: Any) -> Any:
|
|
83
|
+
raise NotImplementedError
|
|
84
|
+
|
|
85
|
+
def build_edge_search_filters(self, search_filters: Any) -> Any:
|
|
86
|
+
raise NotImplementedError
|
|
87
|
+
|
|
88
|
+
class Config:
|
|
89
|
+
arbitrary_types_allowed = True
|
graphiti_core/edges.py
CHANGED
|
@@ -25,7 +25,7 @@ from uuid import uuid4
|
|
|
25
25
|
from pydantic import BaseModel, Field
|
|
26
26
|
from typing_extensions import LiteralString
|
|
27
27
|
|
|
28
|
-
from graphiti_core.driver.driver import
|
|
28
|
+
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
29
29
|
from graphiti_core.embedder import EmbedderClient
|
|
30
30
|
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
|
|
31
31
|
from graphiti_core.helpers import parse_db_date
|
|
@@ -53,6 +53,9 @@ class Edge(BaseModel, ABC):
|
|
|
53
53
|
async def save(self, driver: GraphDriver): ...
|
|
54
54
|
|
|
55
55
|
async def delete(self, driver: GraphDriver):
|
|
56
|
+
if driver.graph_operations_interface:
|
|
57
|
+
return await driver.graph_operations_interface.edge_delete(self, driver)
|
|
58
|
+
|
|
56
59
|
if driver.provider == GraphProvider.KUZU:
|
|
57
60
|
await driver.execute_query(
|
|
58
61
|
"""
|
|
@@ -77,17 +80,13 @@ class Edge(BaseModel, ABC):
|
|
|
77
80
|
uuid=self.uuid,
|
|
78
81
|
)
|
|
79
82
|
|
|
80
|
-
if driver.aoss_client:
|
|
81
|
-
await driver.aoss_client.delete(
|
|
82
|
-
index=ENTITY_EDGE_INDEX_NAME,
|
|
83
|
-
id=self.uuid,
|
|
84
|
-
params={'routing': self.group_id},
|
|
85
|
-
)
|
|
86
|
-
|
|
87
83
|
logger.debug(f'Deleted Edge: {self.uuid}')
|
|
88
84
|
|
|
89
85
|
@classmethod
|
|
90
86
|
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
87
|
+
if driver.graph_operations_interface:
|
|
88
|
+
return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)
|
|
89
|
+
|
|
91
90
|
if driver.provider == GraphProvider.KUZU:
|
|
92
91
|
await driver.execute_query(
|
|
93
92
|
"""
|
|
@@ -115,12 +114,6 @@ class Edge(BaseModel, ABC):
|
|
|
115
114
|
uuids=uuids,
|
|
116
115
|
)
|
|
117
116
|
|
|
118
|
-
if driver.aoss_client:
|
|
119
|
-
await driver.aoss_client.delete_by_query(
|
|
120
|
-
index=ENTITY_EDGE_INDEX_NAME,
|
|
121
|
-
body={'query': {'terms': {'uuid': uuids}}},
|
|
122
|
-
)
|
|
123
|
-
|
|
124
117
|
logger.debug(f'Deleted Edges: {uuids}')
|
|
125
118
|
|
|
126
119
|
def __hash__(self):
|
|
@@ -258,6 +251,9 @@ class EntityEdge(Edge):
|
|
|
258
251
|
return self.fact_embedding
|
|
259
252
|
|
|
260
253
|
async def load_fact_embedding(self, driver: GraphDriver):
|
|
254
|
+
if driver.graph_operations_interface:
|
|
255
|
+
return await driver.graph_operations_interface.edge_load_embeddings(self, driver)
|
|
256
|
+
|
|
261
257
|
query = """
|
|
262
258
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
263
259
|
RETURN e.fact_embedding AS fact_embedding
|
|
@@ -268,21 +264,6 @@ class EntityEdge(Edge):
|
|
|
268
264
|
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
|
|
269
265
|
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
|
|
270
266
|
"""
|
|
271
|
-
elif driver.aoss_client:
|
|
272
|
-
resp = await driver.aoss_client.search(
|
|
273
|
-
body={
|
|
274
|
-
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
|
275
|
-
'size': 1,
|
|
276
|
-
},
|
|
277
|
-
index=ENTITY_EDGE_INDEX_NAME,
|
|
278
|
-
params={'routing': self.group_id},
|
|
279
|
-
)
|
|
280
|
-
|
|
281
|
-
if resp['hits']['hits']:
|
|
282
|
-
self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding']
|
|
283
|
-
return
|
|
284
|
-
else:
|
|
285
|
-
raise EdgeNotFoundError(self.uuid)
|
|
286
267
|
|
|
287
268
|
if driver.provider == GraphProvider.KUZU:
|
|
288
269
|
query = """
|
|
@@ -320,15 +301,11 @@ class EntityEdge(Edge):
|
|
|
320
301
|
if driver.provider == GraphProvider.KUZU:
|
|
321
302
|
edge_data['attributes'] = json.dumps(self.attributes)
|
|
322
303
|
result = await driver.execute_query(
|
|
323
|
-
get_entity_edge_save_query(driver.provider
|
|
304
|
+
get_entity_edge_save_query(driver.provider),
|
|
324
305
|
**edge_data,
|
|
325
306
|
)
|
|
326
307
|
else:
|
|
327
308
|
edge_data.update(self.attributes or {})
|
|
328
|
-
|
|
329
|
-
if driver.aoss_client:
|
|
330
|
-
await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
|
|
331
|
-
|
|
332
309
|
result = await driver.execute_query(
|
|
333
310
|
get_entity_edge_save_query(driver.provider),
|
|
334
311
|
edge_data=edge_data,
|
|
@@ -68,6 +68,7 @@ def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False)
|
|
|
68
68
|
MATCH (target:Entity {uuid: $edge_data.target_uuid})
|
|
69
69
|
MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
|
|
70
70
|
SET e = $edge_data
|
|
71
|
+
SET e.fact_embedding = vecf32($edge_data.fact_embedding)
|
|
71
72
|
RETURN e.uuid AS uuid
|
|
72
73
|
"""
|
|
73
74
|
case GraphProvider.NEPTUNE:
|
|
@@ -133,6 +133,7 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: b
|
|
|
133
133
|
MERGE (n:Entity {{uuid: $entity_data.uuid}})
|
|
134
134
|
SET n:{labels}
|
|
135
135
|
SET n = $entity_data
|
|
136
|
+
SET n.name_embedding = vecf32($entity_data.name_embedding)
|
|
136
137
|
RETURN n.uuid AS uuid
|
|
137
138
|
"""
|
|
138
139
|
case GraphProvider.KUZU:
|