faiss 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -7,8 +7,8 @@
7
7
 
8
8
  #include <faiss/IndexAdditiveQuantizerFastScan.h>
9
9
 
10
- #include <limits.h>
11
10
  #include <cassert>
11
+ #include <climits>
12
12
  #include <memory>
13
13
 
14
14
  #include <omp.h>
@@ -35,30 +35,30 @@ IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan(
35
35
  }
36
36
 
37
37
  void IndexAdditiveQuantizerFastScan::init(
38
- AdditiveQuantizer* aq,
38
+ AdditiveQuantizer* aq_2,
39
39
  MetricType metric,
40
40
  int bbs) {
41
- FAISS_THROW_IF_NOT(aq != nullptr);
42
- FAISS_THROW_IF_NOT(!aq->nbits.empty());
43
- FAISS_THROW_IF_NOT(aq->nbits[0] == 4);
41
+ FAISS_THROW_IF_NOT(aq_2 != nullptr);
42
+ FAISS_THROW_IF_NOT(!aq_2->nbits.empty());
43
+ FAISS_THROW_IF_NOT(aq_2->nbits[0] == 4);
44
44
  if (metric == METRIC_INNER_PRODUCT) {
45
45
  FAISS_THROW_IF_NOT_MSG(
46
- aq->search_type == AdditiveQuantizer::ST_LUT_nonorm,
46
+ aq_2->search_type == AdditiveQuantizer::ST_LUT_nonorm,
47
47
  "Search type must be ST_LUT_nonorm for IP metric");
48
48
  } else {
49
49
  FAISS_THROW_IF_NOT_MSG(
50
- aq->search_type == AdditiveQuantizer::ST_norm_lsq2x4 ||
51
- aq->search_type == AdditiveQuantizer::ST_norm_rq2x4,
50
+ aq_2->search_type == AdditiveQuantizer::ST_norm_lsq2x4 ||
51
+ aq_2->search_type == AdditiveQuantizer::ST_norm_rq2x4,
52
52
  "Search type must be lsq2x4 or rq2x4 for L2 metric");
53
53
  }
54
54
 
55
- this->aq = aq;
55
+ this->aq = aq_2;
56
56
  if (metric == METRIC_L2) {
57
- M = aq->M + 2; // 2x4 bits AQ
57
+ M = aq_2->M + 2; // 2x4 bits AQ
58
58
  } else {
59
- M = aq->M;
59
+ M = aq_2->M;
60
60
  }
61
- init_fastscan(aq->d, M, 4, metric, bbs);
61
+ init_fastscan(aq_2->d, M, 4, metric, bbs);
62
62
 
63
63
  max_train_points = 1024 * ksub * M;
64
64
  }
@@ -83,7 +83,7 @@ IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan(
83
83
  pq4_pack_codes(orig_codes, ntotal, M, ntotal2, bbs, M2, codes.get());
84
84
  }
85
85
 
86
- IndexAdditiveQuantizerFastScan::~IndexAdditiveQuantizerFastScan() {}
86
+ IndexAdditiveQuantizerFastScan::~IndexAdditiveQuantizerFastScan() = default;
87
87
 
88
88
  void IndexAdditiveQuantizerFastScan::train(idx_t n, const float* x_in) {
89
89
  if (is_trained) {
@@ -203,9 +203,9 @@ void IndexAdditiveQuantizerFastScan::search(
203
203
 
204
204
  NormTableScaler scaler(norm_scale);
205
205
  if (metric_type == METRIC_L2) {
206
- search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
206
+ search_dispatch_implem<true>(n, x, k, distances, labels, &scaler);
207
207
  } else {
208
- search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
208
+ search_dispatch_implem<false>(n, x, k, distances, labels, &scaler);
209
209
  }
210
210
  }
211
211
 
@@ -15,7 +15,12 @@
15
15
 
16
16
  namespace faiss {
17
17
 
18
- IndexBinary::~IndexBinary() {}
18
+ IndexBinary::IndexBinary(idx_t d, MetricType metric)
19
+ : d(d), code_size(d / 8), metric_type(metric) {
20
+ FAISS_THROW_IF_NOT(d % 8 == 0);
21
+ }
22
+
23
+ IndexBinary::~IndexBinary() = default;
19
24
 
20
25
  void IndexBinary::train(idx_t, const uint8_t*) {
21
26
  // Does nothing by default.
@@ -51,7 +56,7 @@ void IndexBinary::reconstruct(idx_t, uint8_t*) const {
51
56
 
52
57
  void IndexBinary::reconstruct_n(idx_t i0, idx_t ni, uint8_t* recons) const {
53
58
  for (idx_t i = 0; i < ni; i++) {
54
- reconstruct(i0 + i, recons + i * d);
59
+ reconstruct(i0 + i, recons + i * code_size);
55
60
  }
56
61
  }
57
62
 
@@ -70,10 +75,10 @@ void IndexBinary::search_and_reconstruct(
70
75
  for (idx_t j = 0; j < k; ++j) {
71
76
  idx_t ij = i * k + j;
72
77
  idx_t key = labels[ij];
73
- uint8_t* reconstructed = recons + ij * d;
78
+ uint8_t* reconstructed = recons + ij * code_size;
74
79
  if (key < 0) {
75
80
  // Fill with NaNs
76
- memset(reconstructed, -1, sizeof(*reconstructed) * d);
81
+ memset(reconstructed, -1, code_size);
77
82
  } else {
78
83
  reconstruct(key, reconstructed);
79
84
  }
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #ifndef FAISS_INDEX_BINARY_H
11
9
  #define FAISS_INDEX_BINARY_H
12
10
 
@@ -16,7 +14,6 @@
16
14
  #include <typeinfo>
17
15
 
18
16
  #include <faiss/Index.h>
19
- #include <faiss/impl/FaissAssert.h>
20
17
 
21
18
  namespace faiss {
22
19
 
@@ -35,27 +32,19 @@ struct IndexBinary {
35
32
  using component_t = uint8_t;
36
33
  using distance_t = int32_t;
37
34
 
38
- int d; ///< vector dimension
39
- int code_size; ///< number of bytes per vector ( = d / 8 )
40
- idx_t ntotal; ///< total nb of indexed vectors
41
- bool verbose; ///< verbosity level
35
+ int d = 0; ///< vector dimension
36
+ int code_size = 0; ///< number of bytes per vector ( = d / 8 )
37
+ idx_t ntotal = 0; ///< total nb of indexed vectors
38
+ bool verbose = false; ///< verbosity level
42
39
 
43
40
  /// set if the Index does not require training, or if training is done
44
41
  /// already
45
- bool is_trained;
42
+ bool is_trained = true;
46
43
 
47
44
  /// type of metric this index uses for search
48
- MetricType metric_type;
49
-
50
- explicit IndexBinary(idx_t d = 0, MetricType metric = METRIC_L2)
51
- : d(d),
52
- code_size(d / 8),
53
- ntotal(0),
54
- verbose(false),
55
- is_trained(true),
56
- metric_type(metric) {
57
- FAISS_THROW_IF_NOT(d % 8 == 0);
58
- }
45
+ MetricType metric_type = METRIC_L2;
46
+
47
+ explicit IndexBinary(idx_t d = 0, MetricType metric = METRIC_L2);
59
48
 
60
49
  virtual ~IndexBinary();
61
50
 
@@ -9,13 +9,14 @@
9
9
 
10
10
  #include <faiss/IndexBinaryFromFloat.h>
11
11
 
12
+ #include <faiss/impl/FaissAssert.h>
12
13
  #include <faiss/utils/utils.h>
13
14
  #include <algorithm>
14
15
  #include <memory>
15
16
 
16
17
  namespace faiss {
17
18
 
18
- IndexBinaryFromFloat::IndexBinaryFromFloat() {}
19
+ IndexBinaryFromFloat::IndexBinaryFromFloat() = default;
19
20
 
20
21
  IndexBinaryFromFloat::IndexBinaryFromFloat(Index* index)
21
22
  : IndexBinary(index->d), index(index), own_fields(false) {
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/IndexBinaryHNSW.h>
11
9
 
12
10
  #include <omp.h>
@@ -20,14 +18,15 @@
20
18
  #include <queue>
21
19
  #include <unordered_set>
22
20
 
23
- #include <stdint.h>
24
21
  #include <sys/stat.h>
25
22
  #include <sys/types.h>
23
+ #include <cstdint>
26
24
 
27
25
  #include <faiss/IndexBinaryFlat.h>
28
26
  #include <faiss/impl/AuxIndexStructures.h>
29
27
  #include <faiss/impl/DistanceComputer.h>
30
28
  #include <faiss/impl/FaissAssert.h>
29
+ #include <faiss/impl/ResultHandler.h>
31
30
  #include <faiss/utils/Heap.h>
32
31
  #include <faiss/utils/hamming.h>
33
32
  #include <faiss/utils/random.h>
@@ -201,27 +200,31 @@ void IndexBinaryHNSW::search(
201
200
  !params, "search params not supported for this index");
202
201
  FAISS_THROW_IF_NOT(k > 0);
203
202
 
203
+ // we use the buffer for distances as float but convert them back
204
+ // to int in the end
205
+ float* distances_f = (float*)distances;
206
+
207
+ using RH = HeapBlockResultHandler<HNSW::C>;
208
+ RH bres(n, distances_f, labels, k);
209
+
204
210
  #pragma omp parallel
205
211
  {
206
212
  VisitedTable vt(ntotal);
207
213
  std::unique_ptr<DistanceComputer> dis(get_distance_computer());
214
+ RH::SingleResultHandler res(bres);
208
215
 
209
216
  #pragma omp for
210
217
  for (idx_t i = 0; i < n; i++) {
211
- idx_t* idxi = labels + i * k;
212
- float* simi = (float*)(distances + i * k);
213
-
218
+ res.begin(i);
214
219
  dis->set_query((float*)(x + i * code_size));
215
-
216
- maxheap_heapify(k, simi, idxi);
217
- hnsw.search(*dis, k, idxi, simi, vt);
218
- maxheap_reorder(k, simi, idxi);
220
+ hnsw.search(*dis, res, vt);
221
+ res.end();
219
222
  }
220
223
  }
221
224
 
222
225
  #pragma omp parallel for
223
226
  for (int i = 0; i < n * k; ++i) {
224
- distances[i] = std::round(((float*)distances)[i]);
227
+ distances[i] = std::round(distances_f[i]);
225
228
  }
226
229
  }
227
230
 
@@ -281,31 +284,21 @@ struct FlatHammingDis : DistanceComputer {
281
284
  }
282
285
  };
283
286
 
287
+ struct BuildDistanceComputer {
288
+ using T = DistanceComputer*;
289
+ template <class HammingComputer>
290
+ DistanceComputer* f(IndexBinaryFlat* flat_storage) {
291
+ return new FlatHammingDis<HammingComputer>(*flat_storage);
292
+ }
293
+ };
294
+
284
295
  } // namespace
285
296
 
286
297
  DistanceComputer* IndexBinaryHNSW::get_distance_computer() const {
287
298
  IndexBinaryFlat* flat_storage = dynamic_cast<IndexBinaryFlat*>(storage);
288
-
289
299
  FAISS_ASSERT(flat_storage != nullptr);
290
-
291
- switch (code_size) {
292
- case 4:
293
- return new FlatHammingDis<HammingComputer4>(*flat_storage);
294
- case 8:
295
- return new FlatHammingDis<HammingComputer8>(*flat_storage);
296
- case 16:
297
- return new FlatHammingDis<HammingComputer16>(*flat_storage);
298
- case 20:
299
- return new FlatHammingDis<HammingComputer20>(*flat_storage);
300
- case 32:
301
- return new FlatHammingDis<HammingComputer32>(*flat_storage);
302
- case 64:
303
- return new FlatHammingDis<HammingComputer64>(*flat_storage);
304
- default:
305
- break;
306
- }
307
-
308
- return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
300
+ BuildDistanceComputer bd;
301
+ return dispatch_HammingComputer(code_size, bd, flat_storage);
309
302
  }
310
303
 
311
304
  } // namespace faiss
@@ -176,6 +176,14 @@ void search_single_query_template(
176
176
  } while (fe.next());
177
177
  }
178
178
 
179
+ struct Run_search_single_query {
180
+ using T = void;
181
+ template <class HammingComputer, class... Types>
182
+ T f(Types... args) {
183
+ search_single_query_template<HammingComputer>(args...);
184
+ }
185
+ };
186
+
179
187
  template <class SearchResults>
180
188
  void search_single_query(
181
189
  const IndexBinaryHash& index,
@@ -184,29 +192,9 @@ void search_single_query(
184
192
  size_t& n0,
185
193
  size_t& nlist,
186
194
  size_t& ndis) {
187
- #define HC(name) \
188
- search_single_query_template<name>(index, q, res, n0, nlist, ndis);
189
- switch (index.code_size) {
190
- case 4:
191
- HC(HammingComputer4);
192
- break;
193
- case 8:
194
- HC(HammingComputer8);
195
- break;
196
- case 16:
197
- HC(HammingComputer16);
198
- break;
199
- case 20:
200
- HC(HammingComputer20);
201
- break;
202
- case 32:
203
- HC(HammingComputer32);
204
- break;
205
- default:
206
- HC(HammingComputerDefault);
207
- break;
208
- }
209
- #undef HC
195
+ Run_search_single_query r;
196
+ dispatch_HammingComputer(
197
+ index.code_size, r, index, q, res, n0, nlist, ndis);
210
198
  }
211
199
 
212
200
  } // anonymous namespace
@@ -349,15 +337,14 @@ namespace {
349
337
 
350
338
  template <class HammingComputer, class SearchResults>
351
339
  static void verify_shortlist(
352
- const IndexBinaryFlat& index,
340
+ const IndexBinaryFlat* index,
353
341
  const uint8_t* q,
354
342
  const std::unordered_set<idx_t>& shortlist,
355
343
  SearchResults& res) {
356
- size_t code_size = index.code_size;
357
- size_t nlist = 0, ndis = 0, n0 = 0;
344
+ size_t code_size = index->code_size;
358
345
 
359
346
  HammingComputer hc(q, code_size);
360
- const uint8_t* codes = index.xb.data();
347
+ const uint8_t* codes = index->xb.data();
361
348
 
362
349
  for (auto i : shortlist) {
363
350
  int dis = hc.hamming(codes + i * code_size);
@@ -365,6 +352,14 @@ static void verify_shortlist(
365
352
  }
366
353
  }
367
354
 
355
+ struct Run_verify_shortlist {
356
+ using T = void;
357
+ template <class HammingComputer, class... Types>
358
+ void f(Types... args) {
359
+ verify_shortlist<HammingComputer>(args...);
360
+ }
361
+ };
362
+
368
363
  template <class SearchResults>
369
364
  void search_1_query_multihash(
370
365
  const IndexBinaryMultiHash& index,
@@ -405,29 +400,9 @@ void search_1_query_multihash(
405
400
  ndis += shortlist.size();
406
401
 
407
402
  // verify shortlist
408
-
409
- #define HC(name) verify_shortlist<name>(*index.storage, xi, shortlist, res)
410
- switch (index.code_size) {
411
- case 4:
412
- HC(HammingComputer4);
413
- break;
414
- case 8:
415
- HC(HammingComputer8);
416
- break;
417
- case 16:
418
- HC(HammingComputer16);
419
- break;
420
- case 20:
421
- HC(HammingComputer20);
422
- break;
423
- case 32:
424
- HC(HammingComputer32);
425
- break;
426
- default:
427
- HC(HammingComputerDefault);
428
- break;
429
- }
430
- #undef HC
403
+ Run_verify_shortlist r;
404
+ dispatch_HammingComputer(
405
+ index.code_size, r, index.storage, xi, shortlist, res);
431
406
  }
432
407
 
433
408
  } // anonymous namespace