graphiti-core 0.22.0rc5__py3-none-any.whl → 0.22.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

graphiti_core/nodes.py CHANGED
@@ -27,10 +27,6 @@ from pydantic import BaseModel, Field
27
27
  from typing_extensions import LiteralString
28
28
 
29
29
  from graphiti_core.driver.driver import (
30
- COMMUNITY_INDEX_NAME,
31
- ENTITY_EDGE_INDEX_NAME,
32
- ENTITY_INDEX_NAME,
33
- EPISODE_INDEX_NAME,
34
30
  GraphDriver,
35
31
  GraphProvider,
36
32
  )
@@ -99,6 +95,9 @@ class Node(BaseModel, ABC):
99
95
  async def save(self, driver: GraphDriver): ...
100
96
 
101
97
  async def delete(self, driver: GraphDriver):
98
+ if driver.graph_operations_interface:
99
+ return await driver.graph_operations_interface.node_delete(self, driver)
100
+
102
101
  match driver.provider:
103
102
  case GraphProvider.NEO4J:
104
103
  records, _, _ = await driver.execute_query(
@@ -113,27 +112,6 @@ class Node(BaseModel, ABC):
113
112
  uuid=self.uuid,
114
113
  )
115
114
 
116
- edge_uuids: list[str] = records[0].get('edge_uuids', []) if records else []
117
-
118
- if driver.aoss_client:
119
- # Delete the node from OpenSearch indices
120
- for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
121
- await driver.aoss_client.delete(
122
- index=index,
123
- id=self.uuid,
124
- params={'routing': self.group_id},
125
- )
126
-
127
- # Bulk delete the detached edges
128
- if edge_uuids:
129
- actions = []
130
- for eid in edge_uuids:
131
- actions.append(
132
- {'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
133
- )
134
-
135
- await driver.aoss_client.bulk(body=actions)
136
-
137
115
  case GraphProvider.KUZU:
138
116
  for label in ['Episodic', 'Community']:
139
117
  await driver.execute_query(
@@ -181,14 +159,18 @@ class Node(BaseModel, ABC):
181
159
 
182
160
  @classmethod
183
161
  async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
162
+ if driver.graph_operations_interface:
163
+ return await driver.graph_operations_interface.node_delete_by_group_id(
164
+ cls, driver, group_id, batch_size
165
+ )
166
+
184
167
  match driver.provider:
185
168
  case GraphProvider.NEO4J:
186
169
  async with driver.session() as session:
187
170
  await session.run(
188
171
  """
189
172
  MATCH (n:Entity|Episodic|Community {group_id: $group_id})
190
- CALL {
191
- WITH n
173
+ CALL (n) {
192
174
  DETACH DELETE n
193
175
  } IN TRANSACTIONS OF $batch_size ROWS
194
176
  """,
@@ -196,31 +178,6 @@ class Node(BaseModel, ABC):
196
178
  batch_size=batch_size,
197
179
  )
198
180
 
199
- if driver.aoss_client:
200
- await driver.aoss_client.delete_by_query(
201
- index=EPISODE_INDEX_NAME,
202
- body={'query': {'term': {'group_id': group_id}}},
203
- params={'routing': group_id},
204
- )
205
-
206
- await driver.aoss_client.delete_by_query(
207
- index=ENTITY_INDEX_NAME,
208
- body={'query': {'term': {'group_id': group_id}}},
209
- params={'routing': group_id},
210
- )
211
-
212
- await driver.aoss_client.delete_by_query(
213
- index=COMMUNITY_INDEX_NAME,
214
- body={'query': {'term': {'group_id': group_id}}},
215
- params={'routing': group_id},
216
- )
217
-
218
- await driver.aoss_client.delete_by_query(
219
- index=ENTITY_EDGE_INDEX_NAME,
220
- body={'query': {'term': {'group_id': group_id}}},
221
- params={'routing': group_id},
222
- )
223
-
224
181
  case GraphProvider.KUZU:
225
182
  for label in ['Episodic', 'Community']:
226
183
  await driver.execute_query(
@@ -258,6 +215,11 @@ class Node(BaseModel, ABC):
258
215
 
259
216
  @classmethod
260
217
  async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
218
+ if driver.graph_operations_interface:
219
+ return await driver.graph_operations_interface.node_delete_by_uuids(
220
+ cls, driver, uuids, group_id=None, batch_size=batch_size
221
+ )
222
+
261
223
  match driver.provider:
262
224
  case GraphProvider.FALKORDB:
263
225
  for label in ['Entity', 'Episodic', 'Community']:
@@ -300,7 +262,7 @@ class Node(BaseModel, ABC):
300
262
  case _: # Neo4J, Neptune
301
263
  async with driver.session() as session:
302
264
  # Collect all edge UUIDs before deleting nodes
303
- result = await session.run(
265
+ await session.run(
304
266
  """
305
267
  MATCH (n:Entity|Episodic|Community)
306
268
  WHERE n.uuid IN $uuids
@@ -310,18 +272,12 @@ class Node(BaseModel, ABC):
310
272
  uuids=uuids,
311
273
  )
312
274
 
313
- record = await result.single()
314
- edge_uuids: list[str] = (
315
- record['edge_uuids'] if record and record['edge_uuids'] else []
316
- )
317
-
318
275
  # Now delete the nodes in batches
319
276
  await session.run(
320
277
  """
321
278
  MATCH (n:Entity|Episodic|Community)
322
279
  WHERE n.uuid IN $uuids
323
- CALL {
324
- WITH n
280
+ CALL (n) {
325
281
  DETACH DELETE n
326
282
  } IN TRANSACTIONS OF $batch_size ROWS
327
283
  """,
@@ -329,20 +285,6 @@ class Node(BaseModel, ABC):
329
285
  batch_size=batch_size,
330
286
  )
331
287
 
332
- if driver.aoss_client:
333
- for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
334
- await driver.aoss_client.delete_by_query(
335
- index=index,
336
- body={'query': {'terms': {'uuid': uuids}}},
337
- )
338
-
339
- if edge_uuids:
340
- actions = [
341
- {'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
342
- for eid in edge_uuids
343
- ]
344
- await driver.aoss_client.bulk(body=actions)
345
-
346
288
  @classmethod
347
289
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
348
290
 
@@ -363,6 +305,9 @@ class EpisodicNode(Node):
363
305
  )
364
306
 
365
307
  async def save(self, driver: GraphDriver):
308
+ if driver.graph_operations_interface:
309
+ return await driver.graph_operations_interface.episodic_node_save(self, driver)
310
+
366
311
  episode_args = {
367
312
  'uuid': self.uuid,
368
313
  'name': self.name,
@@ -375,12 +320,6 @@ class EpisodicNode(Node):
375
320
  'source': self.source.value,
376
321
  }
377
322
 
378
- if driver.aoss_client:
379
- await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
380
- 'episodes',
381
- [episode_args],
382
- )
383
-
384
323
  result = await driver.execute_query(
385
324
  get_episode_node_save_query(driver.provider), **episode_args
386
325
  )
@@ -510,26 +449,14 @@ class EntityNode(Node):
510
449
  return self.name_embedding
511
450
 
512
451
  async def load_name_embedding(self, driver: GraphDriver):
452
+ if driver.graph_operations_interface:
453
+ return await driver.graph_operations_interface.node_load_embeddings(self, driver)
454
+
513
455
  if driver.provider == GraphProvider.NEPTUNE:
514
456
  query: LiteralString = """
515
457
  MATCH (n:Entity {uuid: $uuid})
516
458
  RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
517
459
  """
518
- elif driver.aoss_client:
519
- resp = await driver.aoss_client.search(
520
- body={
521
- 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
522
- 'size': 1,
523
- },
524
- index=ENTITY_INDEX_NAME,
525
- params={'routing': self.group_id},
526
- )
527
-
528
- if resp['hits']['hits']:
529
- self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding']
530
- return
531
- else:
532
- raise NodeNotFoundError(self.uuid)
533
460
 
534
461
  else:
535
462
  query: LiteralString = """
@@ -548,6 +475,9 @@ class EntityNode(Node):
548
475
  self.name_embedding = records[0]['name_embedding']
549
476
 
550
477
  async def save(self, driver: GraphDriver):
478
+ if driver.graph_operations_interface:
479
+ return await driver.graph_operations_interface.node_save(self, driver)
480
+
551
481
  entity_data: dict[str, Any] = {
552
482
  'uuid': self.uuid,
553
483
  'name': self.name,
@@ -568,11 +498,8 @@ class EntityNode(Node):
568
498
  entity_data.update(self.attributes or {})
569
499
  labels = ':'.join(self.labels + ['Entity'])
570
500
 
571
- if driver.aoss_client:
572
- await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
573
-
574
501
  result = await driver.execute_query(
575
- get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
502
+ get_entity_node_save_query(driver.provider, labels),
576
503
  entity_data=entity_data,
577
504
  )
578
505
 
@@ -249,41 +249,3 @@ def edge_search_filter_query_constructor(
249
249
  filter_queries.append(expired_at_filter)
250
250
 
251
251
  return filter_queries, filter_params
252
-
253
-
254
- def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
255
- filters = [{'terms': {'group_id': group_ids}}]
256
-
257
- if search_filters.node_labels:
258
- filters.append({'terms': {'node_labels': search_filters.node_labels}})
259
-
260
- return filters
261
-
262
-
263
- def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
264
- filters: list[dict] = [{'terms': {'group_id': group_ids}}]
265
-
266
- if search_filters.edge_types:
267
- filters.append({'terms': {'edge_types': search_filters.edge_types}})
268
-
269
- if search_filters.edge_uuids:
270
- filters.append({'terms': {'uuid': search_filters.edge_uuids}})
271
-
272
- for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
273
- ranges = getattr(search_filters, field)
274
- if ranges:
275
- # OR of ANDs
276
- should_clauses = []
277
- for and_group in ranges:
278
- and_filters = []
279
- for df in and_group: # df is a DateFilter
280
- range_query = {
281
- 'range': {
282
- field: {cypher_to_opensearch_operator(df.comparison_operator): df.date}
283
- }
284
- }
285
- and_filters.append(range_query)
286
- should_clauses.append({'bool': {'filter': and_filters}})
287
- filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}})
288
-
289
- return filters