graphiti-core 0.17.4__py3-none-any.whl → 0.24.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphiti_core/cross_encoder/gemini_reranker_client.py +1 -1
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/decorators.py +110 -0
- graphiti_core/driver/driver.py +62 -2
- graphiti_core/driver/falkordb_driver.py +215 -23
- graphiti_core/driver/graph_operations/graph_operations.py +191 -0
- graphiti_core/driver/kuzu_driver.py +182 -0
- graphiti_core/driver/neo4j_driver.py +61 -8
- graphiti_core/driver/neptune_driver.py +305 -0
- graphiti_core/driver/search_interface/search_interface.py +89 -0
- graphiti_core/edges.py +264 -132
- graphiti_core/embedder/azure_openai.py +10 -3
- graphiti_core/embedder/client.py +2 -1
- graphiti_core/graph_queries.py +114 -101
- graphiti_core/graphiti.py +582 -255
- graphiti_core/graphiti_types.py +2 -0
- graphiti_core/helpers.py +21 -14
- graphiti_core/llm_client/anthropic_client.py +142 -52
- graphiti_core/llm_client/azure_openai_client.py +57 -19
- graphiti_core/llm_client/client.py +83 -21
- graphiti_core/llm_client/config.py +1 -1
- graphiti_core/llm_client/gemini_client.py +75 -57
- graphiti_core/llm_client/openai_base_client.py +94 -50
- graphiti_core/llm_client/openai_client.py +28 -8
- graphiti_core/llm_client/openai_generic_client.py +91 -56
- graphiti_core/models/edges/edge_db_queries.py +259 -35
- graphiti_core/models/nodes/node_db_queries.py +311 -32
- graphiti_core/nodes.py +388 -164
- graphiti_core/prompts/dedupe_edges.py +42 -31
- graphiti_core/prompts/dedupe_nodes.py +56 -39
- graphiti_core/prompts/eval.py +4 -4
- graphiti_core/prompts/extract_edges.py +23 -14
- graphiti_core/prompts/extract_nodes.py +73 -32
- graphiti_core/prompts/prompt_helpers.py +39 -0
- graphiti_core/prompts/snippets.py +29 -0
- graphiti_core/prompts/summarize_nodes.py +23 -25
- graphiti_core/search/search.py +154 -74
- graphiti_core/search/search_config.py +39 -4
- graphiti_core/search/search_filters.py +109 -31
- graphiti_core/search/search_helpers.py +5 -6
- graphiti_core/search/search_utils.py +1360 -473
- graphiti_core/tracer.py +193 -0
- graphiti_core/utils/bulk_utils.py +216 -90
- graphiti_core/utils/datetime_utils.py +13 -0
- graphiti_core/utils/maintenance/community_operations.py +62 -38
- graphiti_core/utils/maintenance/dedup_helpers.py +262 -0
- graphiti_core/utils/maintenance/edge_operations.py +286 -126
- graphiti_core/utils/maintenance/graph_data_operations.py +44 -74
- graphiti_core/utils/maintenance/node_operations.py +320 -158
- graphiti_core/utils/maintenance/temporal_operations.py +11 -3
- graphiti_core/utils/ontology_utils/entity_types_utils.py +1 -1
- graphiti_core/utils/text_utils.py +53 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/METADATA +221 -87
- graphiti_core-0.24.3.dist-info/RECORD +86 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/WHEEL +1 -1
- graphiti_core-0.17.4.dist-info/RECORD +0 -77
- /graphiti_core/{utils/maintenance/utils.py → migrations/__init__.py} +0 -0
- {graphiti_core-0.17.4.dist-info → graphiti_core-0.24.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -23,7 +23,10 @@ import numpy as np
|
|
|
23
23
|
from numpy._typing import NDArray
|
|
24
24
|
from typing_extensions import LiteralString
|
|
25
25
|
|
|
26
|
-
from graphiti_core.driver.driver import
|
|
26
|
+
from graphiti_core.driver.driver import (
|
|
27
|
+
GraphDriver,
|
|
28
|
+
GraphProvider,
|
|
29
|
+
)
|
|
27
30
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
|
28
31
|
from graphiti_core.graph_queries import (
|
|
29
32
|
get_nodes_query,
|
|
@@ -31,13 +34,17 @@ from graphiti_core.graph_queries import (
|
|
|
31
34
|
get_vector_cosine_func_query,
|
|
32
35
|
)
|
|
33
36
|
from graphiti_core.helpers import (
|
|
34
|
-
RUNTIME_QUERY,
|
|
35
37
|
lucene_sanitize,
|
|
36
38
|
normalize_l2,
|
|
37
39
|
semaphore_gather,
|
|
38
40
|
)
|
|
41
|
+
from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query
|
|
42
|
+
from graphiti_core.models.nodes.node_db_queries import (
|
|
43
|
+
COMMUNITY_NODE_RETURN,
|
|
44
|
+
EPISODIC_NODE_RETURN,
|
|
45
|
+
get_entity_node_return_query,
|
|
46
|
+
)
|
|
39
47
|
from graphiti_core.nodes import (
|
|
40
|
-
ENTITY_NODE_RETURN,
|
|
41
48
|
CommunityNode,
|
|
42
49
|
EntityNode,
|
|
43
50
|
EpisodicNode,
|
|
@@ -57,12 +64,35 @@ RELEVANT_SCHEMA_LIMIT = 10
|
|
|
57
64
|
DEFAULT_MIN_SCORE = 0.6
|
|
58
65
|
DEFAULT_MMR_LAMBDA = 0.5
|
|
59
66
|
MAX_SEARCH_DEPTH = 3
|
|
60
|
-
MAX_QUERY_LENGTH =
|
|
67
|
+
MAX_QUERY_LENGTH = 128
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> float:
|
|
71
|
+
"""
|
|
72
|
+
Calculates the cosine similarity between two vectors using NumPy.
|
|
73
|
+
"""
|
|
74
|
+
dot_product = np.dot(vector1, vector2)
|
|
75
|
+
norm_vector1 = np.linalg.norm(vector1)
|
|
76
|
+
norm_vector2 = np.linalg.norm(vector2)
|
|
61
77
|
|
|
78
|
+
if norm_vector1 == 0 or norm_vector2 == 0:
|
|
79
|
+
return 0 # Handle cases where one or both vectors are zero vectors
|
|
62
80
|
|
|
63
|
-
|
|
81
|
+
return dot_product / (norm_vector1 * norm_vector2)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver):
|
|
85
|
+
if driver.provider == GraphProvider.KUZU:
|
|
86
|
+
# Kuzu only supports simple queries.
|
|
87
|
+
if len(query.split(' ')) > MAX_QUERY_LENGTH:
|
|
88
|
+
return ''
|
|
89
|
+
return query
|
|
90
|
+
elif driver.provider == GraphProvider.FALKORDB:
|
|
91
|
+
return driver.build_fulltext_query(query, group_ids, MAX_QUERY_LENGTH)
|
|
64
92
|
group_ids_filter_list = (
|
|
65
|
-
[f'group_id:"{
|
|
93
|
+
[driver.fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
|
|
94
|
+
if group_ids is not None
|
|
95
|
+
else []
|
|
66
96
|
)
|
|
67
97
|
group_ids_filter = ''
|
|
68
98
|
for f in group_ids_filter_list:
|
|
@@ -100,25 +130,18 @@ async def get_mentioned_nodes(
|
|
|
100
130
|
) -> list[EntityNode]:
|
|
101
131
|
episode_uuids = [episode.uuid for episode in episodes]
|
|
102
132
|
|
|
103
|
-
|
|
104
|
-
|
|
133
|
+
records, _, _ = await driver.execute_query(
|
|
134
|
+
"""
|
|
135
|
+
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
|
|
136
|
+
WHERE episode.uuid IN $uuids
|
|
105
137
|
RETURN DISTINCT
|
|
106
|
-
n.uuid As uuid,
|
|
107
|
-
n.group_id AS group_id,
|
|
108
|
-
n.name AS name,
|
|
109
|
-
n.created_at AS created_at,
|
|
110
|
-
n.summary AS summary,
|
|
111
|
-
labels(n) AS labels,
|
|
112
|
-
properties(n) AS attributes
|
|
113
138
|
"""
|
|
114
|
-
|
|
115
|
-
records, _, _ = await driver.execute_query(
|
|
116
|
-
query,
|
|
139
|
+
+ get_entity_node_return_query(driver.provider),
|
|
117
140
|
uuids=episode_uuids,
|
|
118
141
|
routing_='r',
|
|
119
142
|
)
|
|
120
143
|
|
|
121
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
144
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
122
145
|
|
|
123
146
|
return nodes
|
|
124
147
|
|
|
@@ -128,18 +151,13 @@ async def get_communities_by_nodes(
|
|
|
128
151
|
) -> list[CommunityNode]:
|
|
129
152
|
node_uuids = [node.uuid for node in nodes]
|
|
130
153
|
|
|
131
|
-
query = """
|
|
132
|
-
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
|
133
|
-
RETURN DISTINCT
|
|
134
|
-
c.uuid As uuid,
|
|
135
|
-
c.group_id AS group_id,
|
|
136
|
-
c.name AS name,
|
|
137
|
-
c.created_at AS created_at,
|
|
138
|
-
c.summary AS summary
|
|
139
|
-
"""
|
|
140
|
-
|
|
141
154
|
records, _, _ = await driver.execute_query(
|
|
142
|
-
|
|
155
|
+
"""
|
|
156
|
+
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)
|
|
157
|
+
WHERE m.uuid IN $uuids
|
|
158
|
+
RETURN DISTINCT
|
|
159
|
+
"""
|
|
160
|
+
+ COMMUNITY_NODE_RETURN,
|
|
143
161
|
uuids=node_uuids,
|
|
144
162
|
routing_='r',
|
|
145
163
|
)
|
|
@@ -156,49 +174,110 @@ async def edge_fulltext_search(
|
|
|
156
174
|
group_ids: list[str] | None = None,
|
|
157
175
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
158
176
|
) -> list[EntityEdge]:
|
|
177
|
+
if driver.search_interface:
|
|
178
|
+
return await driver.search_interface.edge_fulltext_search(
|
|
179
|
+
driver, query, search_filter, group_ids, limit
|
|
180
|
+
)
|
|
181
|
+
|
|
159
182
|
# fulltext search over facts
|
|
160
|
-
fuzzy_query = fulltext_query(query, group_ids)
|
|
183
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
184
|
+
|
|
161
185
|
if fuzzy_query == '':
|
|
162
186
|
return []
|
|
163
187
|
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
+ filter_query
|
|
173
|
-
+ """
|
|
174
|
-
WITH r, score, startNode(r) AS n, endNode(r) AS m
|
|
175
|
-
RETURN
|
|
176
|
-
r.uuid AS uuid,
|
|
177
|
-
r.group_id AS group_id,
|
|
178
|
-
n.uuid AS source_node_uuid,
|
|
179
|
-
m.uuid AS target_node_uuid,
|
|
180
|
-
r.created_at AS created_at,
|
|
181
|
-
r.name AS name,
|
|
182
|
-
r.fact AS fact,
|
|
183
|
-
r.episodes AS episodes,
|
|
184
|
-
r.expired_at AS expired_at,
|
|
185
|
-
r.valid_at AS valid_at,
|
|
186
|
-
r.invalid_at AS invalid_at,
|
|
187
|
-
properties(r) AS attributes
|
|
188
|
-
ORDER BY score DESC LIMIT $limit
|
|
188
|
+
match_query = """
|
|
189
|
+
YIELD relationship AS rel, score
|
|
190
|
+
MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
|
|
191
|
+
"""
|
|
192
|
+
if driver.provider == GraphProvider.KUZU:
|
|
193
|
+
match_query = """
|
|
194
|
+
YIELD node, score
|
|
195
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: node.uuid})-[:RELATES_TO]->(m:Entity)
|
|
189
196
|
"""
|
|
190
|
-
)
|
|
191
197
|
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
params=filter_params,
|
|
195
|
-
query=fuzzy_query,
|
|
196
|
-
group_ids=group_ids,
|
|
197
|
-
limit=limit,
|
|
198
|
-
routing_='r',
|
|
198
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
199
|
+
search_filter, driver.provider
|
|
199
200
|
)
|
|
200
201
|
|
|
201
|
-
|
|
202
|
+
if group_ids is not None:
|
|
203
|
+
filter_queries.append('e.group_id IN $group_ids')
|
|
204
|
+
filter_params['group_ids'] = group_ids
|
|
205
|
+
|
|
206
|
+
filter_query = ''
|
|
207
|
+
if filter_queries:
|
|
208
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
209
|
+
|
|
210
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
211
|
+
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
|
|
212
|
+
if res['hits']['total']['value'] > 0:
|
|
213
|
+
input_ids = []
|
|
214
|
+
for r in res['hits']['hits']:
|
|
215
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
216
|
+
|
|
217
|
+
# Match the edge ids and return the values
|
|
218
|
+
query = (
|
|
219
|
+
"""
|
|
220
|
+
UNWIND $ids as id
|
|
221
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
222
|
+
WHERE e.group_id IN $group_ids
|
|
223
|
+
AND id(e)=id
|
|
224
|
+
"""
|
|
225
|
+
+ filter_query
|
|
226
|
+
+ """
|
|
227
|
+
AND id(e)=id
|
|
228
|
+
WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
|
|
229
|
+
RETURN
|
|
230
|
+
e.uuid AS uuid,
|
|
231
|
+
e.group_id AS group_id,
|
|
232
|
+
n.uuid AS source_node_uuid,
|
|
233
|
+
m.uuid AS target_node_uuid,
|
|
234
|
+
e.created_at AS created_at,
|
|
235
|
+
e.name AS name,
|
|
236
|
+
e.fact AS fact,
|
|
237
|
+
split(e.episodes, ",") AS episodes,
|
|
238
|
+
e.expired_at AS expired_at,
|
|
239
|
+
e.valid_at AS valid_at,
|
|
240
|
+
e.invalid_at AS invalid_at,
|
|
241
|
+
properties(e) AS attributes
|
|
242
|
+
ORDER BY score DESC LIMIT $limit
|
|
243
|
+
"""
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
records, _, _ = await driver.execute_query(
|
|
247
|
+
query,
|
|
248
|
+
query=fuzzy_query,
|
|
249
|
+
ids=input_ids,
|
|
250
|
+
limit=limit,
|
|
251
|
+
routing_='r',
|
|
252
|
+
**filter_params,
|
|
253
|
+
)
|
|
254
|
+
else:
|
|
255
|
+
return []
|
|
256
|
+
else:
|
|
257
|
+
query = (
|
|
258
|
+
get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
|
|
259
|
+
+ match_query
|
|
260
|
+
+ filter_query
|
|
261
|
+
+ """
|
|
262
|
+
WITH e, score, n, m
|
|
263
|
+
RETURN
|
|
264
|
+
"""
|
|
265
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
266
|
+
+ """
|
|
267
|
+
ORDER BY score DESC
|
|
268
|
+
LIMIT $limit
|
|
269
|
+
"""
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
records, _, _ = await driver.execute_query(
|
|
273
|
+
query,
|
|
274
|
+
query=fuzzy_query,
|
|
275
|
+
limit=limit,
|
|
276
|
+
routing_='r',
|
|
277
|
+
**filter_params,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
202
281
|
|
|
203
282
|
return edges
|
|
204
283
|
|
|
@@ -213,95 +292,86 @@ async def edge_similarity_search(
|
|
|
213
292
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
214
293
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
215
294
|
) -> list[EntityEdge]:
|
|
216
|
-
|
|
217
|
-
|
|
295
|
+
if driver.search_interface:
|
|
296
|
+
return await driver.search_interface.edge_similarity_search(
|
|
297
|
+
driver,
|
|
298
|
+
search_vector,
|
|
299
|
+
source_node_uuid,
|
|
300
|
+
target_node_uuid,
|
|
301
|
+
search_filter,
|
|
302
|
+
group_ids,
|
|
303
|
+
limit,
|
|
304
|
+
min_score,
|
|
305
|
+
)
|
|
218
306
|
|
|
219
|
-
|
|
220
|
-
|
|
307
|
+
match_query = """
|
|
308
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
309
|
+
"""
|
|
310
|
+
if driver.provider == GraphProvider.KUZU:
|
|
311
|
+
match_query = """
|
|
312
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
|
|
313
|
+
"""
|
|
314
|
+
|
|
315
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
316
|
+
search_filter, driver.provider
|
|
317
|
+
)
|
|
221
318
|
|
|
222
|
-
group_filter_query: LiteralString = 'WHERE r.group_id IS NOT NULL'
|
|
223
319
|
if group_ids is not None:
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
query_params['source_node_uuid'] = source_node_uuid
|
|
227
|
-
query_params['target_node_uuid'] = target_node_uuid
|
|
320
|
+
filter_queries.append('e.group_id IN $group_ids')
|
|
321
|
+
filter_params['group_ids'] = group_ids
|
|
228
322
|
|
|
229
323
|
if source_node_uuid is not None:
|
|
230
|
-
|
|
324
|
+
filter_params['source_uuid'] = source_node_uuid
|
|
325
|
+
filter_queries.append('n.uuid = $source_uuid')
|
|
231
326
|
|
|
232
327
|
if target_node_uuid is not None:
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
query = (
|
|
236
|
-
RUNTIME_QUERY
|
|
237
|
-
+ """
|
|
238
|
-
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
|
239
|
-
"""
|
|
240
|
-
+ group_filter_query
|
|
241
|
-
+ filter_query
|
|
242
|
-
+ """
|
|
243
|
-
WITH DISTINCT r, """
|
|
244
|
-
+ get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
|
|
245
|
-
+ """ AS score
|
|
246
|
-
WHERE score > $min_score
|
|
247
|
-
RETURN
|
|
248
|
-
r.uuid AS uuid,
|
|
249
|
-
r.group_id AS group_id,
|
|
250
|
-
startNode(r).uuid AS source_node_uuid,
|
|
251
|
-
endNode(r).uuid AS target_node_uuid,
|
|
252
|
-
r.created_at AS created_at,
|
|
253
|
-
r.name AS name,
|
|
254
|
-
r.fact AS fact,
|
|
255
|
-
r.episodes AS episodes,
|
|
256
|
-
r.expired_at AS expired_at,
|
|
257
|
-
r.valid_at AS valid_at,
|
|
258
|
-
r.invalid_at AS invalid_at,
|
|
259
|
-
properties(r) AS attributes
|
|
260
|
-
ORDER BY score DESC
|
|
261
|
-
LIMIT $limit
|
|
262
|
-
"""
|
|
263
|
-
)
|
|
264
|
-
records, header, _ = await driver.execute_query(
|
|
265
|
-
query,
|
|
266
|
-
params=query_params,
|
|
267
|
-
search_vector=search_vector,
|
|
268
|
-
source_uuid=source_node_uuid,
|
|
269
|
-
target_uuid=target_node_uuid,
|
|
270
|
-
group_ids=group_ids,
|
|
271
|
-
limit=limit,
|
|
272
|
-
min_score=min_score,
|
|
273
|
-
routing_='r',
|
|
274
|
-
)
|
|
275
|
-
|
|
276
|
-
edges = [get_entity_edge_from_record(record) for record in records]
|
|
328
|
+
filter_params['target_uuid'] = target_node_uuid
|
|
329
|
+
filter_queries.append('m.uuid = $target_uuid')
|
|
277
330
|
|
|
278
|
-
|
|
331
|
+
filter_query = ''
|
|
332
|
+
if filter_queries:
|
|
333
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
279
334
|
|
|
335
|
+
search_vector_var = '$search_vector'
|
|
336
|
+
if driver.provider == GraphProvider.KUZU:
|
|
337
|
+
search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
|
|
280
338
|
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
339
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
340
|
+
query = (
|
|
341
|
+
"""
|
|
342
|
+
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
|
|
343
|
+
"""
|
|
344
|
+
+ filter_query
|
|
345
|
+
+ """
|
|
346
|
+
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
|
|
347
|
+
"""
|
|
348
|
+
)
|
|
349
|
+
resp, header, _ = await driver.execute_query(
|
|
350
|
+
query,
|
|
351
|
+
search_vector=search_vector,
|
|
352
|
+
limit=limit,
|
|
353
|
+
min_score=min_score,
|
|
354
|
+
routing_='r',
|
|
355
|
+
**filter_params,
|
|
356
|
+
)
|
|
293
357
|
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
358
|
+
if len(resp) > 0:
|
|
359
|
+
# Calculate Cosine similarity then return the edge ids
|
|
360
|
+
input_ids = []
|
|
361
|
+
for r in resp:
|
|
362
|
+
if r['embedding']:
|
|
363
|
+
score = calculate_cosine_similarity(
|
|
364
|
+
search_vector, list(map(float, r['embedding'].split(',')))
|
|
365
|
+
)
|
|
366
|
+
if score > min_score:
|
|
367
|
+
input_ids.append({'id': r['id'], 'score': score})
|
|
368
|
+
|
|
369
|
+
# Match the edge ides and return the values
|
|
370
|
+
query = """
|
|
371
|
+
UNWIND $ids as i
|
|
372
|
+
MATCH ()-[r]->()
|
|
373
|
+
WHERE id(r) = i.id
|
|
374
|
+
RETURN
|
|
305
375
|
r.uuid AS uuid,
|
|
306
376
|
r.group_id AS group_id,
|
|
307
377
|
startNode(r).uuid AS source_node_uuid,
|
|
@@ -309,25 +379,176 @@ async def edge_bfs_search(
|
|
|
309
379
|
r.created_at AS created_at,
|
|
310
380
|
r.name AS name,
|
|
311
381
|
r.fact AS fact,
|
|
312
|
-
r.episodes AS episodes,
|
|
382
|
+
split(r.episodes, ",") AS episodes,
|
|
313
383
|
r.expired_at AS expired_at,
|
|
314
384
|
r.valid_at AS valid_at,
|
|
315
385
|
r.invalid_at AS invalid_at,
|
|
316
386
|
properties(r) AS attributes
|
|
387
|
+
ORDER BY i.score DESC
|
|
317
388
|
LIMIT $limit
|
|
318
|
-
|
|
319
|
-
|
|
389
|
+
"""
|
|
390
|
+
records, _, _ = await driver.execute_query(
|
|
391
|
+
query,
|
|
392
|
+
ids=input_ids,
|
|
393
|
+
search_vector=search_vector,
|
|
394
|
+
limit=limit,
|
|
395
|
+
min_score=min_score,
|
|
396
|
+
routing_='r',
|
|
397
|
+
**filter_params,
|
|
398
|
+
)
|
|
399
|
+
else:
|
|
400
|
+
return []
|
|
401
|
+
else:
|
|
402
|
+
query = (
|
|
403
|
+
match_query
|
|
404
|
+
+ filter_query
|
|
405
|
+
+ """
|
|
406
|
+
WITH DISTINCT e, n, m, """
|
|
407
|
+
+ get_vector_cosine_func_query('e.fact_embedding', search_vector_var, driver.provider)
|
|
408
|
+
+ """ AS score
|
|
409
|
+
WHERE score > $min_score
|
|
410
|
+
RETURN
|
|
411
|
+
"""
|
|
412
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
413
|
+
+ """
|
|
414
|
+
ORDER BY score DESC
|
|
415
|
+
LIMIT $limit
|
|
416
|
+
"""
|
|
417
|
+
)
|
|
320
418
|
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
419
|
+
records, _, _ = await driver.execute_query(
|
|
420
|
+
query,
|
|
421
|
+
search_vector=search_vector,
|
|
422
|
+
limit=limit,
|
|
423
|
+
min_score=min_score,
|
|
424
|
+
routing_='r',
|
|
425
|
+
**filter_params,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
429
|
+
|
|
430
|
+
return edges
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
async def edge_bfs_search(
|
|
434
|
+
driver: GraphDriver,
|
|
435
|
+
bfs_origin_node_uuids: list[str] | None,
|
|
436
|
+
bfs_max_depth: int,
|
|
437
|
+
search_filter: SearchFilters,
|
|
438
|
+
group_ids: list[str] | None = None,
|
|
439
|
+
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
440
|
+
) -> list[EntityEdge]:
|
|
441
|
+
# vector similarity search over embedded facts
|
|
442
|
+
if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0:
|
|
443
|
+
return []
|
|
444
|
+
|
|
445
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
446
|
+
search_filter, driver.provider
|
|
328
447
|
)
|
|
329
448
|
|
|
330
|
-
|
|
449
|
+
if group_ids is not None:
|
|
450
|
+
filter_queries.append('e.group_id IN $group_ids')
|
|
451
|
+
filter_params['group_ids'] = group_ids
|
|
452
|
+
|
|
453
|
+
filter_query = ''
|
|
454
|
+
if filter_queries:
|
|
455
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
456
|
+
|
|
457
|
+
if driver.provider == GraphProvider.KUZU:
|
|
458
|
+
# Kuzu stores entity edges twice with an intermediate node, so we need to match them
|
|
459
|
+
# separately for the correct BFS depth.
|
|
460
|
+
depth = bfs_max_depth * 2 - 1
|
|
461
|
+
match_queries = [
|
|
462
|
+
f"""
|
|
463
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
464
|
+
MATCH path = (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
|
|
465
|
+
UNWIND nodes(path) AS relNode
|
|
466
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
|
|
467
|
+
""",
|
|
468
|
+
]
|
|
469
|
+
if bfs_max_depth > 1:
|
|
470
|
+
depth = (bfs_max_depth - 1) * 2 - 1
|
|
471
|
+
match_queries.append(f"""
|
|
472
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
473
|
+
MATCH path = (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
|
|
474
|
+
UNWIND nodes(path) AS relNode
|
|
475
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
|
|
476
|
+
""")
|
|
477
|
+
|
|
478
|
+
records = []
|
|
479
|
+
for match_query in match_queries:
|
|
480
|
+
sub_records, _, _ = await driver.execute_query(
|
|
481
|
+
match_query
|
|
482
|
+
+ filter_query
|
|
483
|
+
+ """
|
|
484
|
+
RETURN DISTINCT
|
|
485
|
+
"""
|
|
486
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
487
|
+
+ """
|
|
488
|
+
LIMIT $limit
|
|
489
|
+
""",
|
|
490
|
+
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
491
|
+
limit=limit,
|
|
492
|
+
routing_='r',
|
|
493
|
+
**filter_params,
|
|
494
|
+
)
|
|
495
|
+
records.extend(sub_records)
|
|
496
|
+
else:
|
|
497
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
498
|
+
query = (
|
|
499
|
+
f"""
|
|
500
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
501
|
+
MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
|
|
502
|
+
WHERE origin:Entity OR origin:Episodic
|
|
503
|
+
UNWIND relationships(path) AS rel
|
|
504
|
+
MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
|
|
505
|
+
"""
|
|
506
|
+
+ filter_query
|
|
507
|
+
+ """
|
|
508
|
+
RETURN DISTINCT
|
|
509
|
+
e.uuid AS uuid,
|
|
510
|
+
e.group_id AS group_id,
|
|
511
|
+
startNode(e).uuid AS source_node_uuid,
|
|
512
|
+
endNode(e).uuid AS target_node_uuid,
|
|
513
|
+
e.created_at AS created_at,
|
|
514
|
+
e.name AS name,
|
|
515
|
+
e.fact AS fact,
|
|
516
|
+
split(e.episodes, ',') AS episodes,
|
|
517
|
+
e.expired_at AS expired_at,
|
|
518
|
+
e.valid_at AS valid_at,
|
|
519
|
+
e.invalid_at AS invalid_at,
|
|
520
|
+
properties(e) AS attributes
|
|
521
|
+
LIMIT $limit
|
|
522
|
+
"""
|
|
523
|
+
)
|
|
524
|
+
else:
|
|
525
|
+
query = (
|
|
526
|
+
f"""
|
|
527
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
528
|
+
MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
|
|
529
|
+
UNWIND relationships(path) AS rel
|
|
530
|
+
MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
|
|
531
|
+
"""
|
|
532
|
+
+ filter_query
|
|
533
|
+
+ """
|
|
534
|
+
RETURN DISTINCT
|
|
535
|
+
"""
|
|
536
|
+
+ get_entity_edge_return_query(driver.provider)
|
|
537
|
+
+ """
|
|
538
|
+
LIMIT $limit
|
|
539
|
+
"""
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
records, _, _ = await driver.execute_query(
|
|
543
|
+
query,
|
|
544
|
+
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
545
|
+
depth=bfs_max_depth,
|
|
546
|
+
limit=limit,
|
|
547
|
+
routing_='r',
|
|
548
|
+
**filter_params,
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
|
|
331
552
|
|
|
332
553
|
return edges
|
|
333
554
|
|
|
@@ -339,36 +560,88 @@ async def node_fulltext_search(
|
|
|
339
560
|
group_ids: list[str] | None = None,
|
|
340
561
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
341
562
|
) -> list[EntityNode]:
|
|
563
|
+
if driver.search_interface:
|
|
564
|
+
return await driver.search_interface.node_fulltext_search(
|
|
565
|
+
driver, query, search_filter, group_ids, limit
|
|
566
|
+
)
|
|
567
|
+
|
|
342
568
|
# BM25 search to get top nodes
|
|
343
|
-
fuzzy_query = fulltext_query(query, group_ids)
|
|
569
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
344
570
|
if fuzzy_query == '':
|
|
345
571
|
return []
|
|
346
|
-
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
347
572
|
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
573
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
574
|
+
search_filter, driver.provider
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
if group_ids is not None:
|
|
578
|
+
filter_queries.append('n.group_id IN $group_ids')
|
|
579
|
+
filter_params['group_ids'] = group_ids
|
|
580
|
+
|
|
581
|
+
filter_query = ''
|
|
582
|
+
if filter_queries:
|
|
583
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
584
|
+
|
|
585
|
+
yield_query = 'YIELD node AS n, score'
|
|
586
|
+
if driver.provider == GraphProvider.KUZU:
|
|
587
|
+
yield_query = 'WITH node AS n, score'
|
|
588
|
+
|
|
589
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
590
|
+
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
591
|
+
if res['hits']['total']['value'] > 0:
|
|
592
|
+
input_ids = []
|
|
593
|
+
for r in res['hits']['hits']:
|
|
594
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
595
|
+
|
|
596
|
+
# Match the edge ides and return the values
|
|
597
|
+
query = (
|
|
598
|
+
"""
|
|
599
|
+
UNWIND $ids as i
|
|
600
|
+
MATCH (n:Entity)
|
|
601
|
+
WHERE n.uuid=i.id
|
|
602
|
+
RETURN
|
|
603
|
+
"""
|
|
604
|
+
+ get_entity_node_return_query(driver.provider)
|
|
605
|
+
+ """
|
|
606
|
+
ORDER BY i.score DESC
|
|
607
|
+
LIMIT $limit
|
|
608
|
+
"""
|
|
609
|
+
)
|
|
610
|
+
records, _, _ = await driver.execute_query(
|
|
611
|
+
query,
|
|
612
|
+
ids=input_ids,
|
|
613
|
+
query=fuzzy_query,
|
|
614
|
+
limit=limit,
|
|
615
|
+
routing_='r',
|
|
616
|
+
**filter_params,
|
|
617
|
+
)
|
|
618
|
+
else:
|
|
619
|
+
return []
|
|
620
|
+
else:
|
|
621
|
+
query = (
|
|
622
|
+
get_nodes_query(
|
|
623
|
+
'node_name_and_summary', '$query', limit=limit, provider=driver.provider
|
|
624
|
+
)
|
|
625
|
+
+ yield_query
|
|
626
|
+
+ filter_query
|
|
627
|
+
+ """
|
|
352
628
|
WITH n, score
|
|
629
|
+
ORDER BY score DESC
|
|
353
630
|
LIMIT $limit
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
group_ids=group_ids,
|
|
367
|
-
limit=limit,
|
|
368
|
-
routing_='r',
|
|
369
|
-
)
|
|
631
|
+
RETURN
|
|
632
|
+
"""
|
|
633
|
+
+ get_entity_node_return_query(driver.provider)
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
records, _, _ = await driver.execute_query(
|
|
637
|
+
query,
|
|
638
|
+
query=fuzzy_query,
|
|
639
|
+
limit=limit,
|
|
640
|
+
routing_='r',
|
|
641
|
+
**filter_params,
|
|
642
|
+
)
|
|
370
643
|
|
|
371
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
644
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
372
645
|
|
|
373
646
|
return nodes
|
|
374
647
|
|
|
@@ -381,47 +654,112 @@ async def node_similarity_search(
|
|
|
381
654
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
382
655
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
383
656
|
) -> list[EntityNode]:
|
|
384
|
-
|
|
385
|
-
|
|
657
|
+
if driver.search_interface:
|
|
658
|
+
return await driver.search_interface.node_similarity_search(
|
|
659
|
+
driver, search_vector, search_filter, group_ids, limit, min_score
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
663
|
+
search_filter, driver.provider
|
|
664
|
+
)
|
|
386
665
|
|
|
387
|
-
group_filter_query: LiteralString = 'WHERE n.group_id IS NOT NULL'
|
|
388
666
|
if group_ids is not None:
|
|
389
|
-
|
|
390
|
-
|
|
667
|
+
filter_queries.append('n.group_id IN $group_ids')
|
|
668
|
+
filter_params['group_ids'] = group_ids
|
|
391
669
|
|
|
392
|
-
filter_query
|
|
393
|
-
|
|
670
|
+
filter_query = ''
|
|
671
|
+
if filter_queries:
|
|
672
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
394
673
|
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
+ filter_query
|
|
402
|
-
+ """
|
|
403
|
-
WITH n, """
|
|
404
|
-
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
|
405
|
-
+ """ AS score
|
|
406
|
-
WHERE score > $min_score"""
|
|
407
|
-
+ ENTITY_NODE_RETURN
|
|
408
|
-
+ """
|
|
409
|
-
ORDER BY score DESC
|
|
410
|
-
LIMIT $limit
|
|
674
|
+
search_vector_var = '$search_vector'
|
|
675
|
+
if driver.provider == GraphProvider.KUZU:
|
|
676
|
+
search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
|
|
677
|
+
|
|
678
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
679
|
+
query = (
|
|
411
680
|
"""
|
|
412
|
-
|
|
681
|
+
MATCH (n:Entity)
|
|
682
|
+
"""
|
|
683
|
+
+ filter_query
|
|
684
|
+
+ """
|
|
685
|
+
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
686
|
+
"""
|
|
687
|
+
)
|
|
688
|
+
resp, header, _ = await driver.execute_query(
|
|
689
|
+
query,
|
|
690
|
+
params=filter_params,
|
|
691
|
+
search_vector=search_vector,
|
|
692
|
+
limit=limit,
|
|
693
|
+
min_score=min_score,
|
|
694
|
+
routing_='r',
|
|
695
|
+
)
|
|
413
696
|
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
697
|
+
if len(resp) > 0:
|
|
698
|
+
# Calculate Cosine similarity then return the edge ids
|
|
699
|
+
input_ids = []
|
|
700
|
+
for r in resp:
|
|
701
|
+
if r['embedding']:
|
|
702
|
+
score = calculate_cosine_similarity(
|
|
703
|
+
search_vector, list(map(float, r['embedding'].split(',')))
|
|
704
|
+
)
|
|
705
|
+
if score > min_score:
|
|
706
|
+
input_ids.append({'id': r['id'], 'score': score})
|
|
707
|
+
|
|
708
|
+
# Match the edge ides and return the values
|
|
709
|
+
query = (
|
|
710
|
+
"""
|
|
711
|
+
UNWIND $ids as i
|
|
712
|
+
MATCH (n:Entity)
|
|
713
|
+
WHERE id(n)=i.id
|
|
714
|
+
RETURN
|
|
715
|
+
"""
|
|
716
|
+
+ get_entity_node_return_query(driver.provider)
|
|
717
|
+
+ """
|
|
718
|
+
ORDER BY i.score DESC
|
|
719
|
+
LIMIT $limit
|
|
720
|
+
"""
|
|
721
|
+
)
|
|
722
|
+
records, header, _ = await driver.execute_query(
|
|
723
|
+
query,
|
|
724
|
+
ids=input_ids,
|
|
725
|
+
search_vector=search_vector,
|
|
726
|
+
limit=limit,
|
|
727
|
+
min_score=min_score,
|
|
728
|
+
routing_='r',
|
|
729
|
+
**filter_params,
|
|
730
|
+
)
|
|
731
|
+
else:
|
|
732
|
+
return []
|
|
733
|
+
else:
|
|
734
|
+
query = (
|
|
735
|
+
"""
|
|
736
|
+
MATCH (n:Entity)
|
|
737
|
+
"""
|
|
738
|
+
+ filter_query
|
|
739
|
+
+ """
|
|
740
|
+
WITH n, """
|
|
741
|
+
+ get_vector_cosine_func_query('n.name_embedding', search_vector_var, driver.provider)
|
|
742
|
+
+ """ AS score
|
|
743
|
+
WHERE score > $min_score
|
|
744
|
+
RETURN
|
|
745
|
+
"""
|
|
746
|
+
+ get_entity_node_return_query(driver.provider)
|
|
747
|
+
+ """
|
|
748
|
+
ORDER BY score DESC
|
|
749
|
+
LIMIT $limit
|
|
750
|
+
"""
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
records, _, _ = await driver.execute_query(
|
|
754
|
+
query,
|
|
755
|
+
search_vector=search_vector,
|
|
756
|
+
limit=limit,
|
|
757
|
+
min_score=min_score,
|
|
758
|
+
routing_='r',
|
|
759
|
+
**filter_params,
|
|
760
|
+
)
|
|
423
761
|
|
|
424
|
-
nodes = [get_entity_node_from_record(record) for record in records]
|
|
762
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
425
763
|
|
|
426
764
|
return nodes
|
|
427
765
|
|
|
@@ -431,35 +769,85 @@ async def node_bfs_search(
|
|
|
431
769
|
bfs_origin_node_uuids: list[str] | None,
|
|
432
770
|
search_filter: SearchFilters,
|
|
433
771
|
bfs_max_depth: int,
|
|
434
|
-
|
|
772
|
+
group_ids: list[str] | None = None,
|
|
773
|
+
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
435
774
|
) -> list[EntityNode]:
|
|
436
|
-
|
|
437
|
-
if bfs_origin_node_uuids is None:
|
|
775
|
+
if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0 or bfs_max_depth < 1:
|
|
438
776
|
return []
|
|
439
777
|
|
|
440
|
-
|
|
778
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
779
|
+
search_filter, driver.provider
|
|
780
|
+
)
|
|
441
781
|
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
+
|
|
450
|
-
|
|
451
|
-
|
|
782
|
+
if group_ids is not None:
|
|
783
|
+
filter_queries.append('n.group_id IN $group_ids')
|
|
784
|
+
filter_queries.append('origin.group_id IN $group_ids')
|
|
785
|
+
filter_params['group_ids'] = group_ids
|
|
786
|
+
|
|
787
|
+
filter_query = ''
|
|
788
|
+
if filter_queries:
|
|
789
|
+
filter_query = ' AND ' + (' AND '.join(filter_queries))
|
|
790
|
+
|
|
791
|
+
match_queries = [
|
|
792
|
+
f"""
|
|
793
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
794
|
+
MATCH (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
|
|
795
|
+
WHERE n.group_id = origin.group_id
|
|
452
796
|
"""
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
797
|
+
]
|
|
798
|
+
|
|
799
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
800
|
+
match_queries = [
|
|
801
|
+
f"""
|
|
802
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
803
|
+
MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
|
|
804
|
+
WHERE origin:Entity OR origin.Episode
|
|
805
|
+
AND n.group_id = origin.group_id
|
|
806
|
+
"""
|
|
807
|
+
]
|
|
808
|
+
|
|
809
|
+
if driver.provider == GraphProvider.KUZU:
|
|
810
|
+
depth = bfs_max_depth * 2
|
|
811
|
+
match_queries = [
|
|
812
|
+
"""
|
|
813
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
814
|
+
MATCH (origin:Episodic {uuid: origin_uuid})-[:MENTIONS]->(n:Entity)
|
|
815
|
+
WHERE n.group_id = origin.group_id
|
|
816
|
+
""",
|
|
817
|
+
f"""
|
|
818
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
819
|
+
MATCH (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*2..{depth}]->(n:Entity)
|
|
820
|
+
WHERE n.group_id = origin.group_id
|
|
821
|
+
""",
|
|
822
|
+
]
|
|
823
|
+
if bfs_max_depth > 1:
|
|
824
|
+
depth = (bfs_max_depth - 1) * 2
|
|
825
|
+
match_queries.append(f"""
|
|
826
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
827
|
+
MATCH (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*2..{depth}]->(n:Entity)
|
|
828
|
+
WHERE n.group_id = origin.group_id
|
|
829
|
+
""")
|
|
830
|
+
|
|
831
|
+
records = []
|
|
832
|
+
for match_query in match_queries:
|
|
833
|
+
sub_records, _, _ = await driver.execute_query(
|
|
834
|
+
match_query
|
|
835
|
+
+ filter_query
|
|
836
|
+
+ """
|
|
837
|
+
RETURN
|
|
838
|
+
"""
|
|
839
|
+
+ get_entity_node_return_query(driver.provider)
|
|
840
|
+
+ """
|
|
841
|
+
LIMIT $limit
|
|
842
|
+
""",
|
|
843
|
+
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
844
|
+
limit=limit,
|
|
845
|
+
routing_='r',
|
|
846
|
+
**filter_params,
|
|
847
|
+
)
|
|
848
|
+
records.extend(sub_records)
|
|
849
|
+
|
|
850
|
+
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
|
|
463
851
|
|
|
464
852
|
return nodes
|
|
465
853
|
|
|
@@ -471,39 +859,80 @@ async def episode_fulltext_search(
|
|
|
471
859
|
group_ids: list[str] | None = None,
|
|
472
860
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
473
861
|
) -> list[EpisodicNode]:
|
|
862
|
+
if driver.search_interface:
|
|
863
|
+
return await driver.search_interface.episode_fulltext_search(
|
|
864
|
+
driver, query, _search_filter, group_ids, limit
|
|
865
|
+
)
|
|
866
|
+
|
|
474
867
|
# BM25 search to get top episodes
|
|
475
|
-
fuzzy_query = fulltext_query(query, group_ids)
|
|
868
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
476
869
|
if fuzzy_query == '':
|
|
477
870
|
return []
|
|
478
871
|
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
872
|
+
filter_params: dict[str, Any] = {}
|
|
873
|
+
group_filter_query: LiteralString = ''
|
|
874
|
+
if group_ids is not None:
|
|
875
|
+
group_filter_query += '\nAND e.group_id IN $group_ids'
|
|
876
|
+
filter_params['group_ids'] = group_ids
|
|
877
|
+
|
|
878
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
879
|
+
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
880
|
+
if res['hits']['total']['value'] > 0:
|
|
881
|
+
input_ids = []
|
|
882
|
+
for r in res['hits']['hits']:
|
|
883
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
884
|
+
|
|
885
|
+
# Match the edge ides and return the values
|
|
886
|
+
query = """
|
|
887
|
+
UNWIND $ids as i
|
|
888
|
+
MATCH (e:Episodic)
|
|
889
|
+
WHERE e.uuid=i.uuid
|
|
890
|
+
RETURN
|
|
891
|
+
e.content AS content,
|
|
892
|
+
e.created_at AS created_at,
|
|
893
|
+
e.valid_at AS valid_at,
|
|
894
|
+
e.uuid AS uuid,
|
|
895
|
+
e.name AS name,
|
|
896
|
+
e.group_id AS group_id,
|
|
897
|
+
e.source_description AS source_description,
|
|
898
|
+
e.source AS source,
|
|
899
|
+
e.entity_edges AS entity_edges
|
|
900
|
+
ORDER BY i.score DESC
|
|
901
|
+
LIMIT $limit
|
|
902
|
+
"""
|
|
903
|
+
records, _, _ = await driver.execute_query(
|
|
904
|
+
query,
|
|
905
|
+
ids=input_ids,
|
|
906
|
+
query=fuzzy_query,
|
|
907
|
+
limit=limit,
|
|
908
|
+
routing_='r',
|
|
909
|
+
**filter_params,
|
|
910
|
+
)
|
|
911
|
+
else:
|
|
912
|
+
return []
|
|
913
|
+
else:
|
|
914
|
+
query = (
|
|
915
|
+
get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
|
|
916
|
+
+ """
|
|
917
|
+
YIELD node AS episode, score
|
|
918
|
+
MATCH (e:Episodic)
|
|
919
|
+
WHERE e.uuid = episode.uuid
|
|
920
|
+
"""
|
|
921
|
+
+ group_filter_query
|
|
922
|
+
+ """
|
|
923
|
+
RETURN
|
|
924
|
+
"""
|
|
925
|
+
+ EPISODIC_NODE_RETURN
|
|
926
|
+
+ """
|
|
927
|
+
ORDER BY score DESC
|
|
928
|
+
LIMIT $limit
|
|
929
|
+
"""
|
|
930
|
+
)
|
|
931
|
+
|
|
932
|
+
records, _, _ = await driver.execute_query(
|
|
933
|
+
query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
|
|
934
|
+
)
|
|
499
935
|
|
|
500
|
-
records, _, _ = await driver.execute_query(
|
|
501
|
-
query,
|
|
502
|
-
query=fuzzy_query,
|
|
503
|
-
group_ids=group_ids,
|
|
504
|
-
limit=limit,
|
|
505
|
-
routing_='r',
|
|
506
|
-
)
|
|
507
936
|
episodes = [get_episodic_node_from_record(record) for record in records]
|
|
508
937
|
|
|
509
938
|
return episodes
|
|
@@ -516,33 +945,75 @@ async def community_fulltext_search(
|
|
|
516
945
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
517
946
|
) -> list[CommunityNode]:
|
|
518
947
|
# BM25 search to get top communities
|
|
519
|
-
fuzzy_query = fulltext_query(query, group_ids)
|
|
948
|
+
fuzzy_query = fulltext_query(query, group_ids, driver)
|
|
520
949
|
if fuzzy_query == '':
|
|
521
950
|
return []
|
|
522
951
|
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
952
|
+
filter_params: dict[str, Any] = {}
|
|
953
|
+
group_filter_query: LiteralString = ''
|
|
954
|
+
if group_ids is not None:
|
|
955
|
+
group_filter_query = 'WHERE c.group_id IN $group_ids'
|
|
956
|
+
filter_params['group_ids'] = group_ids
|
|
957
|
+
|
|
958
|
+
yield_query = 'YIELD node AS c, score'
|
|
959
|
+
if driver.provider == GraphProvider.KUZU:
|
|
960
|
+
yield_query = 'WITH node AS c, score'
|
|
961
|
+
|
|
962
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
963
|
+
res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
|
|
964
|
+
if res['hits']['total']['value'] > 0:
|
|
965
|
+
# Calculate Cosine similarity then return the edge ids
|
|
966
|
+
input_ids = []
|
|
967
|
+
for r in res['hits']['hits']:
|
|
968
|
+
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
|
|
969
|
+
|
|
970
|
+
# Match the edge ides and return the values
|
|
971
|
+
query = """
|
|
972
|
+
UNWIND $ids as i
|
|
973
|
+
MATCH (comm:Community)
|
|
974
|
+
WHERE comm.uuid=i.id
|
|
975
|
+
RETURN
|
|
976
|
+
comm.uuid AS uuid,
|
|
977
|
+
comm.group_id AS group_id,
|
|
978
|
+
comm.name AS name,
|
|
979
|
+
comm.created_at AS created_at,
|
|
980
|
+
comm.summary AS summary,
|
|
981
|
+
[x IN split(comm.name_embedding, ",") | toFloat(x)]AS name_embedding
|
|
982
|
+
ORDER BY i.score DESC
|
|
983
|
+
LIMIT $limit
|
|
984
|
+
"""
|
|
985
|
+
records, _, _ = await driver.execute_query(
|
|
986
|
+
query,
|
|
987
|
+
ids=input_ids,
|
|
988
|
+
query=fuzzy_query,
|
|
989
|
+
limit=limit,
|
|
990
|
+
routing_='r',
|
|
991
|
+
**filter_params,
|
|
992
|
+
)
|
|
993
|
+
else:
|
|
994
|
+
return []
|
|
995
|
+
else:
|
|
996
|
+
query = (
|
|
997
|
+
get_nodes_query('community_name', '$query', limit=limit, provider=driver.provider)
|
|
998
|
+
+ yield_query
|
|
999
|
+
+ """
|
|
1000
|
+
WITH c, score
|
|
1001
|
+
"""
|
|
1002
|
+
+ group_filter_query
|
|
1003
|
+
+ """
|
|
1004
|
+
RETURN
|
|
1005
|
+
"""
|
|
1006
|
+
+ COMMUNITY_NODE_RETURN
|
|
1007
|
+
+ """
|
|
1008
|
+
ORDER BY score DESC
|
|
1009
|
+
LIMIT $limit
|
|
1010
|
+
"""
|
|
1011
|
+
)
|
|
1012
|
+
|
|
1013
|
+
records, _, _ = await driver.execute_query(
|
|
1014
|
+
query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
|
|
1015
|
+
)
|
|
538
1016
|
|
|
539
|
-
records, _, _ = await driver.execute_query(
|
|
540
|
-
query,
|
|
541
|
-
query=fuzzy_query,
|
|
542
|
-
group_ids=group_ids,
|
|
543
|
-
limit=limit,
|
|
544
|
-
routing_='r',
|
|
545
|
-
)
|
|
546
1017
|
communities = [get_community_node_from_record(record) for record in records]
|
|
547
1018
|
|
|
548
1019
|
return communities
|
|
@@ -560,40 +1031,99 @@ async def community_similarity_search(
|
|
|
560
1031
|
|
|
561
1032
|
group_filter_query: LiteralString = ''
|
|
562
1033
|
if group_ids is not None:
|
|
563
|
-
group_filter_query += 'WHERE
|
|
1034
|
+
group_filter_query += ' WHERE c.group_id IN $group_ids'
|
|
564
1035
|
query_params['group_ids'] = group_ids
|
|
565
1036
|
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
|
|
570
|
-
|
|
571
|
-
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
1037
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1038
|
+
query = (
|
|
1039
|
+
"""
|
|
1040
|
+
MATCH (n:Community)
|
|
1041
|
+
"""
|
|
1042
|
+
+ group_filter_query
|
|
1043
|
+
+ """
|
|
1044
|
+
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
|
|
1045
|
+
"""
|
|
1046
|
+
)
|
|
1047
|
+
resp, header, _ = await driver.execute_query(
|
|
1048
|
+
query,
|
|
1049
|
+
search_vector=search_vector,
|
|
1050
|
+
limit=limit,
|
|
1051
|
+
min_score=min_score,
|
|
1052
|
+
routing_='r',
|
|
1053
|
+
**query_params,
|
|
1054
|
+
)
|
|
1055
|
+
|
|
1056
|
+
if len(resp) > 0:
|
|
1057
|
+
# Calculate Cosine similarity then return the edge ids
|
|
1058
|
+
input_ids = []
|
|
1059
|
+
for r in resp:
|
|
1060
|
+
if r['embedding']:
|
|
1061
|
+
score = calculate_cosine_similarity(
|
|
1062
|
+
search_vector, list(map(float, r['embedding'].split(',')))
|
|
1063
|
+
)
|
|
1064
|
+
if score > min_score:
|
|
1065
|
+
input_ids.append({'id': r['id'], 'score': score})
|
|
1066
|
+
|
|
1067
|
+
# Match the edge ides and return the values
|
|
1068
|
+
query = """
|
|
1069
|
+
UNWIND $ids as i
|
|
1070
|
+
MATCH (comm:Community)
|
|
1071
|
+
WHERE id(comm)=i.id
|
|
1072
|
+
RETURN
|
|
1073
|
+
comm.uuid As uuid,
|
|
1074
|
+
comm.group_id AS group_id,
|
|
1075
|
+
comm.name AS name,
|
|
1076
|
+
comm.created_at AS created_at,
|
|
1077
|
+
comm.summary AS summary,
|
|
1078
|
+
comm.name_embedding AS name_embedding
|
|
1079
|
+
ORDER BY i.score DESC
|
|
1080
|
+
LIMIT $limit
|
|
1081
|
+
"""
|
|
1082
|
+
records, header, _ = await driver.execute_query(
|
|
1083
|
+
query,
|
|
1084
|
+
ids=input_ids,
|
|
1085
|
+
search_vector=search_vector,
|
|
1086
|
+
limit=limit,
|
|
1087
|
+
min_score=min_score,
|
|
1088
|
+
routing_='r',
|
|
1089
|
+
**query_params,
|
|
1090
|
+
)
|
|
1091
|
+
else:
|
|
1092
|
+
return []
|
|
1093
|
+
else:
|
|
1094
|
+
search_vector_var = '$search_vector'
|
|
1095
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1096
|
+
search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
|
|
1097
|
+
|
|
1098
|
+
query = (
|
|
1099
|
+
"""
|
|
1100
|
+
MATCH (c:Community)
|
|
1101
|
+
"""
|
|
1102
|
+
+ group_filter_query
|
|
1103
|
+
+ """
|
|
1104
|
+
WITH c,
|
|
1105
|
+
"""
|
|
1106
|
+
+ get_vector_cosine_func_query('c.name_embedding', search_vector_var, driver.provider)
|
|
1107
|
+
+ """ AS score
|
|
1108
|
+
WHERE score > $min_score
|
|
1109
|
+
RETURN
|
|
1110
|
+
"""
|
|
1111
|
+
+ COMMUNITY_NODE_RETURN
|
|
1112
|
+
+ """
|
|
1113
|
+
ORDER BY score DESC
|
|
1114
|
+
LIMIT $limit
|
|
1115
|
+
"""
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
records, _, _ = await driver.execute_query(
|
|
1119
|
+
query,
|
|
1120
|
+
search_vector=search_vector,
|
|
1121
|
+
limit=limit,
|
|
1122
|
+
min_score=min_score,
|
|
1123
|
+
routing_='r',
|
|
1124
|
+
**query_params,
|
|
1125
|
+
)
|
|
588
1126
|
|
|
589
|
-
records, _, _ = await driver.execute_query(
|
|
590
|
-
query,
|
|
591
|
-
search_vector=search_vector,
|
|
592
|
-
group_ids=group_ids,
|
|
593
|
-
limit=limit,
|
|
594
|
-
min_score=min_score,
|
|
595
|
-
routing_='r',
|
|
596
|
-
)
|
|
597
1127
|
communities = [get_community_node_from_record(record) for record in records]
|
|
598
1128
|
|
|
599
1129
|
return communities
|
|
@@ -664,7 +1194,7 @@ async def hybrid_node_search(
|
|
|
664
1194
|
}
|
|
665
1195
|
result_uuids = [[node.uuid for node in result] for result in results]
|
|
666
1196
|
|
|
667
|
-
ranked_uuids = rrf(result_uuids)
|
|
1197
|
+
ranked_uuids, _ = rrf(result_uuids)
|
|
668
1198
|
|
|
669
1199
|
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
|
|
670
1200
|
|
|
@@ -684,80 +1214,140 @@ async def get_relevant_nodes(
|
|
|
684
1214
|
return []
|
|
685
1215
|
|
|
686
1216
|
group_id = nodes[0].group_id
|
|
687
|
-
|
|
688
|
-
# vector similarity search over entity names
|
|
689
|
-
query_params: dict[str, Any] = {}
|
|
690
|
-
|
|
691
|
-
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
692
|
-
query_params.update(filter_params)
|
|
693
|
-
|
|
694
|
-
query = (
|
|
695
|
-
RUNTIME_QUERY
|
|
696
|
-
+ """
|
|
697
|
-
UNWIND $nodes AS node
|
|
698
|
-
MATCH (n:Entity {group_id: $group_id})
|
|
699
|
-
"""
|
|
700
|
-
+ filter_query
|
|
701
|
-
+ """
|
|
702
|
-
WITH node, n, """
|
|
703
|
-
+ get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
|
|
704
|
-
+ """ AS score
|
|
705
|
-
WHERE score > $min_score
|
|
706
|
-
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
|
707
|
-
"""
|
|
708
|
-
+ get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
|
|
709
|
-
+ """
|
|
710
|
-
YIELD node AS m
|
|
711
|
-
WHERE m.group_id = $group_id
|
|
712
|
-
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
|
713
|
-
|
|
714
|
-
WITH node,
|
|
715
|
-
top_vector_nodes,
|
|
716
|
-
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
|
717
|
-
|
|
718
|
-
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
|
719
|
-
|
|
720
|
-
UNWIND combined_nodes AS combined_node
|
|
721
|
-
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
|
722
|
-
|
|
723
|
-
RETURN
|
|
724
|
-
node.uuid AS search_node_uuid,
|
|
725
|
-
[x IN deduped_nodes | {
|
|
726
|
-
uuid: x.uuid,
|
|
727
|
-
name: x.name,
|
|
728
|
-
name_embedding: x.name_embedding,
|
|
729
|
-
group_id: x.group_id,
|
|
730
|
-
created_at: x.created_at,
|
|
731
|
-
summary: x.summary,
|
|
732
|
-
labels: labels(x),
|
|
733
|
-
attributes: properties(x)
|
|
734
|
-
}] AS matches
|
|
735
|
-
"""
|
|
736
|
-
)
|
|
737
|
-
|
|
738
1217
|
query_nodes = [
|
|
739
1218
|
{
|
|
740
1219
|
'uuid': node.uuid,
|
|
741
1220
|
'name': node.name,
|
|
742
1221
|
'name_embedding': node.name_embedding,
|
|
743
|
-
'fulltext_query': fulltext_query(node.name, [node.group_id]),
|
|
1222
|
+
'fulltext_query': fulltext_query(node.name, [node.group_id], driver),
|
|
744
1223
|
}
|
|
745
1224
|
for node in nodes
|
|
746
1225
|
]
|
|
747
1226
|
|
|
1227
|
+
filter_queries, filter_params = node_search_filter_query_constructor(
|
|
1228
|
+
search_filter, driver.provider
|
|
1229
|
+
)
|
|
1230
|
+
|
|
1231
|
+
filter_query = ''
|
|
1232
|
+
if filter_queries:
|
|
1233
|
+
filter_query = 'WHERE ' + (' AND '.join(filter_queries))
|
|
1234
|
+
|
|
1235
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1236
|
+
embedding_size = len(nodes[0].name_embedding) if nodes[0].name_embedding is not None else 0
|
|
1237
|
+
if embedding_size == 0:
|
|
1238
|
+
return []
|
|
1239
|
+
|
|
1240
|
+
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
|
|
1241
|
+
query = (
|
|
1242
|
+
"""
|
|
1243
|
+
UNWIND $nodes AS node
|
|
1244
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
1245
|
+
"""
|
|
1246
|
+
+ filter_query
|
|
1247
|
+
+ """
|
|
1248
|
+
WITH node, n, """
|
|
1249
|
+
+ get_vector_cosine_func_query(
|
|
1250
|
+
'n.name_embedding',
|
|
1251
|
+
f'CAST(node.name_embedding AS FLOAT[{embedding_size}])',
|
|
1252
|
+
driver.provider,
|
|
1253
|
+
)
|
|
1254
|
+
+ """ AS score
|
|
1255
|
+
WHERE score > $min_score
|
|
1256
|
+
WITH node, collect(n)[:$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
|
1257
|
+
"""
|
|
1258
|
+
+ get_nodes_query(
|
|
1259
|
+
'node_name_and_summary',
|
|
1260
|
+
'node.fulltext_query',
|
|
1261
|
+
limit=limit,
|
|
1262
|
+
provider=driver.provider,
|
|
1263
|
+
)
|
|
1264
|
+
+ """
|
|
1265
|
+
WITH node AS m
|
|
1266
|
+
WHERE m.group_id = $group_id AND NOT m.uuid IN vector_node_uuids
|
|
1267
|
+
WITH node, top_vector_nodes, collect(m) AS fulltext_nodes
|
|
1268
|
+
|
|
1269
|
+
WITH node, list_concat(top_vector_nodes, fulltext_nodes) AS combined_nodes
|
|
1270
|
+
|
|
1271
|
+
UNWIND combined_nodes AS x
|
|
1272
|
+
WITH node, collect(DISTINCT {
|
|
1273
|
+
uuid: x.uuid,
|
|
1274
|
+
name: x.name,
|
|
1275
|
+
name_embedding: x.name_embedding,
|
|
1276
|
+
group_id: x.group_id,
|
|
1277
|
+
created_at: x.created_at,
|
|
1278
|
+
summary: x.summary,
|
|
1279
|
+
labels: x.labels,
|
|
1280
|
+
attributes: x.attributes
|
|
1281
|
+
}) AS matches
|
|
1282
|
+
|
|
1283
|
+
RETURN
|
|
1284
|
+
node.uuid AS search_node_uuid, matches
|
|
1285
|
+
"""
|
|
1286
|
+
)
|
|
1287
|
+
else:
|
|
1288
|
+
query = (
|
|
1289
|
+
"""
|
|
1290
|
+
UNWIND $nodes AS node
|
|
1291
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
1292
|
+
"""
|
|
1293
|
+
+ filter_query
|
|
1294
|
+
+ """
|
|
1295
|
+
WITH node, n, """
|
|
1296
|
+
+ get_vector_cosine_func_query(
|
|
1297
|
+
'n.name_embedding', 'node.name_embedding', driver.provider
|
|
1298
|
+
)
|
|
1299
|
+
+ """ AS score
|
|
1300
|
+
WHERE score > $min_score
|
|
1301
|
+
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
|
1302
|
+
"""
|
|
1303
|
+
+ get_nodes_query(
|
|
1304
|
+
'node_name_and_summary',
|
|
1305
|
+
'node.fulltext_query',
|
|
1306
|
+
limit=limit,
|
|
1307
|
+
provider=driver.provider,
|
|
1308
|
+
)
|
|
1309
|
+
+ """
|
|
1310
|
+
YIELD node AS m
|
|
1311
|
+
WHERE m.group_id = $group_id
|
|
1312
|
+
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
|
1313
|
+
|
|
1314
|
+
WITH node,
|
|
1315
|
+
top_vector_nodes,
|
|
1316
|
+
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
|
1317
|
+
|
|
1318
|
+
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
|
1319
|
+
|
|
1320
|
+
UNWIND combined_nodes AS combined_node
|
|
1321
|
+
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
|
1322
|
+
|
|
1323
|
+
RETURN
|
|
1324
|
+
node.uuid AS search_node_uuid,
|
|
1325
|
+
[x IN deduped_nodes | {
|
|
1326
|
+
uuid: x.uuid,
|
|
1327
|
+
name: x.name,
|
|
1328
|
+
name_embedding: x.name_embedding,
|
|
1329
|
+
group_id: x.group_id,
|
|
1330
|
+
created_at: x.created_at,
|
|
1331
|
+
summary: x.summary,
|
|
1332
|
+
labels: labels(x),
|
|
1333
|
+
attributes: properties(x)
|
|
1334
|
+
}] AS matches
|
|
1335
|
+
"""
|
|
1336
|
+
)
|
|
1337
|
+
|
|
748
1338
|
results, _, _ = await driver.execute_query(
|
|
749
1339
|
query,
|
|
750
|
-
params=query_params,
|
|
751
1340
|
nodes=query_nodes,
|
|
752
1341
|
group_id=group_id,
|
|
753
1342
|
limit=limit,
|
|
754
1343
|
min_score=min_score,
|
|
755
1344
|
routing_='r',
|
|
1345
|
+
**filter_params,
|
|
756
1346
|
)
|
|
757
1347
|
|
|
758
1348
|
relevant_nodes_dict: dict[str, list[EntityNode]] = {
|
|
759
1349
|
result['search_node_uuid']: [
|
|
760
|
-
get_entity_node_from_record(record) for record in result['matches']
|
|
1350
|
+
get_entity_node_from_record(record, driver.provider) for record in result['matches']
|
|
761
1351
|
]
|
|
762
1352
|
for result in results
|
|
763
1353
|
}
|
|
@@ -777,25 +1367,52 @@ async def get_relevant_edges(
|
|
|
777
1367
|
if len(edges) == 0:
|
|
778
1368
|
return []
|
|
779
1369
|
|
|
780
|
-
|
|
1370
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
1371
|
+
search_filter, driver.provider
|
|
1372
|
+
)
|
|
781
1373
|
|
|
782
|
-
filter_query
|
|
783
|
-
|
|
1374
|
+
filter_query = ''
|
|
1375
|
+
if filter_queries:
|
|
1376
|
+
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
|
|
784
1377
|
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
1378
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1379
|
+
query = (
|
|
1380
|
+
"""
|
|
1381
|
+
UNWIND $edges AS edge
|
|
1382
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1383
|
+
"""
|
|
1384
|
+
+ filter_query
|
|
1385
|
+
+ """
|
|
1386
|
+
WITH e, edge
|
|
1387
|
+
RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
|
|
1388
|
+
edge.fact_embedding as target_embedding
|
|
1389
|
+
"""
|
|
1390
|
+
)
|
|
1391
|
+
resp, _, _ = await driver.execute_query(
|
|
1392
|
+
query,
|
|
1393
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1394
|
+
limit=limit,
|
|
1395
|
+
min_score=min_score,
|
|
1396
|
+
routing_='r',
|
|
1397
|
+
**filter_params,
|
|
1398
|
+
)
|
|
1399
|
+
|
|
1400
|
+
# Calculate Cosine similarity then return the edge ids
|
|
1401
|
+
input_ids = []
|
|
1402
|
+
for r in resp:
|
|
1403
|
+
score = calculate_cosine_similarity(
|
|
1404
|
+
list(map(float, r['source_embedding'].split(','))), r['target_embedding']
|
|
1405
|
+
)
|
|
1406
|
+
if score > min_score:
|
|
1407
|
+
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
|
|
1408
|
+
|
|
1409
|
+
# Match the edge ides and return the values
|
|
1410
|
+
query = """
|
|
1411
|
+
UNWIND $ids AS edge
|
|
1412
|
+
MATCH ()-[e]->()
|
|
1413
|
+
WHERE id(e) = edge.id
|
|
1414
|
+
WITH edge, e
|
|
1415
|
+
ORDER BY edge.score DESC
|
|
799
1416
|
RETURN edge.uuid AS search_edge_uuid,
|
|
800
1417
|
collect({
|
|
801
1418
|
uuid: e.uuid,
|
|
@@ -805,28 +1422,117 @@ async def get_relevant_edges(
|
|
|
805
1422
|
name: e.name,
|
|
806
1423
|
group_id: e.group_id,
|
|
807
1424
|
fact: e.fact,
|
|
808
|
-
fact_embedding: e.fact_embedding,
|
|
809
|
-
episodes: e.episodes,
|
|
1425
|
+
fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
|
|
1426
|
+
episodes: split(e.episodes, ","),
|
|
810
1427
|
expired_at: e.expired_at,
|
|
811
1428
|
valid_at: e.valid_at,
|
|
812
1429
|
invalid_at: e.invalid_at,
|
|
813
1430
|
attributes: properties(e)
|
|
814
1431
|
})[..$limit] AS matches
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
1432
|
+
"""
|
|
1433
|
+
|
|
1434
|
+
results, _, _ = await driver.execute_query(
|
|
1435
|
+
query,
|
|
1436
|
+
ids=input_ids,
|
|
1437
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1438
|
+
limit=limit,
|
|
1439
|
+
min_score=min_score,
|
|
1440
|
+
routing_='r',
|
|
1441
|
+
**filter_params,
|
|
1442
|
+
)
|
|
1443
|
+
else:
|
|
1444
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1445
|
+
embedding_size = (
|
|
1446
|
+
len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
|
|
1447
|
+
)
|
|
1448
|
+
if embedding_size == 0:
|
|
1449
|
+
return []
|
|
1450
|
+
|
|
1451
|
+
query = (
|
|
1452
|
+
"""
|
|
1453
|
+
UNWIND $edges AS edge
|
|
1454
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1455
|
+
"""
|
|
1456
|
+
+ filter_query
|
|
1457
|
+
+ """
|
|
1458
|
+
WITH e, edge, n, m, """
|
|
1459
|
+
+ get_vector_cosine_func_query(
|
|
1460
|
+
'e.fact_embedding',
|
|
1461
|
+
f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
|
|
1462
|
+
driver.provider,
|
|
1463
|
+
)
|
|
1464
|
+
+ """ AS score
|
|
1465
|
+
WHERE score > $min_score
|
|
1466
|
+
WITH e, edge, n, m, score
|
|
1467
|
+
ORDER BY score DESC
|
|
1468
|
+
LIMIT $limit
|
|
1469
|
+
RETURN
|
|
1470
|
+
edge.uuid AS search_edge_uuid,
|
|
1471
|
+
collect({
|
|
1472
|
+
uuid: e.uuid,
|
|
1473
|
+
source_node_uuid: n.uuid,
|
|
1474
|
+
target_node_uuid: m.uuid,
|
|
1475
|
+
created_at: e.created_at,
|
|
1476
|
+
name: e.name,
|
|
1477
|
+
group_id: e.group_id,
|
|
1478
|
+
fact: e.fact,
|
|
1479
|
+
fact_embedding: e.fact_embedding,
|
|
1480
|
+
episodes: e.episodes,
|
|
1481
|
+
expired_at: e.expired_at,
|
|
1482
|
+
valid_at: e.valid_at,
|
|
1483
|
+
invalid_at: e.invalid_at,
|
|
1484
|
+
attributes: e.attributes
|
|
1485
|
+
}) AS matches
|
|
1486
|
+
"""
|
|
1487
|
+
)
|
|
1488
|
+
else:
|
|
1489
|
+
query = (
|
|
1490
|
+
"""
|
|
1491
|
+
UNWIND $edges AS edge
|
|
1492
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
1493
|
+
"""
|
|
1494
|
+
+ filter_query
|
|
1495
|
+
+ """
|
|
1496
|
+
WITH e, edge, """
|
|
1497
|
+
+ get_vector_cosine_func_query(
|
|
1498
|
+
'e.fact_embedding', 'edge.fact_embedding', driver.provider
|
|
1499
|
+
)
|
|
1500
|
+
+ """ AS score
|
|
1501
|
+
WHERE score > $min_score
|
|
1502
|
+
WITH edge, e, score
|
|
1503
|
+
ORDER BY score DESC
|
|
1504
|
+
RETURN
|
|
1505
|
+
edge.uuid AS search_edge_uuid,
|
|
1506
|
+
collect({
|
|
1507
|
+
uuid: e.uuid,
|
|
1508
|
+
source_node_uuid: startNode(e).uuid,
|
|
1509
|
+
target_node_uuid: endNode(e).uuid,
|
|
1510
|
+
created_at: e.created_at,
|
|
1511
|
+
name: e.name,
|
|
1512
|
+
group_id: e.group_id,
|
|
1513
|
+
fact: e.fact,
|
|
1514
|
+
fact_embedding: e.fact_embedding,
|
|
1515
|
+
episodes: e.episodes,
|
|
1516
|
+
expired_at: e.expired_at,
|
|
1517
|
+
valid_at: e.valid_at,
|
|
1518
|
+
invalid_at: e.invalid_at,
|
|
1519
|
+
attributes: properties(e)
|
|
1520
|
+
})[..$limit] AS matches
|
|
1521
|
+
"""
|
|
1522
|
+
)
|
|
1523
|
+
|
|
1524
|
+
results, _, _ = await driver.execute_query(
|
|
1525
|
+
query,
|
|
1526
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1527
|
+
limit=limit,
|
|
1528
|
+
min_score=min_score,
|
|
1529
|
+
routing_='r',
|
|
1530
|
+
**filter_params,
|
|
1531
|
+
)
|
|
826
1532
|
|
|
827
1533
|
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
|
828
1534
|
result['search_edge_uuid']: [
|
|
829
|
-
get_entity_edge_from_record(record) for record in result['matches']
|
|
1535
|
+
get_entity_edge_from_record(record, driver.provider) for record in result['matches']
|
|
830
1536
|
]
|
|
831
1537
|
for result in results
|
|
832
1538
|
}
|
|
@@ -846,26 +1552,54 @@ async def get_edge_invalidation_candidates(
|
|
|
846
1552
|
if len(edges) == 0:
|
|
847
1553
|
return []
|
|
848
1554
|
|
|
849
|
-
|
|
1555
|
+
filter_queries, filter_params = edge_search_filter_query_constructor(
|
|
1556
|
+
search_filter, driver.provider
|
|
1557
|
+
)
|
|
850
1558
|
|
|
851
|
-
filter_query
|
|
852
|
-
|
|
1559
|
+
filter_query = ''
|
|
1560
|
+
if filter_queries:
|
|
1561
|
+
filter_query = ' AND ' + (' AND '.join(filter_queries))
|
|
853
1562
|
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
868
|
-
|
|
1563
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1564
|
+
query = (
|
|
1565
|
+
"""
|
|
1566
|
+
UNWIND $edges AS edge
|
|
1567
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1568
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1569
|
+
"""
|
|
1570
|
+
+ filter_query
|
|
1571
|
+
+ """
|
|
1572
|
+
WITH e, edge
|
|
1573
|
+
RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding,
|
|
1574
|
+
edge.fact_embedding as target_embedding,
|
|
1575
|
+
edge.uuid as search_edge_uuid
|
|
1576
|
+
"""
|
|
1577
|
+
)
|
|
1578
|
+
resp, _, _ = await driver.execute_query(
|
|
1579
|
+
query,
|
|
1580
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1581
|
+
limit=limit,
|
|
1582
|
+
min_score=min_score,
|
|
1583
|
+
routing_='r',
|
|
1584
|
+
**filter_params,
|
|
1585
|
+
)
|
|
1586
|
+
|
|
1587
|
+
# Calculate Cosine similarity then return the edge ids
|
|
1588
|
+
input_ids = []
|
|
1589
|
+
for r in resp:
|
|
1590
|
+
score = calculate_cosine_similarity(
|
|
1591
|
+
list(map(float, r['source_embedding'].split(','))), r['target_embedding']
|
|
1592
|
+
)
|
|
1593
|
+
if score > min_score:
|
|
1594
|
+
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
|
|
1595
|
+
|
|
1596
|
+
# Match the edge ides and return the values
|
|
1597
|
+
query = """
|
|
1598
|
+
UNWIND $ids AS edge
|
|
1599
|
+
MATCH ()-[e]->()
|
|
1600
|
+
WHERE id(e) = edge.id
|
|
1601
|
+
WITH edge, e
|
|
1602
|
+
ORDER BY edge.score DESC
|
|
869
1603
|
RETURN edge.uuid AS search_edge_uuid,
|
|
870
1604
|
collect({
|
|
871
1605
|
uuid: e.uuid,
|
|
@@ -875,27 +1609,117 @@ async def get_edge_invalidation_candidates(
|
|
|
875
1609
|
name: e.name,
|
|
876
1610
|
group_id: e.group_id,
|
|
877
1611
|
fact: e.fact,
|
|
878
|
-
fact_embedding: e.fact_embedding,
|
|
879
|
-
episodes: e.episodes,
|
|
1612
|
+
fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
|
|
1613
|
+
episodes: split(e.episodes, ","),
|
|
880
1614
|
expired_at: e.expired_at,
|
|
881
1615
|
valid_at: e.valid_at,
|
|
882
1616
|
invalid_at: e.invalid_at,
|
|
883
1617
|
attributes: properties(e)
|
|
884
1618
|
})[..$limit] AS matches
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
1619
|
+
"""
|
|
1620
|
+
results, _, _ = await driver.execute_query(
|
|
1621
|
+
query,
|
|
1622
|
+
ids=input_ids,
|
|
1623
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1624
|
+
limit=limit,
|
|
1625
|
+
min_score=min_score,
|
|
1626
|
+
routing_='r',
|
|
1627
|
+
**filter_params,
|
|
1628
|
+
)
|
|
1629
|
+
else:
|
|
1630
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1631
|
+
embedding_size = (
|
|
1632
|
+
len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
|
|
1633
|
+
)
|
|
1634
|
+
if embedding_size == 0:
|
|
1635
|
+
return []
|
|
1636
|
+
|
|
1637
|
+
query = (
|
|
1638
|
+
"""
|
|
1639
|
+
UNWIND $edges AS edge
|
|
1640
|
+
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
|
|
1641
|
+
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
|
|
1642
|
+
"""
|
|
1643
|
+
+ filter_query
|
|
1644
|
+
+ """
|
|
1645
|
+
WITH edge, e, n, m, """
|
|
1646
|
+
+ get_vector_cosine_func_query(
|
|
1647
|
+
'e.fact_embedding',
|
|
1648
|
+
f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
|
|
1649
|
+
driver.provider,
|
|
1650
|
+
)
|
|
1651
|
+
+ """ AS score
|
|
1652
|
+
WHERE score > $min_score
|
|
1653
|
+
WITH edge, e, n, m, score
|
|
1654
|
+
ORDER BY score DESC
|
|
1655
|
+
LIMIT $limit
|
|
1656
|
+
RETURN
|
|
1657
|
+
edge.uuid AS search_edge_uuid,
|
|
1658
|
+
collect({
|
|
1659
|
+
uuid: e.uuid,
|
|
1660
|
+
source_node_uuid: n.uuid,
|
|
1661
|
+
target_node_uuid: m.uuid,
|
|
1662
|
+
created_at: e.created_at,
|
|
1663
|
+
name: e.name,
|
|
1664
|
+
group_id: e.group_id,
|
|
1665
|
+
fact: e.fact,
|
|
1666
|
+
fact_embedding: e.fact_embedding,
|
|
1667
|
+
episodes: e.episodes,
|
|
1668
|
+
expired_at: e.expired_at,
|
|
1669
|
+
valid_at: e.valid_at,
|
|
1670
|
+
invalid_at: e.invalid_at,
|
|
1671
|
+
attributes: e.attributes
|
|
1672
|
+
}) AS matches
|
|
1673
|
+
"""
|
|
1674
|
+
)
|
|
1675
|
+
else:
|
|
1676
|
+
query = (
|
|
1677
|
+
"""
|
|
1678
|
+
UNWIND $edges AS edge
|
|
1679
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
1680
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
1681
|
+
"""
|
|
1682
|
+
+ filter_query
|
|
1683
|
+
+ """
|
|
1684
|
+
WITH edge, e, """
|
|
1685
|
+
+ get_vector_cosine_func_query(
|
|
1686
|
+
'e.fact_embedding', 'edge.fact_embedding', driver.provider
|
|
1687
|
+
)
|
|
1688
|
+
+ """ AS score
|
|
1689
|
+
WHERE score > $min_score
|
|
1690
|
+
WITH edge, e, score
|
|
1691
|
+
ORDER BY score DESC
|
|
1692
|
+
RETURN
|
|
1693
|
+
edge.uuid AS search_edge_uuid,
|
|
1694
|
+
collect({
|
|
1695
|
+
uuid: e.uuid,
|
|
1696
|
+
source_node_uuid: startNode(e).uuid,
|
|
1697
|
+
target_node_uuid: endNode(e).uuid,
|
|
1698
|
+
created_at: e.created_at,
|
|
1699
|
+
name: e.name,
|
|
1700
|
+
group_id: e.group_id,
|
|
1701
|
+
fact: e.fact,
|
|
1702
|
+
fact_embedding: e.fact_embedding,
|
|
1703
|
+
episodes: e.episodes,
|
|
1704
|
+
expired_at: e.expired_at,
|
|
1705
|
+
valid_at: e.valid_at,
|
|
1706
|
+
invalid_at: e.invalid_at,
|
|
1707
|
+
attributes: properties(e)
|
|
1708
|
+
})[..$limit] AS matches
|
|
1709
|
+
"""
|
|
1710
|
+
)
|
|
1711
|
+
|
|
1712
|
+
results, _, _ = await driver.execute_query(
|
|
1713
|
+
query,
|
|
1714
|
+
edges=[edge.model_dump() for edge in edges],
|
|
1715
|
+
limit=limit,
|
|
1716
|
+
min_score=min_score,
|
|
1717
|
+
routing_='r',
|
|
1718
|
+
**filter_params,
|
|
1719
|
+
)
|
|
896
1720
|
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
|
|
897
1721
|
result['search_edge_uuid']: [
|
|
898
|
-
get_entity_edge_from_record(record) for record in result['matches']
|
|
1722
|
+
get_entity_edge_from_record(record, driver.provider) for record in result['matches']
|
|
899
1723
|
]
|
|
900
1724
|
for result in results
|
|
901
1725
|
}
|
|
@@ -906,7 +1730,9 @@ async def get_edge_invalidation_candidates(
|
|
|
906
1730
|
|
|
907
1731
|
|
|
908
1732
|
# takes in a list of rankings of uuids
|
|
909
|
-
def rrf(
|
|
1733
|
+
def rrf(
|
|
1734
|
+
results: list[list[str]], rank_const=1, min_score: float = 0
|
|
1735
|
+
) -> tuple[list[str], list[float]]:
|
|
910
1736
|
scores: dict[str, float] = defaultdict(float)
|
|
911
1737
|
for result in results:
|
|
912
1738
|
for i, uuid in enumerate(result):
|
|
@@ -917,7 +1743,9 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
|
|
|
917
1743
|
|
|
918
1744
|
sorted_uuids = [term[0] for term in scored_uuids]
|
|
919
1745
|
|
|
920
|
-
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
|
|
1746
|
+
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
|
|
1747
|
+
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
|
|
1748
|
+
]
|
|
921
1749
|
|
|
922
1750
|
|
|
923
1751
|
async def node_distance_reranker(
|
|
@@ -925,24 +1753,31 @@ async def node_distance_reranker(
|
|
|
925
1753
|
node_uuids: list[str],
|
|
926
1754
|
center_node_uuid: str,
|
|
927
1755
|
min_score: float = 0,
|
|
928
|
-
) -> list[str]:
|
|
1756
|
+
) -> tuple[list[str], list[float]]:
|
|
929
1757
|
# filter out node_uuid center node node uuid
|
|
930
1758
|
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
|
931
1759
|
scores: dict[str, float] = {center_node_uuid: 0.0}
|
|
932
1760
|
|
|
933
|
-
# Find the shortest path to center node
|
|
934
1761
|
query = """
|
|
1762
|
+
UNWIND $node_uuids AS node_uuid
|
|
1763
|
+
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
|
1764
|
+
RETURN 1 AS score, node_uuid AS uuid
|
|
1765
|
+
"""
|
|
1766
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1767
|
+
query = """
|
|
935
1768
|
UNWIND $node_uuids AS node_uuid
|
|
936
|
-
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
|
1769
|
+
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(n:Entity {uuid: node_uuid})
|
|
937
1770
|
RETURN 1 AS score, node_uuid AS uuid
|
|
938
1771
|
"""
|
|
1772
|
+
|
|
1773
|
+
# Find the shortest path to center node
|
|
939
1774
|
results, header, _ = await driver.execute_query(
|
|
940
1775
|
query,
|
|
941
1776
|
node_uuids=filtered_uuids,
|
|
942
1777
|
center_uuid=center_node_uuid,
|
|
943
1778
|
routing_='r',
|
|
944
1779
|
)
|
|
945
|
-
if driver.provider ==
|
|
1780
|
+
if driver.provider == GraphProvider.FALKORDB:
|
|
946
1781
|
results = [dict(zip(header, row, strict=True)) for row in results]
|
|
947
1782
|
|
|
948
1783
|
for result in results:
|
|
@@ -962,24 +1797,25 @@ async def node_distance_reranker(
|
|
|
962
1797
|
scores[center_node_uuid] = 0.1
|
|
963
1798
|
filtered_uuids = [center_node_uuid] + filtered_uuids
|
|
964
1799
|
|
|
965
|
-
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score]
|
|
1800
|
+
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score], [
|
|
1801
|
+
1 / scores[uuid] for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score
|
|
1802
|
+
]
|
|
966
1803
|
|
|
967
1804
|
|
|
968
1805
|
async def episode_mentions_reranker(
|
|
969
1806
|
driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
|
|
970
|
-
) -> list[str]:
|
|
1807
|
+
) -> tuple[list[str], list[float]]:
|
|
971
1808
|
# use rrf as a preliminary ranker
|
|
972
|
-
sorted_uuids = rrf(node_uuids)
|
|
1809
|
+
sorted_uuids, _ = rrf(node_uuids)
|
|
973
1810
|
scores: dict[str, float] = {}
|
|
974
1811
|
|
|
975
1812
|
# Find the shortest path to center node
|
|
976
|
-
|
|
977
|
-
|
|
1813
|
+
results, _, _ = await driver.execute_query(
|
|
1814
|
+
"""
|
|
1815
|
+
UNWIND $node_uuids AS node_uuid
|
|
978
1816
|
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
|
979
1817
|
RETURN count(*) AS score, n.uuid AS uuid
|
|
980
|
-
"""
|
|
981
|
-
results, _, _ = await driver.execute_query(
|
|
982
|
-
query,
|
|
1818
|
+
""",
|
|
983
1819
|
node_uuids=sorted_uuids,
|
|
984
1820
|
routing_='r',
|
|
985
1821
|
)
|
|
@@ -987,10 +1823,16 @@ async def episode_mentions_reranker(
|
|
|
987
1823
|
for result in results:
|
|
988
1824
|
scores[result['uuid']] = result['score']
|
|
989
1825
|
|
|
1826
|
+
for uuid in sorted_uuids:
|
|
1827
|
+
if uuid not in scores:
|
|
1828
|
+
scores[uuid] = float('inf')
|
|
1829
|
+
|
|
990
1830
|
# rerank on shortest distance
|
|
991
1831
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
992
1832
|
|
|
993
|
-
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
|
|
1833
|
+
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
|
|
1834
|
+
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
|
|
1835
|
+
]
|
|
994
1836
|
|
|
995
1837
|
|
|
996
1838
|
def maximal_marginal_relevance(
|
|
@@ -998,7 +1840,7 @@ def maximal_marginal_relevance(
|
|
|
998
1840
|
candidates: dict[str, list[float]],
|
|
999
1841
|
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
|
1000
1842
|
min_score: float = -2.0,
|
|
1001
|
-
) -> list[str]:
|
|
1843
|
+
) -> tuple[list[str], list[float]]:
|
|
1002
1844
|
start = time()
|
|
1003
1845
|
query_array = np.array(query_vector)
|
|
1004
1846
|
candidate_arrays: dict[str, NDArray] = {}
|
|
@@ -1029,21 +1871,36 @@ def maximal_marginal_relevance(
|
|
|
1029
1871
|
end = time()
|
|
1030
1872
|
logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
|
|
1031
1873
|
|
|
1032
|
-
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
|
|
1874
|
+
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score], [
|
|
1875
|
+
mmr_scores[uuid] for uuid in uuids if mmr_scores[uuid] >= min_score
|
|
1876
|
+
]
|
|
1033
1877
|
|
|
1034
1878
|
|
|
1035
1879
|
async def get_embeddings_for_nodes(
|
|
1036
1880
|
driver: GraphDriver, nodes: list[EntityNode]
|
|
1037
1881
|
) -> dict[str, list[float]]:
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1882
|
+
if driver.graph_operations_interface:
|
|
1883
|
+
return await driver.graph_operations_interface.node_load_embeddings_bulk(driver, nodes)
|
|
1884
|
+
elif driver.provider == GraphProvider.NEPTUNE:
|
|
1885
|
+
query = """
|
|
1886
|
+
MATCH (n:Entity)
|
|
1887
|
+
WHERE n.uuid IN $node_uuids
|
|
1888
|
+
RETURN DISTINCT
|
|
1889
|
+
n.uuid AS uuid,
|
|
1890
|
+
split(n.name_embedding, ",") AS name_embedding
|
|
1891
|
+
"""
|
|
1892
|
+
else:
|
|
1893
|
+
query = """
|
|
1894
|
+
MATCH (n:Entity)
|
|
1895
|
+
WHERE n.uuid IN $node_uuids
|
|
1896
|
+
RETURN DISTINCT
|
|
1897
|
+
n.uuid AS uuid,
|
|
1898
|
+
n.name_embedding AS name_embedding
|
|
1899
|
+
"""
|
|
1045
1900
|
results, _, _ = await driver.execute_query(
|
|
1046
|
-
query,
|
|
1901
|
+
query,
|
|
1902
|
+
node_uuids=[node.uuid for node in nodes],
|
|
1903
|
+
routing_='r',
|
|
1047
1904
|
)
|
|
1048
1905
|
|
|
1049
1906
|
embeddings_dict: dict[str, list[float]] = {}
|
|
@@ -1059,13 +1916,22 @@ async def get_embeddings_for_nodes(
|
|
|
1059
1916
|
async def get_embeddings_for_communities(
|
|
1060
1917
|
driver: GraphDriver, communities: list[CommunityNode]
|
|
1061
1918
|
) -> dict[str, list[float]]:
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1919
|
+
if driver.provider == GraphProvider.NEPTUNE:
|
|
1920
|
+
query = """
|
|
1921
|
+
MATCH (c:Community)
|
|
1922
|
+
WHERE c.uuid IN $community_uuids
|
|
1923
|
+
RETURN DISTINCT
|
|
1924
|
+
c.uuid AS uuid,
|
|
1925
|
+
split(c.name_embedding, ",") AS name_embedding
|
|
1926
|
+
"""
|
|
1927
|
+
else:
|
|
1928
|
+
query = """
|
|
1929
|
+
MATCH (c:Community)
|
|
1930
|
+
WHERE c.uuid IN $community_uuids
|
|
1931
|
+
RETURN DISTINCT
|
|
1932
|
+
c.uuid AS uuid,
|
|
1933
|
+
c.name_embedding AS name_embedding
|
|
1934
|
+
"""
|
|
1069
1935
|
results, _, _ = await driver.execute_query(
|
|
1070
1936
|
query,
|
|
1071
1937
|
community_uuids=[community.uuid for community in communities],
|
|
@@ -1085,13 +1951,34 @@ async def get_embeddings_for_communities(
|
|
|
1085
1951
|
async def get_embeddings_for_edges(
|
|
1086
1952
|
driver: GraphDriver, edges: list[EntityEdge]
|
|
1087
1953
|
) -> dict[str, list[float]]:
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1954
|
+
if driver.graph_operations_interface:
|
|
1955
|
+
return await driver.graph_operations_interface.edge_load_embeddings_bulk(driver, edges)
|
|
1956
|
+
elif driver.provider == GraphProvider.NEPTUNE:
|
|
1957
|
+
query = """
|
|
1958
|
+
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1959
|
+
WHERE e.uuid IN $edge_uuids
|
|
1960
|
+
RETURN DISTINCT
|
|
1961
|
+
e.uuid AS uuid,
|
|
1962
|
+
split(e.fact_embedding, ",") AS fact_embedding
|
|
1963
|
+
"""
|
|
1964
|
+
else:
|
|
1965
|
+
match_query = """
|
|
1966
|
+
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1967
|
+
"""
|
|
1968
|
+
if driver.provider == GraphProvider.KUZU:
|
|
1969
|
+
match_query = """
|
|
1970
|
+
MATCH (n:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m:Entity)
|
|
1971
|
+
"""
|
|
1094
1972
|
|
|
1973
|
+
query = (
|
|
1974
|
+
match_query
|
|
1975
|
+
+ """
|
|
1976
|
+
WHERE e.uuid IN $edge_uuids
|
|
1977
|
+
RETURN DISTINCT
|
|
1978
|
+
e.uuid AS uuid,
|
|
1979
|
+
e.fact_embedding AS fact_embedding
|
|
1980
|
+
"""
|
|
1981
|
+
)
|
|
1095
1982
|
results, _, _ = await driver.execute_query(
|
|
1096
1983
|
query,
|
|
1097
1984
|
edge_uuids=[edge.uuid for edge in edges],
|