faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -12,28 +12,196 @@
12
12
  #pragma once
13
13
 
14
14
  #include <faiss/impl/AuxIndexStructures.h>
15
+ #include <faiss/impl/FaissException.h>
16
+ #include <faiss/impl/IDSelector.h>
15
17
  #include <faiss/utils/Heap.h>
16
18
  #include <faiss/utils/partitioning.h>
17
19
 
20
+ #include <algorithm>
21
+ #include <iostream>
22
+
18
23
  namespace faiss {
19
24
 
20
25
  /*****************************************************************
21
- * Heap based result handler
26
+ * The classes below are intended to be used as template arguments
27
+ * they handle results for batches of queries (size nq).
28
+ * They can be called in two ways:
29
+ * - by instanciating a SingleResultHandler that tracks results for a single
30
+ * query
31
+ * - with begin_multiple/add_results/end_multiple calls where a whole block of
32
+ * results is submitted
33
+ * All classes are templated on C which to define wheter the min or the max of
34
+ * results is to be kept, and on sel, so that the codepaths for with / without
35
+ * selector can be separated at compile time.
22
36
  *****************************************************************/
23
37
 
38
+ template <class C, bool use_sel = false>
39
+ struct BlockResultHandler {
40
+ size_t nq; // number of queries for which we search
41
+ const IDSelector* sel;
42
+
43
+ explicit BlockResultHandler(size_t nq, const IDSelector* sel = nullptr)
44
+ : nq(nq), sel(sel) {
45
+ assert(!use_sel || sel);
46
+ }
47
+
48
+ // currently handled query range
49
+ size_t i0 = 0, i1 = 0;
50
+
51
+ // start collecting results for queries [i0, i1)
52
+ virtual void begin_multiple(size_t i0_2, size_t i1_2) {
53
+ this->i0 = i0_2;
54
+ this->i1 = i1_2;
55
+ }
56
+
57
+ // add results for queries [i0, i1) and database [j0, j1)
58
+ virtual void add_results(size_t, size_t, const typename C::T*) {}
59
+
60
+ // series of results for queries i0..i1 is done
61
+ virtual void end_multiple() {}
62
+
63
+ virtual ~BlockResultHandler() {}
64
+
65
+ bool is_in_selection(idx_t i) const {
66
+ return !use_sel || sel->is_member(i);
67
+ }
68
+ };
69
+
70
+ // handler for a single query
24
71
  template <class C>
25
- struct HeapResultHandler {
72
+ struct ResultHandler {
73
+ // if not better than threshold, then not necessary to call add_result
74
+ typename C::T threshold = C::neutral();
75
+
76
+ // return whether threshold was updated
77
+ virtual bool add_result(typename C::T dis, typename C::TI idx) = 0;
78
+
79
+ virtual ~ResultHandler() {}
80
+ };
81
+
82
+ /*****************************************************************
83
+ * Single best result handler.
84
+ * Tracks the only best result, thus avoiding storing
85
+ * some temporary data in memory.
86
+ *****************************************************************/
87
+
88
+ template <class C, bool use_sel = false>
89
+ struct Top1BlockResultHandler : BlockResultHandler<C, use_sel> {
90
+ using T = typename C::T;
91
+ using TI = typename C::TI;
92
+ using BlockResultHandler<C, use_sel>::i0;
93
+ using BlockResultHandler<C, use_sel>::i1;
94
+
95
+ // contains exactly nq elements
96
+ T* dis_tab;
97
+ // contains exactly nq elements
98
+ TI* ids_tab;
99
+
100
+ Top1BlockResultHandler(
101
+ size_t nq,
102
+ T* dis_tab,
103
+ TI* ids_tab,
104
+ const IDSelector* sel = nullptr)
105
+ : BlockResultHandler<C, use_sel>(nq, sel),
106
+ dis_tab(dis_tab),
107
+ ids_tab(ids_tab) {}
108
+
109
+ struct SingleResultHandler : ResultHandler<C> {
110
+ Top1BlockResultHandler& hr;
111
+ using ResultHandler<C>::threshold;
112
+
113
+ TI min_idx;
114
+ size_t current_idx = 0;
115
+
116
+ explicit SingleResultHandler(Top1BlockResultHandler& hr) : hr(hr) {}
117
+
118
+ /// begin results for query # i
119
+ void begin(const size_t current_idx_2) {
120
+ this->current_idx = current_idx_2;
121
+ threshold = C::neutral();
122
+ min_idx = -1;
123
+ }
124
+
125
+ /// add one result for query i
126
+ bool add_result(T dis, TI idx) final {
127
+ if (C::cmp(this->threshold, dis)) {
128
+ threshold = dis;
129
+ min_idx = idx;
130
+ return true;
131
+ }
132
+ return false;
133
+ }
134
+
135
+ /// series of results for query i is done
136
+ void end() {
137
+ hr.dis_tab[current_idx] = threshold;
138
+ hr.ids_tab[current_idx] = min_idx;
139
+ }
140
+ };
141
+
142
+ /// begin
143
+ void begin_multiple(size_t i0, size_t i1) final {
144
+ this->i0 = i0;
145
+ this->i1 = i1;
146
+
147
+ for (size_t i = i0; i < i1; i++) {
148
+ this->dis_tab[i] = C::neutral();
149
+ }
150
+ }
151
+
152
+ /// add results for query i0..i1 and j0..j1
153
+ void add_results(size_t j0, size_t j1, const T* dis_tab_2) final {
154
+ for (int64_t i = i0; i < i1; i++) {
155
+ const T* dis_tab_i = dis_tab_2 + (j1 - j0) * (i - i0) - j0;
156
+
157
+ auto& min_distance = this->dis_tab[i];
158
+ auto& min_index = this->ids_tab[i];
159
+
160
+ for (size_t j = j0; j < j1; j++) {
161
+ const T distance = dis_tab_i[j];
162
+
163
+ if (C::cmp(min_distance, distance)) {
164
+ min_distance = distance;
165
+ min_index = j;
166
+ }
167
+ }
168
+ }
169
+ }
170
+
171
+ void add_result(const size_t i, const T dis, const TI idx) {
172
+ auto& min_distance = this->dis_tab[i];
173
+ auto& min_index = this->ids_tab[i];
174
+
175
+ if (C::cmp(min_distance, dis)) {
176
+ min_distance = dis;
177
+ min_index = idx;
178
+ }
179
+ }
180
+ };
181
+
182
+ /*****************************************************************
183
+ * Heap based result handler
184
+ *****************************************************************/
185
+
186
+ template <class C, bool use_sel = false>
187
+ struct HeapBlockResultHandler : BlockResultHandler<C, use_sel> {
26
188
  using T = typename C::T;
27
189
  using TI = typename C::TI;
190
+ using BlockResultHandler<C, use_sel>::i0;
191
+ using BlockResultHandler<C, use_sel>::i1;
28
192
 
29
- int nq;
30
193
  T* heap_dis_tab;
31
194
  TI* heap_ids_tab;
32
195
 
33
196
  int64_t k; // number of results to keep
34
197
 
35
- HeapResultHandler(size_t nq, T* heap_dis_tab, TI* heap_ids_tab, size_t k)
36
- : nq(nq),
198
+ HeapBlockResultHandler(
199
+ size_t nq,
200
+ T* heap_dis_tab,
201
+ TI* heap_ids_tab,
202
+ size_t k,
203
+ const IDSelector* sel = nullptr)
204
+ : BlockResultHandler<C, use_sel>(nq, sel),
37
205
  heap_dis_tab(heap_dis_tab),
38
206
  heap_ids_tab(heap_ids_tab),
39
207
  k(k) {}
@@ -43,30 +211,33 @@ struct HeapResultHandler {
43
211
  * called from 1 thread)
44
212
  */
45
213
 
46
- struct SingleResultHandler {
47
- HeapResultHandler& hr;
214
+ struct SingleResultHandler : ResultHandler<C> {
215
+ HeapBlockResultHandler& hr;
216
+ using ResultHandler<C>::threshold;
48
217
  size_t k;
49
218
 
50
219
  T* heap_dis;
51
220
  TI* heap_ids;
52
- T thresh;
53
221
 
54
- SingleResultHandler(HeapResultHandler& hr) : hr(hr), k(hr.k) {}
222
+ explicit SingleResultHandler(HeapBlockResultHandler& hr)
223
+ : hr(hr), k(hr.k) {}
55
224
 
56
225
  /// begin results for query # i
57
226
  void begin(size_t i) {
58
227
  heap_dis = hr.heap_dis_tab + i * k;
59
228
  heap_ids = hr.heap_ids_tab + i * k;
60
229
  heap_heapify<C>(k, heap_dis, heap_ids);
61
- thresh = heap_dis[0];
230
+ threshold = heap_dis[0];
62
231
  }
63
232
 
64
233
  /// add one result for query i
65
- void add_result(T dis, TI idx) {
66
- if (C::cmp(heap_dis[0], dis)) {
234
+ bool add_result(T dis, TI idx) final {
235
+ if (C::cmp(threshold, dis)) {
67
236
  heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
68
- thresh = heap_dis[0];
237
+ threshold = heap_dis[0];
238
+ return true;
69
239
  }
240
+ return false;
70
241
  }
71
242
 
72
243
  /// series of results for query i is done
@@ -79,19 +250,17 @@ struct HeapResultHandler {
79
250
  * API for multiple results (called from 1 thread)
80
251
  */
81
252
 
82
- size_t i0, i1;
83
-
84
253
  /// begin
85
- void begin_multiple(size_t i0, size_t i1) {
86
- this->i0 = i0;
87
- this->i1 = i1;
254
+ void begin_multiple(size_t i0_2, size_t i1_2) final {
255
+ this->i0 = i0_2;
256
+ this->i1 = i1_2;
88
257
  for (size_t i = i0; i < i1; i++) {
89
258
  heap_heapify<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
90
259
  }
91
260
  }
92
261
 
93
262
  /// add results for query i0..i1 and j0..j1
94
- void add_results(size_t j0, size_t j1, const T* dis_tab) {
263
+ void add_results(size_t j0, size_t j1, const T* dis_tab) final {
95
264
  #pragma omp parallel for
96
265
  for (int64_t i = i0; i < i1; i++) {
97
266
  T* heap_dis = heap_dis_tab + i * k;
@@ -109,7 +278,7 @@ struct HeapResultHandler {
109
278
  }
110
279
 
111
280
  /// series of results for queries i0..i1 is done
112
- void end_multiple() {
281
+ void end_multiple() final {
113
282
  // maybe parallel for
114
283
  for (size_t i = i0; i < i1; i++) {
115
284
  heap_reorder<C>(k, heap_dis_tab + i * k, heap_ids_tab + i * k);
@@ -128,9 +297,10 @@ struct HeapResultHandler {
128
297
 
129
298
  /// Reservoir for a single query
130
299
  template <class C>
131
- struct ReservoirTopN {
300
+ struct ReservoirTopN : ResultHandler<C> {
132
301
  using T = typename C::T;
133
302
  using TI = typename C::TI;
303
+ using ResultHandler<C>::threshold;
134
304
 
135
305
  T* vals;
136
306
  TI* ids;
@@ -139,8 +309,6 @@ struct ReservoirTopN {
139
309
  size_t n; // number of requested elements
140
310
  size_t capacity; // size of storage
141
311
 
142
- T threshold; // current threshold
143
-
144
312
  ReservoirTopN() {}
145
313
 
146
314
  ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
@@ -149,15 +317,22 @@ struct ReservoirTopN {
149
317
  threshold = C::neutral();
150
318
  }
151
319
 
152
- void add(T val, TI id) {
320
+ bool add_result(T val, TI id) final {
321
+ bool updated_threshold = false;
153
322
  if (C::cmp(threshold, val)) {
154
323
  if (i == capacity) {
155
324
  shrink_fuzzy();
325
+ updated_threshold = true;
156
326
  }
157
327
  vals[i] = val;
158
328
  ids[i] = id;
159
329
  i++;
160
330
  }
331
+ return updated_threshold;
332
+ }
333
+
334
+ void add(T val, TI id) {
335
+ add_result(val, id);
161
336
  }
162
337
 
163
338
  // reduce storage from capacity to anything
@@ -169,6 +344,11 @@ struct ReservoirTopN {
169
344
  vals, ids, capacity, n, (capacity + n) / 2, &i);
170
345
  }
171
346
 
347
+ void shrink() {
348
+ threshold = partition<C>(vals, ids, i, n);
349
+ i = n;
350
+ }
351
+
172
352
  void to_result(T* heap_dis, TI* heap_ids) const {
173
353
  for (int j = 0; j < std::min(i, n); j++) {
174
354
  heap_push<C>(j + 1, heap_dis, heap_ids, vals[j], ids[j]);
@@ -186,24 +366,26 @@ struct ReservoirTopN {
186
366
  }
187
367
  };
188
368
 
189
- template <class C>
190
- struct ReservoirResultHandler {
369
+ template <class C, bool use_sel = false>
370
+ struct ReservoirBlockResultHandler : BlockResultHandler<C, use_sel> {
191
371
  using T = typename C::T;
192
372
  using TI = typename C::TI;
373
+ using BlockResultHandler<C, use_sel>::i0;
374
+ using BlockResultHandler<C, use_sel>::i1;
193
375
 
194
- int nq;
195
376
  T* heap_dis_tab;
196
377
  TI* heap_ids_tab;
197
378
 
198
379
  int64_t k; // number of results to keep
199
380
  size_t capacity; // capacity of the reservoirs
200
381
 
201
- ReservoirResultHandler(
382
+ ReservoirBlockResultHandler(
202
383
  size_t nq,
203
384
  T* heap_dis_tab,
204
385
  TI* heap_ids_tab,
205
- size_t k)
206
- : nq(nq),
386
+ size_t k,
387
+ const IDSelector* sel = nullptr)
388
+ : BlockResultHandler<C, use_sel>(nq, sel),
207
389
  heap_dis_tab(heap_dis_tab),
208
390
  heap_ids_tab(heap_ids_tab),
209
391
  k(k) {
@@ -216,40 +398,34 @@ struct ReservoirResultHandler {
216
398
  * called from 1 thread)
217
399
  */
218
400
 
219
- struct SingleResultHandler {
220
- ReservoirResultHandler& hr;
401
+ struct SingleResultHandler : ReservoirTopN<C> {
402
+ ReservoirBlockResultHandler& hr;
221
403
 
222
404
  std::vector<T> reservoir_dis;
223
405
  std::vector<TI> reservoir_ids;
224
- ReservoirTopN<C> res1;
225
406
 
226
- SingleResultHandler(ReservoirResultHandler& hr)
227
- : hr(hr),
228
- reservoir_dis(hr.capacity),
229
- reservoir_ids(hr.capacity) {}
407
+ explicit SingleResultHandler(ReservoirBlockResultHandler& hr)
408
+ : ReservoirTopN<C>(hr.k, hr.capacity, nullptr, nullptr),
409
+ hr(hr) {}
230
410
 
231
- size_t i;
411
+ size_t qno;
232
412
 
233
413
  /// begin results for query # i
234
- void begin(size_t i) {
235
- res1 = ReservoirTopN<C>(
236
- hr.k,
237
- hr.capacity,
238
- reservoir_dis.data(),
239
- reservoir_ids.data());
240
- this->i = i;
414
+ void begin(size_t qno_2) {
415
+ reservoir_dis.resize(hr.capacity);
416
+ reservoir_ids.resize(hr.capacity);
417
+ this->vals = reservoir_dis.data();
418
+ this->ids = reservoir_ids.data();
419
+ this->i = 0; // size of reservoir
420
+ this->threshold = C::neutral();
421
+ this->qno = qno_2;
241
422
  }
242
423
 
243
- /// add one result for query i
244
- void add_result(T dis, TI idx) {
245
- res1.add(dis, idx);
246
- }
247
-
248
- /// series of results for query i is done
424
+ /// series of results for query qno is done
249
425
  void end() {
250
- T* heap_dis = hr.heap_dis_tab + i * hr.k;
251
- TI* heap_ids = hr.heap_ids_tab + i * hr.k;
252
- res1.to_result(heap_dis, heap_ids);
426
+ T* heap_dis = hr.heap_dis_tab + qno * hr.k;
427
+ TI* heap_ids = hr.heap_ids_tab + qno * hr.k;
428
+ this->to_result(heap_dis, heap_ids);
253
429
  }
254
430
  };
255
431
 
@@ -257,44 +433,41 @@ struct ReservoirResultHandler {
257
433
  * API for multiple results (called from 1 thread)
258
434
  */
259
435
 
260
- size_t i0, i1;
261
-
262
436
  std::vector<T> reservoir_dis;
263
437
  std::vector<TI> reservoir_ids;
264
438
  std::vector<ReservoirTopN<C>> reservoirs;
265
439
 
266
440
  /// begin
267
- void begin_multiple(size_t i0, size_t i1) {
268
- this->i0 = i0;
269
- this->i1 = i1;
441
+ void begin_multiple(size_t i0_2, size_t i1_2) {
442
+ this->i0 = i0_2;
443
+ this->i1 = i1_2;
270
444
  reservoir_dis.resize((i1 - i0) * capacity);
271
445
  reservoir_ids.resize((i1 - i0) * capacity);
272
446
  reservoirs.clear();
273
- for (size_t i = i0; i < i1; i++) {
447
+ for (size_t i = i0_2; i < i1_2; i++) {
274
448
  reservoirs.emplace_back(
275
449
  k,
276
450
  capacity,
277
- reservoir_dis.data() + (i - i0) * capacity,
278
- reservoir_ids.data() + (i - i0) * capacity);
451
+ reservoir_dis.data() + (i - i0_2) * capacity,
452
+ reservoir_ids.data() + (i - i0_2) * capacity);
279
453
  }
280
454
  }
281
455
 
282
456
  /// add results for query i0..i1 and j0..j1
283
457
  void add_results(size_t j0, size_t j1, const T* dis_tab) {
284
- // maybe parallel for
285
458
  #pragma omp parallel for
286
459
  for (int64_t i = i0; i < i1; i++) {
287
460
  ReservoirTopN<C>& reservoir = reservoirs[i - i0];
288
461
  const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
289
462
  for (size_t j = j0; j < j1; j++) {
290
463
  T dis = dis_tab_i[j];
291
- reservoir.add(dis, j);
464
+ reservoir.add_result(dis, j);
292
465
  }
293
466
  }
294
467
  }
295
468
 
296
469
  /// series of results for queries i0..i1 is done
297
- void end_multiple() {
470
+ void end_multiple() final {
298
471
  // maybe parallel for
299
472
  for (size_t i = i0; i < i1; i++) {
300
473
  reservoirs[i - i0].to_result(
@@ -307,30 +480,39 @@ struct ReservoirResultHandler {
307
480
  * Result handler for range searches
308
481
  *****************************************************************/
309
482
 
310
- template <class C>
311
- struct RangeSearchResultHandler {
483
+ template <class C, bool use_sel = false>
484
+ struct RangeSearchBlockResultHandler : BlockResultHandler<C, use_sel> {
312
485
  using T = typename C::T;
313
486
  using TI = typename C::TI;
487
+ using BlockResultHandler<C, use_sel>::i0;
488
+ using BlockResultHandler<C, use_sel>::i1;
314
489
 
315
490
  RangeSearchResult* res;
316
- float radius;
491
+ T radius;
317
492
 
318
- RangeSearchResultHandler(RangeSearchResult* res, float radius)
319
- : res(res), radius(radius) {}
493
+ RangeSearchBlockResultHandler(
494
+ RangeSearchResult* res,
495
+ float radius,
496
+ const IDSelector* sel = nullptr)
497
+ : BlockResultHandler<C, use_sel>(res->nq, sel),
498
+ res(res),
499
+ radius(radius) {}
320
500
 
321
501
  /******************************************************
322
502
  * API for 1 result at a time (each SingleResultHandler is
323
503
  * called from 1 thread)
324
504
  ******************************************************/
325
505
 
326
- struct SingleResultHandler {
506
+ struct SingleResultHandler : ResultHandler<C> {
327
507
  // almost the same interface as RangeSearchResultHandler
508
+ using ResultHandler<C>::threshold;
328
509
  RangeSearchPartialResult pres;
329
- float radius;
330
510
  RangeQueryResult* qr = nullptr;
331
511
 
332
- SingleResultHandler(RangeSearchResultHandler& rh)
333
- : pres(rh.res), radius(rh.radius) {}
512
+ explicit SingleResultHandler(RangeSearchBlockResultHandler& rh)
513
+ : pres(rh.res) {
514
+ threshold = rh.radius;
515
+ }
334
516
 
335
517
  /// begin results for query # i
336
518
  void begin(size_t i) {
@@ -338,17 +520,26 @@ struct RangeSearchResultHandler {
338
520
  }
339
521
 
340
522
  /// add one result for query i
341
- void add_result(T dis, TI idx) {
342
- if (C::cmp(radius, dis)) {
523
+ bool add_result(T dis, TI idx) final {
524
+ if (C::cmp(threshold, dis)) {
343
525
  qr->add(dis, idx);
344
526
  }
527
+ return false;
345
528
  }
346
529
 
347
530
  /// series of results for query i is done
348
531
  void end() {}
349
532
 
350
533
  ~SingleResultHandler() {
351
- pres.finalize();
534
+ try {
535
+ // finalize the partial result
536
+ pres.finalize();
537
+ } catch (const faiss::FaissException& e) {
538
+ // Do nothing if allocation fails in finalizing partial results.
539
+ #ifndef NDEBUG
540
+ std::cerr << e.what() << std::endl;
541
+ #endif
542
+ }
352
543
  }
353
544
  };
354
545
 
@@ -356,16 +547,14 @@ struct RangeSearchResultHandler {
356
547
  * API for multiple results (called from 1 thread)
357
548
  ******************************************************/
358
549
 
359
- size_t i0, i1;
360
-
361
550
  std::vector<RangeSearchPartialResult*> partial_results;
362
551
  std::vector<size_t> j0s;
363
552
  int pr = 0;
364
553
 
365
554
  /// begin
366
- void begin_multiple(size_t i0, size_t i1) {
367
- this->i0 = i0;
368
- this->i1 = i1;
555
+ void begin_multiple(size_t i0_2, size_t i1_2) {
556
+ this->i0 = i0_2;
557
+ this->i1 = i1_2;
369
558
  }
370
559
 
371
560
  /// add results for query i0..i1 and j0..j1
@@ -404,109 +593,95 @@ struct RangeSearchResultHandler {
404
593
  }
405
594
  }
406
595
 
407
- void end_multiple() {}
408
-
409
- ~RangeSearchResultHandler() {
410
- if (partial_results.size() > 0) {
411
- RangeSearchPartialResult::merge(partial_results);
596
+ ~RangeSearchBlockResultHandler() {
597
+ try {
598
+ if (partial_results.size() > 0) {
599
+ RangeSearchPartialResult::merge(partial_results);
600
+ }
601
+ } catch (const faiss::FaissException& e) {
602
+ // Do nothing if allocation fails in merge.
603
+ #ifndef NDEBUG
604
+ std::cerr << e.what() << std::endl;
605
+ #endif
412
606
  }
413
607
  }
414
608
  };
415
609
 
416
610
  /*****************************************************************
417
- * Single best result handler.
418
- * Tracks the only best result, thus avoiding storing
419
- * some temporary data in memory.
611
+ * Dispatcher function to choose the right knn result handler depending on k
420
612
  *****************************************************************/
421
613
 
422
- template <class C>
423
- struct SingleBestResultHandler {
424
- using T = typename C::T;
425
- using TI = typename C::TI;
426
-
427
- int nq;
428
- // contains exactly nq elements
429
- T* dis_tab;
430
- // contains exactly nq elements
431
- TI* ids_tab;
432
-
433
- SingleBestResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
434
- : nq(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
435
-
436
- struct SingleResultHandler {
437
- SingleBestResultHandler& hr;
438
-
439
- T min_dis;
440
- TI min_idx;
441
- size_t current_idx = 0;
442
-
443
- SingleResultHandler(SingleBestResultHandler& hr) : hr(hr) {}
444
-
445
- /// begin results for query # i
446
- void begin(const size_t current_idx) {
447
- this->current_idx = current_idx;
448
- min_dis = HUGE_VALF;
449
- min_idx = 0;
450
- }
451
-
452
- /// add one result for query i
453
- void add_result(T dis, TI idx) {
454
- if (C::cmp(min_dis, dis)) {
455
- min_dis = dis;
456
- min_idx = idx;
457
- }
458
- }
614
+ // declared in distances.cpp
615
+ FAISS_API extern int distance_compute_min_k_reservoir;
616
+
617
+ template <class Consumer, class... Types>
618
+ typename Consumer::T dispatch_knn_ResultHandler(
619
+ size_t nx,
620
+ float* vals,
621
+ int64_t* ids,
622
+ size_t k,
623
+ MetricType metric,
624
+ const IDSelector* sel,
625
+ Consumer& consumer,
626
+ Types... args) {
627
+ #define DISPATCH_C_SEL(C, use_sel) \
628
+ if (k == 1) { \
629
+ Top1BlockResultHandler<C, use_sel> res(nx, vals, ids, sel); \
630
+ return consumer.template f<>(res, args...); \
631
+ } else if (k < distance_compute_min_k_reservoir) { \
632
+ HeapBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
633
+ return consumer.template f<>(res, args...); \
634
+ } else { \
635
+ ReservoirBlockResultHandler<C, use_sel> res(nx, vals, ids, k, sel); \
636
+ return consumer.template f<>(res, args...); \
637
+ }
459
638
 
460
- /// series of results for query i is done
461
- void end() {
462
- hr.dis_tab[current_idx] = min_dis;
463
- hr.ids_tab[current_idx] = min_idx;
639
+ if (is_similarity_metric(metric)) {
640
+ using C = CMin<float, int64_t>;
641
+ if (sel) {
642
+ DISPATCH_C_SEL(C, true);
643
+ } else {
644
+ DISPATCH_C_SEL(C, false);
464
645
  }
465
- };
466
-
467
- size_t i0, i1;
468
-
469
- /// begin
470
- void begin_multiple(size_t i0, size_t i1) {
471
- this->i0 = i0;
472
- this->i1 = i1;
473
-
474
- for (size_t i = i0; i < i1; i++) {
475
- this->dis_tab[i] = HUGE_VALF;
646
+ } else {
647
+ using C = CMax<float, int64_t>;
648
+ if (sel) {
649
+ DISPATCH_C_SEL(C, true);
650
+ } else {
651
+ DISPATCH_C_SEL(C, false);
476
652
  }
477
653
  }
478
-
479
- /// add results for query i0..i1 and j0..j1
480
- void add_results(size_t j0, size_t j1, const T* dis_tab) {
481
- for (int64_t i = i0; i < i1; i++) {
482
- const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
483
-
484
- auto& min_distance = this->dis_tab[i];
485
- auto& min_index = this->ids_tab[i];
486
-
487
- for (size_t j = j0; j < j1; j++) {
488
- const T distance = dis_tab_i[j];
489
-
490
- if (C::cmp(min_distance, distance)) {
491
- min_distance = distance;
492
- min_index = j;
493
- }
494
- }
654
+ #undef DISPATCH_C_SEL
655
+ }
656
+
657
+ template <class Consumer, class... Types>
658
+ typename Consumer::T dispatch_range_ResultHandler(
659
+ RangeSearchResult* res,
660
+ float radius,
661
+ MetricType metric,
662
+ const IDSelector* sel,
663
+ Consumer& consumer,
664
+ Types... args) {
665
+ #define DISPATCH_C_SEL(C, use_sel) \
666
+ RangeSearchBlockResultHandler<C, use_sel> resb(res, radius, sel); \
667
+ return consumer.template f<>(resb, args...);
668
+
669
+ if (is_similarity_metric(metric)) {
670
+ using C = CMin<float, int64_t>;
671
+ if (sel) {
672
+ DISPATCH_C_SEL(C, true);
673
+ } else {
674
+ DISPATCH_C_SEL(C, false);
495
675
  }
496
- }
497
-
498
- void add_result(const size_t i, const T dis, const TI idx) {
499
- auto& min_distance = this->dis_tab[i];
500
- auto& min_index = this->ids_tab[i];
501
-
502
- if (C::cmp(min_distance, dis)) {
503
- min_distance = dis;
504
- min_index = idx;
676
+ } else {
677
+ using C = CMax<float, int64_t>;
678
+ if (sel) {
679
+ DISPATCH_C_SEL(C, true);
680
+ } else {
681
+ DISPATCH_C_SEL(C, false);
505
682
  }
506
683
  }
507
-
508
- /// series of results for queries i0..i1 is done
509
- void end_multiple() {}
510
- };
684
+ #undef DISPATCH_C_SEL
685
+ }
511
686
 
512
687
  } // namespace faiss