faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
|
@@ -10,10 +10,10 @@
|
|
|
10
10
|
#include <faiss/utils/distances.h>
|
|
11
11
|
|
|
12
12
|
#include <algorithm>
|
|
13
|
-
#include <cstdio>
|
|
14
13
|
#include <cassert>
|
|
15
|
-
#include <cstring>
|
|
16
14
|
#include <cmath>
|
|
15
|
+
#include <cstdio>
|
|
16
|
+
#include <cstring>
|
|
17
17
|
|
|
18
18
|
#include <omp.h>
|
|
19
19
|
|
|
@@ -21,186 +21,153 @@
|
|
|
21
21
|
#include <faiss/impl/FaissAssert.h>
|
|
22
22
|
#include <faiss/impl/ResultHandler.h>
|
|
23
23
|
|
|
24
|
-
|
|
25
|
-
|
|
26
24
|
#ifndef FINTEGER
|
|
27
25
|
#define FINTEGER long
|
|
28
26
|
#endif
|
|
29
27
|
|
|
30
|
-
|
|
31
28
|
extern "C" {
|
|
32
29
|
|
|
33
30
|
/* declare BLAS functions, see http://www.netlib.org/clapack/cblas/ */
|
|
34
31
|
|
|
35
|
-
int sgemm_
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
32
|
+
int sgemm_(
|
|
33
|
+
const char* transa,
|
|
34
|
+
const char* transb,
|
|
35
|
+
FINTEGER* m,
|
|
36
|
+
FINTEGER* n,
|
|
37
|
+
FINTEGER* k,
|
|
38
|
+
const float* alpha,
|
|
39
|
+
const float* a,
|
|
40
|
+
FINTEGER* lda,
|
|
41
|
+
const float* b,
|
|
42
|
+
FINTEGER* ldb,
|
|
43
|
+
float* beta,
|
|
44
|
+
float* c,
|
|
45
|
+
FINTEGER* ldc);
|
|
41
46
|
}
|
|
42
47
|
|
|
43
|
-
|
|
44
48
|
namespace faiss {
|
|
45
49
|
|
|
46
|
-
|
|
47
|
-
|
|
48
50
|
/***************************************************************************
|
|
49
51
|
* Matrix/vector ops
|
|
50
52
|
***************************************************************************/
|
|
51
53
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
54
|
/* Compute the L2 norm of a set of nx vectors */
|
|
56
|
-
void fvec_norms_L2
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
55
|
+
void fvec_norms_L2(
|
|
56
|
+
float* __restrict nr,
|
|
57
|
+
const float* __restrict x,
|
|
58
|
+
size_t d,
|
|
59
|
+
size_t nx) {
|
|
61
60
|
#pragma omp parallel for
|
|
62
61
|
for (int64_t i = 0; i < nx; i++) {
|
|
63
|
-
nr[i] = sqrtf
|
|
62
|
+
nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
|
|
64
63
|
}
|
|
65
64
|
}
|
|
66
65
|
|
|
67
|
-
void fvec_norms_L2sqr
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
66
|
+
void fvec_norms_L2sqr(
|
|
67
|
+
float* __restrict nr,
|
|
68
|
+
const float* __restrict x,
|
|
69
|
+
size_t d,
|
|
70
|
+
size_t nx) {
|
|
71
71
|
#pragma omp parallel for
|
|
72
72
|
for (int64_t i = 0; i < nx; i++)
|
|
73
|
-
nr[i] = fvec_norm_L2sqr
|
|
73
|
+
nr[i] = fvec_norm_L2sqr(x + i * d, d);
|
|
74
74
|
}
|
|
75
75
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
void fvec_renorm_L2 (size_t d, size_t nx, float * __restrict x)
|
|
79
|
-
{
|
|
76
|
+
void fvec_renorm_L2(size_t d, size_t nx, float* __restrict x) {
|
|
80
77
|
#pragma omp parallel for
|
|
81
78
|
for (int64_t i = 0; i < nx; i++) {
|
|
82
|
-
float
|
|
79
|
+
float* __restrict xi = x + i * d;
|
|
83
80
|
|
|
84
|
-
float nr = fvec_norm_L2sqr
|
|
81
|
+
float nr = fvec_norm_L2sqr(xi, d);
|
|
85
82
|
|
|
86
83
|
if (nr > 0) {
|
|
87
84
|
size_t j;
|
|
88
|
-
const float inv_nr = 1.0 / sqrtf
|
|
85
|
+
const float inv_nr = 1.0 / sqrtf(nr);
|
|
89
86
|
for (j = 0; j < d; j++)
|
|
90
87
|
xi[j] *= inv_nr;
|
|
91
88
|
}
|
|
92
89
|
}
|
|
93
90
|
}
|
|
94
91
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
92
|
/***************************************************************************
|
|
107
93
|
* KNN functions
|
|
108
94
|
***************************************************************************/
|
|
109
95
|
|
|
110
96
|
namespace {
|
|
111
97
|
|
|
112
|
-
|
|
113
|
-
|
|
114
98
|
/* Find the nearest neighbors for nx queries in a set of ny vectors */
|
|
115
|
-
template<class ResultHandler>
|
|
116
|
-
void exhaustive_inner_product_seq
|
|
117
|
-
const float
|
|
118
|
-
const float
|
|
119
|
-
size_t d,
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
check_period *= omp_get_max_threads();
|
|
125
|
-
|
|
99
|
+
template <class ResultHandler>
|
|
100
|
+
void exhaustive_inner_product_seq(
|
|
101
|
+
const float* x,
|
|
102
|
+
const float* y,
|
|
103
|
+
size_t d,
|
|
104
|
+
size_t nx,
|
|
105
|
+
size_t ny,
|
|
106
|
+
ResultHandler& res) {
|
|
126
107
|
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
|
|
108
|
+
int nt = std::min(int(nx), omp_get_max_threads());
|
|
127
109
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
#pragma omp parallel
|
|
132
|
-
{
|
|
133
|
-
SingleResultHandler resi(res);
|
|
110
|
+
#pragma omp parallel num_threads(nt)
|
|
111
|
+
{
|
|
112
|
+
SingleResultHandler resi(res);
|
|
134
113
|
#pragma omp for
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
114
|
+
for (int64_t i = 0; i < nx; i++) {
|
|
115
|
+
const float* x_i = x + i * d;
|
|
116
|
+
const float* y_j = y;
|
|
138
117
|
|
|
139
|
-
|
|
118
|
+
resi.begin(i);
|
|
140
119
|
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
}
|
|
146
|
-
resi.end();
|
|
120
|
+
for (size_t j = 0; j < ny; j++) {
|
|
121
|
+
float ip = fvec_inner_product(x_i, y_j, d);
|
|
122
|
+
resi.add_result(ip, j);
|
|
123
|
+
y_j += d;
|
|
147
124
|
}
|
|
125
|
+
resi.end();
|
|
148
126
|
}
|
|
149
|
-
InterruptCallback::check ();
|
|
150
127
|
}
|
|
151
|
-
|
|
152
128
|
}
|
|
153
129
|
|
|
154
|
-
template<class ResultHandler>
|
|
155
|
-
void exhaustive_L2sqr_seq
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
size_t check_period = InterruptCallback::get_period_hint (ny * d);
|
|
163
|
-
check_period *= omp_get_max_threads();
|
|
130
|
+
template <class ResultHandler>
|
|
131
|
+
void exhaustive_L2sqr_seq(
|
|
132
|
+
const float* x,
|
|
133
|
+
const float* y,
|
|
134
|
+
size_t d,
|
|
135
|
+
size_t nx,
|
|
136
|
+
size_t ny,
|
|
137
|
+
ResultHandler& res) {
|
|
164
138
|
using SingleResultHandler = typename ResultHandler::SingleResultHandler;
|
|
139
|
+
int nt = std::min(int(nx), omp_get_max_threads());
|
|
165
140
|
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
#pragma omp parallel
|
|
170
|
-
{
|
|
171
|
-
SingleResultHandler resi(res);
|
|
141
|
+
#pragma omp parallel num_threads(nt)
|
|
142
|
+
{
|
|
143
|
+
SingleResultHandler resi(res);
|
|
172
144
|
#pragma omp for
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
}
|
|
182
|
-
resi.end();
|
|
145
|
+
for (int64_t i = 0; i < nx; i++) {
|
|
146
|
+
const float* x_i = x + i * d;
|
|
147
|
+
const float* y_j = y;
|
|
148
|
+
resi.begin(i);
|
|
149
|
+
for (size_t j = 0; j < ny; j++) {
|
|
150
|
+
float disij = fvec_L2sqr(x_i, y_j, d);
|
|
151
|
+
resi.add_result(disij, j);
|
|
152
|
+
y_j += d;
|
|
183
153
|
}
|
|
154
|
+
resi.end();
|
|
184
155
|
}
|
|
185
|
-
InterruptCallback::check ();
|
|
186
156
|
}
|
|
187
|
-
|
|
188
|
-
};
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
157
|
+
}
|
|
193
158
|
|
|
194
159
|
/** Find the nearest neighbors for nx queries in a set of ny vectors */
|
|
195
|
-
template<class ResultHandler>
|
|
196
|
-
void exhaustive_inner_product_blas
|
|
197
|
-
const float
|
|
198
|
-
const float
|
|
199
|
-
size_t d,
|
|
200
|
-
|
|
201
|
-
|
|
160
|
+
template <class ResultHandler>
|
|
161
|
+
void exhaustive_inner_product_blas(
|
|
162
|
+
const float* x,
|
|
163
|
+
const float* y,
|
|
164
|
+
size_t d,
|
|
165
|
+
size_t nx,
|
|
166
|
+
size_t ny,
|
|
167
|
+
ResultHandler& res) {
|
|
202
168
|
// BLAS does not like empty matrices
|
|
203
|
-
if (nx == 0 || ny == 0)
|
|
169
|
+
if (nx == 0 || ny == 0)
|
|
170
|
+
return;
|
|
204
171
|
|
|
205
172
|
/* block sizes */
|
|
206
173
|
const size_t bs_x = distance_compute_blas_query_bs;
|
|
@@ -209,86 +176,105 @@ void exhaustive_inner_product_blas (
|
|
|
209
176
|
|
|
210
177
|
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
|
211
178
|
size_t i1 = i0 + bs_x;
|
|
212
|
-
if(i1 > nx)
|
|
179
|
+
if (i1 > nx)
|
|
180
|
+
i1 = nx;
|
|
213
181
|
|
|
214
182
|
res.begin_multiple(i0, i1);
|
|
215
183
|
|
|
216
184
|
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
|
217
185
|
size_t j1 = j0 + bs_y;
|
|
218
|
-
if (j1 > ny)
|
|
186
|
+
if (j1 > ny)
|
|
187
|
+
j1 = ny;
|
|
219
188
|
/* compute the actual dot products */
|
|
220
189
|
{
|
|
221
190
|
float one = 1, zero = 0;
|
|
222
191
|
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
|
223
|
-
sgemm_
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
192
|
+
sgemm_("Transpose",
|
|
193
|
+
"Not transpose",
|
|
194
|
+
&nyi,
|
|
195
|
+
&nxi,
|
|
196
|
+
&di,
|
|
197
|
+
&one,
|
|
198
|
+
y + j0 * d,
|
|
199
|
+
&di,
|
|
200
|
+
x + i0 * d,
|
|
201
|
+
&di,
|
|
202
|
+
&zero,
|
|
203
|
+
ip_block.get(),
|
|
204
|
+
&nyi);
|
|
227
205
|
}
|
|
228
206
|
|
|
229
207
|
res.add_results(j0, j1, ip_block.get());
|
|
230
|
-
|
|
231
208
|
}
|
|
232
209
|
res.end_multiple();
|
|
233
|
-
InterruptCallback::check
|
|
234
|
-
|
|
210
|
+
InterruptCallback::check();
|
|
235
211
|
}
|
|
236
212
|
}
|
|
237
213
|
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
214
|
// distance correction is an operator that can be applied to transform
|
|
242
215
|
// the distances
|
|
243
|
-
template<class ResultHandler>
|
|
244
|
-
void exhaustive_L2sqr_blas
|
|
245
|
-
const float
|
|
246
|
-
const float
|
|
247
|
-
size_t d,
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
216
|
+
template <class ResultHandler>
|
|
217
|
+
void exhaustive_L2sqr_blas(
|
|
218
|
+
const float* x,
|
|
219
|
+
const float* y,
|
|
220
|
+
size_t d,
|
|
221
|
+
size_t nx,
|
|
222
|
+
size_t ny,
|
|
223
|
+
ResultHandler& res,
|
|
224
|
+
const float* y_norms = nullptr) {
|
|
251
225
|
// BLAS does not like empty matrices
|
|
252
|
-
if (nx == 0 || ny == 0)
|
|
226
|
+
if (nx == 0 || ny == 0)
|
|
227
|
+
return;
|
|
253
228
|
|
|
254
229
|
/* block sizes */
|
|
255
230
|
const size_t bs_x = distance_compute_blas_query_bs;
|
|
256
231
|
const size_t bs_y = distance_compute_blas_database_bs;
|
|
257
232
|
// const size_t bs_x = 16, bs_y = 16;
|
|
258
|
-
std::unique_ptr<float
|
|
259
|
-
std::unique_ptr<float
|
|
260
|
-
std::unique_ptr<float
|
|
233
|
+
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);
|
|
234
|
+
std::unique_ptr<float[]> x_norms(new float[nx]);
|
|
235
|
+
std::unique_ptr<float[]> del2;
|
|
261
236
|
|
|
262
|
-
fvec_norms_L2sqr
|
|
237
|
+
fvec_norms_L2sqr(x_norms.get(), x, d, nx);
|
|
263
238
|
|
|
264
239
|
if (!y_norms) {
|
|
265
|
-
float
|
|
240
|
+
float* y_norms2 = new float[ny];
|
|
266
241
|
del2.reset(y_norms2);
|
|
267
|
-
fvec_norms_L2sqr
|
|
242
|
+
fvec_norms_L2sqr(y_norms2, y, d, ny);
|
|
268
243
|
y_norms = y_norms2;
|
|
269
244
|
}
|
|
270
245
|
|
|
271
246
|
for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
|
|
272
247
|
size_t i1 = i0 + bs_x;
|
|
273
|
-
if(i1 > nx)
|
|
248
|
+
if (i1 > nx)
|
|
249
|
+
i1 = nx;
|
|
274
250
|
|
|
275
251
|
res.begin_multiple(i0, i1);
|
|
276
252
|
|
|
277
253
|
for (size_t j0 = 0; j0 < ny; j0 += bs_y) {
|
|
278
254
|
size_t j1 = j0 + bs_y;
|
|
279
|
-
if (j1 > ny)
|
|
255
|
+
if (j1 > ny)
|
|
256
|
+
j1 = ny;
|
|
280
257
|
/* compute the actual dot products */
|
|
281
258
|
{
|
|
282
259
|
float one = 1, zero = 0;
|
|
283
260
|
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
|
|
284
|
-
sgemm_
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
261
|
+
sgemm_("Transpose",
|
|
262
|
+
"Not transpose",
|
|
263
|
+
&nyi,
|
|
264
|
+
&nxi,
|
|
265
|
+
&di,
|
|
266
|
+
&one,
|
|
267
|
+
y + j0 * d,
|
|
268
|
+
&di,
|
|
269
|
+
x + i0 * d,
|
|
270
|
+
&di,
|
|
271
|
+
&zero,
|
|
272
|
+
ip_block.get(),
|
|
273
|
+
&nyi);
|
|
288
274
|
}
|
|
289
|
-
|
|
275
|
+
#pragma omp parallel for
|
|
290
276
|
for (int64_t i = i0; i < i1; i++) {
|
|
291
|
-
float
|
|
277
|
+
float* ip_line = ip_block.get() + (i - i0) * (j1 - j0);
|
|
292
278
|
|
|
293
279
|
for (size_t j = j0; j < j1; j++) {
|
|
294
280
|
float ip = *ip_line;
|
|
@@ -296,7 +282,8 @@ void exhaustive_L2sqr_blas (
|
|
|
296
282
|
|
|
297
283
|
// negative values can occur for identical vectors
|
|
298
284
|
// due to roundoff errors
|
|
299
|
-
if (dis < 0)
|
|
285
|
+
if (dis < 0)
|
|
286
|
+
dis = 0;
|
|
300
287
|
|
|
301
288
|
*ip_line = dis;
|
|
302
289
|
ip_line++;
|
|
@@ -305,18 +292,12 @@ void exhaustive_L2sqr_blas (
|
|
|
305
292
|
res.add_results(j0, j1, ip_block.get());
|
|
306
293
|
}
|
|
307
294
|
res.end_multiple();
|
|
308
|
-
InterruptCallback::check
|
|
295
|
+
InterruptCallback::check();
|
|
309
296
|
}
|
|
310
297
|
}
|
|
311
298
|
|
|
312
|
-
|
|
313
|
-
|
|
314
299
|
} // anonymous namespace
|
|
315
300
|
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
301
|
/*******************************************************
|
|
321
302
|
* KNN driver functions
|
|
322
303
|
*******************************************************/
|
|
@@ -326,268 +307,275 @@ int distance_compute_blas_query_bs = 4096;
|
|
|
326
307
|
int distance_compute_blas_database_bs = 1024;
|
|
327
308
|
int distance_compute_min_k_reservoir = 100;
|
|
328
309
|
|
|
329
|
-
void knn_inner_product
|
|
330
|
-
const float
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
310
|
+
void knn_inner_product(
|
|
311
|
+
const float* x,
|
|
312
|
+
const float* y,
|
|
313
|
+
size_t d,
|
|
314
|
+
size_t nx,
|
|
315
|
+
size_t ny,
|
|
316
|
+
float_minheap_array_t* ha) {
|
|
334
317
|
if (ha->k < distance_compute_min_k_reservoir) {
|
|
335
318
|
HeapResultHandler<CMin<float, int64_t>> res(
|
|
336
|
-
|
|
319
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
|
337
320
|
if (nx < distance_compute_blas_threshold) {
|
|
338
|
-
exhaustive_inner_product_seq
|
|
321
|
+
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
|
339
322
|
} else {
|
|
340
|
-
exhaustive_inner_product_blas
|
|
323
|
+
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
|
341
324
|
}
|
|
342
325
|
} else {
|
|
343
326
|
ReservoirResultHandler<CMin<float, int64_t>> res(
|
|
344
|
-
|
|
327
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
|
345
328
|
if (nx < distance_compute_blas_threshold) {
|
|
346
|
-
exhaustive_inner_product_seq
|
|
329
|
+
exhaustive_inner_product_seq(x, y, d, nx, ny, res);
|
|
347
330
|
} else {
|
|
348
|
-
exhaustive_inner_product_blas
|
|
331
|
+
exhaustive_inner_product_blas(x, y, d, nx, ny, res);
|
|
349
332
|
}
|
|
350
333
|
}
|
|
351
334
|
}
|
|
352
335
|
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
const float *y_norm2
|
|
362
|
-
) {
|
|
363
|
-
|
|
336
|
+
void knn_L2sqr(
|
|
337
|
+
const float* x,
|
|
338
|
+
const float* y,
|
|
339
|
+
size_t d,
|
|
340
|
+
size_t nx,
|
|
341
|
+
size_t ny,
|
|
342
|
+
float_maxheap_array_t* ha,
|
|
343
|
+
const float* y_norm2) {
|
|
364
344
|
if (ha->k < distance_compute_min_k_reservoir) {
|
|
365
345
|
HeapResultHandler<CMax<float, int64_t>> res(
|
|
366
|
-
|
|
346
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
|
367
347
|
|
|
368
348
|
if (nx < distance_compute_blas_threshold) {
|
|
369
|
-
exhaustive_L2sqr_seq
|
|
349
|
+
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
|
|
370
350
|
} else {
|
|
371
|
-
exhaustive_L2sqr_blas
|
|
351
|
+
exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
|
|
372
352
|
}
|
|
373
353
|
} else {
|
|
374
354
|
ReservoirResultHandler<CMax<float, int64_t>> res(
|
|
375
|
-
|
|
355
|
+
ha->nh, ha->val, ha->ids, ha->k);
|
|
376
356
|
if (nx < distance_compute_blas_threshold) {
|
|
377
|
-
exhaustive_L2sqr_seq
|
|
357
|
+
exhaustive_L2sqr_seq(x, y, d, nx, ny, res);
|
|
378
358
|
} else {
|
|
379
|
-
exhaustive_L2sqr_blas
|
|
359
|
+
exhaustive_L2sqr_blas(x, y, d, nx, ny, res, y_norm2);
|
|
380
360
|
}
|
|
381
361
|
}
|
|
382
362
|
}
|
|
383
363
|
|
|
384
|
-
|
|
385
364
|
/***************************************************************************
|
|
386
365
|
* Range search
|
|
387
366
|
***************************************************************************/
|
|
388
367
|
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
size_t d, size_t nx, size_t ny,
|
|
368
|
+
void range_search_L2sqr(
|
|
369
|
+
const float* x,
|
|
370
|
+
const float* y,
|
|
371
|
+
size_t d,
|
|
372
|
+
size_t nx,
|
|
373
|
+
size_t ny,
|
|
396
374
|
float radius,
|
|
397
|
-
RangeSearchResult
|
|
398
|
-
{
|
|
375
|
+
RangeSearchResult* res) {
|
|
399
376
|
RangeSearchResultHandler<CMax<float, int64_t>> resh(res, radius);
|
|
400
377
|
if (nx < distance_compute_blas_threshold) {
|
|
401
|
-
exhaustive_L2sqr_seq
|
|
378
|
+
exhaustive_L2sqr_seq(x, y, d, nx, ny, resh);
|
|
402
379
|
} else {
|
|
403
|
-
exhaustive_L2sqr_blas
|
|
380
|
+
exhaustive_L2sqr_blas(x, y, d, nx, ny, resh);
|
|
404
381
|
}
|
|
405
382
|
}
|
|
406
383
|
|
|
407
|
-
void range_search_inner_product
|
|
408
|
-
const float
|
|
409
|
-
const float
|
|
410
|
-
size_t d,
|
|
384
|
+
void range_search_inner_product(
|
|
385
|
+
const float* x,
|
|
386
|
+
const float* y,
|
|
387
|
+
size_t d,
|
|
388
|
+
size_t nx,
|
|
389
|
+
size_t ny,
|
|
411
390
|
float radius,
|
|
412
|
-
RangeSearchResult
|
|
413
|
-
{
|
|
414
|
-
|
|
391
|
+
RangeSearchResult* res) {
|
|
415
392
|
RangeSearchResultHandler<CMin<float, int64_t>> resh(res, radius);
|
|
416
393
|
if (nx < distance_compute_blas_threshold) {
|
|
417
|
-
exhaustive_inner_product_seq
|
|
394
|
+
exhaustive_inner_product_seq(x, y, d, nx, ny, resh);
|
|
418
395
|
} else {
|
|
419
|
-
exhaustive_inner_product_blas
|
|
396
|
+
exhaustive_inner_product_blas(x, y, d, nx, ny, resh);
|
|
420
397
|
}
|
|
421
398
|
}
|
|
422
399
|
|
|
423
|
-
|
|
424
400
|
/***************************************************************************
|
|
425
401
|
* compute a subset of distances
|
|
426
402
|
***************************************************************************/
|
|
427
403
|
|
|
428
404
|
/* compute the inner product between x and a subset y of ny vectors,
|
|
429
405
|
whose indices are given by idy. */
|
|
430
|
-
void fvec_inner_products_by_idx
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
406
|
+
void fvec_inner_products_by_idx(
|
|
407
|
+
float* __restrict ip,
|
|
408
|
+
const float* x,
|
|
409
|
+
const float* y,
|
|
410
|
+
const int64_t* __restrict ids, /* for y vecs */
|
|
411
|
+
size_t d,
|
|
412
|
+
size_t nx,
|
|
413
|
+
size_t ny) {
|
|
436
414
|
#pragma omp parallel for
|
|
437
415
|
for (int64_t j = 0; j < nx; j++) {
|
|
438
|
-
const int64_t
|
|
439
|
-
const float
|
|
440
|
-
float
|
|
416
|
+
const int64_t* __restrict idsj = ids + j * ny;
|
|
417
|
+
const float* xj = x + j * d;
|
|
418
|
+
float* __restrict ipj = ip + j * ny;
|
|
441
419
|
for (size_t i = 0; i < ny; i++) {
|
|
442
420
|
if (idsj[i] < 0)
|
|
443
421
|
continue;
|
|
444
|
-
ipj[i] = fvec_inner_product
|
|
422
|
+
ipj[i] = fvec_inner_product(xj, y + d * idsj[i], d);
|
|
445
423
|
}
|
|
446
424
|
}
|
|
447
425
|
}
|
|
448
426
|
|
|
449
|
-
|
|
450
|
-
|
|
451
427
|
/* compute the inner product between x and a subset y of ny vectors,
|
|
452
428
|
whose indices are given by idy. */
|
|
453
|
-
void fvec_L2sqr_by_idx
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
429
|
+
void fvec_L2sqr_by_idx(
|
|
430
|
+
float* __restrict dis,
|
|
431
|
+
const float* x,
|
|
432
|
+
const float* y,
|
|
433
|
+
const int64_t* __restrict ids, /* ids of y vecs */
|
|
434
|
+
size_t d,
|
|
435
|
+
size_t nx,
|
|
436
|
+
size_t ny) {
|
|
459
437
|
#pragma omp parallel for
|
|
460
438
|
for (int64_t j = 0; j < nx; j++) {
|
|
461
|
-
const int64_t
|
|
462
|
-
const float
|
|
463
|
-
float
|
|
439
|
+
const int64_t* __restrict idsj = ids + j * ny;
|
|
440
|
+
const float* xj = x + j * d;
|
|
441
|
+
float* __restrict disj = dis + j * ny;
|
|
464
442
|
for (size_t i = 0; i < ny; i++) {
|
|
465
443
|
if (idsj[i] < 0)
|
|
466
444
|
continue;
|
|
467
|
-
disj[i] = fvec_L2sqr
|
|
445
|
+
disj[i] = fvec_L2sqr(xj, y + d * idsj[i], d);
|
|
468
446
|
}
|
|
469
447
|
}
|
|
470
448
|
}
|
|
471
449
|
|
|
472
|
-
void pairwise_indexed_L2sqr
|
|
473
|
-
size_t d,
|
|
474
|
-
|
|
475
|
-
const float
|
|
476
|
-
|
|
477
|
-
|
|
450
|
+
void pairwise_indexed_L2sqr(
|
|
451
|
+
size_t d,
|
|
452
|
+
size_t n,
|
|
453
|
+
const float* x,
|
|
454
|
+
const int64_t* ix,
|
|
455
|
+
const float* y,
|
|
456
|
+
const int64_t* iy,
|
|
457
|
+
float* dis) {
|
|
478
458
|
#pragma omp parallel for
|
|
479
459
|
for (int64_t j = 0; j < n; j++) {
|
|
480
460
|
if (ix[j] >= 0 && iy[j] >= 0) {
|
|
481
|
-
dis[j] = fvec_L2sqr
|
|
461
|
+
dis[j] = fvec_L2sqr(x + d * ix[j], y + d * iy[j], d);
|
|
482
462
|
}
|
|
483
463
|
}
|
|
484
464
|
}
|
|
485
465
|
|
|
486
|
-
void pairwise_indexed_inner_product
|
|
487
|
-
size_t d,
|
|
488
|
-
|
|
489
|
-
const float
|
|
490
|
-
|
|
491
|
-
|
|
466
|
+
void pairwise_indexed_inner_product(
|
|
467
|
+
size_t d,
|
|
468
|
+
size_t n,
|
|
469
|
+
const float* x,
|
|
470
|
+
const int64_t* ix,
|
|
471
|
+
const float* y,
|
|
472
|
+
const int64_t* iy,
|
|
473
|
+
float* dis) {
|
|
492
474
|
#pragma omp parallel for
|
|
493
475
|
for (int64_t j = 0; j < n; j++) {
|
|
494
476
|
if (ix[j] >= 0 && iy[j] >= 0) {
|
|
495
|
-
dis[j] = fvec_inner_product
|
|
477
|
+
dis[j] = fvec_inner_product(x + d * ix[j], y + d * iy[j], d);
|
|
496
478
|
}
|
|
497
479
|
}
|
|
498
480
|
}
|
|
499
481
|
|
|
500
|
-
|
|
501
482
|
/* Find the nearest neighbors for nx queries in a set of ny vectors
|
|
502
483
|
indexed by ids. May be useful for re-ranking a pre-selected vector list */
|
|
503
|
-
void knn_inner_products_by_idx
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
484
|
+
void knn_inner_products_by_idx(
|
|
485
|
+
const float* x,
|
|
486
|
+
const float* y,
|
|
487
|
+
const int64_t* ids,
|
|
488
|
+
size_t d,
|
|
489
|
+
size_t nx,
|
|
490
|
+
size_t ny,
|
|
491
|
+
float_minheap_array_t* res) {
|
|
509
492
|
size_t k = res->k;
|
|
510
493
|
|
|
511
494
|
#pragma omp parallel for
|
|
512
495
|
for (int64_t i = 0; i < nx; i++) {
|
|
513
|
-
const float
|
|
514
|
-
const int64_t
|
|
496
|
+
const float* x_ = x + i * d;
|
|
497
|
+
const int64_t* idsi = ids + i * ny;
|
|
515
498
|
size_t j;
|
|
516
|
-
float
|
|
517
|
-
int64_t
|
|
518
|
-
minheap_heapify
|
|
499
|
+
float* __restrict simi = res->get_val(i);
|
|
500
|
+
int64_t* __restrict idxi = res->get_ids(i);
|
|
501
|
+
minheap_heapify(k, simi, idxi);
|
|
519
502
|
|
|
520
503
|
for (j = 0; j < ny; j++) {
|
|
521
|
-
if (idsi[j] < 0)
|
|
522
|
-
|
|
504
|
+
if (idsi[j] < 0)
|
|
505
|
+
break;
|
|
506
|
+
float ip = fvec_inner_product(x_, y + d * idsi[j], d);
|
|
523
507
|
|
|
524
508
|
if (ip > simi[0]) {
|
|
525
|
-
minheap_replace_top
|
|
509
|
+
minheap_replace_top(k, simi, idxi, ip, idsi[j]);
|
|
526
510
|
}
|
|
527
511
|
}
|
|
528
|
-
minheap_reorder
|
|
512
|
+
minheap_reorder(k, simi, idxi);
|
|
529
513
|
}
|
|
530
|
-
|
|
531
514
|
}
|
|
532
515
|
|
|
533
|
-
void knn_L2sqr_by_idx
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
516
|
+
void knn_L2sqr_by_idx(
|
|
517
|
+
const float* x,
|
|
518
|
+
const float* y,
|
|
519
|
+
const int64_t* __restrict ids,
|
|
520
|
+
size_t d,
|
|
521
|
+
size_t nx,
|
|
522
|
+
size_t ny,
|
|
523
|
+
float_maxheap_array_t* res) {
|
|
539
524
|
size_t k = res->k;
|
|
540
525
|
|
|
541
526
|
#pragma omp parallel for
|
|
542
527
|
for (int64_t i = 0; i < nx; i++) {
|
|
543
|
-
const float
|
|
544
|
-
const int64_t
|
|
545
|
-
float
|
|
546
|
-
int64_t
|
|
547
|
-
maxheap_heapify
|
|
528
|
+
const float* x_ = x + i * d;
|
|
529
|
+
const int64_t* __restrict idsi = ids + i * ny;
|
|
530
|
+
float* __restrict simi = res->get_val(i);
|
|
531
|
+
int64_t* __restrict idxi = res->get_ids(i);
|
|
532
|
+
maxheap_heapify(res->k, simi, idxi);
|
|
548
533
|
for (size_t j = 0; j < ny; j++) {
|
|
549
|
-
float disij = fvec_L2sqr
|
|
534
|
+
float disij = fvec_L2sqr(x_, y + d * idsi[j], d);
|
|
550
535
|
|
|
551
536
|
if (disij < simi[0]) {
|
|
552
|
-
maxheap_replace_top
|
|
537
|
+
maxheap_replace_top(k, simi, idxi, disij, idsi[j]);
|
|
553
538
|
}
|
|
554
539
|
}
|
|
555
|
-
maxheap_reorder
|
|
540
|
+
maxheap_reorder(res->k, simi, idxi);
|
|
556
541
|
}
|
|
557
|
-
|
|
558
542
|
}
|
|
559
543
|
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
{
|
|
570
|
-
if (nq == 0 || nb == 0)
|
|
571
|
-
|
|
572
|
-
if (
|
|
573
|
-
|
|
544
|
+
void pairwise_L2sqr(
|
|
545
|
+
int64_t d,
|
|
546
|
+
int64_t nq,
|
|
547
|
+
const float* xq,
|
|
548
|
+
int64_t nb,
|
|
549
|
+
const float* xb,
|
|
550
|
+
float* dis,
|
|
551
|
+
int64_t ldq,
|
|
552
|
+
int64_t ldb,
|
|
553
|
+
int64_t ldd) {
|
|
554
|
+
if (nq == 0 || nb == 0)
|
|
555
|
+
return;
|
|
556
|
+
if (ldq == -1)
|
|
557
|
+
ldq = d;
|
|
558
|
+
if (ldb == -1)
|
|
559
|
+
ldb = d;
|
|
560
|
+
if (ldd == -1)
|
|
561
|
+
ldd = nb;
|
|
574
562
|
|
|
575
563
|
// store in beginning of distance matrix to avoid malloc
|
|
576
|
-
float
|
|
564
|
+
float* b_norms = dis;
|
|
577
565
|
|
|
578
566
|
#pragma omp parallel for
|
|
579
567
|
for (int64_t i = 0; i < nb; i++)
|
|
580
|
-
b_norms
|
|
568
|
+
b_norms[i] = fvec_norm_L2sqr(xb + i * ldb, d);
|
|
581
569
|
|
|
582
570
|
#pragma omp parallel for
|
|
583
571
|
for (int64_t i = 1; i < nq; i++) {
|
|
584
|
-
float q_norm = fvec_norm_L2sqr
|
|
572
|
+
float q_norm = fvec_norm_L2sqr(xq + i * ldq, d);
|
|
585
573
|
for (int64_t j = 0; j < nb; j++)
|
|
586
|
-
dis[i * ldd + j] = q_norm + b_norms
|
|
574
|
+
dis[i * ldd + j] = q_norm + b_norms[j];
|
|
587
575
|
}
|
|
588
576
|
|
|
589
577
|
{
|
|
590
|
-
float q_norm = fvec_norm_L2sqr
|
|
578
|
+
float q_norm = fvec_norm_L2sqr(xq, d);
|
|
591
579
|
for (int64_t j = 0; j < nb; j++)
|
|
592
580
|
dis[j] += q_norm;
|
|
593
581
|
}
|
|
@@ -596,22 +584,28 @@ void pairwise_L2sqr (int64_t d,
|
|
|
596
584
|
FINTEGER nbi = nb, nqi = nq, di = d, ldqi = ldq, ldbi = ldb, lddi = ldd;
|
|
597
585
|
float one = 1.0, minus_2 = -2.0;
|
|
598
586
|
|
|
599
|
-
sgemm_
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
587
|
+
sgemm_("Transposed",
|
|
588
|
+
"Not transposed",
|
|
589
|
+
&nbi,
|
|
590
|
+
&nqi,
|
|
591
|
+
&di,
|
|
592
|
+
&minus_2,
|
|
593
|
+
xb,
|
|
594
|
+
&ldbi,
|
|
595
|
+
xq,
|
|
596
|
+
&ldqi,
|
|
597
|
+
&one,
|
|
598
|
+
dis,
|
|
599
|
+
&lddi);
|
|
605
600
|
}
|
|
606
|
-
|
|
607
601
|
}
|
|
608
602
|
|
|
609
|
-
void inner_product_to_L2sqr(
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
603
|
+
void inner_product_to_L2sqr(
|
|
604
|
+
float* __restrict dis,
|
|
605
|
+
const float* nr1,
|
|
606
|
+
const float* nr2,
|
|
607
|
+
size_t n1,
|
|
608
|
+
size_t n2) {
|
|
615
609
|
#pragma omp parallel for
|
|
616
610
|
for (int64_t j = 0; j < n1; j++) {
|
|
617
611
|
float* disj = dis + j * n2;
|
|
@@ -620,5 +614,4 @@ void inner_product_to_L2sqr(float* __restrict dis,
|
|
|
620
614
|
}
|
|
621
615
|
}
|
|
622
616
|
|
|
623
|
-
|
|
624
617
|
} // namespace faiss
|