faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#pragma once
|
|
9
|
+
|
|
10
|
+
#include <cstdint>
|
|
11
|
+
#include <vector>
|
|
12
|
+
|
|
13
|
+
#include <faiss/Index.h>
|
|
14
|
+
#include <faiss/IndexFlat.h>
|
|
15
|
+
|
|
16
|
+
namespace faiss {
|
|
17
|
+
|
|
18
|
+
/** Abstract structure for additive quantizers
|
|
19
|
+
*
|
|
20
|
+
* Different from the product quantizer in which the decoded vector is the
|
|
21
|
+
* concatenation of M sub-vectors, additive quantizers sum M sub-vectors
|
|
22
|
+
* to get the decoded vector.
|
|
23
|
+
*/
|
|
24
|
+
struct AdditiveQuantizer {
|
|
25
|
+
size_t d; ///< size of the input vectors
|
|
26
|
+
size_t M; ///< number of codebooks
|
|
27
|
+
std::vector<size_t> nbits; ///< bits for each step
|
|
28
|
+
std::vector<float> codebooks; ///< codebooks
|
|
29
|
+
|
|
30
|
+
// derived values
|
|
31
|
+
std::vector<uint64_t> codebook_offsets;
|
|
32
|
+
size_t code_size; ///< code size in bytes
|
|
33
|
+
size_t tot_bits; ///< total number of bits
|
|
34
|
+
size_t total_codebook_size; ///< size of the codebook in vectors
|
|
35
|
+
bool only_8bit; ///< are all nbits = 8 (use faster decoder)
|
|
36
|
+
|
|
37
|
+
bool verbose; ///< verbose during training?
|
|
38
|
+
bool is_trained; ///< is trained or not
|
|
39
|
+
|
|
40
|
+
IndexFlat1D qnorm; ///< store and search norms
|
|
41
|
+
|
|
42
|
+
uint32_t encode_qcint(
|
|
43
|
+
float x) const; ///< encode norm by non-uniform scalar quantization
|
|
44
|
+
|
|
45
|
+
float decode_qcint(uint32_t c)
|
|
46
|
+
const; ///< decode norm by non-uniform scalar quantization
|
|
47
|
+
|
|
48
|
+
/// Encodes how search is performed and how vectors are encoded
|
|
49
|
+
enum Search_type_t {
|
|
50
|
+
ST_decompress, ///< decompress database vector
|
|
51
|
+
ST_LUT_nonorm, ///< use a LUT, don't include norms (OK for IP or
|
|
52
|
+
///< normalized vectors)
|
|
53
|
+
ST_norm_from_LUT, ///< compute the norms from the look-up tables (cost
|
|
54
|
+
///< is in O(M^2))
|
|
55
|
+
ST_norm_float, ///< use a LUT, and store float32 norm with the vectors
|
|
56
|
+
ST_norm_qint8, ///< use a LUT, and store 8bit-quantized norm
|
|
57
|
+
ST_norm_qint4,
|
|
58
|
+
ST_norm_cqint8, ///< use a LUT, and store non-uniform quantized norm
|
|
59
|
+
ST_norm_cqint4,
|
|
60
|
+
};
|
|
61
|
+
|
|
62
|
+
AdditiveQuantizer(
|
|
63
|
+
size_t d,
|
|
64
|
+
const std::vector<size_t>& nbits,
|
|
65
|
+
Search_type_t search_type = ST_decompress);
|
|
66
|
+
|
|
67
|
+
AdditiveQuantizer();
|
|
68
|
+
|
|
69
|
+
///< compute derived values when d, M and nbits have been set
|
|
70
|
+
void set_derived_values();
|
|
71
|
+
|
|
72
|
+
///< Train the additive quantizer
|
|
73
|
+
virtual void train(size_t n, const float* x) = 0;
|
|
74
|
+
|
|
75
|
+
/** Encode a set of vectors
|
|
76
|
+
*
|
|
77
|
+
* @param x vectors to encode, size n * d
|
|
78
|
+
* @param codes output codes, size n * code_size
|
|
79
|
+
*/
|
|
80
|
+
virtual void compute_codes(const float* x, uint8_t* codes, size_t n)
|
|
81
|
+
const = 0;
|
|
82
|
+
|
|
83
|
+
/** pack a series of code to bit-compact format
|
|
84
|
+
*
|
|
85
|
+
* @param codes codes to be packed, size n * code_size
|
|
86
|
+
* @param packed_codes output bit-compact codes
|
|
87
|
+
* @param ld_codes leading dimension of codes
|
|
88
|
+
* @param norms norms of the vectors (size n). Will be computed if
|
|
89
|
+
* needed but not provided
|
|
90
|
+
*/
|
|
91
|
+
void pack_codes(
|
|
92
|
+
size_t n,
|
|
93
|
+
const int32_t* codes,
|
|
94
|
+
uint8_t* packed_codes,
|
|
95
|
+
int64_t ld_codes = -1,
|
|
96
|
+
const float* norms = nullptr) const;
|
|
97
|
+
|
|
98
|
+
/** Decode a set of vectors
|
|
99
|
+
*
|
|
100
|
+
* @param codes codes to decode, size n * code_size
|
|
101
|
+
* @param x output vectors, size n * d
|
|
102
|
+
*/
|
|
103
|
+
void decode(const uint8_t* codes, float* x, size_t n) const;
|
|
104
|
+
|
|
105
|
+
/** Decode a set of vectors in non-packed format
|
|
106
|
+
*
|
|
107
|
+
* @param codes codes to decode, size n * ld_codes
|
|
108
|
+
* @param x output vectors, size n * d
|
|
109
|
+
*/
|
|
110
|
+
void decode_unpacked(
|
|
111
|
+
const int32_t* codes,
|
|
112
|
+
float* x,
|
|
113
|
+
size_t n,
|
|
114
|
+
int64_t ld_codes = -1) const;
|
|
115
|
+
|
|
116
|
+
/****************************************************************************
|
|
117
|
+
* Search functions in an external set of codes.
|
|
118
|
+
****************************************************************************/
|
|
119
|
+
|
|
120
|
+
/// Also determines what's in the codes
|
|
121
|
+
Search_type_t search_type;
|
|
122
|
+
|
|
123
|
+
/// min/max for quantization of norms
|
|
124
|
+
float norm_min, norm_max;
|
|
125
|
+
|
|
126
|
+
template <bool is_IP, Search_type_t effective_search_type>
|
|
127
|
+
float compute_1_distance_LUT(const uint8_t* codes, const float* LUT) const;
|
|
128
|
+
|
|
129
|
+
/*
|
|
130
|
+
float compute_1_L2sqr(const uint8_t* codes, const float* LUT);
|
|
131
|
+
*/
|
|
132
|
+
/****************************************************************************
|
|
133
|
+
* Support for exhaustive distance computations with all the centroids.
|
|
134
|
+
* Hence, the number of these centroids should not be too large.
|
|
135
|
+
****************************************************************************/
|
|
136
|
+
using idx_t = Index::idx_t;
|
|
137
|
+
|
|
138
|
+
/// decoding function for a code in a 64-bit word
|
|
139
|
+
void decode_64bit(idx_t n, float* x) const;
|
|
140
|
+
|
|
141
|
+
/** Compute inner-product look-up tables. Used in the centroid search
|
|
142
|
+
* functions.
|
|
143
|
+
*
|
|
144
|
+
* @param xq query vector, size (n, d)
|
|
145
|
+
* @param LUT look-up table, size (n, total_codebook_size)
|
|
146
|
+
*/
|
|
147
|
+
void compute_LUT(size_t n, const float* xq, float* LUT) const;
|
|
148
|
+
|
|
149
|
+
/// exact IP search
|
|
150
|
+
void knn_centroids_inner_product(
|
|
151
|
+
idx_t n,
|
|
152
|
+
const float* xq,
|
|
153
|
+
idx_t k,
|
|
154
|
+
float* distances,
|
|
155
|
+
idx_t* labels) const;
|
|
156
|
+
|
|
157
|
+
/** For L2 search we need the L2 norms of the centroids
|
|
158
|
+
*
|
|
159
|
+
* @param norms output norms table, size total_codebook_size
|
|
160
|
+
*/
|
|
161
|
+
void compute_centroid_norms(float* norms) const;
|
|
162
|
+
|
|
163
|
+
/** Exact L2 search, with precomputed norms */
|
|
164
|
+
void knn_centroids_L2(
|
|
165
|
+
idx_t n,
|
|
166
|
+
const float* xq,
|
|
167
|
+
idx_t k,
|
|
168
|
+
float* distances,
|
|
169
|
+
idx_t* labels,
|
|
170
|
+
const float* centroid_norms) const;
|
|
171
|
+
|
|
172
|
+
virtual ~AdditiveQuantizer();
|
|
173
|
+
};
|
|
174
|
+
|
|
175
|
+
}; // namespace faiss
|
|
@@ -14,18 +14,16 @@
|
|
|
14
14
|
|
|
15
15
|
#include <faiss/impl/FaissAssert.h>
|
|
16
16
|
|
|
17
|
-
|
|
18
17
|
namespace faiss {
|
|
19
18
|
|
|
20
|
-
|
|
21
19
|
/***********************************************************************
|
|
22
20
|
* RangeSearchResult
|
|
23
21
|
***********************************************************************/
|
|
24
22
|
|
|
25
|
-
RangeSearchResult::RangeSearchResult
|
|
23
|
+
RangeSearchResult::RangeSearchResult(idx_t nq, bool alloc_lims) : nq(nq) {
|
|
26
24
|
if (alloc_lims) {
|
|
27
|
-
lims = new size_t
|
|
28
|
-
memset
|
|
25
|
+
lims = new size_t[nq + 1];
|
|
26
|
+
memset(lims, 0, sizeof(*lims) * (nq + 1));
|
|
29
27
|
} else {
|
|
30
28
|
lims = nullptr;
|
|
31
29
|
}
|
|
@@ -36,145 +34,129 @@ RangeSearchResult::RangeSearchResult (idx_t nq, bool alloc_lims): nq (nq) {
|
|
|
36
34
|
|
|
37
35
|
/// called when lims contains the nb of elements result entries
|
|
38
36
|
/// for each query
|
|
39
|
-
void RangeSearchResult::do_allocation
|
|
37
|
+
void RangeSearchResult::do_allocation() {
|
|
38
|
+
// works only if all the partial results are aggregated
|
|
39
|
+
// simulatenously
|
|
40
|
+
FAISS_THROW_IF_NOT(labels == nullptr && distances == nullptr);
|
|
40
41
|
size_t ofs = 0;
|
|
41
42
|
for (int i = 0; i < nq; i++) {
|
|
42
43
|
size_t n = lims[i];
|
|
43
|
-
lims
|
|
44
|
+
lims[i] = ofs;
|
|
44
45
|
ofs += n;
|
|
45
46
|
}
|
|
46
|
-
lims
|
|
47
|
-
labels = new idx_t
|
|
48
|
-
distances = new float
|
|
47
|
+
lims[nq] = ofs;
|
|
48
|
+
labels = new idx_t[ofs];
|
|
49
|
+
distances = new float[ofs];
|
|
49
50
|
}
|
|
50
51
|
|
|
51
|
-
RangeSearchResult::~RangeSearchResult
|
|
52
|
-
delete
|
|
53
|
-
delete
|
|
54
|
-
delete
|
|
52
|
+
RangeSearchResult::~RangeSearchResult() {
|
|
53
|
+
delete[] labels;
|
|
54
|
+
delete[] distances;
|
|
55
|
+
delete[] lims;
|
|
55
56
|
}
|
|
56
57
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
58
|
/***********************************************************************
|
|
62
59
|
* BufferList
|
|
63
60
|
***********************************************************************/
|
|
64
61
|
|
|
65
|
-
|
|
66
|
-
BufferList::BufferList (size_t buffer_size):
|
|
67
|
-
buffer_size (buffer_size)
|
|
68
|
-
{
|
|
62
|
+
BufferList::BufferList(size_t buffer_size) : buffer_size(buffer_size) {
|
|
69
63
|
wp = buffer_size;
|
|
70
64
|
}
|
|
71
65
|
|
|
72
|
-
BufferList::~BufferList
|
|
73
|
-
{
|
|
66
|
+
BufferList::~BufferList() {
|
|
74
67
|
for (int i = 0; i < buffers.size(); i++) {
|
|
75
|
-
delete
|
|
76
|
-
delete
|
|
68
|
+
delete[] buffers[i].ids;
|
|
69
|
+
delete[] buffers[i].dis;
|
|
77
70
|
}
|
|
78
71
|
}
|
|
79
72
|
|
|
80
|
-
void BufferList::add
|
|
73
|
+
void BufferList::add(idx_t id, float dis) {
|
|
81
74
|
if (wp == buffer_size) { // need new buffer
|
|
82
75
|
append_buffer();
|
|
83
76
|
}
|
|
84
|
-
Buffer
|
|
85
|
-
buf.ids
|
|
86
|
-
buf.dis
|
|
77
|
+
Buffer& buf = buffers.back();
|
|
78
|
+
buf.ids[wp] = id;
|
|
79
|
+
buf.dis[wp] = dis;
|
|
87
80
|
wp++;
|
|
88
81
|
}
|
|
89
82
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
Buffer buf = {new idx_t [buffer_size], new float [buffer_size]};
|
|
94
|
-
buffers.push_back (buf);
|
|
83
|
+
void BufferList::append_buffer() {
|
|
84
|
+
Buffer buf = {new idx_t[buffer_size], new float[buffer_size]};
|
|
85
|
+
buffers.push_back(buf);
|
|
95
86
|
wp = 0;
|
|
96
87
|
}
|
|
97
88
|
|
|
98
89
|
/// copy elemnts ofs:ofs+n-1 seen as linear data in the buffers to
|
|
99
90
|
/// tables dest_ids, dest_dis
|
|
100
|
-
void BufferList::copy_range
|
|
101
|
-
|
|
102
|
-
|
|
91
|
+
void BufferList::copy_range(
|
|
92
|
+
size_t ofs,
|
|
93
|
+
size_t n,
|
|
94
|
+
idx_t* dest_ids,
|
|
95
|
+
float* dest_dis) {
|
|
103
96
|
size_t bno = ofs / buffer_size;
|
|
104
97
|
ofs -= bno * buffer_size;
|
|
105
98
|
while (n > 0) {
|
|
106
99
|
size_t ncopy = ofs + n < buffer_size ? n : buffer_size - ofs;
|
|
107
|
-
Buffer buf = buffers
|
|
108
|
-
memcpy
|
|
109
|
-
memcpy
|
|
100
|
+
Buffer buf = buffers[bno];
|
|
101
|
+
memcpy(dest_ids, buf.ids + ofs, ncopy * sizeof(*dest_ids));
|
|
102
|
+
memcpy(dest_dis, buf.dis + ofs, ncopy * sizeof(*dest_dis));
|
|
110
103
|
dest_ids += ncopy;
|
|
111
104
|
dest_dis += ncopy;
|
|
112
105
|
ofs = 0;
|
|
113
|
-
bno
|
|
106
|
+
bno++;
|
|
114
107
|
n -= ncopy;
|
|
115
108
|
}
|
|
116
109
|
}
|
|
117
110
|
|
|
118
|
-
|
|
119
111
|
/***********************************************************************
|
|
120
112
|
* RangeSearchPartialResult
|
|
121
113
|
***********************************************************************/
|
|
122
114
|
|
|
123
|
-
void RangeQueryResult::add
|
|
115
|
+
void RangeQueryResult::add(float dis, idx_t id) {
|
|
124
116
|
nres++;
|
|
125
|
-
pres->add
|
|
117
|
+
pres->add(id, dis);
|
|
126
118
|
}
|
|
127
119
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
RangeSearchPartialResult::RangeSearchPartialResult (RangeSearchResult * res_in):
|
|
131
|
-
BufferList(res_in->buffer_size),
|
|
132
|
-
res(res_in)
|
|
133
|
-
{}
|
|
134
|
-
|
|
120
|
+
RangeSearchPartialResult::RangeSearchPartialResult(RangeSearchResult* res_in)
|
|
121
|
+
: BufferList(res_in->buffer_size), res(res_in) {}
|
|
135
122
|
|
|
136
123
|
/// begin a new result
|
|
137
|
-
RangeQueryResult
|
|
138
|
-
RangeSearchPartialResult::new_result (idx_t qno)
|
|
139
|
-
{
|
|
124
|
+
RangeQueryResult& RangeSearchPartialResult::new_result(idx_t qno) {
|
|
140
125
|
RangeQueryResult qres = {qno, 0, this};
|
|
141
|
-
queries.push_back
|
|
126
|
+
queries.push_back(qres);
|
|
142
127
|
return queries.back();
|
|
143
128
|
}
|
|
144
129
|
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
{
|
|
148
|
-
set_lims ();
|
|
130
|
+
void RangeSearchPartialResult::finalize() {
|
|
131
|
+
set_lims();
|
|
149
132
|
#pragma omp barrier
|
|
150
133
|
|
|
151
134
|
#pragma omp single
|
|
152
|
-
res->do_allocation
|
|
135
|
+
res->do_allocation();
|
|
153
136
|
|
|
154
137
|
#pragma omp barrier
|
|
155
|
-
copy_result
|
|
138
|
+
copy_result();
|
|
156
139
|
}
|
|
157
140
|
|
|
158
|
-
|
|
159
141
|
/// called by range_search before do_allocation
|
|
160
|
-
void RangeSearchPartialResult::set_lims
|
|
161
|
-
{
|
|
142
|
+
void RangeSearchPartialResult::set_lims() {
|
|
162
143
|
for (int i = 0; i < queries.size(); i++) {
|
|
163
|
-
RangeQueryResult
|
|
144
|
+
RangeQueryResult& qres = queries[i];
|
|
164
145
|
res->lims[qres.qno] = qres.nres;
|
|
165
146
|
}
|
|
166
147
|
}
|
|
167
148
|
|
|
168
149
|
/// called by range_search after do_allocation
|
|
169
|
-
void RangeSearchPartialResult::copy_result
|
|
170
|
-
{
|
|
150
|
+
void RangeSearchPartialResult::copy_result(bool incremental) {
|
|
171
151
|
size_t ofs = 0;
|
|
172
152
|
for (int i = 0; i < queries.size(); i++) {
|
|
173
|
-
RangeQueryResult
|
|
153
|
+
RangeQueryResult& qres = queries[i];
|
|
174
154
|
|
|
175
|
-
copy_range
|
|
176
|
-
|
|
177
|
-
|
|
155
|
+
copy_range(
|
|
156
|
+
ofs,
|
|
157
|
+
qres.nres,
|
|
158
|
+
res->labels + res->lims[qres.qno],
|
|
159
|
+
res->distances + res->lims[qres.qno]);
|
|
178
160
|
if (incremental) {
|
|
179
161
|
res->lims[qres.qno] += qres.nres;
|
|
180
162
|
}
|
|
@@ -182,26 +164,28 @@ void RangeSearchPartialResult::copy_result (bool incremental)
|
|
|
182
164
|
}
|
|
183
165
|
}
|
|
184
166
|
|
|
185
|
-
void RangeSearchPartialResult::merge
|
|
186
|
-
|
|
187
|
-
{
|
|
188
|
-
|
|
167
|
+
void RangeSearchPartialResult::merge(
|
|
168
|
+
std::vector<RangeSearchPartialResult*>& partial_results,
|
|
169
|
+
bool do_delete) {
|
|
189
170
|
int npres = partial_results.size();
|
|
190
|
-
if (npres == 0)
|
|
191
|
-
|
|
171
|
+
if (npres == 0)
|
|
172
|
+
return;
|
|
173
|
+
RangeSearchResult* result = partial_results[0]->res;
|
|
192
174
|
size_t nx = result->nq;
|
|
193
175
|
|
|
194
176
|
// count
|
|
195
|
-
for (const RangeSearchPartialResult
|
|
196
|
-
if (!pres)
|
|
197
|
-
|
|
177
|
+
for (const RangeSearchPartialResult* pres : partial_results) {
|
|
178
|
+
if (!pres)
|
|
179
|
+
continue;
|
|
180
|
+
for (const RangeQueryResult& qres : pres->queries) {
|
|
198
181
|
result->lims[qres.qno] += qres.nres;
|
|
199
182
|
}
|
|
200
183
|
}
|
|
201
|
-
result->do_allocation
|
|
184
|
+
result->do_allocation();
|
|
202
185
|
for (int j = 0; j < npres; j++) {
|
|
203
|
-
if (!partial_results[j])
|
|
204
|
-
|
|
186
|
+
if (!partial_results[j])
|
|
187
|
+
continue;
|
|
188
|
+
partial_results[j]->copy_result(true);
|
|
205
189
|
if (do_delete) {
|
|
206
190
|
delete partial_results[j];
|
|
207
191
|
partial_results[j] = nullptr;
|
|
@@ -210,22 +194,19 @@ void RangeSearchPartialResult::merge (std::vector <RangeSearchPartialResult *> &
|
|
|
210
194
|
|
|
211
195
|
// reset the limits
|
|
212
196
|
for (size_t i = nx; i > 0; i--) {
|
|
213
|
-
result->lims
|
|
197
|
+
result->lims[i] = result->lims[i - 1];
|
|
214
198
|
}
|
|
215
|
-
result->lims
|
|
199
|
+
result->lims[0] = 0;
|
|
216
200
|
}
|
|
217
201
|
|
|
218
202
|
/***********************************************************************
|
|
219
203
|
* IDSelectorRange
|
|
220
204
|
***********************************************************************/
|
|
221
205
|
|
|
222
|
-
IDSelectorRange::IDSelectorRange
|
|
223
|
-
|
|
224
|
-
{
|
|
225
|
-
}
|
|
206
|
+
IDSelectorRange::IDSelectorRange(idx_t imin, idx_t imax)
|
|
207
|
+
: imin(imin), imax(imax) {}
|
|
226
208
|
|
|
227
|
-
bool IDSelectorRange::is_member
|
|
228
|
-
{
|
|
209
|
+
bool IDSelectorRange::is_member(idx_t id) const {
|
|
229
210
|
return id >= imin && id < imax;
|
|
230
211
|
}
|
|
231
212
|
|
|
@@ -233,33 +214,29 @@ bool IDSelectorRange::is_member (idx_t id) const
|
|
|
233
214
|
* IDSelectorArray
|
|
234
215
|
***********************************************************************/
|
|
235
216
|
|
|
236
|
-
IDSelectorArray::IDSelectorArray
|
|
237
|
-
n (n), ids(ids)
|
|
238
|
-
{
|
|
239
|
-
}
|
|
217
|
+
IDSelectorArray::IDSelectorArray(size_t n, const idx_t* ids) : n(n), ids(ids) {}
|
|
240
218
|
|
|
241
|
-
bool IDSelectorArray::is_member
|
|
242
|
-
{
|
|
219
|
+
bool IDSelectorArray::is_member(idx_t id) const {
|
|
243
220
|
for (idx_t i = 0; i < n; i++) {
|
|
244
|
-
if (ids[i] == id)
|
|
221
|
+
if (ids[i] == id)
|
|
222
|
+
return true;
|
|
245
223
|
}
|
|
246
224
|
return false;
|
|
247
225
|
}
|
|
248
226
|
|
|
249
|
-
|
|
250
227
|
/***********************************************************************
|
|
251
228
|
* IDSelectorBatch
|
|
252
229
|
***********************************************************************/
|
|
253
230
|
|
|
254
|
-
IDSelectorBatch::IDSelectorBatch
|
|
255
|
-
{
|
|
231
|
+
IDSelectorBatch::IDSelectorBatch(size_t n, const idx_t* indices) {
|
|
256
232
|
nbits = 0;
|
|
257
|
-
while (n > (1L << nbits))
|
|
233
|
+
while (n > (1L << nbits))
|
|
234
|
+
nbits++;
|
|
258
235
|
nbits += 5;
|
|
259
236
|
// for n = 1M, nbits = 25 is optimal, see P56659518
|
|
260
237
|
|
|
261
238
|
mask = (1L << nbits) - 1;
|
|
262
|
-
bloom.resize
|
|
239
|
+
bloom.resize(1UL << (nbits - 3), 0);
|
|
263
240
|
for (long i = 0; i < n; i++) {
|
|
264
241
|
Index::idx_t id = indices[i];
|
|
265
242
|
set.insert(id);
|
|
@@ -268,39 +245,36 @@ IDSelectorBatch::IDSelectorBatch (size_t n, const idx_t *indices)
|
|
|
268
245
|
}
|
|
269
246
|
}
|
|
270
247
|
|
|
271
|
-
bool IDSelectorBatch::is_member
|
|
272
|
-
{
|
|
248
|
+
bool IDSelectorBatch::is_member(idx_t i) const {
|
|
273
249
|
long im = i & mask;
|
|
274
|
-
if(!(bloom[im>>3] & (1 << (im & 7)))) {
|
|
250
|
+
if (!(bloom[im >> 3] & (1 << (im & 7)))) {
|
|
275
251
|
return 0;
|
|
276
252
|
}
|
|
277
253
|
return set.count(i);
|
|
278
254
|
}
|
|
279
255
|
|
|
280
|
-
|
|
281
256
|
/***********************************************************
|
|
282
257
|
* Interrupt callback
|
|
283
258
|
***********************************************************/
|
|
284
259
|
|
|
285
|
-
|
|
286
260
|
std::unique_ptr<InterruptCallback> InterruptCallback::instance;
|
|
287
261
|
|
|
288
262
|
std::mutex InterruptCallback::lock;
|
|
289
263
|
|
|
290
|
-
void InterruptCallback::clear_instance
|
|
291
|
-
delete instance.release
|
|
264
|
+
void InterruptCallback::clear_instance() {
|
|
265
|
+
delete instance.release();
|
|
292
266
|
}
|
|
293
267
|
|
|
294
|
-
void InterruptCallback::check
|
|
268
|
+
void InterruptCallback::check() {
|
|
295
269
|
if (!instance.get()) {
|
|
296
270
|
return;
|
|
297
271
|
}
|
|
298
|
-
if (instance->want_interrupt
|
|
299
|
-
FAISS_THROW_MSG
|
|
272
|
+
if (instance->want_interrupt()) {
|
|
273
|
+
FAISS_THROW_MSG("computation interrupted");
|
|
300
274
|
}
|
|
301
275
|
}
|
|
302
276
|
|
|
303
|
-
bool InterruptCallback::is_interrupted
|
|
277
|
+
bool InterruptCallback::is_interrupted() {
|
|
304
278
|
if (!instance.get()) {
|
|
305
279
|
return false;
|
|
306
280
|
}
|
|
@@ -308,8 +282,7 @@ bool InterruptCallback::is_interrupted () {
|
|
|
308
282
|
return instance->want_interrupt();
|
|
309
283
|
}
|
|
310
284
|
|
|
311
|
-
|
|
312
|
-
size_t InterruptCallback::get_period_hint (size_t flops) {
|
|
285
|
+
size_t InterruptCallback::get_period_hint(size_t flops) {
|
|
313
286
|
if (!instance.get()) {
|
|
314
287
|
return 1L << 30; // never check
|
|
315
288
|
}
|
|
@@ -317,7 +290,4 @@ size_t InterruptCallback::get_period_hint (size_t flops) {
|
|
|
317
290
|
return std::max((size_t)10 * 10 * 1000 * 1000 / (flops + 1), (size_t)1);
|
|
318
291
|
}
|
|
319
292
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
293
|
} // namespace faiss
|