faiss 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -17,23 +17,170 @@
17
17
 
18
18
  namespace faiss {
19
19
 
20
+ /*****************************************************************
21
+ * The classes below are intended to be used as template arguments
22
+ * they handle results for batches of queries (size nq).
23
+ * They can be called in two ways:
24
+ * - by instanciating a SingleResultHandler that tracks results for a single
25
+ * query
26
+ * - with begin_multiple/add_results/end_multiple calls where a whole block of
27
+ * resutls is submitted
28
+ * All classes are templated on C which to define wheter the min or the max of
29
+ * results is to be kept.
30
+ *****************************************************************/
31
+
32
+ template <class C>
33
+ struct BlockResultHandler {
34
+ size_t nq; // number of queries for which we search
35
+
36
+ explicit BlockResultHandler(size_t nq) : nq(nq) {}
37
+
38
+ // currently handled query range
39
+ size_t i0 = 0, i1 = 0;
40
+
41
+ // start collecting results for queries [i0, i1)
42
+ virtual void begin_multiple(size_t i0_2, size_t i1_2) {
43
+ this->i0 = i0_2;
44
+ this->i1 = i1_2;
45
+ }
46
+
47
+ // add results for queries [i0, i1) and database [j0, j1)
48
+ virtual void add_results(size_t, size_t, const typename C::T*) {}
49
+
50
+ // series of results for queries i0..i1 is done
51
+ virtual void end_multiple() {}
52
+
53
+ virtual ~BlockResultHandler() {}
54
+ };
55
+
56
+ // handler for a single query
57
+ template <class C>
58
+ struct ResultHandler {
59
+ // if not better than threshold, then not necessary to call add_result
60
+ typename C::T threshold = 0;
61
+
62
+ // return whether threshold was updated
63
+ virtual bool add_result(typename C::T dis, typename C::TI idx) = 0;
64
+
65
+ virtual ~ResultHandler() {}
66
+ };
67
+
68
+ /*****************************************************************
69
+ * Single best result handler.
70
+ * Tracks the only best result, thus avoiding storing
71
+ * some temporary data in memory.
72
+ *****************************************************************/
73
+
74
+ template <class C>
75
+ struct Top1BlockResultHandler : BlockResultHandler<C> {
76
+ using T = typename C::T;
77
+ using TI = typename C::TI;
78
+ using BlockResultHandler<C>::i0;
79
+ using BlockResultHandler<C>::i1;
80
+
81
+ // contains exactly nq elements
82
+ T* dis_tab;
83
+ // contains exactly nq elements
84
+ TI* ids_tab;
85
+
86
+ Top1BlockResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
87
+ : BlockResultHandler<C>(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
88
+
89
+ struct SingleResultHandler : ResultHandler<C> {
90
+ Top1BlockResultHandler& hr;
91
+ using ResultHandler<C>::threshold;
92
+
93
+ TI min_idx;
94
+ size_t current_idx = 0;
95
+
96
+ explicit SingleResultHandler(Top1BlockResultHandler& hr) : hr(hr) {}
97
+
98
+ /// begin results for query # i
99
+ void begin(const size_t current_idx_2) {
100
+ this->current_idx = current_idx_2;
101
+ threshold = C::neutral();
102
+ min_idx = -1;
103
+ }
104
+
105
+ /// add one result for query i
106
+ bool add_result(T dis, TI idx) final {
107
+ if (C::cmp(this->threshold, dis)) {
108
+ threshold = dis;
109
+ min_idx = idx;
110
+ return true;
111
+ }
112
+ return false;
113
+ }
114
+
115
+ /// series of results for query i is done
116
+ void end() {
117
+ hr.dis_tab[current_idx] = threshold;
118
+ hr.ids_tab[current_idx] = min_idx;
119
+ }
120
+ };
121
+
122
+ /// begin
123
+ void begin_multiple(size_t i0, size_t i1) final {
124
+ this->i0 = i0;
125
+ this->i1 = i1;
126
+
127
+ for (size_t i = i0; i < i1; i++) {
128
+ this->dis_tab[i] = C::neutral();
129
+ }
130
+ }
131
+
132
+ /// add results for query i0..i1 and j0..j1
133
+ void add_results(size_t j0, size_t j1, const T* dis_tab_2) final {
134
+ for (int64_t i = i0; i < i1; i++) {
135
+ const T* dis_tab_i = dis_tab_2 + (j1 - j0) * (i - i0) - j0;
136
+
137
+ auto& min_distance = this->dis_tab[i];
138
+ auto& min_index = this->ids_tab[i];
139
+
140
+ for (size_t j = j0; j < j1; j++) {
141
+ const T distance = dis_tab_i[j];
142
+
143
+ if (C::cmp(min_distance, distance)) {
144
+ min_distance = distance;
145
+ min_index = j;
146
+ }
147
+ }
148
+ }
149
+ }
150
+
151
+ void add_result(const size_t i, const T dis, const TI idx) {
152
+ auto& min_distance = this->dis_tab[i];
153
+ auto& min_index = this->ids_tab[i];
154
+
155
+ if (C::cmp(min_distance, dis)) {
156
+ min_distance = dis;
157
+ min_index = idx;
158
+ }
159
+ }
160
+ };
161
+
20
162
  /*****************************************************************
21
163
  * Heap based result handler
22
164
  *****************************************************************/
23
165
 
24
166
  template <class C>
25
- struct HeapResultHandler {
167
+ struct HeapBlockResultHandler : BlockResultHandler<C> {
26
168
  using T = typename C::T;
27
169
  using TI = typename C::TI;
170
+ using BlockResultHandler<C>::i0;
171
+ using BlockResultHandler<C>::i1;
28
172
 
29
- int nq;
30
173
  T* heap_dis_tab;
31
174
  TI* heap_ids_tab;
32
175
 
33
176
  int64_t k; // number of results to keep
34
177
 
35
- HeapResultHandler(size_t nq, T* heap_dis_tab, TI* heap_ids_tab, size_t k)
36
- : nq(nq),
178
+ HeapBlockResultHandler(
179
+ size_t nq,
180
+ T* heap_dis_tab,
181
+ TI* heap_ids_tab,
182
+ size_t k)
183
+ : BlockResultHandler<C>(nq),
37
184
  heap_dis_tab(heap_dis_tab),
38
185
  heap_ids_tab(heap_ids_tab),
39
186
  k(k) {}
@@ -43,30 +190,33 @@ struct HeapResultHandler {
43
190
  * called from 1 thread)
44
191
  */
45
192
 
46
- struct SingleResultHandler {
47
- HeapResultHandler& hr;
193
+ struct SingleResultHandler : ResultHandler<C> {
194
+ HeapBlockResultHandler& hr;
195
+ using ResultHandler<C>::threshold;
48
196
  size_t k;
49
197
 
50
198
  T* heap_dis;
51
199
  TI* heap_ids;
52
- T thresh;
53
200
 
54
- SingleResultHandler(HeapResultHandler& hr) : hr(hr), k(hr.k) {}
201
+ explicit SingleResultHandler(HeapBlockResultHandler& hr)
202
+ : hr(hr), k(hr.k) {}
55
203
 
56
204
  /// begin results for query # i
57
205
  void begin(size_t i) {
58
206
  heap_dis = hr.heap_dis_tab + i * k;
59
207
  heap_ids = hr.heap_ids_tab + i * k;
60
208
  heap_heapify<C>(k, heap_dis, heap_ids);
61
- thresh = heap_dis[0];
209
+ threshold = heap_dis[0];
62
210
  }
63
211
 
64
212
  /// add one result for query i
65
- void add_result(T dis, TI idx) {
66
- if (C::cmp(heap_dis[0], dis)) {
213
+ bool add_result(T dis, TI idx) final {
214
+ if (C::cmp(threshold, dis)) {
67
215
  heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
68
- thresh = heap_dis[0];
216
+ threshold = heap_dis[0];
217
+ return true;
69
218
  }
219
+ return false;
70
220
  }
71
221
 
72
222
  /// series of results for query i is done
@@ -79,19 +229,17 @@ struct HeapResultHandler {
79
229
  * API for multiple results (called from 1 thread)
80
230
  */
81
231
 
82
- size_t i0, i1;
83
-
84
232
  /// begin
85
- void begin_multiple(size_t i0, size_t i1) {
86
- this->i0 = i0;
87
- this->i1 = i1;
233
+ void begin_multiple(size_t i0_2, size_t i1_2) final {
234
+ this->i0 = i0_2;
235
+ this->i1 = i1_2;
88
236
  for (size_t i = i0; i < i1; i++) {
89
237
  heap_heapify<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
90
238
  }
91
239
  }
92
240
 
93
241
  /// add results for query i0..i1 and j0..j1
94
- void add_results(size_t j0, size_t j1, const T* dis_tab) {
242
+ void add_results(size_t j0, size_t j1, const T* dis_tab) final {
95
243
  #pragma omp parallel for
96
244
  for (int64_t i = i0; i < i1; i++) {
97
245
  T* heap_dis = heap_dis_tab + i * k;
@@ -109,7 +257,7 @@ struct HeapResultHandler {
109
257
  }
110
258
 
111
259
  /// series of results for queries i0..i1 is done
112
- void end_multiple() {
260
+ void end_multiple() final {
113
261
  // maybe parallel for
114
262
  for (size_t i = i0; i < i1; i++) {
115
263
  heap_reorder<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
@@ -128,9 +276,10 @@ struct HeapResultHandler {
128
276
 
129
277
  /// Reservoir for a single query
130
278
  template <class C>
131
- struct ReservoirTopN {
279
+ struct ReservoirTopN : ResultHandler<C> {
132
280
  using T = typename C::T;
133
281
  using TI = typename C::TI;
282
+ using ResultHandler<C>::threshold;
134
283
 
135
284
  T* vals;
136
285
  TI* ids;
@@ -139,8 +288,6 @@ struct ReservoirTopN {
139
288
  size_t n; // number of requested elements
140
289
  size_t capacity; // size of storage
141
290
 
142
- T threshold; // current threshold
143
-
144
291
  ReservoirTopN() {}
145
292
 
146
293
  ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
@@ -149,15 +296,22 @@ struct ReservoirTopN {
149
296
  threshold = C::neutral();
150
297
  }
151
298
 
152
- void add(T val, TI id) {
299
+ bool add_result(T val, TI id) final {
300
+ bool updated_threshold = false;
153
301
  if (C::cmp(threshold, val)) {
154
302
  if (i == capacity) {
155
303
  shrink_fuzzy();
304
+ updated_threshold = true;
156
305
  }
157
306
  vals[i] = val;
158
307
  ids[i] = id;
159
308
  i++;
160
309
  }
310
+ return updated_threshold;
311
+ }
312
+
313
+ void add(T val, TI id) {
314
+ add_result(val, id);
161
315
  }
162
316
 
163
317
  // reduce storage from capacity to anything
@@ -169,6 +323,11 @@ struct ReservoirTopN {
169
323
  vals, ids, capacity, n, (capacity + n) / 2, &i);
170
324
  }
171
325
 
326
+ void shrink() {
327
+ threshold = partition<C>(vals, ids, i, n);
328
+ i = n;
329
+ }
330
+
172
331
  void to_result(T* heap_dis, TI* heap_ids) const {
173
332
  for (int j = 0; j < std::min(i, n); j++) {
174
333
  heap_push<C>(j + 1, heap_dis, heap_ids, vals[j], ids[j]);
@@ -187,23 +346,24 @@ struct ReservoirTopN {
187
346
  };
188
347
 
189
348
  template <class C>
190
- struct ReservoirResultHandler {
349
+ struct ReservoirBlockResultHandler : BlockResultHandler<C> {
191
350
  using T = typename C::T;
192
351
  using TI = typename C::TI;
352
+ using BlockResultHandler<C>::i0;
353
+ using BlockResultHandler<C>::i1;
193
354
 
194
- int nq;
195
355
  T* heap_dis_tab;
196
356
  TI* heap_ids_tab;
197
357
 
198
358
  int64_t k; // number of results to keep
199
359
  size_t capacity; // capacity of the reservoirs
200
360
 
201
- ReservoirResultHandler(
361
+ ReservoirBlockResultHandler(
202
362
  size_t nq,
203
363
  T* heap_dis_tab,
204
364
  TI* heap_ids_tab,
205
365
  size_t k)
206
- : nq(nq),
366
+ : BlockResultHandler<C>(nq),
207
367
  heap_dis_tab(heap_dis_tab),
208
368
  heap_ids_tab(heap_ids_tab),
209
369
  k(k) {
@@ -216,40 +376,34 @@ struct ReservoirResultHandler {
216
376
  * called from 1 thread)
217
377
  */
218
378
 
219
- struct SingleResultHandler {
220
- ReservoirResultHandler& hr;
379
+ struct SingleResultHandler : ReservoirTopN<C> {
380
+ ReservoirBlockResultHandler& hr;
221
381
 
222
382
  std::vector<T> reservoir_dis;
223
383
  std::vector<TI> reservoir_ids;
224
- ReservoirTopN<C> res1;
225
384
 
226
- SingleResultHandler(ReservoirResultHandler& hr)
227
- : hr(hr),
228
- reservoir_dis(hr.capacity),
229
- reservoir_ids(hr.capacity) {}
385
+ explicit SingleResultHandler(ReservoirBlockResultHandler& hr)
386
+ : ReservoirTopN<C>(hr.k, hr.capacity, nullptr, nullptr),
387
+ hr(hr) {}
230
388
 
231
- size_t i;
389
+ size_t qno;
232
390
 
233
391
  /// begin results for query # i
234
- void begin(size_t i) {
235
- res1 = ReservoirTopN<C>(
236
- hr.k,
237
- hr.capacity,
238
- reservoir_dis.data(),
239
- reservoir_ids.data());
240
- this->i = i;
392
+ void begin(size_t qno_2) {
393
+ reservoir_dis.resize(hr.capacity);
394
+ reservoir_ids.resize(hr.capacity);
395
+ this->vals = reservoir_dis.data();
396
+ this->ids = reservoir_ids.data();
397
+ this->i = 0; // size of reservoir
398
+ this->threshold = C::neutral();
399
+ this->qno = qno_2;
241
400
  }
242
401
 
243
- /// add one result for query i
244
- void add_result(T dis, TI idx) {
245
- res1.add(dis, idx);
246
- }
247
-
248
- /// series of results for query i is done
402
+ /// series of results for query qno is done
249
403
  void end() {
250
- T* heap_dis = hr.heap_dis_tab + i * hr.k;
251
- TI* heap_ids = hr.heap_ids_tab + i * hr.k;
252
- res1.to_result(heap_dis, heap_ids);
404
+ T* heap_dis = hr.heap_dis_tab + qno * hr.k;
405
+ TI* heap_ids = hr.heap_ids_tab + qno * hr.k;
406
+ this->to_result(heap_dis, heap_ids);
253
407
  }
254
408
  };
255
409
 
@@ -257,44 +411,41 @@ struct ReservoirResultHandler {
257
411
  * API for multiple results (called from 1 thread)
258
412
  */
259
413
 
260
- size_t i0, i1;
261
-
262
414
  std::vector<T> reservoir_dis;
263
415
  std::vector<TI> reservoir_ids;
264
416
  std::vector<ReservoirTopN<C>> reservoirs;
265
417
 
266
418
  /// begin
267
- void begin_multiple(size_t i0, size_t i1) {
268
- this->i0 = i0;
269
- this->i1 = i1;
419
+ void begin_multiple(size_t i0_2, size_t i1_2) {
420
+ this->i0 = i0_2;
421
+ this->i1 = i1_2;
270
422
  reservoir_dis.resize((i1 - i0) * capacity);
271
423
  reservoir_ids.resize((i1 - i0) * capacity);
272
424
  reservoirs.clear();
273
- for (size_t i = i0; i < i1; i++) {
425
+ for (size_t i = i0_2; i < i1_2; i++) {
274
426
  reservoirs.emplace_back(
275
427
  k,
276
428
  capacity,
277
- reservoir_dis.data() + (i - i0) * capacity,
278
- reservoir_ids.data() + (i - i0) * capacity);
429
+ reservoir_dis.data() + (i - i0_2) * capacity,
430
+ reservoir_ids.data() + (i - i0_2) * capacity);
279
431
  }
280
432
  }
281
433
 
282
434
  /// add results for query i0..i1 and j0..j1
283
435
  void add_results(size_t j0, size_t j1, const T* dis_tab) {
284
- // maybe parallel for
285
436
  #pragma omp parallel for
286
437
  for (int64_t i = i0; i < i1; i++) {
287
438
  ReservoirTopN<C>& reservoir = reservoirs[i - i0];
288
439
  const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
289
440
  for (size_t j = j0; j < j1; j++) {
290
441
  T dis = dis_tab_i[j];
291
- reservoir.add(dis, j);
442
+ reservoir.add_result(dis, j);
292
443
  }
293
444
  }
294
445
  }
295
446
 
296
447
  /// series of results for queries i0..i1 is done
297
- void end_multiple() {
448
+ void end_multiple() final {
298
449
  // maybe parallel for
299
450
  for (size_t i = i0; i < i1; i++) {
300
451
  reservoirs[i - i0].to_result(
@@ -308,29 +459,33 @@ struct ReservoirResultHandler {
308
459
  *****************************************************************/
309
460
 
310
461
  template <class C>
311
- struct RangeSearchResultHandler {
462
+ struct RangeSearchBlockResultHandler : BlockResultHandler<C> {
312
463
  using T = typename C::T;
313
464
  using TI = typename C::TI;
465
+ using BlockResultHandler<C>::i0;
466
+ using BlockResultHandler<C>::i1;
314
467
 
315
468
  RangeSearchResult* res;
316
- float radius;
469
+ T radius;
317
470
 
318
- RangeSearchResultHandler(RangeSearchResult* res, float radius)
319
- : res(res), radius(radius) {}
471
+ RangeSearchBlockResultHandler(RangeSearchResult* res, float radius)
472
+ : BlockResultHandler<C>(res->nq), res(res), radius(radius) {}
320
473
 
321
474
  /******************************************************
322
475
  * API for 1 result at a time (each SingleResultHandler is
323
476
  * called from 1 thread)
324
477
  ******************************************************/
325
478
 
326
- struct SingleResultHandler {
479
+ struct SingleResultHandler : ResultHandler<C> {
327
480
  // almost the same interface as RangeSearchResultHandler
481
+ using ResultHandler<C>::threshold;
328
482
  RangeSearchPartialResult pres;
329
- float radius;
330
483
  RangeQueryResult* qr = nullptr;
331
484
 
332
- SingleResultHandler(RangeSearchResultHandler& rh)
333
- : pres(rh.res), radius(rh.radius) {}
485
+ explicit SingleResultHandler(RangeSearchBlockResultHandler& rh)
486
+ : pres(rh.res) {
487
+ threshold = rh.radius;
488
+ }
334
489
 
335
490
  /// begin results for query # i
336
491
  void begin(size_t i) {
@@ -338,10 +493,11 @@ struct RangeSearchResultHandler {
338
493
  }
339
494
 
340
495
  /// add one result for query i
341
- void add_result(T dis, TI idx) {
342
- if (C::cmp(radius, dis)) {
496
+ bool add_result(T dis, TI idx) final {
497
+ if (C::cmp(threshold, dis)) {
343
498
  qr->add(dis, idx);
344
499
  }
500
+ return false;
345
501
  }
346
502
 
347
503
  /// series of results for query i is done
@@ -356,16 +512,14 @@ struct RangeSearchResultHandler {
356
512
  * API for multiple results (called from 1 thread)
357
513
  ******************************************************/
358
514
 
359
- size_t i0, i1;
360
-
361
515
  std::vector<RangeSearchPartialResult*> partial_results;
362
516
  std::vector<size_t> j0s;
363
517
  int pr = 0;
364
518
 
365
519
  /// begin
366
- void begin_multiple(size_t i0, size_t i1) {
367
- this->i0 = i0;
368
- this->i1 = i1;
520
+ void begin_multiple(size_t i0_2, size_t i1_2) {
521
+ this->i0 = i0_2;
522
+ this->i1 = i1_2;
369
523
  }
370
524
 
371
525
  /// add results for query i0..i1 and j0..j1
@@ -404,109 +558,11 @@ struct RangeSearchResultHandler {
404
558
  }
405
559
  }
406
560
 
407
- void end_multiple() {}
408
-
409
- ~RangeSearchResultHandler() {
561
+ ~RangeSearchBlockResultHandler() {
410
562
  if (partial_results.size() > 0) {
411
563
  RangeSearchPartialResult::merge(partial_results);
412
564
  }
413
565
  }
414
566
  };
415
567
 
416
- /*****************************************************************
417
- * Single best result handler.
418
- * Tracks the only best result, thus avoiding storing
419
- * some temporary data in memory.
420
- *****************************************************************/
421
-
422
- template <class C>
423
- struct SingleBestResultHandler {
424
- using T = typename C::T;
425
- using TI = typename C::TI;
426
-
427
- int nq;
428
- // contains exactly nq elements
429
- T* dis_tab;
430
- // contains exactly nq elements
431
- TI* ids_tab;
432
-
433
- SingleBestResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
434
- : nq(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
435
-
436
- struct SingleResultHandler {
437
- SingleBestResultHandler& hr;
438
-
439
- T min_dis;
440
- TI min_idx;
441
- size_t current_idx = 0;
442
-
443
- SingleResultHandler(SingleBestResultHandler& hr) : hr(hr) {}
444
-
445
- /// begin results for query # i
446
- void begin(const size_t current_idx) {
447
- this->current_idx = current_idx;
448
- min_dis = HUGE_VALF;
449
- min_idx = 0;
450
- }
451
-
452
- /// add one result for query i
453
- void add_result(T dis, TI idx) {
454
- if (C::cmp(min_dis, dis)) {
455
- min_dis = dis;
456
- min_idx = idx;
457
- }
458
- }
459
-
460
- /// series of results for query i is done
461
- void end() {
462
- hr.dis_tab[current_idx] = min_dis;
463
- hr.ids_tab[current_idx] = min_idx;
464
- }
465
- };
466
-
467
- size_t i0, i1;
468
-
469
- /// begin
470
- void begin_multiple(size_t i0, size_t i1) {
471
- this->i0 = i0;
472
- this->i1 = i1;
473
-
474
- for (size_t i = i0; i < i1; i++) {
475
- this->dis_tab[i] = HUGE_VALF;
476
- }
477
- }
478
-
479
- /// add results for query i0..i1 and j0..j1
480
- void add_results(size_t j0, size_t j1, const T* dis_tab) {
481
- for (int64_t i = i0; i < i1; i++) {
482
- const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
483
-
484
- auto& min_distance = this->dis_tab[i];
485
- auto& min_index = this->ids_tab[i];
486
-
487
- for (size_t j = j0; j < j1; j++) {
488
- const T distance = dis_tab_i[j];
489
-
490
- if (C::cmp(min_distance, distance)) {
491
- min_distance = distance;
492
- min_index = j;
493
- }
494
- }
495
- }
496
- }
497
-
498
- void add_result(const size_t i, const T dis, const TI idx) {
499
- auto& min_distance = this->dis_tab[i];
500
- auto& min_index = this->ids_tab[i];
501
-
502
- if (C::cmp(min_distance, dis)) {
503
- min_distance = dis;
504
- min_index = idx;
505
- }
506
- }
507
-
508
- /// series of results for queries i0..i1 is done
509
- void end_multiple() {}
510
- };
511
-
512
568
  } // namespace faiss