graphiti-core 0.12.0rc1__py3-none-any.whl → 0.24.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. graphiti_core/cross_encoder/bge_reranker_client.py +12 -2
  2. graphiti_core/cross_encoder/gemini_reranker_client.py +161 -0
  3. graphiti_core/cross_encoder/openai_reranker_client.py +7 -5
  4. graphiti_core/decorators.py +110 -0
  5. graphiti_core/driver/__init__.py +19 -0
  6. graphiti_core/driver/driver.py +124 -0
  7. graphiti_core/driver/falkordb_driver.py +362 -0
  8. graphiti_core/driver/graph_operations/graph_operations.py +191 -0
  9. graphiti_core/driver/kuzu_driver.py +182 -0
  10. graphiti_core/driver/neo4j_driver.py +117 -0
  11. graphiti_core/driver/neptune_driver.py +305 -0
  12. graphiti_core/driver/search_interface/search_interface.py +89 -0
  13. graphiti_core/edges.py +287 -172
  14. graphiti_core/embedder/azure_openai.py +71 -0
  15. graphiti_core/embedder/client.py +2 -1
  16. graphiti_core/embedder/gemini.py +116 -22
  17. graphiti_core/embedder/voyage.py +13 -2
  18. graphiti_core/errors.py +8 -0
  19. graphiti_core/graph_queries.py +162 -0
  20. graphiti_core/graphiti.py +705 -193
  21. graphiti_core/graphiti_types.py +4 -2
  22. graphiti_core/helpers.py +87 -10
  23. graphiti_core/llm_client/__init__.py +16 -0
  24. graphiti_core/llm_client/anthropic_client.py +159 -56
  25. graphiti_core/llm_client/azure_openai_client.py +115 -0
  26. graphiti_core/llm_client/client.py +98 -21
  27. graphiti_core/llm_client/config.py +1 -1
  28. graphiti_core/llm_client/gemini_client.py +290 -41
  29. graphiti_core/llm_client/groq_client.py +14 -3
  30. graphiti_core/llm_client/openai_base_client.py +261 -0
  31. graphiti_core/llm_client/openai_client.py +56 -132
  32. graphiti_core/llm_client/openai_generic_client.py +91 -56
  33. graphiti_core/models/edges/edge_db_queries.py +259 -35
  34. graphiti_core/models/nodes/node_db_queries.py +311 -32
  35. graphiti_core/nodes.py +420 -205
  36. graphiti_core/prompts/dedupe_edges.py +46 -32
  37. graphiti_core/prompts/dedupe_nodes.py +67 -42
  38. graphiti_core/prompts/eval.py +4 -4
  39. graphiti_core/prompts/extract_edges.py +27 -16
  40. graphiti_core/prompts/extract_nodes.py +74 -31
  41. graphiti_core/prompts/prompt_helpers.py +39 -0
  42. graphiti_core/prompts/snippets.py +29 -0
  43. graphiti_core/prompts/summarize_nodes.py +23 -25
  44. graphiti_core/search/search.py +158 -82
  45. graphiti_core/search/search_config.py +39 -4
  46. graphiti_core/search/search_filters.py +126 -35
  47. graphiti_core/search/search_helpers.py +5 -6
  48. graphiti_core/search/search_utils.py +1405 -485
  49. graphiti_core/telemetry/__init__.py +9 -0
  50. graphiti_core/telemetry/telemetry.py +117 -0
  51. graphiti_core/tracer.py +193 -0
  52. graphiti_core/utils/bulk_utils.py +364 -285
  53. graphiti_core/utils/datetime_utils.py +13 -0
  54. graphiti_core/utils/maintenance/community_operations.py +67 -49
  55. graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
  56. graphiti_core/utils/maintenance/edge_operations.py +339 -197
  57. graphiti_core/utils/maintenance/graph_data_operations.py +50 -114
  58. graphiti_core/utils/maintenance/node_operations.py +319 -238
  59. graphiti_core/utils/maintenance/temporal_operations.py +11 -3
  60. graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
  61. graphiti_core/utils/text_utils.py +53 -0
  62. graphiti_core-0.24.3.dist-info/METADATA +726 -0
  63. graphiti_core-0.24.3.dist-info/RECORD +86 -0
  64. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
  65. graphiti_core-0.12.0rc1.dist-info/METADATA +0 -350
  66. graphiti_core-0.12.0rc1.dist-info/RECORD +0 -66
  67. /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
  68. {graphiti_core-0.12.0rc1.dist-info → graphiti_core-0.24.3.dist-info/licenses}/LICENSE +0 -0
@@ -20,20 +20,31 @@ from time import time
20
20
  from typing import Any
21
21
 
22
22
  import numpy as np
23
- from neo4j import AsyncDriver, Query
24
23
  from numpy._typing import NDArray
25
24
  from typing_extensions import LiteralString
26
25
 
26
+ from graphiti_core.driver.driver import (
27
+ GraphDriver,
28
+ GraphProvider,
29
+ )
27
30
  from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
31
+ from graphiti_core.graph_queries import (
32
+ get_nodes_query,
33
+ get_relationships_query,
34
+ get_vector_cosine_func_query,
35
+ )
28
36
  from graphiti_core.helpers import (
29
- DEFAULT_DATABASE,
30
- RUNTIME_QUERY,
31
37
  lucene_sanitize,
32
38
  normalize_l2,
33
39
  semaphore_gather,
34
40
  )
41
+ from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query
42
+ from graphiti_core.models.nodes.node_db_queries import (
43
+ COMMUNITY_NODE_RETURN,
44
+ EPISODIC_NODE_RETURN,
45
+ get_entity_node_return_query,
46
+ )
35
47
  from graphiti_core.nodes import (
36
- ENTITY_NODE_RETURN,
37
48
  CommunityNode,
38
49
  EntityNode,
39
50
  EpisodicNode,
@@ -53,16 +64,39 @@ RELEVANT_SCHEMA_LIMIT = 10
53
64
  DEFAULT_MIN_SCORE = 0.6
54
65
  DEFAULT_MMR_LAMBDA = 0.5
55
66
  MAX_SEARCH_DEPTH = 3
56
- MAX_QUERY_LENGTH = 32
67
+ MAX_QUERY_LENGTH = 128
68
+
69
+
70
+ def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> float:
71
+ """
72
+ Calculates the cosine similarity between two vectors using NumPy.
73
+ """
74
+ dot_product = np.dot(vector1, vector2)
75
+ norm_vector1 = np.linalg.norm(vector1)
76
+ norm_vector2 = np.linalg.norm(vector2)
77
+
78
+ if norm_vector1 == 0 or norm_vector2 == 0:
79
+ return 0 # Handle cases where one or both vectors are zero vectors
80
+
81
+ return dot_product / (norm_vector1 * norm_vector2)
57
82
 
58
83
 
59
- def fulltext_query(query: str, group_ids: list[str] | None = None):
84
+ def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver):
85
+ if driver.provider == GraphProvider.KUZU:
86
+ # Kuzu only supports simple queries.
87
+ if len(query.split(' ')) > MAX_QUERY_LENGTH:
88
+ return ''
89
+ return query
90
+ elif driver.provider == GraphProvider.FALKORDB:
91
+ return driver.build_fulltext_query(query, group_ids, MAX_QUERY_LENGTH)
60
92
  group_ids_filter_list = (
61
- [f'group_id:"{lucene_sanitize(g)}"' for g in group_ids] if group_ids is not None else []
93
+ [driver.fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
94
+ if group_ids is not None
95
+ else []
62
96
  )
63
97
  group_ids_filter = ''
64
98
  for f in group_ids_filter_list:
65
- group_ids_filter += f if not group_ids_filter else f'OR {f}'
99
+ group_ids_filter += f if not group_ids_filter else f' OR {f}'
66
100
 
67
101
  group_ids_filter += ' AND ' if group_ids_filter else ''
68
102
 
@@ -77,7 +111,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
77
111
 
78
112
 
79
113
  async def get_episodes_by_mentions(
80
- driver: AsyncDriver,
114
+ driver: GraphDriver,
81
115
  nodes: list[EntityNode],
82
116
  edges: list[EntityEdge],
83
117
  limit: int = RELEVANT_SCHEMA_LIMIT,
@@ -92,47 +126,39 @@ async def get_episodes_by_mentions(
92
126
 
93
127
 
94
128
  async def get_mentioned_nodes(
95
- driver: AsyncDriver, episodes: list[EpisodicNode]
129
+ driver: GraphDriver, episodes: list[EpisodicNode]
96
130
  ) -> list[EntityNode]:
97
131
  episode_uuids = [episode.uuid for episode in episodes]
132
+
98
133
  records, _, _ = await driver.execute_query(
99
134
  """
100
- MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
135
+ MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
136
+ WHERE episode.uuid IN $uuids
101
137
  RETURN DISTINCT
102
- n.uuid As uuid,
103
- n.group_id AS group_id,
104
- n.name AS name,
105
- n.created_at AS created_at,
106
- n.summary AS summary,
107
- labels(n) AS labels,
108
- properties(n) AS attributes
109
- """,
138
+ """
139
+ + get_entity_node_return_query(driver.provider),
110
140
  uuids=episode_uuids,
111
- database_=DEFAULT_DATABASE,
112
141
  routing_='r',
113
142
  )
114
143
 
115
- nodes = [get_entity_node_from_record(record) for record in records]
144
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
116
145
 
117
146
  return nodes
118
147
 
119
148
 
120
149
  async def get_communities_by_nodes(
121
- driver: AsyncDriver, nodes: list[EntityNode]
150
+ driver: GraphDriver, nodes: list[EntityNode]
122
151
  ) -> list[CommunityNode]:
123
152
  node_uuids = [node.uuid for node in nodes]
153
+
124
154
  records, _, _ = await driver.execute_query(
125
155
  """
126
- MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
127
- RETURN DISTINCT
128
- c.uuid As uuid,
129
- c.group_id AS group_id,
130
- c.name AS name,
131
- c.created_at AS created_at,
132
- c.summary AS summary
133
- """,
156
+ MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)
157
+ WHERE m.uuid IN $uuids
158
+ RETURN DISTINCT
159
+ """
160
+ + COMMUNITY_NODE_RETURN,
134
161
  uuids=node_uuids,
135
- database_=DEFAULT_DATABASE,
136
162
  routing_='r',
137
163
  )
138
164
 
@@ -142,61 +168,122 @@ async def get_communities_by_nodes(
142
168
 
143
169
 
144
170
  async def edge_fulltext_search(
145
- driver: AsyncDriver,
171
+ driver: GraphDriver,
146
172
  query: str,
147
173
  search_filter: SearchFilters,
148
174
  group_ids: list[str] | None = None,
149
175
  limit=RELEVANT_SCHEMA_LIMIT,
150
176
  ) -> list[EntityEdge]:
177
+ if driver.search_interface:
178
+ return await driver.search_interface.edge_fulltext_search(
179
+ driver, query, search_filter, group_ids, limit
180
+ )
181
+
151
182
  # fulltext search over facts
152
- fuzzy_query = fulltext_query(query, group_ids)
183
+ fuzzy_query = fulltext_query(query, group_ids, driver)
184
+
153
185
  if fuzzy_query == '':
154
186
  return []
155
187
 
156
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
157
-
158
- cypher_query = Query(
188
+ match_query = """
189
+ YIELD relationship AS rel, score
190
+ MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
191
+ """
192
+ if driver.provider == GraphProvider.KUZU:
193
+ match_query = """
194
+ YIELD node, score
195
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: node.uuid})-[:RELATES_TO]->(m:Entity)
159
196
  """
160
- CALL db.index.fulltext.queryRelationships("edge_name_and_fact", $query, {limit: $limit})
161
- YIELD relationship AS rel, score
162
- MATCH (:Entity)-[r:RELATES_TO]->(:Entity)
163
- WHERE r.group_id IN $group_ids"""
164
- + filter_query
165
- + """\nWITH r, score, startNode(r) AS n, endNode(r) AS m
166
- RETURN
167
- r.uuid AS uuid,
168
- r.group_id AS group_id,
169
- n.uuid AS source_node_uuid,
170
- m.uuid AS target_node_uuid,
171
- r.created_at AS created_at,
172
- r.name AS name,
173
- r.fact AS fact,
174
- r.episodes AS episodes,
175
- r.expired_at AS expired_at,
176
- r.valid_at AS valid_at,
177
- r.invalid_at AS invalid_at,
178
- properties(r) AS attributes
179
- ORDER BY score DESC LIMIT $limit
180
- """
181
- )
182
197
 
183
- records, _, _ = await driver.execute_query(
184
- cypher_query,
185
- filter_params,
186
- query=fuzzy_query,
187
- group_ids=group_ids,
188
- limit=limit,
189
- database_=DEFAULT_DATABASE,
190
- routing_='r',
198
+ filter_queries, filter_params = edge_search_filter_query_constructor(
199
+ search_filter, driver.provider
191
200
  )
192
201
 
193
- edges = [get_entity_edge_from_record(record) for record in records]
202
+ if group_ids is not None:
203
+ filter_queries.append('e.group_id IN $group_ids')
204
+ filter_params['group_ids'] = group_ids
205
+
206
+ filter_query = ''
207
+ if filter_queries:
208
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
209
+
210
+ if driver.provider == GraphProvider.NEPTUNE:
211
+ res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
212
+ if res['hits']['total']['value'] > 0:
213
+ input_ids = []
214
+ for r in res['hits']['hits']:
215
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
216
+
217
+ # Match the edge ids and return the values
218
+ query = (
219
+ """
220
+ UNWIND $ids as id
221
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
222
+ WHERE e.group_id IN $group_ids
223
+ AND id(e)=id
224
+ """
225
+ + filter_query
226
+ + """
227
+ AND id(e)=id
228
+ WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
229
+ RETURN
230
+ e.uuid AS uuid,
231
+ e.group_id AS group_id,
232
+ n.uuid AS source_node_uuid,
233
+ m.uuid AS target_node_uuid,
234
+ e.created_at AS created_at,
235
+ e.name AS name,
236
+ e.fact AS fact,
237
+ split(e.episodes, ",") AS episodes,
238
+ e.expired_at AS expired_at,
239
+ e.valid_at AS valid_at,
240
+ e.invalid_at AS invalid_at,
241
+ properties(e) AS attributes
242
+ ORDER BY score DESC LIMIT $limit
243
+ """
244
+ )
245
+
246
+ records, _, _ = await driver.execute_query(
247
+ query,
248
+ query=fuzzy_query,
249
+ ids=input_ids,
250
+ limit=limit,
251
+ routing_='r',
252
+ **filter_params,
253
+ )
254
+ else:
255
+ return []
256
+ else:
257
+ query = (
258
+ get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
259
+ + match_query
260
+ + filter_query
261
+ + """
262
+ WITH e, score, n, m
263
+ RETURN
264
+ """
265
+ + get_entity_edge_return_query(driver.provider)
266
+ + """
267
+ ORDER BY score DESC
268
+ LIMIT $limit
269
+ """
270
+ )
271
+
272
+ records, _, _ = await driver.execute_query(
273
+ query,
274
+ query=fuzzy_query,
275
+ limit=limit,
276
+ routing_='r',
277
+ **filter_params,
278
+ )
279
+
280
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
194
281
 
195
282
  return edges
196
283
 
197
284
 
198
285
  async def edge_similarity_search(
199
- driver: AsyncDriver,
286
+ driver: GraphDriver,
200
287
  search_vector: list[float],
201
288
  source_node_uuid: str | None,
202
289
  target_node_uuid: str | None,
@@ -205,34 +292,85 @@ async def edge_similarity_search(
205
292
  limit: int = RELEVANT_SCHEMA_LIMIT,
206
293
  min_score: float = DEFAULT_MIN_SCORE,
207
294
  ) -> list[EntityEdge]:
208
- # vector similarity search over embedded facts
209
- query_params: dict[str, Any] = {}
295
+ if driver.search_interface:
296
+ return await driver.search_interface.edge_similarity_search(
297
+ driver,
298
+ search_vector,
299
+ source_node_uuid,
300
+ target_node_uuid,
301
+ search_filter,
302
+ group_ids,
303
+ limit,
304
+ min_score,
305
+ )
210
306
 
211
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
212
- query_params.update(filter_params)
307
+ match_query = """
308
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
309
+ """
310
+ if driver.provider == GraphProvider.KUZU:
311
+ match_query = """
312
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
313
+ """
314
+
315
+ filter_queries, filter_params = edge_search_filter_query_constructor(
316
+ search_filter, driver.provider
317
+ )
213
318
 
214
- group_filter_query: LiteralString = ''
215
319
  if group_ids is not None:
216
- group_filter_query += 'WHERE r.group_id IN $group_ids'
217
- query_params['group_ids'] = group_ids
218
- query_params['source_node_uuid'] = source_node_uuid
219
- query_params['target_node_uuid'] = target_node_uuid
320
+ filter_queries.append('e.group_id IN $group_ids')
321
+ filter_params['group_ids'] = group_ids
220
322
 
221
323
  if source_node_uuid is not None:
222
- group_filter_query += '\nAND (n.uuid IN [$source_uuid, $target_uuid])'
324
+ filter_params['source_uuid'] = source_node_uuid
325
+ filter_queries.append('n.uuid = $source_uuid')
223
326
 
224
327
  if target_node_uuid is not None:
225
- group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
226
-
227
- query: LiteralString = (
228
- RUNTIME_QUERY
229
- + """
230
- MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
231
- """
232
- + group_filter_query
233
- + filter_query
234
- + """\nWITH DISTINCT r, vector.similarity.cosine(r.fact_embedding, $search_vector) AS score
235
- WHERE score > $min_score
328
+ filter_params['target_uuid'] = target_node_uuid
329
+ filter_queries.append('m.uuid = $target_uuid')
330
+
331
+ filter_query = ''
332
+ if filter_queries:
333
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
334
+
335
+ search_vector_var = '$search_vector'
336
+ if driver.provider == GraphProvider.KUZU:
337
+ search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
338
+
339
+ if driver.provider == GraphProvider.NEPTUNE:
340
+ query = (
341
+ """
342
+ MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
343
+ """
344
+ + filter_query
345
+ + """
346
+ RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
347
+ """
348
+ )
349
+ resp, header, _ = await driver.execute_query(
350
+ query,
351
+ search_vector=search_vector,
352
+ limit=limit,
353
+ min_score=min_score,
354
+ routing_='r',
355
+ **filter_params,
356
+ )
357
+
358
+ if len(resp) > 0:
359
+ # Calculate Cosine similarity then return the edge ids
360
+ input_ids = []
361
+ for r in resp:
362
+ if r['embedding']:
363
+ score = calculate_cosine_similarity(
364
+ search_vector, list(map(float, r['embedding'].split(',')))
365
+ )
366
+ if score > min_score:
367
+ input_ids.append({'id': r['id'], 'score': score})
368
+
369
+ # Match the edge ides and return the values
370
+ query = """
371
+ UNWIND $ids as i
372
+ MATCH ()-[r]->()
373
+ WHERE id(r) = i.id
236
374
  RETURN
237
375
  r.uuid AS uuid,
238
376
  r.group_id AS group_id,
@@ -241,292 +379,648 @@ async def edge_similarity_search(
241
379
  r.created_at AS created_at,
242
380
  r.name AS name,
243
381
  r.fact AS fact,
244
- r.episodes AS episodes,
382
+ split(r.episodes, ",") AS episodes,
245
383
  r.expired_at AS expired_at,
246
384
  r.valid_at AS valid_at,
247
385
  r.invalid_at AS invalid_at,
248
386
  properties(r) AS attributes
249
- ORDER BY score DESC
387
+ ORDER BY i.score DESC
250
388
  LIMIT $limit
251
- """
252
- )
389
+ """
390
+ records, _, _ = await driver.execute_query(
391
+ query,
392
+ ids=input_ids,
393
+ search_vector=search_vector,
394
+ limit=limit,
395
+ min_score=min_score,
396
+ routing_='r',
397
+ **filter_params,
398
+ )
399
+ else:
400
+ return []
401
+ else:
402
+ query = (
403
+ match_query
404
+ + filter_query
405
+ + """
406
+ WITH DISTINCT e, n, m, """
407
+ + get_vector_cosine_func_query('e.fact_embedding', search_vector_var, driver.provider)
408
+ + """ AS score
409
+ WHERE score > $min_score
410
+ RETURN
411
+ """
412
+ + get_entity_edge_return_query(driver.provider)
413
+ + """
414
+ ORDER BY score DESC
415
+ LIMIT $limit
416
+ """
417
+ )
253
418
 
254
- records, _, _ = await driver.execute_query(
255
- query,
256
- query_params,
257
- search_vector=search_vector,
258
- source_uuid=source_node_uuid,
259
- target_uuid=target_node_uuid,
260
- group_ids=group_ids,
261
- limit=limit,
262
- min_score=min_score,
263
- database_=DEFAULT_DATABASE,
264
- routing_='r',
265
- )
419
+ records, _, _ = await driver.execute_query(
420
+ query,
421
+ search_vector=search_vector,
422
+ limit=limit,
423
+ min_score=min_score,
424
+ routing_='r',
425
+ **filter_params,
426
+ )
266
427
 
267
- edges = [get_entity_edge_from_record(record) for record in records]
428
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
268
429
 
269
430
  return edges
270
431
 
271
432
 
272
433
  async def edge_bfs_search(
273
- driver: AsyncDriver,
434
+ driver: GraphDriver,
274
435
  bfs_origin_node_uuids: list[str] | None,
275
436
  bfs_max_depth: int,
276
437
  search_filter: SearchFilters,
277
- limit: int,
438
+ group_ids: list[str] | None = None,
439
+ limit: int = RELEVANT_SCHEMA_LIMIT,
278
440
  ) -> list[EntityEdge]:
279
441
  # vector similarity search over embedded facts
280
- if bfs_origin_node_uuids is None:
442
+ if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0:
281
443
  return []
282
444
 
283
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
445
+ filter_queries, filter_params = edge_search_filter_query_constructor(
446
+ search_filter, driver.provider
447
+ )
284
448
 
285
- query = Query(
286
- """
449
+ if group_ids is not None:
450
+ filter_queries.append('e.group_id IN $group_ids')
451
+ filter_params['group_ids'] = group_ids
452
+
453
+ filter_query = ''
454
+ if filter_queries:
455
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
456
+
457
+ if driver.provider == GraphProvider.KUZU:
458
+ # Kuzu stores entity edges twice with an intermediate node, so we need to match them
459
+ # separately for the correct BFS depth.
460
+ depth = bfs_max_depth * 2 - 1
461
+ match_queries = [
462
+ f"""
463
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
464
+ MATCH path = (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
465
+ UNWIND nodes(path) AS relNode
466
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
467
+ """,
468
+ ]
469
+ if bfs_max_depth > 1:
470
+ depth = (bfs_max_depth - 1) * 2 - 1
471
+ match_queries.append(f"""
287
472
  UNWIND $bfs_origin_node_uuids AS origin_uuid
288
- MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
473
+ MATCH path = (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
474
+ UNWIND nodes(path) AS relNode
475
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
476
+ """)
477
+
478
+ records = []
479
+ for match_query in match_queries:
480
+ sub_records, _, _ = await driver.execute_query(
481
+ match_query
482
+ + filter_query
483
+ + """
484
+ RETURN DISTINCT
485
+ """
486
+ + get_entity_edge_return_query(driver.provider)
487
+ + """
488
+ LIMIT $limit
489
+ """,
490
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
491
+ limit=limit,
492
+ routing_='r',
493
+ **filter_params,
494
+ )
495
+ records.extend(sub_records)
496
+ else:
497
+ if driver.provider == GraphProvider.NEPTUNE:
498
+ query = (
499
+ f"""
500
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
501
+ MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
502
+ WHERE origin:Entity OR origin:Episodic
289
503
  UNWIND relationships(path) AS rel
290
- MATCH ()-[r:RELATES_TO]-()
291
- WHERE r.uuid = rel.uuid
504
+ MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
292
505
  """
293
- + filter_query
294
- + """
506
+ + filter_query
507
+ + """
295
508
  RETURN DISTINCT
296
- r.uuid AS uuid,
297
- r.group_id AS group_id,
298
- startNode(r).uuid AS source_node_uuid,
299
- endNode(r).uuid AS target_node_uuid,
300
- r.created_at AS created_at,
301
- r.name AS name,
302
- r.fact AS fact,
303
- r.episodes AS episodes,
304
- r.expired_at AS expired_at,
305
- r.valid_at AS valid_at,
306
- r.invalid_at AS invalid_at,
307
- properties(r) AS attributes
509
+ e.uuid AS uuid,
510
+ e.group_id AS group_id,
511
+ startNode(e).uuid AS source_node_uuid,
512
+ endNode(e).uuid AS target_node_uuid,
513
+ e.created_at AS created_at,
514
+ e.name AS name,
515
+ e.fact AS fact,
516
+ split(e.episodes, ',') AS episodes,
517
+ e.expired_at AS expired_at,
518
+ e.valid_at AS valid_at,
519
+ e.invalid_at AS invalid_at,
520
+ properties(e) AS attributes
308
521
  LIMIT $limit
309
- """
310
- )
311
-
312
- records, _, _ = await driver.execute_query(
313
- query,
314
- filter_params,
315
- bfs_origin_node_uuids=bfs_origin_node_uuids,
316
- depth=bfs_max_depth,
317
- limit=limit,
318
- database_=DEFAULT_DATABASE,
319
- routing_='r',
320
- )
522
+ """
523
+ )
524
+ else:
525
+ query = (
526
+ f"""
527
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
528
+ MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
529
+ UNWIND relationships(path) AS rel
530
+ MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
531
+ """
532
+ + filter_query
533
+ + """
534
+ RETURN DISTINCT
535
+ """
536
+ + get_entity_edge_return_query(driver.provider)
537
+ + """
538
+ LIMIT $limit
539
+ """
540
+ )
541
+
542
+ records, _, _ = await driver.execute_query(
543
+ query,
544
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
545
+ depth=bfs_max_depth,
546
+ limit=limit,
547
+ routing_='r',
548
+ **filter_params,
549
+ )
321
550
 
322
- edges = [get_entity_edge_from_record(record) for record in records]
551
+ edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
323
552
 
324
553
  return edges
325
554
 
326
555
 
327
556
  async def node_fulltext_search(
328
- driver: AsyncDriver,
557
+ driver: GraphDriver,
329
558
  query: str,
330
559
  search_filter: SearchFilters,
331
560
  group_ids: list[str] | None = None,
332
561
  limit=RELEVANT_SCHEMA_LIMIT,
333
562
  ) -> list[EntityNode]:
563
+ if driver.search_interface:
564
+ return await driver.search_interface.node_fulltext_search(
565
+ driver, query, search_filter, group_ids, limit
566
+ )
567
+
334
568
  # BM25 search to get top nodes
335
- fuzzy_query = fulltext_query(query, group_ids)
569
+ fuzzy_query = fulltext_query(query, group_ids, driver)
336
570
  if fuzzy_query == '':
337
571
  return []
338
572
 
339
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
340
-
341
- query = (
342
- """
343
- CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
344
- YIELD node AS n, score
345
- WHERE n:Entity
346
- """
347
- + filter_query
348
- + ENTITY_NODE_RETURN
349
- + """
350
- ORDER BY score DESC
351
- """
573
+ filter_queries, filter_params = node_search_filter_query_constructor(
574
+ search_filter, driver.provider
352
575
  )
353
576
 
354
- records, _, _ = await driver.execute_query(
355
- query,
356
- filter_params,
357
- query=fuzzy_query,
358
- group_ids=group_ids,
359
- limit=limit,
360
- database_=DEFAULT_DATABASE,
361
- routing_='r',
362
- )
363
- nodes = [get_entity_node_from_record(record) for record in records]
577
+ if group_ids is not None:
578
+ filter_queries.append('n.group_id IN $group_ids')
579
+ filter_params['group_ids'] = group_ids
580
+
581
+ filter_query = ''
582
+ if filter_queries:
583
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
584
+
585
+ yield_query = 'YIELD node AS n, score'
586
+ if driver.provider == GraphProvider.KUZU:
587
+ yield_query = 'WITH node AS n, score'
588
+
589
+ if driver.provider == GraphProvider.NEPTUNE:
590
+ res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
591
+ if res['hits']['total']['value'] > 0:
592
+ input_ids = []
593
+ for r in res['hits']['hits']:
594
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
595
+
596
+ # Match the edge ides and return the values
597
+ query = (
598
+ """
599
+ UNWIND $ids as i
600
+ MATCH (n:Entity)
601
+ WHERE n.uuid=i.id
602
+ RETURN
603
+ """
604
+ + get_entity_node_return_query(driver.provider)
605
+ + """
606
+ ORDER BY i.score DESC
607
+ LIMIT $limit
608
+ """
609
+ )
610
+ records, _, _ = await driver.execute_query(
611
+ query,
612
+ ids=input_ids,
613
+ query=fuzzy_query,
614
+ limit=limit,
615
+ routing_='r',
616
+ **filter_params,
617
+ )
618
+ else:
619
+ return []
620
+ else:
621
+ query = (
622
+ get_nodes_query(
623
+ 'node_name_and_summary', '$query', limit=limit, provider=driver.provider
624
+ )
625
+ + yield_query
626
+ + filter_query
627
+ + """
628
+ WITH n, score
629
+ ORDER BY score DESC
630
+ LIMIT $limit
631
+ RETURN
632
+ """
633
+ + get_entity_node_return_query(driver.provider)
634
+ )
635
+
636
+ records, _, _ = await driver.execute_query(
637
+ query,
638
+ query=fuzzy_query,
639
+ limit=limit,
640
+ routing_='r',
641
+ **filter_params,
642
+ )
643
+
644
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
364
645
 
365
646
  return nodes
366
647
 
367
648
 
368
649
  async def node_similarity_search(
369
- driver: AsyncDriver,
650
+ driver: GraphDriver,
370
651
  search_vector: list[float],
371
652
  search_filter: SearchFilters,
372
653
  group_ids: list[str] | None = None,
373
654
  limit=RELEVANT_SCHEMA_LIMIT,
374
655
  min_score: float = DEFAULT_MIN_SCORE,
375
656
  ) -> list[EntityNode]:
376
- # vector similarity search over entity names
377
- query_params: dict[str, Any] = {}
657
+ if driver.search_interface:
658
+ return await driver.search_interface.node_similarity_search(
659
+ driver, search_vector, search_filter, group_ids, limit, min_score
660
+ )
661
+
662
+ filter_queries, filter_params = node_search_filter_query_constructor(
663
+ search_filter, driver.provider
664
+ )
378
665
 
379
- group_filter_query: LiteralString = ''
380
666
  if group_ids is not None:
381
- group_filter_query += 'WHERE n.group_id IN $group_ids'
382
- query_params['group_ids'] = group_ids
667
+ filter_queries.append('n.group_id IN $group_ids')
668
+ filter_params['group_ids'] = group_ids
383
669
 
384
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
385
- query_params.update(filter_params)
670
+ filter_query = ''
671
+ if filter_queries:
672
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
386
673
 
387
- records, _, _ = await driver.execute_query(
388
- RUNTIME_QUERY
389
- + """
390
- MATCH (n:Entity)
674
+ search_vector_var = '$search_vector'
675
+ if driver.provider == GraphProvider.KUZU:
676
+ search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
677
+
678
+ if driver.provider == GraphProvider.NEPTUNE:
679
+ query = (
391
680
  """
392
- + group_filter_query
393
- + filter_query
394
- + """
395
- WITH n, vector.similarity.cosine(n.name_embedding, $search_vector) AS score
396
- WHERE score > $min_score"""
397
- + ENTITY_NODE_RETURN
398
- + """
399
- ORDER BY score DESC
400
- LIMIT $limit
401
- """,
402
- query_params,
403
- search_vector=search_vector,
404
- group_ids=group_ids,
405
- limit=limit,
406
- min_score=min_score,
407
- database_=DEFAULT_DATABASE,
408
- routing_='r',
409
- )
410
- nodes = [get_entity_node_from_record(record) for record in records]
681
+ MATCH (n:Entity)
682
+ """
683
+ + filter_query
684
+ + """
685
+ RETURN DISTINCT id(n) as id, n.name_embedding as embedding
686
+ """
687
+ )
688
+ resp, header, _ = await driver.execute_query(
689
+ query,
690
+ params=filter_params,
691
+ search_vector=search_vector,
692
+ limit=limit,
693
+ min_score=min_score,
694
+ routing_='r',
695
+ )
696
+
697
+ if len(resp) > 0:
698
+ # Calculate Cosine similarity then return the edge ids
699
+ input_ids = []
700
+ for r in resp:
701
+ if r['embedding']:
702
+ score = calculate_cosine_similarity(
703
+ search_vector, list(map(float, r['embedding'].split(',')))
704
+ )
705
+ if score > min_score:
706
+ input_ids.append({'id': r['id'], 'score': score})
707
+
708
+ # Match the edge ides and return the values
709
+ query = (
710
+ """
711
+ UNWIND $ids as i
712
+ MATCH (n:Entity)
713
+ WHERE id(n)=i.id
714
+ RETURN
715
+ """
716
+ + get_entity_node_return_query(driver.provider)
717
+ + """
718
+ ORDER BY i.score DESC
719
+ LIMIT $limit
720
+ """
721
+ )
722
+ records, header, _ = await driver.execute_query(
723
+ query,
724
+ ids=input_ids,
725
+ search_vector=search_vector,
726
+ limit=limit,
727
+ min_score=min_score,
728
+ routing_='r',
729
+ **filter_params,
730
+ )
731
+ else:
732
+ return []
733
+ else:
734
+ query = (
735
+ """
736
+ MATCH (n:Entity)
737
+ """
738
+ + filter_query
739
+ + """
740
+ WITH n, """
741
+ + get_vector_cosine_func_query('n.name_embedding', search_vector_var, driver.provider)
742
+ + """ AS score
743
+ WHERE score > $min_score
744
+ RETURN
745
+ """
746
+ + get_entity_node_return_query(driver.provider)
747
+ + """
748
+ ORDER BY score DESC
749
+ LIMIT $limit
750
+ """
751
+ )
752
+
753
+ records, _, _ = await driver.execute_query(
754
+ query,
755
+ search_vector=search_vector,
756
+ limit=limit,
757
+ min_score=min_score,
758
+ routing_='r',
759
+ **filter_params,
760
+ )
761
+
762
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
411
763
 
412
764
  return nodes
413
765
 
414
766
 
415
767
  async def node_bfs_search(
416
- driver: AsyncDriver,
768
+ driver: GraphDriver,
417
769
  bfs_origin_node_uuids: list[str] | None,
418
770
  search_filter: SearchFilters,
419
771
  bfs_max_depth: int,
420
- limit: int,
772
+ group_ids: list[str] | None = None,
773
+ limit: int = RELEVANT_SCHEMA_LIMIT,
421
774
  ) -> list[EntityNode]:
422
- # vector similarity search over entity names
423
- if bfs_origin_node_uuids is None:
775
+ if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0 or bfs_max_depth < 1:
424
776
  return []
425
777
 
426
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
778
+ filter_queries, filter_params = node_search_filter_query_constructor(
779
+ search_filter, driver.provider
780
+ )
427
781
 
428
- records, _, _ = await driver.execute_query(
782
+ if group_ids is not None:
783
+ filter_queries.append('n.group_id IN $group_ids')
784
+ filter_queries.append('origin.group_id IN $group_ids')
785
+ filter_params['group_ids'] = group_ids
786
+
787
+ filter_query = ''
788
+ if filter_queries:
789
+ filter_query = ' AND ' + (' AND '.join(filter_queries))
790
+
791
+ match_queries = [
792
+ f"""
793
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
794
+ MATCH (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
795
+ WHERE n.group_id = origin.group_id
429
796
  """
797
+ ]
798
+
799
+ if driver.provider == GraphProvider.NEPTUNE:
800
+ match_queries = [
801
+ f"""
802
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
803
+ MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
804
+ WHERE origin:Entity OR origin.Episode
805
+ AND n.group_id = origin.group_id
806
+ """
807
+ ]
808
+
809
+ if driver.provider == GraphProvider.KUZU:
810
+ depth = bfs_max_depth * 2
811
+ match_queries = [
812
+ """
813
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
814
+ MATCH (origin:Episodic {uuid: origin_uuid})-[:MENTIONS]->(n:Entity)
815
+ WHERE n.group_id = origin.group_id
816
+ """,
817
+ f"""
430
818
  UNWIND $bfs_origin_node_uuids AS origin_uuid
431
- MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
819
+ MATCH (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*2..{depth}]->(n:Entity)
432
820
  WHERE n.group_id = origin.group_id
821
+ """,
822
+ ]
823
+ if bfs_max_depth > 1:
824
+ depth = (bfs_max_depth - 1) * 2
825
+ match_queries.append(f"""
826
+ UNWIND $bfs_origin_node_uuids AS origin_uuid
827
+ MATCH (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*2..{depth}]->(n:Entity)
828
+ WHERE n.group_id = origin.group_id
829
+ """)
830
+
831
+ records = []
832
+ for match_query in match_queries:
833
+ sub_records, _, _ = await driver.execute_query(
834
+ match_query
835
+ + filter_query
836
+ + """
837
+ RETURN
433
838
  """
434
- + filter_query
435
- + ENTITY_NODE_RETURN
436
- + """
437
- LIMIT $limit
438
- """,
439
- filter_params,
440
- bfs_origin_node_uuids=bfs_origin_node_uuids,
441
- depth=bfs_max_depth,
442
- limit=limit,
443
- database_=DEFAULT_DATABASE,
444
- routing_='r',
445
- )
446
- nodes = [get_entity_node_from_record(record) for record in records]
839
+ + get_entity_node_return_query(driver.provider)
840
+ + """
841
+ LIMIT $limit
842
+ """,
843
+ bfs_origin_node_uuids=bfs_origin_node_uuids,
844
+ limit=limit,
845
+ routing_='r',
846
+ **filter_params,
847
+ )
848
+ records.extend(sub_records)
849
+
850
+ nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
447
851
 
448
852
  return nodes
449
853
 
450
854
 
451
855
  async def episode_fulltext_search(
452
- driver: AsyncDriver,
856
+ driver: GraphDriver,
453
857
  query: str,
454
858
  _search_filter: SearchFilters,
455
859
  group_ids: list[str] | None = None,
456
860
  limit=RELEVANT_SCHEMA_LIMIT,
457
861
  ) -> list[EpisodicNode]:
862
+ if driver.search_interface:
863
+ return await driver.search_interface.episode_fulltext_search(
864
+ driver, query, _search_filter, group_ids, limit
865
+ )
866
+
458
867
  # BM25 search to get top episodes
459
- fuzzy_query = fulltext_query(query, group_ids)
868
+ fuzzy_query = fulltext_query(query, group_ids, driver)
460
869
  if fuzzy_query == '':
461
870
  return []
462
871
 
463
- records, _, _ = await driver.execute_query(
464
- """
465
- CALL db.index.fulltext.queryNodes("episode_content", $query, {limit: $limit})
466
- YIELD node AS episode, score
467
- MATCH (e:Episodic)
468
- WHERE e.uuid = episode.uuid
469
- RETURN
470
- e.content AS content,
471
- e.created_at AS created_at,
472
- e.valid_at AS valid_at,
473
- e.uuid AS uuid,
474
- e.name AS name,
475
- e.group_id AS group_id,
476
- e.source_description AS source_description,
477
- e.source AS source,
478
- e.entity_edges AS entity_edges
479
- ORDER BY score DESC
480
- LIMIT $limit
481
- """,
482
- query=fuzzy_query,
483
- group_ids=group_ids,
484
- limit=limit,
485
- database_=DEFAULT_DATABASE,
486
- routing_='r',
487
- )
872
+ filter_params: dict[str, Any] = {}
873
+ group_filter_query: LiteralString = ''
874
+ if group_ids is not None:
875
+ group_filter_query += '\nAND e.group_id IN $group_ids'
876
+ filter_params['group_ids'] = group_ids
877
+
878
+ if driver.provider == GraphProvider.NEPTUNE:
879
+ res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
880
+ if res['hits']['total']['value'] > 0:
881
+ input_ids = []
882
+ for r in res['hits']['hits']:
883
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
884
+
885
+ # Match the edge ides and return the values
886
+ query = """
887
+ UNWIND $ids as i
888
+ MATCH (e:Episodic)
889
+ WHERE e.uuid=i.uuid
890
+ RETURN
891
+ e.content AS content,
892
+ e.created_at AS created_at,
893
+ e.valid_at AS valid_at,
894
+ e.uuid AS uuid,
895
+ e.name AS name,
896
+ e.group_id AS group_id,
897
+ e.source_description AS source_description,
898
+ e.source AS source,
899
+ e.entity_edges AS entity_edges
900
+ ORDER BY i.score DESC
901
+ LIMIT $limit
902
+ """
903
+ records, _, _ = await driver.execute_query(
904
+ query,
905
+ ids=input_ids,
906
+ query=fuzzy_query,
907
+ limit=limit,
908
+ routing_='r',
909
+ **filter_params,
910
+ )
911
+ else:
912
+ return []
913
+ else:
914
+ query = (
915
+ get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
916
+ + """
917
+ YIELD node AS episode, score
918
+ MATCH (e:Episodic)
919
+ WHERE e.uuid = episode.uuid
920
+ """
921
+ + group_filter_query
922
+ + """
923
+ RETURN
924
+ """
925
+ + EPISODIC_NODE_RETURN
926
+ + """
927
+ ORDER BY score DESC
928
+ LIMIT $limit
929
+ """
930
+ )
931
+
932
+ records, _, _ = await driver.execute_query(
933
+ query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
934
+ )
935
+
488
936
  episodes = [get_episodic_node_from_record(record) for record in records]
489
937
 
490
938
  return episodes
491
939
 
492
940
 
493
941
  async def community_fulltext_search(
494
- driver: AsyncDriver,
942
+ driver: GraphDriver,
495
943
  query: str,
496
944
  group_ids: list[str] | None = None,
497
945
  limit=RELEVANT_SCHEMA_LIMIT,
498
946
  ) -> list[CommunityNode]:
499
947
  # BM25 search to get top communities
500
- fuzzy_query = fulltext_query(query, group_ids)
948
+ fuzzy_query = fulltext_query(query, group_ids, driver)
501
949
  if fuzzy_query == '':
502
950
  return []
503
951
 
504
- records, _, _ = await driver.execute_query(
505
- """
506
- CALL db.index.fulltext.queryNodes("community_name", $query, {limit: $limit})
507
- YIELD node AS comm, score
508
- RETURN
509
- comm.uuid AS uuid,
510
- comm.group_id AS group_id,
511
- comm.name AS name,
512
- comm.created_at AS created_at,
513
- comm.summary AS summary
514
- ORDER BY score DESC
515
- LIMIT $limit
516
- """,
517
- query=fuzzy_query,
518
- group_ids=group_ids,
519
- limit=limit,
520
- database_=DEFAULT_DATABASE,
521
- routing_='r',
522
- )
952
+ filter_params: dict[str, Any] = {}
953
+ group_filter_query: LiteralString = ''
954
+ if group_ids is not None:
955
+ group_filter_query = 'WHERE c.group_id IN $group_ids'
956
+ filter_params['group_ids'] = group_ids
957
+
958
+ yield_query = 'YIELD node AS c, score'
959
+ if driver.provider == GraphProvider.KUZU:
960
+ yield_query = 'WITH node AS c, score'
961
+
962
+ if driver.provider == GraphProvider.NEPTUNE:
963
+ res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
964
+ if res['hits']['total']['value'] > 0:
965
+ # Calculate Cosine similarity then return the edge ids
966
+ input_ids = []
967
+ for r in res['hits']['hits']:
968
+ input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
969
+
970
+ # Match the edge ides and return the values
971
+ query = """
972
+ UNWIND $ids as i
973
+ MATCH (comm:Community)
974
+ WHERE comm.uuid=i.id
975
+ RETURN
976
+ comm.uuid AS uuid,
977
+ comm.group_id AS group_id,
978
+ comm.name AS name,
979
+ comm.created_at AS created_at,
980
+ comm.summary AS summary,
981
+ [x IN split(comm.name_embedding, ",") | toFloat(x)]AS name_embedding
982
+ ORDER BY i.score DESC
983
+ LIMIT $limit
984
+ """
985
+ records, _, _ = await driver.execute_query(
986
+ query,
987
+ ids=input_ids,
988
+ query=fuzzy_query,
989
+ limit=limit,
990
+ routing_='r',
991
+ **filter_params,
992
+ )
993
+ else:
994
+ return []
995
+ else:
996
+ query = (
997
+ get_nodes_query('community_name', '$query', limit=limit, provider=driver.provider)
998
+ + yield_query
999
+ + """
1000
+ WITH c, score
1001
+ """
1002
+ + group_filter_query
1003
+ + """
1004
+ RETURN
1005
+ """
1006
+ + COMMUNITY_NODE_RETURN
1007
+ + """
1008
+ ORDER BY score DESC
1009
+ LIMIT $limit
1010
+ """
1011
+ )
1012
+
1013
+ records, _, _ = await driver.execute_query(
1014
+ query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
1015
+ )
1016
+
523
1017
  communities = [get_community_node_from_record(record) for record in records]
524
1018
 
525
1019
  return communities
526
1020
 
527
1021
 
528
1022
  async def community_similarity_search(
529
- driver: AsyncDriver,
1023
+ driver: GraphDriver,
530
1024
  search_vector: list[float],
531
1025
  group_ids: list[str] | None = None,
532
1026
  limit=RELEVANT_SCHEMA_LIMIT,
@@ -537,34 +1031,99 @@ async def community_similarity_search(
537
1031
 
538
1032
  group_filter_query: LiteralString = ''
539
1033
  if group_ids is not None:
540
- group_filter_query += 'WHERE comm.group_id IN $group_ids'
1034
+ group_filter_query += ' WHERE c.group_id IN $group_ids'
541
1035
  query_params['group_ids'] = group_ids
542
1036
 
543
- records, _, _ = await driver.execute_query(
544
- RUNTIME_QUERY
545
- + """
546
- MATCH (comm:Community)
547
- """
548
- + group_filter_query
549
- + """
550
- WITH comm, vector.similarity.cosine(comm.name_embedding, $search_vector) AS score
551
- WHERE score > $min_score
552
- RETURN
553
- comm.uuid As uuid,
554
- comm.group_id AS group_id,
555
- comm.name AS name,
556
- comm.created_at AS created_at,
557
- comm.summary AS summary
558
- ORDER BY score DESC
559
- LIMIT $limit
560
- """,
561
- search_vector=search_vector,
562
- group_ids=group_ids,
563
- limit=limit,
564
- min_score=min_score,
565
- database_=DEFAULT_DATABASE,
566
- routing_='r',
567
- )
1037
+ if driver.provider == GraphProvider.NEPTUNE:
1038
+ query = (
1039
+ """
1040
+ MATCH (n:Community)
1041
+ """
1042
+ + group_filter_query
1043
+ + """
1044
+ RETURN DISTINCT id(n) as id, n.name_embedding as embedding
1045
+ """
1046
+ )
1047
+ resp, header, _ = await driver.execute_query(
1048
+ query,
1049
+ search_vector=search_vector,
1050
+ limit=limit,
1051
+ min_score=min_score,
1052
+ routing_='r',
1053
+ **query_params,
1054
+ )
1055
+
1056
+ if len(resp) > 0:
1057
+ # Calculate Cosine similarity then return the edge ids
1058
+ input_ids = []
1059
+ for r in resp:
1060
+ if r['embedding']:
1061
+ score = calculate_cosine_similarity(
1062
+ search_vector, list(map(float, r['embedding'].split(',')))
1063
+ )
1064
+ if score > min_score:
1065
+ input_ids.append({'id': r['id'], 'score': score})
1066
+
1067
+ # Match the edge ides and return the values
1068
+ query = """
1069
+ UNWIND $ids as i
1070
+ MATCH (comm:Community)
1071
+ WHERE id(comm)=i.id
1072
+ RETURN
1073
+ comm.uuid As uuid,
1074
+ comm.group_id AS group_id,
1075
+ comm.name AS name,
1076
+ comm.created_at AS created_at,
1077
+ comm.summary AS summary,
1078
+ comm.name_embedding AS name_embedding
1079
+ ORDER BY i.score DESC
1080
+ LIMIT $limit
1081
+ """
1082
+ records, header, _ = await driver.execute_query(
1083
+ query,
1084
+ ids=input_ids,
1085
+ search_vector=search_vector,
1086
+ limit=limit,
1087
+ min_score=min_score,
1088
+ routing_='r',
1089
+ **query_params,
1090
+ )
1091
+ else:
1092
+ return []
1093
+ else:
1094
+ search_vector_var = '$search_vector'
1095
+ if driver.provider == GraphProvider.KUZU:
1096
+ search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
1097
+
1098
+ query = (
1099
+ """
1100
+ MATCH (c:Community)
1101
+ """
1102
+ + group_filter_query
1103
+ + """
1104
+ WITH c,
1105
+ """
1106
+ + get_vector_cosine_func_query('c.name_embedding', search_vector_var, driver.provider)
1107
+ + """ AS score
1108
+ WHERE score > $min_score
1109
+ RETURN
1110
+ """
1111
+ + COMMUNITY_NODE_RETURN
1112
+ + """
1113
+ ORDER BY score DESC
1114
+ LIMIT $limit
1115
+ """
1116
+ )
1117
+
1118
+ records, _, _ = await driver.execute_query(
1119
+ query,
1120
+ search_vector=search_vector,
1121
+ limit=limit,
1122
+ min_score=min_score,
1123
+ routing_='r',
1124
+ **query_params,
1125
+ )
1126
+
568
1127
  communities = [get_community_node_from_record(record) for record in records]
569
1128
 
570
1129
  return communities
@@ -573,7 +1132,7 @@ async def community_similarity_search(
573
1132
  async def hybrid_node_search(
574
1133
  queries: list[str],
575
1134
  embeddings: list[list[float]],
576
- driver: AsyncDriver,
1135
+ driver: GraphDriver,
577
1136
  search_filter: SearchFilters,
578
1137
  group_ids: list[str] | None = None,
579
1138
  limit: int = RELEVANT_SCHEMA_LIMIT,
@@ -590,7 +1149,7 @@ async def hybrid_node_search(
590
1149
  A list of text queries to search for.
591
1150
  embeddings : list[list[float]]
592
1151
  A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
593
- driver : AsyncDriver
1152
+ driver : GraphDriver
594
1153
  The Neo4j driver instance for database operations.
595
1154
  group_ids : list[str] | None, optional
596
1155
  The list of group ids to retrieve nodes from.
@@ -635,7 +1194,7 @@ async def hybrid_node_search(
635
1194
  }
636
1195
  result_uuids = [[node.uuid for node in result] for result in results]
637
1196
 
638
- ranked_uuids = rrf(result_uuids)
1197
+ ranked_uuids, _ = rrf(result_uuids)
639
1198
 
640
1199
  relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
641
1200
 
@@ -645,7 +1204,7 @@ async def hybrid_node_search(
645
1204
 
646
1205
 
647
1206
  async def get_relevant_nodes(
648
- driver: AsyncDriver,
1207
+ driver: GraphDriver,
649
1208
  nodes: list[EntityNode],
650
1209
  search_filter: SearchFilters,
651
1210
  min_score: float = DEFAULT_MIN_SCORE,
@@ -655,77 +1214,140 @@ async def get_relevant_nodes(
655
1214
  return []
656
1215
 
657
1216
  group_id = nodes[0].group_id
658
-
659
- # vector similarity search over entity names
660
- query_params: dict[str, Any] = {}
661
-
662
- filter_query, filter_params = node_search_filter_query_constructor(search_filter)
663
- query_params.update(filter_params)
664
-
665
- query = (
666
- RUNTIME_QUERY
667
- + """UNWIND $nodes AS node
668
- MATCH (n:Entity {group_id: $group_id})
669
- """
670
- + filter_query
671
- + """
672
- WITH node, n, vector.similarity.cosine(n.name_embedding, node.name_embedding) AS score
673
- WHERE score > $min_score
674
- WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
675
-
676
- CALL db.index.fulltext.queryNodes("node_name_and_summary", node.fulltext_query, {limit: $limit})
677
- YIELD node AS m
678
- WHERE m.group_id = $group_id
679
- WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
680
-
681
- WITH node,
682
- top_vector_nodes,
683
- [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
684
-
685
- WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
686
-
687
- UNWIND combined_nodes AS combined_node
688
- WITH node, collect(DISTINCT combined_node) AS deduped_nodes
689
-
690
- RETURN
691
- node.uuid AS search_node_uuid,
692
- [x IN deduped_nodes | {
693
- uuid: x.uuid,
694
- name: x.name,
695
- name_embedding: x.name_embedding,
696
- group_id: x.group_id,
697
- created_at: x.created_at,
698
- summary: x.summary,
699
- labels: labels(x),
700
- attributes: properties(x)
701
- }] AS matches
702
- """
703
- )
704
-
705
1217
  query_nodes = [
706
1218
  {
707
1219
  'uuid': node.uuid,
708
1220
  'name': node.name,
709
1221
  'name_embedding': node.name_embedding,
710
- 'fulltext_query': fulltext_query(node.name, [node.group_id]),
1222
+ 'fulltext_query': fulltext_query(node.name, [node.group_id], driver),
711
1223
  }
712
1224
  for node in nodes
713
1225
  ]
714
1226
 
1227
+ filter_queries, filter_params = node_search_filter_query_constructor(
1228
+ search_filter, driver.provider
1229
+ )
1230
+
1231
+ filter_query = ''
1232
+ if filter_queries:
1233
+ filter_query = 'WHERE ' + (' AND '.join(filter_queries))
1234
+
1235
+ if driver.provider == GraphProvider.KUZU:
1236
+ embedding_size = len(nodes[0].name_embedding) if nodes[0].name_embedding is not None else 0
1237
+ if embedding_size == 0:
1238
+ return []
1239
+
1240
+ # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
1241
+ query = (
1242
+ """
1243
+ UNWIND $nodes AS node
1244
+ MATCH (n:Entity {group_id: $group_id})
1245
+ """
1246
+ + filter_query
1247
+ + """
1248
+ WITH node, n, """
1249
+ + get_vector_cosine_func_query(
1250
+ 'n.name_embedding',
1251
+ f'CAST(node.name_embedding AS FLOAT[{embedding_size}])',
1252
+ driver.provider,
1253
+ )
1254
+ + """ AS score
1255
+ WHERE score > $min_score
1256
+ WITH node, collect(n)[:$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
1257
+ """
1258
+ + get_nodes_query(
1259
+ 'node_name_and_summary',
1260
+ 'node.fulltext_query',
1261
+ limit=limit,
1262
+ provider=driver.provider,
1263
+ )
1264
+ + """
1265
+ WITH node AS m
1266
+ WHERE m.group_id = $group_id AND NOT m.uuid IN vector_node_uuids
1267
+ WITH node, top_vector_nodes, collect(m) AS fulltext_nodes
1268
+
1269
+ WITH node, list_concat(top_vector_nodes, fulltext_nodes) AS combined_nodes
1270
+
1271
+ UNWIND combined_nodes AS x
1272
+ WITH node, collect(DISTINCT {
1273
+ uuid: x.uuid,
1274
+ name: x.name,
1275
+ name_embedding: x.name_embedding,
1276
+ group_id: x.group_id,
1277
+ created_at: x.created_at,
1278
+ summary: x.summary,
1279
+ labels: x.labels,
1280
+ attributes: x.attributes
1281
+ }) AS matches
1282
+
1283
+ RETURN
1284
+ node.uuid AS search_node_uuid, matches
1285
+ """
1286
+ )
1287
+ else:
1288
+ query = (
1289
+ """
1290
+ UNWIND $nodes AS node
1291
+ MATCH (n:Entity {group_id: $group_id})
1292
+ """
1293
+ + filter_query
1294
+ + """
1295
+ WITH node, n, """
1296
+ + get_vector_cosine_func_query(
1297
+ 'n.name_embedding', 'node.name_embedding', driver.provider
1298
+ )
1299
+ + """ AS score
1300
+ WHERE score > $min_score
1301
+ WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
1302
+ """
1303
+ + get_nodes_query(
1304
+ 'node_name_and_summary',
1305
+ 'node.fulltext_query',
1306
+ limit=limit,
1307
+ provider=driver.provider,
1308
+ )
1309
+ + """
1310
+ YIELD node AS m
1311
+ WHERE m.group_id = $group_id
1312
+ WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
1313
+
1314
+ WITH node,
1315
+ top_vector_nodes,
1316
+ [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
1317
+
1318
+ WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
1319
+
1320
+ UNWIND combined_nodes AS combined_node
1321
+ WITH node, collect(DISTINCT combined_node) AS deduped_nodes
1322
+
1323
+ RETURN
1324
+ node.uuid AS search_node_uuid,
1325
+ [x IN deduped_nodes | {
1326
+ uuid: x.uuid,
1327
+ name: x.name,
1328
+ name_embedding: x.name_embedding,
1329
+ group_id: x.group_id,
1330
+ created_at: x.created_at,
1331
+ summary: x.summary,
1332
+ labels: labels(x),
1333
+ attributes: properties(x)
1334
+ }] AS matches
1335
+ """
1336
+ )
1337
+
715
1338
  results, _, _ = await driver.execute_query(
716
1339
  query,
717
- query_params,
718
1340
  nodes=query_nodes,
719
1341
  group_id=group_id,
720
1342
  limit=limit,
721
1343
  min_score=min_score,
722
- database_=DEFAULT_DATABASE,
723
1344
  routing_='r',
1345
+ **filter_params,
724
1346
  )
725
1347
 
726
1348
  relevant_nodes_dict: dict[str, list[EntityNode]] = {
727
1349
  result['search_node_uuid']: [
728
- get_entity_node_from_record(record) for record in result['matches']
1350
+ get_entity_node_from_record(record, driver.provider) for record in result['matches']
729
1351
  ]
730
1352
  for result in results
731
1353
  }
@@ -736,7 +1358,7 @@ async def get_relevant_nodes(
736
1358
 
737
1359
 
738
1360
  async def get_relevant_edges(
739
- driver: AsyncDriver,
1361
+ driver: GraphDriver,
740
1362
  edges: list[EntityEdge],
741
1363
  search_filter: SearchFilters,
742
1364
  min_score: float = DEFAULT_MIN_SCORE,
@@ -745,53 +1367,172 @@ async def get_relevant_edges(
745
1367
  if len(edges) == 0:
746
1368
  return []
747
1369
 
748
- query_params: dict[str, Any] = {}
1370
+ filter_queries, filter_params = edge_search_filter_query_constructor(
1371
+ search_filter, driver.provider
1372
+ )
749
1373
 
750
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
751
- query_params.update(filter_params)
1374
+ filter_query = ''
1375
+ if filter_queries:
1376
+ filter_query = ' WHERE ' + (' AND '.join(filter_queries))
752
1377
 
753
- query = (
754
- RUNTIME_QUERY
755
- + """UNWIND $edges AS edge
756
- MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1378
+ if driver.provider == GraphProvider.NEPTUNE:
1379
+ query = (
757
1380
  """
758
- + filter_query
759
- + """
760
- WITH e, edge, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
761
- WHERE score > $min_score
762
- WITH edge, e, score
763
- ORDER BY score DESC
764
- RETURN edge.uuid AS search_edge_uuid,
765
- collect({
766
- uuid: e.uuid,
767
- source_node_uuid: startNode(e).uuid,
768
- target_node_uuid: endNode(e).uuid,
769
- created_at: e.created_at,
770
- name: e.name,
771
- group_id: e.group_id,
772
- fact: e.fact,
773
- fact_embedding: e.fact_embedding,
774
- episodes: e.episodes,
775
- expired_at: e.expired_at,
776
- valid_at: e.valid_at,
777
- invalid_at: e.invalid_at,
778
- attributes: properties(e)
779
- })[..$limit] AS matches
780
- """
781
- )
1381
+ UNWIND $edges AS edge
1382
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1383
+ """
1384
+ + filter_query
1385
+ + """
1386
+ WITH e, edge
1387
+ RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
1388
+ edge.fact_embedding as target_embedding
1389
+ """
1390
+ )
1391
+ resp, _, _ = await driver.execute_query(
1392
+ query,
1393
+ edges=[edge.model_dump() for edge in edges],
1394
+ limit=limit,
1395
+ min_score=min_score,
1396
+ routing_='r',
1397
+ **filter_params,
1398
+ )
1399
+
1400
+ # Calculate Cosine similarity then return the edge ids
1401
+ input_ids = []
1402
+ for r in resp:
1403
+ score = calculate_cosine_similarity(
1404
+ list(map(float, r['source_embedding'].split(','))), r['target_embedding']
1405
+ )
1406
+ if score > min_score:
1407
+ input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
1408
+
1409
+ # Match the edge ides and return the values
1410
+ query = """
1411
+ UNWIND $ids AS edge
1412
+ MATCH ()-[e]->()
1413
+ WHERE id(e) = edge.id
1414
+ WITH edge, e
1415
+ ORDER BY edge.score DESC
1416
+ RETURN edge.uuid AS search_edge_uuid,
1417
+ collect({
1418
+ uuid: e.uuid,
1419
+ source_node_uuid: startNode(e).uuid,
1420
+ target_node_uuid: endNode(e).uuid,
1421
+ created_at: e.created_at,
1422
+ name: e.name,
1423
+ group_id: e.group_id,
1424
+ fact: e.fact,
1425
+ fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
1426
+ episodes: split(e.episodes, ","),
1427
+ expired_at: e.expired_at,
1428
+ valid_at: e.valid_at,
1429
+ invalid_at: e.invalid_at,
1430
+ attributes: properties(e)
1431
+ })[..$limit] AS matches
1432
+ """
1433
+
1434
+ results, _, _ = await driver.execute_query(
1435
+ query,
1436
+ ids=input_ids,
1437
+ edges=[edge.model_dump() for edge in edges],
1438
+ limit=limit,
1439
+ min_score=min_score,
1440
+ routing_='r',
1441
+ **filter_params,
1442
+ )
1443
+ else:
1444
+ if driver.provider == GraphProvider.KUZU:
1445
+ embedding_size = (
1446
+ len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
1447
+ )
1448
+ if embedding_size == 0:
1449
+ return []
1450
+
1451
+ query = (
1452
+ """
1453
+ UNWIND $edges AS edge
1454
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
1455
+ """
1456
+ + filter_query
1457
+ + """
1458
+ WITH e, edge, n, m, """
1459
+ + get_vector_cosine_func_query(
1460
+ 'e.fact_embedding',
1461
+ f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
1462
+ driver.provider,
1463
+ )
1464
+ + """ AS score
1465
+ WHERE score > $min_score
1466
+ WITH e, edge, n, m, score
1467
+ ORDER BY score DESC
1468
+ LIMIT $limit
1469
+ RETURN
1470
+ edge.uuid AS search_edge_uuid,
1471
+ collect({
1472
+ uuid: e.uuid,
1473
+ source_node_uuid: n.uuid,
1474
+ target_node_uuid: m.uuid,
1475
+ created_at: e.created_at,
1476
+ name: e.name,
1477
+ group_id: e.group_id,
1478
+ fact: e.fact,
1479
+ fact_embedding: e.fact_embedding,
1480
+ episodes: e.episodes,
1481
+ expired_at: e.expired_at,
1482
+ valid_at: e.valid_at,
1483
+ invalid_at: e.invalid_at,
1484
+ attributes: e.attributes
1485
+ }) AS matches
1486
+ """
1487
+ )
1488
+ else:
1489
+ query = (
1490
+ """
1491
+ UNWIND $edges AS edge
1492
+ MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1493
+ """
1494
+ + filter_query
1495
+ + """
1496
+ WITH e, edge, """
1497
+ + get_vector_cosine_func_query(
1498
+ 'e.fact_embedding', 'edge.fact_embedding', driver.provider
1499
+ )
1500
+ + """ AS score
1501
+ WHERE score > $min_score
1502
+ WITH edge, e, score
1503
+ ORDER BY score DESC
1504
+ RETURN
1505
+ edge.uuid AS search_edge_uuid,
1506
+ collect({
1507
+ uuid: e.uuid,
1508
+ source_node_uuid: startNode(e).uuid,
1509
+ target_node_uuid: endNode(e).uuid,
1510
+ created_at: e.created_at,
1511
+ name: e.name,
1512
+ group_id: e.group_id,
1513
+ fact: e.fact,
1514
+ fact_embedding: e.fact_embedding,
1515
+ episodes: e.episodes,
1516
+ expired_at: e.expired_at,
1517
+ valid_at: e.valid_at,
1518
+ invalid_at: e.invalid_at,
1519
+ attributes: properties(e)
1520
+ })[..$limit] AS matches
1521
+ """
1522
+ )
1523
+
1524
+ results, _, _ = await driver.execute_query(
1525
+ query,
1526
+ edges=[edge.model_dump() for edge in edges],
1527
+ limit=limit,
1528
+ min_score=min_score,
1529
+ routing_='r',
1530
+ **filter_params,
1531
+ )
782
1532
 
783
- results, _, _ = await driver.execute_query(
784
- query,
785
- query_params,
786
- edges=[edge.model_dump() for edge in edges],
787
- limit=limit,
788
- min_score=min_score,
789
- database_=DEFAULT_DATABASE,
790
- routing_='r',
791
- )
792
1533
  relevant_edges_dict: dict[str, list[EntityEdge]] = {
793
1534
  result['search_edge_uuid']: [
794
- get_entity_edge_from_record(record) for record in result['matches']
1535
+ get_entity_edge_from_record(record, driver.provider) for record in result['matches']
795
1536
  ]
796
1537
  for result in results
797
1538
  }
@@ -802,7 +1543,7 @@ async def get_relevant_edges(
802
1543
 
803
1544
 
804
1545
  async def get_edge_invalidation_candidates(
805
- driver: AsyncDriver,
1546
+ driver: GraphDriver,
806
1547
  edges: list[EntityEdge],
807
1548
  search_filter: SearchFilters,
808
1549
  min_score: float = DEFAULT_MIN_SCORE,
@@ -811,54 +1552,174 @@ async def get_edge_invalidation_candidates(
811
1552
  if len(edges) == 0:
812
1553
  return []
813
1554
 
814
- query_params: dict[str, Any] = {}
1555
+ filter_queries, filter_params = edge_search_filter_query_constructor(
1556
+ search_filter, driver.provider
1557
+ )
815
1558
 
816
- filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
817
- query_params.update(filter_params)
1559
+ filter_query = ''
1560
+ if filter_queries:
1561
+ filter_query = ' AND ' + (' AND '.join(filter_queries))
818
1562
 
819
- query = (
820
- RUNTIME_QUERY
821
- + """UNWIND $edges AS edge
822
- MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
823
- WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1563
+ if driver.provider == GraphProvider.NEPTUNE:
1564
+ query = (
824
1565
  """
825
- + filter_query
826
- + """
827
- WITH edge, e, vector.similarity.cosine(e.fact_embedding, edge.fact_embedding) AS score
828
- WHERE score > $min_score
829
- WITH edge, e, score
830
- ORDER BY score DESC
831
- RETURN edge.uuid AS search_edge_uuid,
832
- collect({
833
- uuid: e.uuid,
834
- source_node_uuid: startNode(e).uuid,
835
- target_node_uuid: endNode(e).uuid,
836
- created_at: e.created_at,
837
- name: e.name,
838
- group_id: e.group_id,
839
- fact: e.fact,
840
- fact_embedding: e.fact_embedding,
841
- episodes: e.episodes,
842
- expired_at: e.expired_at,
843
- valid_at: e.valid_at,
844
- invalid_at: e.invalid_at,
845
- attributes: properties(e)
846
- })[..$limit] AS matches
847
- """
848
- )
1566
+ UNWIND $edges AS edge
1567
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1568
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1569
+ """
1570
+ + filter_query
1571
+ + """
1572
+ WITH e, edge
1573
+ RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding,
1574
+ edge.fact_embedding as target_embedding,
1575
+ edge.uuid as search_edge_uuid
1576
+ """
1577
+ )
1578
+ resp, _, _ = await driver.execute_query(
1579
+ query,
1580
+ edges=[edge.model_dump() for edge in edges],
1581
+ limit=limit,
1582
+ min_score=min_score,
1583
+ routing_='r',
1584
+ **filter_params,
1585
+ )
849
1586
 
850
- results, _, _ = await driver.execute_query(
851
- query,
852
- query_params,
853
- edges=[edge.model_dump() for edge in edges],
854
- limit=limit,
855
- min_score=min_score,
856
- database_=DEFAULT_DATABASE,
857
- routing_='r',
858
- )
1587
+ # Calculate Cosine similarity then return the edge ids
1588
+ input_ids = []
1589
+ for r in resp:
1590
+ score = calculate_cosine_similarity(
1591
+ list(map(float, r['source_embedding'].split(','))), r['target_embedding']
1592
+ )
1593
+ if score > min_score:
1594
+ input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
1595
+
1596
+ # Match the edge ides and return the values
1597
+ query = """
1598
+ UNWIND $ids AS edge
1599
+ MATCH ()-[e]->()
1600
+ WHERE id(e) = edge.id
1601
+ WITH edge, e
1602
+ ORDER BY edge.score DESC
1603
+ RETURN edge.uuid AS search_edge_uuid,
1604
+ collect({
1605
+ uuid: e.uuid,
1606
+ source_node_uuid: startNode(e).uuid,
1607
+ target_node_uuid: endNode(e).uuid,
1608
+ created_at: e.created_at,
1609
+ name: e.name,
1610
+ group_id: e.group_id,
1611
+ fact: e.fact,
1612
+ fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
1613
+ episodes: split(e.episodes, ","),
1614
+ expired_at: e.expired_at,
1615
+ valid_at: e.valid_at,
1616
+ invalid_at: e.invalid_at,
1617
+ attributes: properties(e)
1618
+ })[..$limit] AS matches
1619
+ """
1620
+ results, _, _ = await driver.execute_query(
1621
+ query,
1622
+ ids=input_ids,
1623
+ edges=[edge.model_dump() for edge in edges],
1624
+ limit=limit,
1625
+ min_score=min_score,
1626
+ routing_='r',
1627
+ **filter_params,
1628
+ )
1629
+ else:
1630
+ if driver.provider == GraphProvider.KUZU:
1631
+ embedding_size = (
1632
+ len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
1633
+ )
1634
+ if embedding_size == 0:
1635
+ return []
1636
+
1637
+ query = (
1638
+ """
1639
+ UNWIND $edges AS edge
1640
+ MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
1641
+ WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
1642
+ """
1643
+ + filter_query
1644
+ + """
1645
+ WITH edge, e, n, m, """
1646
+ + get_vector_cosine_func_query(
1647
+ 'e.fact_embedding',
1648
+ f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
1649
+ driver.provider,
1650
+ )
1651
+ + """ AS score
1652
+ WHERE score > $min_score
1653
+ WITH edge, e, n, m, score
1654
+ ORDER BY score DESC
1655
+ LIMIT $limit
1656
+ RETURN
1657
+ edge.uuid AS search_edge_uuid,
1658
+ collect({
1659
+ uuid: e.uuid,
1660
+ source_node_uuid: n.uuid,
1661
+ target_node_uuid: m.uuid,
1662
+ created_at: e.created_at,
1663
+ name: e.name,
1664
+ group_id: e.group_id,
1665
+ fact: e.fact,
1666
+ fact_embedding: e.fact_embedding,
1667
+ episodes: e.episodes,
1668
+ expired_at: e.expired_at,
1669
+ valid_at: e.valid_at,
1670
+ invalid_at: e.invalid_at,
1671
+ attributes: e.attributes
1672
+ }) AS matches
1673
+ """
1674
+ )
1675
+ else:
1676
+ query = (
1677
+ """
1678
+ UNWIND $edges AS edge
1679
+ MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1680
+ WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1681
+ """
1682
+ + filter_query
1683
+ + """
1684
+ WITH edge, e, """
1685
+ + get_vector_cosine_func_query(
1686
+ 'e.fact_embedding', 'edge.fact_embedding', driver.provider
1687
+ )
1688
+ + """ AS score
1689
+ WHERE score > $min_score
1690
+ WITH edge, e, score
1691
+ ORDER BY score DESC
1692
+ RETURN
1693
+ edge.uuid AS search_edge_uuid,
1694
+ collect({
1695
+ uuid: e.uuid,
1696
+ source_node_uuid: startNode(e).uuid,
1697
+ target_node_uuid: endNode(e).uuid,
1698
+ created_at: e.created_at,
1699
+ name: e.name,
1700
+ group_id: e.group_id,
1701
+ fact: e.fact,
1702
+ fact_embedding: e.fact_embedding,
1703
+ episodes: e.episodes,
1704
+ expired_at: e.expired_at,
1705
+ valid_at: e.valid_at,
1706
+ invalid_at: e.invalid_at,
1707
+ attributes: properties(e)
1708
+ })[..$limit] AS matches
1709
+ """
1710
+ )
1711
+
1712
+ results, _, _ = await driver.execute_query(
1713
+ query,
1714
+ edges=[edge.model_dump() for edge in edges],
1715
+ limit=limit,
1716
+ min_score=min_score,
1717
+ routing_='r',
1718
+ **filter_params,
1719
+ )
859
1720
  invalidation_edges_dict: dict[str, list[EntityEdge]] = {
860
1721
  result['search_edge_uuid']: [
861
- get_entity_edge_from_record(record) for record in result['matches']
1722
+ get_entity_edge_from_record(record, driver.provider) for record in result['matches']
862
1723
  ]
863
1724
  for result in results
864
1725
  }
@@ -869,7 +1730,9 @@ async def get_edge_invalidation_candidates(
869
1730
 
870
1731
 
871
1732
  # takes in a list of rankings of uuids
872
- def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[str]:
1733
+ def rrf(
1734
+ results: list[list[str]], rank_const=1, min_score: float = 0
1735
+ ) -> tuple[list[str], list[float]]:
873
1736
  scores: dict[str, float] = defaultdict(float)
874
1737
  for result in results:
875
1738
  for i, uuid in enumerate(result):
@@ -880,35 +1743,44 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
880
1743
 
881
1744
  sorted_uuids = [term[0] for term in scored_uuids]
882
1745
 
883
- return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
1746
+ return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
1747
+ scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
1748
+ ]
884
1749
 
885
1750
 
886
1751
  async def node_distance_reranker(
887
- driver: AsyncDriver,
1752
+ driver: GraphDriver,
888
1753
  node_uuids: list[str],
889
1754
  center_node_uuid: str,
890
1755
  min_score: float = 0,
891
- ) -> list[str]:
1756
+ ) -> tuple[list[str], list[float]]:
892
1757
  # filter out node_uuid center node node uuid
893
1758
  filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
894
1759
  scores: dict[str, float] = {center_node_uuid: 0.0}
895
1760
 
896
- # Find the shortest path to center node
897
- query = Query("""
1761
+ query = """
1762
+ UNWIND $node_uuids AS node_uuid
1763
+ MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
1764
+ RETURN 1 AS score, node_uuid AS uuid
1765
+ """
1766
+ if driver.provider == GraphProvider.KUZU:
1767
+ query = """
898
1768
  UNWIND $node_uuids AS node_uuid
899
- MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: node_uuid})
900
- RETURN length(p) AS score, node_uuid AS uuid
901
- """)
1769
+ MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(n:Entity {uuid: node_uuid})
1770
+ RETURN 1 AS score, node_uuid AS uuid
1771
+ """
902
1772
 
903
- path_results, _, _ = await driver.execute_query(
1773
+ # Find the shortest path to center node
1774
+ results, header, _ = await driver.execute_query(
904
1775
  query,
905
1776
  node_uuids=filtered_uuids,
906
1777
  center_uuid=center_node_uuid,
907
- database_=DEFAULT_DATABASE,
908
1778
  routing_='r',
909
1779
  )
1780
+ if driver.provider == GraphProvider.FALKORDB:
1781
+ results = [dict(zip(header, row, strict=True)) for row in results]
910
1782
 
911
- for result in path_results:
1783
+ for result in results:
912
1784
  uuid = result['uuid']
913
1785
  score = result['score']
914
1786
  scores[uuid] = score
@@ -925,37 +1797,42 @@ async def node_distance_reranker(
925
1797
  scores[center_node_uuid] = 0.1
926
1798
  filtered_uuids = [center_node_uuid] + filtered_uuids
927
1799
 
928
- return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score]
1800
+ return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score], [
1801
+ 1 / scores[uuid] for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score
1802
+ ]
929
1803
 
930
1804
 
931
1805
  async def episode_mentions_reranker(
932
- driver: AsyncDriver, node_uuids: list[list[str]], min_score: float = 0
933
- ) -> list[str]:
1806
+ driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
1807
+ ) -> tuple[list[str], list[float]]:
934
1808
  # use rrf as a preliminary ranker
935
- sorted_uuids = rrf(node_uuids)
1809
+ sorted_uuids, _ = rrf(node_uuids)
936
1810
  scores: dict[str, float] = {}
937
1811
 
938
1812
  # Find the shortest path to center node
939
- query = Query("""
940
- UNWIND $node_uuids AS node_uuid
1813
+ results, _, _ = await driver.execute_query(
1814
+ """
1815
+ UNWIND $node_uuids AS node_uuid
941
1816
  MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
942
1817
  RETURN count(*) AS score, n.uuid AS uuid
943
- """)
944
-
945
- results, _, _ = await driver.execute_query(
946
- query,
1818
+ """,
947
1819
  node_uuids=sorted_uuids,
948
- database_=DEFAULT_DATABASE,
949
1820
  routing_='r',
950
1821
  )
951
1822
 
952
1823
  for result in results:
953
1824
  scores[result['uuid']] = result['score']
954
1825
 
1826
+ for uuid in sorted_uuids:
1827
+ if uuid not in scores:
1828
+ scores[uuid] = float('inf')
1829
+
955
1830
  # rerank on shortest distance
956
1831
  sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
957
1832
 
958
- return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
1833
+ return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
1834
+ scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
1835
+ ]
959
1836
 
960
1837
 
961
1838
  def maximal_marginal_relevance(
@@ -963,7 +1840,7 @@ def maximal_marginal_relevance(
963
1840
  candidates: dict[str, list[float]],
964
1841
  mmr_lambda: float = DEFAULT_MMR_LAMBDA,
965
1842
  min_score: float = -2.0,
966
- ) -> list[str]:
1843
+ ) -> tuple[list[str], list[float]]:
967
1844
  start = time()
968
1845
  query_array = np.array(query_vector)
969
1846
  candidate_arrays: dict[str, NDArray] = {}
@@ -994,21 +1871,36 @@ def maximal_marginal_relevance(
994
1871
  end = time()
995
1872
  logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
996
1873
 
997
- return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
1874
+ return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score], [
1875
+ mmr_scores[uuid] for uuid in uuids if mmr_scores[uuid] >= min_score
1876
+ ]
998
1877
 
999
1878
 
1000
1879
  async def get_embeddings_for_nodes(
1001
- driver: AsyncDriver, nodes: list[EntityNode]
1880
+ driver: GraphDriver, nodes: list[EntityNode]
1002
1881
  ) -> dict[str, list[float]]:
1003
- query: LiteralString = """MATCH (n:Entity)
1004
- WHERE n.uuid IN $node_uuids
1005
- RETURN DISTINCT
1006
- n.uuid AS uuid,
1007
- n.name_embedding AS name_embedding
1008
- """
1009
-
1882
+ if driver.graph_operations_interface:
1883
+ return await driver.graph_operations_interface.node_load_embeddings_bulk(driver, nodes)
1884
+ elif driver.provider == GraphProvider.NEPTUNE:
1885
+ query = """
1886
+ MATCH (n:Entity)
1887
+ WHERE n.uuid IN $node_uuids
1888
+ RETURN DISTINCT
1889
+ n.uuid AS uuid,
1890
+ split(n.name_embedding, ",") AS name_embedding
1891
+ """
1892
+ else:
1893
+ query = """
1894
+ MATCH (n:Entity)
1895
+ WHERE n.uuid IN $node_uuids
1896
+ RETURN DISTINCT
1897
+ n.uuid AS uuid,
1898
+ n.name_embedding AS name_embedding
1899
+ """
1010
1900
  results, _, _ = await driver.execute_query(
1011
- query, node_uuids=[node.uuid for node in nodes], database_=DEFAULT_DATABASE, routing_='r'
1901
+ query,
1902
+ node_uuids=[node.uuid for node in nodes],
1903
+ routing_='r',
1012
1904
  )
1013
1905
 
1014
1906
  embeddings_dict: dict[str, list[float]] = {}
@@ -1022,19 +1914,27 @@ async def get_embeddings_for_nodes(
1022
1914
 
1023
1915
 
1024
1916
  async def get_embeddings_for_communities(
1025
- driver: AsyncDriver, communities: list[CommunityNode]
1917
+ driver: GraphDriver, communities: list[CommunityNode]
1026
1918
  ) -> dict[str, list[float]]:
1027
- query: LiteralString = """MATCH (c:Community)
1028
- WHERE c.uuid IN $community_uuids
1029
- RETURN DISTINCT
1030
- c.uuid AS uuid,
1031
- c.name_embedding AS name_embedding
1032
- """
1033
-
1919
+ if driver.provider == GraphProvider.NEPTUNE:
1920
+ query = """
1921
+ MATCH (c:Community)
1922
+ WHERE c.uuid IN $community_uuids
1923
+ RETURN DISTINCT
1924
+ c.uuid AS uuid,
1925
+ split(c.name_embedding, ",") AS name_embedding
1926
+ """
1927
+ else:
1928
+ query = """
1929
+ MATCH (c:Community)
1930
+ WHERE c.uuid IN $community_uuids
1931
+ RETURN DISTINCT
1932
+ c.uuid AS uuid,
1933
+ c.name_embedding AS name_embedding
1934
+ """
1034
1935
  results, _, _ = await driver.execute_query(
1035
1936
  query,
1036
1937
  community_uuids=[community.uuid for community in communities],
1037
- database_=DEFAULT_DATABASE,
1038
1938
  routing_='r',
1039
1939
  )
1040
1940
 
@@ -1049,19 +1949,39 @@ async def get_embeddings_for_communities(
1049
1949
 
1050
1950
 
1051
1951
  async def get_embeddings_for_edges(
1052
- driver: AsyncDriver, edges: list[EntityEdge]
1952
+ driver: GraphDriver, edges: list[EntityEdge]
1053
1953
  ) -> dict[str, list[float]]:
1054
- query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1055
- WHERE e.uuid IN $edge_uuids
1056
- RETURN DISTINCT
1057
- e.uuid AS uuid,
1058
- e.fact_embedding AS fact_embedding
1059
- """
1954
+ if driver.graph_operations_interface:
1955
+ return await driver.graph_operations_interface.edge_load_embeddings_bulk(driver, edges)
1956
+ elif driver.provider == GraphProvider.NEPTUNE:
1957
+ query = """
1958
+ MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1959
+ WHERE e.uuid IN $edge_uuids
1960
+ RETURN DISTINCT
1961
+ e.uuid AS uuid,
1962
+ split(e.fact_embedding, ",") AS fact_embedding
1963
+ """
1964
+ else:
1965
+ match_query = """
1966
+ MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1967
+ """
1968
+ if driver.provider == GraphProvider.KUZU:
1969
+ match_query = """
1970
+ MATCH (n:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m:Entity)
1971
+ """
1060
1972
 
1973
+ query = (
1974
+ match_query
1975
+ + """
1976
+ WHERE e.uuid IN $edge_uuids
1977
+ RETURN DISTINCT
1978
+ e.uuid AS uuid,
1979
+ e.fact_embedding AS fact_embedding
1980
+ """
1981
+ )
1061
1982
  results, _, _ = await driver.execute_query(
1062
1983
  query,
1063
1984
  edge_uuids=[edge.uuid for edge in edges],
1064
- database_=DEFAULT_DATABASE,
1065
1985
  routing_='r',
1066
1986
  )
1067
1987