faiss 0.2.7 → 0.3.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +0 -1
- data/vendor/faiss/faiss/Clustering.cpp +4 -18
- data/vendor/faiss/faiss/Clustering.h +31 -21
- data/vendor/faiss/faiss/IVFlib.cpp +22 -11
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +20 -5
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
- data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
- data/vendor/faiss/faiss/IndexHNSW.h +12 -48
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
- data/vendor/faiss/faiss/IndexIVF.h +37 -5
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
- data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +10 -10
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
- data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
- data/vendor/faiss/faiss/impl/HNSW.h +9 -8
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
- data/vendor/faiss/faiss/impl/io.cpp +10 -10
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
- data/vendor/faiss/faiss/index_factory.cpp +10 -7
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
- data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
- data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/distances.cpp +128 -74
- data/vendor/faiss/faiss/utils/distances.h +81 -4
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/utils.cpp +112 -6
- data/vendor/faiss/faiss/utils/utils.h +57 -20
- metadata +11 -4
@@ -195,8 +195,9 @@ void NNDescent::update() {
|
|
195
195
|
int l = 0;
|
196
196
|
|
197
197
|
while ((l < maxl) && (c < S)) {
|
198
|
-
if (nn.pool[l].flag)
|
198
|
+
if (nn.pool[l].flag) {
|
199
199
|
++c;
|
200
|
+
}
|
200
201
|
++l;
|
201
202
|
}
|
202
203
|
nn.M = l;
|
@@ -305,8 +306,9 @@ void NNDescent::generate_eval_set(
|
|
305
306
|
for (int i = 0; i < c.size(); i++) {
|
306
307
|
std::vector<Neighbor> tmp;
|
307
308
|
for (int j = 0; j < N; j++) {
|
308
|
-
if (c[i] == j)
|
309
|
+
if (c[i] == j) {
|
309
310
|
continue; // skip itself
|
311
|
+
}
|
310
312
|
float dist = qdis.symmetric_dis(c[i], j);
|
311
313
|
tmp.push_back(Neighbor(j, dist, true));
|
312
314
|
}
|
@@ -360,8 +362,9 @@ void NNDescent::init_graph(DistanceComputer& qdis) {
|
|
360
362
|
|
361
363
|
for (int j = 0; j < S; j++) {
|
362
364
|
int id = tmp[j];
|
363
|
-
if (id == i)
|
365
|
+
if (id == i) {
|
364
366
|
continue;
|
367
|
+
}
|
365
368
|
float dist = qdis.symmetric_dis(i, id);
|
366
369
|
|
367
370
|
graph[i].pool.push_back(Neighbor(id, dist, true));
|
@@ -374,6 +377,10 @@ void NNDescent::init_graph(DistanceComputer& qdis) {
|
|
374
377
|
|
375
378
|
void NNDescent::build(DistanceComputer& qdis, const int n, bool verbose) {
|
376
379
|
FAISS_THROW_IF_NOT_MSG(L >= K, "L should be >= K in NNDescent.build");
|
380
|
+
FAISS_THROW_IF_NOT_FMT(
|
381
|
+
n > NUM_EVAL_POINTS,
|
382
|
+
"NNDescent.build cannot build a graph smaller than %d",
|
383
|
+
int(NUM_EVAL_POINTS));
|
377
384
|
|
378
385
|
if (verbose) {
|
379
386
|
printf("Parameters: K=%d, S=%d, R=%d, L=%d, iter=%d\n",
|
@@ -403,7 +410,7 @@ void NNDescent::build(DistanceComputer& qdis, const int n, bool verbose) {
|
|
403
410
|
has_built = true;
|
404
411
|
|
405
412
|
if (verbose) {
|
406
|
-
printf("
|
413
|
+
printf("Added %d points into the index\n", ntotal);
|
407
414
|
}
|
408
415
|
}
|
409
416
|
|
@@ -414,30 +421,30 @@ void NNDescent::search(
|
|
414
421
|
float* dists,
|
415
422
|
VisitedTable& vt) const {
|
416
423
|
FAISS_THROW_IF_NOT_MSG(has_built, "The index is not build yet.");
|
417
|
-
int
|
424
|
+
int L_2 = std::max(search_L, topk);
|
418
425
|
|
419
426
|
// candidate pool, the K best items is the result.
|
420
|
-
std::vector<Neighbor> retset(
|
427
|
+
std::vector<Neighbor> retset(L_2 + 1);
|
421
428
|
|
422
|
-
// Randomly choose
|
423
|
-
std::vector<int> init_ids(
|
429
|
+
// Randomly choose L_2 points to initialize the candidate pool
|
430
|
+
std::vector<int> init_ids(L_2);
|
424
431
|
std::mt19937 rng(random_seed);
|
425
432
|
|
426
|
-
gen_random(rng, init_ids.data(),
|
427
|
-
for (int i = 0; i <
|
433
|
+
gen_random(rng, init_ids.data(), L_2, ntotal);
|
434
|
+
for (int i = 0; i < L_2; i++) {
|
428
435
|
int id = init_ids[i];
|
429
436
|
float dist = qdis(id);
|
430
437
|
retset[i] = Neighbor(id, dist, true);
|
431
438
|
}
|
432
439
|
|
433
440
|
// Maintain the candidate pool in ascending order
|
434
|
-
std::sort(retset.begin(), retset.begin() +
|
441
|
+
std::sort(retset.begin(), retset.begin() + L_2);
|
435
442
|
|
436
443
|
int k = 0;
|
437
444
|
|
438
|
-
// Stop until the smallest position updated is >=
|
439
|
-
while (k <
|
440
|
-
int nk =
|
445
|
+
// Stop until the smallest position updated is >= L_2
|
446
|
+
while (k < L_2) {
|
447
|
+
int nk = L_2;
|
441
448
|
|
442
449
|
if (retset[k].flag) {
|
443
450
|
retset[k].flag = false;
|
@@ -445,25 +452,28 @@ void NNDescent::search(
|
|
445
452
|
|
446
453
|
for (int m = 0; m < K; ++m) {
|
447
454
|
int id = final_graph[n * K + m];
|
448
|
-
if (vt.get(id))
|
455
|
+
if (vt.get(id)) {
|
449
456
|
continue;
|
457
|
+
}
|
450
458
|
|
451
459
|
vt.set(id);
|
452
460
|
float dist = qdis(id);
|
453
|
-
if (dist >= retset[
|
461
|
+
if (dist >= retset[L_2 - 1].distance) {
|
454
462
|
continue;
|
463
|
+
}
|
455
464
|
|
456
465
|
Neighbor nn(id, dist, true);
|
457
|
-
int r = insert_into_pool(retset.data(),
|
466
|
+
int r = insert_into_pool(retset.data(), L_2, nn);
|
458
467
|
|
459
468
|
if (r < nk)
|
460
469
|
nk = r;
|
461
470
|
}
|
462
471
|
}
|
463
|
-
if (nk <= k)
|
472
|
+
if (nk <= k) {
|
464
473
|
k = nk;
|
465
|
-
else
|
474
|
+
} else {
|
466
475
|
++k;
|
476
|
+
}
|
467
477
|
}
|
468
478
|
for (size_t i = 0; i < topk; i++) {
|
469
479
|
indices[i] = retset[i].id;
|
@@ -54,7 +54,7 @@ namespace nsg {
|
|
54
54
|
|
55
55
|
template <class node_t>
|
56
56
|
struct Graph {
|
57
|
-
node_t* data; ///< the flattened adjacency matrix
|
57
|
+
node_t* data; ///< the flattened adjacency matrix, size N-by-K
|
58
58
|
int K; ///< nb of neighbors per node
|
59
59
|
int N; ///< total nb of nodes
|
60
60
|
bool own_fields; ///< the underlying data owned by itself or not
|
@@ -12,11 +12,11 @@
|
|
12
12
|
#include <omp.h>
|
13
13
|
#include <stdint.h>
|
14
14
|
|
15
|
+
#include <algorithm>
|
15
16
|
#include <cmath>
|
16
17
|
#include <cstdlib>
|
17
18
|
#include <cstring>
|
18
|
-
|
19
|
-
#include <algorithm>
|
19
|
+
#include <memory>
|
20
20
|
|
21
21
|
#include <faiss/utils/distances.h>
|
22
22
|
#include <faiss/utils/hamming.h>
|
@@ -683,18 +683,21 @@ struct RankingScore2 : Score3Computer<float, double> {
|
|
683
683
|
double accum_gt_weight_diff(
|
684
684
|
const std::vector<int>& a,
|
685
685
|
const std::vector<int>& b) {
|
686
|
-
|
686
|
+
const auto nb_2 = b.size();
|
687
|
+
const auto na = a.size();
|
687
688
|
|
688
689
|
double accu = 0;
|
689
|
-
|
690
|
-
for (
|
691
|
-
|
692
|
-
while (j <
|
690
|
+
size_t j = 0;
|
691
|
+
for (size_t i = 0; i < na; i++) {
|
692
|
+
const auto ai = a[i];
|
693
|
+
while (j < nb_2 && ai >= b[j]) {
|
693
694
|
j++;
|
695
|
+
}
|
694
696
|
|
695
697
|
double accu_i = 0;
|
696
|
-
for (
|
698
|
+
for (auto k = j; k < b.size(); k++) {
|
697
699
|
accu_i += rank_weight(b[k] - ai);
|
700
|
+
}
|
698
701
|
|
699
702
|
accu += rank_weight(ai) * accu_i;
|
700
703
|
}
|
@@ -882,14 +885,13 @@ void PolysemousTraining::optimize_ranking(
|
|
882
885
|
|
883
886
|
double t0 = getmillisecs();
|
884
887
|
|
885
|
-
PermutationObjective
|
888
|
+
std::unique_ptr<PermutationObjective> obj(new RankingScore2(
|
886
889
|
nbits,
|
887
890
|
nq,
|
888
891
|
nb,
|
889
892
|
codes.data(),
|
890
893
|
codes.data() + nq,
|
891
|
-
gt_distances.data());
|
892
|
-
ScopeDeleter1<PermutationObjective> del(obj);
|
894
|
+
gt_distances.data()));
|
893
895
|
|
894
896
|
if (verbose > 0) {
|
895
897
|
printf(" m=%d, nq=%zd, nb=%zd, initialize RankingScore "
|
@@ -900,7 +902,7 @@ void PolysemousTraining::optimize_ranking(
|
|
900
902
|
getmillisecs() - t0);
|
901
903
|
}
|
902
904
|
|
903
|
-
SimulatedAnnealingOptimizer optim(obj, *this);
|
905
|
+
SimulatedAnnealingOptimizer optim(obj.get(), *this);
|
904
906
|
|
905
907
|
if (log_pattern.size()) {
|
906
908
|
char fname[256];
|
@@ -135,11 +135,10 @@ void ProductQuantizer::train(size_t n, const float* x) {
|
|
135
135
|
}
|
136
136
|
}
|
137
137
|
|
138
|
-
float
|
139
|
-
ScopeDeleter<float> del(xslice);
|
138
|
+
std::unique_ptr<float[]> xslice(new float[n * dsub]);
|
140
139
|
for (int m = 0; m < M; m++) {
|
141
140
|
for (int j = 0; j < n; j++)
|
142
|
-
memcpy(xslice + j * dsub,
|
141
|
+
memcpy(xslice.get() + j * dsub,
|
143
142
|
x + j * d + m * dsub,
|
144
143
|
dsub * sizeof(float));
|
145
144
|
|
@@ -153,11 +152,19 @@ void ProductQuantizer::train(size_t n, const float* x) {
|
|
153
152
|
switch (final_train_type) {
|
154
153
|
case Train_hypercube:
|
155
154
|
init_hypercube(
|
156
|
-
dsub,
|
155
|
+
dsub,
|
156
|
+
nbits,
|
157
|
+
n,
|
158
|
+
xslice.get(),
|
159
|
+
clus.centroids.data());
|
157
160
|
break;
|
158
161
|
case Train_hypercube_pca:
|
159
162
|
init_hypercube_pca(
|
160
|
-
dsub,
|
163
|
+
dsub,
|
164
|
+
nbits,
|
165
|
+
n,
|
166
|
+
xslice.get(),
|
167
|
+
clus.centroids.data());
|
161
168
|
break;
|
162
169
|
case Train_hot_start:
|
163
170
|
memcpy(clus.centroids.data(),
|
@@ -172,7 +179,7 @@ void ProductQuantizer::train(size_t n, const float* x) {
|
|
172
179
|
printf("Training PQ slice %d/%zd\n", m, M);
|
173
180
|
}
|
174
181
|
IndexFlatL2 index(dsub);
|
175
|
-
clus.train(n, xslice, assign_index ? *assign_index : index);
|
182
|
+
clus.train(n, xslice.get(), assign_index ? *assign_index : index);
|
176
183
|
set_params(clus.centroids.data(), m);
|
177
184
|
}
|
178
185
|
|
@@ -306,7 +313,8 @@ void ProductQuantizer::decode(const uint8_t* code, float* x) const {
|
|
306
313
|
}
|
307
314
|
|
308
315
|
void ProductQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
|
309
|
-
|
316
|
+
#pragma omp parallel for if (n > 100)
|
317
|
+
for (int64_t i = 0; i < n; i++) {
|
310
318
|
this->decode(code + code_size * i, x + d * i);
|
311
319
|
}
|
312
320
|
}
|
@@ -342,21 +350,20 @@ void ProductQuantizer::compute_codes_with_assign_index(
|
|
342
350
|
assign_index->reset();
|
343
351
|
assign_index->add(ksub, get_centroids(m, 0));
|
344
352
|
size_t bs = 65536;
|
345
|
-
|
346
|
-
|
347
|
-
idx_t
|
348
|
-
ScopeDeleter<idx_t> del2(assign);
|
353
|
+
|
354
|
+
std::unique_ptr<float[]> xslice(new float[bs * dsub]);
|
355
|
+
std::unique_ptr<idx_t[]> assign(new idx_t[bs]);
|
349
356
|
|
350
357
|
for (size_t i0 = 0; i0 < n; i0 += bs) {
|
351
358
|
size_t i1 = std::min(i0 + bs, n);
|
352
359
|
|
353
360
|
for (size_t i = i0; i < i1; i++) {
|
354
|
-
memcpy(xslice + (i - i0) * dsub,
|
361
|
+
memcpy(xslice.get() + (i - i0) * dsub,
|
355
362
|
x + i * d + m * dsub,
|
356
363
|
dsub * sizeof(float));
|
357
364
|
}
|
358
365
|
|
359
|
-
assign_index->assign(i1 - i0, xslice, assign);
|
366
|
+
assign_index->assign(i1 - i0, xslice.get(), assign.get());
|
360
367
|
|
361
368
|
if (nbits == 8) {
|
362
369
|
uint8_t* c = codes + code_size * i0 + m;
|
@@ -405,15 +412,14 @@ void ProductQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
|
|
405
412
|
for (int64_t i = 0; i < n; i++)
|
406
413
|
compute_code(x + i * d, codes + i * code_size);
|
407
414
|
|
408
|
-
} else { //
|
409
|
-
float
|
410
|
-
|
411
|
-
compute_distance_tables(n, x, dis_tables);
|
415
|
+
} else { // worthwhile to use BLAS
|
416
|
+
std::unique_ptr<float[]> dis_tables(new float[n * ksub * M]);
|
417
|
+
compute_distance_tables(n, x, dis_tables.get());
|
412
418
|
|
413
419
|
#pragma omp parallel for
|
414
420
|
for (int64_t i = 0; i < n; i++) {
|
415
421
|
uint8_t* code = codes + i * code_size;
|
416
|
-
const float* tab = dis_tables + i * ksub * M;
|
422
|
+
const float* tab = dis_tables.get() + i * ksub * M;
|
417
423
|
compute_code_from_distance_table(tab, code);
|
418
424
|
}
|
419
425
|
}
|
@@ -774,10 +780,6 @@ void ProductQuantizer::search_ip(
|
|
774
780
|
init_finalize_heap);
|
775
781
|
}
|
776
782
|
|
777
|
-
static float sqr(float x) {
|
778
|
-
return x * x;
|
779
|
-
}
|
780
|
-
|
781
783
|
void ProductQuantizer::compute_sdc_table() {
|
782
784
|
sdc_table.resize(M * ksub * ksub);
|
783
785
|
|
@@ -35,7 +35,7 @@ struct ProductQuantizer : Quantizer {
|
|
35
35
|
enum train_type_t {
|
36
36
|
Train_default,
|
37
37
|
Train_hot_start, ///< the centroids are already initialized
|
38
|
-
Train_shared, ///< share dictionary
|
38
|
+
Train_shared, ///< share dictionary across PQ segments
|
39
39
|
Train_hypercube, ///< initialize centroids with nbits-D hypercube
|
40
40
|
Train_hypercube_pca, ///< initialize centroids with nbits-D hypercube
|
41
41
|
};
|