nucliadb 6.2.1.post2835__py3-none-any.whl → 6.2.1.post2842__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,913 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+
20
+ import asyncio
21
+ import heapq
22
+ import json
23
+ from collections import defaultdict
24
+ from datetime import datetime
25
+ from typing import Any, Collection, Iterable, Optional, Union
26
+
27
+ from nuclia_models.predict.generative_responses import (
28
+ JSONGenerativeResponse,
29
+ MetaGenerativeResponse,
30
+ StatusGenerativeResponse,
31
+ )
32
+ from pydantic import BaseModel
33
+ from sentry_sdk import capture_exception
34
+
35
+ from nucliadb.common.external_index_providers.base import TextBlockMatch
36
+ from nucliadb.common.ids import FieldId, ParagraphId
37
+ from nucliadb.search import logger
38
+ from nucliadb.search.requesters.utils import Method, node_query
39
+ from nucliadb.search.search.chat.query import (
40
+ find_request_from_ask_request,
41
+ get_relations_results_from_entities,
42
+ )
43
+ from nucliadb.search.search.find import query_parser_from_find_request
44
+ from nucliadb.search.search.find_merge import (
45
+ compose_find_resources,
46
+ hydrate_and_rerank,
47
+ )
48
+ from nucliadb.search.search.hydrator import ResourceHydrationOptions, TextBlockHydrationOptions
49
+ from nucliadb.search.search.merge import merge_suggest_results
50
+ from nucliadb.search.search.metrics import RAGMetrics
51
+ from nucliadb.search.search.query import QueryParser
52
+ from nucliadb.search.search.rerankers import Reranker, RerankingOptions
53
+ from nucliadb.search.utilities import get_predict
54
+ from nucliadb_models.common import FieldTypeName
55
+ from nucliadb_models.internal.predict import (
56
+ RerankModel,
57
+ )
58
+ from nucliadb_models.resource import ExtractedDataTypeName
59
+ from nucliadb_models.search import (
60
+ SCORE_TYPE,
61
+ AskRequest,
62
+ ChatModel,
63
+ DirectionalRelation,
64
+ EntitySubgraph,
65
+ GraphStrategy,
66
+ KnowledgeboxFindResults,
67
+ KnowledgeboxSuggestResults,
68
+ NucliaDBClientType,
69
+ QueryEntityDetection,
70
+ RelationDirection,
71
+ RelationRanking,
72
+ Relations,
73
+ ResourceProperties,
74
+ TextPosition,
75
+ UserPrompt,
76
+ )
77
+ from nucliadb_protos import nodereader_pb2
78
+ from nucliadb_protos.utils_pb2 import RelationNode
79
+
80
+ SCHEMA = {
81
+ "title": "score_triplets",
82
+ "description": "Return a list of triplets and their relevance scores (0-10) for the supplied question.",
83
+ "type": "object",
84
+ "properties": {
85
+ "triplets": {
86
+ "type": "array",
87
+ "description": "A list of triplets with their relevance scores.",
88
+ "items": {
89
+ "type": "object",
90
+ "properties": {
91
+ "head_entity": {"type": "string", "description": "The first entity in the triplet."},
92
+ "relationship": {
93
+ "type": "string",
94
+ "description": "The relationship between the two entities.",
95
+ },
96
+ "tail_entity": {
97
+ "type": "string",
98
+ "description": "The second entity in the triplet.",
99
+ },
100
+ "score": {
101
+ "type": "integer",
102
+ "description": "A relevance score in the range 0 to 10.",
103
+ "minimum": 0,
104
+ "maximum": 10,
105
+ },
106
+ },
107
+ "required": ["head_entity", "relationship", "tail_entity", "score"],
108
+ },
109
+ }
110
+ },
111
+ "required": ["triplets"],
112
+ }
113
+
114
+ PROMPT = """\
115
+ You are an advanced language model assisting in scoring relationships (edges) between two entities in a knowledge graph, given a user’s question.
116
+
117
+ For each provided **(head_entity, relationship, tail_entity)**, you must:
118
+ 1. Assign a **relevance score** between **0** and **10**.
119
+ 2. **0** means “this relationship can’t be relevant at all to the question.”
120
+ 3. **10** means “this relationship is extremely relevant to the question.”
121
+ 4. You may use **any integer** between 0 and 10 (e.g., 3, 7, etc.) based on how relevant you deem the relationship to be.
122
+ 5. **Language Agnosticism**: The question and the relationships may be in different languages. The relevance scoring should still work and be agnostic of the language.
123
+ 6. Relationships that may not answer the question directly but expand knowledge in a relevant way, should also be scored positively.
124
+
125
+ Once you have decided the best score for each triplet, return these results **using a function call** in JSON format with the following rules:
126
+
127
+ - The function name should be `score_triplets`.
128
+ - The first argument should be the list of triplets.
129
+ - Each triplet should have the following keys:
130
+ - `head_entity`: The first entity in the triplet.
131
+ - `relationship`: The relationship between the two entities.
132
+ - `tail_entity`: The second entity in the triplet.
133
+ - `score`: The relevance score in the range 0 to 10.
134
+
135
+ You **must** comply with the provided JSON Schema to ensure a well-structured response and mantain the order of the triplets.
136
+
137
+
138
+ ## Examples:
139
+
140
+ ### Example 1:
141
+
142
+ **Input**
143
+
144
+ {
145
+ "question": "Who is the mayor of the capital city of Australia?",
146
+ "triplets": [
147
+ {
148
+ "head_entity": "Australia",
149
+ "relationship": "has prime minister",
150
+ "tail_entity": "Scott Morrison"
151
+ },
152
+ {
153
+ "head_entity": "Canberra",
154
+ "relationship": "is capital of",
155
+ "tail_entity": "Australia"
156
+ },
157
+ {
158
+ "head_entity": "Scott Knowles",
159
+ "relationship": "holds position",
160
+ "tail_entity": "Mayor"
161
+ },
162
+ {
163
+ "head_entity": "Barbera Smith",
164
+ "relationship": "tiene cargo",
165
+ "tail_entity": "Alcalde"
166
+ },
167
+ {
168
+ "head_entity": "Austria",
169
+ "relationship": "has capital",
170
+ "tail_entity": "Vienna"
171
+ }
172
+ ]
173
+ }
174
+
175
+ **Output**
176
+
177
+ {
178
+ "triplets": [
179
+ {
180
+ "head_entity": "Australia",
181
+ "relationship": "has prime minister",
182
+ "tail_entity": "Scott Morrison",
183
+ "score": 4
184
+ },
185
+ {
186
+ "head_entity": "Canberra",
187
+ "relationship": "is capital of",
188
+ "tail_entity": "Australia",
189
+ "score": 8
190
+ },
191
+ {
192
+ "head_entity": "Scott Knowles",
193
+ "relationship": "holds position",
194
+ "tail_entity": "Mayor",
195
+ "score": 8
196
+ },
197
+ {
198
+ "head_entity": "Barbera Smith",
199
+ "relationship": "tiene cargo",
200
+ "tail_entity": "Alcalde",
201
+ "score": 8
202
+ },
203
+ {
204
+ "head_entity": "Austria",
205
+ "relationship": "has capital",
206
+ "tail_entity": "Vienna",
207
+ "score": 0
208
+ }
209
+ ]
210
+ }
211
+
212
+
213
+
214
+ ### Example 2:
215
+
216
+ **Input**
217
+
218
+ {
219
+ "question": "How many products does John Adams Roofing Inc. offer?",
220
+ "triplets": [
221
+ {
222
+ "head_entity": "John Adams Roofing Inc.",
223
+ "relationship": "has product",
224
+ "tail_entity": "Titanium Grade 3 Roofing Nails"
225
+ },
226
+ {
227
+ "head_entity": "John Adams Roofing Inc.",
228
+ "relationship": "is located in",
229
+ "tail_entity": "New York"
230
+ },
231
+ {
232
+ "head_entity": "John Adams Roofing Inc.",
233
+ "relationship": "was founded by",
234
+ "tail_entity": "John Adams"
235
+ },
236
+ {
237
+ "head_entity": "John Adams Roofing Inc.",
238
+ "relationship": "tiene stock",
239
+ "tail_entity": "Baldosas solares"
240
+ },
241
+ {
242
+ "head_entity": "John Adams Roofing Inc.",
243
+ "relationship": "has product",
244
+ "tail_entity": "Mercerized Cotton Thread"
245
+ }
246
+ ]
247
+ }
248
+
249
+ **Output**
250
+
251
+ {
252
+ "triplets": [
253
+ {
254
+ "head_entity": "John Adams Roofing Inc.",
255
+ "relationship": "has product",
256
+ "tail_entity": "Titanium Grade 3 Roofing Nails",
257
+ "score": 10
258
+ },
259
+ {
260
+ "head_entity": "John Adams Roofing Inc.",
261
+ "relationship": "is located in",
262
+ "tail_entity": "New York",
263
+ "score": 6
264
+ },
265
+ {
266
+ "head_entity": "John Adams Roofing Inc.",
267
+ "relationship": "was founded by",
268
+ "tail_entity": "John Adams",
269
+ "score": 5
270
+ },
271
+ {
272
+ "head_entity": "John Adams Roofing Inc.",
273
+ "relationship": "tiene stock",
274
+ "tail_entity": "Baldosas solares",
275
+ "score": 10
276
+ },
277
+ {
278
+ "head_entity": "John Adams Roofing Inc.",
279
+ "relationship": "has product",
280
+ "tail_entity": "Mercerized Cotton Thread",
281
+ "score": 10
282
+ }
283
+ ]
284
+ }
285
+
286
+ Now, let's get started! Here are the triplets you need to score:
287
+
288
+ **Input**
289
+
290
+ """
291
+
292
+
293
+ class RelationsParagraphMatch(BaseModel):
294
+ paragraph_id: ParagraphId
295
+ score: float
296
+ relations: Relations
297
+
298
+
299
+ async def get_graph_results(
300
+ *,
301
+ kbid: str,
302
+ query: str,
303
+ item: AskRequest,
304
+ ndb_client: NucliaDBClientType,
305
+ user: str,
306
+ origin: str,
307
+ graph_strategy: GraphStrategy,
308
+ generative_model: Optional[str] = None,
309
+ metrics: RAGMetrics = RAGMetrics(),
310
+ shards: Optional[list[str]] = None,
311
+ ) -> tuple[KnowledgeboxFindResults, QueryParser]:
312
+ relations = Relations(entities={})
313
+ explored_entities: set[str] = set()
314
+ predict = get_predict()
315
+
316
+ for hop in range(graph_strategy.hops):
317
+ entities_to_explore: Iterable[RelationNode] = []
318
+ scores: dict[str, list[float]] = {}
319
+ if hop == 0:
320
+ # Get the entities from the query
321
+ with metrics.time("graph_strat_query_entities"):
322
+ if graph_strategy.query_entity_detection == QueryEntityDetection.SUGGEST:
323
+ suggest_result = await fuzzy_search_entities(
324
+ kbid=kbid,
325
+ query=query,
326
+ range_creation_start=item.range_creation_start,
327
+ range_creation_end=item.range_creation_end,
328
+ range_modification_start=item.range_modification_start,
329
+ range_modification_end=item.range_modification_end,
330
+ target_shard_replicas=shards,
331
+ )
332
+ if suggest_result.entities is not None:
333
+ entities_to_explore = (
334
+ RelationNode(
335
+ ntype=RelationNode.NodeType.ENTITY,
336
+ value=result.value,
337
+ subtype=result.family,
338
+ )
339
+ for result in suggest_result.entities.entities
340
+ )
341
+ elif (
342
+ not entities_to_explore
343
+ or graph_strategy.query_entity_detection == QueryEntityDetection.PREDICT
344
+ ):
345
+ try:
346
+ entities_to_explore = await predict.detect_entities(kbid, query)
347
+ except Exception as e:
348
+ capture_exception(e)
349
+ logger.exception("Error in detecting entities for graph strategy")
350
+ entities_to_explore = []
351
+ else:
352
+ # Find neighbors of the current relations and remove the ones already explored
353
+ entities_to_explore = (
354
+ RelationNode(
355
+ ntype=RelationNode.NodeType.ENTITY,
356
+ value=relation.entity,
357
+ subtype=relation.entity_subtype,
358
+ )
359
+ for subgraph in relations.entities.values()
360
+ for relation in subgraph.related_to
361
+ if relation.entity not in explored_entities
362
+ )
363
+ # Get the relations for the new entities
364
+ with metrics.time("graph_strat_neighbor_relations"):
365
+ try:
366
+ new_relations = await get_relations_results_from_entities(
367
+ kbid=kbid,
368
+ entities=entities_to_explore,
369
+ target_shard_replicas=shards,
370
+ timeout=5.0,
371
+ only_with_metadata=True,
372
+ only_agentic_relations=graph_strategy.agentic_graph_only,
373
+ )
374
+ except Exception as e:
375
+ capture_exception(e)
376
+ logger.exception("Error in getting query relations for graph strategy")
377
+ new_relations = Relations(entities={})
378
+
379
+ # Removing the relations connected to the entities that were already explored
380
+ # XXX: This could be optimized by implementing a filter in the index
381
+ # so we don't have to remove them after
382
+ new_subgraphs = {
383
+ entity: filter_subgraph(subgraph, explored_entities)
384
+ for entity, subgraph in new_relations.entities.items()
385
+ }
386
+
387
+ explored_entities.update(new_subgraphs.keys())
388
+
389
+ if not new_subgraphs or all(not subgraph.related_to for subgraph in new_subgraphs.values()):
390
+ break
391
+
392
+ relations.entities.update(new_subgraphs)
393
+
394
+ # Rank the relevance of the relations
395
+ with metrics.time("graph_strat_rank_relations"):
396
+ try:
397
+ if graph_strategy.relation_ranking == RelationRanking.RERANKER:
398
+ relations, scores = await rank_relations_reranker(
399
+ relations,
400
+ query,
401
+ kbid,
402
+ user,
403
+ top_k=graph_strategy.top_k,
404
+ )
405
+ elif graph_strategy.relation_ranking == RelationRanking.GENERATIVE:
406
+ relations, scores = await rank_relations_generative(
407
+ relations,
408
+ query,
409
+ kbid,
410
+ user,
411
+ top_k=graph_strategy.top_k,
412
+ generative_model=generative_model,
413
+ )
414
+ except Exception as e:
415
+ capture_exception(e)
416
+ logger.exception("Error in ranking relations for graph strategy")
417
+ relations = Relations(entities={})
418
+ break
419
+
420
+ # Get the text blocks of the paragraphs that contain the top relations
421
+ with metrics.time("graph_strat_build_response"):
422
+ find_request = find_request_from_ask_request(item, query)
423
+ query_parser, rank_fusion, reranker = await query_parser_from_find_request(
424
+ kbid, find_request, generative_model=generative_model
425
+ )
426
+ find_results = await build_graph_response(
427
+ kbid=kbid,
428
+ query=query,
429
+ final_relations=relations,
430
+ scores=scores,
431
+ top_k=graph_strategy.top_k,
432
+ reranker=reranker,
433
+ show=find_request.show,
434
+ extracted=find_request.extracted,
435
+ field_type_filter=find_request.field_type_filter,
436
+ relation_text_as_paragraphs=graph_strategy.relation_text_as_paragraphs,
437
+ )
438
+ return find_results, query_parser
439
+
440
+
441
+ async def fuzzy_search_entities(
442
+ kbid: str,
443
+ query: str,
444
+ range_creation_start: Optional[datetime] = None,
445
+ range_creation_end: Optional[datetime] = None,
446
+ range_modification_start: Optional[datetime] = None,
447
+ range_modification_end: Optional[datetime] = None,
448
+ target_shard_replicas: Optional[list[str]] = None,
449
+ ) -> KnowledgeboxSuggestResults:
450
+ """Fuzzy find entities in KB given a query using the same methodology as /suggest, but split by words."""
451
+
452
+ base_request = nodereader_pb2.SuggestRequest(
453
+ body="", features=[nodereader_pb2.SuggestFeatures.ENTITIES]
454
+ )
455
+ if range_creation_start is not None:
456
+ base_request.timestamps.from_created.FromDatetime(range_creation_start)
457
+ if range_creation_end is not None:
458
+ base_request.timestamps.to_created.FromDatetime(range_creation_end)
459
+ if range_modification_start is not None:
460
+ base_request.timestamps.from_modified.FromDatetime(range_modification_start)
461
+ if range_modification_end is not None:
462
+ base_request.timestamps.to_modified.FromDatetime(range_modification_end)
463
+
464
+ tasks = []
465
+ # XXX: Splitting by words is not ideal, in the future, modify suggest to better handle this
466
+ for word in query.split():
467
+ if len(word) < 3:
468
+ continue
469
+ request = nodereader_pb2.SuggestRequest()
470
+ request.CopyFrom(base_request)
471
+ request.body = word
472
+ tasks.append(
473
+ node_query(kbid, Method.SUGGEST, request, target_shard_replicas=target_shard_replicas)
474
+ )
475
+
476
+ try:
477
+ results_raw = await asyncio.gather(*tasks)
478
+ return await merge_suggest_results(
479
+ [item for r in results_raw for item in r[0]],
480
+ kbid=kbid,
481
+ )
482
+ except Exception as e:
483
+ capture_exception(e)
484
+ logger.exception("Error in finding entities in query for graph strategy")
485
+ return KnowledgeboxSuggestResults(entities=None)
486
+
487
+
488
+ async def rank_relations_reranker(
489
+ relations: Relations,
490
+ query: str,
491
+ kbid: str,
492
+ user: str,
493
+ top_k: int,
494
+ score_threshold: float = 0.02,
495
+ ) -> tuple[Relations, dict[str, list[float]]]:
496
+ # Store the index for keeping track after scoring
497
+ flat_rels: list[tuple[str, int, DirectionalRelation]] = [
498
+ (ent, idx, rel)
499
+ for (ent, rels) in relations.entities.items()
500
+ for (idx, rel) in enumerate(rels.related_to)
501
+ ]
502
+ # Build triplets (dict) from each relation for use in reranker
503
+ triplets: list[dict[str, str]] = [
504
+ {
505
+ "head_entity": ent,
506
+ "relationship": rel.relation_label,
507
+ "tail_entity": rel.entity,
508
+ }
509
+ if rel.direction == RelationDirection.OUT
510
+ else {
511
+ "head_entity": rel.entity,
512
+ "relationship": rel.relation_label,
513
+ "tail_entity": ent,
514
+ }
515
+ for (ent, _, rel) in flat_rels
516
+ ]
517
+
518
+ # Dedupe triplets so that they get evaluated once; map triplet -> [orig_indices]
519
+ triplet_to_orig_indices: dict[tuple[str, str, str], list[int]] = {}
520
+ unique_triplets: list[dict[str, str]] = []
521
+
522
+ for i, t in enumerate(triplets):
523
+ key = (t["head_entity"], t["relationship"], t["tail_entity"])
524
+ if key not in triplet_to_orig_indices:
525
+ triplet_to_orig_indices[key] = []
526
+ unique_triplets.append(t)
527
+ triplet_to_orig_indices[key].append(i)
528
+
529
+ # Build the reranker model input
530
+ predict = get_predict()
531
+ rerank_model = RerankModel(
532
+ question=query,
533
+ user_id=user,
534
+ context={
535
+ str(idx): f"{t['head_entity']} {t['relationship']} {t['tail_entity']}"
536
+ for idx, t in enumerate(unique_triplets)
537
+ },
538
+ )
539
+ # Get the rerank scores
540
+ res = await predict.rerank(kbid, rerank_model)
541
+
542
+ # Convert returned scores to a list of (int_idx, score)
543
+ # where int_idx corresponds to indices in unique_triplets
544
+ reranked_indices_scores = [(int(idx), score) for idx, score in res.context_scores.items()]
545
+
546
+ return _scores_to_ranked_rels(
547
+ unique_triplets,
548
+ reranked_indices_scores,
549
+ triplet_to_orig_indices,
550
+ flat_rels,
551
+ top_k,
552
+ score_threshold,
553
+ )
554
+
555
+
556
+ async def rank_relations_generative(
557
+ relations: Relations,
558
+ query: str,
559
+ kbid: str,
560
+ user: str,
561
+ top_k: int,
562
+ generative_model: Optional[str] = None,
563
+ score_threshold: float = 2,
564
+ max_rels_to_eval: int = 100,
565
+ ) -> tuple[Relations, dict[str, list[float]]]:
566
+ # Store the index for keeping track after scoring
567
+ flat_rels: list[tuple[str, int, DirectionalRelation]] = [
568
+ (ent, idx, rel)
569
+ for (ent, rels) in relations.entities.items()
570
+ for (idx, rel) in enumerate(rels.related_to)
571
+ ]
572
+ triplets: list[dict[str, str]] = [
573
+ {
574
+ "head_entity": ent,
575
+ "relationship": rel.relation_label,
576
+ "tail_entity": rel.entity,
577
+ }
578
+ if rel.direction == RelationDirection.OUT
579
+ else {
580
+ "head_entity": rel.entity,
581
+ "relationship": rel.relation_label,
582
+ "tail_entity": ent,
583
+ }
584
+ for (ent, _, rel) in flat_rels
585
+ ]
586
+
587
+ # Dedupe triplets so that they get evaluated once, we will re-associate the scores later
588
+ triplet_to_orig_indices: dict[tuple[str, str, str], list[int]] = {}
589
+ unique_triplets: list[dict[str, str]] = []
590
+
591
+ for i, t in enumerate(triplets):
592
+ key = (t["head_entity"], t["relationship"], t["tail_entity"])
593
+ if key not in triplet_to_orig_indices:
594
+ triplet_to_orig_indices[key] = []
595
+ unique_triplets.append(t)
596
+ triplet_to_orig_indices[key].append(i)
597
+
598
+ if len(flat_rels) > max_rels_to_eval:
599
+ logger.warning(f"Too many relations to evaluate ({len(flat_rels)}), using reranker to reduce")
600
+ return await rank_relations_reranker(relations, query, kbid, user, top_k=max_rels_to_eval)
601
+
602
+ data = {
603
+ "question": query,
604
+ "triplets": unique_triplets,
605
+ }
606
+ prompt = PROMPT + json.dumps(data, indent=4)
607
+
608
+ predict = get_predict()
609
+ chat_model = ChatModel(
610
+ question=prompt,
611
+ user_id=user,
612
+ json_schema=SCHEMA,
613
+ format_prompt=False, # We supply our own prompt
614
+ query_context_order={},
615
+ query_context={},
616
+ user_prompt=UserPrompt(prompt=prompt),
617
+ max_tokens=4096,
618
+ generative_model=generative_model,
619
+ )
620
+
621
+ ident, model, answer_stream = await predict.chat_query_ndjson(kbid, chat_model)
622
+ response_json = None
623
+ status = None
624
+ _ = None
625
+
626
+ async for generative_chunk in answer_stream:
627
+ item = generative_chunk.chunk
628
+ if isinstance(item, JSONGenerativeResponse):
629
+ response_json = item
630
+ elif isinstance(item, StatusGenerativeResponse):
631
+ status = item
632
+ elif isinstance(item, MetaGenerativeResponse):
633
+ _ = item
634
+ else:
635
+ raise ValueError(f"Unknown generative chunk type: {item}")
636
+
637
+ if response_json is None or status is None or status.code != "0":
638
+ raise ValueError("No JSON response found")
639
+
640
+ scored_unique_triplets: list[dict[str, Union[str, Any]]] = response_json.object["triplets"]
641
+
642
+ if len(scored_unique_triplets) != len(unique_triplets):
643
+ raise ValueError("Mismatch between input and output triplets")
644
+
645
+ unique_indices_scores = ((idx, float(t["score"])) for (idx, t) in enumerate(scored_unique_triplets))
646
+
647
+ return _scores_to_ranked_rels(
648
+ unique_triplets,
649
+ unique_indices_scores,
650
+ triplet_to_orig_indices,
651
+ flat_rels,
652
+ top_k,
653
+ score_threshold,
654
+ )
655
+
656
+
657
+ def _scores_to_ranked_rels(
658
+ unique_triplets: list[dict[str, str]],
659
+ unique_indices_scores: Iterable[tuple[int, float]],
660
+ triplet_to_orig_indices: dict[tuple[str, str, str], list[int]],
661
+ flat_rels: list[tuple[str, int, DirectionalRelation]],
662
+ top_k: int,
663
+ score_threshold: float,
664
+ ) -> tuple[Relations, dict[str, list[float]]]:
665
+ """
666
+ Helper function to convert unique scores assigned by a model back to the original relations while taking
667
+ care of threshold
668
+ """
669
+ top_k_indices_scores = heapq.nlargest(top_k, unique_indices_scores, key=lambda x: x[1])
670
+
671
+ # Prepare a new Relations object + a dict of top scores by entity
672
+ top_k_rels: dict[str, EntitySubgraph] = defaultdict(lambda: EntitySubgraph(related_to=[]))
673
+ top_k_scores_by_ent: dict[str, list[float]] = defaultdict(list)
674
+ # Re-expand model scores to the original triplets
675
+ for idx, score in top_k_indices_scores:
676
+ # If the model's score is below threshold, skip
677
+ if score <= score_threshold:
678
+ continue
679
+
680
+ # Identify which original triplets (in flat_rels) this corresponds to
681
+ t = unique_triplets[idx]
682
+ key = (t["head_entity"], t["relationship"], t["tail_entity"])
683
+ orig_indices = triplet_to_orig_indices[key]
684
+
685
+ for orig_i in orig_indices:
686
+ ent, rel_idx, rel = flat_rels[orig_i]
687
+ # Insert the relation into top_k_rels
688
+ top_k_rels[ent].related_to.append(rel)
689
+
690
+ # Keep track of which indices were chosen per entity
691
+ top_k_scores_by_ent[ent].append(score)
692
+
693
+ return Relations(entities=top_k_rels), dict(top_k_scores_by_ent)
694
+
695
+
696
+ def build_text_blocks_from_relations(
697
+ relations: Relations,
698
+ scores: dict[str, list[float]],
699
+ ) -> list[TextBlockMatch]:
700
+ """
701
+ The goal of this function is to generate TextBlockMatch with custom text for each unique relation in the graph.
702
+
703
+ This is a hacky way to generate paragraphs from relations, and it is not the intended use of TextBlockMatch.
704
+ """
705
+ # Build a set of unique triplets with their scores
706
+ triplets: dict[tuple[str, str, str], tuple[float, Relations, Optional[ParagraphId]]] = defaultdict(
707
+ lambda: (0.0, Relations(entities={}), None)
708
+ )
709
+ for ent, subgraph in relations.entities.items():
710
+ for rel, score in zip(subgraph.related_to, scores[ent]):
711
+ key = (
712
+ (
713
+ ent,
714
+ rel.relation_label,
715
+ rel.entity,
716
+ )
717
+ if rel.direction == RelationDirection.OUT
718
+ else (rel.entity, rel.relation_label, ent)
719
+ )
720
+ existing_score, existing_relations, p_id = triplets[key]
721
+ if ent not in existing_relations.entities:
722
+ existing_relations.entities[ent] = EntitySubgraph(related_to=[])
723
+
724
+ # XXX: Since relations with the same triplet can point to different paragraphs,
725
+ # we keep the first one, but we lose the other ones
726
+ if p_id is None and rel.metadata and rel.metadata.paragraph_id:
727
+ p_id = ParagraphId.from_string(rel.metadata.paragraph_id)
728
+ existing_relations.entities[ent].related_to.append(rel)
729
+ # XXX: Here we use the max even though all relations with same triplet should have same score
730
+ triplets[key] = (max(existing_score, score), existing_relations, p_id)
731
+
732
+ # Build the text blocks
733
+ text_blocks = [
734
+ TextBlockMatch(
735
+ # XXX: Even though we are setting a paragraph_id, the text is not coming from the paragraph
736
+ paragraph_id=p_id,
737
+ score=score,
738
+ score_type=SCORE_TYPE.RELATION_RELEVANCE,
739
+ order=0,
740
+ text=f"- {ent} {rel} {tail}", # Manually build the text
741
+ position=TextPosition(
742
+ page_number=0,
743
+ index=0,
744
+ start=0,
745
+ end=0,
746
+ start_seconds=[],
747
+ end_seconds=[],
748
+ ),
749
+ field_labels=[],
750
+ paragraph_labels=[],
751
+ fuzzy_search=False,
752
+ is_a_table=False,
753
+ representation_file="",
754
+ page_with_visual=False,
755
+ relevant_relations=relations,
756
+ )
757
+ for (ent, rel, tail), (score, relations, p_id) in triplets.items()
758
+ if p_id is not None
759
+ ]
760
+ return text_blocks
761
+
762
+
763
+ def get_paragraph_info_from_relations(
764
+ relations: Relations,
765
+ scores: dict[str, list[float]],
766
+ ) -> list[RelationsParagraphMatch]:
767
+ """
768
+ Gathers paragraph info from the 'relations' object, merges relations by paragraph,
769
+ and removes paragraphs contained entirely within others.
770
+ """
771
+
772
+ # Group paragraphs by field so we can detect containment
773
+ paragraphs_by_field: dict[FieldId, list[RelationsParagraphMatch]] = defaultdict(list)
774
+
775
+ # Loop over each entity in the relation graph
776
+ for ent, subgraph in relations.entities.items():
777
+ for rel_score, rel in zip(scores[ent], subgraph.related_to):
778
+ if rel.metadata and rel.metadata.paragraph_id:
779
+ p_id = ParagraphId.from_string(rel.metadata.paragraph_id)
780
+ match = RelationsParagraphMatch(
781
+ paragraph_id=p_id,
782
+ score=rel_score,
783
+ relations=Relations(entities={ent: EntitySubgraph(related_to=[rel])}),
784
+ )
785
+ paragraphs_by_field[p_id.field_id].append(match)
786
+
787
+ # For each field, sort paragraphs by start asc, end desc, and do one pass to remove contained ones
788
+ final_paragraphs: list[RelationsParagraphMatch] = []
789
+
790
+ for _, paragraph_list in paragraphs_by_field.items():
791
+ # Sort by paragraph_start ascending; if tie, paragraph_end descending
792
+ paragraph_list.sort(
793
+ key=lambda m: (m.paragraph_id.paragraph_start, -m.paragraph_id.paragraph_end)
794
+ )
795
+
796
+ kept: list[RelationsParagraphMatch] = []
797
+ current_max_end = -1
798
+
799
+ for match in paragraph_list:
800
+ end = match.paragraph_id.paragraph_end
801
+
802
+ # If end <= current_max_end, this paragraph is contained last one
803
+ if end <= current_max_end:
804
+ # We merge the scores and relations
805
+ container = kept[-1]
806
+ container.score = max(container.score, match.score)
807
+ for ent, subgraph in match.relations.entities.items():
808
+ if ent not in container.relations.entities:
809
+ container.relations.entities[ent] = EntitySubgraph(related_to=[])
810
+ container.relations.entities[ent].related_to.extend(subgraph.related_to)
811
+
812
+ else:
813
+ # Not contained; keep it
814
+ kept.append(match)
815
+ current_max_end = end
816
+ final_paragraphs.extend(kept)
817
+
818
+ return final_paragraphs
819
+
820
+
821
+ async def build_graph_response(
822
+ *,
823
+ kbid: str,
824
+ query: str,
825
+ final_relations: Relations,
826
+ scores: dict[str, list[float]],
827
+ top_k: int,
828
+ reranker: Reranker,
829
+ relation_text_as_paragraphs: bool,
830
+ show: list[ResourceProperties] = [],
831
+ extracted: list[ExtractedDataTypeName] = [],
832
+ field_type_filter: list[FieldTypeName] = [],
833
+ ) -> KnowledgeboxFindResults:
834
+ if relation_text_as_paragraphs:
835
+ text_blocks = build_text_blocks_from_relations(final_relations, scores)
836
+ else:
837
+ paragraphs_info = get_paragraph_info_from_relations(final_relations, scores)
838
+ text_blocks = relations_matches_to_text_block_matches(paragraphs_info)
839
+
840
+ # hydrate and rerank
841
+ resource_hydration_options = ResourceHydrationOptions(
842
+ show=show, extracted=extracted, field_type_filter=field_type_filter
843
+ )
844
+ text_block_hydration_options = TextBlockHydrationOptions(only_hydrate_empty=True)
845
+ reranking_options = RerankingOptions(kbid=kbid, query=query)
846
+ text_blocks, resources, best_matches = await hydrate_and_rerank(
847
+ text_blocks,
848
+ kbid,
849
+ resource_hydration_options=resource_hydration_options,
850
+ text_block_hydration_options=text_block_hydration_options,
851
+ reranker=reranker,
852
+ reranking_options=reranking_options,
853
+ top_k=top_k,
854
+ )
855
+
856
+ find_resources = compose_find_resources(text_blocks, resources)
857
+
858
+ return KnowledgeboxFindResults(
859
+ query=query,
860
+ resources=find_resources,
861
+ best_matches=best_matches,
862
+ relations=final_relations,
863
+ total=len(text_blocks),
864
+ )
865
+
866
+
867
+ def filter_subgraph(subgraph: EntitySubgraph, entities_to_remove: Collection[str]) -> EntitySubgraph:
868
+ """
869
+ Removes the relationships with entities in `entities_to_remove` from the subgraph.
870
+ """
871
+ return EntitySubgraph(
872
+ related_to=[rel for rel in subgraph.related_to if rel.entity not in entities_to_remove]
873
+ )
874
+
875
+
876
+ def relations_match_to_text_block_match(
877
+ paragraph_match: RelationsParagraphMatch,
878
+ ) -> TextBlockMatch:
879
+ """
880
+ Given a paragraph_id, return a TextBlockMatch with the bare minimum fields
881
+ This is required by the Graph Strategy to get text blocks from the relevant paragraphs
882
+ """
883
+ # XXX: this is a workaround for the fact we always assume retrieval means keyword/semantic search and
884
+ # the hydration and find response building code works with TextBlockMatch, we extended it to have relevant relations information
885
+ parsed_paragraph_id = paragraph_match.paragraph_id
886
+ return TextBlockMatch(
887
+ paragraph_id=parsed_paragraph_id,
888
+ score=paragraph_match.score,
889
+ score_type=SCORE_TYPE.RELATION_RELEVANCE,
890
+ order=0, # NOTE: this will be filled later
891
+ text="", # NOTE: this will be filled later too
892
+ position=TextPosition(
893
+ page_number=0,
894
+ index=0,
895
+ start=parsed_paragraph_id.paragraph_start,
896
+ end=parsed_paragraph_id.paragraph_end,
897
+ start_seconds=[],
898
+ end_seconds=[],
899
+ ),
900
+ field_labels=[],
901
+ paragraph_labels=[],
902
+ fuzzy_search=False,
903
+ is_a_table=False,
904
+ representation_file="",
905
+ page_with_visual=False,
906
+ relevant_relations=paragraph_match.relations,
907
+ )
908
+
909
+
910
+ def relations_matches_to_text_block_matches(
911
+ paragraph_matches: Collection[RelationsParagraphMatch],
912
+ ) -> list[TextBlockMatch]:
913
+ return [relations_match_to_text_block_match(match) for match in paragraph_matches]