nucliadb 6.2.1.post2838__py3-none-any.whl → 6.2.1.post2842__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nucliadb/common/external_index_providers/base.py +2 -1
- nucliadb/common/ids.py +18 -4
- nucliadb/search/api/v1/suggest.py +0 -2
- nucliadb/search/search/chat/ask.py +35 -10
- nucliadb/search/search/chat/prompt.py +4 -2
- nucliadb/search/search/chat/query.py +56 -28
- nucliadb/search/search/graph_strategy.py +913 -0
- nucliadb/search/search/hydrator.py +6 -0
- nucliadb/search/search/merge.py +54 -22
- {nucliadb-6.2.1.post2838.dist-info → nucliadb-6.2.1.post2842.dist-info}/METADATA +5 -5
- {nucliadb-6.2.1.post2838.dist-info → nucliadb-6.2.1.post2842.dist-info}/RECORD +15 -14
- {nucliadb-6.2.1.post2838.dist-info → nucliadb-6.2.1.post2842.dist-info}/WHEEL +0 -0
- {nucliadb-6.2.1.post2838.dist-info → nucliadb-6.2.1.post2842.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.2.1.post2838.dist-info → nucliadb-6.2.1.post2842.dist-info}/top_level.txt +0 -0
- {nucliadb-6.2.1.post2838.dist-info → nucliadb-6.2.1.post2842.dist-info}/zip-safe +0 -0
@@ -28,7 +28,7 @@ from nucliadb.common.counters import IndexCounts
|
|
28
28
|
from nucliadb.common.external_index_providers.exceptions import ExternalIndexingError
|
29
29
|
from nucliadb.common.ids import ParagraphId
|
30
30
|
from nucliadb_models.external_index_providers import ExternalIndexProviderType
|
31
|
-
from nucliadb_models.search import SCORE_TYPE, TextPosition
|
31
|
+
from nucliadb_models.search import SCORE_TYPE, Relations, TextPosition
|
32
32
|
from nucliadb_protos.knowledgebox_pb2 import (
|
33
33
|
CreateExternalIndexProviderMetadata,
|
34
34
|
StoredExternalIndexProviderMetadata,
|
@@ -73,6 +73,7 @@ class TextBlockMatch(BaseModel):
|
|
73
73
|
paragraph_labels: list[str] = []
|
74
74
|
field_labels: list[str] = []
|
75
75
|
text: Optional[str] = None
|
76
|
+
relevant_relations: Optional[Relations] = None
|
76
77
|
|
77
78
|
|
78
79
|
class QueryResults(BaseModel):
|
nucliadb/common/ids.py
CHANGED
@@ -111,13 +111,11 @@ class FieldId:
|
|
111
111
|
parts = value.split("/")
|
112
112
|
if len(parts) == 3:
|
113
113
|
rid, _type, key = parts
|
114
|
-
|
115
|
-
raise ValueError(f"Invalid FieldId: {value}")
|
114
|
+
_type = cls.parse_field_type(_type)
|
116
115
|
return cls(rid=rid, type=_type, key=key)
|
117
116
|
elif len(parts) == 4:
|
118
117
|
rid, _type, key, subfield_id = parts
|
119
|
-
|
120
|
-
raise ValueError(f"Invalid FieldId: {value}")
|
118
|
+
_type = cls.parse_field_type(_type)
|
121
119
|
return cls(
|
122
120
|
rid=rid,
|
123
121
|
type=_type,
|
@@ -127,6 +125,22 @@ class FieldId:
|
|
127
125
|
else:
|
128
126
|
raise ValueError(f"Invalid FieldId: {value}")
|
129
127
|
|
128
|
+
@classmethod
|
129
|
+
def parse_field_type(cls, _type: str) -> str:
|
130
|
+
if _type not in FIELD_TYPE_STR_TO_PB:
|
131
|
+
# Try to parse the enum value
|
132
|
+
# XXX: This is to support field types that are integer values of FieldType
|
133
|
+
# Which is how legacy processor relations reported the paragraph_id
|
134
|
+
try:
|
135
|
+
type_pb = FieldType.ValueType(int(_type))
|
136
|
+
except ValueError:
|
137
|
+
raise ValueError(f"Invalid FieldId: {_type}")
|
138
|
+
if type_pb in FIELD_TYPE_PB_TO_STR:
|
139
|
+
return FIELD_TYPE_PB_TO_STR[type_pb]
|
140
|
+
else:
|
141
|
+
raise ValueError(f"Invalid FieldId: {_type}")
|
142
|
+
return _type
|
143
|
+
|
130
144
|
|
131
145
|
@dataclass
|
132
146
|
class ParagraphId:
|
@@ -57,6 +57,7 @@ from nucliadb.search.search.exceptions import (
|
|
57
57
|
IncompleteFindResultsError,
|
58
58
|
InvalidQueryError,
|
59
59
|
)
|
60
|
+
from nucliadb.search.search.graph_strategy import get_graph_results
|
60
61
|
from nucliadb.search.search.metrics import RAGMetrics
|
61
62
|
from nucliadb.search.search.query import QueryParser
|
62
63
|
from nucliadb.search.utilities import get_predict
|
@@ -75,6 +76,7 @@ from nucliadb_models.search import (
|
|
75
76
|
ErrorAskResponseItem,
|
76
77
|
FindParagraph,
|
77
78
|
FindRequest,
|
79
|
+
GraphStrategy,
|
78
80
|
JSONAskResponseItem,
|
79
81
|
KnowledgeboxFindResults,
|
80
82
|
MetadataAskResponseItem,
|
@@ -629,6 +631,13 @@ def parse_prequeries(ask_request: AskRequest) -> Optional[PreQueriesStrategy]:
|
|
629
631
|
return None
|
630
632
|
|
631
633
|
|
634
|
+
def parse_graph_strategy(ask_request: AskRequest) -> Optional[GraphStrategy]:
|
635
|
+
for rag_strategy in ask_request.rag_strategies:
|
636
|
+
if rag_strategy.name == RagStrategyName.GRAPH:
|
637
|
+
return cast(GraphStrategy, rag_strategy)
|
638
|
+
return None
|
639
|
+
|
640
|
+
|
632
641
|
async def retrieval_step(
|
633
642
|
kbid: str,
|
634
643
|
main_query: str,
|
@@ -675,17 +684,33 @@ async def retrieval_in_kb(
|
|
675
684
|
metrics: RAGMetrics,
|
676
685
|
) -> RetrievalResults:
|
677
686
|
prequeries = parse_prequeries(ask_request)
|
687
|
+
graph_strategy = parse_graph_strategy(ask_request)
|
678
688
|
with metrics.time("retrieval"):
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
|
687
|
-
|
688
|
-
|
689
|
+
prequeries_results = None
|
690
|
+
if graph_strategy is not None:
|
691
|
+
main_results, query_parser = await get_graph_results(
|
692
|
+
kbid=kbid,
|
693
|
+
query=main_query,
|
694
|
+
item=ask_request,
|
695
|
+
ndb_client=client_type,
|
696
|
+
user=user_id,
|
697
|
+
origin=origin,
|
698
|
+
graph_strategy=graph_strategy,
|
699
|
+
metrics=metrics,
|
700
|
+
shards=ask_request.shards,
|
701
|
+
)
|
702
|
+
# TODO (oni): Fallback to normal retrieval if no graph results are found
|
703
|
+
else:
|
704
|
+
main_results, prequeries_results, query_parser = await get_find_results(
|
705
|
+
kbid=kbid,
|
706
|
+
query=main_query,
|
707
|
+
item=ask_request,
|
708
|
+
ndb_client=client_type,
|
709
|
+
user=user_id,
|
710
|
+
origin=origin,
|
711
|
+
metrics=metrics,
|
712
|
+
prequeries_strategy=prequeries,
|
713
|
+
)
|
689
714
|
if len(main_results.resources) == 0 and all(
|
690
715
|
len(prequery_result.resources) == 0 for (_, prequery_result) in prequeries_results or []
|
691
716
|
):
|
@@ -1013,8 +1013,10 @@ class PromptContextBuilder:
|
|
1013
1013
|
neighbouring_paragraphs = cast(NeighbouringParagraphsStrategy, strategy)
|
1014
1014
|
elif strategy.name == RagStrategyName.METADATA_EXTENSION:
|
1015
1015
|
metadata_extension = cast(MetadataExtensionStrategy, strategy)
|
1016
|
-
elif
|
1017
|
-
|
1016
|
+
elif (
|
1017
|
+
strategy.name != RagStrategyName.PREQUERIES and strategy.name != RagStrategyName.GRAPH
|
1018
|
+
): # pragma: no cover
|
1019
|
+
# Prequeries and graph are not handled here
|
1018
1020
|
logger.warning(
|
1019
1021
|
"Unknown rag strategy",
|
1020
1022
|
extra={"strategy": strategy.name, "kbid": self.kbid},
|
@@ -18,7 +18,7 @@
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
20
|
import asyncio
|
21
|
-
from typing import Optional
|
21
|
+
from typing import Iterable, Optional
|
22
22
|
|
23
23
|
from nucliadb.common.models_utils import to_proto
|
24
24
|
from nucliadb.search import logger
|
@@ -51,6 +51,7 @@ from nucliadb_models.search import (
|
|
51
51
|
)
|
52
52
|
from nucliadb_protos import audit_pb2
|
53
53
|
from nucliadb_protos.nodereader_pb2 import RelationSearchResponse, SearchRequest, SearchResponse
|
54
|
+
from nucliadb_protos.utils_pb2 import RelationNode
|
54
55
|
from nucliadb_telemetry.errors import capture_exception
|
55
56
|
from nucliadb_utils.utilities import get_audit
|
56
57
|
|
@@ -145,15 +146,7 @@ async def get_find_results(
|
|
145
146
|
return main_results, prequeries_results, query_parser
|
146
147
|
|
147
148
|
|
148
|
-
|
149
|
-
kbid: str,
|
150
|
-
query: str,
|
151
|
-
item: AskRequest,
|
152
|
-
ndb_client: NucliaDBClientType,
|
153
|
-
user: str,
|
154
|
-
origin: str,
|
155
|
-
metrics: RAGMetrics = RAGMetrics(),
|
156
|
-
) -> tuple[KnowledgeboxFindResults, QueryParser]:
|
149
|
+
def find_request_from_ask_request(item: AskRequest, query: str) -> FindRequest:
|
157
150
|
find_request = FindRequest()
|
158
151
|
find_request.resource_filters = item.resource_filters
|
159
152
|
find_request.features = []
|
@@ -189,7 +182,19 @@ async def run_main_query(
|
|
189
182
|
find_request.show_hidden = item.show_hidden
|
190
183
|
|
191
184
|
# this executes the model validators, that can tweak some fields
|
192
|
-
FindRequest.model_validate(find_request)
|
185
|
+
return FindRequest.model_validate(find_request)
|
186
|
+
|
187
|
+
|
188
|
+
async def run_main_query(
|
189
|
+
kbid: str,
|
190
|
+
query: str,
|
191
|
+
item: AskRequest,
|
192
|
+
ndb_client: NucliaDBClientType,
|
193
|
+
user: str,
|
194
|
+
origin: str,
|
195
|
+
metrics: RAGMetrics = RAGMetrics(),
|
196
|
+
) -> tuple[KnowledgeboxFindResults, QueryParser]:
|
197
|
+
find_request = find_request_from_ask_request(item, query)
|
193
198
|
|
194
199
|
find_results, incomplete, query_parser = await find(
|
195
200
|
kbid,
|
@@ -211,36 +216,59 @@ async def get_relations_results(
|
|
211
216
|
text_answer: str,
|
212
217
|
target_shard_replicas: Optional[list[str]],
|
213
218
|
timeout: Optional[float] = None,
|
219
|
+
only_with_metadata: bool = False,
|
220
|
+
only_agentic_relations: bool = False,
|
214
221
|
) -> Relations:
|
215
222
|
try:
|
216
223
|
predict = get_predict()
|
217
224
|
detected_entities = await predict.detect_entities(kbid, text_answer)
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
results: list[SearchResponse]
|
223
|
-
(
|
224
|
-
results,
|
225
|
-
_,
|
226
|
-
_,
|
227
|
-
) = await node_query(
|
228
|
-
kbid,
|
229
|
-
Method.SEARCH,
|
230
|
-
request,
|
225
|
+
|
226
|
+
return await get_relations_results_from_entities(
|
227
|
+
kbid=kbid,
|
228
|
+
entities=detected_entities,
|
231
229
|
target_shard_replicas=target_shard_replicas,
|
232
230
|
timeout=timeout,
|
233
|
-
|
234
|
-
|
231
|
+
only_with_metadata=only_with_metadata,
|
232
|
+
only_agentic_relations=only_agentic_relations,
|
235
233
|
)
|
236
|
-
relations_results: list[RelationSearchResponse] = [result.relation for result in results]
|
237
|
-
return await merge_relations_results(relations_results, request.relation_subgraph)
|
238
234
|
except Exception as exc:
|
239
235
|
capture_exception(exc)
|
240
236
|
logger.exception("Error getting relations results")
|
241
237
|
return Relations(entities={})
|
242
238
|
|
243
239
|
|
240
|
+
async def get_relations_results_from_entities(
|
241
|
+
*,
|
242
|
+
kbid: str,
|
243
|
+
entities: Iterable[RelationNode],
|
244
|
+
target_shard_replicas: Optional[list[str]],
|
245
|
+
timeout: Optional[float] = None,
|
246
|
+
only_with_metadata: bool = False,
|
247
|
+
only_agentic_relations: bool = False,
|
248
|
+
) -> Relations:
|
249
|
+
request = SearchRequest()
|
250
|
+
request.relation_subgraph.entry_points.extend(entities)
|
251
|
+
request.relation_subgraph.depth = 1
|
252
|
+
results: list[SearchResponse]
|
253
|
+
(
|
254
|
+
results,
|
255
|
+
_,
|
256
|
+
_,
|
257
|
+
) = await node_query(
|
258
|
+
kbid,
|
259
|
+
Method.SEARCH,
|
260
|
+
request,
|
261
|
+
target_shard_replicas=target_shard_replicas,
|
262
|
+
timeout=timeout,
|
263
|
+
use_read_replica_nodes=True,
|
264
|
+
retry_on_primary=False,
|
265
|
+
)
|
266
|
+
relations_results: list[RelationSearchResponse] = [result.relation for result in results]
|
267
|
+
return await merge_relations_results(
|
268
|
+
relations_results, request.relation_subgraph, only_with_metadata, only_agentic_relations
|
269
|
+
)
|
270
|
+
|
271
|
+
|
244
272
|
def maybe_audit_chat(
|
245
273
|
*,
|
246
274
|
kbid: str,
|