nucliadb 6.2.0.post2679__py3-none-any.whl → 6.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. migrations/0028_extracted_vectors_reference.py +61 -0
  2. migrations/0029_backfill_field_status.py +149 -0
  3. migrations/0030_label_deduplication.py +60 -0
  4. nucliadb/common/cluster/manager.py +41 -331
  5. nucliadb/common/cluster/rebalance.py +2 -2
  6. nucliadb/common/cluster/rollover.py +12 -71
  7. nucliadb/common/cluster/settings.py +3 -0
  8. nucliadb/common/cluster/standalone/utils.py +0 -43
  9. nucliadb/common/cluster/utils.py +0 -16
  10. nucliadb/common/counters.py +1 -0
  11. nucliadb/common/datamanagers/fields.py +48 -7
  12. nucliadb/common/datamanagers/vectorsets.py +11 -2
  13. nucliadb/common/external_index_providers/base.py +2 -1
  14. nucliadb/common/external_index_providers/pinecone.py +3 -5
  15. nucliadb/common/ids.py +18 -4
  16. nucliadb/common/models_utils/from_proto.py +479 -0
  17. nucliadb/common/models_utils/to_proto.py +60 -0
  18. nucliadb/common/nidx.py +76 -37
  19. nucliadb/export_import/models.py +3 -3
  20. nucliadb/health.py +0 -7
  21. nucliadb/ingest/app.py +0 -8
  22. nucliadb/ingest/consumer/auditing.py +1 -1
  23. nucliadb/ingest/consumer/shard_creator.py +1 -1
  24. nucliadb/ingest/fields/base.py +83 -21
  25. nucliadb/ingest/orm/brain.py +55 -56
  26. nucliadb/ingest/orm/broker_message.py +12 -2
  27. nucliadb/ingest/orm/entities.py +6 -17
  28. nucliadb/ingest/orm/knowledgebox.py +44 -22
  29. nucliadb/ingest/orm/processor/data_augmentation.py +7 -29
  30. nucliadb/ingest/orm/processor/processor.py +5 -2
  31. nucliadb/ingest/orm/resource.py +222 -413
  32. nucliadb/ingest/processing.py +8 -2
  33. nucliadb/ingest/serialize.py +77 -46
  34. nucliadb/ingest/service/writer.py +2 -56
  35. nucliadb/ingest/settings.py +1 -4
  36. nucliadb/learning_proxy.py +6 -4
  37. nucliadb/purge/__init__.py +102 -12
  38. nucliadb/purge/orphan_shards.py +6 -4
  39. nucliadb/reader/api/models.py +3 -3
  40. nucliadb/reader/api/v1/__init__.py +1 -0
  41. nucliadb/reader/api/v1/download.py +2 -2
  42. nucliadb/reader/api/v1/knowledgebox.py +3 -3
  43. nucliadb/reader/api/v1/resource.py +23 -12
  44. nucliadb/reader/api/v1/services.py +4 -4
  45. nucliadb/reader/api/v1/vectorsets.py +48 -0
  46. nucliadb/search/api/v1/ask.py +11 -1
  47. nucliadb/search/api/v1/feedback.py +3 -3
  48. nucliadb/search/api/v1/knowledgebox.py +8 -13
  49. nucliadb/search/api/v1/search.py +3 -2
  50. nucliadb/search/api/v1/suggest.py +0 -2
  51. nucliadb/search/predict.py +6 -4
  52. nucliadb/search/requesters/utils.py +1 -2
  53. nucliadb/search/search/chat/ask.py +77 -13
  54. nucliadb/search/search/chat/prompt.py +16 -5
  55. nucliadb/search/search/chat/query.py +74 -34
  56. nucliadb/search/search/exceptions.py +2 -7
  57. nucliadb/search/search/find.py +9 -5
  58. nucliadb/search/search/find_merge.py +10 -4
  59. nucliadb/search/search/graph_strategy.py +884 -0
  60. nucliadb/search/search/hydrator.py +6 -0
  61. nucliadb/search/search/merge.py +79 -24
  62. nucliadb/search/search/query.py +74 -245
  63. nucliadb/search/search/query_parser/exceptions.py +11 -1
  64. nucliadb/search/search/query_parser/fetcher.py +405 -0
  65. nucliadb/search/search/query_parser/models.py +0 -3
  66. nucliadb/search/search/query_parser/parser.py +22 -21
  67. nucliadb/search/search/rerankers.py +1 -42
  68. nucliadb/search/search/shards.py +19 -0
  69. nucliadb/standalone/api_router.py +2 -14
  70. nucliadb/standalone/settings.py +4 -0
  71. nucliadb/train/generators/field_streaming.py +7 -3
  72. nucliadb/train/lifecycle.py +3 -6
  73. nucliadb/train/nodes.py +14 -12
  74. nucliadb/train/resource.py +380 -0
  75. nucliadb/writer/api/constants.py +20 -16
  76. nucliadb/writer/api/v1/__init__.py +1 -0
  77. nucliadb/writer/api/v1/export_import.py +1 -1
  78. nucliadb/writer/api/v1/field.py +13 -7
  79. nucliadb/writer/api/v1/knowledgebox.py +3 -46
  80. nucliadb/writer/api/v1/resource.py +20 -13
  81. nucliadb/writer/api/v1/services.py +10 -1
  82. nucliadb/writer/api/v1/upload.py +61 -34
  83. nucliadb/writer/{vectorsets.py → api/v1/vectorsets.py} +99 -47
  84. nucliadb/writer/back_pressure.py +17 -46
  85. nucliadb/writer/resource/basic.py +9 -7
  86. nucliadb/writer/resource/field.py +42 -9
  87. nucliadb/writer/settings.py +2 -2
  88. nucliadb/writer/tus/gcs.py +11 -10
  89. {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/METADATA +11 -14
  90. {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/RECORD +94 -96
  91. {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/WHEEL +1 -1
  92. nucliadb/common/cluster/discovery/base.py +0 -178
  93. nucliadb/common/cluster/discovery/k8s.py +0 -301
  94. nucliadb/common/cluster/discovery/manual.py +0 -57
  95. nucliadb/common/cluster/discovery/single.py +0 -51
  96. nucliadb/common/cluster/discovery/types.py +0 -32
  97. nucliadb/common/cluster/discovery/utils.py +0 -67
  98. nucliadb/common/cluster/standalone/grpc_node_binding.py +0 -349
  99. nucliadb/common/cluster/standalone/index_node.py +0 -123
  100. nucliadb/common/cluster/standalone/service.py +0 -84
  101. nucliadb/standalone/introspect.py +0 -208
  102. nucliadb-6.2.0.post2679.dist-info/zip-safe +0 -1
  103. /nucliadb/common/{cluster/discovery → models_utils}/__init__.py +0 -0
  104. {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/entry_points.txt +0 -0
  105. {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,884 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+
20
+ import heapq
21
+ import json
22
+ from collections import defaultdict
23
+ from typing import Any, Collection, Iterable, Optional, Union
24
+
25
+ from nuclia_models.predict.generative_responses import (
26
+ JSONGenerativeResponse,
27
+ MetaGenerativeResponse,
28
+ StatusGenerativeResponse,
29
+ )
30
+ from pydantic import BaseModel
31
+ from sentry_sdk import capture_exception
32
+
33
+ from nucliadb.common.external_index_providers.base import TextBlockMatch
34
+ from nucliadb.common.ids import FieldId, ParagraphId
35
+ from nucliadb.search import logger
36
+ from nucliadb.search.requesters.utils import Method, node_query
37
+ from nucliadb.search.search.chat.query import (
38
+ find_request_from_ask_request,
39
+ get_relations_results_from_entities,
40
+ )
41
+ from nucliadb.search.search.find import query_parser_from_find_request
42
+ from nucliadb.search.search.find_merge import (
43
+ compose_find_resources,
44
+ hydrate_and_rerank,
45
+ )
46
+ from nucliadb.search.search.hydrator import ResourceHydrationOptions, TextBlockHydrationOptions
47
+ from nucliadb.search.search.merge import merge_relation_prefix_results
48
+ from nucliadb.search.search.metrics import RAGMetrics
49
+ from nucliadb.search.search.rerankers import Reranker, RerankingOptions
50
+ from nucliadb.search.utilities import get_predict
51
+ from nucliadb_models.common import FieldTypeName
52
+ from nucliadb_models.internal.predict import (
53
+ RerankModel,
54
+ )
55
+ from nucliadb_models.resource import ExtractedDataTypeName
56
+ from nucliadb_models.search import (
57
+ SCORE_TYPE,
58
+ AskRequest,
59
+ ChatModel,
60
+ DirectionalRelation,
61
+ EntitySubgraph,
62
+ FindRequest,
63
+ GraphStrategy,
64
+ KnowledgeboxFindResults,
65
+ NucliaDBClientType,
66
+ QueryEntityDetection,
67
+ RelatedEntities,
68
+ RelationDirection,
69
+ RelationRanking,
70
+ Relations,
71
+ ResourceProperties,
72
+ TextPosition,
73
+ UserPrompt,
74
+ )
75
+ from nucliadb_protos import nodereader_pb2
76
+ from nucliadb_protos.utils_pb2 import RelationNode
77
+
78
+ SCHEMA = {
79
+ "title": "score_triplets",
80
+ "description": "Return a list of triplets and their relevance scores (0-10) for the supplied question.",
81
+ "type": "object",
82
+ "properties": {
83
+ "triplets": {
84
+ "type": "array",
85
+ "description": "A list of triplets with their relevance scores.",
86
+ "items": {
87
+ "type": "object",
88
+ "properties": {
89
+ "head_entity": {"type": "string", "description": "The first entity in the triplet."},
90
+ "relationship": {
91
+ "type": "string",
92
+ "description": "The relationship between the two entities.",
93
+ },
94
+ "tail_entity": {
95
+ "type": "string",
96
+ "description": "The second entity in the triplet.",
97
+ },
98
+ "score": {
99
+ "type": "integer",
100
+ "description": "A relevance score in the range 0 to 10.",
101
+ "minimum": 0,
102
+ "maximum": 10,
103
+ },
104
+ },
105
+ "required": ["head_entity", "relationship", "tail_entity", "score"],
106
+ },
107
+ }
108
+ },
109
+ "required": ["triplets"],
110
+ }
111
+
112
+ PROMPT = """\
113
+ You are an advanced language model assisting in scoring relationships (edges) between two entities in a knowledge graph, given a user’s question.
114
+
115
+ For each provided **(head_entity, relationship, tail_entity)**, you must:
116
+ 1. Assign a **relevance score** between **0** and **10**.
117
+ 2. **0** means “this relationship can’t be relevant at all to the question.”
118
+ 3. **10** means “this relationship is extremely relevant to the question.”
119
+ 4. You may use **any integer** between 0 and 10 (e.g., 3, 7, etc.) based on how relevant you deem the relationship to be.
120
+ 5. **Language Agnosticism**: The question and the relationships may be in different languages. The relevance scoring should still work and be agnostic of the language.
121
+ 6. Relationships that may not answer the question directly but expand knowledge in a relevant way, should also be scored positively.
122
+
123
+ Once you have decided the best score for each triplet, return these results **using a function call** in JSON format with the following rules:
124
+
125
+ - The function name should be `score_triplets`.
126
+ - The first argument should be the list of triplets.
127
+ - Each triplet should have the following keys:
128
+ - `head_entity`: The first entity in the triplet.
129
+ - `relationship`: The relationship between the two entities.
130
+ - `tail_entity`: The second entity in the triplet.
131
+ - `score`: The relevance score in the range 0 to 10.
132
+
133
+ You **must** comply with the provided JSON Schema to ensure a well-structured response and mantain the order of the triplets.
134
+
135
+
136
+ ## Examples:
137
+
138
+ ### Example 1:
139
+
140
+ **Input**
141
+
142
+ {
143
+ "question": "Who is the mayor of the capital city of Australia?",
144
+ "triplets": [
145
+ {
146
+ "head_entity": "Australia",
147
+ "relationship": "has prime minister",
148
+ "tail_entity": "Scott Morrison"
149
+ },
150
+ {
151
+ "head_entity": "Canberra",
152
+ "relationship": "is capital of",
153
+ "tail_entity": "Australia"
154
+ },
155
+ {
156
+ "head_entity": "Scott Knowles",
157
+ "relationship": "holds position",
158
+ "tail_entity": "Mayor"
159
+ },
160
+ {
161
+ "head_entity": "Barbera Smith",
162
+ "relationship": "tiene cargo",
163
+ "tail_entity": "Alcalde"
164
+ },
165
+ {
166
+ "head_entity": "Austria",
167
+ "relationship": "has capital",
168
+ "tail_entity": "Vienna"
169
+ }
170
+ ]
171
+ }
172
+
173
+ **Output**
174
+
175
+ {
176
+ "triplets": [
177
+ {
178
+ "head_entity": "Australia",
179
+ "relationship": "has prime minister",
180
+ "tail_entity": "Scott Morrison",
181
+ "score": 4
182
+ },
183
+ {
184
+ "head_entity": "Canberra",
185
+ "relationship": "is capital of",
186
+ "tail_entity": "Australia",
187
+ "score": 8
188
+ },
189
+ {
190
+ "head_entity": "Scott Knowles",
191
+ "relationship": "holds position",
192
+ "tail_entity": "Mayor",
193
+ "score": 8
194
+ },
195
+ {
196
+ "head_entity": "Barbera Smith",
197
+ "relationship": "tiene cargo",
198
+ "tail_entity": "Alcalde",
199
+ "score": 8
200
+ },
201
+ {
202
+ "head_entity": "Austria",
203
+ "relationship": "has capital",
204
+ "tail_entity": "Vienna",
205
+ "score": 0
206
+ }
207
+ ]
208
+ }
209
+
210
+
211
+
212
+ ### Example 2:
213
+
214
+ **Input**
215
+
216
+ {
217
+ "question": "How many products does John Adams Roofing Inc. offer?",
218
+ "triplets": [
219
+ {
220
+ "head_entity": "John Adams Roofing Inc.",
221
+ "relationship": "has product",
222
+ "tail_entity": "Titanium Grade 3 Roofing Nails"
223
+ },
224
+ {
225
+ "head_entity": "John Adams Roofing Inc.",
226
+ "relationship": "is located in",
227
+ "tail_entity": "New York"
228
+ },
229
+ {
230
+ "head_entity": "John Adams Roofing Inc.",
231
+ "relationship": "was founded by",
232
+ "tail_entity": "John Adams"
233
+ },
234
+ {
235
+ "head_entity": "John Adams Roofing Inc.",
236
+ "relationship": "tiene stock",
237
+ "tail_entity": "Baldosas solares"
238
+ },
239
+ {
240
+ "head_entity": "John Adams Roofing Inc.",
241
+ "relationship": "has product",
242
+ "tail_entity": "Mercerized Cotton Thread"
243
+ }
244
+ ]
245
+ }
246
+
247
+ **Output**
248
+
249
+ {
250
+ "triplets": [
251
+ {
252
+ "head_entity": "John Adams Roofing Inc.",
253
+ "relationship": "has product",
254
+ "tail_entity": "Titanium Grade 3 Roofing Nails",
255
+ "score": 10
256
+ },
257
+ {
258
+ "head_entity": "John Adams Roofing Inc.",
259
+ "relationship": "is located in",
260
+ "tail_entity": "New York",
261
+ "score": 6
262
+ },
263
+ {
264
+ "head_entity": "John Adams Roofing Inc.",
265
+ "relationship": "was founded by",
266
+ "tail_entity": "John Adams",
267
+ "score": 5
268
+ },
269
+ {
270
+ "head_entity": "John Adams Roofing Inc.",
271
+ "relationship": "tiene stock",
272
+ "tail_entity": "Baldosas solares",
273
+ "score": 10
274
+ },
275
+ {
276
+ "head_entity": "John Adams Roofing Inc.",
277
+ "relationship": "has product",
278
+ "tail_entity": "Mercerized Cotton Thread",
279
+ "score": 10
280
+ }
281
+ ]
282
+ }
283
+
284
+ Now, let's get started! Here are the triplets you need to score:
285
+
286
+ **Input**
287
+
288
+ """
289
+
290
+
291
+ class RelationsParagraphMatch(BaseModel):
292
+ paragraph_id: ParagraphId
293
+ score: float
294
+ relations: Relations
295
+
296
+
297
+ async def get_graph_results(
298
+ *,
299
+ kbid: str,
300
+ query: str,
301
+ item: AskRequest,
302
+ ndb_client: NucliaDBClientType,
303
+ user: str,
304
+ origin: str,
305
+ graph_strategy: GraphStrategy,
306
+ generative_model: Optional[str] = None,
307
+ metrics: RAGMetrics = RAGMetrics(),
308
+ shards: Optional[list[str]] = None,
309
+ ) -> tuple[KnowledgeboxFindResults, FindRequest]:
310
+ relations = Relations(entities={})
311
+ explored_entities: set[str] = set()
312
+ scores: dict[str, list[float]] = {}
313
+ predict = get_predict()
314
+
315
+ for hop in range(graph_strategy.hops):
316
+ entities_to_explore: Iterable[RelationNode] = []
317
+
318
+ if hop == 0:
319
+ # Get the entities from the query
320
+ with metrics.time("graph_strat_query_entities"):
321
+ if graph_strategy.query_entity_detection == QueryEntityDetection.SUGGEST:
322
+ relation_result = await fuzzy_search_entities(
323
+ kbid=kbid,
324
+ query=query,
325
+ )
326
+ if relation_result is not None:
327
+ entities_to_explore = (
328
+ RelationNode(
329
+ ntype=RelationNode.NodeType.ENTITY,
330
+ value=result.value,
331
+ subtype=result.family,
332
+ )
333
+ for result in relation_result.entities
334
+ )
335
+ elif (
336
+ not entities_to_explore
337
+ or graph_strategy.query_entity_detection == QueryEntityDetection.PREDICT
338
+ ):
339
+ try:
340
+ # Purposely ignore the entity subtype. This is done so we find all entities that match
341
+ # the entity by name. e.g: in a query like "2000", predict might detect the number as
342
+ # a year entity or as a currency entity. We want graph results for both, so we ignore the
343
+ # subtype just in this case.
344
+ entities_to_explore = [
345
+ RelationNode(ntype=r.ntype, value=r.value, subtype="")
346
+ for r in await predict.detect_entities(kbid, query)
347
+ ]
348
+ except Exception as e:
349
+ capture_exception(e)
350
+ logger.exception("Error in detecting entities for graph strategy")
351
+ entities_to_explore = []
352
+ else:
353
+ # Find neighbors of the current relations and remove the ones already explored
354
+ entities_to_explore = (
355
+ RelationNode(
356
+ ntype=RelationNode.NodeType.ENTITY,
357
+ value=relation.entity,
358
+ subtype=relation.entity_subtype,
359
+ )
360
+ for subgraph in relations.entities.values()
361
+ for relation in subgraph.related_to
362
+ if relation.entity not in explored_entities
363
+ )
364
+
365
+ # Get the relations for the new entities
366
+ with metrics.time("graph_strat_neighbor_relations"):
367
+ try:
368
+ new_relations = await get_relations_results_from_entities(
369
+ kbid=kbid,
370
+ entities=entities_to_explore,
371
+ target_shard_replicas=shards,
372
+ timeout=5.0,
373
+ only_with_metadata=True,
374
+ only_agentic_relations=graph_strategy.agentic_graph_only,
375
+ deleted_entities=explored_entities,
376
+ )
377
+ except Exception as e:
378
+ capture_exception(e)
379
+ logger.exception("Error in getting query relations for graph strategy")
380
+ new_relations = Relations(entities={})
381
+
382
+ new_subgraphs = new_relations.entities
383
+
384
+ explored_entities.update(new_subgraphs.keys())
385
+
386
+ if not new_subgraphs or all(not subgraph.related_to for subgraph in new_subgraphs.values()):
387
+ break
388
+
389
+ relations.entities.update(new_subgraphs)
390
+
391
+ # Rank the relevance of the relations
392
+ with metrics.time("graph_strat_rank_relations"):
393
+ try:
394
+ if graph_strategy.relation_ranking == RelationRanking.RERANKER:
395
+ relations, scores = await rank_relations_reranker(
396
+ relations,
397
+ query,
398
+ kbid,
399
+ user,
400
+ top_k=graph_strategy.top_k,
401
+ )
402
+ elif graph_strategy.relation_ranking == RelationRanking.GENERATIVE:
403
+ relations, scores = await rank_relations_generative(
404
+ relations,
405
+ query,
406
+ kbid,
407
+ user,
408
+ top_k=graph_strategy.top_k,
409
+ generative_model=generative_model,
410
+ )
411
+ except Exception as e:
412
+ capture_exception(e)
413
+ logger.exception("Error in ranking relations for graph strategy")
414
+ relations = Relations(entities={})
415
+ scores = {}
416
+ break
417
+
418
+ # Get the text blocks of the paragraphs that contain the top relations
419
+ with metrics.time("graph_strat_build_response"):
420
+ find_request = find_request_from_ask_request(item, query)
421
+ query_parser, rank_fusion, reranker = await query_parser_from_find_request(
422
+ kbid, find_request, generative_model=generative_model
423
+ )
424
+ find_results = await build_graph_response(
425
+ kbid=kbid,
426
+ query=query,
427
+ final_relations=relations,
428
+ scores=scores,
429
+ top_k=graph_strategy.top_k,
430
+ reranker=reranker,
431
+ show=find_request.show,
432
+ extracted=find_request.extracted,
433
+ field_type_filter=find_request.field_type_filter,
434
+ relation_text_as_paragraphs=graph_strategy.relation_text_as_paragraphs,
435
+ )
436
+ return find_results, find_request
437
+
438
+
439
+ async def fuzzy_search_entities(
440
+ kbid: str,
441
+ query: str,
442
+ ) -> Optional[RelatedEntities]:
443
+ """Fuzzy find entities in KB given a query using the same methodology as /suggest, but split by words."""
444
+
445
+ request = nodereader_pb2.SearchRequest()
446
+ request.relation_prefix.query = query
447
+
448
+ results: list[nodereader_pb2.SearchResponse]
449
+ try:
450
+ (
451
+ results,
452
+ _,
453
+ _,
454
+ ) = await node_query(
455
+ kbid,
456
+ Method.SEARCH,
457
+ request,
458
+ use_read_replica_nodes=True,
459
+ retry_on_primary=False,
460
+ )
461
+ return merge_relation_prefix_results(results)
462
+ except Exception as e:
463
+ capture_exception(e)
464
+ logger.exception("Error in finding entities in query for graph strategy")
465
+ return None
466
+
467
+
468
+ async def rank_relations_reranker(
469
+ relations: Relations,
470
+ query: str,
471
+ kbid: str,
472
+ user: str,
473
+ top_k: int,
474
+ score_threshold: float = 0.02,
475
+ ) -> tuple[Relations, dict[str, list[float]]]:
476
+ # Store the index for keeping track after scoring
477
+ flat_rels: list[tuple[str, int, DirectionalRelation]] = [
478
+ (ent, idx, rel)
479
+ for (ent, rels) in relations.entities.items()
480
+ for (idx, rel) in enumerate(rels.related_to)
481
+ ]
482
+ # Build triplets (dict) from each relation for use in reranker
483
+ triplets: list[dict[str, str]] = [
484
+ {
485
+ "head_entity": ent,
486
+ "relationship": rel.relation_label,
487
+ "tail_entity": rel.entity,
488
+ }
489
+ if rel.direction == RelationDirection.OUT
490
+ else {
491
+ "head_entity": rel.entity,
492
+ "relationship": rel.relation_label,
493
+ "tail_entity": ent,
494
+ }
495
+ for (ent, _, rel) in flat_rels
496
+ ]
497
+
498
+ # Dedupe triplets so that they get evaluated once; map triplet -> [orig_indices]
499
+ triplet_to_orig_indices: dict[tuple[str, str, str], list[int]] = {}
500
+ unique_triplets: list[dict[str, str]] = []
501
+
502
+ for i, t in enumerate(triplets):
503
+ key = (t["head_entity"], t["relationship"], t["tail_entity"])
504
+ if key not in triplet_to_orig_indices:
505
+ triplet_to_orig_indices[key] = []
506
+ unique_triplets.append(t)
507
+ triplet_to_orig_indices[key].append(i)
508
+
509
+ # Build the reranker model input
510
+ predict = get_predict()
511
+ rerank_model = RerankModel(
512
+ question=query,
513
+ user_id=user,
514
+ context={
515
+ str(idx): f"{t['head_entity']} {t['relationship']} {t['tail_entity']}"
516
+ for idx, t in enumerate(unique_triplets)
517
+ },
518
+ )
519
+ # Get the rerank scores
520
+ res = await predict.rerank(kbid, rerank_model)
521
+
522
+ # Convert returned scores to a list of (int_idx, score)
523
+ # where int_idx corresponds to indices in unique_triplets
524
+ reranked_indices_scores = [(int(idx), score) for idx, score in res.context_scores.items()]
525
+
526
+ return _scores_to_ranked_rels(
527
+ unique_triplets,
528
+ reranked_indices_scores,
529
+ triplet_to_orig_indices,
530
+ flat_rels,
531
+ top_k,
532
+ score_threshold,
533
+ )
534
+
535
+
536
+ async def rank_relations_generative(
537
+ relations: Relations,
538
+ query: str,
539
+ kbid: str,
540
+ user: str,
541
+ top_k: int,
542
+ generative_model: Optional[str] = None,
543
+ score_threshold: float = 2,
544
+ max_rels_to_eval: int = 100,
545
+ ) -> tuple[Relations, dict[str, list[float]]]:
546
+ # Store the index for keeping track after scoring
547
+ flat_rels: list[tuple[str, int, DirectionalRelation]] = [
548
+ (ent, idx, rel)
549
+ for (ent, rels) in relations.entities.items()
550
+ for (idx, rel) in enumerate(rels.related_to)
551
+ ]
552
+ triplets: list[dict[str, str]] = [
553
+ {
554
+ "head_entity": ent,
555
+ "relationship": rel.relation_label,
556
+ "tail_entity": rel.entity,
557
+ }
558
+ if rel.direction == RelationDirection.OUT
559
+ else {
560
+ "head_entity": rel.entity,
561
+ "relationship": rel.relation_label,
562
+ "tail_entity": ent,
563
+ }
564
+ for (ent, _, rel) in flat_rels
565
+ ]
566
+
567
+ # Dedupe triplets so that they get evaluated once, we will re-associate the scores later
568
+ triplet_to_orig_indices: dict[tuple[str, str, str], list[int]] = {}
569
+ unique_triplets: list[dict[str, str]] = []
570
+
571
+ for i, t in enumerate(triplets):
572
+ key = (t["head_entity"], t["relationship"], t["tail_entity"])
573
+ if key not in triplet_to_orig_indices:
574
+ triplet_to_orig_indices[key] = []
575
+ unique_triplets.append(t)
576
+ triplet_to_orig_indices[key].append(i)
577
+
578
+ if len(flat_rels) > max_rels_to_eval:
579
+ logger.warning(f"Too many relations to evaluate ({len(flat_rels)}), using reranker to reduce")
580
+ return await rank_relations_reranker(relations, query, kbid, user, top_k=max_rels_to_eval)
581
+
582
+ data = {
583
+ "question": query,
584
+ "triplets": unique_triplets,
585
+ }
586
+ prompt = PROMPT + json.dumps(data, indent=4)
587
+
588
+ predict = get_predict()
589
+ chat_model = ChatModel(
590
+ question=prompt,
591
+ user_id=user,
592
+ json_schema=SCHEMA,
593
+ format_prompt=False, # We supply our own prompt
594
+ query_context_order={},
595
+ query_context={},
596
+ user_prompt=UserPrompt(prompt=prompt),
597
+ max_tokens=4096,
598
+ generative_model=generative_model,
599
+ )
600
+
601
+ ident, model, answer_stream = await predict.chat_query_ndjson(kbid, chat_model)
602
+ response_json = None
603
+ status = None
604
+ _ = None
605
+
606
+ async for generative_chunk in answer_stream:
607
+ item = generative_chunk.chunk
608
+ if isinstance(item, JSONGenerativeResponse):
609
+ response_json = item
610
+ elif isinstance(item, StatusGenerativeResponse):
611
+ status = item
612
+ elif isinstance(item, MetaGenerativeResponse):
613
+ _ = item
614
+ else:
615
+ raise ValueError(f"Unknown generative chunk type: {item}")
616
+
617
+ if response_json is None or status is None or status.code != "0":
618
+ raise ValueError("No JSON response found")
619
+
620
+ scored_unique_triplets: list[dict[str, Union[str, Any]]] = response_json.object["triplets"]
621
+
622
+ if len(scored_unique_triplets) != len(unique_triplets):
623
+ raise ValueError("Mismatch between input and output triplets")
624
+
625
+ unique_indices_scores = ((idx, float(t["score"])) for (idx, t) in enumerate(scored_unique_triplets))
626
+
627
+ return _scores_to_ranked_rels(
628
+ unique_triplets,
629
+ unique_indices_scores,
630
+ triplet_to_orig_indices,
631
+ flat_rels,
632
+ top_k,
633
+ score_threshold,
634
+ )
635
+
636
+
637
+ def _scores_to_ranked_rels(
638
+ unique_triplets: list[dict[str, str]],
639
+ unique_indices_scores: Iterable[tuple[int, float]],
640
+ triplet_to_orig_indices: dict[tuple[str, str, str], list[int]],
641
+ flat_rels: list[tuple[str, int, DirectionalRelation]],
642
+ top_k: int,
643
+ score_threshold: float,
644
+ ) -> tuple[Relations, dict[str, list[float]]]:
645
+ """
646
+ Helper function to convert unique scores assigned by a model back to the original relations while taking
647
+ care of threshold
648
+ """
649
+ top_k_indices_scores = heapq.nlargest(top_k, unique_indices_scores, key=lambda x: x[1])
650
+
651
+ # Prepare a new Relations object + a dict of top scores by entity
652
+ top_k_rels: dict[str, EntitySubgraph] = defaultdict(lambda: EntitySubgraph(related_to=[]))
653
+ top_k_scores_by_ent: dict[str, list[float]] = defaultdict(list)
654
+ # Re-expand model scores to the original triplets
655
+ for idx, score in top_k_indices_scores:
656
+ # If the model's score is below threshold, skip
657
+ if score <= score_threshold:
658
+ continue
659
+
660
+ # Identify which original triplets (in flat_rels) this corresponds to
661
+ t = unique_triplets[idx]
662
+ key = (t["head_entity"], t["relationship"], t["tail_entity"])
663
+ orig_indices = triplet_to_orig_indices[key]
664
+
665
+ for orig_i in orig_indices:
666
+ ent, rel_idx, rel = flat_rels[orig_i]
667
+ # Insert the relation into top_k_rels
668
+ top_k_rels[ent].related_to.append(rel)
669
+
670
+ # Keep track of which indices were chosen per entity
671
+ top_k_scores_by_ent[ent].append(score)
672
+
673
+ return Relations(entities=top_k_rels), dict(top_k_scores_by_ent)
674
+
675
+
676
+ def build_text_blocks_from_relations(
677
+ relations: Relations,
678
+ scores: dict[str, list[float]],
679
+ ) -> list[TextBlockMatch]:
680
+ """
681
+ The goal of this function is to generate TextBlockMatch with custom text for each unique relation in the graph.
682
+
683
+ This is a hacky way to generate paragraphs from relations, and it is not the intended use of TextBlockMatch.
684
+ """
685
+ # Build a set of unique triplets with their scores
686
+ triplets: dict[tuple[str, str, str], tuple[float, Relations, Optional[ParagraphId]]] = defaultdict(
687
+ lambda: (0.0, Relations(entities={}), None)
688
+ )
689
+ for ent, subgraph in relations.entities.items():
690
+ for rel, score in zip(subgraph.related_to, scores[ent]):
691
+ key = (
692
+ (
693
+ ent,
694
+ rel.relation_label,
695
+ rel.entity,
696
+ )
697
+ if rel.direction == RelationDirection.OUT
698
+ else (rel.entity, rel.relation_label, ent)
699
+ )
700
+ existing_score, existing_relations, p_id = triplets[key]
701
+ if ent not in existing_relations.entities:
702
+ existing_relations.entities[ent] = EntitySubgraph(related_to=[])
703
+
704
+ # XXX: Since relations with the same triplet can point to different paragraphs,
705
+ # we keep the first one, but we lose the other ones
706
+ if p_id is None and rel.metadata and rel.metadata.paragraph_id:
707
+ p_id = ParagraphId.from_string(rel.metadata.paragraph_id)
708
+ existing_relations.entities[ent].related_to.append(rel)
709
+ # XXX: Here we use the max even though all relations with same triplet should have same score
710
+ triplets[key] = (max(existing_score, score), existing_relations, p_id)
711
+
712
+ # Build the text blocks
713
+ text_blocks = [
714
+ TextBlockMatch(
715
+ # XXX: Even though we are setting a paragraph_id, the text is not coming from the paragraph
716
+ paragraph_id=p_id,
717
+ score=score,
718
+ score_type=SCORE_TYPE.RELATION_RELEVANCE,
719
+ order=0,
720
+ text=f"- {ent} {rel} {tail}", # Manually build the text
721
+ position=TextPosition(
722
+ page_number=0,
723
+ index=0,
724
+ start=0,
725
+ end=0,
726
+ start_seconds=[],
727
+ end_seconds=[],
728
+ ),
729
+ field_labels=[],
730
+ paragraph_labels=[],
731
+ fuzzy_search=False,
732
+ is_a_table=False,
733
+ representation_file="",
734
+ page_with_visual=False,
735
+ relevant_relations=relations,
736
+ )
737
+ for (ent, rel, tail), (score, relations, p_id) in triplets.items()
738
+ if p_id is not None
739
+ ]
740
+ return text_blocks
741
+
742
+
743
+ def get_paragraph_info_from_relations(
744
+ relations: Relations,
745
+ scores: dict[str, list[float]],
746
+ ) -> list[RelationsParagraphMatch]:
747
+ """
748
+ Gathers paragraph info from the 'relations' object, merges relations by paragraph,
749
+ and removes paragraphs contained entirely within others.
750
+ """
751
+
752
+ # Group paragraphs by field so we can detect containment
753
+ paragraphs_by_field: dict[FieldId, list[RelationsParagraphMatch]] = defaultdict(list)
754
+
755
+ # Loop over each entity in the relation graph
756
+ for ent, subgraph in relations.entities.items():
757
+ for rel_score, rel in zip(scores[ent], subgraph.related_to):
758
+ if rel.metadata and rel.metadata.paragraph_id:
759
+ p_id = ParagraphId.from_string(rel.metadata.paragraph_id)
760
+ match = RelationsParagraphMatch(
761
+ paragraph_id=p_id,
762
+ score=rel_score,
763
+ relations=Relations(entities={ent: EntitySubgraph(related_to=[rel])}),
764
+ )
765
+ paragraphs_by_field[p_id.field_id].append(match)
766
+
767
+ # For each field, sort paragraphs by start asc, end desc, and do one pass to remove contained ones
768
+ final_paragraphs: list[RelationsParagraphMatch] = []
769
+
770
+ for _, paragraph_list in paragraphs_by_field.items():
771
+ # Sort by paragraph_start ascending; if tie, paragraph_end descending
772
+ paragraph_list.sort(
773
+ key=lambda m: (m.paragraph_id.paragraph_start, -m.paragraph_id.paragraph_end)
774
+ )
775
+
776
+ kept: list[RelationsParagraphMatch] = []
777
+ current_max_end = -1
778
+
779
+ for match in paragraph_list:
780
+ end = match.paragraph_id.paragraph_end
781
+
782
+ # If end <= current_max_end, this paragraph is contained last one
783
+ if end <= current_max_end:
784
+ # We merge the scores and relations
785
+ container = kept[-1]
786
+ container.score = max(container.score, match.score)
787
+ for ent, subgraph in match.relations.entities.items():
788
+ if ent not in container.relations.entities:
789
+ container.relations.entities[ent] = EntitySubgraph(related_to=[])
790
+ container.relations.entities[ent].related_to.extend(subgraph.related_to)
791
+
792
+ else:
793
+ # Not contained; keep it
794
+ kept.append(match)
795
+ current_max_end = end
796
+ final_paragraphs.extend(kept)
797
+
798
+ return final_paragraphs
799
+
800
+
801
+ async def build_graph_response(
802
+ *,
803
+ kbid: str,
804
+ query: str,
805
+ final_relations: Relations,
806
+ scores: dict[str, list[float]],
807
+ top_k: int,
808
+ reranker: Reranker,
809
+ relation_text_as_paragraphs: bool,
810
+ show: list[ResourceProperties] = [],
811
+ extracted: list[ExtractedDataTypeName] = [],
812
+ field_type_filter: list[FieldTypeName] = [],
813
+ ) -> KnowledgeboxFindResults:
814
+ if relation_text_as_paragraphs:
815
+ text_blocks = build_text_blocks_from_relations(final_relations, scores)
816
+ else:
817
+ paragraphs_info = get_paragraph_info_from_relations(final_relations, scores)
818
+ text_blocks = relations_matches_to_text_block_matches(paragraphs_info)
819
+
820
+ # hydrate and rerank
821
+ resource_hydration_options = ResourceHydrationOptions(
822
+ show=show, extracted=extracted, field_type_filter=field_type_filter
823
+ )
824
+ text_block_hydration_options = TextBlockHydrationOptions(only_hydrate_empty=True)
825
+ reranking_options = RerankingOptions(kbid=kbid, query=query)
826
+ text_blocks, resources, best_matches = await hydrate_and_rerank(
827
+ text_blocks,
828
+ kbid,
829
+ resource_hydration_options=resource_hydration_options,
830
+ text_block_hydration_options=text_block_hydration_options,
831
+ reranker=reranker,
832
+ reranking_options=reranking_options,
833
+ top_k=top_k,
834
+ )
835
+
836
+ find_resources = compose_find_resources(text_blocks, resources)
837
+
838
+ return KnowledgeboxFindResults(
839
+ query=query,
840
+ resources=find_resources,
841
+ best_matches=best_matches,
842
+ relations=final_relations,
843
+ total=len(text_blocks),
844
+ )
845
+
846
+
847
+ def relations_match_to_text_block_match(
848
+ paragraph_match: RelationsParagraphMatch,
849
+ ) -> TextBlockMatch:
850
+ """
851
+ Given a paragraph_id, return a TextBlockMatch with the bare minimum fields
852
+ This is required by the Graph Strategy to get text blocks from the relevant paragraphs
853
+ """
854
+ # XXX: this is a workaround for the fact we always assume retrieval means keyword/semantic search and
855
+ # the hydration and find response building code works with TextBlockMatch, we extended it to have relevant relations information
856
+ parsed_paragraph_id = paragraph_match.paragraph_id
857
+ return TextBlockMatch(
858
+ paragraph_id=parsed_paragraph_id,
859
+ score=paragraph_match.score,
860
+ score_type=SCORE_TYPE.RELATION_RELEVANCE,
861
+ order=0, # NOTE: this will be filled later
862
+ text="", # NOTE: this will be filled later too
863
+ position=TextPosition(
864
+ page_number=0,
865
+ index=0,
866
+ start=parsed_paragraph_id.paragraph_start,
867
+ end=parsed_paragraph_id.paragraph_end,
868
+ start_seconds=[],
869
+ end_seconds=[],
870
+ ),
871
+ field_labels=[],
872
+ paragraph_labels=[],
873
+ fuzzy_search=False,
874
+ is_a_table=False,
875
+ representation_file="",
876
+ page_with_visual=False,
877
+ relevant_relations=paragraph_match.relations,
878
+ )
879
+
880
+
881
+ def relations_matches_to_text_block_matches(
882
+ paragraph_matches: Collection[RelationsParagraphMatch],
883
+ ) -> list[TextBlockMatch]:
884
+ return [relations_match_to_text_block_match(match) for match in paragraph_matches]