graphiti-core 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/graphiti.py +87 -67
- graphiti_core/llm_client/openai_client.py +0 -1
- graphiti_core/prompts/dedupe_edges.py +46 -8
- graphiti_core/prompts/dedupe_nodes.py +61 -13
- graphiti_core/prompts/extract_edges.py +2 -1
- graphiti_core/prompts/extract_nodes.py +2 -0
- graphiti_core/search/search.py +8 -8
- graphiti_core/search/search_utils.py +196 -54
- graphiti_core/utils/bulk_utils.py +138 -20
- graphiti_core/utils/maintenance/edge_operations.py +76 -9
- graphiti_core/utils/maintenance/node_operations.py +87 -29
- graphiti_core/utils/maintenance/temporal_operations.py +3 -4
- graphiti_core/utils/utils.py +22 -1
- {graphiti_core-0.1.0.dist-info → graphiti_core-0.2.1.dist-info}/METADATA +40 -38
- {graphiti_core-0.1.0.dist-info → graphiti_core-0.2.1.dist-info}/RECORD +17 -17
- {graphiti_core-0.1.0.dist-info → graphiti_core-0.2.1.dist-info}/LICENSE +0 -0
- {graphiti_core-0.1.0.dist-info → graphiti_core-0.2.1.dist-info}/WHEEL +0 -0
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
3
|
import re
|
|
4
|
-
import typing
|
|
5
4
|
from collections import defaultdict
|
|
6
5
|
from time import time
|
|
6
|
+
from typing import Any
|
|
7
7
|
|
|
8
|
-
from neo4j import AsyncDriver
|
|
8
|
+
from neo4j import AsyncDriver, Query
|
|
9
9
|
|
|
10
10
|
from graphiti_core.edges import EntityEdge
|
|
11
11
|
from graphiti_core.helpers import parse_db_date
|
|
@@ -66,12 +66,12 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
|
|
|
66
66
|
r.expired_at AS expired_at,
|
|
67
67
|
r.valid_at AS valid_at,
|
|
68
68
|
r.invalid_at AS invalid_at
|
|
69
|
-
|
|
69
|
+
|
|
70
70
|
""",
|
|
71
71
|
node_ids=node_ids,
|
|
72
72
|
)
|
|
73
73
|
|
|
74
|
-
context: dict[str,
|
|
74
|
+
context: dict[str, Any] = {}
|
|
75
75
|
|
|
76
76
|
for record in records:
|
|
77
77
|
n_uuid = record['source_node_uuid']
|
|
@@ -96,14 +96,17 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):
|
|
|
96
96
|
|
|
97
97
|
|
|
98
98
|
async def edge_similarity_search(
|
|
99
|
-
|
|
99
|
+
driver: AsyncDriver,
|
|
100
|
+
search_vector: list[float],
|
|
101
|
+
source_node_uuid: str | None,
|
|
102
|
+
target_node_uuid: str | None,
|
|
103
|
+
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
100
104
|
) -> list[EntityEdge]:
|
|
101
105
|
# vector similarity search over embedded facts
|
|
102
|
-
|
|
103
|
-
"""
|
|
106
|
+
query = Query("""
|
|
104
107
|
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
105
|
-
YIELD relationship AS
|
|
106
|
-
MATCH (n)-[r:
|
|
108
|
+
YIELD relationship AS rel, score
|
|
109
|
+
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
107
110
|
RETURN
|
|
108
111
|
r.uuid AS uuid,
|
|
109
112
|
n.uuid AS source_node_uuid,
|
|
@@ -117,8 +120,71 @@ async def edge_similarity_search(
|
|
|
117
120
|
r.valid_at AS valid_at,
|
|
118
121
|
r.invalid_at AS invalid_at
|
|
119
122
|
ORDER BY score DESC
|
|
120
|
-
|
|
123
|
+
""")
|
|
124
|
+
|
|
125
|
+
if source_node_uuid is None and target_node_uuid is None:
|
|
126
|
+
query = Query("""
|
|
127
|
+
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
128
|
+
YIELD relationship AS rel, score
|
|
129
|
+
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
|
130
|
+
RETURN
|
|
131
|
+
r.uuid AS uuid,
|
|
132
|
+
n.uuid AS source_node_uuid,
|
|
133
|
+
m.uuid AS target_node_uuid,
|
|
134
|
+
r.created_at AS created_at,
|
|
135
|
+
r.name AS name,
|
|
136
|
+
r.fact AS fact,
|
|
137
|
+
r.fact_embedding AS fact_embedding,
|
|
138
|
+
r.episodes AS episodes,
|
|
139
|
+
r.expired_at AS expired_at,
|
|
140
|
+
r.valid_at AS valid_at,
|
|
141
|
+
r.invalid_at AS invalid_at
|
|
142
|
+
ORDER BY score DESC
|
|
143
|
+
""")
|
|
144
|
+
elif source_node_uuid is None:
|
|
145
|
+
query = Query("""
|
|
146
|
+
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
147
|
+
YIELD relationship AS rel, score
|
|
148
|
+
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
149
|
+
RETURN
|
|
150
|
+
r.uuid AS uuid,
|
|
151
|
+
n.uuid AS source_node_uuid,
|
|
152
|
+
m.uuid AS target_node_uuid,
|
|
153
|
+
r.created_at AS created_at,
|
|
154
|
+
r.name AS name,
|
|
155
|
+
r.fact AS fact,
|
|
156
|
+
r.fact_embedding AS fact_embedding,
|
|
157
|
+
r.episodes AS episodes,
|
|
158
|
+
r.expired_at AS expired_at,
|
|
159
|
+
r.valid_at AS valid_at,
|
|
160
|
+
r.invalid_at AS invalid_at
|
|
161
|
+
ORDER BY score DESC
|
|
162
|
+
""")
|
|
163
|
+
elif target_node_uuid is None:
|
|
164
|
+
query = Query("""
|
|
165
|
+
CALL db.index.vector.queryRelationships("fact_embedding", $limit, $search_vector)
|
|
166
|
+
YIELD relationship AS rel, score
|
|
167
|
+
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
|
168
|
+
RETURN
|
|
169
|
+
r.uuid AS uuid,
|
|
170
|
+
n.uuid AS source_node_uuid,
|
|
171
|
+
m.uuid AS target_node_uuid,
|
|
172
|
+
r.created_at AS created_at,
|
|
173
|
+
r.name AS name,
|
|
174
|
+
r.fact AS fact,
|
|
175
|
+
r.fact_embedding AS fact_embedding,
|
|
176
|
+
r.episodes AS episodes,
|
|
177
|
+
r.expired_at AS expired_at,
|
|
178
|
+
r.valid_at AS valid_at,
|
|
179
|
+
r.invalid_at AS invalid_at
|
|
180
|
+
ORDER BY score DESC
|
|
181
|
+
""")
|
|
182
|
+
|
|
183
|
+
records, _, _ = await driver.execute_query(
|
|
184
|
+
query,
|
|
121
185
|
search_vector=search_vector,
|
|
186
|
+
source_uuid=source_node_uuid,
|
|
187
|
+
target_uuid=target_node_uuid,
|
|
122
188
|
limit=limit,
|
|
123
189
|
)
|
|
124
190
|
|
|
@@ -145,7 +211,7 @@ async def edge_similarity_search(
|
|
|
145
211
|
|
|
146
212
|
|
|
147
213
|
async def entity_similarity_search(
|
|
148
|
-
|
|
214
|
+
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
|
149
215
|
) -> list[EntityNode]:
|
|
150
216
|
# vector similarity search over entity names
|
|
151
217
|
records, _, _ = await driver.execute_query(
|
|
@@ -155,6 +221,7 @@ async def entity_similarity_search(
|
|
|
155
221
|
RETURN
|
|
156
222
|
n.uuid As uuid,
|
|
157
223
|
n.name AS name,
|
|
224
|
+
n.name_embeddings AS name_embedding,
|
|
158
225
|
n.created_at AS created_at,
|
|
159
226
|
n.summary AS summary
|
|
160
227
|
ORDER BY score DESC
|
|
@@ -169,6 +236,7 @@ async def entity_similarity_search(
|
|
|
169
236
|
EntityNode(
|
|
170
237
|
uuid=record['uuid'],
|
|
171
238
|
name=record['name'],
|
|
239
|
+
name_embedding=record['name_embedding'],
|
|
172
240
|
labels=['Entity'],
|
|
173
241
|
created_at=record['created_at'].to_native(),
|
|
174
242
|
summary=record['summary'],
|
|
@@ -179,7 +247,7 @@ async def entity_similarity_search(
|
|
|
179
247
|
|
|
180
248
|
|
|
181
249
|
async def entity_fulltext_search(
|
|
182
|
-
|
|
250
|
+
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
|
|
183
251
|
) -> list[EntityNode]:
|
|
184
252
|
# BM25 search to get top nodes
|
|
185
253
|
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
|
@@ -187,8 +255,9 @@ async def entity_fulltext_search(
|
|
|
187
255
|
"""
|
|
188
256
|
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
|
|
189
257
|
RETURN
|
|
190
|
-
node.uuid
|
|
258
|
+
node.uuid AS uuid,
|
|
191
259
|
node.name AS name,
|
|
260
|
+
node.name_embeddings AS name_embedding,
|
|
192
261
|
node.created_at AS created_at,
|
|
193
262
|
node.summary AS summary
|
|
194
263
|
ORDER BY score DESC
|
|
@@ -204,6 +273,7 @@ async def entity_fulltext_search(
|
|
|
204
273
|
EntityNode(
|
|
205
274
|
uuid=record['uuid'],
|
|
206
275
|
name=record['name'],
|
|
276
|
+
name_embedding=record['name_embedding'],
|
|
207
277
|
labels=['Entity'],
|
|
208
278
|
created_at=record['created_at'].to_native(),
|
|
209
279
|
summary=record['summary'],
|
|
@@ -214,17 +284,18 @@ async def entity_fulltext_search(
|
|
|
214
284
|
|
|
215
285
|
|
|
216
286
|
async def edge_fulltext_search(
|
|
217
|
-
|
|
287
|
+
driver: AsyncDriver,
|
|
288
|
+
query: str,
|
|
289
|
+
source_node_uuid: str | None,
|
|
290
|
+
target_node_uuid: str | None,
|
|
291
|
+
limit=RELEVANT_SCHEMA_LIMIT,
|
|
218
292
|
) -> list[EntityEdge]:
|
|
219
293
|
# fulltext search over facts
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
YIELD relationship AS r, score
|
|
226
|
-
MATCH (n:Entity)-[r]->(m:Entity)
|
|
227
|
-
RETURN
|
|
294
|
+
cypher_query = Query("""
|
|
295
|
+
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
296
|
+
YIELD relationship AS rel, score
|
|
297
|
+
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
298
|
+
RETURN
|
|
228
299
|
r.uuid AS uuid,
|
|
229
300
|
n.uuid AS source_node_uuid,
|
|
230
301
|
m.uuid AS target_node_uuid,
|
|
@@ -237,8 +308,73 @@ async def edge_fulltext_search(
|
|
|
237
308
|
r.valid_at AS valid_at,
|
|
238
309
|
r.invalid_at AS invalid_at
|
|
239
310
|
ORDER BY score DESC LIMIT $limit
|
|
240
|
-
"""
|
|
311
|
+
""")
|
|
312
|
+
|
|
313
|
+
if source_node_uuid is None and target_node_uuid is None:
|
|
314
|
+
cypher_query = Query("""
|
|
315
|
+
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
316
|
+
YIELD relationship AS rel, score
|
|
317
|
+
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity)
|
|
318
|
+
RETURN
|
|
319
|
+
r.uuid AS uuid,
|
|
320
|
+
n.uuid AS source_node_uuid,
|
|
321
|
+
m.uuid AS target_node_uuid,
|
|
322
|
+
r.created_at AS created_at,
|
|
323
|
+
r.name AS name,
|
|
324
|
+
r.fact AS fact,
|
|
325
|
+
r.fact_embedding AS fact_embedding,
|
|
326
|
+
r.episodes AS episodes,
|
|
327
|
+
r.expired_at AS expired_at,
|
|
328
|
+
r.valid_at AS valid_at,
|
|
329
|
+
r.invalid_at AS invalid_at
|
|
330
|
+
ORDER BY score DESC LIMIT $limit
|
|
331
|
+
""")
|
|
332
|
+
elif source_node_uuid is None:
|
|
333
|
+
cypher_query = Query("""
|
|
334
|
+
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
335
|
+
YIELD relationship AS rel, score
|
|
336
|
+
MATCH (n:Entity)-[r {uuid: rel.uuid}]-(m:Entity {uuid: $target_uuid})
|
|
337
|
+
RETURN
|
|
338
|
+
r.uuid AS uuid,
|
|
339
|
+
n.uuid AS source_node_uuid,
|
|
340
|
+
m.uuid AS target_node_uuid,
|
|
341
|
+
r.created_at AS created_at,
|
|
342
|
+
r.name AS name,
|
|
343
|
+
r.fact AS fact,
|
|
344
|
+
r.fact_embedding AS fact_embedding,
|
|
345
|
+
r.episodes AS episodes,
|
|
346
|
+
r.expired_at AS expired_at,
|
|
347
|
+
r.valid_at AS valid_at,
|
|
348
|
+
r.invalid_at AS invalid_at
|
|
349
|
+
ORDER BY score DESC LIMIT $limit
|
|
350
|
+
""")
|
|
351
|
+
elif target_node_uuid is None:
|
|
352
|
+
cypher_query = Query("""
|
|
353
|
+
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
|
|
354
|
+
YIELD relationship AS rel, score
|
|
355
|
+
MATCH (n:Entity {uuid: $source_uuid})-[r {uuid: rel.uuid}]-(m:Entity)
|
|
356
|
+
RETURN
|
|
357
|
+
r.uuid AS uuid,
|
|
358
|
+
n.uuid AS source_node_uuid,
|
|
359
|
+
m.uuid AS target_node_uuid,
|
|
360
|
+
r.created_at AS created_at,
|
|
361
|
+
r.name AS name,
|
|
362
|
+
r.fact AS fact,
|
|
363
|
+
r.fact_embedding AS fact_embedding,
|
|
364
|
+
r.episodes AS episodes,
|
|
365
|
+
r.expired_at AS expired_at,
|
|
366
|
+
r.valid_at AS valid_at,
|
|
367
|
+
r.invalid_at AS invalid_at
|
|
368
|
+
ORDER BY score DESC LIMIT $limit
|
|
369
|
+
""")
|
|
370
|
+
|
|
371
|
+
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
|
|
372
|
+
|
|
373
|
+
records, _, _ = await driver.execute_query(
|
|
374
|
+
cypher_query,
|
|
241
375
|
query=fuzzy_query,
|
|
376
|
+
source_uuid=source_node_uuid,
|
|
377
|
+
target_uuid=target_node_uuid,
|
|
242
378
|
limit=limit,
|
|
243
379
|
)
|
|
244
380
|
|
|
@@ -265,16 +401,16 @@ async def edge_fulltext_search(
|
|
|
265
401
|
|
|
266
402
|
|
|
267
403
|
async def hybrid_node_search(
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
404
|
+
queries: list[str],
|
|
405
|
+
embeddings: list[list[float]],
|
|
406
|
+
driver: AsyncDriver,
|
|
407
|
+
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
272
408
|
) -> list[EntityNode]:
|
|
273
409
|
"""
|
|
274
410
|
Perform a hybrid search for nodes using both text queries and embeddings.
|
|
275
411
|
|
|
276
412
|
This method combines fulltext search and vector similarity search to find
|
|
277
|
-
relevant nodes in the graph database.
|
|
413
|
+
relevant nodes in the graph database. It uses a rrf reranker.
|
|
278
414
|
|
|
279
415
|
Parameters
|
|
280
416
|
----------
|
|
@@ -307,33 +443,31 @@ async def hybrid_node_search(
|
|
|
307
443
|
"""
|
|
308
444
|
|
|
309
445
|
start = time()
|
|
310
|
-
relevant_nodes: list[EntityNode] = []
|
|
311
|
-
relevant_node_uuids = set()
|
|
312
446
|
|
|
313
|
-
results =
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
entity_similarity_search(e, driver, 2 *
|
|
317
|
-
|
|
318
|
-
],
|
|
447
|
+
results: list[list[EntityNode]] = list(
|
|
448
|
+
await asyncio.gather(
|
|
449
|
+
*[entity_fulltext_search(q, driver, 2 * limit) for q in queries],
|
|
450
|
+
*[entity_similarity_search(e, driver, 2 * limit) for e in embeddings],
|
|
451
|
+
)
|
|
319
452
|
)
|
|
320
453
|
|
|
321
|
-
|
|
322
|
-
for node in result
|
|
323
|
-
|
|
324
|
-
|
|
454
|
+
node_uuid_map: dict[str, EntityNode] = {
|
|
455
|
+
node.uuid: node for result in results for node in result
|
|
456
|
+
}
|
|
457
|
+
result_uuids = [[node.uuid for node in result] for result in results]
|
|
458
|
+
|
|
459
|
+
ranked_uuids = rrf(result_uuids)
|
|
325
460
|
|
|
326
|
-
|
|
327
|
-
relevant_nodes.append(node)
|
|
461
|
+
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
|
|
328
462
|
|
|
329
463
|
end = time()
|
|
330
|
-
logger.info(f'Found relevant nodes: {
|
|
464
|
+
logger.info(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms')
|
|
331
465
|
return relevant_nodes
|
|
332
466
|
|
|
333
467
|
|
|
334
468
|
async def get_relevant_nodes(
|
|
335
|
-
|
|
336
|
-
|
|
469
|
+
nodes: list[EntityNode],
|
|
470
|
+
driver: AsyncDriver,
|
|
337
471
|
) -> list[EntityNode]:
|
|
338
472
|
"""
|
|
339
473
|
Retrieve relevant nodes based on the provided list of EntityNodes.
|
|
@@ -369,8 +503,11 @@ async def get_relevant_nodes(
|
|
|
369
503
|
|
|
370
504
|
|
|
371
505
|
async def get_relevant_edges(
|
|
372
|
-
|
|
373
|
-
|
|
506
|
+
driver: AsyncDriver,
|
|
507
|
+
edges: list[EntityEdge],
|
|
508
|
+
source_node_uuid: str | None,
|
|
509
|
+
target_node_uuid: str | None,
|
|
510
|
+
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
374
511
|
) -> list[EntityEdge]:
|
|
375
512
|
start = time()
|
|
376
513
|
relevant_edges: list[EntityEdge] = []
|
|
@@ -378,11 +515,16 @@ async def get_relevant_edges(
|
|
|
378
515
|
|
|
379
516
|
results = await asyncio.gather(
|
|
380
517
|
*[
|
|
381
|
-
edge_similarity_search(
|
|
518
|
+
edge_similarity_search(
|
|
519
|
+
driver, edge.fact_embedding, source_node_uuid, target_node_uuid, limit
|
|
520
|
+
)
|
|
382
521
|
for edge in edges
|
|
383
522
|
if edge.fact_embedding is not None
|
|
384
523
|
],
|
|
385
|
-
*[
|
|
524
|
+
*[
|
|
525
|
+
edge_fulltext_search(driver, edge.fact, source_node_uuid, target_node_uuid, limit)
|
|
526
|
+
for edge in edges
|
|
527
|
+
],
|
|
386
528
|
)
|
|
387
529
|
|
|
388
530
|
for result in results:
|
|
@@ -415,18 +557,18 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
|
|
|
415
557
|
|
|
416
558
|
|
|
417
559
|
async def node_distance_reranker(
|
|
418
|
-
|
|
560
|
+
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
|
|
419
561
|
) -> list[str]:
|
|
420
562
|
# use rrf as a preliminary ranker
|
|
421
563
|
sorted_uuids = rrf(results)
|
|
422
564
|
scores: dict[str, float] = {}
|
|
423
565
|
|
|
424
566
|
for uuid in sorted_uuids:
|
|
425
|
-
# Find shortest path to center node
|
|
567
|
+
# Find the shortest path to center node
|
|
426
568
|
records, _, _ = await driver.execute_query(
|
|
427
569
|
"""
|
|
428
570
|
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
|
|
429
|
-
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]
|
|
571
|
+
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO*1..10]->(n:Entity)
|
|
430
572
|
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
|
|
431
573
|
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
|
|
432
574
|
""",
|
|
@@ -437,8 +579,8 @@ async def node_distance_reranker(
|
|
|
437
579
|
|
|
438
580
|
for record in records:
|
|
439
581
|
if (
|
|
440
|
-
|
|
441
|
-
|
|
582
|
+
record['source_uuid'] == center_node_uuid
|
|
583
|
+
or record['target_uuid'] == center_node_uuid
|
|
442
584
|
):
|
|
443
585
|
continue
|
|
444
586
|
distance = record['score']
|
|
@@ -15,11 +15,13 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import asyncio
|
|
18
|
+
import logging
|
|
18
19
|
import typing
|
|
19
20
|
from datetime import datetime
|
|
21
|
+
from math import ceil
|
|
20
22
|
|
|
21
23
|
from neo4j import AsyncDriver
|
|
22
|
-
from numpy import dot
|
|
24
|
+
from numpy import dot, sqrt
|
|
23
25
|
from pydantic import BaseModel
|
|
24
26
|
|
|
25
27
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge
|
|
@@ -39,8 +41,12 @@ from graphiti_core.utils.maintenance.node_operations import (
|
|
|
39
41
|
dedupe_node_list,
|
|
40
42
|
extract_nodes,
|
|
41
43
|
)
|
|
44
|
+
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
|
|
45
|
+
from graphiti_core.utils.utils import chunk_edges_by_nodes
|
|
42
46
|
|
|
43
|
-
|
|
47
|
+
logger = logging.getLogger(__name__)
|
|
48
|
+
|
|
49
|
+
CHUNK_SIZE = 10
|
|
44
50
|
|
|
45
51
|
|
|
46
52
|
class RawEpisode(BaseModel):
|
|
@@ -114,27 +120,58 @@ async def dedupe_nodes_bulk(
|
|
|
114
120
|
|
|
115
121
|
compressed_nodes, compressed_map = await compress_nodes(llm_client, nodes, uuid_map)
|
|
116
122
|
|
|
117
|
-
|
|
123
|
+
node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)]
|
|
118
124
|
|
|
119
|
-
|
|
120
|
-
|
|
125
|
+
existing_nodes_chunks: list[list[EntityNode]] = list(
|
|
126
|
+
await asyncio.gather(
|
|
127
|
+
*[get_relevant_nodes(node_chunk, driver) for node_chunk in node_chunks]
|
|
128
|
+
)
|
|
121
129
|
)
|
|
122
130
|
|
|
123
|
-
|
|
131
|
+
results: list[tuple[list[EntityNode], dict[str, str]]] = list(
|
|
132
|
+
await asyncio.gather(
|
|
133
|
+
*[
|
|
134
|
+
dedupe_extracted_nodes(llm_client, node_chunk, existing_nodes_chunks[i])
|
|
135
|
+
for i, node_chunk in enumerate(node_chunks)
|
|
136
|
+
]
|
|
137
|
+
)
|
|
138
|
+
)
|
|
124
139
|
|
|
125
|
-
|
|
140
|
+
final_nodes: list[EntityNode] = []
|
|
141
|
+
for result in results:
|
|
142
|
+
final_nodes.extend(result[0])
|
|
143
|
+
partial_uuid_map = result[1]
|
|
144
|
+
compressed_map.update(partial_uuid_map)
|
|
145
|
+
|
|
146
|
+
return final_nodes, compressed_map
|
|
126
147
|
|
|
127
148
|
|
|
128
149
|
async def dedupe_edges_bulk(
|
|
129
150
|
driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge]
|
|
130
151
|
) -> list[EntityEdge]:
|
|
131
|
-
#
|
|
152
|
+
# First compress edges
|
|
132
153
|
compressed_edges = await compress_edges(llm_client, extracted_edges)
|
|
133
154
|
|
|
134
|
-
|
|
155
|
+
edge_chunks = [
|
|
156
|
+
compressed_edges[i : i + CHUNK_SIZE] for i in range(0, len(compressed_edges), CHUNK_SIZE)
|
|
157
|
+
]
|
|
135
158
|
|
|
136
|
-
|
|
159
|
+
relevant_edges_chunks: list[list[EntityEdge]] = list(
|
|
160
|
+
await asyncio.gather(
|
|
161
|
+
*[get_relevant_edges(driver, edge_chunk, None, None) for edge_chunk in edge_chunks]
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
resolved_edge_chunks: list[list[EntityEdge]] = list(
|
|
166
|
+
await asyncio.gather(
|
|
167
|
+
*[
|
|
168
|
+
dedupe_extracted_edges(llm_client, edge_chunk, relevant_edges_chunks[i])
|
|
169
|
+
for i, edge_chunk in enumerate(edge_chunks)
|
|
170
|
+
]
|
|
171
|
+
)
|
|
172
|
+
)
|
|
137
173
|
|
|
174
|
+
edges = [edge for edge_chunk in resolved_edge_chunks for edge in edge_chunk]
|
|
138
175
|
return edges
|
|
139
176
|
|
|
140
177
|
|
|
@@ -154,13 +191,58 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str
|
|
|
154
191
|
async def compress_nodes(
|
|
155
192
|
llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str]
|
|
156
193
|
) -> tuple[list[EntityNode], dict[str, str]]:
|
|
194
|
+
# We want to first compress the nodes by deduplicating nodes across each of the episodes added in bulk
|
|
157
195
|
if len(nodes) == 0:
|
|
158
196
|
return nodes, uuid_map
|
|
159
197
|
|
|
160
|
-
|
|
161
|
-
|
|
198
|
+
# Our approach involves us deduplicating chunks of nodes in parallel.
|
|
199
|
+
# We want n chunks of size n so that n ** 2 == len(nodes).
|
|
200
|
+
# We want chunk sizes to be at least 10 for optimizing LLM processing time
|
|
201
|
+
chunk_size = max(int(sqrt(len(nodes))), CHUNK_SIZE)
|
|
162
202
|
|
|
163
|
-
|
|
203
|
+
# First calculate similarity scores between nodes
|
|
204
|
+
similarity_scores: list[tuple[int, int, float]] = [
|
|
205
|
+
(i, j, dot(n.name_embedding or [], m.name_embedding or []))
|
|
206
|
+
for i, n in enumerate(nodes)
|
|
207
|
+
for j, m in enumerate(nodes[:i])
|
|
208
|
+
]
|
|
209
|
+
|
|
210
|
+
# We now sort by semantic similarity
|
|
211
|
+
similarity_scores.sort(key=lambda score_tuple: score_tuple[2])
|
|
212
|
+
|
|
213
|
+
# initialize our chunks based on chunk size
|
|
214
|
+
node_chunks: list[list[EntityNode]] = [[] for _ in range(ceil(len(nodes) / chunk_size))]
|
|
215
|
+
|
|
216
|
+
# Draft the most similar nodes into the same chunk
|
|
217
|
+
while len(similarity_scores) > 0:
|
|
218
|
+
i, j, _ = similarity_scores.pop()
|
|
219
|
+
# determine if any of the nodes have already been drafted into a chunk
|
|
220
|
+
n = nodes[i]
|
|
221
|
+
m = nodes[j]
|
|
222
|
+
# make sure the shortest chunks get preference
|
|
223
|
+
node_chunks.sort(reverse=True, key=lambda chunk: len(chunk))
|
|
224
|
+
|
|
225
|
+
n_chunk = max([i if n in chunk else -1 for i, chunk in enumerate(node_chunks)])
|
|
226
|
+
m_chunk = max([i if m in chunk else -1 for i, chunk in enumerate(node_chunks)])
|
|
227
|
+
|
|
228
|
+
# both nodes already in a chunk
|
|
229
|
+
if n_chunk > -1 and m_chunk > -1:
|
|
230
|
+
continue
|
|
231
|
+
|
|
232
|
+
# n has a chunk and that chunk is not full
|
|
233
|
+
elif n_chunk > -1 and len(node_chunks[n_chunk]) < chunk_size:
|
|
234
|
+
# put m in the same chunk as n
|
|
235
|
+
node_chunks[n_chunk].append(m)
|
|
236
|
+
|
|
237
|
+
# m has a chunk and that chunk is not full
|
|
238
|
+
elif m_chunk > -1 and len(node_chunks[m_chunk]) < chunk_size:
|
|
239
|
+
# put n in the same chunk as m
|
|
240
|
+
node_chunks[m_chunk].append(n)
|
|
241
|
+
|
|
242
|
+
# neither node has a chunk or the chunk is full
|
|
243
|
+
else:
|
|
244
|
+
# add both nodes to the shortest chunk
|
|
245
|
+
node_chunks[-1].extend([n, m])
|
|
164
246
|
|
|
165
247
|
results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks])
|
|
166
248
|
|
|
@@ -181,13 +263,9 @@ async def compress_nodes(
|
|
|
181
263
|
async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list[EntityEdge]:
|
|
182
264
|
if len(edges) == 0:
|
|
183
265
|
return edges
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
edges
|
|
187
|
-
key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or [])
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)]
|
|
266
|
+
# We only want to dedupe edges that are between the same pair of nodes
|
|
267
|
+
# We build a map of the edges based on their source and target nodes.
|
|
268
|
+
edge_chunks = chunk_edges_by_nodes(edges)
|
|
191
269
|
|
|
192
270
|
results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks])
|
|
193
271
|
|
|
@@ -225,3 +303,43 @@ def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
|
|
|
225
303
|
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
|
|
226
304
|
|
|
227
305
|
return edges
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
async def extract_edge_dates_bulk(
|
|
309
|
+
llm_client: LLMClient,
|
|
310
|
+
extracted_edges: list[EntityEdge],
|
|
311
|
+
episode_pairs: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
312
|
+
) -> list[EntityEdge]:
|
|
313
|
+
edges: list[EntityEdge] = []
|
|
314
|
+
# confirm that all of our edges have at least one episode
|
|
315
|
+
for edge in extracted_edges:
|
|
316
|
+
if edge.episodes is not None and len(edge.episodes) > 0:
|
|
317
|
+
edges.append(edge)
|
|
318
|
+
|
|
319
|
+
episode_uuid_map: dict[str, tuple[EpisodicNode, list[EpisodicNode]]] = {
|
|
320
|
+
episode.uuid: (episode, previous_episodes) for episode, previous_episodes in episode_pairs
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
results = await asyncio.gather(
|
|
324
|
+
*[
|
|
325
|
+
extract_edge_dates(
|
|
326
|
+
llm_client,
|
|
327
|
+
edge,
|
|
328
|
+
episode_uuid_map[edge.episodes[0]][0], # type: ignore
|
|
329
|
+
episode_uuid_map[edge.episodes[0]][1], # type: ignore
|
|
330
|
+
)
|
|
331
|
+
for edge in edges
|
|
332
|
+
]
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
for i, result in enumerate(results):
|
|
336
|
+
valid_at = result[0]
|
|
337
|
+
invalid_at = result[1]
|
|
338
|
+
edge = edges[i]
|
|
339
|
+
|
|
340
|
+
edge.valid_at = valid_at
|
|
341
|
+
edge.invalid_at = invalid_at
|
|
342
|
+
if edge.invalid_at:
|
|
343
|
+
edge.expired_at = datetime.now()
|
|
344
|
+
|
|
345
|
+
return edges
|