faiss 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -14,40 +14,86 @@
14
14
  #include <faiss/utils/Heap.h>
15
15
  #include <faiss/utils/simdlib.h>
16
16
 
17
+ #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/ResultHandler.h>
17
19
  #include <faiss/impl/platform_macros.h>
18
20
  #include <faiss/utils/AlignedTable.h>
19
21
  #include <faiss/utils/partitioning.h>
20
22
 
21
23
  /** This file contains callbacks for kernels that compute distances.
22
- *
23
- * The SIMDResultHandler object is intended to be templated and inlined.
24
- * Methods:
25
- * - handle(): called when 32 distances are computed and provided in two
26
- * simd16uint16. (q, b) indicate which entry it is in the block.
27
- * - set_block_origin(): set the sub-matrix that is being computed
28
24
  */
29
25
 
30
26
  namespace faiss {
31
27
 
28
+ struct SIMDResultHandler {
29
+ // used to dispatch templates
30
+ bool is_CMax = false;
31
+ uint8_t sizeof_ids = 0;
32
+ bool with_fields = false;
33
+
34
+ /** called when 32 distances are computed and provided in two
35
+ * simd16uint16. (q, b) indicate which entry it is in the block. */
36
+ virtual void handle(
37
+ size_t q,
38
+ size_t b,
39
+ simd16uint16 d0,
40
+ simd16uint16 d1) = 0;
41
+
42
+ /// set the sub-matrix that is being computed
43
+ virtual void set_block_origin(size_t i0, size_t j0) = 0;
44
+
45
+ virtual ~SIMDResultHandler() {}
46
+ };
47
+
48
+ /* Result handler that will return float resutls eventually */
49
+ struct SIMDResultHandlerToFloat : SIMDResultHandler {
50
+ size_t nq; // number of queries
51
+ size_t ntotal; // ignore excess elements after ntotal
52
+
53
+ /// these fields are used mainly for the IVF variants (with_id_map=true)
54
+ const idx_t* id_map = nullptr; // map offset in invlist to vector id
55
+ const int* q_map = nullptr; // map q to global query
56
+ const uint16_t* dbias =
57
+ nullptr; // table of biases to add to each query (for IVF L2 search)
58
+ const float* normalizers = nullptr; // size 2 * nq, to convert
59
+
60
+ SIMDResultHandlerToFloat(size_t nq, size_t ntotal)
61
+ : nq(nq), ntotal(ntotal) {}
62
+
63
+ virtual void begin(const float* norms) {
64
+ normalizers = norms;
65
+ }
66
+
67
+ // called at end of search to convert int16 distances to float, before
68
+ // normalizers are deallocated
69
+ virtual void end() {
70
+ normalizers = nullptr;
71
+ }
72
+ };
73
+
74
+ FAISS_API extern bool simd_result_handlers_accept_virtual;
75
+
32
76
  namespace simd_result_handlers {
33
77
 
34
- /** Dummy structure that just computes a checksum on results
78
+ /** Dummy structure that just computes a chqecksum on results
35
79
  * (to avoid the computation to be optimized away) */
36
- struct DummyResultHandler {
80
+ struct DummyResultHandler : SIMDResultHandler {
37
81
  size_t cs = 0;
38
82
 
39
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
83
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
40
84
  cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0();
41
85
  }
42
86
 
43
- void set_block_origin(size_t, size_t) {}
87
+ void set_block_origin(size_t, size_t) final {}
88
+
89
+ ~DummyResultHandler() {}
44
90
  };
45
91
 
46
92
  /** memorize results in a nq-by-nb matrix.
47
93
  *
48
94
  * j0 is the current upper-left block of the matrix
49
95
  */
50
- struct StoreResultHandler {
96
+ struct StoreResultHandler : SIMDResultHandler {
51
97
  uint16_t* data;
52
98
  size_t ld; // total number of columns
53
99
  size_t i0 = 0;
@@ -55,32 +101,32 @@ struct StoreResultHandler {
55
101
 
56
102
  StoreResultHandler(uint16_t* data, size_t ld) : data(data), ld(ld) {}
57
103
 
58
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
104
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
59
105
  size_t ofs = (q + i0) * ld + j0 + b * 32;
60
106
  d0.store(data + ofs);
61
107
  d1.store(data + ofs + 16);
62
108
  }
63
109
 
64
- void set_block_origin(size_t i0, size_t j0) {
65
- this->i0 = i0;
66
- this->j0 = j0;
110
+ void set_block_origin(size_t i0_in, size_t j0_in) final {
111
+ this->i0 = i0_in;
112
+ this->j0 = j0_in;
67
113
  }
68
114
  };
69
115
 
70
116
  /** stores results in fixed-size matrix. */
71
117
  template <int NQ, int BB>
72
- struct FixedStorageHandler {
118
+ struct FixedStorageHandler : SIMDResultHandler {
73
119
  simd16uint16 dis[NQ][BB];
74
120
  int i0 = 0;
75
121
 
76
- void handle(int q, int b, simd16uint16 d0, simd16uint16 d1) {
122
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
77
123
  dis[q + i0][2 * b] = d0;
78
124
  dis[q + i0][2 * b + 1] = d1;
79
125
  }
80
126
 
81
- void set_block_origin(size_t i0, size_t j0) {
82
- this->i0 = i0;
83
- assert(j0 == 0);
127
+ void set_block_origin(size_t i0_in, size_t j0_in) final {
128
+ this->i0 = i0_in;
129
+ assert(j0_in == 0);
84
130
  }
85
131
 
86
132
  template <class OtherResultHandler>
@@ -91,30 +137,29 @@ struct FixedStorageHandler {
91
137
  }
92
138
  }
93
139
  }
140
+ virtual ~FixedStorageHandler() {}
94
141
  };
95
142
 
96
- /** Record origin of current block */
143
+ /** Result handler that compares distances to check if they need to be kept */
97
144
  template <class C, bool with_id_map>
98
- struct SIMDResultHandler {
145
+ struct ResultHandlerCompare : SIMDResultHandlerToFloat {
99
146
  using TI = typename C::TI;
100
147
 
101
148
  bool disable = false;
102
149
 
103
150
  int64_t i0 = 0; // query origin
104
151
  int64_t j0 = 0; // db origin
105
- size_t ntotal; // ignore excess elements after ntotal
106
-
107
- /// these fields are used mainly for the IVF variants (with_id_map=true)
108
- const TI* id_map; // map offset in invlist to vector id
109
- const int* q_map; // map q to global query
110
- const uint16_t* dbias; // table of biases to add to each query
111
152
 
112
- explicit SIMDResultHandler(size_t ntotal)
113
- : ntotal(ntotal), id_map(nullptr), q_map(nullptr), dbias(nullptr) {}
153
+ ResultHandlerCompare(size_t nq, size_t ntotal)
154
+ : SIMDResultHandlerToFloat(nq, ntotal) {
155
+ this->is_CMax = C::is_max;
156
+ this->sizeof_ids = sizeof(typename C::TI);
157
+ this->with_fields = with_id_map;
158
+ }
114
159
 
115
- void set_block_origin(size_t i0, size_t j0) {
116
- this->i0 = i0;
117
- this->j0 = j0;
160
+ void set_block_origin(size_t i0_in, size_t j0_in) final {
161
+ this->i0 = i0_in;
162
+ this->j0 = j0_in;
118
163
  }
119
164
 
120
165
  // adjust handler data for IVF.
@@ -172,43 +217,37 @@ struct SIMDResultHandler {
172
217
  return lt_mask;
173
218
  }
174
219
 
175
- virtual void to_flat_arrays(
176
- float* distances,
177
- int64_t* labels,
178
- const float* normalizers = nullptr) = 0;
179
-
180
- virtual ~SIMDResultHandler() {}
220
+ virtual ~ResultHandlerCompare() {}
181
221
  };
182
222
 
183
223
  /** Special version for k=1 */
184
224
  template <class C, bool with_id_map = false>
185
- struct SingleResultHandler : SIMDResultHandler<C, with_id_map> {
225
+ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
186
226
  using T = typename C::T;
187
227
  using TI = typename C::TI;
228
+ using RHC = ResultHandlerCompare<C, with_id_map>;
229
+ using RHC::normalizers;
188
230
 
189
- struct Result {
190
- T val;
191
- TI id;
192
- };
193
- std::vector<Result> results;
231
+ std::vector<int16_t> idis;
232
+ float* dis;
233
+ int64_t* ids;
194
234
 
195
- SingleResultHandler(size_t nq, size_t ntotal)
196
- : SIMDResultHandler<C, with_id_map>(ntotal), results(nq) {
235
+ SingleResultHandler(size_t nq, size_t ntotal, float* dis, int64_t* ids)
236
+ : RHC(nq, ntotal), idis(nq), dis(dis), ids(ids) {
197
237
  for (int i = 0; i < nq; i++) {
198
- Result res = {C::neutral(), -1};
199
- results[i] = res;
238
+ ids[i] = -1;
239
+ idis[i] = C::neutral();
200
240
  }
201
241
  }
202
242
 
203
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
243
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
204
244
  if (this->disable) {
205
245
  return;
206
246
  }
207
247
 
208
248
  this->adjust_with_origin(q, d0, d1);
209
249
 
210
- Result& res = results[q];
211
- uint32_t lt_mask = this->get_lt_mask(res.val, b, d0, d1);
250
+ uint32_t lt_mask = this->get_lt_mask(idis[q], b, d0, d1);
212
251
  if (!lt_mask) {
213
252
  return;
214
253
  }
@@ -221,70 +260,61 @@ struct SingleResultHandler : SIMDResultHandler<C, with_id_map> {
221
260
  // find first non-zero
222
261
  int j = __builtin_ctz(lt_mask);
223
262
  lt_mask -= 1 << j;
224
- T dis = d32tab[j];
225
- if (C::cmp(res.val, dis)) {
226
- res.val = dis;
227
- res.id = this->adjust_id(b, j);
263
+ T d = d32tab[j];
264
+ if (C::cmp(idis[q], d)) {
265
+ idis[q] = d;
266
+ ids[q] = this->adjust_id(b, j);
228
267
  }
229
268
  }
230
269
  }
231
270
 
232
- void to_flat_arrays(
233
- float* distances,
234
- int64_t* labels,
235
- const float* normalizers = nullptr) override {
236
- for (int q = 0; q < results.size(); q++) {
271
+ void end() {
272
+ for (int q = 0; q < this->nq; q++) {
237
273
  if (!normalizers) {
238
- distances[q] = results[q].val;
274
+ dis[q] = idis[q];
239
275
  } else {
240
276
  float one_a = 1 / normalizers[2 * q];
241
277
  float b = normalizers[2 * q + 1];
242
- distances[q] = b + results[q].val * one_a;
278
+ dis[q] = b + idis[q] * one_a;
243
279
  }
244
- labels[q] = results[q].id;
245
280
  }
246
281
  }
247
282
  };
248
283
 
249
284
  /** Structure that collects results in a min- or max-heap */
250
285
  template <class C, bool with_id_map = false>
251
- struct HeapHandler : SIMDResultHandler<C, with_id_map> {
286
+ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
252
287
  using T = typename C::T;
253
288
  using TI = typename C::TI;
289
+ using RHC = ResultHandlerCompare<C, with_id_map>;
290
+ using RHC::normalizers;
254
291
 
255
- int nq;
256
- T* heap_dis_tab;
257
- TI* heap_ids_tab;
292
+ std::vector<uint16_t> idis;
293
+ std::vector<TI> iids;
294
+ float* dis;
295
+ int64_t* ids;
258
296
 
259
297
  int64_t k; // number of results to keep
260
298
 
261
- HeapHandler(
262
- int nq,
263
- T* heap_dis_tab,
264
- TI* heap_ids_tab,
265
- size_t k,
266
- size_t ntotal)
267
- : SIMDResultHandler<C, with_id_map>(ntotal),
268
- nq(nq),
269
- heap_dis_tab(heap_dis_tab),
270
- heap_ids_tab(heap_ids_tab),
299
+ HeapHandler(size_t nq, size_t ntotal, int64_t k, float* dis, int64_t* ids)
300
+ : RHC(nq, ntotal),
301
+ idis(nq * k),
302
+ iids(nq * k),
303
+ dis(dis),
304
+ ids(ids),
271
305
  k(k) {
272
- for (int q = 0; q < nq; q++) {
273
- T* heap_dis_in = heap_dis_tab + q * k;
274
- TI* heap_ids_in = heap_ids_tab + q * k;
275
- heap_heapify<C>(k, heap_dis_in, heap_ids_in);
276
- }
306
+ heap_heapify<C>(k * nq, idis.data(), iids.data());
277
307
  }
278
308
 
279
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
309
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
280
310
  if (this->disable) {
281
311
  return;
282
312
  }
283
313
 
284
314
  this->adjust_with_origin(q, d0, d1);
285
315
 
286
- T* heap_dis = heap_dis_tab + q * k;
287
- TI* heap_ids = heap_ids_tab + q * k;
316
+ T* heap_dis = idis.data() + q * k;
317
+ TI* heap_ids = iids.data() + q * k;
288
318
 
289
319
  uint16_t cur_thresh =
290
320
  heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) : 0xffff;
@@ -313,16 +343,13 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
313
343
  }
314
344
  }
315
345
 
316
- void to_flat_arrays(
317
- float* distances,
318
- int64_t* labels,
319
- const float* normalizers = nullptr) override {
320
- for (int q = 0; q < nq; q++) {
321
- T* heap_dis_in = heap_dis_tab + q * k;
322
- TI* heap_ids_in = heap_ids_tab + q * k;
346
+ void end() override {
347
+ for (int q = 0; q < this->nq; q++) {
348
+ T* heap_dis_in = idis.data() + q * k;
349
+ TI* heap_ids_in = iids.data() + q * k;
323
350
  heap_reorder<C>(k, heap_dis_in, heap_ids_in);
324
- int64_t* heap_ids = labels + q * k;
325
- float* heap_dis = distances + q * k;
351
+ float* heap_dis = dis + q * k;
352
+ int64_t* heap_ids = ids + q * k;
326
353
 
327
354
  float one_a = 1.0, b = 0.0;
328
355
  if (normalizers) {
@@ -330,8 +357,8 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
330
357
  b = normalizers[2 * q + 1];
331
358
  }
332
359
  for (int j = 0; j < k; j++) {
333
- heap_ids[j] = heap_ids_in[j];
334
360
  heap_dis[j] = heap_dis_in[j] * one_a + b;
361
+ heap_ids[j] = heap_ids_in[j];
335
362
  }
336
363
  }
337
364
  }
@@ -342,114 +369,45 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
342
369
  * Results are stored when they are below the threshold until the capacity is
343
370
  * reached. Then a partition sort is used to update the threshold. */
344
371
 
345
- namespace {
346
-
347
- uint64_t get_cy() {
348
- #ifdef MICRO_BENCHMARK
349
- uint32_t high, low;
350
- asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high));
351
- return ((uint64_t)high << 32) | (low);
352
- #else
353
- return 0;
354
- #endif
355
- }
356
-
357
- } // anonymous namespace
358
-
359
- template <class C>
360
- struct ReservoirTopN {
361
- using T = typename C::T;
362
- using TI = typename C::TI;
363
-
364
- T* vals;
365
- TI* ids;
366
-
367
- size_t i; // number of stored elements
368
- size_t n; // number of requested elements
369
- size_t capacity; // size of storage
370
- size_t cycles = 0;
371
-
372
- T threshold; // current threshold
373
-
374
- ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
375
- : vals(vals), ids(ids), i(0), n(n), capacity(capacity) {
376
- assert(n < capacity);
377
- threshold = C::neutral();
378
- }
379
-
380
- void add(T val, TI id) {
381
- if (C::cmp(threshold, val)) {
382
- if (i == capacity) {
383
- shrink_fuzzy();
384
- }
385
- vals[i] = val;
386
- ids[i] = id;
387
- i++;
388
- }
389
- }
390
-
391
- /// shrink number of stored elements to n
392
- void shrink_xx() {
393
- uint64_t t0 = get_cy();
394
- qselect(vals, ids, i, n);
395
- i = n; // forget all elements above i = n
396
- threshold = C::Crev::neutral();
397
- for (size_t j = 0; j < n; j++) {
398
- if (C::cmp(vals[j], threshold)) {
399
- threshold = vals[j];
400
- }
401
- }
402
- cycles += get_cy() - t0;
403
- }
404
-
405
- void shrink() {
406
- uint64_t t0 = get_cy();
407
- threshold = partition<C>(vals, ids, i, n);
408
- i = n;
409
- cycles += get_cy() - t0;
410
- }
411
-
412
- void shrink_fuzzy() {
413
- uint64_t t0 = get_cy();
414
- assert(i == capacity);
415
- threshold = partition_fuzzy<C>(
416
- vals, ids, capacity, n, (capacity + n) / 2, &i);
417
- cycles += get_cy() - t0;
418
- }
419
- };
420
-
421
372
  /** Handler built from several ReservoirTopN (one per query) */
422
373
  template <class C, bool with_id_map = false>
423
- struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
374
+ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
424
375
  using T = typename C::T;
425
376
  using TI = typename C::TI;
377
+ using RHC = ResultHandlerCompare<C, with_id_map>;
378
+ using RHC::normalizers;
426
379
 
427
380
  size_t capacity; // rounded up to multiple of 16
381
+
382
+ // where the final results will be written
383
+ float* dis;
384
+ int64_t* ids;
385
+
428
386
  std::vector<TI> all_ids;
429
387
  AlignedTable<T> all_vals;
430
-
431
388
  std::vector<ReservoirTopN<C>> reservoirs;
432
389
 
433
- uint64_t times[4];
434
-
435
- ReservoirHandler(size_t nq, size_t ntotal, size_t n, size_t capacity_in)
436
- : SIMDResultHandler<C, with_id_map>(ntotal),
437
- capacity((capacity_in + 15) & ~15),
438
- all_ids(nq * capacity),
439
- all_vals(nq * capacity) {
390
+ ReservoirHandler(
391
+ size_t nq,
392
+ size_t ntotal,
393
+ size_t k,
394
+ size_t cap,
395
+ float* dis,
396
+ int64_t* ids)
397
+ : RHC(nq, ntotal), capacity((cap + 15) & ~15), dis(dis), ids(ids) {
440
398
  assert(capacity % 16 == 0);
441
- for (size_t i = 0; i < nq; i++) {
399
+ all_ids.resize(nq * capacity);
400
+ all_vals.resize(nq * capacity);
401
+ for (size_t q = 0; q < nq; q++) {
442
402
  reservoirs.emplace_back(
443
- n,
403
+ k,
444
404
  capacity,
445
- all_vals.get() + i * capacity,
446
- all_ids.data() + i * capacity);
405
+ all_vals.get() + q * capacity,
406
+ all_ids.data() + q * capacity);
447
407
  }
448
- times[0] = times[1] = times[2] = times[3] = 0;
449
408
  }
450
409
 
451
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
452
- uint64_t t0 = get_cy();
410
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
453
411
  if (this->disable) {
454
412
  return;
455
413
  }
@@ -457,8 +415,6 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
457
415
 
458
416
  ReservoirTopN<C>& res = reservoirs[q];
459
417
  uint32_t lt_mask = this->get_lt_mask(res.threshold, b, d0, d1);
460
- uint64_t t1 = get_cy();
461
- times[0] += t1 - t0;
462
418
 
463
419
  if (!lt_mask) {
464
420
  return;
@@ -474,20 +430,14 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
474
430
  T dis = d32tab[j];
475
431
  res.add(dis, this->adjust_id(b, j));
476
432
  }
477
- times[1] += get_cy() - t1;
478
433
  }
479
434
 
480
- void to_flat_arrays(
481
- float* distances,
482
- int64_t* labels,
483
- const float* normalizers = nullptr) override {
435
+ void end() override {
484
436
  using Cf = typename std::conditional<
485
437
  C::is_max,
486
438
  CMax<float, int64_t>,
487
439
  CMin<float, int64_t>>::type;
488
440
 
489
- uint64_t t0 = get_cy();
490
- uint64_t t3 = 0;
491
441
  std::vector<int> perm(reservoirs[0].n);
492
442
  for (int q = 0; q < reservoirs.size(); q++) {
493
443
  ReservoirTopN<C>& res = reservoirs[q];
@@ -496,8 +446,8 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
496
446
  if (res.i > res.n) {
497
447
  res.shrink();
498
448
  }
499
- int64_t* heap_ids = labels + q * n;
500
- float* heap_dis = distances + q * n;
449
+ int64_t* heap_ids = ids + q * n;
450
+ float* heap_dis = dis + q * n;
501
451
 
502
452
  float one_a = 1.0, b = 0.0;
503
453
  if (normalizers) {
@@ -518,14 +468,236 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
518
468
 
519
469
  // possibly add empty results
520
470
  heap_heapify<Cf>(n - res.i, heap_dis + res.i, heap_ids + res.i);
471
+ }
472
+ }
473
+ };
474
+
475
+ /** Result hanlder for range search. The difficulty is that the range distances
476
+ * have to be scaled using the scaler.
477
+ */
478
+
479
+ template <class C, bool with_id_map = false>
480
+ struct RangeHandler : ResultHandlerCompare<C, with_id_map> {
481
+ using T = typename C::T;
482
+ using TI = typename C::TI;
483
+ using RHC = ResultHandlerCompare<C, with_id_map>;
484
+ using RHC::normalizers;
485
+ using RHC::nq;
486
+
487
+ RangeSearchResult& rres;
488
+ float radius;
489
+ std::vector<uint16_t> thresholds;
490
+ std::vector<size_t> n_per_query;
491
+ size_t q0 = 0;
492
+
493
+ // we cannot use the RangeSearchPartialResult interface because queries can
494
+ // be performed by batches
495
+ struct Triplet {
496
+ idx_t q;
497
+ idx_t b;
498
+ uint16_t dis;
499
+ };
500
+ std::vector<Triplet> triplets;
501
+
502
+ RangeHandler(RangeSearchResult& rres, float radius, size_t ntotal)
503
+ : RHC(rres.nq, ntotal), rres(rres), radius(radius) {
504
+ thresholds.resize(nq);
505
+ n_per_query.resize(nq + 1);
506
+ }
507
+
508
+ virtual void begin(const float* norms) {
509
+ normalizers = norms;
510
+ for (int q = 0; q < nq; ++q) {
511
+ thresholds[q] =
512
+ normalizers[2 * q] * (radius - normalizers[2 * q + 1]);
513
+ }
514
+ }
515
+
516
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
517
+ if (this->disable) {
518
+ return;
519
+ }
520
+ this->adjust_with_origin(q, d0, d1);
521
+
522
+ uint32_t lt_mask = this->get_lt_mask(thresholds[q], b, d0, d1);
523
+
524
+ if (!lt_mask) {
525
+ return;
526
+ }
527
+ ALIGNED(32) uint16_t d32tab[32];
528
+ d0.store(d32tab);
529
+ d1.store(d32tab + 16);
530
+
531
+ while (lt_mask) {
532
+ // find first non-zero
533
+ int j = __builtin_ctz(lt_mask);
534
+ lt_mask -= 1 << j;
535
+ T dis = d32tab[j];
536
+ n_per_query[q]++;
537
+ triplets.push_back({idx_t(q + q0), this->adjust_id(b, j), dis});
538
+ }
539
+ }
521
540
 
522
- t3 += res.cycles;
541
+ void end() override {
542
+ memcpy(rres.lims, n_per_query.data(), sizeof(n_per_query[0]) * nq);
543
+ rres.do_allocation();
544
+ for (auto it = triplets.begin(); it != triplets.end(); ++it) {
545
+ size_t& l = rres.lims[it->q];
546
+ rres.distances[l] = it->dis;
547
+ rres.labels[l] = it->b;
548
+ l++;
549
+ }
550
+ memmove(rres.lims + 1, rres.lims, sizeof(*rres.lims) * rres.nq);
551
+ rres.lims[0] = 0;
552
+
553
+ for (int q = 0; q < nq; q++) {
554
+ float one_a = 1 / normalizers[2 * q];
555
+ float b = normalizers[2 * q + 1];
556
+ for (size_t i = rres.lims[q]; i < rres.lims[q + 1]; i++) {
557
+ rres.distances[i] = rres.distances[i] * one_a + b;
558
+ }
523
559
  }
524
- times[2] += get_cy() - t0;
525
- times[3] += t3;
526
560
  }
527
561
  };
528
562
 
563
+ #ifndef SWIG
564
+
565
+ // handler for a subset of queries
566
+ template <class C, bool with_id_map = false>
567
+ struct PartialRangeHandler : RangeHandler<C, with_id_map> {
568
+ using T = typename C::T;
569
+ using TI = typename C::TI;
570
+ using RHC = RangeHandler<C, with_id_map>;
571
+ using RHC::normalizers;
572
+ using RHC::nq, RHC::q0, RHC::triplets, RHC::n_per_query;
573
+
574
+ RangeSearchPartialResult& pres;
575
+
576
+ PartialRangeHandler(
577
+ RangeSearchPartialResult& pres,
578
+ float radius,
579
+ size_t ntotal,
580
+ size_t q0,
581
+ size_t q1)
582
+ : RangeHandler<C, with_id_map>(*pres.res, radius, ntotal),
583
+ pres(pres) {
584
+ nq = q1 - q0;
585
+ this->q0 = q0;
586
+ }
587
+
588
+ // shift left n_per_query
589
+ void shift_n_per_query() {
590
+ memmove(n_per_query.data() + 1,
591
+ n_per_query.data(),
592
+ nq * sizeof(n_per_query[0]));
593
+ n_per_query[0] = 0;
594
+ }
595
+
596
+ // commit to partial result instead of full RangeResult
597
+ void end() override {
598
+ std::vector<typename RHC::Triplet> sorted_triplets(triplets.size());
599
+ for (int q = 0; q < nq; q++) {
600
+ n_per_query[q + 1] += n_per_query[q];
601
+ }
602
+ shift_n_per_query();
603
+
604
+ for (size_t i = 0; i < triplets.size(); i++) {
605
+ sorted_triplets[n_per_query[triplets[i].q - q0]++] = triplets[i];
606
+ }
607
+ shift_n_per_query();
608
+
609
+ size_t* lims = n_per_query.data();
610
+
611
+ for (int q = 0; q < nq; q++) {
612
+ float one_a = 1 / normalizers[2 * q];
613
+ float b = normalizers[2 * q + 1];
614
+ RangeQueryResult& qres = pres.new_result(q + q0);
615
+ for (size_t i = lims[q]; i < lims[q + 1]; i++) {
616
+ qres.add(
617
+ sorted_triplets[i].dis * one_a + b,
618
+ sorted_triplets[i].b);
619
+ }
620
+ }
621
+ }
622
+ };
623
+
624
+ #endif
625
+
626
+ /********************************************************************************
627
+ * Dynamic dispatching function. The consumer should have a templatized method f
628
+ * that will be replaced with the actual SIMDResultHandler that is determined
629
+ * dynamically.
630
+ */
631
+
632
+ template <class C, bool W, class Consumer, class... Types>
633
+ void dispatch_SIMDResultHanlder_fixedCW(
634
+ SIMDResultHandler& res,
635
+ Consumer& consumer,
636
+ Types... args) {
637
+ if (auto resh = dynamic_cast<SingleResultHandler<C, W>*>(&res)) {
638
+ consumer.template f<SingleResultHandler<C, W>>(*resh, args...);
639
+ } else if (auto resh = dynamic_cast<HeapHandler<C, W>*>(&res)) {
640
+ consumer.template f<HeapHandler<C, W>>(*resh, args...);
641
+ } else if (auto resh = dynamic_cast<ReservoirHandler<C, W>*>(&res)) {
642
+ consumer.template f<ReservoirHandler<C, W>>(*resh, args...);
643
+ } else { // generic handler -- will not be inlined
644
+ FAISS_THROW_IF_NOT_FMT(
645
+ simd_result_handlers_accept_virtual,
646
+ "Running vitrual handler for %s",
647
+ typeid(res).name());
648
+ consumer.template f<SIMDResultHandler>(res, args...);
649
+ }
650
+ }
651
+
652
+ template <class C, class Consumer, class... Types>
653
+ void dispatch_SIMDResultHanlder_fixedC(
654
+ SIMDResultHandler& res,
655
+ Consumer& consumer,
656
+ Types... args) {
657
+ if (res.with_fields) {
658
+ dispatch_SIMDResultHanlder_fixedCW<C, true>(res, consumer, args...);
659
+ } else {
660
+ dispatch_SIMDResultHanlder_fixedCW<C, false>(res, consumer, args...);
661
+ }
662
+ }
663
+
664
+ template <class Consumer, class... Types>
665
+ void dispatch_SIMDResultHanlder(
666
+ SIMDResultHandler& res,
667
+ Consumer& consumer,
668
+ Types... args) {
669
+ if (res.sizeof_ids == 0) {
670
+ if (auto resh = dynamic_cast<StoreResultHandler*>(&res)) {
671
+ consumer.template f<StoreResultHandler>(*resh, args...);
672
+ } else if (auto resh = dynamic_cast<DummyResultHandler*>(&res)) {
673
+ consumer.template f<DummyResultHandler>(*resh, args...);
674
+ } else { // generic path
675
+ FAISS_THROW_IF_NOT_FMT(
676
+ simd_result_handlers_accept_virtual,
677
+ "Running vitrual handler for %s",
678
+ typeid(res).name());
679
+ consumer.template f<SIMDResultHandler>(res, args...);
680
+ }
681
+ } else if (res.sizeof_ids == sizeof(int)) {
682
+ if (res.is_CMax) {
683
+ dispatch_SIMDResultHanlder_fixedC<CMax<uint16_t, int>>(
684
+ res, consumer, args...);
685
+ } else {
686
+ dispatch_SIMDResultHanlder_fixedC<CMin<uint16_t, int>>(
687
+ res, consumer, args...);
688
+ }
689
+ } else if (res.sizeof_ids == sizeof(int64_t)) {
690
+ if (res.is_CMax) {
691
+ dispatch_SIMDResultHanlder_fixedC<CMax<uint16_t, int64_t>>(
692
+ res, consumer, args...);
693
+ } else {
694
+ dispatch_SIMDResultHanlder_fixedC<CMin<uint16_t, int64_t>>(
695
+ res, consumer, args...);
696
+ }
697
+ } else {
698
+ FAISS_THROW_FMT("Unknown id size %d", res.sizeof_ids);
699
+ }
700
+ }
529
701
  } // namespace simd_result_handlers
530
702
 
531
703
  } // namespace faiss