faiss 0.2.5 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/extconf.rb +1 -1
  5. data/ext/faiss/index.cpp +13 -0
  6. data/lib/faiss/version.rb +1 -1
  7. data/lib/faiss.rb +2 -2
  8. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  9. data/vendor/faiss/faiss/AutoTune.h +0 -1
  10. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  11. data/vendor/faiss/faiss/Clustering.h +0 -2
  12. data/vendor/faiss/faiss/IVFlib.h +0 -2
  13. data/vendor/faiss/faiss/Index.h +1 -2
  14. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  15. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  16. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  17. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  18. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  19. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  20. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  21. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  22. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  23. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  24. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  25. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  26. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  27. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  29. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  30. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  31. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  32. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  33. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  34. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  35. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  36. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  37. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  38. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  39. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  41. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  43. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  44. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  45. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  46. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  47. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  48. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  49. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  50. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  51. data/vendor/faiss/faiss/IndexShards.h +2 -3
  52. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  53. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  54. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  55. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  56. data/vendor/faiss/faiss/MetricType.h +14 -0
  57. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  58. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  59. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  60. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  61. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  62. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  69. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  70. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  71. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  72. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  73. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  74. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  75. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  76. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  77. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  78. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  81. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  82. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  83. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  84. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  85. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  86. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  87. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  91. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  92. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  93. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  94. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  95. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  96. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  97. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  98. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  99. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  100. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  101. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  102. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  104. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  105. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  106. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  107. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  109. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  110. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  111. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  112. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  113. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  114. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  115. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  116. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  117. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  118. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  119. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  120. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  121. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  122. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  123. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  124. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  125. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  128. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  129. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  130. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  131. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  132. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  133. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  134. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  135. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  137. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  138. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  139. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  140. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  141. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  142. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  143. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  144. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  145. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  146. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  147. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  148. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  149. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  150. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  151. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  152. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  153. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  155. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  156. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  157. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  158. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  159. data/vendor/faiss/faiss/utils/distances.h +11 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  164. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  165. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  166. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  167. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  168. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  169. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  170. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  171. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  172. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  173. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  174. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  175. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  176. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  179. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  180. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  181. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  182. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  183. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  184. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  185. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  186. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  187. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  188. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  189. data/vendor/faiss/faiss/utils/utils.h +2 -9
  190. metadata +30 -4
  191. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -0,0 +1,692 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/utils/sorting.h>
11
+
12
+ #include <omp.h>
13
+ #include <algorithm>
14
+
15
+ #include <faiss/impl/FaissAssert.h>
16
+ #include <faiss/utils/utils.h>
17
+
18
+ namespace faiss {
19
+
20
+ /*****************************************************************************
21
+ * Argsort
22
+ ****************************************************************************/
23
+
24
+ namespace {
25
+ struct ArgsortComparator {
26
+ const float* vals;
27
+ bool operator()(const size_t a, const size_t b) const {
28
+ return vals[a] < vals[b];
29
+ }
30
+ };
31
+
32
+ struct SegmentS {
33
+ size_t i0; // begin pointer in the permutation array
34
+ size_t i1; // end
35
+ size_t len() const {
36
+ return i1 - i0;
37
+ }
38
+ };
39
+
40
+ // see https://en.wikipedia.org/wiki/Merge_algorithm#Parallel_merge
41
+ // extended to > 1 merge thread
42
+
43
+ // merges 2 ranges that should be consecutive on the source into
44
+ // the union of the two on the destination
45
+ template <typename T>
46
+ void parallel_merge(
47
+ const T* src,
48
+ T* dst,
49
+ SegmentS& s1,
50
+ SegmentS& s2,
51
+ int nt,
52
+ const ArgsortComparator& comp) {
53
+ if (s2.len() > s1.len()) { // make sure that s1 larger than s2
54
+ std::swap(s1, s2);
55
+ }
56
+
57
+ // compute sub-ranges for each thread
58
+ std::vector<SegmentS> s1s(nt), s2s(nt), sws(nt);
59
+ s2s[0].i0 = s2.i0;
60
+ s2s[nt - 1].i1 = s2.i1;
61
+
62
+ // not sure parallel actually helps here
63
+ #pragma omp parallel for num_threads(nt)
64
+ for (int t = 0; t < nt; t++) {
65
+ s1s[t].i0 = s1.i0 + s1.len() * t / nt;
66
+ s1s[t].i1 = s1.i0 + s1.len() * (t + 1) / nt;
67
+
68
+ if (t + 1 < nt) {
69
+ T pivot = src[s1s[t].i1];
70
+ size_t i0 = s2.i0, i1 = s2.i1;
71
+ while (i0 + 1 < i1) {
72
+ size_t imed = (i1 + i0) / 2;
73
+ if (comp(pivot, src[imed])) {
74
+ i1 = imed;
75
+ } else {
76
+ i0 = imed;
77
+ }
78
+ }
79
+ s2s[t].i1 = s2s[t + 1].i0 = i1;
80
+ }
81
+ }
82
+ s1.i0 = std::min(s1.i0, s2.i0);
83
+ s1.i1 = std::max(s1.i1, s2.i1);
84
+ s2 = s1;
85
+ sws[0].i0 = s1.i0;
86
+ for (int t = 0; t < nt; t++) {
87
+ sws[t].i1 = sws[t].i0 + s1s[t].len() + s2s[t].len();
88
+ if (t + 1 < nt) {
89
+ sws[t + 1].i0 = sws[t].i1;
90
+ }
91
+ }
92
+ assert(sws[nt - 1].i1 == s1.i1);
93
+
94
+ // do the actual merging
95
+ #pragma omp parallel for num_threads(nt)
96
+ for (int t = 0; t < nt; t++) {
97
+ SegmentS sw = sws[t];
98
+ SegmentS s1t = s1s[t];
99
+ SegmentS s2t = s2s[t];
100
+ if (s1t.i0 < s1t.i1 && s2t.i0 < s2t.i1) {
101
+ for (;;) {
102
+ // assert (sw.len() == s1t.len() + s2t.len());
103
+ if (comp(src[s1t.i0], src[s2t.i0])) {
104
+ dst[sw.i0++] = src[s1t.i0++];
105
+ if (s1t.i0 == s1t.i1) {
106
+ break;
107
+ }
108
+ } else {
109
+ dst[sw.i0++] = src[s2t.i0++];
110
+ if (s2t.i0 == s2t.i1) {
111
+ break;
112
+ }
113
+ }
114
+ }
115
+ }
116
+ if (s1t.len() > 0) {
117
+ assert(s1t.len() == sw.len());
118
+ memcpy(dst + sw.i0, src + s1t.i0, s1t.len() * sizeof(dst[0]));
119
+ } else if (s2t.len() > 0) {
120
+ assert(s2t.len() == sw.len());
121
+ memcpy(dst + sw.i0, src + s2t.i0, s2t.len() * sizeof(dst[0]));
122
+ }
123
+ }
124
+ }
125
+
126
+ }; // namespace
127
+
128
+ void fvec_argsort(size_t n, const float* vals, size_t* perm) {
129
+ for (size_t i = 0; i < n; i++) {
130
+ perm[i] = i;
131
+ }
132
+ ArgsortComparator comp = {vals};
133
+ std::sort(perm, perm + n, comp);
134
+ }
135
+
136
+ void fvec_argsort_parallel(size_t n, const float* vals, size_t* perm) {
137
+ size_t* perm2 = new size_t[n];
138
+ // 2 result tables, during merging, flip between them
139
+ size_t *permB = perm2, *permA = perm;
140
+
141
+ int nt = omp_get_max_threads();
142
+ { // prepare correct permutation so that the result ends in perm
143
+ // at final iteration
144
+ int nseg = nt;
145
+ while (nseg > 1) {
146
+ nseg = (nseg + 1) / 2;
147
+ std::swap(permA, permB);
148
+ }
149
+ }
150
+
151
+ #pragma omp parallel
152
+ for (size_t i = 0; i < n; i++) {
153
+ permA[i] = i;
154
+ }
155
+
156
+ ArgsortComparator comp = {vals};
157
+
158
+ std::vector<SegmentS> segs(nt);
159
+
160
+ // independent sorts
161
+ #pragma omp parallel for
162
+ for (int t = 0; t < nt; t++) {
163
+ size_t i0 = t * n / nt;
164
+ size_t i1 = (t + 1) * n / nt;
165
+ SegmentS seg = {i0, i1};
166
+ std::sort(permA + seg.i0, permA + seg.i1, comp);
167
+ segs[t] = seg;
168
+ }
169
+ int prev_nested = omp_get_nested();
170
+ omp_set_nested(1);
171
+
172
+ int nseg = nt;
173
+ while (nseg > 1) {
174
+ int nseg1 = (nseg + 1) / 2;
175
+ int sub_nt = nseg % 2 == 0 ? nt : nt - 1;
176
+ int sub_nseg1 = nseg / 2;
177
+
178
+ #pragma omp parallel for num_threads(nseg1)
179
+ for (int s = 0; s < nseg; s += 2) {
180
+ if (s + 1 == nseg) { // otherwise isolated segment
181
+ memcpy(permB + segs[s].i0,
182
+ permA + segs[s].i0,
183
+ segs[s].len() * sizeof(size_t));
184
+ } else {
185
+ int t0 = s * sub_nt / sub_nseg1;
186
+ int t1 = (s + 1) * sub_nt / sub_nseg1;
187
+ printf("merge %d %d, %d threads\n", s, s + 1, t1 - t0);
188
+ parallel_merge(
189
+ permA, permB, segs[s], segs[s + 1], t1 - t0, comp);
190
+ }
191
+ }
192
+ for (int s = 0; s < nseg; s += 2) {
193
+ segs[s / 2] = segs[s];
194
+ }
195
+ nseg = nseg1;
196
+ std::swap(permA, permB);
197
+ }
198
+ assert(permA == perm);
199
+ omp_set_nested(prev_nested);
200
+ delete[] perm2;
201
+ }
202
+
203
+ /*****************************************************************************
204
+ * Bucket sort
205
+ ****************************************************************************/
206
+
207
+ // extern symbol in the .h
208
+ int bucket_sort_verbose = 0;
209
+
210
+ namespace {
211
+
212
+ void bucket_sort_ref(
213
+ size_t nval,
214
+ const uint64_t* vals,
215
+ uint64_t vmax,
216
+ int64_t* lims,
217
+ int64_t* perm) {
218
+ double t0 = getmillisecs();
219
+ memset(lims, 0, sizeof(*lims) * (vmax + 1));
220
+ for (size_t i = 0; i < nval; i++) {
221
+ FAISS_THROW_IF_NOT(vals[i] < vmax);
222
+ lims[vals[i] + 1]++;
223
+ }
224
+ double t1 = getmillisecs();
225
+ // compute cumulative sum
226
+ for (size_t i = 0; i < vmax; i++) {
227
+ lims[i + 1] += lims[i];
228
+ }
229
+ FAISS_THROW_IF_NOT(lims[vmax] == nval);
230
+ double t2 = getmillisecs();
231
+ // populate buckets
232
+ for (size_t i = 0; i < nval; i++) {
233
+ perm[lims[vals[i]]++] = i;
234
+ }
235
+ double t3 = getmillisecs();
236
+ // reset pointers
237
+ for (size_t i = vmax; i > 0; i--) {
238
+ lims[i] = lims[i - 1];
239
+ }
240
+ lims[0] = 0;
241
+ double t4 = getmillisecs();
242
+ if (bucket_sort_verbose) {
243
+ printf("times %.3f %.3f %.3f %.3f\n",
244
+ t1 - t0,
245
+ t2 - t1,
246
+ t3 - t2,
247
+ t4 - t3);
248
+ }
249
+ }
250
+
251
+ void bucket_sort_parallel(
252
+ size_t nval,
253
+ const uint64_t* vals,
254
+ uint64_t vmax,
255
+ int64_t* lims,
256
+ int64_t* perm,
257
+ int nt_in) {
258
+ memset(lims, 0, sizeof(*lims) * (vmax + 1));
259
+ #pragma omp parallel num_threads(nt_in)
260
+ {
261
+ int nt = omp_get_num_threads(); // might be different from nt_in
262
+ int rank = omp_get_thread_num();
263
+ std::vector<int64_t> local_lims(vmax + 1);
264
+
265
+ // range of indices handled by this thread
266
+ size_t i0 = nval * rank / nt;
267
+ size_t i1 = nval * (rank + 1) / nt;
268
+
269
+ // build histogram in local lims
270
+ double t0 = getmillisecs();
271
+ for (size_t i = i0; i < i1; i++) {
272
+ local_lims[vals[i]]++;
273
+ }
274
+ #pragma omp critical
275
+ { // accumulate histograms (not shifted indices to prepare cumsum)
276
+ for (size_t i = 0; i < vmax; i++) {
277
+ lims[i + 1] += local_lims[i];
278
+ }
279
+ }
280
+ #pragma omp barrier
281
+
282
+ double t1 = getmillisecs();
283
+ #pragma omp master
284
+ {
285
+ // compute cumulative sum
286
+ for (size_t i = 0; i < vmax; i++) {
287
+ lims[i + 1] += lims[i];
288
+ }
289
+ FAISS_THROW_IF_NOT(lims[vmax] == nval);
290
+ }
291
+ #pragma omp barrier
292
+
293
+ #pragma omp critical
294
+ { // current thread grabs a slot in the buckets
295
+ for (size_t i = 0; i < vmax; i++) {
296
+ size_t nv = local_lims[i];
297
+ local_lims[i] = lims[i]; // where we should start writing
298
+ lims[i] += nv;
299
+ }
300
+ }
301
+
302
+ double t2 = getmillisecs();
303
+ #pragma omp barrier
304
+ { // populate buckets, this is the slowest operation
305
+ for (size_t i = i0; i < i1; i++) {
306
+ perm[local_lims[vals[i]]++] = i;
307
+ }
308
+ }
309
+ #pragma omp barrier
310
+ double t3 = getmillisecs();
311
+
312
+ #pragma omp master
313
+ { // shift back lims
314
+ for (size_t i = vmax; i > 0; i--) {
315
+ lims[i] = lims[i - 1];
316
+ }
317
+ lims[0] = 0;
318
+ double t4 = getmillisecs();
319
+ if (bucket_sort_verbose) {
320
+ printf("times %.3f %.3f %.3f %.3f\n",
321
+ t1 - t0,
322
+ t2 - t1,
323
+ t3 - t2,
324
+ t4 - t3);
325
+ }
326
+ }
327
+ }
328
+ }
329
+
330
+ /***********************************************
331
+ * in-place bucket sort
332
+ */
333
+
334
+ template <class TI>
335
+ void bucket_sort_inplace_ref(
336
+ size_t nrow,
337
+ size_t ncol,
338
+ TI* vals,
339
+ TI nbucket,
340
+ int64_t* lims) {
341
+ double t0 = getmillisecs();
342
+ size_t nval = nrow * ncol;
343
+ FAISS_THROW_IF_NOT(
344
+ nbucket < nval); // unclear what would happen in this case...
345
+
346
+ memset(lims, 0, sizeof(*lims) * (nbucket + 1));
347
+ for (size_t i = 0; i < nval; i++) {
348
+ FAISS_THROW_IF_NOT(vals[i] < nbucket);
349
+ lims[vals[i] + 1]++;
350
+ }
351
+ double t1 = getmillisecs();
352
+ // compute cumulative sum
353
+ for (size_t i = 0; i < nbucket; i++) {
354
+ lims[i + 1] += lims[i];
355
+ }
356
+ FAISS_THROW_IF_NOT(lims[nbucket] == nval);
357
+ double t2 = getmillisecs();
358
+
359
+ std::vector<size_t> ptrs(nbucket);
360
+ for (size_t i = 0; i < nbucket; i++) {
361
+ ptrs[i] = lims[i];
362
+ }
363
+
364
+ // find loops in the permutation and follow them
365
+ TI row = -1;
366
+ TI init_bucket_no = 0, bucket_no = 0;
367
+ for (;;) {
368
+ size_t idx = ptrs[bucket_no];
369
+ if (row >= 0) {
370
+ ptrs[bucket_no] += 1;
371
+ }
372
+ assert(idx < lims[bucket_no + 1]);
373
+ TI next_bucket_no = vals[idx];
374
+ vals[idx] = row;
375
+ if (next_bucket_no != -1) {
376
+ row = idx / ncol;
377
+ bucket_no = next_bucket_no;
378
+ } else {
379
+ // start new loop
380
+ for (; init_bucket_no < nbucket; init_bucket_no++) {
381
+ if (ptrs[init_bucket_no] < lims[init_bucket_no + 1]) {
382
+ break;
383
+ }
384
+ }
385
+ if (init_bucket_no == nbucket) { // we're done
386
+ break;
387
+ }
388
+ bucket_no = init_bucket_no;
389
+ row = -1;
390
+ }
391
+ }
392
+
393
+ for (size_t i = 0; i < nbucket; i++) {
394
+ assert(ptrs[i] == lims[i + 1]);
395
+ }
396
+ double t3 = getmillisecs();
397
+ if (bucket_sort_verbose) {
398
+ printf("times %.3f %.3f %.3f\n", t1 - t0, t2 - t1, t3 - t2);
399
+ }
400
+ }
401
+
402
+ // collects row numbers to write into buckets
403
+ template <class TI>
404
+ struct ToWrite {
405
+ TI nbucket;
406
+ std::vector<TI> buckets;
407
+ std::vector<TI> rows;
408
+ std::vector<size_t> lims;
409
+
410
+ explicit ToWrite(TI nbucket) : nbucket(nbucket) {
411
+ lims.resize(nbucket + 1);
412
+ }
413
+
414
+ /// add one element (row) to write in bucket b
415
+ void add(TI row, TI b) {
416
+ assert(b >= 0 && b < nbucket);
417
+ rows.push_back(row);
418
+ buckets.push_back(b);
419
+ }
420
+
421
+ void bucket_sort() {
422
+ FAISS_THROW_IF_NOT(buckets.size() == rows.size());
423
+ lims.resize(nbucket + 1);
424
+ memset(lims.data(), 0, sizeof(lims[0]) * (nbucket + 1));
425
+
426
+ for (size_t i = 0; i < buckets.size(); i++) {
427
+ assert(buckets[i] >= 0 && buckets[i] < nbucket);
428
+ lims[buckets[i] + 1]++;
429
+ }
430
+ // compute cumulative sum
431
+ for (size_t i = 0; i < nbucket; i++) {
432
+ lims[i + 1] += lims[i];
433
+ }
434
+ FAISS_THROW_IF_NOT(lims[nbucket] == buckets.size());
435
+
436
+ // could also do a circular perm...
437
+ std::vector<TI> new_rows(rows.size());
438
+ std::vector<size_t> ptrs = lims;
439
+ for (size_t i = 0; i < buckets.size(); i++) {
440
+ TI b = buckets[i];
441
+ assert(ptrs[b] < lims[b + 1]);
442
+ new_rows[ptrs[b]++] = rows[i];
443
+ }
444
+ buckets.resize(0);
445
+ std::swap(rows, new_rows);
446
+ }
447
+
448
+ void swap(ToWrite& other) {
449
+ assert(nbucket == other.nbucket);
450
+ buckets.swap(other.buckets);
451
+ rows.swap(other.rows);
452
+ lims.swap(other.lims);
453
+ }
454
+ };
455
+
456
+ template <class TI>
457
+ void bucket_sort_inplace_parallel(
458
+ size_t nrow,
459
+ size_t ncol,
460
+ TI* vals,
461
+ TI nbucket,
462
+ int64_t* lims,
463
+ int nt_in) {
464
+ int verbose = bucket_sort_verbose;
465
+ memset(lims, 0, sizeof(*lims) * (nbucket + 1));
466
+ std::vector<ToWrite<TI>> all_to_write;
467
+ size_t nval = nrow * ncol;
468
+ FAISS_THROW_IF_NOT(
469
+ nbucket < nval); // unclear what would happen in this case...
470
+
471
+ // try to keep size of all_to_write < 5GiB
472
+ // but we need at least one element per bucket
473
+ size_t init_to_write = std::max(
474
+ size_t(nbucket),
475
+ std::min(nval / 10, ((size_t)5 << 30) / (sizeof(TI) * 3 * nt_in)));
476
+ if (verbose > 0) {
477
+ printf("init_to_write=%zd\n", init_to_write);
478
+ }
479
+
480
+ std::vector<size_t> ptrs(nbucket); // ptrs is shared across all threads
481
+ std::vector<char> did_wrap(
482
+ nbucket); // DON'T use std::vector<bool> that cannot be accessed
483
+ // safely from multiple threads!!!
484
+
485
+ #pragma omp parallel num_threads(nt_in)
486
+ {
487
+ int nt = omp_get_num_threads(); // might be different from nt_in (?)
488
+ int rank = omp_get_thread_num();
489
+ std::vector<int64_t> local_lims(nbucket + 1);
490
+
491
+ // range of indices handled by this thread
492
+ size_t i0 = nval * rank / nt;
493
+ size_t i1 = nval * (rank + 1) / nt;
494
+
495
+ // build histogram in local lims
496
+ for (size_t i = i0; i < i1; i++) {
497
+ local_lims[vals[i]]++;
498
+ }
499
+ #pragma omp critical
500
+ { // accumulate histograms (not shifted indices to prepare cumsum)
501
+ for (size_t i = 0; i < nbucket; i++) {
502
+ lims[i + 1] += local_lims[i];
503
+ }
504
+ all_to_write.push_back(ToWrite<TI>(nbucket));
505
+ }
506
+
507
+ #pragma omp barrier
508
+ // this thread's things to write
509
+ ToWrite<TI>& to_write = all_to_write[rank];
510
+
511
+ #pragma omp master
512
+ {
513
+ // compute cumulative sum
514
+ for (size_t i = 0; i < nbucket; i++) {
515
+ lims[i + 1] += lims[i];
516
+ }
517
+ FAISS_THROW_IF_NOT(lims[nbucket] == nval);
518
+ // at this point lims is final (read only!)
519
+
520
+ memcpy(ptrs.data(), lims, sizeof(lims[0]) * nbucket);
521
+
522
+ // initial values to write (we write -1s to get the process running)
523
+ // make sure at least one element per bucket
524
+ size_t written = 0;
525
+ for (TI b = 0; b < nbucket; b++) {
526
+ size_t l0 = lims[b], l1 = lims[b + 1];
527
+ size_t target_to_write = l1 * init_to_write / nval;
528
+ do {
529
+ if (l0 == l1) {
530
+ break;
531
+ }
532
+ to_write.add(-1, b);
533
+ l0++;
534
+ written++;
535
+ } while (written < target_to_write);
536
+ }
537
+
538
+ to_write.bucket_sort();
539
+ }
540
+
541
+ // this thread writes only buckets b0:b1
542
+ size_t b0 = (rank * nbucket + nt - 1) / nt;
543
+ size_t b1 = ((rank + 1) * nbucket + nt - 1) / nt;
544
+
545
+ // in this loop, we write elements collected in the previous round
546
+ // and collect the elements that are overwritten for the next round
547
+ size_t tot_written = 0;
548
+ int round = 0;
549
+ for (;;) {
550
+ #pragma omp barrier
551
+
552
+ size_t n_to_write = 0;
553
+ for (const ToWrite<TI>& to_write_2 : all_to_write) {
554
+ n_to_write += to_write_2.lims.back();
555
+ }
556
+
557
+ tot_written += n_to_write;
558
+ // assert(tot_written <= nval);
559
+
560
+ #pragma omp master
561
+ {
562
+ if (verbose >= 1) {
563
+ printf("ROUND %d n_to_write=%zd\n", round, n_to_write);
564
+ }
565
+ if (verbose > 2) {
566
+ for (size_t b = 0; b < nbucket; b++) {
567
+ printf(" b=%zd [", b);
568
+ for (size_t i = lims[b]; i < lims[b + 1]; i++) {
569
+ printf(" %s%d",
570
+ ptrs[b] == i ? ">" : "",
571
+ int(vals[i]));
572
+ }
573
+ printf(" %s] %s\n",
574
+ ptrs[b] == lims[b + 1] ? ">" : "",
575
+ did_wrap[b] ? "w" : "");
576
+ }
577
+ printf("To write\n");
578
+ for (size_t b = 0; b < nbucket; b++) {
579
+ printf(" b=%zd ", b);
580
+ const char* sep = "[";
581
+ for (const ToWrite<TI>& to_write_2 : all_to_write) {
582
+ printf("%s", sep);
583
+ sep = " |";
584
+ size_t l0 = to_write_2.lims[b];
585
+ size_t l1 = to_write_2.lims[b + 1];
586
+ for (size_t i = l0; i < l1; i++) {
587
+ printf(" %d", int(to_write_2.rows[i]));
588
+ }
589
+ }
590
+ printf(" ]\n");
591
+ }
592
+ }
593
+ }
594
+ if (n_to_write == 0) {
595
+ break;
596
+ }
597
+ round++;
598
+
599
+ #pragma omp barrier
600
+
601
+ ToWrite<TI> next_to_write(nbucket);
602
+
603
+ for (size_t b = b0; b < b1; b++) {
604
+ for (const ToWrite<TI>& to_write_2 : all_to_write) {
605
+ size_t l0 = to_write_2.lims[b];
606
+ size_t l1 = to_write_2.lims[b + 1];
607
+ for (size_t i = l0; i < l1; i++) {
608
+ TI row = to_write_2.rows[i];
609
+ size_t idx = ptrs[b];
610
+ if (verbose > 2) {
611
+ printf(" bucket %d (rank %d) idx %zd\n",
612
+ int(row),
613
+ rank,
614
+ idx);
615
+ }
616
+ if (idx < lims[b + 1]) {
617
+ ptrs[b]++;
618
+ } else {
619
+ // wrapping around
620
+ assert(!did_wrap[b]);
621
+ did_wrap[b] = true;
622
+ idx = lims[b];
623
+ ptrs[b] = idx + 1;
624
+ }
625
+
626
+ // check if we need to remember the overwritten number
627
+ if (vals[idx] >= 0) {
628
+ TI new_row = idx / ncol;
629
+ next_to_write.add(new_row, vals[idx]);
630
+ if (verbose > 2) {
631
+ printf(" new_row=%d\n", int(new_row));
632
+ }
633
+ } else {
634
+ assert(did_wrap[b]);
635
+ }
636
+
637
+ vals[idx] = row;
638
+ }
639
+ }
640
+ }
641
+ next_to_write.bucket_sort();
642
+ #pragma omp barrier
643
+ all_to_write[rank].swap(next_to_write);
644
+ }
645
+ }
646
+ }
647
+
648
+ } // anonymous namespace
649
+
650
+ void bucket_sort(
651
+ size_t nval,
652
+ const uint64_t* vals,
653
+ uint64_t vmax,
654
+ int64_t* lims,
655
+ int64_t* perm,
656
+ int nt) {
657
+ if (nt == 0) {
658
+ bucket_sort_ref(nval, vals, vmax, lims, perm);
659
+ } else {
660
+ bucket_sort_parallel(nval, vals, vmax, lims, perm, nt);
661
+ }
662
+ }
663
+
664
+ void matrix_bucket_sort_inplace(
665
+ size_t nrow,
666
+ size_t ncol,
667
+ int32_t* vals,
668
+ int32_t vmax,
669
+ int64_t* lims,
670
+ int nt) {
671
+ if (nt == 0) {
672
+ bucket_sort_inplace_ref(nrow, ncol, vals, vmax, lims);
673
+ } else {
674
+ bucket_sort_inplace_parallel(nrow, ncol, vals, vmax, lims, nt);
675
+ }
676
+ }
677
+
678
+ void matrix_bucket_sort_inplace(
679
+ size_t nrow,
680
+ size_t ncol,
681
+ int64_t* vals,
682
+ int64_t vmax,
683
+ int64_t* lims,
684
+ int nt) {
685
+ if (nt == 0) {
686
+ bucket_sort_inplace_ref(nrow, ncol, vals, vmax, lims);
687
+ } else {
688
+ bucket_sort_inplace_parallel(nrow, ncol, vals, vmax, lims, nt);
689
+ }
690
+ }
691
+
692
+ } // namespace faiss