faiss 0.1.0 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,243 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ /* All distance functions for L2 and IP distances.
11
+ * The actual functions are implemented in distances.cpp and distances_simd.cpp */
12
+
13
+ #pragma once
14
+
15
+ #include <stdint.h>
16
+
17
+ #include <faiss/utils/Heap.h>
18
+
19
+
20
+ namespace faiss {
21
+
22
+ /*********************************************************
23
+ * Optimized distance/norm/inner prod computations
24
+ *********************************************************/
25
+
26
+
27
+ /// Squared L2 distance between two vectors
28
+ float fvec_L2sqr (
29
+ const float * x,
30
+ const float * y,
31
+ size_t d);
32
+
33
+ /// inner product
34
+ float fvec_inner_product (
35
+ const float * x,
36
+ const float * y,
37
+ size_t d);
38
+
39
+ /// L1 distance
40
+ float fvec_L1 (
41
+ const float * x,
42
+ const float * y,
43
+ size_t d);
44
+
45
+ float fvec_Linf (
46
+ const float * x,
47
+ const float * y,
48
+ size_t d);
49
+
50
+
51
+ /** Compute pairwise distances between sets of vectors
52
+ *
53
+ * @param d dimension of the vectors
54
+ * @param nq nb of query vectors
55
+ * @param nb nb of database vectors
56
+ * @param xq query vectors (size nq * d)
57
+ * @param xb database vectros (size nb * d)
58
+ * @param dis output distances (size nq * nb)
59
+ * @param ldq,ldb, ldd strides for the matrices
60
+ */
61
+ void pairwise_L2sqr (int64_t d,
62
+ int64_t nq, const float *xq,
63
+ int64_t nb, const float *xb,
64
+ float *dis,
65
+ int64_t ldq = -1, int64_t ldb = -1, int64_t ldd = -1);
66
+
67
+ /* compute the inner product between nx vectors x and one y */
68
+ void fvec_inner_products_ny (
69
+ float * ip, /* output inner product */
70
+ const float * x,
71
+ const float * y,
72
+ size_t d, size_t ny);
73
+
74
+ /* compute ny square L2 distance bewteen x and a set of contiguous y vectors */
75
+ void fvec_L2sqr_ny (
76
+ float * dis,
77
+ const float * x,
78
+ const float * y,
79
+ size_t d, size_t ny);
80
+
81
+
82
+ /** squared norm of a vector */
83
+ float fvec_norm_L2sqr (const float * x,
84
+ size_t d);
85
+
86
+ /** compute the L2 norms for a set of vectors
87
+ *
88
+ * @param ip output norms, size nx
89
+ * @param x set of vectors, size nx * d
90
+ */
91
+ void fvec_norms_L2 (float * ip, const float * x, size_t d, size_t nx);
92
+
93
+ /// same as fvec_norms_L2, but computes square norms
94
+ void fvec_norms_L2sqr (float * ip, const float * x, size_t d, size_t nx);
95
+
96
+ /* L2-renormalize a set of vector. Nothing done if the vector is 0-normed */
97
+ void fvec_renorm_L2 (size_t d, size_t nx, float * x);
98
+
99
+
100
+ /* This function exists because the Torch counterpart is extremly slow
101
+ (not multi-threaded + unexpected overhead even in single thread).
102
+ It is here to implement the usual property |x-y|^2=|x|^2+|y|^2-2<x|y> */
103
+ void inner_product_to_L2sqr (float * dis,
104
+ const float * nr1,
105
+ const float * nr2,
106
+ size_t n1, size_t n2);
107
+
108
+ /***************************************************************************
109
+ * Compute a subset of distances
110
+ ***************************************************************************/
111
+
112
+ /* compute the inner product between x and a subset y of ny vectors,
113
+ whose indices are given by idy. */
114
+ void fvec_inner_products_by_idx (
115
+ float * ip,
116
+ const float * x,
117
+ const float * y,
118
+ const int64_t *ids,
119
+ size_t d, size_t nx, size_t ny);
120
+
121
+ /* same but for a subset in y indexed by idsy (ny vectors in total) */
122
+ void fvec_L2sqr_by_idx (
123
+ float * dis,
124
+ const float * x,
125
+ const float * y,
126
+ const int64_t *ids, /* ids of y vecs */
127
+ size_t d, size_t nx, size_t ny);
128
+
129
+
130
+ /** compute dis[j] = L2sqr(x[ix[j]], y[iy[j]]) forall j=0..n-1
131
+ *
132
+ * @param x size (max(ix) + 1, d)
133
+ * @param y size (max(iy) + 1, d)
134
+ * @param ix size n
135
+ * @param iy size n
136
+ * @param dis size n
137
+ */
138
+ void pairwise_indexed_L2sqr (
139
+ size_t d, size_t n,
140
+ const float * x, const int64_t *ix,
141
+ const float * y, const int64_t *iy,
142
+ float *dis);
143
+
144
+ /* same for inner product */
145
+ void pairwise_indexed_inner_product (
146
+ size_t d, size_t n,
147
+ const float * x, const int64_t *ix,
148
+ const float * y, const int64_t *iy,
149
+ float *dis);
150
+
151
+ /***************************************************************************
152
+ * KNN functions
153
+ ***************************************************************************/
154
+
155
+ // threshold on nx above which we switch to BLAS to compute distances
156
+ extern int distance_compute_blas_threshold;
157
+
158
+ /** Return the k nearest neighors of each of the nx vectors x among the ny
159
+ * vector y, w.r.t to max inner product
160
+ *
161
+ * @param x query vectors, size nx * d
162
+ * @param y database vectors, size ny * d
163
+ * @param res result array, which also provides k. Sorted on output
164
+ */
165
+ void knn_inner_product (
166
+ const float * x,
167
+ const float * y,
168
+ size_t d, size_t nx, size_t ny,
169
+ float_minheap_array_t * res);
170
+
171
+ /** Same as knn_inner_product, for the L2 distance */
172
+ void knn_L2sqr (
173
+ const float * x,
174
+ const float * y,
175
+ size_t d, size_t nx, size_t ny,
176
+ float_maxheap_array_t * res);
177
+
178
+
179
+
180
+ /** same as knn_L2sqr, but base_shift[bno] is subtracted to all
181
+ * computed distances.
182
+ *
183
+ * @param base_shift size ny
184
+ */
185
+ void knn_L2sqr_base_shift (
186
+ const float * x,
187
+ const float * y,
188
+ size_t d, size_t nx, size_t ny,
189
+ float_maxheap_array_t * res,
190
+ const float *base_shift);
191
+
192
+ /* Find the nearest neighbors for nx queries in a set of ny vectors
193
+ * indexed by ids. May be useful for re-ranking a pre-selected vector list
194
+ */
195
+ void knn_inner_products_by_idx (
196
+ const float * x,
197
+ const float * y,
198
+ const int64_t * ids,
199
+ size_t d, size_t nx, size_t ny,
200
+ float_minheap_array_t * res);
201
+
202
+ void knn_L2sqr_by_idx (const float * x,
203
+ const float * y,
204
+ const int64_t * ids,
205
+ size_t d, size_t nx, size_t ny,
206
+ float_maxheap_array_t * res);
207
+
208
+ /***************************************************************************
209
+ * Range search
210
+ ***************************************************************************/
211
+
212
+
213
+
214
+ /// Forward declaration, see AuxIndexStructures.h
215
+ struct RangeSearchResult;
216
+
217
+ /** Return the k nearest neighors of each of the nx vectors x among the ny
218
+ * vector y, w.r.t to max inner product
219
+ *
220
+ * @param x query vectors, size nx * d
221
+ * @param y database vectors, size ny * d
222
+ * @param radius search radius around the x vectors
223
+ * @param result result structure
224
+ */
225
+ void range_search_L2sqr (
226
+ const float * x,
227
+ const float * y,
228
+ size_t d, size_t nx, size_t ny,
229
+ float radius,
230
+ RangeSearchResult *result);
231
+
232
+ /// same as range_search_L2sqr for the inner product similarity
233
+ void range_search_inner_product (
234
+ const float * x,
235
+ const float * y,
236
+ size_t d, size_t nx, size_t ny,
237
+ float radius,
238
+ RangeSearchResult *result);
239
+
240
+
241
+
242
+
243
+ } // namespace faiss
@@ -0,0 +1,809 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/utils/distances.h>
11
+
12
+ #include <cstdio>
13
+ #include <cassert>
14
+ #include <cstring>
15
+ #include <cmath>
16
+
17
+ #ifdef __SSE__
18
+ #include <immintrin.h>
19
+ #endif
20
+
21
+ #ifdef __aarch64__
22
+ #include <arm_neon.h>
23
+ #endif
24
+
25
+ #include <omp.h>
26
+
27
+ namespace faiss {
28
+
29
+ #ifdef __AVX__
30
+ #define USE_AVX
31
+ #endif
32
+
33
+
34
+ /*********************************************************
35
+ * Optimized distance computations
36
+ *********************************************************/
37
+
38
+
39
+ /* Functions to compute:
40
+ - L2 distance between 2 vectors
41
+ - inner product between 2 vectors
42
+ - L2 norm of a vector
43
+
44
+ The functions should probably not be invoked when a large number of
45
+ vectors are be processed in batch (in which case Matrix multiply
46
+ is faster), but may be useful for comparing vectors isolated in
47
+ memory.
48
+
49
+ Works with any vectors of any dimension, even unaligned (in which
50
+ case they are slower).
51
+
52
+ */
53
+
54
+
55
+ /*********************************************************
56
+ * Reference implementations
57
+ */
58
+
59
+
60
+ float fvec_L2sqr_ref (const float * x,
61
+ const float * y,
62
+ size_t d)
63
+ {
64
+ size_t i;
65
+ float res = 0;
66
+ for (i = 0; i < d; i++) {
67
+ const float tmp = x[i] - y[i];
68
+ res += tmp * tmp;
69
+ }
70
+ return res;
71
+ }
72
+
73
+ float fvec_L1_ref (const float * x,
74
+ const float * y,
75
+ size_t d)
76
+ {
77
+ size_t i;
78
+ float res = 0;
79
+ for (i = 0; i < d; i++) {
80
+ const float tmp = x[i] - y[i];
81
+ res += fabs(tmp);
82
+ }
83
+ return res;
84
+ }
85
+
86
+ float fvec_Linf_ref (const float * x,
87
+ const float * y,
88
+ size_t d)
89
+ {
90
+ size_t i;
91
+ float res = 0;
92
+ for (i = 0; i < d; i++) {
93
+ res = fmax(res, fabs(x[i] - y[i]));
94
+ }
95
+ return res;
96
+ }
97
+
98
+ float fvec_inner_product_ref (const float * x,
99
+ const float * y,
100
+ size_t d)
101
+ {
102
+ size_t i;
103
+ float res = 0;
104
+ for (i = 0; i < d; i++)
105
+ res += x[i] * y[i];
106
+ return res;
107
+ }
108
+
109
+ float fvec_norm_L2sqr_ref (const float *x, size_t d)
110
+ {
111
+ size_t i;
112
+ double res = 0;
113
+ for (i = 0; i < d; i++)
114
+ res += x[i] * x[i];
115
+ return res;
116
+ }
117
+
118
+
119
+ void fvec_L2sqr_ny_ref (float * dis,
120
+ const float * x,
121
+ const float * y,
122
+ size_t d, size_t ny)
123
+ {
124
+ for (size_t i = 0; i < ny; i++) {
125
+ dis[i] = fvec_L2sqr (x, y, d);
126
+ y += d;
127
+ }
128
+ }
129
+
130
+
131
+
132
+
133
+ /*********************************************************
134
+ * SSE and AVX implementations
135
+ */
136
+
137
+ #ifdef __SSE__
138
+
139
+ // reads 0 <= d < 4 floats as __m128
140
+ static inline __m128 masked_read (int d, const float *x)
141
+ {
142
+ assert (0 <= d && d < 4);
143
+ __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
144
+ switch (d) {
145
+ case 3:
146
+ buf[2] = x[2];
147
+ case 2:
148
+ buf[1] = x[1];
149
+ case 1:
150
+ buf[0] = x[0];
151
+ }
152
+ return _mm_load_ps (buf);
153
+ // cannot use AVX2 _mm_mask_set1_epi32
154
+ }
155
+
156
+ float fvec_norm_L2sqr (const float * x,
157
+ size_t d)
158
+ {
159
+ __m128 mx;
160
+ __m128 msum1 = _mm_setzero_ps();
161
+
162
+ while (d >= 4) {
163
+ mx = _mm_loadu_ps (x); x += 4;
164
+ msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
165
+ d -= 4;
166
+ }
167
+
168
+ mx = masked_read (d, x);
169
+ msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
170
+
171
+ msum1 = _mm_hadd_ps (msum1, msum1);
172
+ msum1 = _mm_hadd_ps (msum1, msum1);
173
+ return _mm_cvtss_f32 (msum1);
174
+ }
175
+
176
+ namespace {
177
+
178
+ float sqr (float x) {
179
+ return x * x;
180
+ }
181
+
182
+
183
+ void fvec_L2sqr_ny_D1 (float * dis, const float * x,
184
+ const float * y, size_t ny)
185
+ {
186
+ float x0s = x[0];
187
+ __m128 x0 = _mm_set_ps (x0s, x0s, x0s, x0s);
188
+
189
+ size_t i;
190
+ for (i = 0; i + 3 < ny; i += 4) {
191
+ __m128 tmp, accu;
192
+ tmp = x0 - _mm_loadu_ps (y); y += 4;
193
+ accu = tmp * tmp;
194
+ dis[i] = _mm_cvtss_f32 (accu);
195
+ tmp = _mm_shuffle_ps (accu, accu, 1);
196
+ dis[i + 1] = _mm_cvtss_f32 (tmp);
197
+ tmp = _mm_shuffle_ps (accu, accu, 2);
198
+ dis[i + 2] = _mm_cvtss_f32 (tmp);
199
+ tmp = _mm_shuffle_ps (accu, accu, 3);
200
+ dis[i + 3] = _mm_cvtss_f32 (tmp);
201
+ }
202
+ while (i < ny) { // handle non-multiple-of-4 case
203
+ dis[i++] = sqr(x0s - *y++);
204
+ }
205
+ }
206
+
207
+
208
+ void fvec_L2sqr_ny_D2 (float * dis, const float * x,
209
+ const float * y, size_t ny)
210
+ {
211
+ __m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
212
+
213
+ size_t i;
214
+ for (i = 0; i + 1 < ny; i += 2) {
215
+ __m128 tmp, accu;
216
+ tmp = x0 - _mm_loadu_ps (y); y += 4;
217
+ accu = tmp * tmp;
218
+ accu = _mm_hadd_ps (accu, accu);
219
+ dis[i] = _mm_cvtss_f32 (accu);
220
+ accu = _mm_shuffle_ps (accu, accu, 3);
221
+ dis[i + 1] = _mm_cvtss_f32 (accu);
222
+ }
223
+ if (i < ny) { // handle odd case
224
+ dis[i] = sqr(x[0] - y[0]) + sqr(x[1] - y[1]);
225
+ }
226
+ }
227
+
228
+
229
+
230
+ void fvec_L2sqr_ny_D4 (float * dis, const float * x,
231
+ const float * y, size_t ny)
232
+ {
233
+ __m128 x0 = _mm_loadu_ps(x);
234
+
235
+ for (size_t i = 0; i < ny; i++) {
236
+ __m128 tmp, accu;
237
+ tmp = x0 - _mm_loadu_ps (y); y += 4;
238
+ accu = tmp * tmp;
239
+ accu = _mm_hadd_ps (accu, accu);
240
+ accu = _mm_hadd_ps (accu, accu);
241
+ dis[i] = _mm_cvtss_f32 (accu);
242
+ }
243
+ }
244
+
245
+
246
+ void fvec_L2sqr_ny_D8 (float * dis, const float * x,
247
+ const float * y, size_t ny)
248
+ {
249
+ __m128 x0 = _mm_loadu_ps(x);
250
+ __m128 x1 = _mm_loadu_ps(x + 4);
251
+
252
+ for (size_t i = 0; i < ny; i++) {
253
+ __m128 tmp, accu;
254
+ tmp = x0 - _mm_loadu_ps (y); y += 4;
255
+ accu = tmp * tmp;
256
+ tmp = x1 - _mm_loadu_ps (y); y += 4;
257
+ accu += tmp * tmp;
258
+ accu = _mm_hadd_ps (accu, accu);
259
+ accu = _mm_hadd_ps (accu, accu);
260
+ dis[i] = _mm_cvtss_f32 (accu);
261
+ }
262
+ }
263
+
264
+
265
+ void fvec_L2sqr_ny_D12 (float * dis, const float * x,
266
+ const float * y, size_t ny)
267
+ {
268
+ __m128 x0 = _mm_loadu_ps(x);
269
+ __m128 x1 = _mm_loadu_ps(x + 4);
270
+ __m128 x2 = _mm_loadu_ps(x + 8);
271
+
272
+ for (size_t i = 0; i < ny; i++) {
273
+ __m128 tmp, accu;
274
+ tmp = x0 - _mm_loadu_ps (y); y += 4;
275
+ accu = tmp * tmp;
276
+ tmp = x1 - _mm_loadu_ps (y); y += 4;
277
+ accu += tmp * tmp;
278
+ tmp = x2 - _mm_loadu_ps (y); y += 4;
279
+ accu += tmp * tmp;
280
+ accu = _mm_hadd_ps (accu, accu);
281
+ accu = _mm_hadd_ps (accu, accu);
282
+ dis[i] = _mm_cvtss_f32 (accu);
283
+ }
284
+ }
285
+
286
+
287
+ } // anonymous namespace
288
+
289
+ void fvec_L2sqr_ny (float * dis, const float * x,
290
+ const float * y, size_t d, size_t ny) {
291
+ // optimized for a few special cases
292
+ switch(d) {
293
+ case 1:
294
+ fvec_L2sqr_ny_D1 (dis, x, y, ny);
295
+ return;
296
+ case 2:
297
+ fvec_L2sqr_ny_D2 (dis, x, y, ny);
298
+ return;
299
+ case 4:
300
+ fvec_L2sqr_ny_D4 (dis, x, y, ny);
301
+ return;
302
+ case 8:
303
+ fvec_L2sqr_ny_D8 (dis, x, y, ny);
304
+ return;
305
+ case 12:
306
+ fvec_L2sqr_ny_D12 (dis, x, y, ny);
307
+ return;
308
+ default:
309
+ fvec_L2sqr_ny_ref (dis, x, y, d, ny);
310
+ return;
311
+ }
312
+ }
313
+
314
+
315
+
316
+ #endif
317
+
318
+ #ifdef USE_AVX
319
+
320
+ // reads 0 <= d < 8 floats as __m256
321
+ static inline __m256 masked_read_8 (int d, const float *x)
322
+ {
323
+ assert (0 <= d && d < 8);
324
+ if (d < 4) {
325
+ __m256 res = _mm256_setzero_ps ();
326
+ res = _mm256_insertf128_ps (res, masked_read (d, x), 0);
327
+ return res;
328
+ } else {
329
+ __m256 res = _mm256_setzero_ps ();
330
+ res = _mm256_insertf128_ps (res, _mm_loadu_ps (x), 0);
331
+ res = _mm256_insertf128_ps (res, masked_read (d - 4, x + 4), 1);
332
+ return res;
333
+ }
334
+ }
335
+
336
+ float fvec_inner_product (const float * x,
337
+ const float * y,
338
+ size_t d)
339
+ {
340
+ __m256 msum1 = _mm256_setzero_ps();
341
+
342
+ while (d >= 8) {
343
+ __m256 mx = _mm256_loadu_ps (x); x += 8;
344
+ __m256 my = _mm256_loadu_ps (y); y += 8;
345
+ msum1 = _mm256_add_ps (msum1, _mm256_mul_ps (mx, my));
346
+ d -= 8;
347
+ }
348
+
349
+ __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
350
+ msum2 += _mm256_extractf128_ps(msum1, 0);
351
+
352
+ if (d >= 4) {
353
+ __m128 mx = _mm_loadu_ps (x); x += 4;
354
+ __m128 my = _mm_loadu_ps (y); y += 4;
355
+ msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
356
+ d -= 4;
357
+ }
358
+
359
+ if (d > 0) {
360
+ __m128 mx = masked_read (d, x);
361
+ __m128 my = masked_read (d, y);
362
+ msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
363
+ }
364
+
365
+ msum2 = _mm_hadd_ps (msum2, msum2);
366
+ msum2 = _mm_hadd_ps (msum2, msum2);
367
+ return _mm_cvtss_f32 (msum2);
368
+ }
369
+
370
+ float fvec_L2sqr (const float * x,
371
+ const float * y,
372
+ size_t d)
373
+ {
374
+ __m256 msum1 = _mm256_setzero_ps();
375
+
376
+ while (d >= 8) {
377
+ __m256 mx = _mm256_loadu_ps (x); x += 8;
378
+ __m256 my = _mm256_loadu_ps (y); y += 8;
379
+ const __m256 a_m_b1 = mx - my;
380
+ msum1 += a_m_b1 * a_m_b1;
381
+ d -= 8;
382
+ }
383
+
384
+ __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
385
+ msum2 += _mm256_extractf128_ps(msum1, 0);
386
+
387
+ if (d >= 4) {
388
+ __m128 mx = _mm_loadu_ps (x); x += 4;
389
+ __m128 my = _mm_loadu_ps (y); y += 4;
390
+ const __m128 a_m_b1 = mx - my;
391
+ msum2 += a_m_b1 * a_m_b1;
392
+ d -= 4;
393
+ }
394
+
395
+ if (d > 0) {
396
+ __m128 mx = masked_read (d, x);
397
+ __m128 my = masked_read (d, y);
398
+ __m128 a_m_b1 = mx - my;
399
+ msum2 += a_m_b1 * a_m_b1;
400
+ }
401
+
402
+ msum2 = _mm_hadd_ps (msum2, msum2);
403
+ msum2 = _mm_hadd_ps (msum2, msum2);
404
+ return _mm_cvtss_f32 (msum2);
405
+ }
406
+
407
+ float fvec_L1 (const float * x, const float * y, size_t d)
408
+ {
409
+ __m256 msum1 = _mm256_setzero_ps();
410
+ __m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL));
411
+
412
+ while (d >= 8) {
413
+ __m256 mx = _mm256_loadu_ps (x); x += 8;
414
+ __m256 my = _mm256_loadu_ps (y); y += 8;
415
+ const __m256 a_m_b = mx - my;
416
+ msum1 += _mm256_and_ps(signmask, a_m_b);
417
+ d -= 8;
418
+ }
419
+
420
+ __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
421
+ msum2 += _mm256_extractf128_ps(msum1, 0);
422
+ __m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL));
423
+
424
+ if (d >= 4) {
425
+ __m128 mx = _mm_loadu_ps (x); x += 4;
426
+ __m128 my = _mm_loadu_ps (y); y += 4;
427
+ const __m128 a_m_b = mx - my;
428
+ msum2 += _mm_and_ps(signmask2, a_m_b);
429
+ d -= 4;
430
+ }
431
+
432
+ if (d > 0) {
433
+ __m128 mx = masked_read (d, x);
434
+ __m128 my = masked_read (d, y);
435
+ __m128 a_m_b = mx - my;
436
+ msum2 += _mm_and_ps(signmask2, a_m_b);
437
+ }
438
+
439
+ msum2 = _mm_hadd_ps (msum2, msum2);
440
+ msum2 = _mm_hadd_ps (msum2, msum2);
441
+ return _mm_cvtss_f32 (msum2);
442
+ }
443
+
444
+ float fvec_Linf (const float * x, const float * y, size_t d)
445
+ {
446
+ __m256 msum1 = _mm256_setzero_ps();
447
+ __m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL));
448
+
449
+ while (d >= 8) {
450
+ __m256 mx = _mm256_loadu_ps (x); x += 8;
451
+ __m256 my = _mm256_loadu_ps (y); y += 8;
452
+ const __m256 a_m_b = mx - my;
453
+ msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b));
454
+ d -= 8;
455
+ }
456
+
457
+ __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
458
+ msum2 = _mm_max_ps (msum2, _mm256_extractf128_ps(msum1, 0));
459
+ __m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL));
460
+
461
+ if (d >= 4) {
462
+ __m128 mx = _mm_loadu_ps (x); x += 4;
463
+ __m128 my = _mm_loadu_ps (y); y += 4;
464
+ const __m128 a_m_b = mx - my;
465
+ msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
466
+ d -= 4;
467
+ }
468
+
469
+ if (d > 0) {
470
+ __m128 mx = masked_read (d, x);
471
+ __m128 my = masked_read (d, y);
472
+ __m128 a_m_b = mx - my;
473
+ msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
474
+ }
475
+
476
+ msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2);
477
+ msum2 = _mm_max_ps(msum2, _mm_shuffle_ps (msum2, msum2, 1));
478
+ return _mm_cvtss_f32 (msum2);
479
+ }
480
+
481
+ #elif defined(__SSE__) // But not AVX
482
+
483
+ float fvec_L1 (const float * x, const float * y, size_t d)
484
+ {
485
+ return fvec_L1_ref (x, y, d);
486
+ }
487
+
488
+ float fvec_Linf (const float * x, const float * y, size_t d)
489
+ {
490
+ return fvec_Linf_ref (x, y, d);
491
+ }
492
+
493
+
494
+ float fvec_L2sqr (const float * x,
495
+ const float * y,
496
+ size_t d)
497
+ {
498
+ __m128 msum1 = _mm_setzero_ps();
499
+
500
+ while (d >= 4) {
501
+ __m128 mx = _mm_loadu_ps (x); x += 4;
502
+ __m128 my = _mm_loadu_ps (y); y += 4;
503
+ const __m128 a_m_b1 = mx - my;
504
+ msum1 += a_m_b1 * a_m_b1;
505
+ d -= 4;
506
+ }
507
+
508
+ if (d > 0) {
509
+ // add the last 1, 2 or 3 values
510
+ __m128 mx = masked_read (d, x);
511
+ __m128 my = masked_read (d, y);
512
+ __m128 a_m_b1 = mx - my;
513
+ msum1 += a_m_b1 * a_m_b1;
514
+ }
515
+
516
+ msum1 = _mm_hadd_ps (msum1, msum1);
517
+ msum1 = _mm_hadd_ps (msum1, msum1);
518
+ return _mm_cvtss_f32 (msum1);
519
+ }
520
+
521
+
522
+ float fvec_inner_product (const float * x,
523
+ const float * y,
524
+ size_t d)
525
+ {
526
+ __m128 mx, my;
527
+ __m128 msum1 = _mm_setzero_ps();
528
+
529
+ while (d >= 4) {
530
+ mx = _mm_loadu_ps (x); x += 4;
531
+ my = _mm_loadu_ps (y); y += 4;
532
+ msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, my));
533
+ d -= 4;
534
+ }
535
+
536
+ // add the last 1, 2, or 3 values
537
+ mx = masked_read (d, x);
538
+ my = masked_read (d, y);
539
+ __m128 prod = _mm_mul_ps (mx, my);
540
+
541
+ msum1 = _mm_add_ps (msum1, prod);
542
+
543
+ msum1 = _mm_hadd_ps (msum1, msum1);
544
+ msum1 = _mm_hadd_ps (msum1, msum1);
545
+ return _mm_cvtss_f32 (msum1);
546
+ }
547
+
548
+ #elif defined(__aarch64__)
549
+
550
+
551
+ float fvec_L2sqr (const float * x,
552
+ const float * y,
553
+ size_t d)
554
+ {
555
+ if (d & 3) return fvec_L2sqr_ref (x, y, d);
556
+ float32x4_t accu = vdupq_n_f32 (0);
557
+ for (size_t i = 0; i < d; i += 4) {
558
+ float32x4_t xi = vld1q_f32 (x + i);
559
+ float32x4_t yi = vld1q_f32 (y + i);
560
+ float32x4_t sq = vsubq_f32 (xi, yi);
561
+ accu = vfmaq_f32 (accu, sq, sq);
562
+ }
563
+ float32x4_t a2 = vpaddq_f32 (accu, accu);
564
+ return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
565
+ }
566
+
567
+ float fvec_inner_product (const float * x,
568
+ const float * y,
569
+ size_t d)
570
+ {
571
+ if (d & 3) return fvec_inner_product_ref (x, y, d);
572
+ float32x4_t accu = vdupq_n_f32 (0);
573
+ for (size_t i = 0; i < d; i += 4) {
574
+ float32x4_t xi = vld1q_f32 (x + i);
575
+ float32x4_t yi = vld1q_f32 (y + i);
576
+ accu = vfmaq_f32 (accu, xi, yi);
577
+ }
578
+ float32x4_t a2 = vpaddq_f32 (accu, accu);
579
+ return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
580
+ }
581
+
582
+ float fvec_norm_L2sqr (const float *x, size_t d)
583
+ {
584
+ if (d & 3) return fvec_norm_L2sqr_ref (x, d);
585
+ float32x4_t accu = vdupq_n_f32 (0);
586
+ for (size_t i = 0; i < d; i += 4) {
587
+ float32x4_t xi = vld1q_f32 (x + i);
588
+ accu = vfmaq_f32 (accu, xi, xi);
589
+ }
590
+ float32x4_t a2 = vpaddq_f32 (accu, accu);
591
+ return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
592
+ }
593
+
594
+ // not optimized for ARM
595
+ void fvec_L2sqr_ny (float * dis, const float * x,
596
+ const float * y, size_t d, size_t ny) {
597
+ fvec_L2sqr_ny_ref (dis, x, y, d, ny);
598
+ }
599
+
600
+ float fvec_L1 (const float * x, const float * y, size_t d)
601
+ {
602
+ return fvec_L1_ref (x, y, d);
603
+ }
604
+
605
+ float fvec_Linf (const float * x, const float * y, size_t d)
606
+ {
607
+ return fvec_Linf_ref (x, y, d);
608
+ }
609
+
610
+
611
+ #else
612
+ // scalar implementation
613
+
614
+ float fvec_L2sqr (const float * x,
615
+ const float * y,
616
+ size_t d)
617
+ {
618
+ return fvec_L2sqr_ref (x, y, d);
619
+ }
620
+
621
+ float fvec_L1 (const float * x, const float * y, size_t d)
622
+ {
623
+ return fvec_L1_ref (x, y, d);
624
+ }
625
+
626
+ float fvec_Linf (const float * x, const float * y, size_t d)
627
+ {
628
+ return fvec_Linf_ref (x, y, d);
629
+ }
630
+
631
+ float fvec_inner_product (const float * x,
632
+ const float * y,
633
+ size_t d)
634
+ {
635
+ return fvec_inner_product_ref (x, y, d);
636
+ }
637
+
638
+ float fvec_norm_L2sqr (const float *x, size_t d)
639
+ {
640
+ return fvec_norm_L2sqr_ref (x, d);
641
+ }
642
+
643
+ void fvec_L2sqr_ny (float * dis, const float * x,
644
+ const float * y, size_t d, size_t ny) {
645
+ fvec_L2sqr_ny_ref (dis, x, y, d, ny);
646
+ }
647
+
648
+
649
+ #endif
650
+
651
+
652
+
653
+
654
+
655
+
656
+
657
+
658
+
659
+
660
+
661
+
662
+
663
+
664
+
665
+
666
+
667
+
668
+
669
+
670
+ /***************************************************************************
671
+ * heavily optimized table computations
672
+ ***************************************************************************/
673
+
674
+
675
+ static inline void fvec_madd_ref (size_t n, const float *a,
676
+ float bf, const float *b, float *c) {
677
+ for (size_t i = 0; i < n; i++)
678
+ c[i] = a[i] + bf * b[i];
679
+ }
680
+
681
+ #ifdef __SSE__
682
+
683
+ static inline void fvec_madd_sse (size_t n, const float *a,
684
+ float bf, const float *b, float *c) {
685
+ n >>= 2;
686
+ __m128 bf4 = _mm_set_ps1 (bf);
687
+ __m128 * a4 = (__m128*)a;
688
+ __m128 * b4 = (__m128*)b;
689
+ __m128 * c4 = (__m128*)c;
690
+
691
+ while (n--) {
692
+ *c4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
693
+ b4++;
694
+ a4++;
695
+ c4++;
696
+ }
697
+ }
698
+
699
+ void fvec_madd (size_t n, const float *a,
700
+ float bf, const float *b, float *c)
701
+ {
702
+ if ((n & 3) == 0 &&
703
+ ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
704
+ fvec_madd_sse (n, a, bf, b, c);
705
+ else
706
+ fvec_madd_ref (n, a, bf, b, c);
707
+ }
708
+
709
+ #else
710
+
711
+ void fvec_madd (size_t n, const float *a,
712
+ float bf, const float *b, float *c)
713
+ {
714
+ fvec_madd_ref (n, a, bf, b, c);
715
+ }
716
+
717
+ #endif
718
+
719
+ static inline int fvec_madd_and_argmin_ref (size_t n, const float *a,
720
+ float bf, const float *b, float *c) {
721
+ float vmin = 1e20;
722
+ int imin = -1;
723
+
724
+ for (size_t i = 0; i < n; i++) {
725
+ c[i] = a[i] + bf * b[i];
726
+ if (c[i] < vmin) {
727
+ vmin = c[i];
728
+ imin = i;
729
+ }
730
+ }
731
+ return imin;
732
+ }
733
+
734
+ #ifdef __SSE__
735
+
736
+ static inline int fvec_madd_and_argmin_sse (
737
+ size_t n, const float *a,
738
+ float bf, const float *b, float *c) {
739
+ n >>= 2;
740
+ __m128 bf4 = _mm_set_ps1 (bf);
741
+ __m128 vmin4 = _mm_set_ps1 (1e20);
742
+ __m128i imin4 = _mm_set1_epi32 (-1);
743
+ __m128i idx4 = _mm_set_epi32 (3, 2, 1, 0);
744
+ __m128i inc4 = _mm_set1_epi32 (4);
745
+ __m128 * a4 = (__m128*)a;
746
+ __m128 * b4 = (__m128*)b;
747
+ __m128 * c4 = (__m128*)c;
748
+
749
+ while (n--) {
750
+ __m128 vc4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
751
+ *c4 = vc4;
752
+ __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
753
+ // imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower!
754
+
755
+ imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
756
+ _mm_andnot_si128 (mask, imin4));
757
+ vmin4 = _mm_min_ps (vmin4, vc4);
758
+ b4++;
759
+ a4++;
760
+ c4++;
761
+ idx4 = _mm_add_epi32 (idx4, inc4);
762
+ }
763
+
764
+ // 4 values -> 2
765
+ {
766
+ idx4 = _mm_shuffle_epi32 (imin4, 3 << 2 | 2);
767
+ __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 3 << 2 | 2);
768
+ __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
769
+ imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
770
+ _mm_andnot_si128 (mask, imin4));
771
+ vmin4 = _mm_min_ps (vmin4, vc4);
772
+ }
773
+ // 2 values -> 1
774
+ {
775
+ idx4 = _mm_shuffle_epi32 (imin4, 1);
776
+ __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 1);
777
+ __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
778
+ imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
779
+ _mm_andnot_si128 (mask, imin4));
780
+ // vmin4 = _mm_min_ps (vmin4, vc4);
781
+ }
782
+ return _mm_cvtsi128_si32 (imin4);
783
+ }
784
+
785
+
786
+ int fvec_madd_and_argmin (size_t n, const float *a,
787
+ float bf, const float *b, float *c)
788
+ {
789
+ if ((n & 3) == 0 &&
790
+ ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
791
+ return fvec_madd_and_argmin_sse (n, a, bf, b, c);
792
+ else
793
+ return fvec_madd_and_argmin_ref (n, a, bf, b, c);
794
+ }
795
+
796
+ #else
797
+
798
+ int fvec_madd_and_argmin (size_t n, const float *a,
799
+ float bf, const float *b, float *c)
800
+ {
801
+ return fvec_madd_and_argmin_ref (n, a, bf, b, c);
802
+ }
803
+
804
+ #endif
805
+
806
+
807
+
808
+
809
+ } // namespace faiss