faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -43,137 +43,12 @@ int sgemm_(
43
43
 
44
44
  namespace faiss {
45
45
 
46
- /* compute an estimator using look-up tables for typical values of M */
47
- template <typename CT, class C>
48
- void pq_estimators_from_tables_Mmul4(
49
- int M,
50
- const CT* codes,
51
- size_t ncodes,
52
- const float* __restrict dis_table,
53
- size_t ksub,
54
- size_t k,
55
- float* heap_dis,
56
- int64_t* heap_ids) {
57
- for (size_t j = 0; j < ncodes; j++) {
58
- float dis = 0;
59
- const float* dt = dis_table;
60
-
61
- for (size_t m = 0; m < M; m += 4) {
62
- float dism = 0;
63
- dism = dt[*codes++];
64
- dt += ksub;
65
- dism += dt[*codes++];
66
- dt += ksub;
67
- dism += dt[*codes++];
68
- dt += ksub;
69
- dism += dt[*codes++];
70
- dt += ksub;
71
- dis += dism;
72
- }
73
-
74
- if (C::cmp(heap_dis[0], dis)) {
75
- heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
76
- }
77
- }
78
- }
79
-
80
- template <typename CT, class C>
81
- void pq_estimators_from_tables_M4(
82
- const CT* codes,
83
- size_t ncodes,
84
- const float* __restrict dis_table,
85
- size_t ksub,
86
- size_t k,
87
- float* heap_dis,
88
- int64_t* heap_ids) {
89
- for (size_t j = 0; j < ncodes; j++) {
90
- float dis = 0;
91
- const float* dt = dis_table;
92
- dis = dt[*codes++];
93
- dt += ksub;
94
- dis += dt[*codes++];
95
- dt += ksub;
96
- dis += dt[*codes++];
97
- dt += ksub;
98
- dis += dt[*codes++];
99
-
100
- if (C::cmp(heap_dis[0], dis)) {
101
- heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
102
- }
103
- }
104
- }
105
-
106
- template <typename CT, class C>
107
- static inline void pq_estimators_from_tables(
108
- const ProductQuantizer& pq,
109
- const CT* codes,
110
- size_t ncodes,
111
- const float* dis_table,
112
- size_t k,
113
- float* heap_dis,
114
- int64_t* heap_ids) {
115
- if (pq.M == 4) {
116
- pq_estimators_from_tables_M4<CT, C>(
117
- codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids);
118
- return;
119
- }
120
-
121
- if (pq.M % 4 == 0) {
122
- pq_estimators_from_tables_Mmul4<CT, C>(
123
- pq.M, codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids);
124
- return;
125
- }
126
-
127
- /* Default is relatively slow */
128
- const size_t M = pq.M;
129
- const size_t ksub = pq.ksub;
130
- for (size_t j = 0; j < ncodes; j++) {
131
- float dis = 0;
132
- const float* __restrict dt = dis_table;
133
- for (int m = 0; m < M; m++) {
134
- dis += dt[*codes++];
135
- dt += ksub;
136
- }
137
- if (C::cmp(heap_dis[0], dis)) {
138
- heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
139
- }
140
- }
141
- }
142
-
143
- template <class C>
144
- static inline void pq_estimators_from_tables_generic(
145
- const ProductQuantizer& pq,
146
- size_t nbits,
147
- const uint8_t* codes,
148
- size_t ncodes,
149
- const float* dis_table,
150
- size_t k,
151
- float* heap_dis,
152
- int64_t* heap_ids) {
153
- const size_t M = pq.M;
154
- const size_t ksub = pq.ksub;
155
- for (size_t j = 0; j < ncodes; ++j) {
156
- PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
157
- float dis = 0;
158
- const float* __restrict dt = dis_table;
159
- for (size_t m = 0; m < M; m++) {
160
- uint64_t c = decoder.decode();
161
- dis += dt[c];
162
- dt += ksub;
163
- }
164
-
165
- if (C::cmp(heap_dis[0], dis)) {
166
- heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
167
- }
168
- }
169
- }
170
-
171
46
  /*********************************************
172
47
  * PQ implementation
173
48
  *********************************************/
174
49
 
175
50
  ProductQuantizer::ProductQuantizer(size_t d, size_t M, size_t nbits)
176
- : d(d), M(M), nbits(nbits), assign_index(nullptr) {
51
+ : Quantizer(d, 0), M(M), nbits(nbits), assign_index(nullptr) {
177
52
  set_derived_values();
178
53
  }
179
54
 
@@ -246,7 +121,7 @@ static void init_hypercube_pca(
246
121
  }
247
122
  }
248
123
 
249
- void ProductQuantizer::train(int n, const float* x) {
124
+ void ProductQuantizer::train(size_t n, const float* x) {
250
125
  if (train_type != Train_shared) {
251
126
  train_type_t final_train_type;
252
127
  final_train_type = train_type;
@@ -321,26 +196,66 @@ void ProductQuantizer::train(int n, const float* x) {
321
196
  template <class PQEncoder>
322
197
  void compute_code(const ProductQuantizer& pq, const float* x, uint8_t* code) {
323
198
  std::vector<float> distances(pq.ksub);
199
+
200
+ // It seems to be meaningless to allocate std::vector<float> distances.
201
+ // But it is done in order to cope the ineffectiveness of the way
202
+ // the compiler generates the code. Basically, doing something like
203
+ //
204
+ // size_t min_distance = HUGE_VALF;
205
+ // size_t idxm = 0;
206
+ // for (size_t i = 0; i < N; i++) {
207
+ // const float distance = compute_distance(x, y + i * d, d);
208
+ // if (distance < min_distance) {
209
+ // min_distance = distance;
210
+ // idxm = i;
211
+ // }
212
+ // }
213
+ //
214
+ // generates significantly more CPU instructions than the baseline
215
+ //
216
+ // std::vector<float> distances_cached(N);
217
+ // for (size_t i = 0; i < N; i++) {
218
+ // distances_cached[i] = compute_distance(x, y + i * d, d);
219
+ // }
220
+ // size_t min_distance = HUGE_VALF;
221
+ // size_t idxm = 0;
222
+ // for (size_t i = 0; i < N; i++) {
223
+ // const float distance = distances_cached[i];
224
+ // if (distance < min_distance) {
225
+ // min_distance = distance;
226
+ // idxm = i;
227
+ // }
228
+ // }
229
+ //
230
+ // So, the baseline is faster. This is because of the vectorization.
231
+ // I suppose that the branch predictor might affect the performance as well.
232
+ // So, the buffer is allocated, but it might be unused in
233
+ // manually optimized code. Let's hope that the compiler is smart enough to
234
+ // get rid of std::vector allocation in such a case.
235
+
324
236
  PQEncoder encoder(code, pq.nbits);
325
237
  for (size_t m = 0; m < pq.M; m++) {
326
- float mindis = 1e20;
327
- uint64_t idxm = 0;
328
238
  const float* xsub = x + m * pq.dsub;
329
239
 
330
- fvec_L2sqr_ny(
331
- distances.data(),
332
- xsub,
333
- pq.get_centroids(m, 0),
334
- pq.dsub,
335
- pq.ksub);
336
-
337
- /* Find best centroid */
338
- for (size_t i = 0; i < pq.ksub; i++) {
339
- float dis = distances[i];
340
- if (dis < mindis) {
341
- mindis = dis;
342
- idxm = i;
343
- }
240
+ uint64_t idxm = 0;
241
+ if (pq.transposed_centroids.empty()) {
242
+ // the regular version
243
+ idxm = fvec_L2sqr_ny_nearest(
244
+ distances.data(),
245
+ xsub,
246
+ pq.get_centroids(m, 0),
247
+ pq.dsub,
248
+ pq.ksub);
249
+ } else {
250
+ // transposed centroids are available, use'em
251
+ idxm = fvec_L2sqr_ny_nearest_y_transposed(
252
+ distances.data(),
253
+ xsub,
254
+ pq.transposed_centroids.data() + m * pq.ksub,
255
+ pq.centroids_sq_lengths.data() + m * pq.ksub,
256
+ pq.dsub,
257
+ pq.M * pq.ksub,
258
+ pq.ksub);
344
259
  }
345
260
 
346
261
  encoder.encode(idxm);
@@ -469,10 +384,13 @@ void ProductQuantizer::compute_codes_with_assign_index(
469
384
  }
470
385
  }
471
386
 
387
+ // block size used in ProductQuantizer::compute_codes
388
+ int product_quantizer_compute_codes_bs = 256 * 1024;
389
+
472
390
  void ProductQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
473
391
  const {
474
392
  // process by blocks to avoid using too much RAM
475
- size_t bs = 256 * 1024;
393
+ size_t bs = product_quantizer_compute_codes_bs;
476
394
  if (n > bs) {
477
395
  for (size_t i0 = 0; i0 < n; i0 += bs) {
478
396
  size_t i1 = std::min(i0 + bs, n);
@@ -606,8 +524,140 @@ void ProductQuantizer::compute_inner_prod_tables(
606
524
  }
607
525
  }
608
526
 
527
+ /**********************************************
528
+ * Templatized search functions
529
+ * The template class C indicates whether to keep the highest or smallest values
530
+ **********************************************/
531
+
532
+ namespace {
533
+
534
+ /* compute an estimator using look-up tables for typical values of M */
535
+ template <typename CT, class C>
536
+ void pq_estimators_from_tables_Mmul4(
537
+ int M,
538
+ const CT* codes,
539
+ size_t ncodes,
540
+ const float* __restrict dis_table,
541
+ size_t ksub,
542
+ size_t k,
543
+ float* heap_dis,
544
+ int64_t* heap_ids) {
545
+ for (size_t j = 0; j < ncodes; j++) {
546
+ float dis = 0;
547
+ const float* dt = dis_table;
548
+
549
+ for (size_t m = 0; m < M; m += 4) {
550
+ float dism = 0;
551
+ dism = dt[*codes++];
552
+ dt += ksub;
553
+ dism += dt[*codes++];
554
+ dt += ksub;
555
+ dism += dt[*codes++];
556
+ dt += ksub;
557
+ dism += dt[*codes++];
558
+ dt += ksub;
559
+ dis += dism;
560
+ }
561
+
562
+ if (C::cmp(heap_dis[0], dis)) {
563
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
564
+ }
565
+ }
566
+ }
567
+
568
+ template <typename CT, class C>
569
+ void pq_estimators_from_tables_M4(
570
+ const CT* codes,
571
+ size_t ncodes,
572
+ const float* __restrict dis_table,
573
+ size_t ksub,
574
+ size_t k,
575
+ float* heap_dis,
576
+ int64_t* heap_ids) {
577
+ for (size_t j = 0; j < ncodes; j++) {
578
+ float dis = 0;
579
+ const float* dt = dis_table;
580
+ dis = dt[*codes++];
581
+ dt += ksub;
582
+ dis += dt[*codes++];
583
+ dt += ksub;
584
+ dis += dt[*codes++];
585
+ dt += ksub;
586
+ dis += dt[*codes++];
587
+
588
+ if (C::cmp(heap_dis[0], dis)) {
589
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
590
+ }
591
+ }
592
+ }
593
+
594
+ template <typename CT, class C>
595
+ void pq_estimators_from_tables(
596
+ const ProductQuantizer& pq,
597
+ const CT* codes,
598
+ size_t ncodes,
599
+ const float* dis_table,
600
+ size_t k,
601
+ float* heap_dis,
602
+ int64_t* heap_ids) {
603
+ if (pq.M == 4) {
604
+ pq_estimators_from_tables_M4<CT, C>(
605
+ codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids);
606
+ return;
607
+ }
608
+
609
+ if (pq.M % 4 == 0) {
610
+ pq_estimators_from_tables_Mmul4<CT, C>(
611
+ pq.M, codes, ncodes, dis_table, pq.ksub, k, heap_dis, heap_ids);
612
+ return;
613
+ }
614
+
615
+ /* Default is relatively slow */
616
+ const size_t M = pq.M;
617
+ const size_t ksub = pq.ksub;
618
+ for (size_t j = 0; j < ncodes; j++) {
619
+ float dis = 0;
620
+ const float* __restrict dt = dis_table;
621
+ for (int m = 0; m < M; m++) {
622
+ dis += dt[*codes++];
623
+ dt += ksub;
624
+ }
625
+ if (C::cmp(heap_dis[0], dis)) {
626
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
627
+ }
628
+ }
629
+ }
630
+
631
+ template <class C>
632
+ void pq_estimators_from_tables_generic(
633
+ const ProductQuantizer& pq,
634
+ size_t nbits,
635
+ const uint8_t* codes,
636
+ size_t ncodes,
637
+ const float* dis_table,
638
+ size_t k,
639
+ float* heap_dis,
640
+ int64_t* heap_ids) {
641
+ const size_t M = pq.M;
642
+ const size_t ksub = pq.ksub;
643
+ for (size_t j = 0; j < ncodes; ++j) {
644
+ PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
645
+ float dis = 0;
646
+ const float* __restrict dt = dis_table;
647
+ for (size_t m = 0; m < M; m++) {
648
+ uint64_t c = decoder.decode();
649
+ dis += dt[c];
650
+ dt += ksub;
651
+ }
652
+
653
+ if (C::cmp(heap_dis[0], dis)) {
654
+ heap_replace_top<C>(k, heap_dis, heap_ids, dis, j);
655
+ }
656
+ }
657
+ }
658
+
609
659
  template <class C>
610
- static void pq_knn_search_with_tables(
660
+ void pq_knn_search_with_tables(
611
661
  const ProductQuantizer& pq,
612
662
  size_t nbits,
613
663
  const float* dis_tables,
@@ -667,6 +717,8 @@ static void pq_knn_search_with_tables(
667
717
  }
668
718
  }
669
719
 
720
+ } // anonymous namespace
721
+
670
722
  void ProductQuantizer::search(
671
723
  const float* __restrict x,
672
724
  size_t nx,
@@ -781,4 +833,32 @@ void ProductQuantizer::search_sdc(
781
833
  }
782
834
  }
783
835
 
836
+ void ProductQuantizer::sync_transposed_centroids() {
837
+ transposed_centroids.resize(d * ksub);
838
+ centroids_sq_lengths.resize(ksub * M);
839
+
840
+ for (size_t mi = 0; mi < M; mi++) {
841
+ for (size_t ki = 0; ki < ksub; ki++) {
842
+ float sqlen = 0;
843
+
844
+ for (size_t di = 0; di < dsub; di++) {
845
+ const float q = centroids[(mi * ksub + ki) * dsub + di];
846
+
847
+ transposed_centroids[(di * M + mi) * ksub + ki] = q;
848
+ sqlen += q * q;
849
+ }
850
+
851
+ centroids_sq_lengths[mi * ksub + ki] = sqlen;
852
+ }
853
+ }
854
+ }
855
+
856
+ void ProductQuantizer::clear_transposed_centroids() {
857
+ transposed_centroids.clear();
858
+ transposed_centroids.shrink_to_fit();
859
+
860
+ centroids_sq_lengths.clear();
861
+ centroids_sq_lengths.shrink_to_fit();
862
+ }
863
+
784
864
  } // namespace faiss
@@ -15,23 +15,23 @@
15
15
  #include <vector>
16
16
 
17
17
  #include <faiss/Clustering.h>
18
+ #include <faiss/impl/Quantizer.h>
19
+ #include <faiss/impl/platform_macros.h>
18
20
  #include <faiss/utils/Heap.h>
19
21
 
20
22
  namespace faiss {
21
23
 
22
24
  /** Product Quantizer. Implemented only for METRIC_L2 */
23
- struct ProductQuantizer {
25
+ struct ProductQuantizer : Quantizer {
24
26
  using idx_t = Index::idx_t;
25
27
 
26
- size_t d; ///< size of the input vectors
27
28
  size_t M; ///< number of subquantizers
28
29
  size_t nbits; ///< number of bits per quantization index
29
30
 
30
31
  // values derived from the above
31
- size_t dsub; ///< dimensionality of each subvector
32
- size_t code_size; ///< bytes per indexed vector
33
- size_t ksub; ///< number of centroids for each subquantizer
34
- bool verbose; ///< verbose during training?
32
+ size_t dsub; ///< dimensionality of each subvector
33
+ size_t ksub; ///< number of centroids for each subquantizer
34
+ bool verbose; ///< verbose during training?
35
35
 
36
36
  /// initialization
37
37
  enum train_type_t {
@@ -49,9 +49,18 @@ struct ProductQuantizer {
49
49
  /// d / M)
50
50
  Index* assign_index;
51
51
 
52
- /// Centroid table, size M * ksub * dsub
52
+ /// Centroid table, size M * ksub * dsub.
53
+ /// Layout: (M, ksub, dsub)
53
54
  std::vector<float> centroids;
54
55
 
56
+ /// Transposed centroid table, size M * ksub * dsub.
57
+ /// Layout: (dsub, M, ksub)
58
+ std::vector<float> transposed_centroids;
59
+
60
+ /// Squared lengths of centroids, size M * ksub
61
+ /// Layout: (M, ksub)
62
+ std::vector<float> centroids_sq_lengths;
63
+
55
64
  /// return the centroids associated with subvector m
56
65
  float* get_centroids(size_t m, size_t i) {
57
66
  return &centroids[(m * ksub + i) * dsub];
@@ -62,7 +71,7 @@ struct ProductQuantizer {
62
71
 
63
72
  // Train the product quantizer on a set of points. A clustering
64
73
  // can be set on input to define non-default clustering parameters
65
- void train(int n, const float* x);
74
+ void train(size_t n, const float* x) override;
66
75
 
67
76
  ProductQuantizer(
68
77
  size_t d, /* dimensionality of the input vectors */
@@ -81,7 +90,7 @@ struct ProductQuantizer {
81
90
  void compute_code(const float* x, uint8_t* code) const;
82
91
 
83
92
  /// same as compute_code for several vectors
84
- void compute_codes(const float* x, uint8_t* codes, size_t n) const;
93
+ void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
85
94
 
86
95
  /// speed up code assignment using assign_index
87
96
  /// (non-const because the index is changed)
@@ -92,7 +101,7 @@ struct ProductQuantizer {
92
101
 
93
102
  /// decode a vector from a given code (or n vectors if third argument)
94
103
  void decode(const uint8_t* code, float* x) const;
95
- void decode(const uint8_t* code, float* x, size_t n) const;
104
+ void decode(const uint8_t* code, float* x, size_t n) const override;
96
105
 
97
106
  /// If we happen to have the distance tables precomputed, this is
98
107
  /// more efficient to compute the codes.
@@ -165,8 +174,18 @@ struct ProductQuantizer {
165
174
  const size_t ncodes,
166
175
  float_maxheap_array_t* res,
167
176
  bool init_finalize_heap = true) const;
177
+
178
+ /// Sync transposed centroids with regular centroids. This call
179
+ /// is needed if centroids were edited directly.
180
+ void sync_transposed_centroids();
181
+
182
+ /// Clear transposed centroids table so ones are no longer used.
183
+ void clear_transposed_centroids();
168
184
  };
169
185
 
186
+ // block size used in ProductQuantizer::compute_codes
187
+ FAISS_API extern int product_quantizer_compute_codes_bs;
188
+
170
189
  /*************************************************
171
190
  * Objects to encode / decode strings of bits
172
191
  *************************************************/
@@ -0,0 +1,43 @@
1
+ // (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
+
3
+ #pragma once
4
+
5
+ #include <stdint.h>
6
+
7
+ namespace faiss {
8
+
9
+ /** Product Quantizer. Implemented only for METRIC_L2 */
10
+ struct Quantizer {
11
+ using idx_t = Index::idx_t;
12
+
13
+ size_t d; ///< size of the input vectors
14
+ size_t code_size; ///< bytes per indexed vector
15
+
16
+ explicit Quantizer(size_t d = 0, size_t code_size = 0)
17
+ : d(d), code_size(code_size) {}
18
+
19
+ /** Train the quantizer
20
+ *
21
+ * @param x training vectors, size n * d
22
+ */
23
+ virtual void train(size_t n, const float* x) = 0;
24
+
25
+ /** Quantize a set of vectors
26
+ *
27
+ * @param x input vectors, size n * d
28
+ * @param codes output codes, size n * code_size
29
+ */
30
+ virtual void compute_codes(const float* x, uint8_t* codes, size_t n)
31
+ const = 0;
32
+
33
+ /** Decode a set of vectors
34
+ *
35
+ * @param codes input codes, size n * code_size
36
+ * @param x output vectors, size n * d
37
+ */
38
+ virtual void decode(const uint8_t* code, float* x, size_t n) const = 0;
39
+
40
+ virtual ~Quantizer() {}
41
+ };
42
+
43
+ } // namespace faiss