faiss 0.5.2 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +5 -6
- data/ext/faiss/index_binary.cpp +76 -17
- data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
- data/ext/faiss/kmeans.cpp +12 -9
- data/ext/faiss/numo.hpp +11 -9
- data/ext/faiss/pca_matrix.cpp +10 -8
- data/ext/faiss/product_quantizer.cpp +14 -12
- data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
- data/ext/faiss/{utils.h → utils_rb.h} +6 -0
- data/lib/faiss/version.rb +1 -1
- data/lib/faiss.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +130 -11
- data/vendor/faiss/faiss/AutoTune.h +14 -1
- data/vendor/faiss/faiss/Clustering.cpp +59 -10
- data/vendor/faiss/faiss/Clustering.h +12 -0
- data/vendor/faiss/faiss/IVFlib.cpp +31 -28
- data/vendor/faiss/faiss/Index.cpp +20 -8
- data/vendor/faiss/faiss/Index.h +25 -3
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
- data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
- data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
- data/vendor/faiss/faiss/IndexFastScan.h +10 -1
- data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
- data/vendor/faiss/faiss/IndexFlat.h +16 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
- data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
- data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
- data/vendor/faiss/faiss/IndexHNSW.h +14 -12
- data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
- data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
- data/vendor/faiss/faiss/IndexIVF.h +14 -4
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
- data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
- data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
- data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
- data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
- data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
- data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
- data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
- data/vendor/faiss/faiss/IndexNSG.h +0 -2
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
- data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
- data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
- data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
- data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
- data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
- data/vendor/faiss/faiss/IndexShards.cpp +3 -4
- data/vendor/faiss/faiss/MetricType.h +16 -0
- data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
- data/vendor/faiss/faiss/VectorTransform.h +23 -0
- data/vendor/faiss/faiss/clone_index.cpp +7 -4
- data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
- data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
- data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
- data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
- data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
- data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
- data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
- data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
- data/vendor/faiss/faiss/impl/HNSW.h +8 -6
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
- data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
- data/vendor/faiss/faiss/impl/NSG.h +17 -7
- data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
- data/vendor/faiss/faiss/impl/Panorama.h +22 -6
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
- data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
- data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
- data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
- data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
- data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
- data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
- data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
- data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
- data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
- data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
- data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
- data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
- data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
- data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
- data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
- data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
- data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
- data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
- data/vendor/faiss/faiss/index_factory.cpp +35 -16
- data/vendor/faiss/faiss/index_io.h +29 -3
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
- data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
- data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
- data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
- data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
- data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
- data/vendor/faiss/faiss/utils/Heap.h +21 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
- data/vendor/faiss/faiss/utils/distances.cpp +141 -23
- data/vendor/faiss/faiss/utils/distances.h +98 -0
- data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
- data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
- data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
- data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
- data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
- data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
- data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
- data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
- data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
- data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
- data/vendor/faiss/faiss/utils/utils.cpp +16 -9
- metadata +47 -18
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
- /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
|
@@ -8,13 +8,11 @@
|
|
|
8
8
|
#include <faiss/IndexAdditiveQuantizer.h>
|
|
9
9
|
|
|
10
10
|
#include <algorithm>
|
|
11
|
-
#include <cmath>
|
|
12
11
|
#include <cstring>
|
|
13
12
|
|
|
14
13
|
#include <faiss/impl/FaissAssert.h>
|
|
15
14
|
#include <faiss/impl/ResidualQuantizer.h>
|
|
16
15
|
#include <faiss/impl/ResultHandler.h>
|
|
17
|
-
#include <faiss/utils/distances.h>
|
|
18
16
|
#include <faiss/utils/extra_distances.h>
|
|
19
17
|
|
|
20
18
|
namespace faiss {
|
|
@@ -189,17 +187,14 @@ void search_with_LUT(
|
|
|
189
187
|
FlatCodesDistanceComputer* IndexAdditiveQuantizer::
|
|
190
188
|
get_FlatCodesDistanceComputer() const {
|
|
191
189
|
if (aq->search_type == AdditiveQuantizer::ST_decompress) {
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
} else {
|
|
201
|
-
FAISS_THROW_MSG("unsupported metric");
|
|
202
|
-
}
|
|
190
|
+
return with_VectorDistance(
|
|
191
|
+
d,
|
|
192
|
+
metric_type,
|
|
193
|
+
metric_arg,
|
|
194
|
+
[&](auto vd) -> FlatCodesDistanceComputer* {
|
|
195
|
+
return new AQDistanceComputerDecompress<decltype(vd)>(
|
|
196
|
+
*this, vd);
|
|
197
|
+
});
|
|
203
198
|
} else {
|
|
204
199
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
205
200
|
return new AQDistanceComputerLUT<
|
|
@@ -242,17 +237,17 @@ void IndexAdditiveQuantizer::search(
|
|
|
242
237
|
!params, "search params not supported for this index");
|
|
243
238
|
|
|
244
239
|
if (aq->search_type == AdditiveQuantizer::ST_decompress) {
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
}
|
|
240
|
+
with_VectorDistance(d, metric_type, metric_arg, [&](auto vd) {
|
|
241
|
+
if constexpr (decltype(vd)::is_similarity) {
|
|
242
|
+
HeapBlockResultHandler<CMin<float, idx_t>> rh(
|
|
243
|
+
n, distances, labels, k);
|
|
244
|
+
search_with_decompress(*this, x, vd, rh);
|
|
245
|
+
} else {
|
|
246
|
+
HeapBlockResultHandler<CMax<float, idx_t>> rh(
|
|
247
|
+
n, distances, labels, k);
|
|
248
|
+
search_with_decompress(*this, x, vd, rh);
|
|
249
|
+
}
|
|
250
|
+
});
|
|
256
251
|
} else {
|
|
257
252
|
if (metric_type == METRIC_INNER_PRODUCT) {
|
|
258
253
|
HeapBlockResultHandler<CMin<float, idx_t>> rh(
|
|
@@ -22,6 +22,7 @@
|
|
|
22
22
|
#include <faiss/impl/DistanceComputer.h>
|
|
23
23
|
#include <faiss/impl/FaissAssert.h>
|
|
24
24
|
#include <faiss/impl/ResultHandler.h>
|
|
25
|
+
#include <faiss/impl/VisitedTable.h>
|
|
25
26
|
#include <faiss/utils/Heap.h>
|
|
26
27
|
#include <faiss/utils/hamming.h>
|
|
27
28
|
#include <faiss/utils/random.h>
|
|
@@ -205,10 +206,14 @@ void IndexBinaryHNSW::search(
|
|
|
205
206
|
idx_t k,
|
|
206
207
|
int32_t* distances,
|
|
207
208
|
idx_t* labels,
|
|
208
|
-
const SearchParameters*
|
|
209
|
-
FAISS_THROW_IF_NOT_MSG(
|
|
210
|
-
!params, "search params not supported for this index");
|
|
209
|
+
const SearchParameters* params_in) const {
|
|
211
210
|
FAISS_THROW_IF_NOT(k > 0);
|
|
211
|
+
const SearchParametersHNSW* params = nullptr;
|
|
212
|
+
if (params_in) {
|
|
213
|
+
params = dynamic_cast<const SearchParametersHNSW*>(params_in);
|
|
214
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
215
|
+
params, "IndexBinaryHNSW params have incorrect type");
|
|
216
|
+
}
|
|
212
217
|
|
|
213
218
|
// we use the buffer for distances as float but convert them back
|
|
214
219
|
// to int in the end
|
|
@@ -231,7 +236,7 @@ void IndexBinaryHNSW::search(
|
|
|
231
236
|
// as the index parameter. This state does not get used in the
|
|
232
237
|
// search function, as it is merely there to to enable Panorama
|
|
233
238
|
// execution for IndexHNSWFlatPanorama.
|
|
234
|
-
hnsw.search(*dis, nullptr, res, vt);
|
|
239
|
+
hnsw.search(*dis, nullptr, res, vt, params_in);
|
|
235
240
|
res.end();
|
|
236
241
|
}
|
|
237
242
|
}
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
#include <cstdio>
|
|
15
15
|
|
|
16
16
|
#include <algorithm>
|
|
17
|
+
#include <limits>
|
|
17
18
|
#include <memory>
|
|
18
19
|
|
|
19
20
|
#include <faiss/IndexFlat.h>
|
|
@@ -120,25 +121,46 @@ void IndexBinaryIVF::search(
|
|
|
120
121
|
idx_t k,
|
|
121
122
|
int32_t* distances,
|
|
122
123
|
idx_t* labels,
|
|
123
|
-
const SearchParameters*
|
|
124
|
-
FAISS_THROW_IF_NOT_MSG(
|
|
125
|
-
!params, "search params not supported for this index");
|
|
124
|
+
const SearchParameters* params_in) const {
|
|
126
125
|
FAISS_THROW_IF_NOT(k > 0);
|
|
126
|
+
const IVFSearchParameters* params = nullptr;
|
|
127
|
+
if (params_in) {
|
|
128
|
+
params = dynamic_cast<const IVFSearchParameters*>(params_in);
|
|
129
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
130
|
+
params, "IndexBinaryIVF params have incorrect type");
|
|
131
|
+
FAISS_THROW_IF_MSG(
|
|
132
|
+
params->sel, "IDSelector is not supported for IndexBinaryIVF");
|
|
133
|
+
}
|
|
134
|
+
const size_t nprobe =
|
|
135
|
+
std::min(nlist, params ? params->nprobe : this->nprobe);
|
|
127
136
|
FAISS_THROW_IF_NOT(nprobe > 0);
|
|
128
137
|
|
|
129
|
-
|
|
130
|
-
std::unique_ptr<
|
|
131
|
-
std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe_2]);
|
|
138
|
+
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
|
|
139
|
+
std::unique_ptr<int32_t[]> coarse_dis(new int32_t[n * nprobe]);
|
|
132
140
|
|
|
133
141
|
double t0 = getmillisecs();
|
|
134
|
-
quantizer->search(
|
|
142
|
+
quantizer->search(
|
|
143
|
+
n,
|
|
144
|
+
x,
|
|
145
|
+
nprobe,
|
|
146
|
+
coarse_dis.get(),
|
|
147
|
+
idx.get(),
|
|
148
|
+
params ? params->quantizer_params : nullptr);
|
|
135
149
|
indexIVF_stats.quantization_time += getmillisecs() - t0;
|
|
136
150
|
|
|
137
151
|
t0 = getmillisecs();
|
|
138
|
-
invlists->prefetch_lists(idx.get(), n *
|
|
152
|
+
invlists->prefetch_lists(idx.get(), n * nprobe);
|
|
139
153
|
|
|
140
154
|
search_preassigned(
|
|
141
|
-
n,
|
|
155
|
+
n,
|
|
156
|
+
x,
|
|
157
|
+
k,
|
|
158
|
+
idx.get(),
|
|
159
|
+
coarse_dis.get(),
|
|
160
|
+
distances,
|
|
161
|
+
labels,
|
|
162
|
+
false,
|
|
163
|
+
params);
|
|
142
164
|
indexIVF_stats.search_time += getmillisecs() - t0;
|
|
143
165
|
}
|
|
144
166
|
|
|
@@ -389,6 +411,10 @@ void search_knn_hamming_heap(
|
|
|
389
411
|
idx_t nprobe = params ? params->nprobe : ivf->nprobe;
|
|
390
412
|
nprobe = std::min((idx_t)ivf->nlist, nprobe);
|
|
391
413
|
idx_t max_codes = params ? params->max_codes : ivf->max_codes;
|
|
414
|
+
const idx_t unlimited_list_size = std::numeric_limits<idx_t>::max();
|
|
415
|
+
if (max_codes == 0) {
|
|
416
|
+
max_codes = unlimited_list_size;
|
|
417
|
+
}
|
|
392
418
|
MetricType metric_type = ivf->metric_type;
|
|
393
419
|
|
|
394
420
|
// almost verbatim copy from IndexIVF::search_preassigned
|
|
@@ -437,6 +463,10 @@ void search_knn_hamming_heap(
|
|
|
437
463
|
nlistv++;
|
|
438
464
|
|
|
439
465
|
size_t list_size = ivf->invlists->list_size(key);
|
|
466
|
+
size_t list_size_max = max_codes - nscan;
|
|
467
|
+
if (list_size > list_size_max) {
|
|
468
|
+
list_size = list_size_max;
|
|
469
|
+
}
|
|
440
470
|
InvertedLists::ScopedCodes scodes(ivf->invlists, key);
|
|
441
471
|
std::unique_ptr<InvertedLists::ScopedIds> sids;
|
|
442
472
|
const idx_t* ids = nullptr;
|
|
@@ -451,7 +481,7 @@ void search_knn_hamming_heap(
|
|
|
451
481
|
list_size, scodes.get(), ids, simi, idxi, k);
|
|
452
482
|
|
|
453
483
|
nscan += list_size;
|
|
454
|
-
if (
|
|
484
|
+
if (nscan >= max_codes) {
|
|
455
485
|
break;
|
|
456
486
|
}
|
|
457
487
|
}
|
|
@@ -525,6 +555,10 @@ void search_knn_hamming_count(
|
|
|
525
555
|
|
|
526
556
|
nlistv++;
|
|
527
557
|
size_t list_size = ivf->invlists->list_size(key);
|
|
558
|
+
size_t list_size_max = max_codes - nscan;
|
|
559
|
+
if (list_size > list_size_max) {
|
|
560
|
+
list_size = list_size_max;
|
|
561
|
+
}
|
|
528
562
|
InvertedLists::ScopedCodes scodes(ivf->invlists, key);
|
|
529
563
|
const uint8_t* list_vecs = scodes.get();
|
|
530
564
|
const idx_t* ids =
|
|
@@ -541,7 +575,7 @@ void search_knn_hamming_count(
|
|
|
541
575
|
}
|
|
542
576
|
|
|
543
577
|
nscan += list_size;
|
|
544
|
-
if (
|
|
578
|
+
if (nscan >= max_codes) {
|
|
545
579
|
break;
|
|
546
580
|
}
|
|
547
581
|
}
|
|
@@ -20,11 +20,8 @@
|
|
|
20
20
|
#include <faiss/impl/pq4_fast_scan.h>
|
|
21
21
|
#include <faiss/impl/simd_result_handlers.h>
|
|
22
22
|
#include <faiss/utils/hamming.h>
|
|
23
|
-
#include <faiss/utils/utils.h>
|
|
24
|
-
|
|
25
|
-
#include <faiss/impl/pq4_fast_scan.h>
|
|
26
|
-
#include <faiss/impl/simd_result_handlers.h>
|
|
27
23
|
#include <faiss/utils/quantize_lut.h>
|
|
24
|
+
#include <faiss/utils/utils.h>
|
|
28
25
|
|
|
29
26
|
namespace faiss {
|
|
30
27
|
|
|
@@ -84,7 +81,8 @@ void IndexFastScan::add(idx_t n, const float* x) {
|
|
|
84
81
|
compute_codes(tmp_codes.get(), n, x);
|
|
85
82
|
|
|
86
83
|
ntotal2 = roundup(ntotal + n, bbs);
|
|
87
|
-
size_t
|
|
84
|
+
size_t n_blocks = ntotal2 / bbs;
|
|
85
|
+
size_t new_size = n_blocks * get_block_stride();
|
|
88
86
|
size_t old_size = codes.size();
|
|
89
87
|
if (new_size > old_size) {
|
|
90
88
|
codes.resize(new_size);
|
|
@@ -92,7 +90,15 @@ void IndexFastScan::add(idx_t n, const float* x) {
|
|
|
92
90
|
}
|
|
93
91
|
|
|
94
92
|
pq4_pack_codes_range(
|
|
95
|
-
tmp_codes.get(),
|
|
93
|
+
tmp_codes.get(),
|
|
94
|
+
M,
|
|
95
|
+
ntotal,
|
|
96
|
+
ntotal + n,
|
|
97
|
+
bbs,
|
|
98
|
+
M2,
|
|
99
|
+
codes.get(),
|
|
100
|
+
0,
|
|
101
|
+
get_block_stride());
|
|
96
102
|
|
|
97
103
|
ntotal += n;
|
|
98
104
|
}
|
|
@@ -101,17 +107,25 @@ CodePacker* IndexFastScan::get_CodePacker() const {
|
|
|
101
107
|
return new CodePackerPQ4(M, bbs);
|
|
102
108
|
}
|
|
103
109
|
|
|
110
|
+
size_t IndexFastScan::get_block_stride() const {
|
|
111
|
+
std::unique_ptr<CodePacker> packer(get_CodePacker());
|
|
112
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
113
|
+
packer->nvec == static_cast<size_t>(bbs),
|
|
114
|
+
"CodePacker must pack bbs vectors per block for fast-scan");
|
|
115
|
+
return packer->block_size;
|
|
116
|
+
}
|
|
117
|
+
|
|
104
118
|
size_t IndexFastScan::remove_ids(const IDSelector& sel) {
|
|
105
119
|
idx_t j = 0;
|
|
106
120
|
std::vector<uint8_t> buffer(code_size);
|
|
107
|
-
|
|
121
|
+
std::unique_ptr<CodePacker> packer(get_CodePacker());
|
|
108
122
|
for (idx_t i = 0; i < ntotal; i++) {
|
|
109
123
|
if (sel.is_member(i)) {
|
|
110
124
|
// should be removed
|
|
111
125
|
} else {
|
|
112
126
|
if (i > j) {
|
|
113
|
-
packer
|
|
114
|
-
packer
|
|
127
|
+
packer->unpack_1(codes.data(), i, buffer.data());
|
|
128
|
+
packer->pack_1(buffer.data(), j, codes.data());
|
|
115
129
|
}
|
|
116
130
|
j++;
|
|
117
131
|
}
|
|
@@ -120,8 +134,7 @@ size_t IndexFastScan::remove_ids(const IDSelector& sel) {
|
|
|
120
134
|
if (nremove > 0) {
|
|
121
135
|
ntotal = j;
|
|
122
136
|
ntotal2 = roundup(ntotal, bbs);
|
|
123
|
-
|
|
124
|
-
codes.resize(new_size);
|
|
137
|
+
codes.resize(ntotal2 / bbs * get_block_stride());
|
|
125
138
|
}
|
|
126
139
|
return nremove;
|
|
127
140
|
}
|
|
@@ -143,13 +156,14 @@ void IndexFastScan::merge_from(Index& otherIndex, idx_t add_id) {
|
|
|
143
156
|
check_compatible_for_merge(otherIndex);
|
|
144
157
|
IndexFastScan* other = static_cast<IndexFastScan*>(&otherIndex);
|
|
145
158
|
ntotal2 = roundup(ntotal + other->ntotal, bbs);
|
|
146
|
-
codes.resize(ntotal2
|
|
159
|
+
codes.resize(ntotal2 / bbs * get_block_stride());
|
|
147
160
|
std::vector<uint8_t> buffer(code_size);
|
|
148
|
-
|
|
161
|
+
std::unique_ptr<CodePacker> packer(get_CodePacker());
|
|
162
|
+
std::unique_ptr<CodePacker> other_packer(other->get_CodePacker());
|
|
149
163
|
|
|
150
164
|
for (int i = 0; i < other->ntotal; i++) {
|
|
151
|
-
|
|
152
|
-
packer
|
|
165
|
+
other_packer->unpack_1(other->codes.data(), i, buffer.data());
|
|
166
|
+
packer->pack_1(buffer.data(), ntotal + i, codes.data());
|
|
153
167
|
}
|
|
154
168
|
ntotal += other->ntotal;
|
|
155
169
|
other->reset();
|
|
@@ -531,7 +545,8 @@ void IndexFastScan::search_implem_12(
|
|
|
531
545
|
codes.get(),
|
|
532
546
|
LUT.get(),
|
|
533
547
|
*handler.get(),
|
|
534
|
-
context.norm_scaler
|
|
548
|
+
context.norm_scaler,
|
|
549
|
+
get_block_stride());
|
|
535
550
|
}
|
|
536
551
|
if (!(skip & 8)) {
|
|
537
552
|
handler->end();
|
|
@@ -614,7 +629,8 @@ void IndexFastScan::search_implem_14(
|
|
|
614
629
|
codes.get(),
|
|
615
630
|
LUT.get(),
|
|
616
631
|
*handler.get(),
|
|
617
|
-
context.norm_scaler
|
|
632
|
+
context.norm_scaler,
|
|
633
|
+
get_block_stride());
|
|
618
634
|
}
|
|
619
635
|
if (!(skip & 8)) {
|
|
620
636
|
handler->end();
|
|
@@ -639,11 +655,8 @@ template void IndexFastScan::search_dispatch_implem<false>(
|
|
|
639
655
|
|
|
640
656
|
void IndexFastScan::reconstruct(idx_t key, float* recons) const {
|
|
641
657
|
std::vector<uint8_t> code(code_size, 0);
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
uint8_t c = pq4_get_packed_element(codes.data(), bbs, M2, key, m);
|
|
645
|
-
bsw.write(c, nbits);
|
|
646
|
-
}
|
|
658
|
+
std::unique_ptr<CodePacker> packer(get_CodePacker());
|
|
659
|
+
packer->unpack_1(codes.data(), key, code.data());
|
|
647
660
|
sa_decode(1, code.data(), recons);
|
|
648
661
|
}
|
|
649
662
|
|
|
@@ -214,7 +214,16 @@ struct IndexFastScan : Index {
|
|
|
214
214
|
*
|
|
215
215
|
* @return pointer to the code packer
|
|
216
216
|
*/
|
|
217
|
-
CodePacker* get_CodePacker() const;
|
|
217
|
+
virtual CodePacker* get_CodePacker() const;
|
|
218
|
+
|
|
219
|
+
/** Get stride in bytes between consecutive SIMD blocks.
|
|
220
|
+
*
|
|
221
|
+
* Derived from get_CodePacker()->block_size so that there is a
|
|
222
|
+
* single source of truth for the block layout.
|
|
223
|
+
*
|
|
224
|
+
* @return stride in bytes
|
|
225
|
+
*/
|
|
226
|
+
size_t get_block_stride() const;
|
|
218
227
|
|
|
219
228
|
/** Merge another index into this one
|
|
220
229
|
*
|