faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -5,49 +5,34 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  /* All distance functions for L2 and IP distances.
11
- * The actual functions are implemented in distances.cpp and distances_simd.cpp */
9
+ * The actual functions are implemented in distances.cpp and distances_simd.cpp
10
+ */
12
11
 
13
12
  #pragma once
14
13
 
15
14
  #include <stdint.h>
16
15
 
17
- #include <faiss/utils/Heap.h>
18
16
  #include <faiss/impl/platform_macros.h>
19
-
17
+ #include <faiss/utils/Heap.h>
20
18
 
21
19
  namespace faiss {
22
20
 
23
- /*********************************************************
21
+ /*********************************************************
24
22
  * Optimized distance/norm/inner prod computations
25
23
  *********************************************************/
26
24
 
27
-
28
25
  /// Squared L2 distance between two vectors
29
- float fvec_L2sqr (
30
- const float * x,
31
- const float * y,
32
- size_t d);
26
+ float fvec_L2sqr(const float* x, const float* y, size_t d);
33
27
 
34
28
  /// inner product
35
- float fvec_inner_product (
36
- const float * x,
37
- const float * y,
38
- size_t d);
29
+ float fvec_inner_product(const float* x, const float* y, size_t d);
39
30
 
40
31
  /// L1 distance
41
- float fvec_L1 (
42
- const float * x,
43
- const float * y,
44
- size_t d);
45
-
46
- float fvec_Linf (
47
- const float * x,
48
- const float * y,
49
- size_t d);
32
+ float fvec_L1(const float* x, const float* y, size_t d);
50
33
 
34
+ /// infinity distance
35
+ float fvec_Linf(const float* x, const float* y, size_t d);
51
36
 
52
37
  /** Compute pairwise distances between sets of vectors
53
38
  *
@@ -59,74 +44,83 @@ float fvec_Linf (
59
44
  * @param dis output distances (size nq * nb)
60
45
  * @param ldq,ldb, ldd strides for the matrices
61
46
  */
62
- void pairwise_L2sqr (int64_t d,
63
- int64_t nq, const float *xq,
64
- int64_t nb, const float *xb,
65
- float *dis,
66
- int64_t ldq = -1, int64_t ldb = -1, int64_t ldd = -1);
47
+ void pairwise_L2sqr(
48
+ int64_t d,
49
+ int64_t nq,
50
+ const float* xq,
51
+ int64_t nb,
52
+ const float* xb,
53
+ float* dis,
54
+ int64_t ldq = -1,
55
+ int64_t ldb = -1,
56
+ int64_t ldd = -1);
67
57
 
68
58
  /* compute the inner product between nx vectors x and one y */
69
- void fvec_inner_products_ny (
70
- float * ip, /* output inner product */
71
- const float * x,
72
- const float * y,
73
- size_t d, size_t ny);
59
+ void fvec_inner_products_ny(
60
+ float* ip, /* output inner product */
61
+ const float* x,
62
+ const float* y,
63
+ size_t d,
64
+ size_t ny);
74
65
 
75
66
  /* compute ny square L2 distance bewteen x and a set of contiguous y vectors */
76
- void fvec_L2sqr_ny (
77
- float * dis,
78
- const float * x,
79
- const float * y,
80
- size_t d, size_t ny);
81
-
67
+ void fvec_L2sqr_ny(
68
+ float* dis,
69
+ const float* x,
70
+ const float* y,
71
+ size_t d,
72
+ size_t ny);
82
73
 
83
74
  /** squared norm of a vector */
84
- float fvec_norm_L2sqr (const float * x,
85
- size_t d);
75
+ float fvec_norm_L2sqr(const float* x, size_t d);
86
76
 
87
77
  /** compute the L2 norms for a set of vectors
88
78
  *
89
- * @param ip output norms, size nx
79
+ * @param norms output norms, size nx
90
80
  * @param x set of vectors, size nx * d
91
81
  */
92
- void fvec_norms_L2 (float * ip, const float * x, size_t d, size_t nx);
82
+ void fvec_norms_L2(float* norms, const float* x, size_t d, size_t nx);
93
83
 
94
- /// same as fvec_norms_L2, but computes square norms
95
- void fvec_norms_L2sqr (float * ip, const float * x, size_t d, size_t nx);
84
+ /// same as fvec_norms_L2, but computes squared norms
85
+ void fvec_norms_L2sqr(float* norms, const float* x, size_t d, size_t nx);
96
86
 
97
87
  /* L2-renormalize a set of vector. Nothing done if the vector is 0-normed */
98
- void fvec_renorm_L2 (size_t d, size_t nx, float * x);
99
-
88
+ void fvec_renorm_L2(size_t d, size_t nx, float* x);
100
89
 
101
90
  /* This function exists because the Torch counterpart is extremly slow
102
91
  (not multi-threaded + unexpected overhead even in single thread).
103
92
  It is here to implement the usual property |x-y|^2=|x|^2+|y|^2-2<x|y> */
104
- void inner_product_to_L2sqr (float * dis,
105
- const float * nr1,
106
- const float * nr2,
107
- size_t n1, size_t n2);
93
+ void inner_product_to_L2sqr(
94
+ float* dis,
95
+ const float* nr1,
96
+ const float* nr2,
97
+ size_t n1,
98
+ size_t n2);
108
99
 
109
100
  /***************************************************************************
110
101
  * Compute a subset of distances
111
102
  ***************************************************************************/
112
103
 
113
- /* compute the inner product between x and a subset y of ny vectors,
114
- whose indices are given by idy. */
115
- void fvec_inner_products_by_idx (
116
- float * ip,
117
- const float * x,
118
- const float * y,
119
- const int64_t *ids,
120
- size_t d, size_t nx, size_t ny);
104
+ /* compute the inner product between x and a subset y of ny vectors,
105
+ whose indices are given by idy. */
106
+ void fvec_inner_products_by_idx(
107
+ float* ip,
108
+ const float* x,
109
+ const float* y,
110
+ const int64_t* ids,
111
+ size_t d,
112
+ size_t nx,
113
+ size_t ny);
121
114
 
122
115
  /* same but for a subset in y indexed by idsy (ny vectors in total) */
123
- void fvec_L2sqr_by_idx (
124
- float * dis,
125
- const float * x,
126
- const float * y,
127
- const int64_t *ids, /* ids of y vecs */
128
- size_t d, size_t nx, size_t ny);
129
-
116
+ void fvec_L2sqr_by_idx(
117
+ float* dis,
118
+ const float* x,
119
+ const float* y,
120
+ const int64_t* ids, /* ids of y vecs */
121
+ size_t d,
122
+ size_t nx,
123
+ size_t ny);
130
124
 
131
125
  /** compute dis[j] = L2sqr(x[ix[j]], y[iy[j]]) forall j=0..n-1
132
126
  *
@@ -136,18 +130,24 @@ void fvec_L2sqr_by_idx (
136
130
  * @param iy size n
137
131
  * @param dis size n
138
132
  */
139
- void pairwise_indexed_L2sqr (
140
- size_t d, size_t n,
141
- const float * x, const int64_t *ix,
142
- const float * y, const int64_t *iy,
143
- float *dis);
133
+ void pairwise_indexed_L2sqr(
134
+ size_t d,
135
+ size_t n,
136
+ const float* x,
137
+ const int64_t* ix,
138
+ const float* y,
139
+ const int64_t* iy,
140
+ float* dis);
144
141
 
145
142
  /* same for inner product */
146
- void pairwise_indexed_inner_product (
147
- size_t d, size_t n,
148
- const float * x, const int64_t *ix,
149
- const float * y, const int64_t *iy,
150
- float *dis);
143
+ void pairwise_indexed_inner_product(
144
+ size_t d,
145
+ size_t n,
146
+ const float* x,
147
+ const int64_t* ix,
148
+ const float* y,
149
+ const int64_t* iy,
150
+ float* dis);
151
151
 
152
152
  /***************************************************************************
153
153
  * KNN functions
@@ -171,46 +171,51 @@ FAISS_API extern int distance_compute_min_k_reservoir;
171
171
  * @param y database vectors, size ny * d
172
172
  * @param res result array, which also provides k. Sorted on output
173
173
  */
174
- void knn_inner_product (
175
- const float * x,
176
- const float * y,
177
- size_t d, size_t nx, size_t ny,
178
- float_minheap_array_t * res);
174
+ void knn_inner_product(
175
+ const float* x,
176
+ const float* y,
177
+ size_t d,
178
+ size_t nx,
179
+ size_t ny,
180
+ float_minheap_array_t* res);
179
181
 
180
182
  /** Same as knn_inner_product, for the L2 distance
181
183
  * @param y_norm2 norms for the y vectors (nullptr or size ny)
182
184
  */
183
- void knn_L2sqr (
184
- const float * x,
185
- const float * y,
186
- size_t d, size_t nx, size_t ny,
187
- float_maxheap_array_t * res,
188
- const float *y_norm2 = nullptr);
189
-
185
+ void knn_L2sqr(
186
+ const float* x,
187
+ const float* y,
188
+ size_t d,
189
+ size_t nx,
190
+ size_t ny,
191
+ float_maxheap_array_t* res,
192
+ const float* y_norm2 = nullptr);
190
193
 
191
194
  /* Find the nearest neighbors for nx queries in a set of ny vectors
192
195
  * indexed by ids. May be useful for re-ranking a pre-selected vector list
193
196
  */
194
- void knn_inner_products_by_idx (
195
- const float * x,
196
- const float * y,
197
- const int64_t * ids,
198
- size_t d, size_t nx, size_t ny,
199
- float_minheap_array_t * res);
200
-
201
- void knn_L2sqr_by_idx (
202
- const float * x,
203
- const float * y,
204
- const int64_t * ids,
205
- size_t d, size_t nx, size_t ny,
206
- float_maxheap_array_t * res);
197
+ void knn_inner_products_by_idx(
198
+ const float* x,
199
+ const float* y,
200
+ const int64_t* ids,
201
+ size_t d,
202
+ size_t nx,
203
+ size_t ny,
204
+ float_minheap_array_t* res);
205
+
206
+ void knn_L2sqr_by_idx(
207
+ const float* x,
208
+ const float* y,
209
+ const int64_t* ids,
210
+ size_t d,
211
+ size_t nx,
212
+ size_t ny,
213
+ float_maxheap_array_t* res);
207
214
 
208
215
  /***************************************************************************
209
216
  * Range search
210
217
  ***************************************************************************/
211
218
 
212
-
213
-
214
219
  /// Forward declaration, see AuxIndexStructures.h
215
220
  struct RangeSearchResult;
216
221
 
@@ -222,21 +227,24 @@ struct RangeSearchResult;
222
227
  * @param radius search radius around the x vectors
223
228
  * @param result result structure
224
229
  */
225
- void range_search_L2sqr (
226
- const float * x,
227
- const float * y,
228
- size_t d, size_t nx, size_t ny,
230
+ void range_search_L2sqr(
231
+ const float* x,
232
+ const float* y,
233
+ size_t d,
234
+ size_t nx,
235
+ size_t ny,
229
236
  float radius,
230
- RangeSearchResult *result);
237
+ RangeSearchResult* result);
231
238
 
232
239
  /// same as range_search_L2sqr for the inner product similarity
233
- void range_search_inner_product (
234
- const float * x,
235
- const float * y,
236
- size_t d, size_t nx, size_t ny,
240
+ void range_search_inner_product(
241
+ const float* x,
242
+ const float* y,
243
+ size_t d,
244
+ size_t nx,
245
+ size_t ny,
237
246
  float radius,
238
- RangeSearchResult *result);
239
-
247
+ RangeSearchResult* result);
240
248
 
241
249
  /***************************************************************************
242
250
  * PQ tables computations
@@ -244,9 +252,16 @@ void range_search_inner_product (
244
252
 
245
253
  /// specialized function for PQ2
246
254
  void compute_PQ_dis_tables_dsub2(
247
- size_t d, size_t ksub, const float *centroids,
248
- size_t nx, const float * x,
255
+ size_t d,
256
+ size_t ksub,
257
+ const float* centroids,
258
+ size_t nx,
259
+ const float* x,
249
260
  bool is_inner_product,
250
- float * dis_tables);
261
+ float* dis_tables);
262
+
263
+ /***************************************************************************
264
+ * Templatized versions of distance functions
265
+ ***************************************************************************/
251
266
 
252
267
  } // namespace faiss
@@ -9,13 +9,14 @@
9
9
 
10
10
  #include <faiss/utils/distances.h>
11
11
 
12
- #include <cstdio>
13
12
  #include <cassert>
14
- #include <cstring>
15
13
  #include <cmath>
14
+ #include <cstdio>
15
+ #include <cstring>
16
16
 
17
- #include <faiss/utils/simdlib.h>
18
17
  #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/platform_macros.h>
19
+ #include <faiss/utils/simdlib.h>
19
20
 
20
21
  #ifdef __SSE3__
21
22
  #include <immintrin.h>
@@ -25,19 +26,16 @@
25
26
  #include <arm_neon.h>
26
27
  #endif
27
28
 
28
-
29
29
  namespace faiss {
30
30
 
31
31
  #ifdef __AVX__
32
32
  #define USE_AVX
33
33
  #endif
34
34
 
35
-
36
35
  /*********************************************************
37
36
  * Optimized distance computations
38
37
  *********************************************************/
39
38
 
40
-
41
39
  /* Functions to compute:
42
40
  - L2 distance between 2 vectors
43
41
  - inner product between 2 vectors
@@ -53,29 +51,21 @@ namespace faiss {
53
51
 
54
52
  */
55
53
 
56
-
57
54
  /*********************************************************
58
55
  * Reference implementations
59
56
  */
60
57
 
61
-
62
- float fvec_L2sqr_ref (const float * x,
63
- const float * y,
64
- size_t d)
65
- {
58
+ float fvec_L2sqr_ref(const float* x, const float* y, size_t d) {
66
59
  size_t i;
67
60
  float res = 0;
68
61
  for (i = 0; i < d; i++) {
69
62
  const float tmp = x[i] - y[i];
70
- res += tmp * tmp;
63
+ res += tmp * tmp;
71
64
  }
72
65
  return res;
73
66
  }
74
67
 
75
- float fvec_L1_ref (const float * x,
76
- const float * y,
77
- size_t d)
78
- {
68
+ float fvec_L1_ref(const float* x, const float* y, size_t d) {
79
69
  size_t i;
80
70
  float res = 0;
81
71
  for (i = 0; i < d; i++) {
@@ -85,56 +75,49 @@ float fvec_L1_ref (const float * x,
85
75
  return res;
86
76
  }
87
77
 
88
- float fvec_Linf_ref (const float * x,
89
- const float * y,
90
- size_t d)
91
- {
78
+ float fvec_Linf_ref(const float* x, const float* y, size_t d) {
92
79
  size_t i;
93
80
  float res = 0;
94
81
  for (i = 0; i < d; i++) {
95
- res = fmax(res, fabs(x[i] - y[i]));
82
+ res = fmax(res, fabs(x[i] - y[i]));
96
83
  }
97
84
  return res;
98
85
  }
99
86
 
100
- float fvec_inner_product_ref (const float * x,
101
- const float * y,
102
- size_t d)
103
- {
87
+ float fvec_inner_product_ref(const float* x, const float* y, size_t d) {
104
88
  size_t i;
105
89
  float res = 0;
106
90
  for (i = 0; i < d; i++)
107
- res += x[i] * y[i];
91
+ res += x[i] * y[i];
108
92
  return res;
109
93
  }
110
94
 
111
- float fvec_norm_L2sqr_ref (const float *x, size_t d)
112
- {
95
+ float fvec_norm_L2sqr_ref(const float* x, size_t d) {
113
96
  size_t i;
114
97
  double res = 0;
115
98
  for (i = 0; i < d; i++)
116
- res += x[i] * x[i];
99
+ res += x[i] * x[i];
117
100
  return res;
118
101
  }
119
102
 
120
-
121
- void fvec_L2sqr_ny_ref (float * dis,
122
- const float * x,
123
- const float * y,
124
- size_t d, size_t ny)
125
- {
103
+ void fvec_L2sqr_ny_ref(
104
+ float* dis,
105
+ const float* x,
106
+ const float* y,
107
+ size_t d,
108
+ size_t ny) {
126
109
  for (size_t i = 0; i < ny; i++) {
127
- dis[i] = fvec_L2sqr (x, y, d);
110
+ dis[i] = fvec_L2sqr(x, y, d);
128
111
  y += d;
129
112
  }
130
113
  }
131
114
 
132
-
133
- void fvec_inner_products_ny_ref (float * ip,
134
- const float * x,
135
- const float * y,
136
- size_t d, size_t ny)
137
- {
115
+ void fvec_inner_products_ny_ref(
116
+ float* ip,
117
+ const float* x,
118
+ const float* y,
119
+ size_t d,
120
+ size_t ny) {
138
121
  // BLAS slower for the use cases here
139
122
  #if 0
140
123
  {
@@ -146,15 +129,11 @@ void fvec_inner_products_ny_ref (float * ip,
146
129
  }
147
130
  #endif
148
131
  for (size_t i = 0; i < ny; i++) {
149
- ip[i] = fvec_inner_product (x, y, d);
132
+ ip[i] = fvec_inner_product(x, y, d);
150
133
  y += d;
151
134
  }
152
135
  }
153
136
 
154
-
155
-
156
-
157
-
158
137
  /*********************************************************
159
138
  * SSE and AVX implementations
160
139
  */
@@ -162,40 +141,38 @@ void fvec_inner_products_ny_ref (float * ip,
162
141
  #ifdef __SSE3__
163
142
 
164
143
  // reads 0 <= d < 4 floats as __m128
165
- static inline __m128 masked_read (int d, const float *x)
166
- {
167
- assert (0 <= d && d < 4);
168
- __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
144
+ static inline __m128 masked_read(int d, const float* x) {
145
+ assert(0 <= d && d < 4);
146
+ ALIGNED(16) float buf[4] = {0, 0, 0, 0};
169
147
  switch (d) {
170
- case 3:
171
- buf[2] = x[2];
172
- case 2:
173
- buf[1] = x[1];
174
- case 1:
175
- buf[0] = x[0];
176
- }
177
- return _mm_load_ps (buf);
148
+ case 3:
149
+ buf[2] = x[2];
150
+ case 2:
151
+ buf[1] = x[1];
152
+ case 1:
153
+ buf[0] = x[0];
154
+ }
155
+ return _mm_load_ps(buf);
178
156
  // cannot use AVX2 _mm_mask_set1_epi32
179
157
  }
180
158
 
181
- float fvec_norm_L2sqr (const float * x,
182
- size_t d)
183
- {
159
+ float fvec_norm_L2sqr(const float* x, size_t d) {
184
160
  __m128 mx;
185
161
  __m128 msum1 = _mm_setzero_ps();
186
162
 
187
163
  while (d >= 4) {
188
- mx = _mm_loadu_ps (x); x += 4;
189
- msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
164
+ mx = _mm_loadu_ps(x);
165
+ x += 4;
166
+ msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
190
167
  d -= 4;
191
168
  }
192
169
 
193
- mx = masked_read (d, x);
194
- msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
170
+ mx = masked_read(d, x);
171
+ msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
195
172
 
196
- msum1 = _mm_hadd_ps (msum1, msum1);
197
- msum1 = _mm_hadd_ps (msum1, msum1);
198
- return _mm_cvtss_f32 (msum1);
173
+ msum1 = _mm_hadd_ps(msum1, msum1);
174
+ msum1 = _mm_hadd_ps(msum1, msum1);
175
+ return _mm_cvtss_f32(msum1);
199
176
  }
200
177
 
201
178
  namespace {
@@ -204,586 +181,588 @@ namespace {
204
181
  /// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
205
182
  /// functions below
206
183
  struct ElementOpL2 {
207
-
208
- static float op (float x, float y) {
184
+ static float op(float x, float y) {
209
185
  float tmp = x - y;
210
186
  return tmp * tmp;
211
187
  }
212
188
 
213
- static __m128 op (__m128 x, __m128 y) {
214
- __m128 tmp = x - y;
215
- return tmp * tmp;
189
+ static __m128 op(__m128 x, __m128 y) {
190
+ __m128 tmp = _mm_sub_ps(x, y);
191
+ return _mm_mul_ps(tmp, tmp);
216
192
  }
217
-
218
193
  };
219
194
 
220
195
  /// Function that does a component-wise operation between x and y
221
196
  /// to compute inner products
222
197
  struct ElementOpIP {
223
-
224
- static float op (float x, float y) {
198
+ static float op(float x, float y) {
225
199
  return x * y;
226
200
  }
227
201
 
228
- static __m128 op (__m128 x, __m128 y) {
229
- return x * y;
202
+ static __m128 op(__m128 x, __m128 y) {
203
+ return _mm_mul_ps(x, y);
230
204
  }
231
-
232
205
  };
233
206
 
234
- template<class ElementOp>
235
- void fvec_op_ny_D1 (float * dis, const float * x,
236
- const float * y, size_t ny)
237
- {
207
+ template <class ElementOp>
208
+ void fvec_op_ny_D1(float* dis, const float* x, const float* y, size_t ny) {
238
209
  float x0s = x[0];
239
- __m128 x0 = _mm_set_ps (x0s, x0s, x0s, x0s);
210
+ __m128 x0 = _mm_set_ps(x0s, x0s, x0s, x0s);
240
211
 
241
212
  size_t i;
242
213
  for (i = 0; i + 3 < ny; i += 4) {
243
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
244
- dis[i] = _mm_cvtss_f32 (accu);
245
- __m128 tmp = _mm_shuffle_ps (accu, accu, 1);
246
- dis[i + 1] = _mm_cvtss_f32 (tmp);
247
- tmp = _mm_shuffle_ps (accu, accu, 2);
248
- dis[i + 2] = _mm_cvtss_f32 (tmp);
249
- tmp = _mm_shuffle_ps (accu, accu, 3);
250
- dis[i + 3] = _mm_cvtss_f32 (tmp);
214
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
215
+ y += 4;
216
+ dis[i] = _mm_cvtss_f32(accu);
217
+ __m128 tmp = _mm_shuffle_ps(accu, accu, 1);
218
+ dis[i + 1] = _mm_cvtss_f32(tmp);
219
+ tmp = _mm_shuffle_ps(accu, accu, 2);
220
+ dis[i + 2] = _mm_cvtss_f32(tmp);
221
+ tmp = _mm_shuffle_ps(accu, accu, 3);
222
+ dis[i + 3] = _mm_cvtss_f32(tmp);
251
223
  }
252
224
  while (i < ny) { // handle non-multiple-of-4 case
253
225
  dis[i++] = ElementOp::op(x0s, *y++);
254
226
  }
255
227
  }
256
228
 
257
- template<class ElementOp>
258
- void fvec_op_ny_D2 (float * dis, const float * x,
259
- const float * y, size_t ny)
260
- {
261
- __m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
229
+ template <class ElementOp>
230
+ void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) {
231
+ __m128 x0 = _mm_set_ps(x[1], x[0], x[1], x[0]);
262
232
 
263
233
  size_t i;
264
234
  for (i = 0; i + 1 < ny; i += 2) {
265
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
266
- accu = _mm_hadd_ps (accu, accu);
267
- dis[i] = _mm_cvtss_f32 (accu);
268
- accu = _mm_shuffle_ps (accu, accu, 3);
269
- dis[i + 1] = _mm_cvtss_f32 (accu);
235
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
236
+ y += 4;
237
+ accu = _mm_hadd_ps(accu, accu);
238
+ dis[i] = _mm_cvtss_f32(accu);
239
+ accu = _mm_shuffle_ps(accu, accu, 3);
240
+ dis[i + 1] = _mm_cvtss_f32(accu);
270
241
  }
271
242
  if (i < ny) { // handle odd case
272
243
  dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]);
273
244
  }
274
245
  }
275
246
 
276
-
277
-
278
- template<class ElementOp>
279
- void fvec_op_ny_D4 (float * dis, const float * x,
280
- const float * y, size_t ny)
281
- {
247
+ template <class ElementOp>
248
+ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
282
249
  __m128 x0 = _mm_loadu_ps(x);
283
250
 
284
251
  for (size_t i = 0; i < ny; i++) {
285
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
286
- accu = _mm_hadd_ps (accu, accu);
287
- accu = _mm_hadd_ps (accu, accu);
288
- dis[i] = _mm_cvtss_f32 (accu);
252
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
253
+ y += 4;
254
+ accu = _mm_hadd_ps(accu, accu);
255
+ accu = _mm_hadd_ps(accu, accu);
256
+ dis[i] = _mm_cvtss_f32(accu);
289
257
  }
290
258
  }
291
259
 
292
- template<class ElementOp>
293
- void fvec_op_ny_D8 (float * dis, const float * x,
294
- const float * y, size_t ny)
295
- {
260
+ template <class ElementOp>
261
+ void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
296
262
  __m128 x0 = _mm_loadu_ps(x);
297
263
  __m128 x1 = _mm_loadu_ps(x + 4);
298
264
 
299
265
  for (size_t i = 0; i < ny; i++) {
300
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
301
- accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
302
- accu = _mm_hadd_ps (accu, accu);
303
- accu = _mm_hadd_ps (accu, accu);
304
- dis[i] = _mm_cvtss_f32 (accu);
266
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
267
+ y += 4;
268
+ accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
269
+ y += 4;
270
+ accu = _mm_hadd_ps(accu, accu);
271
+ accu = _mm_hadd_ps(accu, accu);
272
+ dis[i] = _mm_cvtss_f32(accu);
305
273
  }
306
274
  }
307
275
 
308
- template<class ElementOp>
309
- void fvec_op_ny_D12 (float * dis, const float * x,
310
- const float * y, size_t ny)
311
- {
276
+ template <class ElementOp>
277
+ void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
312
278
  __m128 x0 = _mm_loadu_ps(x);
313
279
  __m128 x1 = _mm_loadu_ps(x + 4);
314
280
  __m128 x2 = _mm_loadu_ps(x + 8);
315
281
 
316
282
  for (size_t i = 0; i < ny; i++) {
317
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
318
- accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
319
- accu += ElementOp::op(x2, _mm_loadu_ps (y)); y += 4;
320
- accu = _mm_hadd_ps (accu, accu);
321
- accu = _mm_hadd_ps (accu, accu);
322
- dis[i] = _mm_cvtss_f32 (accu);
283
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
284
+ y += 4;
285
+ accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
286
+ y += 4;
287
+ accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y)));
288
+ y += 4;
289
+ accu = _mm_hadd_ps(accu, accu);
290
+ accu = _mm_hadd_ps(accu, accu);
291
+ dis[i] = _mm_cvtss_f32(accu);
323
292
  }
324
293
  }
325
294
 
326
-
327
-
328
295
  } // anonymous namespace
329
296
 
330
- void fvec_L2sqr_ny (float * dis, const float * x,
331
- const float * y, size_t d, size_t ny) {
297
+ void fvec_L2sqr_ny(
298
+ float* dis,
299
+ const float* x,
300
+ const float* y,
301
+ size_t d,
302
+ size_t ny) {
332
303
  // optimized for a few special cases
333
304
 
334
- #define DISPATCH(dval) \
335
- case dval:\
336
- fvec_op_ny_D ## dval <ElementOpL2> (dis, x, y, ny); \
305
+ #define DISPATCH(dval) \
306
+ case dval: \
307
+ fvec_op_ny_D##dval<ElementOpL2>(dis, x, y, ny); \
337
308
  return;
338
309
 
339
- switch(d) {
310
+ switch (d) {
340
311
  DISPATCH(1)
341
312
  DISPATCH(2)
342
313
  DISPATCH(4)
343
314
  DISPATCH(8)
344
315
  DISPATCH(12)
345
- default:
346
- fvec_L2sqr_ny_ref (dis, x, y, d, ny);
347
- return;
316
+ default:
317
+ fvec_L2sqr_ny_ref(dis, x, y, d, ny);
318
+ return;
348
319
  }
349
320
  #undef DISPATCH
350
-
351
321
  }
352
322
 
353
- void fvec_inner_products_ny (float * dis, const float * x,
354
- const float * y, size_t d, size_t ny) {
355
-
356
- #define DISPATCH(dval) \
357
- case dval:\
358
- fvec_op_ny_D ## dval <ElementOpIP> (dis, x, y, ny); \
323
+ void fvec_inner_products_ny(
324
+ float* dis,
325
+ const float* x,
326
+ const float* y,
327
+ size_t d,
328
+ size_t ny) {
329
+ #define DISPATCH(dval) \
330
+ case dval: \
331
+ fvec_op_ny_D##dval<ElementOpIP>(dis, x, y, ny); \
359
332
  return;
360
333
 
361
- switch(d) {
334
+ switch (d) {
362
335
  DISPATCH(1)
363
336
  DISPATCH(2)
364
337
  DISPATCH(4)
365
338
  DISPATCH(8)
366
339
  DISPATCH(12)
367
- default:
368
- fvec_inner_products_ny_ref (dis, x, y, d, ny);
369
- return;
340
+ default:
341
+ fvec_inner_products_ny_ref(dis, x, y, d, ny);
342
+ return;
370
343
  }
371
344
  #undef DISPATCH
372
-
373
345
  }
374
346
 
375
-
376
-
377
347
  #endif
378
348
 
379
349
  #ifdef USE_AVX
380
350
 
381
351
  // reads 0 <= d < 8 floats as __m256
382
- static inline __m256 masked_read_8 (int d, const float *x)
383
- {
384
- assert (0 <= d && d < 8);
352
+ static inline __m256 masked_read_8(int d, const float* x) {
353
+ assert(0 <= d && d < 8);
385
354
  if (d < 4) {
386
- __m256 res = _mm256_setzero_ps ();
387
- res = _mm256_insertf128_ps (res, masked_read (d, x), 0);
355
+ __m256 res = _mm256_setzero_ps();
356
+ res = _mm256_insertf128_ps(res, masked_read(d, x), 0);
388
357
  return res;
389
358
  } else {
390
- __m256 res = _mm256_setzero_ps ();
391
- res = _mm256_insertf128_ps (res, _mm_loadu_ps (x), 0);
392
- res = _mm256_insertf128_ps (res, masked_read (d - 4, x + 4), 1);
359
+ __m256 res = _mm256_setzero_ps();
360
+ res = _mm256_insertf128_ps(res, _mm_loadu_ps(x), 0);
361
+ res = _mm256_insertf128_ps(res, masked_read(d - 4, x + 4), 1);
393
362
  return res;
394
363
  }
395
364
  }
396
365
 
397
- float fvec_inner_product (const float * x,
398
- const float * y,
399
- size_t d)
400
- {
366
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
401
367
  __m256 msum1 = _mm256_setzero_ps();
402
368
 
403
369
  while (d >= 8) {
404
- __m256 mx = _mm256_loadu_ps (x); x += 8;
405
- __m256 my = _mm256_loadu_ps (y); y += 8;
406
- msum1 = _mm256_add_ps (msum1, _mm256_mul_ps (mx, my));
370
+ __m256 mx = _mm256_loadu_ps(x);
371
+ x += 8;
372
+ __m256 my = _mm256_loadu_ps(y);
373
+ y += 8;
374
+ msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(mx, my));
407
375
  d -= 8;
408
376
  }
409
377
 
410
378
  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
411
- msum2 += _mm256_extractf128_ps(msum1, 0);
379
+ msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
412
380
 
413
381
  if (d >= 4) {
414
- __m128 mx = _mm_loadu_ps (x); x += 4;
415
- __m128 my = _mm_loadu_ps (y); y += 4;
416
- msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
382
+ __m128 mx = _mm_loadu_ps(x);
383
+ x += 4;
384
+ __m128 my = _mm_loadu_ps(y);
385
+ y += 4;
386
+ msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
417
387
  d -= 4;
418
388
  }
419
389
 
420
390
  if (d > 0) {
421
- __m128 mx = masked_read (d, x);
422
- __m128 my = masked_read (d, y);
423
- msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
391
+ __m128 mx = masked_read(d, x);
392
+ __m128 my = masked_read(d, y);
393
+ msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
424
394
  }
425
395
 
426
- msum2 = _mm_hadd_ps (msum2, msum2);
427
- msum2 = _mm_hadd_ps (msum2, msum2);
428
- return _mm_cvtss_f32 (msum2);
396
+ msum2 = _mm_hadd_ps(msum2, msum2);
397
+ msum2 = _mm_hadd_ps(msum2, msum2);
398
+ return _mm_cvtss_f32(msum2);
429
399
  }
430
400
 
431
- float fvec_L2sqr (const float * x,
432
- const float * y,
433
- size_t d)
434
- {
401
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
435
402
  __m256 msum1 = _mm256_setzero_ps();
436
403
 
437
404
  while (d >= 8) {
438
- __m256 mx = _mm256_loadu_ps (x); x += 8;
439
- __m256 my = _mm256_loadu_ps (y); y += 8;
440
- const __m256 a_m_b1 = mx - my;
441
- msum1 += a_m_b1 * a_m_b1;
405
+ __m256 mx = _mm256_loadu_ps(x);
406
+ x += 8;
407
+ __m256 my = _mm256_loadu_ps(y);
408
+ y += 8;
409
+ const __m256 a_m_b1 = _mm256_sub_ps(mx, my);
410
+ msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(a_m_b1, a_m_b1));
442
411
  d -= 8;
443
412
  }
444
413
 
445
414
  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
446
- msum2 += _mm256_extractf128_ps(msum1, 0);
415
+ msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
447
416
 
448
417
  if (d >= 4) {
449
- __m128 mx = _mm_loadu_ps (x); x += 4;
450
- __m128 my = _mm_loadu_ps (y); y += 4;
451
- const __m128 a_m_b1 = mx - my;
452
- msum2 += a_m_b1 * a_m_b1;
418
+ __m128 mx = _mm_loadu_ps(x);
419
+ x += 4;
420
+ __m128 my = _mm_loadu_ps(y);
421
+ y += 4;
422
+ const __m128 a_m_b1 = _mm_sub_ps(mx, my);
423
+ msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
453
424
  d -= 4;
454
425
  }
455
426
 
456
427
  if (d > 0) {
457
- __m128 mx = masked_read (d, x);
458
- __m128 my = masked_read (d, y);
459
- __m128 a_m_b1 = mx - my;
460
- msum2 += a_m_b1 * a_m_b1;
428
+ __m128 mx = masked_read(d, x);
429
+ __m128 my = masked_read(d, y);
430
+ __m128 a_m_b1 = _mm_sub_ps(mx, my);
431
+ msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
461
432
  }
462
433
 
463
- msum2 = _mm_hadd_ps (msum2, msum2);
464
- msum2 = _mm_hadd_ps (msum2, msum2);
465
- return _mm_cvtss_f32 (msum2);
434
+ msum2 = _mm_hadd_ps(msum2, msum2);
435
+ msum2 = _mm_hadd_ps(msum2, msum2);
436
+ return _mm_cvtss_f32(msum2);
466
437
  }
467
438
 
468
- float fvec_L1 (const float * x, const float * y, size_t d)
469
- {
439
+ float fvec_L1(const float* x, const float* y, size_t d) {
470
440
  __m256 msum1 = _mm256_setzero_ps();
471
- __m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL));
441
+ __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
472
442
 
473
443
  while (d >= 8) {
474
- __m256 mx = _mm256_loadu_ps (x); x += 8;
475
- __m256 my = _mm256_loadu_ps (y); y += 8;
476
- const __m256 a_m_b = mx - my;
477
- msum1 += _mm256_and_ps(signmask, a_m_b);
444
+ __m256 mx = _mm256_loadu_ps(x);
445
+ x += 8;
446
+ __m256 my = _mm256_loadu_ps(y);
447
+ y += 8;
448
+ const __m256 a_m_b = _mm256_sub_ps(mx, my);
449
+ msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b));
478
450
  d -= 8;
479
451
  }
480
452
 
481
453
  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
482
- msum2 += _mm256_extractf128_ps(msum1, 0);
483
- __m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL));
454
+ msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
455
+ __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL));
484
456
 
485
457
  if (d >= 4) {
486
- __m128 mx = _mm_loadu_ps (x); x += 4;
487
- __m128 my = _mm_loadu_ps (y); y += 4;
488
- const __m128 a_m_b = mx - my;
489
- msum2 += _mm_and_ps(signmask2, a_m_b);
458
+ __m128 mx = _mm_loadu_ps(x);
459
+ x += 4;
460
+ __m128 my = _mm_loadu_ps(y);
461
+ y += 4;
462
+ const __m128 a_m_b = _mm_sub_ps(mx, my);
463
+ msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
490
464
  d -= 4;
491
465
  }
492
466
 
493
467
  if (d > 0) {
494
- __m128 mx = masked_read (d, x);
495
- __m128 my = masked_read (d, y);
496
- __m128 a_m_b = mx - my;
497
- msum2 += _mm_and_ps(signmask2, a_m_b);
468
+ __m128 mx = masked_read(d, x);
469
+ __m128 my = masked_read(d, y);
470
+ __m128 a_m_b = _mm_sub_ps(mx, my);
471
+ msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
498
472
  }
499
473
 
500
- msum2 = _mm_hadd_ps (msum2, msum2);
501
- msum2 = _mm_hadd_ps (msum2, msum2);
502
- return _mm_cvtss_f32 (msum2);
474
+ msum2 = _mm_hadd_ps(msum2, msum2);
475
+ msum2 = _mm_hadd_ps(msum2, msum2);
476
+ return _mm_cvtss_f32(msum2);
503
477
  }
504
478
 
505
- float fvec_Linf (const float * x, const float * y, size_t d)
506
- {
479
+ float fvec_Linf(const float* x, const float* y, size_t d) {
507
480
  __m256 msum1 = _mm256_setzero_ps();
508
- __m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL));
481
+ __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
509
482
 
510
483
  while (d >= 8) {
511
- __m256 mx = _mm256_loadu_ps (x); x += 8;
512
- __m256 my = _mm256_loadu_ps (y); y += 8;
513
- const __m256 a_m_b = mx - my;
484
+ __m256 mx = _mm256_loadu_ps(x);
485
+ x += 8;
486
+ __m256 my = _mm256_loadu_ps(y);
487
+ y += 8;
488
+ const __m256 a_m_b = _mm256_sub_ps(mx, my);
514
489
  msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b));
515
490
  d -= 8;
516
491
  }
517
492
 
518
493
  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
519
- msum2 = _mm_max_ps (msum2, _mm256_extractf128_ps(msum1, 0));
520
- __m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL));
494
+ msum2 = _mm_max_ps(msum2, _mm256_extractf128_ps(msum1, 0));
495
+ __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL));
521
496
 
522
497
  if (d >= 4) {
523
- __m128 mx = _mm_loadu_ps (x); x += 4;
524
- __m128 my = _mm_loadu_ps (y); y += 4;
525
- const __m128 a_m_b = mx - my;
498
+ __m128 mx = _mm_loadu_ps(x);
499
+ x += 4;
500
+ __m128 my = _mm_loadu_ps(y);
501
+ y += 4;
502
+ const __m128 a_m_b = _mm_sub_ps(mx, my);
526
503
  msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
527
504
  d -= 4;
528
505
  }
529
506
 
530
507
  if (d > 0) {
531
- __m128 mx = masked_read (d, x);
532
- __m128 my = masked_read (d, y);
533
- __m128 a_m_b = mx - my;
508
+ __m128 mx = masked_read(d, x);
509
+ __m128 my = masked_read(d, y);
510
+ __m128 a_m_b = _mm_sub_ps(mx, my);
534
511
  msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
535
512
  }
536
513
 
537
514
  msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2);
538
- msum2 = _mm_max_ps(msum2, _mm_shuffle_ps (msum2, msum2, 1));
539
- return _mm_cvtss_f32 (msum2);
515
+ msum2 = _mm_max_ps(msum2, _mm_shuffle_ps(msum2, msum2, 1));
516
+ return _mm_cvtss_f32(msum2);
540
517
  }
541
518
 
542
519
  #elif defined(__SSE3__) // But not AVX
543
520
 
544
- float fvec_L1 (const float * x, const float * y, size_t d)
545
- {
546
- return fvec_L1_ref (x, y, d);
521
+ float fvec_L1(const float* x, const float* y, size_t d) {
522
+ return fvec_L1_ref(x, y, d);
547
523
  }
548
524
 
549
- float fvec_Linf (const float * x, const float * y, size_t d)
550
- {
551
- return fvec_Linf_ref (x, y, d);
525
+ float fvec_Linf(const float* x, const float* y, size_t d) {
526
+ return fvec_Linf_ref(x, y, d);
552
527
  }
553
528
 
554
-
555
- float fvec_L2sqr (const float * x,
556
- const float * y,
557
- size_t d)
558
- {
529
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
559
530
  __m128 msum1 = _mm_setzero_ps();
560
531
 
561
532
  while (d >= 4) {
562
- __m128 mx = _mm_loadu_ps (x); x += 4;
563
- __m128 my = _mm_loadu_ps (y); y += 4;
564
- const __m128 a_m_b1 = mx - my;
565
- msum1 += a_m_b1 * a_m_b1;
533
+ __m128 mx = _mm_loadu_ps(x);
534
+ x += 4;
535
+ __m128 my = _mm_loadu_ps(y);
536
+ y += 4;
537
+ const __m128 a_m_b1 = _mm_sub_ps(mx, my);
538
+ msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
566
539
  d -= 4;
567
540
  }
568
541
 
569
542
  if (d > 0) {
570
543
  // add the last 1, 2 or 3 values
571
- __m128 mx = masked_read (d, x);
572
- __m128 my = masked_read (d, y);
573
- __m128 a_m_b1 = mx - my;
574
- msum1 += a_m_b1 * a_m_b1;
544
+ __m128 mx = masked_read(d, x);
545
+ __m128 my = masked_read(d, y);
546
+ __m128 a_m_b1 = _mm_sub_ps(mx, my);
547
+ msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
575
548
  }
576
549
 
577
- msum1 = _mm_hadd_ps (msum1, msum1);
578
- msum1 = _mm_hadd_ps (msum1, msum1);
579
- return _mm_cvtss_f32 (msum1);
550
+ msum1 = _mm_hadd_ps(msum1, msum1);
551
+ msum1 = _mm_hadd_ps(msum1, msum1);
552
+ return _mm_cvtss_f32(msum1);
580
553
  }
581
554
 
582
-
583
- float fvec_inner_product (const float * x,
584
- const float * y,
585
- size_t d)
586
- {
555
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
587
556
  __m128 mx, my;
588
557
  __m128 msum1 = _mm_setzero_ps();
589
558
 
590
559
  while (d >= 4) {
591
- mx = _mm_loadu_ps (x); x += 4;
592
- my = _mm_loadu_ps (y); y += 4;
593
- msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, my));
560
+ mx = _mm_loadu_ps(x);
561
+ x += 4;
562
+ my = _mm_loadu_ps(y);
563
+ y += 4;
564
+ msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, my));
594
565
  d -= 4;
595
566
  }
596
567
 
597
568
  // add the last 1, 2, or 3 values
598
- mx = masked_read (d, x);
599
- my = masked_read (d, y);
600
- __m128 prod = _mm_mul_ps (mx, my);
569
+ mx = masked_read(d, x);
570
+ my = masked_read(d, y);
571
+ __m128 prod = _mm_mul_ps(mx, my);
601
572
 
602
- msum1 = _mm_add_ps (msum1, prod);
573
+ msum1 = _mm_add_ps(msum1, prod);
603
574
 
604
- msum1 = _mm_hadd_ps (msum1, msum1);
605
- msum1 = _mm_hadd_ps (msum1, msum1);
606
- return _mm_cvtss_f32 (msum1);
575
+ msum1 = _mm_hadd_ps(msum1, msum1);
576
+ msum1 = _mm_hadd_ps(msum1, msum1);
577
+ return _mm_cvtss_f32(msum1);
607
578
  }
608
579
 
609
580
  #elif defined(__aarch64__)
610
581
 
611
-
612
- float fvec_L2sqr (const float * x,
613
- const float * y,
614
- size_t d)
615
- {
616
- if (d & 3) return fvec_L2sqr_ref (x, y, d);
617
- float32x4_t accu = vdupq_n_f32 (0);
618
- for (size_t i = 0; i < d; i += 4) {
619
- float32x4_t xi = vld1q_f32 (x + i);
620
- float32x4_t yi = vld1q_f32 (y + i);
621
- float32x4_t sq = vsubq_f32 (xi, yi);
622
- accu = vfmaq_f32 (accu, sq, sq);
582
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
583
+ float32x4_t accux4 = vdupq_n_f32(0);
584
+ const size_t d_simd = d - (d & 3);
585
+ size_t i;
586
+ for (i = 0; i < d_simd; i += 4) {
587
+ float32x4_t xi = vld1q_f32(x + i);
588
+ float32x4_t yi = vld1q_f32(y + i);
589
+ float32x4_t sq = vsubq_f32(xi, yi);
590
+ accux4 = vfmaq_f32(accux4, sq, sq);
591
+ }
592
+ float32x4_t accux2 = vpaddq_f32(accux4, accux4);
593
+ float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
594
+ for (; i < d; ++i) {
595
+ float32_t xi = x[i];
596
+ float32_t yi = y[i];
597
+ float32_t sq = xi - yi;
598
+ accux1 += sq * sq;
599
+ }
600
+ return accux1;
601
+ }
602
+
603
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
604
+ float32x4_t accux4 = vdupq_n_f32(0);
605
+ const size_t d_simd = d - (d & 3);
606
+ size_t i;
607
+ for (i = 0; i < d_simd; i += 4) {
608
+ float32x4_t xi = vld1q_f32(x + i);
609
+ float32x4_t yi = vld1q_f32(y + i);
610
+ accux4 = vfmaq_f32(accux4, xi, yi);
623
611
  }
624
- float32x4_t a2 = vpaddq_f32 (accu, accu);
625
- return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
626
- }
627
-
628
- float fvec_inner_product (const float * x,
629
- const float * y,
630
- size_t d)
631
- {
632
- if (d & 3) return fvec_inner_product_ref (x, y, d);
633
- float32x4_t accu = vdupq_n_f32 (0);
634
- for (size_t i = 0; i < d; i += 4) {
635
- float32x4_t xi = vld1q_f32 (x + i);
636
- float32x4_t yi = vld1q_f32 (y + i);
637
- accu = vfmaq_f32 (accu, xi, yi);
612
+ float32x4_t accux2 = vpaddq_f32(accux4, accux4);
613
+ float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
614
+ for (; i < d; ++i) {
615
+ float32_t xi = x[i];
616
+ float32_t yi = y[i];
617
+ accux1 += xi * yi;
638
618
  }
639
- float32x4_t a2 = vpaddq_f32 (accu, accu);
640
- return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
619
+ return accux1;
641
620
  }
642
621
 
643
- float fvec_norm_L2sqr (const float *x, size_t d)
644
- {
645
- if (d & 3) return fvec_norm_L2sqr_ref (x, d);
646
- float32x4_t accu = vdupq_n_f32 (0);
647
- for (size_t i = 0; i < d; i += 4) {
648
- float32x4_t xi = vld1q_f32 (x + i);
649
- accu = vfmaq_f32 (accu, xi, xi);
622
+ float fvec_norm_L2sqr(const float* x, size_t d) {
623
+ float32x4_t accux4 = vdupq_n_f32(0);
624
+ const size_t d_simd = d - (d & 3);
625
+ size_t i;
626
+ for (i = 0; i < d_simd; i += 4) {
627
+ float32x4_t xi = vld1q_f32(x + i);
628
+ accux4 = vfmaq_f32(accux4, xi, xi);
650
629
  }
651
- float32x4_t a2 = vpaddq_f32 (accu, accu);
652
- return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
630
+ float32x4_t accux2 = vpaddq_f32(accux4, accux4);
631
+ float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
632
+ for (; i < d; ++i) {
633
+ float32_t xi = x[i];
634
+ accux1 += xi * xi;
635
+ }
636
+ return accux1;
653
637
  }
654
638
 
655
639
  // not optimized for ARM
656
- void fvec_L2sqr_ny (float * dis, const float * x,
657
- const float * y, size_t d, size_t ny) {
658
- fvec_L2sqr_ny_ref (dis, x, y, d, ny);
640
+ void fvec_L2sqr_ny(
641
+ float* dis,
642
+ const float* x,
643
+ const float* y,
644
+ size_t d,
645
+ size_t ny) {
646
+ fvec_L2sqr_ny_ref(dis, x, y, d, ny);
659
647
  }
660
648
 
661
- float fvec_L1 (const float * x, const float * y, size_t d)
662
- {
663
- return fvec_L1_ref (x, y, d);
649
+ float fvec_L1(const float* x, const float* y, size_t d) {
650
+ return fvec_L1_ref(x, y, d);
664
651
  }
665
652
 
666
- float fvec_Linf (const float * x, const float * y, size_t d)
667
- {
668
- return fvec_Linf_ref (x, y, d);
653
+ float fvec_Linf(const float* x, const float* y, size_t d) {
654
+ return fvec_Linf_ref(x, y, d);
669
655
  }
670
656
 
657
+ void fvec_inner_products_ny(
658
+ float* dis,
659
+ const float* x,
660
+ const float* y,
661
+ size_t d,
662
+ size_t ny) {
663
+ fvec_inner_products_ny_ref(dis, x, y, d, ny);
664
+ }
671
665
 
672
666
  #else
673
667
  // scalar implementation
674
668
 
675
- float fvec_L2sqr (const float * x,
676
- const float * y,
677
- size_t d)
678
- {
679
- return fvec_L2sqr_ref (x, y, d);
669
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
670
+ return fvec_L2sqr_ref(x, y, d);
680
671
  }
681
672
 
682
- float fvec_L1 (const float * x, const float * y, size_t d)
683
- {
684
- return fvec_L1_ref (x, y, d);
673
+ float fvec_L1(const float* x, const float* y, size_t d) {
674
+ return fvec_L1_ref(x, y, d);
685
675
  }
686
676
 
687
- float fvec_Linf (const float * x, const float * y, size_t d)
688
- {
689
- return fvec_Linf_ref (x, y, d);
677
+ float fvec_Linf(const float* x, const float* y, size_t d) {
678
+ return fvec_Linf_ref(x, y, d);
690
679
  }
691
680
 
692
- float fvec_inner_product (const float * x,
693
- const float * y,
694
- size_t d)
695
- {
696
- return fvec_inner_product_ref (x, y, d);
681
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
682
+ return fvec_inner_product_ref(x, y, d);
697
683
  }
698
684
 
699
- float fvec_norm_L2sqr (const float *x, size_t d)
700
- {
701
- return fvec_norm_L2sqr_ref (x, d);
685
+ float fvec_norm_L2sqr(const float* x, size_t d) {
686
+ return fvec_norm_L2sqr_ref(x, d);
702
687
  }
703
688
 
704
- void fvec_L2sqr_ny (float * dis, const float * x,
705
- const float * y, size_t d, size_t ny) {
706
- fvec_L2sqr_ny_ref (dis, x, y, d, ny);
689
+ void fvec_L2sqr_ny(
690
+ float* dis,
691
+ const float* x,
692
+ const float* y,
693
+ size_t d,
694
+ size_t ny) {
695
+ fvec_L2sqr_ny_ref(dis, x, y, d, ny);
707
696
  }
708
697
 
709
- void fvec_inner_products_ny (float * dis, const float * x,
710
- const float * y, size_t d, size_t ny) {
711
- fvec_inner_products_ny_ref (dis, x, y, d, ny);
698
+ void fvec_inner_products_ny(
699
+ float* dis,
700
+ const float* x,
701
+ const float* y,
702
+ size_t d,
703
+ size_t ny) {
704
+ fvec_inner_products_ny_ref(dis, x, y, d, ny);
712
705
  }
713
706
 
714
-
715
707
  #endif
716
708
 
717
-
718
-
719
-
720
-
721
-
722
-
723
-
724
-
725
-
726
-
727
-
728
-
729
-
730
-
731
-
732
-
733
-
734
-
735
-
736
709
  /***************************************************************************
737
710
  * heavily optimized table computations
738
711
  ***************************************************************************/
739
712
 
740
-
741
- static inline void fvec_madd_ref (size_t n, const float *a,
742
- float bf, const float *b, float *c) {
713
+ static inline void fvec_madd_ref(
714
+ size_t n,
715
+ const float* a,
716
+ float bf,
717
+ const float* b,
718
+ float* c) {
743
719
  for (size_t i = 0; i < n; i++)
744
720
  c[i] = a[i] + bf * b[i];
745
721
  }
746
722
 
747
723
  #ifdef __SSE3__
748
724
 
749
- static inline void fvec_madd_sse (size_t n, const float *a,
750
- float bf, const float *b, float *c) {
725
+ static inline void fvec_madd_sse(
726
+ size_t n,
727
+ const float* a,
728
+ float bf,
729
+ const float* b,
730
+ float* c) {
751
731
  n >>= 2;
752
- __m128 bf4 = _mm_set_ps1 (bf);
753
- __m128 * a4 = (__m128*)a;
754
- __m128 * b4 = (__m128*)b;
755
- __m128 * c4 = (__m128*)c;
732
+ __m128 bf4 = _mm_set_ps1(bf);
733
+ __m128* a4 = (__m128*)a;
734
+ __m128* b4 = (__m128*)b;
735
+ __m128* c4 = (__m128*)c;
756
736
 
757
737
  while (n--) {
758
- *c4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
738
+ *c4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
759
739
  b4++;
760
740
  a4++;
761
741
  c4++;
762
742
  }
763
743
  }
764
744
 
765
- void fvec_madd (size_t n, const float *a,
766
- float bf, const float *b, float *c)
767
- {
768
- if ((n & 3) == 0 &&
769
- ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
770
- fvec_madd_sse (n, a, bf, b, c);
745
+ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
746
+ if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
747
+ fvec_madd_sse(n, a, bf, b, c);
771
748
  else
772
- fvec_madd_ref (n, a, bf, b, c);
749
+ fvec_madd_ref(n, a, bf, b, c);
773
750
  }
774
751
 
775
752
  #else
776
753
 
777
- void fvec_madd (size_t n, const float *a,
778
- float bf, const float *b, float *c)
779
- {
780
- fvec_madd_ref (n, a, bf, b, c);
754
+ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
755
+ fvec_madd_ref(n, a, bf, b, c);
781
756
  }
782
757
 
783
758
  #endif
784
759
 
785
- static inline int fvec_madd_and_argmin_ref (size_t n, const float *a,
786
- float bf, const float *b, float *c) {
760
+ static inline int fvec_madd_and_argmin_ref(
761
+ size_t n,
762
+ const float* a,
763
+ float bf,
764
+ const float* b,
765
+ float* c) {
787
766
  float vmin = 1e20;
788
767
  int imin = -1;
789
768
 
@@ -799,125 +778,100 @@ static inline int fvec_madd_and_argmin_ref (size_t n, const float *a,
799
778
 
800
779
  #ifdef __SSE3__
801
780
 
802
- static inline int fvec_madd_and_argmin_sse (
803
- size_t n, const float *a,
804
- float bf, const float *b, float *c) {
781
+ static inline int fvec_madd_and_argmin_sse(
782
+ size_t n,
783
+ const float* a,
784
+ float bf,
785
+ const float* b,
786
+ float* c) {
805
787
  n >>= 2;
806
- __m128 bf4 = _mm_set_ps1 (bf);
807
- __m128 vmin4 = _mm_set_ps1 (1e20);
808
- __m128i imin4 = _mm_set1_epi32 (-1);
809
- __m128i idx4 = _mm_set_epi32 (3, 2, 1, 0);
810
- __m128i inc4 = _mm_set1_epi32 (4);
811
- __m128 * a4 = (__m128*)a;
812
- __m128 * b4 = (__m128*)b;
813
- __m128 * c4 = (__m128*)c;
788
+ __m128 bf4 = _mm_set_ps1(bf);
789
+ __m128 vmin4 = _mm_set_ps1(1e20);
790
+ __m128i imin4 = _mm_set1_epi32(-1);
791
+ __m128i idx4 = _mm_set_epi32(3, 2, 1, 0);
792
+ __m128i inc4 = _mm_set1_epi32(4);
793
+ __m128* a4 = (__m128*)a;
794
+ __m128* b4 = (__m128*)b;
795
+ __m128* c4 = (__m128*)c;
814
796
 
815
797
  while (n--) {
816
- __m128 vc4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
798
+ __m128 vc4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
817
799
  *c4 = vc4;
818
- __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
800
+ __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
819
801
  // imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower!
820
802
 
821
- imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
822
- _mm_andnot_si128 (mask, imin4));
823
- vmin4 = _mm_min_ps (vmin4, vc4);
803
+ imin4 = _mm_or_si128(
804
+ _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
805
+ vmin4 = _mm_min_ps(vmin4, vc4);
824
806
  b4++;
825
807
  a4++;
826
808
  c4++;
827
- idx4 = _mm_add_epi32 (idx4, inc4);
809
+ idx4 = _mm_add_epi32(idx4, inc4);
828
810
  }
829
811
 
830
812
  // 4 values -> 2
831
813
  {
832
- idx4 = _mm_shuffle_epi32 (imin4, 3 << 2 | 2);
833
- __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 3 << 2 | 2);
834
- __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
835
- imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
836
- _mm_andnot_si128 (mask, imin4));
837
- vmin4 = _mm_min_ps (vmin4, vc4);
814
+ idx4 = _mm_shuffle_epi32(imin4, 3 << 2 | 2);
815
+ __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 3 << 2 | 2);
816
+ __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
817
+ imin4 = _mm_or_si128(
818
+ _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
819
+ vmin4 = _mm_min_ps(vmin4, vc4);
838
820
  }
839
821
  // 2 values -> 1
840
822
  {
841
- idx4 = _mm_shuffle_epi32 (imin4, 1);
842
- __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 1);
843
- __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
844
- imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
845
- _mm_andnot_si128 (mask, imin4));
823
+ idx4 = _mm_shuffle_epi32(imin4, 1);
824
+ __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 1);
825
+ __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
826
+ imin4 = _mm_or_si128(
827
+ _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
846
828
  // vmin4 = _mm_min_ps (vmin4, vc4);
847
829
  }
848
- return _mm_cvtsi128_si32 (imin4);
830
+ return _mm_cvtsi128_si32(imin4);
849
831
  }
850
832
 
851
-
852
- int fvec_madd_and_argmin (size_t n, const float *a,
853
- float bf, const float *b, float *c)
854
- {
855
- if ((n & 3) == 0 &&
856
- ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
857
- return fvec_madd_and_argmin_sse (n, a, bf, b, c);
833
+ int fvec_madd_and_argmin(
834
+ size_t n,
835
+ const float* a,
836
+ float bf,
837
+ const float* b,
838
+ float* c) {
839
+ if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
840
+ return fvec_madd_and_argmin_sse(n, a, bf, b, c);
858
841
  else
859
- return fvec_madd_and_argmin_ref (n, a, bf, b, c);
842
+ return fvec_madd_and_argmin_ref(n, a, bf, b, c);
860
843
  }
861
844
 
862
845
  #else
863
846
 
864
- int fvec_madd_and_argmin (size_t n, const float *a,
865
- float bf, const float *b, float *c)
866
- {
867
- return fvec_madd_and_argmin_ref (n, a, bf, b, c);
847
+ int fvec_madd_and_argmin(
848
+ size_t n,
849
+ const float* a,
850
+ float bf,
851
+ const float* b,
852
+ float* c) {
853
+ return fvec_madd_and_argmin_ref(n, a, bf, b, c);
868
854
  }
869
855
 
870
856
  #endif
871
857
 
872
-
873
858
  /***************************************************************************
874
859
  * PQ tables computations
875
860
  ***************************************************************************/
876
861
 
877
- #ifdef __AVX2__
878
-
879
862
  namespace {
880
863
 
881
-
882
- // get even float32's of a and b, interleaved
883
- simd8float32 geteven(simd8float32 a, simd8float32 b) {
884
- return simd8float32(
885
- _mm256_shuffle_ps(a.f, b.f, 0 << 0 | 2 << 2 | 0 << 4 | 2 << 6)
886
- );
887
- }
888
-
889
- // get odd float32's of a and b, interleaved
890
- simd8float32 getodd(simd8float32 a, simd8float32 b) {
891
- return simd8float32(
892
- _mm256_shuffle_ps(a.f, b.f, 1 << 0 | 3 << 2 | 1 << 4 | 3 << 6)
893
- );
894
- }
895
-
896
- // 3 cycles
897
- // if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
898
- simd8float32 getlow128(simd8float32 a, simd8float32 b) {
899
- return simd8float32(
900
- _mm256_permute2f128_ps(a.f, b.f, 0 | 2 << 4)
901
- );
902
- }
903
-
904
- simd8float32 gethigh128(simd8float32 a, simd8float32 b) {
905
- return simd8float32(
906
- _mm256_permute2f128_ps(a.f, b.f, 1 | 3 << 4)
907
- );
908
- }
909
-
910
864
  /// compute the IP for dsub = 2 for 8 centroids and 4 sub-vectors at a time
911
- template<bool is_inner_product>
865
+ template <bool is_inner_product>
912
866
  void pq2_8cents_table(
913
867
  const simd8float32 centroids[8],
914
868
  const simd8float32 x,
915
- float *out, size_t ldo, size_t nout = 4
916
- ) {
917
-
869
+ float* out,
870
+ size_t ldo,
871
+ size_t nout = 4) {
918
872
  simd8float32 ips[4];
919
873
 
920
- for(int i = 0; i < 4; i++) {
874
+ for (int i = 0; i < 4; i++) {
921
875
  simd8float32 p1, p2;
922
876
  if (is_inner_product) {
923
877
  p1 = x * centroids[2 * i];
@@ -941,21 +895,21 @@ void pq2_8cents_table(
941
895
  simd8float32 ip1 = getlow128(ip13a, ip13b);
942
896
  simd8float32 ip3 = gethigh128(ip13a, ip13b);
943
897
 
944
- switch(nout) {
945
- case 4:
946
- ip3.storeu(out + 3 * ldo);
947
- case 3:
948
- ip2.storeu(out + 2 * ldo);
949
- case 2:
950
- ip1.storeu(out + 1 * ldo);
951
- case 1:
952
- ip0.storeu(out);
898
+ switch (nout) {
899
+ case 4:
900
+ ip3.storeu(out + 3 * ldo);
901
+ case 3:
902
+ ip2.storeu(out + 2 * ldo);
903
+ case 2:
904
+ ip1.storeu(out + 1 * ldo);
905
+ case 1:
906
+ ip0.storeu(out);
953
907
  }
954
908
  }
955
909
 
956
- simd8float32 load_simd8float32_partial(const float *x, int n) {
910
+ simd8float32 load_simd8float32_partial(const float* x, int n) {
957
911
  ALIGNED(32) float tmp[8] = {0, 0, 0, 0, 0, 0, 0, 0};
958
- float *wp = tmp;
912
+ float* wp = tmp;
959
913
  for (int i = 0; i < n; i++) {
960
914
  *wp++ = *x++;
961
915
  }
@@ -964,25 +918,23 @@ simd8float32 load_simd8float32_partial(const float *x, int n) {
964
918
 
965
919
  } // anonymous namespace
966
920
 
967
-
968
-
969
-
970
921
  void compute_PQ_dis_tables_dsub2(
971
- size_t d, size_t ksub, const float *all_centroids,
972
- size_t nx, const float * x,
922
+ size_t d,
923
+ size_t ksub,
924
+ const float* all_centroids,
925
+ size_t nx,
926
+ const float* x,
973
927
  bool is_inner_product,
974
- float * dis_tables)
975
- {
928
+ float* dis_tables) {
976
929
  size_t M = d / 2;
977
930
  FAISS_THROW_IF_NOT(ksub % 8 == 0);
978
931
 
979
- for(size_t m0 = 0; m0 < M; m0 += 4) {
932
+ for (size_t m0 = 0; m0 < M; m0 += 4) {
980
933
  int m1 = std::min(M, m0 + 4);
981
- for(int k0 = 0; k0 < ksub; k0 += 8) {
982
-
934
+ for (int k0 = 0; k0 < ksub; k0 += 8) {
983
935
  simd8float32 centroids[8];
984
936
  for (int k = 0; k < 8; k++) {
985
- float centroid[8] __attribute__((aligned(32)));
937
+ ALIGNED(32) float centroid[8];
986
938
  size_t wp = 0;
987
939
  size_t rp = (m0 * ksub + k + k0) * 2;
988
940
  for (int m = m0; m < m1; m++) {
@@ -992,45 +944,33 @@ void compute_PQ_dis_tables_dsub2(
992
944
  }
993
945
  centroids[k] = simd8float32(centroid);
994
946
  }
995
- for(size_t i = 0; i < nx; i++) {
947
+ for (size_t i = 0; i < nx; i++) {
996
948
  simd8float32 xi;
997
949
  if (m1 == m0 + 4) {
998
950
  xi.loadu(x + i * d + m0 * 2);
999
951
  } else {
1000
- xi = load_simd8float32_partial(x + i * d + m0 * 2, 2 * (m1 - m0));
952
+ xi = load_simd8float32_partial(
953
+ x + i * d + m0 * 2, 2 * (m1 - m0));
1001
954
  }
1002
955
 
1003
- if(is_inner_product) {
956
+ if (is_inner_product) {
1004
957
  pq2_8cents_table<true>(
1005
- centroids, xi,
1006
- dis_tables + (i * M + m0) * ksub + k0,
1007
- ksub, m1 - m0
1008
- );
958
+ centroids,
959
+ xi,
960
+ dis_tables + (i * M + m0) * ksub + k0,
961
+ ksub,
962
+ m1 - m0);
1009
963
  } else {
1010
964
  pq2_8cents_table<false>(
1011
- centroids, xi,
1012
- dis_tables + (i * M + m0) * ksub + k0,
1013
- ksub, m1 - m0
1014
- );
965
+ centroids,
966
+ xi,
967
+ dis_tables + (i * M + m0) * ksub + k0,
968
+ ksub,
969
+ m1 - m0);
1015
970
  }
1016
971
  }
1017
972
  }
1018
973
  }
1019
-
1020
974
  }
1021
975
 
1022
- #else
1023
-
1024
- void compute_PQ_dis_tables_dsub2(
1025
- size_t d, size_t ksub, const float *all_centroids,
1026
- size_t nx, const float * x,
1027
- bool is_inner_product,
1028
- float * dis_tables)
1029
- {
1030
- FAISS_THROW_MSG("only implemented for AVX2");
1031
- }
1032
-
1033
- #endif
1034
-
1035
-
1036
976
  } // namespace faiss