faiss 0.2.0 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +16 -0
- data/LICENSE.txt +1 -1
- data/README.md +7 -7
- data/ext/faiss/extconf.rb +6 -3
- data/ext/faiss/numo.hpp +4 -4
- data/ext/faiss/utils.cpp +1 -1
- data/ext/faiss/utils.h +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +292 -291
- data/vendor/faiss/faiss/AutoTune.h +55 -56
- data/vendor/faiss/faiss/Clustering.cpp +365 -194
- data/vendor/faiss/faiss/Clustering.h +102 -35
- data/vendor/faiss/faiss/IVFlib.cpp +171 -195
- data/vendor/faiss/faiss/IVFlib.h +48 -51
- data/vendor/faiss/faiss/Index.cpp +85 -103
- data/vendor/faiss/faiss/Index.h +54 -48
- data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
- data/vendor/faiss/faiss/Index2Layer.h +22 -36
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
- data/vendor/faiss/faiss/IndexBinary.h +140 -132
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
- data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
- data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
- data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
- data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
- data/vendor/faiss/faiss/IndexFlat.h +42 -59
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
- data/vendor/faiss/faiss/IndexHNSW.h +57 -41
- data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
- data/vendor/faiss/faiss/IndexIVF.h +169 -118
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
- data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
- data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
- data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
- data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
- data/vendor/faiss/faiss/IndexLSH.h +20 -38
- data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
- data/vendor/faiss/faiss/IndexLattice.h +11 -16
- data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
- data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
- data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
- data/vendor/faiss/faiss/IndexNSG.h +85 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
- data/vendor/faiss/faiss/IndexPQ.h +64 -82
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
- data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
- data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
- data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
- data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
- data/vendor/faiss/faiss/IndexRefine.h +32 -23
- data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
- data/vendor/faiss/faiss/IndexReplicas.h +62 -56
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
- data/vendor/faiss/faiss/IndexShards.cpp +256 -240
- data/vendor/faiss/faiss/IndexShards.h +85 -73
- data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
- data/vendor/faiss/faiss/MatrixStats.h +7 -10
- data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
- data/vendor/faiss/faiss/MetaIndexes.h +40 -34
- data/vendor/faiss/faiss/MetricType.h +7 -7
- data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
- data/vendor/faiss/faiss/VectorTransform.h +64 -89
- data/vendor/faiss/faiss/clone_index.cpp +78 -73
- data/vendor/faiss/faiss/clone_index.h +4 -9
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
- data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
- data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
- data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
- data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
- data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
- data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
- data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
- data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
- data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
- data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
- data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
- data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
- data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
- data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
- data/vendor/faiss/faiss/impl/FaissException.h +41 -29
- data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
- data/vendor/faiss/faiss/impl/HNSW.h +179 -200
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
- data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
- data/vendor/faiss/faiss/impl/NSG.h +199 -0
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
- data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
- data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
- data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
- data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
- data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
- data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
- data/vendor/faiss/faiss/impl/io.cpp +76 -95
- data/vendor/faiss/faiss/impl/io.h +31 -41
- data/vendor/faiss/faiss/impl/io_macros.h +60 -29
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
- data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
- data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
- data/vendor/faiss/faiss/index_factory.cpp +619 -397
- data/vendor/faiss/faiss/index_factory.h +8 -6
- data/vendor/faiss/faiss/index_io.h +23 -26
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
- data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
- data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
- data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
- data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
- data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
- data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
- data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
- data/vendor/faiss/faiss/utils/Heap.h +186 -209
- data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
- data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
- data/vendor/faiss/faiss/utils/distances.cpp +305 -312
- data/vendor/faiss/faiss/utils/distances.h +170 -122
- data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
- data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
- data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
- data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
- data/vendor/faiss/faiss/utils/hamming.h +62 -85
- data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
- data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
- data/vendor/faiss/faiss/utils/partitioning.h +26 -21
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
- data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
- data/vendor/faiss/faiss/utils/random.cpp +39 -63
- data/vendor/faiss/faiss/utils/random.h +13 -16
- data/vendor/faiss/faiss/utils/simdlib.h +4 -2
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
- data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
- data/vendor/faiss/faiss/utils/utils.cpp +304 -287
- data/vendor/faiss/faiss/utils/utils.h +54 -49
- metadata +29 -4
|
@@ -9,13 +9,15 @@
|
|
|
9
9
|
|
|
10
10
|
#include <faiss/utils/distances.h>
|
|
11
11
|
|
|
12
|
-
#include <
|
|
12
|
+
#include <algorithm>
|
|
13
13
|
#include <cassert>
|
|
14
|
-
#include <cstring>
|
|
15
14
|
#include <cmath>
|
|
15
|
+
#include <cstdio>
|
|
16
|
+
#include <cstring>
|
|
16
17
|
|
|
17
|
-
#include <faiss/utils/simdlib.h>
|
|
18
18
|
#include <faiss/impl/FaissAssert.h>
|
|
19
|
+
#include <faiss/impl/platform_macros.h>
|
|
20
|
+
#include <faiss/utils/simdlib.h>
|
|
19
21
|
|
|
20
22
|
#ifdef __SSE3__
|
|
21
23
|
#include <immintrin.h>
|
|
@@ -25,19 +27,16 @@
|
|
|
25
27
|
#include <arm_neon.h>
|
|
26
28
|
#endif
|
|
27
29
|
|
|
28
|
-
|
|
29
30
|
namespace faiss {
|
|
30
31
|
|
|
31
32
|
#ifdef __AVX__
|
|
32
33
|
#define USE_AVX
|
|
33
34
|
#endif
|
|
34
35
|
|
|
35
|
-
|
|
36
36
|
/*********************************************************
|
|
37
37
|
* Optimized distance computations
|
|
38
38
|
*********************************************************/
|
|
39
39
|
|
|
40
|
-
|
|
41
40
|
/* Functions to compute:
|
|
42
41
|
- L2 distance between 2 vectors
|
|
43
42
|
- inner product between 2 vectors
|
|
@@ -53,29 +52,21 @@ namespace faiss {
|
|
|
53
52
|
|
|
54
53
|
*/
|
|
55
54
|
|
|
56
|
-
|
|
57
55
|
/*********************************************************
|
|
58
56
|
* Reference implementations
|
|
59
57
|
*/
|
|
60
58
|
|
|
61
|
-
|
|
62
|
-
float fvec_L2sqr_ref (const float * x,
|
|
63
|
-
const float * y,
|
|
64
|
-
size_t d)
|
|
65
|
-
{
|
|
59
|
+
float fvec_L2sqr_ref(const float* x, const float* y, size_t d) {
|
|
66
60
|
size_t i;
|
|
67
61
|
float res = 0;
|
|
68
62
|
for (i = 0; i < d; i++) {
|
|
69
63
|
const float tmp = x[i] - y[i];
|
|
70
|
-
|
|
64
|
+
res += tmp * tmp;
|
|
71
65
|
}
|
|
72
66
|
return res;
|
|
73
67
|
}
|
|
74
68
|
|
|
75
|
-
float fvec_L1_ref
|
|
76
|
-
const float * y,
|
|
77
|
-
size_t d)
|
|
78
|
-
{
|
|
69
|
+
float fvec_L1_ref(const float* x, const float* y, size_t d) {
|
|
79
70
|
size_t i;
|
|
80
71
|
float res = 0;
|
|
81
72
|
for (i = 0; i < d; i++) {
|
|
@@ -85,56 +76,49 @@ float fvec_L1_ref (const float * x,
|
|
|
85
76
|
return res;
|
|
86
77
|
}
|
|
87
78
|
|
|
88
|
-
float fvec_Linf_ref
|
|
89
|
-
const float * y,
|
|
90
|
-
size_t d)
|
|
91
|
-
{
|
|
79
|
+
float fvec_Linf_ref(const float* x, const float* y, size_t d) {
|
|
92
80
|
size_t i;
|
|
93
81
|
float res = 0;
|
|
94
82
|
for (i = 0; i < d; i++) {
|
|
95
|
-
|
|
83
|
+
res = fmax(res, fabs(x[i] - y[i]));
|
|
96
84
|
}
|
|
97
85
|
return res;
|
|
98
86
|
}
|
|
99
87
|
|
|
100
|
-
float fvec_inner_product_ref
|
|
101
|
-
const float * y,
|
|
102
|
-
size_t d)
|
|
103
|
-
{
|
|
88
|
+
float fvec_inner_product_ref(const float* x, const float* y, size_t d) {
|
|
104
89
|
size_t i;
|
|
105
90
|
float res = 0;
|
|
106
91
|
for (i = 0; i < d; i++)
|
|
107
|
-
|
|
92
|
+
res += x[i] * y[i];
|
|
108
93
|
return res;
|
|
109
94
|
}
|
|
110
95
|
|
|
111
|
-
float fvec_norm_L2sqr_ref
|
|
112
|
-
{
|
|
96
|
+
float fvec_norm_L2sqr_ref(const float* x, size_t d) {
|
|
113
97
|
size_t i;
|
|
114
98
|
double res = 0;
|
|
115
99
|
for (i = 0; i < d; i++)
|
|
116
|
-
|
|
100
|
+
res += x[i] * x[i];
|
|
117
101
|
return res;
|
|
118
102
|
}
|
|
119
103
|
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
{
|
|
104
|
+
void fvec_L2sqr_ny_ref(
|
|
105
|
+
float* dis,
|
|
106
|
+
const float* x,
|
|
107
|
+
const float* y,
|
|
108
|
+
size_t d,
|
|
109
|
+
size_t ny) {
|
|
126
110
|
for (size_t i = 0; i < ny; i++) {
|
|
127
|
-
dis[i] = fvec_L2sqr
|
|
111
|
+
dis[i] = fvec_L2sqr(x, y, d);
|
|
128
112
|
y += d;
|
|
129
113
|
}
|
|
130
114
|
}
|
|
131
115
|
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
{
|
|
116
|
+
void fvec_inner_products_ny_ref(
|
|
117
|
+
float* ip,
|
|
118
|
+
const float* x,
|
|
119
|
+
const float* y,
|
|
120
|
+
size_t d,
|
|
121
|
+
size_t ny) {
|
|
138
122
|
// BLAS slower for the use cases here
|
|
139
123
|
#if 0
|
|
140
124
|
{
|
|
@@ -146,15 +130,11 @@ void fvec_inner_products_ny_ref (float * ip,
|
|
|
146
130
|
}
|
|
147
131
|
#endif
|
|
148
132
|
for (size_t i = 0; i < ny; i++) {
|
|
149
|
-
ip[i] = fvec_inner_product
|
|
133
|
+
ip[i] = fvec_inner_product(x, y, d);
|
|
150
134
|
y += d;
|
|
151
135
|
}
|
|
152
136
|
}
|
|
153
137
|
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
138
|
/*********************************************************
|
|
159
139
|
* SSE and AVX implementations
|
|
160
140
|
*/
|
|
@@ -162,40 +142,38 @@ void fvec_inner_products_ny_ref (float * ip,
|
|
|
162
142
|
#ifdef __SSE3__
|
|
163
143
|
|
|
164
144
|
// reads 0 <= d < 4 floats as __m128
|
|
165
|
-
static inline __m128 masked_read
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
__attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
|
|
145
|
+
static inline __m128 masked_read(int d, const float* x) {
|
|
146
|
+
assert(0 <= d && d < 4);
|
|
147
|
+
ALIGNED(16) float buf[4] = {0, 0, 0, 0};
|
|
169
148
|
switch (d) {
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
149
|
+
case 3:
|
|
150
|
+
buf[2] = x[2];
|
|
151
|
+
case 2:
|
|
152
|
+
buf[1] = x[1];
|
|
153
|
+
case 1:
|
|
154
|
+
buf[0] = x[0];
|
|
176
155
|
}
|
|
177
|
-
return _mm_load_ps
|
|
156
|
+
return _mm_load_ps(buf);
|
|
178
157
|
// cannot use AVX2 _mm_mask_set1_epi32
|
|
179
158
|
}
|
|
180
159
|
|
|
181
|
-
float fvec_norm_L2sqr
|
|
182
|
-
size_t d)
|
|
183
|
-
{
|
|
160
|
+
float fvec_norm_L2sqr(const float* x, size_t d) {
|
|
184
161
|
__m128 mx;
|
|
185
162
|
__m128 msum1 = _mm_setzero_ps();
|
|
186
163
|
|
|
187
164
|
while (d >= 4) {
|
|
188
|
-
mx = _mm_loadu_ps
|
|
189
|
-
|
|
165
|
+
mx = _mm_loadu_ps(x);
|
|
166
|
+
x += 4;
|
|
167
|
+
msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
|
|
190
168
|
d -= 4;
|
|
191
169
|
}
|
|
192
170
|
|
|
193
|
-
mx = masked_read
|
|
194
|
-
msum1 = _mm_add_ps
|
|
171
|
+
mx = masked_read(d, x);
|
|
172
|
+
msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
|
|
195
173
|
|
|
196
|
-
msum1 = _mm_hadd_ps
|
|
197
|
-
msum1 = _mm_hadd_ps
|
|
198
|
-
return
|
|
174
|
+
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
175
|
+
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
176
|
+
return _mm_cvtss_f32(msum1);
|
|
199
177
|
}
|
|
200
178
|
|
|
201
179
|
namespace {
|
|
@@ -204,586 +182,588 @@ namespace {
|
|
|
204
182
|
/// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
|
|
205
183
|
/// functions below
|
|
206
184
|
struct ElementOpL2 {
|
|
207
|
-
|
|
208
|
-
static float op (float x, float y) {
|
|
185
|
+
static float op(float x, float y) {
|
|
209
186
|
float tmp = x - y;
|
|
210
187
|
return tmp * tmp;
|
|
211
188
|
}
|
|
212
189
|
|
|
213
|
-
static __m128 op
|
|
214
|
-
__m128 tmp = x
|
|
215
|
-
return tmp
|
|
190
|
+
static __m128 op(__m128 x, __m128 y) {
|
|
191
|
+
__m128 tmp = _mm_sub_ps(x, y);
|
|
192
|
+
return _mm_mul_ps(tmp, tmp);
|
|
216
193
|
}
|
|
217
|
-
|
|
218
194
|
};
|
|
219
195
|
|
|
220
196
|
/// Function that does a component-wise operation between x and y
|
|
221
197
|
/// to compute inner products
|
|
222
198
|
struct ElementOpIP {
|
|
223
|
-
|
|
224
|
-
static float op (float x, float y) {
|
|
199
|
+
static float op(float x, float y) {
|
|
225
200
|
return x * y;
|
|
226
201
|
}
|
|
227
202
|
|
|
228
|
-
static __m128 op
|
|
229
|
-
return x
|
|
203
|
+
static __m128 op(__m128 x, __m128 y) {
|
|
204
|
+
return _mm_mul_ps(x, y);
|
|
230
205
|
}
|
|
231
|
-
|
|
232
206
|
};
|
|
233
207
|
|
|
234
|
-
template<class ElementOp>
|
|
235
|
-
void fvec_op_ny_D1
|
|
236
|
-
const float * y, size_t ny)
|
|
237
|
-
{
|
|
208
|
+
template <class ElementOp>
|
|
209
|
+
void fvec_op_ny_D1(float* dis, const float* x, const float* y, size_t ny) {
|
|
238
210
|
float x0s = x[0];
|
|
239
|
-
__m128 x0 = _mm_set_ps
|
|
211
|
+
__m128 x0 = _mm_set_ps(x0s, x0s, x0s, x0s);
|
|
240
212
|
|
|
241
213
|
size_t i;
|
|
242
214
|
for (i = 0; i + 3 < ny; i += 4) {
|
|
243
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
215
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
216
|
+
y += 4;
|
|
217
|
+
dis[i] = _mm_cvtss_f32(accu);
|
|
218
|
+
__m128 tmp = _mm_shuffle_ps(accu, accu, 1);
|
|
219
|
+
dis[i + 1] = _mm_cvtss_f32(tmp);
|
|
220
|
+
tmp = _mm_shuffle_ps(accu, accu, 2);
|
|
221
|
+
dis[i + 2] = _mm_cvtss_f32(tmp);
|
|
222
|
+
tmp = _mm_shuffle_ps(accu, accu, 3);
|
|
223
|
+
dis[i + 3] = _mm_cvtss_f32(tmp);
|
|
251
224
|
}
|
|
252
225
|
while (i < ny) { // handle non-multiple-of-4 case
|
|
253
226
|
dis[i++] = ElementOp::op(x0s, *y++);
|
|
254
227
|
}
|
|
255
228
|
}
|
|
256
229
|
|
|
257
|
-
template<class ElementOp>
|
|
258
|
-
void fvec_op_ny_D2
|
|
259
|
-
|
|
260
|
-
{
|
|
261
|
-
__m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
|
|
230
|
+
template <class ElementOp>
|
|
231
|
+
void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) {
|
|
232
|
+
__m128 x0 = _mm_set_ps(x[1], x[0], x[1], x[0]);
|
|
262
233
|
|
|
263
234
|
size_t i;
|
|
264
235
|
for (i = 0; i + 1 < ny; i += 2) {
|
|
265
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
236
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
237
|
+
y += 4;
|
|
238
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
239
|
+
dis[i] = _mm_cvtss_f32(accu);
|
|
240
|
+
accu = _mm_shuffle_ps(accu, accu, 3);
|
|
241
|
+
dis[i + 1] = _mm_cvtss_f32(accu);
|
|
270
242
|
}
|
|
271
243
|
if (i < ny) { // handle odd case
|
|
272
244
|
dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]);
|
|
273
245
|
}
|
|
274
246
|
}
|
|
275
247
|
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
template<class ElementOp>
|
|
279
|
-
void fvec_op_ny_D4 (float * dis, const float * x,
|
|
280
|
-
const float * y, size_t ny)
|
|
281
|
-
{
|
|
248
|
+
template <class ElementOp>
|
|
249
|
+
void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
|
|
282
250
|
__m128 x0 = _mm_loadu_ps(x);
|
|
283
251
|
|
|
284
252
|
for (size_t i = 0; i < ny; i++) {
|
|
285
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps
|
|
286
|
-
|
|
287
|
-
accu = _mm_hadd_ps
|
|
288
|
-
|
|
253
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
254
|
+
y += 4;
|
|
255
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
256
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
257
|
+
dis[i] = _mm_cvtss_f32(accu);
|
|
289
258
|
}
|
|
290
259
|
}
|
|
291
260
|
|
|
292
|
-
template<class ElementOp>
|
|
293
|
-
void fvec_op_ny_D8
|
|
294
|
-
const float * y, size_t ny)
|
|
295
|
-
{
|
|
261
|
+
template <class ElementOp>
|
|
262
|
+
void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
|
|
296
263
|
__m128 x0 = _mm_loadu_ps(x);
|
|
297
264
|
__m128 x1 = _mm_loadu_ps(x + 4);
|
|
298
265
|
|
|
299
266
|
for (size_t i = 0; i < ny; i++) {
|
|
300
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps
|
|
301
|
-
|
|
302
|
-
accu =
|
|
303
|
-
|
|
304
|
-
|
|
267
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
268
|
+
y += 4;
|
|
269
|
+
accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
|
|
270
|
+
y += 4;
|
|
271
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
272
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
273
|
+
dis[i] = _mm_cvtss_f32(accu);
|
|
305
274
|
}
|
|
306
275
|
}
|
|
307
276
|
|
|
308
|
-
template<class ElementOp>
|
|
309
|
-
void fvec_op_ny_D12
|
|
310
|
-
const float * y, size_t ny)
|
|
311
|
-
{
|
|
277
|
+
template <class ElementOp>
|
|
278
|
+
void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
|
|
312
279
|
__m128 x0 = _mm_loadu_ps(x);
|
|
313
280
|
__m128 x1 = _mm_loadu_ps(x + 4);
|
|
314
281
|
__m128 x2 = _mm_loadu_ps(x + 8);
|
|
315
282
|
|
|
316
283
|
for (size_t i = 0; i < ny; i++) {
|
|
317
|
-
__m128 accu = ElementOp::op(x0, _mm_loadu_ps
|
|
318
|
-
|
|
319
|
-
accu
|
|
320
|
-
|
|
321
|
-
accu =
|
|
322
|
-
|
|
284
|
+
__m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
|
|
285
|
+
y += 4;
|
|
286
|
+
accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
|
|
287
|
+
y += 4;
|
|
288
|
+
accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y)));
|
|
289
|
+
y += 4;
|
|
290
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
291
|
+
accu = _mm_hadd_ps(accu, accu);
|
|
292
|
+
dis[i] = _mm_cvtss_f32(accu);
|
|
323
293
|
}
|
|
324
294
|
}
|
|
325
295
|
|
|
326
|
-
|
|
327
|
-
|
|
328
296
|
} // anonymous namespace
|
|
329
297
|
|
|
330
|
-
void fvec_L2sqr_ny
|
|
331
|
-
|
|
298
|
+
void fvec_L2sqr_ny(
|
|
299
|
+
float* dis,
|
|
300
|
+
const float* x,
|
|
301
|
+
const float* y,
|
|
302
|
+
size_t d,
|
|
303
|
+
size_t ny) {
|
|
332
304
|
// optimized for a few special cases
|
|
333
305
|
|
|
334
|
-
#define DISPATCH(dval)
|
|
335
|
-
case dval
|
|
336
|
-
fvec_op_ny_D
|
|
306
|
+
#define DISPATCH(dval) \
|
|
307
|
+
case dval: \
|
|
308
|
+
fvec_op_ny_D##dval<ElementOpL2>(dis, x, y, ny); \
|
|
337
309
|
return;
|
|
338
310
|
|
|
339
|
-
switch(d) {
|
|
311
|
+
switch (d) {
|
|
340
312
|
DISPATCH(1)
|
|
341
313
|
DISPATCH(2)
|
|
342
314
|
DISPATCH(4)
|
|
343
315
|
DISPATCH(8)
|
|
344
316
|
DISPATCH(12)
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
317
|
+
default:
|
|
318
|
+
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
319
|
+
return;
|
|
348
320
|
}
|
|
349
321
|
#undef DISPATCH
|
|
350
|
-
|
|
351
322
|
}
|
|
352
323
|
|
|
353
|
-
void fvec_inner_products_ny
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
324
|
+
void fvec_inner_products_ny(
|
|
325
|
+
float* dis,
|
|
326
|
+
const float* x,
|
|
327
|
+
const float* y,
|
|
328
|
+
size_t d,
|
|
329
|
+
size_t ny) {
|
|
330
|
+
#define DISPATCH(dval) \
|
|
331
|
+
case dval: \
|
|
332
|
+
fvec_op_ny_D##dval<ElementOpIP>(dis, x, y, ny); \
|
|
359
333
|
return;
|
|
360
334
|
|
|
361
|
-
switch(d) {
|
|
335
|
+
switch (d) {
|
|
362
336
|
DISPATCH(1)
|
|
363
337
|
DISPATCH(2)
|
|
364
338
|
DISPATCH(4)
|
|
365
339
|
DISPATCH(8)
|
|
366
340
|
DISPATCH(12)
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
341
|
+
default:
|
|
342
|
+
fvec_inner_products_ny_ref(dis, x, y, d, ny);
|
|
343
|
+
return;
|
|
370
344
|
}
|
|
371
345
|
#undef DISPATCH
|
|
372
|
-
|
|
373
346
|
}
|
|
374
347
|
|
|
375
|
-
|
|
376
|
-
|
|
377
348
|
#endif
|
|
378
349
|
|
|
379
350
|
#ifdef USE_AVX
|
|
380
351
|
|
|
381
352
|
// reads 0 <= d < 8 floats as __m256
|
|
382
|
-
static inline __m256 masked_read_8
|
|
383
|
-
|
|
384
|
-
assert (0 <= d && d < 8);
|
|
353
|
+
static inline __m256 masked_read_8(int d, const float* x) {
|
|
354
|
+
assert(0 <= d && d < 8);
|
|
385
355
|
if (d < 4) {
|
|
386
|
-
__m256 res = _mm256_setzero_ps
|
|
387
|
-
res = _mm256_insertf128_ps
|
|
356
|
+
__m256 res = _mm256_setzero_ps();
|
|
357
|
+
res = _mm256_insertf128_ps(res, masked_read(d, x), 0);
|
|
388
358
|
return res;
|
|
389
359
|
} else {
|
|
390
|
-
__m256 res = _mm256_setzero_ps
|
|
391
|
-
res = _mm256_insertf128_ps
|
|
392
|
-
res = _mm256_insertf128_ps
|
|
360
|
+
__m256 res = _mm256_setzero_ps();
|
|
361
|
+
res = _mm256_insertf128_ps(res, _mm_loadu_ps(x), 0);
|
|
362
|
+
res = _mm256_insertf128_ps(res, masked_read(d - 4, x + 4), 1);
|
|
393
363
|
return res;
|
|
394
364
|
}
|
|
395
365
|
}
|
|
396
366
|
|
|
397
|
-
float fvec_inner_product
|
|
398
|
-
const float * y,
|
|
399
|
-
size_t d)
|
|
400
|
-
{
|
|
367
|
+
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
401
368
|
__m256 msum1 = _mm256_setzero_ps();
|
|
402
369
|
|
|
403
370
|
while (d >= 8) {
|
|
404
|
-
__m256 mx = _mm256_loadu_ps
|
|
405
|
-
|
|
406
|
-
|
|
371
|
+
__m256 mx = _mm256_loadu_ps(x);
|
|
372
|
+
x += 8;
|
|
373
|
+
__m256 my = _mm256_loadu_ps(y);
|
|
374
|
+
y += 8;
|
|
375
|
+
msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(mx, my));
|
|
407
376
|
d -= 8;
|
|
408
377
|
}
|
|
409
378
|
|
|
410
379
|
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
411
|
-
msum2
|
|
380
|
+
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
|
|
412
381
|
|
|
413
382
|
if (d >= 4) {
|
|
414
|
-
__m128 mx = _mm_loadu_ps
|
|
415
|
-
|
|
416
|
-
|
|
383
|
+
__m128 mx = _mm_loadu_ps(x);
|
|
384
|
+
x += 4;
|
|
385
|
+
__m128 my = _mm_loadu_ps(y);
|
|
386
|
+
y += 4;
|
|
387
|
+
msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
|
|
417
388
|
d -= 4;
|
|
418
389
|
}
|
|
419
390
|
|
|
420
391
|
if (d > 0) {
|
|
421
|
-
__m128 mx = masked_read
|
|
422
|
-
__m128 my = masked_read
|
|
423
|
-
msum2 = _mm_add_ps
|
|
392
|
+
__m128 mx = masked_read(d, x);
|
|
393
|
+
__m128 my = masked_read(d, y);
|
|
394
|
+
msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
|
|
424
395
|
}
|
|
425
396
|
|
|
426
|
-
msum2 = _mm_hadd_ps
|
|
427
|
-
msum2 = _mm_hadd_ps
|
|
428
|
-
return
|
|
397
|
+
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
398
|
+
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
399
|
+
return _mm_cvtss_f32(msum2);
|
|
429
400
|
}
|
|
430
401
|
|
|
431
|
-
float fvec_L2sqr
|
|
432
|
-
const float * y,
|
|
433
|
-
size_t d)
|
|
434
|
-
{
|
|
402
|
+
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
435
403
|
__m256 msum1 = _mm256_setzero_ps();
|
|
436
404
|
|
|
437
405
|
while (d >= 8) {
|
|
438
|
-
__m256 mx = _mm256_loadu_ps
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
406
|
+
__m256 mx = _mm256_loadu_ps(x);
|
|
407
|
+
x += 8;
|
|
408
|
+
__m256 my = _mm256_loadu_ps(y);
|
|
409
|
+
y += 8;
|
|
410
|
+
const __m256 a_m_b1 = _mm256_sub_ps(mx, my);
|
|
411
|
+
msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(a_m_b1, a_m_b1));
|
|
442
412
|
d -= 8;
|
|
443
413
|
}
|
|
444
414
|
|
|
445
415
|
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
446
|
-
msum2
|
|
416
|
+
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
|
|
447
417
|
|
|
448
418
|
if (d >= 4) {
|
|
449
|
-
__m128 mx = _mm_loadu_ps
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
419
|
+
__m128 mx = _mm_loadu_ps(x);
|
|
420
|
+
x += 4;
|
|
421
|
+
__m128 my = _mm_loadu_ps(y);
|
|
422
|
+
y += 4;
|
|
423
|
+
const __m128 a_m_b1 = _mm_sub_ps(mx, my);
|
|
424
|
+
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
|
|
453
425
|
d -= 4;
|
|
454
426
|
}
|
|
455
427
|
|
|
456
428
|
if (d > 0) {
|
|
457
|
-
__m128 mx = masked_read
|
|
458
|
-
__m128 my = masked_read
|
|
459
|
-
__m128 a_m_b1 = mx
|
|
460
|
-
msum2
|
|
429
|
+
__m128 mx = masked_read(d, x);
|
|
430
|
+
__m128 my = masked_read(d, y);
|
|
431
|
+
__m128 a_m_b1 = _mm_sub_ps(mx, my);
|
|
432
|
+
msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
|
|
461
433
|
}
|
|
462
434
|
|
|
463
|
-
msum2 = _mm_hadd_ps
|
|
464
|
-
msum2 = _mm_hadd_ps
|
|
465
|
-
return
|
|
435
|
+
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
436
|
+
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
437
|
+
return _mm_cvtss_f32(msum2);
|
|
466
438
|
}
|
|
467
439
|
|
|
468
|
-
float fvec_L1
|
|
469
|
-
{
|
|
440
|
+
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
470
441
|
__m256 msum1 = _mm256_setzero_ps();
|
|
471
|
-
__m256 signmask =
|
|
442
|
+
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
|
|
472
443
|
|
|
473
444
|
while (d >= 8) {
|
|
474
|
-
__m256 mx = _mm256_loadu_ps
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
445
|
+
__m256 mx = _mm256_loadu_ps(x);
|
|
446
|
+
x += 8;
|
|
447
|
+
__m256 my = _mm256_loadu_ps(y);
|
|
448
|
+
y += 8;
|
|
449
|
+
const __m256 a_m_b = _mm256_sub_ps(mx, my);
|
|
450
|
+
msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b));
|
|
478
451
|
d -= 8;
|
|
479
452
|
}
|
|
480
453
|
|
|
481
454
|
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
482
|
-
msum2
|
|
483
|
-
__m128 signmask2 =
|
|
455
|
+
msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
|
|
456
|
+
__m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL));
|
|
484
457
|
|
|
485
458
|
if (d >= 4) {
|
|
486
|
-
__m128 mx = _mm_loadu_ps
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
459
|
+
__m128 mx = _mm_loadu_ps(x);
|
|
460
|
+
x += 4;
|
|
461
|
+
__m128 my = _mm_loadu_ps(y);
|
|
462
|
+
y += 4;
|
|
463
|
+
const __m128 a_m_b = _mm_sub_ps(mx, my);
|
|
464
|
+
msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
|
490
465
|
d -= 4;
|
|
491
466
|
}
|
|
492
467
|
|
|
493
468
|
if (d > 0) {
|
|
494
|
-
__m128 mx = masked_read
|
|
495
|
-
__m128 my = masked_read
|
|
496
|
-
__m128 a_m_b = mx
|
|
497
|
-
msum2
|
|
469
|
+
__m128 mx = masked_read(d, x);
|
|
470
|
+
__m128 my = masked_read(d, y);
|
|
471
|
+
__m128 a_m_b = _mm_sub_ps(mx, my);
|
|
472
|
+
msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
|
498
473
|
}
|
|
499
474
|
|
|
500
|
-
msum2 = _mm_hadd_ps
|
|
501
|
-
msum2 = _mm_hadd_ps
|
|
502
|
-
return
|
|
475
|
+
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
476
|
+
msum2 = _mm_hadd_ps(msum2, msum2);
|
|
477
|
+
return _mm_cvtss_f32(msum2);
|
|
503
478
|
}
|
|
504
479
|
|
|
505
|
-
float fvec_Linf
|
|
506
|
-
{
|
|
480
|
+
float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
507
481
|
__m256 msum1 = _mm256_setzero_ps();
|
|
508
|
-
__m256 signmask =
|
|
482
|
+
__m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
|
|
509
483
|
|
|
510
484
|
while (d >= 8) {
|
|
511
|
-
__m256 mx = _mm256_loadu_ps
|
|
512
|
-
|
|
513
|
-
|
|
485
|
+
__m256 mx = _mm256_loadu_ps(x);
|
|
486
|
+
x += 8;
|
|
487
|
+
__m256 my = _mm256_loadu_ps(y);
|
|
488
|
+
y += 8;
|
|
489
|
+
const __m256 a_m_b = _mm256_sub_ps(mx, my);
|
|
514
490
|
msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b));
|
|
515
491
|
d -= 8;
|
|
516
492
|
}
|
|
517
493
|
|
|
518
494
|
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
519
|
-
msum2 = _mm_max_ps
|
|
520
|
-
__m128 signmask2 =
|
|
495
|
+
msum2 = _mm_max_ps(msum2, _mm256_extractf128_ps(msum1, 0));
|
|
496
|
+
__m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL));
|
|
521
497
|
|
|
522
498
|
if (d >= 4) {
|
|
523
|
-
__m128 mx = _mm_loadu_ps
|
|
524
|
-
|
|
525
|
-
|
|
499
|
+
__m128 mx = _mm_loadu_ps(x);
|
|
500
|
+
x += 4;
|
|
501
|
+
__m128 my = _mm_loadu_ps(y);
|
|
502
|
+
y += 4;
|
|
503
|
+
const __m128 a_m_b = _mm_sub_ps(mx, my);
|
|
526
504
|
msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
|
527
505
|
d -= 4;
|
|
528
506
|
}
|
|
529
507
|
|
|
530
508
|
if (d > 0) {
|
|
531
|
-
__m128 mx = masked_read
|
|
532
|
-
__m128 my = masked_read
|
|
533
|
-
__m128 a_m_b = mx
|
|
509
|
+
__m128 mx = masked_read(d, x);
|
|
510
|
+
__m128 my = masked_read(d, y);
|
|
511
|
+
__m128 a_m_b = _mm_sub_ps(mx, my);
|
|
534
512
|
msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
|
535
513
|
}
|
|
536
514
|
|
|
537
515
|
msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2);
|
|
538
|
-
msum2 = _mm_max_ps(msum2, _mm_shuffle_ps
|
|
539
|
-
return
|
|
516
|
+
msum2 = _mm_max_ps(msum2, _mm_shuffle_ps(msum2, msum2, 1));
|
|
517
|
+
return _mm_cvtss_f32(msum2);
|
|
540
518
|
}
|
|
541
519
|
|
|
542
520
|
#elif defined(__SSE3__) // But not AVX
|
|
543
521
|
|
|
544
|
-
float fvec_L1
|
|
545
|
-
|
|
546
|
-
return fvec_L1_ref (x, y, d);
|
|
522
|
+
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
523
|
+
return fvec_L1_ref(x, y, d);
|
|
547
524
|
}
|
|
548
525
|
|
|
549
|
-
float fvec_Linf
|
|
550
|
-
|
|
551
|
-
return fvec_Linf_ref (x, y, d);
|
|
526
|
+
float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
527
|
+
return fvec_Linf_ref(x, y, d);
|
|
552
528
|
}
|
|
553
529
|
|
|
554
|
-
|
|
555
|
-
float fvec_L2sqr (const float * x,
|
|
556
|
-
const float * y,
|
|
557
|
-
size_t d)
|
|
558
|
-
{
|
|
530
|
+
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
559
531
|
__m128 msum1 = _mm_setzero_ps();
|
|
560
532
|
|
|
561
533
|
while (d >= 4) {
|
|
562
|
-
__m128 mx = _mm_loadu_ps
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
534
|
+
__m128 mx = _mm_loadu_ps(x);
|
|
535
|
+
x += 4;
|
|
536
|
+
__m128 my = _mm_loadu_ps(y);
|
|
537
|
+
y += 4;
|
|
538
|
+
const __m128 a_m_b1 = _mm_sub_ps(mx, my);
|
|
539
|
+
msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
|
|
566
540
|
d -= 4;
|
|
567
541
|
}
|
|
568
542
|
|
|
569
543
|
if (d > 0) {
|
|
570
544
|
// add the last 1, 2 or 3 values
|
|
571
|
-
__m128 mx = masked_read
|
|
572
|
-
__m128 my = masked_read
|
|
573
|
-
__m128 a_m_b1 = mx
|
|
574
|
-
msum1
|
|
545
|
+
__m128 mx = masked_read(d, x);
|
|
546
|
+
__m128 my = masked_read(d, y);
|
|
547
|
+
__m128 a_m_b1 = _mm_sub_ps(mx, my);
|
|
548
|
+
msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
|
|
575
549
|
}
|
|
576
550
|
|
|
577
|
-
msum1 = _mm_hadd_ps
|
|
578
|
-
msum1 = _mm_hadd_ps
|
|
579
|
-
return
|
|
551
|
+
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
552
|
+
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
553
|
+
return _mm_cvtss_f32(msum1);
|
|
580
554
|
}
|
|
581
555
|
|
|
582
|
-
|
|
583
|
-
float fvec_inner_product (const float * x,
|
|
584
|
-
const float * y,
|
|
585
|
-
size_t d)
|
|
586
|
-
{
|
|
556
|
+
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
587
557
|
__m128 mx, my;
|
|
588
558
|
__m128 msum1 = _mm_setzero_ps();
|
|
589
559
|
|
|
590
560
|
while (d >= 4) {
|
|
591
|
-
mx = _mm_loadu_ps
|
|
592
|
-
|
|
593
|
-
|
|
561
|
+
mx = _mm_loadu_ps(x);
|
|
562
|
+
x += 4;
|
|
563
|
+
my = _mm_loadu_ps(y);
|
|
564
|
+
y += 4;
|
|
565
|
+
msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, my));
|
|
594
566
|
d -= 4;
|
|
595
567
|
}
|
|
596
568
|
|
|
597
569
|
// add the last 1, 2, or 3 values
|
|
598
|
-
mx = masked_read
|
|
599
|
-
my = masked_read
|
|
600
|
-
__m128 prod = _mm_mul_ps
|
|
570
|
+
mx = masked_read(d, x);
|
|
571
|
+
my = masked_read(d, y);
|
|
572
|
+
__m128 prod = _mm_mul_ps(mx, my);
|
|
601
573
|
|
|
602
|
-
msum1 = _mm_add_ps
|
|
574
|
+
msum1 = _mm_add_ps(msum1, prod);
|
|
603
575
|
|
|
604
|
-
msum1 = _mm_hadd_ps
|
|
605
|
-
msum1 = _mm_hadd_ps
|
|
606
|
-
return
|
|
576
|
+
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
577
|
+
msum1 = _mm_hadd_ps(msum1, msum1);
|
|
578
|
+
return _mm_cvtss_f32(msum1);
|
|
607
579
|
}
|
|
608
580
|
|
|
609
581
|
#elif defined(__aarch64__)
|
|
610
582
|
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
{
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
583
|
+
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
584
|
+
float32x4_t accux4 = vdupq_n_f32(0);
|
|
585
|
+
const size_t d_simd = d - (d & 3);
|
|
586
|
+
size_t i;
|
|
587
|
+
for (i = 0; i < d_simd; i += 4) {
|
|
588
|
+
float32x4_t xi = vld1q_f32(x + i);
|
|
589
|
+
float32x4_t yi = vld1q_f32(y + i);
|
|
590
|
+
float32x4_t sq = vsubq_f32(xi, yi);
|
|
591
|
+
accux4 = vfmaq_f32(accux4, sq, sq);
|
|
592
|
+
}
|
|
593
|
+
float32x4_t accux2 = vpaddq_f32(accux4, accux4);
|
|
594
|
+
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
|
|
595
|
+
for (; i < d; ++i) {
|
|
596
|
+
float32_t xi = x[i];
|
|
597
|
+
float32_t yi = y[i];
|
|
598
|
+
float32_t sq = xi - yi;
|
|
599
|
+
accux1 += sq * sq;
|
|
623
600
|
}
|
|
624
|
-
|
|
625
|
-
return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
|
|
601
|
+
return accux1;
|
|
626
602
|
}
|
|
627
603
|
|
|
628
|
-
float fvec_inner_product
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
float32x4_t yi = vld1q_f32 (y + i);
|
|
637
|
-
accu = vfmaq_f32 (accu, xi, yi);
|
|
604
|
+
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
605
|
+
float32x4_t accux4 = vdupq_n_f32(0);
|
|
606
|
+
const size_t d_simd = d - (d & 3);
|
|
607
|
+
size_t i;
|
|
608
|
+
for (i = 0; i < d_simd; i += 4) {
|
|
609
|
+
float32x4_t xi = vld1q_f32(x + i);
|
|
610
|
+
float32x4_t yi = vld1q_f32(y + i);
|
|
611
|
+
accux4 = vfmaq_f32(accux4, xi, yi);
|
|
638
612
|
}
|
|
639
|
-
float32x4_t
|
|
640
|
-
|
|
613
|
+
float32x4_t accux2 = vpaddq_f32(accux4, accux4);
|
|
614
|
+
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
|
|
615
|
+
for (; i < d; ++i) {
|
|
616
|
+
float32_t xi = x[i];
|
|
617
|
+
float32_t yi = y[i];
|
|
618
|
+
accux1 += xi * yi;
|
|
619
|
+
}
|
|
620
|
+
return accux1;
|
|
641
621
|
}
|
|
642
622
|
|
|
643
|
-
float fvec_norm_L2sqr
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
for (
|
|
648
|
-
float32x4_t xi = vld1q_f32
|
|
649
|
-
|
|
623
|
+
float fvec_norm_L2sqr(const float* x, size_t d) {
|
|
624
|
+
float32x4_t accux4 = vdupq_n_f32(0);
|
|
625
|
+
const size_t d_simd = d - (d & 3);
|
|
626
|
+
size_t i;
|
|
627
|
+
for (i = 0; i < d_simd; i += 4) {
|
|
628
|
+
float32x4_t xi = vld1q_f32(x + i);
|
|
629
|
+
accux4 = vfmaq_f32(accux4, xi, xi);
|
|
630
|
+
}
|
|
631
|
+
float32x4_t accux2 = vpaddq_f32(accux4, accux4);
|
|
632
|
+
float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
|
|
633
|
+
for (; i < d; ++i) {
|
|
634
|
+
float32_t xi = x[i];
|
|
635
|
+
accux1 += xi * xi;
|
|
650
636
|
}
|
|
651
|
-
|
|
652
|
-
return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
|
|
637
|
+
return accux1;
|
|
653
638
|
}
|
|
654
639
|
|
|
655
640
|
// not optimized for ARM
|
|
656
|
-
void fvec_L2sqr_ny
|
|
657
|
-
|
|
658
|
-
|
|
641
|
+
void fvec_L2sqr_ny(
|
|
642
|
+
float* dis,
|
|
643
|
+
const float* x,
|
|
644
|
+
const float* y,
|
|
645
|
+
size_t d,
|
|
646
|
+
size_t ny) {
|
|
647
|
+
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
659
648
|
}
|
|
660
649
|
|
|
661
|
-
float fvec_L1
|
|
662
|
-
|
|
663
|
-
return fvec_L1_ref (x, y, d);
|
|
650
|
+
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
651
|
+
return fvec_L1_ref(x, y, d);
|
|
664
652
|
}
|
|
665
653
|
|
|
666
|
-
float fvec_Linf
|
|
667
|
-
|
|
668
|
-
return fvec_Linf_ref (x, y, d);
|
|
654
|
+
float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
655
|
+
return fvec_Linf_ref(x, y, d);
|
|
669
656
|
}
|
|
670
657
|
|
|
658
|
+
void fvec_inner_products_ny(
|
|
659
|
+
float* dis,
|
|
660
|
+
const float* x,
|
|
661
|
+
const float* y,
|
|
662
|
+
size_t d,
|
|
663
|
+
size_t ny) {
|
|
664
|
+
fvec_inner_products_ny_ref(dis, x, y, d, ny);
|
|
665
|
+
}
|
|
671
666
|
|
|
672
667
|
#else
|
|
673
668
|
// scalar implementation
|
|
674
669
|
|
|
675
|
-
float fvec_L2sqr
|
|
676
|
-
|
|
677
|
-
size_t d)
|
|
678
|
-
{
|
|
679
|
-
return fvec_L2sqr_ref (x, y, d);
|
|
670
|
+
float fvec_L2sqr(const float* x, const float* y, size_t d) {
|
|
671
|
+
return fvec_L2sqr_ref(x, y, d);
|
|
680
672
|
}
|
|
681
673
|
|
|
682
|
-
float fvec_L1
|
|
683
|
-
|
|
684
|
-
return fvec_L1_ref (x, y, d);
|
|
674
|
+
float fvec_L1(const float* x, const float* y, size_t d) {
|
|
675
|
+
return fvec_L1_ref(x, y, d);
|
|
685
676
|
}
|
|
686
677
|
|
|
687
|
-
float fvec_Linf
|
|
688
|
-
|
|
689
|
-
return fvec_Linf_ref (x, y, d);
|
|
678
|
+
float fvec_Linf(const float* x, const float* y, size_t d) {
|
|
679
|
+
return fvec_Linf_ref(x, y, d);
|
|
690
680
|
}
|
|
691
681
|
|
|
692
|
-
float fvec_inner_product
|
|
693
|
-
|
|
694
|
-
size_t d)
|
|
695
|
-
{
|
|
696
|
-
return fvec_inner_product_ref (x, y, d);
|
|
682
|
+
float fvec_inner_product(const float* x, const float* y, size_t d) {
|
|
683
|
+
return fvec_inner_product_ref(x, y, d);
|
|
697
684
|
}
|
|
698
685
|
|
|
699
|
-
float fvec_norm_L2sqr
|
|
700
|
-
|
|
701
|
-
return fvec_norm_L2sqr_ref (x, d);
|
|
686
|
+
float fvec_norm_L2sqr(const float* x, size_t d) {
|
|
687
|
+
return fvec_norm_L2sqr_ref(x, d);
|
|
702
688
|
}
|
|
703
689
|
|
|
704
|
-
void fvec_L2sqr_ny
|
|
705
|
-
|
|
706
|
-
|
|
690
|
+
void fvec_L2sqr_ny(
|
|
691
|
+
float* dis,
|
|
692
|
+
const float* x,
|
|
693
|
+
const float* y,
|
|
694
|
+
size_t d,
|
|
695
|
+
size_t ny) {
|
|
696
|
+
fvec_L2sqr_ny_ref(dis, x, y, d, ny);
|
|
707
697
|
}
|
|
708
698
|
|
|
709
|
-
void fvec_inner_products_ny
|
|
710
|
-
|
|
711
|
-
|
|
699
|
+
void fvec_inner_products_ny(
|
|
700
|
+
float* dis,
|
|
701
|
+
const float* x,
|
|
702
|
+
const float* y,
|
|
703
|
+
size_t d,
|
|
704
|
+
size_t ny) {
|
|
705
|
+
fvec_inner_products_ny_ref(dis, x, y, d, ny);
|
|
712
706
|
}
|
|
713
707
|
|
|
714
|
-
|
|
715
708
|
#endif
|
|
716
709
|
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
710
|
/***************************************************************************
|
|
737
711
|
* heavily optimized table computations
|
|
738
712
|
***************************************************************************/
|
|
739
713
|
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
714
|
+
static inline void fvec_madd_ref(
|
|
715
|
+
size_t n,
|
|
716
|
+
const float* a,
|
|
717
|
+
float bf,
|
|
718
|
+
const float* b,
|
|
719
|
+
float* c) {
|
|
743
720
|
for (size_t i = 0; i < n; i++)
|
|
744
721
|
c[i] = a[i] + bf * b[i];
|
|
745
722
|
}
|
|
746
723
|
|
|
747
724
|
#ifdef __SSE3__
|
|
748
725
|
|
|
749
|
-
static inline void fvec_madd_sse
|
|
750
|
-
|
|
726
|
+
static inline void fvec_madd_sse(
|
|
727
|
+
size_t n,
|
|
728
|
+
const float* a,
|
|
729
|
+
float bf,
|
|
730
|
+
const float* b,
|
|
731
|
+
float* c) {
|
|
751
732
|
n >>= 2;
|
|
752
|
-
__m128 bf4 = _mm_set_ps1
|
|
753
|
-
__m128
|
|
754
|
-
__m128
|
|
755
|
-
__m128
|
|
733
|
+
__m128 bf4 = _mm_set_ps1(bf);
|
|
734
|
+
__m128* a4 = (__m128*)a;
|
|
735
|
+
__m128* b4 = (__m128*)b;
|
|
736
|
+
__m128* c4 = (__m128*)c;
|
|
756
737
|
|
|
757
738
|
while (n--) {
|
|
758
|
-
*c4 = _mm_add_ps
|
|
739
|
+
*c4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
|
|
759
740
|
b4++;
|
|
760
741
|
a4++;
|
|
761
742
|
c4++;
|
|
762
743
|
}
|
|
763
744
|
}
|
|
764
745
|
|
|
765
|
-
void fvec_madd
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
if ((n & 3) == 0 &&
|
|
769
|
-
((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
|
770
|
-
fvec_madd_sse (n, a, bf, b, c);
|
|
746
|
+
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
|
747
|
+
if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
|
748
|
+
fvec_madd_sse(n, a, bf, b, c);
|
|
771
749
|
else
|
|
772
|
-
fvec_madd_ref
|
|
750
|
+
fvec_madd_ref(n, a, bf, b, c);
|
|
773
751
|
}
|
|
774
752
|
|
|
775
753
|
#else
|
|
776
754
|
|
|
777
|
-
void fvec_madd
|
|
778
|
-
|
|
779
|
-
{
|
|
780
|
-
fvec_madd_ref (n, a, bf, b, c);
|
|
755
|
+
void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
|
|
756
|
+
fvec_madd_ref(n, a, bf, b, c);
|
|
781
757
|
}
|
|
782
758
|
|
|
783
759
|
#endif
|
|
784
760
|
|
|
785
|
-
static inline int fvec_madd_and_argmin_ref
|
|
786
|
-
|
|
761
|
+
static inline int fvec_madd_and_argmin_ref(
|
|
762
|
+
size_t n,
|
|
763
|
+
const float* a,
|
|
764
|
+
float bf,
|
|
765
|
+
const float* b,
|
|
766
|
+
float* c) {
|
|
787
767
|
float vmin = 1e20;
|
|
788
768
|
int imin = -1;
|
|
789
769
|
|
|
@@ -799,125 +779,100 @@ static inline int fvec_madd_and_argmin_ref (size_t n, const float *a,
|
|
|
799
779
|
|
|
800
780
|
#ifdef __SSE3__
|
|
801
781
|
|
|
802
|
-
static inline int fvec_madd_and_argmin_sse
|
|
803
|
-
size_t n,
|
|
804
|
-
|
|
782
|
+
static inline int fvec_madd_and_argmin_sse(
|
|
783
|
+
size_t n,
|
|
784
|
+
const float* a,
|
|
785
|
+
float bf,
|
|
786
|
+
const float* b,
|
|
787
|
+
float* c) {
|
|
805
788
|
n >>= 2;
|
|
806
|
-
__m128 bf4 = _mm_set_ps1
|
|
807
|
-
__m128 vmin4 = _mm_set_ps1
|
|
808
|
-
__m128i imin4 = _mm_set1_epi32
|
|
809
|
-
__m128i idx4 = _mm_set_epi32
|
|
810
|
-
__m128i inc4 = _mm_set1_epi32
|
|
811
|
-
__m128
|
|
812
|
-
__m128
|
|
813
|
-
__m128
|
|
789
|
+
__m128 bf4 = _mm_set_ps1(bf);
|
|
790
|
+
__m128 vmin4 = _mm_set_ps1(1e20);
|
|
791
|
+
__m128i imin4 = _mm_set1_epi32(-1);
|
|
792
|
+
__m128i idx4 = _mm_set_epi32(3, 2, 1, 0);
|
|
793
|
+
__m128i inc4 = _mm_set1_epi32(4);
|
|
794
|
+
__m128* a4 = (__m128*)a;
|
|
795
|
+
__m128* b4 = (__m128*)b;
|
|
796
|
+
__m128* c4 = (__m128*)c;
|
|
814
797
|
|
|
815
798
|
while (n--) {
|
|
816
|
-
__m128 vc4 = _mm_add_ps
|
|
799
|
+
__m128 vc4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
|
|
817
800
|
*c4 = vc4;
|
|
818
|
-
__m128i mask = (
|
|
801
|
+
__m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
|
|
819
802
|
// imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower!
|
|
820
803
|
|
|
821
|
-
imin4 = _mm_or_si128
|
|
822
|
-
|
|
823
|
-
vmin4 = _mm_min_ps
|
|
804
|
+
imin4 = _mm_or_si128(
|
|
805
|
+
_mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
|
|
806
|
+
vmin4 = _mm_min_ps(vmin4, vc4);
|
|
824
807
|
b4++;
|
|
825
808
|
a4++;
|
|
826
809
|
c4++;
|
|
827
|
-
idx4 = _mm_add_epi32
|
|
810
|
+
idx4 = _mm_add_epi32(idx4, inc4);
|
|
828
811
|
}
|
|
829
812
|
|
|
830
813
|
// 4 values -> 2
|
|
831
814
|
{
|
|
832
|
-
idx4 = _mm_shuffle_epi32
|
|
833
|
-
__m128 vc4 = _mm_shuffle_ps
|
|
834
|
-
__m128i mask = (
|
|
835
|
-
imin4 = _mm_or_si128
|
|
836
|
-
|
|
837
|
-
vmin4 = _mm_min_ps
|
|
815
|
+
idx4 = _mm_shuffle_epi32(imin4, 3 << 2 | 2);
|
|
816
|
+
__m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 3 << 2 | 2);
|
|
817
|
+
__m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
|
|
818
|
+
imin4 = _mm_or_si128(
|
|
819
|
+
_mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
|
|
820
|
+
vmin4 = _mm_min_ps(vmin4, vc4);
|
|
838
821
|
}
|
|
839
822
|
// 2 values -> 1
|
|
840
823
|
{
|
|
841
|
-
idx4 = _mm_shuffle_epi32
|
|
842
|
-
__m128 vc4 = _mm_shuffle_ps
|
|
843
|
-
__m128i mask = (
|
|
844
|
-
imin4 = _mm_or_si128
|
|
845
|
-
|
|
824
|
+
idx4 = _mm_shuffle_epi32(imin4, 1);
|
|
825
|
+
__m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 1);
|
|
826
|
+
__m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
|
|
827
|
+
imin4 = _mm_or_si128(
|
|
828
|
+
_mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
|
|
846
829
|
// vmin4 = _mm_min_ps (vmin4, vc4);
|
|
847
830
|
}
|
|
848
|
-
return _mm_cvtsi128_si32
|
|
831
|
+
return _mm_cvtsi128_si32(imin4);
|
|
849
832
|
}
|
|
850
833
|
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
834
|
+
int fvec_madd_and_argmin(
|
|
835
|
+
size_t n,
|
|
836
|
+
const float* a,
|
|
837
|
+
float bf,
|
|
838
|
+
const float* b,
|
|
839
|
+
float* c) {
|
|
840
|
+
if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
|
841
|
+
return fvec_madd_and_argmin_sse(n, a, bf, b, c);
|
|
858
842
|
else
|
|
859
|
-
return fvec_madd_and_argmin_ref
|
|
843
|
+
return fvec_madd_and_argmin_ref(n, a, bf, b, c);
|
|
860
844
|
}
|
|
861
845
|
|
|
862
846
|
#else
|
|
863
847
|
|
|
864
|
-
int fvec_madd_and_argmin
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
848
|
+
int fvec_madd_and_argmin(
|
|
849
|
+
size_t n,
|
|
850
|
+
const float* a,
|
|
851
|
+
float bf,
|
|
852
|
+
const float* b,
|
|
853
|
+
float* c) {
|
|
854
|
+
return fvec_madd_and_argmin_ref(n, a, bf, b, c);
|
|
868
855
|
}
|
|
869
856
|
|
|
870
857
|
#endif
|
|
871
858
|
|
|
872
|
-
|
|
873
859
|
/***************************************************************************
|
|
874
860
|
* PQ tables computations
|
|
875
861
|
***************************************************************************/
|
|
876
862
|
|
|
877
|
-
#ifdef __AVX2__
|
|
878
|
-
|
|
879
863
|
namespace {
|
|
880
864
|
|
|
881
|
-
|
|
882
|
-
// get even float32's of a and b, interleaved
|
|
883
|
-
simd8float32 geteven(simd8float32 a, simd8float32 b) {
|
|
884
|
-
return simd8float32(
|
|
885
|
-
_mm256_shuffle_ps(a.f, b.f, 0 << 0 | 2 << 2 | 0 << 4 | 2 << 6)
|
|
886
|
-
);
|
|
887
|
-
}
|
|
888
|
-
|
|
889
|
-
// get odd float32's of a and b, interleaved
|
|
890
|
-
simd8float32 getodd(simd8float32 a, simd8float32 b) {
|
|
891
|
-
return simd8float32(
|
|
892
|
-
_mm256_shuffle_ps(a.f, b.f, 1 << 0 | 3 << 2 | 1 << 4 | 3 << 6)
|
|
893
|
-
);
|
|
894
|
-
}
|
|
895
|
-
|
|
896
|
-
// 3 cycles
|
|
897
|
-
// if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
|
|
898
|
-
simd8float32 getlow128(simd8float32 a, simd8float32 b) {
|
|
899
|
-
return simd8float32(
|
|
900
|
-
_mm256_permute2f128_ps(a.f, b.f, 0 | 2 << 4)
|
|
901
|
-
);
|
|
902
|
-
}
|
|
903
|
-
|
|
904
|
-
simd8float32 gethigh128(simd8float32 a, simd8float32 b) {
|
|
905
|
-
return simd8float32(
|
|
906
|
-
_mm256_permute2f128_ps(a.f, b.f, 1 | 3 << 4)
|
|
907
|
-
);
|
|
908
|
-
}
|
|
909
|
-
|
|
910
865
|
/// compute the IP for dsub = 2 for 8 centroids and 4 sub-vectors at a time
|
|
911
|
-
template<bool is_inner_product>
|
|
866
|
+
template <bool is_inner_product>
|
|
912
867
|
void pq2_8cents_table(
|
|
913
868
|
const simd8float32 centroids[8],
|
|
914
869
|
const simd8float32 x,
|
|
915
|
-
float
|
|
916
|
-
|
|
917
|
-
|
|
870
|
+
float* out,
|
|
871
|
+
size_t ldo,
|
|
872
|
+
size_t nout = 4) {
|
|
918
873
|
simd8float32 ips[4];
|
|
919
874
|
|
|
920
|
-
for(int i = 0; i < 4; i++) {
|
|
875
|
+
for (int i = 0; i < 4; i++) {
|
|
921
876
|
simd8float32 p1, p2;
|
|
922
877
|
if (is_inner_product) {
|
|
923
878
|
p1 = x * centroids[2 * i];
|
|
@@ -941,21 +896,21 @@ void pq2_8cents_table(
|
|
|
941
896
|
simd8float32 ip1 = getlow128(ip13a, ip13b);
|
|
942
897
|
simd8float32 ip3 = gethigh128(ip13a, ip13b);
|
|
943
898
|
|
|
944
|
-
switch(nout) {
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
899
|
+
switch (nout) {
|
|
900
|
+
case 4:
|
|
901
|
+
ip3.storeu(out + 3 * ldo);
|
|
902
|
+
case 3:
|
|
903
|
+
ip2.storeu(out + 2 * ldo);
|
|
904
|
+
case 2:
|
|
905
|
+
ip1.storeu(out + 1 * ldo);
|
|
906
|
+
case 1:
|
|
907
|
+
ip0.storeu(out);
|
|
953
908
|
}
|
|
954
909
|
}
|
|
955
910
|
|
|
956
|
-
simd8float32 load_simd8float32_partial(const float
|
|
911
|
+
simd8float32 load_simd8float32_partial(const float* x, int n) {
|
|
957
912
|
ALIGNED(32) float tmp[8] = {0, 0, 0, 0, 0, 0, 0, 0};
|
|
958
|
-
float
|
|
913
|
+
float* wp = tmp;
|
|
959
914
|
for (int i = 0; i < n; i++) {
|
|
960
915
|
*wp++ = *x++;
|
|
961
916
|
}
|
|
@@ -964,25 +919,23 @@ simd8float32 load_simd8float32_partial(const float *x, int n) {
|
|
|
964
919
|
|
|
965
920
|
} // anonymous namespace
|
|
966
921
|
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
922
|
void compute_PQ_dis_tables_dsub2(
|
|
971
|
-
size_t d,
|
|
972
|
-
size_t
|
|
923
|
+
size_t d,
|
|
924
|
+
size_t ksub,
|
|
925
|
+
const float* all_centroids,
|
|
926
|
+
size_t nx,
|
|
927
|
+
const float* x,
|
|
973
928
|
bool is_inner_product,
|
|
974
|
-
float
|
|
975
|
-
{
|
|
929
|
+
float* dis_tables) {
|
|
976
930
|
size_t M = d / 2;
|
|
977
931
|
FAISS_THROW_IF_NOT(ksub % 8 == 0);
|
|
978
932
|
|
|
979
|
-
for(size_t m0 = 0; m0 < M; m0 += 4) {
|
|
933
|
+
for (size_t m0 = 0; m0 < M; m0 += 4) {
|
|
980
934
|
int m1 = std::min(M, m0 + 4);
|
|
981
|
-
for(int k0 = 0; k0 < ksub; k0 += 8) {
|
|
982
|
-
|
|
935
|
+
for (int k0 = 0; k0 < ksub; k0 += 8) {
|
|
983
936
|
simd8float32 centroids[8];
|
|
984
937
|
for (int k = 0; k < 8; k++) {
|
|
985
|
-
float centroid[8]
|
|
938
|
+
ALIGNED(32) float centroid[8];
|
|
986
939
|
size_t wp = 0;
|
|
987
940
|
size_t rp = (m0 * ksub + k + k0) * 2;
|
|
988
941
|
for (int m = m0; m < m1; m++) {
|
|
@@ -992,45 +945,82 @@ void compute_PQ_dis_tables_dsub2(
|
|
|
992
945
|
}
|
|
993
946
|
centroids[k] = simd8float32(centroid);
|
|
994
947
|
}
|
|
995
|
-
for(size_t i = 0; i < nx; i++) {
|
|
948
|
+
for (size_t i = 0; i < nx; i++) {
|
|
996
949
|
simd8float32 xi;
|
|
997
950
|
if (m1 == m0 + 4) {
|
|
998
951
|
xi.loadu(x + i * d + m0 * 2);
|
|
999
952
|
} else {
|
|
1000
|
-
xi = load_simd8float32_partial(
|
|
953
|
+
xi = load_simd8float32_partial(
|
|
954
|
+
x + i * d + m0 * 2, 2 * (m1 - m0));
|
|
1001
955
|
}
|
|
1002
956
|
|
|
1003
|
-
if(is_inner_product) {
|
|
957
|
+
if (is_inner_product) {
|
|
1004
958
|
pq2_8cents_table<true>(
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
959
|
+
centroids,
|
|
960
|
+
xi,
|
|
961
|
+
dis_tables + (i * M + m0) * ksub + k0,
|
|
962
|
+
ksub,
|
|
963
|
+
m1 - m0);
|
|
1009
964
|
} else {
|
|
1010
965
|
pq2_8cents_table<false>(
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
966
|
+
centroids,
|
|
967
|
+
xi,
|
|
968
|
+
dis_tables + (i * M + m0) * ksub + k0,
|
|
969
|
+
ksub,
|
|
970
|
+
m1 - m0);
|
|
1015
971
|
}
|
|
1016
972
|
}
|
|
1017
973
|
}
|
|
1018
974
|
}
|
|
1019
|
-
|
|
1020
975
|
}
|
|
1021
976
|
|
|
1022
|
-
|
|
977
|
+
/*********************************************************
|
|
978
|
+
* Vector to vector functions
|
|
979
|
+
*********************************************************/
|
|
1023
980
|
|
|
1024
|
-
void
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
981
|
+
void fvec_sub(size_t d, const float* a, const float* b, float* c) {
|
|
982
|
+
size_t i;
|
|
983
|
+
for (i = 0; i + 7 < d; i += 8) {
|
|
984
|
+
simd8float32 ci, ai, bi;
|
|
985
|
+
ai.loadu(a + i);
|
|
986
|
+
bi.loadu(b + i);
|
|
987
|
+
ci = ai - bi;
|
|
988
|
+
ci.storeu(c + i);
|
|
989
|
+
}
|
|
990
|
+
// finish non-multiple of 8 remainder
|
|
991
|
+
for (; i < d; i++) {
|
|
992
|
+
c[i] = a[i] - b[i];
|
|
993
|
+
}
|
|
1031
994
|
}
|
|
1032
995
|
|
|
1033
|
-
|
|
996
|
+
void fvec_add(size_t d, const float* a, const float* b, float* c) {
|
|
997
|
+
size_t i;
|
|
998
|
+
for (i = 0; i + 7 < d; i += 8) {
|
|
999
|
+
simd8float32 ci, ai, bi;
|
|
1000
|
+
ai.loadu(a + i);
|
|
1001
|
+
bi.loadu(b + i);
|
|
1002
|
+
ci = ai + bi;
|
|
1003
|
+
ci.storeu(c + i);
|
|
1004
|
+
}
|
|
1005
|
+
// finish non-multiple of 8 remainder
|
|
1006
|
+
for (; i < d; i++) {
|
|
1007
|
+
c[i] = a[i] + b[i];
|
|
1008
|
+
}
|
|
1009
|
+
}
|
|
1034
1010
|
|
|
1011
|
+
void fvec_add(size_t d, const float* a, float b, float* c) {
|
|
1012
|
+
size_t i;
|
|
1013
|
+
simd8float32 bv(b);
|
|
1014
|
+
for (i = 0; i + 7 < d; i += 8) {
|
|
1015
|
+
simd8float32 ci, ai, bi;
|
|
1016
|
+
ai.loadu(a + i);
|
|
1017
|
+
ci = ai + bv;
|
|
1018
|
+
ci.storeu(c + i);
|
|
1019
|
+
}
|
|
1020
|
+
// finish non-multiple of 8 remainder
|
|
1021
|
+
for (; i < d; i++) {
|
|
1022
|
+
c[i] = a[i] + b;
|
|
1023
|
+
}
|
|
1024
|
+
}
|
|
1035
1025
|
|
|
1036
1026
|
} // namespace faiss
|