faiss 0.1.7 → 0.2.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +18 -0
- data/README.md +7 -7
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +8 -2
- data/ext/faiss/index.cpp +102 -69
- data/ext/faiss/index_binary.cpp +24 -30
- data/ext/faiss/kmeans.cpp +20 -16
- data/ext/faiss/numo.hpp +867 -0
- data/ext/faiss/pca_matrix.cpp +13 -14
- data/ext/faiss/product_quantizer.cpp +23 -24
- data/ext/faiss/utils.cpp +10 -37
- data/ext/faiss/utils.h +2 -13
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +0 -5
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +334 -195
- data/vendor/faiss/faiss/Clustering.h +88 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
- data/vendor/faiss/faiss/Index2Layer.h +22 -22
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
- data/vendor/faiss/faiss/IndexFlat.h +35 -46
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
- data/vendor/faiss/faiss/IndexIVF.h +146 -113
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
- data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
- data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
- data/vendor/faiss/faiss/IndexLSH.h +21 -26
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
- data/vendor/faiss/faiss/IndexPQ.h +64 -67
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
- data/vendor/faiss/faiss/IndexRefine.h +22 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
- data/vendor/faiss/faiss/IndexResidual.h +152 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
- data/vendor/faiss/faiss/VectorTransform.h +61 -89
- data/vendor/faiss/faiss/clone_index.cpp +77 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
- data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
- data/vendor/faiss/faiss/impl/io.cpp +75 -94
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +40 -29
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +269 -218
- data/vendor/faiss/faiss/index_factory.h +6 -7
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +301 -310
- data/vendor/faiss/faiss/utils/distances.h +133 -118
- data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +53 -48
- metadata +26 -12
- data/lib/faiss/index.rb +0 -20
- data/lib/faiss/index_binary.rb +0 -20
- data/lib/faiss/kmeans.rb +0 -15
- data/lib/faiss/pca_matrix.rb +0 -15
- data/lib/faiss/product_quantizer.rb +0 -22
@@ -8,18 +8,21 @@
|
|
8
8
|
// -*- c++ -*-
|
9
9
|
|
10
10
|
#include <faiss/impl/PolysemousTraining.h>
|
11
|
+
#include "faiss/impl/FaissAssert.h"
|
12
|
+
|
13
|
+
#include <omp.h>
|
14
|
+
#include <stdint.h>
|
11
15
|
|
12
|
-
#include <cstdlib>
|
13
16
|
#include <cmath>
|
17
|
+
#include <cstdlib>
|
14
18
|
#include <cstring>
|
15
|
-
#include <stdint.h>
|
16
19
|
|
17
20
|
#include <algorithm>
|
18
21
|
|
19
|
-
#include <faiss/utils/random.h>
|
20
|
-
#include <faiss/utils/utils.h>
|
21
22
|
#include <faiss/utils/distances.h>
|
22
23
|
#include <faiss/utils/hamming.h>
|
24
|
+
#include <faiss/utils/random.h>
|
25
|
+
#include <faiss/utils/utils.h>
|
23
26
|
|
24
27
|
#include <faiss/impl/FaissAssert.h>
|
25
28
|
|
@@ -29,16 +32,14 @@
|
|
29
32
|
|
30
33
|
namespace faiss {
|
31
34
|
|
32
|
-
|
33
35
|
/****************************************************
|
34
36
|
* Optimization code
|
35
37
|
****************************************************/
|
36
38
|
|
37
|
-
SimulatedAnnealingParameters::SimulatedAnnealingParameters
|
38
|
-
{
|
39
|
+
SimulatedAnnealingParameters::SimulatedAnnealingParameters() {
|
39
40
|
// set some reasonable defaults for the optimization
|
40
41
|
init_temperature = 0.7;
|
41
|
-
temperature_decay = pow
|
42
|
+
temperature_decay = pow(0.9, 1 / 500.);
|
42
43
|
// reduce by a factor 0.9 every 500 it
|
43
44
|
n_iter = 500000;
|
44
45
|
n_redo = 2;
|
@@ -50,44 +51,37 @@ SimulatedAnnealingParameters::SimulatedAnnealingParameters ()
|
|
50
51
|
|
51
52
|
// what would the cost update be if iw and jw were swapped?
|
52
53
|
// default implementation just computes both and computes the difference
|
53
|
-
double PermutationObjective::cost_update
|
54
|
-
const
|
55
|
-
|
56
|
-
double orig_cost = compute_cost (perm);
|
54
|
+
double PermutationObjective::cost_update(const int* perm, int iw, int jw)
|
55
|
+
const {
|
56
|
+
double orig_cost = compute_cost(perm);
|
57
57
|
|
58
|
-
std::vector<int> perm2
|
58
|
+
std::vector<int> perm2(n);
|
59
59
|
for (int i = 0; i < n; i++)
|
60
60
|
perm2[i] = perm[i];
|
61
61
|
perm2[iw] = perm[jw];
|
62
62
|
perm2[jw] = perm[iw];
|
63
63
|
|
64
|
-
double new_cost = compute_cost
|
64
|
+
double new_cost = compute_cost(perm2.data());
|
65
65
|
return new_cost - orig_cost;
|
66
66
|
}
|
67
67
|
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
logfile (nullptr)
|
78
|
-
{
|
79
|
-
rnd = new RandomGenerator (p.seed);
|
80
|
-
FAISS_THROW_IF_NOT (n < 100000 && n >=0 );
|
68
|
+
SimulatedAnnealingOptimizer::SimulatedAnnealingOptimizer(
|
69
|
+
PermutationObjective* obj,
|
70
|
+
const SimulatedAnnealingParameters& p)
|
71
|
+
: SimulatedAnnealingParameters(p),
|
72
|
+
obj(obj),
|
73
|
+
n(obj->n),
|
74
|
+
logfile(nullptr) {
|
75
|
+
rnd = new RandomGenerator(p.seed);
|
76
|
+
FAISS_THROW_IF_NOT(n < 100000 && n >= 0);
|
81
77
|
}
|
82
78
|
|
83
|
-
SimulatedAnnealingOptimizer::~SimulatedAnnealingOptimizer
|
84
|
-
{
|
79
|
+
SimulatedAnnealingOptimizer::~SimulatedAnnealingOptimizer() {
|
85
80
|
delete rnd;
|
86
81
|
}
|
87
82
|
|
88
83
|
// run the optimization and return the best result in best_perm
|
89
|
-
double SimulatedAnnealingOptimizer::run_optimization
|
90
|
-
{
|
84
|
+
double SimulatedAnnealingOptimizer::run_optimization(int* best_perm) {
|
91
85
|
double min_cost = 1e30;
|
92
86
|
|
93
87
|
// just do a few runs of the annealing and keep the lowest output cost
|
@@ -95,84 +89,89 @@ double SimulatedAnnealingOptimizer::run_optimization (int * best_perm)
|
|
95
89
|
std::vector<int> perm(n);
|
96
90
|
for (int i = 0; i < n; i++)
|
97
91
|
perm[i] = i;
|
98
|
-
|
92
|
+
if (init_random) {
|
99
93
|
for (int i = 0; i < n; i++) {
|
100
|
-
int j = i + rnd->rand_int
|
101
|
-
std::swap
|
94
|
+
int j = i + rnd->rand_int(n - i);
|
95
|
+
std::swap(perm[i], perm[j]);
|
102
96
|
}
|
103
97
|
}
|
104
|
-
|
105
|
-
if (logfile)
|
106
|
-
|
107
|
-
|
108
|
-
|
98
|
+
float cost = optimize(perm.data());
|
99
|
+
if (logfile)
|
100
|
+
fprintf(logfile, "\n");
|
101
|
+
if (verbose > 1) {
|
102
|
+
printf(" optimization run %d: cost=%g %s\n",
|
103
|
+
it,
|
104
|
+
cost,
|
105
|
+
cost < min_cost ? "keep" : "");
|
109
106
|
}
|
110
107
|
if (cost < min_cost) {
|
111
|
-
memcpy
|
108
|
+
memcpy(best_perm, perm.data(), sizeof(perm[0]) * n);
|
112
109
|
min_cost = cost;
|
113
110
|
}
|
114
111
|
}
|
115
|
-
|
112
|
+
return min_cost;
|
116
113
|
}
|
117
114
|
|
118
115
|
// perform the optimization loop, starting from and modifying
|
119
116
|
// permutation in-place
|
120
|
-
double SimulatedAnnealingOptimizer::optimize
|
121
|
-
|
122
|
-
double cost = init_cost = obj->compute_cost (perm);
|
117
|
+
double SimulatedAnnealingOptimizer::optimize(int* perm) {
|
118
|
+
double cost = init_cost = obj->compute_cost(perm);
|
123
119
|
int log2n = 0;
|
124
|
-
while (!(n <= (1 << log2n)))
|
120
|
+
while (!(n <= (1 << log2n)))
|
121
|
+
log2n++;
|
125
122
|
double temperature = init_temperature;
|
126
|
-
|
123
|
+
int n_swap = 0, n_hot = 0;
|
127
124
|
for (int it = 0; it < n_iter; it++) {
|
128
125
|
temperature = temperature * temperature_decay;
|
129
126
|
int iw, jw;
|
130
127
|
if (only_bit_flips) {
|
131
|
-
iw = rnd->rand_int
|
132
|
-
jw = iw ^ (1 << rnd->rand_int
|
128
|
+
iw = rnd->rand_int(n);
|
129
|
+
jw = iw ^ (1 << rnd->rand_int(log2n));
|
133
130
|
} else {
|
134
|
-
iw = rnd->rand_int
|
135
|
-
jw = rnd->rand_int
|
136
|
-
if (jw == iw)
|
131
|
+
iw = rnd->rand_int(n);
|
132
|
+
jw = rnd->rand_int(n - 1);
|
133
|
+
if (jw == iw)
|
134
|
+
jw++;
|
137
135
|
}
|
138
|
-
|
139
|
-
|
140
|
-
std::swap
|
136
|
+
double delta_cost = obj->cost_update(perm, iw, jw);
|
137
|
+
if (delta_cost < 0 || rnd->rand_float() < temperature) {
|
138
|
+
std::swap(perm[iw], perm[jw]);
|
141
139
|
cost += delta_cost;
|
142
140
|
n_swap++;
|
143
|
-
if (delta_cost >= 0)
|
141
|
+
if (delta_cost >= 0)
|
142
|
+
n_hot++;
|
144
143
|
}
|
145
|
-
|
146
|
-
printf
|
147
|
-
|
148
|
-
|
144
|
+
if (verbose > 2 || (verbose > 1 && it % 10000 == 0)) {
|
145
|
+
printf(" iteration %d cost %g temp %g n_swap %d "
|
146
|
+
"(%d hot) \r",
|
147
|
+
it,
|
148
|
+
cost,
|
149
|
+
temperature,
|
150
|
+
n_swap,
|
151
|
+
n_hot);
|
149
152
|
fflush(stdout);
|
150
153
|
}
|
151
154
|
if (logfile) {
|
152
|
-
fprintf
|
153
|
-
|
155
|
+
fprintf(logfile,
|
156
|
+
"%d %g %g %d %d\n",
|
157
|
+
it,
|
158
|
+
cost,
|
159
|
+
temperature,
|
160
|
+
n_swap,
|
161
|
+
n_hot);
|
154
162
|
}
|
155
|
-
|
156
|
-
if (verbose > 1)
|
163
|
+
}
|
164
|
+
if (verbose > 1)
|
165
|
+
printf("\n");
|
157
166
|
return cost;
|
158
167
|
}
|
159
168
|
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
169
|
/****************************************************
|
165
170
|
* Cost functions: ReproduceDistanceTable
|
166
171
|
****************************************************/
|
167
172
|
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
static inline int hamming_dis (uint64_t a, uint64_t b)
|
174
|
-
{
|
175
|
-
return __builtin_popcountl (a ^ b);
|
173
|
+
static inline int hamming_dis(uint64_t a, uint64_t b) {
|
174
|
+
return __builtin_popcountl(a ^ b);
|
176
175
|
}
|
177
176
|
|
178
177
|
namespace {
|
@@ -182,14 +181,14 @@ struct ReproduceWithHammingObjective : PermutationObjective {
|
|
182
181
|
int nbits;
|
183
182
|
double dis_weight_factor;
|
184
183
|
|
185
|
-
static double sqr
|
186
|
-
|
184
|
+
static double sqr(double x) {
|
185
|
+
return x * x;
|
186
|
+
}
|
187
187
|
|
188
188
|
// weihgting of distances: it is more important to reproduce small
|
189
189
|
// distances well
|
190
|
-
double dis_weight
|
191
|
-
|
192
|
-
return exp (-dis_weight_factor * x);
|
190
|
+
double dis_weight(double x) const {
|
191
|
+
return exp(-dis_weight_factor * x);
|
193
192
|
}
|
194
193
|
|
195
194
|
std::vector<double> target_dis; // wanted distances (size n^2)
|
@@ -197,101 +196,105 @@ struct ReproduceWithHammingObjective : PermutationObjective {
|
|
197
196
|
|
198
197
|
// cost = quadratic difference between actual distance and Hamming distance
|
199
198
|
double compute_cost(const int* perm) const override {
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
199
|
+
double cost = 0;
|
200
|
+
for (int i = 0; i < n; i++) {
|
201
|
+
for (int j = 0; j < n; j++) {
|
202
|
+
double wanted = target_dis[i * n + j];
|
203
|
+
double w = weights[i * n + j];
|
204
|
+
double actual = hamming_dis(perm[i], perm[j]);
|
205
|
+
cost += w * sqr(wanted - actual);
|
206
|
+
}
|
207
207
|
}
|
208
|
-
|
209
|
-
return cost;
|
208
|
+
return cost;
|
210
209
|
}
|
211
210
|
|
212
|
-
|
213
211
|
// what would the cost update be if iw and jw were swapped?
|
214
212
|
// computed in O(n) instead of O(n^2) for the full re-computation
|
215
213
|
double cost_update(const int* perm, int iw, int jw) const override {
|
216
|
-
|
214
|
+
double delta_cost = 0;
|
217
215
|
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
216
|
+
for (int i = 0; i < n; i++) {
|
217
|
+
if (i == iw) {
|
218
|
+
for (int j = 0; j < n; j++) {
|
219
|
+
double wanted = target_dis[i * n + j],
|
220
|
+
w = weights[i * n + j];
|
221
|
+
double actual = hamming_dis(perm[i], perm[j]);
|
222
|
+
delta_cost -= w * sqr(wanted - actual);
|
223
|
+
double new_actual = hamming_dis(
|
224
|
+
perm[jw],
|
225
|
+
perm[j == iw ? jw
|
226
|
+
: j == jw ? iw
|
227
|
+
: j]);
|
228
|
+
delta_cost += w * sqr(wanted - new_actual);
|
229
|
+
}
|
230
|
+
} else if (i == jw) {
|
231
|
+
for (int j = 0; j < n; j++) {
|
232
|
+
double wanted = target_dis[i * n + j],
|
233
|
+
w = weights[i * n + j];
|
234
|
+
double actual = hamming_dis(perm[i], perm[j]);
|
235
|
+
delta_cost -= w * sqr(wanted - actual);
|
236
|
+
double new_actual = hamming_dis(
|
237
|
+
perm[iw],
|
238
|
+
perm[j == iw ? jw
|
239
|
+
: j == jw ? iw
|
240
|
+
: j]);
|
241
|
+
delta_cost += w * sqr(wanted - new_actual);
|
242
|
+
}
|
243
|
+
} else {
|
244
|
+
int j = iw;
|
245
|
+
{
|
246
|
+
double wanted = target_dis[i * n + j],
|
247
|
+
w = weights[i * n + j];
|
248
|
+
double actual = hamming_dis(perm[i], perm[j]);
|
249
|
+
delta_cost -= w * sqr(wanted - actual);
|
250
|
+
double new_actual = hamming_dis(perm[i], perm[jw]);
|
251
|
+
delta_cost += w * sqr(wanted - new_actual);
|
252
|
+
}
|
253
|
+
j = jw;
|
254
|
+
{
|
255
|
+
double wanted = target_dis[i * n + j],
|
256
|
+
w = weights[i * n + j];
|
257
|
+
double actual = hamming_dis(perm[i], perm[j]);
|
258
|
+
delta_cost -= w * sqr(wanted - actual);
|
259
|
+
double new_actual = hamming_dis(perm[i], perm[iw]);
|
260
|
+
delta_cost += w * sqr(wanted - new_actual);
|
261
|
+
}
|
262
|
+
}
|
254
263
|
}
|
255
|
-
}
|
256
264
|
|
257
|
-
|
265
|
+
return delta_cost;
|
258
266
|
}
|
259
267
|
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
double dis_weight_factor):
|
266
|
-
nbits (nbits), dis_weight_factor (dis_weight_factor)
|
267
|
-
{
|
268
|
+
ReproduceWithHammingObjective(
|
269
|
+
int nbits,
|
270
|
+
const std::vector<double>& dis_table,
|
271
|
+
double dis_weight_factor)
|
272
|
+
: nbits(nbits), dis_weight_factor(dis_weight_factor) {
|
268
273
|
n = 1 << nbits;
|
269
|
-
FAISS_THROW_IF_NOT
|
270
|
-
set_affine_target_dis
|
274
|
+
FAISS_THROW_IF_NOT(dis_table.size() == n * n);
|
275
|
+
set_affine_target_dis(dis_table);
|
271
276
|
}
|
272
277
|
|
273
|
-
void set_affine_target_dis
|
274
|
-
{
|
278
|
+
void set_affine_target_dis(const std::vector<double>& dis_table) {
|
275
279
|
double sum = 0, sum2 = 0;
|
276
280
|
int n2 = n * n;
|
277
281
|
for (int i = 0; i < n2; i++) {
|
278
|
-
sum += dis_table
|
279
|
-
sum2 += dis_table
|
282
|
+
sum += dis_table[i];
|
283
|
+
sum2 += dis_table[i] * dis_table[i];
|
280
284
|
}
|
281
285
|
double mean = sum / n2;
|
282
286
|
double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
|
283
287
|
|
284
|
-
target_dis.resize
|
288
|
+
target_dis.resize(n2);
|
285
289
|
|
286
290
|
for (int i = 0; i < n2; i++) {
|
287
291
|
// the mapping function
|
288
|
-
double td = (dis_table
|
289
|
-
|
292
|
+
double td = (dis_table[i] - mean) / stddev * sqrt(nbits / 4) +
|
293
|
+
nbits / 2;
|
290
294
|
target_dis[i] = td;
|
291
295
|
// compute a weight
|
292
|
-
weights.push_back
|
296
|
+
weights.push_back(dis_weight(td));
|
293
297
|
}
|
294
|
-
|
295
298
|
}
|
296
299
|
|
297
300
|
~ReproduceWithHammingObjective() override {}
|
@@ -301,27 +304,23 @@ struct ReproduceWithHammingObjective : PermutationObjective {
|
|
301
304
|
|
302
305
|
// weihgting of distances: it is more important to reproduce small
|
303
306
|
// distances well
|
304
|
-
double ReproduceDistancesObjective::dis_weight
|
305
|
-
|
306
|
-
return exp (-dis_weight_factor * x);
|
307
|
+
double ReproduceDistancesObjective::dis_weight(double x) const {
|
308
|
+
return exp(-dis_weight_factor * x);
|
307
309
|
}
|
308
310
|
|
309
|
-
|
310
|
-
|
311
|
-
{
|
312
|
-
return source_dis [i * n + j];
|
311
|
+
double ReproduceDistancesObjective::get_source_dis(int i, int j) const {
|
312
|
+
return source_dis[i * n + j];
|
313
313
|
}
|
314
314
|
|
315
315
|
// cost = quadratic difference between actual distance and Hamming distance
|
316
|
-
double ReproduceDistancesObjective::compute_cost
|
317
|
-
{
|
316
|
+
double ReproduceDistancesObjective::compute_cost(const int* perm) const {
|
318
317
|
double cost = 0;
|
319
318
|
for (int i = 0; i < n; i++) {
|
320
319
|
for (int j = 0; j < n; j++) {
|
321
|
-
double wanted = target_dis
|
322
|
-
double w = weights
|
323
|
-
double actual = get_source_dis
|
324
|
-
cost += w * sqr
|
320
|
+
double wanted = target_dis[i * n + j];
|
321
|
+
double w = weights[i * n + j];
|
322
|
+
double actual = get_source_dis(perm[i], perm[j]);
|
323
|
+
cost += w * sqr(wanted - actual);
|
325
324
|
}
|
326
325
|
}
|
327
326
|
return cost;
|
@@ -329,79 +328,75 @@ double ReproduceDistancesObjective::compute_cost (const int *perm) const
|
|
329
328
|
|
330
329
|
// what would the cost update be if iw and jw were swapped?
|
331
330
|
// computed in O(n) instead of O(n^2) for the full re-computation
|
332
|
-
double ReproduceDistancesObjective::cost_update(
|
333
|
-
const
|
334
|
-
{
|
331
|
+
double ReproduceDistancesObjective::cost_update(const int* perm, int iw, int jw)
|
332
|
+
const {
|
335
333
|
double delta_cost = 0;
|
336
|
-
|
334
|
+
for (int i = 0; i < n; i++) {
|
337
335
|
if (i == iw) {
|
338
336
|
for (int j = 0; j < n; j++) {
|
339
|
-
double wanted = target_dis [i * n + j]
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
337
|
+
double wanted = target_dis[i * n + j], w = weights[i * n + j];
|
338
|
+
double actual = get_source_dis(perm[i], perm[j]);
|
339
|
+
delta_cost -= w * sqr(wanted - actual);
|
340
|
+
double new_actual = get_source_dis(
|
341
|
+
perm[jw],
|
342
|
+
perm[j == iw ? jw
|
343
|
+
: j == jw ? iw
|
344
|
+
: j]);
|
345
|
+
delta_cost += w * sqr(wanted - new_actual);
|
347
346
|
}
|
348
347
|
} else if (i == jw) {
|
349
348
|
for (int j = 0; j < n; j++) {
|
350
|
-
double wanted = target_dis [i * n + j]
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
349
|
+
double wanted = target_dis[i * n + j], w = weights[i * n + j];
|
350
|
+
double actual = get_source_dis(perm[i], perm[j]);
|
351
|
+
delta_cost -= w * sqr(wanted - actual);
|
352
|
+
double new_actual = get_source_dis(
|
353
|
+
perm[iw],
|
354
|
+
perm[j == iw ? jw
|
355
|
+
: j == jw ? iw
|
356
|
+
: j]);
|
357
|
+
delta_cost += w * sqr(wanted - new_actual);
|
358
358
|
}
|
359
|
-
} else
|
359
|
+
} else {
|
360
360
|
int j = iw;
|
361
361
|
{
|
362
|
-
double wanted = target_dis [i * n + j]
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
delta_cost += w * sqr (wanted - new_actual);
|
362
|
+
double wanted = target_dis[i * n + j], w = weights[i * n + j];
|
363
|
+
double actual = get_source_dis(perm[i], perm[j]);
|
364
|
+
delta_cost -= w * sqr(wanted - actual);
|
365
|
+
double new_actual = get_source_dis(perm[i], perm[jw]);
|
366
|
+
delta_cost += w * sqr(wanted - new_actual);
|
368
367
|
}
|
369
368
|
j = jw;
|
370
369
|
{
|
371
|
-
double wanted = target_dis [i * n + j]
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
delta_cost += w * sqr (wanted - new_actual);
|
370
|
+
double wanted = target_dis[i * n + j], w = weights[i * n + j];
|
371
|
+
double actual = get_source_dis(perm[i], perm[j]);
|
372
|
+
delta_cost -= w * sqr(wanted - actual);
|
373
|
+
double new_actual = get_source_dis(perm[i], perm[iw]);
|
374
|
+
delta_cost += w * sqr(wanted - new_actual);
|
377
375
|
}
|
378
376
|
}
|
379
377
|
}
|
380
|
-
|
378
|
+
return delta_cost;
|
381
379
|
}
|
382
380
|
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
double dis_weight_factor):
|
390
|
-
dis_weight_factor (dis_weight_factor),
|
391
|
-
target_dis (target_dis_in)
|
392
|
-
{
|
381
|
+
ReproduceDistancesObjective::ReproduceDistancesObjective(
|
382
|
+
int n,
|
383
|
+
const double* source_dis_in,
|
384
|
+
const double* target_dis_in,
|
385
|
+
double dis_weight_factor)
|
386
|
+
: dis_weight_factor(dis_weight_factor), target_dis(target_dis_in) {
|
393
387
|
this->n = n;
|
394
|
-
set_affine_target_dis
|
388
|
+
set_affine_target_dis(source_dis_in);
|
395
389
|
}
|
396
390
|
|
397
|
-
void ReproduceDistancesObjective::compute_mean_stdev
|
398
|
-
|
399
|
-
|
400
|
-
|
391
|
+
void ReproduceDistancesObjective::compute_mean_stdev(
|
392
|
+
const double* tab,
|
393
|
+
size_t n2,
|
394
|
+
double* mean_out,
|
395
|
+
double* stddev_out) {
|
401
396
|
double sum = 0, sum2 = 0;
|
402
397
|
for (int i = 0; i < n2; i++) {
|
403
|
-
sum += tab
|
404
|
-
sum2 += tab
|
398
|
+
sum += tab[i];
|
399
|
+
sum2 += tab[i] * tab[i];
|
405
400
|
}
|
406
401
|
double mean = sum / n2;
|
407
402
|
double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
|
@@ -409,32 +404,34 @@ void ReproduceDistancesObjective::compute_mean_stdev (
|
|
409
404
|
*stddev_out = stddev;
|
410
405
|
}
|
411
406
|
|
412
|
-
void ReproduceDistancesObjective::set_affine_target_dis
|
413
|
-
|
414
|
-
{
|
407
|
+
void ReproduceDistancesObjective::set_affine_target_dis(
|
408
|
+
const double* source_dis_in) {
|
415
409
|
int n2 = n * n;
|
416
410
|
|
417
411
|
double mean_src, stddev_src;
|
418
|
-
compute_mean_stdev
|
412
|
+
compute_mean_stdev(source_dis_in, n2, &mean_src, &stddev_src);
|
419
413
|
|
420
414
|
double mean_target, stddev_target;
|
421
|
-
compute_mean_stdev
|
415
|
+
compute_mean_stdev(target_dis, n2, &mean_target, &stddev_target);
|
422
416
|
|
423
|
-
printf
|
424
|
-
|
417
|
+
printf("map mean %g std %g -> mean %g std %g\n",
|
418
|
+
mean_src,
|
419
|
+
stddev_src,
|
420
|
+
mean_target,
|
421
|
+
stddev_target);
|
425
422
|
|
426
|
-
source_dis.resize
|
427
|
-
weights.resize
|
423
|
+
source_dis.resize(n2);
|
424
|
+
weights.resize(n2);
|
428
425
|
|
429
426
|
for (int i = 0; i < n2; i++) {
|
430
427
|
// the mapping function
|
431
|
-
source_dis[i] =
|
432
|
-
|
428
|
+
source_dis[i] =
|
429
|
+
(source_dis_in[i] - mean_src) / stddev_src * stddev_target +
|
430
|
+
mean_target;
|
433
431
|
|
434
432
|
// compute a weight
|
435
|
-
weights
|
433
|
+
weights[i] = dis_weight(target_dis[i]);
|
436
434
|
}
|
437
|
-
|
438
435
|
}
|
439
436
|
|
440
437
|
/****************************************************
|
@@ -444,8 +441,7 @@ void ReproduceDistancesObjective::set_affine_target_dis (
|
|
444
441
|
/// Maintains a 3D table of elementary costs.
|
445
442
|
/// Accumulates elements based on Hamming distance comparisons
|
446
443
|
template <typename Ttab, typename Taccu>
|
447
|
-
struct Score3Computer: PermutationObjective {
|
448
|
-
|
444
|
+
struct Score3Computer : PermutationObjective {
|
449
445
|
int nc;
|
450
446
|
|
451
447
|
// cost matrix of size nc * nc *nc
|
@@ -453,21 +449,18 @@ struct Score3Computer: PermutationObjective {
|
|
453
449
|
// where x has PQ code i, y- PQ code j and y+ PQ code k
|
454
450
|
std::vector<Ttab> n_gt;
|
455
451
|
|
456
|
-
|
457
452
|
/// the cost is a triple loop on the nc * nc * nc matrix of entries.
|
458
453
|
///
|
459
|
-
Taccu compute
|
460
|
-
{
|
454
|
+
Taccu compute(const int* perm) const {
|
461
455
|
Taccu accu = 0;
|
462
|
-
const Ttab
|
456
|
+
const Ttab* p = n_gt.data();
|
463
457
|
for (int i = 0; i < nc; i++) {
|
464
|
-
int ip = perm
|
458
|
+
int ip = perm[i];
|
465
459
|
for (int j = 0; j < nc; j++) {
|
466
|
-
int jp = perm
|
460
|
+
int jp = perm[j];
|
467
461
|
for (int k = 0; k < nc; k++) {
|
468
|
-
int kp = perm
|
469
|
-
if (hamming_dis
|
470
|
-
hamming_dis (ip, kp)) {
|
462
|
+
int kp = perm[k];
|
463
|
+
if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
|
471
464
|
accu += *p; // n_gt [ ( i * nc + j) * nc + k];
|
472
465
|
}
|
473
466
|
p++;
|
@@ -477,7 +470,6 @@ struct Score3Computer: PermutationObjective {
|
|
477
470
|
return accu;
|
478
471
|
}
|
479
472
|
|
480
|
-
|
481
473
|
/** cost update if entries iw and jw of the permutation would be
|
482
474
|
* swapped.
|
483
475
|
*
|
@@ -487,25 +479,23 @@ struct Score3Computer: PermutationObjective {
|
|
487
479
|
* cells. Practical speedup is about 8x, and the code is quite
|
488
480
|
* complex :-/
|
489
481
|
*/
|
490
|
-
Taccu compute_update
|
491
|
-
|
492
|
-
|
493
|
-
|
482
|
+
Taccu compute_update(const int* perm, int iw, int jw) const {
|
483
|
+
assert(iw != jw);
|
484
|
+
if (iw > jw)
|
485
|
+
std::swap(iw, jw);
|
494
486
|
|
495
487
|
Taccu accu = 0;
|
496
|
-
const Ttab
|
488
|
+
const Ttab* n_gt_i = n_gt.data();
|
497
489
|
for (int i = 0; i < nc; i++) {
|
498
|
-
int ip0 = perm
|
499
|
-
int ip = perm
|
490
|
+
int ip0 = perm[i];
|
491
|
+
int ip = perm[i == iw ? jw : i == jw ? iw : i];
|
500
492
|
|
501
|
-
//accu += update_i (perm, iw, jw, ip0, ip, n_gt_i);
|
493
|
+
// accu += update_i (perm, iw, jw, ip0, ip, n_gt_i);
|
502
494
|
|
503
|
-
accu += update_i_cross
|
504
|
-
ip0, ip, n_gt_i);
|
495
|
+
accu += update_i_cross(perm, iw, jw, ip0, ip, n_gt_i);
|
505
496
|
|
506
497
|
if (ip != ip0)
|
507
|
-
accu += update_i_plane
|
508
|
-
ip0, ip, n_gt_i);
|
498
|
+
accu += update_i_plane(perm, iw, jw, ip0, ip, n_gt_i);
|
509
499
|
|
510
500
|
n_gt_i += nc * nc;
|
511
501
|
}
|
@@ -513,23 +503,26 @@ struct Score3Computer: PermutationObjective {
|
|
513
503
|
return accu;
|
514
504
|
}
|
515
505
|
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
506
|
+
Taccu update_i(
|
507
|
+
const int* perm,
|
508
|
+
int iw,
|
509
|
+
int jw,
|
510
|
+
int ip0,
|
511
|
+
int ip,
|
512
|
+
const Ttab* n_gt_i) const {
|
520
513
|
Taccu accu = 0;
|
521
|
-
const Ttab
|
514
|
+
const Ttab* n_gt_ij = n_gt_i;
|
522
515
|
for (int j = 0; j < nc; j++) {
|
523
516
|
int jp0 = perm[j];
|
524
|
-
int jp = perm
|
517
|
+
int jp = perm[j == iw ? jw : j == jw ? iw : j];
|
525
518
|
for (int k = 0; k < nc; k++) {
|
526
|
-
int kp0 = perm
|
527
|
-
int kp = perm
|
528
|
-
int ng = n_gt_ij
|
529
|
-
if (hamming_dis
|
519
|
+
int kp0 = perm[k];
|
520
|
+
int kp = perm[k == iw ? jw : k == jw ? iw : k];
|
521
|
+
int ng = n_gt_ij[k];
|
522
|
+
if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
|
530
523
|
accu += ng;
|
531
524
|
}
|
532
|
-
if (hamming_dis
|
525
|
+
if (hamming_dis(ip0, jp0) < hamming_dis(ip0, kp0)) {
|
533
526
|
accu -= ng;
|
534
527
|
}
|
535
528
|
}
|
@@ -539,23 +532,27 @@ struct Score3Computer: PermutationObjective {
|
|
539
532
|
}
|
540
533
|
|
541
534
|
// 2 inner loops for the case ip0 != ip
|
542
|
-
Taccu update_i_plane
|
543
|
-
|
544
|
-
|
535
|
+
Taccu update_i_plane(
|
536
|
+
const int* perm,
|
537
|
+
int iw,
|
538
|
+
int jw,
|
539
|
+
int ip0,
|
540
|
+
int ip,
|
541
|
+
const Ttab* n_gt_i) const {
|
545
542
|
Taccu accu = 0;
|
546
|
-
const Ttab
|
543
|
+
const Ttab* n_gt_ij = n_gt_i;
|
547
544
|
|
548
545
|
for (int j = 0; j < nc; j++) {
|
549
546
|
if (j != iw && j != jw) {
|
550
547
|
int jp = perm[j];
|
551
548
|
for (int k = 0; k < nc; k++) {
|
552
549
|
if (k != iw && k != jw) {
|
553
|
-
int kp = perm
|
554
|
-
Ttab ng = n_gt_ij
|
555
|
-
if (hamming_dis
|
550
|
+
int kp = perm[k];
|
551
|
+
Ttab ng = n_gt_ij[k];
|
552
|
+
if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
|
556
553
|
accu += ng;
|
557
554
|
}
|
558
|
-
if (hamming_dis
|
555
|
+
if (hamming_dis(ip0, jp) < hamming_dis(ip0, kp)) {
|
559
556
|
accu -= ng;
|
560
557
|
}
|
561
558
|
}
|
@@ -567,114 +564,128 @@ struct Score3Computer: PermutationObjective {
|
|
567
564
|
}
|
568
565
|
|
569
566
|
/// used for the 8 cells were the 3 indices are swapped
|
570
|
-
inline Taccu update_k
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
567
|
+
inline Taccu update_k(
|
568
|
+
const int* perm,
|
569
|
+
int iw,
|
570
|
+
int jw,
|
571
|
+
int ip0,
|
572
|
+
int ip,
|
573
|
+
int jp0,
|
574
|
+
int jp,
|
575
|
+
int k,
|
576
|
+
const Ttab* n_gt_ij) const {
|
575
577
|
Taccu accu = 0;
|
576
|
-
int kp0 = perm
|
577
|
-
int kp = perm
|
578
|
-
Ttab ng = n_gt_ij
|
579
|
-
if (hamming_dis
|
578
|
+
int kp0 = perm[k];
|
579
|
+
int kp = perm[k == iw ? jw : k == jw ? iw : k];
|
580
|
+
Ttab ng = n_gt_ij[k];
|
581
|
+
if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
|
580
582
|
accu += ng;
|
581
583
|
}
|
582
|
-
if (hamming_dis
|
584
|
+
if (hamming_dis(ip0, jp0) < hamming_dis(ip0, kp0)) {
|
583
585
|
accu -= ng;
|
584
586
|
}
|
585
587
|
return accu;
|
586
588
|
}
|
587
589
|
|
588
590
|
/// compute update on a line of k's, where i and j are swapped
|
589
|
-
Taccu update_j_line
|
590
|
-
|
591
|
-
|
592
|
-
|
591
|
+
Taccu update_j_line(
|
592
|
+
const int* perm,
|
593
|
+
int iw,
|
594
|
+
int jw,
|
595
|
+
int ip0,
|
596
|
+
int ip,
|
597
|
+
int jp0,
|
598
|
+
int jp,
|
599
|
+
const Ttab* n_gt_ij) const {
|
593
600
|
Taccu accu = 0;
|
594
601
|
for (int k = 0; k < nc; k++) {
|
595
|
-
if (k == iw || k == jw)
|
596
|
-
|
597
|
-
|
598
|
-
|
602
|
+
if (k == iw || k == jw)
|
603
|
+
continue;
|
604
|
+
int kp = perm[k];
|
605
|
+
Ttab ng = n_gt_ij[k];
|
606
|
+
if (hamming_dis(ip, jp) < hamming_dis(ip, kp)) {
|
599
607
|
accu += ng;
|
600
608
|
}
|
601
|
-
if (hamming_dis
|
609
|
+
if (hamming_dis(ip0, jp0) < hamming_dis(ip0, kp)) {
|
602
610
|
accu -= ng;
|
603
611
|
}
|
604
612
|
}
|
605
613
|
return accu;
|
606
614
|
}
|
607
615
|
|
608
|
-
|
609
616
|
/// considers the 2 pairs of crossing lines j=iw or jw and k = iw or kw
|
610
|
-
Taccu update_i_cross
|
611
|
-
|
612
|
-
|
617
|
+
Taccu update_i_cross(
|
618
|
+
const int* perm,
|
619
|
+
int iw,
|
620
|
+
int jw,
|
621
|
+
int ip0,
|
622
|
+
int ip,
|
623
|
+
const Ttab* n_gt_i) const {
|
613
624
|
Taccu accu = 0;
|
614
|
-
const Ttab
|
625
|
+
const Ttab* n_gt_ij = n_gt_i;
|
615
626
|
|
616
627
|
for (int j = 0; j < nc; j++) {
|
617
628
|
int jp0 = perm[j];
|
618
|
-
int jp = perm
|
629
|
+
int jp = perm[j == iw ? jw : j == jw ? iw : j];
|
619
630
|
|
620
|
-
accu += update_k
|
621
|
-
accu += update_k
|
631
|
+
accu += update_k(perm, iw, jw, ip0, ip, jp0, jp, iw, n_gt_ij);
|
632
|
+
accu += update_k(perm, iw, jw, ip0, ip, jp0, jp, jw, n_gt_ij);
|
622
633
|
|
623
634
|
if (jp != jp0)
|
624
|
-
accu += update_j_line
|
635
|
+
accu += update_j_line(perm, iw, jw, ip0, ip, jp0, jp, n_gt_ij);
|
625
636
|
|
626
637
|
n_gt_ij += nc;
|
627
638
|
}
|
628
639
|
return accu;
|
629
640
|
}
|
630
641
|
|
631
|
-
|
632
642
|
/// PermutationObjective implementeation (just negates the scores
|
633
643
|
/// for minimization)
|
634
644
|
|
635
645
|
double compute_cost(const int* perm) const override {
|
636
|
-
|
646
|
+
return -compute(perm);
|
637
647
|
}
|
638
648
|
|
639
649
|
double cost_update(const int* perm, int iw, int jw) const override {
|
640
|
-
|
641
|
-
|
650
|
+
double ret = -compute_update(perm, iw, jw);
|
651
|
+
return ret;
|
642
652
|
}
|
643
653
|
|
644
654
|
~Score3Computer() override {}
|
645
655
|
};
|
646
656
|
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
657
|
struct IndirectSort {
|
652
|
-
const float
|
653
|
-
bool operator
|
658
|
+
const float* tab;
|
659
|
+
bool operator()(int a, int b) {
|
660
|
+
return tab[a] < tab[b];
|
661
|
+
}
|
654
662
|
};
|
655
663
|
|
656
|
-
|
657
|
-
|
658
|
-
struct RankingScore2: Score3Computer<float, double> {
|
664
|
+
struct RankingScore2 : Score3Computer<float, double> {
|
659
665
|
int nbits;
|
660
666
|
int nq, nb;
|
661
667
|
const uint32_t *qcodes, *bcodes;
|
662
|
-
const float
|
663
|
-
|
664
|
-
RankingScore2
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
668
|
+
const float* gt_distances;
|
669
|
+
|
670
|
+
RankingScore2(
|
671
|
+
int nbits,
|
672
|
+
int nq,
|
673
|
+
int nb,
|
674
|
+
const uint32_t* qcodes,
|
675
|
+
const uint32_t* bcodes,
|
676
|
+
const float* gt_distances)
|
677
|
+
: nbits(nbits),
|
678
|
+
nq(nq),
|
679
|
+
nb(nb),
|
680
|
+
qcodes(qcodes),
|
681
|
+
bcodes(bcodes),
|
682
|
+
gt_distances(gt_distances) {
|
670
683
|
n = nc = 1 << nbits;
|
671
|
-
n_gt.resize
|
672
|
-
init_n_gt
|
684
|
+
n_gt.resize(nc * nc * nc);
|
685
|
+
init_n_gt();
|
673
686
|
}
|
674
687
|
|
675
|
-
|
676
|
-
double rank_weight (int r)
|
677
|
-
{
|
688
|
+
double rank_weight(int r) {
|
678
689
|
return 1.0 / (r + 1);
|
679
690
|
}
|
680
691
|
|
@@ -683,271 +694,290 @@ struct RankingScore2: Score3Computer<float, double> {
|
|
683
694
|
/// they are the ranks of j and k respectively.
|
684
695
|
/// specific version for diff-of-rank weighting, cannot optimized
|
685
696
|
/// with a cumulative table
|
686
|
-
double accum_gt_weight_diff
|
687
|
-
|
688
|
-
|
697
|
+
double accum_gt_weight_diff(
|
698
|
+
const std::vector<int>& a,
|
699
|
+
const std::vector<int>& b) {
|
689
700
|
int nb = b.size(), na = a.size();
|
690
701
|
|
691
702
|
double accu = 0;
|
692
703
|
int j = 0;
|
693
704
|
for (int i = 0; i < na; i++) {
|
694
705
|
int ai = a[i];
|
695
|
-
while (j < nb && ai >= b[j])
|
706
|
+
while (j < nb && ai >= b[j])
|
707
|
+
j++;
|
696
708
|
|
697
709
|
double accu_i = 0;
|
698
710
|
for (int k = j; k < b.size(); k++)
|
699
|
-
accu_i += rank_weight
|
700
|
-
|
701
|
-
accu += rank_weight (ai) * accu_i;
|
711
|
+
accu_i += rank_weight(b[k] - ai);
|
702
712
|
|
713
|
+
accu += rank_weight(ai) * accu_i;
|
703
714
|
}
|
704
715
|
return accu;
|
705
716
|
}
|
706
717
|
|
707
|
-
void init_n_gt
|
708
|
-
{
|
718
|
+
void init_n_gt() {
|
709
719
|
for (int q = 0; q < nq; q++) {
|
710
|
-
const float
|
711
|
-
const uint32_t
|
712
|
-
float
|
720
|
+
const float* gtd = gt_distances + q * nb;
|
721
|
+
const uint32_t* cb = bcodes; // all same codes
|
722
|
+
float* n_gt_q = &n_gt[qcodes[q] * nc * nc];
|
713
723
|
|
714
|
-
printf("init gt for q=%d/%d \r", q, nq);
|
724
|
+
printf("init gt for q=%d/%d \r", q, nq);
|
725
|
+
fflush(stdout);
|
715
726
|
|
716
|
-
std::vector<int> rankv
|
717
|
-
int
|
727
|
+
std::vector<int> rankv(nb);
|
728
|
+
int* ranks = rankv.data();
|
718
729
|
|
719
730
|
// elements in each code bin, ordered by rank within each bin
|
720
|
-
std::vector<std::vector<int
|
731
|
+
std::vector<std::vector<int>> tab(nc);
|
721
732
|
|
722
733
|
{ // build rank table
|
723
734
|
IndirectSort s = {gtd};
|
724
|
-
for (int j = 0; j < nb; j++)
|
725
|
-
|
735
|
+
for (int j = 0; j < nb; j++)
|
736
|
+
ranks[j] = j;
|
737
|
+
std::sort(ranks, ranks + nb, s);
|
726
738
|
}
|
727
739
|
|
728
740
|
for (int rank = 0; rank < nb; rank++) {
|
729
|
-
int i = ranks
|
730
|
-
tab
|
741
|
+
int i = ranks[rank];
|
742
|
+
tab[cb[i]].push_back(rank);
|
731
743
|
}
|
732
744
|
|
733
|
-
|
734
745
|
// this is very expensive. Any suggestion for improvement
|
735
746
|
// welcome.
|
736
747
|
for (int i = 0; i < nc; i++) {
|
737
|
-
std::vector<int
|
748
|
+
std::vector<int>& di = tab[i];
|
738
749
|
for (int j = 0; j < nc; j++) {
|
739
|
-
std::vector<int
|
740
|
-
n_gt_q
|
741
|
-
|
750
|
+
std::vector<int>& dj = tab[j];
|
751
|
+
n_gt_q[i * nc + j] += accum_gt_weight_diff(di, dj);
|
742
752
|
}
|
743
753
|
}
|
744
|
-
|
745
754
|
}
|
746
|
-
|
747
755
|
}
|
748
|
-
|
749
756
|
};
|
750
757
|
|
751
|
-
|
752
758
|
/*****************************************
|
753
759
|
* PolysemousTraining
|
754
760
|
******************************************/
|
755
761
|
|
756
|
-
|
757
|
-
|
758
|
-
PolysemousTraining::PolysemousTraining ()
|
759
|
-
{
|
762
|
+
PolysemousTraining::PolysemousTraining() {
|
760
763
|
optimization_type = OT_ReproduceDistances_affine;
|
761
764
|
ntrain_permutation = 0;
|
762
765
|
dis_weight_factor = log(2);
|
766
|
+
// max 20 G RAM
|
767
|
+
max_memory = (size_t)(20) * 1024 * 1024 * 1024;
|
763
768
|
}
|
764
769
|
|
765
|
-
|
766
|
-
|
767
|
-
void PolysemousTraining::optimize_reproduce_distances (
|
768
|
-
ProductQuantizer &pq) const
|
769
|
-
{
|
770
|
-
|
770
|
+
void PolysemousTraining::optimize_reproduce_distances(
|
771
|
+
ProductQuantizer& pq) const {
|
771
772
|
int dsub = pq.dsub;
|
772
773
|
|
773
774
|
int n = pq.ksub;
|
774
775
|
int nbits = pq.nbits;
|
775
776
|
|
776
|
-
|
777
|
+
size_t mem1 = memory_usage_per_thread(pq);
|
778
|
+
int nt = std::min(omp_get_max_threads(), int(pq.M));
|
779
|
+
FAISS_THROW_IF_NOT_FMT(
|
780
|
+
mem1 < max_memory,
|
781
|
+
"Polysemous training will use %zd bytes per thread, while the max is set to %zd",
|
782
|
+
mem1,
|
783
|
+
max_memory);
|
784
|
+
|
785
|
+
if (mem1 * nt > max_memory) {
|
786
|
+
nt = max_memory / mem1;
|
787
|
+
fprintf(stderr,
|
788
|
+
"Polysemous training: WARN, reducing number of threads to %d to save memory",
|
789
|
+
nt);
|
790
|
+
}
|
791
|
+
|
792
|
+
#pragma omp parallel for num_threads(nt)
|
777
793
|
for (int m = 0; m < pq.M; m++) {
|
778
794
|
std::vector<double> dis_table;
|
779
795
|
|
780
796
|
// printf ("Optimizing quantizer %d\n", m);
|
781
797
|
|
782
|
-
float
|
798
|
+
float* centroids = pq.get_centroids(m, 0);
|
783
799
|
|
784
800
|
for (int i = 0; i < n; i++) {
|
785
801
|
for (int j = 0; j < n; j++) {
|
786
|
-
dis_table.push_back
|
787
|
-
|
788
|
-
dsub));
|
802
|
+
dis_table.push_back(fvec_L2sqr(
|
803
|
+
centroids + i * dsub, centroids + j * dsub, dsub));
|
789
804
|
}
|
790
805
|
}
|
791
806
|
|
792
|
-
std::vector<int> perm
|
793
|
-
ReproduceWithHammingObjective obj
|
794
|
-
nbits, dis_table,
|
795
|
-
dis_weight_factor);
|
796
|
-
|
807
|
+
std::vector<int> perm(n);
|
808
|
+
ReproduceWithHammingObjective obj(nbits, dis_table, dis_weight_factor);
|
797
809
|
|
798
|
-
SimulatedAnnealingOptimizer optim
|
810
|
+
SimulatedAnnealingOptimizer optim(&obj, *this);
|
799
811
|
|
800
812
|
if (log_pattern.size()) {
|
801
813
|
char fname[256];
|
802
|
-
snprintf
|
803
|
-
printf
|
804
|
-
optim.logfile = fopen
|
805
|
-
FAISS_THROW_IF_NOT_MSG
|
814
|
+
snprintf(fname, 256, log_pattern.c_str(), m);
|
815
|
+
printf("opening log file %s\n", fname);
|
816
|
+
optim.logfile = fopen(fname, "w");
|
817
|
+
FAISS_THROW_IF_NOT_MSG(optim.logfile, "could not open logfile");
|
806
818
|
}
|
807
|
-
double final_cost = optim.run_optimization
|
819
|
+
double final_cost = optim.run_optimization(perm.data());
|
808
820
|
|
809
821
|
if (verbose > 0) {
|
810
|
-
printf
|
811
|
-
|
822
|
+
printf("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
|
823
|
+
m,
|
824
|
+
optim.init_cost,
|
825
|
+
final_cost);
|
812
826
|
}
|
813
827
|
|
814
|
-
if (log_pattern.size())
|
828
|
+
if (log_pattern.size())
|
829
|
+
fclose(optim.logfile);
|
815
830
|
|
816
831
|
std::vector<float> centroids_copy;
|
817
832
|
for (int i = 0; i < dsub * n; i++)
|
818
|
-
centroids_copy.push_back
|
833
|
+
centroids_copy.push_back(centroids[i]);
|
819
834
|
|
820
835
|
for (int i = 0; i < n; i++)
|
821
|
-
memcpy
|
822
|
-
|
823
|
-
|
824
|
-
|
836
|
+
memcpy(centroids + perm[i] * dsub,
|
837
|
+
centroids_copy.data() + i * dsub,
|
838
|
+
dsub * sizeof(centroids[0]));
|
825
839
|
}
|
826
|
-
|
827
840
|
}
|
828
841
|
|
829
|
-
|
830
|
-
|
831
|
-
|
832
|
-
{
|
833
|
-
|
842
|
+
void PolysemousTraining::optimize_ranking(
|
843
|
+
ProductQuantizer& pq,
|
844
|
+
size_t n,
|
845
|
+
const float* x) const {
|
834
846
|
int dsub = pq.dsub;
|
835
|
-
|
836
847
|
int nbits = pq.nbits;
|
837
848
|
|
838
|
-
std::vector<uint8_t> all_codes
|
849
|
+
std::vector<uint8_t> all_codes(pq.code_size * n);
|
839
850
|
|
840
|
-
pq.compute_codes
|
851
|
+
pq.compute_codes(x, all_codes.data(), n);
|
841
852
|
|
842
|
-
FAISS_THROW_IF_NOT
|
853
|
+
FAISS_THROW_IF_NOT(pq.nbits == 8);
|
843
854
|
|
844
|
-
if (n == 0)
|
845
|
-
pq.compute_sdc_table
|
855
|
+
if (n == 0) {
|
856
|
+
pq.compute_sdc_table();
|
857
|
+
}
|
846
858
|
|
847
859
|
#pragma omp parallel for
|
848
860
|
for (int m = 0; m < pq.M; m++) {
|
849
861
|
size_t nq, nb;
|
850
|
-
std::vector
|
851
|
-
std::vector
|
862
|
+
std::vector<uint32_t> codes; // query codes, then db codes
|
863
|
+
std::vector<float> gt_distances; // nq * nb matrix of distances
|
852
864
|
|
853
865
|
if (n > 0) {
|
854
|
-
std::vector<float> xtrain
|
866
|
+
std::vector<float> xtrain(n * dsub);
|
855
867
|
for (int i = 0; i < n; i++)
|
856
|
-
memcpy
|
857
|
-
|
858
|
-
|
868
|
+
memcpy(xtrain.data() + i * dsub,
|
869
|
+
x + i * pq.d + m * dsub,
|
870
|
+
sizeof(float) * dsub);
|
859
871
|
|
860
|
-
codes.resize
|
872
|
+
codes.resize(n);
|
861
873
|
for (int i = 0; i < n; i++)
|
862
|
-
codes
|
874
|
+
codes[i] = all_codes[i * pq.code_size + m];
|
863
875
|
|
864
|
-
nq = n / 4;
|
865
|
-
|
866
|
-
const float
|
876
|
+
nq = n / 4;
|
877
|
+
nb = n - nq;
|
878
|
+
const float* xq = xtrain.data();
|
879
|
+
const float* xb = xq + nq * dsub;
|
867
880
|
|
868
|
-
gt_distances.resize
|
881
|
+
gt_distances.resize(nq * nb);
|
869
882
|
|
870
|
-
pairwise_L2sqr
|
871
|
-
nq, xq,
|
872
|
-
nb, xb,
|
873
|
-
gt_distances.data());
|
883
|
+
pairwise_L2sqr(dsub, nq, xq, nb, xb, gt_distances.data());
|
874
884
|
} else {
|
875
885
|
nq = nb = pq.ksub;
|
876
|
-
codes.resize
|
886
|
+
codes.resize(2 * nq);
|
877
887
|
for (int i = 0; i < nq; i++)
|
878
|
-
codes[i] = codes
|
888
|
+
codes[i] = codes[i + nq] = i;
|
879
889
|
|
880
|
-
gt_distances.resize
|
890
|
+
gt_distances.resize(nq * nb);
|
881
891
|
|
882
|
-
memcpy
|
883
|
-
|
884
|
-
|
892
|
+
memcpy(gt_distances.data(),
|
893
|
+
pq.sdc_table.data() + m * nq * nb,
|
894
|
+
sizeof(float) * nq * nb);
|
885
895
|
}
|
886
896
|
|
887
|
-
double t0 = getmillisecs
|
897
|
+
double t0 = getmillisecs();
|
888
898
|
|
889
|
-
PermutationObjective
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
899
|
+
PermutationObjective* obj = new RankingScore2(
|
900
|
+
nbits,
|
901
|
+
nq,
|
902
|
+
nb,
|
903
|
+
codes.data(),
|
904
|
+
codes.data() + nq,
|
905
|
+
gt_distances.data());
|
906
|
+
ScopeDeleter1<PermutationObjective> del(obj);
|
894
907
|
|
895
908
|
if (verbose > 0) {
|
896
909
|
printf(" m=%d, nq=%zd, nb=%zd, intialize RankingScore "
|
897
910
|
"in %.3f ms\n",
|
898
|
-
m,
|
911
|
+
m,
|
912
|
+
nq,
|
913
|
+
nb,
|
914
|
+
getmillisecs() - t0);
|
899
915
|
}
|
900
916
|
|
901
|
-
SimulatedAnnealingOptimizer optim
|
917
|
+
SimulatedAnnealingOptimizer optim(obj, *this);
|
902
918
|
|
903
919
|
if (log_pattern.size()) {
|
904
920
|
char fname[256];
|
905
|
-
snprintf
|
906
|
-
printf
|
907
|
-
optim.logfile = fopen
|
908
|
-
FAISS_THROW_IF_NOT_FMT
|
909
|
-
|
921
|
+
snprintf(fname, 256, log_pattern.c_str(), m);
|
922
|
+
printf("opening log file %s\n", fname);
|
923
|
+
optim.logfile = fopen(fname, "w");
|
924
|
+
FAISS_THROW_IF_NOT_FMT(
|
925
|
+
optim.logfile, "could not open logfile %s", fname);
|
910
926
|
}
|
911
927
|
|
912
|
-
std::vector<int> perm
|
928
|
+
std::vector<int> perm(pq.ksub);
|
913
929
|
|
914
|
-
double final_cost = optim.run_optimization
|
915
|
-
printf
|
916
|
-
|
930
|
+
double final_cost = optim.run_optimization(perm.data());
|
931
|
+
printf("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
|
932
|
+
m,
|
933
|
+
optim.init_cost,
|
934
|
+
final_cost);
|
917
935
|
|
918
|
-
if (log_pattern.size())
|
936
|
+
if (log_pattern.size())
|
937
|
+
fclose(optim.logfile);
|
919
938
|
|
920
|
-
float
|
939
|
+
float* centroids = pq.get_centroids(m, 0);
|
921
940
|
|
922
941
|
std::vector<float> centroids_copy;
|
923
942
|
for (int i = 0; i < dsub * pq.ksub; i++)
|
924
|
-
centroids_copy.push_back
|
943
|
+
centroids_copy.push_back(centroids[i]);
|
925
944
|
|
926
945
|
for (int i = 0; i < pq.ksub; i++)
|
927
|
-
memcpy
|
928
|
-
|
929
|
-
|
930
|
-
|
946
|
+
memcpy(centroids + perm[i] * dsub,
|
947
|
+
centroids_copy.data() + i * dsub,
|
948
|
+
dsub * sizeof(centroids[0]));
|
931
949
|
}
|
932
|
-
|
933
950
|
}
|
934
951
|
|
935
|
-
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
{
|
952
|
+
void PolysemousTraining::optimize_pq_for_hamming(
|
953
|
+
ProductQuantizer& pq,
|
954
|
+
size_t n,
|
955
|
+
const float* x) const {
|
940
956
|
if (optimization_type == OT_None) {
|
941
|
-
|
942
957
|
} else if (optimization_type == OT_ReproduceDistances_affine) {
|
943
|
-
optimize_reproduce_distances
|
958
|
+
optimize_reproduce_distances(pq);
|
944
959
|
} else {
|
945
|
-
optimize_ranking
|
960
|
+
optimize_ranking(pq, n, x);
|
946
961
|
}
|
947
962
|
|
948
|
-
pq.compute_sdc_table
|
949
|
-
|
963
|
+
pq.compute_sdc_table();
|
950
964
|
}
|
951
965
|
|
966
|
+
size_t PolysemousTraining::memory_usage_per_thread(
|
967
|
+
const ProductQuantizer& pq) const {
|
968
|
+
size_t n = pq.ksub;
|
969
|
+
|
970
|
+
switch (optimization_type) {
|
971
|
+
case OT_None:
|
972
|
+
return 0;
|
973
|
+
case OT_ReproduceDistances_affine:
|
974
|
+
return n * n * sizeof(double) * 3;
|
975
|
+
case OT_Ranking_weighted_diff:
|
976
|
+
return n * n * n * sizeof(float);
|
977
|
+
}
|
978
|
+
|
979
|
+
FAISS_THROW_MSG("Invalid optmization type");
|
980
|
+
return 0;
|
981
|
+
}
|
952
982
|
|
953
983
|
} // namespace faiss
|