graphiti-core 0.19.0rc3__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/nodes.py CHANGED
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import json
17
18
  import logging
18
19
  from abc import ABC, abstractmethod
19
20
  from datetime import datetime
@@ -32,10 +33,10 @@ from graphiti_core.helpers import parse_db_date
32
33
  from graphiti_core.models.nodes.node_db_queries import (
33
34
  COMMUNITY_NODE_RETURN,
34
35
  COMMUNITY_NODE_RETURN_NEPTUNE,
35
- ENTITY_NODE_RETURN,
36
36
  EPISODIC_NODE_RETURN,
37
37
  EPISODIC_NODE_RETURN_NEPTUNE,
38
38
  get_community_node_save_query,
39
+ get_entity_node_return_query,
39
40
  get_entity_node_save_query,
40
41
  get_episode_node_save_query,
41
42
  )
@@ -95,12 +96,37 @@ class Node(BaseModel, ABC):
95
96
  case GraphProvider.NEO4J:
96
97
  await driver.execute_query(
97
98
  """
98
- MATCH (n:Entity|Episodic|Community {uuid: $uuid})
99
- DETACH DELETE n
100
- """,
99
+ MATCH (n:Entity|Episodic|Community {uuid: $uuid})
100
+ DETACH DELETE n
101
+ """,
102
+ uuid=self.uuid,
103
+ )
104
+ case GraphProvider.KUZU:
105
+ for label in ['Episodic', 'Community']:
106
+ await driver.execute_query(
107
+ f"""
108
+ MATCH (n:{label} {{uuid: $uuid}})
109
+ DETACH DELETE n
110
+ """,
111
+ uuid=self.uuid,
112
+ )
113
+ # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
114
+ # Explicitly delete the "edge" nodes first, then the entity node.
115
+ await driver.execute_query(
116
+ """
117
+ MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
118
+ DETACH DELETE e
119
+ """,
101
120
  uuid=self.uuid,
102
121
  )
103
- case _: # FalkorDB and Neptune
122
+ await driver.execute_query(
123
+ """
124
+ MATCH (n:Entity {uuid: $uuid})
125
+ DETACH DELETE n
126
+ """,
127
+ uuid=self.uuid,
128
+ )
129
+ case _: # FalkorDB, Neptune
104
130
  for label in ['Entity', 'Episodic', 'Community']:
105
131
  await driver.execute_query(
106
132
  f"""
@@ -136,8 +162,32 @@ class Node(BaseModel, ABC):
136
162
  group_id=group_id,
137
163
  batch_size=batch_size,
138
164
  )
139
-
140
- case _: # FalkorDB and Neptune
165
+ case GraphProvider.KUZU:
166
+ for label in ['Episodic', 'Community']:
167
+ await driver.execute_query(
168
+ f"""
169
+ MATCH (n:{label} {{group_id: $group_id}})
170
+ DETACH DELETE n
171
+ """,
172
+ group_id=group_id,
173
+ )
174
+ # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
175
+ # Explicitly delete the "edge" nodes first, then the entity node.
176
+ await driver.execute_query(
177
+ """
178
+ MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
179
+ DETACH DELETE e
180
+ """,
181
+ group_id=group_id,
182
+ )
183
+ await driver.execute_query(
184
+ """
185
+ MATCH (n:Entity {group_id: $group_id})
186
+ DETACH DELETE n
187
+ """,
188
+ group_id=group_id,
189
+ )
190
+ case _: # FalkorDB, Neptune
141
191
  for label in ['Entity', 'Episodic', 'Community']:
142
192
  await driver.execute_query(
143
193
  f"""
@@ -149,30 +199,59 @@ class Node(BaseModel, ABC):
149
199
 
150
200
  @classmethod
151
201
  async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
152
- if driver.provider == GraphProvider.FALKORDB:
153
- for label in ['Entity', 'Episodic', 'Community']:
202
+ match driver.provider:
203
+ case GraphProvider.FALKORDB:
204
+ for label in ['Entity', 'Episodic', 'Community']:
205
+ await driver.execute_query(
206
+ f"""
207
+ MATCH (n:{label})
208
+ WHERE n.uuid IN $uuids
209
+ DETACH DELETE n
210
+ """,
211
+ uuids=uuids,
212
+ )
213
+ case GraphProvider.KUZU:
214
+ for label in ['Episodic', 'Community']:
215
+ await driver.execute_query(
216
+ f"""
217
+ MATCH (n:{label})
218
+ WHERE n.uuid IN $uuids
219
+ DETACH DELETE n
220
+ """,
221
+ uuids=uuids,
222
+ )
223
+ # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
224
+ # Explicitly delete the "edge" nodes first, then the entity node.
154
225
  await driver.execute_query(
155
- f"""
156
- MATCH (n:{label})
157
- WHERE n.uuid IN $uuids
158
- DETACH DELETE n
159
- """,
226
+ """
227
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
228
+ WHERE n.uuid IN $uuids
229
+ DETACH DELETE e
230
+ """,
160
231
  uuids=uuids,
161
232
  )
162
- else:
163
- async with driver.session() as session:
164
- await session.run(
233
+ await driver.execute_query(
165
234
  """
166
- MATCH (n:Entity|Episodic|Community)
235
+ MATCH (n:Entity)
167
236
  WHERE n.uuid IN $uuids
168
- CALL {
169
- WITH n
170
- DETACH DELETE n
171
- } IN TRANSACTIONS OF $batch_size ROWS
237
+ DETACH DELETE n
172
238
  """,
173
239
  uuids=uuids,
174
- batch_size=batch_size,
175
240
  )
241
+ case _: # Neo4J, Neptune
242
+ async with driver.session() as session:
243
+ await session.run(
244
+ """
245
+ MATCH (n:Entity|Episodic|Community)
246
+ WHERE n.uuid IN $uuids
247
+ CALL {
248
+ WITH n
249
+ DETACH DELETE n
250
+ } IN TRANSACTIONS OF $batch_size ROWS
251
+ """,
252
+ uuids=uuids,
253
+ batch_size=batch_size,
254
+ )
176
255
 
177
256
  @classmethod
178
257
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
@@ -207,18 +286,24 @@ class EpisodicNode(Node):
207
286
  }
208
287
  ],
209
288
  )
289
+
290
+ episode_args = {
291
+ 'uuid': self.uuid,
292
+ 'name': self.name,
293
+ 'group_id': self.group_id,
294
+ 'source_description': self.source_description,
295
+ 'content': self.content,
296
+ 'entity_edges': self.entity_edges,
297
+ 'created_at': self.created_at,
298
+ 'valid_at': self.valid_at,
299
+ 'source': self.source.value,
300
+ }
301
+
302
+ if driver.provider == GraphProvider.NEO4J:
303
+ episode_args['group_label'] = 'Episodic_' + self.group_id.replace('-', '')
304
+
210
305
  result = await driver.execute_query(
211
- get_episode_node_save_query(driver.provider),
212
- uuid=self.uuid,
213
- name=self.name,
214
- group_id=self.group_id,
215
- group_label='Episodic_' + self.group_id.replace('-', ''),
216
- source_description=self.source_description,
217
- content=self.content,
218
- entity_edges=self.entity_edges,
219
- created_at=self.created_at,
220
- valid_at=self.valid_at,
221
- source=self.source.value,
306
+ get_episode_node_save_query(driver.provider), **episode_args
222
307
  )
223
308
 
224
309
  logger.debug(f'Saved Node to Graph: {self.uuid}')
@@ -376,17 +461,25 @@ class EntityNode(Node):
376
461
  'summary': self.summary,
377
462
  'created_at': self.created_at,
378
463
  }
379
- entity_data.update(self.attributes or {})
380
464
 
381
- if driver.provider == GraphProvider.NEPTUNE:
382
- driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
465
+ if driver.provider == GraphProvider.KUZU:
466
+ entity_data['attributes'] = json.dumps(self.attributes)
467
+ entity_data['labels'] = list(set(self.labels + ['Entity']))
468
+ result = await driver.execute_query(
469
+ get_entity_node_save_query(driver.provider, labels=''),
470
+ **entity_data,
471
+ )
472
+ else:
473
+ entity_data.update(self.attributes or {})
474
+ labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
383
475
 
384
- labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
476
+ if driver.provider == GraphProvider.NEPTUNE:
477
+ driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
385
478
 
386
- result = await driver.execute_query(
387
- get_entity_node_save_query(driver.provider, labels),
388
- entity_data=entity_data,
389
- )
479
+ result = await driver.execute_query(
480
+ get_entity_node_save_query(driver.provider, labels),
481
+ entity_data=entity_data,
482
+ )
390
483
 
391
484
  logger.debug(f'Saved Node to Graph: {self.uuid}')
392
485
 
@@ -399,12 +492,12 @@ class EntityNode(Node):
399
492
  MATCH (n:Entity {uuid: $uuid})
400
493
  RETURN
401
494
  """
402
- + ENTITY_NODE_RETURN,
495
+ + get_entity_node_return_query(driver.provider),
403
496
  uuid=uuid,
404
497
  routing_='r',
405
498
  )
406
499
 
407
- nodes = [get_entity_node_from_record(record) for record in records]
500
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
408
501
 
409
502
  if len(nodes) == 0:
410
503
  raise NodeNotFoundError(uuid)
@@ -419,12 +512,12 @@ class EntityNode(Node):
419
512
  WHERE n.uuid IN $uuids
420
513
  RETURN
421
514
  """
422
- + ENTITY_NODE_RETURN,
515
+ + get_entity_node_return_query(driver.provider),
423
516
  uuids=uuids,
424
517
  routing_='r',
425
518
  )
426
519
 
427
- nodes = [get_entity_node_from_record(record) for record in records]
520
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
428
521
 
429
522
  return nodes
430
523
 
@@ -456,7 +549,7 @@ class EntityNode(Node):
456
549
  + """
457
550
  RETURN
458
551
  """
459
- + ENTITY_NODE_RETURN
552
+ + get_entity_node_return_query(driver.provider)
460
553
  + with_embeddings_query
461
554
  + """
462
555
  ORDER BY n.uuid DESC
@@ -468,7 +561,7 @@ class EntityNode(Node):
468
561
  routing_='r',
469
562
  )
470
563
 
471
- nodes = [get_entity_node_from_record(record) for record in records]
564
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
472
565
 
473
566
  return nodes
474
567
 
@@ -533,7 +626,7 @@ class CommunityNode(Node):
533
626
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
534
627
  records, _, _ = await driver.execute_query(
535
628
  """
536
- MATCH (n:Community {uuid: $uuid})
629
+ MATCH (c:Community {uuid: $uuid})
537
630
  RETURN
538
631
  """
539
632
  + (
@@ -556,8 +649,8 @@ class CommunityNode(Node):
556
649
  async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
557
650
  records, _, _ = await driver.execute_query(
558
651
  """
559
- MATCH (n:Community)
560
- WHERE n.uuid IN $uuids
652
+ MATCH (c:Community)
653
+ WHERE c.uuid IN $uuids
561
654
  RETURN
562
655
  """
563
656
  + (
@@ -581,13 +674,13 @@ class CommunityNode(Node):
581
674
  limit: int | None = None,
582
675
  uuid_cursor: str | None = None,
583
676
  ):
584
- cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
677
+ cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
585
678
  limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
586
679
 
587
680
  records, _, _ = await driver.execute_query(
588
681
  """
589
- MATCH (n:Community)
590
- WHERE n.group_id IN $group_ids
682
+ MATCH (c:Community)
683
+ WHERE c.group_id IN $group_ids
591
684
  """
592
685
  + cursor_query
593
686
  + """
@@ -599,7 +692,7 @@ class CommunityNode(Node):
599
692
  else COMMUNITY_NODE_RETURN
600
693
  )
601
694
  + """
602
- ORDER BY n.uuid DESC
695
+ ORDER BY c.uuid DESC
603
696
  """
604
697
  + limit_query,
605
698
  group_ids=group_ids,
@@ -636,25 +729,35 @@ def get_episodic_node_from_record(record: Any) -> EpisodicNode:
636
729
  )
637
730
 
638
731
 
639
- def get_entity_node_from_record(record: Any) -> EntityNode:
732
+ def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
733
+ if provider == GraphProvider.KUZU:
734
+ attributes = json.loads(record['attributes']) if record['attributes'] else {}
735
+ else:
736
+ attributes = record['attributes']
737
+ attributes.pop('uuid', None)
738
+ attributes.pop('name', None)
739
+ attributes.pop('group_id', None)
740
+ attributes.pop('name_embedding', None)
741
+ attributes.pop('summary', None)
742
+ attributes.pop('created_at', None)
743
+ attributes.pop('labels', None)
744
+
745
+ labels = record.get('labels', [])
746
+ group_id = record.get('group_id')
747
+ if 'Entity_' + group_id.replace('-', '') in labels:
748
+ labels.remove('Entity_' + group_id.replace('-', ''))
749
+
640
750
  entity_node = EntityNode(
641
751
  uuid=record['uuid'],
642
752
  name=record['name'],
643
753
  name_embedding=record.get('name_embedding'),
644
- group_id=record['group_id'],
645
- labels=record['labels'],
754
+ group_id=group_id,
755
+ labels=labels,
646
756
  created_at=parse_db_date(record['created_at']), # type: ignore
647
757
  summary=record['summary'],
648
- attributes=record['attributes'],
758
+ attributes=attributes,
649
759
  )
650
760
 
651
- entity_node.attributes.pop('uuid', None)
652
- entity_node.attributes.pop('name', None)
653
- entity_node.attributes.pop('group_id', None)
654
- entity_node.attributes.pop('name_embedding', None)
655
- entity_node.attributes.pop('summary', None)
656
- entity_node.attributes.pop('created_at', None)
657
-
658
761
  return entity_node
659
762
 
660
763
 
@@ -325,12 +325,20 @@ async def node_search(
325
325
  search_tasks = []
326
326
  if NodeSearchMethod.bm25 in config.search_methods:
327
327
  search_tasks.append(
328
- node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
328
+ node_fulltext_search(
329
+ driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
330
+ )
329
331
  )
330
332
  if NodeSearchMethod.cosine_similarity in config.search_methods:
331
333
  search_tasks.append(
332
334
  node_similarity_search(
333
- driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
335
+ driver,
336
+ query_vector,
337
+ search_filter,
338
+ group_ids,
339
+ 2 * limit,
340
+ config.sim_min_score,
341
+ config.use_local_indexes,
334
342
  )
335
343
  )
336
344
  if NodeSearchMethod.bfs in config.search_methods:
@@ -426,7 +434,9 @@ async def episode_search(
426
434
  search_results: list[list[EpisodicNode]] = list(
427
435
  await semaphore_gather(
428
436
  *[
429
- episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
437
+ episode_fulltext_search(
438
+ driver, query, search_filter, group_ids, 2 * limit, config.use_local_indexes
439
+ ),
430
440
  ]
431
441
  )
432
442
  )
@@ -24,6 +24,7 @@ from graphiti_core.search.search_utils import (
24
24
  DEFAULT_MIN_SCORE,
25
25
  DEFAULT_MMR_LAMBDA,
26
26
  MAX_SEARCH_DEPTH,
27
+ USE_HNSW,
27
28
  )
28
29
 
29
30
  DEFAULT_SEARCH_LIMIT = 10
@@ -91,6 +92,7 @@ class NodeSearchConfig(BaseModel):
91
92
  sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
92
93
  mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
93
94
  bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
95
+ use_local_indexes: bool = Field(default=USE_HNSW)
94
96
 
95
97
 
96
98
  class EpisodeSearchConfig(BaseModel):
@@ -99,6 +101,7 @@ class EpisodeSearchConfig(BaseModel):
99
101
  sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
100
102
  mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
101
103
  bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
104
+ use_local_indexes: bool = Field(default=USE_HNSW)
102
105
 
103
106
 
104
107
  class CommunitySearchConfig(BaseModel):
@@ -107,6 +110,7 @@ class CommunitySearchConfig(BaseModel):
107
110
  sim_min_score: float = Field(default=DEFAULT_MIN_SCORE)
108
111
  mmr_lambda: float = Field(default=DEFAULT_MMR_LAMBDA)
109
112
  bfs_max_depth: int = Field(default=MAX_SEARCH_DEPTH)
113
+ use_local_indexes: bool = Field(default=USE_HNSW)
110
114
 
111
115
 
112
116
  class SearchConfig(BaseModel):
@@ -20,6 +20,8 @@ from typing import Any
20
20
 
21
21
  from pydantic import BaseModel, Field
22
22
 
23
+ from graphiti_core.driver.driver import GraphProvider
24
+
23
25
 
24
26
  class ComparisonOperator(Enum):
25
27
  equals = '='
@@ -54,16 +56,21 @@ class SearchFilters(BaseModel):
54
56
 
55
57
  def node_search_filter_query_constructor(
56
58
  filters: SearchFilters,
57
- ) -> tuple[str, dict[str, Any]]:
58
- filter_query: str = ''
59
+ provider: GraphProvider,
60
+ ) -> tuple[list[str], dict[str, Any]]:
61
+ filter_queries: list[str] = []
59
62
  filter_params: dict[str, Any] = {}
60
63
 
61
64
  if filters.node_labels is not None:
62
- node_labels = '|'.join(filters.node_labels)
63
- node_label_filter = ' AND n:' + node_labels
64
- filter_query += node_label_filter
65
+ if provider == GraphProvider.KUZU:
66
+ node_label_filter = 'list_has_all(n.labels, $labels)'
67
+ filter_params['labels'] = filters.node_labels
68
+ else:
69
+ node_labels = '|'.join(filters.node_labels)
70
+ node_label_filter = 'n:' + node_labels
71
+ filter_queries.append(node_label_filter)
65
72
 
66
- return filter_query, filter_params
73
+ return filter_queries, filter_params
67
74
 
68
75
 
69
76
  def date_filter_query_constructor(
@@ -81,23 +88,29 @@ def date_filter_query_constructor(
81
88
 
82
89
  def edge_search_filter_query_constructor(
83
90
  filters: SearchFilters,
84
- ) -> tuple[str, dict[str, Any]]:
85
- filter_query: str = ''
91
+ provider: GraphProvider,
92
+ ) -> tuple[list[str], dict[str, Any]]:
93
+ filter_queries: list[str] = []
86
94
  filter_params: dict[str, Any] = {}
87
95
 
88
96
  if filters.edge_types is not None:
89
97
  edge_types = filters.edge_types
90
- edge_types_filter = '\nAND e.name in $edge_types'
91
- filter_query += edge_types_filter
98
+ filter_queries.append('e.name in $edge_types')
92
99
  filter_params['edge_types'] = edge_types
93
100
 
94
101
  if filters.node_labels is not None:
95
- node_labels = '|'.join(filters.node_labels)
96
- node_label_filter = '\nAND n:' + node_labels + ' AND m:' + node_labels
97
- filter_query += node_label_filter
102
+ if provider == GraphProvider.KUZU:
103
+ node_label_filter = (
104
+ 'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'
105
+ )
106
+ filter_params['labels'] = filters.node_labels
107
+ else:
108
+ node_labels = '|'.join(filters.node_labels)
109
+ node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels
110
+ filter_queries.append(node_label_filter)
98
111
 
99
112
  if filters.valid_at is not None:
100
- valid_at_filter = '\nAND ('
113
+ valid_at_filter = '('
101
114
  for i, or_list in enumerate(filters.valid_at):
102
115
  for j, date_filter in enumerate(or_list):
103
116
  if date_filter.comparison_operator not in [
@@ -125,10 +138,10 @@ def edge_search_filter_query_constructor(
125
138
  else:
126
139
  valid_at_filter += ' OR '
127
140
 
128
- filter_query += valid_at_filter
141
+ filter_queries.append(valid_at_filter)
129
142
 
130
143
  if filters.invalid_at is not None:
131
- invalid_at_filter = ' AND ('
144
+ invalid_at_filter = '('
132
145
  for i, or_list in enumerate(filters.invalid_at):
133
146
  for j, date_filter in enumerate(or_list):
134
147
  if date_filter.comparison_operator not in [
@@ -156,10 +169,10 @@ def edge_search_filter_query_constructor(
156
169
  else:
157
170
  invalid_at_filter += ' OR '
158
171
 
159
- filter_query += invalid_at_filter
172
+ filter_queries.append(invalid_at_filter)
160
173
 
161
174
  if filters.created_at is not None:
162
- created_at_filter = ' AND ('
175
+ created_at_filter = '('
163
176
  for i, or_list in enumerate(filters.created_at):
164
177
  for j, date_filter in enumerate(or_list):
165
178
  if date_filter.comparison_operator not in [
@@ -187,10 +200,10 @@ def edge_search_filter_query_constructor(
187
200
  else:
188
201
  created_at_filter += ' OR '
189
202
 
190
- filter_query += created_at_filter
203
+ filter_queries.append(created_at_filter)
191
204
 
192
205
  if filters.expired_at is not None:
193
- expired_at_filter = ' AND ('
206
+ expired_at_filter = '('
194
207
  for i, or_list in enumerate(filters.expired_at):
195
208
  for j, date_filter in enumerate(or_list):
196
209
  if date_filter.comparison_operator not in [
@@ -218,6 +231,6 @@ def edge_search_filter_query_constructor(
218
231
  else:
219
232
  expired_at_filter += ' OR '
220
233
 
221
- filter_query += expired_at_filter
234
+ filter_queries.append(expired_at_filter)
222
235
 
223
- return filter_query, filter_params
236
+ return filter_queries, filter_params