faiss 0.1.3 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (199) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +25 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +16 -4
  5. data/ext/faiss/ext.cpp +12 -308
  6. data/ext/faiss/extconf.rb +6 -3
  7. data/ext/faiss/index.cpp +189 -0
  8. data/ext/faiss/index_binary.cpp +75 -0
  9. data/ext/faiss/kmeans.cpp +40 -0
  10. data/ext/faiss/numo.hpp +867 -0
  11. data/ext/faiss/pca_matrix.cpp +33 -0
  12. data/ext/faiss/product_quantizer.cpp +53 -0
  13. data/ext/faiss/utils.cpp +13 -0
  14. data/ext/faiss/utils.h +5 -0
  15. data/lib/faiss.rb +0 -5
  16. data/lib/faiss/version.rb +1 -1
  17. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  18. data/vendor/faiss/faiss/AutoTune.h +6 -3
  19. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  20. data/vendor/faiss/faiss/Index.cpp +3 -4
  21. data/vendor/faiss/faiss/Index.h +3 -3
  22. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  23. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  26. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  27. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  29. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  30. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  31. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  32. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  33. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  34. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  35. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  37. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  38. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  39. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  41. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  42. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  43. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  44. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  45. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  46. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  47. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  48. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  49. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  50. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  51. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  52. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  53. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  54. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  55. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  56. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  57. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  58. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  59. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  60. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  61. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  62. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  63. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  64. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  65. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  66. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  67. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  68. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  69. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  70. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  71. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  72. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  73. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  74. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  75. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  76. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  77. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  78. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  79. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  80. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  81. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  82. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  83. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  84. data/vendor/faiss/faiss/impl/io.h +7 -2
  85. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  86. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  87. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  88. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  89. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  90. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  91. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  92. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  93. data/vendor/faiss/faiss/index_io.h +1 -48
  94. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  95. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  96. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  97. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  98. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  99. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  100. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  101. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  102. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  103. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  104. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  105. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  106. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  107. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  108. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  109. data/vendor/faiss/faiss/utils/distances.h +28 -20
  110. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  111. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  112. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  113. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  114. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  115. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  116. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  117. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  118. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  119. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  120. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  121. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  122. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  123. metadata +54 -149
  124. data/lib/faiss/index.rb +0 -20
  125. data/lib/faiss/index_binary.rb +0 -20
  126. data/lib/faiss/kmeans.rb +0 -15
  127. data/lib/faiss/pca_matrix.rb +0 -15
  128. data/lib/faiss/product_quantizer.rb +0 -22
  129. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  130. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  131. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  132. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  133. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  134. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  135. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  136. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  137. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  138. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  139. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  140. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  141. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  142. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  143. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  144. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  145. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  146. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  147. data/vendor/faiss/c_api/Index_c.h +0 -183
  148. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  149. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  150. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  151. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  152. data/vendor/faiss/c_api/error_c.h +0 -42
  153. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  154. data/vendor/faiss/c_api/error_impl.h +0 -16
  155. data/vendor/faiss/c_api/faiss_c.h +0 -58
  156. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  157. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  158. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  159. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  160. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  161. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  162. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  163. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  164. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  165. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  166. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  167. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  168. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  169. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  170. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  171. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  172. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  173. data/vendor/faiss/c_api/index_io_c.h +0 -50
  174. data/vendor/faiss/c_api/macros_impl.h +0 -110
  175. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  176. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  177. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  178. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  179. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  180. data/vendor/faiss/misc/test_blas.cpp +0 -87
  181. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  182. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  183. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  184. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  185. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  186. data/vendor/faiss/tests/test_merge.cpp +0 -260
  187. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  188. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  189. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  190. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  191. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  192. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  193. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  194. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  195. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  196. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  197. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  198. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  199. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -176,8 +176,7 @@ void knn_extra_metrics_template (
176
176
  float disij = vd (x_i, y_j);
177
177
 
178
178
  if (disij < simi[0]) {
179
- maxheap_pop (k, simi, idxi);
180
- maxheap_push (k, simi, idxi, disij, j);
179
+ maxheap_replace_top (k, simi, idxi, disij, j);
181
180
  }
182
181
  y_j += d;
183
182
  }
@@ -8,7 +8,7 @@
8
8
  namespace faiss {
9
9
 
10
10
 
11
- inline BitstringWriter::BitstringWriter(uint8_t *code, int code_size):
11
+ inline BitstringWriter::BitstringWriter(uint8_t *code, size_t code_size):
12
12
  code (code), code_size (code_size), i(0)
13
13
  {
14
14
  memset (code, 0, code_size);
@@ -24,7 +24,7 @@ inline void BitstringWriter::write(uint64_t x, int nbit) {
24
24
  i += nbit;
25
25
  return;
26
26
  } else {
27
- int j = i >> 3;
27
+ size_t j = i >> 3;
28
28
  code[j++] |= x << (i & 7);
29
29
  i += nbit;
30
30
  x >>= na;
@@ -36,7 +36,7 @@ inline void BitstringWriter::write(uint64_t x, int nbit) {
36
36
  }
37
37
 
38
38
 
39
- inline BitstringReader::BitstringReader(const uint8_t *code, int code_size):
39
+ inline BitstringReader::BitstringReader(const uint8_t *code, size_t code_size):
40
40
  code (code), code_size (code_size), i(0)
41
41
  {}
42
42
 
@@ -52,7 +52,7 @@ inline uint64_t BitstringReader::read(int nbit) {
52
52
  return res;
53
53
  } else {
54
54
  int ofs = na;
55
- int j = (i >> 3) + 1;
55
+ size_t j = (i >> 3) + 1;
56
56
  i += nbit;
57
57
  nbit -= na;
58
58
  while (nbit > 8) {
@@ -292,8 +292,7 @@ void hammings_knn_hc (
292
292
  for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
293
293
  dis = hc.hamming (bs2_);
294
294
  if (dis < bh_val_[0]) {
295
- faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
296
- faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
295
+ faiss::maxheap_replace_top<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
297
296
  }
298
297
  }
299
298
  }
@@ -391,8 +390,7 @@ void hammings_knn_hc_1 (
391
390
  for (j = 0; j < n2; j++, bs2_+= nwords) {
392
391
  dis = popcount64 (bs1_ ^ *bs2_);
393
392
  if (dis < bh_val_0) {
394
- faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
395
- faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
393
+ faiss::maxheap_replace_top<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
396
394
  bh_val_0 = bh_val_[0];
397
395
  }
398
396
  }
@@ -818,8 +816,7 @@ static void hamming_dis_inner_loop (
818
816
  int ndiff = hc.hamming (cb);
819
817
  cb += code_size;
820
818
  if (ndiff < bh_val_[0]) {
821
- maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
822
- maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, ndiff, j);
819
+ maxheap_replace_top<hamdis_t> (k, bh_val_, bh_ids_, ndiff, j);
823
820
  }
824
821
  }
825
822
  }
@@ -27,11 +27,6 @@
27
27
 
28
28
  #include <stdint.h>
29
29
 
30
- #ifdef _MSC_VER
31
- #include <intrin.h>
32
- #define __builtin_popcountl __popcnt64
33
- #endif // _MSC_VER
34
-
35
30
  #include <faiss/impl/platform_macros.h>
36
31
  #include <faiss/utils/Heap.h>
37
32
 
@@ -91,7 +86,7 @@ struct BitstringWriter {
91
86
  size_t i; // current bit offset
92
87
 
93
88
  // code_size in bytes
94
- BitstringWriter(uint8_t *code, int code_size);
89
+ BitstringWriter(uint8_t *code, size_t code_size);
95
90
 
96
91
  // write the nbit low bits of x
97
92
  void write(uint64_t x, int nbit);
@@ -103,7 +98,7 @@ struct BitstringReader {
103
98
  size_t i;
104
99
 
105
100
  // code_size in bytes
106
- BitstringReader(const uint8_t *code, int code_size);
101
+ BitstringReader(const uint8_t *code, size_t code_size);
107
102
 
108
103
  // read nbit bits from the code
109
104
  uint64_t read(int nbit);
@@ -0,0 +1,98 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+
9
+
10
+ #pragma once
11
+
12
+ #include <climits>
13
+ #include <cmath>
14
+
15
+ #include <limits>
16
+
17
+
18
+ namespace faiss {
19
+
20
+ /*******************************************************************
21
+ * C object: uniform handling of min and max heap
22
+ *******************************************************************/
23
+
24
+ /** The C object gives the type T of the values of a key-value storage, the type
25
+ * of the keys, TI and the comparison that is done: CMax for a decreasing
26
+ * series and CMin for increasing series. In other words, for a given threshold
27
+ * threshold, an incoming value x is kept if
28
+ *
29
+ * C::cmp(threshold, x)
30
+ *
31
+ * is true.
32
+ */
33
+
34
+ template <typename T_, typename TI_>
35
+ struct CMax;
36
+
37
+ template<typename T> inline T cmin_nextafter(T x);
38
+ template<typename T> inline T cmax_nextafter(T x);
39
+
40
+ // traits of minheaps = heaps where the minimum value is stored on top
41
+ // useful to find the *max* values of an array
42
+ template <typename T_, typename TI_>
43
+ struct CMin {
44
+ typedef T_ T;
45
+ typedef TI_ TI;
46
+ typedef CMax<T_, TI_> Crev; // reference to reverse comparison
47
+ inline static bool cmp (T a, T b) {
48
+ return a < b;
49
+ }
50
+ inline static T neutral () {
51
+ return std::numeric_limits<T>::lowest();
52
+ }
53
+ static const bool is_max = false;
54
+
55
+ inline static T nextafter(T x) {
56
+ return cmin_nextafter(x);
57
+ }
58
+ };
59
+
60
+
61
+
62
+
63
+ template <typename T_, typename TI_>
64
+ struct CMax {
65
+ typedef T_ T;
66
+ typedef TI_ TI;
67
+ typedef CMin<T_, TI_> Crev;
68
+ inline static bool cmp (T a, T b) {
69
+ return a > b;
70
+ }
71
+ inline static T neutral () {
72
+ return std::numeric_limits<T>::max();
73
+ }
74
+ static const bool is_max = true;
75
+ inline static T nextafter(T x) {
76
+ return cmax_nextafter(x);
77
+ }
78
+ };
79
+
80
+
81
+ template<> inline float cmin_nextafter<float>(float x) {
82
+ return std::nextafterf(x, -HUGE_VALF);
83
+ }
84
+
85
+ template<> inline float cmax_nextafter<float>(float x) {
86
+ return std::nextafterf(x, HUGE_VALF);
87
+ }
88
+
89
+ template<> inline uint16_t cmin_nextafter<uint16_t>(uint16_t x) {
90
+ return x - 1;
91
+ }
92
+
93
+ template<> inline uint16_t cmax_nextafter<uint16_t>(uint16_t x) {
94
+ return x + 1;
95
+ }
96
+
97
+
98
+ } // namespace faiss
@@ -0,0 +1,1256 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/utils/partitioning.h>
9
+
10
+ #include <cmath>
11
+ #include <cassert>
12
+
13
+ #include <faiss/impl/FaissAssert.h>
14
+ #include <faiss/utils/AlignedTable.h>
15
+ #include <faiss/utils/ordered_key_value.h>
16
+ #include <faiss/utils/simdlib.h>
17
+
18
+ #include <faiss/impl/platform_macros.h>
19
+
20
+ namespace faiss {
21
+
22
+
23
+ /******************************************************************
24
+ * Internal routines
25
+ ******************************************************************/
26
+
27
+
28
+ namespace partitioning {
29
+
30
+ template<typename T>
31
+ T median3(T a, T b, T c) {
32
+ if (a > b) {
33
+ std::swap(a, b);
34
+ }
35
+ if (c > b) {
36
+ return b;
37
+ }
38
+ if (c > a) {
39
+ return c;
40
+ }
41
+ return a;
42
+ }
43
+
44
+
45
+ template<class C>
46
+ typename C::T sample_threshold_median3(
47
+ const typename C::T * vals, int n,
48
+ typename C::T thresh_inf, typename C::T thresh_sup
49
+ ) {
50
+ using T = typename C::T;
51
+ size_t big_prime = 6700417;
52
+ T val3[3];
53
+ int vi = 0;
54
+
55
+ for (size_t i = 0; i < n; i++) {
56
+ T v = vals[(i * big_prime) % n];
57
+ // thresh_inf < v < thresh_sup (for CMax)
58
+ if (C::cmp(v, thresh_inf) && C::cmp(thresh_sup, v)) {
59
+ val3[vi++] = v;
60
+ if (vi == 3) {
61
+ break;
62
+ }
63
+ }
64
+ }
65
+
66
+ if (vi == 3) {
67
+ return median3(val3[0], val3[1], val3[2]);
68
+ } else if (vi != 0) {
69
+ return val3[0];
70
+ } else {
71
+ return thresh_inf;
72
+ // FAISS_THROW_MSG("too few values to compute a median");
73
+ }
74
+ }
75
+
76
+ template<class C>
77
+ void count_lt_and_eq(
78
+ const typename C::T * vals, size_t n, typename C::T thresh,
79
+ size_t & n_lt, size_t & n_eq
80
+ ) {
81
+ n_lt = n_eq = 0;
82
+
83
+ for(size_t i = 0; i < n; i++) {
84
+ typename C::T v = *vals++;
85
+ if(C::cmp(thresh, v)) {
86
+ n_lt++;
87
+ } else if(v == thresh) {
88
+ n_eq++;
89
+ }
90
+ }
91
+ }
92
+
93
+
94
+ template<class C>
95
+ size_t compress_array(
96
+ typename C::T *vals, typename C::TI * ids,
97
+ size_t n, typename C::T thresh, size_t n_eq
98
+ ) {
99
+ size_t wp = 0;
100
+ for(size_t i = 0; i < n; i++) {
101
+ if (C::cmp(thresh, vals[i])) {
102
+ vals[wp] = vals[i];
103
+ ids[wp] = ids[i];
104
+ wp++;
105
+ } else if (n_eq > 0 && vals[i] == thresh) {
106
+ vals[wp] = vals[i];
107
+ ids[wp] = ids[i];
108
+ wp++;
109
+ n_eq--;
110
+ }
111
+ }
112
+ assert(n_eq == 0);
113
+ return wp;
114
+ }
115
+
116
+
117
+ #define IFV if(false)
118
+
119
+ template<class C>
120
+ typename C::T partition_fuzzy_median3(
121
+ typename C::T *vals, typename C::TI * ids, size_t n,
122
+ size_t q_min, size_t q_max, size_t * q_out)
123
+ {
124
+
125
+ if (q_min == 0) {
126
+ if (q_out) {
127
+ *q_out = C::Crev::neutral();
128
+ }
129
+ return 0;
130
+ }
131
+ if (q_max >= n) {
132
+ if (q_out) {
133
+ *q_out = q_max;
134
+ }
135
+ return C::neutral();
136
+ }
137
+
138
+ using T = typename C::T;
139
+
140
+ // here we use bissection with a median of 3 to find the threshold and
141
+ // compress the arrays afterwards. So it's a n*log(n) algoirithm rather than
142
+ // qselect's O(n) but it avoids shuffling around the array.
143
+
144
+ FAISS_THROW_IF_NOT(n >= 3);
145
+
146
+ T thresh_inf = C::Crev::neutral();
147
+ T thresh_sup = C::neutral();
148
+ T thresh = median3(vals[0], vals[n / 2], vals[n - 1]);
149
+
150
+ size_t n_eq = 0, n_lt = 0;
151
+ size_t q = 0;
152
+
153
+ for(int it = 0; it < 200; it++) {
154
+ count_lt_and_eq<C>(vals, n, thresh, n_lt, n_eq);
155
+
156
+ IFV printf(" thresh=%g [%g %g] n_lt=%ld n_eq=%ld, q=%ld:%ld/%ld\n",
157
+ float(thresh), float(thresh_inf), float(thresh_sup),
158
+ long(n_lt), long(n_eq), long(q_min), long(q_max), long(n));
159
+
160
+ if (n_lt <= q_min) {
161
+ if (n_lt + n_eq >= q_min) {
162
+ q = q_min;
163
+ break;
164
+ } else {
165
+ thresh_inf = thresh;
166
+ }
167
+ } else if (n_lt <= q_max) {
168
+ q = n_lt;
169
+ break;
170
+ } else {
171
+ thresh_sup = thresh;
172
+ }
173
+
174
+ // FIXME avoid a second pass over the array to sample the threshold
175
+ IFV printf(" sample thresh in [%g %g]\n", float(thresh_inf), float(thresh_sup));
176
+ T new_thresh = sample_threshold_median3<C>(vals, n, thresh_inf, thresh_sup);
177
+ if (new_thresh == thresh_inf) {
178
+ // then there is nothing between thresh_inf and thresh_sup
179
+ break;
180
+ }
181
+ thresh = new_thresh;
182
+ }
183
+
184
+ int64_t n_eq_1 = q - n_lt;
185
+
186
+ IFV printf("shrink: thresh=%g n_eq_1=%ld\n", float(thresh), long(n_eq_1));
187
+
188
+ if (n_eq_1 < 0) { // happens when > q elements are at lower bound
189
+ q = q_min;
190
+ thresh = C::Crev::nextafter(thresh);
191
+ n_eq_1 = q;
192
+ } else {
193
+ assert(n_eq_1 <= n_eq);
194
+ }
195
+
196
+ int wp = compress_array<C>(vals, ids, n, thresh, n_eq_1);
197
+
198
+ assert(wp == q);
199
+ if (q_out) {
200
+ *q_out = q;
201
+ }
202
+
203
+ return thresh;
204
+ }
205
+
206
+
207
+ } // namespace partitioning
208
+
209
+
210
+
211
+ /******************************************************************
212
+ * SIMD routines when vals is an aligned array of uint16_t
213
+ ******************************************************************/
214
+
215
+
216
+ namespace simd_partitioning {
217
+
218
+
219
+
220
+ void find_minimax(
221
+ const uint16_t * vals, size_t n,
222
+ uint16_t & smin, uint16_t & smax
223
+ ) {
224
+
225
+ simd16uint16 vmin(0xffff), vmax(0);
226
+ for (size_t i = 0; i + 15 < n; i += 16) {
227
+ simd16uint16 v(vals + i);
228
+ vmin.accu_min(v);
229
+ vmax.accu_max(v);
230
+ }
231
+
232
+ ALIGNED(32) uint16_t tab32[32];
233
+ vmin.store(tab32);
234
+ vmax.store(tab32 + 16);
235
+
236
+ smin = tab32[0], smax = tab32[16];
237
+
238
+ for(int i = 1; i < 16; i++) {
239
+ smin = std::min(smin, tab32[i]);
240
+ smax = std::max(smax, tab32[i + 16]);
241
+ }
242
+
243
+ // missing values
244
+ for(size_t i = (n & ~15); i < n; i++) {
245
+ smin = std::min(smin, vals[i]);
246
+ smax = std::max(smax, vals[i]);
247
+ }
248
+
249
+ }
250
+
251
+
252
+ // max func differentiates between CMin and CMax (keep lowest or largest)
253
+ template<class C>
254
+ simd16uint16 max_func(simd16uint16 v, simd16uint16 thr16) {
255
+ constexpr bool is_max = C::is_max;
256
+ if (is_max) {
257
+ return max(v, thr16);
258
+ } else {
259
+ return min(v, thr16);
260
+ }
261
+ }
262
+
263
+ template<class C>
264
+ void count_lt_and_eq(
265
+ const uint16_t * vals, int n, uint16_t thresh,
266
+ size_t & n_lt, size_t & n_eq
267
+ ) {
268
+ n_lt = n_eq = 0;
269
+ simd16uint16 thr16(thresh);
270
+
271
+ size_t n1 = n / 16;
272
+
273
+ for (size_t i = 0; i < n1; i++) {
274
+ simd16uint16 v(vals);
275
+ vals += 16;
276
+ simd16uint16 eqmask = (v == thr16);
277
+ simd16uint16 max2 = max_func<C>(v, thr16);
278
+ simd16uint16 gemask = (v == max2);
279
+ uint32_t bits = get_MSBs(uint16_to_uint8_saturate(eqmask, gemask));
280
+ int i_eq = __builtin_popcount(bits & 0x00ff00ff);
281
+ int i_ge = __builtin_popcount(bits) - i_eq;
282
+ n_eq += i_eq;
283
+ n_lt += 16 - i_ge;
284
+ }
285
+
286
+ for(size_t i = n1 * 16; i < n; i++) {
287
+ uint16_t v = *vals++;
288
+ if(C::cmp(thresh, v)) {
289
+ n_lt++;
290
+ } else if(v == thresh) {
291
+ n_eq++;
292
+ }
293
+ }
294
+ }
295
+
296
+
297
+
298
+ /* compress separated values and ids table, keeping all values < thresh and at
299
+ * most n_eq equal values */
300
+ template<class C>
301
+ int simd_compress_array(
302
+ uint16_t *vals, typename C::TI * ids, size_t n, uint16_t thresh, int n_eq
303
+ ) {
304
+ simd16uint16 thr16(thresh);
305
+ simd16uint16 mixmask(0xff00);
306
+
307
+ int wp = 0;
308
+ size_t i0;
309
+
310
+ // loop while there are eqs to collect
311
+ for (i0 = 0; i0 + 15 < n && n_eq > 0; i0 += 16) {
312
+ simd16uint16 v(vals + i0);
313
+ simd16uint16 max2 = max_func<C>(v, thr16);
314
+ simd16uint16 gemask = (v == max2);
315
+ simd16uint16 eqmask = (v == thr16);
316
+ uint32_t bits = get_MSBs(blendv(
317
+ simd32uint8(eqmask), simd32uint8(gemask), simd32uint8(mixmask)));
318
+ bits ^= 0xAAAAAAAA;
319
+ // bit 2*i : eq
320
+ // bit 2*i + 1 : lt
321
+
322
+ while(bits) {
323
+ int j = __builtin_ctz(bits) & (~1);
324
+ bool is_eq = (bits >> j) & 1;
325
+ bool is_lt = (bits >> j) & 2;
326
+ bits &= ~(3 << j);
327
+ j >>= 1;
328
+
329
+ if (is_lt) {
330
+ vals[wp] = vals[i0 + j];
331
+ ids[wp] = ids[i0 + j];
332
+ wp++;
333
+ } else if(is_eq && n_eq > 0) {
334
+ vals[wp] = vals[i0 + j];
335
+ ids[wp] = ids[i0 + j];
336
+ wp++;
337
+ n_eq--;
338
+ }
339
+ }
340
+ }
341
+
342
+ // handle remaining, only striclty lt ones.
343
+ for (; i0 + 15 < n; i0 += 16) {
344
+ simd16uint16 v(vals + i0);
345
+ simd16uint16 max2 = max_func<C>(v, thr16);
346
+ simd16uint16 gemask = (v == max2);
347
+ uint32_t bits = ~get_MSBs(simd32uint8(gemask));
348
+
349
+ while(bits) {
350
+ int j = __builtin_ctz(bits);
351
+ bits &= ~(3 << j);
352
+ j >>= 1;
353
+
354
+ vals[wp] = vals[i0 + j];
355
+ ids[wp] = ids[i0 + j];
356
+ wp++;
357
+ }
358
+ }
359
+
360
+ // end with scalar
361
+ for(int i = (n & ~15); i < n; i++) {
362
+ if (C::cmp(thresh, vals[i])) {
363
+ vals[wp] = vals[i];
364
+ ids[wp] = ids[i];
365
+ wp++;
366
+ } else if (vals[i] == thresh && n_eq > 0) {
367
+ vals[wp] = vals[i];
368
+ ids[wp] = ids[i];
369
+ wp++;
370
+ n_eq--;
371
+ }
372
+ }
373
+ assert(n_eq == 0);
374
+ return wp;
375
+ }
376
+
377
+ // #define MICRO_BENCHMARK
378
+
379
+ static uint64_t get_cy () {
380
+ #ifdef MICRO_BENCHMARK
381
+ uint32_t high, low;
382
+ asm volatile("rdtsc \n\t"
383
+ : "=a" (low),
384
+ "=d" (high));
385
+ return ((uint64_t)high << 32) | (low);
386
+ #else
387
+ return 0;
388
+ #endif
389
+ }
390
+
391
+
392
+
393
+ #define IFV if(false)
394
+
395
+ template<class C>
396
+ uint16_t simd_partition_fuzzy_with_bounds(
397
+ uint16_t *vals, typename C::TI * ids, size_t n,
398
+ size_t q_min, size_t q_max, size_t * q_out,
399
+ uint16_t s0i, uint16_t s1i)
400
+ {
401
+
402
+ if (q_min == 0) {
403
+ if (q_out) {
404
+ *q_out = 0;
405
+ }
406
+ return 0;
407
+ }
408
+ if (q_max >= n) {
409
+ if (q_out) {
410
+ *q_out = q_max;
411
+ }
412
+ return 0xffff;
413
+ }
414
+ if (s0i == s1i) {
415
+ if (q_out) {
416
+ *q_out = q_min;
417
+ }
418
+ return s0i;
419
+ }
420
+ uint64_t t0 = get_cy();
421
+
422
+ // lower bound inclusive, upper exclusive
423
+ size_t s0 = s0i, s1 = s1i + 1;
424
+
425
+ IFV printf("bounds: %ld %ld\n", s0, s1 - 1);
426
+
427
+ int thresh;
428
+ size_t n_eq = 0, n_lt = 0;
429
+ size_t q = 0;
430
+
431
+ for(int it = 0; it < 200; it++) {
432
+ // while(s0 + 1 < s1) {
433
+ thresh = (s0 + s1) / 2;
434
+ count_lt_and_eq<C>(vals, n, thresh, n_lt, n_eq);
435
+
436
+ IFV printf(" [%ld %ld] thresh=%d n_lt=%ld n_eq=%ld, q=%ld:%ld/%ld\n",
437
+ s0, s1, thresh, n_lt, n_eq, q_min, q_max, n);
438
+ if (n_lt <= q_min) {
439
+ if (n_lt + n_eq >= q_min) {
440
+ q = q_min;
441
+ break;
442
+ } else {
443
+ if (C::is_max) {
444
+ s0 = thresh;
445
+ } else {
446
+ s1 = thresh;
447
+ }
448
+ }
449
+ } else if (n_lt <= q_max) {
450
+ q = n_lt;
451
+ break;
452
+ } else {
453
+ if (C::is_max) {
454
+ s1 = thresh;
455
+ } else {
456
+ s0 = thresh;
457
+ }
458
+ }
459
+
460
+ }
461
+
462
+ uint64_t t1 = get_cy();
463
+
464
+ // number of equal values to keep
465
+ int64_t n_eq_1 = q - n_lt;
466
+
467
+ IFV printf("shrink: thresh=%d q=%ld n_eq_1=%ld\n", thresh, q, n_eq_1);
468
+ if (n_eq_1 < 0) { // happens when > q elements are at lower bound
469
+ assert(s0 + 1 == s1);
470
+ q = q_min;
471
+ if (C::is_max) {
472
+ thresh--;
473
+ } else {
474
+ thresh++;
475
+ }
476
+ n_eq_1 = q;
477
+ IFV printf(" override: thresh=%d n_eq_1=%ld\n", thresh, n_eq_1);
478
+ } else {
479
+ assert(n_eq_1 <= n_eq);
480
+ }
481
+
482
+ size_t wp = simd_compress_array<C>(vals, ids, n, thresh, n_eq_1);
483
+
484
+ IFV printf("wp=%ld\n", wp);
485
+ assert(wp == q);
486
+ if (q_out) {
487
+ *q_out = q;
488
+ }
489
+
490
+ uint64_t t2 = get_cy();
491
+
492
+ partition_stats.bissect_cycles += t1 - t0;
493
+ partition_stats.compress_cycles += t2 - t1;
494
+
495
+ return thresh;
496
+ }
497
+
498
+
499
+ template<class C>
500
+ uint16_t simd_partition_fuzzy_with_bounds_histogram(
501
+ uint16_t *vals, typename C::TI * ids, size_t n,
502
+ size_t q_min, size_t q_max, size_t * q_out,
503
+ uint16_t s0i, uint16_t s1i)
504
+ {
505
+
506
+ if (q_min == 0) {
507
+ if (q_out) {
508
+ *q_out = 0;
509
+ }
510
+ return 0;
511
+ }
512
+ if (q_max >= n) {
513
+ if (q_out) {
514
+ *q_out = q_max;
515
+ }
516
+ return 0xffff;
517
+ }
518
+ if (s0i == s1i) {
519
+ if (q_out) {
520
+ *q_out = q_min;
521
+ }
522
+ return s0i;
523
+ }
524
+
525
+ IFV printf("partition fuzzy, q=%ld:%ld / %ld, bounds=%d %d\n",
526
+ q_min, q_max, n, s0i, s1i);
527
+
528
+ if (!C::is_max) {
529
+ IFV printf("revert due to CMin, q_min:q_max -> %ld:%ld\n", q_min, q_max);
530
+ q_min = n - q_min;
531
+ q_max = n - q_max;
532
+ }
533
+
534
+ // lower and upper bound of range, inclusive
535
+ int s0 = s0i, s1 = s1i;
536
+ // number of values < s0 and > s1
537
+ size_t n_lt = 0, n_gt = 0;
538
+
539
+ // output of loop:
540
+ int thresh; // final threshold
541
+ uint64_t tot_eq = 0; // total nb of equal values
542
+ uint64_t n_eq = 0; // nb of equal values to keep
543
+ size_t q; // final quantile
544
+
545
+ // buffer for the histograms
546
+ int hist[16];
547
+
548
+ for(int it = 0; it < 20; it++) {
549
+ // otherwise we would be done already
550
+
551
+ int shift = 0;
552
+
553
+ IFV printf(" it %d bounds: %d %d n_lt=%ld n_gt=%ld\n",
554
+ it, s0, s1, n_lt, n_gt);
555
+
556
+ int maxval = s1 - s0;
557
+
558
+ while(maxval > 15) {
559
+ shift++;
560
+ maxval >>= 1;
561
+ }
562
+
563
+ IFV printf(" histogram shift %d maxval %d ?= %d\n",
564
+ shift, maxval, int((s1 - s0) >> shift));
565
+
566
+ if (maxval > 7) {
567
+ simd_histogram_16(vals, n, s0, shift, hist);
568
+ } else {
569
+ simd_histogram_8(vals, n, s0, shift, hist);
570
+ }
571
+ IFV {
572
+ int sum = n_lt + n_gt;
573
+ printf(" n_lt=%ld hist=[", n_lt);
574
+ for(int i = 0; i <= maxval; i++) {
575
+ printf("%d ", hist[i]);
576
+ sum += hist[i];
577
+ }
578
+ printf("] n_gt=%ld sum=%d\n", n_gt, sum);
579
+ assert(sum == n);
580
+ }
581
+
582
+ size_t sum_below = n_lt;
583
+ int i;
584
+ for (i = 0; i <= maxval; i++) {
585
+ sum_below += hist[i];
586
+ if (sum_below >= q_min) {
587
+ break;
588
+ }
589
+ }
590
+ IFV printf(" i=%d sum_below=%ld\n", i, sum_below);
591
+ if (i <= maxval) {
592
+ s0 = s0 + (i << shift);
593
+ s1 = s0 + (1 << shift) - 1;
594
+ n_lt = sum_below - hist[i];
595
+ n_gt = n - sum_below;
596
+ } else {
597
+ assert(!"not implemented");
598
+ }
599
+
600
+ IFV printf(" new bin: s0=%d s1=%d n_lt=%ld n_gt=%ld\n", s0, s1, n_lt, n_gt);
601
+
602
+ if (s1 > s0) {
603
+ if (n_lt >= q_min && q_max >= n_lt) {
604
+ IFV printf(" FOUND1\n");
605
+ thresh = s0;
606
+ q = n_lt;
607
+ break;
608
+ }
609
+
610
+ size_t n_lt_2 = n - n_gt;
611
+ if (n_lt_2 >= q_min && q_max >= n_lt_2) {
612
+ thresh = s1 + 1;
613
+ q = n_lt_2;
614
+ IFV printf(" FOUND2\n");
615
+ break;
616
+ }
617
+ } else {
618
+ thresh = s0;
619
+ q = q_min;
620
+ tot_eq = n - n_gt - n_lt;
621
+ n_eq = q_min - n_lt;
622
+ IFV printf(" FOUND3\n");
623
+ break;
624
+ }
625
+ }
626
+
627
+ IFV printf("end bissection: thresh=%d q=%ld n_eq=%ld\n", thresh, q, n_eq);
628
+
629
+ if (!C::is_max) {
630
+ if (n_eq == 0) {
631
+ thresh --;
632
+ } else {
633
+ // thresh unchanged
634
+ n_eq = tot_eq - n_eq;
635
+ }
636
+ q = n - q;
637
+ IFV printf("revert due to CMin, q->%ld n_eq->%ld\n", q, n_eq);
638
+ }
639
+
640
+ size_t wp = simd_compress_array<C>(vals, ids, n, thresh, n_eq);
641
+ IFV printf("wp=%ld ?= %ld\n", wp, q);
642
+ assert(wp == q);
643
+ if (q_out) {
644
+ *q_out = wp;
645
+ }
646
+
647
+ return thresh;
648
+ }
649
+
650
+
651
+
652
+ template<class C>
653
+ uint16_t simd_partition_fuzzy(
654
+ uint16_t *vals, typename C::TI * ids, size_t n,
655
+ size_t q_min, size_t q_max, size_t * q_out
656
+ ) {
657
+
658
+ assert(is_aligned_pointer(vals));
659
+
660
+ uint16_t s0i, s1i;
661
+ find_minimax(vals, n, s0i, s1i);
662
+ // QSelect_stats.t0 += get_cy() - t0;
663
+
664
+ return simd_partition_fuzzy_with_bounds<C>(
665
+ vals, ids, n, q_min, q_max, q_out, s0i, s1i);
666
+ }
667
+
668
+
669
+
670
+ template<class C>
671
+ uint16_t simd_partition(uint16_t *vals, typename C::TI * ids, size_t n, size_t q) {
672
+
673
+ assert(is_aligned_pointer(vals));
674
+
675
+ if (q == 0) {
676
+ return 0;
677
+ }
678
+ if (q >= n) {
679
+ return 0xffff;
680
+ }
681
+
682
+ uint16_t s0i, s1i;
683
+ find_minimax(vals, n, s0i, s1i);
684
+
685
+ return simd_partition_fuzzy_with_bounds<C>(
686
+ vals, ids, n, q, q, nullptr, s0i, s1i);
687
+ }
688
+
689
+ template<class C>
690
+ uint16_t simd_partition_with_bounds(
691
+ uint16_t *vals, typename C::TI * ids, size_t n, size_t q,
692
+ uint16_t s0i, uint16_t s1i)
693
+ {
694
+ return simd_partition_fuzzy_with_bounds<C>(
695
+ vals, ids, n, q, q, nullptr, s0i, s1i);
696
+ }
697
+
698
+ } // namespace simd_partitioning
699
+
700
+
701
+ /******************************************************************
702
+ * Driver routine
703
+ ******************************************************************/
704
+
705
+
706
+ template<class C>
707
+ typename C::T partition_fuzzy(
708
+ typename C::T *vals, typename C::TI * ids, size_t n,
709
+ size_t q_min, size_t q_max, size_t * q_out)
710
+ {
711
+ // the code below compiles and runs without AVX2 but it's slower than
712
+ // the scalar implementation
713
+ #ifdef __AVX2__
714
+ constexpr bool is_uint16 = std::is_same<typename C::T, uint16_t>::value;
715
+ if (is_uint16 && is_aligned_pointer(vals)) {
716
+ return simd_partitioning::simd_partition_fuzzy<C>(
717
+ (uint16_t*)vals, ids, n, q_min, q_max, q_out);
718
+ }
719
+ #endif
720
+ return partitioning::partition_fuzzy_median3<C>(
721
+ vals, ids, n, q_min, q_max, q_out);
722
+ }
723
+
724
+
725
+ // explicit template instanciations
726
+
727
+ template float partition_fuzzy<CMin<float, int64_t>> (
728
+ float *vals, int64_t * ids, size_t n,
729
+ size_t q_min, size_t q_max, size_t * q_out);
730
+
731
+ template float partition_fuzzy<CMax<float, int64_t>> (
732
+ float *vals, int64_t * ids, size_t n,
733
+ size_t q_min, size_t q_max, size_t * q_out);
734
+
735
+ template uint16_t partition_fuzzy<CMin<uint16_t, int64_t>> (
736
+ uint16_t *vals, int64_t * ids, size_t n,
737
+ size_t q_min, size_t q_max, size_t * q_out);
738
+
739
+ template uint16_t partition_fuzzy<CMax<uint16_t, int64_t>> (
740
+ uint16_t *vals, int64_t * ids, size_t n,
741
+ size_t q_min, size_t q_max, size_t * q_out);
742
+
743
+ template uint16_t partition_fuzzy<CMin<uint16_t, int>> (
744
+ uint16_t *vals, int * ids, size_t n,
745
+ size_t q_min, size_t q_max, size_t * q_out);
746
+
747
+ template uint16_t partition_fuzzy<CMax<uint16_t, int>> (
748
+ uint16_t *vals, int * ids, size_t n,
749
+ size_t q_min, size_t q_max, size_t * q_out);
750
+
751
+
752
+
753
+ /******************************************************************
754
+ * Histogram subroutines
755
+ ******************************************************************/
756
+
757
+ #ifdef __AVX2__
758
+ /// FIXME when MSB of uint16 is set
759
+ // this code does not compile properly with GCC 7.4.0
760
+
761
+ namespace {
762
+
763
+ /************************************************************
764
+ * 8 bins
765
+ ************************************************************/
766
+
767
+ simd32uint8 accu4to8(simd16uint16 a4) {
768
+ simd16uint16 mask4(0x0f0f);
769
+
770
+ simd16uint16 a8_0 = a4 & mask4;
771
+ simd16uint16 a8_1 = (a4 >> 4) & mask4;
772
+
773
+ return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
774
+ }
775
+
776
+
777
+ simd16uint16 accu8to16(simd32uint8 a8) {
778
+ simd16uint16 mask8(0x00ff);
779
+
780
+ simd16uint16 a8_0 = simd16uint16(a8) & mask8;
781
+ simd16uint16 a8_1 = (simd16uint16(a8) >> 8) & mask8;
782
+
783
+ return simd16uint16(_mm256_hadd_epi16(a8_0.i, a8_1.i));
784
+ }
785
+
786
+
787
+ static const simd32uint8 shifts(_mm256_setr_epi8(
788
+ 1, 16, 0, 0, 4, 64, 0, 0,
789
+ 0, 0, 1, 16, 0, 0, 4, 64,
790
+ 1, 16, 0, 0, 4, 64, 0, 0,
791
+ 0, 0, 1, 16, 0, 0, 4, 64
792
+ ));
793
+
794
+ // 2-bit accumulator: we can add only up to 3 elements
795
+ // on output we return 2*4-bit results
796
+ // preproc returns either an index in 0..7 or 0xffff
797
+ // that yeilds a 0 when used in the table look-up
798
+ template<int N, class Preproc>
799
+ void compute_accu2(
800
+ const uint16_t * & data,
801
+ Preproc & pp,
802
+ simd16uint16 & a4lo, simd16uint16 & a4hi
803
+ ) {
804
+ simd16uint16 mask2(0x3333);
805
+ simd16uint16 a2((uint16_t)0); // 2-bit accu
806
+ for (int j = 0; j < N; j ++) {
807
+ simd16uint16 v(data);
808
+ data += 16;
809
+ v = pp(v);
810
+ // 0x800 -> force second half of table
811
+ simd16uint16 idx = v | (v << 8) | simd16uint16(0x800);
812
+ a2 += simd16uint16(shifts.lookup_2_lanes(simd32uint8(idx)));
813
+ }
814
+ a4lo += a2 & mask2;
815
+ a4hi += (a2 >> 2) & mask2;
816
+ }
817
+
818
+
819
+ template<class Preproc>
820
+ simd16uint16 histogram_8(
821
+ const uint16_t * data, Preproc pp,
822
+ size_t n_in) {
823
+
824
+ assert (n_in % 16 == 0);
825
+ int n = n_in / 16;
826
+
827
+ simd32uint8 a8lo(0);
828
+ simd32uint8 a8hi(0);
829
+
830
+ for(int i0 = 0; i0 < n; i0 += 15) {
831
+ simd16uint16 a4lo(0); // 4-bit accus
832
+ simd16uint16 a4hi(0);
833
+
834
+ int i1 = std::min(i0 + 15, n);
835
+ int i;
836
+ for(i = i0; i + 2 < i1; i += 3) {
837
+ compute_accu2<3>(data, pp, a4lo, a4hi); // adds 3 max
838
+ }
839
+ switch (i1 - i) {
840
+ case 2:
841
+ compute_accu2<2>(data, pp, a4lo, a4hi);
842
+ break;
843
+ case 1:
844
+ compute_accu2<1>(data, pp, a4lo, a4hi);
845
+ break;
846
+ }
847
+
848
+ a8lo += accu4to8(a4lo);
849
+ a8hi += accu4to8(a4hi);
850
+ }
851
+
852
+ // move to 16-bit accu
853
+ simd16uint16 a16lo = accu8to16(a8lo);
854
+ simd16uint16 a16hi = accu8to16(a8hi);
855
+
856
+ simd16uint16 a16 = simd16uint16(_mm256_hadd_epi16(a16lo.i, a16hi.i));
857
+
858
+ // the 2 lanes must still be combined
859
+ return a16;
860
+ }
861
+
862
+
863
+ /************************************************************
864
+ * 16 bins
865
+ ************************************************************/
866
+
867
+
868
+
869
+ static const simd32uint8 shifts2(_mm256_setr_epi8(
870
+ 1, 2, 4, 8, 16, 32, 64, (char)128,
871
+ 1, 2, 4, 8, 16, 32, 64, (char)128,
872
+ 1, 2, 4, 8, 16, 32, 64, (char)128,
873
+ 1, 2, 4, 8, 16, 32, 64, (char)128
874
+ ));
875
+
876
+
877
+ simd32uint8 shiftr_16(simd32uint8 x, int n)
878
+ {
879
+ return simd32uint8(simd16uint16(x) >> n);
880
+ }
881
+
882
+
883
+ inline simd32uint8 combine_2x2(simd32uint8 a, simd32uint8 b) {
884
+
885
+ __m256i a1b0 = _mm256_permute2f128_si256(a.i, b.i, 0x21);
886
+ __m256i a0b1 = _mm256_blend_epi32(a.i, b.i, 0xF0);
887
+
888
+ return simd32uint8(a1b0) + simd32uint8(a0b1);
889
+ }
890
+
891
+
892
+ // 2-bit accumulator: we can add only up to 3 elements
893
+ // on output we return 2*4-bit results
894
+ template<int N, class Preproc>
895
+ void compute_accu2_16(
896
+ const uint16_t * & data, Preproc pp,
897
+ simd32uint8 & a4_0, simd32uint8 & a4_1,
898
+ simd32uint8 & a4_2, simd32uint8 & a4_3
899
+ ) {
900
+ simd32uint8 mask1(0x55);
901
+ simd32uint8 a2_0; // 2-bit accu
902
+ simd32uint8 a2_1; // 2-bit accu
903
+ a2_0.clear(); a2_1.clear();
904
+
905
+ for (int j = 0; j < N; j ++) {
906
+ simd16uint16 v(data);
907
+ data += 16;
908
+ v = pp(v);
909
+
910
+ simd16uint16 idx = v | (v << 8);
911
+ simd32uint8 a1 = shifts2.lookup_2_lanes(simd32uint8(idx));
912
+ // contains 0s for out-of-bounds elements
913
+
914
+ simd16uint16 lt8 = (v >> 3) == simd16uint16(0);
915
+ lt8.i = _mm256_xor_si256(lt8.i, _mm256_set1_epi16(0xff00));
916
+
917
+ a1 = a1 & lt8;
918
+
919
+ a2_0 += a1 & mask1;
920
+ a2_1 += shiftr_16(a1, 1) & mask1;
921
+ }
922
+ simd32uint8 mask2(0x33);
923
+
924
+ a4_0 += a2_0 & mask2;
925
+ a4_1 += a2_1 & mask2;
926
+ a4_2 += shiftr_16(a2_0, 2) & mask2;
927
+ a4_3 += shiftr_16(a2_1, 2) & mask2;
928
+
929
+ }
930
+
931
+
932
+ simd32uint8 accu4to8_2(simd32uint8 a4_0, simd32uint8 a4_1) {
933
+ simd32uint8 mask4(0x0f);
934
+
935
+ simd32uint8 a8_0 = combine_2x2(
936
+ a4_0 & mask4,
937
+ shiftr_16(a4_0, 4) & mask4
938
+ );
939
+
940
+ simd32uint8 a8_1 = combine_2x2(
941
+ a4_1 & mask4,
942
+ shiftr_16(a4_1, 4) & mask4
943
+ );
944
+
945
+ return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
946
+ }
947
+
948
+
949
+
950
+ template<class Preproc>
951
+ simd16uint16 histogram_16(const uint16_t * data, Preproc pp, size_t n_in) {
952
+
953
+ assert (n_in % 16 == 0);
954
+ int n = n_in / 16;
955
+
956
+ simd32uint8 a8lo((uint8_t)0);
957
+ simd32uint8 a8hi((uint8_t)0);
958
+
959
+ for(int i0 = 0; i0 < n; i0 += 7) {
960
+ simd32uint8 a4_0(0); // 0, 4, 8, 12
961
+ simd32uint8 a4_1(0); // 1, 5, 9, 13
962
+ simd32uint8 a4_2(0); // 2, 6, 10, 14
963
+ simd32uint8 a4_3(0); // 3, 7, 11, 15
964
+
965
+ int i1 = std::min(i0 + 7, n);
966
+ int i;
967
+ for(i = i0; i + 2 < i1; i += 3) {
968
+ compute_accu2_16<3>(data, pp, a4_0, a4_1, a4_2, a4_3);
969
+ }
970
+ switch (i1 - i) {
971
+ case 2:
972
+ compute_accu2_16<2>(data, pp, a4_0, a4_1, a4_2, a4_3);
973
+ break;
974
+ case 1:
975
+ compute_accu2_16<1>(data, pp, a4_0, a4_1, a4_2, a4_3);
976
+ break;
977
+ }
978
+
979
+ a8lo += accu4to8_2(a4_0, a4_1);
980
+ a8hi += accu4to8_2(a4_2, a4_3);
981
+ }
982
+
983
+ // move to 16-bit accu
984
+ simd16uint16 a16lo = accu8to16(a8lo);
985
+ simd16uint16 a16hi = accu8to16(a8hi);
986
+
987
+ simd16uint16 a16 = simd16uint16(_mm256_hadd_epi16(a16lo.i, a16hi.i));
988
+
989
+ __m256i perm32 = _mm256_setr_epi32(
990
+ 0, 2, 4, 6, 1, 3, 5, 7
991
+ );
992
+ a16.i = _mm256_permutevar8x32_epi32(a16.i, perm32);
993
+
994
+ return a16;
995
+ }
996
+
997
+ struct PreprocNOP {
998
+ simd16uint16 operator () (simd16uint16 x) {
999
+ return x;
1000
+ }
1001
+
1002
+ };
1003
+
1004
+
1005
+ template<int shift, int nbin>
1006
+ struct PreprocMinShift {
1007
+ simd16uint16 min16;
1008
+ simd16uint16 max16;
1009
+
1010
+ explicit PreprocMinShift(uint16_t min) {
1011
+ min16.set1(min);
1012
+ int vmax0 = std::min((nbin << shift) + min, 65536);
1013
+ uint16_t vmax = uint16_t(vmax0 - 1 - min);
1014
+ max16.set1(vmax); // vmax inclusive
1015
+ }
1016
+
1017
+ simd16uint16 operator () (simd16uint16 x) {
1018
+ x = x - min16;
1019
+ simd16uint16 mask = (x == max(x, max16)) - (x == max16);
1020
+ return (x >> shift) | mask;
1021
+ }
1022
+
1023
+ };
1024
+
1025
+ /* unbounded versions of the functions */
1026
+
1027
+ void simd_histogram_8_unbounded(
1028
+ const uint16_t *data, int n,
1029
+ int *hist)
1030
+ {
1031
+ PreprocNOP pp;
1032
+ simd16uint16 a16 = histogram_8(data, pp, (n & ~15));
1033
+
1034
+ ALIGNED(32) uint16_t a16_tab[16];
1035
+ a16.store(a16_tab);
1036
+
1037
+ for(int i = 0; i < 8; i++) {
1038
+ hist[i] = a16_tab[i] + a16_tab[i + 8];
1039
+ }
1040
+
1041
+ for(int i = (n & ~15); i < n; i++) {
1042
+ hist[data[i]]++;
1043
+ }
1044
+
1045
+ }
1046
+
1047
+
1048
+ void simd_histogram_16_unbounded(
1049
+ const uint16_t *data, int n,
1050
+ int *hist)
1051
+ {
1052
+
1053
+ simd16uint16 a16 = histogram_16(data, PreprocNOP(), (n & ~15));
1054
+
1055
+ ALIGNED(32) uint16_t a16_tab[16];
1056
+ a16.store(a16_tab);
1057
+
1058
+ for(int i = 0; i < 16; i++) {
1059
+ hist[i] = a16_tab[i];
1060
+ }
1061
+
1062
+ for(int i = (n & ~15); i < n; i++) {
1063
+ hist[data[i]]++;
1064
+ }
1065
+
1066
+ }
1067
+
1068
+
1069
+
1070
+ } // anonymous namespace
1071
+
1072
+ /************************************************************
1073
+ * Driver routines
1074
+ ************************************************************/
1075
+
1076
+ void simd_histogram_8(
1077
+ const uint16_t *data, int n,
1078
+ uint16_t min, int shift,
1079
+ int *hist)
1080
+ {
1081
+ if (shift < 0) {
1082
+ simd_histogram_8_unbounded(data, n, hist);
1083
+ return;
1084
+ }
1085
+
1086
+ simd16uint16 a16;
1087
+
1088
+ #define DISPATCH(s) \
1089
+ case s: \
1090
+ a16 = histogram_8(data, PreprocMinShift<s, 8>(min), (n & ~15)); \
1091
+ break
1092
+
1093
+ switch(shift) {
1094
+ DISPATCH(0);
1095
+ DISPATCH(1);
1096
+ DISPATCH(2);
1097
+ DISPATCH(3);
1098
+ DISPATCH(4);
1099
+ DISPATCH(5);
1100
+ DISPATCH(6);
1101
+ DISPATCH(7);
1102
+ DISPATCH(8);
1103
+ DISPATCH(9);
1104
+ DISPATCH(10);
1105
+ DISPATCH(11);
1106
+ DISPATCH(12);
1107
+ DISPATCH(13);
1108
+ default:
1109
+ FAISS_THROW_FMT("dispatch for shift=%d not instantiated", shift);
1110
+ }
1111
+ #undef DISPATCH
1112
+
1113
+ ALIGNED(32) uint16_t a16_tab[16];
1114
+ a16.store(a16_tab);
1115
+
1116
+ for(int i = 0; i < 8; i++) {
1117
+ hist[i] = a16_tab[i] + a16_tab[i + 8];
1118
+ }
1119
+
1120
+ // complete with remaining bins
1121
+ for(int i = (n & ~15); i < n; i++) {
1122
+ if (data[i] < min) continue;
1123
+ uint16_t v = data[i] - min;
1124
+ v >>= shift;
1125
+ if (v < 8) hist[v]++;
1126
+ }
1127
+
1128
+ }
1129
+
1130
+
1131
+
1132
+ void simd_histogram_16(
1133
+ const uint16_t *data, int n,
1134
+ uint16_t min, int shift,
1135
+ int *hist)
1136
+ {
1137
+ if (shift < 0) {
1138
+ simd_histogram_16_unbounded(data, n, hist);
1139
+ return;
1140
+ }
1141
+
1142
+ simd16uint16 a16;
1143
+
1144
+ #define DISPATCH(s) \
1145
+ case s: \
1146
+ a16 = histogram_16(data, PreprocMinShift<s, 16>(min), (n & ~15)); \
1147
+ break
1148
+
1149
+ switch(shift) {
1150
+ DISPATCH(0);
1151
+ DISPATCH(1);
1152
+ DISPATCH(2);
1153
+ DISPATCH(3);
1154
+ DISPATCH(4);
1155
+ DISPATCH(5);
1156
+ DISPATCH(6);
1157
+ DISPATCH(7);
1158
+ DISPATCH(8);
1159
+ DISPATCH(9);
1160
+ DISPATCH(10);
1161
+ DISPATCH(11);
1162
+ DISPATCH(12);
1163
+ default:
1164
+ FAISS_THROW_FMT("dispatch for shift=%d not instantiated", shift);
1165
+ }
1166
+ #undef DISPATCH
1167
+
1168
+ ALIGNED(32) uint16_t a16_tab[16];
1169
+ a16.store(a16_tab);
1170
+
1171
+ for(int i = 0; i < 16; i++) {
1172
+ hist[i] = a16_tab[i];
1173
+ }
1174
+
1175
+ for(int i = (n & ~15); i < n; i++) {
1176
+ if (data[i] < min) continue;
1177
+ uint16_t v = data[i] - min;
1178
+ v >>= shift;
1179
+ if (v < 16) hist[v]++;
1180
+ }
1181
+
1182
+ }
1183
+
1184
+
1185
+ // no AVX2
1186
+ #else
1187
+
1188
+
1189
+
1190
+ void simd_histogram_16(
1191
+ const uint16_t *data, int n,
1192
+ uint16_t min, int shift,
1193
+ int *hist)
1194
+ {
1195
+ memset(hist, 0, sizeof(*hist) * 16);
1196
+ if (shift < 0) {
1197
+ for(size_t i = 0; i < n; i++) {
1198
+ hist[data[i]]++;
1199
+ }
1200
+ } else {
1201
+ int vmax0 = std::min((16 << shift) + min, 65536);
1202
+ uint16_t vmax = uint16_t(vmax0 - 1 - min);
1203
+
1204
+ for(size_t i = 0; i < n; i++) {
1205
+ uint16_t v = data[i];
1206
+ v -= min;
1207
+ if (!(v <= vmax))
1208
+ continue;
1209
+ v >>= shift;
1210
+ hist[v]++;
1211
+
1212
+ /*
1213
+ if (data[i] < min) continue;
1214
+ uint16_t v = data[i] - min;
1215
+ v >>= shift;
1216
+ if (v < 16) hist[v]++;
1217
+ */
1218
+ }
1219
+ }
1220
+
1221
+ }
1222
+
1223
+ void simd_histogram_8(
1224
+ const uint16_t *data, int n,
1225
+ uint16_t min, int shift,
1226
+ int *hist)
1227
+ {
1228
+ memset(hist, 0, sizeof(*hist) * 8);
1229
+ if (shift < 0) {
1230
+ for(size_t i = 0; i < n; i++) {
1231
+ hist[data[i]]++;
1232
+ }
1233
+ } else {
1234
+ for(size_t i = 0; i < n; i++) {
1235
+ if (data[i] < min) continue;
1236
+ uint16_t v = data[i] - min;
1237
+ v >>= shift;
1238
+ if (v < 8) hist[v]++;
1239
+ }
1240
+ }
1241
+
1242
+ }
1243
+
1244
+
1245
+ #endif
1246
+
1247
+
1248
+ void PartitionStats::reset() {
1249
+ memset(this, 0, sizeof(*this));
1250
+ }
1251
+
1252
+ PartitionStats partition_stats;
1253
+
1254
+
1255
+
1256
+ } // namespace faiss