faiss 0.2.5 → 0.2.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/ext/faiss/extconf.rb +1 -1
- data/ext/faiss/index.cpp +13 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +2 -2
- data/vendor/faiss/faiss/AutoTune.cpp +15 -4
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +1 -5
- data/vendor/faiss/faiss/Clustering.h +0 -2
- data/vendor/faiss/faiss/IVFlib.h +0 -2
- data/vendor/faiss/faiss/Index.h +1 -2
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
- data/vendor/faiss/faiss/IndexBinary.h +0 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
- data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
- data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
- data/vendor/faiss/faiss/IndexFastScan.h +5 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
- data/vendor/faiss/faiss/IndexFlat.h +1 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
- data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
- data/vendor/faiss/faiss/IndexHNSW.h +0 -1
- data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
- data/vendor/faiss/faiss/IndexIDMap.h +0 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
- data/vendor/faiss/faiss/IndexIVF.h +121 -61
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
- data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
- data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
- data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
- data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
- data/vendor/faiss/faiss/IndexReplicas.h +0 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
- data/vendor/faiss/faiss/IndexShards.cpp +26 -109
- data/vendor/faiss/faiss/IndexShards.h +2 -3
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
- data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
- data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
- data/vendor/faiss/faiss/MetaIndexes.h +29 -0
- data/vendor/faiss/faiss/MetricType.h +14 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
- data/vendor/faiss/faiss/VectorTransform.h +1 -3
- data/vendor/faiss/faiss/clone_index.cpp +232 -18
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
- data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
- data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
- data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
- data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
- data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
- data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
- data/vendor/faiss/faiss/impl/HNSW.h +6 -9
- data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
- data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
- data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
- data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
- data/vendor/faiss/faiss/impl/NSG.h +4 -7
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
- data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
- data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
- data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
- data/vendor/faiss/faiss/index_factory.cpp +8 -10
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
- data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
- data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
- data/vendor/faiss/faiss/utils/Heap.h +35 -1
- data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
- data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
- data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
- data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
- data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
- data/vendor/faiss/faiss/utils/distances.cpp +61 -7
- data/vendor/faiss/faiss/utils/distances.h +11 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
- data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
- data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
- data/vendor/faiss/faiss/utils/fp16.h +7 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
- data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
- data/vendor/faiss/faiss/utils/hamming.h +21 -10
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
- data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
- data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
- data/vendor/faiss/faiss/utils/sorting.h +71 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
- data/vendor/faiss/faiss/utils/utils.cpp +4 -176
- data/vendor/faiss/faiss/utils/utils.h +2 -9
- metadata +30 -4
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -0,0 +1,692 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#include <faiss/utils/sorting.h>
|
11
|
+
|
12
|
+
#include <omp.h>
|
13
|
+
#include <algorithm>
|
14
|
+
|
15
|
+
#include <faiss/impl/FaissAssert.h>
|
16
|
+
#include <faiss/utils/utils.h>
|
17
|
+
|
18
|
+
namespace faiss {
|
19
|
+
|
20
|
+
/*****************************************************************************
|
21
|
+
* Argsort
|
22
|
+
****************************************************************************/
|
23
|
+
|
24
|
+
namespace {
|
25
|
+
struct ArgsortComparator {
|
26
|
+
const float* vals;
|
27
|
+
bool operator()(const size_t a, const size_t b) const {
|
28
|
+
return vals[a] < vals[b];
|
29
|
+
}
|
30
|
+
};
|
31
|
+
|
32
|
+
struct SegmentS {
|
33
|
+
size_t i0; // begin pointer in the permutation array
|
34
|
+
size_t i1; // end
|
35
|
+
size_t len() const {
|
36
|
+
return i1 - i0;
|
37
|
+
}
|
38
|
+
};
|
39
|
+
|
40
|
+
// see https://en.wikipedia.org/wiki/Merge_algorithm#Parallel_merge
|
41
|
+
// extended to > 1 merge thread
|
42
|
+
|
43
|
+
// merges 2 ranges that should be consecutive on the source into
|
44
|
+
// the union of the two on the destination
|
45
|
+
template <typename T>
|
46
|
+
void parallel_merge(
|
47
|
+
const T* src,
|
48
|
+
T* dst,
|
49
|
+
SegmentS& s1,
|
50
|
+
SegmentS& s2,
|
51
|
+
int nt,
|
52
|
+
const ArgsortComparator& comp) {
|
53
|
+
if (s2.len() > s1.len()) { // make sure that s1 larger than s2
|
54
|
+
std::swap(s1, s2);
|
55
|
+
}
|
56
|
+
|
57
|
+
// compute sub-ranges for each thread
|
58
|
+
std::vector<SegmentS> s1s(nt), s2s(nt), sws(nt);
|
59
|
+
s2s[0].i0 = s2.i0;
|
60
|
+
s2s[nt - 1].i1 = s2.i1;
|
61
|
+
|
62
|
+
// not sure parallel actually helps here
|
63
|
+
#pragma omp parallel for num_threads(nt)
|
64
|
+
for (int t = 0; t < nt; t++) {
|
65
|
+
s1s[t].i0 = s1.i0 + s1.len() * t / nt;
|
66
|
+
s1s[t].i1 = s1.i0 + s1.len() * (t + 1) / nt;
|
67
|
+
|
68
|
+
if (t + 1 < nt) {
|
69
|
+
T pivot = src[s1s[t].i1];
|
70
|
+
size_t i0 = s2.i0, i1 = s2.i1;
|
71
|
+
while (i0 + 1 < i1) {
|
72
|
+
size_t imed = (i1 + i0) / 2;
|
73
|
+
if (comp(pivot, src[imed])) {
|
74
|
+
i1 = imed;
|
75
|
+
} else {
|
76
|
+
i0 = imed;
|
77
|
+
}
|
78
|
+
}
|
79
|
+
s2s[t].i1 = s2s[t + 1].i0 = i1;
|
80
|
+
}
|
81
|
+
}
|
82
|
+
s1.i0 = std::min(s1.i0, s2.i0);
|
83
|
+
s1.i1 = std::max(s1.i1, s2.i1);
|
84
|
+
s2 = s1;
|
85
|
+
sws[0].i0 = s1.i0;
|
86
|
+
for (int t = 0; t < nt; t++) {
|
87
|
+
sws[t].i1 = sws[t].i0 + s1s[t].len() + s2s[t].len();
|
88
|
+
if (t + 1 < nt) {
|
89
|
+
sws[t + 1].i0 = sws[t].i1;
|
90
|
+
}
|
91
|
+
}
|
92
|
+
assert(sws[nt - 1].i1 == s1.i1);
|
93
|
+
|
94
|
+
// do the actual merging
|
95
|
+
#pragma omp parallel for num_threads(nt)
|
96
|
+
for (int t = 0; t < nt; t++) {
|
97
|
+
SegmentS sw = sws[t];
|
98
|
+
SegmentS s1t = s1s[t];
|
99
|
+
SegmentS s2t = s2s[t];
|
100
|
+
if (s1t.i0 < s1t.i1 && s2t.i0 < s2t.i1) {
|
101
|
+
for (;;) {
|
102
|
+
// assert (sw.len() == s1t.len() + s2t.len());
|
103
|
+
if (comp(src[s1t.i0], src[s2t.i0])) {
|
104
|
+
dst[sw.i0++] = src[s1t.i0++];
|
105
|
+
if (s1t.i0 == s1t.i1) {
|
106
|
+
break;
|
107
|
+
}
|
108
|
+
} else {
|
109
|
+
dst[sw.i0++] = src[s2t.i0++];
|
110
|
+
if (s2t.i0 == s2t.i1) {
|
111
|
+
break;
|
112
|
+
}
|
113
|
+
}
|
114
|
+
}
|
115
|
+
}
|
116
|
+
if (s1t.len() > 0) {
|
117
|
+
assert(s1t.len() == sw.len());
|
118
|
+
memcpy(dst + sw.i0, src + s1t.i0, s1t.len() * sizeof(dst[0]));
|
119
|
+
} else if (s2t.len() > 0) {
|
120
|
+
assert(s2t.len() == sw.len());
|
121
|
+
memcpy(dst + sw.i0, src + s2t.i0, s2t.len() * sizeof(dst[0]));
|
122
|
+
}
|
123
|
+
}
|
124
|
+
}
|
125
|
+
|
126
|
+
}; // namespace
|
127
|
+
|
128
|
+
void fvec_argsort(size_t n, const float* vals, size_t* perm) {
|
129
|
+
for (size_t i = 0; i < n; i++) {
|
130
|
+
perm[i] = i;
|
131
|
+
}
|
132
|
+
ArgsortComparator comp = {vals};
|
133
|
+
std::sort(perm, perm + n, comp);
|
134
|
+
}
|
135
|
+
|
136
|
+
void fvec_argsort_parallel(size_t n, const float* vals, size_t* perm) {
|
137
|
+
size_t* perm2 = new size_t[n];
|
138
|
+
// 2 result tables, during merging, flip between them
|
139
|
+
size_t *permB = perm2, *permA = perm;
|
140
|
+
|
141
|
+
int nt = omp_get_max_threads();
|
142
|
+
{ // prepare correct permutation so that the result ends in perm
|
143
|
+
// at final iteration
|
144
|
+
int nseg = nt;
|
145
|
+
while (nseg > 1) {
|
146
|
+
nseg = (nseg + 1) / 2;
|
147
|
+
std::swap(permA, permB);
|
148
|
+
}
|
149
|
+
}
|
150
|
+
|
151
|
+
#pragma omp parallel
|
152
|
+
for (size_t i = 0; i < n; i++) {
|
153
|
+
permA[i] = i;
|
154
|
+
}
|
155
|
+
|
156
|
+
ArgsortComparator comp = {vals};
|
157
|
+
|
158
|
+
std::vector<SegmentS> segs(nt);
|
159
|
+
|
160
|
+
// independent sorts
|
161
|
+
#pragma omp parallel for
|
162
|
+
for (int t = 0; t < nt; t++) {
|
163
|
+
size_t i0 = t * n / nt;
|
164
|
+
size_t i1 = (t + 1) * n / nt;
|
165
|
+
SegmentS seg = {i0, i1};
|
166
|
+
std::sort(permA + seg.i0, permA + seg.i1, comp);
|
167
|
+
segs[t] = seg;
|
168
|
+
}
|
169
|
+
int prev_nested = omp_get_nested();
|
170
|
+
omp_set_nested(1);
|
171
|
+
|
172
|
+
int nseg = nt;
|
173
|
+
while (nseg > 1) {
|
174
|
+
int nseg1 = (nseg + 1) / 2;
|
175
|
+
int sub_nt = nseg % 2 == 0 ? nt : nt - 1;
|
176
|
+
int sub_nseg1 = nseg / 2;
|
177
|
+
|
178
|
+
#pragma omp parallel for num_threads(nseg1)
|
179
|
+
for (int s = 0; s < nseg; s += 2) {
|
180
|
+
if (s + 1 == nseg) { // otherwise isolated segment
|
181
|
+
memcpy(permB + segs[s].i0,
|
182
|
+
permA + segs[s].i0,
|
183
|
+
segs[s].len() * sizeof(size_t));
|
184
|
+
} else {
|
185
|
+
int t0 = s * sub_nt / sub_nseg1;
|
186
|
+
int t1 = (s + 1) * sub_nt / sub_nseg1;
|
187
|
+
printf("merge %d %d, %d threads\n", s, s + 1, t1 - t0);
|
188
|
+
parallel_merge(
|
189
|
+
permA, permB, segs[s], segs[s + 1], t1 - t0, comp);
|
190
|
+
}
|
191
|
+
}
|
192
|
+
for (int s = 0; s < nseg; s += 2) {
|
193
|
+
segs[s / 2] = segs[s];
|
194
|
+
}
|
195
|
+
nseg = nseg1;
|
196
|
+
std::swap(permA, permB);
|
197
|
+
}
|
198
|
+
assert(permA == perm);
|
199
|
+
omp_set_nested(prev_nested);
|
200
|
+
delete[] perm2;
|
201
|
+
}
|
202
|
+
|
203
|
+
/*****************************************************************************
|
204
|
+
* Bucket sort
|
205
|
+
****************************************************************************/
|
206
|
+
|
207
|
+
// extern symbol in the .h
|
208
|
+
int bucket_sort_verbose = 0;
|
209
|
+
|
210
|
+
namespace {
|
211
|
+
|
212
|
+
void bucket_sort_ref(
|
213
|
+
size_t nval,
|
214
|
+
const uint64_t* vals,
|
215
|
+
uint64_t vmax,
|
216
|
+
int64_t* lims,
|
217
|
+
int64_t* perm) {
|
218
|
+
double t0 = getmillisecs();
|
219
|
+
memset(lims, 0, sizeof(*lims) * (vmax + 1));
|
220
|
+
for (size_t i = 0; i < nval; i++) {
|
221
|
+
FAISS_THROW_IF_NOT(vals[i] < vmax);
|
222
|
+
lims[vals[i] + 1]++;
|
223
|
+
}
|
224
|
+
double t1 = getmillisecs();
|
225
|
+
// compute cumulative sum
|
226
|
+
for (size_t i = 0; i < vmax; i++) {
|
227
|
+
lims[i + 1] += lims[i];
|
228
|
+
}
|
229
|
+
FAISS_THROW_IF_NOT(lims[vmax] == nval);
|
230
|
+
double t2 = getmillisecs();
|
231
|
+
// populate buckets
|
232
|
+
for (size_t i = 0; i < nval; i++) {
|
233
|
+
perm[lims[vals[i]]++] = i;
|
234
|
+
}
|
235
|
+
double t3 = getmillisecs();
|
236
|
+
// reset pointers
|
237
|
+
for (size_t i = vmax; i > 0; i--) {
|
238
|
+
lims[i] = lims[i - 1];
|
239
|
+
}
|
240
|
+
lims[0] = 0;
|
241
|
+
double t4 = getmillisecs();
|
242
|
+
if (bucket_sort_verbose) {
|
243
|
+
printf("times %.3f %.3f %.3f %.3f\n",
|
244
|
+
t1 - t0,
|
245
|
+
t2 - t1,
|
246
|
+
t3 - t2,
|
247
|
+
t4 - t3);
|
248
|
+
}
|
249
|
+
}
|
250
|
+
|
251
|
+
void bucket_sort_parallel(
|
252
|
+
size_t nval,
|
253
|
+
const uint64_t* vals,
|
254
|
+
uint64_t vmax,
|
255
|
+
int64_t* lims,
|
256
|
+
int64_t* perm,
|
257
|
+
int nt_in) {
|
258
|
+
memset(lims, 0, sizeof(*lims) * (vmax + 1));
|
259
|
+
#pragma omp parallel num_threads(nt_in)
|
260
|
+
{
|
261
|
+
int nt = omp_get_num_threads(); // might be different from nt_in
|
262
|
+
int rank = omp_get_thread_num();
|
263
|
+
std::vector<int64_t> local_lims(vmax + 1);
|
264
|
+
|
265
|
+
// range of indices handled by this thread
|
266
|
+
size_t i0 = nval * rank / nt;
|
267
|
+
size_t i1 = nval * (rank + 1) / nt;
|
268
|
+
|
269
|
+
// build histogram in local lims
|
270
|
+
double t0 = getmillisecs();
|
271
|
+
for (size_t i = i0; i < i1; i++) {
|
272
|
+
local_lims[vals[i]]++;
|
273
|
+
}
|
274
|
+
#pragma omp critical
|
275
|
+
{ // accumulate histograms (not shifted indices to prepare cumsum)
|
276
|
+
for (size_t i = 0; i < vmax; i++) {
|
277
|
+
lims[i + 1] += local_lims[i];
|
278
|
+
}
|
279
|
+
}
|
280
|
+
#pragma omp barrier
|
281
|
+
|
282
|
+
double t1 = getmillisecs();
|
283
|
+
#pragma omp master
|
284
|
+
{
|
285
|
+
// compute cumulative sum
|
286
|
+
for (size_t i = 0; i < vmax; i++) {
|
287
|
+
lims[i + 1] += lims[i];
|
288
|
+
}
|
289
|
+
FAISS_THROW_IF_NOT(lims[vmax] == nval);
|
290
|
+
}
|
291
|
+
#pragma omp barrier
|
292
|
+
|
293
|
+
#pragma omp critical
|
294
|
+
{ // current thread grabs a slot in the buckets
|
295
|
+
for (size_t i = 0; i < vmax; i++) {
|
296
|
+
size_t nv = local_lims[i];
|
297
|
+
local_lims[i] = lims[i]; // where we should start writing
|
298
|
+
lims[i] += nv;
|
299
|
+
}
|
300
|
+
}
|
301
|
+
|
302
|
+
double t2 = getmillisecs();
|
303
|
+
#pragma omp barrier
|
304
|
+
{ // populate buckets, this is the slowest operation
|
305
|
+
for (size_t i = i0; i < i1; i++) {
|
306
|
+
perm[local_lims[vals[i]]++] = i;
|
307
|
+
}
|
308
|
+
}
|
309
|
+
#pragma omp barrier
|
310
|
+
double t3 = getmillisecs();
|
311
|
+
|
312
|
+
#pragma omp master
|
313
|
+
{ // shift back lims
|
314
|
+
for (size_t i = vmax; i > 0; i--) {
|
315
|
+
lims[i] = lims[i - 1];
|
316
|
+
}
|
317
|
+
lims[0] = 0;
|
318
|
+
double t4 = getmillisecs();
|
319
|
+
if (bucket_sort_verbose) {
|
320
|
+
printf("times %.3f %.3f %.3f %.3f\n",
|
321
|
+
t1 - t0,
|
322
|
+
t2 - t1,
|
323
|
+
t3 - t2,
|
324
|
+
t4 - t3);
|
325
|
+
}
|
326
|
+
}
|
327
|
+
}
|
328
|
+
}
|
329
|
+
|
330
|
+
/***********************************************
|
331
|
+
* in-place bucket sort
|
332
|
+
*/
|
333
|
+
|
334
|
+
template <class TI>
|
335
|
+
void bucket_sort_inplace_ref(
|
336
|
+
size_t nrow,
|
337
|
+
size_t ncol,
|
338
|
+
TI* vals,
|
339
|
+
TI nbucket,
|
340
|
+
int64_t* lims) {
|
341
|
+
double t0 = getmillisecs();
|
342
|
+
size_t nval = nrow * ncol;
|
343
|
+
FAISS_THROW_IF_NOT(
|
344
|
+
nbucket < nval); // unclear what would happen in this case...
|
345
|
+
|
346
|
+
memset(lims, 0, sizeof(*lims) * (nbucket + 1));
|
347
|
+
for (size_t i = 0; i < nval; i++) {
|
348
|
+
FAISS_THROW_IF_NOT(vals[i] < nbucket);
|
349
|
+
lims[vals[i] + 1]++;
|
350
|
+
}
|
351
|
+
double t1 = getmillisecs();
|
352
|
+
// compute cumulative sum
|
353
|
+
for (size_t i = 0; i < nbucket; i++) {
|
354
|
+
lims[i + 1] += lims[i];
|
355
|
+
}
|
356
|
+
FAISS_THROW_IF_NOT(lims[nbucket] == nval);
|
357
|
+
double t2 = getmillisecs();
|
358
|
+
|
359
|
+
std::vector<size_t> ptrs(nbucket);
|
360
|
+
for (size_t i = 0; i < nbucket; i++) {
|
361
|
+
ptrs[i] = lims[i];
|
362
|
+
}
|
363
|
+
|
364
|
+
// find loops in the permutation and follow them
|
365
|
+
TI row = -1;
|
366
|
+
TI init_bucket_no = 0, bucket_no = 0;
|
367
|
+
for (;;) {
|
368
|
+
size_t idx = ptrs[bucket_no];
|
369
|
+
if (row >= 0) {
|
370
|
+
ptrs[bucket_no] += 1;
|
371
|
+
}
|
372
|
+
assert(idx < lims[bucket_no + 1]);
|
373
|
+
TI next_bucket_no = vals[idx];
|
374
|
+
vals[idx] = row;
|
375
|
+
if (next_bucket_no != -1) {
|
376
|
+
row = idx / ncol;
|
377
|
+
bucket_no = next_bucket_no;
|
378
|
+
} else {
|
379
|
+
// start new loop
|
380
|
+
for (; init_bucket_no < nbucket; init_bucket_no++) {
|
381
|
+
if (ptrs[init_bucket_no] < lims[init_bucket_no + 1]) {
|
382
|
+
break;
|
383
|
+
}
|
384
|
+
}
|
385
|
+
if (init_bucket_no == nbucket) { // we're done
|
386
|
+
break;
|
387
|
+
}
|
388
|
+
bucket_no = init_bucket_no;
|
389
|
+
row = -1;
|
390
|
+
}
|
391
|
+
}
|
392
|
+
|
393
|
+
for (size_t i = 0; i < nbucket; i++) {
|
394
|
+
assert(ptrs[i] == lims[i + 1]);
|
395
|
+
}
|
396
|
+
double t3 = getmillisecs();
|
397
|
+
if (bucket_sort_verbose) {
|
398
|
+
printf("times %.3f %.3f %.3f\n", t1 - t0, t2 - t1, t3 - t2);
|
399
|
+
}
|
400
|
+
}
|
401
|
+
|
402
|
+
// collects row numbers to write into buckets
|
403
|
+
template <class TI>
|
404
|
+
struct ToWrite {
|
405
|
+
TI nbucket;
|
406
|
+
std::vector<TI> buckets;
|
407
|
+
std::vector<TI> rows;
|
408
|
+
std::vector<size_t> lims;
|
409
|
+
|
410
|
+
explicit ToWrite(TI nbucket) : nbucket(nbucket) {
|
411
|
+
lims.resize(nbucket + 1);
|
412
|
+
}
|
413
|
+
|
414
|
+
/// add one element (row) to write in bucket b
|
415
|
+
void add(TI row, TI b) {
|
416
|
+
assert(b >= 0 && b < nbucket);
|
417
|
+
rows.push_back(row);
|
418
|
+
buckets.push_back(b);
|
419
|
+
}
|
420
|
+
|
421
|
+
void bucket_sort() {
|
422
|
+
FAISS_THROW_IF_NOT(buckets.size() == rows.size());
|
423
|
+
lims.resize(nbucket + 1);
|
424
|
+
memset(lims.data(), 0, sizeof(lims[0]) * (nbucket + 1));
|
425
|
+
|
426
|
+
for (size_t i = 0; i < buckets.size(); i++) {
|
427
|
+
assert(buckets[i] >= 0 && buckets[i] < nbucket);
|
428
|
+
lims[buckets[i] + 1]++;
|
429
|
+
}
|
430
|
+
// compute cumulative sum
|
431
|
+
for (size_t i = 0; i < nbucket; i++) {
|
432
|
+
lims[i + 1] += lims[i];
|
433
|
+
}
|
434
|
+
FAISS_THROW_IF_NOT(lims[nbucket] == buckets.size());
|
435
|
+
|
436
|
+
// could also do a circular perm...
|
437
|
+
std::vector<TI> new_rows(rows.size());
|
438
|
+
std::vector<size_t> ptrs = lims;
|
439
|
+
for (size_t i = 0; i < buckets.size(); i++) {
|
440
|
+
TI b = buckets[i];
|
441
|
+
assert(ptrs[b] < lims[b + 1]);
|
442
|
+
new_rows[ptrs[b]++] = rows[i];
|
443
|
+
}
|
444
|
+
buckets.resize(0);
|
445
|
+
std::swap(rows, new_rows);
|
446
|
+
}
|
447
|
+
|
448
|
+
void swap(ToWrite& other) {
|
449
|
+
assert(nbucket == other.nbucket);
|
450
|
+
buckets.swap(other.buckets);
|
451
|
+
rows.swap(other.rows);
|
452
|
+
lims.swap(other.lims);
|
453
|
+
}
|
454
|
+
};
|
455
|
+
|
456
|
+
template <class TI>
|
457
|
+
void bucket_sort_inplace_parallel(
|
458
|
+
size_t nrow,
|
459
|
+
size_t ncol,
|
460
|
+
TI* vals,
|
461
|
+
TI nbucket,
|
462
|
+
int64_t* lims,
|
463
|
+
int nt_in) {
|
464
|
+
int verbose = bucket_sort_verbose;
|
465
|
+
memset(lims, 0, sizeof(*lims) * (nbucket + 1));
|
466
|
+
std::vector<ToWrite<TI>> all_to_write;
|
467
|
+
size_t nval = nrow * ncol;
|
468
|
+
FAISS_THROW_IF_NOT(
|
469
|
+
nbucket < nval); // unclear what would happen in this case...
|
470
|
+
|
471
|
+
// try to keep size of all_to_write < 5GiB
|
472
|
+
// but we need at least one element per bucket
|
473
|
+
size_t init_to_write = std::max(
|
474
|
+
size_t(nbucket),
|
475
|
+
std::min(nval / 10, ((size_t)5 << 30) / (sizeof(TI) * 3 * nt_in)));
|
476
|
+
if (verbose > 0) {
|
477
|
+
printf("init_to_write=%zd\n", init_to_write);
|
478
|
+
}
|
479
|
+
|
480
|
+
std::vector<size_t> ptrs(nbucket); // ptrs is shared across all threads
|
481
|
+
std::vector<char> did_wrap(
|
482
|
+
nbucket); // DON'T use std::vector<bool> that cannot be accessed
|
483
|
+
// safely from multiple threads!!!
|
484
|
+
|
485
|
+
#pragma omp parallel num_threads(nt_in)
|
486
|
+
{
|
487
|
+
int nt = omp_get_num_threads(); // might be different from nt_in (?)
|
488
|
+
int rank = omp_get_thread_num();
|
489
|
+
std::vector<int64_t> local_lims(nbucket + 1);
|
490
|
+
|
491
|
+
// range of indices handled by this thread
|
492
|
+
size_t i0 = nval * rank / nt;
|
493
|
+
size_t i1 = nval * (rank + 1) / nt;
|
494
|
+
|
495
|
+
// build histogram in local lims
|
496
|
+
for (size_t i = i0; i < i1; i++) {
|
497
|
+
local_lims[vals[i]]++;
|
498
|
+
}
|
499
|
+
#pragma omp critical
|
500
|
+
{ // accumulate histograms (not shifted indices to prepare cumsum)
|
501
|
+
for (size_t i = 0; i < nbucket; i++) {
|
502
|
+
lims[i + 1] += local_lims[i];
|
503
|
+
}
|
504
|
+
all_to_write.push_back(ToWrite<TI>(nbucket));
|
505
|
+
}
|
506
|
+
|
507
|
+
#pragma omp barrier
|
508
|
+
// this thread's things to write
|
509
|
+
ToWrite<TI>& to_write = all_to_write[rank];
|
510
|
+
|
511
|
+
#pragma omp master
|
512
|
+
{
|
513
|
+
// compute cumulative sum
|
514
|
+
for (size_t i = 0; i < nbucket; i++) {
|
515
|
+
lims[i + 1] += lims[i];
|
516
|
+
}
|
517
|
+
FAISS_THROW_IF_NOT(lims[nbucket] == nval);
|
518
|
+
// at this point lims is final (read only!)
|
519
|
+
|
520
|
+
memcpy(ptrs.data(), lims, sizeof(lims[0]) * nbucket);
|
521
|
+
|
522
|
+
// initial values to write (we write -1s to get the process running)
|
523
|
+
// make sure at least one element per bucket
|
524
|
+
size_t written = 0;
|
525
|
+
for (TI b = 0; b < nbucket; b++) {
|
526
|
+
size_t l0 = lims[b], l1 = lims[b + 1];
|
527
|
+
size_t target_to_write = l1 * init_to_write / nval;
|
528
|
+
do {
|
529
|
+
if (l0 == l1) {
|
530
|
+
break;
|
531
|
+
}
|
532
|
+
to_write.add(-1, b);
|
533
|
+
l0++;
|
534
|
+
written++;
|
535
|
+
} while (written < target_to_write);
|
536
|
+
}
|
537
|
+
|
538
|
+
to_write.bucket_sort();
|
539
|
+
}
|
540
|
+
|
541
|
+
// this thread writes only buckets b0:b1
|
542
|
+
size_t b0 = (rank * nbucket + nt - 1) / nt;
|
543
|
+
size_t b1 = ((rank + 1) * nbucket + nt - 1) / nt;
|
544
|
+
|
545
|
+
// in this loop, we write elements collected in the previous round
|
546
|
+
// and collect the elements that are overwritten for the next round
|
547
|
+
size_t tot_written = 0;
|
548
|
+
int round = 0;
|
549
|
+
for (;;) {
|
550
|
+
#pragma omp barrier
|
551
|
+
|
552
|
+
size_t n_to_write = 0;
|
553
|
+
for (const ToWrite<TI>& to_write_2 : all_to_write) {
|
554
|
+
n_to_write += to_write_2.lims.back();
|
555
|
+
}
|
556
|
+
|
557
|
+
tot_written += n_to_write;
|
558
|
+
// assert(tot_written <= nval);
|
559
|
+
|
560
|
+
#pragma omp master
|
561
|
+
{
|
562
|
+
if (verbose >= 1) {
|
563
|
+
printf("ROUND %d n_to_write=%zd\n", round, n_to_write);
|
564
|
+
}
|
565
|
+
if (verbose > 2) {
|
566
|
+
for (size_t b = 0; b < nbucket; b++) {
|
567
|
+
printf(" b=%zd [", b);
|
568
|
+
for (size_t i = lims[b]; i < lims[b + 1]; i++) {
|
569
|
+
printf(" %s%d",
|
570
|
+
ptrs[b] == i ? ">" : "",
|
571
|
+
int(vals[i]));
|
572
|
+
}
|
573
|
+
printf(" %s] %s\n",
|
574
|
+
ptrs[b] == lims[b + 1] ? ">" : "",
|
575
|
+
did_wrap[b] ? "w" : "");
|
576
|
+
}
|
577
|
+
printf("To write\n");
|
578
|
+
for (size_t b = 0; b < nbucket; b++) {
|
579
|
+
printf(" b=%zd ", b);
|
580
|
+
const char* sep = "[";
|
581
|
+
for (const ToWrite<TI>& to_write_2 : all_to_write) {
|
582
|
+
printf("%s", sep);
|
583
|
+
sep = " |";
|
584
|
+
size_t l0 = to_write_2.lims[b];
|
585
|
+
size_t l1 = to_write_2.lims[b + 1];
|
586
|
+
for (size_t i = l0; i < l1; i++) {
|
587
|
+
printf(" %d", int(to_write_2.rows[i]));
|
588
|
+
}
|
589
|
+
}
|
590
|
+
printf(" ]\n");
|
591
|
+
}
|
592
|
+
}
|
593
|
+
}
|
594
|
+
if (n_to_write == 0) {
|
595
|
+
break;
|
596
|
+
}
|
597
|
+
round++;
|
598
|
+
|
599
|
+
#pragma omp barrier
|
600
|
+
|
601
|
+
ToWrite<TI> next_to_write(nbucket);
|
602
|
+
|
603
|
+
for (size_t b = b0; b < b1; b++) {
|
604
|
+
for (const ToWrite<TI>& to_write_2 : all_to_write) {
|
605
|
+
size_t l0 = to_write_2.lims[b];
|
606
|
+
size_t l1 = to_write_2.lims[b + 1];
|
607
|
+
for (size_t i = l0; i < l1; i++) {
|
608
|
+
TI row = to_write_2.rows[i];
|
609
|
+
size_t idx = ptrs[b];
|
610
|
+
if (verbose > 2) {
|
611
|
+
printf(" bucket %d (rank %d) idx %zd\n",
|
612
|
+
int(row),
|
613
|
+
rank,
|
614
|
+
idx);
|
615
|
+
}
|
616
|
+
if (idx < lims[b + 1]) {
|
617
|
+
ptrs[b]++;
|
618
|
+
} else {
|
619
|
+
// wrapping around
|
620
|
+
assert(!did_wrap[b]);
|
621
|
+
did_wrap[b] = true;
|
622
|
+
idx = lims[b];
|
623
|
+
ptrs[b] = idx + 1;
|
624
|
+
}
|
625
|
+
|
626
|
+
// check if we need to remember the overwritten number
|
627
|
+
if (vals[idx] >= 0) {
|
628
|
+
TI new_row = idx / ncol;
|
629
|
+
next_to_write.add(new_row, vals[idx]);
|
630
|
+
if (verbose > 2) {
|
631
|
+
printf(" new_row=%d\n", int(new_row));
|
632
|
+
}
|
633
|
+
} else {
|
634
|
+
assert(did_wrap[b]);
|
635
|
+
}
|
636
|
+
|
637
|
+
vals[idx] = row;
|
638
|
+
}
|
639
|
+
}
|
640
|
+
}
|
641
|
+
next_to_write.bucket_sort();
|
642
|
+
#pragma omp barrier
|
643
|
+
all_to_write[rank].swap(next_to_write);
|
644
|
+
}
|
645
|
+
}
|
646
|
+
}
|
647
|
+
|
648
|
+
} // anonymous namespace
|
649
|
+
|
650
|
+
void bucket_sort(
|
651
|
+
size_t nval,
|
652
|
+
const uint64_t* vals,
|
653
|
+
uint64_t vmax,
|
654
|
+
int64_t* lims,
|
655
|
+
int64_t* perm,
|
656
|
+
int nt) {
|
657
|
+
if (nt == 0) {
|
658
|
+
bucket_sort_ref(nval, vals, vmax, lims, perm);
|
659
|
+
} else {
|
660
|
+
bucket_sort_parallel(nval, vals, vmax, lims, perm, nt);
|
661
|
+
}
|
662
|
+
}
|
663
|
+
|
664
|
+
void matrix_bucket_sort_inplace(
|
665
|
+
size_t nrow,
|
666
|
+
size_t ncol,
|
667
|
+
int32_t* vals,
|
668
|
+
int32_t vmax,
|
669
|
+
int64_t* lims,
|
670
|
+
int nt) {
|
671
|
+
if (nt == 0) {
|
672
|
+
bucket_sort_inplace_ref(nrow, ncol, vals, vmax, lims);
|
673
|
+
} else {
|
674
|
+
bucket_sort_inplace_parallel(nrow, ncol, vals, vmax, lims, nt);
|
675
|
+
}
|
676
|
+
}
|
677
|
+
|
678
|
+
void matrix_bucket_sort_inplace(
|
679
|
+
size_t nrow,
|
680
|
+
size_t ncol,
|
681
|
+
int64_t* vals,
|
682
|
+
int64_t vmax,
|
683
|
+
int64_t* lims,
|
684
|
+
int nt) {
|
685
|
+
if (nt == 0) {
|
686
|
+
bucket_sort_inplace_ref(nrow, ncol, vals, vmax, lims);
|
687
|
+
} else {
|
688
|
+
bucket_sort_inplace_parallel(nrow, ncol, vals, vmax, lims, nt);
|
689
|
+
}
|
690
|
+
}
|
691
|
+
|
692
|
+
} // namespace faiss
|