faiss 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -5,49 +5,34 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  /* All distance functions for L2 and IP distances.
11
- * The actual functions are implemented in distances.cpp and distances_simd.cpp */
9
+ * The actual functions are implemented in distances.cpp and distances_simd.cpp
10
+ */
12
11
 
13
12
  #pragma once
14
13
 
15
14
  #include <stdint.h>
16
15
 
17
- #include <faiss/utils/Heap.h>
18
16
  #include <faiss/impl/platform_macros.h>
19
-
17
+ #include <faiss/utils/Heap.h>
20
18
 
21
19
  namespace faiss {
22
20
 
23
- /*********************************************************
21
+ /*********************************************************
24
22
  * Optimized distance/norm/inner prod computations
25
23
  *********************************************************/
26
24
 
27
-
28
25
  /// Squared L2 distance between two vectors
29
- float fvec_L2sqr (
30
- const float * x,
31
- const float * y,
32
- size_t d);
26
+ float fvec_L2sqr(const float* x, const float* y, size_t d);
33
27
 
34
28
  /// inner product
35
- float fvec_inner_product (
36
- const float * x,
37
- const float * y,
38
- size_t d);
29
+ float fvec_inner_product(const float* x, const float* y, size_t d);
39
30
 
40
31
  /// L1 distance
41
- float fvec_L1 (
42
- const float * x,
43
- const float * y,
44
- size_t d);
45
-
46
- float fvec_Linf (
47
- const float * x,
48
- const float * y,
49
- size_t d);
32
+ float fvec_L1(const float* x, const float* y, size_t d);
50
33
 
34
+ /// infinity distance
35
+ float fvec_Linf(const float* x, const float* y, size_t d);
51
36
 
52
37
  /** Compute pairwise distances between sets of vectors
53
38
  *
@@ -59,74 +44,83 @@ float fvec_Linf (
59
44
  * @param dis output distances (size nq * nb)
60
45
  * @param ldq,ldb, ldd strides for the matrices
61
46
  */
62
- void pairwise_L2sqr (int64_t d,
63
- int64_t nq, const float *xq,
64
- int64_t nb, const float *xb,
65
- float *dis,
66
- int64_t ldq = -1, int64_t ldb = -1, int64_t ldd = -1);
47
+ void pairwise_L2sqr(
48
+ int64_t d,
49
+ int64_t nq,
50
+ const float* xq,
51
+ int64_t nb,
52
+ const float* xb,
53
+ float* dis,
54
+ int64_t ldq = -1,
55
+ int64_t ldb = -1,
56
+ int64_t ldd = -1);
67
57
 
68
58
  /* compute the inner product between nx vectors x and one y */
69
- void fvec_inner_products_ny (
70
- float * ip, /* output inner product */
71
- const float * x,
72
- const float * y,
73
- size_t d, size_t ny);
59
+ void fvec_inner_products_ny(
60
+ float* ip, /* output inner product */
61
+ const float* x,
62
+ const float* y,
63
+ size_t d,
64
+ size_t ny);
74
65
 
75
66
  /* compute ny square L2 distance bewteen x and a set of contiguous y vectors */
76
- void fvec_L2sqr_ny (
77
- float * dis,
78
- const float * x,
79
- const float * y,
80
- size_t d, size_t ny);
81
-
67
+ void fvec_L2sqr_ny(
68
+ float* dis,
69
+ const float* x,
70
+ const float* y,
71
+ size_t d,
72
+ size_t ny);
82
73
 
83
74
  /** squared norm of a vector */
84
- float fvec_norm_L2sqr (const float * x,
85
- size_t d);
75
+ float fvec_norm_L2sqr(const float* x, size_t d);
86
76
 
87
77
  /** compute the L2 norms for a set of vectors
88
78
  *
89
- * @param ip output norms, size nx
79
+ * @param norms output norms, size nx
90
80
  * @param x set of vectors, size nx * d
91
81
  */
92
- void fvec_norms_L2 (float * ip, const float * x, size_t d, size_t nx);
82
+ void fvec_norms_L2(float* norms, const float* x, size_t d, size_t nx);
93
83
 
94
- /// same as fvec_norms_L2, but computes square norms
95
- void fvec_norms_L2sqr (float * ip, const float * x, size_t d, size_t nx);
84
+ /// same as fvec_norms_L2, but computes squared norms
85
+ void fvec_norms_L2sqr(float* norms, const float* x, size_t d, size_t nx);
96
86
 
97
87
  /* L2-renormalize a set of vector. Nothing done if the vector is 0-normed */
98
- void fvec_renorm_L2 (size_t d, size_t nx, float * x);
99
-
88
+ void fvec_renorm_L2(size_t d, size_t nx, float* x);
100
89
 
101
90
  /* This function exists because the Torch counterpart is extremly slow
102
91
  (not multi-threaded + unexpected overhead even in single thread).
103
92
  It is here to implement the usual property |x-y|^2=|x|^2+|y|^2-2<x|y> */
104
- void inner_product_to_L2sqr (float * dis,
105
- const float * nr1,
106
- const float * nr2,
107
- size_t n1, size_t n2);
93
+ void inner_product_to_L2sqr(
94
+ float* dis,
95
+ const float* nr1,
96
+ const float* nr2,
97
+ size_t n1,
98
+ size_t n2);
108
99
 
109
100
  /***************************************************************************
110
101
  * Compute a subset of distances
111
102
  ***************************************************************************/
112
103
 
113
- /* compute the inner product between x and a subset y of ny vectors,
114
- whose indices are given by idy. */
115
- void fvec_inner_products_by_idx (
116
- float * ip,
117
- const float * x,
118
- const float * y,
119
- const int64_t *ids,
120
- size_t d, size_t nx, size_t ny);
104
+ /* compute the inner product between x and a subset y of ny vectors,
105
+ whose indices are given by idy. */
106
+ void fvec_inner_products_by_idx(
107
+ float* ip,
108
+ const float* x,
109
+ const float* y,
110
+ const int64_t* ids,
111
+ size_t d,
112
+ size_t nx,
113
+ size_t ny);
121
114
 
122
115
  /* same but for a subset in y indexed by idsy (ny vectors in total) */
123
- void fvec_L2sqr_by_idx (
124
- float * dis,
125
- const float * x,
126
- const float * y,
127
- const int64_t *ids, /* ids of y vecs */
128
- size_t d, size_t nx, size_t ny);
129
-
116
+ void fvec_L2sqr_by_idx(
117
+ float* dis,
118
+ const float* x,
119
+ const float* y,
120
+ const int64_t* ids, /* ids of y vecs */
121
+ size_t d,
122
+ size_t nx,
123
+ size_t ny);
130
124
 
131
125
  /** compute dis[j] = L2sqr(x[ix[j]], y[iy[j]]) forall j=0..n-1
132
126
  *
@@ -136,18 +130,24 @@ void fvec_L2sqr_by_idx (
136
130
  * @param iy size n
137
131
  * @param dis size n
138
132
  */
139
- void pairwise_indexed_L2sqr (
140
- size_t d, size_t n,
141
- const float * x, const int64_t *ix,
142
- const float * y, const int64_t *iy,
143
- float *dis);
133
+ void pairwise_indexed_L2sqr(
134
+ size_t d,
135
+ size_t n,
136
+ const float* x,
137
+ const int64_t* ix,
138
+ const float* y,
139
+ const int64_t* iy,
140
+ float* dis);
144
141
 
145
142
  /* same for inner product */
146
- void pairwise_indexed_inner_product (
147
- size_t d, size_t n,
148
- const float * x, const int64_t *ix,
149
- const float * y, const int64_t *iy,
150
- float *dis);
143
+ void pairwise_indexed_inner_product(
144
+ size_t d,
145
+ size_t n,
146
+ const float* x,
147
+ const int64_t* ix,
148
+ const float* y,
149
+ const int64_t* iy,
150
+ float* dis);
151
151
 
152
152
  /***************************************************************************
153
153
  * KNN functions
@@ -171,46 +171,51 @@ FAISS_API extern int distance_compute_min_k_reservoir;
171
171
  * @param y database vectors, size ny * d
172
172
  * @param res result array, which also provides k. Sorted on output
173
173
  */
174
- void knn_inner_product (
175
- const float * x,
176
- const float * y,
177
- size_t d, size_t nx, size_t ny,
178
- float_minheap_array_t * res);
174
+ void knn_inner_product(
175
+ const float* x,
176
+ const float* y,
177
+ size_t d,
178
+ size_t nx,
179
+ size_t ny,
180
+ float_minheap_array_t* res);
179
181
 
180
182
  /** Same as knn_inner_product, for the L2 distance
181
183
  * @param y_norm2 norms for the y vectors (nullptr or size ny)
182
184
  */
183
- void knn_L2sqr (
184
- const float * x,
185
- const float * y,
186
- size_t d, size_t nx, size_t ny,
187
- float_maxheap_array_t * res,
188
- const float *y_norm2 = nullptr);
189
-
185
+ void knn_L2sqr(
186
+ const float* x,
187
+ const float* y,
188
+ size_t d,
189
+ size_t nx,
190
+ size_t ny,
191
+ float_maxheap_array_t* res,
192
+ const float* y_norm2 = nullptr);
190
193
 
191
194
  /* Find the nearest neighbors for nx queries in a set of ny vectors
192
195
  * indexed by ids. May be useful for re-ranking a pre-selected vector list
193
196
  */
194
- void knn_inner_products_by_idx (
195
- const float * x,
196
- const float * y,
197
- const int64_t * ids,
198
- size_t d, size_t nx, size_t ny,
199
- float_minheap_array_t * res);
200
-
201
- void knn_L2sqr_by_idx (
202
- const float * x,
203
- const float * y,
204
- const int64_t * ids,
205
- size_t d, size_t nx, size_t ny,
206
- float_maxheap_array_t * res);
197
+ void knn_inner_products_by_idx(
198
+ const float* x,
199
+ const float* y,
200
+ const int64_t* ids,
201
+ size_t d,
202
+ size_t nx,
203
+ size_t ny,
204
+ float_minheap_array_t* res);
205
+
206
+ void knn_L2sqr_by_idx(
207
+ const float* x,
208
+ const float* y,
209
+ const int64_t* ids,
210
+ size_t d,
211
+ size_t nx,
212
+ size_t ny,
213
+ float_maxheap_array_t* res);
207
214
 
208
215
  /***************************************************************************
209
216
  * Range search
210
217
  ***************************************************************************/
211
218
 
212
-
213
-
214
219
  /// Forward declaration, see AuxIndexStructures.h
215
220
  struct RangeSearchResult;
216
221
 
@@ -222,21 +227,24 @@ struct RangeSearchResult;
222
227
  * @param radius search radius around the x vectors
223
228
  * @param result result structure
224
229
  */
225
- void range_search_L2sqr (
226
- const float * x,
227
- const float * y,
228
- size_t d, size_t nx, size_t ny,
230
+ void range_search_L2sqr(
231
+ const float* x,
232
+ const float* y,
233
+ size_t d,
234
+ size_t nx,
235
+ size_t ny,
229
236
  float radius,
230
- RangeSearchResult *result);
237
+ RangeSearchResult* result);
231
238
 
232
239
  /// same as range_search_L2sqr for the inner product similarity
233
- void range_search_inner_product (
234
- const float * x,
235
- const float * y,
236
- size_t d, size_t nx, size_t ny,
240
+ void range_search_inner_product(
241
+ const float* x,
242
+ const float* y,
243
+ size_t d,
244
+ size_t nx,
245
+ size_t ny,
237
246
  float radius,
238
- RangeSearchResult *result);
239
-
247
+ RangeSearchResult* result);
240
248
 
241
249
  /***************************************************************************
242
250
  * PQ tables computations
@@ -244,9 +252,16 @@ void range_search_inner_product (
244
252
 
245
253
  /// specialized function for PQ2
246
254
  void compute_PQ_dis_tables_dsub2(
247
- size_t d, size_t ksub, const float *centroids,
248
- size_t nx, const float * x,
255
+ size_t d,
256
+ size_t ksub,
257
+ const float* centroids,
258
+ size_t nx,
259
+ const float* x,
249
260
  bool is_inner_product,
250
- float * dis_tables);
261
+ float* dis_tables);
262
+
263
+ /***************************************************************************
264
+ * Templatized versions of distance functions
265
+ ***************************************************************************/
251
266
 
252
267
  } // namespace faiss
@@ -9,13 +9,14 @@
9
9
 
10
10
  #include <faiss/utils/distances.h>
11
11
 
12
- #include <cstdio>
13
12
  #include <cassert>
14
- #include <cstring>
15
13
  #include <cmath>
14
+ #include <cstdio>
15
+ #include <cstring>
16
16
 
17
- #include <faiss/utils/simdlib.h>
18
17
  #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/platform_macros.h>
19
+ #include <faiss/utils/simdlib.h>
19
20
 
20
21
  #ifdef __SSE3__
21
22
  #include <immintrin.h>
@@ -25,19 +26,16 @@
25
26
  #include <arm_neon.h>
26
27
  #endif
27
28
 
28
-
29
29
  namespace faiss {
30
30
 
31
31
  #ifdef __AVX__
32
32
  #define USE_AVX
33
33
  #endif
34
34
 
35
-
36
35
  /*********************************************************
37
36
  * Optimized distance computations
38
37
  *********************************************************/
39
38
 
40
-
41
39
  /* Functions to compute:
42
40
  - L2 distance between 2 vectors
43
41
  - inner product between 2 vectors
@@ -53,29 +51,21 @@ namespace faiss {
53
51
 
54
52
  */
55
53
 
56
-
57
54
  /*********************************************************
58
55
  * Reference implementations
59
56
  */
60
57
 
61
-
62
- float fvec_L2sqr_ref (const float * x,
63
- const float * y,
64
- size_t d)
65
- {
58
+ float fvec_L2sqr_ref(const float* x, const float* y, size_t d) {
66
59
  size_t i;
67
60
  float res = 0;
68
61
  for (i = 0; i < d; i++) {
69
62
  const float tmp = x[i] - y[i];
70
- res += tmp * tmp;
63
+ res += tmp * tmp;
71
64
  }
72
65
  return res;
73
66
  }
74
67
 
75
- float fvec_L1_ref (const float * x,
76
- const float * y,
77
- size_t d)
78
- {
68
+ float fvec_L1_ref(const float* x, const float* y, size_t d) {
79
69
  size_t i;
80
70
  float res = 0;
81
71
  for (i = 0; i < d; i++) {
@@ -85,56 +75,49 @@ float fvec_L1_ref (const float * x,
85
75
  return res;
86
76
  }
87
77
 
88
- float fvec_Linf_ref (const float * x,
89
- const float * y,
90
- size_t d)
91
- {
78
+ float fvec_Linf_ref(const float* x, const float* y, size_t d) {
92
79
  size_t i;
93
80
  float res = 0;
94
81
  for (i = 0; i < d; i++) {
95
- res = fmax(res, fabs(x[i] - y[i]));
82
+ res = fmax(res, fabs(x[i] - y[i]));
96
83
  }
97
84
  return res;
98
85
  }
99
86
 
100
- float fvec_inner_product_ref (const float * x,
101
- const float * y,
102
- size_t d)
103
- {
87
+ float fvec_inner_product_ref(const float* x, const float* y, size_t d) {
104
88
  size_t i;
105
89
  float res = 0;
106
90
  for (i = 0; i < d; i++)
107
- res += x[i] * y[i];
91
+ res += x[i] * y[i];
108
92
  return res;
109
93
  }
110
94
 
111
- float fvec_norm_L2sqr_ref (const float *x, size_t d)
112
- {
95
+ float fvec_norm_L2sqr_ref(const float* x, size_t d) {
113
96
  size_t i;
114
97
  double res = 0;
115
98
  for (i = 0; i < d; i++)
116
- res += x[i] * x[i];
99
+ res += x[i] * x[i];
117
100
  return res;
118
101
  }
119
102
 
120
-
121
- void fvec_L2sqr_ny_ref (float * dis,
122
- const float * x,
123
- const float * y,
124
- size_t d, size_t ny)
125
- {
103
+ void fvec_L2sqr_ny_ref(
104
+ float* dis,
105
+ const float* x,
106
+ const float* y,
107
+ size_t d,
108
+ size_t ny) {
126
109
  for (size_t i = 0; i < ny; i++) {
127
- dis[i] = fvec_L2sqr (x, y, d);
110
+ dis[i] = fvec_L2sqr(x, y, d);
128
111
  y += d;
129
112
  }
130
113
  }
131
114
 
132
-
133
- void fvec_inner_products_ny_ref (float * ip,
134
- const float * x,
135
- const float * y,
136
- size_t d, size_t ny)
137
- {
115
+ void fvec_inner_products_ny_ref(
116
+ float* ip,
117
+ const float* x,
118
+ const float* y,
119
+ size_t d,
120
+ size_t ny) {
138
121
  // BLAS slower for the use cases here
139
122
  #if 0
140
123
  {
@@ -146,15 +129,11 @@ void fvec_inner_products_ny_ref (float * ip,
146
129
  }
147
130
  #endif
148
131
  for (size_t i = 0; i < ny; i++) {
149
- ip[i] = fvec_inner_product (x, y, d);
132
+ ip[i] = fvec_inner_product(x, y, d);
150
133
  y += d;
151
134
  }
152
135
  }
153
136
 
154
-
155
-
156
-
157
-
158
137
  /*********************************************************
159
138
  * SSE and AVX implementations
160
139
  */
@@ -162,40 +141,38 @@ void fvec_inner_products_ny_ref (float * ip,
162
141
  #ifdef __SSE3__
163
142
 
164
143
  // reads 0 <= d < 4 floats as __m128
165
- static inline __m128 masked_read (int d, const float *x)
166
- {
167
- assert (0 <= d && d < 4);
168
- __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
144
+ static inline __m128 masked_read(int d, const float* x) {
145
+ assert(0 <= d && d < 4);
146
+ ALIGNED(16) float buf[4] = {0, 0, 0, 0};
169
147
  switch (d) {
170
- case 3:
171
- buf[2] = x[2];
172
- case 2:
173
- buf[1] = x[1];
174
- case 1:
175
- buf[0] = x[0];
176
- }
177
- return _mm_load_ps (buf);
148
+ case 3:
149
+ buf[2] = x[2];
150
+ case 2:
151
+ buf[1] = x[1];
152
+ case 1:
153
+ buf[0] = x[0];
154
+ }
155
+ return _mm_load_ps(buf);
178
156
  // cannot use AVX2 _mm_mask_set1_epi32
179
157
  }
180
158
 
181
- float fvec_norm_L2sqr (const float * x,
182
- size_t d)
183
- {
159
+ float fvec_norm_L2sqr(const float* x, size_t d) {
184
160
  __m128 mx;
185
161
  __m128 msum1 = _mm_setzero_ps();
186
162
 
187
163
  while (d >= 4) {
188
- mx = _mm_loadu_ps (x); x += 4;
189
- msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
164
+ mx = _mm_loadu_ps(x);
165
+ x += 4;
166
+ msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
190
167
  d -= 4;
191
168
  }
192
169
 
193
- mx = masked_read (d, x);
194
- msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
170
+ mx = masked_read(d, x);
171
+ msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
195
172
 
196
- msum1 = _mm_hadd_ps (msum1, msum1);
197
- msum1 = _mm_hadd_ps (msum1, msum1);
198
- return _mm_cvtss_f32 (msum1);
173
+ msum1 = _mm_hadd_ps(msum1, msum1);
174
+ msum1 = _mm_hadd_ps(msum1, msum1);
175
+ return _mm_cvtss_f32(msum1);
199
176
  }
200
177
 
201
178
  namespace {
@@ -204,586 +181,588 @@ namespace {
204
181
  /// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
205
182
  /// functions below
206
183
  struct ElementOpL2 {
207
-
208
- static float op (float x, float y) {
184
+ static float op(float x, float y) {
209
185
  float tmp = x - y;
210
186
  return tmp * tmp;
211
187
  }
212
188
 
213
- static __m128 op (__m128 x, __m128 y) {
214
- __m128 tmp = x - y;
215
- return tmp * tmp;
189
+ static __m128 op(__m128 x, __m128 y) {
190
+ __m128 tmp = _mm_sub_ps(x, y);
191
+ return _mm_mul_ps(tmp, tmp);
216
192
  }
217
-
218
193
  };
219
194
 
220
195
  /// Function that does a component-wise operation between x and y
221
196
  /// to compute inner products
222
197
  struct ElementOpIP {
223
-
224
- static float op (float x, float y) {
198
+ static float op(float x, float y) {
225
199
  return x * y;
226
200
  }
227
201
 
228
- static __m128 op (__m128 x, __m128 y) {
229
- return x * y;
202
+ static __m128 op(__m128 x, __m128 y) {
203
+ return _mm_mul_ps(x, y);
230
204
  }
231
-
232
205
  };
233
206
 
234
- template<class ElementOp>
235
- void fvec_op_ny_D1 (float * dis, const float * x,
236
- const float * y, size_t ny)
237
- {
207
+ template <class ElementOp>
208
+ void fvec_op_ny_D1(float* dis, const float* x, const float* y, size_t ny) {
238
209
  float x0s = x[0];
239
- __m128 x0 = _mm_set_ps (x0s, x0s, x0s, x0s);
210
+ __m128 x0 = _mm_set_ps(x0s, x0s, x0s, x0s);
240
211
 
241
212
  size_t i;
242
213
  for (i = 0; i + 3 < ny; i += 4) {
243
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
244
- dis[i] = _mm_cvtss_f32 (accu);
245
- __m128 tmp = _mm_shuffle_ps (accu, accu, 1);
246
- dis[i + 1] = _mm_cvtss_f32 (tmp);
247
- tmp = _mm_shuffle_ps (accu, accu, 2);
248
- dis[i + 2] = _mm_cvtss_f32 (tmp);
249
- tmp = _mm_shuffle_ps (accu, accu, 3);
250
- dis[i + 3] = _mm_cvtss_f32 (tmp);
214
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
215
+ y += 4;
216
+ dis[i] = _mm_cvtss_f32(accu);
217
+ __m128 tmp = _mm_shuffle_ps(accu, accu, 1);
218
+ dis[i + 1] = _mm_cvtss_f32(tmp);
219
+ tmp = _mm_shuffle_ps(accu, accu, 2);
220
+ dis[i + 2] = _mm_cvtss_f32(tmp);
221
+ tmp = _mm_shuffle_ps(accu, accu, 3);
222
+ dis[i + 3] = _mm_cvtss_f32(tmp);
251
223
  }
252
224
  while (i < ny) { // handle non-multiple-of-4 case
253
225
  dis[i++] = ElementOp::op(x0s, *y++);
254
226
  }
255
227
  }
256
228
 
257
- template<class ElementOp>
258
- void fvec_op_ny_D2 (float * dis, const float * x,
259
- const float * y, size_t ny)
260
- {
261
- __m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
229
+ template <class ElementOp>
230
+ void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) {
231
+ __m128 x0 = _mm_set_ps(x[1], x[0], x[1], x[0]);
262
232
 
263
233
  size_t i;
264
234
  for (i = 0; i + 1 < ny; i += 2) {
265
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
266
- accu = _mm_hadd_ps (accu, accu);
267
- dis[i] = _mm_cvtss_f32 (accu);
268
- accu = _mm_shuffle_ps (accu, accu, 3);
269
- dis[i + 1] = _mm_cvtss_f32 (accu);
235
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
236
+ y += 4;
237
+ accu = _mm_hadd_ps(accu, accu);
238
+ dis[i] = _mm_cvtss_f32(accu);
239
+ accu = _mm_shuffle_ps(accu, accu, 3);
240
+ dis[i + 1] = _mm_cvtss_f32(accu);
270
241
  }
271
242
  if (i < ny) { // handle odd case
272
243
  dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]);
273
244
  }
274
245
  }
275
246
 
276
-
277
-
278
- template<class ElementOp>
279
- void fvec_op_ny_D4 (float * dis, const float * x,
280
- const float * y, size_t ny)
281
- {
247
+ template <class ElementOp>
248
+ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
282
249
  __m128 x0 = _mm_loadu_ps(x);
283
250
 
284
251
  for (size_t i = 0; i < ny; i++) {
285
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
286
- accu = _mm_hadd_ps (accu, accu);
287
- accu = _mm_hadd_ps (accu, accu);
288
- dis[i] = _mm_cvtss_f32 (accu);
252
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
253
+ y += 4;
254
+ accu = _mm_hadd_ps(accu, accu);
255
+ accu = _mm_hadd_ps(accu, accu);
256
+ dis[i] = _mm_cvtss_f32(accu);
289
257
  }
290
258
  }
291
259
 
292
- template<class ElementOp>
293
- void fvec_op_ny_D8 (float * dis, const float * x,
294
- const float * y, size_t ny)
295
- {
260
+ template <class ElementOp>
261
+ void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
296
262
  __m128 x0 = _mm_loadu_ps(x);
297
263
  __m128 x1 = _mm_loadu_ps(x + 4);
298
264
 
299
265
  for (size_t i = 0; i < ny; i++) {
300
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
301
- accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
302
- accu = _mm_hadd_ps (accu, accu);
303
- accu = _mm_hadd_ps (accu, accu);
304
- dis[i] = _mm_cvtss_f32 (accu);
266
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
267
+ y += 4;
268
+ accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
269
+ y += 4;
270
+ accu = _mm_hadd_ps(accu, accu);
271
+ accu = _mm_hadd_ps(accu, accu);
272
+ dis[i] = _mm_cvtss_f32(accu);
305
273
  }
306
274
  }
307
275
 
308
- template<class ElementOp>
309
- void fvec_op_ny_D12 (float * dis, const float * x,
310
- const float * y, size_t ny)
311
- {
276
+ template <class ElementOp>
277
+ void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
312
278
  __m128 x0 = _mm_loadu_ps(x);
313
279
  __m128 x1 = _mm_loadu_ps(x + 4);
314
280
  __m128 x2 = _mm_loadu_ps(x + 8);
315
281
 
316
282
  for (size_t i = 0; i < ny; i++) {
317
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
318
- accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
319
- accu += ElementOp::op(x2, _mm_loadu_ps (y)); y += 4;
320
- accu = _mm_hadd_ps (accu, accu);
321
- accu = _mm_hadd_ps (accu, accu);
322
- dis[i] = _mm_cvtss_f32 (accu);
283
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
284
+ y += 4;
285
+ accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
286
+ y += 4;
287
+ accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y)));
288
+ y += 4;
289
+ accu = _mm_hadd_ps(accu, accu);
290
+ accu = _mm_hadd_ps(accu, accu);
291
+ dis[i] = _mm_cvtss_f32(accu);
323
292
  }
324
293
  }
325
294
 
326
-
327
-
328
295
  } // anonymous namespace
329
296
 
330
- void fvec_L2sqr_ny (float * dis, const float * x,
331
- const float * y, size_t d, size_t ny) {
297
+ void fvec_L2sqr_ny(
298
+ float* dis,
299
+ const float* x,
300
+ const float* y,
301
+ size_t d,
302
+ size_t ny) {
332
303
  // optimized for a few special cases
333
304
 
334
- #define DISPATCH(dval) \
335
- case dval:\
336
- fvec_op_ny_D ## dval <ElementOpL2> (dis, x, y, ny); \
305
+ #define DISPATCH(dval) \
306
+ case dval: \
307
+ fvec_op_ny_D##dval<ElementOpL2>(dis, x, y, ny); \
337
308
  return;
338
309
 
339
- switch(d) {
310
+ switch (d) {
340
311
  DISPATCH(1)
341
312
  DISPATCH(2)
342
313
  DISPATCH(4)
343
314
  DISPATCH(8)
344
315
  DISPATCH(12)
345
- default:
346
- fvec_L2sqr_ny_ref (dis, x, y, d, ny);
347
- return;
316
+ default:
317
+ fvec_L2sqr_ny_ref(dis, x, y, d, ny);
318
+ return;
348
319
  }
349
320
  #undef DISPATCH
350
-
351
321
  }
352
322
 
353
- void fvec_inner_products_ny (float * dis, const float * x,
354
- const float * y, size_t d, size_t ny) {
355
-
356
- #define DISPATCH(dval) \
357
- case dval:\
358
- fvec_op_ny_D ## dval <ElementOpIP> (dis, x, y, ny); \
323
+ void fvec_inner_products_ny(
324
+ float* dis,
325
+ const float* x,
326
+ const float* y,
327
+ size_t d,
328
+ size_t ny) {
329
+ #define DISPATCH(dval) \
330
+ case dval: \
331
+ fvec_op_ny_D##dval<ElementOpIP>(dis, x, y, ny); \
359
332
  return;
360
333
 
361
- switch(d) {
334
+ switch (d) {
362
335
  DISPATCH(1)
363
336
  DISPATCH(2)
364
337
  DISPATCH(4)
365
338
  DISPATCH(8)
366
339
  DISPATCH(12)
367
- default:
368
- fvec_inner_products_ny_ref (dis, x, y, d, ny);
369
- return;
340
+ default:
341
+ fvec_inner_products_ny_ref(dis, x, y, d, ny);
342
+ return;
370
343
  }
371
344
  #undef DISPATCH
372
-
373
345
  }
374
346
 
375
-
376
-
377
347
  #endif
378
348
 
379
349
  #ifdef USE_AVX
380
350
 
381
351
  // reads 0 <= d < 8 floats as __m256
382
- static inline __m256 masked_read_8 (int d, const float *x)
383
- {
384
- assert (0 <= d && d < 8);
352
+ static inline __m256 masked_read_8(int d, const float* x) {
353
+ assert(0 <= d && d < 8);
385
354
  if (d < 4) {
386
- __m256 res = _mm256_setzero_ps ();
387
- res = _mm256_insertf128_ps (res, masked_read (d, x), 0);
355
+ __m256 res = _mm256_setzero_ps();
356
+ res = _mm256_insertf128_ps(res, masked_read(d, x), 0);
388
357
  return res;
389
358
  } else {
390
- __m256 res = _mm256_setzero_ps ();
391
- res = _mm256_insertf128_ps (res, _mm_loadu_ps (x), 0);
392
- res = _mm256_insertf128_ps (res, masked_read (d - 4, x + 4), 1);
359
+ __m256 res = _mm256_setzero_ps();
360
+ res = _mm256_insertf128_ps(res, _mm_loadu_ps(x), 0);
361
+ res = _mm256_insertf128_ps(res, masked_read(d - 4, x + 4), 1);
393
362
  return res;
394
363
  }
395
364
  }
396
365
 
397
- float fvec_inner_product (const float * x,
398
- const float * y,
399
- size_t d)
400
- {
366
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
401
367
  __m256 msum1 = _mm256_setzero_ps();
402
368
 
403
369
  while (d >= 8) {
404
- __m256 mx = _mm256_loadu_ps (x); x += 8;
405
- __m256 my = _mm256_loadu_ps (y); y += 8;
406
- msum1 = _mm256_add_ps (msum1, _mm256_mul_ps (mx, my));
370
+ __m256 mx = _mm256_loadu_ps(x);
371
+ x += 8;
372
+ __m256 my = _mm256_loadu_ps(y);
373
+ y += 8;
374
+ msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(mx, my));
407
375
  d -= 8;
408
376
  }
409
377
 
410
378
  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
411
- msum2 += _mm256_extractf128_ps(msum1, 0);
379
+ msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
412
380
 
413
381
  if (d >= 4) {
414
- __m128 mx = _mm_loadu_ps (x); x += 4;
415
- __m128 my = _mm_loadu_ps (y); y += 4;
416
- msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
382
+ __m128 mx = _mm_loadu_ps(x);
383
+ x += 4;
384
+ __m128 my = _mm_loadu_ps(y);
385
+ y += 4;
386
+ msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
417
387
  d -= 4;
418
388
  }
419
389
 
420
390
  if (d > 0) {
421
- __m128 mx = masked_read (d, x);
422
- __m128 my = masked_read (d, y);
423
- msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
391
+ __m128 mx = masked_read(d, x);
392
+ __m128 my = masked_read(d, y);
393
+ msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
424
394
  }
425
395
 
426
- msum2 = _mm_hadd_ps (msum2, msum2);
427
- msum2 = _mm_hadd_ps (msum2, msum2);
428
- return _mm_cvtss_f32 (msum2);
396
+ msum2 = _mm_hadd_ps(msum2, msum2);
397
+ msum2 = _mm_hadd_ps(msum2, msum2);
398
+ return _mm_cvtss_f32(msum2);
429
399
  }
430
400
 
431
- float fvec_L2sqr (const float * x,
432
- const float * y,
433
- size_t d)
434
- {
401
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
435
402
  __m256 msum1 = _mm256_setzero_ps();
436
403
 
437
404
  while (d >= 8) {
438
- __m256 mx = _mm256_loadu_ps (x); x += 8;
439
- __m256 my = _mm256_loadu_ps (y); y += 8;
440
- const __m256 a_m_b1 = mx - my;
441
- msum1 += a_m_b1 * a_m_b1;
405
+ __m256 mx = _mm256_loadu_ps(x);
406
+ x += 8;
407
+ __m256 my = _mm256_loadu_ps(y);
408
+ y += 8;
409
+ const __m256 a_m_b1 = _mm256_sub_ps(mx, my);
410
+ msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(a_m_b1, a_m_b1));
442
411
  d -= 8;
443
412
  }
444
413
 
445
414
  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
446
- msum2 += _mm256_extractf128_ps(msum1, 0);
415
+ msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
447
416
 
448
417
  if (d >= 4) {
449
- __m128 mx = _mm_loadu_ps (x); x += 4;
450
- __m128 my = _mm_loadu_ps (y); y += 4;
451
- const __m128 a_m_b1 = mx - my;
452
- msum2 += a_m_b1 * a_m_b1;
418
+ __m128 mx = _mm_loadu_ps(x);
419
+ x += 4;
420
+ __m128 my = _mm_loadu_ps(y);
421
+ y += 4;
422
+ const __m128 a_m_b1 = _mm_sub_ps(mx, my);
423
+ msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
453
424
  d -= 4;
454
425
  }
455
426
 
456
427
  if (d > 0) {
457
- __m128 mx = masked_read (d, x);
458
- __m128 my = masked_read (d, y);
459
- __m128 a_m_b1 = mx - my;
460
- msum2 += a_m_b1 * a_m_b1;
428
+ __m128 mx = masked_read(d, x);
429
+ __m128 my = masked_read(d, y);
430
+ __m128 a_m_b1 = _mm_sub_ps(mx, my);
431
+ msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
461
432
  }
462
433
 
463
- msum2 = _mm_hadd_ps (msum2, msum2);
464
- msum2 = _mm_hadd_ps (msum2, msum2);
465
- return _mm_cvtss_f32 (msum2);
434
+ msum2 = _mm_hadd_ps(msum2, msum2);
435
+ msum2 = _mm_hadd_ps(msum2, msum2);
436
+ return _mm_cvtss_f32(msum2);
466
437
  }
467
438
 
468
- float fvec_L1 (const float * x, const float * y, size_t d)
469
- {
439
+ float fvec_L1(const float* x, const float* y, size_t d) {
470
440
  __m256 msum1 = _mm256_setzero_ps();
471
- __m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL));
441
+ __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
472
442
 
473
443
  while (d >= 8) {
474
- __m256 mx = _mm256_loadu_ps (x); x += 8;
475
- __m256 my = _mm256_loadu_ps (y); y += 8;
476
- const __m256 a_m_b = mx - my;
477
- msum1 += _mm256_and_ps(signmask, a_m_b);
444
+ __m256 mx = _mm256_loadu_ps(x);
445
+ x += 8;
446
+ __m256 my = _mm256_loadu_ps(y);
447
+ y += 8;
448
+ const __m256 a_m_b = _mm256_sub_ps(mx, my);
449
+ msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b));
478
450
  d -= 8;
479
451
  }
480
452
 
481
453
  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
482
- msum2 += _mm256_extractf128_ps(msum1, 0);
483
- __m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL));
454
+ msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
455
+ __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL));
484
456
 
485
457
  if (d >= 4) {
486
- __m128 mx = _mm_loadu_ps (x); x += 4;
487
- __m128 my = _mm_loadu_ps (y); y += 4;
488
- const __m128 a_m_b = mx - my;
489
- msum2 += _mm_and_ps(signmask2, a_m_b);
458
+ __m128 mx = _mm_loadu_ps(x);
459
+ x += 4;
460
+ __m128 my = _mm_loadu_ps(y);
461
+ y += 4;
462
+ const __m128 a_m_b = _mm_sub_ps(mx, my);
463
+ msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
490
464
  d -= 4;
491
465
  }
492
466
 
493
467
  if (d > 0) {
494
- __m128 mx = masked_read (d, x);
495
- __m128 my = masked_read (d, y);
496
- __m128 a_m_b = mx - my;
497
- msum2 += _mm_and_ps(signmask2, a_m_b);
468
+ __m128 mx = masked_read(d, x);
469
+ __m128 my = masked_read(d, y);
470
+ __m128 a_m_b = _mm_sub_ps(mx, my);
471
+ msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
498
472
  }
499
473
 
500
- msum2 = _mm_hadd_ps (msum2, msum2);
501
- msum2 = _mm_hadd_ps (msum2, msum2);
502
- return _mm_cvtss_f32 (msum2);
474
+ msum2 = _mm_hadd_ps(msum2, msum2);
475
+ msum2 = _mm_hadd_ps(msum2, msum2);
476
+ return _mm_cvtss_f32(msum2);
503
477
  }
504
478
 
505
- float fvec_Linf (const float * x, const float * y, size_t d)
506
- {
479
+ float fvec_Linf(const float* x, const float* y, size_t d) {
507
480
  __m256 msum1 = _mm256_setzero_ps();
508
- __m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL));
481
+ __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
509
482
 
510
483
  while (d >= 8) {
511
- __m256 mx = _mm256_loadu_ps (x); x += 8;
512
- __m256 my = _mm256_loadu_ps (y); y += 8;
513
- const __m256 a_m_b = mx - my;
484
+ __m256 mx = _mm256_loadu_ps(x);
485
+ x += 8;
486
+ __m256 my = _mm256_loadu_ps(y);
487
+ y += 8;
488
+ const __m256 a_m_b = _mm256_sub_ps(mx, my);
514
489
  msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b));
515
490
  d -= 8;
516
491
  }
517
492
 
518
493
  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
519
- msum2 = _mm_max_ps (msum2, _mm256_extractf128_ps(msum1, 0));
520
- __m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL));
494
+ msum2 = _mm_max_ps(msum2, _mm256_extractf128_ps(msum1, 0));
495
+ __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL));
521
496
 
522
497
  if (d >= 4) {
523
- __m128 mx = _mm_loadu_ps (x); x += 4;
524
- __m128 my = _mm_loadu_ps (y); y += 4;
525
- const __m128 a_m_b = mx - my;
498
+ __m128 mx = _mm_loadu_ps(x);
499
+ x += 4;
500
+ __m128 my = _mm_loadu_ps(y);
501
+ y += 4;
502
+ const __m128 a_m_b = _mm_sub_ps(mx, my);
526
503
  msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
527
504
  d -= 4;
528
505
  }
529
506
 
530
507
  if (d > 0) {
531
- __m128 mx = masked_read (d, x);
532
- __m128 my = masked_read (d, y);
533
- __m128 a_m_b = mx - my;
508
+ __m128 mx = masked_read(d, x);
509
+ __m128 my = masked_read(d, y);
510
+ __m128 a_m_b = _mm_sub_ps(mx, my);
534
511
  msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
535
512
  }
536
513
 
537
514
  msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2);
538
- msum2 = _mm_max_ps(msum2, _mm_shuffle_ps (msum2, msum2, 1));
539
- return _mm_cvtss_f32 (msum2);
515
+ msum2 = _mm_max_ps(msum2, _mm_shuffle_ps(msum2, msum2, 1));
516
+ return _mm_cvtss_f32(msum2);
540
517
  }
541
518
 
542
519
  #elif defined(__SSE3__) // But not AVX
543
520
 
544
- float fvec_L1 (const float * x, const float * y, size_t d)
545
- {
546
- return fvec_L1_ref (x, y, d);
521
+ float fvec_L1(const float* x, const float* y, size_t d) {
522
+ return fvec_L1_ref(x, y, d);
547
523
  }
548
524
 
549
- float fvec_Linf (const float * x, const float * y, size_t d)
550
- {
551
- return fvec_Linf_ref (x, y, d);
525
+ float fvec_Linf(const float* x, const float* y, size_t d) {
526
+ return fvec_Linf_ref(x, y, d);
552
527
  }
553
528
 
554
-
555
- float fvec_L2sqr (const float * x,
556
- const float * y,
557
- size_t d)
558
- {
529
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
559
530
  __m128 msum1 = _mm_setzero_ps();
560
531
 
561
532
  while (d >= 4) {
562
- __m128 mx = _mm_loadu_ps (x); x += 4;
563
- __m128 my = _mm_loadu_ps (y); y += 4;
564
- const __m128 a_m_b1 = mx - my;
565
- msum1 += a_m_b1 * a_m_b1;
533
+ __m128 mx = _mm_loadu_ps(x);
534
+ x += 4;
535
+ __m128 my = _mm_loadu_ps(y);
536
+ y += 4;
537
+ const __m128 a_m_b1 = _mm_sub_ps(mx, my);
538
+ msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
566
539
  d -= 4;
567
540
  }
568
541
 
569
542
  if (d > 0) {
570
543
  // add the last 1, 2 or 3 values
571
- __m128 mx = masked_read (d, x);
572
- __m128 my = masked_read (d, y);
573
- __m128 a_m_b1 = mx - my;
574
- msum1 += a_m_b1 * a_m_b1;
544
+ __m128 mx = masked_read(d, x);
545
+ __m128 my = masked_read(d, y);
546
+ __m128 a_m_b1 = _mm_sub_ps(mx, my);
547
+ msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
575
548
  }
576
549
 
577
- msum1 = _mm_hadd_ps (msum1, msum1);
578
- msum1 = _mm_hadd_ps (msum1, msum1);
579
- return _mm_cvtss_f32 (msum1);
550
+ msum1 = _mm_hadd_ps(msum1, msum1);
551
+ msum1 = _mm_hadd_ps(msum1, msum1);
552
+ return _mm_cvtss_f32(msum1);
580
553
  }
581
554
 
582
-
583
- float fvec_inner_product (const float * x,
584
- const float * y,
585
- size_t d)
586
- {
555
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
587
556
  __m128 mx, my;
588
557
  __m128 msum1 = _mm_setzero_ps();
589
558
 
590
559
  while (d >= 4) {
591
- mx = _mm_loadu_ps (x); x += 4;
592
- my = _mm_loadu_ps (y); y += 4;
593
- msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, my));
560
+ mx = _mm_loadu_ps(x);
561
+ x += 4;
562
+ my = _mm_loadu_ps(y);
563
+ y += 4;
564
+ msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, my));
594
565
  d -= 4;
595
566
  }
596
567
 
597
568
  // add the last 1, 2, or 3 values
598
- mx = masked_read (d, x);
599
- my = masked_read (d, y);
600
- __m128 prod = _mm_mul_ps (mx, my);
569
+ mx = masked_read(d, x);
570
+ my = masked_read(d, y);
571
+ __m128 prod = _mm_mul_ps(mx, my);
601
572
 
602
- msum1 = _mm_add_ps (msum1, prod);
573
+ msum1 = _mm_add_ps(msum1, prod);
603
574
 
604
- msum1 = _mm_hadd_ps (msum1, msum1);
605
- msum1 = _mm_hadd_ps (msum1, msum1);
606
- return _mm_cvtss_f32 (msum1);
575
+ msum1 = _mm_hadd_ps(msum1, msum1);
576
+ msum1 = _mm_hadd_ps(msum1, msum1);
577
+ return _mm_cvtss_f32(msum1);
607
578
  }
608
579
 
609
580
  #elif defined(__aarch64__)
610
581
 
611
-
612
- float fvec_L2sqr (const float * x,
613
- const float * y,
614
- size_t d)
615
- {
616
- if (d & 3) return fvec_L2sqr_ref (x, y, d);
617
- float32x4_t accu = vdupq_n_f32 (0);
618
- for (size_t i = 0; i < d; i += 4) {
619
- float32x4_t xi = vld1q_f32 (x + i);
620
- float32x4_t yi = vld1q_f32 (y + i);
621
- float32x4_t sq = vsubq_f32 (xi, yi);
622
- accu = vfmaq_f32 (accu, sq, sq);
582
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
583
+ float32x4_t accux4 = vdupq_n_f32(0);
584
+ const size_t d_simd = d - (d & 3);
585
+ size_t i;
586
+ for (i = 0; i < d_simd; i += 4) {
587
+ float32x4_t xi = vld1q_f32(x + i);
588
+ float32x4_t yi = vld1q_f32(y + i);
589
+ float32x4_t sq = vsubq_f32(xi, yi);
590
+ accux4 = vfmaq_f32(accux4, sq, sq);
591
+ }
592
+ float32x4_t accux2 = vpaddq_f32(accux4, accux4);
593
+ float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
594
+ for (; i < d; ++i) {
595
+ float32_t xi = x[i];
596
+ float32_t yi = y[i];
597
+ float32_t sq = xi - yi;
598
+ accux1 += sq * sq;
599
+ }
600
+ return accux1;
601
+ }
602
+
603
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
604
+ float32x4_t accux4 = vdupq_n_f32(0);
605
+ const size_t d_simd = d - (d & 3);
606
+ size_t i;
607
+ for (i = 0; i < d_simd; i += 4) {
608
+ float32x4_t xi = vld1q_f32(x + i);
609
+ float32x4_t yi = vld1q_f32(y + i);
610
+ accux4 = vfmaq_f32(accux4, xi, yi);
623
611
  }
624
- float32x4_t a2 = vpaddq_f32 (accu, accu);
625
- return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
626
- }
627
-
628
- float fvec_inner_product (const float * x,
629
- const float * y,
630
- size_t d)
631
- {
632
- if (d & 3) return fvec_inner_product_ref (x, y, d);
633
- float32x4_t accu = vdupq_n_f32 (0);
634
- for (size_t i = 0; i < d; i += 4) {
635
- float32x4_t xi = vld1q_f32 (x + i);
636
- float32x4_t yi = vld1q_f32 (y + i);
637
- accu = vfmaq_f32 (accu, xi, yi);
612
+ float32x4_t accux2 = vpaddq_f32(accux4, accux4);
613
+ float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
614
+ for (; i < d; ++i) {
615
+ float32_t xi = x[i];
616
+ float32_t yi = y[i];
617
+ accux1 += xi * yi;
638
618
  }
639
- float32x4_t a2 = vpaddq_f32 (accu, accu);
640
- return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
619
+ return accux1;
641
620
  }
642
621
 
643
- float fvec_norm_L2sqr (const float *x, size_t d)
644
- {
645
- if (d & 3) return fvec_norm_L2sqr_ref (x, d);
646
- float32x4_t accu = vdupq_n_f32 (0);
647
- for (size_t i = 0; i < d; i += 4) {
648
- float32x4_t xi = vld1q_f32 (x + i);
649
- accu = vfmaq_f32 (accu, xi, xi);
622
+ float fvec_norm_L2sqr(const float* x, size_t d) {
623
+ float32x4_t accux4 = vdupq_n_f32(0);
624
+ const size_t d_simd = d - (d & 3);
625
+ size_t i;
626
+ for (i = 0; i < d_simd; i += 4) {
627
+ float32x4_t xi = vld1q_f32(x + i);
628
+ accux4 = vfmaq_f32(accux4, xi, xi);
650
629
  }
651
- float32x4_t a2 = vpaddq_f32 (accu, accu);
652
- return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
630
+ float32x4_t accux2 = vpaddq_f32(accux4, accux4);
631
+ float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
632
+ for (; i < d; ++i) {
633
+ float32_t xi = x[i];
634
+ accux1 += xi * xi;
635
+ }
636
+ return accux1;
653
637
  }
654
638
 
655
639
  // not optimized for ARM
656
- void fvec_L2sqr_ny (float * dis, const float * x,
657
- const float * y, size_t d, size_t ny) {
658
- fvec_L2sqr_ny_ref (dis, x, y, d, ny);
640
+ void fvec_L2sqr_ny(
641
+ float* dis,
642
+ const float* x,
643
+ const float* y,
644
+ size_t d,
645
+ size_t ny) {
646
+ fvec_L2sqr_ny_ref(dis, x, y, d, ny);
659
647
  }
660
648
 
661
- float fvec_L1 (const float * x, const float * y, size_t d)
662
- {
663
- return fvec_L1_ref (x, y, d);
649
+ float fvec_L1(const float* x, const float* y, size_t d) {
650
+ return fvec_L1_ref(x, y, d);
664
651
  }
665
652
 
666
- float fvec_Linf (const float * x, const float * y, size_t d)
667
- {
668
- return fvec_Linf_ref (x, y, d);
653
+ float fvec_Linf(const float* x, const float* y, size_t d) {
654
+ return fvec_Linf_ref(x, y, d);
669
655
  }
670
656
 
657
+ void fvec_inner_products_ny(
658
+ float* dis,
659
+ const float* x,
660
+ const float* y,
661
+ size_t d,
662
+ size_t ny) {
663
+ fvec_inner_products_ny_ref(dis, x, y, d, ny);
664
+ }
671
665
 
672
666
  #else
673
667
  // scalar implementation
674
668
 
675
- float fvec_L2sqr (const float * x,
676
- const float * y,
677
- size_t d)
678
- {
679
- return fvec_L2sqr_ref (x, y, d);
669
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
670
+ return fvec_L2sqr_ref(x, y, d);
680
671
  }
681
672
 
682
- float fvec_L1 (const float * x, const float * y, size_t d)
683
- {
684
- return fvec_L1_ref (x, y, d);
673
+ float fvec_L1(const float* x, const float* y, size_t d) {
674
+ return fvec_L1_ref(x, y, d);
685
675
  }
686
676
 
687
- float fvec_Linf (const float * x, const float * y, size_t d)
688
- {
689
- return fvec_Linf_ref (x, y, d);
677
+ float fvec_Linf(const float* x, const float* y, size_t d) {
678
+ return fvec_Linf_ref(x, y, d);
690
679
  }
691
680
 
692
- float fvec_inner_product (const float * x,
693
- const float * y,
694
- size_t d)
695
- {
696
- return fvec_inner_product_ref (x, y, d);
681
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
682
+ return fvec_inner_product_ref(x, y, d);
697
683
  }
698
684
 
699
- float fvec_norm_L2sqr (const float *x, size_t d)
700
- {
701
- return fvec_norm_L2sqr_ref (x, d);
685
+ float fvec_norm_L2sqr(const float* x, size_t d) {
686
+ return fvec_norm_L2sqr_ref(x, d);
702
687
  }
703
688
 
704
- void fvec_L2sqr_ny (float * dis, const float * x,
705
- const float * y, size_t d, size_t ny) {
706
- fvec_L2sqr_ny_ref (dis, x, y, d, ny);
689
+ void fvec_L2sqr_ny(
690
+ float* dis,
691
+ const float* x,
692
+ const float* y,
693
+ size_t d,
694
+ size_t ny) {
695
+ fvec_L2sqr_ny_ref(dis, x, y, d, ny);
707
696
  }
708
697
 
709
- void fvec_inner_products_ny (float * dis, const float * x,
710
- const float * y, size_t d, size_t ny) {
711
- fvec_inner_products_ny_ref (dis, x, y, d, ny);
698
+ void fvec_inner_products_ny(
699
+ float* dis,
700
+ const float* x,
701
+ const float* y,
702
+ size_t d,
703
+ size_t ny) {
704
+ fvec_inner_products_ny_ref(dis, x, y, d, ny);
712
705
  }
713
706
 
714
-
715
707
  #endif
716
708
 
717
-
718
-
719
-
720
-
721
-
722
-
723
-
724
-
725
-
726
-
727
-
728
-
729
-
730
-
731
-
732
-
733
-
734
-
735
-
736
709
  /***************************************************************************
737
710
  * heavily optimized table computations
738
711
  ***************************************************************************/
739
712
 
740
-
741
- static inline void fvec_madd_ref (size_t n, const float *a,
742
- float bf, const float *b, float *c) {
713
+ static inline void fvec_madd_ref(
714
+ size_t n,
715
+ const float* a,
716
+ float bf,
717
+ const float* b,
718
+ float* c) {
743
719
  for (size_t i = 0; i < n; i++)
744
720
  c[i] = a[i] + bf * b[i];
745
721
  }
746
722
 
747
723
  #ifdef __SSE3__
748
724
 
749
- static inline void fvec_madd_sse (size_t n, const float *a,
750
- float bf, const float *b, float *c) {
725
+ static inline void fvec_madd_sse(
726
+ size_t n,
727
+ const float* a,
728
+ float bf,
729
+ const float* b,
730
+ float* c) {
751
731
  n >>= 2;
752
- __m128 bf4 = _mm_set_ps1 (bf);
753
- __m128 * a4 = (__m128*)a;
754
- __m128 * b4 = (__m128*)b;
755
- __m128 * c4 = (__m128*)c;
732
+ __m128 bf4 = _mm_set_ps1(bf);
733
+ __m128* a4 = (__m128*)a;
734
+ __m128* b4 = (__m128*)b;
735
+ __m128* c4 = (__m128*)c;
756
736
 
757
737
  while (n--) {
758
- *c4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
738
+ *c4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
759
739
  b4++;
760
740
  a4++;
761
741
  c4++;
762
742
  }
763
743
  }
764
744
 
765
- void fvec_madd (size_t n, const float *a,
766
- float bf, const float *b, float *c)
767
- {
768
- if ((n & 3) == 0 &&
769
- ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
770
- fvec_madd_sse (n, a, bf, b, c);
745
+ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
746
+ if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
747
+ fvec_madd_sse(n, a, bf, b, c);
771
748
  else
772
- fvec_madd_ref (n, a, bf, b, c);
749
+ fvec_madd_ref(n, a, bf, b, c);
773
750
  }
774
751
 
775
752
  #else
776
753
 
777
- void fvec_madd (size_t n, const float *a,
778
- float bf, const float *b, float *c)
779
- {
780
- fvec_madd_ref (n, a, bf, b, c);
754
+ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
755
+ fvec_madd_ref(n, a, bf, b, c);
781
756
  }
782
757
 
783
758
  #endif
784
759
 
785
- static inline int fvec_madd_and_argmin_ref (size_t n, const float *a,
786
- float bf, const float *b, float *c) {
760
+ static inline int fvec_madd_and_argmin_ref(
761
+ size_t n,
762
+ const float* a,
763
+ float bf,
764
+ const float* b,
765
+ float* c) {
787
766
  float vmin = 1e20;
788
767
  int imin = -1;
789
768
 
@@ -799,125 +778,100 @@ static inline int fvec_madd_and_argmin_ref (size_t n, const float *a,
799
778
 
800
779
  #ifdef __SSE3__
801
780
 
802
- static inline int fvec_madd_and_argmin_sse (
803
- size_t n, const float *a,
804
- float bf, const float *b, float *c) {
781
+ static inline int fvec_madd_and_argmin_sse(
782
+ size_t n,
783
+ const float* a,
784
+ float bf,
785
+ const float* b,
786
+ float* c) {
805
787
  n >>= 2;
806
- __m128 bf4 = _mm_set_ps1 (bf);
807
- __m128 vmin4 = _mm_set_ps1 (1e20);
808
- __m128i imin4 = _mm_set1_epi32 (-1);
809
- __m128i idx4 = _mm_set_epi32 (3, 2, 1, 0);
810
- __m128i inc4 = _mm_set1_epi32 (4);
811
- __m128 * a4 = (__m128*)a;
812
- __m128 * b4 = (__m128*)b;
813
- __m128 * c4 = (__m128*)c;
788
+ __m128 bf4 = _mm_set_ps1(bf);
789
+ __m128 vmin4 = _mm_set_ps1(1e20);
790
+ __m128i imin4 = _mm_set1_epi32(-1);
791
+ __m128i idx4 = _mm_set_epi32(3, 2, 1, 0);
792
+ __m128i inc4 = _mm_set1_epi32(4);
793
+ __m128* a4 = (__m128*)a;
794
+ __m128* b4 = (__m128*)b;
795
+ __m128* c4 = (__m128*)c;
814
796
 
815
797
  while (n--) {
816
- __m128 vc4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
798
+ __m128 vc4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
817
799
  *c4 = vc4;
818
- __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
800
+ __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
819
801
  // imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower!
820
802
 
821
- imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
822
- _mm_andnot_si128 (mask, imin4));
823
- vmin4 = _mm_min_ps (vmin4, vc4);
803
+ imin4 = _mm_or_si128(
804
+ _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
805
+ vmin4 = _mm_min_ps(vmin4, vc4);
824
806
  b4++;
825
807
  a4++;
826
808
  c4++;
827
- idx4 = _mm_add_epi32 (idx4, inc4);
809
+ idx4 = _mm_add_epi32(idx4, inc4);
828
810
  }
829
811
 
830
812
  // 4 values -> 2
831
813
  {
832
- idx4 = _mm_shuffle_epi32 (imin4, 3 << 2 | 2);
833
- __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 3 << 2 | 2);
834
- __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
835
- imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
836
- _mm_andnot_si128 (mask, imin4));
837
- vmin4 = _mm_min_ps (vmin4, vc4);
814
+ idx4 = _mm_shuffle_epi32(imin4, 3 << 2 | 2);
815
+ __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 3 << 2 | 2);
816
+ __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
817
+ imin4 = _mm_or_si128(
818
+ _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
819
+ vmin4 = _mm_min_ps(vmin4, vc4);
838
820
  }
839
821
  // 2 values -> 1
840
822
  {
841
- idx4 = _mm_shuffle_epi32 (imin4, 1);
842
- __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 1);
843
- __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
844
- imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
845
- _mm_andnot_si128 (mask, imin4));
823
+ idx4 = _mm_shuffle_epi32(imin4, 1);
824
+ __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 1);
825
+ __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
826
+ imin4 = _mm_or_si128(
827
+ _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
846
828
  // vmin4 = _mm_min_ps (vmin4, vc4);
847
829
  }
848
- return _mm_cvtsi128_si32 (imin4);
830
+ return _mm_cvtsi128_si32(imin4);
849
831
  }
850
832
 
851
-
852
- int fvec_madd_and_argmin (size_t n, const float *a,
853
- float bf, const float *b, float *c)
854
- {
855
- if ((n & 3) == 0 &&
856
- ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
857
- return fvec_madd_and_argmin_sse (n, a, bf, b, c);
833
+ int fvec_madd_and_argmin(
834
+ size_t n,
835
+ const float* a,
836
+ float bf,
837
+ const float* b,
838
+ float* c) {
839
+ if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
840
+ return fvec_madd_and_argmin_sse(n, a, bf, b, c);
858
841
  else
859
- return fvec_madd_and_argmin_ref (n, a, bf, b, c);
842
+ return fvec_madd_and_argmin_ref(n, a, bf, b, c);
860
843
  }
861
844
 
862
845
  #else
863
846
 
864
- int fvec_madd_and_argmin (size_t n, const float *a,
865
- float bf, const float *b, float *c)
866
- {
867
- return fvec_madd_and_argmin_ref (n, a, bf, b, c);
847
+ int fvec_madd_and_argmin(
848
+ size_t n,
849
+ const float* a,
850
+ float bf,
851
+ const float* b,
852
+ float* c) {
853
+ return fvec_madd_and_argmin_ref(n, a, bf, b, c);
868
854
  }
869
855
 
870
856
  #endif
871
857
 
872
-
873
858
  /***************************************************************************
874
859
  * PQ tables computations
875
860
  ***************************************************************************/
876
861
 
877
- #ifdef __AVX2__
878
-
879
862
  namespace {
880
863
 
881
-
882
- // get even float32's of a and b, interleaved
883
- simd8float32 geteven(simd8float32 a, simd8float32 b) {
884
- return simd8float32(
885
- _mm256_shuffle_ps(a.f, b.f, 0 << 0 | 2 << 2 | 0 << 4 | 2 << 6)
886
- );
887
- }
888
-
889
- // get odd float32's of a and b, interleaved
890
- simd8float32 getodd(simd8float32 a, simd8float32 b) {
891
- return simd8float32(
892
- _mm256_shuffle_ps(a.f, b.f, 1 << 0 | 3 << 2 | 1 << 4 | 3 << 6)
893
- );
894
- }
895
-
896
- // 3 cycles
897
- // if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
898
- simd8float32 getlow128(simd8float32 a, simd8float32 b) {
899
- return simd8float32(
900
- _mm256_permute2f128_ps(a.f, b.f, 0 | 2 << 4)
901
- );
902
- }
903
-
904
- simd8float32 gethigh128(simd8float32 a, simd8float32 b) {
905
- return simd8float32(
906
- _mm256_permute2f128_ps(a.f, b.f, 1 | 3 << 4)
907
- );
908
- }
909
-
910
864
  /// compute the IP for dsub = 2 for 8 centroids and 4 sub-vectors at a time
911
- template<bool is_inner_product>
865
+ template <bool is_inner_product>
912
866
  void pq2_8cents_table(
913
867
  const simd8float32 centroids[8],
914
868
  const simd8float32 x,
915
- float *out, size_t ldo, size_t nout = 4
916
- ) {
917
-
869
+ float* out,
870
+ size_t ldo,
871
+ size_t nout = 4) {
918
872
  simd8float32 ips[4];
919
873
 
920
- for(int i = 0; i < 4; i++) {
874
+ for (int i = 0; i < 4; i++) {
921
875
  simd8float32 p1, p2;
922
876
  if (is_inner_product) {
923
877
  p1 = x * centroids[2 * i];
@@ -941,21 +895,21 @@ void pq2_8cents_table(
941
895
  simd8float32 ip1 = getlow128(ip13a, ip13b);
942
896
  simd8float32 ip3 = gethigh128(ip13a, ip13b);
943
897
 
944
- switch(nout) {
945
- case 4:
946
- ip3.storeu(out + 3 * ldo);
947
- case 3:
948
- ip2.storeu(out + 2 * ldo);
949
- case 2:
950
- ip1.storeu(out + 1 * ldo);
951
- case 1:
952
- ip0.storeu(out);
898
+ switch (nout) {
899
+ case 4:
900
+ ip3.storeu(out + 3 * ldo);
901
+ case 3:
902
+ ip2.storeu(out + 2 * ldo);
903
+ case 2:
904
+ ip1.storeu(out + 1 * ldo);
905
+ case 1:
906
+ ip0.storeu(out);
953
907
  }
954
908
  }
955
909
 
956
- simd8float32 load_simd8float32_partial(const float *x, int n) {
910
+ simd8float32 load_simd8float32_partial(const float* x, int n) {
957
911
  ALIGNED(32) float tmp[8] = {0, 0, 0, 0, 0, 0, 0, 0};
958
- float *wp = tmp;
912
+ float* wp = tmp;
959
913
  for (int i = 0; i < n; i++) {
960
914
  *wp++ = *x++;
961
915
  }
@@ -964,25 +918,23 @@ simd8float32 load_simd8float32_partial(const float *x, int n) {
964
918
 
965
919
  } // anonymous namespace
966
920
 
967
-
968
-
969
-
970
921
  void compute_PQ_dis_tables_dsub2(
971
- size_t d, size_t ksub, const float *all_centroids,
972
- size_t nx, const float * x,
922
+ size_t d,
923
+ size_t ksub,
924
+ const float* all_centroids,
925
+ size_t nx,
926
+ const float* x,
973
927
  bool is_inner_product,
974
- float * dis_tables)
975
- {
928
+ float* dis_tables) {
976
929
  size_t M = d / 2;
977
930
  FAISS_THROW_IF_NOT(ksub % 8 == 0);
978
931
 
979
- for(size_t m0 = 0; m0 < M; m0 += 4) {
932
+ for (size_t m0 = 0; m0 < M; m0 += 4) {
980
933
  int m1 = std::min(M, m0 + 4);
981
- for(int k0 = 0; k0 < ksub; k0 += 8) {
982
-
934
+ for (int k0 = 0; k0 < ksub; k0 += 8) {
983
935
  simd8float32 centroids[8];
984
936
  for (int k = 0; k < 8; k++) {
985
- float centroid[8] __attribute__((aligned(32)));
937
+ ALIGNED(32) float centroid[8];
986
938
  size_t wp = 0;
987
939
  size_t rp = (m0 * ksub + k + k0) * 2;
988
940
  for (int m = m0; m < m1; m++) {
@@ -992,45 +944,33 @@ void compute_PQ_dis_tables_dsub2(
992
944
  }
993
945
  centroids[k] = simd8float32(centroid);
994
946
  }
995
- for(size_t i = 0; i < nx; i++) {
947
+ for (size_t i = 0; i < nx; i++) {
996
948
  simd8float32 xi;
997
949
  if (m1 == m0 + 4) {
998
950
  xi.loadu(x + i * d + m0 * 2);
999
951
  } else {
1000
- xi = load_simd8float32_partial(x + i * d + m0 * 2, 2 * (m1 - m0));
952
+ xi = load_simd8float32_partial(
953
+ x + i * d + m0 * 2, 2 * (m1 - m0));
1001
954
  }
1002
955
 
1003
- if(is_inner_product) {
956
+ if (is_inner_product) {
1004
957
  pq2_8cents_table<true>(
1005
- centroids, xi,
1006
- dis_tables + (i * M + m0) * ksub + k0,
1007
- ksub, m1 - m0
1008
- );
958
+ centroids,
959
+ xi,
960
+ dis_tables + (i * M + m0) * ksub + k0,
961
+ ksub,
962
+ m1 - m0);
1009
963
  } else {
1010
964
  pq2_8cents_table<false>(
1011
- centroids, xi,
1012
- dis_tables + (i * M + m0) * ksub + k0,
1013
- ksub, m1 - m0
1014
- );
965
+ centroids,
966
+ xi,
967
+ dis_tables + (i * M + m0) * ksub + k0,
968
+ ksub,
969
+ m1 - m0);
1015
970
  }
1016
971
  }
1017
972
  }
1018
973
  }
1019
-
1020
974
  }
1021
975
 
1022
- #else
1023
-
1024
- void compute_PQ_dis_tables_dsub2(
1025
- size_t d, size_t ksub, const float *all_centroids,
1026
- size_t nx, const float * x,
1027
- bool is_inner_product,
1028
- float * dis_tables)
1029
- {
1030
- FAISS_THROW_MSG("only implemented for AVX2");
1031
- }
1032
-
1033
- #endif
1034
-
1035
-
1036
976
  } // namespace faiss