faiss 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +1 -2
- data/vendor/faiss/faiss/Clustering.cpp +39 -22
- data/vendor/faiss/faiss/Clustering.h +40 -21
- data/vendor/faiss/faiss/IVFlib.cpp +26 -12
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +40 -10
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
- data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
- data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
- data/vendor/faiss/faiss/IndexHNSW.h +62 -49
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
- data/vendor/faiss/faiss/IndexIVF.h +46 -6
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
- data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +11 -11
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
- data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
- data/vendor/faiss/faiss/impl/HNSW.h +52 -30
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
- data/vendor/faiss/faiss/impl/io.cpp +23 -15
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
- data/vendor/faiss/faiss/index_factory.cpp +41 -20
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +147 -123
- data/vendor/faiss/faiss/utils/distances.h +86 -9
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +120 -7
- data/vendor/faiss/faiss/utils/utils.h +60 -20
- metadata +23 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
|
@@ -5,13 +5,12 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
// -*- c++ -*-
|
|
9
|
-
|
|
10
8
|
#include <faiss/utils/distances.h>
|
|
11
9
|
|
|
12
10
|
#include <algorithm>
|
|
13
11
|
#include <cassert>
|
|
14
12
|
#include <cmath>
|
|
13
|
+
#include <cstddef>
|
|
15
14
|
#include <cstdio>
|
|
16
15
|
#include <cstring>
|
|
17
16
|
|
|
@@ -64,7 +63,7 @@ void fvec_norms_L2(
|
|
|
64
63
|
const float* __restrict x,
|
|
65
64
|
size_t d,
|
|
66
65
|
size_t nx) {
|
|
67
|
-
#pragma omp parallel for
|
|
66
|
+
#pragma omp parallel for if (nx > 10000)
|
|
68
67
|
for (int64_t i = 0; i < nx; i++) {
|
|
69
68
|
nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
|
|
70
69
|
}
|
|
@@ -75,24 +74,52 @@ void fvec_norms_L2sqr(
|
|
|
75
74
|
const float* __restrict x,
|
|
76
75
|
size_t d,
|
|
77
76
|
size_t nx) {
|
|
78
|
-
#pragma omp parallel for
|
|
77
|
+
#pragma omp parallel for if (nx > 10000)
|
|
79
78
|
for (int64_t i = 0; i < nx; i++)
|
|
80
79
|
nr[i] = fvec_norm_L2sqr(x + i * d, d);
|
|
81
80
|
}
|
|
82
81
|
|
|
83
|
-
|
|
84
|
-
|
|
82
|
+
// The following is a workaround to a problem
|
|
83
|
+
// in OpenMP in fbcode. The crash occurs
|
|
84
|
+
// inside OMP when IndexIVFSpectralHash::set_query()
|
|
85
|
+
// calls fvec_renorm_L2. set_query() is always
|
|
86
|
+
// calling this function with nx == 1, so even
|
|
87
|
+
// the omp version should run single threaded,
|
|
88
|
+
// as per the if condition of the omp pragma.
|
|
89
|
+
// Instead, the omp version crashes inside OMP.
|
|
90
|
+
// The workaround below is explicitly branching
|
|
91
|
+
// off to a codepath without omp.
|
|
92
|
+
|
|
93
|
+
#define FVEC_RENORM_L2_IMPL \
|
|
94
|
+
float* __restrict xi = x + i * d; \
|
|
95
|
+
\
|
|
96
|
+
float nr = fvec_norm_L2sqr(xi, d); \
|
|
97
|
+
\
|
|
98
|
+
if (nr > 0) { \
|
|
99
|
+
size_t j; \
|
|
100
|
+
const float inv_nr = 1.0 / sqrtf(nr); \
|
|
101
|
+
for (j = 0; j < d; j++) \
|
|
102
|
+
xi[j] *= inv_nr; \
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
void fvec_renorm_L2_noomp(size_t d, size_t nx, float* __restrict x) {
|
|
85
106
|
for (int64_t i = 0; i < nx; i++) {
|
|
86
|
-
|
|
107
|
+
FVEC_RENORM_L2_IMPL
|
|
108
|
+
}
|
|
109
|
+
}
|
|
87
110
|
|
|
88
|
-
|
|
111
|
+
void fvec_renorm_L2_omp(size_t d, size_t nx, float* __restrict x) {
|
|
112
|
+
#pragma omp parallel for if (nx > 10000)
|
|
113
|
+
for (int64_t i = 0; i < nx; i++) {
|
|
114
|
+
FVEC_RENORM_L2_IMPL
|
|
115
|
+
}
|
|
116
|
+
}
|
|
89
117
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
}
|
|
118
|
+
void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
|
|
119
|
+
if (nx <= 10000) {
|
|
120
|
+
fvec_renorm_L2_noomp(d, nx, x);
|
|
121
|
+
} else {
|
|
122
|
+
fvec_renorm_L2_omp(d, nx, x);
|
|
96
123
|
}
|
|
97
124
|
}
|
|
98
125
|
|
|
@@ -103,19 +130,17 @@ void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
|
|
|
103
130
|
namespace {
|
|
104
131
|
|
|
105
132
|
/* Find the nearest neighbors for nx queries in a set of ny vectors */
|
|
106
|
-
template <class
|
|
133
|
+
template <class BlockResultHandler>
|
|
107
134
|
void exhaustive_inner_product_seq(
|
|
108
135
|
const float* x,
|
|
109
136
|
const float* y,
|
|
110
137
|
size_t d,
|
|
111
138
|
size_t nx,
|
|
112
139
|
size_t ny,
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
int nt = std::min(int(nx), omp_get_max_threads());
|
|
117
|
-
|
|
118
|
-
FAISS_ASSERT(use_sel == (sel != nullptr));
|
|
140
|
+
BlockResultHandler& res) {
|
|
141
|
+
using SingleResultHandler =
|
|
142
|
+
typename BlockResultHandler::SingleResultHandler;
|
|
143
|
+
[[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads());
|
|
119
144
|
|
|
120
145
|
#pragma omp parallel num_threads(nt)
|
|
121
146
|
{
|
|
@@ -128,7 +153,7 @@ void exhaustive_inner_product_seq(
|
|
|
128
153
|
resi.begin(i);
|
|
129
154
|
|
|
130
155
|
for (size_t j = 0; j < ny; j++, y_j += d) {
|
|
131
|
-
if (
|
|
156
|
+
if (!res.is_in_selection(j)) {
|
|
132
157
|
continue;
|
|
133
158
|
}
|
|
134
159
|
float ip = fvec_inner_product(x_i, y_j, d);
|
|
@@ -139,19 +164,17 @@ void exhaustive_inner_product_seq(
|
|
|
139
164
|
}
|
|
140
165
|
}
|
|
141
166
|
|
|
142
|
-
template <class
|
|
167
|
+
template <class BlockResultHandler>
|
|
143
168
|
void exhaustive_L2sqr_seq(
|
|
144
169
|
const float* x,
|
|
145
170
|
const float* y,
|
|
146
171
|
size_t d,
|
|
147
172
|
size_t nx,
|
|
148
173
|
size_t ny,
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
int nt = std::min(int(nx), omp_get_max_threads());
|
|
153
|
-
|
|
154
|
-
FAISS_ASSERT(use_sel == (sel != nullptr));
|
|
174
|
+
BlockResultHandler& res) {
|
|
175
|
+
using SingleResultHandler =
|
|
176
|
+
typename BlockResultHandler::SingleResultHandler;
|
|
177
|
+
[[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads());
|
|
155
178
|
|
|
156
179
|
#pragma omp parallel num_threads(nt)
|
|
157
180
|
{
|
|
@@ -162,7 +185,7 @@ void exhaustive_L2sqr_seq(
|
|
|
162
185
|
const float* y_j = y;
|
|
163
186
|
resi.begin(i);
|
|
164
187
|
for (size_t j = 0; j < ny; j++, y_j += d) {
|
|
165
|
-
if (
|
|
188
|
+
if (!res.is_in_selection(j)) {
|
|
166
189
|
continue;
|
|
167
190
|
}
|
|
168
191
|
float disij = fvec_L2sqr(x_i, y_j, d);
|
|
@@ -174,14 +197,14 @@ void exhaustive_L2sqr_seq(
|
|
|
174
197
|
}
|
|
175
198
|
|
|
176
199
|
/** Find the nearest neighbors for nx queries in a set of ny vectors */
|
|
177
|
-
template <class
|
|
200
|
+
template <class BlockResultHandler>
|
|
178
201
|
void exhaustive_inner_product_blas(
|
|
179
202
|
const float* x,
|
|
180
203
|
const float* y,
|
|
181
204
|
size_t d,
|
|
182
205
|
size_t nx,
|
|
183
206
|
size_t ny,
|
|
184
|
-
|
|
207
|
+
BlockResultHandler& res) {
|
|
185
208
|
// BLAS does not like empty matrices
|
|
186
209
|
if (nx == 0 || ny == 0)
|
|
187
210
|
return;
|
|
@@ -230,14 +253,14 @@ void exhaustive_inner_product_blas(
|
|
|
230
253
|
|
|
231
254
|
// distance correction is an operator that can be applied to transform
|
|
232
255
|
// the distances
|
|
233
|
-
template <class
|
|
256
|
+
template <class BlockResultHandler>
|
|
234
257
|
void exhaustive_L2sqr_blas_default_impl(
|
|
235
258
|
const float* x,
|
|
236
259
|
const float* y,
|
|
237
260
|
size_t d,
|
|
238
261
|
size_t nx,
|
|
239
262
|
size_t ny,
|
|
240
|
-
|
|
263
|
+
BlockResultHandler& res,
|
|
241
264
|
const float* y_norms = nullptr) {
|
|
242
265
|
// BLAS does not like empty matrices
|
|
243
266
|
if (nx == 0 || ny == 0)
|
|
@@ -297,6 +320,9 @@ void exhaustive_L2sqr_blas_default_impl(
|
|
|
297
320
|
float ip = *ip_line;
|
|
298
321
|
float dis = x_norms[i] + y_norms[j] - 2 * ip;
|
|
299
322
|
|
|
323
|
+
if (!res.is_in_selection(j)) {
|
|
324
|
+
dis = HUGE_VALF;
|
|
325
|
+
}
|
|
300
326
|
// negative values can occur for identical vectors
|
|
301
327
|
// due to roundoff errors
|
|
302
328
|
if (dis < 0)
|
|
@@ -313,14 +339,14 @@ void exhaustive_L2sqr_blas_default_impl(
|
|
|
313
339
|
}
|
|
314
340
|
}
|
|
315
341
|
|
|
316
|
-
template <class
|
|
342
|
+
template <class BlockResultHandler>
|
|
317
343
|
void exhaustive_L2sqr_blas(
|
|
318
344
|
const float* x,
|
|
319
345
|
const float* y,
|
|
320
346
|
size_t d,
|
|
321
347
|
size_t nx,
|
|
322
348
|
size_t ny,
|
|
323
|
-
|
|
349
|
+
BlockResultHandler& res,
|
|
324
350
|
const float* y_norms = nullptr) {
|
|
325
351
|
exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res);
|
|
326
352
|
}
|
|
@@ -332,7 +358,7 @@ void exhaustive_L2sqr_blas_cmax_avx2(
|
|
|
332
358
|
size_t d,
|
|
333
359
|
size_t nx,
|
|
334
360
|
size_t ny,
|
|
335
|
-
|
|
361
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
|
336
362
|
const float* y_norms) {
|
|
337
363
|
// BLAS does not like empty matrices
|
|
338
364
|
if (nx == 0 || ny == 0)
|
|
@@ -388,8 +414,8 @@ void exhaustive_L2sqr_blas_cmax_avx2(
|
|
|
388
414
|
for (int64_t i = i0; i < i1; i++) {
|
|
389
415
|
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
|
|
390
416
|
|
|
391
|
-
_mm_prefetch(ip_line, _MM_HINT_NTA);
|
|
392
|
-
_mm_prefetch(ip_line + 16, _MM_HINT_NTA);
|
|
417
|
+
_mm_prefetch((const char*)ip_line, _MM_HINT_NTA);
|
|
418
|
+
_mm_prefetch((const char*)(ip_line + 16), _MM_HINT_NTA);
|
|
393
419
|
|
|
394
420
|
// constant
|
|
395
421
|
const __m256 mul_minus2 = _mm256_set1_ps(-2);
|
|
@@ -416,8 +442,8 @@ void exhaustive_L2sqr_blas_cmax_avx2(
|
|
|
416
442
|
|
|
417
443
|
// process 16 elements per loop
|
|
418
444
|
for (; idx_j < (count / 16) * 16; idx_j += 16, ip_line += 16) {
|
|
419
|
-
_mm_prefetch(ip_line + 32, _MM_HINT_NTA);
|
|
420
|
-
_mm_prefetch(ip_line + 48, _MM_HINT_NTA);
|
|
445
|
+
_mm_prefetch((const char*)(ip_line + 32), _MM_HINT_NTA);
|
|
446
|
+
_mm_prefetch((const char*)(ip_line + 48), _MM_HINT_NTA);
|
|
421
447
|
|
|
422
448
|
// load values for norms
|
|
423
449
|
const __m256 y_norm_0 =
|
|
@@ -535,13 +561,13 @@ void exhaustive_L2sqr_blas_cmax_avx2(
|
|
|
535
561
|
|
|
536
562
|
// an override if only a single closest point is needed
|
|
537
563
|
template <>
|
|
538
|
-
void exhaustive_L2sqr_blas<
|
|
564
|
+
void exhaustive_L2sqr_blas<Top1BlockResultHandler<CMax<float, int64_t>>>(
|
|
539
565
|
const float* x,
|
|
540
566
|
const float* y,
|
|
541
567
|
size_t d,
|
|
542
568
|
size_t nx,
|
|
543
569
|
size_t ny,
|
|
544
|
-
|
|
570
|
+
Top1BlockResultHandler<CMax<float, int64_t>>& res,
|
|
545
571
|
const float* y_norms) {
|
|
546
572
|
#if defined(__AVX2__)
|
|
547
573
|
// use a faster fused kernel if available
|
|
@@ -562,34 +588,50 @@ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
|
|
|
562
588
|
|
|
563
589
|
// run the default implementation
|
|
564
590
|
exhaustive_L2sqr_blas_default_impl<
|
|
565
|
-
|
|
591
|
+
Top1BlockResultHandler<CMax<float, int64_t>>>(
|
|
566
592
|
x, y, d, nx, ny, res, y_norms);
|
|
567
593
|
#else
|
|
568
594
|
// run the default implementation
|
|
569
595
|
exhaustive_L2sqr_blas_default_impl<
|
|
570
|
-
|
|
596
|
+
Top1BlockResultHandler<CMax<float, int64_t>>>(
|
|
571
597
|
x, y, d, nx, ny, res, y_norms);
|
|
572
598
|
#endif
|
|
573
599
|
}
|
|
574
600
|
|
|
575
|
-
|
|
576
|
-
void
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
} else {
|
|
590
|
-
exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
|
|
601
|
+
struct Run_search_inner_product {
|
|
602
|
+
using T = void;
|
|
603
|
+
template <class BlockResultHandler>
|
|
604
|
+
void f(BlockResultHandler& res,
|
|
605
|
+
const float* x,
|
|
606
|
+
const float* y,
|
|
607
|
+
size_t d,
|
|
608
|
+
size_t nx,
|
|
609
|
+
size_t ny) {
|
|
610
|
+
if (res.sel || nx < distance_compute_blas_threshold) {
|
|
611
|
+
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
|
612
|
+
} else {
|
|
613
|
+
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
|
614
|
+
}
|
|
591
615
|
}
|
|
592
|
-
}
|
|
616
|
+
};
|
|
617
|
+
|
|
618
|
+
struct Run_search_L2sqr {
|
|
619
|
+
using T = void;
|
|
620
|
+
template <class BlockResultHandler>
|
|
621
|
+
void f(BlockResultHandler& res,
|
|
622
|
+
const float* x,
|
|
623
|
+
const float* y,
|
|
624
|
+
size_t d,
|
|
625
|
+
size_t nx,
|
|
626
|
+
size_t ny,
|
|
627
|
+
const float* y_norm2) {
|
|
628
|
+
if (res.sel || nx < distance_compute_blas_threshold) {
|
|
629
|
+
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
|
|
630
|
+
} else {
|
|
631
|
+
exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
|
|
632
|
+
}
|
|
633
|
+
}
|
|
634
|
+
};
|
|
593
635
|
|
|
594
636
|
} // anonymous namespace
|
|
595
637
|
|
|
@@ -609,7 +651,7 @@ void knn_inner_product(
|
|
|
609
651
|
size_t nx,
|
|
610
652
|
size_t ny,
|
|
611
653
|
size_t k,
|
|
612
|
-
float*
|
|
654
|
+
float* vals,
|
|
613
655
|
int64_t* ids,
|
|
614
656
|
const IDSelector* sel) {
|
|
615
657
|
int64_t imin = 0;
|
|
@@ -622,30 +664,14 @@ void knn_inner_product(
|
|
|
622
664
|
}
|
|
623
665
|
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
|
|
624
666
|
knn_inner_products_by_idx(
|
|
625
|
-
x, y, sela->ids, d, nx, sela->n, k,
|
|
667
|
+
x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
|
|
626
668
|
return;
|
|
627
669
|
}
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
} else if (nx < distance_compute_blas_threshold) {
|
|
634
|
-
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
|
635
|
-
} else {
|
|
636
|
-
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
|
637
|
-
}
|
|
638
|
-
} else {
|
|
639
|
-
using RH = ReservoirResultHandler<CMin<float, int64_t>>;
|
|
640
|
-
RH res(nx, val, ids, k);
|
|
641
|
-
if (sel) {
|
|
642
|
-
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, res, sel);
|
|
643
|
-
} else if (nx < distance_compute_blas_threshold) {
|
|
644
|
-
exhaustive_inner_product_seq(x, y, d, nx, ny, res, nullptr);
|
|
645
|
-
} else {
|
|
646
|
-
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
|
647
|
-
}
|
|
648
|
-
}
|
|
670
|
+
|
|
671
|
+
Run_search_inner_product r;
|
|
672
|
+
dispatch_knn_ResultHandler(
|
|
673
|
+
nx, vals, ids, k, METRIC_INNER_PRODUCT, sel, r, x, y, d, nx, ny);
|
|
674
|
+
|
|
649
675
|
if (imin != 0) {
|
|
650
676
|
for (size_t i = 0; i < nx * k; i++) {
|
|
651
677
|
if (ids[i] >= 0) {
|
|
@@ -687,19 +713,14 @@ void knn_L2sqr(
|
|
|
687
713
|
sel = nullptr;
|
|
688
714
|
}
|
|
689
715
|
if (auto sela = dynamic_cast<const IDSelectorArray*>(sel)) {
|
|
690
|
-
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, sela->n, k, vals, ids, 0);
|
|
716
|
+
knn_L2sqr_by_idx(x, y, sela->ids, d, nx, ny, sela->n, k, vals, ids, 0);
|
|
691
717
|
return;
|
|
692
718
|
}
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
|
699
|
-
} else {
|
|
700
|
-
ReservoirResultHandler<CMax<float, int64_t>> res(nx, vals, ids, k);
|
|
701
|
-
knn_L2sqr_select(x, y, d, nx, ny, res, y_norm2, sel);
|
|
702
|
-
}
|
|
719
|
+
|
|
720
|
+
Run_search_L2sqr r;
|
|
721
|
+
dispatch_knn_ResultHandler(
|
|
722
|
+
nx, vals, ids, k, METRIC_L2, sel, r, x, y, d, nx, ny, y_norm2);
|
|
723
|
+
|
|
703
724
|
if (imin != 0) {
|
|
704
725
|
for (size_t i = 0; i < nx * k; i++) {
|
|
705
726
|
if (ids[i] >= 0) {
|
|
@@ -726,6 +747,7 @@ void knn_L2sqr(
|
|
|
726
747
|
* Range search
|
|
727
748
|
***************************************************************************/
|
|
728
749
|
|
|
750
|
+
// TODO accept a y_norm2 as well
|
|
729
751
|
void range_search_L2sqr(
|
|
730
752
|
const float* x,
|
|
731
753
|
const float* y,
|
|
@@ -735,15 +757,9 @@ void range_search_L2sqr(
|
|
|
735
757
|
float radius,
|
|
736
758
|
RangeSearchResult* res,
|
|
737
759
|
const IDSelector* sel) {
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
exhaustive_L2sqr_seq<RH, true>(x, y, d, nx, ny, resh, sel);
|
|
742
|
-
} else if (nx < distance_compute_blas_threshold) {
|
|
743
|
-
exhaustive_L2sqr_seq(x, y, d, nx, ny, resh, sel);
|
|
744
|
-
} else {
|
|
745
|
-
exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
|
|
746
|
-
}
|
|
760
|
+
Run_search_L2sqr r;
|
|
761
|
+
dispatch_range_ResultHandler(
|
|
762
|
+
res, radius, METRIC_L2, sel, r, x, y, d, nx, ny, nullptr);
|
|
747
763
|
}
|
|
748
764
|
|
|
749
765
|
void range_search_inner_product(
|
|
@@ -755,15 +771,9 @@ void range_search_inner_product(
|
|
|
755
771
|
float radius,
|
|
756
772
|
RangeSearchResult* res,
|
|
757
773
|
const IDSelector* sel) {
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
exhaustive_inner_product_seq<RH, true>(x, y, d, nx, ny, resh, sel);
|
|
762
|
-
} else if (nx < distance_compute_blas_threshold) {
|
|
763
|
-
exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
|
|
764
|
-
} else {
|
|
765
|
-
exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
|
|
766
|
-
}
|
|
774
|
+
Run_search_inner_product r;
|
|
775
|
+
dispatch_range_ResultHandler(
|
|
776
|
+
res, radius, METRIC_INNER_PRODUCT, sel, r, x, y, d, nx, ny);
|
|
767
777
|
}
|
|
768
778
|
|
|
769
779
|
/***************************************************************************
|
|
@@ -786,9 +796,11 @@ void fvec_inner_products_by_idx(
|
|
|
786
796
|
const float* xj = x + j * d;
|
|
787
797
|
float* __restrict ipj = ip + j * ny;
|
|
788
798
|
for (size_t i = 0; i < ny; i++) {
|
|
789
|
-
if (idsj[i] < 0)
|
|
790
|
-
|
|
791
|
-
|
|
799
|
+
if (idsj[i] < 0) {
|
|
800
|
+
ipj[i] = -INFINITY;
|
|
801
|
+
} else {
|
|
802
|
+
ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
|
|
803
|
+
}
|
|
792
804
|
}
|
|
793
805
|
}
|
|
794
806
|
}
|
|
@@ -809,9 +821,11 @@ void fvec_L2sqr_by_idx(
|
|
|
809
821
|
const float* xj = x + j * d;
|
|
810
822
|
float* __restrict disj = dis + j * ny;
|
|
811
823
|
for (size_t i = 0; i < ny; i++) {
|
|
812
|
-
if (idsj[i] < 0)
|
|
813
|
-
|
|
814
|
-
|
|
824
|
+
if (idsj[i] < 0) {
|
|
825
|
+
disj[i] = INFINITY;
|
|
826
|
+
} else {
|
|
827
|
+
disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
|
|
828
|
+
}
|
|
815
829
|
}
|
|
816
830
|
}
|
|
817
831
|
}
|
|
@@ -828,6 +842,8 @@ void pairwise_indexed_L2sqr(
|
|
|
828
842
|
for (int64_t j = 0; j < n; j++) {
|
|
829
843
|
if (ix[j] >= 0 && iy[j] >= 0) {
|
|
830
844
|
dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
|
|
845
|
+
} else {
|
|
846
|
+
dis[j] = INFINITY;
|
|
831
847
|
}
|
|
832
848
|
}
|
|
833
849
|
}
|
|
@@ -844,6 +860,8 @@ void pairwise_indexed_inner_product(
|
|
|
844
860
|
for (int64_t j = 0; j < n; j++) {
|
|
845
861
|
if (ix[j] >= 0 && iy[j] >= 0) {
|
|
846
862
|
dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
|
|
863
|
+
} else {
|
|
864
|
+
dis[j] = -INFINITY;
|
|
847
865
|
}
|
|
848
866
|
}
|
|
849
867
|
}
|
|
@@ -857,6 +875,7 @@ void knn_inner_products_by_idx(
|
|
|
857
875
|
size_t d,
|
|
858
876
|
size_t nx,
|
|
859
877
|
size_t ny,
|
|
878
|
+
size_t nsubset,
|
|
860
879
|
size_t k,
|
|
861
880
|
float* res_vals,
|
|
862
881
|
int64_t* res_ids,
|
|
@@ -874,9 +893,10 @@ void knn_inner_products_by_idx(
|
|
|
874
893
|
int64_t* __restrict idxi = res_ids + i * k;
|
|
875
894
|
minheap_heapify(k, simi, idxi);
|
|
876
895
|
|
|
877
|
-
for (j = 0; j <
|
|
878
|
-
if (idsi[j] < 0)
|
|
896
|
+
for (j = 0; j < nsubset; j++) {
|
|
897
|
+
if (idsi[j] < 0 || idsi[j] >= ny) {
|
|
879
898
|
break;
|
|
899
|
+
}
|
|
880
900
|
float ip = fvec_inner_product(x_, y + d * idsi[j], d);
|
|
881
901
|
|
|
882
902
|
if (ip > simi[0]) {
|
|
@@ -894,6 +914,7 @@ void knn_L2sqr_by_idx(
|
|
|
894
914
|
size_t d,
|
|
895
915
|
size_t nx,
|
|
896
916
|
size_t ny,
|
|
917
|
+
size_t nsubset,
|
|
897
918
|
size_t k,
|
|
898
919
|
float* res_vals,
|
|
899
920
|
int64_t* res_ids,
|
|
@@ -908,7 +929,10 @@ void knn_L2sqr_by_idx(
|
|
|
908
929
|
float* __restrict simi = res_vals + i * k;
|
|
909
930
|
int64_t* __restrict idxi = res_ids + i * k;
|
|
910
931
|
maxheap_heapify(k, simi, idxi);
|
|
911
|
-
for (size_t j = 0; j <
|
|
932
|
+
for (size_t j = 0; j < nsubset; j++) {
|
|
933
|
+
if (idsi[j] < 0 || idsi[j] >= ny) {
|
|
934
|
+
break;
|
|
935
|
+
}
|
|
912
936
|
float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
|
|
913
937
|
|
|
914
938
|
if (disij < simi[0]) {
|