faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -140,8 +140,12 @@ std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
140
140
  {"SQ4", ScalarQuantizer::QT_4bit},
141
141
  {"SQ6", ScalarQuantizer::QT_6bit},
142
142
  {"SQfp16", ScalarQuantizer::QT_fp16},
143
+ {"SQbf16", ScalarQuantizer::QT_bf16},
144
+ {"SQ8_direct_signed", ScalarQuantizer::QT_8bit_direct_signed},
145
+ {"SQ8_direct", ScalarQuantizer::QT_8bit_direct},
143
146
  };
144
- const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16)";
147
+ const std::string sq_pattern =
148
+ "(SQ4|SQ8|SQ6|SQfp16|SQbf16|SQ8_direct_signed|SQ8_direct)";
145
149
 
146
150
  std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
147
151
  {"_Nfloat", AdditiveQuantizer::ST_norm_float},
@@ -216,12 +220,25 @@ VectorTransform* parse_VectorTransform(const std::string& description, int d) {
216
220
  return new RemapDimensionsTransform(d, std::max(d_out, d), false);
217
221
  }
218
222
  return nullptr;
219
- };
223
+ }
220
224
 
221
225
  /***************************************************************
222
226
  * Parse IndexIVF
223
227
  */
224
228
 
229
+ size_t parse_nlist(std::string s) {
230
+ size_t multiplier = 1;
231
+ if (s.back() == 'k') {
232
+ s.pop_back();
233
+ multiplier = 1024;
234
+ }
235
+ if (s.back() == 'M') {
236
+ s.pop_back();
237
+ multiplier = 1024 * 1024;
238
+ }
239
+ return std::stoi(s) * multiplier;
240
+ }
241
+
225
242
  // parsing guard + function
226
243
  Index* parse_coarse_quantizer(
227
244
  const std::string& description,
@@ -236,8 +253,8 @@ Index* parse_coarse_quantizer(
236
253
  };
237
254
  use_2layer = false;
238
255
 
239
- if (match("IVF([0-9]+)")) {
240
- nlist = std::stoi(sm[1].str());
256
+ if (match("IVF([0-9]+[kM]?)")) {
257
+ nlist = parse_nlist(sm[1].str());
241
258
  return new IndexFlat(d, mt);
242
259
  }
243
260
  if (match("IMI2x([0-9]+)")) {
@@ -248,18 +265,18 @@ Index* parse_coarse_quantizer(
248
265
  nlist = (size_t)1 << (2 * nbit);
249
266
  return new MultiIndexQuantizer(d, 2, nbit);
250
267
  }
251
- if (match("IVF([0-9]+)_HNSW([0-9]*)")) {
252
- nlist = std::stoi(sm[1].str());
268
+ if (match("IVF([0-9]+[kM]?)_HNSW([0-9]*)")) {
269
+ nlist = parse_nlist(sm[1].str());
253
270
  int hnsw_M = sm[2].length() > 0 ? std::stoi(sm[2]) : 32;
254
271
  return new IndexHNSWFlat(d, hnsw_M, mt);
255
272
  }
256
- if (match("IVF([0-9]+)_NSG([0-9]+)")) {
257
- nlist = std::stoi(sm[1].str());
273
+ if (match("IVF([0-9]+[kM]?)_NSG([0-9]+)")) {
274
+ nlist = parse_nlist(sm[1].str());
258
275
  int R = std::stoi(sm[2]);
259
276
  return new IndexNSGFlat(d, R, mt);
260
277
  }
261
- if (match("IVF([0-9]+)\\(Index([0-9])\\)")) {
262
- nlist = std::stoi(sm[1].str());
278
+ if (match("IVF([0-9]+[kM]?)\\(Index([0-9])\\)")) {
279
+ nlist = parse_nlist(sm[1].str());
263
280
  int no = std::stoi(sm[2].str());
264
281
  FAISS_ASSERT(no >= 0 && no < parenthesis_indexes.size());
265
282
  return parenthesis_indexes[no].release();
@@ -440,11 +457,13 @@ IndexHNSW* parse_IndexHNSW(
440
457
  if (match("Flat|")) {
441
458
  return new IndexHNSWFlat(d, hnsw_M, mt);
442
459
  }
443
- if (match("PQ([0-9]+)(np)?")) {
460
+
461
+ if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
444
462
  int M = std::stoi(sm[1].str());
445
- IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M);
463
+ int nbit = mres_to_int(sm[2], 8, 1);
464
+ IndexHNSWPQ* ipq = new IndexHNSWPQ(d, M, hnsw_M, nbit);
446
465
  dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
447
- sm[2].str() != "np";
466
+ sm[3].str() != "np";
448
467
  return ipq;
449
468
  }
450
469
  if (match(sq_pattern)) {
@@ -490,11 +509,12 @@ IndexNSG* parse_IndexNSG(
490
509
  if (match("Flat|")) {
491
510
  return new IndexNSGFlat(d, nsg_R, mt);
492
511
  }
493
- if (match("PQ([0-9]+)(np)?")) {
512
+ if (match("PQ([0-9]+)(x[0-9]+)?(np)?")) {
494
513
  int M = std::stoi(sm[1].str());
495
- IndexNSGPQ* ipq = new IndexNSGPQ(d, M, nsg_R);
514
+ int nbit = mres_to_int(sm[2], 8, 1);
515
+ IndexNSGPQ* ipq = new IndexNSGPQ(d, M, nsg_R, nbit);
496
516
  dynamic_cast<IndexPQ*>(ipq->storage)->do_polysemous_training =
497
- sm[2].str() != "np";
517
+ sm[3].str() != "np";
498
518
  return ipq;
499
519
  }
500
520
  if (match(sq_pattern)) {
@@ -523,11 +543,12 @@ Index* parse_other_indexes(
523
543
  }
524
544
 
525
545
  // IndexLSH
526
- if (match("LSH(r?)(t?)")) {
527
- bool rotate_data = sm[1].length() > 0;
528
- bool train_thresholds = sm[2].length() > 0;
546
+ if (match("LSH([0-9]*)(r?)(t?)")) {
547
+ int nbits = sm[1].length() > 0 ? std::stoi(sm[1].str()) : d;
548
+ bool rotate_data = sm[2].length() > 0;
549
+ bool train_thresholds = sm[3].length() > 0;
529
550
  FAISS_THROW_IF_NOT(metric == METRIC_L2);
530
- return new IndexLSH(d, d, rotate_data, train_thresholds);
551
+ return new IndexLSH(d, nbits, rotate_data, train_thresholds);
531
552
  }
532
553
 
533
554
  // IndexLattice
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  // I/O code for indexes
11
9
 
12
10
  #ifndef FAISS_INDEX_IO_H
@@ -35,9 +33,12 @@ struct IOReader;
35
33
  struct IOWriter;
36
34
  struct InvertedLists;
37
35
 
38
- void write_index(const Index* idx, const char* fname);
39
- void write_index(const Index* idx, FILE* f);
40
- void write_index(const Index* idx, IOWriter* writer);
36
+ /// skip the storage for graph-based indexes
37
+ const int IO_FLAG_SKIP_STORAGE = 1;
38
+
39
+ void write_index(const Index* idx, const char* fname, int io_flags = 0);
40
+ void write_index(const Index* idx, FILE* f, int io_flags = 0);
41
+ void write_index(const Index* idx, IOWriter* writer, int io_flags = 0);
41
42
 
42
43
  void write_index_binary(const IndexBinary* idx, const char* fname);
43
44
  void write_index_binary(const IndexBinary* idx, FILE* f);
@@ -52,6 +53,12 @@ const int IO_FLAG_ONDISK_SAME_DIR = 4;
52
53
  const int IO_FLAG_SKIP_IVF_DATA = 8;
53
54
  // don't initialize precomputed table after loading
54
55
  const int IO_FLAG_SKIP_PRECOMPUTE_TABLE = 16;
56
+ // don't compute the sdc table for PQ-based indices
57
+ // this will prevent distances from being computed
58
+ // between elements in the index. For indices like HNSWPQ,
59
+ // this will prevent graph building because sdc
60
+ // computations are required to construct the graph
61
+ const int IO_FLAG_PQ_SKIP_SDC_TABLE = 32;
55
62
  // try to memmap data (useful to load an ArrayInvertedLists as an
56
63
  // OnDiskInvertedLists)
57
64
  const int IO_FLAG_MMAP = IO_FLAG_SKIP_IVF_DATA | 0x646f0000;
@@ -9,6 +9,7 @@
9
9
 
10
10
  #include <faiss/impl/CodePacker.h>
11
11
  #include <faiss/impl/FaissAssert.h>
12
+ #include <faiss/impl/IDSelector.h>
12
13
 
13
14
  #include <faiss/impl/io.h>
14
15
  #include <faiss/impl/io_macros.h>
@@ -54,7 +55,9 @@ size_t BlockInvertedLists::add_entries(
54
55
  codes[list_no].resize(n_block * block_size);
55
56
  if (o % block_size == 0) {
56
57
  // copy whole blocks
57
- memcpy(&codes[list_no][o * code_size], code, n_block * block_size);
58
+ memcpy(&codes[list_no][o * packer->code_size],
59
+ code,
60
+ n_block * block_size);
58
61
  } else {
59
62
  FAISS_THROW_IF_NOT_MSG(packer, "missing code packer");
60
63
  std::vector<uint8_t> buffer(packer->code_size);
@@ -76,6 +79,29 @@ const uint8_t* BlockInvertedLists::get_codes(size_t list_no) const {
76
79
  return codes[list_no].get();
77
80
  }
78
81
 
82
+ size_t BlockInvertedLists::remove_ids(const IDSelector& sel) {
83
+ idx_t nremove = 0;
84
+ #pragma omp parallel for
85
+ for (idx_t i = 0; i < nlist; i++) {
86
+ std::vector<uint8_t> buffer(packer->code_size);
87
+ idx_t l = ids[i].size(), j = 0;
88
+ while (j < l) {
89
+ if (sel.is_member(ids[i][j])) {
90
+ l--;
91
+ ids[i][j] = ids[i][l];
92
+ packer->unpack_1(codes[i].data(), l, buffer.data());
93
+ packer->pack_1(buffer.data(), j, codes[i].data());
94
+ } else {
95
+ j++;
96
+ }
97
+ }
98
+ resize(i, l);
99
+ nremove += ids[i].size() - l;
100
+ }
101
+
102
+ return nremove;
103
+ }
104
+
79
105
  const idx_t* BlockInvertedLists::get_ids(size_t list_no) const {
80
106
  assert(list_no < nlist);
81
107
  return ids[list_no].data();
@@ -101,13 +127,7 @@ void BlockInvertedLists::update_entries(
101
127
  size_t,
102
128
  const idx_t*,
103
129
  const uint8_t*) {
104
- FAISS_THROW_MSG("not impemented");
105
- /*
106
- assert (list_no < nlist);
107
- assert (n_entry + offset <= ids[list_no].size());
108
- memcpy (&ids[list_no][offset], ids_in, sizeof(ids_in[0]) * n_entry);
109
- memcpy (&codes[list_no][offset * code_size], codes_in, code_size * n_entry);
110
- */
130
+ FAISS_THROW_MSG("not implemented");
111
131
  }
112
132
 
113
133
  BlockInvertedLists::~BlockInvertedLists() {
@@ -15,6 +15,7 @@
15
15
  namespace faiss {
16
16
 
17
17
  struct CodePacker;
18
+ struct IDSelector;
18
19
 
19
20
  /** Inverted Lists that are organized by blocks.
20
21
  *
@@ -47,6 +48,8 @@ struct BlockInvertedLists : InvertedLists {
47
48
  size_t list_size(size_t list_no) const override;
48
49
  const uint8_t* get_codes(size_t list_no) const override;
49
50
  const idx_t* get_ids(size_t list_no) const override;
51
+ /// remove ids from the InvertedLists
52
+ size_t remove_ids(const IDSelector& sel);
50
53
 
51
54
  // works only on empty BlockInvertedLists
52
55
  // the codes should be of size ceil(n_entry / n_per_block) * block_size
@@ -15,6 +15,7 @@
15
15
  #include <faiss/impl/AuxIndexStructures.h>
16
16
  #include <faiss/impl/FaissAssert.h>
17
17
  #include <faiss/impl/IDSelector.h>
18
+ #include <faiss/invlists/BlockInvertedLists.h>
18
19
 
19
20
  namespace faiss {
20
21
 
@@ -148,8 +149,12 @@ size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists* invlists) {
148
149
  std::vector<idx_t> toremove(nlist);
149
150
 
150
151
  size_t nremove = 0;
151
-
152
+ BlockInvertedLists* block_invlists =
153
+ dynamic_cast<BlockInvertedLists*>(invlists);
152
154
  if (type == NoMap) {
155
+ if (block_invlists != nullptr) {
156
+ return block_invlists->remove_ids(sel);
157
+ }
153
158
  // exhaustive scan of IVF
154
159
  #pragma omp parallel for
155
160
  for (idx_t i = 0; i < nlist; i++) {
@@ -178,6 +183,9 @@ size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists* invlists) {
178
183
  }
179
184
  }
180
185
  } else if (type == Hashtable) {
186
+ FAISS_THROW_IF_MSG(
187
+ block_invlists,
188
+ "remove with hashtable is not supported with BlockInvertedLists");
181
189
  const IDSelectorArray* sela =
182
190
  dynamic_cast<const IDSelectorArray*>(&sel);
183
191
  FAISS_THROW_IF_NOT_MSG(
@@ -199,7 +207,7 @@ size_t DirectMap::remove_ids(const IDSelector& sel, InvertedLists* invlists) {
199
207
  last_id,
200
208
  ScopedCodes(invlists, list_no, last).get());
201
209
  // update hash entry for last element
202
- hashtable[last_id] = list_no << 32 | offset;
210
+ hashtable[last_id] = lo_build(list_no, offset);
203
211
  }
204
212
  invlists->resize(list_no, last);
205
213
  nremove++;
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/invlists/InvertedLists.h>
11
9
 
12
10
  #include <cstdio>
@@ -24,17 +22,10 @@ InvertedListsIterator::~InvertedListsIterator() {}
24
22
  ******************************************/
25
23
 
26
24
  InvertedLists::InvertedLists(size_t nlist, size_t code_size)
27
- : nlist(nlist), code_size(code_size), use_iterator(false) {}
25
+ : nlist(nlist), code_size(code_size) {}
28
26
 
29
27
  InvertedLists::~InvertedLists() {}
30
28
 
31
- bool InvertedLists::is_empty(size_t list_no) const {
32
- return use_iterator
33
- ? !std::unique_ptr<InvertedListsIterator>(get_iterator(list_no))
34
- ->is_available()
35
- : list_size(list_no) == 0;
36
- }
37
-
38
29
  idx_t InvertedLists::get_single_id(size_t list_no, size_t offset) const {
39
30
  assert(offset < list_size(list_no));
40
31
  const idx_t* ids = get_ids(list_no);
@@ -58,7 +49,8 @@ const uint8_t* InvertedLists::get_single_code(size_t list_no, size_t offset)
58
49
  size_t InvertedLists::add_entry(
59
50
  size_t list_no,
60
51
  idx_t theid,
61
- const uint8_t* code) {
52
+ const uint8_t* code,
53
+ void* /*inverted_list_context*/) {
62
54
  return add_entries(list_no, 1, &theid, code);
63
55
  }
64
56
 
@@ -76,10 +68,6 @@ void InvertedLists::reset() {
76
68
  }
77
69
  }
78
70
 
79
- InvertedListsIterator* InvertedLists::get_iterator(size_t /*list_no*/) const {
80
- FAISS_THROW_MSG("get_iterator is not supported");
81
- }
82
-
83
71
  void InvertedLists::merge_from(InvertedLists* oivf, size_t add_id) {
84
72
  #pragma omp parallel for
85
73
  for (idx_t i = 0; i < nlist; i++) {
@@ -229,6 +217,54 @@ size_t InvertedLists::compute_ntotal() const {
229
217
  return tot;
230
218
  }
231
219
 
220
+ bool InvertedLists::is_empty(size_t list_no, void* inverted_list_context)
221
+ const {
222
+ if (use_iterator) {
223
+ return !std::unique_ptr<InvertedListsIterator>(
224
+ get_iterator(list_no, inverted_list_context))
225
+ ->is_available();
226
+ } else {
227
+ FAISS_THROW_IF_NOT(inverted_list_context == nullptr);
228
+ return list_size(list_no) == 0;
229
+ }
230
+ }
231
+
232
+ // implemnent iterator on top of get_codes / get_ids
233
+ namespace {
234
+
235
+ struct CodeArrayIterator : InvertedListsIterator {
236
+ size_t list_size;
237
+ size_t code_size;
238
+ InvertedLists::ScopedCodes codes;
239
+ InvertedLists::ScopedIds ids;
240
+ size_t idx = 0;
241
+
242
+ CodeArrayIterator(const InvertedLists* il, size_t list_no)
243
+ : list_size(il->list_size(list_no)),
244
+ code_size(il->code_size),
245
+ codes(il, list_no),
246
+ ids(il, list_no) {}
247
+
248
+ bool is_available() const override {
249
+ return idx < list_size;
250
+ }
251
+ void next() override {
252
+ idx++;
253
+ }
254
+ std::pair<idx_t, const uint8_t*> get_id_and_codes() override {
255
+ return {ids[idx], codes.get() + code_size * idx};
256
+ }
257
+ };
258
+
259
+ } // namespace
260
+
261
+ InvertedListsIterator* InvertedLists::get_iterator(
262
+ size_t list_no,
263
+ void* inverted_list_context) const {
264
+ FAISS_THROW_IF_NOT(inverted_list_context == nullptr);
265
+ return new CodeArrayIterator(this, list_no);
266
+ }
267
+
232
268
  /*****************************************
233
269
  * ArrayInvertedLists implementation
234
270
  ******************************************/
@@ -260,6 +296,12 @@ size_t ArrayInvertedLists::list_size(size_t list_no) const {
260
296
  return ids[list_no].size();
261
297
  }
262
298
 
299
+ bool ArrayInvertedLists::is_empty(size_t list_no, void* inverted_list_context)
300
+ const {
301
+ FAISS_THROW_IF_NOT(inverted_list_context == nullptr);
302
+ return ids[list_no].size() == 0;
303
+ }
304
+
263
305
  const uint8_t* ArrayInvertedLists::get_codes(size_t list_no) const {
264
306
  assert(list_no < nlist);
265
307
  return codes[list_no].data();
@@ -287,6 +329,20 @@ void ArrayInvertedLists::update_entries(
287
329
  memcpy(&codes[list_no][offset * code_size], codes_in, code_size * n_entry);
288
330
  }
289
331
 
332
+ void ArrayInvertedLists::permute_invlists(const idx_t* map) {
333
+ std::vector<std::vector<uint8_t>> new_codes(nlist);
334
+ std::vector<std::vector<idx_t>> new_ids(nlist);
335
+
336
+ for (size_t i = 0; i < nlist; i++) {
337
+ size_t o = map[i];
338
+ FAISS_THROW_IF_NOT(o < nlist);
339
+ std::swap(new_codes[i], codes[o]);
340
+ std::swap(new_ids[i], ids[o]);
341
+ }
342
+ std::swap(codes, new_codes);
343
+ std::swap(ids, new_ids);
344
+ }
345
+
290
346
  ArrayInvertedLists::~ArrayInvertedLists() {}
291
347
 
292
348
  /*****************************************************************
@@ -423,7 +479,7 @@ idx_t translate_list_no(const SliceInvertedLists* sil, idx_t list_no) {
423
479
  return list_no + sil->i0;
424
480
  }
425
481
 
426
- }; // namespace
482
+ } // namespace
427
483
 
428
484
  SliceInvertedLists::SliceInvertedLists(
429
485
  const InvertedLists* il,
@@ -508,7 +564,7 @@ idx_t sum_il_sizes(int nil, const InvertedLists** ils_in) {
508
564
  return tot;
509
565
  }
510
566
 
511
- }; // namespace
567
+ } // namespace
512
568
 
513
569
  VStackInvertedLists::VStackInvertedLists(int nil, const InvertedLists** ils_in)
514
570
  : ReadOnlyInvertedLists(
@@ -37,7 +37,9 @@ struct InvertedListsIterator {
37
37
  struct InvertedLists {
38
38
  size_t nlist; ///< number of possible key values
39
39
  size_t code_size; ///< code size per vector in bytes
40
- bool use_iterator;
40
+
41
+ /// request to use iterator rather than get_codes / get_ids
42
+ bool use_iterator = false;
41
43
 
42
44
  InvertedLists(size_t nlist, size_t code_size);
43
45
 
@@ -50,15 +52,9 @@ struct InvertedLists {
50
52
  /*************************
51
53
  * Read only functions */
52
54
 
53
- // check if the list is empty
54
- bool is_empty(size_t list_no) const;
55
-
56
55
  /// get the size of a list
57
56
  virtual size_t list_size(size_t list_no) const = 0;
58
57
 
59
- /// get iterable for lists that use_iterator
60
- virtual InvertedListsIterator* get_iterator(size_t list_no) const;
61
-
62
58
  /** get the codes for an inverted list
63
59
  * must be released by release_codes
64
60
  *
@@ -90,11 +86,27 @@ struct InvertedLists {
90
86
  /// a list can be -1 hence the signed long
91
87
  virtual void prefetch_lists(const idx_t* list_nos, int nlist) const;
92
88
 
89
+ /*****************************************
90
+ * Iterator interface (with context) */
91
+
92
+ /// check if the list is empty
93
+ virtual bool is_empty(size_t list_no, void* inverted_list_context = nullptr)
94
+ const;
95
+
96
+ /// get iterable for lists that use_iterator
97
+ virtual InvertedListsIterator* get_iterator(
98
+ size_t list_no,
99
+ void* inverted_list_context = nullptr) const;
100
+
93
101
  /*************************
94
102
  * writing functions */
95
103
 
96
104
  /// add one entry to an inverted list
97
- virtual size_t add_entry(size_t list_no, idx_t theid, const uint8_t* code);
105
+ virtual size_t add_entry(
106
+ size_t list_no,
107
+ idx_t theid,
108
+ const uint8_t* code,
109
+ void* inverted_list_context = nullptr);
98
110
 
99
111
  virtual size_t add_entries(
100
112
  size_t list_no,
@@ -253,6 +265,12 @@ struct ArrayInvertedLists : InvertedLists {
253
265
 
254
266
  void resize(size_t list_no, size_t new_size) override;
255
267
 
268
+ /// permute the inverted lists, map maps new_id to old_id
269
+ void permute_invlists(const idx_t* map);
270
+
271
+ bool is_empty(size_t list_no, void* inverted_list_context = nullptr)
272
+ const override;
273
+
256
274
  ~ArrayInvertedLists() override;
257
275
  };
258
276
 
@@ -394,8 +394,8 @@ const idx_t* OnDiskInvertedLists::get_ids(size_t list_no) const {
394
394
  return nullptr;
395
395
  }
396
396
 
397
- return (
398
- const idx_t*)(ptr + lists[list_no].offset + code_size * lists[list_no].capacity);
397
+ return (const idx_t*)(ptr + lists[list_no].offset +
398
+ code_size * lists[list_no].capacity);
399
399
  }
400
400
 
401
401
  void OnDiskInvertedLists::update_entries(
@@ -407,7 +407,7 @@ void OnDiskInvertedLists::update_entries(
407
407
  FAISS_THROW_IF_NOT(!read_only);
408
408
  if (n_entry == 0)
409
409
  return;
410
- const List& l = lists[list_no];
410
+ [[maybe_unused]] const List& l = lists[list_no];
411
411
  assert(n_entry + offset <= l.size);
412
412
  idx_t* ids = const_cast<idx_t*>(get_ids(list_no));
413
413
  memcpy(ids + offset, ids_in, sizeof(ids_in[0]) * n_entry);
@@ -524,7 +524,7 @@ void OnDiskInvertedLists::free_slot(size_t offset, size_t capacity) {
524
524
  it++;
525
525
  }
526
526
 
527
- size_t inf = 1UL << 60;
527
+ size_t inf = ((size_t)1) << 60;
528
528
 
529
529
  size_t end_prev = inf;
530
530
  if (it != slots.begin()) {
@@ -533,7 +533,7 @@ void OnDiskInvertedLists::free_slot(size_t offset, size_t capacity) {
533
533
  end_prev = prev->offset + prev->capacity;
534
534
  }
535
535
 
536
- size_t begin_next = 1L << 60;
536
+ size_t begin_next = ((size_t)1) << 60;
537
537
  if (it != slots.end()) {
538
538
  begin_next = it->offset;
539
539
  }
@@ -565,15 +565,16 @@ void OnDiskInvertedLists::free_slot(size_t offset, size_t capacity) {
565
565
  /*****************************************
566
566
  * Compact form
567
567
  *****************************************/
568
-
569
- size_t OnDiskInvertedLists::merge_from(
568
+ size_t OnDiskInvertedLists::merge_from_multiple(
570
569
  const InvertedLists** ils,
571
570
  int n_il,
571
+ bool shift_ids,
572
572
  bool verbose) {
573
573
  FAISS_THROW_IF_NOT_MSG(
574
574
  totsize == 0, "works only on an empty InvertedLists");
575
575
 
576
576
  std::vector<size_t> sizes(nlist);
577
+ std::vector<size_t> shift_id_offsets(n_il);
577
578
  for (int i = 0; i < n_il; i++) {
578
579
  const InvertedLists* il = ils[i];
579
580
  FAISS_THROW_IF_NOT(il->nlist == nlist && il->code_size == code_size);
@@ -581,6 +582,10 @@ size_t OnDiskInvertedLists::merge_from(
581
582
  for (size_t j = 0; j < nlist; j++) {
582
583
  sizes[j] += il->list_size(j);
583
584
  }
585
+
586
+ size_t il_totsize = il->compute_ntotal();
587
+ shift_id_offsets[i] =
588
+ (shift_ids && i > 0) ? shift_id_offsets[i - 1] + il_totsize : 0;
584
589
  }
585
590
 
586
591
  size_t cums = 0;
@@ -605,11 +610,21 @@ size_t OnDiskInvertedLists::merge_from(
605
610
  const InvertedLists* il = ils[i];
606
611
  size_t n_entry = il->list_size(j);
607
612
  l.size += n_entry;
613
+ ScopedIds scope_ids(il, j);
614
+ const idx_t* scope_ids_data = scope_ids.get();
615
+ std::vector<idx_t> new_ids;
616
+ if (shift_ids) {
617
+ new_ids.resize(n_entry);
618
+ for (size_t k = 0; k < n_entry; k++) {
619
+ new_ids[k] = scope_ids[k] + shift_id_offsets[i];
620
+ }
621
+ scope_ids_data = new_ids.data();
622
+ }
608
623
  update_entries(
609
624
  j,
610
625
  l.size - n_entry,
611
626
  n_entry,
612
- ScopedIds(il, j).get(),
627
+ scope_ids_data,
613
628
  ScopedCodes(il, j).get());
614
629
  }
615
630
  assert(l.size == l.capacity);
@@ -638,7 +653,7 @@ size_t OnDiskInvertedLists::merge_from(
638
653
  size_t OnDiskInvertedLists::merge_from_1(
639
654
  const InvertedLists* ils,
640
655
  bool verbose) {
641
- return merge_from(&ils, 1, verbose);
656
+ return merge_from_multiple(&ils, 1, verbose);
642
657
  }
643
658
 
644
659
  void OnDiskInvertedLists::crop_invlists(size_t l0, size_t l1) {
@@ -101,9 +101,10 @@ struct OnDiskInvertedLists : InvertedLists {
101
101
 
102
102
  // copy all inverted lists into *this, in compact form (without
103
103
  // allocating slots)
104
- size_t merge_from(
104
+ size_t merge_from_multiple(
105
105
  const InvertedLists** ils,
106
106
  int n_il,
107
+ bool shift_ids = false,
107
108
  bool verbose = false);
108
109
 
109
110
  /// same as merge_from for a single invlist
@@ -22,7 +22,7 @@ struct PyThreadLock {
22
22
  }
23
23
  };
24
24
 
25
- }; // namespace
25
+ } // namespace
26
26
 
27
27
  /***********************************************************
28
28
  * Callbacks for IO reader and writer
@@ -46,7 +46,7 @@ size_t PyCallbackIOWriter::operator()(
46
46
  size_t wi = ws > bs ? bs : ws;
47
47
  PyObject* result = PyObject_CallFunction(
48
48
  callback, "(N)", PyBytes_FromStringAndSize(ptr, wi));
49
- if (result == NULL) {
49
+ if (result == nullptr) {
50
50
  FAISS_THROW_MSG("py err");
51
51
  }
52
52
  // TODO check nb of bytes written
@@ -77,7 +77,7 @@ size_t PyCallbackIOReader::operator()(void* ptrv, size_t size, size_t nitems) {
77
77
  while (rs > 0) {
78
78
  size_t ri = rs > bs ? bs : rs;
79
79
  PyObject* result = PyObject_CallFunction(callback, "(n)", ri);
80
- if (result == NULL) {
80
+ if (result == nullptr) {
81
81
  FAISS_THROW_MSG("propagate py error");
82
82
  }
83
83
  if (!PyBytes_Check(result)) {
@@ -122,7 +122,7 @@ bool PyCallbackIDSelector::is_member(faiss::idx_t id) const {
122
122
  FAISS_THROW_IF_NOT((id >> 32) == 0);
123
123
  PyThreadLock gil;
124
124
  PyObject* result = PyObject_CallFunction(callback, "(n)", int(id));
125
- if (result == NULL) {
125
+ if (result == nullptr) {
126
126
  FAISS_THROW_MSG("propagate py error");
127
127
  }
128
128
  bool b = PyObject_IsTrue(result);
@@ -93,7 +93,7 @@ void HeapArray<C>::addn_query_subset_with_ids(
93
93
  }
94
94
  #pragma omp parallel for if (nsubset * nj > 100000)
95
95
  for (int64_t si = 0; si < nsubset; si++) {
96
- T i = subset[si];
96
+ TI i = subset[si];
97
97
  T* __restrict simi = get_val(i);
98
98
  TI* __restrict idxi = get_ids(i);
99
99
  const T* ip_line = vin + si * nj;
@@ -136,6 +136,8 @@ void HeapArray<C>::per_line_extrema(T* out_val, TI* out_ids) const {
136
136
 
137
137
  template struct HeapArray<CMin<float, int64_t>>;
138
138
  template struct HeapArray<CMax<float, int64_t>>;
139
+ template struct HeapArray<CMin<float, int32_t>>;
140
+ template struct HeapArray<CMax<float, int32_t>>;
139
141
  template struct HeapArray<CMin<int, int64_t>>;
140
142
  template struct HeapArray<CMax<int, int64_t>>;
141
143