faiss 0.1.0 → 0.1.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
@@ -0,0 +1,170 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#pragma once
|
11
|
+
|
12
|
+
#include <vector>
|
13
|
+
|
14
|
+
#include <faiss/impl/HNSW.h>
|
15
|
+
#include <faiss/IndexFlat.h>
|
16
|
+
#include <faiss/IndexPQ.h>
|
17
|
+
#include <faiss/IndexScalarQuantizer.h>
|
18
|
+
#include <faiss/utils/utils.h>
|
19
|
+
|
20
|
+
|
21
|
+
namespace faiss {
|
22
|
+
|
23
|
+
struct IndexHNSW;
|
24
|
+
|
25
|
+
struct ReconstructFromNeighbors {
|
26
|
+
typedef Index::idx_t idx_t;
|
27
|
+
typedef HNSW::storage_idx_t storage_idx_t;
|
28
|
+
|
29
|
+
const IndexHNSW & index;
|
30
|
+
size_t M; // number of neighbors
|
31
|
+
size_t k; // number of codebook entries
|
32
|
+
size_t nsq; // number of subvectors
|
33
|
+
size_t code_size;
|
34
|
+
int k_reorder; // nb to reorder. -1 = all
|
35
|
+
|
36
|
+
std::vector<float> codebook; // size nsq * k * (M + 1)
|
37
|
+
|
38
|
+
std::vector<uint8_t> codes; // size ntotal * code_size
|
39
|
+
size_t ntotal;
|
40
|
+
size_t d, dsub; // derived values
|
41
|
+
|
42
|
+
explicit ReconstructFromNeighbors(const IndexHNSW& index,
|
43
|
+
size_t k=256, size_t nsq=1);
|
44
|
+
|
45
|
+
/// codes must be added in the correct order and the IndexHNSW
|
46
|
+
/// must be populated and sorted
|
47
|
+
void add_codes(size_t n, const float *x);
|
48
|
+
|
49
|
+
size_t compute_distances(size_t n, const idx_t *shortlist,
|
50
|
+
const float *query, float *distances) const;
|
51
|
+
|
52
|
+
/// called by add_codes
|
53
|
+
void estimate_code(const float *x, storage_idx_t i, uint8_t *code) const;
|
54
|
+
|
55
|
+
/// called by compute_distances
|
56
|
+
void reconstruct(storage_idx_t i, float *x, float *tmp) const;
|
57
|
+
|
58
|
+
void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float *x) const;
|
59
|
+
|
60
|
+
/// get the M+1 -by-d table for neighbor coordinates for vector i
|
61
|
+
void get_neighbor_table(storage_idx_t i, float *out) const;
|
62
|
+
|
63
|
+
};
|
64
|
+
|
65
|
+
|
66
|
+
/** The HNSW index is a normal random-access index with a HNSW
|
67
|
+
* link structure built on top */
|
68
|
+
|
69
|
+
struct IndexHNSW : Index {
|
70
|
+
|
71
|
+
typedef HNSW::storage_idx_t storage_idx_t;
|
72
|
+
|
73
|
+
// the link strcuture
|
74
|
+
HNSW hnsw;
|
75
|
+
|
76
|
+
// the sequential storage
|
77
|
+
bool own_fields;
|
78
|
+
Index *storage;
|
79
|
+
|
80
|
+
ReconstructFromNeighbors *reconstruct_from_neighbors;
|
81
|
+
|
82
|
+
explicit IndexHNSW (int d = 0, int M = 32);
|
83
|
+
explicit IndexHNSW (Index *storage, int M = 32);
|
84
|
+
|
85
|
+
~IndexHNSW() override;
|
86
|
+
|
87
|
+
void add(idx_t n, const float *x) override;
|
88
|
+
|
89
|
+
/// Trains the storage if needed
|
90
|
+
void train(idx_t n, const float* x) override;
|
91
|
+
|
92
|
+
/// entry point for search
|
93
|
+
void search (idx_t n, const float *x, idx_t k,
|
94
|
+
float *distances, idx_t *labels) const override;
|
95
|
+
|
96
|
+
void reconstruct(idx_t key, float* recons) const override;
|
97
|
+
|
98
|
+
void reset () override;
|
99
|
+
|
100
|
+
void shrink_level_0_neighbors(int size);
|
101
|
+
|
102
|
+
/** Perform search only on level 0, given the starting points for
|
103
|
+
* each vertex.
|
104
|
+
*
|
105
|
+
* @param search_type 1:perform one search per nprobe, 2: enqueue
|
106
|
+
* all entry points
|
107
|
+
*/
|
108
|
+
void search_level_0(idx_t n, const float *x, idx_t k,
|
109
|
+
const storage_idx_t *nearest, const float *nearest_d,
|
110
|
+
float *distances, idx_t *labels, int nprobe = 1,
|
111
|
+
int search_type = 1) const;
|
112
|
+
|
113
|
+
/// alternative graph building
|
114
|
+
void init_level_0_from_knngraph(
|
115
|
+
int k, const float *D, const idx_t *I);
|
116
|
+
|
117
|
+
/// alternative graph building
|
118
|
+
void init_level_0_from_entry_points(
|
119
|
+
int npt, const storage_idx_t *points,
|
120
|
+
const storage_idx_t *nearests);
|
121
|
+
|
122
|
+
// reorder links from nearest to farthest
|
123
|
+
void reorder_links();
|
124
|
+
|
125
|
+
void link_singletons();
|
126
|
+
};
|
127
|
+
|
128
|
+
|
129
|
+
/** Flat index topped with with a HNSW structure to access elements
|
130
|
+
* more efficiently.
|
131
|
+
*/
|
132
|
+
|
133
|
+
struct IndexHNSWFlat : IndexHNSW {
|
134
|
+
IndexHNSWFlat();
|
135
|
+
IndexHNSWFlat(int d, int M);
|
136
|
+
};
|
137
|
+
|
138
|
+
/** PQ index topped with with a HNSW structure to access elements
|
139
|
+
* more efficiently.
|
140
|
+
*/
|
141
|
+
struct IndexHNSWPQ : IndexHNSW {
|
142
|
+
IndexHNSWPQ();
|
143
|
+
IndexHNSWPQ(int d, int pq_m, int M);
|
144
|
+
void train(idx_t n, const float* x) override;
|
145
|
+
};
|
146
|
+
|
147
|
+
/** SQ index topped with with a HNSW structure to access elements
|
148
|
+
* more efficiently.
|
149
|
+
*/
|
150
|
+
struct IndexHNSWSQ : IndexHNSW {
|
151
|
+
IndexHNSWSQ();
|
152
|
+
IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M);
|
153
|
+
};
|
154
|
+
|
155
|
+
/** 2-level code structure with fast random access
|
156
|
+
*/
|
157
|
+
struct IndexHNSW2Level : IndexHNSW {
|
158
|
+
IndexHNSW2Level();
|
159
|
+
IndexHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M);
|
160
|
+
|
161
|
+
void flip_to_ivf();
|
162
|
+
|
163
|
+
/// entry point for search
|
164
|
+
void search (idx_t n, const float *x, idx_t k,
|
165
|
+
float *distances, idx_t *labels) const override;
|
166
|
+
|
167
|
+
};
|
168
|
+
|
169
|
+
|
170
|
+
} // namespace faiss
|
@@ -0,0 +1,909 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#include <faiss/IndexIVF.h>
|
11
|
+
|
12
|
+
|
13
|
+
#include <omp.h>
|
14
|
+
|
15
|
+
#include <cstdio>
|
16
|
+
#include <memory>
|
17
|
+
|
18
|
+
#include <faiss/utils/utils.h>
|
19
|
+
#include <faiss/utils/hamming.h>
|
20
|
+
|
21
|
+
#include <faiss/impl/FaissAssert.h>
|
22
|
+
#include <faiss/IndexFlat.h>
|
23
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
24
|
+
|
25
|
+
namespace faiss {
|
26
|
+
|
27
|
+
using ScopedIds = InvertedLists::ScopedIds;
|
28
|
+
using ScopedCodes = InvertedLists::ScopedCodes;
|
29
|
+
|
30
|
+
/*****************************************
|
31
|
+
* Level1Quantizer implementation
|
32
|
+
******************************************/
|
33
|
+
|
34
|
+
|
35
|
+
Level1Quantizer::Level1Quantizer (Index * quantizer, size_t nlist):
|
36
|
+
quantizer (quantizer),
|
37
|
+
nlist (nlist),
|
38
|
+
quantizer_trains_alone (0),
|
39
|
+
own_fields (false),
|
40
|
+
clustering_index (nullptr)
|
41
|
+
{
|
42
|
+
// here we set a low # iterations because this is typically used
|
43
|
+
// for large clusterings (nb this is not used for the MultiIndex,
|
44
|
+
// for which quantizer_trains_alone = true)
|
45
|
+
cp.niter = 10;
|
46
|
+
}
|
47
|
+
|
48
|
+
Level1Quantizer::Level1Quantizer ():
|
49
|
+
quantizer (nullptr),
|
50
|
+
nlist (0),
|
51
|
+
quantizer_trains_alone (0), own_fields (false),
|
52
|
+
clustering_index (nullptr)
|
53
|
+
{}
|
54
|
+
|
55
|
+
Level1Quantizer::~Level1Quantizer ()
|
56
|
+
{
|
57
|
+
if (own_fields) delete quantizer;
|
58
|
+
}
|
59
|
+
|
60
|
+
void Level1Quantizer::train_q1 (size_t n, const float *x, bool verbose, MetricType metric_type)
|
61
|
+
{
|
62
|
+
size_t d = quantizer->d;
|
63
|
+
if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
|
64
|
+
if (verbose)
|
65
|
+
printf ("IVF quantizer does not need training.\n");
|
66
|
+
} else if (quantizer_trains_alone == 1) {
|
67
|
+
if (verbose)
|
68
|
+
printf ("IVF quantizer trains alone...\n");
|
69
|
+
quantizer->train (n, x);
|
70
|
+
quantizer->verbose = verbose;
|
71
|
+
FAISS_THROW_IF_NOT_MSG (quantizer->ntotal == nlist,
|
72
|
+
"nlist not consistent with quantizer size");
|
73
|
+
} else if (quantizer_trains_alone == 0) {
|
74
|
+
if (verbose)
|
75
|
+
printf ("Training level-1 quantizer on %ld vectors in %ldD\n",
|
76
|
+
n, d);
|
77
|
+
|
78
|
+
Clustering clus (d, nlist, cp);
|
79
|
+
quantizer->reset();
|
80
|
+
if (clustering_index) {
|
81
|
+
clus.train (n, x, *clustering_index);
|
82
|
+
quantizer->add (nlist, clus.centroids.data());
|
83
|
+
} else {
|
84
|
+
clus.train (n, x, *quantizer);
|
85
|
+
}
|
86
|
+
quantizer->is_trained = true;
|
87
|
+
} else if (quantizer_trains_alone == 2) {
|
88
|
+
if (verbose)
|
89
|
+
printf (
|
90
|
+
"Training L2 quantizer on %ld vectors in %ldD%s\n",
|
91
|
+
n, d,
|
92
|
+
clustering_index ? "(user provided index)" : "");
|
93
|
+
FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
|
94
|
+
Clustering clus (d, nlist, cp);
|
95
|
+
if (!clustering_index) {
|
96
|
+
IndexFlatL2 assigner (d);
|
97
|
+
clus.train(n, x, assigner);
|
98
|
+
} else {
|
99
|
+
clus.train(n, x, *clustering_index);
|
100
|
+
}
|
101
|
+
if (verbose)
|
102
|
+
printf ("Adding centroids to quantizer\n");
|
103
|
+
quantizer->add (nlist, clus.centroids.data());
|
104
|
+
}
|
105
|
+
}
|
106
|
+
|
107
|
+
size_t Level1Quantizer::coarse_code_size () const
|
108
|
+
{
|
109
|
+
size_t nl = nlist - 1;
|
110
|
+
size_t nbyte = 0;
|
111
|
+
while (nl > 0) {
|
112
|
+
nbyte ++;
|
113
|
+
nl >>= 8;
|
114
|
+
}
|
115
|
+
return nbyte;
|
116
|
+
}
|
117
|
+
|
118
|
+
void Level1Quantizer::encode_listno (Index::idx_t list_no, uint8_t *code) const
|
119
|
+
{
|
120
|
+
// little endian
|
121
|
+
size_t nl = nlist - 1;
|
122
|
+
while (nl > 0) {
|
123
|
+
*code++ = list_no & 0xff;
|
124
|
+
list_no >>= 8;
|
125
|
+
nl >>= 8;
|
126
|
+
}
|
127
|
+
}
|
128
|
+
|
129
|
+
Index::idx_t Level1Quantizer::decode_listno (const uint8_t *code) const
|
130
|
+
{
|
131
|
+
size_t nl = nlist - 1;
|
132
|
+
int64_t list_no = 0;
|
133
|
+
int nbit = 0;
|
134
|
+
while (nl > 0) {
|
135
|
+
list_no |= int64_t(*code++) << nbit;
|
136
|
+
nbit += 8;
|
137
|
+
nl >>= 8;
|
138
|
+
}
|
139
|
+
FAISS_THROW_IF_NOT (list_no >= 0 && list_no < nlist);
|
140
|
+
return list_no;
|
141
|
+
}
|
142
|
+
|
143
|
+
|
144
|
+
|
145
|
+
/*****************************************
|
146
|
+
* IndexIVF implementation
|
147
|
+
******************************************/
|
148
|
+
|
149
|
+
|
150
|
+
IndexIVF::IndexIVF (Index * quantizer, size_t d,
|
151
|
+
size_t nlist, size_t code_size,
|
152
|
+
MetricType metric):
|
153
|
+
Index (d, metric),
|
154
|
+
Level1Quantizer (quantizer, nlist),
|
155
|
+
invlists (new ArrayInvertedLists (nlist, code_size)),
|
156
|
+
own_invlists (true),
|
157
|
+
code_size (code_size),
|
158
|
+
nprobe (1),
|
159
|
+
max_codes (0),
|
160
|
+
parallel_mode (0),
|
161
|
+
maintain_direct_map (false)
|
162
|
+
{
|
163
|
+
FAISS_THROW_IF_NOT (d == quantizer->d);
|
164
|
+
is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
|
165
|
+
// Spherical by default if the metric is inner_product
|
166
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
167
|
+
cp.spherical = true;
|
168
|
+
}
|
169
|
+
|
170
|
+
}
|
171
|
+
|
172
|
+
IndexIVF::IndexIVF ():
|
173
|
+
invlists (nullptr), own_invlists (false),
|
174
|
+
code_size (0),
|
175
|
+
nprobe (1), max_codes (0), parallel_mode (0),
|
176
|
+
maintain_direct_map (false)
|
177
|
+
{}
|
178
|
+
|
179
|
+
void IndexIVF::add (idx_t n, const float * x)
|
180
|
+
{
|
181
|
+
add_with_ids (n, x, nullptr);
|
182
|
+
}
|
183
|
+
|
184
|
+
|
185
|
+
void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
|
186
|
+
{
|
187
|
+
// do some blocking to avoid excessive allocs
|
188
|
+
idx_t bs = 65536;
|
189
|
+
if (n > bs) {
|
190
|
+
for (idx_t i0 = 0; i0 < n; i0 += bs) {
|
191
|
+
idx_t i1 = std::min (n, i0 + bs);
|
192
|
+
if (verbose) {
|
193
|
+
printf(" IndexIVF::add_with_ids %ld:%ld\n", i0, i1);
|
194
|
+
}
|
195
|
+
add_with_ids (i1 - i0, x + i0 * d,
|
196
|
+
xids ? xids + i0 : nullptr);
|
197
|
+
}
|
198
|
+
return;
|
199
|
+
}
|
200
|
+
|
201
|
+
FAISS_THROW_IF_NOT (is_trained);
|
202
|
+
std::unique_ptr<idx_t []> idx(new idx_t[n]);
|
203
|
+
quantizer->assign (n, x, idx.get());
|
204
|
+
size_t nadd = 0, nminus1 = 0;
|
205
|
+
|
206
|
+
for (size_t i = 0; i < n; i++) {
|
207
|
+
if (idx[i] < 0) nminus1++;
|
208
|
+
}
|
209
|
+
|
210
|
+
std::unique_ptr<uint8_t []> flat_codes(new uint8_t [n * code_size]);
|
211
|
+
encode_vectors (n, x, idx.get(), flat_codes.get());
|
212
|
+
|
213
|
+
#pragma omp parallel reduction(+: nadd)
|
214
|
+
{
|
215
|
+
int nt = omp_get_num_threads();
|
216
|
+
int rank = omp_get_thread_num();
|
217
|
+
|
218
|
+
// each thread takes care of a subset of lists
|
219
|
+
for (size_t i = 0; i < n; i++) {
|
220
|
+
idx_t list_no = idx [i];
|
221
|
+
if (list_no >= 0 && list_no % nt == rank) {
|
222
|
+
idx_t id = xids ? xids[i] : ntotal + i;
|
223
|
+
invlists->add_entry (list_no, id,
|
224
|
+
flat_codes.get() + i * code_size);
|
225
|
+
nadd++;
|
226
|
+
}
|
227
|
+
}
|
228
|
+
}
|
229
|
+
|
230
|
+
if (verbose) {
|
231
|
+
printf(" added %ld / %ld vectors (%ld -1s)\n", nadd, n, nminus1);
|
232
|
+
}
|
233
|
+
|
234
|
+
ntotal += n;
|
235
|
+
}
|
236
|
+
|
237
|
+
|
238
|
+
void IndexIVF::make_direct_map (bool new_maintain_direct_map)
|
239
|
+
{
|
240
|
+
// nothing to do
|
241
|
+
if (new_maintain_direct_map == maintain_direct_map)
|
242
|
+
return;
|
243
|
+
|
244
|
+
if (new_maintain_direct_map) {
|
245
|
+
direct_map.resize (ntotal, -1);
|
246
|
+
for (size_t key = 0; key < nlist; key++) {
|
247
|
+
size_t list_size = invlists->list_size (key);
|
248
|
+
ScopedIds idlist (invlists, key);
|
249
|
+
|
250
|
+
for (long ofs = 0; ofs < list_size; ofs++) {
|
251
|
+
FAISS_THROW_IF_NOT_MSG (
|
252
|
+
0 <= idlist [ofs] && idlist[ofs] < ntotal,
|
253
|
+
"direct map supported only for seuquential ids");
|
254
|
+
direct_map [idlist [ofs]] = key << 32 | ofs;
|
255
|
+
}
|
256
|
+
}
|
257
|
+
} else {
|
258
|
+
direct_map.clear ();
|
259
|
+
}
|
260
|
+
maintain_direct_map = new_maintain_direct_map;
|
261
|
+
}
|
262
|
+
|
263
|
+
|
264
|
+
void IndexIVF::search (idx_t n, const float *x, idx_t k,
|
265
|
+
float *distances, idx_t *labels) const
|
266
|
+
{
|
267
|
+
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
268
|
+
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
269
|
+
|
270
|
+
double t0 = getmillisecs();
|
271
|
+
quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
|
272
|
+
indexIVF_stats.quantization_time += getmillisecs() - t0;
|
273
|
+
|
274
|
+
t0 = getmillisecs();
|
275
|
+
invlists->prefetch_lists (idx.get(), n * nprobe);
|
276
|
+
|
277
|
+
search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
|
278
|
+
distances, labels, false);
|
279
|
+
indexIVF_stats.search_time += getmillisecs() - t0;
|
280
|
+
}
|
281
|
+
|
282
|
+
|
283
|
+
|
284
|
+
void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
285
|
+
const idx_t *keys,
|
286
|
+
const float *coarse_dis ,
|
287
|
+
float *distances, idx_t *labels,
|
288
|
+
bool store_pairs,
|
289
|
+
const IVFSearchParameters *params) const
|
290
|
+
{
|
291
|
+
long nprobe = params ? params->nprobe : this->nprobe;
|
292
|
+
long max_codes = params ? params->max_codes : this->max_codes;
|
293
|
+
|
294
|
+
size_t nlistv = 0, ndis = 0, nheap = 0;
|
295
|
+
|
296
|
+
using HeapForIP = CMin<float, idx_t>;
|
297
|
+
using HeapForL2 = CMax<float, idx_t>;
|
298
|
+
|
299
|
+
bool interrupt = false;
|
300
|
+
|
301
|
+
// don't start parallel section if single query
|
302
|
+
bool do_parallel =
|
303
|
+
parallel_mode == 0 ? n > 1 :
|
304
|
+
parallel_mode == 1 ? nprobe > 1 :
|
305
|
+
nprobe * n > 1;
|
306
|
+
|
307
|
+
#pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
|
308
|
+
{
|
309
|
+
InvertedListScanner *scanner = get_InvertedListScanner(store_pairs);
|
310
|
+
ScopeDeleter1<InvertedListScanner> del(scanner);
|
311
|
+
|
312
|
+
/*****************************************************
|
313
|
+
* Depending on parallel_mode, there are two possible ways
|
314
|
+
* to organize the search. Here we define local functions
|
315
|
+
* that are in common between the two
|
316
|
+
******************************************************/
|
317
|
+
|
318
|
+
// intialize + reorder a result heap
|
319
|
+
|
320
|
+
auto init_result = [&](float *simi, idx_t *idxi) {
|
321
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
322
|
+
heap_heapify<HeapForIP> (k, simi, idxi);
|
323
|
+
} else {
|
324
|
+
heap_heapify<HeapForL2> (k, simi, idxi);
|
325
|
+
}
|
326
|
+
};
|
327
|
+
|
328
|
+
auto reorder_result = [&] (float *simi, idx_t *idxi) {
|
329
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
330
|
+
heap_reorder<HeapForIP> (k, simi, idxi);
|
331
|
+
} else {
|
332
|
+
heap_reorder<HeapForL2> (k, simi, idxi);
|
333
|
+
}
|
334
|
+
};
|
335
|
+
|
336
|
+
// single list scan using the current scanner (with query
|
337
|
+
// set porperly) and storing results in simi and idxi
|
338
|
+
auto scan_one_list = [&] (idx_t key, float coarse_dis_i,
|
339
|
+
float *simi, idx_t *idxi) {
|
340
|
+
|
341
|
+
if (key < 0) {
|
342
|
+
// not enough centroids for multiprobe
|
343
|
+
return (size_t)0;
|
344
|
+
}
|
345
|
+
FAISS_THROW_IF_NOT_FMT (key < (idx_t) nlist,
|
346
|
+
"Invalid key=%ld nlist=%ld\n",
|
347
|
+
key, nlist);
|
348
|
+
|
349
|
+
size_t list_size = invlists->list_size(key);
|
350
|
+
|
351
|
+
// don't waste time on empty lists
|
352
|
+
if (list_size == 0) {
|
353
|
+
return (size_t)0;
|
354
|
+
}
|
355
|
+
|
356
|
+
scanner->set_list (key, coarse_dis_i);
|
357
|
+
|
358
|
+
nlistv++;
|
359
|
+
|
360
|
+
InvertedLists::ScopedCodes scodes (invlists, key);
|
361
|
+
|
362
|
+
std::unique_ptr<InvertedLists::ScopedIds> sids;
|
363
|
+
const Index::idx_t * ids = nullptr;
|
364
|
+
|
365
|
+
if (!store_pairs) {
|
366
|
+
sids.reset (new InvertedLists::ScopedIds (invlists, key));
|
367
|
+
ids = sids->get();
|
368
|
+
}
|
369
|
+
|
370
|
+
nheap += scanner->scan_codes (list_size, scodes.get(),
|
371
|
+
ids, simi, idxi, k);
|
372
|
+
|
373
|
+
return list_size;
|
374
|
+
};
|
375
|
+
|
376
|
+
/****************************************************
|
377
|
+
* Actual loops, depending on parallel_mode
|
378
|
+
****************************************************/
|
379
|
+
|
380
|
+
if (parallel_mode == 0) {
|
381
|
+
|
382
|
+
#pragma omp for
|
383
|
+
for (size_t i = 0; i < n; i++) {
|
384
|
+
|
385
|
+
if (interrupt) {
|
386
|
+
continue;
|
387
|
+
}
|
388
|
+
|
389
|
+
// loop over queries
|
390
|
+
scanner->set_query (x + i * d);
|
391
|
+
float * simi = distances + i * k;
|
392
|
+
idx_t * idxi = labels + i * k;
|
393
|
+
|
394
|
+
init_result (simi, idxi);
|
395
|
+
|
396
|
+
long nscan = 0;
|
397
|
+
|
398
|
+
// loop over probes
|
399
|
+
for (size_t ik = 0; ik < nprobe; ik++) {
|
400
|
+
|
401
|
+
nscan += scan_one_list (
|
402
|
+
keys [i * nprobe + ik],
|
403
|
+
coarse_dis[i * nprobe + ik],
|
404
|
+
simi, idxi
|
405
|
+
);
|
406
|
+
|
407
|
+
if (max_codes && nscan >= max_codes) {
|
408
|
+
break;
|
409
|
+
}
|
410
|
+
}
|
411
|
+
|
412
|
+
ndis += nscan;
|
413
|
+
reorder_result (simi, idxi);
|
414
|
+
|
415
|
+
if (InterruptCallback::is_interrupted ()) {
|
416
|
+
interrupt = true;
|
417
|
+
}
|
418
|
+
|
419
|
+
} // parallel for
|
420
|
+
} else if (parallel_mode == 1) {
|
421
|
+
std::vector <idx_t> local_idx (k);
|
422
|
+
std::vector <float> local_dis (k);
|
423
|
+
|
424
|
+
for (size_t i = 0; i < n; i++) {
|
425
|
+
scanner->set_query (x + i * d);
|
426
|
+
init_result (local_dis.data(), local_idx.data());
|
427
|
+
|
428
|
+
#pragma omp for schedule(dynamic)
|
429
|
+
for (size_t ik = 0; ik < nprobe; ik++) {
|
430
|
+
ndis += scan_one_list
|
431
|
+
(keys [i * nprobe + ik],
|
432
|
+
coarse_dis[i * nprobe + ik],
|
433
|
+
local_dis.data(), local_idx.data());
|
434
|
+
|
435
|
+
// can't do the test on max_codes
|
436
|
+
}
|
437
|
+
// merge thread-local results
|
438
|
+
|
439
|
+
float * simi = distances + i * k;
|
440
|
+
idx_t * idxi = labels + i * k;
|
441
|
+
#pragma omp single
|
442
|
+
init_result (simi, idxi);
|
443
|
+
|
444
|
+
#pragma omp barrier
|
445
|
+
#pragma omp critical
|
446
|
+
{
|
447
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
448
|
+
heap_addn<HeapForIP>
|
449
|
+
(k, simi, idxi,
|
450
|
+
local_dis.data(), local_idx.data(), k);
|
451
|
+
} else {
|
452
|
+
heap_addn<HeapForL2>
|
453
|
+
(k, simi, idxi,
|
454
|
+
local_dis.data(), local_idx.data(), k);
|
455
|
+
}
|
456
|
+
}
|
457
|
+
#pragma omp barrier
|
458
|
+
#pragma omp single
|
459
|
+
reorder_result (simi, idxi);
|
460
|
+
}
|
461
|
+
} else {
|
462
|
+
FAISS_THROW_FMT ("parallel_mode %d not supported\n",
|
463
|
+
parallel_mode);
|
464
|
+
}
|
465
|
+
} // parallel section
|
466
|
+
|
467
|
+
if (interrupt) {
|
468
|
+
FAISS_THROW_MSG ("computation interrupted");
|
469
|
+
}
|
470
|
+
|
471
|
+
indexIVF_stats.nq += n;
|
472
|
+
indexIVF_stats.nlist += nlistv;
|
473
|
+
indexIVF_stats.ndis += ndis;
|
474
|
+
indexIVF_stats.nheap_updates += nheap;
|
475
|
+
|
476
|
+
}
|
477
|
+
|
478
|
+
|
479
|
+
|
480
|
+
|
481
|
+
void IndexIVF::range_search (idx_t nx, const float *x, float radius,
|
482
|
+
RangeSearchResult *result) const
|
483
|
+
{
|
484
|
+
std::unique_ptr<idx_t[]> keys (new idx_t[nx * nprobe]);
|
485
|
+
std::unique_ptr<float []> coarse_dis (new float[nx * nprobe]);
|
486
|
+
|
487
|
+
double t0 = getmillisecs();
|
488
|
+
quantizer->search (nx, x, nprobe, coarse_dis.get (), keys.get ());
|
489
|
+
indexIVF_stats.quantization_time += getmillisecs() - t0;
|
490
|
+
|
491
|
+
t0 = getmillisecs();
|
492
|
+
invlists->prefetch_lists (keys.get(), nx * nprobe);
|
493
|
+
|
494
|
+
range_search_preassigned (nx, x, radius, keys.get (), coarse_dis.get (),
|
495
|
+
result);
|
496
|
+
|
497
|
+
indexIVF_stats.search_time += getmillisecs() - t0;
|
498
|
+
}
|
499
|
+
|
500
|
+
void IndexIVF::range_search_preassigned (
|
501
|
+
idx_t nx, const float *x, float radius,
|
502
|
+
const idx_t *keys, const float *coarse_dis,
|
503
|
+
RangeSearchResult *result) const
|
504
|
+
{
|
505
|
+
|
506
|
+
size_t nlistv = 0, ndis = 0;
|
507
|
+
bool store_pairs = false;
|
508
|
+
|
509
|
+
std::vector<RangeSearchPartialResult *> all_pres (omp_get_max_threads());
|
510
|
+
|
511
|
+
#pragma omp parallel reduction(+: nlistv, ndis)
|
512
|
+
{
|
513
|
+
RangeSearchPartialResult pres(result);
|
514
|
+
std::unique_ptr<InvertedListScanner> scanner
|
515
|
+
(get_InvertedListScanner(store_pairs));
|
516
|
+
FAISS_THROW_IF_NOT (scanner.get ());
|
517
|
+
all_pres[omp_get_thread_num()] = &pres;
|
518
|
+
|
519
|
+
// prepare the list scanning function
|
520
|
+
|
521
|
+
auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult &qres) {
|
522
|
+
|
523
|
+
idx_t key = keys[i * nprobe + ik]; /* select the list */
|
524
|
+
if (key < 0) return;
|
525
|
+
FAISS_THROW_IF_NOT_FMT (
|
526
|
+
key < (idx_t) nlist,
|
527
|
+
"Invalid key=%ld at ik=%ld nlist=%ld\n",
|
528
|
+
key, ik, nlist);
|
529
|
+
const size_t list_size = invlists->list_size(key);
|
530
|
+
|
531
|
+
if (list_size == 0) return;
|
532
|
+
|
533
|
+
InvertedLists::ScopedCodes scodes (invlists, key);
|
534
|
+
InvertedLists::ScopedIds ids (invlists, key);
|
535
|
+
|
536
|
+
scanner->set_list (key, coarse_dis[i * nprobe + ik]);
|
537
|
+
nlistv++;
|
538
|
+
ndis += list_size;
|
539
|
+
scanner->scan_codes_range (list_size, scodes.get(),
|
540
|
+
ids.get(), radius, qres);
|
541
|
+
};
|
542
|
+
|
543
|
+
if (parallel_mode == 0) {
|
544
|
+
|
545
|
+
#pragma omp for
|
546
|
+
for (size_t i = 0; i < nx; i++) {
|
547
|
+
scanner->set_query (x + i * d);
|
548
|
+
|
549
|
+
RangeQueryResult & qres = pres.new_result (i);
|
550
|
+
|
551
|
+
for (size_t ik = 0; ik < nprobe; ik++) {
|
552
|
+
scan_list_func (i, ik, qres);
|
553
|
+
}
|
554
|
+
|
555
|
+
}
|
556
|
+
|
557
|
+
} else if (parallel_mode == 1) {
|
558
|
+
|
559
|
+
for (size_t i = 0; i < nx; i++) {
|
560
|
+
scanner->set_query (x + i * d);
|
561
|
+
|
562
|
+
RangeQueryResult & qres = pres.new_result (i);
|
563
|
+
|
564
|
+
#pragma omp for schedule(dynamic)
|
565
|
+
for (size_t ik = 0; ik < nprobe; ik++) {
|
566
|
+
scan_list_func (i, ik, qres);
|
567
|
+
}
|
568
|
+
}
|
569
|
+
} else if (parallel_mode == 2) {
|
570
|
+
std::vector<RangeQueryResult *> all_qres (nx);
|
571
|
+
RangeQueryResult *qres = nullptr;
|
572
|
+
|
573
|
+
#pragma omp for schedule(dynamic)
|
574
|
+
for (size_t iik = 0; iik < nx * nprobe; iik++) {
|
575
|
+
size_t i = iik / nprobe;
|
576
|
+
size_t ik = iik % nprobe;
|
577
|
+
if (qres == nullptr || qres->qno != i) {
|
578
|
+
FAISS_ASSERT (!qres || i > qres->qno);
|
579
|
+
qres = &pres.new_result (i);
|
580
|
+
scanner->set_query (x + i * d);
|
581
|
+
}
|
582
|
+
scan_list_func (i, ik, *qres);
|
583
|
+
}
|
584
|
+
} else {
|
585
|
+
FAISS_THROW_FMT ("parallel_mode %d not supported\n", parallel_mode);
|
586
|
+
}
|
587
|
+
if (parallel_mode == 0) {
|
588
|
+
pres.finalize ();
|
589
|
+
} else {
|
590
|
+
#pragma omp barrier
|
591
|
+
#pragma omp single
|
592
|
+
RangeSearchPartialResult::merge (all_pres, false);
|
593
|
+
#pragma omp barrier
|
594
|
+
|
595
|
+
}
|
596
|
+
}
|
597
|
+
indexIVF_stats.nq += nx;
|
598
|
+
indexIVF_stats.nlist += nlistv;
|
599
|
+
indexIVF_stats.ndis += ndis;
|
600
|
+
}
|
601
|
+
|
602
|
+
|
603
|
+
InvertedListScanner *IndexIVF::get_InvertedListScanner (
|
604
|
+
bool /*store_pairs*/) const
|
605
|
+
{
|
606
|
+
return nullptr;
|
607
|
+
}
|
608
|
+
|
609
|
+
void IndexIVF::reconstruct (idx_t key, float* recons) const
|
610
|
+
{
|
611
|
+
FAISS_THROW_IF_NOT_MSG (direct_map.size() == ntotal,
|
612
|
+
"direct map is not initialized");
|
613
|
+
FAISS_THROW_IF_NOT_MSG (key >= 0 && key < direct_map.size(),
|
614
|
+
"invalid key");
|
615
|
+
idx_t list_no = direct_map[key] >> 32;
|
616
|
+
idx_t offset = direct_map[key] & 0xffffffff;
|
617
|
+
reconstruct_from_offset (list_no, offset, recons);
|
618
|
+
}
|
619
|
+
|
620
|
+
|
621
|
+
void IndexIVF::reconstruct_n (idx_t i0, idx_t ni, float* recons) const
|
622
|
+
{
|
623
|
+
FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
|
624
|
+
|
625
|
+
for (idx_t list_no = 0; list_no < nlist; list_no++) {
|
626
|
+
size_t list_size = invlists->list_size (list_no);
|
627
|
+
ScopedIds idlist (invlists, list_no);
|
628
|
+
|
629
|
+
for (idx_t offset = 0; offset < list_size; offset++) {
|
630
|
+
idx_t id = idlist[offset];
|
631
|
+
if (!(id >= i0 && id < i0 + ni)) {
|
632
|
+
continue;
|
633
|
+
}
|
634
|
+
|
635
|
+
float* reconstructed = recons + (id - i0) * d;
|
636
|
+
reconstruct_from_offset (list_no, offset, reconstructed);
|
637
|
+
}
|
638
|
+
}
|
639
|
+
}
|
640
|
+
|
641
|
+
|
642
|
+
/* standalone codec interface */
|
643
|
+
size_t IndexIVF::sa_code_size () const
|
644
|
+
{
|
645
|
+
size_t coarse_size = coarse_code_size();
|
646
|
+
return code_size + coarse_size;
|
647
|
+
}
|
648
|
+
|
649
|
+
void IndexIVF::sa_encode (idx_t n, const float *x,
|
650
|
+
uint8_t *bytes) const
|
651
|
+
{
|
652
|
+
FAISS_THROW_IF_NOT (is_trained);
|
653
|
+
std::unique_ptr<int64_t []> idx (new int64_t [n]);
|
654
|
+
quantizer->assign (n, x, idx.get());
|
655
|
+
encode_vectors (n, x, idx.get(), bytes, true);
|
656
|
+
}
|
657
|
+
|
658
|
+
|
659
|
+
void IndexIVF::search_and_reconstruct (idx_t n, const float *x, idx_t k,
|
660
|
+
float *distances, idx_t *labels,
|
661
|
+
float *recons) const
|
662
|
+
{
|
663
|
+
idx_t * idx = new idx_t [n * nprobe];
|
664
|
+
ScopeDeleter<idx_t> del (idx);
|
665
|
+
float * coarse_dis = new float [n * nprobe];
|
666
|
+
ScopeDeleter<float> del2 (coarse_dis);
|
667
|
+
|
668
|
+
quantizer->search (n, x, nprobe, coarse_dis, idx);
|
669
|
+
|
670
|
+
invlists->prefetch_lists (idx, n * nprobe);
|
671
|
+
|
672
|
+
// search_preassigned() with `store_pairs` enabled to obtain the list_no
|
673
|
+
// and offset into `codes` for reconstruction
|
674
|
+
search_preassigned (n, x, k, idx, coarse_dis,
|
675
|
+
distances, labels, true /* store_pairs */);
|
676
|
+
for (idx_t i = 0; i < n; ++i) {
|
677
|
+
for (idx_t j = 0; j < k; ++j) {
|
678
|
+
idx_t ij = i * k + j;
|
679
|
+
idx_t key = labels[ij];
|
680
|
+
float* reconstructed = recons + ij * d;
|
681
|
+
if (key < 0) {
|
682
|
+
// Fill with NaNs
|
683
|
+
memset(reconstructed, -1, sizeof(*reconstructed) * d);
|
684
|
+
} else {
|
685
|
+
int list_no = key >> 32;
|
686
|
+
int offset = key & 0xffffffff;
|
687
|
+
|
688
|
+
// Update label to the actual id
|
689
|
+
labels[ij] = invlists->get_single_id (list_no, offset);
|
690
|
+
|
691
|
+
reconstruct_from_offset (list_no, offset, reconstructed);
|
692
|
+
}
|
693
|
+
}
|
694
|
+
}
|
695
|
+
}
|
696
|
+
|
697
|
+
void IndexIVF::reconstruct_from_offset(
|
698
|
+
int64_t /*list_no*/,
|
699
|
+
int64_t /*offset*/,
|
700
|
+
float* /*recons*/) const {
|
701
|
+
FAISS_THROW_MSG ("reconstruct_from_offset not implemented");
|
702
|
+
}
|
703
|
+
|
704
|
+
void IndexIVF::reset ()
|
705
|
+
{
|
706
|
+
direct_map.clear ();
|
707
|
+
invlists->reset ();
|
708
|
+
ntotal = 0;
|
709
|
+
}
|
710
|
+
|
711
|
+
|
712
|
+
size_t IndexIVF::remove_ids (const IDSelector & sel)
|
713
|
+
{
|
714
|
+
FAISS_THROW_IF_NOT_MSG (!maintain_direct_map,
|
715
|
+
"direct map remove not implemented");
|
716
|
+
|
717
|
+
std::vector<idx_t> toremove(nlist);
|
718
|
+
|
719
|
+
#pragma omp parallel for
|
720
|
+
for (idx_t i = 0; i < nlist; i++) {
|
721
|
+
idx_t l0 = invlists->list_size (i), l = l0, j = 0;
|
722
|
+
ScopedIds idsi (invlists, i);
|
723
|
+
while (j < l) {
|
724
|
+
if (sel.is_member (idsi[j])) {
|
725
|
+
l--;
|
726
|
+
invlists->update_entry (
|
727
|
+
i, j,
|
728
|
+
invlists->get_single_id (i, l),
|
729
|
+
ScopedCodes (invlists, i, l).get());
|
730
|
+
} else {
|
731
|
+
j++;
|
732
|
+
}
|
733
|
+
}
|
734
|
+
toremove[i] = l0 - l;
|
735
|
+
}
|
736
|
+
// this will not run well in parallel on ondisk because of possible shrinks
|
737
|
+
size_t nremove = 0;
|
738
|
+
for (idx_t i = 0; i < nlist; i++) {
|
739
|
+
if (toremove[i] > 0) {
|
740
|
+
nremove += toremove[i];
|
741
|
+
invlists->resize(
|
742
|
+
i, invlists->list_size(i) - toremove[i]);
|
743
|
+
}
|
744
|
+
}
|
745
|
+
ntotal -= nremove;
|
746
|
+
return nremove;
|
747
|
+
}
|
748
|
+
|
749
|
+
|
750
|
+
|
751
|
+
|
752
|
+
void IndexIVF::train (idx_t n, const float *x)
|
753
|
+
{
|
754
|
+
if (verbose)
|
755
|
+
printf ("Training level-1 quantizer\n");
|
756
|
+
|
757
|
+
train_q1 (n, x, verbose, metric_type);
|
758
|
+
|
759
|
+
if (verbose)
|
760
|
+
printf ("Training IVF residual\n");
|
761
|
+
|
762
|
+
train_residual (n, x);
|
763
|
+
is_trained = true;
|
764
|
+
|
765
|
+
}
|
766
|
+
|
767
|
+
void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
|
768
|
+
if (verbose)
|
769
|
+
printf("IndexIVF: no residual training\n");
|
770
|
+
// does nothing by default
|
771
|
+
}
|
772
|
+
|
773
|
+
|
774
|
+
void IndexIVF::check_compatible_for_merge (const IndexIVF &other) const
|
775
|
+
{
|
776
|
+
// minimal sanity checks
|
777
|
+
FAISS_THROW_IF_NOT (other.d == d);
|
778
|
+
FAISS_THROW_IF_NOT (other.nlist == nlist);
|
779
|
+
FAISS_THROW_IF_NOT (other.code_size == code_size);
|
780
|
+
FAISS_THROW_IF_NOT_MSG (typeid (*this) == typeid (other),
|
781
|
+
"can only merge indexes of the same type");
|
782
|
+
}
|
783
|
+
|
784
|
+
|
785
|
+
void IndexIVF::merge_from (IndexIVF &other, idx_t add_id)
|
786
|
+
{
|
787
|
+
check_compatible_for_merge (other);
|
788
|
+
FAISS_THROW_IF_NOT_MSG ((!maintain_direct_map &&
|
789
|
+
!other.maintain_direct_map),
|
790
|
+
"direct map copy not implemented");
|
791
|
+
|
792
|
+
invlists->merge_from (other.invlists, add_id);
|
793
|
+
|
794
|
+
ntotal += other.ntotal;
|
795
|
+
other.ntotal = 0;
|
796
|
+
}
|
797
|
+
|
798
|
+
|
799
|
+
void IndexIVF::replace_invlists (InvertedLists *il, bool own)
|
800
|
+
{
|
801
|
+
if (own_invlists) {
|
802
|
+
delete invlists;
|
803
|
+
}
|
804
|
+
// FAISS_THROW_IF_NOT (ntotal == 0);
|
805
|
+
if (il) {
|
806
|
+
FAISS_THROW_IF_NOT (il->nlist == nlist &&
|
807
|
+
il->code_size == code_size);
|
808
|
+
}
|
809
|
+
invlists = il;
|
810
|
+
own_invlists = own;
|
811
|
+
}
|
812
|
+
|
813
|
+
|
814
|
+
void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
|
815
|
+
idx_t a1, idx_t a2) const
|
816
|
+
{
|
817
|
+
|
818
|
+
FAISS_THROW_IF_NOT (nlist == other.nlist);
|
819
|
+
FAISS_THROW_IF_NOT (code_size == other.code_size);
|
820
|
+
FAISS_THROW_IF_NOT (!other.maintain_direct_map);
|
821
|
+
FAISS_THROW_IF_NOT_FMT (
|
822
|
+
subset_type == 0 || subset_type == 1 || subset_type == 2,
|
823
|
+
"subset type %d not implemented", subset_type);
|
824
|
+
|
825
|
+
size_t accu_n = 0;
|
826
|
+
size_t accu_a1 = 0;
|
827
|
+
size_t accu_a2 = 0;
|
828
|
+
|
829
|
+
InvertedLists *oivf = other.invlists;
|
830
|
+
|
831
|
+
for (idx_t list_no = 0; list_no < nlist; list_no++) {
|
832
|
+
size_t n = invlists->list_size (list_no);
|
833
|
+
ScopedIds ids_in (invlists, list_no);
|
834
|
+
|
835
|
+
if (subset_type == 0) {
|
836
|
+
for (idx_t i = 0; i < n; i++) {
|
837
|
+
idx_t id = ids_in[i];
|
838
|
+
if (a1 <= id && id < a2) {
|
839
|
+
oivf->add_entry (list_no,
|
840
|
+
invlists->get_single_id (list_no, i),
|
841
|
+
ScopedCodes (invlists, list_no, i).get());
|
842
|
+
other.ntotal++;
|
843
|
+
}
|
844
|
+
}
|
845
|
+
} else if (subset_type == 1) {
|
846
|
+
for (idx_t i = 0; i < n; i++) {
|
847
|
+
idx_t id = ids_in[i];
|
848
|
+
if (id % a1 == a2) {
|
849
|
+
oivf->add_entry (list_no,
|
850
|
+
invlists->get_single_id (list_no, i),
|
851
|
+
ScopedCodes (invlists, list_no, i).get());
|
852
|
+
other.ntotal++;
|
853
|
+
}
|
854
|
+
}
|
855
|
+
} else if (subset_type == 2) {
|
856
|
+
// see what is allocated to a1 and to a2
|
857
|
+
size_t next_accu_n = accu_n + n;
|
858
|
+
size_t next_accu_a1 = next_accu_n * a1 / ntotal;
|
859
|
+
size_t i1 = next_accu_a1 - accu_a1;
|
860
|
+
size_t next_accu_a2 = next_accu_n * a2 / ntotal;
|
861
|
+
size_t i2 = next_accu_a2 - accu_a2;
|
862
|
+
|
863
|
+
for (idx_t i = i1; i < i2; i++) {
|
864
|
+
oivf->add_entry (list_no,
|
865
|
+
invlists->get_single_id (list_no, i),
|
866
|
+
ScopedCodes (invlists, list_no, i).get());
|
867
|
+
}
|
868
|
+
|
869
|
+
other.ntotal += i2 - i1;
|
870
|
+
accu_a1 = next_accu_a1;
|
871
|
+
accu_a2 = next_accu_a2;
|
872
|
+
}
|
873
|
+
accu_n += n;
|
874
|
+
}
|
875
|
+
FAISS_ASSERT(accu_n == ntotal);
|
876
|
+
|
877
|
+
}
|
878
|
+
|
879
|
+
|
880
|
+
|
881
|
+
|
882
|
+
IndexIVF::~IndexIVF()
|
883
|
+
{
|
884
|
+
if (own_invlists) {
|
885
|
+
delete invlists;
|
886
|
+
}
|
887
|
+
}
|
888
|
+
|
889
|
+
|
890
|
+
void IndexIVFStats::reset()
|
891
|
+
{
|
892
|
+
memset ((void*)this, 0, sizeof (*this));
|
893
|
+
}
|
894
|
+
|
895
|
+
|
896
|
+
IndexIVFStats indexIVF_stats;
|
897
|
+
|
898
|
+
void InvertedListScanner::scan_codes_range (size_t ,
|
899
|
+
const uint8_t *,
|
900
|
+
const idx_t *,
|
901
|
+
float ,
|
902
|
+
RangeQueryResult &) const
|
903
|
+
{
|
904
|
+
FAISS_THROW_MSG ("scan_codes_range not implemented");
|
905
|
+
}
|
906
|
+
|
907
|
+
|
908
|
+
|
909
|
+
} // namespace faiss
|