graphiti-core 0.17.11__py3-none-any.whl → 0.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +20 -2
- graphiti_core/driver/falkordb_driver.py +16 -9
- graphiti_core/driver/neo4j_driver.py +8 -6
- graphiti_core/edges.py +73 -99
- graphiti_core/graph_queries.py +51 -97
- graphiti_core/graphiti.py +24 -9
- graphiti_core/helpers.py +3 -2
- graphiti_core/models/edges/edge_db_queries.py +106 -32
- graphiti_core/models/nodes/node_db_queries.py +101 -20
- graphiti_core/nodes.py +113 -128
- graphiti_core/prompts/dedupe_nodes.py +1 -1
- graphiti_core/prompts/extract_edges.py +4 -4
- graphiti_core/prompts/extract_nodes.py +12 -10
- graphiti_core/search/search.py +44 -32
- graphiti_core/search/search_config.py +8 -4
- graphiti_core/search/search_filters.py +5 -5
- graphiti_core/search/search_utils.py +154 -189
- graphiti_core/utils/bulk_utils.py +3 -5
- graphiti_core/utils/maintenance/community_operations.py +11 -7
- graphiti_core/utils/maintenance/edge_operations.py +19 -50
- graphiti_core/utils/maintenance/graph_data_operations.py +14 -29
- graphiti_core/utils/maintenance/node_operations.py +11 -55
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/METADATA +11 -3
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/RECORD +26 -26
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/WHEEL +0 -0
- {graphiti_core-0.17.11.dist-info → graphiti_core-0.18.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -23,7 +23,7 @@ import numpy as np
|
|
|
23
23
|
from numpy._typing import NDArray
|
|
24
24
|
from typing_extensions import LiteralString
|
|
25
25
|
|
|
26
|
-
from graphiti_core.driver.driver import GraphDriver
|
|
26
|
+
from graphiti_core.driver.driver import GraphDriver, GraphProvider
|
|
27
27
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
|
28
28
|
from graphiti_core.graph_queries import (
|
|
29
29
|
get_nodes_query,
|
|
@@ -36,6 +36,8 @@ from graphiti_core.helpers import (
|
|
|
36
36
|
normalize_l2,
|
|
37
37
|
semaphore_gather,
|
|
38
38
|
)
|
|
39
|
+
from graphiti_core.models.edges.edge_db_queries import ENTITY_EDGE_RETURN
|
|
40
|
+
from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN, EPISODIC_NODE_RETURN
|
|
39
41
|
from graphiti_core.nodes import (
|
|
40
42
|
ENTITY_NODE_RETURN,
|
|
41
43
|
CommunityNode,
|
|
@@ -100,20 +102,13 @@ async def get_mentioned_nodes(
|
|
|
100
102
|
) -> list[EntityNode]:
|
|
101
103
|
episode_uuids = [episode.uuid for episode in episodes]
|
|
102
104
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
+
records, _, _ = await driver.execute_query(
|
|
106
|
+
"""
|
|
107
|
+
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
|
|
108
|
+
WHERE episode.uuid IN $uuids
|
|
105
109
|
RETURN DISTINCT
|
|
106
|
-
n.uuid As uuid,
|
|
107
|
-
n.group_id AS group_id,
|
|
108
|
-
n.name AS name,
|
|
109
|
-
n.created_at AS created_at,
|
|
110
|
-
n.summary AS summary,
|
|
111
|
-
labels(n) AS labels,
|
|
112
|
-
properties(n) AS attributes
|
|
113
110
|
"""
|
|
114
|
-
|
|
115
|
-
records, _, _ = await driver.execute_query(
|
|
116
|
-
query,
|
|
111
|
+
+ ENTITY_NODE_RETURN,
|
|
117
112
|
uuids=episode_uuids,
|
|
118
113
|
routing_='r',
|
|
119
114
|
)
|
|
@@ -128,18 +123,13 @@ async def get_communities_by_nodes(
|
|
|
128
123
|
) -> list[CommunityNode]:
|
|
129
124
|
node_uuids = [node.uuid for node in nodes]
|
|
130
125
|
|
|
131
|
-
query = """
|
|
132
|
-
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
|
133
|
-
RETURN DISTINCT
|
|
134
|
-
c.uuid As uuid,
|
|
135
|
-
c.group_id AS group_id,
|
|
136
|
-
c.name AS name,
|
|
137
|
-
c.created_at AS created_at,
|
|
138
|
-
c.summary AS summary
|
|
139
|
-
"""
|
|
140
|
-
|
|
141
126
|
records, _, _ = await driver.execute_query(
|
|
142
|
-
|
|
127
|
+
"""
|
|
128
|
+
MATCH (n:Community)-[:HAS_MEMBER]->(m:Entity)
|
|
129
|
+
WHERE m.uuid IN $uuids
|
|
130
|
+
RETURN DISTINCT
|
|
131
|
+
"""
|
|
132
|
+
+ COMMUNITY_NODE_RETURN,
|
|
143
133
|
uuids=node_uuids,
|
|
144
134
|
routing_='r',
|
|
145
135
|
)
|
|
@@ -164,38 +154,30 @@ async def edge_fulltext_search(
|
|
|
164
154
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
165
155
|
|
|
166
156
|
query = (
|
|
167
|
-
get_relationships_query('edge_name_and_fact',
|
|
157
|
+
get_relationships_query('edge_name_and_fact', provider=driver.provider)
|
|
168
158
|
+ """
|
|
169
159
|
YIELD relationship AS rel, score
|
|
170
|
-
MATCH (n:Entity)-[
|
|
171
|
-
WHERE
|
|
160
|
+
MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
|
|
161
|
+
WHERE e.group_id IN $group_ids """
|
|
172
162
|
+ filter_query
|
|
173
163
|
+ """
|
|
174
|
-
WITH
|
|
164
|
+
WITH e, score, n, m
|
|
175
165
|
RETURN
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
r.name AS name,
|
|
182
|
-
r.fact AS fact,
|
|
183
|
-
r.episodes AS episodes,
|
|
184
|
-
r.expired_at AS expired_at,
|
|
185
|
-
r.valid_at AS valid_at,
|
|
186
|
-
r.invalid_at AS invalid_at,
|
|
187
|
-
properties(r) AS attributes
|
|
188
|
-
ORDER BY score DESC LIMIT $limit
|
|
166
|
+
"""
|
|
167
|
+
+ ENTITY_EDGE_RETURN
|
|
168
|
+
+ """
|
|
169
|
+
ORDER BY score DESC
|
|
170
|
+
LIMIT $limit
|
|
189
171
|
"""
|
|
190
172
|
)
|
|
191
173
|
|
|
192
174
|
records, _, _ = await driver.execute_query(
|
|
193
175
|
query,
|
|
194
|
-
params=filter_params,
|
|
195
176
|
query=fuzzy_query,
|
|
196
177
|
group_ids=group_ids,
|
|
197
178
|
limit=limit,
|
|
198
179
|
routing_='r',
|
|
180
|
+
**filter_params,
|
|
199
181
|
)
|
|
200
182
|
|
|
201
183
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
@@ -219,58 +201,47 @@ async def edge_similarity_search(
|
|
|
219
201
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
220
202
|
query_params.update(filter_params)
|
|
221
203
|
|
|
222
|
-
group_filter_query: LiteralString = 'WHERE
|
|
204
|
+
group_filter_query: LiteralString = 'WHERE e.group_id IS NOT NULL'
|
|
223
205
|
if group_ids is not None:
|
|
224
|
-
group_filter_query += '\nAND
|
|
206
|
+
group_filter_query += '\nAND e.group_id IN $group_ids'
|
|
225
207
|
query_params['group_ids'] = group_ids
|
|
226
|
-
query_params['source_node_uuid'] = source_node_uuid
|
|
227
|
-
query_params['target_node_uuid'] = target_node_uuid
|
|
228
208
|
|
|
229
209
|
if source_node_uuid is not None:
|
|
230
|
-
|
|
210
|
+
query_params['source_uuid'] = source_node_uuid
|
|
211
|
+
group_filter_query += '\nAND (n.uuid = $source_uuid)'
|
|
231
212
|
|
|
232
213
|
if target_node_uuid is not None:
|
|
233
|
-
|
|
214
|
+
query_params['target_uuid'] = target_node_uuid
|
|
215
|
+
group_filter_query += '\nAND (m.uuid = $target_uuid)'
|
|
234
216
|
|
|
235
217
|
query = (
|
|
236
218
|
RUNTIME_QUERY
|
|
237
219
|
+ """
|
|
238
|
-
MATCH (n:Entity)-[
|
|
220
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
239
221
|
"""
|
|
240
222
|
+ group_filter_query
|
|
241
223
|
+ filter_query
|
|
242
224
|
+ """
|
|
243
|
-
WITH DISTINCT
|
|
244
|
-
+ get_vector_cosine_func_query('
|
|
225
|
+
WITH DISTINCT e, n, m, """
|
|
226
|
+
+ get_vector_cosine_func_query('e.fact_embedding', '$search_vector', driver.provider)
|
|
245
227
|
+ """ AS score
|
|
246
228
|
WHERE score > $min_score
|
|
247
229
|
RETURN
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
endNode(r).uuid AS target_node_uuid,
|
|
252
|
-
r.created_at AS created_at,
|
|
253
|
-
r.name AS name,
|
|
254
|
-
r.fact AS fact,
|
|
255
|
-
r.episodes AS episodes,
|
|
256
|
-
r.expired_at AS expired_at,
|
|
257
|
-
r.valid_at AS valid_at,
|
|
258
|
-
r.invalid_at AS invalid_at,
|
|
259
|
-
properties(r) AS attributes
|
|
230
|
+
"""
|
|
231
|
+
+ ENTITY_EDGE_RETURN
|
|
232
|
+
+ """
|
|
260
233
|
ORDER BY score DESC
|
|
261
234
|
LIMIT $limit
|
|
262
235
|
"""
|
|
263
236
|
)
|
|
264
|
-
|
|
237
|
+
|
|
238
|
+
records, _, _ = await driver.execute_query(
|
|
265
239
|
query,
|
|
266
|
-
params=query_params,
|
|
267
240
|
search_vector=search_vector,
|
|
268
|
-
source_uuid=source_node_uuid,
|
|
269
|
-
target_uuid=target_node_uuid,
|
|
270
|
-
group_ids=group_ids,
|
|
271
241
|
limit=limit,
|
|
272
242
|
min_score=min_score,
|
|
273
243
|
routing_='r',
|
|
244
|
+
**query_params,
|
|
274
245
|
)
|
|
275
246
|
|
|
276
247
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
@@ -293,41 +264,31 @@ async def edge_bfs_search(
|
|
|
293
264
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
294
265
|
|
|
295
266
|
query = (
|
|
296
|
-
"""
|
|
267
|
+
f"""
|
|
297
268
|
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
298
|
-
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->
|
|
269
|
+
MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
|
|
299
270
|
UNWIND relationships(path) AS rel
|
|
300
|
-
MATCH (n:Entity)-[
|
|
301
|
-
WHERE
|
|
302
|
-
AND
|
|
271
|
+
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
272
|
+
WHERE e.uuid = rel.uuid
|
|
273
|
+
AND e.group_id IN $group_ids
|
|
303
274
|
"""
|
|
304
275
|
+ filter_query
|
|
305
|
-
+ """
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
r.created_at AS created_at,
|
|
312
|
-
r.name AS name,
|
|
313
|
-
r.fact AS fact,
|
|
314
|
-
r.episodes AS episodes,
|
|
315
|
-
r.expired_at AS expired_at,
|
|
316
|
-
r.valid_at AS valid_at,
|
|
317
|
-
r.invalid_at AS invalid_at,
|
|
318
|
-
properties(r) AS attributes
|
|
319
|
-
LIMIT $limit
|
|
276
|
+
+ """
|
|
277
|
+
RETURN DISTINCT
|
|
278
|
+
"""
|
|
279
|
+
+ ENTITY_EDGE_RETURN
|
|
280
|
+
+ """
|
|
281
|
+
LIMIT $limit
|
|
320
282
|
"""
|
|
321
283
|
)
|
|
322
284
|
|
|
323
285
|
records, _, _ = await driver.execute_query(
|
|
324
286
|
query,
|
|
325
|
-
params=filter_params,
|
|
326
287
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
327
|
-
depth=bfs_max_depth,
|
|
328
288
|
group_ids=group_ids,
|
|
329
289
|
limit=limit,
|
|
330
290
|
routing_='r',
|
|
291
|
+
**filter_params,
|
|
331
292
|
)
|
|
332
293
|
|
|
333
294
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
@@ -352,23 +313,27 @@ async def node_fulltext_search(
|
|
|
352
313
|
get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
|
|
353
314
|
+ """
|
|
354
315
|
YIELD node AS n, score
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
316
|
+
WHERE n:Entity AND n.group_id IN $group_ids
|
|
317
|
+
WITH n, score
|
|
318
|
+
LIMIT $limit
|
|
358
319
|
"""
|
|
359
320
|
+ filter_query
|
|
321
|
+
+ """
|
|
322
|
+
RETURN
|
|
323
|
+
"""
|
|
360
324
|
+ ENTITY_NODE_RETURN
|
|
361
325
|
+ """
|
|
362
326
|
ORDER BY score DESC
|
|
363
327
|
"""
|
|
364
328
|
)
|
|
365
|
-
|
|
329
|
+
|
|
330
|
+
records, _, _ = await driver.execute_query(
|
|
366
331
|
query,
|
|
367
|
-
params=filter_params,
|
|
368
332
|
query=fuzzy_query,
|
|
369
333
|
group_ids=group_ids,
|
|
370
334
|
limit=limit,
|
|
371
335
|
routing_='r',
|
|
336
|
+
**filter_params,
|
|
372
337
|
)
|
|
373
338
|
|
|
374
339
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
@@ -406,22 +371,23 @@ async def node_similarity_search(
|
|
|
406
371
|
WITH n, """
|
|
407
372
|
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
|
408
373
|
+ """ AS score
|
|
409
|
-
WHERE score > $min_score
|
|
374
|
+
WHERE score > $min_score
|
|
375
|
+
RETURN
|
|
376
|
+
"""
|
|
410
377
|
+ ENTITY_NODE_RETURN
|
|
411
378
|
+ """
|
|
412
379
|
ORDER BY score DESC
|
|
413
380
|
LIMIT $limit
|
|
414
|
-
|
|
381
|
+
"""
|
|
415
382
|
)
|
|
416
383
|
|
|
417
|
-
records,
|
|
384
|
+
records, _, _ = await driver.execute_query(
|
|
418
385
|
query,
|
|
419
|
-
params=query_params,
|
|
420
386
|
search_vector=search_vector,
|
|
421
|
-
group_ids=group_ids,
|
|
422
387
|
limit=limit,
|
|
423
388
|
min_score=min_score,
|
|
424
389
|
routing_='r',
|
|
390
|
+
**query_params,
|
|
425
391
|
)
|
|
426
392
|
|
|
427
393
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
@@ -444,26 +410,29 @@ async def node_bfs_search(
|
|
|
444
410
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
445
411
|
|
|
446
412
|
query = (
|
|
447
|
-
"""
|
|
413
|
+
f"""
|
|
448
414
|
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
449
|
-
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->
|
|
415
|
+
MATCH (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
|
|
450
416
|
WHERE n.group_id = origin.group_id
|
|
451
417
|
AND origin.group_id IN $group_ids
|
|
452
418
|
"""
|
|
453
419
|
+ filter_query
|
|
420
|
+
+ """
|
|
421
|
+
RETURN
|
|
422
|
+
"""
|
|
454
423
|
+ ENTITY_NODE_RETURN
|
|
455
424
|
+ """
|
|
456
425
|
LIMIT $limit
|
|
457
426
|
"""
|
|
458
427
|
)
|
|
428
|
+
|
|
459
429
|
records, _, _ = await driver.execute_query(
|
|
460
430
|
query,
|
|
461
|
-
params=filter_params,
|
|
462
431
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
463
|
-
depth=bfs_max_depth,
|
|
464
432
|
group_ids=group_ids,
|
|
465
433
|
limit=limit,
|
|
466
434
|
routing_='r',
|
|
435
|
+
**filter_params,
|
|
467
436
|
)
|
|
468
437
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
469
438
|
|
|
@@ -489,16 +458,10 @@ async def episode_fulltext_search(
|
|
|
489
458
|
MATCH (e:Episodic)
|
|
490
459
|
WHERE e.uuid = episode.uuid
|
|
491
460
|
AND e.group_id IN $group_ids
|
|
492
|
-
RETURN
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
e.uuid AS uuid,
|
|
497
|
-
e.name AS name,
|
|
498
|
-
e.group_id AS group_id,
|
|
499
|
-
e.source_description AS source_description,
|
|
500
|
-
e.source AS source,
|
|
501
|
-
e.entity_edges AS entity_edges
|
|
461
|
+
RETURN
|
|
462
|
+
"""
|
|
463
|
+
+ EPISODIC_NODE_RETURN
|
|
464
|
+
+ """
|
|
502
465
|
ORDER BY score DESC
|
|
503
466
|
LIMIT $limit
|
|
504
467
|
"""
|
|
@@ -530,15 +493,12 @@ async def community_fulltext_search(
|
|
|
530
493
|
query = (
|
|
531
494
|
get_nodes_query(driver.provider, 'community_name', '$query')
|
|
532
495
|
+ """
|
|
533
|
-
YIELD node AS
|
|
534
|
-
WHERE
|
|
496
|
+
YIELD node AS n, score
|
|
497
|
+
WHERE n.group_id IN $group_ids
|
|
535
498
|
RETURN
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
comm.created_at AS created_at,
|
|
540
|
-
comm.summary AS summary,
|
|
541
|
-
comm.name_embedding AS name_embedding
|
|
499
|
+
"""
|
|
500
|
+
+ COMMUNITY_NODE_RETURN
|
|
501
|
+
+ """
|
|
542
502
|
ORDER BY score DESC
|
|
543
503
|
LIMIT $limit
|
|
544
504
|
"""
|
|
@@ -568,39 +528,37 @@ async def community_similarity_search(
|
|
|
568
528
|
|
|
569
529
|
group_filter_query: LiteralString = ''
|
|
570
530
|
if group_ids is not None:
|
|
571
|
-
group_filter_query += 'WHERE
|
|
531
|
+
group_filter_query += 'WHERE n.group_id IN $group_ids'
|
|
572
532
|
query_params['group_ids'] = group_ids
|
|
573
533
|
|
|
574
534
|
query = (
|
|
575
535
|
RUNTIME_QUERY
|
|
576
536
|
+ """
|
|
577
|
-
|
|
578
|
-
|
|
537
|
+
MATCH (n:Community)
|
|
538
|
+
"""
|
|
579
539
|
+ group_filter_query
|
|
580
540
|
+ """
|
|
581
|
-
|
|
582
|
-
|
|
541
|
+
WITH n,
|
|
542
|
+
"""
|
|
543
|
+
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
|
583
544
|
+ """ AS score
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
comm.name_embedding AS name_embedding
|
|
592
|
-
ORDER BY score DESC
|
|
593
|
-
LIMIT $limit
|
|
545
|
+
WHERE score > $min_score
|
|
546
|
+
RETURN
|
|
547
|
+
"""
|
|
548
|
+
+ COMMUNITY_NODE_RETURN
|
|
549
|
+
+ """
|
|
550
|
+
ORDER BY score DESC
|
|
551
|
+
LIMIT $limit
|
|
594
552
|
"""
|
|
595
553
|
)
|
|
596
554
|
|
|
597
555
|
records, _, _ = await driver.execute_query(
|
|
598
556
|
query,
|
|
599
557
|
search_vector=search_vector,
|
|
600
|
-
group_ids=group_ids,
|
|
601
558
|
limit=limit,
|
|
602
559
|
min_score=min_score,
|
|
603
560
|
routing_='r',
|
|
561
|
+
**query_params,
|
|
604
562
|
)
|
|
605
563
|
communities = [get_community_node_from_record(record) for record in records]
|
|
606
564
|
|
|
@@ -672,7 +630,7 @@ async def hybrid_node_search(
|
|
|
672
630
|
}
|
|
673
631
|
result_uuids = [[node.uuid for node in result] for result in results]
|
|
674
632
|
|
|
675
|
-
ranked_uuids = rrf(result_uuids)
|
|
633
|
+
ranked_uuids, _ = rrf(result_uuids)
|
|
676
634
|
|
|
677
635
|
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
|
|
678
636
|
|
|
@@ -719,8 +677,8 @@ async def get_relevant_nodes(
|
|
|
719
677
|
WHERE m.group_id = $group_id
|
|
720
678
|
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
|
721
679
|
|
|
722
|
-
WITH node,
|
|
723
|
-
top_vector_nodes,
|
|
680
|
+
WITH node,
|
|
681
|
+
top_vector_nodes,
|
|
724
682
|
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
|
725
683
|
|
|
726
684
|
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
|
@@ -728,10 +686,10 @@ async def get_relevant_nodes(
|
|
|
728
686
|
UNWIND combined_nodes AS combined_node
|
|
729
687
|
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
|
730
688
|
|
|
731
|
-
RETURN
|
|
689
|
+
RETURN
|
|
732
690
|
node.uuid AS search_node_uuid,
|
|
733
691
|
[x IN deduped_nodes | {
|
|
734
|
-
uuid: x.uuid,
|
|
692
|
+
uuid: x.uuid,
|
|
735
693
|
name: x.name,
|
|
736
694
|
name_embedding: x.name_embedding,
|
|
737
695
|
group_id: x.group_id,
|
|
@@ -755,12 +713,12 @@ async def get_relevant_nodes(
|
|
|
755
713
|
|
|
756
714
|
results, _, _ = await driver.execute_query(
|
|
757
715
|
query,
|
|
758
|
-
params=query_params,
|
|
759
716
|
nodes=query_nodes,
|
|
760
717
|
group_id=group_id,
|
|
761
718
|
limit=limit,
|
|
762
719
|
min_score=min_score,
|
|
763
720
|
routing_='r',
|
|
721
|
+
**query_params,
|
|
764
722
|
)
|
|
765
723
|
|
|
766
724
|
relevant_nodes_dict: dict[str, list[EntityNode]] = {
|
|
@@ -825,11 +783,11 @@ async def get_relevant_edges(
|
|
|
825
783
|
|
|
826
784
|
results, _, _ = await driver.execute_query(
|
|
827
785
|
query,
|
|
828
|
-
params=query_params,
|
|
829
786
|
edges=[edge.model_dump() for edge in edges],
|
|
830
787
|
limit=limit,
|
|
831
788
|
min_score=min_score,
|
|
832
789
|
routing_='r',
|
|
790
|
+
**query_params,
|
|
833
791
|
)
|
|
834
792
|
|
|
835
793
|
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
|
@@ -895,11 +853,11 @@ async def get_edge_invalidation_candidates(
|
|
|
895
853
|
|
|
896
854
|
results, _, _ = await driver.execute_query(
|
|
897
855
|
query,
|
|
898
|
-
params=query_params,
|
|
899
856
|
edges=[edge.model_dump() for edge in edges],
|
|
900
857
|
limit=limit,
|
|
901
858
|
min_score=min_score,
|
|
902
859
|
routing_='r',
|
|
860
|
+
**query_params,
|
|
903
861
|
)
|
|
904
862
|
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
|
|
905
863
|
result['search_edge_uuid']: [
|
|
@@ -914,7 +872,9 @@ async def get_edge_invalidation_candidates(
|
|
|
914
872
|
|
|
915
873
|
|
|
916
874
|
# takes in a list of rankings of uuids
|
|
917
|
-
def rrf(
|
|
875
|
+
def rrf(
|
|
876
|
+
results: list[list[str]], rank_const=1, min_score: float = 0
|
|
877
|
+
) -> tuple[list[str], list[float]]:
|
|
918
878
|
scores: dict[str, float] = defaultdict(float)
|
|
919
879
|
for result in results:
|
|
920
880
|
for i, uuid in enumerate(result):
|
|
@@ -925,7 +885,9 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
|
|
|
925
885
|
|
|
926
886
|
sorted_uuids = [term[0] for term in scored_uuids]
|
|
927
887
|
|
|
928
|
-
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
|
|
888
|
+
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
|
|
889
|
+
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
|
|
890
|
+
]
|
|
929
891
|
|
|
930
892
|
|
|
931
893
|
async def node_distance_reranker(
|
|
@@ -933,24 +895,23 @@ async def node_distance_reranker(
|
|
|
933
895
|
node_uuids: list[str],
|
|
934
896
|
center_node_uuid: str,
|
|
935
897
|
min_score: float = 0,
|
|
936
|
-
) -> list[str]:
|
|
898
|
+
) -> tuple[list[str], list[float]]:
|
|
937
899
|
# filter out node_uuid center node node uuid
|
|
938
900
|
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
|
939
901
|
scores: dict[str, float] = {center_node_uuid: 0.0}
|
|
940
902
|
|
|
941
903
|
# Find the shortest path to center node
|
|
942
|
-
|
|
904
|
+
results, header, _ = await driver.execute_query(
|
|
905
|
+
"""
|
|
943
906
|
UNWIND $node_uuids AS node_uuid
|
|
944
907
|
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
|
945
908
|
RETURN 1 AS score, node_uuid AS uuid
|
|
946
|
-
"""
|
|
947
|
-
results, header, _ = await driver.execute_query(
|
|
948
|
-
query,
|
|
909
|
+
""",
|
|
949
910
|
node_uuids=filtered_uuids,
|
|
950
911
|
center_uuid=center_node_uuid,
|
|
951
912
|
routing_='r',
|
|
952
913
|
)
|
|
953
|
-
if driver.provider ==
|
|
914
|
+
if driver.provider == GraphProvider.FALKORDB:
|
|
954
915
|
results = [dict(zip(header, row, strict=True)) for row in results]
|
|
955
916
|
|
|
956
917
|
for result in results:
|
|
@@ -970,24 +931,25 @@ async def node_distance_reranker(
|
|
|
970
931
|
scores[center_node_uuid] = 0.1
|
|
971
932
|
filtered_uuids = [center_node_uuid] + filtered_uuids
|
|
972
933
|
|
|
973
|
-
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score]
|
|
934
|
+
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score], [
|
|
935
|
+
1 / scores[uuid] for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score
|
|
936
|
+
]
|
|
974
937
|
|
|
975
938
|
|
|
976
939
|
async def episode_mentions_reranker(
|
|
977
940
|
driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
|
|
978
|
-
) -> list[str]:
|
|
941
|
+
) -> tuple[list[str], list[float]]:
|
|
979
942
|
# use rrf as a preliminary ranker
|
|
980
|
-
sorted_uuids = rrf(node_uuids)
|
|
943
|
+
sorted_uuids, _ = rrf(node_uuids)
|
|
981
944
|
scores: dict[str, float] = {}
|
|
982
945
|
|
|
983
946
|
# Find the shortest path to center node
|
|
984
|
-
|
|
985
|
-
|
|
947
|
+
results, _, _ = await driver.execute_query(
|
|
948
|
+
"""
|
|
949
|
+
UNWIND $node_uuids AS node_uuid
|
|
986
950
|
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
|
987
951
|
RETURN count(*) AS score, n.uuid AS uuid
|
|
988
|
-
"""
|
|
989
|
-
results, _, _ = await driver.execute_query(
|
|
990
|
-
query,
|
|
952
|
+
""",
|
|
991
953
|
node_uuids=sorted_uuids,
|
|
992
954
|
routing_='r',
|
|
993
955
|
)
|
|
@@ -998,7 +960,9 @@ async def episode_mentions_reranker(
|
|
|
998
960
|
# rerank on shortest distance
|
|
999
961
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
1000
962
|
|
|
1001
|
-
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
|
|
963
|
+
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
|
|
964
|
+
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
|
|
965
|
+
]
|
|
1002
966
|
|
|
1003
967
|
|
|
1004
968
|
def maximal_marginal_relevance(
|
|
@@ -1006,7 +970,7 @@ def maximal_marginal_relevance(
|
|
|
1006
970
|
candidates: dict[str, list[float]],
|
|
1007
971
|
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
|
1008
972
|
min_score: float = -2.0,
|
|
1009
|
-
) -> list[str]:
|
|
973
|
+
) -> tuple[list[str], list[float]]:
|
|
1010
974
|
start = time()
|
|
1011
975
|
query_array = np.array(query_vector)
|
|
1012
976
|
candidate_arrays: dict[str, NDArray] = {}
|
|
@@ -1037,21 +1001,24 @@ def maximal_marginal_relevance(
|
|
|
1037
1001
|
end = time()
|
|
1038
1002
|
logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
|
|
1039
1003
|
|
|
1040
|
-
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
|
|
1004
|
+
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score], [
|
|
1005
|
+
mmr_scores[uuid] for uuid in uuids if mmr_scores[uuid] >= min_score
|
|
1006
|
+
]
|
|
1041
1007
|
|
|
1042
1008
|
|
|
1043
1009
|
async def get_embeddings_for_nodes(
|
|
1044
1010
|
driver: GraphDriver, nodes: list[EntityNode]
|
|
1045
1011
|
) -> dict[str, list[float]]:
|
|
1046
|
-
query: LiteralString = """MATCH (n:Entity)
|
|
1047
|
-
WHERE n.uuid IN $node_uuids
|
|
1048
|
-
RETURN DISTINCT
|
|
1049
|
-
n.uuid AS uuid,
|
|
1050
|
-
n.name_embedding AS name_embedding
|
|
1051
|
-
"""
|
|
1052
|
-
|
|
1053
1012
|
results, _, _ = await driver.execute_query(
|
|
1054
|
-
|
|
1013
|
+
"""
|
|
1014
|
+
MATCH (n:Entity)
|
|
1015
|
+
WHERE n.uuid IN $node_uuids
|
|
1016
|
+
RETURN DISTINCT
|
|
1017
|
+
n.uuid AS uuid,
|
|
1018
|
+
n.name_embedding AS name_embedding
|
|
1019
|
+
""",
|
|
1020
|
+
node_uuids=[node.uuid for node in nodes],
|
|
1021
|
+
routing_='r',
|
|
1055
1022
|
)
|
|
1056
1023
|
|
|
1057
1024
|
embeddings_dict: dict[str, list[float]] = {}
|
|
@@ -1067,15 +1034,14 @@ async def get_embeddings_for_nodes(
|
|
|
1067
1034
|
async def get_embeddings_for_communities(
|
|
1068
1035
|
driver: GraphDriver, communities: list[CommunityNode]
|
|
1069
1036
|
) -> dict[str, list[float]]:
|
|
1070
|
-
query: LiteralString = """MATCH (c:Community)
|
|
1071
|
-
WHERE c.uuid IN $community_uuids
|
|
1072
|
-
RETURN DISTINCT
|
|
1073
|
-
c.uuid AS uuid,
|
|
1074
|
-
c.name_embedding AS name_embedding
|
|
1075
|
-
"""
|
|
1076
|
-
|
|
1077
1037
|
results, _, _ = await driver.execute_query(
|
|
1078
|
-
|
|
1038
|
+
"""
|
|
1039
|
+
MATCH (c:Community)
|
|
1040
|
+
WHERE c.uuid IN $community_uuids
|
|
1041
|
+
RETURN DISTINCT
|
|
1042
|
+
c.uuid AS uuid,
|
|
1043
|
+
c.name_embedding AS name_embedding
|
|
1044
|
+
""",
|
|
1079
1045
|
community_uuids=[community.uuid for community in communities],
|
|
1080
1046
|
routing_='r',
|
|
1081
1047
|
)
|
|
@@ -1093,15 +1059,14 @@ async def get_embeddings_for_communities(
|
|
|
1093
1059
|
async def get_embeddings_for_edges(
|
|
1094
1060
|
driver: GraphDriver, edges: list[EntityEdge]
|
|
1095
1061
|
) -> dict[str, list[float]]:
|
|
1096
|
-
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1097
|
-
WHERE e.uuid IN $edge_uuids
|
|
1098
|
-
RETURN DISTINCT
|
|
1099
|
-
e.uuid AS uuid,
|
|
1100
|
-
e.fact_embedding AS fact_embedding
|
|
1101
|
-
"""
|
|
1102
|
-
|
|
1103
1062
|
results, _, _ = await driver.execute_query(
|
|
1104
|
-
|
|
1063
|
+
"""
|
|
1064
|
+
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1065
|
+
WHERE e.uuid IN $edge_uuids
|
|
1066
|
+
RETURN DISTINCT
|
|
1067
|
+
e.uuid AS uuid,
|
|
1068
|
+
e.fact_embedding AS fact_embedding
|
|
1069
|
+
""",
|
|
1105
1070
|
edge_uuids=[edge.uuid for edge in edges],
|
|
1106
1071
|
routing_='r',
|
|
1107
1072
|
)
|