faiss 0.1.3 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (199) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +25 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +16 -4
  5. data/ext/faiss/ext.cpp +12 -308
  6. data/ext/faiss/extconf.rb +6 -3
  7. data/ext/faiss/index.cpp +189 -0
  8. data/ext/faiss/index_binary.cpp +75 -0
  9. data/ext/faiss/kmeans.cpp +40 -0
  10. data/ext/faiss/numo.hpp +867 -0
  11. data/ext/faiss/pca_matrix.cpp +33 -0
  12. data/ext/faiss/product_quantizer.cpp +53 -0
  13. data/ext/faiss/utils.cpp +13 -0
  14. data/ext/faiss/utils.h +5 -0
  15. data/lib/faiss.rb +0 -5
  16. data/lib/faiss/version.rb +1 -1
  17. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  18. data/vendor/faiss/faiss/AutoTune.h +6 -3
  19. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  20. data/vendor/faiss/faiss/Index.cpp +3 -4
  21. data/vendor/faiss/faiss/Index.h +3 -3
  22. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  23. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  26. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  27. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  29. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  30. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  31. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  32. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  33. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  34. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  35. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  37. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  38. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  39. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  41. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  42. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  43. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  44. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  45. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  46. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  47. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  48. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  49. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  50. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  51. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  52. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  53. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  54. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  55. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  56. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  57. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  58. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  59. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  60. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  61. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  62. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  63. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  64. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  65. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  66. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  67. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  68. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  69. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  70. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  71. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  72. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  73. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  74. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  75. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  76. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  77. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  78. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  79. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  80. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  81. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  82. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  83. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  84. data/vendor/faiss/faiss/impl/io.h +7 -2
  85. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  86. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  87. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  88. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  89. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  90. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  91. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  92. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  93. data/vendor/faiss/faiss/index_io.h +1 -48
  94. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  95. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  96. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  97. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  98. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  99. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  100. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  101. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  102. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  103. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  104. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  105. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  106. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  107. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  108. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  109. data/vendor/faiss/faiss/utils/distances.h +28 -20
  110. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  111. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  112. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  113. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  114. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  115. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  116. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  117. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  118. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  119. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  120. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  121. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  122. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  123. metadata +54 -149
  124. data/lib/faiss/index.rb +0 -20
  125. data/lib/faiss/index_binary.rb +0 -20
  126. data/lib/faiss/kmeans.rb +0 -15
  127. data/lib/faiss/pca_matrix.rb +0 -15
  128. data/lib/faiss/product_quantizer.rb +0 -22
  129. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  130. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  131. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  132. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  133. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  134. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  135. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  136. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  137. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  138. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  139. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  140. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  141. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  142. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  143. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  144. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  145. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  146. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  147. data/vendor/faiss/c_api/Index_c.h +0 -183
  148. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  149. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  150. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  151. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  152. data/vendor/faiss/c_api/error_c.h +0 -42
  153. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  154. data/vendor/faiss/c_api/error_impl.h +0 -16
  155. data/vendor/faiss/c_api/faiss_c.h +0 -58
  156. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  157. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  158. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  159. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  160. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  161. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  162. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  163. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  164. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  165. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  166. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  167. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  168. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  169. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  170. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  171. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  172. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  173. data/vendor/faiss/c_api/index_io_c.h +0 -50
  174. data/vendor/faiss/c_api/macros_impl.h +0 -110
  175. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  176. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  177. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  178. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  179. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  180. data/vendor/faiss/misc/test_blas.cpp +0 -87
  181. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  182. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  183. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  184. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  185. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  186. data/vendor/faiss/tests/test_merge.cpp +0 -260
  187. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  188. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  189. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  190. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  191. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  192. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  193. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  194. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  195. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  196. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  197. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  198. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  199. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -120,7 +120,7 @@ struct IndexBinary {
120
120
  * @param x input vectors to search, size n * d / 8
121
121
  * @param labels output labels of the NNs, size n*k
122
122
  */
123
- void assign(idx_t n, const uint8_t *x, idx_t *labels, idx_t k = 1);
123
+ void assign(idx_t n, const uint8_t *x, idx_t *labels, idx_t k = 1) const;
124
124
 
125
125
  /// Removes all elements from the database.
126
126
  virtual void reset() = 0;
@@ -18,16 +18,7 @@
18
18
 
19
19
  #include <faiss/impl/AuxIndexStructures.h>
20
20
  #include <faiss/impl/FaissAssert.h>
21
-
22
- #ifdef _MSC_VER
23
- #include <intrin.h>
24
-
25
- static inline int __builtin_ctzll(uint64_t x) {
26
- unsigned long ret;
27
- _BitScanForward64(&ret, x);
28
- return (int)ret;
29
- }
30
- #endif // _MSC_VER
21
+ #include <faiss/impl/platform_macros.h>
31
22
 
32
23
  namespace faiss {
33
24
 
@@ -145,8 +136,7 @@ struct KnnSearchResults {
145
136
 
146
137
  inline void add (float dis, idx_t id) {
147
138
  if (dis < heap_sim[0]) {
148
- heap_pop<C> (k, heap_sim, heap_ids);
149
- heap_push<C> (k, heap_sim, heap_ids, dis, id);
139
+ heap_replace_top<C> (k, heap_sim, heap_ids, dis, id);
150
140
  }
151
141
  }
152
142
 
@@ -319,9 +319,8 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
319
319
  for (size_t j = 0; j < n; j++) {
320
320
  uint32_t dis = hc.hamming (codes);
321
321
  if (dis < simi[0]) {
322
- heap_pop<C> (k, simi, idxi);
323
322
  idx_t id = store_pairs ? lo_build(list_no, j) : ids[j];
324
- heap_push<C> (k, simi, idxi, dis, id);
323
+ heap_replace_top<C> (k, simi, idxi, dis, id);
325
324
  nup++;
326
325
  }
327
326
  codes += code_size;
@@ -226,155 +226,7 @@ void IndexFlat::sa_decode (idx_t n, const uint8_t *bytes, float *x) const
226
226
 
227
227
 
228
228
 
229
- /***************************************************
230
- * IndexFlatL2BaseShift
231
- ***************************************************/
232
-
233
- IndexFlatL2BaseShift::IndexFlatL2BaseShift (idx_t d, size_t nshift, const float *shift):
234
- IndexFlatL2 (d), shift (nshift)
235
- {
236
- memcpy (this->shift.data(), shift, sizeof(float) * nshift);
237
- }
238
229
 
239
- void IndexFlatL2BaseShift::search (
240
- idx_t n,
241
- const float *x,
242
- idx_t k,
243
- float *distances,
244
- idx_t *labels) const
245
- {
246
- FAISS_THROW_IF_NOT (shift.size() == ntotal);
247
-
248
- float_maxheap_array_t res = {
249
- size_t(n), size_t(k), labels, distances};
250
- knn_L2sqr_base_shift (x, xb.data(), d, n, ntotal, &res, shift.data());
251
- }
252
-
253
-
254
-
255
- /***************************************************
256
- * IndexRefineFlat
257
- ***************************************************/
258
-
259
- IndexRefineFlat::IndexRefineFlat (Index *base_index):
260
- Index (base_index->d, base_index->metric_type),
261
- refine_index (base_index->d, base_index->metric_type),
262
- base_index (base_index), own_fields (false),
263
- k_factor (1)
264
- {
265
- is_trained = base_index->is_trained;
266
- FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0,
267
- "base_index should be empty in the beginning");
268
- }
269
-
270
- IndexRefineFlat::IndexRefineFlat () {
271
- base_index = nullptr;
272
- own_fields = false;
273
- k_factor = 1;
274
- }
275
-
276
-
277
- void IndexRefineFlat::train (idx_t n, const float *x)
278
- {
279
- base_index->train (n, x);
280
- is_trained = true;
281
- }
282
-
283
- void IndexRefineFlat::add (idx_t n, const float *x) {
284
- FAISS_THROW_IF_NOT (is_trained);
285
- base_index->add (n, x);
286
- refine_index.add (n, x);
287
- ntotal = refine_index.ntotal;
288
- }
289
-
290
- void IndexRefineFlat::reset ()
291
- {
292
- base_index->reset ();
293
- refine_index.reset ();
294
- ntotal = 0;
295
- }
296
-
297
- namespace {
298
- typedef faiss::Index::idx_t idx_t;
299
-
300
- template<class C>
301
- static void reorder_2_heaps (
302
- idx_t n,
303
- idx_t k, idx_t *labels, float *distances,
304
- idx_t k_base, const idx_t *base_labels, const float *base_distances)
305
- {
306
- #pragma omp parallel for
307
- for (idx_t i = 0; i < n; i++) {
308
- idx_t *idxo = labels + i * k;
309
- float *diso = distances + i * k;
310
- const idx_t *idxi = base_labels + i * k_base;
311
- const float *disi = base_distances + i * k_base;
312
-
313
- heap_heapify<C> (k, diso, idxo, disi, idxi, k);
314
- if (k_base != k) { // add remaining elements
315
- heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
316
- }
317
- heap_reorder<C> (k, diso, idxo);
318
- }
319
- }
320
-
321
-
322
- }
323
-
324
-
325
- void IndexRefineFlat::search (
326
- idx_t n, const float *x, idx_t k,
327
- float *distances, idx_t *labels) const
328
- {
329
- FAISS_THROW_IF_NOT (is_trained);
330
- idx_t k_base = idx_t (k * k_factor);
331
- idx_t * base_labels = labels;
332
- float * base_distances = distances;
333
- ScopeDeleter<idx_t> del1;
334
- ScopeDeleter<float> del2;
335
-
336
-
337
- if (k != k_base) {
338
- base_labels = new idx_t [n * k_base];
339
- del1.set (base_labels);
340
- base_distances = new float [n * k_base];
341
- del2.set (base_distances);
342
- }
343
-
344
- base_index->search (n, x, k_base, base_distances, base_labels);
345
-
346
- for (int i = 0; i < n * k_base; i++)
347
- assert (base_labels[i] >= -1 &&
348
- base_labels[i] < ntotal);
349
-
350
- // compute refined distances
351
- refine_index.compute_distance_subset (
352
- n, x, k_base, base_distances, base_labels);
353
-
354
- // sort and store result
355
- if (metric_type == METRIC_L2) {
356
- typedef CMax <float, idx_t> C;
357
- reorder_2_heaps<C> (
358
- n, k, labels, distances,
359
- k_base, base_labels, base_distances);
360
-
361
- } else if (metric_type == METRIC_INNER_PRODUCT) {
362
- typedef CMin <float, idx_t> C;
363
- reorder_2_heaps<C> (
364
- n, k, labels, distances,
365
- k_base, base_labels, base_distances);
366
- } else {
367
- FAISS_THROW_MSG("Metric type not supported");
368
- }
369
-
370
- }
371
-
372
-
373
-
374
- IndexRefineFlat::~IndexRefineFlat ()
375
- {
376
- if (own_fields) delete base_index;
377
- }
378
230
 
379
231
  /***************************************************
380
232
  * IndexFlat1D
@@ -93,57 +93,6 @@ struct IndexFlatL2:IndexFlat {
93
93
  };
94
94
 
95
95
 
96
- // same as an IndexFlatL2 but a value is subtracted from each distance
97
- struct IndexFlatL2BaseShift: IndexFlatL2 {
98
- std::vector<float> shift;
99
-
100
- IndexFlatL2BaseShift (idx_t d, size_t nshift, const float *shift);
101
-
102
- void search(
103
- idx_t n,
104
- const float* x,
105
- idx_t k,
106
- float* distances,
107
- idx_t* labels) const override;
108
- };
109
-
110
-
111
- /** Index that queries in a base_index (a fast one) and refines the
112
- * results with an exact search, hopefully improving the results.
113
- */
114
- struct IndexRefineFlat: Index {
115
-
116
- /// storage for full vectors
117
- IndexFlat refine_index;
118
-
119
- /// faster index to pre-select the vectors that should be filtered
120
- Index *base_index;
121
- bool own_fields; ///< should the base index be deallocated?
122
-
123
- /// factor between k requested in search and the k requested from
124
- /// the base_index (should be >= 1)
125
- float k_factor;
126
-
127
- explicit IndexRefineFlat (Index *base_index);
128
-
129
- IndexRefineFlat ();
130
-
131
- void train(idx_t n, const float* x) override;
132
-
133
- void add(idx_t n, const float* x) override;
134
-
135
- void reset() override;
136
-
137
- void search(
138
- idx_t n,
139
- const float* x,
140
- idx_t k,
141
- float* distances,
142
- idx_t* labels) const override;
143
-
144
- ~IndexRefineFlat() override;
145
- };
146
-
147
96
 
148
97
  /// optimized version for 1D "vectors".
149
98
  struct IndexFlat1D:IndexFlatL2 {
@@ -277,7 +277,7 @@ IndexHNSW::~IndexHNSW() {
277
277
  void IndexHNSW::train(idx_t n, const float* x)
278
278
  {
279
279
  FAISS_THROW_IF_NOT_MSG(storage,
280
- "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
280
+ "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
281
281
  // hnsw structure does not require training
282
282
  storage->train (n, x);
283
283
  is_trained = true;
@@ -288,7 +288,7 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
288
288
 
289
289
  {
290
290
  FAISS_THROW_IF_NOT_MSG(storage,
291
- "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
291
+ "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
292
292
  size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
293
293
 
294
294
  idx_t check_period = InterruptCallback::get_period_hint (
@@ -352,7 +352,7 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
352
352
  void IndexHNSW::add(idx_t n, const float *x)
353
353
  {
354
354
  FAISS_THROW_IF_NOT_MSG(storage,
355
- "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
355
+ "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
356
356
  FAISS_THROW_IF_NOT(is_trained);
357
357
  int n0 = ntotal;
358
358
  storage->add(n, x);
@@ -1003,8 +1003,7 @@ int search_from_candidates_2(const HNSW & hnsw,
1003
1003
  if (nres < k) {
1004
1004
  faiss::maxheap_push (++nres, D, I, d, v1);
1005
1005
  } else if (d < D[0]) {
1006
- faiss::maxheap_pop (nres--, D, I);
1007
- faiss::maxheap_push (++nres, D, I, d, v1);
1006
+ faiss::maxheap_replace_top (nres, D, I, d, v1);
1008
1007
  }
1009
1008
  }
1010
1009
  vt.visited[v1] = vt.visno + 1;
@@ -88,12 +88,19 @@ void Level1Quantizer::train_q1 (size_t n, const float *x, bool verbose, MetricTy
88
88
  }
89
89
  quantizer->is_trained = true;
90
90
  } else if (quantizer_trains_alone == 2) {
91
- if (verbose)
91
+ if (verbose) {
92
92
  printf (
93
93
  "Training L2 quantizer on %zd vectors in %zdD%s\n",
94
94
  n, d,
95
95
  clustering_index ? "(user provided index)" : "");
96
- FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
96
+ }
97
+ // also accept spherical centroids because in that case
98
+ // L2 and IP are equivalent
99
+ FAISS_THROW_IF_NOT (
100
+ metric_type == METRIC_L2 ||
101
+ (metric_type == METRIC_INNER_PRODUCT && cp.spherical)
102
+ );
103
+
97
104
  Clustering clus (d, nlist, cp);
98
105
  if (!clustering_index) {
99
106
  IndexFlatL2 assigner (d);
@@ -263,23 +270,76 @@ void IndexIVF::set_direct_map_type (DirectMap::Type type)
263
270
  direct_map.set_type (type, invlists, ntotal);
264
271
  }
265
272
 
266
-
273
+ /** It is a sad fact of software that a conceptually simple function like this
274
+ * becomes very complex when you factor in several ways of parallelizing +
275
+ * interrupt/error handling + collecting stats + min/max collection. The
276
+ * codepath that is used 95% of time is the one for parallel_mode = 0 */
267
277
  void IndexIVF::search (idx_t n, const float *x, idx_t k,
268
278
  float *distances, idx_t *labels) const
269
279
  {
270
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
271
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
272
280
 
273
- double t0 = getmillisecs();
274
- quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
275
- indexIVF_stats.quantization_time += getmillisecs() - t0;
276
281
 
277
- t0 = getmillisecs();
278
- invlists->prefetch_lists (idx.get(), n * nprobe);
282
+ // search function for a subset of queries
283
+ auto sub_search_func = [this, k]
284
+ (idx_t n, const float *x, float *distances, idx_t *labels,
285
+ IndexIVFStats *ivf_stats) {
286
+
287
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
288
+ std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
289
+
290
+ double t0 = getmillisecs();
291
+ quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
292
+
293
+ double t1 = getmillisecs();
294
+ invlists->prefetch_lists (idx.get(), n * nprobe);
295
+
296
+ search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
297
+ distances, labels, false, nullptr, ivf_stats);
298
+ double t2 = getmillisecs();
299
+ ivf_stats->quantization_time += t1 - t0;
300
+ ivf_stats->search_time += t2 - t0;
301
+ };
302
+
303
+
304
+ if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
305
+ int nt = std::min(omp_get_max_threads(), int(n));
306
+ std::vector<IndexIVFStats> stats(nt);
307
+ std::mutex exception_mutex;
308
+ std::string exception_string;
309
+
310
+ #pragma omp parallel for if (nt > 1)
311
+ for(idx_t slice = 0; slice < nt; slice++) {
312
+ IndexIVFStats local_stats;
313
+ idx_t i0 = n * slice / nt;
314
+ idx_t i1 = n * (slice + 1) / nt;
315
+ if (i1 > i0) {
316
+ try {
317
+ sub_search_func(
318
+ i1 - i0, x + i0 * d,
319
+ distances + i0 * k, labels + i0 * k,
320
+ &stats[slice]
321
+ );
322
+ } catch(const std::exception & e) {
323
+ std::lock_guard<std::mutex> lock(exception_mutex);
324
+ exception_string = e.what();
325
+ }
326
+ }
327
+ }
328
+
329
+ if (!exception_string.empty()) {
330
+ FAISS_THROW_MSG (exception_string.c_str());
331
+ }
332
+
333
+ // collect stats
334
+ for(idx_t slice = 0; slice < nt; slice++) {
335
+ indexIVF_stats.add(stats[slice]);
336
+ }
337
+ } else {
338
+ // handle paralellization at level below (or don't run in parallel at all)
339
+ sub_search_func(n, x, distances, labels, &indexIVF_stats);
340
+ }
341
+
279
342
 
280
- search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
281
- distances, labels, false);
282
- indexIVF_stats.search_time += getmillisecs() - t0;
283
343
  }
284
344
 
285
345
 
@@ -288,7 +348,8 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
288
348
  const float *coarse_dis ,
289
349
  float *distances, idx_t *labels,
290
350
  bool store_pairs,
291
- const IVFSearchParameters *params) const
351
+ const IVFSearchParameters *params,
352
+ IndexIVFStats *ivf_stats) const
292
353
  {
293
354
  long nprobe = params ? params->nprobe : this->nprobe;
294
355
  long max_codes = params ? params->max_codes : this->max_codes;
@@ -305,13 +366,12 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
305
366
  int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
306
367
  bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
307
368
 
308
- // don't start parallel section if single query
309
369
  bool do_parallel = omp_get_max_threads() >= 2 && (
310
- pmode == 0 ? n > 1 :
370
+ pmode == 0 ? false :
371
+ pmode == 3 ? n > 1 :
311
372
  pmode == 1 ? nprobe > 1 :
312
373
  nprobe * n > 1);
313
374
 
314
-
315
375
  #pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
316
376
  {
317
377
  InvertedListScanner *scanner = get_InvertedListScanner(store_pairs);
@@ -409,7 +469,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
409
469
  * Actual loops, depending on parallel_mode
410
470
  ****************************************************/
411
471
 
412
- if (pmode == 0) {
472
+ if (pmode == 0 || pmode == 3) {
413
473
 
414
474
  #pragma omp for
415
475
  for (idx_t i = 0; i < n; i++) {
@@ -527,11 +587,12 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
527
587
  }
528
588
  }
529
589
 
530
- indexIVF_stats.nq += n;
531
- indexIVF_stats.nlist += nlistv;
532
- indexIVF_stats.ndis += ndis;
533
- indexIVF_stats.nheap_updates += nheap;
534
-
590
+ if (ivf_stats) {
591
+ ivf_stats->nq += n;
592
+ ivf_stats->nlist += nlistv;
593
+ ivf_stats->ndis += ndis;
594
+ ivf_stats->nheap_updates += nheap;
595
+ }
535
596
  }
536
597
 
537
598
 
@@ -551,7 +612,7 @@ void IndexIVF::range_search (idx_t nx, const float *x, float radius,
551
612
  invlists->prefetch_lists (keys.get(), nx * nprobe);
552
613
 
553
614
  range_search_preassigned (nx, x, radius, keys.get (), coarse_dis.get (),
554
- result);
615
+ result, false, nullptr, &indexIVF_stats);
555
616
 
556
617
  indexIVF_stats.search_time += getmillisecs() - t0;
557
618
  }
@@ -561,7 +622,8 @@ void IndexIVF::range_search_preassigned (
561
622
  const idx_t *keys, const float *coarse_dis,
562
623
  RangeSearchResult *result,
563
624
  bool store_pairs,
564
- const IVFSearchParameters *params) const
625
+ const IVFSearchParameters *params,
626
+ IndexIVFStats *stats) const
565
627
  {
566
628
  long nprobe = params ? params->nprobe : this->nprobe;
567
629
  long max_codes = params ? params->max_codes : this->max_codes;
@@ -574,7 +636,15 @@ void IndexIVF::range_search_preassigned (
574
636
 
575
637
  std::vector<RangeSearchPartialResult *> all_pres (omp_get_max_threads());
576
638
 
577
- #pragma omp parallel reduction(+: nlistv, ndis)
639
+ int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
640
+ // don't start parallel section if single query
641
+ bool do_parallel = omp_get_max_threads() >= 2 && (
642
+ pmode == 3 ? false :
643
+ pmode == 0 ? nx > 1 :
644
+ pmode == 1 ? nprobe > 1 :
645
+ nprobe * nx > 1);
646
+
647
+ #pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis)
578
648
  {
579
649
  RangeSearchPartialResult pres(result);
580
650
  std::unique_ptr<InvertedListScanner> scanner
@@ -680,9 +750,11 @@ void IndexIVF::range_search_preassigned (
680
750
  }
681
751
  }
682
752
 
683
- indexIVF_stats.nq += nx;
684
- indexIVF_stats.nlist += nlistv;
685
- indexIVF_stats.ndis += ndis;
753
+ if (stats) {
754
+ stats->nq += nx;
755
+ stats->nlist += nlistv;
756
+ stats->ndis += ndis;
757
+ }
686
758
  }
687
759
 
688
760
 
@@ -879,11 +951,15 @@ void IndexIVF::replace_invlists (InvertedLists *il, bool own)
879
951
  {
880
952
  if (own_invlists) {
881
953
  delete invlists;
954
+ invlists = nullptr;
882
955
  }
883
956
  // FAISS_THROW_IF_NOT (ntotal == 0);
884
957
  if (il) {
885
- FAISS_THROW_IF_NOT (il->nlist == nlist &&
886
- il->code_size == code_size);
958
+ FAISS_THROW_IF_NOT (il->nlist == nlist);
959
+ FAISS_THROW_IF_NOT (
960
+ il->code_size == code_size ||
961
+ il->code_size == InvertedLists::INVALID_CODE_SIZE
962
+ );
887
963
  }
888
964
  invlists = il;
889
965
  own_invlists = own;
@@ -971,6 +1047,17 @@ void IndexIVFStats::reset()
971
1047
  memset ((void*)this, 0, sizeof (*this));
972
1048
  }
973
1049
 
1050
+ void IndexIVFStats::add (const IndexIVFStats & other)
1051
+ {
1052
+ nq += other.nq;
1053
+ nlist += other.nlist;
1054
+ ndis += other.ndis;
1055
+ nheap_updates += other.nheap_updates;
1056
+ quantization_time += other.quantization_time;
1057
+ search_time += other.search_time;
1058
+
1059
+ }
1060
+
974
1061
 
975
1062
  IndexIVFStats indexIVF_stats;
976
1063