faiss 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#ifndef INDEX_FLAT_H
|
|
11
|
+
#define INDEX_FLAT_H
|
|
12
|
+
|
|
13
|
+
#include <vector>
|
|
14
|
+
|
|
15
|
+
#include <faiss/Index.h>
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
namespace faiss {
|
|
19
|
+
|
|
20
|
+
/** Index that stores the full vectors and performs exhaustive search */
|
|
21
|
+
struct IndexFlat: Index {
|
|
22
|
+
/// database vectors, size ntotal * d
|
|
23
|
+
std::vector<float> xb;
|
|
24
|
+
|
|
25
|
+
explicit IndexFlat (idx_t d, MetricType metric = METRIC_L2);
|
|
26
|
+
|
|
27
|
+
void add(idx_t n, const float* x) override;
|
|
28
|
+
|
|
29
|
+
void reset() override;
|
|
30
|
+
|
|
31
|
+
void search(
|
|
32
|
+
idx_t n,
|
|
33
|
+
const float* x,
|
|
34
|
+
idx_t k,
|
|
35
|
+
float* distances,
|
|
36
|
+
idx_t* labels) const override;
|
|
37
|
+
|
|
38
|
+
void range_search(
|
|
39
|
+
idx_t n,
|
|
40
|
+
const float* x,
|
|
41
|
+
float radius,
|
|
42
|
+
RangeSearchResult* result) const override;
|
|
43
|
+
|
|
44
|
+
void reconstruct(idx_t key, float* recons) const override;
|
|
45
|
+
|
|
46
|
+
/** compute distance with a subset of vectors
|
|
47
|
+
*
|
|
48
|
+
* @param x query vectors, size n * d
|
|
49
|
+
* @param labels indices of the vectors that should be compared
|
|
50
|
+
* for each query vector, size n * k
|
|
51
|
+
* @param distances
|
|
52
|
+
* corresponding output distances, size n * k
|
|
53
|
+
*/
|
|
54
|
+
void compute_distance_subset (
|
|
55
|
+
idx_t n,
|
|
56
|
+
const float *x,
|
|
57
|
+
idx_t k,
|
|
58
|
+
float *distances,
|
|
59
|
+
const idx_t *labels) const;
|
|
60
|
+
|
|
61
|
+
/** remove some ids. NB that Because of the structure of the
|
|
62
|
+
* indexing structure, the semantics of this operation are
|
|
63
|
+
* different from the usual ones: the new ids are shifted */
|
|
64
|
+
size_t remove_ids(const IDSelector& sel) override;
|
|
65
|
+
|
|
66
|
+
IndexFlat () {}
|
|
67
|
+
|
|
68
|
+
DistanceComputer * get_distance_computer() const override;
|
|
69
|
+
|
|
70
|
+
/* The stanadlone codec interface (just memcopies in this case) */
|
|
71
|
+
size_t sa_code_size () const override;
|
|
72
|
+
|
|
73
|
+
void sa_encode (idx_t n, const float *x,
|
|
74
|
+
uint8_t *bytes) const override;
|
|
75
|
+
|
|
76
|
+
void sa_decode (idx_t n, const uint8_t *bytes,
|
|
77
|
+
float *x) const override;
|
|
78
|
+
|
|
79
|
+
};
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
struct IndexFlatIP:IndexFlat {
|
|
84
|
+
explicit IndexFlatIP (idx_t d): IndexFlat (d, METRIC_INNER_PRODUCT) {}
|
|
85
|
+
IndexFlatIP () {}
|
|
86
|
+
};
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
struct IndexFlatL2:IndexFlat {
|
|
90
|
+
explicit IndexFlatL2 (idx_t d): IndexFlat (d, METRIC_L2) {}
|
|
91
|
+
IndexFlatL2 () {}
|
|
92
|
+
};
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
// same as an IndexFlatL2 but a value is subtracted from each distance
|
|
96
|
+
struct IndexFlatL2BaseShift: IndexFlatL2 {
|
|
97
|
+
std::vector<float> shift;
|
|
98
|
+
|
|
99
|
+
IndexFlatL2BaseShift (idx_t d, size_t nshift, const float *shift);
|
|
100
|
+
|
|
101
|
+
void search(
|
|
102
|
+
idx_t n,
|
|
103
|
+
const float* x,
|
|
104
|
+
idx_t k,
|
|
105
|
+
float* distances,
|
|
106
|
+
idx_t* labels) const override;
|
|
107
|
+
};
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
/** Index that queries in a base_index (a fast one) and refines the
|
|
111
|
+
* results with an exact search, hopefully improving the results.
|
|
112
|
+
*/
|
|
113
|
+
struct IndexRefineFlat: Index {
|
|
114
|
+
|
|
115
|
+
/// storage for full vectors
|
|
116
|
+
IndexFlat refine_index;
|
|
117
|
+
|
|
118
|
+
/// faster index to pre-select the vectors that should be filtered
|
|
119
|
+
Index *base_index;
|
|
120
|
+
bool own_fields; ///< should the base index be deallocated?
|
|
121
|
+
|
|
122
|
+
/// factor between k requested in search and the k requested from
|
|
123
|
+
/// the base_index (should be >= 1)
|
|
124
|
+
float k_factor;
|
|
125
|
+
|
|
126
|
+
explicit IndexRefineFlat (Index *base_index);
|
|
127
|
+
|
|
128
|
+
IndexRefineFlat ();
|
|
129
|
+
|
|
130
|
+
void train(idx_t n, const float* x) override;
|
|
131
|
+
|
|
132
|
+
void add(idx_t n, const float* x) override;
|
|
133
|
+
|
|
134
|
+
void reset() override;
|
|
135
|
+
|
|
136
|
+
void search(
|
|
137
|
+
idx_t n,
|
|
138
|
+
const float* x,
|
|
139
|
+
idx_t k,
|
|
140
|
+
float* distances,
|
|
141
|
+
idx_t* labels) const override;
|
|
142
|
+
|
|
143
|
+
~IndexRefineFlat() override;
|
|
144
|
+
};
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
/// optimized version for 1D "vectors"
|
|
148
|
+
struct IndexFlat1D:IndexFlatL2 {
|
|
149
|
+
bool continuous_update; ///< is the permutation updated continuously?
|
|
150
|
+
|
|
151
|
+
std::vector<idx_t> perm; ///< sorted database indices
|
|
152
|
+
|
|
153
|
+
explicit IndexFlat1D (bool continuous_update=true);
|
|
154
|
+
|
|
155
|
+
/// if not continuous_update, call this between the last add and
|
|
156
|
+
/// the first search
|
|
157
|
+
void update_permutation ();
|
|
158
|
+
|
|
159
|
+
void add(idx_t n, const float* x) override;
|
|
160
|
+
|
|
161
|
+
void reset() override;
|
|
162
|
+
|
|
163
|
+
/// Warn: the distances returned are L1 not L2
|
|
164
|
+
void search(
|
|
165
|
+
idx_t n,
|
|
166
|
+
const float* x,
|
|
167
|
+
idx_t k,
|
|
168
|
+
float* distances,
|
|
169
|
+
idx_t* labels) const override;
|
|
170
|
+
};
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
#endif
|
|
@@ -0,0 +1,1090 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#include <faiss/IndexHNSW.h>
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
#include <cstdlib>
|
|
14
|
+
#include <cassert>
|
|
15
|
+
#include <cstring>
|
|
16
|
+
#include <cstdio>
|
|
17
|
+
#include <cmath>
|
|
18
|
+
#include <omp.h>
|
|
19
|
+
|
|
20
|
+
#include <unordered_set>
|
|
21
|
+
#include <queue>
|
|
22
|
+
|
|
23
|
+
#include <sys/types.h>
|
|
24
|
+
#include <sys/stat.h>
|
|
25
|
+
#include <unistd.h>
|
|
26
|
+
#include <stdint.h>
|
|
27
|
+
|
|
28
|
+
#ifdef __SSE__
|
|
29
|
+
#include <immintrin.h>
|
|
30
|
+
#endif
|
|
31
|
+
|
|
32
|
+
#include <faiss/utils/distances.h>
|
|
33
|
+
#include <faiss/utils/random.h>
|
|
34
|
+
#include <faiss/utils/Heap.h>
|
|
35
|
+
#include <faiss/impl/FaissAssert.h>
|
|
36
|
+
#include <faiss/IndexFlat.h>
|
|
37
|
+
#include <faiss/IndexIVFPQ.h>
|
|
38
|
+
#include <faiss/Index2Layer.h>
|
|
39
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
extern "C" {
|
|
43
|
+
|
|
44
|
+
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
|
|
45
|
+
|
|
46
|
+
int sgemm_ (const char *transa, const char *transb, FINTEGER *m, FINTEGER *
|
|
47
|
+
n, FINTEGER *k, const float *alpha, const float *a,
|
|
48
|
+
FINTEGER *lda, const float *b, FINTEGER *
|
|
49
|
+
ldb, float *beta, float *c, FINTEGER *ldc);
|
|
50
|
+
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
namespace faiss {
|
|
54
|
+
|
|
55
|
+
using idx_t = Index::idx_t;
|
|
56
|
+
using MinimaxHeap = HNSW::MinimaxHeap;
|
|
57
|
+
using storage_idx_t = HNSW::storage_idx_t;
|
|
58
|
+
using NodeDistCloser = HNSW::NodeDistCloser;
|
|
59
|
+
using NodeDistFarther = HNSW::NodeDistFarther;
|
|
60
|
+
|
|
61
|
+
HNSWStats hnsw_stats;
|
|
62
|
+
|
|
63
|
+
/**************************************************************
|
|
64
|
+
* add / search blocks of descriptors
|
|
65
|
+
**************************************************************/
|
|
66
|
+
|
|
67
|
+
namespace {
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
void hnsw_add_vertices(IndexHNSW &index_hnsw,
|
|
71
|
+
size_t n0,
|
|
72
|
+
size_t n, const float *x,
|
|
73
|
+
bool verbose,
|
|
74
|
+
bool preset_levels = false) {
|
|
75
|
+
size_t d = index_hnsw.d;
|
|
76
|
+
HNSW & hnsw = index_hnsw.hnsw;
|
|
77
|
+
size_t ntotal = n0 + n;
|
|
78
|
+
double t0 = getmillisecs();
|
|
79
|
+
if (verbose) {
|
|
80
|
+
printf("hnsw_add_vertices: adding %ld elements on top of %ld "
|
|
81
|
+
"(preset_levels=%d)\n",
|
|
82
|
+
n, n0, int(preset_levels));
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
if (n == 0) {
|
|
86
|
+
return;
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
int max_level = hnsw.prepare_level_tab(n, preset_levels);
|
|
90
|
+
|
|
91
|
+
if (verbose) {
|
|
92
|
+
printf(" max_level = %d\n", max_level);
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
std::vector<omp_lock_t> locks(ntotal);
|
|
96
|
+
for(int i = 0; i < ntotal; i++)
|
|
97
|
+
omp_init_lock(&locks[i]);
|
|
98
|
+
|
|
99
|
+
// add vectors from highest to lowest level
|
|
100
|
+
std::vector<int> hist;
|
|
101
|
+
std::vector<int> order(n);
|
|
102
|
+
|
|
103
|
+
{ // make buckets with vectors of the same level
|
|
104
|
+
|
|
105
|
+
// build histogram
|
|
106
|
+
for (int i = 0; i < n; i++) {
|
|
107
|
+
storage_idx_t pt_id = i + n0;
|
|
108
|
+
int pt_level = hnsw.levels[pt_id] - 1;
|
|
109
|
+
while (pt_level >= hist.size())
|
|
110
|
+
hist.push_back(0);
|
|
111
|
+
hist[pt_level] ++;
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
// accumulate
|
|
115
|
+
std::vector<int> offsets(hist.size() + 1, 0);
|
|
116
|
+
for (int i = 0; i < hist.size() - 1; i++) {
|
|
117
|
+
offsets[i + 1] = offsets[i] + hist[i];
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
// bucket sort
|
|
121
|
+
for (int i = 0; i < n; i++) {
|
|
122
|
+
storage_idx_t pt_id = i + n0;
|
|
123
|
+
int pt_level = hnsw.levels[pt_id] - 1;
|
|
124
|
+
order[offsets[pt_level]++] = pt_id;
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
idx_t check_period = InterruptCallback::get_period_hint
|
|
129
|
+
(max_level * index_hnsw.d * hnsw.efConstruction);
|
|
130
|
+
|
|
131
|
+
{ // perform add
|
|
132
|
+
RandomGenerator rng2(789);
|
|
133
|
+
|
|
134
|
+
int i1 = n;
|
|
135
|
+
|
|
136
|
+
for (int pt_level = hist.size() - 1; pt_level >= 0; pt_level--) {
|
|
137
|
+
int i0 = i1 - hist[pt_level];
|
|
138
|
+
|
|
139
|
+
if (verbose) {
|
|
140
|
+
printf("Adding %d elements at level %d\n",
|
|
141
|
+
i1 - i0, pt_level);
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
// random permutation to get rid of dataset order bias
|
|
145
|
+
for (int j = i0; j < i1; j++)
|
|
146
|
+
std::swap(order[j], order[j + rng2.rand_int(i1 - j)]);
|
|
147
|
+
|
|
148
|
+
bool interrupt = false;
|
|
149
|
+
|
|
150
|
+
#pragma omp parallel if(i1 > i0 + 100)
|
|
151
|
+
{
|
|
152
|
+
VisitedTable vt (ntotal);
|
|
153
|
+
|
|
154
|
+
DistanceComputer *dis =
|
|
155
|
+
index_hnsw.storage->get_distance_computer();
|
|
156
|
+
ScopeDeleter1<DistanceComputer> del(dis);
|
|
157
|
+
int prev_display = verbose && omp_get_thread_num() == 0 ? 0 : -1;
|
|
158
|
+
size_t counter = 0;
|
|
159
|
+
|
|
160
|
+
#pragma omp for schedule(dynamic)
|
|
161
|
+
for (int i = i0; i < i1; i++) {
|
|
162
|
+
storage_idx_t pt_id = order[i];
|
|
163
|
+
dis->set_query (x + (pt_id - n0) * d);
|
|
164
|
+
|
|
165
|
+
// cannot break
|
|
166
|
+
if (interrupt) {
|
|
167
|
+
continue;
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
hnsw.add_with_locks(*dis, pt_level, pt_id, locks, vt);
|
|
171
|
+
|
|
172
|
+
if (prev_display >= 0 && i - i0 > prev_display + 10000) {
|
|
173
|
+
prev_display = i - i0;
|
|
174
|
+
printf(" %d / %d\r", i - i0, i1 - i0);
|
|
175
|
+
fflush(stdout);
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
if (counter % check_period == 0) {
|
|
179
|
+
if (InterruptCallback::is_interrupted ()) {
|
|
180
|
+
interrupt = true;
|
|
181
|
+
}
|
|
182
|
+
}
|
|
183
|
+
counter++;
|
|
184
|
+
}
|
|
185
|
+
|
|
186
|
+
}
|
|
187
|
+
if (interrupt) {
|
|
188
|
+
FAISS_THROW_MSG ("computation interrupted");
|
|
189
|
+
}
|
|
190
|
+
i1 = i0;
|
|
191
|
+
}
|
|
192
|
+
FAISS_ASSERT(i1 == 0);
|
|
193
|
+
}
|
|
194
|
+
if (verbose) {
|
|
195
|
+
printf("Done in %.3f ms\n", getmillisecs() - t0);
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
for(int i = 0; i < ntotal; i++) {
|
|
199
|
+
omp_destroy_lock(&locks[i]);
|
|
200
|
+
}
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
} // namespace
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
/**************************************************************
|
|
210
|
+
* IndexHNSW implementation
|
|
211
|
+
**************************************************************/
|
|
212
|
+
|
|
213
|
+
IndexHNSW::IndexHNSW(int d, int M):
|
|
214
|
+
Index(d, METRIC_L2),
|
|
215
|
+
hnsw(M),
|
|
216
|
+
own_fields(false),
|
|
217
|
+
storage(nullptr),
|
|
218
|
+
reconstruct_from_neighbors(nullptr)
|
|
219
|
+
{}
|
|
220
|
+
|
|
221
|
+
IndexHNSW::IndexHNSW(Index *storage, int M):
|
|
222
|
+
Index(storage->d, storage->metric_type),
|
|
223
|
+
hnsw(M),
|
|
224
|
+
own_fields(false),
|
|
225
|
+
storage(storage),
|
|
226
|
+
reconstruct_from_neighbors(nullptr)
|
|
227
|
+
{}
|
|
228
|
+
|
|
229
|
+
IndexHNSW::~IndexHNSW() {
|
|
230
|
+
if (own_fields) {
|
|
231
|
+
delete storage;
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
void IndexHNSW::train(idx_t n, const float* x)
|
|
236
|
+
{
|
|
237
|
+
FAISS_THROW_IF_NOT_MSG(storage,
|
|
238
|
+
"Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
|
|
239
|
+
// hnsw structure does not require training
|
|
240
|
+
storage->train (n, x);
|
|
241
|
+
is_trained = true;
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
void IndexHNSW::search (idx_t n, const float *x, idx_t k,
|
|
245
|
+
float *distances, idx_t *labels) const
|
|
246
|
+
|
|
247
|
+
{
|
|
248
|
+
FAISS_THROW_IF_NOT_MSG(storage,
|
|
249
|
+
"Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
|
|
250
|
+
size_t nreorder = 0;
|
|
251
|
+
|
|
252
|
+
idx_t check_period = InterruptCallback::get_period_hint (
|
|
253
|
+
hnsw.max_level * d * hnsw.efSearch);
|
|
254
|
+
|
|
255
|
+
for (idx_t i0 = 0; i0 < n; i0 += check_period) {
|
|
256
|
+
idx_t i1 = std::min(i0 + check_period, n);
|
|
257
|
+
|
|
258
|
+
#pragma omp parallel reduction(+ : nreorder)
|
|
259
|
+
{
|
|
260
|
+
VisitedTable vt (ntotal);
|
|
261
|
+
DistanceComputer *dis = storage->get_distance_computer();
|
|
262
|
+
ScopeDeleter1<DistanceComputer> del(dis);
|
|
263
|
+
|
|
264
|
+
#pragma omp for
|
|
265
|
+
for(idx_t i = i0; i < i1; i++) {
|
|
266
|
+
idx_t * idxi = labels + i * k;
|
|
267
|
+
float * simi = distances + i * k;
|
|
268
|
+
dis->set_query(x + i * d);
|
|
269
|
+
|
|
270
|
+
maxheap_heapify (k, simi, idxi);
|
|
271
|
+
hnsw.search(*dis, k, idxi, simi, vt);
|
|
272
|
+
|
|
273
|
+
maxheap_reorder (k, simi, idxi);
|
|
274
|
+
|
|
275
|
+
if (reconstruct_from_neighbors &&
|
|
276
|
+
reconstruct_from_neighbors->k_reorder != 0) {
|
|
277
|
+
int k_reorder = reconstruct_from_neighbors->k_reorder;
|
|
278
|
+
if (k_reorder == -1 || k_reorder > k) k_reorder = k;
|
|
279
|
+
|
|
280
|
+
nreorder += reconstruct_from_neighbors->compute_distances(
|
|
281
|
+
k_reorder, idxi, x + i * d, simi);
|
|
282
|
+
|
|
283
|
+
// sort top k_reorder
|
|
284
|
+
maxheap_heapify (k_reorder, simi, idxi, simi, idxi, k_reorder);
|
|
285
|
+
maxheap_reorder (k_reorder, simi, idxi);
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
}
|
|
291
|
+
InterruptCallback::check ();
|
|
292
|
+
}
|
|
293
|
+
hnsw_stats.nreorder += nreorder;
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
void IndexHNSW::add(idx_t n, const float *x)
|
|
298
|
+
{
|
|
299
|
+
FAISS_THROW_IF_NOT_MSG(storage,
|
|
300
|
+
"Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
|
|
301
|
+
FAISS_THROW_IF_NOT(is_trained);
|
|
302
|
+
int n0 = ntotal;
|
|
303
|
+
storage->add(n, x);
|
|
304
|
+
ntotal = storage->ntotal;
|
|
305
|
+
|
|
306
|
+
hnsw_add_vertices (*this, n0, n, x, verbose,
|
|
307
|
+
hnsw.levels.size() == ntotal);
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
void IndexHNSW::reset()
|
|
311
|
+
{
|
|
312
|
+
hnsw.reset();
|
|
313
|
+
storage->reset();
|
|
314
|
+
ntotal = 0;
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
void IndexHNSW::reconstruct (idx_t key, float* recons) const
|
|
318
|
+
{
|
|
319
|
+
storage->reconstruct(key, recons);
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
void IndexHNSW::shrink_level_0_neighbors(int new_size)
|
|
323
|
+
{
|
|
324
|
+
#pragma omp parallel
|
|
325
|
+
{
|
|
326
|
+
DistanceComputer *dis = storage->get_distance_computer();
|
|
327
|
+
ScopeDeleter1<DistanceComputer> del(dis);
|
|
328
|
+
|
|
329
|
+
#pragma omp for
|
|
330
|
+
for (idx_t i = 0; i < ntotal; i++) {
|
|
331
|
+
|
|
332
|
+
size_t begin, end;
|
|
333
|
+
hnsw.neighbor_range(i, 0, &begin, &end);
|
|
334
|
+
|
|
335
|
+
std::priority_queue<NodeDistFarther> initial_list;
|
|
336
|
+
|
|
337
|
+
for (size_t j = begin; j < end; j++) {
|
|
338
|
+
int v1 = hnsw.neighbors[j];
|
|
339
|
+
if (v1 < 0) break;
|
|
340
|
+
initial_list.emplace(dis->symmetric_dis(i, v1), v1);
|
|
341
|
+
|
|
342
|
+
// initial_list.emplace(qdis(v1), v1);
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
std::vector<NodeDistFarther> shrunk_list;
|
|
346
|
+
HNSW::shrink_neighbor_list(*dis, initial_list,
|
|
347
|
+
shrunk_list, new_size);
|
|
348
|
+
|
|
349
|
+
for (size_t j = begin; j < end; j++) {
|
|
350
|
+
if (j - begin < shrunk_list.size())
|
|
351
|
+
hnsw.neighbors[j] = shrunk_list[j - begin].id;
|
|
352
|
+
else
|
|
353
|
+
hnsw.neighbors[j] = -1;
|
|
354
|
+
}
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
void IndexHNSW::search_level_0(
|
|
361
|
+
idx_t n, const float *x, idx_t k,
|
|
362
|
+
const storage_idx_t *nearest, const float *nearest_d,
|
|
363
|
+
float *distances, idx_t *labels, int nprobe,
|
|
364
|
+
int search_type) const
|
|
365
|
+
{
|
|
366
|
+
|
|
367
|
+
storage_idx_t ntotal = hnsw.levels.size();
|
|
368
|
+
#pragma omp parallel
|
|
369
|
+
{
|
|
370
|
+
DistanceComputer *qdis = storage->get_distance_computer();
|
|
371
|
+
ScopeDeleter1<DistanceComputer> del(qdis);
|
|
372
|
+
|
|
373
|
+
VisitedTable vt (ntotal);
|
|
374
|
+
|
|
375
|
+
#pragma omp for
|
|
376
|
+
for(idx_t i = 0; i < n; i++) {
|
|
377
|
+
idx_t * idxi = labels + i * k;
|
|
378
|
+
float * simi = distances + i * k;
|
|
379
|
+
|
|
380
|
+
qdis->set_query(x + i * d);
|
|
381
|
+
maxheap_heapify (k, simi, idxi);
|
|
382
|
+
|
|
383
|
+
if (search_type == 1) {
|
|
384
|
+
|
|
385
|
+
int nres = 0;
|
|
386
|
+
|
|
387
|
+
for(int j = 0; j < nprobe; j++) {
|
|
388
|
+
storage_idx_t cj = nearest[i * nprobe + j];
|
|
389
|
+
|
|
390
|
+
if (cj < 0) break;
|
|
391
|
+
|
|
392
|
+
if (vt.get(cj)) continue;
|
|
393
|
+
|
|
394
|
+
int candidates_size = std::max(hnsw.efSearch, int(k));
|
|
395
|
+
MinimaxHeap candidates(candidates_size);
|
|
396
|
+
|
|
397
|
+
candidates.push(cj, nearest_d[i * nprobe + j]);
|
|
398
|
+
|
|
399
|
+
nres = hnsw.search_from_candidates(
|
|
400
|
+
*qdis, k, idxi, simi,
|
|
401
|
+
candidates, vt, 0, nres
|
|
402
|
+
);
|
|
403
|
+
}
|
|
404
|
+
} else if (search_type == 2) {
|
|
405
|
+
|
|
406
|
+
int candidates_size = std::max(hnsw.efSearch, int(k));
|
|
407
|
+
candidates_size = std::max(candidates_size, nprobe);
|
|
408
|
+
|
|
409
|
+
MinimaxHeap candidates(candidates_size);
|
|
410
|
+
for(int j = 0; j < nprobe; j++) {
|
|
411
|
+
storage_idx_t cj = nearest[i * nprobe + j];
|
|
412
|
+
|
|
413
|
+
if (cj < 0) break;
|
|
414
|
+
candidates.push(cj, nearest_d[i * nprobe + j]);
|
|
415
|
+
}
|
|
416
|
+
hnsw.search_from_candidates(
|
|
417
|
+
*qdis, k, idxi, simi,
|
|
418
|
+
candidates, vt, 0
|
|
419
|
+
);
|
|
420
|
+
|
|
421
|
+
}
|
|
422
|
+
vt.advance();
|
|
423
|
+
|
|
424
|
+
maxheap_reorder (k, simi, idxi);
|
|
425
|
+
|
|
426
|
+
}
|
|
427
|
+
}
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
void IndexHNSW::init_level_0_from_knngraph(
|
|
433
|
+
int k, const float *D, const idx_t *I)
|
|
434
|
+
{
|
|
435
|
+
int dest_size = hnsw.nb_neighbors (0);
|
|
436
|
+
|
|
437
|
+
#pragma omp parallel for
|
|
438
|
+
for (idx_t i = 0; i < ntotal; i++) {
|
|
439
|
+
DistanceComputer *qdis = storage->get_distance_computer();
|
|
440
|
+
float vec[d];
|
|
441
|
+
storage->reconstruct(i, vec);
|
|
442
|
+
qdis->set_query(vec);
|
|
443
|
+
|
|
444
|
+
std::priority_queue<NodeDistFarther> initial_list;
|
|
445
|
+
|
|
446
|
+
for (size_t j = 0; j < k; j++) {
|
|
447
|
+
int v1 = I[i * k + j];
|
|
448
|
+
if (v1 == i) continue;
|
|
449
|
+
if (v1 < 0) break;
|
|
450
|
+
initial_list.emplace(D[i * k + j], v1);
|
|
451
|
+
}
|
|
452
|
+
|
|
453
|
+
std::vector<NodeDistFarther> shrunk_list;
|
|
454
|
+
HNSW::shrink_neighbor_list(*qdis, initial_list, shrunk_list, dest_size);
|
|
455
|
+
|
|
456
|
+
size_t begin, end;
|
|
457
|
+
hnsw.neighbor_range(i, 0, &begin, &end);
|
|
458
|
+
|
|
459
|
+
for (size_t j = begin; j < end; j++) {
|
|
460
|
+
if (j - begin < shrunk_list.size())
|
|
461
|
+
hnsw.neighbors[j] = shrunk_list[j - begin].id;
|
|
462
|
+
else
|
|
463
|
+
hnsw.neighbors[j] = -1;
|
|
464
|
+
}
|
|
465
|
+
}
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
void IndexHNSW::init_level_0_from_entry_points(
|
|
471
|
+
int n, const storage_idx_t *points,
|
|
472
|
+
const storage_idx_t *nearests)
|
|
473
|
+
{
|
|
474
|
+
|
|
475
|
+
std::vector<omp_lock_t> locks(ntotal);
|
|
476
|
+
for(int i = 0; i < ntotal; i++)
|
|
477
|
+
omp_init_lock(&locks[i]);
|
|
478
|
+
|
|
479
|
+
#pragma omp parallel
|
|
480
|
+
{
|
|
481
|
+
VisitedTable vt (ntotal);
|
|
482
|
+
|
|
483
|
+
DistanceComputer *dis = storage->get_distance_computer();
|
|
484
|
+
ScopeDeleter1<DistanceComputer> del(dis);
|
|
485
|
+
float vec[storage->d];
|
|
486
|
+
|
|
487
|
+
#pragma omp for schedule(dynamic)
|
|
488
|
+
for (int i = 0; i < n; i++) {
|
|
489
|
+
storage_idx_t pt_id = points[i];
|
|
490
|
+
storage_idx_t nearest = nearests[i];
|
|
491
|
+
storage->reconstruct (pt_id, vec);
|
|
492
|
+
dis->set_query (vec);
|
|
493
|
+
|
|
494
|
+
hnsw.add_links_starting_from(*dis, pt_id,
|
|
495
|
+
nearest, (*dis)(nearest),
|
|
496
|
+
0, locks.data(), vt);
|
|
497
|
+
|
|
498
|
+
if (verbose && i % 10000 == 0) {
|
|
499
|
+
printf(" %d / %d\r", i, n);
|
|
500
|
+
fflush(stdout);
|
|
501
|
+
}
|
|
502
|
+
}
|
|
503
|
+
}
|
|
504
|
+
if (verbose) {
|
|
505
|
+
printf("\n");
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
for(int i = 0; i < ntotal; i++)
|
|
509
|
+
omp_destroy_lock(&locks[i]);
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
void IndexHNSW::reorder_links()
|
|
513
|
+
{
|
|
514
|
+
int M = hnsw.nb_neighbors(0);
|
|
515
|
+
|
|
516
|
+
#pragma omp parallel
|
|
517
|
+
{
|
|
518
|
+
std::vector<float> distances (M);
|
|
519
|
+
std::vector<size_t> order (M);
|
|
520
|
+
std::vector<storage_idx_t> tmp (M);
|
|
521
|
+
DistanceComputer *dis = storage->get_distance_computer();
|
|
522
|
+
ScopeDeleter1<DistanceComputer> del(dis);
|
|
523
|
+
|
|
524
|
+
#pragma omp for
|
|
525
|
+
for(storage_idx_t i = 0; i < ntotal; i++) {
|
|
526
|
+
|
|
527
|
+
size_t begin, end;
|
|
528
|
+
hnsw.neighbor_range(i, 0, &begin, &end);
|
|
529
|
+
|
|
530
|
+
for (size_t j = begin; j < end; j++) {
|
|
531
|
+
storage_idx_t nj = hnsw.neighbors[j];
|
|
532
|
+
if (nj < 0) {
|
|
533
|
+
end = j;
|
|
534
|
+
break;
|
|
535
|
+
}
|
|
536
|
+
distances[j - begin] = dis->symmetric_dis(i, nj);
|
|
537
|
+
tmp [j - begin] = nj;
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
fvec_argsort (end - begin, distances.data(), order.data());
|
|
541
|
+
for (size_t j = begin; j < end; j++) {
|
|
542
|
+
hnsw.neighbors[j] = tmp[order[j - begin]];
|
|
543
|
+
}
|
|
544
|
+
}
|
|
545
|
+
|
|
546
|
+
}
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
void IndexHNSW::link_singletons()
|
|
551
|
+
{
|
|
552
|
+
printf("search for singletons\n");
|
|
553
|
+
|
|
554
|
+
std::vector<bool> seen(ntotal);
|
|
555
|
+
|
|
556
|
+
for (size_t i = 0; i < ntotal; i++) {
|
|
557
|
+
size_t begin, end;
|
|
558
|
+
hnsw.neighbor_range(i, 0, &begin, &end);
|
|
559
|
+
for (size_t j = begin; j < end; j++) {
|
|
560
|
+
storage_idx_t ni = hnsw.neighbors[j];
|
|
561
|
+
if (ni >= 0) seen[ni] = true;
|
|
562
|
+
}
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
int n_sing = 0, n_sing_l1 = 0;
|
|
566
|
+
std::vector<storage_idx_t> singletons;
|
|
567
|
+
for (storage_idx_t i = 0; i < ntotal; i++) {
|
|
568
|
+
if (!seen[i]) {
|
|
569
|
+
singletons.push_back(i);
|
|
570
|
+
n_sing++;
|
|
571
|
+
if (hnsw.levels[i] > 1)
|
|
572
|
+
n_sing_l1++;
|
|
573
|
+
}
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
printf(" Found %d / %ld singletons (%d appear in a level above)\n",
|
|
577
|
+
n_sing, ntotal, n_sing_l1);
|
|
578
|
+
|
|
579
|
+
std::vector<float>recons(singletons.size() * d);
|
|
580
|
+
for (int i = 0; i < singletons.size(); i++) {
|
|
581
|
+
|
|
582
|
+
FAISS_ASSERT(!"not implemented");
|
|
583
|
+
|
|
584
|
+
}
|
|
585
|
+
|
|
586
|
+
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
|
|
590
|
+
/**************************************************************
|
|
591
|
+
* ReconstructFromNeighbors implementation
|
|
592
|
+
**************************************************************/
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
ReconstructFromNeighbors::ReconstructFromNeighbors(
|
|
596
|
+
const IndexHNSW & index, size_t k, size_t nsq):
|
|
597
|
+
index(index), k(k), nsq(nsq) {
|
|
598
|
+
M = index.hnsw.nb_neighbors(0);
|
|
599
|
+
FAISS_ASSERT(k <= 256);
|
|
600
|
+
code_size = k == 1 ? 0 : nsq;
|
|
601
|
+
ntotal = 0;
|
|
602
|
+
d = index.d;
|
|
603
|
+
FAISS_ASSERT(d % nsq == 0);
|
|
604
|
+
dsub = d / nsq;
|
|
605
|
+
k_reorder = -1;
|
|
606
|
+
}
|
|
607
|
+
|
|
608
|
+
void ReconstructFromNeighbors::reconstruct(storage_idx_t i, float *x, float *tmp) const
|
|
609
|
+
{
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
const HNSW & hnsw = index.hnsw;
|
|
613
|
+
size_t begin, end;
|
|
614
|
+
hnsw.neighbor_range(i, 0, &begin, &end);
|
|
615
|
+
|
|
616
|
+
if (k == 1 || nsq == 1) {
|
|
617
|
+
const float * beta;
|
|
618
|
+
if (k == 1) {
|
|
619
|
+
beta = codebook.data();
|
|
620
|
+
} else {
|
|
621
|
+
int idx = codes[i];
|
|
622
|
+
beta = codebook.data() + idx * (M + 1);
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
float w0 = beta[0]; // weight of image itself
|
|
626
|
+
index.storage->reconstruct(i, tmp);
|
|
627
|
+
|
|
628
|
+
for (int l = 0; l < d; l++)
|
|
629
|
+
x[l] = w0 * tmp[l];
|
|
630
|
+
|
|
631
|
+
for (size_t j = begin; j < end; j++) {
|
|
632
|
+
|
|
633
|
+
storage_idx_t ji = hnsw.neighbors[j];
|
|
634
|
+
if (ji < 0) ji = i;
|
|
635
|
+
float w = beta[j - begin + 1];
|
|
636
|
+
index.storage->reconstruct(ji, tmp);
|
|
637
|
+
for (int l = 0; l < d; l++)
|
|
638
|
+
x[l] += w * tmp[l];
|
|
639
|
+
}
|
|
640
|
+
} else if (nsq == 2) {
|
|
641
|
+
int idx0 = codes[2 * i];
|
|
642
|
+
int idx1 = codes[2 * i + 1];
|
|
643
|
+
|
|
644
|
+
const float *beta0 = codebook.data() + idx0 * (M + 1);
|
|
645
|
+
const float *beta1 = codebook.data() + (idx1 + k) * (M + 1);
|
|
646
|
+
|
|
647
|
+
index.storage->reconstruct(i, tmp);
|
|
648
|
+
|
|
649
|
+
float w0;
|
|
650
|
+
|
|
651
|
+
w0 = beta0[0];
|
|
652
|
+
for (int l = 0; l < dsub; l++)
|
|
653
|
+
x[l] = w0 * tmp[l];
|
|
654
|
+
|
|
655
|
+
w0 = beta1[0];
|
|
656
|
+
for (int l = dsub; l < d; l++)
|
|
657
|
+
x[l] = w0 * tmp[l];
|
|
658
|
+
|
|
659
|
+
for (size_t j = begin; j < end; j++) {
|
|
660
|
+
storage_idx_t ji = hnsw.neighbors[j];
|
|
661
|
+
if (ji < 0) ji = i;
|
|
662
|
+
index.storage->reconstruct(ji, tmp);
|
|
663
|
+
float w;
|
|
664
|
+
w = beta0[j - begin + 1];
|
|
665
|
+
for (int l = 0; l < dsub; l++)
|
|
666
|
+
x[l] += w * tmp[l];
|
|
667
|
+
|
|
668
|
+
w = beta1[j - begin + 1];
|
|
669
|
+
for (int l = dsub; l < d; l++)
|
|
670
|
+
x[l] += w * tmp[l];
|
|
671
|
+
}
|
|
672
|
+
} else {
|
|
673
|
+
const float *betas[nsq];
|
|
674
|
+
{
|
|
675
|
+
const float *b = codebook.data();
|
|
676
|
+
const uint8_t *c = &codes[i * code_size];
|
|
677
|
+
for (int sq = 0; sq < nsq; sq++) {
|
|
678
|
+
betas[sq] = b + (*c++) * (M + 1);
|
|
679
|
+
b += (M + 1) * k;
|
|
680
|
+
}
|
|
681
|
+
}
|
|
682
|
+
|
|
683
|
+
index.storage->reconstruct(i, tmp);
|
|
684
|
+
{
|
|
685
|
+
int d0 = 0;
|
|
686
|
+
for (int sq = 0; sq < nsq; sq++) {
|
|
687
|
+
float w = *(betas[sq]++);
|
|
688
|
+
int d1 = d0 + dsub;
|
|
689
|
+
for (int l = d0; l < d1; l++) {
|
|
690
|
+
x[l] = w * tmp[l];
|
|
691
|
+
}
|
|
692
|
+
d0 = d1;
|
|
693
|
+
}
|
|
694
|
+
}
|
|
695
|
+
|
|
696
|
+
for (size_t j = begin; j < end; j++) {
|
|
697
|
+
storage_idx_t ji = hnsw.neighbors[j];
|
|
698
|
+
if (ji < 0) ji = i;
|
|
699
|
+
|
|
700
|
+
index.storage->reconstruct(ji, tmp);
|
|
701
|
+
int d0 = 0;
|
|
702
|
+
for (int sq = 0; sq < nsq; sq++) {
|
|
703
|
+
float w = *(betas[sq]++);
|
|
704
|
+
int d1 = d0 + dsub;
|
|
705
|
+
for (int l = d0; l < d1; l++) {
|
|
706
|
+
x[l] += w * tmp[l];
|
|
707
|
+
}
|
|
708
|
+
d0 = d1;
|
|
709
|
+
}
|
|
710
|
+
}
|
|
711
|
+
}
|
|
712
|
+
}
|
|
713
|
+
|
|
714
|
+
void ReconstructFromNeighbors::reconstruct_n(storage_idx_t n0,
|
|
715
|
+
storage_idx_t ni,
|
|
716
|
+
float *x) const
|
|
717
|
+
{
|
|
718
|
+
#pragma omp parallel
|
|
719
|
+
{
|
|
720
|
+
std::vector<float> tmp(index.d);
|
|
721
|
+
#pragma omp for
|
|
722
|
+
for (storage_idx_t i = 0; i < ni; i++) {
|
|
723
|
+
reconstruct(n0 + i, x + i * index.d, tmp.data());
|
|
724
|
+
}
|
|
725
|
+
}
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
size_t ReconstructFromNeighbors::compute_distances(
|
|
729
|
+
size_t n, const idx_t *shortlist,
|
|
730
|
+
const float *query, float *distances) const
|
|
731
|
+
{
|
|
732
|
+
std::vector<float> tmp(2 * index.d);
|
|
733
|
+
size_t ncomp = 0;
|
|
734
|
+
for (int i = 0; i < n; i++) {
|
|
735
|
+
if (shortlist[i] < 0) break;
|
|
736
|
+
reconstruct(shortlist[i], tmp.data(), tmp.data() + index.d);
|
|
737
|
+
distances[i] = fvec_L2sqr(query, tmp.data(), index.d);
|
|
738
|
+
ncomp++;
|
|
739
|
+
}
|
|
740
|
+
return ncomp;
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
void ReconstructFromNeighbors::get_neighbor_table(storage_idx_t i, float *tmp1) const
|
|
744
|
+
{
|
|
745
|
+
const HNSW & hnsw = index.hnsw;
|
|
746
|
+
size_t begin, end;
|
|
747
|
+
hnsw.neighbor_range(i, 0, &begin, &end);
|
|
748
|
+
size_t d = index.d;
|
|
749
|
+
|
|
750
|
+
index.storage->reconstruct(i, tmp1);
|
|
751
|
+
|
|
752
|
+
for (size_t j = begin; j < end; j++) {
|
|
753
|
+
storage_idx_t ji = hnsw.neighbors[j];
|
|
754
|
+
if (ji < 0) ji = i;
|
|
755
|
+
index.storage->reconstruct(ji, tmp1 + (j - begin + 1) * d);
|
|
756
|
+
}
|
|
757
|
+
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
|
|
761
|
+
/// called by add_codes
|
|
762
|
+
void ReconstructFromNeighbors::estimate_code(
|
|
763
|
+
const float *x, storage_idx_t i, uint8_t *code) const
|
|
764
|
+
{
|
|
765
|
+
|
|
766
|
+
// fill in tmp table with the neighbor values
|
|
767
|
+
float *tmp1 = new float[d * (M + 1) + (d * k)];
|
|
768
|
+
float *tmp2 = tmp1 + d * (M + 1);
|
|
769
|
+
ScopeDeleter<float> del(tmp1);
|
|
770
|
+
|
|
771
|
+
// collect coordinates of base
|
|
772
|
+
get_neighbor_table (i, tmp1);
|
|
773
|
+
|
|
774
|
+
for (size_t sq = 0; sq < nsq; sq++) {
|
|
775
|
+
int d0 = sq * dsub;
|
|
776
|
+
|
|
777
|
+
{
|
|
778
|
+
FINTEGER ki = k, di = d, m1 = M + 1;
|
|
779
|
+
FINTEGER dsubi = dsub;
|
|
780
|
+
float zero = 0, one = 1;
|
|
781
|
+
|
|
782
|
+
sgemm_ ("N", "N", &dsubi, &ki, &m1, &one,
|
|
783
|
+
tmp1 + d0, &di,
|
|
784
|
+
codebook.data() + sq * (m1 * k), &m1,
|
|
785
|
+
&zero, tmp2, &dsubi);
|
|
786
|
+
}
|
|
787
|
+
|
|
788
|
+
float min = HUGE_VAL;
|
|
789
|
+
int argmin = -1;
|
|
790
|
+
for (size_t j = 0; j < k; j++) {
|
|
791
|
+
float dis = fvec_L2sqr(x + d0, tmp2 + j * dsub, dsub);
|
|
792
|
+
if (dis < min) {
|
|
793
|
+
min = dis;
|
|
794
|
+
argmin = j;
|
|
795
|
+
}
|
|
796
|
+
}
|
|
797
|
+
code[sq] = argmin;
|
|
798
|
+
}
|
|
799
|
+
|
|
800
|
+
}
|
|
801
|
+
|
|
802
|
+
void ReconstructFromNeighbors::add_codes(size_t n, const float *x)
|
|
803
|
+
{
|
|
804
|
+
if (k == 1) { // nothing to encode
|
|
805
|
+
ntotal += n;
|
|
806
|
+
return;
|
|
807
|
+
}
|
|
808
|
+
codes.resize(codes.size() + code_size * n);
|
|
809
|
+
#pragma omp parallel for
|
|
810
|
+
for (int i = 0; i < n; i++) {
|
|
811
|
+
estimate_code(x + i * index.d, ntotal + i,
|
|
812
|
+
codes.data() + (ntotal + i) * code_size);
|
|
813
|
+
}
|
|
814
|
+
ntotal += n;
|
|
815
|
+
FAISS_ASSERT (codes.size() == ntotal * code_size);
|
|
816
|
+
}
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
/**************************************************************
|
|
820
|
+
* IndexHNSWFlat implementation
|
|
821
|
+
**************************************************************/
|
|
822
|
+
|
|
823
|
+
|
|
824
|
+
IndexHNSWFlat::IndexHNSWFlat()
|
|
825
|
+
{
|
|
826
|
+
is_trained = true;
|
|
827
|
+
}
|
|
828
|
+
|
|
829
|
+
IndexHNSWFlat::IndexHNSWFlat(int d, int M):
|
|
830
|
+
IndexHNSW(new IndexFlatL2(d), M)
|
|
831
|
+
{
|
|
832
|
+
own_fields = true;
|
|
833
|
+
is_trained = true;
|
|
834
|
+
}
|
|
835
|
+
|
|
836
|
+
|
|
837
|
+
/**************************************************************
|
|
838
|
+
* IndexHNSWPQ implementation
|
|
839
|
+
**************************************************************/
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
IndexHNSWPQ::IndexHNSWPQ() {}
|
|
843
|
+
|
|
844
|
+
IndexHNSWPQ::IndexHNSWPQ(int d, int pq_m, int M):
|
|
845
|
+
IndexHNSW(new IndexPQ(d, pq_m, 8), M)
|
|
846
|
+
{
|
|
847
|
+
own_fields = true;
|
|
848
|
+
is_trained = false;
|
|
849
|
+
}
|
|
850
|
+
|
|
851
|
+
void IndexHNSWPQ::train(idx_t n, const float* x)
|
|
852
|
+
{
|
|
853
|
+
IndexHNSW::train (n, x);
|
|
854
|
+
(dynamic_cast<IndexPQ*> (storage))->pq.compute_sdc_table();
|
|
855
|
+
}
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
/**************************************************************
|
|
859
|
+
* IndexHNSWSQ implementation
|
|
860
|
+
**************************************************************/
|
|
861
|
+
|
|
862
|
+
|
|
863
|
+
IndexHNSWSQ::IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M):
|
|
864
|
+
IndexHNSW (new IndexScalarQuantizer (d, qtype), M)
|
|
865
|
+
{
|
|
866
|
+
is_trained = false;
|
|
867
|
+
own_fields = true;
|
|
868
|
+
}
|
|
869
|
+
|
|
870
|
+
IndexHNSWSQ::IndexHNSWSQ() {}
|
|
871
|
+
|
|
872
|
+
|
|
873
|
+
/**************************************************************
|
|
874
|
+
* IndexHNSW2Level implementation
|
|
875
|
+
**************************************************************/
|
|
876
|
+
|
|
877
|
+
|
|
878
|
+
IndexHNSW2Level::IndexHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M):
|
|
879
|
+
IndexHNSW (new Index2Layer (quantizer, nlist, m_pq), M)
|
|
880
|
+
{
|
|
881
|
+
own_fields = true;
|
|
882
|
+
is_trained = false;
|
|
883
|
+
}
|
|
884
|
+
|
|
885
|
+
IndexHNSW2Level::IndexHNSW2Level() {}
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
namespace {
|
|
889
|
+
|
|
890
|
+
|
|
891
|
+
// same as search_from_candidates but uses v
|
|
892
|
+
// visno -> is in result list
|
|
893
|
+
// visno + 1 -> in result list + in candidates
|
|
894
|
+
int search_from_candidates_2(const HNSW & hnsw,
|
|
895
|
+
DistanceComputer & qdis, int k,
|
|
896
|
+
idx_t *I, float * D,
|
|
897
|
+
MinimaxHeap &candidates,
|
|
898
|
+
VisitedTable &vt,
|
|
899
|
+
int level, int nres_in = 0)
|
|
900
|
+
{
|
|
901
|
+
int nres = nres_in;
|
|
902
|
+
int ndis = 0;
|
|
903
|
+
for (int i = 0; i < candidates.size(); i++) {
|
|
904
|
+
idx_t v1 = candidates.ids[i];
|
|
905
|
+
FAISS_ASSERT(v1 >= 0);
|
|
906
|
+
vt.visited[v1] = vt.visno + 1;
|
|
907
|
+
}
|
|
908
|
+
|
|
909
|
+
int nstep = 0;
|
|
910
|
+
|
|
911
|
+
while (candidates.size() > 0) {
|
|
912
|
+
float d0 = 0;
|
|
913
|
+
int v0 = candidates.pop_min(&d0);
|
|
914
|
+
|
|
915
|
+
size_t begin, end;
|
|
916
|
+
hnsw.neighbor_range(v0, level, &begin, &end);
|
|
917
|
+
|
|
918
|
+
for (size_t j = begin; j < end; j++) {
|
|
919
|
+
int v1 = hnsw.neighbors[j];
|
|
920
|
+
if (v1 < 0) break;
|
|
921
|
+
if (vt.visited[v1] == vt.visno + 1) {
|
|
922
|
+
// nothing to do
|
|
923
|
+
} else {
|
|
924
|
+
ndis++;
|
|
925
|
+
float d = qdis(v1);
|
|
926
|
+
candidates.push(v1, d);
|
|
927
|
+
|
|
928
|
+
// never seen before --> add to heap
|
|
929
|
+
if (vt.visited[v1] < vt.visno) {
|
|
930
|
+
if (nres < k) {
|
|
931
|
+
faiss::maxheap_push (++nres, D, I, d, v1);
|
|
932
|
+
} else if (d < D[0]) {
|
|
933
|
+
faiss::maxheap_pop (nres--, D, I);
|
|
934
|
+
faiss::maxheap_push (++nres, D, I, d, v1);
|
|
935
|
+
}
|
|
936
|
+
}
|
|
937
|
+
vt.visited[v1] = vt.visno + 1;
|
|
938
|
+
}
|
|
939
|
+
}
|
|
940
|
+
|
|
941
|
+
nstep++;
|
|
942
|
+
if (nstep > hnsw.efSearch) {
|
|
943
|
+
break;
|
|
944
|
+
}
|
|
945
|
+
}
|
|
946
|
+
|
|
947
|
+
if (level == 0) {
|
|
948
|
+
#pragma omp critical
|
|
949
|
+
{
|
|
950
|
+
hnsw_stats.n1 ++;
|
|
951
|
+
if (candidates.size() == 0)
|
|
952
|
+
hnsw_stats.n2 ++;
|
|
953
|
+
}
|
|
954
|
+
}
|
|
955
|
+
|
|
956
|
+
|
|
957
|
+
return nres;
|
|
958
|
+
}
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
} // namespace
|
|
962
|
+
|
|
963
|
+
void IndexHNSW2Level::search (idx_t n, const float *x, idx_t k,
|
|
964
|
+
float *distances, idx_t *labels) const
|
|
965
|
+
{
|
|
966
|
+
if (dynamic_cast<const Index2Layer*>(storage)) {
|
|
967
|
+
IndexHNSW::search (n, x, k, distances, labels);
|
|
968
|
+
|
|
969
|
+
} else { // "mixed" search
|
|
970
|
+
|
|
971
|
+
const IndexIVFPQ *index_ivfpq =
|
|
972
|
+
dynamic_cast<const IndexIVFPQ*>(storage);
|
|
973
|
+
|
|
974
|
+
int nprobe = index_ivfpq->nprobe;
|
|
975
|
+
|
|
976
|
+
std::unique_ptr<idx_t[]> coarse_assign(new idx_t[n * nprobe]);
|
|
977
|
+
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
978
|
+
|
|
979
|
+
index_ivfpq->quantizer->search (n, x, nprobe, coarse_dis.get(),
|
|
980
|
+
coarse_assign.get());
|
|
981
|
+
|
|
982
|
+
index_ivfpq->search_preassigned (n, x, k, coarse_assign.get(),
|
|
983
|
+
coarse_dis.get(), distances, labels,
|
|
984
|
+
false);
|
|
985
|
+
|
|
986
|
+
#pragma omp parallel
|
|
987
|
+
{
|
|
988
|
+
VisitedTable vt (ntotal);
|
|
989
|
+
DistanceComputer *dis = storage->get_distance_computer();
|
|
990
|
+
ScopeDeleter1<DistanceComputer> del(dis);
|
|
991
|
+
|
|
992
|
+
int candidates_size = hnsw.upper_beam;
|
|
993
|
+
MinimaxHeap candidates(candidates_size);
|
|
994
|
+
|
|
995
|
+
#pragma omp for
|
|
996
|
+
for(idx_t i = 0; i < n; i++) {
|
|
997
|
+
idx_t * idxi = labels + i * k;
|
|
998
|
+
float * simi = distances + i * k;
|
|
999
|
+
dis->set_query(x + i * d);
|
|
1000
|
+
|
|
1001
|
+
// mark all inverted list elements as visited
|
|
1002
|
+
|
|
1003
|
+
for (int j = 0; j < nprobe; j++) {
|
|
1004
|
+
idx_t key = coarse_assign[j + i * nprobe];
|
|
1005
|
+
if (key < 0) break;
|
|
1006
|
+
size_t list_length = index_ivfpq->get_list_size (key);
|
|
1007
|
+
const idx_t * ids = index_ivfpq->invlists->get_ids (key);
|
|
1008
|
+
|
|
1009
|
+
for (int jj = 0; jj < list_length; jj++) {
|
|
1010
|
+
vt.set (ids[jj]);
|
|
1011
|
+
}
|
|
1012
|
+
}
|
|
1013
|
+
|
|
1014
|
+
candidates.clear();
|
|
1015
|
+
// copy the upper_beam elements to candidates list
|
|
1016
|
+
|
|
1017
|
+
int search_policy = 2;
|
|
1018
|
+
|
|
1019
|
+
if (search_policy == 1) {
|
|
1020
|
+
|
|
1021
|
+
for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) {
|
|
1022
|
+
if (idxi[j] < 0) break;
|
|
1023
|
+
candidates.push (idxi[j], simi[j]);
|
|
1024
|
+
// search_from_candidates adds them back
|
|
1025
|
+
idxi[j] = -1;
|
|
1026
|
+
simi[j] = HUGE_VAL;
|
|
1027
|
+
}
|
|
1028
|
+
|
|
1029
|
+
// reorder from sorted to heap
|
|
1030
|
+
maxheap_heapify (k, simi, idxi, simi, idxi, k);
|
|
1031
|
+
|
|
1032
|
+
hnsw.search_from_candidates(
|
|
1033
|
+
*dis, k, idxi, simi,
|
|
1034
|
+
candidates, vt, 0, k
|
|
1035
|
+
);
|
|
1036
|
+
|
|
1037
|
+
vt.advance();
|
|
1038
|
+
|
|
1039
|
+
} else if (search_policy == 2) {
|
|
1040
|
+
|
|
1041
|
+
for (int j = 0 ; j < hnsw.upper_beam && j < k; j++) {
|
|
1042
|
+
if (idxi[j] < 0) break;
|
|
1043
|
+
candidates.push (idxi[j], simi[j]);
|
|
1044
|
+
}
|
|
1045
|
+
|
|
1046
|
+
// reorder from sorted to heap
|
|
1047
|
+
maxheap_heapify (k, simi, idxi, simi, idxi, k);
|
|
1048
|
+
|
|
1049
|
+
search_from_candidates_2 (
|
|
1050
|
+
hnsw, *dis, k, idxi, simi,
|
|
1051
|
+
candidates, vt, 0, k);
|
|
1052
|
+
vt.advance ();
|
|
1053
|
+
vt.advance ();
|
|
1054
|
+
|
|
1055
|
+
}
|
|
1056
|
+
|
|
1057
|
+
maxheap_reorder (k, simi, idxi);
|
|
1058
|
+
}
|
|
1059
|
+
}
|
|
1060
|
+
}
|
|
1061
|
+
|
|
1062
|
+
|
|
1063
|
+
}
|
|
1064
|
+
|
|
1065
|
+
|
|
1066
|
+
void IndexHNSW2Level::flip_to_ivf ()
|
|
1067
|
+
{
|
|
1068
|
+
Index2Layer *storage2l =
|
|
1069
|
+
dynamic_cast<Index2Layer*>(storage);
|
|
1070
|
+
|
|
1071
|
+
FAISS_THROW_IF_NOT (storage2l);
|
|
1072
|
+
|
|
1073
|
+
IndexIVFPQ * index_ivfpq =
|
|
1074
|
+
new IndexIVFPQ (storage2l->q1.quantizer,
|
|
1075
|
+
d, storage2l->q1.nlist,
|
|
1076
|
+
storage2l->pq.M, 8);
|
|
1077
|
+
index_ivfpq->pq = storage2l->pq;
|
|
1078
|
+
index_ivfpq->is_trained = storage2l->is_trained;
|
|
1079
|
+
index_ivfpq->precompute_table();
|
|
1080
|
+
index_ivfpq->own_fields = storage2l->q1.own_fields;
|
|
1081
|
+
storage2l->transfer_to_IVFPQ(*index_ivfpq);
|
|
1082
|
+
index_ivfpq->make_direct_map (true);
|
|
1083
|
+
|
|
1084
|
+
storage = index_ivfpq;
|
|
1085
|
+
delete storage2l;
|
|
1086
|
+
|
|
1087
|
+
}
|
|
1088
|
+
|
|
1089
|
+
|
|
1090
|
+
} // namespace faiss
|