graphiti-core 0.20.4__py3-none-any.whl → 0.21.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -14,15 +14,40 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import asyncio
17
18
  import copy
18
19
  import logging
20
+ import os
19
21
  from abc import ABC, abstractmethod
20
22
  from collections.abc import Coroutine
23
+ from datetime import datetime
21
24
  from enum import Enum
22
25
  from typing import Any
23
26
 
27
+ from dotenv import load_dotenv
28
+
29
+ from graphiti_core.embedder.client import EMBEDDING_DIM
30
+
31
+ try:
32
+ from opensearchpy import AsyncOpenSearch, helpers
33
+
34
+ _HAS_OPENSEARCH = True
35
+ except ImportError:
36
+ OpenSearch = None
37
+ helpers = None
38
+ _HAS_OPENSEARCH = False
39
+
24
40
  logger = logging.getLogger(__name__)
25
41
 
42
+ DEFAULT_SIZE = 10
43
+
44
+ load_dotenv()
45
+
46
+ ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
47
+ EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
48
+ COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
49
+ ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
50
+
26
51
 
27
52
  class GraphProvider(Enum):
28
53
  NEO4J = 'neo4j'
@@ -31,6 +56,91 @@ class GraphProvider(Enum):
31
56
  NEPTUNE = 'neptune'
32
57
 
33
58
 
59
+ aoss_indices = [
60
+ {
61
+ 'index_name': ENTITY_INDEX_NAME,
62
+ 'body': {
63
+ 'settings': {'index': {'knn': True}},
64
+ 'mappings': {
65
+ 'properties': {
66
+ 'uuid': {'type': 'keyword'},
67
+ 'name': {'type': 'text'},
68
+ 'summary': {'type': 'text'},
69
+ 'group_id': {'type': 'keyword'},
70
+ 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
71
+ 'name_embedding': {
72
+ 'type': 'knn_vector',
73
+ 'dimension': EMBEDDING_DIM,
74
+ 'method': {
75
+ 'engine': 'faiss',
76
+ 'space_type': 'cosinesimil',
77
+ 'name': 'hnsw',
78
+ 'parameters': {'ef_construction': 128, 'm': 16},
79
+ },
80
+ },
81
+ }
82
+ },
83
+ },
84
+ },
85
+ {
86
+ 'index_name': COMMUNITY_INDEX_NAME,
87
+ 'body': {
88
+ 'mappings': {
89
+ 'properties': {
90
+ 'uuid': {'type': 'keyword'},
91
+ 'name': {'type': 'text'},
92
+ 'group_id': {'type': 'keyword'},
93
+ }
94
+ }
95
+ },
96
+ },
97
+ {
98
+ 'index_name': EPISODE_INDEX_NAME,
99
+ 'body': {
100
+ 'mappings': {
101
+ 'properties': {
102
+ 'uuid': {'type': 'keyword'},
103
+ 'content': {'type': 'text'},
104
+ 'source': {'type': 'text'},
105
+ 'source_description': {'type': 'text'},
106
+ 'group_id': {'type': 'keyword'},
107
+ 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
108
+ 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
109
+ }
110
+ }
111
+ },
112
+ },
113
+ {
114
+ 'index_name': ENTITY_EDGE_INDEX_NAME,
115
+ 'body': {
116
+ 'settings': {'index': {'knn': True}},
117
+ 'mappings': {
118
+ 'properties': {
119
+ 'uuid': {'type': 'keyword'},
120
+ 'name': {'type': 'text'},
121
+ 'fact': {'type': 'text'},
122
+ 'group_id': {'type': 'keyword'},
123
+ 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
124
+ 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
125
+ 'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
126
+ 'invalid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
127
+ 'fact_embedding': {
128
+ 'type': 'knn_vector',
129
+ 'dimension': EMBEDDING_DIM,
130
+ 'method': {
131
+ 'engine': 'faiss',
132
+ 'space_type': 'cosinesimil',
133
+ 'name': 'hnsw',
134
+ 'parameters': {'ef_construction': 128, 'm': 16},
135
+ },
136
+ },
137
+ }
138
+ },
139
+ },
140
+ },
141
+ ]
142
+
143
+
34
144
  class GraphDriverSession(ABC):
35
145
  provider: GraphProvider
36
146
 
@@ -61,6 +171,7 @@ class GraphDriver(ABC):
61
171
  '' # Neo4j (default) syntax does not require a prefix for fulltext queries
62
172
  )
63
173
  _database: str
174
+ aoss_client: AsyncOpenSearch | None # type: ignore
64
175
 
65
176
  @abstractmethod
66
177
  def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@@ -87,3 +198,116 @@ class GraphDriver(ABC):
87
198
  cloned._database = database
88
199
 
89
200
  return cloned
201
+
202
+ async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
203
+ # No matter what happens above, always return True
204
+ return self.delete_aoss_indices()
205
+
206
+ async def create_aoss_indices(self):
207
+ client = self.aoss_client
208
+ if not client:
209
+ logger.warning('No OpenSearch client found')
210
+ return
211
+
212
+ for index in aoss_indices:
213
+ alias_name = index['index_name']
214
+
215
+ # If alias already exists, skip (idempotent behavior)
216
+ if await client.indices.exists_alias(name=alias_name):
217
+ continue
218
+
219
+ # Build a physical index name with timestamp
220
+ ts_suffix = datetime.utcnow().strftime('%Y%m%d%H%M%S')
221
+ physical_index_name = f'{alias_name}_{ts_suffix}'
222
+
223
+ # Create the index
224
+ await client.indices.create(index=physical_index_name, body=index['body'])
225
+
226
+ # Point alias to it
227
+ await client.indices.put_alias(index=physical_index_name, name=alias_name)
228
+
229
+ # Allow some time for index creation
230
+ await asyncio.sleep(1)
231
+
232
+ async def delete_aoss_indices(self):
233
+ client = self.aoss_client
234
+
235
+ if not client:
236
+ logger.warning('No OpenSearch client found')
237
+ return
238
+
239
+ for entry in aoss_indices:
240
+ alias_name = entry['index_name']
241
+
242
+ try:
243
+ # Resolve alias → indices
244
+ alias_info = await client.indices.get_alias(name=alias_name)
245
+ indices = list(alias_info.keys())
246
+
247
+ if not indices:
248
+ logger.info(f"No indices found for alias '{alias_name}'")
249
+ continue
250
+
251
+ for index in indices:
252
+ if await client.indices.exists(index=index):
253
+ await client.indices.delete(index=index)
254
+ logger.info(f"Deleted index '{index}' (alias: {alias_name})")
255
+ else:
256
+ logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
257
+
258
+ except Exception as e:
259
+ logger.error(f"Error deleting indices for alias '{alias_name}': {e}")
260
+
261
+ async def clear_aoss_indices(self):
262
+ client = self.aoss_client
263
+
264
+ if not client:
265
+ logger.warning('No OpenSearch client found')
266
+ return
267
+
268
+ for index in aoss_indices:
269
+ index_name = index['index_name']
270
+
271
+ if await client.indices.exists(index=index_name):
272
+ try:
273
+ # Delete all documents but keep the index
274
+ response = await client.delete_by_query(
275
+ index=index_name,
276
+ body={'query': {'match_all': {}}},
277
+ )
278
+ logger.info(f"Cleared index '{index_name}': {response}")
279
+ except Exception as e:
280
+ logger.error(f"Error clearing index '{index_name}': {e}")
281
+ else:
282
+ logger.warning(f"Index '{index_name}' does not exist")
283
+
284
+ async def save_to_aoss(self, name: str, data: list[dict]) -> int:
285
+ client = self.aoss_client
286
+ if not client or not helpers:
287
+ logger.warning('No OpenSearch client found')
288
+ return 0
289
+
290
+ for index in aoss_indices:
291
+ if name.lower() == index['index_name']:
292
+ to_index = []
293
+ for d in data:
294
+ doc = {}
295
+ for p in index['body']['mappings']['properties']:
296
+ if p in d: # protect against missing fields
297
+ doc[p] = d[p]
298
+
299
+ item = {
300
+ '_index': name,
301
+ '_id': d['uuid'],
302
+ '_routing': d.get('group_id'),
303
+ '_source': doc,
304
+ }
305
+ to_index.append(item)
306
+
307
+ success, failed = await helpers.async_bulk(
308
+ client, to_index, stats_only=True, request_timeout=60
309
+ )
310
+
311
+ return success if failed == 0 else success
312
+
313
+ return 0
@@ -74,6 +74,7 @@ class FalkorDriverSession(GraphDriverSession):
74
74
 
75
75
  class FalkorDriver(GraphDriver):
76
76
  provider = GraphProvider.FALKORDB
77
+ aoss_client: None = None
77
78
 
78
79
  def __init__(
79
80
  self,
@@ -92,6 +92,7 @@ SCHEMA_QUERIES = """
92
92
 
93
93
  class KuzuDriver(GraphDriver):
94
94
  provider: GraphProvider = GraphProvider.KUZU
95
+ aoss_client: None = None
95
96
 
96
97
  def __init__(
97
98
  self,
@@ -22,14 +22,44 @@ from neo4j import AsyncGraphDatabase, EagerResult
22
22
  from typing_extensions import LiteralString
23
23
 
24
24
  from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
25
+ from graphiti_core.helpers import semaphore_gather
25
26
 
26
27
  logger = logging.getLogger(__name__)
27
28
 
29
+ try:
30
+ import boto3
31
+ from opensearchpy import (
32
+ AIOHttpConnection,
33
+ AsyncOpenSearch,
34
+ AWSV4SignerAuth,
35
+ Urllib3AWSV4SignerAuth,
36
+ Urllib3HttpConnection,
37
+ )
38
+
39
+ _HAS_OPENSEARCH = True
40
+ except ImportError:
41
+ boto3 = None
42
+ OpenSearch = None
43
+ Urllib3AWSV4SignerAuth = None
44
+ Urllib3HttpConnection = None
45
+ _HAS_OPENSEARCH = False
46
+
28
47
 
29
48
  class Neo4jDriver(GraphDriver):
30
49
  provider = GraphProvider.NEO4J
31
50
 
32
- def __init__(self, uri: str, user: str | None, password: str | None, database: str = 'neo4j'):
51
+ def __init__(
52
+ self,
53
+ uri: str,
54
+ user: str | None,
55
+ password: str | None,
56
+ database: str = 'neo4j',
57
+ aoss_host: str | None = None,
58
+ aoss_port: int | None = None,
59
+ aws_profile_name: str | None = None,
60
+ aws_region: str | None = None,
61
+ aws_service: str | None = None,
62
+ ):
33
63
  super().__init__()
34
64
  self.client = AsyncGraphDatabase.driver(
35
65
  uri=uri,
@@ -37,6 +67,26 @@ class Neo4jDriver(GraphDriver):
37
67
  )
38
68
  self._database = database
39
69
 
70
+ self.aoss_client = None
71
+ if aoss_host and aoss_port and boto3 is not None:
72
+ try:
73
+ region = aws_region
74
+ service = aws_service
75
+ credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
76
+ auth = AWSV4SignerAuth(credentials, region or '', service or '')
77
+
78
+ self.aoss_client = AsyncOpenSearch(
79
+ hosts=[{'host': aoss_host, 'port': aoss_port}],
80
+ auth=auth,
81
+ use_ssl=True,
82
+ verify_certs=True,
83
+ connection_class=AIOHttpConnection,
84
+ pool_maxsize=20,
85
+ ) # type: ignore
86
+ except Exception as e:
87
+ logger.warning(f'Failed to initialize OpenSearch client: {e}')
88
+ self.aoss_client = None
89
+
40
90
  async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
41
91
  # Check if database_ is provided in kwargs.
42
92
  # If not populated, set the value to retain backwards compatibility
@@ -60,7 +110,14 @@ class Neo4jDriver(GraphDriver):
60
110
  async def close(self) -> None:
61
111
  return await self.client.close()
62
112
 
63
- def delete_all_indexes(self) -> Coroutine[Any, Any, EagerResult]:
113
+ def delete_all_indexes(self) -> Coroutine:
114
+ if self.aoss_client:
115
+ return semaphore_gather(
116
+ self.client.execute_query(
117
+ 'CALL db.indexes() YIELD name DROP INDEX name',
118
+ ),
119
+ self.delete_aoss_indices(),
120
+ )
64
121
  return self.client.execute_query(
65
122
  'CALL db.indexes() YIELD name DROP INDEX name',
66
123
  )
@@ -22,16 +22,21 @@ from typing import Any
22
22
 
23
23
  import boto3
24
24
  from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
25
- from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
25
+ from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
26
26
 
27
- from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
27
+ from graphiti_core.driver.driver import (
28
+ DEFAULT_SIZE,
29
+ GraphDriver,
30
+ GraphDriverSession,
31
+ GraphProvider,
32
+ )
28
33
 
29
34
  logger = logging.getLogger(__name__)
30
- DEFAULT_SIZE = 10
31
35
 
32
- aoss_indices = [
36
+ neptune_aoss_indices = [
33
37
  {
34
38
  'index_name': 'node_name_and_summary',
39
+ 'alias_name': 'entities',
35
40
  'body': {
36
41
  'mappings': {
37
42
  'properties': {
@@ -49,6 +54,7 @@ aoss_indices = [
49
54
  },
50
55
  {
51
56
  'index_name': 'community_name',
57
+ 'alias_name': 'communities',
52
58
  'body': {
53
59
  'mappings': {
54
60
  'properties': {
@@ -65,6 +71,7 @@ aoss_indices = [
65
71
  },
66
72
  {
67
73
  'index_name': 'episode_content',
74
+ 'alias_name': 'episodes',
68
75
  'body': {
69
76
  'mappings': {
70
77
  'properties': {
@@ -88,6 +95,7 @@ aoss_indices = [
88
95
  },
89
96
  {
90
97
  'index_name': 'edge_name_and_fact',
98
+ 'alias_name': 'facts',
91
99
  'body': {
92
100
  'mappings': {
93
101
  'properties': {
@@ -220,54 +228,27 @@ class NeptuneDriver(GraphDriver):
220
228
  async def _delete_all_data(self) -> Any:
221
229
  return await self.execute_query('MATCH (n) DETACH DELETE n')
222
230
 
223
- def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
224
- return self.delete_all_indexes_impl()
225
-
226
- async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
227
- # No matter what happens above, always return True
228
- return self.delete_aoss_indices()
229
-
230
231
  async def create_aoss_indices(self):
231
- for index in aoss_indices:
232
+ for index in neptune_aoss_indices:
232
233
  index_name = index['index_name']
233
234
  client = self.aoss_client
235
+ if not client:
236
+ raise ValueError(
237
+ 'You must provide an AOSS endpoint to create an OpenSearch driver.'
238
+ )
234
239
  if not client.indices.exists(index=index_name):
235
- client.indices.create(index=index_name, body=index['body'])
240
+ await client.indices.create(index=index_name, body=index['body'])
241
+
242
+ alias_name = index.get('alias_name', index_name)
243
+
244
+ if not client.indices.exists_alias(name=alias_name, index=index_name):
245
+ await client.indices.put_alias(index=index_name, name=alias_name)
246
+
236
247
  # Sleep for 1 minute to let the index creation complete
237
248
  await asyncio.sleep(60)
238
249
 
239
- async def delete_aoss_indices(self):
240
- for index in aoss_indices:
241
- index_name = index['index_name']
242
- client = self.aoss_client
243
- if client.indices.exists(index=index_name):
244
- client.indices.delete(index=index_name)
245
-
246
- def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
247
- for index in aoss_indices:
248
- if name.lower() == index['index_name']:
249
- index['query']['query']['multi_match']['query'] = query_text
250
- query = {'size': limit, 'query': index['query']}
251
- resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
252
- return resp
253
- return {}
254
-
255
- def save_to_aoss(self, name: str, data: list[dict]) -> int:
256
- for index in aoss_indices:
257
- if name.lower() == index['index_name']:
258
- to_index = []
259
- for d in data:
260
- item = {'_index': name}
261
- for p in index['body']['mappings']['properties']:
262
- item[p] = d[p]
263
- to_index.append(item)
264
- success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
265
- if failed > 0:
266
- return success
267
- else:
268
- return 0
269
-
270
- return 0
250
+ def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
251
+ return self.delete_all_indexes_impl()
271
252
 
272
253
 
273
254
  class NeptuneDriverSession(GraphDriverSession):
graphiti_core/edges.py CHANGED
@@ -25,7 +25,7 @@ from uuid import uuid4
25
25
  from pydantic import BaseModel, Field
26
26
  from typing_extensions import LiteralString
27
27
 
28
- from graphiti_core.driver.driver import GraphDriver, GraphProvider
28
+ from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider
29
29
  from graphiti_core.embedder import EmbedderClient
30
30
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
31
31
  from graphiti_core.helpers import parse_db_date
@@ -77,6 +77,13 @@ class Edge(BaseModel, ABC):
77
77
  uuid=self.uuid,
78
78
  )
79
79
 
80
+ if driver.aoss_client:
81
+ await driver.aoss_client.delete(
82
+ index=ENTITY_EDGE_INDEX_NAME,
83
+ id=self.uuid,
84
+ params={'routing': self.group_id},
85
+ )
86
+
80
87
  logger.debug(f'Deleted Edge: {self.uuid}')
81
88
 
82
89
  @classmethod
@@ -108,6 +115,12 @@ class Edge(BaseModel, ABC):
108
115
  uuids=uuids,
109
116
  )
110
117
 
118
+ if driver.aoss_client:
119
+ await driver.aoss_client.delete_by_query(
120
+ index=ENTITY_EDGE_INDEX_NAME,
121
+ body={'query': {'terms': {'uuid': uuids}}},
122
+ )
123
+
111
124
  logger.debug(f'Deleted Edges: {uuids}')
112
125
 
113
126
  def __hash__(self):
@@ -255,6 +268,21 @@ class EntityEdge(Edge):
255
268
  MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
256
269
  RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
257
270
  """
271
+ elif driver.aoss_client:
272
+ resp = await driver.aoss_client.search(
273
+ body={
274
+ 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
275
+ 'size': 1,
276
+ },
277
+ index=ENTITY_EDGE_INDEX_NAME,
278
+ params={'routing': self.group_id},
279
+ )
280
+
281
+ if resp['hits']['hits']:
282
+ self.fact_embedding = resp['hits']['hits'][0]['_source']['fact_embedding']
283
+ return
284
+ else:
285
+ raise EdgeNotFoundError(self.uuid)
258
286
 
259
287
  if driver.provider == GraphProvider.KUZU:
260
288
  query = """
@@ -292,14 +320,14 @@ class EntityEdge(Edge):
292
320
  if driver.provider == GraphProvider.KUZU:
293
321
  edge_data['attributes'] = json.dumps(self.attributes)
294
322
  result = await driver.execute_query(
295
- get_entity_edge_save_query(driver.provider),
323
+ get_entity_edge_save_query(driver.provider, has_aoss=bool(driver.aoss_client)),
296
324
  **edge_data,
297
325
  )
298
326
  else:
299
327
  edge_data.update(self.attributes or {})
300
328
 
301
- if driver.provider == GraphProvider.NEPTUNE:
302
- driver.save_to_aoss('edge_name_and_fact', [edge_data]) # pyright: ignore reportAttributeAccessIssue
329
+ if driver.aoss_client:
330
+ await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
303
331
 
304
332
  result = await driver.execute_query(
305
333
  get_entity_edge_save_query(driver.provider),
@@ -336,6 +364,35 @@ class EntityEdge(Edge):
336
364
  raise EdgeNotFoundError(uuid)
337
365
  return edges[0]
338
366
 
367
+ @classmethod
368
+ async def get_between_nodes(
369
+ cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
370
+ ):
371
+ match_query = """
372
+ MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
373
+ """
374
+ if driver.provider == GraphProvider.KUZU:
375
+ match_query = """
376
+ MATCH (n:Entity {uuid: $source_node_uuid})
377
+ -[:RELATES_TO]->(e:RelatesToNode_)
378
+ -[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
379
+ """
380
+
381
+ records, _, _ = await driver.execute_query(
382
+ match_query
383
+ + """
384
+ RETURN
385
+ """
386
+ + get_entity_edge_return_query(driver.provider),
387
+ source_node_uuid=source_node_uuid,
388
+ target_node_uuid=target_node_uuid,
389
+ routing_='r',
390
+ )
391
+
392
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
393
+
394
+ return edges
395
+
339
396
  @classmethod
340
397
  async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
341
398
  if len(uuids) == 0:
@@ -14,12 +14,13 @@ See the License for the specific language governing permissions and
14
14
  limitations under the License.
15
15
  """
16
16
 
17
+ import os
17
18
  from abc import ABC, abstractmethod
18
19
  from collections.abc import Iterable
19
20
 
20
21
  from pydantic import BaseModel, Field
21
22
 
22
- EMBEDDING_DIM = 1024
23
+ EMBEDDING_DIM = int(os.getenv('EMBEDDING_DIM', 1024))
23
24
 
24
25
 
25
26
  class EmbedderConfig(BaseModel):
graphiti_core/graphiti.py CHANGED
@@ -60,9 +60,7 @@ from graphiti_core.search.search_config_recipes import (
60
60
  from graphiti_core.search.search_filters import SearchFilters
61
61
  from graphiti_core.search.search_utils import (
62
62
  RELEVANT_SCHEMA_LIMIT,
63
- get_edge_invalidation_candidates,
64
63
  get_mentioned_nodes,
65
- get_relevant_edges,
66
64
  )
67
65
  from graphiti_core.telemetry import capture_event
68
66
  from graphiti_core.utils.bulk_utils import (
@@ -1037,10 +1035,28 @@ class Graphiti:
1037
1035
 
1038
1036
  updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
1039
1037
 
1040
- related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
1038
+ valid_edges = await EntityEdge.get_between_nodes(
1039
+ self.driver, edge.source_node_uuid, edge.target_node_uuid
1040
+ )
1041
+
1042
+ related_edges = (
1043
+ await search(
1044
+ self.clients,
1045
+ updated_edge.fact,
1046
+ group_ids=[updated_edge.group_id],
1047
+ config=EDGE_HYBRID_SEARCH_RRF,
1048
+ search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
1049
+ )
1050
+ ).edges
1041
1051
  existing_edges = (
1042
- await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
1043
- )[0]
1052
+ await search(
1053
+ self.clients,
1054
+ updated_edge.fact,
1055
+ group_ids=[updated_edge.group_id],
1056
+ config=EDGE_HYBRID_SEARCH_RRF,
1057
+ search_filter=SearchFilters(),
1058
+ )
1059
+ ).edges
1044
1060
 
1045
1061
  resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
1046
1062
  self.llm_client,