faiss 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -17,23 +17,170 @@
17
17
 
18
18
  namespace faiss {
19
19
 
20
+ /*****************************************************************
21
+ * The classes below are intended to be used as template arguments
22
+ * they handle results for batches of queries (size nq).
23
+ * They can be called in two ways:
24
+ * - by instanciating a SingleResultHandler that tracks results for a single
25
+ * query
26
+ * - with begin_multiple/add_results/end_multiple calls where a whole block of
27
+ * resutls is submitted
28
+ * All classes are templated on C which to define wheter the min or the max of
29
+ * results is to be kept.
30
+ *****************************************************************/
31
+
32
+ template <class C>
33
+ struct BlockResultHandler {
34
+ size_t nq; // number of queries for which we search
35
+
36
+ explicit BlockResultHandler(size_t nq) : nq(nq) {}
37
+
38
+ // currently handled query range
39
+ size_t i0 = 0, i1 = 0;
40
+
41
+ // start collecting results for queries [i0, i1)
42
+ virtual void begin_multiple(size_t i0_2, size_t i1_2) {
43
+ this->i0 = i0_2;
44
+ this->i1 = i1_2;
45
+ }
46
+
47
+ // add results for queries [i0, i1) and database [j0, j1)
48
+ virtual void add_results(size_t, size_t, const typename C::T*) {}
49
+
50
+ // series of results for queries i0..i1 is done
51
+ virtual void end_multiple() {}
52
+
53
+ virtual ~BlockResultHandler() {}
54
+ };
55
+
56
+ // handler for a single query
57
+ template <class C>
58
+ struct ResultHandler {
59
+ // if not better than threshold, then not necessary to call add_result
60
+ typename C::T threshold = 0;
61
+
62
+ // return whether threshold was updated
63
+ virtual bool add_result(typename C::T dis, typename C::TI idx) = 0;
64
+
65
+ virtual ~ResultHandler() {}
66
+ };
67
+
68
+ /*****************************************************************
69
+ * Single best result handler.
70
+ * Tracks the only best result, thus avoiding storing
71
+ * some temporary data in memory.
72
+ *****************************************************************/
73
+
74
+ template <class C>
75
+ struct Top1BlockResultHandler : BlockResultHandler<C> {
76
+ using T = typename C::T;
77
+ using TI = typename C::TI;
78
+ using BlockResultHandler<C>::i0;
79
+ using BlockResultHandler<C>::i1;
80
+
81
+ // contains exactly nq elements
82
+ T* dis_tab;
83
+ // contains exactly nq elements
84
+ TI* ids_tab;
85
+
86
+ Top1BlockResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
87
+ : BlockResultHandler<C>(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
88
+
89
+ struct SingleResultHandler : ResultHandler<C> {
90
+ Top1BlockResultHandler& hr;
91
+ using ResultHandler<C>::threshold;
92
+
93
+ TI min_idx;
94
+ size_t current_idx = 0;
95
+
96
+ explicit SingleResultHandler(Top1BlockResultHandler& hr) : hr(hr) {}
97
+
98
+ /// begin results for query # i
99
+ void begin(const size_t current_idx_2) {
100
+ this->current_idx = current_idx_2;
101
+ threshold = C::neutral();
102
+ min_idx = -1;
103
+ }
104
+
105
+ /// add one result for query i
106
+ bool add_result(T dis, TI idx) final {
107
+ if (C::cmp(this->threshold, dis)) {
108
+ threshold = dis;
109
+ min_idx = idx;
110
+ return true;
111
+ }
112
+ return false;
113
+ }
114
+
115
+ /// series of results for query i is done
116
+ void end() {
117
+ hr.dis_tab[current_idx] = threshold;
118
+ hr.ids_tab[current_idx] = min_idx;
119
+ }
120
+ };
121
+
122
+ /// begin
123
+ void begin_multiple(size_t i0, size_t i1) final {
124
+ this->i0 = i0;
125
+ this->i1 = i1;
126
+
127
+ for (size_t i = i0; i < i1; i++) {
128
+ this->dis_tab[i] = C::neutral();
129
+ }
130
+ }
131
+
132
+ /// add results for query i0..i1 and j0..j1
133
+ void add_results(size_t j0, size_t j1, const T* dis_tab_2) final {
134
+ for (int64_t i = i0; i < i1; i++) {
135
+ const T* dis_tab_i = dis_tab_2 + (j1 - j0) * (i - i0) - j0;
136
+
137
+ auto& min_distance = this->dis_tab[i];
138
+ auto& min_index = this->ids_tab[i];
139
+
140
+ for (size_t j = j0; j < j1; j++) {
141
+ const T distance = dis_tab_i[j];
142
+
143
+ if (C::cmp(min_distance, distance)) {
144
+ min_distance = distance;
145
+ min_index = j;
146
+ }
147
+ }
148
+ }
149
+ }
150
+
151
+ void add_result(const size_t i, const T dis, const TI idx) {
152
+ auto& min_distance = this->dis_tab[i];
153
+ auto& min_index = this->ids_tab[i];
154
+
155
+ if (C::cmp(min_distance, dis)) {
156
+ min_distance = dis;
157
+ min_index = idx;
158
+ }
159
+ }
160
+ };
161
+
20
162
  /*****************************************************************
21
163
  * Heap based result handler
22
164
  *****************************************************************/
23
165
 
24
166
  template <class C>
25
- struct HeapResultHandler {
167
+ struct HeapBlockResultHandler : BlockResultHandler<C> {
26
168
  using T = typename C::T;
27
169
  using TI = typename C::TI;
170
+ using BlockResultHandler<C>::i0;
171
+ using BlockResultHandler<C>::i1;
28
172
 
29
- int nq;
30
173
  T* heap_dis_tab;
31
174
  TI* heap_ids_tab;
32
175
 
33
176
  int64_t k; // number of results to keep
34
177
 
35
- HeapResultHandler(size_t nq, T* heap_dis_tab, TI* heap_ids_tab, size_t k)
36
- : nq(nq),
178
+ HeapBlockResultHandler(
179
+ size_t nq,
180
+ T* heap_dis_tab,
181
+ TI* heap_ids_tab,
182
+ size_t k)
183
+ : BlockResultHandler<C>(nq),
37
184
  heap_dis_tab(heap_dis_tab),
38
185
  heap_ids_tab(heap_ids_tab),
39
186
  k(k) {}
@@ -43,30 +190,33 @@ struct HeapResultHandler {
43
190
  * called from 1 thread)
44
191
  */
45
192
 
46
- struct SingleResultHandler {
47
- HeapResultHandler& hr;
193
+ struct SingleResultHandler : ResultHandler<C> {
194
+ HeapBlockResultHandler& hr;
195
+ using ResultHandler<C>::threshold;
48
196
  size_t k;
49
197
 
50
198
  T* heap_dis;
51
199
  TI* heap_ids;
52
- T thresh;
53
200
 
54
- SingleResultHandler(HeapResultHandler& hr) : hr(hr), k(hr.k) {}
201
+ explicit SingleResultHandler(HeapBlockResultHandler& hr)
202
+ : hr(hr), k(hr.k) {}
55
203
 
56
204
  /// begin results for query # i
57
205
  void begin(size_t i) {
58
206
  heap_dis = hr.heap_dis_tab + i * k;
59
207
  heap_ids = hr.heap_ids_tab + i * k;
60
208
  heap_heapify<C>(k, heap_dis, heap_ids);
61
- thresh = heap_dis[0];
209
+ threshold = heap_dis[0];
62
210
  }
63
211
 
64
212
  /// add one result for query i
65
- void add_result(T dis, TI idx) {
66
- if (C::cmp(heap_dis[0], dis)) {
213
+ bool add_result(T dis, TI idx) final {
214
+ if (C::cmp(threshold, dis)) {
67
215
  heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
68
- thresh = heap_dis[0];
216
+ threshold = heap_dis[0];
217
+ return true;
69
218
  }
219
+ return false;
70
220
  }
71
221
 
72
222
  /// series of results for query i is done
@@ -79,19 +229,17 @@ struct HeapResultHandler {
79
229
  * API for multiple results (called from 1 thread)
80
230
  */
81
231
 
82
- size_t i0, i1;
83
-
84
232
  /// begin
85
- void begin_multiple(size_t i0, size_t i1) {
86
- this->i0 = i0;
87
- this->i1 = i1;
233
+ void begin_multiple(size_t i0_2, size_t i1_2) final {
234
+ this->i0 = i0_2;
235
+ this->i1 = i1_2;
88
236
  for (size_t i = i0; i < i1; i++) {
89
237
  heap_heapify<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
90
238
  }
91
239
  }
92
240
 
93
241
  /// add results for query i0..i1 and j0..j1
94
- void add_results(size_t j0, size_t j1, const T* dis_tab) {
242
+ void add_results(size_t j0, size_t j1, const T* dis_tab) final {
95
243
  #pragma omp parallel for
96
244
  for (int64_t i = i0; i < i1; i++) {
97
245
  T* heap_dis = heap_dis_tab + i * k;
@@ -109,7 +257,7 @@ struct HeapResultHandler {
109
257
  }
110
258
 
111
259
  /// series of results for queries i0..i1 is done
112
- void end_multiple() {
260
+ void end_multiple() final {
113
261
  // maybe parallel for
114
262
  for (size_t i = i0; i < i1; i++) {
115
263
  heap_reorder<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
@@ -128,9 +276,10 @@ struct HeapResultHandler {
128
276
 
129
277
  /// Reservoir for a single query
130
278
  template <class C>
131
- struct ReservoirTopN {
279
+ struct ReservoirTopN : ResultHandler<C> {
132
280
  using T = typename C::T;
133
281
  using TI = typename C::TI;
282
+ using ResultHandler<C>::threshold;
134
283
 
135
284
  T* vals;
136
285
  TI* ids;
@@ -139,8 +288,6 @@ struct ReservoirTopN {
139
288
  size_t n; // number of requested elements
140
289
  size_t capacity; // size of storage
141
290
 
142
- T threshold; // current threshold
143
-
144
291
  ReservoirTopN() {}
145
292
 
146
293
  ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
@@ -149,15 +296,22 @@ struct ReservoirTopN {
149
296
  threshold = C::neutral();
150
297
  }
151
298
 
152
- void add(T val, TI id) {
299
+ bool add_result(T val, TI id) final {
300
+ bool updated_threshold = false;
153
301
  if (C::cmp(threshold, val)) {
154
302
  if (i == capacity) {
155
303
  shrink_fuzzy();
304
+ updated_threshold = true;
156
305
  }
157
306
  vals[i] = val;
158
307
  ids[i] = id;
159
308
  i++;
160
309
  }
310
+ return updated_threshold;
311
+ }
312
+
313
+ void add(T val, TI id) {
314
+ add_result(val, id);
161
315
  }
162
316
 
163
317
  // reduce storage from capacity to anything
@@ -169,6 +323,11 @@ struct ReservoirTopN {
169
323
  vals, ids, capacity, n, (capacity + n) / 2, &i);
170
324
  }
171
325
 
326
+ void shrink() {
327
+ threshold = partition<C>(vals, ids, i, n);
328
+ i = n;
329
+ }
330
+
172
331
  void to_result(T* heap_dis, TI* heap_ids) const {
173
332
  for (int j = 0; j < std::min(i, n); j++) {
174
333
  heap_push<C>(j + 1, heap_dis, heap_ids, vals[j], ids[j]);
@@ -187,23 +346,24 @@ struct ReservoirTopN {
187
346
  };
188
347
 
189
348
  template <class C>
190
- struct ReservoirResultHandler {
349
+ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
191
350
  using T = typename C::T;
192
351
  using TI = typename C::TI;
352
+ using BlockResultHandler<C>::i0;
353
+ using BlockResultHandler<C>::i1;
193
354
 
194
- int nq;
195
355
  T* heap_dis_tab;
196
356
  TI* heap_ids_tab;
197
357
 
198
358
  int64_t k; // number of results to keep
199
359
  size_t capacity; // capacity of the reservoirs
200
360
 
201
- ReservoirResultHandler(
361
+ ReservoirBlockResultHandler(
202
362
  size_t nq,
203
363
  T* heap_dis_tab,
204
364
  TI* heap_ids_tab,
205
365
  size_t k)
206
- : nq(nq),
366
+ : BlockResultHandler<C>(nq),
207
367
  heap_dis_tab(heap_dis_tab),
208
368
  heap_ids_tab(heap_ids_tab),
209
369
  k(k) {
@@ -216,40 +376,34 @@ struct ReservoirResultHandler {
216
376
  * called from 1 thread)
217
377
  */
218
378
 
219
- struct SingleResultHandler {
220
- ReservoirResultHandler& hr;
379
+ struct SingleResultHandler : ReservoirTopN<C> {
380
+ ReservoirBlockResultHandler& hr;
221
381
 
222
382
  std::vector<T> reservoir_dis;
223
383
  std::vector<TI> reservoir_ids;
224
- ReservoirTopN<C> res1;
225
384
 
226
- SingleResultHandler(ReservoirResultHandler& hr)
227
- : hr(hr),
228
- reservoir_dis(hr.capacity),
229
- reservoir_ids(hr.capacity) {}
385
+ explicit SingleResultHandler(ReservoirBlockResultHandler& hr)
386
+ : ReservoirTopN<C>(hr.k, hr.capacity, nullptr, nullptr),
387
+ hr(hr) {}
230
388
 
231
- size_t i;
389
+ size_t qno;
232
390
 
233
391
  /// begin results for query # i
234
- void begin(size_t i) {
235
- res1 = ReservoirTopN<C>(
236
- hr.k,
237
- hr.capacity,
238
- reservoir_dis.data(),
239
- reservoir_ids.data());
240
- this->i = i;
392
+ void begin(size_t qno_2) {
393
+ reservoir_dis.resize(hr.capacity);
394
+ reservoir_ids.resize(hr.capacity);
395
+ this->vals = reservoir_dis.data();
396
+ this->ids = reservoir_ids.data();
397
+ this->i = 0; // size of reservoir
398
+ this->threshold = C::neutral();
399
+ this->qno = qno_2;
241
400
  }
242
401
 
243
- /// add one result for query i
244
- void add_result(T dis, TI idx) {
245
- res1.add(dis, idx);
246
- }
247
-
248
- /// series of results for query i is done
402
+ /// series of results for query qno is done
249
403
  void end() {
250
- T* heap_dis = hr.heap_dis_tab + i * hr.k;
251
- TI* heap_ids = hr.heap_ids_tab + i * hr.k;
252
- res1.to_result(heap_dis, heap_ids);
404
+ T* heap_dis = hr.heap_dis_tab + qno * hr.k;
405
+ TI* heap_ids = hr.heap_ids_tab + qno * hr.k;
406
+ this->to_result(heap_dis, heap_ids);
253
407
  }
254
408
  };
255
409
 
@@ -257,44 +411,41 @@ struct ReservoirResultHandler {
257
411
  * API for multiple results (called from 1 thread)
258
412
  */
259
413
 
260
- size_t i0, i1;
261
-
262
414
  std::vector<T> reservoir_dis;
263
415
  std::vector<TI> reservoir_ids;
264
416
  std::vector<ReservoirTopN<C>> reservoirs;
265
417
 
266
418
  /// begin
267
- void begin_multiple(size_t i0, size_t i1) {
268
- this->i0 = i0;
269
- this->i1 = i1;
419
+ void begin_multiple(size_t i0_2, size_t i1_2) {
420
+ this->i0 = i0_2;
421
+ this->i1 = i1_2;
270
422
  reservoir_dis.resize((i1 - i0) * capacity);
271
423
  reservoir_ids.resize((i1 - i0) * capacity);
272
424
  reservoirs.clear();
273
- for (size_t i = i0; i < i1; i++) {
425
+ for (size_t i = i0_2; i < i1_2; i++) {
274
426
  reservoirs.emplace_back(
275
427
  k,
276
428
  capacity,
277
- reservoir_dis.data() + (i - i0) * capacity,
278
- reservoir_ids.data() + (i - i0) * capacity);
429
+ reservoir_dis.data() + (i - i0_2) * capacity,
430
+ reservoir_ids.data() + (i - i0_2) * capacity);
279
431
  }
280
432
  }
281
433
 
282
434
  /// add results for query i0..i1 and j0..j1
283
435
  void add_results(size_t j0, size_t j1, const T* dis_tab) {
284
- // maybe parallel for
285
436
  #pragma omp parallel for
286
437
  for (int64_t i = i0; i < i1; i++) {
287
438
  ReservoirTopN<C>& reservoir = reservoirs[i - i0];
288
439
  const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
289
440
  for (size_t j = j0; j < j1; j++) {
290
441
  T dis = dis_tab_i[j];
291
- reservoir.add(dis, j);
442
+ reservoir.add_result(dis, j);
292
443
  }
293
444
  }
294
445
  }
295
446
 
296
447
  /// series of results for queries i0..i1 is done
297
- void end_multiple() {
448
+ void end_multiple() final {
298
449
  // maybe parallel for
299
450
  for (size_t i = i0; i < i1; i++) {
300
451
  reservoirs[i - i0].to_result(
@@ -308,29 +459,33 @@ struct ReservoirResultHandler {
308
459
  *****************************************************************/
309
460
 
310
461
  template <class C>
311
- struct RangeSearchResultHandler {
462
+ struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
312
463
  using T = typename C::T;
313
464
  using TI = typename C::TI;
465
+ using BlockResultHandler<C>::i0;
466
+ using BlockResultHandler<C>::i1;
314
467
 
315
468
  RangeSearchResult* res;
316
- float radius;
469
+ T radius;
317
470
 
318
- RangeSearchResultHandler(RangeSearchResult* res, float radius)
319
- : res(res), radius(radius) {}
471
+ RangeSearchBlockResultHandler(RangeSearchResult* res, float radius)
472
+ : BlockResultHandler<C>(res->nq), res(res), radius(radius) {}
320
473
 
321
474
  /******************************************************
322
475
  * API for 1 result at a time (each SingleResultHandler is
323
476
  * called from 1 thread)
324
477
  ******************************************************/
325
478
 
326
- struct SingleResultHandler {
479
+ struct SingleResultHandler : ResultHandler<C> {
327
480
  // almost the same interface as RangeSearchResultHandler
481
+ using ResultHandler<C>::threshold;
328
482
  RangeSearchPartialResult pres;
329
- float radius;
330
483
  RangeQueryResult* qr = nullptr;
331
484
 
332
- SingleResultHandler(RangeSearchResultHandler& rh)
333
- : pres(rh.res), radius(rh.radius) {}
485
+ explicit SingleResultHandler(RangeSearchBlockResultHandler& rh)
486
+ : pres(rh.res) {
487
+ threshold = rh.radius;
488
+ }
334
489
 
335
490
  /// begin results for query # i
336
491
  void begin(size_t i) {
@@ -338,10 +493,11 @@ struct RangeSearchResultHandler {
338
493
  }
339
494
 
340
495
  /// add one result for query i
341
- void add_result(T dis, TI idx) {
342
- if (C::cmp(radius, dis)) {
496
+ bool add_result(T dis, TI idx) final {
497
+ if (C::cmp(threshold, dis)) {
343
498
  qr->add(dis, idx);
344
499
  }
500
+ return false;
345
501
  }
346
502
 
347
503
  /// series of results for query i is done
@@ -356,16 +512,14 @@ struct RangeSearchResultHandler {
356
512
  * API for multiple results (called from 1 thread)
357
513
  ******************************************************/
358
514
 
359
- size_t i0, i1;
360
-
361
515
  std::vector<RangeSearchPartialResult*> partial_results;
362
516
  std::vector<size_t> j0s;
363
517
  int pr = 0;
364
518
 
365
519
  /// begin
366
- void begin_multiple(size_t i0, size_t i1) {
367
- this->i0 = i0;
368
- this->i1 = i1;
520
+ void begin_multiple(size_t i0_2, size_t i1_2) {
521
+ this->i0 = i0_2;
522
+ this->i1 = i1_2;
369
523
  }
370
524
 
371
525
  /// add results for query i0..i1 and j0..j1
@@ -404,109 +558,11 @@ struct RangeSearchResultHandler {
404
558
  }
405
559
  }
406
560
 
407
- void end_multiple() {}
408
-
409
- ~RangeSearchResultHandler() {
561
+ ~RangeSearchBlockResultHandler() {
410
562
  if (partial_results.size() > 0) {
411
563
  RangeSearchPartialResult::merge(partial_results);
412
564
  }
413
565
  }
414
566
  };
415
567
 
416
- /*****************************************************************
417
- * Single best result handler.
418
- * Tracks the only best result, thus avoiding storing
419
- * some temporary data in memory.
420
- *****************************************************************/
421
-
422
- template <class C>
423
- struct SingleBestResultHandler {
424
- using T = typename C::T;
425
- using TI = typename C::TI;
426
-
427
- int nq;
428
- // contains exactly nq elements
429
- T* dis_tab;
430
- // contains exactly nq elements
431
- TI* ids_tab;
432
-
433
- SingleBestResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
434
- : nq(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
435
-
436
- struct SingleResultHandler {
437
- SingleBestResultHandler& hr;
438
-
439
- T min_dis;
440
- TI min_idx;
441
- size_t current_idx = 0;
442
-
443
- SingleResultHandler(SingleBestResultHandler& hr) : hr(hr) {}
444
-
445
- /// begin results for query # i
446
- void begin(const size_t current_idx) {
447
- this->current_idx = current_idx;
448
- min_dis = HUGE_VALF;
449
- min_idx = 0;
450
- }
451
-
452
- /// add one result for query i
453
- void add_result(T dis, TI idx) {
454
- if (C::cmp(min_dis, dis)) {
455
- min_dis = dis;
456
- min_idx = idx;
457
- }
458
- }
459
-
460
- /// series of results for query i is done
461
- void end() {
462
- hr.dis_tab[current_idx] = min_dis;
463
- hr.ids_tab[current_idx] = min_idx;
464
- }
465
- };
466
-
467
- size_t i0, i1;
468
-
469
- /// begin
470
- void begin_multiple(size_t i0, size_t i1) {
471
- this->i0 = i0;
472
- this->i1 = i1;
473
-
474
- for (size_t i = i0; i < i1; i++) {
475
- this->dis_tab[i] = HUGE_VALF;
476
- }
477
- }
478
-
479
- /// add results for query i0..i1 and j0..j1
480
- void add_results(size_t j0, size_t j1, const T* dis_tab) {
481
- for (int64_t i = i0; i < i1; i++) {
482
- const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
483
-
484
- auto& min_distance = this->dis_tab[i];
485
- auto& min_index = this->ids_tab[i];
486
-
487
- for (size_t j = j0; j < j1; j++) {
488
- const T distance = dis_tab_i[j];
489
-
490
- if (C::cmp(min_distance, distance)) {
491
- min_distance = distance;
492
- min_index = j;
493
- }
494
- }
495
- }
496
- }
497
-
498
- void add_result(const size_t i, const T dis, const TI idx) {
499
- auto& min_distance = this->dis_tab[i];
500
- auto& min_index = this->ids_tab[i];
501
-
502
- if (C::cmp(min_distance, dis)) {
503
- min_distance = dis;
504
- min_index = idx;
505
- }
506
- }
507
-
508
- /// series of results for queries i0..i1 is done
509
- void end_multiple() {}
510
- };
511
-
512
568
  } // namespace faiss