faiss 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#pragma once
|
|
11
|
+
|
|
12
|
+
#include <vector>
|
|
13
|
+
|
|
14
|
+
#include <faiss/impl/HNSW.h>
|
|
15
|
+
#include <faiss/IndexFlat.h>
|
|
16
|
+
#include <faiss/IndexPQ.h>
|
|
17
|
+
#include <faiss/IndexScalarQuantizer.h>
|
|
18
|
+
#include <faiss/utils/utils.h>
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
namespace faiss {
|
|
22
|
+
|
|
23
|
+
struct IndexHNSW;
|
|
24
|
+
|
|
25
|
+
struct ReconstructFromNeighbors {
|
|
26
|
+
typedef Index::idx_t idx_t;
|
|
27
|
+
typedef HNSW::storage_idx_t storage_idx_t;
|
|
28
|
+
|
|
29
|
+
const IndexHNSW & index;
|
|
30
|
+
size_t M; // number of neighbors
|
|
31
|
+
size_t k; // number of codebook entries
|
|
32
|
+
size_t nsq; // number of subvectors
|
|
33
|
+
size_t code_size;
|
|
34
|
+
int k_reorder; // nb to reorder. -1 = all
|
|
35
|
+
|
|
36
|
+
std::vector<float> codebook; // size nsq * k * (M + 1)
|
|
37
|
+
|
|
38
|
+
std::vector<uint8_t> codes; // size ntotal * code_size
|
|
39
|
+
size_t ntotal;
|
|
40
|
+
size_t d, dsub; // derived values
|
|
41
|
+
|
|
42
|
+
explicit ReconstructFromNeighbors(const IndexHNSW& index,
|
|
43
|
+
size_t k=256, size_t nsq=1);
|
|
44
|
+
|
|
45
|
+
/// codes must be added in the correct order and the IndexHNSW
|
|
46
|
+
/// must be populated and sorted
|
|
47
|
+
void add_codes(size_t n, const float *x);
|
|
48
|
+
|
|
49
|
+
size_t compute_distances(size_t n, const idx_t *shortlist,
|
|
50
|
+
const float *query, float *distances) const;
|
|
51
|
+
|
|
52
|
+
/// called by add_codes
|
|
53
|
+
void estimate_code(const float *x, storage_idx_t i, uint8_t *code) const;
|
|
54
|
+
|
|
55
|
+
/// called by compute_distances
|
|
56
|
+
void reconstruct(storage_idx_t i, float *x, float *tmp) const;
|
|
57
|
+
|
|
58
|
+
void reconstruct_n(storage_idx_t n0, storage_idx_t ni, float *x) const;
|
|
59
|
+
|
|
60
|
+
/// get the M+1 -by-d table for neighbor coordinates for vector i
|
|
61
|
+
void get_neighbor_table(storage_idx_t i, float *out) const;
|
|
62
|
+
|
|
63
|
+
};
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
/** The HNSW index is a normal random-access index with a HNSW
|
|
67
|
+
* link structure built on top */
|
|
68
|
+
|
|
69
|
+
struct IndexHNSW : Index {
|
|
70
|
+
|
|
71
|
+
typedef HNSW::storage_idx_t storage_idx_t;
|
|
72
|
+
|
|
73
|
+
// the link strcuture
|
|
74
|
+
HNSW hnsw;
|
|
75
|
+
|
|
76
|
+
// the sequential storage
|
|
77
|
+
bool own_fields;
|
|
78
|
+
Index *storage;
|
|
79
|
+
|
|
80
|
+
ReconstructFromNeighbors *reconstruct_from_neighbors;
|
|
81
|
+
|
|
82
|
+
explicit IndexHNSW (int d = 0, int M = 32);
|
|
83
|
+
explicit IndexHNSW (Index *storage, int M = 32);
|
|
84
|
+
|
|
85
|
+
~IndexHNSW() override;
|
|
86
|
+
|
|
87
|
+
void add(idx_t n, const float *x) override;
|
|
88
|
+
|
|
89
|
+
/// Trains the storage if needed
|
|
90
|
+
void train(idx_t n, const float* x) override;
|
|
91
|
+
|
|
92
|
+
/// entry point for search
|
|
93
|
+
void search (idx_t n, const float *x, idx_t k,
|
|
94
|
+
float *distances, idx_t *labels) const override;
|
|
95
|
+
|
|
96
|
+
void reconstruct(idx_t key, float* recons) const override;
|
|
97
|
+
|
|
98
|
+
void reset () override;
|
|
99
|
+
|
|
100
|
+
void shrink_level_0_neighbors(int size);
|
|
101
|
+
|
|
102
|
+
/** Perform search only on level 0, given the starting points for
|
|
103
|
+
* each vertex.
|
|
104
|
+
*
|
|
105
|
+
* @param search_type 1:perform one search per nprobe, 2: enqueue
|
|
106
|
+
* all entry points
|
|
107
|
+
*/
|
|
108
|
+
void search_level_0(idx_t n, const float *x, idx_t k,
|
|
109
|
+
const storage_idx_t *nearest, const float *nearest_d,
|
|
110
|
+
float *distances, idx_t *labels, int nprobe = 1,
|
|
111
|
+
int search_type = 1) const;
|
|
112
|
+
|
|
113
|
+
/// alternative graph building
|
|
114
|
+
void init_level_0_from_knngraph(
|
|
115
|
+
int k, const float *D, const idx_t *I);
|
|
116
|
+
|
|
117
|
+
/// alternative graph building
|
|
118
|
+
void init_level_0_from_entry_points(
|
|
119
|
+
int npt, const storage_idx_t *points,
|
|
120
|
+
const storage_idx_t *nearests);
|
|
121
|
+
|
|
122
|
+
// reorder links from nearest to farthest
|
|
123
|
+
void reorder_links();
|
|
124
|
+
|
|
125
|
+
void link_singletons();
|
|
126
|
+
};
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
/** Flat index topped with with a HNSW structure to access elements
|
|
130
|
+
* more efficiently.
|
|
131
|
+
*/
|
|
132
|
+
|
|
133
|
+
struct IndexHNSWFlat : IndexHNSW {
|
|
134
|
+
IndexHNSWFlat();
|
|
135
|
+
IndexHNSWFlat(int d, int M);
|
|
136
|
+
};
|
|
137
|
+
|
|
138
|
+
/** PQ index topped with with a HNSW structure to access elements
|
|
139
|
+
* more efficiently.
|
|
140
|
+
*/
|
|
141
|
+
struct IndexHNSWPQ : IndexHNSW {
|
|
142
|
+
IndexHNSWPQ();
|
|
143
|
+
IndexHNSWPQ(int d, int pq_m, int M);
|
|
144
|
+
void train(idx_t n, const float* x) override;
|
|
145
|
+
};
|
|
146
|
+
|
|
147
|
+
/** SQ index topped with with a HNSW structure to access elements
|
|
148
|
+
* more efficiently.
|
|
149
|
+
*/
|
|
150
|
+
struct IndexHNSWSQ : IndexHNSW {
|
|
151
|
+
IndexHNSWSQ();
|
|
152
|
+
IndexHNSWSQ(int d, ScalarQuantizer::QuantizerType qtype, int M);
|
|
153
|
+
};
|
|
154
|
+
|
|
155
|
+
/** 2-level code structure with fast random access
|
|
156
|
+
*/
|
|
157
|
+
struct IndexHNSW2Level : IndexHNSW {
|
|
158
|
+
IndexHNSW2Level();
|
|
159
|
+
IndexHNSW2Level(Index *quantizer, size_t nlist, int m_pq, int M);
|
|
160
|
+
|
|
161
|
+
void flip_to_ivf();
|
|
162
|
+
|
|
163
|
+
/// entry point for search
|
|
164
|
+
void search (idx_t n, const float *x, idx_t k,
|
|
165
|
+
float *distances, idx_t *labels) const override;
|
|
166
|
+
|
|
167
|
+
};
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
} // namespace faiss
|
|
@@ -0,0 +1,909 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#include <faiss/IndexIVF.h>
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
#include <omp.h>
|
|
14
|
+
|
|
15
|
+
#include <cstdio>
|
|
16
|
+
#include <memory>
|
|
17
|
+
|
|
18
|
+
#include <faiss/utils/utils.h>
|
|
19
|
+
#include <faiss/utils/hamming.h>
|
|
20
|
+
|
|
21
|
+
#include <faiss/impl/FaissAssert.h>
|
|
22
|
+
#include <faiss/IndexFlat.h>
|
|
23
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
|
24
|
+
|
|
25
|
+
namespace faiss {
|
|
26
|
+
|
|
27
|
+
using ScopedIds = InvertedLists::ScopedIds;
|
|
28
|
+
using ScopedCodes = InvertedLists::ScopedCodes;
|
|
29
|
+
|
|
30
|
+
/*****************************************
|
|
31
|
+
* Level1Quantizer implementation
|
|
32
|
+
******************************************/
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
Level1Quantizer::Level1Quantizer (Index * quantizer, size_t nlist):
|
|
36
|
+
quantizer (quantizer),
|
|
37
|
+
nlist (nlist),
|
|
38
|
+
quantizer_trains_alone (0),
|
|
39
|
+
own_fields (false),
|
|
40
|
+
clustering_index (nullptr)
|
|
41
|
+
{
|
|
42
|
+
// here we set a low # iterations because this is typically used
|
|
43
|
+
// for large clusterings (nb this is not used for the MultiIndex,
|
|
44
|
+
// for which quantizer_trains_alone = true)
|
|
45
|
+
cp.niter = 10;
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
Level1Quantizer::Level1Quantizer ():
|
|
49
|
+
quantizer (nullptr),
|
|
50
|
+
nlist (0),
|
|
51
|
+
quantizer_trains_alone (0), own_fields (false),
|
|
52
|
+
clustering_index (nullptr)
|
|
53
|
+
{}
|
|
54
|
+
|
|
55
|
+
Level1Quantizer::~Level1Quantizer ()
|
|
56
|
+
{
|
|
57
|
+
if (own_fields) delete quantizer;
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
void Level1Quantizer::train_q1 (size_t n, const float *x, bool verbose, MetricType metric_type)
|
|
61
|
+
{
|
|
62
|
+
size_t d = quantizer->d;
|
|
63
|
+
if (quantizer->is_trained && (quantizer->ntotal == nlist)) {
|
|
64
|
+
if (verbose)
|
|
65
|
+
printf ("IVF quantizer does not need training.\n");
|
|
66
|
+
} else if (quantizer_trains_alone == 1) {
|
|
67
|
+
if (verbose)
|
|
68
|
+
printf ("IVF quantizer trains alone...\n");
|
|
69
|
+
quantizer->train (n, x);
|
|
70
|
+
quantizer->verbose = verbose;
|
|
71
|
+
FAISS_THROW_IF_NOT_MSG (quantizer->ntotal == nlist,
|
|
72
|
+
"nlist not consistent with quantizer size");
|
|
73
|
+
} else if (quantizer_trains_alone == 0) {
|
|
74
|
+
if (verbose)
|
|
75
|
+
printf ("Training level-1 quantizer on %ld vectors in %ldD\n",
|
|
76
|
+
n, d);
|
|
77
|
+
|
|
78
|
+
Clustering clus (d, nlist, cp);
|
|
79
|
+
quantizer->reset();
|
|
80
|
+
if (clustering_index) {
|
|
81
|
+
clus.train (n, x, *clustering_index);
|
|
82
|
+
quantizer->add (nlist, clus.centroids.data());
|
|
83
|
+
} else {
|
|
84
|
+
clus.train (n, x, *quantizer);
|
|
85
|
+
}
|
|
86
|
+
quantizer->is_trained = true;
|
|
87
|
+
} else if (quantizer_trains_alone == 2) {
|
|
88
|
+
if (verbose)
|
|
89
|
+
printf (
|
|
90
|
+
"Training L2 quantizer on %ld vectors in %ldD%s\n",
|
|
91
|
+
n, d,
|
|
92
|
+
clustering_index ? "(user provided index)" : "");
|
|
93
|
+
FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
|
|
94
|
+
Clustering clus (d, nlist, cp);
|
|
95
|
+
if (!clustering_index) {
|
|
96
|
+
IndexFlatL2 assigner (d);
|
|
97
|
+
clus.train(n, x, assigner);
|
|
98
|
+
} else {
|
|
99
|
+
clus.train(n, x, *clustering_index);
|
|
100
|
+
}
|
|
101
|
+
if (verbose)
|
|
102
|
+
printf ("Adding centroids to quantizer\n");
|
|
103
|
+
quantizer->add (nlist, clus.centroids.data());
|
|
104
|
+
}
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
size_t Level1Quantizer::coarse_code_size () const
|
|
108
|
+
{
|
|
109
|
+
size_t nl = nlist - 1;
|
|
110
|
+
size_t nbyte = 0;
|
|
111
|
+
while (nl > 0) {
|
|
112
|
+
nbyte ++;
|
|
113
|
+
nl >>= 8;
|
|
114
|
+
}
|
|
115
|
+
return nbyte;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
void Level1Quantizer::encode_listno (Index::idx_t list_no, uint8_t *code) const
|
|
119
|
+
{
|
|
120
|
+
// little endian
|
|
121
|
+
size_t nl = nlist - 1;
|
|
122
|
+
while (nl > 0) {
|
|
123
|
+
*code++ = list_no & 0xff;
|
|
124
|
+
list_no >>= 8;
|
|
125
|
+
nl >>= 8;
|
|
126
|
+
}
|
|
127
|
+
}
|
|
128
|
+
|
|
129
|
+
Index::idx_t Level1Quantizer::decode_listno (const uint8_t *code) const
|
|
130
|
+
{
|
|
131
|
+
size_t nl = nlist - 1;
|
|
132
|
+
int64_t list_no = 0;
|
|
133
|
+
int nbit = 0;
|
|
134
|
+
while (nl > 0) {
|
|
135
|
+
list_no |= int64_t(*code++) << nbit;
|
|
136
|
+
nbit += 8;
|
|
137
|
+
nl >>= 8;
|
|
138
|
+
}
|
|
139
|
+
FAISS_THROW_IF_NOT (list_no >= 0 && list_no < nlist);
|
|
140
|
+
return list_no;
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
/*****************************************
|
|
146
|
+
* IndexIVF implementation
|
|
147
|
+
******************************************/
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
IndexIVF::IndexIVF (Index * quantizer, size_t d,
|
|
151
|
+
size_t nlist, size_t code_size,
|
|
152
|
+
MetricType metric):
|
|
153
|
+
Index (d, metric),
|
|
154
|
+
Level1Quantizer (quantizer, nlist),
|
|
155
|
+
invlists (new ArrayInvertedLists (nlist, code_size)),
|
|
156
|
+
own_invlists (true),
|
|
157
|
+
code_size (code_size),
|
|
158
|
+
nprobe (1),
|
|
159
|
+
max_codes (0),
|
|
160
|
+
parallel_mode (0),
|
|
161
|
+
maintain_direct_map (false)
|
|
162
|
+
{
|
|
163
|
+
FAISS_THROW_IF_NOT (d == quantizer->d);
|
|
164
|
+
is_trained = quantizer->is_trained && (quantizer->ntotal == nlist);
|
|
165
|
+
// Spherical by default if the metric is inner_product
|
|
166
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
167
|
+
cp.spherical = true;
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
IndexIVF::IndexIVF ():
|
|
173
|
+
invlists (nullptr), own_invlists (false),
|
|
174
|
+
code_size (0),
|
|
175
|
+
nprobe (1), max_codes (0), parallel_mode (0),
|
|
176
|
+
maintain_direct_map (false)
|
|
177
|
+
{}
|
|
178
|
+
|
|
179
|
+
void IndexIVF::add (idx_t n, const float * x)
|
|
180
|
+
{
|
|
181
|
+
add_with_ids (n, x, nullptr);
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
|
|
186
|
+
{
|
|
187
|
+
// do some blocking to avoid excessive allocs
|
|
188
|
+
idx_t bs = 65536;
|
|
189
|
+
if (n > bs) {
|
|
190
|
+
for (idx_t i0 = 0; i0 < n; i0 += bs) {
|
|
191
|
+
idx_t i1 = std::min (n, i0 + bs);
|
|
192
|
+
if (verbose) {
|
|
193
|
+
printf(" IndexIVF::add_with_ids %ld:%ld\n", i0, i1);
|
|
194
|
+
}
|
|
195
|
+
add_with_ids (i1 - i0, x + i0 * d,
|
|
196
|
+
xids ? xids + i0 : nullptr);
|
|
197
|
+
}
|
|
198
|
+
return;
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
FAISS_THROW_IF_NOT (is_trained);
|
|
202
|
+
std::unique_ptr<idx_t []> idx(new idx_t[n]);
|
|
203
|
+
quantizer->assign (n, x, idx.get());
|
|
204
|
+
size_t nadd = 0, nminus1 = 0;
|
|
205
|
+
|
|
206
|
+
for (size_t i = 0; i < n; i++) {
|
|
207
|
+
if (idx[i] < 0) nminus1++;
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
std::unique_ptr<uint8_t []> flat_codes(new uint8_t [n * code_size]);
|
|
211
|
+
encode_vectors (n, x, idx.get(), flat_codes.get());
|
|
212
|
+
|
|
213
|
+
#pragma omp parallel reduction(+: nadd)
|
|
214
|
+
{
|
|
215
|
+
int nt = omp_get_num_threads();
|
|
216
|
+
int rank = omp_get_thread_num();
|
|
217
|
+
|
|
218
|
+
// each thread takes care of a subset of lists
|
|
219
|
+
for (size_t i = 0; i < n; i++) {
|
|
220
|
+
idx_t list_no = idx [i];
|
|
221
|
+
if (list_no >= 0 && list_no % nt == rank) {
|
|
222
|
+
idx_t id = xids ? xids[i] : ntotal + i;
|
|
223
|
+
invlists->add_entry (list_no, id,
|
|
224
|
+
flat_codes.get() + i * code_size);
|
|
225
|
+
nadd++;
|
|
226
|
+
}
|
|
227
|
+
}
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
if (verbose) {
|
|
231
|
+
printf(" added %ld / %ld vectors (%ld -1s)\n", nadd, n, nminus1);
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
ntotal += n;
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
void IndexIVF::make_direct_map (bool new_maintain_direct_map)
|
|
239
|
+
{
|
|
240
|
+
// nothing to do
|
|
241
|
+
if (new_maintain_direct_map == maintain_direct_map)
|
|
242
|
+
return;
|
|
243
|
+
|
|
244
|
+
if (new_maintain_direct_map) {
|
|
245
|
+
direct_map.resize (ntotal, -1);
|
|
246
|
+
for (size_t key = 0; key < nlist; key++) {
|
|
247
|
+
size_t list_size = invlists->list_size (key);
|
|
248
|
+
ScopedIds idlist (invlists, key);
|
|
249
|
+
|
|
250
|
+
for (long ofs = 0; ofs < list_size; ofs++) {
|
|
251
|
+
FAISS_THROW_IF_NOT_MSG (
|
|
252
|
+
0 <= idlist [ofs] && idlist[ofs] < ntotal,
|
|
253
|
+
"direct map supported only for seuquential ids");
|
|
254
|
+
direct_map [idlist [ofs]] = key << 32 | ofs;
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
} else {
|
|
258
|
+
direct_map.clear ();
|
|
259
|
+
}
|
|
260
|
+
maintain_direct_map = new_maintain_direct_map;
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
void IndexIVF::search (idx_t n, const float *x, idx_t k,
|
|
265
|
+
float *distances, idx_t *labels) const
|
|
266
|
+
{
|
|
267
|
+
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
|
268
|
+
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
|
|
269
|
+
|
|
270
|
+
double t0 = getmillisecs();
|
|
271
|
+
quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
|
|
272
|
+
indexIVF_stats.quantization_time += getmillisecs() - t0;
|
|
273
|
+
|
|
274
|
+
t0 = getmillisecs();
|
|
275
|
+
invlists->prefetch_lists (idx.get(), n * nprobe);
|
|
276
|
+
|
|
277
|
+
search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
|
|
278
|
+
distances, labels, false);
|
|
279
|
+
indexIVF_stats.search_time += getmillisecs() - t0;
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
285
|
+
const idx_t *keys,
|
|
286
|
+
const float *coarse_dis ,
|
|
287
|
+
float *distances, idx_t *labels,
|
|
288
|
+
bool store_pairs,
|
|
289
|
+
const IVFSearchParameters *params) const
|
|
290
|
+
{
|
|
291
|
+
long nprobe = params ? params->nprobe : this->nprobe;
|
|
292
|
+
long max_codes = params ? params->max_codes : this->max_codes;
|
|
293
|
+
|
|
294
|
+
size_t nlistv = 0, ndis = 0, nheap = 0;
|
|
295
|
+
|
|
296
|
+
using HeapForIP = CMin<float, idx_t>;
|
|
297
|
+
using HeapForL2 = CMax<float, idx_t>;
|
|
298
|
+
|
|
299
|
+
bool interrupt = false;
|
|
300
|
+
|
|
301
|
+
// don't start parallel section if single query
|
|
302
|
+
bool do_parallel =
|
|
303
|
+
parallel_mode == 0 ? n > 1 :
|
|
304
|
+
parallel_mode == 1 ? nprobe > 1 :
|
|
305
|
+
nprobe * n > 1;
|
|
306
|
+
|
|
307
|
+
#pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
|
|
308
|
+
{
|
|
309
|
+
InvertedListScanner *scanner = get_InvertedListScanner(store_pairs);
|
|
310
|
+
ScopeDeleter1<InvertedListScanner> del(scanner);
|
|
311
|
+
|
|
312
|
+
/*****************************************************
|
|
313
|
+
* Depending on parallel_mode, there are two possible ways
|
|
314
|
+
* to organize the search. Here we define local functions
|
|
315
|
+
* that are in common between the two
|
|
316
|
+
******************************************************/
|
|
317
|
+
|
|
318
|
+
// intialize + reorder a result heap
|
|
319
|
+
|
|
320
|
+
auto init_result = [&](float *simi, idx_t *idxi) {
|
|
321
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
322
|
+
heap_heapify<HeapForIP> (k, simi, idxi);
|
|
323
|
+
} else {
|
|
324
|
+
heap_heapify<HeapForL2> (k, simi, idxi);
|
|
325
|
+
}
|
|
326
|
+
};
|
|
327
|
+
|
|
328
|
+
auto reorder_result = [&] (float *simi, idx_t *idxi) {
|
|
329
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
330
|
+
heap_reorder<HeapForIP> (k, simi, idxi);
|
|
331
|
+
} else {
|
|
332
|
+
heap_reorder<HeapForL2> (k, simi, idxi);
|
|
333
|
+
}
|
|
334
|
+
};
|
|
335
|
+
|
|
336
|
+
// single list scan using the current scanner (with query
|
|
337
|
+
// set porperly) and storing results in simi and idxi
|
|
338
|
+
auto scan_one_list = [&] (idx_t key, float coarse_dis_i,
|
|
339
|
+
float *simi, idx_t *idxi) {
|
|
340
|
+
|
|
341
|
+
if (key < 0) {
|
|
342
|
+
// not enough centroids for multiprobe
|
|
343
|
+
return (size_t)0;
|
|
344
|
+
}
|
|
345
|
+
FAISS_THROW_IF_NOT_FMT (key < (idx_t) nlist,
|
|
346
|
+
"Invalid key=%ld nlist=%ld\n",
|
|
347
|
+
key, nlist);
|
|
348
|
+
|
|
349
|
+
size_t list_size = invlists->list_size(key);
|
|
350
|
+
|
|
351
|
+
// don't waste time on empty lists
|
|
352
|
+
if (list_size == 0) {
|
|
353
|
+
return (size_t)0;
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
scanner->set_list (key, coarse_dis_i);
|
|
357
|
+
|
|
358
|
+
nlistv++;
|
|
359
|
+
|
|
360
|
+
InvertedLists::ScopedCodes scodes (invlists, key);
|
|
361
|
+
|
|
362
|
+
std::unique_ptr<InvertedLists::ScopedIds> sids;
|
|
363
|
+
const Index::idx_t * ids = nullptr;
|
|
364
|
+
|
|
365
|
+
if (!store_pairs) {
|
|
366
|
+
sids.reset (new InvertedLists::ScopedIds (invlists, key));
|
|
367
|
+
ids = sids->get();
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
nheap += scanner->scan_codes (list_size, scodes.get(),
|
|
371
|
+
ids, simi, idxi, k);
|
|
372
|
+
|
|
373
|
+
return list_size;
|
|
374
|
+
};
|
|
375
|
+
|
|
376
|
+
/****************************************************
|
|
377
|
+
* Actual loops, depending on parallel_mode
|
|
378
|
+
****************************************************/
|
|
379
|
+
|
|
380
|
+
if (parallel_mode == 0) {
|
|
381
|
+
|
|
382
|
+
#pragma omp for
|
|
383
|
+
for (size_t i = 0; i < n; i++) {
|
|
384
|
+
|
|
385
|
+
if (interrupt) {
|
|
386
|
+
continue;
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
// loop over queries
|
|
390
|
+
scanner->set_query (x + i * d);
|
|
391
|
+
float * simi = distances + i * k;
|
|
392
|
+
idx_t * idxi = labels + i * k;
|
|
393
|
+
|
|
394
|
+
init_result (simi, idxi);
|
|
395
|
+
|
|
396
|
+
long nscan = 0;
|
|
397
|
+
|
|
398
|
+
// loop over probes
|
|
399
|
+
for (size_t ik = 0; ik < nprobe; ik++) {
|
|
400
|
+
|
|
401
|
+
nscan += scan_one_list (
|
|
402
|
+
keys [i * nprobe + ik],
|
|
403
|
+
coarse_dis[i * nprobe + ik],
|
|
404
|
+
simi, idxi
|
|
405
|
+
);
|
|
406
|
+
|
|
407
|
+
if (max_codes && nscan >= max_codes) {
|
|
408
|
+
break;
|
|
409
|
+
}
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
ndis += nscan;
|
|
413
|
+
reorder_result (simi, idxi);
|
|
414
|
+
|
|
415
|
+
if (InterruptCallback::is_interrupted ()) {
|
|
416
|
+
interrupt = true;
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
} // parallel for
|
|
420
|
+
} else if (parallel_mode == 1) {
|
|
421
|
+
std::vector <idx_t> local_idx (k);
|
|
422
|
+
std::vector <float> local_dis (k);
|
|
423
|
+
|
|
424
|
+
for (size_t i = 0; i < n; i++) {
|
|
425
|
+
scanner->set_query (x + i * d);
|
|
426
|
+
init_result (local_dis.data(), local_idx.data());
|
|
427
|
+
|
|
428
|
+
#pragma omp for schedule(dynamic)
|
|
429
|
+
for (size_t ik = 0; ik < nprobe; ik++) {
|
|
430
|
+
ndis += scan_one_list
|
|
431
|
+
(keys [i * nprobe + ik],
|
|
432
|
+
coarse_dis[i * nprobe + ik],
|
|
433
|
+
local_dis.data(), local_idx.data());
|
|
434
|
+
|
|
435
|
+
// can't do the test on max_codes
|
|
436
|
+
}
|
|
437
|
+
// merge thread-local results
|
|
438
|
+
|
|
439
|
+
float * simi = distances + i * k;
|
|
440
|
+
idx_t * idxi = labels + i * k;
|
|
441
|
+
#pragma omp single
|
|
442
|
+
init_result (simi, idxi);
|
|
443
|
+
|
|
444
|
+
#pragma omp barrier
|
|
445
|
+
#pragma omp critical
|
|
446
|
+
{
|
|
447
|
+
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
448
|
+
heap_addn<HeapForIP>
|
|
449
|
+
(k, simi, idxi,
|
|
450
|
+
local_dis.data(), local_idx.data(), k);
|
|
451
|
+
} else {
|
|
452
|
+
heap_addn<HeapForL2>
|
|
453
|
+
(k, simi, idxi,
|
|
454
|
+
local_dis.data(), local_idx.data(), k);
|
|
455
|
+
}
|
|
456
|
+
}
|
|
457
|
+
#pragma omp barrier
|
|
458
|
+
#pragma omp single
|
|
459
|
+
reorder_result (simi, idxi);
|
|
460
|
+
}
|
|
461
|
+
} else {
|
|
462
|
+
FAISS_THROW_FMT ("parallel_mode %d not supported\n",
|
|
463
|
+
parallel_mode);
|
|
464
|
+
}
|
|
465
|
+
} // parallel section
|
|
466
|
+
|
|
467
|
+
if (interrupt) {
|
|
468
|
+
FAISS_THROW_MSG ("computation interrupted");
|
|
469
|
+
}
|
|
470
|
+
|
|
471
|
+
indexIVF_stats.nq += n;
|
|
472
|
+
indexIVF_stats.nlist += nlistv;
|
|
473
|
+
indexIVF_stats.ndis += ndis;
|
|
474
|
+
indexIVF_stats.nheap_updates += nheap;
|
|
475
|
+
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
void IndexIVF::range_search (idx_t nx, const float *x, float radius,
|
|
482
|
+
RangeSearchResult *result) const
|
|
483
|
+
{
|
|
484
|
+
std::unique_ptr<idx_t[]> keys (new idx_t[nx * nprobe]);
|
|
485
|
+
std::unique_ptr<float []> coarse_dis (new float[nx * nprobe]);
|
|
486
|
+
|
|
487
|
+
double t0 = getmillisecs();
|
|
488
|
+
quantizer->search (nx, x, nprobe, coarse_dis.get (), keys.get ());
|
|
489
|
+
indexIVF_stats.quantization_time += getmillisecs() - t0;
|
|
490
|
+
|
|
491
|
+
t0 = getmillisecs();
|
|
492
|
+
invlists->prefetch_lists (keys.get(), nx * nprobe);
|
|
493
|
+
|
|
494
|
+
range_search_preassigned (nx, x, radius, keys.get (), coarse_dis.get (),
|
|
495
|
+
result);
|
|
496
|
+
|
|
497
|
+
indexIVF_stats.search_time += getmillisecs() - t0;
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
void IndexIVF::range_search_preassigned (
|
|
501
|
+
idx_t nx, const float *x, float radius,
|
|
502
|
+
const idx_t *keys, const float *coarse_dis,
|
|
503
|
+
RangeSearchResult *result) const
|
|
504
|
+
{
|
|
505
|
+
|
|
506
|
+
size_t nlistv = 0, ndis = 0;
|
|
507
|
+
bool store_pairs = false;
|
|
508
|
+
|
|
509
|
+
std::vector<RangeSearchPartialResult *> all_pres (omp_get_max_threads());
|
|
510
|
+
|
|
511
|
+
#pragma omp parallel reduction(+: nlistv, ndis)
|
|
512
|
+
{
|
|
513
|
+
RangeSearchPartialResult pres(result);
|
|
514
|
+
std::unique_ptr<InvertedListScanner> scanner
|
|
515
|
+
(get_InvertedListScanner(store_pairs));
|
|
516
|
+
FAISS_THROW_IF_NOT (scanner.get ());
|
|
517
|
+
all_pres[omp_get_thread_num()] = &pres;
|
|
518
|
+
|
|
519
|
+
// prepare the list scanning function
|
|
520
|
+
|
|
521
|
+
auto scan_list_func = [&](size_t i, size_t ik, RangeQueryResult &qres) {
|
|
522
|
+
|
|
523
|
+
idx_t key = keys[i * nprobe + ik]; /* select the list */
|
|
524
|
+
if (key < 0) return;
|
|
525
|
+
FAISS_THROW_IF_NOT_FMT (
|
|
526
|
+
key < (idx_t) nlist,
|
|
527
|
+
"Invalid key=%ld at ik=%ld nlist=%ld\n",
|
|
528
|
+
key, ik, nlist);
|
|
529
|
+
const size_t list_size = invlists->list_size(key);
|
|
530
|
+
|
|
531
|
+
if (list_size == 0) return;
|
|
532
|
+
|
|
533
|
+
InvertedLists::ScopedCodes scodes (invlists, key);
|
|
534
|
+
InvertedLists::ScopedIds ids (invlists, key);
|
|
535
|
+
|
|
536
|
+
scanner->set_list (key, coarse_dis[i * nprobe + ik]);
|
|
537
|
+
nlistv++;
|
|
538
|
+
ndis += list_size;
|
|
539
|
+
scanner->scan_codes_range (list_size, scodes.get(),
|
|
540
|
+
ids.get(), radius, qres);
|
|
541
|
+
};
|
|
542
|
+
|
|
543
|
+
if (parallel_mode == 0) {
|
|
544
|
+
|
|
545
|
+
#pragma omp for
|
|
546
|
+
for (size_t i = 0; i < nx; i++) {
|
|
547
|
+
scanner->set_query (x + i * d);
|
|
548
|
+
|
|
549
|
+
RangeQueryResult & qres = pres.new_result (i);
|
|
550
|
+
|
|
551
|
+
for (size_t ik = 0; ik < nprobe; ik++) {
|
|
552
|
+
scan_list_func (i, ik, qres);
|
|
553
|
+
}
|
|
554
|
+
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
} else if (parallel_mode == 1) {
|
|
558
|
+
|
|
559
|
+
for (size_t i = 0; i < nx; i++) {
|
|
560
|
+
scanner->set_query (x + i * d);
|
|
561
|
+
|
|
562
|
+
RangeQueryResult & qres = pres.new_result (i);
|
|
563
|
+
|
|
564
|
+
#pragma omp for schedule(dynamic)
|
|
565
|
+
for (size_t ik = 0; ik < nprobe; ik++) {
|
|
566
|
+
scan_list_func (i, ik, qres);
|
|
567
|
+
}
|
|
568
|
+
}
|
|
569
|
+
} else if (parallel_mode == 2) {
|
|
570
|
+
std::vector<RangeQueryResult *> all_qres (nx);
|
|
571
|
+
RangeQueryResult *qres = nullptr;
|
|
572
|
+
|
|
573
|
+
#pragma omp for schedule(dynamic)
|
|
574
|
+
for (size_t iik = 0; iik < nx * nprobe; iik++) {
|
|
575
|
+
size_t i = iik / nprobe;
|
|
576
|
+
size_t ik = iik % nprobe;
|
|
577
|
+
if (qres == nullptr || qres->qno != i) {
|
|
578
|
+
FAISS_ASSERT (!qres || i > qres->qno);
|
|
579
|
+
qres = &pres.new_result (i);
|
|
580
|
+
scanner->set_query (x + i * d);
|
|
581
|
+
}
|
|
582
|
+
scan_list_func (i, ik, *qres);
|
|
583
|
+
}
|
|
584
|
+
} else {
|
|
585
|
+
FAISS_THROW_FMT ("parallel_mode %d not supported\n", parallel_mode);
|
|
586
|
+
}
|
|
587
|
+
if (parallel_mode == 0) {
|
|
588
|
+
pres.finalize ();
|
|
589
|
+
} else {
|
|
590
|
+
#pragma omp barrier
|
|
591
|
+
#pragma omp single
|
|
592
|
+
RangeSearchPartialResult::merge (all_pres, false);
|
|
593
|
+
#pragma omp barrier
|
|
594
|
+
|
|
595
|
+
}
|
|
596
|
+
}
|
|
597
|
+
indexIVF_stats.nq += nx;
|
|
598
|
+
indexIVF_stats.nlist += nlistv;
|
|
599
|
+
indexIVF_stats.ndis += ndis;
|
|
600
|
+
}
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
InvertedListScanner *IndexIVF::get_InvertedListScanner (
|
|
604
|
+
bool /*store_pairs*/) const
|
|
605
|
+
{
|
|
606
|
+
return nullptr;
|
|
607
|
+
}
|
|
608
|
+
|
|
609
|
+
void IndexIVF::reconstruct (idx_t key, float* recons) const
|
|
610
|
+
{
|
|
611
|
+
FAISS_THROW_IF_NOT_MSG (direct_map.size() == ntotal,
|
|
612
|
+
"direct map is not initialized");
|
|
613
|
+
FAISS_THROW_IF_NOT_MSG (key >= 0 && key < direct_map.size(),
|
|
614
|
+
"invalid key");
|
|
615
|
+
idx_t list_no = direct_map[key] >> 32;
|
|
616
|
+
idx_t offset = direct_map[key] & 0xffffffff;
|
|
617
|
+
reconstruct_from_offset (list_no, offset, recons);
|
|
618
|
+
}
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
void IndexIVF::reconstruct_n (idx_t i0, idx_t ni, float* recons) const
|
|
622
|
+
{
|
|
623
|
+
FAISS_THROW_IF_NOT (ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
|
|
624
|
+
|
|
625
|
+
for (idx_t list_no = 0; list_no < nlist; list_no++) {
|
|
626
|
+
size_t list_size = invlists->list_size (list_no);
|
|
627
|
+
ScopedIds idlist (invlists, list_no);
|
|
628
|
+
|
|
629
|
+
for (idx_t offset = 0; offset < list_size; offset++) {
|
|
630
|
+
idx_t id = idlist[offset];
|
|
631
|
+
if (!(id >= i0 && id < i0 + ni)) {
|
|
632
|
+
continue;
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
float* reconstructed = recons + (id - i0) * d;
|
|
636
|
+
reconstruct_from_offset (list_no, offset, reconstructed);
|
|
637
|
+
}
|
|
638
|
+
}
|
|
639
|
+
}
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
/* standalone codec interface */
|
|
643
|
+
size_t IndexIVF::sa_code_size () const
|
|
644
|
+
{
|
|
645
|
+
size_t coarse_size = coarse_code_size();
|
|
646
|
+
return code_size + coarse_size;
|
|
647
|
+
}
|
|
648
|
+
|
|
649
|
+
void IndexIVF::sa_encode (idx_t n, const float *x,
|
|
650
|
+
uint8_t *bytes) const
|
|
651
|
+
{
|
|
652
|
+
FAISS_THROW_IF_NOT (is_trained);
|
|
653
|
+
std::unique_ptr<int64_t []> idx (new int64_t [n]);
|
|
654
|
+
quantizer->assign (n, x, idx.get());
|
|
655
|
+
encode_vectors (n, x, idx.get(), bytes, true);
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
void IndexIVF::search_and_reconstruct (idx_t n, const float *x, idx_t k,
|
|
660
|
+
float *distances, idx_t *labels,
|
|
661
|
+
float *recons) const
|
|
662
|
+
{
|
|
663
|
+
idx_t * idx = new idx_t [n * nprobe];
|
|
664
|
+
ScopeDeleter<idx_t> del (idx);
|
|
665
|
+
float * coarse_dis = new float [n * nprobe];
|
|
666
|
+
ScopeDeleter<float> del2 (coarse_dis);
|
|
667
|
+
|
|
668
|
+
quantizer->search (n, x, nprobe, coarse_dis, idx);
|
|
669
|
+
|
|
670
|
+
invlists->prefetch_lists (idx, n * nprobe);
|
|
671
|
+
|
|
672
|
+
// search_preassigned() with `store_pairs` enabled to obtain the list_no
|
|
673
|
+
// and offset into `codes` for reconstruction
|
|
674
|
+
search_preassigned (n, x, k, idx, coarse_dis,
|
|
675
|
+
distances, labels, true /* store_pairs */);
|
|
676
|
+
for (idx_t i = 0; i < n; ++i) {
|
|
677
|
+
for (idx_t j = 0; j < k; ++j) {
|
|
678
|
+
idx_t ij = i * k + j;
|
|
679
|
+
idx_t key = labels[ij];
|
|
680
|
+
float* reconstructed = recons + ij * d;
|
|
681
|
+
if (key < 0) {
|
|
682
|
+
// Fill with NaNs
|
|
683
|
+
memset(reconstructed, -1, sizeof(*reconstructed) * d);
|
|
684
|
+
} else {
|
|
685
|
+
int list_no = key >> 32;
|
|
686
|
+
int offset = key & 0xffffffff;
|
|
687
|
+
|
|
688
|
+
// Update label to the actual id
|
|
689
|
+
labels[ij] = invlists->get_single_id (list_no, offset);
|
|
690
|
+
|
|
691
|
+
reconstruct_from_offset (list_no, offset, reconstructed);
|
|
692
|
+
}
|
|
693
|
+
}
|
|
694
|
+
}
|
|
695
|
+
}
|
|
696
|
+
|
|
697
|
+
void IndexIVF::reconstruct_from_offset(
|
|
698
|
+
int64_t /*list_no*/,
|
|
699
|
+
int64_t /*offset*/,
|
|
700
|
+
float* /*recons*/) const {
|
|
701
|
+
FAISS_THROW_MSG ("reconstruct_from_offset not implemented");
|
|
702
|
+
}
|
|
703
|
+
|
|
704
|
+
void IndexIVF::reset ()
|
|
705
|
+
{
|
|
706
|
+
direct_map.clear ();
|
|
707
|
+
invlists->reset ();
|
|
708
|
+
ntotal = 0;
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
|
|
712
|
+
size_t IndexIVF::remove_ids (const IDSelector & sel)
|
|
713
|
+
{
|
|
714
|
+
FAISS_THROW_IF_NOT_MSG (!maintain_direct_map,
|
|
715
|
+
"direct map remove not implemented");
|
|
716
|
+
|
|
717
|
+
std::vector<idx_t> toremove(nlist);
|
|
718
|
+
|
|
719
|
+
#pragma omp parallel for
|
|
720
|
+
for (idx_t i = 0; i < nlist; i++) {
|
|
721
|
+
idx_t l0 = invlists->list_size (i), l = l0, j = 0;
|
|
722
|
+
ScopedIds idsi (invlists, i);
|
|
723
|
+
while (j < l) {
|
|
724
|
+
if (sel.is_member (idsi[j])) {
|
|
725
|
+
l--;
|
|
726
|
+
invlists->update_entry (
|
|
727
|
+
i, j,
|
|
728
|
+
invlists->get_single_id (i, l),
|
|
729
|
+
ScopedCodes (invlists, i, l).get());
|
|
730
|
+
} else {
|
|
731
|
+
j++;
|
|
732
|
+
}
|
|
733
|
+
}
|
|
734
|
+
toremove[i] = l0 - l;
|
|
735
|
+
}
|
|
736
|
+
// this will not run well in parallel on ondisk because of possible shrinks
|
|
737
|
+
size_t nremove = 0;
|
|
738
|
+
for (idx_t i = 0; i < nlist; i++) {
|
|
739
|
+
if (toremove[i] > 0) {
|
|
740
|
+
nremove += toremove[i];
|
|
741
|
+
invlists->resize(
|
|
742
|
+
i, invlists->list_size(i) - toremove[i]);
|
|
743
|
+
}
|
|
744
|
+
}
|
|
745
|
+
ntotal -= nremove;
|
|
746
|
+
return nremove;
|
|
747
|
+
}
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
void IndexIVF::train (idx_t n, const float *x)
|
|
753
|
+
{
|
|
754
|
+
if (verbose)
|
|
755
|
+
printf ("Training level-1 quantizer\n");
|
|
756
|
+
|
|
757
|
+
train_q1 (n, x, verbose, metric_type);
|
|
758
|
+
|
|
759
|
+
if (verbose)
|
|
760
|
+
printf ("Training IVF residual\n");
|
|
761
|
+
|
|
762
|
+
train_residual (n, x);
|
|
763
|
+
is_trained = true;
|
|
764
|
+
|
|
765
|
+
}
|
|
766
|
+
|
|
767
|
+
void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
|
|
768
|
+
if (verbose)
|
|
769
|
+
printf("IndexIVF: no residual training\n");
|
|
770
|
+
// does nothing by default
|
|
771
|
+
}
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
void IndexIVF::check_compatible_for_merge (const IndexIVF &other) const
|
|
775
|
+
{
|
|
776
|
+
// minimal sanity checks
|
|
777
|
+
FAISS_THROW_IF_NOT (other.d == d);
|
|
778
|
+
FAISS_THROW_IF_NOT (other.nlist == nlist);
|
|
779
|
+
FAISS_THROW_IF_NOT (other.code_size == code_size);
|
|
780
|
+
FAISS_THROW_IF_NOT_MSG (typeid (*this) == typeid (other),
|
|
781
|
+
"can only merge indexes of the same type");
|
|
782
|
+
}
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
void IndexIVF::merge_from (IndexIVF &other, idx_t add_id)
|
|
786
|
+
{
|
|
787
|
+
check_compatible_for_merge (other);
|
|
788
|
+
FAISS_THROW_IF_NOT_MSG ((!maintain_direct_map &&
|
|
789
|
+
!other.maintain_direct_map),
|
|
790
|
+
"direct map copy not implemented");
|
|
791
|
+
|
|
792
|
+
invlists->merge_from (other.invlists, add_id);
|
|
793
|
+
|
|
794
|
+
ntotal += other.ntotal;
|
|
795
|
+
other.ntotal = 0;
|
|
796
|
+
}
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
void IndexIVF::replace_invlists (InvertedLists *il, bool own)
|
|
800
|
+
{
|
|
801
|
+
if (own_invlists) {
|
|
802
|
+
delete invlists;
|
|
803
|
+
}
|
|
804
|
+
// FAISS_THROW_IF_NOT (ntotal == 0);
|
|
805
|
+
if (il) {
|
|
806
|
+
FAISS_THROW_IF_NOT (il->nlist == nlist &&
|
|
807
|
+
il->code_size == code_size);
|
|
808
|
+
}
|
|
809
|
+
invlists = il;
|
|
810
|
+
own_invlists = own;
|
|
811
|
+
}
|
|
812
|
+
|
|
813
|
+
|
|
814
|
+
void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
|
|
815
|
+
idx_t a1, idx_t a2) const
|
|
816
|
+
{
|
|
817
|
+
|
|
818
|
+
FAISS_THROW_IF_NOT (nlist == other.nlist);
|
|
819
|
+
FAISS_THROW_IF_NOT (code_size == other.code_size);
|
|
820
|
+
FAISS_THROW_IF_NOT (!other.maintain_direct_map);
|
|
821
|
+
FAISS_THROW_IF_NOT_FMT (
|
|
822
|
+
subset_type == 0 || subset_type == 1 || subset_type == 2,
|
|
823
|
+
"subset type %d not implemented", subset_type);
|
|
824
|
+
|
|
825
|
+
size_t accu_n = 0;
|
|
826
|
+
size_t accu_a1 = 0;
|
|
827
|
+
size_t accu_a2 = 0;
|
|
828
|
+
|
|
829
|
+
InvertedLists *oivf = other.invlists;
|
|
830
|
+
|
|
831
|
+
for (idx_t list_no = 0; list_no < nlist; list_no++) {
|
|
832
|
+
size_t n = invlists->list_size (list_no);
|
|
833
|
+
ScopedIds ids_in (invlists, list_no);
|
|
834
|
+
|
|
835
|
+
if (subset_type == 0) {
|
|
836
|
+
for (idx_t i = 0; i < n; i++) {
|
|
837
|
+
idx_t id = ids_in[i];
|
|
838
|
+
if (a1 <= id && id < a2) {
|
|
839
|
+
oivf->add_entry (list_no,
|
|
840
|
+
invlists->get_single_id (list_no, i),
|
|
841
|
+
ScopedCodes (invlists, list_no, i).get());
|
|
842
|
+
other.ntotal++;
|
|
843
|
+
}
|
|
844
|
+
}
|
|
845
|
+
} else if (subset_type == 1) {
|
|
846
|
+
for (idx_t i = 0; i < n; i++) {
|
|
847
|
+
idx_t id = ids_in[i];
|
|
848
|
+
if (id % a1 == a2) {
|
|
849
|
+
oivf->add_entry (list_no,
|
|
850
|
+
invlists->get_single_id (list_no, i),
|
|
851
|
+
ScopedCodes (invlists, list_no, i).get());
|
|
852
|
+
other.ntotal++;
|
|
853
|
+
}
|
|
854
|
+
}
|
|
855
|
+
} else if (subset_type == 2) {
|
|
856
|
+
// see what is allocated to a1 and to a2
|
|
857
|
+
size_t next_accu_n = accu_n + n;
|
|
858
|
+
size_t next_accu_a1 = next_accu_n * a1 / ntotal;
|
|
859
|
+
size_t i1 = next_accu_a1 - accu_a1;
|
|
860
|
+
size_t next_accu_a2 = next_accu_n * a2 / ntotal;
|
|
861
|
+
size_t i2 = next_accu_a2 - accu_a2;
|
|
862
|
+
|
|
863
|
+
for (idx_t i = i1; i < i2; i++) {
|
|
864
|
+
oivf->add_entry (list_no,
|
|
865
|
+
invlists->get_single_id (list_no, i),
|
|
866
|
+
ScopedCodes (invlists, list_no, i).get());
|
|
867
|
+
}
|
|
868
|
+
|
|
869
|
+
other.ntotal += i2 - i1;
|
|
870
|
+
accu_a1 = next_accu_a1;
|
|
871
|
+
accu_a2 = next_accu_a2;
|
|
872
|
+
}
|
|
873
|
+
accu_n += n;
|
|
874
|
+
}
|
|
875
|
+
FAISS_ASSERT(accu_n == ntotal);
|
|
876
|
+
|
|
877
|
+
}
|
|
878
|
+
|
|
879
|
+
|
|
880
|
+
|
|
881
|
+
|
|
882
|
+
IndexIVF::~IndexIVF()
|
|
883
|
+
{
|
|
884
|
+
if (own_invlists) {
|
|
885
|
+
delete invlists;
|
|
886
|
+
}
|
|
887
|
+
}
|
|
888
|
+
|
|
889
|
+
|
|
890
|
+
void IndexIVFStats::reset()
|
|
891
|
+
{
|
|
892
|
+
memset ((void*)this, 0, sizeof (*this));
|
|
893
|
+
}
|
|
894
|
+
|
|
895
|
+
|
|
896
|
+
IndexIVFStats indexIVF_stats;
|
|
897
|
+
|
|
898
|
+
void InvertedListScanner::scan_codes_range (size_t ,
|
|
899
|
+
const uint8_t *,
|
|
900
|
+
const idx_t *,
|
|
901
|
+
float ,
|
|
902
|
+
RangeQueryResult &) const
|
|
903
|
+
{
|
|
904
|
+
FAISS_THROW_MSG ("scan_codes_range not implemented");
|
|
905
|
+
}
|
|
906
|
+
|
|
907
|
+
|
|
908
|
+
|
|
909
|
+
} // namespace faiss
|