graphiti-core 0.21.0rc1__py3-none-any.whl → 0.21.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -17,16 +17,19 @@ limitations under the License.
17
17
  import asyncio
18
18
  import copy
19
19
  import logging
20
+ import os
20
21
  from abc import ABC, abstractmethod
21
22
  from collections.abc import Coroutine
22
23
  from datetime import datetime
23
24
  from enum import Enum
24
25
  from typing import Any
25
26
 
27
+ from dotenv import load_dotenv
28
+
26
29
  from graphiti_core.embedder.client import EMBEDDING_DIM
27
30
 
28
31
  try:
29
- from opensearchpy import OpenSearch, helpers
32
+ from opensearchpy import AsyncOpenSearch, helpers
30
33
 
31
34
  _HAS_OPENSEARCH = True
32
35
  except ImportError:
@@ -38,6 +41,13 @@ logger = logging.getLogger(__name__)
38
41
 
39
42
  DEFAULT_SIZE = 10
40
43
 
44
+ load_dotenv()
45
+
46
+ ENTITY_INDEX_NAME = os.environ.get('ENTITY_INDEX_NAME', 'entities')
47
+ EPISODE_INDEX_NAME = os.environ.get('EPISODE_INDEX_NAME', 'episodes')
48
+ COMMUNITY_INDEX_NAME = os.environ.get('COMMUNITY_INDEX_NAME', 'communities')
49
+ ENTITY_EDGE_INDEX_NAME = os.environ.get('ENTITY_EDGE_INDEX_NAME', 'entity_edges')
50
+
41
51
 
42
52
  class GraphProvider(Enum):
43
53
  NEO4J = 'neo4j'
@@ -48,20 +58,19 @@ class GraphProvider(Enum):
48
58
 
49
59
  aoss_indices = [
50
60
  {
51
- 'index_name': 'entities',
61
+ 'index_name': ENTITY_INDEX_NAME,
52
62
  'body': {
63
+ 'settings': {'index': {'knn': True}},
53
64
  'mappings': {
54
65
  'properties': {
55
66
  'uuid': {'type': 'keyword'},
56
67
  'name': {'type': 'text'},
57
68
  'summary': {'type': 'text'},
58
- 'group_id': {'type': 'text'},
59
- 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
69
+ 'group_id': {'type': 'keyword'},
70
+ 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
60
71
  'name_embedding': {
61
72
  'type': 'knn_vector',
62
- 'dims': EMBEDDING_DIM,
63
- 'index': True,
64
- 'similarity': 'cosine',
73
+ 'dimension': EMBEDDING_DIM,
65
74
  'method': {
66
75
  'engine': 'faiss',
67
76
  'space_type': 'cosinesimil',
@@ -70,23 +79,23 @@ aoss_indices = [
70
79
  },
71
80
  },
72
81
  }
73
- }
82
+ },
74
83
  },
75
84
  },
76
85
  {
77
- 'index_name': 'communities',
86
+ 'index_name': COMMUNITY_INDEX_NAME,
78
87
  'body': {
79
88
  'mappings': {
80
89
  'properties': {
81
90
  'uuid': {'type': 'keyword'},
82
91
  'name': {'type': 'text'},
83
- 'group_id': {'type': 'text'},
92
+ 'group_id': {'type': 'keyword'},
84
93
  }
85
94
  }
86
95
  },
87
96
  },
88
97
  {
89
- 'index_name': 'episodes',
98
+ 'index_name': EPISODE_INDEX_NAME,
90
99
  'body': {
91
100
  'mappings': {
92
101
  'properties': {
@@ -94,31 +103,30 @@ aoss_indices = [
94
103
  'content': {'type': 'text'},
95
104
  'source': {'type': 'text'},
96
105
  'source_description': {'type': 'text'},
97
- 'group_id': {'type': 'text'},
98
- 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
99
- 'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
106
+ 'group_id': {'type': 'keyword'},
107
+ 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
108
+ 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
100
109
  }
101
110
  }
102
111
  },
103
112
  },
104
113
  {
105
- 'index_name': 'entity_edges',
114
+ 'index_name': ENTITY_EDGE_INDEX_NAME,
106
115
  'body': {
116
+ 'settings': {'index': {'knn': True}},
107
117
  'mappings': {
108
118
  'properties': {
109
119
  'uuid': {'type': 'keyword'},
110
120
  'name': {'type': 'text'},
111
121
  'fact': {'type': 'text'},
112
- 'group_id': {'type': 'text'},
113
- 'created_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
114
- 'valid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
115
- 'expired_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
116
- 'invalid_at': {'type': 'date', 'format': "yyyy-MM-dd'T'HH:mm:ss.SSSZ"},
122
+ 'group_id': {'type': 'keyword'},
123
+ 'created_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
124
+ 'valid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
125
+ 'expired_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
126
+ 'invalid_at': {'type': 'date', 'format': 'strict_date_optional_time_nanos'},
117
127
  'fact_embedding': {
118
128
  'type': 'knn_vector',
119
- 'dims': EMBEDDING_DIM,
120
- 'index': True,
121
- 'similarity': 'cosine',
129
+ 'dimension': EMBEDDING_DIM,
122
130
  'method': {
123
131
  'engine': 'faiss',
124
132
  'space_type': 'cosinesimil',
@@ -127,7 +135,7 @@ aoss_indices = [
127
135
  },
128
136
  },
129
137
  }
130
- }
138
+ },
131
139
  },
132
140
  },
133
141
  ]
@@ -163,7 +171,7 @@ class GraphDriver(ABC):
163
171
  '' # Neo4j (default) syntax does not require a prefix for fulltext queries
164
172
  )
165
173
  _database: str
166
- aoss_client: OpenSearch | None # type: ignore
174
+ aoss_client: AsyncOpenSearch | None # type: ignore
167
175
 
168
176
  @abstractmethod
169
177
  def execute_query(self, cypher_query_: str, **kwargs: Any) -> Coroutine:
@@ -205,7 +213,7 @@ class GraphDriver(ABC):
205
213
  alias_name = index['index_name']
206
214
 
207
215
  # If alias already exists, skip (idempotent behavior)
208
- if client.indices.exists_alias(name=alias_name):
216
+ if await client.indices.exists_alias(name=alias_name):
209
217
  continue
210
218
 
211
219
  # Build a physical index name with timestamp
@@ -213,27 +221,67 @@ class GraphDriver(ABC):
213
221
  physical_index_name = f'{alias_name}_{ts_suffix}'
214
222
 
215
223
  # Create the index
216
- client.indices.create(index=physical_index_name, body=index['body'])
224
+ await client.indices.create(index=physical_index_name, body=index['body'])
217
225
 
218
226
  # Point alias to it
219
- client.indices.put_alias(index=physical_index_name, name=alias_name)
227
+ await client.indices.put_alias(index=physical_index_name, name=alias_name)
220
228
 
221
229
  # Allow some time for index creation
222
- await asyncio.sleep(60)
230
+ await asyncio.sleep(1)
223
231
 
224
232
  async def delete_aoss_indices(self):
225
- for index in aoss_indices:
226
- index_name = index['index_name']
227
- client = self.aoss_client
233
+ client = self.aoss_client
234
+
235
+ if not client:
236
+ logger.warning('No OpenSearch client found')
237
+ return
238
+
239
+ for entry in aoss_indices:
240
+ alias_name = entry['index_name']
241
+
242
+ try:
243
+ # Resolve alias → indices
244
+ alias_info = await client.indices.get_alias(name=alias_name)
245
+ indices = list(alias_info.keys())
246
+
247
+ if not indices:
248
+ logger.info(f"No indices found for alias '{alias_name}'")
249
+ continue
228
250
 
229
- if not client:
230
- logger.warning('No OpenSearch client found')
231
- return
251
+ for index in indices:
252
+ if await client.indices.exists(index=index):
253
+ await client.indices.delete(index=index)
254
+ logger.info(f"Deleted index '{index}' (alias: {alias_name})")
255
+ else:
256
+ logger.warning(f"Index '{index}' not found for alias '{alias_name}'")
232
257
 
233
- if client.indices.exists(index=index_name):
234
- client.indices.delete(index=index_name)
258
+ except Exception as e:
259
+ logger.error(f"Error deleting indices for alias '{alias_name}': {e}")
235
260
 
236
- def save_to_aoss(self, name: str, data: list[dict]) -> int:
261
+ async def clear_aoss_indices(self):
262
+ client = self.aoss_client
263
+
264
+ if not client:
265
+ logger.warning('No OpenSearch client found')
266
+ return
267
+
268
+ for index in aoss_indices:
269
+ index_name = index['index_name']
270
+
271
+ if await client.indices.exists(index=index_name):
272
+ try:
273
+ # Delete all documents but keep the index
274
+ response = await client.delete_by_query(
275
+ index=index_name,
276
+ body={'query': {'match_all': {}}},
277
+ )
278
+ logger.info(f"Cleared index '{index_name}': {response}")
279
+ except Exception as e:
280
+ logger.error(f"Error clearing index '{index_name}': {e}")
281
+ else:
282
+ logger.warning(f"Index '{index_name}' does not exist")
283
+
284
+ async def save_to_aoss(self, name: str, data: list[dict]) -> int:
237
285
  client = self.aoss_client
238
286
  if not client or not helpers:
239
287
  logger.warning('No OpenSearch client found')
@@ -243,16 +291,22 @@ class GraphDriver(ABC):
243
291
  if name.lower() == index['index_name']:
244
292
  to_index = []
245
293
  for d in data:
294
+ doc = {}
295
+ for p in index['body']['mappings']['properties']:
296
+ if p in d: # protect against missing fields
297
+ doc[p] = d[p]
298
+
246
299
  item = {
247
300
  '_index': name,
248
- '_routing': d.get('group_id'), # shard routing
301
+ '_id': d['uuid'],
302
+ '_routing': d.get('group_id'),
303
+ '_source': doc,
249
304
  }
250
- for p in index['body']['mappings']['properties']:
251
- if p in d: # protect against missing fields
252
- item[p] = d[p]
253
305
  to_index.append(item)
254
306
 
255
- success, failed = helpers.bulk(client, to_index, stats_only=True)
307
+ success, failed = await helpers.async_bulk(
308
+ client, to_index, stats_only=True, request_timeout=60
309
+ )
256
310
 
257
311
  return success if failed == 0 else success
258
312
 
@@ -28,7 +28,13 @@ logger = logging.getLogger(__name__)
28
28
 
29
29
  try:
30
30
  import boto3
31
- from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection
31
+ from opensearchpy import (
32
+ AIOHttpConnection,
33
+ AsyncOpenSearch,
34
+ AWSV4SignerAuth,
35
+ Urllib3AWSV4SignerAuth,
36
+ Urllib3HttpConnection,
37
+ )
32
38
 
33
39
  _HAS_OPENSEARCH = True
34
40
  except ImportError:
@@ -50,6 +56,9 @@ class Neo4jDriver(GraphDriver):
50
56
  database: str = 'neo4j',
51
57
  aoss_host: str | None = None,
52
58
  aoss_port: int | None = None,
59
+ aws_profile_name: str | None = None,
60
+ aws_region: str | None = None,
61
+ aws_service: str | None = None,
53
62
  ):
54
63
  super().__init__()
55
64
  self.client = AsyncGraphDatabase.driver(
@@ -61,15 +70,17 @@ class Neo4jDriver(GraphDriver):
61
70
  self.aoss_client = None
62
71
  if aoss_host and aoss_port and boto3 is not None:
63
72
  try:
64
- session = boto3.Session()
65
- self.aoss_client = OpenSearch( # type: ignore
73
+ region = aws_region
74
+ service = aws_service
75
+ credentials = boto3.Session(profile_name=aws_profile_name).get_credentials()
76
+ auth = AWSV4SignerAuth(credentials, region or '', service or '')
77
+
78
+ self.aoss_client = AsyncOpenSearch(
66
79
  hosts=[{'host': aoss_host, 'port': aoss_port}],
67
- http_auth=Urllib3AWSV4SignerAuth( # type: ignore
68
- session.get_credentials(), session.region_name, 'aoss'
69
- ),
80
+ auth=auth,
70
81
  use_ssl=True,
71
82
  verify_certs=True,
72
- connection_class=Urllib3HttpConnection,
83
+ connection_class=AIOHttpConnection,
73
84
  pool_maxsize=20,
74
85
  ) # type: ignore
75
86
  except Exception as e:
@@ -237,12 +237,12 @@ class NeptuneDriver(GraphDriver):
237
237
  'You must provide an AOSS endpoint to create an OpenSearch driver.'
238
238
  )
239
239
  if not client.indices.exists(index=index_name):
240
- client.indices.create(index=index_name, body=index['body'])
240
+ await client.indices.create(index=index_name, body=index['body'])
241
241
 
242
242
  alias_name = index.get('alias_name', index_name)
243
243
 
244
244
  if not client.indices.exists_alias(name=alias_name, index=index_name):
245
- client.indices.put_alias(index=index_name, name=alias_name)
245
+ await client.indices.put_alias(index=index_name, name=alias_name)
246
246
 
247
247
  # Sleep for 1 minute to let the index creation complete
248
248
  await asyncio.sleep(60)
graphiti_core/edges.py CHANGED
@@ -25,7 +25,7 @@ from uuid import uuid4
25
25
  from pydantic import BaseModel, Field
26
26
  from typing_extensions import LiteralString
27
27
 
28
- from graphiti_core.driver.driver import GraphDriver, GraphProvider
28
+ from graphiti_core.driver.driver import ENTITY_EDGE_INDEX_NAME, GraphDriver, GraphProvider
29
29
  from graphiti_core.embedder import EmbedderClient
30
30
  from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
31
31
  from graphiti_core.helpers import parse_db_date
@@ -77,6 +77,13 @@ class Edge(BaseModel, ABC):
77
77
  uuid=self.uuid,
78
78
  )
79
79
 
80
+ if driver.aoss_client:
81
+ await driver.aoss_client.delete(
82
+ index=ENTITY_EDGE_INDEX_NAME,
83
+ id=self.uuid,
84
+ params={'routing': self.group_id},
85
+ )
86
+
80
87
  logger.debug(f'Deleted Edge: {self.uuid}')
81
88
 
82
89
  @classmethod
@@ -108,6 +115,12 @@ class Edge(BaseModel, ABC):
108
115
  uuids=uuids,
109
116
  )
110
117
 
118
+ if driver.aoss_client:
119
+ await driver.aoss_client.delete_by_query(
120
+ index=ENTITY_EDGE_INDEX_NAME,
121
+ body={'query': {'terms': {'uuid': uuids}}},
122
+ )
123
+
111
124
  logger.debug(f'Deleted Edges: {uuids}')
112
125
 
113
126
  def __hash__(self):
@@ -256,13 +269,13 @@ class EntityEdge(Edge):
256
269
  RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
257
270
  """
258
271
  elif driver.aoss_client:
259
- resp = driver.aoss_client.search(
272
+ resp = await driver.aoss_client.search(
260
273
  body={
261
274
  'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
262
275
  'size': 1,
263
276
  },
264
- index='entity_edges',
265
- routing=self.group_id,
277
+ index=ENTITY_EDGE_INDEX_NAME,
278
+ params={'routing': self.group_id},
266
279
  )
267
280
 
268
281
  if resp['hits']['hits']:
@@ -314,7 +327,7 @@ class EntityEdge(Edge):
314
327
  edge_data.update(self.attributes or {})
315
328
 
316
329
  if driver.aoss_client:
317
- driver.save_to_aoss('entity_edges', [edge_data]) # pyright: ignore reportAttributeAccessIssue
330
+ await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, [edge_data]) # pyright: ignore reportAttributeAccessIssue
318
331
 
319
332
  result = await driver.execute_query(
320
333
  get_entity_edge_save_query(driver.provider),
@@ -351,6 +364,35 @@ class EntityEdge(Edge):
351
364
  raise EdgeNotFoundError(uuid)
352
365
  return edges[0]
353
366
 
367
+ @classmethod
368
+ async def get_between_nodes(
369
+ cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
370
+ ):
371
+ match_query = """
372
+ MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
373
+ """
374
+ if driver.provider == GraphProvider.KUZU:
375
+ match_query = """
376
+ MATCH (n:Entity {uuid: $source_node_uuid})
377
+ -[:RELATES_TO]->(e:RelatesToNode_)
378
+ -[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
379
+ """
380
+
381
+ records, _, _ = await driver.execute_query(
382
+ match_query
383
+ + """
384
+ RETURN
385
+ """
386
+ + get_entity_edge_return_query(driver.provider),
387
+ source_node_uuid=source_node_uuid,
388
+ target_node_uuid=target_node_uuid,
389
+ routing_='r',
390
+ )
391
+
392
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
393
+
394
+ return edges
395
+
354
396
  @classmethod
355
397
  async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
356
398
  if len(uuids) == 0:
graphiti_core/graphiti.py CHANGED
@@ -60,9 +60,7 @@ from graphiti_core.search.search_config_recipes import (
60
60
  from graphiti_core.search.search_filters import SearchFilters
61
61
  from graphiti_core.search.search_utils import (
62
62
  RELEVANT_SCHEMA_LIMIT,
63
- get_edge_invalidation_candidates,
64
63
  get_mentioned_nodes,
65
- get_relevant_edges,
66
64
  )
67
65
  from graphiti_core.telemetry import capture_event
68
66
  from graphiti_core.utils.bulk_utils import (
@@ -1037,10 +1035,28 @@ class Graphiti:
1037
1035
 
1038
1036
  updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
1039
1037
 
1040
- related_edges = (await get_relevant_edges(self.driver, [updated_edge], SearchFilters()))[0]
1038
+ valid_edges = await EntityEdge.get_between_nodes(
1039
+ self.driver, edge.source_node_uuid, edge.target_node_uuid
1040
+ )
1041
+
1042
+ related_edges = (
1043
+ await search(
1044
+ self.clients,
1045
+ updated_edge.fact,
1046
+ group_ids=[updated_edge.group_id],
1047
+ config=EDGE_HYBRID_SEARCH_RRF,
1048
+ search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
1049
+ )
1050
+ ).edges
1041
1051
  existing_edges = (
1042
- await get_edge_invalidation_candidates(self.driver, [updated_edge], SearchFilters())
1043
- )[0]
1052
+ await search(
1053
+ self.clients,
1054
+ updated_edge.fact,
1055
+ group_ids=[updated_edge.group_id],
1056
+ config=EDGE_HYBRID_SEARCH_RRF,
1057
+ search_filter=SearchFilters(),
1058
+ )
1059
+ ).edges
1044
1060
 
1045
1061
  resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
1046
1062
  self.llm_client,
graphiti_core/nodes.py CHANGED
@@ -26,7 +26,14 @@ from uuid import uuid4
26
26
  from pydantic import BaseModel, Field
27
27
  from typing_extensions import LiteralString
28
28
 
29
- from graphiti_core.driver.driver import GraphDriver, GraphProvider
29
+ from graphiti_core.driver.driver import (
30
+ COMMUNITY_INDEX_NAME,
31
+ ENTITY_EDGE_INDEX_NAME,
32
+ ENTITY_INDEX_NAME,
33
+ EPISODE_INDEX_NAME,
34
+ GraphDriver,
35
+ GraphProvider,
36
+ )
30
37
  from graphiti_core.embedder import EmbedderClient
31
38
  from graphiti_core.errors import NodeNotFoundError
32
39
  from graphiti_core.helpers import parse_db_date
@@ -94,13 +101,39 @@ class Node(BaseModel, ABC):
94
101
  async def delete(self, driver: GraphDriver):
95
102
  match driver.provider:
96
103
  case GraphProvider.NEO4J:
97
- await driver.execute_query(
104
+ records, _, _ = await driver.execute_query(
98
105
  """
99
- MATCH (n:Entity|Episodic|Community {uuid: $uuid})
106
+ MATCH (n {uuid: $uuid})
107
+ WHERE n:Entity OR n:Episodic OR n:Community
108
+ OPTIONAL MATCH (n)-[r]-()
109
+ WITH collect(r.uuid) AS edge_uuids, n
100
110
  DETACH DELETE n
111
+ RETURN edge_uuids
101
112
  """,
102
113
  uuid=self.uuid,
103
114
  )
115
+
116
+ edge_uuids: list[str] = records[0].get('edge_uuids', []) if records else []
117
+
118
+ if driver.aoss_client:
119
+ # Delete the node from OpenSearch indices
120
+ for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
121
+ await driver.aoss_client.delete(
122
+ index=index,
123
+ id=self.uuid,
124
+ params={'routing': self.group_id},
125
+ )
126
+
127
+ # Bulk delete the detached edges
128
+ if edge_uuids:
129
+ actions = []
130
+ for eid in edge_uuids:
131
+ actions.append(
132
+ {'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
133
+ )
134
+
135
+ await driver.aoss_client.bulk(body=actions)
136
+
104
137
  case GraphProvider.KUZU:
105
138
  for label in ['Episodic', 'Community']:
106
139
  await driver.execute_query(
@@ -162,6 +195,32 @@ class Node(BaseModel, ABC):
162
195
  group_id=group_id,
163
196
  batch_size=batch_size,
164
197
  )
198
+
199
+ if driver.aoss_client:
200
+ await driver.aoss_client.delete_by_query(
201
+ index=EPISODE_INDEX_NAME,
202
+ body={'query': {'term': {'group_id': group_id}}},
203
+ params={'routing': group_id},
204
+ )
205
+
206
+ await driver.aoss_client.delete_by_query(
207
+ index=ENTITY_INDEX_NAME,
208
+ body={'query': {'term': {'group_id': group_id}}},
209
+ params={'routing': group_id},
210
+ )
211
+
212
+ await driver.aoss_client.delete_by_query(
213
+ index=COMMUNITY_INDEX_NAME,
214
+ body={'query': {'term': {'group_id': group_id}}},
215
+ params={'routing': group_id},
216
+ )
217
+
218
+ await driver.aoss_client.delete_by_query(
219
+ index=ENTITY_EDGE_INDEX_NAME,
220
+ body={'query': {'term': {'group_id': group_id}}},
221
+ params={'routing': group_id},
222
+ )
223
+
165
224
  case GraphProvider.KUZU:
166
225
  for label in ['Episodic', 'Community']:
167
226
  await driver.execute_query(
@@ -240,6 +299,23 @@ class Node(BaseModel, ABC):
240
299
  )
241
300
  case _: # Neo4J, Neptune
242
301
  async with driver.session() as session:
302
+ # Collect all edge UUIDs before deleting nodes
303
+ result = await session.run(
304
+ """
305
+ MATCH (n:Entity|Episodic|Community)
306
+ WHERE n.uuid IN $uuids
307
+ MATCH (n)-[r]-()
308
+ RETURN collect(r.uuid) AS edge_uuids
309
+ """,
310
+ uuids=uuids,
311
+ )
312
+
313
+ record = await result.single()
314
+ edge_uuids: list[str] = (
315
+ record['edge_uuids'] if record and record['edge_uuids'] else []
316
+ )
317
+
318
+ # Now delete the nodes in batches
243
319
  await session.run(
244
320
  """
245
321
  MATCH (n:Entity|Episodic|Community)
@@ -253,6 +329,20 @@ class Node(BaseModel, ABC):
253
329
  batch_size=batch_size,
254
330
  )
255
331
 
332
+ if driver.aoss_client:
333
+ for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
334
+ await driver.aoss_client.delete_by_query(
335
+ index=index,
336
+ body={'query': {'terms': {'uuid': uuids}}},
337
+ )
338
+
339
+ if edge_uuids:
340
+ actions = [
341
+ {'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
342
+ for eid in edge_uuids
343
+ ]
344
+ await driver.aoss_client.bulk(body=actions)
345
+
256
346
  @classmethod
257
347
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
258
348
 
@@ -286,7 +376,7 @@ class EpisodicNode(Node):
286
376
  }
287
377
 
288
378
  if driver.aoss_client:
289
- driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
379
+ await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
290
380
  'episodes',
291
381
  [episode_args],
292
382
  )
@@ -426,13 +516,13 @@ class EntityNode(Node):
426
516
  RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
427
517
  """
428
518
  elif driver.aoss_client:
429
- resp = driver.aoss_client.search(
519
+ resp = await driver.aoss_client.search(
430
520
  body={
431
521
  'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
432
522
  'size': 1,
433
523
  },
434
- index='entities',
435
- routing=self.group_id,
524
+ index=ENTITY_INDEX_NAME,
525
+ params={'routing': self.group_id},
436
526
  )
437
527
 
438
528
  if resp['hits']['hits']:
@@ -479,7 +569,7 @@ class EntityNode(Node):
479
569
  labels = ':'.join(self.labels + ['Entity'])
480
570
 
481
571
  if driver.aoss_client:
482
- driver.save_to_aoss('entities', [entity_data]) # pyright: ignore reportAttributeAccessIssue
572
+ await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
483
573
 
484
574
  result = await driver.execute_query(
485
575
  get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
@@ -577,7 +667,7 @@ class CommunityNode(Node):
577
667
 
578
668
  async def save(self, driver: GraphDriver):
579
669
  if driver.provider == GraphProvider.NEPTUNE:
580
- driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
670
+ await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
581
671
  'communities',
582
672
  [{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
583
673
  )
@@ -52,6 +52,7 @@ class SearchFilters(BaseModel):
52
52
  invalid_at: list[list[DateFilter]] | None = Field(default=None)
53
53
  created_at: list[list[DateFilter]] | None = Field(default=None)
54
54
  expired_at: list[list[DateFilter]] | None = Field(default=None)
55
+ edge_uuids: list[str] | None = Field(default=None)
55
56
 
56
57
 
57
58
  def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
@@ -108,6 +109,10 @@ def edge_search_filter_query_constructor(
108
109
  filter_queries.append('e.name in $edge_types')
109
110
  filter_params['edge_types'] = edge_types
110
111
 
112
+ if filters.edge_uuids is not None:
113
+ filter_queries.append('e.uuid in $edge_uuids')
114
+ filter_params['edge_uuids'] = filters.edge_uuids
115
+
111
116
  if filters.node_labels is not None:
112
117
  if provider == GraphProvider.KUZU:
113
118
  node_label_filter = (
@@ -261,6 +266,9 @@ def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters)
261
266
  if search_filters.edge_types:
262
267
  filters.append({'terms': {'edge_types': search_filters.edge_types}})
263
268
 
269
+ if search_filters.edge_uuids:
270
+ filters.append({'terms': {'uuid': search_filters.edge_uuids}})
271
+
264
272
  for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
265
273
  ranges = getattr(search_filters, field)
266
274
  if ranges:
@@ -23,7 +23,13 @@ import numpy as np
23
23
  from numpy._typing import NDArray
24
24
  from typing_extensions import LiteralString
25
25
 
26
- from graphiti_core.driver.driver import GraphDriver, GraphProvider
26
+ from graphiti_core.driver.driver import (
27
+ ENTITY_EDGE_INDEX_NAME,
28
+ ENTITY_INDEX_NAME,
29
+ EPISODE_INDEX_NAME,
30
+ GraphDriver,
31
+ GraphProvider,
32
+ )
27
33
  from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
28
34
  from graphiti_core.graph_queries import (
29
35
  get_nodes_query,
@@ -209,11 +215,11 @@ async def edge_fulltext_search(
209
215
  # Match the edge ids and return the values
210
216
  query = (
211
217
  """
212
- UNWIND $ids as id
213
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
214
- WHERE e.group_id IN $group_ids
215
- AND id(e)=id
216
- """
218
+ UNWIND $ids as id
219
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
220
+ WHERE e.group_id IN $group_ids
221
+ AND id(e)=id
222
+ """
217
223
  + filter_query
218
224
  + """
219
225
  AND id(e)=id
@@ -248,17 +254,21 @@ async def edge_fulltext_search(
248
254
  elif driver.aoss_client:
249
255
  route = group_ids[0] if group_ids else None
250
256
  filters = build_aoss_edge_filters(group_ids or [], search_filter)
251
- res = driver.aoss_client.search(
252
- index='entity_edges',
253
- routing=route,
254
- _source=['uuid'],
255
- query={
256
- 'bool': {
257
- 'filter': filters,
258
- 'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
259
- }
257
+ res = await driver.aoss_client.search(
258
+ index=ENTITY_EDGE_INDEX_NAME,
259
+ params={'routing': route},
260
+ body={
261
+ 'size': limit,
262
+ '_source': ['uuid'],
263
+ 'query': {
264
+ 'bool': {
265
+ 'filter': filters,
266
+ 'must': [{'match': {'fact': {'query': query, 'operator': 'or'}}}],
267
+ }
268
+ },
260
269
  },
261
270
  )
271
+
262
272
  if res['hits']['total']['value'] > 0:
263
273
  input_uuids = {}
264
274
  for r in res['hits']['hits']:
@@ -344,8 +354,8 @@ async def edge_similarity_search(
344
354
  if driver.provider == GraphProvider.NEPTUNE:
345
355
  query = (
346
356
  """
347
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
348
- """
357
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
358
+ """
349
359
  + filter_query
350
360
  + """
351
361
  RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
@@ -406,17 +416,22 @@ async def edge_similarity_search(
406
416
  elif driver.aoss_client:
407
417
  route = group_ids[0] if group_ids else None
408
418
  filters = build_aoss_edge_filters(group_ids or [], search_filter)
409
- res = driver.aoss_client.search(
410
- index='entity_edges',
411
- routing=route,
412
- _source=['uuid'],
413
- knn={
414
- 'field': 'fact_embedding',
415
- 'query_vector': search_vector,
416
- 'k': limit,
417
- 'num_candidates': 1000,
419
+ res = await driver.aoss_client.search(
420
+ index=ENTITY_EDGE_INDEX_NAME,
421
+ params={'routing': route},
422
+ body={
423
+ 'size': limit,
424
+ '_source': ['uuid'],
425
+ 'query': {
426
+ 'knn': {
427
+ 'fact_embedding': {
428
+ 'vector': list(map(float, search_vector)),
429
+ 'k': limit,
430
+ 'filter': {'bool': {'filter': filters}},
431
+ }
432
+ }
433
+ },
418
434
  },
419
- query={'bool': {'filter': filters}},
420
435
  )
421
436
 
422
437
  if res['hits']['total']['value'] > 0:
@@ -428,6 +443,7 @@ async def edge_similarity_search(
428
443
  entity_edges = await EntityEdge.get_by_uuids(driver, list(input_uuids.keys()))
429
444
  entity_edges.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
430
445
  return entity_edges
446
+ return []
431
447
 
432
448
  else:
433
449
  query = (
@@ -622,11 +638,11 @@ async def node_fulltext_search(
622
638
  # Match the edge ides and return the values
623
639
  query = (
624
640
  """
625
- UNWIND $ids as i
626
- MATCH (n:Entity)
627
- WHERE n.uuid=i.id
628
- RETURN
629
- """
641
+ UNWIND $ids as i
642
+ MATCH (n:Entity)
643
+ WHERE n.uuid=i.id
644
+ RETURN
645
+ """
630
646
  + get_entity_node_return_query(driver.provider)
631
647
  + """
632
648
  ORDER BY i.score DESC
@@ -646,25 +662,27 @@ async def node_fulltext_search(
646
662
  elif driver.aoss_client:
647
663
  route = group_ids[0] if group_ids else None
648
664
  filters = build_aoss_node_filters(group_ids or [], search_filter)
649
- res = driver.aoss_client.search(
650
- 'entities',
651
- routing=route,
652
- _source=['uuid'],
653
- query={
654
- 'bool': {
655
- 'filter': filters,
656
- 'must': [
657
- {
658
- 'multi_match': {
659
- 'query': query,
660
- 'field': ['name', 'summary'],
661
- 'operator': 'or',
665
+ res = await driver.aoss_client.search(
666
+ index=ENTITY_INDEX_NAME,
667
+ params={'routing': route},
668
+ body={
669
+ '_source': ['uuid'],
670
+ 'size': limit,
671
+ 'query': {
672
+ 'bool': {
673
+ 'filter': filters,
674
+ 'must': [
675
+ {
676
+ 'multi_match': {
677
+ 'query': query,
678
+ 'fields': ['name', 'summary'],
679
+ 'operator': 'or',
680
+ }
662
681
  }
663
- }
664
- ],
665
- }
682
+ ],
683
+ }
684
+ },
666
685
  },
667
- limit=limit,
668
686
  )
669
687
 
670
688
  if res['hits']['total']['value'] > 0:
@@ -734,8 +752,8 @@ async def node_similarity_search(
734
752
  if driver.provider == GraphProvider.NEPTUNE:
735
753
  query = (
736
754
  """
737
- MATCH (n:Entity)
738
- """
755
+ MATCH (n:Entity)
756
+ """
739
757
  + filter_query
740
758
  + """
741
759
  RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@@ -764,11 +782,11 @@ async def node_similarity_search(
764
782
  # Match the edge ides and return the values
765
783
  query = (
766
784
  """
767
- UNWIND $ids as i
768
- MATCH (n:Entity)
769
- WHERE id(n)=i.id
770
- RETURN
771
- """
785
+ UNWIND $ids as i
786
+ MATCH (n:Entity)
787
+ WHERE id(n)=i.id
788
+ RETURN
789
+ """
772
790
  + get_entity_node_return_query(driver.provider)
773
791
  + """
774
792
  ORDER BY i.score DESC
@@ -789,17 +807,22 @@ async def node_similarity_search(
789
807
  elif driver.aoss_client:
790
808
  route = group_ids[0] if group_ids else None
791
809
  filters = build_aoss_node_filters(group_ids or [], search_filter)
792
- res = driver.aoss_client.search(
793
- index='entities',
794
- routing=route,
795
- _source=['uuid'],
796
- knn={
797
- 'field': 'fact_embedding',
798
- 'query_vector': search_vector,
799
- 'k': limit,
800
- 'num_candidates': 1000,
810
+ res = await driver.aoss_client.search(
811
+ index=ENTITY_INDEX_NAME,
812
+ params={'routing': route},
813
+ body={
814
+ 'size': limit,
815
+ '_source': ['uuid'],
816
+ 'query': {
817
+ 'knn': {
818
+ 'name_embedding': {
819
+ 'vector': list(map(float, search_vector)),
820
+ 'k': limit,
821
+ 'filter': {'bool': {'filter': filters}},
822
+ }
823
+ }
824
+ },
801
825
  },
802
- query={'bool': {'filter': filters}},
803
826
  )
804
827
 
805
828
  if res['hits']['total']['value'] > 0:
@@ -811,11 +834,12 @@ async def node_similarity_search(
811
834
  entity_nodes = await EntityNode.get_by_uuids(driver, list(input_uuids.keys()))
812
835
  entity_nodes.sort(key=lambda e: input_uuids.get(e.uuid, 0), reverse=True)
813
836
  return entity_nodes
837
+ return []
814
838
  else:
815
839
  query = (
816
840
  """
817
- MATCH (n:Entity)
818
- """
841
+ MATCH (n:Entity)
842
+ """
819
843
  + filter_query
820
844
  + """
821
845
  WITH n, """
@@ -988,11 +1012,12 @@ async def episode_fulltext_search(
988
1012
  return []
989
1013
  elif driver.aoss_client:
990
1014
  route = group_ids[0] if group_ids else None
991
- res = driver.aoss_client.search(
992
- 'episodes',
993
- routing=route,
994
- _source=['uuid'],
995
- query={
1015
+ res = await driver.aoss_client.search(
1016
+ index=EPISODE_INDEX_NAME,
1017
+ params={'routing': route},
1018
+ body={
1019
+ 'size': limit,
1020
+ '_source': ['uuid'],
996
1021
  'bool': {
997
1022
  'filter': {'terms': group_ids},
998
1023
  'must': [
@@ -1004,9 +1029,8 @@ async def episode_fulltext_search(
1004
1029
  }
1005
1030
  }
1006
1031
  ],
1007
- }
1032
+ },
1008
1033
  },
1009
- limit=limit,
1010
1034
  )
1011
1035
 
1012
1036
  if res['hits']['total']['value'] > 0:
@@ -1147,8 +1171,8 @@ async def community_similarity_search(
1147
1171
  if driver.provider == GraphProvider.NEPTUNE:
1148
1172
  query = (
1149
1173
  """
1150
- MATCH (n:Community)
1151
- """
1174
+ MATCH (n:Community)
1175
+ """
1152
1176
  + group_filter_query
1153
1177
  + """
1154
1178
  RETURN DISTINCT id(n) as id, n.name_embedding as embedding
@@ -1207,8 +1231,8 @@ async def community_similarity_search(
1207
1231
 
1208
1232
  query = (
1209
1233
  """
1210
- MATCH (c:Community)
1211
- """
1234
+ MATCH (c:Community)
1235
+ """
1212
1236
  + group_filter_query
1213
1237
  + """
1214
1238
  WITH c,
@@ -1350,9 +1374,9 @@ async def get_relevant_nodes(
1350
1374
  # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
1351
1375
  query = (
1352
1376
  """
1353
- UNWIND $nodes AS node
1354
- MATCH (n:Entity {group_id: $group_id})
1355
- """
1377
+ UNWIND $nodes AS node
1378
+ MATCH (n:Entity {group_id: $group_id})
1379
+ """
1356
1380
  + filter_query
1357
1381
  + """
1358
1382
  WITH node, n, """
@@ -1397,9 +1421,9 @@ async def get_relevant_nodes(
1397
1421
  else:
1398
1422
  query = (
1399
1423
  """
1400
- UNWIND $nodes AS node
1401
- MATCH (n:Entity {group_id: $group_id})
1402
- """
1424
+ UNWIND $nodes AS node
1425
+ MATCH (n:Entity {group_id: $group_id})
1426
+ """
1403
1427
  + filter_query
1404
1428
  + """
1405
1429
  WITH node, n, """
@@ -1488,9 +1512,9 @@ async def get_relevant_edges(
1488
1512
  if driver.provider == GraphProvider.NEPTUNE:
1489
1513
  query = (
1490
1514
  """
1491
- UNWIND $edges AS edge
1492
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1493
- """
1515
+ UNWIND $edges AS edge
1516
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1517
+ """
1494
1518
  + filter_query
1495
1519
  + """
1496
1520
  WITH e, edge
@@ -1560,9 +1584,9 @@ async def get_relevant_edges(
1560
1584
 
1561
1585
  query = (
1562
1586
  """
1563
- UNWIND $edges AS edge
1564
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
1565
- """
1587
+ UNWIND $edges AS edge
1588
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
1589
+ """
1566
1590
  + filter_query
1567
1591
  + """
1568
1592
  WITH e, edge, n, m, """
@@ -1598,9 +1622,9 @@ async def get_relevant_edges(
1598
1622
  else:
1599
1623
  query = (
1600
1624
  """
1601
- UNWIND $edges AS edge
1602
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1603
- """
1625
+ UNWIND $edges AS edge
1626
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1627
+ """
1604
1628
  + filter_query
1605
1629
  + """
1606
1630
  WITH e, edge, """
@@ -1673,10 +1697,10 @@ async def get_edge_invalidation_candidates(
1673
1697
  if driver.provider == GraphProvider.NEPTUNE:
1674
1698
  query = (
1675
1699
  """
1676
- UNWIND $edges AS edge
1677
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1678
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1679
- """
1700
+ UNWIND $edges AS edge
1701
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1702
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1703
+ """
1680
1704
  + filter_query
1681
1705
  + """
1682
1706
  WITH e, edge
@@ -1746,10 +1770,10 @@ async def get_edge_invalidation_candidates(
1746
1770
 
1747
1771
  query = (
1748
1772
  """
1749
- UNWIND $edges AS edge
1750
- MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
1751
- WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
1752
- """
1773
+ UNWIND $edges AS edge
1774
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
1775
+ WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
1776
+ """
1753
1777
  + filter_query
1754
1778
  + """
1755
1779
  WITH edge, e, n, m, """
@@ -1785,10 +1809,10 @@ async def get_edge_invalidation_candidates(
1785
1809
  else:
1786
1810
  query = (
1787
1811
  """
1788
- UNWIND $edges AS edge
1789
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1790
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1791
- """
1812
+ UNWIND $edges AS edge
1813
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1814
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1815
+ """
1792
1816
  + filter_query
1793
1817
  + """
1794
1818
  WITH edge, e, """
@@ -23,7 +23,14 @@ import numpy as np
23
23
  from pydantic import BaseModel, Field
24
24
  from typing_extensions import Any
25
25
 
26
- from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
26
+ from graphiti_core.driver.driver import (
27
+ ENTITY_EDGE_INDEX_NAME,
28
+ ENTITY_INDEX_NAME,
29
+ EPISODE_INDEX_NAME,
30
+ GraphDriver,
31
+ GraphDriverSession,
32
+ GraphProvider,
33
+ )
27
34
  from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
28
35
  from graphiti_core.embedder import EmbedderClient
29
36
  from graphiti_core.graphiti_types import GraphitiClients
@@ -203,9 +210,9 @@ async def add_nodes_and_edges_bulk_tx(
203
210
  )
204
211
 
205
212
  if driver.aoss_client:
206
- driver.save_to_aoss('episodes', episodes)
207
- driver.save_to_aoss('entities', nodes)
208
- driver.save_to_aoss('entity_edges', edges)
213
+ await driver.save_to_aoss(EPISODE_INDEX_NAME, episodes)
214
+ await driver.save_to_aoss(ENTITY_INDEX_NAME, nodes)
215
+ await driver.save_to_aoss(ENTITY_EDGE_INDEX_NAME, edges)
209
216
 
210
217
 
211
218
  async def extract_nodes_and_edges_bulk(
@@ -36,8 +36,10 @@ from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
36
36
  from graphiti_core.prompts import prompt_library
37
37
  from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
38
38
  from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
39
+ from graphiti_core.search.search import search
40
+ from graphiti_core.search.search_config import SearchResults
41
+ from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
39
42
  from graphiti_core.search.search_filters import SearchFilters
40
- from graphiti_core.search.search_utils import get_edge_invalidation_candidates, get_relevant_edges
41
43
  from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
42
44
 
43
45
  logger = logging.getLogger(__name__)
@@ -258,12 +260,44 @@ async def resolve_extracted_edges(
258
260
  embedder = clients.embedder
259
261
  await create_entity_edge_embeddings(embedder, extracted_edges)
260
262
 
261
- search_results = await semaphore_gather(
262
- get_relevant_edges(driver, extracted_edges, SearchFilters()),
263
- get_edge_invalidation_candidates(driver, extracted_edges, SearchFilters(), 0.2),
263
+ valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
264
+ *[
265
+ EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
266
+ for edge in extracted_edges
267
+ ]
268
+ )
269
+
270
+ related_edges_results: list[SearchResults] = await semaphore_gather(
271
+ *[
272
+ search(
273
+ clients,
274
+ extracted_edge.fact,
275
+ group_ids=[extracted_edge.group_id],
276
+ config=EDGE_HYBRID_SEARCH_RRF,
277
+ search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
278
+ )
279
+ for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
280
+ ]
264
281
  )
265
282
 
266
- related_edges_lists, edge_invalidation_candidates = search_results
283
+ related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
284
+
285
+ edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
286
+ *[
287
+ search(
288
+ clients,
289
+ extracted_edge.fact,
290
+ group_ids=[extracted_edge.group_id],
291
+ config=EDGE_HYBRID_SEARCH_RRF,
292
+ search_filter=SearchFilters(),
293
+ )
294
+ for extracted_edge in extracted_edges
295
+ ]
296
+ )
297
+
298
+ edge_invalidation_candidates: list[list[EntityEdge]] = [
299
+ result.edges for result in edge_invalidation_candidate_results
300
+ ]
267
301
 
268
302
  logger.debug(
269
303
  f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
@@ -95,6 +95,8 @@ async def clear_data(driver: GraphDriver, group_ids: list[str] | None = None):
95
95
 
96
96
  async def delete_all(tx):
97
97
  await tx.run('MATCH (n) DETACH DELETE n')
98
+ if driver.aoss_client:
99
+ await driver.clear_aoss_indices()
98
100
 
99
101
  async def delete_group_ids(tx):
100
102
  labels = ['Entity', 'Episodic', 'Community']
@@ -151,9 +153,9 @@ async def retrieve_episodes(
151
153
 
152
154
  query: LiteralString = (
153
155
  """
154
- MATCH (e:Episodic)
155
- WHERE e.valid_at <= $reference_time
156
- """
156
+ MATCH (e:Episodic)
157
+ WHERE e.valid_at <= $reference_time
158
+ """
157
159
  + query_filter
158
160
  + """
159
161
  RETURN
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: graphiti-core
3
- Version: 0.21.0rc1
3
+ Version: 0.21.0rc2
4
4
  Summary: A temporal graph building library
5
5
  Project-URL: Homepage, https://help.getzep.com/graphiti/graphiti/overview
6
6
  Project-URL: Repository, https://github.com/getzep/graphiti
@@ -1,11 +1,11 @@
1
1
  graphiti_core/__init__.py,sha256=e5SWFkRiaUwfprYIeIgVIh7JDedNiloZvd3roU-0aDY,55
2
- graphiti_core/edges.py,sha256=eGDQBTtOoqLg5grykVFDwHEpxQRYTVZeM3u0NAkBI6Y,19380
2
+ graphiti_core/edges.py,sha256=PhJm_s28cHLEaIqcw66wP16hOq4P4bVQbC_sESHQkXU,20919
3
3
  graphiti_core/errors.py,sha256=cH_v9TPgEPeQE6GFOHIg5TvejpUCBddGarMY2Whxbwc,2707
4
4
  graphiti_core/graph_queries.py,sha256=9DWMiFTB-OmodMDaOws0lwzgiD7EUDNO7mAFJ1nxusE,6624
5
- graphiti_core/graphiti.py,sha256=eSQyajym9kefrf6PDCxCDzjWaco4hgFryyosg27husA,41601
5
+ graphiti_core/graphiti.py,sha256=UPa85sdCdO4xnl68EsFYnbkqTgta2LA-XwP7hzcAsyg,42072
6
6
  graphiti_core/graphiti_types.py,sha256=C_p2XwScQlCzo7ets097TrSLs9ATxPZQ4WCsxDS7QHc,1066
7
7
  graphiti_core/helpers.py,sha256=6q_wpiOW3_j28EfZ7FgWW7Hl5pONj_5zvVXZGW9FxTU,5175
8
- graphiti_core/nodes.py,sha256=xqVsSmMu7h7eorZ2TLFEBPBQLtvIID_haGl2fn7r_Gw,26591
8
+ graphiti_core/nodes.py,sha256=wYLQcVEXvQMxTpTc9LWSoPTzzaoUOm0rl07c9wS1XSY,30323
9
9
  graphiti_core/py.typed,sha256=vlmmzQOt7bmeQl9L3XJP4W6Ry0iiELepnOrinKz5KQg,79
10
10
  graphiti_core/cross_encoder/__init__.py,sha256=hry59vz21x-AtGZ0MJ7ugw0HTwJkXiddpp_Yqnwsen0,723
11
11
  graphiti_core/cross_encoder/bge_reranker_client.py,sha256=y3TfFxZh0Yvj6HUShmfUm6MC7OPXwWUlv1Qe5HF3S3I,1797
@@ -13,11 +13,11 @@ graphiti_core/cross_encoder/client.py,sha256=KLsbfWKOEaAV3adFe3XZlAeb-gje9_sVKCV
13
13
  graphiti_core/cross_encoder/gemini_reranker_client.py,sha256=hmITG5YIib52nrKvINwRi4xTfAO1U4jCCaEVIwImHw0,6208
14
14
  graphiti_core/cross_encoder/openai_reranker_client.py,sha256=WHMl6Q6gEslR2EzjwpFSZt2Kh6bnu8alkLvzmi0MDtg,4674
15
15
  graphiti_core/driver/__init__.py,sha256=kCWimqQU19airu5gKwCmZtZuXkDfaQfKSUhMDoL-rTA,626
16
- graphiti_core/driver/driver.py,sha256=JOoQ9omQHbBDgeNAMD5K7U0SyUScqmxdcdOB8KBfmMc,8299
16
+ graphiti_core/driver/driver.py,sha256=5YbpyUq7L1pf8s5R2FYv6B3KabTCmRvhSAHiVyRvt5o,10433
17
17
  graphiti_core/driver/falkordb_driver.py,sha256=JsNBRQHBVENA8eqAngD-8dw1aTH1ZKUtE1on8sd7owY,6431
18
18
  graphiti_core/driver/kuzu_driver.py,sha256=RcWu8E0CCdofrFe34NmCeqfuhaZr_7ZN5jqDkI3VQMI,5453
19
- graphiti_core/driver/neo4j_driver.py,sha256=he1DpxcrMBSzMD7ZIKF11VQ3U358cxHctrn9YXgaWLY,3831
20
- graphiti_core/driver/neptune_driver.py,sha256=ag4zr1bctB_GgAS-h6e9nNEmG-8P_f2H1MJQ8sPFifo,10299
19
+ graphiti_core/driver/neo4j_driver.py,sha256=E93PdOZaH7wzEbIfoiDSYht49jr6zSzvMMyo1INGEOw,4096
20
+ graphiti_core/driver/neptune_driver.py,sha256=akNLHhFHPEeQu-xO3PM51RomklntT6k5eA2CQ4AFbCc,10311
21
21
  graphiti_core/embedder/__init__.py,sha256=EL564ZuE-DZjcuKNUK_exMn_XHXm2LdO9fzdXePVKL4,179
22
22
  graphiti_core/embedder/azure_openai.py,sha256=OyomPwC1fIsddI-3n6g00kQFdQznZorBhHwkQKCLUok,2384
23
23
  graphiti_core/embedder/client.py,sha256=BXFMXvuPWxaAzPaPILnxtqQQ4JWBFQv9GdBLOXUWgwE,1158
@@ -58,23 +58,23 @@ graphiti_core/search/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hS
58
58
  graphiti_core/search/search.py,sha256=2kj7fybSFv6Fnf_cfEUhJhrpfzNtmkPPZ0hV3BQCDqg,18387
59
59
  graphiti_core/search/search_config.py,sha256=v_rUHsu1yo5OuPfEm21lSuXexQs-o8qYwSSemW2QWhU,4165
60
60
  graphiti_core/search/search_config_recipes.py,sha256=4GquRphHhJlpXQhAZOySYnCzBWYoTwxlJj44eTOavZQ,7443
61
- graphiti_core/search/search_filters.py,sha256=Gj7Lis62aFLV_3H9Vc4euDyYSkfX3fBS9rxjXNAN9es,9991
61
+ graphiti_core/search/search_filters.py,sha256=DOAmYkc6A0z20EZId5fJZj1RvLz4WeQcoPANk9k-Sh8,10304
62
62
  graphiti_core/search/search_helpers.py,sha256=wj3ARlCNnZixNNntgCdAqzGoE4de4lW3r4rSG-3WyGw,2877
63
- graphiti_core/search/search_utils.py,sha256=YQqHfquaDRZaKp-dSniBwyK9ybofGhfhEAGmiI83ZWo,73697
63
+ graphiti_core/search/search_utils.py,sha256=HqNYbyQklAq5GoOUo_W8Xut1GMGVYlNv0kE0rh03KHo,76827
64
64
  graphiti_core/telemetry/__init__.py,sha256=5kALLDlU9bb2v19CdN7qVANsJWyfnL9E60J6FFgzm3o,226
65
65
  graphiti_core/telemetry/telemetry.py,sha256=47LrzOVBCcZxsYPsnSxWFiztHoxYKKxPwyRX0hnbDGc,3230
66
66
  graphiti_core/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
67
- graphiti_core/utils/bulk_utils.py,sha256=KIi4E9muAtVQJQN8i40uBFbrizfNKAS60_VjS3RO7Nc,16938
67
+ graphiti_core/utils/bulk_utils.py,sha256=XyjWg_S6sQ1hM8HfHHoDnDn5yAWO1Ja0fYGFMfS7XV8,17071
68
68
  graphiti_core/utils/datetime_utils.py,sha256=J-zYSq7-H-2n9hYOXNIun12kM10vNX9mMATGR_egTmY,1806
69
69
  graphiti_core/utils/maintenance/__init__.py,sha256=vW4H1KyapTl-OOz578uZABYcpND4wPx3Vt6aAPaXh78,301
70
70
  graphiti_core/utils/maintenance/community_operations.py,sha256=XMiokEemn96GlvjkOvbo9hIX04Fea3eVj408NHG5P4o,11042
71
- graphiti_core/utils/maintenance/edge_operations.py,sha256=yxL5rc8eZh0GyduF_Vn04cqdmQQtCFwrbXEuoNF6G6E,20242
72
- graphiti_core/utils/maintenance/graph_data_operations.py,sha256=3UNSd2152q-EU1Cia1Nmlpn--nLFk6q3Mq6tcAx2sto,5988
71
+ graphiti_core/utils/maintenance/edge_operations.py,sha256=sejfmlbXCiMFcLAKFsw70_FHY1lVX0tLpdk4UCzuU-4,21418
72
+ graphiti_core/utils/maintenance/graph_data_operations.py,sha256=42icj3S_ELAJ-NK3jVS_rg_243dmnaZOyUitJj_uJ-M,6085
73
73
  graphiti_core/utils/maintenance/node_operations.py,sha256=r9ilkA01eq1z-nF8P_s1EXG6A6j15qmnfIqetnzqF50,13644
74
74
  graphiti_core/utils/maintenance/temporal_operations.py,sha256=IIaVtShpVkOYe6haxz3a1x3v54-MzaEXG8VsxFUNeoY,3582
75
75
  graphiti_core/utils/maintenance/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
76
76
  graphiti_core/utils/ontology_utils/entity_types_utils.py,sha256=4eVgxLWY6Q8k9cRJ5pW59IYF--U4nXZsZIGOVb_yHfQ,1285
77
- graphiti_core-0.21.0rc1.dist-info/METADATA,sha256=IzEjyOC4yTFqX-PIWWMTjew_atl6OvXNHZ8E7vx6hp4,26933
78
- graphiti_core-0.21.0rc1.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
79
- graphiti_core-0.21.0rc1.dist-info/licenses/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
80
- graphiti_core-0.21.0rc1.dist-info/RECORD,,
77
+ graphiti_core-0.21.0rc2.dist-info/METADATA,sha256=Pqw2f7ySYgCDyyuWDBvT4KLrL-SCJXa2aFHlNzIJzJo,26933
78
+ graphiti_core-0.21.0rc2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
79
+ graphiti_core-0.21.0rc2.dist-info/licenses/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
80
+ graphiti_core-0.21.0rc2.dist-info/RECORD,,