graphiti-core 0.17.4__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
  2. graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
  3. graphiti_core/decorators.py +110 -0
  4. graphiti_core/driver/driver.py +62 -2
  5. graphiti_core/driver/falkordb_driver.py +215 -23
  6. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  7. graphiti_core/driver/kuzu_driver.py +182 -0
  8. graphiti_core/driver/neo4j_driver.py +61 -8
  9. graphiti_core/driver/neptune_driver.py +305 -0
  10. graphiti_core/driver/search_interface/search_interface.py +89 -0
  11. graphiti_core/edges.py +264 -132
  12. graphiti_core/embedder/azure_openai.py +10 -3
  13. graphiti_core/embedder/client.py +2 -1
  14. graphiti_core/graph_queries.py +114 -101
  15. graphiti_core/graphiti.py +582 -255
  16. graphiti_core/graphiti_types.py +2 -0
  17. graphiti_core/helpers.py +21 -14
  18. graphiti_core/llm_client/anthropic_client.py +142 -52
  19. graphiti_core/llm_client/azure_openai_client.py +57 -19
  20. graphiti_core/llm_client/client.py +83 -21
  21. graphiti_core/llm_client/config.py +1 -1
  22. graphiti_core/llm_client/gemini_client.py +75 -57
  23. graphiti_core/llm_client/openai_base_client.py +94 -50
  24. graphiti_core/llm_client/openai_client.py +28 -8
  25. graphiti_core/llm_client/openai_generic_client.py +91 -56
  26. graphiti_core/models/edges/edge_db_queries.py +259 -35
  27. graphiti_core/models/nodes/node_db_queries.py +311 -32
  28. graphiti_core/nodes.py +388 -164
  29. graphiti_core/prompts/dedupe_edges.py +42 -31
  30. graphiti_core/prompts/dedupe_nodes.py +56 -39
  31. graphiti_core/prompts/eval.py +4 -4
  32. graphiti_core/prompts/extract_edges.py +23 -14
  33. graphiti_core/prompts/extract_nodes.py +73 -32
  34. graphiti_core/prompts/prompt_helpers.py +39 -0
  35. graphiti_core/prompts/snippets.py +29 -0
  36. graphiti_core/prompts/summarize_nodes.py +23 -25
  37. graphiti_core/search/search.py +154 -74
  38. graphiti_core/search/search_config.py +39 -4
  39. graphiti_core/search/search_filters.py +109 -31
  40. graphiti_core/search/search_helpers.py +5 -6
  41. graphiti_core/search/search_utils.py +1360 -473
  42. graphiti_core/tracer.py +193 -0
  43. graphiti_core/utils/bulk_utils.py +216 -90
  44. graphiti_core/utils/datetime_utils.py +13 -0
  45. graphiti_core/utils/maintenance/community_operations.py +62 -38
  46. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  47. graphiti_core/utils/maintenance/edge_operations.py +286 -126
  48. graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
  49. graphiti_core/utils/maintenance/node_operations.py +320 -158
  50. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  51. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  52. graphiti_core/utils/text_utils.py +53 -0
  53. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/METADATA +221 -87
  54. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  55. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  56. graphiti_core-0.17.4.dist-info/RECORD +0 -77
  57. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  58. {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/licenses/LICENSE +0 -0
@@ -23,7 +23,10 @@ import numpy as np
23
23
  from numpy._typing import NDArray
24
24
  from typing_extensions import LiteralString
25
25
 
26
- from graphiti_core.driver.driver import GraphDriver
26
+ from graphiti_core.driver.driver import (
27
+ GraphDriver,
28
+ GraphProvider,
29
+ )
27
30
  from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
28
31
  from graphiti_core.graph_queries import (
29
32
  get_nodes_query,
@@ -31,13 +34,17 @@ from graphiti_core.graph_queries import (
31
34
  get_vector_cosine_func_query,
32
35
  )
33
36
  from graphiti_core.helpers import (
34
- RUNTIME_QUERY,
35
37
  lucene_sanitize,
36
38
  normalize_l2,
37
39
  semaphore_gather,
38
40
  )
41
+ from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query
42
+ from graphiti_core.models.nodes.node_db_queries import (
43
+ COMMUNITY_NODE_RETURN,
44
+ EPISODIC_NODE_RETURN,
45
+ get_entity_node_return_query,
46
+ )
39
47
  from graphiti_core.nodes import (
40
- ENTITY_NODE_RETURN,
41
48
  CommunityNode,
42
49
  EntityNode,
43
50
  EpisodicNode,
@@ -57,12 +64,35 @@ RELEVANT_SCHEMA_LIMIT = 10
57
64
  DEFAULT_MIN_SCORE = 0.6
58
65
  DEFAULT_MMR_LAMBDA = 0.5
59
66
  MAX_SEARCH_DEPTH = 3
60
- MAX_QUERY_LENGTH = 32
67
+ MAX_QUERY_LENGTH = 128
68
+
69
+
70
+ def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> float:
71
+ """
72
+ Calculates the cosine similarity between two vectors using NumPy.
73
+ """
74
+ dot_product = np.dot(vector1, vector2)
75
+ norm_vector1 = np.linalg.norm(vector1)
76
+ norm_vector2 = np.linalg.norm(vector2)
61
77
 
78
+ if norm_vector1 == 0 or norm_vector2 == 0:
79
+ return 0 # Handle cases where one or both vectors are zero vectors
62
80
 
63
- def fulltext_query(query: str, group_ids: list[str] | None = None):
81
+ return dot_product / (norm_vector1 * norm_vector2)
82
+
83
+
84
+ def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver):
85
+ if driver.provider == GraphProvider.KUZU:
86
+ # Kuzu only supports simple queries.
87
+ if len(query.split(' ')) > MAX_QUERY_LENGTH:
88
+ return ''
89
+ return query
90
+ elif driver.provider == GraphProvider.FALKORDB:
91
+ return driver.build_fulltext_query(query, group_ids, MAX_QUERY_LENGTH)
64
92
  group_ids_filter_list = (
65
- [f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else []
93
+ [driver.fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
94
+ if group_ids is not None
95
+ else []
66
96
  )
67
97
  group_ids_filter = ''
68
98
  for f in group_ids_filter_list:
@@ -100,25 +130,18 @@ async def get_mentioned_nodes(
100
130
  ) -> list[EntityNode]:
101
131
  episode_uuids = [episode.uuid for episode in episodes]
102
132
 
103
- query = """
104
- MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
133
+ records, _, _ = await driver.execute_query(
134
+ """
135
+ MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
136
+ WHERE episode.uuid IN $uuids
105
137
  RETURN DISTINCT
106
- n.uuid As uuid,
107
- n.group_id AS group_id,
108
- n.name AS name,
109
- n.created_at AS created_at,
110
- n.summary AS summary,
111
- labels(n) AS labels,
112
- properties(n) AS attributes
113
138
  """
114
-
115
- records, _, _ = await driver.execute_query(
116
- query,
139
+ + get_entity_node_return_query(driver.provider),
117
140
  uuids=episode_uuids,
118
141
  routing_='r',
119
142
  )
120
143
 
121
- nodes = [get_entity_node_from_record(record) for record in records]
144
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
122
145
 
123
146
  return nodes
124
147
 
@@ -128,18 +151,13 @@ async def get_communities_by_nodes(
128
151
  ) -> list[CommunityNode]:
129
152
  node_uuids = [node.uuid for node in nodes]
130
153
 
131
- query = """
132
- MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
133
- RETURN DISTINCT
134
- c.uuid As uuid,
135
- c.group_id AS group_id,
136
- c.name AS name,
137
- c.created_at AS created_at,
138
- c.summary AS summary
139
- """
140
-
141
154
  records, _, _ = await driver.execute_query(
142
- query,
155
+ """
156
+ MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)
157
+ WHERE m.uuid IN $uuids
158
+ RETURN DISTINCT
159
+ """
160
+ + COMMUNITY_NODE_RETURN,
143
161
  uuids=node_uuids,
144
162
  routing_='r',
145
163
  )
@@ -156,49 +174,110 @@ async def edge_fulltext_search(
156
174
  group_ids: list[str] | None = None,
157
175
  limit=RELEVANT_SCHEMA_LIMIT,
158
176
  ) -> list[EntityEdge]:
177
+ if driver.search_interface:
178
+ return await driver.search_interface.edge_fulltext_search(
179
+ driver, query, search_filter, group_ids, limit
180
+ )
181
+
159
182
  # fulltext search over facts
160
- fuzzy_query = fulltext_query(query, group_ids)
183
+ fuzzy_query = fulltext_query(query, group_ids, driver)
184
+
161
185
  if fuzzy_query == '':
162
186
  return []
163
187
 
164
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
165
-
166
- query = (
167
- get_relationships_query('edge_name_and_fact', db_type=driver.provider)
168
- + """
169
- YIELD relationship AS rel, score
170
- MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
171
- WHERE r.group_id IN $group_ids """
172
- + filter_query
173
- + """
174
- WITH r, score, startNode(r) AS n, endNode(r) AS m
175
- RETURN
176
- r.uuid AS uuid,
177
- r.group_id AS group_id,
178
- n.uuid AS source_node_uuid,
179
- m.uuid AS target_node_uuid,
180
- r.created_at AS created_at,
181
- r.name AS name,
182
- r.fact AS fact,
183
- r.episodes AS episodes,
184
- r.expired_at AS expired_at,
185
- r.valid_at AS valid_at,
186
- r.invalid_at AS invalid_at,
187
- properties(r) AS attributes
188
- ORDER BY score DESC LIMIT $limit
188
+ match_query = """
189
+ YIELD relationship AS rel, score
190
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
191
+ """
192
+ if driver.provider == GraphProvider.KUZU:
193
+ match_query = """
194
+ YIELD node, score
195
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: node.uuid})-[:RELATES_TO]->(m:Entity)
189
196
  """
190
- )
191
197
 
192
- records, _, _ = await driver.execute_query(
193
- query,
194
- params=filter_params,
195
- query=fuzzy_query,
196
- group_ids=group_ids,
197
- limit=limit,
198
- routing_='r',
198
+ filter_queries, filter_params = edge_search_filter_query_constructor(
199
+ search_filter, driver.provider
199
200
  )
200
201
 
201
- edges = [get_entity_edge_from_record(record) for record in records]
202
+ if group_ids is not None:
203
+ filter_queries.append('e.group_id IN $group_ids')
204
+ filter_params['group_ids'] = group_ids
205
+
206
+ filter_query = ''
207
+ if filter_queries:
208
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
209
+
210
+ if driver.provider == GraphProvider.NEPTUNE:
211
+ res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
212
+ if res['hits']['total']['value'] > 0:
213
+ input_ids = []
214
+ for r in res['hits']['hits']:
215
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
216
+
217
+ # Match the edge ids and return the values
218
+ query = (
219
+ """
220
+ UNWIND $ids as id
221
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
222
+ WHERE e.group_id IN $group_ids
223
+ AND id(e)=id
224
+ """
225
+ + filter_query
226
+ + """
227
+ AND id(e)=id
228
+ WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
229
+ RETURN
230
+ e.uuid AS uuid,
231
+ e.group_id AS group_id,
232
+ n.uuid AS source_node_uuid,
233
+ m.uuid AS target_node_uuid,
234
+ e.created_at AS created_at,
235
+ e.name AS name,
236
+ e.fact AS fact,
237
+ split(e.episodes, ",") AS episodes,
238
+ e.expired_at AS expired_at,
239
+ e.valid_at AS valid_at,
240
+ e.invalid_at AS invalid_at,
241
+ properties(e) AS attributes
242
+ ORDER BY score DESC LIMIT $limit
243
+ """
244
+ )
245
+
246
+ records, _, _ = await driver.execute_query(
247
+ query,
248
+ query=fuzzy_query,
249
+ ids=input_ids,
250
+ limit=limit,
251
+ routing_='r',
252
+ **filter_params,
253
+ )
254
+ else:
255
+ return []
256
+ else:
257
+ query = (
258
+ get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
259
+ + match_query
260
+ + filter_query
261
+ + """
262
+ WITH e, score, n, m
263
+ RETURN
264
+ """
265
+ + get_entity_edge_return_query(driver.provider)
266
+ + """
267
+ ORDER BY score DESC
268
+ LIMIT $limit
269
+ """
270
+ )
271
+
272
+ records, _, _ = await driver.execute_query(
273
+ query,
274
+ query=fuzzy_query,
275
+ limit=limit,
276
+ routing_='r',
277
+ **filter_params,
278
+ )
279
+
280
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
202
281
 
203
282
  return edges
204
283
 
@@ -213,95 +292,86 @@ async def edge_similarity_search(
213
292
  limit: int = RELEVANT_SCHEMA_LIMIT,
214
293
  min_score: float = DEFAULT_MIN_SCORE,
215
294
  ) -> list[EntityEdge]:
216
- # vector similarity search over embedded facts
217
- query_params: dict[str, Any] = {}
295
+ if driver.search_interface:
296
+ return await driver.search_interface.edge_similarity_search(
297
+ driver,
298
+ search_vector,
299
+ source_node_uuid,
300
+ target_node_uuid,
301
+ search_filter,
302
+ group_ids,
303
+ limit,
304
+ min_score,
305
+ )
218
306
 
219
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
220
- query_params.update(filter_params)
307
+ match_query = """
308
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
309
+ """
310
+ if driver.provider == GraphProvider.KUZU:
311
+ match_query = """
312
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
313
+ """
314
+
315
+ filter_queries, filter_params = edge_search_filter_query_constructor(
316
+ search_filter, driver.provider
317
+ )
221
318
 
222
- group_filter_query: LiteralString = 'WHERE r.group_id IS NOT NULL'
223
319
  if group_ids is not None:
224
- group_filter_query += '\nAND r.group_id IN $group_ids'
225
- query_params['group_ids'] = group_ids
226
- query_params['source_node_uuid'] = source_node_uuid
227
- query_params['target_node_uuid'] = target_node_uuid
320
+ filter_queries.append('e.group_id IN $group_ids')
321
+ filter_params['group_ids'] = group_ids
228
322
 
229
323
  if source_node_uuid is not None:
230
- group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
324
+ filter_params['source_uuid'] = source_node_uuid
325
+ filter_queries.append('n.uuid = $source_uuid')
231
326
 
232
327
  if target_node_uuid is not None:
233
- group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
234
-
235
- query = (
236
- RUNTIME_QUERY
237
- + """
238
- MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
239
- """
240
- + group_filter_query
241
- + filter_query
242
- + """
243
- WITH DISTINCT r, """
244
- + get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
245
- + """ AS score
246
- WHERE score > $min_score
247
- RETURN
248
- r.uuid AS uuid,
249
- r.group_id AS group_id,
250
- startNode(r).uuid AS source_node_uuid,
251
- endNode(r).uuid AS target_node_uuid,
252
- r.created_at AS created_at,
253
- r.name AS name,
254
- r.fact AS fact,
255
- r.episodes AS episodes,
256
- r.expired_at AS expired_at,
257
- r.valid_at AS valid_at,
258
- r.invalid_at AS invalid_at,
259
- properties(r) AS attributes
260
- ORDER BY score DESC
261
- LIMIT $limit
262
- """
263
- )
264
- records, header, _ = await driver.execute_query(
265
- query,
266
- params=query_params,
267
- search_vector=search_vector,
268
- source_uuid=source_node_uuid,
269
- target_uuid=target_node_uuid,
270
- group_ids=group_ids,
271
- limit=limit,
272
- min_score=min_score,
273
- routing_='r',
274
- )
275
-
276
- edges = [get_entity_edge_from_record(record) for record in records]
328
+ filter_params['target_uuid'] = target_node_uuid
329
+ filter_queries.append('m.uuid = $target_uuid')
277
330
 
278
- return edges
331
+ filter_query = ''
332
+ if filter_queries:
333
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
279
334
 
335
+ search_vector_var = '$search_vector'
336
+ if driver.provider == GraphProvider.KUZU:
337
+ search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
280
338
 
281
- async def edge_bfs_search(
282
- driver: GraphDriver,
283
- bfs_origin_node_uuids: list[str] | None,
284
- bfs_max_depth: int,
285
- search_filter: SearchFilters,
286
- limit: int,
287
- ) -> list[EntityEdge]:
288
- # vector similarity search over embedded facts
289
- if bfs_origin_node_uuids is None:
290
- return []
291
-
292
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
339
+ if driver.provider == GraphProvider.NEPTUNE:
340
+ query = (
341
+ """
342
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
343
+ """
344
+ + filter_query
345
+ + """
346
+ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
347
+ """
348
+ )
349
+ resp, header, _ = await driver.execute_query(
350
+ query,
351
+ search_vector=search_vector,
352
+ limit=limit,
353
+ min_score=min_score,
354
+ routing_='r',
355
+ **filter_params,
356
+ )
293
357
 
294
- query = (
295
- """
296
- UNWIND $bfs_origin_node_uuids AS origin_uuid
297
- MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
298
- UNWIND relationships(path) AS rel
299
- MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
300
- WHERE r.uuid = rel.uuid
301
- """
302
- + filter_query
303
- + """
304
- RETURN DISTINCT
358
+ if len(resp) > 0:
359
+ # Calculate Cosine similarity then return the edge ids
360
+ input_ids = []
361
+ for r in resp:
362
+ if r['embedding']:
363
+ score = calculate_cosine_similarity(
364
+ search_vector, list(map(float, r['embedding'].split(',')))
365
+ )
366
+ if score > min_score:
367
+ input_ids.append({'id': r['id'], 'score': score})
368
+
369
+ # Match the edge ides and return the values
370
+ query = """
371
+ UNWIND $ids as i
372
+ MATCH ()-[r]->()
373
+ WHERE id(r) = i.id
374
+ RETURN
305
375
  r.uuid AS uuid,
306
376
  r.group_id AS group_id,
307
377
  startNode(r).uuid AS source_node_uuid,
@@ -309,25 +379,176 @@ async def edge_bfs_search(
309
379
  r.created_at AS created_at,
310
380
  r.name AS name,
311
381
  r.fact AS fact,
312
- r.episodes AS episodes,
382
+ split(r.episodes, ",") AS episodes,
313
383
  r.expired_at AS expired_at,
314
384
  r.valid_at AS valid_at,
315
385
  r.invalid_at AS invalid_at,
316
386
  properties(r) AS attributes
387
+ ORDER BY i.score DESC
317
388
  LIMIT $limit
318
- """
319
- )
389
+ """
390
+ records, _, _ = await driver.execute_query(
391
+ query,
392
+ ids=input_ids,
393
+ search_vector=search_vector,
394
+ limit=limit,
395
+ min_score=min_score,
396
+ routing_='r',
397
+ **filter_params,
398
+ )
399
+ else:
400
+ return []
401
+ else:
402
+ query = (
403
+ match_query
404
+ + filter_query
405
+ + """
406
+ WITH DISTINCT e, n, m, """
407
+ + get_vector_cosine_func_query('e.fact_embedding', search_vector_var, driver.provider)
408
+ + """ AS score
409
+ WHERE score > $min_score
410
+ RETURN
411
+ """
412
+ + get_entity_edge_return_query(driver.provider)
413
+ + """
414
+ ORDER BY score DESC
415
+ LIMIT $limit
416
+ """
417
+ )
320
418
 
321
- records, _, _ = await driver.execute_query(
322
- query,
323
- params=filter_params,
324
- bfs_origin_node_uuids=bfs_origin_node_uuids,
325
- depth=bfs_max_depth,
326
- limit=limit,
327
- routing_='r',
419
+ records, _, _ = await driver.execute_query(
420
+ query,
421
+ search_vector=search_vector,
422
+ limit=limit,
423
+ min_score=min_score,
424
+ routing_='r',
425
+ **filter_params,
426
+ )
427
+
428
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
429
+
430
+ return edges
431
+
432
+
433
+ async def edge_bfs_search(
434
+ driver: GraphDriver,
435
+ bfs_origin_node_uuids: list[str] | None,
436
+ bfs_max_depth: int,
437
+ search_filter: SearchFilters,
438
+ group_ids: list[str] | None = None,
439
+ limit: int = RELEVANT_SCHEMA_LIMIT,
440
+ ) -> list[EntityEdge]:
441
+ # vector similarity search over embedded facts
442
+ if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0:
443
+ return []
444
+
445
+ filter_queries, filter_params = edge_search_filter_query_constructor(
446
+ search_filter, driver.provider
328
447
  )
329
448
 
330
- edges = [get_entity_edge_from_record(record) for record in records]
449
+ if group_ids is not None:
450
+ filter_queries.append('e.group_id IN $group_ids')
451
+ filter_params['group_ids'] = group_ids
452
+
453
+ filter_query = ''
454
+ if filter_queries:
455
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
456
+
457
+ if driver.provider == GraphProvider.KUZU:
458
+ # Kuzu stores entity edges twice with an intermediate node, so we need to match them
459
+ # separately for the correct BFS depth.
460
+ depth = bfs_max_depth * 2 - 1
461
+ match_queries = [
462
+ f"""
463
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
464
+ MATCH path = (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
465
+ UNWIND nodes(path) AS relNode
466
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
467
+ """,
468
+ ]
469
+ if bfs_max_depth > 1:
470
+ depth = (bfs_max_depth - 1) * 2 - 1
471
+ match_queries.append(f"""
472
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
473
+ MATCH path = (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
474
+ UNWIND nodes(path) AS relNode
475
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
476
+ """)
477
+
478
+ records = []
479
+ for match_query in match_queries:
480
+ sub_records, _, _ = await driver.execute_query(
481
+ match_query
482
+ + filter_query
483
+ + """
484
+ RETURN DISTINCT
485
+ """
486
+ + get_entity_edge_return_query(driver.provider)
487
+ + """
488
+ LIMIT $limit
489
+ """,
490
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
491
+ limit=limit,
492
+ routing_='r',
493
+ **filter_params,
494
+ )
495
+ records.extend(sub_records)
496
+ else:
497
+ if driver.provider == GraphProvider.NEPTUNE:
498
+ query = (
499
+ f"""
500
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
501
+ MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
502
+ WHERE origin:Entity OR origin:Episodic
503
+ UNWIND relationships(path) AS rel
504
+ MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
505
+ """
506
+ + filter_query
507
+ + """
508
+ RETURN DISTINCT
509
+ e.uuid AS uuid,
510
+ e.group_id AS group_id,
511
+ startNode(e).uuid AS source_node_uuid,
512
+ endNode(e).uuid AS target_node_uuid,
513
+ e.created_at AS created_at,
514
+ e.name AS name,
515
+ e.fact AS fact,
516
+ split(e.episodes, ',') AS episodes,
517
+ e.expired_at AS expired_at,
518
+ e.valid_at AS valid_at,
519
+ e.invalid_at AS invalid_at,
520
+ properties(e) AS attributes
521
+ LIMIT $limit
522
+ """
523
+ )
524
+ else:
525
+ query = (
526
+ f"""
527
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
528
+ MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
529
+ UNWIND relationships(path) AS rel
530
+ MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
531
+ """
532
+ + filter_query
533
+ + """
534
+ RETURN DISTINCT
535
+ """
536
+ + get_entity_edge_return_query(driver.provider)
537
+ + """
538
+ LIMIT $limit
539
+ """
540
+ )
541
+
542
+ records, _, _ = await driver.execute_query(
543
+ query,
544
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
545
+ depth=bfs_max_depth,
546
+ limit=limit,
547
+ routing_='r',
548
+ **filter_params,
549
+ )
550
+
551
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
331
552
 
332
553
  return edges
333
554
 
@@ -339,36 +560,88 @@ async def node_fulltext_search(
339
560
  group_ids: list[str] | None = None,
340
561
  limit=RELEVANT_SCHEMA_LIMIT,
341
562
  ) -> list[EntityNode]:
563
+ if driver.search_interface:
564
+ return await driver.search_interface.node_fulltext_search(
565
+ driver, query, search_filter, group_ids, limit
566
+ )
567
+
342
568
  # BM25 search to get top nodes
343
- fuzzy_query = fulltext_query(query, group_ids)
569
+ fuzzy_query = fulltext_query(query, group_ids, driver)
344
570
  if fuzzy_query == '':
345
571
  return []
346
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
347
572
 
348
- query = (
349
- get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
350
- + """
351
- YIELD node AS n, score
573
+ filter_queries, filter_params = node_search_filter_query_constructor(
574
+ search_filter, driver.provider
575
+ )
576
+
577
+ if group_ids is not None:
578
+ filter_queries.append('n.group_id IN $group_ids')
579
+ filter_params['group_ids'] = group_ids
580
+
581
+ filter_query = ''
582
+ if filter_queries:
583
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
584
+
585
+ yield_query = 'YIELD node AS n, score'
586
+ if driver.provider == GraphProvider.KUZU:
587
+ yield_query = 'WITH node AS n, score'
588
+
589
+ if driver.provider == GraphProvider.NEPTUNE:
590
+ res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
591
+ if res['hits']['total']['value'] > 0:
592
+ input_ids = []
593
+ for r in res['hits']['hits']:
594
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
595
+
596
+ # Match the edge ides and return the values
597
+ query = (
598
+ """
599
+ UNWIND $ids as i
600
+ MATCH (n:Entity)
601
+ WHERE n.uuid=i.id
602
+ RETURN
603
+ """
604
+ + get_entity_node_return_query(driver.provider)
605
+ + """
606
+ ORDER BY i.score DESC
607
+ LIMIT $limit
608
+ """
609
+ )
610
+ records, _, _ = await driver.execute_query(
611
+ query,
612
+ ids=input_ids,
613
+ query=fuzzy_query,
614
+ limit=limit,
615
+ routing_='r',
616
+ **filter_params,
617
+ )
618
+ else:
619
+ return []
620
+ else:
621
+ query = (
622
+ get_nodes_query(
623
+ 'node_name_and_summary', '$query', limit=limit, provider=driver.provider
624
+ )
625
+ + yield_query
626
+ + filter_query
627
+ + """
352
628
  WITH n, score
629
+ ORDER BY score DESC
353
630
  LIMIT $limit
354
- WHERE n:Entity
355
- """
356
- + filter_query
357
- + ENTITY_NODE_RETURN
358
- + """
359
- ORDER BY score DESC
360
- """
361
- )
362
- records, header, _ = await driver.execute_query(
363
- query,
364
- params=filter_params,
365
- query=fuzzy_query,
366
- group_ids=group_ids,
367
- limit=limit,
368
- routing_='r',
369
- )
631
+ RETURN
632
+ """
633
+ + get_entity_node_return_query(driver.provider)
634
+ )
635
+
636
+ records, _, _ = await driver.execute_query(
637
+ query,
638
+ query=fuzzy_query,
639
+ limit=limit,
640
+ routing_='r',
641
+ **filter_params,
642
+ )
370
643
 
371
- nodes = [get_entity_node_from_record(record) for record in records]
644
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
372
645
 
373
646
  return nodes
374
647
 
@@ -381,47 +654,112 @@ async def node_similarity_search(
381
654
  limit=RELEVANT_SCHEMA_LIMIT,
382
655
  min_score: float = DEFAULT_MIN_SCORE,
383
656
  ) -> list[EntityNode]:
384
- # vector similarity search over entity names
385
- query_params: dict[str, Any] = {}
657
+ if driver.search_interface:
658
+ return await driver.search_interface.node_similarity_search(
659
+ driver, search_vector, search_filter, group_ids, limit, min_score
660
+ )
661
+
662
+ filter_queries, filter_params = node_search_filter_query_constructor(
663
+ search_filter, driver.provider
664
+ )
386
665
 
387
- group_filter_query: LiteralString = 'WHERE n.group_id IS NOT NULL'
388
666
  if group_ids is not None:
389
- group_filter_query += ' AND n.group_id IN $group_ids'
390
- query_params['group_ids'] = group_ids
667
+ filter_queries.append('n.group_id IN $group_ids')
668
+ filter_params['group_ids'] = group_ids
391
669
 
392
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
393
- query_params.update(filter_params)
670
+ filter_query = ''
671
+ if filter_queries:
672
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
394
673
 
395
- query = (
396
- RUNTIME_QUERY
397
- + """
398
- MATCH (n:Entity)
399
- """
400
- + group_filter_query
401
- + filter_query
402
- + """
403
- WITH n, """
404
- + get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
405
- + """ AS score
406
- WHERE score > $min_score"""
407
- + ENTITY_NODE_RETURN
408
- + """
409
- ORDER BY score DESC
410
- LIMIT $limit
674
+ search_vector_var = '$search_vector'
675
+ if driver.provider == GraphProvider.KUZU:
676
+ search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
677
+
678
+ if driver.provider == GraphProvider.NEPTUNE:
679
+ query = (
411
680
  """
412
- )
681
+ MATCH (n:Entity)
682
+ """
683
+ + filter_query
684
+ + """
685
+ RETURN DISTINCT id(n) as id, n.name_embedding as embedding
686
+ """
687
+ )
688
+ resp, header, _ = await driver.execute_query(
689
+ query,
690
+ params=filter_params,
691
+ search_vector=search_vector,
692
+ limit=limit,
693
+ min_score=min_score,
694
+ routing_='r',
695
+ )
413
696
 
414
- records, header, _ = await driver.execute_query(
415
- query,
416
- params=query_params,
417
- search_vector=search_vector,
418
- group_ids=group_ids,
419
- limit=limit,
420
- min_score=min_score,
421
- routing_='r',
422
- )
697
+ if len(resp) > 0:
698
+ # Calculate Cosine similarity then return the edge ids
699
+ input_ids = []
700
+ for r in resp:
701
+ if r['embedding']:
702
+ score = calculate_cosine_similarity(
703
+ search_vector, list(map(float, r['embedding'].split(',')))
704
+ )
705
+ if score > min_score:
706
+ input_ids.append({'id': r['id'], 'score': score})
707
+
708
+ # Match the edge ides and return the values
709
+ query = (
710
+ """
711
+ UNWIND $ids as i
712
+ MATCH (n:Entity)
713
+ WHERE id(n)=i.id
714
+ RETURN
715
+ """
716
+ + get_entity_node_return_query(driver.provider)
717
+ + """
718
+ ORDER BY i.score DESC
719
+ LIMIT $limit
720
+ """
721
+ )
722
+ records, header, _ = await driver.execute_query(
723
+ query,
724
+ ids=input_ids,
725
+ search_vector=search_vector,
726
+ limit=limit,
727
+ min_score=min_score,
728
+ routing_='r',
729
+ **filter_params,
730
+ )
731
+ else:
732
+ return []
733
+ else:
734
+ query = (
735
+ """
736
+ MATCH (n:Entity)
737
+ """
738
+ + filter_query
739
+ + """
740
+ WITH n, """
741
+ + get_vector_cosine_func_query('n.name_embedding', search_vector_var, driver.provider)
742
+ + """ AS score
743
+ WHERE score > $min_score
744
+ RETURN
745
+ """
746
+ + get_entity_node_return_query(driver.provider)
747
+ + """
748
+ ORDER BY score DESC
749
+ LIMIT $limit
750
+ """
751
+ )
752
+
753
+ records, _, _ = await driver.execute_query(
754
+ query,
755
+ search_vector=search_vector,
756
+ limit=limit,
757
+ min_score=min_score,
758
+ routing_='r',
759
+ **filter_params,
760
+ )
423
761
 
424
- nodes = [get_entity_node_from_record(record) for record in records]
762
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
425
763
 
426
764
  return nodes
427
765
 
@@ -431,35 +769,85 @@ async def node_bfs_search(
431
769
  bfs_origin_node_uuids: list[str] | None,
432
770
  search_filter: SearchFilters,
433
771
  bfs_max_depth: int,
434
- limit: int,
772
+ group_ids: list[str] | None = None,
773
+ limit: int = RELEVANT_SCHEMA_LIMIT,
435
774
  ) -> list[EntityNode]:
436
- # vector similarity search over entity names
437
- if bfs_origin_node_uuids is None:
775
+ if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0 or bfs_max_depth < 1:
438
776
  return []
439
777
 
440
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
778
+ filter_queries, filter_params = node_search_filter_query_constructor(
779
+ search_filter, driver.provider
780
+ )
441
781
 
442
- query = (
443
- """
444
- UNWIND $bfs_origin_node_uuids AS origin_uuid
445
- MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
446
- WHERE n.group_id = origin.group_id
447
- """
448
- + filter_query
449
- + ENTITY_NODE_RETURN
450
- + """
451
- LIMIT $limit
782
+ if group_ids is not None:
783
+ filter_queries.append('n.group_id IN $group_ids')
784
+ filter_queries.append('origin.group_id IN $group_ids')
785
+ filter_params['group_ids'] = group_ids
786
+
787
+ filter_query = ''
788
+ if filter_queries:
789
+ filter_query = ' AND ' + (' AND '.join(filter_queries))
790
+
791
+ match_queries = [
792
+ f"""
793
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
794
+ MATCH (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
795
+ WHERE n.group_id = origin.group_id
452
796
  """
453
- )
454
- records, _, _ = await driver.execute_query(
455
- query,
456
- params=filter_params,
457
- bfs_origin_node_uuids=bfs_origin_node_uuids,
458
- depth=bfs_max_depth,
459
- limit=limit,
460
- routing_='r',
461
- )
462
- nodes = [get_entity_node_from_record(record) for record in records]
797
+ ]
798
+
799
+ if driver.provider == GraphProvider.NEPTUNE:
800
+ match_queries = [
801
+ f"""
802
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
803
+ MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
804
+ WHERE origin:Entity OR origin.Episode
805
+ AND n.group_id = origin.group_id
806
+ """
807
+ ]
808
+
809
+ if driver.provider == GraphProvider.KUZU:
810
+ depth = bfs_max_depth * 2
811
+ match_queries = [
812
+ """
813
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
814
+ MATCH (origin:Episodic {uuid: origin_uuid})-[:MENTIONS]->(n:Entity)
815
+ WHERE n.group_id = origin.group_id
816
+ """,
817
+ f"""
818
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
819
+ MATCH (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*2..{depth}]->(n:Entity)
820
+ WHERE n.group_id = origin.group_id
821
+ """,
822
+ ]
823
+ if bfs_max_depth > 1:
824
+ depth = (bfs_max_depth - 1) * 2
825
+ match_queries.append(f"""
826
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
827
+ MATCH (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*2..{depth}]->(n:Entity)
828
+ WHERE n.group_id = origin.group_id
829
+ """)
830
+
831
+ records = []
832
+ for match_query in match_queries:
833
+ sub_records, _, _ = await driver.execute_query(
834
+ match_query
835
+ + filter_query
836
+ + """
837
+ RETURN
838
+ """
839
+ + get_entity_node_return_query(driver.provider)
840
+ + """
841
+ LIMIT $limit
842
+ """,
843
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
844
+ limit=limit,
845
+ routing_='r',
846
+ **filter_params,
847
+ )
848
+ records.extend(sub_records)
849
+
850
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
463
851
 
464
852
  return nodes
465
853
 
@@ -471,39 +859,80 @@ async def episode_fulltext_search(
471
859
  group_ids: list[str] | None = None,
472
860
  limit=RELEVANT_SCHEMA_LIMIT,
473
861
  ) -> list[EpisodicNode]:
862
+ if driver.search_interface:
863
+ return await driver.search_interface.episode_fulltext_search(
864
+ driver, query, _search_filter, group_ids, limit
865
+ )
866
+
474
867
  # BM25 search to get top episodes
475
- fuzzy_query = fulltext_query(query, group_ids)
868
+ fuzzy_query = fulltext_query(query, group_ids, driver)
476
869
  if fuzzy_query == '':
477
870
  return []
478
871
 
479
- query = (
480
- get_nodes_query(driver.provider, 'episode_content', '$query')
481
- + """
482
- YIELD node AS episode, score
483
- MATCH (e:Episodic)
484
- WHERE e.uuid = episode.uuid
485
- RETURN
486
- e.content AS content,
487
- e.created_at AS created_at,
488
- e.valid_at AS valid_at,
489
- e.uuid AS uuid,
490
- e.name AS name,
491
- e.group_id AS group_id,
492
- e.source_description AS source_description,
493
- e.source AS source,
494
- e.entity_edges AS entity_edges
495
- ORDER BY score DESC
496
- LIMIT $limit
497
- """
498
- )
872
+ filter_params: dict[str, Any] = {}
873
+ group_filter_query: LiteralString = ''
874
+ if group_ids is not None:
875
+ group_filter_query += '\nAND e.group_id IN $group_ids'
876
+ filter_params['group_ids'] = group_ids
877
+
878
+ if driver.provider == GraphProvider.NEPTUNE:
879
+ res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
880
+ if res['hits']['total']['value'] > 0:
881
+ input_ids = []
882
+ for r in res['hits']['hits']:
883
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
884
+
885
+ # Match the edge ides and return the values
886
+ query = """
887
+ UNWIND $ids as i
888
+ MATCH (e:Episodic)
889
+ WHERE e.uuid=i.uuid
890
+ RETURN
891
+ e.content AS content,
892
+ e.created_at AS created_at,
893
+ e.valid_at AS valid_at,
894
+ e.uuid AS uuid,
895
+ e.name AS name,
896
+ e.group_id AS group_id,
897
+ e.source_description AS source_description,
898
+ e.source AS source,
899
+ e.entity_edges AS entity_edges
900
+ ORDER BY i.score DESC
901
+ LIMIT $limit
902
+ """
903
+ records, _, _ = await driver.execute_query(
904
+ query,
905
+ ids=input_ids,
906
+ query=fuzzy_query,
907
+ limit=limit,
908
+ routing_='r',
909
+ **filter_params,
910
+ )
911
+ else:
912
+ return []
913
+ else:
914
+ query = (
915
+ get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
916
+ + """
917
+ YIELD node AS episode, score
918
+ MATCH (e:Episodic)
919
+ WHERE e.uuid = episode.uuid
920
+ """
921
+ + group_filter_query
922
+ + """
923
+ RETURN
924
+ """
925
+ + EPISODIC_NODE_RETURN
926
+ + """
927
+ ORDER BY score DESC
928
+ LIMIT $limit
929
+ """
930
+ )
931
+
932
+ records, _, _ = await driver.execute_query(
933
+ query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
934
+ )
499
935
 
500
- records, _, _ = await driver.execute_query(
501
- query,
502
- query=fuzzy_query,
503
- group_ids=group_ids,
504
- limit=limit,
505
- routing_='r',
506
- )
507
936
  episodes = [get_episodic_node_from_record(record) for record in records]
508
937
 
509
938
  return episodes
@@ -516,33 +945,75 @@ async def community_fulltext_search(
516
945
  limit=RELEVANT_SCHEMA_LIMIT,
517
946
  ) -> list[CommunityNode]:
518
947
  # BM25 search to get top communities
519
- fuzzy_query = fulltext_query(query, group_ids)
948
+ fuzzy_query = fulltext_query(query, group_ids, driver)
520
949
  if fuzzy_query == '':
521
950
  return []
522
951
 
523
- query = (
524
- get_nodes_query(driver.provider, 'community_name', '$query')
525
- + """
526
- YIELD node AS comm, score
527
- RETURN
528
- comm.uuid AS uuid,
529
- comm.group_id AS group_id,
530
- comm.name AS name,
531
- comm.created_at AS created_at,
532
- comm.summary AS summary,
533
- comm.name_embedding AS name_embedding
534
- ORDER BY score DESC
535
- LIMIT $limit
536
- """
537
- )
952
+ filter_params: dict[str, Any] = {}
953
+ group_filter_query: LiteralString = ''
954
+ if group_ids is not None:
955
+ group_filter_query = 'WHERE c.group_id IN $group_ids'
956
+ filter_params['group_ids'] = group_ids
957
+
958
+ yield_query = 'YIELD node AS c, score'
959
+ if driver.provider == GraphProvider.KUZU:
960
+ yield_query = 'WITH node AS c, score'
961
+
962
+ if driver.provider == GraphProvider.NEPTUNE:
963
+ res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
964
+ if res['hits']['total']['value'] > 0:
965
+ # Calculate Cosine similarity then return the edge ids
966
+ input_ids = []
967
+ for r in res['hits']['hits']:
968
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
969
+
970
+ # Match the edge ides and return the values
971
+ query = """
972
+ UNWIND $ids as i
973
+ MATCH (comm:Community)
974
+ WHERE comm.uuid=i.id
975
+ RETURN
976
+ comm.uuid AS uuid,
977
+ comm.group_id AS group_id,
978
+ comm.name AS name,
979
+ comm.created_at AS created_at,
980
+ comm.summary AS summary,
981
+ [x IN split(comm.name_embedding, ",") | toFloat(x)]AS name_embedding
982
+ ORDER BY i.score DESC
983
+ LIMIT $limit
984
+ """
985
+ records, _, _ = await driver.execute_query(
986
+ query,
987
+ ids=input_ids,
988
+ query=fuzzy_query,
989
+ limit=limit,
990
+ routing_='r',
991
+ **filter_params,
992
+ )
993
+ else:
994
+ return []
995
+ else:
996
+ query = (
997
+ get_nodes_query('community_name', '$query', limit=limit, provider=driver.provider)
998
+ + yield_query
999
+ + """
1000
+ WITH c, score
1001
+ """
1002
+ + group_filter_query
1003
+ + """
1004
+ RETURN
1005
+ """
1006
+ + COMMUNITY_NODE_RETURN
1007
+ + """
1008
+ ORDER BY score DESC
1009
+ LIMIT $limit
1010
+ """
1011
+ )
1012
+
1013
+ records, _, _ = await driver.execute_query(
1014
+ query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
1015
+ )
538
1016
 
539
- records, _, _ = await driver.execute_query(
540
- query,
541
- query=fuzzy_query,
542
- group_ids=group_ids,
543
- limit=limit,
544
- routing_='r',
545
- )
546
1017
  communities = [get_community_node_from_record(record) for record in records]
547
1018
 
548
1019
  return communities
@@ -560,40 +1031,99 @@ async def community_similarity_search(
560
1031
 
561
1032
  group_filter_query: LiteralString = ''
562
1033
  if group_ids is not None:
563
- group_filter_query += 'WHERE comm.group_id IN $group_ids'
1034
+ group_filter_query += ' WHERE c.group_id IN $group_ids'
564
1035
  query_params['group_ids'] = group_ids
565
1036
 
566
- query = (
567
- RUNTIME_QUERY
568
- + """
569
- MATCH (comm:Community)
570
- """
571
- + group_filter_query
572
- + """
573
- WITH comm, """
574
- + get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider)
575
- + """ AS score
576
- WHERE score > $min_score
577
- RETURN
578
- comm.uuid As uuid,
579
- comm.group_id AS group_id,
580
- comm.name AS name,
581
- comm.created_at AS created_at,
582
- comm.summary AS summary,
583
- comm.name_embedding AS name_embedding
584
- ORDER BY score DESC
585
- LIMIT $limit
586
- """
587
- )
1037
+ if driver.provider == GraphProvider.NEPTUNE:
1038
+ query = (
1039
+ """
1040
+ MATCH (n:Community)
1041
+ """
1042
+ + group_filter_query
1043
+ + """
1044
+ RETURN DISTINCT id(n) as id, n.name_embedding as embedding
1045
+ """
1046
+ )
1047
+ resp, header, _ = await driver.execute_query(
1048
+ query,
1049
+ search_vector=search_vector,
1050
+ limit=limit,
1051
+ min_score=min_score,
1052
+ routing_='r',
1053
+ **query_params,
1054
+ )
1055
+
1056
+ if len(resp) > 0:
1057
+ # Calculate Cosine similarity then return the edge ids
1058
+ input_ids = []
1059
+ for r in resp:
1060
+ if r['embedding']:
1061
+ score = calculate_cosine_similarity(
1062
+ search_vector, list(map(float, r['embedding'].split(',')))
1063
+ )
1064
+ if score > min_score:
1065
+ input_ids.append({'id': r['id'], 'score': score})
1066
+
1067
+ # Match the edge ides and return the values
1068
+ query = """
1069
+ UNWIND $ids as i
1070
+ MATCH (comm:Community)
1071
+ WHERE id(comm)=i.id
1072
+ RETURN
1073
+ comm.uuid As uuid,
1074
+ comm.group_id AS group_id,
1075
+ comm.name AS name,
1076
+ comm.created_at AS created_at,
1077
+ comm.summary AS summary,
1078
+ comm.name_embedding AS name_embedding
1079
+ ORDER BY i.score DESC
1080
+ LIMIT $limit
1081
+ """
1082
+ records, header, _ = await driver.execute_query(
1083
+ query,
1084
+ ids=input_ids,
1085
+ search_vector=search_vector,
1086
+ limit=limit,
1087
+ min_score=min_score,
1088
+ routing_='r',
1089
+ **query_params,
1090
+ )
1091
+ else:
1092
+ return []
1093
+ else:
1094
+ search_vector_var = '$search_vector'
1095
+ if driver.provider == GraphProvider.KUZU:
1096
+ search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
1097
+
1098
+ query = (
1099
+ """
1100
+ MATCH (c:Community)
1101
+ """
1102
+ + group_filter_query
1103
+ + """
1104
+ WITH c,
1105
+ """
1106
+ + get_vector_cosine_func_query('c.name_embedding', search_vector_var, driver.provider)
1107
+ + """ AS score
1108
+ WHERE score > $min_score
1109
+ RETURN
1110
+ """
1111
+ + COMMUNITY_NODE_RETURN
1112
+ + """
1113
+ ORDER BY score DESC
1114
+ LIMIT $limit
1115
+ """
1116
+ )
1117
+
1118
+ records, _, _ = await driver.execute_query(
1119
+ query,
1120
+ search_vector=search_vector,
1121
+ limit=limit,
1122
+ min_score=min_score,
1123
+ routing_='r',
1124
+ **query_params,
1125
+ )
588
1126
 
589
- records, _, _ = await driver.execute_query(
590
- query,
591
- search_vector=search_vector,
592
- group_ids=group_ids,
593
- limit=limit,
594
- min_score=min_score,
595
- routing_='r',
596
- )
597
1127
  communities = [get_community_node_from_record(record) for record in records]
598
1128
 
599
1129
  return communities
@@ -664,7 +1194,7 @@ async def hybrid_node_search(
664
1194
  }
665
1195
  result_uuids = [[node.uuid for node in result] for result in results]
666
1196
 
667
- ranked_uuids = rrf(result_uuids)
1197
+ ranked_uuids, _ = rrf(result_uuids)
668
1198
 
669
1199
  relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
670
1200
 
@@ -684,80 +1214,140 @@ async def get_relevant_nodes(
684
1214
  return []
685
1215
 
686
1216
  group_id = nodes[0].group_id
687
-
688
- # vector similarity search over entity names
689
- query_params: dict[str, Any] = {}
690
-
691
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
692
- query_params.update(filter_params)
693
-
694
- query = (
695
- RUNTIME_QUERY
696
- + """
697
- UNWIND $nodes AS node
698
- MATCH (n:Entity {group_id: $group_id})
699
- """
700
- + filter_query
701
- + """
702
- WITH node, n, """
703
- + get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
704
- + """ AS score
705
- WHERE score > $min_score
706
- WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
707
- """
708
- + get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
709
- + """
710
- YIELD node AS m
711
- WHERE m.group_id = $group_id
712
- WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
713
-
714
- WITH node,
715
- top_vector_nodes,
716
- [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
717
-
718
- WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
719
-
720
- UNWIND combined_nodes AS combined_node
721
- WITH node, collect(DISTINCT combined_node) AS deduped_nodes
722
-
723
- RETURN
724
- node.uuid AS search_node_uuid,
725
- [x IN deduped_nodes | {
726
- uuid: x.uuid,
727
- name: x.name,
728
- name_embedding: x.name_embedding,
729
- group_id: x.group_id,
730
- created_at: x.created_at,
731
- summary: x.summary,
732
- labels: labels(x),
733
- attributes: properties(x)
734
- }] AS matches
735
- """
736
- )
737
-
738
1217
  query_nodes = [
739
1218
  {
740
1219
  'uuid': node.uuid,
741
1220
  'name': node.name,
742
1221
  'name_embedding': node.name_embedding,
743
- 'fulltext_query': fulltext_query(node.name, [node.group_id]),
1222
+ 'fulltext_query': fulltext_query(node.name, [node.group_id], driver),
744
1223
  }
745
1224
  for node in nodes
746
1225
  ]
747
1226
 
1227
+ filter_queries, filter_params = node_search_filter_query_constructor(
1228
+ search_filter, driver.provider
1229
+ )
1230
+
1231
+ filter_query = ''
1232
+ if filter_queries:
1233
+ filter_query = 'WHERE ' + (' AND '.join(filter_queries))
1234
+
1235
+ if driver.provider == GraphProvider.KUZU:
1236
+ embedding_size = len(nodes[0].name_embedding) if nodes[0].name_embedding is not None else 0
1237
+ if embedding_size == 0:
1238
+ return []
1239
+
1240
+ # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
1241
+ query = (
1242
+ """
1243
+ UNWIND $nodes AS node
1244
+ MATCH (n:Entity {group_id: $group_id})
1245
+ """
1246
+ + filter_query
1247
+ + """
1248
+ WITH node, n, """
1249
+ + get_vector_cosine_func_query(
1250
+ 'n.name_embedding',
1251
+ f'CAST(node.name_embedding AS FLOAT[{embedding_size}])',
1252
+ driver.provider,
1253
+ )
1254
+ + """ AS score
1255
+ WHERE score > $min_score
1256
+ WITH node, collect(n)[:$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
1257
+ """
1258
+ + get_nodes_query(
1259
+ 'node_name_and_summary',
1260
+ 'node.fulltext_query',
1261
+ limit=limit,
1262
+ provider=driver.provider,
1263
+ )
1264
+ + """
1265
+ WITH node AS m
1266
+ WHERE m.group_id = $group_id AND NOT m.uuid IN vector_node_uuids
1267
+ WITH node, top_vector_nodes, collect(m) AS fulltext_nodes
1268
+
1269
+ WITH node, list_concat(top_vector_nodes, fulltext_nodes) AS combined_nodes
1270
+
1271
+ UNWIND combined_nodes AS x
1272
+ WITH node, collect(DISTINCT {
1273
+ uuid: x.uuid,
1274
+ name: x.name,
1275
+ name_embedding: x.name_embedding,
1276
+ group_id: x.group_id,
1277
+ created_at: x.created_at,
1278
+ summary: x.summary,
1279
+ labels: x.labels,
1280
+ attributes: x.attributes
1281
+ }) AS matches
1282
+
1283
+ RETURN
1284
+ node.uuid AS search_node_uuid, matches
1285
+ """
1286
+ )
1287
+ else:
1288
+ query = (
1289
+ """
1290
+ UNWIND $nodes AS node
1291
+ MATCH (n:Entity {group_id: $group_id})
1292
+ """
1293
+ + filter_query
1294
+ + """
1295
+ WITH node, n, """
1296
+ + get_vector_cosine_func_query(
1297
+ 'n.name_embedding', 'node.name_embedding', driver.provider
1298
+ )
1299
+ + """ AS score
1300
+ WHERE score > $min_score
1301
+ WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
1302
+ """
1303
+ + get_nodes_query(
1304
+ 'node_name_and_summary',
1305
+ 'node.fulltext_query',
1306
+ limit=limit,
1307
+ provider=driver.provider,
1308
+ )
1309
+ + """
1310
+ YIELD node AS m
1311
+ WHERE m.group_id = $group_id
1312
+ WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
1313
+
1314
+ WITH node,
1315
+ top_vector_nodes,
1316
+ [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
1317
+
1318
+ WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
1319
+
1320
+ UNWIND combined_nodes AS combined_node
1321
+ WITH node, collect(DISTINCT combined_node) AS deduped_nodes
1322
+
1323
+ RETURN
1324
+ node.uuid AS search_node_uuid,
1325
+ [x IN deduped_nodes | {
1326
+ uuid: x.uuid,
1327
+ name: x.name,
1328
+ name_embedding: x.name_embedding,
1329
+ group_id: x.group_id,
1330
+ created_at: x.created_at,
1331
+ summary: x.summary,
1332
+ labels: labels(x),
1333
+ attributes: properties(x)
1334
+ }] AS matches
1335
+ """
1336
+ )
1337
+
748
1338
  results, _, _ = await driver.execute_query(
749
1339
  query,
750
- params=query_params,
751
1340
  nodes=query_nodes,
752
1341
  group_id=group_id,
753
1342
  limit=limit,
754
1343
  min_score=min_score,
755
1344
  routing_='r',
1345
+ **filter_params,
756
1346
  )
757
1347
 
758
1348
  relevant_nodes_dict: dict[str, list[EntityNode]] = {
759
1349
  result['search_node_uuid']: [
760
- get_entity_node_from_record(record) for record in result['matches']
1350
+ get_entity_node_from_record(record, driver.provider) for record in result['matches']
761
1351
  ]
762
1352
  for result in results
763
1353
  }
@@ -777,25 +1367,52 @@ async def get_relevant_edges(
777
1367
  if len(edges) == 0:
778
1368
  return []
779
1369
 
780
- query_params: dict[str, Any] = {}
1370
+ filter_queries, filter_params = edge_search_filter_query_constructor(
1371
+ search_filter, driver.provider
1372
+ )
781
1373
 
782
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
783
- query_params.update(filter_params)
1374
+ filter_query = ''
1375
+ if filter_queries:
1376
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
784
1377
 
785
- query = (
786
- RUNTIME_QUERY
787
- + """
788
- UNWIND $edges AS edge
789
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
790
- """
791
- + filter_query
792
- + """
793
- WITH e, edge, """
794
- + get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
795
- + """ AS score
796
- WHERE score > $min_score
797
- WITH edge, e, score
798
- ORDER BY score DESC
1378
+ if driver.provider == GraphProvider.NEPTUNE:
1379
+ query = (
1380
+ """
1381
+ UNWIND $edges AS edge
1382
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1383
+ """
1384
+ + filter_query
1385
+ + """
1386
+ WITH e, edge
1387
+ RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
1388
+ edge.fact_embedding as target_embedding
1389
+ """
1390
+ )
1391
+ resp, _, _ = await driver.execute_query(
1392
+ query,
1393
+ edges=[edge.model_dump() for edge in edges],
1394
+ limit=limit,
1395
+ min_score=min_score,
1396
+ routing_='r',
1397
+ **filter_params,
1398
+ )
1399
+
1400
+ # Calculate Cosine similarity then return the edge ids
1401
+ input_ids = []
1402
+ for r in resp:
1403
+ score = calculate_cosine_similarity(
1404
+ list(map(float, r['source_embedding'].split(','))), r['target_embedding']
1405
+ )
1406
+ if score > min_score:
1407
+ input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
1408
+
1409
+ # Match the edge ides and return the values
1410
+ query = """
1411
+ UNWIND $ids AS edge
1412
+ MATCH ()-[e]->()
1413
+ WHERE id(e) = edge.id
1414
+ WITH edge, e
1415
+ ORDER BY edge.score DESC
799
1416
  RETURN edge.uuid AS search_edge_uuid,
800
1417
  collect({
801
1418
  uuid: e.uuid,
@@ -805,28 +1422,117 @@ async def get_relevant_edges(
805
1422
  name: e.name,
806
1423
  group_id: e.group_id,
807
1424
  fact: e.fact,
808
- fact_embedding: e.fact_embedding,
809
- episodes: e.episodes,
1425
+ fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
1426
+ episodes: split(e.episodes, ","),
810
1427
  expired_at: e.expired_at,
811
1428
  valid_at: e.valid_at,
812
1429
  invalid_at: e.invalid_at,
813
1430
  attributes: properties(e)
814
1431
  })[..$limit] AS matches
815
- """
816
- )
817
-
818
- results, _, _ = await driver.execute_query(
819
- query,
820
- params=query_params,
821
- edges=[edge.model_dump() for edge in edges],
822
- limit=limit,
823
- min_score=min_score,
824
- routing_='r',
825
- )
1432
+ """
1433
+
1434
+ results, _, _ = await driver.execute_query(
1435
+ query,
1436
+ ids=input_ids,
1437
+ edges=[edge.model_dump() for edge in edges],
1438
+ limit=limit,
1439
+ min_score=min_score,
1440
+ routing_='r',
1441
+ **filter_params,
1442
+ )
1443
+ else:
1444
+ if driver.provider == GraphProvider.KUZU:
1445
+ embedding_size = (
1446
+ len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
1447
+ )
1448
+ if embedding_size == 0:
1449
+ return []
1450
+
1451
+ query = (
1452
+ """
1453
+ UNWIND $edges AS edge
1454
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
1455
+ """
1456
+ + filter_query
1457
+ + """
1458
+ WITH e, edge, n, m, """
1459
+ + get_vector_cosine_func_query(
1460
+ 'e.fact_embedding',
1461
+ f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
1462
+ driver.provider,
1463
+ )
1464
+ + """ AS score
1465
+ WHERE score > $min_score
1466
+ WITH e, edge, n, m, score
1467
+ ORDER BY score DESC
1468
+ LIMIT $limit
1469
+ RETURN
1470
+ edge.uuid AS search_edge_uuid,
1471
+ collect({
1472
+ uuid: e.uuid,
1473
+ source_node_uuid: n.uuid,
1474
+ target_node_uuid: m.uuid,
1475
+ created_at: e.created_at,
1476
+ name: e.name,
1477
+ group_id: e.group_id,
1478
+ fact: e.fact,
1479
+ fact_embedding: e.fact_embedding,
1480
+ episodes: e.episodes,
1481
+ expired_at: e.expired_at,
1482
+ valid_at: e.valid_at,
1483
+ invalid_at: e.invalid_at,
1484
+ attributes: e.attributes
1485
+ }) AS matches
1486
+ """
1487
+ )
1488
+ else:
1489
+ query = (
1490
+ """
1491
+ UNWIND $edges AS edge
1492
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1493
+ """
1494
+ + filter_query
1495
+ + """
1496
+ WITH e, edge, """
1497
+ + get_vector_cosine_func_query(
1498
+ 'e.fact_embedding', 'edge.fact_embedding', driver.provider
1499
+ )
1500
+ + """ AS score
1501
+ WHERE score > $min_score
1502
+ WITH edge, e, score
1503
+ ORDER BY score DESC
1504
+ RETURN
1505
+ edge.uuid AS search_edge_uuid,
1506
+ collect({
1507
+ uuid: e.uuid,
1508
+ source_node_uuid: startNode(e).uuid,
1509
+ target_node_uuid: endNode(e).uuid,
1510
+ created_at: e.created_at,
1511
+ name: e.name,
1512
+ group_id: e.group_id,
1513
+ fact: e.fact,
1514
+ fact_embedding: e.fact_embedding,
1515
+ episodes: e.episodes,
1516
+ expired_at: e.expired_at,
1517
+ valid_at: e.valid_at,
1518
+ invalid_at: e.invalid_at,
1519
+ attributes: properties(e)
1520
+ })[..$limit] AS matches
1521
+ """
1522
+ )
1523
+
1524
+ results, _, _ = await driver.execute_query(
1525
+ query,
1526
+ edges=[edge.model_dump() for edge in edges],
1527
+ limit=limit,
1528
+ min_score=min_score,
1529
+ routing_='r',
1530
+ **filter_params,
1531
+ )
826
1532
 
827
1533
  relevant_edges_dict: dict[str, list[EntityEdge]] = {
828
1534
  result['search_edge_uuid']: [
829
- get_entity_edge_from_record(record) for record in result['matches']
1535
+ get_entity_edge_from_record(record, driver.provider) for record in result['matches']
830
1536
  ]
831
1537
  for result in results
832
1538
  }
@@ -846,26 +1552,54 @@ async def get_edge_invalidation_candidates(
846
1552
  if len(edges) == 0:
847
1553
  return []
848
1554
 
849
- query_params: dict[str, Any] = {}
1555
+ filter_queries, filter_params = edge_search_filter_query_constructor(
1556
+ search_filter, driver.provider
1557
+ )
850
1558
 
851
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
852
- query_params.update(filter_params)
1559
+ filter_query = ''
1560
+ if filter_queries:
1561
+ filter_query = ' AND ' + (' AND '.join(filter_queries))
853
1562
 
854
- query = (
855
- RUNTIME_QUERY
856
- + """
857
- UNWIND $edges AS edge
858
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
859
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
860
- """
861
- + filter_query
862
- + """
863
- WITH edge, e, """
864
- + get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
865
- + """ AS score
866
- WHERE score > $min_score
867
- WITH edge, e, score
868
- ORDER BY score DESC
1563
+ if driver.provider == GraphProvider.NEPTUNE:
1564
+ query = (
1565
+ """
1566
+ UNWIND $edges AS edge
1567
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1568
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1569
+ """
1570
+ + filter_query
1571
+ + """
1572
+ WITH e, edge
1573
+ RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding,
1574
+ edge.fact_embedding as target_embedding,
1575
+ edge.uuid as search_edge_uuid
1576
+ """
1577
+ )
1578
+ resp, _, _ = await driver.execute_query(
1579
+ query,
1580
+ edges=[edge.model_dump() for edge in edges],
1581
+ limit=limit,
1582
+ min_score=min_score,
1583
+ routing_='r',
1584
+ **filter_params,
1585
+ )
1586
+
1587
+ # Calculate Cosine similarity then return the edge ids
1588
+ input_ids = []
1589
+ for r in resp:
1590
+ score = calculate_cosine_similarity(
1591
+ list(map(float, r['source_embedding'].split(','))), r['target_embedding']
1592
+ )
1593
+ if score > min_score:
1594
+ input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
1595
+
1596
+ # Match the edge ides and return the values
1597
+ query = """
1598
+ UNWIND $ids AS edge
1599
+ MATCH ()-[e]->()
1600
+ WHERE id(e) = edge.id
1601
+ WITH edge, e
1602
+ ORDER BY edge.score DESC
869
1603
  RETURN edge.uuid AS search_edge_uuid,
870
1604
  collect({
871
1605
  uuid: e.uuid,
@@ -875,27 +1609,117 @@ async def get_edge_invalidation_candidates(
875
1609
  name: e.name,
876
1610
  group_id: e.group_id,
877
1611
  fact: e.fact,
878
- fact_embedding: e.fact_embedding,
879
- episodes: e.episodes,
1612
+ fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
1613
+ episodes: split(e.episodes, ","),
880
1614
  expired_at: e.expired_at,
881
1615
  valid_at: e.valid_at,
882
1616
  invalid_at: e.invalid_at,
883
1617
  attributes: properties(e)
884
1618
  })[..$limit] AS matches
885
- """
886
- )
887
-
888
- results, _, _ = await driver.execute_query(
889
- query,
890
- params=query_params,
891
- edges=[edge.model_dump() for edge in edges],
892
- limit=limit,
893
- min_score=min_score,
894
- routing_='r',
895
- )
1619
+ """
1620
+ results, _, _ = await driver.execute_query(
1621
+ query,
1622
+ ids=input_ids,
1623
+ edges=[edge.model_dump() for edge in edges],
1624
+ limit=limit,
1625
+ min_score=min_score,
1626
+ routing_='r',
1627
+ **filter_params,
1628
+ )
1629
+ else:
1630
+ if driver.provider == GraphProvider.KUZU:
1631
+ embedding_size = (
1632
+ len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
1633
+ )
1634
+ if embedding_size == 0:
1635
+ return []
1636
+
1637
+ query = (
1638
+ """
1639
+ UNWIND $edges AS edge
1640
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
1641
+ WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
1642
+ """
1643
+ + filter_query
1644
+ + """
1645
+ WITH edge, e, n, m, """
1646
+ + get_vector_cosine_func_query(
1647
+ 'e.fact_embedding',
1648
+ f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
1649
+ driver.provider,
1650
+ )
1651
+ + """ AS score
1652
+ WHERE score > $min_score
1653
+ WITH edge, e, n, m, score
1654
+ ORDER BY score DESC
1655
+ LIMIT $limit
1656
+ RETURN
1657
+ edge.uuid AS search_edge_uuid,
1658
+ collect({
1659
+ uuid: e.uuid,
1660
+ source_node_uuid: n.uuid,
1661
+ target_node_uuid: m.uuid,
1662
+ created_at: e.created_at,
1663
+ name: e.name,
1664
+ group_id: e.group_id,
1665
+ fact: e.fact,
1666
+ fact_embedding: e.fact_embedding,
1667
+ episodes: e.episodes,
1668
+ expired_at: e.expired_at,
1669
+ valid_at: e.valid_at,
1670
+ invalid_at: e.invalid_at,
1671
+ attributes: e.attributes
1672
+ }) AS matches
1673
+ """
1674
+ )
1675
+ else:
1676
+ query = (
1677
+ """
1678
+ UNWIND $edges AS edge
1679
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1680
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1681
+ """
1682
+ + filter_query
1683
+ + """
1684
+ WITH edge, e, """
1685
+ + get_vector_cosine_func_query(
1686
+ 'e.fact_embedding', 'edge.fact_embedding', driver.provider
1687
+ )
1688
+ + """ AS score
1689
+ WHERE score > $min_score
1690
+ WITH edge, e, score
1691
+ ORDER BY score DESC
1692
+ RETURN
1693
+ edge.uuid AS search_edge_uuid,
1694
+ collect({
1695
+ uuid: e.uuid,
1696
+ source_node_uuid: startNode(e).uuid,
1697
+ target_node_uuid: endNode(e).uuid,
1698
+ created_at: e.created_at,
1699
+ name: e.name,
1700
+ group_id: e.group_id,
1701
+ fact: e.fact,
1702
+ fact_embedding: e.fact_embedding,
1703
+ episodes: e.episodes,
1704
+ expired_at: e.expired_at,
1705
+ valid_at: e.valid_at,
1706
+ invalid_at: e.invalid_at,
1707
+ attributes: properties(e)
1708
+ })[..$limit] AS matches
1709
+ """
1710
+ )
1711
+
1712
+ results, _, _ = await driver.execute_query(
1713
+ query,
1714
+ edges=[edge.model_dump() for edge in edges],
1715
+ limit=limit,
1716
+ min_score=min_score,
1717
+ routing_='r',
1718
+ **filter_params,
1719
+ )
896
1720
  invalidation_edges_dict: dict[str, list[EntityEdge]] = {
897
1721
  result['search_edge_uuid']: [
898
- get_entity_edge_from_record(record) for record in result['matches']
1722
+ get_entity_edge_from_record(record, driver.provider) for record in result['matches']
899
1723
  ]
900
1724
  for result in results
901
1725
  }
@@ -906,7 +1730,9 @@ async def get_edge_invalidation_candidates(
906
1730
 
907
1731
 
908
1732
  # takes in a list of rankings of uuids
909
- def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]:
1733
+ def rrf(
1734
+ results: list[list[str]], rank_const=1, min_score: float = 0
1735
+ ) -> tuple[list[str], list[float]]:
910
1736
  scores: dict[str, float] = defaultdict(float)
911
1737
  for result in results:
912
1738
  for i, uuid in enumerate(result):
@@ -917,7 +1743,9 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
917
1743
 
918
1744
  sorted_uuids = [term[0] for term in scored_uuids]
919
1745
 
920
- return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
1746
+ return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
1747
+ scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
1748
+ ]
921
1749
 
922
1750
 
923
1751
  async def node_distance_reranker(
@@ -925,24 +1753,31 @@ async def node_distance_reranker(
925
1753
  node_uuids: list[str],
926
1754
  center_node_uuid: str,
927
1755
  min_score: float = 0,
928
- ) -> list[str]:
1756
+ ) -> tuple[list[str], list[float]]:
929
1757
  # filter out node_uuid center node node uuid
930
1758
  filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
931
1759
  scores: dict[str, float] = {center_node_uuid: 0.0}
932
1760
 
933
- # Find the shortest path to center node
934
1761
  query = """
1762
+ UNWIND $node_uuids AS node_uuid
1763
+ MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
1764
+ RETURN 1 AS score, node_uuid AS uuid
1765
+ """
1766
+ if driver.provider == GraphProvider.KUZU:
1767
+ query = """
935
1768
  UNWIND $node_uuids AS node_uuid
936
- MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
1769
+ MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(n:Entity {uuid: node_uuid})
937
1770
  RETURN 1 AS score, node_uuid AS uuid
938
1771
  """
1772
+
1773
+ # Find the shortest path to center node
939
1774
  results, header, _ = await driver.execute_query(
940
1775
  query,
941
1776
  node_uuids=filtered_uuids,
942
1777
  center_uuid=center_node_uuid,
943
1778
  routing_='r',
944
1779
  )
945
- if driver.provider == 'falkordb':
1780
+ if driver.provider == GraphProvider.FALKORDB:
946
1781
  results = [dict(zip(header, row, strict=True)) for row in results]
947
1782
 
948
1783
  for result in results:
@@ -962,24 +1797,25 @@ async def node_distance_reranker(
962
1797
  scores[center_node_uuid] = 0.1
963
1798
  filtered_uuids = [center_node_uuid] + filtered_uuids
964
1799
 
965
- return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score]
1800
+ return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score], [
1801
+ 1 / scores[uuid] for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score
1802
+ ]
966
1803
 
967
1804
 
968
1805
  async def episode_mentions_reranker(
969
1806
  driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
970
- ) -> list[str]:
1807
+ ) -> tuple[list[str], list[float]]:
971
1808
  # use rrf as a preliminary ranker
972
- sorted_uuids = rrf(node_uuids)
1809
+ sorted_uuids, _ = rrf(node_uuids)
973
1810
  scores: dict[str, float] = {}
974
1811
 
975
1812
  # Find the shortest path to center node
976
- query = """
977
- UNWIND $node_uuids AS node_uuid
1813
+ results, _, _ = await driver.execute_query(
1814
+ """
1815
+ UNWIND $node_uuids AS node_uuid
978
1816
  MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
979
1817
  RETURN count(*) AS score, n.uuid AS uuid
980
- """
981
- results, _, _ = await driver.execute_query(
982
- query,
1818
+ """,
983
1819
  node_uuids=sorted_uuids,
984
1820
  routing_='r',
985
1821
  )
@@ -987,10 +1823,16 @@ async def episode_mentions_reranker(
987
1823
  for result in results:
988
1824
  scores[result['uuid']] = result['score']
989
1825
 
1826
+ for uuid in sorted_uuids:
1827
+ if uuid not in scores:
1828
+ scores[uuid] = float('inf')
1829
+
990
1830
  # rerank on shortest distance
991
1831
  sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
992
1832
 
993
- return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
1833
+ return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
1834
+ scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
1835
+ ]
994
1836
 
995
1837
 
996
1838
  def maximal_marginal_relevance(
@@ -998,7 +1840,7 @@ def maximal_marginal_relevance(
998
1840
  candidates: dict[str, list[float]],
999
1841
  mmr_lambda: float = DEFAULT_MMR_LAMBDA,
1000
1842
  min_score: float = -2.0,
1001
- ) -> list[str]:
1843
+ ) -> tuple[list[str], list[float]]:
1002
1844
  start = time()
1003
1845
  query_array = np.array(query_vector)
1004
1846
  candidate_arrays: dict[str, NDArray] = {}
@@ -1029,21 +1871,36 @@ def maximal_marginal_relevance(
1029
1871
  end = time()
1030
1872
  logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
1031
1873
 
1032
- return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
1874
+ return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score], [
1875
+ mmr_scores[uuid] for uuid in uuids if mmr_scores[uuid] >= min_score
1876
+ ]
1033
1877
 
1034
1878
 
1035
1879
  async def get_embeddings_for_nodes(
1036
1880
  driver: GraphDriver, nodes: list[EntityNode]
1037
1881
  ) -> dict[str, list[float]]:
1038
- query: LiteralString = """MATCH (n:Entity)
1039
- WHERE n.uuid IN $node_uuids
1040
- RETURN DISTINCT
1041
- n.uuid AS uuid,
1042
- n.name_embedding AS name_embedding
1043
- """
1044
-
1882
+ if driver.graph_operations_interface:
1883
+ return await driver.graph_operations_interface.node_load_embeddings_bulk(driver, nodes)
1884
+ elif driver.provider == GraphProvider.NEPTUNE:
1885
+ query = """
1886
+ MATCH (n:Entity)
1887
+ WHERE n.uuid IN $node_uuids
1888
+ RETURN DISTINCT
1889
+ n.uuid AS uuid,
1890
+ split(n.name_embedding, ",") AS name_embedding
1891
+ """
1892
+ else:
1893
+ query = """
1894
+ MATCH (n:Entity)
1895
+ WHERE n.uuid IN $node_uuids
1896
+ RETURN DISTINCT
1897
+ n.uuid AS uuid,
1898
+ n.name_embedding AS name_embedding
1899
+ """
1045
1900
  results, _, _ = await driver.execute_query(
1046
- query, node_uuids=[node.uuid for node in nodes], routing_='r'
1901
+ query,
1902
+ node_uuids=[node.uuid for node in nodes],
1903
+ routing_='r',
1047
1904
  )
1048
1905
 
1049
1906
  embeddings_dict: dict[str, list[float]] = {}
@@ -1059,13 +1916,22 @@ async def get_embeddings_for_nodes(
1059
1916
  async def get_embeddings_for_communities(
1060
1917
  driver: GraphDriver, communities: list[CommunityNode]
1061
1918
  ) -> dict[str, list[float]]:
1062
- query: LiteralString = """MATCH (c:Community)
1063
- WHERE c.uuid IN $community_uuids
1064
- RETURN DISTINCT
1065
- c.uuid AS uuid,
1066
- c.name_embedding AS name_embedding
1067
- """
1068
-
1919
+ if driver.provider == GraphProvider.NEPTUNE:
1920
+ query = """
1921
+ MATCH (c:Community)
1922
+ WHERE c.uuid IN $community_uuids
1923
+ RETURN DISTINCT
1924
+ c.uuid AS uuid,
1925
+ split(c.name_embedding, ",") AS name_embedding
1926
+ """
1927
+ else:
1928
+ query = """
1929
+ MATCH (c:Community)
1930
+ WHERE c.uuid IN $community_uuids
1931
+ RETURN DISTINCT
1932
+ c.uuid AS uuid,
1933
+ c.name_embedding AS name_embedding
1934
+ """
1069
1935
  results, _, _ = await driver.execute_query(
1070
1936
  query,
1071
1937
  community_uuids=[community.uuid for community in communities],
@@ -1085,13 +1951,34 @@ async def get_embeddings_for_communities(
1085
1951
  async def get_embeddings_for_edges(
1086
1952
  driver: GraphDriver, edges: list[EntityEdge]
1087
1953
  ) -> dict[str, list[float]]:
1088
- query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1089
- WHERE e.uuid IN $edge_uuids
1090
- RETURN DISTINCT
1091
- e.uuid AS uuid,
1092
- e.fact_embedding AS fact_embedding
1093
- """
1954
+ if driver.graph_operations_interface:
1955
+ return await driver.graph_operations_interface.edge_load_embeddings_bulk(driver, edges)
1956
+ elif driver.provider == GraphProvider.NEPTUNE:
1957
+ query = """
1958
+ MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1959
+ WHERE e.uuid IN $edge_uuids
1960
+ RETURN DISTINCT
1961
+ e.uuid AS uuid,
1962
+ split(e.fact_embedding, ",") AS fact_embedding
1963
+ """
1964
+ else:
1965
+ match_query = """
1966
+ MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1967
+ """
1968
+ if driver.provider == GraphProvider.KUZU:
1969
+ match_query = """
1970
+ MATCH (n:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m:Entity)
1971
+ """
1094
1972
 
1973
+ query = (
1974
+ match_query
1975
+ + """
1976
+ WHERE e.uuid IN $edge_uuids
1977
+ RETURN DISTINCT
1978
+ e.uuid AS uuid,
1979
+ e.fact_embedding AS fact_embedding
1980
+ """
1981
+ )
1095
1982
  results, _, _ = await driver.execute_query(
1096
1983
  query,
1097
1984
  edge_uuids=[edge.uuid for edge in edges],