faiss 0.1.3 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (199) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +25 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +16 -4
  5. data/ext/faiss/ext.cpp +12 -308
  6. data/ext/faiss/extconf.rb +6 -3
  7. data/ext/faiss/index.cpp +189 -0
  8. data/ext/faiss/index_binary.cpp +75 -0
  9. data/ext/faiss/kmeans.cpp +40 -0
  10. data/ext/faiss/numo.hpp +867 -0
  11. data/ext/faiss/pca_matrix.cpp +33 -0
  12. data/ext/faiss/product_quantizer.cpp +53 -0
  13. data/ext/faiss/utils.cpp +13 -0
  14. data/ext/faiss/utils.h +5 -0
  15. data/lib/faiss.rb +0 -5
  16. data/lib/faiss/version.rb +1 -1
  17. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  18. data/vendor/faiss/faiss/AutoTune.h +6 -3
  19. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  20. data/vendor/faiss/faiss/Index.cpp +3 -4
  21. data/vendor/faiss/faiss/Index.h +3 -3
  22. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  23. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  26. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  27. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  29. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  30. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  31. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  32. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  33. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  34. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  35. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  37. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  38. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  39. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  41. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  42. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  43. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  44. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  45. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  46. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  47. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  48. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  49. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  50. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  51. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  52. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  53. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  54. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  55. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  56. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  57. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  58. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  59. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  60. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  61. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  62. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  63. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  64. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  65. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  66. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  67. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  68. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  69. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  70. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  71. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  72. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  73. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  74. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  75. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  76. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  77. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  78. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  79. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  80. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  81. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  82. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  83. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  84. data/vendor/faiss/faiss/impl/io.h +7 -2
  85. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  86. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  87. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  88. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  89. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  90. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  91. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  92. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  93. data/vendor/faiss/faiss/index_io.h +1 -48
  94. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  95. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  96. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  97. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  98. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  99. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  100. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  101. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  102. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  103. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  104. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  105. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  106. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  107. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  108. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  109. data/vendor/faiss/faiss/utils/distances.h +28 -20
  110. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  111. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  112. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  113. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  114. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  115. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  116. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  117. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  118. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  119. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  120. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  121. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  122. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  123. metadata +54 -149
  124. data/lib/faiss/index.rb +0 -20
  125. data/lib/faiss/index_binary.rb +0 -20
  126. data/lib/faiss/kmeans.rb +0 -15
  127. data/lib/faiss/pca_matrix.rb +0 -15
  128. data/lib/faiss/product_quantizer.rb +0 -22
  129. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  130. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  131. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  132. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  133. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  134. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  135. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  136. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  137. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  138. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  139. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  140. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  141. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  142. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  143. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  144. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  145. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  146. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  147. data/vendor/faiss/c_api/Index_c.h +0 -183
  148. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  149. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  150. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  151. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  152. data/vendor/faiss/c_api/error_c.h +0 -42
  153. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  154. data/vendor/faiss/c_api/error_impl.h +0 -16
  155. data/vendor/faiss/c_api/faiss_c.h +0 -58
  156. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  157. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  158. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  159. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  160. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  161. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  162. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  163. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  164. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  165. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  166. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  167. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  168. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  169. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  170. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  171. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  172. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  173. data/vendor/faiss/c_api/index_io_c.h +0 -50
  174. data/vendor/faiss/c_api/macros_impl.h +0 -110
  175. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  176. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  177. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  178. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  179. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  180. data/vendor/faiss/misc/test_blas.cpp +0 -87
  181. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  182. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  183. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  184. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  185. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  186. data/vendor/faiss/tests/test_merge.cpp +0 -260
  187. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  188. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  189. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  190. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  191. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  192. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  193. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  194. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  195. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  196. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  197. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  198. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  199. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -156,6 +156,14 @@ void pairwise_indexed_inner_product (
156
156
  // threshold on nx above which we switch to BLAS to compute distances
157
157
  FAISS_API extern int distance_compute_blas_threshold;
158
158
 
159
+ // block sizes for BLAS distance computations
160
+ FAISS_API extern int distance_compute_blas_query_bs;
161
+ FAISS_API extern int distance_compute_blas_database_bs;
162
+
163
+ // above this number of results we switch to a reservoir to collect results
164
+ // rather than a heap
165
+ FAISS_API extern int distance_compute_min_k_reservoir;
166
+
159
167
  /** Return the k nearest neighors of each of the nx vectors x among the ny
160
168
  * vector y, w.r.t to max inner product
161
169
  *
@@ -169,27 +177,17 @@ void knn_inner_product (
169
177
  size_t d, size_t nx, size_t ny,
170
178
  float_minheap_array_t * res);
171
179
 
172
- /** Same as knn_inner_product, for the L2 distance */
180
+ /** Same as knn_inner_product, for the L2 distance
181
+ * @param y_norm2 norms for the y vectors (nullptr or size ny)
182
+ */
173
183
  void knn_L2sqr (
174
184
  const float * x,
175
185
  const float * y,
176
186
  size_t d, size_t nx, size_t ny,
177
- float_maxheap_array_t * res);
187
+ float_maxheap_array_t * res,
188
+ const float *y_norm2 = nullptr);
178
189
 
179
190
 
180
-
181
- /** same as knn_L2sqr, but base_shift[bno] is subtracted to all
182
- * computed distances.
183
- *
184
- * @param base_shift size ny
185
- */
186
- void knn_L2sqr_base_shift (
187
- const float * x,
188
- const float * y,
189
- size_t d, size_t nx, size_t ny,
190
- float_maxheap_array_t * res,
191
- const float *base_shift);
192
-
193
191
  /* Find the nearest neighbors for nx queries in a set of ny vectors
194
192
  * indexed by ids. May be useful for re-ranking a pre-selected vector list
195
193
  */
@@ -200,11 +198,12 @@ void knn_inner_products_by_idx (
200
198
  size_t d, size_t nx, size_t ny,
201
199
  float_minheap_array_t * res);
202
200
 
203
- void knn_L2sqr_by_idx (const float * x,
204
- const float * y,
205
- const int64_t * ids,
206
- size_t d, size_t nx, size_t ny,
207
- float_maxheap_array_t * res);
201
+ void knn_L2sqr_by_idx (
202
+ const float * x,
203
+ const float * y,
204
+ const int64_t * ids,
205
+ size_t d, size_t nx, size_t ny,
206
+ float_maxheap_array_t * res);
208
207
 
209
208
  /***************************************************************************
210
209
  * Range search
@@ -239,6 +238,15 @@ void range_search_inner_product (
239
238
  RangeSearchResult *result);
240
239
 
241
240
 
241
+ /***************************************************************************
242
+ * PQ tables computations
243
+ ***************************************************************************/
242
244
 
245
+ /// specialized function for PQ2
246
+ void compute_PQ_dis_tables_dsub2(
247
+ size_t d, size_t ksub, const float *centroids,
248
+ size_t nx, const float * x,
249
+ bool is_inner_product,
250
+ float * dis_tables);
243
251
 
244
252
  } // namespace faiss
@@ -14,6 +14,9 @@
14
14
  #include <cstring>
15
15
  #include <cmath>
16
16
 
17
+ #include <faiss/utils/simdlib.h>
18
+ #include <faiss/impl/FaissAssert.h>
19
+
17
20
  #ifdef __SSE3__
18
21
  #include <immintrin.h>
19
22
  #endif
@@ -127,6 +130,29 @@ void fvec_L2sqr_ny_ref (float * dis,
127
130
  }
128
131
 
129
132
 
133
+ void fvec_inner_products_ny_ref (float * ip,
134
+ const float * x,
135
+ const float * y,
136
+ size_t d, size_t ny)
137
+ {
138
+ // BLAS slower for the use cases here
139
+ #if 0
140
+ {
141
+ FINTEGER di = d;
142
+ FINTEGER nyi = ny;
143
+ float one = 1.0, zero = 0.0;
144
+ FINTEGER onei = 1;
145
+ sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei);
146
+ }
147
+ #endif
148
+ for (size_t i = 0; i < ny; i++) {
149
+ ip[i] = fvec_inner_product (x, y, d);
150
+ y += d;
151
+ }
152
+ }
153
+
154
+
155
+
130
156
 
131
157
 
132
158
  /*********************************************************
@@ -174,12 +200,39 @@ float fvec_norm_L2sqr (const float * x,
174
200
 
175
201
  namespace {
176
202
 
177
- float sqr (float x) {
178
- return x * x;
179
- }
203
+ /// Function that does a component-wise operation between x and y
204
+ /// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
205
+ /// functions below
206
+ struct ElementOpL2 {
207
+
208
+ static float op (float x, float y) {
209
+ float tmp = x - y;
210
+ return tmp * tmp;
211
+ }
212
+
213
+ static __m128 op (__m128 x, __m128 y) {
214
+ __m128 tmp = x - y;
215
+ return tmp * tmp;
216
+ }
217
+
218
+ };
180
219
 
220
+ /// Function that does a component-wise operation between x and y
221
+ /// to compute inner products
222
+ struct ElementOpIP {
181
223
 
182
- void fvec_L2sqr_ny_D1 (float * dis, const float * x,
224
+ static float op (float x, float y) {
225
+ return x * y;
226
+ }
227
+
228
+ static __m128 op (__m128 x, __m128 y) {
229
+ return x * y;
230
+ }
231
+
232
+ };
233
+
234
+ template<class ElementOp>
235
+ void fvec_op_ny_D1 (float * dis, const float * x,
183
236
  const float * y, size_t ny)
184
237
  {
185
238
  float x0s = x[0];
@@ -187,11 +240,9 @@ void fvec_L2sqr_ny_D1 (float * dis, const float * x,
187
240
 
188
241
  size_t i;
189
242
  for (i = 0; i + 3 < ny; i += 4) {
190
- __m128 tmp, accu;
191
- tmp = x0 - _mm_loadu_ps (y); y += 4;
192
- accu = tmp * tmp;
243
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
193
244
  dis[i] = _mm_cvtss_f32 (accu);
194
- tmp = _mm_shuffle_ps (accu, accu, 1);
245
+ __m128 tmp = _mm_shuffle_ps (accu, accu, 1);
195
246
  dis[i + 1] = _mm_cvtss_f32 (tmp);
196
247
  tmp = _mm_shuffle_ps (accu, accu, 2);
197
248
  dis[i + 2] = _mm_cvtss_f32 (tmp);
@@ -199,69 +250,63 @@ void fvec_L2sqr_ny_D1 (float * dis, const float * x,
199
250
  dis[i + 3] = _mm_cvtss_f32 (tmp);
200
251
  }
201
252
  while (i < ny) { // handle non-multiple-of-4 case
202
- dis[i++] = sqr(x0s - *y++);
253
+ dis[i++] = ElementOp::op(x0s, *y++);
203
254
  }
204
255
  }
205
256
 
206
-
207
- void fvec_L2sqr_ny_D2 (float * dis, const float * x,
257
+ template<class ElementOp>
258
+ void fvec_op_ny_D2 (float * dis, const float * x,
208
259
  const float * y, size_t ny)
209
260
  {
210
261
  __m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
211
262
 
212
263
  size_t i;
213
264
  for (i = 0; i + 1 < ny; i += 2) {
214
- __m128 tmp, accu;
215
- tmp = x0 - _mm_loadu_ps (y); y += 4;
216
- accu = tmp * tmp;
265
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
217
266
  accu = _mm_hadd_ps (accu, accu);
218
267
  dis[i] = _mm_cvtss_f32 (accu);
219
268
  accu = _mm_shuffle_ps (accu, accu, 3);
220
269
  dis[i + 1] = _mm_cvtss_f32 (accu);
221
270
  }
222
271
  if (i < ny) { // handle odd case
223
- dis[i] = sqr(x[0] - y[0]) + sqr(x[1] - y[1]);
272
+ dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]);
224
273
  }
225
274
  }
226
275
 
227
276
 
228
277
 
229
- void fvec_L2sqr_ny_D4 (float * dis, const float * x,
278
+ template<class ElementOp>
279
+ void fvec_op_ny_D4 (float * dis, const float * x,
230
280
  const float * y, size_t ny)
231
281
  {
232
282
  __m128 x0 = _mm_loadu_ps(x);
233
283
 
234
284
  for (size_t i = 0; i < ny; i++) {
235
- __m128 tmp, accu;
236
- tmp = x0 - _mm_loadu_ps (y); y += 4;
237
- accu = tmp * tmp;
285
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
238
286
  accu = _mm_hadd_ps (accu, accu);
239
287
  accu = _mm_hadd_ps (accu, accu);
240
288
  dis[i] = _mm_cvtss_f32 (accu);
241
289
  }
242
290
  }
243
291
 
244
-
245
- void fvec_L2sqr_ny_D8 (float * dis, const float * x,
292
+ template<class ElementOp>
293
+ void fvec_op_ny_D8 (float * dis, const float * x,
246
294
  const float * y, size_t ny)
247
295
  {
248
296
  __m128 x0 = _mm_loadu_ps(x);
249
297
  __m128 x1 = _mm_loadu_ps(x + 4);
250
298
 
251
299
  for (size_t i = 0; i < ny; i++) {
252
- __m128 tmp, accu;
253
- tmp = x0 - _mm_loadu_ps (y); y += 4;
254
- accu = tmp * tmp;
255
- tmp = x1 - _mm_loadu_ps (y); y += 4;
256
- accu += tmp * tmp;
300
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
301
+ accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
257
302
  accu = _mm_hadd_ps (accu, accu);
258
303
  accu = _mm_hadd_ps (accu, accu);
259
304
  dis[i] = _mm_cvtss_f32 (accu);
260
305
  }
261
306
  }
262
307
 
263
-
264
- void fvec_L2sqr_ny_D12 (float * dis, const float * x,
308
+ template<class ElementOp>
309
+ void fvec_op_ny_D12 (float * dis, const float * x,
265
310
  const float * y, size_t ny)
266
311
  {
267
312
  __m128 x0 = _mm_loadu_ps(x);
@@ -269,13 +314,9 @@ void fvec_L2sqr_ny_D12 (float * dis, const float * x,
269
314
  __m128 x2 = _mm_loadu_ps(x + 8);
270
315
 
271
316
  for (size_t i = 0; i < ny; i++) {
272
- __m128 tmp, accu;
273
- tmp = x0 - _mm_loadu_ps (y); y += 4;
274
- accu = tmp * tmp;
275
- tmp = x1 - _mm_loadu_ps (y); y += 4;
276
- accu += tmp * tmp;
277
- tmp = x2 - _mm_loadu_ps (y); y += 4;
278
- accu += tmp * tmp;
317
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
318
+ accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
319
+ accu += ElementOp::op(x2, _mm_loadu_ps (y)); y += 4;
279
320
  accu = _mm_hadd_ps (accu, accu);
280
321
  accu = _mm_hadd_ps (accu, accu);
281
322
  dis[i] = _mm_cvtss_f32 (accu);
@@ -283,31 +324,52 @@ void fvec_L2sqr_ny_D12 (float * dis, const float * x,
283
324
  }
284
325
 
285
326
 
327
+
286
328
  } // anonymous namespace
287
329
 
288
330
  void fvec_L2sqr_ny (float * dis, const float * x,
289
331
  const float * y, size_t d, size_t ny) {
290
332
  // optimized for a few special cases
291
- switch(d) {
292
- case 1:
293
- fvec_L2sqr_ny_D1 (dis, x, y, ny);
294
- return;
295
- case 2:
296
- fvec_L2sqr_ny_D2 (dis, x, y, ny);
297
- return;
298
- case 4:
299
- fvec_L2sqr_ny_D4 (dis, x, y, ny);
333
+
334
+ #define DISPATCH(dval) \
335
+ case dval:\
336
+ fvec_op_ny_D ## dval <ElementOpL2> (dis, x, y, ny); \
300
337
  return;
301
- case 8:
302
- fvec_L2sqr_ny_D8 (dis, x, y, ny);
338
+
339
+ switch(d) {
340
+ DISPATCH(1)
341
+ DISPATCH(2)
342
+ DISPATCH(4)
343
+ DISPATCH(8)
344
+ DISPATCH(12)
345
+ default:
346
+ fvec_L2sqr_ny_ref (dis, x, y, d, ny);
303
347
  return;
304
- case 12:
305
- fvec_L2sqr_ny_D12 (dis, x, y, ny);
348
+ }
349
+ #undef DISPATCH
350
+
351
+ }
352
+
353
+ void fvec_inner_products_ny (float * dis, const float * x,
354
+ const float * y, size_t d, size_t ny) {
355
+
356
+ #define DISPATCH(dval) \
357
+ case dval:\
358
+ fvec_op_ny_D ## dval <ElementOpIP> (dis, x, y, ny); \
306
359
  return;
360
+
361
+ switch(d) {
362
+ DISPATCH(1)
363
+ DISPATCH(2)
364
+ DISPATCH(4)
365
+ DISPATCH(8)
366
+ DISPATCH(12)
307
367
  default:
308
- fvec_L2sqr_ny_ref (dis, x, y, d, ny);
368
+ fvec_inner_products_ny_ref (dis, x, y, d, ny);
309
369
  return;
310
370
  }
371
+ #undef DISPATCH
372
+
311
373
  }
312
374
 
313
375
 
@@ -644,6 +706,11 @@ void fvec_L2sqr_ny (float * dis, const float * x,
644
706
  fvec_L2sqr_ny_ref (dis, x, y, d, ny);
645
707
  }
646
708
 
709
+ void fvec_inner_products_ny (float * dis, const float * x,
710
+ const float * y, size_t d, size_t ny) {
711
+ fvec_inner_products_ny_ref (dis, x, y, d, ny);
712
+ }
713
+
647
714
 
648
715
  #endif
649
716
 
@@ -803,6 +870,167 @@ int fvec_madd_and_argmin (size_t n, const float *a,
803
870
  #endif
804
871
 
805
872
 
873
+ /***************************************************************************
874
+ * PQ tables computations
875
+ ***************************************************************************/
876
+
877
+ #ifdef __AVX2__
878
+
879
+ namespace {
880
+
881
+
882
+ // get even float32's of a and b, interleaved
883
+ simd8float32 geteven(simd8float32 a, simd8float32 b) {
884
+ return simd8float32(
885
+ _mm256_shuffle_ps(a.f, b.f, 0 << 0 | 2 << 2 | 0 << 4 | 2 << 6)
886
+ );
887
+ }
888
+
889
+ // get odd float32's of a and b, interleaved
890
+ simd8float32 getodd(simd8float32 a, simd8float32 b) {
891
+ return simd8float32(
892
+ _mm256_shuffle_ps(a.f, b.f, 1 << 0 | 3 << 2 | 1 << 4 | 3 << 6)
893
+ );
894
+ }
895
+
896
+ // 3 cycles
897
+ // if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
898
+ simd8float32 getlow128(simd8float32 a, simd8float32 b) {
899
+ return simd8float32(
900
+ _mm256_permute2f128_ps(a.f, b.f, 0 | 2 << 4)
901
+ );
902
+ }
903
+
904
+ simd8float32 gethigh128(simd8float32 a, simd8float32 b) {
905
+ return simd8float32(
906
+ _mm256_permute2f128_ps(a.f, b.f, 1 | 3 << 4)
907
+ );
908
+ }
909
+
910
+ /// compute the IP for dsub = 2 for 8 centroids and 4 sub-vectors at a time
911
+ template<bool is_inner_product>
912
+ void pq2_8cents_table(
913
+ const simd8float32 centroids[8],
914
+ const simd8float32 x,
915
+ float *out, size_t ldo, size_t nout = 4
916
+ ) {
917
+
918
+ simd8float32 ips[4];
919
+
920
+ for(int i = 0; i < 4; i++) {
921
+ simd8float32 p1, p2;
922
+ if (is_inner_product) {
923
+ p1 = x * centroids[2 * i];
924
+ p2 = x * centroids[2 * i + 1];
925
+ } else {
926
+ p1 = (x - centroids[2 * i]);
927
+ p1 = p1 * p1;
928
+ p2 = (x - centroids[2 * i + 1]);
929
+ p2 = p2 * p2;
930
+ }
931
+ ips[i] = hadd(p1, p2);
932
+ }
933
+
934
+ simd8float32 ip02a = geteven(ips[0], ips[1]);
935
+ simd8float32 ip02b = geteven(ips[2], ips[3]);
936
+ simd8float32 ip0 = getlow128(ip02a, ip02b);
937
+ simd8float32 ip2 = gethigh128(ip02a, ip02b);
938
+
939
+ simd8float32 ip13a = getodd(ips[0], ips[1]);
940
+ simd8float32 ip13b = getodd(ips[2], ips[3]);
941
+ simd8float32 ip1 = getlow128(ip13a, ip13b);
942
+ simd8float32 ip3 = gethigh128(ip13a, ip13b);
943
+
944
+ switch(nout) {
945
+ case 4:
946
+ ip3.storeu(out + 3 * ldo);
947
+ case 3:
948
+ ip2.storeu(out + 2 * ldo);
949
+ case 2:
950
+ ip1.storeu(out + 1 * ldo);
951
+ case 1:
952
+ ip0.storeu(out);
953
+ }
954
+ }
955
+
956
+ simd8float32 load_simd8float32_partial(const float *x, int n) {
957
+ ALIGNED(32) float tmp[8] = {0, 0, 0, 0, 0, 0, 0, 0};
958
+ float *wp = tmp;
959
+ for (int i = 0; i < n; i++) {
960
+ *wp++ = *x++;
961
+ }
962
+ return simd8float32(tmp);
963
+ }
964
+
965
+ } // anonymous namespace
966
+
967
+
968
+
969
+
970
+ void compute_PQ_dis_tables_dsub2(
971
+ size_t d, size_t ksub, const float *all_centroids,
972
+ size_t nx, const float * x,
973
+ bool is_inner_product,
974
+ float * dis_tables)
975
+ {
976
+ size_t M = d / 2;
977
+ FAISS_THROW_IF_NOT(ksub % 8 == 0);
978
+
979
+ for(size_t m0 = 0; m0 < M; m0 += 4) {
980
+ int m1 = std::min(M, m0 + 4);
981
+ for(int k0 = 0; k0 < ksub; k0 += 8) {
982
+
983
+ simd8float32 centroids[8];
984
+ for (int k = 0; k < 8; k++) {
985
+ float centroid[8] __attribute__((aligned(32)));
986
+ size_t wp = 0;
987
+ size_t rp = (m0 * ksub + k + k0) * 2;
988
+ for (int m = m0; m < m1; m++) {
989
+ centroid[wp++] = all_centroids[rp];
990
+ centroid[wp++] = all_centroids[rp + 1];
991
+ rp += 2 * ksub;
992
+ }
993
+ centroids[k] = simd8float32(centroid);
994
+ }
995
+ for(size_t i = 0; i < nx; i++) {
996
+ simd8float32 xi;
997
+ if (m1 == m0 + 4) {
998
+ xi.loadu(x + i * d + m0 * 2);
999
+ } else {
1000
+ xi = load_simd8float32_partial(x + i * d + m0 * 2, 2 * (m1 - m0));
1001
+ }
1002
+
1003
+ if(is_inner_product) {
1004
+ pq2_8cents_table<true>(
1005
+ centroids, xi,
1006
+ dis_tables + (i * M + m0) * ksub + k0,
1007
+ ksub, m1 - m0
1008
+ );
1009
+ } else {
1010
+ pq2_8cents_table<false>(
1011
+ centroids, xi,
1012
+ dis_tables + (i * M + m0) * ksub + k0,
1013
+ ksub, m1 - m0
1014
+ );
1015
+ }
1016
+ }
1017
+ }
1018
+ }
1019
+
1020
+ }
1021
+
1022
+ #else
1023
+
1024
+ void compute_PQ_dis_tables_dsub2(
1025
+ size_t d, size_t ksub, const float *all_centroids,
1026
+ size_t nx, const float * x,
1027
+ bool is_inner_product,
1028
+ float * dis_tables)
1029
+ {
1030
+ FAISS_THROW_MSG("only implemented for AVX2");
1031
+ }
1032
+
1033
+ #endif
806
1034
 
807
1035
 
808
1036
  } // namespace faiss