nucliadb 6.2.1.post2838__py3-none-any.whl → 6.2.1.post2853__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,915 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+
20
+ import asyncio
21
+ import heapq
22
+ import json
23
+ from collections import defaultdict
24
+ from datetime import datetime
25
+ from typing import Any, Collection, Iterable, Optional, Union
26
+
27
+ from nuclia_models.predict.generative_responses import (
28
+ JSONGenerativeResponse,
29
+ MetaGenerativeResponse,
30
+ StatusGenerativeResponse,
31
+ )
32
+ from pydantic import BaseModel
33
+ from sentry_sdk import capture_exception
34
+
35
+ from nucliadb.common.external_index_providers.base import TextBlockMatch
36
+ from nucliadb.common.ids import FieldId, ParagraphId
37
+ from nucliadb.search import logger
38
+ from nucliadb.search.requesters.utils import Method, node_query
39
+ from nucliadb.search.search.chat.query import (
40
+ find_request_from_ask_request,
41
+ get_relations_results_from_entities,
42
+ )
43
+ from nucliadb.search.search.find import query_parser_from_find_request
44
+ from nucliadb.search.search.find_merge import (
45
+ compose_find_resources,
46
+ hydrate_and_rerank,
47
+ )
48
+ from nucliadb.search.search.hydrator import ResourceHydrationOptions, TextBlockHydrationOptions
49
+ from nucliadb.search.search.merge import merge_suggest_results
50
+ from nucliadb.search.search.metrics import RAGMetrics
51
+ from nucliadb.search.search.query import QueryParser
52
+ from nucliadb.search.search.rerankers import Reranker, RerankingOptions
53
+ from nucliadb.search.utilities import get_predict
54
+ from nucliadb_models.common import FieldTypeName
55
+ from nucliadb_models.internal.predict import (
56
+ RerankModel,
57
+ )
58
+ from nucliadb_models.resource import ExtractedDataTypeName
59
+ from nucliadb_models.search import (
60
+ SCORE_TYPE,
61
+ AskRequest,
62
+ ChatModel,
63
+ DirectionalRelation,
64
+ EntitySubgraph,
65
+ GraphStrategy,
66
+ KnowledgeboxFindResults,
67
+ KnowledgeboxSuggestResults,
68
+ NucliaDBClientType,
69
+ QueryEntityDetection,
70
+ RelationDirection,
71
+ RelationRanking,
72
+ Relations,
73
+ ResourceProperties,
74
+ TextPosition,
75
+ UserPrompt,
76
+ )
77
+ from nucliadb_protos import nodereader_pb2
78
+ from nucliadb_protos.utils_pb2 import RelationNode
79
+
80
+ SCHEMA = {
81
+ "title": "score_triplets",
82
+ "description": "Return a list of triplets and their relevance scores (0-10) for the supplied question.",
83
+ "type": "object",
84
+ "properties": {
85
+ "triplets": {
86
+ "type": "array",
87
+ "description": "A list of triplets with their relevance scores.",
88
+ "items": {
89
+ "type": "object",
90
+ "properties": {
91
+ "head_entity": {"type": "string", "description": "The first entity in the triplet."},
92
+ "relationship": {
93
+ "type": "string",
94
+ "description": "The relationship between the two entities.",
95
+ },
96
+ "tail_entity": {
97
+ "type": "string",
98
+ "description": "The second entity in the triplet.",
99
+ },
100
+ "score": {
101
+ "type": "integer",
102
+ "description": "A relevance score in the range 0 to 10.",
103
+ "minimum": 0,
104
+ "maximum": 10,
105
+ },
106
+ },
107
+ "required": ["head_entity", "relationship", "tail_entity", "score"],
108
+ },
109
+ }
110
+ },
111
+ "required": ["triplets"],
112
+ }
113
+
114
+ PROMPT = """\
115
+ You are an advanced language model assisting in scoring relationships (edges) between two entities in a knowledge graph, given a user’s question.
116
+
117
+ For each provided **(head_entity, relationship, tail_entity)**, you must:
118
+ 1. Assign a **relevance score** between **0** and **10**.
119
+ 2. **0** means “this relationship can’t be relevant at all to the question.”
120
+ 3. **10** means “this relationship is extremely relevant to the question.”
121
+ 4. You may use **any integer** between 0 and 10 (e.g., 3, 7, etc.) based on how relevant you deem the relationship to be.
122
+ 5. **Language Agnosticism**: The question and the relationships may be in different languages. The relevance scoring should still work and be agnostic of the language.
123
+ 6. Relationships that may not answer the question directly but expand knowledge in a relevant way, should also be scored positively.
124
+
125
+ Once you have decided the best score for each triplet, return these results **using a function call** in JSON format with the following rules:
126
+
127
+ - The function name should be `score_triplets`.
128
+ - The first argument should be the list of triplets.
129
+ - Each triplet should have the following keys:
130
+ - `head_entity`: The first entity in the triplet.
131
+ - `relationship`: The relationship between the two entities.
132
+ - `tail_entity`: The second entity in the triplet.
133
+ - `score`: The relevance score in the range 0 to 10.
134
+
135
+ You **must** comply with the provided JSON Schema to ensure a well-structured response and mantain the order of the triplets.
136
+
137
+
138
+ ## Examples:
139
+
140
+ ### Example 1:
141
+
142
+ **Input**
143
+
144
+ {
145
+ "question": "Who is the mayor of the capital city of Australia?",
146
+ "triplets": [
147
+ {
148
+ "head_entity": "Australia",
149
+ "relationship": "has prime minister",
150
+ "tail_entity": "Scott Morrison"
151
+ },
152
+ {
153
+ "head_entity": "Canberra",
154
+ "relationship": "is capital of",
155
+ "tail_entity": "Australia"
156
+ },
157
+ {
158
+ "head_entity": "Scott Knowles",
159
+ "relationship": "holds position",
160
+ "tail_entity": "Mayor"
161
+ },
162
+ {
163
+ "head_entity": "Barbera Smith",
164
+ "relationship": "tiene cargo",
165
+ "tail_entity": "Alcalde"
166
+ },
167
+ {
168
+ "head_entity": "Austria",
169
+ "relationship": "has capital",
170
+ "tail_entity": "Vienna"
171
+ }
172
+ ]
173
+ }
174
+
175
+ **Output**
176
+
177
+ {
178
+ "triplets": [
179
+ {
180
+ "head_entity": "Australia",
181
+ "relationship": "has prime minister",
182
+ "tail_entity": "Scott Morrison",
183
+ "score": 4
184
+ },
185
+ {
186
+ "head_entity": "Canberra",
187
+ "relationship": "is capital of",
188
+ "tail_entity": "Australia",
189
+ "score": 8
190
+ },
191
+ {
192
+ "head_entity": "Scott Knowles",
193
+ "relationship": "holds position",
194
+ "tail_entity": "Mayor",
195
+ "score": 8
196
+ },
197
+ {
198
+ "head_entity": "Barbera Smith",
199
+ "relationship": "tiene cargo",
200
+ "tail_entity": "Alcalde",
201
+ "score": 8
202
+ },
203
+ {
204
+ "head_entity": "Austria",
205
+ "relationship": "has capital",
206
+ "tail_entity": "Vienna",
207
+ "score": 0
208
+ }
209
+ ]
210
+ }
211
+
212
+
213
+
214
+ ### Example 2:
215
+
216
+ **Input**
217
+
218
+ {
219
+ "question": "How many products does John Adams Roofing Inc. offer?",
220
+ "triplets": [
221
+ {
222
+ "head_entity": "John Adams Roofing Inc.",
223
+ "relationship": "has product",
224
+ "tail_entity": "Titanium Grade 3 Roofing Nails"
225
+ },
226
+ {
227
+ "head_entity": "John Adams Roofing Inc.",
228
+ "relationship": "is located in",
229
+ "tail_entity": "New York"
230
+ },
231
+ {
232
+ "head_entity": "John Adams Roofing Inc.",
233
+ "relationship": "was founded by",
234
+ "tail_entity": "John Adams"
235
+ },
236
+ {
237
+ "head_entity": "John Adams Roofing Inc.",
238
+ "relationship": "tiene stock",
239
+ "tail_entity": "Baldosas solares"
240
+ },
241
+ {
242
+ "head_entity": "John Adams Roofing Inc.",
243
+ "relationship": "has product",
244
+ "tail_entity": "Mercerized Cotton Thread"
245
+ }
246
+ ]
247
+ }
248
+
249
+ **Output**
250
+
251
+ {
252
+ "triplets": [
253
+ {
254
+ "head_entity": "John Adams Roofing Inc.",
255
+ "relationship": "has product",
256
+ "tail_entity": "Titanium Grade 3 Roofing Nails",
257
+ "score": 10
258
+ },
259
+ {
260
+ "head_entity": "John Adams Roofing Inc.",
261
+ "relationship": "is located in",
262
+ "tail_entity": "New York",
263
+ "score": 6
264
+ },
265
+ {
266
+ "head_entity": "John Adams Roofing Inc.",
267
+ "relationship": "was founded by",
268
+ "tail_entity": "John Adams",
269
+ "score": 5
270
+ },
271
+ {
272
+ "head_entity": "John Adams Roofing Inc.",
273
+ "relationship": "tiene stock",
274
+ "tail_entity": "Baldosas solares",
275
+ "score": 10
276
+ },
277
+ {
278
+ "head_entity": "John Adams Roofing Inc.",
279
+ "relationship": "has product",
280
+ "tail_entity": "Mercerized Cotton Thread",
281
+ "score": 10
282
+ }
283
+ ]
284
+ }
285
+
286
+ Now, let's get started! Here are the triplets you need to score:
287
+
288
+ **Input**
289
+
290
+ """
291
+
292
+
293
+ class RelationsParagraphMatch(BaseModel):
294
+ paragraph_id: ParagraphId
295
+ score: float
296
+ relations: Relations
297
+
298
+
299
+ async def get_graph_results(
300
+ *,
301
+ kbid: str,
302
+ query: str,
303
+ item: AskRequest,
304
+ ndb_client: NucliaDBClientType,
305
+ user: str,
306
+ origin: str,
307
+ graph_strategy: GraphStrategy,
308
+ generative_model: Optional[str] = None,
309
+ metrics: RAGMetrics = RAGMetrics(),
310
+ shards: Optional[list[str]] = None,
311
+ ) -> tuple[KnowledgeboxFindResults, QueryParser]:
312
+ relations = Relations(entities={})
313
+ explored_entities: set[str] = set()
314
+ scores: dict[str, list[float]] = {}
315
+ predict = get_predict()
316
+
317
+ for hop in range(graph_strategy.hops):
318
+ entities_to_explore: Iterable[RelationNode] = []
319
+
320
+ if hop == 0:
321
+ # Get the entities from the query
322
+ with metrics.time("graph_strat_query_entities"):
323
+ if graph_strategy.query_entity_detection == QueryEntityDetection.SUGGEST:
324
+ suggest_result = await fuzzy_search_entities(
325
+ kbid=kbid,
326
+ query=query,
327
+ range_creation_start=item.range_creation_start,
328
+ range_creation_end=item.range_creation_end,
329
+ range_modification_start=item.range_modification_start,
330
+ range_modification_end=item.range_modification_end,
331
+ target_shard_replicas=shards,
332
+ )
333
+ if suggest_result.entities is not None:
334
+ entities_to_explore = (
335
+ RelationNode(
336
+ ntype=RelationNode.NodeType.ENTITY,
337
+ value=result.value,
338
+ subtype=result.family,
339
+ )
340
+ for result in suggest_result.entities.entities
341
+ )
342
+ elif (
343
+ not entities_to_explore
344
+ or graph_strategy.query_entity_detection == QueryEntityDetection.PREDICT
345
+ ):
346
+ try:
347
+ entities_to_explore = await predict.detect_entities(kbid, query)
348
+ except Exception as e:
349
+ capture_exception(e)
350
+ logger.exception("Error in detecting entities for graph strategy")
351
+ entities_to_explore = []
352
+ else:
353
+ # Find neighbors of the current relations and remove the ones already explored
354
+ entities_to_explore = (
355
+ RelationNode(
356
+ ntype=RelationNode.NodeType.ENTITY,
357
+ value=relation.entity,
358
+ subtype=relation.entity_subtype,
359
+ )
360
+ for subgraph in relations.entities.values()
361
+ for relation in subgraph.related_to
362
+ if relation.entity not in explored_entities
363
+ )
364
+ # Get the relations for the new entities
365
+ with metrics.time("graph_strat_neighbor_relations"):
366
+ try:
367
+ new_relations = await get_relations_results_from_entities(
368
+ kbid=kbid,
369
+ entities=entities_to_explore,
370
+ target_shard_replicas=shards,
371
+ timeout=5.0,
372
+ only_with_metadata=True,
373
+ only_agentic_relations=graph_strategy.agentic_graph_only,
374
+ )
375
+ except Exception as e:
376
+ capture_exception(e)
377
+ logger.exception("Error in getting query relations for graph strategy")
378
+ new_relations = Relations(entities={})
379
+
380
+ # Removing the relations connected to the entities that were already explored
381
+ # XXX: This could be optimized by implementing a filter in the index
382
+ # so we don't have to remove them after
383
+ new_subgraphs = {
384
+ entity: filter_subgraph(subgraph, explored_entities)
385
+ for entity, subgraph in new_relations.entities.items()
386
+ }
387
+
388
+ explored_entities.update(new_subgraphs.keys())
389
+
390
+ if not new_subgraphs or all(not subgraph.related_to for subgraph in new_subgraphs.values()):
391
+ break
392
+
393
+ relations.entities.update(new_subgraphs)
394
+
395
+ # Rank the relevance of the relations
396
+ with metrics.time("graph_strat_rank_relations"):
397
+ try:
398
+ if graph_strategy.relation_ranking == RelationRanking.RERANKER:
399
+ relations, scores = await rank_relations_reranker(
400
+ relations,
401
+ query,
402
+ kbid,
403
+ user,
404
+ top_k=graph_strategy.top_k,
405
+ )
406
+ elif graph_strategy.relation_ranking == RelationRanking.GENERATIVE:
407
+ relations, scores = await rank_relations_generative(
408
+ relations,
409
+ query,
410
+ kbid,
411
+ user,
412
+ top_k=graph_strategy.top_k,
413
+ generative_model=generative_model,
414
+ )
415
+ except Exception as e:
416
+ capture_exception(e)
417
+ logger.exception("Error in ranking relations for graph strategy")
418
+ relations = Relations(entities={})
419
+ scores = {}
420
+ break
421
+
422
+ # Get the text blocks of the paragraphs that contain the top relations
423
+ with metrics.time("graph_strat_build_response"):
424
+ find_request = find_request_from_ask_request(item, query)
425
+ query_parser, rank_fusion, reranker = await query_parser_from_find_request(
426
+ kbid, find_request, generative_model=generative_model
427
+ )
428
+ find_results = await build_graph_response(
429
+ kbid=kbid,
430
+ query=query,
431
+ final_relations=relations,
432
+ scores=scores,
433
+ top_k=graph_strategy.top_k,
434
+ reranker=reranker,
435
+ show=find_request.show,
436
+ extracted=find_request.extracted,
437
+ field_type_filter=find_request.field_type_filter,
438
+ relation_text_as_paragraphs=graph_strategy.relation_text_as_paragraphs,
439
+ )
440
+ return find_results, query_parser
441
+
442
+
443
+ async def fuzzy_search_entities(
444
+ kbid: str,
445
+ query: str,
446
+ range_creation_start: Optional[datetime] = None,
447
+ range_creation_end: Optional[datetime] = None,
448
+ range_modification_start: Optional[datetime] = None,
449
+ range_modification_end: Optional[datetime] = None,
450
+ target_shard_replicas: Optional[list[str]] = None,
451
+ ) -> KnowledgeboxSuggestResults:
452
+ """Fuzzy find entities in KB given a query using the same methodology as /suggest, but split by words."""
453
+
454
+ base_request = nodereader_pb2.SuggestRequest(
455
+ body="", features=[nodereader_pb2.SuggestFeatures.ENTITIES]
456
+ )
457
+ if range_creation_start is not None:
458
+ base_request.timestamps.from_created.FromDatetime(range_creation_start)
459
+ if range_creation_end is not None:
460
+ base_request.timestamps.to_created.FromDatetime(range_creation_end)
461
+ if range_modification_start is not None:
462
+ base_request.timestamps.from_modified.FromDatetime(range_modification_start)
463
+ if range_modification_end is not None:
464
+ base_request.timestamps.to_modified.FromDatetime(range_modification_end)
465
+
466
+ tasks = []
467
+ # XXX: Splitting by words is not ideal, in the future, modify suggest to better handle this
468
+ for word in query.split():
469
+ if len(word) < 3:
470
+ continue
471
+ request = nodereader_pb2.SuggestRequest()
472
+ request.CopyFrom(base_request)
473
+ request.body = word
474
+ tasks.append(
475
+ node_query(kbid, Method.SUGGEST, request, target_shard_replicas=target_shard_replicas)
476
+ )
477
+
478
+ try:
479
+ results_raw = await asyncio.gather(*tasks)
480
+ return await merge_suggest_results(
481
+ [item for r in results_raw for item in r[0]],
482
+ kbid=kbid,
483
+ )
484
+ except Exception as e:
485
+ capture_exception(e)
486
+ logger.exception("Error in finding entities in query for graph strategy")
487
+ return KnowledgeboxSuggestResults(entities=None)
488
+
489
+
490
+ async def rank_relations_reranker(
491
+ relations: Relations,
492
+ query: str,
493
+ kbid: str,
494
+ user: str,
495
+ top_k: int,
496
+ score_threshold: float = 0.02,
497
+ ) -> tuple[Relations, dict[str, list[float]]]:
498
+ # Store the index for keeping track after scoring
499
+ flat_rels: list[tuple[str, int, DirectionalRelation]] = [
500
+ (ent, idx, rel)
501
+ for (ent, rels) in relations.entities.items()
502
+ for (idx, rel) in enumerate(rels.related_to)
503
+ ]
504
+ # Build triplets (dict) from each relation for use in reranker
505
+ triplets: list[dict[str, str]] = [
506
+ {
507
+ "head_entity": ent,
508
+ "relationship": rel.relation_label,
509
+ "tail_entity": rel.entity,
510
+ }
511
+ if rel.direction == RelationDirection.OUT
512
+ else {
513
+ "head_entity": rel.entity,
514
+ "relationship": rel.relation_label,
515
+ "tail_entity": ent,
516
+ }
517
+ for (ent, _, rel) in flat_rels
518
+ ]
519
+
520
+ # Dedupe triplets so that they get evaluated once; map triplet -> [orig_indices]
521
+ triplet_to_orig_indices: dict[tuple[str, str, str], list[int]] = {}
522
+ unique_triplets: list[dict[str, str]] = []
523
+
524
+ for i, t in enumerate(triplets):
525
+ key = (t["head_entity"], t["relationship"], t["tail_entity"])
526
+ if key not in triplet_to_orig_indices:
527
+ triplet_to_orig_indices[key] = []
528
+ unique_triplets.append(t)
529
+ triplet_to_orig_indices[key].append(i)
530
+
531
+ # Build the reranker model input
532
+ predict = get_predict()
533
+ rerank_model = RerankModel(
534
+ question=query,
535
+ user_id=user,
536
+ context={
537
+ str(idx): f"{t['head_entity']} {t['relationship']} {t['tail_entity']}"
538
+ for idx, t in enumerate(unique_triplets)
539
+ },
540
+ )
541
+ # Get the rerank scores
542
+ res = await predict.rerank(kbid, rerank_model)
543
+
544
+ # Convert returned scores to a list of (int_idx, score)
545
+ # where int_idx corresponds to indices in unique_triplets
546
+ reranked_indices_scores = [(int(idx), score) for idx, score in res.context_scores.items()]
547
+
548
+ return _scores_to_ranked_rels(
549
+ unique_triplets,
550
+ reranked_indices_scores,
551
+ triplet_to_orig_indices,
552
+ flat_rels,
553
+ top_k,
554
+ score_threshold,
555
+ )
556
+
557
+
558
+ async def rank_relations_generative(
559
+ relations: Relations,
560
+ query: str,
561
+ kbid: str,
562
+ user: str,
563
+ top_k: int,
564
+ generative_model: Optional[str] = None,
565
+ score_threshold: float = 2,
566
+ max_rels_to_eval: int = 100,
567
+ ) -> tuple[Relations, dict[str, list[float]]]:
568
+ # Store the index for keeping track after scoring
569
+ flat_rels: list[tuple[str, int, DirectionalRelation]] = [
570
+ (ent, idx, rel)
571
+ for (ent, rels) in relations.entities.items()
572
+ for (idx, rel) in enumerate(rels.related_to)
573
+ ]
574
+ triplets: list[dict[str, str]] = [
575
+ {
576
+ "head_entity": ent,
577
+ "relationship": rel.relation_label,
578
+ "tail_entity": rel.entity,
579
+ }
580
+ if rel.direction == RelationDirection.OUT
581
+ else {
582
+ "head_entity": rel.entity,
583
+ "relationship": rel.relation_label,
584
+ "tail_entity": ent,
585
+ }
586
+ for (ent, _, rel) in flat_rels
587
+ ]
588
+
589
+ # Dedupe triplets so that they get evaluated once, we will re-associate the scores later
590
+ triplet_to_orig_indices: dict[tuple[str, str, str], list[int]] = {}
591
+ unique_triplets: list[dict[str, str]] = []
592
+
593
+ for i, t in enumerate(triplets):
594
+ key = (t["head_entity"], t["relationship"], t["tail_entity"])
595
+ if key not in triplet_to_orig_indices:
596
+ triplet_to_orig_indices[key] = []
597
+ unique_triplets.append(t)
598
+ triplet_to_orig_indices[key].append(i)
599
+
600
+ if len(flat_rels) > max_rels_to_eval:
601
+ logger.warning(f"Too many relations to evaluate ({len(flat_rels)}), using reranker to reduce")
602
+ return await rank_relations_reranker(relations, query, kbid, user, top_k=max_rels_to_eval)
603
+
604
+ data = {
605
+ "question": query,
606
+ "triplets": unique_triplets,
607
+ }
608
+ prompt = PROMPT + json.dumps(data, indent=4)
609
+
610
+ predict = get_predict()
611
+ chat_model = ChatModel(
612
+ question=prompt,
613
+ user_id=user,
614
+ json_schema=SCHEMA,
615
+ format_prompt=False, # We supply our own prompt
616
+ query_context_order={},
617
+ query_context={},
618
+ user_prompt=UserPrompt(prompt=prompt),
619
+ max_tokens=4096,
620
+ generative_model=generative_model,
621
+ )
622
+
623
+ ident, model, answer_stream = await predict.chat_query_ndjson(kbid, chat_model)
624
+ response_json = None
625
+ status = None
626
+ _ = None
627
+
628
+ async for generative_chunk in answer_stream:
629
+ item = generative_chunk.chunk
630
+ if isinstance(item, JSONGenerativeResponse):
631
+ response_json = item
632
+ elif isinstance(item, StatusGenerativeResponse):
633
+ status = item
634
+ elif isinstance(item, MetaGenerativeResponse):
635
+ _ = item
636
+ else:
637
+ raise ValueError(f"Unknown generative chunk type: {item}")
638
+
639
+ if response_json is None or status is None or status.code != "0":
640
+ raise ValueError("No JSON response found")
641
+
642
+ scored_unique_triplets: list[dict[str, Union[str, Any]]] = response_json.object["triplets"]
643
+
644
+ if len(scored_unique_triplets) != len(unique_triplets):
645
+ raise ValueError("Mismatch between input and output triplets")
646
+
647
+ unique_indices_scores = ((idx, float(t["score"])) for (idx, t) in enumerate(scored_unique_triplets))
648
+
649
+ return _scores_to_ranked_rels(
650
+ unique_triplets,
651
+ unique_indices_scores,
652
+ triplet_to_orig_indices,
653
+ flat_rels,
654
+ top_k,
655
+ score_threshold,
656
+ )
657
+
658
+
659
+ def _scores_to_ranked_rels(
660
+ unique_triplets: list[dict[str, str]],
661
+ unique_indices_scores: Iterable[tuple[int, float]],
662
+ triplet_to_orig_indices: dict[tuple[str, str, str], list[int]],
663
+ flat_rels: list[tuple[str, int, DirectionalRelation]],
664
+ top_k: int,
665
+ score_threshold: float,
666
+ ) -> tuple[Relations, dict[str, list[float]]]:
667
+ """
668
+ Helper function to convert unique scores assigned by a model back to the original relations while taking
669
+ care of threshold
670
+ """
671
+ top_k_indices_scores = heapq.nlargest(top_k, unique_indices_scores, key=lambda x: x[1])
672
+
673
+ # Prepare a new Relations object + a dict of top scores by entity
674
+ top_k_rels: dict[str, EntitySubgraph] = defaultdict(lambda: EntitySubgraph(related_to=[]))
675
+ top_k_scores_by_ent: dict[str, list[float]] = defaultdict(list)
676
+ # Re-expand model scores to the original triplets
677
+ for idx, score in top_k_indices_scores:
678
+ # If the model's score is below threshold, skip
679
+ if score <= score_threshold:
680
+ continue
681
+
682
+ # Identify which original triplets (in flat_rels) this corresponds to
683
+ t = unique_triplets[idx]
684
+ key = (t["head_entity"], t["relationship"], t["tail_entity"])
685
+ orig_indices = triplet_to_orig_indices[key]
686
+
687
+ for orig_i in orig_indices:
688
+ ent, rel_idx, rel = flat_rels[orig_i]
689
+ # Insert the relation into top_k_rels
690
+ top_k_rels[ent].related_to.append(rel)
691
+
692
+ # Keep track of which indices were chosen per entity
693
+ top_k_scores_by_ent[ent].append(score)
694
+
695
+ return Relations(entities=top_k_rels), dict(top_k_scores_by_ent)
696
+
697
+
698
+ def build_text_blocks_from_relations(
699
+ relations: Relations,
700
+ scores: dict[str, list[float]],
701
+ ) -> list[TextBlockMatch]:
702
+ """
703
+ The goal of this function is to generate TextBlockMatch with custom text for each unique relation in the graph.
704
+
705
+ This is a hacky way to generate paragraphs from relations, and it is not the intended use of TextBlockMatch.
706
+ """
707
+ # Build a set of unique triplets with their scores
708
+ triplets: dict[tuple[str, str, str], tuple[float, Relations, Optional[ParagraphId]]] = defaultdict(
709
+ lambda: (0.0, Relations(entities={}), None)
710
+ )
711
+ for ent, subgraph in relations.entities.items():
712
+ for rel, score in zip(subgraph.related_to, scores[ent]):
713
+ key = (
714
+ (
715
+ ent,
716
+ rel.relation_label,
717
+ rel.entity,
718
+ )
719
+ if rel.direction == RelationDirection.OUT
720
+ else (rel.entity, rel.relation_label, ent)
721
+ )
722
+ existing_score, existing_relations, p_id = triplets[key]
723
+ if ent not in existing_relations.entities:
724
+ existing_relations.entities[ent] = EntitySubgraph(related_to=[])
725
+
726
+ # XXX: Since relations with the same triplet can point to different paragraphs,
727
+ # we keep the first one, but we lose the other ones
728
+ if p_id is None and rel.metadata and rel.metadata.paragraph_id:
729
+ p_id = ParagraphId.from_string(rel.metadata.paragraph_id)
730
+ existing_relations.entities[ent].related_to.append(rel)
731
+ # XXX: Here we use the max even though all relations with same triplet should have same score
732
+ triplets[key] = (max(existing_score, score), existing_relations, p_id)
733
+
734
+ # Build the text blocks
735
+ text_blocks = [
736
+ TextBlockMatch(
737
+ # XXX: Even though we are setting a paragraph_id, the text is not coming from the paragraph
738
+ paragraph_id=p_id,
739
+ score=score,
740
+ score_type=SCORE_TYPE.RELATION_RELEVANCE,
741
+ order=0,
742
+ text=f"- {ent} {rel} {tail}", # Manually build the text
743
+ position=TextPosition(
744
+ page_number=0,
745
+ index=0,
746
+ start=0,
747
+ end=0,
748
+ start_seconds=[],
749
+ end_seconds=[],
750
+ ),
751
+ field_labels=[],
752
+ paragraph_labels=[],
753
+ fuzzy_search=False,
754
+ is_a_table=False,
755
+ representation_file="",
756
+ page_with_visual=False,
757
+ relevant_relations=relations,
758
+ )
759
+ for (ent, rel, tail), (score, relations, p_id) in triplets.items()
760
+ if p_id is not None
761
+ ]
762
+ return text_blocks
763
+
764
+
765
+ def get_paragraph_info_from_relations(
766
+ relations: Relations,
767
+ scores: dict[str, list[float]],
768
+ ) -> list[RelationsParagraphMatch]:
769
+ """
770
+ Gathers paragraph info from the 'relations' object, merges relations by paragraph,
771
+ and removes paragraphs contained entirely within others.
772
+ """
773
+
774
+ # Group paragraphs by field so we can detect containment
775
+ paragraphs_by_field: dict[FieldId, list[RelationsParagraphMatch]] = defaultdict(list)
776
+
777
+ # Loop over each entity in the relation graph
778
+ for ent, subgraph in relations.entities.items():
779
+ for rel_score, rel in zip(scores[ent], subgraph.related_to):
780
+ if rel.metadata and rel.metadata.paragraph_id:
781
+ p_id = ParagraphId.from_string(rel.metadata.paragraph_id)
782
+ match = RelationsParagraphMatch(
783
+ paragraph_id=p_id,
784
+ score=rel_score,
785
+ relations=Relations(entities={ent: EntitySubgraph(related_to=[rel])}),
786
+ )
787
+ paragraphs_by_field[p_id.field_id].append(match)
788
+
789
+ # For each field, sort paragraphs by start asc, end desc, and do one pass to remove contained ones
790
+ final_paragraphs: list[RelationsParagraphMatch] = []
791
+
792
+ for _, paragraph_list in paragraphs_by_field.items():
793
+ # Sort by paragraph_start ascending; if tie, paragraph_end descending
794
+ paragraph_list.sort(
795
+ key=lambda m: (m.paragraph_id.paragraph_start, -m.paragraph_id.paragraph_end)
796
+ )
797
+
798
+ kept: list[RelationsParagraphMatch] = []
799
+ current_max_end = -1
800
+
801
+ for match in paragraph_list:
802
+ end = match.paragraph_id.paragraph_end
803
+
804
+ # If end <= current_max_end, this paragraph is contained last one
805
+ if end <= current_max_end:
806
+ # We merge the scores and relations
807
+ container = kept[-1]
808
+ container.score = max(container.score, match.score)
809
+ for ent, subgraph in match.relations.entities.items():
810
+ if ent not in container.relations.entities:
811
+ container.relations.entities[ent] = EntitySubgraph(related_to=[])
812
+ container.relations.entities[ent].related_to.extend(subgraph.related_to)
813
+
814
+ else:
815
+ # Not contained; keep it
816
+ kept.append(match)
817
+ current_max_end = end
818
+ final_paragraphs.extend(kept)
819
+
820
+ return final_paragraphs
821
+
822
+
823
+ async def build_graph_response(
824
+ *,
825
+ kbid: str,
826
+ query: str,
827
+ final_relations: Relations,
828
+ scores: dict[str, list[float]],
829
+ top_k: int,
830
+ reranker: Reranker,
831
+ relation_text_as_paragraphs: bool,
832
+ show: list[ResourceProperties] = [],
833
+ extracted: list[ExtractedDataTypeName] = [],
834
+ field_type_filter: list[FieldTypeName] = [],
835
+ ) -> KnowledgeboxFindResults:
836
+ if relation_text_as_paragraphs:
837
+ text_blocks = build_text_blocks_from_relations(final_relations, scores)
838
+ else:
839
+ paragraphs_info = get_paragraph_info_from_relations(final_relations, scores)
840
+ text_blocks = relations_matches_to_text_block_matches(paragraphs_info)
841
+
842
+ # hydrate and rerank
843
+ resource_hydration_options = ResourceHydrationOptions(
844
+ show=show, extracted=extracted, field_type_filter=field_type_filter
845
+ )
846
+ text_block_hydration_options = TextBlockHydrationOptions(only_hydrate_empty=True)
847
+ reranking_options = RerankingOptions(kbid=kbid, query=query)
848
+ text_blocks, resources, best_matches = await hydrate_and_rerank(
849
+ text_blocks,
850
+ kbid,
851
+ resource_hydration_options=resource_hydration_options,
852
+ text_block_hydration_options=text_block_hydration_options,
853
+ reranker=reranker,
854
+ reranking_options=reranking_options,
855
+ top_k=top_k,
856
+ )
857
+
858
+ find_resources = compose_find_resources(text_blocks, resources)
859
+
860
+ return KnowledgeboxFindResults(
861
+ query=query,
862
+ resources=find_resources,
863
+ best_matches=best_matches,
864
+ relations=final_relations,
865
+ total=len(text_blocks),
866
+ )
867
+
868
+
869
+ def filter_subgraph(subgraph: EntitySubgraph, entities_to_remove: Collection[str]) -> EntitySubgraph:
870
+ """
871
+ Removes the relationships with entities in `entities_to_remove` from the subgraph.
872
+ """
873
+ return EntitySubgraph(
874
+ related_to=[rel for rel in subgraph.related_to if rel.entity not in entities_to_remove]
875
+ )
876
+
877
+
878
+ def relations_match_to_text_block_match(
879
+ paragraph_match: RelationsParagraphMatch,
880
+ ) -> TextBlockMatch:
881
+ """
882
+ Given a paragraph_id, return a TextBlockMatch with the bare minimum fields
883
+ This is required by the Graph Strategy to get text blocks from the relevant paragraphs
884
+ """
885
+ # XXX: this is a workaround for the fact we always assume retrieval means keyword/semantic search and
886
+ # the hydration and find response building code works with TextBlockMatch, we extended it to have relevant relations information
887
+ parsed_paragraph_id = paragraph_match.paragraph_id
888
+ return TextBlockMatch(
889
+ paragraph_id=parsed_paragraph_id,
890
+ score=paragraph_match.score,
891
+ score_type=SCORE_TYPE.RELATION_RELEVANCE,
892
+ order=0, # NOTE: this will be filled later
893
+ text="", # NOTE: this will be filled later too
894
+ position=TextPosition(
895
+ page_number=0,
896
+ index=0,
897
+ start=parsed_paragraph_id.paragraph_start,
898
+ end=parsed_paragraph_id.paragraph_end,
899
+ start_seconds=[],
900
+ end_seconds=[],
901
+ ),
902
+ field_labels=[],
903
+ paragraph_labels=[],
904
+ fuzzy_search=False,
905
+ is_a_table=False,
906
+ representation_file="",
907
+ page_with_visual=False,
908
+ relevant_relations=paragraph_match.relations,
909
+ )
910
+
911
+
912
+ def relations_matches_to_text_block_matches(
913
+ paragraph_matches: Collection[RelationsParagraphMatch],
914
+ ) -> list[TextBlockMatch]:
915
+ return [relations_match_to_text_block_match(match) for match in paragraph_matches]