faiss 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#pragma once
|
|
11
|
+
|
|
12
|
+
#include <vector>
|
|
13
|
+
#include <unordered_set>
|
|
14
|
+
#include <queue>
|
|
15
|
+
|
|
16
|
+
#include <omp.h>
|
|
17
|
+
|
|
18
|
+
#include <faiss/Index.h>
|
|
19
|
+
#include <faiss/impl/FaissAssert.h>
|
|
20
|
+
#include <faiss/utils/random.h>
|
|
21
|
+
#include <faiss/utils/Heap.h>
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
namespace faiss {
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
/** Implementation of the Hierarchical Navigable Small World
|
|
28
|
+
* datastructure.
|
|
29
|
+
*
|
|
30
|
+
* Efficient and robust approximate nearest neighbor search using
|
|
31
|
+
* Hierarchical Navigable Small World graphs
|
|
32
|
+
*
|
|
33
|
+
* Yu. A. Malkov, D. A. Yashunin, arXiv 2017
|
|
34
|
+
*
|
|
35
|
+
* This implmentation is heavily influenced by the NMSlib
|
|
36
|
+
* implementation by Yury Malkov and Leonid Boystov
|
|
37
|
+
* (https://github.com/searchivarius/nmslib)
|
|
38
|
+
*
|
|
39
|
+
* The HNSW object stores only the neighbor link structure, see
|
|
40
|
+
* IndexHNSW.h for the full index object.
|
|
41
|
+
*/
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
struct VisitedTable;
|
|
45
|
+
struct DistanceComputer; // from AuxIndexStructures
|
|
46
|
+
|
|
47
|
+
struct HNSW {
|
|
48
|
+
/// internal storage of vectors (32 bits: this is expensive)
|
|
49
|
+
typedef int storage_idx_t;
|
|
50
|
+
|
|
51
|
+
/// Faiss results are 64-bit
|
|
52
|
+
typedef Index::idx_t idx_t;
|
|
53
|
+
|
|
54
|
+
typedef std::pair<float, storage_idx_t> Node;
|
|
55
|
+
|
|
56
|
+
/** Heap structure that allows fast
|
|
57
|
+
*/
|
|
58
|
+
struct MinimaxHeap {
|
|
59
|
+
int n;
|
|
60
|
+
int k;
|
|
61
|
+
int nvalid;
|
|
62
|
+
|
|
63
|
+
std::vector<storage_idx_t> ids;
|
|
64
|
+
std::vector<float> dis;
|
|
65
|
+
typedef faiss::CMax<float, storage_idx_t> HC;
|
|
66
|
+
|
|
67
|
+
explicit MinimaxHeap(int n): n(n), k(0), nvalid(0), ids(n), dis(n) {}
|
|
68
|
+
|
|
69
|
+
void push(storage_idx_t i, float v);
|
|
70
|
+
|
|
71
|
+
float max() const;
|
|
72
|
+
|
|
73
|
+
int size() const;
|
|
74
|
+
|
|
75
|
+
void clear();
|
|
76
|
+
|
|
77
|
+
int pop_min(float *vmin_out = nullptr);
|
|
78
|
+
|
|
79
|
+
int count_below(float thresh);
|
|
80
|
+
};
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
/// to sort pairs of (id, distance) from nearest to fathest or the reverse
|
|
84
|
+
struct NodeDistCloser {
|
|
85
|
+
float d;
|
|
86
|
+
int id;
|
|
87
|
+
NodeDistCloser(float d, int id): d(d), id(id) {}
|
|
88
|
+
bool operator < (const NodeDistCloser &obj1) const { return d < obj1.d; }
|
|
89
|
+
};
|
|
90
|
+
|
|
91
|
+
struct NodeDistFarther {
|
|
92
|
+
float d;
|
|
93
|
+
int id;
|
|
94
|
+
NodeDistFarther(float d, int id): d(d), id(id) {}
|
|
95
|
+
bool operator < (const NodeDistFarther &obj1) const { return d > obj1.d; }
|
|
96
|
+
};
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
/// assignment probability to each layer (sum=1)
|
|
100
|
+
std::vector<double> assign_probas;
|
|
101
|
+
|
|
102
|
+
/// number of neighbors stored per layer (cumulative), should not
|
|
103
|
+
/// be changed after first add
|
|
104
|
+
std::vector<int> cum_nneighbor_per_level;
|
|
105
|
+
|
|
106
|
+
/// level of each vector (base level = 1), size = ntotal
|
|
107
|
+
std::vector<int> levels;
|
|
108
|
+
|
|
109
|
+
/// offsets[i] is the offset in the neighbors array where vector i is stored
|
|
110
|
+
/// size ntotal + 1
|
|
111
|
+
std::vector<size_t> offsets;
|
|
112
|
+
|
|
113
|
+
/// neighbors[offsets[i]:offsets[i+1]] is the list of neighbors of vector i
|
|
114
|
+
/// for all levels. this is where all storage goes.
|
|
115
|
+
std::vector<storage_idx_t> neighbors;
|
|
116
|
+
|
|
117
|
+
/// entry point in the search structure (one of the points with maximum level
|
|
118
|
+
storage_idx_t entry_point;
|
|
119
|
+
|
|
120
|
+
faiss::RandomGenerator rng;
|
|
121
|
+
|
|
122
|
+
/// maximum level
|
|
123
|
+
int max_level;
|
|
124
|
+
|
|
125
|
+
/// expansion factor at construction time
|
|
126
|
+
int efConstruction;
|
|
127
|
+
|
|
128
|
+
/// expansion factor at search time
|
|
129
|
+
int efSearch;
|
|
130
|
+
|
|
131
|
+
/// during search: do we check whether the next best distance is good enough?
|
|
132
|
+
bool check_relative_distance = true;
|
|
133
|
+
|
|
134
|
+
/// number of entry points in levels > 0.
|
|
135
|
+
int upper_beam;
|
|
136
|
+
|
|
137
|
+
/// use bounded queue during exploration
|
|
138
|
+
bool search_bounded_queue = true;
|
|
139
|
+
|
|
140
|
+
// methods that initialize the tree sizes
|
|
141
|
+
|
|
142
|
+
/// initialize the assign_probas and cum_nneighbor_per_level to
|
|
143
|
+
/// have 2*M links on level 0 and M links on levels > 0
|
|
144
|
+
void set_default_probas(int M, float levelMult);
|
|
145
|
+
|
|
146
|
+
/// set nb of neighbors for this level (before adding anything)
|
|
147
|
+
void set_nb_neighbors(int level_no, int n);
|
|
148
|
+
|
|
149
|
+
// methods that access the tree sizes
|
|
150
|
+
|
|
151
|
+
/// nb of neighbors for this level
|
|
152
|
+
int nb_neighbors(int layer_no) const;
|
|
153
|
+
|
|
154
|
+
/// cumumlative nb up to (and excluding) this level
|
|
155
|
+
int cum_nb_neighbors(int layer_no) const;
|
|
156
|
+
|
|
157
|
+
/// range of entries in the neighbors table of vertex no at layer_no
|
|
158
|
+
void neighbor_range(idx_t no, int layer_no,
|
|
159
|
+
size_t * begin, size_t * end) const;
|
|
160
|
+
|
|
161
|
+
/// only mandatory parameter: nb of neighbors
|
|
162
|
+
explicit HNSW(int M = 32);
|
|
163
|
+
|
|
164
|
+
/// pick a random level for a new point
|
|
165
|
+
int random_level();
|
|
166
|
+
|
|
167
|
+
/// add n random levels to table (for debugging...)
|
|
168
|
+
void fill_with_random_links(size_t n);
|
|
169
|
+
|
|
170
|
+
void add_links_starting_from(DistanceComputer& ptdis,
|
|
171
|
+
storage_idx_t pt_id,
|
|
172
|
+
storage_idx_t nearest,
|
|
173
|
+
float d_nearest,
|
|
174
|
+
int level,
|
|
175
|
+
omp_lock_t *locks,
|
|
176
|
+
VisitedTable &vt);
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
/** add point pt_id on all levels <= pt_level and build the link
|
|
180
|
+
* structure for them. */
|
|
181
|
+
void add_with_locks(DistanceComputer& ptdis, int pt_level, int pt_id,
|
|
182
|
+
std::vector<omp_lock_t>& locks,
|
|
183
|
+
VisitedTable& vt);
|
|
184
|
+
|
|
185
|
+
int search_from_candidates(DistanceComputer& qdis, int k,
|
|
186
|
+
idx_t *I, float *D,
|
|
187
|
+
MinimaxHeap& candidates,
|
|
188
|
+
VisitedTable &vt,
|
|
189
|
+
int level, int nres_in = 0) const;
|
|
190
|
+
|
|
191
|
+
std::priority_queue<Node> search_from_candidate_unbounded(
|
|
192
|
+
const Node& node,
|
|
193
|
+
DistanceComputer& qdis,
|
|
194
|
+
int ef,
|
|
195
|
+
VisitedTable *vt
|
|
196
|
+
) const;
|
|
197
|
+
|
|
198
|
+
/// search interface
|
|
199
|
+
void search(DistanceComputer& qdis, int k,
|
|
200
|
+
idx_t *I, float *D,
|
|
201
|
+
VisitedTable& vt) const;
|
|
202
|
+
|
|
203
|
+
void reset();
|
|
204
|
+
|
|
205
|
+
void clear_neighbor_tables(int level);
|
|
206
|
+
void print_neighbor_stats(int level) const;
|
|
207
|
+
|
|
208
|
+
int prepare_level_tab(size_t n, bool preset_levels = false);
|
|
209
|
+
|
|
210
|
+
static void shrink_neighbor_list(
|
|
211
|
+
DistanceComputer& qdis,
|
|
212
|
+
std::priority_queue<NodeDistFarther>& input,
|
|
213
|
+
std::vector<NodeDistFarther>& output,
|
|
214
|
+
int max_size);
|
|
215
|
+
|
|
216
|
+
};
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
/**************************************************************
|
|
220
|
+
* Auxiliary structures
|
|
221
|
+
**************************************************************/
|
|
222
|
+
|
|
223
|
+
/// set implementation optimized for fast access.
|
|
224
|
+
struct VisitedTable {
|
|
225
|
+
std::vector<uint8_t> visited;
|
|
226
|
+
int visno;
|
|
227
|
+
|
|
228
|
+
explicit VisitedTable(int size)
|
|
229
|
+
: visited(size), visno(1) {}
|
|
230
|
+
|
|
231
|
+
/// set flog #no to true
|
|
232
|
+
void set(int no) {
|
|
233
|
+
visited[no] = visno;
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
/// get flag #no
|
|
237
|
+
bool get(int no) const {
|
|
238
|
+
return visited[no] == visno;
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
/// reset all flags to false
|
|
242
|
+
void advance() {
|
|
243
|
+
visno++;
|
|
244
|
+
if (visno == 250) {
|
|
245
|
+
// 250 rather than 255 because sometimes we use visno and visno+1
|
|
246
|
+
memset(visited.data(), 0, sizeof(visited[0]) * visited.size());
|
|
247
|
+
visno = 1;
|
|
248
|
+
}
|
|
249
|
+
}
|
|
250
|
+
};
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
struct HNSWStats {
|
|
254
|
+
size_t n1, n2, n3;
|
|
255
|
+
size_t ndis;
|
|
256
|
+
size_t nreorder;
|
|
257
|
+
bool view;
|
|
258
|
+
|
|
259
|
+
HNSWStats() {
|
|
260
|
+
reset();
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
void reset() {
|
|
264
|
+
n1 = n2 = n3 = 0;
|
|
265
|
+
ndis = 0;
|
|
266
|
+
nreorder = 0;
|
|
267
|
+
view = false;
|
|
268
|
+
}
|
|
269
|
+
};
|
|
270
|
+
|
|
271
|
+
// global var that collects them all
|
|
272
|
+
extern HNSWStats hnsw_stats;
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
} // namespace faiss
|
|
@@ -0,0 +1,953 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#include <faiss/impl/PolysemousTraining.h>
|
|
11
|
+
|
|
12
|
+
#include <cstdlib>
|
|
13
|
+
#include <cmath>
|
|
14
|
+
#include <cstring>
|
|
15
|
+
#include <stdint.h>
|
|
16
|
+
|
|
17
|
+
#include <algorithm>
|
|
18
|
+
|
|
19
|
+
#include <faiss/utils/random.h>
|
|
20
|
+
#include <faiss/utils/utils.h>
|
|
21
|
+
#include <faiss/utils/distances.h>
|
|
22
|
+
#include <faiss/utils/hamming.h>
|
|
23
|
+
|
|
24
|
+
#include <faiss/impl/FaissAssert.h>
|
|
25
|
+
|
|
26
|
+
/*****************************************
|
|
27
|
+
* Mixed PQ / Hamming
|
|
28
|
+
******************************************/
|
|
29
|
+
|
|
30
|
+
namespace faiss {
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
/****************************************************
|
|
34
|
+
* Optimization code
|
|
35
|
+
****************************************************/
|
|
36
|
+
|
|
37
|
+
SimulatedAnnealingParameters::SimulatedAnnealingParameters ()
|
|
38
|
+
{
|
|
39
|
+
// set some reasonable defaults for the optimization
|
|
40
|
+
init_temperature = 0.7;
|
|
41
|
+
temperature_decay = pow (0.9, 1/500.);
|
|
42
|
+
// reduce by a factor 0.9 every 500 it
|
|
43
|
+
n_iter = 500000;
|
|
44
|
+
n_redo = 2;
|
|
45
|
+
seed = 123;
|
|
46
|
+
verbose = 0;
|
|
47
|
+
only_bit_flips = false;
|
|
48
|
+
init_random = false;
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
// what would the cost update be if iw and jw were swapped?
|
|
52
|
+
// default implementation just computes both and computes the difference
|
|
53
|
+
double PermutationObjective::cost_update (
|
|
54
|
+
const int *perm, int iw, int jw) const
|
|
55
|
+
{
|
|
56
|
+
double orig_cost = compute_cost (perm);
|
|
57
|
+
|
|
58
|
+
std::vector<int> perm2 (n);
|
|
59
|
+
for (int i = 0; i < n; i++)
|
|
60
|
+
perm2[i] = perm[i];
|
|
61
|
+
perm2[iw] = perm[jw];
|
|
62
|
+
perm2[jw] = perm[iw];
|
|
63
|
+
|
|
64
|
+
double new_cost = compute_cost (perm2.data());
|
|
65
|
+
return new_cost - orig_cost;
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
SimulatedAnnealingOptimizer::SimulatedAnnealingOptimizer (
|
|
72
|
+
PermutationObjective *obj,
|
|
73
|
+
const SimulatedAnnealingParameters &p):
|
|
74
|
+
SimulatedAnnealingParameters (p),
|
|
75
|
+
obj (obj),
|
|
76
|
+
n(obj->n),
|
|
77
|
+
logfile (nullptr)
|
|
78
|
+
{
|
|
79
|
+
rnd = new RandomGenerator (p.seed);
|
|
80
|
+
FAISS_THROW_IF_NOT (n < 100000 && n >=0 );
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
SimulatedAnnealingOptimizer::~SimulatedAnnealingOptimizer ()
|
|
84
|
+
{
|
|
85
|
+
delete rnd;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// run the optimization and return the best result in best_perm
|
|
89
|
+
double SimulatedAnnealingOptimizer::run_optimization (int * best_perm)
|
|
90
|
+
{
|
|
91
|
+
double min_cost = 1e30;
|
|
92
|
+
|
|
93
|
+
// just do a few runs of the annealing and keep the lowest output cost
|
|
94
|
+
for (int it = 0; it < n_redo; it++) {
|
|
95
|
+
std::vector<int> perm(n);
|
|
96
|
+
for (int i = 0; i < n; i++)
|
|
97
|
+
perm[i] = i;
|
|
98
|
+
if (init_random) {
|
|
99
|
+
for (int i = 0; i < n; i++) {
|
|
100
|
+
int j = i + rnd->rand_int (n - i);
|
|
101
|
+
std::swap (perm[i], perm[j]);
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
float cost = optimize (perm.data());
|
|
105
|
+
if (logfile) fprintf (logfile, "\n");
|
|
106
|
+
if(verbose > 1) {
|
|
107
|
+
printf (" optimization run %d: cost=%g %s\n",
|
|
108
|
+
it, cost, cost < min_cost ? "keep" : "");
|
|
109
|
+
}
|
|
110
|
+
if (cost < min_cost) {
|
|
111
|
+
memcpy (best_perm, perm.data(), sizeof(perm[0]) * n);
|
|
112
|
+
min_cost = cost;
|
|
113
|
+
}
|
|
114
|
+
}
|
|
115
|
+
return min_cost;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
// perform the optimization loop, starting from and modifying
|
|
119
|
+
// permutation in-place
|
|
120
|
+
double SimulatedAnnealingOptimizer::optimize (int *perm)
|
|
121
|
+
{
|
|
122
|
+
double cost = init_cost = obj->compute_cost (perm);
|
|
123
|
+
int log2n = 0;
|
|
124
|
+
while (!(n <= (1 << log2n))) log2n++;
|
|
125
|
+
double temperature = init_temperature;
|
|
126
|
+
int n_swap = 0, n_hot = 0;
|
|
127
|
+
for (int it = 0; it < n_iter; it++) {
|
|
128
|
+
temperature = temperature * temperature_decay;
|
|
129
|
+
int iw, jw;
|
|
130
|
+
if (only_bit_flips) {
|
|
131
|
+
iw = rnd->rand_int (n);
|
|
132
|
+
jw = iw ^ (1 << rnd->rand_int (log2n));
|
|
133
|
+
} else {
|
|
134
|
+
iw = rnd->rand_int (n);
|
|
135
|
+
jw = rnd->rand_int (n - 1);
|
|
136
|
+
if (jw == iw) jw++;
|
|
137
|
+
}
|
|
138
|
+
double delta_cost = obj->cost_update (perm, iw, jw);
|
|
139
|
+
if (delta_cost < 0 || rnd->rand_float () < temperature) {
|
|
140
|
+
std::swap (perm[iw], perm[jw]);
|
|
141
|
+
cost += delta_cost;
|
|
142
|
+
n_swap++;
|
|
143
|
+
if (delta_cost >= 0) n_hot++;
|
|
144
|
+
}
|
|
145
|
+
if (verbose > 2 || (verbose > 1 && it % 10000 == 0)) {
|
|
146
|
+
printf (" iteration %d cost %g temp %g n_swap %d "
|
|
147
|
+
"(%d hot) \r",
|
|
148
|
+
it, cost, temperature, n_swap, n_hot);
|
|
149
|
+
fflush(stdout);
|
|
150
|
+
}
|
|
151
|
+
if (logfile) {
|
|
152
|
+
fprintf (logfile, "%d %g %g %d %d\n",
|
|
153
|
+
it, cost, temperature, n_swap, n_hot);
|
|
154
|
+
}
|
|
155
|
+
}
|
|
156
|
+
if (verbose > 1) printf("\n");
|
|
157
|
+
return cost;
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
/****************************************************
|
|
165
|
+
* Cost functions: ReproduceDistanceTable
|
|
166
|
+
****************************************************/
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
static inline int hamming_dis (uint64_t a, uint64_t b)
|
|
174
|
+
{
|
|
175
|
+
return __builtin_popcountl (a ^ b);
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
namespace {
|
|
179
|
+
|
|
180
|
+
/// optimize permutation to reproduce a distance table with Hamming distances
|
|
181
|
+
struct ReproduceWithHammingObjective : PermutationObjective {
|
|
182
|
+
int nbits;
|
|
183
|
+
double dis_weight_factor;
|
|
184
|
+
|
|
185
|
+
static double sqr (double x) { return x * x; }
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
// weihgting of distances: it is more important to reproduce small
|
|
189
|
+
// distances well
|
|
190
|
+
double dis_weight (double x) const
|
|
191
|
+
{
|
|
192
|
+
return exp (-dis_weight_factor * x);
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
std::vector<double> target_dis; // wanted distances (size n^2)
|
|
196
|
+
std::vector<double> weights; // weights for each distance (size n^2)
|
|
197
|
+
|
|
198
|
+
// cost = quadratic difference between actual distance and Hamming distance
|
|
199
|
+
double compute_cost(const int* perm) const override {
|
|
200
|
+
double cost = 0;
|
|
201
|
+
for (int i = 0; i < n; i++) {
|
|
202
|
+
for (int j = 0; j < n; j++) {
|
|
203
|
+
double wanted = target_dis[i * n + j];
|
|
204
|
+
double w = weights[i * n + j];
|
|
205
|
+
double actual = hamming_dis(perm[i], perm[j]);
|
|
206
|
+
cost += w * sqr(wanted - actual);
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
return cost;
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
// what would the cost update be if iw and jw were swapped?
|
|
214
|
+
// computed in O(n) instead of O(n^2) for the full re-computation
|
|
215
|
+
double cost_update(const int* perm, int iw, int jw) const override {
|
|
216
|
+
double delta_cost = 0;
|
|
217
|
+
|
|
218
|
+
for (int i = 0; i < n; i++) {
|
|
219
|
+
if (i == iw) {
|
|
220
|
+
for (int j = 0; j < n; j++) {
|
|
221
|
+
double wanted = target_dis[i * n + j], w = weights[i * n + j];
|
|
222
|
+
double actual = hamming_dis(perm[i], perm[j]);
|
|
223
|
+
delta_cost -= w * sqr(wanted - actual);
|
|
224
|
+
double new_actual =
|
|
225
|
+
hamming_dis(perm[jw], perm[j == iw ? jw : j == jw ? iw : j]);
|
|
226
|
+
delta_cost += w * sqr(wanted - new_actual);
|
|
227
|
+
}
|
|
228
|
+
} else if (i == jw) {
|
|
229
|
+
for (int j = 0; j < n; j++) {
|
|
230
|
+
double wanted = target_dis[i * n + j], w = weights[i * n + j];
|
|
231
|
+
double actual = hamming_dis(perm[i], perm[j]);
|
|
232
|
+
delta_cost -= w * sqr(wanted - actual);
|
|
233
|
+
double new_actual =
|
|
234
|
+
hamming_dis(perm[iw], perm[j == iw ? jw : j == jw ? iw : j]);
|
|
235
|
+
delta_cost += w * sqr(wanted - new_actual);
|
|
236
|
+
}
|
|
237
|
+
} else {
|
|
238
|
+
int j = iw;
|
|
239
|
+
{
|
|
240
|
+
double wanted = target_dis[i * n + j], w = weights[i * n + j];
|
|
241
|
+
double actual = hamming_dis(perm[i], perm[j]);
|
|
242
|
+
delta_cost -= w * sqr(wanted - actual);
|
|
243
|
+
double new_actual = hamming_dis(perm[i], perm[jw]);
|
|
244
|
+
delta_cost += w * sqr(wanted - new_actual);
|
|
245
|
+
}
|
|
246
|
+
j = jw;
|
|
247
|
+
{
|
|
248
|
+
double wanted = target_dis[i * n + j], w = weights[i * n + j];
|
|
249
|
+
double actual = hamming_dis(perm[i], perm[j]);
|
|
250
|
+
delta_cost -= w * sqr(wanted - actual);
|
|
251
|
+
double new_actual = hamming_dis(perm[i], perm[iw]);
|
|
252
|
+
delta_cost += w * sqr(wanted - new_actual);
|
|
253
|
+
}
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
return delta_cost;
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
ReproduceWithHammingObjective (
|
|
263
|
+
int nbits,
|
|
264
|
+
const std::vector<double> & dis_table,
|
|
265
|
+
double dis_weight_factor):
|
|
266
|
+
nbits (nbits), dis_weight_factor (dis_weight_factor)
|
|
267
|
+
{
|
|
268
|
+
n = 1 << nbits;
|
|
269
|
+
FAISS_THROW_IF_NOT (dis_table.size() == n * n);
|
|
270
|
+
set_affine_target_dis (dis_table);
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
void set_affine_target_dis (const std::vector<double> & dis_table)
|
|
274
|
+
{
|
|
275
|
+
double sum = 0, sum2 = 0;
|
|
276
|
+
int n2 = n * n;
|
|
277
|
+
for (int i = 0; i < n2; i++) {
|
|
278
|
+
sum += dis_table [i];
|
|
279
|
+
sum2 += dis_table [i] * dis_table [i];
|
|
280
|
+
}
|
|
281
|
+
double mean = sum / n2;
|
|
282
|
+
double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
|
|
283
|
+
|
|
284
|
+
target_dis.resize (n2);
|
|
285
|
+
|
|
286
|
+
for (int i = 0; i < n2; i++) {
|
|
287
|
+
// the mapping function
|
|
288
|
+
double td = (dis_table [i] - mean) / stddev * sqrt(nbits / 4) +
|
|
289
|
+
nbits / 2;
|
|
290
|
+
target_dis[i] = td;
|
|
291
|
+
// compute a weight
|
|
292
|
+
weights.push_back (dis_weight (td));
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
~ReproduceWithHammingObjective() override {}
|
|
298
|
+
};
|
|
299
|
+
|
|
300
|
+
} // anonymous namespace
|
|
301
|
+
|
|
302
|
+
// weihgting of distances: it is more important to reproduce small
|
|
303
|
+
// distances well
|
|
304
|
+
double ReproduceDistancesObjective::dis_weight (double x) const
|
|
305
|
+
{
|
|
306
|
+
return exp (-dis_weight_factor * x);
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
double ReproduceDistancesObjective::get_source_dis (int i, int j) const
|
|
311
|
+
{
|
|
312
|
+
return source_dis [i * n + j];
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
// cost = quadratic difference between actual distance and Hamming distance
|
|
316
|
+
double ReproduceDistancesObjective::compute_cost (const int *perm) const
|
|
317
|
+
{
|
|
318
|
+
double cost = 0;
|
|
319
|
+
for (int i = 0; i < n; i++) {
|
|
320
|
+
for (int j = 0; j < n; j++) {
|
|
321
|
+
double wanted = target_dis [i * n + j];
|
|
322
|
+
double w = weights [i * n + j];
|
|
323
|
+
double actual = get_source_dis (perm[i], perm[j]);
|
|
324
|
+
cost += w * sqr (wanted - actual);
|
|
325
|
+
}
|
|
326
|
+
}
|
|
327
|
+
return cost;
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
// what would the cost update be if iw and jw were swapped?
|
|
331
|
+
// computed in O(n) instead of O(n^2) for the full re-computation
|
|
332
|
+
double ReproduceDistancesObjective::cost_update(
|
|
333
|
+
const int *perm, int iw, int jw) const
|
|
334
|
+
{
|
|
335
|
+
double delta_cost = 0;
|
|
336
|
+
for (int i = 0; i < n; i++) {
|
|
337
|
+
if (i == iw) {
|
|
338
|
+
for (int j = 0; j < n; j++) {
|
|
339
|
+
double wanted = target_dis [i * n + j],
|
|
340
|
+
w = weights [i * n + j];
|
|
341
|
+
double actual = get_source_dis (perm[i], perm[j]);
|
|
342
|
+
delta_cost -= w * sqr (wanted - actual);
|
|
343
|
+
double new_actual = get_source_dis (
|
|
344
|
+
perm[jw],
|
|
345
|
+
perm[j == iw ? jw : j == jw ? iw : j]);
|
|
346
|
+
delta_cost += w * sqr (wanted - new_actual);
|
|
347
|
+
}
|
|
348
|
+
} else if (i == jw) {
|
|
349
|
+
for (int j = 0; j < n; j++) {
|
|
350
|
+
double wanted = target_dis [i * n + j],
|
|
351
|
+
w = weights [i * n + j];
|
|
352
|
+
double actual = get_source_dis (perm[i], perm[j]);
|
|
353
|
+
delta_cost -= w * sqr (wanted - actual);
|
|
354
|
+
double new_actual = get_source_dis (
|
|
355
|
+
perm[iw],
|
|
356
|
+
perm[j == iw ? jw : j == jw ? iw : j]);
|
|
357
|
+
delta_cost += w * sqr (wanted - new_actual);
|
|
358
|
+
}
|
|
359
|
+
} else {
|
|
360
|
+
int j = iw;
|
|
361
|
+
{
|
|
362
|
+
double wanted = target_dis [i * n + j],
|
|
363
|
+
w = weights [i * n + j];
|
|
364
|
+
double actual = get_source_dis (perm[i], perm[j]);
|
|
365
|
+
delta_cost -= w * sqr (wanted - actual);
|
|
366
|
+
double new_actual = get_source_dis (perm[i], perm[jw]);
|
|
367
|
+
delta_cost += w * sqr (wanted - new_actual);
|
|
368
|
+
}
|
|
369
|
+
j = jw;
|
|
370
|
+
{
|
|
371
|
+
double wanted = target_dis [i * n + j],
|
|
372
|
+
w = weights [i * n + j];
|
|
373
|
+
double actual = get_source_dis (perm[i], perm[j]);
|
|
374
|
+
delta_cost -= w * sqr (wanted - actual);
|
|
375
|
+
double new_actual = get_source_dis (perm[i], perm[iw]);
|
|
376
|
+
delta_cost += w * sqr (wanted - new_actual);
|
|
377
|
+
}
|
|
378
|
+
}
|
|
379
|
+
}
|
|
380
|
+
return delta_cost;
|
|
381
|
+
}
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
|
|
385
|
+
ReproduceDistancesObjective::ReproduceDistancesObjective (
|
|
386
|
+
int n,
|
|
387
|
+
const double *source_dis_in,
|
|
388
|
+
const double *target_dis_in,
|
|
389
|
+
double dis_weight_factor):
|
|
390
|
+
dis_weight_factor (dis_weight_factor),
|
|
391
|
+
target_dis (target_dis_in)
|
|
392
|
+
{
|
|
393
|
+
this->n = n;
|
|
394
|
+
set_affine_target_dis (source_dis_in);
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
void ReproduceDistancesObjective::compute_mean_stdev (
|
|
398
|
+
const double *tab, size_t n2,
|
|
399
|
+
double *mean_out, double *stddev_out)
|
|
400
|
+
{
|
|
401
|
+
double sum = 0, sum2 = 0;
|
|
402
|
+
for (int i = 0; i < n2; i++) {
|
|
403
|
+
sum += tab [i];
|
|
404
|
+
sum2 += tab [i] * tab [i];
|
|
405
|
+
}
|
|
406
|
+
double mean = sum / n2;
|
|
407
|
+
double stddev = sqrt(sum2 / n2 - (sum / n2) * (sum / n2));
|
|
408
|
+
*mean_out = mean;
|
|
409
|
+
*stddev_out = stddev;
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
void ReproduceDistancesObjective::set_affine_target_dis (
|
|
413
|
+
const double *source_dis_in)
|
|
414
|
+
{
|
|
415
|
+
int n2 = n * n;
|
|
416
|
+
|
|
417
|
+
double mean_src, stddev_src;
|
|
418
|
+
compute_mean_stdev (source_dis_in, n2, &mean_src, &stddev_src);
|
|
419
|
+
|
|
420
|
+
double mean_target, stddev_target;
|
|
421
|
+
compute_mean_stdev (target_dis, n2, &mean_target, &stddev_target);
|
|
422
|
+
|
|
423
|
+
printf ("map mean %g std %g -> mean %g std %g\n",
|
|
424
|
+
mean_src, stddev_src, mean_target, stddev_target);
|
|
425
|
+
|
|
426
|
+
source_dis.resize (n2);
|
|
427
|
+
weights.resize (n2);
|
|
428
|
+
|
|
429
|
+
for (int i = 0; i < n2; i++) {
|
|
430
|
+
// the mapping function
|
|
431
|
+
source_dis[i] = (source_dis_in[i] - mean_src) / stddev_src
|
|
432
|
+
* stddev_target + mean_target;
|
|
433
|
+
|
|
434
|
+
// compute a weight
|
|
435
|
+
weights [i] = dis_weight (target_dis[i]);
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
/****************************************************
|
|
441
|
+
* Cost functions: RankingScore
|
|
442
|
+
****************************************************/
|
|
443
|
+
|
|
444
|
+
/// Maintains a 3D table of elementary costs.
|
|
445
|
+
/// Accumulates elements based on Hamming distance comparisons
|
|
446
|
+
template <typename Ttab, typename Taccu>
|
|
447
|
+
struct Score3Computer: PermutationObjective {
|
|
448
|
+
|
|
449
|
+
int nc;
|
|
450
|
+
|
|
451
|
+
// cost matrix of size nc * nc *nc
|
|
452
|
+
// n_gt (i,j,k) = count of d_gt(x, y-) < d_gt(x, y+)
|
|
453
|
+
// where x has PQ code i, y- PQ code j and y+ PQ code k
|
|
454
|
+
std::vector<Ttab> n_gt;
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
/// the cost is a triple loop on the nc * nc * nc matrix of entries.
|
|
458
|
+
///
|
|
459
|
+
Taccu compute (const int * perm) const
|
|
460
|
+
{
|
|
461
|
+
Taccu accu = 0;
|
|
462
|
+
const Ttab *p = n_gt.data();
|
|
463
|
+
for (int i = 0; i < nc; i++) {
|
|
464
|
+
int ip = perm [i];
|
|
465
|
+
for (int j = 0; j < nc; j++) {
|
|
466
|
+
int jp = perm [j];
|
|
467
|
+
for (int k = 0; k < nc; k++) {
|
|
468
|
+
int kp = perm [k];
|
|
469
|
+
if (hamming_dis (ip, jp) <
|
|
470
|
+
hamming_dis (ip, kp)) {
|
|
471
|
+
accu += *p; // n_gt [ ( i * nc + j) * nc + k];
|
|
472
|
+
}
|
|
473
|
+
p++;
|
|
474
|
+
}
|
|
475
|
+
}
|
|
476
|
+
}
|
|
477
|
+
return accu;
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
/** cost update if entries iw and jw of the permutation would be
|
|
482
|
+
* swapped.
|
|
483
|
+
*
|
|
484
|
+
* The computation is optimized by avoiding elements in the
|
|
485
|
+
* nc*nc*nc cube that are known not to change. For nc=256, this
|
|
486
|
+
* reduces the nb of cells to visit to about 6/256 th of the
|
|
487
|
+
* cells. Practical speedup is about 8x, and the code is quite
|
|
488
|
+
* complex :-/
|
|
489
|
+
*/
|
|
490
|
+
Taccu compute_update (const int *perm, int iw, int jw) const
|
|
491
|
+
{
|
|
492
|
+
assert (iw != jw);
|
|
493
|
+
if (iw > jw) std::swap (iw, jw);
|
|
494
|
+
|
|
495
|
+
Taccu accu = 0;
|
|
496
|
+
const Ttab * n_gt_i = n_gt.data();
|
|
497
|
+
for (int i = 0; i < nc; i++) {
|
|
498
|
+
int ip0 = perm [i];
|
|
499
|
+
int ip = perm [i == iw ? jw : i == jw ? iw : i];
|
|
500
|
+
|
|
501
|
+
//accu += update_i (perm, iw, jw, ip0, ip, n_gt_i);
|
|
502
|
+
|
|
503
|
+
accu += update_i_cross (perm, iw, jw,
|
|
504
|
+
ip0, ip, n_gt_i);
|
|
505
|
+
|
|
506
|
+
if (ip != ip0)
|
|
507
|
+
accu += update_i_plane (perm, iw, jw,
|
|
508
|
+
ip0, ip, n_gt_i);
|
|
509
|
+
|
|
510
|
+
n_gt_i += nc * nc;
|
|
511
|
+
}
|
|
512
|
+
|
|
513
|
+
return accu;
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
Taccu update_i (const int *perm, int iw, int jw,
|
|
518
|
+
int ip0, int ip, const Ttab * n_gt_i) const
|
|
519
|
+
{
|
|
520
|
+
Taccu accu = 0;
|
|
521
|
+
const Ttab *n_gt_ij = n_gt_i;
|
|
522
|
+
for (int j = 0; j < nc; j++) {
|
|
523
|
+
int jp0 = perm[j];
|
|
524
|
+
int jp = perm [j == iw ? jw : j == jw ? iw : j];
|
|
525
|
+
for (int k = 0; k < nc; k++) {
|
|
526
|
+
int kp0 = perm [k];
|
|
527
|
+
int kp = perm [k == iw ? jw : k == jw ? iw : k];
|
|
528
|
+
int ng = n_gt_ij [k];
|
|
529
|
+
if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
|
|
530
|
+
accu += ng;
|
|
531
|
+
}
|
|
532
|
+
if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) {
|
|
533
|
+
accu -= ng;
|
|
534
|
+
}
|
|
535
|
+
}
|
|
536
|
+
n_gt_ij += nc;
|
|
537
|
+
}
|
|
538
|
+
return accu;
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
// 2 inner loops for the case ip0 != ip
|
|
542
|
+
Taccu update_i_plane (const int *perm, int iw, int jw,
|
|
543
|
+
int ip0, int ip, const Ttab * n_gt_i) const
|
|
544
|
+
{
|
|
545
|
+
Taccu accu = 0;
|
|
546
|
+
const Ttab *n_gt_ij = n_gt_i;
|
|
547
|
+
|
|
548
|
+
for (int j = 0; j < nc; j++) {
|
|
549
|
+
if (j != iw && j != jw) {
|
|
550
|
+
int jp = perm[j];
|
|
551
|
+
for (int k = 0; k < nc; k++) {
|
|
552
|
+
if (k != iw && k != jw) {
|
|
553
|
+
int kp = perm [k];
|
|
554
|
+
Ttab ng = n_gt_ij [k];
|
|
555
|
+
if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
|
|
556
|
+
accu += ng;
|
|
557
|
+
}
|
|
558
|
+
if (hamming_dis (ip0, jp) < hamming_dis (ip0, kp)) {
|
|
559
|
+
accu -= ng;
|
|
560
|
+
}
|
|
561
|
+
}
|
|
562
|
+
}
|
|
563
|
+
}
|
|
564
|
+
n_gt_ij += nc;
|
|
565
|
+
}
|
|
566
|
+
return accu;
|
|
567
|
+
}
|
|
568
|
+
|
|
569
|
+
/// used for the 8 cells were the 3 indices are swapped
|
|
570
|
+
inline Taccu update_k (const int *perm, int iw, int jw,
|
|
571
|
+
int ip0, int ip, int jp0, int jp,
|
|
572
|
+
int k,
|
|
573
|
+
const Ttab * n_gt_ij) const
|
|
574
|
+
{
|
|
575
|
+
Taccu accu = 0;
|
|
576
|
+
int kp0 = perm [k];
|
|
577
|
+
int kp = perm [k == iw ? jw : k == jw ? iw : k];
|
|
578
|
+
Ttab ng = n_gt_ij [k];
|
|
579
|
+
if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
|
|
580
|
+
accu += ng;
|
|
581
|
+
}
|
|
582
|
+
if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp0)) {
|
|
583
|
+
accu -= ng;
|
|
584
|
+
}
|
|
585
|
+
return accu;
|
|
586
|
+
}
|
|
587
|
+
|
|
588
|
+
/// compute update on a line of k's, where i and j are swapped
|
|
589
|
+
Taccu update_j_line (const int *perm, int iw, int jw,
|
|
590
|
+
int ip0, int ip, int jp0, int jp,
|
|
591
|
+
const Ttab * n_gt_ij) const
|
|
592
|
+
{
|
|
593
|
+
Taccu accu = 0;
|
|
594
|
+
for (int k = 0; k < nc; k++) {
|
|
595
|
+
if (k == iw || k == jw) continue;
|
|
596
|
+
int kp = perm [k];
|
|
597
|
+
Ttab ng = n_gt_ij [k];
|
|
598
|
+
if (hamming_dis (ip, jp) < hamming_dis (ip, kp)) {
|
|
599
|
+
accu += ng;
|
|
600
|
+
}
|
|
601
|
+
if (hamming_dis (ip0, jp0) < hamming_dis (ip0, kp)) {
|
|
602
|
+
accu -= ng;
|
|
603
|
+
}
|
|
604
|
+
}
|
|
605
|
+
return accu;
|
|
606
|
+
}
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
/// considers the 2 pairs of crossing lines j=iw or jw and k = iw or kw
|
|
610
|
+
Taccu update_i_cross (const int *perm, int iw, int jw,
|
|
611
|
+
int ip0, int ip, const Ttab * n_gt_i) const
|
|
612
|
+
{
|
|
613
|
+
Taccu accu = 0;
|
|
614
|
+
const Ttab *n_gt_ij = n_gt_i;
|
|
615
|
+
|
|
616
|
+
for (int j = 0; j < nc; j++) {
|
|
617
|
+
int jp0 = perm[j];
|
|
618
|
+
int jp = perm [j == iw ? jw : j == jw ? iw : j];
|
|
619
|
+
|
|
620
|
+
accu += update_k (perm, iw, jw, ip0, ip, jp0, jp, iw, n_gt_ij);
|
|
621
|
+
accu += update_k (perm, iw, jw, ip0, ip, jp0, jp, jw, n_gt_ij);
|
|
622
|
+
|
|
623
|
+
if (jp != jp0)
|
|
624
|
+
accu += update_j_line (perm, iw, jw, ip0, ip, jp0, jp, n_gt_ij);
|
|
625
|
+
|
|
626
|
+
n_gt_ij += nc;
|
|
627
|
+
}
|
|
628
|
+
return accu;
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
/// PermutationObjective implementeation (just negates the scores
|
|
633
|
+
/// for minimization)
|
|
634
|
+
|
|
635
|
+
double compute_cost(const int* perm) const override {
|
|
636
|
+
return -compute(perm);
|
|
637
|
+
}
|
|
638
|
+
|
|
639
|
+
double cost_update(const int* perm, int iw, int jw) const override {
|
|
640
|
+
double ret = -compute_update(perm, iw, jw);
|
|
641
|
+
return ret;
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
~Score3Computer() override {}
|
|
645
|
+
};
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
struct IndirectSort {
|
|
652
|
+
const float *tab;
|
|
653
|
+
bool operator () (int a, int b) {return tab[a] < tab[b]; }
|
|
654
|
+
};
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
struct RankingScore2: Score3Computer<float, double> {
|
|
659
|
+
int nbits;
|
|
660
|
+
int nq, nb;
|
|
661
|
+
const uint32_t *qcodes, *bcodes;
|
|
662
|
+
const float *gt_distances;
|
|
663
|
+
|
|
664
|
+
RankingScore2 (int nbits, int nq, int nb,
|
|
665
|
+
const uint32_t *qcodes, const uint32_t *bcodes,
|
|
666
|
+
const float *gt_distances):
|
|
667
|
+
nbits(nbits), nq(nq), nb(nb), qcodes(qcodes),
|
|
668
|
+
bcodes(bcodes), gt_distances(gt_distances)
|
|
669
|
+
{
|
|
670
|
+
n = nc = 1 << nbits;
|
|
671
|
+
n_gt.resize (nc * nc * nc);
|
|
672
|
+
init_n_gt ();
|
|
673
|
+
}
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
double rank_weight (int r)
|
|
677
|
+
{
|
|
678
|
+
return 1.0 / (r + 1);
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
/// count nb of i, j in a x b st. i < j
|
|
682
|
+
/// a and b should be sorted on input
|
|
683
|
+
/// they are the ranks of j and k respectively.
|
|
684
|
+
/// specific version for diff-of-rank weighting, cannot optimized
|
|
685
|
+
/// with a cumulative table
|
|
686
|
+
double accum_gt_weight_diff (const std::vector<int> & a,
|
|
687
|
+
const std::vector<int> & b)
|
|
688
|
+
{
|
|
689
|
+
int nb = b.size(), na = a.size();
|
|
690
|
+
|
|
691
|
+
double accu = 0;
|
|
692
|
+
int j = 0;
|
|
693
|
+
for (int i = 0; i < na; i++) {
|
|
694
|
+
int ai = a[i];
|
|
695
|
+
while (j < nb && ai >= b[j]) j++;
|
|
696
|
+
|
|
697
|
+
double accu_i = 0;
|
|
698
|
+
for (int k = j; k < b.size(); k++)
|
|
699
|
+
accu_i += rank_weight (b[k] - ai);
|
|
700
|
+
|
|
701
|
+
accu += rank_weight (ai) * accu_i;
|
|
702
|
+
|
|
703
|
+
}
|
|
704
|
+
return accu;
|
|
705
|
+
}
|
|
706
|
+
|
|
707
|
+
void init_n_gt ()
|
|
708
|
+
{
|
|
709
|
+
for (int q = 0; q < nq; q++) {
|
|
710
|
+
const float *gtd = gt_distances + q * nb;
|
|
711
|
+
const uint32_t *cb = bcodes;// all same codes
|
|
712
|
+
float * n_gt_q = & n_gt [qcodes[q] * nc * nc];
|
|
713
|
+
|
|
714
|
+
printf("init gt for q=%d/%d \r", q, nq); fflush(stdout);
|
|
715
|
+
|
|
716
|
+
std::vector<int> rankv (nb);
|
|
717
|
+
int * ranks = rankv.data();
|
|
718
|
+
|
|
719
|
+
// elements in each code bin, ordered by rank within each bin
|
|
720
|
+
std::vector<std::vector<int> > tab (nc);
|
|
721
|
+
|
|
722
|
+
{ // build rank table
|
|
723
|
+
IndirectSort s = {gtd};
|
|
724
|
+
for (int j = 0; j < nb; j++) ranks[j] = j;
|
|
725
|
+
std::sort (ranks, ranks + nb, s);
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
for (int rank = 0; rank < nb; rank++) {
|
|
729
|
+
int i = ranks [rank];
|
|
730
|
+
tab [cb[i]].push_back (rank);
|
|
731
|
+
}
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
// this is very expensive. Any suggestion for improvement
|
|
735
|
+
// welcome.
|
|
736
|
+
for (int i = 0; i < nc; i++) {
|
|
737
|
+
std::vector<int> & di = tab[i];
|
|
738
|
+
for (int j = 0; j < nc; j++) {
|
|
739
|
+
std::vector<int> & dj = tab[j];
|
|
740
|
+
n_gt_q [i * nc + j] += accum_gt_weight_diff (di, dj);
|
|
741
|
+
|
|
742
|
+
}
|
|
743
|
+
}
|
|
744
|
+
|
|
745
|
+
}
|
|
746
|
+
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
};
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
/*****************************************
|
|
753
|
+
* PolysemousTraining
|
|
754
|
+
******************************************/
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
PolysemousTraining::PolysemousTraining ()
|
|
759
|
+
{
|
|
760
|
+
optimization_type = OT_ReproduceDistances_affine;
|
|
761
|
+
ntrain_permutation = 0;
|
|
762
|
+
dis_weight_factor = log(2);
|
|
763
|
+
}
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
void PolysemousTraining::optimize_reproduce_distances (
|
|
768
|
+
ProductQuantizer &pq) const
|
|
769
|
+
{
|
|
770
|
+
|
|
771
|
+
int dsub = pq.dsub;
|
|
772
|
+
|
|
773
|
+
int n = pq.ksub;
|
|
774
|
+
int nbits = pq.nbits;
|
|
775
|
+
|
|
776
|
+
#pragma omp parallel for
|
|
777
|
+
for (int m = 0; m < pq.M; m++) {
|
|
778
|
+
std::vector<double> dis_table;
|
|
779
|
+
|
|
780
|
+
// printf ("Optimizing quantizer %d\n", m);
|
|
781
|
+
|
|
782
|
+
float * centroids = pq.get_centroids (m, 0);
|
|
783
|
+
|
|
784
|
+
for (int i = 0; i < n; i++) {
|
|
785
|
+
for (int j = 0; j < n; j++) {
|
|
786
|
+
dis_table.push_back (fvec_L2sqr (centroids + i * dsub,
|
|
787
|
+
centroids + j * dsub,
|
|
788
|
+
dsub));
|
|
789
|
+
}
|
|
790
|
+
}
|
|
791
|
+
|
|
792
|
+
std::vector<int> perm (n);
|
|
793
|
+
ReproduceWithHammingObjective obj (
|
|
794
|
+
nbits, dis_table,
|
|
795
|
+
dis_weight_factor);
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
SimulatedAnnealingOptimizer optim (&obj, *this);
|
|
799
|
+
|
|
800
|
+
if (log_pattern.size()) {
|
|
801
|
+
char fname[256];
|
|
802
|
+
snprintf (fname, 256, log_pattern.c_str(), m);
|
|
803
|
+
printf ("opening log file %s\n", fname);
|
|
804
|
+
optim.logfile = fopen (fname, "w");
|
|
805
|
+
FAISS_THROW_IF_NOT_MSG (optim.logfile, "could not open logfile");
|
|
806
|
+
}
|
|
807
|
+
double final_cost = optim.run_optimization (perm.data());
|
|
808
|
+
|
|
809
|
+
if (verbose > 0) {
|
|
810
|
+
printf ("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
|
|
811
|
+
m, optim.init_cost, final_cost);
|
|
812
|
+
}
|
|
813
|
+
|
|
814
|
+
if (log_pattern.size()) fclose (optim.logfile);
|
|
815
|
+
|
|
816
|
+
std::vector<float> centroids_copy;
|
|
817
|
+
for (int i = 0; i < dsub * n; i++)
|
|
818
|
+
centroids_copy.push_back (centroids[i]);
|
|
819
|
+
|
|
820
|
+
for (int i = 0; i < n; i++)
|
|
821
|
+
memcpy (centroids + perm[i] * dsub,
|
|
822
|
+
centroids_copy.data() + i * dsub,
|
|
823
|
+
dsub * sizeof(centroids[0]));
|
|
824
|
+
|
|
825
|
+
}
|
|
826
|
+
|
|
827
|
+
}
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
void PolysemousTraining::optimize_ranking (
|
|
831
|
+
ProductQuantizer &pq, size_t n, const float *x) const
|
|
832
|
+
{
|
|
833
|
+
|
|
834
|
+
int dsub = pq.dsub;
|
|
835
|
+
|
|
836
|
+
int nbits = pq.nbits;
|
|
837
|
+
|
|
838
|
+
std::vector<uint8_t> all_codes (pq.code_size * n);
|
|
839
|
+
|
|
840
|
+
pq.compute_codes (x, all_codes.data(), n);
|
|
841
|
+
|
|
842
|
+
FAISS_THROW_IF_NOT (pq.nbits == 8);
|
|
843
|
+
|
|
844
|
+
if (n == 0)
|
|
845
|
+
pq.compute_sdc_table ();
|
|
846
|
+
|
|
847
|
+
#pragma omp parallel for
|
|
848
|
+
for (int m = 0; m < pq.M; m++) {
|
|
849
|
+
size_t nq, nb;
|
|
850
|
+
std::vector <uint32_t> codes; // query codes, then db codes
|
|
851
|
+
std::vector <float> gt_distances; // nq * nb matrix of distances
|
|
852
|
+
|
|
853
|
+
if (n > 0) {
|
|
854
|
+
std::vector<float> xtrain (n * dsub);
|
|
855
|
+
for (int i = 0; i < n; i++)
|
|
856
|
+
memcpy (xtrain.data() + i * dsub,
|
|
857
|
+
x + i * pq.d + m * dsub,
|
|
858
|
+
sizeof(float) * dsub);
|
|
859
|
+
|
|
860
|
+
codes.resize (n);
|
|
861
|
+
for (int i = 0; i < n; i++)
|
|
862
|
+
codes [i] = all_codes [i * pq.code_size + m];
|
|
863
|
+
|
|
864
|
+
nq = n / 4; nb = n - nq;
|
|
865
|
+
const float *xq = xtrain.data();
|
|
866
|
+
const float *xb = xq + nq * dsub;
|
|
867
|
+
|
|
868
|
+
gt_distances.resize (nq * nb);
|
|
869
|
+
|
|
870
|
+
pairwise_L2sqr (dsub,
|
|
871
|
+
nq, xq,
|
|
872
|
+
nb, xb,
|
|
873
|
+
gt_distances.data());
|
|
874
|
+
} else {
|
|
875
|
+
nq = nb = pq.ksub;
|
|
876
|
+
codes.resize (2 * nq);
|
|
877
|
+
for (int i = 0; i < nq; i++)
|
|
878
|
+
codes[i] = codes [i + nq] = i;
|
|
879
|
+
|
|
880
|
+
gt_distances.resize (nq * nb);
|
|
881
|
+
|
|
882
|
+
memcpy (gt_distances.data (),
|
|
883
|
+
pq.sdc_table.data () + m * nq * nb,
|
|
884
|
+
sizeof (float) * nq * nb);
|
|
885
|
+
}
|
|
886
|
+
|
|
887
|
+
double t0 = getmillisecs ();
|
|
888
|
+
|
|
889
|
+
PermutationObjective *obj = new RankingScore2 (
|
|
890
|
+
nbits, nq, nb,
|
|
891
|
+
codes.data(), codes.data() + nq,
|
|
892
|
+
gt_distances.data ());
|
|
893
|
+
ScopeDeleter1<PermutationObjective> del (obj);
|
|
894
|
+
|
|
895
|
+
if (verbose > 0) {
|
|
896
|
+
printf(" m=%d, nq=%ld, nb=%ld, intialize RankingScore "
|
|
897
|
+
"in %.3f ms\n",
|
|
898
|
+
m, nq, nb, getmillisecs () - t0);
|
|
899
|
+
}
|
|
900
|
+
|
|
901
|
+
SimulatedAnnealingOptimizer optim (obj, *this);
|
|
902
|
+
|
|
903
|
+
if (log_pattern.size()) {
|
|
904
|
+
char fname[256];
|
|
905
|
+
snprintf (fname, 256, log_pattern.c_str(), m);
|
|
906
|
+
printf ("opening log file %s\n", fname);
|
|
907
|
+
optim.logfile = fopen (fname, "w");
|
|
908
|
+
FAISS_THROW_IF_NOT_FMT (optim.logfile,
|
|
909
|
+
"could not open logfile %s", fname);
|
|
910
|
+
}
|
|
911
|
+
|
|
912
|
+
std::vector<int> perm (pq.ksub);
|
|
913
|
+
|
|
914
|
+
double final_cost = optim.run_optimization (perm.data());
|
|
915
|
+
printf ("SimulatedAnnealingOptimizer for m=%d: %g -> %g\n",
|
|
916
|
+
m, optim.init_cost, final_cost);
|
|
917
|
+
|
|
918
|
+
if (log_pattern.size()) fclose (optim.logfile);
|
|
919
|
+
|
|
920
|
+
float * centroids = pq.get_centroids (m, 0);
|
|
921
|
+
|
|
922
|
+
std::vector<float> centroids_copy;
|
|
923
|
+
for (int i = 0; i < dsub * pq.ksub; i++)
|
|
924
|
+
centroids_copy.push_back (centroids[i]);
|
|
925
|
+
|
|
926
|
+
for (int i = 0; i < pq.ksub; i++)
|
|
927
|
+
memcpy (centroids + perm[i] * dsub,
|
|
928
|
+
centroids_copy.data() + i * dsub,
|
|
929
|
+
dsub * sizeof(centroids[0]));
|
|
930
|
+
|
|
931
|
+
}
|
|
932
|
+
|
|
933
|
+
}
|
|
934
|
+
|
|
935
|
+
|
|
936
|
+
|
|
937
|
+
void PolysemousTraining::optimize_pq_for_hamming (ProductQuantizer &pq,
|
|
938
|
+
size_t n, const float *x) const
|
|
939
|
+
{
|
|
940
|
+
if (optimization_type == OT_None) {
|
|
941
|
+
|
|
942
|
+
} else if (optimization_type == OT_ReproduceDistances_affine) {
|
|
943
|
+
optimize_reproduce_distances (pq);
|
|
944
|
+
} else {
|
|
945
|
+
optimize_ranking (pq, n, x);
|
|
946
|
+
}
|
|
947
|
+
|
|
948
|
+
pq.compute_sdc_table ();
|
|
949
|
+
|
|
950
|
+
}
|
|
951
|
+
|
|
952
|
+
|
|
953
|
+
} // namespace faiss
|