graphiti-core 0.17.4__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/driver.py +62 -2
- graphiti_core/driver/falkordb_driver.py +215 -23
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +61 -8
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +264 -132
- graphiti_core/embedder/azure_openai.py +10 -3
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +114 -101
- graphiti_core/graphiti.py +582 -255
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/helpers.py +21 -14
- graphiti_core/llm_client/anthropic_client.py +142 -52
- graphiti_core/llm_client/azure_openai_client.py +57 -19
- graphiti_core/llm_client/client.py +83 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +75 -57
- graphiti_core/llm_client/openai_base_client.py +94 -50
- graphiti_core/llm_client/openai_client.py +28 -8
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +388 -164
- graphiti_core/prompts/dedupe_edges.py +42 -31
- graphiti_core/prompts/dedupe_nodes.py +56 -39
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +23 -14
- graphiti_core/prompts/extract_nodes.py +73 -32
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +154 -74
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +109 -31
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1360 -473
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +216 -90
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +62 -38
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +286 -126
- graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
- graphiti_core/utils/maintenance/node_operations.py +320 -158
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/METADATA +221 -87
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.17.4.dist-info/RECORD +0 -77
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/licenses/LICENSE +0 -0
graphiti_core/nodes.py
CHANGED
|
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import json
|
|
17
18
|
import logging
|
|
18
19
|
from abc import ABC, abstractmethod
|
|
19
20
|
from datetime import datetime
|
|
@@ -25,30 +26,27 @@ from uuid import uuid4
|
|
|
25
26
|
from pydantic import BaseModel, Field
|
|
26
27
|
from typing_extensions import LiteralString
|
|
27
28
|
|
|
28
|
-
from graphiti_core.driver.driver import
|
|
29
|
+
from graphiti_core.driver.driver import (
|
|
30
|
+
GraphDriver,
|
|
31
|
+
GraphProvider,
|
|
32
|
+
)
|
|
29
33
|
from graphiti_core.embedder import EmbedderClient
|
|
30
34
|
from graphiti_core.errors import NodeNotFoundError
|
|
31
35
|
from graphiti_core.helpers import parse_db_date
|
|
32
36
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
37
|
+
COMMUNITY_NODE_RETURN,
|
|
38
|
+
COMMUNITY_NODE_RETURN_NEPTUNE,
|
|
39
|
+
EPISODIC_NODE_RETURN,
|
|
40
|
+
EPISODIC_NODE_RETURN_NEPTUNE,
|
|
41
|
+
get_community_node_save_query,
|
|
42
|
+
get_entity_node_return_query,
|
|
43
|
+
get_entity_node_save_query,
|
|
44
|
+
get_episode_node_save_query,
|
|
36
45
|
)
|
|
37
46
|
from graphiti_core.utils.datetime_utils import utc_now
|
|
38
47
|
|
|
39
48
|
logger = logging.getLogger(__name__)
|
|
40
49
|
|
|
41
|
-
ENTITY_NODE_RETURN: LiteralString = """
|
|
42
|
-
RETURN
|
|
43
|
-
n.uuid As uuid,
|
|
44
|
-
n.name AS name,
|
|
45
|
-
n.group_id AS group_id,
|
|
46
|
-
n.created_at AS created_at,
|
|
47
|
-
n.summary AS summary,
|
|
48
|
-
labels(n) AS labels,
|
|
49
|
-
properties(n) AS attributes
|
|
50
|
-
"""
|
|
51
|
-
|
|
52
50
|
|
|
53
51
|
class EpisodeType(Enum):
|
|
54
52
|
"""
|
|
@@ -97,18 +95,60 @@ class Node(BaseModel, ABC):
|
|
|
97
95
|
async def save(self, driver: GraphDriver): ...
|
|
98
96
|
|
|
99
97
|
async def delete(self, driver: GraphDriver):
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
98
|
+
if driver.graph_operations_interface:
|
|
99
|
+
return await driver.graph_operations_interface.node_delete(self, driver)
|
|
100
|
+
|
|
101
|
+
match driver.provider:
|
|
102
|
+
case GraphProvider.NEO4J:
|
|
103
|
+
records, _, _ = await driver.execute_query(
|
|
104
|
+
"""
|
|
105
|
+
MATCH (n {uuid: $uuid})
|
|
106
|
+
WHERE n:Entity OR n:Episodic OR n:Community
|
|
107
|
+
OPTIONAL MATCH (n)-[r]-()
|
|
108
|
+
WITH collect(r.uuid) AS edge_uuids, n
|
|
109
|
+
DETACH DELETE n
|
|
110
|
+
RETURN edge_uuids
|
|
111
|
+
""",
|
|
112
|
+
uuid=self.uuid,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
case GraphProvider.KUZU:
|
|
116
|
+
for label in ['Episodic', 'Community']:
|
|
117
|
+
await driver.execute_query(
|
|
118
|
+
f"""
|
|
119
|
+
MATCH (n:{label} {{uuid: $uuid}})
|
|
120
|
+
DETACH DELETE n
|
|
121
|
+
""",
|
|
122
|
+
uuid=self.uuid,
|
|
123
|
+
)
|
|
124
|
+
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
|
|
125
|
+
# Explicitly delete the "edge" nodes first, then the entity node.
|
|
126
|
+
await driver.execute_query(
|
|
127
|
+
"""
|
|
128
|
+
MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
|
|
129
|
+
DETACH DELETE e
|
|
130
|
+
""",
|
|
131
|
+
uuid=self.uuid,
|
|
132
|
+
)
|
|
133
|
+
await driver.execute_query(
|
|
134
|
+
"""
|
|
135
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
136
|
+
DETACH DELETE n
|
|
137
|
+
""",
|
|
138
|
+
uuid=self.uuid,
|
|
139
|
+
)
|
|
140
|
+
case _: # FalkorDB, Neptune
|
|
141
|
+
for label in ['Entity', 'Episodic', 'Community']:
|
|
142
|
+
await driver.execute_query(
|
|
143
|
+
f"""
|
|
144
|
+
MATCH (n:{label} {{uuid: $uuid}})
|
|
145
|
+
DETACH DELETE n
|
|
146
|
+
""",
|
|
147
|
+
uuid=self.uuid,
|
|
148
|
+
)
|
|
107
149
|
|
|
108
150
|
logger.debug(f'Deleted Node: {self.uuid}')
|
|
109
151
|
|
|
110
|
-
return result
|
|
111
|
-
|
|
112
152
|
def __hash__(self):
|
|
113
153
|
return hash(self.uuid)
|
|
114
154
|
|
|
@@ -118,16 +158,132 @@ class Node(BaseModel, ABC):
|
|
|
118
158
|
return False
|
|
119
159
|
|
|
120
160
|
@classmethod
|
|
121
|
-
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
161
|
+
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
|
|
162
|
+
if driver.graph_operations_interface:
|
|
163
|
+
return await driver.graph_operations_interface.node_delete_by_group_id(
|
|
164
|
+
cls, driver, group_id, batch_size
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
match driver.provider:
|
|
168
|
+
case GraphProvider.NEO4J:
|
|
169
|
+
async with driver.session() as session:
|
|
170
|
+
await session.run(
|
|
171
|
+
"""
|
|
172
|
+
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
|
173
|
+
CALL (n) {
|
|
174
|
+
DETACH DELETE n
|
|
175
|
+
} IN TRANSACTIONS OF $batch_size ROWS
|
|
176
|
+
""",
|
|
177
|
+
group_id=group_id,
|
|
178
|
+
batch_size=batch_size,
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
case GraphProvider.KUZU:
|
|
182
|
+
for label in ['Episodic', 'Community']:
|
|
183
|
+
await driver.execute_query(
|
|
184
|
+
f"""
|
|
185
|
+
MATCH (n:{label} {{group_id: $group_id}})
|
|
186
|
+
DETACH DELETE n
|
|
187
|
+
""",
|
|
188
|
+
group_id=group_id,
|
|
189
|
+
)
|
|
190
|
+
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
|
|
191
|
+
# Explicitly delete the "edge" nodes first, then the entity node.
|
|
192
|
+
await driver.execute_query(
|
|
193
|
+
"""
|
|
194
|
+
MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
|
|
195
|
+
DETACH DELETE e
|
|
196
|
+
""",
|
|
197
|
+
group_id=group_id,
|
|
198
|
+
)
|
|
199
|
+
await driver.execute_query(
|
|
200
|
+
"""
|
|
201
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
202
|
+
DETACH DELETE n
|
|
203
|
+
""",
|
|
204
|
+
group_id=group_id,
|
|
205
|
+
)
|
|
206
|
+
case _: # FalkorDB, Neptune
|
|
207
|
+
for label in ['Entity', 'Episodic', 'Community']:
|
|
208
|
+
await driver.execute_query(
|
|
209
|
+
f"""
|
|
210
|
+
MATCH (n:{label} {{group_id: $group_id}})
|
|
211
|
+
DETACH DELETE n
|
|
212
|
+
""",
|
|
213
|
+
group_id=group_id,
|
|
214
|
+
)
|
|
129
215
|
|
|
130
|
-
|
|
216
|
+
@classmethod
|
|
217
|
+
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
|
|
218
|
+
if driver.graph_operations_interface:
|
|
219
|
+
return await driver.graph_operations_interface.node_delete_by_uuids(
|
|
220
|
+
cls, driver, uuids, group_id=None, batch_size=batch_size
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
match driver.provider:
|
|
224
|
+
case GraphProvider.FALKORDB:
|
|
225
|
+
for label in ['Entity', 'Episodic', 'Community']:
|
|
226
|
+
await driver.execute_query(
|
|
227
|
+
f"""
|
|
228
|
+
MATCH (n:{label})
|
|
229
|
+
WHERE n.uuid IN $uuids
|
|
230
|
+
DETACH DELETE n
|
|
231
|
+
""",
|
|
232
|
+
uuids=uuids,
|
|
233
|
+
)
|
|
234
|
+
case GraphProvider.KUZU:
|
|
235
|
+
for label in ['Episodic', 'Community']:
|
|
236
|
+
await driver.execute_query(
|
|
237
|
+
f"""
|
|
238
|
+
MATCH (n:{label})
|
|
239
|
+
WHERE n.uuid IN $uuids
|
|
240
|
+
DETACH DELETE n
|
|
241
|
+
""",
|
|
242
|
+
uuids=uuids,
|
|
243
|
+
)
|
|
244
|
+
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
|
|
245
|
+
# Explicitly delete the "edge" nodes first, then the entity node.
|
|
246
|
+
await driver.execute_query(
|
|
247
|
+
"""
|
|
248
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
|
|
249
|
+
WHERE n.uuid IN $uuids
|
|
250
|
+
DETACH DELETE e
|
|
251
|
+
""",
|
|
252
|
+
uuids=uuids,
|
|
253
|
+
)
|
|
254
|
+
await driver.execute_query(
|
|
255
|
+
"""
|
|
256
|
+
MATCH (n:Entity)
|
|
257
|
+
WHERE n.uuid IN $uuids
|
|
258
|
+
DETACH DELETE n
|
|
259
|
+
""",
|
|
260
|
+
uuids=uuids,
|
|
261
|
+
)
|
|
262
|
+
case _: # Neo4J, Neptune
|
|
263
|
+
async with driver.session() as session:
|
|
264
|
+
# Collect all edge UUIDs before deleting nodes
|
|
265
|
+
await session.run(
|
|
266
|
+
"""
|
|
267
|
+
MATCH (n:Entity|Episodic|Community)
|
|
268
|
+
WHERE n.uuid IN $uuids
|
|
269
|
+
MATCH (n)-[r]-()
|
|
270
|
+
RETURN collect(r.uuid) AS edge_uuids
|
|
271
|
+
""",
|
|
272
|
+
uuids=uuids,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Now delete the nodes in batches
|
|
276
|
+
await session.run(
|
|
277
|
+
"""
|
|
278
|
+
MATCH (n:Entity|Episodic|Community)
|
|
279
|
+
WHERE n.uuid IN $uuids
|
|
280
|
+
CALL (n) {
|
|
281
|
+
DETACH DELETE n
|
|
282
|
+
} IN TRANSACTIONS OF $batch_size ROWS
|
|
283
|
+
""",
|
|
284
|
+
uuids=uuids,
|
|
285
|
+
batch_size=batch_size,
|
|
286
|
+
)
|
|
131
287
|
|
|
132
288
|
@classmethod
|
|
133
289
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
|
@@ -149,17 +305,23 @@ class EpisodicNode(Node):
|
|
|
149
305
|
)
|
|
150
306
|
|
|
151
307
|
async def save(self, driver: GraphDriver):
|
|
308
|
+
if driver.graph_operations_interface:
|
|
309
|
+
return await driver.graph_operations_interface.episodic_node_save(self, driver)
|
|
310
|
+
|
|
311
|
+
episode_args = {
|
|
312
|
+
'uuid': self.uuid,
|
|
313
|
+
'name': self.name,
|
|
314
|
+
'group_id': self.group_id,
|
|
315
|
+
'source_description': self.source_description,
|
|
316
|
+
'content': self.content,
|
|
317
|
+
'entity_edges': self.entity_edges,
|
|
318
|
+
'created_at': self.created_at,
|
|
319
|
+
'valid_at': self.valid_at,
|
|
320
|
+
'source': self.source.value,
|
|
321
|
+
}
|
|
322
|
+
|
|
152
323
|
result = await driver.execute_query(
|
|
153
|
-
|
|
154
|
-
uuid=self.uuid,
|
|
155
|
-
name=self.name,
|
|
156
|
-
group_id=self.group_id,
|
|
157
|
-
source_description=self.source_description,
|
|
158
|
-
content=self.content,
|
|
159
|
-
entity_edges=self.entity_edges,
|
|
160
|
-
created_at=self.created_at,
|
|
161
|
-
valid_at=self.valid_at,
|
|
162
|
-
source=self.source.value,
|
|
324
|
+
get_episode_node_save_query(driver.provider), **episode_args
|
|
163
325
|
)
|
|
164
326
|
|
|
165
327
|
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
@@ -170,17 +332,14 @@ class EpisodicNode(Node):
|
|
|
170
332
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
171
333
|
records, _, _ = await driver.execute_query(
|
|
172
334
|
"""
|
|
173
|
-
|
|
174
|
-
RETURN
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
e.source AS source,
|
|
182
|
-
e.entity_edges AS entity_edges
|
|
183
|
-
""",
|
|
335
|
+
MATCH (e:Episodic {uuid: $uuid})
|
|
336
|
+
RETURN
|
|
337
|
+
"""
|
|
338
|
+
+ (
|
|
339
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
340
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
341
|
+
else EPISODIC_NODE_RETURN
|
|
342
|
+
),
|
|
184
343
|
uuid=uuid,
|
|
185
344
|
routing_='r',
|
|
186
345
|
)
|
|
@@ -196,18 +355,15 @@ class EpisodicNode(Node):
|
|
|
196
355
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
197
356
|
records, _, _ = await driver.execute_query(
|
|
198
357
|
"""
|
|
199
|
-
|
|
358
|
+
MATCH (e:Episodic)
|
|
359
|
+
WHERE e.uuid IN $uuids
|
|
200
360
|
RETURN DISTINCT
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
e.source_description AS source_description,
|
|
208
|
-
e.source AS source,
|
|
209
|
-
e.entity_edges AS entity_edges
|
|
210
|
-
""",
|
|
361
|
+
"""
|
|
362
|
+
+ (
|
|
363
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
364
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
365
|
+
else EPISODIC_NODE_RETURN
|
|
366
|
+
),
|
|
211
367
|
uuids=uuids,
|
|
212
368
|
routing_='r',
|
|
213
369
|
)
|
|
@@ -229,22 +385,21 @@ class EpisodicNode(Node):
|
|
|
229
385
|
|
|
230
386
|
records, _, _ = await driver.execute_query(
|
|
231
387
|
"""
|
|
232
|
-
|
|
233
|
-
|
|
388
|
+
MATCH (e:Episodic)
|
|
389
|
+
WHERE e.group_id IN $group_ids
|
|
390
|
+
"""
|
|
234
391
|
+ cursor_query
|
|
235
392
|
+ """
|
|
236
393
|
RETURN DISTINCT
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
ORDER BY e.uuid DESC
|
|
247
|
-
"""
|
|
394
|
+
"""
|
|
395
|
+
+ (
|
|
396
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
397
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
398
|
+
else EPISODIC_NODE_RETURN
|
|
399
|
+
)
|
|
400
|
+
+ """
|
|
401
|
+
ORDER BY uuid DESC
|
|
402
|
+
"""
|
|
248
403
|
+ limit_query,
|
|
249
404
|
group_ids=group_ids,
|
|
250
405
|
uuid=uuid_cursor,
|
|
@@ -260,18 +415,14 @@ class EpisodicNode(Node):
|
|
|
260
415
|
async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
|
|
261
416
|
records, _, _ = await driver.execute_query(
|
|
262
417
|
"""
|
|
263
|
-
|
|
418
|
+
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
|
264
419
|
RETURN DISTINCT
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
e.source_description AS source_description,
|
|
272
|
-
e.source AS source,
|
|
273
|
-
e.entity_edges AS entity_edges
|
|
274
|
-
""",
|
|
420
|
+
"""
|
|
421
|
+
+ (
|
|
422
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
423
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
424
|
+
else EPISODIC_NODE_RETURN
|
|
425
|
+
),
|
|
275
426
|
entity_node_uuid=entity_node_uuid,
|
|
276
427
|
routing_='r',
|
|
277
428
|
)
|
|
@@ -298,11 +449,25 @@ class EntityNode(Node):
|
|
|
298
449
|
return self.name_embedding
|
|
299
450
|
|
|
300
451
|
async def load_name_embedding(self, driver: GraphDriver):
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
452
|
+
if driver.graph_operations_interface:
|
|
453
|
+
return await driver.graph_operations_interface.node_load_embeddings(self, driver)
|
|
454
|
+
|
|
455
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
456
|
+
query: LiteralString = """
|
|
457
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
458
|
+
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
|
459
|
+
"""
|
|
460
|
+
|
|
461
|
+
else:
|
|
462
|
+
query: LiteralString = """
|
|
463
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
464
|
+
RETURN n.name_embedding AS name_embedding
|
|
465
|
+
"""
|
|
466
|
+
records, _, _ = await driver.execute_query(
|
|
467
|
+
query,
|
|
468
|
+
uuid=self.uuid,
|
|
469
|
+
routing_='r',
|
|
470
|
+
)
|
|
306
471
|
|
|
307
472
|
if len(records) == 0:
|
|
308
473
|
raise NodeNotFoundError(self.uuid)
|
|
@@ -310,6 +475,9 @@ class EntityNode(Node):
|
|
|
310
475
|
self.name_embedding = records[0]['name_embedding']
|
|
311
476
|
|
|
312
477
|
async def save(self, driver: GraphDriver):
|
|
478
|
+
if driver.graph_operations_interface:
|
|
479
|
+
return await driver.graph_operations_interface.node_save(self, driver)
|
|
480
|
+
|
|
313
481
|
entity_data: dict[str, Any] = {
|
|
314
482
|
'uuid': self.uuid,
|
|
315
483
|
'name': self.name,
|
|
@@ -319,13 +487,21 @@ class EntityNode(Node):
|
|
|
319
487
|
'created_at': self.created_at,
|
|
320
488
|
}
|
|
321
489
|
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
490
|
+
if driver.provider == GraphProvider.KUZU:
|
|
491
|
+
entity_data['attributes'] = json.dumps(self.attributes)
|
|
492
|
+
entity_data['labels'] = list(set(self.labels + ['Entity']))
|
|
493
|
+
result = await driver.execute_query(
|
|
494
|
+
get_entity_node_save_query(driver.provider, labels=''),
|
|
495
|
+
**entity_data,
|
|
496
|
+
)
|
|
497
|
+
else:
|
|
498
|
+
entity_data.update(self.attributes or {})
|
|
499
|
+
labels = ':'.join(self.labels + ['Entity'])
|
|
500
|
+
|
|
501
|
+
result = await driver.execute_query(
|
|
502
|
+
get_entity_node_save_query(driver.provider, labels),
|
|
503
|
+
entity_data=entity_data,
|
|
504
|
+
)
|
|
329
505
|
|
|
330
506
|
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
331
507
|
|
|
@@ -333,19 +509,17 @@ class EntityNode(Node):
|
|
|
333
509
|
|
|
334
510
|
@classmethod
|
|
335
511
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
336
|
-
query = (
|
|
337
|
-
"""
|
|
338
|
-
MATCH (n:Entity {uuid: $uuid})
|
|
339
|
-
"""
|
|
340
|
-
+ ENTITY_NODE_RETURN
|
|
341
|
-
)
|
|
342
512
|
records, _, _ = await driver.execute_query(
|
|
343
|
-
|
|
513
|
+
"""
|
|
514
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
515
|
+
RETURN
|
|
516
|
+
"""
|
|
517
|
+
+ get_entity_node_return_query(driver.provider),
|
|
344
518
|
uuid=uuid,
|
|
345
519
|
routing_='r',
|
|
346
520
|
)
|
|
347
521
|
|
|
348
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
522
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
349
523
|
|
|
350
524
|
if len(nodes) == 0:
|
|
351
525
|
raise NodeNotFoundError(uuid)
|
|
@@ -356,14 +530,16 @@ class EntityNode(Node):
|
|
|
356
530
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
357
531
|
records, _, _ = await driver.execute_query(
|
|
358
532
|
"""
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
533
|
+
MATCH (n:Entity)
|
|
534
|
+
WHERE n.uuid IN $uuids
|
|
535
|
+
RETURN
|
|
536
|
+
"""
|
|
537
|
+
+ get_entity_node_return_query(driver.provider),
|
|
362
538
|
uuids=uuids,
|
|
363
539
|
routing_='r',
|
|
364
540
|
)
|
|
365
541
|
|
|
366
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
542
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
367
543
|
|
|
368
544
|
return nodes
|
|
369
545
|
|
|
@@ -374,19 +550,32 @@ class EntityNode(Node):
|
|
|
374
550
|
group_ids: list[str],
|
|
375
551
|
limit: int | None = None,
|
|
376
552
|
uuid_cursor: str | None = None,
|
|
553
|
+
with_embeddings: bool = False,
|
|
377
554
|
):
|
|
378
555
|
cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
|
|
379
556
|
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
557
|
+
with_embeddings_query: LiteralString = (
|
|
558
|
+
""",
|
|
559
|
+
n.name_embedding AS name_embedding
|
|
560
|
+
"""
|
|
561
|
+
if with_embeddings
|
|
562
|
+
else ''
|
|
563
|
+
)
|
|
380
564
|
|
|
381
565
|
records, _, _ = await driver.execute_query(
|
|
382
566
|
"""
|
|
383
|
-
|
|
384
|
-
|
|
567
|
+
MATCH (n:Entity)
|
|
568
|
+
WHERE n.group_id IN $group_ids
|
|
569
|
+
"""
|
|
385
570
|
+ cursor_query
|
|
386
|
-
+ ENTITY_NODE_RETURN
|
|
387
571
|
+ """
|
|
388
|
-
|
|
389
|
-
|
|
572
|
+
RETURN
|
|
573
|
+
"""
|
|
574
|
+
+ get_entity_node_return_query(driver.provider)
|
|
575
|
+
+ with_embeddings_query
|
|
576
|
+
+ """
|
|
577
|
+
ORDER BY n.uuid DESC
|
|
578
|
+
"""
|
|
390
579
|
+ limit_query,
|
|
391
580
|
group_ids=group_ids,
|
|
392
581
|
uuid=uuid_cursor,
|
|
@@ -394,7 +583,7 @@ class EntityNode(Node):
|
|
|
394
583
|
routing_='r',
|
|
395
584
|
)
|
|
396
585
|
|
|
397
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
586
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
398
587
|
|
|
399
588
|
return nodes
|
|
400
589
|
|
|
@@ -404,8 +593,13 @@ class CommunityNode(Node):
|
|
|
404
593
|
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
|
405
594
|
|
|
406
595
|
async def save(self, driver: GraphDriver):
|
|
596
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
597
|
+
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
|
598
|
+
'communities',
|
|
599
|
+
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
|
600
|
+
)
|
|
407
601
|
result = await driver.execute_query(
|
|
408
|
-
|
|
602
|
+
get_community_node_save_query(driver.provider), # type: ignore
|
|
409
603
|
uuid=self.uuid,
|
|
410
604
|
name=self.name,
|
|
411
605
|
group_id=self.group_id,
|
|
@@ -428,11 +622,22 @@ class CommunityNode(Node):
|
|
|
428
622
|
return self.name_embedding
|
|
429
623
|
|
|
430
624
|
async def load_name_embedding(self, driver: GraphDriver):
|
|
431
|
-
|
|
625
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
626
|
+
query: LiteralString = """
|
|
627
|
+
MATCH (c:Community {uuid: $uuid})
|
|
628
|
+
RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
|
|
629
|
+
"""
|
|
630
|
+
else:
|
|
631
|
+
query: LiteralString = """
|
|
432
632
|
MATCH (c:Community {uuid: $uuid})
|
|
433
633
|
RETURN c.name_embedding AS name_embedding
|
|
434
|
-
|
|
435
|
-
|
|
634
|
+
"""
|
|
635
|
+
|
|
636
|
+
records, _, _ = await driver.execute_query(
|
|
637
|
+
query,
|
|
638
|
+
uuid=self.uuid,
|
|
639
|
+
routing_='r',
|
|
640
|
+
)
|
|
436
641
|
|
|
437
642
|
if len(records) == 0:
|
|
438
643
|
raise NodeNotFoundError(self.uuid)
|
|
@@ -443,14 +648,14 @@ class CommunityNode(Node):
|
|
|
443
648
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
444
649
|
records, _, _ = await driver.execute_query(
|
|
445
650
|
"""
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
651
|
+
MATCH (c:Community {uuid: $uuid})
|
|
652
|
+
RETURN
|
|
653
|
+
"""
|
|
654
|
+
+ (
|
|
655
|
+
COMMUNITY_NODE_RETURN_NEPTUNE
|
|
656
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
657
|
+
else COMMUNITY_NODE_RETURN
|
|
658
|
+
),
|
|
454
659
|
uuid=uuid,
|
|
455
660
|
routing_='r',
|
|
456
661
|
)
|
|
@@ -466,14 +671,15 @@ class CommunityNode(Node):
|
|
|
466
671
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
467
672
|
records, _, _ = await driver.execute_query(
|
|
468
673
|
"""
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
674
|
+
MATCH (c:Community)
|
|
675
|
+
WHERE c.uuid IN $uuids
|
|
676
|
+
RETURN
|
|
677
|
+
"""
|
|
678
|
+
+ (
|
|
679
|
+
COMMUNITY_NODE_RETURN_NEPTUNE
|
|
680
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
681
|
+
else COMMUNITY_NODE_RETURN
|
|
682
|
+
),
|
|
477
683
|
uuids=uuids,
|
|
478
684
|
routing_='r',
|
|
479
685
|
)
|
|
@@ -490,23 +696,26 @@ class CommunityNode(Node):
|
|
|
490
696
|
limit: int | None = None,
|
|
491
697
|
uuid_cursor: str | None = None,
|
|
492
698
|
):
|
|
493
|
-
cursor_query: LiteralString = 'AND
|
|
699
|
+
cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
|
|
494
700
|
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
495
701
|
|
|
496
702
|
records, _, _ = await driver.execute_query(
|
|
497
703
|
"""
|
|
498
|
-
|
|
499
|
-
|
|
704
|
+
MATCH (c:Community)
|
|
705
|
+
WHERE c.group_id IN $group_ids
|
|
706
|
+
"""
|
|
500
707
|
+ cursor_query
|
|
501
708
|
+ """
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
709
|
+
RETURN
|
|
710
|
+
"""
|
|
711
|
+
+ (
|
|
712
|
+
COMMUNITY_NODE_RETURN_NEPTUNE
|
|
713
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
714
|
+
else COMMUNITY_NODE_RETURN
|
|
715
|
+
)
|
|
716
|
+
+ """
|
|
717
|
+
ORDER BY c.uuid DESC
|
|
718
|
+
"""
|
|
510
719
|
+ limit_query,
|
|
511
720
|
group_ids=group_ids,
|
|
512
721
|
uuid=uuid_cursor,
|
|
@@ -542,24 +751,35 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
|
|
542
751
|
)
|
|
543
752
|
|
|
544
753
|
|
|
545
|
-
def get_entity_node_from_record(record: Any) -> EntityNode:
|
|
754
|
+
def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
|
|
755
|
+
if provider == GraphProvider.KUZU:
|
|
756
|
+
attributes = json.loads(record['attributes']) if record['attributes'] else {}
|
|
757
|
+
else:
|
|
758
|
+
attributes = record['attributes']
|
|
759
|
+
attributes.pop('uuid', None)
|
|
760
|
+
attributes.pop('name', None)
|
|
761
|
+
attributes.pop('group_id', None)
|
|
762
|
+
attributes.pop('name_embedding', None)
|
|
763
|
+
attributes.pop('summary', None)
|
|
764
|
+
attributes.pop('created_at', None)
|
|
765
|
+
attributes.pop('labels', None)
|
|
766
|
+
|
|
767
|
+
labels = record.get('labels', [])
|
|
768
|
+
group_id = record.get('group_id')
|
|
769
|
+
if 'Entity_' + group_id.replace('-', '') in labels:
|
|
770
|
+
labels.remove('Entity_' + group_id.replace('-', ''))
|
|
771
|
+
|
|
546
772
|
entity_node = EntityNode(
|
|
547
773
|
uuid=record['uuid'],
|
|
548
774
|
name=record['name'],
|
|
549
|
-
|
|
550
|
-
|
|
775
|
+
name_embedding=record.get('name_embedding'),
|
|
776
|
+
group_id=group_id,
|
|
777
|
+
labels=labels,
|
|
551
778
|
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
552
779
|
summary=record['summary'],
|
|
553
|
-
attributes=
|
|
780
|
+
attributes=attributes,
|
|
554
781
|
)
|
|
555
782
|
|
|
556
|
-
entity_node.attributes.pop('uuid', None)
|
|
557
|
-
entity_node.attributes.pop('name', None)
|
|
558
|
-
entity_node.attributes.pop('group_id', None)
|
|
559
|
-
entity_node.attributes.pop('name_embedding', None)
|
|
560
|
-
entity_node.attributes.pop('summary', None)
|
|
561
|
-
entity_node.attributes.pop('created_at', None)
|
|
562
|
-
|
|
563
783
|
return entity_node
|
|
564
784
|
|
|
565
785
|
|
|
@@ -575,8 +795,12 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
|
|
|
575
795
|
|
|
576
796
|
|
|
577
797
|
async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
|
|
578
|
-
|
|
798
|
+
# filter out falsey values from nodes
|
|
799
|
+
filtered_nodes = [node for node in nodes if node.name]
|
|
800
|
+
|
|
801
|
+
if not filtered_nodes:
|
|
579
802
|
return
|
|
580
|
-
|
|
581
|
-
for node
|
|
803
|
+
|
|
804
|
+
name_embeddings = await embedder.create_batch([node.name for node in filtered_nodes])
|
|
805
|
+
for node, name_embedding in zip(filtered_nodes, name_embeddings, strict=True):
|
|
582
806
|
node.name_embedding = name_embedding
|