nucliadb 6.9.1.post5192__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0023_backfill_pg_catalog.py +2 -2
- migrations/0029_backfill_field_status.py +3 -4
- migrations/0032_remove_old_relations.py +2 -3
- migrations/0038_backfill_catalog_field_labels.py +2 -2
- migrations/0039_backfill_converation_splits_metadata.py +2 -2
- migrations/0041_reindex_conversations.py +137 -0
- migrations/pg/0010_shards_index.py +34 -0
- nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
- migrations/pg/0012_catalog_statistics_undo.py +26 -0
- nucliadb/backups/create.py +2 -15
- nucliadb/backups/restore.py +4 -15
- nucliadb/backups/tasks.py +4 -1
- nucliadb/common/back_pressure/cache.py +2 -3
- nucliadb/common/back_pressure/materializer.py +7 -13
- nucliadb/common/back_pressure/settings.py +6 -6
- nucliadb/common/back_pressure/utils.py +1 -0
- nucliadb/common/cache.py +9 -9
- nucliadb/common/catalog/interface.py +12 -12
- nucliadb/common/catalog/pg.py +41 -29
- nucliadb/common/catalog/utils.py +3 -3
- nucliadb/common/cluster/manager.py +5 -4
- nucliadb/common/cluster/rebalance.py +483 -114
- nucliadb/common/cluster/rollover.py +25 -9
- nucliadb/common/cluster/settings.py +3 -8
- nucliadb/common/cluster/utils.py +34 -8
- nucliadb/common/context/__init__.py +7 -8
- nucliadb/common/context/fastapi.py +1 -2
- nucliadb/common/datamanagers/__init__.py +2 -4
- nucliadb/common/datamanagers/atomic.py +4 -2
- nucliadb/common/datamanagers/cluster.py +1 -2
- nucliadb/common/datamanagers/fields.py +3 -4
- nucliadb/common/datamanagers/kb.py +6 -6
- nucliadb/common/datamanagers/labels.py +2 -3
- nucliadb/common/datamanagers/resources.py +10 -33
- nucliadb/common/datamanagers/rollover.py +5 -7
- nucliadb/common/datamanagers/search_configurations.py +1 -2
- nucliadb/common/datamanagers/synonyms.py +1 -2
- nucliadb/common/datamanagers/utils.py +4 -4
- nucliadb/common/datamanagers/vectorsets.py +4 -4
- nucliadb/common/external_index_providers/base.py +32 -5
- nucliadb/common/external_index_providers/manager.py +4 -5
- nucliadb/common/filter_expression.py +128 -40
- nucliadb/common/http_clients/processing.py +12 -23
- nucliadb/common/ids.py +6 -4
- nucliadb/common/locking.py +1 -2
- nucliadb/common/maindb/driver.py +9 -8
- nucliadb/common/maindb/local.py +5 -5
- nucliadb/common/maindb/pg.py +9 -8
- nucliadb/common/nidx.py +3 -4
- nucliadb/export_import/datamanager.py +4 -3
- nucliadb/export_import/exporter.py +11 -19
- nucliadb/export_import/importer.py +13 -6
- nucliadb/export_import/tasks.py +2 -0
- nucliadb/export_import/utils.py +6 -18
- nucliadb/health.py +2 -2
- nucliadb/ingest/app.py +8 -8
- nucliadb/ingest/consumer/consumer.py +8 -10
- nucliadb/ingest/consumer/pull.py +3 -8
- nucliadb/ingest/consumer/service.py +3 -3
- nucliadb/ingest/consumer/utils.py +1 -1
- nucliadb/ingest/fields/base.py +28 -49
- nucliadb/ingest/fields/conversation.py +12 -12
- nucliadb/ingest/fields/exceptions.py +1 -2
- nucliadb/ingest/fields/file.py +22 -8
- nucliadb/ingest/fields/link.py +7 -7
- nucliadb/ingest/fields/text.py +2 -3
- nucliadb/ingest/orm/brain_v2.py +78 -64
- nucliadb/ingest/orm/broker_message.py +2 -4
- nucliadb/ingest/orm/entities.py +10 -209
- nucliadb/ingest/orm/index_message.py +4 -4
- nucliadb/ingest/orm/knowledgebox.py +18 -27
- nucliadb/ingest/orm/processor/auditing.py +1 -3
- nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
- nucliadb/ingest/orm/processor/processor.py +27 -27
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
- nucliadb/ingest/orm/resource.py +72 -70
- nucliadb/ingest/orm/utils.py +1 -1
- nucliadb/ingest/processing.py +17 -17
- nucliadb/ingest/serialize.py +202 -145
- nucliadb/ingest/service/writer.py +3 -109
- nucliadb/ingest/settings.py +3 -4
- nucliadb/ingest/utils.py +1 -2
- nucliadb/learning_proxy.py +11 -11
- nucliadb/metrics_exporter.py +5 -4
- nucliadb/middleware/__init__.py +82 -1
- nucliadb/migrator/datamanager.py +3 -4
- nucliadb/migrator/migrator.py +1 -2
- nucliadb/migrator/models.py +1 -2
- nucliadb/migrator/settings.py +1 -2
- nucliadb/models/internal/augment.py +614 -0
- nucliadb/models/internal/processing.py +19 -19
- nucliadb/openapi.py +2 -2
- nucliadb/purge/__init__.py +3 -8
- nucliadb/purge/orphan_shards.py +1 -2
- nucliadb/reader/__init__.py +5 -0
- nucliadb/reader/api/models.py +6 -13
- nucliadb/reader/api/v1/download.py +59 -38
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/learning_config.py +24 -4
- nucliadb/reader/api/v1/resource.py +61 -9
- nucliadb/reader/api/v1/services.py +18 -14
- nucliadb/reader/app.py +3 -1
- nucliadb/reader/reader/notifications.py +1 -2
- nucliadb/search/api/v1/__init__.py +2 -0
- nucliadb/search/api/v1/ask.py +3 -4
- nucliadb/search/api/v1/augment.py +585 -0
- nucliadb/search/api/v1/catalog.py +11 -15
- nucliadb/search/api/v1/find.py +16 -22
- nucliadb/search/api/v1/hydrate.py +25 -25
- nucliadb/search/api/v1/knowledgebox.py +1 -2
- nucliadb/search/api/v1/predict_proxy.py +1 -2
- nucliadb/search/api/v1/resource/ask.py +7 -7
- nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
- nucliadb/search/api/v1/resource/search.py +9 -11
- nucliadb/search/api/v1/retrieve.py +130 -0
- nucliadb/search/api/v1/search.py +28 -32
- nucliadb/search/api/v1/suggest.py +11 -14
- nucliadb/search/api/v1/summarize.py +1 -2
- nucliadb/search/api/v1/utils.py +2 -2
- nucliadb/search/app.py +3 -2
- nucliadb/search/augmentor/__init__.py +21 -0
- nucliadb/search/augmentor/augmentor.py +232 -0
- nucliadb/search/augmentor/fields.py +704 -0
- nucliadb/search/augmentor/metrics.py +24 -0
- nucliadb/search/augmentor/paragraphs.py +334 -0
- nucliadb/search/augmentor/resources.py +238 -0
- nucliadb/search/augmentor/utils.py +33 -0
- nucliadb/search/lifecycle.py +3 -1
- nucliadb/search/predict.py +24 -17
- nucliadb/search/predict_models.py +8 -9
- nucliadb/search/requesters/utils.py +11 -10
- nucliadb/search/search/cache.py +19 -23
- nucliadb/search/search/chat/ask.py +88 -59
- nucliadb/search/search/chat/exceptions.py +3 -5
- nucliadb/search/search/chat/fetcher.py +201 -0
- nucliadb/search/search/chat/images.py +6 -4
- nucliadb/search/search/chat/old_prompt.py +1375 -0
- nucliadb/search/search/chat/parser.py +510 -0
- nucliadb/search/search/chat/prompt.py +563 -615
- nucliadb/search/search/chat/query.py +449 -36
- nucliadb/search/search/chat/rpc.py +85 -0
- nucliadb/search/search/fetch.py +3 -4
- nucliadb/search/search/filters.py +8 -11
- nucliadb/search/search/find.py +33 -31
- nucliadb/search/search/find_merge.py +124 -331
- nucliadb/search/search/graph_strategy.py +14 -12
- nucliadb/search/search/hydrator/__init__.py +3 -152
- nucliadb/search/search/hydrator/fields.py +92 -50
- nucliadb/search/search/hydrator/images.py +7 -7
- nucliadb/search/search/hydrator/paragraphs.py +42 -26
- nucliadb/search/search/hydrator/resources.py +20 -16
- nucliadb/search/search/ingestion_agents.py +5 -5
- nucliadb/search/search/merge.py +90 -94
- nucliadb/search/search/metrics.py +10 -9
- nucliadb/search/search/paragraphs.py +7 -9
- nucliadb/search/search/predict_proxy.py +13 -9
- nucliadb/search/search/query.py +14 -86
- nucliadb/search/search/query_parser/fetcher.py +51 -82
- nucliadb/search/search/query_parser/models.py +19 -20
- nucliadb/search/search/query_parser/old_filters.py +20 -19
- nucliadb/search/search/query_parser/parsers/ask.py +4 -5
- nucliadb/search/search/query_parser/parsers/catalog.py +5 -6
- nucliadb/search/search/query_parser/parsers/common.py +5 -6
- nucliadb/search/search/query_parser/parsers/find.py +6 -26
- nucliadb/search/search/query_parser/parsers/graph.py +13 -23
- nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
- nucliadb/search/search/query_parser/parsers/search.py +15 -53
- nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
- nucliadb/search/search/rank_fusion.py +18 -13
- nucliadb/search/search/rerankers.py +5 -6
- nucliadb/search/search/retrieval.py +300 -0
- nucliadb/search/search/summarize.py +5 -6
- nucliadb/search/search/utils.py +3 -4
- nucliadb/search/settings.py +1 -2
- nucliadb/standalone/api_router.py +1 -1
- nucliadb/standalone/app.py +4 -3
- nucliadb/standalone/auth.py +5 -6
- nucliadb/standalone/lifecycle.py +2 -2
- nucliadb/standalone/run.py +2 -4
- nucliadb/standalone/settings.py +5 -6
- nucliadb/standalone/versions.py +3 -4
- nucliadb/tasks/consumer.py +13 -8
- nucliadb/tasks/models.py +2 -1
- nucliadb/tasks/producer.py +3 -3
- nucliadb/tasks/retries.py +8 -7
- nucliadb/train/api/utils.py +1 -3
- nucliadb/train/api/v1/shards.py +1 -2
- nucliadb/train/api/v1/trainset.py +1 -2
- nucliadb/train/app.py +1 -1
- nucliadb/train/generator.py +4 -4
- nucliadb/train/generators/field_classifier.py +2 -2
- nucliadb/train/generators/field_streaming.py +6 -6
- nucliadb/train/generators/image_classifier.py +2 -2
- nucliadb/train/generators/paragraph_classifier.py +2 -2
- nucliadb/train/generators/paragraph_streaming.py +2 -2
- nucliadb/train/generators/question_answer_streaming.py +2 -2
- nucliadb/train/generators/sentence_classifier.py +2 -2
- nucliadb/train/generators/token_classifier.py +3 -2
- nucliadb/train/generators/utils.py +6 -5
- nucliadb/train/nodes.py +3 -3
- nucliadb/train/resource.py +6 -8
- nucliadb/train/settings.py +3 -4
- nucliadb/train/types.py +11 -11
- nucliadb/train/upload.py +3 -2
- nucliadb/train/uploader.py +1 -2
- nucliadb/train/utils.py +1 -2
- nucliadb/writer/api/v1/export_import.py +4 -1
- nucliadb/writer/api/v1/field.py +7 -11
- nucliadb/writer/api/v1/knowledgebox.py +3 -4
- nucliadb/writer/api/v1/resource.py +9 -20
- nucliadb/writer/api/v1/services.py +10 -132
- nucliadb/writer/api/v1/upload.py +73 -72
- nucliadb/writer/app.py +8 -2
- nucliadb/writer/resource/basic.py +12 -15
- nucliadb/writer/resource/field.py +7 -5
- nucliadb/writer/resource/origin.py +7 -0
- nucliadb/writer/settings.py +2 -3
- nucliadb/writer/tus/__init__.py +2 -3
- nucliadb/writer/tus/azure.py +1 -3
- nucliadb/writer/tus/dm.py +3 -3
- nucliadb/writer/tus/exceptions.py +3 -4
- nucliadb/writer/tus/gcs.py +5 -6
- nucliadb/writer/tus/s3.py +2 -3
- nucliadb/writer/tus/storage.py +3 -3
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +9 -10
- nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
- nucliadb/common/datamanagers/entities.py +0 -139
- nucliadb-6.9.1.post5192.dist-info/RECORD +0 -392
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
nucliadb/search/search/query.py
CHANGED
|
@@ -18,17 +18,13 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
from datetime import datetime
|
|
21
|
-
from typing import Any
|
|
21
|
+
from typing import Any
|
|
22
22
|
|
|
23
23
|
from nidx_protos import nodereader_pb2
|
|
24
24
|
from nidx_protos.noderesources_pb2 import Resource
|
|
25
25
|
|
|
26
|
-
from nucliadb.common import datamanagers
|
|
27
26
|
from nucliadb.common.exceptions import InvalidQueryError
|
|
28
27
|
from nucliadb.common.filter_expression import add_and_expression, parse_expression
|
|
29
|
-
from nucliadb.search.search.filters import (
|
|
30
|
-
translate_label,
|
|
31
|
-
)
|
|
32
28
|
from nucliadb.search.search.query_parser.fetcher import Fetcher
|
|
33
29
|
from nucliadb_models.filters import FilterExpression
|
|
34
30
|
from nucliadb_models.labels import LABEL_HIDDEN
|
|
@@ -38,7 +34,6 @@ from nucliadb_models.search import (
|
|
|
38
34
|
SortOrder,
|
|
39
35
|
SuggestOptions,
|
|
40
36
|
)
|
|
41
|
-
from nucliadb_protos import utils_pb2
|
|
42
37
|
|
|
43
38
|
from .query_parser.old_filters import OldFilterParams, parse_old_filters
|
|
44
39
|
|
|
@@ -47,16 +42,16 @@ async def paragraph_query_to_pb(
|
|
|
47
42
|
kbid: str,
|
|
48
43
|
rid: str,
|
|
49
44
|
query: str,
|
|
50
|
-
filter_expression:
|
|
45
|
+
filter_expression: FilterExpression | None,
|
|
51
46
|
fields: list[str],
|
|
52
47
|
filters: list[str],
|
|
53
48
|
faceted: list[str],
|
|
54
49
|
top_k: int,
|
|
55
|
-
range_creation_start:
|
|
56
|
-
range_creation_end:
|
|
57
|
-
range_modification_start:
|
|
58
|
-
range_modification_end:
|
|
59
|
-
sort:
|
|
50
|
+
range_creation_start: datetime | None = None,
|
|
51
|
+
range_creation_end: datetime | None = None,
|
|
52
|
+
range_modification_start: datetime | None = None,
|
|
53
|
+
range_modification_end: datetime | None = None,
|
|
54
|
+
sort: str | None = None,
|
|
60
55
|
sort_ord: str = SortOrder.DESC.value,
|
|
61
56
|
with_duplicates: bool = False,
|
|
62
57
|
) -> nodereader_pb2.SearchRequest:
|
|
@@ -119,86 +114,19 @@ async def paragraph_query_to_pb(
|
|
|
119
114
|
return request
|
|
120
115
|
|
|
121
116
|
|
|
122
|
-
def expand_entities(
|
|
123
|
-
meta_cache: datamanagers.entities.EntitiesMetaCache,
|
|
124
|
-
detected_entities: list[utils_pb2.RelationNode],
|
|
125
|
-
) -> list[utils_pb2.RelationNode]:
|
|
126
|
-
"""
|
|
127
|
-
Iterate through duplicated entities in a kb.
|
|
128
|
-
|
|
129
|
-
The algorithm first makes it so we can look up duplicates by source and
|
|
130
|
-
by the referenced entity and expands from both directions.
|
|
131
|
-
"""
|
|
132
|
-
result_entities = {entity.value: entity for entity in detected_entities}
|
|
133
|
-
duplicated_entities = meta_cache.duplicate_entities
|
|
134
|
-
duplicated_entities_by_value = meta_cache.duplicate_entities_by_value
|
|
135
|
-
|
|
136
|
-
for entity in detected_entities[:]:
|
|
137
|
-
if entity.subtype not in duplicated_entities:
|
|
138
|
-
continue
|
|
139
|
-
|
|
140
|
-
if entity.value in duplicated_entities[entity.subtype]:
|
|
141
|
-
for duplicate in duplicated_entities[entity.subtype][entity.value]:
|
|
142
|
-
result_entities[duplicate] = utils_pb2.RelationNode(
|
|
143
|
-
ntype=utils_pb2.RelationNode.NodeType.ENTITY,
|
|
144
|
-
subtype=entity.subtype,
|
|
145
|
-
value=duplicate,
|
|
146
|
-
)
|
|
147
|
-
|
|
148
|
-
if entity.value in duplicated_entities_by_value[entity.subtype]:
|
|
149
|
-
source_duplicate = duplicated_entities_by_value[entity.subtype][entity.value]
|
|
150
|
-
result_entities[source_duplicate] = utils_pb2.RelationNode(
|
|
151
|
-
ntype=utils_pb2.RelationNode.NodeType.ENTITY,
|
|
152
|
-
subtype=entity.subtype,
|
|
153
|
-
value=source_duplicate,
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
if source_duplicate in duplicated_entities[entity.subtype]:
|
|
157
|
-
for duplicate in duplicated_entities[entity.subtype][source_duplicate]:
|
|
158
|
-
if duplicate == entity.value:
|
|
159
|
-
continue
|
|
160
|
-
result_entities[duplicate] = utils_pb2.RelationNode(
|
|
161
|
-
ntype=utils_pb2.RelationNode.NodeType.ENTITY,
|
|
162
|
-
subtype=entity.subtype,
|
|
163
|
-
value=duplicate,
|
|
164
|
-
)
|
|
165
|
-
|
|
166
|
-
return list(result_entities.values())
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
def apply_entities_filter(
|
|
170
|
-
request: nodereader_pb2.SearchRequest,
|
|
171
|
-
detected_entities: list[utils_pb2.RelationNode],
|
|
172
|
-
) -> list[str]:
|
|
173
|
-
added_filters = []
|
|
174
|
-
for entity_filter in [
|
|
175
|
-
f"/e/{entity.subtype}/{entity.value}"
|
|
176
|
-
for entity in detected_entities
|
|
177
|
-
if entity.ntype == utils_pb2.RelationNode.NodeType.ENTITY
|
|
178
|
-
]:
|
|
179
|
-
if entity_filter not in added_filters:
|
|
180
|
-
added_filters.append(entity_filter)
|
|
181
|
-
# Add the entity to the filter expression (with AND)
|
|
182
|
-
entity_expr = nodereader_pb2.FilterExpression()
|
|
183
|
-
entity_expr.facet.facet = translate_label(entity_filter)
|
|
184
|
-
add_and_expression(request.field_filter, entity_expr)
|
|
185
|
-
|
|
186
|
-
return added_filters
|
|
187
|
-
|
|
188
|
-
|
|
189
117
|
async def suggest_query_to_pb(
|
|
190
118
|
kbid: str,
|
|
191
119
|
features: list[SuggestOptions],
|
|
192
120
|
query: str,
|
|
193
|
-
filter_expression:
|
|
121
|
+
filter_expression: FilterExpression | None,
|
|
194
122
|
fields: list[str],
|
|
195
123
|
filters: list[str],
|
|
196
124
|
faceted: list[str],
|
|
197
|
-
range_creation_start:
|
|
198
|
-
range_creation_end:
|
|
199
|
-
range_modification_start:
|
|
200
|
-
range_modification_end:
|
|
201
|
-
hidden:
|
|
125
|
+
range_creation_start: datetime | None = None,
|
|
126
|
+
range_creation_end: datetime | None = None,
|
|
127
|
+
range_modification_start: datetime | None = None,
|
|
128
|
+
range_modification_end: datetime | None = None,
|
|
129
|
+
hidden: bool | None = None,
|
|
202
130
|
) -> nodereader_pb2.SuggestRequest:
|
|
203
131
|
request = nodereader_pb2.SuggestRequest()
|
|
204
132
|
|
|
@@ -305,7 +233,7 @@ def check_supported_filters(filters: dict[str, Any], paragraph_labels: list[str]
|
|
|
305
233
|
)
|
|
306
234
|
|
|
307
235
|
|
|
308
|
-
def get_sort_field_proto(obj: SortField) ->
|
|
236
|
+
def get_sort_field_proto(obj: SortField) -> nodereader_pb2.OrderBy.OrderField.ValueType | None:
|
|
309
237
|
return {
|
|
310
238
|
SortField.SCORE: None,
|
|
311
239
|
SortField.CREATED: nodereader_pb2.OrderBy.OrderField.CREATED,
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
import asyncio
|
|
21
|
-
from typing import
|
|
21
|
+
from typing import TypeVar
|
|
22
22
|
|
|
23
23
|
from async_lru import alru_cache
|
|
24
24
|
from typing_extensions import TypeIs
|
|
@@ -29,15 +29,10 @@ from nucliadb.common.maindb.utils import get_driver
|
|
|
29
29
|
from nucliadb.search import logger
|
|
30
30
|
from nucliadb.search.predict import SendToPredictError, convert_relations
|
|
31
31
|
from nucliadb.search.predict_models import QueryModel
|
|
32
|
-
from nucliadb.search.search.metrics import
|
|
33
|
-
query_parse_dependency_observer,
|
|
34
|
-
)
|
|
32
|
+
from nucliadb.search.search.metrics import query_parse_dependency_observer
|
|
35
33
|
from nucliadb.search.utilities import get_predict
|
|
36
34
|
from nucliadb_models.internal.predict import QueryInfo
|
|
37
|
-
from nucliadb_models.search import
|
|
38
|
-
Image,
|
|
39
|
-
MaxTokens,
|
|
40
|
-
)
|
|
35
|
+
from nucliadb_models.search import Image, MaxTokens
|
|
41
36
|
from nucliadb_protos import knowledgebox_pb2, utils_pb2
|
|
42
37
|
|
|
43
38
|
|
|
@@ -53,23 +48,22 @@ not_cached = NotCached()
|
|
|
53
48
|
T = TypeVar("T")
|
|
54
49
|
|
|
55
50
|
|
|
56
|
-
def is_cached(field:
|
|
51
|
+
def is_cached(field: T | NotCached) -> TypeIs[T]:
|
|
57
52
|
return not isinstance(field, NotCached)
|
|
58
53
|
|
|
59
54
|
|
|
60
55
|
class FetcherCache:
|
|
61
|
-
predict_query_info:
|
|
56
|
+
predict_query_info: QueryInfo | None | NotCached = not_cached
|
|
62
57
|
|
|
63
58
|
# semantic search
|
|
64
|
-
vectorset:
|
|
59
|
+
vectorset: str | NotCached = not_cached
|
|
65
60
|
|
|
66
|
-
labels:
|
|
61
|
+
labels: knowledgebox_pb2.Labels | NotCached = not_cached
|
|
67
62
|
|
|
68
|
-
synonyms:
|
|
63
|
+
synonyms: knowledgebox_pb2.Synonyms | None | NotCached = not_cached
|
|
69
64
|
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
detected_entities: Union[list[utils_pb2.RelationNode], NotCached] = not_cached
|
|
65
|
+
deleted_entity_groups: list[str] | NotCached = not_cached
|
|
66
|
+
detected_entities: list[utils_pb2.RelationNode] | NotCached = not_cached
|
|
73
67
|
|
|
74
68
|
|
|
75
69
|
class Fetcher:
|
|
@@ -90,12 +84,12 @@ class Fetcher:
|
|
|
90
84
|
kbid: str,
|
|
91
85
|
*,
|
|
92
86
|
query: str,
|
|
93
|
-
user_vector:
|
|
94
|
-
vectorset:
|
|
87
|
+
user_vector: list[float] | None,
|
|
88
|
+
vectorset: str | None,
|
|
95
89
|
rephrase: bool,
|
|
96
|
-
rephrase_prompt:
|
|
97
|
-
generative_model:
|
|
98
|
-
query_image:
|
|
90
|
+
rephrase_prompt: str | None,
|
|
91
|
+
generative_model: str | None,
|
|
92
|
+
query_image: Image | None,
|
|
99
93
|
):
|
|
100
94
|
self.kbid = kbid
|
|
101
95
|
self.query = query
|
|
@@ -112,11 +106,11 @@ class Fetcher:
|
|
|
112
106
|
|
|
113
107
|
# Semantic search
|
|
114
108
|
|
|
115
|
-
async def get_matryoshka_dimension(self) ->
|
|
109
|
+
async def get_matryoshka_dimension(self) -> int | None:
|
|
116
110
|
vectorset = await self.get_vectorset()
|
|
117
|
-
return await get_matryoshka_dimension_cached(self.kbid, vectorset)
|
|
111
|
+
return await self.get_matryoshka_dimension_cached(self.kbid, vectorset)
|
|
118
112
|
|
|
119
|
-
async def
|
|
113
|
+
async def get_user_vectorset(self) -> str | None:
|
|
120
114
|
"""Returns the user's requested vectorset and validates if it does exist
|
|
121
115
|
in the KB.
|
|
122
116
|
|
|
@@ -124,7 +118,7 @@ class Fetcher:
|
|
|
124
118
|
async with self.locks.setdefault("user_vectorset", asyncio.Lock()):
|
|
125
119
|
if not self.user_vectorset_validated:
|
|
126
120
|
if self.user_vectorset is not None:
|
|
127
|
-
await validate_vectorset(self.kbid, self.user_vectorset)
|
|
121
|
+
await self.validate_vectorset(self.kbid, self.user_vectorset)
|
|
128
122
|
self.user_vectorset_validated = True
|
|
129
123
|
return self.user_vectorset
|
|
130
124
|
|
|
@@ -137,7 +131,7 @@ class Fetcher:
|
|
|
137
131
|
if is_cached(self.cache.vectorset):
|
|
138
132
|
return self.cache.vectorset
|
|
139
133
|
|
|
140
|
-
user_vectorset = await self.
|
|
134
|
+
user_vectorset = await self.get_user_vectorset()
|
|
141
135
|
if user_vectorset:
|
|
142
136
|
# user explicitly asked for a vectorset
|
|
143
137
|
self.cache.vectorset = user_vectorset
|
|
@@ -170,7 +164,7 @@ class Fetcher:
|
|
|
170
164
|
self.cache.vectorset = vectorset
|
|
171
165
|
return vectorset
|
|
172
166
|
|
|
173
|
-
async def get_query_vector(self) ->
|
|
167
|
+
async def get_query_vector(self) -> list[float] | None:
|
|
174
168
|
if self.user_vector is not None:
|
|
175
169
|
query_vector = self.user_vector
|
|
176
170
|
else:
|
|
@@ -206,13 +200,20 @@ class Fetcher:
|
|
|
206
200
|
|
|
207
201
|
return query_vector
|
|
208
202
|
|
|
209
|
-
async def get_rephrased_query(self) ->
|
|
203
|
+
async def get_rephrased_query(self) -> str | None:
|
|
210
204
|
query_info = await self._predict_query_endpoint()
|
|
211
205
|
if query_info is None:
|
|
212
206
|
return None
|
|
213
207
|
return query_info.rephrased_query
|
|
214
208
|
|
|
215
|
-
|
|
209
|
+
def get_cached_rephrased_query(self) -> str | None:
|
|
210
|
+
if not is_cached(self.cache.predict_query_info):
|
|
211
|
+
return None
|
|
212
|
+
if self.cache.predict_query_info is None:
|
|
213
|
+
return None
|
|
214
|
+
return self.cache.predict_query_info.rephrased_query
|
|
215
|
+
|
|
216
|
+
async def get_semantic_min_score(self) -> float | None:
|
|
216
217
|
query_info = await self._predict_query_endpoint()
|
|
217
218
|
if query_info is None:
|
|
218
219
|
return None
|
|
@@ -234,24 +235,6 @@ class Fetcher:
|
|
|
234
235
|
|
|
235
236
|
# Entities
|
|
236
237
|
|
|
237
|
-
async def get_entities_meta_cache(self) -> datamanagers.entities.EntitiesMetaCache:
|
|
238
|
-
async with self.locks.setdefault("entities_meta_cache", asyncio.Lock()):
|
|
239
|
-
if is_cached(self.cache.entities_meta_cache):
|
|
240
|
-
return self.cache.entities_meta_cache
|
|
241
|
-
|
|
242
|
-
entities_meta_cache = await get_entities_meta_cache(self.kbid)
|
|
243
|
-
self.cache.entities_meta_cache = entities_meta_cache
|
|
244
|
-
return entities_meta_cache
|
|
245
|
-
|
|
246
|
-
async def get_deleted_entity_groups(self) -> list[str]:
|
|
247
|
-
async with self.locks.setdefault("deleted_entity_groups", asyncio.Lock()):
|
|
248
|
-
if is_cached(self.cache.deleted_entity_groups):
|
|
249
|
-
return self.cache.deleted_entity_groups
|
|
250
|
-
|
|
251
|
-
deleted_entity_groups = await get_deleted_entity_groups(self.kbid)
|
|
252
|
-
self.cache.deleted_entity_groups = deleted_entity_groups
|
|
253
|
-
return deleted_entity_groups
|
|
254
|
-
|
|
255
238
|
async def get_detected_entities(self) -> list[utils_pb2.RelationNode]:
|
|
256
239
|
async with self.locks.setdefault("detected_entities", asyncio.Lock()):
|
|
257
240
|
if is_cached(self.cache.detected_entities):
|
|
@@ -275,7 +258,7 @@ class Fetcher:
|
|
|
275
258
|
|
|
276
259
|
# Synonyms
|
|
277
260
|
|
|
278
|
-
async def get_synonyms(self) ->
|
|
261
|
+
async def get_synonyms(self) -> knowledgebox_pb2.Synonyms | None:
|
|
279
262
|
async with self.locks.setdefault("synonyms", asyncio.Lock()):
|
|
280
263
|
if is_cached(self.cache.synonyms):
|
|
281
264
|
return self.cache.synonyms
|
|
@@ -293,7 +276,7 @@ class Fetcher:
|
|
|
293
276
|
|
|
294
277
|
return query_info.visual_llm
|
|
295
278
|
|
|
296
|
-
async def get_max_context_tokens(self, max_tokens:
|
|
279
|
+
async def get_max_context_tokens(self, max_tokens: MaxTokens | None) -> int:
|
|
297
280
|
query_info = await self._predict_query_endpoint()
|
|
298
281
|
if query_info is None:
|
|
299
282
|
raise SendToPredictError("Error while using predict's query endpoint")
|
|
@@ -308,21 +291,21 @@ class Fetcher:
|
|
|
308
291
|
return max_tokens.context
|
|
309
292
|
return model_max
|
|
310
293
|
|
|
311
|
-
def get_max_answer_tokens(self, max_tokens:
|
|
294
|
+
def get_max_answer_tokens(self, max_tokens: MaxTokens | None) -> int | None:
|
|
312
295
|
if max_tokens is not None and max_tokens.answer is not None:
|
|
313
296
|
return max_tokens.answer
|
|
314
297
|
return None
|
|
315
298
|
|
|
316
299
|
# Predict API
|
|
317
300
|
|
|
318
|
-
async def _predict_query_endpoint(self) ->
|
|
301
|
+
async def _predict_query_endpoint(self) -> QueryInfo | None:
|
|
319
302
|
async with self.locks.setdefault("predict_query_endpoint", asyncio.Lock()):
|
|
320
303
|
if is_cached(self.cache.predict_query_info):
|
|
321
304
|
return self.cache.predict_query_info
|
|
322
305
|
|
|
323
306
|
# we can't call get_vectorset, as it would do a recirsive loop between
|
|
324
307
|
# functions, so we'll manually parse it
|
|
325
|
-
vectorset = await self.
|
|
308
|
+
vectorset = await self.get_user_vectorset()
|
|
326
309
|
try:
|
|
327
310
|
query_info = await query_information(
|
|
328
311
|
self.kbid,
|
|
@@ -348,24 +331,28 @@ class Fetcher:
|
|
|
348
331
|
|
|
349
332
|
return detected_entities
|
|
350
333
|
|
|
334
|
+
async def validate_vectorset(self, kbid: str, vectorset: str):
|
|
335
|
+
async with datamanagers.with_ro_transaction() as txn:
|
|
336
|
+
if not await datamanagers.vectorsets.exists(txn, kbid=kbid, vectorset_id=vectorset):
|
|
337
|
+
raise InvalidQueryError(
|
|
338
|
+
"vectorset", f"Vectorset {vectorset} doesn't exist in your Knowledge Box"
|
|
339
|
+
)
|
|
351
340
|
|
|
352
|
-
|
|
353
|
-
async
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
"vectorset", f"Vectorset {vectorset} doesn't exist in you Knowledge Box"
|
|
357
|
-
)
|
|
341
|
+
@alru_cache(maxsize=10)
|
|
342
|
+
async def get_matryoshka_dimension_cached(self, kbid: str, vectorset: str) -> int | None:
|
|
343
|
+
# This can be safely cached as the matryoshka dimension is not expected to change
|
|
344
|
+
return await get_matryoshka_dimension(kbid, vectorset)
|
|
358
345
|
|
|
359
346
|
|
|
360
347
|
@query_parse_dependency_observer.wrap({"type": "query_information"})
|
|
361
348
|
async def query_information(
|
|
362
349
|
kbid: str,
|
|
363
350
|
query: str,
|
|
364
|
-
semantic_model:
|
|
365
|
-
generative_model:
|
|
351
|
+
semantic_model: str | None,
|
|
352
|
+
generative_model: str | None = None,
|
|
366
353
|
rephrase: bool = False,
|
|
367
|
-
rephrase_prompt:
|
|
368
|
-
query_image:
|
|
354
|
+
rephrase_prompt: str | None = None,
|
|
355
|
+
query_image: Image | None = None,
|
|
369
356
|
) -> QueryInfo:
|
|
370
357
|
predict = get_predict()
|
|
371
358
|
item = QueryModel(
|
|
@@ -385,14 +372,8 @@ async def detect_entities(kbid: str, query: str) -> list[utils_pb2.RelationNode]
|
|
|
385
372
|
return await predict.detect_entities(kbid, query)
|
|
386
373
|
|
|
387
374
|
|
|
388
|
-
@alru_cache(maxsize=None)
|
|
389
|
-
async def get_matryoshka_dimension_cached(kbid: str, vectorset: str) -> Optional[int]:
|
|
390
|
-
# This can be safely cached as the matryoshka dimension is not expected to change
|
|
391
|
-
return await get_matryoshka_dimension(kbid, vectorset)
|
|
392
|
-
|
|
393
|
-
|
|
394
375
|
@query_parse_dependency_observer.wrap({"type": "matryoshka_dimension"})
|
|
395
|
-
async def get_matryoshka_dimension(kbid: str, vectorset:
|
|
376
|
+
async def get_matryoshka_dimension(kbid: str, vectorset: str | None) -> int | None:
|
|
396
377
|
async with get_driver().ro_transaction() as txn:
|
|
397
378
|
matryoshka_dimension = None
|
|
398
379
|
if not vectorset:
|
|
@@ -414,18 +395,6 @@ async def get_classification_labels(kbid: str) -> knowledgebox_pb2.Labels:
|
|
|
414
395
|
|
|
415
396
|
|
|
416
397
|
@query_parse_dependency_observer.wrap({"type": "synonyms"})
|
|
417
|
-
async def get_kb_synonyms(kbid: str) ->
|
|
398
|
+
async def get_kb_synonyms(kbid: str) -> knowledgebox_pb2.Synonyms | None:
|
|
418
399
|
async with get_driver().ro_transaction() as txn:
|
|
419
400
|
return await datamanagers.synonyms.get(txn, kbid=kbid)
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
@query_parse_dependency_observer.wrap({"type": "entities_meta_cache"})
|
|
423
|
-
async def get_entities_meta_cache(kbid: str) -> datamanagers.entities.EntitiesMetaCache:
|
|
424
|
-
async with get_driver().ro_transaction() as txn:
|
|
425
|
-
return await datamanagers.entities.get_entities_meta_cache(txn, kbid=kbid)
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
@query_parse_dependency_observer.wrap({"type": "deleted_entities_groups"})
|
|
429
|
-
async def get_deleted_entity_groups(kbid: str) -> list[str]:
|
|
430
|
-
async with get_driver().ro_transaction() as txn:
|
|
431
|
-
return list((await datamanagers.entities.get_deleted_groups(txn, kbid=kbid)).entities_groups)
|
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
from datetime import datetime
|
|
21
|
-
from typing import Optional, Union
|
|
22
21
|
|
|
23
22
|
from nidx_protos import nodereader_pb2
|
|
24
23
|
from pydantic import BaseModel, ConfigDict, Field
|
|
@@ -26,6 +25,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
|
|
26
25
|
from nucliadb.search.search.query_parser.fetcher import Fetcher
|
|
27
26
|
from nucliadb_models import search as search_models
|
|
28
27
|
from nucliadb_models.graph.requests import GraphPathQuery
|
|
28
|
+
from nucliadb_models.search import MAX_RANK_FUSION_WINDOW
|
|
29
29
|
from nucliadb_protos import utils_pb2
|
|
30
30
|
|
|
31
31
|
### Retrieval
|
|
@@ -46,7 +46,7 @@ KeywordQuery = _TextQuery
|
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
class SemanticQuery(BaseModel):
|
|
49
|
-
query:
|
|
49
|
+
query: list[float] | None
|
|
50
50
|
vectorset: str
|
|
51
51
|
min_score: float
|
|
52
52
|
|
|
@@ -66,11 +66,11 @@ class GraphQuery(BaseModel):
|
|
|
66
66
|
|
|
67
67
|
|
|
68
68
|
class Query(BaseModel):
|
|
69
|
-
fulltext:
|
|
70
|
-
keyword:
|
|
71
|
-
semantic:
|
|
72
|
-
relation:
|
|
73
|
-
graph:
|
|
69
|
+
fulltext: FulltextQuery | None = None
|
|
70
|
+
keyword: KeywordQuery | None = None
|
|
71
|
+
semantic: SemanticQuery | None = None
|
|
72
|
+
relation: RelationQuery | None = None
|
|
73
|
+
graph: GraphQuery | None = None
|
|
74
74
|
|
|
75
75
|
|
|
76
76
|
# filters
|
|
@@ -79,29 +79,28 @@ class Query(BaseModel):
|
|
|
79
79
|
class Filters(BaseModel):
|
|
80
80
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
81
81
|
|
|
82
|
-
field_expression:
|
|
83
|
-
paragraph_expression:
|
|
82
|
+
field_expression: nodereader_pb2.FilterExpression | None = None
|
|
83
|
+
paragraph_expression: nodereader_pb2.FilterExpression | None = None
|
|
84
84
|
filter_expression_operator: nodereader_pb2.FilterOperator.ValueType = (
|
|
85
85
|
nodereader_pb2.FilterOperator.AND
|
|
86
86
|
)
|
|
87
87
|
|
|
88
|
-
autofilter: Optional[list[utils_pb2.RelationNode]] = None
|
|
89
88
|
facets: list[str] = Field(default_factory=list)
|
|
90
|
-
hidden:
|
|
91
|
-
security:
|
|
89
|
+
hidden: bool | None = None
|
|
90
|
+
security: search_models.RequestSecurity | None = None
|
|
92
91
|
with_duplicates: bool = False
|
|
93
92
|
|
|
94
93
|
|
|
95
94
|
class DateTimeFilter(BaseModel):
|
|
96
|
-
after:
|
|
97
|
-
before:
|
|
95
|
+
after: datetime | None = None # aka, start
|
|
96
|
+
before: datetime | None = None # aka, end
|
|
98
97
|
|
|
99
98
|
|
|
100
99
|
# rank fusion
|
|
101
100
|
|
|
102
101
|
|
|
103
102
|
class RankFusion(BaseModel):
|
|
104
|
-
window: int = Field(le=
|
|
103
|
+
window: int = Field(le=MAX_RANK_FUSION_WINDOW)
|
|
105
104
|
|
|
106
105
|
|
|
107
106
|
class ReciprocalRankFusion(RankFusion):
|
|
@@ -122,7 +121,7 @@ class PredictReranker(BaseModel):
|
|
|
122
121
|
window: int = Field(le=200)
|
|
123
122
|
|
|
124
123
|
|
|
125
|
-
Reranker =
|
|
124
|
+
Reranker = NoopReranker | PredictReranker
|
|
126
125
|
|
|
127
126
|
# retrieval and generation operations
|
|
128
127
|
|
|
@@ -131,8 +130,8 @@ class UnitRetrieval(BaseModel):
|
|
|
131
130
|
query: Query
|
|
132
131
|
top_k: int
|
|
133
132
|
filters: Filters = Field(default_factory=Filters)
|
|
134
|
-
rank_fusion:
|
|
135
|
-
reranker:
|
|
133
|
+
rank_fusion: RankFusion | None = None
|
|
134
|
+
reranker: Reranker | None = None
|
|
136
135
|
|
|
137
136
|
|
|
138
137
|
# TODO: augmentation things: hydration...
|
|
@@ -141,7 +140,7 @@ class UnitRetrieval(BaseModel):
|
|
|
141
140
|
class Generation(BaseModel):
|
|
142
141
|
use_visual_llm: bool
|
|
143
142
|
max_context_tokens: int
|
|
144
|
-
max_answer_tokens:
|
|
143
|
+
max_answer_tokens: int | None
|
|
145
144
|
|
|
146
145
|
|
|
147
146
|
class ParsedQuery(BaseModel):
|
|
@@ -149,7 +148,7 @@ class ParsedQuery(BaseModel):
|
|
|
149
148
|
|
|
150
149
|
fetcher: Fetcher
|
|
151
150
|
retrieval: UnitRetrieval
|
|
152
|
-
generation:
|
|
151
|
+
generation: Generation | None = None
|
|
153
152
|
|
|
154
153
|
|
|
155
154
|
### Graph
|
|
@@ -20,7 +20,6 @@
|
|
|
20
20
|
|
|
21
21
|
from dataclasses import dataclass
|
|
22
22
|
from datetime import datetime
|
|
23
|
-
from typing import Optional, Union
|
|
24
23
|
|
|
25
24
|
from nidx_protos.nodereader_pb2 import FilterExpression
|
|
26
25
|
|
|
@@ -36,19 +35,19 @@ from .fetcher import Fetcher
|
|
|
36
35
|
|
|
37
36
|
@dataclass
|
|
38
37
|
class OldFilterParams:
|
|
39
|
-
label_filters:
|
|
40
|
-
keyword_filters:
|
|
41
|
-
range_creation_start:
|
|
42
|
-
range_creation_end:
|
|
43
|
-
range_modification_start:
|
|
44
|
-
range_modification_end:
|
|
45
|
-
fields:
|
|
46
|
-
key_filters:
|
|
38
|
+
label_filters: list[str] | list[Filter]
|
|
39
|
+
keyword_filters: list[str] | list[Filter]
|
|
40
|
+
range_creation_start: datetime | None = None
|
|
41
|
+
range_creation_end: datetime | None = None
|
|
42
|
+
range_modification_start: datetime | None = None
|
|
43
|
+
range_modification_end: datetime | None = None
|
|
44
|
+
fields: list[str] | None = None
|
|
45
|
+
key_filters: list[str] | None = None
|
|
47
46
|
|
|
48
47
|
|
|
49
48
|
async def parse_old_filters(
|
|
50
49
|
old: OldFilterParams, fetcher: Fetcher
|
|
51
|
-
) -> tuple[
|
|
50
|
+
) -> tuple[FilterExpression | None, FilterExpression | None]:
|
|
52
51
|
filters = []
|
|
53
52
|
paragraph_filter_expression = None
|
|
54
53
|
|
|
@@ -128,6 +127,7 @@ async def parse_old_filters(
|
|
|
128
127
|
f.field.field_type = parts[1]
|
|
129
128
|
if len(parts) > 2:
|
|
130
129
|
f.field.field_id = parts[2]
|
|
130
|
+
expr.bool_and.operands.append(f)
|
|
131
131
|
key_exprs.append(expr)
|
|
132
132
|
|
|
133
133
|
if len(key_exprs) == 1:
|
|
@@ -149,8 +149,8 @@ async def parse_old_filters(
|
|
|
149
149
|
|
|
150
150
|
|
|
151
151
|
def convert_label_filter_to_expressions(
|
|
152
|
-
fltr:
|
|
153
|
-
) -> tuple[
|
|
152
|
+
fltr: str | Filter, classification_labels: knowledgebox_pb2.Labels
|
|
153
|
+
) -> tuple[FilterExpression | None, FilterExpression | None]:
|
|
154
154
|
if isinstance(fltr, str):
|
|
155
155
|
fltr = translate_label(fltr)
|
|
156
156
|
f = FilterExpression()
|
|
@@ -174,7 +174,7 @@ def convert_label_filter_to_expressions(
|
|
|
174
174
|
|
|
175
175
|
def split_labels(
|
|
176
176
|
labels: list[str], classification_labels: knowledgebox_pb2.Labels, combinator: str, negate: bool
|
|
177
|
-
) -> tuple[
|
|
177
|
+
) -> tuple[FilterExpression | None, FilterExpression | None]:
|
|
178
178
|
field = []
|
|
179
179
|
paragraph = []
|
|
180
180
|
for label in labels:
|
|
@@ -223,13 +223,14 @@ def is_paragraph_label(label: str, classification_labels: knowledgebox_pb2.Label
|
|
|
223
223
|
if not label.startswith("/l/"):
|
|
224
224
|
return False
|
|
225
225
|
# Classification labels should have the form /l/labelset/label
|
|
226
|
+
# REVIEW: there's no technical reason why this has to be like this (/l/labelset could be valid)
|
|
226
227
|
parts = label.split("/")
|
|
227
228
|
if len(parts) < 4:
|
|
228
229
|
return False
|
|
229
230
|
labelset_id = parts[2]
|
|
230
231
|
|
|
231
232
|
try:
|
|
232
|
-
labelset:
|
|
233
|
+
labelset: knowledgebox_pb2.LabelSet | None = classification_labels.labelset.get(labelset_id)
|
|
233
234
|
if labelset is None:
|
|
234
235
|
return False
|
|
235
236
|
return knowledgebox_pb2.LabelSet.LabelSetKind.PARAGRAPHS in labelset.kind
|
|
@@ -238,19 +239,19 @@ def is_paragraph_label(label: str, classification_labels: knowledgebox_pb2.Label
|
|
|
238
239
|
return False
|
|
239
240
|
|
|
240
241
|
|
|
241
|
-
def convert_keyword_filter_to_expression(fltr:
|
|
242
|
+
def convert_keyword_filter_to_expression(fltr: str | Filter) -> FilterExpression:
|
|
242
243
|
if isinstance(fltr, str):
|
|
243
244
|
return convert_keyword_to_expression(fltr)
|
|
244
245
|
|
|
245
246
|
f = FilterExpression()
|
|
246
247
|
if fltr.all:
|
|
247
|
-
f.bool_and.operands.extend(
|
|
248
|
+
f.bool_and.operands.extend(convert_keyword_to_expression(f) for f in fltr.all)
|
|
248
249
|
if fltr.any:
|
|
249
|
-
f.bool_or.operands.extend(
|
|
250
|
+
f.bool_or.operands.extend(convert_keyword_to_expression(f) for f in fltr.any)
|
|
250
251
|
if fltr.none:
|
|
251
|
-
f.bool_not.bool_or.operands.extend(
|
|
252
|
+
f.bool_not.bool_or.operands.extend(convert_keyword_to_expression(f) for f in fltr.none)
|
|
252
253
|
if fltr.not_all:
|
|
253
|
-
f.bool_not.bool_and.operands.extend(
|
|
254
|
+
f.bool_not.bool_and.operands.extend(convert_keyword_to_expression(f) for f in fltr.not_all)
|
|
254
255
|
|
|
255
256
|
return f
|
|
256
257
|
|
|
@@ -17,7 +17,8 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
|
-
|
|
20
|
+
|
|
21
|
+
from typing_extensions import assert_never
|
|
21
22
|
|
|
22
23
|
from nucliadb.search.search.query_parser.fetcher import Fetcher
|
|
23
24
|
from nucliadb.search.search.query_parser.models import (
|
|
@@ -26,7 +27,7 @@ from nucliadb.search.search.query_parser.models import (
|
|
|
26
27
|
from nucliadb_models.search import AskRequest, MaxTokens
|
|
27
28
|
|
|
28
29
|
|
|
29
|
-
async def parse_ask(kbid: str, item: AskRequest, *, fetcher:
|
|
30
|
+
async def parse_ask(kbid: str, item: AskRequest, *, fetcher: Fetcher | None = None) -> Generation:
|
|
30
31
|
fetcher = fetcher or fetcher_for_ask(kbid, item)
|
|
31
32
|
parser = _AskParser(kbid, item, fetcher)
|
|
32
33
|
return await parser.parse()
|
|
@@ -64,9 +65,7 @@ class _AskParser:
|
|
|
64
65
|
elif isinstance(self.item.max_tokens, MaxTokens):
|
|
65
66
|
max_tokens = self.item.max_tokens
|
|
66
67
|
else: # pragma: no cover
|
|
67
|
-
|
|
68
|
-
# that is, if we are missing some ifs
|
|
69
|
-
_a: int = "a"
|
|
68
|
+
assert_never(self.item.max_tokens)
|
|
70
69
|
|
|
71
70
|
max_context_tokens = await self.fetcher.get_max_context_tokens(max_tokens)
|
|
72
71
|
max_answer_tokens = self.fetcher.get_max_answer_tokens(max_tokens)
|