faiss 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -65,42 +65,65 @@ using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer;
65
65
  */
66
66
 
67
67
  struct Codec8bit {
68
- static void encode_component(float x, uint8_t* code, int i) {
68
+ static FAISS_ALWAYS_INLINE void encode_component(
69
+ float x,
70
+ uint8_t* code,
71
+ int i) {
69
72
  code[i] = (int)(255 * x);
70
73
  }
71
74
 
72
- static float decode_component(const uint8_t* code, int i) {
75
+ static FAISS_ALWAYS_INLINE float decode_component(
76
+ const uint8_t* code,
77
+ int i) {
73
78
  return (code[i] + 0.5f) / 255.0f;
74
79
  }
75
80
 
76
81
  #ifdef __AVX2__
77
- static __m256 decode_8_components(const uint8_t* code, int i) {
78
- uint64_t c8 = *(uint64_t*)(code + i);
79
- __m128i c4lo = _mm_cvtepu8_epi32(_mm_set1_epi32(c8));
80
- __m128i c4hi = _mm_cvtepu8_epi32(_mm_set1_epi32(c8 >> 32));
81
- // __m256i i8 = _mm256_set_m128i(c4lo, c4hi);
82
- __m256i i8 = _mm256_castsi128_si256(c4lo);
83
- i8 = _mm256_insertf128_si256(i8, c4hi, 1);
84
- __m256 f8 = _mm256_cvtepi32_ps(i8);
85
- __m256 half = _mm256_set1_ps(0.5f);
86
- f8 = _mm256_add_ps(f8, half);
87
- __m256 one_255 = _mm256_set1_ps(1.f / 255.f);
88
- return _mm256_mul_ps(f8, one_255);
82
+ static FAISS_ALWAYS_INLINE __m256
83
+ decode_8_components(const uint8_t* code, int i) {
84
+ const uint64_t c8 = *(uint64_t*)(code + i);
85
+
86
+ const __m128i i8 = _mm_set1_epi64x(c8);
87
+ const __m256i i32 = _mm256_cvtepu8_epi32(i8);
88
+ const __m256 f8 = _mm256_cvtepi32_ps(i32);
89
+ const __m256 half_one_255 = _mm256_set1_ps(0.5f / 255.f);
90
+ const __m256 one_255 = _mm256_set1_ps(1.f / 255.f);
91
+ return _mm256_fmadd_ps(f8, one_255, half_one_255);
92
+ }
93
+ #endif
94
+
95
+ #ifdef __aarch64__
96
+ static FAISS_ALWAYS_INLINE float32x4x2_t
97
+ decode_8_components(const uint8_t* code, int i) {
98
+ float32_t result[8] = {};
99
+ for (size_t j = 0; j < 8; j++) {
100
+ result[j] = decode_component(code, i + j);
101
+ }
102
+ float32x4_t res1 = vld1q_f32(result);
103
+ float32x4_t res2 = vld1q_f32(result + 4);
104
+ float32x4x2_t res = vzipq_f32(res1, res2);
105
+ return vuzpq_f32(res.val[0], res.val[1]);
89
106
  }
90
107
  #endif
91
108
  };
92
109
 
93
110
  struct Codec4bit {
94
- static void encode_component(float x, uint8_t* code, int i) {
111
+ static FAISS_ALWAYS_INLINE void encode_component(
112
+ float x,
113
+ uint8_t* code,
114
+ int i) {
95
115
  code[i / 2] |= (int)(x * 15.0) << ((i & 1) << 2);
96
116
  }
97
117
 
98
- static float decode_component(const uint8_t* code, int i) {
118
+ static FAISS_ALWAYS_INLINE float decode_component(
119
+ const uint8_t* code,
120
+ int i) {
99
121
  return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
100
122
  }
101
123
 
102
124
  #ifdef __AVX2__
103
- static __m256 decode_8_components(const uint8_t* code, int i) {
125
+ static FAISS_ALWAYS_INLINE __m256
126
+ decode_8_components(const uint8_t* code, int i) {
104
127
  uint32_t c4 = *(uint32_t*)(code + (i >> 1));
105
128
  uint32_t mask = 0x0f0f0f0f;
106
129
  uint32_t c4ev = c4 & mask;
@@ -120,10 +143,27 @@ struct Codec4bit {
120
143
  return _mm256_mul_ps(f8, one_255);
121
144
  }
122
145
  #endif
146
+
147
+ #ifdef __aarch64__
148
+ static FAISS_ALWAYS_INLINE float32x4x2_t
149
+ decode_8_components(const uint8_t* code, int i) {
150
+ float32_t result[8] = {};
151
+ for (size_t j = 0; j < 8; j++) {
152
+ result[j] = decode_component(code, i + j);
153
+ }
154
+ float32x4_t res1 = vld1q_f32(result);
155
+ float32x4_t res2 = vld1q_f32(result + 4);
156
+ float32x4x2_t res = vzipq_f32(res1, res2);
157
+ return vuzpq_f32(res.val[0], res.val[1]);
158
+ }
159
+ #endif
123
160
  };
124
161
 
125
162
  struct Codec6bit {
126
- static void encode_component(float x, uint8_t* code, int i) {
163
+ static FAISS_ALWAYS_INLINE void encode_component(
164
+ float x,
165
+ uint8_t* code,
166
+ int i) {
127
167
  int bits = (int)(x * 63.0);
128
168
  code += (i >> 2) * 3;
129
169
  switch (i & 3) {
@@ -144,7 +184,9 @@ struct Codec6bit {
144
184
  }
145
185
  }
146
186
 
147
- static float decode_component(const uint8_t* code, int i) {
187
+ static FAISS_ALWAYS_INLINE float decode_component(
188
+ const uint8_t* code,
189
+ int i) {
148
190
  uint8_t bits;
149
191
  code += (i >> 2) * 3;
150
192
  switch (i & 3) {
@@ -170,7 +212,7 @@ struct Codec6bit {
170
212
 
171
213
  /* Load 6 bytes that represent 8 6-bit values, return them as a
172
214
  * 8*32 bit vector register */
173
- static __m256i load6(const uint16_t* code16) {
215
+ static FAISS_ALWAYS_INLINE __m256i load6(const uint16_t* code16) {
174
216
  const __m128i perm = _mm_set_epi8(
175
217
  -1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0);
176
218
  const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0);
@@ -189,18 +231,45 @@ struct Codec6bit {
189
231
  return c5;
190
232
  }
191
233
 
192
- static __m256 decode_8_components(const uint8_t* code, int i) {
234
+ static FAISS_ALWAYS_INLINE __m256
235
+ decode_8_components(const uint8_t* code, int i) {
236
+ // // Faster code for Intel CPUs or AMD Zen3+, just keeping it here
237
+ // // for the reference, maybe, it becomes used oned day.
238
+ // const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3);
239
+ // const uint32_t* data32 = (const uint32_t*)data16;
240
+ // const uint64_t val = *data32 + ((uint64_t)data16[2] << 32);
241
+ // const uint64_t vext = _pdep_u64(val, 0x3F3F3F3F3F3F3F3FULL);
242
+ // const __m128i i8 = _mm_set1_epi64x(vext);
243
+ // const __m256i i32 = _mm256_cvtepi8_epi32(i8);
244
+ // const __m256 f8 = _mm256_cvtepi32_ps(i32);
245
+ // const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
246
+ // const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
247
+ // return _mm256_fmadd_ps(f8, one_255, half_one_255);
248
+
193
249
  __m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3));
194
250
  __m256 f8 = _mm256_cvtepi32_ps(i8);
195
251
  // this could also be done with bit manipulations but it is
196
252
  // not obviously faster
197
- __m256 half = _mm256_set1_ps(0.5f);
198
- f8 = _mm256_add_ps(f8, half);
199
- __m256 one_63 = _mm256_set1_ps(1.f / 63.f);
200
- return _mm256_mul_ps(f8, one_63);
253
+ const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
254
+ const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
255
+ return _mm256_fmadd_ps(f8, one_255, half_one_255);
201
256
  }
202
257
 
203
258
  #endif
259
+
260
+ #ifdef __aarch64__
261
+ static FAISS_ALWAYS_INLINE float32x4x2_t
262
+ decode_8_components(const uint8_t* code, int i) {
263
+ float32_t result[8] = {};
264
+ for (size_t j = 0; j < 8; j++) {
265
+ result[j] = decode_component(code, i + j);
266
+ }
267
+ float32x4_t res1 = vld1q_f32(result);
268
+ float32x4_t res2 = vld1q_f32(result + 4);
269
+ float32x4x2_t res = vzipq_f32(res1, res2);
270
+ return vuzpq_f32(res.val[0], res.val[1]);
271
+ }
272
+ #endif
204
273
  };
205
274
 
206
275
  /*******************************************************************
@@ -242,7 +311,8 @@ struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
242
311
  }
243
312
  }
244
313
 
245
- float reconstruct_component(const uint8_t* code, int i) const {
314
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
315
+ const {
246
316
  float xi = Codec::decode_component(code, i);
247
317
  return vmin + xi * vdiff;
248
318
  }
@@ -255,11 +325,36 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
255
325
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
256
326
  : QuantizerTemplate<Codec, true, 1>(d, trained) {}
257
327
 
258
- __m256 reconstruct_8_components(const uint8_t* code, int i) const {
328
+ FAISS_ALWAYS_INLINE __m256
329
+ reconstruct_8_components(const uint8_t* code, int i) const {
259
330
  __m256 xi = Codec::decode_8_components(code, i);
260
- return _mm256_add_ps(
261
- _mm256_set1_ps(this->vmin),
262
- _mm256_mul_ps(xi, _mm256_set1_ps(this->vdiff)));
331
+ return _mm256_fmadd_ps(
332
+ xi, _mm256_set1_ps(this->vdiff), _mm256_set1_ps(this->vmin));
333
+ }
334
+ };
335
+
336
+ #endif
337
+
338
+ #ifdef __aarch64__
339
+
340
+ template <class Codec>
341
+ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
342
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
343
+ : QuantizerTemplate<Codec, true, 1>(d, trained) {}
344
+
345
+ FAISS_ALWAYS_INLINE float32x4x2_t
346
+ reconstruct_8_components(const uint8_t* code, int i) const {
347
+ float32x4x2_t xi = Codec::decode_8_components(code, i);
348
+ float32x4x2_t res = vzipq_f32(
349
+ vfmaq_f32(
350
+ vdupq_n_f32(this->vmin),
351
+ xi.val[0],
352
+ vdupq_n_f32(this->vdiff)),
353
+ vfmaq_f32(
354
+ vdupq_n_f32(this->vmin),
355
+ xi.val[1],
356
+ vdupq_n_f32(this->vdiff)));
357
+ return vuzpq_f32(res.val[0], res.val[1]);
263
358
  }
264
359
  };
265
360
 
@@ -296,7 +391,8 @@ struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
296
391
  }
297
392
  }
298
393
 
299
- float reconstruct_component(const uint8_t* code, int i) const {
394
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
395
+ const {
300
396
  float xi = Codec::decode_component(code, i);
301
397
  return vmin[i] + xi * vdiff[i];
302
398
  }
@@ -309,11 +405,36 @@ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
309
405
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
310
406
  : QuantizerTemplate<Codec, false, 1>(d, trained) {}
311
407
 
312
- __m256 reconstruct_8_components(const uint8_t* code, int i) const {
408
+ FAISS_ALWAYS_INLINE __m256
409
+ reconstruct_8_components(const uint8_t* code, int i) const {
313
410
  __m256 xi = Codec::decode_8_components(code, i);
314
- return _mm256_add_ps(
315
- _mm256_loadu_ps(this->vmin + i),
316
- _mm256_mul_ps(xi, _mm256_loadu_ps(this->vdiff + i)));
411
+ return _mm256_fmadd_ps(
412
+ xi,
413
+ _mm256_loadu_ps(this->vdiff + i),
414
+ _mm256_loadu_ps(this->vmin + i));
415
+ }
416
+ };
417
+
418
+ #endif
419
+
420
+ #ifdef __aarch64__
421
+
422
+ template <class Codec>
423
+ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
424
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
425
+ : QuantizerTemplate<Codec, false, 1>(d, trained) {}
426
+
427
+ FAISS_ALWAYS_INLINE float32x4x2_t
428
+ reconstruct_8_components(const uint8_t* code, int i) const {
429
+ float32x4x2_t xi = Codec::decode_8_components(code, i);
430
+
431
+ float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i);
432
+ float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i);
433
+
434
+ float32x4x2_t res = vzipq_f32(
435
+ vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
436
+ vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1]));
437
+ return vuzpq_f32(res.val[0], res.val[1]);
317
438
  }
318
439
  };
319
440
 
@@ -344,7 +465,8 @@ struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer {
344
465
  }
345
466
  }
346
467
 
347
- float reconstruct_component(const uint8_t* code, int i) const {
468
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
469
+ const {
348
470
  return decode_fp16(((uint16_t*)code)[i]);
349
471
  }
350
472
  };
@@ -356,7 +478,8 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
356
478
  QuantizerFP16(size_t d, const std::vector<float>& trained)
357
479
  : QuantizerFP16<1>(d, trained) {}
358
480
 
359
- __m256 reconstruct_8_components(const uint8_t* code, int i) const {
481
+ FAISS_ALWAYS_INLINE __m256
482
+ reconstruct_8_components(const uint8_t* code, int i) const {
360
483
  __m128i codei = _mm_loadu_si128((const __m128i*)(code + 2 * i));
361
484
  return _mm256_cvtph_ps(codei);
362
485
  }
@@ -364,6 +487,23 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
364
487
 
365
488
  #endif
366
489
 
490
+ #ifdef __aarch64__
491
+
492
+ template <>
493
+ struct QuantizerFP16<8> : QuantizerFP16<1> {
494
+ QuantizerFP16(size_t d, const std::vector<float>& trained)
495
+ : QuantizerFP16<1>(d, trained) {}
496
+
497
+ FAISS_ALWAYS_INLINE float32x4x2_t
498
+ reconstruct_8_components(const uint8_t* code, int i) const {
499
+ uint16x4x2_t codei = vld2_u16((const uint16_t*)(code + 2 * i));
500
+ return vzipq_f32(
501
+ vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
502
+ vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1])));
503
+ }
504
+ };
505
+ #endif
506
+
367
507
  /*******************************************************************
368
508
  * 8bit_direct quantizer
369
509
  *******************************************************************/
@@ -390,7 +530,8 @@ struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer {
390
530
  }
391
531
  }
392
532
 
393
- float reconstruct_component(const uint8_t* code, int i) const {
533
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
534
+ const {
394
535
  return code[i];
395
536
  }
396
537
  };
@@ -402,7 +543,8 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
402
543
  Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
403
544
  : Quantizer8bitDirect<1>(d, trained) {}
404
545
 
405
- __m256 reconstruct_8_components(const uint8_t* code, int i) const {
546
+ FAISS_ALWAYS_INLINE __m256
547
+ reconstruct_8_components(const uint8_t* code, int i) const {
406
548
  __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
407
549
  __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32
408
550
  return _mm256_cvtepi32_ps(y8); // 8 * float32
@@ -411,6 +553,28 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
411
553
 
412
554
  #endif
413
555
 
556
+ #ifdef __aarch64__
557
+
558
+ template <>
559
+ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
560
+ Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
561
+ : Quantizer8bitDirect<1>(d, trained) {}
562
+
563
+ FAISS_ALWAYS_INLINE float32x4x2_t
564
+ reconstruct_8_components(const uint8_t* code, int i) const {
565
+ float32_t result[8] = {};
566
+ for (size_t j = 0; j < 8; j++) {
567
+ result[j] = code[i + j];
568
+ }
569
+ float32x4_t res1 = vld1q_f32(result);
570
+ float32x4_t res2 = vld1q_f32(result + 4);
571
+ float32x4x2_t res = vzipq_f32(res1, res2);
572
+ return vuzpq_f32(res.val[0], res.val[1]);
573
+ }
574
+ };
575
+
576
+ #endif
577
+
414
578
  template <int SIMDWIDTH>
415
579
  ScalarQuantizer::SQuantizer* select_quantizer_1(
416
580
  QuantizerType qtype,
@@ -486,7 +650,7 @@ void train_Uniform(
486
650
  } else if (rs == ScalarQuantizer::RS_quantiles) {
487
651
  std::vector<float> x_copy(n);
488
652
  memcpy(x_copy.data(), x, n * sizeof(*x));
489
- // TODO just do a qucikselect
653
+ // TODO just do a quickselect
490
654
  std::sort(x_copy.begin(), x_copy.end());
491
655
  int o = int(rs_arg * n);
492
656
  if (o < 0)
@@ -632,22 +796,22 @@ struct SimilarityL2<1> {
632
796
 
633
797
  float accu;
634
798
 
635
- void begin() {
799
+ FAISS_ALWAYS_INLINE void begin() {
636
800
  accu = 0;
637
801
  yi = y;
638
802
  }
639
803
 
640
- void add_component(float x) {
804
+ FAISS_ALWAYS_INLINE void add_component(float x) {
641
805
  float tmp = *yi++ - x;
642
806
  accu += tmp * tmp;
643
807
  }
644
808
 
645
- void add_component_2(float x1, float x2) {
809
+ FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) {
646
810
  float tmp = x1 - x2;
647
811
  accu += tmp * tmp;
648
812
  }
649
813
 
650
- float result() {
814
+ FAISS_ALWAYS_INLINE float result() {
651
815
  return accu;
652
816
  }
653
817
  };
@@ -663,34 +827,89 @@ struct SimilarityL2<8> {
663
827
  explicit SimilarityL2(const float* y) : y(y) {}
664
828
  __m256 accu8;
665
829
 
666
- void begin_8() {
830
+ FAISS_ALWAYS_INLINE void begin_8() {
667
831
  accu8 = _mm256_setzero_ps();
668
832
  yi = y;
669
833
  }
670
834
 
671
- void add_8_components(__m256 x) {
835
+ FAISS_ALWAYS_INLINE void add_8_components(__m256 x) {
672
836
  __m256 yiv = _mm256_loadu_ps(yi);
673
837
  yi += 8;
674
838
  __m256 tmp = _mm256_sub_ps(yiv, x);
675
- accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(tmp, tmp));
839
+ accu8 = _mm256_fmadd_ps(tmp, tmp, accu8);
676
840
  }
677
841
 
678
- void add_8_components_2(__m256 x, __m256 y) {
679
- __m256 tmp = _mm256_sub_ps(y, x);
680
- accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(tmp, tmp));
842
+ FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x, __m256 y_2) {
843
+ __m256 tmp = _mm256_sub_ps(y_2, x);
844
+ accu8 = _mm256_fmadd_ps(tmp, tmp, accu8);
681
845
  }
682
846
 
683
- float result_8() {
684
- __m256 sum = _mm256_hadd_ps(accu8, accu8);
685
- __m256 sum2 = _mm256_hadd_ps(sum, sum);
686
- // now add the 0th and 4th component
687
- return _mm_cvtss_f32(_mm256_castps256_ps128(sum2)) +
688
- _mm_cvtss_f32(_mm256_extractf128_ps(sum2, 1));
847
+ FAISS_ALWAYS_INLINE float result_8() {
848
+ const __m128 sum = _mm_add_ps(
849
+ _mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1));
850
+ const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2));
851
+ const __m128 v1 = _mm_add_ps(sum, v0);
852
+ __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
853
+ const __m128 v3 = _mm_add_ps(v1, v2);
854
+ return _mm_cvtss_f32(v3);
689
855
  }
690
856
  };
691
857
 
692
858
  #endif
693
859
 
860
+ #ifdef __aarch64__
861
+ template <>
862
+ struct SimilarityL2<8> {
863
+ static constexpr int simdwidth = 8;
864
+ static constexpr MetricType metric_type = METRIC_L2;
865
+
866
+ const float *y, *yi;
867
+ explicit SimilarityL2(const float* y) : y(y) {}
868
+ float32x4x2_t accu8;
869
+
870
+ FAISS_ALWAYS_INLINE void begin_8() {
871
+ accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
872
+ yi = y;
873
+ }
874
+
875
+ FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) {
876
+ float32x4x2_t yiv = vld1q_f32_x2(yi);
877
+ yi += 8;
878
+
879
+ float32x4_t sub0 = vsubq_f32(yiv.val[0], x.val[0]);
880
+ float32x4_t sub1 = vsubq_f32(yiv.val[1], x.val[1]);
881
+
882
+ float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
883
+ float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
884
+
885
+ float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
886
+ accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
887
+ }
888
+
889
+ FAISS_ALWAYS_INLINE void add_8_components_2(
890
+ float32x4x2_t x,
891
+ float32x4x2_t y) {
892
+ float32x4_t sub0 = vsubq_f32(y.val[0], x.val[0]);
893
+ float32x4_t sub1 = vsubq_f32(y.val[1], x.val[1]);
894
+
895
+ float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
896
+ float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
897
+
898
+ float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
899
+ accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
900
+ }
901
+
902
+ FAISS_ALWAYS_INLINE float result_8() {
903
+ float32x4_t sum_0 = vpaddq_f32(accu8.val[0], accu8.val[0]);
904
+ float32x4_t sum_1 = vpaddq_f32(accu8.val[1], accu8.val[1]);
905
+
906
+ float32x4_t sum2_0 = vpaddq_f32(sum_0, sum_0);
907
+ float32x4_t sum2_1 = vpaddq_f32(sum_1, sum_1);
908
+ return vgetq_lane_f32(sum2_0, 0) + vgetq_lane_f32(sum2_1, 0);
909
+ }
910
+ };
911
+ #endif
912
+
694
913
  template <int SIMDWIDTH>
695
914
  struct SimilarityIP {};
696
915
 
@@ -704,20 +923,20 @@ struct SimilarityIP<1> {
704
923
 
705
924
  explicit SimilarityIP(const float* y) : y(y) {}
706
925
 
707
- void begin() {
926
+ FAISS_ALWAYS_INLINE void begin() {
708
927
  accu = 0;
709
928
  yi = y;
710
929
  }
711
930
 
712
- void add_component(float x) {
931
+ FAISS_ALWAYS_INLINE void add_component(float x) {
713
932
  accu += *yi++ * x;
714
933
  }
715
934
 
716
- void add_component_2(float x1, float x2) {
935
+ FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) {
717
936
  accu += x1 * x2;
718
937
  }
719
938
 
720
- float result() {
939
+ FAISS_ALWAYS_INLINE float result() {
721
940
  return accu;
722
941
  }
723
942
  };
@@ -737,27 +956,79 @@ struct SimilarityIP<8> {
737
956
 
738
957
  __m256 accu8;
739
958
 
740
- void begin_8() {
959
+ FAISS_ALWAYS_INLINE void begin_8() {
741
960
  accu8 = _mm256_setzero_ps();
742
961
  yi = y;
743
962
  }
744
963
 
745
- void add_8_components(__m256 x) {
964
+ FAISS_ALWAYS_INLINE void add_8_components(__m256 x) {
746
965
  __m256 yiv = _mm256_loadu_ps(yi);
747
966
  yi += 8;
748
- accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(yiv, x));
967
+ accu8 = _mm256_fmadd_ps(yiv, x, accu8);
968
+ }
969
+
970
+ FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x1, __m256 x2) {
971
+ accu8 = _mm256_fmadd_ps(x1, x2, accu8);
972
+ }
973
+
974
+ FAISS_ALWAYS_INLINE float result_8() {
975
+ const __m128 sum = _mm_add_ps(
976
+ _mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1));
977
+ const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2));
978
+ const __m128 v1 = _mm_add_ps(sum, v0);
979
+ __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
980
+ const __m128 v3 = _mm_add_ps(v1, v2);
981
+ return _mm_cvtss_f32(v3);
982
+ }
983
+ };
984
+ #endif
985
+
986
+ #ifdef __aarch64__
987
+
988
+ template <>
989
+ struct SimilarityIP<8> {
990
+ static constexpr int simdwidth = 8;
991
+ static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
992
+
993
+ const float *y, *yi;
994
+
995
+ explicit SimilarityIP(const float* y) : y(y) {}
996
+ float32x4x2_t accu8;
997
+
998
+ FAISS_ALWAYS_INLINE void begin_8() {
999
+ accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
1000
+ yi = y;
1001
+ }
1002
+
1003
+ FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) {
1004
+ float32x4x2_t yiv = vld1q_f32_x2(yi);
1005
+ yi += 8;
1006
+
1007
+ float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]);
1008
+ float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]);
1009
+ float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
1010
+ accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
749
1011
  }
750
1012
 
751
- void add_8_components_2(__m256 x1, __m256 x2) {
752
- accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(x1, x2));
1013
+ FAISS_ALWAYS_INLINE void add_8_components_2(
1014
+ float32x4x2_t x1,
1015
+ float32x4x2_t x2) {
1016
+ float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]);
1017
+ float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]);
1018
+ float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
1019
+ accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
753
1020
  }
754
1021
 
755
- float result_8() {
756
- __m256 sum = _mm256_hadd_ps(accu8, accu8);
757
- __m256 sum2 = _mm256_hadd_ps(sum, sum);
758
- // now add the 0th and 4th component
759
- return _mm_cvtss_f32(_mm256_castps256_ps128(sum2)) +
760
- _mm_cvtss_f32(_mm256_extractf128_ps(sum2, 1));
1022
+ FAISS_ALWAYS_INLINE float result_8() {
1023
+ float32x4x2_t sum_tmp = vzipq_f32(
1024
+ vpaddq_f32(accu8.val[0], accu8.val[0]),
1025
+ vpaddq_f32(accu8.val[1], accu8.val[1]));
1026
+ float32x4x2_t sum = vuzpq_f32(sum_tmp.val[0], sum_tmp.val[1]);
1027
+ float32x4x2_t sum2_tmp = vzipq_f32(
1028
+ vpaddq_f32(sum.val[0], sum.val[0]),
1029
+ vpaddq_f32(sum.val[1], sum.val[1]));
1030
+ float32x4x2_t sum2 = vuzpq_f32(sum2_tmp.val[0], sum2_tmp.val[1]);
1031
+ return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0);
761
1032
  }
762
1033
  };
763
1034
  #endif
@@ -864,6 +1135,53 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
864
1135
 
865
1136
  #endif
866
1137
 
1138
+ #ifdef __aarch64__
1139
+
1140
+ template <class Quantizer, class Similarity>
1141
+ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
1142
+ using Sim = Similarity;
1143
+
1144
+ Quantizer quant;
1145
+
1146
+ DCTemplate(size_t d, const std::vector<float>& trained)
1147
+ : quant(d, trained) {}
1148
+ float compute_distance(const float* x, const uint8_t* code) const {
1149
+ Similarity sim(x);
1150
+ sim.begin_8();
1151
+ for (size_t i = 0; i < quant.d; i += 8) {
1152
+ float32x4x2_t xi = quant.reconstruct_8_components(code, i);
1153
+ sim.add_8_components(xi);
1154
+ }
1155
+ return sim.result_8();
1156
+ }
1157
+
1158
+ float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1159
+ const {
1160
+ Similarity sim(nullptr);
1161
+ sim.begin_8();
1162
+ for (size_t i = 0; i < quant.d; i += 8) {
1163
+ float32x4x2_t x1 = quant.reconstruct_8_components(code1, i);
1164
+ float32x4x2_t x2 = quant.reconstruct_8_components(code2, i);
1165
+ sim.add_8_components_2(x1, x2);
1166
+ }
1167
+ return sim.result_8();
1168
+ }
1169
+
1170
+ void set_query(const float* x) final {
1171
+ q = x;
1172
+ }
1173
+
1174
+ float symmetric_dis(idx_t i, idx_t j) override {
1175
+ return compute_code_distance(
1176
+ codes + i * code_size, codes + j * code_size);
1177
+ }
1178
+
1179
+ float query_to_code(const uint8_t* code) const final {
1180
+ return compute_distance(q, code);
1181
+ }
1182
+ };
1183
+ #endif
1184
+
867
1185
  /*******************************************************************
868
1186
  * DistanceComputerByte: computes distances in the integer domain
869
1187
  *******************************************************************/
@@ -980,6 +1298,54 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
980
1298
 
981
1299
  #endif
982
1300
 
1301
+ #ifdef __aarch64__
1302
+
1303
+ template <class Similarity>
1304
+ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1305
+ using Sim = Similarity;
1306
+
1307
+ int d;
1308
+ std::vector<uint8_t> tmp;
1309
+
1310
+ DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1311
+
1312
+ int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1313
+ const {
1314
+ int accu = 0;
1315
+ for (int i = 0; i < d; i++) {
1316
+ if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1317
+ accu += int(code1[i]) * code2[i];
1318
+ } else {
1319
+ int diff = int(code1[i]) - code2[i];
1320
+ accu += diff * diff;
1321
+ }
1322
+ }
1323
+ return accu;
1324
+ }
1325
+
1326
+ void set_query(const float* x) final {
1327
+ for (int i = 0; i < d; i++) {
1328
+ tmp[i] = int(x[i]);
1329
+ }
1330
+ }
1331
+
1332
+ int compute_distance(const float* x, const uint8_t* code) {
1333
+ set_query(x);
1334
+ return compute_code_distance(tmp.data(), code);
1335
+ }
1336
+
1337
+ float symmetric_dis(idx_t i, idx_t j) override {
1338
+ return compute_code_distance(
1339
+ codes + i * code_size, codes + j * code_size);
1340
+ }
1341
+
1342
+ float query_to_code(const uint8_t* code) const final {
1343
+ return compute_code_distance(tmp.data(), code);
1344
+ }
1345
+ };
1346
+
1347
+ #endif
1348
+
983
1349
  /*******************************************************************
984
1350
  * select_distance_computer: runtime selection of template
985
1351
  * specialization
@@ -1115,34 +1481,8 @@ void ScalarQuantizer::train(size_t n, const float* x) {
1115
1481
  }
1116
1482
  }
1117
1483
 
1118
- void ScalarQuantizer::train_residual(
1119
- size_t n,
1120
- const float* x,
1121
- Index* quantizer,
1122
- bool by_residual,
1123
- bool verbose) {
1124
- const float* x_in = x;
1125
-
1126
- // 100k points more than enough
1127
- x = fvecs_maybe_subsample(d, (size_t*)&n, 100000, x, verbose, 1234);
1128
-
1129
- ScopeDeleter<float> del_x(x_in == x ? nullptr : x);
1130
-
1131
- if (by_residual) {
1132
- std::vector<idx_t> idx(n);
1133
- quantizer->assign(n, x, idx.data());
1134
-
1135
- std::vector<float> residuals(n * d);
1136
- quantizer->compute_residual_n(n, x, residuals.data(), idx.data());
1137
-
1138
- train(n, residuals.data());
1139
- } else {
1140
- train(n, x);
1141
- }
1142
- }
1143
-
1144
1484
  ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
1145
- #ifdef USE_F16C
1485
+ #if defined(USE_F16C) || defined(__aarch64__)
1146
1486
  if (d % 8 == 0) {
1147
1487
  return select_quantizer_1<8>(qtype, d, trained);
1148
1488
  } else
@@ -1173,7 +1513,7 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
1173
1513
  SQDistanceComputer* ScalarQuantizer::get_distance_computer(
1174
1514
  MetricType metric) const {
1175
1515
  FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1176
- #ifdef USE_F16C
1516
+ #if defined(USE_F16C) || defined(__aarch64__)
1177
1517
  if (d % 8 == 0) {
1178
1518
  if (metric == METRIC_L2) {
1179
1519
  return select_distance_computer<SimilarityL2<8>>(qtype, d, trained);
@@ -1204,7 +1544,6 @@ template <class DCClass, int use_sel>
1204
1544
  struct IVFSQScannerIP : InvertedListScanner {
1205
1545
  DCClass dc;
1206
1546
  bool by_residual;
1207
- const IDSelector* sel;
1208
1547
 
1209
1548
  float accu0; /// added to all distances
1210
1549
 
@@ -1215,9 +1554,11 @@ struct IVFSQScannerIP : InvertedListScanner {
1215
1554
  bool store_pairs,
1216
1555
  const IDSelector* sel,
1217
1556
  bool by_residual)
1218
- : dc(d, trained), by_residual(by_residual), sel(sel), accu0(0) {
1557
+ : dc(d, trained), by_residual(by_residual), accu0(0) {
1219
1558
  this->store_pairs = store_pairs;
1559
+ this->sel = sel;
1220
1560
  this->code_size = code_size;
1561
+ this->keep_max = true;
1221
1562
  }
1222
1563
 
1223
1564
  void set_query(const float* query) override {
@@ -1288,7 +1629,6 @@ struct IVFSQScannerL2 : InvertedListScanner {
1288
1629
 
1289
1630
  bool by_residual;
1290
1631
  const Index* quantizer;
1291
- const IDSelector* sel;
1292
1632
  const float* x; /// current query
1293
1633
 
1294
1634
  std::vector<float> tmp;
@@ -1304,10 +1644,10 @@ struct IVFSQScannerL2 : InvertedListScanner {
1304
1644
  : dc(d, trained),
1305
1645
  by_residual(by_residual),
1306
1646
  quantizer(quantizer),
1307
- sel(sel),
1308
1647
  x(nullptr),
1309
1648
  tmp(d) {
1310
1649
  this->store_pairs = store_pairs;
1650
+ this->sel = sel;
1311
1651
  this->code_size = code_size;
1312
1652
  }
1313
1653
 
@@ -1509,7 +1849,7 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
1509
1849
  bool store_pairs,
1510
1850
  const IDSelector* sel,
1511
1851
  bool by_residual) const {
1512
- #ifdef USE_F16C
1852
+ #if defined(USE_F16C) || defined(__aarch64__)
1513
1853
  if (d % 8 == 0) {
1514
1854
  return sel0_InvertedListScanner<8>(
1515
1855
  mt, this, quantizer, store_pairs, sel, by_residual);