faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -14,40 +14,87 @@
14
14
  #include <faiss/utils/Heap.h>
15
15
  #include <faiss/utils/simdlib.h>
16
16
 
17
+ #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/IDSelector.h>
19
+ #include <faiss/impl/ResultHandler.h>
17
20
  #include <faiss/impl/platform_macros.h>
18
21
  #include <faiss/utils/AlignedTable.h>
19
22
  #include <faiss/utils/partitioning.h>
20
23
 
21
24
  /** This file contains callbacks for kernels that compute distances.
22
- *
23
- * The SIMDResultHandler object is intended to be templated and inlined.
24
- * Methods:
25
- * - handle(): called when 32 distances are computed and provided in two
26
- * simd16uint16. (q, b) indicate which entry it is in the block.
27
- * - set_block_origin(): set the sub-matrix that is being computed
28
25
  */
29
26
 
30
27
  namespace faiss {
31
28
 
29
+ struct SIMDResultHandler {
30
+ // used to dispatch templates
31
+ bool is_CMax = false;
32
+ uint8_t sizeof_ids = 0;
33
+ bool with_fields = false;
34
+
35
+ /** called when 32 distances are computed and provided in two
36
+ * simd16uint16. (q, b) indicate which entry it is in the block. */
37
+ virtual void handle(
38
+ size_t q,
39
+ size_t b,
40
+ simd16uint16 d0,
41
+ simd16uint16 d1) = 0;
42
+
43
+ /// set the sub-matrix that is being computed
44
+ virtual void set_block_origin(size_t i0, size_t j0) = 0;
45
+
46
+ virtual ~SIMDResultHandler() {}
47
+ };
48
+
49
+ /* Result handler that will return float resutls eventually */
50
+ struct SIMDResultHandlerToFloat : SIMDResultHandler {
51
+ size_t nq; // number of queries
52
+ size_t ntotal; // ignore excess elements after ntotal
53
+
54
+ /// these fields are used mainly for the IVF variants (with_id_map=true)
55
+ const idx_t* id_map = nullptr; // map offset in invlist to vector id
56
+ const int* q_map = nullptr; // map q to global query
57
+ const uint16_t* dbias =
58
+ nullptr; // table of biases to add to each query (for IVF L2 search)
59
+ const float* normalizers = nullptr; // size 2 * nq, to convert
60
+
61
+ SIMDResultHandlerToFloat(size_t nq, size_t ntotal)
62
+ : nq(nq), ntotal(ntotal) {}
63
+
64
+ virtual void begin(const float* norms) {
65
+ normalizers = norms;
66
+ }
67
+
68
+ // called at end of search to convert int16 distances to float, before
69
+ // normalizers are deallocated
70
+ virtual void end() {
71
+ normalizers = nullptr;
72
+ }
73
+ };
74
+
75
+ FAISS_API extern bool simd_result_handlers_accept_virtual;
76
+
32
77
  namespace simd_result_handlers {
33
78
 
34
- /** Dummy structure that just computes a checksum on results
79
+ /** Dummy structure that just computes a chqecksum on results
35
80
  * (to avoid the computation to be optimized away) */
36
- struct DummyResultHandler {
81
+ struct DummyResultHandler : SIMDResultHandler {
37
82
  size_t cs = 0;
38
83
 
39
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
84
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
40
85
  cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0();
41
86
  }
42
87
 
43
- void set_block_origin(size_t, size_t) {}
88
+ void set_block_origin(size_t, size_t) final {}
89
+
90
+ ~DummyResultHandler() {}
44
91
  };
45
92
 
46
93
  /** memorize results in a nq-by-nb matrix.
47
94
  *
48
95
  * j0 is the current upper-left block of the matrix
49
96
  */
50
- struct StoreResultHandler {
97
+ struct StoreResultHandler : SIMDResultHandler {
51
98
  uint16_t* data;
52
99
  size_t ld; // total number of columns
53
100
  size_t i0 = 0;
@@ -55,32 +102,32 @@ struct StoreResultHandler {
55
102
 
56
103
  StoreResultHandler(uint16_t* data, size_t ld) : data(data), ld(ld) {}
57
104
 
58
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
105
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
59
106
  size_t ofs = (q + i0) * ld + j0 + b * 32;
60
107
  d0.store(data + ofs);
61
108
  d1.store(data + ofs + 16);
62
109
  }
63
110
 
64
- void set_block_origin(size_t i0, size_t j0) {
65
- this->i0 = i0;
66
- this->j0 = j0;
111
+ void set_block_origin(size_t i0_in, size_t j0_in) final {
112
+ this->i0 = i0_in;
113
+ this->j0 = j0_in;
67
114
  }
68
115
  };
69
116
 
70
117
  /** stores results in fixed-size matrix. */
71
118
  template <int NQ, int BB>
72
- struct FixedStorageHandler {
119
+ struct FixedStorageHandler : SIMDResultHandler {
73
120
  simd16uint16 dis[NQ][BB];
74
121
  int i0 = 0;
75
122
 
76
- void handle(int q, int b, simd16uint16 d0, simd16uint16 d1) {
123
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
77
124
  dis[q + i0][2 * b] = d0;
78
125
  dis[q + i0][2 * b + 1] = d1;
79
126
  }
80
127
 
81
- void set_block_origin(size_t i0, size_t j0) {
82
- this->i0 = i0;
83
- assert(j0 == 0);
128
+ void set_block_origin(size_t i0_in, size_t j0_in) final {
129
+ this->i0 = i0_in;
130
+ assert(j0_in == 0);
84
131
  }
85
132
 
86
133
  template <class OtherResultHandler>
@@ -91,30 +138,32 @@ struct FixedStorageHandler {
91
138
  }
92
139
  }
93
140
  }
141
+
142
+ virtual ~FixedStorageHandler() {}
94
143
  };
95
144
 
96
- /** Record origin of current block */
145
+ /** Result handler that compares distances to check if they need to be kept */
97
146
  template <class C, bool with_id_map>
98
- struct SIMDResultHandler {
147
+ struct ResultHandlerCompare : SIMDResultHandlerToFloat {
99
148
  using TI = typename C::TI;
100
149
 
101
150
  bool disable = false;
102
151
 
103
152
  int64_t i0 = 0; // query origin
104
153
  int64_t j0 = 0; // db origin
105
- size_t ntotal; // ignore excess elements after ntotal
106
154
 
107
- /// these fields are used mainly for the IVF variants (with_id_map=true)
108
- const TI* id_map; // map offset in invlist to vector id
109
- const int* q_map; // map q to global query
110
- const uint16_t* dbias; // table of biases to add to each query
155
+ const IDSelector* sel;
111
156
 
112
- explicit SIMDResultHandler(size_t ntotal)
113
- : ntotal(ntotal), id_map(nullptr), q_map(nullptr), dbias(nullptr) {}
157
+ ResultHandlerCompare(size_t nq, size_t ntotal, const IDSelector* sel_in)
158
+ : SIMDResultHandlerToFloat(nq, ntotal), sel{sel_in} {
159
+ this->is_CMax = C::is_max;
160
+ this->sizeof_ids = sizeof(typename C::TI);
161
+ this->with_fields = with_id_map;
162
+ }
114
163
 
115
- void set_block_origin(size_t i0, size_t j0) {
116
- this->i0 = i0;
117
- this->j0 = j0;
164
+ void set_block_origin(size_t i0_in, size_t j0_in) final {
165
+ this->i0 = i0_in;
166
+ this->j0 = j0_in;
118
167
  }
119
168
 
120
169
  // adjust handler data for IVF.
@@ -172,43 +221,42 @@ struct SIMDResultHandler {
172
221
  return lt_mask;
173
222
  }
174
223
 
175
- virtual void to_flat_arrays(
176
- float* distances,
177
- int64_t* labels,
178
- const float* normalizers = nullptr) = 0;
179
-
180
- virtual ~SIMDResultHandler() {}
224
+ virtual ~ResultHandlerCompare() {}
181
225
  };
182
226
 
183
227
  /** Special version for k=1 */
184
228
  template <class C, bool with_id_map = false>
185
- struct SingleResultHandler : SIMDResultHandler<C, with_id_map> {
229
+ struct SingleResultHandler : ResultHandlerCompare<C, with_id_map> {
186
230
  using T = typename C::T;
187
231
  using TI = typename C::TI;
188
-
189
- struct Result {
190
- T val;
191
- TI id;
192
- };
193
- std::vector<Result> results;
194
-
195
- SingleResultHandler(size_t nq, size_t ntotal)
196
- : SIMDResultHandler<C, with_id_map>(ntotal), results(nq) {
197
- for (int i = 0; i < nq; i++) {
198
- Result res = {C::neutral(), -1};
199
- results[i] = res;
232
+ using RHC = ResultHandlerCompare<C, with_id_map>;
233
+ using RHC::normalizers;
234
+
235
+ std::vector<int16_t> idis;
236
+ float* dis;
237
+ int64_t* ids;
238
+
239
+ SingleResultHandler(
240
+ size_t nq,
241
+ size_t ntotal,
242
+ float* dis,
243
+ int64_t* ids,
244
+ const IDSelector* sel_in)
245
+ : RHC(nq, ntotal, sel_in), idis(nq), dis(dis), ids(ids) {
246
+ for (size_t i = 0; i < nq; i++) {
247
+ ids[i] = -1;
248
+ idis[i] = C::neutral();
200
249
  }
201
250
  }
202
251
 
203
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
252
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
204
253
  if (this->disable) {
205
254
  return;
206
255
  }
207
256
 
208
257
  this->adjust_with_origin(q, d0, d1);
209
258
 
210
- Result& res = results[q];
211
- uint32_t lt_mask = this->get_lt_mask(res.val, b, d0, d1);
259
+ uint32_t lt_mask = this->get_lt_mask(idis[q], b, d0, d1);
212
260
  if (!lt_mask) {
213
261
  return;
214
262
  }
@@ -217,74 +265,87 @@ struct SingleResultHandler : SIMDResultHandler<C, with_id_map> {
217
265
  d0.store(d32tab);
218
266
  d1.store(d32tab + 16);
219
267
 
220
- while (lt_mask) {
221
- // find first non-zero
222
- int j = __builtin_ctz(lt_mask);
223
- lt_mask -= 1 << j;
224
- T dis = d32tab[j];
225
- if (C::cmp(res.val, dis)) {
226
- res.val = dis;
227
- res.id = this->adjust_id(b, j);
268
+ if (this->sel != nullptr) {
269
+ while (lt_mask) {
270
+ // find first non-zero
271
+ int j = __builtin_ctz(lt_mask);
272
+ auto real_idx = this->adjust_id(b, j);
273
+ lt_mask -= 1 << j;
274
+ if (this->sel->is_member(real_idx)) {
275
+ T d = d32tab[j];
276
+ if (C::cmp(idis[q], d)) {
277
+ idis[q] = d;
278
+ ids[q] = real_idx;
279
+ }
280
+ }
281
+ }
282
+ } else {
283
+ while (lt_mask) {
284
+ // find first non-zero
285
+ int j = __builtin_ctz(lt_mask);
286
+ lt_mask -= 1 << j;
287
+ T d = d32tab[j];
288
+ if (C::cmp(idis[q], d)) {
289
+ idis[q] = d;
290
+ ids[q] = this->adjust_id(b, j);
291
+ }
228
292
  }
229
293
  }
230
294
  }
231
295
 
232
- void to_flat_arrays(
233
- float* distances,
234
- int64_t* labels,
235
- const float* normalizers = nullptr) override {
236
- for (int q = 0; q < results.size(); q++) {
296
+ void end() {
297
+ for (size_t q = 0; q < this->nq; q++) {
237
298
  if (!normalizers) {
238
- distances[q] = results[q].val;
299
+ dis[q] = idis[q];
239
300
  } else {
240
301
  float one_a = 1 / normalizers[2 * q];
241
302
  float b = normalizers[2 * q + 1];
242
- distances[q] = b + results[q].val * one_a;
303
+ dis[q] = b + idis[q] * one_a;
243
304
  }
244
- labels[q] = results[q].id;
245
305
  }
246
306
  }
247
307
  };
248
308
 
249
309
  /** Structure that collects results in a min- or max-heap */
250
310
  template <class C, bool with_id_map = false>
251
- struct HeapHandler : SIMDResultHandler<C, with_id_map> {
311
+ struct HeapHandler : ResultHandlerCompare<C, with_id_map> {
252
312
  using T = typename C::T;
253
313
  using TI = typename C::TI;
314
+ using RHC = ResultHandlerCompare<C, with_id_map>;
315
+ using RHC::normalizers;
254
316
 
255
- int nq;
256
- T* heap_dis_tab;
257
- TI* heap_ids_tab;
317
+ std::vector<uint16_t> idis;
318
+ std::vector<TI> iids;
319
+ float* dis;
320
+ int64_t* ids;
258
321
 
259
322
  int64_t k; // number of results to keep
260
323
 
261
324
  HeapHandler(
262
- int nq,
263
- T* heap_dis_tab,
264
- TI* heap_ids_tab,
265
- size_t k,
266
- size_t ntotal)
267
- : SIMDResultHandler<C, with_id_map>(ntotal),
268
- nq(nq),
269
- heap_dis_tab(heap_dis_tab),
270
- heap_ids_tab(heap_ids_tab),
325
+ size_t nq,
326
+ size_t ntotal,
327
+ int64_t k,
328
+ float* dis,
329
+ int64_t* ids,
330
+ const IDSelector* sel_in)
331
+ : RHC(nq, ntotal, sel_in),
332
+ idis(nq * k),
333
+ iids(nq * k),
334
+ dis(dis),
335
+ ids(ids),
271
336
  k(k) {
272
- for (int q = 0; q < nq; q++) {
273
- T* heap_dis_in = heap_dis_tab + q * k;
274
- TI* heap_ids_in = heap_ids_tab + q * k;
275
- heap_heapify<C>(k, heap_dis_in, heap_ids_in);
276
- }
337
+ heap_heapify<C>(k * nq, idis.data(), iids.data());
277
338
  }
278
339
 
279
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
340
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
280
341
  if (this->disable) {
281
342
  return;
282
343
  }
283
344
 
284
345
  this->adjust_with_origin(q, d0, d1);
285
346
 
286
- T* heap_dis = heap_dis_tab + q * k;
287
- TI* heap_ids = heap_ids_tab + q * k;
347
+ T* heap_dis = idis.data() + q * k;
348
+ TI* heap_ids = iids.data() + q * k;
288
349
 
289
350
  uint16_t cur_thresh =
290
351
  heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) : 0xffff;
@@ -300,29 +361,41 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
300
361
  d0.store(d32tab);
301
362
  d1.store(d32tab + 16);
302
363
 
303
- while (lt_mask) {
304
- // find first non-zero
305
- int j = __builtin_ctz(lt_mask);
306
- lt_mask -= 1 << j;
307
- T dis = d32tab[j];
308
- if (C::cmp(heap_dis[0], dis)) {
309
- int64_t idx = this->adjust_id(b, j);
310
- heap_pop<C>(k, heap_dis, heap_ids);
311
- heap_push<C>(k, heap_dis, heap_ids, dis, idx);
364
+ if (this->sel != nullptr) {
365
+ while (lt_mask) {
366
+ // find first non-zero
367
+ int j = __builtin_ctz(lt_mask);
368
+ auto real_idx = this->adjust_id(b, j);
369
+ lt_mask -= 1 << j;
370
+ if (this->sel->is_member(real_idx)) {
371
+ T dis = d32tab[j];
372
+ if (C::cmp(heap_dis[0], dis)) {
373
+ heap_replace_top<C>(
374
+ k, heap_dis, heap_ids, dis, real_idx);
375
+ }
376
+ }
377
+ }
378
+ } else {
379
+ while (lt_mask) {
380
+ // find first non-zero
381
+ int j = __builtin_ctz(lt_mask);
382
+ lt_mask -= 1 << j;
383
+ T dis = d32tab[j];
384
+ if (C::cmp(heap_dis[0], dis)) {
385
+ int64_t idx = this->adjust_id(b, j);
386
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, idx);
387
+ }
312
388
  }
313
389
  }
314
390
  }
315
391
 
316
- void to_flat_arrays(
317
- float* distances,
318
- int64_t* labels,
319
- const float* normalizers = nullptr) override {
320
- for (int q = 0; q < nq; q++) {
321
- T* heap_dis_in = heap_dis_tab + q * k;
322
- TI* heap_ids_in = heap_ids_tab + q * k;
392
+ void end() override {
393
+ for (size_t q = 0; q < this->nq; q++) {
394
+ T* heap_dis_in = idis.data() + q * k;
395
+ TI* heap_ids_in = iids.data() + q * k;
323
396
  heap_reorder<C>(k, heap_dis_in, heap_ids_in);
324
- int64_t* heap_ids = labels + q * k;
325
- float* heap_dis = distances + q * k;
397
+ float* heap_dis = dis + q * k;
398
+ int64_t* heap_ids = ids + q * k;
326
399
 
327
400
  float one_a = 1.0, b = 0.0;
328
401
  if (normalizers) {
@@ -330,8 +403,8 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
330
403
  b = normalizers[2 * q + 1];
331
404
  }
332
405
  for (int j = 0; j < k; j++) {
333
- heap_ids[j] = heap_ids_in[j];
334
406
  heap_dis[j] = heap_dis_in[j] * one_a + b;
407
+ heap_ids[j] = heap_ids_in[j];
335
408
  }
336
409
  }
337
410
  }
@@ -342,114 +415,49 @@ struct HeapHandler : SIMDResultHandler<C, with_id_map> {
342
415
  * Results are stored when they are below the threshold until the capacity is
343
416
  * reached. Then a partition sort is used to update the threshold. */
344
417
 
345
- namespace {
346
-
347
- uint64_t get_cy() {
348
- #ifdef MICRO_BENCHMARK
349
- uint32_t high, low;
350
- asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high));
351
- return ((uint64_t)high << 32) | (low);
352
- #else
353
- return 0;
354
- #endif
355
- }
356
-
357
- } // anonymous namespace
358
-
359
- template <class C>
360
- struct ReservoirTopN {
361
- using T = typename C::T;
362
- using TI = typename C::TI;
363
-
364
- T* vals;
365
- TI* ids;
366
-
367
- size_t i; // number of stored elements
368
- size_t n; // number of requested elements
369
- size_t capacity; // size of storage
370
- size_t cycles = 0;
371
-
372
- T threshold; // current threshold
373
-
374
- ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
375
- : vals(vals), ids(ids), i(0), n(n), capacity(capacity) {
376
- assert(n < capacity);
377
- threshold = C::neutral();
378
- }
379
-
380
- void add(T val, TI id) {
381
- if (C::cmp(threshold, val)) {
382
- if (i == capacity) {
383
- shrink_fuzzy();
384
- }
385
- vals[i] = val;
386
- ids[i] = id;
387
- i++;
388
- }
389
- }
390
-
391
- /// shrink number of stored elements to n
392
- void shrink_xx() {
393
- uint64_t t0 = get_cy();
394
- qselect(vals, ids, i, n);
395
- i = n; // forget all elements above i = n
396
- threshold = C::Crev::neutral();
397
- for (size_t j = 0; j < n; j++) {
398
- if (C::cmp(vals[j], threshold)) {
399
- threshold = vals[j];
400
- }
401
- }
402
- cycles += get_cy() - t0;
403
- }
404
-
405
- void shrink() {
406
- uint64_t t0 = get_cy();
407
- threshold = partition<C>(vals, ids, i, n);
408
- i = n;
409
- cycles += get_cy() - t0;
410
- }
411
-
412
- void shrink_fuzzy() {
413
- uint64_t t0 = get_cy();
414
- assert(i == capacity);
415
- threshold = partition_fuzzy<C>(
416
- vals, ids, capacity, n, (capacity + n) / 2, &i);
417
- cycles += get_cy() - t0;
418
- }
419
- };
420
-
421
418
  /** Handler built from several ReservoirTopN (one per query) */
422
419
  template <class C, bool with_id_map = false>
423
- struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
420
+ struct ReservoirHandler : ResultHandlerCompare<C, with_id_map> {
424
421
  using T = typename C::T;
425
422
  using TI = typename C::TI;
423
+ using RHC = ResultHandlerCompare<C, with_id_map>;
424
+ using RHC::normalizers;
426
425
 
427
426
  size_t capacity; // rounded up to multiple of 16
427
+
428
+ // where the final results will be written
429
+ float* dis;
430
+ int64_t* ids;
431
+
428
432
  std::vector<TI> all_ids;
429
433
  AlignedTable<T> all_vals;
430
-
431
434
  std::vector<ReservoirTopN<C>> reservoirs;
432
435
 
433
- uint64_t times[4];
434
-
435
- ReservoirHandler(size_t nq, size_t ntotal, size_t n, size_t capacity_in)
436
- : SIMDResultHandler<C, with_id_map>(ntotal),
437
- capacity((capacity_in + 15) & ~15),
438
- all_ids(nq * capacity),
439
- all_vals(nq * capacity) {
436
+ ReservoirHandler(
437
+ size_t nq,
438
+ size_t ntotal,
439
+ size_t k,
440
+ size_t cap,
441
+ float* dis,
442
+ int64_t* ids,
443
+ const IDSelector* sel_in)
444
+ : RHC(nq, ntotal, sel_in),
445
+ capacity((cap + 15) & ~15),
446
+ dis(dis),
447
+ ids(ids) {
440
448
  assert(capacity % 16 == 0);
441
- for (size_t i = 0; i < nq; i++) {
449
+ all_ids.resize(nq * capacity);
450
+ all_vals.resize(nq * capacity);
451
+ for (size_t q = 0; q < nq; q++) {
442
452
  reservoirs.emplace_back(
443
- n,
453
+ k,
444
454
  capacity,
445
- all_vals.get() + i * capacity,
446
- all_ids.data() + i * capacity);
455
+ all_vals.get() + q * capacity,
456
+ all_ids.data() + q * capacity);
447
457
  }
448
- times[0] = times[1] = times[2] = times[3] = 0;
449
458
  }
450
459
 
451
- void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
452
- uint64_t t0 = get_cy();
460
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
453
461
  if (this->disable) {
454
462
  return;
455
463
  }
@@ -457,8 +465,6 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
457
465
 
458
466
  ReservoirTopN<C>& res = reservoirs[q];
459
467
  uint32_t lt_mask = this->get_lt_mask(res.threshold, b, d0, d1);
460
- uint64_t t1 = get_cy();
461
- times[0] += t1 - t0;
462
468
 
463
469
  if (!lt_mask) {
464
470
  return;
@@ -467,65 +473,315 @@ struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
467
473
  d0.store(d32tab);
468
474
  d1.store(d32tab + 16);
469
475
 
470
- while (lt_mask) {
471
- // find first non-zero
472
- int j = __builtin_ctz(lt_mask);
473
- lt_mask -= 1 << j;
474
- T dis = d32tab[j];
475
- res.add(dis, this->adjust_id(b, j));
476
+ if (this->sel != nullptr) {
477
+ while (lt_mask) {
478
+ // find first non-zero
479
+ int j = __builtin_ctz(lt_mask);
480
+ auto real_idx = this->adjust_id(b, j);
481
+ lt_mask -= 1 << j;
482
+ if (this->sel->is_member(real_idx)) {
483
+ T dis = d32tab[j];
484
+ res.add(dis, real_idx);
485
+ }
486
+ }
487
+ } else {
488
+ while (lt_mask) {
489
+ // find first non-zero
490
+ int j = __builtin_ctz(lt_mask);
491
+ lt_mask -= 1 << j;
492
+ T dis = d32tab[j];
493
+ res.add(dis, this->adjust_id(b, j));
494
+ }
476
495
  }
477
- times[1] += get_cy() - t1;
478
496
  }
479
497
 
480
- void to_flat_arrays(
481
- float* distances,
482
- int64_t* labels,
483
- const float* normalizers = nullptr) override {
498
+ void end() override {
484
499
  using Cf = typename std::conditional<
485
500
  C::is_max,
486
501
  CMax<float, int64_t>,
487
502
  CMin<float, int64_t>>::type;
488
503
 
489
- uint64_t t0 = get_cy();
490
- uint64_t t3 = 0;
491
504
  std::vector<int> perm(reservoirs[0].n);
492
- for (int q = 0; q < reservoirs.size(); q++) {
505
+ for (size_t q = 0; q < reservoirs.size(); q++) {
493
506
  ReservoirTopN<C>& res = reservoirs[q];
494
507
  size_t n = res.n;
495
508
 
496
509
  if (res.i > res.n) {
497
510
  res.shrink();
498
511
  }
499
- int64_t* heap_ids = labels + q * n;
500
- float* heap_dis = distances + q * n;
512
+ int64_t* heap_ids = ids + q * n;
513
+ float* heap_dis = dis + q * n;
501
514
 
502
515
  float one_a = 1.0, b = 0.0;
503
516
  if (normalizers) {
504
517
  one_a = 1 / normalizers[2 * q];
505
518
  b = normalizers[2 * q + 1];
506
519
  }
507
- for (int i = 0; i < res.i; i++) {
520
+ for (size_t i = 0; i < res.i; i++) {
508
521
  perm[i] = i;
509
522
  }
510
523
  // indirect sort of result arrays
511
524
  std::sort(perm.begin(), perm.begin() + res.i, [&res](int i, int j) {
512
525
  return C::cmp(res.vals[j], res.vals[i]);
513
526
  });
514
- for (int i = 0; i < res.i; i++) {
527
+ for (size_t i = 0; i < res.i; i++) {
515
528
  heap_dis[i] = res.vals[perm[i]] * one_a + b;
516
529
  heap_ids[i] = res.ids[perm[i]];
517
530
  }
518
531
 
519
532
  // possibly add empty results
520
533
  heap_heapify<Cf>(n - res.i, heap_dis + res.i, heap_ids + res.i);
534
+ }
535
+ }
536
+ };
537
+
538
+ /** Result handler for range search. The difficulty is that the range distances
539
+ * have to be scaled using the scaler.
540
+ */
541
+
542
+ template <class C, bool with_id_map = false>
543
+ struct RangeHandler : ResultHandlerCompare<C, with_id_map> {
544
+ using T = typename C::T;
545
+ using TI = typename C::TI;
546
+ using RHC = ResultHandlerCompare<C, with_id_map>;
547
+ using RHC::normalizers;
548
+ using RHC::nq;
549
+
550
+ RangeSearchResult& rres;
551
+ float radius;
552
+ std::vector<uint16_t> thresholds;
553
+ std::vector<size_t> n_per_query;
554
+ size_t q0 = 0;
555
+
556
+ // we cannot use the RangeSearchPartialResult interface because queries can
557
+ // be performed by batches
558
+ struct Triplet {
559
+ idx_t q;
560
+ idx_t b;
561
+ uint16_t dis;
562
+ };
563
+ std::vector<Triplet> triplets;
564
+
565
+ RangeHandler(
566
+ RangeSearchResult& rres,
567
+ float radius,
568
+ size_t ntotal,
569
+ const IDSelector* sel_in)
570
+ : RHC(rres.nq, ntotal, sel_in), rres(rres), radius(radius) {
571
+ thresholds.resize(nq);
572
+ n_per_query.resize(nq + 1);
573
+ }
574
+
575
+ virtual void begin(const float* norms) override {
576
+ normalizers = norms;
577
+ for (int q = 0; q < nq; ++q) {
578
+ thresholds[q] =
579
+ normalizers[2 * q] * (radius - normalizers[2 * q + 1]);
580
+ }
581
+ }
582
+
583
+ void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final {
584
+ if (this->disable) {
585
+ return;
586
+ }
587
+ this->adjust_with_origin(q, d0, d1);
521
588
 
522
- t3 += res.cycles;
589
+ uint32_t lt_mask = this->get_lt_mask(thresholds[q], b, d0, d1);
590
+
591
+ if (!lt_mask) {
592
+ return;
593
+ }
594
+ ALIGNED(32) uint16_t d32tab[32];
595
+ d0.store(d32tab);
596
+ d1.store(d32tab + 16);
597
+
598
+ if (this->sel != nullptr) {
599
+ while (lt_mask) {
600
+ // find first non-zero
601
+ int j = __builtin_ctz(lt_mask);
602
+ lt_mask -= 1 << j;
603
+
604
+ auto real_idx = this->adjust_id(b, j);
605
+ if (this->sel->is_member(real_idx)) {
606
+ T dis = d32tab[j];
607
+ n_per_query[q]++;
608
+ triplets.push_back({idx_t(q + q0), real_idx, dis});
609
+ }
610
+ }
611
+ } else {
612
+ while (lt_mask) {
613
+ // find first non-zero
614
+ int j = __builtin_ctz(lt_mask);
615
+ lt_mask -= 1 << j;
616
+ T dis = d32tab[j];
617
+ n_per_query[q]++;
618
+ triplets.push_back({idx_t(q + q0), this->adjust_id(b, j), dis});
619
+ }
620
+ }
621
+ }
622
+
623
+ void end() override {
624
+ memcpy(rres.lims, n_per_query.data(), sizeof(n_per_query[0]) * nq);
625
+ rres.do_allocation();
626
+ for (auto it = triplets.begin(); it != triplets.end(); ++it) {
627
+ size_t& l = rres.lims[it->q];
628
+ rres.distances[l] = it->dis;
629
+ rres.labels[l] = it->b;
630
+ l++;
631
+ }
632
+ memmove(rres.lims + 1, rres.lims, sizeof(*rres.lims) * rres.nq);
633
+ rres.lims[0] = 0;
634
+
635
+ for (int q = 0; q < nq; q++) {
636
+ float one_a = 1 / normalizers[2 * q];
637
+ float b = normalizers[2 * q + 1];
638
+ for (size_t i = rres.lims[q]; i < rres.lims[q + 1]; i++) {
639
+ rres.distances[i] = rres.distances[i] * one_a + b;
640
+ }
523
641
  }
524
- times[2] += get_cy() - t0;
525
- times[3] += t3;
526
642
  }
527
643
  };
528
644
 
645
+ #ifndef SWIG
646
+
647
+ // handler for a subset of queries
648
+ template <class C, bool with_id_map = false>
649
+ struct PartialRangeHandler : RangeHandler<C, with_id_map> {
650
+ using T = typename C::T;
651
+ using TI = typename C::TI;
652
+ using RHC = RangeHandler<C, with_id_map>;
653
+ using RHC::normalizers;
654
+ using RHC::nq, RHC::q0, RHC::triplets, RHC::n_per_query;
655
+
656
+ RangeSearchPartialResult& pres;
657
+
658
+ PartialRangeHandler(
659
+ RangeSearchPartialResult& pres,
660
+ float radius,
661
+ size_t ntotal,
662
+ size_t q0,
663
+ size_t q1,
664
+ const IDSelector* sel_in)
665
+ : RangeHandler<C, with_id_map>(*pres.res, radius, ntotal, sel_in),
666
+ pres(pres) {
667
+ nq = q1 - q0;
668
+ this->q0 = q0;
669
+ }
670
+
671
+ // shift left n_per_query
672
+ void shift_n_per_query() {
673
+ memmove(n_per_query.data() + 1,
674
+ n_per_query.data(),
675
+ nq * sizeof(n_per_query[0]));
676
+ n_per_query[0] = 0;
677
+ }
678
+
679
+ // commit to partial result instead of full RangeResult
680
+ void end() override {
681
+ std::vector<typename RHC::Triplet> sorted_triplets(triplets.size());
682
+ for (int q = 0; q < nq; q++) {
683
+ n_per_query[q + 1] += n_per_query[q];
684
+ }
685
+ shift_n_per_query();
686
+
687
+ for (size_t i = 0; i < triplets.size(); i++) {
688
+ sorted_triplets[n_per_query[triplets[i].q - q0]++] = triplets[i];
689
+ }
690
+ shift_n_per_query();
691
+
692
+ size_t* lims = n_per_query.data();
693
+
694
+ for (int q = 0; q < nq; q++) {
695
+ float one_a = 1 / normalizers[2 * q];
696
+ float b = normalizers[2 * q + 1];
697
+ RangeQueryResult& qres = pres.new_result(q + q0);
698
+ for (size_t i = lims[q]; i < lims[q + 1]; i++) {
699
+ qres.add(
700
+ sorted_triplets[i].dis * one_a + b,
701
+ sorted_triplets[i].b);
702
+ }
703
+ }
704
+ }
705
+ };
706
+
707
+ #endif
708
+
709
+ /********************************************************************************
710
+ * Dynamic dispatching function. The consumer should have a templatized method f
711
+ * that will be replaced with the actual SIMDResultHandler that is determined
712
+ * dynamically.
713
+ */
714
+
715
+ template <class C, bool W, class Consumer, class... Types>
716
+ void dispatch_SIMDResultHandler_fixedCW(
717
+ SIMDResultHandler& res,
718
+ Consumer& consumer,
719
+ Types... args) {
720
+ if (auto resh = dynamic_cast<SingleResultHandler<C, W>*>(&res)) {
721
+ consumer.template f<SingleResultHandler<C, W>>(*resh, args...);
722
+ } else if (auto resh = dynamic_cast<HeapHandler<C, W>*>(&res)) {
723
+ consumer.template f<HeapHandler<C, W>>(*resh, args...);
724
+ } else if (auto resh = dynamic_cast<ReservoirHandler<C, W>*>(&res)) {
725
+ consumer.template f<ReservoirHandler<C, W>>(*resh, args...);
726
+ } else { // generic handler -- will not be inlined
727
+ FAISS_THROW_IF_NOT_FMT(
728
+ simd_result_handlers_accept_virtual,
729
+ "Running vitrual handler for %s",
730
+ typeid(res).name());
731
+ consumer.template f<SIMDResultHandler>(res, args...);
732
+ }
733
+ }
734
+
735
+ template <class C, class Consumer, class... Types>
736
+ void dispatch_SIMDResultHandler_fixedC(
737
+ SIMDResultHandler& res,
738
+ Consumer& consumer,
739
+ Types... args) {
740
+ if (res.with_fields) {
741
+ dispatch_SIMDResultHandler_fixedCW<C, true>(res, consumer, args...);
742
+ } else {
743
+ dispatch_SIMDResultHandler_fixedCW<C, false>(res, consumer, args...);
744
+ }
745
+ }
746
+
747
+ template <class Consumer, class... Types>
748
+ void dispatch_SIMDResultHandler(
749
+ SIMDResultHandler& res,
750
+ Consumer& consumer,
751
+ Types... args) {
752
+ if (res.sizeof_ids == 0) {
753
+ if (auto resh = dynamic_cast<StoreResultHandler*>(&res)) {
754
+ consumer.template f<StoreResultHandler>(*resh, args...);
755
+ } else if (auto resh = dynamic_cast<DummyResultHandler*>(&res)) {
756
+ consumer.template f<DummyResultHandler>(*resh, args...);
757
+ } else { // generic path
758
+ FAISS_THROW_IF_NOT_FMT(
759
+ simd_result_handlers_accept_virtual,
760
+ "Running vitrual handler for %s",
761
+ typeid(res).name());
762
+ consumer.template f<SIMDResultHandler>(res, args...);
763
+ }
764
+ } else if (res.sizeof_ids == sizeof(int)) {
765
+ if (res.is_CMax) {
766
+ dispatch_SIMDResultHandler_fixedC<CMax<uint16_t, int>>(
767
+ res, consumer, args...);
768
+ } else {
769
+ dispatch_SIMDResultHandler_fixedC<CMin<uint16_t, int>>(
770
+ res, consumer, args...);
771
+ }
772
+ } else if (res.sizeof_ids == sizeof(int64_t)) {
773
+ if (res.is_CMax) {
774
+ dispatch_SIMDResultHandler_fixedC<CMax<uint16_t, int64_t>>(
775
+ res, consumer, args...);
776
+ } else {
777
+ dispatch_SIMDResultHandler_fixedC<CMin<uint16_t, int64_t>>(
778
+ res, consumer, args...);
779
+ }
780
+ } else {
781
+ FAISS_THROW_FMT("Unknown id size %d", res.sizeof_ids);
782
+ }
783
+ }
784
+
529
785
  } // namespace simd_result_handlers
530
786
 
531
787
  } // namespace faiss