faiss 0.1.0 → 0.1.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
@@ -0,0 +1,243 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
/* All distance functions for L2 and IP distances.
|
11
|
+
* The actual functions are implemented in distances.cpp and distances_simd.cpp */
|
12
|
+
|
13
|
+
#pragma once
|
14
|
+
|
15
|
+
#include <stdint.h>
|
16
|
+
|
17
|
+
#include <faiss/utils/Heap.h>
|
18
|
+
|
19
|
+
|
20
|
+
namespace faiss {
|
21
|
+
|
22
|
+
/*********************************************************
|
23
|
+
* Optimized distance/norm/inner prod computations
|
24
|
+
*********************************************************/
|
25
|
+
|
26
|
+
|
27
|
+
/// Squared L2 distance between two vectors
|
28
|
+
float fvec_L2sqr (
|
29
|
+
const float * x,
|
30
|
+
const float * y,
|
31
|
+
size_t d);
|
32
|
+
|
33
|
+
/// inner product
|
34
|
+
float fvec_inner_product (
|
35
|
+
const float * x,
|
36
|
+
const float * y,
|
37
|
+
size_t d);
|
38
|
+
|
39
|
+
/// L1 distance
|
40
|
+
float fvec_L1 (
|
41
|
+
const float * x,
|
42
|
+
const float * y,
|
43
|
+
size_t d);
|
44
|
+
|
45
|
+
float fvec_Linf (
|
46
|
+
const float * x,
|
47
|
+
const float * y,
|
48
|
+
size_t d);
|
49
|
+
|
50
|
+
|
51
|
+
/** Compute pairwise distances between sets of vectors
|
52
|
+
*
|
53
|
+
* @param d dimension of the vectors
|
54
|
+
* @param nq nb of query vectors
|
55
|
+
* @param nb nb of database vectors
|
56
|
+
* @param xq query vectors (size nq * d)
|
57
|
+
* @param xb database vectros (size nb * d)
|
58
|
+
* @param dis output distances (size nq * nb)
|
59
|
+
* @param ldq,ldb, ldd strides for the matrices
|
60
|
+
*/
|
61
|
+
void pairwise_L2sqr (int64_t d,
|
62
|
+
int64_t nq, const float *xq,
|
63
|
+
int64_t nb, const float *xb,
|
64
|
+
float *dis,
|
65
|
+
int64_t ldq = -1, int64_t ldb = -1, int64_t ldd = -1);
|
66
|
+
|
67
|
+
/* compute the inner product between nx vectors x and one y */
|
68
|
+
void fvec_inner_products_ny (
|
69
|
+
float * ip, /* output inner product */
|
70
|
+
const float * x,
|
71
|
+
const float * y,
|
72
|
+
size_t d, size_t ny);
|
73
|
+
|
74
|
+
/* compute ny square L2 distance bewteen x and a set of contiguous y vectors */
|
75
|
+
void fvec_L2sqr_ny (
|
76
|
+
float * dis,
|
77
|
+
const float * x,
|
78
|
+
const float * y,
|
79
|
+
size_t d, size_t ny);
|
80
|
+
|
81
|
+
|
82
|
+
/** squared norm of a vector */
|
83
|
+
float fvec_norm_L2sqr (const float * x,
|
84
|
+
size_t d);
|
85
|
+
|
86
|
+
/** compute the L2 norms for a set of vectors
|
87
|
+
*
|
88
|
+
* @param ip output norms, size nx
|
89
|
+
* @param x set of vectors, size nx * d
|
90
|
+
*/
|
91
|
+
void fvec_norms_L2 (float * ip, const float * x, size_t d, size_t nx);
|
92
|
+
|
93
|
+
/// same as fvec_norms_L2, but computes square norms
|
94
|
+
void fvec_norms_L2sqr (float * ip, const float * x, size_t d, size_t nx);
|
95
|
+
|
96
|
+
/* L2-renormalize a set of vector. Nothing done if the vector is 0-normed */
|
97
|
+
void fvec_renorm_L2 (size_t d, size_t nx, float * x);
|
98
|
+
|
99
|
+
|
100
|
+
/* This function exists because the Torch counterpart is extremly slow
|
101
|
+
(not multi-threaded + unexpected overhead even in single thread).
|
102
|
+
It is here to implement the usual property |x-y|^2=|x|^2+|y|^2-2<x|y> */
|
103
|
+
void inner_product_to_L2sqr (float * dis,
|
104
|
+
const float * nr1,
|
105
|
+
const float * nr2,
|
106
|
+
size_t n1, size_t n2);
|
107
|
+
|
108
|
+
/***************************************************************************
|
109
|
+
* Compute a subset of distances
|
110
|
+
***************************************************************************/
|
111
|
+
|
112
|
+
/* compute the inner product between x and a subset y of ny vectors,
|
113
|
+
whose indices are given by idy. */
|
114
|
+
void fvec_inner_products_by_idx (
|
115
|
+
float * ip,
|
116
|
+
const float * x,
|
117
|
+
const float * y,
|
118
|
+
const int64_t *ids,
|
119
|
+
size_t d, size_t nx, size_t ny);
|
120
|
+
|
121
|
+
/* same but for a subset in y indexed by idsy (ny vectors in total) */
|
122
|
+
void fvec_L2sqr_by_idx (
|
123
|
+
float * dis,
|
124
|
+
const float * x,
|
125
|
+
const float * y,
|
126
|
+
const int64_t *ids, /* ids of y vecs */
|
127
|
+
size_t d, size_t nx, size_t ny);
|
128
|
+
|
129
|
+
|
130
|
+
/** compute dis[j] = L2sqr(x[ix[j]], y[iy[j]]) forall j=0..n-1
|
131
|
+
*
|
132
|
+
* @param x size (max(ix) + 1, d)
|
133
|
+
* @param y size (max(iy) + 1, d)
|
134
|
+
* @param ix size n
|
135
|
+
* @param iy size n
|
136
|
+
* @param dis size n
|
137
|
+
*/
|
138
|
+
void pairwise_indexed_L2sqr (
|
139
|
+
size_t d, size_t n,
|
140
|
+
const float * x, const int64_t *ix,
|
141
|
+
const float * y, const int64_t *iy,
|
142
|
+
float *dis);
|
143
|
+
|
144
|
+
/* same for inner product */
|
145
|
+
void pairwise_indexed_inner_product (
|
146
|
+
size_t d, size_t n,
|
147
|
+
const float * x, const int64_t *ix,
|
148
|
+
const float * y, const int64_t *iy,
|
149
|
+
float *dis);
|
150
|
+
|
151
|
+
/***************************************************************************
|
152
|
+
* KNN functions
|
153
|
+
***************************************************************************/
|
154
|
+
|
155
|
+
// threshold on nx above which we switch to BLAS to compute distances
|
156
|
+
extern int distance_compute_blas_threshold;
|
157
|
+
|
158
|
+
/** Return the k nearest neighors of each of the nx vectors x among the ny
|
159
|
+
* vector y, w.r.t to max inner product
|
160
|
+
*
|
161
|
+
* @param x query vectors, size nx * d
|
162
|
+
* @param y database vectors, size ny * d
|
163
|
+
* @param res result array, which also provides k. Sorted on output
|
164
|
+
*/
|
165
|
+
void knn_inner_product (
|
166
|
+
const float * x,
|
167
|
+
const float * y,
|
168
|
+
size_t d, size_t nx, size_t ny,
|
169
|
+
float_minheap_array_t * res);
|
170
|
+
|
171
|
+
/** Same as knn_inner_product, for the L2 distance */
|
172
|
+
void knn_L2sqr (
|
173
|
+
const float * x,
|
174
|
+
const float * y,
|
175
|
+
size_t d, size_t nx, size_t ny,
|
176
|
+
float_maxheap_array_t * res);
|
177
|
+
|
178
|
+
|
179
|
+
|
180
|
+
/** same as knn_L2sqr, but base_shift[bno] is subtracted to all
|
181
|
+
* computed distances.
|
182
|
+
*
|
183
|
+
* @param base_shift size ny
|
184
|
+
*/
|
185
|
+
void knn_L2sqr_base_shift (
|
186
|
+
const float * x,
|
187
|
+
const float * y,
|
188
|
+
size_t d, size_t nx, size_t ny,
|
189
|
+
float_maxheap_array_t * res,
|
190
|
+
const float *base_shift);
|
191
|
+
|
192
|
+
/* Find the nearest neighbors for nx queries in a set of ny vectors
|
193
|
+
* indexed by ids. May be useful for re-ranking a pre-selected vector list
|
194
|
+
*/
|
195
|
+
void knn_inner_products_by_idx (
|
196
|
+
const float * x,
|
197
|
+
const float * y,
|
198
|
+
const int64_t * ids,
|
199
|
+
size_t d, size_t nx, size_t ny,
|
200
|
+
float_minheap_array_t * res);
|
201
|
+
|
202
|
+
void knn_L2sqr_by_idx (const float * x,
|
203
|
+
const float * y,
|
204
|
+
const int64_t * ids,
|
205
|
+
size_t d, size_t nx, size_t ny,
|
206
|
+
float_maxheap_array_t * res);
|
207
|
+
|
208
|
+
/***************************************************************************
|
209
|
+
* Range search
|
210
|
+
***************************************************************************/
|
211
|
+
|
212
|
+
|
213
|
+
|
214
|
+
/// Forward declaration, see AuxIndexStructures.h
|
215
|
+
struct RangeSearchResult;
|
216
|
+
|
217
|
+
/** Return the k nearest neighors of each of the nx vectors x among the ny
|
218
|
+
* vector y, w.r.t to max inner product
|
219
|
+
*
|
220
|
+
* @param x query vectors, size nx * d
|
221
|
+
* @param y database vectors, size ny * d
|
222
|
+
* @param radius search radius around the x vectors
|
223
|
+
* @param result result structure
|
224
|
+
*/
|
225
|
+
void range_search_L2sqr (
|
226
|
+
const float * x,
|
227
|
+
const float * y,
|
228
|
+
size_t d, size_t nx, size_t ny,
|
229
|
+
float radius,
|
230
|
+
RangeSearchResult *result);
|
231
|
+
|
232
|
+
/// same as range_search_L2sqr for the inner product similarity
|
233
|
+
void range_search_inner_product (
|
234
|
+
const float * x,
|
235
|
+
const float * y,
|
236
|
+
size_t d, size_t nx, size_t ny,
|
237
|
+
float radius,
|
238
|
+
RangeSearchResult *result);
|
239
|
+
|
240
|
+
|
241
|
+
|
242
|
+
|
243
|
+
} // namespace faiss
|
@@ -0,0 +1,809 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#include <faiss/utils/distances.h>
|
11
|
+
|
12
|
+
#include <cstdio>
|
13
|
+
#include <cassert>
|
14
|
+
#include <cstring>
|
15
|
+
#include <cmath>
|
16
|
+
|
17
|
+
#ifdef __SSE__
|
18
|
+
#include <immintrin.h>
|
19
|
+
#endif
|
20
|
+
|
21
|
+
#ifdef __aarch64__
|
22
|
+
#include <arm_neon.h>
|
23
|
+
#endif
|
24
|
+
|
25
|
+
#include <omp.h>
|
26
|
+
|
27
|
+
namespace faiss {
|
28
|
+
|
29
|
+
#ifdef __AVX__
|
30
|
+
#define USE_AVX
|
31
|
+
#endif
|
32
|
+
|
33
|
+
|
34
|
+
/*********************************************************
|
35
|
+
* Optimized distance computations
|
36
|
+
*********************************************************/
|
37
|
+
|
38
|
+
|
39
|
+
/* Functions to compute:
|
40
|
+
- L2 distance between 2 vectors
|
41
|
+
- inner product between 2 vectors
|
42
|
+
- L2 norm of a vector
|
43
|
+
|
44
|
+
The functions should probably not be invoked when a large number of
|
45
|
+
vectors are be processed in batch (in which case Matrix multiply
|
46
|
+
is faster), but may be useful for comparing vectors isolated in
|
47
|
+
memory.
|
48
|
+
|
49
|
+
Works with any vectors of any dimension, even unaligned (in which
|
50
|
+
case they are slower).
|
51
|
+
|
52
|
+
*/
|
53
|
+
|
54
|
+
|
55
|
+
/*********************************************************
|
56
|
+
* Reference implementations
|
57
|
+
*/
|
58
|
+
|
59
|
+
|
60
|
+
float fvec_L2sqr_ref (const float * x,
|
61
|
+
const float * y,
|
62
|
+
size_t d)
|
63
|
+
{
|
64
|
+
size_t i;
|
65
|
+
float res = 0;
|
66
|
+
for (i = 0; i < d; i++) {
|
67
|
+
const float tmp = x[i] - y[i];
|
68
|
+
res += tmp * tmp;
|
69
|
+
}
|
70
|
+
return res;
|
71
|
+
}
|
72
|
+
|
73
|
+
float fvec_L1_ref (const float * x,
|
74
|
+
const float * y,
|
75
|
+
size_t d)
|
76
|
+
{
|
77
|
+
size_t i;
|
78
|
+
float res = 0;
|
79
|
+
for (i = 0; i < d; i++) {
|
80
|
+
const float tmp = x[i] - y[i];
|
81
|
+
res += fabs(tmp);
|
82
|
+
}
|
83
|
+
return res;
|
84
|
+
}
|
85
|
+
|
86
|
+
float fvec_Linf_ref (const float * x,
|
87
|
+
const float * y,
|
88
|
+
size_t d)
|
89
|
+
{
|
90
|
+
size_t i;
|
91
|
+
float res = 0;
|
92
|
+
for (i = 0; i < d; i++) {
|
93
|
+
res = fmax(res, fabs(x[i] - y[i]));
|
94
|
+
}
|
95
|
+
return res;
|
96
|
+
}
|
97
|
+
|
98
|
+
float fvec_inner_product_ref (const float * x,
|
99
|
+
const float * y,
|
100
|
+
size_t d)
|
101
|
+
{
|
102
|
+
size_t i;
|
103
|
+
float res = 0;
|
104
|
+
for (i = 0; i < d; i++)
|
105
|
+
res += x[i] * y[i];
|
106
|
+
return res;
|
107
|
+
}
|
108
|
+
|
109
|
+
float fvec_norm_L2sqr_ref (const float *x, size_t d)
|
110
|
+
{
|
111
|
+
size_t i;
|
112
|
+
double res = 0;
|
113
|
+
for (i = 0; i < d; i++)
|
114
|
+
res += x[i] * x[i];
|
115
|
+
return res;
|
116
|
+
}
|
117
|
+
|
118
|
+
|
119
|
+
void fvec_L2sqr_ny_ref (float * dis,
|
120
|
+
const float * x,
|
121
|
+
const float * y,
|
122
|
+
size_t d, size_t ny)
|
123
|
+
{
|
124
|
+
for (size_t i = 0; i < ny; i++) {
|
125
|
+
dis[i] = fvec_L2sqr (x, y, d);
|
126
|
+
y += d;
|
127
|
+
}
|
128
|
+
}
|
129
|
+
|
130
|
+
|
131
|
+
|
132
|
+
|
133
|
+
/*********************************************************
|
134
|
+
* SSE and AVX implementations
|
135
|
+
*/
|
136
|
+
|
137
|
+
#ifdef __SSE__
|
138
|
+
|
139
|
+
// reads 0 <= d < 4 floats as __m128
|
140
|
+
static inline __m128 masked_read (int d, const float *x)
|
141
|
+
{
|
142
|
+
assert (0 <= d && d < 4);
|
143
|
+
__attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
|
144
|
+
switch (d) {
|
145
|
+
case 3:
|
146
|
+
buf[2] = x[2];
|
147
|
+
case 2:
|
148
|
+
buf[1] = x[1];
|
149
|
+
case 1:
|
150
|
+
buf[0] = x[0];
|
151
|
+
}
|
152
|
+
return _mm_load_ps (buf);
|
153
|
+
// cannot use AVX2 _mm_mask_set1_epi32
|
154
|
+
}
|
155
|
+
|
156
|
+
float fvec_norm_L2sqr (const float * x,
|
157
|
+
size_t d)
|
158
|
+
{
|
159
|
+
__m128 mx;
|
160
|
+
__m128 msum1 = _mm_setzero_ps();
|
161
|
+
|
162
|
+
while (d >= 4) {
|
163
|
+
mx = _mm_loadu_ps (x); x += 4;
|
164
|
+
msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
|
165
|
+
d -= 4;
|
166
|
+
}
|
167
|
+
|
168
|
+
mx = masked_read (d, x);
|
169
|
+
msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
|
170
|
+
|
171
|
+
msum1 = _mm_hadd_ps (msum1, msum1);
|
172
|
+
msum1 = _mm_hadd_ps (msum1, msum1);
|
173
|
+
return _mm_cvtss_f32 (msum1);
|
174
|
+
}
|
175
|
+
|
176
|
+
namespace {
|
177
|
+
|
178
|
+
float sqr (float x) {
|
179
|
+
return x * x;
|
180
|
+
}
|
181
|
+
|
182
|
+
|
183
|
+
void fvec_L2sqr_ny_D1 (float * dis, const float * x,
|
184
|
+
const float * y, size_t ny)
|
185
|
+
{
|
186
|
+
float x0s = x[0];
|
187
|
+
__m128 x0 = _mm_set_ps (x0s, x0s, x0s, x0s);
|
188
|
+
|
189
|
+
size_t i;
|
190
|
+
for (i = 0; i + 3 < ny; i += 4) {
|
191
|
+
__m128 tmp, accu;
|
192
|
+
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
193
|
+
accu = tmp * tmp;
|
194
|
+
dis[i] = _mm_cvtss_f32 (accu);
|
195
|
+
tmp = _mm_shuffle_ps (accu, accu, 1);
|
196
|
+
dis[i + 1] = _mm_cvtss_f32 (tmp);
|
197
|
+
tmp = _mm_shuffle_ps (accu, accu, 2);
|
198
|
+
dis[i + 2] = _mm_cvtss_f32 (tmp);
|
199
|
+
tmp = _mm_shuffle_ps (accu, accu, 3);
|
200
|
+
dis[i + 3] = _mm_cvtss_f32 (tmp);
|
201
|
+
}
|
202
|
+
while (i < ny) { // handle non-multiple-of-4 case
|
203
|
+
dis[i++] = sqr(x0s - *y++);
|
204
|
+
}
|
205
|
+
}
|
206
|
+
|
207
|
+
|
208
|
+
void fvec_L2sqr_ny_D2 (float * dis, const float * x,
|
209
|
+
const float * y, size_t ny)
|
210
|
+
{
|
211
|
+
__m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
|
212
|
+
|
213
|
+
size_t i;
|
214
|
+
for (i = 0; i + 1 < ny; i += 2) {
|
215
|
+
__m128 tmp, accu;
|
216
|
+
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
217
|
+
accu = tmp * tmp;
|
218
|
+
accu = _mm_hadd_ps (accu, accu);
|
219
|
+
dis[i] = _mm_cvtss_f32 (accu);
|
220
|
+
accu = _mm_shuffle_ps (accu, accu, 3);
|
221
|
+
dis[i + 1] = _mm_cvtss_f32 (accu);
|
222
|
+
}
|
223
|
+
if (i < ny) { // handle odd case
|
224
|
+
dis[i] = sqr(x[0] - y[0]) + sqr(x[1] - y[1]);
|
225
|
+
}
|
226
|
+
}
|
227
|
+
|
228
|
+
|
229
|
+
|
230
|
+
void fvec_L2sqr_ny_D4 (float * dis, const float * x,
|
231
|
+
const float * y, size_t ny)
|
232
|
+
{
|
233
|
+
__m128 x0 = _mm_loadu_ps(x);
|
234
|
+
|
235
|
+
for (size_t i = 0; i < ny; i++) {
|
236
|
+
__m128 tmp, accu;
|
237
|
+
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
238
|
+
accu = tmp * tmp;
|
239
|
+
accu = _mm_hadd_ps (accu, accu);
|
240
|
+
accu = _mm_hadd_ps (accu, accu);
|
241
|
+
dis[i] = _mm_cvtss_f32 (accu);
|
242
|
+
}
|
243
|
+
}
|
244
|
+
|
245
|
+
|
246
|
+
void fvec_L2sqr_ny_D8 (float * dis, const float * x,
|
247
|
+
const float * y, size_t ny)
|
248
|
+
{
|
249
|
+
__m128 x0 = _mm_loadu_ps(x);
|
250
|
+
__m128 x1 = _mm_loadu_ps(x + 4);
|
251
|
+
|
252
|
+
for (size_t i = 0; i < ny; i++) {
|
253
|
+
__m128 tmp, accu;
|
254
|
+
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
255
|
+
accu = tmp * tmp;
|
256
|
+
tmp = x1 - _mm_loadu_ps (y); y += 4;
|
257
|
+
accu += tmp * tmp;
|
258
|
+
accu = _mm_hadd_ps (accu, accu);
|
259
|
+
accu = _mm_hadd_ps (accu, accu);
|
260
|
+
dis[i] = _mm_cvtss_f32 (accu);
|
261
|
+
}
|
262
|
+
}
|
263
|
+
|
264
|
+
|
265
|
+
void fvec_L2sqr_ny_D12 (float * dis, const float * x,
|
266
|
+
const float * y, size_t ny)
|
267
|
+
{
|
268
|
+
__m128 x0 = _mm_loadu_ps(x);
|
269
|
+
__m128 x1 = _mm_loadu_ps(x + 4);
|
270
|
+
__m128 x2 = _mm_loadu_ps(x + 8);
|
271
|
+
|
272
|
+
for (size_t i = 0; i < ny; i++) {
|
273
|
+
__m128 tmp, accu;
|
274
|
+
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
275
|
+
accu = tmp * tmp;
|
276
|
+
tmp = x1 - _mm_loadu_ps (y); y += 4;
|
277
|
+
accu += tmp * tmp;
|
278
|
+
tmp = x2 - _mm_loadu_ps (y); y += 4;
|
279
|
+
accu += tmp * tmp;
|
280
|
+
accu = _mm_hadd_ps (accu, accu);
|
281
|
+
accu = _mm_hadd_ps (accu, accu);
|
282
|
+
dis[i] = _mm_cvtss_f32 (accu);
|
283
|
+
}
|
284
|
+
}
|
285
|
+
|
286
|
+
|
287
|
+
} // anonymous namespace
|
288
|
+
|
289
|
+
void fvec_L2sqr_ny (float * dis, const float * x,
|
290
|
+
const float * y, size_t d, size_t ny) {
|
291
|
+
// optimized for a few special cases
|
292
|
+
switch(d) {
|
293
|
+
case 1:
|
294
|
+
fvec_L2sqr_ny_D1 (dis, x, y, ny);
|
295
|
+
return;
|
296
|
+
case 2:
|
297
|
+
fvec_L2sqr_ny_D2 (dis, x, y, ny);
|
298
|
+
return;
|
299
|
+
case 4:
|
300
|
+
fvec_L2sqr_ny_D4 (dis, x, y, ny);
|
301
|
+
return;
|
302
|
+
case 8:
|
303
|
+
fvec_L2sqr_ny_D8 (dis, x, y, ny);
|
304
|
+
return;
|
305
|
+
case 12:
|
306
|
+
fvec_L2sqr_ny_D12 (dis, x, y, ny);
|
307
|
+
return;
|
308
|
+
default:
|
309
|
+
fvec_L2sqr_ny_ref (dis, x, y, d, ny);
|
310
|
+
return;
|
311
|
+
}
|
312
|
+
}
|
313
|
+
|
314
|
+
|
315
|
+
|
316
|
+
#endif
|
317
|
+
|
318
|
+
#ifdef USE_AVX
|
319
|
+
|
320
|
+
// reads 0 <= d < 8 floats as __m256
|
321
|
+
static inline __m256 masked_read_8 (int d, const float *x)
|
322
|
+
{
|
323
|
+
assert (0 <= d && d < 8);
|
324
|
+
if (d < 4) {
|
325
|
+
__m256 res = _mm256_setzero_ps ();
|
326
|
+
res = _mm256_insertf128_ps (res, masked_read (d, x), 0);
|
327
|
+
return res;
|
328
|
+
} else {
|
329
|
+
__m256 res = _mm256_setzero_ps ();
|
330
|
+
res = _mm256_insertf128_ps (res, _mm_loadu_ps (x), 0);
|
331
|
+
res = _mm256_insertf128_ps (res, masked_read (d - 4, x + 4), 1);
|
332
|
+
return res;
|
333
|
+
}
|
334
|
+
}
|
335
|
+
|
336
|
+
float fvec_inner_product (const float * x,
|
337
|
+
const float * y,
|
338
|
+
size_t d)
|
339
|
+
{
|
340
|
+
__m256 msum1 = _mm256_setzero_ps();
|
341
|
+
|
342
|
+
while (d >= 8) {
|
343
|
+
__m256 mx = _mm256_loadu_ps (x); x += 8;
|
344
|
+
__m256 my = _mm256_loadu_ps (y); y += 8;
|
345
|
+
msum1 = _mm256_add_ps (msum1, _mm256_mul_ps (mx, my));
|
346
|
+
d -= 8;
|
347
|
+
}
|
348
|
+
|
349
|
+
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
350
|
+
msum2 += _mm256_extractf128_ps(msum1, 0);
|
351
|
+
|
352
|
+
if (d >= 4) {
|
353
|
+
__m128 mx = _mm_loadu_ps (x); x += 4;
|
354
|
+
__m128 my = _mm_loadu_ps (y); y += 4;
|
355
|
+
msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
|
356
|
+
d -= 4;
|
357
|
+
}
|
358
|
+
|
359
|
+
if (d > 0) {
|
360
|
+
__m128 mx = masked_read (d, x);
|
361
|
+
__m128 my = masked_read (d, y);
|
362
|
+
msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
|
363
|
+
}
|
364
|
+
|
365
|
+
msum2 = _mm_hadd_ps (msum2, msum2);
|
366
|
+
msum2 = _mm_hadd_ps (msum2, msum2);
|
367
|
+
return _mm_cvtss_f32 (msum2);
|
368
|
+
}
|
369
|
+
|
370
|
+
float fvec_L2sqr (const float * x,
|
371
|
+
const float * y,
|
372
|
+
size_t d)
|
373
|
+
{
|
374
|
+
__m256 msum1 = _mm256_setzero_ps();
|
375
|
+
|
376
|
+
while (d >= 8) {
|
377
|
+
__m256 mx = _mm256_loadu_ps (x); x += 8;
|
378
|
+
__m256 my = _mm256_loadu_ps (y); y += 8;
|
379
|
+
const __m256 a_m_b1 = mx - my;
|
380
|
+
msum1 += a_m_b1 * a_m_b1;
|
381
|
+
d -= 8;
|
382
|
+
}
|
383
|
+
|
384
|
+
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
385
|
+
msum2 += _mm256_extractf128_ps(msum1, 0);
|
386
|
+
|
387
|
+
if (d >= 4) {
|
388
|
+
__m128 mx = _mm_loadu_ps (x); x += 4;
|
389
|
+
__m128 my = _mm_loadu_ps (y); y += 4;
|
390
|
+
const __m128 a_m_b1 = mx - my;
|
391
|
+
msum2 += a_m_b1 * a_m_b1;
|
392
|
+
d -= 4;
|
393
|
+
}
|
394
|
+
|
395
|
+
if (d > 0) {
|
396
|
+
__m128 mx = masked_read (d, x);
|
397
|
+
__m128 my = masked_read (d, y);
|
398
|
+
__m128 a_m_b1 = mx - my;
|
399
|
+
msum2 += a_m_b1 * a_m_b1;
|
400
|
+
}
|
401
|
+
|
402
|
+
msum2 = _mm_hadd_ps (msum2, msum2);
|
403
|
+
msum2 = _mm_hadd_ps (msum2, msum2);
|
404
|
+
return _mm_cvtss_f32 (msum2);
|
405
|
+
}
|
406
|
+
|
407
|
+
float fvec_L1 (const float * x, const float * y, size_t d)
|
408
|
+
{
|
409
|
+
__m256 msum1 = _mm256_setzero_ps();
|
410
|
+
__m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL));
|
411
|
+
|
412
|
+
while (d >= 8) {
|
413
|
+
__m256 mx = _mm256_loadu_ps (x); x += 8;
|
414
|
+
__m256 my = _mm256_loadu_ps (y); y += 8;
|
415
|
+
const __m256 a_m_b = mx - my;
|
416
|
+
msum1 += _mm256_and_ps(signmask, a_m_b);
|
417
|
+
d -= 8;
|
418
|
+
}
|
419
|
+
|
420
|
+
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
421
|
+
msum2 += _mm256_extractf128_ps(msum1, 0);
|
422
|
+
__m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL));
|
423
|
+
|
424
|
+
if (d >= 4) {
|
425
|
+
__m128 mx = _mm_loadu_ps (x); x += 4;
|
426
|
+
__m128 my = _mm_loadu_ps (y); y += 4;
|
427
|
+
const __m128 a_m_b = mx - my;
|
428
|
+
msum2 += _mm_and_ps(signmask2, a_m_b);
|
429
|
+
d -= 4;
|
430
|
+
}
|
431
|
+
|
432
|
+
if (d > 0) {
|
433
|
+
__m128 mx = masked_read (d, x);
|
434
|
+
__m128 my = masked_read (d, y);
|
435
|
+
__m128 a_m_b = mx - my;
|
436
|
+
msum2 += _mm_and_ps(signmask2, a_m_b);
|
437
|
+
}
|
438
|
+
|
439
|
+
msum2 = _mm_hadd_ps (msum2, msum2);
|
440
|
+
msum2 = _mm_hadd_ps (msum2, msum2);
|
441
|
+
return _mm_cvtss_f32 (msum2);
|
442
|
+
}
|
443
|
+
|
444
|
+
float fvec_Linf (const float * x, const float * y, size_t d)
|
445
|
+
{
|
446
|
+
__m256 msum1 = _mm256_setzero_ps();
|
447
|
+
__m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL));
|
448
|
+
|
449
|
+
while (d >= 8) {
|
450
|
+
__m256 mx = _mm256_loadu_ps (x); x += 8;
|
451
|
+
__m256 my = _mm256_loadu_ps (y); y += 8;
|
452
|
+
const __m256 a_m_b = mx - my;
|
453
|
+
msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b));
|
454
|
+
d -= 8;
|
455
|
+
}
|
456
|
+
|
457
|
+
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
458
|
+
msum2 = _mm_max_ps (msum2, _mm256_extractf128_ps(msum1, 0));
|
459
|
+
__m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL));
|
460
|
+
|
461
|
+
if (d >= 4) {
|
462
|
+
__m128 mx = _mm_loadu_ps (x); x += 4;
|
463
|
+
__m128 my = _mm_loadu_ps (y); y += 4;
|
464
|
+
const __m128 a_m_b = mx - my;
|
465
|
+
msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
466
|
+
d -= 4;
|
467
|
+
}
|
468
|
+
|
469
|
+
if (d > 0) {
|
470
|
+
__m128 mx = masked_read (d, x);
|
471
|
+
__m128 my = masked_read (d, y);
|
472
|
+
__m128 a_m_b = mx - my;
|
473
|
+
msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
474
|
+
}
|
475
|
+
|
476
|
+
msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2);
|
477
|
+
msum2 = _mm_max_ps(msum2, _mm_shuffle_ps (msum2, msum2, 1));
|
478
|
+
return _mm_cvtss_f32 (msum2);
|
479
|
+
}
|
480
|
+
|
481
|
+
#elif defined(__SSE__) // But not AVX
|
482
|
+
|
483
|
+
float fvec_L1 (const float * x, const float * y, size_t d)
|
484
|
+
{
|
485
|
+
return fvec_L1_ref (x, y, d);
|
486
|
+
}
|
487
|
+
|
488
|
+
float fvec_Linf (const float * x, const float * y, size_t d)
|
489
|
+
{
|
490
|
+
return fvec_Linf_ref (x, y, d);
|
491
|
+
}
|
492
|
+
|
493
|
+
|
494
|
+
float fvec_L2sqr (const float * x,
|
495
|
+
const float * y,
|
496
|
+
size_t d)
|
497
|
+
{
|
498
|
+
__m128 msum1 = _mm_setzero_ps();
|
499
|
+
|
500
|
+
while (d >= 4) {
|
501
|
+
__m128 mx = _mm_loadu_ps (x); x += 4;
|
502
|
+
__m128 my = _mm_loadu_ps (y); y += 4;
|
503
|
+
const __m128 a_m_b1 = mx - my;
|
504
|
+
msum1 += a_m_b1 * a_m_b1;
|
505
|
+
d -= 4;
|
506
|
+
}
|
507
|
+
|
508
|
+
if (d > 0) {
|
509
|
+
// add the last 1, 2 or 3 values
|
510
|
+
__m128 mx = masked_read (d, x);
|
511
|
+
__m128 my = masked_read (d, y);
|
512
|
+
__m128 a_m_b1 = mx - my;
|
513
|
+
msum1 += a_m_b1 * a_m_b1;
|
514
|
+
}
|
515
|
+
|
516
|
+
msum1 = _mm_hadd_ps (msum1, msum1);
|
517
|
+
msum1 = _mm_hadd_ps (msum1, msum1);
|
518
|
+
return _mm_cvtss_f32 (msum1);
|
519
|
+
}
|
520
|
+
|
521
|
+
|
522
|
+
float fvec_inner_product (const float * x,
|
523
|
+
const float * y,
|
524
|
+
size_t d)
|
525
|
+
{
|
526
|
+
__m128 mx, my;
|
527
|
+
__m128 msum1 = _mm_setzero_ps();
|
528
|
+
|
529
|
+
while (d >= 4) {
|
530
|
+
mx = _mm_loadu_ps (x); x += 4;
|
531
|
+
my = _mm_loadu_ps (y); y += 4;
|
532
|
+
msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, my));
|
533
|
+
d -= 4;
|
534
|
+
}
|
535
|
+
|
536
|
+
// add the last 1, 2, or 3 values
|
537
|
+
mx = masked_read (d, x);
|
538
|
+
my = masked_read (d, y);
|
539
|
+
__m128 prod = _mm_mul_ps (mx, my);
|
540
|
+
|
541
|
+
msum1 = _mm_add_ps (msum1, prod);
|
542
|
+
|
543
|
+
msum1 = _mm_hadd_ps (msum1, msum1);
|
544
|
+
msum1 = _mm_hadd_ps (msum1, msum1);
|
545
|
+
return _mm_cvtss_f32 (msum1);
|
546
|
+
}
|
547
|
+
|
548
|
+
#elif defined(__aarch64__)
|
549
|
+
|
550
|
+
|
551
|
+
float fvec_L2sqr (const float * x,
|
552
|
+
const float * y,
|
553
|
+
size_t d)
|
554
|
+
{
|
555
|
+
if (d & 3) return fvec_L2sqr_ref (x, y, d);
|
556
|
+
float32x4_t accu = vdupq_n_f32 (0);
|
557
|
+
for (size_t i = 0; i < d; i += 4) {
|
558
|
+
float32x4_t xi = vld1q_f32 (x + i);
|
559
|
+
float32x4_t yi = vld1q_f32 (y + i);
|
560
|
+
float32x4_t sq = vsubq_f32 (xi, yi);
|
561
|
+
accu = vfmaq_f32 (accu, sq, sq);
|
562
|
+
}
|
563
|
+
float32x4_t a2 = vpaddq_f32 (accu, accu);
|
564
|
+
return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
|
565
|
+
}
|
566
|
+
|
567
|
+
float fvec_inner_product (const float * x,
|
568
|
+
const float * y,
|
569
|
+
size_t d)
|
570
|
+
{
|
571
|
+
if (d & 3) return fvec_inner_product_ref (x, y, d);
|
572
|
+
float32x4_t accu = vdupq_n_f32 (0);
|
573
|
+
for (size_t i = 0; i < d; i += 4) {
|
574
|
+
float32x4_t xi = vld1q_f32 (x + i);
|
575
|
+
float32x4_t yi = vld1q_f32 (y + i);
|
576
|
+
accu = vfmaq_f32 (accu, xi, yi);
|
577
|
+
}
|
578
|
+
float32x4_t a2 = vpaddq_f32 (accu, accu);
|
579
|
+
return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
|
580
|
+
}
|
581
|
+
|
582
|
+
float fvec_norm_L2sqr (const float *x, size_t d)
|
583
|
+
{
|
584
|
+
if (d & 3) return fvec_norm_L2sqr_ref (x, d);
|
585
|
+
float32x4_t accu = vdupq_n_f32 (0);
|
586
|
+
for (size_t i = 0; i < d; i += 4) {
|
587
|
+
float32x4_t xi = vld1q_f32 (x + i);
|
588
|
+
accu = vfmaq_f32 (accu, xi, xi);
|
589
|
+
}
|
590
|
+
float32x4_t a2 = vpaddq_f32 (accu, accu);
|
591
|
+
return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
|
592
|
+
}
|
593
|
+
|
594
|
+
// not optimized for ARM
|
595
|
+
void fvec_L2sqr_ny (float * dis, const float * x,
|
596
|
+
const float * y, size_t d, size_t ny) {
|
597
|
+
fvec_L2sqr_ny_ref (dis, x, y, d, ny);
|
598
|
+
}
|
599
|
+
|
600
|
+
float fvec_L1 (const float * x, const float * y, size_t d)
|
601
|
+
{
|
602
|
+
return fvec_L1_ref (x, y, d);
|
603
|
+
}
|
604
|
+
|
605
|
+
float fvec_Linf (const float * x, const float * y, size_t d)
|
606
|
+
{
|
607
|
+
return fvec_Linf_ref (x, y, d);
|
608
|
+
}
|
609
|
+
|
610
|
+
|
611
|
+
#else
|
612
|
+
// scalar implementation
|
613
|
+
|
614
|
+
float fvec_L2sqr (const float * x,
|
615
|
+
const float * y,
|
616
|
+
size_t d)
|
617
|
+
{
|
618
|
+
return fvec_L2sqr_ref (x, y, d);
|
619
|
+
}
|
620
|
+
|
621
|
+
float fvec_L1 (const float * x, const float * y, size_t d)
|
622
|
+
{
|
623
|
+
return fvec_L1_ref (x, y, d);
|
624
|
+
}
|
625
|
+
|
626
|
+
float fvec_Linf (const float * x, const float * y, size_t d)
|
627
|
+
{
|
628
|
+
return fvec_Linf_ref (x, y, d);
|
629
|
+
}
|
630
|
+
|
631
|
+
float fvec_inner_product (const float * x,
|
632
|
+
const float * y,
|
633
|
+
size_t d)
|
634
|
+
{
|
635
|
+
return fvec_inner_product_ref (x, y, d);
|
636
|
+
}
|
637
|
+
|
638
|
+
float fvec_norm_L2sqr (const float *x, size_t d)
|
639
|
+
{
|
640
|
+
return fvec_norm_L2sqr_ref (x, d);
|
641
|
+
}
|
642
|
+
|
643
|
+
void fvec_L2sqr_ny (float * dis, const float * x,
|
644
|
+
const float * y, size_t d, size_t ny) {
|
645
|
+
fvec_L2sqr_ny_ref (dis, x, y, d, ny);
|
646
|
+
}
|
647
|
+
|
648
|
+
|
649
|
+
#endif
|
650
|
+
|
651
|
+
|
652
|
+
|
653
|
+
|
654
|
+
|
655
|
+
|
656
|
+
|
657
|
+
|
658
|
+
|
659
|
+
|
660
|
+
|
661
|
+
|
662
|
+
|
663
|
+
|
664
|
+
|
665
|
+
|
666
|
+
|
667
|
+
|
668
|
+
|
669
|
+
|
670
|
+
/***************************************************************************
|
671
|
+
* heavily optimized table computations
|
672
|
+
***************************************************************************/
|
673
|
+
|
674
|
+
|
675
|
+
static inline void fvec_madd_ref (size_t n, const float *a,
|
676
|
+
float bf, const float *b, float *c) {
|
677
|
+
for (size_t i = 0; i < n; i++)
|
678
|
+
c[i] = a[i] + bf * b[i];
|
679
|
+
}
|
680
|
+
|
681
|
+
#ifdef __SSE__
|
682
|
+
|
683
|
+
static inline void fvec_madd_sse (size_t n, const float *a,
|
684
|
+
float bf, const float *b, float *c) {
|
685
|
+
n >>= 2;
|
686
|
+
__m128 bf4 = _mm_set_ps1 (bf);
|
687
|
+
__m128 * a4 = (__m128*)a;
|
688
|
+
__m128 * b4 = (__m128*)b;
|
689
|
+
__m128 * c4 = (__m128*)c;
|
690
|
+
|
691
|
+
while (n--) {
|
692
|
+
*c4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
|
693
|
+
b4++;
|
694
|
+
a4++;
|
695
|
+
c4++;
|
696
|
+
}
|
697
|
+
}
|
698
|
+
|
699
|
+
void fvec_madd (size_t n, const float *a,
|
700
|
+
float bf, const float *b, float *c)
|
701
|
+
{
|
702
|
+
if ((n & 3) == 0 &&
|
703
|
+
((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
704
|
+
fvec_madd_sse (n, a, bf, b, c);
|
705
|
+
else
|
706
|
+
fvec_madd_ref (n, a, bf, b, c);
|
707
|
+
}
|
708
|
+
|
709
|
+
#else
|
710
|
+
|
711
|
+
void fvec_madd (size_t n, const float *a,
|
712
|
+
float bf, const float *b, float *c)
|
713
|
+
{
|
714
|
+
fvec_madd_ref (n, a, bf, b, c);
|
715
|
+
}
|
716
|
+
|
717
|
+
#endif
|
718
|
+
|
719
|
+
static inline int fvec_madd_and_argmin_ref (size_t n, const float *a,
|
720
|
+
float bf, const float *b, float *c) {
|
721
|
+
float vmin = 1e20;
|
722
|
+
int imin = -1;
|
723
|
+
|
724
|
+
for (size_t i = 0; i < n; i++) {
|
725
|
+
c[i] = a[i] + bf * b[i];
|
726
|
+
if (c[i] < vmin) {
|
727
|
+
vmin = c[i];
|
728
|
+
imin = i;
|
729
|
+
}
|
730
|
+
}
|
731
|
+
return imin;
|
732
|
+
}
|
733
|
+
|
734
|
+
#ifdef __SSE__
|
735
|
+
|
736
|
+
static inline int fvec_madd_and_argmin_sse (
|
737
|
+
size_t n, const float *a,
|
738
|
+
float bf, const float *b, float *c) {
|
739
|
+
n >>= 2;
|
740
|
+
__m128 bf4 = _mm_set_ps1 (bf);
|
741
|
+
__m128 vmin4 = _mm_set_ps1 (1e20);
|
742
|
+
__m128i imin4 = _mm_set1_epi32 (-1);
|
743
|
+
__m128i idx4 = _mm_set_epi32 (3, 2, 1, 0);
|
744
|
+
__m128i inc4 = _mm_set1_epi32 (4);
|
745
|
+
__m128 * a4 = (__m128*)a;
|
746
|
+
__m128 * b4 = (__m128*)b;
|
747
|
+
__m128 * c4 = (__m128*)c;
|
748
|
+
|
749
|
+
while (n--) {
|
750
|
+
__m128 vc4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
|
751
|
+
*c4 = vc4;
|
752
|
+
__m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
|
753
|
+
// imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower!
|
754
|
+
|
755
|
+
imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
|
756
|
+
_mm_andnot_si128 (mask, imin4));
|
757
|
+
vmin4 = _mm_min_ps (vmin4, vc4);
|
758
|
+
b4++;
|
759
|
+
a4++;
|
760
|
+
c4++;
|
761
|
+
idx4 = _mm_add_epi32 (idx4, inc4);
|
762
|
+
}
|
763
|
+
|
764
|
+
// 4 values -> 2
|
765
|
+
{
|
766
|
+
idx4 = _mm_shuffle_epi32 (imin4, 3 << 2 | 2);
|
767
|
+
__m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 3 << 2 | 2);
|
768
|
+
__m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
|
769
|
+
imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
|
770
|
+
_mm_andnot_si128 (mask, imin4));
|
771
|
+
vmin4 = _mm_min_ps (vmin4, vc4);
|
772
|
+
}
|
773
|
+
// 2 values -> 1
|
774
|
+
{
|
775
|
+
idx4 = _mm_shuffle_epi32 (imin4, 1);
|
776
|
+
__m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 1);
|
777
|
+
__m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
|
778
|
+
imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
|
779
|
+
_mm_andnot_si128 (mask, imin4));
|
780
|
+
// vmin4 = _mm_min_ps (vmin4, vc4);
|
781
|
+
}
|
782
|
+
return _mm_cvtsi128_si32 (imin4);
|
783
|
+
}
|
784
|
+
|
785
|
+
|
786
|
+
int fvec_madd_and_argmin (size_t n, const float *a,
|
787
|
+
float bf, const float *b, float *c)
|
788
|
+
{
|
789
|
+
if ((n & 3) == 0 &&
|
790
|
+
((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
791
|
+
return fvec_madd_and_argmin_sse (n, a, bf, b, c);
|
792
|
+
else
|
793
|
+
return fvec_madd_and_argmin_ref (n, a, bf, b, c);
|
794
|
+
}
|
795
|
+
|
796
|
+
#else
|
797
|
+
|
798
|
+
int fvec_madd_and_argmin (size_t n, const float *a,
|
799
|
+
float bf, const float *b, float *c)
|
800
|
+
{
|
801
|
+
return fvec_madd_and_argmin_ref (n, a, bf, b, c);
|
802
|
+
}
|
803
|
+
|
804
|
+
#endif
|
805
|
+
|
806
|
+
|
807
|
+
|
808
|
+
|
809
|
+
} // namespace faiss
|