graphiti-core 0.11.6rc9__py3-none-any.whl → 0.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/cross_encoder/openai_reranker_client.py +1 -1
- graphiti_core/driver/__init__.py +17 -0
- graphiti_core/driver/driver.py +66 -0
- graphiti_core/driver/falkordb_driver.py +132 -0
- graphiti_core/driver/neo4j_driver.py +61 -0
- graphiti_core/edges.py +66 -40
- graphiti_core/embedder/azure_openai.py +64 -0
- graphiti_core/embedder/gemini.py +14 -3
- graphiti_core/graph_queries.py +149 -0
- graphiti_core/graphiti.py +41 -14
- graphiti_core/graphiti_types.py +2 -2
- graphiti_core/helpers.py +9 -4
- graphiti_core/llm_client/__init__.py +16 -0
- graphiti_core/llm_client/azure_openai_client.py +73 -0
- graphiti_core/llm_client/gemini_client.py +4 -1
- graphiti_core/models/edges/edge_db_queries.py +2 -4
- graphiti_core/nodes.py +31 -31
- graphiti_core/prompts/dedupe_edges.py +52 -1
- graphiti_core/prompts/dedupe_nodes.py +79 -4
- graphiti_core/prompts/extract_edges.py +50 -5
- graphiti_core/prompts/invalidate_edges.py +1 -1
- graphiti_core/search/search.py +6 -10
- graphiti_core/search/search_filters.py +23 -9
- graphiti_core/search/search_utils.py +250 -189
- graphiti_core/utils/bulk_utils.py +38 -11
- graphiti_core/utils/maintenance/community_operations.py +6 -7
- graphiti_core/utils/maintenance/edge_operations.py +149 -19
- graphiti_core/utils/maintenance/graph_data_operations.py +13 -42
- graphiti_core/utils/maintenance/node_operations.py +52 -71
- {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/METADATA +14 -5
- {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/RECORD +33 -26
- {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/LICENSE +0 -0
- {graphiti_core-0.11.6rc9.dist-info → graphiti_core-0.12.0.dist-info}/WHEEL +0 -0
|
@@ -20,11 +20,16 @@ from time import time
|
|
|
20
20
|
from typing import Any
|
|
21
21
|
|
|
22
22
|
import numpy as np
|
|
23
|
-
from neo4j import AsyncDriver, Query
|
|
24
23
|
from numpy._typing import NDArray
|
|
25
24
|
from typing_extensions import LiteralString
|
|
26
25
|
|
|
26
|
+
from graphiti_core.driver.driver import GraphDriver
|
|
27
27
|
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
|
|
28
|
+
from graphiti_core.graph_queries import (
|
|
29
|
+
get_nodes_query,
|
|
30
|
+
get_relationships_query,
|
|
31
|
+
get_vector_cosine_func_query,
|
|
32
|
+
)
|
|
28
33
|
from graphiti_core.helpers import (
|
|
29
34
|
DEFAULT_DATABASE,
|
|
30
35
|
RUNTIME_QUERY,
|
|
@@ -58,7 +63,7 @@ MAX_QUERY_LENGTH = 32
|
|
|
58
63
|
|
|
59
64
|
def fulltext_query(query: str, group_ids: list[str] | None = None):
|
|
60
65
|
group_ids_filter_list = (
|
|
61
|
-
[f'
|
|
66
|
+
[f"group_id-'{lucene_sanitize(g)}'" for g in group_ids] if group_ids is not None else []
|
|
62
67
|
)
|
|
63
68
|
group_ids_filter = ''
|
|
64
69
|
for f in group_ids_filter_list:
|
|
@@ -77,7 +82,7 @@ def fulltext_query(query: str, group_ids: list[str] | None = None):
|
|
|
77
82
|
|
|
78
83
|
|
|
79
84
|
async def get_episodes_by_mentions(
|
|
80
|
-
driver:
|
|
85
|
+
driver: GraphDriver,
|
|
81
86
|
nodes: list[EntityNode],
|
|
82
87
|
edges: list[EntityEdge],
|
|
83
88
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
@@ -92,11 +97,11 @@ async def get_episodes_by_mentions(
|
|
|
92
97
|
|
|
93
98
|
|
|
94
99
|
async def get_mentioned_nodes(
|
|
95
|
-
driver:
|
|
100
|
+
driver: GraphDriver, episodes: list[EpisodicNode]
|
|
96
101
|
) -> list[EntityNode]:
|
|
97
102
|
episode_uuids = [episode.uuid for episode in episodes]
|
|
98
|
-
|
|
99
|
-
|
|
103
|
+
|
|
104
|
+
query = """
|
|
100
105
|
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
|
|
101
106
|
RETURN DISTINCT
|
|
102
107
|
n.uuid As uuid,
|
|
@@ -106,7 +111,10 @@ async def get_mentioned_nodes(
|
|
|
106
111
|
n.summary AS summary,
|
|
107
112
|
labels(n) AS labels,
|
|
108
113
|
properties(n) AS attributes
|
|
109
|
-
"""
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
records, _, _ = await driver.execute_query(
|
|
117
|
+
query,
|
|
110
118
|
uuids=episode_uuids,
|
|
111
119
|
database_=DEFAULT_DATABASE,
|
|
112
120
|
routing_='r',
|
|
@@ -118,11 +126,11 @@ async def get_mentioned_nodes(
|
|
|
118
126
|
|
|
119
127
|
|
|
120
128
|
async def get_communities_by_nodes(
|
|
121
|
-
driver:
|
|
129
|
+
driver: GraphDriver, nodes: list[EntityNode]
|
|
122
130
|
) -> list[CommunityNode]:
|
|
123
131
|
node_uuids = [node.uuid for node in nodes]
|
|
124
|
-
|
|
125
|
-
|
|
132
|
+
|
|
133
|
+
query = """
|
|
126
134
|
MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids
|
|
127
135
|
RETURN DISTINCT
|
|
128
136
|
c.uuid As uuid,
|
|
@@ -130,7 +138,10 @@ async def get_communities_by_nodes(
|
|
|
130
138
|
c.name AS name,
|
|
131
139
|
c.created_at AS created_at,
|
|
132
140
|
c.summary AS summary
|
|
133
|
-
"""
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
records, _, _ = await driver.execute_query(
|
|
144
|
+
query,
|
|
134
145
|
uuids=node_uuids,
|
|
135
146
|
database_=DEFAULT_DATABASE,
|
|
136
147
|
routing_='r',
|
|
@@ -142,7 +153,7 @@ async def get_communities_by_nodes(
|
|
|
142
153
|
|
|
143
154
|
|
|
144
155
|
async def edge_fulltext_search(
|
|
145
|
-
driver:
|
|
156
|
+
driver: GraphDriver,
|
|
146
157
|
query: str,
|
|
147
158
|
search_filter: SearchFilters,
|
|
148
159
|
group_ids: list[str] | None = None,
|
|
@@ -155,33 +166,35 @@ async def edge_fulltext_search(
|
|
|
155
166
|
|
|
156
167
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
157
168
|
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
169
|
+
query = (
|
|
170
|
+
get_relationships_query('edge_name_and_fact', db_type=driver.provider)
|
|
171
|
+
+ """
|
|
172
|
+
YIELD relationship AS rel, score
|
|
173
|
+
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
|
174
|
+
WHERE r.group_id IN $group_ids """
|
|
164
175
|
+ filter_query
|
|
165
|
-
+ """
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
176
|
+
+ """
|
|
177
|
+
WITH r, score, startNode(r) AS n, endNode(r) AS m
|
|
178
|
+
RETURN
|
|
179
|
+
r.uuid AS uuid,
|
|
180
|
+
r.group_id AS group_id,
|
|
181
|
+
n.uuid AS source_node_uuid,
|
|
182
|
+
m.uuid AS target_node_uuid,
|
|
183
|
+
r.created_at AS created_at,
|
|
184
|
+
r.name AS name,
|
|
185
|
+
r.fact AS fact,
|
|
186
|
+
r.episodes AS episodes,
|
|
187
|
+
r.expired_at AS expired_at,
|
|
188
|
+
r.valid_at AS valid_at,
|
|
189
|
+
r.invalid_at AS invalid_at,
|
|
190
|
+
properties(r) AS attributes
|
|
191
|
+
ORDER BY score DESC LIMIT $limit
|
|
192
|
+
"""
|
|
180
193
|
)
|
|
181
194
|
|
|
182
195
|
records, _, _ = await driver.execute_query(
|
|
183
|
-
|
|
184
|
-
filter_params,
|
|
196
|
+
query,
|
|
197
|
+
params=filter_params,
|
|
185
198
|
query=fuzzy_query,
|
|
186
199
|
group_ids=group_ids,
|
|
187
200
|
limit=limit,
|
|
@@ -195,7 +208,7 @@ async def edge_fulltext_search(
|
|
|
195
208
|
|
|
196
209
|
|
|
197
210
|
async def edge_similarity_search(
|
|
198
|
-
driver:
|
|
211
|
+
driver: GraphDriver,
|
|
199
212
|
search_vector: list[float],
|
|
200
213
|
source_node_uuid: str | None,
|
|
201
214
|
target_node_uuid: str | None,
|
|
@@ -210,9 +223,9 @@ async def edge_similarity_search(
|
|
|
210
223
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
211
224
|
query_params.update(filter_params)
|
|
212
225
|
|
|
213
|
-
group_filter_query: LiteralString = ''
|
|
226
|
+
group_filter_query: LiteralString = 'WHERE r.group_id IS NOT NULL'
|
|
214
227
|
if group_ids is not None:
|
|
215
|
-
group_filter_query += '
|
|
228
|
+
group_filter_query += '\nAND r.group_id IN $group_ids'
|
|
216
229
|
query_params['group_ids'] = group_ids
|
|
217
230
|
query_params['source_node_uuid'] = source_node_uuid
|
|
218
231
|
query_params['target_node_uuid'] = target_node_uuid
|
|
@@ -223,35 +236,38 @@ async def edge_similarity_search(
|
|
|
223
236
|
if target_node_uuid is not None:
|
|
224
237
|
group_filter_query += '\nAND (m.uuid IN [$source_uuid, $target_uuid])'
|
|
225
238
|
|
|
226
|
-
query
|
|
239
|
+
query = (
|
|
227
240
|
RUNTIME_QUERY
|
|
228
241
|
+ """
|
|
229
|
-
|
|
230
|
-
|
|
242
|
+
MATCH (n:Entity)-[r:RELATES_TO]->(m:Entity)
|
|
243
|
+
"""
|
|
231
244
|
+ group_filter_query
|
|
232
245
|
+ filter_query
|
|
233
|
-
+ """
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
246
|
+
+ """
|
|
247
|
+
WITH DISTINCT r, """
|
|
248
|
+
+ get_vector_cosine_func_query('r.fact_embedding', '$search_vector', driver.provider)
|
|
249
|
+
+ """ AS score
|
|
250
|
+
WHERE score > $min_score
|
|
251
|
+
RETURN
|
|
252
|
+
r.uuid AS uuid,
|
|
253
|
+
r.group_id AS group_id,
|
|
254
|
+
startNode(r).uuid AS source_node_uuid,
|
|
255
|
+
endNode(r).uuid AS target_node_uuid,
|
|
256
|
+
r.created_at AS created_at,
|
|
257
|
+
r.name AS name,
|
|
258
|
+
r.fact AS fact,
|
|
259
|
+
r.episodes AS episodes,
|
|
260
|
+
r.expired_at AS expired_at,
|
|
261
|
+
r.valid_at AS valid_at,
|
|
262
|
+
r.invalid_at AS invalid_at,
|
|
263
|
+
properties(r) AS attributes
|
|
264
|
+
ORDER BY score DESC
|
|
265
|
+
LIMIT $limit
|
|
249
266
|
"""
|
|
250
267
|
)
|
|
251
|
-
|
|
252
|
-
records, _, _ = await driver.execute_query(
|
|
268
|
+
records, header, _ = await driver.execute_query(
|
|
253
269
|
query,
|
|
254
|
-
query_params,
|
|
270
|
+
params=query_params,
|
|
255
271
|
search_vector=search_vector,
|
|
256
272
|
source_uuid=source_node_uuid,
|
|
257
273
|
target_uuid=target_node_uuid,
|
|
@@ -262,13 +278,16 @@ async def edge_similarity_search(
|
|
|
262
278
|
routing_='r',
|
|
263
279
|
)
|
|
264
280
|
|
|
281
|
+
if driver.provider == 'falkordb':
|
|
282
|
+
records = [dict(zip(header, row, strict=True)) for row in records]
|
|
283
|
+
|
|
265
284
|
edges = [get_entity_edge_from_record(record) for record in records]
|
|
266
285
|
|
|
267
286
|
return edges
|
|
268
287
|
|
|
269
288
|
|
|
270
289
|
async def edge_bfs_search(
|
|
271
|
-
driver:
|
|
290
|
+
driver: GraphDriver,
|
|
272
291
|
bfs_origin_node_uuids: list[str] | None,
|
|
273
292
|
bfs_max_depth: int,
|
|
274
293
|
search_filter: SearchFilters,
|
|
@@ -280,14 +299,14 @@ async def edge_bfs_search(
|
|
|
280
299
|
|
|
281
300
|
filter_query, filter_params = edge_search_filter_query_constructor(search_filter)
|
|
282
301
|
|
|
283
|
-
query =
|
|
302
|
+
query = (
|
|
284
303
|
"""
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
304
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
305
|
+
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
|
306
|
+
UNWIND relationships(path) AS rel
|
|
307
|
+
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
|
308
|
+
WHERE r.uuid = rel.uuid
|
|
309
|
+
"""
|
|
291
310
|
+ filter_query
|
|
292
311
|
+ """
|
|
293
312
|
RETURN DISTINCT
|
|
@@ -301,14 +320,15 @@ async def edge_bfs_search(
|
|
|
301
320
|
r.episodes AS episodes,
|
|
302
321
|
r.expired_at AS expired_at,
|
|
303
322
|
r.valid_at AS valid_at,
|
|
304
|
-
r.invalid_at AS invalid_at
|
|
323
|
+
r.invalid_at AS invalid_at,
|
|
324
|
+
properties(r) AS attributes
|
|
305
325
|
LIMIT $limit
|
|
306
326
|
"""
|
|
307
327
|
)
|
|
308
328
|
|
|
309
329
|
records, _, _ = await driver.execute_query(
|
|
310
330
|
query,
|
|
311
|
-
filter_params,
|
|
331
|
+
params=filter_params,
|
|
312
332
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
313
333
|
depth=bfs_max_depth,
|
|
314
334
|
limit=limit,
|
|
@@ -322,7 +342,7 @@ async def edge_bfs_search(
|
|
|
322
342
|
|
|
323
343
|
|
|
324
344
|
async def node_fulltext_search(
|
|
325
|
-
driver:
|
|
345
|
+
driver: GraphDriver,
|
|
326
346
|
query: str,
|
|
327
347
|
search_filter: SearchFilters,
|
|
328
348
|
group_ids: list[str] | None = None,
|
|
@@ -332,38 +352,41 @@ async def node_fulltext_search(
|
|
|
332
352
|
fuzzy_query = fulltext_query(query, group_ids)
|
|
333
353
|
if fuzzy_query == '':
|
|
334
354
|
return []
|
|
335
|
-
|
|
336
355
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
337
356
|
|
|
338
357
|
query = (
|
|
358
|
+
get_nodes_query(driver.provider, 'node_name_and_summary', '$query')
|
|
359
|
+
+ """
|
|
360
|
+
YIELD node AS n, score
|
|
361
|
+
WITH n, score
|
|
362
|
+
LIMIT $limit
|
|
363
|
+
WHERE n:Entity
|
|
339
364
|
"""
|
|
340
|
-
CALL db.index.fulltext.queryNodes("node_name_and_summary", $query, {limit: $limit})
|
|
341
|
-
YIELD node AS n, score
|
|
342
|
-
WHERE n:Entity
|
|
343
|
-
"""
|
|
344
365
|
+ filter_query
|
|
345
366
|
+ ENTITY_NODE_RETURN
|
|
346
367
|
+ """
|
|
347
368
|
ORDER BY score DESC
|
|
348
369
|
"""
|
|
349
370
|
)
|
|
350
|
-
|
|
351
|
-
records, _, _ = await driver.execute_query(
|
|
371
|
+
records, header, _ = await driver.execute_query(
|
|
352
372
|
query,
|
|
353
|
-
filter_params,
|
|
373
|
+
params=filter_params,
|
|
354
374
|
query=fuzzy_query,
|
|
355
375
|
group_ids=group_ids,
|
|
356
376
|
limit=limit,
|
|
357
377
|
database_=DEFAULT_DATABASE,
|
|
358
378
|
routing_='r',
|
|
359
379
|
)
|
|
380
|
+
if driver.provider == 'falkordb':
|
|
381
|
+
records = [dict(zip(header, row, strict=True)) for row in records]
|
|
382
|
+
|
|
360
383
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
361
384
|
|
|
362
385
|
return nodes
|
|
363
386
|
|
|
364
387
|
|
|
365
388
|
async def node_similarity_search(
|
|
366
|
-
driver:
|
|
389
|
+
driver: GraphDriver,
|
|
367
390
|
search_vector: list[float],
|
|
368
391
|
search_filter: SearchFilters,
|
|
369
392
|
group_ids: list[str] | None = None,
|
|
@@ -373,30 +396,36 @@ async def node_similarity_search(
|
|
|
373
396
|
# vector similarity search over entity names
|
|
374
397
|
query_params: dict[str, Any] = {}
|
|
375
398
|
|
|
376
|
-
group_filter_query: LiteralString = ''
|
|
399
|
+
group_filter_query: LiteralString = 'WHERE n.group_id IS NOT NULL'
|
|
377
400
|
if group_ids is not None:
|
|
378
|
-
group_filter_query += '
|
|
401
|
+
group_filter_query += ' AND n.group_id IN $group_ids'
|
|
379
402
|
query_params['group_ids'] = group_ids
|
|
380
403
|
|
|
381
404
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
382
405
|
query_params.update(filter_params)
|
|
383
406
|
|
|
384
|
-
|
|
407
|
+
query = (
|
|
385
408
|
RUNTIME_QUERY
|
|
386
409
|
+ """
|
|
387
|
-
|
|
388
|
-
|
|
410
|
+
MATCH (n:Entity)
|
|
411
|
+
"""
|
|
389
412
|
+ group_filter_query
|
|
390
413
|
+ filter_query
|
|
391
414
|
+ """
|
|
392
|
-
|
|
393
|
-
|
|
415
|
+
WITH n, """
|
|
416
|
+
+ get_vector_cosine_func_query('n.name_embedding', '$search_vector', driver.provider)
|
|
417
|
+
+ """ AS score
|
|
418
|
+
WHERE score > $min_score"""
|
|
394
419
|
+ ENTITY_NODE_RETURN
|
|
395
420
|
+ """
|
|
396
421
|
ORDER BY score DESC
|
|
397
422
|
LIMIT $limit
|
|
398
|
-
|
|
399
|
-
|
|
423
|
+
"""
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
records, header, _ = await driver.execute_query(
|
|
427
|
+
query,
|
|
428
|
+
params=query_params,
|
|
400
429
|
search_vector=search_vector,
|
|
401
430
|
group_ids=group_ids,
|
|
402
431
|
limit=limit,
|
|
@@ -404,13 +433,15 @@ async def node_similarity_search(
|
|
|
404
433
|
database_=DEFAULT_DATABASE,
|
|
405
434
|
routing_='r',
|
|
406
435
|
)
|
|
436
|
+
if driver.provider == 'falkordb':
|
|
437
|
+
records = [dict(zip(header, row, strict=True)) for row in records]
|
|
407
438
|
nodes = [get_entity_node_from_record(record) for record in records]
|
|
408
439
|
|
|
409
440
|
return nodes
|
|
410
441
|
|
|
411
442
|
|
|
412
443
|
async def node_bfs_search(
|
|
413
|
-
driver:
|
|
444
|
+
driver: GraphDriver,
|
|
414
445
|
bfs_origin_node_uuids: list[str] | None,
|
|
415
446
|
search_filter: SearchFilters,
|
|
416
447
|
bfs_max_depth: int,
|
|
@@ -422,18 +453,21 @@ async def node_bfs_search(
|
|
|
422
453
|
|
|
423
454
|
filter_query, filter_params = node_search_filter_query_constructor(search_filter)
|
|
424
455
|
|
|
425
|
-
|
|
456
|
+
query = (
|
|
426
457
|
"""
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
458
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
459
|
+
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
|
460
|
+
WHERE n.group_id = origin.group_id
|
|
461
|
+
"""
|
|
431
462
|
+ filter_query
|
|
432
463
|
+ ENTITY_NODE_RETURN
|
|
433
464
|
+ """
|
|
434
465
|
LIMIT $limit
|
|
435
|
-
"""
|
|
436
|
-
|
|
466
|
+
"""
|
|
467
|
+
)
|
|
468
|
+
records, _, _ = await driver.execute_query(
|
|
469
|
+
query,
|
|
470
|
+
params=filter_params,
|
|
437
471
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
438
472
|
depth=bfs_max_depth,
|
|
439
473
|
limit=limit,
|
|
@@ -446,7 +480,7 @@ async def node_bfs_search(
|
|
|
446
480
|
|
|
447
481
|
|
|
448
482
|
async def episode_fulltext_search(
|
|
449
|
-
driver:
|
|
483
|
+
driver: GraphDriver,
|
|
450
484
|
query: str,
|
|
451
485
|
_search_filter: SearchFilters,
|
|
452
486
|
group_ids: list[str] | None = None,
|
|
@@ -457,9 +491,9 @@ async def episode_fulltext_search(
|
|
|
457
491
|
if fuzzy_query == '':
|
|
458
492
|
return []
|
|
459
493
|
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
494
|
+
query = (
|
|
495
|
+
get_nodes_query(driver.provider, 'episode_content', '$query')
|
|
496
|
+
+ """
|
|
463
497
|
YIELD node AS episode, score
|
|
464
498
|
MATCH (e:Episodic)
|
|
465
499
|
WHERE e.uuid = episode.uuid
|
|
@@ -475,7 +509,11 @@ async def episode_fulltext_search(
|
|
|
475
509
|
e.entity_edges AS entity_edges
|
|
476
510
|
ORDER BY score DESC
|
|
477
511
|
LIMIT $limit
|
|
478
|
-
"""
|
|
512
|
+
"""
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
records, _, _ = await driver.execute_query(
|
|
516
|
+
query,
|
|
479
517
|
query=fuzzy_query,
|
|
480
518
|
group_ids=group_ids,
|
|
481
519
|
limit=limit,
|
|
@@ -488,7 +526,7 @@ async def episode_fulltext_search(
|
|
|
488
526
|
|
|
489
527
|
|
|
490
528
|
async def community_fulltext_search(
|
|
491
|
-
driver:
|
|
529
|
+
driver: GraphDriver,
|
|
492
530
|
query: str,
|
|
493
531
|
group_ids: list[str] | None = None,
|
|
494
532
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
@@ -498,9 +536,9 @@ async def community_fulltext_search(
|
|
|
498
536
|
if fuzzy_query == '':
|
|
499
537
|
return []
|
|
500
538
|
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
539
|
+
query = (
|
|
540
|
+
get_nodes_query(driver.provider, 'community_name', '$query')
|
|
541
|
+
+ """
|
|
504
542
|
YIELD node AS comm, score
|
|
505
543
|
RETURN
|
|
506
544
|
comm.uuid AS uuid,
|
|
@@ -510,7 +548,11 @@ async def community_fulltext_search(
|
|
|
510
548
|
comm.summary AS summary
|
|
511
549
|
ORDER BY score DESC
|
|
512
550
|
LIMIT $limit
|
|
513
|
-
"""
|
|
551
|
+
"""
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
records, _, _ = await driver.execute_query(
|
|
555
|
+
query,
|
|
514
556
|
query=fuzzy_query,
|
|
515
557
|
group_ids=group_ids,
|
|
516
558
|
limit=limit,
|
|
@@ -523,7 +565,7 @@ async def community_fulltext_search(
|
|
|
523
565
|
|
|
524
566
|
|
|
525
567
|
async def community_similarity_search(
|
|
526
|
-
driver:
|
|
568
|
+
driver: GraphDriver,
|
|
527
569
|
search_vector: list[float],
|
|
528
570
|
group_ids: list[str] | None = None,
|
|
529
571
|
limit=RELEVANT_SCHEMA_LIMIT,
|
|
@@ -537,14 +579,16 @@ async def community_similarity_search(
|
|
|
537
579
|
group_filter_query += 'WHERE comm.group_id IN $group_ids'
|
|
538
580
|
query_params['group_ids'] = group_ids
|
|
539
581
|
|
|
540
|
-
|
|
582
|
+
query = (
|
|
541
583
|
RUNTIME_QUERY
|
|
542
584
|
+ """
|
|
543
585
|
MATCH (comm:Community)
|
|
544
586
|
"""
|
|
545
587
|
+ group_filter_query
|
|
546
588
|
+ """
|
|
547
|
-
WITH comm,
|
|
589
|
+
WITH comm, """
|
|
590
|
+
+ get_vector_cosine_func_query('comm.name_embedding', '$search_vector', driver.provider)
|
|
591
|
+
+ """ AS score
|
|
548
592
|
WHERE score > $min_score
|
|
549
593
|
RETURN
|
|
550
594
|
comm.uuid As uuid,
|
|
@@ -554,7 +598,11 @@ async def community_similarity_search(
|
|
|
554
598
|
comm.summary AS summary
|
|
555
599
|
ORDER BY score DESC
|
|
556
600
|
LIMIT $limit
|
|
557
|
-
"""
|
|
601
|
+
"""
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
records, _, _ = await driver.execute_query(
|
|
605
|
+
query,
|
|
558
606
|
search_vector=search_vector,
|
|
559
607
|
group_ids=group_ids,
|
|
560
608
|
limit=limit,
|
|
@@ -570,7 +618,7 @@ async def community_similarity_search(
|
|
|
570
618
|
async def hybrid_node_search(
|
|
571
619
|
queries: list[str],
|
|
572
620
|
embeddings: list[list[float]],
|
|
573
|
-
driver:
|
|
621
|
+
driver: GraphDriver,
|
|
574
622
|
search_filter: SearchFilters,
|
|
575
623
|
group_ids: list[str] | None = None,
|
|
576
624
|
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
@@ -587,7 +635,7 @@ async def hybrid_node_search(
|
|
|
587
635
|
A list of text queries to search for.
|
|
588
636
|
embeddings : list[list[float]]
|
|
589
637
|
A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
|
|
590
|
-
driver :
|
|
638
|
+
driver : GraphDriver
|
|
591
639
|
The Neo4j driver instance for database operations.
|
|
592
640
|
group_ids : list[str] | None, optional
|
|
593
641
|
The list of group ids to retrieve nodes from.
|
|
@@ -642,7 +690,7 @@ async def hybrid_node_search(
|
|
|
642
690
|
|
|
643
691
|
|
|
644
692
|
async def get_relevant_nodes(
|
|
645
|
-
driver:
|
|
693
|
+
driver: GraphDriver,
|
|
646
694
|
nodes: list[EntityNode],
|
|
647
695
|
search_filter: SearchFilters,
|
|
648
696
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
@@ -661,29 +709,33 @@ async def get_relevant_nodes(
|
|
|
661
709
|
|
|
662
710
|
query = (
|
|
663
711
|
RUNTIME_QUERY
|
|
664
|
-
+ """
|
|
665
|
-
|
|
666
|
-
|
|
712
|
+
+ """
|
|
713
|
+
UNWIND $nodes AS node
|
|
714
|
+
MATCH (n:Entity {group_id: $group_id})
|
|
715
|
+
"""
|
|
667
716
|
+ filter_query
|
|
668
717
|
+ """
|
|
669
|
-
WITH node, n,
|
|
718
|
+
WITH node, n, """
|
|
719
|
+
+ get_vector_cosine_func_query('n.name_embedding', 'node.name_embedding', driver.provider)
|
|
720
|
+
+ """ AS score
|
|
670
721
|
WHERE score > $min_score
|
|
671
722
|
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
|
|
672
|
-
|
|
673
|
-
|
|
723
|
+
"""
|
|
724
|
+
+ get_nodes_query(driver.provider, 'node_name_and_summary', 'node.fulltext_query')
|
|
725
|
+
+ """
|
|
674
726
|
YIELD node AS m
|
|
675
727
|
WHERE m.group_id = $group_id
|
|
676
728
|
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
|
|
677
|
-
|
|
729
|
+
|
|
678
730
|
WITH node,
|
|
679
731
|
top_vector_nodes,
|
|
680
732
|
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
|
|
681
|
-
|
|
733
|
+
|
|
682
734
|
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
|
|
683
|
-
|
|
735
|
+
|
|
684
736
|
UNWIND combined_nodes AS combined_node
|
|
685
737
|
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
|
|
686
|
-
|
|
738
|
+
|
|
687
739
|
RETURN
|
|
688
740
|
node.uuid AS search_node_uuid,
|
|
689
741
|
[x IN deduped_nodes | {
|
|
@@ -711,7 +763,7 @@ async def get_relevant_nodes(
|
|
|
711
763
|
|
|
712
764
|
results, _, _ = await driver.execute_query(
|
|
713
765
|
query,
|
|
714
|
-
query_params,
|
|
766
|
+
params=query_params,
|
|
715
767
|
nodes=query_nodes,
|
|
716
768
|
group_id=group_id,
|
|
717
769
|
limit=limit,
|
|
@@ -733,7 +785,7 @@ async def get_relevant_nodes(
|
|
|
733
785
|
|
|
734
786
|
|
|
735
787
|
async def get_relevant_edges(
|
|
736
|
-
driver:
|
|
788
|
+
driver: GraphDriver,
|
|
737
789
|
edges: list[EntityEdge],
|
|
738
790
|
search_filter: SearchFilters,
|
|
739
791
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
@@ -749,42 +801,47 @@ async def get_relevant_edges(
|
|
|
749
801
|
|
|
750
802
|
query = (
|
|
751
803
|
RUNTIME_QUERY
|
|
752
|
-
+ """
|
|
753
|
-
|
|
754
|
-
|
|
804
|
+
+ """
|
|
805
|
+
UNWIND $edges AS edge
|
|
806
|
+
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
|
|
807
|
+
"""
|
|
755
808
|
+ filter_query
|
|
756
809
|
+ """
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
810
|
+
WITH e, edge, """
|
|
811
|
+
+ get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
|
|
812
|
+
+ """ AS score
|
|
813
|
+
WHERE score > $min_score
|
|
814
|
+
WITH edge, e, score
|
|
815
|
+
ORDER BY score DESC
|
|
816
|
+
RETURN edge.uuid AS search_edge_uuid,
|
|
817
|
+
collect({
|
|
818
|
+
uuid: e.uuid,
|
|
819
|
+
source_node_uuid: startNode(e).uuid,
|
|
820
|
+
target_node_uuid: endNode(e).uuid,
|
|
821
|
+
created_at: e.created_at,
|
|
822
|
+
name: e.name,
|
|
823
|
+
group_id: e.group_id,
|
|
824
|
+
fact: e.fact,
|
|
825
|
+
fact_embedding: e.fact_embedding,
|
|
826
|
+
episodes: e.episodes,
|
|
827
|
+
expired_at: e.expired_at,
|
|
828
|
+
valid_at: e.valid_at,
|
|
829
|
+
invalid_at: e.invalid_at,
|
|
830
|
+
attributes: properties(e)
|
|
831
|
+
})[..$limit] AS matches
|
|
776
832
|
"""
|
|
777
833
|
)
|
|
778
834
|
|
|
779
835
|
results, _, _ = await driver.execute_query(
|
|
780
836
|
query,
|
|
781
|
-
query_params,
|
|
837
|
+
params=query_params,
|
|
782
838
|
edges=[edge.model_dump() for edge in edges],
|
|
783
839
|
limit=limit,
|
|
784
840
|
min_score=min_score,
|
|
785
841
|
database_=DEFAULT_DATABASE,
|
|
786
842
|
routing_='r',
|
|
787
843
|
)
|
|
844
|
+
|
|
788
845
|
relevant_edges_dict: dict[str, list[EntityEdge]] = {
|
|
789
846
|
result['search_edge_uuid']: [
|
|
790
847
|
get_entity_edge_from_record(record) for record in result['matches']
|
|
@@ -798,7 +855,7 @@ async def get_relevant_edges(
|
|
|
798
855
|
|
|
799
856
|
|
|
800
857
|
async def get_edge_invalidation_candidates(
|
|
801
|
-
driver:
|
|
858
|
+
driver: GraphDriver,
|
|
802
859
|
edges: list[EntityEdge],
|
|
803
860
|
search_filter: SearchFilters,
|
|
804
861
|
min_score: float = DEFAULT_MIN_SCORE,
|
|
@@ -814,37 +871,41 @@ async def get_edge_invalidation_candidates(
|
|
|
814
871
|
|
|
815
872
|
query = (
|
|
816
873
|
RUNTIME_QUERY
|
|
817
|
-
+ """
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
874
|
+
+ """
|
|
875
|
+
UNWIND $edges AS edge
|
|
876
|
+
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
|
|
877
|
+
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
|
|
878
|
+
"""
|
|
821
879
|
+ filter_query
|
|
822
880
|
+ """
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
881
|
+
WITH edge, e, """
|
|
882
|
+
+ get_vector_cosine_func_query('e.fact_embedding', 'edge.fact_embedding', driver.provider)
|
|
883
|
+
+ """ AS score
|
|
884
|
+
WHERE score > $min_score
|
|
885
|
+
WITH edge, e, score
|
|
886
|
+
ORDER BY score DESC
|
|
887
|
+
RETURN edge.uuid AS search_edge_uuid,
|
|
888
|
+
collect({
|
|
889
|
+
uuid: e.uuid,
|
|
890
|
+
source_node_uuid: startNode(e).uuid,
|
|
891
|
+
target_node_uuid: endNode(e).uuid,
|
|
892
|
+
created_at: e.created_at,
|
|
893
|
+
name: e.name,
|
|
894
|
+
group_id: e.group_id,
|
|
895
|
+
fact: e.fact,
|
|
896
|
+
fact_embedding: e.fact_embedding,
|
|
897
|
+
episodes: e.episodes,
|
|
898
|
+
expired_at: e.expired_at,
|
|
899
|
+
valid_at: e.valid_at,
|
|
900
|
+
invalid_at: e.invalid_at,
|
|
901
|
+
attributes: properties(e)
|
|
902
|
+
})[..$limit] AS matches
|
|
842
903
|
"""
|
|
843
904
|
)
|
|
844
905
|
|
|
845
906
|
results, _, _ = await driver.execute_query(
|
|
846
907
|
query,
|
|
847
|
-
query_params,
|
|
908
|
+
params=query_params,
|
|
848
909
|
edges=[edge.model_dump() for edge in edges],
|
|
849
910
|
limit=limit,
|
|
850
911
|
min_score=min_score,
|
|
@@ -879,7 +940,7 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
|
|
|
879
940
|
|
|
880
941
|
|
|
881
942
|
async def node_distance_reranker(
|
|
882
|
-
driver:
|
|
943
|
+
driver: GraphDriver,
|
|
883
944
|
node_uuids: list[str],
|
|
884
945
|
center_node_uuid: str,
|
|
885
946
|
min_score: float = 0,
|
|
@@ -889,21 +950,22 @@ async def node_distance_reranker(
|
|
|
889
950
|
scores: dict[str, float] = {center_node_uuid: 0.0}
|
|
890
951
|
|
|
891
952
|
# Find the shortest path to center node
|
|
892
|
-
query =
|
|
953
|
+
query = """
|
|
893
954
|
UNWIND $node_uuids AS node_uuid
|
|
894
|
-
MATCH
|
|
895
|
-
RETURN
|
|
896
|
-
"""
|
|
897
|
-
|
|
898
|
-
path_results, _, _ = await driver.execute_query(
|
|
955
|
+
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
|
|
956
|
+
RETURN 1 AS score, node_uuid AS uuid
|
|
957
|
+
"""
|
|
958
|
+
results, header, _ = await driver.execute_query(
|
|
899
959
|
query,
|
|
900
960
|
node_uuids=filtered_uuids,
|
|
901
961
|
center_uuid=center_node_uuid,
|
|
902
962
|
database_=DEFAULT_DATABASE,
|
|
903
963
|
routing_='r',
|
|
904
964
|
)
|
|
965
|
+
if driver.provider == 'falkordb':
|
|
966
|
+
results = [dict(zip(header, row, strict=True)) for row in results]
|
|
905
967
|
|
|
906
|
-
for result in
|
|
968
|
+
for result in results:
|
|
907
969
|
uuid = result['uuid']
|
|
908
970
|
score = result['score']
|
|
909
971
|
scores[uuid] = score
|
|
@@ -924,19 +986,18 @@ async def node_distance_reranker(
|
|
|
924
986
|
|
|
925
987
|
|
|
926
988
|
async def episode_mentions_reranker(
|
|
927
|
-
driver:
|
|
989
|
+
driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
|
|
928
990
|
) -> list[str]:
|
|
929
991
|
# use rrf as a preliminary ranker
|
|
930
992
|
sorted_uuids = rrf(node_uuids)
|
|
931
993
|
scores: dict[str, float] = {}
|
|
932
994
|
|
|
933
995
|
# Find the shortest path to center node
|
|
934
|
-
query =
|
|
996
|
+
query = """
|
|
935
997
|
UNWIND $node_uuids AS node_uuid
|
|
936
998
|
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
|
|
937
999
|
RETURN count(*) AS score, n.uuid AS uuid
|
|
938
|
-
"""
|
|
939
|
-
|
|
1000
|
+
"""
|
|
940
1001
|
results, _, _ = await driver.execute_query(
|
|
941
1002
|
query,
|
|
942
1003
|
node_uuids=sorted_uuids,
|
|
@@ -993,7 +1054,7 @@ def maximal_marginal_relevance(
|
|
|
993
1054
|
|
|
994
1055
|
|
|
995
1056
|
async def get_embeddings_for_nodes(
|
|
996
|
-
driver:
|
|
1057
|
+
driver: GraphDriver, nodes: list[EntityNode]
|
|
997
1058
|
) -> dict[str, list[float]]:
|
|
998
1059
|
query: LiteralString = """MATCH (n:Entity)
|
|
999
1060
|
WHERE n.uuid IN $node_uuids
|
|
@@ -1017,7 +1078,7 @@ async def get_embeddings_for_nodes(
|
|
|
1017
1078
|
|
|
1018
1079
|
|
|
1019
1080
|
async def get_embeddings_for_communities(
|
|
1020
|
-
driver:
|
|
1081
|
+
driver: GraphDriver, communities: list[CommunityNode]
|
|
1021
1082
|
) -> dict[str, list[float]]:
|
|
1022
1083
|
query: LiteralString = """MATCH (c:Community)
|
|
1023
1084
|
WHERE c.uuid IN $community_uuids
|
|
@@ -1044,7 +1105,7 @@ async def get_embeddings_for_communities(
|
|
|
1044
1105
|
|
|
1045
1106
|
|
|
1046
1107
|
async def get_embeddings_for_edges(
|
|
1047
|
-
driver:
|
|
1108
|
+
driver: GraphDriver, edges: list[EntityEdge]
|
|
1048
1109
|
) -> dict[str, list[float]]:
|
|
1049
1110
|
query: LiteralString = """MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
|
|
1050
1111
|
WHERE e.uuid IN $edge_uuids
|