faiss 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -14,6 +14,7 @@
14
14
  #include <cstring>
15
15
 
16
16
  #include <algorithm>
17
+ #include <memory>
17
18
 
18
19
  #include <faiss/impl/DistanceComputer.h>
19
20
  #include <faiss/impl/FaissAssert.h>
@@ -86,7 +87,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
86
87
  ndis++;
87
88
 
88
89
  float dis = distance_single_code<PQDecoder>(
89
- pq, precomputed_table.data(), code);
90
+ pq.M, pq.nbits, precomputed_table.data(), code);
90
91
  return dis;
91
92
  }
92
93
 
@@ -198,17 +199,16 @@ void IndexPQ::search(
198
199
 
199
200
  } else { // code-to-code distances
200
201
 
201
- uint8_t* q_codes = new uint8_t[n * pq.code_size];
202
- ScopeDeleter<uint8_t> del(q_codes);
202
+ std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]);
203
203
 
204
204
  if (!encode_signs) {
205
- pq.compute_codes(x, q_codes, n);
205
+ pq.compute_codes(x, q_codes.get(), n);
206
206
  } else {
207
207
  FAISS_THROW_IF_NOT(d == pq.nbits * pq.M);
208
- memset(q_codes, 0, n * pq.code_size);
208
+ memset(q_codes.get(), 0, n * pq.code_size);
209
209
  for (size_t i = 0; i < n; i++) {
210
210
  const float* xi = x + i * d;
211
- uint8_t* code = q_codes + i * pq.code_size;
211
+ uint8_t* code = q_codes.get() + i * pq.code_size;
212
212
  for (int j = 0; j < d; j++)
213
213
  if (xi[j] > 0)
214
214
  code[j >> 3] |= 1 << (j & 7);
@@ -219,19 +219,18 @@ void IndexPQ::search(
219
219
  float_maxheap_array_t res = {
220
220
  size_t(n), size_t(k), labels, distances};
221
221
 
222
- pq.search_sdc(q_codes, n, codes.data(), ntotal, &res, true);
222
+ pq.search_sdc(q_codes.get(), n, codes.data(), ntotal, &res, true);
223
223
 
224
224
  } else {
225
- int* idistances = new int[n * k];
226
- ScopeDeleter<int> del(idistances);
225
+ std::unique_ptr<int[]> idistances(new int[n * k]);
227
226
 
228
227
  int_maxheap_array_t res = {
229
- size_t(n), size_t(k), labels, idistances};
228
+ size_t(n), size_t(k), labels, idistances.get()};
230
229
 
231
230
  if (search_type == ST_HE) {
232
231
  hammings_knn_hc(
233
232
  &res,
234
- q_codes,
233
+ q_codes.get(),
235
234
  codes.data(),
236
235
  ntotal,
237
236
  pq.code_size,
@@ -240,7 +239,7 @@ void IndexPQ::search(
240
239
  } else if (search_type == ST_generalized_HE) {
241
240
  generalized_hammings_knn_hc(
242
241
  &res,
243
- q_codes,
242
+ q_codes.get(),
244
243
  codes.data(),
245
244
  ntotal,
246
245
  pq.code_size,
@@ -263,21 +262,23 @@ void IndexPQStats::reset() {
263
262
 
264
263
  IndexPQStats indexPQ_stats;
265
264
 
265
+ namespace {
266
+
266
267
  template <class HammingComputer>
267
- static size_t polysemous_inner_loop(
268
- const IndexPQ& index,
268
+ size_t polysemous_inner_loop(
269
+ const IndexPQ* index,
269
270
  const float* dis_table_qi,
270
271
  const uint8_t* q_code,
271
272
  size_t k,
272
273
  float* heap_dis,
273
274
  int64_t* heap_ids,
274
275
  int ht) {
275
- int M = index.pq.M;
276
- int code_size = index.pq.code_size;
277
- int ksub = index.pq.ksub;
278
- size_t ntotal = index.ntotal;
276
+ int M = index->pq.M;
277
+ int code_size = index->pq.code_size;
278
+ int ksub = index->pq.ksub;
279
+ size_t ntotal = index->ntotal;
279
280
 
280
- const uint8_t* b_code = index.codes.data();
281
+ const uint8_t* b_code = index->codes.data();
281
282
 
282
283
  size_t n_pass_i = 0;
283
284
 
@@ -305,6 +306,16 @@ static size_t polysemous_inner_loop(
305
306
  return n_pass_i;
306
307
  }
307
308
 
309
+ struct Run_polysemous_inner_loop {
310
+ using T = size_t;
311
+ template <class HammingComputer, class... Types>
312
+ size_t f(Types... args) {
313
+ return polysemous_inner_loop<HammingComputer>(args...);
314
+ }
315
+ };
316
+
317
+ } // anonymous namespace
318
+
308
319
  void IndexPQ::search_core_polysemous(
309
320
  idx_t n,
310
321
  const float* x,
@@ -321,22 +332,20 @@ void IndexPQ::search_core_polysemous(
321
332
  }
322
333
 
323
334
  // PQ distance tables
324
- float* dis_tables = new float[n * pq.ksub * pq.M];
325
- ScopeDeleter<float> del(dis_tables);
326
- pq.compute_distance_tables(n, x, dis_tables);
335
+ std::unique_ptr<float[]> dis_tables(new float[n * pq.ksub * pq.M]);
336
+ pq.compute_distance_tables(n, x, dis_tables.get());
327
337
 
328
338
  // Hamming embedding queries
329
- uint8_t* q_codes = new uint8_t[n * pq.code_size];
330
- ScopeDeleter<uint8_t> del2(q_codes);
339
+ std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]);
331
340
 
332
341
  if (false) {
333
- pq.compute_codes(x, q_codes, n);
342
+ pq.compute_codes(x, q_codes.get(), n);
334
343
  } else {
335
344
  #pragma omp parallel for
336
345
  for (idx_t qi = 0; qi < n; qi++) {
337
346
  pq.compute_code_from_distance_table(
338
- dis_tables + qi * pq.M * pq.ksub,
339
- q_codes + qi * pq.code_size);
347
+ dis_tables.get() + qi * pq.M * pq.ksub,
348
+ q_codes.get() + qi * pq.code_size);
340
349
  }
341
350
  }
342
351
 
@@ -346,54 +355,33 @@ void IndexPQ::search_core_polysemous(
346
355
 
347
356
  #pragma omp parallel for reduction(+ : n_pass, bad_code_size)
348
357
  for (idx_t qi = 0; qi < n; qi++) {
349
- const uint8_t* q_code = q_codes + qi * pq.code_size;
358
+ const uint8_t* q_code = q_codes.get() + qi * pq.code_size;
350
359
 
351
- const float* dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
360
+ const float* dis_table_qi = dis_tables.get() + qi * pq.M * pq.ksub;
352
361
 
353
362
  int64_t* heap_ids = labels + qi * k;
354
363
  float* heap_dis = distances + qi * k;
355
364
  maxheap_heapify(k, heap_dis, heap_ids);
356
365
 
357
366
  if (!generalized_hamming) {
358
- switch (pq.code_size) {
359
- #define DISPATCH(cs) \
360
- case cs: \
361
- n_pass += polysemous_inner_loop<HammingComputer##cs>( \
362
- *this, \
363
- dis_table_qi, \
364
- q_code, \
365
- k, \
366
- heap_dis, \
367
- heap_ids, \
368
- polysemous_ht); \
369
- break;
370
- DISPATCH(4)
371
- DISPATCH(8)
372
- DISPATCH(16)
373
- DISPATCH(32)
374
- DISPATCH(20)
375
- default:
376
- if (pq.code_size % 4 == 0) {
377
- n_pass += polysemous_inner_loop<HammingComputerDefault>(
378
- *this,
379
- dis_table_qi,
380
- q_code,
381
- k,
382
- heap_dis,
383
- heap_ids,
384
- polysemous_ht);
385
- } else {
386
- bad_code_size++;
387
- }
388
- break;
389
- }
390
- #undef DISPATCH
367
+ Run_polysemous_inner_loop r;
368
+ n_pass += dispatch_HammingComputer(
369
+ pq.code_size,
370
+ r,
371
+ this,
372
+ dis_table_qi,
373
+ q_code,
374
+ k,
375
+ heap_dis,
376
+ heap_ids,
377
+ polysemous_ht);
378
+
391
379
  } else { // generalized hamming
392
380
  switch (pq.code_size) {
393
381
  #define DISPATCH(cs) \
394
382
  case cs: \
395
383
  n_pass += polysemous_inner_loop<GenHammingComputer##cs>( \
396
- *this, \
384
+ this, \
397
385
  dis_table_qi, \
398
386
  q_code, \
399
387
  k, \
@@ -407,7 +395,7 @@ void IndexPQ::search_core_polysemous(
407
395
  default:
408
396
  if (pq.code_size % 8 == 0) {
409
397
  n_pass += polysemous_inner_loop<GenHammingComputerM8>(
410
- *this,
398
+ this,
411
399
  dis_table_qi,
412
400
  q_code,
413
401
  k,
@@ -450,12 +438,11 @@ void IndexPQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
450
438
 
451
439
  void IndexPQ::hamming_distance_table(idx_t n, const float* x, int32_t* dis)
452
440
  const {
453
- uint8_t* q_codes = new uint8_t[n * pq.code_size];
454
- ScopeDeleter<uint8_t> del(q_codes);
441
+ std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]);
455
442
 
456
- pq.compute_codes(x, q_codes, n);
443
+ pq.compute_codes(x, q_codes.get(), n);
457
444
 
458
- hammings(q_codes, codes.data(), n, ntotal, pq.code_size, dis);
445
+ hammings(q_codes.get(), codes.data(), n, ntotal, pq.code_size, dis);
459
446
  }
460
447
 
461
448
  void IndexPQ::hamming_distance_histogram(
@@ -469,16 +456,15 @@ void IndexPQ::hamming_distance_histogram(
469
456
  FAISS_THROW_IF_NOT(pq.nbits == 8);
470
457
 
471
458
  // Hamming embedding queries
472
- uint8_t* q_codes = new uint8_t[n * pq.code_size];
473
- ScopeDeleter<uint8_t> del(q_codes);
474
- pq.compute_codes(x, q_codes, n);
459
+ std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]);
460
+ pq.compute_codes(x, q_codes.get(), n);
475
461
 
476
462
  uint8_t* b_codes;
477
- ScopeDeleter<uint8_t> del_b_codes;
463
+ std::unique_ptr<uint8_t[]> del_b_codes;
478
464
 
479
465
  if (xb) {
480
466
  b_codes = new uint8_t[nb * pq.code_size];
481
- del_b_codes.set(b_codes);
467
+ del_b_codes.reset(b_codes);
482
468
  pq.compute_codes(xb, b_codes, nb);
483
469
  } else {
484
470
  nb = ntotal;
@@ -491,8 +477,7 @@ void IndexPQ::hamming_distance_histogram(
491
477
  #pragma omp parallel
492
478
  {
493
479
  std::vector<int64_t> histi(nbits + 1);
494
- hamdis_t* distances = new hamdis_t[nb * bs];
495
- ScopeDeleter<hamdis_t> del(distances);
480
+ std::unique_ptr<hamdis_t[]> distances(new hamdis_t[nb * bs]);
496
481
  #pragma omp for
497
482
  for (idx_t q0 = 0; q0 < n; q0 += bs) {
498
483
  // printf ("dis stats: %zd/%zd\n", q0, n);
@@ -501,12 +486,12 @@ void IndexPQ::hamming_distance_histogram(
501
486
  q1 = n;
502
487
 
503
488
  hammings(
504
- q_codes + q0 * pq.code_size,
489
+ q_codes.get() + q0 * pq.code_size,
505
490
  b_codes,
506
491
  q1 - q0,
507
492
  nb,
508
493
  pq.code_size,
509
- distances);
494
+ distances.get());
510
495
 
511
496
  for (size_t i = 0; i < nb * (q1 - q0); i++)
512
497
  histi[distances[i]]++;
@@ -639,7 +624,7 @@ struct SemiSortedArray {
639
624
  int N;
640
625
 
641
626
  // type of the heap: CMax = sort ascending
642
- typedef CMax<T, int> HC;
627
+ using HC = CMax<T, int>;
643
628
  std::vector<int> perm;
644
629
 
645
630
  int k; // k elements are sorted
@@ -733,7 +718,7 @@ struct MinSumK {
733
718
  * We use a heap to maintain a queue of sums, with the associated
734
719
  * terms involved in the sum.
735
720
  */
736
- typedef CMin<T, int64_t> HC;
721
+ using HC = CMin<T, int64_t>;
737
722
  size_t heap_capacity, heap_size;
738
723
  T* bh_val;
739
724
  int64_t* bh_ids;
@@ -827,7 +812,7 @@ struct MinSumK {
827
812
  // enqueue followers
828
813
  int64_t ii = ti;
829
814
  for (int m = 0; m < M; m++) {
830
- int64_t n = ii & ((1L << nbit) - 1);
815
+ int64_t n = ii & (((int64_t)1 << nbit) - 1);
831
816
  ii >>= nbit;
832
817
  if (n + 1 >= N)
833
818
  continue;
@@ -851,7 +836,7 @@ struct MinSumK {
851
836
  }
852
837
  int64_t ti = 0;
853
838
  for (int m = 0; m < M; m++) {
854
- int64_t n = ii & ((1L << nbit) - 1);
839
+ int64_t n = ii & (((int64_t)1 << nbit) - 1);
855
840
  ti += int64_t(ssx[m].get_ord(n)) << (nbit * m);
856
841
  ii >>= nbit;
857
842
  }
@@ -923,17 +908,16 @@ void MultiIndexQuantizer::search(
923
908
  return;
924
909
  }
925
910
 
926
- float* dis_tables = new float[n * pq.ksub * pq.M];
927
- ScopeDeleter<float> del(dis_tables);
911
+ std::unique_ptr<float[]> dis_tables(new float[n * pq.ksub * pq.M]);
928
912
 
929
- pq.compute_distance_tables(n, x, dis_tables);
913
+ pq.compute_distance_tables(n, x, dis_tables.get());
930
914
 
931
915
  if (k == 1) {
932
916
  // simple version that just finds the min in each table
933
917
 
934
918
  #pragma omp parallel for
935
919
  for (int i = 0; i < n; i++) {
936
- const float* dis_table = dis_tables + i * pq.ksub * pq.M;
920
+ const float* dis_table = dis_tables.get() + i * pq.ksub * pq.M;
937
921
  float dis = 0;
938
922
  idx_t label = 0;
939
923
 
@@ -963,7 +947,7 @@ void MultiIndexQuantizer::search(
963
947
  k, pq.M, pq.nbits, pq.ksub);
964
948
  #pragma omp for
965
949
  for (int i = 0; i < n; i++) {
966
- msk.run(dis_tables + i * pq.ksub * pq.M,
950
+ msk.run(dis_tables.get() + i * pq.ksub * pq.M,
967
951
  pq.ksub,
968
952
  distances + i * k,
969
953
  labels + i * k);
@@ -975,7 +959,7 @@ void MultiIndexQuantizer::search(
975
959
  void MultiIndexQuantizer::reconstruct(idx_t key, float* recons) const {
976
960
  int64_t jj = key;
977
961
  for (int m = 0; m < pq.M; m++) {
978
- int64_t n = jj & ((1L << pq.nbits) - 1);
962
+ int64_t n = jj & (((int64_t)1 << pq.nbits) - 1);
979
963
  jj >>= pq.nbits;
980
964
  memcpy(recons, pq.get_centroids(m, n), sizeof(recons[0]) * pq.dsub);
981
965
  recons += pq.dsub;
@@ -1107,7 +1091,7 @@ void MultiIndexQuantizer2::search(
1107
1091
 
1108
1092
  const idx_t* idmap0 = sub_ids.data() + i * k2;
1109
1093
  int64_t ld_idmap = k2 * n;
1110
- int64_t mask1 = ksub - 1L;
1094
+ int64_t mask1 = ksub - (int64_t)1;
1111
1095
 
1112
1096
  for (int k = 0; k < K; k++) {
1113
1097
  const idx_t* idmap = idmap0;
@@ -31,10 +31,7 @@ struct IndexPQ : IndexFlatCodes {
31
31
  * @param M number of subquantizers
32
32
  * @param nbits number of bit per subvector index
33
33
  */
34
- IndexPQ(int d, ///< dimensionality of the input vectors
35
- size_t M, ///< number of subquantizers
36
- size_t nbits, ///< number of bit per subvector index
37
- MetricType metric = METRIC_L2);
34
+ IndexPQ(int d, size_t M, size_t nbits, MetricType metric = METRIC_L2);
38
35
 
39
36
  IndexPQ();
40
37
 
@@ -7,8 +7,8 @@
7
7
 
8
8
  #include <faiss/IndexPQFastScan.h>
9
9
 
10
- #include <limits.h>
11
10
  #include <cassert>
11
+ #include <climits>
12
12
  #include <memory>
13
13
 
14
14
  #include <omp.h>
@@ -67,7 +67,7 @@ void IndexPreTransform::train(idx_t n, const float* x) {
67
67
  }
68
68
  }
69
69
  const float* prev_x = x;
70
- ScopeDeleter<float> del;
70
+ std::unique_ptr<const float[]> del;
71
71
 
72
72
  if (verbose) {
73
73
  printf("IndexPreTransform::train: training chain 0 to %d\n",
@@ -102,10 +102,12 @@ void IndexPreTransform::train(idx_t n, const float* x) {
102
102
 
103
103
  float* xt = chain[i]->apply(n, prev_x);
104
104
 
105
- if (prev_x != x)
106
- delete[] prev_x;
105
+ if (prev_x != x) {
106
+ del.reset();
107
+ }
108
+
107
109
  prev_x = xt;
108
- del.set(xt);
110
+ del.reset(xt);
109
111
  }
110
112
 
111
113
  is_trained = true;
@@ -113,11 +115,11 @@ void IndexPreTransform::train(idx_t n, const float* x) {
113
115
 
114
116
  const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
115
117
  const float* prev_x = x;
116
- ScopeDeleter<float> del;
118
+ std::unique_ptr<const float[]> del;
117
119
 
118
120
  for (int i = 0; i < chain.size(); i++) {
119
121
  float* xt = chain[i]->apply(n, prev_x);
120
- ScopeDeleter<float> del2(xt);
122
+ std::unique_ptr<const float[]> del2(xt);
121
123
  del2.swap(del);
122
124
  prev_x = xt;
123
125
  }
@@ -128,11 +130,11 @@ const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
128
130
  void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
129
131
  const {
130
132
  const float* next_x = xt;
131
- ScopeDeleter<float> del;
133
+ std::unique_ptr<const float[]> del;
132
134
 
133
135
  for (int i = chain.size() - 1; i >= 0; i--) {
134
136
  float* prev_x = (i == 0) ? x : new float[n * chain[i]->d_in];
135
- ScopeDeleter<float> del2((prev_x == x) ? nullptr : prev_x);
137
+ std::unique_ptr<const float[]> del2((prev_x == x) ? nullptr : prev_x);
136
138
  chain[i]->reverse_transform(n, next_x, prev_x);
137
139
  del2.swap(del);
138
140
  next_x = prev_x;
@@ -141,9 +143,8 @@ void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
141
143
 
142
144
  void IndexPreTransform::add(idx_t n, const float* x) {
143
145
  FAISS_THROW_IF_NOT(is_trained);
144
- const float* xt = apply_chain(n, x);
145
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
146
- index->add(n, xt);
146
+ TransformedVectors tv(x, apply_chain(n, x));
147
+ index->add(n, tv.x);
147
148
  ntotal = index->ntotal;
148
149
  }
149
150
 
@@ -152,9 +153,8 @@ void IndexPreTransform::add_with_ids(
152
153
  const float* x,
153
154
  const idx_t* xids) {
154
155
  FAISS_THROW_IF_NOT(is_trained);
155
- const float* xt = apply_chain(n, x);
156
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
157
- index->add_with_ids(n, xt, xids);
156
+ TransformedVectors tv(x, apply_chain(n, x));
157
+ index->add_with_ids(n, tv.x, xids);
158
158
  ntotal = index->ntotal;
159
159
  }
160
160
 
@@ -178,7 +178,7 @@ void IndexPreTransform::search(
178
178
  FAISS_THROW_IF_NOT(k > 0);
179
179
  FAISS_THROW_IF_NOT(is_trained);
180
180
  const float* xt = apply_chain(n, x);
181
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
181
+ std::unique_ptr<const float[]> del(xt == x ? nullptr : xt);
182
182
  index->search(
183
183
  n, xt, k, distances, labels, extract_index_search_params(params));
184
184
  }
@@ -190,10 +190,9 @@ void IndexPreTransform::range_search(
190
190
  RangeSearchResult* result,
191
191
  const SearchParameters* params) const {
192
192
  FAISS_THROW_IF_NOT(is_trained);
193
- const float* xt = apply_chain(n, x);
194
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
193
+ TransformedVectors tv(x, apply_chain(n, x));
195
194
  index->range_search(
196
- n, xt, radius, result, extract_index_search_params(params));
195
+ n, tv.x, radius, result, extract_index_search_params(params));
197
196
  }
198
197
 
199
198
  void IndexPreTransform::reset() {
@@ -209,7 +208,7 @@ size_t IndexPreTransform::remove_ids(const IDSelector& sel) {
209
208
 
210
209
  void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
211
210
  float* x = chain.empty() ? recons : new float[index->d];
212
- ScopeDeleter<float> del(recons == x ? nullptr : x);
211
+ std::unique_ptr<float[]> del(recons == x ? nullptr : x);
213
212
  // Initial reconstruction
214
213
  index->reconstruct(key, x);
215
214
 
@@ -219,7 +218,7 @@ void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
219
218
 
220
219
  void IndexPreTransform::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
221
220
  float* x = chain.empty() ? recons : new float[ni * index->d];
222
- ScopeDeleter<float> del(recons == x ? nullptr : x);
221
+ std::unique_ptr<float[]> del(recons == x ? nullptr : x);
223
222
  // Initial reconstruction
224
223
  index->reconstruct_n(i0, ni, x);
225
224
 
@@ -238,14 +237,14 @@ void IndexPreTransform::search_and_reconstruct(
238
237
  FAISS_THROW_IF_NOT(k > 0);
239
238
  FAISS_THROW_IF_NOT(is_trained);
240
239
 
241
- const float* xt = apply_chain(n, x);
242
- ScopeDeleter<float> del((xt == x) ? nullptr : xt);
240
+ TransformedVectors trans(x, apply_chain(n, x));
243
241
 
244
242
  float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
245
- ScopeDeleter<float> del2((recons_temp == recons) ? nullptr : recons_temp);
243
+ std::unique_ptr<float[]> del2(
244
+ (recons_temp == recons) ? nullptr : recons_temp);
246
245
  index->search_and_reconstruct(
247
246
  n,
248
- xt,
247
+ trans.x,
249
248
  k,
250
249
  distances,
251
250
  labels,
@@ -262,13 +261,8 @@ size_t IndexPreTransform::sa_code_size() const {
262
261
 
263
262
  void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes)
264
263
  const {
265
- if (chain.empty()) {
266
- index->sa_encode(n, x, bytes);
267
- } else {
268
- const float* xt = apply_chain(n, x);
269
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
270
- index->sa_encode(n, xt, bytes);
271
- }
264
+ TransformedVectors tv(x, apply_chain(n, x));
265
+ index->sa_encode(n, tv.x, bytes);
272
266
  }
273
267
 
274
268
  void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)
@@ -62,7 +62,7 @@ void IndexRefine::reset() {
62
62
 
63
63
  namespace {
64
64
 
65
- typedef faiss::idx_t idx_t;
65
+ using idx_t = faiss::idx_t;
66
66
 
67
67
  template <class C>
68
68
  static void reorder_2_heaps(
@@ -96,25 +96,40 @@ void IndexRefine::search(
96
96
  idx_t k,
97
97
  float* distances,
98
98
  idx_t* labels,
99
- const SearchParameters* params) const {
100
- FAISS_THROW_IF_NOT_MSG(
101
- !params, "search params not supported for this index");
99
+ const SearchParameters* params_in) const {
100
+ const IndexRefineSearchParameters* params = nullptr;
101
+ if (params_in) {
102
+ params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
103
+ FAISS_THROW_IF_NOT_MSG(
104
+ params, "IndexRefine params have incorrect type");
105
+ }
106
+
107
+ idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
108
+ : idx_t(k * k_factor);
109
+ SearchParameters* base_index_params =
110
+ (params != nullptr) ? params->base_index_params : nullptr;
111
+
112
+ FAISS_THROW_IF_NOT(k_base >= k);
113
+
114
+ FAISS_THROW_IF_NOT(base_index);
115
+ FAISS_THROW_IF_NOT(refine_index);
116
+
102
117
  FAISS_THROW_IF_NOT(k > 0);
103
118
  FAISS_THROW_IF_NOT(is_trained);
104
- idx_t k_base = idx_t(k * k_factor);
105
119
  idx_t* base_labels = labels;
106
120
  float* base_distances = distances;
107
- ScopeDeleter<idx_t> del1;
108
- ScopeDeleter<float> del2;
121
+ std::unique_ptr<idx_t[]> del1;
122
+ std::unique_ptr<float[]> del2;
109
123
 
110
124
  if (k != k_base) {
111
125
  base_labels = new idx_t[n * k_base];
112
- del1.set(base_labels);
126
+ del1.reset(base_labels);
113
127
  base_distances = new float[n * k_base];
114
- del2.set(base_distances);
128
+ del2.reset(base_distances);
115
129
  }
116
130
 
117
- base_index->search(n, x, k_base, base_distances, base_labels);
131
+ base_index->search(
132
+ n, x, k_base, base_distances, base_labels, base_index_params);
118
133
 
119
134
  for (int i = 0; i < n * k_base; i++)
120
135
  assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
@@ -225,25 +240,40 @@ void IndexRefineFlat::search(
225
240
  idx_t k,
226
241
  float* distances,
227
242
  idx_t* labels,
228
- const SearchParameters* params) const {
229
- FAISS_THROW_IF_NOT_MSG(
230
- !params, "search params not supported for this index");
243
+ const SearchParameters* params_in) const {
244
+ const IndexRefineSearchParameters* params = nullptr;
245
+ if (params_in) {
246
+ params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
247
+ FAISS_THROW_IF_NOT_MSG(
248
+ params, "IndexRefineFlat params have incorrect type");
249
+ }
250
+
251
+ idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
252
+ : idx_t(k * k_factor);
253
+ SearchParameters* base_index_params =
254
+ (params != nullptr) ? params->base_index_params : nullptr;
255
+
256
+ FAISS_THROW_IF_NOT(k_base >= k);
257
+
258
+ FAISS_THROW_IF_NOT(base_index);
259
+ FAISS_THROW_IF_NOT(refine_index);
260
+
231
261
  FAISS_THROW_IF_NOT(k > 0);
232
262
  FAISS_THROW_IF_NOT(is_trained);
233
- idx_t k_base = idx_t(k * k_factor);
234
263
  idx_t* base_labels = labels;
235
264
  float* base_distances = distances;
236
- ScopeDeleter<idx_t> del1;
237
- ScopeDeleter<float> del2;
265
+ std::unique_ptr<idx_t[]> del1;
266
+ std::unique_ptr<float[]> del2;
238
267
 
239
268
  if (k != k_base) {
240
269
  base_labels = new idx_t[n * k_base];
241
- del1.set(base_labels);
270
+ del1.reset(base_labels);
242
271
  base_distances = new float[n * k_base];
243
- del2.set(base_distances);
272
+ del2.reset(base_distances);
244
273
  }
245
274
 
246
- base_index->search(n, x, k_base, base_distances, base_labels);
275
+ base_index->search(
276
+ n, x, k_base, base_distances, base_labels, base_index_params);
247
277
 
248
278
  for (int i = 0; i < n * k_base; i++)
249
279
  assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
@@ -11,6 +11,13 @@
11
11
 
12
12
  namespace faiss {
13
13
 
14
+ struct IndexRefineSearchParameters : SearchParameters {
15
+ float k_factor = 1;
16
+ SearchParameters* base_index_params = nullptr; // non-owning
17
+
18
+ virtual ~IndexRefineSearchParameters() = default;
19
+ };
20
+
14
21
  /** Index that queries in a base_index (a fast one) and refines the
15
22
  * results with an exact search, hopefully improving the results.
16
23
  */