faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -7,16 +7,16 @@
7
7
 
8
8
  #pragma once
9
9
 
10
- #include <vector>
11
10
  #include <algorithm>
12
11
  #include <type_traits>
12
+ #include <vector>
13
13
 
14
14
  #include <faiss/utils/Heap.h>
15
15
  #include <faiss/utils/simdlib.h>
16
16
 
17
+ #include <faiss/impl/platform_macros.h>
17
18
  #include <faiss/utils/AlignedTable.h>
18
19
  #include <faiss/utils/partitioning.h>
19
- #include <faiss/impl/platform_macros.h>
20
20
 
21
21
  /** This file contains callbacks for kernels that compute distances.
22
22
  *
@@ -31,7 +31,6 @@ namespace faiss {
31
31
 
32
32
  namespace simd_result_handlers {
33
33
 
34
-
35
34
  /** Dummy structure that just computes a checksum on results
36
35
  * (to avoid the computation to be optimized away) */
37
36
  struct DummyResultHandler {
@@ -41,8 +40,7 @@ struct DummyResultHandler {
41
40
  cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0();
42
41
  }
43
42
 
44
- void set_block_origin(size_t, size_t) {
45
- }
43
+ void set_block_origin(size_t, size_t) {}
46
44
  };
47
45
 
48
46
  /** memorize results in a nq-by-nb matrix.
@@ -50,14 +48,12 @@ struct DummyResultHandler {
50
48
  * j0 is the current upper-left block of the matrix
51
49
  */
52
50
  struct StoreResultHandler {
53
- uint16_t *data;
51
+ uint16_t* data;
54
52
  size_t ld; // total number of columns
55
53
  size_t i0 = 0;
56
54
  size_t j0 = 0;
57
55
 
58
- StoreResultHandler(uint16_t *data, size_t ld):
59
- data(data), ld(ld) {
60
- }
56
+ StoreResultHandler(uint16_t* data, size_t ld) : data(data), ld(ld) {}
61
57
 
62
58
  void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
63
59
  size_t ofs = (q + i0) * ld + j0 + b * 32;
@@ -71,9 +67,8 @@ struct StoreResultHandler {
71
67
  }
72
68
  };
73
69
 
74
-
75
70
  /** stores results in fixed-size matrix. */
76
- template<int NQ, int BB>
71
+ template <int NQ, int BB>
77
72
  struct FixedStorageHandler {
78
73
  simd16uint16 dis[NQ][BB];
79
74
  int i0 = 0;
@@ -88,47 +83,42 @@ struct FixedStorageHandler {
88
83
  assert(j0 == 0);
89
84
  }
90
85
 
91
- template<class OtherResultHandler>
92
- void to_other_handler(OtherResultHandler & other) const {
86
+ template <class OtherResultHandler>
87
+ void to_other_handler(OtherResultHandler& other) const {
93
88
  for (int q = 0; q < NQ; q++) {
94
- for(int b = 0; b < BB; b += 2) {
89
+ for (int b = 0; b < BB; b += 2) {
95
90
  other.handle(q, b / 2, dis[q][b], dis[q][b + 1]);
96
91
  }
97
92
  }
98
93
  }
99
-
100
94
  };
101
95
 
102
-
103
96
  /** Record origin of current block */
104
- template<class C, bool with_id_map>
97
+ template <class C, bool with_id_map>
105
98
  struct SIMDResultHandler {
106
99
  using TI = typename C::TI;
107
100
 
108
101
  bool disable = false;
109
102
 
110
- int64_t i0 = 0; // query origin
111
- int64_t j0 = 0; // db origin
112
- size_t ntotal; // ignore excess elements after ntotal
103
+ int64_t i0 = 0; // query origin
104
+ int64_t j0 = 0; // db origin
105
+ size_t ntotal; // ignore excess elements after ntotal
113
106
 
114
107
  /// these fields are used mainly for the IVF variants (with_id_map=true)
115
- const TI *id_map; // map offset in invlist to vector id
116
- const int *q_map; // map q to global query
117
- const uint16_t *dbias; // table of biases to add to each query
108
+ const TI* id_map; // map offset in invlist to vector id
109
+ const int* q_map; // map q to global query
110
+ const uint16_t* dbias; // table of biases to add to each query
118
111
 
119
- explicit SIMDResultHandler(size_t ntotal):
120
- ntotal(ntotal), id_map(nullptr), q_map(nullptr), dbias(nullptr)
121
- {}
112
+ explicit SIMDResultHandler(size_t ntotal)
113
+ : ntotal(ntotal), id_map(nullptr), q_map(nullptr), dbias(nullptr) {}
122
114
 
123
115
  void set_block_origin(size_t i0, size_t j0) {
124
116
  this->i0 = i0;
125
117
  this->j0 = j0;
126
118
  }
127
119
 
128
-
129
120
  // adjust handler data for IVF.
130
- void adjust_with_origin(size_t & q, simd16uint16 & d0, simd16uint16 & d1)
131
- {
121
+ void adjust_with_origin(size_t& q, simd16uint16& d0, simd16uint16& d1) {
132
122
  q += i0;
133
123
 
134
124
  if (dbias) {
@@ -154,9 +144,10 @@ struct SIMDResultHandler {
154
144
  /// return binary mask of elements below thr in (d0, d1)
155
145
  /// inverse_test returns elements above
156
146
  uint32_t get_lt_mask(
157
- uint16_t thr, size_t b,
158
- simd16uint16 d0, simd16uint16 d1
159
- ) {
147
+ uint16_t thr,
148
+ size_t b,
149
+ simd16uint16 d0,
150
+ simd16uint16 d1) {
160
151
  simd16uint16 thr16(thr);
161
152
  uint32_t lt_mask;
162
153
 
@@ -182,18 +173,16 @@ struct SIMDResultHandler {
182
173
  }
183
174
 
184
175
  virtual void to_flat_arrays(
185
- float *distances, int64_t *labels,
186
- const float *normalizers = nullptr
187
- ) = 0;
176
+ float* distances,
177
+ int64_t* labels,
178
+ const float* normalizers = nullptr) = 0;
188
179
 
189
180
  virtual ~SIMDResultHandler() {}
190
-
191
181
  };
192
182
 
193
-
194
183
  /** Special version for k=1 */
195
- template<class C, bool with_id_map = false>
196
- struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
184
+ template <class C, bool with_id_map = false>
185
+ struct SingleResultHandler : SIMDResultHandler<C, with_id_map> {
197
186
  using T = typename C::T;
198
187
  using TI = typename C::TI;
199
188
 
@@ -203,9 +192,8 @@ struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
203
192
  };
204
193
  std::vector<Result> results;
205
194
 
206
- SingleResultHandler(size_t nq, size_t ntotal):
207
- SIMDResultHandler<C, with_id_map>(ntotal), results(nq)
208
- {
195
+ SingleResultHandler(size_t nq, size_t ntotal)
196
+ : SIMDResultHandler<C, with_id_map>(ntotal), results(nq) {
209
197
  for (int i = 0; i < nq; i++) {
210
198
  Result res = {C::neutral(), -1};
211
199
  results[i] = res;
@@ -213,13 +201,13 @@ struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
213
201
  }
214
202
 
215
203
  void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
216
- if(this->disable) {
204
+ if (this->disable) {
217
205
  return;
218
206
  }
219
207
 
220
208
  this->adjust_with_origin(q, d0, d1);
221
209
 
222
- Result & res = results[q];
210
+ Result& res = results[q];
223
211
  uint32_t lt_mask = this->get_lt_mask(res.val, b, d0, d1);
224
212
  if (!lt_mask) {
225
213
  return;
@@ -242,9 +230,9 @@ struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
242
230
  }
243
231
 
244
232
  void to_flat_arrays(
245
- float *distances, int64_t *labels,
246
- const float *normalizers = nullptr
247
- ) override {
233
+ float* distances,
234
+ int64_t* labels,
235
+ const float* normalizers = nullptr) override {
248
236
  for (int q = 0; q < results.size(); q++) {
249
237
  if (!normalizers) {
250
238
  distances[q] = results[q].val;
@@ -256,48 +244,50 @@ struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
256
244
  labels[q] = results[q].id;
257
245
  }
258
246
  }
259
-
260
247
  };
261
248
 
262
249
  /** Structure that collects results in a min- or max-heap */
263
- template<class C, bool with_id_map = false>
264
- struct HeapHandler: SIMDResultHandler<C, with_id_map> {
250
+ template <class C, bool with_id_map = false>
251
+ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
265
252
  using T = typename C::T;
266
253
  using TI = typename C::TI;
267
254
 
268
255
  int nq;
269
- T *heap_dis_tab;
270
- TI *heap_ids_tab;
256
+ T* heap_dis_tab;
257
+ TI* heap_ids_tab;
271
258
 
272
- int64_t k; // number of results to keep
259
+ int64_t k; // number of results to keep
273
260
 
274
261
  HeapHandler(
275
- int nq,
276
- T * heap_dis_tab, TI * heap_ids_tab,
277
- size_t k, size_t ntotal
278
- ):
279
- SIMDResultHandler<C, with_id_map>(ntotal), nq(nq),
280
- heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
281
- {
282
- for (int q = 0; q < nq; q++) {
283
- T *heap_dis_in = heap_dis_tab + q * k;
284
- TI *heap_ids_in = heap_ids_tab + q * k;
285
- heap_heapify<C> (k, heap_dis_in, heap_ids_in);
262
+ int nq,
263
+ T* heap_dis_tab,
264
+ TI* heap_ids_tab,
265
+ size_t k,
266
+ size_t ntotal)
267
+ : SIMDResultHandler<C, with_id_map>(ntotal),
268
+ nq(nq),
269
+ heap_dis_tab(heap_dis_tab),
270
+ heap_ids_tab(heap_ids_tab),
271
+ k(k) {
272
+ for (int q = 0; q < nq; q++) {
273
+ T* heap_dis_in = heap_dis_tab + q * k;
274
+ TI* heap_ids_in = heap_ids_tab + q * k;
275
+ heap_heapify<C>(k, heap_dis_in, heap_ids_in);
286
276
  }
287
277
  }
288
278
 
289
279
  void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
290
- if(this->disable) {
280
+ if (this->disable) {
291
281
  return;
292
282
  }
293
283
 
294
284
  this->adjust_with_origin(q, d0, d1);
295
285
 
296
- T *heap_dis = heap_dis_tab + q * k;
297
- TI *heap_ids = heap_ids_tab + q * k;
286
+ T* heap_dis = heap_dis_tab + q * k;
287
+ TI* heap_ids = heap_ids_tab + q * k;
298
288
 
299
- uint16_t cur_thresh = heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) :
300
- 0xffff;
289
+ uint16_t cur_thresh =
290
+ heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) : 0xffff;
301
291
 
302
292
  // here we handle the reverse comparison case as well
303
293
  uint32_t lt_mask = this->get_lt_mask(cur_thresh, b, d0, d1);
@@ -306,7 +296,7 @@ struct HeapHandler: SIMDResultHandler<C, with_id_map> {
306
296
  return;
307
297
  }
308
298
 
309
- ALIGNED(32) uint16_t d32tab[32] ;
299
+ ALIGNED(32) uint16_t d32tab[32];
310
300
  d0.store(d32tab);
311
301
  d1.store(d32tab + 16);
312
302
 
@@ -321,20 +311,18 @@ struct HeapHandler: SIMDResultHandler<C, with_id_map> {
321
311
  heap_push<C>(k, heap_dis, heap_ids, dis, idx);
322
312
  }
323
313
  }
324
-
325
314
  }
326
315
 
327
316
  void to_flat_arrays(
328
- float *distances, int64_t *labels,
329
- const float *normalizers = nullptr
330
- ) override {
331
-
317
+ float* distances,
318
+ int64_t* labels,
319
+ const float* normalizers = nullptr) override {
332
320
  for (int q = 0; q < nq; q++) {
333
- T *heap_dis_in = heap_dis_tab + q * k;
334
- TI *heap_ids_in = heap_ids_tab + q * k;
335
- heap_reorder<C> (k, heap_dis_in, heap_ids_in);
336
- int64_t *heap_ids = labels + q * k;
337
- float *heap_dis = distances + q * k;
321
+ T* heap_dis_in = heap_dis_tab + q * k;
322
+ TI* heap_ids_in = heap_ids_tab + q * k;
323
+ heap_reorder<C>(k, heap_dis_in, heap_ids_in);
324
+ int64_t* heap_ids = labels + q * k;
325
+ float* heap_dis = distances + q * k;
338
326
 
339
327
  float one_a = 1.0, b = 0.0;
340
328
  if (normalizers) {
@@ -347,10 +335,8 @@ struct HeapHandler: SIMDResultHandler<C, with_id_map> {
347
335
  }
348
336
  }
349
337
  }
350
-
351
338
  };
352
339
 
353
-
354
340
  /** Simple top-N implementation using a reservoir.
355
341
  *
356
342
  * Results are stored when they are below the threshold until the capacity is
@@ -358,12 +344,10 @@ struct HeapHandler: SIMDResultHandler<C, with_id_map> {
358
344
 
359
345
  namespace {
360
346
 
361
- uint64_t get_cy () {
362
- #ifdef MICRO_BENCHMARK
347
+ uint64_t get_cy() {
348
+ #ifdef MICRO_BENCHMARK
363
349
  uint32_t high, low;
364
- asm volatile("rdtsc \n\t"
365
- : "=a" (low),
366
- "=d" (high));
350
+ asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high));
367
351
  return ((uint64_t)high << 32) | (low);
368
352
  #else
369
353
  return 0;
@@ -372,27 +356,23 @@ uint64_t get_cy () {
372
356
 
373
357
  } // anonymous namespace
374
358
 
375
- template<class C>
359
+ template <class C>
376
360
  struct ReservoirTopN {
377
361
  using T = typename C::T;
378
362
  using TI = typename C::TI;
379
363
 
380
- T *vals;
381
- TI *ids;
364
+ T* vals;
365
+ TI* ids;
382
366
 
383
- size_t i; // number of stored elements
384
- size_t n; // number of requested elements
385
- size_t capacity; // size of storage
367
+ size_t i; // number of stored elements
368
+ size_t n; // number of requested elements
369
+ size_t capacity; // size of storage
386
370
  size_t cycles = 0;
387
371
 
388
372
  T threshold; // current threshold
389
373
 
390
- ReservoirTopN(
391
- size_t n, size_t capacity,
392
- T *vals, TI *ids
393
- ):
394
- vals(vals), ids(ids),
395
- i(0), n(n), capacity(capacity) {
374
+ ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
375
+ : vals(vals), ids(ids), i(0), n(n), capacity(capacity) {
396
376
  assert(n < capacity);
397
377
  threshold = C::neutral();
398
378
  }
@@ -411,11 +391,11 @@ struct ReservoirTopN {
411
391
  /// shrink number of stored elements to n
412
392
  void shrink_xx() {
413
393
  uint64_t t0 = get_cy();
414
- qselect (vals, ids, i, n);
415
- i = n; // forget all elements above i = n
394
+ qselect(vals, ids, i, n);
395
+ i = n; // forget all elements above i = n
416
396
  threshold = C::Crev::neutral();
417
- for(size_t j = 0; j < n; j++) {
418
- if(C::cmp(vals[j], threshold)) {
397
+ for (size_t j = 0; j < n; j++) {
398
+ if (C::cmp(vals[j], threshold)) {
419
399
  threshold = vals[j];
420
400
  }
421
401
  }
@@ -433,16 +413,14 @@ struct ReservoirTopN {
433
413
  uint64_t t0 = get_cy();
434
414
  assert(i == capacity);
435
415
  threshold = partition_fuzzy<C>(
436
- vals, ids, capacity, n, (capacity + n) / 2,
437
- &i);
416
+ vals, ids, capacity, n, (capacity + n) / 2, &i);
438
417
  cycles += get_cy() - t0;
439
418
  }
440
419
  };
441
420
 
442
-
443
421
  /** Handler built from several ReservoirTopN (one per query) */
444
- template<class C, bool with_id_map = false>
445
- struct ReservoirHandler: SIMDResultHandler<C, with_id_map> {
422
+ template <class C, bool with_id_map = false>
423
+ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
446
424
  using T = typename C::T;
447
425
  using TI = typename C::TI;
448
426
 
@@ -454,30 +432,30 @@ struct ReservoirHandler: SIMDResultHandler<C, with_id_map> {
454
432
 
455
433
  uint64_t times[4];
456
434
 
457
- ReservoirHandler(size_t nq, size_t ntotal, size_t n, size_t capacity_in):
458
- SIMDResultHandler<C, with_id_map>(ntotal), capacity((capacity_in + 15) & ~15),
459
- all_ids(nq * capacity), all_vals(nq * capacity)
460
- {
435
+ ReservoirHandler(size_t nq, size_t ntotal, size_t n, size_t capacity_in)
436
+ : SIMDResultHandler<C, with_id_map>(ntotal),
437
+ capacity((capacity_in + 15) & ~15),
438
+ all_ids(nq * capacity),
439
+ all_vals(nq * capacity) {
461
440
  assert(capacity % 16 == 0);
462
441
  for (size_t i = 0; i < nq; i++) {
463
442
  reservoirs.emplace_back(
464
- n, capacity,
465
- all_vals.get() + i * capacity,
466
- all_ids.data() + i * capacity
467
- );
443
+ n,
444
+ capacity,
445
+ all_vals.get() + i * capacity,
446
+ all_ids.data() + i * capacity);
468
447
  }
469
448
  times[0] = times[1] = times[2] = times[3] = 0;
470
449
  }
471
450
 
472
-
473
451
  void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
474
452
  uint64_t t0 = get_cy();
475
- if(this->disable) {
453
+ if (this->disable) {
476
454
  return;
477
455
  }
478
456
  this->adjust_with_origin(q, d0, d1);
479
457
 
480
- ReservoirTopN<C> & res = reservoirs[q];
458
+ ReservoirTopN<C>& res = reservoirs[q];
481
459
  uint32_t lt_mask = this->get_lt_mask(res.threshold, b, d0, d1);
482
460
  uint64_t t1 = get_cy();
483
461
  times[0] += t1 - t0;
@@ -499,27 +477,27 @@ struct ReservoirHandler: SIMDResultHandler<C, with_id_map> {
499
477
  times[1] += get_cy() - t1;
500
478
  }
501
479
 
502
-
503
480
  void to_flat_arrays(
504
- float *distances, int64_t *labels,
505
- const float *normalizers = nullptr
506
- ) override {
481
+ float* distances,
482
+ int64_t* labels,
483
+ const float* normalizers = nullptr) override {
507
484
  using Cf = typename std::conditional<
508
485
  C::is_max,
509
- CMax<float, int64_t>, CMin<float, int64_t>>::type;
486
+ CMax<float, int64_t>,
487
+ CMin<float, int64_t>>::type;
510
488
 
511
489
  uint64_t t0 = get_cy();
512
490
  uint64_t t3 = 0;
513
491
  std::vector<int> perm(reservoirs[0].n);
514
492
  for (int q = 0; q < reservoirs.size(); q++) {
515
- ReservoirTopN<C> & res = reservoirs[q];
493
+ ReservoirTopN<C>& res = reservoirs[q];
516
494
  size_t n = res.n;
517
495
 
518
496
  if (res.i > res.n) {
519
497
  res.shrink();
520
498
  }
521
- int64_t *heap_ids = labels + q * n;
522
- float *heap_dis = distances + q * n;
499
+ int64_t* heap_ids = labels + q * n;
500
+ float* heap_dis = distances + q * n;
523
501
 
524
502
  float one_a = 1.0, b = 0.0;
525
503
  if (normalizers) {
@@ -530,30 +508,24 @@ struct ReservoirHandler: SIMDResultHandler<C, with_id_map> {
530
508
  perm[i] = i;
531
509
  }
532
510
  // indirect sort of result arrays
533
- std::sort(
534
- perm.begin(), perm.begin() + res.i,
535
- [&res](int i, int j) {
536
- return C::cmp(res.vals[j], res.vals[i]);
537
- }
538
- );
511
+ std::sort(perm.begin(), perm.begin() + res.i, [&res](int i, int j) {
512
+ return C::cmp(res.vals[j], res.vals[i]);
513
+ });
539
514
  for (int i = 0; i < res.i; i++) {
540
515
  heap_dis[i] = res.vals[perm[i]] * one_a + b;
541
516
  heap_ids[i] = res.ids[perm[i]];
542
517
  }
543
518
 
544
519
  // possibly add empty results
545
- heap_heapify<Cf> (n - res.i, heap_dis + res.i, heap_ids + res.i);
520
+ heap_heapify<Cf>(n - res.i, heap_dis + res.i, heap_ids + res.i);
546
521
 
547
522
  t3 += res.cycles;
548
523
  }
549
524
  times[2] += get_cy() - t0;
550
525
  times[3] += t3;
551
526
  }
552
-
553
527
  };
554
528
 
555
-
556
529
  } // namespace simd_result_handlers
557
530
 
558
-
559
531
  } // namespace faiss