graphiti-core 0.12.0rc5__py3-none-any.whl → 0.12.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/driver/__init__.py +17 -0
- graphiti_core/driver/driver.py +66 -0
- graphiti_core/driver/falkordb_driver.py +131 -0
- graphiti_core/driver/neo4j_driver.py +61 -0
- graphiti_core/edges.py +26 -26
- graphiti_core/embedder/azure_openai.py +64 -0
- graphiti_core/graph_queries.py +149 -0
- graphiti_core/graphiti.py +21 -8
- graphiti_core/graphiti_types.py +2 -2
- graphiti_core/helpers.py +9 -3
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/azure_openai_client.py +73 -0
- graphiti_core/nodes.py +31 -31
- graphiti_core/prompts/dedupe_nodes.py +5 -1
- graphiti_core/prompts/extract_edges.py +2 -0
- graphiti_core/prompts/extract_nodes.py +2 -0
- graphiti_core/search/search.py +6 -10
- graphiti_core/search/search_utils.py +243 -187
- graphiti_core/utils/bulk_utils.py +21 -11
- graphiti_core/utils/maintenance/community_operations.py +6 -7
- graphiti_core/utils/maintenance/edge_operations.py +68 -3
- graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
- graphiti_core/utils/maintenance/node_operations.py +19 -5
- {graphiti_core-0.12.0rc5.dist-info → graphiti_core-0.12.2.dist-info}/METADATA +4 -3
- {graphiti_core-0.12.0rc5.dist-info → graphiti_core-0.12.2.dist-info}/RECORD +28 -21
- {graphiti_core-0.12.0rc5.dist-info → graphiti_core-0.12.2.dist-info}/LICENSE +0 -0
- {graphiti_core-0.12.0rc5.dist-info → graphiti_core-0.12.2.dist-info}/WHEEL +0 -0
|
@@ -20,11 +20,16 @@ from time import time
|
|
|
20
20
|
from typing import Any
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
|
-
from neo4j import AsyncDriver, Query
|
|
24
23
|
from numpy._typing import NDArray
|
|
25
24
|
from typing_extensions import LiteralString
|
|
26
25
|
|
|
26
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
27
27
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
|
28
|
+
from graphiti_core.graph_queries import (
|
|
29
|
+
get_nodes_query,
|
|
30
|
+
get_relationships_query,
|
|
31
|
+
get_vector_cosine_func_query,
|
|
32
|
+
)
|
|
28
33
|
from graphiti_core.helpers import (
|
|
29
34
|
DEFAULT_DATABASE,
|
|
30
35
|
RUNTIME_QUERY,
|
|
@@ -58,7 +63,7 @@ MAX_QUERY_LENGTH = 32
|
|
|
58
63
|
|
|
59
64
|
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
|
60
65
|
group_ids_filter_list = (
|
|
61
|
-
[f'
|
|
66
|
+
[f"group_id-'{lucene_sanitize(g)}'" for g in group_ids] if group_ids is not None else []
|
|
62
67
|
)
|
|
63
68
|
group_ids_filter = ''
|
|
64
69
|
for f in group_ids_filter_list:
|
|
@@ -77,7 +82,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
|
|
|
77
82
|
|
|
78
83
|
|
|
79
84
|
async def get_episodes_by_mentions(
|
|
80
|
-
driver:
|
|
85
|
+
driver: GraphDriver,
|
|
81
86
|
nodes: list[EntityNode],
|
|
82
87
|
edges: list[EntityEdge],
|
|
83
88
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
@@ -92,11 +97,11 @@ async def get_episodes_by_mentions(
|
|
|
92
97
|
|
|
93
98
|
|
|
94
99
|
async def get_mentioned_nodes(
|
|
95
|
-
driver:
|
|
100
|
+
driver: GraphDriver, episodes: list[EpisodicNode]
|
|
96
101
|
) -> list[EntityNode]:
|
|
97
102
|
episode_uuids = [episode.uuid for episode in episodes]
|
|
98
|
-
|
|
99
|
-
|
|
103
|
+
|
|
104
|
+
query = """
|
|
100
105
|
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
|
101
106
|
RETURN DISTINCT
|
|
102
107
|
n.uuid As uuid,
|
|
@@ -106,7 +111,10 @@ async def get_mentioned_nodes(
|
|
|
106
111
|
n.summary AS summary,
|
|
107
112
|
labels(n) AS labels,
|
|
108
113
|
properties(n) AS attributes
|
|
109
|
-
"""
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
records, _, _ = await driver.execute_query(
|
|
117
|
+
query,
|
|
110
118
|
uuids=episode_uuids,
|
|
111
119
|
database_=DEFAULT_DATABASE,
|
|
112
120
|
routing_='r',
|
|
@@ -118,11 +126,11 @@ async def get_mentioned_nodes(
|
|
|
118
126
|
|
|
119
127
|
|
|
120
128
|
async def get_communities_by_nodes(
|
|
121
|
-
driver:
|
|
129
|
+
driver: GraphDriver, nodes: list[EntityNode]
|
|
122
130
|
) -> list[CommunityNode]:
|
|
123
131
|
node_uuids = [node.uuid for node in nodes]
|
|
124
|
-
|
|
125
|
-
|
|
132
|
+
|
|
133
|
+
query = """
|
|
126
134
|
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
|
127
135
|
RETURN DISTINCT
|
|
128
136
|
c.uuid As uuid,
|
|
@@ -130,7 +138,10 @@ async def get_communities_by_nodes(
|
|
|
130
138
|
c.name AS name,
|
|
131
139
|
c.created_at AS created_at,
|
|
132
140
|
c.summary AS summary
|
|
133
|
-
"""
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
records, _, _ = await driver.execute_query(
|
|
144
|
+
query,
|
|
134
145
|
uuids=node_uuids,
|
|
135
146
|
database_=DEFAULT_DATABASE,
|
|
136
147
|
routing_='r',
|
|
@@ -142,7 +153,7 @@ async def get_communities_by_nodes(
|
|
|
142
153
|
|
|
143
154
|
|
|
144
155
|
async def edge_fulltext_search(
|
|
145
|
-
driver:
|
|
156
|
+
driver: GraphDriver,
|
|
146
157
|
query: str,
|
|
147
158
|
search_filter: SearchFilters,
|
|
148
159
|
group_ids: list[str] | None = None,
|
|
@@ -155,34 +166,35 @@ async def edge_fulltext_search(
|
|
|
155
166
|
|
|
156
167
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
157
168
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
169
|
+
query = (
|
|
170
|
+
get_relationships_query('edge_name_and_fact', db_type=driver.provider)
|
|
171
|
+
+ """
|
|
172
|
+
YIELD relationship AS rel, score
|
|
173
|
+
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
|
174
|
+
WHERE r.group_id IN $group_ids """
|
|
164
175
|
+ filter_query
|
|
165
|
-
+ """
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
176
|
+
+ """
|
|
177
|
+
WITH r, score, startNode(r) AS n, endNode(r) AS m
|
|
178
|
+
RETURN
|
|
179
|
+
r.uuid AS uuid,
|
|
180
|
+
r.group_id AS group_id,
|
|
181
|
+
n.uuid AS source_node_uuid,
|
|
182
|
+
m.uuid AS target_node_uuid,
|
|
183
|
+
r.created_at AS created_at,
|
|
184
|
+
r.name AS name,
|
|
185
|
+
r.fact AS fact,
|
|
186
|
+
r.episodes AS episodes,
|
|
187
|
+
r.expired_at AS expired_at,
|
|
188
|
+
r.valid_at AS valid_at,
|
|
189
|
+
r.invalid_at AS invalid_at,
|
|
190
|
+
properties(r) AS attributes
|
|
191
|
+
ORDER BY score DESC LIMIT $limit
|
|
192
|
+
"""
|
|
181
193
|
)
|
|
182
194
|
|
|
183
195
|
records, _, _ = await driver.execute_query(
|
|
184
|
-
|
|
185
|
-
filter_params,
|
|
196
|
+
query,
|
|
197
|
+
params=filter_params,
|
|
186
198
|
query=fuzzy_query,
|
|
187
199
|
group_ids=group_ids,
|
|
188
200
|
limit=limit,
|
|
@@ -196,7 +208,7 @@ async def edge_fulltext_search(
|
|
|
196
208
|
|
|
197
209
|
|
|
198
210
|
async def edge_similarity_search(
|
|
199
|
-
driver:
|
|
211
|
+
driver: GraphDriver,
|
|
200
212
|
search_vector: list[float],
|
|
201
213
|
source_node_uuid: str | None,
|
|
202
214
|
target_node_uuid: str | None,
|
|
@@ -224,36 +236,38 @@ async def edge_similarity_search(
|
|
|
224
236
|
if target_node_uuid is not None:
|
|
225
237
|
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
|
226
238
|
|
|
227
|
-
query
|
|
239
|
+
query = (
|
|
228
240
|
RUNTIME_QUERY
|
|
229
241
|
+ """
|
|
230
242
|
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
|
231
|
-
|
|
243
|
+
"""
|
|
232
244
|
+ group_filter_query
|
|
233
245
|
+ filter_query
|
|
234
|
-
+ """
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
246
|
+
+ """
|
|
247
|
+
WITH DISTINCT r, """
|
|
248
|
+
+ get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
|
|
249
|
+
+ """ AS score
|
|
250
|
+
WHERE score > $min_score
|
|
251
|
+
RETURN
|
|
252
|
+
r.uuid AS uuid,
|
|
253
|
+
r.group_id AS group_id,
|
|
254
|
+
startNode(r).uuid AS source_node_uuid,
|
|
255
|
+
endNode(r).uuid AS target_node_uuid,
|
|
256
|
+
r.created_at AS created_at,
|
|
257
|
+
r.name AS name,
|
|
258
|
+
r.fact AS fact,
|
|
259
|
+
r.episodes AS episodes,
|
|
260
|
+
r.expired_at AS expired_at,
|
|
261
|
+
r.valid_at AS valid_at,
|
|
262
|
+
r.invalid_at AS invalid_at,
|
|
263
|
+
properties(r) AS attributes
|
|
264
|
+
ORDER BY score DESC
|
|
265
|
+
LIMIT $limit
|
|
251
266
|
"""
|
|
252
267
|
)
|
|
253
|
-
|
|
254
|
-
records, _, _ = await driver.execute_query(
|
|
268
|
+
records, header, _ = await driver.execute_query(
|
|
255
269
|
query,
|
|
256
|
-
query_params,
|
|
270
|
+
params=query_params,
|
|
257
271
|
search_vector=search_vector,
|
|
258
272
|
source_uuid=source_node_uuid,
|
|
259
273
|
target_uuid=target_node_uuid,
|
|
@@ -264,13 +278,16 @@ async def edge_similarity_search(
|
|
|
264
278
|
routing_='r',
|
|
265
279
|
)
|
|
266
280
|
|
|
281
|
+
if driver.provider == 'falkordb':
|
|
282
|
+
records = [dict(zip(header, row, strict=True)) for row in records]
|
|
283
|
+
|
|
267
284
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
268
285
|
|
|
269
286
|
return edges
|
|
270
287
|
|
|
271
288
|
|
|
272
289
|
async def edge_bfs_search(
|
|
273
|
-
driver:
|
|
290
|
+
driver: GraphDriver,
|
|
274
291
|
bfs_origin_node_uuids: list[str] | None,
|
|
275
292
|
bfs_max_depth: int,
|
|
276
293
|
search_filter: SearchFilters,
|
|
@@ -282,14 +299,14 @@ async def edge_bfs_search(
|
|
|
282
299
|
|
|
283
300
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
284
301
|
|
|
285
|
-
query =
|
|
302
|
+
query = (
|
|
286
303
|
"""
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
304
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
305
|
+
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
|
306
|
+
UNWIND relationships(path) AS rel
|
|
307
|
+
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
|
308
|
+
WHERE r.uuid = rel.uuid
|
|
309
|
+
"""
|
|
293
310
|
+ filter_query
|
|
294
311
|
+ """
|
|
295
312
|
RETURN DISTINCT
|
|
@@ -311,7 +328,7 @@ async def edge_bfs_search(
|
|
|
311
328
|
|
|
312
329
|
records, _, _ = await driver.execute_query(
|
|
313
330
|
query,
|
|
314
|
-
filter_params,
|
|
331
|
+
params=filter_params,
|
|
315
332
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
316
333
|
depth=bfs_max_depth,
|
|
317
334
|
limit=limit,
|
|
@@ -325,7 +342,7 @@ async def edge_bfs_search(
|
|
|
325
342
|
|
|
326
343
|
|
|
327
344
|
async def node_fulltext_search(
|
|
328
|
-
driver:
|
|
345
|
+
driver: GraphDriver,
|
|
329
346
|
query: str,
|
|
330
347
|
search_filter: SearchFilters,
|
|
331
348
|
group_ids: list[str] | None = None,
|
|
@@ -335,38 +352,41 @@ async def node_fulltext_search(
|
|
|
335
352
|
fuzzy_query = fulltext_query(query, group_ids)
|
|
336
353
|
if fuzzy_query == '':
|
|
337
354
|
return []
|
|
338
|
-
|
|
339
355
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
340
356
|
|
|
341
357
|
query = (
|
|
358
|
+
get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
|
|
359
|
+
+ """
|
|
360
|
+
YIELD node AS n, score
|
|
361
|
+
WITH n, score
|
|
362
|
+
LIMIT $limit
|
|
363
|
+
WHERE n:Entity
|
|
342
364
|
"""
|
|
343
|
-
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
|
344
|
-
YIELD node AS n, score
|
|
345
|
-
WHERE n:Entity
|
|
346
|
-
"""
|
|
347
365
|
+ filter_query
|
|
348
366
|
+ ENTITY_NODE_RETURN
|
|
349
367
|
+ """
|
|
350
368
|
ORDER BY score DESC
|
|
351
369
|
"""
|
|
352
370
|
)
|
|
353
|
-
|
|
354
|
-
records, _, _ = await driver.execute_query(
|
|
371
|
+
records, header, _ = await driver.execute_query(
|
|
355
372
|
query,
|
|
356
|
-
filter_params,
|
|
373
|
+
params=filter_params,
|
|
357
374
|
query=fuzzy_query,
|
|
358
375
|
group_ids=group_ids,
|
|
359
376
|
limit=limit,
|
|
360
377
|
database_=DEFAULT_DATABASE,
|
|
361
378
|
routing_='r',
|
|
362
379
|
)
|
|
380
|
+
if driver.provider == 'falkordb':
|
|
381
|
+
records = [dict(zip(header, row, strict=True)) for row in records]
|
|
382
|
+
|
|
363
383
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
364
384
|
|
|
365
385
|
return nodes
|
|
366
386
|
|
|
367
387
|
|
|
368
388
|
async def node_similarity_search(
|
|
369
|
-
driver:
|
|
389
|
+
driver: GraphDriver,
|
|
370
390
|
search_vector: list[float],
|
|
371
391
|
search_filter: SearchFilters,
|
|
372
392
|
group_ids: list[str] | None = None,
|
|
@@ -384,22 +404,28 @@ async def node_similarity_search(
|
|
|
384
404
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
385
405
|
query_params.update(filter_params)
|
|
386
406
|
|
|
387
|
-
|
|
407
|
+
query = (
|
|
388
408
|
RUNTIME_QUERY
|
|
389
409
|
+ """
|
|
390
|
-
|
|
391
|
-
|
|
410
|
+
MATCH (n:Entity)
|
|
411
|
+
"""
|
|
392
412
|
+ group_filter_query
|
|
393
413
|
+ filter_query
|
|
394
414
|
+ """
|
|
395
|
-
|
|
396
|
-
|
|
415
|
+
WITH n, """
|
|
416
|
+
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
|
417
|
+
+ """ AS score
|
|
418
|
+
WHERE score > $min_score"""
|
|
397
419
|
+ ENTITY_NODE_RETURN
|
|
398
420
|
+ """
|
|
399
421
|
ORDER BY score DESC
|
|
400
422
|
LIMIT $limit
|
|
401
|
-
|
|
402
|
-
|
|
423
|
+
"""
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
records, header, _ = await driver.execute_query(
|
|
427
|
+
query,
|
|
428
|
+
params=query_params,
|
|
403
429
|
search_vector=search_vector,
|
|
404
430
|
group_ids=group_ids,
|
|
405
431
|
limit=limit,
|
|
@@ -407,13 +433,15 @@ async def node_similarity_search(
|
|
|
407
433
|
database_=DEFAULT_DATABASE,
|
|
408
434
|
routing_='r',
|
|
409
435
|
)
|
|
436
|
+
if driver.provider == 'falkordb':
|
|
437
|
+
records = [dict(zip(header, row, strict=True)) for row in records]
|
|
410
438
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
411
439
|
|
|
412
440
|
return nodes
|
|
413
441
|
|
|
414
442
|
|
|
415
443
|
async def node_bfs_search(
|
|
416
|
-
driver:
|
|
444
|
+
driver: GraphDriver,
|
|
417
445
|
bfs_origin_node_uuids: list[str] | None,
|
|
418
446
|
search_filter: SearchFilters,
|
|
419
447
|
bfs_max_depth: int,
|
|
@@ -425,18 +453,21 @@ async def node_bfs_search(
|
|
|
425
453
|
|
|
426
454
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
427
455
|
|
|
428
|
-
|
|
456
|
+
query = (
|
|
429
457
|
"""
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
458
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
459
|
+
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
|
460
|
+
WHERE n.group_id = origin.group_id
|
|
461
|
+
"""
|
|
434
462
|
+ filter_query
|
|
435
463
|
+ ENTITY_NODE_RETURN
|
|
436
464
|
+ """
|
|
437
465
|
LIMIT $limit
|
|
438
|
-
"""
|
|
439
|
-
|
|
466
|
+
"""
|
|
467
|
+
)
|
|
468
|
+
records, _, _ = await driver.execute_query(
|
|
469
|
+
query,
|
|
470
|
+
params=filter_params,
|
|
440
471
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
441
472
|
depth=bfs_max_depth,
|
|
442
473
|
limit=limit,
|
|
@@ -449,7 +480,7 @@ async def node_bfs_search(
|
|
|
449
480
|
|
|
450
481
|
|
|
451
482
|
async def episode_fulltext_search(
|
|
452
|
-
driver:
|
|
483
|
+
driver: GraphDriver,
|
|
453
484
|
query: str,
|
|
454
485
|
_search_filter: SearchFilters,
|
|
455
486
|
group_ids: list[str] | None = None,
|
|
@@ -460,9 +491,9 @@ async def episode_fulltext_search(
|
|
|
460
491
|
if fuzzy_query == '':
|
|
461
492
|
return []
|
|
462
493
|
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
494
|
+
query = (
|
|
495
|
+
get_nodes_query(driver.provider, 'episode_content', '$query')
|
|
496
|
+
+ """
|
|
466
497
|
YIELD node AS episode, score
|
|
467
498
|
MATCH (e:Episodic)
|
|
468
499
|
WHERE e.uuid = episode.uuid
|
|
@@ -478,7 +509,11 @@ async def episode_fulltext_search(
|
|
|
478
509
|
e.entity_edges AS entity_edges
|
|
479
510
|
ORDER BY score DESC
|
|
480
511
|
LIMIT $limit
|
|
481
|
-
"""
|
|
512
|
+
"""
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
records, _, _ = await driver.execute_query(
|
|
516
|
+
query,
|
|
482
517
|
query=fuzzy_query,
|
|
483
518
|
group_ids=group_ids,
|
|
484
519
|
limit=limit,
|
|
@@ -491,7 +526,7 @@ async def episode_fulltext_search(
|
|
|
491
526
|
|
|
492
527
|
|
|
493
528
|
async def community_fulltext_search(
|
|
494
|
-
driver:
|
|
529
|
+
driver: GraphDriver,
|
|
495
530
|
query: str,
|
|
496
531
|
group_ids: list[str] | None = None,
|
|
497
532
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
@@ -501,9 +536,9 @@ async def community_fulltext_search(
|
|
|
501
536
|
if fuzzy_query == '':
|
|
502
537
|
return []
|
|
503
538
|
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
539
|
+
query = (
|
|
540
|
+
get_nodes_query(driver.provider, 'community_name', '$query')
|
|
541
|
+
+ """
|
|
507
542
|
YIELD node AS comm, score
|
|
508
543
|
RETURN
|
|
509
544
|
comm.uuid AS uuid,
|
|
@@ -513,7 +548,11 @@ async def community_fulltext_search(
|
|
|
513
548
|
comm.summary AS summary
|
|
514
549
|
ORDER BY score DESC
|
|
515
550
|
LIMIT $limit
|
|
516
|
-
"""
|
|
551
|
+
"""
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
records, _, _ = await driver.execute_query(
|
|
555
|
+
query,
|
|
517
556
|
query=fuzzy_query,
|
|
518
557
|
group_ids=group_ids,
|
|
519
558
|
limit=limit,
|
|
@@ -526,7 +565,7 @@ async def community_fulltext_search(
|
|
|
526
565
|
|
|
527
566
|
|
|
528
567
|
async def community_similarity_search(
|
|
529
|
-
driver:
|
|
568
|
+
driver: GraphDriver,
|
|
530
569
|
search_vector: list[float],
|
|
531
570
|
group_ids: list[str] | None = None,
|
|
532
571
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
@@ -540,14 +579,16 @@ async def community_similarity_search(
|
|
|
540
579
|
group_filter_query += 'WHERE comm.group_id IN $group_ids'
|
|
541
580
|
query_params['group_ids'] = group_ids
|
|
542
581
|
|
|
543
|
-
|
|
582
|
+
query = (
|
|
544
583
|
RUNTIME_QUERY
|
|
545
584
|
+ """
|
|
546
585
|
MATCH (comm:Community)
|
|
547
586
|
"""
|
|
548
587
|
+ group_filter_query
|
|
549
588
|
+ """
|
|
550
|
-
WITH comm,
|
|
589
|
+
WITH comm, """
|
|
590
|
+
+ get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider)
|
|
591
|
+
+ """ AS score
|
|
551
592
|
WHERE score > $min_score
|
|
552
593
|
RETURN
|
|
553
594
|
comm.uuid As uuid,
|
|
@@ -557,7 +598,11 @@ async def community_similarity_search(
|
|
|
557
598
|
comm.summary AS summary
|
|
558
599
|
ORDER BY score DESC
|
|
559
600
|
LIMIT $limit
|
|
560
|
-
"""
|
|
601
|
+
"""
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
records, _, _ = await driver.execute_query(
|
|
605
|
+
query,
|
|
561
606
|
search_vector=search_vector,
|
|
562
607
|
group_ids=group_ids,
|
|
563
608
|
limit=limit,
|
|
@@ -573,7 +618,7 @@ async def community_similarity_search(
|
|
|
573
618
|
async def hybrid_node_search(
|
|
574
619
|
queries: list[str],
|
|
575
620
|
embeddings: list[list[float]],
|
|
576
|
-
driver:
|
|
621
|
+
driver: GraphDriver,
|
|
577
622
|
search_filter: SearchFilters,
|
|
578
623
|
group_ids: list[str] | None = None,
|
|
579
624
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
@@ -590,7 +635,7 @@ async def hybrid_node_search(
|
|
|
590
635
|
A list of text queries to search for.
|
|
591
636
|
embeddings : list[list[float]]
|
|
592
637
|
A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
|
|
593
|
-
driver :
|
|
638
|
+
driver : GraphDriver
|
|
594
639
|
The Neo4j driver instance for database operations.
|
|
595
640
|
group_ids : list[str] | None, optional
|
|
596
641
|
The list of group ids to retrieve nodes from.
|
|
@@ -645,7 +690,7 @@ async def hybrid_node_search(
|
|
|
645
690
|
|
|
646
691
|
|
|
647
692
|
async def get_relevant_nodes(
|
|
648
|
-
driver:
|
|
693
|
+
driver: GraphDriver,
|
|
649
694
|
nodes: list[EntityNode],
|
|
650
695
|
search_filter: SearchFilters,
|
|
651
696
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
@@ -664,29 +709,33 @@ async def get_relevant_nodes(
|
|
|
664
709
|
|
|
665
710
|
query = (
|
|
666
711
|
RUNTIME_QUERY
|
|
667
|
-
+ """
|
|
668
|
-
|
|
669
|
-
|
|
712
|
+
+ """
|
|
713
|
+
UNWIND $nodes AS node
|
|
714
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
715
|
+
"""
|
|
670
716
|
+ filter_query
|
|
671
717
|
+ """
|
|
672
|
-
WITH node, n,
|
|
718
|
+
WITH node, n, """
|
|
719
|
+
+ get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
|
|
720
|
+
+ """ AS score
|
|
673
721
|
WHERE score > $min_score
|
|
674
722
|
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
|
675
|
-
|
|
676
|
-
|
|
723
|
+
"""
|
|
724
|
+
+ get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
|
|
725
|
+
+ """
|
|
677
726
|
YIELD node AS m
|
|
678
727
|
WHERE m.group_id = $group_id
|
|
679
728
|
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
|
680
|
-
|
|
729
|
+
|
|
681
730
|
WITH node,
|
|
682
731
|
top_vector_nodes,
|
|
683
732
|
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
|
684
|
-
|
|
733
|
+
|
|
685
734
|
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
|
686
|
-
|
|
735
|
+
|
|
687
736
|
UNWIND combined_nodes AS combined_node
|
|
688
737
|
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
|
689
|
-
|
|
738
|
+
|
|
690
739
|
RETURN
|
|
691
740
|
node.uuid AS search_node_uuid,
|
|
692
741
|
[x IN deduped_nodes | {
|
|
@@ -714,7 +763,7 @@ async def get_relevant_nodes(
|
|
|
714
763
|
|
|
715
764
|
results, _, _ = await driver.execute_query(
|
|
716
765
|
query,
|
|
717
|
-
query_params,
|
|
766
|
+
params=query_params,
|
|
718
767
|
nodes=query_nodes,
|
|
719
768
|
group_id=group_id,
|
|
720
769
|
limit=limit,
|
|
@@ -736,7 +785,7 @@ async def get_relevant_nodes(
|
|
|
736
785
|
|
|
737
786
|
|
|
738
787
|
async def get_relevant_edges(
|
|
739
|
-
driver:
|
|
788
|
+
driver: GraphDriver,
|
|
740
789
|
edges: list[EntityEdge],
|
|
741
790
|
search_filter: SearchFilters,
|
|
742
791
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
@@ -752,43 +801,47 @@ async def get_relevant_edges(
|
|
|
752
801
|
|
|
753
802
|
query = (
|
|
754
803
|
RUNTIME_QUERY
|
|
755
|
-
+ """
|
|
756
|
-
|
|
757
|
-
|
|
804
|
+
+ """
|
|
805
|
+
UNWIND $edges AS edge
|
|
806
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
807
|
+
"""
|
|
758
808
|
+ filter_query
|
|
759
809
|
+ """
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
810
|
+
WITH e, edge, """
|
|
811
|
+
+ get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
|
|
812
|
+
+ """ AS score
|
|
813
|
+
WHERE score > $min_score
|
|
814
|
+
WITH edge, e, score
|
|
815
|
+
ORDER BY score DESC
|
|
816
|
+
RETURN edge.uuid AS search_edge_uuid,
|
|
817
|
+
collect({
|
|
818
|
+
uuid: e.uuid,
|
|
819
|
+
source_node_uuid: startNode(e).uuid,
|
|
820
|
+
target_node_uuid: endNode(e).uuid,
|
|
821
|
+
created_at: e.created_at,
|
|
822
|
+
name: e.name,
|
|
823
|
+
group_id: e.group_id,
|
|
824
|
+
fact: e.fact,
|
|
825
|
+
fact_embedding: e.fact_embedding,
|
|
826
|
+
episodes: e.episodes,
|
|
827
|
+
expired_at: e.expired_at,
|
|
828
|
+
valid_at: e.valid_at,
|
|
829
|
+
invalid_at: e.invalid_at,
|
|
830
|
+
attributes: properties(e)
|
|
831
|
+
})[..$limit] AS matches
|
|
780
832
|
"""
|
|
781
833
|
)
|
|
782
834
|
|
|
783
835
|
results, _, _ = await driver.execute_query(
|
|
784
836
|
query,
|
|
785
|
-
query_params,
|
|
837
|
+
params=query_params,
|
|
786
838
|
edges=[edge.model_dump() for edge in edges],
|
|
787
839
|
limit=limit,
|
|
788
840
|
min_score=min_score,
|
|
789
841
|
database_=DEFAULT_DATABASE,
|
|
790
842
|
routing_='r',
|
|
791
843
|
)
|
|
844
|
+
|
|
792
845
|
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
|
793
846
|
result['search_edge_uuid']: [
|
|
794
847
|
get_entity_edge_from_record(record) for record in result['matches']
|
|
@@ -802,7 +855,7 @@ async def get_relevant_edges(
|
|
|
802
855
|
|
|
803
856
|
|
|
804
857
|
async def get_edge_invalidation_candidates(
|
|
805
|
-
driver:
|
|
858
|
+
driver: GraphDriver,
|
|
806
859
|
edges: list[EntityEdge],
|
|
807
860
|
search_filter: SearchFilters,
|
|
808
861
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
@@ -818,38 +871,41 @@ async def get_edge_invalidation_candidates(
|
|
|
818
871
|
|
|
819
872
|
query = (
|
|
820
873
|
RUNTIME_QUERY
|
|
821
|
-
+ """
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
874
|
+
+ """
|
|
875
|
+
UNWIND $edges AS edge
|
|
876
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
877
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
878
|
+
"""
|
|
825
879
|
+ filter_query
|
|
826
880
|
+ """
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
881
|
+
WITH edge, e, """
|
|
882
|
+
+ get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
|
|
883
|
+
+ """ AS score
|
|
884
|
+
WHERE score > $min_score
|
|
885
|
+
WITH edge, e, score
|
|
886
|
+
ORDER BY score DESC
|
|
887
|
+
RETURN edge.uuid AS search_edge_uuid,
|
|
888
|
+
collect({
|
|
889
|
+
uuid: e.uuid,
|
|
890
|
+
source_node_uuid: startNode(e).uuid,
|
|
891
|
+
target_node_uuid: endNode(e).uuid,
|
|
892
|
+
created_at: e.created_at,
|
|
893
|
+
name: e.name,
|
|
894
|
+
group_id: e.group_id,
|
|
895
|
+
fact: e.fact,
|
|
896
|
+
fact_embedding: e.fact_embedding,
|
|
897
|
+
episodes: e.episodes,
|
|
898
|
+
expired_at: e.expired_at,
|
|
899
|
+
valid_at: e.valid_at,
|
|
900
|
+
invalid_at: e.invalid_at,
|
|
901
|
+
attributes: properties(e)
|
|
902
|
+
})[..$limit] AS matches
|
|
847
903
|
"""
|
|
848
904
|
)
|
|
849
905
|
|
|
850
906
|
results, _, _ = await driver.execute_query(
|
|
851
907
|
query,
|
|
852
|
-
query_params,
|
|
908
|
+
params=query_params,
|
|
853
909
|
edges=[edge.model_dump() for edge in edges],
|
|
854
910
|
limit=limit,
|
|
855
911
|
min_score=min_score,
|
|
@@ -884,7 +940,7 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
|
|
|
884
940
|
|
|
885
941
|
|
|
886
942
|
async def node_distance_reranker(
|
|
887
|
-
driver:
|
|
943
|
+
driver: GraphDriver,
|
|
888
944
|
node_uuids: list[str],
|
|
889
945
|
center_node_uuid: str,
|
|
890
946
|
min_score: float = 0,
|
|
@@ -894,21 +950,22 @@ async def node_distance_reranker(
|
|
|
894
950
|
scores: dict[str, float] = {center_node_uuid: 0.0}
|
|
895
951
|
|
|
896
952
|
# Find the shortest path to center node
|
|
897
|
-
query =
|
|
953
|
+
query = """
|
|
898
954
|
UNWIND $node_uuids AS node_uuid
|
|
899
|
-
MATCH
|
|
900
|
-
RETURN
|
|
901
|
-
"""
|
|
902
|
-
|
|
903
|
-
path_results, _, _ = await driver.execute_query(
|
|
955
|
+
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
|
956
|
+
RETURN 1 AS score, node_uuid AS uuid
|
|
957
|
+
"""
|
|
958
|
+
results, header, _ = await driver.execute_query(
|
|
904
959
|
query,
|
|
905
960
|
node_uuids=filtered_uuids,
|
|
906
961
|
center_uuid=center_node_uuid,
|
|
907
962
|
database_=DEFAULT_DATABASE,
|
|
908
963
|
routing_='r',
|
|
909
964
|
)
|
|
965
|
+
if driver.provider == 'falkordb':
|
|
966
|
+
results = [dict(zip(header, row, strict=True)) for row in results]
|
|
910
967
|
|
|
911
|
-
for result in
|
|
968
|
+
for result in results:
|
|
912
969
|
uuid = result['uuid']
|
|
913
970
|
score = result['score']
|
|
914
971
|
scores[uuid] = score
|
|
@@ -929,19 +986,18 @@ async def node_distance_reranker(
|
|
|
929
986
|
|
|
930
987
|
|
|
931
988
|
async def episode_mentions_reranker(
|
|
932
|
-
driver:
|
|
989
|
+
driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
|
|
933
990
|
) -> list[str]:
|
|
934
991
|
# use rrf as a preliminary ranker
|
|
935
992
|
sorted_uuids = rrf(node_uuids)
|
|
936
993
|
scores: dict[str, float] = {}
|
|
937
994
|
|
|
938
995
|
# Find the shortest path to center node
|
|
939
|
-
query =
|
|
996
|
+
query = """
|
|
940
997
|
UNWIND $node_uuids AS node_uuid
|
|
941
998
|
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
|
942
999
|
RETURN count(*) AS score, n.uuid AS uuid
|
|
943
|
-
"""
|
|
944
|
-
|
|
1000
|
+
"""
|
|
945
1001
|
results, _, _ = await driver.execute_query(
|
|
946
1002
|
query,
|
|
947
1003
|
node_uuids=sorted_uuids,
|
|
@@ -998,7 +1054,7 @@ def maximal_marginal_relevance(
|
|
|
998
1054
|
|
|
999
1055
|
|
|
1000
1056
|
async def get_embeddings_for_nodes(
|
|
1001
|
-
driver:
|
|
1057
|
+
driver: GraphDriver, nodes: list[EntityNode]
|
|
1002
1058
|
) -> dict[str, list[float]]:
|
|
1003
1059
|
query: LiteralString = """MATCH (n:Entity)
|
|
1004
1060
|
WHERE n.uuid IN $node_uuids
|
|
@@ -1022,7 +1078,7 @@ async def get_embeddings_for_nodes(
|
|
|
1022
1078
|
|
|
1023
1079
|
|
|
1024
1080
|
async def get_embeddings_for_communities(
|
|
1025
|
-
driver:
|
|
1081
|
+
driver: GraphDriver, communities: list[CommunityNode]
|
|
1026
1082
|
) -> dict[str, list[float]]:
|
|
1027
1083
|
query: LiteralString = """MATCH (c:Community)
|
|
1028
1084
|
WHERE c.uuid IN $community_uuids
|
|
@@ -1049,7 +1105,7 @@ async def get_embeddings_for_communities(
|
|
|
1049
1105
|
|
|
1050
1106
|
|
|
1051
1107
|
async def get_embeddings_for_edges(
|
|
1052
|
-
driver:
|
|
1108
|
+
driver: GraphDriver, edges: list[EntityEdge]
|
|
1053
1109
|
) -> dict[str, list[float]]:
|
|
1054
1110
|
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1055
1111
|
WHERE e.uuid IN $edge_uuids
|