graphiti-core 0.18.8__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of graphiti-core might be problematic. Click here for more details.

Files changed (32) hide show
  1. graphiti_core/driver/driver.py +4 -0
  2. graphiti_core/driver/falkordb_driver.py +3 -14
  3. graphiti_core/driver/kuzu_driver.py +175 -0
  4. graphiti_core/driver/neptune_driver.py +301 -0
  5. graphiti_core/edges.py +155 -62
  6. graphiti_core/graph_queries.py +31 -2
  7. graphiti_core/graphiti.py +6 -1
  8. graphiti_core/helpers.py +8 -8
  9. graphiti_core/llm_client/config.py +1 -1
  10. graphiti_core/llm_client/openai_base_client.py +15 -5
  11. graphiti_core/llm_client/openai_client.py +16 -6
  12. graphiti_core/migrations/__init__.py +0 -0
  13. graphiti_core/migrations/neo4j_node_group_labels.py +114 -0
  14. graphiti_core/models/edges/edge_db_queries.py +205 -76
  15. graphiti_core/models/nodes/node_db_queries.py +253 -74
  16. graphiti_core/nodes.py +271 -98
  17. graphiti_core/prompts/extract_edges.py +1 -0
  18. graphiti_core/prompts/extract_nodes.py +1 -1
  19. graphiti_core/search/search.py +42 -12
  20. graphiti_core/search/search_config.py +4 -0
  21. graphiti_core/search/search_filters.py +35 -22
  22. graphiti_core/search/search_utils.py +1329 -392
  23. graphiti_core/utils/bulk_utils.py +50 -15
  24. graphiti_core/utils/datetime_utils.py +13 -0
  25. graphiti_core/utils/maintenance/community_operations.py +39 -32
  26. graphiti_core/utils/maintenance/edge_operations.py +47 -13
  27. graphiti_core/utils/maintenance/graph_data_operations.py +100 -15
  28. graphiti_core/utils/maintenance/node_operations.py +7 -3
  29. {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0.dist-info}/METADATA +87 -13
  30. {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0.dist-info}/RECORD +32 -28
  31. {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0.dist-info}/WHEEL +0 -0
  32. {graphiti_core-0.18.8.dist-info → graphiti_core-0.19.0.dist-info}/licenses/LICENSE +0 -0
@@ -15,6 +15,7 @@ limitations under the License.
15
15
  """
16
16
 
17
17
  import logging
18
+ import os
18
19
  from collections import defaultdict
19
20
  from time import time
20
21
  from typing import Any
@@ -36,10 +37,13 @@ from graphiti_core.helpers import (
36
37
  normalize_l2,
37
38
  semaphore_gather,
38
39
  )
39
- from graphiti_core.models.edges.edge_db_queries import ENTITY_EDGE_RETURN
40
- from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN, EPISODIC_NODE_RETURN
40
+ from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query
41
+ from graphiti_core.models.nodes.node_db_queries import (
42
+ COMMUNITY_NODE_RETURN,
43
+ EPISODIC_NODE_RETURN,
44
+ get_entity_node_return_query,
45
+ )
41
46
  from graphiti_core.nodes import (
42
- ENTITY_NODE_RETURN,
43
47
  CommunityNode,
44
48
  EntityNode,
45
49
  EpisodicNode,
@@ -54,6 +58,7 @@ from graphiti_core.search.search_filters import (
54
58
  )
55
59
 
56
60
  logger = logging.getLogger(__name__)
61
+ USE_HNSW = os.getenv('USE_HNSW', '').lower() in ('true', '1', 'yes')
57
62
 
58
63
  RELEVANT_SCHEMA_LIMIT = 10
59
64
  DEFAULT_MIN_SCORE = 0.6
@@ -62,9 +67,30 @@ MAX_SEARCH_DEPTH = 3
62
67
  MAX_QUERY_LENGTH = 128
63
68
 
64
69
 
65
- def fulltext_query(query: str, group_ids: list[str] | None = None, fulltext_syntax: str = ''):
70
+ def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> float:
71
+ """
72
+ Calculates the cosine similarity between two vectors using NumPy.
73
+ """
74
+ dot_product = np.dot(vector1, vector2)
75
+ norm_vector1 = np.linalg.norm(vector1)
76
+ norm_vector2 = np.linalg.norm(vector2)
77
+
78
+ if norm_vector1 == 0 or norm_vector2 == 0:
79
+ return 0 # Handle cases where one or both vectors are zero vectors
80
+
81
+ return dot_product / (norm_vector1 * norm_vector2)
82
+
83
+
84
+ def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver):
85
+ if driver.provider == GraphProvider.KUZU:
86
+ # Kuzu only supports simple queries.
87
+ if len(query.split(' ')) > MAX_QUERY_LENGTH:
88
+ return ''
89
+ return query
66
90
  group_ids_filter_list = (
67
- [fulltext_syntax + f'group_id:"{g}"' for g in group_ids] if group_ids is not None else []
91
+ [driver.fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
92
+ if group_ids is not None
93
+ else []
68
94
  )
69
95
  group_ids_filter = ''
70
96
  for f in group_ids_filter_list:
@@ -108,12 +134,12 @@ async def get_mentioned_nodes(
108
134
  WHERE episode.uuid IN $uuids
109
135
  RETURN DISTINCT
110
136
  """
111
- + ENTITY_NODE_RETURN,
137
+ + get_entity_node_return_query(driver.provider),
112
138
  uuids=episode_uuids,
113
139
  routing_='r',
114
140
  )
115
141
 
116
- nodes = [get_entity_node_from_record(record) for record in records]
142
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
117
143
 
118
144
  return nodes
119
145
 
@@ -125,7 +151,7 @@ async def get_communities_by_nodes(
125
151
 
126
152
  records, _, _ = await driver.execute_query(
127
153
  """
128
- MATCH (n:Community)-[:HAS_MEMBER]->(m:Entity)
154
+ MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)
129
155
  WHERE m.uuid IN $uuids
130
156
  RETURN DISTINCT
131
157
  """
@@ -147,40 +173,105 @@ async def edge_fulltext_search(
147
173
  limit=RELEVANT_SCHEMA_LIMIT,
148
174
  ) -> list[EntityEdge]:
149
175
  # fulltext search over facts
150
- fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
176
+ fuzzy_query = fulltext_query(query, group_ids, driver)
177
+
151
178
  if fuzzy_query == '':
152
179
  return []
153
180
 
154
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
155
-
156
- query = (
157
- get_relationships_query('edge_name_and_fact', provider=driver.provider)
158
- + """
159
- YIELD relationship AS rel, score
160
- MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
161
- WHERE e.group_id IN $group_ids """
162
- + filter_query
163
- + """
164
- WITH e, score, n, m
165
- RETURN
166
- """
167
- + ENTITY_EDGE_RETURN
168
- + """
169
- ORDER BY score DESC
170
- LIMIT $limit
181
+ match_query = """
182
+ YIELD relationship AS rel, score
183
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
184
+ """
185
+ if driver.provider == GraphProvider.KUZU:
186
+ match_query = """
187
+ YIELD node, score
188
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: node.uuid})-[:RELATES_TO]->(m:Entity)
171
189
  """
172
- )
173
190
 
174
- records, _, _ = await driver.execute_query(
175
- query,
176
- query=fuzzy_query,
177
- group_ids=group_ids,
178
- limit=limit,
179
- routing_='r',
180
- **filter_params,
191
+ filter_queries, filter_params = edge_search_filter_query_constructor(
192
+ search_filter, driver.provider
181
193
  )
182
194
 
183
- edges = [get_entity_edge_from_record(record) for record in records]
195
+ if group_ids is not None:
196
+ filter_queries.append('e.group_id IN $group_ids')
197
+ filter_params['group_ids'] = group_ids
198
+
199
+ filter_query = ''
200
+ if filter_queries:
201
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
202
+
203
+ if driver.provider == GraphProvider.NEPTUNE:
204
+ res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
205
+ if res['hits']['total']['value'] > 0:
206
+ # Calculate Cosine similarity then return the edge ids
207
+ input_ids = []
208
+ for r in res['hits']['hits']:
209
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
210
+
211
+ # Match the edge ids and return the values
212
+ query = (
213
+ """
214
+ UNWIND $ids as id
215
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
216
+ WHERE e.group_id IN $group_ids
217
+ AND id(e)=id
218
+ """
219
+ + filter_query
220
+ + """
221
+ AND id(e)=id
222
+ WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
223
+ RETURN
224
+ e.uuid AS uuid,
225
+ e.group_id AS group_id,
226
+ n.uuid AS source_node_uuid,
227
+ m.uuid AS target_node_uuid,
228
+ e.created_at AS created_at,
229
+ e.name AS name,
230
+ e.fact AS fact,
231
+ split(e.episodes, ",") AS episodes,
232
+ e.expired_at AS expired_at,
233
+ e.valid_at AS valid_at,
234
+ e.invalid_at AS invalid_at,
235
+ properties(e) AS attributes
236
+ ORDER BY score DESC LIMIT $limit
237
+ """
238
+ )
239
+
240
+ records, _, _ = await driver.execute_query(
241
+ query,
242
+ query=fuzzy_query,
243
+ ids=input_ids,
244
+ limit=limit,
245
+ routing_='r',
246
+ **filter_params,
247
+ )
248
+ else:
249
+ return []
250
+ else:
251
+ query = (
252
+ get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
253
+ + match_query
254
+ + filter_query
255
+ + """
256
+ WITH e, score, n, m
257
+ RETURN
258
+ """
259
+ + get_entity_edge_return_query(driver.provider)
260
+ + """
261
+ ORDER BY score DESC
262
+ LIMIT $limit
263
+ """
264
+ )
265
+
266
+ records, _, _ = await driver.execute_query(
267
+ query,
268
+ query=fuzzy_query,
269
+ limit=limit,
270
+ routing_='r',
271
+ **filter_params,
272
+ )
273
+
274
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
184
275
 
185
276
  return edges
186
277
 
@@ -195,56 +286,130 @@ async def edge_similarity_search(
195
286
  limit: int = RELEVANT_SCHEMA_LIMIT,
196
287
  min_score: float = DEFAULT_MIN_SCORE,
197
288
  ) -> list[EntityEdge]:
198
- # vector similarity search over embedded facts
199
- query_params: dict[str, Any] = {}
289
+ match_query = """
290
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
291
+ """
292
+ if driver.provider == GraphProvider.KUZU:
293
+ match_query = """
294
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
295
+ """
200
296
 
201
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
202
- query_params.update(filter_params)
297
+ filter_queries, filter_params = edge_search_filter_query_constructor(
298
+ search_filter, driver.provider
299
+ )
203
300
 
204
- group_filter_query: LiteralString = 'WHERE e.group_id IS NOT NULL'
205
301
  if group_ids is not None:
206
- group_filter_query += '\nAND e.group_id IN $group_ids'
207
- query_params['group_ids'] = group_ids
302
+ filter_queries.append('e.group_id IN $group_ids')
303
+ filter_params['group_ids'] = group_ids
208
304
 
209
305
  if source_node_uuid is not None:
210
- query_params['source_uuid'] = source_node_uuid
211
- group_filter_query += '\nAND (n.uuid = $source_uuid)'
306
+ filter_params['source_uuid'] = source_node_uuid
307
+ filter_queries.append('n.uuid = $source_uuid')
212
308
 
213
309
  if target_node_uuid is not None:
214
- query_params['target_uuid'] = target_node_uuid
215
- group_filter_query += '\nAND (m.uuid = $target_uuid)'
310
+ filter_params['target_uuid'] = target_node_uuid
311
+ filter_queries.append('m.uuid = $target_uuid')
312
+
313
+ filter_query = ''
314
+ if filter_queries:
315
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
316
+
317
+ search_vector_var = '$search_vector'
318
+ if driver.provider == GraphProvider.KUZU:
319
+ search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
320
+
321
+ if driver.provider == GraphProvider.NEPTUNE:
322
+ query = (
323
+ RUNTIME_QUERY
324
+ + """
325
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
326
+ """
327
+ + filter_query
328
+ + """
329
+ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
330
+ """
331
+ )
332
+ resp, header, _ = await driver.execute_query(
333
+ query,
334
+ search_vector=search_vector,
335
+ limit=limit,
336
+ min_score=min_score,
337
+ routing_='r',
338
+ **filter_params,
339
+ )
216
340
 
217
- query = (
218
- RUNTIME_QUERY
219
- + """
220
- MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
221
- """
222
- + group_filter_query
223
- + filter_query
224
- + """
225
- WITH DISTINCT e, n, m, """
226
- + get_vector_cosine_func_query('e.fact_embedding', '$search_vector', driver.provider)
227
- + """ AS score
228
- WHERE score > $min_score
229
- RETURN
230
- """
231
- + ENTITY_EDGE_RETURN
232
- + """
233
- ORDER BY score DESC
234
- LIMIT $limit
235
- """
236
- )
341
+ if len(resp) > 0:
342
+ # Calculate Cosine similarity then return the edge ids
343
+ input_ids = []
344
+ for r in resp:
345
+ if r['embedding']:
346
+ score = calculate_cosine_similarity(
347
+ search_vector, list(map(float, r['embedding'].split(',')))
348
+ )
349
+ if score > min_score:
350
+ input_ids.append({'id': r['id'], 'score': score})
351
+
352
+ # Match the edge ides and return the values
353
+ query = """
354
+ UNWIND $ids as i
355
+ MATCH ()-[r]->()
356
+ WHERE id(r) = i.id
357
+ RETURN
358
+ r.uuid AS uuid,
359
+ r.group_id AS group_id,
360
+ startNode(r).uuid AS source_node_uuid,
361
+ endNode(r).uuid AS target_node_uuid,
362
+ r.created_at AS created_at,
363
+ r.name AS name,
364
+ r.fact AS fact,
365
+ split(r.episodes, ",") AS episodes,
366
+ r.expired_at AS expired_at,
367
+ r.valid_at AS valid_at,
368
+ r.invalid_at AS invalid_at,
369
+ properties(r) AS attributes
370
+ ORDER BY i.score DESC
371
+ LIMIT $limit
372
+ """
373
+ records, _, _ = await driver.execute_query(
374
+ query,
375
+ ids=input_ids,
376
+ search_vector=search_vector,
377
+ limit=limit,
378
+ min_score=min_score,
379
+ routing_='r',
380
+ **filter_params,
381
+ )
382
+ else:
383
+ return []
384
+ else:
385
+ query = (
386
+ RUNTIME_QUERY
387
+ + match_query
388
+ + filter_query
389
+ + """
390
+ WITH DISTINCT e, n, m, """
391
+ + get_vector_cosine_func_query('e.fact_embedding', search_vector_var, driver.provider)
392
+ + """ AS score
393
+ WHERE score > $min_score
394
+ RETURN
395
+ """
396
+ + get_entity_edge_return_query(driver.provider)
397
+ + """
398
+ ORDER BY score DESC
399
+ LIMIT $limit
400
+ """
401
+ )
237
402
 
238
- records, _, _ = await driver.execute_query(
239
- query,
240
- search_vector=search_vector,
241
- limit=limit,
242
- min_score=min_score,
243
- routing_='r',
244
- **query_params,
245
- )
403
+ records, _, _ = await driver.execute_query(
404
+ query,
405
+ search_vector=search_vector,
406
+ limit=limit,
407
+ min_score=min_score,
408
+ routing_='r',
409
+ **filter_params,
410
+ )
246
411
 
247
- edges = [get_entity_edge_from_record(record) for record in records]
412
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
248
413
 
249
414
  return edges
250
415
 
@@ -258,40 +423,116 @@ async def edge_bfs_search(
258
423
  limit: int = RELEVANT_SCHEMA_LIMIT,
259
424
  ) -> list[EntityEdge]:
260
425
  # vector similarity search over embedded facts
261
- if bfs_origin_node_uuids is None:
426
+ if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0:
262
427
  return []
263
428
 
264
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
265
-
266
- query = (
267
- f"""
268
- UNWIND $bfs_origin_node_uuids AS origin_uuid
269
- MATCH path = (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
270
- UNWIND relationships(path) AS rel
271
- MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
272
- WHERE e.uuid = rel.uuid
273
- AND e.group_id IN $group_ids
274
- """
275
- + filter_query
276
- + """
277
- RETURN DISTINCT
278
- """
279
- + ENTITY_EDGE_RETURN
280
- + """
281
- LIMIT $limit
282
- """
429
+ filter_queries, filter_params = edge_search_filter_query_constructor(
430
+ search_filter, driver.provider
283
431
  )
284
432
 
285
- records, _, _ = await driver.execute_query(
286
- query,
287
- bfs_origin_node_uuids=bfs_origin_node_uuids,
288
- group_ids=group_ids,
289
- limit=limit,
290
- routing_='r',
291
- **filter_params,
292
- )
433
+ if group_ids is not None:
434
+ filter_queries.append('e.group_id IN $group_ids')
435
+ filter_params['group_ids'] = group_ids
436
+
437
+ filter_query = ''
438
+ if filter_queries:
439
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
440
+
441
+ if driver.provider == GraphProvider.KUZU:
442
+ # Kuzu stores entity edges twice with an intermediate node, so we need to match them
443
+ # separately for the correct BFS depth.
444
+ depth = bfs_max_depth * 2 - 1
445
+ match_queries = [
446
+ f"""
447
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
448
+ MATCH path = (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
449
+ UNWIND nodes(path) AS relNode
450
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
451
+ """,
452
+ ]
453
+ if bfs_max_depth > 1:
454
+ depth = (bfs_max_depth - 1) * 2 - 1
455
+ match_queries.append(f"""
456
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
457
+ MATCH path = (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
458
+ UNWIND nodes(path) AS relNode
459
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
460
+ """)
461
+
462
+ records = []
463
+ for match_query in match_queries:
464
+ sub_records, _, _ = await driver.execute_query(
465
+ match_query
466
+ + filter_query
467
+ + """
468
+ RETURN DISTINCT
469
+ """
470
+ + get_entity_edge_return_query(driver.provider)
471
+ + """
472
+ LIMIT $limit
473
+ """,
474
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
475
+ limit=limit,
476
+ routing_='r',
477
+ **filter_params,
478
+ )
479
+ records.extend(sub_records)
480
+ else:
481
+ if driver.provider == GraphProvider.NEPTUNE:
482
+ query = (
483
+ f"""
484
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
485
+ MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
486
+ WHERE origin:Entity OR origin:Episodic
487
+ UNWIND relationships(path) AS rel
488
+ MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
489
+ """
490
+ + filter_query
491
+ + """
492
+ RETURN DISTINCT
493
+ e.uuid AS uuid,
494
+ e.group_id AS group_id,
495
+ startNode(e).uuid AS source_node_uuid,
496
+ endNode(e).uuid AS target_node_uuid,
497
+ e.created_at AS created_at,
498
+ e.name AS name,
499
+ e.fact AS fact,
500
+ split(e.episodes, ',') AS episodes,
501
+ e.expired_at AS expired_at,
502
+ e.valid_at AS valid_at,
503
+ e.invalid_at AS invalid_at,
504
+ properties(e) AS attributes
505
+ LIMIT $limit
506
+ """
507
+ )
508
+ else:
509
+ query = (
510
+ f"""
511
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
512
+ MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
513
+ UNWIND relationships(path) AS rel
514
+ MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
515
+ """
516
+ + filter_query
517
+ + """
518
+ RETURN DISTINCT
519
+ """
520
+ + get_entity_edge_return_query(driver.provider)
521
+ + """
522
+ LIMIT $limit
523
+ """
524
+ )
525
+
526
+ records, _, _ = await driver.execute_query(
527
+ query,
528
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
529
+ depth=bfs_max_depth,
530
+ limit=limit,
531
+ routing_='r',
532
+ **filter_params,
533
+ )
293
534
 
294
- edges = [get_entity_edge_from_record(record) for record in records]
535
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
295
536
 
296
537
  return edges
297
538
 
@@ -302,39 +543,90 @@ async def node_fulltext_search(
302
543
  search_filter: SearchFilters,
303
544
  group_ids: list[str] | None = None,
304
545
  limit=RELEVANT_SCHEMA_LIMIT,
546
+ use_local_indexes: bool = False,
305
547
  ) -> list[EntityNode]:
306
548
  # BM25 search to get top nodes
307
- fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
549
+ fuzzy_query = fulltext_query(query, group_ids, driver)
308
550
  if fuzzy_query == '':
309
551
  return []
310
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
311
552
 
312
- query = (
313
- get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
314
- + """
315
- YIELD node AS n, score
316
- WHERE n:Entity AND n.group_id IN $group_ids
317
- """
318
- + filter_query
319
- + """
320
- WITH n, score
321
- ORDER BY score DESC
322
- LIMIT $limit
323
- RETURN
324
- """
325
- + ENTITY_NODE_RETURN
553
+ filter_queries, filter_params = node_search_filter_query_constructor(
554
+ search_filter, driver.provider
326
555
  )
327
556
 
328
- records, _, _ = await driver.execute_query(
329
- query,
330
- query=fuzzy_query,
331
- group_ids=group_ids,
332
- limit=limit,
333
- routing_='r',
334
- **filter_params,
335
- )
557
+ if group_ids is not None:
558
+ filter_queries.append('n.group_id IN $group_ids')
559
+ filter_params['group_ids'] = group_ids
560
+
561
+ filter_query = ''
562
+ if filter_queries:
563
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
564
+
565
+ yield_query = 'YIELD node AS n, score'
566
+ if driver.provider == GraphProvider.KUZU:
567
+ yield_query = 'WITH node AS n, score'
568
+
569
+ if driver.provider == GraphProvider.NEPTUNE:
570
+ res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
571
+ if res['hits']['total']['value'] > 0:
572
+ # Calculate Cosine similarity then return the edge ids
573
+ input_ids = []
574
+ for r in res['hits']['hits']:
575
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
576
+
577
+ # Match the edge ides and return the values
578
+ query = (
579
+ """
580
+ UNWIND $ids as i
581
+ MATCH (n:Entity)
582
+ WHERE n.uuid=i.id
583
+ RETURN
584
+ """
585
+ + get_entity_node_return_query(driver.provider)
586
+ + """
587
+ ORDER BY i.score DESC
588
+ LIMIT $limit
589
+ """
590
+ )
591
+ records, _, _ = await driver.execute_query(
592
+ query,
593
+ ids=input_ids,
594
+ query=fuzzy_query,
595
+ limit=limit,
596
+ routing_='r',
597
+ **filter_params,
598
+ )
599
+ else:
600
+ return []
601
+ else:
602
+ index_name = (
603
+ 'node_name_and_summary'
604
+ if not use_local_indexes
605
+ else 'node_name_and_summary_'
606
+ + (group_ids[0].replace('-', '') if group_ids is not None else '')
607
+ )
608
+ query = (
609
+ get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
610
+ + yield_query
611
+ + filter_query
612
+ + """
613
+ WITH n, score
614
+ ORDER BY score DESC
615
+ LIMIT $limit
616
+ RETURN
617
+ """
618
+ + get_entity_node_return_query(driver.provider)
619
+ )
336
620
 
337
- nodes = [get_entity_node_from_record(record) for record in records]
621
+ records, _, _ = await driver.execute_query(
622
+ query,
623
+ query=fuzzy_query,
624
+ limit=limit,
625
+ routing_='r',
626
+ **filter_params,
627
+ )
628
+
629
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
338
630
 
339
631
  return nodes
340
632
 
@@ -346,49 +638,140 @@ async def node_similarity_search(
346
638
  group_ids: list[str] | None = None,
347
639
  limit=RELEVANT_SCHEMA_LIMIT,
348
640
  min_score: float = DEFAULT_MIN_SCORE,
641
+ use_local_indexes: bool = False,
349
642
  ) -> list[EntityNode]:
350
- # vector similarity search over entity names
351
- query_params: dict[str, Any] = {}
643
+ filter_queries, filter_params = node_search_filter_query_constructor(
644
+ search_filter, driver.provider
645
+ )
352
646
 
353
- group_filter_query: LiteralString = 'WHERE n.group_id IS NOT NULL'
354
647
  if group_ids is not None:
355
- group_filter_query += ' AND n.group_id IN $group_ids'
356
- query_params['group_ids'] = group_ids
648
+ filter_queries.append('n.group_id IN $group_ids')
649
+ filter_params['group_ids'] = group_ids
650
+
651
+ filter_query = ''
652
+ if filter_queries:
653
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
654
+
655
+ search_vector_var = '$search_vector'
656
+ if driver.provider == GraphProvider.KUZU:
657
+ search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
658
+
659
+ if driver.provider == GraphProvider.NEPTUNE:
660
+ query = (
661
+ RUNTIME_QUERY
662
+ + """
663
+ MATCH (n:Entity)
664
+ """
665
+ + filter_query
666
+ + """
667
+ RETURN DISTINCT id(n) as id, n.name_embedding as embedding
668
+ """
669
+ )
670
+ resp, header, _ = await driver.execute_query(
671
+ query,
672
+ params=filter_params,
673
+ search_vector=search_vector,
674
+ limit=limit,
675
+ min_score=min_score,
676
+ routing_='r',
677
+ )
357
678
 
358
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
359
- query_params.update(filter_params)
679
+ if len(resp) > 0:
680
+ # Calculate Cosine similarity then return the edge ids
681
+ input_ids = []
682
+ for r in resp:
683
+ if r['embedding']:
684
+ score = calculate_cosine_similarity(
685
+ search_vector, list(map(float, r['embedding'].split(',')))
686
+ )
687
+ if score > min_score:
688
+ input_ids.append({'id': r['id'], 'score': score})
689
+
690
+ # Match the edge ides and return the values
691
+ query = (
692
+ """
693
+ UNWIND $ids as i
694
+ MATCH (n:Entity)
695
+ WHERE id(n)=i.id
696
+ RETURN
697
+ """
698
+ + get_entity_node_return_query(driver.provider)
699
+ + """
700
+ ORDER BY i.score DESC
701
+ LIMIT $limit
702
+ """
703
+ )
704
+ records, header, _ = await driver.execute_query(
705
+ query,
706
+ ids=input_ids,
707
+ search_vector=search_vector,
708
+ limit=limit,
709
+ min_score=min_score,
710
+ routing_='r',
711
+ **filter_params,
712
+ )
713
+ else:
714
+ return []
715
+ elif driver.provider == GraphProvider.NEO4J and use_local_indexes:
716
+ index_name = 'group_entity_vector_' + (
717
+ group_ids[0].replace('-', '') if group_ids is not None else ''
718
+ )
719
+ query = (
720
+ f"""
721
+ CALL db.index.vector.queryNodes('{index_name}', {limit}, $search_vector) YIELD node AS n, score
722
+ """
723
+ + filter_query
724
+ + """
725
+ AND score > $min_score
726
+ RETURN
727
+ """
728
+ + get_entity_node_return_query(driver.provider)
729
+ + """
730
+ ORDER BY score DESC
731
+ LIMIT $limit
732
+ """
733
+ )
360
734
 
361
- query = (
362
- RUNTIME_QUERY
363
- + """
364
- MATCH (n:Entity)
365
- """
366
- + group_filter_query
367
- + filter_query
368
- + """
369
- WITH n, """
370
- + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
371
- + """ AS score
372
- WHERE score > $min_score
373
- RETURN
374
- """
375
- + ENTITY_NODE_RETURN
376
- + """
377
- ORDER BY score DESC
378
- LIMIT $limit
379
- """
380
- )
735
+ records, _, _ = await driver.execute_query(
736
+ query,
737
+ search_vector=search_vector,
738
+ limit=limit,
739
+ min_score=min_score,
740
+ routing_='r',
741
+ **filter_params,
742
+ )
381
743
 
382
- records, _, _ = await driver.execute_query(
383
- query,
384
- search_vector=search_vector,
385
- limit=limit,
386
- min_score=min_score,
387
- routing_='r',
388
- **query_params,
389
- )
744
+ else:
745
+ query = (
746
+ RUNTIME_QUERY
747
+ + """
748
+ MATCH (n:Entity)
749
+ """
750
+ + filter_query
751
+ + """
752
+ WITH n, """
753
+ + get_vector_cosine_func_query('n.name_embedding', search_vector_var, driver.provider)
754
+ + """ AS score
755
+ WHERE score > $min_score
756
+ RETURN
757
+ """
758
+ + get_entity_node_return_query(driver.provider)
759
+ + """
760
+ ORDER BY score DESC
761
+ LIMIT $limit
762
+ """
763
+ )
764
+
765
+ records, _, _ = await driver.execute_query(
766
+ query,
767
+ search_vector=search_vector,
768
+ limit=limit,
769
+ min_score=min_score,
770
+ routing_='r',
771
+ **filter_params,
772
+ )
390
773
 
391
- nodes = [get_entity_node_from_record(record) for record in records]
774
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
392
775
 
393
776
  return nodes
394
777
 
@@ -401,38 +784,82 @@ async def node_bfs_search(
401
784
  group_ids: list[str] | None = None,
402
785
  limit: int = RELEVANT_SCHEMA_LIMIT,
403
786
  ) -> list[EntityNode]:
404
- # vector similarity search over entity names
405
- if bfs_origin_node_uuids is None:
787
+ if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0 or bfs_max_depth < 1:
406
788
  return []
407
789
 
408
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
790
+ filter_queries, filter_params = node_search_filter_query_constructor(
791
+ search_filter, driver.provider
792
+ )
793
+
794
+ if group_ids is not None:
795
+ filter_queries.append('n.group_id IN $group_ids')
796
+ filter_queries.append('origin.group_id IN $group_ids')
797
+ filter_params['group_ids'] = group_ids
798
+
799
+ filter_query = ''
800
+ if filter_queries:
801
+ filter_query = ' AND ' + (' AND '.join(filter_queries))
409
802
 
410
- query = (
803
+ match_queries = [
411
804
  f"""
805
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
806
+ MATCH (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
807
+ WHERE n.group_id = origin.group_id
808
+ """
809
+ ]
810
+
811
+ if driver.provider == GraphProvider.NEPTUNE:
812
+ match_queries = [
813
+ f"""
412
814
  UNWIND $bfs_origin_node_uuids AS origin_uuid
413
- MATCH (origin:Entity|Episodic {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
815
+ MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
816
+ WHERE origin:Entity OR origin.Episode
817
+ AND n.group_id = origin.group_id
818
+ """
819
+ ]
820
+
821
+ if driver.provider == GraphProvider.KUZU:
822
+ depth = bfs_max_depth * 2
823
+ match_queries = [
824
+ """
825
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
826
+ MATCH (origin:Episodic {uuid: origin_uuid})-[:MENTIONS]->(n:Entity)
414
827
  WHERE n.group_id = origin.group_id
415
- AND origin.group_id IN $group_ids
416
- """
417
- + filter_query
418
- + """
419
- RETURN
420
- """
421
- + ENTITY_NODE_RETURN
422
- + """
423
- LIMIT $limit
424
- """
425
- )
828
+ """,
829
+ f"""
830
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
831
+ MATCH (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*2..{depth}]->(n:Entity)
832
+ WHERE n.group_id = origin.group_id
833
+ """,
834
+ ]
835
+ if bfs_max_depth > 1:
836
+ depth = (bfs_max_depth - 1) * 2
837
+ match_queries.append(f"""
838
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
839
+ MATCH (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*2..{depth}]->(n:Entity)
840
+ WHERE n.group_id = origin.group_id
841
+ """)
842
+
843
+ records = []
844
+ for match_query in match_queries:
845
+ sub_records, _, _ = await driver.execute_query(
846
+ match_query
847
+ + filter_query
848
+ + """
849
+ RETURN
850
+ """
851
+ + get_entity_node_return_query(driver.provider)
852
+ + """
853
+ LIMIT $limit
854
+ """,
855
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
856
+ limit=limit,
857
+ routing_='r',
858
+ **filter_params,
859
+ )
860
+ records.extend(sub_records)
426
861
 
427
- records, _, _ = await driver.execute_query(
428
- query,
429
- bfs_origin_node_uuids=bfs_origin_node_uuids,
430
- group_ids=group_ids,
431
- limit=limit,
432
- routing_='r',
433
- **filter_params,
434
- )
435
- nodes = [get_entity_node_from_record(record) for record in records]
862
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
436
863
 
437
864
  return nodes
438
865
 
@@ -443,35 +870,84 @@ async def episode_fulltext_search(
443
870
  _search_filter: SearchFilters,
444
871
  group_ids: list[str] | None = None,
445
872
  limit=RELEVANT_SCHEMA_LIMIT,
873
+ use_local_indexes: bool = False,
446
874
  ) -> list[EpisodicNode]:
447
875
  # BM25 search to get top episodes
448
- fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
876
+ fuzzy_query = fulltext_query(query, group_ids, driver)
449
877
  if fuzzy_query == '':
450
878
  return []
451
879
 
452
- query = (
453
- get_nodes_query(driver.provider, 'episode_content', '$query')
454
- + """
455
- YIELD node AS episode, score
456
- MATCH (e:Episodic)
457
- WHERE e.uuid = episode.uuid
458
- AND e.group_id IN $group_ids
459
- RETURN
460
- """
461
- + EPISODIC_NODE_RETURN
462
- + """
463
- ORDER BY score DESC
464
- LIMIT $limit
465
- """
466
- )
880
+ filter_params: dict[str, Any] = {}
881
+ group_filter_query: LiteralString = ''
882
+ if group_ids is not None:
883
+ group_filter_query += '\nAND e.group_id IN $group_ids'
884
+ filter_params['group_ids'] = group_ids
885
+
886
+ if driver.provider == GraphProvider.NEPTUNE:
887
+ res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
888
+ if res['hits']['total']['value'] > 0:
889
+ # Calculate Cosine similarity then return the edge ids
890
+ input_ids = []
891
+ for r in res['hits']['hits']:
892
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
893
+
894
+ # Match the edge ides and return the values
895
+ query = """
896
+ UNWIND $ids as i
897
+ MATCH (e:Episodic)
898
+ WHERE e.uuid=i.id
899
+ RETURN
900
+ e.content AS content,
901
+ e.created_at AS created_at,
902
+ e.valid_at AS valid_at,
903
+ e.uuid AS uuid,
904
+ e.name AS name,
905
+ e.group_id AS group_id,
906
+ e.source_description AS source_description,
907
+ e.source AS source,
908
+ e.entity_edges AS entity_edges
909
+ ORDER BY i.score DESC
910
+ LIMIT $limit
911
+ """
912
+ records, _, _ = await driver.execute_query(
913
+ query,
914
+ ids=input_ids,
915
+ query=fuzzy_query,
916
+ limit=limit,
917
+ routing_='r',
918
+ **filter_params,
919
+ )
920
+ else:
921
+ return []
922
+ else:
923
+ index_name = (
924
+ 'episode_content'
925
+ if not use_local_indexes
926
+ else 'episode_content_'
927
+ + (group_ids[0].replace('-', '') if group_ids is not None else '')
928
+ )
929
+ query = (
930
+ get_nodes_query(index_name, '$query', limit=limit, provider=driver.provider)
931
+ + """
932
+ YIELD node AS episode, score
933
+ MATCH (e:Episodic)
934
+ WHERE e.uuid = episode.uuid
935
+ """
936
+ + group_filter_query
937
+ + """
938
+ RETURN
939
+ """
940
+ + EPISODIC_NODE_RETURN
941
+ + """
942
+ ORDER BY score DESC
943
+ LIMIT $limit
944
+ """
945
+ )
946
+
947
+ records, _, _ = await driver.execute_query(
948
+ query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
949
+ )
467
950
 
468
- records, _, _ = await driver.execute_query(
469
- query,
470
- query=fuzzy_query,
471
- group_ids=group_ids,
472
- limit=limit,
473
- routing_='r',
474
- )
475
951
  episodes = [get_episodic_node_from_record(record) for record in records]
476
952
 
477
953
  return episodes
@@ -484,31 +960,75 @@ async def community_fulltext_search(
484
960
  limit=RELEVANT_SCHEMA_LIMIT,
485
961
  ) -> list[CommunityNode]:
486
962
  # BM25 search to get top communities
487
- fuzzy_query = fulltext_query(query, group_ids, driver.fulltext_syntax)
963
+ fuzzy_query = fulltext_query(query, group_ids, driver)
488
964
  if fuzzy_query == '':
489
965
  return []
490
966
 
491
- query = (
492
- get_nodes_query(driver.provider, 'community_name', '$query')
493
- + """
494
- YIELD node AS n, score
495
- WHERE n.group_id IN $group_ids
496
- RETURN
497
- """
498
- + COMMUNITY_NODE_RETURN
499
- + """
500
- ORDER BY score DESC
501
- LIMIT $limit
502
- """
503
- )
967
+ filter_params: dict[str, Any] = {}
968
+ group_filter_query: LiteralString = ''
969
+ if group_ids is not None:
970
+ group_filter_query = 'WHERE c.group_id IN $group_ids'
971
+ filter_params['group_ids'] = group_ids
972
+
973
+ yield_query = 'YIELD node AS c, score'
974
+ if driver.provider == GraphProvider.KUZU:
975
+ yield_query = 'WITH node AS c, score'
976
+
977
+ if driver.provider == GraphProvider.NEPTUNE:
978
+ res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
979
+ if res['hits']['total']['value'] > 0:
980
+ # Calculate Cosine similarity then return the edge ids
981
+ input_ids = []
982
+ for r in res['hits']['hits']:
983
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
984
+
985
+ # Match the edge ides and return the values
986
+ query = """
987
+ UNWIND $ids as i
988
+ MATCH (comm:Community)
989
+ WHERE comm.uuid=i.id
990
+ RETURN
991
+ comm.uuid AS uuid,
992
+ comm.group_id AS group_id,
993
+ comm.name AS name,
994
+ comm.created_at AS created_at,
995
+ comm.summary AS summary,
996
+ [x IN split(comm.name_embedding, ",") | toFloat(x)]AS name_embedding
997
+ ORDER BY i.score DESC
998
+ LIMIT $limit
999
+ """
1000
+ records, _, _ = await driver.execute_query(
1001
+ query,
1002
+ ids=input_ids,
1003
+ query=fuzzy_query,
1004
+ limit=limit,
1005
+ routing_='r',
1006
+ **filter_params,
1007
+ )
1008
+ else:
1009
+ return []
1010
+ else:
1011
+ query = (
1012
+ get_nodes_query('community_name', '$query', limit=limit, provider=driver.provider)
1013
+ + yield_query
1014
+ + """
1015
+ WITH c, score
1016
+ """
1017
+ + group_filter_query
1018
+ + """
1019
+ RETURN
1020
+ """
1021
+ + COMMUNITY_NODE_RETURN
1022
+ + """
1023
+ ORDER BY score DESC
1024
+ LIMIT $limit
1025
+ """
1026
+ )
1027
+
1028
+ records, _, _ = await driver.execute_query(
1029
+ query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
1030
+ )
504
1031
 
505
- records, _, _ = await driver.execute_query(
506
- query,
507
- query=fuzzy_query,
508
- group_ids=group_ids,
509
- limit=limit,
510
- routing_='r',
511
- )
512
1032
  communities = [get_community_node_from_record(record) for record in records]
513
1033
 
514
1034
  return communities
@@ -526,38 +1046,101 @@ async def community_similarity_search(
526
1046
 
527
1047
  group_filter_query: LiteralString = ''
528
1048
  if group_ids is not None:
529
- group_filter_query += 'WHERE n.group_id IN $group_ids'
1049
+ group_filter_query += ' WHERE c.group_id IN $group_ids'
530
1050
  query_params['group_ids'] = group_ids
531
1051
 
532
- query = (
533
- RUNTIME_QUERY
534
- + """
535
- MATCH (n:Community)
536
- """
537
- + group_filter_query
538
- + """
539
- WITH n,
540
- """
541
- + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
542
- + """ AS score
543
- WHERE score > $min_score
544
- RETURN
545
- """
546
- + COMMUNITY_NODE_RETURN
547
- + """
548
- ORDER BY score DESC
549
- LIMIT $limit
550
- """
551
- )
1052
+ if driver.provider == GraphProvider.NEPTUNE:
1053
+ query = (
1054
+ RUNTIME_QUERY
1055
+ + """
1056
+ MATCH (n:Community)
1057
+ """
1058
+ + group_filter_query
1059
+ + """
1060
+ RETURN DISTINCT id(n) as id, n.name_embedding as embedding
1061
+ """
1062
+ )
1063
+ resp, header, _ = await driver.execute_query(
1064
+ query,
1065
+ search_vector=search_vector,
1066
+ limit=limit,
1067
+ min_score=min_score,
1068
+ routing_='r',
1069
+ **query_params,
1070
+ )
1071
+
1072
+ if len(resp) > 0:
1073
+ # Calculate Cosine similarity then return the edge ids
1074
+ input_ids = []
1075
+ for r in resp:
1076
+ if r['embedding']:
1077
+ score = calculate_cosine_similarity(
1078
+ search_vector, list(map(float, r['embedding'].split(',')))
1079
+ )
1080
+ if score > min_score:
1081
+ input_ids.append({'id': r['id'], 'score': score})
1082
+
1083
+ # Match the edge ides and return the values
1084
+ query = """
1085
+ UNWIND $ids as i
1086
+ MATCH (comm:Community)
1087
+ WHERE id(comm)=i.id
1088
+ RETURN
1089
+ comm.uuid As uuid,
1090
+ comm.group_id AS group_id,
1091
+ comm.name AS name,
1092
+ comm.created_at AS created_at,
1093
+ comm.summary AS summary,
1094
+ comm.name_embedding AS name_embedding
1095
+ ORDER BY i.score DESC
1096
+ LIMIT $limit
1097
+ """
1098
+ records, header, _ = await driver.execute_query(
1099
+ query,
1100
+ ids=input_ids,
1101
+ search_vector=search_vector,
1102
+ limit=limit,
1103
+ min_score=min_score,
1104
+ routing_='r',
1105
+ **query_params,
1106
+ )
1107
+ else:
1108
+ return []
1109
+ else:
1110
+ search_vector_var = '$search_vector'
1111
+ if driver.provider == GraphProvider.KUZU:
1112
+ search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
1113
+
1114
+ query = (
1115
+ RUNTIME_QUERY
1116
+ + """
1117
+ MATCH (c:Community)
1118
+ """
1119
+ + group_filter_query
1120
+ + """
1121
+ WITH c,
1122
+ """
1123
+ + get_vector_cosine_func_query('c.name_embedding', search_vector_var, driver.provider)
1124
+ + """ AS score
1125
+ WHERE score > $min_score
1126
+ RETURN
1127
+ """
1128
+ + COMMUNITY_NODE_RETURN
1129
+ + """
1130
+ ORDER BY score DESC
1131
+ LIMIT $limit
1132
+ """
1133
+ )
1134
+
1135
+ records, _, _ = await driver.execute_query(
1136
+ query,
1137
+ search_vector=search_vector,
1138
+ limit=limit,
1139
+ min_score=min_score,
1140
+ routing_='r',
1141
+ **query_params,
1142
+ )
552
1143
 
553
- records, _, _ = await driver.execute_query(
554
- query,
555
- search_vector=search_vector,
556
- limit=limit,
557
- min_score=min_score,
558
- routing_='r',
559
- **query_params,
560
- )
561
1144
  communities = [get_community_node_from_record(record) for record in records]
562
1145
 
563
1146
  return communities
@@ -648,67 +1231,129 @@ async def get_relevant_nodes(
648
1231
  return []
649
1232
 
650
1233
  group_id = nodes[0].group_id
651
-
652
- # vector similarity search over entity names
653
- query_params: dict[str, Any] = {}
654
-
655
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
656
- query_params.update(filter_params)
657
-
658
- query = (
659
- RUNTIME_QUERY
660
- + """
661
- UNWIND $nodes AS node
662
- MATCH (n:Entity {group_id: $group_id})
663
- """
664
- + filter_query
665
- + """
666
- WITH node, n, """
667
- + get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
668
- + """ AS score
669
- WHERE score > $min_score
670
- WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
671
- """
672
- + get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
673
- + """
674
- YIELD node AS m
675
- WHERE m.group_id = $group_id
676
- WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
677
-
678
- WITH node,
679
- top_vector_nodes,
680
- [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
681
-
682
- WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
683
-
684
- UNWIND combined_nodes AS combined_node
685
- WITH node, collect(DISTINCT combined_node) AS deduped_nodes
686
-
687
- RETURN
688
- node.uuid AS search_node_uuid,
689
- [x IN deduped_nodes | {
690
- uuid: x.uuid,
691
- name: x.name,
692
- name_embedding: x.name_embedding,
693
- group_id: x.group_id,
694
- created_at: x.created_at,
695
- summary: x.summary,
696
- labels: labels(x),
697
- attributes: properties(x)
698
- }] AS matches
699
- """
700
- )
701
-
702
1234
  query_nodes = [
703
1235
  {
704
1236
  'uuid': node.uuid,
705
1237
  'name': node.name,
706
1238
  'name_embedding': node.name_embedding,
707
- 'fulltext_query': fulltext_query(node.name, [node.group_id], driver.fulltext_syntax),
1239
+ 'fulltext_query': fulltext_query(node.name, [node.group_id], driver),
708
1240
  }
709
1241
  for node in nodes
710
1242
  ]
711
1243
 
1244
+ filter_queries, filter_params = node_search_filter_query_constructor(
1245
+ search_filter, driver.provider
1246
+ )
1247
+
1248
+ filter_query = ''
1249
+ if filter_queries:
1250
+ filter_query = 'WHERE ' + (' AND '.join(filter_queries))
1251
+
1252
+ if driver.provider == GraphProvider.KUZU:
1253
+ embedding_size = len(nodes[0].name_embedding) if nodes[0].name_embedding is not None else 0
1254
+ if embedding_size == 0:
1255
+ return []
1256
+
1257
+ # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
1258
+ query = (
1259
+ RUNTIME_QUERY
1260
+ + """
1261
+ UNWIND $nodes AS node
1262
+ MATCH (n:Entity {group_id: $group_id})
1263
+ """
1264
+ + filter_query
1265
+ + """
1266
+ WITH node, n, """
1267
+ + get_vector_cosine_func_query(
1268
+ 'n.name_embedding',
1269
+ f'CAST(node.name_embedding AS FLOAT[{embedding_size}])',
1270
+ driver.provider,
1271
+ )
1272
+ + """ AS score
1273
+ WHERE score > $min_score
1274
+ WITH node, collect(n)[:$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
1275
+ """
1276
+ + get_nodes_query(
1277
+ 'node_name_and_summary',
1278
+ 'node.fulltext_query',
1279
+ limit=limit,
1280
+ provider=driver.provider,
1281
+ )
1282
+ + """
1283
+ WITH node AS m
1284
+ WHERE m.group_id = $group_id AND NOT m.uuid IN vector_node_uuids
1285
+ WITH node, top_vector_nodes, collect(m) AS fulltext_nodes
1286
+
1287
+ WITH node, list_concat(top_vector_nodes, fulltext_nodes) AS combined_nodes
1288
+
1289
+ UNWIND combined_nodes AS x
1290
+ WITH node, collect(DISTINCT {
1291
+ uuid: x.uuid,
1292
+ name: x.name,
1293
+ name_embedding: x.name_embedding,
1294
+ group_id: x.group_id,
1295
+ created_at: x.created_at,
1296
+ summary: x.summary,
1297
+ labels: x.labels,
1298
+ attributes: x.attributes
1299
+ }) AS matches
1300
+
1301
+ RETURN
1302
+ node.uuid AS search_node_uuid, matches
1303
+ """
1304
+ )
1305
+ else:
1306
+ query = (
1307
+ RUNTIME_QUERY
1308
+ + """
1309
+ UNWIND $nodes AS node
1310
+ MATCH (n:Entity {group_id: $group_id})
1311
+ """
1312
+ + filter_query
1313
+ + """
1314
+ WITH node, n, """
1315
+ + get_vector_cosine_func_query(
1316
+ 'n.name_embedding', 'node.name_embedding', driver.provider
1317
+ )
1318
+ + """ AS score
1319
+ WHERE score > $min_score
1320
+ WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
1321
+ """
1322
+ + get_nodes_query(
1323
+ 'node_name_and_summary',
1324
+ 'node.fulltext_query',
1325
+ limit=limit,
1326
+ provider=driver.provider,
1327
+ )
1328
+ + """
1329
+ YIELD node AS m
1330
+ WHERE m.group_id = $group_id
1331
+ WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
1332
+
1333
+ WITH node,
1334
+ top_vector_nodes,
1335
+ [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
1336
+
1337
+ WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
1338
+
1339
+ UNWIND combined_nodes AS combined_node
1340
+ WITH node, collect(DISTINCT combined_node) AS deduped_nodes
1341
+
1342
+ RETURN
1343
+ node.uuid AS search_node_uuid,
1344
+ [x IN deduped_nodes | {
1345
+ uuid: x.uuid,
1346
+ name: x.name,
1347
+ name_embedding: x.name_embedding,
1348
+ group_id: x.group_id,
1349
+ created_at: x.created_at,
1350
+ summary: x.summary,
1351
+ labels: labels(x),
1352
+ attributes: properties(x)
1353
+ }] AS matches
1354
+ """
1355
+ )
1356
+
712
1357
  results, _, _ = await driver.execute_query(
713
1358
  query,
714
1359
  nodes=query_nodes,
@@ -716,12 +1361,12 @@ async def get_relevant_nodes(
716
1361
  limit=limit,
717
1362
  min_score=min_score,
718
1363
  routing_='r',
719
- **query_params,
1364
+ **filter_params,
720
1365
  )
721
1366
 
722
1367
  relevant_nodes_dict: dict[str, list[EntityNode]] = {
723
1368
  result['search_node_uuid']: [
724
- get_entity_node_from_record(record) for record in result['matches']
1369
+ get_entity_node_from_record(record, driver.provider) for record in result['matches']
725
1370
  ]
726
1371
  for result in results
727
1372
  }
@@ -741,25 +1386,53 @@ async def get_relevant_edges(
741
1386
  if len(edges) == 0:
742
1387
  return []
743
1388
 
744
- query_params: dict[str, Any] = {}
1389
+ filter_queries, filter_params = edge_search_filter_query_constructor(
1390
+ search_filter, driver.provider
1391
+ )
745
1392
 
746
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
747
- query_params.update(filter_params)
1393
+ filter_query = ''
1394
+ if filter_queries:
1395
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
1396
+
1397
+ if driver.provider == GraphProvider.NEPTUNE:
1398
+ query = (
1399
+ RUNTIME_QUERY
1400
+ + """
1401
+ UNWIND $edges AS edge
1402
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1403
+ """
1404
+ + filter_query
1405
+ + """
1406
+ WITH e, edge
1407
+ RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
1408
+ edge.fact_embedding as target_embedding
1409
+ """
1410
+ )
1411
+ resp, _, _ = await driver.execute_query(
1412
+ query,
1413
+ edges=[edge.model_dump() for edge in edges],
1414
+ limit=limit,
1415
+ min_score=min_score,
1416
+ routing_='r',
1417
+ **filter_params,
1418
+ )
748
1419
 
749
- query = (
750
- RUNTIME_QUERY
751
- + """
752
- UNWIND $edges AS edge
753
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
754
- """
755
- + filter_query
756
- + """
757
- WITH e, edge, """
758
- + get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
759
- + """ AS score
760
- WHERE score > $min_score
761
- WITH edge, e, score
762
- ORDER BY score DESC
1420
+ # Calculate Cosine similarity then return the edge ids
1421
+ input_ids = []
1422
+ for r in resp:
1423
+ score = calculate_cosine_similarity(
1424
+ list(map(float, r['source_embedding'].split(','))), r['target_embedding']
1425
+ )
1426
+ if score > min_score:
1427
+ input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
1428
+
1429
+ # Match the edge ides and return the values
1430
+ query = """
1431
+ UNWIND $ids AS edge
1432
+ MATCH ()-[e]->()
1433
+ WHERE id(e) = edge.id
1434
+ WITH edge, e
1435
+ ORDER BY edge.score DESC
763
1436
  RETURN edge.uuid AS search_edge_uuid,
764
1437
  collect({
765
1438
  uuid: e.uuid,
@@ -769,28 +1442,119 @@ async def get_relevant_edges(
769
1442
  name: e.name,
770
1443
  group_id: e.group_id,
771
1444
  fact: e.fact,
772
- fact_embedding: e.fact_embedding,
773
- episodes: e.episodes,
1445
+ fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
1446
+ episodes: split(e.episodes, ","),
774
1447
  expired_at: e.expired_at,
775
1448
  valid_at: e.valid_at,
776
1449
  invalid_at: e.invalid_at,
777
1450
  attributes: properties(e)
778
1451
  })[..$limit] AS matches
779
- """
780
- )
781
-
782
- results, _, _ = await driver.execute_query(
783
- query,
784
- edges=[edge.model_dump() for edge in edges],
785
- limit=limit,
786
- min_score=min_score,
787
- routing_='r',
788
- **query_params,
789
- )
1452
+ """
1453
+
1454
+ results, _, _ = await driver.execute_query(
1455
+ query,
1456
+ ids=input_ids,
1457
+ edges=[edge.model_dump() for edge in edges],
1458
+ limit=limit,
1459
+ min_score=min_score,
1460
+ routing_='r',
1461
+ **filter_params,
1462
+ )
1463
+ else:
1464
+ if driver.provider == GraphProvider.KUZU:
1465
+ embedding_size = (
1466
+ len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
1467
+ )
1468
+ if embedding_size == 0:
1469
+ return []
1470
+
1471
+ query = (
1472
+ RUNTIME_QUERY
1473
+ + """
1474
+ UNWIND $edges AS edge
1475
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
1476
+ """
1477
+ + filter_query
1478
+ + """
1479
+ WITH e, edge, n, m, """
1480
+ + get_vector_cosine_func_query(
1481
+ 'e.fact_embedding',
1482
+ f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
1483
+ driver.provider,
1484
+ )
1485
+ + """ AS score
1486
+ WHERE score > $min_score
1487
+ WITH e, edge, n, m, score
1488
+ ORDER BY score DESC
1489
+ LIMIT $limit
1490
+ RETURN
1491
+ edge.uuid AS search_edge_uuid,
1492
+ collect({
1493
+ uuid: e.uuid,
1494
+ source_node_uuid: n.uuid,
1495
+ target_node_uuid: m.uuid,
1496
+ created_at: e.created_at,
1497
+ name: e.name,
1498
+ group_id: e.group_id,
1499
+ fact: e.fact,
1500
+ fact_embedding: e.fact_embedding,
1501
+ episodes: e.episodes,
1502
+ expired_at: e.expired_at,
1503
+ valid_at: e.valid_at,
1504
+ invalid_at: e.invalid_at,
1505
+ attributes: e.attributes
1506
+ }) AS matches
1507
+ """
1508
+ )
1509
+ else:
1510
+ query = (
1511
+ RUNTIME_QUERY
1512
+ + """
1513
+ UNWIND $edges AS edge
1514
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1515
+ """
1516
+ + filter_query
1517
+ + """
1518
+ WITH e, edge, """
1519
+ + get_vector_cosine_func_query(
1520
+ 'e.fact_embedding', 'edge.fact_embedding', driver.provider
1521
+ )
1522
+ + """ AS score
1523
+ WHERE score > $min_score
1524
+ WITH edge, e, score
1525
+ ORDER BY score DESC
1526
+ RETURN
1527
+ edge.uuid AS search_edge_uuid,
1528
+ collect({
1529
+ uuid: e.uuid,
1530
+ source_node_uuid: startNode(e).uuid,
1531
+ target_node_uuid: endNode(e).uuid,
1532
+ created_at: e.created_at,
1533
+ name: e.name,
1534
+ group_id: e.group_id,
1535
+ fact: e.fact,
1536
+ fact_embedding: e.fact_embedding,
1537
+ episodes: e.episodes,
1538
+ expired_at: e.expired_at,
1539
+ valid_at: e.valid_at,
1540
+ invalid_at: e.invalid_at,
1541
+ attributes: properties(e)
1542
+ })[..$limit] AS matches
1543
+ """
1544
+ )
1545
+
1546
+ results, _, _ = await driver.execute_query(
1547
+ query,
1548
+ edges=[edge.model_dump() for edge in edges],
1549
+ limit=limit,
1550
+ min_score=min_score,
1551
+ routing_='r',
1552
+ **filter_params,
1553
+ )
790
1554
 
791
1555
  relevant_edges_dict: dict[str, list[EntityEdge]] = {
792
1556
  result['search_edge_uuid']: [
793
- get_entity_edge_from_record(record) for record in result['matches']
1557
+ get_entity_edge_from_record(record, driver.provider) for record in result['matches']
794
1558
  ]
795
1559
  for result in results
796
1560
  }
@@ -810,26 +1574,55 @@ async def get_edge_invalidation_candidates(
810
1574
  if len(edges) == 0:
811
1575
  return []
812
1576
 
813
- query_params: dict[str, Any] = {}
1577
+ filter_queries, filter_params = edge_search_filter_query_constructor(
1578
+ search_filter, driver.provider
1579
+ )
814
1580
 
815
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
816
- query_params.update(filter_params)
1581
+ filter_query = ''
1582
+ if filter_queries:
1583
+ filter_query = ' AND ' + (' AND '.join(filter_queries))
1584
+
1585
+ if driver.provider == GraphProvider.NEPTUNE:
1586
+ query = (
1587
+ RUNTIME_QUERY
1588
+ + """
1589
+ UNWIND $edges AS edge
1590
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1591
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1592
+ """
1593
+ + filter_query
1594
+ + """
1595
+ WITH e, edge
1596
+ RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding,
1597
+ edge.fact_embedding as target_embedding,
1598
+ edge.uuid as search_edge_uuid
1599
+ """
1600
+ )
1601
+ resp, _, _ = await driver.execute_query(
1602
+ query,
1603
+ edges=[edge.model_dump() for edge in edges],
1604
+ limit=limit,
1605
+ min_score=min_score,
1606
+ routing_='r',
1607
+ **filter_params,
1608
+ )
817
1609
 
818
- query = (
819
- RUNTIME_QUERY
820
- + """
821
- UNWIND $edges AS edge
822
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
823
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
824
- """
825
- + filter_query
826
- + """
827
- WITH edge, e, """
828
- + get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
829
- + """ AS score
830
- WHERE score > $min_score
831
- WITH edge, e, score
832
- ORDER BY score DESC
1610
+ # Calculate Cosine similarity then return the edge ids
1611
+ input_ids = []
1612
+ for r in resp:
1613
+ score = calculate_cosine_similarity(
1614
+ list(map(float, r['source_embedding'].split(','))), r['target_embedding']
1615
+ )
1616
+ if score > min_score:
1617
+ input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
1618
+
1619
+ # Match the edge ides and return the values
1620
+ query = """
1621
+ UNWIND $ids AS edge
1622
+ MATCH ()-[e]->()
1623
+ WHERE id(e) = edge.id
1624
+ WITH edge, e
1625
+ ORDER BY edge.score DESC
833
1626
  RETURN edge.uuid AS search_edge_uuid,
834
1627
  collect({
835
1628
  uuid: e.uuid,
@@ -839,27 +1632,119 @@ async def get_edge_invalidation_candidates(
839
1632
  name: e.name,
840
1633
  group_id: e.group_id,
841
1634
  fact: e.fact,
842
- fact_embedding: e.fact_embedding,
843
- episodes: e.episodes,
1635
+ fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
1636
+ episodes: split(e.episodes, ","),
844
1637
  expired_at: e.expired_at,
845
1638
  valid_at: e.valid_at,
846
1639
  invalid_at: e.invalid_at,
847
1640
  attributes: properties(e)
848
1641
  })[..$limit] AS matches
849
- """
850
- )
851
-
852
- results, _, _ = await driver.execute_query(
853
- query,
854
- edges=[edge.model_dump() for edge in edges],
855
- limit=limit,
856
- min_score=min_score,
857
- routing_='r',
858
- **query_params,
859
- )
1642
+ """
1643
+ results, _, _ = await driver.execute_query(
1644
+ query,
1645
+ ids=input_ids,
1646
+ edges=[edge.model_dump() for edge in edges],
1647
+ limit=limit,
1648
+ min_score=min_score,
1649
+ routing_='r',
1650
+ **filter_params,
1651
+ )
1652
+ else:
1653
+ if driver.provider == GraphProvider.KUZU:
1654
+ embedding_size = (
1655
+ len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
1656
+ )
1657
+ if embedding_size == 0:
1658
+ return []
1659
+
1660
+ query = (
1661
+ RUNTIME_QUERY
1662
+ + """
1663
+ UNWIND $edges AS edge
1664
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
1665
+ WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
1666
+ """
1667
+ + filter_query
1668
+ + """
1669
+ WITH edge, e, n, m, """
1670
+ + get_vector_cosine_func_query(
1671
+ 'e.fact_embedding',
1672
+ f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
1673
+ driver.provider,
1674
+ )
1675
+ + """ AS score
1676
+ WHERE score > $min_score
1677
+ WITH edge, e, n, m, score
1678
+ ORDER BY score DESC
1679
+ LIMIT $limit
1680
+ RETURN
1681
+ edge.uuid AS search_edge_uuid,
1682
+ collect({
1683
+ uuid: e.uuid,
1684
+ source_node_uuid: n.uuid,
1685
+ target_node_uuid: m.uuid,
1686
+ created_at: e.created_at,
1687
+ name: e.name,
1688
+ group_id: e.group_id,
1689
+ fact: e.fact,
1690
+ fact_embedding: e.fact_embedding,
1691
+ episodes: e.episodes,
1692
+ expired_at: e.expired_at,
1693
+ valid_at: e.valid_at,
1694
+ invalid_at: e.invalid_at,
1695
+ attributes: e.attributes
1696
+ }) AS matches
1697
+ """
1698
+ )
1699
+ else:
1700
+ query = (
1701
+ RUNTIME_QUERY
1702
+ + """
1703
+ UNWIND $edges AS edge
1704
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1705
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1706
+ """
1707
+ + filter_query
1708
+ + """
1709
+ WITH edge, e, """
1710
+ + get_vector_cosine_func_query(
1711
+ 'e.fact_embedding', 'edge.fact_embedding', driver.provider
1712
+ )
1713
+ + """ AS score
1714
+ WHERE score > $min_score
1715
+ WITH edge, e, score
1716
+ ORDER BY score DESC
1717
+ RETURN
1718
+ edge.uuid AS search_edge_uuid,
1719
+ collect({
1720
+ uuid: e.uuid,
1721
+ source_node_uuid: startNode(e).uuid,
1722
+ target_node_uuid: endNode(e).uuid,
1723
+ created_at: e.created_at,
1724
+ name: e.name,
1725
+ group_id: e.group_id,
1726
+ fact: e.fact,
1727
+ fact_embedding: e.fact_embedding,
1728
+ episodes: e.episodes,
1729
+ expired_at: e.expired_at,
1730
+ valid_at: e.valid_at,
1731
+ invalid_at: e.invalid_at,
1732
+ attributes: properties(e)
1733
+ })[..$limit] AS matches
1734
+ """
1735
+ )
1736
+
1737
+ results, _, _ = await driver.execute_query(
1738
+ query,
1739
+ edges=[edge.model_dump() for edge in edges],
1740
+ limit=limit,
1741
+ min_score=min_score,
1742
+ routing_='r',
1743
+ **filter_params,
1744
+ )
860
1745
  invalidation_edges_dict: dict[str, list[EntityEdge]] = {
861
1746
  result['search_edge_uuid']: [
862
- get_entity_edge_from_record(record) for record in result['matches']
1747
+ get_entity_edge_from_record(record, driver.provider) for record in result['matches']
863
1748
  ]
864
1749
  for result in results
865
1750
  }
@@ -898,13 +1783,21 @@ async def node_distance_reranker(
898
1783
  filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
899
1784
  scores: dict[str, float] = {center_node_uuid: 0.0}
900
1785
 
901
- # Find the shortest path to center node
902
- results, header, _ = await driver.execute_query(
903
- """
1786
+ query = """
1787
+ UNWIND $node_uuids AS node_uuid
1788
+ MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
1789
+ RETURN 1 AS score, node_uuid AS uuid
1790
+ """
1791
+ if driver.provider == GraphProvider.KUZU:
1792
+ query = """
904
1793
  UNWIND $node_uuids AS node_uuid
905
- MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
1794
+ MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(n:Entity {uuid: node_uuid})
906
1795
  RETURN 1 AS score, node_uuid AS uuid
907
- """,
1796
+ """
1797
+
1798
+ # Find the shortest path to center node
1799
+ results, header, _ = await driver.execute_query(
1800
+ query,
908
1801
  node_uuids=filtered_uuids,
909
1802
  center_uuid=center_node_uuid,
910
1803
  routing_='r',
@@ -955,6 +1848,10 @@ async def episode_mentions_reranker(
955
1848
  for result in results:
956
1849
  scores[result['uuid']] = result['score']
957
1850
 
1851
+ for uuid in sorted_uuids:
1852
+ if uuid not in scores:
1853
+ scores[uuid] = float('inf')
1854
+
958
1855
  # rerank on shortest distance
959
1856
  sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
960
1857
 
@@ -1007,14 +1904,24 @@ def maximal_marginal_relevance(
1007
1904
  async def get_embeddings_for_nodes(
1008
1905
  driver: GraphDriver, nodes: list[EntityNode]
1009
1906
  ) -> dict[str, list[float]]:
1010
- results, _, _ = await driver.execute_query(
1907
+ if driver.provider == GraphProvider.NEPTUNE:
1908
+ query = """
1909
+ MATCH (n:Entity)
1910
+ WHERE n.uuid IN $node_uuids
1911
+ RETURN DISTINCT
1912
+ n.uuid AS uuid,
1913
+ split(n.name_embedding, ",") AS name_embedding
1011
1914
  """
1915
+ else:
1916
+ query = """
1012
1917
  MATCH (n:Entity)
1013
1918
  WHERE n.uuid IN $node_uuids
1014
1919
  RETURN DISTINCT
1015
1920
  n.uuid AS uuid,
1016
1921
  n.name_embedding AS name_embedding
1017
- """,
1922
+ """
1923
+ results, _, _ = await driver.execute_query(
1924
+ query,
1018
1925
  node_uuids=[node.uuid for node in nodes],
1019
1926
  routing_='r',
1020
1927
  )
@@ -1032,14 +1939,24 @@ async def get_embeddings_for_nodes(
1032
1939
  async def get_embeddings_for_communities(
1033
1940
  driver: GraphDriver, communities: list[CommunityNode]
1034
1941
  ) -> dict[str, list[float]]:
1035
- results, _, _ = await driver.execute_query(
1942
+ if driver.provider == GraphProvider.NEPTUNE:
1943
+ query = """
1944
+ MATCH (c:Community)
1945
+ WHERE c.uuid IN $community_uuids
1946
+ RETURN DISTINCT
1947
+ c.uuid AS uuid,
1948
+ split(c.name_embedding, ",") AS name_embedding
1036
1949
  """
1950
+ else:
1951
+ query = """
1037
1952
  MATCH (c:Community)
1038
1953
  WHERE c.uuid IN $community_uuids
1039
1954
  RETURN DISTINCT
1040
1955
  c.uuid AS uuid,
1041
1956
  c.name_embedding AS name_embedding
1042
- """,
1957
+ """
1958
+ results, _, _ = await driver.execute_query(
1959
+ query,
1043
1960
  community_uuids=[community.uuid for community in communities],
1044
1961
  routing_='r',
1045
1962
  )
@@ -1057,14 +1974,34 @@ async def get_embeddings_for_communities(
1057
1974
  async def get_embeddings_for_edges(
1058
1975
  driver: GraphDriver, edges: list[EntityEdge]
1059
1976
  ) -> dict[str, list[float]]:
1060
- results, _, _ = await driver.execute_query(
1061
- """
1977
+ if driver.provider == GraphProvider.NEPTUNE:
1978
+ query = """
1062
1979
  MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1063
1980
  WHERE e.uuid IN $edge_uuids
1981
+ RETURN DISTINCT
1982
+ e.uuid AS uuid,
1983
+ split(e.fact_embedding, ",") AS fact_embedding
1984
+ """
1985
+ else:
1986
+ match_query = """
1987
+ MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1988
+ """
1989
+ if driver.provider == GraphProvider.KUZU:
1990
+ match_query = """
1991
+ MATCH (n:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m:Entity)
1992
+ """
1993
+
1994
+ query = (
1995
+ match_query
1996
+ + """
1997
+ WHERE e.uuid IN $edge_uuids
1064
1998
  RETURN DISTINCT
1065
1999
  e.uuid AS uuid,
1066
2000
  e.fact_embedding AS fact_embedding
1067
- """,
2001
+ """
2002
+ )
2003
+ results, _, _ = await driver.execute_query(
2004
+ query,
1068
2005
  edge_uuids=[edge.uuid for edge in edges],
1069
2006
  routing_='r',
1070
2007
  )