faiss 0.1.5 → 0.2.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +24 -0
- data/README.md +12 -0
- data/ext/faiss/ext.cpp +1 -1
- data/ext/faiss/extconf.rb +6 -2
- data/ext/faiss/index.cpp +114 -43
- data/ext/faiss/index_binary.cpp +24 -30
- data/ext/faiss/kmeans.cpp +20 -16
- data/ext/faiss/numo.hpp +867 -0
- data/ext/faiss/pca_matrix.cpp +13 -14
- data/ext/faiss/product_quantizer.cpp +23 -24
- data/ext/faiss/utils.cpp +10 -37
- data/ext/faiss/utils.h +2 -13
- data/lib/faiss.rb +0 -5
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +334 -195
- data/vendor/faiss/faiss/Clustering.h +88 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
- data/vendor/faiss/faiss/Index2Layer.h +22 -22
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
- data/vendor/faiss/faiss/IndexFlat.h +35 -46
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
- data/vendor/faiss/faiss/IndexIVF.h +146 -113
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
- data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
- data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
- data/vendor/faiss/faiss/IndexLSH.h +21 -26
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
- data/vendor/faiss/faiss/IndexPQ.h +64 -67
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
- data/vendor/faiss/faiss/IndexRefine.h +22 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
- data/vendor/faiss/faiss/IndexResidual.h +152 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
- data/vendor/faiss/faiss/VectorTransform.h +61 -89
- data/vendor/faiss/faiss/clone_index.cpp +77 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
- data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
- data/vendor/faiss/faiss/impl/io.cpp +75 -94
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +40 -29
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +269 -218
- data/vendor/faiss/faiss/index_factory.h +6 -7
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +301 -310
- data/vendor/faiss/faiss/utils/distances.h +133 -118
- data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +53 -48
- metadata +24 -10
- data/lib/faiss/index.rb +0 -20
- data/lib/faiss/index_binary.rb +0 -20
- data/lib/faiss/kmeans.rb +0 -15
- data/lib/faiss/pca_matrix.rb +0 -15
- data/lib/faiss/product_quantizer.rb +0 -22
|
@@ -5,49 +5,34 @@
|
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
|
6
6
|
*/
|
|
7
7
|
|
|
8
|
-
// -*- c++ -*-
|
|
9
|
-
|
|
10
8
|
/* All distance functions for L2 and IP distances.
|
|
11
|
-
* The actual functions are implemented in distances.cpp and distances_simd.cpp
|
|
9
|
+
* The actual functions are implemented in distances.cpp and distances_simd.cpp
|
|
10
|
+
*/
|
|
12
11
|
|
|
13
12
|
#pragma once
|
|
14
13
|
|
|
15
14
|
#include <stdint.h>
|
|
16
15
|
|
|
17
|
-
#include <faiss/utils/Heap.h>
|
|
18
16
|
#include <faiss/impl/platform_macros.h>
|
|
19
|
-
|
|
17
|
+
#include <faiss/utils/Heap.h>
|
|
20
18
|
|
|
21
19
|
namespace faiss {
|
|
22
20
|
|
|
23
|
-
|
|
21
|
+
/*********************************************************
|
|
24
22
|
* Optimized distance/norm/inner prod computations
|
|
25
23
|
*********************************************************/
|
|
26
24
|
|
|
27
|
-
|
|
28
25
|
/// Squared L2 distance between two vectors
|
|
29
|
-
float fvec_L2sqr
|
|
30
|
-
const float * x,
|
|
31
|
-
const float * y,
|
|
32
|
-
size_t d);
|
|
26
|
+
float fvec_L2sqr(const float* x, const float* y, size_t d);
|
|
33
27
|
|
|
34
28
|
/// inner product
|
|
35
|
-
float
|
|
36
|
-
const float * x,
|
|
37
|
-
const float * y,
|
|
38
|
-
size_t d);
|
|
29
|
+
float fvec_inner_product(const float* x, const float* y, size_t d);
|
|
39
30
|
|
|
40
31
|
/// L1 distance
|
|
41
|
-
float fvec_L1
|
|
42
|
-
const float * x,
|
|
43
|
-
const float * y,
|
|
44
|
-
size_t d);
|
|
45
|
-
|
|
46
|
-
float fvec_Linf (
|
|
47
|
-
const float * x,
|
|
48
|
-
const float * y,
|
|
49
|
-
size_t d);
|
|
32
|
+
float fvec_L1(const float* x, const float* y, size_t d);
|
|
50
33
|
|
|
34
|
+
/// infinity distance
|
|
35
|
+
float fvec_Linf(const float* x, const float* y, size_t d);
|
|
51
36
|
|
|
52
37
|
/** Compute pairwise distances between sets of vectors
|
|
53
38
|
*
|
|
@@ -59,74 +44,83 @@ float fvec_Linf (
|
|
|
59
44
|
* @param dis output distances (size nq * nb)
|
|
60
45
|
* @param ldq,ldb, ldd strides for the matrices
|
|
61
46
|
*/
|
|
62
|
-
void pairwise_L2sqr
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
47
|
+
void pairwise_L2sqr(
|
|
48
|
+
int64_t d,
|
|
49
|
+
int64_t nq,
|
|
50
|
+
const float* xq,
|
|
51
|
+
int64_t nb,
|
|
52
|
+
const float* xb,
|
|
53
|
+
float* dis,
|
|
54
|
+
int64_t ldq = -1,
|
|
55
|
+
int64_t ldb = -1,
|
|
56
|
+
int64_t ldd = -1);
|
|
67
57
|
|
|
68
58
|
/* compute the inner product between nx vectors x and one y */
|
|
69
|
-
void fvec_inner_products_ny
|
|
70
|
-
float
|
|
71
|
-
const float
|
|
72
|
-
const float
|
|
73
|
-
size_t d,
|
|
59
|
+
void fvec_inner_products_ny(
|
|
60
|
+
float* ip, /* output inner product */
|
|
61
|
+
const float* x,
|
|
62
|
+
const float* y,
|
|
63
|
+
size_t d,
|
|
64
|
+
size_t ny);
|
|
74
65
|
|
|
75
66
|
/* compute ny square L2 distance bewteen x and a set of contiguous y vectors */
|
|
76
|
-
void fvec_L2sqr_ny
|
|
77
|
-
float
|
|
78
|
-
const float
|
|
79
|
-
const float
|
|
80
|
-
size_t d,
|
|
81
|
-
|
|
67
|
+
void fvec_L2sqr_ny(
|
|
68
|
+
float* dis,
|
|
69
|
+
const float* x,
|
|
70
|
+
const float* y,
|
|
71
|
+
size_t d,
|
|
72
|
+
size_t ny);
|
|
82
73
|
|
|
83
74
|
/** squared norm of a vector */
|
|
84
|
-
float fvec_norm_L2sqr
|
|
85
|
-
size_t d);
|
|
75
|
+
float fvec_norm_L2sqr(const float* x, size_t d);
|
|
86
76
|
|
|
87
77
|
/** compute the L2 norms for a set of vectors
|
|
88
78
|
*
|
|
89
|
-
* @param
|
|
79
|
+
* @param norms output norms, size nx
|
|
90
80
|
* @param x set of vectors, size nx * d
|
|
91
81
|
*/
|
|
92
|
-
void fvec_norms_L2
|
|
82
|
+
void fvec_norms_L2(float* norms, const float* x, size_t d, size_t nx);
|
|
93
83
|
|
|
94
|
-
/// same as fvec_norms_L2, but computes
|
|
95
|
-
void fvec_norms_L2sqr
|
|
84
|
+
/// same as fvec_norms_L2, but computes squared norms
|
|
85
|
+
void fvec_norms_L2sqr(float* norms, const float* x, size_t d, size_t nx);
|
|
96
86
|
|
|
97
87
|
/* L2-renormalize a set of vector. Nothing done if the vector is 0-normed */
|
|
98
|
-
void fvec_renorm_L2
|
|
99
|
-
|
|
88
|
+
void fvec_renorm_L2(size_t d, size_t nx, float* x);
|
|
100
89
|
|
|
101
90
|
/* This function exists because the Torch counterpart is extremly slow
|
|
102
91
|
(not multi-threaded + unexpected overhead even in single thread).
|
|
103
92
|
It is here to implement the usual property |x-y|^2=|x|^2+|y|^2-2<x|y> */
|
|
104
|
-
void inner_product_to_L2sqr
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
93
|
+
void inner_product_to_L2sqr(
|
|
94
|
+
float* dis,
|
|
95
|
+
const float* nr1,
|
|
96
|
+
const float* nr2,
|
|
97
|
+
size_t n1,
|
|
98
|
+
size_t n2);
|
|
108
99
|
|
|
109
100
|
/***************************************************************************
|
|
110
101
|
* Compute a subset of distances
|
|
111
102
|
***************************************************************************/
|
|
112
103
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
void fvec_inner_products_by_idx
|
|
116
|
-
float
|
|
117
|
-
const float
|
|
118
|
-
const float
|
|
119
|
-
const int64_t
|
|
120
|
-
size_t d,
|
|
104
|
+
/* compute the inner product between x and a subset y of ny vectors,
|
|
105
|
+
whose indices are given by idy. */
|
|
106
|
+
void fvec_inner_products_by_idx(
|
|
107
|
+
float* ip,
|
|
108
|
+
const float* x,
|
|
109
|
+
const float* y,
|
|
110
|
+
const int64_t* ids,
|
|
111
|
+
size_t d,
|
|
112
|
+
size_t nx,
|
|
113
|
+
size_t ny);
|
|
121
114
|
|
|
122
115
|
/* same but for a subset in y indexed by idsy (ny vectors in total) */
|
|
123
|
-
void fvec_L2sqr_by_idx
|
|
124
|
-
float
|
|
125
|
-
const float
|
|
126
|
-
const float
|
|
127
|
-
const int64_t
|
|
128
|
-
size_t d,
|
|
129
|
-
|
|
116
|
+
void fvec_L2sqr_by_idx(
|
|
117
|
+
float* dis,
|
|
118
|
+
const float* x,
|
|
119
|
+
const float* y,
|
|
120
|
+
const int64_t* ids, /* ids of y vecs */
|
|
121
|
+
size_t d,
|
|
122
|
+
size_t nx,
|
|
123
|
+
size_t ny);
|
|
130
124
|
|
|
131
125
|
/** compute dis[j] = L2sqr(x[ix[j]], y[iy[j]]) forall j=0..n-1
|
|
132
126
|
*
|
|
@@ -136,18 +130,24 @@ void fvec_L2sqr_by_idx (
|
|
|
136
130
|
* @param iy size n
|
|
137
131
|
* @param dis size n
|
|
138
132
|
*/
|
|
139
|
-
void pairwise_indexed_L2sqr
|
|
140
|
-
size_t d,
|
|
141
|
-
|
|
142
|
-
const float
|
|
143
|
-
|
|
133
|
+
void pairwise_indexed_L2sqr(
|
|
134
|
+
size_t d,
|
|
135
|
+
size_t n,
|
|
136
|
+
const float* x,
|
|
137
|
+
const int64_t* ix,
|
|
138
|
+
const float* y,
|
|
139
|
+
const int64_t* iy,
|
|
140
|
+
float* dis);
|
|
144
141
|
|
|
145
142
|
/* same for inner product */
|
|
146
|
-
void pairwise_indexed_inner_product
|
|
147
|
-
size_t d,
|
|
148
|
-
|
|
149
|
-
const float
|
|
150
|
-
|
|
143
|
+
void pairwise_indexed_inner_product(
|
|
144
|
+
size_t d,
|
|
145
|
+
size_t n,
|
|
146
|
+
const float* x,
|
|
147
|
+
const int64_t* ix,
|
|
148
|
+
const float* y,
|
|
149
|
+
const int64_t* iy,
|
|
150
|
+
float* dis);
|
|
151
151
|
|
|
152
152
|
/***************************************************************************
|
|
153
153
|
* KNN functions
|
|
@@ -171,46 +171,51 @@ FAISS_API extern int distance_compute_min_k_reservoir;
|
|
|
171
171
|
* @param y database vectors, size ny * d
|
|
172
172
|
* @param res result array, which also provides k. Sorted on output
|
|
173
173
|
*/
|
|
174
|
-
void knn_inner_product
|
|
175
|
-
const float
|
|
176
|
-
const float
|
|
177
|
-
size_t d,
|
|
178
|
-
|
|
174
|
+
void knn_inner_product(
|
|
175
|
+
const float* x,
|
|
176
|
+
const float* y,
|
|
177
|
+
size_t d,
|
|
178
|
+
size_t nx,
|
|
179
|
+
size_t ny,
|
|
180
|
+
float_minheap_array_t* res);
|
|
179
181
|
|
|
180
182
|
/** Same as knn_inner_product, for the L2 distance
|
|
181
183
|
* @param y_norm2 norms for the y vectors (nullptr or size ny)
|
|
182
184
|
*/
|
|
183
|
-
void knn_L2sqr
|
|
184
|
-
const float
|
|
185
|
-
const float
|
|
186
|
-
size_t d,
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
185
|
+
void knn_L2sqr(
|
|
186
|
+
const float* x,
|
|
187
|
+
const float* y,
|
|
188
|
+
size_t d,
|
|
189
|
+
size_t nx,
|
|
190
|
+
size_t ny,
|
|
191
|
+
float_maxheap_array_t* res,
|
|
192
|
+
const float* y_norm2 = nullptr);
|
|
190
193
|
|
|
191
194
|
/* Find the nearest neighbors for nx queries in a set of ny vectors
|
|
192
195
|
* indexed by ids. May be useful for re-ranking a pre-selected vector list
|
|
193
196
|
*/
|
|
194
|
-
void knn_inner_products_by_idx
|
|
195
|
-
const float
|
|
196
|
-
const float
|
|
197
|
-
const int64_t
|
|
198
|
-
size_t d,
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
const
|
|
205
|
-
|
|
206
|
-
|
|
197
|
+
void knn_inner_products_by_idx(
|
|
198
|
+
const float* x,
|
|
199
|
+
const float* y,
|
|
200
|
+
const int64_t* ids,
|
|
201
|
+
size_t d,
|
|
202
|
+
size_t nx,
|
|
203
|
+
size_t ny,
|
|
204
|
+
float_minheap_array_t* res);
|
|
205
|
+
|
|
206
|
+
void knn_L2sqr_by_idx(
|
|
207
|
+
const float* x,
|
|
208
|
+
const float* y,
|
|
209
|
+
const int64_t* ids,
|
|
210
|
+
size_t d,
|
|
211
|
+
size_t nx,
|
|
212
|
+
size_t ny,
|
|
213
|
+
float_maxheap_array_t* res);
|
|
207
214
|
|
|
208
215
|
/***************************************************************************
|
|
209
216
|
* Range search
|
|
210
217
|
***************************************************************************/
|
|
211
218
|
|
|
212
|
-
|
|
213
|
-
|
|
214
219
|
/// Forward declaration, see AuxIndexStructures.h
|
|
215
220
|
struct RangeSearchResult;
|
|
216
221
|
|
|
@@ -222,21 +227,24 @@ struct RangeSearchResult;
|
|
|
222
227
|
* @param radius search radius around the x vectors
|
|
223
228
|
* @param result result structure
|
|
224
229
|
*/
|
|
225
|
-
void range_search_L2sqr
|
|
226
|
-
const float
|
|
227
|
-
const float
|
|
228
|
-
size_t d,
|
|
230
|
+
void range_search_L2sqr(
|
|
231
|
+
const float* x,
|
|
232
|
+
const float* y,
|
|
233
|
+
size_t d,
|
|
234
|
+
size_t nx,
|
|
235
|
+
size_t ny,
|
|
229
236
|
float radius,
|
|
230
|
-
RangeSearchResult
|
|
237
|
+
RangeSearchResult* result);
|
|
231
238
|
|
|
232
239
|
/// same as range_search_L2sqr for the inner product similarity
|
|
233
|
-
void range_search_inner_product
|
|
234
|
-
const float
|
|
235
|
-
const float
|
|
236
|
-
size_t d,
|
|
240
|
+
void range_search_inner_product(
|
|
241
|
+
const float* x,
|
|
242
|
+
const float* y,
|
|
243
|
+
size_t d,
|
|
244
|
+
size_t nx,
|
|
245
|
+
size_t ny,
|
|
237
246
|
float radius,
|
|
238
|
-
RangeSearchResult
|
|
239
|
-
|
|
247
|
+
RangeSearchResult* result);
|
|
240
248
|
|
|
241
249
|
/***************************************************************************
|
|
242
250
|
* PQ tables computations
|
|
@@ -244,9 +252,16 @@ void range_search_inner_product (
|
|
|
244
252
|
|
|
245
253
|
/// specialized function for PQ2
|
|
246
254
|
void compute_PQ_dis_tables_dsub2(
|
|
247
|
-
size_t d,
|
|
248
|
-
size_t
|
|
255
|
+
size_t d,
|
|
256
|
+
size_t ksub,
|
|
257
|
+
const float* centroids,
|
|
258
|
+
size_t nx,
|
|
259
|
+
const float* x,
|
|
249
260
|
bool is_inner_product,
|
|
250
|
-
float
|
|
261
|
+
float* dis_tables);
|
|
262
|
+
|
|
263
|
+
/***************************************************************************
|
|
264
|
+
* Templatized versions of distance functions
|
|
265
|
+
***************************************************************************/
|
|
251
266
|
|
|
252
267
|
} // namespace faiss
|
|
@@ -9,13 +9,14 @@
|
|
|
9
9
|
|
|
10
10
|
#include <faiss/utils/distances.h>
|
|
11
11
|
|
|
12
|
-
#include <cstdio>
|
|
13
12
|
#include <cassert>
|
|
14
|
-
#include <cstring>
|
|
15
13
|
#include <cmath>
|
|
14
|
+
#include <cstdio>
|
|
15
|
+
#include <cstring>
|
|
16
16
|
|
|
17
|
-
#include <faiss/utils/simdlib.h>
|
|
18
17
|
#include <faiss/impl/FaissAssert.h>
|
|
18
|
+
#include <faiss/impl/platform_macros.h>
|
|
19
|
+
#include <faiss/utils/simdlib.h>
|
|
19
20
|
|
|
20
21
|
#ifdef __SSE3__
|
|
21
22
|
#include <immintrin.h>
|
|
@@ -25,19 +26,16 @@
|
|
|
25
26
|
#include <arm_neon.h>
|
|
26
27
|
#endif
|
|
27
28
|
|
|
28
|
-
|
|
29
29
|
namespace faiss {
|
|
30
30
|
|
|
31
31
|
#ifdef __AVX__
|
|
32
32
|
#define USE_AVX
|
|
33
33
|
#endif
|
|
34
34
|
|
|
35
|
-
|
|
36
35
|
/*********************************************************
|
|
37
36
|
* Optimized distance computations
|
|
38
37
|
*********************************************************/
|
|
39
38
|
|
|
40
|
-
|
|
41
39
|
/* Functions to compute:
|
|
42
40
|
- L2 distance between 2 vectors
|
|
43
41
|
- inner product between 2 vectors
|
|
@@ -53,29 +51,21 @@ namespace faiss {
|
|
|
53
51
|
|
|
54
52
|
*/
|
|
55
53
|
|
|
56
|
-
|
|
57
54
|
/*********************************************************
|
|
58
55
|
* Reference implementations
|
|
59
56
|
*/
|
|
60
57
|
|
|
61
|
-
|
|
62
|
-
float fvec_L2sqr_ref (const float * x,
|
|
63
|
-
const float * y,
|
|
64
|
-
size_t d)
|
|
65
|
-
{
|
|
58
|
+
float fvec_L2sqr_ref(const float* x, const float* y, size_t d) {
|
|
66
59
|
size_t i;
|
|
67
60
|
float res = 0;
|
|
68
61
|
for (i = 0; i < d; i++) {
|
|
69
62
|
const float tmp = x[i] - y[i];
|
|
70
|
-
|
|
63
|
+
res += tmp * tmp;
|
|
71
64
|
}
|
|
72
65
|
return res;
|
|
73
66
|
}
|
|
74
67
|
|
|
75
|
-
float fvec_L1_ref
|
|
76
|
-
const float * y,
|
|
77
|
-
size_t d)
|
|
78
|
-
{
|
|
68
|
+
float fvec_L1_ref(const float* x, const float* y, size_t d) {
|
|
79
69
|
size_t i;
|
|
80
70
|
float res = 0;
|
|
81
71
|
for (i = 0; i < d; i++) {
|
|
@@ -85,56 +75,49 @@ float fvec_L1_ref (const float * x,
|
|
|
85
75
|
return res;
|
|
86
76
|
}
|
|
87
77
|
|
|
88
|
-
float fvec_Linf_ref
|
|
89
|
-
const float * y,
|
|
90
|
-
size_t d)
|
|
91
|
-
{
|
|
78
|
+
float fvec_Linf_ref(const float* x, const float* y, size_t d) {
|
|
92
79
|
size_t i;
|
|
93
80
|
float res = 0;
|
|
94
81
|
for (i = 0; i < d; i++) {
|
|
95
|
-
|
|
82
|
+
res = fmax(res, fabs(x[i] - y[i]));
|
|
96
83
|
}
|
|
97
84
|
return res;
|
|
98
85
|
}
|
|
99
86
|
|
|
100
|
-
float fvec_inner_product_ref
|
|
101
|
-
const float * y,
|
|
102
|
-
size_t d)
|
|
103
|
-
{
|
|
87
|
+
float fvec_inner_product_ref(const float* x, const float* y, size_t d) {
|
|
104
88
|
size_t i;
|
|
105
89
|
float res = 0;
|
|
106
90
|
for (i = 0; i < d; i++)
|
|
107
|
-
|
|
91
|
+
res += x[i] * y[i];
|
|
108
92
|
return res;
|
|
109
93
|
}
|
|
110
94
|
|
|
111
|
-
float fvec_norm_L2sqr_ref
|
|
112
|
-
{
|
|
95
|
+
float fvec_norm_L2sqr_ref(const float* x, size_t d) {
|
|
113
96
|
size_t i;
|
|
114
97
|
double res = 0;
|
|
115
98
|
for (i = 0; i < d; i++)
|
|
116
|
-
|
|
99
|
+
res += x[i] * x[i];
|
|
117
100
|
return res;
|
|
118
101
|
}
|
|
119
102
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
{
|
|
103
|
+
void fvec_L2sqr_ny_ref(
|
|
104
|
+
float* dis,
|
|
105
|
+
const float* x,
|
|
106
|
+
const float* y,
|
|
107
|
+
size_t d,
|
|
108
|
+
size_t ny) {
|
|
126
109
|
for (size_t i = 0; i < ny; i++) {
|
|
127
|
-
dis[i] = fvec_L2sqr
|
|
110
|
+
dis[i] = fvec_L2sqr(x, y, d);
|
|
128
111
|
y += d;
|
|
129
112
|
}
|
|
130
113
|
}
|
|
131
114
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
{
|
|
115
|
+
void fvec_inner_products_ny_ref(
|
|
116
|
+
float* ip,
|
|
117
|
+
const float* x,
|
|
118
|
+
const float* y,
|
|
119
|
+
size_t d,
|
|
120
|
+
size_t ny) {
|
|
138
121
|
// BLAS slower for the use cases here
|
|
139
122
|
#if 0
|
|
140
123
|
{
|
|
@@ -146,15 +129,11 @@ void fvec_inner_products_ny_ref (float * ip,
|
|
|
146
129
|
}
|
|
147
130
|
#endif
|
|
148
131
|
for (size_t i = 0; i < ny; i++) {
|
|
149
|
-
ip[i] = fvec_inner_product
|
|
132
|
+
ip[i] = fvec_inner_product(x, y, d);
|
|
150
133
|
y += d;
|
|
151
134
|
}
|
|
152
135
|
}
|
|
153
136
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
137
|
/*********************************************************
|
|
159
138
|
* SSE and AVX implementations
|
|
160
139
|
*/
|
|
@@ -162,40 +141,38 @@ void fvec_inner_products_ny_ref (float * ip,
|
|
|
162
141
|
#ifdef __SSE3__
|
|
163
142
|
|
|
164
143
|
// reads 0 <= d < 4 floats as __m128
|
|
165
|
-
static inline __m128 masked_read
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
__attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
|
|
144
|
+
static inline __m128 masked_read(int d, const float* x) {
|
|
145
|
+
assert(0 <= d && d < 4);
|
|
146
|
+
ALIGNED(16) float buf[4] = {0, 0, 0, 0};
|
|
169
147
|
switch (d) {
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
}
|
|
177
|
-
return _mm_load_ps
|
|
148
|
+
case 3:
|
|
149
|
+
buf[2] = x[2];
|
|
150
|
+
case 2:
|
|
151
|
+
buf[1] = x[1];
|
|
152
|
+
case 1:
|
|
153
|
+
buf[0] = x[0];
|
|
154
|
+
}
|
|
155
|
+
return _mm_load_ps(buf);
|
|
178
156
|
// cannot use AVX2 _mm_mask_set1_epi32
|
|
179
157
|
}
|
|
180
158
|
|
|
181
|
-
float fvec_norm_L2sqr
|
|
182
|
-
size_t d)
|
|
183
|
-
{
|
|
159
|
+
float fvec_norm_L2sqr(const float* x, size_t d) {
|
|
184
160
|
__m128 mx;
|
|
185
161
|
__m128 msum1 = _mm_setzero_ps();
|
|
186
162
|
|
|
187
163
|
while (d >= 4) {
|
|
188
|
-
mx = _mm_loadu_ps
|
|
189
|
-
|
|
164
|
+
mx = _mm_loadu_ps(x);
|
|
165
|
+
x += 4;
|
|
166
|
+
msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
|
|
190
167
|
d -= 4;
|
|
191
168
|
}
|
|
192
169
|
|
|
193
|
-
mx = masked_read
|
|
194
|
-
msum1 = _mm_add_ps
|
|
170
|
+
mx = masked_read(d, x);
|
|
171
|
+
msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
|
|
195
172
|
|
|
196
|
-
msum1 = _mm_hadd_ps
|
|
197
|
-
msum1 = _mm_hadd_ps
|
|
198
|
-
return
|
|
173
|
+
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
174
|
+
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
175
|
+
return _mm_cvtss_f32(msum1);
|
|
199
176
|
}
|
|
200
177
|
|
|
201
178
|
namespace {
|
|
@@ -204,586 +181,588 @@ namespace {
|
|
|
204
181
|
/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
|
|
205
182
|
/// functions below
|
|
206
183
|
struct ElementOpL2 {
|
|
207
|
-
|
|
208
|
-
static float op (float x, float y) {
|
|
184
|
+
static float op(float x, float y) {
|
|
209
185
|
float tmp = x - y;
|
|
210
186
|
return tmp * tmp;
|
|
211
187
|
}
|
|
212
188
|
|
|
213
|
-
static __m128 op
|
|
214
|
-
__m128 tmp = x
|
|
215
|
-
return tmp
|
|
189
|
+
static __m128 op(__m128 x, __m128 y) {
|
|
190
|
+
__m128 tmp = _mm_sub_ps(x, y);
|
|
191
|
+
return _mm_mul_ps(tmp, tmp);
|
|
216
192
|
}
|
|
217
|
-
|
|
218
193
|
};
|
|
219
194
|
|
|
220
195
|
/// Function that does a component-wise operation between x and y
|
|
221
196
|
/// to compute inner products
|
|
222
197
|
struct ElementOpIP {
|
|
223
|
-
|
|
224
|
-
static float op (float x, float y) {
|
|
198
|
+
static float op(float x, float y) {
|
|
225
199
|
return x * y;
|
|
226
200
|
}
|
|
227
201
|
|
|
228
|
-
static __m128 op
|
|
229
|
-
return x
|
|
202
|
+
static __m128 op(__m128 x, __m128 y) {
|
|
203
|
+
return _mm_mul_ps(x, y);
|
|
230
204
|
}
|
|
231
|
-
|
|
232
205
|
};
|
|
233
206
|
|
|
234
|
-
template<class ElementOp>
|
|
235
|
-
void fvec_op_ny_D1
|
|
236
|
-
const float * y, size_t ny)
|
|
237
|
-
{
|
|
207
|
+
template <class ElementOp>
|
|
208
|
+
void fvec_op_ny_D1(float* dis, const float* x, const float* y, size_t ny) {
|
|
238
209
|
float x0s = x[0];
|
|
239
|
-
__m128 x0 = _mm_set_ps
|
|
210
|
+
__m128 x0 = _mm_set_ps(x0s, x0s, x0s, x0s);
|
|
240
211
|
|
|
241
212
|
size_t i;
|
|
242
213
|
for (i = 0; i + 3 < ny; i += 4) {
|
|
243
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
214
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
215
|
+
y += 4;
|
|
216
|
+
dis[i] = _mm_cvtss_f32(accu);
|
|
217
|
+
__m128 tmp = _mm_shuffle_ps(accu, accu, 1);
|
|
218
|
+
dis[i + 1] = _mm_cvtss_f32(tmp);
|
|
219
|
+
tmp = _mm_shuffle_ps(accu, accu, 2);
|
|
220
|
+
dis[i + 2] = _mm_cvtss_f32(tmp);
|
|
221
|
+
tmp = _mm_shuffle_ps(accu, accu, 3);
|
|
222
|
+
dis[i + 3] = _mm_cvtss_f32(tmp);
|
|
251
223
|
}
|
|
252
224
|
while (i < ny) { // handle non-multiple-of-4 case
|
|
253
225
|
dis[i++] = ElementOp::op(x0s, *y++);
|
|
254
226
|
}
|
|
255
227
|
}
|
|
256
228
|
|
|
257
|
-
template<class ElementOp>
|
|
258
|
-
void fvec_op_ny_D2
|
|
259
|
-
|
|
260
|
-
{
|
|
261
|
-
__m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
|
|
229
|
+
template <class ElementOp>
|
|
230
|
+
void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) {
|
|
231
|
+
__m128 x0 = _mm_set_ps(x[1], x[0], x[1], x[0]);
|
|
262
232
|
|
|
263
233
|
size_t i;
|
|
264
234
|
for (i = 0; i + 1 < ny; i += 2) {
|
|
265
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
235
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
236
|
+
y += 4;
|
|
237
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
238
|
+
dis[i] = _mm_cvtss_f32(accu);
|
|
239
|
+
accu = _mm_shuffle_ps(accu, accu, 3);
|
|
240
|
+
dis[i + 1] = _mm_cvtss_f32(accu);
|
|
270
241
|
}
|
|
271
242
|
if (i < ny) { // handle odd case
|
|
272
243
|
dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]);
|
|
273
244
|
}
|
|
274
245
|
}
|
|
275
246
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
template<class ElementOp>
|
|
279
|
-
void fvec_op_ny_D4 (float * dis, const float * x,
|
|
280
|
-
const float * y, size_t ny)
|
|
281
|
-
{
|
|
247
|
+
template <class ElementOp>
|
|
248
|
+
void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
|
|
282
249
|
__m128 x0 = _mm_loadu_ps(x);
|
|
283
250
|
|
|
284
251
|
for (size_t i = 0; i < ny; i++) {
|
|
285
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps
|
|
286
|
-
|
|
287
|
-
accu = _mm_hadd_ps
|
|
288
|
-
|
|
252
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
253
|
+
y += 4;
|
|
254
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
255
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
256
|
+
dis[i] = _mm_cvtss_f32(accu);
|
|
289
257
|
}
|
|
290
258
|
}
|
|
291
259
|
|
|
292
|
-
template<class ElementOp>
|
|
293
|
-
void fvec_op_ny_D8
|
|
294
|
-
const float * y, size_t ny)
|
|
295
|
-
{
|
|
260
|
+
template <class ElementOp>
|
|
261
|
+
void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
|
|
296
262
|
__m128 x0 = _mm_loadu_ps(x);
|
|
297
263
|
__m128 x1 = _mm_loadu_ps(x + 4);
|
|
298
264
|
|
|
299
265
|
for (size_t i = 0; i < ny; i++) {
|
|
300
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps
|
|
301
|
-
|
|
302
|
-
accu =
|
|
303
|
-
|
|
304
|
-
|
|
266
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
267
|
+
y += 4;
|
|
268
|
+
accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
|
|
269
|
+
y += 4;
|
|
270
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
271
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
272
|
+
dis[i] = _mm_cvtss_f32(accu);
|
|
305
273
|
}
|
|
306
274
|
}
|
|
307
275
|
|
|
308
|
-
template<class ElementOp>
|
|
309
|
-
void fvec_op_ny_D12
|
|
310
|
-
const float * y, size_t ny)
|
|
311
|
-
{
|
|
276
|
+
template <class ElementOp>
|
|
277
|
+
void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
|
|
312
278
|
__m128 x0 = _mm_loadu_ps(x);
|
|
313
279
|
__m128 x1 = _mm_loadu_ps(x + 4);
|
|
314
280
|
__m128 x2 = _mm_loadu_ps(x + 8);
|
|
315
281
|
|
|
316
282
|
for (size_t i = 0; i < ny; i++) {
|
|
317
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps
|
|
318
|
-
|
|
319
|
-
accu
|
|
320
|
-
|
|
321
|
-
accu =
|
|
322
|
-
|
|
283
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
284
|
+
y += 4;
|
|
285
|
+
accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
|
|
286
|
+
y += 4;
|
|
287
|
+
accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y)));
|
|
288
|
+
y += 4;
|
|
289
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
290
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
291
|
+
dis[i] = _mm_cvtss_f32(accu);
|
|
323
292
|
}
|
|
324
293
|
}
|
|
325
294
|
|
|
326
|
-
|
|
327
|
-
|
|
328
295
|
} // anonymous namespace
|
|
329
296
|
|
|
330
|
-
void fvec_L2sqr_ny
|
|
331
|
-
|
|
297
|
+
void fvec_L2sqr_ny(
|
|
298
|
+
float* dis,
|
|
299
|
+
const float* x,
|
|
300
|
+
const float* y,
|
|
301
|
+
size_t d,
|
|
302
|
+
size_t ny) {
|
|
332
303
|
// optimized for a few special cases
|
|
333
304
|
|
|
334
|
-
#define DISPATCH(dval)
|
|
335
|
-
case dval
|
|
336
|
-
fvec_op_ny_D
|
|
305
|
+
#define DISPATCH(dval) \
|
|
306
|
+
case dval: \
|
|
307
|
+
fvec_op_ny_D##dval<ElementOpL2>(dis, x, y, ny); \
|
|
337
308
|
return;
|
|
338
309
|
|
|
339
|
-
switch(d) {
|
|
310
|
+
switch (d) {
|
|
340
311
|
DISPATCH(1)
|
|
341
312
|
DISPATCH(2)
|
|
342
313
|
DISPATCH(4)
|
|
343
314
|
DISPATCH(8)
|
|
344
315
|
DISPATCH(12)
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
316
|
+
default:
|
|
317
|
+
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
318
|
+
return;
|
|
348
319
|
}
|
|
349
320
|
#undef DISPATCH
|
|
350
|
-
|
|
351
321
|
}
|
|
352
322
|
|
|
353
|
-
void fvec_inner_products_ny
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
323
|
+
void fvec_inner_products_ny(
|
|
324
|
+
float* dis,
|
|
325
|
+
const float* x,
|
|
326
|
+
const float* y,
|
|
327
|
+
size_t d,
|
|
328
|
+
size_t ny) {
|
|
329
|
+
#define DISPATCH(dval) \
|
|
330
|
+
case dval: \
|
|
331
|
+
fvec_op_ny_D##dval<ElementOpIP>(dis, x, y, ny); \
|
|
359
332
|
return;
|
|
360
333
|
|
|
361
|
-
switch(d) {
|
|
334
|
+
switch (d) {
|
|
362
335
|
DISPATCH(1)
|
|
363
336
|
DISPATCH(2)
|
|
364
337
|
DISPATCH(4)
|
|
365
338
|
DISPATCH(8)
|
|
366
339
|
DISPATCH(12)
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
340
|
+
default:
|
|
341
|
+
fvec_inner_products_ny_ref(dis, x, y, d, ny);
|
|
342
|
+
return;
|
|
370
343
|
}
|
|
371
344
|
#undef DISPATCH
|
|
372
|
-
|
|
373
345
|
}
|
|
374
346
|
|
|
375
|
-
|
|
376
|
-
|
|
377
347
|
#endif
|
|
378
348
|
|
|
379
349
|
#ifdef USE_AVX
|
|
380
350
|
|
|
381
351
|
// reads 0 <= d < 8 floats as __m256
|
|
382
|
-
static inline __m256 masked_read_8
|
|
383
|
-
|
|
384
|
-
assert (0 <= d && d < 8);
|
|
352
|
+
static inline __m256 masked_read_8(int d, const float* x) {
|
|
353
|
+
assert(0 <= d && d < 8);
|
|
385
354
|
if (d < 4) {
|
|
386
|
-
__m256 res = _mm256_setzero_ps
|
|
387
|
-
res = _mm256_insertf128_ps
|
|
355
|
+
__m256 res = _mm256_setzero_ps();
|
|
356
|
+
res = _mm256_insertf128_ps(res, masked_read(d, x), 0);
|
|
388
357
|
return res;
|
|
389
358
|
} else {
|
|
390
|
-
__m256 res = _mm256_setzero_ps
|
|
391
|
-
res = _mm256_insertf128_ps
|
|
392
|
-
res = _mm256_insertf128_ps
|
|
359
|
+
__m256 res = _mm256_setzero_ps();
|
|
360
|
+
res = _mm256_insertf128_ps(res, _mm_loadu_ps(x), 0);
|
|
361
|
+
res = _mm256_insertf128_ps(res, masked_read(d - 4, x + 4), 1);
|
|
393
362
|
return res;
|
|
394
363
|
}
|
|
395
364
|
}
|
|
396
365
|
|
|
397
|
-
float fvec_inner_product
|
|
398
|
-
const float * y,
|
|
399
|
-
size_t d)
|
|
400
|
-
{
|
|
366
|
+
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
401
367
|
__m256 msum1 = _mm256_setzero_ps();
|
|
402
368
|
|
|
403
369
|
while (d >= 8) {
|
|
404
|
-
__m256 mx = _mm256_loadu_ps
|
|
405
|
-
|
|
406
|
-
|
|
370
|
+
__m256 mx = _mm256_loadu_ps(x);
|
|
371
|
+
x += 8;
|
|
372
|
+
__m256 my = _mm256_loadu_ps(y);
|
|
373
|
+
y += 8;
|
|
374
|
+
msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(mx, my));
|
|
407
375
|
d -= 8;
|
|
408
376
|
}
|
|
409
377
|
|
|
410
378
|
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
411
|
-
msum2
|
|
379
|
+
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
|
|
412
380
|
|
|
413
381
|
if (d >= 4) {
|
|
414
|
-
__m128 mx = _mm_loadu_ps
|
|
415
|
-
|
|
416
|
-
|
|
382
|
+
__m128 mx = _mm_loadu_ps(x);
|
|
383
|
+
x += 4;
|
|
384
|
+
__m128 my = _mm_loadu_ps(y);
|
|
385
|
+
y += 4;
|
|
386
|
+
msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
|
|
417
387
|
d -= 4;
|
|
418
388
|
}
|
|
419
389
|
|
|
420
390
|
if (d > 0) {
|
|
421
|
-
__m128 mx = masked_read
|
|
422
|
-
__m128 my = masked_read
|
|
423
|
-
msum2 = _mm_add_ps
|
|
391
|
+
__m128 mx = masked_read(d, x);
|
|
392
|
+
__m128 my = masked_read(d, y);
|
|
393
|
+
msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
|
|
424
394
|
}
|
|
425
395
|
|
|
426
|
-
msum2 = _mm_hadd_ps
|
|
427
|
-
msum2 = _mm_hadd_ps
|
|
428
|
-
return
|
|
396
|
+
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
397
|
+
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
398
|
+
return _mm_cvtss_f32(msum2);
|
|
429
399
|
}
|
|
430
400
|
|
|
431
|
-
float fvec_L2sqr
|
|
432
|
-
const float * y,
|
|
433
|
-
size_t d)
|
|
434
|
-
{
|
|
401
|
+
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
435
402
|
__m256 msum1 = _mm256_setzero_ps();
|
|
436
403
|
|
|
437
404
|
while (d >= 8) {
|
|
438
|
-
__m256 mx = _mm256_loadu_ps
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
405
|
+
__m256 mx = _mm256_loadu_ps(x);
|
|
406
|
+
x += 8;
|
|
407
|
+
__m256 my = _mm256_loadu_ps(y);
|
|
408
|
+
y += 8;
|
|
409
|
+
const __m256 a_m_b1 = _mm256_sub_ps(mx, my);
|
|
410
|
+
msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(a_m_b1, a_m_b1));
|
|
442
411
|
d -= 8;
|
|
443
412
|
}
|
|
444
413
|
|
|
445
414
|
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
446
|
-
msum2
|
|
415
|
+
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
|
|
447
416
|
|
|
448
417
|
if (d >= 4) {
|
|
449
|
-
__m128 mx = _mm_loadu_ps
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
418
|
+
__m128 mx = _mm_loadu_ps(x);
|
|
419
|
+
x += 4;
|
|
420
|
+
__m128 my = _mm_loadu_ps(y);
|
|
421
|
+
y += 4;
|
|
422
|
+
const __m128 a_m_b1 = _mm_sub_ps(mx, my);
|
|
423
|
+
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
|
|
453
424
|
d -= 4;
|
|
454
425
|
}
|
|
455
426
|
|
|
456
427
|
if (d > 0) {
|
|
457
|
-
__m128 mx = masked_read
|
|
458
|
-
__m128 my = masked_read
|
|
459
|
-
__m128 a_m_b1 = mx
|
|
460
|
-
msum2
|
|
428
|
+
__m128 mx = masked_read(d, x);
|
|
429
|
+
__m128 my = masked_read(d, y);
|
|
430
|
+
__m128 a_m_b1 = _mm_sub_ps(mx, my);
|
|
431
|
+
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
|
|
461
432
|
}
|
|
462
433
|
|
|
463
|
-
msum2 = _mm_hadd_ps
|
|
464
|
-
msum2 = _mm_hadd_ps
|
|
465
|
-
return
|
|
434
|
+
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
435
|
+
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
436
|
+
return _mm_cvtss_f32(msum2);
|
|
466
437
|
}
|
|
467
438
|
|
|
468
|
-
float fvec_L1
|
|
469
|
-
{
|
|
439
|
+
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
470
440
|
__m256 msum1 = _mm256_setzero_ps();
|
|
471
|
-
__m256 signmask =
|
|
441
|
+
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
|
|
472
442
|
|
|
473
443
|
while (d >= 8) {
|
|
474
|
-
__m256 mx = _mm256_loadu_ps
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
444
|
+
__m256 mx = _mm256_loadu_ps(x);
|
|
445
|
+
x += 8;
|
|
446
|
+
__m256 my = _mm256_loadu_ps(y);
|
|
447
|
+
y += 8;
|
|
448
|
+
const __m256 a_m_b = _mm256_sub_ps(mx, my);
|
|
449
|
+
msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b));
|
|
478
450
|
d -= 8;
|
|
479
451
|
}
|
|
480
452
|
|
|
481
453
|
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
482
|
-
msum2
|
|
483
|
-
__m128 signmask2 =
|
|
454
|
+
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
|
|
455
|
+
__m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL));
|
|
484
456
|
|
|
485
457
|
if (d >= 4) {
|
|
486
|
-
__m128 mx = _mm_loadu_ps
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
458
|
+
__m128 mx = _mm_loadu_ps(x);
|
|
459
|
+
x += 4;
|
|
460
|
+
__m128 my = _mm_loadu_ps(y);
|
|
461
|
+
y += 4;
|
|
462
|
+
const __m128 a_m_b = _mm_sub_ps(mx, my);
|
|
463
|
+
msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
|
490
464
|
d -= 4;
|
|
491
465
|
}
|
|
492
466
|
|
|
493
467
|
if (d > 0) {
|
|
494
|
-
__m128 mx = masked_read
|
|
495
|
-
__m128 my = masked_read
|
|
496
|
-
__m128 a_m_b = mx
|
|
497
|
-
msum2
|
|
468
|
+
__m128 mx = masked_read(d, x);
|
|
469
|
+
__m128 my = masked_read(d, y);
|
|
470
|
+
__m128 a_m_b = _mm_sub_ps(mx, my);
|
|
471
|
+
msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
|
498
472
|
}
|
|
499
473
|
|
|
500
|
-
msum2 = _mm_hadd_ps
|
|
501
|
-
msum2 = _mm_hadd_ps
|
|
502
|
-
return
|
|
474
|
+
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
475
|
+
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
476
|
+
return _mm_cvtss_f32(msum2);
|
|
503
477
|
}
|
|
504
478
|
|
|
505
|
-
float fvec_Linf
|
|
506
|
-
{
|
|
479
|
+
float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
507
480
|
__m256 msum1 = _mm256_setzero_ps();
|
|
508
|
-
__m256 signmask =
|
|
481
|
+
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
|
|
509
482
|
|
|
510
483
|
while (d >= 8) {
|
|
511
|
-
__m256 mx = _mm256_loadu_ps
|
|
512
|
-
|
|
513
|
-
|
|
484
|
+
__m256 mx = _mm256_loadu_ps(x);
|
|
485
|
+
x += 8;
|
|
486
|
+
__m256 my = _mm256_loadu_ps(y);
|
|
487
|
+
y += 8;
|
|
488
|
+
const __m256 a_m_b = _mm256_sub_ps(mx, my);
|
|
514
489
|
msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b));
|
|
515
490
|
d -= 8;
|
|
516
491
|
}
|
|
517
492
|
|
|
518
493
|
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
519
|
-
msum2 = _mm_max_ps
|
|
520
|
-
__m128 signmask2 =
|
|
494
|
+
msum2 = _mm_max_ps(msum2, _mm256_extractf128_ps(msum1, 0));
|
|
495
|
+
__m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL));
|
|
521
496
|
|
|
522
497
|
if (d >= 4) {
|
|
523
|
-
__m128 mx = _mm_loadu_ps
|
|
524
|
-
|
|
525
|
-
|
|
498
|
+
__m128 mx = _mm_loadu_ps(x);
|
|
499
|
+
x += 4;
|
|
500
|
+
__m128 my = _mm_loadu_ps(y);
|
|
501
|
+
y += 4;
|
|
502
|
+
const __m128 a_m_b = _mm_sub_ps(mx, my);
|
|
526
503
|
msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
|
527
504
|
d -= 4;
|
|
528
505
|
}
|
|
529
506
|
|
|
530
507
|
if (d > 0) {
|
|
531
|
-
__m128 mx = masked_read
|
|
532
|
-
__m128 my = masked_read
|
|
533
|
-
__m128 a_m_b = mx
|
|
508
|
+
__m128 mx = masked_read(d, x);
|
|
509
|
+
__m128 my = masked_read(d, y);
|
|
510
|
+
__m128 a_m_b = _mm_sub_ps(mx, my);
|
|
534
511
|
msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
|
535
512
|
}
|
|
536
513
|
|
|
537
514
|
msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2);
|
|
538
|
-
msum2 = _mm_max_ps(msum2, _mm_shuffle_ps
|
|
539
|
-
return
|
|
515
|
+
msum2 = _mm_max_ps(msum2, _mm_shuffle_ps(msum2, msum2, 1));
|
|
516
|
+
return _mm_cvtss_f32(msum2);
|
|
540
517
|
}
|
|
541
518
|
|
|
542
519
|
#elif defined(__SSE3__) // But not AVX
|
|
543
520
|
|
|
544
|
-
float fvec_L1
|
|
545
|
-
|
|
546
|
-
return fvec_L1_ref (x, y, d);
|
|
521
|
+
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
522
|
+
return fvec_L1_ref(x, y, d);
|
|
547
523
|
}
|
|
548
524
|
|
|
549
|
-
float fvec_Linf
|
|
550
|
-
|
|
551
|
-
return fvec_Linf_ref (x, y, d);
|
|
525
|
+
float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
526
|
+
return fvec_Linf_ref(x, y, d);
|
|
552
527
|
}
|
|
553
528
|
|
|
554
|
-
|
|
555
|
-
float fvec_L2sqr (const float * x,
|
|
556
|
-
const float * y,
|
|
557
|
-
size_t d)
|
|
558
|
-
{
|
|
529
|
+
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
559
530
|
__m128 msum1 = _mm_setzero_ps();
|
|
560
531
|
|
|
561
532
|
while (d >= 4) {
|
|
562
|
-
__m128 mx = _mm_loadu_ps
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
533
|
+
__m128 mx = _mm_loadu_ps(x);
|
|
534
|
+
x += 4;
|
|
535
|
+
__m128 my = _mm_loadu_ps(y);
|
|
536
|
+
y += 4;
|
|
537
|
+
const __m128 a_m_b1 = _mm_sub_ps(mx, my);
|
|
538
|
+
msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
|
|
566
539
|
d -= 4;
|
|
567
540
|
}
|
|
568
541
|
|
|
569
542
|
if (d > 0) {
|
|
570
543
|
// add the last 1, 2 or 3 values
|
|
571
|
-
__m128 mx = masked_read
|
|
572
|
-
__m128 my = masked_read
|
|
573
|
-
__m128 a_m_b1 = mx
|
|
574
|
-
msum1
|
|
544
|
+
__m128 mx = masked_read(d, x);
|
|
545
|
+
__m128 my = masked_read(d, y);
|
|
546
|
+
__m128 a_m_b1 = _mm_sub_ps(mx, my);
|
|
547
|
+
msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
|
|
575
548
|
}
|
|
576
549
|
|
|
577
|
-
msum1 = _mm_hadd_ps
|
|
578
|
-
msum1 = _mm_hadd_ps
|
|
579
|
-
return
|
|
550
|
+
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
551
|
+
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
552
|
+
return _mm_cvtss_f32(msum1);
|
|
580
553
|
}
|
|
581
554
|
|
|
582
|
-
|
|
583
|
-
float fvec_inner_product (const float * x,
|
|
584
|
-
const float * y,
|
|
585
|
-
size_t d)
|
|
586
|
-
{
|
|
555
|
+
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
587
556
|
__m128 mx, my;
|
|
588
557
|
__m128 msum1 = _mm_setzero_ps();
|
|
589
558
|
|
|
590
559
|
while (d >= 4) {
|
|
591
|
-
mx = _mm_loadu_ps
|
|
592
|
-
|
|
593
|
-
|
|
560
|
+
mx = _mm_loadu_ps(x);
|
|
561
|
+
x += 4;
|
|
562
|
+
my = _mm_loadu_ps(y);
|
|
563
|
+
y += 4;
|
|
564
|
+
msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, my));
|
|
594
565
|
d -= 4;
|
|
595
566
|
}
|
|
596
567
|
|
|
597
568
|
// add the last 1, 2, or 3 values
|
|
598
|
-
mx = masked_read
|
|
599
|
-
my = masked_read
|
|
600
|
-
__m128 prod = _mm_mul_ps
|
|
569
|
+
mx = masked_read(d, x);
|
|
570
|
+
my = masked_read(d, y);
|
|
571
|
+
__m128 prod = _mm_mul_ps(mx, my);
|
|
601
572
|
|
|
602
|
-
msum1 = _mm_add_ps
|
|
573
|
+
msum1 = _mm_add_ps(msum1, prod);
|
|
603
574
|
|
|
604
|
-
msum1 = _mm_hadd_ps
|
|
605
|
-
msum1 = _mm_hadd_ps
|
|
606
|
-
return
|
|
575
|
+
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
576
|
+
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
577
|
+
return _mm_cvtss_f32(msum1);
|
|
607
578
|
}
|
|
608
579
|
|
|
609
580
|
#elif defined(__aarch64__)
|
|
610
581
|
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
{
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
582
|
+
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
583
|
+
float32x4_t accux4 = vdupq_n_f32(0);
|
|
584
|
+
const size_t d_simd = d - (d & 3);
|
|
585
|
+
size_t i;
|
|
586
|
+
for (i = 0; i < d_simd; i += 4) {
|
|
587
|
+
float32x4_t xi = vld1q_f32(x + i);
|
|
588
|
+
float32x4_t yi = vld1q_f32(y + i);
|
|
589
|
+
float32x4_t sq = vsubq_f32(xi, yi);
|
|
590
|
+
accux4 = vfmaq_f32(accux4, sq, sq);
|
|
591
|
+
}
|
|
592
|
+
float32x4_t accux2 = vpaddq_f32(accux4, accux4);
|
|
593
|
+
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
|
|
594
|
+
for (; i < d; ++i) {
|
|
595
|
+
float32_t xi = x[i];
|
|
596
|
+
float32_t yi = y[i];
|
|
597
|
+
float32_t sq = xi - yi;
|
|
598
|
+
accux1 += sq * sq;
|
|
599
|
+
}
|
|
600
|
+
return accux1;
|
|
601
|
+
}
|
|
602
|
+
|
|
603
|
+
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
604
|
+
float32x4_t accux4 = vdupq_n_f32(0);
|
|
605
|
+
const size_t d_simd = d - (d & 3);
|
|
606
|
+
size_t i;
|
|
607
|
+
for (i = 0; i < d_simd; i += 4) {
|
|
608
|
+
float32x4_t xi = vld1q_f32(x + i);
|
|
609
|
+
float32x4_t yi = vld1q_f32(y + i);
|
|
610
|
+
accux4 = vfmaq_f32(accux4, xi, yi);
|
|
623
611
|
}
|
|
624
|
-
float32x4_t
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
size_t d)
|
|
631
|
-
{
|
|
632
|
-
if (d & 3) return fvec_inner_product_ref (x, y, d);
|
|
633
|
-
float32x4_t accu = vdupq_n_f32 (0);
|
|
634
|
-
for (size_t i = 0; i < d; i += 4) {
|
|
635
|
-
float32x4_t xi = vld1q_f32 (x + i);
|
|
636
|
-
float32x4_t yi = vld1q_f32 (y + i);
|
|
637
|
-
accu = vfmaq_f32 (accu, xi, yi);
|
|
612
|
+
float32x4_t accux2 = vpaddq_f32(accux4, accux4);
|
|
613
|
+
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
|
|
614
|
+
for (; i < d; ++i) {
|
|
615
|
+
float32_t xi = x[i];
|
|
616
|
+
float32_t yi = y[i];
|
|
617
|
+
accux1 += xi * yi;
|
|
638
618
|
}
|
|
639
|
-
|
|
640
|
-
return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
|
|
619
|
+
return accux1;
|
|
641
620
|
}
|
|
642
621
|
|
|
643
|
-
float fvec_norm_L2sqr
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
for (
|
|
648
|
-
float32x4_t xi = vld1q_f32
|
|
649
|
-
|
|
622
|
+
float fvec_norm_L2sqr(const float* x, size_t d) {
|
|
623
|
+
float32x4_t accux4 = vdupq_n_f32(0);
|
|
624
|
+
const size_t d_simd = d - (d & 3);
|
|
625
|
+
size_t i;
|
|
626
|
+
for (i = 0; i < d_simd; i += 4) {
|
|
627
|
+
float32x4_t xi = vld1q_f32(x + i);
|
|
628
|
+
accux4 = vfmaq_f32(accux4, xi, xi);
|
|
650
629
|
}
|
|
651
|
-
float32x4_t
|
|
652
|
-
|
|
630
|
+
float32x4_t accux2 = vpaddq_f32(accux4, accux4);
|
|
631
|
+
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
|
|
632
|
+
for (; i < d; ++i) {
|
|
633
|
+
float32_t xi = x[i];
|
|
634
|
+
accux1 += xi * xi;
|
|
635
|
+
}
|
|
636
|
+
return accux1;
|
|
653
637
|
}
|
|
654
638
|
|
|
655
639
|
// not optimized for ARM
|
|
656
|
-
void fvec_L2sqr_ny
|
|
657
|
-
|
|
658
|
-
|
|
640
|
+
void fvec_L2sqr_ny(
|
|
641
|
+
float* dis,
|
|
642
|
+
const float* x,
|
|
643
|
+
const float* y,
|
|
644
|
+
size_t d,
|
|
645
|
+
size_t ny) {
|
|
646
|
+
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
659
647
|
}
|
|
660
648
|
|
|
661
|
-
float fvec_L1
|
|
662
|
-
|
|
663
|
-
return fvec_L1_ref (x, y, d);
|
|
649
|
+
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
650
|
+
return fvec_L1_ref(x, y, d);
|
|
664
651
|
}
|
|
665
652
|
|
|
666
|
-
float fvec_Linf
|
|
667
|
-
|
|
668
|
-
return fvec_Linf_ref (x, y, d);
|
|
653
|
+
float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
654
|
+
return fvec_Linf_ref(x, y, d);
|
|
669
655
|
}
|
|
670
656
|
|
|
657
|
+
void fvec_inner_products_ny(
|
|
658
|
+
float* dis,
|
|
659
|
+
const float* x,
|
|
660
|
+
const float* y,
|
|
661
|
+
size_t d,
|
|
662
|
+
size_t ny) {
|
|
663
|
+
fvec_inner_products_ny_ref(dis, x, y, d, ny);
|
|
664
|
+
}
|
|
671
665
|
|
|
672
666
|
#else
|
|
673
667
|
// scalar implementation
|
|
674
668
|
|
|
675
|
-
float fvec_L2sqr
|
|
676
|
-
|
|
677
|
-
size_t d)
|
|
678
|
-
{
|
|
679
|
-
return fvec_L2sqr_ref (x, y, d);
|
|
669
|
+
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
670
|
+
return fvec_L2sqr_ref(x, y, d);
|
|
680
671
|
}
|
|
681
672
|
|
|
682
|
-
float fvec_L1
|
|
683
|
-
|
|
684
|
-
return fvec_L1_ref (x, y, d);
|
|
673
|
+
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
674
|
+
return fvec_L1_ref(x, y, d);
|
|
685
675
|
}
|
|
686
676
|
|
|
687
|
-
float fvec_Linf
|
|
688
|
-
|
|
689
|
-
return fvec_Linf_ref (x, y, d);
|
|
677
|
+
float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
678
|
+
return fvec_Linf_ref(x, y, d);
|
|
690
679
|
}
|
|
691
680
|
|
|
692
|
-
float fvec_inner_product
|
|
693
|
-
|
|
694
|
-
size_t d)
|
|
695
|
-
{
|
|
696
|
-
return fvec_inner_product_ref (x, y, d);
|
|
681
|
+
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
682
|
+
return fvec_inner_product_ref(x, y, d);
|
|
697
683
|
}
|
|
698
684
|
|
|
699
|
-
float fvec_norm_L2sqr
|
|
700
|
-
|
|
701
|
-
return fvec_norm_L2sqr_ref (x, d);
|
|
685
|
+
float fvec_norm_L2sqr(const float* x, size_t d) {
|
|
686
|
+
return fvec_norm_L2sqr_ref(x, d);
|
|
702
687
|
}
|
|
703
688
|
|
|
704
|
-
void fvec_L2sqr_ny
|
|
705
|
-
|
|
706
|
-
|
|
689
|
+
void fvec_L2sqr_ny(
|
|
690
|
+
float* dis,
|
|
691
|
+
const float* x,
|
|
692
|
+
const float* y,
|
|
693
|
+
size_t d,
|
|
694
|
+
size_t ny) {
|
|
695
|
+
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
707
696
|
}
|
|
708
697
|
|
|
709
|
-
void fvec_inner_products_ny
|
|
710
|
-
|
|
711
|
-
|
|
698
|
+
void fvec_inner_products_ny(
|
|
699
|
+
float* dis,
|
|
700
|
+
const float* x,
|
|
701
|
+
const float* y,
|
|
702
|
+
size_t d,
|
|
703
|
+
size_t ny) {
|
|
704
|
+
fvec_inner_products_ny_ref(dis, x, y, d, ny);
|
|
712
705
|
}
|
|
713
706
|
|
|
714
|
-
|
|
715
707
|
#endif
|
|
716
708
|
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
709
|
/***************************************************************************
|
|
737
710
|
* heavily optimized table computations
|
|
738
711
|
***************************************************************************/
|
|
739
712
|
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
713
|
+
static inline void fvec_madd_ref(
|
|
714
|
+
size_t n,
|
|
715
|
+
const float* a,
|
|
716
|
+
float bf,
|
|
717
|
+
const float* b,
|
|
718
|
+
float* c) {
|
|
743
719
|
for (size_t i = 0; i < n; i++)
|
|
744
720
|
c[i] = a[i] + bf * b[i];
|
|
745
721
|
}
|
|
746
722
|
|
|
747
723
|
#ifdef __SSE3__
|
|
748
724
|
|
|
749
|
-
static inline void fvec_madd_sse
|
|
750
|
-
|
|
725
|
+
static inline void fvec_madd_sse(
|
|
726
|
+
size_t n,
|
|
727
|
+
const float* a,
|
|
728
|
+
float bf,
|
|
729
|
+
const float* b,
|
|
730
|
+
float* c) {
|
|
751
731
|
n >>= 2;
|
|
752
|
-
__m128 bf4 = _mm_set_ps1
|
|
753
|
-
__m128
|
|
754
|
-
__m128
|
|
755
|
-
__m128
|
|
732
|
+
__m128 bf4 = _mm_set_ps1(bf);
|
|
733
|
+
__m128* a4 = (__m128*)a;
|
|
734
|
+
__m128* b4 = (__m128*)b;
|
|
735
|
+
__m128* c4 = (__m128*)c;
|
|
756
736
|
|
|
757
737
|
while (n--) {
|
|
758
|
-
*c4 = _mm_add_ps
|
|
738
|
+
*c4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
|
|
759
739
|
b4++;
|
|
760
740
|
a4++;
|
|
761
741
|
c4++;
|
|
762
742
|
}
|
|
763
743
|
}
|
|
764
744
|
|
|
765
|
-
void fvec_madd
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
if ((n & 3) == 0 &&
|
|
769
|
-
((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
|
770
|
-
fvec_madd_sse (n, a, bf, b, c);
|
|
745
|
+
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
|
746
|
+
if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
|
747
|
+
fvec_madd_sse(n, a, bf, b, c);
|
|
771
748
|
else
|
|
772
|
-
fvec_madd_ref
|
|
749
|
+
fvec_madd_ref(n, a, bf, b, c);
|
|
773
750
|
}
|
|
774
751
|
|
|
775
752
|
#else
|
|
776
753
|
|
|
777
|
-
void fvec_madd
|
|
778
|
-
|
|
779
|
-
{
|
|
780
|
-
fvec_madd_ref (n, a, bf, b, c);
|
|
754
|
+
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
|
755
|
+
fvec_madd_ref(n, a, bf, b, c);
|
|
781
756
|
}
|
|
782
757
|
|
|
783
758
|
#endif
|
|
784
759
|
|
|
785
|
-
static inline int fvec_madd_and_argmin_ref
|
|
786
|
-
|
|
760
|
+
static inline int fvec_madd_and_argmin_ref(
|
|
761
|
+
size_t n,
|
|
762
|
+
const float* a,
|
|
763
|
+
float bf,
|
|
764
|
+
const float* b,
|
|
765
|
+
float* c) {
|
|
787
766
|
float vmin = 1e20;
|
|
788
767
|
int imin = -1;
|
|
789
768
|
|
|
@@ -799,125 +778,100 @@ static inline int fvec_madd_and_argmin_ref (size_t n, const float *a,
|
|
|
799
778
|
|
|
800
779
|
#ifdef __SSE3__
|
|
801
780
|
|
|
802
|
-
static inline int fvec_madd_and_argmin_sse
|
|
803
|
-
size_t n,
|
|
804
|
-
|
|
781
|
+
static inline int fvec_madd_and_argmin_sse(
|
|
782
|
+
size_t n,
|
|
783
|
+
const float* a,
|
|
784
|
+
float bf,
|
|
785
|
+
const float* b,
|
|
786
|
+
float* c) {
|
|
805
787
|
n >>= 2;
|
|
806
|
-
__m128 bf4 = _mm_set_ps1
|
|
807
|
-
__m128 vmin4 = _mm_set_ps1
|
|
808
|
-
__m128i imin4 = _mm_set1_epi32
|
|
809
|
-
__m128i idx4 = _mm_set_epi32
|
|
810
|
-
__m128i inc4 = _mm_set1_epi32
|
|
811
|
-
__m128
|
|
812
|
-
__m128
|
|
813
|
-
__m128
|
|
788
|
+
__m128 bf4 = _mm_set_ps1(bf);
|
|
789
|
+
__m128 vmin4 = _mm_set_ps1(1e20);
|
|
790
|
+
__m128i imin4 = _mm_set1_epi32(-1);
|
|
791
|
+
__m128i idx4 = _mm_set_epi32(3, 2, 1, 0);
|
|
792
|
+
__m128i inc4 = _mm_set1_epi32(4);
|
|
793
|
+
__m128* a4 = (__m128*)a;
|
|
794
|
+
__m128* b4 = (__m128*)b;
|
|
795
|
+
__m128* c4 = (__m128*)c;
|
|
814
796
|
|
|
815
797
|
while (n--) {
|
|
816
|
-
__m128 vc4 = _mm_add_ps
|
|
798
|
+
__m128 vc4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
|
|
817
799
|
*c4 = vc4;
|
|
818
|
-
__m128i mask = (
|
|
800
|
+
__m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
|
|
819
801
|
// imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower!
|
|
820
802
|
|
|
821
|
-
imin4 = _mm_or_si128
|
|
822
|
-
|
|
823
|
-
vmin4 = _mm_min_ps
|
|
803
|
+
imin4 = _mm_or_si128(
|
|
804
|
+
_mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
|
|
805
|
+
vmin4 = _mm_min_ps(vmin4, vc4);
|
|
824
806
|
b4++;
|
|
825
807
|
a4++;
|
|
826
808
|
c4++;
|
|
827
|
-
idx4 = _mm_add_epi32
|
|
809
|
+
idx4 = _mm_add_epi32(idx4, inc4);
|
|
828
810
|
}
|
|
829
811
|
|
|
830
812
|
// 4 values -> 2
|
|
831
813
|
{
|
|
832
|
-
idx4 = _mm_shuffle_epi32
|
|
833
|
-
__m128 vc4 = _mm_shuffle_ps
|
|
834
|
-
__m128i mask = (
|
|
835
|
-
imin4 = _mm_or_si128
|
|
836
|
-
|
|
837
|
-
vmin4 = _mm_min_ps
|
|
814
|
+
idx4 = _mm_shuffle_epi32(imin4, 3 << 2 | 2);
|
|
815
|
+
__m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 3 << 2 | 2);
|
|
816
|
+
__m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
|
|
817
|
+
imin4 = _mm_or_si128(
|
|
818
|
+
_mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
|
|
819
|
+
vmin4 = _mm_min_ps(vmin4, vc4);
|
|
838
820
|
}
|
|
839
821
|
// 2 values -> 1
|
|
840
822
|
{
|
|
841
|
-
idx4 = _mm_shuffle_epi32
|
|
842
|
-
__m128 vc4 = _mm_shuffle_ps
|
|
843
|
-
__m128i mask = (
|
|
844
|
-
imin4 = _mm_or_si128
|
|
845
|
-
|
|
823
|
+
idx4 = _mm_shuffle_epi32(imin4, 1);
|
|
824
|
+
__m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 1);
|
|
825
|
+
__m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
|
|
826
|
+
imin4 = _mm_or_si128(
|
|
827
|
+
_mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
|
|
846
828
|
// vmin4 = _mm_min_ps (vmin4, vc4);
|
|
847
829
|
}
|
|
848
|
-
return _mm_cvtsi128_si32
|
|
830
|
+
return _mm_cvtsi128_si32(imin4);
|
|
849
831
|
}
|
|
850
832
|
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
833
|
+
int fvec_madd_and_argmin(
|
|
834
|
+
size_t n,
|
|
835
|
+
const float* a,
|
|
836
|
+
float bf,
|
|
837
|
+
const float* b,
|
|
838
|
+
float* c) {
|
|
839
|
+
if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
|
840
|
+
return fvec_madd_and_argmin_sse(n, a, bf, b, c);
|
|
858
841
|
else
|
|
859
|
-
return fvec_madd_and_argmin_ref
|
|
842
|
+
return fvec_madd_and_argmin_ref(n, a, bf, b, c);
|
|
860
843
|
}
|
|
861
844
|
|
|
862
845
|
#else
|
|
863
846
|
|
|
864
|
-
int fvec_madd_and_argmin
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
847
|
+
int fvec_madd_and_argmin(
|
|
848
|
+
size_t n,
|
|
849
|
+
const float* a,
|
|
850
|
+
float bf,
|
|
851
|
+
const float* b,
|
|
852
|
+
float* c) {
|
|
853
|
+
return fvec_madd_and_argmin_ref(n, a, bf, b, c);
|
|
868
854
|
}
|
|
869
855
|
|
|
870
856
|
#endif
|
|
871
857
|
|
|
872
|
-
|
|
873
858
|
/***************************************************************************
|
|
874
859
|
* PQ tables computations
|
|
875
860
|
***************************************************************************/
|
|
876
861
|
|
|
877
|
-
#ifdef __AVX2__
|
|
878
|
-
|
|
879
862
|
namespace {
|
|
880
863
|
|
|
881
|
-
|
|
882
|
-
// get even float32's of a and b, interleaved
|
|
883
|
-
simd8float32 geteven(simd8float32 a, simd8float32 b) {
|
|
884
|
-
return simd8float32(
|
|
885
|
-
_mm256_shuffle_ps(a.f, b.f, 0 << 0 | 2 << 2 | 0 << 4 | 2 << 6)
|
|
886
|
-
);
|
|
887
|
-
}
|
|
888
|
-
|
|
889
|
-
// get odd float32's of a and b, interleaved
|
|
890
|
-
simd8float32 getodd(simd8float32 a, simd8float32 b) {
|
|
891
|
-
return simd8float32(
|
|
892
|
-
_mm256_shuffle_ps(a.f, b.f, 1 << 0 | 3 << 2 | 1 << 4 | 3 << 6)
|
|
893
|
-
);
|
|
894
|
-
}
|
|
895
|
-
|
|
896
|
-
// 3 cycles
|
|
897
|
-
// if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
|
|
898
|
-
simd8float32 getlow128(simd8float32 a, simd8float32 b) {
|
|
899
|
-
return simd8float32(
|
|
900
|
-
_mm256_permute2f128_ps(a.f, b.f, 0 | 2 << 4)
|
|
901
|
-
);
|
|
902
|
-
}
|
|
903
|
-
|
|
904
|
-
simd8float32 gethigh128(simd8float32 a, simd8float32 b) {
|
|
905
|
-
return simd8float32(
|
|
906
|
-
_mm256_permute2f128_ps(a.f, b.f, 1 | 3 << 4)
|
|
907
|
-
);
|
|
908
|
-
}
|
|
909
|
-
|
|
910
864
|
/// compute the IP for dsub = 2 for 8 centroids and 4 sub-vectors at a time
|
|
911
|
-
template<bool is_inner_product>
|
|
865
|
+
template <bool is_inner_product>
|
|
912
866
|
void pq2_8cents_table(
|
|
913
867
|
const simd8float32 centroids[8],
|
|
914
868
|
const simd8float32 x,
|
|
915
|
-
float
|
|
916
|
-
|
|
917
|
-
|
|
869
|
+
float* out,
|
|
870
|
+
size_t ldo,
|
|
871
|
+
size_t nout = 4) {
|
|
918
872
|
simd8float32 ips[4];
|
|
919
873
|
|
|
920
|
-
for(int i = 0; i < 4; i++) {
|
|
874
|
+
for (int i = 0; i < 4; i++) {
|
|
921
875
|
simd8float32 p1, p2;
|
|
922
876
|
if (is_inner_product) {
|
|
923
877
|
p1 = x * centroids[2 * i];
|
|
@@ -941,21 +895,21 @@ void pq2_8cents_table(
|
|
|
941
895
|
simd8float32 ip1 = getlow128(ip13a, ip13b);
|
|
942
896
|
simd8float32 ip3 = gethigh128(ip13a, ip13b);
|
|
943
897
|
|
|
944
|
-
switch(nout) {
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
898
|
+
switch (nout) {
|
|
899
|
+
case 4:
|
|
900
|
+
ip3.storeu(out + 3 * ldo);
|
|
901
|
+
case 3:
|
|
902
|
+
ip2.storeu(out + 2 * ldo);
|
|
903
|
+
case 2:
|
|
904
|
+
ip1.storeu(out + 1 * ldo);
|
|
905
|
+
case 1:
|
|
906
|
+
ip0.storeu(out);
|
|
953
907
|
}
|
|
954
908
|
}
|
|
955
909
|
|
|
956
|
-
simd8float32 load_simd8float32_partial(const float
|
|
910
|
+
simd8float32 load_simd8float32_partial(const float* x, int n) {
|
|
957
911
|
ALIGNED(32) float tmp[8] = {0, 0, 0, 0, 0, 0, 0, 0};
|
|
958
|
-
float
|
|
912
|
+
float* wp = tmp;
|
|
959
913
|
for (int i = 0; i < n; i++) {
|
|
960
914
|
*wp++ = *x++;
|
|
961
915
|
}
|
|
@@ -964,25 +918,23 @@ simd8float32 load_simd8float32_partial(const float *x, int n) {
|
|
|
964
918
|
|
|
965
919
|
} // anonymous namespace
|
|
966
920
|
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
921
|
void compute_PQ_dis_tables_dsub2(
|
|
971
|
-
size_t d,
|
|
972
|
-
size_t
|
|
922
|
+
size_t d,
|
|
923
|
+
size_t ksub,
|
|
924
|
+
const float* all_centroids,
|
|
925
|
+
size_t nx,
|
|
926
|
+
const float* x,
|
|
973
927
|
bool is_inner_product,
|
|
974
|
-
float
|
|
975
|
-
{
|
|
928
|
+
float* dis_tables) {
|
|
976
929
|
size_t M = d / 2;
|
|
977
930
|
FAISS_THROW_IF_NOT(ksub % 8 == 0);
|
|
978
931
|
|
|
979
|
-
for(size_t m0 = 0; m0 < M; m0 += 4) {
|
|
932
|
+
for (size_t m0 = 0; m0 < M; m0 += 4) {
|
|
980
933
|
int m1 = std::min(M, m0 + 4);
|
|
981
|
-
for(int k0 = 0; k0 < ksub; k0 += 8) {
|
|
982
|
-
|
|
934
|
+
for (int k0 = 0; k0 < ksub; k0 += 8) {
|
|
983
935
|
simd8float32 centroids[8];
|
|
984
936
|
for (int k = 0; k < 8; k++) {
|
|
985
|
-
float centroid[8]
|
|
937
|
+
ALIGNED(32) float centroid[8];
|
|
986
938
|
size_t wp = 0;
|
|
987
939
|
size_t rp = (m0 * ksub + k + k0) * 2;
|
|
988
940
|
for (int m = m0; m < m1; m++) {
|
|
@@ -992,45 +944,33 @@ void compute_PQ_dis_tables_dsub2(
|
|
|
992
944
|
}
|
|
993
945
|
centroids[k] = simd8float32(centroid);
|
|
994
946
|
}
|
|
995
|
-
for(size_t i = 0; i < nx; i++) {
|
|
947
|
+
for (size_t i = 0; i < nx; i++) {
|
|
996
948
|
simd8float32 xi;
|
|
997
949
|
if (m1 == m0 + 4) {
|
|
998
950
|
xi.loadu(x + i * d + m0 * 2);
|
|
999
951
|
} else {
|
|
1000
|
-
xi = load_simd8float32_partial(
|
|
952
|
+
xi = load_simd8float32_partial(
|
|
953
|
+
x + i * d + m0 * 2, 2 * (m1 - m0));
|
|
1001
954
|
}
|
|
1002
955
|
|
|
1003
|
-
if(is_inner_product) {
|
|
956
|
+
if (is_inner_product) {
|
|
1004
957
|
pq2_8cents_table<true>(
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
958
|
+
centroids,
|
|
959
|
+
xi,
|
|
960
|
+
dis_tables + (i * M + m0) * ksub + k0,
|
|
961
|
+
ksub,
|
|
962
|
+
m1 - m0);
|
|
1009
963
|
} else {
|
|
1010
964
|
pq2_8cents_table<false>(
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
965
|
+
centroids,
|
|
966
|
+
xi,
|
|
967
|
+
dis_tables + (i * M + m0) * ksub + k0,
|
|
968
|
+
ksub,
|
|
969
|
+
m1 - m0);
|
|
1015
970
|
}
|
|
1016
971
|
}
|
|
1017
972
|
}
|
|
1018
973
|
}
|
|
1019
|
-
|
|
1020
974
|
}
|
|
1021
975
|
|
|
1022
|
-
#else
|
|
1023
|
-
|
|
1024
|
-
void compute_PQ_dis_tables_dsub2(
|
|
1025
|
-
size_t d, size_t ksub, const float *all_centroids,
|
|
1026
|
-
size_t nx, const float * x,
|
|
1027
|
-
bool is_inner_product,
|
|
1028
|
-
float * dis_tables)
|
|
1029
|
-
{
|
|
1030
|
-
FAISS_THROW_MSG("only implemented for AVX2");
|
|
1031
|
-
}
|
|
1032
|
-
|
|
1033
|
-
#endif
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
976
|
} // namespace faiss
|