faiss 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -65,14 +65,6 @@ struct ScalarQuantizer : Quantizer {
65
65
 
66
66
  void train(size_t n, const float* x) override;
67
67
 
68
- /// Used by an IVF index to train based on the residuals
69
- void train_residual(
70
- size_t n,
71
- const float* x,
72
- Index* quantizer,
73
- bool by_residual,
74
- bool verbose);
75
-
76
68
  /** Encode a set of vectors
77
69
  *
78
70
  * @param x vectors to encode, size n * d
@@ -13,25 +13,218 @@
13
13
 
14
14
  #include <type_traits>
15
15
 
16
+ #include <faiss/impl/ProductQuantizer.h>
16
17
  #include <faiss/impl/code_distance/code_distance-generic.h>
17
18
 
18
19
  namespace {
19
20
 
21
+ inline float horizontal_sum(const __m128 v) {
22
+ const __m128 v0 = _mm_shuffle_ps(v, v, _MM_SHUFFLE(0, 0, 3, 2));
23
+ const __m128 v1 = _mm_add_ps(v, v0);
24
+ __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
25
+ const __m128 v3 = _mm_add_ps(v1, v2);
26
+ return _mm_cvtss_f32(v3);
27
+ }
28
+
20
29
  // Computes a horizontal sum over an __m256 register
21
- inline float horizontal_sum(const __m256 reg) {
22
- const __m256 h0 = _mm256_hadd_ps(reg, reg);
23
- const __m256 h1 = _mm256_hadd_ps(h0, h0);
30
+ inline float horizontal_sum(const __m256 v) {
31
+ const __m128 v0 =
32
+ _mm_add_ps(_mm256_castps256_ps128(v), _mm256_extractf128_ps(v, 1));
33
+ return horizontal_sum(v0);
34
+ }
35
+
36
+ // processes a single code for M=4, ksub=256, nbits=8
37
+ float inline distance_single_code_avx2_pqdecoder8_m4(
38
+ // precomputed distances, layout (4, 256)
39
+ const float* sim_table,
40
+ const uint8_t* code) {
41
+ float result = 0;
42
+
43
+ const float* tab = sim_table;
44
+ constexpr size_t ksub = 1 << 8;
45
+
46
+ const __m128i vksub = _mm_set1_epi32(ksub);
47
+ __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3);
48
+ offsets_0 = _mm_mullo_epi32(offsets_0, vksub);
49
+
50
+ // accumulators of partial sums
51
+ __m128 partialSum;
52
+
53
+ // load 4 uint8 values
54
+ const __m128i mm1 = _mm_cvtsi32_si128(*((const int32_t*)code));
55
+ {
56
+ // convert uint8 values (low part of __m128i) to int32
57
+ // values
58
+ const __m128i idx1 = _mm_cvtepu8_epi32(mm1);
24
59
 
25
- // extract high and low __m128 regs from __m256
26
- const __m128 h2 = _mm256_extractf128_ps(h1, 1);
27
- const __m128 h3 = _mm256_castps256_ps128(h1);
60
+ // add offsets
61
+ const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0);
28
62
 
29
- // get a final hsum into all 4 regs
30
- const __m128 h4 = _mm_add_ss(h2, h3);
63
+ // gather 8 values, similar to 8 operations of tab[idx]
64
+ __m128 collected =
65
+ _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float));
31
66
 
32
- // extract f[0] from __m128
33
- const float hsum = _mm_cvtss_f32(h4);
34
- return hsum;
67
+ // collect partial sums
68
+ partialSum = collected;
69
+ }
70
+
71
+ // horizontal sum for partialSum
72
+ result = horizontal_sum(partialSum);
73
+ return result;
74
+ }
75
+
76
+ // processes a single code for M=8, ksub=256, nbits=8
77
+ float inline distance_single_code_avx2_pqdecoder8_m8(
78
+ // precomputed distances, layout (8, 256)
79
+ const float* sim_table,
80
+ const uint8_t* code) {
81
+ float result = 0;
82
+
83
+ const float* tab = sim_table;
84
+ constexpr size_t ksub = 1 << 8;
85
+
86
+ const __m256i vksub = _mm256_set1_epi32(ksub);
87
+ __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
88
+ offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
89
+
90
+ // accumulators of partial sums
91
+ __m256 partialSum;
92
+
93
+ // load 8 uint8 values
94
+ const __m128i mm1 = _mm_loadu_si64((const __m128i_u*)code);
95
+ {
96
+ // convert uint8 values (low part of __m128i) to int32
97
+ // values
98
+ const __m256i idx1 = _mm256_cvtepu8_epi32(mm1);
99
+
100
+ // add offsets
101
+ const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0);
102
+
103
+ // gather 8 values, similar to 8 operations of tab[idx]
104
+ __m256 collected =
105
+ _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float));
106
+
107
+ // collect partial sums
108
+ partialSum = collected;
109
+ }
110
+
111
+ // horizontal sum for partialSum
112
+ result = horizontal_sum(partialSum);
113
+ return result;
114
+ }
115
+
116
+ // processes four codes for M=4, ksub=256, nbits=8
117
+ inline void distance_four_codes_avx2_pqdecoder8_m4(
118
+ // precomputed distances, layout (4, 256)
119
+ const float* sim_table,
120
+ // codes
121
+ const uint8_t* __restrict code0,
122
+ const uint8_t* __restrict code1,
123
+ const uint8_t* __restrict code2,
124
+ const uint8_t* __restrict code3,
125
+ // computed distances
126
+ float& result0,
127
+ float& result1,
128
+ float& result2,
129
+ float& result3) {
130
+ constexpr intptr_t N = 4;
131
+
132
+ const float* tab = sim_table;
133
+ constexpr size_t ksub = 1 << 8;
134
+
135
+ // process 8 values
136
+ const __m128i vksub = _mm_set1_epi32(ksub);
137
+ __m128i offsets_0 = _mm_setr_epi32(0, 1, 2, 3);
138
+ offsets_0 = _mm_mullo_epi32(offsets_0, vksub);
139
+
140
+ // accumulators of partial sums
141
+ __m128 partialSums[N];
142
+
143
+ // load 4 uint8 values
144
+ __m128i mm1[N];
145
+ mm1[0] = _mm_cvtsi32_si128(*((const int32_t*)code0));
146
+ mm1[1] = _mm_cvtsi32_si128(*((const int32_t*)code1));
147
+ mm1[2] = _mm_cvtsi32_si128(*((const int32_t*)code2));
148
+ mm1[3] = _mm_cvtsi32_si128(*((const int32_t*)code3));
149
+
150
+ for (intptr_t j = 0; j < N; j++) {
151
+ // convert uint8 values (low part of __m128i) to int32
152
+ // values
153
+ const __m128i idx1 = _mm_cvtepu8_epi32(mm1[j]);
154
+
155
+ // add offsets
156
+ const __m128i indices_to_read_from = _mm_add_epi32(idx1, offsets_0);
157
+
158
+ // gather 4 values, similar to 4 operations of tab[idx]
159
+ __m128 collected =
160
+ _mm_i32gather_ps(tab, indices_to_read_from, sizeof(float));
161
+
162
+ // collect partial sums
163
+ partialSums[j] = collected;
164
+ }
165
+
166
+ // horizontal sum for partialSum
167
+ result0 = horizontal_sum(partialSums[0]);
168
+ result1 = horizontal_sum(partialSums[1]);
169
+ result2 = horizontal_sum(partialSums[2]);
170
+ result3 = horizontal_sum(partialSums[3]);
171
+ }
172
+
173
+ // processes four codes for M=8, ksub=256, nbits=8
174
+ inline void distance_four_codes_avx2_pqdecoder8_m8(
175
+ // precomputed distances, layout (8, 256)
176
+ const float* sim_table,
177
+ // codes
178
+ const uint8_t* __restrict code0,
179
+ const uint8_t* __restrict code1,
180
+ const uint8_t* __restrict code2,
181
+ const uint8_t* __restrict code3,
182
+ // computed distances
183
+ float& result0,
184
+ float& result1,
185
+ float& result2,
186
+ float& result3) {
187
+ constexpr intptr_t N = 4;
188
+
189
+ const float* tab = sim_table;
190
+ constexpr size_t ksub = 1 << 8;
191
+
192
+ // process 8 values
193
+ const __m256i vksub = _mm256_set1_epi32(ksub);
194
+ __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
195
+ offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
196
+
197
+ // accumulators of partial sums
198
+ __m256 partialSums[N];
199
+
200
+ // load 8 uint8 values
201
+ __m128i mm1[N];
202
+ mm1[0] = _mm_loadu_si64((const __m128i_u*)code0);
203
+ mm1[1] = _mm_loadu_si64((const __m128i_u*)code1);
204
+ mm1[2] = _mm_loadu_si64((const __m128i_u*)code2);
205
+ mm1[3] = _mm_loadu_si64((const __m128i_u*)code3);
206
+
207
+ for (intptr_t j = 0; j < N; j++) {
208
+ // convert uint8 values (low part of __m128i) to int32
209
+ // values
210
+ const __m256i idx1 = _mm256_cvtepu8_epi32(mm1[j]);
211
+
212
+ // add offsets
213
+ const __m256i indices_to_read_from = _mm256_add_epi32(idx1, offsets_0);
214
+
215
+ // gather 8 values, similar to 8 operations of tab[idx]
216
+ __m256 collected =
217
+ _mm256_i32gather_ps(tab, indices_to_read_from, sizeof(float));
218
+
219
+ // collect partial sums
220
+ partialSums[j] = collected;
221
+ }
222
+
223
+ // horizontal sum for partialSum
224
+ result0 = horizontal_sum(partialSums[0]);
225
+ result1 = horizontal_sum(partialSums[1]);
226
+ result2 = horizontal_sum(partialSums[2]);
227
+ result3 = horizontal_sum(partialSums[3]);
35
228
  }
36
229
 
37
230
  } // namespace
@@ -41,36 +234,48 @@ namespace faiss {
41
234
  template <typename PQDecoderT>
42
235
  typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, float>::
43
236
  type inline distance_single_code_avx2(
44
- // the product quantizer
45
- const ProductQuantizer& pq,
237
+ // number of subquantizers
238
+ const size_t M,
239
+ // number of bits per quantization index
240
+ const size_t nbits,
46
241
  // precomputed distances, layout (M, ksub)
47
242
  const float* sim_table,
48
243
  const uint8_t* code) {
49
244
  // default implementation
50
- return distance_single_code_generic<PQDecoderT>(pq, sim_table, code);
245
+ return distance_single_code_generic<PQDecoderT>(M, nbits, sim_table, code);
51
246
  }
52
247
 
53
248
  template <typename PQDecoderT>
54
249
  typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
55
250
  type inline distance_single_code_avx2(
56
- // the product quantizer
57
- const ProductQuantizer& pq,
251
+ // number of subquantizers
252
+ const size_t M,
253
+ // number of bits per quantization index
254
+ const size_t nbits,
58
255
  // precomputed distances, layout (M, ksub)
59
256
  const float* sim_table,
60
257
  const uint8_t* code) {
258
+ if (M == 4) {
259
+ return distance_single_code_avx2_pqdecoder8_m4(sim_table, code);
260
+ }
261
+ if (M == 8) {
262
+ return distance_single_code_avx2_pqdecoder8_m8(sim_table, code);
263
+ }
264
+
61
265
  float result = 0;
266
+ constexpr size_t ksub = 1 << 8;
62
267
 
63
268
  size_t m = 0;
64
- const size_t pqM16 = pq.M / 16;
269
+ const size_t pqM16 = M / 16;
65
270
 
66
271
  const float* tab = sim_table;
67
272
 
68
273
  if (pqM16 > 0) {
69
274
  // process 16 values per loop
70
275
 
71
- const __m256i ksub = _mm256_set1_epi32(pq.ksub);
276
+ const __m256i vksub = _mm256_set1_epi32(ksub);
72
277
  __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
73
- offsets_0 = _mm256_mullo_epi32(offsets_0, ksub);
278
+ offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
74
279
 
75
280
  // accumulators of partial sums
76
281
  __m256 partialSum = _mm256_setzero_ps();
@@ -91,7 +296,7 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
91
296
  // gather 8 values, similar to 8 operations of tab[idx]
92
297
  __m256 collected = _mm256_i32gather_ps(
93
298
  tab, indices_to_read_from, sizeof(float));
94
- tab += pq.ksub * 8;
299
+ tab += ksub * 8;
95
300
 
96
301
  // collect partial sums
97
302
  partialSum = _mm256_add_ps(partialSum, collected);
@@ -111,7 +316,7 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
111
316
  // gather 8 values, similar to 8 operations of tab[idx]
112
317
  __m256 collected = _mm256_i32gather_ps(
113
318
  tab, indices_to_read_from, sizeof(float));
114
- tab += pq.ksub * 8;
319
+ tab += ksub * 8;
115
320
 
116
321
  // collect partial sums
117
322
  partialSum = _mm256_add_ps(partialSum, collected);
@@ -123,13 +328,13 @@ typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, float>::
123
328
  }
124
329
 
125
330
  //
126
- if (m < pq.M) {
331
+ if (m < M) {
127
332
  // process leftovers
128
- PQDecoder8 decoder(code + m, pq.nbits);
333
+ PQDecoder8 decoder(code + m, nbits);
129
334
 
130
- for (; m < pq.M; m++) {
335
+ for (; m < M; m++) {
131
336
  result += tab[decoder.decode()];
132
- tab += pq.ksub;
337
+ tab += ksub;
133
338
  }
134
339
  }
135
340
 
@@ -140,8 +345,10 @@ template <typename PQDecoderT>
140
345
  typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, void>::
141
346
  type
142
347
  distance_four_codes_avx2(
143
- // the product quantizer
144
- const ProductQuantizer& pq,
348
+ // number of subquantizers
349
+ const size_t M,
350
+ // number of bits per quantization index
351
+ const size_t nbits,
145
352
  // precomputed distances, layout (M, ksub)
146
353
  const float* sim_table,
147
354
  // codes
@@ -155,7 +362,8 @@ typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, void>::
155
362
  float& result2,
156
363
  float& result3) {
157
364
  distance_four_codes_generic<PQDecoderT>(
158
- pq,
365
+ M,
366
+ nbits,
159
367
  sim_table,
160
368
  code0,
161
369
  code1,
@@ -171,8 +379,10 @@ typename std::enable_if<!std::is_same<PQDecoderT, PQDecoder8>::value, void>::
171
379
  template <typename PQDecoderT>
172
380
  typename std::enable_if<std::is_same<PQDecoderT, PQDecoder8>::value, void>::type
173
381
  distance_four_codes_avx2(
174
- // the product quantizer
175
- const ProductQuantizer& pq,
382
+ // number of subquantizers
383
+ const size_t M,
384
+ // number of bits per quantization index
385
+ const size_t nbits,
176
386
  // precomputed distances, layout (M, ksub)
177
387
  const float* sim_table,
178
388
  // codes
@@ -185,13 +395,41 @@ distance_four_codes_avx2(
185
395
  float& result1,
186
396
  float& result2,
187
397
  float& result3) {
398
+ if (M == 4) {
399
+ distance_four_codes_avx2_pqdecoder8_m4(
400
+ sim_table,
401
+ code0,
402
+ code1,
403
+ code2,
404
+ code3,
405
+ result0,
406
+ result1,
407
+ result2,
408
+ result3);
409
+ return;
410
+ }
411
+ if (M == 8) {
412
+ distance_four_codes_avx2_pqdecoder8_m8(
413
+ sim_table,
414
+ code0,
415
+ code1,
416
+ code2,
417
+ code3,
418
+ result0,
419
+ result1,
420
+ result2,
421
+ result3);
422
+ return;
423
+ }
424
+
188
425
  result0 = 0;
189
426
  result1 = 0;
190
427
  result2 = 0;
191
428
  result3 = 0;
429
+ constexpr size_t ksub = 1 << 8;
192
430
 
193
431
  size_t m = 0;
194
- const size_t pqM16 = pq.M / 16;
432
+ const size_t pqM16 = M / 16;
195
433
 
196
434
  constexpr intptr_t N = 4;
197
435
 
@@ -199,9 +437,9 @@ distance_four_codes_avx2(
199
437
 
200
438
  if (pqM16 > 0) {
201
439
  // process 16 values per loop
202
- const __m256i ksub = _mm256_set1_epi32(pq.ksub);
440
+ const __m256i vksub = _mm256_set1_epi32(ksub);
203
441
  __m256i offsets_0 = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7);
204
- offsets_0 = _mm256_mullo_epi32(offsets_0, ksub);
442
+ offsets_0 = _mm256_mullo_epi32(offsets_0, vksub);
205
443
 
206
444
  // accumulators of partial sums
207
445
  __m256 partialSums[N];
@@ -235,7 +473,7 @@ distance_four_codes_avx2(
235
473
  // collect partial sums
236
474
  partialSums[j] = _mm256_add_ps(partialSums[j], collected);
237
475
  }
238
- tab += pq.ksub * 8;
476
+ tab += ksub * 8;
239
477
 
240
478
  // process next 8 codes
241
479
  for (intptr_t j = 0; j < N; j++) {
@@ -259,7 +497,7 @@ distance_four_codes_avx2(
259
497
  partialSums[j] = _mm256_add_ps(partialSums[j], collected);
260
498
  }
261
499
 
262
- tab += pq.ksub * 8;
500
+ tab += ksub * 8;
263
501
  }
264
502
 
265
503
  // horizontal sum for partialSum
@@ -270,18 +508,18 @@ distance_four_codes_avx2(
270
508
  }
271
509
 
272
510
  //
273
- if (m < pq.M) {
511
+ if (m < M) {
274
512
  // process leftovers
275
- PQDecoder8 decoder0(code0 + m, pq.nbits);
276
- PQDecoder8 decoder1(code1 + m, pq.nbits);
277
- PQDecoder8 decoder2(code2 + m, pq.nbits);
278
- PQDecoder8 decoder3(code3 + m, pq.nbits);
279
- for (; m < pq.M; m++) {
513
+ PQDecoder8 decoder0(code0 + m, nbits);
514
+ PQDecoder8 decoder1(code1 + m, nbits);
515
+ PQDecoder8 decoder2(code2 + m, nbits);
516
+ PQDecoder8 decoder3(code3 + m, nbits);
517
+ for (; m < M; m++) {
280
518
  result0 += tab[decoder0.decode()];
281
519
  result1 += tab[decoder1.decode()];
282
520
  result2 += tab[decoder2.decode()];
283
521
  result3 += tab[decoder3.decode()];
284
- tab += pq.ksub;
522
+ tab += ksub;
285
523
  }
286
524
  }
287
525
  }
@@ -7,27 +7,31 @@
7
7
 
8
8
  #pragma once
9
9
 
10
- #include <faiss/impl/ProductQuantizer.h>
10
+ #include <cstddef>
11
+ #include <cstdint>
11
12
 
12
13
  namespace faiss {
13
14
 
14
15
  /// Returns the distance to a single code.
15
16
  template <typename PQDecoderT>
16
17
  inline float distance_single_code_generic(
17
- // the product quantizer
18
- const ProductQuantizer& pq,
18
+ // number of subquantizers
19
+ const size_t M,
20
+ // number of bits per quantization index
21
+ const size_t nbits,
19
22
  // precomputed distances, layout (M, ksub)
20
23
  const float* sim_table,
21
24
  // the code
22
25
  const uint8_t* code) {
23
- PQDecoderT decoder(code, pq.nbits);
26
+ PQDecoderT decoder(code, nbits);
27
+ const size_t ksub = 1 << nbits;
24
28
 
25
29
  const float* tab = sim_table;
26
30
  float result = 0;
27
31
 
28
- for (size_t m = 0; m < pq.M; m++) {
32
+ for (size_t m = 0; m < M; m++) {
29
33
  result += tab[decoder.decode()];
30
- tab += pq.ksub;
34
+ tab += ksub;
31
35
  }
32
36
 
33
37
  return result;
@@ -37,8 +41,10 @@ inline float distance_single_code_generic(
37
41
  /// General-purpose version.
38
42
  template <typename PQDecoderT>
39
43
  inline void distance_four_codes_generic(
40
- // the product quantizer
41
- const ProductQuantizer& pq,
44
+ // number of subquantizers
45
+ const size_t M,
46
+ // number of bits per quantization index
47
+ const size_t nbits,
42
48
  // precomputed distances, layout (M, ksub)
43
49
  const float* sim_table,
44
50
  // codes
@@ -51,10 +57,11 @@ inline void distance_four_codes_generic(
51
57
  float& result1,
52
58
  float& result2,
53
59
  float& result3) {
54
- PQDecoderT decoder0(code0, pq.nbits);
55
- PQDecoderT decoder1(code1, pq.nbits);
56
- PQDecoderT decoder2(code2, pq.nbits);
57
- PQDecoderT decoder3(code3, pq.nbits);
60
+ PQDecoderT decoder0(code0, nbits);
61
+ PQDecoderT decoder1(code1, nbits);
62
+ PQDecoderT decoder2(code2, nbits);
63
+ PQDecoderT decoder3(code3, nbits);
64
+ const size_t ksub = 1 << nbits;
58
65
 
59
66
  const float* tab = sim_table;
60
67
  result0 = 0;
@@ -62,12 +69,12 @@ inline void distance_four_codes_generic(
62
69
  result2 = 0;
63
70
  result3 = 0;
64
71
 
65
- for (size_t m = 0; m < pq.M; m++) {
72
+ for (size_t m = 0; m < M; m++) {
66
73
  result0 += tab[decoder0.decode()];
67
74
  result1 += tab[decoder1.decode()];
68
75
  result2 += tab[decoder2.decode()];
69
76
  result3 += tab[decoder3.decode()];
70
- tab += pq.ksub;
77
+ tab += ksub;
71
78
  }
72
79
  }
73
80
 
@@ -32,19 +32,23 @@ namespace faiss {
32
32
 
33
33
  template <typename PQDecoderT>
34
34
  inline float distance_single_code(
35
- // the product quantizer
36
- const ProductQuantizer& pq,
35
+ // number of subquantizers
36
+ const size_t M,
37
+ // number of bits per quantization index
38
+ const size_t nbits,
37
39
  // precomputed distances, layout (M, ksub)
38
40
  const float* sim_table,
39
41
  // the code
40
42
  const uint8_t* code) {
41
- return distance_single_code_avx2<PQDecoderT>(pq, sim_table, code);
43
+ return distance_single_code_avx2<PQDecoderT>(M, nbits, sim_table, code);
42
44
  }
43
45
 
44
46
  template <typename PQDecoderT>
45
47
  inline void distance_four_codes(
46
- // the product quantizer
47
- const ProductQuantizer& pq,
48
+ // number of subquantizers
49
+ const size_t M,
50
+ // number of bits per quantization index
51
+ const size_t nbits,
48
52
  // precomputed distances, layout (M, ksub)
49
53
  const float* sim_table,
50
54
  // codes
@@ -58,7 +62,8 @@ inline void distance_four_codes(
58
62
  float& result2,
59
63
  float& result3) {
60
64
  distance_four_codes_avx2<PQDecoderT>(
61
- pq,
65
+ M,
66
+ nbits,
62
67
  sim_table,
63
68
  code0,
64
69
  code1,
@@ -80,19 +85,23 @@ namespace faiss {
80
85
 
81
86
  template <typename PQDecoderT>
82
87
  inline float distance_single_code(
83
- // the product quantizer
84
- const ProductQuantizer& pq,
88
+ // number of subquantizers
89
+ const size_t M,
90
+ // number of bits per quantization index
91
+ const size_t nbits,
85
92
  // precomputed distances, layout (M, ksub)
86
93
  const float* sim_table,
87
94
  // the code
88
95
  const uint8_t* code) {
89
- return distance_single_code_generic<PQDecoderT>(pq, sim_table, code);
96
+ return distance_single_code_generic<PQDecoderT>(M, nbits, sim_table, code);
90
97
  }
91
98
 
92
99
  template <typename PQDecoderT>
93
100
  inline void distance_four_codes(
94
- // the product quantizer
95
- const ProductQuantizer& pq,
101
+ // number of subquantizers
102
+ const size_t M,
103
+ // number of bits per quantization index
104
+ const size_t nbits,
96
105
  // precomputed distances, layout (M, ksub)
97
106
  const float* sim_table,
98
107
  // codes
@@ -106,7 +115,8 @@ inline void distance_four_codes(
106
115
  float& result2,
107
116
  float& result3) {
108
117
  distance_four_codes_generic<PQDecoderT>(
109
- pq,
118
+ M,
119
+ nbits,
110
120
  sim_table,
111
121
  code0,
112
122
  code1,