faiss 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +1 -2
- data/vendor/faiss/faiss/Clustering.cpp +39 -22
- data/vendor/faiss/faiss/Clustering.h +40 -21
- data/vendor/faiss/faiss/IVFlib.cpp +26 -12
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +40 -10
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
- data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
- data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
- data/vendor/faiss/faiss/IndexHNSW.h +62 -49
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
- data/vendor/faiss/faiss/IndexIVF.h +46 -6
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
- data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +11 -11
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
- data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
- data/vendor/faiss/faiss/impl/HNSW.h +52 -30
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
- data/vendor/faiss/faiss/impl/io.cpp +23 -15
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
- data/vendor/faiss/faiss/index_factory.cpp +41 -20
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +147 -123
- data/vendor/faiss/faiss/utils/distances.h +86 -9
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +120 -7
- data/vendor/faiss/faiss/utils/utils.h +60 -20
- metadata +23 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
|
@@ -7,8 +7,8 @@
|
|
|
7
7
|
|
|
8
8
|
#include <faiss/IndexFastScan.h>
|
|
9
9
|
|
|
10
|
-
#include <limits.h>
|
|
11
10
|
#include <cassert>
|
|
11
|
+
#include <climits>
|
|
12
12
|
#include <memory>
|
|
13
13
|
|
|
14
14
|
#include <omp.h>
|
|
@@ -37,22 +37,22 @@ inline size_t roundup(size_t a, size_t b) {
|
|
|
37
37
|
|
|
38
38
|
void IndexFastScan::init_fastscan(
|
|
39
39
|
int d,
|
|
40
|
-
size_t
|
|
41
|
-
size_t
|
|
40
|
+
size_t M_2,
|
|
41
|
+
size_t nbits_2,
|
|
42
42
|
MetricType metric,
|
|
43
43
|
int bbs) {
|
|
44
|
-
FAISS_THROW_IF_NOT(
|
|
44
|
+
FAISS_THROW_IF_NOT(nbits_2 == 4);
|
|
45
45
|
FAISS_THROW_IF_NOT(bbs % 32 == 0);
|
|
46
46
|
this->d = d;
|
|
47
|
-
this->M =
|
|
48
|
-
this->nbits =
|
|
47
|
+
this->M = M_2;
|
|
48
|
+
this->nbits = nbits_2;
|
|
49
49
|
this->metric_type = metric;
|
|
50
50
|
this->bbs = bbs;
|
|
51
|
-
ksub = (1 <<
|
|
51
|
+
ksub = (1 << nbits_2);
|
|
52
52
|
|
|
53
|
-
code_size = (
|
|
53
|
+
code_size = (M_2 * nbits_2 + 7) / 8;
|
|
54
54
|
ntotal = ntotal2 = 0;
|
|
55
|
-
M2 = roundup(
|
|
55
|
+
M2 = roundup(M_2, 2);
|
|
56
56
|
is_trained = false;
|
|
57
57
|
}
|
|
58
58
|
|
|
@@ -158,7 +158,7 @@ void IndexFastScan::merge_from(Index& otherIndex, idx_t add_id) {
|
|
|
158
158
|
|
|
159
159
|
namespace {
|
|
160
160
|
|
|
161
|
-
template <class C, typename dis_t
|
|
161
|
+
template <class C, typename dis_t>
|
|
162
162
|
void estimators_from_tables_generic(
|
|
163
163
|
const IndexFastScan& index,
|
|
164
164
|
const uint8_t* codes,
|
|
@@ -167,23 +167,27 @@ void estimators_from_tables_generic(
|
|
|
167
167
|
size_t k,
|
|
168
168
|
typename C::T* heap_dis,
|
|
169
169
|
int64_t* heap_ids,
|
|
170
|
-
const
|
|
170
|
+
const NormTableScaler* scaler) {
|
|
171
171
|
using accu_t = typename C::T;
|
|
172
172
|
|
|
173
173
|
for (size_t j = 0; j < ncodes; ++j) {
|
|
174
174
|
BitstringReader bsr(codes + j * index.code_size, index.code_size);
|
|
175
175
|
accu_t dis = 0;
|
|
176
176
|
const dis_t* dt = dis_table;
|
|
177
|
-
|
|
177
|
+
int nscale = scaler ? scaler->nscale : 0;
|
|
178
|
+
|
|
179
|
+
for (size_t m = 0; m < index.M - nscale; m++) {
|
|
178
180
|
uint64_t c = bsr.read(index.nbits);
|
|
179
181
|
dis += dt[c];
|
|
180
182
|
dt += index.ksub;
|
|
181
183
|
}
|
|
182
184
|
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
185
|
+
if (nscale) {
|
|
186
|
+
for (size_t m = 0; m < nscale; m++) {
|
|
187
|
+
uint64_t c = bsr.read(index.nbits);
|
|
188
|
+
dis += scaler->scale_one(dt[c]);
|
|
189
|
+
dt += index.ksub;
|
|
190
|
+
}
|
|
187
191
|
}
|
|
188
192
|
|
|
189
193
|
if (C::cmp(heap_dis[0], dis)) {
|
|
@@ -193,6 +197,28 @@ void estimators_from_tables_generic(
|
|
|
193
197
|
}
|
|
194
198
|
}
|
|
195
199
|
|
|
200
|
+
template <class C>
|
|
201
|
+
ResultHandlerCompare<C, false>* make_knn_handler(
|
|
202
|
+
int impl,
|
|
203
|
+
idx_t n,
|
|
204
|
+
idx_t k,
|
|
205
|
+
size_t ntotal,
|
|
206
|
+
float* distances,
|
|
207
|
+
idx_t* labels,
|
|
208
|
+
const IDSelector* sel = nullptr) {
|
|
209
|
+
using HeapHC = HeapHandler<C, false>;
|
|
210
|
+
using ReservoirHC = ReservoirHandler<C, false>;
|
|
211
|
+
using SingleResultHC = SingleResultHandler<C, false>;
|
|
212
|
+
|
|
213
|
+
if (k == 1) {
|
|
214
|
+
return new SingleResultHC(n, ntotal, distances, labels, sel);
|
|
215
|
+
} else if (impl % 2 == 0) {
|
|
216
|
+
return new HeapHC(n, ntotal, k, distances, labels, sel);
|
|
217
|
+
} else /* if (impl % 2 == 1) */ {
|
|
218
|
+
return new ReservoirHC(n, ntotal, k, 2 * k, distances, labels, sel);
|
|
219
|
+
}
|
|
220
|
+
}
|
|
221
|
+
|
|
196
222
|
} // anonymous namespace
|
|
197
223
|
|
|
198
224
|
using namespace quantize_lut;
|
|
@@ -241,22 +267,21 @@ void IndexFastScan::search(
|
|
|
241
267
|
!params, "search params not supported for this index");
|
|
242
268
|
FAISS_THROW_IF_NOT(k > 0);
|
|
243
269
|
|
|
244
|
-
DummyScaler scaler;
|
|
245
270
|
if (metric_type == METRIC_L2) {
|
|
246
|
-
search_dispatch_implem<true>(n, x, k, distances, labels,
|
|
271
|
+
search_dispatch_implem<true>(n, x, k, distances, labels, nullptr);
|
|
247
272
|
} else {
|
|
248
|
-
search_dispatch_implem<false>(n, x, k, distances, labels,
|
|
273
|
+
search_dispatch_implem<false>(n, x, k, distances, labels, nullptr);
|
|
249
274
|
}
|
|
250
275
|
}
|
|
251
276
|
|
|
252
|
-
template <bool is_max
|
|
277
|
+
template <bool is_max>
|
|
253
278
|
void IndexFastScan::search_dispatch_implem(
|
|
254
279
|
idx_t n,
|
|
255
280
|
const float* x,
|
|
256
281
|
idx_t k,
|
|
257
282
|
float* distances,
|
|
258
283
|
idx_t* labels,
|
|
259
|
-
const
|
|
284
|
+
const NormTableScaler* scaler) const {
|
|
260
285
|
using Cfloat = typename std::conditional<
|
|
261
286
|
is_max,
|
|
262
287
|
CMax<float, int64_t>,
|
|
@@ -319,14 +344,14 @@ void IndexFastScan::search_dispatch_implem(
|
|
|
319
344
|
}
|
|
320
345
|
}
|
|
321
346
|
|
|
322
|
-
template <class Cfloat
|
|
347
|
+
template <class Cfloat>
|
|
323
348
|
void IndexFastScan::search_implem_234(
|
|
324
349
|
idx_t n,
|
|
325
350
|
const float* x,
|
|
326
351
|
idx_t k,
|
|
327
352
|
float* distances,
|
|
328
353
|
idx_t* labels,
|
|
329
|
-
const
|
|
354
|
+
const NormTableScaler* scaler) const {
|
|
330
355
|
FAISS_THROW_IF_NOT(implem == 2 || implem == 3 || implem == 4);
|
|
331
356
|
|
|
332
357
|
const size_t dim12 = ksub * M;
|
|
@@ -378,7 +403,7 @@ void IndexFastScan::search_implem_234(
|
|
|
378
403
|
}
|
|
379
404
|
}
|
|
380
405
|
|
|
381
|
-
template <class C
|
|
406
|
+
template <class C>
|
|
382
407
|
void IndexFastScan::search_implem_12(
|
|
383
408
|
idx_t n,
|
|
384
409
|
const float* x,
|
|
@@ -386,7 +411,8 @@ void IndexFastScan::search_implem_12(
|
|
|
386
411
|
float* distances,
|
|
387
412
|
idx_t* labels,
|
|
388
413
|
int impl,
|
|
389
|
-
const
|
|
414
|
+
const NormTableScaler* scaler) const {
|
|
415
|
+
using RH = ResultHandlerCompare<C, false>;
|
|
390
416
|
FAISS_THROW_IF_NOT(bbs == 32);
|
|
391
417
|
|
|
392
418
|
// handle qbs2 blocking by recursive call
|
|
@@ -432,63 +458,31 @@ void IndexFastScan::search_implem_12(
|
|
|
432
458
|
pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
|
|
433
459
|
FAISS_THROW_IF_NOT(LUT_nq == n);
|
|
434
460
|
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
} else {
|
|
440
|
-
handler.disable = bool(skip & 2);
|
|
441
|
-
pq4_accumulate_loop_qbs(
|
|
442
|
-
qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
|
|
443
|
-
}
|
|
444
|
-
|
|
445
|
-
handler.to_flat_arrays(distances, labels, normalizers.get());
|
|
446
|
-
|
|
447
|
-
} else if (impl == 12) {
|
|
448
|
-
std::vector<uint16_t> tmp_dis(n * k);
|
|
449
|
-
std::vector<int32_t> tmp_ids(n * k);
|
|
450
|
-
|
|
451
|
-
if (skip & 4) {
|
|
452
|
-
// skip
|
|
453
|
-
} else {
|
|
454
|
-
HeapHandler<C> handler(
|
|
455
|
-
n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
|
|
456
|
-
handler.disable = bool(skip & 2);
|
|
457
|
-
|
|
458
|
-
pq4_accumulate_loop_qbs(
|
|
459
|
-
qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
|
|
460
|
-
|
|
461
|
-
if (!(skip & 8)) {
|
|
462
|
-
handler.to_flat_arrays(distances, labels, normalizers.get());
|
|
463
|
-
}
|
|
464
|
-
}
|
|
465
|
-
|
|
466
|
-
} else { // impl == 13
|
|
467
|
-
|
|
468
|
-
ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
|
|
469
|
-
handler.disable = bool(skip & 2);
|
|
461
|
+
std::unique_ptr<RH> handler(
|
|
462
|
+
make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
|
|
463
|
+
handler->disable = bool(skip & 2);
|
|
464
|
+
handler->normalizers = normalizers.get();
|
|
470
465
|
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
FastScan_stats.t3 += handler.times[3];
|
|
466
|
+
if (skip & 4) {
|
|
467
|
+
// pass
|
|
468
|
+
} else {
|
|
469
|
+
pq4_accumulate_loop_qbs(
|
|
470
|
+
qbs,
|
|
471
|
+
ntotal2,
|
|
472
|
+
M2,
|
|
473
|
+
codes.get(),
|
|
474
|
+
LUT.get(),
|
|
475
|
+
*handler.get(),
|
|
476
|
+
scaler);
|
|
477
|
+
}
|
|
478
|
+
if (!(skip & 8)) {
|
|
479
|
+
handler->end();
|
|
486
480
|
}
|
|
487
481
|
}
|
|
488
482
|
|
|
489
483
|
FastScanStats FastScan_stats;
|
|
490
484
|
|
|
491
|
-
template <class C
|
|
485
|
+
template <class C>
|
|
492
486
|
void IndexFastScan::search_implem_14(
|
|
493
487
|
idx_t n,
|
|
494
488
|
const float* x,
|
|
@@ -496,7 +490,8 @@ void IndexFastScan::search_implem_14(
|
|
|
496
490
|
float* distances,
|
|
497
491
|
idx_t* labels,
|
|
498
492
|
int impl,
|
|
499
|
-
const
|
|
493
|
+
const NormTableScaler* scaler) const {
|
|
494
|
+
using RH = ResultHandlerCompare<C, false>;
|
|
500
495
|
FAISS_THROW_IF_NOT(bbs % 32 == 0);
|
|
501
496
|
|
|
502
497
|
int qbs2 = qbs == 0 ? 4 : qbs;
|
|
@@ -531,90 +526,44 @@ void IndexFastScan::search_implem_14(
|
|
|
531
526
|
AlignedTable<uint8_t> LUT(n * dim12);
|
|
532
527
|
pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
|
|
533
528
|
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
} else {
|
|
539
|
-
handler.disable = bool(skip & 2);
|
|
540
|
-
pq4_accumulate_loop(
|
|
541
|
-
n,
|
|
542
|
-
ntotal2,
|
|
543
|
-
bbs,
|
|
544
|
-
M2,
|
|
545
|
-
codes.get(),
|
|
546
|
-
LUT.get(),
|
|
547
|
-
handler,
|
|
548
|
-
scaler);
|
|
549
|
-
}
|
|
550
|
-
handler.to_flat_arrays(distances, labels, normalizers.get());
|
|
551
|
-
|
|
552
|
-
} else if (impl == 14) {
|
|
553
|
-
std::vector<uint16_t> tmp_dis(n * k);
|
|
554
|
-
std::vector<int32_t> tmp_ids(n * k);
|
|
555
|
-
|
|
556
|
-
if (skip & 4) {
|
|
557
|
-
// skip
|
|
558
|
-
} else if (k > 1) {
|
|
559
|
-
HeapHandler<C> handler(
|
|
560
|
-
n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
|
|
561
|
-
handler.disable = bool(skip & 2);
|
|
562
|
-
|
|
563
|
-
pq4_accumulate_loop(
|
|
564
|
-
n,
|
|
565
|
-
ntotal2,
|
|
566
|
-
bbs,
|
|
567
|
-
M2,
|
|
568
|
-
codes.get(),
|
|
569
|
-
LUT.get(),
|
|
570
|
-
handler,
|
|
571
|
-
scaler);
|
|
572
|
-
|
|
573
|
-
if (!(skip & 8)) {
|
|
574
|
-
handler.to_flat_arrays(distances, labels, normalizers.get());
|
|
575
|
-
}
|
|
576
|
-
}
|
|
577
|
-
|
|
578
|
-
} else { // impl == 15
|
|
529
|
+
std::unique_ptr<RH> handler(
|
|
530
|
+
make_knn_handler<C>(impl, n, k, ntotal, distances, labels));
|
|
531
|
+
handler->disable = bool(skip & 2);
|
|
532
|
+
handler->normalizers = normalizers.get();
|
|
579
533
|
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
}
|
|
596
|
-
|
|
597
|
-
if (!(skip & 8)) {
|
|
598
|
-
handler.to_flat_arrays(distances, labels, normalizers.get());
|
|
599
|
-
}
|
|
534
|
+
if (skip & 4) {
|
|
535
|
+
// pass
|
|
536
|
+
} else {
|
|
537
|
+
pq4_accumulate_loop(
|
|
538
|
+
n,
|
|
539
|
+
ntotal2,
|
|
540
|
+
bbs,
|
|
541
|
+
M2,
|
|
542
|
+
codes.get(),
|
|
543
|
+
LUT.get(),
|
|
544
|
+
*handler.get(),
|
|
545
|
+
scaler);
|
|
546
|
+
}
|
|
547
|
+
if (!(skip & 8)) {
|
|
548
|
+
handler->end();
|
|
600
549
|
}
|
|
601
550
|
}
|
|
602
551
|
|
|
603
|
-
template void IndexFastScan::search_dispatch_implem<true
|
|
552
|
+
template void IndexFastScan::search_dispatch_implem<true>(
|
|
604
553
|
idx_t n,
|
|
605
554
|
const float* x,
|
|
606
555
|
idx_t k,
|
|
607
556
|
float* distances,
|
|
608
557
|
idx_t* labels,
|
|
609
|
-
const NormTableScaler
|
|
558
|
+
const NormTableScaler* scaler) const;
|
|
610
559
|
|
|
611
|
-
template void IndexFastScan::search_dispatch_implem<false
|
|
560
|
+
template void IndexFastScan::search_dispatch_implem<false>(
|
|
612
561
|
idx_t n,
|
|
613
562
|
const float* x,
|
|
614
563
|
idx_t k,
|
|
615
564
|
float* distances,
|
|
616
565
|
idx_t* labels,
|
|
617
|
-
const NormTableScaler
|
|
566
|
+
const NormTableScaler* scaler) const;
|
|
618
567
|
|
|
619
568
|
void IndexFastScan::reconstruct(idx_t key, float* recons) const {
|
|
620
569
|
std::vector<uint8_t> code(code_size, 0);
|
|
@@ -13,6 +13,7 @@
|
|
|
13
13
|
namespace faiss {
|
|
14
14
|
|
|
15
15
|
struct CodePacker;
|
|
16
|
+
struct NormTableScaler;
|
|
16
17
|
|
|
17
18
|
/** Fast scan version of IndexPQ and IndexAQ. Works for 4-bit PQ and AQ for now.
|
|
18
19
|
*
|
|
@@ -87,25 +88,25 @@ struct IndexFastScan : Index {
|
|
|
87
88
|
uint8_t* lut,
|
|
88
89
|
float* normalizers) const;
|
|
89
90
|
|
|
90
|
-
template <bool is_max
|
|
91
|
+
template <bool is_max>
|
|
91
92
|
void search_dispatch_implem(
|
|
92
93
|
idx_t n,
|
|
93
94
|
const float* x,
|
|
94
95
|
idx_t k,
|
|
95
96
|
float* distances,
|
|
96
97
|
idx_t* labels,
|
|
97
|
-
const
|
|
98
|
+
const NormTableScaler* scaler) const;
|
|
98
99
|
|
|
99
|
-
template <class Cfloat
|
|
100
|
+
template <class Cfloat>
|
|
100
101
|
void search_implem_234(
|
|
101
102
|
idx_t n,
|
|
102
103
|
const float* x,
|
|
103
104
|
idx_t k,
|
|
104
105
|
float* distances,
|
|
105
106
|
idx_t* labels,
|
|
106
|
-
const
|
|
107
|
+
const NormTableScaler* scaler) const;
|
|
107
108
|
|
|
108
|
-
template <class C
|
|
109
|
+
template <class C>
|
|
109
110
|
void search_implem_12(
|
|
110
111
|
idx_t n,
|
|
111
112
|
const float* x,
|
|
@@ -113,9 +114,9 @@ struct IndexFastScan : Index {
|
|
|
113
114
|
float* distances,
|
|
114
115
|
idx_t* labels,
|
|
115
116
|
int impl,
|
|
116
|
-
const
|
|
117
|
+
const NormTableScaler* scaler) const;
|
|
117
118
|
|
|
118
|
-
template <class C
|
|
119
|
+
template <class C>
|
|
119
120
|
void search_implem_14(
|
|
120
121
|
idx_t n,
|
|
121
122
|
const float* x,
|
|
@@ -123,7 +124,7 @@ struct IndexFastScan : Index {
|
|
|
123
124
|
float* distances,
|
|
124
125
|
idx_t* labels,
|
|
125
126
|
int impl,
|
|
126
|
-
const
|
|
127
|
+
const NormTableScaler* scaler) const;
|
|
127
128
|
|
|
128
129
|
void reconstruct(idx_t key, float* recons) const override;
|
|
129
130
|
size_t remove_ids(const IDSelector& sel) override;
|
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
#include <faiss/utils/Heap.h>
|
|
15
15
|
#include <faiss/utils/distances.h>
|
|
16
16
|
#include <faiss/utils/extra_distances.h>
|
|
17
|
+
#include <faiss/utils/prefetch.h>
|
|
17
18
|
#include <faiss/utils/sorting.h>
|
|
18
19
|
#include <faiss/utils/utils.h>
|
|
19
20
|
#include <cstring>
|
|
@@ -40,15 +41,19 @@ void IndexFlat::search(
|
|
|
40
41
|
} else if (metric_type == METRIC_L2) {
|
|
41
42
|
float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
|
|
42
43
|
knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel);
|
|
43
|
-
} else if (is_similarity_metric(metric_type)) {
|
|
44
|
-
float_minheap_array_t res = {size_t(n), size_t(k), labels, distances};
|
|
45
|
-
knn_extra_metrics(
|
|
46
|
-
x, get_xb(), d, n, ntotal, metric_type, metric_arg, &res);
|
|
47
44
|
} else {
|
|
48
|
-
FAISS_THROW_IF_NOT(!sel);
|
|
49
|
-
float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
|
|
45
|
+
FAISS_THROW_IF_NOT(!sel); // TODO implement with selector
|
|
50
46
|
knn_extra_metrics(
|
|
51
|
-
x,
|
|
47
|
+
x,
|
|
48
|
+
get_xb(),
|
|
49
|
+
d,
|
|
50
|
+
n,
|
|
51
|
+
ntotal,
|
|
52
|
+
metric_type,
|
|
53
|
+
metric_arg,
|
|
54
|
+
k,
|
|
55
|
+
distances,
|
|
56
|
+
labels);
|
|
52
57
|
}
|
|
53
58
|
}
|
|
54
59
|
|
|
@@ -122,6 +127,39 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
|
|
|
122
127
|
void set_query(const float* x) override {
|
|
123
128
|
q = x;
|
|
124
129
|
}
|
|
130
|
+
|
|
131
|
+
// compute four distances
|
|
132
|
+
void distances_batch_4(
|
|
133
|
+
const idx_t idx0,
|
|
134
|
+
const idx_t idx1,
|
|
135
|
+
const idx_t idx2,
|
|
136
|
+
const idx_t idx3,
|
|
137
|
+
float& dis0,
|
|
138
|
+
float& dis1,
|
|
139
|
+
float& dis2,
|
|
140
|
+
float& dis3) final override {
|
|
141
|
+
ndis += 4;
|
|
142
|
+
|
|
143
|
+
// compute first, assign next
|
|
144
|
+
const float* __restrict y0 =
|
|
145
|
+
reinterpret_cast<const float*>(codes + idx0 * code_size);
|
|
146
|
+
const float* __restrict y1 =
|
|
147
|
+
reinterpret_cast<const float*>(codes + idx1 * code_size);
|
|
148
|
+
const float* __restrict y2 =
|
|
149
|
+
reinterpret_cast<const float*>(codes + idx2 * code_size);
|
|
150
|
+
const float* __restrict y3 =
|
|
151
|
+
reinterpret_cast<const float*>(codes + idx3 * code_size);
|
|
152
|
+
|
|
153
|
+
float dp0 = 0;
|
|
154
|
+
float dp1 = 0;
|
|
155
|
+
float dp2 = 0;
|
|
156
|
+
float dp3 = 0;
|
|
157
|
+
fvec_L2sqr_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
|
|
158
|
+
dis0 = dp0;
|
|
159
|
+
dis1 = dp1;
|
|
160
|
+
dis2 = dp2;
|
|
161
|
+
dis3 = dp3;
|
|
162
|
+
}
|
|
125
163
|
};
|
|
126
164
|
|
|
127
165
|
struct FlatIPDis : FlatCodesDistanceComputer {
|
|
@@ -131,13 +169,13 @@ struct FlatIPDis : FlatCodesDistanceComputer {
|
|
|
131
169
|
const float* b;
|
|
132
170
|
size_t ndis;
|
|
133
171
|
|
|
134
|
-
float symmetric_dis(idx_t i, idx_t j) override {
|
|
172
|
+
float symmetric_dis(idx_t i, idx_t j) final override {
|
|
135
173
|
return fvec_inner_product(b + j * d, b + i * d, d);
|
|
136
174
|
}
|
|
137
175
|
|
|
138
|
-
float distance_to_code(const uint8_t* code) final {
|
|
176
|
+
float distance_to_code(const uint8_t* code) final override {
|
|
139
177
|
ndis++;
|
|
140
|
-
return fvec_inner_product(q, (float*)code, d);
|
|
178
|
+
return fvec_inner_product(q, (const float*)code, d);
|
|
141
179
|
}
|
|
142
180
|
|
|
143
181
|
explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
|
|
@@ -153,6 +191,39 @@ struct FlatIPDis : FlatCodesDistanceComputer {
|
|
|
153
191
|
void set_query(const float* x) override {
|
|
154
192
|
q = x;
|
|
155
193
|
}
|
|
194
|
+
|
|
195
|
+
// compute four distances
|
|
196
|
+
void distances_batch_4(
|
|
197
|
+
const idx_t idx0,
|
|
198
|
+
const idx_t idx1,
|
|
199
|
+
const idx_t idx2,
|
|
200
|
+
const idx_t idx3,
|
|
201
|
+
float& dis0,
|
|
202
|
+
float& dis1,
|
|
203
|
+
float& dis2,
|
|
204
|
+
float& dis3) final override {
|
|
205
|
+
ndis += 4;
|
|
206
|
+
|
|
207
|
+
// compute first, assign next
|
|
208
|
+
const float* __restrict y0 =
|
|
209
|
+
reinterpret_cast<const float*>(codes + idx0 * code_size);
|
|
210
|
+
const float* __restrict y1 =
|
|
211
|
+
reinterpret_cast<const float*>(codes + idx1 * code_size);
|
|
212
|
+
const float* __restrict y2 =
|
|
213
|
+
reinterpret_cast<const float*>(codes + idx2 * code_size);
|
|
214
|
+
const float* __restrict y3 =
|
|
215
|
+
reinterpret_cast<const float*>(codes + idx3 * code_size);
|
|
216
|
+
|
|
217
|
+
float dp0 = 0;
|
|
218
|
+
float dp1 = 0;
|
|
219
|
+
float dp2 = 0;
|
|
220
|
+
float dp3 = 0;
|
|
221
|
+
fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
|
|
222
|
+
dis0 = dp0;
|
|
223
|
+
dis1 = dp1;
|
|
224
|
+
dis2 = dp2;
|
|
225
|
+
dis3 = dp3;
|
|
226
|
+
}
|
|
156
227
|
};
|
|
157
228
|
|
|
158
229
|
} // namespace
|
|
@@ -184,6 +255,131 @@ void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
|
|
|
184
255
|
}
|
|
185
256
|
}
|
|
186
257
|
|
|
258
|
+
/***************************************************
|
|
259
|
+
* IndexFlatL2
|
|
260
|
+
***************************************************/
|
|
261
|
+
|
|
262
|
+
namespace {
|
|
263
|
+
struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
|
|
264
|
+
size_t d;
|
|
265
|
+
idx_t nb;
|
|
266
|
+
const float* q;
|
|
267
|
+
const float* b;
|
|
268
|
+
size_t ndis;
|
|
269
|
+
|
|
270
|
+
const float* l2norms;
|
|
271
|
+
float query_l2norm;
|
|
272
|
+
|
|
273
|
+
float distance_to_code(const uint8_t* code) final override {
|
|
274
|
+
ndis++;
|
|
275
|
+
return fvec_L2sqr(q, (float*)code, d);
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
float operator()(const idx_t i) final override {
|
|
279
|
+
const float* __restrict y =
|
|
280
|
+
reinterpret_cast<const float*>(codes + i * code_size);
|
|
281
|
+
|
|
282
|
+
prefetch_L2(l2norms + i);
|
|
283
|
+
const float dp0 = fvec_inner_product(q, y, d);
|
|
284
|
+
return query_l2norm + l2norms[i] - 2 * dp0;
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
float symmetric_dis(idx_t i, idx_t j) final override {
|
|
288
|
+
const float* __restrict yi =
|
|
289
|
+
reinterpret_cast<const float*>(codes + i * code_size);
|
|
290
|
+
const float* __restrict yj =
|
|
291
|
+
reinterpret_cast<const float*>(codes + j * code_size);
|
|
292
|
+
|
|
293
|
+
prefetch_L2(l2norms + i);
|
|
294
|
+
prefetch_L2(l2norms + j);
|
|
295
|
+
const float dp0 = fvec_inner_product(yi, yj, d);
|
|
296
|
+
return l2norms[i] + l2norms[j] - 2 * dp0;
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
explicit FlatL2WithNormsDis(
|
|
300
|
+
const IndexFlatL2& storage,
|
|
301
|
+
const float* q = nullptr)
|
|
302
|
+
: FlatCodesDistanceComputer(
|
|
303
|
+
storage.codes.data(),
|
|
304
|
+
storage.code_size),
|
|
305
|
+
d(storage.d),
|
|
306
|
+
nb(storage.ntotal),
|
|
307
|
+
q(q),
|
|
308
|
+
b(storage.get_xb()),
|
|
309
|
+
ndis(0),
|
|
310
|
+
l2norms(storage.cached_l2norms.data()),
|
|
311
|
+
query_l2norm(0) {}
|
|
312
|
+
|
|
313
|
+
void set_query(const float* x) override {
|
|
314
|
+
q = x;
|
|
315
|
+
query_l2norm = fvec_norm_L2sqr(q, d);
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
// compute four distances
|
|
319
|
+
void distances_batch_4(
|
|
320
|
+
const idx_t idx0,
|
|
321
|
+
const idx_t idx1,
|
|
322
|
+
const idx_t idx2,
|
|
323
|
+
const idx_t idx3,
|
|
324
|
+
float& dis0,
|
|
325
|
+
float& dis1,
|
|
326
|
+
float& dis2,
|
|
327
|
+
float& dis3) final override {
|
|
328
|
+
ndis += 4;
|
|
329
|
+
|
|
330
|
+
// compute first, assign next
|
|
331
|
+
const float* __restrict y0 =
|
|
332
|
+
reinterpret_cast<const float*>(codes + idx0 * code_size);
|
|
333
|
+
const float* __restrict y1 =
|
|
334
|
+
reinterpret_cast<const float*>(codes + idx1 * code_size);
|
|
335
|
+
const float* __restrict y2 =
|
|
336
|
+
reinterpret_cast<const float*>(codes + idx2 * code_size);
|
|
337
|
+
const float* __restrict y3 =
|
|
338
|
+
reinterpret_cast<const float*>(codes + idx3 * code_size);
|
|
339
|
+
|
|
340
|
+
prefetch_L2(l2norms + idx0);
|
|
341
|
+
prefetch_L2(l2norms + idx1);
|
|
342
|
+
prefetch_L2(l2norms + idx2);
|
|
343
|
+
prefetch_L2(l2norms + idx3);
|
|
344
|
+
|
|
345
|
+
float dp0 = 0;
|
|
346
|
+
float dp1 = 0;
|
|
347
|
+
float dp2 = 0;
|
|
348
|
+
float dp3 = 0;
|
|
349
|
+
fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
|
|
350
|
+
dis0 = query_l2norm + l2norms[idx0] - 2 * dp0;
|
|
351
|
+
dis1 = query_l2norm + l2norms[idx1] - 2 * dp1;
|
|
352
|
+
dis2 = query_l2norm + l2norms[idx2] - 2 * dp2;
|
|
353
|
+
dis3 = query_l2norm + l2norms[idx3] - 2 * dp3;
|
|
354
|
+
}
|
|
355
|
+
};
|
|
356
|
+
|
|
357
|
+
} // namespace
|
|
358
|
+
|
|
359
|
+
void IndexFlatL2::sync_l2norms() {
|
|
360
|
+
cached_l2norms.resize(ntotal);
|
|
361
|
+
fvec_norms_L2sqr(
|
|
362
|
+
cached_l2norms.data(),
|
|
363
|
+
reinterpret_cast<const float*>(codes.data()),
|
|
364
|
+
d,
|
|
365
|
+
ntotal);
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
void IndexFlatL2::clear_l2norms() {
|
|
369
|
+
cached_l2norms.clear();
|
|
370
|
+
cached_l2norms.shrink_to_fit();
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer() const {
|
|
374
|
+
if (metric_type == METRIC_L2) {
|
|
375
|
+
if (!cached_l2norms.empty()) {
|
|
376
|
+
return new FlatL2WithNormsDis(*this);
|
|
377
|
+
}
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
return IndexFlat::get_FlatCodesDistanceComputer();
|
|
381
|
+
}
|
|
382
|
+
|
|
187
383
|
/***************************************************
|
|
188
384
|
* IndexFlat1D
|
|
189
385
|
***************************************************/
|