faiss 0.2.7 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +11 -4
@@ -72,6 +72,8 @@ struct SearchParametersIVF : SearchParameters {
|
|
72
72
|
size_t nprobe = 1; ///< number of probes at query time
|
73
73
|
size_t max_codes = 0; ///< max nb of codes to visit to do a query
|
74
74
|
SearchParameters* quantizer_params = nullptr;
|
75
|
+
/// context object to pass to InvertedLists
|
76
|
+
void* inverted_list_context = nullptr;
|
75
77
|
|
76
78
|
virtual ~SearchParametersIVF() {}
|
77
79
|
};
|
@@ -177,6 +179,7 @@ struct IndexIVF : Index, IndexIVFInterface {
|
|
177
179
|
bool own_invlists = false;
|
178
180
|
|
179
181
|
size_t code_size = 0; ///< code size per vector in bytes
|
182
|
+
|
180
183
|
/** Parallel mode determines how queries are parallelized with OpenMP
|
181
184
|
*
|
182
185
|
* 0 (default): split over queries
|
@@ -194,6 +197,10 @@ struct IndexIVF : Index, IndexIVFInterface {
|
|
194
197
|
* enables reconstruct() */
|
195
198
|
DirectMap direct_map;
|
196
199
|
|
200
|
+
/// do the codes in the invlists encode the vectors relative to the
|
201
|
+
/// centroids?
|
202
|
+
bool by_residual = true;
|
203
|
+
|
197
204
|
/** The Inverted file takes a quantizer (an Index) on input,
|
198
205
|
* which implements the function mapping a vector to a list
|
199
206
|
* identifier.
|
@@ -207,7 +214,7 @@ struct IndexIVF : Index, IndexIVFInterface {
|
|
207
214
|
|
208
215
|
void reset() override;
|
209
216
|
|
210
|
-
/// Trains the quantizer and calls
|
217
|
+
/// Trains the quantizer and calls train_encoder to train sub-quantizers
|
211
218
|
void train(idx_t n, const float* x) override;
|
212
219
|
|
213
220
|
/// Calls add_with_ids with NULL ids
|
@@ -227,7 +234,8 @@ struct IndexIVF : Index, IndexIVFInterface {
|
|
227
234
|
idx_t n,
|
228
235
|
const float* x,
|
229
236
|
const idx_t* xids,
|
230
|
-
const idx_t* precomputed_idx
|
237
|
+
const idx_t* precomputed_idx,
|
238
|
+
void* inverted_list_context = nullptr);
|
231
239
|
|
232
240
|
/** Encodes a set of vectors as they would appear in the inverted lists
|
233
241
|
*
|
@@ -252,9 +260,15 @@ struct IndexIVF : Index, IndexIVFInterface {
|
|
252
260
|
*/
|
253
261
|
void add_sa_codes(idx_t n, const uint8_t* codes, const idx_t* xids);
|
254
262
|
|
255
|
-
|
256
|
-
|
257
|
-
|
263
|
+
/** Train the encoder for the vectors.
|
264
|
+
*
|
265
|
+
* If by_residual then it is called with residuals and corresponding assign
|
266
|
+
* array, otherwise x is the raw training vectors and assign=nullptr */
|
267
|
+
virtual void train_encoder(idx_t n, const float* x, const idx_t* assign);
|
268
|
+
|
269
|
+
/// can be redefined by subclasses to indicate how many training vectors
|
270
|
+
/// they need
|
271
|
+
virtual idx_t train_encoder_num_vectors() const;
|
258
272
|
|
259
273
|
void search_preassigned(
|
260
274
|
idx_t n,
|
@@ -346,6 +360,24 @@ struct IndexIVF : Index, IndexIVFInterface {
|
|
346
360
|
float* recons,
|
347
361
|
const SearchParameters* params = nullptr) const override;
|
348
362
|
|
363
|
+
/** Similar to search, but also returns the codes corresponding to the
|
364
|
+
* stored vectors for the search results.
|
365
|
+
*
|
366
|
+
* @param codes codes (n, k, code_size)
|
367
|
+
* @param include_listno
|
368
|
+
* include the list ids in the code (in this case add
|
369
|
+
* ceil(log8(nlist)) to the code size)
|
370
|
+
*/
|
371
|
+
void search_and_return_codes(
|
372
|
+
idx_t n,
|
373
|
+
const float* x,
|
374
|
+
idx_t k,
|
375
|
+
float* distances,
|
376
|
+
idx_t* labels,
|
377
|
+
uint8_t* recons,
|
378
|
+
bool include_listno = false,
|
379
|
+
const SearchParameters* params = nullptr) const;
|
380
|
+
|
349
381
|
/** Reconstruct a vector given the location in terms of (inv list index +
|
350
382
|
* inv list offset) instead of the id.
|
351
383
|
*
|
@@ -37,30 +37,20 @@ IndexIVFAdditiveQuantizer::IndexIVFAdditiveQuantizer(
|
|
37
37
|
IndexIVFAdditiveQuantizer::IndexIVFAdditiveQuantizer(AdditiveQuantizer* aq)
|
38
38
|
: IndexIVF(), aq(aq) {}
|
39
39
|
|
40
|
-
void IndexIVFAdditiveQuantizer::
|
41
|
-
|
40
|
+
void IndexIVFAdditiveQuantizer::train_encoder(
|
41
|
+
idx_t n,
|
42
|
+
const float* x,
|
43
|
+
const idx_t* assign) {
|
44
|
+
aq->train(n, x);
|
45
|
+
}
|
42
46
|
|
47
|
+
idx_t IndexIVFAdditiveQuantizer::train_encoder_num_vectors() const {
|
43
48
|
size_t max_train_points = 1024 * ((size_t)1 << aq->nbits[0]);
|
44
49
|
// we need more data to train LSQ
|
45
50
|
if (dynamic_cast<LocalSearchQuantizer*>(aq)) {
|
46
51
|
max_train_points = 1024 * aq->M * ((size_t)1 << aq->nbits[0]);
|
47
52
|
}
|
48
|
-
|
49
|
-
x = fvecs_maybe_subsample(
|
50
|
-
d, (size_t*)&n, max_train_points, x, verbose, 1234);
|
51
|
-
ScopeDeleter<float> del_x(x_in == x ? nullptr : x);
|
52
|
-
|
53
|
-
if (by_residual) {
|
54
|
-
std::vector<idx_t> idx(n);
|
55
|
-
quantizer->assign(n, x, idx.data());
|
56
|
-
|
57
|
-
std::vector<float> residuals(n * d);
|
58
|
-
quantizer->compute_residual_n(n, x, residuals.data(), idx.data());
|
59
|
-
|
60
|
-
aq->train(n, residuals.data());
|
61
|
-
} else {
|
62
|
-
aq->train(n, x);
|
63
|
-
}
|
53
|
+
return max_train_points;
|
64
54
|
}
|
65
55
|
|
66
56
|
void IndexIVFAdditiveQuantizer::encode_vectors(
|
@@ -126,7 +116,7 @@ void IndexIVFAdditiveQuantizer::sa_decode(
|
|
126
116
|
}
|
127
117
|
}
|
128
118
|
|
129
|
-
IndexIVFAdditiveQuantizer::~IndexIVFAdditiveQuantizer()
|
119
|
+
IndexIVFAdditiveQuantizer::~IndexIVFAdditiveQuantizer() = default;
|
130
120
|
|
131
121
|
/*********************************************
|
132
122
|
* AQInvertedListScanner
|
@@ -159,6 +149,7 @@ struct AQInvertedListScanner : InvertedListScanner {
|
|
159
149
|
const float* q;
|
160
150
|
/// following codes come from this inverted list
|
161
151
|
void set_list(idx_t list_no, float coarse_dis) override {
|
152
|
+
this->list_no = list_no;
|
162
153
|
if (ia.metric_type == METRIC_L2 && ia.by_residual) {
|
163
154
|
ia.quantizer->compute_residual(q0, tmp.data(), list_no);
|
164
155
|
q = tmp.data();
|
@@ -167,7 +158,7 @@ struct AQInvertedListScanner : InvertedListScanner {
|
|
167
158
|
}
|
168
159
|
}
|
169
160
|
|
170
|
-
~AQInvertedListScanner()
|
161
|
+
~AQInvertedListScanner() = default;
|
171
162
|
};
|
172
163
|
|
173
164
|
template <bool is_IP>
|
@@ -198,7 +189,7 @@ struct AQInvertedListScannerDecompress : AQInvertedListScanner {
|
|
198
189
|
: fvec_L2sqr(q, b.data(), aq.d);
|
199
190
|
}
|
200
191
|
|
201
|
-
~AQInvertedListScannerDecompress() override
|
192
|
+
~AQInvertedListScannerDecompress() override = default;
|
202
193
|
};
|
203
194
|
|
204
195
|
template <bool is_IP, Search_type_t search_type>
|
@@ -241,7 +232,7 @@ struct AQInvertedListScannerLUT : AQInvertedListScanner {
|
|
241
232
|
aq.compute_1_distance_LUT<is_IP, search_type>(code, LUT.data());
|
242
233
|
}
|
243
234
|
|
244
|
-
~AQInvertedListScannerLUT() override
|
235
|
+
~AQInvertedListScannerLUT() override = default;
|
245
236
|
};
|
246
237
|
|
247
238
|
} // anonymous namespace
|
@@ -320,7 +311,7 @@ IndexIVFResidualQuantizer::IndexIVFResidualQuantizer(
|
|
320
311
|
metric,
|
321
312
|
search_type) {}
|
322
313
|
|
323
|
-
IndexIVFResidualQuantizer::~IndexIVFResidualQuantizer()
|
314
|
+
IndexIVFResidualQuantizer::~IndexIVFResidualQuantizer() = default;
|
324
315
|
|
325
316
|
/**************************************************************************************
|
326
317
|
* IndexIVFLocalSearchQuantizer
|
@@ -342,7 +333,7 @@ IndexIVFLocalSearchQuantizer::IndexIVFLocalSearchQuantizer(
|
|
342
333
|
IndexIVFLocalSearchQuantizer::IndexIVFLocalSearchQuantizer()
|
343
334
|
: IndexIVFAdditiveQuantizer(&lsq) {}
|
344
335
|
|
345
|
-
IndexIVFLocalSearchQuantizer::~IndexIVFLocalSearchQuantizer()
|
336
|
+
IndexIVFLocalSearchQuantizer::~IndexIVFLocalSearchQuantizer() = default;
|
346
337
|
|
347
338
|
/**************************************************************************************
|
348
339
|
* IndexIVFProductResidualQuantizer
|
@@ -365,7 +356,7 @@ IndexIVFProductResidualQuantizer::IndexIVFProductResidualQuantizer(
|
|
365
356
|
IndexIVFProductResidualQuantizer::IndexIVFProductResidualQuantizer()
|
366
357
|
: IndexIVFAdditiveQuantizer(&prq) {}
|
367
358
|
|
368
|
-
IndexIVFProductResidualQuantizer::~IndexIVFProductResidualQuantizer()
|
359
|
+
IndexIVFProductResidualQuantizer::~IndexIVFProductResidualQuantizer() = default;
|
369
360
|
|
370
361
|
/**************************************************************************************
|
371
362
|
* IndexIVFProductLocalSearchQuantizer
|
@@ -388,6 +379,7 @@ IndexIVFProductLocalSearchQuantizer::IndexIVFProductLocalSearchQuantizer(
|
|
388
379
|
IndexIVFProductLocalSearchQuantizer::IndexIVFProductLocalSearchQuantizer()
|
389
380
|
: IndexIVFAdditiveQuantizer(&plsq) {}
|
390
381
|
|
391
|
-
IndexIVFProductLocalSearchQuantizer::~IndexIVFProductLocalSearchQuantizer()
|
382
|
+
IndexIVFProductLocalSearchQuantizer::~IndexIVFProductLocalSearchQuantizer() =
|
383
|
+
default;
|
392
384
|
|
393
385
|
} // namespace faiss
|
@@ -26,7 +26,6 @@ namespace faiss {
|
|
26
26
|
struct IndexIVFAdditiveQuantizer : IndexIVF {
|
27
27
|
// the quantizer
|
28
28
|
AdditiveQuantizer* aq;
|
29
|
-
bool by_residual = true;
|
30
29
|
int use_precomputed_table = 0; // for future use
|
31
30
|
|
32
31
|
using Search_type_t = AdditiveQuantizer::Search_type_t;
|
@@ -40,7 +39,9 @@ struct IndexIVFAdditiveQuantizer : IndexIVF {
|
|
40
39
|
|
41
40
|
explicit IndexIVFAdditiveQuantizer(AdditiveQuantizer* aq);
|
42
41
|
|
43
|
-
void
|
42
|
+
void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
|
43
|
+
|
44
|
+
idx_t train_encoder_num_vectors() const override;
|
44
45
|
|
45
46
|
void encode_vectors(
|
46
47
|
idx_t n,
|
@@ -125,51 +125,27 @@ IndexIVFAdditiveQuantizerFastScan::IndexIVFAdditiveQuantizerFastScan() {
|
|
125
125
|
is_trained = false;
|
126
126
|
}
|
127
127
|
|
128
|
-
IndexIVFAdditiveQuantizerFastScan::~IndexIVFAdditiveQuantizerFastScan()
|
128
|
+
IndexIVFAdditiveQuantizerFastScan::~IndexIVFAdditiveQuantizerFastScan() =
|
129
|
+
default;
|
129
130
|
|
130
131
|
/*********************************************************
|
131
132
|
* Training
|
132
133
|
*********************************************************/
|
133
134
|
|
134
|
-
|
135
|
+
idx_t IndexIVFAdditiveQuantizerFastScan::train_encoder_num_vectors() const {
|
136
|
+
return max_train_points;
|
137
|
+
}
|
138
|
+
|
139
|
+
void IndexIVFAdditiveQuantizerFastScan::train_encoder(
|
135
140
|
idx_t n,
|
136
|
-
const float*
|
141
|
+
const float* x,
|
142
|
+
const idx_t* assign) {
|
137
143
|
if (aq->is_trained) {
|
138
144
|
return;
|
139
145
|
}
|
140
146
|
|
141
|
-
const int seed = 0x12345;
|
142
|
-
size_t nt = n;
|
143
|
-
const float* x = fvecs_maybe_subsample(
|
144
|
-
d, &nt, max_train_points, x_in, verbose, seed);
|
145
|
-
n = nt;
|
146
147
|
if (verbose) {
|
147
|
-
printf("training additive quantizer on %
|
148
|
-
}
|
149
|
-
aq->verbose = verbose;
|
150
|
-
|
151
|
-
std::unique_ptr<float[]> del_x;
|
152
|
-
if (x != x_in) {
|
153
|
-
del_x.reset((float*)x);
|
154
|
-
}
|
155
|
-
|
156
|
-
const float* trainset;
|
157
|
-
std::vector<float> residuals(n * d);
|
158
|
-
std::vector<idx_t> assign(n);
|
159
|
-
|
160
|
-
if (by_residual) {
|
161
|
-
if (verbose) {
|
162
|
-
printf("computing residuals\n");
|
163
|
-
}
|
164
|
-
quantizer->assign(n, x, assign.data());
|
165
|
-
residuals.resize(n * d);
|
166
|
-
for (idx_t i = 0; i < n; i++) {
|
167
|
-
quantizer->compute_residual(
|
168
|
-
x + i * d, residuals.data() + i * d, assign[i]);
|
169
|
-
}
|
170
|
-
trainset = residuals.data();
|
171
|
-
} else {
|
172
|
-
trainset = x;
|
148
|
+
printf("training additive quantizer on %d vectors\n", int(n));
|
173
149
|
}
|
174
150
|
|
175
151
|
if (verbose) {
|
@@ -181,17 +157,16 @@ void IndexIVFAdditiveQuantizerFastScan::train_residual(
|
|
181
157
|
d);
|
182
158
|
}
|
183
159
|
aq->verbose = verbose;
|
184
|
-
aq->train(n,
|
160
|
+
aq->train(n, x);
|
185
161
|
|
186
162
|
// train norm quantizer
|
187
163
|
if (by_residual && metric_type == METRIC_L2) {
|
188
164
|
std::vector<float> decoded_x(n * d);
|
189
165
|
std::vector<uint8_t> x_codes(n * aq->code_size);
|
190
|
-
aq->compute_codes(
|
166
|
+
aq->compute_codes(x, x_codes.data(), n);
|
191
167
|
aq->decode(x_codes.data(), decoded_x.data(), n);
|
192
168
|
|
193
169
|
// add coarse centroids
|
194
|
-
FAISS_THROW_IF_NOT(assign.size() == n);
|
195
170
|
std::vector<float> centroid(d);
|
196
171
|
for (idx_t i = 0; i < n; i++) {
|
197
172
|
auto xi = decoded_x.data() + i * d;
|
@@ -236,7 +211,8 @@ void IndexIVFAdditiveQuantizerFastScan::estimate_norm_scale(
|
|
236
211
|
|
237
212
|
size_t index_nprobe = nprobe;
|
238
213
|
nprobe = 1;
|
239
|
-
|
214
|
+
CoarseQuantized cq{index_nprobe, coarse_dis.data(), coarse_ids.data()};
|
215
|
+
compute_LUT(n, x, cq, dis_tables, biases);
|
240
216
|
nprobe = index_nprobe;
|
241
217
|
|
242
218
|
float scale = 0;
|
@@ -338,11 +314,8 @@ void IndexIVFAdditiveQuantizerFastScan::search(
|
|
338
314
|
}
|
339
315
|
|
340
316
|
NormTableScaler scaler(norm_scale);
|
341
|
-
|
342
|
-
|
343
|
-
} else {
|
344
|
-
search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
|
345
|
-
}
|
317
|
+
IndexIVFFastScan::CoarseQuantized cq{nprobe};
|
318
|
+
search_dispatch_implem(n, x, k, distances, labels, cq, &scaler);
|
346
319
|
}
|
347
320
|
|
348
321
|
/*********************************************************
|
@@ -408,12 +381,12 @@ bool IndexIVFAdditiveQuantizerFastScan::lookup_table_is_3d() const {
|
|
408
381
|
void IndexIVFAdditiveQuantizerFastScan::compute_LUT(
|
409
382
|
size_t n,
|
410
383
|
const float* x,
|
411
|
-
const
|
412
|
-
const float*,
|
384
|
+
const CoarseQuantized& cq,
|
413
385
|
AlignedTable<float>& dis_tables,
|
414
386
|
AlignedTable<float>& biases) const {
|
415
387
|
const size_t dim12 = ksub * M;
|
416
388
|
const size_t ip_dim12 = aq->M * ksub;
|
389
|
+
const size_t nprobe = cq.nprobe;
|
417
390
|
|
418
391
|
dis_tables.resize(n * dim12);
|
419
392
|
|
@@ -434,7 +407,7 @@ void IndexIVFAdditiveQuantizerFastScan::compute_LUT(
|
|
434
407
|
#pragma omp for
|
435
408
|
for (idx_t ij = 0; ij < n * nprobe; ij++) {
|
436
409
|
int i = ij / nprobe;
|
437
|
-
quantizer->reconstruct(
|
410
|
+
quantizer->reconstruct(cq.ids[ij], c);
|
438
411
|
biases[ij] = coef * fvec_inner_product(c, x + i * d, d);
|
439
412
|
}
|
440
413
|
}
|
@@ -63,7 +63,9 @@ struct IndexIVFAdditiveQuantizerFastScan : IndexIVFFastScan {
|
|
63
63
|
const IndexIVFAdditiveQuantizer& orig,
|
64
64
|
int bbs = 32);
|
65
65
|
|
66
|
-
void
|
66
|
+
void train_encoder(idx_t n, const float* x, const idx_t* assign) override;
|
67
|
+
|
68
|
+
idx_t train_encoder_num_vectors() const override;
|
67
69
|
|
68
70
|
void estimate_norm_scale(idx_t n, const float* x);
|
69
71
|
|
@@ -91,8 +93,7 @@ struct IndexIVFAdditiveQuantizerFastScan : IndexIVFFastScan {
|
|
91
93
|
void compute_LUT(
|
92
94
|
size_t n,
|
93
95
|
const float* x,
|
94
|
-
const
|
95
|
-
const float* coarse_dis,
|
96
|
+
const CoarseQuantized& cq,
|
96
97
|
AlignedTable<float>& dis_tables,
|
97
98
|
AlignedTable<float>& biases) const override;
|
98
99
|
|