nucliadb 6.3.7.post4066__py3-none-any.whl → 6.3.7.post4068__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,77 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+ from typing import Optional
21
+
22
+ from nucliadb.search.search.query_parser.fetcher import Fetcher
23
+ from nucliadb.search.search.query_parser.models import (
24
+ Generation,
25
+ )
26
+ from nucliadb_models.search import AskRequest, MaxTokens
27
+
28
+
29
+ async def parse_ask(kbid: str, item: AskRequest, *, fetcher: Optional[Fetcher] = None) -> Generation:
30
+ fetcher = fetcher or fetcher_for_ask(kbid, item)
31
+ parser = _AskParser(kbid, item, fetcher)
32
+ return await parser.parse()
33
+
34
+
35
+ def fetcher_for_ask(kbid: str, item: AskRequest) -> Fetcher:
36
+ return Fetcher(
37
+ kbid=kbid,
38
+ query=item.query,
39
+ user_vector=None,
40
+ vectorset=item.vectorset,
41
+ rephrase=item.rephrase,
42
+ rephrase_prompt=None,
43
+ generative_model=item.generative_model,
44
+ )
45
+
46
+
47
+ class _AskParser:
48
+ def __init__(self, kbid: str, item: AskRequest, fetcher: Fetcher):
49
+ self.kbid = kbid
50
+ self.item = item
51
+ self.fetcher = fetcher
52
+
53
+ async def parse(self) -> Generation:
54
+ use_visual_llm = await self.fetcher.get_visual_llm_enabled()
55
+
56
+ if self.item.max_tokens is None:
57
+ max_tokens = None
58
+ elif isinstance(self.item.max_tokens, int):
59
+ max_tokens = MaxTokens(
60
+ context=None,
61
+ answer=self.item.max_tokens,
62
+ )
63
+ elif isinstance(self.item.max_tokens, MaxTokens):
64
+ max_tokens = self.item.max_tokens
65
+ else: # pragma: nocover
66
+ # This is a trick so mypy generates an error if this branch can be reached,
67
+ # that is, if we are missing some ifs
68
+ _a: int = "a"
69
+
70
+ max_context_tokens = await self.fetcher.get_max_context_tokens(max_tokens)
71
+ max_answer_tokens = self.fetcher.get_max_answer_tokens(max_tokens)
72
+
73
+ return Generation(
74
+ use_visual_llm=use_visual_llm,
75
+ max_context_tokens=max_context_tokens,
76
+ max_answer_tokens=max_answer_tokens,
77
+ )
@@ -0,0 +1,189 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+ import re
21
+ import string
22
+ from typing import Optional, Union
23
+
24
+ from nucliadb.search import logger
25
+ from nucliadb.search.search.query_parser.exceptions import InvalidQueryError
26
+ from nucliadb.search.search.query_parser.fetcher import Fetcher
27
+ from nucliadb.search.search.query_parser.models import (
28
+ KeywordQuery,
29
+ SemanticQuery,
30
+ )
31
+ from nucliadb.search.search.utils import should_disable_vector_search
32
+ from nucliadb_models import search as search_models
33
+
34
+ DEFAULT_GENERIC_SEMANTIC_THRESHOLD = 0.7
35
+
36
+ # -* is an invalid query in tantivy and it won't return results but if you add some whitespaces
37
+ # between - and *, it will actually trigger a tantivy bug and panic
38
+ INVALID_QUERY = re.compile(r"- +\*")
39
+
40
+
41
+ def validate_base_request(item: search_models.BaseSearchRequest):
42
+ # Filter some queries that panic tantivy, better than returning the 500
43
+ if INVALID_QUERY.search(item.query):
44
+ raise InvalidQueryError("query", "Invalid query syntax")
45
+
46
+ # synonyms are not compatible with vector/graph search
47
+ if (
48
+ item.with_synonyms
49
+ and item.query
50
+ and (
51
+ search_models.SearchOptions.SEMANTIC in item.features
52
+ or search_models.SearchOptions.RELATIONS in item.features
53
+ )
54
+ ):
55
+ raise InvalidQueryError(
56
+ "synonyms",
57
+ "Search with custom synonyms is only supported on paragraph and document search",
58
+ )
59
+
60
+ if search_models.SearchOptions.SEMANTIC in item.features:
61
+ if should_disable_vector_search(item):
62
+ item.features.remove(search_models.SearchOptions.SEMANTIC)
63
+
64
+
65
+ def parse_top_k(item: search_models.BaseSearchRequest) -> int:
66
+ assert item.top_k is not None, "top_k must have an int value"
67
+ top_k = item.top_k
68
+ return top_k
69
+
70
+
71
+ async def parse_keyword_query(
72
+ item: search_models.BaseSearchRequest,
73
+ *,
74
+ fetcher: Fetcher,
75
+ ) -> KeywordQuery:
76
+ query = item.query
77
+ is_synonyms_query = False
78
+
79
+ if item.with_synonyms:
80
+ synonyms_query = await query_with_synonyms(query, fetcher=fetcher)
81
+ if synonyms_query is not None:
82
+ query = synonyms_query
83
+ is_synonyms_query = True
84
+
85
+ min_score = parse_keyword_min_score(item.min_score)
86
+
87
+ return KeywordQuery(
88
+ query=query,
89
+ is_synonyms_query=is_synonyms_query,
90
+ min_score=min_score,
91
+ )
92
+
93
+
94
+ async def parse_semantic_query(
95
+ item: search_models.BaseSearchRequest,
96
+ *,
97
+ fetcher: Fetcher,
98
+ ) -> SemanticQuery:
99
+ vectorset = await fetcher.get_vectorset()
100
+ query = await fetcher.get_query_vector()
101
+
102
+ min_score = await parse_semantic_min_score(item.min_score, fetcher=fetcher)
103
+
104
+ return SemanticQuery(query=query, vectorset=vectorset, min_score=min_score)
105
+
106
+
107
+ def parse_keyword_min_score(
108
+ min_score: Optional[Union[float, search_models.MinScore]],
109
+ ) -> float:
110
+ # Keep backward compatibility with the deprecated min_score payload
111
+ # parameter being a float (specifying semantic)
112
+ if min_score is None or isinstance(min_score, float):
113
+ return 0.0
114
+ else:
115
+ return min_score.bm25
116
+
117
+
118
+ async def parse_semantic_min_score(
119
+ min_score: Optional[Union[float, search_models.MinScore]],
120
+ *,
121
+ fetcher: Fetcher,
122
+ ):
123
+ if min_score is None:
124
+ min_score = None
125
+ elif isinstance(min_score, float):
126
+ min_score = min_score
127
+ else:
128
+ min_score = min_score.semantic
129
+
130
+ if min_score is None:
131
+ # min score not defined by the user, we'll try to get the default
132
+ # from Predict API
133
+ min_score = await fetcher.get_semantic_min_score()
134
+ if min_score is None:
135
+ logger.warning(
136
+ "Semantic threshold not found in query information, using default",
137
+ extra={"kbid": fetcher.kbid},
138
+ )
139
+ min_score = DEFAULT_GENERIC_SEMANTIC_THRESHOLD
140
+
141
+ return min_score
142
+
143
+
144
+ async def query_with_synonyms(
145
+ query: str,
146
+ *,
147
+ fetcher: Fetcher,
148
+ ) -> Optional[str]:
149
+ """
150
+ Replace the terms in the query with an expression that will make it match with the configured synonyms.
151
+ We're using the Tantivy's query language here: https://docs.rs/tantivy/latest/tantivy/query/struct.QueryParser.html
152
+
153
+ Example:
154
+ - Synonyms: Foo -> Bar, Baz
155
+ - Query: "What is Foo?"
156
+ - Advanced Query: "What is (Foo OR Bar OR Baz)?"
157
+ """
158
+ if not query:
159
+ return None
160
+
161
+ synonyms = await fetcher.get_synonyms()
162
+ if synonyms is None:
163
+ # No synonyms found
164
+ return None
165
+
166
+ # Calculate term variants: 'term' -> '(term OR synonym1 OR synonym2)'
167
+ variants: dict[str, str] = {}
168
+ for term, term_synonyms in synonyms.terms.items():
169
+ if len(term_synonyms.synonyms) > 0:
170
+ variants[term] = "({})".format(" OR ".join([term] + list(term_synonyms.synonyms)))
171
+
172
+ # Split the query into terms
173
+ query_terms = query.split()
174
+
175
+ # Remove punctuation from the query terms
176
+ clean_query_terms = [term.strip(string.punctuation) for term in query_terms]
177
+
178
+ # Replace the original terms with the variants if the cleaned term is in the variants
179
+ term_with_synonyms_found = False
180
+ for index, clean_term in enumerate(clean_query_terms):
181
+ if clean_term in variants:
182
+ term_with_synonyms_found = True
183
+ query_terms[index] = query_terms[index].replace(clean_term, variants[clean_term])
184
+
185
+ if term_with_synonyms_found:
186
+ advanced_query = " ".join(query_terms)
187
+ return advanced_query
188
+
189
+ return None
@@ -18,37 +18,103 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
+ from typing import Optional
22
+
21
23
  from pydantic import ValidationError
22
24
 
25
+ from nucliadb.common.models_utils.from_proto import RelationNodeTypeMap
23
26
  from nucliadb.search.search.metrics import query_parser_observer
24
- from nucliadb.search.search.query_parser.exceptions import InternalParserError
27
+ from nucliadb.search.search.query import expand_entities
28
+ from nucliadb.search.search.query_parser.exceptions import InternalParserError, InvalidQueryError
29
+ from nucliadb.search.search.query_parser.fetcher import Fetcher
30
+ from nucliadb.search.search.query_parser.filter_expression import parse_expression
25
31
  from nucliadb.search.search.query_parser.models import (
32
+ Filters,
26
33
  NoopReranker,
34
+ ParsedQuery,
27
35
  PredictReranker,
36
+ Query,
28
37
  RankFusion,
29
38
  ReciprocalRankFusion,
39
+ RelationQuery,
30
40
  Reranker,
31
41
  UnitRetrieval,
32
42
  )
43
+ from nucliadb.search.search.query_parser.old_filters import OldFilterParams, parse_old_filters
44
+ from nucliadb.search.search.utils import filter_hidden_resources
33
45
  from nucliadb_models import search as search_models
46
+ from nucliadb_models.filters import FilterExpression
34
47
  from nucliadb_models.search import (
35
48
  FindRequest,
36
49
  )
50
+ from nucliadb_protos import nodereader_pb2, utils_pb2
51
+
52
+ from .common import (
53
+ parse_keyword_query,
54
+ parse_semantic_query,
55
+ parse_top_k,
56
+ validate_base_request,
57
+ )
37
58
 
38
59
 
39
60
  @query_parser_observer.wrap({"type": "parse_find"})
40
- async def parse_find(kbid: str, item: FindRequest) -> UnitRetrieval:
41
- parser = _FindParser(kbid, item)
42
- return await parser.parse()
61
+ async def parse_find(
62
+ kbid: str,
63
+ item: FindRequest,
64
+ generative_model: Optional[str] = None,
65
+ *,
66
+ fetcher: Optional[Fetcher] = None,
67
+ ) -> ParsedQuery:
68
+ fetcher = fetcher or fetcher_for_find(kbid, item, generative_model)
69
+ parser = _FindParser(kbid, item, fetcher)
70
+ retrieval = await parser.parse()
71
+ return ParsedQuery(fetcher=fetcher, retrieval=retrieval, generation=None)
72
+
73
+
74
+ def fetcher_for_find(kbid: str, item: FindRequest, generative_model: Optional[str]) -> Fetcher:
75
+ return Fetcher(
76
+ kbid=kbid,
77
+ query=item.query,
78
+ user_vector=item.vector,
79
+ vectorset=item.vectorset,
80
+ rephrase=item.rephrase,
81
+ rephrase_prompt=item.rephrase_prompt,
82
+ generative_model=generative_model,
83
+ )
43
84
 
44
85
 
45
86
  class _FindParser:
46
- def __init__(self, kbid: str, item: FindRequest):
87
+ def __init__(self, kbid: str, item: FindRequest, fetcher: Fetcher):
47
88
  self.kbid = kbid
48
89
  self.item = item
90
+ self.fetcher = fetcher
91
+
92
+ # cached data while parsing
93
+ self._query: Optional[Query] = None
94
+ self._top_k: Optional[int] = None
49
95
 
50
96
  async def parse(self) -> UnitRetrieval:
51
- top_k = self._parse_top_k()
97
+ validate_base_request(self.item)
98
+
99
+ self._top_k = parse_top_k(self.item)
100
+
101
+ # parse search types (features)
102
+
103
+ self._query = Query()
104
+
105
+ if search_models.SearchOptions.KEYWORD in self.item.features:
106
+ self._query.keyword = await parse_keyword_query(self.item, fetcher=self.fetcher)
107
+
108
+ if search_models.SearchOptions.SEMANTIC in self.item.features:
109
+ self._query.semantic = await parse_semantic_query(self.item, fetcher=self.fetcher)
110
+
111
+ if search_models.SearchOptions.RELATIONS in self.item.features:
112
+ self._query.relation = await self._parse_relation_query()
113
+
114
+ # TODO: graph search
115
+
116
+ filters = await self._parse_filters()
117
+
52
118
  try:
53
119
  rank_fusion = self._parse_rank_fusion()
54
120
  except ValidationError as exc:
@@ -66,20 +132,116 @@ class _FindParser:
66
132
  rank_fusion.window = max(rank_fusion.window, reranker.window)
67
133
 
68
134
  return UnitRetrieval(
69
- top_k=top_k,
135
+ query=self._query,
136
+ top_k=self._top_k,
137
+ filters=filters,
70
138
  rank_fusion=rank_fusion,
71
139
  reranker=reranker,
72
140
  )
73
141
 
74
- def _parse_top_k(self) -> int:
75
- assert self.item.top_k is not None, "top_k must have an int value"
76
- top_k = self.item.top_k
77
- return top_k
142
+ async def _parse_relation_query(self) -> RelationQuery:
143
+ detected_entities = await self._get_detected_entities()
144
+
145
+ deleted_entity_groups = await self.fetcher.get_deleted_entity_groups()
146
+
147
+ meta_cache = await self.fetcher.get_entities_meta_cache()
148
+ deleted_entities = meta_cache.deleted_entities
149
+
150
+ return RelationQuery(
151
+ detected_entities=detected_entities,
152
+ deleted_entity_groups=deleted_entity_groups,
153
+ deleted_entities=deleted_entities,
154
+ )
155
+
156
+ async def _get_detected_entities(self) -> list[utils_pb2.RelationNode]:
157
+ """Get entities from request, either automatically detected or
158
+ explicitly set by the user."""
159
+
160
+ if self.item.query_entities:
161
+ detected_entities = []
162
+ for entity in self.item.query_entities:
163
+ relation_node = utils_pb2.RelationNode()
164
+ relation_node.value = entity.name
165
+ if entity.type is not None:
166
+ relation_node.ntype = RelationNodeTypeMap[entity.type]
167
+ if entity.subtype is not None:
168
+ relation_node.subtype = entity.subtype
169
+ detected_entities.append(relation_node)
170
+ else:
171
+ detected_entities = await self.fetcher.get_detected_entities()
172
+
173
+ meta_cache = await self.fetcher.get_entities_meta_cache()
174
+ detected_entities = expand_entities(meta_cache, detected_entities)
175
+
176
+ return detected_entities
177
+
178
+ async def _parse_filters(self) -> Filters:
179
+ assert self._query is not None, "query must be parsed before filters"
180
+
181
+ has_old_filters = (
182
+ len(self.item.filters) > 0
183
+ or len(self.item.resource_filters) > 0
184
+ or len(self.item.fields) > 0
185
+ or len(self.item.keyword_filters) > 0
186
+ or self.item.range_creation_start is not None
187
+ or self.item.range_creation_end is not None
188
+ or self.item.range_modification_start is not None
189
+ or self.item.range_modification_end is not None
190
+ )
191
+ if self.item.filter_expression is not None and has_old_filters:
192
+ raise InvalidQueryError("filter_expression", "Cannot mix old filters with filter_expression")
193
+
194
+ field_expr = None
195
+ paragraph_expr = None
196
+ filter_operator = nodereader_pb2.FilterOperator.AND
197
+
198
+ if has_old_filters:
199
+ old_filters = OldFilterParams(
200
+ label_filters=self.item.filters,
201
+ keyword_filters=self.item.keyword_filters,
202
+ range_creation_start=self.item.range_creation_start,
203
+ range_creation_end=self.item.range_creation_end,
204
+ range_modification_start=self.item.range_modification_start,
205
+ range_modification_end=self.item.range_modification_end,
206
+ fields=self.item.fields,
207
+ key_filters=self.item.resource_filters,
208
+ )
209
+ field_expr, paragraph_expr = await parse_old_filters(old_filters, self.fetcher)
210
+
211
+ if self.item.filter_expression is not None:
212
+ if self.item.filter_expression.field:
213
+ field_expr = await parse_expression(self.item.filter_expression.field, self.kbid)
214
+ if self.item.filter_expression.paragraph:
215
+ paragraph_expr = await parse_expression(self.item.filter_expression.paragraph, self.kbid)
216
+ if self.item.filter_expression.operator == FilterExpression.Operator.OR:
217
+ filter_operator = nodereader_pb2.FilterOperator.OR
218
+ else:
219
+ filter_operator = nodereader_pb2.FilterOperator.AND
220
+
221
+ autofilter = None
222
+ if self.item.autofilter:
223
+ if self._query.relation is not None:
224
+ autofilter = self._query.relation.detected_entities
225
+ else:
226
+ autofilter = await self._get_detected_entities()
227
+
228
+ hidden = await filter_hidden_resources(self.kbid, self.item.show_hidden)
229
+
230
+ return Filters(
231
+ autofilter=autofilter,
232
+ facets=[],
233
+ field_expression=field_expr,
234
+ paragraph_expression=paragraph_expr,
235
+ filter_expression_operator=filter_operator,
236
+ security=self.item.security,
237
+ hidden=hidden,
238
+ with_duplicates=self.item.with_duplicates,
239
+ )
78
240
 
79
241
  def _parse_rank_fusion(self) -> RankFusion:
80
242
  rank_fusion: RankFusion
81
243
 
82
- top_k = self._parse_top_k()
244
+ top_k = parse_top_k(self.item)
83
245
  window = min(top_k, 500)
84
246
 
85
247
  if isinstance(self.item.rank_fusion, search_models.RankFusionName):
@@ -104,7 +266,7 @@ class _FindParser:
104
266
  def _parse_reranker(self) -> Reranker:
105
267
  reranking: Reranker
106
268
 
107
- top_k = self._parse_top_k()
269
+ top_k = parse_top_k(self.item)
108
270
 
109
271
  if isinstance(self.item.reranker, search_models.RerankerName):
110
272
  if self.item.reranker == search_models.RerankerName.NOOP: