faiss 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +1 -1
  6. data/lib/faiss/version.rb +1 -1
  7. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  8. data/vendor/faiss/faiss/AutoTune.h +6 -3
  9. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  10. data/vendor/faiss/faiss/Index.cpp +3 -4
  11. data/vendor/faiss/faiss/Index.h +3 -3
  12. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  13. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  14. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  15. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  16. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  17. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  18. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  19. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  20. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  21. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  22. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  24. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  25. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  26. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  27. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  28. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  29. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  30. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  31. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  32. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  33. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  34. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  35. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  36. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  37. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  38. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  39. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  40. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  41. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  42. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  43. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  44. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  47. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  48. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  49. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  50. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  51. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  52. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  53. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  54. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  55. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  56. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  57. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  58. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  59. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  60. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  61. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  62. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  63. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  64. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  65. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  71. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  72. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  73. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  74. data/vendor/faiss/faiss/impl/io.h +7 -2
  75. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  76. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  77. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  78. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  79. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  81. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  82. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  83. data/vendor/faiss/faiss/index_io.h +1 -48
  84. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  85. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  86. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  87. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  88. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  89. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  90. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  91. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  92. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  93. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  94. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  95. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  96. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  97. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  98. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  99. data/vendor/faiss/faiss/utils/distances.h +28 -20
  100. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  101. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  102. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  103. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  104. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  105. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  106. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  107. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  108. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  109. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  110. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  111. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  112. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  113. metadata +43 -141
  114. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  115. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  116. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  117. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  118. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  119. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  120. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  121. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  122. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  123. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  124. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  125. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  126. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  127. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  128. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  129. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  130. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  131. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  132. data/vendor/faiss/c_api/Index_c.h +0 -183
  133. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  134. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  135. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  136. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  137. data/vendor/faiss/c_api/error_c.h +0 -42
  138. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  139. data/vendor/faiss/c_api/error_impl.h +0 -16
  140. data/vendor/faiss/c_api/faiss_c.h +0 -58
  141. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  142. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  143. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  144. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  145. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  146. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  147. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  148. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  149. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  150. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  151. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  152. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  153. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  154. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  155. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  156. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  157. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  158. data/vendor/faiss/c_api/index_io_c.h +0 -50
  159. data/vendor/faiss/c_api/macros_impl.h +0 -110
  160. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  161. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  162. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  163. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  164. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  165. data/vendor/faiss/misc/test_blas.cpp +0 -87
  166. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  167. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  168. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  169. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  170. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  171. data/vendor/faiss/tests/test_merge.cpp +0 -260
  172. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  173. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  174. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  175. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  176. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  177. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  178. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  179. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  180. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  181. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  182. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  183. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  184. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -176,8 +176,7 @@ void knn_extra_metrics_template (
176
176
  float disij = vd (x_i, y_j);
177
177
 
178
178
  if (disij < simi[0]) {
179
- maxheap_pop (k, simi, idxi);
180
- maxheap_push (k, simi, idxi, disij, j);
179
+ maxheap_replace_top (k, simi, idxi, disij, j);
181
180
  }
182
181
  y_j += d;
183
182
  }
@@ -8,7 +8,7 @@
8
8
  namespace faiss {
9
9
 
10
10
 
11
- inline BitstringWriter::BitstringWriter(uint8_t *code, int code_size):
11
+ inline BitstringWriter::BitstringWriter(uint8_t *code, size_t code_size):
12
12
  code (code), code_size (code_size), i(0)
13
13
  {
14
14
  memset (code, 0, code_size);
@@ -24,7 +24,7 @@ inline void BitstringWriter::write(uint64_t x, int nbit) {
24
24
  i += nbit;
25
25
  return;
26
26
  } else {
27
- int j = i >> 3;
27
+ size_t j = i >> 3;
28
28
  code[j++] |= x << (i & 7);
29
29
  i += nbit;
30
30
  x >>= na;
@@ -36,7 +36,7 @@ inline void BitstringWriter::write(uint64_t x, int nbit) {
36
36
  }
37
37
 
38
38
 
39
- inline BitstringReader::BitstringReader(const uint8_t *code, int code_size):
39
+ inline BitstringReader::BitstringReader(const uint8_t *code, size_t code_size):
40
40
  code (code), code_size (code_size), i(0)
41
41
  {}
42
42
 
@@ -52,7 +52,7 @@ inline uint64_t BitstringReader::read(int nbit) {
52
52
  return res;
53
53
  } else {
54
54
  int ofs = na;
55
- int j = (i >> 3) + 1;
55
+ size_t j = (i >> 3) + 1;
56
56
  i += nbit;
57
57
  nbit -= na;
58
58
  while (nbit > 8) {
@@ -292,8 +292,7 @@ void hammings_knn_hc (
292
292
  for (j = j0; j < j1; j++, bs2_+= bytes_per_code) {
293
293
  dis = hc.hamming (bs2_);
294
294
  if (dis < bh_val_[0]) {
295
- faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
296
- faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
295
+ faiss::maxheap_replace_top<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
297
296
  }
298
297
  }
299
298
  }
@@ -391,8 +390,7 @@ void hammings_knn_hc_1 (
391
390
  for (j = 0; j < n2; j++, bs2_+= nwords) {
392
391
  dis = popcount64 (bs1_ ^ *bs2_);
393
392
  if (dis < bh_val_0) {
394
- faiss::maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
395
- faiss::maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
393
+ faiss::maxheap_replace_top<hamdis_t> (k, bh_val_, bh_ids_, dis, j);
396
394
  bh_val_0 = bh_val_[0];
397
395
  }
398
396
  }
@@ -818,8 +816,7 @@ static void hamming_dis_inner_loop (
818
816
  int ndiff = hc.hamming (cb);
819
817
  cb += code_size;
820
818
  if (ndiff < bh_val_[0]) {
821
- maxheap_pop<hamdis_t> (k, bh_val_, bh_ids_);
822
- maxheap_push<hamdis_t> (k, bh_val_, bh_ids_, ndiff, j);
819
+ maxheap_replace_top<hamdis_t> (k, bh_val_, bh_ids_, ndiff, j);
823
820
  }
824
821
  }
825
822
  }
@@ -27,11 +27,6 @@
27
27
 
28
28
  #include <stdint.h>
29
29
 
30
- #ifdef _MSC_VER
31
- #include <intrin.h>
32
- #define __builtin_popcountl __popcnt64
33
- #endif // _MSC_VER
34
-
35
30
  #include <faiss/impl/platform_macros.h>
36
31
  #include <faiss/utils/Heap.h>
37
32
 
@@ -91,7 +86,7 @@ struct BitstringWriter {
91
86
  size_t i; // current bit offset
92
87
 
93
88
  // code_size in bytes
94
- BitstringWriter(uint8_t *code, int code_size);
89
+ BitstringWriter(uint8_t *code, size_t code_size);
95
90
 
96
91
  // write the nbit low bits of x
97
92
  void write(uint64_t x, int nbit);
@@ -103,7 +98,7 @@ struct BitstringReader {
103
98
  size_t i;
104
99
 
105
100
  // code_size in bytes
106
- BitstringReader(const uint8_t *code, int code_size);
101
+ BitstringReader(const uint8_t *code, size_t code_size);
107
102
 
108
103
  // read nbit bits from the code
109
104
  uint64_t read(int nbit);
@@ -0,0 +1,98 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+
9
+
10
+ #pragma once
11
+
12
+ #include <climits>
13
+ #include <cmath>
14
+
15
+ #include <limits>
16
+
17
+
18
+ namespace faiss {
19
+
20
+ /*******************************************************************
21
+ * C object: uniform handling of min and max heap
22
+ *******************************************************************/
23
+
24
+ /** The C object gives the type T of the values of a key-value storage, the type
25
+ * of the keys, TI and the comparison that is done: CMax for a decreasing
26
+ * series and CMin for increasing series. In other words, for a given threshold
27
+ * threshold, an incoming value x is kept if
28
+ *
29
+ * C::cmp(threshold, x)
30
+ *
31
+ * is true.
32
+ */
33
+
34
+ template <typename T_, typename TI_>
35
+ struct CMax;
36
+
37
+ template<typename T> inline T cmin_nextafter(T x);
38
+ template<typename T> inline T cmax_nextafter(T x);
39
+
40
+ // traits of minheaps = heaps where the minimum value is stored on top
41
+ // useful to find the *max* values of an array
42
+ template <typename T_, typename TI_>
43
+ struct CMin {
44
+ typedef T_ T;
45
+ typedef TI_ TI;
46
+ typedef CMax<T_, TI_> Crev; // reference to reverse comparison
47
+ inline static bool cmp (T a, T b) {
48
+ return a < b;
49
+ }
50
+ inline static T neutral () {
51
+ return std::numeric_limits<T>::lowest();
52
+ }
53
+ static const bool is_max = false;
54
+
55
+ inline static T nextafter(T x) {
56
+ return cmin_nextafter(x);
57
+ }
58
+ };
59
+
60
+
61
+
62
+
63
+ template <typename T_, typename TI_>
64
+ struct CMax {
65
+ typedef T_ T;
66
+ typedef TI_ TI;
67
+ typedef CMin<T_, TI_> Crev;
68
+ inline static bool cmp (T a, T b) {
69
+ return a > b;
70
+ }
71
+ inline static T neutral () {
72
+ return std::numeric_limits<T>::max();
73
+ }
74
+ static const bool is_max = true;
75
+ inline static T nextafter(T x) {
76
+ return cmax_nextafter(x);
77
+ }
78
+ };
79
+
80
+
81
+ template<> inline float cmin_nextafter<float>(float x) {
82
+ return std::nextafterf(x, -HUGE_VALF);
83
+ }
84
+
85
+ template<> inline float cmax_nextafter<float>(float x) {
86
+ return std::nextafterf(x, HUGE_VALF);
87
+ }
88
+
89
+ template<> inline uint16_t cmin_nextafter<uint16_t>(uint16_t x) {
90
+ return x - 1;
91
+ }
92
+
93
+ template<> inline uint16_t cmax_nextafter<uint16_t>(uint16_t x) {
94
+ return x + 1;
95
+ }
96
+
97
+
98
+ } // namespace faiss
@@ -0,0 +1,1256 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/utils/partitioning.h>
9
+
10
+ #include <cmath>
11
+ #include <cassert>
12
+
13
+ #include <faiss/impl/FaissAssert.h>
14
+ #include <faiss/utils/AlignedTable.h>
15
+ #include <faiss/utils/ordered_key_value.h>
16
+ #include <faiss/utils/simdlib.h>
17
+
18
+ #include <faiss/impl/platform_macros.h>
19
+
20
+ namespace faiss {
21
+
22
+
23
+ /******************************************************************
24
+ * Internal routines
25
+ ******************************************************************/
26
+
27
+
28
+ namespace partitioning {
29
+
30
+ template<typename T>
31
+ T median3(T a, T b, T c) {
32
+ if (a > b) {
33
+ std::swap(a, b);
34
+ }
35
+ if (c > b) {
36
+ return b;
37
+ }
38
+ if (c > a) {
39
+ return c;
40
+ }
41
+ return a;
42
+ }
43
+
44
+
45
+ template<class C>
46
+ typename C::T sample_threshold_median3(
47
+ const typename C::T * vals, int n,
48
+ typename C::T thresh_inf, typename C::T thresh_sup
49
+ ) {
50
+ using T = typename C::T;
51
+ size_t big_prime = 6700417;
52
+ T val3[3];
53
+ int vi = 0;
54
+
55
+ for (size_t i = 0; i < n; i++) {
56
+ T v = vals[(i * big_prime) % n];
57
+ // thresh_inf < v < thresh_sup (for CMax)
58
+ if (C::cmp(v, thresh_inf) && C::cmp(thresh_sup, v)) {
59
+ val3[vi++] = v;
60
+ if (vi == 3) {
61
+ break;
62
+ }
63
+ }
64
+ }
65
+
66
+ if (vi == 3) {
67
+ return median3(val3[0], val3[1], val3[2]);
68
+ } else if (vi != 0) {
69
+ return val3[0];
70
+ } else {
71
+ return thresh_inf;
72
+ // FAISS_THROW_MSG("too few values to compute a median");
73
+ }
74
+ }
75
+
76
+ template<class C>
77
+ void count_lt_and_eq(
78
+ const typename C::T * vals, size_t n, typename C::T thresh,
79
+ size_t & n_lt, size_t & n_eq
80
+ ) {
81
+ n_lt = n_eq = 0;
82
+
83
+ for(size_t i = 0; i < n; i++) {
84
+ typename C::T v = *vals++;
85
+ if(C::cmp(thresh, v)) {
86
+ n_lt++;
87
+ } else if(v == thresh) {
88
+ n_eq++;
89
+ }
90
+ }
91
+ }
92
+
93
+
94
+ template<class C>
95
+ size_t compress_array(
96
+ typename C::T *vals, typename C::TI * ids,
97
+ size_t n, typename C::T thresh, size_t n_eq
98
+ ) {
99
+ size_t wp = 0;
100
+ for(size_t i = 0; i < n; i++) {
101
+ if (C::cmp(thresh, vals[i])) {
102
+ vals[wp] = vals[i];
103
+ ids[wp] = ids[i];
104
+ wp++;
105
+ } else if (n_eq > 0 && vals[i] == thresh) {
106
+ vals[wp] = vals[i];
107
+ ids[wp] = ids[i];
108
+ wp++;
109
+ n_eq--;
110
+ }
111
+ }
112
+ assert(n_eq == 0);
113
+ return wp;
114
+ }
115
+
116
+
117
+ #define IFV if(false)
118
+
119
+ template<class C>
120
+ typename C::T partition_fuzzy_median3(
121
+ typename C::T *vals, typename C::TI * ids, size_t n,
122
+ size_t q_min, size_t q_max, size_t * q_out)
123
+ {
124
+
125
+ if (q_min == 0) {
126
+ if (q_out) {
127
+ *q_out = C::Crev::neutral();
128
+ }
129
+ return 0;
130
+ }
131
+ if (q_max >= n) {
132
+ if (q_out) {
133
+ *q_out = q_max;
134
+ }
135
+ return C::neutral();
136
+ }
137
+
138
+ using T = typename C::T;
139
+
140
+ // here we use bissection with a median of 3 to find the threshold and
141
+ // compress the arrays afterwards. So it's a n*log(n) algoirithm rather than
142
+ // qselect's O(n) but it avoids shuffling around the array.
143
+
144
+ FAISS_THROW_IF_NOT(n >= 3);
145
+
146
+ T thresh_inf = C::Crev::neutral();
147
+ T thresh_sup = C::neutral();
148
+ T thresh = median3(vals[0], vals[n / 2], vals[n - 1]);
149
+
150
+ size_t n_eq = 0, n_lt = 0;
151
+ size_t q = 0;
152
+
153
+ for(int it = 0; it < 200; it++) {
154
+ count_lt_and_eq<C>(vals, n, thresh, n_lt, n_eq);
155
+
156
+ IFV printf(" thresh=%g [%g %g] n_lt=%ld n_eq=%ld, q=%ld:%ld/%ld\n",
157
+ float(thresh), float(thresh_inf), float(thresh_sup),
158
+ long(n_lt), long(n_eq), long(q_min), long(q_max), long(n));
159
+
160
+ if (n_lt <= q_min) {
161
+ if (n_lt + n_eq >= q_min) {
162
+ q = q_min;
163
+ break;
164
+ } else {
165
+ thresh_inf = thresh;
166
+ }
167
+ } else if (n_lt <= q_max) {
168
+ q = n_lt;
169
+ break;
170
+ } else {
171
+ thresh_sup = thresh;
172
+ }
173
+
174
+ // FIXME avoid a second pass over the array to sample the threshold
175
+ IFV printf(" sample thresh in [%g %g]\n", float(thresh_inf), float(thresh_sup));
176
+ T new_thresh = sample_threshold_median3<C>(vals, n, thresh_inf, thresh_sup);
177
+ if (new_thresh == thresh_inf) {
178
+ // then there is nothing between thresh_inf and thresh_sup
179
+ break;
180
+ }
181
+ thresh = new_thresh;
182
+ }
183
+
184
+ int64_t n_eq_1 = q - n_lt;
185
+
186
+ IFV printf("shrink: thresh=%g n_eq_1=%ld\n", float(thresh), long(n_eq_1));
187
+
188
+ if (n_eq_1 < 0) { // happens when > q elements are at lower bound
189
+ q = q_min;
190
+ thresh = C::Crev::nextafter(thresh);
191
+ n_eq_1 = q;
192
+ } else {
193
+ assert(n_eq_1 <= n_eq);
194
+ }
195
+
196
+ int wp = compress_array<C>(vals, ids, n, thresh, n_eq_1);
197
+
198
+ assert(wp == q);
199
+ if (q_out) {
200
+ *q_out = q;
201
+ }
202
+
203
+ return thresh;
204
+ }
205
+
206
+
207
+ } // namespace partitioning
208
+
209
+
210
+
211
+ /******************************************************************
212
+ * SIMD routines when vals is an aligned array of uint16_t
213
+ ******************************************************************/
214
+
215
+
216
+ namespace simd_partitioning {
217
+
218
+
219
+
220
+ void find_minimax(
221
+ const uint16_t * vals, size_t n,
222
+ uint16_t & smin, uint16_t & smax
223
+ ) {
224
+
225
+ simd16uint16 vmin(0xffff), vmax(0);
226
+ for (size_t i = 0; i + 15 < n; i += 16) {
227
+ simd16uint16 v(vals + i);
228
+ vmin.accu_min(v);
229
+ vmax.accu_max(v);
230
+ }
231
+
232
+ ALIGNED(32) uint16_t tab32[32];
233
+ vmin.store(tab32);
234
+ vmax.store(tab32 + 16);
235
+
236
+ smin = tab32[0], smax = tab32[16];
237
+
238
+ for(int i = 1; i < 16; i++) {
239
+ smin = std::min(smin, tab32[i]);
240
+ smax = std::max(smax, tab32[i + 16]);
241
+ }
242
+
243
+ // missing values
244
+ for(size_t i = (n & ~15); i < n; i++) {
245
+ smin = std::min(smin, vals[i]);
246
+ smax = std::max(smax, vals[i]);
247
+ }
248
+
249
+ }
250
+
251
+
252
+ // max func differentiates between CMin and CMax (keep lowest or largest)
253
+ template<class C>
254
+ simd16uint16 max_func(simd16uint16 v, simd16uint16 thr16) {
255
+ constexpr bool is_max = C::is_max;
256
+ if (is_max) {
257
+ return max(v, thr16);
258
+ } else {
259
+ return min(v, thr16);
260
+ }
261
+ }
262
+
263
+ template<class C>
264
+ void count_lt_and_eq(
265
+ const uint16_t * vals, int n, uint16_t thresh,
266
+ size_t & n_lt, size_t & n_eq
267
+ ) {
268
+ n_lt = n_eq = 0;
269
+ simd16uint16 thr16(thresh);
270
+
271
+ size_t n1 = n / 16;
272
+
273
+ for (size_t i = 0; i < n1; i++) {
274
+ simd16uint16 v(vals);
275
+ vals += 16;
276
+ simd16uint16 eqmask = (v == thr16);
277
+ simd16uint16 max2 = max_func<C>(v, thr16);
278
+ simd16uint16 gemask = (v == max2);
279
+ uint32_t bits = get_MSBs(uint16_to_uint8_saturate(eqmask, gemask));
280
+ int i_eq = __builtin_popcount(bits & 0x00ff00ff);
281
+ int i_ge = __builtin_popcount(bits) - i_eq;
282
+ n_eq += i_eq;
283
+ n_lt += 16 - i_ge;
284
+ }
285
+
286
+ for(size_t i = n1 * 16; i < n; i++) {
287
+ uint16_t v = *vals++;
288
+ if(C::cmp(thresh, v)) {
289
+ n_lt++;
290
+ } else if(v == thresh) {
291
+ n_eq++;
292
+ }
293
+ }
294
+ }
295
+
296
+
297
+
298
+ /* compress separated values and ids table, keeping all values < thresh and at
299
+ * most n_eq equal values */
300
+ template<class C>
301
+ int simd_compress_array(
302
+ uint16_t *vals, typename C::TI * ids, size_t n, uint16_t thresh, int n_eq
303
+ ) {
304
+ simd16uint16 thr16(thresh);
305
+ simd16uint16 mixmask(0xff00);
306
+
307
+ int wp = 0;
308
+ size_t i0;
309
+
310
+ // loop while there are eqs to collect
311
+ for (i0 = 0; i0 + 15 < n && n_eq > 0; i0 += 16) {
312
+ simd16uint16 v(vals + i0);
313
+ simd16uint16 max2 = max_func<C>(v, thr16);
314
+ simd16uint16 gemask = (v == max2);
315
+ simd16uint16 eqmask = (v == thr16);
316
+ uint32_t bits = get_MSBs(blendv(
317
+ simd32uint8(eqmask), simd32uint8(gemask), simd32uint8(mixmask)));
318
+ bits ^= 0xAAAAAAAA;
319
+ // bit 2*i : eq
320
+ // bit 2*i + 1 : lt
321
+
322
+ while(bits) {
323
+ int j = __builtin_ctz(bits) & (~1);
324
+ bool is_eq = (bits >> j) & 1;
325
+ bool is_lt = (bits >> j) & 2;
326
+ bits &= ~(3 << j);
327
+ j >>= 1;
328
+
329
+ if (is_lt) {
330
+ vals[wp] = vals[i0 + j];
331
+ ids[wp] = ids[i0 + j];
332
+ wp++;
333
+ } else if(is_eq && n_eq > 0) {
334
+ vals[wp] = vals[i0 + j];
335
+ ids[wp] = ids[i0 + j];
336
+ wp++;
337
+ n_eq--;
338
+ }
339
+ }
340
+ }
341
+
342
+ // handle remaining, only striclty lt ones.
343
+ for (; i0 + 15 < n; i0 += 16) {
344
+ simd16uint16 v(vals + i0);
345
+ simd16uint16 max2 = max_func<C>(v, thr16);
346
+ simd16uint16 gemask = (v == max2);
347
+ uint32_t bits = ~get_MSBs(simd32uint8(gemask));
348
+
349
+ while(bits) {
350
+ int j = __builtin_ctz(bits);
351
+ bits &= ~(3 << j);
352
+ j >>= 1;
353
+
354
+ vals[wp] = vals[i0 + j];
355
+ ids[wp] = ids[i0 + j];
356
+ wp++;
357
+ }
358
+ }
359
+
360
+ // end with scalar
361
+ for(int i = (n & ~15); i < n; i++) {
362
+ if (C::cmp(thresh, vals[i])) {
363
+ vals[wp] = vals[i];
364
+ ids[wp] = ids[i];
365
+ wp++;
366
+ } else if (vals[i] == thresh && n_eq > 0) {
367
+ vals[wp] = vals[i];
368
+ ids[wp] = ids[i];
369
+ wp++;
370
+ n_eq--;
371
+ }
372
+ }
373
+ assert(n_eq == 0);
374
+ return wp;
375
+ }
376
+
377
+ // #define MICRO_BENCHMARK
378
+
379
+ static uint64_t get_cy () {
380
+ #ifdef MICRO_BENCHMARK
381
+ uint32_t high, low;
382
+ asm volatile("rdtsc \n\t"
383
+ : "=a" (low),
384
+ "=d" (high));
385
+ return ((uint64_t)high << 32) | (low);
386
+ #else
387
+ return 0;
388
+ #endif
389
+ }
390
+
391
+
392
+
393
+ #define IFV if(false)
394
+
395
+ template<class C>
396
+ uint16_t simd_partition_fuzzy_with_bounds(
397
+ uint16_t *vals, typename C::TI * ids, size_t n,
398
+ size_t q_min, size_t q_max, size_t * q_out,
399
+ uint16_t s0i, uint16_t s1i)
400
+ {
401
+
402
+ if (q_min == 0) {
403
+ if (q_out) {
404
+ *q_out = 0;
405
+ }
406
+ return 0;
407
+ }
408
+ if (q_max >= n) {
409
+ if (q_out) {
410
+ *q_out = q_max;
411
+ }
412
+ return 0xffff;
413
+ }
414
+ if (s0i == s1i) {
415
+ if (q_out) {
416
+ *q_out = q_min;
417
+ }
418
+ return s0i;
419
+ }
420
+ uint64_t t0 = get_cy();
421
+
422
+ // lower bound inclusive, upper exclusive
423
+ size_t s0 = s0i, s1 = s1i + 1;
424
+
425
+ IFV printf("bounds: %ld %ld\n", s0, s1 - 1);
426
+
427
+ int thresh;
428
+ size_t n_eq = 0, n_lt = 0;
429
+ size_t q = 0;
430
+
431
+ for(int it = 0; it < 200; it++) {
432
+ // while(s0 + 1 < s1) {
433
+ thresh = (s0 + s1) / 2;
434
+ count_lt_and_eq<C>(vals, n, thresh, n_lt, n_eq);
435
+
436
+ IFV printf(" [%ld %ld] thresh=%d n_lt=%ld n_eq=%ld, q=%ld:%ld/%ld\n",
437
+ s0, s1, thresh, n_lt, n_eq, q_min, q_max, n);
438
+ if (n_lt <= q_min) {
439
+ if (n_lt + n_eq >= q_min) {
440
+ q = q_min;
441
+ break;
442
+ } else {
443
+ if (C::is_max) {
444
+ s0 = thresh;
445
+ } else {
446
+ s1 = thresh;
447
+ }
448
+ }
449
+ } else if (n_lt <= q_max) {
450
+ q = n_lt;
451
+ break;
452
+ } else {
453
+ if (C::is_max) {
454
+ s1 = thresh;
455
+ } else {
456
+ s0 = thresh;
457
+ }
458
+ }
459
+
460
+ }
461
+
462
+ uint64_t t1 = get_cy();
463
+
464
+ // number of equal values to keep
465
+ int64_t n_eq_1 = q - n_lt;
466
+
467
+ IFV printf("shrink: thresh=%d q=%ld n_eq_1=%ld\n", thresh, q, n_eq_1);
468
+ if (n_eq_1 < 0) { // happens when > q elements are at lower bound
469
+ assert(s0 + 1 == s1);
470
+ q = q_min;
471
+ if (C::is_max) {
472
+ thresh--;
473
+ } else {
474
+ thresh++;
475
+ }
476
+ n_eq_1 = q;
477
+ IFV printf(" override: thresh=%d n_eq_1=%ld\n", thresh, n_eq_1);
478
+ } else {
479
+ assert(n_eq_1 <= n_eq);
480
+ }
481
+
482
+ size_t wp = simd_compress_array<C>(vals, ids, n, thresh, n_eq_1);
483
+
484
+ IFV printf("wp=%ld\n", wp);
485
+ assert(wp == q);
486
+ if (q_out) {
487
+ *q_out = q;
488
+ }
489
+
490
+ uint64_t t2 = get_cy();
491
+
492
+ partition_stats.bissect_cycles += t1 - t0;
493
+ partition_stats.compress_cycles += t2 - t1;
494
+
495
+ return thresh;
496
+ }
497
+
498
+
499
+ template<class C>
500
+ uint16_t simd_partition_fuzzy_with_bounds_histogram(
501
+ uint16_t *vals, typename C::TI * ids, size_t n,
502
+ size_t q_min, size_t q_max, size_t * q_out,
503
+ uint16_t s0i, uint16_t s1i)
504
+ {
505
+
506
+ if (q_min == 0) {
507
+ if (q_out) {
508
+ *q_out = 0;
509
+ }
510
+ return 0;
511
+ }
512
+ if (q_max >= n) {
513
+ if (q_out) {
514
+ *q_out = q_max;
515
+ }
516
+ return 0xffff;
517
+ }
518
+ if (s0i == s1i) {
519
+ if (q_out) {
520
+ *q_out = q_min;
521
+ }
522
+ return s0i;
523
+ }
524
+
525
+ IFV printf("partition fuzzy, q=%ld:%ld / %ld, bounds=%d %d\n",
526
+ q_min, q_max, n, s0i, s1i);
527
+
528
+ if (!C::is_max) {
529
+ IFV printf("revert due to CMin, q_min:q_max -> %ld:%ld\n", q_min, q_max);
530
+ q_min = n - q_min;
531
+ q_max = n - q_max;
532
+ }
533
+
534
+ // lower and upper bound of range, inclusive
535
+ int s0 = s0i, s1 = s1i;
536
+ // number of values < s0 and > s1
537
+ size_t n_lt = 0, n_gt = 0;
538
+
539
+ // output of loop:
540
+ int thresh; // final threshold
541
+ uint64_t tot_eq = 0; // total nb of equal values
542
+ uint64_t n_eq = 0; // nb of equal values to keep
543
+ size_t q; // final quantile
544
+
545
+ // buffer for the histograms
546
+ int hist[16];
547
+
548
+ for(int it = 0; it < 20; it++) {
549
+ // otherwise we would be done already
550
+
551
+ int shift = 0;
552
+
553
+ IFV printf(" it %d bounds: %d %d n_lt=%ld n_gt=%ld\n",
554
+ it, s0, s1, n_lt, n_gt);
555
+
556
+ int maxval = s1 - s0;
557
+
558
+ while(maxval > 15) {
559
+ shift++;
560
+ maxval >>= 1;
561
+ }
562
+
563
+ IFV printf(" histogram shift %d maxval %d ?= %d\n",
564
+ shift, maxval, int((s1 - s0) >> shift));
565
+
566
+ if (maxval > 7) {
567
+ simd_histogram_16(vals, n, s0, shift, hist);
568
+ } else {
569
+ simd_histogram_8(vals, n, s0, shift, hist);
570
+ }
571
+ IFV {
572
+ int sum = n_lt + n_gt;
573
+ printf(" n_lt=%ld hist=[", n_lt);
574
+ for(int i = 0; i <= maxval; i++) {
575
+ printf("%d ", hist[i]);
576
+ sum += hist[i];
577
+ }
578
+ printf("] n_gt=%ld sum=%d\n", n_gt, sum);
579
+ assert(sum == n);
580
+ }
581
+
582
+ size_t sum_below = n_lt;
583
+ int i;
584
+ for (i = 0; i <= maxval; i++) {
585
+ sum_below += hist[i];
586
+ if (sum_below >= q_min) {
587
+ break;
588
+ }
589
+ }
590
+ IFV printf(" i=%d sum_below=%ld\n", i, sum_below);
591
+ if (i <= maxval) {
592
+ s0 = s0 + (i << shift);
593
+ s1 = s0 + (1 << shift) - 1;
594
+ n_lt = sum_below - hist[i];
595
+ n_gt = n - sum_below;
596
+ } else {
597
+ assert(!"not implemented");
598
+ }
599
+
600
+ IFV printf(" new bin: s0=%d s1=%d n_lt=%ld n_gt=%ld\n", s0, s1, n_lt, n_gt);
601
+
602
+ if (s1 > s0) {
603
+ if (n_lt >= q_min && q_max >= n_lt) {
604
+ IFV printf(" FOUND1\n");
605
+ thresh = s0;
606
+ q = n_lt;
607
+ break;
608
+ }
609
+
610
+ size_t n_lt_2 = n - n_gt;
611
+ if (n_lt_2 >= q_min && q_max >= n_lt_2) {
612
+ thresh = s1 + 1;
613
+ q = n_lt_2;
614
+ IFV printf(" FOUND2\n");
615
+ break;
616
+ }
617
+ } else {
618
+ thresh = s0;
619
+ q = q_min;
620
+ tot_eq = n - n_gt - n_lt;
621
+ n_eq = q_min - n_lt;
622
+ IFV printf(" FOUND3\n");
623
+ break;
624
+ }
625
+ }
626
+
627
+ IFV printf("end bissection: thresh=%d q=%ld n_eq=%ld\n", thresh, q, n_eq);
628
+
629
+ if (!C::is_max) {
630
+ if (n_eq == 0) {
631
+ thresh --;
632
+ } else {
633
+ // thresh unchanged
634
+ n_eq = tot_eq - n_eq;
635
+ }
636
+ q = n - q;
637
+ IFV printf("revert due to CMin, q->%ld n_eq->%ld\n", q, n_eq);
638
+ }
639
+
640
+ size_t wp = simd_compress_array<C>(vals, ids, n, thresh, n_eq);
641
+ IFV printf("wp=%ld ?= %ld\n", wp, q);
642
+ assert(wp == q);
643
+ if (q_out) {
644
+ *q_out = wp;
645
+ }
646
+
647
+ return thresh;
648
+ }
649
+
650
+
651
+
652
+ template<class C>
653
+ uint16_t simd_partition_fuzzy(
654
+ uint16_t *vals, typename C::TI * ids, size_t n,
655
+ size_t q_min, size_t q_max, size_t * q_out
656
+ ) {
657
+
658
+ assert(is_aligned_pointer(vals));
659
+
660
+ uint16_t s0i, s1i;
661
+ find_minimax(vals, n, s0i, s1i);
662
+ // QSelect_stats.t0 += get_cy() - t0;
663
+
664
+ return simd_partition_fuzzy_with_bounds<C>(
665
+ vals, ids, n, q_min, q_max, q_out, s0i, s1i);
666
+ }
667
+
668
+
669
+
670
+ template<class C>
671
+ uint16_t simd_partition(uint16_t *vals, typename C::TI * ids, size_t n, size_t q) {
672
+
673
+ assert(is_aligned_pointer(vals));
674
+
675
+ if (q == 0) {
676
+ return 0;
677
+ }
678
+ if (q >= n) {
679
+ return 0xffff;
680
+ }
681
+
682
+ uint16_t s0i, s1i;
683
+ find_minimax(vals, n, s0i, s1i);
684
+
685
+ return simd_partition_fuzzy_with_bounds<C>(
686
+ vals, ids, n, q, q, nullptr, s0i, s1i);
687
+ }
688
+
689
+ template<class C>
690
+ uint16_t simd_partition_with_bounds(
691
+ uint16_t *vals, typename C::TI * ids, size_t n, size_t q,
692
+ uint16_t s0i, uint16_t s1i)
693
+ {
694
+ return simd_partition_fuzzy_with_bounds<C>(
695
+ vals, ids, n, q, q, nullptr, s0i, s1i);
696
+ }
697
+
698
+ } // namespace simd_partitioning
699
+
700
+
701
+ /******************************************************************
702
+ * Driver routine
703
+ ******************************************************************/
704
+
705
+
706
+ template<class C>
707
+ typename C::T partition_fuzzy(
708
+ typename C::T *vals, typename C::TI * ids, size_t n,
709
+ size_t q_min, size_t q_max, size_t * q_out)
710
+ {
711
+ // the code below compiles and runs without AVX2 but it's slower than
712
+ // the scalar implementation
713
+ #ifdef __AVX2__
714
+ constexpr bool is_uint16 = std::is_same<typename C::T, uint16_t>::value;
715
+ if (is_uint16 && is_aligned_pointer(vals)) {
716
+ return simd_partitioning::simd_partition_fuzzy<C>(
717
+ (uint16_t*)vals, ids, n, q_min, q_max, q_out);
718
+ }
719
+ #endif
720
+ return partitioning::partition_fuzzy_median3<C>(
721
+ vals, ids, n, q_min, q_max, q_out);
722
+ }
723
+
724
+
725
+ // explicit template instanciations
726
+
727
+ template float partition_fuzzy<CMin<float, int64_t>> (
728
+ float *vals, int64_t * ids, size_t n,
729
+ size_t q_min, size_t q_max, size_t * q_out);
730
+
731
+ template float partition_fuzzy<CMax<float, int64_t>> (
732
+ float *vals, int64_t * ids, size_t n,
733
+ size_t q_min, size_t q_max, size_t * q_out);
734
+
735
+ template uint16_t partition_fuzzy<CMin<uint16_t, int64_t>> (
736
+ uint16_t *vals, int64_t * ids, size_t n,
737
+ size_t q_min, size_t q_max, size_t * q_out);
738
+
739
+ template uint16_t partition_fuzzy<CMax<uint16_t, int64_t>> (
740
+ uint16_t *vals, int64_t * ids, size_t n,
741
+ size_t q_min, size_t q_max, size_t * q_out);
742
+
743
+ template uint16_t partition_fuzzy<CMin<uint16_t, int>> (
744
+ uint16_t *vals, int * ids, size_t n,
745
+ size_t q_min, size_t q_max, size_t * q_out);
746
+
747
+ template uint16_t partition_fuzzy<CMax<uint16_t, int>> (
748
+ uint16_t *vals, int * ids, size_t n,
749
+ size_t q_min, size_t q_max, size_t * q_out);
750
+
751
+
752
+
753
+ /******************************************************************
754
+ * Histogram subroutines
755
+ ******************************************************************/
756
+
757
+ #ifdef __AVX2__
758
+ /// FIXME when MSB of uint16 is set
759
+ // this code does not compile properly with GCC 7.4.0
760
+
761
+ namespace {
762
+
763
+ /************************************************************
764
+ * 8 bins
765
+ ************************************************************/
766
+
767
+ simd32uint8 accu4to8(simd16uint16 a4) {
768
+ simd16uint16 mask4(0x0f0f);
769
+
770
+ simd16uint16 a8_0 = a4 & mask4;
771
+ simd16uint16 a8_1 = (a4 >> 4) & mask4;
772
+
773
+ return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
774
+ }
775
+
776
+
777
+ simd16uint16 accu8to16(simd32uint8 a8) {
778
+ simd16uint16 mask8(0x00ff);
779
+
780
+ simd16uint16 a8_0 = simd16uint16(a8) & mask8;
781
+ simd16uint16 a8_1 = (simd16uint16(a8) >> 8) & mask8;
782
+
783
+ return simd16uint16(_mm256_hadd_epi16(a8_0.i, a8_1.i));
784
+ }
785
+
786
+
787
+ static const simd32uint8 shifts(_mm256_setr_epi8(
788
+ 1, 16, 0, 0, 4, 64, 0, 0,
789
+ 0, 0, 1, 16, 0, 0, 4, 64,
790
+ 1, 16, 0, 0, 4, 64, 0, 0,
791
+ 0, 0, 1, 16, 0, 0, 4, 64
792
+ ));
793
+
794
+ // 2-bit accumulator: we can add only up to 3 elements
795
+ // on output we return 2*4-bit results
796
+ // preproc returns either an index in 0..7 or 0xffff
797
+ // that yeilds a 0 when used in the table look-up
798
+ template<int N, class Preproc>
799
+ void compute_accu2(
800
+ const uint16_t * & data,
801
+ Preproc & pp,
802
+ simd16uint16 & a4lo, simd16uint16 & a4hi
803
+ ) {
804
+ simd16uint16 mask2(0x3333);
805
+ simd16uint16 a2((uint16_t)0); // 2-bit accu
806
+ for (int j = 0; j < N; j ++) {
807
+ simd16uint16 v(data);
808
+ data += 16;
809
+ v = pp(v);
810
+ // 0x800 -> force second half of table
811
+ simd16uint16 idx = v | (v << 8) | simd16uint16(0x800);
812
+ a2 += simd16uint16(shifts.lookup_2_lanes(simd32uint8(idx)));
813
+ }
814
+ a4lo += a2 & mask2;
815
+ a4hi += (a2 >> 2) & mask2;
816
+ }
817
+
818
+
819
+ template<class Preproc>
820
+ simd16uint16 histogram_8(
821
+ const uint16_t * data, Preproc pp,
822
+ size_t n_in) {
823
+
824
+ assert (n_in % 16 == 0);
825
+ int n = n_in / 16;
826
+
827
+ simd32uint8 a8lo(0);
828
+ simd32uint8 a8hi(0);
829
+
830
+ for(int i0 = 0; i0 < n; i0 += 15) {
831
+ simd16uint16 a4lo(0); // 4-bit accus
832
+ simd16uint16 a4hi(0);
833
+
834
+ int i1 = std::min(i0 + 15, n);
835
+ int i;
836
+ for(i = i0; i + 2 < i1; i += 3) {
837
+ compute_accu2<3>(data, pp, a4lo, a4hi); // adds 3 max
838
+ }
839
+ switch (i1 - i) {
840
+ case 2:
841
+ compute_accu2<2>(data, pp, a4lo, a4hi);
842
+ break;
843
+ case 1:
844
+ compute_accu2<1>(data, pp, a4lo, a4hi);
845
+ break;
846
+ }
847
+
848
+ a8lo += accu4to8(a4lo);
849
+ a8hi += accu4to8(a4hi);
850
+ }
851
+
852
+ // move to 16-bit accu
853
+ simd16uint16 a16lo = accu8to16(a8lo);
854
+ simd16uint16 a16hi = accu8to16(a8hi);
855
+
856
+ simd16uint16 a16 = simd16uint16(_mm256_hadd_epi16(a16lo.i, a16hi.i));
857
+
858
+ // the 2 lanes must still be combined
859
+ return a16;
860
+ }
861
+
862
+
863
+ /************************************************************
864
+ * 16 bins
865
+ ************************************************************/
866
+
867
+
868
+
869
+ static const simd32uint8 shifts2(_mm256_setr_epi8(
870
+ 1, 2, 4, 8, 16, 32, 64, (char)128,
871
+ 1, 2, 4, 8, 16, 32, 64, (char)128,
872
+ 1, 2, 4, 8, 16, 32, 64, (char)128,
873
+ 1, 2, 4, 8, 16, 32, 64, (char)128
874
+ ));
875
+
876
+
877
+ simd32uint8 shiftr_16(simd32uint8 x, int n)
878
+ {
879
+ return simd32uint8(simd16uint16(x) >> n);
880
+ }
881
+
882
+
883
+ inline simd32uint8 combine_2x2(simd32uint8 a, simd32uint8 b) {
884
+
885
+ __m256i a1b0 = _mm256_permute2f128_si256(a.i, b.i, 0x21);
886
+ __m256i a0b1 = _mm256_blend_epi32(a.i, b.i, 0xF0);
887
+
888
+ return simd32uint8(a1b0) + simd32uint8(a0b1);
889
+ }
890
+
891
+
892
+ // 2-bit accumulator: we can add only up to 3 elements
893
+ // on output we return 2*4-bit results
894
+ template<int N, class Preproc>
895
+ void compute_accu2_16(
896
+ const uint16_t * & data, Preproc pp,
897
+ simd32uint8 & a4_0, simd32uint8 & a4_1,
898
+ simd32uint8 & a4_2, simd32uint8 & a4_3
899
+ ) {
900
+ simd32uint8 mask1(0x55);
901
+ simd32uint8 a2_0; // 2-bit accu
902
+ simd32uint8 a2_1; // 2-bit accu
903
+ a2_0.clear(); a2_1.clear();
904
+
905
+ for (int j = 0; j < N; j ++) {
906
+ simd16uint16 v(data);
907
+ data += 16;
908
+ v = pp(v);
909
+
910
+ simd16uint16 idx = v | (v << 8);
911
+ simd32uint8 a1 = shifts2.lookup_2_lanes(simd32uint8(idx));
912
+ // contains 0s for out-of-bounds elements
913
+
914
+ simd16uint16 lt8 = (v >> 3) == simd16uint16(0);
915
+ lt8.i = _mm256_xor_si256(lt8.i, _mm256_set1_epi16(0xff00));
916
+
917
+ a1 = a1 & lt8;
918
+
919
+ a2_0 += a1 & mask1;
920
+ a2_1 += shiftr_16(a1, 1) & mask1;
921
+ }
922
+ simd32uint8 mask2(0x33);
923
+
924
+ a4_0 += a2_0 & mask2;
925
+ a4_1 += a2_1 & mask2;
926
+ a4_2 += shiftr_16(a2_0, 2) & mask2;
927
+ a4_3 += shiftr_16(a2_1, 2) & mask2;
928
+
929
+ }
930
+
931
+
932
+ simd32uint8 accu4to8_2(simd32uint8 a4_0, simd32uint8 a4_1) {
933
+ simd32uint8 mask4(0x0f);
934
+
935
+ simd32uint8 a8_0 = combine_2x2(
936
+ a4_0 & mask4,
937
+ shiftr_16(a4_0, 4) & mask4
938
+ );
939
+
940
+ simd32uint8 a8_1 = combine_2x2(
941
+ a4_1 & mask4,
942
+ shiftr_16(a4_1, 4) & mask4
943
+ );
944
+
945
+ return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
946
+ }
947
+
948
+
949
+
950
+ template<class Preproc>
951
+ simd16uint16 histogram_16(const uint16_t * data, Preproc pp, size_t n_in) {
952
+
953
+ assert (n_in % 16 == 0);
954
+ int n = n_in / 16;
955
+
956
+ simd32uint8 a8lo((uint8_t)0);
957
+ simd32uint8 a8hi((uint8_t)0);
958
+
959
+ for(int i0 = 0; i0 < n; i0 += 7) {
960
+ simd32uint8 a4_0(0); // 0, 4, 8, 12
961
+ simd32uint8 a4_1(0); // 1, 5, 9, 13
962
+ simd32uint8 a4_2(0); // 2, 6, 10, 14
963
+ simd32uint8 a4_3(0); // 3, 7, 11, 15
964
+
965
+ int i1 = std::min(i0 + 7, n);
966
+ int i;
967
+ for(i = i0; i + 2 < i1; i += 3) {
968
+ compute_accu2_16<3>(data, pp, a4_0, a4_1, a4_2, a4_3);
969
+ }
970
+ switch (i1 - i) {
971
+ case 2:
972
+ compute_accu2_16<2>(data, pp, a4_0, a4_1, a4_2, a4_3);
973
+ break;
974
+ case 1:
975
+ compute_accu2_16<1>(data, pp, a4_0, a4_1, a4_2, a4_3);
976
+ break;
977
+ }
978
+
979
+ a8lo += accu4to8_2(a4_0, a4_1);
980
+ a8hi += accu4to8_2(a4_2, a4_3);
981
+ }
982
+
983
+ // move to 16-bit accu
984
+ simd16uint16 a16lo = accu8to16(a8lo);
985
+ simd16uint16 a16hi = accu8to16(a8hi);
986
+
987
+ simd16uint16 a16 = simd16uint16(_mm256_hadd_epi16(a16lo.i, a16hi.i));
988
+
989
+ __m256i perm32 = _mm256_setr_epi32(
990
+ 0, 2, 4, 6, 1, 3, 5, 7
991
+ );
992
+ a16.i = _mm256_permutevar8x32_epi32(a16.i, perm32);
993
+
994
+ return a16;
995
+ }
996
+
997
+ struct PreprocNOP {
998
+ simd16uint16 operator () (simd16uint16 x) {
999
+ return x;
1000
+ }
1001
+
1002
+ };
1003
+
1004
+
1005
+ template<int shift, int nbin>
1006
+ struct PreprocMinShift {
1007
+ simd16uint16 min16;
1008
+ simd16uint16 max16;
1009
+
1010
+ explicit PreprocMinShift(uint16_t min) {
1011
+ min16.set1(min);
1012
+ int vmax0 = std::min((nbin << shift) + min, 65536);
1013
+ uint16_t vmax = uint16_t(vmax0 - 1 - min);
1014
+ max16.set1(vmax); // vmax inclusive
1015
+ }
1016
+
1017
+ simd16uint16 operator () (simd16uint16 x) {
1018
+ x = x - min16;
1019
+ simd16uint16 mask = (x == max(x, max16)) - (x == max16);
1020
+ return (x >> shift) | mask;
1021
+ }
1022
+
1023
+ };
1024
+
1025
+ /* unbounded versions of the functions */
1026
+
1027
+ void simd_histogram_8_unbounded(
1028
+ const uint16_t *data, int n,
1029
+ int *hist)
1030
+ {
1031
+ PreprocNOP pp;
1032
+ simd16uint16 a16 = histogram_8(data, pp, (n & ~15));
1033
+
1034
+ ALIGNED(32) uint16_t a16_tab[16];
1035
+ a16.store(a16_tab);
1036
+
1037
+ for(int i = 0; i < 8; i++) {
1038
+ hist[i] = a16_tab[i] + a16_tab[i + 8];
1039
+ }
1040
+
1041
+ for(int i = (n & ~15); i < n; i++) {
1042
+ hist[data[i]]++;
1043
+ }
1044
+
1045
+ }
1046
+
1047
+
1048
+ void simd_histogram_16_unbounded(
1049
+ const uint16_t *data, int n,
1050
+ int *hist)
1051
+ {
1052
+
1053
+ simd16uint16 a16 = histogram_16(data, PreprocNOP(), (n & ~15));
1054
+
1055
+ ALIGNED(32) uint16_t a16_tab[16];
1056
+ a16.store(a16_tab);
1057
+
1058
+ for(int i = 0; i < 16; i++) {
1059
+ hist[i] = a16_tab[i];
1060
+ }
1061
+
1062
+ for(int i = (n & ~15); i < n; i++) {
1063
+ hist[data[i]]++;
1064
+ }
1065
+
1066
+ }
1067
+
1068
+
1069
+
1070
+ } // anonymous namespace
1071
+
1072
+ /************************************************************
1073
+ * Driver routines
1074
+ ************************************************************/
1075
+
1076
+ void simd_histogram_8(
1077
+ const uint16_t *data, int n,
1078
+ uint16_t min, int shift,
1079
+ int *hist)
1080
+ {
1081
+ if (shift < 0) {
1082
+ simd_histogram_8_unbounded(data, n, hist);
1083
+ return;
1084
+ }
1085
+
1086
+ simd16uint16 a16;
1087
+
1088
+ #define DISPATCH(s) \
1089
+ case s: \
1090
+ a16 = histogram_8(data, PreprocMinShift<s, 8>(min), (n & ~15)); \
1091
+ break
1092
+
1093
+ switch(shift) {
1094
+ DISPATCH(0);
1095
+ DISPATCH(1);
1096
+ DISPATCH(2);
1097
+ DISPATCH(3);
1098
+ DISPATCH(4);
1099
+ DISPATCH(5);
1100
+ DISPATCH(6);
1101
+ DISPATCH(7);
1102
+ DISPATCH(8);
1103
+ DISPATCH(9);
1104
+ DISPATCH(10);
1105
+ DISPATCH(11);
1106
+ DISPATCH(12);
1107
+ DISPATCH(13);
1108
+ default:
1109
+ FAISS_THROW_FMT("dispatch for shift=%d not instantiated", shift);
1110
+ }
1111
+ #undef DISPATCH
1112
+
1113
+ ALIGNED(32) uint16_t a16_tab[16];
1114
+ a16.store(a16_tab);
1115
+
1116
+ for(int i = 0; i < 8; i++) {
1117
+ hist[i] = a16_tab[i] + a16_tab[i + 8];
1118
+ }
1119
+
1120
+ // complete with remaining bins
1121
+ for(int i = (n & ~15); i < n; i++) {
1122
+ if (data[i] < min) continue;
1123
+ uint16_t v = data[i] - min;
1124
+ v >>= shift;
1125
+ if (v < 8) hist[v]++;
1126
+ }
1127
+
1128
+ }
1129
+
1130
+
1131
+
1132
+ void simd_histogram_16(
1133
+ const uint16_t *data, int n,
1134
+ uint16_t min, int shift,
1135
+ int *hist)
1136
+ {
1137
+ if (shift < 0) {
1138
+ simd_histogram_16_unbounded(data, n, hist);
1139
+ return;
1140
+ }
1141
+
1142
+ simd16uint16 a16;
1143
+
1144
+ #define DISPATCH(s) \
1145
+ case s: \
1146
+ a16 = histogram_16(data, PreprocMinShift<s, 16>(min), (n & ~15)); \
1147
+ break
1148
+
1149
+ switch(shift) {
1150
+ DISPATCH(0);
1151
+ DISPATCH(1);
1152
+ DISPATCH(2);
1153
+ DISPATCH(3);
1154
+ DISPATCH(4);
1155
+ DISPATCH(5);
1156
+ DISPATCH(6);
1157
+ DISPATCH(7);
1158
+ DISPATCH(8);
1159
+ DISPATCH(9);
1160
+ DISPATCH(10);
1161
+ DISPATCH(11);
1162
+ DISPATCH(12);
1163
+ default:
1164
+ FAISS_THROW_FMT("dispatch for shift=%d not instantiated", shift);
1165
+ }
1166
+ #undef DISPATCH
1167
+
1168
+ ALIGNED(32) uint16_t a16_tab[16];
1169
+ a16.store(a16_tab);
1170
+
1171
+ for(int i = 0; i < 16; i++) {
1172
+ hist[i] = a16_tab[i];
1173
+ }
1174
+
1175
+ for(int i = (n & ~15); i < n; i++) {
1176
+ if (data[i] < min) continue;
1177
+ uint16_t v = data[i] - min;
1178
+ v >>= shift;
1179
+ if (v < 16) hist[v]++;
1180
+ }
1181
+
1182
+ }
1183
+
1184
+
1185
+ // no AVX2
1186
+ #else
1187
+
1188
+
1189
+
1190
+ void simd_histogram_16(
1191
+ const uint16_t *data, int n,
1192
+ uint16_t min, int shift,
1193
+ int *hist)
1194
+ {
1195
+ memset(hist, 0, sizeof(*hist) * 16);
1196
+ if (shift < 0) {
1197
+ for(size_t i = 0; i < n; i++) {
1198
+ hist[data[i]]++;
1199
+ }
1200
+ } else {
1201
+ int vmax0 = std::min((16 << shift) + min, 65536);
1202
+ uint16_t vmax = uint16_t(vmax0 - 1 - min);
1203
+
1204
+ for(size_t i = 0; i < n; i++) {
1205
+ uint16_t v = data[i];
1206
+ v -= min;
1207
+ if (!(v <= vmax))
1208
+ continue;
1209
+ v >>= shift;
1210
+ hist[v]++;
1211
+
1212
+ /*
1213
+ if (data[i] < min) continue;
1214
+ uint16_t v = data[i] - min;
1215
+ v >>= shift;
1216
+ if (v < 16) hist[v]++;
1217
+ */
1218
+ }
1219
+ }
1220
+
1221
+ }
1222
+
1223
+ void simd_histogram_8(
1224
+ const uint16_t *data, int n,
1225
+ uint16_t min, int shift,
1226
+ int *hist)
1227
+ {
1228
+ memset(hist, 0, sizeof(*hist) * 8);
1229
+ if (shift < 0) {
1230
+ for(size_t i = 0; i < n; i++) {
1231
+ hist[data[i]]++;
1232
+ }
1233
+ } else {
1234
+ for(size_t i = 0; i < n; i++) {
1235
+ if (data[i] < min) continue;
1236
+ uint16_t v = data[i] - min;
1237
+ v >>= shift;
1238
+ if (v < 8) hist[v]++;
1239
+ }
1240
+ }
1241
+
1242
+ }
1243
+
1244
+
1245
+ #endif
1246
+
1247
+
1248
+ void PartitionStats::reset() {
1249
+ memset(this, 0, sizeof(*this));
1250
+ }
1251
+
1252
+ PartitionStats partition_stats;
1253
+
1254
+
1255
+
1256
+ } // namespace faiss