faiss 0.3.1 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.h +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +35 -4
  6. data/vendor/faiss/faiss/Clustering.h +10 -1
  7. data/vendor/faiss/faiss/IVFlib.cpp +4 -1
  8. data/vendor/faiss/faiss/Index.h +21 -6
  9. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  10. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -1
  11. data/vendor/faiss/faiss/IndexFastScan.cpp +22 -4
  12. data/vendor/faiss/faiss/IndexFlat.cpp +11 -7
  13. data/vendor/faiss/faiss/IndexFlatCodes.cpp +159 -5
  14. data/vendor/faiss/faiss/IndexFlatCodes.h +20 -3
  15. data/vendor/faiss/faiss/IndexHNSW.cpp +143 -90
  16. data/vendor/faiss/faiss/IndexHNSW.h +52 -3
  17. data/vendor/faiss/faiss/IndexIVF.cpp +3 -3
  18. data/vendor/faiss/faiss/IndexIVF.h +9 -1
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +15 -0
  20. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -0
  21. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +130 -57
  22. data/vendor/faiss/faiss/IndexIVFFastScan.h +14 -7
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +1 -3
  24. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +21 -2
  25. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  26. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -29
  28. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  29. data/vendor/faiss/faiss/IndexNSG.h +1 -1
  30. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  31. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  32. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  33. data/vendor/faiss/faiss/IndexRefine.cpp +5 -5
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +3 -1
  35. data/vendor/faiss/faiss/MetricType.h +7 -2
  36. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  37. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  38. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  39. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  40. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +36 -4
  41. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +6 -0
  42. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  43. data/vendor/faiss/faiss/gpu/GpuIndex.h +2 -8
  44. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +6 -0
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +2 -0
  47. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +25 -0
  48. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  49. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +6 -0
  50. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  51. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +65 -0
  52. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  53. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  54. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  55. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  56. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +25 -0
  57. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +9 -1
  58. data/vendor/faiss/faiss/impl/DistanceComputer.h +46 -0
  59. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  60. data/vendor/faiss/faiss/impl/HNSW.cpp +358 -190
  61. data/vendor/faiss/faiss/impl/HNSW.h +43 -22
  62. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +8 -8
  63. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  64. data/vendor/faiss/faiss/impl/NNDescent.cpp +13 -8
  65. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +1 -0
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +5 -1
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +151 -32
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +719 -102
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -0
  71. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +5 -0
  72. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  73. data/vendor/faiss/faiss/impl/index_read.cpp +29 -15
  74. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  75. data/vendor/faiss/faiss/impl/index_write.cpp +28 -10
  76. data/vendor/faiss/faiss/impl/io.cpp +13 -5
  77. data/vendor/faiss/faiss/impl/io.h +4 -4
  78. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  79. data/vendor/faiss/faiss/impl/platform_macros.h +22 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +11 -0
  81. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +1 -1
  82. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +448 -1
  83. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +5 -5
  84. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  85. data/vendor/faiss/faiss/impl/simd_result_handlers.h +143 -59
  86. data/vendor/faiss/faiss/index_factory.cpp +31 -13
  87. data/vendor/faiss/faiss/index_io.h +12 -5
  88. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  89. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  90. data/vendor/faiss/faiss/invlists/DirectMap.cpp +9 -1
  91. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +55 -17
  92. data/vendor/faiss/faiss/invlists/InvertedLists.h +18 -9
  93. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +21 -6
  94. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  95. data/vendor/faiss/faiss/python/python_callbacks.cpp +3 -3
  96. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  97. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  98. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  99. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  100. data/vendor/faiss/faiss/utils/distances.cpp +58 -88
  101. data/vendor/faiss/faiss/utils/distances.h +5 -5
  102. data/vendor/faiss/faiss/utils/distances_simd.cpp +997 -9
  103. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  104. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  105. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  106. data/vendor/faiss/faiss/utils/hamming.cpp +1 -1
  107. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +4 -1
  108. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +2 -1
  109. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  110. data/vendor/faiss/faiss/utils/random.h +25 -0
  111. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  112. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  113. data/vendor/faiss/faiss/utils/simdlib_neon.h +5 -2
  114. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  115. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  116. data/vendor/faiss/faiss/utils/utils.cpp +10 -3
  117. data/vendor/faiss/faiss/utils/utils.h +3 -0
  118. metadata +16 -4
  119. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -46,6 +46,7 @@ struct ResultHandler;
46
46
  struct SearchParametersHNSW : SearchParameters {
47
47
  int efSearch = 16;
48
48
  bool check_relative_distance = true;
49
+ bool bounded_queue = true;
49
50
 
50
51
  ~SearchParametersHNSW() {}
51
52
  };
@@ -141,9 +142,6 @@ struct HNSW {
141
142
  /// enough?
142
143
  bool check_relative_distance = true;
143
144
 
144
- /// number of entry points in levels > 0.
145
- int upper_beam = 1;
146
-
147
145
  /// use bounded queue during exploration
148
146
  bool search_bounded_queue = true;
149
147
 
@@ -184,7 +182,8 @@ struct HNSW {
184
182
  float d_nearest,
185
183
  int level,
186
184
  omp_lock_t* locks,
187
- VisitedTable& vt);
185
+ VisitedTable& vt,
186
+ bool keep_max_size_level0 = false);
188
187
 
189
188
  /** add point pt_id on all levels <= pt_level and build the link
190
189
  * structure for them. */
@@ -193,7 +192,8 @@ struct HNSW {
193
192
  int pt_level,
194
193
  int pt_id,
195
194
  std::vector<omp_lock_t>& locks,
196
- VisitedTable& vt);
195
+ VisitedTable& vt,
196
+ bool keep_max_size_level0 = false);
197
197
 
198
198
  /// search interface for 1 point, single thread
199
199
  HNSWStats search(
@@ -211,7 +211,8 @@ struct HNSW {
211
211
  const float* nearest_d,
212
212
  int search_type,
213
213
  HNSWStats& search_stats,
214
- VisitedTable& vt) const;
214
+ VisitedTable& vt,
215
+ const SearchParametersHNSW* params = nullptr) const;
215
216
 
216
217
  void reset();
217
218
 
@@ -224,40 +225,60 @@ struct HNSW {
224
225
  DistanceComputer& qdis,
225
226
  std::priority_queue<NodeDistFarther>& input,
226
227
  std::vector<NodeDistFarther>& output,
227
- int max_size);
228
+ int max_size,
229
+ bool keep_max_size_level0 = false);
228
230
 
229
231
  void permute_entries(const idx_t* map);
230
232
  };
231
233
 
232
234
  struct HNSWStats {
233
- size_t n1, n2, n3;
234
- size_t ndis;
235
- size_t nreorder;
236
-
237
- HNSWStats(
238
- size_t n1 = 0,
239
- size_t n2 = 0,
240
- size_t n3 = 0,
241
- size_t ndis = 0,
242
- size_t nreorder = 0)
243
- : n1(n1), n2(n2), n3(n3), ndis(ndis), nreorder(nreorder) {}
235
+ size_t n1 = 0; /// number of vectors searched
236
+ size_t n2 =
237
+ 0; /// number of queries for which the candidate list is exhausted
238
+ size_t ndis = 0; /// number of distances computed
239
+ size_t nhops = 0; /// number of hops aka number of edges traversed
244
240
 
245
241
  void reset() {
246
- n1 = n2 = n3 = 0;
242
+ n1 = n2 = 0;
247
243
  ndis = 0;
248
- nreorder = 0;
244
+ nhops = 0;
249
245
  }
250
246
 
251
247
  void combine(const HNSWStats& other) {
252
248
  n1 += other.n1;
253
249
  n2 += other.n2;
254
- n3 += other.n3;
255
250
  ndis += other.ndis;
256
- nreorder += other.nreorder;
251
+ nhops += other.nhops;
257
252
  }
258
253
  };
259
254
 
260
255
  // global var that collects them all
261
256
  FAISS_API extern HNSWStats hnsw_stats;
262
257
 
258
+ int search_from_candidates(
259
+ const HNSW& hnsw,
260
+ DistanceComputer& qdis,
261
+ ResultHandler<HNSW::C>& res,
262
+ HNSW::MinimaxHeap& candidates,
263
+ VisitedTable& vt,
264
+ HNSWStats& stats,
265
+ int level,
266
+ int nres_in = 0,
267
+ const SearchParametersHNSW* params = nullptr);
268
+
269
+ HNSWStats greedy_update_nearest(
270
+ const HNSW& hnsw,
271
+ DistanceComputer& qdis,
272
+ int level,
273
+ HNSW::storage_idx_t& nearest,
274
+ float& d_nearest);
275
+
276
+ std::priority_queue<HNSW::Node> search_from_candidate_unbounded(
277
+ const HNSW& hnsw,
278
+ const HNSW::Node& node,
279
+ DistanceComputer& qdis,
280
+ int ef,
281
+ VisitedTable* vt,
282
+ HNSWStats& stats);
283
+
263
284
  } // namespace faiss
@@ -104,10 +104,10 @@ int dgemm_(
104
104
 
105
105
  namespace {
106
106
 
107
- void fmat_inverse(float* a, int n) {
108
- int info;
109
- int lwork = n * n;
110
- std::vector<int> ipiv(n);
107
+ void fmat_inverse(float* a, FINTEGER n) {
108
+ FINTEGER info;
109
+ FINTEGER lwork = n * n;
110
+ std::vector<FINTEGER> ipiv(n);
111
111
  std::vector<float> workspace(lwork);
112
112
 
113
113
  sgetrf_(&n, &n, a, &n, ipiv.data(), &info);
@@ -123,10 +123,10 @@ void dfvec_add(size_t d, const double* a, const float* b, double* c) {
123
123
  }
124
124
  }
125
125
 
126
- void dmat_inverse(double* a, int n) {
127
- int info;
128
- int lwork = n * n;
129
- std::vector<int> ipiv(n);
126
+ void dmat_inverse(double* a, FINTEGER n) {
127
+ FINTEGER info;
128
+ FINTEGER lwork = n * n;
129
+ std::vector<FINTEGER> ipiv(n);
130
130
  std::vector<double> workspace(lwork);
131
131
 
132
132
  dgetrf_(&n, &n, a, &n, ipiv.data(), &info);
@@ -38,6 +38,23 @@ struct DummyScaler {
38
38
  return simd16uint16(0);
39
39
  }
40
40
 
41
+ #ifdef __AVX512F__
42
+ inline simd64uint8 lookup(const simd64uint8&, const simd64uint8&) const {
43
+ FAISS_THROW_MSG("DummyScaler::lookup should not be called.");
44
+ return simd64uint8(0);
45
+ }
46
+
47
+ inline simd32uint16 scale_lo(const simd64uint8&) const {
48
+ FAISS_THROW_MSG("DummyScaler::scale_lo should not be called.");
49
+ return simd32uint16(0);
50
+ }
51
+
52
+ inline simd32uint16 scale_hi(const simd64uint8&) const {
53
+ FAISS_THROW_MSG("DummyScaler::scale_hi should not be called.");
54
+ return simd32uint16(0);
55
+ }
56
+ #endif
57
+
41
58
  template <class dist_t>
42
59
  inline dist_t scale_one(const dist_t&) const {
43
60
  FAISS_THROW_MSG("DummyScaler::scale_one should not be called.");
@@ -67,6 +84,23 @@ struct NormTableScaler {
67
84
  return (simd16uint16(res) >> 8) * scale_simd;
68
85
  }
69
86
 
87
+ #ifdef __AVX512F__
88
+ inline simd64uint8 lookup(const simd64uint8& lut, const simd64uint8& c)
89
+ const {
90
+ return lut.lookup_4_lanes(c);
91
+ }
92
+
93
+ inline simd32uint16 scale_lo(const simd64uint8& res) const {
94
+ auto scale_simd_wide = simd32uint16(scale_simd, scale_simd);
95
+ return simd32uint16(res) * scale_simd_wide;
96
+ }
97
+
98
+ inline simd32uint16 scale_hi(const simd64uint8& res) const {
99
+ auto scale_simd_wide = simd32uint16(scale_simd, scale_simd);
100
+ return (simd32uint16(res) >> 8) * scale_simd_wide;
101
+ }
102
+ #endif
103
+
70
104
  // for non-SIMD implem 2, 3, 4
71
105
  template <class dist_t>
72
106
  inline dist_t scale_one(const dist_t& x) const {
@@ -154,15 +154,20 @@ NNDescent::NNDescent(const int d, const int K) : K(K), d(d) {
154
154
  NNDescent::~NNDescent() {}
155
155
 
156
156
  void NNDescent::join(DistanceComputer& qdis) {
157
+ idx_t check_period = InterruptCallback::get_period_hint(d * search_L);
158
+ for (idx_t i0 = 0; i0 < (idx_t)ntotal; i0 += check_period) {
159
+ idx_t i1 = std::min(i0 + check_period, (idx_t)ntotal);
157
160
  #pragma omp parallel for default(shared) schedule(dynamic, 100)
158
- for (int n = 0; n < ntotal; n++) {
159
- graph[n].join([&](int i, int j) {
160
- if (i != j) {
161
- float dist = qdis.symmetric_dis(i, j);
162
- graph[i].insert(j, dist);
163
- graph[j].insert(i, dist);
164
- }
165
- });
161
+ for (idx_t n = i0; n < i1; n++) {
162
+ graph[n].join([&](int i, int j) {
163
+ if (i != j) {
164
+ float dist = qdis.symmetric_dis(i, j);
165
+ graph[i].insert(j, dist);
166
+ graph[j].insert(i, dist);
167
+ }
168
+ });
169
+ }
170
+ InterruptCallback::check();
166
171
  }
167
172
  }
168
173
 
@@ -25,35 +25,6 @@ namespace {
25
25
  // It needs to be smaller than 0
26
26
  constexpr int EMPTY_ID = -1;
27
27
 
28
- /* Wrap the distance computer into one that negates the
29
- distances. This makes supporting INNER_PRODUCE search easier */
30
-
31
- struct NegativeDistanceComputer : DistanceComputer {
32
- /// owned by this
33
- DistanceComputer* basedis;
34
-
35
- explicit NegativeDistanceComputer(DistanceComputer* basedis)
36
- : basedis(basedis) {}
37
-
38
- void set_query(const float* x) override {
39
- basedis->set_query(x);
40
- }
41
-
42
- /// compute distance of vector i to current query
43
- float operator()(idx_t i) override {
44
- return -(*basedis)(i);
45
- }
46
-
47
- /// compute distance between two stored vectors
48
- float symmetric_dis(idx_t i, idx_t j) override {
49
- return -basedis->symmetric_dis(i, j);
50
- }
51
-
52
- ~NegativeDistanceComputer() override {
53
- delete basedis;
54
- }
55
- };
56
-
57
28
  } // namespace
58
29
 
59
30
  DistanceComputer* storage_distance_computer(const Index* storage) {
@@ -61,6 +61,7 @@ void ProductQuantizer::set_derived_values() {
61
61
  "The dimension of the vector (d) should be a multiple of the number of subquantizers (M)");
62
62
  dsub = d / M;
63
63
  code_size = (nbits * M + 7) / 8;
64
+ FAISS_THROW_IF_MSG(nbits > 24, "nbits larger than 24 is not practical.");
64
65
  ksub = 1 << nbits;
65
66
  centroids.resize(d * ksub);
66
67
  verbose = false;
@@ -21,7 +21,11 @@
21
21
 
22
22
  namespace faiss {
23
23
 
24
- /** Product Quantizer. Implemented only for METRIC_L2 */
24
+ /** Product Quantizer.
25
+ * PQ is trained using k-means, minimizing the L2 distance to centroids.
26
+ * PQ supports L2 and Inner Product search, however the quantization error is
27
+ * biased towards L2 distance.
28
+ */
25
29
  struct ProductQuantizer : Quantizer {
26
30
  size_t M; ///< number of subquantizers
27
31
  size_t nbits; ///< number of bits per quantization index
@@ -12,9 +12,14 @@
12
12
  #pragma once
13
13
 
14
14
  #include <faiss/impl/AuxIndexStructures.h>
15
+ #include <faiss/impl/FaissException.h>
16
+ #include <faiss/impl/IDSelector.h>
15
17
  #include <faiss/utils/Heap.h>
16
18
  #include <faiss/utils/partitioning.h>
17
19
 
20
+ #include <algorithm>
21
+ #include <iostream>
22
+
18
23
  namespace faiss {
19
24
 
20
25
  /*****************************************************************
@@ -24,16 +29,21 @@ namespace faiss {
24
29
  * - by instanciating a SingleResultHandler that tracks results for a single
25
30
  * query
26
31
  * - with begin_multiple/add_results/end_multiple calls where a whole block of
27
- * resutls is submitted
32
+ * results is submitted
28
33
  * All classes are templated on C which to define wheter the min or the max of
29
- * results is to be kept.
34
+ * results is to be kept, and on sel, so that the codepaths for with / without
35
+ * selector can be separated at compile time.
30
36
  *****************************************************************/
31
37
 
32
- template <class C>
38
+ template <class C, bool use_sel = false>
33
39
  struct BlockResultHandler {
34
40
  size_t nq; // number of queries for which we search
41
+ const IDSelector* sel;
35
42
 
36
- explicit BlockResultHandler(size_t nq) : nq(nq) {}
43
+ explicit BlockResultHandler(size_t nq, const IDSelector* sel = nullptr)
44
+ : nq(nq), sel(sel) {
45
+ assert(!use_sel || sel);
46
+ }
37
47
 
38
48
  // currently handled query range
39
49
  size_t i0 = 0, i1 = 0;
@@ -51,13 +61,17 @@ struct BlockResultHandler {
51
61
  virtual void end_multiple() {}
52
62
 
53
63
  virtual ~BlockResultHandler() {}
64
+
65
+ bool is_in_selection(idx_t i) const {
66
+ return !use_sel || sel->is_member(i);
67
+ }
54
68
  };
55
69
 
56
70
  // handler for a single query
57
71
  template <class C>
58
72
  struct ResultHandler {
59
73
  // if not better than threshold, then not necessary to call add_result
60
- typename C::T threshold = 0;
74
+ typename C::T threshold = C::neutral();
61
75
 
62
76
  // return whether threshold was updated
63
77
  virtual bool add_result(typename C::T dis, typename C::TI idx) = 0;
@@ -71,20 +85,26 @@ struct ResultHandler {
71
85
  * some temporary data in memory.
72
86
  *****************************************************************/
73
87
 
74
- template <class C>
75
- struct Top1BlockResultHandler : BlockResultHandler<C> {
88
+ template <class C, bool use_sel = false>
89
+ struct Top1BlockResultHandler : BlockResultHandler<C, use_sel> {
76
90
  using T = typename C::T;
77
91
  using TI = typename C::TI;
78
- using BlockResultHandler<C>::i0;
79
- using BlockResultHandler<C>::i1;
92
+ using BlockResultHandler<C, use_sel>::i0;
93
+ using BlockResultHandler<C, use_sel>::i1;
80
94
 
81
95
  // contains exactly nq elements
82
96
  T* dis_tab;
83
97
  // contains exactly nq elements
84
98
  TI* ids_tab;
85
99
 
86
- Top1BlockResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
87
- : BlockResultHandler<C>(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
100
+ Top1BlockResultHandler(
101
+ size_t nq,
102
+ T* dis_tab,
103
+ TI* ids_tab,
104
+ const IDSelector* sel = nullptr)
105
+ : BlockResultHandler<C, use_sel>(nq, sel),
106
+ dis_tab(dis_tab),
107
+ ids_tab(ids_tab) {}
88
108
 
89
109
  struct SingleResultHandler : ResultHandler<C> {
90
110
  Top1BlockResultHandler& hr;
@@ -163,12 +183,12 @@ struct Top1BlockResultHandler : BlockResultHandler<C> {
163
183
  * Heap based result handler
164
184
  *****************************************************************/
165
185
 
166
- template <class C>
167
- struct HeapBlockResultHandler : BlockResultHandler<C> {
186
+ template <class C, bool use_sel = false>
187
+ struct HeapBlockResultHandler : BlockResultHandler<C, use_sel> {
168
188
  using T = typename C::T;
169
189
  using TI = typename C::TI;
170
- using BlockResultHandler<C>::i0;
171
- using BlockResultHandler<C>::i1;
190
+ using BlockResultHandler<C, use_sel>::i0;
191
+ using BlockResultHandler<C, use_sel>::i1;
172
192
 
173
193
  T* heap_dis_tab;
174
194
  TI* heap_ids_tab;
@@ -179,8 +199,9 @@ struct HeapBlockResultHandler : BlockResultHandler<C> {
179
199
  size_t nq,
180
200
  T* heap_dis_tab,
181
201
  TI* heap_ids_tab,
182
- size_t k)
183
- : BlockResultHandler<C>(nq),
202
+ size_t k,
203
+ const IDSelector* sel = nullptr)
204
+ : BlockResultHandler<C, use_sel>(nq, sel),
184
205
  heap_dis_tab(heap_dis_tab),
185
206
  heap_ids_tab(heap_ids_tab),
186
207
  k(k) {}
@@ -345,12 +366,12 @@ struct ReservoirTopN : ResultHandler<C> {
345
366
  }
346
367
  };
347
368
 
348
- template <class C>
349
- struct ReservoirBlockResultHandler : BlockResultHandler<C> {
369
+ template <class C, bool use_sel = false>
370
+ struct ReservoirBlockResultHandler : BlockResultHandler<C, use_sel> {
350
371
  using T = typename C::T;
351
372
  using TI = typename C::TI;
352
- using BlockResultHandler<C>::i0;
353
- using BlockResultHandler<C>::i1;
373
+ using BlockResultHandler<C, use_sel>::i0;
374
+ using BlockResultHandler<C, use_sel>::i1;
354
375
 
355
376
  T* heap_dis_tab;
356
377
  TI* heap_ids_tab;
@@ -362,8 +383,9 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
362
383
  size_t nq,
363
384
  T* heap_dis_tab,
364
385
  TI* heap_ids_tab,
365
- size_t k)
366
- : BlockResultHandler<C>(nq),
386
+ size_t k,
387
+ const IDSelector* sel = nullptr)
388
+ : BlockResultHandler<C, use_sel>(nq, sel),
367
389
  heap_dis_tab(heap_dis_tab),
368
390
  heap_ids_tab(heap_ids_tab),
369
391
  k(k) {
@@ -458,18 +480,23 @@ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
458
480
  * Result handler for range searches
459
481
  *****************************************************************/
460
482
 
461
- template <class C>
462
- struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
483
+ template <class C, bool use_sel = false>
484
+ struct RangeSearchBlockResultHandler : BlockResultHandler<C, use_sel> {
463
485
  using T = typename C::T;
464
486
  using TI = typename C::TI;
465
- using BlockResultHandler<C>::i0;
466
- using BlockResultHandler<C>::i1;
487
+ using BlockResultHandler<C, use_sel>::i0;
488
+ using BlockResultHandler<C, use_sel>::i1;
467
489
 
468
490
  RangeSearchResult* res;
469
491
  T radius;
470
492
 
471
- RangeSearchBlockResultHandler(RangeSearchResult* res, float radius)
472
- : BlockResultHandler<C>(res->nq), res(res), radius(radius) {}
493
+ RangeSearchBlockResultHandler(
494
+ RangeSearchResult* res,
495
+ float radius,
496
+ const IDSelector* sel = nullptr)
497
+ : BlockResultHandler<C, use_sel>(res->nq, sel),
498
+ res(res),
499
+ radius(radius) {}
473
500
 
474
501
  /******************************************************
475
502
  * API for 1 result at a time (each SingleResultHandler is
@@ -504,7 +531,15 @@ struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
504
531
  void end() {}
505
532
 
506
533
  ~SingleResultHandler() {
507
- pres.finalize();
534
+ try {
535
+ // finalize the partial result
536
+ pres.finalize();
537
+ } catch (const faiss::FaissException& e) {
538
+ // Do nothing if allocation fails in finalizing partial results.
539
+ #ifndef NDEBUG
540
+ std::cerr << e.what() << std::endl;
541
+ #endif
542
+ }
508
543
  }
509
544
  };
510
545
 
@@ -559,10 +594,94 @@ struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
559
594
  }
560
595
 
561
596
  ~RangeSearchBlockResultHandler() {
562
- if (partial_results.size() > 0) {
563
- RangeSearchPartialResult::merge(partial_results);
597
+ try {
598
+ if (partial_results.size() > 0) {
599
+ RangeSearchPartialResult::merge(partial_results);
600
+ }
601
+ } catch (const faiss::FaissException& e) {
602
+ // Do nothing if allocation fails in merge.
603
+ #ifndef NDEBUG
604
+ std::cerr << e.what() << std::endl;
605
+ #endif
564
606
  }
565
607
  }
566
608
  };
567
609
 
610
+ /*****************************************************************
611
+ * Dispatcher function to choose the right knn result handler depending on k
612
+ *****************************************************************/
613
+
614
+ // declared in distances.cpp
615
+ FAISS_API extern int distance_compute_min_k_reservoir;
616
+
617
+ template <class Consumer, class... Types>
618
+ typename Consumer::T dispatch_knn_ResultHandler(
619
+ size_t nx,
620
+ float* vals,
621
+ int64_t* ids,
622
+ size_t k,
623
+ MetricType metric,
624
+ const IDSelector* sel,
625
+ Consumer& consumer,
626
+ Types... args) {
627
+ #define DISPATCH_C_SEL(C, use_sel) \
628
+ if (k == 1) { \
629
+ Top1BlockResultHandler<C, use_sel> res(nx, vals, ids, sel); \
630
+ return consumer.template f<>(res, args...); \
631
+ } else if (k < distance_compute_min_k_reservoir) { \
632
+ HeapBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
633
+ return consumer.template f<>(res, args...); \
634
+ } else { \
635
+ ReservoirBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
636
+ return consumer.template f<>(res, args...); \
637
+ }
638
+
639
+ if (is_similarity_metric(metric)) {
640
+ using C = CMin<float, int64_t>;
641
+ if (sel) {
642
+ DISPATCH_C_SEL(C, true);
643
+ } else {
644
+ DISPATCH_C_SEL(C, false);
645
+ }
646
+ } else {
647
+ using C = CMax<float, int64_t>;
648
+ if (sel) {
649
+ DISPATCH_C_SEL(C, true);
650
+ } else {
651
+ DISPATCH_C_SEL(C, false);
652
+ }
653
+ }
654
+ #undef DISPATCH_C_SEL
655
+ }
656
+
657
+ template <class Consumer, class... Types>
658
+ typename Consumer::T dispatch_range_ResultHandler(
659
+ RangeSearchResult* res,
660
+ float radius,
661
+ MetricType metric,
662
+ const IDSelector* sel,
663
+ Consumer& consumer,
664
+ Types... args) {
665
+ #define DISPATCH_C_SEL(C, use_sel) \
666
+ RangeSearchBlockResultHandler<C, use_sel> resb(res, radius, sel); \
667
+ return consumer.template f<>(resb, args...);
668
+
669
+ if (is_similarity_metric(metric)) {
670
+ using C = CMin<float, int64_t>;
671
+ if (sel) {
672
+ DISPATCH_C_SEL(C, true);
673
+ } else {
674
+ DISPATCH_C_SEL(C, false);
675
+ }
676
+ } else {
677
+ using C = CMax<float, int64_t>;
678
+ if (sel) {
679
+ DISPATCH_C_SEL(C, true);
680
+ } else {
681
+ DISPATCH_C_SEL(C, false);
682
+ }
683
+ }
684
+ #undef DISPATCH_C_SEL
685
+ }
686
+
568
687
  } // namespace faiss