faiss 0.1.5 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +24 -0
  3. data/README.md +12 -0
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +6 -2
  6. data/ext/faiss/index.cpp +114 -43
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +24 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -7,16 +7,16 @@
7
7
 
8
8
  #pragma once
9
9
 
10
- #include <vector>
11
10
  #include <algorithm>
12
11
  #include <type_traits>
12
+ #include <vector>
13
13
 
14
14
  #include <faiss/utils/Heap.h>
15
15
  #include <faiss/utils/simdlib.h>
16
16
 
17
+ #include <faiss/impl/platform_macros.h>
17
18
  #include <faiss/utils/AlignedTable.h>
18
19
  #include <faiss/utils/partitioning.h>
19
- #include <faiss/impl/platform_macros.h>
20
20
 
21
21
  /** This file contains callbacks for kernels that compute distances.
22
22
  *
@@ -31,7 +31,6 @@ namespace faiss {
31
31
 
32
32
  namespace simd_result_handlers {
33
33
 
34
-
35
34
  /** Dummy structure that just computes a checksum on results
36
35
  * (to avoid the computation to be optimized away) */
37
36
  struct DummyResultHandler {
@@ -41,8 +40,7 @@ struct DummyResultHandler {
41
40
  cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0();
42
41
  }
43
42
 
44
- void set_block_origin(size_t, size_t) {
45
- }
43
+ void set_block_origin(size_t, size_t) {}
46
44
  };
47
45
 
48
46
  /** memorize results in a nq-by-nb matrix.
@@ -50,14 +48,12 @@ struct DummyResultHandler {
50
48
  * j0 is the current upper-left block of the matrix
51
49
  */
52
50
  struct StoreResultHandler {
53
- uint16_t *data;
51
+ uint16_t* data;
54
52
  size_t ld; // total number of columns
55
53
  size_t i0 = 0;
56
54
  size_t j0 = 0;
57
55
 
58
- StoreResultHandler(uint16_t *data, size_t ld):
59
- data(data), ld(ld) {
60
- }
56
+ StoreResultHandler(uint16_t* data, size_t ld) : data(data), ld(ld) {}
61
57
 
62
58
  void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
63
59
  size_t ofs = (q + i0) * ld + j0 + b * 32;
@@ -71,9 +67,8 @@ struct StoreResultHandler {
71
67
  }
72
68
  };
73
69
 
74
-
75
70
  /** stores results in fixed-size matrix. */
76
- template<int NQ, int BB>
71
+ template <int NQ, int BB>
77
72
  struct FixedStorageHandler {
78
73
  simd16uint16 dis[NQ][BB];
79
74
  int i0 = 0;
@@ -88,47 +83,42 @@ struct FixedStorageHandler {
88
83
  assert(j0 == 0);
89
84
  }
90
85
 
91
- template<class OtherResultHandler>
92
- void to_other_handler(OtherResultHandler & other) const {
86
+ template <class OtherResultHandler>
87
+ void to_other_handler(OtherResultHandler& other) const {
93
88
  for (int q = 0; q < NQ; q++) {
94
- for(int b = 0; b < BB; b += 2) {
89
+ for (int b = 0; b < BB; b += 2) {
95
90
  other.handle(q, b / 2, dis[q][b], dis[q][b + 1]);
96
91
  }
97
92
  }
98
93
  }
99
-
100
94
  };
101
95
 
102
-
103
96
  /** Record origin of current block */
104
- template<class C, bool with_id_map>
97
+ template <class C, bool with_id_map>
105
98
  struct SIMDResultHandler {
106
99
  using TI = typename C::TI;
107
100
 
108
101
  bool disable = false;
109
102
 
110
- int64_t i0 = 0; // query origin
111
- int64_t j0 = 0; // db origin
112
- size_t ntotal; // ignore excess elements after ntotal
103
+ int64_t i0 = 0; // query origin
104
+ int64_t j0 = 0; // db origin
105
+ size_t ntotal; // ignore excess elements after ntotal
113
106
 
114
107
  /// these fields are used mainly for the IVF variants (with_id_map=true)
115
- const TI *id_map; // map offset in invlist to vector id
116
- const int *q_map; // map q to global query
117
- const uint16_t *dbias; // table of biases to add to each query
108
+ const TI* id_map; // map offset in invlist to vector id
109
+ const int* q_map; // map q to global query
110
+ const uint16_t* dbias; // table of biases to add to each query
118
111
 
119
- explicit SIMDResultHandler(size_t ntotal):
120
- ntotal(ntotal), id_map(nullptr), q_map(nullptr), dbias(nullptr)
121
- {}
112
+ explicit SIMDResultHandler(size_t ntotal)
113
+ : ntotal(ntotal), id_map(nullptr), q_map(nullptr), dbias(nullptr) {}
122
114
 
123
115
  void set_block_origin(size_t i0, size_t j0) {
124
116
  this->i0 = i0;
125
117
  this->j0 = j0;
126
118
  }
127
119
 
128
-
129
120
  // adjust handler data for IVF.
130
- void adjust_with_origin(size_t & q, simd16uint16 & d0, simd16uint16 & d1)
131
- {
121
+ void adjust_with_origin(size_t& q, simd16uint16& d0, simd16uint16& d1) {
132
122
  q += i0;
133
123
 
134
124
  if (dbias) {
@@ -154,9 +144,10 @@ struct SIMDResultHandler {
154
144
  /// return binary mask of elements below thr in (d0, d1)
155
145
  /// inverse_test returns elements above
156
146
  uint32_t get_lt_mask(
157
- uint16_t thr, size_t b,
158
- simd16uint16 d0, simd16uint16 d1
159
- ) {
147
+ uint16_t thr,
148
+ size_t b,
149
+ simd16uint16 d0,
150
+ simd16uint16 d1) {
160
151
  simd16uint16 thr16(thr);
161
152
  uint32_t lt_mask;
162
153
 
@@ -182,18 +173,16 @@ struct SIMDResultHandler {
182
173
  }
183
174
 
184
175
  virtual void to_flat_arrays(
185
- float *distances, int64_t *labels,
186
- const float *normalizers = nullptr
187
- ) = 0;
176
+ float* distances,
177
+ int64_t* labels,
178
+ const float* normalizers = nullptr) = 0;
188
179
 
189
180
  virtual ~SIMDResultHandler() {}
190
-
191
181
  };
192
182
 
193
-
194
183
  /** Special version for k=1 */
195
- template<class C, bool with_id_map = false>
196
- struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
184
+ template <class C, bool with_id_map = false>
185
+ struct SingleResultHandler : SIMDResultHandler<C, with_id_map> {
197
186
  using T = typename C::T;
198
187
  using TI = typename C::TI;
199
188
 
@@ -203,9 +192,8 @@ struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
203
192
  };
204
193
  std::vector<Result> results;
205
194
 
206
- SingleResultHandler(size_t nq, size_t ntotal):
207
- SIMDResultHandler<C, with_id_map>(ntotal), results(nq)
208
- {
195
+ SingleResultHandler(size_t nq, size_t ntotal)
196
+ : SIMDResultHandler<C, with_id_map>(ntotal), results(nq) {
209
197
  for (int i = 0; i < nq; i++) {
210
198
  Result res = {C::neutral(), -1};
211
199
  results[i] = res;
@@ -213,13 +201,13 @@ struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
213
201
  }
214
202
 
215
203
  void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
216
- if(this->disable) {
204
+ if (this->disable) {
217
205
  return;
218
206
  }
219
207
 
220
208
  this->adjust_with_origin(q, d0, d1);
221
209
 
222
- Result & res = results[q];
210
+ Result& res = results[q];
223
211
  uint32_t lt_mask = this->get_lt_mask(res.val, b, d0, d1);
224
212
  if (!lt_mask) {
225
213
  return;
@@ -242,9 +230,9 @@ struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
242
230
  }
243
231
 
244
232
  void to_flat_arrays(
245
- float *distances, int64_t *labels,
246
- const float *normalizers = nullptr
247
- ) override {
233
+ float* distances,
234
+ int64_t* labels,
235
+ const float* normalizers = nullptr) override {
248
236
  for (int q = 0; q < results.size(); q++) {
249
237
  if (!normalizers) {
250
238
  distances[q] = results[q].val;
@@ -256,48 +244,50 @@ struct SingleResultHandler: SIMDResultHandler<C, with_id_map> {
256
244
  labels[q] = results[q].id;
257
245
  }
258
246
  }
259
-
260
247
  };
261
248
 
262
249
  /** Structure that collects results in a min- or max-heap */
263
- template<class C, bool with_id_map = false>
264
- struct HeapHandler: SIMDResultHandler<C, with_id_map> {
250
+ template <class C, bool with_id_map = false>
251
+ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
265
252
  using T = typename C::T;
266
253
  using TI = typename C::TI;
267
254
 
268
255
  int nq;
269
- T *heap_dis_tab;
270
- TI *heap_ids_tab;
256
+ T* heap_dis_tab;
257
+ TI* heap_ids_tab;
271
258
 
272
- int64_t k; // number of results to keep
259
+ int64_t k; // number of results to keep
273
260
 
274
261
  HeapHandler(
275
- int nq,
276
- T * heap_dis_tab, TI * heap_ids_tab,
277
- size_t k, size_t ntotal
278
- ):
279
- SIMDResultHandler<C, with_id_map>(ntotal), nq(nq),
280
- heap_dis_tab(heap_dis_tab), heap_ids_tab(heap_ids_tab), k(k)
281
- {
282
- for (int q = 0; q < nq; q++) {
283
- T *heap_dis_in = heap_dis_tab + q * k;
284
- TI *heap_ids_in = heap_ids_tab + q * k;
285
- heap_heapify<C> (k, heap_dis_in, heap_ids_in);
262
+ int nq,
263
+ T* heap_dis_tab,
264
+ TI* heap_ids_tab,
265
+ size_t k,
266
+ size_t ntotal)
267
+ : SIMDResultHandler<C, with_id_map>(ntotal),
268
+ nq(nq),
269
+ heap_dis_tab(heap_dis_tab),
270
+ heap_ids_tab(heap_ids_tab),
271
+ k(k) {
272
+ for (int q = 0; q < nq; q++) {
273
+ T* heap_dis_in = heap_dis_tab + q * k;
274
+ TI* heap_ids_in = heap_ids_tab + q * k;
275
+ heap_heapify<C>(k, heap_dis_in, heap_ids_in);
286
276
  }
287
277
  }
288
278
 
289
279
  void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
290
- if(this->disable) {
280
+ if (this->disable) {
291
281
  return;
292
282
  }
293
283
 
294
284
  this->adjust_with_origin(q, d0, d1);
295
285
 
296
- T *heap_dis = heap_dis_tab + q * k;
297
- TI *heap_ids = heap_ids_tab + q * k;
286
+ T* heap_dis = heap_dis_tab + q * k;
287
+ TI* heap_ids = heap_ids_tab + q * k;
298
288
 
299
- uint16_t cur_thresh = heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) :
300
- 0xffff;
289
+ uint16_t cur_thresh =
290
+ heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) : 0xffff;
301
291
 
302
292
  // here we handle the reverse comparison case as well
303
293
  uint32_t lt_mask = this->get_lt_mask(cur_thresh, b, d0, d1);
@@ -306,7 +296,7 @@ struct HeapHandler: SIMDResultHandler<C, with_id_map> {
306
296
  return;
307
297
  }
308
298
 
309
- ALIGNED(32) uint16_t d32tab[32] ;
299
+ ALIGNED(32) uint16_t d32tab[32];
310
300
  d0.store(d32tab);
311
301
  d1.store(d32tab + 16);
312
302
 
@@ -321,20 +311,18 @@ struct HeapHandler: SIMDResultHandler<C, with_id_map> {
321
311
  heap_push<C>(k, heap_dis, heap_ids, dis, idx);
322
312
  }
323
313
  }
324
-
325
314
  }
326
315
 
327
316
  void to_flat_arrays(
328
- float *distances, int64_t *labels,
329
- const float *normalizers = nullptr
330
- ) override {
331
-
317
+ float* distances,
318
+ int64_t* labels,
319
+ const float* normalizers = nullptr) override {
332
320
  for (int q = 0; q < nq; q++) {
333
- T *heap_dis_in = heap_dis_tab + q * k;
334
- TI *heap_ids_in = heap_ids_tab + q * k;
335
- heap_reorder<C> (k, heap_dis_in, heap_ids_in);
336
- int64_t *heap_ids = labels + q * k;
337
- float *heap_dis = distances + q * k;
321
+ T* heap_dis_in = heap_dis_tab + q * k;
322
+ TI* heap_ids_in = heap_ids_tab + q * k;
323
+ heap_reorder<C>(k, heap_dis_in, heap_ids_in);
324
+ int64_t* heap_ids = labels + q * k;
325
+ float* heap_dis = distances + q * k;
338
326
 
339
327
  float one_a = 1.0, b = 0.0;
340
328
  if (normalizers) {
@@ -347,10 +335,8 @@ struct HeapHandler: SIMDResultHandler<C, with_id_map> {
347
335
  }
348
336
  }
349
337
  }
350
-
351
338
  };
352
339
 
353
-
354
340
  /** Simple top-N implementation using a reservoir.
355
341
  *
356
342
  * Results are stored when they are below the threshold until the capacity is
@@ -358,12 +344,10 @@ struct HeapHandler: SIMDResultHandler<C, with_id_map> {
358
344
 
359
345
  namespace {
360
346
 
361
- uint64_t get_cy () {
362
- #ifdef MICRO_BENCHMARK
347
+ uint64_t get_cy() {
348
+ #ifdef MICRO_BENCHMARK
363
349
  uint32_t high, low;
364
- asm volatile("rdtsc \n\t"
365
- : "=a" (low),
366
- "=d" (high));
350
+ asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high));
367
351
  return ((uint64_t)high << 32) | (low);
368
352
  #else
369
353
  return 0;
@@ -372,27 +356,23 @@ uint64_t get_cy () {
372
356
 
373
357
  } // anonymous namespace
374
358
 
375
- template<class C>
359
+ template <class C>
376
360
  struct ReservoirTopN {
377
361
  using T = typename C::T;
378
362
  using TI = typename C::TI;
379
363
 
380
- T *vals;
381
- TI *ids;
364
+ T* vals;
365
+ TI* ids;
382
366
 
383
- size_t i; // number of stored elements
384
- size_t n; // number of requested elements
385
- size_t capacity; // size of storage
367
+ size_t i; // number of stored elements
368
+ size_t n; // number of requested elements
369
+ size_t capacity; // size of storage
386
370
  size_t cycles = 0;
387
371
 
388
372
  T threshold; // current threshold
389
373
 
390
- ReservoirTopN(
391
- size_t n, size_t capacity,
392
- T *vals, TI *ids
393
- ):
394
- vals(vals), ids(ids),
395
- i(0), n(n), capacity(capacity) {
374
+ ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
375
+ : vals(vals), ids(ids), i(0), n(n), capacity(capacity) {
396
376
  assert(n < capacity);
397
377
  threshold = C::neutral();
398
378
  }
@@ -411,11 +391,11 @@ struct ReservoirTopN {
411
391
  /// shrink number of stored elements to n
412
392
  void shrink_xx() {
413
393
  uint64_t t0 = get_cy();
414
- qselect (vals, ids, i, n);
415
- i = n; // forget all elements above i = n
394
+ qselect(vals, ids, i, n);
395
+ i = n; // forget all elements above i = n
416
396
  threshold = C::Crev::neutral();
417
- for(size_t j = 0; j < n; j++) {
418
- if(C::cmp(vals[j], threshold)) {
397
+ for (size_t j = 0; j < n; j++) {
398
+ if (C::cmp(vals[j], threshold)) {
419
399
  threshold = vals[j];
420
400
  }
421
401
  }
@@ -433,16 +413,14 @@ struct ReservoirTopN {
433
413
  uint64_t t0 = get_cy();
434
414
  assert(i == capacity);
435
415
  threshold = partition_fuzzy<C>(
436
- vals, ids, capacity, n, (capacity + n) / 2,
437
- &i);
416
+ vals, ids, capacity, n, (capacity + n) / 2, &i);
438
417
  cycles += get_cy() - t0;
439
418
  }
440
419
  };
441
420
 
442
-
443
421
  /** Handler built from several ReservoirTopN (one per query) */
444
- template<class C, bool with_id_map = false>
445
- struct ReservoirHandler: SIMDResultHandler<C, with_id_map> {
422
+ template <class C, bool with_id_map = false>
423
+ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
446
424
  using T = typename C::T;
447
425
  using TI = typename C::TI;
448
426
 
@@ -454,30 +432,30 @@ struct ReservoirHandler: SIMDResultHandler<C, with_id_map> {
454
432
 
455
433
  uint64_t times[4];
456
434
 
457
- ReservoirHandler(size_t nq, size_t ntotal, size_t n, size_t capacity_in):
458
- SIMDResultHandler<C, with_id_map>(ntotal), capacity((capacity_in + 15) & ~15),
459
- all_ids(nq * capacity), all_vals(nq * capacity)
460
- {
435
+ ReservoirHandler(size_t nq, size_t ntotal, size_t n, size_t capacity_in)
436
+ : SIMDResultHandler<C, with_id_map>(ntotal),
437
+ capacity((capacity_in + 15) & ~15),
438
+ all_ids(nq * capacity),
439
+ all_vals(nq * capacity) {
461
440
  assert(capacity % 16 == 0);
462
441
  for (size_t i = 0; i < nq; i++) {
463
442
  reservoirs.emplace_back(
464
- n, capacity,
465
- all_vals.get() + i * capacity,
466
- all_ids.data() + i * capacity
467
- );
443
+ n,
444
+ capacity,
445
+ all_vals.get() + i * capacity,
446
+ all_ids.data() + i * capacity);
468
447
  }
469
448
  times[0] = times[1] = times[2] = times[3] = 0;
470
449
  }
471
450
 
472
-
473
451
  void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
474
452
  uint64_t t0 = get_cy();
475
- if(this->disable) {
453
+ if (this->disable) {
476
454
  return;
477
455
  }
478
456
  this->adjust_with_origin(q, d0, d1);
479
457
 
480
- ReservoirTopN<C> & res = reservoirs[q];
458
+ ReservoirTopN<C>& res = reservoirs[q];
481
459
  uint32_t lt_mask = this->get_lt_mask(res.threshold, b, d0, d1);
482
460
  uint64_t t1 = get_cy();
483
461
  times[0] += t1 - t0;
@@ -499,27 +477,27 @@ struct ReservoirHandler: SIMDResultHandler<C, with_id_map> {
499
477
  times[1] += get_cy() - t1;
500
478
  }
501
479
 
502
-
503
480
  void to_flat_arrays(
504
- float *distances, int64_t *labels,
505
- const float *normalizers = nullptr
506
- ) override {
481
+ float* distances,
482
+ int64_t* labels,
483
+ const float* normalizers = nullptr) override {
507
484
  using Cf = typename std::conditional<
508
485
  C::is_max,
509
- CMax<float, int64_t>, CMin<float, int64_t>>::type;
486
+ CMax<float, int64_t>,
487
+ CMin<float, int64_t>>::type;
510
488
 
511
489
  uint64_t t0 = get_cy();
512
490
  uint64_t t3 = 0;
513
491
  std::vector<int> perm(reservoirs[0].n);
514
492
  for (int q = 0; q < reservoirs.size(); q++) {
515
- ReservoirTopN<C> & res = reservoirs[q];
493
+ ReservoirTopN<C>& res = reservoirs[q];
516
494
  size_t n = res.n;
517
495
 
518
496
  if (res.i > res.n) {
519
497
  res.shrink();
520
498
  }
521
- int64_t *heap_ids = labels + q * n;
522
- float *heap_dis = distances + q * n;
499
+ int64_t* heap_ids = labels + q * n;
500
+ float* heap_dis = distances + q * n;
523
501
 
524
502
  float one_a = 1.0, b = 0.0;
525
503
  if (normalizers) {
@@ -530,30 +508,24 @@ struct ReservoirHandler: SIMDResultHandler<C, with_id_map> {
530
508
  perm[i] = i;
531
509
  }
532
510
  // indirect sort of result arrays
533
- std::sort(
534
- perm.begin(), perm.begin() + res.i,
535
- [&res](int i, int j) {
536
- return C::cmp(res.vals[j], res.vals[i]);
537
- }
538
- );
511
+ std::sort(perm.begin(), perm.begin() + res.i, [&res](int i, int j) {
512
+ return C::cmp(res.vals[j], res.vals[i]);
513
+ });
539
514
  for (int i = 0; i < res.i; i++) {
540
515
  heap_dis[i] = res.vals[perm[i]] * one_a + b;
541
516
  heap_ids[i] = res.ids[perm[i]];
542
517
  }
543
518
 
544
519
  // possibly add empty results
545
- heap_heapify<Cf> (n - res.i, heap_dis + res.i, heap_ids + res.i);
520
+ heap_heapify<Cf>(n - res.i, heap_dis + res.i, heap_ids + res.i);
546
521
 
547
522
  t3 += res.cycles;
548
523
  }
549
524
  times[2] += get_cy() - t0;
550
525
  times[3] += t3;
551
526
  }
552
-
553
527
  };
554
528
 
555
-
556
529
  } // namespace simd_result_handlers
557
530
 
558
-
559
531
  } // namespace faiss