faiss 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +1 -1
  6. data/lib/faiss/version.rb +1 -1
  7. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  8. data/vendor/faiss/faiss/AutoTune.h +6 -3
  9. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  10. data/vendor/faiss/faiss/Index.cpp +3 -4
  11. data/vendor/faiss/faiss/Index.h +3 -3
  12. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  13. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  14. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  15. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  16. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  17. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  18. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  19. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  20. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  21. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  22. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  24. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  25. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  26. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  27. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  28. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  29. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  30. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  31. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  32. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  33. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  34. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  35. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  36. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  37. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  38. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  39. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  40. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  41. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  42. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  43. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  44. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  47. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  48. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  49. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  50. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  51. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  52. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  53. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  54. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  55. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  56. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  57. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  58. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  59. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  60. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  61. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  62. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  63. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  64. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  65. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  71. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  72. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  73. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  74. data/vendor/faiss/faiss/impl/io.h +7 -2
  75. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  76. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  77. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  78. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  79. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  81. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  82. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  83. data/vendor/faiss/faiss/index_io.h +1 -48
  84. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  85. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  86. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  87. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  88. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  89. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  90. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  91. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  92. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  93. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  94. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  95. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  96. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  97. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  98. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  99. data/vendor/faiss/faiss/utils/distances.h +28 -20
  100. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  101. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  102. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  103. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  104. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  105. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  106. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  107. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  108. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  109. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  110. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  111. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  112. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  113. metadata +43 -141
  114. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  115. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  116. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  117. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  118. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  119. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  120. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  121. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  122. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  123. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  124. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  125. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  126. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  127. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  128. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  129. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  130. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  131. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  132. data/vendor/faiss/c_api/Index_c.h +0 -183
  133. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  134. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  135. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  136. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  137. data/vendor/faiss/c_api/error_c.h +0 -42
  138. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  139. data/vendor/faiss/c_api/error_impl.h +0 -16
  140. data/vendor/faiss/c_api/faiss_c.h +0 -58
  141. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  142. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  143. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  144. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  145. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  146. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  147. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  148. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  149. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  150. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  151. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  152. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  153. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  154. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  155. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  156. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  157. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  158. data/vendor/faiss/c_api/index_io_c.h +0 -50
  159. data/vendor/faiss/c_api/macros_impl.h +0 -110
  160. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  161. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  162. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  163. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  164. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  165. data/vendor/faiss/misc/test_blas.cpp +0 -87
  166. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  167. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  168. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  169. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  170. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  171. data/vendor/faiss/tests/test_merge.cpp +0 -260
  172. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  173. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  174. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  175. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  176. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  177. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  178. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  179. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  180. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  181. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  182. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  183. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  184. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -156,6 +156,14 @@ void pairwise_indexed_inner_product (
156
156
  // threshold on nx above which we switch to BLAS to compute distances
157
157
  FAISS_API extern int distance_compute_blas_threshold;
158
158
 
159
+ // block sizes for BLAS distance computations
160
+ FAISS_API extern int distance_compute_blas_query_bs;
161
+ FAISS_API extern int distance_compute_blas_database_bs;
162
+
163
+ // above this number of results we switch to a reservoir to collect results
164
+ // rather than a heap
165
+ FAISS_API extern int distance_compute_min_k_reservoir;
166
+
159
167
  /** Return the k nearest neighors of each of the nx vectors x among the ny
160
168
  * vector y, w.r.t to max inner product
161
169
  *
@@ -169,27 +177,17 @@ void knn_inner_product (
169
177
  size_t d, size_t nx, size_t ny,
170
178
  float_minheap_array_t * res);
171
179
 
172
- /** Same as knn_inner_product, for the L2 distance */
180
+ /** Same as knn_inner_product, for the L2 distance
181
+ * @param y_norm2 norms for the y vectors (nullptr or size ny)
182
+ */
173
183
  void knn_L2sqr (
174
184
  const float * x,
175
185
  const float * y,
176
186
  size_t d, size_t nx, size_t ny,
177
- float_maxheap_array_t * res);
187
+ float_maxheap_array_t * res,
188
+ const float *y_norm2 = nullptr);
178
189
 
179
190
 
180
-
181
- /** same as knn_L2sqr, but base_shift[bno] is subtracted to all
182
- * computed distances.
183
- *
184
- * @param base_shift size ny
185
- */
186
- void knn_L2sqr_base_shift (
187
- const float * x,
188
- const float * y,
189
- size_t d, size_t nx, size_t ny,
190
- float_maxheap_array_t * res,
191
- const float *base_shift);
192
-
193
191
  /* Find the nearest neighbors for nx queries in a set of ny vectors
194
192
  * indexed by ids. May be useful for re-ranking a pre-selected vector list
195
193
  */
@@ -200,11 +198,12 @@ void knn_inner_products_by_idx (
200
198
  size_t d, size_t nx, size_t ny,
201
199
  float_minheap_array_t * res);
202
200
 
203
- void knn_L2sqr_by_idx (const float * x,
204
- const float * y,
205
- const int64_t * ids,
206
- size_t d, size_t nx, size_t ny,
207
- float_maxheap_array_t * res);
201
+ void knn_L2sqr_by_idx (
202
+ const float * x,
203
+ const float * y,
204
+ const int64_t * ids,
205
+ size_t d, size_t nx, size_t ny,
206
+ float_maxheap_array_t * res);
208
207
 
209
208
  /***************************************************************************
210
209
  * Range search
@@ -239,6 +238,15 @@ void range_search_inner_product (
239
238
  RangeSearchResult *result);
240
239
 
241
240
 
241
+ /***************************************************************************
242
+ * PQ tables computations
243
+ ***************************************************************************/
242
244
 
245
+ /// specialized function for PQ2
246
+ void compute_PQ_dis_tables_dsub2(
247
+ size_t d, size_t ksub, const float *centroids,
248
+ size_t nx, const float * x,
249
+ bool is_inner_product,
250
+ float * dis_tables);
243
251
 
244
252
  } // namespace faiss
@@ -14,6 +14,9 @@
14
14
  #include <cstring>
15
15
  #include <cmath>
16
16
 
17
+ #include <faiss/utils/simdlib.h>
18
+ #include <faiss/impl/FaissAssert.h>
19
+
17
20
  #ifdef __SSE3__
18
21
  #include <immintrin.h>
19
22
  #endif
@@ -127,6 +130,29 @@ void fvec_L2sqr_ny_ref (float * dis,
127
130
  }
128
131
 
129
132
 
133
+ void fvec_inner_products_ny_ref (float * ip,
134
+ const float * x,
135
+ const float * y,
136
+ size_t d, size_t ny)
137
+ {
138
+ // BLAS slower for the use cases here
139
+ #if 0
140
+ {
141
+ FINTEGER di = d;
142
+ FINTEGER nyi = ny;
143
+ float one = 1.0, zero = 0.0;
144
+ FINTEGER onei = 1;
145
+ sgemv_ ("T", &di, &nyi, &one, y, &di, x, &onei, &zero, ip, &onei);
146
+ }
147
+ #endif
148
+ for (size_t i = 0; i < ny; i++) {
149
+ ip[i] = fvec_inner_product (x, y, d);
150
+ y += d;
151
+ }
152
+ }
153
+
154
+
155
+
130
156
 
131
157
 
132
158
  /*********************************************************
@@ -174,12 +200,39 @@ float fvec_norm_L2sqr (const float * x,
174
200
 
175
201
  namespace {
176
202
 
177
- float sqr (float x) {
178
- return x * x;
179
- }
203
+ /// Function that does a component-wise operation between x and y
204
+ /// to compute L2 distances. ElementOp can then be used in the fvec_op_ny
205
+ /// functions below
206
+ struct ElementOpL2 {
207
+
208
+ static float op (float x, float y) {
209
+ float tmp = x - y;
210
+ return tmp * tmp;
211
+ }
212
+
213
+ static __m128 op (__m128 x, __m128 y) {
214
+ __m128 tmp = x - y;
215
+ return tmp * tmp;
216
+ }
217
+
218
+ };
180
219
 
220
+ /// Function that does a component-wise operation between x and y
221
+ /// to compute inner products
222
+ struct ElementOpIP {
181
223
 
182
- void fvec_L2sqr_ny_D1 (float * dis, const float * x,
224
+ static float op (float x, float y) {
225
+ return x * y;
226
+ }
227
+
228
+ static __m128 op (__m128 x, __m128 y) {
229
+ return x * y;
230
+ }
231
+
232
+ };
233
+
234
+ template<class ElementOp>
235
+ void fvec_op_ny_D1 (float * dis, const float * x,
183
236
  const float * y, size_t ny)
184
237
  {
185
238
  float x0s = x[0];
@@ -187,11 +240,9 @@ void fvec_L2sqr_ny_D1 (float * dis, const float * x,
187
240
 
188
241
  size_t i;
189
242
  for (i = 0; i + 3 < ny; i += 4) {
190
- __m128 tmp, accu;
191
- tmp = x0 - _mm_loadu_ps (y); y += 4;
192
- accu = tmp * tmp;
243
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
193
244
  dis[i] = _mm_cvtss_f32 (accu);
194
- tmp = _mm_shuffle_ps (accu, accu, 1);
245
+ __m128 tmp = _mm_shuffle_ps (accu, accu, 1);
195
246
  dis[i + 1] = _mm_cvtss_f32 (tmp);
196
247
  tmp = _mm_shuffle_ps (accu, accu, 2);
197
248
  dis[i + 2] = _mm_cvtss_f32 (tmp);
@@ -199,69 +250,63 @@ void fvec_L2sqr_ny_D1 (float * dis, const float * x,
199
250
  dis[i + 3] = _mm_cvtss_f32 (tmp);
200
251
  }
201
252
  while (i < ny) { // handle non-multiple-of-4 case
202
- dis[i++] = sqr(x0s - *y++);
253
+ dis[i++] = ElementOp::op(x0s, *y++);
203
254
  }
204
255
  }
205
256
 
206
-
207
- void fvec_L2sqr_ny_D2 (float * dis, const float * x,
257
+ template<class ElementOp>
258
+ void fvec_op_ny_D2 (float * dis, const float * x,
208
259
  const float * y, size_t ny)
209
260
  {
210
261
  __m128 x0 = _mm_set_ps (x[1], x[0], x[1], x[0]);
211
262
 
212
263
  size_t i;
213
264
  for (i = 0; i + 1 < ny; i += 2) {
214
- __m128 tmp, accu;
215
- tmp = x0 - _mm_loadu_ps (y); y += 4;
216
- accu = tmp * tmp;
265
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
217
266
  accu = _mm_hadd_ps (accu, accu);
218
267
  dis[i] = _mm_cvtss_f32 (accu);
219
268
  accu = _mm_shuffle_ps (accu, accu, 3);
220
269
  dis[i + 1] = _mm_cvtss_f32 (accu);
221
270
  }
222
271
  if (i < ny) { // handle odd case
223
- dis[i] = sqr(x[0] - y[0]) + sqr(x[1] - y[1]);
272
+ dis[i] = ElementOp::op(x[0], y[0]) + ElementOp::op(x[1], y[1]);
224
273
  }
225
274
  }
226
275
 
227
276
 
228
277
 
229
- void fvec_L2sqr_ny_D4 (float * dis, const float * x,
278
+ template<class ElementOp>
279
+ void fvec_op_ny_D4 (float * dis, const float * x,
230
280
  const float * y, size_t ny)
231
281
  {
232
282
  __m128 x0 = _mm_loadu_ps(x);
233
283
 
234
284
  for (size_t i = 0; i < ny; i++) {
235
- __m128 tmp, accu;
236
- tmp = x0 - _mm_loadu_ps (y); y += 4;
237
- accu = tmp * tmp;
285
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
238
286
  accu = _mm_hadd_ps (accu, accu);
239
287
  accu = _mm_hadd_ps (accu, accu);
240
288
  dis[i] = _mm_cvtss_f32 (accu);
241
289
  }
242
290
  }
243
291
 
244
-
245
- void fvec_L2sqr_ny_D8 (float * dis, const float * x,
292
+ template<class ElementOp>
293
+ void fvec_op_ny_D8 (float * dis, const float * x,
246
294
  const float * y, size_t ny)
247
295
  {
248
296
  __m128 x0 = _mm_loadu_ps(x);
249
297
  __m128 x1 = _mm_loadu_ps(x + 4);
250
298
 
251
299
  for (size_t i = 0; i < ny; i++) {
252
- __m128 tmp, accu;
253
- tmp = x0 - _mm_loadu_ps (y); y += 4;
254
- accu = tmp * tmp;
255
- tmp = x1 - _mm_loadu_ps (y); y += 4;
256
- accu += tmp * tmp;
300
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
301
+ accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
257
302
  accu = _mm_hadd_ps (accu, accu);
258
303
  accu = _mm_hadd_ps (accu, accu);
259
304
  dis[i] = _mm_cvtss_f32 (accu);
260
305
  }
261
306
  }
262
307
 
263
-
264
- void fvec_L2sqr_ny_D12 (float * dis, const float * x,
308
+ template<class ElementOp>
309
+ void fvec_op_ny_D12 (float * dis, const float * x,
265
310
  const float * y, size_t ny)
266
311
  {
267
312
  __m128 x0 = _mm_loadu_ps(x);
@@ -269,13 +314,9 @@ void fvec_L2sqr_ny_D12 (float * dis, const float * x,
269
314
  __m128 x2 = _mm_loadu_ps(x + 8);
270
315
 
271
316
  for (size_t i = 0; i < ny; i++) {
272
- __m128 tmp, accu;
273
- tmp = x0 - _mm_loadu_ps (y); y += 4;
274
- accu = tmp * tmp;
275
- tmp = x1 - _mm_loadu_ps (y); y += 4;
276
- accu += tmp * tmp;
277
- tmp = x2 - _mm_loadu_ps (y); y += 4;
278
- accu += tmp * tmp;
317
+ __m128 accu = ElementOp::op(x0, _mm_loadu_ps (y)); y += 4;
318
+ accu += ElementOp::op(x1, _mm_loadu_ps (y)); y += 4;
319
+ accu += ElementOp::op(x2, _mm_loadu_ps (y)); y += 4;
279
320
  accu = _mm_hadd_ps (accu, accu);
280
321
  accu = _mm_hadd_ps (accu, accu);
281
322
  dis[i] = _mm_cvtss_f32 (accu);
@@ -283,31 +324,52 @@ void fvec_L2sqr_ny_D12 (float * dis, const float * x,
283
324
  }
284
325
 
285
326
 
327
+
286
328
  } // anonymous namespace
287
329
 
288
330
  void fvec_L2sqr_ny (float * dis, const float * x,
289
331
  const float * y, size_t d, size_t ny) {
290
332
  // optimized for a few special cases
291
- switch(d) {
292
- case 1:
293
- fvec_L2sqr_ny_D1 (dis, x, y, ny);
294
- return;
295
- case 2:
296
- fvec_L2sqr_ny_D2 (dis, x, y, ny);
297
- return;
298
- case 4:
299
- fvec_L2sqr_ny_D4 (dis, x, y, ny);
333
+
334
+ #define DISPATCH(dval) \
335
+ case dval:\
336
+ fvec_op_ny_D ## dval <ElementOpL2> (dis, x, y, ny); \
300
337
  return;
301
- case 8:
302
- fvec_L2sqr_ny_D8 (dis, x, y, ny);
338
+
339
+ switch(d) {
340
+ DISPATCH(1)
341
+ DISPATCH(2)
342
+ DISPATCH(4)
343
+ DISPATCH(8)
344
+ DISPATCH(12)
345
+ default:
346
+ fvec_L2sqr_ny_ref (dis, x, y, d, ny);
303
347
  return;
304
- case 12:
305
- fvec_L2sqr_ny_D12 (dis, x, y, ny);
348
+ }
349
+ #undef DISPATCH
350
+
351
+ }
352
+
353
+ void fvec_inner_products_ny (float * dis, const float * x,
354
+ const float * y, size_t d, size_t ny) {
355
+
356
+ #define DISPATCH(dval) \
357
+ case dval:\
358
+ fvec_op_ny_D ## dval <ElementOpIP> (dis, x, y, ny); \
306
359
  return;
360
+
361
+ switch(d) {
362
+ DISPATCH(1)
363
+ DISPATCH(2)
364
+ DISPATCH(4)
365
+ DISPATCH(8)
366
+ DISPATCH(12)
307
367
  default:
308
- fvec_L2sqr_ny_ref (dis, x, y, d, ny);
368
+ fvec_inner_products_ny_ref (dis, x, y, d, ny);
309
369
  return;
310
370
  }
371
+ #undef DISPATCH
372
+
311
373
  }
312
374
 
313
375
 
@@ -644,6 +706,11 @@ void fvec_L2sqr_ny (float * dis, const float * x,
644
706
  fvec_L2sqr_ny_ref (dis, x, y, d, ny);
645
707
  }
646
708
 
709
+ void fvec_inner_products_ny (float * dis, const float * x,
710
+ const float * y, size_t d, size_t ny) {
711
+ fvec_inner_products_ny_ref (dis, x, y, d, ny);
712
+ }
713
+
647
714
 
648
715
  #endif
649
716
 
@@ -803,6 +870,167 @@ int fvec_madd_and_argmin (size_t n, const float *a,
803
870
  #endif
804
871
 
805
872
 
873
+ /***************************************************************************
874
+ * PQ tables computations
875
+ ***************************************************************************/
876
+
877
+ #ifdef __AVX2__
878
+
879
+ namespace {
880
+
881
+
882
+ // get even float32's of a and b, interleaved
883
+ simd8float32 geteven(simd8float32 a, simd8float32 b) {
884
+ return simd8float32(
885
+ _mm256_shuffle_ps(a.f, b.f, 0 << 0 | 2 << 2 | 0 << 4 | 2 << 6)
886
+ );
887
+ }
888
+
889
+ // get odd float32's of a and b, interleaved
890
+ simd8float32 getodd(simd8float32 a, simd8float32 b) {
891
+ return simd8float32(
892
+ _mm256_shuffle_ps(a.f, b.f, 1 << 0 | 3 << 2 | 1 << 4 | 3 << 6)
893
+ );
894
+ }
895
+
896
+ // 3 cycles
897
+ // if the lanes are a = [a0 a1] and b = [b0 b1], return [a0 b0]
898
+ simd8float32 getlow128(simd8float32 a, simd8float32 b) {
899
+ return simd8float32(
900
+ _mm256_permute2f128_ps(a.f, b.f, 0 | 2 << 4)
901
+ );
902
+ }
903
+
904
+ simd8float32 gethigh128(simd8float32 a, simd8float32 b) {
905
+ return simd8float32(
906
+ _mm256_permute2f128_ps(a.f, b.f, 1 | 3 << 4)
907
+ );
908
+ }
909
+
910
+ /// compute the IP for dsub = 2 for 8 centroids and 4 sub-vectors at a time
911
+ template<bool is_inner_product>
912
+ void pq2_8cents_table(
913
+ const simd8float32 centroids[8],
914
+ const simd8float32 x,
915
+ float *out, size_t ldo, size_t nout = 4
916
+ ) {
917
+
918
+ simd8float32 ips[4];
919
+
920
+ for(int i = 0; i < 4; i++) {
921
+ simd8float32 p1, p2;
922
+ if (is_inner_product) {
923
+ p1 = x * centroids[2 * i];
924
+ p2 = x * centroids[2 * i + 1];
925
+ } else {
926
+ p1 = (x - centroids[2 * i]);
927
+ p1 = p1 * p1;
928
+ p2 = (x - centroids[2 * i + 1]);
929
+ p2 = p2 * p2;
930
+ }
931
+ ips[i] = hadd(p1, p2);
932
+ }
933
+
934
+ simd8float32 ip02a = geteven(ips[0], ips[1]);
935
+ simd8float32 ip02b = geteven(ips[2], ips[3]);
936
+ simd8float32 ip0 = getlow128(ip02a, ip02b);
937
+ simd8float32 ip2 = gethigh128(ip02a, ip02b);
938
+
939
+ simd8float32 ip13a = getodd(ips[0], ips[1]);
940
+ simd8float32 ip13b = getodd(ips[2], ips[3]);
941
+ simd8float32 ip1 = getlow128(ip13a, ip13b);
942
+ simd8float32 ip3 = gethigh128(ip13a, ip13b);
943
+
944
+ switch(nout) {
945
+ case 4:
946
+ ip3.storeu(out + 3 * ldo);
947
+ case 3:
948
+ ip2.storeu(out + 2 * ldo);
949
+ case 2:
950
+ ip1.storeu(out + 1 * ldo);
951
+ case 1:
952
+ ip0.storeu(out);
953
+ }
954
+ }
955
+
956
+ simd8float32 load_simd8float32_partial(const float *x, int n) {
957
+ ALIGNED(32) float tmp[8] = {0, 0, 0, 0, 0, 0, 0, 0};
958
+ float *wp = tmp;
959
+ for (int i = 0; i < n; i++) {
960
+ *wp++ = *x++;
961
+ }
962
+ return simd8float32(tmp);
963
+ }
964
+
965
+ } // anonymous namespace
966
+
967
+
968
+
969
+
970
+ void compute_PQ_dis_tables_dsub2(
971
+ size_t d, size_t ksub, const float *all_centroids,
972
+ size_t nx, const float * x,
973
+ bool is_inner_product,
974
+ float * dis_tables)
975
+ {
976
+ size_t M = d / 2;
977
+ FAISS_THROW_IF_NOT(ksub % 8 == 0);
978
+
979
+ for(size_t m0 = 0; m0 < M; m0 += 4) {
980
+ int m1 = std::min(M, m0 + 4);
981
+ for(int k0 = 0; k0 < ksub; k0 += 8) {
982
+
983
+ simd8float32 centroids[8];
984
+ for (int k = 0; k < 8; k++) {
985
+ float centroid[8] __attribute__((aligned(32)));
986
+ size_t wp = 0;
987
+ size_t rp = (m0 * ksub + k + k0) * 2;
988
+ for (int m = m0; m < m1; m++) {
989
+ centroid[wp++] = all_centroids[rp];
990
+ centroid[wp++] = all_centroids[rp + 1];
991
+ rp += 2 * ksub;
992
+ }
993
+ centroids[k] = simd8float32(centroid);
994
+ }
995
+ for(size_t i = 0; i < nx; i++) {
996
+ simd8float32 xi;
997
+ if (m1 == m0 + 4) {
998
+ xi.loadu(x + i * d + m0 * 2);
999
+ } else {
1000
+ xi = load_simd8float32_partial(x + i * d + m0 * 2, 2 * (m1 - m0));
1001
+ }
1002
+
1003
+ if(is_inner_product) {
1004
+ pq2_8cents_table<true>(
1005
+ centroids, xi,
1006
+ dis_tables + (i * M + m0) * ksub + k0,
1007
+ ksub, m1 - m0
1008
+ );
1009
+ } else {
1010
+ pq2_8cents_table<false>(
1011
+ centroids, xi,
1012
+ dis_tables + (i * M + m0) * ksub + k0,
1013
+ ksub, m1 - m0
1014
+ );
1015
+ }
1016
+ }
1017
+ }
1018
+ }
1019
+
1020
+ }
1021
+
1022
+ #else
1023
+
1024
+ void compute_PQ_dis_tables_dsub2(
1025
+ size_t d, size_t ksub, const float *all_centroids,
1026
+ size_t nx, const float * x,
1027
+ bool is_inner_product,
1028
+ float * dis_tables)
1029
+ {
1030
+ FAISS_THROW_MSG("only implemented for AVX2");
1031
+ }
1032
+
1033
+ #endif
806
1034
 
807
1035
 
808
1036
  } // namespace faiss