graphiti-core 0.22.0rc5__py3-none-any.whl → 0.22.1rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +5 -7
- graphiti_core/driver/falkordb_driver.py +54 -3
- graphiti_core/driver/graph_operations/__init__.py +0 -0
- graphiti_core/driver/graph_operations/graph_operations.py +195 -0
- graphiti_core/driver/neo4j_driver.py +9 -0
- graphiti_core/driver/search_interface/__init__.py +0 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +11 -34
- graphiti_core/models/edges/edge_db_queries.py +1 -0
- graphiti_core/models/nodes/node_db_queries.py +1 -0
- graphiti_core/nodes.py +26 -99
- graphiti_core/search/search_filters.py +0 -38
- graphiti_core/search/search_utils.py +84 -220
- graphiti_core/utils/bulk_utils.py +14 -28
- graphiti_core/utils/maintenance/edge_operations.py +20 -15
- graphiti_core/utils/maintenance/graph_data_operations.py +6 -25
- {graphiti_core-0.22.0rc5.dist-info → graphiti_core-0.22.1rc1.dist-info}/METADATA +36 -3
- {graphiti_core-0.22.0rc5.dist-info → graphiti_core-0.22.1rc1.dist-info}/RECORD +20 -16
- {graphiti_core-0.22.0rc5.dist-info → graphiti_core-0.22.1rc1.dist-info}/WHEEL +0 -0
- {graphiti_core-0.22.0rc5.dist-info → graphiti_core-0.22.1rc1.dist-info}/licenses/LICENSE +0 -0
graphiti_core/nodes.py
CHANGED
|
@@ -27,10 +27,6 @@ from pydantic import BaseModel, Field
|
|
|
27
27
|
from typing_extensions import LiteralString
|
|
28
28
|
|
|
29
29
|
from graphiti_core.driver.driver import (
|
|
30
|
-
COMMUNITY_INDEX_NAME,
|
|
31
|
-
ENTITY_EDGE_INDEX_NAME,
|
|
32
|
-
ENTITY_INDEX_NAME,
|
|
33
|
-
EPISODE_INDEX_NAME,
|
|
34
30
|
GraphDriver,
|
|
35
31
|
GraphProvider,
|
|
36
32
|
)
|
|
@@ -99,6 +95,9 @@ class Node(BaseModel, ABC):
|
|
|
99
95
|
async def save(self, driver: GraphDriver): ...
|
|
100
96
|
|
|
101
97
|
async def delete(self, driver: GraphDriver):
|
|
98
|
+
if driver.graph_operations_interface:
|
|
99
|
+
return await driver.graph_operations_interface.node_delete(self, driver)
|
|
100
|
+
|
|
102
101
|
match driver.provider:
|
|
103
102
|
case GraphProvider.NEO4J:
|
|
104
103
|
records, _, _ = await driver.execute_query(
|
|
@@ -113,27 +112,6 @@ class Node(BaseModel, ABC):
|
|
|
113
112
|
uuid=self.uuid,
|
|
114
113
|
)
|
|
115
114
|
|
|
116
|
-
edge_uuids: list[str] = records[0].get('edge_uuids', []) if records else []
|
|
117
|
-
|
|
118
|
-
if driver.aoss_client:
|
|
119
|
-
# Delete the node from OpenSearch indices
|
|
120
|
-
for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
|
|
121
|
-
await driver.aoss_client.delete(
|
|
122
|
-
index=index,
|
|
123
|
-
id=self.uuid,
|
|
124
|
-
params={'routing': self.group_id},
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
# Bulk delete the detached edges
|
|
128
|
-
if edge_uuids:
|
|
129
|
-
actions = []
|
|
130
|
-
for eid in edge_uuids:
|
|
131
|
-
actions.append(
|
|
132
|
-
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
await driver.aoss_client.bulk(body=actions)
|
|
136
|
-
|
|
137
115
|
case GraphProvider.KUZU:
|
|
138
116
|
for label in ['Episodic', 'Community']:
|
|
139
117
|
await driver.execute_query(
|
|
@@ -181,14 +159,18 @@ class Node(BaseModel, ABC):
|
|
|
181
159
|
|
|
182
160
|
@classmethod
|
|
183
161
|
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
|
|
162
|
+
if driver.graph_operations_interface:
|
|
163
|
+
return await driver.graph_operations_interface.node_delete_by_group_id(
|
|
164
|
+
cls, driver, group_id, batch_size
|
|
165
|
+
)
|
|
166
|
+
|
|
184
167
|
match driver.provider:
|
|
185
168
|
case GraphProvider.NEO4J:
|
|
186
169
|
async with driver.session() as session:
|
|
187
170
|
await session.run(
|
|
188
171
|
"""
|
|
189
172
|
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
|
|
190
|
-
CALL {
|
|
191
|
-
WITH n
|
|
173
|
+
CALL (n) {
|
|
192
174
|
DETACH DELETE n
|
|
193
175
|
} IN TRANSACTIONS OF $batch_size ROWS
|
|
194
176
|
""",
|
|
@@ -196,31 +178,6 @@ class Node(BaseModel, ABC):
|
|
|
196
178
|
batch_size=batch_size,
|
|
197
179
|
)
|
|
198
180
|
|
|
199
|
-
if driver.aoss_client:
|
|
200
|
-
await driver.aoss_client.delete_by_query(
|
|
201
|
-
index=EPISODE_INDEX_NAME,
|
|
202
|
-
body={'query': {'term': {'group_id': group_id}}},
|
|
203
|
-
params={'routing': group_id},
|
|
204
|
-
)
|
|
205
|
-
|
|
206
|
-
await driver.aoss_client.delete_by_query(
|
|
207
|
-
index=ENTITY_INDEX_NAME,
|
|
208
|
-
body={'query': {'term': {'group_id': group_id}}},
|
|
209
|
-
params={'routing': group_id},
|
|
210
|
-
)
|
|
211
|
-
|
|
212
|
-
await driver.aoss_client.delete_by_query(
|
|
213
|
-
index=COMMUNITY_INDEX_NAME,
|
|
214
|
-
body={'query': {'term': {'group_id': group_id}}},
|
|
215
|
-
params={'routing': group_id},
|
|
216
|
-
)
|
|
217
|
-
|
|
218
|
-
await driver.aoss_client.delete_by_query(
|
|
219
|
-
index=ENTITY_EDGE_INDEX_NAME,
|
|
220
|
-
body={'query': {'term': {'group_id': group_id}}},
|
|
221
|
-
params={'routing': group_id},
|
|
222
|
-
)
|
|
223
|
-
|
|
224
181
|
case GraphProvider.KUZU:
|
|
225
182
|
for label in ['Episodic', 'Community']:
|
|
226
183
|
await driver.execute_query(
|
|
@@ -258,6 +215,11 @@ class Node(BaseModel, ABC):
|
|
|
258
215
|
|
|
259
216
|
@classmethod
|
|
260
217
|
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
|
|
218
|
+
if driver.graph_operations_interface:
|
|
219
|
+
return await driver.graph_operations_interface.node_delete_by_uuids(
|
|
220
|
+
cls, driver, uuids, group_id=None, batch_size=batch_size
|
|
221
|
+
)
|
|
222
|
+
|
|
261
223
|
match driver.provider:
|
|
262
224
|
case GraphProvider.FALKORDB:
|
|
263
225
|
for label in ['Entity', 'Episodic', 'Community']:
|
|
@@ -300,7 +262,7 @@ class Node(BaseModel, ABC):
|
|
|
300
262
|
case _: # Neo4J, Neptune
|
|
301
263
|
async with driver.session() as session:
|
|
302
264
|
# Collect all edge UUIDs before deleting nodes
|
|
303
|
-
|
|
265
|
+
await session.run(
|
|
304
266
|
"""
|
|
305
267
|
MATCH (n:Entity|Episodic|Community)
|
|
306
268
|
WHERE n.uuid IN $uuids
|
|
@@ -310,18 +272,12 @@ class Node(BaseModel, ABC):
|
|
|
310
272
|
uuids=uuids,
|
|
311
273
|
)
|
|
312
274
|
|
|
313
|
-
record = await result.single()
|
|
314
|
-
edge_uuids: list[str] = (
|
|
315
|
-
record['edge_uuids'] if record and record['edge_uuids'] else []
|
|
316
|
-
)
|
|
317
|
-
|
|
318
275
|
# Now delete the nodes in batches
|
|
319
276
|
await session.run(
|
|
320
277
|
"""
|
|
321
278
|
MATCH (n:Entity|Episodic|Community)
|
|
322
279
|
WHERE n.uuid IN $uuids
|
|
323
|
-
CALL {
|
|
324
|
-
WITH n
|
|
280
|
+
CALL (n) {
|
|
325
281
|
DETACH DELETE n
|
|
326
282
|
} IN TRANSACTIONS OF $batch_size ROWS
|
|
327
283
|
""",
|
|
@@ -329,20 +285,6 @@ class Node(BaseModel, ABC):
|
|
|
329
285
|
batch_size=batch_size,
|
|
330
286
|
)
|
|
331
287
|
|
|
332
|
-
if driver.aoss_client:
|
|
333
|
-
for index in (EPISODE_INDEX_NAME, ENTITY_INDEX_NAME, COMMUNITY_INDEX_NAME):
|
|
334
|
-
await driver.aoss_client.delete_by_query(
|
|
335
|
-
index=index,
|
|
336
|
-
body={'query': {'terms': {'uuid': uuids}}},
|
|
337
|
-
)
|
|
338
|
-
|
|
339
|
-
if edge_uuids:
|
|
340
|
-
actions = [
|
|
341
|
-
{'delete': {'_index': ENTITY_EDGE_INDEX_NAME, '_id': eid}}
|
|
342
|
-
for eid in edge_uuids
|
|
343
|
-
]
|
|
344
|
-
await driver.aoss_client.bulk(body=actions)
|
|
345
|
-
|
|
346
288
|
@classmethod
|
|
347
289
|
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
|
|
348
290
|
|
|
@@ -363,6 +305,9 @@ class EpisodicNode(Node):
|
|
|
363
305
|
)
|
|
364
306
|
|
|
365
307
|
async def save(self, driver: GraphDriver):
|
|
308
|
+
if driver.graph_operations_interface:
|
|
309
|
+
return await driver.graph_operations_interface.episodic_node_save(self, driver)
|
|
310
|
+
|
|
366
311
|
episode_args = {
|
|
367
312
|
'uuid': self.uuid,
|
|
368
313
|
'name': self.name,
|
|
@@ -375,12 +320,6 @@ class EpisodicNode(Node):
|
|
|
375
320
|
'source': self.source.value,
|
|
376
321
|
}
|
|
377
322
|
|
|
378
|
-
if driver.aoss_client:
|
|
379
|
-
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
|
|
380
|
-
'episodes',
|
|
381
|
-
[episode_args],
|
|
382
|
-
)
|
|
383
|
-
|
|
384
323
|
result = await driver.execute_query(
|
|
385
324
|
get_episode_node_save_query(driver.provider), **episode_args
|
|
386
325
|
)
|
|
@@ -510,26 +449,14 @@ class EntityNode(Node):
|
|
|
510
449
|
return self.name_embedding
|
|
511
450
|
|
|
512
451
|
async def load_name_embedding(self, driver: GraphDriver):
|
|
452
|
+
if driver.graph_operations_interface:
|
|
453
|
+
return await driver.graph_operations_interface.node_load_embeddings(self, driver)
|
|
454
|
+
|
|
513
455
|
if driver.provider == GraphProvider.NEPTUNE:
|
|
514
456
|
query: LiteralString = """
|
|
515
457
|
MATCH (n:Entity {uuid: $uuid})
|
|
516
458
|
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
|
|
517
459
|
"""
|
|
518
|
-
elif driver.aoss_client:
|
|
519
|
-
resp = await driver.aoss_client.search(
|
|
520
|
-
body={
|
|
521
|
-
'query': {'multi_match': {'query': self.uuid, 'fields': ['uuid']}},
|
|
522
|
-
'size': 1,
|
|
523
|
-
},
|
|
524
|
-
index=ENTITY_INDEX_NAME,
|
|
525
|
-
params={'routing': self.group_id},
|
|
526
|
-
)
|
|
527
|
-
|
|
528
|
-
if resp['hits']['hits']:
|
|
529
|
-
self.name_embedding = resp['hits']['hits'][0]['_source']['name_embedding']
|
|
530
|
-
return
|
|
531
|
-
else:
|
|
532
|
-
raise NodeNotFoundError(self.uuid)
|
|
533
460
|
|
|
534
461
|
else:
|
|
535
462
|
query: LiteralString = """
|
|
@@ -548,6 +475,9 @@ class EntityNode(Node):
|
|
|
548
475
|
self.name_embedding = records[0]['name_embedding']
|
|
549
476
|
|
|
550
477
|
async def save(self, driver: GraphDriver):
|
|
478
|
+
if driver.graph_operations_interface:
|
|
479
|
+
return await driver.graph_operations_interface.node_save(self, driver)
|
|
480
|
+
|
|
551
481
|
entity_data: dict[str, Any] = {
|
|
552
482
|
'uuid': self.uuid,
|
|
553
483
|
'name': self.name,
|
|
@@ -568,11 +498,8 @@ class EntityNode(Node):
|
|
|
568
498
|
entity_data.update(self.attributes or {})
|
|
569
499
|
labels = ':'.join(self.labels + ['Entity'])
|
|
570
500
|
|
|
571
|
-
if driver.aoss_client:
|
|
572
|
-
await driver.save_to_aoss(ENTITY_INDEX_NAME, [entity_data]) # pyright: ignore reportAttributeAccessIssue
|
|
573
|
-
|
|
574
501
|
result = await driver.execute_query(
|
|
575
|
-
get_entity_node_save_query(driver.provider, labels
|
|
502
|
+
get_entity_node_save_query(driver.provider, labels),
|
|
576
503
|
entity_data=entity_data,
|
|
577
504
|
)
|
|
578
505
|
|
|
@@ -249,41 +249,3 @@ def edge_search_filter_query_constructor(
|
|
|
249
249
|
filter_queries.append(expired_at_filter)
|
|
250
250
|
|
|
251
251
|
return filter_queries, filter_params
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
def build_aoss_node_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
|
|
255
|
-
filters = [{'terms': {'group_id': group_ids}}]
|
|
256
|
-
|
|
257
|
-
if search_filters.node_labels:
|
|
258
|
-
filters.append({'terms': {'node_labels': search_filters.node_labels}})
|
|
259
|
-
|
|
260
|
-
return filters
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
def build_aoss_edge_filters(group_ids: list[str], search_filters: SearchFilters) -> list[dict]:
|
|
264
|
-
filters: list[dict] = [{'terms': {'group_id': group_ids}}]
|
|
265
|
-
|
|
266
|
-
if search_filters.edge_types:
|
|
267
|
-
filters.append({'terms': {'edge_types': search_filters.edge_types}})
|
|
268
|
-
|
|
269
|
-
if search_filters.edge_uuids:
|
|
270
|
-
filters.append({'terms': {'uuid': search_filters.edge_uuids}})
|
|
271
|
-
|
|
272
|
-
for field in ['valid_at', 'invalid_at', 'created_at', 'expired_at']:
|
|
273
|
-
ranges = getattr(search_filters, field)
|
|
274
|
-
if ranges:
|
|
275
|
-
# OR of ANDs
|
|
276
|
-
should_clauses = []
|
|
277
|
-
for and_group in ranges:
|
|
278
|
-
and_filters = []
|
|
279
|
-
for df in and_group: # df is a DateFilter
|
|
280
|
-
range_query = {
|
|
281
|
-
'range': {
|
|
282
|
-
field: {cypher_to_opensearch_operator(df.comparison_operator): df.date}
|
|
283
|
-
}
|
|
284
|
-
}
|
|
285
|
-
and_filters.append(range_query)
|
|
286
|
-
should_clauses.append({'bool': {'filter': and_filters}})
|
|
287
|
-
filters.append({'bool': {'should': should_clauses, 'minimum_should_match': 1}})
|
|
288
|
-
|
|
289
|
-
return filters
|