graphiti-core 0.18.9__py3-none-any.whl → 0.19.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/nodes.py CHANGED
@@ -31,11 +31,13 @@ from graphiti_core.errors import NodeNotFoundError
31
31
  from graphiti_core.helpers import parse_db_date
32
32
  from graphiti_core.models.nodes.node_db_queries import (
33
33
  COMMUNITY_NODE_RETURN,
34
+ COMMUNITY_NODE_RETURN_NEPTUNE,
34
35
  ENTITY_NODE_RETURN,
35
36
  EPISODIC_NODE_RETURN,
36
- EPISODIC_NODE_SAVE,
37
+ EPISODIC_NODE_RETURN_NEPTUNE,
37
38
  get_community_node_save_query,
38
39
  get_entity_node_save_query,
40
+ get_episode_node_save_query,
39
41
  )
40
42
  from graphiti_core.utils.datetime_utils import utc_now
41
43
 
@@ -89,23 +91,24 @@ class Node(BaseModel, ABC):
89
91
  async def save(self, driver: GraphDriver): ...
90
92
 
91
93
  async def delete(self, driver: GraphDriver):
92
- if driver.provider == GraphProvider.FALKORDB:
93
- for label in ['Entity', 'Episodic', 'Community']:
94
+ match driver.provider:
95
+ case GraphProvider.NEO4J:
94
96
  await driver.execute_query(
95
- f"""
96
- MATCH (n:{label} {{uuid: $uuid}})
97
- DETACH DELETE n
98
- """,
99
- uuid=self.uuid,
100
- )
101
- else:
102
- await driver.execute_query(
103
- """
97
+ """
104
98
  MATCH (n:Entity|Episodic|Community {uuid: $uuid})
105
99
  DETACH DELETE n
106
100
  """,
107
- uuid=self.uuid,
108
- )
101
+ uuid=self.uuid,
102
+ )
103
+ case _: # FalkorDB and Neptune
104
+ for label in ['Entity', 'Episodic', 'Community']:
105
+ await driver.execute_query(
106
+ f"""
107
+ MATCH (n:{label} {{uuid: $uuid}})
108
+ DETACH DELETE n
109
+ """,
110
+ uuid=self.uuid,
111
+ )
109
112
 
110
113
  logger.debug(f'Deleted Node: {self.uuid}')
111
114
 
@@ -119,28 +122,30 @@ class Node(BaseModel, ABC):
119
122
 
120
123
  @classmethod
121
124
  async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
122
- if driver.provider == GraphProvider.FALKORDB:
123
- for label in ['Entity', 'Episodic', 'Community']:
124
- await driver.execute_query(
125
- f"""
126
- MATCH (n:{label} {{group_id: $group_id}})
127
- DETACH DELETE n
128
- """,
129
- group_id=group_id,
130
- )
131
- else:
132
- async with driver.session() as session:
133
- await session.run(
134
- """
135
- MATCH (n:Entity|Episodic|Community {group_id: $group_id})
136
- CALL {
137
- WITH n
125
+ match driver.provider:
126
+ case GraphProvider.NEO4J:
127
+ async with driver.session() as session:
128
+ await session.run(
129
+ """
130
+ MATCH (n:Entity|Episodic|Community {group_id: $group_id})
131
+ CALL {
132
+ WITH n
133
+ DETACH DELETE n
134
+ } IN TRANSACTIONS OF $batch_size ROWS
135
+ """,
136
+ group_id=group_id,
137
+ batch_size=batch_size,
138
+ )
139
+
140
+ case _: # FalkorDB and Neptune
141
+ for label in ['Entity', 'Episodic', 'Community']:
142
+ await driver.execute_query(
143
+ f"""
144
+ MATCH (n:{label} {{group_id: $group_id}})
138
145
  DETACH DELETE n
139
- } IN TRANSACTIONS OF $batch_size ROWS
140
- """,
141
- group_id=group_id,
142
- batch_size=batch_size,
143
- )
146
+ """,
147
+ group_id=group_id,
148
+ )
144
149
 
145
150
  @classmethod
146
151
  async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
@@ -189,11 +194,25 @@ class EpisodicNode(Node):
189
194
  )
190
195
 
191
196
  async def save(self, driver: GraphDriver):
197
+ if driver.provider == GraphProvider.NEPTUNE:
198
+ driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
199
+ 'episode_content',
200
+ [
201
+ {
202
+ 'uuid': self.uuid,
203
+ 'group_id': self.group_id,
204
+ 'source': self.source.value,
205
+ 'content': self.content,
206
+ 'source_description': self.source_description,
207
+ }
208
+ ],
209
+ )
192
210
  result = await driver.execute_query(
193
- EPISODIC_NODE_SAVE,
211
+ get_episode_node_save_query(driver.provider),
194
212
  uuid=self.uuid,
195
213
  name=self.name,
196
214
  group_id=self.group_id,
215
+ group_label='Episodic_' + self.group_id.replace('-', ''),
197
216
  source_description=self.source_description,
198
217
  content=self.content,
199
218
  entity_edges=self.entity_edges,
@@ -213,7 +232,11 @@ class EpisodicNode(Node):
213
232
  MATCH (e:Episodic {uuid: $uuid})
214
233
  RETURN
215
234
  """
216
- + EPISODIC_NODE_RETURN,
235
+ + (
236
+ EPISODIC_NODE_RETURN_NEPTUNE
237
+ if driver.provider == GraphProvider.NEPTUNE
238
+ else EPISODIC_NODE_RETURN
239
+ ),
217
240
  uuid=uuid,
218
241
  routing_='r',
219
242
  )
@@ -233,7 +256,11 @@ class EpisodicNode(Node):
233
256
  WHERE e.uuid IN $uuids
234
257
  RETURN DISTINCT
235
258
  """
236
- + EPISODIC_NODE_RETURN,
259
+ + (
260
+ EPISODIC_NODE_RETURN_NEPTUNE
261
+ if driver.provider == GraphProvider.NEPTUNE
262
+ else EPISODIC_NODE_RETURN
263
+ ),
237
264
  uuids=uuids,
238
265
  routing_='r',
239
266
  )
@@ -262,7 +289,11 @@ class EpisodicNode(Node):
262
289
  + """
263
290
  RETURN DISTINCT
264
291
  """
265
- + EPISODIC_NODE_RETURN
292
+ + (
293
+ EPISODIC_NODE_RETURN_NEPTUNE
294
+ if driver.provider == GraphProvider.NEPTUNE
295
+ else EPISODIC_NODE_RETURN
296
+ )
266
297
  + """
267
298
  ORDER BY uuid DESC
268
299
  """
@@ -284,7 +315,11 @@ class EpisodicNode(Node):
284
315
  MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
285
316
  RETURN DISTINCT
286
317
  """
287
- + EPISODIC_NODE_RETURN,
318
+ + (
319
+ EPISODIC_NODE_RETURN_NEPTUNE
320
+ if driver.provider == GraphProvider.NEPTUNE
321
+ else EPISODIC_NODE_RETURN
322
+ ),
288
323
  entity_node_uuid=entity_node_uuid,
289
324
  routing_='r',
290
325
  )
@@ -311,11 +346,18 @@ class EntityNode(Node):
311
346
  return self.name_embedding
312
347
 
313
348
  async def load_name_embedding(self, driver: GraphDriver):
314
- records, _, _ = await driver.execute_query(
349
+ if driver.provider == GraphProvider.NEPTUNE:
350
+ query: LiteralString = """
351
+ MATCH (n:Entity {uuid: $uuid})
352
+ RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
315
353
  """
316
- MATCH (n:Entity {uuid: $uuid})
317
- RETURN n.name_embedding AS name_embedding
318
- """,
354
+ else:
355
+ query: LiteralString = """
356
+ MATCH (n:Entity {uuid: $uuid})
357
+ RETURN n.name_embedding AS name_embedding
358
+ """
359
+ records, _, _ = await driver.execute_query(
360
+ query,
319
361
  uuid=self.uuid,
320
362
  routing_='r',
321
363
  )
@@ -336,7 +378,10 @@ class EntityNode(Node):
336
378
  }
337
379
  entity_data.update(self.attributes or {})
338
380
 
339
- labels = ':'.join(self.labels + ['Entity'])
381
+ if driver.provider == GraphProvider.NEPTUNE:
382
+ driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
383
+
384
+ labels = ':'.join(self.labels + ['Entity', 'Entity_' + self.group_id.replace('-', '')])
340
385
 
341
386
  result = await driver.execute_query(
342
387
  get_entity_node_save_query(driver.provider, labels),
@@ -433,8 +478,13 @@ class CommunityNode(Node):
433
478
  summary: str = Field(description='region summary of member nodes', default_factory=str)
434
479
 
435
480
  async def save(self, driver: GraphDriver):
481
+ if driver.provider == GraphProvider.NEPTUNE:
482
+ driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
483
+ 'community_name',
484
+ [{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
485
+ )
436
486
  result = await driver.execute_query(
437
- get_community_node_save_query(driver.provider),
487
+ get_community_node_save_query(driver.provider), # type: ignore
438
488
  uuid=self.uuid,
439
489
  name=self.name,
440
490
  group_id=self.group_id,
@@ -457,11 +507,19 @@ class CommunityNode(Node):
457
507
  return self.name_embedding
458
508
 
459
509
  async def load_name_embedding(self, driver: GraphDriver):
460
- records, _, _ = await driver.execute_query(
510
+ if driver.provider == GraphProvider.NEPTUNE:
511
+ query: LiteralString = """
512
+ MATCH (c:Community {uuid: $uuid})
513
+ RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
461
514
  """
515
+ else:
516
+ query: LiteralString = """
462
517
  MATCH (c:Community {uuid: $uuid})
463
518
  RETURN c.name_embedding AS name_embedding
464
- """,
519
+ """
520
+
521
+ records, _, _ = await driver.execute_query(
522
+ query,
465
523
  uuid=self.uuid,
466
524
  routing_='r',
467
525
  )
@@ -478,7 +536,11 @@ class CommunityNode(Node):
478
536
  MATCH (n:Community {uuid: $uuid})
479
537
  RETURN
480
538
  """
481
- + COMMUNITY_NODE_RETURN,
539
+ + (
540
+ COMMUNITY_NODE_RETURN_NEPTUNE
541
+ if driver.provider == GraphProvider.NEPTUNE
542
+ else COMMUNITY_NODE_RETURN
543
+ ),
482
544
  uuid=uuid,
483
545
  routing_='r',
484
546
  )
@@ -498,7 +560,11 @@ class CommunityNode(Node):
498
560
  WHERE n.uuid IN $uuids
499
561
  RETURN
500
562
  """
501
- + COMMUNITY_NODE_RETURN,
563
+ + (
564
+ COMMUNITY_NODE_RETURN_NEPTUNE
565
+ if driver.provider == GraphProvider.NEPTUNE
566
+ else COMMUNITY_NODE_RETURN
567
+ ),
502
568
  uuids=uuids,
503
569
  routing_='r',
504
570
  )
@@ -527,7 +593,11 @@ class CommunityNode(Node):
527
593
  + """
528
594
  RETURN
529
595
  """
530
- + COMMUNITY_NODE_RETURN
596
+ + (
597
+ COMMUNITY_NODE_RETURN_NEPTUNE
598
+ if driver.provider == GraphProvider.NEPTUNE
599
+ else COMMUNITY_NODE_RETURN
600
+ )
531
601
  + """
532
602
  ORDER BY n.uuid DESC
533
603
  """
@@ -21,6 +21,7 @@ from time import time
21
21
  from graphiti_core.cross_encoder.client import CrossEncoderClient
22
22
  from graphiti_core.driver.driver import GraphDriver
23
23
  from graphiti_core.edges import EntityEdge
24
+ from graphiti_core.embedder.client import EMBEDDING_DIM
24
25
  from graphiti_core.errors import SearchRerankerError
25
26
  from graphiti_core.graphiti_types import GraphitiClients
26
27
  from graphiti_core.helpers import semaphore_gather
@@ -29,6 +30,7 @@ from graphiti_core.search.search_config import (
29
30
  DEFAULT_SEARCH_LIMIT,
30
31
  CommunityReranker,
31
32
  CommunitySearchConfig,
33
+ CommunitySearchMethod,
32
34
  EdgeReranker,
33
35
  EdgeSearchConfig,
34
36
  EdgeSearchMethod,
@@ -81,11 +83,29 @@ async def search(
81
83
 
82
84
  if query.strip() == '':
83
85
  return SearchResults()
84
- query_vector = (
85
- query_vector
86
- if query_vector is not None
87
- else await embedder.create(input_data=[query.replace('\n', ' ')])
88
- )
86
+
87
+ if (
88
+ config.edge_config
89
+ and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
90
+ or config.edge_config
91
+ and EdgeReranker.mmr == config.edge_config.reranker
92
+ or config.node_config
93
+ and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
94
+ or config.node_config
95
+ and NodeReranker.mmr == config.node_config.reranker
96
+ or (
97
+ config.community_config
98
+ and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
99
+ )
100
+ or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
101
+ ):
102
+ search_vector = (
103
+ query_vector
104
+ if query_vector is not None
105
+ else await embedder.create(input_data=[query.replace('\n', ' ')])
106
+ )
107
+ else:
108
+ search_vector = [0.0] * EMBEDDING_DIM
89
109
 
90
110
  # if group_ids is empty, set it to None
91
111
  group_ids = group_ids if group_ids and group_ids != [''] else None
@@ -99,7 +119,7 @@ async def search(
99
119
  driver,
100
120
  cross_encoder,
101
121
  query,
102
- query_vector,
122
+ search_vector,
103
123
  group_ids,
104
124
  config.edge_config,
105
125
  search_filter,
@@ -112,7 +132,7 @@ async def search(
112
132
  driver,
113
133
  cross_encoder,
114
134
  query,
115
- query_vector,
135
+ search_vector,
116
136
  group_ids,
117
137
  config.node_config,
118
138
  search_filter,
@@ -125,7 +145,7 @@ async def search(
125
145
  driver,
126
146
  cross_encoder,
127
147
  query,
128
- query_vector,
148
+ search_vector,
129
149
  group_ids,
130
150
  config.episode_config,
131
151
  search_filter,
@@ -136,7 +156,7 @@ async def search(
136
156
  driver,
137
157
  cross_encoder,
138
158
  query,
139
- query_vector,
159
+ search_vector,
140
160
  group_ids,
141
161
  config.community_config,
142
162
  config.limit,