faiss 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -65,14 +65,6 @@ struct ScalarQuantizer : Quantizer {
65
65
 
66
66
  void train(size_t n, const float* x) override;
67
67
 
68
- /// Used by an IVF index to train based on the residuals
69
- void train_residual(
70
- size_t n,
71
- const float* x,
72
- Index* quantizer,
73
- bool by_residual,
74
- bool verbose);
75
-
76
68
  /** Encode a set of vectors
77
69
  *
78
70
  * @param x vectors to encode, size n * d
@@ -13,25 +13,218 @@
13
13
 
14
14
  #include <type_traits>
15
15
 
16
+ #include <faiss/impl/ProductQuantizer.h>
16
17
  #include <faiss/impl/code_distance/code_distance-generic.h>
17
18
 
18
19
  namespace {
19
20
 
21
+ inline float horizontal_sum(const __m128 v) {
22
+ const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2));
23
+ const __m128 v1 = _mm_add_ps(v, v0);
24
+ __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
25
+ const __m128 v3 = _mm_add_ps(v1, v2);
26
+ return _mm_cvtss_f32(v3);
27
+ }
28
+
20
29
  // Computes a horizontal sum over an __m256 register
21
- inline float horizontal_sum(const __m256 reg) {
22
- const __m256 h0 = _mm256_hadd_ps(reg, reg);
23
- const __m256 h1 = _mm256_hadd_ps(h0, h0);
30
+ inline float horizontal_sum(const __m256 v) {
31
+ const __m128 v0 =
32
+ _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
33
+ return horizontal_sum(v0);
34
+ }
35
+
36
+ // processes a single code for M=4, ksub=256, nbits=8
37
+ float inline distance_single_code_avx2_pqdecoder8_m4(
38
+ // precomputed distances, layout (4, 256)
39
+ const float* sim_table,
40
+ const uint8_t* code) {
41
+ float result = 0;
42
+
43
+ const float* tab = sim_table;
44
+ constexpr size_t ksub = 1 << 8;
45
+
46
+ const __m128i vksub = _mm_set1_epi32(ksub);
47
+ __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3);
48
+ offsets_0 = _mm_mullo_epi32(offsets_0, vksub);
49
+
50
+ // accumulators of partial sums
51
+ __m128 partialSum;
52
+
53
+ // load 4 uint8 values
54
+ const __m128i mm1 = _mm_cvtsi32_si128(*((const int32_t*)code));
55
+ {
56
+ // convert uint8 values (low part of __m128i) to int32
57
+ // values
58
+ const __m128i idx1 = _mm_cvtepu8_epi32(mm1);
24
59
 
25
- // extract high and low __m128 regs from __m256
26
- const __m128 h2 = _mm256_extractf128_ps(h1, 1);
27
- const __m128 h3 = _mm256_castps256_ps128(h1);
60
+ // add offsets
61
+ const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0);
28
62
 
29
- // get a final hsum into all 4 regs
30
- const __m128 h4 = _mm_add_ss(h2, h3);
63
+ // gather 8 values, similar to 8 operations of tab[idx]
64
+ __m128 collected =
65
+ _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float));
31
66
 
32
- // extract f[0] from __m128
33
- const float hsum = _mm_cvtss_f32(h4);
34
- return hsum;
67
+ // collect partial sums
68
+ partialSum = collected;
69
+ }
70
+
71
+ // horizontal sum for partialSum
72
+ result = horizontal_sum(partialSum);
73
+ return result;
74
+ }
75
+
76
+ // processes a single code for M=8, ksub=256, nbits=8
77
+ float inline distance_single_code_avx2_pqdecoder8_m8(
78
+ // precomputed distances, layout (8, 256)
79
+ const float* sim_table,
80
+ const uint8_t* code) {
81
+ float result = 0;
82
+
83
+ const float* tab = sim_table;
84
+ constexpr size_t ksub = 1 << 8;
85
+
86
+ const __m256i vksub = _mm256_set1_epi32(ksub);
87
+ __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
88
+ offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
89
+
90
+ // accumulators of partial sums
91
+ __m256 partialSum;
92
+
93
+ // load 8 uint8 values
94
+ const __m128i mm1 = _mm_loadu_si64((const __m128i_u*)code);
95
+ {
96
+ // convert uint8 values (low part of __m128i) to int32
97
+ // values
98
+ const __m256i idx1 = _mm256_cvtepu8_epi32(mm1);
99
+
100
+ // add offsets
101
+ const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0);
102
+
103
+ // gather 8 values, similar to 8 operations of tab[idx]
104
+ __m256 collected =
105
+ _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float));
106
+
107
+ // collect partial sums
108
+ partialSum = collected;
109
+ }
110
+
111
+ // horizontal sum for partialSum
112
+ result = horizontal_sum(partialSum);
113
+ return result;
114
+ }
115
+
116
+ // processes four codes for M=4, ksub=256, nbits=8
117
+ inline void distance_four_codes_avx2_pqdecoder8_m4(
118
+ // precomputed distances, layout (4, 256)
119
+ const float* sim_table,
120
+ // codes
121
+ const uint8_t* __restrict code0,
122
+ const uint8_t* __restrict code1,
123
+ const uint8_t* __restrict code2,
124
+ const uint8_t* __restrict code3,
125
+ // computed distances
126
+ float& result0,
127
+ float& result1,
128
+ float& result2,
129
+ float& result3) {
130
+ constexpr intptr_t N = 4;
131
+
132
+ const float* tab = sim_table;
133
+ constexpr size_t ksub = 1 << 8;
134
+
135
+ // process 8 values
136
+ const __m128i vksub = _mm_set1_epi32(ksub);
137
+ __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3);
138
+ offsets_0 = _mm_mullo_epi32(offsets_0, vksub);
139
+
140
+ // accumulators of partial sums
141
+ __m128 partialSums[N];
142
+
143
+ // load 4 uint8 values
144
+ __m128i mm1[N];
145
+ mm1[0] = _mm_cvtsi32_si128(*((const int32_t*)code0));
146
+ mm1[1] = _mm_cvtsi32_si128(*((const int32_t*)code1));
147
+ mm1[2] = _mm_cvtsi32_si128(*((const int32_t*)code2));
148
+ mm1[3] = _mm_cvtsi32_si128(*((const int32_t*)code3));
149
+
150
+ for (intptr_t j = 0; j < N; j++) {
151
+ // convert uint8 values (low part of __m128i) to int32
152
+ // values
153
+ const __m128i idx1 = _mm_cvtepu8_epi32(mm1[j]);
154
+
155
+ // add offsets
156
+ const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0);
157
+
158
+ // gather 4 values, similar to 4 operations of tab[idx]
159
+ __m128 collected =
160
+ _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float));
161
+
162
+ // collect partial sums
163
+ partialSums[j] = collected;
164
+ }
165
+
166
+ // horizontal sum for partialSum
167
+ result0 = horizontal_sum(partialSums[0]);
168
+ result1 = horizontal_sum(partialSums[1]);
169
+ result2 = horizontal_sum(partialSums[2]);
170
+ result3 = horizontal_sum(partialSums[3]);
171
+ }
172
+
173
+ // processes four codes for M=8, ksub=256, nbits=8
174
+ inline void distance_four_codes_avx2_pqdecoder8_m8(
175
+ // precomputed distances, layout (8, 256)
176
+ const float* sim_table,
177
+ // codes
178
+ const uint8_t* __restrict code0,
179
+ const uint8_t* __restrict code1,
180
+ const uint8_t* __restrict code2,
181
+ const uint8_t* __restrict code3,
182
+ // computed distances
183
+ float& result0,
184
+ float& result1,
185
+ float& result2,
186
+ float& result3) {
187
+ constexpr intptr_t N = 4;
188
+
189
+ const float* tab = sim_table;
190
+ constexpr size_t ksub = 1 << 8;
191
+
192
+ // process 8 values
193
+ const __m256i vksub = _mm256_set1_epi32(ksub);
194
+ __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
195
+ offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
196
+
197
+ // accumulators of partial sums
198
+ __m256 partialSums[N];
199
+
200
+ // load 8 uint8 values
201
+ __m128i mm1[N];
202
+ mm1[0] = _mm_loadu_si64((const __m128i_u*)code0);
203
+ mm1[1] = _mm_loadu_si64((const __m128i_u*)code1);
204
+ mm1[2] = _mm_loadu_si64((const __m128i_u*)code2);
205
+ mm1[3] = _mm_loadu_si64((const __m128i_u*)code3);
206
+
207
+ for (intptr_t j = 0; j < N; j++) {
208
+ // convert uint8 values (low part of __m128i) to int32
209
+ // values
210
+ const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]);
211
+
212
+ // add offsets
213
+ const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0);
214
+
215
+ // gather 8 values, similar to 8 operations of tab[idx]
216
+ __m256 collected =
217
+ _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float));
218
+
219
+ // collect partial sums
220
+ partialSums[j] = collected;
221
+ }
222
+
223
+ // horizontal sum for partialSum
224
+ result0 = horizontal_sum(partialSums[0]);
225
+ result1 = horizontal_sum(partialSums[1]);
226
+ result2 = horizontal_sum(partialSums[2]);
227
+ result3 = horizontal_sum(partialSums[3]);
35
228
  }
36
229
 
37
230
  } // namespace
@@ -41,36 +234,48 @@ namespace faiss {
41
234
  template <typename PQDecoderT>
42
235
  typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, float>::
43
236
  type inline distance_single_code_avx2(
44
- // the product quantizer
45
- const ProductQuantizer& pq,
237
+ // number of subquantizers
238
+ const size_t M,
239
+ // number of bits per quantization index
240
+ const size_t nbits,
46
241
  // precomputed distances, layout (M, ksub)
47
242
  const float* sim_table,
48
243
  const uint8_t* code) {
49
244
  // default implementation
50
- return distance_single_code_generic<PQDecoderT>(pq, sim_table, code);
245
+ return distance_single_code_generic<PQDecoderT>(M, nbits, sim_table, code);
51
246
  }
52
247
 
53
248
  template <typename PQDecoderT>
54
249
  typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
55
250
  type inline distance_single_code_avx2(
56
- // the product quantizer
57
- const ProductQuantizer& pq,
251
+ // number of subquantizers
252
+ const size_t M,
253
+ // number of bits per quantization index
254
+ const size_t nbits,
58
255
  // precomputed distances, layout (M, ksub)
59
256
  const float* sim_table,
60
257
  const uint8_t* code) {
258
+ if (M == 4) {
259
+ return distance_single_code_avx2_pqdecoder8_m4(sim_table, code);
260
+ }
261
+ if (M == 8) {
262
+ return distance_single_code_avx2_pqdecoder8_m8(sim_table, code);
263
+ }
264
+
61
265
  float result = 0;
266
+ constexpr size_t ksub = 1 << 8;
62
267
 
63
268
  size_t m = 0;
64
- const size_t pqM16 = pq.M / 16;
269
+ const size_t pqM16 = M / 16;
65
270
 
66
271
  const float* tab = sim_table;
67
272
 
68
273
  if (pqM16 > 0) {
69
274
  // process 16 values per loop
70
275
 
71
- const __m256i ksub = _mm256_set1_epi32(pq.ksub);
276
+ const __m256i vksub = _mm256_set1_epi32(ksub);
72
277
  __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
73
- offsets_0 = _mm256_mullo_epi32(offsets_0, ksub);
278
+ offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
74
279
 
75
280
  // accumulators of partial sums
76
281
  __m256 partialSum = _mm256_setzero_ps();
@@ -91,7 +296,7 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
91
296
  // gather 8 values, similar to 8 operations of tab[idx]
92
297
  __m256 collected = _mm256_i32gather_ps(
93
298
  tab, indices_to_read_from, sizeof(float));
94
- tab += pq.ksub * 8;
299
+ tab += ksub * 8;
95
300
 
96
301
  // collect partial sums
97
302
  partialSum = _mm256_add_ps(partialSum, collected);
@@ -111,7 +316,7 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
111
316
  // gather 8 values, similar to 8 operations of tab[idx]
112
317
  __m256 collected = _mm256_i32gather_ps(
113
318
  tab, indices_to_read_from, sizeof(float));
114
- tab += pq.ksub * 8;
319
+ tab += ksub * 8;
115
320
 
116
321
  // collect partial sums
117
322
  partialSum = _mm256_add_ps(partialSum, collected);
@@ -123,13 +328,13 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
123
328
  }
124
329
 
125
330
  //
126
- if (m < pq.M) {
331
+ if (m < M) {
127
332
  // process leftovers
128
- PQDecoder8 decoder(code + m, pq.nbits);
333
+ PQDecoder8 decoder(code + m, nbits);
129
334
 
130
- for (; m < pq.M; m++) {
335
+ for (; m < M; m++) {
131
336
  result += tab[decoder.decode()];
132
- tab += pq.ksub;
337
+ tab += ksub;
133
338
  }
134
339
  }
135
340
 
@@ -140,8 +345,10 @@ template <typename PQDecoderT>
140
345
  typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, void>::
141
346
  type
142
347
  distance_four_codes_avx2(
143
- // the product quantizer
144
- const ProductQuantizer& pq,
348
+ // number of subquantizers
349
+ const size_t M,
350
+ // number of bits per quantization index
351
+ const size_t nbits,
145
352
  // precomputed distances, layout (M, ksub)
146
353
  const float* sim_table,
147
354
  // codes
@@ -155,7 +362,8 @@ typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, void>::
155
362
  float& result2,
156
363
  float& result3) {
157
364
  distance_four_codes_generic<PQDecoderT>(
158
- pq,
365
+ M,
366
+ nbits,
159
367
  sim_table,
160
368
  code0,
161
369
  code1,
@@ -171,8 +379,10 @@ typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, void>::
171
379
  template <typename PQDecoderT>
172
380
  typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, void>::type
173
381
  distance_four_codes_avx2(
174
- // the product quantizer
175
- const ProductQuantizer& pq,
382
+ // number of subquantizers
383
+ const size_t M,
384
+ // number of bits per quantization index
385
+ const size_t nbits,
176
386
  // precomputed distances, layout (M, ksub)
177
387
  const float* sim_table,
178
388
  // codes
@@ -185,13 +395,41 @@ distance_four_codes_avx2(
185
395
  float& result1,
186
396
  float& result2,
187
397
  float& result3) {
398
+ if (M == 4) {
399
+ distance_four_codes_avx2_pqdecoder8_m4(
400
+ sim_table,
401
+ code0,
402
+ code1,
403
+ code2,
404
+ code3,
405
+ result0,
406
+ result1,
407
+ result2,
408
+ result3);
409
+ return;
410
+ }
411
+ if (M == 8) {
412
+ distance_four_codes_avx2_pqdecoder8_m8(
413
+ sim_table,
414
+ code0,
415
+ code1,
416
+ code2,
417
+ code3,
418
+ result0,
419
+ result1,
420
+ result2,
421
+ result3);
422
+ return;
423
+ }
424
+
188
425
  result0 = 0;
189
426
  result1 = 0;
190
427
  result2 = 0;
191
428
  result3 = 0;
429
+ constexpr size_t ksub = 1 << 8;
192
430
 
193
431
  size_t m = 0;
194
- const size_t pqM16 = pq.M / 16;
432
+ const size_t pqM16 = M / 16;
195
433
 
196
434
  constexpr intptr_t N = 4;
197
435
 
@@ -199,9 +437,9 @@ distance_four_codes_avx2(
199
437
 
200
438
  if (pqM16 > 0) {
201
439
  // process 16 values per loop
202
- const __m256i ksub = _mm256_set1_epi32(pq.ksub);
440
+ const __m256i vksub = _mm256_set1_epi32(ksub);
203
441
  __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
204
- offsets_0 = _mm256_mullo_epi32(offsets_0, ksub);
442
+ offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
205
443
 
206
444
  // accumulators of partial sums
207
445
  __m256 partialSums[N];
@@ -235,7 +473,7 @@ distance_four_codes_avx2(
235
473
  // collect partial sums
236
474
  partialSums[j] = _mm256_add_ps(partialSums[j], collected);
237
475
  }
238
- tab += pq.ksub * 8;
476
+ tab += ksub * 8;
239
477
 
240
478
  // process next 8 codes
241
479
  for (intptr_t j = 0; j < N; j++) {
@@ -259,7 +497,7 @@ distance_four_codes_avx2(
259
497
  partialSums[j] = _mm256_add_ps(partialSums[j], collected);
260
498
  }
261
499
 
262
- tab += pq.ksub * 8;
500
+ tab += ksub * 8;
263
501
  }
264
502
 
265
503
  // horizontal sum for partialSum
@@ -270,18 +508,18 @@ distance_four_codes_avx2(
270
508
  }
271
509
 
272
510
  //
273
- if (m < pq.M) {
511
+ if (m < M) {
274
512
  // process leftovers
275
- PQDecoder8 decoder0(code0 + m, pq.nbits);
276
- PQDecoder8 decoder1(code1 + m, pq.nbits);
277
- PQDecoder8 decoder2(code2 + m, pq.nbits);
278
- PQDecoder8 decoder3(code3 + m, pq.nbits);
279
- for (; m < pq.M; m++) {
513
+ PQDecoder8 decoder0(code0 + m, nbits);
514
+ PQDecoder8 decoder1(code1 + m, nbits);
515
+ PQDecoder8 decoder2(code2 + m, nbits);
516
+ PQDecoder8 decoder3(code3 + m, nbits);
517
+ for (; m < M; m++) {
280
518
  result0 += tab[decoder0.decode()];
281
519
  result1 += tab[decoder1.decode()];
282
520
  result2 += tab[decoder2.decode()];
283
521
  result3 += tab[decoder3.decode()];
284
- tab += pq.ksub;
522
+ tab += ksub;
285
523
  }
286
524
  }
287
525
  }
@@ -7,27 +7,31 @@
7
7
 
8
8
  #pragma once
9
9
 
10
- #include <faiss/impl/ProductQuantizer.h>
10
+ #include <cstddef>
11
+ #include <cstdint>
11
12
 
12
13
  namespace faiss {
13
14
 
14
15
  /// Returns the distance to a single code.
15
16
  template <typename PQDecoderT>
16
17
  inline float distance_single_code_generic(
17
- // the product quantizer
18
- const ProductQuantizer& pq,
18
+ // number of subquantizers
19
+ const size_t M,
20
+ // number of bits per quantization index
21
+ const size_t nbits,
19
22
  // precomputed distances, layout (M, ksub)
20
23
  const float* sim_table,
21
24
  // the code
22
25
  const uint8_t* code) {
23
- PQDecoderT decoder(code, pq.nbits);
26
+ PQDecoderT decoder(code, nbits);
27
+ const size_t ksub = 1 << nbits;
24
28
 
25
29
  const float* tab = sim_table;
26
30
  float result = 0;
27
31
 
28
- for (size_t m = 0; m < pq.M; m++) {
32
+ for (size_t m = 0; m < M; m++) {
29
33
  result += tab[decoder.decode()];
30
- tab += pq.ksub;
34
+ tab += ksub;
31
35
  }
32
36
 
33
37
  return result;
@@ -37,8 +41,10 @@ inline float distance_single_code_generic(
37
41
  /// General-purpose version.
38
42
  template <typename PQDecoderT>
39
43
  inline void distance_four_codes_generic(
40
- // the product quantizer
41
- const ProductQuantizer& pq,
44
+ // number of subquantizers
45
+ const size_t M,
46
+ // number of bits per quantization index
47
+ const size_t nbits,
42
48
  // precomputed distances, layout (M, ksub)
43
49
  const float* sim_table,
44
50
  // codes
@@ -51,10 +57,11 @@ inline void distance_four_codes_generic(
51
57
  float& result1,
52
58
  float& result2,
53
59
  float& result3) {
54
- PQDecoderT decoder0(code0, pq.nbits);
55
- PQDecoderT decoder1(code1, pq.nbits);
56
- PQDecoderT decoder2(code2, pq.nbits);
57
- PQDecoderT decoder3(code3, pq.nbits);
60
+ PQDecoderT decoder0(code0, nbits);
61
+ PQDecoderT decoder1(code1, nbits);
62
+ PQDecoderT decoder2(code2, nbits);
63
+ PQDecoderT decoder3(code3, nbits);
64
+ const size_t ksub = 1 << nbits;
58
65
 
59
66
  const float* tab = sim_table;
60
67
  result0 = 0;
@@ -62,12 +69,12 @@ inline void distance_four_codes_generic(
62
69
  result2 = 0;
63
70
  result3 = 0;
64
71
 
65
- for (size_t m = 0; m < pq.M; m++) {
72
+ for (size_t m = 0; m < M; m++) {
66
73
  result0 += tab[decoder0.decode()];
67
74
  result1 += tab[decoder1.decode()];
68
75
  result2 += tab[decoder2.decode()];
69
76
  result3 += tab[decoder3.decode()];
70
- tab += pq.ksub;
77
+ tab += ksub;
71
78
  }
72
79
  }
73
80
 
@@ -32,19 +32,23 @@ namespace faiss {
32
32
 
33
33
  template <typename PQDecoderT>
34
34
  inline float distance_single_code(
35
- // the product quantizer
36
- const ProductQuantizer& pq,
35
+ // number of subquantizers
36
+ const size_t M,
37
+ // number of bits per quantization index
38
+ const size_t nbits,
37
39
  // precomputed distances, layout (M, ksub)
38
40
  const float* sim_table,
39
41
  // the code
40
42
  const uint8_t* code) {
41
- return distance_single_code_avx2<PQDecoderT>(pq, sim_table, code);
43
+ return distance_single_code_avx2<PQDecoderT>(M, nbits, sim_table, code);
42
44
  }
43
45
 
44
46
  template <typename PQDecoderT>
45
47
  inline void distance_four_codes(
46
- // the product quantizer
47
- const ProductQuantizer& pq,
48
+ // number of subquantizers
49
+ const size_t M,
50
+ // number of bits per quantization index
51
+ const size_t nbits,
48
52
  // precomputed distances, layout (M, ksub)
49
53
  const float* sim_table,
50
54
  // codes
@@ -58,7 +62,8 @@ inline void distance_four_codes(
58
62
  float& result2,
59
63
  float& result3) {
60
64
  distance_four_codes_avx2<PQDecoderT>(
61
- pq,
65
+ M,
66
+ nbits,
62
67
  sim_table,
63
68
  code0,
64
69
  code1,
@@ -80,19 +85,23 @@ namespace faiss {
80
85
 
81
86
  template <typename PQDecoderT>
82
87
  inline float distance_single_code(
83
- // the product quantizer
84
- const ProductQuantizer& pq,
88
+ // number of subquantizers
89
+ const size_t M,
90
+ // number of bits per quantization index
91
+ const size_t nbits,
85
92
  // precomputed distances, layout (M, ksub)
86
93
  const float* sim_table,
87
94
  // the code
88
95
  const uint8_t* code) {
89
- return distance_single_code_generic<PQDecoderT>(pq, sim_table, code);
96
+ return distance_single_code_generic<PQDecoderT>(M, nbits, sim_table, code);
90
97
  }
91
98
 
92
99
  template <typename PQDecoderT>
93
100
  inline void distance_four_codes(
94
- // the product quantizer
95
- const ProductQuantizer& pq,
101
+ // number of subquantizers
102
+ const size_t M,
103
+ // number of bits per quantization index
104
+ const size_t nbits,
96
105
  // precomputed distances, layout (M, ksub)
97
106
  const float* sim_table,
98
107
  // codes
@@ -106,7 +115,8 @@ inline void distance_four_codes(
106
115
  float& result2,
107
116
  float& result3) {
108
117
  distance_four_codes_generic<PQDecoderT>(
109
- pq,
118
+ M,
119
+ nbits,
110
120
  sim_table,
111
121
  code0,
112
122
  code1,