graphiti-core 0.18.9__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +4 -0
- graphiti_core/driver/falkordb_driver.py +3 -14
- graphiti_core/driver/kuzu_driver.py +175 -0
- graphiti_core/driver/neptune_driver.py +301 -0
- graphiti_core/edges.py +155 -62
- graphiti_core/graph_queries.py +31 -2
- graphiti_core/graphiti.py +6 -1
- graphiti_core/helpers.py +8 -8
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/openai_base_client.py +12 -2
- graphiti_core/llm_client/openai_client.py +10 -2
- graphiti_core/migrations/__init__.py +0 -0
- graphiti_core/migrations/neo4j_node_group_labels.py +114 -0
- graphiti_core/models/edges/edge_db_queries.py +205 -76
- graphiti_core/models/nodes/node_db_queries.py +253 -74
- graphiti_core/nodes.py +271 -98
- graphiti_core/search/search.py +42 -12
- graphiti_core/search/search_config.py +4 -0
- graphiti_core/search/search_filters.py +35 -22
- graphiti_core/search/search_utils.py +1329 -392
- graphiti_core/utils/bulk_utils.py +50 -15
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +39 -32
- graphiti_core/utils/maintenance/edge_operations.py +47 -13
- graphiti_core/utils/maintenance/graph_data_operations.py +100 -15
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/METADATA +87 -13
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/RECORD +29 -25
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.18.9.dist-info → graphiti_core-0.19.0.dist-info}/licenses/LICENSE +0 -0
graphiti_core/nodes.py
CHANGED
|
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
|
|
|
14
14
|
limitations under the License.
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
|
+
import json
|
|
17
18
|
import logging
|
|
18
19
|
from abc import ABC, abstractmethod
|
|
19
20
|
from datetime import datetime
|
|
@@ -31,11 +32,13 @@ from graphiti_core.errors import NodeNotFoundError
|
|
|
31
32
|
from graphiti_core.helpers import parse_db_date
|
|
32
33
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
33
34
|
COMMUNITY_NODE_RETURN,
|
|
34
|
-
|
|
35
|
+
COMMUNITY_NODE_RETURN_NEPTUNE,
|
|
35
36
|
EPISODIC_NODE_RETURN,
|
|
36
|
-
|
|
37
|
+
EPISODIC_NODE_RETURN_NEPTUNE,
|
|
37
38
|
get_community_node_save_query,
|
|
39
|
+
get_entity_node_return_query,
|
|
38
40
|
get_entity_node_save_query,
|
|
41
|
+
get_episode_node_save_query,
|
|
39
42
|
)
|
|
40
43
|
from graphiti_core.utils.datetime_utils import utc_now
|
|
41
44
|
|
|
@@ -89,23 +92,49 @@ class Node(BaseModel, ABC):
|
|
|
89
92
|
async def save(self, driver: GraphDriver): ...
|
|
90
93
|
|
|
91
94
|
async def delete(self, driver: GraphDriver):
|
|
92
|
-
|
|
93
|
-
|
|
95
|
+
match driver.provider:
|
|
96
|
+
case GraphProvider.NEO4J:
|
|
94
97
|
await driver.execute_query(
|
|
95
|
-
|
|
96
|
-
MATCH (n:
|
|
98
|
+
"""
|
|
99
|
+
MATCH (n:Entity|Episodic|Community {uuid: $uuid})
|
|
97
100
|
DETACH DELETE n
|
|
98
101
|
""",
|
|
99
102
|
uuid=self.uuid,
|
|
100
103
|
)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
104
|
+
case GraphProvider.KUZU:
|
|
105
|
+
for label in ['Episodic', 'Community']:
|
|
106
|
+
await driver.execute_query(
|
|
107
|
+
f"""
|
|
108
|
+
MATCH (n:{label} {{uuid: $uuid}})
|
|
109
|
+
DETACH DELETE n
|
|
110
|
+
""",
|
|
111
|
+
uuid=self.uuid,
|
|
112
|
+
)
|
|
113
|
+
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
|
|
114
|
+
# Explicitly delete the "edge" nodes first, then the entity node.
|
|
115
|
+
await driver.execute_query(
|
|
116
|
+
"""
|
|
117
|
+
MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
|
|
118
|
+
DETACH DELETE e
|
|
119
|
+
""",
|
|
120
|
+
uuid=self.uuid,
|
|
121
|
+
)
|
|
122
|
+
await driver.execute_query(
|
|
123
|
+
"""
|
|
124
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
125
|
+
DETACH DELETE n
|
|
126
|
+
""",
|
|
127
|
+
uuid=self.uuid,
|
|
128
|
+
)
|
|
129
|
+
case _: # FalkorDB, Neptune
|
|
130
|
+
for label in ['Entity', 'Episodic', 'Community']:
|
|
131
|
+
await driver.execute_query(
|
|
132
|
+
f"""
|
|
133
|
+
MATCH (n:{label} {{uuid: $uuid}})
|
|
134
|
+
DETACH DELETE n
|
|
135
|
+
""",
|
|
136
|
+
uuid=self.uuid,
|
|
137
|
+
)
|
|
109
138
|
|
|
110
139
|
logger.debug(f'Deleted Node: {self.uuid}')
|
|
111
140
|
|
|
@@ -119,55 +148,110 @@ class Node(BaseModel, ABC):
|
|
|
119
148
|
|
|
120
149
|
@classmethod
|
|
121
150
|
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
|
|
122
|
-
|
|
123
|
-
|
|
151
|
+
match driver.provider:
|
|
152
|
+
case GraphProvider.NEO4J:
|
|
153
|
+
async with driver.session() as session:
|
|
154
|
+
await session.run(
|
|
155
|
+
"""
|
|
156
|
+
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
|
157
|
+
CALL {
|
|
158
|
+
WITH n
|
|
159
|
+
DETACH DELETE n
|
|
160
|
+
} IN TRANSACTIONS OF $batch_size ROWS
|
|
161
|
+
""",
|
|
162
|
+
group_id=group_id,
|
|
163
|
+
batch_size=batch_size,
|
|
164
|
+
)
|
|
165
|
+
case GraphProvider.KUZU:
|
|
166
|
+
for label in ['Episodic', 'Community']:
|
|
167
|
+
await driver.execute_query(
|
|
168
|
+
f"""
|
|
169
|
+
MATCH (n:{label} {{group_id: $group_id}})
|
|
170
|
+
DETACH DELETE n
|
|
171
|
+
""",
|
|
172
|
+
group_id=group_id,
|
|
173
|
+
)
|
|
174
|
+
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
|
|
175
|
+
# Explicitly delete the "edge" nodes first, then the entity node.
|
|
124
176
|
await driver.execute_query(
|
|
125
|
-
|
|
126
|
-
MATCH (n:
|
|
127
|
-
DETACH DELETE
|
|
177
|
+
"""
|
|
178
|
+
MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
|
|
179
|
+
DETACH DELETE e
|
|
128
180
|
""",
|
|
129
181
|
group_id=group_id,
|
|
130
182
|
)
|
|
131
|
-
|
|
132
|
-
async with driver.session() as session:
|
|
133
|
-
await session.run(
|
|
183
|
+
await driver.execute_query(
|
|
134
184
|
"""
|
|
135
|
-
MATCH (n:Entity
|
|
136
|
-
|
|
137
|
-
WITH n
|
|
138
|
-
DETACH DELETE n
|
|
139
|
-
} IN TRANSACTIONS OF $batch_size ROWS
|
|
185
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
186
|
+
DETACH DELETE n
|
|
140
187
|
""",
|
|
141
188
|
group_id=group_id,
|
|
142
|
-
batch_size=batch_size,
|
|
143
189
|
)
|
|
190
|
+
case _: # FalkorDB, Neptune
|
|
191
|
+
for label in ['Entity', 'Episodic', 'Community']:
|
|
192
|
+
await driver.execute_query(
|
|
193
|
+
f"""
|
|
194
|
+
MATCH (n:{label} {{group_id: $group_id}})
|
|
195
|
+
DETACH DELETE n
|
|
196
|
+
""",
|
|
197
|
+
group_id=group_id,
|
|
198
|
+
)
|
|
144
199
|
|
|
145
200
|
@classmethod
|
|
146
201
|
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
|
|
147
|
-
|
|
148
|
-
|
|
202
|
+
match driver.provider:
|
|
203
|
+
case GraphProvider.FALKORDB:
|
|
204
|
+
for label in ['Entity', 'Episodic', 'Community']:
|
|
205
|
+
await driver.execute_query(
|
|
206
|
+
f"""
|
|
207
|
+
MATCH (n:{label})
|
|
208
|
+
WHERE n.uuid IN $uuids
|
|
209
|
+
DETACH DELETE n
|
|
210
|
+
""",
|
|
211
|
+
uuids=uuids,
|
|
212
|
+
)
|
|
213
|
+
case GraphProvider.KUZU:
|
|
214
|
+
for label in ['Episodic', 'Community']:
|
|
215
|
+
await driver.execute_query(
|
|
216
|
+
f"""
|
|
217
|
+
MATCH (n:{label})
|
|
218
|
+
WHERE n.uuid IN $uuids
|
|
219
|
+
DETACH DELETE n
|
|
220
|
+
""",
|
|
221
|
+
uuids=uuids,
|
|
222
|
+
)
|
|
223
|
+
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
|
|
224
|
+
# Explicitly delete the "edge" nodes first, then the entity node.
|
|
149
225
|
await driver.execute_query(
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
226
|
+
"""
|
|
227
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
|
|
228
|
+
WHERE n.uuid IN $uuids
|
|
229
|
+
DETACH DELETE e
|
|
230
|
+
""",
|
|
155
231
|
uuids=uuids,
|
|
156
232
|
)
|
|
157
|
-
|
|
158
|
-
async with driver.session() as session:
|
|
159
|
-
await session.run(
|
|
233
|
+
await driver.execute_query(
|
|
160
234
|
"""
|
|
161
|
-
MATCH (n:Entity
|
|
235
|
+
MATCH (n:Entity)
|
|
162
236
|
WHERE n.uuid IN $uuids
|
|
163
|
-
|
|
164
|
-
WITH n
|
|
165
|
-
DETACH DELETE n
|
|
166
|
-
} IN TRANSACTIONS OF $batch_size ROWS
|
|
237
|
+
DETACH DELETE n
|
|
167
238
|
""",
|
|
168
239
|
uuids=uuids,
|
|
169
|
-
batch_size=batch_size,
|
|
170
240
|
)
|
|
241
|
+
case _: # Neo4J, Neptune
|
|
242
|
+
async with driver.session() as session:
|
|
243
|
+
await session.run(
|
|
244
|
+
"""
|
|
245
|
+
MATCH (n:Entity|Episodic|Community)
|
|
246
|
+
WHERE n.uuid IN $uuids
|
|
247
|
+
CALL {
|
|
248
|
+
WITH n
|
|
249
|
+
DETACH DELETE n
|
|
250
|
+
} IN TRANSACTIONS OF $batch_size ROWS
|
|
251
|
+
""",
|
|
252
|
+
uuids=uuids,
|
|
253
|
+
batch_size=batch_size,
|
|
254
|
+
)
|
|
171
255
|
|
|
172
256
|
@classmethod
|
|
173
257
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
|
@@ -189,17 +273,37 @@ class EpisodicNode(Node):
|
|
|
189
273
|
)
|
|
190
274
|
|
|
191
275
|
async def save(self, driver: GraphDriver):
|
|
276
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
277
|
+
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
|
278
|
+
'episode_content',
|
|
279
|
+
[
|
|
280
|
+
{
|
|
281
|
+
'uuid': self.uuid,
|
|
282
|
+
'group_id': self.group_id,
|
|
283
|
+
'source': self.source.value,
|
|
284
|
+
'content': self.content,
|
|
285
|
+
'source_description': self.source_description,
|
|
286
|
+
}
|
|
287
|
+
],
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
episode_args = {
|
|
291
|
+
'uuid': self.uuid,
|
|
292
|
+
'name': self.name,
|
|
293
|
+
'group_id': self.group_id,
|
|
294
|
+
'source_description': self.source_description,
|
|
295
|
+
'content': self.content,
|
|
296
|
+
'entity_edges': self.entity_edges,
|
|
297
|
+
'created_at': self.created_at,
|
|
298
|
+
'valid_at': self.valid_at,
|
|
299
|
+
'source': self.source.value,
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
if driver.provider == GraphProvider.NEO4J:
|
|
303
|
+
episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
|
|
304
|
+
|
|
192
305
|
result = await driver.execute_query(
|
|
193
|
-
|
|
194
|
-
uuid=self.uuid,
|
|
195
|
-
name=self.name,
|
|
196
|
-
group_id=self.group_id,
|
|
197
|
-
source_description=self.source_description,
|
|
198
|
-
content=self.content,
|
|
199
|
-
entity_edges=self.entity_edges,
|
|
200
|
-
created_at=self.created_at,
|
|
201
|
-
valid_at=self.valid_at,
|
|
202
|
-
source=self.source.value,
|
|
306
|
+
get_episode_node_save_query(driver.provider), **episode_args
|
|
203
307
|
)
|
|
204
308
|
|
|
205
309
|
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
@@ -213,7 +317,11 @@ class EpisodicNode(Node):
|
|
|
213
317
|
MATCH (e:Episodic {uuid: $uuid})
|
|
214
318
|
RETURN
|
|
215
319
|
"""
|
|
216
|
-
+
|
|
320
|
+
+ (
|
|
321
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
322
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
323
|
+
else EPISODIC_NODE_RETURN
|
|
324
|
+
),
|
|
217
325
|
uuid=uuid,
|
|
218
326
|
routing_='r',
|
|
219
327
|
)
|
|
@@ -233,7 +341,11 @@ class EpisodicNode(Node):
|
|
|
233
341
|
WHERE e.uuid IN $uuids
|
|
234
342
|
RETURN DISTINCT
|
|
235
343
|
"""
|
|
236
|
-
+
|
|
344
|
+
+ (
|
|
345
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
346
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
347
|
+
else EPISODIC_NODE_RETURN
|
|
348
|
+
),
|
|
237
349
|
uuids=uuids,
|
|
238
350
|
routing_='r',
|
|
239
351
|
)
|
|
@@ -262,7 +374,11 @@ class EpisodicNode(Node):
|
|
|
262
374
|
+ """
|
|
263
375
|
RETURN DISTINCT
|
|
264
376
|
"""
|
|
265
|
-
+
|
|
377
|
+
+ (
|
|
378
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
379
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
380
|
+
else EPISODIC_NODE_RETURN
|
|
381
|
+
)
|
|
266
382
|
+ """
|
|
267
383
|
ORDER BY uuid DESC
|
|
268
384
|
"""
|
|
@@ -284,7 +400,11 @@ class EpisodicNode(Node):
|
|
|
284
400
|
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
|
|
285
401
|
RETURN DISTINCT
|
|
286
402
|
"""
|
|
287
|
-
+
|
|
403
|
+
+ (
|
|
404
|
+
EPISODIC_NODE_RETURN_NEPTUNE
|
|
405
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
406
|
+
else EPISODIC_NODE_RETURN
|
|
407
|
+
),
|
|
288
408
|
entity_node_uuid=entity_node_uuid,
|
|
289
409
|
routing_='r',
|
|
290
410
|
)
|
|
@@ -311,11 +431,18 @@ class EntityNode(Node):
|
|
|
311
431
|
return self.name_embedding
|
|
312
432
|
|
|
313
433
|
async def load_name_embedding(self, driver: GraphDriver):
|
|
314
|
-
|
|
434
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
435
|
+
query: LiteralString = """
|
|
436
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
437
|
+
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
|
315
438
|
"""
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
439
|
+
else:
|
|
440
|
+
query: LiteralString = """
|
|
441
|
+
MATCH (n:Entity {uuid: $uuid})
|
|
442
|
+
RETURN n.name_embedding AS name_embedding
|
|
443
|
+
"""
|
|
444
|
+
records, _, _ = await driver.execute_query(
|
|
445
|
+
query,
|
|
319
446
|
uuid=self.uuid,
|
|
320
447
|
routing_='r',
|
|
321
448
|
)
|
|
@@ -334,14 +461,25 @@ class EntityNode(Node):
|
|
|
334
461
|
'summary': self.summary,
|
|
335
462
|
'created_at': self.created_at,
|
|
336
463
|
}
|
|
337
|
-
entity_data.update(self.attributes or {})
|
|
338
464
|
|
|
339
|
-
|
|
465
|
+
if driver.provider == GraphProvider.KUZU:
|
|
466
|
+
entity_data['attributes'] = json.dumps(self.attributes)
|
|
467
|
+
entity_data['labels'] = list(set(self.labels + ['Entity']))
|
|
468
|
+
result = await driver.execute_query(
|
|
469
|
+
get_entity_node_save_query(driver.provider, labels=''),
|
|
470
|
+
**entity_data,
|
|
471
|
+
)
|
|
472
|
+
else:
|
|
473
|
+
entity_data.update(self.attributes or {})
|
|
474
|
+
labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
|
|
340
475
|
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
476
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
477
|
+
driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
|
478
|
+
|
|
479
|
+
result = await driver.execute_query(
|
|
480
|
+
get_entity_node_save_query(driver.provider, labels),
|
|
481
|
+
entity_data=entity_data,
|
|
482
|
+
)
|
|
345
483
|
|
|
346
484
|
logger.debug(f'Saved Node to Graph: {self.uuid}')
|
|
347
485
|
|
|
@@ -354,12 +492,12 @@ class EntityNode(Node):
|
|
|
354
492
|
MATCH (n:Entity {uuid: $uuid})
|
|
355
493
|
RETURN
|
|
356
494
|
"""
|
|
357
|
-
+
|
|
495
|
+
+ get_entity_node_return_query(driver.provider),
|
|
358
496
|
uuid=uuid,
|
|
359
497
|
routing_='r',
|
|
360
498
|
)
|
|
361
499
|
|
|
362
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
500
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
363
501
|
|
|
364
502
|
if len(nodes) == 0:
|
|
365
503
|
raise NodeNotFoundError(uuid)
|
|
@@ -374,12 +512,12 @@ class EntityNode(Node):
|
|
|
374
512
|
WHERE n.uuid IN $uuids
|
|
375
513
|
RETURN
|
|
376
514
|
"""
|
|
377
|
-
+
|
|
515
|
+
+ get_entity_node_return_query(driver.provider),
|
|
378
516
|
uuids=uuids,
|
|
379
517
|
routing_='r',
|
|
380
518
|
)
|
|
381
519
|
|
|
382
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
520
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
383
521
|
|
|
384
522
|
return nodes
|
|
385
523
|
|
|
@@ -411,7 +549,7 @@ class EntityNode(Node):
|
|
|
411
549
|
+ """
|
|
412
550
|
RETURN
|
|
413
551
|
"""
|
|
414
|
-
+
|
|
552
|
+
+ get_entity_node_return_query(driver.provider)
|
|
415
553
|
+ with_embeddings_query
|
|
416
554
|
+ """
|
|
417
555
|
ORDER BY n.uuid DESC
|
|
@@ -423,7 +561,7 @@ class EntityNode(Node):
|
|
|
423
561
|
routing_='r',
|
|
424
562
|
)
|
|
425
563
|
|
|
426
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
564
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
427
565
|
|
|
428
566
|
return nodes
|
|
429
567
|
|
|
@@ -433,8 +571,13 @@ class CommunityNode(Node):
|
|
|
433
571
|
summary: str = Field(description='region summary of member nodes', default_factory=str)
|
|
434
572
|
|
|
435
573
|
async def save(self, driver: GraphDriver):
|
|
574
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
575
|
+
driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
|
576
|
+
'community_name',
|
|
577
|
+
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
|
|
578
|
+
)
|
|
436
579
|
result = await driver.execute_query(
|
|
437
|
-
get_community_node_save_query(driver.provider),
|
|
580
|
+
get_community_node_save_query(driver.provider), # type: ignore
|
|
438
581
|
uuid=self.uuid,
|
|
439
582
|
name=self.name,
|
|
440
583
|
group_id=self.group_id,
|
|
@@ -457,11 +600,19 @@ class CommunityNode(Node):
|
|
|
457
600
|
return self.name_embedding
|
|
458
601
|
|
|
459
602
|
async def load_name_embedding(self, driver: GraphDriver):
|
|
460
|
-
|
|
603
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
604
|
+
query: LiteralString = """
|
|
605
|
+
MATCH (c:Community {uuid: $uuid})
|
|
606
|
+
RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
|
|
461
607
|
"""
|
|
608
|
+
else:
|
|
609
|
+
query: LiteralString = """
|
|
462
610
|
MATCH (c:Community {uuid: $uuid})
|
|
463
611
|
RETURN c.name_embedding AS name_embedding
|
|
464
|
-
"""
|
|
612
|
+
"""
|
|
613
|
+
|
|
614
|
+
records, _, _ = await driver.execute_query(
|
|
615
|
+
query,
|
|
465
616
|
uuid=self.uuid,
|
|
466
617
|
routing_='r',
|
|
467
618
|
)
|
|
@@ -475,10 +626,14 @@ class CommunityNode(Node):
|
|
|
475
626
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
|
|
476
627
|
records, _, _ = await driver.execute_query(
|
|
477
628
|
"""
|
|
478
|
-
MATCH (
|
|
629
|
+
MATCH (c:Community {uuid: $uuid})
|
|
479
630
|
RETURN
|
|
480
631
|
"""
|
|
481
|
-
+
|
|
632
|
+
+ (
|
|
633
|
+
COMMUNITY_NODE_RETURN_NEPTUNE
|
|
634
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
635
|
+
else COMMUNITY_NODE_RETURN
|
|
636
|
+
),
|
|
482
637
|
uuid=uuid,
|
|
483
638
|
routing_='r',
|
|
484
639
|
)
|
|
@@ -494,11 +649,15 @@ class CommunityNode(Node):
|
|
|
494
649
|
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
|
|
495
650
|
records, _, _ = await driver.execute_query(
|
|
496
651
|
"""
|
|
497
|
-
MATCH (
|
|
498
|
-
WHERE
|
|
652
|
+
MATCH (c:Community)
|
|
653
|
+
WHERE c.uuid IN $uuids
|
|
499
654
|
RETURN
|
|
500
655
|
"""
|
|
501
|
-
+
|
|
656
|
+
+ (
|
|
657
|
+
COMMUNITY_NODE_RETURN_NEPTUNE
|
|
658
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
659
|
+
else COMMUNITY_NODE_RETURN
|
|
660
|
+
),
|
|
502
661
|
uuids=uuids,
|
|
503
662
|
routing_='r',
|
|
504
663
|
)
|
|
@@ -515,21 +674,25 @@ class CommunityNode(Node):
|
|
|
515
674
|
limit: int | None = None,
|
|
516
675
|
uuid_cursor: str | None = None,
|
|
517
676
|
):
|
|
518
|
-
cursor_query: LiteralString = 'AND
|
|
677
|
+
cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
|
|
519
678
|
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
|
|
520
679
|
|
|
521
680
|
records, _, _ = await driver.execute_query(
|
|
522
681
|
"""
|
|
523
|
-
MATCH (
|
|
524
|
-
WHERE
|
|
682
|
+
MATCH (c:Community)
|
|
683
|
+
WHERE c.group_id IN $group_ids
|
|
525
684
|
"""
|
|
526
685
|
+ cursor_query
|
|
527
686
|
+ """
|
|
528
687
|
RETURN
|
|
529
688
|
"""
|
|
530
|
-
+
|
|
689
|
+
+ (
|
|
690
|
+
COMMUNITY_NODE_RETURN_NEPTUNE
|
|
691
|
+
if driver.provider == GraphProvider.NEPTUNE
|
|
692
|
+
else COMMUNITY_NODE_RETURN
|
|
693
|
+
)
|
|
531
694
|
+ """
|
|
532
|
-
ORDER BY
|
|
695
|
+
ORDER BY c.uuid DESC
|
|
533
696
|
"""
|
|
534
697
|
+ limit_query,
|
|
535
698
|
group_ids=group_ids,
|
|
@@ -566,25 +729,35 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
|
|
|
566
729
|
)
|
|
567
730
|
|
|
568
731
|
|
|
569
|
-
def get_entity_node_from_record(record: Any) -> EntityNode:
|
|
732
|
+
def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
|
|
733
|
+
if provider == GraphProvider.KUZU:
|
|
734
|
+
attributes = json.loads(record['attributes']) if record['attributes'] else {}
|
|
735
|
+
else:
|
|
736
|
+
attributes = record['attributes']
|
|
737
|
+
attributes.pop('uuid', None)
|
|
738
|
+
attributes.pop('name', None)
|
|
739
|
+
attributes.pop('group_id', None)
|
|
740
|
+
attributes.pop('name_embedding', None)
|
|
741
|
+
attributes.pop('summary', None)
|
|
742
|
+
attributes.pop('created_at', None)
|
|
743
|
+
attributes.pop('labels', None)
|
|
744
|
+
|
|
745
|
+
labels = record.get('labels', [])
|
|
746
|
+
group_id = record.get('group_id')
|
|
747
|
+
if 'Entity_' + group_id.replace('-', '') in labels:
|
|
748
|
+
labels.remove('Entity_' + group_id.replace('-', ''))
|
|
749
|
+
|
|
570
750
|
entity_node = EntityNode(
|
|
571
751
|
uuid=record['uuid'],
|
|
572
752
|
name=record['name'],
|
|
573
753
|
name_embedding=record.get('name_embedding'),
|
|
574
|
-
group_id=
|
|
575
|
-
labels=
|
|
754
|
+
group_id=group_id,
|
|
755
|
+
labels=labels,
|
|
576
756
|
created_at=parse_db_date(record['created_at']), # type: ignore
|
|
577
757
|
summary=record['summary'],
|
|
578
|
-
attributes=
|
|
758
|
+
attributes=attributes,
|
|
579
759
|
)
|
|
580
760
|
|
|
581
|
-
entity_node.attributes.pop('uuid', None)
|
|
582
|
-
entity_node.attributes.pop('name', None)
|
|
583
|
-
entity_node.attributes.pop('group_id', None)
|
|
584
|
-
entity_node.attributes.pop('name_embedding', None)
|
|
585
|
-
entity_node.attributes.pop('summary', None)
|
|
586
|
-
entity_node.attributes.pop('created_at', None)
|
|
587
|
-
|
|
588
761
|
return entity_node
|
|
589
762
|
|
|
590
763
|
|