faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -7,16 +7,16 @@
7
7
 
8
8
  #pragma once
9
9
 
10
- #include <vector>
11
10
  #include <algorithm>
12
11
  #include <type_traits>
12
+ #include <vector>
13
13
 
14
14
  #include <faiss/utils/Heap.h>
15
15
  #include <faiss/utils/simdlib.h>
16
16
 
17
+ #include <faiss/impl/platform_macros.h>
17
18
  #include <faiss/utils/AlignedTable.h>
18
19
  #include <faiss/utils/partitioning.h>
19
- #include <faiss/impl/platform_macros.h>
20
20
 
21
21
  /** This file contains callbacks for kernels that compute distances.
22
22
  *
@@ -31,7 +31,6 @@ namespace faiss {
31
31
 
32
32
  namespace simd_result_handlers {
33
33
 
34
-
35
34
  /** Dummy structure that just computes a checksum on results
36
35
  * (to avoid the computation to be optimized away) */
37
36
  struct DummyResultHandler {
@@ -41,8 +40,7 @@ struct DummyResultHandler {
41
40
  cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0();
42
41
  }
43
42
 
44
- void set_block_origin(size_t, size_t) {
45
- }
43
+ void set_block_origin(size_t, size_t) {}
46
44
  };
47
45
 
48
46
  /** memorize results in a nq-by-nb matrix.
@@ -50,14 +48,12 @@ struct DummyResultHandler {
50
48
  * j0 is the current upper-left block of the matrix
51
49
  */
52
50
  struct StoreResultHandler {
53
- uint16_t *data;
51
+ uint16_t* data;
54
52
  size_t ld; // total number of columns
55
53
  size_t i0 = 0;
56
54
  size_t j0 = 0;
57
55
 
58
- StoreResultHandler(uint16_t *data, size_t ld):
59
- data(data), ld(ld) {
60
- }
56
+ StoreResultHandler(uint16_t* data, size_t ld) : data(data), ld(ld) {}
61
57
 
62
58
  void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
63
59
  size_t ofs = (q + i0) * ld + j0 + b * 32;
@@ -71,9 +67,8 @@ struct StoreResultHandler {
71
67
  }
72
68
  };
73
69
 
74
-
75
70
  /** stores results in fixed-size matrix. */
76
- template<int NQ, int BB>
71
+ template <int NQ, int BB>
77
72
  struct FixedStorageHandler {
78
73
  simd16uint16 dis[NQ][BB];
79
74
  int i0 = 0;
@@ -88,47 +83,42 @@ struct FixedStorageHandler {
88
83
  assert(j0 == 0);
89
84
  }
90
85
 
91
- template<class OtherResultHandler>
92
- void to_other_handler(OtherResultHandler & other) const {
86
+ template <class OtherResultHandler>
87
+ void to_other_handler(OtherResultHandler& other) const {
93
88
  for (int q = 0; q < NQ; q++) {
94
- for(int b = 0; b < BB; b += 2) {
89
+ for (int b = 0; b < BB; b += 2) {
95
90
  other.handle(q, b / 2, dis[q][b], dis[q][b + 1]);
96
91
  }
97
92
  }
98
93
  }
99
-
100
94
  };
101
95
 
102
-
103
96
  /** Record origin of current block */
104
- template<class C, bool with_id_map>
97
+ template <class C, bool with_id_map>
105
98
  struct SIMDResultHandler {
106
99
  using TI = typename C::TI;
107
100
 
108
101
  bool disable = false;
109
102
 
110
- int64_t i0 = 0; // query origin
111
- int64_t j0 = 0; // db origin
112
- size_t ntotal; // ignore excess elements after ntotal
103
+ int64_t i0 = 0; // query origin
104
+ int64_t j0 = 0; // db origin
105
+ size_t ntotal; // ignore excess elements after ntotal
113
106
 
114
107
  /// these fields are used mainly for the IVF variants (with_id_map=true)
115
- const TI *id_map; // map offset in invlist to vector id
116
- const int *q_map; // map q to global query
117
- const uint16_t *dbias; // table of biases to add to each query
108
+ const TI* id_map; // map offset in invlist to vector id
109
+ const int* q_map; // map q to global query
110
+ const uint16_t* dbias; // table of biases to add to each query
118
111
 
119
- explicit SIMDResultHandler(size_t ntotal):
120
- ntotal(ntotal), id_map(nullptr), q_map(nullptr), dbias(nullptr)
121
- {}
112
+ explicit SIMDResultHandler(size_t ntotal)
113
+ : ntotal(ntotal), id_map(nullptr), q_map(nullptr), dbias(nullptr) {}
122
114
 
123
115
  void set_block_origin(size_t i0, size_t j0) {
124
116
  this->i0 = i0;
125
117
  this->j0 = j0;
126
118
  }
127
119
 
128
-
129
120
  // adjust handler data for IVF.
130
- void adjust_with_origin(size_t & q, simd16uint16 & d0, simd16uint16 & d1)
131
- {
121
+ void adjust_with_origin(size_t& q, simd16uint16& d0, simd16uint16& d1) {
132
122
  q += i0;
133
123
 
134
124
  if (dbias) {
@@ -154,9 +144,10 @@ struct SIMDResultHandler {
154
144
  /// return binary mask of elements below thr in (d0, d1)
155
145
  /// inverse_test returns elements above
156
146
  uint32_t get_lt_mask(
157
- uint16_t thr, size_t b,
158
- simd16uint16 d0, simd16uint16 d1
159
- ) {
147
+ uint16_t thr,
148
+ size_t b,
149
+ simd16uint16 d0,
150
+ simd16uint16 d1) {
160
151
  simd16uint16 thr16(thr);
161
152
  uint32_t lt_mask;
162
153
 
@@ -182,18 +173,16 @@ struct SIMDResultHandler {
182
173
  }
183
174
 
184
175
  virtual void to_flat_arrays(
185
- float *distances, int64_t *labels,
186
- const float *normalizers = nullptr
187
- ) = 0;
176
+ float* distances,
177
+ int64_t* labels,
178
+ const float* normalizers = nullptr) = 0;
188
179
 
189
180
  virtual ~SIMDResultHandler() {}
190
-
191
181
  };
192
182
 
193
-
194
183
  /** Special version for k=1 */
195
- template<class C, bool with_id_map = false>
196
- struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
184
+ template <class C, bool with_id_map = false>
185
+ struct SingleResultHandler : SIMDResultHandler<C, with_id_map> {
197
186
  using T = typename C::T;
198
187
  using TI = typename C::TI;
199
188
 
@@ -203,9 +192,8 @@ struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
203
192
  };
204
193
  std::vector<Result> results;
205
194
 
206
- SingleResultHandler(size_t nq, size_t ntotal):
207
- SIMDResultHandler<C, with_id_map>(ntotal), results(nq)
208
- {
195
+ SingleResultHandler(size_t nq, size_t ntotal)
196
+ : SIMDResultHandler<C, with_id_map>(ntotal), results(nq) {
209
197
  for (int i = 0; i < nq; i++) {
210
198
  Result res = {C::neutral(), -1};
211
199
  results[i] = res;
@@ -213,13 +201,13 @@ struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
213
201
  }
214
202
 
215
203
  void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
216
- if(this->disable) {
204
+ if (this->disable) {
217
205
  return;
218
206
  }
219
207
 
220
208
  this->adjust_with_origin(q, d0, d1);
221
209
 
222
- Result & res = results[q];
210
+ Result& res = results[q];
223
211
  uint32_t lt_mask = this->get_lt_mask(res.val, b, d0, d1);
224
212
  if (!lt_mask) {
225
213
  return;
@@ -242,9 +230,9 @@ struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
242
230
  }
243
231
 
244
232
  void to_flat_arrays(
245
- float *distances, int64_t *labels,
246
- const float *normalizers = nullptr
247
- ) override {
233
+ float* distances,
234
+ int64_t* labels,
235
+ const float* normalizers = nullptr) override {
248
236
  for (int q = 0; q < results.size(); q++) {
249
237
  if (!normalizers) {
250
238
  distances[q] = results[q].val;
@@ -256,48 +244,50 @@ struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
256
244
  labels[q] = results[q].id;
257
245
  }
258
246
  }
259
-
260
247
  };
261
248
 
262
249
  /** Structure that collects results in a min- or max-heap */
263
- template<class C, bool with_id_map = false>
264
- struct HeapHandler: SIMDResultHandler<C, with_id_map> {
250
+ template <class C, bool with_id_map = false>
251
+ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
265
252
  using T = typename C::T;
266
253
  using TI = typename C::TI;
267
254
 
268
255
  int nq;
269
- T *heap_dis_tab;
270
- TI *heap_ids_tab;
256
+ T* heap_dis_tab;
257
+ TI* heap_ids_tab;
271
258
 
272
- int64_t k; // number of results to keep
259
+ int64_t k; // number of results to keep
273
260
 
274
261
  HeapHandler(
275
- int nq,
276
- T * heap_dis_tab, TI * heap_ids_tab,
277
- size_t k, size_t ntotal
278
- ):
279
- SIMDResultHandler<C, with_id_map>(ntotal), nq(nq),
280
- heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
281
- {
282
- for (int q = 0; q < nq; q++) {
283
- T *heap_dis_in = heap_dis_tab + q * k;
284
- TI *heap_ids_in = heap_ids_tab + q * k;
285
- heap_heapify<C> (k, heap_dis_in, heap_ids_in);
262
+ int nq,
263
+ T* heap_dis_tab,
264
+ TI* heap_ids_tab,
265
+ size_t k,
266
+ size_t ntotal)
267
+ : SIMDResultHandler<C, with_id_map>(ntotal),
268
+ nq(nq),
269
+ heap_dis_tab(heap_dis_tab),
270
+ heap_ids_tab(heap_ids_tab),
271
+ k(k) {
272
+ for (int q = 0; q < nq; q++) {
273
+ T* heap_dis_in = heap_dis_tab + q * k;
274
+ TI* heap_ids_in = heap_ids_tab + q * k;
275
+ heap_heapify<C>(k, heap_dis_in, heap_ids_in);
286
276
  }
287
277
  }
288
278
 
289
279
  void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
290
- if(this->disable) {
280
+ if (this->disable) {
291
281
  return;
292
282
  }
293
283
 
294
284
  this->adjust_with_origin(q, d0, d1);
295
285
 
296
- T *heap_dis = heap_dis_tab + q * k;
297
- TI *heap_ids = heap_ids_tab + q * k;
286
+ T* heap_dis = heap_dis_tab + q * k;
287
+ TI* heap_ids = heap_ids_tab + q * k;
298
288
 
299
- uint16_t cur_thresh = heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) :
300
- 0xffff;
289
+ uint16_t cur_thresh =
290
+ heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) : 0xffff;
301
291
 
302
292
  // here we handle the reverse comparison case as well
303
293
  uint32_t lt_mask = this->get_lt_mask(cur_thresh, b, d0, d1);
@@ -306,7 +296,7 @@ struct HeapHandler: SIMDResultHandler<C, with_id_map> {
306
296
  return;
307
297
  }
308
298
 
309
- ALIGNED(32) uint16_t d32tab[32] ;
299
+ ALIGNED(32) uint16_t d32tab[32];
310
300
  d0.store(d32tab);
311
301
  d1.store(d32tab + 16);
312
302
 
@@ -321,20 +311,18 @@ struct HeapHandler: SIMDResultHandler<C, with_id_map> {
321
311
  heap_push<C>(k, heap_dis, heap_ids, dis, idx);
322
312
  }
323
313
  }
324
-
325
314
  }
326
315
 
327
316
  void to_flat_arrays(
328
- float *distances, int64_t *labels,
329
- const float *normalizers = nullptr
330
- ) override {
331
-
317
+ float* distances,
318
+ int64_t* labels,
319
+ const float* normalizers = nullptr) override {
332
320
  for (int q = 0; q < nq; q++) {
333
- T *heap_dis_in = heap_dis_tab + q * k;
334
- TI *heap_ids_in = heap_ids_tab + q * k;
335
- heap_reorder<C> (k, heap_dis_in, heap_ids_in);
336
- int64_t *heap_ids = labels + q * k;
337
- float *heap_dis = distances + q * k;
321
+ T* heap_dis_in = heap_dis_tab + q * k;
322
+ TI* heap_ids_in = heap_ids_tab + q * k;
323
+ heap_reorder<C>(k, heap_dis_in, heap_ids_in);
324
+ int64_t* heap_ids = labels + q * k;
325
+ float* heap_dis = distances + q * k;
338
326
 
339
327
  float one_a = 1.0, b = 0.0;
340
328
  if (normalizers) {
@@ -347,10 +335,8 @@ struct HeapHandler: SIMDResultHandler<C, with_id_map> {
347
335
  }
348
336
  }
349
337
  }
350
-
351
338
  };
352
339
 
353
-
354
340
  /** Simple top-N implementation using a reservoir.
355
341
  *
356
342
  * Results are stored when they are below the threshold until the capacity is
@@ -358,12 +344,10 @@ struct HeapHandler: SIMDResultHandler<C, with_id_map> {
358
344
 
359
345
  namespace {
360
346
 
361
- uint64_t get_cy () {
362
- #ifdef MICRO_BENCHMARK
347
+ uint64_t get_cy() {
348
+ #ifdef MICRO_BENCHMARK
363
349
  uint32_t high, low;
364
- asm volatile("rdtsc \n\t"
365
- : "=a" (low),
366
- "=d" (high));
350
+ asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high));
367
351
  return ((uint64_t)high << 32) | (low);
368
352
  #else
369
353
  return 0;
@@ -372,27 +356,23 @@ uint64_t get_cy () {
372
356
 
373
357
  } // anonymous namespace
374
358
 
375
- template<class C>
359
+ template <class C>
376
360
  struct ReservoirTopN {
377
361
  using T = typename C::T;
378
362
  using TI = typename C::TI;
379
363
 
380
- T *vals;
381
- TI *ids;
364
+ T* vals;
365
+ TI* ids;
382
366
 
383
- size_t i; // number of stored elements
384
- size_t n; // number of requested elements
385
- size_t capacity; // size of storage
367
+ size_t i; // number of stored elements
368
+ size_t n; // number of requested elements
369
+ size_t capacity; // size of storage
386
370
  size_t cycles = 0;
387
371
 
388
372
  T threshold; // current threshold
389
373
 
390
- ReservoirTopN(
391
- size_t n, size_t capacity,
392
- T *vals, TI *ids
393
- ):
394
- vals(vals), ids(ids),
395
- i(0), n(n), capacity(capacity) {
374
+ ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
375
+ : vals(vals), ids(ids), i(0), n(n), capacity(capacity) {
396
376
  assert(n < capacity);
397
377
  threshold = C::neutral();
398
378
  }
@@ -411,11 +391,11 @@ struct ReservoirTopN {
411
391
  /// shrink number of stored elements to n
412
392
  void shrink_xx() {
413
393
  uint64_t t0 = get_cy();
414
- qselect (vals, ids, i, n);
415
- i = n; // forget all elements above i = n
394
+ qselect(vals, ids, i, n);
395
+ i = n; // forget all elements above i = n
416
396
  threshold = C::Crev::neutral();
417
- for(size_t j = 0; j < n; j++) {
418
- if(C::cmp(vals[j], threshold)) {
397
+ for (size_t j = 0; j < n; j++) {
398
+ if (C::cmp(vals[j], threshold)) {
419
399
  threshold = vals[j];
420
400
  }
421
401
  }
@@ -433,16 +413,14 @@ struct ReservoirTopN {
433
413
  uint64_t t0 = get_cy();
434
414
  assert(i == capacity);
435
415
  threshold = partition_fuzzy<C>(
436
- vals, ids, capacity, n, (capacity + n) / 2,
437
- &i);
416
+ vals, ids, capacity, n, (capacity + n) / 2, &i);
438
417
  cycles += get_cy() - t0;
439
418
  }
440
419
  };
441
420
 
442
-
443
421
  /** Handler built from several ReservoirTopN (one per query) */
444
- template<class C, bool with_id_map = false>
445
- struct ReservoirHandler: SIMDResultHandler<C, with_id_map> {
422
+ template <class C, bool with_id_map = false>
423
+ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
446
424
  using T = typename C::T;
447
425
  using TI = typename C::TI;
448
426
 
@@ -454,30 +432,30 @@ struct ReservoirHandler: SIMDResultHandler<C, with_id_map> {
454
432
 
455
433
  uint64_t times[4];
456
434
 
457
- ReservoirHandler(size_t nq, size_t ntotal, size_t n, size_t capacity_in):
458
- SIMDResultHandler<C, with_id_map>(ntotal), capacity((capacity_in + 15) & ~15),
459
- all_ids(nq * capacity), all_vals(nq * capacity)
460
- {
435
+ ReservoirHandler(size_t nq, size_t ntotal, size_t n, size_t capacity_in)
436
+ : SIMDResultHandler<C, with_id_map>(ntotal),
437
+ capacity((capacity_in + 15) & ~15),
438
+ all_ids(nq * capacity),
439
+ all_vals(nq * capacity) {
461
440
  assert(capacity % 16 == 0);
462
441
  for (size_t i = 0; i < nq; i++) {
463
442
  reservoirs.emplace_back(
464
- n, capacity,
465
- all_vals.get() + i * capacity,
466
- all_ids.data() + i * capacity
467
- );
443
+ n,
444
+ capacity,
445
+ all_vals.get() + i * capacity,
446
+ all_ids.data() + i * capacity);
468
447
  }
469
448
  times[0] = times[1] = times[2] = times[3] = 0;
470
449
  }
471
450
 
472
-
473
451
  void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
474
452
  uint64_t t0 = get_cy();
475
- if(this->disable) {
453
+ if (this->disable) {
476
454
  return;
477
455
  }
478
456
  this->adjust_with_origin(q, d0, d1);
479
457
 
480
- ReservoirTopN<C> & res = reservoirs[q];
458
+ ReservoirTopN<C>& res = reservoirs[q];
481
459
  uint32_t lt_mask = this->get_lt_mask(res.threshold, b, d0, d1);
482
460
  uint64_t t1 = get_cy();
483
461
  times[0] += t1 - t0;
@@ -499,27 +477,27 @@ struct ReservoirHandler: SIMDResultHandler<C, with_id_map> {
499
477
  times[1] += get_cy() - t1;
500
478
  }
501
479
 
502
-
503
480
  void to_flat_arrays(
504
- float *distances, int64_t *labels,
505
- const float *normalizers = nullptr
506
- ) override {
481
+ float* distances,
482
+ int64_t* labels,
483
+ const float* normalizers = nullptr) override {
507
484
  using Cf = typename std::conditional<
508
485
  C::is_max,
509
- CMax<float, int64_t>, CMin<float, int64_t>>::type;
486
+ CMax<float, int64_t>,
487
+ CMin<float, int64_t>>::type;
510
488
 
511
489
  uint64_t t0 = get_cy();
512
490
  uint64_t t3 = 0;
513
491
  std::vector<int> perm(reservoirs[0].n);
514
492
  for (int q = 0; q < reservoirs.size(); q++) {
515
- ReservoirTopN<C> & res = reservoirs[q];
493
+ ReservoirTopN<C>& res = reservoirs[q];
516
494
  size_t n = res.n;
517
495
 
518
496
  if (res.i > res.n) {
519
497
  res.shrink();
520
498
  }
521
- int64_t *heap_ids = labels + q * n;
522
- float *heap_dis = distances + q * n;
499
+ int64_t* heap_ids = labels + q * n;
500
+ float* heap_dis = distances + q * n;
523
501
 
524
502
  float one_a = 1.0, b = 0.0;
525
503
  if (normalizers) {
@@ -530,30 +508,24 @@ struct ReservoirHandler: SIMDResultHandler<C, with_id_map> {
530
508
  perm[i] = i;
531
509
  }
532
510
  // indirect sort of result arrays
533
- std::sort(
534
- perm.begin(), perm.begin() + res.i,
535
- [&res](int i, int j) {
536
- return C::cmp(res.vals[j], res.vals[i]);
537
- }
538
- );
511
+ std::sort(perm.begin(), perm.begin() + res.i, [&res](int i, int j) {
512
+ return C::cmp(res.vals[j], res.vals[i]);
513
+ });
539
514
  for (int i = 0; i < res.i; i++) {
540
515
  heap_dis[i] = res.vals[perm[i]] * one_a + b;
541
516
  heap_ids[i] = res.ids[perm[i]];
542
517
  }
543
518
 
544
519
  // possibly add empty results
545
- heap_heapify<Cf> (n - res.i, heap_dis + res.i, heap_ids + res.i);
520
+ heap_heapify<Cf>(n - res.i, heap_dis + res.i, heap_ids + res.i);
546
521
 
547
522
  t3 += res.cycles;
548
523
  }
549
524
  times[2] += get_cy() - t0;
550
525
  times[3] += t3;
551
526
  }
552
-
553
527
  };
554
528
 
555
-
556
529
  } // namespace simd_result_handlers
557
530
 
558
-
559
531
  } // namespace faiss