graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
  2. graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
  3. graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
  4. graphiti_core/decorators.py +110 -0
  5. graphiti_core/driver/__init__.py +19 -0
  6. graphiti_core/driver/driver.py +124 -0
  7. graphiti_core/driver/falkordb_driver.py +362 -0
  8. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  9. graphiti_core/driver/kuzu_driver.py +182 -0
  10. graphiti_core/driver/neo4j_driver.py +117 -0
  11. graphiti_core/driver/neptune_driver.py +305 -0
  12. graphiti_core/driver/search_interface/search_interface.py +89 -0
  13. graphiti_core/edges.py +287 -172
  14. graphiti_core/embedder/azure_openai.py +71 -0
  15. graphiti_core/embedder/client.py +2 -1
  16. graphiti_core/embedder/gemini.py +116 -22
  17. graphiti_core/embedder/voyage.py +13 -2
  18. graphiti_core/errors.py +8 -0
  19. graphiti_core/graph_queries.py +162 -0
  20. graphiti_core/graphiti.py +705 -193
  21. graphiti_core/graphiti_types.py +4 -2
  22. graphiti_core/helpers.py +87 -10
  23. graphiti_core/llm_client/__init__.py +16 -0
  24. graphiti_core/llm_client/anthropic_client.py +159 -56
  25. graphiti_core/llm_client/azure_openai_client.py +115 -0
  26. graphiti_core/llm_client/client.py +98 -21
  27. graphiti_core/llm_client/config.py +1 -1
  28. graphiti_core/llm_client/gemini_client.py +290 -41
  29. graphiti_core/llm_client/groq_client.py +14 -3
  30. graphiti_core/llm_client/openai_base_client.py +261 -0
  31. graphiti_core/llm_client/openai_client.py +56 -132
  32. graphiti_core/llm_client/openai_generic_client.py +91 -56
  33. graphiti_core/models/edges/edge_db_queries.py +259 -35
  34. graphiti_core/models/nodes/node_db_queries.py +311 -32
  35. graphiti_core/nodes.py +420 -205
  36. graphiti_core/prompts/dedupe_edges.py +46 -32
  37. graphiti_core/prompts/dedupe_nodes.py +67 -42
  38. graphiti_core/prompts/eval.py +4 -4
  39. graphiti_core/prompts/extract_edges.py +27 -16
  40. graphiti_core/prompts/extract_nodes.py +74 -31
  41. graphiti_core/prompts/prompt_helpers.py +39 -0
  42. graphiti_core/prompts/snippets.py +29 -0
  43. graphiti_core/prompts/summarize_nodes.py +23 -25
  44. graphiti_core/search/search.py +158 -82
  45. graphiti_core/search/search_config.py +39 -4
  46. graphiti_core/search/search_filters.py +126 -35
  47. graphiti_core/search/search_helpers.py +5 -6
  48. graphiti_core/search/search_utils.py +1405 -485
  49. graphiti_core/telemetry/__init__.py +9 -0
  50. graphiti_core/telemetry/telemetry.py +117 -0
  51. graphiti_core/tracer.py +193 -0
  52. graphiti_core/utils/bulk_utils.py +364 -285
  53. graphiti_core/utils/datetime_utils.py +13 -0
  54. graphiti_core/utils/maintenance/community_operations.py +67 -49
  55. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  56. graphiti_core/utils/maintenance/edge_operations.py +339 -197
  57. graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
  58. graphiti_core/utils/maintenance/node_operations.py +319 -238
  59. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  60. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  61. graphiti_core/utils/text_utils.py +53 -0
  62. graphiti_core-0.24.3.dist-info/METADATA +726 -0
  63. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  64. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  65. graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
  66. graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
  67. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  68. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
graphiti_core/nodes.py CHANGED
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import json
17
18
  import logging
18
19
  from abc import ABC, abstractmethod
19
20
  from datetime import datetime
@@ -22,33 +23,30 @@ from time import time
22
23
  from typing import Any
23
24
  from uuid import uuid4
24
25
 
25
- from neo4j import AsyncDriver
26
26
  from pydantic import BaseModel, Field
27
27
  from typing_extensions import LiteralString
28
28
 
29
+ from graphiti_core.driver.driver import (
30
+ GraphDriver,
31
+ GraphProvider,
32
+ )
29
33
  from graphiti_core.embedder import EmbedderClient
30
34
  from graphiti_core.errors import NodeNotFoundError
31
- from graphiti_core.helpers import DEFAULT_DATABASE
35
+ from graphiti_core.helpers import parse_db_date
32
36
  from graphiti_core.models.nodes.node_db_queries import (
33
- COMMUNITY_NODE_SAVE,
34
- ENTITY_NODE_SAVE,
35
- EPISODIC_NODE_SAVE,
37
+ COMMUNITY_NODE_RETURN,
38
+ COMMUNITY_NODE_RETURN_NEPTUNE,
39
+ EPISODIC_NODE_RETURN,
40
+ EPISODIC_NODE_RETURN_NEPTUNE,
41
+ get_community_node_save_query,
42
+ get_entity_node_return_query,
43
+ get_entity_node_save_query,
44
+ get_episode_node_save_query,
36
45
  )
37
46
  from graphiti_core.utils.datetime_utils import utc_now
38
47
 
39
48
  logger = logging.getLogger(__name__)
40
49
 
41
- ENTITY_NODE_RETURN: LiteralString = """
42
- RETURN
43
- n.uuid As uuid,
44
- n.name AS name,
45
- n.group_id AS group_id,
46
- n.created_at AS created_at,
47
- n.summary AS summary,
48
- labels(n) AS labels,
49
- properties(n) AS attributes
50
- """
51
-
52
50
 
53
51
  class EpisodeType(Enum):
54
52
  """
@@ -94,22 +92,63 @@ class Node(BaseModel, ABC):
94
92
  created_at: datetime = Field(default_factory=lambda: utc_now())
95
93
 
96
94
  @abstractmethod
97
- async def save(self, driver: AsyncDriver): ...
98
-
99
- async def delete(self, driver: AsyncDriver):
100
- result = await driver.execute_query(
101
- """
102
- MATCH (n:Entity|Episodic|Community {uuid: $uuid})
103
- DETACH DELETE n
104
- """,
105
- uuid=self.uuid,
106
- database_=DEFAULT_DATABASE,
107
- )
95
+ async def save(self, driver: GraphDriver): ...
96
+
97
+ async def delete(self, driver: GraphDriver):
98
+ if driver.graph_operations_interface:
99
+ return await driver.graph_operations_interface.node_delete(self, driver)
100
+
101
+ match driver.provider:
102
+ case GraphProvider.NEO4J:
103
+ records, _, _ = await driver.execute_query(
104
+ """
105
+ MATCH (n {uuid: $uuid})
106
+ WHERE n:Entity OR n:Episodic OR n:Community
107
+ OPTIONAL MATCH (n)-[r]-()
108
+ WITH collect(r.uuid) AS edge_uuids, n
109
+ DETACH DELETE n
110
+ RETURN edge_uuids
111
+ """,
112
+ uuid=self.uuid,
113
+ )
114
+
115
+ case GraphProvider.KUZU:
116
+ for label in ['Episodic', 'Community']:
117
+ await driver.execute_query(
118
+ f"""
119
+ MATCH (n:{label} {{uuid: $uuid}})
120
+ DETACH DELETE n
121
+ """,
122
+ uuid=self.uuid,
123
+ )
124
+ # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
125
+ # Explicitly delete the "edge" nodes first, then the entity node.
126
+ await driver.execute_query(
127
+ """
128
+ MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
129
+ DETACH DELETE e
130
+ """,
131
+ uuid=self.uuid,
132
+ )
133
+ await driver.execute_query(
134
+ """
135
+ MATCH (n:Entity {uuid: $uuid})
136
+ DETACH DELETE n
137
+ """,
138
+ uuid=self.uuid,
139
+ )
140
+ case _: # FalkorDB, Neptune
141
+ for label in ['Entity', 'Episodic', 'Community']:
142
+ await driver.execute_query(
143
+ f"""
144
+ MATCH (n:{label} {{uuid: $uuid}})
145
+ DETACH DELETE n
146
+ """,
147
+ uuid=self.uuid,
148
+ )
108
149
 
109
150
  logger.debug(f'Deleted Node: {self.uuid}')
110
151
 
111
- return result
112
-
113
152
  def __hash__(self):
114
153
  return hash(self.uuid)
115
154
 
@@ -119,23 +158,138 @@ class Node(BaseModel, ABC):
119
158
  return False
120
159
 
121
160
  @classmethod
122
- async def delete_by_group_id(cls, driver: AsyncDriver, group_id: str):
123
- await driver.execute_query(
124
- """
125
- MATCH (n:Entity|Episodic|Community {group_id: $group_id})
126
- DETACH DELETE n
127
- """,
128
- group_id=group_id,
129
- database_=DEFAULT_DATABASE,
130
- )
161
+ async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
162
+ if driver.graph_operations_interface:
163
+ return await driver.graph_operations_interface.node_delete_by_group_id(
164
+ cls, driver, group_id, batch_size
165
+ )
166
+
167
+ match driver.provider:
168
+ case GraphProvider.NEO4J:
169
+ async with driver.session() as session:
170
+ await session.run(
171
+ """
172
+ MATCH (n:Entity|Episodic|Community {group_id: $group_id})
173
+ CALL (n) {
174
+ DETACH DELETE n
175
+ } IN TRANSACTIONS OF $batch_size ROWS
176
+ """,
177
+ group_id=group_id,
178
+ batch_size=batch_size,
179
+ )
180
+
181
+ case GraphProvider.KUZU:
182
+ for label in ['Episodic', 'Community']:
183
+ await driver.execute_query(
184
+ f"""
185
+ MATCH (n:{label} {{group_id: $group_id}})
186
+ DETACH DELETE n
187
+ """,
188
+ group_id=group_id,
189
+ )
190
+ # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
191
+ # Explicitly delete the "edge" nodes first, then the entity node.
192
+ await driver.execute_query(
193
+ """
194
+ MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
195
+ DETACH DELETE e
196
+ """,
197
+ group_id=group_id,
198
+ )
199
+ await driver.execute_query(
200
+ """
201
+ MATCH (n:Entity {group_id: $group_id})
202
+ DETACH DELETE n
203
+ """,
204
+ group_id=group_id,
205
+ )
206
+ case _: # FalkorDB, Neptune
207
+ for label in ['Entity', 'Episodic', 'Community']:
208
+ await driver.execute_query(
209
+ f"""
210
+ MATCH (n:{label} {{group_id: $group_id}})
211
+ DETACH DELETE n
212
+ """,
213
+ group_id=group_id,
214
+ )
131
215
 
132
- return 'SUCCESS'
216
+ @classmethod
217
+ async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
218
+ if driver.graph_operations_interface:
219
+ return await driver.graph_operations_interface.node_delete_by_uuids(
220
+ cls, driver, uuids, group_id=None, batch_size=batch_size
221
+ )
222
+
223
+ match driver.provider:
224
+ case GraphProvider.FALKORDB:
225
+ for label in ['Entity', 'Episodic', 'Community']:
226
+ await driver.execute_query(
227
+ f"""
228
+ MATCH (n:{label})
229
+ WHERE n.uuid IN $uuids
230
+ DETACH DELETE n
231
+ """,
232
+ uuids=uuids,
233
+ )
234
+ case GraphProvider.KUZU:
235
+ for label in ['Episodic', 'Community']:
236
+ await driver.execute_query(
237
+ f"""
238
+ MATCH (n:{label})
239
+ WHERE n.uuid IN $uuids
240
+ DETACH DELETE n
241
+ """,
242
+ uuids=uuids,
243
+ )
244
+ # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
245
+ # Explicitly delete the "edge" nodes first, then the entity node.
246
+ await driver.execute_query(
247
+ """
248
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
249
+ WHERE n.uuid IN $uuids
250
+ DETACH DELETE e
251
+ """,
252
+ uuids=uuids,
253
+ )
254
+ await driver.execute_query(
255
+ """
256
+ MATCH (n:Entity)
257
+ WHERE n.uuid IN $uuids
258
+ DETACH DELETE n
259
+ """,
260
+ uuids=uuids,
261
+ )
262
+ case _: # Neo4J, Neptune
263
+ async with driver.session() as session:
264
+ # Collect all edge UUIDs before deleting nodes
265
+ await session.run(
266
+ """
267
+ MATCH (n:Entity|Episodic|Community)
268
+ WHERE n.uuid IN $uuids
269
+ MATCH (n)-[r]-()
270
+ RETURN collect(r.uuid) AS edge_uuids
271
+ """,
272
+ uuids=uuids,
273
+ )
274
+
275
+ # Now delete the nodes in batches
276
+ await session.run(
277
+ """
278
+ MATCH (n:Entity|Episodic|Community)
279
+ WHERE n.uuid IN $uuids
280
+ CALL (n) {
281
+ DETACH DELETE n
282
+ } IN TRANSACTIONS OF $batch_size ROWS
283
+ """,
284
+ uuids=uuids,
285
+ batch_size=batch_size,
286
+ )
133
287
 
134
288
  @classmethod
135
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): ...
289
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
136
290
 
137
291
  @classmethod
138
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): ...
292
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
139
293
 
140
294
 
141
295
  class EpisodicNode(Node):
@@ -150,42 +304,43 @@ class EpisodicNode(Node):
150
304
  default_factory=list,
151
305
  )
152
306
 
153
- async def save(self, driver: AsyncDriver):
307
+ async def save(self, driver: GraphDriver):
308
+ if driver.graph_operations_interface:
309
+ return await driver.graph_operations_interface.episodic_node_save(self, driver)
310
+
311
+ episode_args = {
312
+ 'uuid': self.uuid,
313
+ 'name': self.name,
314
+ 'group_id': self.group_id,
315
+ 'source_description': self.source_description,
316
+ 'content': self.content,
317
+ 'entity_edges': self.entity_edges,
318
+ 'created_at': self.created_at,
319
+ 'valid_at': self.valid_at,
320
+ 'source': self.source.value,
321
+ }
322
+
154
323
  result = await driver.execute_query(
155
- EPISODIC_NODE_SAVE,
156
- uuid=self.uuid,
157
- name=self.name,
158
- group_id=self.group_id,
159
- source_description=self.source_description,
160
- content=self.content,
161
- entity_edges=self.entity_edges,
162
- created_at=self.created_at,
163
- valid_at=self.valid_at,
164
- source=self.source.value,
165
- database_=DEFAULT_DATABASE,
324
+ get_episode_node_save_query(driver.provider), **episode_args
166
325
  )
167
326
 
168
- logger.debug(f'Saved Node to neo4j: {self.uuid}')
327
+ logger.debug(f'Saved Node to Graph: {self.uuid}')
169
328
 
170
329
  return result
171
330
 
172
331
  @classmethod
173
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
332
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
174
333
  records, _, _ = await driver.execute_query(
175
334
  """
176
- MATCH (e:Episodic {uuid: $uuid})
177
- RETURN e.content AS content,
178
- e.created_at AS created_at,
179
- e.valid_at AS valid_at,
180
- e.uuid AS uuid,
181
- e.name AS name,
182
- e.group_id AS group_id,
183
- e.source_description AS source_description,
184
- e.source AS source,
185
- e.entity_edges AS entity_edges
186
- """,
335
+ MATCH (e:Episodic {uuid: $uuid})
336
+ RETURN
337
+ """
338
+ + (
339
+ EPISODIC_NODE_RETURN_NEPTUNE
340
+ if driver.provider == GraphProvider.NEPTUNE
341
+ else EPISODIC_NODE_RETURN
342
+ ),
187
343
  uuid=uuid,
188
- database_=DEFAULT_DATABASE,
189
344
  routing_='r',
190
345
  )
191
346
 
@@ -197,23 +352,19 @@ class EpisodicNode(Node):
197
352
  return episodes[0]
198
353
 
199
354
  @classmethod
200
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
355
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
201
356
  records, _, _ = await driver.execute_query(
202
357
  """
203
- MATCH (e:Episodic) WHERE e.uuid IN $uuids
358
+ MATCH (e:Episodic)
359
+ WHERE e.uuid IN $uuids
204
360
  RETURN DISTINCT
205
- e.content AS content,
206
- e.created_at AS created_at,
207
- e.valid_at AS valid_at,
208
- e.uuid AS uuid,
209
- e.name AS name,
210
- e.group_id AS group_id,
211
- e.source_description AS source_description,
212
- e.source AS source,
213
- e.entity_edges AS entity_edges
214
- """,
361
+ """
362
+ + (
363
+ EPISODIC_NODE_RETURN_NEPTUNE
364
+ if driver.provider == GraphProvider.NEPTUNE
365
+ else EPISODIC_NODE_RETURN
366
+ ),
215
367
  uuids=uuids,
216
- database_=DEFAULT_DATABASE,
217
368
  routing_='r',
218
369
  )
219
370
 
@@ -224,7 +375,7 @@ class EpisodicNode(Node):
224
375
  @classmethod
225
376
  async def get_by_group_ids(
226
377
  cls,
227
- driver: AsyncDriver,
378
+ driver: GraphDriver,
228
379
  group_ids: list[str],
229
380
  limit: int | None = None,
230
381
  uuid_cursor: str | None = None,
@@ -234,27 +385,25 @@ class EpisodicNode(Node):
234
385
 
235
386
  records, _, _ = await driver.execute_query(
236
387
  """
237
- MATCH (e:Episodic) WHERE e.group_id IN $group_ids
238
- """
388
+ MATCH (e:Episodic)
389
+ WHERE e.group_id IN $group_ids
390
+ """
239
391
  + cursor_query
240
392
  + """
241
393
  RETURN DISTINCT
242
- e.content AS content,
243
- e.created_at AS created_at,
244
- e.valid_at AS valid_at,
245
- e.uuid AS uuid,
246
- e.name AS name,
247
- e.group_id AS group_id,
248
- e.source_description AS source_description,
249
- e.source AS source,
250
- e.entity_edges AS entity_edges
251
- ORDER BY e.uuid DESC
252
- """
394
+ """
395
+ + (
396
+ EPISODIC_NODE_RETURN_NEPTUNE
397
+ if driver.provider == GraphProvider.NEPTUNE
398
+ else EPISODIC_NODE_RETURN
399
+ )
400
+ + """
401
+ ORDER BY uuid DESC
402
+ """
253
403
  + limit_query,
254
404
  group_ids=group_ids,
255
405
  uuid=uuid_cursor,
256
406
  limit=limit,
257
- database_=DEFAULT_DATABASE,
258
407
  routing_='r',
259
408
  )
260
409
 
@@ -263,23 +412,18 @@ class EpisodicNode(Node):
263
412
  return episodes
264
413
 
265
414
  @classmethod
266
- async def get_by_entity_node_uuid(cls, driver: AsyncDriver, entity_node_uuid: str):
415
+ async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
267
416
  records, _, _ = await driver.execute_query(
268
417
  """
269
- MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
418
+ MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
270
419
  RETURN DISTINCT
271
- e.content AS content,
272
- e.created_at AS created_at,
273
- e.valid_at AS valid_at,
274
- e.uuid AS uuid,
275
- e.name AS name,
276
- e.group_id AS group_id,
277
- e.source_description AS source_description,
278
- e.source AS source,
279
- e.entity_edges AS entity_edges
280
- """,
420
+ """
421
+ + (
422
+ EPISODIC_NODE_RETURN_NEPTUNE
423
+ if driver.provider == GraphProvider.NEPTUNE
424
+ else EPISODIC_NODE_RETURN
425
+ ),
281
426
  entity_node_uuid=entity_node_uuid,
282
- database_=DEFAULT_DATABASE,
283
427
  routing_='r',
284
428
  )
285
429
 
@@ -304,13 +448,25 @@ class EntityNode(Node):
304
448
 
305
449
  return self.name_embedding
306
450
 
307
- async def load_name_embedding(self, driver: AsyncDriver):
308
- query: LiteralString = """
309
- MATCH (n:Entity {uuid: $uuid})
310
- RETURN n.name_embedding AS name_embedding
311
- """
451
+ async def load_name_embedding(self, driver: GraphDriver):
452
+ if driver.graph_operations_interface:
453
+ return await driver.graph_operations_interface.node_load_embeddings(self, driver)
454
+
455
+ if driver.provider == GraphProvider.NEPTUNE:
456
+ query: LiteralString = """
457
+ MATCH (n:Entity {uuid: $uuid})
458
+ RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
459
+ """
460
+
461
+ else:
462
+ query: LiteralString = """
463
+ MATCH (n:Entity {uuid: $uuid})
464
+ RETURN n.name_embedding AS name_embedding
465
+ """
312
466
  records, _, _ = await driver.execute_query(
313
- query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
467
+ query,
468
+ uuid=self.uuid,
469
+ routing_='r',
314
470
  )
315
471
 
316
472
  if len(records) == 0:
@@ -318,7 +474,10 @@ class EntityNode(Node):
318
474
 
319
475
  self.name_embedding = records[0]['name_embedding']
320
476
 
321
- async def save(self, driver: AsyncDriver):
477
+ async def save(self, driver: GraphDriver):
478
+ if driver.graph_operations_interface:
479
+ return await driver.graph_operations_interface.node_save(self, driver)
480
+
322
481
  entity_data: dict[str, Any] = {
323
482
  'uuid': self.uuid,
324
483
  'name': self.name,
@@ -328,35 +487,39 @@ class EntityNode(Node):
328
487
  'created_at': self.created_at,
329
488
  }
330
489
 
331
- entity_data.update(self.attributes or {})
490
+ if driver.provider == GraphProvider.KUZU:
491
+ entity_data['attributes'] = json.dumps(self.attributes)
492
+ entity_data['labels'] = list(set(self.labels + ['Entity']))
493
+ result = await driver.execute_query(
494
+ get_entity_node_save_query(driver.provider, labels=''),
495
+ **entity_data,
496
+ )
497
+ else:
498
+ entity_data.update(self.attributes or {})
499
+ labels = ':'.join(self.labels + ['Entity'])
332
500
 
333
- result = await driver.execute_query(
334
- ENTITY_NODE_SAVE,
335
- labels=self.labels + ['Entity'],
336
- entity_data=entity_data,
337
- database_=DEFAULT_DATABASE,
338
- )
501
+ result = await driver.execute_query(
502
+ get_entity_node_save_query(driver.provider, labels),
503
+ entity_data=entity_data,
504
+ )
339
505
 
340
- logger.debug(f'Saved Node to neo4j: {self.uuid}')
506
+ logger.debug(f'Saved Node to Graph: {self.uuid}')
341
507
 
342
508
  return result
343
509
 
344
510
  @classmethod
345
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
346
- query = (
347
- """
348
- MATCH (n:Entity {uuid: $uuid})
349
- """
350
- + ENTITY_NODE_RETURN
351
- )
511
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
352
512
  records, _, _ = await driver.execute_query(
353
- query,
513
+ """
514
+ MATCH (n:Entity {uuid: $uuid})
515
+ RETURN
516
+ """
517
+ + get_entity_node_return_query(driver.provider),
354
518
  uuid=uuid,
355
- database_=DEFAULT_DATABASE,
356
519
  routing_='r',
357
520
  )
358
521
 
359
- nodes = [get_entity_node_from_record(record) for record in records]
522
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
360
523
 
361
524
  if len(nodes) == 0:
362
525
  raise NodeNotFoundError(uuid)
@@ -364,50 +527,63 @@ class EntityNode(Node):
364
527
  return nodes[0]
365
528
 
366
529
  @classmethod
367
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
530
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
368
531
  records, _, _ = await driver.execute_query(
369
532
  """
370
- MATCH (n:Entity) WHERE n.uuid IN $uuids
371
- """
372
- + ENTITY_NODE_RETURN,
533
+ MATCH (n:Entity)
534
+ WHERE n.uuid IN $uuids
535
+ RETURN
536
+ """
537
+ + get_entity_node_return_query(driver.provider),
373
538
  uuids=uuids,
374
- database_=DEFAULT_DATABASE,
375
539
  routing_='r',
376
540
  )
377
541
 
378
- nodes = [get_entity_node_from_record(record) for record in records]
542
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
379
543
 
380
544
  return nodes
381
545
 
382
546
  @classmethod
383
547
  async def get_by_group_ids(
384
548
  cls,
385
- driver: AsyncDriver,
549
+ driver: GraphDriver,
386
550
  group_ids: list[str],
387
551
  limit: int | None = None,
388
552
  uuid_cursor: str | None = None,
553
+ with_embeddings: bool = False,
389
554
  ):
390
555
  cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
391
556
  limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
557
+ with_embeddings_query: LiteralString = (
558
+ """,
559
+ n.name_embedding AS name_embedding
560
+ """
561
+ if with_embeddings
562
+ else ''
563
+ )
392
564
 
393
565
  records, _, _ = await driver.execute_query(
394
566
  """
395
- MATCH (n:Entity) WHERE n.group_id IN $group_ids
396
- """
567
+ MATCH (n:Entity)
568
+ WHERE n.group_id IN $group_ids
569
+ """
397
570
  + cursor_query
398
- + ENTITY_NODE_RETURN
399
571
  + """
400
- ORDER BY n.uuid DESC
401
- """
572
+ RETURN
573
+ """
574
+ + get_entity_node_return_query(driver.provider)
575
+ + with_embeddings_query
576
+ + """
577
+ ORDER BY n.uuid DESC
578
+ """
402
579
  + limit_query,
403
580
  group_ids=group_ids,
404
581
  uuid=uuid_cursor,
405
582
  limit=limit,
406
- database_=DEFAULT_DATABASE,
407
583
  routing_='r',
408
584
  )
409
585
 
410
- nodes = [get_entity_node_from_record(record) for record in records]
586
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
411
587
 
412
588
  return nodes
413
589
 
@@ -416,19 +592,23 @@ class CommunityNode(Node):
416
592
  name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
417
593
  summary: str = Field(description='region summary of member nodes', default_factory=str)
418
594
 
419
- async def save(self, driver: AsyncDriver):
595
+ async def save(self, driver: GraphDriver):
596
+ if driver.provider == GraphProvider.NEPTUNE:
597
+ await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
598
+ 'communities',
599
+ [{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
600
+ )
420
601
  result = await driver.execute_query(
421
- COMMUNITY_NODE_SAVE,
602
+ get_community_node_save_query(driver.provider), # type: ignore
422
603
  uuid=self.uuid,
423
604
  name=self.name,
424
605
  group_id=self.group_id,
425
606
  summary=self.summary,
426
607
  name_embedding=self.name_embedding,
427
608
  created_at=self.created_at,
428
- database_=DEFAULT_DATABASE,
429
609
  )
430
610
 
431
- logger.debug(f'Saved Node to neo4j: {self.uuid}')
611
+ logger.debug(f'Saved Node to Graph: {self.uuid}')
432
612
 
433
613
  return result
434
614
 
@@ -441,13 +621,22 @@ class CommunityNode(Node):
441
621
 
442
622
  return self.name_embedding
443
623
 
444
- async def load_name_embedding(self, driver: AsyncDriver):
445
- query: LiteralString = """
624
+ async def load_name_embedding(self, driver: GraphDriver):
625
+ if driver.provider == GraphProvider.NEPTUNE:
626
+ query: LiteralString = """
627
+ MATCH (c:Community {uuid: $uuid})
628
+ RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
629
+ """
630
+ else:
631
+ query: LiteralString = """
446
632
  MATCH (c:Community {uuid: $uuid})
447
633
  RETURN c.name_embedding AS name_embedding
448
- """
634
+ """
635
+
449
636
  records, _, _ = await driver.execute_query(
450
- query, uuid=self.uuid, database_=DEFAULT_DATABASE, routing_='r'
637
+ query,
638
+ uuid=self.uuid,
639
+ routing_='r',
451
640
  )
452
641
 
453
642
  if len(records) == 0:
@@ -456,19 +645,18 @@ class CommunityNode(Node):
456
645
  self.name_embedding = records[0]['name_embedding']
457
646
 
458
647
  @classmethod
459
- async def get_by_uuid(cls, driver: AsyncDriver, uuid: str):
648
+ async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
460
649
  records, _, _ = await driver.execute_query(
461
650
  """
462
- MATCH (n:Community {uuid: $uuid})
463
- RETURN
464
- n.uuid As uuid,
465
- n.name AS name,
466
- n.group_id AS group_id,
467
- n.created_at AS created_at,
468
- n.summary AS summary
469
- """,
651
+ MATCH (c:Community {uuid: $uuid})
652
+ RETURN
653
+ """
654
+ + (
655
+ COMMUNITY_NODE_RETURN_NEPTUNE
656
+ if driver.provider == GraphProvider.NEPTUNE
657
+ else COMMUNITY_NODE_RETURN
658
+ ),
470
659
  uuid=uuid,
471
- database_=DEFAULT_DATABASE,
472
660
  routing_='r',
473
661
  )
474
662
 
@@ -480,19 +668,19 @@ class CommunityNode(Node):
480
668
  return nodes[0]
481
669
 
482
670
  @classmethod
483
- async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
671
+ async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
484
672
  records, _, _ = await driver.execute_query(
485
673
  """
486
- MATCH (n:Community) WHERE n.uuid IN $uuids
487
- RETURN
488
- n.uuid As uuid,
489
- n.name AS name,
490
- n.group_id AS group_id,
491
- n.created_at AS created_at,
492
- n.summary AS summary
493
- """,
674
+ MATCH (c:Community)
675
+ WHERE c.uuid IN $uuids
676
+ RETURN
677
+ """
678
+ + (
679
+ COMMUNITY_NODE_RETURN_NEPTUNE
680
+ if driver.provider == GraphProvider.NEPTUNE
681
+ else COMMUNITY_NODE_RETURN
682
+ ),
494
683
  uuids=uuids,
495
- database_=DEFAULT_DATABASE,
496
684
  routing_='r',
497
685
  )
498
686
 
@@ -503,33 +691,35 @@ class CommunityNode(Node):
503
691
  @classmethod
504
692
  async def get_by_group_ids(
505
693
  cls,
506
- driver: AsyncDriver,
694
+ driver: GraphDriver,
507
695
  group_ids: list[str],
508
696
  limit: int | None = None,
509
697
  uuid_cursor: str | None = None,
510
698
  ):
511
- cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
699
+ cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
512
700
  limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
513
701
 
514
702
  records, _, _ = await driver.execute_query(
515
703
  """
516
- MATCH (n:Community) WHERE n.group_id IN $group_ids
517
- """
704
+ MATCH (c:Community)
705
+ WHERE c.group_id IN $group_ids
706
+ """
518
707
  + cursor_query
519
708
  + """
520
- RETURN
521
- n.uuid As uuid,
522
- n.name AS name,
523
- n.group_id AS group_id,
524
- n.created_at AS created_at,
525
- n.summary AS summary
526
- ORDER BY n.uuid DESC
527
- """
709
+ RETURN
710
+ """
711
+ + (
712
+ COMMUNITY_NODE_RETURN_NEPTUNE
713
+ if driver.provider == GraphProvider.NEPTUNE
714
+ else COMMUNITY_NODE_RETURN
715
+ )
716
+ + """
717
+ ORDER BY c.uuid DESC
718
+ """
528
719
  + limit_query,
529
720
  group_ids=group_ids,
530
721
  uuid=uuid_cursor,
531
722
  limit=limit,
532
- database_=DEFAULT_DATABASE,
533
723
  routing_='r',
534
724
  )
535
725
 
@@ -540,10 +730,18 @@ class CommunityNode(Node):
540
730
 
541
731
  # Node helpers
542
732
  def get_episodic_node_from_record(record: Any) -> EpisodicNode:
733
+ created_at = parse_db_date(record['created_at'])
734
+ valid_at = parse_db_date(record['valid_at'])
735
+
736
+ if created_at is None:
737
+ raise ValueError(f'created_at cannot be None for episode {record.get("uuid", "unknown")}')
738
+ if valid_at is None:
739
+ raise ValueError(f'valid_at cannot be None for episode {record.get("uuid", "unknown")}')
740
+
543
741
  return EpisodicNode(
544
742
  content=record['content'],
545
- created_at=record['created_at'].to_native().timestamp(),
546
- valid_at=(record['valid_at'].to_native()),
743
+ created_at=created_at,
744
+ valid_at=valid_at,
547
745
  uuid=record['uuid'],
548
746
  group_id=record['group_id'],
549
747
  source=EpisodeType.from_str(record['source']),
@@ -553,24 +751,35 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
553
751
  )
554
752
 
555
753
 
556
- def get_entity_node_from_record(record: Any) -> EntityNode:
754
+ def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
755
+ if provider == GraphProvider.KUZU:
756
+ attributes = json.loads(record['attributes']) if record['attributes'] else {}
757
+ else:
758
+ attributes = record['attributes']
759
+ attributes.pop('uuid', None)
760
+ attributes.pop('name', None)
761
+ attributes.pop('group_id', None)
762
+ attributes.pop('name_embedding', None)
763
+ attributes.pop('summary', None)
764
+ attributes.pop('created_at', None)
765
+ attributes.pop('labels', None)
766
+
767
+ labels = record.get('labels', [])
768
+ group_id = record.get('group_id')
769
+ if 'Entity_' + group_id.replace('-', '') in labels:
770
+ labels.remove('Entity_' + group_id.replace('-', ''))
771
+
557
772
  entity_node = EntityNode(
558
773
  uuid=record['uuid'],
559
774
  name=record['name'],
560
- group_id=record['group_id'],
561
- labels=record['labels'],
562
- created_at=record['created_at'].to_native(),
775
+ name_embedding=record.get('name_embedding'),
776
+ group_id=group_id,
777
+ labels=labels,
778
+ created_at=parse_db_date(record['created_at']), # type: ignore
563
779
  summary=record['summary'],
564
- attributes=record['attributes'],
780
+ attributes=attributes,
565
781
  )
566
782
 
567
- entity_node.attributes.pop('uuid', None)
568
- entity_node.attributes.pop('name', None)
569
- entity_node.attributes.pop('group_id', None)
570
- entity_node.attributes.pop('name_embedding', None)
571
- entity_node.attributes.pop('summary', None)
572
- entity_node.attributes.pop('created_at', None)
573
-
574
783
  return entity_node
575
784
 
576
785
 
@@ -580,12 +789,18 @@ def get_community_node_from_record(record: Any) -> CommunityNode:
580
789
  name=record['name'],
581
790
  group_id=record['group_id'],
582
791
  name_embedding=record['name_embedding'],
583
- created_at=record['created_at'].to_native(),
792
+ created_at=parse_db_date(record['created_at']), # type: ignore
584
793
  summary=record['summary'],
585
794
  )
586
795
 
587
796
 
588
797
  async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
589
- name_embeddings = await embedder.create_batch([node.name for node in nodes])
590
- for node, name_embedding in zip(nodes, name_embeddings, strict=True):
798
+ # filter out falsey values from nodes
799
+ filtered_nodes = [node for node in nodes if node.name]
800
+
801
+ if not filtered_nodes:
802
+ return
803
+
804
+ name_embeddings = await embedder.create_batch([node.name for node in filtered_nodes])
805
+ for node, name_embedding in zip(filtered_nodes, name_embeddings, strict=True):
591
806
  node.name_embedding = name_embedding