graphiti-core 0.17.10__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of graphiti-core might be problematic. Click here for more details.
- graphiti_core/graphiti.py +1 -1
- graphiti_core/search/search.py +68 -36
- graphiti_core/search/search_config.py +8 -4
- graphiti_core/search/search_utils.py +40 -22
- {graphiti_core-0.17.10.dist-info → graphiti_core-0.18.0.dist-info}/METADATA +1 -1
- {graphiti_core-0.17.10.dist-info → graphiti_core-0.18.0.dist-info}/RECORD +8 -8
- {graphiti_core-0.17.10.dist-info → graphiti_core-0.18.0.dist-info}/WHEEL +0 -0
- {graphiti_core-0.17.10.dist-info → graphiti_core-0.18.0.dist-info}/licenses/LICENSE +0 -0
graphiti_core/graphiti.py
CHANGED
|
@@ -959,7 +959,7 @@ class Graphiti:
|
|
|
959
959
|
|
|
960
960
|
nodes = await get_mentioned_nodes(self.driver, episodes)
|
|
961
961
|
|
|
962
|
-
return SearchResults(edges=edges, nodes=nodes
|
|
962
|
+
return SearchResults(edges=edges, nodes=nodes)
|
|
963
963
|
|
|
964
964
|
async def add_triplet(self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode):
|
|
965
965
|
if source_node.name_embedding is None:
|
graphiti_core/search/search.py
CHANGED
|
@@ -80,12 +80,7 @@ async def search(
|
|
|
80
80
|
cross_encoder = clients.cross_encoder
|
|
81
81
|
|
|
82
82
|
if query.strip() == '':
|
|
83
|
-
return SearchResults(
|
|
84
|
-
edges=[],
|
|
85
|
-
nodes=[],
|
|
86
|
-
episodes=[],
|
|
87
|
-
communities=[],
|
|
88
|
-
)
|
|
83
|
+
return SearchResults()
|
|
89
84
|
query_vector = (
|
|
90
85
|
query_vector
|
|
91
86
|
if query_vector is not None
|
|
@@ -94,7 +89,12 @@ async def search(
|
|
|
94
89
|
|
|
95
90
|
# if group_ids is empty, set it to None
|
|
96
91
|
group_ids = group_ids if group_ids and group_ids != [''] else None
|
|
97
|
-
|
|
92
|
+
(
|
|
93
|
+
(edges, edge_reranker_scores),
|
|
94
|
+
(nodes, node_reranker_scores),
|
|
95
|
+
(episodes, episode_reranker_scores),
|
|
96
|
+
(communities, community_reranker_scores),
|
|
97
|
+
) = await semaphore_gather(
|
|
98
98
|
edge_search(
|
|
99
99
|
driver,
|
|
100
100
|
cross_encoder,
|
|
@@ -146,9 +146,13 @@ async def search(
|
|
|
146
146
|
|
|
147
147
|
results = SearchResults(
|
|
148
148
|
edges=edges,
|
|
149
|
+
edge_reranker_scores=edge_reranker_scores,
|
|
149
150
|
nodes=nodes,
|
|
151
|
+
node_reranker_scores=node_reranker_scores,
|
|
150
152
|
episodes=episodes,
|
|
153
|
+
episode_reranker_scores=episode_reranker_scores,
|
|
151
154
|
communities=communities,
|
|
155
|
+
community_reranker_scores=community_reranker_scores,
|
|
152
156
|
)
|
|
153
157
|
|
|
154
158
|
latency = (time() - start) * 1000
|
|
@@ -170,9 +174,9 @@ async def edge_search(
|
|
|
170
174
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
171
175
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
172
176
|
reranker_min_score: float = 0,
|
|
173
|
-
) -> list[EntityEdge]:
|
|
177
|
+
) -> tuple[list[EntityEdge], list[float]]:
|
|
174
178
|
if config is None:
|
|
175
|
-
return []
|
|
179
|
+
return [], []
|
|
176
180
|
search_results: list[list[EntityEdge]] = list(
|
|
177
181
|
await semaphore_gather(
|
|
178
182
|
*[
|
|
@@ -188,7 +192,12 @@ async def edge_search(
|
|
|
188
192
|
config.sim_min_score,
|
|
189
193
|
),
|
|
190
194
|
edge_bfs_search(
|
|
191
|
-
driver,
|
|
195
|
+
driver,
|
|
196
|
+
bfs_origin_node_uuids,
|
|
197
|
+
config.bfs_max_depth,
|
|
198
|
+
search_filter,
|
|
199
|
+
group_ids,
|
|
200
|
+
2 * limit,
|
|
192
201
|
),
|
|
193
202
|
]
|
|
194
203
|
)
|
|
@@ -198,22 +207,28 @@ async def edge_search(
|
|
|
198
207
|
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
|
|
199
208
|
search_results.append(
|
|
200
209
|
await edge_bfs_search(
|
|
201
|
-
driver,
|
|
210
|
+
driver,
|
|
211
|
+
source_node_uuids,
|
|
212
|
+
config.bfs_max_depth,
|
|
213
|
+
search_filter,
|
|
214
|
+
group_ids,
|
|
215
|
+
2 * limit,
|
|
202
216
|
)
|
|
203
217
|
)
|
|
204
218
|
|
|
205
219
|
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
|
|
206
220
|
|
|
207
221
|
reranked_uuids: list[str] = []
|
|
222
|
+
edge_scores: list[float] = []
|
|
208
223
|
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
|
|
209
224
|
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
|
|
210
225
|
|
|
211
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
226
|
+
reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
212
227
|
elif config.reranker == EdgeReranker.mmr:
|
|
213
228
|
search_result_uuids_and_vectors = await get_embeddings_for_edges(
|
|
214
229
|
driver, list(edge_uuid_map.values())
|
|
215
230
|
)
|
|
216
|
-
reranked_uuids = maximal_marginal_relevance(
|
|
231
|
+
reranked_uuids, edge_scores = maximal_marginal_relevance(
|
|
217
232
|
query_vector,
|
|
218
233
|
search_result_uuids_and_vectors,
|
|
219
234
|
config.mmr_lambda,
|
|
@@ -225,12 +240,13 @@ async def edge_search(
|
|
|
225
240
|
reranked_uuids = [
|
|
226
241
|
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
|
|
227
242
|
]
|
|
243
|
+
edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
|
|
228
244
|
elif config.reranker == EdgeReranker.node_distance:
|
|
229
245
|
if center_node_uuid is None:
|
|
230
246
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
231
247
|
|
|
232
248
|
# use rrf as a preliminary sort
|
|
233
|
-
sorted_result_uuids = rrf(
|
|
249
|
+
sorted_result_uuids, node_scores = rrf(
|
|
234
250
|
[[edge.uuid for edge in result] for result in search_results],
|
|
235
251
|
min_score=reranker_min_score,
|
|
236
252
|
)
|
|
@@ -243,7 +259,7 @@ async def edge_search(
|
|
|
243
259
|
|
|
244
260
|
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
|
|
245
261
|
|
|
246
|
-
reranked_node_uuids = await node_distance_reranker(
|
|
262
|
+
reranked_node_uuids, edge_scores = await node_distance_reranker(
|
|
247
263
|
driver, source_uuids, center_node_uuid, min_score=reranker_min_score
|
|
248
264
|
)
|
|
249
265
|
|
|
@@ -255,7 +271,7 @@ async def edge_search(
|
|
|
255
271
|
if config.reranker == EdgeReranker.episode_mentions:
|
|
256
272
|
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
|
|
257
273
|
|
|
258
|
-
return reranked_edges[:limit]
|
|
274
|
+
return reranked_edges[:limit], edge_scores[:limit]
|
|
259
275
|
|
|
260
276
|
|
|
261
277
|
async def node_search(
|
|
@@ -270,9 +286,9 @@ async def node_search(
|
|
|
270
286
|
bfs_origin_node_uuids: list[str] | None = None,
|
|
271
287
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
272
288
|
reranker_min_score: float = 0,
|
|
273
|
-
) -> list[EntityNode]:
|
|
289
|
+
) -> tuple[list[EntityNode], list[float]]:
|
|
274
290
|
if config is None:
|
|
275
|
-
return []
|
|
291
|
+
return [], []
|
|
276
292
|
search_results: list[list[EntityNode]] = list(
|
|
277
293
|
await semaphore_gather(
|
|
278
294
|
*[
|
|
@@ -281,7 +297,12 @@ async def node_search(
|
|
|
281
297
|
driver, query_vector, search_filter, group_ids, 2 * limit, config.sim_min_score
|
|
282
298
|
),
|
|
283
299
|
node_bfs_search(
|
|
284
|
-
driver,
|
|
300
|
+
driver,
|
|
301
|
+
bfs_origin_node_uuids,
|
|
302
|
+
search_filter,
|
|
303
|
+
config.bfs_max_depth,
|
|
304
|
+
group_ids,
|
|
305
|
+
2 * limit,
|
|
285
306
|
),
|
|
286
307
|
]
|
|
287
308
|
)
|
|
@@ -291,7 +312,12 @@ async def node_search(
|
|
|
291
312
|
origin_node_uuids = [node.uuid for result in search_results for node in result]
|
|
292
313
|
search_results.append(
|
|
293
314
|
await node_bfs_search(
|
|
294
|
-
driver,
|
|
315
|
+
driver,
|
|
316
|
+
origin_node_uuids,
|
|
317
|
+
search_filter,
|
|
318
|
+
config.bfs_max_depth,
|
|
319
|
+
group_ids,
|
|
320
|
+
2 * limit,
|
|
295
321
|
)
|
|
296
322
|
)
|
|
297
323
|
|
|
@@ -299,14 +325,15 @@ async def node_search(
|
|
|
299
325
|
node_uuid_map = {node.uuid: node for result in search_results for node in result}
|
|
300
326
|
|
|
301
327
|
reranked_uuids: list[str] = []
|
|
328
|
+
node_scores: list[float] = []
|
|
302
329
|
if config.reranker == NodeReranker.rrf:
|
|
303
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
330
|
+
reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
304
331
|
elif config.reranker == NodeReranker.mmr:
|
|
305
332
|
search_result_uuids_and_vectors = await get_embeddings_for_nodes(
|
|
306
333
|
driver, list(node_uuid_map.values())
|
|
307
334
|
)
|
|
308
335
|
|
|
309
|
-
reranked_uuids = maximal_marginal_relevance(
|
|
336
|
+
reranked_uuids, node_scores = maximal_marginal_relevance(
|
|
310
337
|
query_vector,
|
|
311
338
|
search_result_uuids_and_vectors,
|
|
312
339
|
config.mmr_lambda,
|
|
@@ -321,23 +348,24 @@ async def node_search(
|
|
|
321
348
|
for name, score in reranked_node_names
|
|
322
349
|
if score >= reranker_min_score
|
|
323
350
|
]
|
|
351
|
+
node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
|
|
324
352
|
elif config.reranker == NodeReranker.episode_mentions:
|
|
325
|
-
reranked_uuids = await episode_mentions_reranker(
|
|
353
|
+
reranked_uuids, node_scores = await episode_mentions_reranker(
|
|
326
354
|
driver, search_result_uuids, min_score=reranker_min_score
|
|
327
355
|
)
|
|
328
356
|
elif config.reranker == NodeReranker.node_distance:
|
|
329
357
|
if center_node_uuid is None:
|
|
330
358
|
raise SearchRerankerError('No center node provided for Node Distance reranker')
|
|
331
|
-
reranked_uuids = await node_distance_reranker(
|
|
359
|
+
reranked_uuids, node_scores = await node_distance_reranker(
|
|
332
360
|
driver,
|
|
333
|
-
rrf(search_result_uuids, min_score=reranker_min_score),
|
|
361
|
+
rrf(search_result_uuids, min_score=reranker_min_score)[0],
|
|
334
362
|
center_node_uuid,
|
|
335
363
|
min_score=reranker_min_score,
|
|
336
364
|
)
|
|
337
365
|
|
|
338
366
|
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
|
|
339
367
|
|
|
340
|
-
return reranked_nodes[:limit]
|
|
368
|
+
return reranked_nodes[:limit], node_scores[:limit]
|
|
341
369
|
|
|
342
370
|
|
|
343
371
|
async def episode_search(
|
|
@@ -350,9 +378,9 @@ async def episode_search(
|
|
|
350
378
|
search_filter: SearchFilters,
|
|
351
379
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
352
380
|
reranker_min_score: float = 0,
|
|
353
|
-
) -> list[EpisodicNode]:
|
|
381
|
+
) -> tuple[list[EpisodicNode], list[float]]:
|
|
354
382
|
if config is None:
|
|
355
|
-
return []
|
|
383
|
+
return [], []
|
|
356
384
|
search_results: list[list[EpisodicNode]] = list(
|
|
357
385
|
await semaphore_gather(
|
|
358
386
|
*[
|
|
@@ -365,12 +393,13 @@ async def episode_search(
|
|
|
365
393
|
episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
|
|
366
394
|
|
|
367
395
|
reranked_uuids: list[str] = []
|
|
396
|
+
episode_scores: list[float] = []
|
|
368
397
|
if config.reranker == EpisodeReranker.rrf:
|
|
369
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
398
|
+
reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
370
399
|
|
|
371
400
|
elif config.reranker == EpisodeReranker.cross_encoder:
|
|
372
401
|
# use rrf as a preliminary reranker
|
|
373
|
-
rrf_result_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
402
|
+
rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
374
403
|
rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
|
|
375
404
|
|
|
376
405
|
content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
|
|
@@ -381,10 +410,11 @@ async def episode_search(
|
|
|
381
410
|
for content, score in reranked_contents
|
|
382
411
|
if score >= reranker_min_score
|
|
383
412
|
]
|
|
413
|
+
episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]
|
|
384
414
|
|
|
385
415
|
reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
|
|
386
416
|
|
|
387
|
-
return reranked_episodes[:limit]
|
|
417
|
+
return reranked_episodes[:limit], episode_scores[:limit]
|
|
388
418
|
|
|
389
419
|
|
|
390
420
|
async def community_search(
|
|
@@ -396,9 +426,9 @@ async def community_search(
|
|
|
396
426
|
config: CommunitySearchConfig | None,
|
|
397
427
|
limit=DEFAULT_SEARCH_LIMIT,
|
|
398
428
|
reranker_min_score: float = 0,
|
|
399
|
-
) -> list[CommunityNode]:
|
|
429
|
+
) -> tuple[list[CommunityNode], list[float]]:
|
|
400
430
|
if config is None:
|
|
401
|
-
return []
|
|
431
|
+
return [], []
|
|
402
432
|
|
|
403
433
|
search_results: list[list[CommunityNode]] = list(
|
|
404
434
|
await semaphore_gather(
|
|
@@ -417,14 +447,15 @@ async def community_search(
|
|
|
417
447
|
}
|
|
418
448
|
|
|
419
449
|
reranked_uuids: list[str] = []
|
|
450
|
+
community_scores: list[float] = []
|
|
420
451
|
if config.reranker == CommunityReranker.rrf:
|
|
421
|
-
reranked_uuids = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
452
|
+
reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
|
|
422
453
|
elif config.reranker == CommunityReranker.mmr:
|
|
423
454
|
search_result_uuids_and_vectors = await get_embeddings_for_communities(
|
|
424
455
|
driver, list(community_uuid_map.values())
|
|
425
456
|
)
|
|
426
457
|
|
|
427
|
-
reranked_uuids = maximal_marginal_relevance(
|
|
458
|
+
reranked_uuids, community_scores = maximal_marginal_relevance(
|
|
428
459
|
query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
|
|
429
460
|
)
|
|
430
461
|
elif config.reranker == CommunityReranker.cross_encoder:
|
|
@@ -433,7 +464,8 @@ async def community_search(
|
|
|
433
464
|
reranked_uuids = [
|
|
434
465
|
name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
|
|
435
466
|
]
|
|
467
|
+
community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]
|
|
436
468
|
|
|
437
469
|
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
|
|
438
470
|
|
|
439
|
-
return reranked_communities[:limit]
|
|
471
|
+
return reranked_communities[:limit], community_scores[:limit]
|
|
@@ -119,7 +119,11 @@ class SearchConfig(BaseModel):
|
|
|
119
119
|
|
|
120
120
|
|
|
121
121
|
class SearchResults(BaseModel):
|
|
122
|
-
edges: list[EntityEdge]
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
122
|
+
edges: list[EntityEdge] = Field(default_factory=list)
|
|
123
|
+
edge_reranker_scores: list[float] = Field(default_factory=list)
|
|
124
|
+
nodes: list[EntityNode] = Field(default_factory=list)
|
|
125
|
+
node_reranker_scores: list[float] = Field(default_factory=list)
|
|
126
|
+
episodes: list[EpisodicNode] = Field(default_factory=list)
|
|
127
|
+
episode_reranker_scores: list[float] = Field(default_factory=list)
|
|
128
|
+
communities: list[CommunityNode] = Field(default_factory=list)
|
|
129
|
+
community_reranker_scores: list[float] = Field(default_factory=list)
|
|
@@ -283,7 +283,8 @@ async def edge_bfs_search(
|
|
|
283
283
|
bfs_origin_node_uuids: list[str] | None,
|
|
284
284
|
bfs_max_depth: int,
|
|
285
285
|
search_filter: SearchFilters,
|
|
286
|
-
|
|
286
|
+
group_ids: list[str] | None = None,
|
|
287
|
+
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
287
288
|
) -> list[EntityEdge]:
|
|
288
289
|
# vector similarity search over embedded facts
|
|
289
290
|
if bfs_origin_node_uuids is None:
|
|
@@ -293,12 +294,13 @@ async def edge_bfs_search(
|
|
|
293
294
|
|
|
294
295
|
query = (
|
|
295
296
|
"""
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
297
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
298
|
+
MATCH path = (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
|
299
|
+
UNWIND relationships(path) AS rel
|
|
300
|
+
MATCH (n:Entity)-[r:RELATES_TO]-(m:Entity)
|
|
301
|
+
WHERE r.uuid = rel.uuid
|
|
302
|
+
AND r.group_id IN $group_ids
|
|
303
|
+
"""
|
|
302
304
|
+ filter_query
|
|
303
305
|
+ """
|
|
304
306
|
RETURN DISTINCT
|
|
@@ -323,6 +325,7 @@ async def edge_bfs_search(
|
|
|
323
325
|
params=filter_params,
|
|
324
326
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
325
327
|
depth=bfs_max_depth,
|
|
328
|
+
group_ids=group_ids,
|
|
326
329
|
limit=limit,
|
|
327
330
|
routing_='r',
|
|
328
331
|
)
|
|
@@ -431,7 +434,8 @@ async def node_bfs_search(
|
|
|
431
434
|
bfs_origin_node_uuids: list[str] | None,
|
|
432
435
|
search_filter: SearchFilters,
|
|
433
436
|
bfs_max_depth: int,
|
|
434
|
-
|
|
437
|
+
group_ids: list[str] | None = None,
|
|
438
|
+
limit: int = RELEVANT_SCHEMA_LIMIT,
|
|
435
439
|
) -> list[EntityNode]:
|
|
436
440
|
# vector similarity search over entity names
|
|
437
441
|
if bfs_origin_node_uuids is None:
|
|
@@ -441,10 +445,11 @@ async def node_bfs_search(
|
|
|
441
445
|
|
|
442
446
|
query = (
|
|
443
447
|
"""
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
+
UNWIND $bfs_origin_node_uuids AS origin_uuid
|
|
449
|
+
MATCH (origin:Entity|Episodic {uuid: origin_uuid})-[:RELATES_TO|MENTIONS]->{1,3}(n:Entity)
|
|
450
|
+
WHERE n.group_id = origin.group_id
|
|
451
|
+
AND origin.group_id IN $group_ids
|
|
452
|
+
"""
|
|
448
453
|
+ filter_query
|
|
449
454
|
+ ENTITY_NODE_RETURN
|
|
450
455
|
+ """
|
|
@@ -456,6 +461,7 @@ async def node_bfs_search(
|
|
|
456
461
|
params=filter_params,
|
|
457
462
|
bfs_origin_node_uuids=bfs_origin_node_uuids,
|
|
458
463
|
depth=bfs_max_depth,
|
|
464
|
+
group_ids=group_ids,
|
|
459
465
|
limit=limit,
|
|
460
466
|
routing_='r',
|
|
461
467
|
)
|
|
@@ -482,6 +488,7 @@ async def episode_fulltext_search(
|
|
|
482
488
|
YIELD node AS episode, score
|
|
483
489
|
MATCH (e:Episodic)
|
|
484
490
|
WHERE e.uuid = episode.uuid
|
|
491
|
+
AND e.group_id IN $group_ids
|
|
485
492
|
RETURN
|
|
486
493
|
e.content AS content,
|
|
487
494
|
e.created_at AS created_at,
|
|
@@ -524,6 +531,7 @@ async def community_fulltext_search(
|
|
|
524
531
|
get_nodes_query(driver.provider, 'community_name', '$query')
|
|
525
532
|
+ """
|
|
526
533
|
YIELD node AS comm, score
|
|
534
|
+
WHERE comm.group_id IN $group_ids
|
|
527
535
|
RETURN
|
|
528
536
|
comm.uuid AS uuid,
|
|
529
537
|
comm.group_id AS group_id,
|
|
@@ -664,7 +672,7 @@ async def hybrid_node_search(
|
|
|
664
672
|
}
|
|
665
673
|
result_uuids = [[node.uuid for node in result] for result in results]
|
|
666
674
|
|
|
667
|
-
ranked_uuids = rrf(result_uuids)
|
|
675
|
+
ranked_uuids, _ = rrf(result_uuids)
|
|
668
676
|
|
|
669
677
|
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
|
|
670
678
|
|
|
@@ -906,7 +914,9 @@ async def get_edge_invalidation_candidates(
|
|
|
906
914
|
|
|
907
915
|
|
|
908
916
|
# takes in a list of rankings of uuids
|
|
909
|
-
def rrf(
|
|
917
|
+
def rrf(
|
|
918
|
+
results: list[list[str]], rank_const=1, min_score: float = 0
|
|
919
|
+
) -> tuple[list[str], list[float]]:
|
|
910
920
|
scores: dict[str, float] = defaultdict(float)
|
|
911
921
|
for result in results:
|
|
912
922
|
for i, uuid in enumerate(result):
|
|
@@ -917,7 +927,9 @@ def rrf(results: list[list[str]], rank_const=1, min_score: float = 0) -> list[st
|
|
|
917
927
|
|
|
918
928
|
sorted_uuids = [term[0] for term in scored_uuids]
|
|
919
929
|
|
|
920
|
-
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
|
|
930
|
+
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
|
|
931
|
+
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
|
|
932
|
+
]
|
|
921
933
|
|
|
922
934
|
|
|
923
935
|
async def node_distance_reranker(
|
|
@@ -925,7 +937,7 @@ async def node_distance_reranker(
|
|
|
925
937
|
node_uuids: list[str],
|
|
926
938
|
center_node_uuid: str,
|
|
927
939
|
min_score: float = 0,
|
|
928
|
-
) -> list[str]:
|
|
940
|
+
) -> tuple[list[str], list[float]]:
|
|
929
941
|
# filter out node_uuid center node node uuid
|
|
930
942
|
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
|
|
931
943
|
scores: dict[str, float] = {center_node_uuid: 0.0}
|
|
@@ -962,14 +974,16 @@ async def node_distance_reranker(
|
|
|
962
974
|
scores[center_node_uuid] = 0.1
|
|
963
975
|
filtered_uuids = [center_node_uuid] + filtered_uuids
|
|
964
976
|
|
|
965
|
-
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score]
|
|
977
|
+
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score], [
|
|
978
|
+
1 / scores[uuid] for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score
|
|
979
|
+
]
|
|
966
980
|
|
|
967
981
|
|
|
968
982
|
async def episode_mentions_reranker(
|
|
969
983
|
driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
|
|
970
|
-
) -> list[str]:
|
|
984
|
+
) -> tuple[list[str], list[float]]:
|
|
971
985
|
# use rrf as a preliminary ranker
|
|
972
|
-
sorted_uuids = rrf(node_uuids)
|
|
986
|
+
sorted_uuids, _ = rrf(node_uuids)
|
|
973
987
|
scores: dict[str, float] = {}
|
|
974
988
|
|
|
975
989
|
# Find the shortest path to center node
|
|
@@ -990,7 +1004,9 @@ async def episode_mentions_reranker(
|
|
|
990
1004
|
# rerank on shortest distance
|
|
991
1005
|
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
|
|
992
1006
|
|
|
993
|
-
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score]
|
|
1007
|
+
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
|
|
1008
|
+
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
|
|
1009
|
+
]
|
|
994
1010
|
|
|
995
1011
|
|
|
996
1012
|
def maximal_marginal_relevance(
|
|
@@ -998,7 +1014,7 @@ def maximal_marginal_relevance(
|
|
|
998
1014
|
candidates: dict[str, list[float]],
|
|
999
1015
|
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
|
|
1000
1016
|
min_score: float = -2.0,
|
|
1001
|
-
) -> list[str]:
|
|
1017
|
+
) -> tuple[list[str], list[float]]:
|
|
1002
1018
|
start = time()
|
|
1003
1019
|
query_array = np.array(query_vector)
|
|
1004
1020
|
candidate_arrays: dict[str, NDArray] = {}
|
|
@@ -1029,7 +1045,9 @@ def maximal_marginal_relevance(
|
|
|
1029
1045
|
end = time()
|
|
1030
1046
|
logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
|
|
1031
1047
|
|
|
1032
|
-
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score]
|
|
1048
|
+
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score], [
|
|
1049
|
+
mmr_scores[uuid] for uuid in uuids if mmr_scores[uuid] >= min_score
|
|
1050
|
+
]
|
|
1033
1051
|
|
|
1034
1052
|
|
|
1035
1053
|
async def get_embeddings_for_nodes(
|
|
@@ -2,7 +2,7 @@ graphiti_core/__init__.py,sha256=e5SWFkRiaUwfprYIeIgVIh7JDedNiloZvd3roU-0aDY,55
|
|
|
2
2
|
graphiti_core/edges.py,sha256=-SSP6rhk8Dl8LwUZ08GHymJTT5pNDtzb3BV-6z1fBYY,16030
|
|
3
3
|
graphiti_core/errors.py,sha256=cH_v9TPgEPeQE6GFOHIg5TvejpUCBddGarMY2Whxbwc,2707
|
|
4
4
|
graphiti_core/graph_queries.py,sha256=KfWDp8xDnPa9bcHskw8NeMpeeHBtZWBCosVdu1Iwv34,7076
|
|
5
|
-
graphiti_core/graphiti.py,sha256=
|
|
5
|
+
graphiti_core/graphiti.py,sha256=fNBDDOtChAG9U0t4nFD1Il882mMlr2TedeTzMuvNfnM,39568
|
|
6
6
|
graphiti_core/graphiti_types.py,sha256=rL-9bvnLobunJfXU4hkD6mAj14pofKp_wq8QsFDZwDU,1035
|
|
7
7
|
graphiti_core/helpers.py,sha256=YoMAEhe_aMPz_Cd_t1dnIffNwDpenINJu4URePglt2s,5247
|
|
8
8
|
graphiti_core/nodes.py,sha256=AcqHvhNWyapQwBSuziMvPJ-HnOr4Pv1-OiYsEodJcAA,18613
|
|
@@ -52,12 +52,12 @@ graphiti_core/prompts/models.py,sha256=NgxdbPHJpBEcpbXovKyScgpBc73Q-GIW-CBDlBtDj
|
|
|
52
52
|
graphiti_core/prompts/prompt_helpers.py,sha256=-9TABwIcIQUVHcNANx6wIZd-FT2DgYKyGTfx4IGYq2I,64
|
|
53
53
|
graphiti_core/prompts/summarize_nodes.py,sha256=tbg-AgWlzgFBeImKkZ28h2SpmqfPPqvN2Ol1Q71VF9Y,4146
|
|
54
54
|
graphiti_core/search/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
55
|
-
graphiti_core/search/search.py,sha256=
|
|
56
|
-
graphiti_core/search/search_config.py,sha256=
|
|
55
|
+
graphiti_core/search/search.py,sha256=u-kTmSu3VlRHYlQhuYsbwDQ-AKKCp3BZ9JZNRv3ttVY,16720
|
|
56
|
+
graphiti_core/search/search_config.py,sha256=v_rUHsu1yo5OuPfEm21lSuXexQs-o8qYwSSemW2QWhU,4165
|
|
57
57
|
graphiti_core/search/search_config_recipes.py,sha256=4GquRphHhJlpXQhAZOySYnCzBWYoTwxlJj44eTOavZQ,7443
|
|
58
58
|
graphiti_core/search/search_filters.py,sha256=cxiFkqB-r7QzVMh8nmujECLhzgsbeCpBHUQqDXnCQ3A,6383
|
|
59
59
|
graphiti_core/search/search_helpers.py,sha256=G5Ceaq5Pfgx0Weelqgeylp_pUHwiBnINaUYsDbURJbE,2636
|
|
60
|
-
graphiti_core/search/search_utils.py,sha256
|
|
60
|
+
graphiti_core/search/search_utils.py,sha256=nz1Z2HrOt3ay64x_HXaFBPyBNMecyoIqzhc-7Ac0rws,34894
|
|
61
61
|
graphiti_core/telemetry/__init__.py,sha256=5kALLDlU9bb2v19CdN7qVANsJWyfnL9E60J6FFgzm3o,226
|
|
62
62
|
graphiti_core/telemetry/telemetry.py,sha256=47LrzOVBCcZxsYPsnSxWFiztHoxYKKxPwyRX0hnbDGc,3230
|
|
63
63
|
graphiti_core/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -71,7 +71,7 @@ graphiti_core/utils/maintenance/node_operations.py,sha256=ZnopNRTNdBjBotQ2uQiI7E
|
|
|
71
71
|
graphiti_core/utils/maintenance/temporal_operations.py,sha256=mJkw9xLB4W2BsLfC5POr0r-PHWL9SIfNj_l_xu0B5ug,3410
|
|
72
72
|
graphiti_core/utils/maintenance/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
73
73
|
graphiti_core/utils/ontology_utils/entity_types_utils.py,sha256=QJX5cG0GSSNF_Mm_yrldr69wjVAbN_MxLhOSznz85Hk,1279
|
|
74
|
-
graphiti_core-0.
|
|
75
|
-
graphiti_core-0.
|
|
76
|
-
graphiti_core-0.
|
|
77
|
-
graphiti_core-0.
|
|
74
|
+
graphiti_core-0.18.0.dist-info/METADATA,sha256=dexXmf1OnLDXtkztEYhSDUufywjqUA5mOLWkLB9wqPc,23812
|
|
75
|
+
graphiti_core-0.18.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
76
|
+
graphiti_core-0.18.0.dist-info/licenses/LICENSE,sha256=KCUwCyDXuVEgmDWkozHyniRyWjnWUWjkuDHfU6o3JlA,11325
|
|
77
|
+
graphiti_core-0.18.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|