faiss 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +103 -3
- data/ext/faiss/ext.cpp +99 -32
- data/ext/faiss/extconf.rb +12 -2
- data/lib/faiss/ext.bundle +0 -0
- data/lib/faiss/index.rb +3 -3
- data/lib/faiss/index_binary.rb +3 -3
- data/lib/faiss/kmeans.rb +1 -1
- data/lib/faiss/pca_matrix.rb +2 -2
- data/lib/faiss/product_quantizer.rb +3 -3
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/AutoTune.cpp +719 -0
- data/vendor/faiss/AutoTune.h +212 -0
- data/vendor/faiss/Clustering.cpp +261 -0
- data/vendor/faiss/Clustering.h +101 -0
- data/vendor/faiss/IVFlib.cpp +339 -0
- data/vendor/faiss/IVFlib.h +132 -0
- data/vendor/faiss/Index.cpp +171 -0
- data/vendor/faiss/Index.h +261 -0
- data/vendor/faiss/Index2Layer.cpp +437 -0
- data/vendor/faiss/Index2Layer.h +85 -0
- data/vendor/faiss/IndexBinary.cpp +77 -0
- data/vendor/faiss/IndexBinary.h +163 -0
- data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
- data/vendor/faiss/IndexBinaryFlat.h +54 -0
- data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
- data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
- data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
- data/vendor/faiss/IndexBinaryHNSW.h +56 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
- data/vendor/faiss/IndexBinaryIVF.h +211 -0
- data/vendor/faiss/IndexFlat.cpp +508 -0
- data/vendor/faiss/IndexFlat.h +175 -0
- data/vendor/faiss/IndexHNSW.cpp +1090 -0
- data/vendor/faiss/IndexHNSW.h +170 -0
- data/vendor/faiss/IndexIVF.cpp +909 -0
- data/vendor/faiss/IndexIVF.h +353 -0
- data/vendor/faiss/IndexIVFFlat.cpp +502 -0
- data/vendor/faiss/IndexIVFFlat.h +118 -0
- data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
- data/vendor/faiss/IndexIVFPQ.h +161 -0
- data/vendor/faiss/IndexIVFPQR.cpp +219 -0
- data/vendor/faiss/IndexIVFPQR.h +65 -0
- data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
- data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
- data/vendor/faiss/IndexLSH.cpp +225 -0
- data/vendor/faiss/IndexLSH.h +87 -0
- data/vendor/faiss/IndexLattice.cpp +143 -0
- data/vendor/faiss/IndexLattice.h +68 -0
- data/vendor/faiss/IndexPQ.cpp +1188 -0
- data/vendor/faiss/IndexPQ.h +199 -0
- data/vendor/faiss/IndexPreTransform.cpp +288 -0
- data/vendor/faiss/IndexPreTransform.h +91 -0
- data/vendor/faiss/IndexReplicas.cpp +123 -0
- data/vendor/faiss/IndexReplicas.h +76 -0
- data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
- data/vendor/faiss/IndexScalarQuantizer.h +127 -0
- data/vendor/faiss/IndexShards.cpp +317 -0
- data/vendor/faiss/IndexShards.h +100 -0
- data/vendor/faiss/InvertedLists.cpp +623 -0
- data/vendor/faiss/InvertedLists.h +334 -0
- data/vendor/faiss/LICENSE +21 -0
- data/vendor/faiss/MatrixStats.cpp +252 -0
- data/vendor/faiss/MatrixStats.h +62 -0
- data/vendor/faiss/MetaIndexes.cpp +351 -0
- data/vendor/faiss/MetaIndexes.h +126 -0
- data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
- data/vendor/faiss/OnDiskInvertedLists.h +127 -0
- data/vendor/faiss/VectorTransform.cpp +1157 -0
- data/vendor/faiss/VectorTransform.h +322 -0
- data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
- data/vendor/faiss/c_api/AutoTune_c.h +64 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
- data/vendor/faiss/c_api/Clustering_c.h +117 -0
- data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
- data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
- data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
- data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
- data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
- data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
- data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
- data/vendor/faiss/c_api/IndexShards_c.h +42 -0
- data/vendor/faiss/c_api/Index_c.cpp +105 -0
- data/vendor/faiss/c_api/Index_c.h +183 -0
- data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
- data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
- data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
- data/vendor/faiss/c_api/clone_index_c.h +32 -0
- data/vendor/faiss/c_api/error_c.h +42 -0
- data/vendor/faiss/c_api/error_impl.cpp +27 -0
- data/vendor/faiss/c_api/error_impl.h +16 -0
- data/vendor/faiss/c_api/faiss_c.h +58 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
- data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
- data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
- data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
- data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
- data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
- data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
- data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
- data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
- data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
- data/vendor/faiss/c_api/index_factory_c.h +30 -0
- data/vendor/faiss/c_api/index_io_c.cpp +42 -0
- data/vendor/faiss/c_api/index_io_c.h +50 -0
- data/vendor/faiss/c_api/macros_impl.h +110 -0
- data/vendor/faiss/clone_index.cpp +147 -0
- data/vendor/faiss/clone_index.h +38 -0
- data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
- data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
- data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
- data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
- data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
- data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
- data/vendor/faiss/gpu/GpuCloner.h +82 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
- data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
- data/vendor/faiss/gpu/GpuDistance.h +52 -0
- data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
- data/vendor/faiss/gpu/GpuIndex.h +148 -0
- data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
- data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
- data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
- data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
- data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
- data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
- data/vendor/faiss/gpu/GpuResources.cpp +52 -0
- data/vendor/faiss/gpu/GpuResources.h +73 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
- data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
- data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
- data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
- data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
- data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
- data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
- data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
- data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
- data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
- data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
- data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
- data/vendor/faiss/gpu/test/TestUtils.h +93 -0
- data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
- data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
- data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
- data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
- data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
- data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
- data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
- data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
- data/vendor/faiss/gpu/utils/Timer.h +52 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
- data/vendor/faiss/impl/FaissAssert.h +95 -0
- data/vendor/faiss/impl/FaissException.cpp +66 -0
- data/vendor/faiss/impl/FaissException.h +71 -0
- data/vendor/faiss/impl/HNSW.cpp +818 -0
- data/vendor/faiss/impl/HNSW.h +275 -0
- data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
- data/vendor/faiss/impl/PolysemousTraining.h +158 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
- data/vendor/faiss/impl/ProductQuantizer.h +242 -0
- data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
- data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
- data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
- data/vendor/faiss/impl/ThreadedIndex.h +80 -0
- data/vendor/faiss/impl/index_read.cpp +793 -0
- data/vendor/faiss/impl/index_write.cpp +558 -0
- data/vendor/faiss/impl/io.cpp +142 -0
- data/vendor/faiss/impl/io.h +98 -0
- data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
- data/vendor/faiss/impl/lattice_Zn.h +199 -0
- data/vendor/faiss/index_factory.cpp +392 -0
- data/vendor/faiss/index_factory.h +25 -0
- data/vendor/faiss/index_io.h +75 -0
- data/vendor/faiss/misc/test_blas.cpp +84 -0
- data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
- data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
- data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
- data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
- data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
- data/vendor/faiss/tests/test_merge.cpp +258 -0
- data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
- data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
- data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
- data/vendor/faiss/tests/test_params_override.cpp +231 -0
- data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
- data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
- data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
- data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
- data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
- data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
- data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
- data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
- data/vendor/faiss/utils/Heap.cpp +122 -0
- data/vendor/faiss/utils/Heap.h +495 -0
- data/vendor/faiss/utils/WorkerThread.cpp +126 -0
- data/vendor/faiss/utils/WorkerThread.h +61 -0
- data/vendor/faiss/utils/distances.cpp +765 -0
- data/vendor/faiss/utils/distances.h +243 -0
- data/vendor/faiss/utils/distances_simd.cpp +809 -0
- data/vendor/faiss/utils/extra_distances.cpp +336 -0
- data/vendor/faiss/utils/extra_distances.h +54 -0
- data/vendor/faiss/utils/hamming-inl.h +472 -0
- data/vendor/faiss/utils/hamming.cpp +792 -0
- data/vendor/faiss/utils/hamming.h +220 -0
- data/vendor/faiss/utils/random.cpp +192 -0
- data/vendor/faiss/utils/random.h +60 -0
- data/vendor/faiss/utils/utils.cpp +783 -0
- data/vendor/faiss/utils/utils.h +181 -0
- metadata +216 -2
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
/* All distance functions for L2 and IP distances.
|
|
11
|
+
* The actual functions are implemented in distances.cpp and distances_simd.cpp */
|
|
12
|
+
|
|
13
|
+
#pragma once
|
|
14
|
+
|
|
15
|
+
#include <stdint.h>
|
|
16
|
+
|
|
17
|
+
#include <faiss/utils/Heap.h>
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
namespace faiss {
|
|
21
|
+
|
|
22
|
+
/*********************************************************
|
|
23
|
+
* Optimized distance/norm/inner prod computations
|
|
24
|
+
*********************************************************/
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
/// Squared L2 distance between two vectors
|
|
28
|
+
float fvec_L2sqr (
|
|
29
|
+
const float * x,
|
|
30
|
+
const float * y,
|
|
31
|
+
size_t d);
|
|
32
|
+
|
|
33
|
+
/// inner product
|
|
34
|
+
float fvec_inner_product (
|
|
35
|
+
const float * x,
|
|
36
|
+
const float * y,
|
|
37
|
+
size_t d);
|
|
38
|
+
|
|
39
|
+
/// L1 distance
|
|
40
|
+
float fvec_L1 (
|
|
41
|
+
const float * x,
|
|
42
|
+
const float * y,
|
|
43
|
+
size_t d);
|
|
44
|
+
|
|
45
|
+
float fvec_Linf (
|
|
46
|
+
const float * x,
|
|
47
|
+
const float * y,
|
|
48
|
+
size_t d);
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
/** Compute pairwise distances between sets of vectors
|
|
52
|
+
*
|
|
53
|
+
* @param d dimension of the vectors
|
|
54
|
+
* @param nq nb of query vectors
|
|
55
|
+
* @param nb nb of database vectors
|
|
56
|
+
* @param xq query vectors (size nq * d)
|
|
57
|
+
* @param xb database vectros (size nb * d)
|
|
58
|
+
* @param dis output distances (size nq * nb)
|
|
59
|
+
* @param ldq,ldb, ldd strides for the matrices
|
|
60
|
+
*/
|
|
61
|
+
void pairwise_L2sqr (int64_t d,
|
|
62
|
+
int64_t nq, const float *xq,
|
|
63
|
+
int64_t nb, const float *xb,
|
|
64
|
+
float *dis,
|
|
65
|
+
int64_t ldq = -1, int64_t ldb = -1, int64_t ldd = -1);
|
|
66
|
+
|
|
67
|
+
/* compute the inner product between nx vectors x and one y */
|
|
68
|
+
void fvec_inner_products_ny (
|
|
69
|
+
float * ip, /* output inner product */
|
|
70
|
+
const float * x,
|
|
71
|
+
const float * y,
|
|
72
|
+
size_t d, size_t ny);
|
|
73
|
+
|
|
74
|
+
/* compute ny square L2 distance bewteen x and a set of contiguous y vectors */
|
|
75
|
+
void fvec_L2sqr_ny (
|
|
76
|
+
float * dis,
|
|
77
|
+
const float * x,
|
|
78
|
+
const float * y,
|
|
79
|
+
size_t d, size_t ny);
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
/** squared norm of a vector */
|
|
83
|
+
float fvec_norm_L2sqr (const float * x,
|
|
84
|
+
size_t d);
|
|
85
|
+
|
|
86
|
+
/** compute the L2 norms for a set of vectors
|
|
87
|
+
*
|
|
88
|
+
* @param ip output norms, size nx
|
|
89
|
+
* @param x set of vectors, size nx * d
|
|
90
|
+
*/
|
|
91
|
+
void fvec_norms_L2 (float * ip, const float * x, size_t d, size_t nx);
|
|
92
|
+
|
|
93
|
+
/// same as fvec_norms_L2, but computes square norms
|
|
94
|
+
void fvec_norms_L2sqr (float * ip, const float * x, size_t d, size_t nx);
|
|
95
|
+
|
|
96
|
+
/* L2-renormalize a set of vector. Nothing done if the vector is 0-normed */
|
|
97
|
+
void fvec_renorm_L2 (size_t d, size_t nx, float * x);
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
/* This function exists because the Torch counterpart is extremly slow
|
|
101
|
+
(not multi-threaded + unexpected overhead even in single thread).
|
|
102
|
+
It is here to implement the usual property |x-y|^2=|x|^2+|y|^2-2<x|y> */
|
|
103
|
+
void inner_product_to_L2sqr (float * dis,
|
|
104
|
+
const float * nr1,
|
|
105
|
+
const float * nr2,
|
|
106
|
+
size_t n1, size_t n2);
|
|
107
|
+
|
|
108
|
+
/***************************************************************************
|
|
109
|
+
* Compute a subset of distances
|
|
110
|
+
***************************************************************************/
|
|
111
|
+
|
|
112
|
+
/* compute the inner product between x and a subset y of ny vectors,
|
|
113
|
+
whose indices are given by idy. */
|
|
114
|
+
void fvec_inner_products_by_idx (
|
|
115
|
+
float * ip,
|
|
116
|
+
const float * x,
|
|
117
|
+
const float * y,
|
|
118
|
+
const int64_t *ids,
|
|
119
|
+
size_t d, size_t nx, size_t ny);
|
|
120
|
+
|
|
121
|
+
/* same but for a subset in y indexed by idsy (ny vectors in total) */
|
|
122
|
+
void fvec_L2sqr_by_idx (
|
|
123
|
+
float * dis,
|
|
124
|
+
const float * x,
|
|
125
|
+
const float * y,
|
|
126
|
+
const int64_t *ids, /* ids of y vecs */
|
|
127
|
+
size_t d, size_t nx, size_t ny);
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
/** compute dis[j] = L2sqr(x[ix[j]], y[iy[j]]) forall j=0..n-1
|
|
131
|
+
*
|
|
132
|
+
* @param x size (max(ix) + 1, d)
|
|
133
|
+
* @param y size (max(iy) + 1, d)
|
|
134
|
+
* @param ix size n
|
|
135
|
+
* @param iy size n
|
|
136
|
+
* @param dis size n
|
|
137
|
+
*/
|
|
138
|
+
void pairwise_indexed_L2sqr (
|
|
139
|
+
size_t d, size_t n,
|
|
140
|
+
const float * x, const int64_t *ix,
|
|
141
|
+
const float * y, const int64_t *iy,
|
|
142
|
+
float *dis);
|
|
143
|
+
|
|
144
|
+
/* same for inner product */
|
|
145
|
+
void pairwise_indexed_inner_product (
|
|
146
|
+
size_t d, size_t n,
|
|
147
|
+
const float * x, const int64_t *ix,
|
|
148
|
+
const float * y, const int64_t *iy,
|
|
149
|
+
float *dis);
|
|
150
|
+
|
|
151
|
+
/***************************************************************************
|
|
152
|
+
* KNN functions
|
|
153
|
+
***************************************************************************/
|
|
154
|
+
|
|
155
|
+
// threshold on nx above which we switch to BLAS to compute distances
|
|
156
|
+
extern int distance_compute_blas_threshold;
|
|
157
|
+
|
|
158
|
+
/** Return the k nearest neighors of each of the nx vectors x among the ny
|
|
159
|
+
* vector y, w.r.t to max inner product
|
|
160
|
+
*
|
|
161
|
+
* @param x query vectors, size nx * d
|
|
162
|
+
* @param y database vectors, size ny * d
|
|
163
|
+
* @param res result array, which also provides k. Sorted on output
|
|
164
|
+
*/
|
|
165
|
+
void knn_inner_product (
|
|
166
|
+
const float * x,
|
|
167
|
+
const float * y,
|
|
168
|
+
size_t d, size_t nx, size_t ny,
|
|
169
|
+
float_minheap_array_t * res);
|
|
170
|
+
|
|
171
|
+
/** Same as knn_inner_product, for the L2 distance */
|
|
172
|
+
void knn_L2sqr (
|
|
173
|
+
const float * x,
|
|
174
|
+
const float * y,
|
|
175
|
+
size_t d, size_t nx, size_t ny,
|
|
176
|
+
float_maxheap_array_t * res);
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
/** same as knn_L2sqr, but base_shift[bno] is subtracted to all
|
|
181
|
+
* computed distances.
|
|
182
|
+
*
|
|
183
|
+
* @param base_shift size ny
|
|
184
|
+
*/
|
|
185
|
+
void knn_L2sqr_base_shift (
|
|
186
|
+
const float * x,
|
|
187
|
+
const float * y,
|
|
188
|
+
size_t d, size_t nx, size_t ny,
|
|
189
|
+
float_maxheap_array_t * res,
|
|
190
|
+
const float *base_shift);
|
|
191
|
+
|
|
192
|
+
/* Find the nearest neighbors for nx queries in a set of ny vectors
|
|
193
|
+
* indexed by ids. May be useful for re-ranking a pre-selected vector list
|
|
194
|
+
*/
|
|
195
|
+
void knn_inner_products_by_idx (
|
|
196
|
+
const float * x,
|
|
197
|
+
const float * y,
|
|
198
|
+
const int64_t * ids,
|
|
199
|
+
size_t d, size_t nx, size_t ny,
|
|
200
|
+
float_minheap_array_t * res);
|
|
201
|
+
|
|
202
|
+
void knn_L2sqr_by_idx (const float * x,
|
|
203
|
+
const float * y,
|
|
204
|
+
const int64_t * ids,
|
|
205
|
+
size_t d, size_t nx, size_t ny,
|
|
206
|
+
float_maxheap_array_t * res);
|
|
207
|
+
|
|
208
|
+
/***************************************************************************
|
|
209
|
+
* Range search
|
|
210
|
+
***************************************************************************/
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
/// Forward declaration, see AuxIndexStructures.h
|
|
215
|
+
struct RangeSearchResult;
|
|
216
|
+
|
|
217
|
+
/** Return the k nearest neighors of each of the nx vectors x among the ny
|
|
218
|
+
* vector y, w.r.t to max inner product
|
|
219
|
+
*
|
|
220
|
+
* @param x query vectors, size nx * d
|
|
221
|
+
* @param y database vectors, size ny * d
|
|
222
|
+
* @param radius search radius around the x vectors
|
|
223
|
+
* @param result result structure
|
|
224
|
+
*/
|
|
225
|
+
void range_search_L2sqr (
|
|
226
|
+
const float * x,
|
|
227
|
+
const float * y,
|
|
228
|
+
size_t d, size_t nx, size_t ny,
|
|
229
|
+
float radius,
|
|
230
|
+
RangeSearchResult *result);
|
|
231
|
+
|
|
232
|
+
/// same as range_search_L2sqr for the inner product similarity
|
|
233
|
+
void range_search_inner_product (
|
|
234
|
+
const float * x,
|
|
235
|
+
const float * y,
|
|
236
|
+
size_t d, size_t nx, size_t ny,
|
|
237
|
+
float radius,
|
|
238
|
+
RangeSearchResult *result);
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
} // namespace faiss
|
|
@@ -0,0 +1,809 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
// -*- c++ -*-
|
|
9
|
+
|
|
10
|
+
#include <faiss/utils/distances.h>
|
|
11
|
+
|
|
12
|
+
#include <cstdio>
|
|
13
|
+
#include <cassert>
|
|
14
|
+
#include <cstring>
|
|
15
|
+
#include <cmath>
|
|
16
|
+
|
|
17
|
+
#ifdef __SSE__
|
|
18
|
+
#include <immintrin.h>
|
|
19
|
+
#endif
|
|
20
|
+
|
|
21
|
+
#ifdef __aarch64__
|
|
22
|
+
#include <arm_neon.h>
|
|
23
|
+
#endif
|
|
24
|
+
|
|
25
|
+
#include <omp.h>
|
|
26
|
+
|
|
27
|
+
namespace faiss {
|
|
28
|
+
|
|
29
|
+
#ifdef __AVX__
|
|
30
|
+
#define USE_AVX
|
|
31
|
+
#endif
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
/*********************************************************
|
|
35
|
+
* Optimized distance computations
|
|
36
|
+
*********************************************************/
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
/* Functions to compute:
|
|
40
|
+
- L2 distance between 2 vectors
|
|
41
|
+
- inner product between 2 vectors
|
|
42
|
+
- L2 norm of a vector
|
|
43
|
+
|
|
44
|
+
The functions should probably not be invoked when a large number of
|
|
45
|
+
vectors are be processed in batch (in which case Matrix multiply
|
|
46
|
+
is faster), but may be useful for comparing vectors isolated in
|
|
47
|
+
memory.
|
|
48
|
+
|
|
49
|
+
Works with any vectors of any dimension, even unaligned (in which
|
|
50
|
+
case they are slower).
|
|
51
|
+
|
|
52
|
+
*/
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
/*********************************************************
|
|
56
|
+
* Reference implementations
|
|
57
|
+
*/
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
float fvec_L2sqr_ref (const float * x,
|
|
61
|
+
const float * y,
|
|
62
|
+
size_t d)
|
|
63
|
+
{
|
|
64
|
+
size_t i;
|
|
65
|
+
float res = 0;
|
|
66
|
+
for (i = 0; i < d; i++) {
|
|
67
|
+
const float tmp = x[i] - y[i];
|
|
68
|
+
res += tmp * tmp;
|
|
69
|
+
}
|
|
70
|
+
return res;
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
float fvec_L1_ref (const float * x,
|
|
74
|
+
const float * y,
|
|
75
|
+
size_t d)
|
|
76
|
+
{
|
|
77
|
+
size_t i;
|
|
78
|
+
float res = 0;
|
|
79
|
+
for (i = 0; i < d; i++) {
|
|
80
|
+
const float tmp = x[i] - y[i];
|
|
81
|
+
res += fabs(tmp);
|
|
82
|
+
}
|
|
83
|
+
return res;
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
float fvec_Linf_ref (const float * x,
|
|
87
|
+
const float * y,
|
|
88
|
+
size_t d)
|
|
89
|
+
{
|
|
90
|
+
size_t i;
|
|
91
|
+
float res = 0;
|
|
92
|
+
for (i = 0; i < d; i++) {
|
|
93
|
+
res = fmax(res, fabs(x[i] - y[i]));
|
|
94
|
+
}
|
|
95
|
+
return res;
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
float fvec_inner_product_ref (const float * x,
|
|
99
|
+
const float * y,
|
|
100
|
+
size_t d)
|
|
101
|
+
{
|
|
102
|
+
size_t i;
|
|
103
|
+
float res = 0;
|
|
104
|
+
for (i = 0; i < d; i++)
|
|
105
|
+
res += x[i] * y[i];
|
|
106
|
+
return res;
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
float fvec_norm_L2sqr_ref (const float *x, size_t d)
|
|
110
|
+
{
|
|
111
|
+
size_t i;
|
|
112
|
+
double res = 0;
|
|
113
|
+
for (i = 0; i < d; i++)
|
|
114
|
+
res += x[i] * x[i];
|
|
115
|
+
return res;
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
void fvec_L2sqr_ny_ref (float * dis,
|
|
120
|
+
const float * x,
|
|
121
|
+
const float * y,
|
|
122
|
+
size_t d, size_t ny)
|
|
123
|
+
{
|
|
124
|
+
for (size_t i = 0; i < ny; i++) {
|
|
125
|
+
dis[i] = fvec_L2sqr (x, y, d);
|
|
126
|
+
y += d;
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
/*********************************************************
|
|
134
|
+
* SSE and AVX implementations
|
|
135
|
+
*/
|
|
136
|
+
|
|
137
|
+
#ifdef __SSE__
|
|
138
|
+
|
|
139
|
+
// reads 0 <= d < 4 floats as __m128
|
|
140
|
+
static inline __m128 masked_read (int d, const float *x)
|
|
141
|
+
{
|
|
142
|
+
assert (0 <= d && d < 4);
|
|
143
|
+
__attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
|
|
144
|
+
switch (d) {
|
|
145
|
+
case 3:
|
|
146
|
+
buf[2] = x[2];
|
|
147
|
+
case 2:
|
|
148
|
+
buf[1] = x[1];
|
|
149
|
+
case 1:
|
|
150
|
+
buf[0] = x[0];
|
|
151
|
+
}
|
|
152
|
+
return _mm_load_ps (buf);
|
|
153
|
+
// cannot use AVX2 _mm_mask_set1_epi32
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
float fvec_norm_L2sqr (const float * x,
|
|
157
|
+
size_t d)
|
|
158
|
+
{
|
|
159
|
+
__m128 mx;
|
|
160
|
+
__m128 msum1 = _mm_setzero_ps();
|
|
161
|
+
|
|
162
|
+
while (d >= 4) {
|
|
163
|
+
mx = _mm_loadu_ps (x); x += 4;
|
|
164
|
+
msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
|
|
165
|
+
d -= 4;
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
mx = masked_read (d, x);
|
|
169
|
+
msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
|
|
170
|
+
|
|
171
|
+
msum1 = _mm_hadd_ps (msum1, msum1);
|
|
172
|
+
msum1 = _mm_hadd_ps (msum1, msum1);
|
|
173
|
+
return _mm_cvtss_f32 (msum1);
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
namespace {
|
|
177
|
+
|
|
178
|
+
float sqr (float x) {
|
|
179
|
+
return x * x;
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
void fvec_L2sqr_ny_D1 (float * dis, const float * x,
|
|
184
|
+
const float * y, size_t ny)
|
|
185
|
+
{
|
|
186
|
+
float x0s = x[0];
|
|
187
|
+
__m128 x0 = _mm_set_ps (x0s, x0s, x0s, x0s);
|
|
188
|
+
|
|
189
|
+
size_t i;
|
|
190
|
+
for (i = 0; i + 3 < ny; i += 4) {
|
|
191
|
+
__m128 tmp, accu;
|
|
192
|
+
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
|
193
|
+
accu = tmp * tmp;
|
|
194
|
+
dis[i] = _mm_cvtss_f32 (accu);
|
|
195
|
+
tmp = _mm_shuffle_ps (accu, accu, 1);
|
|
196
|
+
dis[i + 1] = _mm_cvtss_f32 (tmp);
|
|
197
|
+
tmp = _mm_shuffle_ps (accu, accu, 2);
|
|
198
|
+
dis[i + 2] = _mm_cvtss_f32 (tmp);
|
|
199
|
+
tmp = _mm_shuffle_ps (accu, accu, 3);
|
|
200
|
+
dis[i + 3] = _mm_cvtss_f32 (tmp);
|
|
201
|
+
}
|
|
202
|
+
while (i < ny) { // handle non-multiple-of-4 case
|
|
203
|
+
dis[i++] = sqr(x0s - *y++);
|
|
204
|
+
}
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
void fvec_L2sqr_ny_D2 (float * dis, const float * x,
|
|
209
|
+
const float * y, size_t ny)
|
|
210
|
+
{
|
|
211
|
+
__m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
|
|
212
|
+
|
|
213
|
+
size_t i;
|
|
214
|
+
for (i = 0; i + 1 < ny; i += 2) {
|
|
215
|
+
__m128 tmp, accu;
|
|
216
|
+
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
|
217
|
+
accu = tmp * tmp;
|
|
218
|
+
accu = _mm_hadd_ps (accu, accu);
|
|
219
|
+
dis[i] = _mm_cvtss_f32 (accu);
|
|
220
|
+
accu = _mm_shuffle_ps (accu, accu, 3);
|
|
221
|
+
dis[i + 1] = _mm_cvtss_f32 (accu);
|
|
222
|
+
}
|
|
223
|
+
if (i < ny) { // handle odd case
|
|
224
|
+
dis[i] = sqr(x[0] - y[0]) + sqr(x[1] - y[1]);
|
|
225
|
+
}
|
|
226
|
+
}
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
void fvec_L2sqr_ny_D4 (float * dis, const float * x,
|
|
231
|
+
const float * y, size_t ny)
|
|
232
|
+
{
|
|
233
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
234
|
+
|
|
235
|
+
for (size_t i = 0; i < ny; i++) {
|
|
236
|
+
__m128 tmp, accu;
|
|
237
|
+
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
|
238
|
+
accu = tmp * tmp;
|
|
239
|
+
accu = _mm_hadd_ps (accu, accu);
|
|
240
|
+
accu = _mm_hadd_ps (accu, accu);
|
|
241
|
+
dis[i] = _mm_cvtss_f32 (accu);
|
|
242
|
+
}
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
void fvec_L2sqr_ny_D8 (float * dis, const float * x,
|
|
247
|
+
const float * y, size_t ny)
|
|
248
|
+
{
|
|
249
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
250
|
+
__m128 x1 = _mm_loadu_ps(x + 4);
|
|
251
|
+
|
|
252
|
+
for (size_t i = 0; i < ny; i++) {
|
|
253
|
+
__m128 tmp, accu;
|
|
254
|
+
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
|
255
|
+
accu = tmp * tmp;
|
|
256
|
+
tmp = x1 - _mm_loadu_ps (y); y += 4;
|
|
257
|
+
accu += tmp * tmp;
|
|
258
|
+
accu = _mm_hadd_ps (accu, accu);
|
|
259
|
+
accu = _mm_hadd_ps (accu, accu);
|
|
260
|
+
dis[i] = _mm_cvtss_f32 (accu);
|
|
261
|
+
}
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
void fvec_L2sqr_ny_D12 (float * dis, const float * x,
|
|
266
|
+
const float * y, size_t ny)
|
|
267
|
+
{
|
|
268
|
+
__m128 x0 = _mm_loadu_ps(x);
|
|
269
|
+
__m128 x1 = _mm_loadu_ps(x + 4);
|
|
270
|
+
__m128 x2 = _mm_loadu_ps(x + 8);
|
|
271
|
+
|
|
272
|
+
for (size_t i = 0; i < ny; i++) {
|
|
273
|
+
__m128 tmp, accu;
|
|
274
|
+
tmp = x0 - _mm_loadu_ps (y); y += 4;
|
|
275
|
+
accu = tmp * tmp;
|
|
276
|
+
tmp = x1 - _mm_loadu_ps (y); y += 4;
|
|
277
|
+
accu += tmp * tmp;
|
|
278
|
+
tmp = x2 - _mm_loadu_ps (y); y += 4;
|
|
279
|
+
accu += tmp * tmp;
|
|
280
|
+
accu = _mm_hadd_ps (accu, accu);
|
|
281
|
+
accu = _mm_hadd_ps (accu, accu);
|
|
282
|
+
dis[i] = _mm_cvtss_f32 (accu);
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
} // anonymous namespace
|
|
288
|
+
|
|
289
|
+
void fvec_L2sqr_ny (float * dis, const float * x,
|
|
290
|
+
const float * y, size_t d, size_t ny) {
|
|
291
|
+
// optimized for a few special cases
|
|
292
|
+
switch(d) {
|
|
293
|
+
case 1:
|
|
294
|
+
fvec_L2sqr_ny_D1 (dis, x, y, ny);
|
|
295
|
+
return;
|
|
296
|
+
case 2:
|
|
297
|
+
fvec_L2sqr_ny_D2 (dis, x, y, ny);
|
|
298
|
+
return;
|
|
299
|
+
case 4:
|
|
300
|
+
fvec_L2sqr_ny_D4 (dis, x, y, ny);
|
|
301
|
+
return;
|
|
302
|
+
case 8:
|
|
303
|
+
fvec_L2sqr_ny_D8 (dis, x, y, ny);
|
|
304
|
+
return;
|
|
305
|
+
case 12:
|
|
306
|
+
fvec_L2sqr_ny_D12 (dis, x, y, ny);
|
|
307
|
+
return;
|
|
308
|
+
default:
|
|
309
|
+
fvec_L2sqr_ny_ref (dis, x, y, d, ny);
|
|
310
|
+
return;
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
#endif
|
|
317
|
+
|
|
318
|
+
#ifdef USE_AVX
|
|
319
|
+
|
|
320
|
+
// reads 0 <= d < 8 floats as __m256
|
|
321
|
+
static inline __m256 masked_read_8 (int d, const float *x)
|
|
322
|
+
{
|
|
323
|
+
assert (0 <= d && d < 8);
|
|
324
|
+
if (d < 4) {
|
|
325
|
+
__m256 res = _mm256_setzero_ps ();
|
|
326
|
+
res = _mm256_insertf128_ps (res, masked_read (d, x), 0);
|
|
327
|
+
return res;
|
|
328
|
+
} else {
|
|
329
|
+
__m256 res = _mm256_setzero_ps ();
|
|
330
|
+
res = _mm256_insertf128_ps (res, _mm_loadu_ps (x), 0);
|
|
331
|
+
res = _mm256_insertf128_ps (res, masked_read (d - 4, x + 4), 1);
|
|
332
|
+
return res;
|
|
333
|
+
}
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
float fvec_inner_product (const float * x,
|
|
337
|
+
const float * y,
|
|
338
|
+
size_t d)
|
|
339
|
+
{
|
|
340
|
+
__m256 msum1 = _mm256_setzero_ps();
|
|
341
|
+
|
|
342
|
+
while (d >= 8) {
|
|
343
|
+
__m256 mx = _mm256_loadu_ps (x); x += 8;
|
|
344
|
+
__m256 my = _mm256_loadu_ps (y); y += 8;
|
|
345
|
+
msum1 = _mm256_add_ps (msum1, _mm256_mul_ps (mx, my));
|
|
346
|
+
d -= 8;
|
|
347
|
+
}
|
|
348
|
+
|
|
349
|
+
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
350
|
+
msum2 += _mm256_extractf128_ps(msum1, 0);
|
|
351
|
+
|
|
352
|
+
if (d >= 4) {
|
|
353
|
+
__m128 mx = _mm_loadu_ps (x); x += 4;
|
|
354
|
+
__m128 my = _mm_loadu_ps (y); y += 4;
|
|
355
|
+
msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
|
|
356
|
+
d -= 4;
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
if (d > 0) {
|
|
360
|
+
__m128 mx = masked_read (d, x);
|
|
361
|
+
__m128 my = masked_read (d, y);
|
|
362
|
+
msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
|
|
363
|
+
}
|
|
364
|
+
|
|
365
|
+
msum2 = _mm_hadd_ps (msum2, msum2);
|
|
366
|
+
msum2 = _mm_hadd_ps (msum2, msum2);
|
|
367
|
+
return _mm_cvtss_f32 (msum2);
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
float fvec_L2sqr (const float * x,
|
|
371
|
+
const float * y,
|
|
372
|
+
size_t d)
|
|
373
|
+
{
|
|
374
|
+
__m256 msum1 = _mm256_setzero_ps();
|
|
375
|
+
|
|
376
|
+
while (d >= 8) {
|
|
377
|
+
__m256 mx = _mm256_loadu_ps (x); x += 8;
|
|
378
|
+
__m256 my = _mm256_loadu_ps (y); y += 8;
|
|
379
|
+
const __m256 a_m_b1 = mx - my;
|
|
380
|
+
msum1 += a_m_b1 * a_m_b1;
|
|
381
|
+
d -= 8;
|
|
382
|
+
}
|
|
383
|
+
|
|
384
|
+
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
385
|
+
msum2 += _mm256_extractf128_ps(msum1, 0);
|
|
386
|
+
|
|
387
|
+
if (d >= 4) {
|
|
388
|
+
__m128 mx = _mm_loadu_ps (x); x += 4;
|
|
389
|
+
__m128 my = _mm_loadu_ps (y); y += 4;
|
|
390
|
+
const __m128 a_m_b1 = mx - my;
|
|
391
|
+
msum2 += a_m_b1 * a_m_b1;
|
|
392
|
+
d -= 4;
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
if (d > 0) {
|
|
396
|
+
__m128 mx = masked_read (d, x);
|
|
397
|
+
__m128 my = masked_read (d, y);
|
|
398
|
+
__m128 a_m_b1 = mx - my;
|
|
399
|
+
msum2 += a_m_b1 * a_m_b1;
|
|
400
|
+
}
|
|
401
|
+
|
|
402
|
+
msum2 = _mm_hadd_ps (msum2, msum2);
|
|
403
|
+
msum2 = _mm_hadd_ps (msum2, msum2);
|
|
404
|
+
return _mm_cvtss_f32 (msum2);
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
float fvec_L1 (const float * x, const float * y, size_t d)
|
|
408
|
+
{
|
|
409
|
+
__m256 msum1 = _mm256_setzero_ps();
|
|
410
|
+
__m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL));
|
|
411
|
+
|
|
412
|
+
while (d >= 8) {
|
|
413
|
+
__m256 mx = _mm256_loadu_ps (x); x += 8;
|
|
414
|
+
__m256 my = _mm256_loadu_ps (y); y += 8;
|
|
415
|
+
const __m256 a_m_b = mx - my;
|
|
416
|
+
msum1 += _mm256_and_ps(signmask, a_m_b);
|
|
417
|
+
d -= 8;
|
|
418
|
+
}
|
|
419
|
+
|
|
420
|
+
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
421
|
+
msum2 += _mm256_extractf128_ps(msum1, 0);
|
|
422
|
+
__m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL));
|
|
423
|
+
|
|
424
|
+
if (d >= 4) {
|
|
425
|
+
__m128 mx = _mm_loadu_ps (x); x += 4;
|
|
426
|
+
__m128 my = _mm_loadu_ps (y); y += 4;
|
|
427
|
+
const __m128 a_m_b = mx - my;
|
|
428
|
+
msum2 += _mm_and_ps(signmask2, a_m_b);
|
|
429
|
+
d -= 4;
|
|
430
|
+
}
|
|
431
|
+
|
|
432
|
+
if (d > 0) {
|
|
433
|
+
__m128 mx = masked_read (d, x);
|
|
434
|
+
__m128 my = masked_read (d, y);
|
|
435
|
+
__m128 a_m_b = mx - my;
|
|
436
|
+
msum2 += _mm_and_ps(signmask2, a_m_b);
|
|
437
|
+
}
|
|
438
|
+
|
|
439
|
+
msum2 = _mm_hadd_ps (msum2, msum2);
|
|
440
|
+
msum2 = _mm_hadd_ps (msum2, msum2);
|
|
441
|
+
return _mm_cvtss_f32 (msum2);
|
|
442
|
+
}
|
|
443
|
+
|
|
444
|
+
float fvec_Linf (const float * x, const float * y, size_t d)
|
|
445
|
+
{
|
|
446
|
+
__m256 msum1 = _mm256_setzero_ps();
|
|
447
|
+
__m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL));
|
|
448
|
+
|
|
449
|
+
while (d >= 8) {
|
|
450
|
+
__m256 mx = _mm256_loadu_ps (x); x += 8;
|
|
451
|
+
__m256 my = _mm256_loadu_ps (y); y += 8;
|
|
452
|
+
const __m256 a_m_b = mx - my;
|
|
453
|
+
msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b));
|
|
454
|
+
d -= 8;
|
|
455
|
+
}
|
|
456
|
+
|
|
457
|
+
__m128 msum2 = _mm256_extractf128_ps(msum1, 1);
|
|
458
|
+
msum2 = _mm_max_ps (msum2, _mm256_extractf128_ps(msum1, 0));
|
|
459
|
+
__m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL));
|
|
460
|
+
|
|
461
|
+
if (d >= 4) {
|
|
462
|
+
__m128 mx = _mm_loadu_ps (x); x += 4;
|
|
463
|
+
__m128 my = _mm_loadu_ps (y); y += 4;
|
|
464
|
+
const __m128 a_m_b = mx - my;
|
|
465
|
+
msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
|
466
|
+
d -= 4;
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
if (d > 0) {
|
|
470
|
+
__m128 mx = masked_read (d, x);
|
|
471
|
+
__m128 my = masked_read (d, y);
|
|
472
|
+
__m128 a_m_b = mx - my;
|
|
473
|
+
msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
|
|
474
|
+
}
|
|
475
|
+
|
|
476
|
+
msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2);
|
|
477
|
+
msum2 = _mm_max_ps(msum2, _mm_shuffle_ps (msum2, msum2, 1));
|
|
478
|
+
return _mm_cvtss_f32 (msum2);
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
#elif defined(__SSE__) // But not AVX
|
|
482
|
+
|
|
483
|
+
float fvec_L1 (const float * x, const float * y, size_t d)
|
|
484
|
+
{
|
|
485
|
+
return fvec_L1_ref (x, y, d);
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
float fvec_Linf (const float * x, const float * y, size_t d)
|
|
489
|
+
{
|
|
490
|
+
return fvec_Linf_ref (x, y, d);
|
|
491
|
+
}
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
float fvec_L2sqr (const float * x,
|
|
495
|
+
const float * y,
|
|
496
|
+
size_t d)
|
|
497
|
+
{
|
|
498
|
+
__m128 msum1 = _mm_setzero_ps();
|
|
499
|
+
|
|
500
|
+
while (d >= 4) {
|
|
501
|
+
__m128 mx = _mm_loadu_ps (x); x += 4;
|
|
502
|
+
__m128 my = _mm_loadu_ps (y); y += 4;
|
|
503
|
+
const __m128 a_m_b1 = mx - my;
|
|
504
|
+
msum1 += a_m_b1 * a_m_b1;
|
|
505
|
+
d -= 4;
|
|
506
|
+
}
|
|
507
|
+
|
|
508
|
+
if (d > 0) {
|
|
509
|
+
// add the last 1, 2 or 3 values
|
|
510
|
+
__m128 mx = masked_read (d, x);
|
|
511
|
+
__m128 my = masked_read (d, y);
|
|
512
|
+
__m128 a_m_b1 = mx - my;
|
|
513
|
+
msum1 += a_m_b1 * a_m_b1;
|
|
514
|
+
}
|
|
515
|
+
|
|
516
|
+
msum1 = _mm_hadd_ps (msum1, msum1);
|
|
517
|
+
msum1 = _mm_hadd_ps (msum1, msum1);
|
|
518
|
+
return _mm_cvtss_f32 (msum1);
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
float fvec_inner_product (const float * x,
|
|
523
|
+
const float * y,
|
|
524
|
+
size_t d)
|
|
525
|
+
{
|
|
526
|
+
__m128 mx, my;
|
|
527
|
+
__m128 msum1 = _mm_setzero_ps();
|
|
528
|
+
|
|
529
|
+
while (d >= 4) {
|
|
530
|
+
mx = _mm_loadu_ps (x); x += 4;
|
|
531
|
+
my = _mm_loadu_ps (y); y += 4;
|
|
532
|
+
msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, my));
|
|
533
|
+
d -= 4;
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
// add the last 1, 2, or 3 values
|
|
537
|
+
mx = masked_read (d, x);
|
|
538
|
+
my = masked_read (d, y);
|
|
539
|
+
__m128 prod = _mm_mul_ps (mx, my);
|
|
540
|
+
|
|
541
|
+
msum1 = _mm_add_ps (msum1, prod);
|
|
542
|
+
|
|
543
|
+
msum1 = _mm_hadd_ps (msum1, msum1);
|
|
544
|
+
msum1 = _mm_hadd_ps (msum1, msum1);
|
|
545
|
+
return _mm_cvtss_f32 (msum1);
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
#elif defined(__aarch64__)
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
float fvec_L2sqr (const float * x,
|
|
552
|
+
const float * y,
|
|
553
|
+
size_t d)
|
|
554
|
+
{
|
|
555
|
+
if (d & 3) return fvec_L2sqr_ref (x, y, d);
|
|
556
|
+
float32x4_t accu = vdupq_n_f32 (0);
|
|
557
|
+
for (size_t i = 0; i < d; i += 4) {
|
|
558
|
+
float32x4_t xi = vld1q_f32 (x + i);
|
|
559
|
+
float32x4_t yi = vld1q_f32 (y + i);
|
|
560
|
+
float32x4_t sq = vsubq_f32 (xi, yi);
|
|
561
|
+
accu = vfmaq_f32 (accu, sq, sq);
|
|
562
|
+
}
|
|
563
|
+
float32x4_t a2 = vpaddq_f32 (accu, accu);
|
|
564
|
+
return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
|
|
565
|
+
}
|
|
566
|
+
|
|
567
|
+
float fvec_inner_product (const float * x,
|
|
568
|
+
const float * y,
|
|
569
|
+
size_t d)
|
|
570
|
+
{
|
|
571
|
+
if (d & 3) return fvec_inner_product_ref (x, y, d);
|
|
572
|
+
float32x4_t accu = vdupq_n_f32 (0);
|
|
573
|
+
for (size_t i = 0; i < d; i += 4) {
|
|
574
|
+
float32x4_t xi = vld1q_f32 (x + i);
|
|
575
|
+
float32x4_t yi = vld1q_f32 (y + i);
|
|
576
|
+
accu = vfmaq_f32 (accu, xi, yi);
|
|
577
|
+
}
|
|
578
|
+
float32x4_t a2 = vpaddq_f32 (accu, accu);
|
|
579
|
+
return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
|
|
580
|
+
}
|
|
581
|
+
|
|
582
|
+
float fvec_norm_L2sqr (const float *x, size_t d)
|
|
583
|
+
{
|
|
584
|
+
if (d & 3) return fvec_norm_L2sqr_ref (x, d);
|
|
585
|
+
float32x4_t accu = vdupq_n_f32 (0);
|
|
586
|
+
for (size_t i = 0; i < d; i += 4) {
|
|
587
|
+
float32x4_t xi = vld1q_f32 (x + i);
|
|
588
|
+
accu = vfmaq_f32 (accu, xi, xi);
|
|
589
|
+
}
|
|
590
|
+
float32x4_t a2 = vpaddq_f32 (accu, accu);
|
|
591
|
+
return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
|
|
592
|
+
}
|
|
593
|
+
|
|
594
|
+
// not optimized for ARM
|
|
595
|
+
void fvec_L2sqr_ny (float * dis, const float * x,
|
|
596
|
+
const float * y, size_t d, size_t ny) {
|
|
597
|
+
fvec_L2sqr_ny_ref (dis, x, y, d, ny);
|
|
598
|
+
}
|
|
599
|
+
|
|
600
|
+
float fvec_L1 (const float * x, const float * y, size_t d)
|
|
601
|
+
{
|
|
602
|
+
return fvec_L1_ref (x, y, d);
|
|
603
|
+
}
|
|
604
|
+
|
|
605
|
+
float fvec_Linf (const float * x, const float * y, size_t d)
|
|
606
|
+
{
|
|
607
|
+
return fvec_Linf_ref (x, y, d);
|
|
608
|
+
}
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
#else
|
|
612
|
+
// scalar implementation
|
|
613
|
+
|
|
614
|
+
float fvec_L2sqr (const float * x,
|
|
615
|
+
const float * y,
|
|
616
|
+
size_t d)
|
|
617
|
+
{
|
|
618
|
+
return fvec_L2sqr_ref (x, y, d);
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
float fvec_L1 (const float * x, const float * y, size_t d)
|
|
622
|
+
{
|
|
623
|
+
return fvec_L1_ref (x, y, d);
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
float fvec_Linf (const float * x, const float * y, size_t d)
|
|
627
|
+
{
|
|
628
|
+
return fvec_Linf_ref (x, y, d);
|
|
629
|
+
}
|
|
630
|
+
|
|
631
|
+
float fvec_inner_product (const float * x,
|
|
632
|
+
const float * y,
|
|
633
|
+
size_t d)
|
|
634
|
+
{
|
|
635
|
+
return fvec_inner_product_ref (x, y, d);
|
|
636
|
+
}
|
|
637
|
+
|
|
638
|
+
float fvec_norm_L2sqr (const float *x, size_t d)
|
|
639
|
+
{
|
|
640
|
+
return fvec_norm_L2sqr_ref (x, d);
|
|
641
|
+
}
|
|
642
|
+
|
|
643
|
+
void fvec_L2sqr_ny (float * dis, const float * x,
|
|
644
|
+
const float * y, size_t d, size_t ny) {
|
|
645
|
+
fvec_L2sqr_ny_ref (dis, x, y, d, ny);
|
|
646
|
+
}
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
#endif
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
/***************************************************************************
|
|
671
|
+
* heavily optimized table computations
|
|
672
|
+
***************************************************************************/
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
static inline void fvec_madd_ref (size_t n, const float *a,
|
|
676
|
+
float bf, const float *b, float *c) {
|
|
677
|
+
for (size_t i = 0; i < n; i++)
|
|
678
|
+
c[i] = a[i] + bf * b[i];
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
#ifdef __SSE__
|
|
682
|
+
|
|
683
|
+
static inline void fvec_madd_sse (size_t n, const float *a,
|
|
684
|
+
float bf, const float *b, float *c) {
|
|
685
|
+
n >>= 2;
|
|
686
|
+
__m128 bf4 = _mm_set_ps1 (bf);
|
|
687
|
+
__m128 * a4 = (__m128*)a;
|
|
688
|
+
__m128 * b4 = (__m128*)b;
|
|
689
|
+
__m128 * c4 = (__m128*)c;
|
|
690
|
+
|
|
691
|
+
while (n--) {
|
|
692
|
+
*c4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
|
|
693
|
+
b4++;
|
|
694
|
+
a4++;
|
|
695
|
+
c4++;
|
|
696
|
+
}
|
|
697
|
+
}
|
|
698
|
+
|
|
699
|
+
void fvec_madd (size_t n, const float *a,
|
|
700
|
+
float bf, const float *b, float *c)
|
|
701
|
+
{
|
|
702
|
+
if ((n & 3) == 0 &&
|
|
703
|
+
((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
|
704
|
+
fvec_madd_sse (n, a, bf, b, c);
|
|
705
|
+
else
|
|
706
|
+
fvec_madd_ref (n, a, bf, b, c);
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
#else
|
|
710
|
+
|
|
711
|
+
void fvec_madd (size_t n, const float *a,
|
|
712
|
+
float bf, const float *b, float *c)
|
|
713
|
+
{
|
|
714
|
+
fvec_madd_ref (n, a, bf, b, c);
|
|
715
|
+
}
|
|
716
|
+
|
|
717
|
+
#endif
|
|
718
|
+
|
|
719
|
+
static inline int fvec_madd_and_argmin_ref (size_t n, const float *a,
|
|
720
|
+
float bf, const float *b, float *c) {
|
|
721
|
+
float vmin = 1e20;
|
|
722
|
+
int imin = -1;
|
|
723
|
+
|
|
724
|
+
for (size_t i = 0; i < n; i++) {
|
|
725
|
+
c[i] = a[i] + bf * b[i];
|
|
726
|
+
if (c[i] < vmin) {
|
|
727
|
+
vmin = c[i];
|
|
728
|
+
imin = i;
|
|
729
|
+
}
|
|
730
|
+
}
|
|
731
|
+
return imin;
|
|
732
|
+
}
|
|
733
|
+
|
|
734
|
+
#ifdef __SSE__
|
|
735
|
+
|
|
736
|
+
static inline int fvec_madd_and_argmin_sse (
|
|
737
|
+
size_t n, const float *a,
|
|
738
|
+
float bf, const float *b, float *c) {
|
|
739
|
+
n >>= 2;
|
|
740
|
+
__m128 bf4 = _mm_set_ps1 (bf);
|
|
741
|
+
__m128 vmin4 = _mm_set_ps1 (1e20);
|
|
742
|
+
__m128i imin4 = _mm_set1_epi32 (-1);
|
|
743
|
+
__m128i idx4 = _mm_set_epi32 (3, 2, 1, 0);
|
|
744
|
+
__m128i inc4 = _mm_set1_epi32 (4);
|
|
745
|
+
__m128 * a4 = (__m128*)a;
|
|
746
|
+
__m128 * b4 = (__m128*)b;
|
|
747
|
+
__m128 * c4 = (__m128*)c;
|
|
748
|
+
|
|
749
|
+
while (n--) {
|
|
750
|
+
__m128 vc4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
|
|
751
|
+
*c4 = vc4;
|
|
752
|
+
__m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
|
|
753
|
+
// imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower!
|
|
754
|
+
|
|
755
|
+
imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
|
|
756
|
+
_mm_andnot_si128 (mask, imin4));
|
|
757
|
+
vmin4 = _mm_min_ps (vmin4, vc4);
|
|
758
|
+
b4++;
|
|
759
|
+
a4++;
|
|
760
|
+
c4++;
|
|
761
|
+
idx4 = _mm_add_epi32 (idx4, inc4);
|
|
762
|
+
}
|
|
763
|
+
|
|
764
|
+
// 4 values -> 2
|
|
765
|
+
{
|
|
766
|
+
idx4 = _mm_shuffle_epi32 (imin4, 3 << 2 | 2);
|
|
767
|
+
__m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 3 << 2 | 2);
|
|
768
|
+
__m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
|
|
769
|
+
imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
|
|
770
|
+
_mm_andnot_si128 (mask, imin4));
|
|
771
|
+
vmin4 = _mm_min_ps (vmin4, vc4);
|
|
772
|
+
}
|
|
773
|
+
// 2 values -> 1
|
|
774
|
+
{
|
|
775
|
+
idx4 = _mm_shuffle_epi32 (imin4, 1);
|
|
776
|
+
__m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 1);
|
|
777
|
+
__m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
|
|
778
|
+
imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
|
|
779
|
+
_mm_andnot_si128 (mask, imin4));
|
|
780
|
+
// vmin4 = _mm_min_ps (vmin4, vc4);
|
|
781
|
+
}
|
|
782
|
+
return _mm_cvtsi128_si32 (imin4);
|
|
783
|
+
}
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
int fvec_madd_and_argmin (size_t n, const float *a,
|
|
787
|
+
float bf, const float *b, float *c)
|
|
788
|
+
{
|
|
789
|
+
if ((n & 3) == 0 &&
|
|
790
|
+
((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
|
|
791
|
+
return fvec_madd_and_argmin_sse (n, a, bf, b, c);
|
|
792
|
+
else
|
|
793
|
+
return fvec_madd_and_argmin_ref (n, a, bf, b, c);
|
|
794
|
+
}
|
|
795
|
+
|
|
796
|
+
#else
|
|
797
|
+
|
|
798
|
+
int fvec_madd_and_argmin (size_t n, const float *a,
|
|
799
|
+
float bf, const float *b, float *c)
|
|
800
|
+
{
|
|
801
|
+
return fvec_madd_and_argmin_ref (n, a, bf, b, c);
|
|
802
|
+
}
|
|
803
|
+
|
|
804
|
+
#endif
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
|
|
808
|
+
|
|
809
|
+
} // namespace faiss
|