graphiti-core 0.17.4__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
  2. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  3. graphiti_core/decorators.py +110 -0
  4. graphiti_core/driver/driver.py +62 -2
  5. graphiti_core/driver/falkordb_driver.py +215 -23
  6. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  7. graphiti_core/driver/kuzu_driver.py +182 -0
  8. graphiti_core/driver/neo4j_driver.py +61 -8
  9. graphiti_core/driver/neptune_driver.py +305 -0
  10. graphiti_core/driver/search_interface/search_interface.py +89 -0
  11. graphiti_core/edges.py +264 -132
  12. graphiti_core/embedder/azure_openai.py +10 -3
  13. graphiti_core/embedder/client.py +2 -1
  14. graphiti_core/graph_queries.py +114 -101
  15. graphiti_core/graphiti.py +582 -255
  16. graphiti_core/graphiti_types.py +2 -0
  17. graphiti_core/helpers.py +21 -14
  18. graphiti_core/llm_client/anthropic_client.py +142 -52
  19. graphiti_core/llm_client/azure_openai_client.py +57 -19
  20. graphiti_core/llm_client/client.py +83 -21
  21. graphiti_core/llm_client/config.py +1 -1
  22. graphiti_core/llm_client/gemini_client.py +75 -57
  23. graphiti_core/llm_client/openai_base_client.py +94 -50
  24. graphiti_core/llm_client/openai_client.py +28 -8
  25. graphiti_core/llm_client/openai_generic_client.py +91 -56
  26. graphiti_core/models/edges/edge_db_queries.py +259 -35
  27. graphiti_core/models/nodes/node_db_queries.py +311 -32
  28. graphiti_core/nodes.py +388 -164
  29. graphiti_core/prompts/dedupe_edges.py +42 -31
  30. graphiti_core/prompts/dedupe_nodes.py +56 -39
  31. graphiti_core/prompts/eval.py +4 -4
  32. graphiti_core/prompts/extract_edges.py +23 -14
  33. graphiti_core/prompts/extract_nodes.py +73 -32
  34. graphiti_core/prompts/prompt_helpers.py +39 -0
  35. graphiti_core/prompts/snippets.py +29 -0
  36. graphiti_core/prompts/summarize_nodes.py +23 -25
  37. graphiti_core/search/search.py +154 -74
  38. graphiti_core/search/search_config.py +39 -4
  39. graphiti_core/search/search_filters.py +109 -31
  40. graphiti_core/search/search_helpers.py +5 -6
  41. graphiti_core/search/search_utils.py +1360 -473
  42. graphiti_core/tracer.py +193 -0
  43. graphiti_core/utils/bulk_utils.py +216 -90
  44. graphiti_core/utils/datetime_utils.py +13 -0
  45. graphiti_core/utils/maintenance/community_operations.py +62 -38
  46. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  47. graphiti_core/utils/maintenance/edge_operations.py +286 -126
  48. graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
  49. graphiti_core/utils/maintenance/node_operations.py +320 -158
  50. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  51. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  52. graphiti_core/utils/text_utils.py +53 -0
  53. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/METADATA +221 -87
  54. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  55. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  56. graphiti_core-0.17.4.dist-info/RECORD +0 -77
  57. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  58. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/licenses/LICENSE +0 -0
graphiti_core/nodes.py CHANGED
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import json
17
18
  import logging
18
19
  from abc import ABC, abstractmethod
19
20
  from datetime import datetime
@@ -25,30 +26,27 @@ from uuid import uuid4
25
26
  from pydantic import BaseModel, Field
26
27
  from typing_extensions import LiteralString
27
28
 
28
- from graphiti_core.driver.driver import GraphDriver
29
+ from graphiti_core.driver.driver import (
30
+ GraphDriver,
31
+ GraphProvider,
32
+ )
29
33
  from graphiti_core.embedder import EmbedderClient
30
34
  from graphiti_core.errors import NodeNotFoundError
31
35
  from graphiti_core.helpers import parse_db_date
32
36
  from graphiti_core.models.nodes.node_db_queries import (
33
- COMMUNITY_NODE_SAVE,
34
- ENTITY_NODE_SAVE,
35
- EPISODIC_NODE_SAVE,
37
+ COMMUNITY_NODE_RETURN,
38
+ COMMUNITY_NODE_RETURN_NEPTUNE,
39
+ EPISODIC_NODE_RETURN,
40
+ EPISODIC_NODE_RETURN_NEPTUNE,
41
+ get_community_node_save_query,
42
+ get_entity_node_return_query,
43
+ get_entity_node_save_query,
44
+ get_episode_node_save_query,
36
45
  )
37
46
  from graphiti_core.utils.datetime_utils import utc_now
38
47
 
39
48
  logger = logging.getLogger(__name__)
40
49
 
41
- ENTITY_NODE_RETURN: LiteralString = """
42
- RETURN
43
- n.uuid As uuid,
44
- n.name AS name,
45
- n.group_id AS group_id,
46
- n.created_at AS created_at,
47
- n.summary AS summary,
48
- labels(n) AS labels,
49
- properties(n) AS attributes
50
- """
51
-
52
50
 
53
51
  class EpisodeType(Enum):
54
52
  """
@@ -97,18 +95,60 @@ class Node(BaseModel, ABC):
97
95
  async def save(self, driver: GraphDriver): ...
98
96
 
99
97
  async def delete(self, driver: GraphDriver):
100
- result = await driver.execute_query(
101
- """
102
- MATCH (n:Entity|Episodic|Community {uuid: $uuid})
103
- DETACH DELETE n
104
- """,
105
- uuid=self.uuid,
106
- )
98
+ if driver.graph_operations_interface:
99
+ return await driver.graph_operations_interface.node_delete(self, driver)
100
+
101
+ match driver.provider:
102
+ case GraphProvider.NEO4J:
103
+ records, _, _ = await driver.execute_query(
104
+ """
105
+ MATCH (n {uuid: $uuid})
106
+ WHERE n:Entity OR n:Episodic OR n:Community
107
+ OPTIONAL MATCH (n)-[r]-()
108
+ WITH collect(r.uuid) AS edge_uuids, n
109
+ DETACH DELETE n
110
+ RETURN edge_uuids
111
+ """,
112
+ uuid=self.uuid,
113
+ )
114
+
115
+ case GraphProvider.KUZU:
116
+ for label in ['Episodic', 'Community']:
117
+ await driver.execute_query(
118
+ f"""
119
+ MATCH (n:{label} {{uuid: $uuid}})
120
+ DETACH DELETE n
121
+ """,
122
+ uuid=self.uuid,
123
+ )
124
+ # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
125
+ # Explicitly delete the "edge" nodes first, then the entity node.
126
+ await driver.execute_query(
127
+ """
128
+ MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
129
+ DETACH DELETE e
130
+ """,
131
+ uuid=self.uuid,
132
+ )
133
+ await driver.execute_query(
134
+ """
135
+ MATCH (n:Entity {uuid: $uuid})
136
+ DETACH DELETE n
137
+ """,
138
+ uuid=self.uuid,
139
+ )
140
+ case _: # FalkorDB, Neptune
141
+ for label in ['Entity', 'Episodic', 'Community']:
142
+ await driver.execute_query(
143
+ f"""
144
+ MATCH (n:{label} {{uuid: $uuid}})
145
+ DETACH DELETE n
146
+ """,
147
+ uuid=self.uuid,
148
+ )
107
149
 
108
150
  logger.debug(f'Deleted Node: {self.uuid}')
109
151
 
110
- return result
111
-
112
152
  def __hash__(self):
113
153
  return hash(self.uuid)
114
154
 
@@ -118,16 +158,132 @@ class Node(BaseModel, ABC):
118
158
  return False
119
159
 
120
160
  @classmethod
121
- async def delete_by_group_id(cls, driver: GraphDriver, group_id: str):
122
- await driver.execute_query(
123
- """
124
- MATCH (n:Entity|Episodic|Community {group_id: $group_id})
125
- DETACH DELETE n
126
- """,
127
- group_id=group_id,
128
- )
161
+ async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
162
+ if driver.graph_operations_interface:
163
+ return await driver.graph_operations_interface.node_delete_by_group_id(
164
+ cls, driver, group_id, batch_size
165
+ )
166
+
167
+ match driver.provider:
168
+ case GraphProvider.NEO4J:
169
+ async with driver.session() as session:
170
+ await session.run(
171
+ """
172
+ MATCH (n:Entity|Episodic|Community {group_id: $group_id})
173
+ CALL (n) {
174
+ DETACH DELETE n
175
+ } IN TRANSACTIONS OF $batch_size ROWS
176
+ """,
177
+ group_id=group_id,
178
+ batch_size=batch_size,
179
+ )
180
+
181
+ case GraphProvider.KUZU:
182
+ for label in ['Episodic', 'Community']:
183
+ await driver.execute_query(
184
+ f"""
185
+ MATCH (n:{label} {{group_id: $group_id}})
186
+ DETACH DELETE n
187
+ """,
188
+ group_id=group_id,
189
+ )
190
+ # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
191
+ # Explicitly delete the "edge" nodes first, then the entity node.
192
+ await driver.execute_query(
193
+ """
194
+ MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
195
+ DETACH DELETE e
196
+ """,
197
+ group_id=group_id,
198
+ )
199
+ await driver.execute_query(
200
+ """
201
+ MATCH (n:Entity {group_id: $group_id})
202
+ DETACH DELETE n
203
+ """,
204
+ group_id=group_id,
205
+ )
206
+ case _: # FalkorDB, Neptune
207
+ for label in ['Entity', 'Episodic', 'Community']:
208
+ await driver.execute_query(
209
+ f"""
210
+ MATCH (n:{label} {{group_id: $group_id}})
211
+ DETACH DELETE n
212
+ """,
213
+ group_id=group_id,
214
+ )
129
215
 
130
- return 'SUCCESS'
216
+ @classmethod
217
+ async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
218
+ if driver.graph_operations_interface:
219
+ return await driver.graph_operations_interface.node_delete_by_uuids(
220
+ cls, driver, uuids, group_id=None, batch_size=batch_size
221
+ )
222
+
223
+ match driver.provider:
224
+ case GraphProvider.FALKORDB:
225
+ for label in ['Entity', 'Episodic', 'Community']:
226
+ await driver.execute_query(
227
+ f"""
228
+ MATCH (n:{label})
229
+ WHERE n.uuid IN $uuids
230
+ DETACH DELETE n
231
+ """,
232
+ uuids=uuids,
233
+ )
234
+ case GraphProvider.KUZU:
235
+ for label in ['Episodic', 'Community']:
236
+ await driver.execute_query(
237
+ f"""
238
+ MATCH (n:{label})
239
+ WHERE n.uuid IN $uuids
240
+ DETACH DELETE n
241
+ """,
242
+ uuids=uuids,
243
+ )
244
+ # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
245
+ # Explicitly delete the "edge" nodes first, then the entity node.
246
+ await driver.execute_query(
247
+ """
248
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
249
+ WHERE n.uuid IN $uuids
250
+ DETACH DELETE e
251
+ """,
252
+ uuids=uuids,
253
+ )
254
+ await driver.execute_query(
255
+ """
256
+ MATCH (n:Entity)
257
+ WHERE n.uuid IN $uuids
258
+ DETACH DELETE n
259
+ """,
260
+ uuids=uuids,
261
+ )
262
+ case _: # Neo4J, Neptune
263
+ async with driver.session() as session:
264
+ # Collect all edge UUIDs before deleting nodes
265
+ await session.run(
266
+ """
267
+ MATCH (n:Entity|Episodic|Community)
268
+ WHERE n.uuid IN $uuids
269
+ MATCH (n)-[r]-()
270
+ RETURN collect(r.uuid) AS edge_uuids
271
+ """,
272
+ uuids=uuids,
273
+ )
274
+
275
+ # Now delete the nodes in batches
276
+ await session.run(
277
+ """
278
+ MATCH (n:Entity|Episodic|Community)
279
+ WHERE n.uuid IN $uuids
280
+ CALL (n) {
281
+ DETACH DELETE n
282
+ } IN TRANSACTIONS OF $batch_size ROWS
283
+ """,
284
+ uuids=uuids,
285
+ batch_size=batch_size,
286
+ )
131
287
 
132
288
  @classmethod
133
289
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
@@ -149,17 +305,23 @@ class EpisodicNode(Node):
149
305
  )
150
306
 
151
307
  async def save(self, driver: GraphDriver):
308
+ if driver.graph_operations_interface:
309
+ return await driver.graph_operations_interface.episodic_node_save(self, driver)
310
+
311
+ episode_args = {
312
+ 'uuid': self.uuid,
313
+ 'name': self.name,
314
+ 'group_id': self.group_id,
315
+ 'source_description': self.source_description,
316
+ 'content': self.content,
317
+ 'entity_edges': self.entity_edges,
318
+ 'created_at': self.created_at,
319
+ 'valid_at': self.valid_at,
320
+ 'source': self.source.value,
321
+ }
322
+
152
323
  result = await driver.execute_query(
153
- EPISODIC_NODE_SAVE,
154
- uuid=self.uuid,
155
- name=self.name,
156
- group_id=self.group_id,
157
- source_description=self.source_description,
158
- content=self.content,
159
- entity_edges=self.entity_edges,
160
- created_at=self.created_at,
161
- valid_at=self.valid_at,
162
- source=self.source.value,
324
+ get_episode_node_save_query(driver.provider), **episode_args
163
325
  )
164
326
 
165
327
  logger.debug(f'Saved Node to Graph: {self.uuid}')
@@ -170,17 +332,14 @@ class EpisodicNode(Node):
170
332
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
171
333
  records, _, _ = await driver.execute_query(
172
334
  """
173
- MATCH (e:Episodic {uuid: $uuid})
174
- RETURN e.content AS content,
175
- e.created_at AS created_at,
176
- e.valid_at AS valid_at,
177
- e.uuid AS uuid,
178
- e.name AS name,
179
- e.group_id AS group_id,
180
- e.source_description AS source_description,
181
- e.source AS source,
182
- e.entity_edges AS entity_edges
183
- """,
335
+ MATCH (e:Episodic {uuid: $uuid})
336
+ RETURN
337
+ """
338
+ + (
339
+ EPISODIC_NODE_RETURN_NEPTUNE
340
+ if driver.provider == GraphProvider.NEPTUNE
341
+ else EPISODIC_NODE_RETURN
342
+ ),
184
343
  uuid=uuid,
185
344
  routing_='r',
186
345
  )
@@ -196,18 +355,15 @@ class EpisodicNode(Node):
196
355
  async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
197
356
  records, _, _ = await driver.execute_query(
198
357
  """
199
- MATCH (e:Episodic) WHERE e.uuid IN $uuids
358
+ MATCH (e:Episodic)
359
+ WHERE e.uuid IN $uuids
200
360
  RETURN DISTINCT
201
- e.content AS content,
202
- e.created_at AS created_at,
203
- e.valid_at AS valid_at,
204
- e.uuid AS uuid,
205
- e.name AS name,
206
- e.group_id AS group_id,
207
- e.source_description AS source_description,
208
- e.source AS source,
209
- e.entity_edges AS entity_edges
210
- """,
361
+ """
362
+ + (
363
+ EPISODIC_NODE_RETURN_NEPTUNE
364
+ if driver.provider == GraphProvider.NEPTUNE
365
+ else EPISODIC_NODE_RETURN
366
+ ),
211
367
  uuids=uuids,
212
368
  routing_='r',
213
369
  )
@@ -229,22 +385,21 @@ class EpisodicNode(Node):
229
385
 
230
386
  records, _, _ = await driver.execute_query(
231
387
  """
232
- MATCH (e:Episodic) WHERE e.group_id IN $group_ids
233
- """
388
+ MATCH (e:Episodic)
389
+ WHERE e.group_id IN $group_ids
390
+ """
234
391
  + cursor_query
235
392
  + """
236
393
  RETURN DISTINCT
237
- e.content AS content,
238
- e.created_at AS created_at,
239
- e.valid_at AS valid_at,
240
- e.uuid AS uuid,
241
- e.name AS name,
242
- e.group_id AS group_id,
243
- e.source_description AS source_description,
244
- e.source AS source,
245
- e.entity_edges AS entity_edges
246
- ORDER BY e.uuid DESC
247
- """
394
+ """
395
+ + (
396
+ EPISODIC_NODE_RETURN_NEPTUNE
397
+ if driver.provider == GraphProvider.NEPTUNE
398
+ else EPISODIC_NODE_RETURN
399
+ )
400
+ + """
401
+ ORDER BY uuid DESC
402
+ """
248
403
  + limit_query,
249
404
  group_ids=group_ids,
250
405
  uuid=uuid_cursor,
@@ -260,18 +415,14 @@ class EpisodicNode(Node):
260
415
  async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
261
416
  records, _, _ = await driver.execute_query(
262
417
  """
263
- MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
418
+ MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
264
419
  RETURN DISTINCT
265
- e.content AS content,
266
- e.created_at AS created_at,
267
- e.valid_at AS valid_at,
268
- e.uuid AS uuid,
269
- e.name AS name,
270
- e.group_id AS group_id,
271
- e.source_description AS source_description,
272
- e.source AS source,
273
- e.entity_edges AS entity_edges
274
- """,
420
+ """
421
+ + (
422
+ EPISODIC_NODE_RETURN_NEPTUNE
423
+ if driver.provider == GraphProvider.NEPTUNE
424
+ else EPISODIC_NODE_RETURN
425
+ ),
275
426
  entity_node_uuid=entity_node_uuid,
276
427
  routing_='r',
277
428
  )
@@ -298,11 +449,25 @@ class EntityNode(Node):
298
449
  return self.name_embedding
299
450
 
300
451
  async def load_name_embedding(self, driver: GraphDriver):
301
- query: LiteralString = """
302
- MATCH (n:Entity {uuid: $uuid})
303
- RETURN n.name_embedding AS name_embedding
304
- """
305
- records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
452
+ if driver.graph_operations_interface:
453
+ return await driver.graph_operations_interface.node_load_embeddings(self, driver)
454
+
455
+ if driver.provider == GraphProvider.NEPTUNE:
456
+ query: LiteralString = """
457
+ MATCH (n:Entity {uuid: $uuid})
458
+ RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
459
+ """
460
+
461
+ else:
462
+ query: LiteralString = """
463
+ MATCH (n:Entity {uuid: $uuid})
464
+ RETURN n.name_embedding AS name_embedding
465
+ """
466
+ records, _, _ = await driver.execute_query(
467
+ query,
468
+ uuid=self.uuid,
469
+ routing_='r',
470
+ )
306
471
 
307
472
  if len(records) == 0:
308
473
  raise NodeNotFoundError(self.uuid)
@@ -310,6 +475,9 @@ class EntityNode(Node):
310
475
  self.name_embedding = records[0]['name_embedding']
311
476
 
312
477
  async def save(self, driver: GraphDriver):
478
+ if driver.graph_operations_interface:
479
+ return await driver.graph_operations_interface.node_save(self, driver)
480
+
313
481
  entity_data: dict[str, Any] = {
314
482
  'uuid': self.uuid,
315
483
  'name': self.name,
@@ -319,13 +487,21 @@ class EntityNode(Node):
319
487
  'created_at': self.created_at,
320
488
  }
321
489
 
322
- entity_data.update(self.attributes or {})
323
-
324
- result = await driver.execute_query(
325
- ENTITY_NODE_SAVE,
326
- labels=self.labels + ['Entity'],
327
- entity_data=entity_data,
328
- )
490
+ if driver.provider == GraphProvider.KUZU:
491
+ entity_data['attributes'] = json.dumps(self.attributes)
492
+ entity_data['labels'] = list(set(self.labels + ['Entity']))
493
+ result = await driver.execute_query(
494
+ get_entity_node_save_query(driver.provider, labels=''),
495
+ **entity_data,
496
+ )
497
+ else:
498
+ entity_data.update(self.attributes or {})
499
+ labels = ':'.join(self.labels + ['Entity'])
500
+
501
+ result = await driver.execute_query(
502
+ get_entity_node_save_query(driver.provider, labels),
503
+ entity_data=entity_data,
504
+ )
329
505
 
330
506
  logger.debug(f'Saved Node to Graph: {self.uuid}')
331
507
 
@@ -333,19 +509,17 @@ class EntityNode(Node):
333
509
 
334
510
  @classmethod
335
511
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
336
- query = (
337
- """
338
- MATCH (n:Entity {uuid: $uuid})
339
- """
340
- + ENTITY_NODE_RETURN
341
- )
342
512
  records, _, _ = await driver.execute_query(
343
- query,
513
+ """
514
+ MATCH (n:Entity {uuid: $uuid})
515
+ RETURN
516
+ """
517
+ + get_entity_node_return_query(driver.provider),
344
518
  uuid=uuid,
345
519
  routing_='r',
346
520
  )
347
521
 
348
- nodes = [get_entity_node_from_record(record) for record in records]
522
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
349
523
 
350
524
  if len(nodes) == 0:
351
525
  raise NodeNotFoundError(uuid)
@@ -356,14 +530,16 @@ class EntityNode(Node):
356
530
  async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
357
531
  records, _, _ = await driver.execute_query(
358
532
  """
359
- MATCH (n:Entity) WHERE n.uuid IN $uuids
360
- """
361
- + ENTITY_NODE_RETURN,
533
+ MATCH (n:Entity)
534
+ WHERE n.uuid IN $uuids
535
+ RETURN
536
+ """
537
+ + get_entity_node_return_query(driver.provider),
362
538
  uuids=uuids,
363
539
  routing_='r',
364
540
  )
365
541
 
366
- nodes = [get_entity_node_from_record(record) for record in records]
542
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
367
543
 
368
544
  return nodes
369
545
 
@@ -374,19 +550,32 @@ class EntityNode(Node):
374
550
  group_ids: list[str],
375
551
  limit: int | None = None,
376
552
  uuid_cursor: str | None = None,
553
+ with_embeddings: bool = False,
377
554
  ):
378
555
  cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
379
556
  limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
557
+ with_embeddings_query: LiteralString = (
558
+ """,
559
+ n.name_embedding AS name_embedding
560
+ """
561
+ if with_embeddings
562
+ else ''
563
+ )
380
564
 
381
565
  records, _, _ = await driver.execute_query(
382
566
  """
383
- MATCH (n:Entity) WHERE n.group_id IN $group_ids
384
- """
567
+ MATCH (n:Entity)
568
+ WHERE n.group_id IN $group_ids
569
+ """
385
570
  + cursor_query
386
- + ENTITY_NODE_RETURN
387
571
  + """
388
- ORDER BY n.uuid DESC
389
- """
572
+ RETURN
573
+ """
574
+ + get_entity_node_return_query(driver.provider)
575
+ + with_embeddings_query
576
+ + """
577
+ ORDER BY n.uuid DESC
578
+ """
390
579
  + limit_query,
391
580
  group_ids=group_ids,
392
581
  uuid=uuid_cursor,
@@ -394,7 +583,7 @@ class EntityNode(Node):
394
583
  routing_='r',
395
584
  )
396
585
 
397
- nodes = [get_entity_node_from_record(record) for record in records]
586
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
398
587
 
399
588
  return nodes
400
589
 
@@ -404,8 +593,13 @@ class CommunityNode(Node):
404
593
  summary: str = Field(description='region summary of member nodes', default_factory=str)
405
594
 
406
595
  async def save(self, driver: GraphDriver):
596
+ if driver.provider == GraphProvider.NEPTUNE:
597
+ await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
598
+ 'communities',
599
+ [{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
600
+ )
407
601
  result = await driver.execute_query(
408
- COMMUNITY_NODE_SAVE,
602
+ get_community_node_save_query(driver.provider), # type: ignore
409
603
  uuid=self.uuid,
410
604
  name=self.name,
411
605
  group_id=self.group_id,
@@ -428,11 +622,22 @@ class CommunityNode(Node):
428
622
  return self.name_embedding
429
623
 
430
624
  async def load_name_embedding(self, driver: GraphDriver):
431
- query: LiteralString = """
625
+ if driver.provider == GraphProvider.NEPTUNE:
626
+ query: LiteralString = """
627
+ MATCH (c:Community {uuid: $uuid})
628
+ RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
629
+ """
630
+ else:
631
+ query: LiteralString = """
432
632
  MATCH (c:Community {uuid: $uuid})
433
633
  RETURN c.name_embedding AS name_embedding
434
- """
435
- records, _, _ = await driver.execute_query(query, uuid=self.uuid, routing_='r')
634
+ """
635
+
636
+ records, _, _ = await driver.execute_query(
637
+ query,
638
+ uuid=self.uuid,
639
+ routing_='r',
640
+ )
436
641
 
437
642
  if len(records) == 0:
438
643
  raise NodeNotFoundError(self.uuid)
@@ -443,14 +648,14 @@ class CommunityNode(Node):
443
648
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
444
649
  records, _, _ = await driver.execute_query(
445
650
  """
446
- MATCH (n:Community {uuid: $uuid})
447
- RETURN
448
- n.uuid As uuid,
449
- n.name AS name,
450
- n.group_id AS group_id,
451
- n.created_at AS created_at,
452
- n.summary AS summary
453
- """,
651
+ MATCH (c:Community {uuid: $uuid})
652
+ RETURN
653
+ """
654
+ + (
655
+ COMMUNITY_NODE_RETURN_NEPTUNE
656
+ if driver.provider == GraphProvider.NEPTUNE
657
+ else COMMUNITY_NODE_RETURN
658
+ ),
454
659
  uuid=uuid,
455
660
  routing_='r',
456
661
  )
@@ -466,14 +671,15 @@ class CommunityNode(Node):
466
671
  async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
467
672
  records, _, _ = await driver.execute_query(
468
673
  """
469
- MATCH (n:Community) WHERE n.uuid IN $uuids
470
- RETURN
471
- n.uuid As uuid,
472
- n.name AS name,
473
- n.group_id AS group_id,
474
- n.created_at AS created_at,
475
- n.summary AS summary
476
- """,
674
+ MATCH (c:Community)
675
+ WHERE c.uuid IN $uuids
676
+ RETURN
677
+ """
678
+ + (
679
+ COMMUNITY_NODE_RETURN_NEPTUNE
680
+ if driver.provider == GraphProvider.NEPTUNE
681
+ else COMMUNITY_NODE_RETURN
682
+ ),
477
683
  uuids=uuids,
478
684
  routing_='r',
479
685
  )
@@ -490,23 +696,26 @@ class CommunityNode(Node):
490
696
  limit: int | None = None,
491
697
  uuid_cursor: str | None = None,
492
698
  ):
493
- cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
699
+ cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
494
700
  limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
495
701
 
496
702
  records, _, _ = await driver.execute_query(
497
703
  """
498
- MATCH (n:Community) WHERE n.group_id IN $group_ids
499
- """
704
+ MATCH (c:Community)
705
+ WHERE c.group_id IN $group_ids
706
+ """
500
707
  + cursor_query
501
708
  + """
502
- RETURN
503
- n.uuid As uuid,
504
- n.name AS name,
505
- n.group_id AS group_id,
506
- n.created_at AS created_at,
507
- n.summary AS summary
508
- ORDER BY n.uuid DESC
509
- """
709
+ RETURN
710
+ """
711
+ + (
712
+ COMMUNITY_NODE_RETURN_NEPTUNE
713
+ if driver.provider == GraphProvider.NEPTUNE
714
+ else COMMUNITY_NODE_RETURN
715
+ )
716
+ + """
717
+ ORDER BY c.uuid DESC
718
+ """
510
719
  + limit_query,
511
720
  group_ids=group_ids,
512
721
  uuid=uuid_cursor,
@@ -542,24 +751,35 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
542
751
  )
543
752
 
544
753
 
545
- def get_entity_node_from_record(record: Any) -> EntityNode:
754
+ def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
755
+ if provider == GraphProvider.KUZU:
756
+ attributes = json.loads(record['attributes']) if record['attributes'] else {}
757
+ else:
758
+ attributes = record['attributes']
759
+ attributes.pop('uuid', None)
760
+ attributes.pop('name', None)
761
+ attributes.pop('group_id', None)
762
+ attributes.pop('name_embedding', None)
763
+ attributes.pop('summary', None)
764
+ attributes.pop('created_at', None)
765
+ attributes.pop('labels', None)
766
+
767
+ labels = record.get('labels', [])
768
+ group_id = record.get('group_id')
769
+ if 'Entity_' + group_id.replace('-', '') in labels:
770
+ labels.remove('Entity_' + group_id.replace('-', ''))
771
+
546
772
  entity_node = EntityNode(
547
773
  uuid=record['uuid'],
548
774
  name=record['name'],
549
- group_id=record['group_id'],
550
- labels=record['labels'],
775
+ name_embedding=record.get('name_embedding'),
776
+ group_id=group_id,
777
+ labels=labels,
551
778
  created_at=parse_db_date(record['created_at']), # type: ignore
552
779
  summary=record['summary'],
553
- attributes=record['attributes'],
780
+ attributes=attributes,
554
781
  )
555
782
 
556
- entity_node.attributes.pop('uuid', None)
557
- entity_node.attributes.pop('name', None)
558
- entity_node.attributes.pop('group_id', None)
559
- entity_node.attributes.pop('name_embedding', None)
560
- entity_node.attributes.pop('summary', None)
561
- entity_node.attributes.pop('created_at', None)
562
-
563
783
  return entity_node
564
784
 
565
785
 
@@ -575,8 +795,12 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
575
795
 
576
796
 
577
797
  async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
578
- if not nodes: # Handle empty list case
798
+ # filter out falsey values from nodes
799
+ filtered_nodes = [node for node in nodes if node.name]
800
+
801
+ if not filtered_nodes:
579
802
  return
580
- name_embeddings = await embedder.create_batch([node.name for node in nodes])
581
- for node, name_embedding in zip(nodes, name_embeddings, strict=True):
803
+
804
+ name_embeddings = await embedder.create_batch([node.name for node in filtered_nodes])
805
+ for node, name_embedding in zip(filtered_nodes, name_embeddings, strict=True):
582
806
  node.name_embedding = name_embedding