faiss 0.2.7 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (172) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/lib/faiss.rb +1 -1
  11. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  12. data/vendor/faiss/faiss/AutoTune.h +0 -1
  13. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  14. data/vendor/faiss/faiss/Clustering.h +31 -21
  15. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  16. data/vendor/faiss/faiss/Index.cpp +1 -1
  17. data/vendor/faiss/faiss/Index.h +20 -5
  18. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  21. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  22. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  23. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  34. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  38. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  59. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  60. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  61. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  62. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  63. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  64. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  65. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  66. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  67. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  69. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  70. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  71. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  72. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  73. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  74. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  75. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  76. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  77. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  78. data/vendor/faiss/faiss/clone_index.h +3 -0
  79. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  80. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  81. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  82. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  90. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  92. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  93. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  97. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  98. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  99. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  101. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  103. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  104. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  105. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  106. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  107. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  108. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  109. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  110. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  111. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  113. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  119. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  125. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  126. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  127. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  128. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  129. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  133. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  135. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  136. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  137. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  138. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  139. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  140. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  141. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  142. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  143. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  144. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  145. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  146. data/vendor/faiss/faiss/utils/distances.h +81 -4
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  148. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  150. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  152. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  153. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  154. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  155. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  156. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  157. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  158. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  159. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  160. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  161. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  162. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  163. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  164. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  165. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  166. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  167. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  168. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  169. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  170. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  171. data/vendor/faiss/faiss/utils/utils.h +57 -20
  172. metadata +11 -4
@@ -14,6 +14,7 @@
14
14
  #include <cstring>
15
15
 
16
16
  #include <algorithm>
17
+ #include <memory>
17
18
 
18
19
  #include <faiss/impl/DistanceComputer.h>
19
20
  #include <faiss/impl/FaissAssert.h>
@@ -86,7 +87,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
86
87
  ndis++;
87
88
 
88
89
  float dis = distance_single_code<PQDecoder>(
89
- pq, precomputed_table.data(), code);
90
+ pq.M, pq.nbits, precomputed_table.data(), code);
90
91
  return dis;
91
92
  }
92
93
 
@@ -198,17 +199,16 @@ void IndexPQ::search(
198
199
 
199
200
  } else { // code-to-code distances
200
201
 
201
- uint8_t* q_codes = new uint8_t[n * pq.code_size];
202
- ScopeDeleter<uint8_t> del(q_codes);
202
+ std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]);
203
203
 
204
204
  if (!encode_signs) {
205
- pq.compute_codes(x, q_codes, n);
205
+ pq.compute_codes(x, q_codes.get(), n);
206
206
  } else {
207
207
  FAISS_THROW_IF_NOT(d == pq.nbits * pq.M);
208
- memset(q_codes, 0, n * pq.code_size);
208
+ memset(q_codes.get(), 0, n * pq.code_size);
209
209
  for (size_t i = 0; i < n; i++) {
210
210
  const float* xi = x + i * d;
211
- uint8_t* code = q_codes + i * pq.code_size;
211
+ uint8_t* code = q_codes.get() + i * pq.code_size;
212
212
  for (int j = 0; j < d; j++)
213
213
  if (xi[j] > 0)
214
214
  code[j >> 3] |= 1 << (j & 7);
@@ -219,19 +219,18 @@ void IndexPQ::search(
219
219
  float_maxheap_array_t res = {
220
220
  size_t(n), size_t(k), labels, distances};
221
221
 
222
- pq.search_sdc(q_codes, n, codes.data(), ntotal, &res, true);
222
+ pq.search_sdc(q_codes.get(), n, codes.data(), ntotal, &res, true);
223
223
 
224
224
  } else {
225
- int* idistances = new int[n * k];
226
- ScopeDeleter<int> del(idistances);
225
+ std::unique_ptr<int[]> idistances(new int[n * k]);
227
226
 
228
227
  int_maxheap_array_t res = {
229
- size_t(n), size_t(k), labels, idistances};
228
+ size_t(n), size_t(k), labels, idistances.get()};
230
229
 
231
230
  if (search_type == ST_HE) {
232
231
  hammings_knn_hc(
233
232
  &res,
234
- q_codes,
233
+ q_codes.get(),
235
234
  codes.data(),
236
235
  ntotal,
237
236
  pq.code_size,
@@ -240,7 +239,7 @@ void IndexPQ::search(
240
239
  } else if (search_type == ST_generalized_HE) {
241
240
  generalized_hammings_knn_hc(
242
241
  &res,
243
- q_codes,
242
+ q_codes.get(),
244
243
  codes.data(),
245
244
  ntotal,
246
245
  pq.code_size,
@@ -263,21 +262,23 @@ void IndexPQStats::reset() {
263
262
 
264
263
  IndexPQStats indexPQ_stats;
265
264
 
265
+ namespace {
266
+
266
267
  template <class HammingComputer>
267
- static size_t polysemous_inner_loop(
268
- const IndexPQ& index,
268
+ size_t polysemous_inner_loop(
269
+ const IndexPQ* index,
269
270
  const float* dis_table_qi,
270
271
  const uint8_t* q_code,
271
272
  size_t k,
272
273
  float* heap_dis,
273
274
  int64_t* heap_ids,
274
275
  int ht) {
275
- int M = index.pq.M;
276
- int code_size = index.pq.code_size;
277
- int ksub = index.pq.ksub;
278
- size_t ntotal = index.ntotal;
276
+ int M = index->pq.M;
277
+ int code_size = index->pq.code_size;
278
+ int ksub = index->pq.ksub;
279
+ size_t ntotal = index->ntotal;
279
280
 
280
- const uint8_t* b_code = index.codes.data();
281
+ const uint8_t* b_code = index->codes.data();
281
282
 
282
283
  size_t n_pass_i = 0;
283
284
 
@@ -305,6 +306,16 @@ static size_t polysemous_inner_loop(
305
306
  return n_pass_i;
306
307
  }
307
308
 
309
+ struct Run_polysemous_inner_loop {
310
+ using T = size_t;
311
+ template <class HammingComputer, class... Types>
312
+ size_t f(Types... args) {
313
+ return polysemous_inner_loop<HammingComputer>(args...);
314
+ }
315
+ };
316
+
317
+ } // anonymous namespace
318
+
308
319
  void IndexPQ::search_core_polysemous(
309
320
  idx_t n,
310
321
  const float* x,
@@ -321,22 +332,20 @@ void IndexPQ::search_core_polysemous(
321
332
  }
322
333
 
323
334
  // PQ distance tables
324
- float* dis_tables = new float[n * pq.ksub * pq.M];
325
- ScopeDeleter<float> del(dis_tables);
326
- pq.compute_distance_tables(n, x, dis_tables);
335
+ std::unique_ptr<float[]> dis_tables(new float[n * pq.ksub * pq.M]);
336
+ pq.compute_distance_tables(n, x, dis_tables.get());
327
337
 
328
338
  // Hamming embedding queries
329
- uint8_t* q_codes = new uint8_t[n * pq.code_size];
330
- ScopeDeleter<uint8_t> del2(q_codes);
339
+ std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]);
331
340
 
332
341
  if (false) {
333
- pq.compute_codes(x, q_codes, n);
342
+ pq.compute_codes(x, q_codes.get(), n);
334
343
  } else {
335
344
  #pragma omp parallel for
336
345
  for (idx_t qi = 0; qi < n; qi++) {
337
346
  pq.compute_code_from_distance_table(
338
- dis_tables + qi * pq.M * pq.ksub,
339
- q_codes + qi * pq.code_size);
347
+ dis_tables.get() + qi * pq.M * pq.ksub,
348
+ q_codes.get() + qi * pq.code_size);
340
349
  }
341
350
  }
342
351
 
@@ -346,54 +355,33 @@ void IndexPQ::search_core_polysemous(
346
355
 
347
356
  #pragma omp parallel for reduction(+ : n_pass, bad_code_size)
348
357
  for (idx_t qi = 0; qi < n; qi++) {
349
- const uint8_t* q_code = q_codes + qi * pq.code_size;
358
+ const uint8_t* q_code = q_codes.get() + qi * pq.code_size;
350
359
 
351
- const float* dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
360
+ const float* dis_table_qi = dis_tables.get() + qi * pq.M * pq.ksub;
352
361
 
353
362
  int64_t* heap_ids = labels + qi * k;
354
363
  float* heap_dis = distances + qi * k;
355
364
  maxheap_heapify(k, heap_dis, heap_ids);
356
365
 
357
366
  if (!generalized_hamming) {
358
- switch (pq.code_size) {
359
- #define DISPATCH(cs) \
360
- case cs: \
361
- n_pass += polysemous_inner_loop<HammingComputer##cs>( \
362
- *this, \
363
- dis_table_qi, \
364
- q_code, \
365
- k, \
366
- heap_dis, \
367
- heap_ids, \
368
- polysemous_ht); \
369
- break;
370
- DISPATCH(4)
371
- DISPATCH(8)
372
- DISPATCH(16)
373
- DISPATCH(32)
374
- DISPATCH(20)
375
- default:
376
- if (pq.code_size % 4 == 0) {
377
- n_pass += polysemous_inner_loop<HammingComputerDefault>(
378
- *this,
379
- dis_table_qi,
380
- q_code,
381
- k,
382
- heap_dis,
383
- heap_ids,
384
- polysemous_ht);
385
- } else {
386
- bad_code_size++;
387
- }
388
- break;
389
- }
390
- #undef DISPATCH
367
+ Run_polysemous_inner_loop r;
368
+ n_pass += dispatch_HammingComputer(
369
+ pq.code_size,
370
+ r,
371
+ this,
372
+ dis_table_qi,
373
+ q_code,
374
+ k,
375
+ heap_dis,
376
+ heap_ids,
377
+ polysemous_ht);
378
+
391
379
  } else { // generalized hamming
392
380
  switch (pq.code_size) {
393
381
  #define DISPATCH(cs) \
394
382
  case cs: \
395
383
  n_pass += polysemous_inner_loop<GenHammingComputer##cs>( \
396
- *this, \
384
+ this, \
397
385
  dis_table_qi, \
398
386
  q_code, \
399
387
  k, \
@@ -407,7 +395,7 @@ void IndexPQ::search_core_polysemous(
407
395
  default:
408
396
  if (pq.code_size % 8 == 0) {
409
397
  n_pass += polysemous_inner_loop<GenHammingComputerM8>(
410
- *this,
398
+ this,
411
399
  dis_table_qi,
412
400
  q_code,
413
401
  k,
@@ -450,12 +438,11 @@ void IndexPQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
450
438
 
451
439
  void IndexPQ::hamming_distance_table(idx_t n, const float* x, int32_t* dis)
452
440
  const {
453
- uint8_t* q_codes = new uint8_t[n * pq.code_size];
454
- ScopeDeleter<uint8_t> del(q_codes);
441
+ std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]);
455
442
 
456
- pq.compute_codes(x, q_codes, n);
443
+ pq.compute_codes(x, q_codes.get(), n);
457
444
 
458
- hammings(q_codes, codes.data(), n, ntotal, pq.code_size, dis);
445
+ hammings(q_codes.get(), codes.data(), n, ntotal, pq.code_size, dis);
459
446
  }
460
447
 
461
448
  void IndexPQ::hamming_distance_histogram(
@@ -469,16 +456,15 @@ void IndexPQ::hamming_distance_histogram(
469
456
  FAISS_THROW_IF_NOT(pq.nbits == 8);
470
457
 
471
458
  // Hamming embedding queries
472
- uint8_t* q_codes = new uint8_t[n * pq.code_size];
473
- ScopeDeleter<uint8_t> del(q_codes);
474
- pq.compute_codes(x, q_codes, n);
459
+ std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]);
460
+ pq.compute_codes(x, q_codes.get(), n);
475
461
 
476
462
  uint8_t* b_codes;
477
- ScopeDeleter<uint8_t> del_b_codes;
463
+ std::unique_ptr<uint8_t[]> del_b_codes;
478
464
 
479
465
  if (xb) {
480
466
  b_codes = new uint8_t[nb * pq.code_size];
481
- del_b_codes.set(b_codes);
467
+ del_b_codes.reset(b_codes);
482
468
  pq.compute_codes(xb, b_codes, nb);
483
469
  } else {
484
470
  nb = ntotal;
@@ -491,8 +477,7 @@ void IndexPQ::hamming_distance_histogram(
491
477
  #pragma omp parallel
492
478
  {
493
479
  std::vector<int64_t> histi(nbits + 1);
494
- hamdis_t* distances = new hamdis_t[nb * bs];
495
- ScopeDeleter<hamdis_t> del(distances);
480
+ std::unique_ptr<hamdis_t[]> distances(new hamdis_t[nb * bs]);
496
481
  #pragma omp for
497
482
  for (idx_t q0 = 0; q0 < n; q0 += bs) {
498
483
  // printf ("dis stats: %zd/%zd\n", q0, n);
@@ -501,12 +486,12 @@ void IndexPQ::hamming_distance_histogram(
501
486
  q1 = n;
502
487
 
503
488
  hammings(
504
- q_codes + q0 * pq.code_size,
489
+ q_codes.get() + q0 * pq.code_size,
505
490
  b_codes,
506
491
  q1 - q0,
507
492
  nb,
508
493
  pq.code_size,
509
- distances);
494
+ distances.get());
510
495
 
511
496
  for (size_t i = 0; i < nb * (q1 - q0); i++)
512
497
  histi[distances[i]]++;
@@ -639,7 +624,7 @@ struct SemiSortedArray {
639
624
  int N;
640
625
 
641
626
  // type of the heap: CMax = sort ascending
642
- typedef CMax<T, int> HC;
627
+ using HC = CMax<T, int>;
643
628
  std::vector<int> perm;
644
629
 
645
630
  int k; // k elements are sorted
@@ -733,7 +718,7 @@ struct MinSumK {
733
718
  * We use a heap to maintain a queue of sums, with the associated
734
719
  * terms involved in the sum.
735
720
  */
736
- typedef CMin<T, int64_t> HC;
721
+ using HC = CMin<T, int64_t>;
737
722
  size_t heap_capacity, heap_size;
738
723
  T* bh_val;
739
724
  int64_t* bh_ids;
@@ -827,7 +812,7 @@ struct MinSumK {
827
812
  // enqueue followers
828
813
  int64_t ii = ti;
829
814
  for (int m = 0; m < M; m++) {
830
- int64_t n = ii & ((1L << nbit) - 1);
815
+ int64_t n = ii & (((int64_t)1 << nbit) - 1);
831
816
  ii >>= nbit;
832
817
  if (n + 1 >= N)
833
818
  continue;
@@ -851,7 +836,7 @@ struct MinSumK {
851
836
  }
852
837
  int64_t ti = 0;
853
838
  for (int m = 0; m < M; m++) {
854
- int64_t n = ii & ((1L << nbit) - 1);
839
+ int64_t n = ii & (((int64_t)1 << nbit) - 1);
855
840
  ti += int64_t(ssx[m].get_ord(n)) << (nbit * m);
856
841
  ii >>= nbit;
857
842
  }
@@ -923,17 +908,16 @@ void MultiIndexQuantizer::search(
923
908
  return;
924
909
  }
925
910
 
926
- float* dis_tables = new float[n * pq.ksub * pq.M];
927
- ScopeDeleter<float> del(dis_tables);
911
+ std::unique_ptr<float[]> dis_tables(new float[n * pq.ksub * pq.M]);
928
912
 
929
- pq.compute_distance_tables(n, x, dis_tables);
913
+ pq.compute_distance_tables(n, x, dis_tables.get());
930
914
 
931
915
  if (k == 1) {
932
916
  // simple version that just finds the min in each table
933
917
 
934
918
  #pragma omp parallel for
935
919
  for (int i = 0; i < n; i++) {
936
- const float* dis_table = dis_tables + i * pq.ksub * pq.M;
920
+ const float* dis_table = dis_tables.get() + i * pq.ksub * pq.M;
937
921
  float dis = 0;
938
922
  idx_t label = 0;
939
923
 
@@ -963,7 +947,7 @@ void MultiIndexQuantizer::search(
963
947
  k, pq.M, pq.nbits, pq.ksub);
964
948
  #pragma omp for
965
949
  for (int i = 0; i < n; i++) {
966
- msk.run(dis_tables + i * pq.ksub * pq.M,
950
+ msk.run(dis_tables.get() + i * pq.ksub * pq.M,
967
951
  pq.ksub,
968
952
  distances + i * k,
969
953
  labels + i * k);
@@ -975,7 +959,7 @@ void MultiIndexQuantizer::search(
975
959
  void MultiIndexQuantizer::reconstruct(idx_t key, float* recons) const {
976
960
  int64_t jj = key;
977
961
  for (int m = 0; m < pq.M; m++) {
978
- int64_t n = jj & ((1L << pq.nbits) - 1);
962
+ int64_t n = jj & (((int64_t)1 << pq.nbits) - 1);
979
963
  jj >>= pq.nbits;
980
964
  memcpy(recons, pq.get_centroids(m, n), sizeof(recons[0]) * pq.dsub);
981
965
  recons += pq.dsub;
@@ -1107,7 +1091,7 @@ void MultiIndexQuantizer2::search(
1107
1091
 
1108
1092
  const idx_t* idmap0 = sub_ids.data() + i * k2;
1109
1093
  int64_t ld_idmap = k2 * n;
1110
- int64_t mask1 = ksub - 1L;
1094
+ int64_t mask1 = ksub - (int64_t)1;
1111
1095
 
1112
1096
  for (int k = 0; k < K; k++) {
1113
1097
  const idx_t* idmap = idmap0;
@@ -31,10 +31,7 @@ struct IndexPQ : IndexFlatCodes {
31
31
  * @param M number of subquantizers
32
32
  * @param nbits number of bit per subvector index
33
33
  */
34
- IndexPQ(int d, ///< dimensionality of the input vectors
35
- size_t M, ///< number of subquantizers
36
- size_t nbits, ///< number of bit per subvector index
37
- MetricType metric = METRIC_L2);
34
+ IndexPQ(int d, size_t M, size_t nbits, MetricType metric = METRIC_L2);
38
35
 
39
36
  IndexPQ();
40
37
 
@@ -7,8 +7,8 @@
7
7
 
8
8
  #include <faiss/IndexPQFastScan.h>
9
9
 
10
- #include <limits.h>
11
10
  #include <cassert>
11
+ #include <climits>
12
12
  #include <memory>
13
13
 
14
14
  #include <omp.h>
@@ -67,7 +67,7 @@ void IndexPreTransform::train(idx_t n, const float* x) {
67
67
  }
68
68
  }
69
69
  const float* prev_x = x;
70
- ScopeDeleter<float> del;
70
+ std::unique_ptr<const float[]> del;
71
71
 
72
72
  if (verbose) {
73
73
  printf("IndexPreTransform::train: training chain 0 to %d\n",
@@ -102,10 +102,12 @@ void IndexPreTransform::train(idx_t n, const float* x) {
102
102
 
103
103
  float* xt = chain[i]->apply(n, prev_x);
104
104
 
105
- if (prev_x != x)
106
- delete[] prev_x;
105
+ if (prev_x != x) {
106
+ del.reset();
107
+ }
108
+
107
109
  prev_x = xt;
108
- del.set(xt);
110
+ del.reset(xt);
109
111
  }
110
112
 
111
113
  is_trained = true;
@@ -113,11 +115,11 @@ void IndexPreTransform::train(idx_t n, const float* x) {
113
115
 
114
116
  const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
115
117
  const float* prev_x = x;
116
- ScopeDeleter<float> del;
118
+ std::unique_ptr<const float[]> del;
117
119
 
118
120
  for (int i = 0; i < chain.size(); i++) {
119
121
  float* xt = chain[i]->apply(n, prev_x);
120
- ScopeDeleter<float> del2(xt);
122
+ std::unique_ptr<const float[]> del2(xt);
121
123
  del2.swap(del);
122
124
  prev_x = xt;
123
125
  }
@@ -128,11 +130,11 @@ const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
128
130
  void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
129
131
  const {
130
132
  const float* next_x = xt;
131
- ScopeDeleter<float> del;
133
+ std::unique_ptr<const float[]> del;
132
134
 
133
135
  for (int i = chain.size() - 1; i >= 0; i--) {
134
136
  float* prev_x = (i == 0) ? x : new float[n * chain[i]->d_in];
135
- ScopeDeleter<float> del2((prev_x == x) ? nullptr : prev_x);
137
+ std::unique_ptr<const float[]> del2((prev_x == x) ? nullptr : prev_x);
136
138
  chain[i]->reverse_transform(n, next_x, prev_x);
137
139
  del2.swap(del);
138
140
  next_x = prev_x;
@@ -141,9 +143,8 @@ void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
141
143
 
142
144
  void IndexPreTransform::add(idx_t n, const float* x) {
143
145
  FAISS_THROW_IF_NOT(is_trained);
144
- const float* xt = apply_chain(n, x);
145
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
146
- index->add(n, xt);
146
+ TransformedVectors tv(x, apply_chain(n, x));
147
+ index->add(n, tv.x);
147
148
  ntotal = index->ntotal;
148
149
  }
149
150
 
@@ -152,9 +153,8 @@ void IndexPreTransform::add_with_ids(
152
153
  const float* x,
153
154
  const idx_t* xids) {
154
155
  FAISS_THROW_IF_NOT(is_trained);
155
- const float* xt = apply_chain(n, x);
156
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
157
- index->add_with_ids(n, xt, xids);
156
+ TransformedVectors tv(x, apply_chain(n, x));
157
+ index->add_with_ids(n, tv.x, xids);
158
158
  ntotal = index->ntotal;
159
159
  }
160
160
 
@@ -178,7 +178,7 @@ void IndexPreTransform::search(
178
178
  FAISS_THROW_IF_NOT(k > 0);
179
179
  FAISS_THROW_IF_NOT(is_trained);
180
180
  const float* xt = apply_chain(n, x);
181
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
181
+ std::unique_ptr<const float[]> del(xt == x ? nullptr : xt);
182
182
  index->search(
183
183
  n, xt, k, distances, labels, extract_index_search_params(params));
184
184
  }
@@ -190,10 +190,9 @@ void IndexPreTransform::range_search(
190
190
  RangeSearchResult* result,
191
191
  const SearchParameters* params) const {
192
192
  FAISS_THROW_IF_NOT(is_trained);
193
- const float* xt = apply_chain(n, x);
194
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
193
+ TransformedVectors tv(x, apply_chain(n, x));
195
194
  index->range_search(
196
- n, xt, radius, result, extract_index_search_params(params));
195
+ n, tv.x, radius, result, extract_index_search_params(params));
197
196
  }
198
197
 
199
198
  void IndexPreTransform::reset() {
@@ -209,7 +208,7 @@ size_t IndexPreTransform::remove_ids(const IDSelector& sel) {
209
208
 
210
209
  void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
211
210
  float* x = chain.empty() ? recons : new float[index->d];
212
- ScopeDeleter<float> del(recons == x ? nullptr : x);
211
+ std::unique_ptr<float[]> del(recons == x ? nullptr : x);
213
212
  // Initial reconstruction
214
213
  index->reconstruct(key, x);
215
214
 
@@ -219,7 +218,7 @@ void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
219
218
 
220
219
  void IndexPreTransform::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
221
220
  float* x = chain.empty() ? recons : new float[ni * index->d];
222
- ScopeDeleter<float> del(recons == x ? nullptr : x);
221
+ std::unique_ptr<float[]> del(recons == x ? nullptr : x);
223
222
  // Initial reconstruction
224
223
  index->reconstruct_n(i0, ni, x);
225
224
 
@@ -238,14 +237,14 @@ void IndexPreTransform::search_and_reconstruct(
238
237
  FAISS_THROW_IF_NOT(k > 0);
239
238
  FAISS_THROW_IF_NOT(is_trained);
240
239
 
241
- const float* xt = apply_chain(n, x);
242
- ScopeDeleter<float> del((xt == x) ? nullptr : xt);
240
+ TransformedVectors trans(x, apply_chain(n, x));
243
241
 
244
242
  float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
245
- ScopeDeleter<float> del2((recons_temp == recons) ? nullptr : recons_temp);
243
+ std::unique_ptr<float[]> del2(
244
+ (recons_temp == recons) ? nullptr : recons_temp);
246
245
  index->search_and_reconstruct(
247
246
  n,
248
- xt,
247
+ trans.x,
249
248
  k,
250
249
  distances,
251
250
  labels,
@@ -262,13 +261,8 @@ size_t IndexPreTransform::sa_code_size() const {
262
261
 
263
262
  void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes)
264
263
  const {
265
- if (chain.empty()) {
266
- index->sa_encode(n, x, bytes);
267
- } else {
268
- const float* xt = apply_chain(n, x);
269
- ScopeDeleter<float> del(xt == x ? nullptr : xt);
270
- index->sa_encode(n, xt, bytes);
271
- }
264
+ TransformedVectors tv(x, apply_chain(n, x));
265
+ index->sa_encode(n, tv.x, bytes);
272
266
  }
273
267
 
274
268
  void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)
@@ -62,7 +62,7 @@ void IndexRefine::reset() {
62
62
 
63
63
  namespace {
64
64
 
65
- typedef faiss::idx_t idx_t;
65
+ using idx_t = faiss::idx_t;
66
66
 
67
67
  template <class C>
68
68
  static void reorder_2_heaps(
@@ -96,25 +96,40 @@ void IndexRefine::search(
96
96
  idx_t k,
97
97
  float* distances,
98
98
  idx_t* labels,
99
- const SearchParameters* params) const {
100
- FAISS_THROW_IF_NOT_MSG(
101
- !params, "search params not supported for this index");
99
+ const SearchParameters* params_in) const {
100
+ const IndexRefineSearchParameters* params = nullptr;
101
+ if (params_in) {
102
+ params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
103
+ FAISS_THROW_IF_NOT_MSG(
104
+ params, "IndexRefine params have incorrect type");
105
+ }
106
+
107
+ idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
108
+ : idx_t(k * k_factor);
109
+ SearchParameters* base_index_params =
110
+ (params != nullptr) ? params->base_index_params : nullptr;
111
+
112
+ FAISS_THROW_IF_NOT(k_base >= k);
113
+
114
+ FAISS_THROW_IF_NOT(base_index);
115
+ FAISS_THROW_IF_NOT(refine_index);
116
+
102
117
  FAISS_THROW_IF_NOT(k > 0);
103
118
  FAISS_THROW_IF_NOT(is_trained);
104
- idx_t k_base = idx_t(k * k_factor);
105
119
  idx_t* base_labels = labels;
106
120
  float* base_distances = distances;
107
- ScopeDeleter<idx_t> del1;
108
- ScopeDeleter<float> del2;
121
+ std::unique_ptr<idx_t[]> del1;
122
+ std::unique_ptr<float[]> del2;
109
123
 
110
124
  if (k != k_base) {
111
125
  base_labels = new idx_t[n * k_base];
112
- del1.set(base_labels);
126
+ del1.reset(base_labels);
113
127
  base_distances = new float[n * k_base];
114
- del2.set(base_distances);
128
+ del2.reset(base_distances);
115
129
  }
116
130
 
117
- base_index->search(n, x, k_base, base_distances, base_labels);
131
+ base_index->search(
132
+ n, x, k_base, base_distances, base_labels, base_index_params);
118
133
 
119
134
  for (int i = 0; i < n * k_base; i++)
120
135
  assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
@@ -225,25 +240,40 @@ void IndexRefineFlat::search(
225
240
  idx_t k,
226
241
  float* distances,
227
242
  idx_t* labels,
228
- const SearchParameters* params) const {
229
- FAISS_THROW_IF_NOT_MSG(
230
- !params, "search params not supported for this index");
243
+ const SearchParameters* params_in) const {
244
+ const IndexRefineSearchParameters* params = nullptr;
245
+ if (params_in) {
246
+ params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
247
+ FAISS_THROW_IF_NOT_MSG(
248
+ params, "IndexRefineFlat params have incorrect type");
249
+ }
250
+
251
+ idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
252
+ : idx_t(k * k_factor);
253
+ SearchParameters* base_index_params =
254
+ (params != nullptr) ? params->base_index_params : nullptr;
255
+
256
+ FAISS_THROW_IF_NOT(k_base >= k);
257
+
258
+ FAISS_THROW_IF_NOT(base_index);
259
+ FAISS_THROW_IF_NOT(refine_index);
260
+
231
261
  FAISS_THROW_IF_NOT(k > 0);
232
262
  FAISS_THROW_IF_NOT(is_trained);
233
- idx_t k_base = idx_t(k * k_factor);
234
263
  idx_t* base_labels = labels;
235
264
  float* base_distances = distances;
236
- ScopeDeleter<idx_t> del1;
237
- ScopeDeleter<float> del2;
265
+ std::unique_ptr<idx_t[]> del1;
266
+ std::unique_ptr<float[]> del2;
238
267
 
239
268
  if (k != k_base) {
240
269
  base_labels = new idx_t[n * k_base];
241
- del1.set(base_labels);
270
+ del1.reset(base_labels);
242
271
  base_distances = new float[n * k_base];
243
- del2.set(base_distances);
272
+ del2.reset(base_distances);
244
273
  }
245
274
 
246
- base_index->search(n, x, k_base, base_distances, base_labels);
275
+ base_index->search(
276
+ n, x, k_base, base_distances, base_labels, base_index_params);
247
277
 
248
278
  for (int i = 0; i < n * k_base; i++)
249
279
  assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
@@ -11,6 +11,13 @@
11
11
 
12
12
  namespace faiss {
13
13
 
14
+ struct IndexRefineSearchParameters : SearchParameters {
15
+ float k_factor = 1;
16
+ SearchParameters* base_index_params = nullptr; // non-owning
17
+
18
+ virtual ~IndexRefineSearchParameters() = default;
19
+ };
20
+
14
21
  /** Index that queries in a base_index (a fast one) and refines the
15
22
  * results with an exact search, hopefully improving the results.
16
23
  */