graphiti-core 0.20.4__py3-none-any.whl → 0.21.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

@@ -60,7 +60,7 @@ EPISODIC_EDGE_RETURN = """
60
60
  """
61
61
 
62
62
 
63
- def get_entity_edge_save_query(provider: GraphProvider) -> str:
63
+ def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False) -> str:
64
64
  match provider:
65
65
  case GraphProvider.FALKORDB:
66
66
  return """
@@ -99,17 +99,28 @@ def get_entity_edge_save_query(provider: GraphProvider) -> str:
99
99
  RETURN e.uuid AS uuid
100
100
  """
101
101
  case _: # Neo4j
102
- return """
103
- MATCH (source:Entity {uuid: $edge_data.source_uuid})
104
- MATCH (target:Entity {uuid: $edge_data.target_uuid})
105
- MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
106
- SET e = $edge_data
107
- WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)
102
+ save_embedding_query = (
103
+ """WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)"""
104
+ if not has_aoss
105
+ else ''
106
+ )
107
+ return (
108
+ (
109
+ """
110
+ MATCH (source:Entity {uuid: $edge_data.source_uuid})
111
+ MATCH (target:Entity {uuid: $edge_data.target_uuid})
112
+ MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
113
+ SET e = $edge_data
114
+ """
115
+ + save_embedding_query
116
+ )
117
+ + """
108
118
  RETURN e.uuid AS uuid
109
- """
119
+ """
120
+ )
110
121
 
111
122
 
112
- def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
123
+ def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = False) -> str:
113
124
  match provider:
114
125
  case GraphProvider.FALKORDB:
115
126
  return """
@@ -152,15 +163,24 @@ def get_entity_edge_save_bulk_query(provider: GraphProvider) -> str:
152
163
  RETURN e.uuid AS uuid
153
164
  """
154
165
  case _:
155
- return """
156
- UNWIND $entity_edges AS edge
157
- MATCH (source:Entity {uuid: edge.source_node_uuid})
158
- MATCH (target:Entity {uuid: edge.target_node_uuid})
159
- MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
160
- SET e = edge
161
- WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)
166
+ save_embedding_query = (
167
+ 'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)'
168
+ if not has_aoss
169
+ else ''
170
+ )
171
+ return (
172
+ """
173
+ UNWIND $entity_edges AS edge
174
+ MATCH (source:Entity {uuid: edge.source_node_uuid})
175
+ MATCH (target:Entity {uuid: edge.target_node_uuid})
176
+ MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
177
+ SET e = edge
178
+ """
179
+ + save_embedding_query
180
+ + """
162
181
  RETURN edge.uuid AS uuid
163
182
  """
183
+ )
164
184
 
165
185
 
166
186
  def get_entity_edge_return_query(provider: GraphProvider) -> str:
@@ -126,7 +126,7 @@ EPISODIC_NODE_RETURN_NEPTUNE = """
126
126
  """
127
127
 
128
128
 
129
- def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
129
+ def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str:
130
130
  match provider:
131
131
  case GraphProvider.FALKORDB:
132
132
  return f"""
@@ -161,16 +161,27 @@ def get_entity_node_save_query(provider: GraphProvider, labels: str) -> str:
161
161
  RETURN n.uuid AS uuid
162
162
  """
163
163
  case _:
164
- return f"""
164
+ save_embedding_query = (
165
+ 'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
166
+ if not has_aoss
167
+ else ''
168
+ )
169
+ return (
170
+ f"""
165
171
  MERGE (n:Entity {{uuid: $entity_data.uuid}})
166
172
  SET n:{labels}
167
173
  SET n = $entity_data
168
- WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)
174
+ """
175
+ + save_embedding_query
176
+ + """
169
177
  RETURN n.uuid AS uuid
170
178
  """
179
+ )
171
180
 
172
181
 
173
- def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict]) -> str | Any:
182
+ def get_entity_node_save_bulk_query(
183
+ provider: GraphProvider, nodes: list[dict], has_aoss: bool = False
184
+ ) -> str | Any:
174
185
  match provider:
175
186
  case GraphProvider.FALKORDB:
176
187
  queries = []
@@ -222,14 +233,23 @@ def get_entity_node_save_bulk_query(provider: GraphProvider, nodes: list[dict])
222
233
  RETURN n.uuid AS uuid
223
234
  """
224
235
  case _: # Neo4j
225
- return """
226
- UNWIND $nodes AS node
227
- MERGE (n:Entity {uuid: node.uuid})
228
- SET n:$(node.labels)
229
- SET n = node
230
- WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)
236
+ save_embedding_query = (
237
+ 'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
238
+ if not has_aoss
239
+ else ''
240
+ )
241
+ return (
242
+ """
243
+ UNWIND $nodes AS node
244
+ MERGE (n:Entity {uuid: node.uuid})
245
+ SET n:$(node.labels)
246
+ SET n = node
247
+ """
248
+ + save_embedding_query
249
+ + """
231
250
  RETURN n.uuid AS uuid
232
251
  """
252
+ )
233
253
 
234
254
 
235
255
  def get_entity_node_return_query(provider: GraphProvider) -> str:
graphiti_core/nodes.py CHANGED
@@ -26,7 +26,14 @@ from uuid import uuid4
26
26
  from pydantic import BaseModel, Field
27
27
  from typing_extensions import LiteralString
28
28
 
29
- from graphiti_core.driver.driver import GraphDriver, GraphProvider
29
+ from graphiti_core.driver.driver import (
30
+ COMMUNITY_INDEX_NAME,
31
+ ENTITY_EDGE_INDEX_NAME,
32
+ ENTITY_INDEX_NAME,
33
+ EPISODE_INDEX_NAME,
34
+ GraphDriver,
35
+ GraphProvider,
36
+ )
30
37
  from graphiti_core.embedder import EmbedderClient
31
38
  from graphiti_core.errors import NodeNotFoundError
32
39
  from graphiti_core.helpers import parse_db_date
@@ -94,13 +101,39 @@ class Node(BaseModel, ABC):
94
101
  async def delete(self, driver: GraphDriver):
95
102
  match driver.provider:
96
103
  case GraphProvider.NEO4J:
97
- await driver.execute_query(
104
+ records, _, _ = await driver.execute_query(
98
105
  """
99
- MATCH (n:Entity|Episodic|Community {uuid: $uuid})
106
+ MATCH (n {uuid: $uuid})
107
+ WHERE n:Entity OR n:Episodic OR n:Community
108
+ OPTIONAL MATCH (n)-[r]-()
109
+ WITH collect(r.uuid) AS edge_uuids, n
100
110
  DETACH DELETE n
111
+ RETURN edge_uuids
101
112
  """,
102
113
  uuid=self.uuid,
103
114
  )
115
+
116
+ edge_uuids: list[str] = records[0].get('edge_uuids', []) if records else []
117
+
118
+ if driver.aoss_client:
119
+ # Delete the node from OpenSearch indices
120
+ for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
121
+ await driver.aoss_client.delete(
122
+ index=index,
123
+ id=self.uuid,
124
+ params={'routing': self.group_id},
125
+ )
126
+
127
+ # Bulk delete the detached edges
128
+ if edge_uuids:
129
+ actions = []
130
+ for eid in edge_uuids:
131
+ actions.append(
132
+ {'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
133
+ )
134
+
135
+ await driver.aoss_client.bulk(body=actions)
136
+
104
137
  case GraphProvider.KUZU:
105
138
  for label in ['Episodic', 'Community']:
106
139
  await driver.execute_query(
@@ -162,6 +195,32 @@ class Node(BaseModel, ABC):
162
195
  group_id=group_id,
163
196
  batch_size=batch_size,
164
197
  )
198
+
199
+ if driver.aoss_client:
200
+ await driver.aoss_client.delete_by_query(
201
+ index=EPISODE_INDEX_NAME,
202
+ body={'query': {'term': {'group_id': group_id}}},
203
+ params={'routing': group_id},
204
+ )
205
+
206
+ await driver.aoss_client.delete_by_query(
207
+ index=ENTITY_INDEX_NAME,
208
+ body={'query': {'term': {'group_id': group_id}}},
209
+ params={'routing': group_id},
210
+ )
211
+
212
+ await driver.aoss_client.delete_by_query(
213
+ index=COMMUNITY_INDEX_NAME,
214
+ body={'query': {'term': {'group_id': group_id}}},
215
+ params={'routing': group_id},
216
+ )
217
+
218
+ await driver.aoss_client.delete_by_query(
219
+ index=ENTITY_EDGE_INDEX_NAME,
220
+ body={'query': {'term': {'group_id': group_id}}},
221
+ params={'routing': group_id},
222
+ )
223
+
165
224
  case GraphProvider.KUZU:
166
225
  for label in ['Episodic', 'Community']:
167
226
  await driver.execute_query(
@@ -240,6 +299,23 @@ class Node(BaseModel, ABC):
240
299
  )
241
300
  case _: # Neo4J, Neptune
242
301
  async with driver.session() as session:
302
+ # Collect all edge UUIDs before deleting nodes
303
+ result = await session.run(
304
+ """
305
+ MATCH (n:Entity|Episodic|Community)
306
+ WHERE n.uuid IN $uuids
307
+ MATCH (n)-[r]-()
308
+ RETURN collect(r.uuid) AS edge_uuids
309
+ """,
310
+ uuids=uuids,
311
+ )
312
+
313
+ record = await result.single()
314
+ edge_uuids: list[str] = (
315
+ record['edge_uuids'] if record and record['edge_uuids'] else []
316
+ )
317
+
318
+ # Now delete the nodes in batches
243
319
  await session.run(
244
320
  """
245
321
  MATCH (n:Entity|Episodic|Community)
@@ -253,6 +329,20 @@ class Node(BaseModel, ABC):
253
329
  batch_size=batch_size,
254
330
  )
255
331
 
332
+ if driver.aoss_client:
333
+ for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
334
+ await driver.aoss_client.delete_by_query(
335
+ index=index,
336
+ body={'query': {'terms': {'uuid': uuids}}},
337
+ )
338
+
339
+ if edge_uuids:
340
+ actions = [
341
+ {'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
342
+ for eid in edge_uuids
343
+ ]
344
+ await driver.aoss_client.bulk(body=actions)
345
+
256
346
  @classmethod
257
347
  async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
258
348
 
@@ -273,20 +363,6 @@ class EpisodicNode(Node):
273
363
  )
274
364
 
275
365
  async def save(self, driver: GraphDriver):
276
- if driver.provider == GraphProvider.NEPTUNE:
277
- driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
278
- 'episode_content',
279
- [
280
- {
281
- 'uuid': self.uuid,
282
- 'group_id': self.group_id,
283
- 'source': self.source.value,
284
- 'content': self.content,
285
- 'source_description': self.source_description,
286
- }
287
- ],
288
- )
289
-
290
366
  episode_args = {
291
367
  'uuid': self.uuid,
292
368
  'name': self.name,
@@ -299,6 +375,12 @@ class EpisodicNode(Node):
299
375
  'source': self.source.value,
300
376
  }
301
377
 
378
+ if driver.aoss_client:
379
+ await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
380
+ 'episodes',
381
+ [episode_args],
382
+ )
383
+
302
384
  result = await driver.execute_query(
303
385
  get_episode_node_save_query(driver.provider), **episode_args
304
386
  )
@@ -433,6 +515,22 @@ class EntityNode(Node):
433
515
  MATCH (n:Entity {uuid: $uuid})
434
516
  RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
435
517
  """
518
+ elif driver.aoss_client:
519
+ resp = await driver.aoss_client.search(
520
+ body={
521
+ 'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
522
+ 'size': 1,
523
+ },
524
+ index=ENTITY_INDEX_NAME,
525
+ params={'routing': self.group_id},
526
+ )
527
+
528
+ if resp['hits']['hits']:
529
+ self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding']
530
+ return
531
+ else:
532
+ raise NodeNotFoundError(self.uuid)
533
+
436
534
  else:
437
535
  query: LiteralString = """
438
536
  MATCH (n:Entity {uuid: $uuid})
@@ -470,11 +568,11 @@ class EntityNode(Node):
470
568
  entity_data.update(self.attributes or {})
471
569
  labels = ':'.join(self.labels + ['Entity'])
472
570
 
473
- if driver.provider == GraphProvider.NEPTUNE:
474
- driver.save_to_aoss('node_name_and_summary', [entity_data]) # pyright: ignore reportAttributeAccessIssue
571
+ if driver.aoss_client:
572
+ await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
475
573
 
476
574
  result = await driver.execute_query(
477
- get_entity_node_save_query(driver.provider, labels),
575
+ get_entity_node_save_query(driver.provider, labels, bool(driver.aoss_client)),
478
576
  entity_data=entity_data,
479
577
  )
480
578
 
@@ -569,8 +667,8 @@ class CommunityNode(Node):
569
667
 
570
668
  async def save(self, driver: GraphDriver):
571
669
  if driver.provider == GraphProvider.NEPTUNE:
572
- driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
573
- 'community_name',
670
+ await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
671
+ 'communities',
574
672
  [{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
575
673
  )
576
674
  result = await driver.execute_query(
@@ -52,6 +52,17 @@ class SearchFilters(BaseModel):
52
52
  invalid_at: list[list[DateFilter]] | None = Field(default=None)
53
53
  created_at: list[list[DateFilter]] | None = Field(default=None)
54
54
  expired_at: list[list[DateFilter]] | None = Field(default=None)
55
+ edge_uuids: list[str] | None = Field(default=None)
56
+
57
+
58
+ def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
59
+ mapping = {
60
+ ComparisonOperator.greater_than: 'gt',
61
+ ComparisonOperator.less_than: 'lt',
62
+ ComparisonOperator.greater_than_equal: 'gte',
63
+ ComparisonOperator.less_than_equal: 'lte',
64
+ }
65
+ return mapping.get(op, op.value)
55
66
 
56
67
 
57
68
  def node_search_filter_query_constructor(
@@ -98,6 +109,10 @@ def edge_search_filter_query_constructor(
98
109
  filter_queries.append('e.name in $edge_types')
99
110
  filter_params['edge_types'] = edge_types
100
111
 
112
+ if filters.edge_uuids is not None:
113
+ filter_queries.append('e.uuid in $edge_uuids')
114
+ filter_params['edge_uuids'] = filters.edge_uuids
115
+
101
116
  if filters.node_labels is not None:
102
117
  if provider == GraphProvider.KUZU:
103
118
  node_label_filter = (
@@ -234,3 +249,41 @@ def edge_search_filter_query_constructor(
234
249
  filter_queries.append(expired_at_filter)
235
250
 
236
251
  return filter_queries, filter_params
252
+
253
+
254
+ def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
255
+ filters = [{'terms': {'group_id': group_ids}}]
256
+
257
+ if search_filters.node_labels:
258
+ filters.append({'terms': {'node_labels': search_filters.node_labels}})
259
+
260
+ return filters
261
+
262
+
263
+ def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
264
+ filters: list[dict] = [{'terms': {'group_id': group_ids}}]
265
+
266
+ if search_filters.edge_types:
267
+ filters.append({'terms': {'edge_types': search_filters.edge_types}})
268
+
269
+ if search_filters.edge_uuids:
270
+ filters.append({'terms': {'uuid': search_filters.edge_uuids}})
271
+
272
+ for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
273
+ ranges = getattr(search_filters, field)
274
+ if ranges:
275
+ # OR of ANDs
276
+ should_clauses = []
277
+ for and_group in ranges:
278
+ and_filters = []
279
+ for df in and_group: # df is a DateFilter
280
+ range_query = {
281
+ 'range': {
282
+ field: {cypher_to_opensearch_operator(df.comparison_operator): df.date}
283
+ }
284
+ }
285
+ and_filters.append(range_query)
286
+ should_clauses.append({'bool': {'filter': and_filters}})
287
+ filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}})
288
+
289
+ return filters