faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -14,6 +14,9 @@
14
14
 
15
15
  namespace faiss {
16
16
 
17
+ // declared in simd_result_handlers.h
18
+ bool simd_result_handlers_accept_virtual = true;
19
+
17
20
  using namespace simd_result_handlers;
18
21
 
19
22
  /************************************************************
@@ -28,6 +31,8 @@ namespace {
28
31
  * writes results in a ResultHandler
29
32
  */
30
33
 
34
+ #ifndef __AVX512F__
35
+
31
36
  template <int NQ, class ResultHandler, class Scaler>
32
37
  void kernel_accumulate_block(
33
38
  int nsq,
@@ -108,6 +113,451 @@ void kernel_accumulate_block(
108
113
  }
109
114
  }
110
115
 
116
+ #else
117
+
118
+ // a special version for NQ=1.
119
+ // Despite the function being large in the text form, it compiles to a very
120
+ // compact assembler code.
121
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
122
+ template <class ResultHandler, class Scaler>
123
+ void kernel_accumulate_block_avx512_nq1(
124
+ int nsq,
125
+ const uint8_t* codes,
126
+ const uint8_t* LUT,
127
+ ResultHandler& res,
128
+ const Scaler& scaler) {
129
+ // NQ is kept in order to match the similarity to baseline function
130
+ constexpr int NQ = 1;
131
+ // distance accumulators. We can accept more for NQ=1
132
+ // layout: accu[q][b]: distance accumulator for vectors 32*b..32*b+15
133
+ simd32uint16 accu[NQ][4];
134
+ // layout: accu[q][b]: distance accumulator for vectors 32*b+16..32*b+31
135
+ simd32uint16 accu1[NQ][4];
136
+
137
+ for (int q = 0; q < NQ; q++) {
138
+ for (int b = 0; b < 4; b++) {
139
+ accu[q][b].clear();
140
+ accu1[q][b].clear();
141
+ }
142
+ }
143
+
144
+ // process "nsq - scaler.nscale" part
145
+ const int nsq_minus_nscale = nsq - scaler.nscale;
146
+ const int nsq_minus_nscale_8 = (nsq_minus_nscale / 8) * 8;
147
+ const int nsq_minus_nscale_4 = (nsq_minus_nscale / 4) * 4;
148
+
149
+ // process in chunks of 8
150
+ for (int sq = 0; sq < nsq_minus_nscale_8; sq += 8) {
151
+ // prefetch
152
+ simd64uint8 c(codes);
153
+ codes += 64;
154
+
155
+ simd64uint8 c1(codes);
156
+ codes += 64;
157
+
158
+ simd64uint8 mask(0xf);
159
+ // shift op does not exist for int8...
160
+ simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask;
161
+ simd64uint8 clo = c & mask;
162
+
163
+ simd64uint8 c1hi = simd64uint8(simd32uint16(c1) >> 4) & mask;
164
+ simd64uint8 c1lo = c1 & mask;
165
+
166
+ for (int q = 0; q < NQ; q++) {
167
+ // load LUTs for 4 quantizers
168
+ simd64uint8 lut(LUT);
169
+ LUT += 64;
170
+
171
+ {
172
+ simd64uint8 res0 = lut.lookup_4_lanes(clo);
173
+ simd64uint8 res1 = lut.lookup_4_lanes(chi);
174
+
175
+ accu[q][0] += simd32uint16(res0);
176
+ accu[q][1] += simd32uint16(res0) >> 8;
177
+
178
+ accu[q][2] += simd32uint16(res1);
179
+ accu[q][3] += simd32uint16(res1) >> 8;
180
+ }
181
+ }
182
+
183
+ for (int q = 0; q < NQ; q++) {
184
+ // load LUTs for 4 quantizers
185
+ simd64uint8 lut(LUT);
186
+ LUT += 64;
187
+
188
+ {
189
+ simd64uint8 res0 = lut.lookup_4_lanes(c1lo);
190
+ simd64uint8 res1 = lut.lookup_4_lanes(c1hi);
191
+
192
+ accu1[q][0] += simd32uint16(res0);
193
+ accu1[q][1] += simd32uint16(res0) >> 8;
194
+
195
+ accu1[q][2] += simd32uint16(res1);
196
+ accu1[q][3] += simd32uint16(res1) >> 8;
197
+ }
198
+ }
199
+ }
200
+
201
+ // process leftovers: a single chunk of size 4
202
+ if (nsq_minus_nscale_8 != nsq_minus_nscale_4) {
203
+ // prefetch
204
+ simd64uint8 c(codes);
205
+ codes += 64;
206
+
207
+ simd64uint8 mask(0xf);
208
+ // shift op does not exist for int8...
209
+ simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask;
210
+ simd64uint8 clo = c & mask;
211
+
212
+ for (int q = 0; q < NQ; q++) {
213
+ // load LUTs for 4 quantizers
214
+ simd64uint8 lut(LUT);
215
+ LUT += 64;
216
+
217
+ simd64uint8 res0 = lut.lookup_4_lanes(clo);
218
+ simd64uint8 res1 = lut.lookup_4_lanes(chi);
219
+
220
+ accu[q][0] += simd32uint16(res0);
221
+ accu[q][1] += simd32uint16(res0) >> 8;
222
+
223
+ accu[q][2] += simd32uint16(res1);
224
+ accu[q][3] += simd32uint16(res1) >> 8;
225
+ }
226
+ }
227
+
228
+ // process leftovers: a single chunk of size 2
229
+ if (nsq_minus_nscale_4 != nsq_minus_nscale) {
230
+ // prefetch
231
+ simd32uint8 c(codes);
232
+ codes += 32;
233
+
234
+ simd32uint8 mask(0xf);
235
+ // shift op does not exist for int8...
236
+ simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
237
+ simd32uint8 clo = c & mask;
238
+
239
+ for (int q = 0; q < NQ; q++) {
240
+ // load LUTs for 2 quantizers
241
+ simd32uint8 lut(LUT);
242
+ LUT += 32;
243
+
244
+ simd32uint8 res0 = lut.lookup_2_lanes(clo);
245
+ simd32uint8 res1 = lut.lookup_2_lanes(chi);
246
+
247
+ accu[q][0] += simd32uint16(simd16uint16(res0));
248
+ accu[q][1] += simd32uint16(simd16uint16(res0) >> 8);
249
+
250
+ accu[q][2] += simd32uint16(simd16uint16(res1));
251
+ accu[q][3] += simd32uint16(simd16uint16(res1) >> 8);
252
+ }
253
+ }
254
+
255
+ // process "sq" part
256
+ const int nscale = scaler.nscale;
257
+ const int nscale_8 = (nscale / 8) * 8;
258
+ const int nscale_4 = (nscale / 4) * 4;
259
+
260
+ // process in chunks of 8
261
+ for (int sq = 0; sq < nscale_8; sq += 8) {
262
+ // prefetch
263
+ simd64uint8 c(codes);
264
+ codes += 64;
265
+
266
+ simd64uint8 c1(codes);
267
+ codes += 64;
268
+
269
+ simd64uint8 mask(0xf);
270
+ // shift op does not exist for int8...
271
+ simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask;
272
+ simd64uint8 clo = c & mask;
273
+
274
+ simd64uint8 c1hi = simd64uint8(simd32uint16(c1) >> 4) & mask;
275
+ simd64uint8 c1lo = c1 & mask;
276
+
277
+ for (int q = 0; q < NQ; q++) {
278
+ // load LUTs for 4 quantizers
279
+ simd64uint8 lut(LUT);
280
+ LUT += 64;
281
+
282
+ {
283
+ simd64uint8 res0 = scaler.lookup(lut, clo);
284
+ accu[q][0] += scaler.scale_lo(res0); // handle vectors 0..15
285
+ accu[q][1] += scaler.scale_hi(res0); // handle vectors 16..31
286
+
287
+ simd64uint8 res1 = scaler.lookup(lut, chi);
288
+ accu[q][2] += scaler.scale_lo(res1); // handle vectors 32..47
289
+ accu[q][3] += scaler.scale_hi(res1); // handle vectors 48..63
290
+ }
291
+ }
292
+
293
+ for (int q = 0; q < NQ; q++) {
294
+ // load LUTs for 4 quantizers
295
+ simd64uint8 lut(LUT);
296
+ LUT += 64;
297
+
298
+ {
299
+ simd64uint8 res0 = scaler.lookup(lut, c1lo);
300
+ accu1[q][0] += scaler.scale_lo(res0); // handle vectors 0..7
301
+ accu1[q][1] += scaler.scale_hi(res0); // handle vectors 8..15
302
+
303
+ simd64uint8 res1 = scaler.lookup(lut, c1hi);
304
+ accu1[q][2] += scaler.scale_lo(res1); // handle vectors 16..23
305
+ accu1[q][3] += scaler.scale_hi(res1); // handle vectors 24..31
306
+ }
307
+ }
308
+ }
309
+
310
+ // process leftovers: a single chunk of size 4
311
+ if (nscale_8 != nscale_4) {
312
+ // prefetch
313
+ simd64uint8 c(codes);
314
+ codes += 64;
315
+
316
+ simd64uint8 mask(0xf);
317
+ // shift op does not exist for int8...
318
+ simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask;
319
+ simd64uint8 clo = c & mask;
320
+
321
+ for (int q = 0; q < NQ; q++) {
322
+ // load LUTs for 4 quantizers
323
+ simd64uint8 lut(LUT);
324
+ LUT += 64;
325
+
326
+ simd64uint8 res0 = scaler.lookup(lut, clo);
327
+ accu[q][0] += scaler.scale_lo(res0); // handle vectors 0..15
328
+ accu[q][1] += scaler.scale_hi(res0); // handle vectors 16..31
329
+
330
+ simd64uint8 res1 = scaler.lookup(lut, chi);
331
+ accu[q][2] += scaler.scale_lo(res1); // handle vectors 32..47
332
+ accu[q][3] += scaler.scale_hi(res1); // handle vectors 48..63
333
+ }
334
+ }
335
+
336
+ // process leftovers: a single chunk of size 2
337
+ if (nscale_4 != nscale) {
338
+ // prefetch
339
+ simd32uint8 c(codes);
340
+ codes += 32;
341
+
342
+ simd32uint8 mask(0xf);
343
+ // shift op does not exist for int8...
344
+ simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
345
+ simd32uint8 clo = c & mask;
346
+
347
+ for (int q = 0; q < NQ; q++) {
348
+ // load LUTs for 2 quantizers
349
+ simd32uint8 lut(LUT);
350
+ LUT += 32;
351
+
352
+ simd32uint8 res0 = scaler.lookup(lut, clo);
353
+ accu[q][0] +=
354
+ simd32uint16(scaler.scale_lo(res0)); // handle vectors 0..7
355
+ accu[q][1] +=
356
+ simd32uint16(scaler.scale_hi(res0)); // handle vectors 8..15
357
+
358
+ simd32uint8 res1 = scaler.lookup(lut, chi);
359
+ accu[q][2] += simd32uint16(
360
+ scaler.scale_lo(res1)); // handle vectors 16..23
361
+ accu[q][3] += simd32uint16(
362
+ scaler.scale_hi(res1)); // handle vectors 24..31
363
+ }
364
+ }
365
+
366
+ for (int q = 0; q < NQ; q++) {
367
+ for (int b = 0; b < 4; b++) {
368
+ accu[q][b] += accu1[q][b];
369
+ }
370
+ }
371
+
372
+ for (int q = 0; q < NQ; q++) {
373
+ accu[q][0] -= accu[q][1] << 8;
374
+ simd16uint16 dis0 = combine4x2(accu[q][0], accu[q][1]);
375
+ accu[q][2] -= accu[q][3] << 8;
376
+ simd16uint16 dis1 = combine4x2(accu[q][2], accu[q][3]);
377
+ res.handle(q, 0, dis0, dis1);
378
+ }
379
+ }
380
+
381
+ // general-purpose case
382
+ FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
383
+ template <int NQ, class ResultHandler, class Scaler>
384
+ void kernel_accumulate_block_avx512_nqx(
385
+ int nsq,
386
+ const uint8_t* codes,
387
+ const uint8_t* LUT,
388
+ ResultHandler& res,
389
+ const Scaler& scaler) {
390
+ // dummy alloc to keep the windows compiler happy
391
+ constexpr int NQA = NQ > 0 ? NQ : 1;
392
+ // distance accumulators
393
+ // layout: accu[q][b]: distance accumulator for vectors 8*b..8*b+7
394
+ simd32uint16 accu[NQA][4];
395
+
396
+ for (int q = 0; q < NQ; q++) {
397
+ for (int b = 0; b < 4; b++) {
398
+ accu[q][b].clear();
399
+ }
400
+ }
401
+
402
+ // process "nsq - scaler.nscale" part
403
+ const int nsq_minus_nscale = nsq - scaler.nscale;
404
+ const int nsq_minus_nscale_4 = (nsq_minus_nscale / 4) * 4;
405
+
406
+ // process in chunks of 8
407
+ for (int sq = 0; sq < nsq_minus_nscale_4; sq += 4) {
408
+ // prefetch
409
+ simd64uint8 c(codes);
410
+ codes += 64;
411
+
412
+ simd64uint8 mask(0xf);
413
+ // shift op does not exist for int8...
414
+ simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask;
415
+ simd64uint8 clo = c & mask;
416
+
417
+ for (int q = 0; q < NQ; q++) {
418
+ // load LUTs for 4 quantizers
419
+ simd32uint8 lut_a(LUT);
420
+ simd32uint8 lut_b(LUT + NQ * 32);
421
+
422
+ simd64uint8 lut(lut_a, lut_b);
423
+ LUT += 32;
424
+
425
+ {
426
+ simd64uint8 res0 = lut.lookup_4_lanes(clo);
427
+ simd64uint8 res1 = lut.lookup_4_lanes(chi);
428
+
429
+ accu[q][0] += simd32uint16(res0);
430
+ accu[q][1] += simd32uint16(res0) >> 8;
431
+
432
+ accu[q][2] += simd32uint16(res1);
433
+ accu[q][3] += simd32uint16(res1) >> 8;
434
+ }
435
+ }
436
+
437
+ LUT += NQ * 32;
438
+ }
439
+
440
+ // process leftovers: a single chunk of size 2
441
+ if (nsq_minus_nscale_4 != nsq_minus_nscale) {
442
+ // prefetch
443
+ simd32uint8 c(codes);
444
+ codes += 32;
445
+
446
+ simd32uint8 mask(0xf);
447
+ // shift op does not exist for int8...
448
+ simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
449
+ simd32uint8 clo = c & mask;
450
+
451
+ for (int q = 0; q < NQ; q++) {
452
+ // load LUTs for 2 quantizers
453
+ simd32uint8 lut(LUT);
454
+ LUT += 32;
455
+
456
+ simd32uint8 res0 = lut.lookup_2_lanes(clo);
457
+ simd32uint8 res1 = lut.lookup_2_lanes(chi);
458
+
459
+ accu[q][0] += simd32uint16(simd16uint16(res0));
460
+ accu[q][1] += simd32uint16(simd16uint16(res0) >> 8);
461
+
462
+ accu[q][2] += simd32uint16(simd16uint16(res1));
463
+ accu[q][3] += simd32uint16(simd16uint16(res1) >> 8);
464
+ }
465
+ }
466
+
467
+ // process "sq" part
468
+ const int nscale = scaler.nscale;
469
+ const int nscale_4 = (nscale / 4) * 4;
470
+
471
+ // process in chunks of 4
472
+ for (int sq = 0; sq < nscale_4; sq += 4) {
473
+ // prefetch
474
+ simd64uint8 c(codes);
475
+ codes += 64;
476
+
477
+ simd64uint8 mask(0xf);
478
+ // shift op does not exist for int8...
479
+ simd64uint8 chi = simd64uint8(simd32uint16(c) >> 4) & mask;
480
+ simd64uint8 clo = c & mask;
481
+
482
+ for (int q = 0; q < NQ; q++) {
483
+ // load LUTs for 4 quantizers
484
+ simd32uint8 lut_a(LUT);
485
+ simd32uint8 lut_b(LUT + NQ * 32);
486
+
487
+ simd64uint8 lut(lut_a, lut_b);
488
+ LUT += 32;
489
+
490
+ {
491
+ simd64uint8 res0 = scaler.lookup(lut, clo);
492
+ accu[q][0] += scaler.scale_lo(res0); // handle vectors 0..7
493
+ accu[q][1] += scaler.scale_hi(res0); // handle vectors 8..15
494
+
495
+ simd64uint8 res1 = scaler.lookup(lut, chi);
496
+ accu[q][2] += scaler.scale_lo(res1); // handle vectors 16..23
497
+ accu[q][3] += scaler.scale_hi(res1); // handle vectors 24..31
498
+ }
499
+ }
500
+
501
+ LUT += NQ * 32;
502
+ }
503
+
504
+ // process leftovers: a single chunk of size 2
505
+ if (nscale_4 != nscale) {
506
+ // prefetch
507
+ simd32uint8 c(codes);
508
+ codes += 32;
509
+
510
+ simd32uint8 mask(0xf);
511
+ // shift op does not exist for int8...
512
+ simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
513
+ simd32uint8 clo = c & mask;
514
+
515
+ for (int q = 0; q < NQ; q++) {
516
+ // load LUTs for 2 quantizers
517
+ simd32uint8 lut(LUT);
518
+ LUT += 32;
519
+
520
+ simd32uint8 res0 = scaler.lookup(lut, clo);
521
+ accu[q][0] +=
522
+ simd32uint16(scaler.scale_lo(res0)); // handle vectors 0..7
523
+ accu[q][1] +=
524
+ simd32uint16(scaler.scale_hi(res0)); // handle vectors 8..15
525
+
526
+ simd32uint8 res1 = scaler.lookup(lut, chi);
527
+ accu[q][2] += simd32uint16(
528
+ scaler.scale_lo(res1)); // handle vectors 16..23
529
+ accu[q][3] += simd32uint16(
530
+ scaler.scale_hi(res1)); // handle vectors 24..31
531
+ }
532
+ }
533
+
534
+ for (int q = 0; q < NQ; q++) {
535
+ accu[q][0] -= accu[q][1] << 8;
536
+ simd16uint16 dis0 = combine4x2(accu[q][0], accu[q][1]);
537
+ accu[q][2] -= accu[q][3] << 8;
538
+ simd16uint16 dis1 = combine4x2(accu[q][2], accu[q][3]);
539
+ res.handle(q, 0, dis0, dis1);
540
+ }
541
+ }
542
+
543
+ template <int NQ, class ResultHandler, class Scaler>
544
+ void kernel_accumulate_block(
545
+ int nsq,
546
+ const uint8_t* codes,
547
+ const uint8_t* LUT,
548
+ ResultHandler& res,
549
+ const Scaler& scaler) {
550
+ if constexpr (NQ == 1) {
551
+ kernel_accumulate_block_avx512_nq1<ResultHandler, Scaler>(
552
+ nsq, codes, LUT, res, scaler);
553
+ } else {
554
+ kernel_accumulate_block_avx512_nqx<NQ, ResultHandler, Scaler>(
555
+ nsq, codes, LUT, res, scaler);
556
+ }
557
+ }
558
+
559
+ #endif
560
+
111
561
  // handle at most 4 blocks of queries
112
562
  template <int QBS, class ResultHandler, class Scaler>
113
563
  void accumulate_q_4step(
@@ -123,7 +573,7 @@ void accumulate_q_4step(
123
573
  constexpr int Q4 = (QBS >> 12) & 15;
124
574
  constexpr int SQ = Q1 + Q2 + Q3 + Q4;
125
575
 
126
- for (int64_t j0 = 0; j0 < ntotal2; j0 += 32) {
576
+ for (size_t j0 = 0; j0 < ntotal2; j0 += 32) {
127
577
  FixedStorageHandler<SQ, 2> res2;
128
578
  const uint8_t* LUT = LUT0;
129
579
  kernel_accumulate_block<Q1>(nsq, codes, LUT, res2, scaler);
@@ -156,7 +606,7 @@ void kernel_accumulate_block_loop(
156
606
  const uint8_t* LUT,
157
607
  ResultHandler& res,
158
608
  const Scaler& scaler) {
159
- for (int64_t j0 = 0; j0 < ntotal2; j0 += 32) {
609
+ for (size_t j0 = 0; j0 < ntotal2; j0 += 32) {
160
610
  res.set_block_origin(0, j0);
161
611
  kernel_accumulate_block<NQ, ResultHandler>(
162
612
  nsq, codes + j0 * nsq / 2, LUT, res, scaler);
@@ -194,10 +644,8 @@ void accumulate(
194
644
  #undef DISPATCH
195
645
  }
196
646
 
197
- } // namespace
198
-
199
647
  template <class ResultHandler, class Scaler>
200
- void pq4_accumulate_loop_qbs(
648
+ void pq4_accumulate_loop_qbs_fixed_scaler(
201
649
  int qbs,
202
650
  size_t ntotal2,
203
651
  int nsq,
@@ -243,7 +691,7 @@ void pq4_accumulate_loop_qbs(
243
691
 
244
692
  // default implementation where qbs is not known at compile time
245
693
 
246
- for (int64_t j0 = 0; j0 < ntotal2; j0 += 32) {
694
+ for (size_t j0 = 0; j0 < ntotal2; j0 += 32) {
247
695
  const uint8_t* LUT = LUT0;
248
696
  int qi = qbs;
249
697
  int i0 = 0;
@@ -272,49 +720,39 @@ void pq4_accumulate_loop_qbs(
272
720
  }
273
721
  }
274
722
 
275
- // explicit template instantiations
276
-
277
- #define INSTANTIATE_ACCUMULATE_Q(RH) \
278
- template void pq4_accumulate_loop_qbs<RH, DummyScaler>( \
279
- int, \
280
- size_t, \
281
- int, \
282
- const uint8_t*, \
283
- const uint8_t*, \
284
- RH&, \
285
- const DummyScaler&); \
286
- template void pq4_accumulate_loop_qbs<RH, NormTableScaler>( \
287
- int, \
288
- size_t, \
289
- int, \
290
- const uint8_t*, \
291
- const uint8_t*, \
292
- RH&, \
293
- const NormTableScaler&);
294
-
295
- using Csi = CMax<uint16_t, int>;
296
- INSTANTIATE_ACCUMULATE_Q(SingleResultHandler<Csi>)
297
- INSTANTIATE_ACCUMULATE_Q(HeapHandler<Csi>)
298
- INSTANTIATE_ACCUMULATE_Q(ReservoirHandler<Csi>)
299
- using Csi2 = CMin<uint16_t, int>;
300
- INSTANTIATE_ACCUMULATE_Q(SingleResultHandler<Csi2>)
301
- INSTANTIATE_ACCUMULATE_Q(HeapHandler<Csi2>)
302
- INSTANTIATE_ACCUMULATE_Q(ReservoirHandler<Csi2>)
303
-
304
- using Cfl = CMax<uint16_t, int64_t>;
305
- using HHCsl = HeapHandler<Cfl, true>;
306
- using RHCsl = ReservoirHandler<Cfl, true>;
307
- using SHCsl = SingleResultHandler<Cfl, true>;
308
- INSTANTIATE_ACCUMULATE_Q(HHCsl)
309
- INSTANTIATE_ACCUMULATE_Q(RHCsl)
310
- INSTANTIATE_ACCUMULATE_Q(SHCsl)
311
- using Cfl2 = CMin<uint16_t, int64_t>;
312
- using HHCsl2 = HeapHandler<Cfl2, true>;
313
- using RHCsl2 = ReservoirHandler<Cfl2, true>;
314
- using SHCsl2 = SingleResultHandler<Cfl2, true>;
315
- INSTANTIATE_ACCUMULATE_Q(HHCsl2)
316
- INSTANTIATE_ACCUMULATE_Q(RHCsl2)
317
- INSTANTIATE_ACCUMULATE_Q(SHCsl2)
723
+ struct Run_pq4_accumulate_loop_qbs {
724
+ template <class ResultHandler>
725
+ void f(ResultHandler& res,
726
+ int qbs,
727
+ size_t nb,
728
+ int nsq,
729
+ const uint8_t* codes,
730
+ const uint8_t* LUT,
731
+ const NormTableScaler* scaler) {
732
+ if (scaler) {
733
+ pq4_accumulate_loop_qbs_fixed_scaler(
734
+ qbs, nb, nsq, codes, LUT, res, *scaler);
735
+ } else {
736
+ DummyScaler dummy;
737
+ pq4_accumulate_loop_qbs_fixed_scaler(
738
+ qbs, nb, nsq, codes, LUT, res, dummy);
739
+ }
740
+ }
741
+ };
742
+
743
+ } // namespace
744
+
745
+ void pq4_accumulate_loop_qbs(
746
+ int qbs,
747
+ size_t nb,
748
+ int nsq,
749
+ const uint8_t* codes,
750
+ const uint8_t* LUT,
751
+ SIMDResultHandler& res,
752
+ const NormTableScaler* scaler) {
753
+ Run_pq4_accumulate_loop_qbs consumer;
754
+ dispatch_SIMDResultHandler(res, consumer, qbs, nb, nsq, codes, LUT, scaler);
755
+ }
318
756
 
319
757
  /***************************************************************
320
758
  * Packing functions