faiss 0.2.7 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (172) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/lib/faiss.rb +1 -1
  11. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  12. data/vendor/faiss/faiss/AutoTune.h +0 -1
  13. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  14. data/vendor/faiss/faiss/Clustering.h +31 -21
  15. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  16. data/vendor/faiss/faiss/Index.cpp +1 -1
  17. data/vendor/faiss/faiss/Index.h +20 -5
  18. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  21. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  22. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  23. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  34. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  38. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  59. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  60. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  61. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  62. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  63. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  64. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  65. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  66. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  67. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  69. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  70. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  71. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  72. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  73. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  74. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  75. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  76. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  77. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  78. data/vendor/faiss/faiss/clone_index.h +3 -0
  79. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  80. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  81. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  82. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  90. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  92. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  93. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  97. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  98. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  99. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  101. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  103. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  104. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  105. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  106. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  107. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  108. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  109. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  110. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  111. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  113. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  119. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  125. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  126. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  127. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  128. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  129. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  133. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  135. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  136. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  137. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  138. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  139. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  140. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  141. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  142. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  143. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  144. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  145. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  146. data/vendor/faiss/faiss/utils/distances.h +81 -4
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  148. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  150. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  152. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  153. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  154. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  155. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  156. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  157. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  158. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  159. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  160. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  161. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  162. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  163. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  164. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  165. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  166. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  167. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  168. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  169. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  170. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  171. data/vendor/faiss/faiss/utils/utils.h +57 -20
  172. metadata +11 -4
@@ -14,40 +14,86 @@
14
14
  #include <faiss/utils/Heap.h>
15
15
  #include <faiss/utils/simdlib.h>
16
16
 
17
+ #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/ResultHandler.h>
17
19
  #include <faiss/impl/platform_macros.h>
18
20
  #include <faiss/utils/AlignedTable.h>
19
21
  #include <faiss/utils/partitioning.h>
20
22
 
21
23
  /** This file contains callbacks for kernels that compute distances.
22
- *
23
- * The SIMDResultHandler object is intended to be templated and inlined.
24
- * Methods:
25
- * - handle(): called when 32 distances are computed and provided in two
26
- * simd16uint16. (q, b) indicate which entry it is in the block.
27
- * - set_block_origin(): set the sub-matrix that is being computed
28
24
  */
29
25
 
30
26
  namespace faiss {
31
27
 
28
+ struct SIMDResultHandler {
29
+ // used to dispatch templates
30
+ bool is_CMax = false;
31
+ uint8_t sizeof_ids = 0;
32
+ bool with_fields = false;
33
+
34
+ /** called when 32 distances are computed and provided in two
35
+ * simd16uint16. (q, b) indicate which entry it is in the block. */
36
+ virtual void handle(
37
+ size_t q,
38
+ size_t b,
39
+ simd16uint16 d0,
40
+ simd16uint16 d1) = 0;
41
+
42
+ /// set the sub-matrix that is being computed
43
+ virtual void set_block_origin(size_t i0, size_t j0) = 0;
44
+
45
+ virtual ~SIMDResultHandler() {}
46
+ };
47
+
48
+ /* Result handler that will return float resutls eventually */
49
+ struct SIMDResultHandlerToFloat : SIMDResultHandler {
50
+ size_t nq; // number of queries
51
+ size_t ntotal; // ignore excess elements after ntotal
52
+
53
+ /// these fields are used mainly for the IVF variants (with_id_map=true)
54
+ const idx_t* id_map = nullptr; // map offset in invlist to vector id
55
+ const int* q_map = nullptr; // map q to global query
56
+ const uint16_t* dbias =
57
+ nullptr; // table of biases to add to each query (for IVF L2 search)
58
+ const float* normalizers = nullptr; // size 2 * nq, to convert
59
+
60
+ SIMDResultHandlerToFloat(size_t nq, size_t ntotal)
61
+ : nq(nq), ntotal(ntotal) {}
62
+
63
+ virtual void begin(const float* norms) {
64
+ normalizers = norms;
65
+ }
66
+
67
+ // called at end of search to convert int16 distances to float, before
68
+ // normalizers are deallocated
69
+ virtual void end() {
70
+ normalizers = nullptr;
71
+ }
72
+ };
73
+
74
+ FAISS_API extern bool simd_result_handlers_accept_virtual;
75
+
32
76
  namespace simd_result_handlers {
33
77
 
34
- /** Dummy structure that just computes a checksum on results
78
+ /** Dummy structure that just computes a chqecksum on results
35
79
  * (to avoid the computation to be optimized away) */
36
- struct DummyResultHandler {
80
+ struct DummyResultHandler : SIMDResultHandler {
37
81
  size_t cs = 0;
38
82
 
39
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
83
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
40
84
  cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0();
41
85
  }
42
86
 
43
- void set_block_origin(size_t, size_t) {}
87
+ void set_block_origin(size_t, size_t) final {}
88
+
89
+ ~DummyResultHandler() {}
44
90
  };
45
91
 
46
92
  /** memorize results in a nq-by-nb matrix.
47
93
  *
48
94
  * j0 is the current upper-left block of the matrix
49
95
  */
50
- struct StoreResultHandler {
96
+ struct StoreResultHandler : SIMDResultHandler {
51
97
  uint16_t* data;
52
98
  size_t ld; // total number of columns
53
99
  size_t i0 = 0;
@@ -55,32 +101,32 @@ struct StoreResultHandler {
55
101
 
56
102
  StoreResultHandler(uint16_t* data, size_t ld) : data(data), ld(ld) {}
57
103
 
58
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
104
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
59
105
  size_t ofs = (q + i0) * ld + j0 + b * 32;
60
106
  d0.store(data + ofs);
61
107
  d1.store(data + ofs + 16);
62
108
  }
63
109
 
64
- void set_block_origin(size_t i0, size_t j0) {
65
- this->i0 = i0;
66
- this->j0 = j0;
110
+ void set_block_origin(size_t i0_in, size_t j0_in) final {
111
+ this->i0 = i0_in;
112
+ this->j0 = j0_in;
67
113
  }
68
114
  };
69
115
 
70
116
  /** stores results in fixed-size matrix. */
71
117
  template <int NQ, int BB>
72
- struct FixedStorageHandler {
118
+ struct FixedStorageHandler : SIMDResultHandler {
73
119
  simd16uint16 dis[NQ][BB];
74
120
  int i0 = 0;
75
121
 
76
- void handle(int q, int b, simd16uint16 d0, simd16uint16 d1) {
122
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
77
123
  dis[q + i0][2 * b] = d0;
78
124
  dis[q + i0][2 * b + 1] = d1;
79
125
  }
80
126
 
81
- void set_block_origin(size_t i0, size_t j0) {
82
- this->i0 = i0;
83
- assert(j0 == 0);
127
+ void set_block_origin(size_t i0_in, size_t j0_in) final {
128
+ this->i0 = i0_in;
129
+ assert(j0_in == 0);
84
130
  }
85
131
 
86
132
  template <class OtherResultHandler>
@@ -91,30 +137,29 @@ struct FixedStorageHandler {
91
137
  }
92
138
  }
93
139
  }
140
+ virtual ~FixedStorageHandler() {}
94
141
  };
95
142
 
96
- /** Record origin of current block */
143
+ /** Result handler that compares distances to check if they need to be kept */
97
144
  template <class C, bool with_id_map>
98
- struct SIMDResultHandler {
145
+ struct ResultHandlerCompare : SIMDResultHandlerToFloat {
99
146
  using TI = typename C::TI;
100
147
 
101
148
  bool disable = false;
102
149
 
103
150
  int64_t i0 = 0; // query origin
104
151
  int64_t j0 = 0; // db origin
105
- size_t ntotal; // ignore excess elements after ntotal
106
-
107
- /// these fields are used mainly for the IVF variants (with_id_map=true)
108
- const TI* id_map; // map offset in invlist to vector id
109
- const int* q_map; // map q to global query
110
- const uint16_t* dbias; // table of biases to add to each query
111
152
 
112
- explicit SIMDResultHandler(size_t ntotal)
113
- : ntotal(ntotal), id_map(nullptr), q_map(nullptr), dbias(nullptr) {}
153
+ ResultHandlerCompare(size_t nq, size_t ntotal)
154
+ : SIMDResultHandlerToFloat(nq, ntotal) {
155
+ this->is_CMax = C::is_max;
156
+ this->sizeof_ids = sizeof(typename C::TI);
157
+ this->with_fields = with_id_map;
158
+ }
114
159
 
115
- void set_block_origin(size_t i0, size_t j0) {
116
- this->i0 = i0;
117
- this->j0 = j0;
160
+ void set_block_origin(size_t i0_in, size_t j0_in) final {
161
+ this->i0 = i0_in;
162
+ this->j0 = j0_in;
118
163
  }
119
164
 
120
165
  // adjust handler data for IVF.
@@ -172,43 +217,37 @@ struct SIMDResultHandler {
172
217
  return lt_mask;
173
218
  }
174
219
 
175
- virtual void to_flat_arrays(
176
- float* distances,
177
- int64_t* labels,
178
- const float* normalizers = nullptr) = 0;
179
-
180
- virtual ~SIMDResultHandler() {}
220
+ virtual ~ResultHandlerCompare() {}
181
221
  };
182
222
 
183
223
  /** Special version for k=1 */
184
224
  template <class C, bool with_id_map = false>
185
- struct SingleResultHandler : SIMDResultHandler<C, with_id_map> {
225
+ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
186
226
  using T = typename C::T;
187
227
  using TI = typename C::TI;
228
+ using RHC = ResultHandlerCompare<C, with_id_map>;
229
+ using RHC::normalizers;
188
230
 
189
- struct Result {
190
- T val;
191
- TI id;
192
- };
193
- std::vector<Result> results;
231
+ std::vector<int16_t> idis;
232
+ float* dis;
233
+ int64_t* ids;
194
234
 
195
- SingleResultHandler(size_t nq, size_t ntotal)
196
- : SIMDResultHandler<C, with_id_map>(ntotal), results(nq) {
235
+ SingleResultHandler(size_t nq, size_t ntotal, float* dis, int64_t* ids)
236
+ : RHC(nq, ntotal), idis(nq), dis(dis), ids(ids) {
197
237
  for (int i = 0; i < nq; i++) {
198
- Result res = {C::neutral(), -1};
199
- results[i] = res;
238
+ ids[i] = -1;
239
+ idis[i] = C::neutral();
200
240
  }
201
241
  }
202
242
 
203
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
243
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
204
244
  if (this->disable) {
205
245
  return;
206
246
  }
207
247
 
208
248
  this->adjust_with_origin(q, d0, d1);
209
249
 
210
- Result& res = results[q];
211
- uint32_t lt_mask = this->get_lt_mask(res.val, b, d0, d1);
250
+ uint32_t lt_mask = this->get_lt_mask(idis[q], b, d0, d1);
212
251
  if (!lt_mask) {
213
252
  return;
214
253
  }
@@ -221,70 +260,61 @@ struct SingleResultHandler : SIMDResultHandler<C, with_id_map> {
221
260
  // find first non-zero
222
261
  int j = __builtin_ctz(lt_mask);
223
262
  lt_mask -= 1 << j;
224
- T dis = d32tab[j];
225
- if (C::cmp(res.val, dis)) {
226
- res.val = dis;
227
- res.id = this->adjust_id(b, j);
263
+ T d = d32tab[j];
264
+ if (C::cmp(idis[q], d)) {
265
+ idis[q] = d;
266
+ ids[q] = this->adjust_id(b, j);
228
267
  }
229
268
  }
230
269
  }
231
270
 
232
- void to_flat_arrays(
233
- float* distances,
234
- int64_t* labels,
235
- const float* normalizers = nullptr) override {
236
- for (int q = 0; q < results.size(); q++) {
271
+ void end() {
272
+ for (int q = 0; q < this->nq; q++) {
237
273
  if (!normalizers) {
238
- distances[q] = results[q].val;
274
+ dis[q] = idis[q];
239
275
  } else {
240
276
  float one_a = 1 / normalizers[2 * q];
241
277
  float b = normalizers[2 * q + 1];
242
- distances[q] = b + results[q].val * one_a;
278
+ dis[q] = b + idis[q] * one_a;
243
279
  }
244
- labels[q] = results[q].id;
245
280
  }
246
281
  }
247
282
  };
248
283
 
249
284
  /** Structure that collects results in a min- or max-heap */
250
285
  template <class C, bool with_id_map = false>
251
- struct HeapHandler : SIMDResultHandler<C, with_id_map> {
286
+ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
252
287
  using T = typename C::T;
253
288
  using TI = typename C::TI;
289
+ using RHC = ResultHandlerCompare<C, with_id_map>;
290
+ using RHC::normalizers;
254
291
 
255
- int nq;
256
- T* heap_dis_tab;
257
- TI* heap_ids_tab;
292
+ std::vector<uint16_t> idis;
293
+ std::vector<TI> iids;
294
+ float* dis;
295
+ int64_t* ids;
258
296
 
259
297
  int64_t k; // number of results to keep
260
298
 
261
- HeapHandler(
262
- int nq,
263
- T* heap_dis_tab,
264
- TI* heap_ids_tab,
265
- size_t k,
266
- size_t ntotal)
267
- : SIMDResultHandler<C, with_id_map>(ntotal),
268
- nq(nq),
269
- heap_dis_tab(heap_dis_tab),
270
- heap_ids_tab(heap_ids_tab),
299
+ HeapHandler(size_t nq, size_t ntotal, int64_t k, float* dis, int64_t* ids)
300
+ : RHC(nq, ntotal),
301
+ idis(nq * k),
302
+ iids(nq * k),
303
+ dis(dis),
304
+ ids(ids),
271
305
  k(k) {
272
- for (int q = 0; q < nq; q++) {
273
- T* heap_dis_in = heap_dis_tab + q * k;
274
- TI* heap_ids_in = heap_ids_tab + q * k;
275
- heap_heapify<C>(k, heap_dis_in, heap_ids_in);
276
- }
306
+ heap_heapify<C>(k * nq, idis.data(), iids.data());
277
307
  }
278
308
 
279
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
309
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
280
310
  if (this->disable) {
281
311
  return;
282
312
  }
283
313
 
284
314
  this->adjust_with_origin(q, d0, d1);
285
315
 
286
- T* heap_dis = heap_dis_tab + q * k;
287
- TI* heap_ids = heap_ids_tab + q * k;
316
+ T* heap_dis = idis.data() + q * k;
317
+ TI* heap_ids = iids.data() + q * k;
288
318
 
289
319
  uint16_t cur_thresh =
290
320
  heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) : 0xffff;
@@ -313,16 +343,13 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
313
343
  }
314
344
  }
315
345
 
316
- void to_flat_arrays(
317
- float* distances,
318
- int64_t* labels,
319
- const float* normalizers = nullptr) override {
320
- for (int q = 0; q < nq; q++) {
321
- T* heap_dis_in = heap_dis_tab + q * k;
322
- TI* heap_ids_in = heap_ids_tab + q * k;
346
+ void end() override {
347
+ for (int q = 0; q < this->nq; q++) {
348
+ T* heap_dis_in = idis.data() + q * k;
349
+ TI* heap_ids_in = iids.data() + q * k;
323
350
  heap_reorder<C>(k, heap_dis_in, heap_ids_in);
324
- int64_t* heap_ids = labels + q * k;
325
- float* heap_dis = distances + q * k;
351
+ float* heap_dis = dis + q * k;
352
+ int64_t* heap_ids = ids + q * k;
326
353
 
327
354
  float one_a = 1.0, b = 0.0;
328
355
  if (normalizers) {
@@ -330,8 +357,8 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
330
357
  b = normalizers[2 * q + 1];
331
358
  }
332
359
  for (int j = 0; j < k; j++) {
333
- heap_ids[j] = heap_ids_in[j];
334
360
  heap_dis[j] = heap_dis_in[j] * one_a + b;
361
+ heap_ids[j] = heap_ids_in[j];
335
362
  }
336
363
  }
337
364
  }
@@ -342,114 +369,45 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
342
369
  * Results are stored when they are below the threshold until the capacity is
343
370
  * reached. Then a partition sort is used to update the threshold. */
344
371
 
345
- namespace {
346
-
347
- uint64_t get_cy() {
348
- #ifdef MICRO_BENCHMARK
349
- uint32_t high, low;
350
- asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high));
351
- return ((uint64_t)high << 32) | (low);
352
- #else
353
- return 0;
354
- #endif
355
- }
356
-
357
- } // anonymous namespace
358
-
359
- template <class C>
360
- struct ReservoirTopN {
361
- using T = typename C::T;
362
- using TI = typename C::TI;
363
-
364
- T* vals;
365
- TI* ids;
366
-
367
- size_t i; // number of stored elements
368
- size_t n; // number of requested elements
369
- size_t capacity; // size of storage
370
- size_t cycles = 0;
371
-
372
- T threshold; // current threshold
373
-
374
- ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
375
- : vals(vals), ids(ids), i(0), n(n), capacity(capacity) {
376
- assert(n < capacity);
377
- threshold = C::neutral();
378
- }
379
-
380
- void add(T val, TI id) {
381
- if (C::cmp(threshold, val)) {
382
- if (i == capacity) {
383
- shrink_fuzzy();
384
- }
385
- vals[i] = val;
386
- ids[i] = id;
387
- i++;
388
- }
389
- }
390
-
391
- /// shrink number of stored elements to n
392
- void shrink_xx() {
393
- uint64_t t0 = get_cy();
394
- qselect(vals, ids, i, n);
395
- i = n; // forget all elements above i = n
396
- threshold = C::Crev::neutral();
397
- for (size_t j = 0; j < n; j++) {
398
- if (C::cmp(vals[j], threshold)) {
399
- threshold = vals[j];
400
- }
401
- }
402
- cycles += get_cy() - t0;
403
- }
404
-
405
- void shrink() {
406
- uint64_t t0 = get_cy();
407
- threshold = partition<C>(vals, ids, i, n);
408
- i = n;
409
- cycles += get_cy() - t0;
410
- }
411
-
412
- void shrink_fuzzy() {
413
- uint64_t t0 = get_cy();
414
- assert(i == capacity);
415
- threshold = partition_fuzzy<C>(
416
- vals, ids, capacity, n, (capacity + n) / 2, &i);
417
- cycles += get_cy() - t0;
418
- }
419
- };
420
-
421
372
  /** Handler built from several ReservoirTopN (one per query) */
422
373
  template <class C, bool with_id_map = false>
423
- struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
374
+ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
424
375
  using T = typename C::T;
425
376
  using TI = typename C::TI;
377
+ using RHC = ResultHandlerCompare<C, with_id_map>;
378
+ using RHC::normalizers;
426
379
 
427
380
  size_t capacity; // rounded up to multiple of 16
381
+
382
+ // where the final results will be written
383
+ float* dis;
384
+ int64_t* ids;
385
+
428
386
  std::vector<TI> all_ids;
429
387
  AlignedTable<T> all_vals;
430
-
431
388
  std::vector<ReservoirTopN<C>> reservoirs;
432
389
 
433
- uint64_t times[4];
434
-
435
- ReservoirHandler(size_t nq, size_t ntotal, size_t n, size_t capacity_in)
436
- : SIMDResultHandler<C, with_id_map>(ntotal),
437
- capacity((capacity_in + 15) & ~15),
438
- all_ids(nq * capacity),
439
- all_vals(nq * capacity) {
390
+ ReservoirHandler(
391
+ size_t nq,
392
+ size_t ntotal,
393
+ size_t k,
394
+ size_t cap,
395
+ float* dis,
396
+ int64_t* ids)
397
+ : RHC(nq, ntotal), capacity((cap + 15) & ~15), dis(dis), ids(ids) {
440
398
  assert(capacity % 16 == 0);
441
- for (size_t i = 0; i < nq; i++) {
399
+ all_ids.resize(nq * capacity);
400
+ all_vals.resize(nq * capacity);
401
+ for (size_t q = 0; q < nq; q++) {
442
402
  reservoirs.emplace_back(
443
- n,
403
+ k,
444
404
  capacity,
445
- all_vals.get() + i * capacity,
446
- all_ids.data() + i * capacity);
405
+ all_vals.get() + q * capacity,
406
+ all_ids.data() + q * capacity);
447
407
  }
448
- times[0] = times[1] = times[2] = times[3] = 0;
449
408
  }
450
409
 
451
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
452
- uint64_t t0 = get_cy();
410
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
453
411
  if (this->disable) {
454
412
  return;
455
413
  }
@@ -457,8 +415,6 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
457
415
 
458
416
  ReservoirTopN<C>& res = reservoirs[q];
459
417
  uint32_t lt_mask = this->get_lt_mask(res.threshold, b, d0, d1);
460
- uint64_t t1 = get_cy();
461
- times[0] += t1 - t0;
462
418
 
463
419
  if (!lt_mask) {
464
420
  return;
@@ -474,20 +430,14 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
474
430
  T dis = d32tab[j];
475
431
  res.add(dis, this->adjust_id(b, j));
476
432
  }
477
- times[1] += get_cy() - t1;
478
433
  }
479
434
 
480
- void to_flat_arrays(
481
- float* distances,
482
- int64_t* labels,
483
- const float* normalizers = nullptr) override {
435
+ void end() override {
484
436
  using Cf = typename std::conditional<
485
437
  C::is_max,
486
438
  CMax<float, int64_t>,
487
439
  CMin<float, int64_t>>::type;
488
440
 
489
- uint64_t t0 = get_cy();
490
- uint64_t t3 = 0;
491
441
  std::vector<int> perm(reservoirs[0].n);
492
442
  for (int q = 0; q < reservoirs.size(); q++) {
493
443
  ReservoirTopN<C>& res = reservoirs[q];
@@ -496,8 +446,8 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
496
446
  if (res.i > res.n) {
497
447
  res.shrink();
498
448
  }
499
- int64_t* heap_ids = labels + q * n;
500
- float* heap_dis = distances + q * n;
449
+ int64_t* heap_ids = ids + q * n;
450
+ float* heap_dis = dis + q * n;
501
451
 
502
452
  float one_a = 1.0, b = 0.0;
503
453
  if (normalizers) {
@@ -518,14 +468,236 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
518
468
 
519
469
  // possibly add empty results
520
470
  heap_heapify<Cf>(n - res.i, heap_dis + res.i, heap_ids + res.i);
471
+ }
472
+ }
473
+ };
474
+
475
+ /** Result hanlder for range search. The difficulty is that the range distances
476
+ * have to be scaled using the scaler.
477
+ */
478
+
479
+ template <class C, bool with_id_map = false>
480
+ struct RangeHandler : ResultHandlerCompare<C, with_id_map> {
481
+ using T = typename C::T;
482
+ using TI = typename C::TI;
483
+ using RHC = ResultHandlerCompare<C, with_id_map>;
484
+ using RHC::normalizers;
485
+ using RHC::nq;
486
+
487
+ RangeSearchResult& rres;
488
+ float radius;
489
+ std::vector<uint16_t> thresholds;
490
+ std::vector<size_t> n_per_query;
491
+ size_t q0 = 0;
492
+
493
+ // we cannot use the RangeSearchPartialResult interface because queries can
494
+ // be performed by batches
495
+ struct Triplet {
496
+ idx_t q;
497
+ idx_t b;
498
+ uint16_t dis;
499
+ };
500
+ std::vector<Triplet> triplets;
501
+
502
+ RangeHandler(RangeSearchResult& rres, float radius, size_t ntotal)
503
+ : RHC(rres.nq, ntotal), rres(rres), radius(radius) {
504
+ thresholds.resize(nq);
505
+ n_per_query.resize(nq + 1);
506
+ }
507
+
508
+ virtual void begin(const float* norms) {
509
+ normalizers = norms;
510
+ for (int q = 0; q < nq; ++q) {
511
+ thresholds[q] =
512
+ normalizers[2 * q] * (radius - normalizers[2 * q + 1]);
513
+ }
514
+ }
515
+
516
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
517
+ if (this->disable) {
518
+ return;
519
+ }
520
+ this->adjust_with_origin(q, d0, d1);
521
+
522
+ uint32_t lt_mask = this->get_lt_mask(thresholds[q], b, d0, d1);
523
+
524
+ if (!lt_mask) {
525
+ return;
526
+ }
527
+ ALIGNED(32) uint16_t d32tab[32];
528
+ d0.store(d32tab);
529
+ d1.store(d32tab + 16);
530
+
531
+ while (lt_mask) {
532
+ // find first non-zero
533
+ int j = __builtin_ctz(lt_mask);
534
+ lt_mask -= 1 << j;
535
+ T dis = d32tab[j];
536
+ n_per_query[q]++;
537
+ triplets.push_back({idx_t(q + q0), this->adjust_id(b, j), dis});
538
+ }
539
+ }
521
540
 
522
- t3 += res.cycles;
541
+ void end() override {
542
+ memcpy(rres.lims, n_per_query.data(), sizeof(n_per_query[0]) * nq);
543
+ rres.do_allocation();
544
+ for (auto it = triplets.begin(); it != triplets.end(); ++it) {
545
+ size_t& l = rres.lims[it->q];
546
+ rres.distances[l] = it->dis;
547
+ rres.labels[l] = it->b;
548
+ l++;
549
+ }
550
+ memmove(rres.lims + 1, rres.lims, sizeof(*rres.lims) * rres.nq);
551
+ rres.lims[0] = 0;
552
+
553
+ for (int q = 0; q < nq; q++) {
554
+ float one_a = 1 / normalizers[2 * q];
555
+ float b = normalizers[2 * q + 1];
556
+ for (size_t i = rres.lims[q]; i < rres.lims[q + 1]; i++) {
557
+ rres.distances[i] = rres.distances[i] * one_a + b;
558
+ }
523
559
  }
524
- times[2] += get_cy() - t0;
525
- times[3] += t3;
526
560
  }
527
561
  };
528
562
 
563
+ #ifndef SWIG
564
+
565
+ // handler for a subset of queries
566
+ template <class C, bool with_id_map = false>
567
+ struct PartialRangeHandler : RangeHandler<C, with_id_map> {
568
+ using T = typename C::T;
569
+ using TI = typename C::TI;
570
+ using RHC = RangeHandler<C, with_id_map>;
571
+ using RHC::normalizers;
572
+ using RHC::nq, RHC::q0, RHC::triplets, RHC::n_per_query;
573
+
574
+ RangeSearchPartialResult& pres;
575
+
576
+ PartialRangeHandler(
577
+ RangeSearchPartialResult& pres,
578
+ float radius,
579
+ size_t ntotal,
580
+ size_t q0,
581
+ size_t q1)
582
+ : RangeHandler<C, with_id_map>(*pres.res, radius, ntotal),
583
+ pres(pres) {
584
+ nq = q1 - q0;
585
+ this->q0 = q0;
586
+ }
587
+
588
+ // shift left n_per_query
589
+ void shift_n_per_query() {
590
+ memmove(n_per_query.data() + 1,
591
+ n_per_query.data(),
592
+ nq * sizeof(n_per_query[0]));
593
+ n_per_query[0] = 0;
594
+ }
595
+
596
+ // commit to partial result instead of full RangeResult
597
+ void end() override {
598
+ std::vector<typename RHC::Triplet> sorted_triplets(triplets.size());
599
+ for (int q = 0; q < nq; q++) {
600
+ n_per_query[q + 1] += n_per_query[q];
601
+ }
602
+ shift_n_per_query();
603
+
604
+ for (size_t i = 0; i < triplets.size(); i++) {
605
+ sorted_triplets[n_per_query[triplets[i].q - q0]++] = triplets[i];
606
+ }
607
+ shift_n_per_query();
608
+
609
+ size_t* lims = n_per_query.data();
610
+
611
+ for (int q = 0; q < nq; q++) {
612
+ float one_a = 1 / normalizers[2 * q];
613
+ float b = normalizers[2 * q + 1];
614
+ RangeQueryResult& qres = pres.new_result(q + q0);
615
+ for (size_t i = lims[q]; i < lims[q + 1]; i++) {
616
+ qres.add(
617
+ sorted_triplets[i].dis * one_a + b,
618
+ sorted_triplets[i].b);
619
+ }
620
+ }
621
+ }
622
+ };
623
+
624
+ #endif
625
+
626
+ /********************************************************************************
627
+ * Dynamic dispatching function. The consumer should have a templatized method f
628
+ * that will be replaced with the actual SIMDResultHandler that is determined
629
+ * dynamically.
630
+ */
631
+
632
+ template <class C, bool W, class Consumer, class... Types>
633
+ void dispatch_SIMDResultHanlder_fixedCW(
634
+ SIMDResultHandler& res,
635
+ Consumer& consumer,
636
+ Types... args) {
637
+ if (auto resh = dynamic_cast<SingleResultHandler<C, W>*>(&res)) {
638
+ consumer.template f<SingleResultHandler<C, W>>(*resh, args...);
639
+ } else if (auto resh = dynamic_cast<HeapHandler<C, W>*>(&res)) {
640
+ consumer.template f<HeapHandler<C, W>>(*resh, args...);
641
+ } else if (auto resh = dynamic_cast<ReservoirHandler<C, W>*>(&res)) {
642
+ consumer.template f<ReservoirHandler<C, W>>(*resh, args...);
643
+ } else { // generic handler -- will not be inlined
644
+ FAISS_THROW_IF_NOT_FMT(
645
+ simd_result_handlers_accept_virtual,
646
+ "Running vitrual handler for %s",
647
+ typeid(res).name());
648
+ consumer.template f<SIMDResultHandler>(res, args...);
649
+ }
650
+ }
651
+
652
+ template <class C, class Consumer, class... Types>
653
+ void dispatch_SIMDResultHanlder_fixedC(
654
+ SIMDResultHandler& res,
655
+ Consumer& consumer,
656
+ Types... args) {
657
+ if (res.with_fields) {
658
+ dispatch_SIMDResultHanlder_fixedCW<C, true>(res, consumer, args...);
659
+ } else {
660
+ dispatch_SIMDResultHanlder_fixedCW<C, false>(res, consumer, args...);
661
+ }
662
+ }
663
+
664
+ template <class Consumer, class... Types>
665
+ void dispatch_SIMDResultHanlder(
666
+ SIMDResultHandler& res,
667
+ Consumer& consumer,
668
+ Types... args) {
669
+ if (res.sizeof_ids == 0) {
670
+ if (auto resh = dynamic_cast<StoreResultHandler*>(&res)) {
671
+ consumer.template f<StoreResultHandler>(*resh, args...);
672
+ } else if (auto resh = dynamic_cast<DummyResultHandler*>(&res)) {
673
+ consumer.template f<DummyResultHandler>(*resh, args...);
674
+ } else { // generic path
675
+ FAISS_THROW_IF_NOT_FMT(
676
+ simd_result_handlers_accept_virtual,
677
+ "Running vitrual handler for %s",
678
+ typeid(res).name());
679
+ consumer.template f<SIMDResultHandler>(res, args...);
680
+ }
681
+ } else if (res.sizeof_ids == sizeof(int)) {
682
+ if (res.is_CMax) {
683
+ dispatch_SIMDResultHanlder_fixedC<CMax<uint16_t, int>>(
684
+ res, consumer, args...);
685
+ } else {
686
+ dispatch_SIMDResultHanlder_fixedC<CMin<uint16_t, int>>(
687
+ res, consumer, args...);
688
+ }
689
+ } else if (res.sizeof_ids == sizeof(int64_t)) {
690
+ if (res.is_CMax) {
691
+ dispatch_SIMDResultHanlder_fixedC<CMax<uint16_t, int64_t>>(
692
+ res, consumer, args...);
693
+ } else {
694
+ dispatch_SIMDResultHanlder_fixedC<CMin<uint16_t, int64_t>>(
695
+ res, consumer, args...);
696
+ }
697
+ } else {
698
+ FAISS_THROW_FMT("Unknown id size %d", res.sizeof_ids);
699
+ }
700
+ }
529
701
  } // namespace simd_result_handlers
530
702
 
531
703
  } // namespace faiss