graphiti-core 0.11.6rc7__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/driver/__init__.py +17 -0
- graphiti_core/driver/driver.py +66 -0
- graphiti_core/driver/falkordb_driver.py +132 -0
- graphiti_core/driver/neo4j_driver.py +61 -0
- graphiti_core/edges.py +66 -40
- graphiti_core/embedder/azure_openai.py +64 -0
- graphiti_core/embedder/gemini.py +14 -3
- graphiti_core/graph_queries.py +149 -0
- graphiti_core/graphiti.py +41 -14
- graphiti_core/graphiti_types.py +2 -2
- graphiti_core/helpers.py +17 -30
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/azure_openai_client.py +73 -0
- graphiti_core/llm_client/gemini_client.py +4 -1
- graphiti_core/models/edges/edge_db_queries.py +2 -4
- graphiti_core/nodes.py +31 -31
- graphiti_core/prompts/dedupe_edges.py +52 -1
- graphiti_core/prompts/dedupe_nodes.py +79 -4
- graphiti_core/prompts/extract_edges.py +50 -5
- graphiti_core/prompts/invalidate_edges.py +1 -1
- graphiti_core/search/search.py +25 -55
- graphiti_core/search/search_filters.py +23 -9
- graphiti_core/search/search_utils.py +360 -195
- graphiti_core/utils/bulk_utils.py +38 -11
- graphiti_core/utils/maintenance/community_operations.py +6 -7
- graphiti_core/utils/maintenance/edge_operations.py +149 -19
- graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
- graphiti_core/utils/maintenance/node_operations.py +52 -71
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/METADATA +14 -5
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/RECORD +33 -26
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/LICENSE +0 -0
- {graphiti_core-0.11.6rc7.dist-info → graphiti_core-0.12.0.dist-info}/WHEEL +0 -0
|
@@ -20,10 +20,16 @@ from time import time
|
|
|
20
20
|
from typing import Any
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
|
-
from
|
|
23
|
+
from numpy._typing import NDArray
|
|
24
24
|
from typing_extensions import LiteralString
|
|
25
25
|
|
|
26
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
26
27
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
|
28
|
+
from graphiti_core.graph_queries import (
|
|
29
|
+
get_nodes_query,
|
|
30
|
+
get_relationships_query,
|
|
31
|
+
get_vector_cosine_func_query,
|
|
32
|
+
)
|
|
27
33
|
from graphiti_core.helpers import (
|
|
28
34
|
DEFAULT_DATABASE,
|
|
29
35
|
RUNTIME_QUERY,
|
|
@@ -57,7 +63,7 @@ MAX_QUERY_LENGTH = 32
|
|
|
57
63
|
|
|
58
64
|
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
|
59
65
|
group_ids_filter_list = (
|
|
60
|
-
[f'
|
|
66
|
+
[f"group_id-'{lucene_sanitize(g)}'" for g in group_ids] if group_ids is not None else []
|
|
61
67
|
)
|
|
62
68
|
group_ids_filter = ''
|
|
63
69
|
for f in group_ids_filter_list:
|
|
@@ -76,7 +82,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
|
|
|
76
82
|
|
|
77
83
|
|
|
78
84
|
async def get_episodes_by_mentions(
|
|
79
|
-
driver:
|
|
85
|
+
driver: GraphDriver,
|
|
80
86
|
nodes: list[EntityNode],
|
|
81
87
|
edges: list[EntityEdge],
|
|
82
88
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
@@ -91,11 +97,11 @@ async def get_episodes_by_mentions(
|
|
|
91
97
|
|
|
92
98
|
|
|
93
99
|
async def get_mentioned_nodes(
|
|
94
|
-
driver:
|
|
100
|
+
driver: GraphDriver, episodes: list[EpisodicNode]
|
|
95
101
|
) -> list[EntityNode]:
|
|
96
102
|
episode_uuids = [episode.uuid for episode in episodes]
|
|
97
|
-
|
|
98
|
-
|
|
103
|
+
|
|
104
|
+
query = """
|
|
99
105
|
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
|
100
106
|
RETURN DISTINCT
|
|
101
107
|
n.uuid As uuid,
|
|
@@ -105,7 +111,10 @@ async def get_mentioned_nodes(
|
|
|
105
111
|
n.summary AS summary,
|
|
106
112
|
labels(n) AS labels,
|
|
107
113
|
properties(n) AS attributes
|
|
108
|
-
"""
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
records, _, _ = await driver.execute_query(
|
|
117
|
+
query,
|
|
109
118
|
uuids=episode_uuids,
|
|
110
119
|
database_=DEFAULT_DATABASE,
|
|
111
120
|
routing_='r',
|
|
@@ -117,11 +126,11 @@ async def get_mentioned_nodes(
|
|
|
117
126
|
|
|
118
127
|
|
|
119
128
|
async def get_communities_by_nodes(
|
|
120
|
-
driver:
|
|
129
|
+
driver: GraphDriver, nodes: list[EntityNode]
|
|
121
130
|
) -> list[CommunityNode]:
|
|
122
131
|
node_uuids = [node.uuid for node in nodes]
|
|
123
|
-
|
|
124
|
-
|
|
132
|
+
|
|
133
|
+
query = """
|
|
125
134
|
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
|
126
135
|
RETURN DISTINCT
|
|
127
136
|
c.uuid As uuid,
|
|
@@ -129,7 +138,10 @@ async def get_communities_by_nodes(
|
|
|
129
138
|
c.name AS name,
|
|
130
139
|
c.created_at AS created_at,
|
|
131
140
|
c.summary AS summary
|
|
132
|
-
"""
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
records, _, _ = await driver.execute_query(
|
|
144
|
+
query,
|
|
133
145
|
uuids=node_uuids,
|
|
134
146
|
database_=DEFAULT_DATABASE,
|
|
135
147
|
routing_='r',
|
|
@@ -141,7 +153,7 @@ async def get_communities_by_nodes(
|
|
|
141
153
|
|
|
142
154
|
|
|
143
155
|
async def edge_fulltext_search(
|
|
144
|
-
driver:
|
|
156
|
+
driver: GraphDriver,
|
|
145
157
|
query: str,
|
|
146
158
|
search_filter: SearchFilters,
|
|
147
159
|
group_ids: list[str] | None = None,
|
|
@@ -154,33 +166,35 @@ async def edge_fulltext_search(
|
|
|
154
166
|
|
|
155
167
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
156
168
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
169
|
+
query = (
|
|
170
|
+
get_relationships_query('edge_name_and_fact', db_type=driver.provider)
|
|
171
|
+
+ """
|
|
172
|
+
YIELD relationship AS rel, score
|
|
173
|
+
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
|
174
|
+
WHERE r.group_id IN $group_ids """
|
|
163
175
|
+ filter_query
|
|
164
|
-
+ """
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
176
|
+
+ """
|
|
177
|
+
WITH r, score, startNode(r) AS n, endNode(r) AS m
|
|
178
|
+
RETURN
|
|
179
|
+
r.uuid AS uuid,
|
|
180
|
+
r.group_id AS group_id,
|
|
181
|
+
n.uuid AS source_node_uuid,
|
|
182
|
+
m.uuid AS target_node_uuid,
|
|
183
|
+
r.created_at AS created_at,
|
|
184
|
+
r.name AS name,
|
|
185
|
+
r.fact AS fact,
|
|
186
|
+
r.episodes AS episodes,
|
|
187
|
+
r.expired_at AS expired_at,
|
|
188
|
+
r.valid_at AS valid_at,
|
|
189
|
+
r.invalid_at AS invalid_at,
|
|
190
|
+
properties(r) AS attributes
|
|
191
|
+
ORDER BY score DESC LIMIT $limit
|
|
192
|
+
"""
|
|
179
193
|
)
|
|
180
194
|
|
|
181
195
|
records, _, _ = await driver.execute_query(
|
|
182
|
-
|
|
183
|
-
filter_params,
|
|
196
|
+
query,
|
|
197
|
+
params=filter_params,
|
|
184
198
|
query=fuzzy_query,
|
|
185
199
|
group_ids=group_ids,
|
|
186
200
|
limit=limit,
|
|
@@ -194,7 +208,7 @@ async def edge_fulltext_search(
|
|
|
194
208
|
|
|
195
209
|
|
|
196
210
|
async def edge_similarity_search(
|
|
197
|
-
driver:
|
|
211
|
+
driver: GraphDriver,
|
|
198
212
|
search_vector: list[float],
|
|
199
213
|
source_node_uuid: str | None,
|
|
200
214
|
target_node_uuid: str | None,
|
|
@@ -209,9 +223,9 @@ async def edge_similarity_search(
|
|
|
209
223
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
210
224
|
query_params.update(filter_params)
|
|
211
225
|
|
|
212
|
-
group_filter_query: LiteralString = ''
|
|
226
|
+
group_filter_query: LiteralString = 'WHERE r.group_id IS NOT NULL'
|
|
213
227
|
if group_ids is not None:
|
|
214
|
-
group_filter_query += '
|
|
228
|
+
group_filter_query += '\nAND r.group_id IN $group_ids'
|
|
215
229
|
query_params['group_ids'] = group_ids
|
|
216
230
|
query_params['source_node_uuid'] = source_node_uuid
|
|
217
231
|
query_params['target_node_uuid'] = target_node_uuid
|
|
@@ -222,35 +236,38 @@ async def edge_similarity_search(
|
|
|
222
236
|
if target_node_uuid is not None:
|
|
223
237
|
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
|
224
238
|
|
|
225
|
-
query
|
|
239
|
+
query = (
|
|
226
240
|
RUNTIME_QUERY
|
|
227
241
|
+ """
|
|
228
|
-
|
|
229
|
-
|
|
242
|
+
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
|
243
|
+
"""
|
|
230
244
|
+ group_filter_query
|
|
231
245
|
+ filter_query
|
|
232
|
-
+ """
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
246
|
+
+ """
|
|
247
|
+
WITH DISTINCT r, """
|
|
248
|
+
+ get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
|
|
249
|
+
+ """ AS score
|
|
250
|
+
WHERE score > $min_score
|
|
251
|
+
RETURN
|
|
252
|
+
r.uuid AS uuid,
|
|
253
|
+
r.group_id AS group_id,
|
|
254
|
+
startNode(r).uuid AS source_node_uuid,
|
|
255
|
+
endNode(r).uuid AS target_node_uuid,
|
|
256
|
+
r.created_at AS created_at,
|
|
257
|
+
r.name AS name,
|
|
258
|
+
r.fact AS fact,
|
|
259
|
+
r.episodes AS episodes,
|
|
260
|
+
r.expired_at AS expired_at,
|
|
261
|
+
r.valid_at AS valid_at,
|
|
262
|
+
r.invalid_at AS invalid_at,
|
|
263
|
+
properties(r) AS attributes
|
|
264
|
+
ORDER BY score DESC
|
|
265
|
+
LIMIT $limit
|
|
248
266
|
"""
|
|
249
267
|
)
|
|
250
|
-
|
|
251
|
-
records, _, _ = await driver.execute_query(
|
|
268
|
+
records, header, _ = await driver.execute_query(
|
|
252
269
|
query,
|
|
253
|
-
query_params,
|
|
270
|
+
params=query_params,
|
|
254
271
|
search_vector=search_vector,
|
|
255
272
|
source_uuid=source_node_uuid,
|
|
256
273
|
target_uuid=target_node_uuid,
|
|
@@ -261,13 +278,16 @@ async def edge_similarity_search(
|
|
|
261
278
|
routing_='r',
|
|
262
279
|
)
|
|
263
280
|
|
|
281
|
+
if driver.provider == 'falkordb':
|
|
282
|
+
records = [dict(zip(header, row, strict=True)) for row in records]
|
|
283
|
+
|
|
264
284
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
265
285
|
|
|
266
286
|
return edges
|
|
267
287
|
|
|
268
288
|
|
|
269
289
|
async def edge_bfs_search(
|
|
270
|
-
driver:
|
|
290
|
+
driver: GraphDriver,
|
|
271
291
|
bfs_origin_node_uuids: list[str] | None,
|
|
272
292
|
bfs_max_depth: int,
|
|
273
293
|
search_filter: SearchFilters,
|
|
@@ -279,14 +299,14 @@ async def edge_bfs_search(
|
|
|
279
299
|
|
|
280
300
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
281
301
|
|
|
282
|
-
query =
|
|
302
|
+
query = (
|
|
283
303
|
"""
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
304
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
305
|
+
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
|
306
|
+
UNWIND relationships(path) AS rel
|
|
307
|
+
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
|
308
|
+
WHERE r.uuid = rel.uuid
|
|
309
|
+
"""
|
|
290
310
|
+ filter_query
|
|
291
311
|
+ """
|
|
292
312
|
RETURN DISTINCT
|
|
@@ -300,14 +320,15 @@ async def edge_bfs_search(
|
|
|
300
320
|
r.episodes AS episodes,
|
|
301
321
|
r.expired_at AS expired_at,
|
|
302
322
|
r.valid_at AS valid_at,
|
|
303
|
-
r.invalid_at AS invalid_at
|
|
323
|
+
r.invalid_at AS invalid_at,
|
|
324
|
+
properties(r) AS attributes
|
|
304
325
|
LIMIT $limit
|
|
305
326
|
"""
|
|
306
327
|
)
|
|
307
328
|
|
|
308
329
|
records, _, _ = await driver.execute_query(
|
|
309
330
|
query,
|
|
310
|
-
filter_params,
|
|
331
|
+
params=filter_params,
|
|
311
332
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
312
333
|
depth=bfs_max_depth,
|
|
313
334
|
limit=limit,
|
|
@@ -321,7 +342,7 @@ async def edge_bfs_search(
|
|
|
321
342
|
|
|
322
343
|
|
|
323
344
|
async def node_fulltext_search(
|
|
324
|
-
driver:
|
|
345
|
+
driver: GraphDriver,
|
|
325
346
|
query: str,
|
|
326
347
|
search_filter: SearchFilters,
|
|
327
348
|
group_ids: list[str] | None = None,
|
|
@@ -331,38 +352,41 @@ async def node_fulltext_search(
|
|
|
331
352
|
fuzzy_query = fulltext_query(query, group_ids)
|
|
332
353
|
if fuzzy_query == '':
|
|
333
354
|
return []
|
|
334
|
-
|
|
335
355
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
336
356
|
|
|
337
357
|
query = (
|
|
358
|
+
get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
|
|
359
|
+
+ """
|
|
360
|
+
YIELD node AS n, score
|
|
361
|
+
WITH n, score
|
|
362
|
+
LIMIT $limit
|
|
363
|
+
WHERE n:Entity
|
|
338
364
|
"""
|
|
339
|
-
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
|
340
|
-
YIELD node AS n, score
|
|
341
|
-
WHERE n:Entity
|
|
342
|
-
"""
|
|
343
365
|
+ filter_query
|
|
344
366
|
+ ENTITY_NODE_RETURN
|
|
345
367
|
+ """
|
|
346
368
|
ORDER BY score DESC
|
|
347
369
|
"""
|
|
348
370
|
)
|
|
349
|
-
|
|
350
|
-
records, _, _ = await driver.execute_query(
|
|
371
|
+
records, header, _ = await driver.execute_query(
|
|
351
372
|
query,
|
|
352
|
-
filter_params,
|
|
373
|
+
params=filter_params,
|
|
353
374
|
query=fuzzy_query,
|
|
354
375
|
group_ids=group_ids,
|
|
355
376
|
limit=limit,
|
|
356
377
|
database_=DEFAULT_DATABASE,
|
|
357
378
|
routing_='r',
|
|
358
379
|
)
|
|
380
|
+
if driver.provider == 'falkordb':
|
|
381
|
+
records = [dict(zip(header, row, strict=True)) for row in records]
|
|
382
|
+
|
|
359
383
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
360
384
|
|
|
361
385
|
return nodes
|
|
362
386
|
|
|
363
387
|
|
|
364
388
|
async def node_similarity_search(
|
|
365
|
-
driver:
|
|
389
|
+
driver: GraphDriver,
|
|
366
390
|
search_vector: list[float],
|
|
367
391
|
search_filter: SearchFilters,
|
|
368
392
|
group_ids: list[str] | None = None,
|
|
@@ -372,30 +396,36 @@ async def node_similarity_search(
|
|
|
372
396
|
# vector similarity search over entity names
|
|
373
397
|
query_params: dict[str, Any] = {}
|
|
374
398
|
|
|
375
|
-
group_filter_query: LiteralString = ''
|
|
399
|
+
group_filter_query: LiteralString = 'WHERE n.group_id IS NOT NULL'
|
|
376
400
|
if group_ids is not None:
|
|
377
|
-
group_filter_query += '
|
|
401
|
+
group_filter_query += ' AND n.group_id IN $group_ids'
|
|
378
402
|
query_params['group_ids'] = group_ids
|
|
379
403
|
|
|
380
404
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
381
405
|
query_params.update(filter_params)
|
|
382
406
|
|
|
383
|
-
|
|
407
|
+
query = (
|
|
384
408
|
RUNTIME_QUERY
|
|
385
409
|
+ """
|
|
386
|
-
|
|
387
|
-
|
|
410
|
+
MATCH (n:Entity)
|
|
411
|
+
"""
|
|
388
412
|
+ group_filter_query
|
|
389
413
|
+ filter_query
|
|
390
414
|
+ """
|
|
391
|
-
|
|
392
|
-
|
|
415
|
+
WITH n, """
|
|
416
|
+
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
|
417
|
+
+ """ AS score
|
|
418
|
+
WHERE score > $min_score"""
|
|
393
419
|
+ ENTITY_NODE_RETURN
|
|
394
420
|
+ """
|
|
395
421
|
ORDER BY score DESC
|
|
396
422
|
LIMIT $limit
|
|
397
|
-
|
|
398
|
-
|
|
423
|
+
"""
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
records, header, _ = await driver.execute_query(
|
|
427
|
+
query,
|
|
428
|
+
params=query_params,
|
|
399
429
|
search_vector=search_vector,
|
|
400
430
|
group_ids=group_ids,
|
|
401
431
|
limit=limit,
|
|
@@ -403,13 +433,15 @@ async def node_similarity_search(
|
|
|
403
433
|
database_=DEFAULT_DATABASE,
|
|
404
434
|
routing_='r',
|
|
405
435
|
)
|
|
436
|
+
if driver.provider == 'falkordb':
|
|
437
|
+
records = [dict(zip(header, row, strict=True)) for row in records]
|
|
406
438
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
407
439
|
|
|
408
440
|
return nodes
|
|
409
441
|
|
|
410
442
|
|
|
411
443
|
async def node_bfs_search(
|
|
412
|
-
driver:
|
|
444
|
+
driver: GraphDriver,
|
|
413
445
|
bfs_origin_node_uuids: list[str] | None,
|
|
414
446
|
search_filter: SearchFilters,
|
|
415
447
|
bfs_max_depth: int,
|
|
@@ -421,18 +453,21 @@ async def node_bfs_search(
|
|
|
421
453
|
|
|
422
454
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
423
455
|
|
|
424
|
-
|
|
456
|
+
query = (
|
|
425
457
|
"""
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
458
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
459
|
+
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
|
460
|
+
WHERE n.group_id = origin.group_id
|
|
461
|
+
"""
|
|
430
462
|
+ filter_query
|
|
431
463
|
+ ENTITY_NODE_RETURN
|
|
432
464
|
+ """
|
|
433
465
|
LIMIT $limit
|
|
434
|
-
"""
|
|
435
|
-
|
|
466
|
+
"""
|
|
467
|
+
)
|
|
468
|
+
records, _, _ = await driver.execute_query(
|
|
469
|
+
query,
|
|
470
|
+
params=filter_params,
|
|
436
471
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
437
472
|
depth=bfs_max_depth,
|
|
438
473
|
limit=limit,
|
|
@@ -445,7 +480,7 @@ async def node_bfs_search(
|
|
|
445
480
|
|
|
446
481
|
|
|
447
482
|
async def episode_fulltext_search(
|
|
448
|
-
driver:
|
|
483
|
+
driver: GraphDriver,
|
|
449
484
|
query: str,
|
|
450
485
|
_search_filter: SearchFilters,
|
|
451
486
|
group_ids: list[str] | None = None,
|
|
@@ -456,9 +491,9 @@ async def episode_fulltext_search(
|
|
|
456
491
|
if fuzzy_query == '':
|
|
457
492
|
return []
|
|
458
493
|
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
494
|
+
query = (
|
|
495
|
+
get_nodes_query(driver.provider, 'episode_content', '$query')
|
|
496
|
+
+ """
|
|
462
497
|
YIELD node AS episode, score
|
|
463
498
|
MATCH (e:Episodic)
|
|
464
499
|
WHERE e.uuid = episode.uuid
|
|
@@ -474,7 +509,11 @@ async def episode_fulltext_search(
|
|
|
474
509
|
e.entity_edges AS entity_edges
|
|
475
510
|
ORDER BY score DESC
|
|
476
511
|
LIMIT $limit
|
|
477
|
-
"""
|
|
512
|
+
"""
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
records, _, _ = await driver.execute_query(
|
|
516
|
+
query,
|
|
478
517
|
query=fuzzy_query,
|
|
479
518
|
group_ids=group_ids,
|
|
480
519
|
limit=limit,
|
|
@@ -487,7 +526,7 @@ async def episode_fulltext_search(
|
|
|
487
526
|
|
|
488
527
|
|
|
489
528
|
async def community_fulltext_search(
|
|
490
|
-
driver:
|
|
529
|
+
driver: GraphDriver,
|
|
491
530
|
query: str,
|
|
492
531
|
group_ids: list[str] | None = None,
|
|
493
532
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
@@ -497,9 +536,9 @@ async def community_fulltext_search(
|
|
|
497
536
|
if fuzzy_query == '':
|
|
498
537
|
return []
|
|
499
538
|
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
539
|
+
query = (
|
|
540
|
+
get_nodes_query(driver.provider, 'community_name', '$query')
|
|
541
|
+
+ """
|
|
503
542
|
YIELD node AS comm, score
|
|
504
543
|
RETURN
|
|
505
544
|
comm.uuid AS uuid,
|
|
@@ -509,7 +548,11 @@ async def community_fulltext_search(
|
|
|
509
548
|
comm.summary AS summary
|
|
510
549
|
ORDER BY score DESC
|
|
511
550
|
LIMIT $limit
|
|
512
|
-
"""
|
|
551
|
+
"""
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
records, _, _ = await driver.execute_query(
|
|
555
|
+
query,
|
|
513
556
|
query=fuzzy_query,
|
|
514
557
|
group_ids=group_ids,
|
|
515
558
|
limit=limit,
|
|
@@ -522,7 +565,7 @@ async def community_fulltext_search(
|
|
|
522
565
|
|
|
523
566
|
|
|
524
567
|
async def community_similarity_search(
|
|
525
|
-
driver:
|
|
568
|
+
driver: GraphDriver,
|
|
526
569
|
search_vector: list[float],
|
|
527
570
|
group_ids: list[str] | None = None,
|
|
528
571
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
@@ -536,14 +579,16 @@ async def community_similarity_search(
|
|
|
536
579
|
group_filter_query += 'WHERE comm.group_id IN $group_ids'
|
|
537
580
|
query_params['group_ids'] = group_ids
|
|
538
581
|
|
|
539
|
-
|
|
582
|
+
query = (
|
|
540
583
|
RUNTIME_QUERY
|
|
541
584
|
+ """
|
|
542
585
|
MATCH (comm:Community)
|
|
543
586
|
"""
|
|
544
587
|
+ group_filter_query
|
|
545
588
|
+ """
|
|
546
|
-
WITH comm,
|
|
589
|
+
WITH comm, """
|
|
590
|
+
+ get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider)
|
|
591
|
+
+ """ AS score
|
|
547
592
|
WHERE score > $min_score
|
|
548
593
|
RETURN
|
|
549
594
|
comm.uuid As uuid,
|
|
@@ -553,7 +598,11 @@ async def community_similarity_search(
|
|
|
553
598
|
comm.summary AS summary
|
|
554
599
|
ORDER BY score DESC
|
|
555
600
|
LIMIT $limit
|
|
556
|
-
"""
|
|
601
|
+
"""
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
records, _, _ = await driver.execute_query(
|
|
605
|
+
query,
|
|
557
606
|
search_vector=search_vector,
|
|
558
607
|
group_ids=group_ids,
|
|
559
608
|
limit=limit,
|
|
@@ -569,7 +618,7 @@ async def community_similarity_search(
|
|
|
569
618
|
async def hybrid_node_search(
|
|
570
619
|
queries: list[str],
|
|
571
620
|
embeddings: list[list[float]],
|
|
572
|
-
driver:
|
|
621
|
+
driver: GraphDriver,
|
|
573
622
|
search_filter: SearchFilters,
|
|
574
623
|
group_ids: list[str] | None = None,
|
|
575
624
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
@@ -586,7 +635,7 @@ async def hybrid_node_search(
|
|
|
586
635
|
A list of text queries to search for.
|
|
587
636
|
embeddings : list[list[float]]
|
|
588
637
|
A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
|
|
589
|
-
driver :
|
|
638
|
+
driver : GraphDriver
|
|
590
639
|
The Neo4j driver instance for database operations.
|
|
591
640
|
group_ids : list[str] | None, optional
|
|
592
641
|
The list of group ids to retrieve nodes from.
|
|
@@ -641,7 +690,7 @@ async def hybrid_node_search(
|
|
|
641
690
|
|
|
642
691
|
|
|
643
692
|
async def get_relevant_nodes(
|
|
644
|
-
driver:
|
|
693
|
+
driver: GraphDriver,
|
|
645
694
|
nodes: list[EntityNode],
|
|
646
695
|
search_filter: SearchFilters,
|
|
647
696
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
@@ -660,29 +709,33 @@ async def get_relevant_nodes(
|
|
|
660
709
|
|
|
661
710
|
query = (
|
|
662
711
|
RUNTIME_QUERY
|
|
663
|
-
+ """
|
|
664
|
-
|
|
665
|
-
|
|
712
|
+
+ """
|
|
713
|
+
UNWIND $nodes AS node
|
|
714
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
715
|
+
"""
|
|
666
716
|
+ filter_query
|
|
667
717
|
+ """
|
|
668
|
-
WITH node, n,
|
|
718
|
+
WITH node, n, """
|
|
719
|
+
+ get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
|
|
720
|
+
+ """ AS score
|
|
669
721
|
WHERE score > $min_score
|
|
670
722
|
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
|
671
|
-
|
|
672
|
-
|
|
723
|
+
"""
|
|
724
|
+
+ get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
|
|
725
|
+
+ """
|
|
673
726
|
YIELD node AS m
|
|
674
727
|
WHERE m.group_id = $group_id
|
|
675
728
|
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
|
676
|
-
|
|
729
|
+
|
|
677
730
|
WITH node,
|
|
678
731
|
top_vector_nodes,
|
|
679
732
|
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
|
680
|
-
|
|
733
|
+
|
|
681
734
|
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
|
682
|
-
|
|
735
|
+
|
|
683
736
|
UNWIND combined_nodes AS combined_node
|
|
684
737
|
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
|
685
|
-
|
|
738
|
+
|
|
686
739
|
RETURN
|
|
687
740
|
node.uuid AS search_node_uuid,
|
|
688
741
|
[x IN deduped_nodes | {
|
|
@@ -710,7 +763,7 @@ async def get_relevant_nodes(
|
|
|
710
763
|
|
|
711
764
|
results, _, _ = await driver.execute_query(
|
|
712
765
|
query,
|
|
713
|
-
query_params,
|
|
766
|
+
params=query_params,
|
|
714
767
|
nodes=query_nodes,
|
|
715
768
|
group_id=group_id,
|
|
716
769
|
limit=limit,
|
|
@@ -732,7 +785,7 @@ async def get_relevant_nodes(
|
|
|
732
785
|
|
|
733
786
|
|
|
734
787
|
async def get_relevant_edges(
|
|
735
|
-
driver:
|
|
788
|
+
driver: GraphDriver,
|
|
736
789
|
edges: list[EntityEdge],
|
|
737
790
|
search_filter: SearchFilters,
|
|
738
791
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
@@ -748,42 +801,47 @@ async def get_relevant_edges(
|
|
|
748
801
|
|
|
749
802
|
query = (
|
|
750
803
|
RUNTIME_QUERY
|
|
751
|
-
+ """
|
|
752
|
-
|
|
753
|
-
|
|
804
|
+
+ """
|
|
805
|
+
UNWIND $edges AS edge
|
|
806
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
807
|
+
"""
|
|
754
808
|
+ filter_query
|
|
755
809
|
+ """
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
810
|
+
WITH e, edge, """
|
|
811
|
+
+ get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
|
|
812
|
+
+ """ AS score
|
|
813
|
+
WHERE score > $min_score
|
|
814
|
+
WITH edge, e, score
|
|
815
|
+
ORDER BY score DESC
|
|
816
|
+
RETURN edge.uuid AS search_edge_uuid,
|
|
817
|
+
collect({
|
|
818
|
+
uuid: e.uuid,
|
|
819
|
+
source_node_uuid: startNode(e).uuid,
|
|
820
|
+
target_node_uuid: endNode(e).uuid,
|
|
821
|
+
created_at: e.created_at,
|
|
822
|
+
name: e.name,
|
|
823
|
+
group_id: e.group_id,
|
|
824
|
+
fact: e.fact,
|
|
825
|
+
fact_embedding: e.fact_embedding,
|
|
826
|
+
episodes: e.episodes,
|
|
827
|
+
expired_at: e.expired_at,
|
|
828
|
+
valid_at: e.valid_at,
|
|
829
|
+
invalid_at: e.invalid_at,
|
|
830
|
+
attributes: properties(e)
|
|
831
|
+
})[..$limit] AS matches
|
|
775
832
|
"""
|
|
776
833
|
)
|
|
777
834
|
|
|
778
835
|
results, _, _ = await driver.execute_query(
|
|
779
836
|
query,
|
|
780
|
-
query_params,
|
|
837
|
+
params=query_params,
|
|
781
838
|
edges=[edge.model_dump() for edge in edges],
|
|
782
839
|
limit=limit,
|
|
783
840
|
min_score=min_score,
|
|
784
841
|
database_=DEFAULT_DATABASE,
|
|
785
842
|
routing_='r',
|
|
786
843
|
)
|
|
844
|
+
|
|
787
845
|
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
|
788
846
|
result['search_edge_uuid']: [
|
|
789
847
|
get_entity_edge_from_record(record) for record in result['matches']
|
|
@@ -797,7 +855,7 @@ async def get_relevant_edges(
|
|
|
797
855
|
|
|
798
856
|
|
|
799
857
|
async def get_edge_invalidation_candidates(
|
|
800
|
-
driver:
|
|
858
|
+
driver: GraphDriver,
|
|
801
859
|
edges: list[EntityEdge],
|
|
802
860
|
search_filter: SearchFilters,
|
|
803
861
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
@@ -813,37 +871,41 @@ async def get_edge_invalidation_candidates(
|
|
|
813
871
|
|
|
814
872
|
query = (
|
|
815
873
|
RUNTIME_QUERY
|
|
816
|
-
+ """
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
874
|
+
+ """
|
|
875
|
+
UNWIND $edges AS edge
|
|
876
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
877
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
878
|
+
"""
|
|
820
879
|
+ filter_query
|
|
821
880
|
+ """
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
881
|
+
WITH edge, e, """
|
|
882
|
+
+ get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
|
|
883
|
+
+ """ AS score
|
|
884
|
+
WHERE score > $min_score
|
|
885
|
+
WITH edge, e, score
|
|
886
|
+
ORDER BY score DESC
|
|
887
|
+
RETURN edge.uuid AS search_edge_uuid,
|
|
888
|
+
collect({
|
|
889
|
+
uuid: e.uuid,
|
|
890
|
+
source_node_uuid: startNode(e).uuid,
|
|
891
|
+
target_node_uuid: endNode(e).uuid,
|
|
892
|
+
created_at: e.created_at,
|
|
893
|
+
name: e.name,
|
|
894
|
+
group_id: e.group_id,
|
|
895
|
+
fact: e.fact,
|
|
896
|
+
fact_embedding: e.fact_embedding,
|
|
897
|
+
episodes: e.episodes,
|
|
898
|
+
expired_at: e.expired_at,
|
|
899
|
+
valid_at: e.valid_at,
|
|
900
|
+
invalid_at: e.invalid_at,
|
|
901
|
+
attributes: properties(e)
|
|
902
|
+
})[..$limit] AS matches
|
|
841
903
|
"""
|
|
842
904
|
)
|
|
843
905
|
|
|
844
906
|
results, _, _ = await driver.execute_query(
|
|
845
907
|
query,
|
|
846
|
-
query_params,
|
|
908
|
+
params=query_params,
|
|
847
909
|
edges=[edge.model_dump() for edge in edges],
|
|
848
910
|
limit=limit,
|
|
849
911
|
min_score=min_score,
|
|
@@ -878,7 +940,7 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
|
|
|
878
940
|
|
|
879
941
|
|
|
880
942
|
async def node_distance_reranker(
|
|
881
|
-
driver:
|
|
943
|
+
driver: GraphDriver,
|
|
882
944
|
node_uuids: list[str],
|
|
883
945
|
center_node_uuid: str,
|
|
884
946
|
min_score: float = 0,
|
|
@@ -888,20 +950,22 @@ async def node_distance_reranker(
|
|
|
888
950
|
scores: dict[str, float] = {center_node_uuid: 0.0}
|
|
889
951
|
|
|
890
952
|
# Find the shortest path to center node
|
|
891
|
-
query =
|
|
953
|
+
query = """
|
|
892
954
|
UNWIND $node_uuids AS node_uuid
|
|
893
|
-
MATCH
|
|
894
|
-
RETURN
|
|
895
|
-
"""
|
|
896
|
-
|
|
897
|
-
path_results, _, _ = await driver.execute_query(
|
|
955
|
+
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
|
956
|
+
RETURN 1 AS score, node_uuid AS uuid
|
|
957
|
+
"""
|
|
958
|
+
results, header, _ = await driver.execute_query(
|
|
898
959
|
query,
|
|
899
960
|
node_uuids=filtered_uuids,
|
|
900
961
|
center_uuid=center_node_uuid,
|
|
901
962
|
database_=DEFAULT_DATABASE,
|
|
963
|
+
routing_='r',
|
|
902
964
|
)
|
|
965
|
+
if driver.provider == 'falkordb':
|
|
966
|
+
results = [dict(zip(header, row, strict=True)) for row in results]
|
|
903
967
|
|
|
904
|
-
for result in
|
|
968
|
+
for result in results:
|
|
905
969
|
uuid = result['uuid']
|
|
906
970
|
score = result['score']
|
|
907
971
|
scores[uuid] = score
|
|
@@ -922,23 +986,23 @@ async def node_distance_reranker(
|
|
|
922
986
|
|
|
923
987
|
|
|
924
988
|
async def episode_mentions_reranker(
|
|
925
|
-
driver:
|
|
989
|
+
driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
|
|
926
990
|
) -> list[str]:
|
|
927
991
|
# use rrf as a preliminary ranker
|
|
928
992
|
sorted_uuids = rrf(node_uuids)
|
|
929
993
|
scores: dict[str, float] = {}
|
|
930
994
|
|
|
931
995
|
# Find the shortest path to center node
|
|
932
|
-
query =
|
|
996
|
+
query = """
|
|
933
997
|
UNWIND $node_uuids AS node_uuid
|
|
934
998
|
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
|
935
999
|
RETURN count(*) AS score, n.uuid AS uuid
|
|
936
|
-
"""
|
|
937
|
-
|
|
1000
|
+
"""
|
|
938
1001
|
results, _, _ = await driver.execute_query(
|
|
939
1002
|
query,
|
|
940
1003
|
node_uuids=sorted_uuids,
|
|
941
1004
|
database_=DEFAULT_DATABASE,
|
|
1005
|
+
routing_='r',
|
|
942
1006
|
)
|
|
943
1007
|
|
|
944
1008
|
for result in results:
|
|
@@ -952,15 +1016,116 @@ async def episode_mentions_reranker(
|
|
|
952
1016
|
|
|
953
1017
|
def maximal_marginal_relevance(
|
|
954
1018
|
query_vector: list[float],
|
|
955
|
-
candidates:
|
|
1019
|
+
candidates: dict[str, list[float]],
|
|
956
1020
|
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
1021
|
+
min_score: float = -2.0,
|
|
1022
|
+
) -> list[str]:
|
|
1023
|
+
start = time()
|
|
1024
|
+
query_array = np.array(query_vector)
|
|
1025
|
+
candidate_arrays: dict[str, NDArray] = {}
|
|
1026
|
+
for uuid, embedding in candidates.items():
|
|
1027
|
+
candidate_arrays[uuid] = normalize_l2(embedding)
|
|
1028
|
+
|
|
1029
|
+
uuids: list[str] = list(candidate_arrays.keys())
|
|
1030
|
+
|
|
1031
|
+
similarity_matrix = np.zeros((len(uuids), len(uuids)))
|
|
1032
|
+
|
|
1033
|
+
for i, uuid_1 in enumerate(uuids):
|
|
1034
|
+
for j, uuid_2 in enumerate(uuids[:i]):
|
|
1035
|
+
u = candidate_arrays[uuid_1]
|
|
1036
|
+
v = candidate_arrays[uuid_2]
|
|
1037
|
+
similarity = np.dot(u, v)
|
|
1038
|
+
|
|
1039
|
+
similarity_matrix[i, j] = similarity
|
|
1040
|
+
similarity_matrix[j, i] = similarity
|
|
1041
|
+
|
|
1042
|
+
mmr_scores: dict[str, float] = {}
|
|
1043
|
+
for i, uuid in enumerate(uuids):
|
|
1044
|
+
max_sim = np.max(similarity_matrix[i, :])
|
|
1045
|
+
mmr = mmr_lambda * np.dot(query_array, candidate_arrays[uuid]) + (mmr_lambda - 1) * max_sim
|
|
1046
|
+
mmr_scores[uuid] = mmr
|
|
1047
|
+
|
|
1048
|
+
uuids.sort(reverse=True, key=lambda c: mmr_scores[c])
|
|
1049
|
+
|
|
1050
|
+
end = time()
|
|
1051
|
+
logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
|
|
1052
|
+
|
|
1053
|
+
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
|
|
1054
|
+
|
|
1055
|
+
|
|
1056
|
+
async def get_embeddings_for_nodes(
|
|
1057
|
+
driver: GraphDriver, nodes: list[EntityNode]
|
|
1058
|
+
) -> dict[str, list[float]]:
|
|
1059
|
+
query: LiteralString = """MATCH (n:Entity)
|
|
1060
|
+
WHERE n.uuid IN $node_uuids
|
|
1061
|
+
RETURN DISTINCT
|
|
1062
|
+
n.uuid AS uuid,
|
|
1063
|
+
n.name_embedding AS name_embedding
|
|
1064
|
+
"""
|
|
1065
|
+
|
|
1066
|
+
results, _, _ = await driver.execute_query(
|
|
1067
|
+
query, node_uuids=[node.uuid for node in nodes], database_=DEFAULT_DATABASE, routing_='r'
|
|
1068
|
+
)
|
|
1069
|
+
|
|
1070
|
+
embeddings_dict: dict[str, list[float]] = {}
|
|
1071
|
+
for result in results:
|
|
1072
|
+
uuid: str = result.get('uuid')
|
|
1073
|
+
embedding: list[float] = result.get('name_embedding')
|
|
1074
|
+
if uuid is not None and embedding is not None:
|
|
1075
|
+
embeddings_dict[uuid] = embedding
|
|
1076
|
+
|
|
1077
|
+
return embeddings_dict
|
|
1078
|
+
|
|
963
1079
|
|
|
964
|
-
|
|
1080
|
+
async def get_embeddings_for_communities(
|
|
1081
|
+
driver: GraphDriver, communities: list[CommunityNode]
|
|
1082
|
+
) -> dict[str, list[float]]:
|
|
1083
|
+
query: LiteralString = """MATCH (c:Community)
|
|
1084
|
+
WHERE c.uuid IN $community_uuids
|
|
1085
|
+
RETURN DISTINCT
|
|
1086
|
+
c.uuid AS uuid,
|
|
1087
|
+
c.name_embedding AS name_embedding
|
|
1088
|
+
"""
|
|
1089
|
+
|
|
1090
|
+
results, _, _ = await driver.execute_query(
|
|
1091
|
+
query,
|
|
1092
|
+
community_uuids=[community.uuid for community in communities],
|
|
1093
|
+
database_=DEFAULT_DATABASE,
|
|
1094
|
+
routing_='r',
|
|
1095
|
+
)
|
|
1096
|
+
|
|
1097
|
+
embeddings_dict: dict[str, list[float]] = {}
|
|
1098
|
+
for result in results:
|
|
1099
|
+
uuid: str = result.get('uuid')
|
|
1100
|
+
embedding: list[float] = result.get('name_embedding')
|
|
1101
|
+
if uuid is not None and embedding is not None:
|
|
1102
|
+
embeddings_dict[uuid] = embedding
|
|
1103
|
+
|
|
1104
|
+
return embeddings_dict
|
|
1105
|
+
|
|
1106
|
+
|
|
1107
|
+
async def get_embeddings_for_edges(
|
|
1108
|
+
driver: GraphDriver, edges: list[EntityEdge]
|
|
1109
|
+
) -> dict[str, list[float]]:
|
|
1110
|
+
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1111
|
+
WHERE e.uuid IN $edge_uuids
|
|
1112
|
+
RETURN DISTINCT
|
|
1113
|
+
e.uuid AS uuid,
|
|
1114
|
+
e.fact_embedding AS fact_embedding
|
|
1115
|
+
"""
|
|
1116
|
+
|
|
1117
|
+
results, _, _ = await driver.execute_query(
|
|
1118
|
+
query,
|
|
1119
|
+
edge_uuids=[edge.uuid for edge in edges],
|
|
1120
|
+
database_=DEFAULT_DATABASE,
|
|
1121
|
+
routing_='r',
|
|
1122
|
+
)
|
|
1123
|
+
|
|
1124
|
+
embeddings_dict: dict[str, list[float]] = {}
|
|
1125
|
+
for result in results:
|
|
1126
|
+
uuid: str = result.get('uuid')
|
|
1127
|
+
embedding: list[float] = result.get('fact_embedding')
|
|
1128
|
+
if uuid is not None and embedding is not None:
|
|
1129
|
+
embeddings_dict[uuid] = embedding
|
|
965
1130
|
|
|
966
|
-
return
|
|
1131
|
+
return embeddings_dict
|