faiss 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -65,42 +65,65 @@ using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer;
65
65
  */
66
66
 
67
67
  struct Codec8bit {
68
- static void encode_component(float x, uint8_t* code, int i) {
68
+ static FAISS_ALWAYS_INLINE void encode_component(
69
+ float x,
70
+ uint8_t* code,
71
+ int i) {
69
72
  code[i] = (int)(255 * x);
70
73
  }
71
74
 
72
- static float decode_component(const uint8_t* code, int i) {
75
+ static FAISS_ALWAYS_INLINE float decode_component(
76
+ const uint8_t* code,
77
+ int i) {
73
78
  return (code[i] + 0.5f) / 255.0f;
74
79
  }
75
80
 
76
81
  #ifdef __AVX2__
77
- static __m256 decode_8_components(const uint8_t* code, int i) {
78
- uint64_t c8 = *(uint64_t*)(code + i);
79
- __m128i c4lo = _mm_cvtepu8_epi32(_mm_set1_epi32(c8));
80
- __m128i c4hi = _mm_cvtepu8_epi32(_mm_set1_epi32(c8 >> 32));
81
- // __m256i i8 = _mm256_set_m128i(c4lo, c4hi);
82
- __m256i i8 = _mm256_castsi128_si256(c4lo);
83
- i8 = _mm256_insertf128_si256(i8, c4hi, 1);
84
- __m256 f8 = _mm256_cvtepi32_ps(i8);
85
- __m256 half = _mm256_set1_ps(0.5f);
86
- f8 = _mm256_add_ps(f8, half);
87
- __m256 one_255 = _mm256_set1_ps(1.f / 255.f);
88
- return _mm256_mul_ps(f8, one_255);
82
+ static FAISS_ALWAYS_INLINE __m256
83
+ decode_8_components(const uint8_t* code, int i) {
84
+ const uint64_t c8 = *(uint64_t*)(code + i);
85
+
86
+ const __m128i i8 = _mm_set1_epi64x(c8);
87
+ const __m256i i32 = _mm256_cvtepu8_epi32(i8);
88
+ const __m256 f8 = _mm256_cvtepi32_ps(i32);
89
+ const __m256 half_one_255 = _mm256_set1_ps(0.5f / 255.f);
90
+ const __m256 one_255 = _mm256_set1_ps(1.f / 255.f);
91
+ return _mm256_fmadd_ps(f8, one_255, half_one_255);
92
+ }
93
+ #endif
94
+
95
+ #ifdef __aarch64__
96
+ static FAISS_ALWAYS_INLINE float32x4x2_t
97
+ decode_8_components(const uint8_t* code, int i) {
98
+ float32_t result[8] = {};
99
+ for (size_t j = 0; j < 8; j++) {
100
+ result[j] = decode_component(code, i + j);
101
+ }
102
+ float32x4_t res1 = vld1q_f32(result);
103
+ float32x4_t res2 = vld1q_f32(result + 4);
104
+ float32x4x2_t res = vzipq_f32(res1, res2);
105
+ return vuzpq_f32(res.val[0], res.val[1]);
89
106
  }
90
107
  #endif
91
108
  };
92
109
 
93
110
  struct Codec4bit {
94
- static void encode_component(float x, uint8_t* code, int i) {
111
+ static FAISS_ALWAYS_INLINE void encode_component(
112
+ float x,
113
+ uint8_t* code,
114
+ int i) {
95
115
  code[i / 2] |= (int)(x * 15.0) << ((i & 1) << 2);
96
116
  }
97
117
 
98
- static float decode_component(const uint8_t* code, int i) {
118
+ static FAISS_ALWAYS_INLINE float decode_component(
119
+ const uint8_t* code,
120
+ int i) {
99
121
  return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
100
122
  }
101
123
 
102
124
  #ifdef __AVX2__
103
- static __m256 decode_8_components(const uint8_t* code, int i) {
125
+ static FAISS_ALWAYS_INLINE __m256
126
+ decode_8_components(const uint8_t* code, int i) {
104
127
  uint32_t c4 = *(uint32_t*)(code + (i >> 1));
105
128
  uint32_t mask = 0x0f0f0f0f;
106
129
  uint32_t c4ev = c4 & mask;
@@ -120,10 +143,27 @@ struct Codec4bit {
120
143
  return _mm256_mul_ps(f8, one_255);
121
144
  }
122
145
  #endif
146
+
147
+ #ifdef __aarch64__
148
+ static FAISS_ALWAYS_INLINE float32x4x2_t
149
+ decode_8_components(const uint8_t* code, int i) {
150
+ float32_t result[8] = {};
151
+ for (size_t j = 0; j < 8; j++) {
152
+ result[j] = decode_component(code, i + j);
153
+ }
154
+ float32x4_t res1 = vld1q_f32(result);
155
+ float32x4_t res2 = vld1q_f32(result + 4);
156
+ float32x4x2_t res = vzipq_f32(res1, res2);
157
+ return vuzpq_f32(res.val[0], res.val[1]);
158
+ }
159
+ #endif
123
160
  };
124
161
 
125
162
  struct Codec6bit {
126
- static void encode_component(float x, uint8_t* code, int i) {
163
+ static FAISS_ALWAYS_INLINE void encode_component(
164
+ float x,
165
+ uint8_t* code,
166
+ int i) {
127
167
  int bits = (int)(x * 63.0);
128
168
  code += (i >> 2) * 3;
129
169
  switch (i & 3) {
@@ -144,7 +184,9 @@ struct Codec6bit {
144
184
  }
145
185
  }
146
186
 
147
- static float decode_component(const uint8_t* code, int i) {
187
+ static FAISS_ALWAYS_INLINE float decode_component(
188
+ const uint8_t* code,
189
+ int i) {
148
190
  uint8_t bits;
149
191
  code += (i >> 2) * 3;
150
192
  switch (i & 3) {
@@ -170,7 +212,7 @@ struct Codec6bit {
170
212
 
171
213
  /* Load 6 bytes that represent 8 6-bit values, return them as a
172
214
  * 8*32 bit vector register */
173
- static __m256i load6(const uint16_t* code16) {
215
+ static FAISS_ALWAYS_INLINE __m256i load6(const uint16_t* code16) {
174
216
  const __m128i perm = _mm_set_epi8(
175
217
  -1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0);
176
218
  const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0);
@@ -189,18 +231,45 @@ struct Codec6bit {
189
231
  return c5;
190
232
  }
191
233
 
192
- static __m256 decode_8_components(const uint8_t* code, int i) {
234
+ static FAISS_ALWAYS_INLINE __m256
235
+ decode_8_components(const uint8_t* code, int i) {
236
+ // // Faster code for Intel CPUs or AMD Zen3+, just keeping it here
237
+ // // for the reference, maybe, it becomes used oned day.
238
+ // const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3);
239
+ // const uint32_t* data32 = (const uint32_t*)data16;
240
+ // const uint64_t val = *data32 + ((uint64_t)data16[2] << 32);
241
+ // const uint64_t vext = _pdep_u64(val, 0x3F3F3F3F3F3F3F3FULL);
242
+ // const __m128i i8 = _mm_set1_epi64x(vext);
243
+ // const __m256i i32 = _mm256_cvtepi8_epi32(i8);
244
+ // const __m256 f8 = _mm256_cvtepi32_ps(i32);
245
+ // const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
246
+ // const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
247
+ // return _mm256_fmadd_ps(f8, one_255, half_one_255);
248
+
193
249
  __m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3));
194
250
  __m256 f8 = _mm256_cvtepi32_ps(i8);
195
251
  // this could also be done with bit manipulations but it is
196
252
  // not obviously faster
197
- __m256 half = _mm256_set1_ps(0.5f);
198
- f8 = _mm256_add_ps(f8, half);
199
- __m256 one_63 = _mm256_set1_ps(1.f / 63.f);
200
- return _mm256_mul_ps(f8, one_63);
253
+ const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
254
+ const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
255
+ return _mm256_fmadd_ps(f8, one_255, half_one_255);
201
256
  }
202
257
 
203
258
  #endif
259
+
260
+ #ifdef __aarch64__
261
+ static FAISS_ALWAYS_INLINE float32x4x2_t
262
+ decode_8_components(const uint8_t* code, int i) {
263
+ float32_t result[8] = {};
264
+ for (size_t j = 0; j < 8; j++) {
265
+ result[j] = decode_component(code, i + j);
266
+ }
267
+ float32x4_t res1 = vld1q_f32(result);
268
+ float32x4_t res2 = vld1q_f32(result + 4);
269
+ float32x4x2_t res = vzipq_f32(res1, res2);
270
+ return vuzpq_f32(res.val[0], res.val[1]);
271
+ }
272
+ #endif
204
273
  };
205
274
 
206
275
  /*******************************************************************
@@ -242,7 +311,8 @@ struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
242
311
  }
243
312
  }
244
313
 
245
- float reconstruct_component(const uint8_t* code, int i) const {
314
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
315
+ const {
246
316
  float xi = Codec::decode_component(code, i);
247
317
  return vmin + xi * vdiff;
248
318
  }
@@ -255,11 +325,36 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
255
325
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
256
326
  : QuantizerTemplate<Codec, true, 1>(d, trained) {}
257
327
 
258
- __m256 reconstruct_8_components(const uint8_t* code, int i) const {
328
+ FAISS_ALWAYS_INLINE __m256
329
+ reconstruct_8_components(const uint8_t* code, int i) const {
259
330
  __m256 xi = Codec::decode_8_components(code, i);
260
- return _mm256_add_ps(
261
- _mm256_set1_ps(this->vmin),
262
- _mm256_mul_ps(xi, _mm256_set1_ps(this->vdiff)));
331
+ return _mm256_fmadd_ps(
332
+ xi, _mm256_set1_ps(this->vdiff), _mm256_set1_ps(this->vmin));
333
+ }
334
+ };
335
+
336
+ #endif
337
+
338
+ #ifdef __aarch64__
339
+
340
+ template <class Codec>
341
+ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
342
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
343
+ : QuantizerTemplate<Codec, true, 1>(d, trained) {}
344
+
345
+ FAISS_ALWAYS_INLINE float32x4x2_t
346
+ reconstruct_8_components(const uint8_t* code, int i) const {
347
+ float32x4x2_t xi = Codec::decode_8_components(code, i);
348
+ float32x4x2_t res = vzipq_f32(
349
+ vfmaq_f32(
350
+ vdupq_n_f32(this->vmin),
351
+ xi.val[0],
352
+ vdupq_n_f32(this->vdiff)),
353
+ vfmaq_f32(
354
+ vdupq_n_f32(this->vmin),
355
+ xi.val[1],
356
+ vdupq_n_f32(this->vdiff)));
357
+ return vuzpq_f32(res.val[0], res.val[1]);
263
358
  }
264
359
  };
265
360
 
@@ -296,7 +391,8 @@ struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
296
391
  }
297
392
  }
298
393
 
299
- float reconstruct_component(const uint8_t* code, int i) const {
394
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
395
+ const {
300
396
  float xi = Codec::decode_component(code, i);
301
397
  return vmin[i] + xi * vdiff[i];
302
398
  }
@@ -309,11 +405,36 @@ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
309
405
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
310
406
  : QuantizerTemplate<Codec, false, 1>(d, trained) {}
311
407
 
312
- __m256 reconstruct_8_components(const uint8_t* code, int i) const {
408
+ FAISS_ALWAYS_INLINE __m256
409
+ reconstruct_8_components(const uint8_t* code, int i) const {
313
410
  __m256 xi = Codec::decode_8_components(code, i);
314
- return _mm256_add_ps(
315
- _mm256_loadu_ps(this->vmin + i),
316
- _mm256_mul_ps(xi, _mm256_loadu_ps(this->vdiff + i)));
411
+ return _mm256_fmadd_ps(
412
+ xi,
413
+ _mm256_loadu_ps(this->vdiff + i),
414
+ _mm256_loadu_ps(this->vmin + i));
415
+ }
416
+ };
417
+
418
+ #endif
419
+
420
+ #ifdef __aarch64__
421
+
422
+ template <class Codec>
423
+ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
424
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
425
+ : QuantizerTemplate<Codec, false, 1>(d, trained) {}
426
+
427
+ FAISS_ALWAYS_INLINE float32x4x2_t
428
+ reconstruct_8_components(const uint8_t* code, int i) const {
429
+ float32x4x2_t xi = Codec::decode_8_components(code, i);
430
+
431
+ float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i);
432
+ float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i);
433
+
434
+ float32x4x2_t res = vzipq_f32(
435
+ vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
436
+ vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1]));
437
+ return vuzpq_f32(res.val[0], res.val[1]);
317
438
  }
318
439
  };
319
440
 
@@ -344,7 +465,8 @@ struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer {
344
465
  }
345
466
  }
346
467
 
347
- float reconstruct_component(const uint8_t* code, int i) const {
468
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
469
+ const {
348
470
  return decode_fp16(((uint16_t*)code)[i]);
349
471
  }
350
472
  };
@@ -356,7 +478,8 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
356
478
  QuantizerFP16(size_t d, const std::vector<float>& trained)
357
479
  : QuantizerFP16<1>(d, trained) {}
358
480
 
359
- __m256 reconstruct_8_components(const uint8_t* code, int i) const {
481
+ FAISS_ALWAYS_INLINE __m256
482
+ reconstruct_8_components(const uint8_t* code, int i) const {
360
483
  __m128i codei = _mm_loadu_si128((const __m128i*)(code + 2 * i));
361
484
  return _mm256_cvtph_ps(codei);
362
485
  }
@@ -364,6 +487,23 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
364
487
 
365
488
  #endif
366
489
 
490
+ #ifdef __aarch64__
491
+
492
+ template <>
493
+ struct QuantizerFP16<8> : QuantizerFP16<1> {
494
+ QuantizerFP16(size_t d, const std::vector<float>& trained)
495
+ : QuantizerFP16<1>(d, trained) {}
496
+
497
+ FAISS_ALWAYS_INLINE float32x4x2_t
498
+ reconstruct_8_components(const uint8_t* code, int i) const {
499
+ uint16x4x2_t codei = vld2_u16((const uint16_t*)(code + 2 * i));
500
+ return vzipq_f32(
501
+ vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
502
+ vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1])));
503
+ }
504
+ };
505
+ #endif
506
+
367
507
  /*******************************************************************
368
508
  * 8bit_direct quantizer
369
509
  *******************************************************************/
@@ -390,7 +530,8 @@ struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer {
390
530
  }
391
531
  }
392
532
 
393
- float reconstruct_component(const uint8_t* code, int i) const {
533
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
534
+ const {
394
535
  return code[i];
395
536
  }
396
537
  };
@@ -402,7 +543,8 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
402
543
  Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
403
544
  : Quantizer8bitDirect<1>(d, trained) {}
404
545
 
405
- __m256 reconstruct_8_components(const uint8_t* code, int i) const {
546
+ FAISS_ALWAYS_INLINE __m256
547
+ reconstruct_8_components(const uint8_t* code, int i) const {
406
548
  __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
407
549
  __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32
408
550
  return _mm256_cvtepi32_ps(y8); // 8 * float32
@@ -411,6 +553,28 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
411
553
 
412
554
  #endif
413
555
 
556
+ #ifdef __aarch64__
557
+
558
+ template <>
559
+ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
560
+ Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
561
+ : Quantizer8bitDirect<1>(d, trained) {}
562
+
563
+ FAISS_ALWAYS_INLINE float32x4x2_t
564
+ reconstruct_8_components(const uint8_t* code, int i) const {
565
+ float32_t result[8] = {};
566
+ for (size_t j = 0; j < 8; j++) {
567
+ result[j] = code[i + j];
568
+ }
569
+ float32x4_t res1 = vld1q_f32(result);
570
+ float32x4_t res2 = vld1q_f32(result + 4);
571
+ float32x4x2_t res = vzipq_f32(res1, res2);
572
+ return vuzpq_f32(res.val[0], res.val[1]);
573
+ }
574
+ };
575
+
576
+ #endif
577
+
414
578
  template <int SIMDWIDTH>
415
579
  ScalarQuantizer::SQuantizer* select_quantizer_1(
416
580
  QuantizerType qtype,
@@ -486,7 +650,7 @@ void train_Uniform(
486
650
  } else if (rs == ScalarQuantizer::RS_quantiles) {
487
651
  std::vector<float> x_copy(n);
488
652
  memcpy(x_copy.data(), x, n * sizeof(*x));
489
- // TODO just do a qucikselect
653
+ // TODO just do a quickselect
490
654
  std::sort(x_copy.begin(), x_copy.end());
491
655
  int o = int(rs_arg * n);
492
656
  if (o < 0)
@@ -632,22 +796,22 @@ struct SimilarityL2<1> {
632
796
 
633
797
  float accu;
634
798
 
635
- void begin() {
799
+ FAISS_ALWAYS_INLINE void begin() {
636
800
  accu = 0;
637
801
  yi = y;
638
802
  }
639
803
 
640
- void add_component(float x) {
804
+ FAISS_ALWAYS_INLINE void add_component(float x) {
641
805
  float tmp = *yi++ - x;
642
806
  accu += tmp * tmp;
643
807
  }
644
808
 
645
- void add_component_2(float x1, float x2) {
809
+ FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) {
646
810
  float tmp = x1 - x2;
647
811
  accu += tmp * tmp;
648
812
  }
649
813
 
650
- float result() {
814
+ FAISS_ALWAYS_INLINE float result() {
651
815
  return accu;
652
816
  }
653
817
  };
@@ -663,34 +827,89 @@ struct SimilarityL2<8> {
663
827
  explicit SimilarityL2(const float* y) : y(y) {}
664
828
  __m256 accu8;
665
829
 
666
- void begin_8() {
830
+ FAISS_ALWAYS_INLINE void begin_8() {
667
831
  accu8 = _mm256_setzero_ps();
668
832
  yi = y;
669
833
  }
670
834
 
671
- void add_8_components(__m256 x) {
835
+ FAISS_ALWAYS_INLINE void add_8_components(__m256 x) {
672
836
  __m256 yiv = _mm256_loadu_ps(yi);
673
837
  yi += 8;
674
838
  __m256 tmp = _mm256_sub_ps(yiv, x);
675
- accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(tmp, tmp));
839
+ accu8 = _mm256_fmadd_ps(tmp, tmp, accu8);
676
840
  }
677
841
 
678
- void add_8_components_2(__m256 x, __m256 y) {
679
- __m256 tmp = _mm256_sub_ps(y, x);
680
- accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(tmp, tmp));
842
+ FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x, __m256 y_2) {
843
+ __m256 tmp = _mm256_sub_ps(y_2, x);
844
+ accu8 = _mm256_fmadd_ps(tmp, tmp, accu8);
681
845
  }
682
846
 
683
- float result_8() {
684
- __m256 sum = _mm256_hadd_ps(accu8, accu8);
685
- __m256 sum2 = _mm256_hadd_ps(sum, sum);
686
- // now add the 0th and 4th component
687
- return _mm_cvtss_f32(_mm256_castps256_ps128(sum2)) +
688
- _mm_cvtss_f32(_mm256_extractf128_ps(sum2, 1));
847
+ FAISS_ALWAYS_INLINE float result_8() {
848
+ const __m128 sum = _mm_add_ps(
849
+ _mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1));
850
+ const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2));
851
+ const __m128 v1 = _mm_add_ps(sum, v0);
852
+ __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
853
+ const __m128 v3 = _mm_add_ps(v1, v2);
854
+ return _mm_cvtss_f32(v3);
689
855
  }
690
856
  };
691
857
 
692
858
  #endif
693
859
 
860
+ #ifdef __aarch64__
861
+ template <>
862
+ struct SimilarityL2<8> {
863
+ static constexpr int simdwidth = 8;
864
+ static constexpr MetricType metric_type = METRIC_L2;
865
+
866
+ const float *y, *yi;
867
+ explicit SimilarityL2(const float* y) : y(y) {}
868
+ float32x4x2_t accu8;
869
+
870
+ FAISS_ALWAYS_INLINE void begin_8() {
871
+ accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
872
+ yi = y;
873
+ }
874
+
875
+ FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) {
876
+ float32x4x2_t yiv = vld1q_f32_x2(yi);
877
+ yi += 8;
878
+
879
+ float32x4_t sub0 = vsubq_f32(yiv.val[0], x.val[0]);
880
+ float32x4_t sub1 = vsubq_f32(yiv.val[1], x.val[1]);
881
+
882
+ float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
883
+ float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
884
+
885
+ float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
886
+ accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
887
+ }
888
+
889
+ FAISS_ALWAYS_INLINE void add_8_components_2(
890
+ float32x4x2_t x,
891
+ float32x4x2_t y) {
892
+ float32x4_t sub0 = vsubq_f32(y.val[0], x.val[0]);
893
+ float32x4_t sub1 = vsubq_f32(y.val[1], x.val[1]);
894
+
895
+ float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
896
+ float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
897
+
898
+ float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
899
+ accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
900
+ }
901
+
902
+ FAISS_ALWAYS_INLINE float result_8() {
903
+ float32x4_t sum_0 = vpaddq_f32(accu8.val[0], accu8.val[0]);
904
+ float32x4_t sum_1 = vpaddq_f32(accu8.val[1], accu8.val[1]);
905
+
906
+ float32x4_t sum2_0 = vpaddq_f32(sum_0, sum_0);
907
+ float32x4_t sum2_1 = vpaddq_f32(sum_1, sum_1);
908
+ return vgetq_lane_f32(sum2_0, 0) + vgetq_lane_f32(sum2_1, 0);
909
+ }
910
+ };
911
+ #endif
912
+
694
913
  template <int SIMDWIDTH>
695
914
  struct SimilarityIP {};
696
915
 
@@ -704,20 +923,20 @@ struct SimilarityIP<1> {
704
923
 
705
924
  explicit SimilarityIP(const float* y) : y(y) {}
706
925
 
707
- void begin() {
926
+ FAISS_ALWAYS_INLINE void begin() {
708
927
  accu = 0;
709
928
  yi = y;
710
929
  }
711
930
 
712
- void add_component(float x) {
931
+ FAISS_ALWAYS_INLINE void add_component(float x) {
713
932
  accu += *yi++ * x;
714
933
  }
715
934
 
716
- void add_component_2(float x1, float x2) {
935
+ FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) {
717
936
  accu += x1 * x2;
718
937
  }
719
938
 
720
- float result() {
939
+ FAISS_ALWAYS_INLINE float result() {
721
940
  return accu;
722
941
  }
723
942
  };
@@ -737,27 +956,79 @@ struct SimilarityIP<8> {
737
956
 
738
957
  __m256 accu8;
739
958
 
740
- void begin_8() {
959
+ FAISS_ALWAYS_INLINE void begin_8() {
741
960
  accu8 = _mm256_setzero_ps();
742
961
  yi = y;
743
962
  }
744
963
 
745
- void add_8_components(__m256 x) {
964
+ FAISS_ALWAYS_INLINE void add_8_components(__m256 x) {
746
965
  __m256 yiv = _mm256_loadu_ps(yi);
747
966
  yi += 8;
748
- accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(yiv, x));
967
+ accu8 = _mm256_fmadd_ps(yiv, x, accu8);
968
+ }
969
+
970
+ FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x1, __m256 x2) {
971
+ accu8 = _mm256_fmadd_ps(x1, x2, accu8);
972
+ }
973
+
974
+ FAISS_ALWAYS_INLINE float result_8() {
975
+ const __m128 sum = _mm_add_ps(
976
+ _mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1));
977
+ const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2));
978
+ const __m128 v1 = _mm_add_ps(sum, v0);
979
+ __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
980
+ const __m128 v3 = _mm_add_ps(v1, v2);
981
+ return _mm_cvtss_f32(v3);
982
+ }
983
+ };
984
+ #endif
985
+
986
+ #ifdef __aarch64__
987
+
988
+ template <>
989
+ struct SimilarityIP<8> {
990
+ static constexpr int simdwidth = 8;
991
+ static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
992
+
993
+ const float *y, *yi;
994
+
995
+ explicit SimilarityIP(const float* y) : y(y) {}
996
+ float32x4x2_t accu8;
997
+
998
+ FAISS_ALWAYS_INLINE void begin_8() {
999
+ accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
1000
+ yi = y;
1001
+ }
1002
+
1003
+ FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) {
1004
+ float32x4x2_t yiv = vld1q_f32_x2(yi);
1005
+ yi += 8;
1006
+
1007
+ float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]);
1008
+ float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]);
1009
+ float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
1010
+ accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
749
1011
  }
750
1012
 
751
- void add_8_components_2(__m256 x1, __m256 x2) {
752
- accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(x1, x2));
1013
+ FAISS_ALWAYS_INLINE void add_8_components_2(
1014
+ float32x4x2_t x1,
1015
+ float32x4x2_t x2) {
1016
+ float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]);
1017
+ float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]);
1018
+ float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
1019
+ accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
753
1020
  }
754
1021
 
755
- float result_8() {
756
- __m256 sum = _mm256_hadd_ps(accu8, accu8);
757
- __m256 sum2 = _mm256_hadd_ps(sum, sum);
758
- // now add the 0th and 4th component
759
- return _mm_cvtss_f32(_mm256_castps256_ps128(sum2)) +
760
- _mm_cvtss_f32(_mm256_extractf128_ps(sum2, 1));
1022
+ FAISS_ALWAYS_INLINE float result_8() {
1023
+ float32x4x2_t sum_tmp = vzipq_f32(
1024
+ vpaddq_f32(accu8.val[0], accu8.val[0]),
1025
+ vpaddq_f32(accu8.val[1], accu8.val[1]));
1026
+ float32x4x2_t sum = vuzpq_f32(sum_tmp.val[0], sum_tmp.val[1]);
1027
+ float32x4x2_t sum2_tmp = vzipq_f32(
1028
+ vpaddq_f32(sum.val[0], sum.val[0]),
1029
+ vpaddq_f32(sum.val[1], sum.val[1]));
1030
+ float32x4x2_t sum2 = vuzpq_f32(sum2_tmp.val[0], sum2_tmp.val[1]);
1031
+ return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0);
761
1032
  }
762
1033
  };
763
1034
  #endif
@@ -864,6 +1135,53 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
864
1135
 
865
1136
  #endif
866
1137
 
1138
+ #ifdef __aarch64__
1139
+
1140
+ template <class Quantizer, class Similarity>
1141
+ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
1142
+ using Sim = Similarity;
1143
+
1144
+ Quantizer quant;
1145
+
1146
+ DCTemplate(size_t d, const std::vector<float>& trained)
1147
+ : quant(d, trained) {}
1148
+ float compute_distance(const float* x, const uint8_t* code) const {
1149
+ Similarity sim(x);
1150
+ sim.begin_8();
1151
+ for (size_t i = 0; i < quant.d; i += 8) {
1152
+ float32x4x2_t xi = quant.reconstruct_8_components(code, i);
1153
+ sim.add_8_components(xi);
1154
+ }
1155
+ return sim.result_8();
1156
+ }
1157
+
1158
+ float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1159
+ const {
1160
+ Similarity sim(nullptr);
1161
+ sim.begin_8();
1162
+ for (size_t i = 0; i < quant.d; i += 8) {
1163
+ float32x4x2_t x1 = quant.reconstruct_8_components(code1, i);
1164
+ float32x4x2_t x2 = quant.reconstruct_8_components(code2, i);
1165
+ sim.add_8_components_2(x1, x2);
1166
+ }
1167
+ return sim.result_8();
1168
+ }
1169
+
1170
+ void set_query(const float* x) final {
1171
+ q = x;
1172
+ }
1173
+
1174
+ float symmetric_dis(idx_t i, idx_t j) override {
1175
+ return compute_code_distance(
1176
+ codes + i * code_size, codes + j * code_size);
1177
+ }
1178
+
1179
+ float query_to_code(const uint8_t* code) const final {
1180
+ return compute_distance(q, code);
1181
+ }
1182
+ };
1183
+ #endif
1184
+
867
1185
  /*******************************************************************
868
1186
  * DistanceComputerByte: computes distances in the integer domain
869
1187
  *******************************************************************/
@@ -980,6 +1298,54 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
980
1298
 
981
1299
  #endif
982
1300
 
1301
+ #ifdef __aarch64__
1302
+
1303
+ template <class Similarity>
1304
+ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1305
+ using Sim = Similarity;
1306
+
1307
+ int d;
1308
+ std::vector<uint8_t> tmp;
1309
+
1310
+ DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1311
+
1312
+ int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1313
+ const {
1314
+ int accu = 0;
1315
+ for (int i = 0; i < d; i++) {
1316
+ if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1317
+ accu += int(code1[i]) * code2[i];
1318
+ } else {
1319
+ int diff = int(code1[i]) - code2[i];
1320
+ accu += diff * diff;
1321
+ }
1322
+ }
1323
+ return accu;
1324
+ }
1325
+
1326
+ void set_query(const float* x) final {
1327
+ for (int i = 0; i < d; i++) {
1328
+ tmp[i] = int(x[i]);
1329
+ }
1330
+ }
1331
+
1332
+ int compute_distance(const float* x, const uint8_t* code) {
1333
+ set_query(x);
1334
+ return compute_code_distance(tmp.data(), code);
1335
+ }
1336
+
1337
+ float symmetric_dis(idx_t i, idx_t j) override {
1338
+ return compute_code_distance(
1339
+ codes + i * code_size, codes + j * code_size);
1340
+ }
1341
+
1342
+ float query_to_code(const uint8_t* code) const final {
1343
+ return compute_code_distance(tmp.data(), code);
1344
+ }
1345
+ };
1346
+
1347
+ #endif
1348
+
983
1349
  /*******************************************************************
984
1350
  * select_distance_computer: runtime selection of template
985
1351
  * specialization
@@ -1115,34 +1481,8 @@ void ScalarQuantizer::train(size_t n, const float* x) {
1115
1481
  }
1116
1482
  }
1117
1483
 
1118
- void ScalarQuantizer::train_residual(
1119
- size_t n,
1120
- const float* x,
1121
- Index* quantizer,
1122
- bool by_residual,
1123
- bool verbose) {
1124
- const float* x_in = x;
1125
-
1126
- // 100k points more than enough
1127
- x = fvecs_maybe_subsample(d, (size_t*)&n, 100000, x, verbose, 1234);
1128
-
1129
- ScopeDeleter<float> del_x(x_in == x ? nullptr : x);
1130
-
1131
- if (by_residual) {
1132
- std::vector<idx_t> idx(n);
1133
- quantizer->assign(n, x, idx.data());
1134
-
1135
- std::vector<float> residuals(n * d);
1136
- quantizer->compute_residual_n(n, x, residuals.data(), idx.data());
1137
-
1138
- train(n, residuals.data());
1139
- } else {
1140
- train(n, x);
1141
- }
1142
- }
1143
-
1144
1484
  ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
1145
- #ifdef USE_F16C
1485
+ #if defined(USE_F16C) || defined(__aarch64__)
1146
1486
  if (d % 8 == 0) {
1147
1487
  return select_quantizer_1<8>(qtype, d, trained);
1148
1488
  } else
@@ -1173,7 +1513,7 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
1173
1513
  SQDistanceComputer* ScalarQuantizer::get_distance_computer(
1174
1514
  MetricType metric) const {
1175
1515
  FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1176
- #ifdef USE_F16C
1516
+ #if defined(USE_F16C) || defined(__aarch64__)
1177
1517
  if (d % 8 == 0) {
1178
1518
  if (metric == METRIC_L2) {
1179
1519
  return select_distance_computer<SimilarityL2<8>>(qtype, d, trained);
@@ -1204,7 +1544,6 @@ template <class DCClass, int use_sel>
1204
1544
  struct IVFSQScannerIP : InvertedListScanner {
1205
1545
  DCClass dc;
1206
1546
  bool by_residual;
1207
- const IDSelector* sel;
1208
1547
 
1209
1548
  float accu0; /// added to all distances
1210
1549
 
@@ -1215,9 +1554,11 @@ struct IVFSQScannerIP : InvertedListScanner {
1215
1554
  bool store_pairs,
1216
1555
  const IDSelector* sel,
1217
1556
  bool by_residual)
1218
- : dc(d, trained), by_residual(by_residual), sel(sel), accu0(0) {
1557
+ : dc(d, trained), by_residual(by_residual), accu0(0) {
1219
1558
  this->store_pairs = store_pairs;
1559
+ this->sel = sel;
1220
1560
  this->code_size = code_size;
1561
+ this->keep_max = true;
1221
1562
  }
1222
1563
 
1223
1564
  void set_query(const float* query) override {
@@ -1288,7 +1629,6 @@ struct IVFSQScannerL2 : InvertedListScanner {
1288
1629
 
1289
1630
  bool by_residual;
1290
1631
  const Index* quantizer;
1291
- const IDSelector* sel;
1292
1632
  const float* x; /// current query
1293
1633
 
1294
1634
  std::vector<float> tmp;
@@ -1304,10 +1644,10 @@ struct IVFSQScannerL2 : InvertedListScanner {
1304
1644
  : dc(d, trained),
1305
1645
  by_residual(by_residual),
1306
1646
  quantizer(quantizer),
1307
- sel(sel),
1308
1647
  x(nullptr),
1309
1648
  tmp(d) {
1310
1649
  this->store_pairs = store_pairs;
1650
+ this->sel = sel;
1311
1651
  this->code_size = code_size;
1312
1652
  }
1313
1653
 
@@ -1509,7 +1849,7 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
1509
1849
  bool store_pairs,
1510
1850
  const IDSelector* sel,
1511
1851
  bool by_residual) const {
1512
- #ifdef USE_F16C
1852
+ #if defined(USE_F16C) || defined(__aarch64__)
1513
1853
  if (d % 8 == 0) {
1514
1854
  return sel0_InvertedListScanner<8>(
1515
1855
  mt, this, quantizer, store_pairs, sel, by_residual);