nucliadb 6.2.1.post2864__py3-none-any.whl → 6.2.1.post2869__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,7 +24,7 @@ import logging
24
24
  from collections import defaultdict
25
25
  from concurrent.futures import ThreadPoolExecutor
26
26
  from functools import partial
27
- from typing import TYPE_CHECKING, Any, AsyncIterator, MutableMapping, Optional, Type
27
+ from typing import TYPE_CHECKING, Any, Optional, Sequence, Type
28
28
 
29
29
  from nucliadb.common import datamanagers
30
30
  from nucliadb.common.datamanagers.resources import KB_RESOURCE_SLUG
@@ -52,7 +52,6 @@ from nucliadb_protos.resources_pb2 import (
52
52
  FieldComputedMetadataWrapper,
53
53
  FieldFile,
54
54
  FieldID,
55
- FieldMetadata,
56
55
  FieldQuestionAnswerWrapper,
57
56
  FieldText,
58
57
  FieldType,
@@ -61,7 +60,6 @@ from nucliadb_protos.resources_pb2 import (
61
60
  LinkExtractedData,
62
61
  Metadata,
63
62
  Paragraph,
64
- ParagraphAnnotation,
65
63
  )
66
64
  from nucliadb_protos.resources_pb2 import Basic as PBBasic
67
65
  from nucliadb_protos.resources_pb2 import Conversation as PBConversation
@@ -69,15 +67,6 @@ from nucliadb_protos.resources_pb2 import Extra as PBExtra
69
67
  from nucliadb_protos.resources_pb2 import Metadata as PBMetadata
70
68
  from nucliadb_protos.resources_pb2 import Origin as PBOrigin
71
69
  from nucliadb_protos.resources_pb2 import Relations as PBRelations
72
- from nucliadb_protos.train_pb2 import (
73
- EnabledMetadata,
74
- TrainField,
75
- TrainMetadata,
76
- TrainParagraph,
77
- TrainResource,
78
- TrainSentence,
79
- )
80
- from nucliadb_protos.train_pb2 import Position as TrainPosition
81
70
  from nucliadb_protos.utils_pb2 import Relation as PBRelation
82
71
  from nucliadb_protos.writer_pb2 import BrokerMessage
83
72
  from nucliadb_utils.storages.storage import Storage
@@ -343,36 +332,24 @@ class Resource:
343
332
  )
344
333
 
345
334
  if self.disable_vectors is False:
346
- # XXX: while we don't remove the "default" vectorset concept, we
347
- # need to do use None as the default one
348
- vo = await field.get_vectors()
349
- if vo is not None:
350
- async with datamanagers.with_ro_transaction() as ro_txn:
351
- dimension = await datamanagers.kb.get_matryoshka_vector_dimension(
352
- ro_txn, kbid=self.kb.kbid
353
- )
354
- brain.apply_field_vectors(
355
- field_key,
356
- vo,
357
- matryoshka_vector_dimension=dimension,
358
- replace_field=reindex,
359
- )
360
-
361
335
  vectorset_configs = []
362
- async with datamanagers.with_ro_transaction() as ro_txn:
363
- async for vectorset_id, vectorset_config in datamanagers.vectorsets.iter(
364
- ro_txn, kbid=self.kb.kbid
365
- ):
366
- vectorset_configs.append(vectorset_config)
336
+ async for vectorset_id, vectorset_config in datamanagers.vectorsets.iter(
337
+ self.txn, kbid=self.kb.kbid
338
+ ):
339
+ vectorset_configs.append(vectorset_config)
340
+
367
341
  for vectorset_config in vectorset_configs:
368
- vo = await field.get_vectors(vectorset=vectorset_config.vectorset_id)
342
+ vo = await field.get_vectors(
343
+ vectorset=vectorset_config.vectorset_id,
344
+ storage_key_kind=vectorset_config.storage_key_kind,
345
+ )
369
346
  if vo is not None:
370
347
  dimension = vectorset_config.vectorset_index_config.vector_dimension
371
348
  brain.apply_field_vectors(
372
349
  field_key,
373
350
  vo,
374
351
  vectorset=vectorset_config.vectorset_id,
375
- matryoshka_vector_dimension=dimension,
352
+ vector_dimension=dimension,
376
353
  replace_field=reindex,
377
354
  )
378
355
  return brain
@@ -671,9 +648,7 @@ class Resource:
671
648
  # Upload to binary storage
672
649
  # Vector indexing
673
650
  if self.disable_vectors is False:
674
- await self.get_fields(force=True)
675
- for field_vectors in message.field_vectors:
676
- await self._apply_extracted_vectors(field_vectors)
651
+ await self._apply_extracted_vectors(message.field_vectors)
677
652
 
678
653
  # Only uploading to binary storage
679
654
  for field_large_metadata in message.field_large_metadata:
@@ -857,55 +832,69 @@ class Resource:
857
832
 
858
833
  add_field_classifications(self.basic, field_metadata)
859
834
 
860
- async def _apply_extracted_vectors(self, field_vectors: ExtractedVectorsWrapper):
861
- # Store vectors in the resource
862
-
863
- if not self.has_field(field_vectors.field.field_type, field_vectors.field.field):
864
- # skipping because field does not exist
865
- logger.warning(f'Field "{field_vectors.field.field}" does not exist, skipping vectors')
866
- return
867
-
868
- field_obj = await self.get_field(
869
- field_vectors.field.field,
870
- field_vectors.field.field_type,
871
- load=False,
872
- )
873
- vo = await field_obj.set_vectors(field_vectors)
874
-
875
- # Prepare vectors to be indexed
835
+ async def _apply_extracted_vectors(
836
+ self,
837
+ fields_vectors: Sequence[ExtractedVectorsWrapper],
838
+ ):
839
+ await self.get_fields(force=True)
840
+ vectorsets = {
841
+ vectorset_id: vs
842
+ async for vectorset_id, vs in datamanagers.vectorsets.iter(self.txn, kbid=self.kb.kbid)
843
+ }
844
+
845
+ for field_vectors in fields_vectors:
846
+ # Bw/c with extracted vectors without vectorsets
847
+ if not field_vectors.vectorset_id:
848
+ assert (
849
+ len(vectorsets) == 1
850
+ ), "Invalid broker message, can't ingest vectors from unknown vectorset to KB with multiple vectorsets"
851
+ vectorset = list(vectorsets.values())[0]
876
852
 
877
- field_key = self.generate_field_id(field_vectors.field)
878
- if vo is not None:
879
- vectorset_id = field_vectors.vectorset_id or None
880
- if vectorset_id is None:
881
- dimension = await datamanagers.kb.get_matryoshka_vector_dimension(
882
- self.txn, kbid=self.kb.kbid
883
- )
884
853
  else:
885
- config = await datamanagers.vectorsets.get(
886
- self.txn, kbid=self.kb.kbid, vectorset_id=vectorset_id
887
- )
888
- if config is None:
854
+ if field_vectors.vectorset_id not in vectorsets:
889
855
  logger.warning(
890
- f"Trying to apply a resource on vectorset '{vectorset_id}' that doesn't exist."
856
+ "Dropping extracted vectors for unknown vectorset",
857
+ extra={"kbid": self.kb.kbid, "vectorset": field_vectors.vectorset_id},
891
858
  )
892
- return
893
- dimension = config.vectorset_index_config.vector_dimension
894
- if not dimension:
895
- raise ValueError(f"Vector dimension not set for vectorset '{vectorset_id}'")
859
+ continue
860
+
861
+ vectorset = vectorsets[field_vectors.vectorset_id]
862
+
863
+ # Store vectors in the resource
864
+
865
+ if not self.has_field(field_vectors.field.field_type, field_vectors.field.field):
866
+ # skipping because field does not exist
867
+ logger.warning(f'Field "{field_vectors.field.field}" does not exist, skipping vectors')
868
+ return
869
+
870
+ field_obj = await self.get_field(
871
+ field_vectors.field.field,
872
+ field_vectors.field.field_type,
873
+ load=False,
874
+ )
875
+ vo = await field_obj.set_vectors(
876
+ field_vectors, vectorset.vectorset_id, vectorset.storage_key_kind
877
+ )
878
+ if vo is None:
879
+ raise AttributeError("Vector object not found on set_vectors")
880
+
881
+ # Prepare vectors to be indexed
882
+
883
+ field_key = self.generate_field_id(field_vectors.field)
884
+ dimension = vectorset.vectorset_index_config.vector_dimension
885
+ if not dimension:
886
+ raise ValueError(f"Vector dimension not set for vectorset '{vectorset.vectorset_id}'")
896
887
 
897
888
  apply_field_vectors_partial = partial(
898
889
  self.indexer.apply_field_vectors,
899
890
  field_key,
900
891
  vo,
901
- vectorset=vectorset_id,
892
+ vectorset=vectorset.vectorset_id,
902
893
  replace_field=True,
903
- matryoshka_vector_dimension=dimension,
894
+ vector_dimension=dimension,
904
895
  )
905
896
  loop = asyncio.get_running_loop()
906
897
  await loop.run_in_executor(_executor, apply_field_vectors_partial)
907
- else:
908
- raise AttributeError("VO not found on set")
909
898
 
910
899
  async def _apply_field_large_metadata(self, field_large_metadata: LargeComputedMetadataWrapper):
911
900
  field_obj = await self.get_field(
@@ -978,291 +967,6 @@ class Resource:
978
967
  self._indexer = None
979
968
  self.txn = None
980
969
 
981
- async def iterate_sentences(
982
- self, enabled_metadata: EnabledMetadata
983
- ) -> AsyncIterator[TrainSentence]: # pragma: no cover
984
- fields = await self.get_fields(force=True)
985
- metadata = TrainMetadata()
986
- userdefinedparagraphclass: dict[str, ParagraphAnnotation] = {}
987
- if enabled_metadata.labels:
988
- if self.basic is None:
989
- self.basic = await self.get_basic()
990
- if self.basic is not None:
991
- metadata.labels.resource.extend(self.basic.usermetadata.classifications)
992
- for fieldmetadata in self.basic.fieldmetadata:
993
- field_id = self.generate_field_id(fieldmetadata.field)
994
- for annotationparagraph in fieldmetadata.paragraphs:
995
- userdefinedparagraphclass[annotationparagraph.key] = annotationparagraph
996
-
997
- for (type_id, field_id), field in fields.items():
998
- fieldid = FieldID(field_type=type_id, field=field_id)
999
- field_key = self.generate_field_id(fieldid)
1000
- fm = await field.get_field_metadata()
1001
- extracted_text = None
1002
- vo = None
1003
- text = None
1004
-
1005
- if enabled_metadata.vector:
1006
- vo = await field.get_vectors()
1007
-
1008
- extracted_text = await field.get_extracted_text()
1009
-
1010
- if fm is None:
1011
- continue
1012
-
1013
- field_metadatas: list[tuple[Optional[str], FieldMetadata]] = [(None, fm.metadata)]
1014
- for subfield_metadata, splitted_metadata in fm.split_metadata.items():
1015
- field_metadatas.append((subfield_metadata, splitted_metadata))
1016
-
1017
- for subfield, field_metadata in field_metadatas:
1018
- if enabled_metadata.labels:
1019
- metadata.labels.ClearField("field")
1020
- metadata.labels.field.extend(field_metadata.classifications)
1021
-
1022
- entities: dict[str, str] = {}
1023
- if enabled_metadata.entities:
1024
- _update_entities_dict(entities, field_metadata)
1025
-
1026
- precomputed_vectors = {}
1027
- if vo is not None:
1028
- if subfield is not None:
1029
- vectors = vo.split_vectors[subfield]
1030
- base_vector_key = f"{self.uuid}/{field_key}/{subfield}"
1031
- else:
1032
- vectors = vo.vectors
1033
- base_vector_key = f"{self.uuid}/{field_key}"
1034
- for index, vector in enumerate(vectors.vectors):
1035
- vector_key = f"{base_vector_key}/{index}/{vector.start}-{vector.end}"
1036
- precomputed_vectors[vector_key] = vector.vector
1037
-
1038
- if extracted_text is not None:
1039
- if subfield is not None:
1040
- text = extracted_text.split_text[subfield]
1041
- else:
1042
- text = extracted_text.text
1043
-
1044
- for paragraph in field_metadata.paragraphs:
1045
- if subfield is not None:
1046
- paragraph_key = (
1047
- f"{self.uuid}/{field_key}/{subfield}/{paragraph.start}-{paragraph.end}"
1048
- )
1049
- else:
1050
- paragraph_key = f"{self.uuid}/{field_key}/{paragraph.start}-{paragraph.end}"
1051
-
1052
- if enabled_metadata.labels:
1053
- metadata.labels.ClearField("field")
1054
- metadata.labels.paragraph.extend(paragraph.classifications)
1055
- if paragraph_key in userdefinedparagraphclass:
1056
- metadata.labels.paragraph.extend(
1057
- userdefinedparagraphclass[paragraph_key].classifications
1058
- )
1059
-
1060
- for index, sentence in enumerate(paragraph.sentences):
1061
- if subfield is not None:
1062
- sentence_key = f"{self.uuid}/{field_key}/{subfield}/{index}/{sentence.start}-{sentence.end}"
1063
- else:
1064
- sentence_key = (
1065
- f"{self.uuid}/{field_key}/{index}/{sentence.start}-{sentence.end}"
1066
- )
1067
-
1068
- if vo is not None:
1069
- metadata.ClearField("vector")
1070
- vector_tmp = precomputed_vectors.get(sentence_key)
1071
- if vector_tmp:
1072
- metadata.vector.extend(vector_tmp)
1073
-
1074
- if extracted_text is not None and text is not None:
1075
- metadata.text = text[sentence.start : sentence.end]
1076
-
1077
- metadata.ClearField("entities")
1078
- metadata.ClearField("entity_positions")
1079
- if enabled_metadata.entities and text is not None:
1080
- local_text = text[sentence.start : sentence.end]
1081
- add_entities_to_metadata(entities, local_text, metadata)
1082
-
1083
- pb_sentence = TrainSentence()
1084
- pb_sentence.uuid = self.uuid
1085
- pb_sentence.field.CopyFrom(fieldid)
1086
- pb_sentence.paragraph = paragraph_key
1087
- pb_sentence.sentence = sentence_key
1088
- pb_sentence.metadata.CopyFrom(metadata)
1089
- yield pb_sentence
1090
-
1091
- async def iterate_paragraphs(
1092
- self, enabled_metadata: EnabledMetadata
1093
- ) -> AsyncIterator[TrainParagraph]:
1094
- fields = await self.get_fields(force=True)
1095
- metadata = TrainMetadata()
1096
- userdefinedparagraphclass: dict[str, ParagraphAnnotation] = {}
1097
- if enabled_metadata.labels:
1098
- if self.basic is None:
1099
- self.basic = await self.get_basic()
1100
- if self.basic is not None:
1101
- metadata.labels.resource.extend(self.basic.usermetadata.classifications)
1102
- for fieldmetadata in self.basic.fieldmetadata:
1103
- field_id = self.generate_field_id(fieldmetadata.field)
1104
- for annotationparagraph in fieldmetadata.paragraphs:
1105
- userdefinedparagraphclass[annotationparagraph.key] = annotationparagraph
1106
-
1107
- for (type_id, field_id), field in fields.items():
1108
- fieldid = FieldID(field_type=type_id, field=field_id)
1109
- field_key = self.generate_field_id(fieldid)
1110
- fm = await field.get_field_metadata()
1111
- extracted_text = None
1112
- text = None
1113
-
1114
- extracted_text = await field.get_extracted_text()
1115
-
1116
- if fm is None:
1117
- continue
1118
-
1119
- field_metadatas: list[tuple[Optional[str], FieldMetadata]] = [(None, fm.metadata)]
1120
- for subfield_metadata, splitted_metadata in fm.split_metadata.items():
1121
- field_metadatas.append((subfield_metadata, splitted_metadata))
1122
-
1123
- for subfield, field_metadata in field_metadatas:
1124
- if enabled_metadata.labels:
1125
- metadata.labels.ClearField("field")
1126
- metadata.labels.field.extend(field_metadata.classifications)
1127
-
1128
- entities: dict[str, str] = {}
1129
- if enabled_metadata.entities:
1130
- _update_entities_dict(entities, field_metadata)
1131
-
1132
- if extracted_text is not None:
1133
- if subfield is not None:
1134
- text = extracted_text.split_text[subfield]
1135
- else:
1136
- text = extracted_text.text
1137
-
1138
- for paragraph in field_metadata.paragraphs:
1139
- if subfield is not None:
1140
- paragraph_key = (
1141
- f"{self.uuid}/{field_key}/{subfield}/{paragraph.start}-{paragraph.end}"
1142
- )
1143
- else:
1144
- paragraph_key = f"{self.uuid}/{field_key}/{paragraph.start}-{paragraph.end}"
1145
-
1146
- if enabled_metadata.labels:
1147
- metadata.labels.ClearField("paragraph")
1148
- metadata.labels.paragraph.extend(paragraph.classifications)
1149
-
1150
- if extracted_text is not None and text is not None:
1151
- metadata.text = text[paragraph.start : paragraph.end]
1152
-
1153
- metadata.ClearField("entities")
1154
- metadata.ClearField("entity_positions")
1155
- if enabled_metadata.entities and text is not None:
1156
- local_text = text[paragraph.start : paragraph.end]
1157
- add_entities_to_metadata(entities, local_text, metadata)
1158
-
1159
- if paragraph_key in userdefinedparagraphclass:
1160
- metadata.labels.paragraph.extend(
1161
- userdefinedparagraphclass[paragraph_key].classifications
1162
- )
1163
-
1164
- pb_paragraph = TrainParagraph()
1165
- pb_paragraph.uuid = self.uuid
1166
- pb_paragraph.field.CopyFrom(fieldid)
1167
- pb_paragraph.paragraph = paragraph_key
1168
- pb_paragraph.metadata.CopyFrom(metadata)
1169
-
1170
- yield pb_paragraph
1171
-
1172
- async def iterate_fields(self, enabled_metadata: EnabledMetadata) -> AsyncIterator[TrainField]:
1173
- fields = await self.get_fields(force=True)
1174
- metadata = TrainMetadata()
1175
- if enabled_metadata.labels:
1176
- if self.basic is None:
1177
- self.basic = await self.get_basic()
1178
- if self.basic is not None:
1179
- metadata.labels.resource.extend(self.basic.usermetadata.classifications)
1180
-
1181
- for (type_id, field_id), field in fields.items():
1182
- fieldid = FieldID(field_type=type_id, field=field_id)
1183
- fm = await field.get_field_metadata()
1184
- extracted_text = None
1185
-
1186
- if enabled_metadata.text:
1187
- extracted_text = await field.get_extracted_text()
1188
-
1189
- if fm is None:
1190
- continue
1191
-
1192
- field_metadatas: list[tuple[Optional[str], FieldMetadata]] = [(None, fm.metadata)]
1193
- for subfield_metadata, splitted_metadata in fm.split_metadata.items():
1194
- field_metadatas.append((subfield_metadata, splitted_metadata))
1195
-
1196
- for subfield, splitted_metadata in field_metadatas:
1197
- if enabled_metadata.labels:
1198
- metadata.labels.ClearField("field")
1199
- metadata.labels.field.extend(splitted_metadata.classifications)
1200
-
1201
- if extracted_text is not None:
1202
- if subfield is not None:
1203
- metadata.text = extracted_text.split_text[subfield]
1204
- else:
1205
- metadata.text = extracted_text.text
1206
-
1207
- if enabled_metadata.entities:
1208
- metadata.ClearField("entities")
1209
- _update_entities_dict(metadata.entities, splitted_metadata)
1210
-
1211
- pb_field = TrainField()
1212
- pb_field.uuid = self.uuid
1213
- pb_field.field.CopyFrom(fieldid)
1214
- pb_field.metadata.CopyFrom(metadata)
1215
- yield pb_field
1216
-
1217
- async def generate_train_resource(self, enabled_metadata: EnabledMetadata) -> TrainResource:
1218
- fields = await self.get_fields(force=True)
1219
- metadata = TrainMetadata()
1220
- if enabled_metadata.labels:
1221
- if self.basic is None:
1222
- self.basic = await self.get_basic()
1223
- if self.basic is not None:
1224
- metadata.labels.resource.extend(self.basic.usermetadata.classifications)
1225
-
1226
- metadata.labels.ClearField("field")
1227
- metadata.ClearField("entities")
1228
-
1229
- for (_, _), field in fields.items():
1230
- extracted_text = None
1231
- fm = await field.get_field_metadata()
1232
-
1233
- if enabled_metadata.text:
1234
- extracted_text = await field.get_extracted_text()
1235
-
1236
- if extracted_text is not None:
1237
- metadata.text += extracted_text.text
1238
- for text in extracted_text.split_text.values():
1239
- metadata.text += f" {text}"
1240
-
1241
- if fm is None:
1242
- continue
1243
-
1244
- field_metadatas: list[tuple[Optional[str], FieldMetadata]] = [(None, fm.metadata)]
1245
- for subfield_metadata, splitted_metadata in fm.split_metadata.items():
1246
- field_metadatas.append((subfield_metadata, splitted_metadata))
1247
-
1248
- for _, splitted_metadata in field_metadatas:
1249
- if enabled_metadata.labels:
1250
- metadata.labels.field.extend(splitted_metadata.classifications)
1251
-
1252
- if enabled_metadata.entities:
1253
- _update_entities_dict(metadata.entities, splitted_metadata)
1254
-
1255
- pb_resource = TrainResource()
1256
- pb_resource.uuid = self.uuid
1257
- if self.basic is not None:
1258
- pb_resource.title = self.basic.title
1259
- pb_resource.icon = self.basic.icon
1260
- pb_resource.slug = self.basic.slug
1261
- pb_resource.modified.CopyFrom(self.basic.modified)
1262
- pb_resource.created.CopyFrom(self.basic.created)
1263
- pb_resource.metadata.CopyFrom(metadata)
1264
- return pb_resource
1265
-
1266
970
 
1267
971
  async def get_file_page_positions(field: File) -> FilePagePositions:
1268
972
  positions: FilePagePositions = {}
@@ -1307,24 +1011,6 @@ def add_field_classifications(basic: PBBasic, fcmw: FieldComputedMetadataWrapper
1307
1011
  return True
1308
1012
 
1309
1013
 
1310
- def add_entities_to_metadata(entities: dict[str, str], local_text: str, metadata: TrainMetadata) -> None:
1311
- for entity_key, entity_value in entities.items():
1312
- if entity_key not in local_text:
1313
- # Add the entity only if found in text
1314
- continue
1315
- metadata.entities[entity_key] = entity_value
1316
-
1317
- # Add positions for the entity relative to the local text
1318
- poskey = f"{entity_value}/{entity_key}"
1319
- metadata.entity_positions[poskey].entity = entity_key
1320
- last_occurrence_end = 0
1321
- for _ in range(local_text.count(entity_key)):
1322
- start = local_text.index(entity_key, last_occurrence_end)
1323
- end = start + len(entity_key)
1324
- metadata.entity_positions[poskey].positions.append(TrainPosition(start=start, end=end))
1325
- last_occurrence_end = end
1326
-
1327
-
1328
1014
  def maybe_update_basic_summary(basic: PBBasic, summary_text: str) -> bool:
1329
1015
  if basic.summary or not summary_text:
1330
1016
  return False
@@ -1393,23 +1079,3 @@ def extract_field_metadata_languages(
1393
1079
  for _, splitted_metadata in field_metadata.metadata.split_metadata.items():
1394
1080
  languages.add(splitted_metadata.language)
1395
1081
  return list(languages)
1396
-
1397
-
1398
- def _update_entities_dict(target_entites_dict: MutableMapping[str, str], field_metadata: FieldMetadata):
1399
- """
1400
- Update the entities dict with the entities from the field metadata.
1401
- Method created to ease the transition from legacy ner field to new entities field.
1402
- """
1403
- # Data Augmentation + Processor entities
1404
- # This will overwrite entities detected from more than one data augmentation task
1405
- # TODO: Change TrainMetadata proto to accept multiple entities with the same text
1406
- entity_map = {
1407
- entity.text: entity.label
1408
- for data_augmentation_task_id, entities_wrapper in field_metadata.entities.items()
1409
- for entity in entities_wrapper.entities
1410
- }
1411
- target_entites_dict.update(entity_map)
1412
-
1413
- # Legacy processor entities
1414
- # TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
1415
- target_entites_dict.update(field_metadata.ner)
@@ -17,10 +17,10 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
-
21
20
  from typing import Optional
22
21
 
23
22
  import nucliadb_models as models
23
+ from nucliadb.common import datamanagers
24
24
  from nucliadb.common.maindb.driver import Transaction
25
25
  from nucliadb.common.maindb.utils import get_driver
26
26
  from nucliadb.common.models_utils import from_proto
@@ -83,7 +83,18 @@ async def set_resource_field_extracted_data(
83
83
  field_data.large_metadata = from_proto.large_computed_metadata(data_lcm)
84
84
 
85
85
  if ExtractedDataTypeName.VECTOR in wanted_extracted_data:
86
- data_vec = await field.get_vectors()
86
+ # XXX: our extracted API is not vectorset-compatible, so we'll get the
87
+ # first vectorset and return the values. Ideally, we should provide a
88
+ # way to select a vectorset
89
+ vectorset_id = None
90
+ async with datamanagers.with_ro_transaction() as txn:
91
+ async for vectorset_id, vs in datamanagers.vectorsets.iter(
92
+ txn=txn,
93
+ kbid=field.resource.kb.kbid,
94
+ ):
95
+ break
96
+ assert vectorset_id is not None, "All KBs must have at least a vectorset"
97
+ data_vec = await field.get_vectors(vectorset_id, vs.storage_key_kind)
87
98
  if data_vec is not None:
88
99
  field_data.vectors = from_proto.vector_object(data_vec)
89
100
 
@@ -485,6 +485,7 @@ class WriterServicer(writer_pb2_grpc.WriterServicer):
485
485
  vector_dimension=request.vector_dimension,
486
486
  ),
487
487
  matryoshka_dimensions=request.matryoshka_dimensions,
488
+ storage_key_kind=VectorSetConfig.StorageKeyKind.VECTORSET_PREFIX,
488
489
  )
489
490
  response = NewVectorSetResponse()
490
491
  try:
@@ -513,6 +514,9 @@ class WriterServicer(writer_pb2_grpc.WriterServicer):
513
514
  kbobj = KnowledgeBoxORM(txn, self.storage, request.kbid)
514
515
  await kbobj.delete_vectorset(request.vectorset_id)
515
516
  await txn.commit()
517
+ except VectorSetConflict as exc:
518
+ response.status = DelVectorSetResponse.Status.ERROR
519
+ response.details = str(exc)
516
520
  except Exception as exc:
517
521
  errors.capture_exception(exc)
518
522
  logger.error("Error in ingest gRPC while deleting a vectorset", exc_info=True)
@@ -26,6 +26,7 @@ from nucliadb.common.cluster.utils import setup_cluster, teardown_cluster
26
26
  from nucliadb.common.maindb.driver import Driver
27
27
  from nucliadb.common.maindb.utils import setup_driver, teardown_driver
28
28
  from nucliadb.ingest import SERVICE_NAME, logger
29
+ from nucliadb.ingest.fields.base import Field
29
30
  from nucliadb.ingest.orm.knowledgebox import (
30
31
  KB_TO_DELETE,
31
32
  KB_TO_DELETE_BASE,
@@ -35,6 +36,7 @@ from nucliadb.ingest.orm.knowledgebox import (
35
36
  RESOURCE_TO_DELETE_STORAGE_BASE,
36
37
  KnowledgeBox,
37
38
  )
39
+ from nucliadb_protos.knowledgebox_pb2 import VectorSetConfig, VectorSetPurge
38
40
  from nucliadb_telemetry import errors
39
41
  from nucliadb_telemetry.logs import setup_logging
40
42
  from nucliadb_utils.storages.storage import Storage
@@ -201,8 +203,8 @@ async def purge_kb_vectorsets(driver: Driver, storage: Storage):
201
203
  """
202
204
  logger.info("START PURGING KB VECTORSETS")
203
205
 
204
- purged = []
205
- async for key in _iter_keys(driver, KB_VECTORSET_TO_DELETE_BASE):
206
+ vectorsets_to_delete = [key async for key in _iter_keys(driver, KB_VECTORSET_TO_DELETE_BASE)]
207
+ for key in vectorsets_to_delete:
206
208
  logger.info(f"Purging vectorsets {key}")
207
209
  try:
208
210
  _base, kbid, vectorset = key.lstrip("/").split("/")
@@ -211,13 +213,38 @@ async def purge_kb_vectorsets(driver: Driver, storage: Storage):
211
213
  continue
212
214
 
213
215
  try:
216
+ async with driver.transaction(read_only=True) as txn:
217
+ value = await txn.get(key)
218
+ assert value is not None, "Key must exist or we wouldn't had fetch it iterating keys"
219
+ purge_payload = VectorSetPurge()
220
+ purge_payload.ParseFromString(value)
221
+
222
+ fields: list[Field] = []
214
223
  async with driver.transaction(read_only=True) as txn:
215
224
  kb = KnowledgeBox(txn, storage, kbid)
216
225
  async for resource in kb.iterate_resources():
217
- fields = await resource.get_fields(force=True)
226
+ fields.extend((await resource.get_fields(force=True)).values())
227
+
218
228
  # we don't need the maindb transaction anymore to remove vectors from storage
219
- for field in fields.values():
220
- await field.delete_vectors(vectorset)
229
+ for field in fields:
230
+ if purge_payload.storage_key_kind == VectorSetConfig.StorageKeyKind.UNSET:
231
+ # Bw/c for purge before adding purge payload. We assume
232
+ # there's only 2 kinds of KBs: with one or with more than
233
+ # one vectorset. KBs with one vectorset are not allowed to
234
+ # delete their vectorset, so we wouldn't be here. It has to
235
+ # be a KB with multiple, so the storage key kind has to be
236
+ # this:
237
+ await field.delete_vectors(
238
+ vectorset, VectorSetConfig.StorageKeyKind.VECTORSET_PREFIX
239
+ )
240
+ else:
241
+ await field.delete_vectors(vectorset, purge_payload.storage_key_kind)
242
+
243
+ # Finally, delete the key
244
+ async with driver.transaction() as txn:
245
+ await txn.delete(key)
246
+ await txn.commit()
247
+
221
248
  except Exception as exc:
222
249
  errors.capture_exception(exc)
223
250
  logger.error(
@@ -227,13 +254,6 @@ async def purge_kb_vectorsets(driver: Driver, storage: Storage):
227
254
  )
228
255
  continue
229
256
 
230
- purged.append(key)
231
-
232
- async with driver.transaction() as txn:
233
- for key in purged:
234
- await txn.delete(key)
235
- await txn.commit()
236
-
237
257
  logger.info("FINISH PURGING KB VECTORSETS")
238
258
 
239
259