graphiti-core 0.18.0__py3-none-any.whl → 0.18.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +20 -2
- graphiti_core/driver/falkordb_driver.py +16 -9
- graphiti_core/driver/neo4j_driver.py +8 -6
- graphiti_core/edges.py +73 -99
- graphiti_core/graph_queries.py +51 -97
- graphiti_core/graphiti.py +27 -12
- graphiti_core/helpers.py +4 -3
- graphiti_core/models/edges/edge_db_queries.py +106 -32
- graphiti_core/models/nodes/node_db_queries.py +101 -20
- graphiti_core/nodes.py +113 -128
- graphiti_core/prompts/dedupe_nodes.py +1 -1
- graphiti_core/prompts/extract_edges.py +4 -4
- graphiti_core/prompts/extract_nodes.py +53 -11
- graphiti_core/search/search_filters.py +5 -5
- graphiti_core/search/search_utils.py +138 -185
- graphiti_core/utils/bulk_utils.py +7 -9
- graphiti_core/utils/maintenance/community_operations.py +11 -7
- graphiti_core/utils/maintenance/edge_operations.py +23 -54
- graphiti_core/utils/maintenance/graph_data_operations.py +14 -29
- graphiti_core/utils/maintenance/node_operations.py +40 -87
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- {graphiti_core-0.18.0.dist-info → graphiti_core-0.18.2.dist-info}/METADATA +11 -3
- {graphiti_core-0.18.0.dist-info → graphiti_core-0.18.2.dist-info}/RECORD +25 -25
- {graphiti_core-0.18.0.dist-info → graphiti_core-0.18.2.dist-info}/WHEEL +0 -0
- {graphiti_core-0.18.0.dist-info → graphiti_core-0.18.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -23,7 +23,7 @@ import numpy as np
|
|
|
23
23
|
from numpy._typing import NDArray
|
|
24
24
|
from typing_extensions import LiteralString
|
|
25
25
|
|
|
26
|
-
from graphiti_core.driver.driver import GraphDriver
|
|
26
|
+
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
27
27
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
|
28
28
|
from graphiti_core.graph_queries import (
|
|
29
29
|
get_nodes_query,
|
|
@@ -36,6 +36,8 @@ from graphiti_core.helpers import (
|
|
|
36
36
|
normalize_l2,
|
|
37
37
|
semaphore_gather,
|
|
38
38
|
)
|
|
39
|
+
from graphiti_core.models.edges.edge_db_queries import ENTITY_EDGE_RETURN
|
|
40
|
+
from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN, EPISODIC_NODE_RETURN
|
|
39
41
|
from graphiti_core.nodes import (
|
|
40
42
|
ENTITY_NODE_RETURN,
|
|
41
43
|
CommunityNode,
|
|
@@ -100,20 +102,13 @@ async def get_mentioned_nodes(
|
|
|
100
102
|
) -> list[EntityNode]:
|
|
101
103
|
episode_uuids = [episode.uuid for episode in episodes]
|
|
102
104
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
+
records, _, _ = await driver.execute_query(
|
|
106
|
+
"""
|
|
107
|
+
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
|
|
108
|
+
WHERE episode.uuid IN $uuids
|
|
105
109
|
RETURN DISTINCT
|
|
106
|
-
n.uuid As uuid,
|
|
107
|
-
n.group_id AS group_id,
|
|
108
|
-
n.name AS name,
|
|
109
|
-
n.created_at AS created_at,
|
|
110
|
-
n.summary AS summary,
|
|
111
|
-
labels(n) AS labels,
|
|
112
|
-
properties(n) AS attributes
|
|
113
110
|
"""
|
|
114
|
-
|
|
115
|
-
records, _, _ = await driver.execute_query(
|
|
116
|
-
query,
|
|
111
|
+
+ ENTITY_NODE_RETURN,
|
|
117
112
|
uuids=episode_uuids,
|
|
118
113
|
routing_='r',
|
|
119
114
|
)
|
|
@@ -128,18 +123,13 @@ async def get_communities_by_nodes(
|
|
|
128
123
|
) -> list[CommunityNode]:
|
|
129
124
|
node_uuids = [node.uuid for node in nodes]
|
|
130
125
|
|
|
131
|
-
query = """
|
|
132
|
-
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
|
133
|
-
RETURN DISTINCT
|
|
134
|
-
c.uuid As uuid,
|
|
135
|
-
c.group_id AS group_id,
|
|
136
|
-
c.name AS name,
|
|
137
|
-
c.created_at AS created_at,
|
|
138
|
-
c.summary AS summary
|
|
139
|
-
"""
|
|
140
|
-
|
|
141
126
|
records, _, _ = await driver.execute_query(
|
|
142
|
-
|
|
127
|
+
"""
|
|
128
|
+
MATCH (n:Community)-[:HAS_MEMBER]->(m:Entity)
|
|
129
|
+
WHERE m.uuid IN $uuids
|
|
130
|
+
RETURN DISTINCT
|
|
131
|
+
"""
|
|
132
|
+
+ COMMUNITY_NODE_RETURN,
|
|
143
133
|
uuids=node_uuids,
|
|
144
134
|
routing_='r',
|
|
145
135
|
)
|
|
@@ -164,38 +154,30 @@ async def edge_fulltext_search(
|
|
|
164
154
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
165
155
|
|
|
166
156
|
query = (
|
|
167
|
-
get_relationships_query('edge_name_and_fact',
|
|
157
|
+
get_relationships_query('edge_name_and_fact', provider=driver.provider)
|
|
168
158
|
+ """
|
|
169
159
|
YIELD relationship AS rel, score
|
|
170
|
-
MATCH (n:Entity)-[
|
|
171
|
-
WHERE
|
|
160
|
+
MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
|
|
161
|
+
WHERE e.group_id IN $group_ids """
|
|
172
162
|
+ filter_query
|
|
173
163
|
+ """
|
|
174
|
-
WITH
|
|
164
|
+
WITH e, score, n, m
|
|
175
165
|
RETURN
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
r.name AS name,
|
|
182
|
-
r.fact AS fact,
|
|
183
|
-
r.episodes AS episodes,
|
|
184
|
-
r.expired_at AS expired_at,
|
|
185
|
-
r.valid_at AS valid_at,
|
|
186
|
-
r.invalid_at AS invalid_at,
|
|
187
|
-
properties(r) AS attributes
|
|
188
|
-
ORDER BY score DESC LIMIT $limit
|
|
166
|
+
"""
|
|
167
|
+
+ ENTITY_EDGE_RETURN
|
|
168
|
+
+ """
|
|
169
|
+
ORDER BY score DESC
|
|
170
|
+
LIMIT $limit
|
|
189
171
|
"""
|
|
190
172
|
)
|
|
191
173
|
|
|
192
174
|
records, _, _ = await driver.execute_query(
|
|
193
175
|
query,
|
|
194
|
-
params=filter_params,
|
|
195
176
|
query=fuzzy_query,
|
|
196
177
|
group_ids=group_ids,
|
|
197
178
|
limit=limit,
|
|
198
179
|
routing_='r',
|
|
180
|
+
**filter_params,
|
|
199
181
|
)
|
|
200
182
|
|
|
201
183
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
@@ -219,58 +201,47 @@ async def edge_similarity_search(
|
|
|
219
201
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
220
202
|
query_params.update(filter_params)
|
|
221
203
|
|
|
222
|
-
group_filter_query: LiteralString = 'WHERE
|
|
204
|
+
group_filter_query: LiteralString = 'WHERE e.group_id IS NOT NULL'
|
|
223
205
|
if group_ids is not None:
|
|
224
|
-
group_filter_query += '\nAND
|
|
206
|
+
group_filter_query += '\nAND e.group_id IN $group_ids'
|
|
225
207
|
query_params['group_ids'] = group_ids
|
|
226
|
-
query_params['source_node_uuid'] = source_node_uuid
|
|
227
|
-
query_params['target_node_uuid'] = target_node_uuid
|
|
228
208
|
|
|
229
209
|
if source_node_uuid is not None:
|
|
230
|
-
|
|
210
|
+
query_params['source_uuid'] = source_node_uuid
|
|
211
|
+
group_filter_query += '\nAND (n.uuid = $source_uuid)'
|
|
231
212
|
|
|
232
213
|
if target_node_uuid is not None:
|
|
233
|
-
|
|
214
|
+
query_params['target_uuid'] = target_node_uuid
|
|
215
|
+
group_filter_query += '\nAND (m.uuid = $target_uuid)'
|
|
234
216
|
|
|
235
217
|
query = (
|
|
236
218
|
RUNTIME_QUERY
|
|
237
219
|
+ """
|
|
238
|
-
MATCH (n:Entity)-[
|
|
220
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
239
221
|
"""
|
|
240
222
|
+ group_filter_query
|
|
241
223
|
+ filter_query
|
|
242
224
|
+ """
|
|
243
|
-
WITH DISTINCT
|
|
244
|
-
+ get_vector_cosine_func_query('
|
|
225
|
+
WITH DISTINCT e, n, m, """
|
|
226
|
+
+ get_vector_cosine_func_query('e.fact_embedding', '$search_vector', driver.provider)
|
|
245
227
|
+ """ AS score
|
|
246
228
|
WHERE score > $min_score
|
|
247
229
|
RETURN
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
endNode(r).uuid AS target_node_uuid,
|
|
252
|
-
r.created_at AS created_at,
|
|
253
|
-
r.name AS name,
|
|
254
|
-
r.fact AS fact,
|
|
255
|
-
r.episodes AS episodes,
|
|
256
|
-
r.expired_at AS expired_at,
|
|
257
|
-
r.valid_at AS valid_at,
|
|
258
|
-
r.invalid_at AS invalid_at,
|
|
259
|
-
properties(r) AS attributes
|
|
230
|
+
"""
|
|
231
|
+
+ ENTITY_EDGE_RETURN
|
|
232
|
+
+ """
|
|
260
233
|
ORDER BY score DESC
|
|
261
234
|
LIMIT $limit
|
|
262
235
|
"""
|
|
263
236
|
)
|
|
264
|
-
|
|
237
|
+
|
|
238
|
+
records, _, _ = await driver.execute_query(
|
|
265
239
|
query,
|
|
266
|
-
params=query_params,
|
|
267
240
|
search_vector=search_vector,
|
|
268
|
-
source_uuid=source_node_uuid,
|
|
269
|
-
target_uuid=target_node_uuid,
|
|
270
|
-
group_ids=group_ids,
|
|
271
241
|
limit=limit,
|
|
272
242
|
min_score=min_score,
|
|
273
243
|
routing_='r',
|
|
244
|
+
**query_params,
|
|
274
245
|
)
|
|
275
246
|
|
|
276
247
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
@@ -293,41 +264,31 @@ async def edge_bfs_search(
|
|
|
293
264
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
294
265
|
|
|
295
266
|
query = (
|
|
267
|
+
f"""
|
|
268
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
269
|
+
MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
|
|
270
|
+
UNWIND relationships(path) AS rel
|
|
271
|
+
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
272
|
+
WHERE e.uuid = rel.uuid
|
|
273
|
+
AND e.group_id IN $group_ids
|
|
296
274
|
"""
|
|
297
|
-
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
298
|
-
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
|
299
|
-
UNWIND relationships(path) AS rel
|
|
300
|
-
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
|
301
|
-
WHERE r.uuid = rel.uuid
|
|
302
|
-
AND r.group_id IN $group_ids
|
|
303
|
-
"""
|
|
304
275
|
+ filter_query
|
|
305
|
-
+ """
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
r.created_at AS created_at,
|
|
312
|
-
r.name AS name,
|
|
313
|
-
r.fact AS fact,
|
|
314
|
-
r.episodes AS episodes,
|
|
315
|
-
r.expired_at AS expired_at,
|
|
316
|
-
r.valid_at AS valid_at,
|
|
317
|
-
r.invalid_at AS invalid_at,
|
|
318
|
-
properties(r) AS attributes
|
|
319
|
-
LIMIT $limit
|
|
276
|
+
+ """
|
|
277
|
+
RETURN DISTINCT
|
|
278
|
+
"""
|
|
279
|
+
+ ENTITY_EDGE_RETURN
|
|
280
|
+
+ """
|
|
281
|
+
LIMIT $limit
|
|
320
282
|
"""
|
|
321
283
|
)
|
|
322
284
|
|
|
323
285
|
records, _, _ = await driver.execute_query(
|
|
324
286
|
query,
|
|
325
|
-
params=filter_params,
|
|
326
287
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
327
|
-
depth=bfs_max_depth,
|
|
328
288
|
group_ids=group_ids,
|
|
329
289
|
limit=limit,
|
|
330
290
|
routing_='r',
|
|
291
|
+
**filter_params,
|
|
331
292
|
)
|
|
332
293
|
|
|
333
294
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
@@ -352,23 +313,25 @@ async def node_fulltext_search(
|
|
|
352
313
|
get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
|
|
353
314
|
+ """
|
|
354
315
|
YIELD node AS n, score
|
|
355
|
-
|
|
356
|
-
LIMIT $limit
|
|
357
|
-
WHERE n:Entity AND n.group_id IN $group_ids
|
|
316
|
+
WHERE n:Entity AND n.group_id IN $group_ids
|
|
358
317
|
"""
|
|
359
318
|
+ filter_query
|
|
360
|
-
+ ENTITY_NODE_RETURN
|
|
361
319
|
+ """
|
|
320
|
+
WITH n, score
|
|
362
321
|
ORDER BY score DESC
|
|
322
|
+
LIMIT $limit
|
|
323
|
+
RETURN
|
|
363
324
|
"""
|
|
325
|
+
+ ENTITY_NODE_RETURN
|
|
364
326
|
)
|
|
365
|
-
|
|
327
|
+
|
|
328
|
+
records, _, _ = await driver.execute_query(
|
|
366
329
|
query,
|
|
367
|
-
params=filter_params,
|
|
368
330
|
query=fuzzy_query,
|
|
369
331
|
group_ids=group_ids,
|
|
370
332
|
limit=limit,
|
|
371
333
|
routing_='r',
|
|
334
|
+
**filter_params,
|
|
372
335
|
)
|
|
373
336
|
|
|
374
337
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
@@ -406,22 +369,23 @@ async def node_similarity_search(
|
|
|
406
369
|
WITH n, """
|
|
407
370
|
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
|
408
371
|
+ """ AS score
|
|
409
|
-
WHERE score > $min_score
|
|
372
|
+
WHERE score > $min_score
|
|
373
|
+
RETURN
|
|
374
|
+
"""
|
|
410
375
|
+ ENTITY_NODE_RETURN
|
|
411
376
|
+ """
|
|
412
377
|
ORDER BY score DESC
|
|
413
378
|
LIMIT $limit
|
|
414
|
-
|
|
379
|
+
"""
|
|
415
380
|
)
|
|
416
381
|
|
|
417
|
-
records,
|
|
382
|
+
records, _, _ = await driver.execute_query(
|
|
418
383
|
query,
|
|
419
|
-
params=query_params,
|
|
420
384
|
search_vector=search_vector,
|
|
421
|
-
group_ids=group_ids,
|
|
422
385
|
limit=limit,
|
|
423
386
|
min_score=min_score,
|
|
424
387
|
routing_='r',
|
|
388
|
+
**query_params,
|
|
425
389
|
)
|
|
426
390
|
|
|
427
391
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
@@ -444,26 +408,29 @@ async def node_bfs_search(
|
|
|
444
408
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
445
409
|
|
|
446
410
|
query = (
|
|
411
|
+
f"""
|
|
412
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
413
|
+
MATCH (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
|
|
414
|
+
WHERE n.group_id = origin.group_id
|
|
415
|
+
AND origin.group_id IN $group_ids
|
|
447
416
|
"""
|
|
448
|
-
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
449
|
-
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
|
450
|
-
WHERE n.group_id = origin.group_id
|
|
451
|
-
AND origin.group_id IN $group_ids
|
|
452
|
-
"""
|
|
453
417
|
+ filter_query
|
|
418
|
+
+ """
|
|
419
|
+
RETURN
|
|
420
|
+
"""
|
|
454
421
|
+ ENTITY_NODE_RETURN
|
|
455
422
|
+ """
|
|
456
423
|
LIMIT $limit
|
|
457
424
|
"""
|
|
458
425
|
)
|
|
426
|
+
|
|
459
427
|
records, _, _ = await driver.execute_query(
|
|
460
428
|
query,
|
|
461
|
-
params=filter_params,
|
|
462
429
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
463
|
-
depth=bfs_max_depth,
|
|
464
430
|
group_ids=group_ids,
|
|
465
431
|
limit=limit,
|
|
466
432
|
routing_='r',
|
|
433
|
+
**filter_params,
|
|
467
434
|
)
|
|
468
435
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
469
436
|
|
|
@@ -489,16 +456,10 @@ async def episode_fulltext_search(
|
|
|
489
456
|
MATCH (e:Episodic)
|
|
490
457
|
WHERE e.uuid = episode.uuid
|
|
491
458
|
AND e.group_id IN $group_ids
|
|
492
|
-
RETURN
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
e.uuid AS uuid,
|
|
497
|
-
e.name AS name,
|
|
498
|
-
e.group_id AS group_id,
|
|
499
|
-
e.source_description AS source_description,
|
|
500
|
-
e.source AS source,
|
|
501
|
-
e.entity_edges AS entity_edges
|
|
459
|
+
RETURN
|
|
460
|
+
"""
|
|
461
|
+
+ EPISODIC_NODE_RETURN
|
|
462
|
+
+ """
|
|
502
463
|
ORDER BY score DESC
|
|
503
464
|
LIMIT $limit
|
|
504
465
|
"""
|
|
@@ -530,15 +491,12 @@ async def community_fulltext_search(
|
|
|
530
491
|
query = (
|
|
531
492
|
get_nodes_query(driver.provider, 'community_name', '$query')
|
|
532
493
|
+ """
|
|
533
|
-
YIELD node AS
|
|
534
|
-
WHERE
|
|
494
|
+
YIELD node AS n, score
|
|
495
|
+
WHERE n.group_id IN $group_ids
|
|
535
496
|
RETURN
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
comm.created_at AS created_at,
|
|
540
|
-
comm.summary AS summary,
|
|
541
|
-
comm.name_embedding AS name_embedding
|
|
497
|
+
"""
|
|
498
|
+
+ COMMUNITY_NODE_RETURN
|
|
499
|
+
+ """
|
|
542
500
|
ORDER BY score DESC
|
|
543
501
|
LIMIT $limit
|
|
544
502
|
"""
|
|
@@ -568,39 +526,37 @@ async def community_similarity_search(
|
|
|
568
526
|
|
|
569
527
|
group_filter_query: LiteralString = ''
|
|
570
528
|
if group_ids is not None:
|
|
571
|
-
group_filter_query += 'WHERE
|
|
529
|
+
group_filter_query += 'WHERE n.group_id IN $group_ids'
|
|
572
530
|
query_params['group_ids'] = group_ids
|
|
573
531
|
|
|
574
532
|
query = (
|
|
575
533
|
RUNTIME_QUERY
|
|
576
534
|
+ """
|
|
577
|
-
|
|
578
|
-
|
|
535
|
+
MATCH (n:Community)
|
|
536
|
+
"""
|
|
579
537
|
+ group_filter_query
|
|
580
538
|
+ """
|
|
581
|
-
|
|
582
|
-
|
|
539
|
+
WITH n,
|
|
540
|
+
"""
|
|
541
|
+
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
|
583
542
|
+ """ AS score
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
comm.name_embedding AS name_embedding
|
|
592
|
-
ORDER BY score DESC
|
|
593
|
-
LIMIT $limit
|
|
543
|
+
WHERE score > $min_score
|
|
544
|
+
RETURN
|
|
545
|
+
"""
|
|
546
|
+
+ COMMUNITY_NODE_RETURN
|
|
547
|
+
+ """
|
|
548
|
+
ORDER BY score DESC
|
|
549
|
+
LIMIT $limit
|
|
594
550
|
"""
|
|
595
551
|
)
|
|
596
552
|
|
|
597
553
|
records, _, _ = await driver.execute_query(
|
|
598
554
|
query,
|
|
599
555
|
search_vector=search_vector,
|
|
600
|
-
group_ids=group_ids,
|
|
601
556
|
limit=limit,
|
|
602
557
|
min_score=min_score,
|
|
603
558
|
routing_='r',
|
|
559
|
+
**query_params,
|
|
604
560
|
)
|
|
605
561
|
communities = [get_community_node_from_record(record) for record in records]
|
|
606
562
|
|
|
@@ -719,8 +675,8 @@ async def get_relevant_nodes(
|
|
|
719
675
|
WHERE m.group_id = $group_id
|
|
720
676
|
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
|
721
677
|
|
|
722
|
-
WITH node,
|
|
723
|
-
top_vector_nodes,
|
|
678
|
+
WITH node,
|
|
679
|
+
top_vector_nodes,
|
|
724
680
|
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
|
725
681
|
|
|
726
682
|
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
|
@@ -728,10 +684,10 @@ async def get_relevant_nodes(
|
|
|
728
684
|
UNWIND combined_nodes AS combined_node
|
|
729
685
|
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
|
730
686
|
|
|
731
|
-
RETURN
|
|
687
|
+
RETURN
|
|
732
688
|
node.uuid AS search_node_uuid,
|
|
733
689
|
[x IN deduped_nodes | {
|
|
734
|
-
uuid: x.uuid,
|
|
690
|
+
uuid: x.uuid,
|
|
735
691
|
name: x.name,
|
|
736
692
|
name_embedding: x.name_embedding,
|
|
737
693
|
group_id: x.group_id,
|
|
@@ -755,12 +711,12 @@ async def get_relevant_nodes(
|
|
|
755
711
|
|
|
756
712
|
results, _, _ = await driver.execute_query(
|
|
757
713
|
query,
|
|
758
|
-
params=query_params,
|
|
759
714
|
nodes=query_nodes,
|
|
760
715
|
group_id=group_id,
|
|
761
716
|
limit=limit,
|
|
762
717
|
min_score=min_score,
|
|
763
718
|
routing_='r',
|
|
719
|
+
**query_params,
|
|
764
720
|
)
|
|
765
721
|
|
|
766
722
|
relevant_nodes_dict: dict[str, list[EntityNode]] = {
|
|
@@ -825,11 +781,11 @@ async def get_relevant_edges(
|
|
|
825
781
|
|
|
826
782
|
results, _, _ = await driver.execute_query(
|
|
827
783
|
query,
|
|
828
|
-
params=query_params,
|
|
829
784
|
edges=[edge.model_dump() for edge in edges],
|
|
830
785
|
limit=limit,
|
|
831
786
|
min_score=min_score,
|
|
832
787
|
routing_='r',
|
|
788
|
+
**query_params,
|
|
833
789
|
)
|
|
834
790
|
|
|
835
791
|
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
|
@@ -895,11 +851,11 @@ async def get_edge_invalidation_candidates(
|
|
|
895
851
|
|
|
896
852
|
results, _, _ = await driver.execute_query(
|
|
897
853
|
query,
|
|
898
|
-
params=query_params,
|
|
899
854
|
edges=[edge.model_dump() for edge in edges],
|
|
900
855
|
limit=limit,
|
|
901
856
|
min_score=min_score,
|
|
902
857
|
routing_='r',
|
|
858
|
+
**query_params,
|
|
903
859
|
)
|
|
904
860
|
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
|
|
905
861
|
result['search_edge_uuid']: [
|
|
@@ -943,18 +899,17 @@ async def node_distance_reranker(
|
|
|
943
899
|
scores: dict[str, float] = {center_node_uuid: 0.0}
|
|
944
900
|
|
|
945
901
|
# Find the shortest path to center node
|
|
946
|
-
|
|
902
|
+
results, header, _ = await driver.execute_query(
|
|
903
|
+
"""
|
|
947
904
|
UNWIND $node_uuids AS node_uuid
|
|
948
905
|
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
|
949
906
|
RETURN 1 AS score, node_uuid AS uuid
|
|
950
|
-
"""
|
|
951
|
-
results, header, _ = await driver.execute_query(
|
|
952
|
-
query,
|
|
907
|
+
""",
|
|
953
908
|
node_uuids=filtered_uuids,
|
|
954
909
|
center_uuid=center_node_uuid,
|
|
955
910
|
routing_='r',
|
|
956
911
|
)
|
|
957
|
-
if driver.provider ==
|
|
912
|
+
if driver.provider == GraphProvider.FALKORDB:
|
|
958
913
|
results = [dict(zip(header, row, strict=True)) for row in results]
|
|
959
914
|
|
|
960
915
|
for result in results:
|
|
@@ -987,13 +942,12 @@ async def episode_mentions_reranker(
|
|
|
987
942
|
scores: dict[str, float] = {}
|
|
988
943
|
|
|
989
944
|
# Find the shortest path to center node
|
|
990
|
-
|
|
991
|
-
|
|
945
|
+
results, _, _ = await driver.execute_query(
|
|
946
|
+
"""
|
|
947
|
+
UNWIND $node_uuids AS node_uuid
|
|
992
948
|
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
|
993
949
|
RETURN count(*) AS score, n.uuid AS uuid
|
|
994
|
-
"""
|
|
995
|
-
results, _, _ = await driver.execute_query(
|
|
996
|
-
query,
|
|
950
|
+
""",
|
|
997
951
|
node_uuids=sorted_uuids,
|
|
998
952
|
routing_='r',
|
|
999
953
|
)
|
|
@@ -1053,15 +1007,16 @@ def maximal_marginal_relevance(
|
|
|
1053
1007
|
async def get_embeddings_for_nodes(
|
|
1054
1008
|
driver: GraphDriver, nodes: list[EntityNode]
|
|
1055
1009
|
) -> dict[str, list[float]]:
|
|
1056
|
-
query: LiteralString = """MATCH (n:Entity)
|
|
1057
|
-
WHERE n.uuid IN $node_uuids
|
|
1058
|
-
RETURN DISTINCT
|
|
1059
|
-
n.uuid AS uuid,
|
|
1060
|
-
n.name_embedding AS name_embedding
|
|
1061
|
-
"""
|
|
1062
|
-
|
|
1063
1010
|
results, _, _ = await driver.execute_query(
|
|
1064
|
-
|
|
1011
|
+
"""
|
|
1012
|
+
MATCH (n:Entity)
|
|
1013
|
+
WHERE n.uuid IN $node_uuids
|
|
1014
|
+
RETURN DISTINCT
|
|
1015
|
+
n.uuid AS uuid,
|
|
1016
|
+
n.name_embedding AS name_embedding
|
|
1017
|
+
""",
|
|
1018
|
+
node_uuids=[node.uuid for node in nodes],
|
|
1019
|
+
routing_='r',
|
|
1065
1020
|
)
|
|
1066
1021
|
|
|
1067
1022
|
embeddings_dict: dict[str, list[float]] = {}
|
|
@@ -1077,15 +1032,14 @@ async def get_embeddings_for_nodes(
|
|
|
1077
1032
|
async def get_embeddings_for_communities(
|
|
1078
1033
|
driver: GraphDriver, communities: list[CommunityNode]
|
|
1079
1034
|
) -> dict[str, list[float]]:
|
|
1080
|
-
query: LiteralString = """MATCH (c:Community)
|
|
1081
|
-
WHERE c.uuid IN $community_uuids
|
|
1082
|
-
RETURN DISTINCT
|
|
1083
|
-
c.uuid AS uuid,
|
|
1084
|
-
c.name_embedding AS name_embedding
|
|
1085
|
-
"""
|
|
1086
|
-
|
|
1087
1035
|
results, _, _ = await driver.execute_query(
|
|
1088
|
-
|
|
1036
|
+
"""
|
|
1037
|
+
MATCH (c:Community)
|
|
1038
|
+
WHERE c.uuid IN $community_uuids
|
|
1039
|
+
RETURN DISTINCT
|
|
1040
|
+
c.uuid AS uuid,
|
|
1041
|
+
c.name_embedding AS name_embedding
|
|
1042
|
+
""",
|
|
1089
1043
|
community_uuids=[community.uuid for community in communities],
|
|
1090
1044
|
routing_='r',
|
|
1091
1045
|
)
|
|
@@ -1103,15 +1057,14 @@ async def get_embeddings_for_communities(
|
|
|
1103
1057
|
async def get_embeddings_for_edges(
|
|
1104
1058
|
driver: GraphDriver, edges: list[EntityEdge]
|
|
1105
1059
|
) -> dict[str, list[float]]:
|
|
1106
|
-
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1107
|
-
WHERE e.uuid IN $edge_uuids
|
|
1108
|
-
RETURN DISTINCT
|
|
1109
|
-
e.uuid AS uuid,
|
|
1110
|
-
e.fact_embedding AS fact_embedding
|
|
1111
|
-
"""
|
|
1112
|
-
|
|
1113
1060
|
results, _, _ = await driver.execute_query(
|
|
1114
|
-
|
|
1061
|
+
"""
|
|
1062
|
+
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1063
|
+
WHERE e.uuid IN $edge_uuids
|
|
1064
|
+
RETURN DISTINCT
|
|
1065
|
+
e.uuid AS uuid,
|
|
1066
|
+
e.fact_embedding AS fact_embedding
|
|
1067
|
+
""",
|
|
1115
1068
|
edge_uuids=[edge.uuid for edge in edges],
|
|
1116
1069
|
routing_='r',
|
|
1117
1070
|
)
|
|
@@ -25,17 +25,15 @@ from typing_extensions import Any
|
|
|
25
25
|
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession
|
|
26
26
|
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
|
|
27
27
|
from graphiti_core.embedder import EmbedderClient
|
|
28
|
-
from graphiti_core.graph_queries import (
|
|
29
|
-
get_entity_edge_save_bulk_query,
|
|
30
|
-
get_entity_node_save_bulk_query,
|
|
31
|
-
)
|
|
32
28
|
from graphiti_core.graphiti_types import GraphitiClients
|
|
33
29
|
from graphiti_core.helpers import normalize_l2, semaphore_gather
|
|
34
30
|
from graphiti_core.models.edges.edge_db_queries import (
|
|
35
31
|
EPISODIC_EDGE_SAVE_BULK,
|
|
32
|
+
get_entity_edge_save_bulk_query,
|
|
36
33
|
)
|
|
37
34
|
from graphiti_core.models.nodes.node_db_queries import (
|
|
38
35
|
EPISODIC_NODE_SAVE_BULK,
|
|
36
|
+
get_entity_node_save_bulk_query,
|
|
39
37
|
)
|
|
40
38
|
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode, create_entity_node_embeddings
|
|
41
39
|
from graphiti_core.utils.maintenance.edge_operations import (
|
|
@@ -158,7 +156,7 @@ async def add_nodes_and_edges_bulk_tx(
|
|
|
158
156
|
edges.append(edge_data)
|
|
159
157
|
|
|
160
158
|
await tx.run(EPISODIC_NODE_SAVE_BULK, episodes=episodes)
|
|
161
|
-
entity_node_save_bulk = get_entity_node_save_bulk_query(
|
|
159
|
+
entity_node_save_bulk = get_entity_node_save_bulk_query(driver.provider, nodes)
|
|
162
160
|
await tx.run(entity_node_save_bulk, nodes=nodes)
|
|
163
161
|
await tx.run(
|
|
164
162
|
EPISODIC_EDGE_SAVE_BULK, episodic_edges=[edge.model_dump() for edge in episodic_edges]
|
|
@@ -171,9 +169,9 @@ async def extract_nodes_and_edges_bulk(
|
|
|
171
169
|
clients: GraphitiClients,
|
|
172
170
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
173
171
|
edge_type_map: dict[tuple[str, str], list[str]],
|
|
174
|
-
entity_types: dict[str, BaseModel] | None = None,
|
|
172
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
175
173
|
excluded_entity_types: list[str] | None = None,
|
|
176
|
-
edge_types: dict[str, BaseModel] | None = None,
|
|
174
|
+
edge_types: dict[str, type[BaseModel]] | None = None,
|
|
177
175
|
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
|
|
178
176
|
extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
|
|
179
177
|
*[
|
|
@@ -204,7 +202,7 @@ async def dedupe_nodes_bulk(
|
|
|
204
202
|
clients: GraphitiClients,
|
|
205
203
|
extracted_nodes: list[list[EntityNode]],
|
|
206
204
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
207
|
-
entity_types: dict[str, BaseModel] | None = None,
|
|
205
|
+
entity_types: dict[str, type[BaseModel]] | None = None,
|
|
208
206
|
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
|
|
209
207
|
embedder = clients.embedder
|
|
210
208
|
min_score = 0.8
|
|
@@ -292,7 +290,7 @@ async def dedupe_edges_bulk(
|
|
|
292
290
|
extracted_edges: list[list[EntityEdge]],
|
|
293
291
|
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
|
|
294
292
|
_entities: list[EntityNode],
|
|
295
|
-
edge_types: dict[str, BaseModel],
|
|
293
|
+
edge_types: dict[str, type[BaseModel]],
|
|
296
294
|
_edge_type_map: dict[tuple[str, str], list[str]],
|
|
297
295
|
) -> dict[str, list[EntityEdge]]:
|
|
298
296
|
embedder = clients.embedder
|