nucliadb 6.2.0.post2679__py3-none-any.whl → 6.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. migrations/0028_extracted_vectors_reference.py +61 -0
  2. migrations/0029_backfill_field_status.py +149 -0
  3. migrations/0030_label_deduplication.py +60 -0
  4. nucliadb/common/cluster/manager.py +41 -331
  5. nucliadb/common/cluster/rebalance.py +2 -2
  6. nucliadb/common/cluster/rollover.py +12 -71
  7. nucliadb/common/cluster/settings.py +3 -0
  8. nucliadb/common/cluster/standalone/utils.py +0 -43
  9. nucliadb/common/cluster/utils.py +0 -16
  10. nucliadb/common/counters.py +1 -0
  11. nucliadb/common/datamanagers/fields.py +48 -7
  12. nucliadb/common/datamanagers/vectorsets.py +11 -2
  13. nucliadb/common/external_index_providers/base.py +2 -1
  14. nucliadb/common/external_index_providers/pinecone.py +3 -5
  15. nucliadb/common/ids.py +18 -4
  16. nucliadb/common/models_utils/from_proto.py +479 -0
  17. nucliadb/common/models_utils/to_proto.py +60 -0
  18. nucliadb/common/nidx.py +76 -37
  19. nucliadb/export_import/models.py +3 -3
  20. nucliadb/health.py +0 -7
  21. nucliadb/ingest/app.py +0 -8
  22. nucliadb/ingest/consumer/auditing.py +1 -1
  23. nucliadb/ingest/consumer/shard_creator.py +1 -1
  24. nucliadb/ingest/fields/base.py +83 -21
  25. nucliadb/ingest/orm/brain.py +55 -56
  26. nucliadb/ingest/orm/broker_message.py +12 -2
  27. nucliadb/ingest/orm/entities.py +6 -17
  28. nucliadb/ingest/orm/knowledgebox.py +44 -22
  29. nucliadb/ingest/orm/processor/data_augmentation.py +7 -29
  30. nucliadb/ingest/orm/processor/processor.py +5 -2
  31. nucliadb/ingest/orm/resource.py +222 -413
  32. nucliadb/ingest/processing.py +8 -2
  33. nucliadb/ingest/serialize.py +77 -46
  34. nucliadb/ingest/service/writer.py +2 -56
  35. nucliadb/ingest/settings.py +1 -4
  36. nucliadb/learning_proxy.py +6 -4
  37. nucliadb/purge/__init__.py +102 -12
  38. nucliadb/purge/orphan_shards.py +6 -4
  39. nucliadb/reader/api/models.py +3 -3
  40. nucliadb/reader/api/v1/__init__.py +1 -0
  41. nucliadb/reader/api/v1/download.py +2 -2
  42. nucliadb/reader/api/v1/knowledgebox.py +3 -3
  43. nucliadb/reader/api/v1/resource.py +23 -12
  44. nucliadb/reader/api/v1/services.py +4 -4
  45. nucliadb/reader/api/v1/vectorsets.py +48 -0
  46. nucliadb/search/api/v1/ask.py +11 -1
  47. nucliadb/search/api/v1/feedback.py +3 -3
  48. nucliadb/search/api/v1/knowledgebox.py +8 -13
  49. nucliadb/search/api/v1/search.py +3 -2
  50. nucliadb/search/api/v1/suggest.py +0 -2
  51. nucliadb/search/predict.py +6 -4
  52. nucliadb/search/requesters/utils.py +1 -2
  53. nucliadb/search/search/chat/ask.py +77 -13
  54. nucliadb/search/search/chat/prompt.py +16 -5
  55. nucliadb/search/search/chat/query.py +74 -34
  56. nucliadb/search/search/exceptions.py +2 -7
  57. nucliadb/search/search/find.py +9 -5
  58. nucliadb/search/search/find_merge.py +10 -4
  59. nucliadb/search/search/graph_strategy.py +884 -0
  60. nucliadb/search/search/hydrator.py +6 -0
  61. nucliadb/search/search/merge.py +79 -24
  62. nucliadb/search/search/query.py +74 -245
  63. nucliadb/search/search/query_parser/exceptions.py +11 -1
  64. nucliadb/search/search/query_parser/fetcher.py +405 -0
  65. nucliadb/search/search/query_parser/models.py +0 -3
  66. nucliadb/search/search/query_parser/parser.py +22 -21
  67. nucliadb/search/search/rerankers.py +1 -42
  68. nucliadb/search/search/shards.py +19 -0
  69. nucliadb/standalone/api_router.py +2 -14
  70. nucliadb/standalone/settings.py +4 -0
  71. nucliadb/train/generators/field_streaming.py +7 -3
  72. nucliadb/train/lifecycle.py +3 -6
  73. nucliadb/train/nodes.py +14 -12
  74. nucliadb/train/resource.py +380 -0
  75. nucliadb/writer/api/constants.py +20 -16
  76. nucliadb/writer/api/v1/__init__.py +1 -0
  77. nucliadb/writer/api/v1/export_import.py +1 -1
  78. nucliadb/writer/api/v1/field.py +13 -7
  79. nucliadb/writer/api/v1/knowledgebox.py +3 -46
  80. nucliadb/writer/api/v1/resource.py +20 -13
  81. nucliadb/writer/api/v1/services.py +10 -1
  82. nucliadb/writer/api/v1/upload.py +61 -34
  83. nucliadb/writer/{vectorsets.py → api/v1/vectorsets.py} +99 -47
  84. nucliadb/writer/back_pressure.py +17 -46
  85. nucliadb/writer/resource/basic.py +9 -7
  86. nucliadb/writer/resource/field.py +42 -9
  87. nucliadb/writer/settings.py +2 -2
  88. nucliadb/writer/tus/gcs.py +11 -10
  89. {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/METADATA +11 -14
  90. {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/RECORD +94 -96
  91. {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/WHEEL +1 -1
  92. nucliadb/common/cluster/discovery/base.py +0 -178
  93. nucliadb/common/cluster/discovery/k8s.py +0 -301
  94. nucliadb/common/cluster/discovery/manual.py +0 -57
  95. nucliadb/common/cluster/discovery/single.py +0 -51
  96. nucliadb/common/cluster/discovery/types.py +0 -32
  97. nucliadb/common/cluster/discovery/utils.py +0 -67
  98. nucliadb/common/cluster/standalone/grpc_node_binding.py +0 -349
  99. nucliadb/common/cluster/standalone/index_node.py +0 -123
  100. nucliadb/common/cluster/standalone/service.py +0 -84
  101. nucliadb/standalone/introspect.py +0 -208
  102. nucliadb-6.2.0.post2679.dist-info/zip-safe +0 -1
  103. /nucliadb/common/{cluster/discovery → models_utils}/__init__.py +0 -0
  104. {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/entry_points.txt +0 -0
  105. {nucliadb-6.2.0.post2679.dist-info → nucliadb-6.2.1.dist-info}/top_level.txt +0 -0
@@ -97,9 +97,13 @@ async def generate_field_streaming_payloads(
97
97
  tl.field = field
98
98
  tl.field_type = field_type
99
99
  tl.split = split
100
- extracted = await get_field_text(kbid, rid, field, field_type)
101
- if extracted is not None:
102
- tl.text.CopyFrom(extracted)
100
+
101
+ if trainset.exclude_text:
102
+ tl.text.text = ""
103
+ else:
104
+ extracted = await get_field_text(kbid, rid, field, field_type)
105
+ if extracted is not None:
106
+ tl.text.CopyFrom(extracted)
103
107
 
104
108
  metadata_obj = await get_field_metadata(kbid, rid, field, field_type)
105
109
  if metadata_obj is not None:
@@ -22,10 +22,7 @@ from contextlib import asynccontextmanager
22
22
 
23
23
  from fastapi import FastAPI
24
24
 
25
- from nucliadb.common.cluster.discovery.utils import (
26
- setup_cluster_discovery,
27
- teardown_cluster_discovery,
28
- )
25
+ from nucliadb.common.nidx import start_nidx_utility, stop_nidx_utility
29
26
  from nucliadb.train import SERVICE_NAME
30
27
  from nucliadb.train.utils import (
31
28
  start_shard_manager,
@@ -40,7 +37,7 @@ from nucliadb_utils.utilities import start_audit_utility, stop_audit_utility
40
37
  @asynccontextmanager
41
38
  async def lifespan(app: FastAPI):
42
39
  await setup_telemetry(SERVICE_NAME)
43
- await setup_cluster_discovery()
40
+ await start_nidx_utility()
44
41
  await start_shard_manager()
45
42
  await start_train_grpc(SERVICE_NAME)
46
43
  await start_audit_utility(SERVICE_NAME)
@@ -50,5 +47,5 @@ async def lifespan(app: FastAPI):
50
47
  await stop_audit_utility()
51
48
  await stop_train_grpc()
52
49
  await stop_shard_manager()
53
- await teardown_cluster_discovery()
50
+ await stop_nidx_utility()
54
51
  await clean_telemetry(SERVICE_NAME)
nucliadb/train/nodes.py CHANGED
@@ -28,6 +28,12 @@ from nucliadb.common.datamanagers.resources import KB_RESOURCE_SLUG_BASE
28
28
  from nucliadb.common.maindb.driver import Driver, Transaction
29
29
  from nucliadb.ingest.orm.entities import EntitiesManager
30
30
  from nucliadb.ingest.orm.knowledgebox import KnowledgeBox
31
+ from nucliadb.train.resource import (
32
+ generate_train_resource,
33
+ iterate_fields,
34
+ iterate_paragraphs,
35
+ iterate_sentences,
36
+ )
31
37
  from nucliadb_protos.train_pb2 import (
32
38
  GetFieldsRequest,
33
39
  GetParagraphsRequest,
@@ -39,9 +45,7 @@ from nucliadb_protos.train_pb2 import (
39
45
  TrainSentence,
40
46
  )
41
47
  from nucliadb_protos.writer_pb2 import ShardObject
42
- from nucliadb_utils import const
43
48
  from nucliadb_utils.storages.storage import Storage
44
- from nucliadb_utils.utilities import has_feature
45
49
 
46
50
 
47
51
  class TrainShardManager(manager.KBShardManager):
@@ -57,9 +61,7 @@ class TrainShardManager(manager.KBShardManager):
57
61
  except StopIteration:
58
62
  raise KeyError("Shard not found")
59
63
 
60
- node_obj, shard_id = manager.choose_node(
61
- shard_object, use_nidx=has_feature(const.Features.NIDX_READS, context={"kbid": kbid})
62
- )
64
+ node_obj, shard_id = manager.choose_node(shard_object)
63
65
  return node_obj, shard_id
64
66
 
65
67
  async def get_kb_obj(self, txn: Transaction, kbid: str) -> Optional[KnowledgeBox]:
@@ -87,11 +89,11 @@ class TrainShardManager(manager.KBShardManager):
87
89
  # Filter by uuid
88
90
  resource = await kb.get(request.uuid)
89
91
  if resource:
90
- async for sentence in resource.iterate_sentences(request.metadata):
92
+ async for sentence in iterate_sentences(resource, request.metadata):
91
93
  yield sentence
92
94
  else:
93
95
  async for resource in kb.iterate_resources():
94
- async for sentence in resource.iterate_sentences(request.metadata):
96
+ async for sentence in iterate_sentences(resource, request.metadata):
95
97
  yield sentence
96
98
 
97
99
  async def kb_paragraphs(self, request: GetParagraphsRequest) -> AsyncIterator[TrainParagraph]:
@@ -101,11 +103,11 @@ class TrainShardManager(manager.KBShardManager):
101
103
  # Filter by uuid
102
104
  resource = await kb.get(request.uuid)
103
105
  if resource:
104
- async for paragraph in resource.iterate_paragraphs(request.metadata):
106
+ async for paragraph in iterate_paragraphs(resource, request.metadata):
105
107
  yield paragraph
106
108
  else:
107
109
  async for resource in kb.iterate_resources():
108
- async for paragraph in resource.iterate_paragraphs(request.metadata):
110
+ async for paragraph in iterate_paragraphs(resource, request.metadata):
109
111
  yield paragraph
110
112
 
111
113
  async def kb_fields(self, request: GetFieldsRequest) -> AsyncIterator[TrainField]:
@@ -115,11 +117,11 @@ class TrainShardManager(manager.KBShardManager):
115
117
  # Filter by uuid
116
118
  resource = await kb.get(request.uuid)
117
119
  if resource:
118
- async for field in resource.iterate_fields(request.metadata):
120
+ async for field in iterate_fields(resource, request.metadata):
119
121
  yield field
120
122
  else:
121
123
  async for resource in kb.iterate_resources():
122
- async for field in resource.iterate_fields(request.metadata):
124
+ async for field in iterate_fields(resource, request.metadata):
123
125
  yield field
124
126
 
125
127
  async def kb_resources(self, request: GetResourcesRequest) -> AsyncIterator[TrainResource]:
@@ -132,4 +134,4 @@ class TrainShardManager(manager.KBShardManager):
132
134
  if rid is not None:
133
135
  resource = await kb.get(rid.decode())
134
136
  if resource is not None:
135
- yield await resource.generate_train_resource(request.metadata)
137
+ yield await generate_train_resource(resource, request.metadata)
@@ -0,0 +1,380 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+ from __future__ import annotations
21
+
22
+ from typing import AsyncIterator, MutableMapping, Optional
23
+
24
+ from nucliadb.common import datamanagers
25
+ from nucliadb.ingest.orm.resource import Resource
26
+ from nucliadb_protos.resources_pb2 import (
27
+ FieldID,
28
+ FieldMetadata,
29
+ ParagraphAnnotation,
30
+ )
31
+ from nucliadb_protos.train_pb2 import (
32
+ EnabledMetadata,
33
+ TrainField,
34
+ TrainMetadata,
35
+ TrainParagraph,
36
+ TrainResource,
37
+ TrainSentence,
38
+ )
39
+ from nucliadb_protos.train_pb2 import Position as TrainPosition
40
+
41
+
42
+ async def iterate_sentences(
43
+ resource: Resource,
44
+ enabled_metadata: EnabledMetadata,
45
+ ) -> AsyncIterator[TrainSentence]: # pragma: no cover
46
+ fields = await resource.get_fields(force=True)
47
+ metadata = TrainMetadata()
48
+ userdefinedparagraphclass: dict[str, ParagraphAnnotation] = {}
49
+ if enabled_metadata.labels:
50
+ if resource.basic is None:
51
+ resource.basic = await resource.get_basic()
52
+ if resource.basic is not None:
53
+ metadata.labels.resource.extend(resource.basic.usermetadata.classifications)
54
+ for fieldmetadata in resource.basic.fieldmetadata:
55
+ field_id = resource.generate_field_id(fieldmetadata.field)
56
+ for annotationparagraph in fieldmetadata.paragraphs:
57
+ userdefinedparagraphclass[annotationparagraph.key] = annotationparagraph
58
+
59
+ for (type_id, field_id), field in fields.items():
60
+ fieldid = FieldID(field_type=type_id, field=field_id)
61
+ field_key = resource.generate_field_id(fieldid)
62
+ fm = await field.get_field_metadata()
63
+ extracted_text = None
64
+ vo = None
65
+ text = None
66
+
67
+ if enabled_metadata.vector:
68
+ # XXX: Given that nobody requested any particular vectorset, we'll
69
+ # return any
70
+ vectorset_id = None
71
+ async with datamanagers.with_ro_transaction() as txn:
72
+ async for vectorset_id, vs in datamanagers.vectorsets.iter(
73
+ txn=txn, kbid=resource.kb.kbid
74
+ ):
75
+ break
76
+ assert vectorset_id is not None, "All KBs must have at least a vectorset"
77
+ vo = await field.get_vectors(vectorset_id, vs.storage_key_kind)
78
+
79
+ extracted_text = await field.get_extracted_text()
80
+
81
+ if fm is None:
82
+ continue
83
+
84
+ field_metadatas: list[tuple[Optional[str], FieldMetadata]] = [(None, fm.metadata)]
85
+ for subfield_metadata, splitted_metadata in fm.split_metadata.items():
86
+ field_metadatas.append((subfield_metadata, splitted_metadata))
87
+
88
+ for subfield, field_metadata in field_metadatas:
89
+ if enabled_metadata.labels:
90
+ metadata.labels.ClearField("field")
91
+ metadata.labels.field.extend(field_metadata.classifications)
92
+
93
+ entities: dict[str, str] = {}
94
+ if enabled_metadata.entities:
95
+ _update_entities_dict(entities, field_metadata)
96
+
97
+ precomputed_vectors = {}
98
+ if vo is not None:
99
+ if subfield is not None:
100
+ vectors = vo.split_vectors[subfield]
101
+ base_vector_key = f"{resource.uuid}/{field_key}/{subfield}"
102
+ else:
103
+ vectors = vo.vectors
104
+ base_vector_key = f"{resource.uuid}/{field_key}"
105
+ for index, vector in enumerate(vectors.vectors):
106
+ vector_key = f"{base_vector_key}/{index}/{vector.start}-{vector.end}"
107
+ precomputed_vectors[vector_key] = vector.vector
108
+
109
+ if extracted_text is not None:
110
+ if subfield is not None:
111
+ text = extracted_text.split_text[subfield]
112
+ else:
113
+ text = extracted_text.text
114
+
115
+ for paragraph in field_metadata.paragraphs:
116
+ if subfield is not None:
117
+ paragraph_key = (
118
+ f"{resource.uuid}/{field_key}/{subfield}/{paragraph.start}-{paragraph.end}"
119
+ )
120
+ else:
121
+ paragraph_key = f"{resource.uuid}/{field_key}/{paragraph.start}-{paragraph.end}"
122
+
123
+ if enabled_metadata.labels:
124
+ metadata.labels.ClearField("field")
125
+ metadata.labels.paragraph.extend(paragraph.classifications)
126
+ if paragraph_key in userdefinedparagraphclass:
127
+ metadata.labels.paragraph.extend(
128
+ userdefinedparagraphclass[paragraph_key].classifications
129
+ )
130
+
131
+ for index, sentence in enumerate(paragraph.sentences):
132
+ if subfield is not None:
133
+ sentence_key = f"{resource.uuid}/{field_key}/{subfield}/{index}/{sentence.start}-{sentence.end}"
134
+ else:
135
+ sentence_key = (
136
+ f"{resource.uuid}/{field_key}/{index}/{sentence.start}-{sentence.end}"
137
+ )
138
+
139
+ if vo is not None:
140
+ metadata.ClearField("vector")
141
+ vector_tmp = precomputed_vectors.get(sentence_key)
142
+ if vector_tmp:
143
+ metadata.vector.extend(vector_tmp)
144
+
145
+ if extracted_text is not None and text is not None:
146
+ metadata.text = text[sentence.start : sentence.end]
147
+
148
+ metadata.ClearField("entities")
149
+ metadata.ClearField("entity_positions")
150
+ if enabled_metadata.entities and text is not None:
151
+ local_text = text[sentence.start : sentence.end]
152
+ add_entities_to_metadata(entities, local_text, metadata)
153
+
154
+ pb_sentence = TrainSentence()
155
+ pb_sentence.uuid = resource.uuid
156
+ pb_sentence.field.CopyFrom(fieldid)
157
+ pb_sentence.paragraph = paragraph_key
158
+ pb_sentence.sentence = sentence_key
159
+ pb_sentence.metadata.CopyFrom(metadata)
160
+ yield pb_sentence
161
+
162
+
163
+ async def iterate_paragraphs(
164
+ resource: Resource, enabled_metadata: EnabledMetadata
165
+ ) -> AsyncIterator[TrainParagraph]:
166
+ fields = await resource.get_fields(force=True)
167
+ metadata = TrainMetadata()
168
+ userdefinedparagraphclass: dict[str, ParagraphAnnotation] = {}
169
+ if enabled_metadata.labels:
170
+ if resource.basic is None:
171
+ resource.basic = await resource.get_basic()
172
+ if resource.basic is not None:
173
+ metadata.labels.resource.extend(resource.basic.usermetadata.classifications)
174
+ for fieldmetadata in resource.basic.fieldmetadata:
175
+ field_id = resource.generate_field_id(fieldmetadata.field)
176
+ for annotationparagraph in fieldmetadata.paragraphs:
177
+ userdefinedparagraphclass[annotationparagraph.key] = annotationparagraph
178
+
179
+ for (type_id, field_id), field in fields.items():
180
+ fieldid = FieldID(field_type=type_id, field=field_id)
181
+ field_key = resource.generate_field_id(fieldid)
182
+ fm = await field.get_field_metadata()
183
+ extracted_text = None
184
+ text = None
185
+
186
+ extracted_text = await field.get_extracted_text()
187
+
188
+ if fm is None:
189
+ continue
190
+
191
+ field_metadatas: list[tuple[Optional[str], FieldMetadata]] = [(None, fm.metadata)]
192
+ for subfield_metadata, splitted_metadata in fm.split_metadata.items():
193
+ field_metadatas.append((subfield_metadata, splitted_metadata))
194
+
195
+ for subfield, field_metadata in field_metadatas:
196
+ if enabled_metadata.labels:
197
+ metadata.labels.ClearField("field")
198
+ metadata.labels.field.extend(field_metadata.classifications)
199
+
200
+ entities: dict[str, str] = {}
201
+ if enabled_metadata.entities:
202
+ _update_entities_dict(entities, field_metadata)
203
+
204
+ if extracted_text is not None:
205
+ if subfield is not None:
206
+ text = extracted_text.split_text[subfield]
207
+ else:
208
+ text = extracted_text.text
209
+
210
+ for paragraph in field_metadata.paragraphs:
211
+ if subfield is not None:
212
+ paragraph_key = (
213
+ f"{resource.uuid}/{field_key}/{subfield}/{paragraph.start}-{paragraph.end}"
214
+ )
215
+ else:
216
+ paragraph_key = f"{resource.uuid}/{field_key}/{paragraph.start}-{paragraph.end}"
217
+
218
+ if enabled_metadata.labels:
219
+ metadata.labels.ClearField("paragraph")
220
+ metadata.labels.paragraph.extend(paragraph.classifications)
221
+
222
+ if extracted_text is not None and text is not None:
223
+ metadata.text = text[paragraph.start : paragraph.end]
224
+
225
+ metadata.ClearField("entities")
226
+ metadata.ClearField("entity_positions")
227
+ if enabled_metadata.entities and text is not None:
228
+ local_text = text[paragraph.start : paragraph.end]
229
+ add_entities_to_metadata(entities, local_text, metadata)
230
+
231
+ if paragraph_key in userdefinedparagraphclass:
232
+ metadata.labels.paragraph.extend(
233
+ userdefinedparagraphclass[paragraph_key].classifications
234
+ )
235
+
236
+ pb_paragraph = TrainParagraph()
237
+ pb_paragraph.uuid = resource.uuid
238
+ pb_paragraph.field.CopyFrom(fieldid)
239
+ pb_paragraph.paragraph = paragraph_key
240
+ pb_paragraph.metadata.CopyFrom(metadata)
241
+
242
+ yield pb_paragraph
243
+
244
+
245
+ async def iterate_fields(
246
+ resource: Resource, enabled_metadata: EnabledMetadata
247
+ ) -> AsyncIterator[TrainField]:
248
+ fields = await resource.get_fields(force=True)
249
+ metadata = TrainMetadata()
250
+ if enabled_metadata.labels:
251
+ if resource.basic is None:
252
+ resource.basic = await resource.get_basic()
253
+ if resource.basic is not None:
254
+ metadata.labels.resource.extend(resource.basic.usermetadata.classifications)
255
+
256
+ for (type_id, field_id), field in fields.items():
257
+ fieldid = FieldID(field_type=type_id, field=field_id)
258
+ fm = await field.get_field_metadata()
259
+ extracted_text = None
260
+
261
+ if enabled_metadata.text:
262
+ extracted_text = await field.get_extracted_text()
263
+
264
+ if fm is None:
265
+ continue
266
+
267
+ field_metadatas: list[tuple[Optional[str], FieldMetadata]] = [(None, fm.metadata)]
268
+ for subfield_metadata, splitted_metadata in fm.split_metadata.items():
269
+ field_metadatas.append((subfield_metadata, splitted_metadata))
270
+
271
+ for subfield, splitted_metadata in field_metadatas:
272
+ if enabled_metadata.labels:
273
+ metadata.labels.ClearField("field")
274
+ metadata.labels.field.extend(splitted_metadata.classifications)
275
+
276
+ if extracted_text is not None:
277
+ if subfield is not None:
278
+ metadata.text = extracted_text.split_text[subfield]
279
+ else:
280
+ metadata.text = extracted_text.text
281
+
282
+ if enabled_metadata.entities:
283
+ metadata.ClearField("entities")
284
+ _update_entities_dict(metadata.entities, splitted_metadata)
285
+
286
+ pb_field = TrainField()
287
+ pb_field.uuid = resource.uuid
288
+ pb_field.field.CopyFrom(fieldid)
289
+ pb_field.metadata.CopyFrom(metadata)
290
+ yield pb_field
291
+
292
+
293
+ async def generate_train_resource(
294
+ resource: Resource, enabled_metadata: EnabledMetadata
295
+ ) -> TrainResource:
296
+ fields = await resource.get_fields(force=True)
297
+ metadata = TrainMetadata()
298
+ if enabled_metadata.labels:
299
+ if resource.basic is None:
300
+ resource.basic = await resource.get_basic()
301
+ if resource.basic is not None:
302
+ metadata.labels.resource.extend(resource.basic.usermetadata.classifications)
303
+
304
+ metadata.labels.ClearField("field")
305
+ metadata.ClearField("entities")
306
+
307
+ for (_, _), field in fields.items():
308
+ extracted_text = None
309
+ fm = await field.get_field_metadata()
310
+
311
+ if enabled_metadata.text:
312
+ extracted_text = await field.get_extracted_text()
313
+
314
+ if extracted_text is not None:
315
+ metadata.text += extracted_text.text
316
+ for text in extracted_text.split_text.values():
317
+ metadata.text += f" {text}"
318
+
319
+ if fm is None:
320
+ continue
321
+
322
+ field_metadatas: list[tuple[Optional[str], FieldMetadata]] = [(None, fm.metadata)]
323
+ for subfield_metadata, splitted_metadata in fm.split_metadata.items():
324
+ field_metadatas.append((subfield_metadata, splitted_metadata))
325
+
326
+ for _, splitted_metadata in field_metadatas:
327
+ if enabled_metadata.labels:
328
+ metadata.labels.field.extend(splitted_metadata.classifications)
329
+
330
+ if enabled_metadata.entities:
331
+ _update_entities_dict(metadata.entities, splitted_metadata)
332
+
333
+ pb_resource = TrainResource()
334
+ pb_resource.uuid = resource.uuid
335
+ if resource.basic is not None:
336
+ pb_resource.title = resource.basic.title
337
+ pb_resource.icon = resource.basic.icon
338
+ pb_resource.slug = resource.basic.slug
339
+ pb_resource.modified.CopyFrom(resource.basic.modified)
340
+ pb_resource.created.CopyFrom(resource.basic.created)
341
+ pb_resource.metadata.CopyFrom(metadata)
342
+ return pb_resource
343
+
344
+
345
+ def add_entities_to_metadata(entities: dict[str, str], local_text: str, metadata: TrainMetadata) -> None:
346
+ for entity_key, entity_value in entities.items():
347
+ if entity_key not in local_text:
348
+ # Add the entity only if found in text
349
+ continue
350
+ metadata.entities[entity_key] = entity_value
351
+
352
+ # Add positions for the entity relative to the local text
353
+ poskey = f"{entity_value}/{entity_key}"
354
+ metadata.entity_positions[poskey].entity = entity_key
355
+ last_occurrence_end = 0
356
+ for _ in range(local_text.count(entity_key)):
357
+ start = local_text.index(entity_key, last_occurrence_end)
358
+ end = start + len(entity_key)
359
+ metadata.entity_positions[poskey].positions.append(TrainPosition(start=start, end=end))
360
+ last_occurrence_end = end
361
+
362
+
363
+ def _update_entities_dict(target_entites_dict: MutableMapping[str, str], field_metadata: FieldMetadata):
364
+ """
365
+ Update the entities dict with the entities from the field metadata.
366
+ Method created to ease the transition from legacy ner field to new entities field.
367
+ """
368
+ # Data Augmentation + Processor entities
369
+ # This will overwrite entities detected from more than one data augmentation task
370
+ # TODO: Change TrainMetadata proto to accept multiple entities with the same text
371
+ entity_map = {
372
+ entity.text: entity.label
373
+ for data_augmentation_task_id, entities_wrapper in field_metadata.entities.items()
374
+ for entity in entities_wrapper.entities
375
+ }
376
+ target_entites_dict.update(entity_map)
377
+
378
+ # Legacy processor entities
379
+ # TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
380
+ target_entites_dict.update(field_metadata.ner)
@@ -17,21 +17,25 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- from typing import TYPE_CHECKING
21
-
22
20
  from fastapi.params import Header
23
21
 
24
- if TYPE_CHECKING: # pragma: no cover
25
- SKIP_STORE_DEFAULT = False
26
- X_NUCLIADB_USER = ""
27
- X_FILE_PASSWORD = None
28
- else:
29
- SKIP_STORE_DEFAULT = Header(
30
- False,
31
- description="If set to true, file fields will not be saved in the blob storage. They will only be sent to process.", # noqa
32
- )
33
- X_NUCLIADB_USER = Header("")
34
- X_FILE_PASSWORD = Header(
35
- None,
36
- description="If a file is password protected, the password must be provided here for the file to be processed", # noqa
37
- )
22
+ X_SKIP_STORE = Header(
23
+ description="If set to true, file fields will not be saved in the blob storage. They will only be sent to process.",
24
+ )
25
+ X_NUCLIADB_USER = Header()
26
+ X_FILE_PASSWORD = Header(
27
+ description="If a file is password protected, the password must be provided here for the file to be processed",
28
+ )
29
+ X_EXTRACT_STRATEGY = Header(
30
+ description="Extract strategy to use when uploading a file. If not provided, the default strategy will be used.",
31
+ )
32
+ X_FILENAME = Header(min_length=1, description="Name of the file being uploaded.")
33
+ X_MD5 = Header(
34
+ min_length=32,
35
+ max_length=32,
36
+ description="MD5 hash of the file being uploaded. This is used to check if the file has been uploaded before.",
37
+ )
38
+ X_PASSWORD = Header(
39
+ min_length=1, description="If the file is password protected, the password must be provided here."
40
+ )
41
+ X_LANGUAGE = Header()
@@ -24,4 +24,5 @@ from . import learning_config # noqa
24
24
  from . import resource # noqa
25
25
  from . import services # noqa
26
26
  from . import upload # noqa
27
+ from . import vectorsets # noqa
27
28
  from .router import api # noqa
@@ -112,7 +112,7 @@ async def kb_create_and_import_endpoint(request: Request):
112
112
  now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
113
113
  import_kb_config = KnowledgeBoxConfig(
114
114
  title=f"Imported KB - {now}",
115
- learning_configuration=learning_config.dict(),
115
+ learning_configuration=learning_config.model_dump(),
116
116
  )
117
117
  kbid, slug = await create_kb(import_kb_config)
118
118
 
@@ -18,7 +18,7 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  from inspect import iscoroutinefunction
21
- from typing import TYPE_CHECKING, Callable, Optional, Type, Union
21
+ from typing import TYPE_CHECKING, Annotated, Callable, Optional, Type, Union
22
22
 
23
23
  from fastapi import HTTPException, Response
24
24
  from fastapi_versioning import version
@@ -30,9 +30,9 @@ from nucliadb.ingest.orm.knowledgebox import KnowledgeBox
30
30
  from nucliadb.ingest.processing import PushPayload, Source
31
31
  from nucliadb.writer import SERVICE_NAME
32
32
  from nucliadb.writer.api.constants import (
33
- SKIP_STORE_DEFAULT,
34
33
  X_FILE_PASSWORD,
35
34
  X_NUCLIADB_USER,
35
+ X_SKIP_STORE,
36
36
  )
37
37
  from nucliadb.writer.api.v1 import transaction
38
38
  from nucliadb.writer.api.v1.resource import (
@@ -55,7 +55,7 @@ from nucliadb_models.utils import FieldIdString
55
55
  from nucliadb_models.writer import ResourceFieldAdded, ResourceUpdated
56
56
  from nucliadb_protos import resources_pb2
57
57
  from nucliadb_protos.resources_pb2 import FieldID, Metadata
58
- from nucliadb_protos.writer_pb2 import BrokerMessage
58
+ from nucliadb_protos.writer_pb2 import BrokerMessage, FieldIDStatus, FieldStatus
59
59
  from nucliadb_utils.authentication import requires
60
60
  from nucliadb_utils.exceptions import LimitsExceededError, SendToProcessError
61
61
  from nucliadb_utils.utilities import (
@@ -380,7 +380,7 @@ async def add_resource_field_file_rslug_prefix(
380
380
  rslug: str,
381
381
  field_id: FieldIdString,
382
382
  field_payload: models.FileField,
383
- x_skip_store: bool = SKIP_STORE_DEFAULT,
383
+ x_skip_store: Annotated[bool, X_SKIP_STORE] = False,
384
384
  ) -> ResourceFieldAdded:
385
385
  return await add_field_to_resource_by_slug(
386
386
  request, kbid, rslug, field_id, field_payload, skip_store=x_skip_store
@@ -402,7 +402,7 @@ async def add_resource_field_file_rid_prefix(
402
402
  rid: str,
403
403
  field_id: FieldIdString,
404
404
  field_payload: models.FileField,
405
- x_skip_store: bool = SKIP_STORE_DEFAULT,
405
+ x_skip_store: Annotated[bool, X_SKIP_STORE] = False,
406
406
  ) -> ResourceFieldAdded:
407
407
  return await add_field_to_resource(
408
408
  request, kbid, rid, field_id, field_payload, skip_store=x_skip_store
@@ -503,8 +503,8 @@ async def reprocess_file_field(
503
503
  kbid: str,
504
504
  rid: str,
505
505
  field_id: FieldIdString,
506
- x_nucliadb_user: str = X_NUCLIADB_USER,
507
- x_file_password: Optional[str] = X_FILE_PASSWORD,
506
+ x_nucliadb_user: Annotated[str, X_NUCLIADB_USER] = "",
507
+ x_file_password: Annotated[Optional[str], X_FILE_PASSWORD] = None,
508
508
  ) -> ResourceUpdated:
509
509
  await maybe_back_pressure(request, kbid, resource_uuid=rid)
510
510
 
@@ -553,6 +553,12 @@ async def reprocess_file_field(
553
553
  writer.source = BrokerMessage.MessageSource.WRITER
554
554
  writer.basic.metadata.useful = True
555
555
  writer.basic.metadata.status = Metadata.Status.PENDING
556
+ writer.field_statuses.append(
557
+ FieldIDStatus(
558
+ id=FieldID(field_type=resources_pb2.FieldType.FILE, field=field_id),
559
+ status=FieldStatus.Status.PENDING,
560
+ )
561
+ )
556
562
  await transaction.commit(writer, partition, wait=False)
557
563
  # Send current resource to reprocess.
558
564
  try: