graphiti-core 0.18.8__py3-none-any.whl → 0.19.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/driver/driver.py +1 -0
- graphiti_core/driver/neptune_driver.py +299 -0
- graphiti_core/edges.py +35 -7
- graphiti_core/graphiti.py +2 -0
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/openai_base_client.py +15 -5
- graphiti_core/llm_client/openai_client.py +16 -6
- graphiti_core/migrations/__init__.py +0 -0
- graphiti_core/migrations/neo4j_node_group_labels.py +53 -0
- graphiti_core/models/edges/edge_db_queries.py +104 -54
- graphiti_core/models/nodes/node_db_queries.py +165 -65
- graphiti_core/nodes.py +121 -51
- graphiti_core/prompts/extract_edges.py +1 -0
- graphiti_core/prompts/extract_nodes.py +1 -1
- graphiti_core/search/search_utils.py +878 -267
- graphiti_core/utils/bulk_utils.py +6 -3
- graphiti_core/utils/maintenance/edge_operations.py +36 -13
- graphiti_core/utils/maintenance/graph_data_operations.py +59 -7
- graphiti_core/utils/maintenance/node_operations.py +7 -3
- {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0rc1.dist-info}/METADATA +44 -6
- {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0rc1.dist-info}/RECORD +23 -20
- {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0rc1.dist-info}/WHEEL +0 -0
- {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0rc1.dist-info}/licenses/LICENSE +0 -0
|
@@ -15,6 +15,7 @@ limitations under the License.
|
|
|
15
15
|
"""
|
|
16
16
|
|
|
17
17
|
import logging
|
|
18
|
+
import os
|
|
18
19
|
from collections import defaultdict
|
|
19
20
|
from time import time
|
|
20
21
|
from typing import Any
|
|
@@ -54,6 +55,7 @@ from graphiti_core.search.search_filters import (
|
|
|
54
55
|
)
|
|
55
56
|
|
|
56
57
|
logger = logging.getLogger(__name__)
|
|
58
|
+
USE_HNSW = os.getenv('USE_HNSW', '').lower() in ('true', '1', 'yes')
|
|
57
59
|
|
|
58
60
|
RELEVANT_SCHEMA_LIMIT = 10
|
|
59
61
|
DEFAULT_MIN_SCORE = 0.6
|
|
@@ -62,6 +64,20 @@ MAX_SEARCH_DEPTH = 3
|
|
|
62
64
|
MAX_QUERY_LENGTH = 128
|
|
63
65
|
|
|
64
66
|
|
|
67
|
+
def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> float:
|
|
68
|
+
"""
|
|
69
|
+
Calculates the cosine similarity between two vectors using NumPy.
|
|
70
|
+
"""
|
|
71
|
+
dot_product = np.dot(vector1, vector2)
|
|
72
|
+
norm_vector1 = np.linalg.norm(vector1)
|
|
73
|
+
norm_vector2 = np.linalg.norm(vector2)
|
|
74
|
+
|
|
75
|
+
if norm_vector1 == 0 or norm_vector2 == 0:
|
|
76
|
+
return 0 # Handle cases where one or both vectors are zero vectors
|
|
77
|
+
|
|
78
|
+
return dot_product / (norm_vector1 * norm_vector2)
|
|
79
|
+
|
|
80
|
+
|
|
65
81
|
def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_syntax: str = ''):
|
|
66
82
|
group_ids_filter_list = (
|
|
67
83
|
[fulltext_syntax + f'group_id:"{g}"' for g in group_ids] if group_ids is not None else []
|
|
@@ -153,32 +169,80 @@ async def edge_fulltext_search(
|
|
|
153
169
|
|
|
154
170
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
155
171
|
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
172
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
173
|
+
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
|
|
174
|
+
if res['hits']['total']['value'] > 0:
|
|
175
|
+
# Calculate Cosine similarity then return the edge ids
|
|
176
|
+
input_ids = []
|
|
177
|
+
for r in res['hits']['hits']:
|
|
178
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
179
|
+
|
|
180
|
+
# Match the edge ids and return the values
|
|
181
|
+
query = (
|
|
182
|
+
"""
|
|
183
|
+
UNWIND $ids as id
|
|
184
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
185
|
+
WHERE e.group_id IN $group_ids
|
|
186
|
+
AND id(e)=id
|
|
187
|
+
"""
|
|
188
|
+
+ filter_query
|
|
189
|
+
+ """
|
|
190
|
+
WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
|
|
191
|
+
RETURN
|
|
192
|
+
e.uuid AS uuid,
|
|
193
|
+
e.group_id AS group_id,
|
|
194
|
+
n.uuid AS source_node_uuid,
|
|
195
|
+
m.uuid AS target_node_uuid,
|
|
196
|
+
e.created_at AS created_at,
|
|
197
|
+
e.name AS name,
|
|
198
|
+
e.fact AS fact,
|
|
199
|
+
split(e.episodes, ",") AS episodes,
|
|
200
|
+
e.expired_at AS expired_at,
|
|
201
|
+
e.valid_at AS valid_at,
|
|
202
|
+
e.invalid_at AS invalid_at,
|
|
203
|
+
properties(e) AS attributes
|
|
204
|
+
ORDER BY score DESC LIMIT $limit
|
|
205
|
+
"""
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
records, _, _ = await driver.execute_query(
|
|
209
|
+
query,
|
|
210
|
+
query=fuzzy_query,
|
|
211
|
+
group_ids=group_ids,
|
|
212
|
+
ids=input_ids,
|
|
213
|
+
limit=limit,
|
|
214
|
+
routing_='r',
|
|
215
|
+
**filter_params,
|
|
216
|
+
)
|
|
217
|
+
else:
|
|
218
|
+
return []
|
|
219
|
+
else:
|
|
220
|
+
query = (
|
|
221
|
+
get_relationships_query('edge_name_and_fact', provider=driver.provider)
|
|
222
|
+
+ """
|
|
223
|
+
YIELD relationship AS rel, score
|
|
224
|
+
MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
|
|
225
|
+
WHERE e.group_id IN $group_ids """
|
|
226
|
+
+ filter_query
|
|
227
|
+
+ """
|
|
228
|
+
WITH e, score, n, m
|
|
229
|
+
RETURN
|
|
230
|
+
"""
|
|
231
|
+
+ ENTITY_EDGE_RETURN
|
|
232
|
+
+ """
|
|
233
|
+
ORDER BY score DESC
|
|
234
|
+
LIMIT $limit
|
|
235
|
+
"""
|
|
236
|
+
)
|
|
173
237
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
238
|
+
records, _, _ = await driver.execute_query(
|
|
239
|
+
query,
|
|
240
|
+
query=fuzzy_query,
|
|
241
|
+
group_ids=group_ids,
|
|
242
|
+
limit=limit,
|
|
243
|
+
routing_='r',
|
|
244
|
+
**filter_params,
|
|
245
|
+
)
|
|
182
246
|
|
|
183
247
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
184
248
|
|
|
@@ -214,35 +278,100 @@ async def edge_similarity_search(
|
|
|
214
278
|
query_params['target_uuid'] = target_node_uuid
|
|
215
279
|
group_filter_query += '\nAND (m.uuid = $target_uuid)'
|
|
216
280
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
281
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
282
|
+
query = (
|
|
283
|
+
RUNTIME_QUERY
|
|
284
|
+
+ """
|
|
285
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
286
|
+
"""
|
|
287
|
+
+ group_filter_query
|
|
288
|
+
+ filter_query
|
|
289
|
+
+ """
|
|
290
|
+
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
|
291
|
+
"""
|
|
292
|
+
)
|
|
293
|
+
resp, header, _ = await driver.execute_query(
|
|
294
|
+
query,
|
|
295
|
+
search_vector=search_vector,
|
|
296
|
+
limit=limit,
|
|
297
|
+
min_score=min_score,
|
|
298
|
+
routing_='r',
|
|
299
|
+
**query_params,
|
|
300
|
+
)
|
|
237
301
|
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
302
|
+
if len(resp) > 0:
|
|
303
|
+
# Calculate Cosine similarity then return the edge ids
|
|
304
|
+
input_ids = []
|
|
305
|
+
for r in resp:
|
|
306
|
+
if r['embedding']:
|
|
307
|
+
score = calculate_cosine_similarity(
|
|
308
|
+
search_vector, list(map(float, r['embedding'].split(',')))
|
|
309
|
+
)
|
|
310
|
+
if score > min_score:
|
|
311
|
+
input_ids.append({'id': r['id'], 'score': score})
|
|
312
|
+
|
|
313
|
+
# Match the edge ides and return the values
|
|
314
|
+
query = """
|
|
315
|
+
UNWIND $ids as i
|
|
316
|
+
MATCH ()-[r]->()
|
|
317
|
+
WHERE id(r) = i.id
|
|
318
|
+
RETURN
|
|
319
|
+
r.uuid AS uuid,
|
|
320
|
+
r.group_id AS group_id,
|
|
321
|
+
startNode(r).uuid AS source_node_uuid,
|
|
322
|
+
endNode(r).uuid AS target_node_uuid,
|
|
323
|
+
r.created_at AS created_at,
|
|
324
|
+
r.name AS name,
|
|
325
|
+
r.fact AS fact,
|
|
326
|
+
split(r.episodes, ",") AS episodes,
|
|
327
|
+
r.expired_at AS expired_at,
|
|
328
|
+
r.valid_at AS valid_at,
|
|
329
|
+
r.invalid_at AS invalid_at,
|
|
330
|
+
properties(r) AS attributes
|
|
331
|
+
ORDER BY i.score DESC
|
|
332
|
+
LIMIT $limit
|
|
333
|
+
"""
|
|
334
|
+
records, _, _ = await driver.execute_query(
|
|
335
|
+
query,
|
|
336
|
+
ids=input_ids,
|
|
337
|
+
search_vector=search_vector,
|
|
338
|
+
limit=limit,
|
|
339
|
+
min_score=min_score,
|
|
340
|
+
routing_='r',
|
|
341
|
+
**query_params,
|
|
342
|
+
)
|
|
343
|
+
else:
|
|
344
|
+
return []
|
|
345
|
+
else:
|
|
346
|
+
query = (
|
|
347
|
+
RUNTIME_QUERY
|
|
348
|
+
+ """
|
|
349
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
350
|
+
"""
|
|
351
|
+
+ group_filter_query
|
|
352
|
+
+ filter_query
|
|
353
|
+
+ """
|
|
354
|
+
WITH DISTINCT e, n, m, """
|
|
355
|
+
+ get_vector_cosine_func_query('e.fact_embedding', '$search_vector', driver.provider)
|
|
356
|
+
+ """ AS score
|
|
357
|
+
WHERE score > $min_score
|
|
358
|
+
RETURN
|
|
359
|
+
"""
|
|
360
|
+
+ ENTITY_EDGE_RETURN
|
|
361
|
+
+ """
|
|
362
|
+
ORDER BY score DESC
|
|
363
|
+
LIMIT $limit
|
|
364
|
+
"""
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
records, _, _ = await driver.execute_query(
|
|
368
|
+
query,
|
|
369
|
+
search_vector=search_vector,
|
|
370
|
+
limit=limit,
|
|
371
|
+
min_score=min_score,
|
|
372
|
+
routing_='r',
|
|
373
|
+
**query_params,
|
|
374
|
+
)
|
|
246
375
|
|
|
247
376
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
248
377
|
|
|
@@ -263,28 +392,58 @@ async def edge_bfs_search(
|
|
|
263
392
|
|
|
264
393
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
265
394
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
395
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
396
|
+
query = (
|
|
397
|
+
f"""
|
|
398
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
399
|
+
MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
|
|
400
|
+
WHERE origin:Entity OR origin:Episodic
|
|
401
|
+
UNWIND relationships(path) AS rel
|
|
402
|
+
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
403
|
+
WHERE e.uuid = rel.uuid
|
|
404
|
+
"""
|
|
405
|
+
+ filter_query
|
|
406
|
+
+ """
|
|
407
|
+
RETURN DISTINCT
|
|
408
|
+
e.uuid AS uuid,
|
|
409
|
+
e.group_id AS group_id,
|
|
410
|
+
startNode(e).uuid AS source_node_uuid,
|
|
411
|
+
endNode(e).uuid AS target_node_uuid,
|
|
412
|
+
e.created_at AS created_at,
|
|
413
|
+
e.name AS name,
|
|
414
|
+
e.fact AS fact,
|
|
415
|
+
split(e.episodes, ',') AS episodes,
|
|
416
|
+
e.expired_at AS expired_at,
|
|
417
|
+
e.valid_at AS valid_at,
|
|
418
|
+
e.invalid_at AS invalid_at,
|
|
419
|
+
properties(e) AS attributes
|
|
420
|
+
LIMIT $limit
|
|
421
|
+
"""
|
|
422
|
+
)
|
|
423
|
+
else:
|
|
424
|
+
query = (
|
|
425
|
+
f"""
|
|
426
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
427
|
+
MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
|
|
428
|
+
UNWIND relationships(path) AS rel
|
|
429
|
+
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
430
|
+
WHERE e.uuid = rel.uuid
|
|
431
|
+
AND e.group_id IN $group_ids
|
|
432
|
+
"""
|
|
433
|
+
+ filter_query
|
|
434
|
+
+ """
|
|
435
|
+
RETURN DISTINCT
|
|
436
|
+
"""
|
|
437
|
+
+ ENTITY_EDGE_RETURN
|
|
438
|
+
+ """
|
|
439
|
+
LIMIT $limit
|
|
440
|
+
"""
|
|
441
|
+
)
|
|
284
442
|
|
|
285
443
|
records, _, _ = await driver.execute_query(
|
|
286
444
|
query,
|
|
287
445
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
446
|
+
depth=bfs_max_depth,
|
|
288
447
|
group_ids=group_ids,
|
|
289
448
|
limit=limit,
|
|
290
449
|
routing_='r',
|
|
@@ -309,30 +468,70 @@ async def node_fulltext_search(
|
|
|
309
468
|
return []
|
|
310
469
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
311
470
|
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
471
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
472
|
+
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
473
|
+
if res['hits']['total']['value'] > 0:
|
|
474
|
+
# Calculate Cosine similarity then return the edge ids
|
|
475
|
+
input_ids = []
|
|
476
|
+
for r in res['hits']['hits']:
|
|
477
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
478
|
+
|
|
479
|
+
# Match the edge ides and return the values
|
|
480
|
+
query = (
|
|
481
|
+
"""
|
|
482
|
+
UNWIND $ids as i
|
|
483
|
+
MATCH (n:Entity)
|
|
484
|
+
WHERE n.uuid=i.id
|
|
485
|
+
RETURN
|
|
486
|
+
"""
|
|
487
|
+
+ ENTITY_NODE_RETURN
|
|
488
|
+
+ """
|
|
489
|
+
ORDER BY i.score DESC
|
|
490
|
+
LIMIT $limit
|
|
491
|
+
"""
|
|
492
|
+
)
|
|
493
|
+
records, _, _ = await driver.execute_query(
|
|
494
|
+
query,
|
|
495
|
+
ids=input_ids,
|
|
496
|
+
query=fuzzy_query,
|
|
497
|
+
group_ids=group_ids,
|
|
498
|
+
limit=limit,
|
|
499
|
+
routing_='r',
|
|
500
|
+
**filter_params,
|
|
501
|
+
)
|
|
502
|
+
else:
|
|
503
|
+
return []
|
|
504
|
+
else:
|
|
505
|
+
index_name = (
|
|
506
|
+
'node_name_and_summary'
|
|
507
|
+
if not USE_HNSW
|
|
508
|
+
else 'node_name_and_summary_'
|
|
509
|
+
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
|
510
|
+
)
|
|
511
|
+
query = (
|
|
512
|
+
get_nodes_query(driver.provider, index_name, '$query')
|
|
513
|
+
+ """
|
|
514
|
+
YIELD node AS n, score
|
|
515
|
+
WHERE n:Entity AND n.group_id IN $group_ids
|
|
516
|
+
"""
|
|
517
|
+
+ filter_query
|
|
518
|
+
+ """
|
|
519
|
+
WITH n, score
|
|
520
|
+
ORDER BY score DESC
|
|
521
|
+
LIMIT $limit
|
|
522
|
+
RETURN
|
|
523
|
+
"""
|
|
524
|
+
+ ENTITY_NODE_RETURN
|
|
525
|
+
)
|
|
327
526
|
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
527
|
+
records, _, _ = await driver.execute_query(
|
|
528
|
+
query,
|
|
529
|
+
query=fuzzy_query,
|
|
530
|
+
group_ids=group_ids,
|
|
531
|
+
limit=limit,
|
|
532
|
+
routing_='r',
|
|
533
|
+
**filter_params,
|
|
534
|
+
)
|
|
336
535
|
|
|
337
536
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
338
537
|
|
|
@@ -358,35 +557,124 @@ async def node_similarity_search(
|
|
|
358
557
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
359
558
|
query_params.update(filter_params)
|
|
360
559
|
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
560
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
561
|
+
query = (
|
|
562
|
+
RUNTIME_QUERY
|
|
563
|
+
+ """
|
|
564
|
+
MATCH (n:Entity)
|
|
565
|
+
"""
|
|
566
|
+
+ group_filter_query
|
|
567
|
+
+ filter_query
|
|
568
|
+
+ """
|
|
569
|
+
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
570
|
+
"""
|
|
571
|
+
)
|
|
572
|
+
resp, header, _ = await driver.execute_query(
|
|
573
|
+
query,
|
|
574
|
+
params=query_params,
|
|
575
|
+
search_vector=search_vector,
|
|
576
|
+
group_ids=group_ids,
|
|
577
|
+
limit=limit,
|
|
578
|
+
min_score=min_score,
|
|
579
|
+
routing_='r',
|
|
580
|
+
)
|
|
381
581
|
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
582
|
+
if len(resp) > 0:
|
|
583
|
+
# Calculate Cosine similarity then return the edge ids
|
|
584
|
+
input_ids = []
|
|
585
|
+
for r in resp:
|
|
586
|
+
if r['embedding']:
|
|
587
|
+
score = calculate_cosine_similarity(
|
|
588
|
+
search_vector, list(map(float, r['embedding'].split(',')))
|
|
589
|
+
)
|
|
590
|
+
if score > min_score:
|
|
591
|
+
input_ids.append({'id': r['id'], 'score': score})
|
|
592
|
+
|
|
593
|
+
# Match the edge ides and return the values
|
|
594
|
+
query = (
|
|
595
|
+
"""
|
|
596
|
+
UNWIND $ids as i
|
|
597
|
+
MATCH (n:Entity)
|
|
598
|
+
WHERE id(n)=i.id
|
|
599
|
+
RETURN
|
|
600
|
+
"""
|
|
601
|
+
+ ENTITY_NODE_RETURN
|
|
602
|
+
+ """
|
|
603
|
+
ORDER BY i.score DESC
|
|
604
|
+
LIMIT $limit
|
|
605
|
+
"""
|
|
606
|
+
)
|
|
607
|
+
records, header, _ = await driver.execute_query(
|
|
608
|
+
query,
|
|
609
|
+
ids=input_ids,
|
|
610
|
+
search_vector=search_vector,
|
|
611
|
+
limit=limit,
|
|
612
|
+
min_score=min_score,
|
|
613
|
+
routing_='r',
|
|
614
|
+
**query_params,
|
|
615
|
+
)
|
|
616
|
+
else:
|
|
617
|
+
return []
|
|
618
|
+
elif driver.provider == GraphProvider.NEO4J and USE_HNSW:
|
|
619
|
+
index_name = 'group_entity_vector_' + (
|
|
620
|
+
group_ids[0].replace('-', '') if group_ids is not None else ''
|
|
621
|
+
)
|
|
622
|
+
query = (
|
|
623
|
+
f"""
|
|
624
|
+
CALL db.index.vector.queryNodes('{index_name}', {limit}, $search_vector) YIELD node AS n, score
|
|
625
|
+
"""
|
|
626
|
+
+ group_filter_query
|
|
627
|
+
+ filter_query
|
|
628
|
+
+ """
|
|
629
|
+
AND score > $min_score
|
|
630
|
+
RETURN
|
|
631
|
+
"""
|
|
632
|
+
+ ENTITY_NODE_RETURN
|
|
633
|
+
+ """
|
|
634
|
+
ORDER BY score DESC
|
|
635
|
+
LIMIT $limit
|
|
636
|
+
"""
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
records, _, _ = await driver.execute_query(
|
|
640
|
+
query,
|
|
641
|
+
search_vector=search_vector,
|
|
642
|
+
limit=limit,
|
|
643
|
+
min_score=min_score,
|
|
644
|
+
routing_='r',
|
|
645
|
+
**query_params,
|
|
646
|
+
)
|
|
647
|
+
|
|
648
|
+
else:
|
|
649
|
+
query = (
|
|
650
|
+
RUNTIME_QUERY
|
|
651
|
+
+ """
|
|
652
|
+
MATCH (n:Entity)
|
|
653
|
+
"""
|
|
654
|
+
+ group_filter_query
|
|
655
|
+
+ filter_query
|
|
656
|
+
+ """
|
|
657
|
+
WITH n, """
|
|
658
|
+
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
|
659
|
+
+ """ AS score
|
|
660
|
+
WHERE score > $min_score
|
|
661
|
+
RETURN
|
|
662
|
+
"""
|
|
663
|
+
+ ENTITY_NODE_RETURN
|
|
664
|
+
+ """
|
|
665
|
+
ORDER BY score DESC
|
|
666
|
+
LIMIT $limit
|
|
667
|
+
"""
|
|
668
|
+
)
|
|
669
|
+
|
|
670
|
+
records, _, _ = await driver.execute_query(
|
|
671
|
+
query,
|
|
672
|
+
search_vector=search_vector,
|
|
673
|
+
limit=limit,
|
|
674
|
+
min_score=min_score,
|
|
675
|
+
routing_='r',
|
|
676
|
+
**query_params,
|
|
677
|
+
)
|
|
390
678
|
|
|
391
679
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
392
680
|
|
|
@@ -407,22 +695,40 @@ async def node_bfs_search(
|
|
|
407
695
|
|
|
408
696
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
409
697
|
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
698
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
699
|
+
query = (
|
|
700
|
+
f"""
|
|
701
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
702
|
+
MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
|
|
703
|
+
WHERE origin:Entity OR origin.Episode
|
|
704
|
+
AND n.group_id = origin.group_id
|
|
705
|
+
"""
|
|
706
|
+
+ filter_query
|
|
707
|
+
+ """
|
|
708
|
+
RETURN
|
|
709
|
+
"""
|
|
710
|
+
+ ENTITY_NODE_RETURN
|
|
711
|
+
+ """
|
|
712
|
+
LIMIT $limit
|
|
713
|
+
"""
|
|
714
|
+
)
|
|
715
|
+
else:
|
|
716
|
+
query = (
|
|
717
|
+
f"""
|
|
718
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
719
|
+
MATCH (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
|
|
720
|
+
WHERE n.group_id = origin.group_id
|
|
721
|
+
AND origin.group_id IN $group_ids
|
|
722
|
+
"""
|
|
723
|
+
+ filter_query
|
|
724
|
+
+ """
|
|
725
|
+
RETURN
|
|
726
|
+
"""
|
|
727
|
+
+ ENTITY_NODE_RETURN
|
|
728
|
+
+ """
|
|
729
|
+
LIMIT $limit
|
|
730
|
+
"""
|
|
731
|
+
)
|
|
426
732
|
|
|
427
733
|
records, _, _ = await driver.execute_query(
|
|
428
734
|
query,
|
|
@@ -449,29 +755,72 @@ async def episode_fulltext_search(
|
|
|
449
755
|
if fuzzy_query == '':
|
|
450
756
|
return []
|
|
451
757
|
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
758
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
759
|
+
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
760
|
+
if res['hits']['total']['value'] > 0:
|
|
761
|
+
# Calculate Cosine similarity then return the edge ids
|
|
762
|
+
input_ids = []
|
|
763
|
+
for r in res['hits']['hits']:
|
|
764
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
765
|
+
|
|
766
|
+
# Match the edge ides and return the values
|
|
767
|
+
query = """
|
|
768
|
+
UNWIND $ids as i
|
|
769
|
+
MATCH (e:Episodic)
|
|
770
|
+
WHERE e.uuid=i.id
|
|
771
|
+
RETURN
|
|
772
|
+
e.content AS content,
|
|
773
|
+
e.created_at AS created_at,
|
|
774
|
+
e.valid_at AS valid_at,
|
|
775
|
+
e.uuid AS uuid,
|
|
776
|
+
e.name AS name,
|
|
777
|
+
e.group_id AS group_id,
|
|
778
|
+
e.source_description AS source_description,
|
|
779
|
+
e.source AS source,
|
|
780
|
+
e.entity_edges AS entity_edges
|
|
781
|
+
ORDER BY i.score DESC
|
|
782
|
+
LIMIT $limit
|
|
783
|
+
"""
|
|
784
|
+
records, _, _ = await driver.execute_query(
|
|
785
|
+
query,
|
|
786
|
+
ids=input_ids,
|
|
787
|
+
query=fuzzy_query,
|
|
788
|
+
group_ids=group_ids,
|
|
789
|
+
limit=limit,
|
|
790
|
+
routing_='r',
|
|
791
|
+
)
|
|
792
|
+
else:
|
|
793
|
+
return []
|
|
794
|
+
else:
|
|
795
|
+
index_name = (
|
|
796
|
+
'episode_content'
|
|
797
|
+
if not USE_HNSW
|
|
798
|
+
else 'episode_content_'
|
|
799
|
+
+ (group_ids[0].replace('-', '') if group_ids is not None else '')
|
|
800
|
+
)
|
|
801
|
+
query = (
|
|
802
|
+
get_nodes_query(driver.provider, index_name, '$query')
|
|
803
|
+
+ """
|
|
804
|
+
YIELD node AS episode, score
|
|
805
|
+
MATCH (e:Episodic)
|
|
806
|
+
WHERE e.uuid = episode.uuid
|
|
807
|
+
AND e.group_id IN $group_ids
|
|
808
|
+
RETURN
|
|
809
|
+
"""
|
|
810
|
+
+ EPISODIC_NODE_RETURN
|
|
811
|
+
+ """
|
|
812
|
+
ORDER BY score DESC
|
|
813
|
+
LIMIT $limit
|
|
814
|
+
"""
|
|
815
|
+
)
|
|
467
816
|
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
817
|
+
records, _, _ = await driver.execute_query(
|
|
818
|
+
query,
|
|
819
|
+
query=fuzzy_query,
|
|
820
|
+
group_ids=group_ids,
|
|
821
|
+
limit=limit,
|
|
822
|
+
routing_='r',
|
|
823
|
+
)
|
|
475
824
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
|
476
825
|
|
|
477
826
|
return episodes
|
|
@@ -488,27 +837,61 @@ async def community_fulltext_search(
|
|
|
488
837
|
if fuzzy_query == '':
|
|
489
838
|
return []
|
|
490
839
|
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
840
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
841
|
+
res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
842
|
+
if res['hits']['total']['value'] > 0:
|
|
843
|
+
# Calculate Cosine similarity then return the edge ids
|
|
844
|
+
input_ids = []
|
|
845
|
+
for r in res['hits']['hits']:
|
|
846
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
847
|
+
|
|
848
|
+
# Match the edge ides and return the values
|
|
849
|
+
query = """
|
|
850
|
+
UNWIND $ids as i
|
|
851
|
+
MATCH (comm:Community)
|
|
852
|
+
WHERE comm.uuid=i.id
|
|
853
|
+
RETURN
|
|
854
|
+
comm.uuid AS uuid,
|
|
855
|
+
comm.group_id AS group_id,
|
|
856
|
+
comm.name AS name,
|
|
857
|
+
comm.created_at AS created_at,
|
|
858
|
+
comm.summary AS summary,
|
|
859
|
+
[x IN split(comm.name_embedding, ",") | toFloat(x)]AS name_embedding
|
|
860
|
+
ORDER BY i.score DESC
|
|
861
|
+
LIMIT $limit
|
|
862
|
+
"""
|
|
863
|
+
records, _, _ = await driver.execute_query(
|
|
864
|
+
query,
|
|
865
|
+
ids=input_ids,
|
|
866
|
+
query=fuzzy_query,
|
|
867
|
+
group_ids=group_ids,
|
|
868
|
+
limit=limit,
|
|
869
|
+
routing_='r',
|
|
870
|
+
)
|
|
871
|
+
else:
|
|
872
|
+
return []
|
|
873
|
+
else:
|
|
874
|
+
query = (
|
|
875
|
+
get_nodes_query(driver.provider, 'community_name', '$query')
|
|
876
|
+
+ """
|
|
877
|
+
YIELD node AS n, score
|
|
878
|
+
WHERE n.group_id IN $group_ids
|
|
879
|
+
RETURN
|
|
880
|
+
"""
|
|
881
|
+
+ COMMUNITY_NODE_RETURN
|
|
882
|
+
+ """
|
|
883
|
+
ORDER BY score DESC
|
|
884
|
+
LIMIT $limit
|
|
885
|
+
"""
|
|
886
|
+
)
|
|
504
887
|
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
888
|
+
records, _, _ = await driver.execute_query(
|
|
889
|
+
query,
|
|
890
|
+
query=fuzzy_query,
|
|
891
|
+
group_ids=group_ids,
|
|
892
|
+
limit=limit,
|
|
893
|
+
routing_='r',
|
|
894
|
+
)
|
|
512
895
|
communities = [get_community_node_from_record(record) for record in records]
|
|
513
896
|
|
|
514
897
|
return communities
|
|
@@ -529,35 +912,93 @@ async def community_similarity_search(
|
|
|
529
912
|
group_filter_query += 'WHERE n.group_id IN $group_ids'
|
|
530
913
|
query_params['group_ids'] = group_ids
|
|
531
914
|
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
)
|
|
915
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
916
|
+
query = (
|
|
917
|
+
RUNTIME_QUERY
|
|
918
|
+
+ """
|
|
919
|
+
MATCH (n:Community)
|
|
920
|
+
"""
|
|
921
|
+
+ group_filter_query
|
|
922
|
+
+ """
|
|
923
|
+
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
924
|
+
"""
|
|
925
|
+
)
|
|
926
|
+
resp, header, _ = await driver.execute_query(
|
|
927
|
+
query,
|
|
928
|
+
search_vector=search_vector,
|
|
929
|
+
limit=limit,
|
|
930
|
+
min_score=min_score,
|
|
931
|
+
routing_='r',
|
|
932
|
+
**query_params,
|
|
933
|
+
)
|
|
552
934
|
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
935
|
+
if len(resp) > 0:
|
|
936
|
+
# Calculate Cosine similarity then return the edge ids
|
|
937
|
+
input_ids = []
|
|
938
|
+
for r in resp:
|
|
939
|
+
if r['embedding']:
|
|
940
|
+
score = calculate_cosine_similarity(
|
|
941
|
+
search_vector, list(map(float, r['embedding'].split(',')))
|
|
942
|
+
)
|
|
943
|
+
if score > min_score:
|
|
944
|
+
input_ids.append({'id': r['id'], 'score': score})
|
|
945
|
+
|
|
946
|
+
# Match the edge ides and return the values
|
|
947
|
+
query = """
|
|
948
|
+
UNWIND $ids as i
|
|
949
|
+
MATCH (comm:Community)
|
|
950
|
+
WHERE id(comm)=i.id
|
|
951
|
+
RETURN
|
|
952
|
+
comm.uuid As uuid,
|
|
953
|
+
comm.group_id AS group_id,
|
|
954
|
+
comm.name AS name,
|
|
955
|
+
comm.created_at AS created_at,
|
|
956
|
+
comm.summary AS summary,
|
|
957
|
+
comm.name_embedding AS name_embedding
|
|
958
|
+
ORDER BY i.score DESC
|
|
959
|
+
LIMIT $limit
|
|
960
|
+
"""
|
|
961
|
+
records, header, _ = await driver.execute_query(
|
|
962
|
+
query,
|
|
963
|
+
ids=input_ids,
|
|
964
|
+
search_vector=search_vector,
|
|
965
|
+
limit=limit,
|
|
966
|
+
min_score=min_score,
|
|
967
|
+
routing_='r',
|
|
968
|
+
**query_params,
|
|
969
|
+
)
|
|
970
|
+
else:
|
|
971
|
+
return []
|
|
972
|
+
else:
|
|
973
|
+
query = (
|
|
974
|
+
RUNTIME_QUERY
|
|
975
|
+
+ """
|
|
976
|
+
MATCH (n:Community)
|
|
977
|
+
"""
|
|
978
|
+
+ group_filter_query
|
|
979
|
+
+ """
|
|
980
|
+
WITH n,
|
|
981
|
+
"""
|
|
982
|
+
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
|
983
|
+
+ """ AS score
|
|
984
|
+
WHERE score > $min_score
|
|
985
|
+
RETURN
|
|
986
|
+
"""
|
|
987
|
+
+ COMMUNITY_NODE_RETURN
|
|
988
|
+
+ """
|
|
989
|
+
ORDER BY score DESC
|
|
990
|
+
LIMIT $limit
|
|
991
|
+
"""
|
|
992
|
+
)
|
|
993
|
+
|
|
994
|
+
records, _, _ = await driver.execute_query(
|
|
995
|
+
query,
|
|
996
|
+
search_vector=search_vector,
|
|
997
|
+
limit=limit,
|
|
998
|
+
min_score=min_score,
|
|
999
|
+
routing_='r',
|
|
1000
|
+
**query_params,
|
|
1001
|
+
)
|
|
561
1002
|
communities = [get_community_node_from_record(record) for record in records]
|
|
562
1003
|
|
|
563
1004
|
return communities
|
|
@@ -746,20 +1187,45 @@ async def get_relevant_edges(
|
|
|
746
1187
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
747
1188
|
query_params.update(filter_params)
|
|
748
1189
|
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
1190
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1191
|
+
query = (
|
|
1192
|
+
RUNTIME_QUERY
|
|
1193
|
+
+ """
|
|
1194
|
+
UNWIND $edges AS edge
|
|
1195
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1196
|
+
"""
|
|
1197
|
+
+ filter_query
|
|
1198
|
+
+ """
|
|
1199
|
+
WITH e, edge
|
|
1200
|
+
RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
|
|
1201
|
+
edge.fact_embedding as target_embedding
|
|
1202
|
+
"""
|
|
1203
|
+
)
|
|
1204
|
+
resp, _, _ = await driver.execute_query(
|
|
1205
|
+
query,
|
|
1206
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1207
|
+
limit=limit,
|
|
1208
|
+
min_score=min_score,
|
|
1209
|
+
routing_='r',
|
|
1210
|
+
**query_params,
|
|
1211
|
+
)
|
|
1212
|
+
|
|
1213
|
+
# Calculate Cosine similarity then return the edge ids
|
|
1214
|
+
input_ids = []
|
|
1215
|
+
for r in resp:
|
|
1216
|
+
score = calculate_cosine_similarity(
|
|
1217
|
+
list(map(float, r['source_embedding'].split(','))), r['target_embedding']
|
|
1218
|
+
)
|
|
1219
|
+
if score > min_score:
|
|
1220
|
+
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
|
|
1221
|
+
|
|
1222
|
+
# Match the edge ides and return the values
|
|
1223
|
+
query = """
|
|
1224
|
+
UNWIND $ids AS edge
|
|
1225
|
+
MATCH ()-[e]->()
|
|
1226
|
+
WHERE id(e) = edge.id
|
|
1227
|
+
WITH edge, e
|
|
1228
|
+
ORDER BY edge.score DESC
|
|
763
1229
|
RETURN edge.uuid AS search_edge_uuid,
|
|
764
1230
|
collect({
|
|
765
1231
|
uuid: e.uuid,
|
|
@@ -769,24 +1235,69 @@ async def get_relevant_edges(
|
|
|
769
1235
|
name: e.name,
|
|
770
1236
|
group_id: e.group_id,
|
|
771
1237
|
fact: e.fact,
|
|
772
|
-
fact_embedding: e.fact_embedding,
|
|
773
|
-
episodes: e.episodes,
|
|
1238
|
+
fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
|
|
1239
|
+
episodes: split(e.episodes, ","),
|
|
774
1240
|
expired_at: e.expired_at,
|
|
775
1241
|
valid_at: e.valid_at,
|
|
776
1242
|
invalid_at: e.invalid_at,
|
|
777
1243
|
attributes: properties(e)
|
|
778
1244
|
})[..$limit] AS matches
|
|
779
|
-
|
|
780
|
-
|
|
1245
|
+
"""
|
|
1246
|
+
|
|
1247
|
+
results, _, _ = await driver.execute_query(
|
|
1248
|
+
query,
|
|
1249
|
+
params=query_params,
|
|
1250
|
+
ids=input_ids,
|
|
1251
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1252
|
+
limit=limit,
|
|
1253
|
+
min_score=min_score,
|
|
1254
|
+
routing_='r',
|
|
1255
|
+
**query_params,
|
|
1256
|
+
)
|
|
1257
|
+
else:
|
|
1258
|
+
query = (
|
|
1259
|
+
RUNTIME_QUERY
|
|
1260
|
+
+ """
|
|
1261
|
+
UNWIND $edges AS edge
|
|
1262
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1263
|
+
"""
|
|
1264
|
+
+ filter_query
|
|
1265
|
+
+ """
|
|
1266
|
+
WITH e, edge, """
|
|
1267
|
+
+ get_vector_cosine_func_query(
|
|
1268
|
+
'e.fact_embedding', 'edge.fact_embedding', driver.provider
|
|
1269
|
+
)
|
|
1270
|
+
+ """ AS score
|
|
1271
|
+
WHERE score > $min_score
|
|
1272
|
+
WITH edge, e, score
|
|
1273
|
+
ORDER BY score DESC
|
|
1274
|
+
RETURN edge.uuid AS search_edge_uuid,
|
|
1275
|
+
collect({
|
|
1276
|
+
uuid: e.uuid,
|
|
1277
|
+
source_node_uuid: startNode(e).uuid,
|
|
1278
|
+
target_node_uuid: endNode(e).uuid,
|
|
1279
|
+
created_at: e.created_at,
|
|
1280
|
+
name: e.name,
|
|
1281
|
+
group_id: e.group_id,
|
|
1282
|
+
fact: e.fact,
|
|
1283
|
+
fact_embedding: e.fact_embedding,
|
|
1284
|
+
episodes: e.episodes,
|
|
1285
|
+
expired_at: e.expired_at,
|
|
1286
|
+
valid_at: e.valid_at,
|
|
1287
|
+
invalid_at: e.invalid_at,
|
|
1288
|
+
attributes: properties(e)
|
|
1289
|
+
})[..$limit] AS matches
|
|
1290
|
+
"""
|
|
1291
|
+
)
|
|
781
1292
|
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
1293
|
+
results, _, _ = await driver.execute_query(
|
|
1294
|
+
query,
|
|
1295
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1296
|
+
limit=limit,
|
|
1297
|
+
min_score=min_score,
|
|
1298
|
+
routing_='r',
|
|
1299
|
+
**query_params,
|
|
1300
|
+
)
|
|
790
1301
|
|
|
791
1302
|
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
|
792
1303
|
result['search_edge_uuid']: [
|
|
@@ -815,21 +1326,47 @@ async def get_edge_invalidation_candidates(
|
|
|
815
1326
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
816
1327
|
query_params.update(filter_params)
|
|
817
1328
|
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
1329
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1330
|
+
query = (
|
|
1331
|
+
RUNTIME_QUERY
|
|
1332
|
+
+ """
|
|
1333
|
+
UNWIND $edges AS edge
|
|
1334
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1335
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1336
|
+
"""
|
|
1337
|
+
+ filter_query
|
|
1338
|
+
+ """
|
|
1339
|
+
WITH e, edge
|
|
1340
|
+
RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding,
|
|
1341
|
+
edge.fact_embedding as target_embedding,
|
|
1342
|
+
edge.uuid as search_edge_uuid
|
|
1343
|
+
"""
|
|
1344
|
+
)
|
|
1345
|
+
resp, _, _ = await driver.execute_query(
|
|
1346
|
+
query,
|
|
1347
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1348
|
+
limit=limit,
|
|
1349
|
+
min_score=min_score,
|
|
1350
|
+
routing_='r',
|
|
1351
|
+
**query_params,
|
|
1352
|
+
)
|
|
1353
|
+
|
|
1354
|
+
# Calculate Cosine similarity then return the edge ids
|
|
1355
|
+
input_ids = []
|
|
1356
|
+
for r in resp:
|
|
1357
|
+
score = calculate_cosine_similarity(
|
|
1358
|
+
list(map(float, r['source_embedding'].split(','))), r['target_embedding']
|
|
1359
|
+
)
|
|
1360
|
+
if score > min_score:
|
|
1361
|
+
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
|
|
1362
|
+
|
|
1363
|
+
# Match the edge ides and return the values
|
|
1364
|
+
query = """
|
|
1365
|
+
UNWIND $ids AS edge
|
|
1366
|
+
MATCH ()-[e]->()
|
|
1367
|
+
WHERE id(e) = edge.id
|
|
1368
|
+
WITH edge, e
|
|
1369
|
+
ORDER BY edge.score DESC
|
|
833
1370
|
RETURN edge.uuid AS search_edge_uuid,
|
|
834
1371
|
collect({
|
|
835
1372
|
uuid: e.uuid,
|
|
@@ -839,24 +1376,68 @@ async def get_edge_invalidation_candidates(
|
|
|
839
1376
|
name: e.name,
|
|
840
1377
|
group_id: e.group_id,
|
|
841
1378
|
fact: e.fact,
|
|
842
|
-
fact_embedding: e.fact_embedding,
|
|
843
|
-
episodes: e.episodes,
|
|
1379
|
+
fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
|
|
1380
|
+
episodes: split(e.episodes, ","),
|
|
844
1381
|
expired_at: e.expired_at,
|
|
845
1382
|
valid_at: e.valid_at,
|
|
846
1383
|
invalid_at: e.invalid_at,
|
|
847
1384
|
attributes: properties(e)
|
|
848
1385
|
})[..$limit] AS matches
|
|
849
|
-
|
|
850
|
-
|
|
1386
|
+
"""
|
|
1387
|
+
results, _, _ = await driver.execute_query(
|
|
1388
|
+
query,
|
|
1389
|
+
ids=input_ids,
|
|
1390
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1391
|
+
limit=limit,
|
|
1392
|
+
min_score=min_score,
|
|
1393
|
+
routing_='r',
|
|
1394
|
+
**query_params,
|
|
1395
|
+
)
|
|
1396
|
+
else:
|
|
1397
|
+
query = (
|
|
1398
|
+
RUNTIME_QUERY
|
|
1399
|
+
+ """
|
|
1400
|
+
UNWIND $edges AS edge
|
|
1401
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1402
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1403
|
+
"""
|
|
1404
|
+
+ filter_query
|
|
1405
|
+
+ """
|
|
1406
|
+
WITH edge, e, """
|
|
1407
|
+
+ get_vector_cosine_func_query(
|
|
1408
|
+
'e.fact_embedding', 'edge.fact_embedding', driver.provider
|
|
1409
|
+
)
|
|
1410
|
+
+ """ AS score
|
|
1411
|
+
WHERE score > $min_score
|
|
1412
|
+
WITH edge, e, score
|
|
1413
|
+
ORDER BY score DESC
|
|
1414
|
+
RETURN edge.uuid AS search_edge_uuid,
|
|
1415
|
+
collect({
|
|
1416
|
+
uuid: e.uuid,
|
|
1417
|
+
source_node_uuid: startNode(e).uuid,
|
|
1418
|
+
target_node_uuid: endNode(e).uuid,
|
|
1419
|
+
created_at: e.created_at,
|
|
1420
|
+
name: e.name,
|
|
1421
|
+
group_id: e.group_id,
|
|
1422
|
+
fact: e.fact,
|
|
1423
|
+
fact_embedding: e.fact_embedding,
|
|
1424
|
+
episodes: e.episodes,
|
|
1425
|
+
expired_at: e.expired_at,
|
|
1426
|
+
valid_at: e.valid_at,
|
|
1427
|
+
invalid_at: e.invalid_at,
|
|
1428
|
+
attributes: properties(e)
|
|
1429
|
+
})[..$limit] AS matches
|
|
1430
|
+
"""
|
|
1431
|
+
)
|
|
851
1432
|
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
1433
|
+
results, _, _ = await driver.execute_query(
|
|
1434
|
+
query,
|
|
1435
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1436
|
+
limit=limit,
|
|
1437
|
+
min_score=min_score,
|
|
1438
|
+
routing_='r',
|
|
1439
|
+
**query_params,
|
|
1440
|
+
)
|
|
860
1441
|
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
|
|
861
1442
|
result['search_edge_uuid']: [
|
|
862
1443
|
get_entity_edge_from_record(record) for record in result['matches']
|
|
@@ -1007,14 +1588,24 @@ def maximal_marginal_relevance(
|
|
|
1007
1588
|
async def get_embeddings_for_nodes(
|
|
1008
1589
|
driver: GraphDriver, nodes: list[EntityNode]
|
|
1009
1590
|
) -> dict[str, list[float]]:
|
|
1010
|
-
|
|
1591
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1592
|
+
query = """
|
|
1593
|
+
MATCH (n:Entity)
|
|
1594
|
+
WHERE n.uuid IN $node_uuids
|
|
1595
|
+
RETURN DISTINCT
|
|
1596
|
+
n.uuid AS uuid,
|
|
1597
|
+
split(n.name_embedding, ",") AS name_embedding
|
|
1011
1598
|
"""
|
|
1599
|
+
else:
|
|
1600
|
+
query = """
|
|
1012
1601
|
MATCH (n:Entity)
|
|
1013
1602
|
WHERE n.uuid IN $node_uuids
|
|
1014
1603
|
RETURN DISTINCT
|
|
1015
1604
|
n.uuid AS uuid,
|
|
1016
1605
|
n.name_embedding AS name_embedding
|
|
1017
|
-
"""
|
|
1606
|
+
"""
|
|
1607
|
+
results, _, _ = await driver.execute_query(
|
|
1608
|
+
query,
|
|
1018
1609
|
node_uuids=[node.uuid for node in nodes],
|
|
1019
1610
|
routing_='r',
|
|
1020
1611
|
)
|
|
@@ -1032,14 +1623,24 @@ async def get_embeddings_for_nodes(
|
|
|
1032
1623
|
async def get_embeddings_for_communities(
|
|
1033
1624
|
driver: GraphDriver, communities: list[CommunityNode]
|
|
1034
1625
|
) -> dict[str, list[float]]:
|
|
1035
|
-
|
|
1626
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1627
|
+
query = """
|
|
1628
|
+
MATCH (c:Community)
|
|
1629
|
+
WHERE c.uuid IN $community_uuids
|
|
1630
|
+
RETURN DISTINCT
|
|
1631
|
+
c.uuid AS uuid,
|
|
1632
|
+
split(c.name_embedding, ",") AS name_embedding
|
|
1036
1633
|
"""
|
|
1634
|
+
else:
|
|
1635
|
+
query = """
|
|
1037
1636
|
MATCH (c:Community)
|
|
1038
1637
|
WHERE c.uuid IN $community_uuids
|
|
1039
1638
|
RETURN DISTINCT
|
|
1040
1639
|
c.uuid AS uuid,
|
|
1041
1640
|
c.name_embedding AS name_embedding
|
|
1042
|
-
"""
|
|
1641
|
+
"""
|
|
1642
|
+
results, _, _ = await driver.execute_query(
|
|
1643
|
+
query,
|
|
1043
1644
|
community_uuids=[community.uuid for community in communities],
|
|
1044
1645
|
routing_='r',
|
|
1045
1646
|
)
|
|
@@ -1057,14 +1658,24 @@ async def get_embeddings_for_communities(
|
|
|
1057
1658
|
async def get_embeddings_for_edges(
|
|
1058
1659
|
driver: GraphDriver, edges: list[EntityEdge]
|
|
1059
1660
|
) -> dict[str, list[float]]:
|
|
1060
|
-
|
|
1661
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1662
|
+
query = """
|
|
1663
|
+
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1664
|
+
WHERE e.uuid IN $edge_uuids
|
|
1665
|
+
RETURN DISTINCT
|
|
1666
|
+
e.uuid AS uuid,
|
|
1667
|
+
split(e.fact_embedding, ",") AS fact_embedding
|
|
1061
1668
|
"""
|
|
1669
|
+
else:
|
|
1670
|
+
query = """
|
|
1062
1671
|
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1063
1672
|
WHERE e.uuid IN $edge_uuids
|
|
1064
1673
|
RETURN DISTINCT
|
|
1065
1674
|
e.uuid AS uuid,
|
|
1066
1675
|
e.fact_embedding AS fact_embedding
|
|
1067
|
-
"""
|
|
1676
|
+
"""
|
|
1677
|
+
results, _, _ = await driver.execute_query(
|
|
1678
|
+
query,
|
|
1068
1679
|
edge_uuids=[edge.uuid for edge in edges],
|
|
1069
1680
|
routing_='r',
|
|
1070
1681
|
)
|