faiss 0.1.3 → 0.1.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (184) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +1 -1
  6. data/lib/faiss/version.rb +1 -1
  7. data/vendor/faiss/faiss/AutoTune.cpp +36 -33
  8. data/vendor/faiss/faiss/AutoTune.h +6 -3
  9. data/vendor/faiss/faiss/Clustering.cpp +16 -12
  10. data/vendor/faiss/faiss/Index.cpp +3 -4
  11. data/vendor/faiss/faiss/Index.h +3 -3
  12. data/vendor/faiss/faiss/IndexBinary.cpp +3 -4
  13. data/vendor/faiss/faiss/IndexBinary.h +1 -1
  14. data/vendor/faiss/faiss/IndexBinaryHash.cpp +2 -12
  15. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -2
  16. data/vendor/faiss/faiss/IndexFlat.cpp +0 -148
  17. data/vendor/faiss/faiss/IndexFlat.h +0 -51
  18. data/vendor/faiss/faiss/IndexHNSW.cpp +4 -5
  19. data/vendor/faiss/faiss/IndexIVF.cpp +118 -31
  20. data/vendor/faiss/faiss/IndexIVF.h +22 -15
  21. data/vendor/faiss/faiss/IndexIVFFlat.cpp +3 -3
  22. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +39 -15
  24. data/vendor/faiss/faiss/IndexIVFPQ.h +25 -9
  25. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +1116 -0
  26. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +166 -0
  27. data/vendor/faiss/faiss/IndexIVFPQR.cpp +8 -9
  28. data/vendor/faiss/faiss/IndexIVFPQR.h +2 -1
  29. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +1 -2
  30. data/vendor/faiss/faiss/IndexPQ.cpp +34 -18
  31. data/vendor/faiss/faiss/IndexPQFastScan.cpp +536 -0
  32. data/vendor/faiss/faiss/IndexPQFastScan.h +111 -0
  33. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -0
  34. data/vendor/faiss/faiss/IndexPreTransform.h +2 -0
  35. data/vendor/faiss/faiss/IndexRefine.cpp +256 -0
  36. data/vendor/faiss/faiss/IndexRefine.h +73 -0
  37. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -2
  38. data/vendor/faiss/faiss/IndexScalarQuantizer.h +1 -1
  39. data/vendor/faiss/faiss/gpu/GpuDistance.h +1 -1
  40. data/vendor/faiss/faiss/gpu/GpuIndex.h +16 -9
  41. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +8 -1
  42. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -11
  43. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +19 -2
  44. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +28 -2
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +24 -14
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +29 -2
  47. data/vendor/faiss/faiss/gpu/GpuResources.h +4 -0
  48. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +60 -27
  49. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +28 -6
  50. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +547 -0
  51. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +51 -0
  52. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +3 -3
  53. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +3 -2
  54. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +274 -0
  55. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +7 -2
  56. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +5 -1
  57. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +231 -0
  58. data/vendor/faiss/faiss/gpu/test/TestUtils.h +33 -0
  59. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +1 -0
  60. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +6 -0
  61. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +5 -6
  62. data/vendor/faiss/faiss/gpu/utils/Timer.h +2 -2
  63. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +5 -4
  64. data/vendor/faiss/faiss/impl/HNSW.cpp +2 -4
  65. data/vendor/faiss/faiss/impl/PolysemousTraining.h +4 -4
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +22 -12
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -0
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +452 -0
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +29 -19
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +6 -0
  71. data/vendor/faiss/faiss/impl/index_read.cpp +64 -96
  72. data/vendor/faiss/faiss/impl/index_write.cpp +34 -25
  73. data/vendor/faiss/faiss/impl/io.cpp +33 -2
  74. data/vendor/faiss/faiss/impl/io.h +7 -2
  75. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -15
  76. data/vendor/faiss/faiss/impl/platform_macros.h +44 -0
  77. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +272 -0
  78. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +169 -0
  79. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +180 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +354 -0
  81. data/vendor/faiss/faiss/impl/simd_result_handlers.h +559 -0
  82. data/vendor/faiss/faiss/index_factory.cpp +112 -7
  83. data/vendor/faiss/faiss/index_io.h +1 -48
  84. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +151 -0
  85. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +76 -0
  86. data/vendor/faiss/faiss/{DirectMap.cpp → invlists/DirectMap.cpp} +1 -1
  87. data/vendor/faiss/faiss/{DirectMap.h → invlists/DirectMap.h} +1 -1
  88. data/vendor/faiss/faiss/{InvertedLists.cpp → invlists/InvertedLists.cpp} +72 -1
  89. data/vendor/faiss/faiss/{InvertedLists.h → invlists/InvertedLists.h} +32 -1
  90. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +107 -0
  91. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +63 -0
  92. data/vendor/faiss/faiss/{OnDiskInvertedLists.cpp → invlists/OnDiskInvertedLists.cpp} +21 -6
  93. data/vendor/faiss/faiss/{OnDiskInvertedLists.h → invlists/OnDiskInvertedLists.h} +5 -2
  94. data/vendor/faiss/faiss/python/python_callbacks.h +8 -1
  95. data/vendor/faiss/faiss/utils/AlignedTable.h +141 -0
  96. data/vendor/faiss/faiss/utils/Heap.cpp +2 -4
  97. data/vendor/faiss/faiss/utils/Heap.h +61 -50
  98. data/vendor/faiss/faiss/utils/distances.cpp +164 -319
  99. data/vendor/faiss/faiss/utils/distances.h +28 -20
  100. data/vendor/faiss/faiss/utils/distances_simd.cpp +277 -49
  101. data/vendor/faiss/faiss/utils/extra_distances.cpp +1 -2
  102. data/vendor/faiss/faiss/utils/hamming-inl.h +4 -4
  103. data/vendor/faiss/faiss/utils/hamming.cpp +3 -6
  104. data/vendor/faiss/faiss/utils/hamming.h +2 -7
  105. data/vendor/faiss/faiss/utils/ordered_key_value.h +98 -0
  106. data/vendor/faiss/faiss/utils/partitioning.cpp +1256 -0
  107. data/vendor/faiss/faiss/utils/partitioning.h +69 -0
  108. data/vendor/faiss/faiss/utils/quantize_lut.cpp +277 -0
  109. data/vendor/faiss/faiss/utils/quantize_lut.h +80 -0
  110. data/vendor/faiss/faiss/utils/simdlib.h +31 -0
  111. data/vendor/faiss/faiss/utils/simdlib_avx2.h +461 -0
  112. data/vendor/faiss/faiss/utils/simdlib_emulated.h +589 -0
  113. metadata +43 -141
  114. data/vendor/faiss/benchs/bench_6bit_codec.cpp +0 -80
  115. data/vendor/faiss/c_api/AutoTune_c.cpp +0 -83
  116. data/vendor/faiss/c_api/AutoTune_c.h +0 -66
  117. data/vendor/faiss/c_api/Clustering_c.cpp +0 -145
  118. data/vendor/faiss/c_api/Clustering_c.h +0 -123
  119. data/vendor/faiss/c_api/IndexFlat_c.cpp +0 -140
  120. data/vendor/faiss/c_api/IndexFlat_c.h +0 -115
  121. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +0 -64
  122. data/vendor/faiss/c_api/IndexIVFFlat_c.h +0 -58
  123. data/vendor/faiss/c_api/IndexIVF_c.cpp +0 -99
  124. data/vendor/faiss/c_api/IndexIVF_c.h +0 -142
  125. data/vendor/faiss/c_api/IndexLSH_c.cpp +0 -37
  126. data/vendor/faiss/c_api/IndexLSH_c.h +0 -40
  127. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +0 -21
  128. data/vendor/faiss/c_api/IndexPreTransform_c.h +0 -32
  129. data/vendor/faiss/c_api/IndexShards_c.cpp +0 -38
  130. data/vendor/faiss/c_api/IndexShards_c.h +0 -39
  131. data/vendor/faiss/c_api/Index_c.cpp +0 -105
  132. data/vendor/faiss/c_api/Index_c.h +0 -183
  133. data/vendor/faiss/c_api/MetaIndexes_c.cpp +0 -49
  134. data/vendor/faiss/c_api/MetaIndexes_c.h +0 -49
  135. data/vendor/faiss/c_api/clone_index_c.cpp +0 -23
  136. data/vendor/faiss/c_api/clone_index_c.h +0 -32
  137. data/vendor/faiss/c_api/error_c.h +0 -42
  138. data/vendor/faiss/c_api/error_impl.cpp +0 -27
  139. data/vendor/faiss/c_api/error_impl.h +0 -16
  140. data/vendor/faiss/c_api/faiss_c.h +0 -58
  141. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +0 -98
  142. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +0 -56
  143. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +0 -52
  144. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +0 -68
  145. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +0 -17
  146. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +0 -30
  147. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +0 -38
  148. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +0 -86
  149. data/vendor/faiss/c_api/gpu/GpuResources_c.h +0 -66
  150. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +0 -54
  151. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +0 -53
  152. data/vendor/faiss/c_api/gpu/macros_impl.h +0 -42
  153. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +0 -220
  154. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +0 -149
  155. data/vendor/faiss/c_api/index_factory_c.cpp +0 -26
  156. data/vendor/faiss/c_api/index_factory_c.h +0 -30
  157. data/vendor/faiss/c_api/index_io_c.cpp +0 -42
  158. data/vendor/faiss/c_api/index_io_c.h +0 -50
  159. data/vendor/faiss/c_api/macros_impl.h +0 -110
  160. data/vendor/faiss/demos/demo_imi_flat.cpp +0 -154
  161. data/vendor/faiss/demos/demo_imi_pq.cpp +0 -203
  162. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +0 -151
  163. data/vendor/faiss/demos/demo_sift1M.cpp +0 -252
  164. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +0 -185
  165. data/vendor/faiss/misc/test_blas.cpp +0 -87
  166. data/vendor/faiss/tests/test_binary_flat.cpp +0 -62
  167. data/vendor/faiss/tests/test_dealloc_invlists.cpp +0 -188
  168. data/vendor/faiss/tests/test_ivfpq_codec.cpp +0 -70
  169. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +0 -100
  170. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +0 -573
  171. data/vendor/faiss/tests/test_merge.cpp +0 -260
  172. data/vendor/faiss/tests/test_omp_threads.cpp +0 -14
  173. data/vendor/faiss/tests/test_ondisk_ivf.cpp +0 -225
  174. data/vendor/faiss/tests/test_pairs_decoding.cpp +0 -193
  175. data/vendor/faiss/tests/test_params_override.cpp +0 -236
  176. data/vendor/faiss/tests/test_pq_encoding.cpp +0 -98
  177. data/vendor/faiss/tests/test_sliding_ivf.cpp +0 -246
  178. data/vendor/faiss/tests/test_threaded_index.cpp +0 -253
  179. data/vendor/faiss/tests/test_transfer_invlists.cpp +0 -159
  180. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +0 -104
  181. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +0 -85
  182. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +0 -98
  183. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +0 -122
  184. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +0 -104
@@ -93,57 +93,6 @@ struct IndexFlatL2:IndexFlat {
93
93
  };
94
94
 
95
95
 
96
- // same as an IndexFlatL2 but a value is subtracted from each distance
97
- struct IndexFlatL2BaseShift: IndexFlatL2 {
98
- std::vector<float> shift;
99
-
100
- IndexFlatL2BaseShift (idx_t d, size_t nshift, const float *shift);
101
-
102
- void search(
103
- idx_t n,
104
- const float* x,
105
- idx_t k,
106
- float* distances,
107
- idx_t* labels) const override;
108
- };
109
-
110
-
111
- /** Index that queries in a base_index (a fast one) and refines the
112
- * results with an exact search, hopefully improving the results.
113
- */
114
- struct IndexRefineFlat: Index {
115
-
116
- /// storage for full vectors
117
- IndexFlat refine_index;
118
-
119
- /// faster index to pre-select the vectors that should be filtered
120
- Index *base_index;
121
- bool own_fields; ///< should the base index be deallocated?
122
-
123
- /// factor between k requested in search and the k requested from
124
- /// the base_index (should be >= 1)
125
- float k_factor;
126
-
127
- explicit IndexRefineFlat (Index *base_index);
128
-
129
- IndexRefineFlat ();
130
-
131
- void train(idx_t n, const float* x) override;
132
-
133
- void add(idx_t n, const float* x) override;
134
-
135
- void reset() override;
136
-
137
- void search(
138
- idx_t n,
139
- const float* x,
140
- idx_t k,
141
- float* distances,
142
- idx_t* labels) const override;
143
-
144
- ~IndexRefineFlat() override;
145
- };
146
-
147
96
 
148
97
  /// optimized version for 1D "vectors".
149
98
  struct IndexFlat1D:IndexFlatL2 {
@@ -277,7 +277,7 @@ IndexHNSW::~IndexHNSW() {
277
277
  void IndexHNSW::train(idx_t n, const float* x)
278
278
  {
279
279
  FAISS_THROW_IF_NOT_MSG(storage,
280
- "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
280
+ "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
281
281
  // hnsw structure does not require training
282
282
  storage->train (n, x);
283
283
  is_trained = true;
@@ -288,7 +288,7 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
288
288
 
289
289
  {
290
290
  FAISS_THROW_IF_NOT_MSG(storage,
291
- "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
291
+ "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
292
292
  size_t n1 = 0, n2 = 0, n3 = 0, ndis = 0, nreorder = 0;
293
293
 
294
294
  idx_t check_period = InterruptCallback::get_period_hint (
@@ -352,7 +352,7 @@ void IndexHNSW::search (idx_t n, const float *x, idx_t k,
352
352
  void IndexHNSW::add(idx_t n, const float *x)
353
353
  {
354
354
  FAISS_THROW_IF_NOT_MSG(storage,
355
- "Please use IndexHSNWFlat (or variants) instead of IndexHNSW directly");
355
+ "Please use IndexHNSWFlat (or variants) instead of IndexHNSW directly");
356
356
  FAISS_THROW_IF_NOT(is_trained);
357
357
  int n0 = ntotal;
358
358
  storage->add(n, x);
@@ -1003,8 +1003,7 @@ int search_from_candidates_2(const HNSW & hnsw,
1003
1003
  if (nres < k) {
1004
1004
  faiss::maxheap_push (++nres, D, I, d, v1);
1005
1005
  } else if (d < D[0]) {
1006
- faiss::maxheap_pop (nres--, D, I);
1007
- faiss::maxheap_push (++nres, D, I, d, v1);
1006
+ faiss::maxheap_replace_top (nres, D, I, d, v1);
1008
1007
  }
1009
1008
  }
1010
1009
  vt.visited[v1] = vt.visno + 1;
@@ -88,12 +88,19 @@ void Level1Quantizer::train_q1 (size_t n, const float *x, bool verbose, MetricTy
88
88
  }
89
89
  quantizer->is_trained = true;
90
90
  } else if (quantizer_trains_alone == 2) {
91
- if (verbose)
91
+ if (verbose) {
92
92
  printf (
93
93
  "Training L2 quantizer on %zd vectors in %zdD%s\n",
94
94
  n, d,
95
95
  clustering_index ? "(user provided index)" : "");
96
- FAISS_THROW_IF_NOT (metric_type == METRIC_L2);
96
+ }
97
+ // also accept spherical centroids because in that case
98
+ // L2 and IP are equivalent
99
+ FAISS_THROW_IF_NOT (
100
+ metric_type == METRIC_L2 ||
101
+ (metric_type == METRIC_INNER_PRODUCT && cp.spherical)
102
+ );
103
+
97
104
  Clustering clus (d, nlist, cp);
98
105
  if (!clustering_index) {
99
106
  IndexFlatL2 assigner (d);
@@ -263,23 +270,76 @@ void IndexIVF::set_direct_map_type (DirectMap::Type type)
263
270
  direct_map.set_type (type, invlists, ntotal);
264
271
  }
265
272
 
266
-
273
+ /** It is a sad fact of software that a conceptually simple function like this
274
+ * becomes very complex when you factor in several ways of parallelizing +
275
+ * interrupt/error handling + collecting stats + min/max collection. The
276
+ * codepath that is used 95% of time is the one for parallel_mode = 0 */
267
277
  void IndexIVF::search (idx_t n, const float *x, idx_t k,
268
278
  float *distances, idx_t *labels) const
269
279
  {
270
- std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
271
- std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
272
280
 
273
- double t0 = getmillisecs();
274
- quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
275
- indexIVF_stats.quantization_time += getmillisecs() - t0;
276
281
 
277
- t0 = getmillisecs();
278
- invlists->prefetch_lists (idx.get(), n * nprobe);
282
+ // search function for a subset of queries
283
+ auto sub_search_func = [this, k]
284
+ (idx_t n, const float *x, float *distances, idx_t *labels,
285
+ IndexIVFStats *ivf_stats) {
286
+
287
+ std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
288
+ std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
289
+
290
+ double t0 = getmillisecs();
291
+ quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
292
+
293
+ double t1 = getmillisecs();
294
+ invlists->prefetch_lists (idx.get(), n * nprobe);
295
+
296
+ search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
297
+ distances, labels, false, nullptr, ivf_stats);
298
+ double t2 = getmillisecs();
299
+ ivf_stats->quantization_time += t1 - t0;
300
+ ivf_stats->search_time += t2 - t0;
301
+ };
302
+
303
+
304
+ if ((parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT) == 0) {
305
+ int nt = std::min(omp_get_max_threads(), int(n));
306
+ std::vector<IndexIVFStats> stats(nt);
307
+ std::mutex exception_mutex;
308
+ std::string exception_string;
309
+
310
+ #pragma omp parallel for if (nt > 1)
311
+ for(idx_t slice = 0; slice < nt; slice++) {
312
+ IndexIVFStats local_stats;
313
+ idx_t i0 = n * slice / nt;
314
+ idx_t i1 = n * (slice + 1) / nt;
315
+ if (i1 > i0) {
316
+ try {
317
+ sub_search_func(
318
+ i1 - i0, x + i0 * d,
319
+ distances + i0 * k, labels + i0 * k,
320
+ &stats[slice]
321
+ );
322
+ } catch(const std::exception & e) {
323
+ std::lock_guard<std::mutex> lock(exception_mutex);
324
+ exception_string = e.what();
325
+ }
326
+ }
327
+ }
328
+
329
+ if (!exception_string.empty()) {
330
+ FAISS_THROW_MSG (exception_string.c_str());
331
+ }
332
+
333
+ // collect stats
334
+ for(idx_t slice = 0; slice < nt; slice++) {
335
+ indexIVF_stats.add(stats[slice]);
336
+ }
337
+ } else {
338
+ // handle paralellization at level below (or don't run in parallel at all)
339
+ sub_search_func(n, x, distances, labels, &indexIVF_stats);
340
+ }
341
+
279
342
 
280
- search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
281
- distances, labels, false);
282
- indexIVF_stats.search_time += getmillisecs() - t0;
283
343
  }
284
344
 
285
345
 
@@ -288,7 +348,8 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
288
348
  const float *coarse_dis ,
289
349
  float *distances, idx_t *labels,
290
350
  bool store_pairs,
291
- const IVFSearchParameters *params) const
351
+ const IVFSearchParameters *params,
352
+ IndexIVFStats *ivf_stats) const
292
353
  {
293
354
  long nprobe = params ? params->nprobe : this->nprobe;
294
355
  long max_codes = params ? params->max_codes : this->max_codes;
@@ -305,13 +366,12 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
305
366
  int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
306
367
  bool do_heap_init = !(this->parallel_mode & PARALLEL_MODE_NO_HEAP_INIT);
307
368
 
308
- // don't start parallel section if single query
309
369
  bool do_parallel = omp_get_max_threads() >= 2 && (
310
- pmode == 0 ? n > 1 :
370
+ pmode == 0 ? false :
371
+ pmode == 3 ? n > 1 :
311
372
  pmode == 1 ? nprobe > 1 :
312
373
  nprobe * n > 1);
313
374
 
314
-
315
375
  #pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
316
376
  {
317
377
  InvertedListScanner *scanner = get_InvertedListScanner(store_pairs);
@@ -409,7 +469,7 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
409
469
  * Actual loops, depending on parallel_mode
410
470
  ****************************************************/
411
471
 
412
- if (pmode == 0) {
472
+ if (pmode == 0 || pmode == 3) {
413
473
 
414
474
  #pragma omp for
415
475
  for (idx_t i = 0; i < n; i++) {
@@ -527,11 +587,12 @@ void IndexIVF::search_preassigned (idx_t n, const float *x, idx_t k,
527
587
  }
528
588
  }
529
589
 
530
- indexIVF_stats.nq += n;
531
- indexIVF_stats.nlist += nlistv;
532
- indexIVF_stats.ndis += ndis;
533
- indexIVF_stats.nheap_updates += nheap;
534
-
590
+ if (ivf_stats) {
591
+ ivf_stats->nq += n;
592
+ ivf_stats->nlist += nlistv;
593
+ ivf_stats->ndis += ndis;
594
+ ivf_stats->nheap_updates += nheap;
595
+ }
535
596
  }
536
597
 
537
598
 
@@ -551,7 +612,7 @@ void IndexIVF::range_search (idx_t nx, const float *x, float radius,
551
612
  invlists->prefetch_lists (keys.get(), nx * nprobe);
552
613
 
553
614
  range_search_preassigned (nx, x, radius, keys.get (), coarse_dis.get (),
554
- result);
615
+ result, false, nullptr, &indexIVF_stats);
555
616
 
556
617
  indexIVF_stats.search_time += getmillisecs() - t0;
557
618
  }
@@ -561,7 +622,8 @@ void IndexIVF::range_search_preassigned (
561
622
  const idx_t *keys, const float *coarse_dis,
562
623
  RangeSearchResult *result,
563
624
  bool store_pairs,
564
- const IVFSearchParameters *params) const
625
+ const IVFSearchParameters *params,
626
+ IndexIVFStats *stats) const
565
627
  {
566
628
  long nprobe = params ? params->nprobe : this->nprobe;
567
629
  long max_codes = params ? params->max_codes : this->max_codes;
@@ -574,7 +636,15 @@ void IndexIVF::range_search_preassigned (
574
636
 
575
637
  std::vector<RangeSearchPartialResult *> all_pres (omp_get_max_threads());
576
638
 
577
- #pragma omp parallel reduction(+: nlistv, ndis)
639
+ int pmode = this->parallel_mode & ~PARALLEL_MODE_NO_HEAP_INIT;
640
+ // don't start parallel section if single query
641
+ bool do_parallel = omp_get_max_threads() >= 2 && (
642
+ pmode == 3 ? false :
643
+ pmode == 0 ? nx > 1 :
644
+ pmode == 1 ? nprobe > 1 :
645
+ nprobe * nx > 1);
646
+
647
+ #pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis)
578
648
  {
579
649
  RangeSearchPartialResult pres(result);
580
650
  std::unique_ptr<InvertedListScanner> scanner
@@ -680,9 +750,11 @@ void IndexIVF::range_search_preassigned (
680
750
  }
681
751
  }
682
752
 
683
- indexIVF_stats.nq += nx;
684
- indexIVF_stats.nlist += nlistv;
685
- indexIVF_stats.ndis += ndis;
753
+ if (stats) {
754
+ stats->nq += nx;
755
+ stats->nlist += nlistv;
756
+ stats->ndis += ndis;
757
+ }
686
758
  }
687
759
 
688
760
 
@@ -879,11 +951,15 @@ void IndexIVF::replace_invlists (InvertedLists *il, bool own)
879
951
  {
880
952
  if (own_invlists) {
881
953
  delete invlists;
954
+ invlists = nullptr;
882
955
  }
883
956
  // FAISS_THROW_IF_NOT (ntotal == 0);
884
957
  if (il) {
885
- FAISS_THROW_IF_NOT (il->nlist == nlist &&
886
- il->code_size == code_size);
958
+ FAISS_THROW_IF_NOT (il->nlist == nlist);
959
+ FAISS_THROW_IF_NOT (
960
+ il->code_size == code_size ||
961
+ il->code_size == InvertedLists::INVALID_CODE_SIZE
962
+ );
887
963
  }
888
964
  invlists = il;
889
965
  own_invlists = own;
@@ -971,6 +1047,17 @@ void IndexIVFStats::reset()
971
1047
  memset ((void*)this, 0, sizeof (*this));
972
1048
  }
973
1049
 
1050
+ void IndexIVFStats::add (const IndexIVFStats & other)
1051
+ {
1052
+ nq += other.nq;
1053
+ nlist += other.nlist;
1054
+ ndis += other.ndis;
1055
+ nheap_updates += other.nheap_updates;
1056
+ quantization_time += other.quantization_time;
1057
+ search_time += other.search_time;
1058
+
1059
+ }
1060
+
974
1061
 
975
1062
  IndexIVFStats indexIVF_stats;
976
1063
 
@@ -16,8 +16,8 @@
16
16
  #include <stdint.h>
17
17
 
18
18
  #include <faiss/Index.h>
19
- #include <faiss/InvertedLists.h>
20
- #include <faiss/DirectMap.h>
19
+ #include <faiss/invlists/InvertedLists.h>
20
+ #include <faiss/invlists/DirectMap.h>
21
21
  #include <faiss/Clustering.h>
22
22
  #include <faiss/impl/platform_macros.h>
23
23
  #include <faiss/utils/Heap.h>
@@ -76,6 +76,7 @@ struct IVFSearchParameters {
76
76
 
77
77
 
78
78
  struct InvertedListScanner;
79
+ struct IndexIVFStats;
79
80
 
80
81
  /** Index based on a inverted file (IVF)
81
82
  *
@@ -109,9 +110,10 @@ struct IndexIVF: Index, Level1Quantizer {
109
110
 
110
111
  /** Parallel mode determines how queries are parallelized with OpenMP
111
112
  *
112
- * 0 (default): parallelize over queries
113
+ * 0 (default): split over queries
113
114
  * 1: parallelize over inverted lists
114
115
  * 2: parallelize over both
116
+ * 3: split over queries with a finer granularity
115
117
  *
116
118
  * PARALLEL_MODE_NO_HEAP_INIT: binary or with the previous to
117
119
  * prevent the heap to be initialized and finalized
@@ -178,14 +180,16 @@ struct IndexIVF: Index, Level1Quantizer {
178
180
  * instead in upper/lower 32 bit of result,
179
181
  * instead of ids (used for reranking).
180
182
  * @param params used to override the object's search parameters
183
+ * @param stats search stats to be updated (can be null)
181
184
  */
182
- virtual void search_preassigned (idx_t n, const float *x, idx_t k,
183
- const idx_t *assign,
184
- const float *centroid_dis,
185
- float *distances, idx_t *labels,
186
- bool store_pairs,
187
- const IVFSearchParameters *params=nullptr
188
- ) const;
185
+ virtual void search_preassigned (
186
+ idx_t n, const float *x, idx_t k,
187
+ const idx_t *assign, const float *centroid_dis,
188
+ float *distances, idx_t *labels,
189
+ bool store_pairs,
190
+ const IVFSearchParameters *params=nullptr,
191
+ IndexIVFStats *stats=nullptr
192
+ ) const;
189
193
 
190
194
  /** assign the vectors, then call search_preassign */
191
195
  void search (idx_t n, const float *x, idx_t k,
@@ -194,11 +198,13 @@ struct IndexIVF: Index, Level1Quantizer {
194
198
  void range_search (idx_t n, const float* x, float radius,
195
199
  RangeSearchResult* result) const override;
196
200
 
197
- void range_search_preassigned(idx_t nx, const float *x, float radius,
198
- const idx_t *keys, const float *coarse_dis,
199
- RangeSearchResult *result,
200
- bool store_pairs=false,
201
- const IVFSearchParameters *params=nullptr) const;
201
+ void range_search_preassigned(
202
+ idx_t nx, const float *x, float radius,
203
+ const idx_t *keys, const float *coarse_dis,
204
+ RangeSearchResult *result,
205
+ bool store_pairs=false,
206
+ const IVFSearchParameters *params=nullptr,
207
+ IndexIVFStats *stats=nullptr) const;
202
208
 
203
209
  /// get a scanner for this index (store_pairs means ignore labels)
204
210
  virtual InvertedListScanner *get_InvertedListScanner (
@@ -364,6 +370,7 @@ struct IndexIVFStats {
364
370
 
365
371
  IndexIVFStats () {reset (); }
366
372
  void reset ();
373
+ void add (const IndexIVFStats & other);
367
374
  };
368
375
 
369
376
  // global var that collects them all
@@ -159,9 +159,8 @@ struct IVFFlatScanner: InvertedListScanner {
159
159
  float dis = metric == METRIC_INNER_PRODUCT ?
160
160
  fvec_inner_product (xi, yj, d) : fvec_L2sqr (xi, yj, d);
161
161
  if (C::cmp (simi[0], dis)) {
162
- heap_pop<C> (k, simi, idxi);
163
162
  int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
164
- heap_push<C> (k, simi, idxi, dis, id);
163
+ heap_replace_top<C> (k, simi, idxi, dis, id);
165
164
  nup++;
166
165
  }
167
166
  }
@@ -317,7 +316,8 @@ void IndexIVFFlatDedup::search_preassigned (
317
316
  const float *centroid_dis,
318
317
  float *distances, idx_t *labels,
319
318
  bool store_pairs,
320
- const IVFSearchParameters *params) const
319
+ const IVFSearchParameters *params,
320
+ IndexIVFStats *stats) const
321
321
  {
322
322
  FAISS_THROW_IF_NOT_MSG (
323
323
  !store_pairs, "store_pairs not supported in IVFDedup");
@@ -77,7 +77,8 @@ struct IndexIVFFlatDedup: IndexIVFFlat {
77
77
  const float *centroid_dis,
78
78
  float *distances, idx_t *labels,
79
79
  bool store_pairs,
80
- const IVFSearchParameters *params=nullptr
80
+ const IVFSearchParameters *params=nullptr,
81
+ IndexIVFStats *stats=nullptr
81
82
  ) const override;
82
83
 
83
84
  size_t remove_ids(const IDSelector& sel) override;