nucliadb 6.2.1.post2835__py3-none-any.whl → 6.2.1.post2842__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nucliadb/common/external_index_providers/base.py +2 -1
- nucliadb/common/ids.py +18 -4
- nucliadb/search/api/v1/suggest.py +0 -2
- nucliadb/search/search/chat/ask.py +35 -10
- nucliadb/search/search/chat/prompt.py +4 -2
- nucliadb/search/search/chat/query.py +56 -28
- nucliadb/search/search/graph_strategy.py +913 -0
- nucliadb/search/search/hydrator.py +6 -0
- nucliadb/search/search/merge.py +54 -22
- {nucliadb-6.2.1.post2835.dist-info → nucliadb-6.2.1.post2842.dist-info}/METADATA +5 -5
- {nucliadb-6.2.1.post2835.dist-info → nucliadb-6.2.1.post2842.dist-info}/RECORD +15 -14
- {nucliadb-6.2.1.post2835.dist-info → nucliadb-6.2.1.post2842.dist-info}/WHEEL +0 -0
- {nucliadb-6.2.1.post2835.dist-info → nucliadb-6.2.1.post2842.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.2.1.post2835.dist-info → nucliadb-6.2.1.post2842.dist-info}/top_level.txt +0 -0
- {nucliadb-6.2.1.post2835.dist-info → nucliadb-6.2.1.post2842.dist-info}/zip-safe +0 -0
@@ -0,0 +1,913 @@
|
|
1
|
+
# Copyright (C) 2021 Bosutech XXI S.L.
|
2
|
+
#
|
3
|
+
# nucliadb is offered under the AGPL v3.0 and as commercial software.
|
4
|
+
# For commercial licensing, contact us at info@nuclia.com.
|
5
|
+
#
|
6
|
+
# AGPL:
|
7
|
+
# This program is free software: you can redistribute it and/or modify
|
8
|
+
# it under the terms of the GNU Affero General Public License as
|
9
|
+
# published by the Free Software Foundation, either version 3 of the
|
10
|
+
# License, or (at your option) any later version.
|
11
|
+
#
|
12
|
+
# This program is distributed in the hope that it will be useful,
|
13
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
14
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
15
|
+
# GNU Affero General Public License for more details.
|
16
|
+
#
|
17
|
+
# You should have received a copy of the GNU Affero General Public License
|
18
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
|
+
|
20
|
+
import asyncio
|
21
|
+
import heapq
|
22
|
+
import json
|
23
|
+
from collections import defaultdict
|
24
|
+
from datetime import datetime
|
25
|
+
from typing import Any, Collection, Iterable, Optional, Union
|
26
|
+
|
27
|
+
from nuclia_models.predict.generative_responses import (
|
28
|
+
JSONGenerativeResponse,
|
29
|
+
MetaGenerativeResponse,
|
30
|
+
StatusGenerativeResponse,
|
31
|
+
)
|
32
|
+
from pydantic import BaseModel
|
33
|
+
from sentry_sdk import capture_exception
|
34
|
+
|
35
|
+
from nucliadb.common.external_index_providers.base import TextBlockMatch
|
36
|
+
from nucliadb.common.ids import FieldId, ParagraphId
|
37
|
+
from nucliadb.search import logger
|
38
|
+
from nucliadb.search.requesters.utils import Method, node_query
|
39
|
+
from nucliadb.search.search.chat.query import (
|
40
|
+
find_request_from_ask_request,
|
41
|
+
get_relations_results_from_entities,
|
42
|
+
)
|
43
|
+
from nucliadb.search.search.find import query_parser_from_find_request
|
44
|
+
from nucliadb.search.search.find_merge import (
|
45
|
+
compose_find_resources,
|
46
|
+
hydrate_and_rerank,
|
47
|
+
)
|
48
|
+
from nucliadb.search.search.hydrator import ResourceHydrationOptions, TextBlockHydrationOptions
|
49
|
+
from nucliadb.search.search.merge import merge_suggest_results
|
50
|
+
from nucliadb.search.search.metrics import RAGMetrics
|
51
|
+
from nucliadb.search.search.query import QueryParser
|
52
|
+
from nucliadb.search.search.rerankers import Reranker, RerankingOptions
|
53
|
+
from nucliadb.search.utilities import get_predict
|
54
|
+
from nucliadb_models.common import FieldTypeName
|
55
|
+
from nucliadb_models.internal.predict import (
|
56
|
+
RerankModel,
|
57
|
+
)
|
58
|
+
from nucliadb_models.resource import ExtractedDataTypeName
|
59
|
+
from nucliadb_models.search import (
|
60
|
+
SCORE_TYPE,
|
61
|
+
AskRequest,
|
62
|
+
ChatModel,
|
63
|
+
DirectionalRelation,
|
64
|
+
EntitySubgraph,
|
65
|
+
GraphStrategy,
|
66
|
+
KnowledgeboxFindResults,
|
67
|
+
KnowledgeboxSuggestResults,
|
68
|
+
NucliaDBClientType,
|
69
|
+
QueryEntityDetection,
|
70
|
+
RelationDirection,
|
71
|
+
RelationRanking,
|
72
|
+
Relations,
|
73
|
+
ResourceProperties,
|
74
|
+
TextPosition,
|
75
|
+
UserPrompt,
|
76
|
+
)
|
77
|
+
from nucliadb_protos import nodereader_pb2
|
78
|
+
from nucliadb_protos.utils_pb2 import RelationNode
|
79
|
+
|
80
|
+
SCHEMA = {
|
81
|
+
"title": "score_triplets",
|
82
|
+
"description": "Return a list of triplets and their relevance scores (0-10) for the supplied question.",
|
83
|
+
"type": "object",
|
84
|
+
"properties": {
|
85
|
+
"triplets": {
|
86
|
+
"type": "array",
|
87
|
+
"description": "A list of triplets with their relevance scores.",
|
88
|
+
"items": {
|
89
|
+
"type": "object",
|
90
|
+
"properties": {
|
91
|
+
"head_entity": {"type": "string", "description": "The first entity in the triplet."},
|
92
|
+
"relationship": {
|
93
|
+
"type": "string",
|
94
|
+
"description": "The relationship between the two entities.",
|
95
|
+
},
|
96
|
+
"tail_entity": {
|
97
|
+
"type": "string",
|
98
|
+
"description": "The second entity in the triplet.",
|
99
|
+
},
|
100
|
+
"score": {
|
101
|
+
"type": "integer",
|
102
|
+
"description": "A relevance score in the range 0 to 10.",
|
103
|
+
"minimum": 0,
|
104
|
+
"maximum": 10,
|
105
|
+
},
|
106
|
+
},
|
107
|
+
"required": ["head_entity", "relationship", "tail_entity", "score"],
|
108
|
+
},
|
109
|
+
}
|
110
|
+
},
|
111
|
+
"required": ["triplets"],
|
112
|
+
}
|
113
|
+
|
114
|
+
PROMPT = """\
|
115
|
+
You are an advanced language model assisting in scoring relationships (edges) between two entities in a knowledge graph, given a user’s question.
|
116
|
+
|
117
|
+
For each provided **(head_entity, relationship, tail_entity)**, you must:
|
118
|
+
1. Assign a **relevance score** between **0** and **10**.
|
119
|
+
2. **0** means “this relationship can’t be relevant at all to the question.”
|
120
|
+
3. **10** means “this relationship is extremely relevant to the question.”
|
121
|
+
4. You may use **any integer** between 0 and 10 (e.g., 3, 7, etc.) based on how relevant you deem the relationship to be.
|
122
|
+
5. **Language Agnosticism**: The question and the relationships may be in different languages. The relevance scoring should still work and be agnostic of the language.
|
123
|
+
6. Relationships that may not answer the question directly but expand knowledge in a relevant way, should also be scored positively.
|
124
|
+
|
125
|
+
Once you have decided the best score for each triplet, return these results **using a function call** in JSON format with the following rules:
|
126
|
+
|
127
|
+
- The function name should be `score_triplets`.
|
128
|
+
- The first argument should be the list of triplets.
|
129
|
+
- Each triplet should have the following keys:
|
130
|
+
- `head_entity`: The first entity in the triplet.
|
131
|
+
- `relationship`: The relationship between the two entities.
|
132
|
+
- `tail_entity`: The second entity in the triplet.
|
133
|
+
- `score`: The relevance score in the range 0 to 10.
|
134
|
+
|
135
|
+
You **must** comply with the provided JSON Schema to ensure a well-structured response and mantain the order of the triplets.
|
136
|
+
|
137
|
+
|
138
|
+
## Examples:
|
139
|
+
|
140
|
+
### Example 1:
|
141
|
+
|
142
|
+
**Input**
|
143
|
+
|
144
|
+
{
|
145
|
+
"question": "Who is the mayor of the capital city of Australia?",
|
146
|
+
"triplets": [
|
147
|
+
{
|
148
|
+
"head_entity": "Australia",
|
149
|
+
"relationship": "has prime minister",
|
150
|
+
"tail_entity": "Scott Morrison"
|
151
|
+
},
|
152
|
+
{
|
153
|
+
"head_entity": "Canberra",
|
154
|
+
"relationship": "is capital of",
|
155
|
+
"tail_entity": "Australia"
|
156
|
+
},
|
157
|
+
{
|
158
|
+
"head_entity": "Scott Knowles",
|
159
|
+
"relationship": "holds position",
|
160
|
+
"tail_entity": "Mayor"
|
161
|
+
},
|
162
|
+
{
|
163
|
+
"head_entity": "Barbera Smith",
|
164
|
+
"relationship": "tiene cargo",
|
165
|
+
"tail_entity": "Alcalde"
|
166
|
+
},
|
167
|
+
{
|
168
|
+
"head_entity": "Austria",
|
169
|
+
"relationship": "has capital",
|
170
|
+
"tail_entity": "Vienna"
|
171
|
+
}
|
172
|
+
]
|
173
|
+
}
|
174
|
+
|
175
|
+
**Output**
|
176
|
+
|
177
|
+
{
|
178
|
+
"triplets": [
|
179
|
+
{
|
180
|
+
"head_entity": "Australia",
|
181
|
+
"relationship": "has prime minister",
|
182
|
+
"tail_entity": "Scott Morrison",
|
183
|
+
"score": 4
|
184
|
+
},
|
185
|
+
{
|
186
|
+
"head_entity": "Canberra",
|
187
|
+
"relationship": "is capital of",
|
188
|
+
"tail_entity": "Australia",
|
189
|
+
"score": 8
|
190
|
+
},
|
191
|
+
{
|
192
|
+
"head_entity": "Scott Knowles",
|
193
|
+
"relationship": "holds position",
|
194
|
+
"tail_entity": "Mayor",
|
195
|
+
"score": 8
|
196
|
+
},
|
197
|
+
{
|
198
|
+
"head_entity": "Barbera Smith",
|
199
|
+
"relationship": "tiene cargo",
|
200
|
+
"tail_entity": "Alcalde",
|
201
|
+
"score": 8
|
202
|
+
},
|
203
|
+
{
|
204
|
+
"head_entity": "Austria",
|
205
|
+
"relationship": "has capital",
|
206
|
+
"tail_entity": "Vienna",
|
207
|
+
"score": 0
|
208
|
+
}
|
209
|
+
]
|
210
|
+
}
|
211
|
+
|
212
|
+
|
213
|
+
|
214
|
+
### Example 2:
|
215
|
+
|
216
|
+
**Input**
|
217
|
+
|
218
|
+
{
|
219
|
+
"question": "How many products does John Adams Roofing Inc. offer?",
|
220
|
+
"triplets": [
|
221
|
+
{
|
222
|
+
"head_entity": "John Adams Roofing Inc.",
|
223
|
+
"relationship": "has product",
|
224
|
+
"tail_entity": "Titanium Grade 3 Roofing Nails"
|
225
|
+
},
|
226
|
+
{
|
227
|
+
"head_entity": "John Adams Roofing Inc.",
|
228
|
+
"relationship": "is located in",
|
229
|
+
"tail_entity": "New York"
|
230
|
+
},
|
231
|
+
{
|
232
|
+
"head_entity": "John Adams Roofing Inc.",
|
233
|
+
"relationship": "was founded by",
|
234
|
+
"tail_entity": "John Adams"
|
235
|
+
},
|
236
|
+
{
|
237
|
+
"head_entity": "John Adams Roofing Inc.",
|
238
|
+
"relationship": "tiene stock",
|
239
|
+
"tail_entity": "Baldosas solares"
|
240
|
+
},
|
241
|
+
{
|
242
|
+
"head_entity": "John Adams Roofing Inc.",
|
243
|
+
"relationship": "has product",
|
244
|
+
"tail_entity": "Mercerized Cotton Thread"
|
245
|
+
}
|
246
|
+
]
|
247
|
+
}
|
248
|
+
|
249
|
+
**Output**
|
250
|
+
|
251
|
+
{
|
252
|
+
"triplets": [
|
253
|
+
{
|
254
|
+
"head_entity": "John Adams Roofing Inc.",
|
255
|
+
"relationship": "has product",
|
256
|
+
"tail_entity": "Titanium Grade 3 Roofing Nails",
|
257
|
+
"score": 10
|
258
|
+
},
|
259
|
+
{
|
260
|
+
"head_entity": "John Adams Roofing Inc.",
|
261
|
+
"relationship": "is located in",
|
262
|
+
"tail_entity": "New York",
|
263
|
+
"score": 6
|
264
|
+
},
|
265
|
+
{
|
266
|
+
"head_entity": "John Adams Roofing Inc.",
|
267
|
+
"relationship": "was founded by",
|
268
|
+
"tail_entity": "John Adams",
|
269
|
+
"score": 5
|
270
|
+
},
|
271
|
+
{
|
272
|
+
"head_entity": "John Adams Roofing Inc.",
|
273
|
+
"relationship": "tiene stock",
|
274
|
+
"tail_entity": "Baldosas solares",
|
275
|
+
"score": 10
|
276
|
+
},
|
277
|
+
{
|
278
|
+
"head_entity": "John Adams Roofing Inc.",
|
279
|
+
"relationship": "has product",
|
280
|
+
"tail_entity": "Mercerized Cotton Thread",
|
281
|
+
"score": 10
|
282
|
+
}
|
283
|
+
]
|
284
|
+
}
|
285
|
+
|
286
|
+
Now, let's get started! Here are the triplets you need to score:
|
287
|
+
|
288
|
+
**Input**
|
289
|
+
|
290
|
+
"""
|
291
|
+
|
292
|
+
|
293
|
+
class RelationsParagraphMatch(BaseModel):
|
294
|
+
paragraph_id: ParagraphId
|
295
|
+
score: float
|
296
|
+
relations: Relations
|
297
|
+
|
298
|
+
|
299
|
+
async def get_graph_results(
|
300
|
+
*,
|
301
|
+
kbid: str,
|
302
|
+
query: str,
|
303
|
+
item: AskRequest,
|
304
|
+
ndb_client: NucliaDBClientType,
|
305
|
+
user: str,
|
306
|
+
origin: str,
|
307
|
+
graph_strategy: GraphStrategy,
|
308
|
+
generative_model: Optional[str] = None,
|
309
|
+
metrics: RAGMetrics = RAGMetrics(),
|
310
|
+
shards: Optional[list[str]] = None,
|
311
|
+
) -> tuple[KnowledgeboxFindResults, QueryParser]:
|
312
|
+
relations = Relations(entities={})
|
313
|
+
explored_entities: set[str] = set()
|
314
|
+
predict = get_predict()
|
315
|
+
|
316
|
+
for hop in range(graph_strategy.hops):
|
317
|
+
entities_to_explore: Iterable[RelationNode] = []
|
318
|
+
scores: dict[str, list[float]] = {}
|
319
|
+
if hop == 0:
|
320
|
+
# Get the entities from the query
|
321
|
+
with metrics.time("graph_strat_query_entities"):
|
322
|
+
if graph_strategy.query_entity_detection == QueryEntityDetection.SUGGEST:
|
323
|
+
suggest_result = await fuzzy_search_entities(
|
324
|
+
kbid=kbid,
|
325
|
+
query=query,
|
326
|
+
range_creation_start=item.range_creation_start,
|
327
|
+
range_creation_end=item.range_creation_end,
|
328
|
+
range_modification_start=item.range_modification_start,
|
329
|
+
range_modification_end=item.range_modification_end,
|
330
|
+
target_shard_replicas=shards,
|
331
|
+
)
|
332
|
+
if suggest_result.entities is not None:
|
333
|
+
entities_to_explore = (
|
334
|
+
RelationNode(
|
335
|
+
ntype=RelationNode.NodeType.ENTITY,
|
336
|
+
value=result.value,
|
337
|
+
subtype=result.family,
|
338
|
+
)
|
339
|
+
for result in suggest_result.entities.entities
|
340
|
+
)
|
341
|
+
elif (
|
342
|
+
not entities_to_explore
|
343
|
+
or graph_strategy.query_entity_detection == QueryEntityDetection.PREDICT
|
344
|
+
):
|
345
|
+
try:
|
346
|
+
entities_to_explore = await predict.detect_entities(kbid, query)
|
347
|
+
except Exception as e:
|
348
|
+
capture_exception(e)
|
349
|
+
logger.exception("Error in detecting entities for graph strategy")
|
350
|
+
entities_to_explore = []
|
351
|
+
else:
|
352
|
+
# Find neighbors of the current relations and remove the ones already explored
|
353
|
+
entities_to_explore = (
|
354
|
+
RelationNode(
|
355
|
+
ntype=RelationNode.NodeType.ENTITY,
|
356
|
+
value=relation.entity,
|
357
|
+
subtype=relation.entity_subtype,
|
358
|
+
)
|
359
|
+
for subgraph in relations.entities.values()
|
360
|
+
for relation in subgraph.related_to
|
361
|
+
if relation.entity not in explored_entities
|
362
|
+
)
|
363
|
+
# Get the relations for the new entities
|
364
|
+
with metrics.time("graph_strat_neighbor_relations"):
|
365
|
+
try:
|
366
|
+
new_relations = await get_relations_results_from_entities(
|
367
|
+
kbid=kbid,
|
368
|
+
entities=entities_to_explore,
|
369
|
+
target_shard_replicas=shards,
|
370
|
+
timeout=5.0,
|
371
|
+
only_with_metadata=True,
|
372
|
+
only_agentic_relations=graph_strategy.agentic_graph_only,
|
373
|
+
)
|
374
|
+
except Exception as e:
|
375
|
+
capture_exception(e)
|
376
|
+
logger.exception("Error in getting query relations for graph strategy")
|
377
|
+
new_relations = Relations(entities={})
|
378
|
+
|
379
|
+
# Removing the relations connected to the entities that were already explored
|
380
|
+
# XXX: This could be optimized by implementing a filter in the index
|
381
|
+
# so we don't have to remove them after
|
382
|
+
new_subgraphs = {
|
383
|
+
entity: filter_subgraph(subgraph, explored_entities)
|
384
|
+
for entity, subgraph in new_relations.entities.items()
|
385
|
+
}
|
386
|
+
|
387
|
+
explored_entities.update(new_subgraphs.keys())
|
388
|
+
|
389
|
+
if not new_subgraphs or all(not subgraph.related_to for subgraph in new_subgraphs.values()):
|
390
|
+
break
|
391
|
+
|
392
|
+
relations.entities.update(new_subgraphs)
|
393
|
+
|
394
|
+
# Rank the relevance of the relations
|
395
|
+
with metrics.time("graph_strat_rank_relations"):
|
396
|
+
try:
|
397
|
+
if graph_strategy.relation_ranking == RelationRanking.RERANKER:
|
398
|
+
relations, scores = await rank_relations_reranker(
|
399
|
+
relations,
|
400
|
+
query,
|
401
|
+
kbid,
|
402
|
+
user,
|
403
|
+
top_k=graph_strategy.top_k,
|
404
|
+
)
|
405
|
+
elif graph_strategy.relation_ranking == RelationRanking.GENERATIVE:
|
406
|
+
relations, scores = await rank_relations_generative(
|
407
|
+
relations,
|
408
|
+
query,
|
409
|
+
kbid,
|
410
|
+
user,
|
411
|
+
top_k=graph_strategy.top_k,
|
412
|
+
generative_model=generative_model,
|
413
|
+
)
|
414
|
+
except Exception as e:
|
415
|
+
capture_exception(e)
|
416
|
+
logger.exception("Error in ranking relations for graph strategy")
|
417
|
+
relations = Relations(entities={})
|
418
|
+
break
|
419
|
+
|
420
|
+
# Get the text blocks of the paragraphs that contain the top relations
|
421
|
+
with metrics.time("graph_strat_build_response"):
|
422
|
+
find_request = find_request_from_ask_request(item, query)
|
423
|
+
query_parser, rank_fusion, reranker = await query_parser_from_find_request(
|
424
|
+
kbid, find_request, generative_model=generative_model
|
425
|
+
)
|
426
|
+
find_results = await build_graph_response(
|
427
|
+
kbid=kbid,
|
428
|
+
query=query,
|
429
|
+
final_relations=relations,
|
430
|
+
scores=scores,
|
431
|
+
top_k=graph_strategy.top_k,
|
432
|
+
reranker=reranker,
|
433
|
+
show=find_request.show,
|
434
|
+
extracted=find_request.extracted,
|
435
|
+
field_type_filter=find_request.field_type_filter,
|
436
|
+
relation_text_as_paragraphs=graph_strategy.relation_text_as_paragraphs,
|
437
|
+
)
|
438
|
+
return find_results, query_parser
|
439
|
+
|
440
|
+
|
441
|
+
async def fuzzy_search_entities(
|
442
|
+
kbid: str,
|
443
|
+
query: str,
|
444
|
+
range_creation_start: Optional[datetime] = None,
|
445
|
+
range_creation_end: Optional[datetime] = None,
|
446
|
+
range_modification_start: Optional[datetime] = None,
|
447
|
+
range_modification_end: Optional[datetime] = None,
|
448
|
+
target_shard_replicas: Optional[list[str]] = None,
|
449
|
+
) -> KnowledgeboxSuggestResults:
|
450
|
+
"""Fuzzy find entities in KB given a query using the same methodology as /suggest, but split by words."""
|
451
|
+
|
452
|
+
base_request = nodereader_pb2.SuggestRequest(
|
453
|
+
body="", features=[nodereader_pb2.SuggestFeatures.ENTITIES]
|
454
|
+
)
|
455
|
+
if range_creation_start is not None:
|
456
|
+
base_request.timestamps.from_created.FromDatetime(range_creation_start)
|
457
|
+
if range_creation_end is not None:
|
458
|
+
base_request.timestamps.to_created.FromDatetime(range_creation_end)
|
459
|
+
if range_modification_start is not None:
|
460
|
+
base_request.timestamps.from_modified.FromDatetime(range_modification_start)
|
461
|
+
if range_modification_end is not None:
|
462
|
+
base_request.timestamps.to_modified.FromDatetime(range_modification_end)
|
463
|
+
|
464
|
+
tasks = []
|
465
|
+
# XXX: Splitting by words is not ideal, in the future, modify suggest to better handle this
|
466
|
+
for word in query.split():
|
467
|
+
if len(word) < 3:
|
468
|
+
continue
|
469
|
+
request = nodereader_pb2.SuggestRequest()
|
470
|
+
request.CopyFrom(base_request)
|
471
|
+
request.body = word
|
472
|
+
tasks.append(
|
473
|
+
node_query(kbid, Method.SUGGEST, request, target_shard_replicas=target_shard_replicas)
|
474
|
+
)
|
475
|
+
|
476
|
+
try:
|
477
|
+
results_raw = await asyncio.gather(*tasks)
|
478
|
+
return await merge_suggest_results(
|
479
|
+
[item for r in results_raw for item in r[0]],
|
480
|
+
kbid=kbid,
|
481
|
+
)
|
482
|
+
except Exception as e:
|
483
|
+
capture_exception(e)
|
484
|
+
logger.exception("Error in finding entities in query for graph strategy")
|
485
|
+
return KnowledgeboxSuggestResults(entities=None)
|
486
|
+
|
487
|
+
|
488
|
+
async def rank_relations_reranker(
|
489
|
+
relations: Relations,
|
490
|
+
query: str,
|
491
|
+
kbid: str,
|
492
|
+
user: str,
|
493
|
+
top_k: int,
|
494
|
+
score_threshold: float = 0.02,
|
495
|
+
) -> tuple[Relations, dict[str, list[float]]]:
|
496
|
+
# Store the index for keeping track after scoring
|
497
|
+
flat_rels: list[tuple[str, int, DirectionalRelation]] = [
|
498
|
+
(ent, idx, rel)
|
499
|
+
for (ent, rels) in relations.entities.items()
|
500
|
+
for (idx, rel) in enumerate(rels.related_to)
|
501
|
+
]
|
502
|
+
# Build triplets (dict) from each relation for use in reranker
|
503
|
+
triplets: list[dict[str, str]] = [
|
504
|
+
{
|
505
|
+
"head_entity": ent,
|
506
|
+
"relationship": rel.relation_label,
|
507
|
+
"tail_entity": rel.entity,
|
508
|
+
}
|
509
|
+
if rel.direction == RelationDirection.OUT
|
510
|
+
else {
|
511
|
+
"head_entity": rel.entity,
|
512
|
+
"relationship": rel.relation_label,
|
513
|
+
"tail_entity": ent,
|
514
|
+
}
|
515
|
+
for (ent, _, rel) in flat_rels
|
516
|
+
]
|
517
|
+
|
518
|
+
# Dedupe triplets so that they get evaluated once; map triplet -> [orig_indices]
|
519
|
+
triplet_to_orig_indices: dict[tuple[str, str, str], list[int]] = {}
|
520
|
+
unique_triplets: list[dict[str, str]] = []
|
521
|
+
|
522
|
+
for i, t in enumerate(triplets):
|
523
|
+
key = (t["head_entity"], t["relationship"], t["tail_entity"])
|
524
|
+
if key not in triplet_to_orig_indices:
|
525
|
+
triplet_to_orig_indices[key] = []
|
526
|
+
unique_triplets.append(t)
|
527
|
+
triplet_to_orig_indices[key].append(i)
|
528
|
+
|
529
|
+
# Build the reranker model input
|
530
|
+
predict = get_predict()
|
531
|
+
rerank_model = RerankModel(
|
532
|
+
question=query,
|
533
|
+
user_id=user,
|
534
|
+
context={
|
535
|
+
str(idx): f"{t['head_entity']} {t['relationship']} {t['tail_entity']}"
|
536
|
+
for idx, t in enumerate(unique_triplets)
|
537
|
+
},
|
538
|
+
)
|
539
|
+
# Get the rerank scores
|
540
|
+
res = await predict.rerank(kbid, rerank_model)
|
541
|
+
|
542
|
+
# Convert returned scores to a list of (int_idx, score)
|
543
|
+
# where int_idx corresponds to indices in unique_triplets
|
544
|
+
reranked_indices_scores = [(int(idx), score) for idx, score in res.context_scores.items()]
|
545
|
+
|
546
|
+
return _scores_to_ranked_rels(
|
547
|
+
unique_triplets,
|
548
|
+
reranked_indices_scores,
|
549
|
+
triplet_to_orig_indices,
|
550
|
+
flat_rels,
|
551
|
+
top_k,
|
552
|
+
score_threshold,
|
553
|
+
)
|
554
|
+
|
555
|
+
|
556
|
+
async def rank_relations_generative(
|
557
|
+
relations: Relations,
|
558
|
+
query: str,
|
559
|
+
kbid: str,
|
560
|
+
user: str,
|
561
|
+
top_k: int,
|
562
|
+
generative_model: Optional[str] = None,
|
563
|
+
score_threshold: float = 2,
|
564
|
+
max_rels_to_eval: int = 100,
|
565
|
+
) -> tuple[Relations, dict[str, list[float]]]:
|
566
|
+
# Store the index for keeping track after scoring
|
567
|
+
flat_rels: list[tuple[str, int, DirectionalRelation]] = [
|
568
|
+
(ent, idx, rel)
|
569
|
+
for (ent, rels) in relations.entities.items()
|
570
|
+
for (idx, rel) in enumerate(rels.related_to)
|
571
|
+
]
|
572
|
+
triplets: list[dict[str, str]] = [
|
573
|
+
{
|
574
|
+
"head_entity": ent,
|
575
|
+
"relationship": rel.relation_label,
|
576
|
+
"tail_entity": rel.entity,
|
577
|
+
}
|
578
|
+
if rel.direction == RelationDirection.OUT
|
579
|
+
else {
|
580
|
+
"head_entity": rel.entity,
|
581
|
+
"relationship": rel.relation_label,
|
582
|
+
"tail_entity": ent,
|
583
|
+
}
|
584
|
+
for (ent, _, rel) in flat_rels
|
585
|
+
]
|
586
|
+
|
587
|
+
# Dedupe triplets so that they get evaluated once, we will re-associate the scores later
|
588
|
+
triplet_to_orig_indices: dict[tuple[str, str, str], list[int]] = {}
|
589
|
+
unique_triplets: list[dict[str, str]] = []
|
590
|
+
|
591
|
+
for i, t in enumerate(triplets):
|
592
|
+
key = (t["head_entity"], t["relationship"], t["tail_entity"])
|
593
|
+
if key not in triplet_to_orig_indices:
|
594
|
+
triplet_to_orig_indices[key] = []
|
595
|
+
unique_triplets.append(t)
|
596
|
+
triplet_to_orig_indices[key].append(i)
|
597
|
+
|
598
|
+
if len(flat_rels) > max_rels_to_eval:
|
599
|
+
logger.warning(f"Too many relations to evaluate ({len(flat_rels)}), using reranker to reduce")
|
600
|
+
return await rank_relations_reranker(relations, query, kbid, user, top_k=max_rels_to_eval)
|
601
|
+
|
602
|
+
data = {
|
603
|
+
"question": query,
|
604
|
+
"triplets": unique_triplets,
|
605
|
+
}
|
606
|
+
prompt = PROMPT + json.dumps(data, indent=4)
|
607
|
+
|
608
|
+
predict = get_predict()
|
609
|
+
chat_model = ChatModel(
|
610
|
+
question=prompt,
|
611
|
+
user_id=user,
|
612
|
+
json_schema=SCHEMA,
|
613
|
+
format_prompt=False, # We supply our own prompt
|
614
|
+
query_context_order={},
|
615
|
+
query_context={},
|
616
|
+
user_prompt=UserPrompt(prompt=prompt),
|
617
|
+
max_tokens=4096,
|
618
|
+
generative_model=generative_model,
|
619
|
+
)
|
620
|
+
|
621
|
+
ident, model, answer_stream = await predict.chat_query_ndjson(kbid, chat_model)
|
622
|
+
response_json = None
|
623
|
+
status = None
|
624
|
+
_ = None
|
625
|
+
|
626
|
+
async for generative_chunk in answer_stream:
|
627
|
+
item = generative_chunk.chunk
|
628
|
+
if isinstance(item, JSONGenerativeResponse):
|
629
|
+
response_json = item
|
630
|
+
elif isinstance(item, StatusGenerativeResponse):
|
631
|
+
status = item
|
632
|
+
elif isinstance(item, MetaGenerativeResponse):
|
633
|
+
_ = item
|
634
|
+
else:
|
635
|
+
raise ValueError(f"Unknown generative chunk type: {item}")
|
636
|
+
|
637
|
+
if response_json is None or status is None or status.code != "0":
|
638
|
+
raise ValueError("No JSON response found")
|
639
|
+
|
640
|
+
scored_unique_triplets: list[dict[str, Union[str, Any]]] = response_json.object["triplets"]
|
641
|
+
|
642
|
+
if len(scored_unique_triplets) != len(unique_triplets):
|
643
|
+
raise ValueError("Mismatch between input and output triplets")
|
644
|
+
|
645
|
+
unique_indices_scores = ((idx, float(t["score"])) for (idx, t) in enumerate(scored_unique_triplets))
|
646
|
+
|
647
|
+
return _scores_to_ranked_rels(
|
648
|
+
unique_triplets,
|
649
|
+
unique_indices_scores,
|
650
|
+
triplet_to_orig_indices,
|
651
|
+
flat_rels,
|
652
|
+
top_k,
|
653
|
+
score_threshold,
|
654
|
+
)
|
655
|
+
|
656
|
+
|
657
|
+
def _scores_to_ranked_rels(
|
658
|
+
unique_triplets: list[dict[str, str]],
|
659
|
+
unique_indices_scores: Iterable[tuple[int, float]],
|
660
|
+
triplet_to_orig_indices: dict[tuple[str, str, str], list[int]],
|
661
|
+
flat_rels: list[tuple[str, int, DirectionalRelation]],
|
662
|
+
top_k: int,
|
663
|
+
score_threshold: float,
|
664
|
+
) -> tuple[Relations, dict[str, list[float]]]:
|
665
|
+
"""
|
666
|
+
Helper function to convert unique scores assigned by a model back to the original relations while taking
|
667
|
+
care of threshold
|
668
|
+
"""
|
669
|
+
top_k_indices_scores = heapq.nlargest(top_k, unique_indices_scores, key=lambda x: x[1])
|
670
|
+
|
671
|
+
# Prepare a new Relations object + a dict of top scores by entity
|
672
|
+
top_k_rels: dict[str, EntitySubgraph] = defaultdict(lambda: EntitySubgraph(related_to=[]))
|
673
|
+
top_k_scores_by_ent: dict[str, list[float]] = defaultdict(list)
|
674
|
+
# Re-expand model scores to the original triplets
|
675
|
+
for idx, score in top_k_indices_scores:
|
676
|
+
# If the model's score is below threshold, skip
|
677
|
+
if score <= score_threshold:
|
678
|
+
continue
|
679
|
+
|
680
|
+
# Identify which original triplets (in flat_rels) this corresponds to
|
681
|
+
t = unique_triplets[idx]
|
682
|
+
key = (t["head_entity"], t["relationship"], t["tail_entity"])
|
683
|
+
orig_indices = triplet_to_orig_indices[key]
|
684
|
+
|
685
|
+
for orig_i in orig_indices:
|
686
|
+
ent, rel_idx, rel = flat_rels[orig_i]
|
687
|
+
# Insert the relation into top_k_rels
|
688
|
+
top_k_rels[ent].related_to.append(rel)
|
689
|
+
|
690
|
+
# Keep track of which indices were chosen per entity
|
691
|
+
top_k_scores_by_ent[ent].append(score)
|
692
|
+
|
693
|
+
return Relations(entities=top_k_rels), dict(top_k_scores_by_ent)
|
694
|
+
|
695
|
+
|
696
|
+
def build_text_blocks_from_relations(
|
697
|
+
relations: Relations,
|
698
|
+
scores: dict[str, list[float]],
|
699
|
+
) -> list[TextBlockMatch]:
|
700
|
+
"""
|
701
|
+
The goal of this function is to generate TextBlockMatch with custom text for each unique relation in the graph.
|
702
|
+
|
703
|
+
This is a hacky way to generate paragraphs from relations, and it is not the intended use of TextBlockMatch.
|
704
|
+
"""
|
705
|
+
# Build a set of unique triplets with their scores
|
706
|
+
triplets: dict[tuple[str, str, str], tuple[float, Relations, Optional[ParagraphId]]] = defaultdict(
|
707
|
+
lambda: (0.0, Relations(entities={}), None)
|
708
|
+
)
|
709
|
+
for ent, subgraph in relations.entities.items():
|
710
|
+
for rel, score in zip(subgraph.related_to, scores[ent]):
|
711
|
+
key = (
|
712
|
+
(
|
713
|
+
ent,
|
714
|
+
rel.relation_label,
|
715
|
+
rel.entity,
|
716
|
+
)
|
717
|
+
if rel.direction == RelationDirection.OUT
|
718
|
+
else (rel.entity, rel.relation_label, ent)
|
719
|
+
)
|
720
|
+
existing_score, existing_relations, p_id = triplets[key]
|
721
|
+
if ent not in existing_relations.entities:
|
722
|
+
existing_relations.entities[ent] = EntitySubgraph(related_to=[])
|
723
|
+
|
724
|
+
# XXX: Since relations with the same triplet can point to different paragraphs,
|
725
|
+
# we keep the first one, but we lose the other ones
|
726
|
+
if p_id is None and rel.metadata and rel.metadata.paragraph_id:
|
727
|
+
p_id = ParagraphId.from_string(rel.metadata.paragraph_id)
|
728
|
+
existing_relations.entities[ent].related_to.append(rel)
|
729
|
+
# XXX: Here we use the max even though all relations with same triplet should have same score
|
730
|
+
triplets[key] = (max(existing_score, score), existing_relations, p_id)
|
731
|
+
|
732
|
+
# Build the text blocks
|
733
|
+
text_blocks = [
|
734
|
+
TextBlockMatch(
|
735
|
+
# XXX: Even though we are setting a paragraph_id, the text is not coming from the paragraph
|
736
|
+
paragraph_id=p_id,
|
737
|
+
score=score,
|
738
|
+
score_type=SCORE_TYPE.RELATION_RELEVANCE,
|
739
|
+
order=0,
|
740
|
+
text=f"- {ent} {rel} {tail}", # Manually build the text
|
741
|
+
position=TextPosition(
|
742
|
+
page_number=0,
|
743
|
+
index=0,
|
744
|
+
start=0,
|
745
|
+
end=0,
|
746
|
+
start_seconds=[],
|
747
|
+
end_seconds=[],
|
748
|
+
),
|
749
|
+
field_labels=[],
|
750
|
+
paragraph_labels=[],
|
751
|
+
fuzzy_search=False,
|
752
|
+
is_a_table=False,
|
753
|
+
representation_file="",
|
754
|
+
page_with_visual=False,
|
755
|
+
relevant_relations=relations,
|
756
|
+
)
|
757
|
+
for (ent, rel, tail), (score, relations, p_id) in triplets.items()
|
758
|
+
if p_id is not None
|
759
|
+
]
|
760
|
+
return text_blocks
|
761
|
+
|
762
|
+
|
763
|
+
def get_paragraph_info_from_relations(
|
764
|
+
relations: Relations,
|
765
|
+
scores: dict[str, list[float]],
|
766
|
+
) -> list[RelationsParagraphMatch]:
|
767
|
+
"""
|
768
|
+
Gathers paragraph info from the 'relations' object, merges relations by paragraph,
|
769
|
+
and removes paragraphs contained entirely within others.
|
770
|
+
"""
|
771
|
+
|
772
|
+
# Group paragraphs by field so we can detect containment
|
773
|
+
paragraphs_by_field: dict[FieldId, list[RelationsParagraphMatch]] = defaultdict(list)
|
774
|
+
|
775
|
+
# Loop over each entity in the relation graph
|
776
|
+
for ent, subgraph in relations.entities.items():
|
777
|
+
for rel_score, rel in zip(scores[ent], subgraph.related_to):
|
778
|
+
if rel.metadata and rel.metadata.paragraph_id:
|
779
|
+
p_id = ParagraphId.from_string(rel.metadata.paragraph_id)
|
780
|
+
match = RelationsParagraphMatch(
|
781
|
+
paragraph_id=p_id,
|
782
|
+
score=rel_score,
|
783
|
+
relations=Relations(entities={ent: EntitySubgraph(related_to=[rel])}),
|
784
|
+
)
|
785
|
+
paragraphs_by_field[p_id.field_id].append(match)
|
786
|
+
|
787
|
+
# For each field, sort paragraphs by start asc, end desc, and do one pass to remove contained ones
|
788
|
+
final_paragraphs: list[RelationsParagraphMatch] = []
|
789
|
+
|
790
|
+
for _, paragraph_list in paragraphs_by_field.items():
|
791
|
+
# Sort by paragraph_start ascending; if tie, paragraph_end descending
|
792
|
+
paragraph_list.sort(
|
793
|
+
key=lambda m: (m.paragraph_id.paragraph_start, -m.paragraph_id.paragraph_end)
|
794
|
+
)
|
795
|
+
|
796
|
+
kept: list[RelationsParagraphMatch] = []
|
797
|
+
current_max_end = -1
|
798
|
+
|
799
|
+
for match in paragraph_list:
|
800
|
+
end = match.paragraph_id.paragraph_end
|
801
|
+
|
802
|
+
# If end <= current_max_end, this paragraph is contained last one
|
803
|
+
if end <= current_max_end:
|
804
|
+
# We merge the scores and relations
|
805
|
+
container = kept[-1]
|
806
|
+
container.score = max(container.score, match.score)
|
807
|
+
for ent, subgraph in match.relations.entities.items():
|
808
|
+
if ent not in container.relations.entities:
|
809
|
+
container.relations.entities[ent] = EntitySubgraph(related_to=[])
|
810
|
+
container.relations.entities[ent].related_to.extend(subgraph.related_to)
|
811
|
+
|
812
|
+
else:
|
813
|
+
# Not contained; keep it
|
814
|
+
kept.append(match)
|
815
|
+
current_max_end = end
|
816
|
+
final_paragraphs.extend(kept)
|
817
|
+
|
818
|
+
return final_paragraphs
|
819
|
+
|
820
|
+
|
821
|
+
async def build_graph_response(
|
822
|
+
*,
|
823
|
+
kbid: str,
|
824
|
+
query: str,
|
825
|
+
final_relations: Relations,
|
826
|
+
scores: dict[str, list[float]],
|
827
|
+
top_k: int,
|
828
|
+
reranker: Reranker,
|
829
|
+
relation_text_as_paragraphs: bool,
|
830
|
+
show: list[ResourceProperties] = [],
|
831
|
+
extracted: list[ExtractedDataTypeName] = [],
|
832
|
+
field_type_filter: list[FieldTypeName] = [],
|
833
|
+
) -> KnowledgeboxFindResults:
|
834
|
+
if relation_text_as_paragraphs:
|
835
|
+
text_blocks = build_text_blocks_from_relations(final_relations, scores)
|
836
|
+
else:
|
837
|
+
paragraphs_info = get_paragraph_info_from_relations(final_relations, scores)
|
838
|
+
text_blocks = relations_matches_to_text_block_matches(paragraphs_info)
|
839
|
+
|
840
|
+
# hydrate and rerank
|
841
|
+
resource_hydration_options = ResourceHydrationOptions(
|
842
|
+
show=show, extracted=extracted, field_type_filter=field_type_filter
|
843
|
+
)
|
844
|
+
text_block_hydration_options = TextBlockHydrationOptions(only_hydrate_empty=True)
|
845
|
+
reranking_options = RerankingOptions(kbid=kbid, query=query)
|
846
|
+
text_blocks, resources, best_matches = await hydrate_and_rerank(
|
847
|
+
text_blocks,
|
848
|
+
kbid,
|
849
|
+
resource_hydration_options=resource_hydration_options,
|
850
|
+
text_block_hydration_options=text_block_hydration_options,
|
851
|
+
reranker=reranker,
|
852
|
+
reranking_options=reranking_options,
|
853
|
+
top_k=top_k,
|
854
|
+
)
|
855
|
+
|
856
|
+
find_resources = compose_find_resources(text_blocks, resources)
|
857
|
+
|
858
|
+
return KnowledgeboxFindResults(
|
859
|
+
query=query,
|
860
|
+
resources=find_resources,
|
861
|
+
best_matches=best_matches,
|
862
|
+
relations=final_relations,
|
863
|
+
total=len(text_blocks),
|
864
|
+
)
|
865
|
+
|
866
|
+
|
867
|
+
def filter_subgraph(subgraph: EntitySubgraph, entities_to_remove: Collection[str]) -> EntitySubgraph:
|
868
|
+
"""
|
869
|
+
Removes the relationships with entities in `entities_to_remove` from the subgraph.
|
870
|
+
"""
|
871
|
+
return EntitySubgraph(
|
872
|
+
related_to=[rel for rel in subgraph.related_to if rel.entity not in entities_to_remove]
|
873
|
+
)
|
874
|
+
|
875
|
+
|
876
|
+
def relations_match_to_text_block_match(
|
877
|
+
paragraph_match: RelationsParagraphMatch,
|
878
|
+
) -> TextBlockMatch:
|
879
|
+
"""
|
880
|
+
Given a paragraph_id, return a TextBlockMatch with the bare minimum fields
|
881
|
+
This is required by the Graph Strategy to get text blocks from the relevant paragraphs
|
882
|
+
"""
|
883
|
+
# XXX: this is a workaround for the fact we always assume retrieval means keyword/semantic search and
|
884
|
+
# the hydration and find response building code works with TextBlockMatch, we extended it to have relevant relations information
|
885
|
+
parsed_paragraph_id = paragraph_match.paragraph_id
|
886
|
+
return TextBlockMatch(
|
887
|
+
paragraph_id=parsed_paragraph_id,
|
888
|
+
score=paragraph_match.score,
|
889
|
+
score_type=SCORE_TYPE.RELATION_RELEVANCE,
|
890
|
+
order=0, # NOTE: this will be filled later
|
891
|
+
text="", # NOTE: this will be filled later too
|
892
|
+
position=TextPosition(
|
893
|
+
page_number=0,
|
894
|
+
index=0,
|
895
|
+
start=parsed_paragraph_id.paragraph_start,
|
896
|
+
end=parsed_paragraph_id.paragraph_end,
|
897
|
+
start_seconds=[],
|
898
|
+
end_seconds=[],
|
899
|
+
),
|
900
|
+
field_labels=[],
|
901
|
+
paragraph_labels=[],
|
902
|
+
fuzzy_search=False,
|
903
|
+
is_a_table=False,
|
904
|
+
representation_file="",
|
905
|
+
page_with_visual=False,
|
906
|
+
relevant_relations=paragraph_match.relations,
|
907
|
+
)
|
908
|
+
|
909
|
+
|
910
|
+
def relations_matches_to_text_block_matches(
|
911
|
+
paragraph_matches: Collection[RelationsParagraphMatch],
|
912
|
+
) -> list[TextBlockMatch]:
|
913
|
+
return [relations_match_to_text_block_match(match) for match in paragraph_matches]
|