nucliadb 6.2.0.post2679__py3-none-any.whl → 6.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0028_extracted_vectors_reference.py +61 -0
- migrations/0029_backfill_field_status.py +149 -0
- migrations/0030_label_deduplication.py +60 -0
- nucliadb/common/cluster/manager.py +41 -331
- nucliadb/common/cluster/rebalance.py +2 -2
- nucliadb/common/cluster/rollover.py +12 -71
- nucliadb/common/cluster/settings.py +3 -0
- nucliadb/common/cluster/standalone/utils.py +0 -43
- nucliadb/common/cluster/utils.py +0 -16
- nucliadb/common/counters.py +1 -0
- nucliadb/common/datamanagers/fields.py +48 -7
- nucliadb/common/datamanagers/vectorsets.py +11 -2
- nucliadb/common/external_index_providers/base.py +2 -1
- nucliadb/common/external_index_providers/pinecone.py +3 -5
- nucliadb/common/ids.py +18 -4
- nucliadb/common/models_utils/from_proto.py +479 -0
- nucliadb/common/models_utils/to_proto.py +60 -0
- nucliadb/common/nidx.py +76 -37
- nucliadb/export_import/models.py +3 -3
- nucliadb/health.py +0 -7
- nucliadb/ingest/app.py +0 -8
- nucliadb/ingest/consumer/auditing.py +1 -1
- nucliadb/ingest/consumer/shard_creator.py +1 -1
- nucliadb/ingest/fields/base.py +83 -21
- nucliadb/ingest/orm/brain.py +55 -56
- nucliadb/ingest/orm/broker_message.py +12 -2
- nucliadb/ingest/orm/entities.py +6 -17
- nucliadb/ingest/orm/knowledgebox.py +44 -22
- nucliadb/ingest/orm/processor/data_augmentation.py +7 -29
- nucliadb/ingest/orm/processor/processor.py +5 -2
- nucliadb/ingest/orm/resource.py +222 -413
- nucliadb/ingest/processing.py +8 -2
- nucliadb/ingest/serialize.py +77 -46
- nucliadb/ingest/service/writer.py +2 -56
- nucliadb/ingest/settings.py +1 -4
- nucliadb/learning_proxy.py +6 -4
- nucliadb/purge/__init__.py +102 -12
- nucliadb/purge/orphan_shards.py +6 -4
- nucliadb/reader/api/models.py +3 -3
- nucliadb/reader/api/v1/__init__.py +1 -0
- nucliadb/reader/api/v1/download.py +2 -2
- nucliadb/reader/api/v1/knowledgebox.py +3 -3
- nucliadb/reader/api/v1/resource.py +23 -12
- nucliadb/reader/api/v1/services.py +4 -4
- nucliadb/reader/api/v1/vectorsets.py +48 -0
- nucliadb/search/api/v1/ask.py +11 -1
- nucliadb/search/api/v1/feedback.py +3 -3
- nucliadb/search/api/v1/knowledgebox.py +8 -13
- nucliadb/search/api/v1/search.py +3 -2
- nucliadb/search/api/v1/suggest.py +0 -2
- nucliadb/search/predict.py +6 -4
- nucliadb/search/requesters/utils.py +1 -2
- nucliadb/search/search/chat/ask.py +77 -13
- nucliadb/search/search/chat/prompt.py +16 -5
- nucliadb/search/search/chat/query.py +74 -34
- nucliadb/search/search/exceptions.py +2 -7
- nucliadb/search/search/find.py +9 -5
- nucliadb/search/search/find_merge.py +10 -4
- nucliadb/search/search/graph_strategy.py +884 -0
- nucliadb/search/search/hydrator.py +6 -0
- nucliadb/search/search/merge.py +79 -24
- nucliadb/search/search/query.py +74 -245
- nucliadb/search/search/query_parser/exceptions.py +11 -1
- nucliadb/search/search/query_parser/fetcher.py +405 -0
- nucliadb/search/search/query_parser/models.py +0 -3
- nucliadb/search/search/query_parser/parser.py +22 -21
- nucliadb/search/search/rerankers.py +1 -42
- nucliadb/search/search/shards.py +19 -0
- nucliadb/standalone/api_router.py +2 -14
- nucliadb/standalone/settings.py +4 -0
- nucliadb/train/generators/field_streaming.py +7 -3
- nucliadb/train/lifecycle.py +3 -6
- nucliadb/train/nodes.py +14 -12
- nucliadb/train/resource.py +380 -0
- nucliadb/writer/api/constants.py +20 -16
- nucliadb/writer/api/v1/__init__.py +1 -0
- nucliadb/writer/api/v1/export_import.py +1 -1
- nucliadb/writer/api/v1/field.py +13 -7
- nucliadb/writer/api/v1/knowledgebox.py +3 -46
- nucliadb/writer/api/v1/resource.py +20 -13
- nucliadb/writer/api/v1/services.py +10 -1
- nucliadb/writer/api/v1/upload.py +61 -34
- nucliadb/writer/{vectorsets.py → api/v1/vectorsets.py} +99 -47
- nucliadb/writer/back_pressure.py +17 -46
- nucliadb/writer/resource/basic.py +9 -7
- nucliadb/writer/resource/field.py +42 -9
- nucliadb/writer/settings.py +2 -2
- nucliadb/writer/tus/gcs.py +11 -10
- {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/METADATA +11 -14
- {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/RECORD +94 -96
- {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/WHEEL +1 -1
- nucliadb/common/cluster/discovery/base.py +0 -178
- nucliadb/common/cluster/discovery/k8s.py +0 -301
- nucliadb/common/cluster/discovery/manual.py +0 -57
- nucliadb/common/cluster/discovery/single.py +0 -51
- nucliadb/common/cluster/discovery/types.py +0 -32
- nucliadb/common/cluster/discovery/utils.py +0 -67
- nucliadb/common/cluster/standalone/grpc_node_binding.py +0 -349
- nucliadb/common/cluster/standalone/index_node.py +0 -123
- nucliadb/common/cluster/standalone/service.py +0 -84
- nucliadb/standalone/introspect.py +0 -208
- nucliadb-6.2.0.post2679.dist-info/zip-safe +0 -1
- /nucliadb/common/{cluster/discovery → models_utils}/__init__.py +0 -0
- {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,405 @@
|
|
1
|
+
# Copyright (C) 2021 Bosutech XXI S.L.
|
2
|
+
#
|
3
|
+
# nucliadb is offered under the AGPL v3.0 and as commercial software.
|
4
|
+
# For commercial licensing, contact us at info@nuclia.com.
|
5
|
+
#
|
6
|
+
# AGPL:
|
7
|
+
# This program is free software: you can redistribute it and/or modify
|
8
|
+
# it under the terms of the GNU Affero General Public License as
|
9
|
+
# published by the Free Software Foundation, either version 3 of the
|
10
|
+
# License, or (at your option) any later version.
|
11
|
+
#
|
12
|
+
# This program is distributed in the hope that it will be useful,
|
13
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
14
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
15
|
+
# GNU Affero General Public License for more details.
|
16
|
+
#
|
17
|
+
# You should have received a copy of the GNU Affero General Public License
|
18
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
|
+
#
|
20
|
+
from typing import Optional, TypeVar, Union
|
21
|
+
|
22
|
+
from async_lru import alru_cache
|
23
|
+
from typing_extensions import TypeIs
|
24
|
+
|
25
|
+
from nucliadb.common import datamanagers
|
26
|
+
from nucliadb.common.maindb.utils import get_driver
|
27
|
+
from nucliadb.search import logger
|
28
|
+
from nucliadb.search.predict import SendToPredictError, convert_relations
|
29
|
+
from nucliadb.search.search.metrics import (
|
30
|
+
query_parse_dependency_observer,
|
31
|
+
)
|
32
|
+
from nucliadb.search.search.query_parser.exceptions import InvalidQueryError
|
33
|
+
from nucliadb.search.utilities import get_predict
|
34
|
+
from nucliadb_models.internal.predict import QueryInfo
|
35
|
+
from nucliadb_protos import knowledgebox_pb2, utils_pb2
|
36
|
+
|
37
|
+
|
38
|
+
# We use a class as cache miss marker to allow None values in the cache and to
|
39
|
+
# make mypy happy with typing
|
40
|
+
class NotCached:
|
41
|
+
pass
|
42
|
+
|
43
|
+
|
44
|
+
not_cached = NotCached()
|
45
|
+
|
46
|
+
|
47
|
+
T = TypeVar("T")
|
48
|
+
|
49
|
+
|
50
|
+
def is_cached(field: Union[T, NotCached]) -> TypeIs[T]:
|
51
|
+
return not isinstance(field, NotCached)
|
52
|
+
|
53
|
+
|
54
|
+
class FetcherCache:
|
55
|
+
predict_query_info: Union[Optional[QueryInfo], NotCached] = not_cached
|
56
|
+
predict_detected_entities: Union[list[utils_pb2.RelationNode], NotCached] = not_cached
|
57
|
+
|
58
|
+
# semantic search
|
59
|
+
query_vector: Union[Optional[list[float]], NotCached] = not_cached
|
60
|
+
vectorset: Union[str, NotCached] = not_cached
|
61
|
+
matryoshka_dimension: Union[Optional[int], NotCached] = not_cached
|
62
|
+
|
63
|
+
labels: Union[knowledgebox_pb2.Labels, NotCached] = not_cached
|
64
|
+
|
65
|
+
synonyms: Union[Optional[knowledgebox_pb2.Synonyms], NotCached] = not_cached
|
66
|
+
|
67
|
+
entities_meta_cache: Union[datamanagers.entities.EntitiesMetaCache, NotCached] = not_cached
|
68
|
+
deleted_entity_groups: Union[list[str], NotCached] = not_cached
|
69
|
+
detected_entities: Union[list[utils_pb2.RelationNode], NotCached] = not_cached
|
70
|
+
|
71
|
+
|
72
|
+
class Fetcher:
|
73
|
+
"""Queries are getting more and more complex and different phases of the
|
74
|
+
query depend on different data, not only from the user but from other parts
|
75
|
+
of the system.
|
76
|
+
|
77
|
+
This class is an encapsulation of data gathering across different parts of
|
78
|
+
the system. Given the user query input, it aims to be as efficient as
|
79
|
+
possible removing redundant expensive calls to other parts of the system. An
|
80
|
+
instance of a fetcher caches it's results and it's thought to be used in the
|
81
|
+
context of a single request. DO NOT use this as a global object!
|
82
|
+
|
83
|
+
"""
|
84
|
+
|
85
|
+
def __init__(
|
86
|
+
self,
|
87
|
+
kbid: str,
|
88
|
+
*,
|
89
|
+
query: str,
|
90
|
+
user_vector: Optional[list[float]],
|
91
|
+
vectorset: Optional[str],
|
92
|
+
rephrase: bool,
|
93
|
+
rephrase_prompt: Optional[str],
|
94
|
+
generative_model: Optional[str],
|
95
|
+
):
|
96
|
+
self.kbid = kbid
|
97
|
+
self.query = query
|
98
|
+
self.user_vector = user_vector
|
99
|
+
self.user_vectorset = vectorset
|
100
|
+
self.rephrase = rephrase
|
101
|
+
self.rephrase_prompt = rephrase_prompt
|
102
|
+
self.generative_model = generative_model
|
103
|
+
|
104
|
+
self.cache = FetcherCache()
|
105
|
+
self._validated = False
|
106
|
+
|
107
|
+
# Validation
|
108
|
+
|
109
|
+
async def initial_validate(self):
|
110
|
+
"""Runs a validation on the input parameters. It can raise errors if
|
111
|
+
there's some wrong parameter.
|
112
|
+
|
113
|
+
This function should be always called if validated input for fetching is
|
114
|
+
desired
|
115
|
+
"""
|
116
|
+
if self._validated:
|
117
|
+
return
|
118
|
+
|
119
|
+
self._validated = True
|
120
|
+
|
121
|
+
async def _validate_vectorset(self):
|
122
|
+
if self.user_vectorset is not None:
|
123
|
+
await validate_vectorset(self.kbid, self.user_vectorset)
|
124
|
+
|
125
|
+
# Semantic search
|
126
|
+
|
127
|
+
async def get_matryoshka_dimension(self) -> Optional[int]:
|
128
|
+
if is_cached(self.cache.matryoshka_dimension):
|
129
|
+
return self.cache.matryoshka_dimension
|
130
|
+
|
131
|
+
vectorset = await self.get_vectorset()
|
132
|
+
matryoshka_dimension = await get_matryoshka_dimension_cached(self.kbid, vectorset)
|
133
|
+
self.cache.matryoshka_dimension = matryoshka_dimension
|
134
|
+
return matryoshka_dimension
|
135
|
+
|
136
|
+
async def _get_user_vectorset(self) -> Optional[str]:
|
137
|
+
"""Returns the user's requested vectorset and validates if it does exist
|
138
|
+
in the KB.
|
139
|
+
|
140
|
+
"""
|
141
|
+
vectorset = self.user_vectorset
|
142
|
+
if not self._validated:
|
143
|
+
await self._validate_vectorset()
|
144
|
+
return vectorset
|
145
|
+
|
146
|
+
async def get_vectorset(self) -> str:
|
147
|
+
"""Get the vectorset to be used in the search. If not specified, by the
|
148
|
+
user, Predict API or the own uses KB will provide a default.
|
149
|
+
|
150
|
+
"""
|
151
|
+
|
152
|
+
if is_cached(self.cache.vectorset):
|
153
|
+
return self.cache.vectorset
|
154
|
+
|
155
|
+
if self.user_vectorset:
|
156
|
+
# user explicitly asked for a vectorset
|
157
|
+
self.cache.vectorset = self.user_vectorset
|
158
|
+
return self.user_vectorset
|
159
|
+
|
160
|
+
# when it's not provided, we get the default from Predict API
|
161
|
+
query_info = await self._predict_query_endpoint()
|
162
|
+
if query_info is None:
|
163
|
+
vectorset = None
|
164
|
+
else:
|
165
|
+
if query_info.sentence is None:
|
166
|
+
logger.error(
|
167
|
+
"Asking for a vectorset but /query didn't return one", extra={"kbid": self.kbid}
|
168
|
+
)
|
169
|
+
vectorset = None
|
170
|
+
else:
|
171
|
+
# vectors field is enforced by the data model to have at least one key
|
172
|
+
for vectorset in query_info.sentence.vectors.keys():
|
173
|
+
vectorset = vectorset
|
174
|
+
break
|
175
|
+
|
176
|
+
if vectorset is None:
|
177
|
+
# in case predict don't answer which vectorset to use, fallback to
|
178
|
+
# the first vectorset of the KB
|
179
|
+
async with datamanagers.with_ro_transaction() as txn:
|
180
|
+
async for vectorset, _ in datamanagers.vectorsets.iter(txn, kbid=self.kbid):
|
181
|
+
break
|
182
|
+
assert vectorset is not None, "All KBs must have at least one vectorset in maindb"
|
183
|
+
|
184
|
+
self.cache.vectorset = vectorset
|
185
|
+
return vectorset
|
186
|
+
|
187
|
+
async def get_query_vector(self) -> Optional[list[float]]:
|
188
|
+
if is_cached(self.cache.query_vector):
|
189
|
+
return self.cache.query_vector
|
190
|
+
|
191
|
+
if self.user_vector is not None:
|
192
|
+
query_vector = self.user_vector
|
193
|
+
else:
|
194
|
+
query_info = await self._predict_query_endpoint()
|
195
|
+
if query_info is None or query_info.sentence is None:
|
196
|
+
self.cache.query_vector = None
|
197
|
+
return None
|
198
|
+
|
199
|
+
vectorset = await self.get_vectorset()
|
200
|
+
if vectorset not in query_info.sentence.vectors:
|
201
|
+
logger.warning(
|
202
|
+
"Predict is not responding with a valid query nucliadb vectorset",
|
203
|
+
extra={
|
204
|
+
"kbid": self.kbid,
|
205
|
+
"vectorset": vectorset,
|
206
|
+
"predict_vectorsets": ",".join(query_info.sentence.vectors.keys()),
|
207
|
+
},
|
208
|
+
)
|
209
|
+
self.cache.query_vector = None
|
210
|
+
return None
|
211
|
+
|
212
|
+
query_vector = query_info.sentence.vectors[vectorset]
|
213
|
+
|
214
|
+
matryoshka_dimension = await self.get_matryoshka_dimension()
|
215
|
+
if matryoshka_dimension is not None:
|
216
|
+
if self.user_vector is not None and len(query_vector) < matryoshka_dimension:
|
217
|
+
raise InvalidQueryError(
|
218
|
+
"vector",
|
219
|
+
f"Invalid vector length, please check valid embedding size for {vectorset} model",
|
220
|
+
)
|
221
|
+
|
222
|
+
# KB using a matryoshka embeddings model, cut the query vector
|
223
|
+
# accordingly
|
224
|
+
query_vector = query_vector[:matryoshka_dimension]
|
225
|
+
|
226
|
+
self.cache.query_vector = query_vector
|
227
|
+
return query_vector
|
228
|
+
|
229
|
+
async def get_rephrased_query(self) -> Optional[str]:
|
230
|
+
query_info = await self._predict_query_endpoint()
|
231
|
+
if query_info is None:
|
232
|
+
return None
|
233
|
+
return query_info.rephrased_query
|
234
|
+
|
235
|
+
# Labels
|
236
|
+
|
237
|
+
async def get_classification_labels(self) -> knowledgebox_pb2.Labels:
|
238
|
+
if is_cached(self.cache.labels):
|
239
|
+
return self.cache.labels
|
240
|
+
|
241
|
+
labels = await get_classification_labels(self.kbid)
|
242
|
+
self.cache.labels = labels
|
243
|
+
return labels
|
244
|
+
|
245
|
+
# Entities
|
246
|
+
|
247
|
+
async def get_entities_meta_cache(self) -> datamanagers.entities.EntitiesMetaCache:
|
248
|
+
if is_cached(self.cache.entities_meta_cache):
|
249
|
+
return self.cache.entities_meta_cache
|
250
|
+
|
251
|
+
entities_meta_cache = await get_entities_meta_cache(self.kbid)
|
252
|
+
self.cache.entities_meta_cache = entities_meta_cache
|
253
|
+
return entities_meta_cache
|
254
|
+
|
255
|
+
async def get_deleted_entity_groups(self) -> list[str]:
|
256
|
+
if is_cached(self.cache.deleted_entity_groups):
|
257
|
+
return self.cache.deleted_entity_groups
|
258
|
+
|
259
|
+
deleted_entity_groups = await get_deleted_entity_groups(self.kbid)
|
260
|
+
self.cache.deleted_entity_groups = deleted_entity_groups
|
261
|
+
return deleted_entity_groups
|
262
|
+
|
263
|
+
async def get_detected_entities(self) -> list[utils_pb2.RelationNode]:
|
264
|
+
if is_cached(self.cache.detected_entities):
|
265
|
+
return self.cache.detected_entities
|
266
|
+
|
267
|
+
# Optimization to avoid calling predict twice
|
268
|
+
if is_cached(self.cache.predict_query_info):
|
269
|
+
# /query supersets detect entities, so we already have them
|
270
|
+
query_info = self.cache.predict_query_info
|
271
|
+
if query_info is not None and query_info.entities is not None:
|
272
|
+
detected_entities = convert_relations(query_info.entities.model_dump())
|
273
|
+
else:
|
274
|
+
detected_entities = []
|
275
|
+
else:
|
276
|
+
# No call to /query has been done, we'll use detect entities
|
277
|
+
# endpoint instead (as it's faster)
|
278
|
+
detected_entities = await self._predict_detect_entities()
|
279
|
+
|
280
|
+
self.cache.detected_entities = detected_entities
|
281
|
+
return detected_entities
|
282
|
+
|
283
|
+
# Synonyms
|
284
|
+
|
285
|
+
async def get_synonyms(self) -> Optional[knowledgebox_pb2.Synonyms]:
|
286
|
+
if is_cached(self.cache.synonyms):
|
287
|
+
return self.cache.synonyms
|
288
|
+
|
289
|
+
synonyms = await get_kb_synonyms(self.kbid)
|
290
|
+
self.cache.synonyms = synonyms
|
291
|
+
return synonyms
|
292
|
+
|
293
|
+
# Predict API
|
294
|
+
|
295
|
+
async def _predict_query_endpoint(self) -> Optional[QueryInfo]:
|
296
|
+
if is_cached(self.cache.predict_query_info):
|
297
|
+
return self.cache.predict_query_info
|
298
|
+
|
299
|
+
# calling twice should be avoided as query endpoint is a superset of detect entities
|
300
|
+
if is_cached(self.cache.predict_detected_entities):
|
301
|
+
logger.warning("Fetcher is not being efficient enough and has called predict twice!")
|
302
|
+
|
303
|
+
# we can't call get_vectorset, as it would do a recirsive loop between
|
304
|
+
# functions, so we'll manually parse it
|
305
|
+
vectorset = await self._get_user_vectorset()
|
306
|
+
try:
|
307
|
+
query_info = await query_information(
|
308
|
+
self.kbid,
|
309
|
+
self.query,
|
310
|
+
vectorset,
|
311
|
+
self.generative_model,
|
312
|
+
self.rephrase,
|
313
|
+
self.rephrase_prompt,
|
314
|
+
)
|
315
|
+
except (SendToPredictError, TimeoutError):
|
316
|
+
query_info = None
|
317
|
+
|
318
|
+
self.cache.predict_query_info = query_info
|
319
|
+
return query_info
|
320
|
+
|
321
|
+
async def _predict_detect_entities(self) -> list[utils_pb2.RelationNode]:
|
322
|
+
if is_cached(self.cache.predict_detected_entities):
|
323
|
+
return self.cache.predict_detected_entities
|
324
|
+
|
325
|
+
try:
|
326
|
+
detected_entities = await detect_entities(self.kbid, self.query)
|
327
|
+
except (SendToPredictError, TimeoutError) as ex:
|
328
|
+
logger.warning(f"Errors on Predict API detecting entities: {ex}", extra={"kbid": self.kbid})
|
329
|
+
detected_entities = []
|
330
|
+
|
331
|
+
self.cache.predict_detected_entities = detected_entities
|
332
|
+
return detected_entities
|
333
|
+
|
334
|
+
|
335
|
+
async def validate_vectorset(kbid: str, vectorset: str):
|
336
|
+
async with datamanagers.with_ro_transaction() as txn:
|
337
|
+
if not await datamanagers.vectorsets.exists(txn, kbid=kbid, vectorset_id=vectorset):
|
338
|
+
raise InvalidQueryError(
|
339
|
+
"vectorset", f"Vectorset {vectorset} doesn't exist in you Knowledge Box"
|
340
|
+
)
|
341
|
+
|
342
|
+
|
343
|
+
@query_parse_dependency_observer.wrap({"type": "query_information"})
|
344
|
+
async def query_information(
|
345
|
+
kbid: str,
|
346
|
+
query: str,
|
347
|
+
semantic_model: Optional[str],
|
348
|
+
generative_model: Optional[str] = None,
|
349
|
+
rephrase: bool = False,
|
350
|
+
rephrase_prompt: Optional[str] = None,
|
351
|
+
) -> QueryInfo:
|
352
|
+
predict = get_predict()
|
353
|
+
return await predict.query(kbid, query, semantic_model, generative_model, rephrase, rephrase_prompt)
|
354
|
+
|
355
|
+
|
356
|
+
@query_parse_dependency_observer.wrap({"type": "detect_entities"})
|
357
|
+
async def detect_entities(kbid: str, query: str) -> list[utils_pb2.RelationNode]:
|
358
|
+
predict = get_predict()
|
359
|
+
return await predict.detect_entities(kbid, query)
|
360
|
+
|
361
|
+
|
362
|
+
@alru_cache(maxsize=None)
|
363
|
+
async def get_matryoshka_dimension_cached(kbid: str, vectorset: Optional[str]) -> Optional[int]:
|
364
|
+
# This can be safely cached as the matryoshka dimension is not expected to change
|
365
|
+
return await get_matryoshka_dimension(kbid, vectorset)
|
366
|
+
|
367
|
+
|
368
|
+
@query_parse_dependency_observer.wrap({"type": "matryoshka_dimension"})
|
369
|
+
async def get_matryoshka_dimension(kbid: str, vectorset: Optional[str]) -> Optional[int]:
|
370
|
+
async with get_driver().transaction(read_only=True) as txn:
|
371
|
+
matryoshka_dimension = None
|
372
|
+
if not vectorset:
|
373
|
+
# XXX this should be migrated once we remove the "default" vectorset
|
374
|
+
# concept
|
375
|
+
matryoshka_dimension = await datamanagers.kb.get_matryoshka_vector_dimension(txn, kbid=kbid)
|
376
|
+
else:
|
377
|
+
vectorset_config = await datamanagers.vectorsets.get(txn, kbid=kbid, vectorset_id=vectorset)
|
378
|
+
if vectorset_config is not None and vectorset_config.vectorset_index_config.vector_dimension:
|
379
|
+
matryoshka_dimension = vectorset_config.vectorset_index_config.vector_dimension
|
380
|
+
|
381
|
+
return matryoshka_dimension
|
382
|
+
|
383
|
+
|
384
|
+
@query_parse_dependency_observer.wrap({"type": "classification_labels"})
|
385
|
+
async def get_classification_labels(kbid: str) -> knowledgebox_pb2.Labels:
|
386
|
+
async with get_driver().transaction(read_only=True) as txn:
|
387
|
+
return await datamanagers.labels.get_labels(txn, kbid=kbid)
|
388
|
+
|
389
|
+
|
390
|
+
@query_parse_dependency_observer.wrap({"type": "synonyms"})
|
391
|
+
async def get_kb_synonyms(kbid: str) -> Optional[knowledgebox_pb2.Synonyms]:
|
392
|
+
async with get_driver().transaction(read_only=True) as txn:
|
393
|
+
return await datamanagers.synonyms.get(txn, kbid=kbid)
|
394
|
+
|
395
|
+
|
396
|
+
@query_parse_dependency_observer.wrap({"type": "entities_meta_cache"})
|
397
|
+
async def get_entities_meta_cache(kbid: str) -> datamanagers.entities.EntitiesMetaCache:
|
398
|
+
async with get_driver().transaction(read_only=True) as txn:
|
399
|
+
return await datamanagers.entities.get_entities_meta_cache(txn, kbid=kbid)
|
400
|
+
|
401
|
+
|
402
|
+
@query_parse_dependency_observer.wrap({"type": "deleted_entities_groups"})
|
403
|
+
async def get_deleted_entity_groups(kbid: str) -> list[str]:
|
404
|
+
async with get_driver().transaction(read_only=True) as txn:
|
405
|
+
return list((await datamanagers.entities.get_deleted_groups(txn, kbid=kbid)).entities_groups)
|
@@ -26,12 +26,11 @@ from nucliadb.search.search.filters import (
|
|
26
26
|
convert_to_node_filters,
|
27
27
|
translate_label_filters,
|
28
28
|
)
|
29
|
-
from nucliadb.search.search.query_parser.exceptions import
|
29
|
+
from nucliadb.search.search.query_parser.exceptions import InternalParserError
|
30
30
|
from nucliadb.search.search.query_parser.models import (
|
31
31
|
CatalogFilters,
|
32
32
|
CatalogQuery,
|
33
33
|
DateTimeFilter,
|
34
|
-
MultiMatchBoosterReranker,
|
35
34
|
NoopReranker,
|
36
35
|
PredictReranker,
|
37
36
|
RankFusion,
|
@@ -50,25 +49,26 @@ from nucliadb_models.search import (
|
|
50
49
|
)
|
51
50
|
|
52
51
|
|
53
|
-
def parse_find(item: FindRequest) -> UnitRetrieval:
|
54
|
-
parser = _FindParser(item)
|
55
|
-
return parser.parse()
|
52
|
+
async def parse_find(kbid: str, item: FindRequest) -> UnitRetrieval:
|
53
|
+
parser = _FindParser(kbid, item)
|
54
|
+
return await parser.parse()
|
56
55
|
|
57
56
|
|
58
57
|
class _FindParser:
|
59
|
-
def __init__(self, item: FindRequest):
|
58
|
+
def __init__(self, kbid: str, item: FindRequest):
|
59
|
+
self.kbid = kbid
|
60
60
|
self.item = item
|
61
61
|
|
62
|
-
def parse(self) -> UnitRetrieval:
|
62
|
+
async def parse(self) -> UnitRetrieval:
|
63
63
|
top_k = self._parse_top_k()
|
64
64
|
try:
|
65
65
|
rank_fusion = self._parse_rank_fusion()
|
66
66
|
except ValidationError as exc:
|
67
|
-
raise
|
67
|
+
raise InternalParserError(f"Parsing error in rank fusion: {str(exc)}") from exc
|
68
68
|
try:
|
69
69
|
reranker = self._parse_reranker()
|
70
70
|
except ValidationError as exc:
|
71
|
-
raise
|
71
|
+
raise InternalParserError(f"Parsing error in reranker: {str(exc)}") from exc
|
72
72
|
|
73
73
|
# Adjust retrieval windows. Our current implementation assume:
|
74
74
|
# `top_k <= reranker.window <= rank_fusion.window`
|
@@ -98,7 +98,7 @@ class _FindParser:
|
|
98
98
|
if self.item.rank_fusion == search_models.RankFusionName.RECIPROCAL_RANK_FUSION:
|
99
99
|
rank_fusion = ReciprocalRankFusion(window=window)
|
100
100
|
else:
|
101
|
-
raise
|
101
|
+
raise InternalParserError(f"Unknown rank fusion algorithm: {self.item.rank_fusion}")
|
102
102
|
|
103
103
|
elif isinstance(self.item.rank_fusion, search_models.ReciprocalRankFusion):
|
104
104
|
user_window = self.item.rank_fusion.window
|
@@ -109,7 +109,7 @@ class _FindParser:
|
|
109
109
|
)
|
110
110
|
|
111
111
|
else:
|
112
|
-
raise
|
112
|
+
raise InternalParserError(f"Unknown rank fusion {self.item.rank_fusion}")
|
113
113
|
|
114
114
|
return rank_fusion
|
115
115
|
|
@@ -122,33 +122,34 @@ class _FindParser:
|
|
122
122
|
if self.item.reranker == search_models.RerankerName.NOOP:
|
123
123
|
reranking = NoopReranker()
|
124
124
|
|
125
|
-
elif self.item.reranker == search_models.RerankerName.MULTI_MATCH_BOOSTER:
|
126
|
-
reranking = MultiMatchBoosterReranker()
|
127
|
-
|
128
125
|
elif self.item.reranker == search_models.RerankerName.PREDICT_RERANKER:
|
129
126
|
# for predict rearnker, by default, we want a x2 factor with a
|
130
127
|
# top of 200 results
|
131
128
|
reranking = PredictReranker(window=min(top_k * 2, 200))
|
132
129
|
|
133
130
|
else:
|
134
|
-
raise
|
131
|
+
raise InternalParserError(f"Unknown reranker algorithm: {self.item.reranker}")
|
135
132
|
|
136
133
|
elif isinstance(self.item.reranker, search_models.PredictReranker):
|
137
134
|
user_window = self.item.reranker.window
|
138
135
|
reranking = PredictReranker(window=min(max(user_window or 0, top_k), 200))
|
139
136
|
|
140
137
|
else:
|
141
|
-
raise
|
138
|
+
raise InternalParserError(f"Unknown reranker {self.item.reranker}")
|
142
139
|
|
143
140
|
return reranking
|
144
141
|
|
145
142
|
|
146
143
|
def parse_catalog(kbid: str, item: search_models.CatalogRequest) -> CatalogQuery:
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
144
|
+
filters = item.filters
|
145
|
+
|
146
|
+
if item.hidden is not None:
|
147
|
+
if item.hidden:
|
148
|
+
filters.append(Filter(all=[LABEL_HIDDEN])) # type: ignore
|
149
|
+
else:
|
150
|
+
filters.append(Filter(none=[LABEL_HIDDEN])) # type: ignore
|
151
|
+
|
152
|
+
label_filters: dict[str, Any] = convert_to_node_filters(item.filters)
|
152
153
|
if len(label_filters) > 0:
|
153
154
|
label_filters = translate_label_filters(label_filters)
|
154
155
|
|
@@ -169,58 +169,17 @@ class PredictReranker(Reranker):
|
|
169
169
|
return best
|
170
170
|
|
171
171
|
|
172
|
-
class MultiMatchBoosterReranker(Reranker):
|
173
|
-
"""This reranker gives more value to items that come from different indices"""
|
174
|
-
|
175
|
-
@property
|
176
|
-
def window(self) -> Optional[int]:
|
177
|
-
return None
|
178
|
-
|
179
|
-
@reranker_observer.wrap({"type": "multi_match_booster"})
|
180
|
-
async def _rerank(self, items: list[RerankableItem], options: RerankingOptions) -> list[RankedItem]:
|
181
|
-
"""Given a list of rerankable items, boost matches that appear multiple
|
182
|
-
times. The returned list can be smaller than the initial, as repeated
|
183
|
-
matches are deduplicated.
|
184
|
-
"""
|
185
|
-
reranked_by_id = {}
|
186
|
-
for item in items:
|
187
|
-
if item.id not in reranked_by_id:
|
188
|
-
reranked_by_id[item.id] = RankedItem(
|
189
|
-
id=item.id,
|
190
|
-
score=item.score,
|
191
|
-
score_type=item.score_type,
|
192
|
-
)
|
193
|
-
else:
|
194
|
-
# it's a mutiple match, boost the score
|
195
|
-
if reranked_by_id[item.id].score < item.score:
|
196
|
-
# previous implementation noted that we are using vector
|
197
|
-
# score x2 when we find a multiple match. However, this may
|
198
|
-
# not be true, as the same paragraph could come in any
|
199
|
-
# position in the rank fusioned result list
|
200
|
-
reranked_by_id[item.id].score = item.score * 2
|
201
|
-
|
202
|
-
reranked_by_id[item.id].score_type = SCORE_TYPE.BOTH
|
203
|
-
|
204
|
-
reranked = list(reranked_by_id.values())
|
205
|
-
sort_by_score(reranked)
|
206
|
-
return reranked
|
207
|
-
|
208
|
-
|
209
172
|
def get_reranker(reranker: parser_models.Reranker) -> Reranker:
|
210
173
|
algorithm: Reranker
|
211
174
|
|
212
175
|
if isinstance(reranker, parser_models.NoopReranker):
|
213
176
|
algorithm = NoopReranker()
|
214
177
|
|
215
|
-
elif isinstance(reranker, parser_models.MultiMatchBoosterReranker):
|
216
|
-
algorithm = MultiMatchBoosterReranker()
|
217
|
-
|
218
178
|
elif isinstance(reranker, parser_models.PredictReranker):
|
219
179
|
algorithm = PredictReranker(reranker.window)
|
220
180
|
|
221
181
|
else:
|
222
|
-
|
223
|
-
algorithm = MultiMatchBoosterReranker()
|
182
|
+
raise ValueError(f"Unknown reranker requested: {reranker}")
|
224
183
|
|
225
184
|
return algorithm
|
226
185
|
|
nucliadb/search/search/shards.py
CHANGED
@@ -19,6 +19,10 @@
|
|
19
19
|
#
|
20
20
|
import asyncio
|
21
21
|
|
22
|
+
import backoff
|
23
|
+
from grpc import StatusCode
|
24
|
+
from grpc.aio import AioRpcError
|
25
|
+
|
22
26
|
from nucliadb.common.cluster.base import AbstractIndexNode
|
23
27
|
from nucliadb_protos.nodereader_pb2 import (
|
24
28
|
GetShardRequest,
|
@@ -39,6 +43,15 @@ node_observer = metrics.Observer(
|
|
39
43
|
)
|
40
44
|
|
41
45
|
|
46
|
+
def should_giveup(e: Exception):
|
47
|
+
if isinstance(e, AioRpcError) and e.code() != StatusCode.NOT_FOUND:
|
48
|
+
return True
|
49
|
+
return False
|
50
|
+
|
51
|
+
|
52
|
+
@backoff.on_exception(
|
53
|
+
backoff.expo, Exception, jitter=None, factor=0.1, max_tries=3, giveup=should_giveup
|
54
|
+
)
|
42
55
|
async def query_shard(node: AbstractIndexNode, shard: str, query: SearchRequest) -> SearchResponse:
|
43
56
|
req = SearchRequest()
|
44
57
|
req.CopyFrom(query)
|
@@ -47,6 +60,9 @@ async def query_shard(node: AbstractIndexNode, shard: str, query: SearchRequest)
|
|
47
60
|
return await node.reader.Search(req) # type: ignore
|
48
61
|
|
49
62
|
|
63
|
+
@backoff.on_exception(
|
64
|
+
backoff.expo, Exception, jitter=None, factor=0.1, max_tries=3, giveup=should_giveup
|
65
|
+
)
|
50
66
|
async def get_shard(node: AbstractIndexNode, shard_id: str) -> Shard:
|
51
67
|
req = GetShardRequest()
|
52
68
|
req.shard_id.id = shard_id
|
@@ -54,6 +70,9 @@ async def get_shard(node: AbstractIndexNode, shard_id: str) -> Shard:
|
|
54
70
|
return await node.reader.GetShard(req) # type: ignore
|
55
71
|
|
56
72
|
|
73
|
+
@backoff.on_exception(
|
74
|
+
backoff.expo, Exception, jitter=None, factor=0.1, max_tries=3, giveup=should_giveup
|
75
|
+
)
|
57
76
|
async def suggest_shard(node: AbstractIndexNode, shard: str, query: SuggestRequest) -> SuggestResponse:
|
58
77
|
req = SuggestRequest()
|
59
78
|
req.CopyFrom(query)
|
@@ -17,14 +17,13 @@
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
19
19
|
#
|
20
|
-
import datetime
|
21
20
|
import logging
|
22
21
|
import time
|
23
22
|
|
24
23
|
import orjson
|
25
24
|
import pydantic
|
26
25
|
from fastapi import Request
|
27
|
-
from fastapi.responses import JSONResponse
|
26
|
+
from fastapi.responses import JSONResponse
|
28
27
|
from fastapi.routing import APIRouter
|
29
28
|
from fastapi_versioning import version
|
30
29
|
from jwcrypto import jwe, jwk # type: ignore
|
@@ -33,7 +32,7 @@ from nucliadb.common import datamanagers
|
|
33
32
|
from nucliadb.common.cluster import manager
|
34
33
|
from nucliadb.common.http_clients import processing
|
35
34
|
from nucliadb.common.http_clients.auth import NucliaAuthHTTPClient
|
36
|
-
from nucliadb.standalone import
|
35
|
+
from nucliadb.standalone import versions
|
37
36
|
from nucliadb_models.resource import NucliaDBRoles
|
38
37
|
from nucliadb_utils.authentication import requires
|
39
38
|
from nucliadb_utils.settings import nuclia_settings
|
@@ -146,17 +145,6 @@ async def versions_endpoint(request: Request) -> JSONResponse:
|
|
146
145
|
)
|
147
146
|
|
148
147
|
|
149
|
-
@standalone_api_router.get("/introspect")
|
150
|
-
def introspect_endpoint(request: Request) -> StreamingResponse:
|
151
|
-
introspect_id = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
152
|
-
return StreamingResponse(
|
153
|
-
content=introspect.stream_tar(request.app),
|
154
|
-
status_code=200,
|
155
|
-
headers={"Content-Disposition": f"attachment; filename=introspect_{introspect_id}.tar.gz"},
|
156
|
-
media_type="application/octet-stream",
|
157
|
-
)
|
158
|
-
|
159
|
-
|
160
148
|
@standalone_api_router.get("/pull/position")
|
161
149
|
async def pull_status(request: Request) -> JSONResponse:
|
162
150
|
async with datamanagers.with_ro_transaction() as txn:
|
nucliadb/standalone/settings.py
CHANGED
@@ -83,6 +83,10 @@ class Settings(DriverSettings, StorageSettings, ExtendedStorageSettings):
|
|
83
83
|
default="X-NUCLIADB-ROLES",
|
84
84
|
description="Only used for `upstream_naive` auth policy.",
|
85
85
|
)
|
86
|
+
auth_policy_security_groups_header: str = pydantic.Field(
|
87
|
+
default="X-NUCLIADB-SECURITY_GROUPS",
|
88
|
+
description="Only used for `upstream_naive` auth policy.",
|
89
|
+
)
|
86
90
|
auth_policy_user_default_roles: list[NucliaDBRoles] = pydantic.Field(
|
87
91
|
default=[NucliaDBRoles.READER, NucliaDBRoles.WRITER, NucliaDBRoles.MANAGER],
|
88
92
|
description="Default role to assign to user that is authenticated \
|