faiss 0.1.3 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (199) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +25 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +16 -4
  5. data/ext/faiss/ext.cpp +12 -308
  6. data/ext/faiss/extconf.rb +6 -3
  7. data/ext/faiss/index.cpp +189 -0
  8. data/ext/faiss/index_binary.cpp +75 -0
  9. data/ext/faiss/kmeans.cpp +40 -0
  10. data/ext/faiss/numo.hpp +867 -0
  11. data/ext/faiss/pca_matrix.cpp +33 -0
  12. data/ext/faiss/product_quantizer.cpp +53 -0
  13. data/ext/faiss/utils.cpp +13 -0
  14. data/ext/faiss/utils.h +5 -0
  15. data/lib/faiss.rb +0 -5
  16. data/lib/faiss/version.rb +1 -1
  17. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  18. data/vendor/faiss/faiss/AutoTune.h +6 -3
  19. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  20. data/vendor/faiss/faiss/Index.cpp +3 -4
  21. data/vendor/faiss/faiss/Index.h +3 -3
  22. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  23. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  26. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  27. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  29. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  30. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  31. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  32. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  33. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  34. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  35. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  37. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  38. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  39. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  41. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  42. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  43. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  44. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  45. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  46. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  47. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  48. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  49. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  50. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  51. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  52. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  53. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  54. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  55. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  56. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  57. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  58. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  59. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  60. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  61. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  62. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  63. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  64. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  65. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  66. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  67. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  68. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  69. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  70. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  71. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  72. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  73. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  74. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  75. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  76. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  77. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  78. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  79. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  80. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  81. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  82. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  83. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  84. data/vendor/faiss/faiss/impl/io.h +7 -2
  85. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  86. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  87. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  88. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  89. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  90. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  91. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  92. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  93. data/vendor/faiss/faiss/index_io.h +1 -48
  94. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  95. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  96. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  97. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  98. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  99. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  100. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  101. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  102. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  103. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  104. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  105. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  106. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  107. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  108. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  109. data/vendor/faiss/faiss/utils/distances.h +28 -20
  110. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  111. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  112. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  113. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  114. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  115. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  116. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  117. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  118. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  119. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  120. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  121. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  122. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  123. metadata +54 -149
  124. data/lib/faiss/index.rb +0 -20
  125. data/lib/faiss/index_binary.rb +0 -20
  126. data/lib/faiss/kmeans.rb +0 -15
  127. data/lib/faiss/pca_matrix.rb +0 -15
  128. data/lib/faiss/product_quantizer.rb +0 -22
  129. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  130. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  131. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  132. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  133. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  134. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  135. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  136. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  137. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  138. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  139. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  140. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  141. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  142. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  143. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  144. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  145. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  146. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  147. data/vendor/faiss/c_api/Index_c.h +0 -183
  148. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  149. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  150. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  151. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  152. data/vendor/faiss/c_api/error_c.h +0 -42
  153. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  154. data/vendor/faiss/c_api/error_impl.h +0 -16
  155. data/vendor/faiss/c_api/faiss_c.h +0 -58
  156. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  157. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  158. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  159. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  160. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  161. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  162. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  163. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  164. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  165. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  166. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  167. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  168. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  169. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  170. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  171. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  172. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  173. data/vendor/faiss/c_api/index_io_c.h +0 -50
  174. data/vendor/faiss/c_api/macros_impl.h +0 -110
  175. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  176. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  177. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  178. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  179. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  180. data/vendor/faiss/misc/test_blas.cpp +0 -87
  181. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  182. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  183. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  184. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  185. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  186. data/vendor/faiss/tests/test_merge.cpp +0 -260
  187. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  188. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  189. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  190. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  191. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  192. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  193. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  194. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  195. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  196. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  197. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  198. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  199. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -120,7 +120,7 @@ struct IndexBinary {
120
120
  * @param x input vectors to search, size n * d / 8
121
121
  * @param labels output labels of the NNs, size n*k
122
122
  */
123
- void assign(idx_t n, const uint8_t *x, idx_t *labels, idx_t k = 1);
123
+ void assign(idx_t n, const uint8_t *x, idx_t *labels, idx_t k = 1) const;
124
124
 
125
125
  /// Removes all elements from the database.
126
126
  virtual void reset() = 0;
@@ -18,16 +18,7 @@
18
18
 
19
19
  #include <faiss/impl/AuxIndexStructures.h>
20
20
  #include <faiss/impl/FaissAssert.h>
21
-
22
- #ifdef _MSC_VER
23
- #include <intrin.h>
24
-
25
- static inline int __builtin_ctzll(uint64_t x) {
26
- unsigned long ret;
27
- _BitScanForward64(&ret, x);
28
- return (int)ret;
29
- }
30
- #endif // _MSC_VER
21
+ #include <faiss/impl/platform_macros.h>
31
22
 
32
23
  namespace faiss {
33
24
 
@@ -145,8 +136,7 @@ struct KnnSearchResults {
145
136
 
146
137
  inline void add (float dis, idx_t id) {
147
138
  if (dis < heap_sim[0]) {
148
- heap_pop<C> (k, heap_sim, heap_ids);
149
- heap_push<C> (k, heap_sim, heap_ids, dis, id);
139
+ heap_replace_top<C> (k, heap_sim, heap_ids, dis, id);
150
140
  }
151
141
  }
152
142
 
@@ -319,9 +319,8 @@ struct IVFBinaryScannerL2: BinaryInvertedListScanner {
319
319
  for (size_t j = 0; j < n; j++) {
320
320
  uint32_t dis = hc.hamming (codes);
321
321
  if (dis < simi[0]) {
322
- heap_pop<C> (k, simi, idxi);
323
322
  idx_t id = store_pairs ? lo_build(list_no, j) : ids[j];
324
- heap_push<C> (k, simi, idxi, dis, id);
323
+ heap_replace_top<C> (k, simi, idxi, dis, id);
325
324
  nup++;
326
325
  }
327
326
  codes += code_size;
@@ -226,155 +226,7 @@ void IndexFlat::sa_decode (idx_t n, const uint8_t *bytes, float *x) const
226
226
 
227
227
 
228
228
 
229
- /***************************************************
230
- * IndexFlatL2BaseShift
231
- ***************************************************/
232
-
233
- IndexFlatL2BaseShift::IndexFlatL2BaseShift (idx_t d, size_t nshift, const float *shift):
234
- IndexFlatL2 (d), shift (nshift)
235
- {
236
- memcpy (this->shift.data(), shift, sizeof(float) * nshift);
237
- }
238
229
 
239
- void IndexFlatL2BaseShift::search (
240
- idx_t n,
241
- const float *x,
242
- idx_t k,
243
- float *distances,
244
- idx_t *labels) const
245
- {
246
- FAISS_THROW_IF_NOT (shift.size() == ntotal);
247
-
248
- float_maxheap_array_t res = {
249
- size_t(n), size_t(k), labels, distances};
250
- knn_L2sqr_base_shift (x, xb.data(), d, n, ntotal, &res, shift.data());
251
- }
252
-
253
-
254
-
255
- /***************************************************
256
- * IndexRefineFlat
257
- ***************************************************/
258
-
259
- IndexRefineFlat::IndexRefineFlat (Index *base_index):
260
- Index (base_index->d, base_index->metric_type),
261
- refine_index (base_index->d, base_index->metric_type),
262
- base_index (base_index), own_fields (false),
263
- k_factor (1)
264
- {
265
- is_trained = base_index->is_trained;
266
- FAISS_THROW_IF_NOT_MSG (base_index->ntotal == 0,
267
- "base_index should be empty in the beginning");
268
- }
269
-
270
- IndexRefineFlat::IndexRefineFlat () {
271
- base_index = nullptr;
272
- own_fields = false;
273
- k_factor = 1;
274
- }
275
-
276
-
277
- void IndexRefineFlat::train (idx_t n, const float *x)
278
- {
279
- base_index->train (n, x);
280
- is_trained = true;
281
- }
282
-
283
- void IndexRefineFlat::add (idx_t n, const float *x) {
284
- FAISS_THROW_IF_NOT (is_trained);
285
- base_index->add (n, x);
286
- refine_index.add (n, x);
287
- ntotal = refine_index.ntotal;
288
- }
289
-
290
- void IndexRefineFlat::reset ()
291
- {
292
- base_index->reset ();
293
- refine_index.reset ();
294
- ntotal = 0;
295
- }
296
-
297
- namespace {
298
- typedef faiss::Index::idx_t idx_t;
299
-
300
- template<class C>
301
- static void reorder_2_heaps (
302
- idx_t n,
303
- idx_t k, idx_t *labels, float *distances,
304
- idx_t k_base, const idx_t *base_labels, const float *base_distances)
305
- {
306
- #pragma omp parallel for
307
- for (idx_t i = 0; i < n; i++) {
308
- idx_t *idxo = labels + i * k;
309
- float *diso = distances + i * k;
310
- const idx_t *idxi = base_labels + i * k_base;
311
- const float *disi = base_distances + i * k_base;
312
-
313
- heap_heapify<C> (k, diso, idxo, disi, idxi, k);
314
- if (k_base != k) { // add remaining elements
315
- heap_addn<C> (k, diso, idxo, disi + k, idxi + k, k_base - k);
316
- }
317
- heap_reorder<C> (k, diso, idxo);
318
- }
319
- }
320
-
321
-
322
- }
323
-
324
-
325
- void IndexRefineFlat::search (
326
- idx_t n, const float *x, idx_t k,
327
- float *distances, idx_t *labels) const
328
- {
329
- FAISS_THROW_IF_NOT (is_trained);
330
- idx_t k_base = idx_t (k * k_factor);
331
- idx_t * base_labels = labels;
332
- float * base_distances = distances;
333
- ScopeDeleter<idx_t> del1;
334
- ScopeDeleter<float> del2;
335
-
336
-
337
- if (k != k_base) {
338
- base_labels = new idx_t [n * k_base];
339
- del1.set (base_labels);
340
- base_distances = new float [n * k_base];
341
- del2.set (base_distances);
342
- }
343
-
344
- base_index->search (n, x, k_base, base_distances, base_labels);
345
-
346
- for (int i = 0; i < n * k_base; i++)
347
- assert (base_labels[i] >= -1 &&
348
- base_labels[i] < ntotal);
349
-
350
- // compute refined distances
351
- refine_index.compute_distance_subset (
352
- n, x, k_base, base_distances, base_labels);
353
-
354
- // sort and store result
355
- if (metric_type == METRIC_L2) {
356
- typedef CMax <float, idx_t> C;
357
- reorder_2_heaps<C> (
358
- n, k, labels, distances,
359
- k_base, base_labels, base_distances);
360
-
361
- } else if (metric_type == METRIC_INNER_PRODUCT) {
362
- typedef CMin <float, idx_t> C;
363
- reorder_2_heaps<C> (
364
- n, k, labels, distances,
365
- k_base, base_labels, base_distances);
366
- } else {
367
- FAISS_THROW_MSG("Metric type not supported");
368
- }
369
-
370
- }
371
-
372
-
373
-
374
- IndexRefineFlat::~IndexRefineFlat ()
375
- {
376
- if (own_fields) delete base_index;
377
- }
378
230
 
379
231
  /***************************************************
380
232
  * IndexFlat1D
@@ -93,57 +93,6 @@ struct IndexFlatL2:IndexFlat {
93
93
  };
94
94
 
95
95
 
96
- // same as an IndexFlatL2 but a value is subtracted from each distance
97
- struct IndexFlatL2BaseShift: IndexFlatL2 {
98
- std::vector<float> shift;
99
-
100
- IndexFlatL2BaseShift (idx_t d, size_t nshift, const float *shift);
101
-
102
- void search(
103
- idx_t n,
104
- const float* x,
105
- idx_t k,
106
- float* distances,
107
- idx_t* labels) const override;
108
- };
109
-
110
-
111
- /** Index that queries in a base_index (a fast one) and refines the
112
- * results with an exact search, hopefully improving the results.
113
- */
114
- struct IndexRefineFlat: Index {
115
-
116
- /// storage for full vectors
117
- IndexFlat refine_index;
118
-
119
- /// faster index to pre-select the vectors that should be filtered
120
- Index *base_index;
121
- bool own_fields; ///< should the base index be deallocated?
122
-
123
- /// factor between k requested in search and the k requested from
124
- /// the base_index (should be >= 1)
125
- float k_factor;
126
-
127
- explicit IndexRefineFlat (Index *base_index);
128
-
129
- IndexRefineFlat ();
130
-
131
- void train(idx_t n, const float* x) override;
132
-
133
- void add(idx_t n, const float* x) override;
134
-
135
- void reset() override;
136
-
137
- void search(
138
- idx_t n,
139
- const float* x,
140
- idx_t k,
141
- float* distances,
142
- idx_t* labels) const override;
143
-
144
- ~IndexRefineFlat() override;
145
- };
146
-
147
96
 
148
97
  /// optimized version for 1D "vectors".
149
98
  struct IndexFlat1D:IndexFlatL2 {
@@ -277,7 +277,7 @@ IndexHNSW::~IndexHNSW() {
277
277
  void IndexHNSW::train(idx_t n, const float* x)
278
278
  {
279
279
  FAISS_THROW_IF_NOT_MSG(storage,
280
- "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
280
+ "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
281
281
  // hnsw structure does not require training
282
282
  storage->train (n, x);
283
283
  is_trained = true;
@@ -288,7 +288,7 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
288
288
 
289
289
  {
290
290
  FAISS_THROW_IF_NOT_MSG(storage,
291
- "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
291
+ "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
292
292
  size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
293
293
 
294
294
  idx_t check_period = InterruptCallback::get_period_hint (
@@ -352,7 +352,7 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
352
352
  void IndexHNSW::add(idx_t n, const float *x)
353
353
  {
354
354
  FAISS_THROW_IF_NOT_MSG(storage,
355
- "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
355
+ "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
356
356
  FAISS_THROW_IF_NOT(is_trained);
357
357
  int n0 = ntotal;
358
358
  storage->add(n, x);
@@ -1003,8 +1003,7 @@ int search_from_candidates_2(const HNSW & hnsw,
1003
1003
  if (nres < k) {
1004
1004
  faiss::maxheap_push (++nres, D, I, d, v1);
1005
1005
  } else if (d < D[0]) {
1006
- faiss::maxheap_pop (nres--, D, I);
1007
- faiss::maxheap_push (++nres, D, I, d, v1);
1006
+ faiss::maxheap_replace_top (nres, D, I, d, v1);
1008
1007
  }
1009
1008
  }
1010
1009
  vt.visited[v1] = vt.visno + 1;
@@ -88,12 +88,19 @@ void Level1Quantizer::train_q1 (size_t n, const float *x, bool verbose, MetricTy
88
88
  }
89
89
  quantizer->is_trained = true;
90
90
  } else if (quantizer_trains_alone == 2) {
91
- if (verbose)
91
+ if (verbose) {
92
92
  printf (
93
93
  "Training L2 quantizer on %zd vectors in %zdD%s\n",
94
94
  n, d,
95
95
  clustering_index ? "(user provided index)" : "");
96
- FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
96
+ }
97
+ // also accept spherical centroids because in that case
98
+ // L2 and IP are equivalent
99
+ FAISS_THROW_IF_NOT (
100
+ metric_type == METRIC_L2 ||
101
+ (metric_type == METRIC_INNER_PRODUCT && cp.spherical)
102
+ );
103
+
97
104
  Clustering clus (d, nlist, cp);
98
105
  if (!clustering_index) {
99
106
  IndexFlatL2 assigner (d);
@@ -263,23 +270,76 @@ void IndexIVF::set_direct_map_type (DirectMap::Type type)
263
270
  direct_map.set_type (type, invlists, ntotal);
264
271
  }
265
272
 
266
-
273
+ /** It is a sad fact of software that a conceptually simple function like this
274
+ * becomes very complex when you factor in several ways of parallelizing +
275
+ * interrupt/error handling + collecting stats + min/max collection. The
276
+ * codepath that is used 95% of time is the one for parallel_mode = 0 */
267
277
  void IndexIVF::search (idx_t n, const float *x, idx_t k,
268
278
  float *distances, idx_t *labels) const
269
279
  {
270
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
271
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
272
280
 
273
- double t0 = getmillisecs();
274
- quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
275
- indexIVF_stats.quantization_time += getmillisecs() - t0;
276
281
 
277
- t0 = getmillisecs();
278
- invlists->prefetch_lists (idx.get(), n * nprobe);
282
+ // search function for a subset of queries
283
+ auto sub_search_func = [this, k]
284
+ (idx_t n, const float *x, float *distances, idx_t *labels,
285
+ IndexIVFStats *ivf_stats) {
286
+
287
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
288
+ std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
289
+
290
+ double t0 = getmillisecs();
291
+ quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
292
+
293
+ double t1 = getmillisecs();
294
+ invlists->prefetch_lists (idx.get(), n * nprobe);
295
+
296
+ search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
297
+ distances, labels, false, nullptr, ivf_stats);
298
+ double t2 = getmillisecs();
299
+ ivf_stats->quantization_time += t1 - t0;
300
+ ivf_stats->search_time += t2 - t0;
301
+ };
302
+
303
+
304
+ if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
305
+ int nt = std::min(omp_get_max_threads(), int(n));
306
+ std::vector<IndexIVFStats> stats(nt);
307
+ std::mutex exception_mutex;
308
+ std::string exception_string;
309
+
310
+ #pragma omp parallel for if (nt > 1)
311
+ for(idx_t slice = 0; slice < nt; slice++) {
312
+ IndexIVFStats local_stats;
313
+ idx_t i0 = n * slice / nt;
314
+ idx_t i1 = n * (slice + 1) / nt;
315
+ if (i1 > i0) {
316
+ try {
317
+ sub_search_func(
318
+ i1 - i0, x + i0 * d,
319
+ distances + i0 * k, labels + i0 * k,
320
+ &stats[slice]
321
+ );
322
+ } catch(const std::exception & e) {
323
+ std::lock_guard<std::mutex> lock(exception_mutex);
324
+ exception_string = e.what();
325
+ }
326
+ }
327
+ }
328
+
329
+ if (!exception_string.empty()) {
330
+ FAISS_THROW_MSG (exception_string.c_str());
331
+ }
332
+
333
+ // collect stats
334
+ for(idx_t slice = 0; slice < nt; slice++) {
335
+ indexIVF_stats.add(stats[slice]);
336
+ }
337
+ } else {
338
+ // handle paralellization at level below (or don't run in parallel at all)
339
+ sub_search_func(n, x, distances, labels, &indexIVF_stats);
340
+ }
341
+
279
342
 
280
- search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
281
- distances, labels, false);
282
- indexIVF_stats.search_time += getmillisecs() - t0;
283
343
  }
284
344
 
285
345
 
@@ -288,7 +348,8 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
288
348
  const float *coarse_dis ,
289
349
  float *distances, idx_t *labels,
290
350
  bool store_pairs,
291
- const IVFSearchParameters *params) const
351
+ const IVFSearchParameters *params,
352
+ IndexIVFStats *ivf_stats) const
292
353
  {
293
354
  long nprobe = params ? params->nprobe : this->nprobe;
294
355
  long max_codes = params ? params->max_codes : this->max_codes;
@@ -305,13 +366,12 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
305
366
  int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
306
367
  bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
307
368
 
308
- // don't start parallel section if single query
309
369
  bool do_parallel = omp_get_max_threads() >= 2 && (
310
- pmode == 0 ? n > 1 :
370
+ pmode == 0 ? false :
371
+ pmode == 3 ? n > 1 :
311
372
  pmode == 1 ? nprobe > 1 :
312
373
  nprobe * n > 1);
313
374
 
314
-
315
375
  #pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
316
376
  {
317
377
  InvertedListScanner *scanner = get_InvertedListScanner(store_pairs);
@@ -409,7 +469,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
409
469
  * Actual loops, depending on parallel_mode
410
470
  ****************************************************/
411
471
 
412
- if (pmode == 0) {
472
+ if (pmode == 0 || pmode == 3) {
413
473
 
414
474
  #pragma omp for
415
475
  for (idx_t i = 0; i < n; i++) {
@@ -527,11 +587,12 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
527
587
  }
528
588
  }
529
589
 
530
- indexIVF_stats.nq += n;
531
- indexIVF_stats.nlist += nlistv;
532
- indexIVF_stats.ndis += ndis;
533
- indexIVF_stats.nheap_updates += nheap;
534
-
590
+ if (ivf_stats) {
591
+ ivf_stats->nq += n;
592
+ ivf_stats->nlist += nlistv;
593
+ ivf_stats->ndis += ndis;
594
+ ivf_stats->nheap_updates += nheap;
595
+ }
535
596
  }
536
597
 
537
598
 
@@ -551,7 +612,7 @@ void IndexIVF::range_search (idx_t nx, const float *x, float radius,
551
612
  invlists->prefetch_lists (keys.get(), nx * nprobe);
552
613
 
553
614
  range_search_preassigned (nx, x, radius, keys.get (), coarse_dis.get (),
554
- result);
615
+ result, false, nullptr, &indexIVF_stats);
555
616
 
556
617
  indexIVF_stats.search_time += getmillisecs() - t0;
557
618
  }
@@ -561,7 +622,8 @@ void IndexIVF::range_search_preassigned (
561
622
  const idx_t *keys, const float *coarse_dis,
562
623
  RangeSearchResult *result,
563
624
  bool store_pairs,
564
- const IVFSearchParameters *params) const
625
+ const IVFSearchParameters *params,
626
+ IndexIVFStats *stats) const
565
627
  {
566
628
  long nprobe = params ? params->nprobe : this->nprobe;
567
629
  long max_codes = params ? params->max_codes : this->max_codes;
@@ -574,7 +636,15 @@ void IndexIVF::range_search_preassigned (
574
636
 
575
637
  std::vector<RangeSearchPartialResult *> all_pres (omp_get_max_threads());
576
638
 
577
- #pragma omp parallel reduction(+: nlistv, ndis)
639
+ int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
640
+ // don't start parallel section if single query
641
+ bool do_parallel = omp_get_max_threads() >= 2 && (
642
+ pmode == 3 ? false :
643
+ pmode == 0 ? nx > 1 :
644
+ pmode == 1 ? nprobe > 1 :
645
+ nprobe * nx > 1);
646
+
647
+ #pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis)
578
648
  {
579
649
  RangeSearchPartialResult pres(result);
580
650
  std::unique_ptr<InvertedListScanner> scanner
@@ -680,9 +750,11 @@ void IndexIVF::range_search_preassigned (
680
750
  }
681
751
  }
682
752
 
683
- indexIVF_stats.nq += nx;
684
- indexIVF_stats.nlist += nlistv;
685
- indexIVF_stats.ndis += ndis;
753
+ if (stats) {
754
+ stats->nq += nx;
755
+ stats->nlist += nlistv;
756
+ stats->ndis += ndis;
757
+ }
686
758
  }
687
759
 
688
760
 
@@ -879,11 +951,15 @@ void IndexIVF::replace_invlists (InvertedLists *il, bool own)
879
951
  {
880
952
  if (own_invlists) {
881
953
  delete invlists;
954
+ invlists = nullptr;
882
955
  }
883
956
  // FAISS_THROW_IF_NOT (ntotal == 0);
884
957
  if (il) {
885
- FAISS_THROW_IF_NOT (il->nlist == nlist &&
886
- il->code_size == code_size);
958
+ FAISS_THROW_IF_NOT (il->nlist == nlist);
959
+ FAISS_THROW_IF_NOT (
960
+ il->code_size == code_size ||
961
+ il->code_size == InvertedLists::INVALID_CODE_SIZE
962
+ );
887
963
  }
888
964
  invlists = il;
889
965
  own_invlists = own;
@@ -971,6 +1047,17 @@ void IndexIVFStats::reset()
971
1047
  memset ((void*)this, 0, sizeof (*this));
972
1048
  }
973
1049
 
1050
+ void IndexIVFStats::add (const IndexIVFStats & other)
1051
+ {
1052
+ nq += other.nq;
1053
+ nlist += other.nlist;
1054
+ ndis += other.ndis;
1055
+ nheap_updates += other.nheap_updates;
1056
+ quantization_time += other.quantization_time;
1057
+ search_time += other.search_time;
1058
+
1059
+ }
1060
+
974
1061
 
975
1062
  IndexIVFStats indexIVF_stats;
976
1063