faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -7,8 +7,8 @@
7
7
 
8
8
  #include <faiss/IndexAdditiveQuantizerFastScan.h>
9
9
 
10
- #include <limits.h>
11
10
  #include <cassert>
11
+ #include <climits>
12
12
  #include <memory>
13
13
 
14
14
  #include <omp.h>
@@ -35,30 +35,30 @@ IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan(
35
35
  }
36
36
 
37
37
  void IndexAdditiveQuantizerFastScan::init(
38
- AdditiveQuantizer* aq,
38
+ AdditiveQuantizer* aq_2,
39
39
  MetricType metric,
40
40
  int bbs) {
41
- FAISS_THROW_IF_NOT(aq != nullptr);
42
- FAISS_THROW_IF_NOT(!aq->nbits.empty());
43
- FAISS_THROW_IF_NOT(aq->nbits[0] == 4);
41
+ FAISS_THROW_IF_NOT(aq_2 != nullptr);
42
+ FAISS_THROW_IF_NOT(!aq_2->nbits.empty());
43
+ FAISS_THROW_IF_NOT(aq_2->nbits[0] == 4);
44
44
  if (metric == METRIC_INNER_PRODUCT) {
45
45
  FAISS_THROW_IF_NOT_MSG(
46
- aq->search_type == AdditiveQuantizer::ST_LUT_nonorm,
46
+ aq_2->search_type == AdditiveQuantizer::ST_LUT_nonorm,
47
47
  "Search type must be ST_LUT_nonorm for IP metric");
48
48
  } else {
49
49
  FAISS_THROW_IF_NOT_MSG(
50
- aq->search_type == AdditiveQuantizer::ST_norm_lsq2x4 ||
51
- aq->search_type == AdditiveQuantizer::ST_norm_rq2x4,
50
+ aq_2->search_type == AdditiveQuantizer::ST_norm_lsq2x4 ||
51
+ aq_2->search_type == AdditiveQuantizer::ST_norm_rq2x4,
52
52
  "Search type must be lsq2x4 or rq2x4 for L2 metric");
53
53
  }
54
54
 
55
- this->aq = aq;
55
+ this->aq = aq_2;
56
56
  if (metric == METRIC_L2) {
57
- M = aq->M + 2; // 2x4 bits AQ
57
+ M = aq_2->M + 2; // 2x4 bits AQ
58
58
  } else {
59
- M = aq->M;
59
+ M = aq_2->M;
60
60
  }
61
- init_fastscan(aq->d, M, 4, metric, bbs);
61
+ init_fastscan(aq_2->d, M, 4, metric, bbs);
62
62
 
63
63
  max_train_points = 1024 * ksub * M;
64
64
  }
@@ -83,7 +83,7 @@ IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan(
83
83
  pq4_pack_codes(orig_codes, ntotal, M, ntotal2, bbs, M2, codes.get());
84
84
  }
85
85
 
86
- IndexAdditiveQuantizerFastScan::~IndexAdditiveQuantizerFastScan() {}
86
+ IndexAdditiveQuantizerFastScan::~IndexAdditiveQuantizerFastScan() = default;
87
87
 
88
88
  void IndexAdditiveQuantizerFastScan::train(idx_t n, const float* x_in) {
89
89
  if (is_trained) {
@@ -203,9 +203,9 @@ void IndexAdditiveQuantizerFastScan::search(
203
203
 
204
204
  NormTableScaler scaler(norm_scale);
205
205
  if (metric_type == METRIC_L2) {
206
- search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
206
+ search_dispatch_implem<true>(n, x, k, distances, labels, &scaler);
207
207
  } else {
208
- search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
208
+ search_dispatch_implem<false>(n, x, k, distances, labels, &scaler);
209
209
  }
210
210
  }
211
211
 
@@ -15,7 +15,12 @@
15
15
 
16
16
  namespace faiss {
17
17
 
18
- IndexBinary::~IndexBinary() {}
18
+ IndexBinary::IndexBinary(idx_t d, MetricType metric)
19
+ : d(d), code_size(d / 8), metric_type(metric) {
20
+ FAISS_THROW_IF_NOT(d % 8 == 0);
21
+ }
22
+
23
+ IndexBinary::~IndexBinary() = default;
19
24
 
20
25
  void IndexBinary::train(idx_t, const uint8_t*) {
21
26
  // Does nothing by default.
@@ -51,7 +56,7 @@ void IndexBinary::reconstruct(idx_t, uint8_t*) const {
51
56
 
52
57
  void IndexBinary::reconstruct_n(idx_t i0, idx_t ni, uint8_t* recons) const {
53
58
  for (idx_t i = 0; i < ni; i++) {
54
- reconstruct(i0 + i, recons + i * d);
59
+ reconstruct(i0 + i, recons + i * code_size);
55
60
  }
56
61
  }
57
62
 
@@ -70,10 +75,10 @@ void IndexBinary::search_and_reconstruct(
70
75
  for (idx_t j = 0; j < k; ++j) {
71
76
  idx_t ij = i * k + j;
72
77
  idx_t key = labels[ij];
73
- uint8_t* reconstructed = recons + ij * d;
78
+ uint8_t* reconstructed = recons + ij * code_size;
74
79
  if (key < 0) {
75
80
  // Fill with NaNs
76
- memset(reconstructed, -1, sizeof(*reconstructed) * d);
81
+ memset(reconstructed, -1, code_size);
77
82
  } else {
78
83
  reconstruct(key, reconstructed);
79
84
  }
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #ifndef FAISS_INDEX_BINARY_H
11
9
  #define FAISS_INDEX_BINARY_H
12
10
 
@@ -16,7 +14,6 @@
16
14
  #include <typeinfo>
17
15
 
18
16
  #include <faiss/Index.h>
19
- #include <faiss/impl/FaissAssert.h>
20
17
 
21
18
  namespace faiss {
22
19
 
@@ -35,27 +32,19 @@ struct IndexBinary {
35
32
  using component_t = uint8_t;
36
33
  using distance_t = int32_t;
37
34
 
38
- int d; ///< vector dimension
39
- int code_size; ///< number of bytes per vector ( = d / 8 )
40
- idx_t ntotal; ///< total nb of indexed vectors
41
- bool verbose; ///< verbosity level
35
+ int d = 0; ///< vector dimension
36
+ int code_size = 0; ///< number of bytes per vector ( = d / 8 )
37
+ idx_t ntotal = 0; ///< total nb of indexed vectors
38
+ bool verbose = false; ///< verbosity level
42
39
 
43
40
  /// set if the Index does not require training, or if training is done
44
41
  /// already
45
- bool is_trained;
42
+ bool is_trained = true;
46
43
 
47
44
  /// type of metric this index uses for search
48
- MetricType metric_type;
49
-
50
- explicit IndexBinary(idx_t d = 0, MetricType metric = METRIC_L2)
51
- : d(d),
52
- code_size(d / 8),
53
- ntotal(0),
54
- verbose(false),
55
- is_trained(true),
56
- metric_type(metric) {
57
- FAISS_THROW_IF_NOT(d % 8 == 0);
58
- }
45
+ MetricType metric_type = METRIC_L2;
46
+
47
+ explicit IndexBinary(idx_t d = 0, MetricType metric = METRIC_L2);
59
48
 
60
49
  virtual ~IndexBinary();
61
50
 
@@ -9,13 +9,14 @@
9
9
 
10
10
  #include <faiss/IndexBinaryFromFloat.h>
11
11
 
12
+ #include <faiss/impl/FaissAssert.h>
12
13
  #include <faiss/utils/utils.h>
13
14
  #include <algorithm>
14
15
  #include <memory>
15
16
 
16
17
  namespace faiss {
17
18
 
18
- IndexBinaryFromFloat::IndexBinaryFromFloat() {}
19
+ IndexBinaryFromFloat::IndexBinaryFromFloat() = default;
19
20
 
20
21
  IndexBinaryFromFloat::IndexBinaryFromFloat(Index* index)
21
22
  : IndexBinary(index->d), index(index), own_fields(false) {
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/IndexBinaryHNSW.h>
11
9
 
12
10
  #include <omp.h>
@@ -20,14 +18,15 @@
20
18
  #include <queue>
21
19
  #include <unordered_set>
22
20
 
23
- #include <stdint.h>
24
21
  #include <sys/stat.h>
25
22
  #include <sys/types.h>
23
+ #include <cstdint>
26
24
 
27
25
  #include <faiss/IndexBinaryFlat.h>
28
26
  #include <faiss/impl/AuxIndexStructures.h>
29
27
  #include <faiss/impl/DistanceComputer.h>
30
28
  #include <faiss/impl/FaissAssert.h>
29
+ #include <faiss/impl/ResultHandler.h>
31
30
  #include <faiss/utils/Heap.h>
32
31
  #include <faiss/utils/hamming.h>
33
32
  #include <faiss/utils/random.h>
@@ -201,27 +200,31 @@ void IndexBinaryHNSW::search(
201
200
  !params, "search params not supported for this index");
202
201
  FAISS_THROW_IF_NOT(k > 0);
203
202
 
203
+ // we use the buffer for distances as float but convert them back
204
+ // to int in the end
205
+ float* distances_f = (float*)distances;
206
+
207
+ using RH = HeapBlockResultHandler<HNSW::C>;
208
+ RH bres(n, distances_f, labels, k);
209
+
204
210
  #pragma omp parallel
205
211
  {
206
212
  VisitedTable vt(ntotal);
207
213
  std::unique_ptr<DistanceComputer> dis(get_distance_computer());
214
+ RH::SingleResultHandler res(bres);
208
215
 
209
216
  #pragma omp for
210
217
  for (idx_t i = 0; i < n; i++) {
211
- idx_t* idxi = labels + i * k;
212
- float* simi = (float*)(distances + i * k);
213
-
218
+ res.begin(i);
214
219
  dis->set_query((float*)(x + i * code_size));
215
-
216
- maxheap_heapify(k, simi, idxi);
217
- hnsw.search(*dis, k, idxi, simi, vt);
218
- maxheap_reorder(k, simi, idxi);
220
+ hnsw.search(*dis, res, vt);
221
+ res.end();
219
222
  }
220
223
  }
221
224
 
222
225
  #pragma omp parallel for
223
226
  for (int i = 0; i < n * k; ++i) {
224
- distances[i] = std::round(((float*)distances)[i]);
227
+ distances[i] = std::round(distances_f[i]);
225
228
  }
226
229
  }
227
230
 
@@ -281,31 +284,21 @@ struct FlatHammingDis : DistanceComputer {
281
284
  }
282
285
  };
283
286
 
287
+ struct BuildDistanceComputer {
288
+ using T = DistanceComputer*;
289
+ template <class HammingComputer>
290
+ DistanceComputer* f(IndexBinaryFlat* flat_storage) {
291
+ return new FlatHammingDis<HammingComputer>(*flat_storage);
292
+ }
293
+ };
294
+
284
295
  } // namespace
285
296
 
286
297
  DistanceComputer* IndexBinaryHNSW::get_distance_computer() const {
287
298
  IndexBinaryFlat* flat_storage = dynamic_cast<IndexBinaryFlat*>(storage);
288
-
289
299
  FAISS_ASSERT(flat_storage != nullptr);
290
-
291
- switch (code_size) {
292
- case 4:
293
- return new FlatHammingDis<HammingComputer4>(*flat_storage);
294
- case 8:
295
- return new FlatHammingDis<HammingComputer8>(*flat_storage);
296
- case 16:
297
- return new FlatHammingDis<HammingComputer16>(*flat_storage);
298
- case 20:
299
- return new FlatHammingDis<HammingComputer20>(*flat_storage);
300
- case 32:
301
- return new FlatHammingDis<HammingComputer32>(*flat_storage);
302
- case 64:
303
- return new FlatHammingDis<HammingComputer64>(*flat_storage);
304
- default:
305
- break;
306
- }
307
-
308
- return new FlatHammingDis<HammingComputerDefault>(*flat_storage);
300
+ BuildDistanceComputer bd;
301
+ return dispatch_HammingComputer(code_size, bd, flat_storage);
309
302
  }
310
303
 
311
304
  } // namespace faiss
@@ -21,7 +21,7 @@ namespace faiss {
21
21
  struct IndexBinaryHNSW : IndexBinary {
22
22
  typedef HNSW::storage_idx_t storage_idx_t;
23
23
 
24
- // the link strcuture
24
+ // the link structure
25
25
  HNSW hnsw;
26
26
 
27
27
  // the sequential storage
@@ -176,6 +176,14 @@ void search_single_query_template(
176
176
  } while (fe.next());
177
177
  }
178
178
 
179
+ struct Run_search_single_query {
180
+ using T = void;
181
+ template <class HammingComputer, class... Types>
182
+ T f(Types... args) {
183
+ search_single_query_template<HammingComputer>(args...);
184
+ }
185
+ };
186
+
179
187
  template <class SearchResults>
180
188
  void search_single_query(
181
189
  const IndexBinaryHash& index,
@@ -184,29 +192,9 @@ void search_single_query(
184
192
  size_t& n0,
185
193
  size_t& nlist,
186
194
  size_t& ndis) {
187
- #define HC(name) \
188
- search_single_query_template<name>(index, q, res, n0, nlist, ndis);
189
- switch (index.code_size) {
190
- case 4:
191
- HC(HammingComputer4);
192
- break;
193
- case 8:
194
- HC(HammingComputer8);
195
- break;
196
- case 16:
197
- HC(HammingComputer16);
198
- break;
199
- case 20:
200
- HC(HammingComputer20);
201
- break;
202
- case 32:
203
- HC(HammingComputer32);
204
- break;
205
- default:
206
- HC(HammingComputerDefault);
207
- break;
208
- }
209
- #undef HC
195
+ Run_search_single_query r;
196
+ dispatch_HammingComputer(
197
+ index.code_size, r, index, q, res, n0, nlist, ndis);
210
198
  }
211
199
 
212
200
  } // anonymous namespace
@@ -349,15 +337,14 @@ namespace {
349
337
 
350
338
  template <class HammingComputer, class SearchResults>
351
339
  static void verify_shortlist(
352
- const IndexBinaryFlat& index,
340
+ const IndexBinaryFlat* index,
353
341
  const uint8_t* q,
354
342
  const std::unordered_set<idx_t>& shortlist,
355
343
  SearchResults& res) {
356
- size_t code_size = index.code_size;
357
- size_t nlist = 0, ndis = 0, n0 = 0;
344
+ size_t code_size = index->code_size;
358
345
 
359
346
  HammingComputer hc(q, code_size);
360
- const uint8_t* codes = index.xb.data();
347
+ const uint8_t* codes = index->xb.data();
361
348
 
362
349
  for (auto i : shortlist) {
363
350
  int dis = hc.hamming(codes + i * code_size);
@@ -365,6 +352,14 @@ static void verify_shortlist(
365
352
  }
366
353
  }
367
354
 
355
+ struct Run_verify_shortlist {
356
+ using T = void;
357
+ template <class HammingComputer, class... Types>
358
+ void f(Types... args) {
359
+ verify_shortlist<HammingComputer>(args...);
360
+ }
361
+ };
362
+
368
363
  template <class SearchResults>
369
364
  void search_1_query_multihash(
370
365
  const IndexBinaryMultiHash& index,
@@ -405,29 +400,9 @@ void search_1_query_multihash(
405
400
  ndis += shortlist.size();
406
401
 
407
402
  // verify shortlist
408
-
409
- #define HC(name) verify_shortlist<name>(*index.storage, xi, shortlist, res)
410
- switch (index.code_size) {
411
- case 4:
412
- HC(HammingComputer4);
413
- break;
414
- case 8:
415
- HC(HammingComputer8);
416
- break;
417
- case 16:
418
- HC(HammingComputer16);
419
- break;
420
- case 20:
421
- HC(HammingComputer20);
422
- break;
423
- case 32:
424
- HC(HammingComputer32);
425
- break;
426
- default:
427
- HC(HammingComputerDefault);
428
- break;
429
- }
430
- #undef HC
403
+ Run_verify_shortlist r;
404
+ dispatch_HammingComputer(
405
+ index.code_size, r, index.storage, xi, shortlist, res);
431
406
  }
432
407
 
433
408
  } // anonymous namespace