faiss 0.3.0 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +10 -3
@@ -14,6 +14,7 @@
|
|
14
14
|
#include <cstring>
|
15
15
|
|
16
16
|
#include <algorithm>
|
17
|
+
#include <memory>
|
17
18
|
|
18
19
|
#include <faiss/impl/DistanceComputer.h>
|
19
20
|
#include <faiss/impl/FaissAssert.h>
|
@@ -86,7 +87,7 @@ struct PQDistanceComputer : FlatCodesDistanceComputer {
|
|
86
87
|
ndis++;
|
87
88
|
|
88
89
|
float dis = distance_single_code<PQDecoder>(
|
89
|
-
pq, precomputed_table.data(), code);
|
90
|
+
pq.M, pq.nbits, precomputed_table.data(), code);
|
90
91
|
return dis;
|
91
92
|
}
|
92
93
|
|
@@ -198,17 +199,16 @@ void IndexPQ::search(
|
|
198
199
|
|
199
200
|
} else { // code-to-code distances
|
200
201
|
|
201
|
-
uint8_t
|
202
|
-
ScopeDeleter<uint8_t> del(q_codes);
|
202
|
+
std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]);
|
203
203
|
|
204
204
|
if (!encode_signs) {
|
205
|
-
pq.compute_codes(x, q_codes, n);
|
205
|
+
pq.compute_codes(x, q_codes.get(), n);
|
206
206
|
} else {
|
207
207
|
FAISS_THROW_IF_NOT(d == pq.nbits * pq.M);
|
208
|
-
memset(q_codes, 0, n * pq.code_size);
|
208
|
+
memset(q_codes.get(), 0, n * pq.code_size);
|
209
209
|
for (size_t i = 0; i < n; i++) {
|
210
210
|
const float* xi = x + i * d;
|
211
|
-
uint8_t* code = q_codes + i * pq.code_size;
|
211
|
+
uint8_t* code = q_codes.get() + i * pq.code_size;
|
212
212
|
for (int j = 0; j < d; j++)
|
213
213
|
if (xi[j] > 0)
|
214
214
|
code[j >> 3] |= 1 << (j & 7);
|
@@ -219,19 +219,18 @@ void IndexPQ::search(
|
|
219
219
|
float_maxheap_array_t res = {
|
220
220
|
size_t(n), size_t(k), labels, distances};
|
221
221
|
|
222
|
-
pq.search_sdc(q_codes, n, codes.data(), ntotal, &res, true);
|
222
|
+
pq.search_sdc(q_codes.get(), n, codes.data(), ntotal, &res, true);
|
223
223
|
|
224
224
|
} else {
|
225
|
-
int
|
226
|
-
ScopeDeleter<int> del(idistances);
|
225
|
+
std::unique_ptr<int[]> idistances(new int[n * k]);
|
227
226
|
|
228
227
|
int_maxheap_array_t res = {
|
229
|
-
size_t(n), size_t(k), labels, idistances};
|
228
|
+
size_t(n), size_t(k), labels, idistances.get()};
|
230
229
|
|
231
230
|
if (search_type == ST_HE) {
|
232
231
|
hammings_knn_hc(
|
233
232
|
&res,
|
234
|
-
q_codes,
|
233
|
+
q_codes.get(),
|
235
234
|
codes.data(),
|
236
235
|
ntotal,
|
237
236
|
pq.code_size,
|
@@ -240,7 +239,7 @@ void IndexPQ::search(
|
|
240
239
|
} else if (search_type == ST_generalized_HE) {
|
241
240
|
generalized_hammings_knn_hc(
|
242
241
|
&res,
|
243
|
-
q_codes,
|
242
|
+
q_codes.get(),
|
244
243
|
codes.data(),
|
245
244
|
ntotal,
|
246
245
|
pq.code_size,
|
@@ -263,21 +262,23 @@ void IndexPQStats::reset() {
|
|
263
262
|
|
264
263
|
IndexPQStats indexPQ_stats;
|
265
264
|
|
265
|
+
namespace {
|
266
|
+
|
266
267
|
template <class HammingComputer>
|
267
|
-
|
268
|
-
const IndexPQ
|
268
|
+
size_t polysemous_inner_loop(
|
269
|
+
const IndexPQ* index,
|
269
270
|
const float* dis_table_qi,
|
270
271
|
const uint8_t* q_code,
|
271
272
|
size_t k,
|
272
273
|
float* heap_dis,
|
273
274
|
int64_t* heap_ids,
|
274
275
|
int ht) {
|
275
|
-
int M = index
|
276
|
-
int code_size = index
|
277
|
-
int ksub = index
|
278
|
-
size_t ntotal = index
|
276
|
+
int M = index->pq.M;
|
277
|
+
int code_size = index->pq.code_size;
|
278
|
+
int ksub = index->pq.ksub;
|
279
|
+
size_t ntotal = index->ntotal;
|
279
280
|
|
280
|
-
const uint8_t* b_code = index
|
281
|
+
const uint8_t* b_code = index->codes.data();
|
281
282
|
|
282
283
|
size_t n_pass_i = 0;
|
283
284
|
|
@@ -305,6 +306,16 @@ static size_t polysemous_inner_loop(
|
|
305
306
|
return n_pass_i;
|
306
307
|
}
|
307
308
|
|
309
|
+
struct Run_polysemous_inner_loop {
|
310
|
+
using T = size_t;
|
311
|
+
template <class HammingComputer, class... Types>
|
312
|
+
size_t f(Types... args) {
|
313
|
+
return polysemous_inner_loop<HammingComputer>(args...);
|
314
|
+
}
|
315
|
+
};
|
316
|
+
|
317
|
+
} // anonymous namespace
|
318
|
+
|
308
319
|
void IndexPQ::search_core_polysemous(
|
309
320
|
idx_t n,
|
310
321
|
const float* x,
|
@@ -321,22 +332,20 @@ void IndexPQ::search_core_polysemous(
|
|
321
332
|
}
|
322
333
|
|
323
334
|
// PQ distance tables
|
324
|
-
float
|
325
|
-
|
326
|
-
pq.compute_distance_tables(n, x, dis_tables);
|
335
|
+
std::unique_ptr<float[]> dis_tables(new float[n * pq.ksub * pq.M]);
|
336
|
+
pq.compute_distance_tables(n, x, dis_tables.get());
|
327
337
|
|
328
338
|
// Hamming embedding queries
|
329
|
-
uint8_t
|
330
|
-
ScopeDeleter<uint8_t> del2(q_codes);
|
339
|
+
std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]);
|
331
340
|
|
332
341
|
if (false) {
|
333
|
-
pq.compute_codes(x, q_codes, n);
|
342
|
+
pq.compute_codes(x, q_codes.get(), n);
|
334
343
|
} else {
|
335
344
|
#pragma omp parallel for
|
336
345
|
for (idx_t qi = 0; qi < n; qi++) {
|
337
346
|
pq.compute_code_from_distance_table(
|
338
|
-
dis_tables + qi * pq.M * pq.ksub,
|
339
|
-
q_codes + qi * pq.code_size);
|
347
|
+
dis_tables.get() + qi * pq.M * pq.ksub,
|
348
|
+
q_codes.get() + qi * pq.code_size);
|
340
349
|
}
|
341
350
|
}
|
342
351
|
|
@@ -346,54 +355,33 @@ void IndexPQ::search_core_polysemous(
|
|
346
355
|
|
347
356
|
#pragma omp parallel for reduction(+ : n_pass, bad_code_size)
|
348
357
|
for (idx_t qi = 0; qi < n; qi++) {
|
349
|
-
const uint8_t* q_code = q_codes + qi * pq.code_size;
|
358
|
+
const uint8_t* q_code = q_codes.get() + qi * pq.code_size;
|
350
359
|
|
351
|
-
const float* dis_table_qi = dis_tables + qi * pq.M * pq.ksub;
|
360
|
+
const float* dis_table_qi = dis_tables.get() + qi * pq.M * pq.ksub;
|
352
361
|
|
353
362
|
int64_t* heap_ids = labels + qi * k;
|
354
363
|
float* heap_dis = distances + qi * k;
|
355
364
|
maxheap_heapify(k, heap_dis, heap_ids);
|
356
365
|
|
357
366
|
if (!generalized_hamming) {
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
DISPATCH(4)
|
371
|
-
DISPATCH(8)
|
372
|
-
DISPATCH(16)
|
373
|
-
DISPATCH(32)
|
374
|
-
DISPATCH(20)
|
375
|
-
default:
|
376
|
-
if (pq.code_size % 4 == 0) {
|
377
|
-
n_pass += polysemous_inner_loop<HammingComputerDefault>(
|
378
|
-
*this,
|
379
|
-
dis_table_qi,
|
380
|
-
q_code,
|
381
|
-
k,
|
382
|
-
heap_dis,
|
383
|
-
heap_ids,
|
384
|
-
polysemous_ht);
|
385
|
-
} else {
|
386
|
-
bad_code_size++;
|
387
|
-
}
|
388
|
-
break;
|
389
|
-
}
|
390
|
-
#undef DISPATCH
|
367
|
+
Run_polysemous_inner_loop r;
|
368
|
+
n_pass += dispatch_HammingComputer(
|
369
|
+
pq.code_size,
|
370
|
+
r,
|
371
|
+
this,
|
372
|
+
dis_table_qi,
|
373
|
+
q_code,
|
374
|
+
k,
|
375
|
+
heap_dis,
|
376
|
+
heap_ids,
|
377
|
+
polysemous_ht);
|
378
|
+
|
391
379
|
} else { // generalized hamming
|
392
380
|
switch (pq.code_size) {
|
393
381
|
#define DISPATCH(cs) \
|
394
382
|
case cs: \
|
395
383
|
n_pass += polysemous_inner_loop<GenHammingComputer##cs>( \
|
396
|
-
|
384
|
+
this, \
|
397
385
|
dis_table_qi, \
|
398
386
|
q_code, \
|
399
387
|
k, \
|
@@ -407,7 +395,7 @@ void IndexPQ::search_core_polysemous(
|
|
407
395
|
default:
|
408
396
|
if (pq.code_size % 8 == 0) {
|
409
397
|
n_pass += polysemous_inner_loop<GenHammingComputerM8>(
|
410
|
-
|
398
|
+
this,
|
411
399
|
dis_table_qi,
|
412
400
|
q_code,
|
413
401
|
k,
|
@@ -450,12 +438,11 @@ void IndexPQ::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
|
|
450
438
|
|
451
439
|
void IndexPQ::hamming_distance_table(idx_t n, const float* x, int32_t* dis)
|
452
440
|
const {
|
453
|
-
uint8_t
|
454
|
-
ScopeDeleter<uint8_t> del(q_codes);
|
441
|
+
std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]);
|
455
442
|
|
456
|
-
pq.compute_codes(x, q_codes, n);
|
443
|
+
pq.compute_codes(x, q_codes.get(), n);
|
457
444
|
|
458
|
-
hammings(q_codes, codes.data(), n, ntotal, pq.code_size, dis);
|
445
|
+
hammings(q_codes.get(), codes.data(), n, ntotal, pq.code_size, dis);
|
459
446
|
}
|
460
447
|
|
461
448
|
void IndexPQ::hamming_distance_histogram(
|
@@ -469,16 +456,15 @@ void IndexPQ::hamming_distance_histogram(
|
|
469
456
|
FAISS_THROW_IF_NOT(pq.nbits == 8);
|
470
457
|
|
471
458
|
// Hamming embedding queries
|
472
|
-
uint8_t
|
473
|
-
|
474
|
-
pq.compute_codes(x, q_codes, n);
|
459
|
+
std::unique_ptr<uint8_t[]> q_codes(new uint8_t[n * pq.code_size]);
|
460
|
+
pq.compute_codes(x, q_codes.get(), n);
|
475
461
|
|
476
462
|
uint8_t* b_codes;
|
477
|
-
|
463
|
+
std::unique_ptr<uint8_t[]> del_b_codes;
|
478
464
|
|
479
465
|
if (xb) {
|
480
466
|
b_codes = new uint8_t[nb * pq.code_size];
|
481
|
-
del_b_codes.
|
467
|
+
del_b_codes.reset(b_codes);
|
482
468
|
pq.compute_codes(xb, b_codes, nb);
|
483
469
|
} else {
|
484
470
|
nb = ntotal;
|
@@ -491,8 +477,7 @@ void IndexPQ::hamming_distance_histogram(
|
|
491
477
|
#pragma omp parallel
|
492
478
|
{
|
493
479
|
std::vector<int64_t> histi(nbits + 1);
|
494
|
-
hamdis_t
|
495
|
-
ScopeDeleter<hamdis_t> del(distances);
|
480
|
+
std::unique_ptr<hamdis_t[]> distances(new hamdis_t[nb * bs]);
|
496
481
|
#pragma omp for
|
497
482
|
for (idx_t q0 = 0; q0 < n; q0 += bs) {
|
498
483
|
// printf ("dis stats: %zd/%zd\n", q0, n);
|
@@ -501,12 +486,12 @@ void IndexPQ::hamming_distance_histogram(
|
|
501
486
|
q1 = n;
|
502
487
|
|
503
488
|
hammings(
|
504
|
-
q_codes + q0 * pq.code_size,
|
489
|
+
q_codes.get() + q0 * pq.code_size,
|
505
490
|
b_codes,
|
506
491
|
q1 - q0,
|
507
492
|
nb,
|
508
493
|
pq.code_size,
|
509
|
-
distances);
|
494
|
+
distances.get());
|
510
495
|
|
511
496
|
for (size_t i = 0; i < nb * (q1 - q0); i++)
|
512
497
|
histi[distances[i]]++;
|
@@ -639,7 +624,7 @@ struct SemiSortedArray {
|
|
639
624
|
int N;
|
640
625
|
|
641
626
|
// type of the heap: CMax = sort ascending
|
642
|
-
|
627
|
+
using HC = CMax<T, int>;
|
643
628
|
std::vector<int> perm;
|
644
629
|
|
645
630
|
int k; // k elements are sorted
|
@@ -733,7 +718,7 @@ struct MinSumK {
|
|
733
718
|
* We use a heap to maintain a queue of sums, with the associated
|
734
719
|
* terms involved in the sum.
|
735
720
|
*/
|
736
|
-
|
721
|
+
using HC = CMin<T, int64_t>;
|
737
722
|
size_t heap_capacity, heap_size;
|
738
723
|
T* bh_val;
|
739
724
|
int64_t* bh_ids;
|
@@ -827,7 +812,7 @@ struct MinSumK {
|
|
827
812
|
// enqueue followers
|
828
813
|
int64_t ii = ti;
|
829
814
|
for (int m = 0; m < M; m++) {
|
830
|
-
int64_t n = ii & ((
|
815
|
+
int64_t n = ii & (((int64_t)1 << nbit) - 1);
|
831
816
|
ii >>= nbit;
|
832
817
|
if (n + 1 >= N)
|
833
818
|
continue;
|
@@ -851,7 +836,7 @@ struct MinSumK {
|
|
851
836
|
}
|
852
837
|
int64_t ti = 0;
|
853
838
|
for (int m = 0; m < M; m++) {
|
854
|
-
int64_t n = ii & ((
|
839
|
+
int64_t n = ii & (((int64_t)1 << nbit) - 1);
|
855
840
|
ti += int64_t(ssx[m].get_ord(n)) << (nbit * m);
|
856
841
|
ii >>= nbit;
|
857
842
|
}
|
@@ -923,17 +908,16 @@ void MultiIndexQuantizer::search(
|
|
923
908
|
return;
|
924
909
|
}
|
925
910
|
|
926
|
-
float
|
927
|
-
ScopeDeleter<float> del(dis_tables);
|
911
|
+
std::unique_ptr<float[]> dis_tables(new float[n * pq.ksub * pq.M]);
|
928
912
|
|
929
|
-
pq.compute_distance_tables(n, x, dis_tables);
|
913
|
+
pq.compute_distance_tables(n, x, dis_tables.get());
|
930
914
|
|
931
915
|
if (k == 1) {
|
932
916
|
// simple version that just finds the min in each table
|
933
917
|
|
934
918
|
#pragma omp parallel for
|
935
919
|
for (int i = 0; i < n; i++) {
|
936
|
-
const float* dis_table = dis_tables + i * pq.ksub * pq.M;
|
920
|
+
const float* dis_table = dis_tables.get() + i * pq.ksub * pq.M;
|
937
921
|
float dis = 0;
|
938
922
|
idx_t label = 0;
|
939
923
|
|
@@ -963,7 +947,7 @@ void MultiIndexQuantizer::search(
|
|
963
947
|
k, pq.M, pq.nbits, pq.ksub);
|
964
948
|
#pragma omp for
|
965
949
|
for (int i = 0; i < n; i++) {
|
966
|
-
msk.run(dis_tables + i * pq.ksub * pq.M,
|
950
|
+
msk.run(dis_tables.get() + i * pq.ksub * pq.M,
|
967
951
|
pq.ksub,
|
968
952
|
distances + i * k,
|
969
953
|
labels + i * k);
|
@@ -975,7 +959,7 @@ void MultiIndexQuantizer::search(
|
|
975
959
|
void MultiIndexQuantizer::reconstruct(idx_t key, float* recons) const {
|
976
960
|
int64_t jj = key;
|
977
961
|
for (int m = 0; m < pq.M; m++) {
|
978
|
-
int64_t n = jj & ((
|
962
|
+
int64_t n = jj & (((int64_t)1 << pq.nbits) - 1);
|
979
963
|
jj >>= pq.nbits;
|
980
964
|
memcpy(recons, pq.get_centroids(m, n), sizeof(recons[0]) * pq.dsub);
|
981
965
|
recons += pq.dsub;
|
@@ -1107,7 +1091,7 @@ void MultiIndexQuantizer2::search(
|
|
1107
1091
|
|
1108
1092
|
const idx_t* idmap0 = sub_ids.data() + i * k2;
|
1109
1093
|
int64_t ld_idmap = k2 * n;
|
1110
|
-
int64_t mask1 = ksub -
|
1094
|
+
int64_t mask1 = ksub - (int64_t)1;
|
1111
1095
|
|
1112
1096
|
for (int k = 0; k < K; k++) {
|
1113
1097
|
const idx_t* idmap = idmap0;
|
@@ -31,10 +31,7 @@ struct IndexPQ : IndexFlatCodes {
|
|
31
31
|
* @param M number of subquantizers
|
32
32
|
* @param nbits number of bit per subvector index
|
33
33
|
*/
|
34
|
-
IndexPQ(int d,
|
35
|
-
size_t M, ///< number of subquantizers
|
36
|
-
size_t nbits, ///< number of bit per subvector index
|
37
|
-
MetricType metric = METRIC_L2);
|
34
|
+
IndexPQ(int d, size_t M, size_t nbits, MetricType metric = METRIC_L2);
|
38
35
|
|
39
36
|
IndexPQ();
|
40
37
|
|
@@ -67,7 +67,7 @@ void IndexPreTransform::train(idx_t n, const float* x) {
|
|
67
67
|
}
|
68
68
|
}
|
69
69
|
const float* prev_x = x;
|
70
|
-
|
70
|
+
std::unique_ptr<const float[]> del;
|
71
71
|
|
72
72
|
if (verbose) {
|
73
73
|
printf("IndexPreTransform::train: training chain 0 to %d\n",
|
@@ -102,10 +102,12 @@ void IndexPreTransform::train(idx_t n, const float* x) {
|
|
102
102
|
|
103
103
|
float* xt = chain[i]->apply(n, prev_x);
|
104
104
|
|
105
|
-
if (prev_x != x)
|
106
|
-
|
105
|
+
if (prev_x != x) {
|
106
|
+
del.reset();
|
107
|
+
}
|
108
|
+
|
107
109
|
prev_x = xt;
|
108
|
-
del.
|
110
|
+
del.reset(xt);
|
109
111
|
}
|
110
112
|
|
111
113
|
is_trained = true;
|
@@ -113,11 +115,11 @@ void IndexPreTransform::train(idx_t n, const float* x) {
|
|
113
115
|
|
114
116
|
const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
|
115
117
|
const float* prev_x = x;
|
116
|
-
|
118
|
+
std::unique_ptr<const float[]> del;
|
117
119
|
|
118
120
|
for (int i = 0; i < chain.size(); i++) {
|
119
121
|
float* xt = chain[i]->apply(n, prev_x);
|
120
|
-
|
122
|
+
std::unique_ptr<const float[]> del2(xt);
|
121
123
|
del2.swap(del);
|
122
124
|
prev_x = xt;
|
123
125
|
}
|
@@ -128,11 +130,11 @@ const float* IndexPreTransform::apply_chain(idx_t n, const float* x) const {
|
|
128
130
|
void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
|
129
131
|
const {
|
130
132
|
const float* next_x = xt;
|
131
|
-
|
133
|
+
std::unique_ptr<const float[]> del;
|
132
134
|
|
133
135
|
for (int i = chain.size() - 1; i >= 0; i--) {
|
134
136
|
float* prev_x = (i == 0) ? x : new float[n * chain[i]->d_in];
|
135
|
-
|
137
|
+
std::unique_ptr<const float[]> del2((prev_x == x) ? nullptr : prev_x);
|
136
138
|
chain[i]->reverse_transform(n, next_x, prev_x);
|
137
139
|
del2.swap(del);
|
138
140
|
next_x = prev_x;
|
@@ -141,9 +143,8 @@ void IndexPreTransform::reverse_chain(idx_t n, const float* xt, float* x)
|
|
141
143
|
|
142
144
|
void IndexPreTransform::add(idx_t n, const float* x) {
|
143
145
|
FAISS_THROW_IF_NOT(is_trained);
|
144
|
-
|
145
|
-
|
146
|
-
index->add(n, xt);
|
146
|
+
TransformedVectors tv(x, apply_chain(n, x));
|
147
|
+
index->add(n, tv.x);
|
147
148
|
ntotal = index->ntotal;
|
148
149
|
}
|
149
150
|
|
@@ -152,9 +153,8 @@ void IndexPreTransform::add_with_ids(
|
|
152
153
|
const float* x,
|
153
154
|
const idx_t* xids) {
|
154
155
|
FAISS_THROW_IF_NOT(is_trained);
|
155
|
-
|
156
|
-
|
157
|
-
index->add_with_ids(n, xt, xids);
|
156
|
+
TransformedVectors tv(x, apply_chain(n, x));
|
157
|
+
index->add_with_ids(n, tv.x, xids);
|
158
158
|
ntotal = index->ntotal;
|
159
159
|
}
|
160
160
|
|
@@ -178,7 +178,7 @@ void IndexPreTransform::search(
|
|
178
178
|
FAISS_THROW_IF_NOT(k > 0);
|
179
179
|
FAISS_THROW_IF_NOT(is_trained);
|
180
180
|
const float* xt = apply_chain(n, x);
|
181
|
-
|
181
|
+
std::unique_ptr<const float[]> del(xt == x ? nullptr : xt);
|
182
182
|
index->search(
|
183
183
|
n, xt, k, distances, labels, extract_index_search_params(params));
|
184
184
|
}
|
@@ -190,10 +190,9 @@ void IndexPreTransform::range_search(
|
|
190
190
|
RangeSearchResult* result,
|
191
191
|
const SearchParameters* params) const {
|
192
192
|
FAISS_THROW_IF_NOT(is_trained);
|
193
|
-
|
194
|
-
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
193
|
+
TransformedVectors tv(x, apply_chain(n, x));
|
195
194
|
index->range_search(
|
196
|
-
n,
|
195
|
+
n, tv.x, radius, result, extract_index_search_params(params));
|
197
196
|
}
|
198
197
|
|
199
198
|
void IndexPreTransform::reset() {
|
@@ -209,7 +208,7 @@ size_t IndexPreTransform::remove_ids(const IDSelector& sel) {
|
|
209
208
|
|
210
209
|
void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
|
211
210
|
float* x = chain.empty() ? recons : new float[index->d];
|
212
|
-
|
211
|
+
std::unique_ptr<float[]> del(recons == x ? nullptr : x);
|
213
212
|
// Initial reconstruction
|
214
213
|
index->reconstruct(key, x);
|
215
214
|
|
@@ -219,7 +218,7 @@ void IndexPreTransform::reconstruct(idx_t key, float* recons) const {
|
|
219
218
|
|
220
219
|
void IndexPreTransform::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
|
221
220
|
float* x = chain.empty() ? recons : new float[ni * index->d];
|
222
|
-
|
221
|
+
std::unique_ptr<float[]> del(recons == x ? nullptr : x);
|
223
222
|
// Initial reconstruction
|
224
223
|
index->reconstruct_n(i0, ni, x);
|
225
224
|
|
@@ -238,14 +237,14 @@ void IndexPreTransform::search_and_reconstruct(
|
|
238
237
|
FAISS_THROW_IF_NOT(k > 0);
|
239
238
|
FAISS_THROW_IF_NOT(is_trained);
|
240
239
|
|
241
|
-
|
242
|
-
ScopeDeleter<float> del((xt == x) ? nullptr : xt);
|
240
|
+
TransformedVectors trans(x, apply_chain(n, x));
|
243
241
|
|
244
242
|
float* recons_temp = chain.empty() ? recons : new float[n * k * index->d];
|
245
|
-
|
243
|
+
std::unique_ptr<float[]> del2(
|
244
|
+
(recons_temp == recons) ? nullptr : recons_temp);
|
246
245
|
index->search_and_reconstruct(
|
247
246
|
n,
|
248
|
-
|
247
|
+
trans.x,
|
249
248
|
k,
|
250
249
|
distances,
|
251
250
|
labels,
|
@@ -262,13 +261,8 @@ size_t IndexPreTransform::sa_code_size() const {
|
|
262
261
|
|
263
262
|
void IndexPreTransform::sa_encode(idx_t n, const float* x, uint8_t* bytes)
|
264
263
|
const {
|
265
|
-
|
266
|
-
|
267
|
-
} else {
|
268
|
-
const float* xt = apply_chain(n, x);
|
269
|
-
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
270
|
-
index->sa_encode(n, xt, bytes);
|
271
|
-
}
|
264
|
+
TransformedVectors tv(x, apply_chain(n, x));
|
265
|
+
index->sa_encode(n, tv.x, bytes);
|
272
266
|
}
|
273
267
|
|
274
268
|
void IndexPreTransform::sa_decode(idx_t n, const uint8_t* bytes, float* x)
|
@@ -62,7 +62,7 @@ void IndexRefine::reset() {
|
|
62
62
|
|
63
63
|
namespace {
|
64
64
|
|
65
|
-
|
65
|
+
using idx_t = faiss::idx_t;
|
66
66
|
|
67
67
|
template <class C>
|
68
68
|
static void reorder_2_heaps(
|
@@ -96,25 +96,40 @@ void IndexRefine::search(
|
|
96
96
|
idx_t k,
|
97
97
|
float* distances,
|
98
98
|
idx_t* labels,
|
99
|
-
const SearchParameters*
|
100
|
-
|
101
|
-
|
99
|
+
const SearchParameters* params_in) const {
|
100
|
+
const IndexRefineSearchParameters* params = nullptr;
|
101
|
+
if (params_in) {
|
102
|
+
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
|
103
|
+
FAISS_THROW_IF_NOT_MSG(
|
104
|
+
params, "IndexRefine params have incorrect type");
|
105
|
+
}
|
106
|
+
|
107
|
+
idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
|
108
|
+
: idx_t(k * k_factor);
|
109
|
+
SearchParameters* base_index_params =
|
110
|
+
(params != nullptr) ? params->base_index_params : nullptr;
|
111
|
+
|
112
|
+
FAISS_THROW_IF_NOT(k_base >= k);
|
113
|
+
|
114
|
+
FAISS_THROW_IF_NOT(base_index);
|
115
|
+
FAISS_THROW_IF_NOT(refine_index);
|
116
|
+
|
102
117
|
FAISS_THROW_IF_NOT(k > 0);
|
103
118
|
FAISS_THROW_IF_NOT(is_trained);
|
104
|
-
idx_t k_base = idx_t(k * k_factor);
|
105
119
|
idx_t* base_labels = labels;
|
106
120
|
float* base_distances = distances;
|
107
|
-
|
108
|
-
|
121
|
+
std::unique_ptr<idx_t[]> del1;
|
122
|
+
std::unique_ptr<float[]> del2;
|
109
123
|
|
110
124
|
if (k != k_base) {
|
111
125
|
base_labels = new idx_t[n * k_base];
|
112
|
-
del1.
|
126
|
+
del1.reset(base_labels);
|
113
127
|
base_distances = new float[n * k_base];
|
114
|
-
del2.
|
128
|
+
del2.reset(base_distances);
|
115
129
|
}
|
116
130
|
|
117
|
-
base_index->search(
|
131
|
+
base_index->search(
|
132
|
+
n, x, k_base, base_distances, base_labels, base_index_params);
|
118
133
|
|
119
134
|
for (int i = 0; i < n * k_base; i++)
|
120
135
|
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
|
@@ -225,25 +240,40 @@ void IndexRefineFlat::search(
|
|
225
240
|
idx_t k,
|
226
241
|
float* distances,
|
227
242
|
idx_t* labels,
|
228
|
-
const SearchParameters*
|
229
|
-
|
230
|
-
|
243
|
+
const SearchParameters* params_in) const {
|
244
|
+
const IndexRefineSearchParameters* params = nullptr;
|
245
|
+
if (params_in) {
|
246
|
+
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
|
247
|
+
FAISS_THROW_IF_NOT_MSG(
|
248
|
+
params, "IndexRefineFlat params have incorrect type");
|
249
|
+
}
|
250
|
+
|
251
|
+
idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
|
252
|
+
: idx_t(k * k_factor);
|
253
|
+
SearchParameters* base_index_params =
|
254
|
+
(params != nullptr) ? params->base_index_params : nullptr;
|
255
|
+
|
256
|
+
FAISS_THROW_IF_NOT(k_base >= k);
|
257
|
+
|
258
|
+
FAISS_THROW_IF_NOT(base_index);
|
259
|
+
FAISS_THROW_IF_NOT(refine_index);
|
260
|
+
|
231
261
|
FAISS_THROW_IF_NOT(k > 0);
|
232
262
|
FAISS_THROW_IF_NOT(is_trained);
|
233
|
-
idx_t k_base = idx_t(k * k_factor);
|
234
263
|
idx_t* base_labels = labels;
|
235
264
|
float* base_distances = distances;
|
236
|
-
|
237
|
-
|
265
|
+
std::unique_ptr<idx_t[]> del1;
|
266
|
+
std::unique_ptr<float[]> del2;
|
238
267
|
|
239
268
|
if (k != k_base) {
|
240
269
|
base_labels = new idx_t[n * k_base];
|
241
|
-
del1.
|
270
|
+
del1.reset(base_labels);
|
242
271
|
base_distances = new float[n * k_base];
|
243
|
-
del2.
|
272
|
+
del2.reset(base_distances);
|
244
273
|
}
|
245
274
|
|
246
|
-
base_index->search(
|
275
|
+
base_index->search(
|
276
|
+
n, x, k_base, base_distances, base_labels, base_index_params);
|
247
277
|
|
248
278
|
for (int i = 0; i < n * k_base; i++)
|
249
279
|
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
|
@@ -11,6 +11,13 @@
|
|
11
11
|
|
12
12
|
namespace faiss {
|
13
13
|
|
14
|
+
struct IndexRefineSearchParameters : SearchParameters {
|
15
|
+
float k_factor = 1;
|
16
|
+
SearchParameters* base_index_params = nullptr; // non-owning
|
17
|
+
|
18
|
+
virtual ~IndexRefineSearchParameters() = default;
|
19
|
+
};
|
20
|
+
|
14
21
|
/** Index that queries in a base_index (a fast one) and refines the
|
15
22
|
* results with an exact search, hopefully improving the results.
|
16
23
|
*/
|