faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -9,13 +9,15 @@
9
9
 
10
10
  #include <faiss/utils/distances.h>
11
11
 
12
- #include <cstdio>
12
+ #include <algorithm>
13
13
  #include <cassert>
14
- #include <cstring>
15
14
  #include <cmath>
15
+ #include <cstdio>
16
+ #include <cstring>
16
17
 
17
- #include <faiss/utils/simdlib.h>
18
18
  #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/impl/platform_macros.h>
20
+ #include <faiss/utils/simdlib.h>
19
21
 
20
22
  #ifdef __SSE3__
21
23
  #include <immintrin.h>
@@ -25,19 +27,16 @@
25
27
  #include <arm_neon.h>
26
28
  #endif
27
29
 
28
-
29
30
  namespace faiss {
30
31
 
31
32
  #ifdef __AVX__
32
33
  #define USE_AVX
33
34
  #endif
34
35
 
35
-
36
36
  /*********************************************************
37
37
  * Optimized distance computations
38
38
  *********************************************************/
39
39
 
40
-
41
40
  /* Functions to compute:
42
41
  - L2 distance between 2 vectors
43
42
  - inner product between 2 vectors
@@ -53,29 +52,21 @@ namespace faiss {
53
52
 
54
53
  */
55
54
 
56
-
57
55
  /*********************************************************
58
56
  * Reference implementations
59
57
  */
60
58
 
61
-
62
- float fvec_L2sqr_ref (const float * x,
63
- const float * y,
64
- size_t d)
65
- {
59
+ float fvec_L2sqr_ref(const float* x, const float* y, size_t d) {
66
60
  size_t i;
67
61
  float res = 0;
68
62
  for (i = 0; i < d; i++) {
69
63
  const float tmp = x[i] - y[i];
70
- res += tmp * tmp;
64
+ res += tmp * tmp;
71
65
  }
72
66
  return res;
73
67
  }
74
68
 
75
- float fvec_L1_ref (const float * x,
76
- const float * y,
77
- size_t d)
78
- {
69
+ float fvec_L1_ref(const float* x, const float* y, size_t d) {
79
70
  size_t i;
80
71
  float res = 0;
81
72
  for (i = 0; i < d; i++) {
@@ -85,56 +76,49 @@ float fvec_L1_ref (const float * x,
85
76
  return res;
86
77
  }
87
78
 
88
- float fvec_Linf_ref (const float * x,
89
- const float * y,
90
- size_t d)
91
- {
79
+ float fvec_Linf_ref(const float* x, const float* y, size_t d) {
92
80
  size_t i;
93
81
  float res = 0;
94
82
  for (i = 0; i < d; i++) {
95
- res = fmax(res, fabs(x[i] - y[i]));
83
+ res = fmax(res, fabs(x[i] - y[i]));
96
84
  }
97
85
  return res;
98
86
  }
99
87
 
100
- float fvec_inner_product_ref (const float * x,
101
- const float * y,
102
- size_t d)
103
- {
88
+ float fvec_inner_product_ref(const float* x, const float* y, size_t d) {
104
89
  size_t i;
105
90
  float res = 0;
106
91
  for (i = 0; i < d; i++)
107
- res += x[i] * y[i];
92
+ res += x[i] * y[i];
108
93
  return res;
109
94
  }
110
95
 
111
- float fvec_norm_L2sqr_ref (const float *x, size_t d)
112
- {
96
+ float fvec_norm_L2sqr_ref(const float* x, size_t d) {
113
97
  size_t i;
114
98
  double res = 0;
115
99
  for (i = 0; i < d; i++)
116
- res += x[i] * x[i];
100
+ res += x[i] * x[i];
117
101
  return res;
118
102
  }
119
103
 
120
-
121
- void fvec_L2sqr_ny_ref (float * dis,
122
- const float * x,
123
- const float * y,
124
- size_t d, size_t ny)
125
- {
104
+ void fvec_L2sqr_ny_ref(
105
+ float* dis,
106
+ const float* x,
107
+ const float* y,
108
+ size_t d,
109
+ size_t ny) {
126
110
  for (size_t i = 0; i < ny; i++) {
127
- dis[i] = fvec_L2sqr (x, y, d);
111
+ dis[i] = fvec_L2sqr(x, y, d);
128
112
  y += d;
129
113
  }
130
114
  }
131
115
 
132
-
133
- void fvec_inner_products_ny_ref (float * ip,
134
- const float * x,
135
- const float * y,
136
- size_t d, size_t ny)
137
- {
116
+ void fvec_inner_products_ny_ref(
117
+ float* ip,
118
+ const float* x,
119
+ const float* y,
120
+ size_t d,
121
+ size_t ny) {
138
122
  // BLAS slower for the use cases here
139
123
  #if 0
140
124
  {
@@ -146,15 +130,11 @@ void fvec_inner_products_ny_ref (float * ip,
146
130
  }
147
131
  #endif
148
132
  for (size_t i = 0; i < ny; i++) {
149
- ip[i] = fvec_inner_product (x, y, d);
133
+ ip[i] = fvec_inner_product(x, y, d);
150
134
  y += d;
151
135
  }
152
136
  }
153
137
 
154
-
155
-
156
-
157
-
158
138
  /*********************************************************
159
139
  * SSE and AVX implementations
160
140
  */
@@ -162,40 +142,38 @@ void fvec_inner_products_ny_ref (float * ip,
162
142
  #ifdef __SSE3__
163
143
 
164
144
  // reads 0 <= d < 4 floats as __m128
165
- static inline __m128 masked_read (int d, const float *x)
166
- {
167
- assert (0 <= d && d < 4);
168
- __attribute__((__aligned__(16))) float buf[4] = {0, 0, 0, 0};
145
+ static inline __m128 masked_read(int d, const float* x) {
146
+ assert(0 <= d && d < 4);
147
+ ALIGNED(16) float buf[4] = {0, 0, 0, 0};
169
148
  switch (d) {
170
- case 3:
171
- buf[2] = x[2];
172
- case 2:
173
- buf[1] = x[1];
174
- case 1:
175
- buf[0] = x[0];
149
+ case 3:
150
+ buf[2] = x[2];
151
+ case 2:
152
+ buf[1] = x[1];
153
+ case 1:
154
+ buf[0] = x[0];
176
155
  }
177
- return _mm_load_ps (buf);
156
+ return _mm_load_ps(buf);
178
157
  // cannot use AVX2 _mm_mask_set1_epi32
179
158
  }
180
159
 
181
- float fvec_norm_L2sqr (const float * x,
182
- size_t d)
183
- {
160
+ float fvec_norm_L2sqr(const float* x, size_t d) {
184
161
  __m128 mx;
185
162
  __m128 msum1 = _mm_setzero_ps();
186
163
 
187
164
  while (d >= 4) {
188
- mx = _mm_loadu_ps (x); x += 4;
189
- msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
165
+ mx = _mm_loadu_ps(x);
166
+ x += 4;
167
+ msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
190
168
  d -= 4;
191
169
  }
192
170
 
193
- mx = masked_read (d, x);
194
- msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, mx));
171
+ mx = masked_read(d, x);
172
+ msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, mx));
195
173
 
196
- msum1 = _mm_hadd_ps (msum1, msum1);
197
- msum1 = _mm_hadd_ps (msum1, msum1);
198
- return _mm_cvtss_f32 (msum1);
174
+ msum1 = _mm_hadd_ps(msum1, msum1);
175
+ msum1 = _mm_hadd_ps(msum1, msum1);
176
+ return _mm_cvtss_f32(msum1);
199
177
  }
200
178
 
201
179
  namespace {
@@ -204,586 +182,588 @@ namespace {
204
182
  /// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
205
183
  /// functions below
206
184
  struct ElementOpL2 {
207
-
208
- static float op (float x, float y) {
185
+ static float op(float x, float y) {
209
186
  float tmp = x - y;
210
187
  return tmp * tmp;
211
188
  }
212
189
 
213
- static __m128 op (__m128 x, __m128 y) {
214
- __m128 tmp = x - y;
215
- return tmp * tmp;
190
+ static __m128 op(__m128 x, __m128 y) {
191
+ __m128 tmp = _mm_sub_ps(x, y);
192
+ return _mm_mul_ps(tmp, tmp);
216
193
  }
217
-
218
194
  };
219
195
 
220
196
  /// Function that does a component-wise operation between x and y
221
197
  /// to compute inner products
222
198
  struct ElementOpIP {
223
-
224
- static float op (float x, float y) {
199
+ static float op(float x, float y) {
225
200
  return x * y;
226
201
  }
227
202
 
228
- static __m128 op (__m128 x, __m128 y) {
229
- return x * y;
203
+ static __m128 op(__m128 x, __m128 y) {
204
+ return _mm_mul_ps(x, y);
230
205
  }
231
-
232
206
  };
233
207
 
234
- template<class ElementOp>
235
- void fvec_op_ny_D1 (float * dis, const float * x,
236
- const float * y, size_t ny)
237
- {
208
+ template <class ElementOp>
209
+ void fvec_op_ny_D1(float* dis, const float* x, const float* y, size_t ny) {
238
210
  float x0s = x[0];
239
- __m128 x0 = _mm_set_ps (x0s, x0s, x0s, x0s);
211
+ __m128 x0 = _mm_set_ps(x0s, x0s, x0s, x0s);
240
212
 
241
213
  size_t i;
242
214
  for (i = 0; i + 3 < ny; i += 4) {
243
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
244
- dis[i] = _mm_cvtss_f32 (accu);
245
- __m128 tmp = _mm_shuffle_ps (accu, accu, 1);
246
- dis[i + 1] = _mm_cvtss_f32 (tmp);
247
- tmp = _mm_shuffle_ps (accu, accu, 2);
248
- dis[i + 2] = _mm_cvtss_f32 (tmp);
249
- tmp = _mm_shuffle_ps (accu, accu, 3);
250
- dis[i + 3] = _mm_cvtss_f32 (tmp);
215
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
216
+ y += 4;
217
+ dis[i] = _mm_cvtss_f32(accu);
218
+ __m128 tmp = _mm_shuffle_ps(accu, accu, 1);
219
+ dis[i + 1] = _mm_cvtss_f32(tmp);
220
+ tmp = _mm_shuffle_ps(accu, accu, 2);
221
+ dis[i + 2] = _mm_cvtss_f32(tmp);
222
+ tmp = _mm_shuffle_ps(accu, accu, 3);
223
+ dis[i + 3] = _mm_cvtss_f32(tmp);
251
224
  }
252
225
  while (i < ny) { // handle non-multiple-of-4 case
253
226
  dis[i++] = ElementOp::op(x0s, *y++);
254
227
  }
255
228
  }
256
229
 
257
- template<class ElementOp>
258
- void fvec_op_ny_D2 (float * dis, const float * x,
259
- const float * y, size_t ny)
260
- {
261
- __m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
230
+ template <class ElementOp>
231
+ void fvec_op_ny_D2(float* dis, const float* x, const float* y, size_t ny) {
232
+ __m128 x0 = _mm_set_ps(x[1], x[0], x[1], x[0]);
262
233
 
263
234
  size_t i;
264
235
  for (i = 0; i + 1 < ny; i += 2) {
265
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
266
- accu = _mm_hadd_ps (accu, accu);
267
- dis[i] = _mm_cvtss_f32 (accu);
268
- accu = _mm_shuffle_ps (accu, accu, 3);
269
- dis[i + 1] = _mm_cvtss_f32 (accu);
236
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
237
+ y += 4;
238
+ accu = _mm_hadd_ps(accu, accu);
239
+ dis[i] = _mm_cvtss_f32(accu);
240
+ accu = _mm_shuffle_ps(accu, accu, 3);
241
+ dis[i + 1] = _mm_cvtss_f32(accu);
270
242
  }
271
243
  if (i < ny) { // handle odd case
272
244
  dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]);
273
245
  }
274
246
  }
275
247
 
276
-
277
-
278
- template<class ElementOp>
279
- void fvec_op_ny_D4 (float * dis, const float * x,
280
- const float * y, size_t ny)
281
- {
248
+ template <class ElementOp>
249
+ void fvec_op_ny_D4(float* dis, const float* x, const float* y, size_t ny) {
282
250
  __m128 x0 = _mm_loadu_ps(x);
283
251
 
284
252
  for (size_t i = 0; i < ny; i++) {
285
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
286
- accu = _mm_hadd_ps (accu, accu);
287
- accu = _mm_hadd_ps (accu, accu);
288
- dis[i] = _mm_cvtss_f32 (accu);
253
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
254
+ y += 4;
255
+ accu = _mm_hadd_ps(accu, accu);
256
+ accu = _mm_hadd_ps(accu, accu);
257
+ dis[i] = _mm_cvtss_f32(accu);
289
258
  }
290
259
  }
291
260
 
292
- template<class ElementOp>
293
- void fvec_op_ny_D8 (float * dis, const float * x,
294
- const float * y, size_t ny)
295
- {
261
+ template <class ElementOp>
262
+ void fvec_op_ny_D8(float* dis, const float* x, const float* y, size_t ny) {
296
263
  __m128 x0 = _mm_loadu_ps(x);
297
264
  __m128 x1 = _mm_loadu_ps(x + 4);
298
265
 
299
266
  for (size_t i = 0; i < ny; i++) {
300
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
301
- accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
302
- accu = _mm_hadd_ps (accu, accu);
303
- accu = _mm_hadd_ps (accu, accu);
304
- dis[i] = _mm_cvtss_f32 (accu);
267
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
268
+ y += 4;
269
+ accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
270
+ y += 4;
271
+ accu = _mm_hadd_ps(accu, accu);
272
+ accu = _mm_hadd_ps(accu, accu);
273
+ dis[i] = _mm_cvtss_f32(accu);
305
274
  }
306
275
  }
307
276
 
308
- template<class ElementOp>
309
- void fvec_op_ny_D12 (float * dis, const float * x,
310
- const float * y, size_t ny)
311
- {
277
+ template <class ElementOp>
278
+ void fvec_op_ny_D12(float* dis, const float* x, const float* y, size_t ny) {
312
279
  __m128 x0 = _mm_loadu_ps(x);
313
280
  __m128 x1 = _mm_loadu_ps(x + 4);
314
281
  __m128 x2 = _mm_loadu_ps(x + 8);
315
282
 
316
283
  for (size_t i = 0; i < ny; i++) {
317
- __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
318
- accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
319
- accu += ElementOp::op(x2, _mm_loadu_ps (y)); y += 4;
320
- accu = _mm_hadd_ps (accu, accu);
321
- accu = _mm_hadd_ps (accu, accu);
322
- dis[i] = _mm_cvtss_f32 (accu);
284
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps(y));
285
+ y += 4;
286
+ accu = _mm_add_ps(accu, ElementOp::op(x1, _mm_loadu_ps(y)));
287
+ y += 4;
288
+ accu = _mm_add_ps(accu, ElementOp::op(x2, _mm_loadu_ps(y)));
289
+ y += 4;
290
+ accu = _mm_hadd_ps(accu, accu);
291
+ accu = _mm_hadd_ps(accu, accu);
292
+ dis[i] = _mm_cvtss_f32(accu);
323
293
  }
324
294
  }
325
295
 
326
-
327
-
328
296
  } // anonymous namespace
329
297
 
330
- void fvec_L2sqr_ny (float * dis, const float * x,
331
- const float * y, size_t d, size_t ny) {
298
+ void fvec_L2sqr_ny(
299
+ float* dis,
300
+ const float* x,
301
+ const float* y,
302
+ size_t d,
303
+ size_t ny) {
332
304
  // optimized for a few special cases
333
305
 
334
- #define DISPATCH(dval) \
335
- case dval:\
336
- fvec_op_ny_D ## dval <ElementOpL2> (dis, x, y, ny); \
306
+ #define DISPATCH(dval) \
307
+ case dval: \
308
+ fvec_op_ny_D##dval<ElementOpL2>(dis, x, y, ny); \
337
309
  return;
338
310
 
339
- switch(d) {
311
+ switch (d) {
340
312
  DISPATCH(1)
341
313
  DISPATCH(2)
342
314
  DISPATCH(4)
343
315
  DISPATCH(8)
344
316
  DISPATCH(12)
345
- default:
346
- fvec_L2sqr_ny_ref (dis, x, y, d, ny);
347
- return;
317
+ default:
318
+ fvec_L2sqr_ny_ref(dis, x, y, d, ny);
319
+ return;
348
320
  }
349
321
  #undef DISPATCH
350
-
351
322
  }
352
323
 
353
- void fvec_inner_products_ny (float * dis, const float * x,
354
- const float * y, size_t d, size_t ny) {
355
-
356
- #define DISPATCH(dval) \
357
- case dval:\
358
- fvec_op_ny_D ## dval <ElementOpIP> (dis, x, y, ny); \
324
+ void fvec_inner_products_ny(
325
+ float* dis,
326
+ const float* x,
327
+ const float* y,
328
+ size_t d,
329
+ size_t ny) {
330
+ #define DISPATCH(dval) \
331
+ case dval: \
332
+ fvec_op_ny_D##dval<ElementOpIP>(dis, x, y, ny); \
359
333
  return;
360
334
 
361
- switch(d) {
335
+ switch (d) {
362
336
  DISPATCH(1)
363
337
  DISPATCH(2)
364
338
  DISPATCH(4)
365
339
  DISPATCH(8)
366
340
  DISPATCH(12)
367
- default:
368
- fvec_inner_products_ny_ref (dis, x, y, d, ny);
369
- return;
341
+ default:
342
+ fvec_inner_products_ny_ref(dis, x, y, d, ny);
343
+ return;
370
344
  }
371
345
  #undef DISPATCH
372
-
373
346
  }
374
347
 
375
-
376
-
377
348
  #endif
378
349
 
379
350
  #ifdef USE_AVX
380
351
 
381
352
  // reads 0 <= d < 8 floats as __m256
382
- static inline __m256 masked_read_8 (int d, const float *x)
383
- {
384
- assert (0 <= d && d < 8);
353
+ static inline __m256 masked_read_8(int d, const float* x) {
354
+ assert(0 <= d && d < 8);
385
355
  if (d < 4) {
386
- __m256 res = _mm256_setzero_ps ();
387
- res = _mm256_insertf128_ps (res, masked_read (d, x), 0);
356
+ __m256 res = _mm256_setzero_ps();
357
+ res = _mm256_insertf128_ps(res, masked_read(d, x), 0);
388
358
  return res;
389
359
  } else {
390
- __m256 res = _mm256_setzero_ps ();
391
- res = _mm256_insertf128_ps (res, _mm_loadu_ps (x), 0);
392
- res = _mm256_insertf128_ps (res, masked_read (d - 4, x + 4), 1);
360
+ __m256 res = _mm256_setzero_ps();
361
+ res = _mm256_insertf128_ps(res, _mm_loadu_ps(x), 0);
362
+ res = _mm256_insertf128_ps(res, masked_read(d - 4, x + 4), 1);
393
363
  return res;
394
364
  }
395
365
  }
396
366
 
397
- float fvec_inner_product (const float * x,
398
- const float * y,
399
- size_t d)
400
- {
367
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
401
368
  __m256 msum1 = _mm256_setzero_ps();
402
369
 
403
370
  while (d >= 8) {
404
- __m256 mx = _mm256_loadu_ps (x); x += 8;
405
- __m256 my = _mm256_loadu_ps (y); y += 8;
406
- msum1 = _mm256_add_ps (msum1, _mm256_mul_ps (mx, my));
371
+ __m256 mx = _mm256_loadu_ps(x);
372
+ x += 8;
373
+ __m256 my = _mm256_loadu_ps(y);
374
+ y += 8;
375
+ msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(mx, my));
407
376
  d -= 8;
408
377
  }
409
378
 
410
379
  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
411
- msum2 += _mm256_extractf128_ps(msum1, 0);
380
+ msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
412
381
 
413
382
  if (d >= 4) {
414
- __m128 mx = _mm_loadu_ps (x); x += 4;
415
- __m128 my = _mm_loadu_ps (y); y += 4;
416
- msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
383
+ __m128 mx = _mm_loadu_ps(x);
384
+ x += 4;
385
+ __m128 my = _mm_loadu_ps(y);
386
+ y += 4;
387
+ msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
417
388
  d -= 4;
418
389
  }
419
390
 
420
391
  if (d > 0) {
421
- __m128 mx = masked_read (d, x);
422
- __m128 my = masked_read (d, y);
423
- msum2 = _mm_add_ps (msum2, _mm_mul_ps (mx, my));
392
+ __m128 mx = masked_read(d, x);
393
+ __m128 my = masked_read(d, y);
394
+ msum2 = _mm_add_ps(msum2, _mm_mul_ps(mx, my));
424
395
  }
425
396
 
426
- msum2 = _mm_hadd_ps (msum2, msum2);
427
- msum2 = _mm_hadd_ps (msum2, msum2);
428
- return _mm_cvtss_f32 (msum2);
397
+ msum2 = _mm_hadd_ps(msum2, msum2);
398
+ msum2 = _mm_hadd_ps(msum2, msum2);
399
+ return _mm_cvtss_f32(msum2);
429
400
  }
430
401
 
431
- float fvec_L2sqr (const float * x,
432
- const float * y,
433
- size_t d)
434
- {
402
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
435
403
  __m256 msum1 = _mm256_setzero_ps();
436
404
 
437
405
  while (d >= 8) {
438
- __m256 mx = _mm256_loadu_ps (x); x += 8;
439
- __m256 my = _mm256_loadu_ps (y); y += 8;
440
- const __m256 a_m_b1 = mx - my;
441
- msum1 += a_m_b1 * a_m_b1;
406
+ __m256 mx = _mm256_loadu_ps(x);
407
+ x += 8;
408
+ __m256 my = _mm256_loadu_ps(y);
409
+ y += 8;
410
+ const __m256 a_m_b1 = _mm256_sub_ps(mx, my);
411
+ msum1 = _mm256_add_ps(msum1, _mm256_mul_ps(a_m_b1, a_m_b1));
442
412
  d -= 8;
443
413
  }
444
414
 
445
415
  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
446
- msum2 += _mm256_extractf128_ps(msum1, 0);
416
+ msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
447
417
 
448
418
  if (d >= 4) {
449
- __m128 mx = _mm_loadu_ps (x); x += 4;
450
- __m128 my = _mm_loadu_ps (y); y += 4;
451
- const __m128 a_m_b1 = mx - my;
452
- msum2 += a_m_b1 * a_m_b1;
419
+ __m128 mx = _mm_loadu_ps(x);
420
+ x += 4;
421
+ __m128 my = _mm_loadu_ps(y);
422
+ y += 4;
423
+ const __m128 a_m_b1 = _mm_sub_ps(mx, my);
424
+ msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
453
425
  d -= 4;
454
426
  }
455
427
 
456
428
  if (d > 0) {
457
- __m128 mx = masked_read (d, x);
458
- __m128 my = masked_read (d, y);
459
- __m128 a_m_b1 = mx - my;
460
- msum2 += a_m_b1 * a_m_b1;
429
+ __m128 mx = masked_read(d, x);
430
+ __m128 my = masked_read(d, y);
431
+ __m128 a_m_b1 = _mm_sub_ps(mx, my);
432
+ msum2 = _mm_add_ps(msum2, _mm_mul_ps(a_m_b1, a_m_b1));
461
433
  }
462
434
 
463
- msum2 = _mm_hadd_ps (msum2, msum2);
464
- msum2 = _mm_hadd_ps (msum2, msum2);
465
- return _mm_cvtss_f32 (msum2);
435
+ msum2 = _mm_hadd_ps(msum2, msum2);
436
+ msum2 = _mm_hadd_ps(msum2, msum2);
437
+ return _mm_cvtss_f32(msum2);
466
438
  }
467
439
 
468
- float fvec_L1 (const float * x, const float * y, size_t d)
469
- {
440
+ float fvec_L1(const float* x, const float* y, size_t d) {
470
441
  __m256 msum1 = _mm256_setzero_ps();
471
- __m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL));
442
+ __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
472
443
 
473
444
  while (d >= 8) {
474
- __m256 mx = _mm256_loadu_ps (x); x += 8;
475
- __m256 my = _mm256_loadu_ps (y); y += 8;
476
- const __m256 a_m_b = mx - my;
477
- msum1 += _mm256_and_ps(signmask, a_m_b);
445
+ __m256 mx = _mm256_loadu_ps(x);
446
+ x += 8;
447
+ __m256 my = _mm256_loadu_ps(y);
448
+ y += 8;
449
+ const __m256 a_m_b = _mm256_sub_ps(mx, my);
450
+ msum1 = _mm256_add_ps(msum1, _mm256_and_ps(signmask, a_m_b));
478
451
  d -= 8;
479
452
  }
480
453
 
481
454
  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
482
- msum2 += _mm256_extractf128_ps(msum1, 0);
483
- __m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL));
455
+ msum2 = _mm_add_ps(msum2, _mm256_extractf128_ps(msum1, 0));
456
+ __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL));
484
457
 
485
458
  if (d >= 4) {
486
- __m128 mx = _mm_loadu_ps (x); x += 4;
487
- __m128 my = _mm_loadu_ps (y); y += 4;
488
- const __m128 a_m_b = mx - my;
489
- msum2 += _mm_and_ps(signmask2, a_m_b);
459
+ __m128 mx = _mm_loadu_ps(x);
460
+ x += 4;
461
+ __m128 my = _mm_loadu_ps(y);
462
+ y += 4;
463
+ const __m128 a_m_b = _mm_sub_ps(mx, my);
464
+ msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
490
465
  d -= 4;
491
466
  }
492
467
 
493
468
  if (d > 0) {
494
- __m128 mx = masked_read (d, x);
495
- __m128 my = masked_read (d, y);
496
- __m128 a_m_b = mx - my;
497
- msum2 += _mm_and_ps(signmask2, a_m_b);
469
+ __m128 mx = masked_read(d, x);
470
+ __m128 my = masked_read(d, y);
471
+ __m128 a_m_b = _mm_sub_ps(mx, my);
472
+ msum2 = _mm_add_ps(msum2, _mm_and_ps(signmask2, a_m_b));
498
473
  }
499
474
 
500
- msum2 = _mm_hadd_ps (msum2, msum2);
501
- msum2 = _mm_hadd_ps (msum2, msum2);
502
- return _mm_cvtss_f32 (msum2);
475
+ msum2 = _mm_hadd_ps(msum2, msum2);
476
+ msum2 = _mm_hadd_ps(msum2, msum2);
477
+ return _mm_cvtss_f32(msum2);
503
478
  }
504
479
 
505
- float fvec_Linf (const float * x, const float * y, size_t d)
506
- {
480
+ float fvec_Linf(const float* x, const float* y, size_t d) {
507
481
  __m256 msum1 = _mm256_setzero_ps();
508
- __m256 signmask = __m256(_mm256_set1_epi32 (0x7fffffffUL));
482
+ __m256 signmask = _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffffUL));
509
483
 
510
484
  while (d >= 8) {
511
- __m256 mx = _mm256_loadu_ps (x); x += 8;
512
- __m256 my = _mm256_loadu_ps (y); y += 8;
513
- const __m256 a_m_b = mx - my;
485
+ __m256 mx = _mm256_loadu_ps(x);
486
+ x += 8;
487
+ __m256 my = _mm256_loadu_ps(y);
488
+ y += 8;
489
+ const __m256 a_m_b = _mm256_sub_ps(mx, my);
514
490
  msum1 = _mm256_max_ps(msum1, _mm256_and_ps(signmask, a_m_b));
515
491
  d -= 8;
516
492
  }
517
493
 
518
494
  __m128 msum2 = _mm256_extractf128_ps(msum1, 1);
519
- msum2 = _mm_max_ps (msum2, _mm256_extractf128_ps(msum1, 0));
520
- __m128 signmask2 = __m128(_mm_set1_epi32 (0x7fffffffUL));
495
+ msum2 = _mm_max_ps(msum2, _mm256_extractf128_ps(msum1, 0));
496
+ __m128 signmask2 = _mm_castsi128_ps(_mm_set1_epi32(0x7fffffffUL));
521
497
 
522
498
  if (d >= 4) {
523
- __m128 mx = _mm_loadu_ps (x); x += 4;
524
- __m128 my = _mm_loadu_ps (y); y += 4;
525
- const __m128 a_m_b = mx - my;
499
+ __m128 mx = _mm_loadu_ps(x);
500
+ x += 4;
501
+ __m128 my = _mm_loadu_ps(y);
502
+ y += 4;
503
+ const __m128 a_m_b = _mm_sub_ps(mx, my);
526
504
  msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
527
505
  d -= 4;
528
506
  }
529
507
 
530
508
  if (d > 0) {
531
- __m128 mx = masked_read (d, x);
532
- __m128 my = masked_read (d, y);
533
- __m128 a_m_b = mx - my;
509
+ __m128 mx = masked_read(d, x);
510
+ __m128 my = masked_read(d, y);
511
+ __m128 a_m_b = _mm_sub_ps(mx, my);
534
512
  msum2 = _mm_max_ps(msum2, _mm_and_ps(signmask2, a_m_b));
535
513
  }
536
514
 
537
515
  msum2 = _mm_max_ps(_mm_movehl_ps(msum2, msum2), msum2);
538
- msum2 = _mm_max_ps(msum2, _mm_shuffle_ps (msum2, msum2, 1));
539
- return _mm_cvtss_f32 (msum2);
516
+ msum2 = _mm_max_ps(msum2, _mm_shuffle_ps(msum2, msum2, 1));
517
+ return _mm_cvtss_f32(msum2);
540
518
  }
541
519
 
542
520
  #elif defined(__SSE3__) // But not AVX
543
521
 
544
- float fvec_L1 (const float * x, const float * y, size_t d)
545
- {
546
- return fvec_L1_ref (x, y, d);
522
+ float fvec_L1(const float* x, const float* y, size_t d) {
523
+ return fvec_L1_ref(x, y, d);
547
524
  }
548
525
 
549
- float fvec_Linf (const float * x, const float * y, size_t d)
550
- {
551
- return fvec_Linf_ref (x, y, d);
526
+ float fvec_Linf(const float* x, const float* y, size_t d) {
527
+ return fvec_Linf_ref(x, y, d);
552
528
  }
553
529
 
554
-
555
- float fvec_L2sqr (const float * x,
556
- const float * y,
557
- size_t d)
558
- {
530
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
559
531
  __m128 msum1 = _mm_setzero_ps();
560
532
 
561
533
  while (d >= 4) {
562
- __m128 mx = _mm_loadu_ps (x); x += 4;
563
- __m128 my = _mm_loadu_ps (y); y += 4;
564
- const __m128 a_m_b1 = mx - my;
565
- msum1 += a_m_b1 * a_m_b1;
534
+ __m128 mx = _mm_loadu_ps(x);
535
+ x += 4;
536
+ __m128 my = _mm_loadu_ps(y);
537
+ y += 4;
538
+ const __m128 a_m_b1 = _mm_sub_ps(mx, my);
539
+ msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
566
540
  d -= 4;
567
541
  }
568
542
 
569
543
  if (d > 0) {
570
544
  // add the last 1, 2 or 3 values
571
- __m128 mx = masked_read (d, x);
572
- __m128 my = masked_read (d, y);
573
- __m128 a_m_b1 = mx - my;
574
- msum1 += a_m_b1 * a_m_b1;
545
+ __m128 mx = masked_read(d, x);
546
+ __m128 my = masked_read(d, y);
547
+ __m128 a_m_b1 = _mm_sub_ps(mx, my);
548
+ msum1 = _mm_add_ps(msum1, _mm_mul_ps(a_m_b1, a_m_b1));
575
549
  }
576
550
 
577
- msum1 = _mm_hadd_ps (msum1, msum1);
578
- msum1 = _mm_hadd_ps (msum1, msum1);
579
- return _mm_cvtss_f32 (msum1);
551
+ msum1 = _mm_hadd_ps(msum1, msum1);
552
+ msum1 = _mm_hadd_ps(msum1, msum1);
553
+ return _mm_cvtss_f32(msum1);
580
554
  }
581
555
 
582
-
583
- float fvec_inner_product (const float * x,
584
- const float * y,
585
- size_t d)
586
- {
556
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
587
557
  __m128 mx, my;
588
558
  __m128 msum1 = _mm_setzero_ps();
589
559
 
590
560
  while (d >= 4) {
591
- mx = _mm_loadu_ps (x); x += 4;
592
- my = _mm_loadu_ps (y); y += 4;
593
- msum1 = _mm_add_ps (msum1, _mm_mul_ps (mx, my));
561
+ mx = _mm_loadu_ps(x);
562
+ x += 4;
563
+ my = _mm_loadu_ps(y);
564
+ y += 4;
565
+ msum1 = _mm_add_ps(msum1, _mm_mul_ps(mx, my));
594
566
  d -= 4;
595
567
  }
596
568
 
597
569
  // add the last 1, 2, or 3 values
598
- mx = masked_read (d, x);
599
- my = masked_read (d, y);
600
- __m128 prod = _mm_mul_ps (mx, my);
570
+ mx = masked_read(d, x);
571
+ my = masked_read(d, y);
572
+ __m128 prod = _mm_mul_ps(mx, my);
601
573
 
602
- msum1 = _mm_add_ps (msum1, prod);
574
+ msum1 = _mm_add_ps(msum1, prod);
603
575
 
604
- msum1 = _mm_hadd_ps (msum1, msum1);
605
- msum1 = _mm_hadd_ps (msum1, msum1);
606
- return _mm_cvtss_f32 (msum1);
576
+ msum1 = _mm_hadd_ps(msum1, msum1);
577
+ msum1 = _mm_hadd_ps(msum1, msum1);
578
+ return _mm_cvtss_f32(msum1);
607
579
  }
608
580
 
609
581
  #elif defined(__aarch64__)
610
582
 
611
-
612
- float fvec_L2sqr (const float * x,
613
- const float * y,
614
- size_t d)
615
- {
616
- if (d & 3) return fvec_L2sqr_ref (x, y, d);
617
- float32x4_t accu = vdupq_n_f32 (0);
618
- for (size_t i = 0; i < d; i += 4) {
619
- float32x4_t xi = vld1q_f32 (x + i);
620
- float32x4_t yi = vld1q_f32 (y + i);
621
- float32x4_t sq = vsubq_f32 (xi, yi);
622
- accu = vfmaq_f32 (accu, sq, sq);
583
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
584
+ float32x4_t accux4 = vdupq_n_f32(0);
585
+ const size_t d_simd = d - (d & 3);
586
+ size_t i;
587
+ for (i = 0; i < d_simd; i += 4) {
588
+ float32x4_t xi = vld1q_f32(x + i);
589
+ float32x4_t yi = vld1q_f32(y + i);
590
+ float32x4_t sq = vsubq_f32(xi, yi);
591
+ accux4 = vfmaq_f32(accux4, sq, sq);
592
+ }
593
+ float32x4_t accux2 = vpaddq_f32(accux4, accux4);
594
+ float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
595
+ for (; i < d; ++i) {
596
+ float32_t xi = x[i];
597
+ float32_t yi = y[i];
598
+ float32_t sq = xi - yi;
599
+ accux1 += sq * sq;
623
600
  }
624
- float32x4_t a2 = vpaddq_f32 (accu, accu);
625
- return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
601
+ return accux1;
626
602
  }
627
603
 
628
- float fvec_inner_product (const float * x,
629
- const float * y,
630
- size_t d)
631
- {
632
- if (d & 3) return fvec_inner_product_ref (x, y, d);
633
- float32x4_t accu = vdupq_n_f32 (0);
634
- for (size_t i = 0; i < d; i += 4) {
635
- float32x4_t xi = vld1q_f32 (x + i);
636
- float32x4_t yi = vld1q_f32 (y + i);
637
- accu = vfmaq_f32 (accu, xi, yi);
604
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
605
+ float32x4_t accux4 = vdupq_n_f32(0);
606
+ const size_t d_simd = d - (d & 3);
607
+ size_t i;
608
+ for (i = 0; i < d_simd; i += 4) {
609
+ float32x4_t xi = vld1q_f32(x + i);
610
+ float32x4_t yi = vld1q_f32(y + i);
611
+ accux4 = vfmaq_f32(accux4, xi, yi);
638
612
  }
639
- float32x4_t a2 = vpaddq_f32 (accu, accu);
640
- return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
613
+ float32x4_t accux2 = vpaddq_f32(accux4, accux4);
614
+ float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
615
+ for (; i < d; ++i) {
616
+ float32_t xi = x[i];
617
+ float32_t yi = y[i];
618
+ accux1 += xi * yi;
619
+ }
620
+ return accux1;
641
621
  }
642
622
 
643
- float fvec_norm_L2sqr (const float *x, size_t d)
644
- {
645
- if (d & 3) return fvec_norm_L2sqr_ref (x, d);
646
- float32x4_t accu = vdupq_n_f32 (0);
647
- for (size_t i = 0; i < d; i += 4) {
648
- float32x4_t xi = vld1q_f32 (x + i);
649
- accu = vfmaq_f32 (accu, xi, xi);
623
+ float fvec_norm_L2sqr(const float* x, size_t d) {
624
+ float32x4_t accux4 = vdupq_n_f32(0);
625
+ const size_t d_simd = d - (d & 3);
626
+ size_t i;
627
+ for (i = 0; i < d_simd; i += 4) {
628
+ float32x4_t xi = vld1q_f32(x + i);
629
+ accux4 = vfmaq_f32(accux4, xi, xi);
630
+ }
631
+ float32x4_t accux2 = vpaddq_f32(accux4, accux4);
632
+ float32_t accux1 = vdups_laneq_f32(accux2, 0) + vdups_laneq_f32(accux2, 1);
633
+ for (; i < d; ++i) {
634
+ float32_t xi = x[i];
635
+ accux1 += xi * xi;
650
636
  }
651
- float32x4_t a2 = vpaddq_f32 (accu, accu);
652
- return vdups_laneq_f32 (a2, 0) + vdups_laneq_f32 (a2, 1);
637
+ return accux1;
653
638
  }
654
639
 
655
640
  // not optimized for ARM
656
- void fvec_L2sqr_ny (float * dis, const float * x,
657
- const float * y, size_t d, size_t ny) {
658
- fvec_L2sqr_ny_ref (dis, x, y, d, ny);
641
+ void fvec_L2sqr_ny(
642
+ float* dis,
643
+ const float* x,
644
+ const float* y,
645
+ size_t d,
646
+ size_t ny) {
647
+ fvec_L2sqr_ny_ref(dis, x, y, d, ny);
659
648
  }
660
649
 
661
- float fvec_L1 (const float * x, const float * y, size_t d)
662
- {
663
- return fvec_L1_ref (x, y, d);
650
+ float fvec_L1(const float* x, const float* y, size_t d) {
651
+ return fvec_L1_ref(x, y, d);
664
652
  }
665
653
 
666
- float fvec_Linf (const float * x, const float * y, size_t d)
667
- {
668
- return fvec_Linf_ref (x, y, d);
654
+ float fvec_Linf(const float* x, const float* y, size_t d) {
655
+ return fvec_Linf_ref(x, y, d);
669
656
  }
670
657
 
658
+ void fvec_inner_products_ny(
659
+ float* dis,
660
+ const float* x,
661
+ const float* y,
662
+ size_t d,
663
+ size_t ny) {
664
+ fvec_inner_products_ny_ref(dis, x, y, d, ny);
665
+ }
671
666
 
672
667
  #else
673
668
  // scalar implementation
674
669
 
675
- float fvec_L2sqr (const float * x,
676
- const float * y,
677
- size_t d)
678
- {
679
- return fvec_L2sqr_ref (x, y, d);
670
+ float fvec_L2sqr(const float* x, const float* y, size_t d) {
671
+ return fvec_L2sqr_ref(x, y, d);
680
672
  }
681
673
 
682
- float fvec_L1 (const float * x, const float * y, size_t d)
683
- {
684
- return fvec_L1_ref (x, y, d);
674
+ float fvec_L1(const float* x, const float* y, size_t d) {
675
+ return fvec_L1_ref(x, y, d);
685
676
  }
686
677
 
687
- float fvec_Linf (const float * x, const float * y, size_t d)
688
- {
689
- return fvec_Linf_ref (x, y, d);
678
+ float fvec_Linf(const float* x, const float* y, size_t d) {
679
+ return fvec_Linf_ref(x, y, d);
690
680
  }
691
681
 
692
- float fvec_inner_product (const float * x,
693
- const float * y,
694
- size_t d)
695
- {
696
- return fvec_inner_product_ref (x, y, d);
682
+ float fvec_inner_product(const float* x, const float* y, size_t d) {
683
+ return fvec_inner_product_ref(x, y, d);
697
684
  }
698
685
 
699
- float fvec_norm_L2sqr (const float *x, size_t d)
700
- {
701
- return fvec_norm_L2sqr_ref (x, d);
686
+ float fvec_norm_L2sqr(const float* x, size_t d) {
687
+ return fvec_norm_L2sqr_ref(x, d);
702
688
  }
703
689
 
704
- void fvec_L2sqr_ny (float * dis, const float * x,
705
- const float * y, size_t d, size_t ny) {
706
- fvec_L2sqr_ny_ref (dis, x, y, d, ny);
690
+ void fvec_L2sqr_ny(
691
+ float* dis,
692
+ const float* x,
693
+ const float* y,
694
+ size_t d,
695
+ size_t ny) {
696
+ fvec_L2sqr_ny_ref(dis, x, y, d, ny);
707
697
  }
708
698
 
709
- void fvec_inner_products_ny (float * dis, const float * x,
710
- const float * y, size_t d, size_t ny) {
711
- fvec_inner_products_ny_ref (dis, x, y, d, ny);
699
+ void fvec_inner_products_ny(
700
+ float* dis,
701
+ const float* x,
702
+ const float* y,
703
+ size_t d,
704
+ size_t ny) {
705
+ fvec_inner_products_ny_ref(dis, x, y, d, ny);
712
706
  }
713
707
 
714
-
715
708
  #endif
716
709
 
717
-
718
-
719
-
720
-
721
-
722
-
723
-
724
-
725
-
726
-
727
-
728
-
729
-
730
-
731
-
732
-
733
-
734
-
735
-
736
710
  /***************************************************************************
737
711
  * heavily optimized table computations
738
712
  ***************************************************************************/
739
713
 
740
-
741
- static inline void fvec_madd_ref (size_t n, const float *a,
742
- float bf, const float *b, float *c) {
714
+ static inline void fvec_madd_ref(
715
+ size_t n,
716
+ const float* a,
717
+ float bf,
718
+ const float* b,
719
+ float* c) {
743
720
  for (size_t i = 0; i < n; i++)
744
721
  c[i] = a[i] + bf * b[i];
745
722
  }
746
723
 
747
724
  #ifdef __SSE3__
748
725
 
749
- static inline void fvec_madd_sse (size_t n, const float *a,
750
- float bf, const float *b, float *c) {
726
+ static inline void fvec_madd_sse(
727
+ size_t n,
728
+ const float* a,
729
+ float bf,
730
+ const float* b,
731
+ float* c) {
751
732
  n >>= 2;
752
- __m128 bf4 = _mm_set_ps1 (bf);
753
- __m128 * a4 = (__m128*)a;
754
- __m128 * b4 = (__m128*)b;
755
- __m128 * c4 = (__m128*)c;
733
+ __m128 bf4 = _mm_set_ps1(bf);
734
+ __m128* a4 = (__m128*)a;
735
+ __m128* b4 = (__m128*)b;
736
+ __m128* c4 = (__m128*)c;
756
737
 
757
738
  while (n--) {
758
- *c4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
739
+ *c4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
759
740
  b4++;
760
741
  a4++;
761
742
  c4++;
762
743
  }
763
744
  }
764
745
 
765
- void fvec_madd (size_t n, const float *a,
766
- float bf, const float *b, float *c)
767
- {
768
- if ((n & 3) == 0 &&
769
- ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
770
- fvec_madd_sse (n, a, bf, b, c);
746
+ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
747
+ if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
748
+ fvec_madd_sse(n, a, bf, b, c);
771
749
  else
772
- fvec_madd_ref (n, a, bf, b, c);
750
+ fvec_madd_ref(n, a, bf, b, c);
773
751
  }
774
752
 
775
753
  #else
776
754
 
777
- void fvec_madd (size_t n, const float *a,
778
- float bf, const float *b, float *c)
779
- {
780
- fvec_madd_ref (n, a, bf, b, c);
755
+ void fvec_madd(size_t n, const float* a, float bf, const float* b, float* c) {
756
+ fvec_madd_ref(n, a, bf, b, c);
781
757
  }
782
758
 
783
759
  #endif
784
760
 
785
- static inline int fvec_madd_and_argmin_ref (size_t n, const float *a,
786
- float bf, const float *b, float *c) {
761
+ static inline int fvec_madd_and_argmin_ref(
762
+ size_t n,
763
+ const float* a,
764
+ float bf,
765
+ const float* b,
766
+ float* c) {
787
767
  float vmin = 1e20;
788
768
  int imin = -1;
789
769
 
@@ -799,125 +779,100 @@ static inline int fvec_madd_and_argmin_ref (size_t n, const float *a,
799
779
 
800
780
  #ifdef __SSE3__
801
781
 
802
- static inline int fvec_madd_and_argmin_sse (
803
- size_t n, const float *a,
804
- float bf, const float *b, float *c) {
782
+ static inline int fvec_madd_and_argmin_sse(
783
+ size_t n,
784
+ const float* a,
785
+ float bf,
786
+ const float* b,
787
+ float* c) {
805
788
  n >>= 2;
806
- __m128 bf4 = _mm_set_ps1 (bf);
807
- __m128 vmin4 = _mm_set_ps1 (1e20);
808
- __m128i imin4 = _mm_set1_epi32 (-1);
809
- __m128i idx4 = _mm_set_epi32 (3, 2, 1, 0);
810
- __m128i inc4 = _mm_set1_epi32 (4);
811
- __m128 * a4 = (__m128*)a;
812
- __m128 * b4 = (__m128*)b;
813
- __m128 * c4 = (__m128*)c;
789
+ __m128 bf4 = _mm_set_ps1(bf);
790
+ __m128 vmin4 = _mm_set_ps1(1e20);
791
+ __m128i imin4 = _mm_set1_epi32(-1);
792
+ __m128i idx4 = _mm_set_epi32(3, 2, 1, 0);
793
+ __m128i inc4 = _mm_set1_epi32(4);
794
+ __m128* a4 = (__m128*)a;
795
+ __m128* b4 = (__m128*)b;
796
+ __m128* c4 = (__m128*)c;
814
797
 
815
798
  while (n--) {
816
- __m128 vc4 = _mm_add_ps (*a4, _mm_mul_ps (bf4, *b4));
799
+ __m128 vc4 = _mm_add_ps(*a4, _mm_mul_ps(bf4, *b4));
817
800
  *c4 = vc4;
818
- __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
801
+ __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
819
802
  // imin4 = _mm_blendv_epi8 (imin4, idx4, mask); // slower!
820
803
 
821
- imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
822
- _mm_andnot_si128 (mask, imin4));
823
- vmin4 = _mm_min_ps (vmin4, vc4);
804
+ imin4 = _mm_or_si128(
805
+ _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
806
+ vmin4 = _mm_min_ps(vmin4, vc4);
824
807
  b4++;
825
808
  a4++;
826
809
  c4++;
827
- idx4 = _mm_add_epi32 (idx4, inc4);
810
+ idx4 = _mm_add_epi32(idx4, inc4);
828
811
  }
829
812
 
830
813
  // 4 values -> 2
831
814
  {
832
- idx4 = _mm_shuffle_epi32 (imin4, 3 << 2 | 2);
833
- __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 3 << 2 | 2);
834
- __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
835
- imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
836
- _mm_andnot_si128 (mask, imin4));
837
- vmin4 = _mm_min_ps (vmin4, vc4);
815
+ idx4 = _mm_shuffle_epi32(imin4, 3 << 2 | 2);
816
+ __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 3 << 2 | 2);
817
+ __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
818
+ imin4 = _mm_or_si128(
819
+ _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
820
+ vmin4 = _mm_min_ps(vmin4, vc4);
838
821
  }
839
822
  // 2 values -> 1
840
823
  {
841
- idx4 = _mm_shuffle_epi32 (imin4, 1);
842
- __m128 vc4 = _mm_shuffle_ps (vmin4, vmin4, 1);
843
- __m128i mask = (__m128i)_mm_cmpgt_ps (vmin4, vc4);
844
- imin4 = _mm_or_si128 (_mm_and_si128 (mask, idx4),
845
- _mm_andnot_si128 (mask, imin4));
824
+ idx4 = _mm_shuffle_epi32(imin4, 1);
825
+ __m128 vc4 = _mm_shuffle_ps(vmin4, vmin4, 1);
826
+ __m128i mask = _mm_castps_si128(_mm_cmpgt_ps(vmin4, vc4));
827
+ imin4 = _mm_or_si128(
828
+ _mm_and_si128(mask, idx4), _mm_andnot_si128(mask, imin4));
846
829
  // vmin4 = _mm_min_ps (vmin4, vc4);
847
830
  }
848
- return _mm_cvtsi128_si32 (imin4);
831
+ return _mm_cvtsi128_si32(imin4);
849
832
  }
850
833
 
851
-
852
- int fvec_madd_and_argmin (size_t n, const float *a,
853
- float bf, const float *b, float *c)
854
- {
855
- if ((n & 3) == 0 &&
856
- ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
857
- return fvec_madd_and_argmin_sse (n, a, bf, b, c);
834
+ int fvec_madd_and_argmin(
835
+ size_t n,
836
+ const float* a,
837
+ float bf,
838
+ const float* b,
839
+ float* c) {
840
+ if ((n & 3) == 0 && ((((long)a) | ((long)b) | ((long)c)) & 15) == 0)
841
+ return fvec_madd_and_argmin_sse(n, a, bf, b, c);
858
842
  else
859
- return fvec_madd_and_argmin_ref (n, a, bf, b, c);
843
+ return fvec_madd_and_argmin_ref(n, a, bf, b, c);
860
844
  }
861
845
 
862
846
  #else
863
847
 
864
- int fvec_madd_and_argmin (size_t n, const float *a,
865
- float bf, const float *b, float *c)
866
- {
867
- return fvec_madd_and_argmin_ref (n, a, bf, b, c);
848
+ int fvec_madd_and_argmin(
849
+ size_t n,
850
+ const float* a,
851
+ float bf,
852
+ const float* b,
853
+ float* c) {
854
+ return fvec_madd_and_argmin_ref(n, a, bf, b, c);
868
855
  }
869
856
 
870
857
  #endif
871
858
 
872
-
873
859
  /***************************************************************************
874
860
  * PQ tables computations
875
861
  ***************************************************************************/
876
862
 
877
- #ifdef __AVX2__
878
-
879
863
  namespace {
880
864
 
881
-
882
- // get even float32's of a and b, interleaved
883
- simd8float32 geteven(simd8float32 a, simd8float32 b) {
884
- return simd8float32(
885
- _mm256_shuffle_ps(a.f, b.f, 0 << 0 | 2 << 2 | 0 << 4 | 2 << 6)
886
- );
887
- }
888
-
889
- // get odd float32's of a and b, interleaved
890
- simd8float32 getodd(simd8float32 a, simd8float32 b) {
891
- return simd8float32(
892
- _mm256_shuffle_ps(a.f, b.f, 1 << 0 | 3 << 2 | 1 << 4 | 3 << 6)
893
- );
894
- }
895
-
896
- // 3 cycles
897
- // if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
898
- simd8float32 getlow128(simd8float32 a, simd8float32 b) {
899
- return simd8float32(
900
- _mm256_permute2f128_ps(a.f, b.f, 0 | 2 << 4)
901
- );
902
- }
903
-
904
- simd8float32 gethigh128(simd8float32 a, simd8float32 b) {
905
- return simd8float32(
906
- _mm256_permute2f128_ps(a.f, b.f, 1 | 3 << 4)
907
- );
908
- }
909
-
910
865
  /// compute the IP for dsub = 2 for 8 centroids and 4 sub-vectors at a time
911
- template<bool is_inner_product>
866
+ template <bool is_inner_product>
912
867
  void pq2_8cents_table(
913
868
  const simd8float32 centroids[8],
914
869
  const simd8float32 x,
915
- float *out, size_t ldo, size_t nout = 4
916
- ) {
917
-
870
+ float* out,
871
+ size_t ldo,
872
+ size_t nout = 4) {
918
873
  simd8float32 ips[4];
919
874
 
920
- for(int i = 0; i < 4; i++) {
875
+ for (int i = 0; i < 4; i++) {
921
876
  simd8float32 p1, p2;
922
877
  if (is_inner_product) {
923
878
  p1 = x * centroids[2 * i];
@@ -941,21 +896,21 @@ void pq2_8cents_table(
941
896
  simd8float32 ip1 = getlow128(ip13a, ip13b);
942
897
  simd8float32 ip3 = gethigh128(ip13a, ip13b);
943
898
 
944
- switch(nout) {
945
- case 4:
946
- ip3.storeu(out + 3 * ldo);
947
- case 3:
948
- ip2.storeu(out + 2 * ldo);
949
- case 2:
950
- ip1.storeu(out + 1 * ldo);
951
- case 1:
952
- ip0.storeu(out);
899
+ switch (nout) {
900
+ case 4:
901
+ ip3.storeu(out + 3 * ldo);
902
+ case 3:
903
+ ip2.storeu(out + 2 * ldo);
904
+ case 2:
905
+ ip1.storeu(out + 1 * ldo);
906
+ case 1:
907
+ ip0.storeu(out);
953
908
  }
954
909
  }
955
910
 
956
- simd8float32 load_simd8float32_partial(const float *x, int n) {
911
+ simd8float32 load_simd8float32_partial(const float* x, int n) {
957
912
  ALIGNED(32) float tmp[8] = {0, 0, 0, 0, 0, 0, 0, 0};
958
- float *wp = tmp;
913
+ float* wp = tmp;
959
914
  for (int i = 0; i < n; i++) {
960
915
  *wp++ = *x++;
961
916
  }
@@ -964,25 +919,23 @@ simd8float32 load_simd8float32_partial(const float *x, int n) {
964
919
 
965
920
  } // anonymous namespace
966
921
 
967
-
968
-
969
-
970
922
  void compute_PQ_dis_tables_dsub2(
971
- size_t d, size_t ksub, const float *all_centroids,
972
- size_t nx, const float * x,
923
+ size_t d,
924
+ size_t ksub,
925
+ const float* all_centroids,
926
+ size_t nx,
927
+ const float* x,
973
928
  bool is_inner_product,
974
- float * dis_tables)
975
- {
929
+ float* dis_tables) {
976
930
  size_t M = d / 2;
977
931
  FAISS_THROW_IF_NOT(ksub % 8 == 0);
978
932
 
979
- for(size_t m0 = 0; m0 < M; m0 += 4) {
933
+ for (size_t m0 = 0; m0 < M; m0 += 4) {
980
934
  int m1 = std::min(M, m0 + 4);
981
- for(int k0 = 0; k0 < ksub; k0 += 8) {
982
-
935
+ for (int k0 = 0; k0 < ksub; k0 += 8) {
983
936
  simd8float32 centroids[8];
984
937
  for (int k = 0; k < 8; k++) {
985
- float centroid[8] __attribute__((aligned(32)));
938
+ ALIGNED(32) float centroid[8];
986
939
  size_t wp = 0;
987
940
  size_t rp = (m0 * ksub + k + k0) * 2;
988
941
  for (int m = m0; m < m1; m++) {
@@ -992,45 +945,82 @@ void compute_PQ_dis_tables_dsub2(
992
945
  }
993
946
  centroids[k] = simd8float32(centroid);
994
947
  }
995
- for(size_t i = 0; i < nx; i++) {
948
+ for (size_t i = 0; i < nx; i++) {
996
949
  simd8float32 xi;
997
950
  if (m1 == m0 + 4) {
998
951
  xi.loadu(x + i * d + m0 * 2);
999
952
  } else {
1000
- xi = load_simd8float32_partial(x + i * d + m0 * 2, 2 * (m1 - m0));
953
+ xi = load_simd8float32_partial(
954
+ x + i * d + m0 * 2, 2 * (m1 - m0));
1001
955
  }
1002
956
 
1003
- if(is_inner_product) {
957
+ if (is_inner_product) {
1004
958
  pq2_8cents_table<true>(
1005
- centroids, xi,
1006
- dis_tables + (i * M + m0) * ksub + k0,
1007
- ksub, m1 - m0
1008
- );
959
+ centroids,
960
+ xi,
961
+ dis_tables + (i * M + m0) * ksub + k0,
962
+ ksub,
963
+ m1 - m0);
1009
964
  } else {
1010
965
  pq2_8cents_table<false>(
1011
- centroids, xi,
1012
- dis_tables + (i * M + m0) * ksub + k0,
1013
- ksub, m1 - m0
1014
- );
966
+ centroids,
967
+ xi,
968
+ dis_tables + (i * M + m0) * ksub + k0,
969
+ ksub,
970
+ m1 - m0);
1015
971
  }
1016
972
  }
1017
973
  }
1018
974
  }
1019
-
1020
975
  }
1021
976
 
1022
- #else
977
+ /*********************************************************
978
+ * Vector to vector functions
979
+ *********************************************************/
1023
980
 
1024
- void compute_PQ_dis_tables_dsub2(
1025
- size_t d, size_t ksub, const float *all_centroids,
1026
- size_t nx, const float * x,
1027
- bool is_inner_product,
1028
- float * dis_tables)
1029
- {
1030
- FAISS_THROW_MSG("only implemented for AVX2");
981
+ void fvec_sub(size_t d, const float* a, const float* b, float* c) {
982
+ size_t i;
983
+ for (i = 0; i + 7 < d; i += 8) {
984
+ simd8float32 ci, ai, bi;
985
+ ai.loadu(a + i);
986
+ bi.loadu(b + i);
987
+ ci = ai - bi;
988
+ ci.storeu(c + i);
989
+ }
990
+ // finish non-multiple of 8 remainder
991
+ for (; i < d; i++) {
992
+ c[i] = a[i] - b[i];
993
+ }
1031
994
  }
1032
995
 
1033
- #endif
996
+ void fvec_add(size_t d, const float* a, const float* b, float* c) {
997
+ size_t i;
998
+ for (i = 0; i + 7 < d; i += 8) {
999
+ simd8float32 ci, ai, bi;
1000
+ ai.loadu(a + i);
1001
+ bi.loadu(b + i);
1002
+ ci = ai + bi;
1003
+ ci.storeu(c + i);
1004
+ }
1005
+ // finish non-multiple of 8 remainder
1006
+ for (; i < d; i++) {
1007
+ c[i] = a[i] + b[i];
1008
+ }
1009
+ }
1034
1010
 
1011
+ void fvec_add(size_t d, const float* a, float b, float* c) {
1012
+ size_t i;
1013
+ simd8float32 bv(b);
1014
+ for (i = 0; i + 7 < d; i += 8) {
1015
+ simd8float32 ci, ai, bi;
1016
+ ai.loadu(a + i);
1017
+ ci = ai + bv;
1018
+ ci.storeu(c + i);
1019
+ }
1020
+ // finish non-multiple of 8 remainder
1021
+ for (; i < d; i++) {
1022
+ c[i] = a[i] + b;
1023
+ }
1024
+ }
1035
1025
 
1036
1026
  } // namespace faiss