faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -7,8 +7,8 @@
7
7
 
8
8
  #include <faiss/utils/partitioning.h>
9
9
 
10
- #include <cmath>
11
10
  #include <cassert>
11
+ #include <cmath>
12
12
 
13
13
  #include <faiss/impl/FaissAssert.h>
14
14
  #include <faiss/utils/AlignedTable.h>
@@ -19,15 +19,13 @@
19
19
 
20
20
  namespace faiss {
21
21
 
22
-
23
22
  /******************************************************************
24
23
  * Internal routines
25
24
  ******************************************************************/
26
25
 
27
-
28
26
  namespace partitioning {
29
27
 
30
- template<typename T>
28
+ template <typename T>
31
29
  T median3(T a, T b, T c) {
32
30
  if (a > b) {
33
31
  std::swap(a, b);
@@ -41,12 +39,12 @@ T median3(T a, T b, T c) {
41
39
  return a;
42
40
  }
43
41
 
44
-
45
- template<class C>
42
+ template <class C>
46
43
  typename C::T sample_threshold_median3(
47
- const typename C::T * vals, int n,
48
- typename C::T thresh_inf, typename C::T thresh_sup
49
- ) {
44
+ const typename C::T* vals,
45
+ int n,
46
+ typename C::T thresh_inf,
47
+ typename C::T thresh_sup) {
50
48
  using T = typename C::T;
51
49
  size_t big_prime = 6700417;
52
50
  T val3[3];
@@ -73,31 +71,34 @@ typename C::T sample_threshold_median3(
73
71
  }
74
72
  }
75
73
 
76
- template<class C>
74
+ template <class C>
77
75
  void count_lt_and_eq(
78
- const typename C::T * vals, size_t n, typename C::T thresh,
79
- size_t & n_lt, size_t & n_eq
80
- ) {
76
+ const typename C::T* vals,
77
+ size_t n,
78
+ typename C::T thresh,
79
+ size_t& n_lt,
80
+ size_t& n_eq) {
81
81
  n_lt = n_eq = 0;
82
82
 
83
- for(size_t i = 0; i < n; i++) {
83
+ for (size_t i = 0; i < n; i++) {
84
84
  typename C::T v = *vals++;
85
- if(C::cmp(thresh, v)) {
85
+ if (C::cmp(thresh, v)) {
86
86
  n_lt++;
87
- } else if(v == thresh) {
87
+ } else if (v == thresh) {
88
88
  n_eq++;
89
89
  }
90
90
  }
91
91
  }
92
92
 
93
-
94
- template<class C>
93
+ template <class C>
95
94
  size_t compress_array(
96
- typename C::T *vals, typename C::TI * ids,
97
- size_t n, typename C::T thresh, size_t n_eq
98
- ) {
95
+ typename C::T* vals,
96
+ typename C::TI* ids,
97
+ size_t n,
98
+ typename C::T thresh,
99
+ size_t n_eq) {
99
100
  size_t wp = 0;
100
- for(size_t i = 0; i < n; i++) {
101
+ for (size_t i = 0; i < n; i++) {
101
102
  if (C::cmp(thresh, vals[i])) {
102
103
  vals[wp] = vals[i];
103
104
  ids[wp] = ids[i];
@@ -113,15 +114,16 @@ size_t compress_array(
113
114
  return wp;
114
115
  }
115
116
 
117
+ #define IFV if (false)
116
118
 
117
- #define IFV if(false)
118
-
119
- template<class C>
119
+ template <class C>
120
120
  typename C::T partition_fuzzy_median3(
121
- typename C::T *vals, typename C::TI * ids, size_t n,
122
- size_t q_min, size_t q_max, size_t * q_out)
123
- {
124
-
121
+ typename C::T* vals,
122
+ typename C::TI* ids,
123
+ size_t n,
124
+ size_t q_min,
125
+ size_t q_max,
126
+ size_t* q_out) {
125
127
  if (q_min == 0) {
126
128
  if (q_out) {
127
129
  *q_out = C::Crev::neutral();
@@ -150,12 +152,19 @@ typename C::T partition_fuzzy_median3(
150
152
  size_t n_eq = 0, n_lt = 0;
151
153
  size_t q = 0;
152
154
 
153
- for(int it = 0; it < 200; it++) {
155
+ for (int it = 0; it < 200; it++) {
154
156
  count_lt_and_eq<C>(vals, n, thresh, n_lt, n_eq);
155
157
 
156
- IFV printf(" thresh=%g [%g %g] n_lt=%ld n_eq=%ld, q=%ld:%ld/%ld\n",
157
- float(thresh), float(thresh_inf), float(thresh_sup),
158
- long(n_lt), long(n_eq), long(q_min), long(q_max), long(n));
158
+ IFV printf(
159
+ " thresh=%g [%g %g] n_lt=%ld n_eq=%ld, q=%ld:%ld/%ld\n",
160
+ float(thresh),
161
+ float(thresh_inf),
162
+ float(thresh_sup),
163
+ long(n_lt),
164
+ long(n_eq),
165
+ long(q_min),
166
+ long(q_max),
167
+ long(n));
159
168
 
160
169
  if (n_lt <= q_min) {
161
170
  if (n_lt + n_eq >= q_min) {
@@ -172,8 +181,12 @@ typename C::T partition_fuzzy_median3(
172
181
  }
173
182
 
174
183
  // FIXME avoid a second pass over the array to sample the threshold
175
- IFV printf(" sample thresh in [%g %g]\n", float(thresh_inf), float(thresh_sup));
176
- T new_thresh = sample_threshold_median3<C>(vals, n, thresh_inf, thresh_sup);
184
+ IFV printf(
185
+ " sample thresh in [%g %g]\n",
186
+ float(thresh_inf),
187
+ float(thresh_sup));
188
+ T new_thresh =
189
+ sample_threshold_median3<C>(vals, n, thresh_inf, thresh_sup);
177
190
  if (new_thresh == thresh_inf) {
178
191
  // then there is nothing between thresh_inf and thresh_sup
179
192
  break;
@@ -203,25 +216,19 @@ typename C::T partition_fuzzy_median3(
203
216
  return thresh;
204
217
  }
205
218
 
206
-
207
219
  } // namespace partitioning
208
220
 
209
-
210
-
211
221
  /******************************************************************
212
222
  * SIMD routines when vals is an aligned array of uint16_t
213
223
  ******************************************************************/
214
224
 
215
-
216
225
  namespace simd_partitioning {
217
226
 
218
-
219
-
220
227
  void find_minimax(
221
- const uint16_t * vals, size_t n,
222
- uint16_t & smin, uint16_t & smax
223
- ) {
224
-
228
+ const uint16_t* vals,
229
+ size_t n,
230
+ uint16_t& smin,
231
+ uint16_t& smax) {
225
232
  simd16uint16 vmin(0xffff), vmax(0);
226
233
  for (size_t i = 0; i + 15 < n; i += 16) {
227
234
  simd16uint16 v(vals + i);
@@ -235,22 +242,20 @@ void find_minimax(
235
242
 
236
243
  smin = tab32[0], smax = tab32[16];
237
244
 
238
- for(int i = 1; i < 16; i++) {
245
+ for (int i = 1; i < 16; i++) {
239
246
  smin = std::min(smin, tab32[i]);
240
247
  smax = std::max(smax, tab32[i + 16]);
241
248
  }
242
249
 
243
250
  // missing values
244
- for(size_t i = (n & ~15); i < n; i++) {
251
+ for (size_t i = (n & ~15); i < n; i++) {
245
252
  smin = std::min(smin, vals[i]);
246
253
  smax = std::max(smax, vals[i]);
247
254
  }
248
-
249
255
  }
250
256
 
251
-
252
257
  // max func differentiates between CMin and CMax (keep lowest or largest)
253
- template<class C>
258
+ template <class C>
254
259
  simd16uint16 max_func(simd16uint16 v, simd16uint16 thr16) {
255
260
  constexpr bool is_max = C::is_max;
256
261
  if (is_max) {
@@ -260,11 +265,13 @@ simd16uint16 max_func(simd16uint16 v, simd16uint16 thr16) {
260
265
  }
261
266
  }
262
267
 
263
- template<class C>
268
+ template <class C>
264
269
  void count_lt_and_eq(
265
- const uint16_t * vals, int n, uint16_t thresh,
266
- size_t & n_lt, size_t & n_eq
267
- ) {
270
+ const uint16_t* vals,
271
+ int n,
272
+ uint16_t thresh,
273
+ size_t& n_lt,
274
+ size_t& n_eq) {
268
275
  n_lt = n_eq = 0;
269
276
  simd16uint16 thr16(thresh);
270
277
 
@@ -283,24 +290,25 @@ void count_lt_and_eq(
283
290
  n_lt += 16 - i_ge;
284
291
  }
285
292
 
286
- for(size_t i = n1 * 16; i < n; i++) {
293
+ for (size_t i = n1 * 16; i < n; i++) {
287
294
  uint16_t v = *vals++;
288
- if(C::cmp(thresh, v)) {
295
+ if (C::cmp(thresh, v)) {
289
296
  n_lt++;
290
- } else if(v == thresh) {
297
+ } else if (v == thresh) {
291
298
  n_eq++;
292
299
  }
293
300
  }
294
301
  }
295
302
 
296
-
297
-
298
303
  /* compress separated values and ids table, keeping all values < thresh and at
299
304
  * most n_eq equal values */
300
- template<class C>
305
+ template <class C>
301
306
  int simd_compress_array(
302
- uint16_t *vals, typename C::TI * ids, size_t n, uint16_t thresh, int n_eq
303
- ) {
307
+ uint16_t* vals,
308
+ typename C::TI* ids,
309
+ size_t n,
310
+ uint16_t thresh,
311
+ int n_eq) {
304
312
  simd16uint16 thr16(thresh);
305
313
  simd16uint16 mixmask(0xff00);
306
314
 
@@ -313,13 +321,15 @@ int simd_compress_array(
313
321
  simd16uint16 max2 = max_func<C>(v, thr16);
314
322
  simd16uint16 gemask = (v == max2);
315
323
  simd16uint16 eqmask = (v == thr16);
316
- uint32_t bits = get_MSBs(blendv(
317
- simd32uint8(eqmask), simd32uint8(gemask), simd32uint8(mixmask)));
324
+ uint32_t bits = get_MSBs(
325
+ blendv(simd32uint8(eqmask),
326
+ simd32uint8(gemask),
327
+ simd32uint8(mixmask)));
318
328
  bits ^= 0xAAAAAAAA;
319
329
  // bit 2*i : eq
320
330
  // bit 2*i + 1 : lt
321
331
 
322
- while(bits) {
332
+ while (bits) {
323
333
  int j = __builtin_ctz(bits) & (~1);
324
334
  bool is_eq = (bits >> j) & 1;
325
335
  bool is_lt = (bits >> j) & 2;
@@ -330,7 +340,7 @@ int simd_compress_array(
330
340
  vals[wp] = vals[i0 + j];
331
341
  ids[wp] = ids[i0 + j];
332
342
  wp++;
333
- } else if(is_eq && n_eq > 0) {
343
+ } else if (is_eq && n_eq > 0) {
334
344
  vals[wp] = vals[i0 + j];
335
345
  ids[wp] = ids[i0 + j];
336
346
  wp++;
@@ -346,7 +356,7 @@ int simd_compress_array(
346
356
  simd16uint16 gemask = (v == max2);
347
357
  uint32_t bits = ~get_MSBs(simd32uint8(gemask));
348
358
 
349
- while(bits) {
359
+ while (bits) {
350
360
  int j = __builtin_ctz(bits);
351
361
  bits &= ~(3 << j);
352
362
  j >>= 1;
@@ -358,7 +368,7 @@ int simd_compress_array(
358
368
  }
359
369
 
360
370
  // end with scalar
361
- for(int i = (n & ~15); i < n; i++) {
371
+ for (int i = (n & ~15); i < n; i++) {
362
372
  if (C::cmp(thresh, vals[i])) {
363
373
  vals[wp] = vals[i];
364
374
  ids[wp] = ids[i];
@@ -376,29 +386,28 @@ int simd_compress_array(
376
386
 
377
387
  // #define MICRO_BENCHMARK
378
388
 
379
- static uint64_t get_cy () {
380
- #ifdef MICRO_BENCHMARK
389
+ static uint64_t get_cy() {
390
+ #ifdef MICRO_BENCHMARK
381
391
  uint32_t high, low;
382
- asm volatile("rdtsc \n\t"
383
- : "=a" (low),
384
- "=d" (high));
392
+ asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high));
385
393
  return ((uint64_t)high << 32) | (low);
386
394
  #else
387
395
  return 0;
388
396
  #endif
389
397
  }
390
398
 
399
+ #define IFV if (false)
391
400
 
392
-
393
- #define IFV if(false)
394
-
395
- template<class C>
401
+ template <class C>
396
402
  uint16_t simd_partition_fuzzy_with_bounds(
397
- uint16_t *vals, typename C::TI * ids, size_t n,
398
- size_t q_min, size_t q_max, size_t * q_out,
399
- uint16_t s0i, uint16_t s1i)
400
- {
401
-
403
+ uint16_t* vals,
404
+ typename C::TI* ids,
405
+ size_t n,
406
+ size_t q_min,
407
+ size_t q_max,
408
+ size_t* q_out,
409
+ uint16_t s0i,
410
+ uint16_t s1i) {
402
411
  if (q_min == 0) {
403
412
  if (q_out) {
404
413
  *q_out = 0;
@@ -428,13 +437,21 @@ uint16_t simd_partition_fuzzy_with_bounds(
428
437
  size_t n_eq = 0, n_lt = 0;
429
438
  size_t q = 0;
430
439
 
431
- for(int it = 0; it < 200; it++) {
440
+ for (int it = 0; it < 200; it++) {
432
441
  // while(s0 + 1 < s1) {
433
442
  thresh = (s0 + s1) / 2;
434
443
  count_lt_and_eq<C>(vals, n, thresh, n_lt, n_eq);
435
444
 
436
- IFV printf(" [%ld %ld] thresh=%d n_lt=%ld n_eq=%ld, q=%ld:%ld/%ld\n",
437
- s0, s1, thresh, n_lt, n_eq, q_min, q_max, n);
445
+ IFV printf(
446
+ " [%ld %ld] thresh=%d n_lt=%ld n_eq=%ld, q=%ld:%ld/%ld\n",
447
+ s0,
448
+ s1,
449
+ thresh,
450
+ n_lt,
451
+ n_eq,
452
+ q_min,
453
+ q_max,
454
+ n);
438
455
  if (n_lt <= q_min) {
439
456
  if (n_lt + n_eq >= q_min) {
440
457
  q = q_min;
@@ -456,7 +473,6 @@ uint16_t simd_partition_fuzzy_with_bounds(
456
473
  s0 = thresh;
457
474
  }
458
475
  }
459
-
460
476
  }
461
477
 
462
478
  uint64_t t1 = get_cy();
@@ -495,14 +511,16 @@ uint16_t simd_partition_fuzzy_with_bounds(
495
511
  return thresh;
496
512
  }
497
513
 
498
-
499
- template<class C>
514
+ template <class C>
500
515
  uint16_t simd_partition_fuzzy_with_bounds_histogram(
501
- uint16_t *vals, typename C::TI * ids, size_t n,
502
- size_t q_min, size_t q_max, size_t * q_out,
503
- uint16_t s0i, uint16_t s1i)
504
- {
505
-
516
+ uint16_t* vals,
517
+ typename C::TI* ids,
518
+ size_t n,
519
+ size_t q_min,
520
+ size_t q_max,
521
+ size_t* q_out,
522
+ uint16_t s0i,
523
+ uint16_t s1i) {
506
524
  if (q_min == 0) {
507
525
  if (q_out) {
508
526
  *q_out = 0;
@@ -522,11 +540,17 @@ uint16_t simd_partition_fuzzy_with_bounds_histogram(
522
540
  return s0i;
523
541
  }
524
542
 
525
- IFV printf("partition fuzzy, q=%ld:%ld / %ld, bounds=%d %d\n",
526
- q_min, q_max, n, s0i, s1i);
543
+ IFV printf(
544
+ "partition fuzzy, q=%ld:%ld / %ld, bounds=%d %d\n",
545
+ q_min,
546
+ q_max,
547
+ n,
548
+ s0i,
549
+ s1i);
527
550
 
528
551
  if (!C::is_max) {
529
- IFV printf("revert due to CMin, q_min:q_max -> %ld:%ld\n", q_min, q_max);
552
+ IFV printf(
553
+ "revert due to CMin, q_min:q_max -> %ld:%ld\n", q_min, q_max);
530
554
  q_min = n - q_min;
531
555
  q_max = n - q_max;
532
556
  }
@@ -537,31 +561,39 @@ uint16_t simd_partition_fuzzy_with_bounds_histogram(
537
561
  size_t n_lt = 0, n_gt = 0;
538
562
 
539
563
  // output of loop:
540
- int thresh; // final threshold
541
- uint64_t tot_eq = 0; // total nb of equal values
542
- uint64_t n_eq = 0; // nb of equal values to keep
543
- size_t q; // final quantile
564
+ int thresh; // final threshold
565
+ uint64_t tot_eq = 0; // total nb of equal values
566
+ uint64_t n_eq = 0; // nb of equal values to keep
567
+ size_t q; // final quantile
544
568
 
545
569
  // buffer for the histograms
546
570
  int hist[16];
547
571
 
548
- for(int it = 0; it < 20; it++) {
572
+ for (int it = 0; it < 20; it++) {
549
573
  // otherwise we would be done already
550
574
 
551
575
  int shift = 0;
552
576
 
553
- IFV printf(" it %d bounds: %d %d n_lt=%ld n_gt=%ld\n",
554
- it, s0, s1, n_lt, n_gt);
577
+ IFV printf(
578
+ " it %d bounds: %d %d n_lt=%ld n_gt=%ld\n",
579
+ it,
580
+ s0,
581
+ s1,
582
+ n_lt,
583
+ n_gt);
555
584
 
556
585
  int maxval = s1 - s0;
557
586
 
558
- while(maxval > 15) {
587
+ while (maxval > 15) {
559
588
  shift++;
560
589
  maxval >>= 1;
561
590
  }
562
591
 
563
- IFV printf(" histogram shift %d maxval %d ?= %d\n",
564
- shift, maxval, int((s1 - s0) >> shift));
592
+ IFV printf(
593
+ " histogram shift %d maxval %d ?= %d\n",
594
+ shift,
595
+ maxval,
596
+ int((s1 - s0) >> shift));
565
597
 
566
598
  if (maxval > 7) {
567
599
  simd_histogram_16(vals, n, s0, shift, hist);
@@ -571,7 +603,7 @@ uint16_t simd_partition_fuzzy_with_bounds_histogram(
571
603
  IFV {
572
604
  int sum = n_lt + n_gt;
573
605
  printf(" n_lt=%ld hist=[", n_lt);
574
- for(int i = 0; i <= maxval; i++) {
606
+ for (int i = 0; i <= maxval; i++) {
575
607
  printf("%d ", hist[i]);
576
608
  sum += hist[i];
577
609
  }
@@ -597,7 +629,12 @@ uint16_t simd_partition_fuzzy_with_bounds_histogram(
597
629
  assert(!"not implemented");
598
630
  }
599
631
 
600
- IFV printf(" new bin: s0=%d s1=%d n_lt=%ld n_gt=%ld\n", s0, s1, n_lt, n_gt);
632
+ IFV printf(
633
+ " new bin: s0=%d s1=%d n_lt=%ld n_gt=%ld\n",
634
+ s0,
635
+ s1,
636
+ n_lt,
637
+ n_gt);
601
638
 
602
639
  if (s1 > s0) {
603
640
  if (n_lt >= q_min && q_max >= n_lt) {
@@ -628,7 +665,7 @@ uint16_t simd_partition_fuzzy_with_bounds_histogram(
628
665
 
629
666
  if (!C::is_max) {
630
667
  if (n_eq == 0) {
631
- thresh --;
668
+ thresh--;
632
669
  } else {
633
670
  // thresh unchanged
634
671
  n_eq = tot_eq - n_eq;
@@ -647,14 +684,14 @@ uint16_t simd_partition_fuzzy_with_bounds_histogram(
647
684
  return thresh;
648
685
  }
649
686
 
650
-
651
-
652
- template<class C>
687
+ template <class C>
653
688
  uint16_t simd_partition_fuzzy(
654
- uint16_t *vals, typename C::TI * ids, size_t n,
655
- size_t q_min, size_t q_max, size_t * q_out
656
- ) {
657
-
689
+ uint16_t* vals,
690
+ typename C::TI* ids,
691
+ size_t n,
692
+ size_t q_min,
693
+ size_t q_max,
694
+ size_t* q_out) {
658
695
  assert(is_aligned_pointer(vals));
659
696
 
660
697
  uint16_t s0i, s1i;
@@ -662,14 +699,15 @@ uint16_t simd_partition_fuzzy(
662
699
  // QSelect_stats.t0 += get_cy() - t0;
663
700
 
664
701
  return simd_partition_fuzzy_with_bounds<C>(
665
- vals, ids, n, q_min, q_max, q_out, s0i, s1i);
702
+ vals, ids, n, q_min, q_max, q_out, s0i, s1i);
666
703
  }
667
704
 
668
-
669
-
670
- template<class C>
671
- uint16_t simd_partition(uint16_t *vals, typename C::TI * ids, size_t n, size_t q) {
672
-
705
+ template <class C>
706
+ uint16_t simd_partition(
707
+ uint16_t* vals,
708
+ typename C::TI* ids,
709
+ size_t n,
710
+ size_t q) {
673
711
  assert(is_aligned_pointer(vals));
674
712
 
675
713
  if (q == 0) {
@@ -683,72 +721,97 @@ uint16_t simd_partition(uint16_t *vals, typename C::TI * ids, size_t n, size_t q
683
721
  find_minimax(vals, n, s0i, s1i);
684
722
 
685
723
  return simd_partition_fuzzy_with_bounds<C>(
686
- vals, ids, n, q, q, nullptr, s0i, s1i);
724
+ vals, ids, n, q, q, nullptr, s0i, s1i);
687
725
  }
688
726
 
689
- template<class C>
727
+ template <class C>
690
728
  uint16_t simd_partition_with_bounds(
691
- uint16_t *vals, typename C::TI * ids, size_t n, size_t q,
692
- uint16_t s0i, uint16_t s1i)
693
- {
729
+ uint16_t* vals,
730
+ typename C::TI* ids,
731
+ size_t n,
732
+ size_t q,
733
+ uint16_t s0i,
734
+ uint16_t s1i) {
694
735
  return simd_partition_fuzzy_with_bounds<C>(
695
- vals, ids, n, q, q, nullptr, s0i, s1i);
736
+ vals, ids, n, q, q, nullptr, s0i, s1i);
696
737
  }
697
738
 
698
739
  } // namespace simd_partitioning
699
740
 
700
-
701
741
  /******************************************************************
702
742
  * Driver routine
703
743
  ******************************************************************/
704
744
 
705
-
706
- template<class C>
745
+ template <class C>
707
746
  typename C::T partition_fuzzy(
708
- typename C::T *vals, typename C::TI * ids, size_t n,
709
- size_t q_min, size_t q_max, size_t * q_out)
710
- {
747
+ typename C::T* vals,
748
+ typename C::TI* ids,
749
+ size_t n,
750
+ size_t q_min,
751
+ size_t q_max,
752
+ size_t* q_out) {
711
753
  // the code below compiles and runs without AVX2 but it's slower than
712
754
  // the scalar implementation
713
755
  #ifdef __AVX2__
714
756
  constexpr bool is_uint16 = std::is_same<typename C::T, uint16_t>::value;
715
757
  if (is_uint16 && is_aligned_pointer(vals)) {
716
758
  return simd_partitioning::simd_partition_fuzzy<C>(
717
- (uint16_t*)vals, ids, n, q_min, q_max, q_out);
759
+ (uint16_t*)vals, ids, n, q_min, q_max, q_out);
718
760
  }
719
761
  #endif
720
762
  return partitioning::partition_fuzzy_median3<C>(
721
- vals, ids, n, q_min, q_max, q_out);
763
+ vals, ids, n, q_min, q_max, q_out);
722
764
  }
723
765
 
724
-
725
766
  // explicit template instanciations
726
767
 
727
- template float partition_fuzzy<CMin<float, int64_t>> (
728
- float *vals, int64_t * ids, size_t n,
729
- size_t q_min, size_t q_max, size_t * q_out);
730
-
731
- template float partition_fuzzy<CMax<float, int64_t>> (
732
- float *vals, int64_t * ids, size_t n,
733
- size_t q_min, size_t q_max, size_t * q_out);
734
-
735
- template uint16_t partition_fuzzy<CMin<uint16_t, int64_t>> (
736
- uint16_t *vals, int64_t * ids, size_t n,
737
- size_t q_min, size_t q_max, size_t * q_out);
738
-
739
- template uint16_t partition_fuzzy<CMax<uint16_t, int64_t>> (
740
- uint16_t *vals, int64_t * ids, size_t n,
741
- size_t q_min, size_t q_max, size_t * q_out);
742
-
743
- template uint16_t partition_fuzzy<CMin<uint16_t, int>> (
744
- uint16_t *vals, int * ids, size_t n,
745
- size_t q_min, size_t q_max, size_t * q_out);
746
-
747
- template uint16_t partition_fuzzy<CMax<uint16_t, int>> (
748
- uint16_t *vals, int * ids, size_t n,
749
- size_t q_min, size_t q_max, size_t * q_out);
750
-
751
-
768
+ template float partition_fuzzy<CMin<float, int64_t>>(
769
+ float* vals,
770
+ int64_t* ids,
771
+ size_t n,
772
+ size_t q_min,
773
+ size_t q_max,
774
+ size_t* q_out);
775
+
776
+ template float partition_fuzzy<CMax<float, int64_t>>(
777
+ float* vals,
778
+ int64_t* ids,
779
+ size_t n,
780
+ size_t q_min,
781
+ size_t q_max,
782
+ size_t* q_out);
783
+
784
+ template uint16_t partition_fuzzy<CMin<uint16_t, int64_t>>(
785
+ uint16_t* vals,
786
+ int64_t* ids,
787
+ size_t n,
788
+ size_t q_min,
789
+ size_t q_max,
790
+ size_t* q_out);
791
+
792
+ template uint16_t partition_fuzzy<CMax<uint16_t, int64_t>>(
793
+ uint16_t* vals,
794
+ int64_t* ids,
795
+ size_t n,
796
+ size_t q_min,
797
+ size_t q_max,
798
+ size_t* q_out);
799
+
800
+ template uint16_t partition_fuzzy<CMin<uint16_t, int>>(
801
+ uint16_t* vals,
802
+ int* ids,
803
+ size_t n,
804
+ size_t q_min,
805
+ size_t q_max,
806
+ size_t* q_out);
807
+
808
+ template uint16_t partition_fuzzy<CMax<uint16_t, int>>(
809
+ uint16_t* vals,
810
+ int* ids,
811
+ size_t n,
812
+ size_t q_min,
813
+ size_t q_max,
814
+ size_t* q_out);
752
815
 
753
816
  /******************************************************************
754
817
  * Histogram subroutines
@@ -758,7 +821,7 @@ template uint16_t partition_fuzzy<CMax<uint16_t, int>> (
758
821
  /// FIXME when MSB of uint16 is set
759
822
  // this code does not compile properly with GCC 7.4.0
760
823
 
761
- namespace {
824
+ namespace {
762
825
 
763
826
  /************************************************************
764
827
  * 8 bins
@@ -773,7 +836,6 @@ simd32uint8 accu4to8(simd16uint16 a4) {
773
836
  return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
774
837
  }
775
838
 
776
-
777
839
  simd16uint16 accu8to16(simd32uint8 a8) {
778
840
  simd16uint16 mask8(0x00ff);
779
841
 
@@ -783,27 +845,53 @@ simd16uint16 accu8to16(simd32uint8 a8) {
783
845
  return simd16uint16(_mm256_hadd_epi16(a8_0.i, a8_1.i));
784
846
  }
785
847
 
786
-
787
848
  static const simd32uint8 shifts(_mm256_setr_epi8(
788
- 1, 16, 0, 0, 4, 64, 0, 0,
789
- 0, 0, 1, 16, 0, 0, 4, 64,
790
- 1, 16, 0, 0, 4, 64, 0, 0,
791
- 0, 0, 1, 16, 0, 0, 4, 64
792
- ));
849
+ 1,
850
+ 16,
851
+ 0,
852
+ 0,
853
+ 4,
854
+ 64,
855
+ 0,
856
+ 0,
857
+ 0,
858
+ 0,
859
+ 1,
860
+ 16,
861
+ 0,
862
+ 0,
863
+ 4,
864
+ 64,
865
+ 1,
866
+ 16,
867
+ 0,
868
+ 0,
869
+ 4,
870
+ 64,
871
+ 0,
872
+ 0,
873
+ 0,
874
+ 0,
875
+ 1,
876
+ 16,
877
+ 0,
878
+ 0,
879
+ 4,
880
+ 64));
793
881
 
794
882
  // 2-bit accumulator: we can add only up to 3 elements
795
883
  // on output we return 2*4-bit results
796
884
  // preproc returns either an index in 0..7 or 0xffff
797
885
  // that yeilds a 0 when used in the table look-up
798
- template<int N, class Preproc>
886
+ template <int N, class Preproc>
799
887
  void compute_accu2(
800
- const uint16_t * & data,
801
- Preproc & pp,
802
- simd16uint16 & a4lo, simd16uint16 & a4hi
803
- ) {
888
+ const uint16_t*& data,
889
+ Preproc& pp,
890
+ simd16uint16& a4lo,
891
+ simd16uint16& a4hi) {
804
892
  simd16uint16 mask2(0x3333);
805
893
  simd16uint16 a2((uint16_t)0); // 2-bit accu
806
- for (int j = 0; j < N; j ++) {
894
+ for (int j = 0; j < N; j++) {
807
895
  simd16uint16 v(data);
808
896
  data += 16;
809
897
  v = pp(v);
@@ -815,34 +903,30 @@ void compute_accu2(
815
903
  a4hi += (a2 >> 2) & mask2;
816
904
  }
817
905
 
818
-
819
- template<class Preproc>
820
- simd16uint16 histogram_8(
821
- const uint16_t * data, Preproc pp,
822
- size_t n_in) {
823
-
824
- assert (n_in % 16 == 0);
906
+ template <class Preproc>
907
+ simd16uint16 histogram_8(const uint16_t* data, Preproc pp, size_t n_in) {
908
+ assert(n_in % 16 == 0);
825
909
  int n = n_in / 16;
826
910
 
827
911
  simd32uint8 a8lo(0);
828
912
  simd32uint8 a8hi(0);
829
913
 
830
- for(int i0 = 0; i0 < n; i0 += 15) {
831
- simd16uint16 a4lo(0); // 4-bit accus
914
+ for (int i0 = 0; i0 < n; i0 += 15) {
915
+ simd16uint16 a4lo(0); // 4-bit accus
832
916
  simd16uint16 a4hi(0);
833
917
 
834
918
  int i1 = std::min(i0 + 15, n);
835
919
  int i;
836
- for(i = i0; i + 2 < i1; i += 3) {
920
+ for (i = i0; i + 2 < i1; i += 3) {
837
921
  compute_accu2<3>(data, pp, a4lo, a4hi); // adds 3 max
838
922
  }
839
923
  switch (i1 - i) {
840
- case 2:
841
- compute_accu2<2>(data, pp, a4lo, a4hi);
842
- break;
843
- case 1:
844
- compute_accu2<1>(data, pp, a4lo, a4hi);
845
- break;
924
+ case 2:
925
+ compute_accu2<2>(data, pp, a4lo, a4hi);
926
+ break;
927
+ case 1:
928
+ compute_accu2<1>(data, pp, a4lo, a4hi);
929
+ break;
846
930
  }
847
931
 
848
932
  a8lo += accu4to8(a4lo);
@@ -859,50 +943,72 @@ simd16uint16 histogram_8(
859
943
  return a16;
860
944
  }
861
945
 
862
-
863
946
  /************************************************************
864
947
  * 16 bins
865
948
  ************************************************************/
866
949
 
867
-
868
-
869
950
  static const simd32uint8 shifts2(_mm256_setr_epi8(
870
- 1, 2, 4, 8, 16, 32, 64, (char)128,
871
- 1, 2, 4, 8, 16, 32, 64, (char)128,
872
- 1, 2, 4, 8, 16, 32, 64, (char)128,
873
- 1, 2, 4, 8, 16, 32, 64, (char)128
874
- ));
875
-
876
-
877
- simd32uint8 shiftr_16(simd32uint8 x, int n)
878
- {
951
+ 1,
952
+ 2,
953
+ 4,
954
+ 8,
955
+ 16,
956
+ 32,
957
+ 64,
958
+ (char)128,
959
+ 1,
960
+ 2,
961
+ 4,
962
+ 8,
963
+ 16,
964
+ 32,
965
+ 64,
966
+ (char)128,
967
+ 1,
968
+ 2,
969
+ 4,
970
+ 8,
971
+ 16,
972
+ 32,
973
+ 64,
974
+ (char)128,
975
+ 1,
976
+ 2,
977
+ 4,
978
+ 8,
979
+ 16,
980
+ 32,
981
+ 64,
982
+ (char)128));
983
+
984
+ simd32uint8 shiftr_16(simd32uint8 x, int n) {
879
985
  return simd32uint8(simd16uint16(x) >> n);
880
986
  }
881
987
 
882
-
883
988
  inline simd32uint8 combine_2x2(simd32uint8 a, simd32uint8 b) {
884
-
885
989
  __m256i a1b0 = _mm256_permute2f128_si256(a.i, b.i, 0x21);
886
990
  __m256i a0b1 = _mm256_blend_epi32(a.i, b.i, 0xF0);
887
991
 
888
992
  return simd32uint8(a1b0) + simd32uint8(a0b1);
889
993
  }
890
994
 
891
-
892
995
  // 2-bit accumulator: we can add only up to 3 elements
893
996
  // on output we return 2*4-bit results
894
- template<int N, class Preproc>
997
+ template <int N, class Preproc>
895
998
  void compute_accu2_16(
896
- const uint16_t * & data, Preproc pp,
897
- simd32uint8 & a4_0, simd32uint8 & a4_1,
898
- simd32uint8 & a4_2, simd32uint8 & a4_3
899
- ) {
999
+ const uint16_t*& data,
1000
+ Preproc pp,
1001
+ simd32uint8& a4_0,
1002
+ simd32uint8& a4_1,
1003
+ simd32uint8& a4_2,
1004
+ simd32uint8& a4_3) {
900
1005
  simd32uint8 mask1(0x55);
901
1006
  simd32uint8 a2_0; // 2-bit accu
902
1007
  simd32uint8 a2_1; // 2-bit accu
903
- a2_0.clear(); a2_1.clear();
1008
+ a2_0.clear();
1009
+ a2_1.clear();
904
1010
 
905
- for (int j = 0; j < N; j ++) {
1011
+ for (int j = 0; j < N; j++) {
906
1012
  simd16uint16 v(data);
907
1013
  data += 16;
908
1014
  v = pp(v);
@@ -925,38 +1031,27 @@ void compute_accu2_16(
925
1031
  a4_1 += a2_1 & mask2;
926
1032
  a4_2 += shiftr_16(a2_0, 2) & mask2;
927
1033
  a4_3 += shiftr_16(a2_1, 2) & mask2;
928
-
929
1034
  }
930
1035
 
931
-
932
1036
  simd32uint8 accu4to8_2(simd32uint8 a4_0, simd32uint8 a4_1) {
933
1037
  simd32uint8 mask4(0x0f);
934
1038
 
935
- simd32uint8 a8_0 = combine_2x2(
936
- a4_0 & mask4,
937
- shiftr_16(a4_0, 4) & mask4
938
- );
1039
+ simd32uint8 a8_0 = combine_2x2(a4_0 & mask4, shiftr_16(a4_0, 4) & mask4);
939
1040
 
940
- simd32uint8 a8_1 = combine_2x2(
941
- a4_1 & mask4,
942
- shiftr_16(a4_1, 4) & mask4
943
- );
1041
+ simd32uint8 a8_1 = combine_2x2(a4_1 & mask4, shiftr_16(a4_1, 4) & mask4);
944
1042
 
945
1043
  return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
946
1044
  }
947
1045
 
948
-
949
-
950
- template<class Preproc>
951
- simd16uint16 histogram_16(const uint16_t * data, Preproc pp, size_t n_in) {
952
-
953
- assert (n_in % 16 == 0);
1046
+ template <class Preproc>
1047
+ simd16uint16 histogram_16(const uint16_t* data, Preproc pp, size_t n_in) {
1048
+ assert(n_in % 16 == 0);
954
1049
  int n = n_in / 16;
955
1050
 
956
1051
  simd32uint8 a8lo((uint8_t)0);
957
1052
  simd32uint8 a8hi((uint8_t)0);
958
1053
 
959
- for(int i0 = 0; i0 < n; i0 += 7) {
1054
+ for (int i0 = 0; i0 < n; i0 += 7) {
960
1055
  simd32uint8 a4_0(0); // 0, 4, 8, 12
961
1056
  simd32uint8 a4_1(0); // 1, 5, 9, 13
962
1057
  simd32uint8 a4_2(0); // 2, 6, 10, 14
@@ -964,16 +1059,16 @@ simd16uint16 histogram_16(const uint16_t * data, Preproc pp, size_t n_in) {
964
1059
 
965
1060
  int i1 = std::min(i0 + 7, n);
966
1061
  int i;
967
- for(i = i0; i + 2 < i1; i += 3) {
1062
+ for (i = i0; i + 2 < i1; i += 3) {
968
1063
  compute_accu2_16<3>(data, pp, a4_0, a4_1, a4_2, a4_3);
969
1064
  }
970
1065
  switch (i1 - i) {
971
- case 2:
972
- compute_accu2_16<2>(data, pp, a4_0, a4_1, a4_2, a4_3);
973
- break;
974
- case 1:
975
- compute_accu2_16<1>(data, pp, a4_0, a4_1, a4_2, a4_3);
976
- break;
1066
+ case 2:
1067
+ compute_accu2_16<2>(data, pp, a4_0, a4_1, a4_2, a4_3);
1068
+ break;
1069
+ case 1:
1070
+ compute_accu2_16<1>(data, pp, a4_0, a4_1, a4_2, a4_3);
1071
+ break;
977
1072
  }
978
1073
 
979
1074
  a8lo += accu4to8_2(a4_0, a4_1);
@@ -986,23 +1081,19 @@ simd16uint16 histogram_16(const uint16_t * data, Preproc pp, size_t n_in) {
986
1081
 
987
1082
  simd16uint16 a16 = simd16uint16(_mm256_hadd_epi16(a16lo.i, a16hi.i));
988
1083
 
989
- __m256i perm32 = _mm256_setr_epi32(
990
- 0, 2, 4, 6, 1, 3, 5, 7
991
- );
1084
+ __m256i perm32 = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
992
1085
  a16.i = _mm256_permutevar8x32_epi32(a16.i, perm32);
993
1086
 
994
1087
  return a16;
995
1088
  }
996
1089
 
997
1090
  struct PreprocNOP {
998
- simd16uint16 operator () (simd16uint16 x) {
1091
+ simd16uint16 operator()(simd16uint16 x) {
999
1092
  return x;
1000
1093
  }
1001
-
1002
1094
  };
1003
1095
 
1004
-
1005
- template<int shift, int nbin>
1096
+ template <int shift, int nbin>
1006
1097
  struct PreprocMinShift {
1007
1098
  simd16uint16 min16;
1008
1099
  simd16uint16 max16;
@@ -1014,59 +1105,46 @@ struct PreprocMinShift {
1014
1105
  max16.set1(vmax); // vmax inclusive
1015
1106
  }
1016
1107
 
1017
- simd16uint16 operator () (simd16uint16 x) {
1108
+ simd16uint16 operator()(simd16uint16 x) {
1018
1109
  x = x - min16;
1019
1110
  simd16uint16 mask = (x == max(x, max16)) - (x == max16);
1020
1111
  return (x >> shift) | mask;
1021
1112
  }
1022
-
1023
1113
  };
1024
1114
 
1025
1115
  /* unbounded versions of the functions */
1026
1116
 
1027
- void simd_histogram_8_unbounded(
1028
- const uint16_t *data, int n,
1029
- int *hist)
1030
- {
1117
+ void simd_histogram_8_unbounded(const uint16_t* data, int n, int* hist) {
1031
1118
  PreprocNOP pp;
1032
1119
  simd16uint16 a16 = histogram_8(data, pp, (n & ~15));
1033
1120
 
1034
1121
  ALIGNED(32) uint16_t a16_tab[16];
1035
1122
  a16.store(a16_tab);
1036
1123
 
1037
- for(int i = 0; i < 8; i++) {
1124
+ for (int i = 0; i < 8; i++) {
1038
1125
  hist[i] = a16_tab[i] + a16_tab[i + 8];
1039
1126
  }
1040
1127
 
1041
- for(int i = (n & ~15); i < n; i++) {
1128
+ for (int i = (n & ~15); i < n; i++) {
1042
1129
  hist[data[i]]++;
1043
1130
  }
1044
-
1045
1131
  }
1046
1132
 
1047
-
1048
- void simd_histogram_16_unbounded(
1049
- const uint16_t *data, int n,
1050
- int *hist)
1051
- {
1052
-
1133
+ void simd_histogram_16_unbounded(const uint16_t* data, int n, int* hist) {
1053
1134
  simd16uint16 a16 = histogram_16(data, PreprocNOP(), (n & ~15));
1054
1135
 
1055
1136
  ALIGNED(32) uint16_t a16_tab[16];
1056
1137
  a16.store(a16_tab);
1057
1138
 
1058
- for(int i = 0; i < 16; i++) {
1139
+ for (int i = 0; i < 16; i++) {
1059
1140
  hist[i] = a16_tab[i];
1060
1141
  }
1061
1142
 
1062
- for(int i = (n & ~15); i < n; i++) {
1143
+ for (int i = (n & ~15); i < n; i++) {
1063
1144
  hist[data[i]]++;
1064
1145
  }
1065
-
1066
1146
  }
1067
1147
 
1068
-
1069
-
1070
1148
  } // anonymous namespace
1071
1149
 
1072
1150
  /************************************************************
@@ -1074,10 +1152,11 @@ void simd_histogram_16_unbounded(
1074
1152
  ************************************************************/
1075
1153
 
1076
1154
  void simd_histogram_8(
1077
- const uint16_t *data, int n,
1078
- uint16_t min, int shift,
1079
- int *hist)
1080
- {
1155
+ const uint16_t* data,
1156
+ int n,
1157
+ uint16_t min,
1158
+ int shift,
1159
+ int* hist) {
1081
1160
  if (shift < 0) {
1082
1161
  simd_histogram_8_unbounded(data, n, hist);
1083
1162
  return;
@@ -1085,12 +1164,12 @@ void simd_histogram_8(
1085
1164
 
1086
1165
  simd16uint16 a16;
1087
1166
 
1088
- #define DISPATCH(s) \
1089
- case s: \
1167
+ #define DISPATCH(s) \
1168
+ case s: \
1090
1169
  a16 = histogram_8(data, PreprocMinShift<s, 8>(min), (n & ~15)); \
1091
1170
  break
1092
1171
 
1093
- switch(shift) {
1172
+ switch (shift) {
1094
1173
  DISPATCH(0);
1095
1174
  DISPATCH(1);
1096
1175
  DISPATCH(2);
@@ -1105,35 +1184,35 @@ void simd_histogram_8(
1105
1184
  DISPATCH(11);
1106
1185
  DISPATCH(12);
1107
1186
  DISPATCH(13);
1108
- default:
1109
- FAISS_THROW_FMT("dispatch for shift=%d not instantiated", shift);
1187
+ default:
1188
+ FAISS_THROW_FMT("dispatch for shift=%d not instantiated", shift);
1110
1189
  }
1111
1190
  #undef DISPATCH
1112
1191
 
1113
1192
  ALIGNED(32) uint16_t a16_tab[16];
1114
1193
  a16.store(a16_tab);
1115
1194
 
1116
- for(int i = 0; i < 8; i++) {
1195
+ for (int i = 0; i < 8; i++) {
1117
1196
  hist[i] = a16_tab[i] + a16_tab[i + 8];
1118
1197
  }
1119
1198
 
1120
1199
  // complete with remaining bins
1121
- for(int i = (n & ~15); i < n; i++) {
1122
- if (data[i] < min) continue;
1200
+ for (int i = (n & ~15); i < n; i++) {
1201
+ if (data[i] < min)
1202
+ continue;
1123
1203
  uint16_t v = data[i] - min;
1124
1204
  v >>= shift;
1125
- if (v < 8) hist[v]++;
1205
+ if (v < 8)
1206
+ hist[v]++;
1126
1207
  }
1127
-
1128
1208
  }
1129
1209
 
1130
-
1131
-
1132
1210
  void simd_histogram_16(
1133
- const uint16_t *data, int n,
1134
- uint16_t min, int shift,
1135
- int *hist)
1136
- {
1211
+ const uint16_t* data,
1212
+ int n,
1213
+ uint16_t min,
1214
+ int shift,
1215
+ int* hist) {
1137
1216
  if (shift < 0) {
1138
1217
  simd_histogram_16_unbounded(data, n, hist);
1139
1218
  return;
@@ -1141,12 +1220,12 @@ void simd_histogram_16(
1141
1220
 
1142
1221
  simd16uint16 a16;
1143
1222
 
1144
- #define DISPATCH(s) \
1145
- case s: \
1223
+ #define DISPATCH(s) \
1224
+ case s: \
1146
1225
  a16 = histogram_16(data, PreprocMinShift<s, 16>(min), (n & ~15)); \
1147
1226
  break
1148
1227
 
1149
- switch(shift) {
1228
+ switch (shift) {
1150
1229
  DISPATCH(0);
1151
1230
  DISPATCH(1);
1152
1231
  DISPATCH(2);
@@ -1160,48 +1239,47 @@ void simd_histogram_16(
1160
1239
  DISPATCH(10);
1161
1240
  DISPATCH(11);
1162
1241
  DISPATCH(12);
1163
- default:
1164
- FAISS_THROW_FMT("dispatch for shift=%d not instantiated", shift);
1242
+ default:
1243
+ FAISS_THROW_FMT("dispatch for shift=%d not instantiated", shift);
1165
1244
  }
1166
1245
  #undef DISPATCH
1167
1246
 
1168
1247
  ALIGNED(32) uint16_t a16_tab[16];
1169
1248
  a16.store(a16_tab);
1170
1249
 
1171
- for(int i = 0; i < 16; i++) {
1250
+ for (int i = 0; i < 16; i++) {
1172
1251
  hist[i] = a16_tab[i];
1173
1252
  }
1174
1253
 
1175
- for(int i = (n & ~15); i < n; i++) {
1176
- if (data[i] < min) continue;
1254
+ for (int i = (n & ~15); i < n; i++) {
1255
+ if (data[i] < min)
1256
+ continue;
1177
1257
  uint16_t v = data[i] - min;
1178
1258
  v >>= shift;
1179
- if (v < 16) hist[v]++;
1259
+ if (v < 16)
1260
+ hist[v]++;
1180
1261
  }
1181
-
1182
1262
  }
1183
1263
 
1184
-
1185
1264
  // no AVX2
1186
1265
  #else
1187
1266
 
1188
-
1189
-
1190
1267
  void simd_histogram_16(
1191
- const uint16_t *data, int n,
1192
- uint16_t min, int shift,
1193
- int *hist)
1194
- {
1268
+ const uint16_t* data,
1269
+ int n,
1270
+ uint16_t min,
1271
+ int shift,
1272
+ int* hist) {
1195
1273
  memset(hist, 0, sizeof(*hist) * 16);
1196
1274
  if (shift < 0) {
1197
- for(size_t i = 0; i < n; i++) {
1275
+ for (size_t i = 0; i < n; i++) {
1198
1276
  hist[data[i]]++;
1199
1277
  }
1200
1278
  } else {
1201
1279
  int vmax0 = std::min((16 << shift) + min, 65536);
1202
1280
  uint16_t vmax = uint16_t(vmax0 - 1 - min);
1203
1281
 
1204
- for(size_t i = 0; i < n; i++) {
1282
+ for (size_t i = 0; i < n; i++) {
1205
1283
  uint16_t v = data[i];
1206
1284
  v -= min;
1207
1285
  if (!(v <= vmax))
@@ -1217,40 +1295,37 @@ void simd_histogram_16(
1217
1295
  */
1218
1296
  }
1219
1297
  }
1220
-
1221
1298
  }
1222
1299
 
1223
1300
  void simd_histogram_8(
1224
- const uint16_t *data, int n,
1225
- uint16_t min, int shift,
1226
- int *hist)
1227
- {
1301
+ const uint16_t* data,
1302
+ int n,
1303
+ uint16_t min,
1304
+ int shift,
1305
+ int* hist) {
1228
1306
  memset(hist, 0, sizeof(*hist) * 8);
1229
1307
  if (shift < 0) {
1230
- for(size_t i = 0; i < n; i++) {
1308
+ for (size_t i = 0; i < n; i++) {
1231
1309
  hist[data[i]]++;
1232
1310
  }
1233
1311
  } else {
1234
- for(size_t i = 0; i < n; i++) {
1235
- if (data[i] < min) continue;
1312
+ for (size_t i = 0; i < n; i++) {
1313
+ if (data[i] < min)
1314
+ continue;
1236
1315
  uint16_t v = data[i] - min;
1237
1316
  v >>= shift;
1238
- if (v < 8) hist[v]++;
1317
+ if (v < 8)
1318
+ hist[v]++;
1239
1319
  }
1240
1320
  }
1241
-
1242
1321
  }
1243
1322
 
1244
-
1245
1323
  #endif
1246
1324
 
1247
-
1248
1325
  void PartitionStats::reset() {
1249
1326
  memset(this, 0, sizeof(*this));
1250
1327
  }
1251
1328
 
1252
1329
  PartitionStats partition_stats;
1253
1330
 
1254
-
1255
-
1256
1331
  } // namespace faiss