faiss 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#ifndef FAISS_INDEX_LATTICE_H
|
|
11
|
+
#define FAISS_INDEX_LATTICE_H
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
#include <vector>
|
|
15
|
+
|
|
16
|
+
#include <faiss/IndexIVF.h>
|
|
17
|
+
#include <faiss/impl/lattice_Zn.h>
|
|
18
|
+
|
|
19
|
+
namespace faiss {
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
/** Index that encodes a vector with a series of Zn lattice quantizers
|
|
26
|
+
*/
|
|
27
|
+
struct IndexLattice: Index {
|
|
28
|
+
|
|
29
|
+
/// number of sub-vectors
|
|
30
|
+
int nsq;
|
|
31
|
+
/// dimension of sub-vectors
|
|
32
|
+
size_t dsq;
|
|
33
|
+
|
|
34
|
+
/// the lattice quantizer
|
|
35
|
+
ZnSphereCodecAlt zn_sphere_codec;
|
|
36
|
+
|
|
37
|
+
/// nb bits used to encode the scale, per subvector
|
|
38
|
+
int scale_nbit, lattice_nbit;
|
|
39
|
+
/// total, in bytes
|
|
40
|
+
size_t code_size;
|
|
41
|
+
|
|
42
|
+
/// mins and maxes of the vector norms, per subquantizer
|
|
43
|
+
std::vector<float> trained;
|
|
44
|
+
|
|
45
|
+
IndexLattice (idx_t d, int nsq, int scale_nbit, int r2);
|
|
46
|
+
|
|
47
|
+
void train(idx_t n, const float* x) override;
|
|
48
|
+
|
|
49
|
+
/* The standalone codec interface */
|
|
50
|
+
size_t sa_code_size () const override;
|
|
51
|
+
|
|
52
|
+
void sa_encode (idx_t n, const float *x,
|
|
53
|
+
uint8_t *bytes) const override;
|
|
54
|
+
|
|
55
|
+
void sa_decode (idx_t n, const uint8_t *bytes,
|
|
56
|
+
float *x) const override;
|
|
57
|
+
|
|
58
|
+
/// not implemented
|
|
59
|
+
void add(idx_t n, const float* x) override;
|
|
60
|
+
void search(idx_t n, const float* x, idx_t k,
|
|
61
|
+
float* distances, idx_t* labels) const override;
|
|
62
|
+
void reset() override;
|
|
63
|
+
|
|
64
|
+
};
|
|
65
|
+
|
|
66
|
+
} // namespace faiss
|
|
67
|
+
|
|
68
|
+
#endif
|
|
@@ -0,0 +1,1188 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#include <faiss/IndexPQ.h>
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
#include <cstddef>
|
|
14
|
+
#include <cstring>
|
|
15
|
+
#include <cstdio>
|
|
16
|
+
#include <cmath>
|
|
17
|
+
|
|
18
|
+
#include <algorithm>
|
|
19
|
+
|
|
20
|
+
#include <faiss/impl/FaissAssert.h>
|
|
21
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
|
22
|
+
#include <faiss/utils/hamming.h>
|
|
23
|
+
|
|
24
|
+
namespace faiss {
|
|
25
|
+
|
|
26
|
+
/*********************************************************
|
|
27
|
+
* IndexPQ implementation
|
|
28
|
+
********************************************************/
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
IndexPQ::IndexPQ (int d, size_t M, size_t nbits, MetricType metric):
|
|
32
|
+
Index(d, metric), pq(d, M, nbits)
|
|
33
|
+
{
|
|
34
|
+
is_trained = false;
|
|
35
|
+
do_polysemous_training = false;
|
|
36
|
+
polysemous_ht = nbits * M + 1;
|
|
37
|
+
search_type = ST_PQ;
|
|
38
|
+
encode_signs = false;
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
IndexPQ::IndexPQ ()
|
|
42
|
+
{
|
|
43
|
+
metric_type = METRIC_L2;
|
|
44
|
+
is_trained = false;
|
|
45
|
+
do_polysemous_training = false;
|
|
46
|
+
polysemous_ht = pq.nbits * pq.M + 1;
|
|
47
|
+
search_type = ST_PQ;
|
|
48
|
+
encode_signs = false;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
void IndexPQ::train (idx_t n, const float *x)
|
|
53
|
+
{
|
|
54
|
+
if (!do_polysemous_training) { // standard training
|
|
55
|
+
pq.train(n, x);
|
|
56
|
+
} else {
|
|
57
|
+
idx_t ntrain_perm = polysemous_training.ntrain_permutation;
|
|
58
|
+
|
|
59
|
+
if (ntrain_perm > n / 4)
|
|
60
|
+
ntrain_perm = n / 4;
|
|
61
|
+
if (verbose) {
|
|
62
|
+
printf ("PQ training on %ld points, remains %ld points: "
|
|
63
|
+
"training polysemous on %s\n",
|
|
64
|
+
n - ntrain_perm, ntrain_perm,
|
|
65
|
+
ntrain_perm == 0 ? "centroids" : "these");
|
|
66
|
+
}
|
|
67
|
+
pq.train(n - ntrain_perm, x);
|
|
68
|
+
|
|
69
|
+
polysemous_training.optimize_pq_for_hamming (
|
|
70
|
+
pq, ntrain_perm, x + (n - ntrain_perm) * d);
|
|
71
|
+
}
|
|
72
|
+
is_trained = true;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
void IndexPQ::add (idx_t n, const float *x)
|
|
77
|
+
{
|
|
78
|
+
FAISS_THROW_IF_NOT (is_trained);
|
|
79
|
+
codes.resize ((n + ntotal) * pq.code_size);
|
|
80
|
+
pq.compute_codes (x, &codes[ntotal * pq.code_size], n);
|
|
81
|
+
ntotal += n;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
size_t IndexPQ::remove_ids (const IDSelector & sel)
|
|
86
|
+
{
|
|
87
|
+
idx_t j = 0;
|
|
88
|
+
for (idx_t i = 0; i < ntotal; i++) {
|
|
89
|
+
if (sel.is_member (i)) {
|
|
90
|
+
// should be removed
|
|
91
|
+
} else {
|
|
92
|
+
if (i > j) {
|
|
93
|
+
memmove (&codes[pq.code_size * j], &codes[pq.code_size * i], pq.code_size);
|
|
94
|
+
}
|
|
95
|
+
j++;
|
|
96
|
+
}
|
|
97
|
+
}
|
|
98
|
+
size_t nremove = ntotal - j;
|
|
99
|
+
if (nremove > 0) {
|
|
100
|
+
ntotal = j;
|
|
101
|
+
codes.resize (ntotal * pq.code_size);
|
|
102
|
+
}
|
|
103
|
+
return nremove;
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
void IndexPQ::reset()
|
|
108
|
+
{
|
|
109
|
+
codes.clear();
|
|
110
|
+
ntotal = 0;
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
void IndexPQ::reconstruct_n (idx_t i0, idx_t ni, float *recons) const
|
|
114
|
+
{
|
|
115
|
+
FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
|
|
116
|
+
for (idx_t i = 0; i < ni; i++) {
|
|
117
|
+
const uint8_t * code = &codes[(i0 + i) * pq.code_size];
|
|
118
|
+
pq.decode (code, recons + i * d);
|
|
119
|
+
}
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
void IndexPQ::reconstruct (idx_t key, float * recons) const
|
|
124
|
+
{
|
|
125
|
+
FAISS_THROW_IF_NOT (key >= 0 && key < ntotal);
|
|
126
|
+
pq.decode (&codes[key * pq.code_size], recons);
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
namespace {
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
struct PQDis: DistanceComputer {
|
|
134
|
+
size_t d;
|
|
135
|
+
Index::idx_t nb;
|
|
136
|
+
const uint8_t *codes;
|
|
137
|
+
size_t code_size;
|
|
138
|
+
const ProductQuantizer & pq;
|
|
139
|
+
const float *sdc;
|
|
140
|
+
std::vector<float> precomputed_table;
|
|
141
|
+
size_t ndis;
|
|
142
|
+
|
|
143
|
+
float operator () (idx_t i) override
|
|
144
|
+
{
|
|
145
|
+
const uint8_t *code = codes + i * code_size;
|
|
146
|
+
const float *dt = precomputed_table.data();
|
|
147
|
+
float accu = 0;
|
|
148
|
+
for (int j = 0; j < pq.M; j++) {
|
|
149
|
+
accu += dt[*code++];
|
|
150
|
+
dt += 256;
|
|
151
|
+
}
|
|
152
|
+
ndis++;
|
|
153
|
+
return accu;
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
float symmetric_dis(idx_t i, idx_t j) override
|
|
157
|
+
{
|
|
158
|
+
const float * sdci = sdc;
|
|
159
|
+
float accu = 0;
|
|
160
|
+
const uint8_t *codei = codes + i * code_size;
|
|
161
|
+
const uint8_t *codej = codes + j * code_size;
|
|
162
|
+
|
|
163
|
+
for (int l = 0; l < pq.M; l++) {
|
|
164
|
+
accu += sdci[(*codei++) + (*codej++) * 256];
|
|
165
|
+
sdci += 256 * 256;
|
|
166
|
+
}
|
|
167
|
+
return accu;
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
explicit PQDis(const IndexPQ& storage, const float* /*q*/ = nullptr)
|
|
171
|
+
: pq(storage.pq) {
|
|
172
|
+
precomputed_table.resize(pq.M * pq.ksub);
|
|
173
|
+
nb = storage.ntotal;
|
|
174
|
+
d = storage.d;
|
|
175
|
+
codes = storage.codes.data();
|
|
176
|
+
code_size = pq.code_size;
|
|
177
|
+
FAISS_ASSERT(pq.ksub == 256);
|
|
178
|
+
FAISS_ASSERT(pq.sdc_table.size() == pq.ksub * pq.ksub * pq.M);
|
|
179
|
+
sdc = pq.sdc_table.data();
|
|
180
|
+
ndis = 0;
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
void set_query(const float *x) override {
|
|
184
|
+
pq.compute_distance_table(x, precomputed_table.data());
|
|
185
|
+
}
|
|
186
|
+
};
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
} // namespace
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
DistanceComputer * IndexPQ::get_distance_computer() const {
|
|
193
|
+
FAISS_THROW_IF_NOT(pq.nbits == 8);
|
|
194
|
+
return new PQDis(*this);
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
/*****************************************
|
|
199
|
+
* IndexPQ polysemous search routines
|
|
200
|
+
******************************************/
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
void IndexPQ::search (idx_t n, const float *x, idx_t k,
|
|
207
|
+
float *distances, idx_t *labels) const
|
|
208
|
+
{
|
|
209
|
+
FAISS_THROW_IF_NOT (is_trained);
|
|
210
|
+
if (search_type == ST_PQ) { // Simple PQ search
|
|
211
|
+
|
|
212
|
+
if (metric_type == METRIC_L2) {
|
|
213
|
+
float_maxheap_array_t res = {
|
|
214
|
+
size_t(n), size_t(k), labels, distances };
|
|
215
|
+
pq.search (x, n, codes.data(), ntotal, &res, true);
|
|
216
|
+
} else {
|
|
217
|
+
float_minheap_array_t res = {
|
|
218
|
+
size_t(n), size_t(k), labels, distances };
|
|
219
|
+
pq.search_ip (x, n, codes.data(), ntotal, &res, true);
|
|
220
|
+
}
|
|
221
|
+
indexPQ_stats.nq += n;
|
|
222
|
+
indexPQ_stats.ncode += n * ntotal;
|
|
223
|
+
|
|
224
|
+
} else if (search_type == ST_polysemous ||
|
|
225
|
+
search_type == ST_polysemous_generalize) {
|
|
226
|
+
|
|
227
|
+
FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
|
|
228
|
+
|
|
229
|
+
search_core_polysemous (n, x, k, distances, labels);
|
|
230
|
+
|
|
231
|
+
} else { // code-to-code distances
|
|
232
|
+
|
|
233
|
+
uint8_t * q_codes = new uint8_t [n * pq.code_size];
|
|
234
|
+
ScopeDeleter<uint8_t> del (q_codes);
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
if (!encode_signs) {
|
|
238
|
+
pq.compute_codes (x, q_codes, n);
|
|
239
|
+
} else {
|
|
240
|
+
FAISS_THROW_IF_NOT (d == pq.nbits * pq.M);
|
|
241
|
+
memset (q_codes, 0, n * pq.code_size);
|
|
242
|
+
for (size_t i = 0; i < n; i++) {
|
|
243
|
+
const float *xi = x + i * d;
|
|
244
|
+
uint8_t *code = q_codes + i * pq.code_size;
|
|
245
|
+
for (int j = 0; j < d; j++)
|
|
246
|
+
if (xi[j] > 0) code [j>>3] |= 1 << (j & 7);
|
|
247
|
+
}
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
if (search_type == ST_SDC) {
|
|
251
|
+
|
|
252
|
+
float_maxheap_array_t res = {
|
|
253
|
+
size_t(n), size_t(k), labels, distances};
|
|
254
|
+
|
|
255
|
+
pq.search_sdc (q_codes, n, codes.data(), ntotal, &res, true);
|
|
256
|
+
|
|
257
|
+
} else {
|
|
258
|
+
int * idistances = new int [n * k];
|
|
259
|
+
ScopeDeleter<int> del (idistances);
|
|
260
|
+
|
|
261
|
+
int_maxheap_array_t res = {
|
|
262
|
+
size_t (n), size_t (k), labels, idistances};
|
|
263
|
+
|
|
264
|
+
if (search_type == ST_HE) {
|
|
265
|
+
|
|
266
|
+
hammings_knn_hc (&res, q_codes, codes.data(),
|
|
267
|
+
ntotal, pq.code_size, true);
|
|
268
|
+
|
|
269
|
+
} else if (search_type == ST_generalized_HE) {
|
|
270
|
+
|
|
271
|
+
generalized_hammings_knn_hc (&res, q_codes, codes.data(),
|
|
272
|
+
ntotal, pq.code_size, true);
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
// convert distances to floats
|
|
276
|
+
for (int i = 0; i < k * n; i++)
|
|
277
|
+
distances[i] = idistances[i];
|
|
278
|
+
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
indexPQ_stats.nq += n;
|
|
283
|
+
indexPQ_stats.ncode += n * ntotal;
|
|
284
|
+
}
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
void IndexPQStats::reset()
|
|
292
|
+
{
|
|
293
|
+
nq = ncode = n_hamming_pass = 0;
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
IndexPQStats indexPQ_stats;
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
template <class HammingComputer>
|
|
300
|
+
static size_t polysemous_inner_loop (
|
|
301
|
+
const IndexPQ & index,
|
|
302
|
+
const float *dis_table_qi, const uint8_t *q_code,
|
|
303
|
+
size_t k, float *heap_dis, int64_t *heap_ids)
|
|
304
|
+
{
|
|
305
|
+
|
|
306
|
+
int M = index.pq.M;
|
|
307
|
+
int code_size = index.pq.code_size;
|
|
308
|
+
int ksub = index.pq.ksub;
|
|
309
|
+
size_t ntotal = index.ntotal;
|
|
310
|
+
int ht = index.polysemous_ht;
|
|
311
|
+
|
|
312
|
+
const uint8_t *b_code = index.codes.data();
|
|
313
|
+
|
|
314
|
+
size_t n_pass_i = 0;
|
|
315
|
+
|
|
316
|
+
HammingComputer hc (q_code, code_size);
|
|
317
|
+
|
|
318
|
+
for (int64_t bi = 0; bi < ntotal; bi++) {
|
|
319
|
+
int hd = hc.hamming (b_code);
|
|
320
|
+
|
|
321
|
+
if (hd < ht) {
|
|
322
|
+
n_pass_i ++;
|
|
323
|
+
|
|
324
|
+
float dis = 0;
|
|
325
|
+
const float * dis_table = dis_table_qi;
|
|
326
|
+
for (int m = 0; m < M; m++) {
|
|
327
|
+
dis += dis_table [b_code[m]];
|
|
328
|
+
dis_table += ksub;
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
if (dis < heap_dis[0]) {
|
|
332
|
+
maxheap_pop (k, heap_dis, heap_ids);
|
|
333
|
+
maxheap_push (k, heap_dis, heap_ids, dis, bi);
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
b_code += code_size;
|
|
337
|
+
}
|
|
338
|
+
return n_pass_i;
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
void IndexPQ::search_core_polysemous (idx_t n, const float *x, idx_t k,
|
|
343
|
+
float *distances, idx_t *labels) const
|
|
344
|
+
{
|
|
345
|
+
FAISS_THROW_IF_NOT (pq.nbits == 8);
|
|
346
|
+
|
|
347
|
+
// PQ distance tables
|
|
348
|
+
float * dis_tables = new float [n * pq.ksub * pq.M];
|
|
349
|
+
ScopeDeleter<float> del (dis_tables);
|
|
350
|
+
pq.compute_distance_tables (n, x, dis_tables);
|
|
351
|
+
|
|
352
|
+
// Hamming embedding queries
|
|
353
|
+
uint8_t * q_codes = new uint8_t [n * pq.code_size];
|
|
354
|
+
ScopeDeleter<uint8_t> del2 (q_codes);
|
|
355
|
+
|
|
356
|
+
if (false) {
|
|
357
|
+
pq.compute_codes (x, q_codes, n);
|
|
358
|
+
} else {
|
|
359
|
+
#pragma omp parallel for
|
|
360
|
+
for (idx_t qi = 0; qi < n; qi++) {
|
|
361
|
+
pq.compute_code_from_distance_table
|
|
362
|
+
(dis_tables + qi * pq.M * pq.ksub,
|
|
363
|
+
q_codes + qi * pq.code_size);
|
|
364
|
+
}
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
size_t n_pass = 0;
|
|
368
|
+
|
|
369
|
+
#pragma omp parallel for reduction (+: n_pass)
|
|
370
|
+
for (idx_t qi = 0; qi < n; qi++) {
|
|
371
|
+
const uint8_t * q_code = q_codes + qi * pq.code_size;
|
|
372
|
+
|
|
373
|
+
const float * dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
|
|
374
|
+
|
|
375
|
+
int64_t * heap_ids = labels + qi * k;
|
|
376
|
+
float *heap_dis = distances + qi * k;
|
|
377
|
+
maxheap_heapify (k, heap_dis, heap_ids);
|
|
378
|
+
|
|
379
|
+
if (search_type == ST_polysemous) {
|
|
380
|
+
|
|
381
|
+
switch (pq.code_size) {
|
|
382
|
+
case 4:
|
|
383
|
+
n_pass += polysemous_inner_loop<HammingComputer4>
|
|
384
|
+
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
385
|
+
break;
|
|
386
|
+
case 8:
|
|
387
|
+
n_pass += polysemous_inner_loop<HammingComputer8>
|
|
388
|
+
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
389
|
+
break;
|
|
390
|
+
case 16:
|
|
391
|
+
n_pass += polysemous_inner_loop<HammingComputer16>
|
|
392
|
+
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
393
|
+
break;
|
|
394
|
+
case 32:
|
|
395
|
+
n_pass += polysemous_inner_loop<HammingComputer32>
|
|
396
|
+
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
397
|
+
break;
|
|
398
|
+
case 20:
|
|
399
|
+
n_pass += polysemous_inner_loop<HammingComputer20>
|
|
400
|
+
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
401
|
+
break;
|
|
402
|
+
default:
|
|
403
|
+
if (pq.code_size % 8 == 0) {
|
|
404
|
+
n_pass += polysemous_inner_loop<HammingComputerM8>
|
|
405
|
+
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
406
|
+
} else if (pq.code_size % 4 == 0) {
|
|
407
|
+
n_pass += polysemous_inner_loop<HammingComputerM4>
|
|
408
|
+
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
409
|
+
} else {
|
|
410
|
+
FAISS_THROW_FMT(
|
|
411
|
+
"code size %zd not supported for polysemous",
|
|
412
|
+
pq.code_size);
|
|
413
|
+
}
|
|
414
|
+
break;
|
|
415
|
+
}
|
|
416
|
+
} else {
|
|
417
|
+
switch (pq.code_size) {
|
|
418
|
+
case 8:
|
|
419
|
+
n_pass += polysemous_inner_loop<GenHammingComputer8>
|
|
420
|
+
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
421
|
+
break;
|
|
422
|
+
case 16:
|
|
423
|
+
n_pass += polysemous_inner_loop<GenHammingComputer16>
|
|
424
|
+
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
425
|
+
break;
|
|
426
|
+
case 32:
|
|
427
|
+
n_pass += polysemous_inner_loop<GenHammingComputer32>
|
|
428
|
+
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
429
|
+
break;
|
|
430
|
+
default:
|
|
431
|
+
if (pq.code_size % 8 == 0) {
|
|
432
|
+
n_pass += polysemous_inner_loop<GenHammingComputerM8>
|
|
433
|
+
(*this, dis_table_qi, q_code, k, heap_dis, heap_ids);
|
|
434
|
+
} else {
|
|
435
|
+
FAISS_THROW_FMT(
|
|
436
|
+
"code size %zd not supported for polysemous",
|
|
437
|
+
pq.code_size);
|
|
438
|
+
}
|
|
439
|
+
break;
|
|
440
|
+
}
|
|
441
|
+
}
|
|
442
|
+
maxheap_reorder (k, heap_dis, heap_ids);
|
|
443
|
+
}
|
|
444
|
+
|
|
445
|
+
indexPQ_stats.nq += n;
|
|
446
|
+
indexPQ_stats.ncode += n * ntotal;
|
|
447
|
+
indexPQ_stats.n_hamming_pass += n_pass;
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
/* The standalone codec interface (just remaps to the PQ functions) */
|
|
454
|
+
size_t IndexPQ::sa_code_size () const
|
|
455
|
+
{
|
|
456
|
+
return pq.code_size;
|
|
457
|
+
}
|
|
458
|
+
|
|
459
|
+
void IndexPQ::sa_encode (idx_t n, const float *x, uint8_t *bytes) const
|
|
460
|
+
{
|
|
461
|
+
pq.compute_codes (x, bytes, n);
|
|
462
|
+
}
|
|
463
|
+
|
|
464
|
+
void IndexPQ::sa_decode (idx_t n, const uint8_t *bytes, float *x) const
|
|
465
|
+
{
|
|
466
|
+
pq.decode (bytes, x, n);
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
/*****************************************
|
|
473
|
+
* Stats of IndexPQ codes
|
|
474
|
+
******************************************/
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
void IndexPQ::hamming_distance_table (idx_t n, const float *x,
|
|
480
|
+
int32_t *dis) const
|
|
481
|
+
{
|
|
482
|
+
uint8_t * q_codes = new uint8_t [n * pq.code_size];
|
|
483
|
+
ScopeDeleter<uint8_t> del (q_codes);
|
|
484
|
+
|
|
485
|
+
pq.compute_codes (x, q_codes, n);
|
|
486
|
+
|
|
487
|
+
hammings (q_codes, codes.data(), n, ntotal, pq.code_size, dis);
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
void IndexPQ::hamming_distance_histogram (idx_t n, const float *x,
|
|
492
|
+
idx_t nb, const float *xb,
|
|
493
|
+
int64_t *hist)
|
|
494
|
+
{
|
|
495
|
+
FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
|
|
496
|
+
FAISS_THROW_IF_NOT (pq.code_size % 8 == 0);
|
|
497
|
+
FAISS_THROW_IF_NOT (pq.nbits == 8);
|
|
498
|
+
|
|
499
|
+
// Hamming embedding queries
|
|
500
|
+
uint8_t * q_codes = new uint8_t [n * pq.code_size];
|
|
501
|
+
ScopeDeleter <uint8_t> del (q_codes);
|
|
502
|
+
pq.compute_codes (x, q_codes, n);
|
|
503
|
+
|
|
504
|
+
uint8_t * b_codes ;
|
|
505
|
+
ScopeDeleter <uint8_t> del_b_codes;
|
|
506
|
+
|
|
507
|
+
if (xb) {
|
|
508
|
+
b_codes = new uint8_t [nb * pq.code_size];
|
|
509
|
+
del_b_codes.set (b_codes);
|
|
510
|
+
pq.compute_codes (xb, b_codes, nb);
|
|
511
|
+
} else {
|
|
512
|
+
nb = ntotal;
|
|
513
|
+
b_codes = codes.data();
|
|
514
|
+
}
|
|
515
|
+
int nbits = pq.M * pq.nbits;
|
|
516
|
+
memset (hist, 0, sizeof(*hist) * (nbits + 1));
|
|
517
|
+
size_t bs = 256;
|
|
518
|
+
|
|
519
|
+
#pragma omp parallel
|
|
520
|
+
{
|
|
521
|
+
std::vector<int64_t> histi (nbits + 1);
|
|
522
|
+
hamdis_t *distances = new hamdis_t [nb * bs];
|
|
523
|
+
ScopeDeleter<hamdis_t> del (distances);
|
|
524
|
+
#pragma omp for
|
|
525
|
+
for (size_t q0 = 0; q0 < n; q0 += bs) {
|
|
526
|
+
// printf ("dis stats: %ld/%ld\n", q0, n);
|
|
527
|
+
size_t q1 = q0 + bs;
|
|
528
|
+
if (q1 > n) q1 = n;
|
|
529
|
+
|
|
530
|
+
hammings (q_codes + q0 * pq.code_size, b_codes,
|
|
531
|
+
q1 - q0, nb,
|
|
532
|
+
pq.code_size, distances);
|
|
533
|
+
|
|
534
|
+
for (size_t i = 0; i < nb * (q1 - q0); i++)
|
|
535
|
+
histi [distances [i]]++;
|
|
536
|
+
}
|
|
537
|
+
#pragma omp critical
|
|
538
|
+
{
|
|
539
|
+
for (int i = 0; i <= nbits; i++)
|
|
540
|
+
hist[i] += histi[i];
|
|
541
|
+
}
|
|
542
|
+
}
|
|
543
|
+
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
/*****************************************
|
|
566
|
+
* MultiIndexQuantizer
|
|
567
|
+
******************************************/
|
|
568
|
+
|
|
569
|
+
namespace {
|
|
570
|
+
|
|
571
|
+
template <typename T>
|
|
572
|
+
struct PreSortedArray {
|
|
573
|
+
|
|
574
|
+
const T * x;
|
|
575
|
+
int N;
|
|
576
|
+
|
|
577
|
+
explicit PreSortedArray (int N): N(N) {
|
|
578
|
+
}
|
|
579
|
+
void init (const T*x) {
|
|
580
|
+
this->x = x;
|
|
581
|
+
}
|
|
582
|
+
// get smallest value
|
|
583
|
+
T get_0 () {
|
|
584
|
+
return x[0];
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
// get delta between n-smallest and n-1 -smallest
|
|
588
|
+
T get_diff (int n) {
|
|
589
|
+
return x[n] - x[n - 1];
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
// remap orders counted from smallest to indices in array
|
|
593
|
+
int get_ord (int n) {
|
|
594
|
+
return n;
|
|
595
|
+
}
|
|
596
|
+
|
|
597
|
+
};
|
|
598
|
+
|
|
599
|
+
template <typename T>
|
|
600
|
+
struct ArgSort {
|
|
601
|
+
const T * x;
|
|
602
|
+
bool operator() (size_t i, size_t j) {
|
|
603
|
+
return x[i] < x[j];
|
|
604
|
+
}
|
|
605
|
+
};
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
/** Array that maintains a permutation of its elements so that the
|
|
609
|
+
* array's elements are sorted
|
|
610
|
+
*/
|
|
611
|
+
template <typename T>
|
|
612
|
+
struct SortedArray {
|
|
613
|
+
const T * x;
|
|
614
|
+
int N;
|
|
615
|
+
std::vector<int> perm;
|
|
616
|
+
|
|
617
|
+
explicit SortedArray (int N) {
|
|
618
|
+
this->N = N;
|
|
619
|
+
perm.resize (N);
|
|
620
|
+
}
|
|
621
|
+
|
|
622
|
+
void init (const T*x) {
|
|
623
|
+
this->x = x;
|
|
624
|
+
for (int n = 0; n < N; n++)
|
|
625
|
+
perm[n] = n;
|
|
626
|
+
ArgSort<T> cmp = {x };
|
|
627
|
+
std::sort (perm.begin(), perm.end(), cmp);
|
|
628
|
+
}
|
|
629
|
+
|
|
630
|
+
// get smallest value
|
|
631
|
+
T get_0 () {
|
|
632
|
+
return x[perm[0]];
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
// get delta between n-smallest and n-1 -smallest
|
|
636
|
+
T get_diff (int n) {
|
|
637
|
+
return x[perm[n]] - x[perm[n - 1]];
|
|
638
|
+
}
|
|
639
|
+
|
|
640
|
+
// remap orders counted from smallest to indices in array
|
|
641
|
+
int get_ord (int n) {
|
|
642
|
+
return perm[n];
|
|
643
|
+
}
|
|
644
|
+
};
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
/** Array has n values. Sort the k first ones and copy the other ones
|
|
649
|
+
* into elements k..n-1
|
|
650
|
+
*/
|
|
651
|
+
template <class C>
|
|
652
|
+
void partial_sort (int k, int n,
|
|
653
|
+
const typename C::T * vals, typename C::TI * perm) {
|
|
654
|
+
// insert first k elts in heap
|
|
655
|
+
for (int i = 1; i < k; i++) {
|
|
656
|
+
indirect_heap_push<C> (i + 1, vals, perm, perm[i]);
|
|
657
|
+
}
|
|
658
|
+
|
|
659
|
+
// insert next n - k elts in heap
|
|
660
|
+
for (int i = k; i < n; i++) {
|
|
661
|
+
typename C::TI id = perm[i];
|
|
662
|
+
typename C::TI top = perm[0];
|
|
663
|
+
|
|
664
|
+
if (C::cmp(vals[top], vals[id])) {
|
|
665
|
+
indirect_heap_pop<C> (k, vals, perm);
|
|
666
|
+
indirect_heap_push<C> (k, vals, perm, id);
|
|
667
|
+
perm[i] = top;
|
|
668
|
+
} else {
|
|
669
|
+
// nothing, elt at i is good where it is.
|
|
670
|
+
}
|
|
671
|
+
}
|
|
672
|
+
|
|
673
|
+
// order the k first elements in heap
|
|
674
|
+
for (int i = k - 1; i > 0; i--) {
|
|
675
|
+
typename C::TI top = perm[0];
|
|
676
|
+
indirect_heap_pop<C> (i + 1, vals, perm);
|
|
677
|
+
perm[i] = top;
|
|
678
|
+
}
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
/** same as SortedArray, but only the k first elements are sorted */
|
|
682
|
+
template <typename T>
|
|
683
|
+
struct SemiSortedArray {
|
|
684
|
+
const T * x;
|
|
685
|
+
int N;
|
|
686
|
+
|
|
687
|
+
// type of the heap: CMax = sort ascending
|
|
688
|
+
typedef CMax<T, int> HC;
|
|
689
|
+
std::vector<int> perm;
|
|
690
|
+
|
|
691
|
+
int k; // k elements are sorted
|
|
692
|
+
|
|
693
|
+
int initial_k, k_factor;
|
|
694
|
+
|
|
695
|
+
explicit SemiSortedArray (int N) {
|
|
696
|
+
this->N = N;
|
|
697
|
+
perm.resize (N);
|
|
698
|
+
perm.resize (N);
|
|
699
|
+
initial_k = 3;
|
|
700
|
+
k_factor = 4;
|
|
701
|
+
}
|
|
702
|
+
|
|
703
|
+
void init (const T*x) {
|
|
704
|
+
this->x = x;
|
|
705
|
+
for (int n = 0; n < N; n++)
|
|
706
|
+
perm[n] = n;
|
|
707
|
+
k = 0;
|
|
708
|
+
grow (initial_k);
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
/// grow the sorted part of the array to size next_k
|
|
712
|
+
void grow (int next_k) {
|
|
713
|
+
if (next_k < N) {
|
|
714
|
+
partial_sort<HC> (next_k - k, N - k, x, &perm[k]);
|
|
715
|
+
k = next_k;
|
|
716
|
+
} else { // full sort of remainder of array
|
|
717
|
+
ArgSort<T> cmp = {x };
|
|
718
|
+
std::sort (perm.begin() + k, perm.end(), cmp);
|
|
719
|
+
k = N;
|
|
720
|
+
}
|
|
721
|
+
}
|
|
722
|
+
|
|
723
|
+
// get smallest value
|
|
724
|
+
T get_0 () {
|
|
725
|
+
return x[perm[0]];
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
// get delta between n-smallest and n-1 -smallest
|
|
729
|
+
T get_diff (int n) {
|
|
730
|
+
if (n >= k) {
|
|
731
|
+
// want to keep powers of 2 - 1
|
|
732
|
+
int next_k = (k + 1) * k_factor - 1;
|
|
733
|
+
grow (next_k);
|
|
734
|
+
}
|
|
735
|
+
return x[perm[n]] - x[perm[n - 1]];
|
|
736
|
+
}
|
|
737
|
+
|
|
738
|
+
// remap orders counted from smallest to indices in array
|
|
739
|
+
int get_ord (int n) {
|
|
740
|
+
assert (n < k);
|
|
741
|
+
return perm[n];
|
|
742
|
+
}
|
|
743
|
+
};
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
/*****************************************
|
|
748
|
+
* Find the k smallest sums of M terms, where each term is taken in a
|
|
749
|
+
* table x of n values.
|
|
750
|
+
*
|
|
751
|
+
* A combination of terms is encoded as a scalar 0 <= t < n^M. The
|
|
752
|
+
* combination t0 ... t(M-1) that correspond to the sum
|
|
753
|
+
*
|
|
754
|
+
* sum = x[0, t0] + x[1, t1] + .... + x[M-1, t(M-1)]
|
|
755
|
+
*
|
|
756
|
+
* is encoded as
|
|
757
|
+
*
|
|
758
|
+
* t = t0 + t1 * n + t2 * n^2 + ... + t(M-1) * n^(M-1)
|
|
759
|
+
*
|
|
760
|
+
* MinSumK is an object rather than a function, so that storage can be
|
|
761
|
+
* re-used over several computations with the same sizes. use_seen is
|
|
762
|
+
* good when there may be ties in the x array and it is a concern if
|
|
763
|
+
* occasionally several t's are returned.
|
|
764
|
+
*
|
|
765
|
+
* @param x size M * n, values to add up
|
|
766
|
+
* @parms k nb of results to retrieve
|
|
767
|
+
* @param M nb of terms
|
|
768
|
+
* @param n nb of distinct values
|
|
769
|
+
* @param sums output, size k, sorted
|
|
770
|
+
* @prarm terms output, size k, with encoding as above
|
|
771
|
+
*
|
|
772
|
+
******************************************/
|
|
773
|
+
template <typename T, class SSA, bool use_seen>
|
|
774
|
+
struct MinSumK {
|
|
775
|
+
int K; ///< nb of sums to return
|
|
776
|
+
int M; ///< nb of elements to sum up
|
|
777
|
+
int nbit; ///< nb of bits to encode one entry
|
|
778
|
+
int N; ///< nb of possible elements for each of the M terms
|
|
779
|
+
|
|
780
|
+
/** the heap.
|
|
781
|
+
* We use a heap to maintain a queue of sums, with the associated
|
|
782
|
+
* terms involved in the sum.
|
|
783
|
+
*/
|
|
784
|
+
typedef CMin<T, int64_t> HC;
|
|
785
|
+
size_t heap_capacity, heap_size;
|
|
786
|
+
T *bh_val;
|
|
787
|
+
int64_t *bh_ids;
|
|
788
|
+
|
|
789
|
+
std::vector <SSA> ssx;
|
|
790
|
+
|
|
791
|
+
// all results get pushed several times. When there are ties, they
|
|
792
|
+
// are popped interleaved with others, so it is not easy to
|
|
793
|
+
// identify them. Therefore, this bit array just marks elements
|
|
794
|
+
// that were seen before.
|
|
795
|
+
std::vector <uint8_t> seen;
|
|
796
|
+
|
|
797
|
+
MinSumK (int K, int M, int nbit, int N):
|
|
798
|
+
K(K), M(M), nbit(nbit), N(N) {
|
|
799
|
+
heap_capacity = K * M;
|
|
800
|
+
assert (N <= (1 << nbit));
|
|
801
|
+
|
|
802
|
+
// we'll do k steps, each step pushes at most M vals
|
|
803
|
+
bh_val = new T[heap_capacity];
|
|
804
|
+
bh_ids = new int64_t[heap_capacity];
|
|
805
|
+
|
|
806
|
+
if (use_seen) {
|
|
807
|
+
int64_t n_ids = weight(M);
|
|
808
|
+
seen.resize ((n_ids + 7) / 8);
|
|
809
|
+
}
|
|
810
|
+
|
|
811
|
+
for (int m = 0; m < M; m++)
|
|
812
|
+
ssx.push_back (SSA(N));
|
|
813
|
+
|
|
814
|
+
}
|
|
815
|
+
|
|
816
|
+
int64_t weight (int i) {
|
|
817
|
+
return 1 << (i * nbit);
|
|
818
|
+
}
|
|
819
|
+
|
|
820
|
+
bool is_seen (int64_t i) {
|
|
821
|
+
return (seen[i >> 3] >> (i & 7)) & 1;
|
|
822
|
+
}
|
|
823
|
+
|
|
824
|
+
void mark_seen (int64_t i) {
|
|
825
|
+
if (use_seen)
|
|
826
|
+
seen [i >> 3] |= 1 << (i & 7);
|
|
827
|
+
}
|
|
828
|
+
|
|
829
|
+
void run (const T *x, int64_t ldx,
|
|
830
|
+
T * sums, int64_t * terms) {
|
|
831
|
+
heap_size = 0;
|
|
832
|
+
|
|
833
|
+
for (int m = 0; m < M; m++) {
|
|
834
|
+
ssx[m].init(x);
|
|
835
|
+
x += ldx;
|
|
836
|
+
}
|
|
837
|
+
|
|
838
|
+
{ // intial result: take min for all elements
|
|
839
|
+
T sum = 0;
|
|
840
|
+
terms[0] = 0;
|
|
841
|
+
mark_seen (0);
|
|
842
|
+
for (int m = 0; m < M; m++) {
|
|
843
|
+
sum += ssx[m].get_0();
|
|
844
|
+
}
|
|
845
|
+
sums[0] = sum;
|
|
846
|
+
for (int m = 0; m < M; m++) {
|
|
847
|
+
heap_push<HC> (++heap_size, bh_val, bh_ids,
|
|
848
|
+
sum + ssx[m].get_diff(1),
|
|
849
|
+
weight(m));
|
|
850
|
+
}
|
|
851
|
+
}
|
|
852
|
+
|
|
853
|
+
for (int k = 1; k < K; k++) {
|
|
854
|
+
// pop smallest value from heap
|
|
855
|
+
if (use_seen) {// skip already seen elements
|
|
856
|
+
while (is_seen (bh_ids[0])) {
|
|
857
|
+
assert (heap_size > 0);
|
|
858
|
+
heap_pop<HC> (heap_size--, bh_val, bh_ids);
|
|
859
|
+
}
|
|
860
|
+
}
|
|
861
|
+
assert (heap_size > 0);
|
|
862
|
+
|
|
863
|
+
T sum = sums[k] = bh_val[0];
|
|
864
|
+
int64_t ti = terms[k] = bh_ids[0];
|
|
865
|
+
|
|
866
|
+
if (use_seen) {
|
|
867
|
+
mark_seen (ti);
|
|
868
|
+
heap_pop<HC> (heap_size--, bh_val, bh_ids);
|
|
869
|
+
} else {
|
|
870
|
+
do {
|
|
871
|
+
heap_pop<HC> (heap_size--, bh_val, bh_ids);
|
|
872
|
+
} while (heap_size > 0 && bh_ids[0] == ti);
|
|
873
|
+
}
|
|
874
|
+
|
|
875
|
+
// enqueue followers
|
|
876
|
+
int64_t ii = ti;
|
|
877
|
+
for (int m = 0; m < M; m++) {
|
|
878
|
+
int64_t n = ii & ((1L << nbit) - 1);
|
|
879
|
+
ii >>= nbit;
|
|
880
|
+
if (n + 1 >= N) continue;
|
|
881
|
+
|
|
882
|
+
enqueue_follower (ti, m, n, sum);
|
|
883
|
+
}
|
|
884
|
+
}
|
|
885
|
+
|
|
886
|
+
/*
|
|
887
|
+
for (int k = 0; k < K; k++)
|
|
888
|
+
for (int l = k + 1; l < K; l++)
|
|
889
|
+
assert (terms[k] != terms[l]);
|
|
890
|
+
*/
|
|
891
|
+
|
|
892
|
+
// convert indices by applying permutation
|
|
893
|
+
for (int k = 0; k < K; k++) {
|
|
894
|
+
int64_t ii = terms[k];
|
|
895
|
+
if (use_seen) {
|
|
896
|
+
// clear seen for reuse at next loop
|
|
897
|
+
seen[ii >> 3] = 0;
|
|
898
|
+
}
|
|
899
|
+
int64_t ti = 0;
|
|
900
|
+
for (int m = 0; m < M; m++) {
|
|
901
|
+
int64_t n = ii & ((1L << nbit) - 1);
|
|
902
|
+
ti += int64_t(ssx[m].get_ord(n)) << (nbit * m);
|
|
903
|
+
ii >>= nbit;
|
|
904
|
+
}
|
|
905
|
+
terms[k] = ti;
|
|
906
|
+
}
|
|
907
|
+
}
|
|
908
|
+
|
|
909
|
+
|
|
910
|
+
void enqueue_follower (int64_t ti, int m, int n, T sum) {
|
|
911
|
+
T next_sum = sum + ssx[m].get_diff(n + 1);
|
|
912
|
+
int64_t next_ti = ti + weight(m);
|
|
913
|
+
heap_push<HC> (++heap_size, bh_val, bh_ids, next_sum, next_ti);
|
|
914
|
+
}
|
|
915
|
+
|
|
916
|
+
~MinSumK () {
|
|
917
|
+
delete [] bh_ids;
|
|
918
|
+
delete [] bh_val;
|
|
919
|
+
}
|
|
920
|
+
};
|
|
921
|
+
|
|
922
|
+
} // anonymous namespace
|
|
923
|
+
|
|
924
|
+
|
|
925
|
+
MultiIndexQuantizer::MultiIndexQuantizer (int d,
|
|
926
|
+
size_t M,
|
|
927
|
+
size_t nbits):
|
|
928
|
+
Index(d, METRIC_L2), pq(d, M, nbits)
|
|
929
|
+
{
|
|
930
|
+
is_trained = false;
|
|
931
|
+
pq.verbose = verbose;
|
|
932
|
+
}
|
|
933
|
+
|
|
934
|
+
|
|
935
|
+
|
|
936
|
+
void MultiIndexQuantizer::train(idx_t n, const float *x)
|
|
937
|
+
{
|
|
938
|
+
pq.verbose = verbose;
|
|
939
|
+
pq.train (n, x);
|
|
940
|
+
is_trained = true;
|
|
941
|
+
// count virtual elements in index
|
|
942
|
+
ntotal = 1;
|
|
943
|
+
for (int m = 0; m < pq.M; m++)
|
|
944
|
+
ntotal *= pq.ksub;
|
|
945
|
+
}
|
|
946
|
+
|
|
947
|
+
|
|
948
|
+
void MultiIndexQuantizer::search (idx_t n, const float *x, idx_t k,
|
|
949
|
+
float *distances, idx_t *labels) const {
|
|
950
|
+
if (n == 0) return;
|
|
951
|
+
|
|
952
|
+
// the allocation just below can be severe...
|
|
953
|
+
idx_t bs = 32768;
|
|
954
|
+
if (n > bs) {
|
|
955
|
+
for (idx_t i0 = 0; i0 < n; i0 += bs) {
|
|
956
|
+
idx_t i1 = std::min(i0 + bs, n);
|
|
957
|
+
if (verbose) {
|
|
958
|
+
printf("MultiIndexQuantizer::search: %ld:%ld / %ld\n",
|
|
959
|
+
i0, i1, n);
|
|
960
|
+
}
|
|
961
|
+
search (i1 - i0, x + i0 * d, k,
|
|
962
|
+
distances + i0 * k,
|
|
963
|
+
labels + i0 * k);
|
|
964
|
+
}
|
|
965
|
+
return;
|
|
966
|
+
}
|
|
967
|
+
|
|
968
|
+
float * dis_tables = new float [n * pq.ksub * pq.M];
|
|
969
|
+
ScopeDeleter<float> del (dis_tables);
|
|
970
|
+
|
|
971
|
+
pq.compute_distance_tables (n, x, dis_tables);
|
|
972
|
+
|
|
973
|
+
if (k == 1) {
|
|
974
|
+
// simple version that just finds the min in each table
|
|
975
|
+
|
|
976
|
+
#pragma omp parallel for
|
|
977
|
+
for (int i = 0; i < n; i++) {
|
|
978
|
+
const float * dis_table = dis_tables + i * pq.ksub * pq.M;
|
|
979
|
+
float dis = 0;
|
|
980
|
+
idx_t label = 0;
|
|
981
|
+
|
|
982
|
+
for (int s = 0; s < pq.M; s++) {
|
|
983
|
+
float vmin = HUGE_VALF;
|
|
984
|
+
idx_t lmin = -1;
|
|
985
|
+
|
|
986
|
+
for (idx_t j = 0; j < pq.ksub; j++) {
|
|
987
|
+
if (dis_table[j] < vmin) {
|
|
988
|
+
vmin = dis_table[j];
|
|
989
|
+
lmin = j;
|
|
990
|
+
}
|
|
991
|
+
}
|
|
992
|
+
dis += vmin;
|
|
993
|
+
label |= lmin << (s * pq.nbits);
|
|
994
|
+
dis_table += pq.ksub;
|
|
995
|
+
}
|
|
996
|
+
|
|
997
|
+
distances [i] = dis;
|
|
998
|
+
labels [i] = label;
|
|
999
|
+
}
|
|
1000
|
+
|
|
1001
|
+
|
|
1002
|
+
} else {
|
|
1003
|
+
|
|
1004
|
+
#pragma omp parallel if(n > 1)
|
|
1005
|
+
{
|
|
1006
|
+
MinSumK <float, SemiSortedArray<float>, false>
|
|
1007
|
+
msk(k, pq.M, pq.nbits, pq.ksub);
|
|
1008
|
+
#pragma omp for
|
|
1009
|
+
for (int i = 0; i < n; i++) {
|
|
1010
|
+
msk.run (dis_tables + i * pq.ksub * pq.M, pq.ksub,
|
|
1011
|
+
distances + i * k, labels + i * k);
|
|
1012
|
+
|
|
1013
|
+
}
|
|
1014
|
+
}
|
|
1015
|
+
}
|
|
1016
|
+
|
|
1017
|
+
}
|
|
1018
|
+
|
|
1019
|
+
|
|
1020
|
+
void MultiIndexQuantizer::reconstruct (idx_t key, float * recons) const
|
|
1021
|
+
{
|
|
1022
|
+
|
|
1023
|
+
int64_t jj = key;
|
|
1024
|
+
for (int m = 0; m < pq.M; m++) {
|
|
1025
|
+
int64_t n = jj & ((1L << pq.nbits) - 1);
|
|
1026
|
+
jj >>= pq.nbits;
|
|
1027
|
+
memcpy(recons, pq.get_centroids(m, n), sizeof(recons[0]) * pq.dsub);
|
|
1028
|
+
recons += pq.dsub;
|
|
1029
|
+
}
|
|
1030
|
+
}
|
|
1031
|
+
|
|
1032
|
+
void MultiIndexQuantizer::add(idx_t /*n*/, const float* /*x*/) {
|
|
1033
|
+
FAISS_THROW_MSG(
|
|
1034
|
+
"This index has virtual elements, "
|
|
1035
|
+
"it does not support add");
|
|
1036
|
+
}
|
|
1037
|
+
|
|
1038
|
+
void MultiIndexQuantizer::reset ()
|
|
1039
|
+
{
|
|
1040
|
+
FAISS_THROW_MSG ( "This index has virtual elements, "
|
|
1041
|
+
"it does not support reset");
|
|
1042
|
+
}
|
|
1043
|
+
|
|
1044
|
+
|
|
1045
|
+
|
|
1046
|
+
|
|
1047
|
+
|
|
1048
|
+
|
|
1049
|
+
|
|
1050
|
+
|
|
1051
|
+
|
|
1052
|
+
|
|
1053
|
+
/*****************************************
|
|
1054
|
+
* MultiIndexQuantizer2
|
|
1055
|
+
******************************************/
|
|
1056
|
+
|
|
1057
|
+
|
|
1058
|
+
|
|
1059
|
+
MultiIndexQuantizer2::MultiIndexQuantizer2 (
|
|
1060
|
+
int d, size_t M, size_t nbits,
|
|
1061
|
+
Index **indexes):
|
|
1062
|
+
MultiIndexQuantizer (d, M, nbits)
|
|
1063
|
+
{
|
|
1064
|
+
assign_indexes.resize (M);
|
|
1065
|
+
for (int i = 0; i < M; i++) {
|
|
1066
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1067
|
+
indexes[i]->d == pq.dsub,
|
|
1068
|
+
"Provided sub-index has incorrect size");
|
|
1069
|
+
assign_indexes[i] = indexes[i];
|
|
1070
|
+
}
|
|
1071
|
+
own_fields = false;
|
|
1072
|
+
}
|
|
1073
|
+
|
|
1074
|
+
MultiIndexQuantizer2::MultiIndexQuantizer2 (
|
|
1075
|
+
int d, size_t nbits,
|
|
1076
|
+
Index *assign_index_0,
|
|
1077
|
+
Index *assign_index_1):
|
|
1078
|
+
MultiIndexQuantizer (d, 2, nbits)
|
|
1079
|
+
{
|
|
1080
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
1081
|
+
assign_index_0->d == pq.dsub &&
|
|
1082
|
+
assign_index_1->d == pq.dsub,
|
|
1083
|
+
"Provided sub-index has incorrect size");
|
|
1084
|
+
assign_indexes.resize (2);
|
|
1085
|
+
assign_indexes [0] = assign_index_0;
|
|
1086
|
+
assign_indexes [1] = assign_index_1;
|
|
1087
|
+
own_fields = false;
|
|
1088
|
+
}
|
|
1089
|
+
|
|
1090
|
+
void MultiIndexQuantizer2::train(idx_t n, const float* x)
|
|
1091
|
+
{
|
|
1092
|
+
MultiIndexQuantizer::train(n, x);
|
|
1093
|
+
// add centroids to sub-indexes
|
|
1094
|
+
for (int i = 0; i < pq.M; i++) {
|
|
1095
|
+
assign_indexes[i]->add(pq.ksub, pq.get_centroids(i, 0));
|
|
1096
|
+
}
|
|
1097
|
+
}
|
|
1098
|
+
|
|
1099
|
+
|
|
1100
|
+
void MultiIndexQuantizer2::search(
|
|
1101
|
+
idx_t n, const float* x, idx_t K,
|
|
1102
|
+
float* distances, idx_t* labels) const
|
|
1103
|
+
{
|
|
1104
|
+
|
|
1105
|
+
if (n == 0) return;
|
|
1106
|
+
|
|
1107
|
+
int k2 = std::min(K, int64_t(pq.ksub));
|
|
1108
|
+
|
|
1109
|
+
int64_t M = pq.M;
|
|
1110
|
+
int64_t dsub = pq.dsub, ksub = pq.ksub;
|
|
1111
|
+
|
|
1112
|
+
// size (M, n, k2)
|
|
1113
|
+
std::vector<idx_t> sub_ids(n * M * k2);
|
|
1114
|
+
std::vector<float> sub_dis(n * M * k2);
|
|
1115
|
+
std::vector<float> xsub(n * dsub);
|
|
1116
|
+
|
|
1117
|
+
for (int m = 0; m < M; m++) {
|
|
1118
|
+
float *xdest = xsub.data();
|
|
1119
|
+
const float *xsrc = x + m * dsub;
|
|
1120
|
+
for (int j = 0; j < n; j++) {
|
|
1121
|
+
memcpy(xdest, xsrc, dsub * sizeof(xdest[0]));
|
|
1122
|
+
xsrc += d;
|
|
1123
|
+
xdest += dsub;
|
|
1124
|
+
}
|
|
1125
|
+
|
|
1126
|
+
assign_indexes[m]->search(
|
|
1127
|
+
n, xsub.data(), k2,
|
|
1128
|
+
&sub_dis[k2 * n * m],
|
|
1129
|
+
&sub_ids[k2 * n * m]);
|
|
1130
|
+
}
|
|
1131
|
+
|
|
1132
|
+
if (K == 1) {
|
|
1133
|
+
// simple version that just finds the min in each table
|
|
1134
|
+
assert (k2 == 1);
|
|
1135
|
+
|
|
1136
|
+
for (int i = 0; i < n; i++) {
|
|
1137
|
+
float dis = 0;
|
|
1138
|
+
idx_t label = 0;
|
|
1139
|
+
|
|
1140
|
+
for (int m = 0; m < M; m++) {
|
|
1141
|
+
float vmin = sub_dis[i + m * n];
|
|
1142
|
+
idx_t lmin = sub_ids[i + m * n];
|
|
1143
|
+
dis += vmin;
|
|
1144
|
+
label |= lmin << (m * pq.nbits);
|
|
1145
|
+
}
|
|
1146
|
+
distances [i] = dis;
|
|
1147
|
+
labels [i] = label;
|
|
1148
|
+
}
|
|
1149
|
+
|
|
1150
|
+
} else {
|
|
1151
|
+
|
|
1152
|
+
#pragma omp parallel if(n > 1)
|
|
1153
|
+
{
|
|
1154
|
+
MinSumK <float, PreSortedArray<float>, false>
|
|
1155
|
+
msk(K, pq.M, pq.nbits, k2);
|
|
1156
|
+
#pragma omp for
|
|
1157
|
+
for (int i = 0; i < n; i++) {
|
|
1158
|
+
idx_t *li = labels + i * K;
|
|
1159
|
+
msk.run (&sub_dis[i * k2], k2 * n,
|
|
1160
|
+
distances + i * K, li);
|
|
1161
|
+
|
|
1162
|
+
// remap ids
|
|
1163
|
+
|
|
1164
|
+
const idx_t *idmap0 = sub_ids.data() + i * k2;
|
|
1165
|
+
int64_t ld_idmap = k2 * n;
|
|
1166
|
+
int64_t mask1 = ksub - 1L;
|
|
1167
|
+
|
|
1168
|
+
for (int k = 0; k < K; k++) {
|
|
1169
|
+
const idx_t *idmap = idmap0;
|
|
1170
|
+
int64_t vin = li[k];
|
|
1171
|
+
int64_t vout = 0;
|
|
1172
|
+
int bs = 0;
|
|
1173
|
+
for (int m = 0; m < M; m++) {
|
|
1174
|
+
int64_t s = vin & mask1;
|
|
1175
|
+
vin >>= pq.nbits;
|
|
1176
|
+
vout |= idmap[s] << bs;
|
|
1177
|
+
bs += pq.nbits;
|
|
1178
|
+
idmap += ld_idmap;
|
|
1179
|
+
}
|
|
1180
|
+
li[k] = vout;
|
|
1181
|
+
}
|
|
1182
|
+
}
|
|
1183
|
+
}
|
|
1184
|
+
}
|
|
1185
|
+
}
|
|
1186
|
+
|
|
1187
|
+
|
|
1188
|
+
} // namespace faiss
|