faiss 0.3.4 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +11 -8
  5. data/vendor/faiss/faiss/Clustering.cpp +0 -16
  6. data/vendor/faiss/faiss/IVFlib.cpp +213 -0
  7. data/vendor/faiss/faiss/IVFlib.h +42 -0
  8. data/vendor/faiss/faiss/Index.h +1 -1
  9. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -7
  10. data/vendor/faiss/faiss/IndexBinaryFlat.h +2 -1
  11. data/vendor/faiss/faiss/IndexFlatCodes.cpp +1 -1
  12. data/vendor/faiss/faiss/IndexFlatCodes.h +4 -2
  13. data/vendor/faiss/faiss/IndexHNSW.cpp +13 -20
  14. data/vendor/faiss/faiss/IndexHNSW.h +1 -1
  15. data/vendor/faiss/faiss/IndexIVF.cpp +20 -3
  16. data/vendor/faiss/faiss/IndexIVF.h +5 -2
  17. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -1
  18. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +2 -1
  19. data/vendor/faiss/faiss/IndexIVFFlat.cpp +2 -1
  20. data/vendor/faiss/faiss/IndexIVFFlat.h +2 -1
  21. data/vendor/faiss/faiss/IndexIVFPQ.cpp +2 -1
  22. data/vendor/faiss/faiss/IndexIVFPQ.h +2 -1
  23. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +277 -0
  24. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +70 -0
  25. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +2 -1
  26. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  27. data/vendor/faiss/faiss/IndexRaBitQ.cpp +148 -0
  28. data/vendor/faiss/faiss/IndexRaBitQ.h +65 -0
  29. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -1
  30. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -1
  31. data/vendor/faiss/faiss/clone_index.cpp +38 -3
  32. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +19 -0
  33. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +4 -11
  34. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +2 -1
  35. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +13 -3
  36. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +1 -1
  37. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +1 -1
  38. data/vendor/faiss/faiss/gpu/test/TestGpuIcmEncoder.cpp +112 -0
  39. data/vendor/faiss/faiss/impl/HNSW.cpp +35 -13
  40. data/vendor/faiss/faiss/impl/HNSW.h +5 -4
  41. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  42. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +519 -0
  43. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +78 -0
  44. data/vendor/faiss/faiss/impl/ResultHandler.h +2 -2
  45. data/vendor/faiss/faiss/impl/code_distance/code_distance-sve.h +3 -4
  46. data/vendor/faiss/faiss/impl/index_read.cpp +220 -25
  47. data/vendor/faiss/faiss/impl/index_write.cpp +29 -0
  48. data/vendor/faiss/faiss/impl/io.h +2 -2
  49. data/vendor/faiss/faiss/impl/io_macros.h +2 -0
  50. data/vendor/faiss/faiss/impl/mapped_io.cpp +313 -0
  51. data/vendor/faiss/faiss/impl/mapped_io.h +51 -0
  52. data/vendor/faiss/faiss/impl/maybe_owned_vector.h +316 -0
  53. data/vendor/faiss/faiss/impl/platform_macros.h +7 -3
  54. data/vendor/faiss/faiss/impl/simd_result_handlers.h +1 -1
  55. data/vendor/faiss/faiss/impl/zerocopy_io.cpp +67 -0
  56. data/vendor/faiss/faiss/impl/zerocopy_io.h +32 -0
  57. data/vendor/faiss/faiss/index_factory.cpp +16 -5
  58. data/vendor/faiss/faiss/index_io.h +4 -0
  59. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +3 -3
  60. data/vendor/faiss/faiss/invlists/InvertedLists.h +5 -3
  61. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +3 -3
  62. data/vendor/faiss/faiss/python/python_callbacks.cpp +24 -0
  63. data/vendor/faiss/faiss/python/python_callbacks.h +22 -0
  64. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +30 -12
  65. data/vendor/faiss/faiss/utils/hamming.cpp +45 -21
  66. data/vendor/faiss/faiss/utils/hamming.h +7 -3
  67. data/vendor/faiss/faiss/utils/hamming_distance/avx512-inl.h +1 -1
  68. data/vendor/faiss/faiss/utils/utils.cpp +4 -4
  69. data/vendor/faiss/faiss/utils/utils.h +3 -3
  70. metadata +16 -4
@@ -0,0 +1,519 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/impl/RaBitQuantizer.h>
9
+
10
+ #include <algorithm>
11
+ #include <cmath>
12
+ #include <cstring>
13
+ #include <limits>
14
+ #include <memory>
15
+ #include <vector>
16
+
17
+ #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/utils/distances.h>
19
+
20
+ namespace faiss {
21
+
22
+ struct FactorsData {
23
+ // ||or - c||^2 - ((metric==IP) ? ||or||^2 : 0)
24
+ float or_minus_c_l2sqr = 0;
25
+ float dp_multiplier = 0;
26
+ };
27
+
28
+ struct QueryFactorsData {
29
+ float c1 = 0;
30
+ float c2 = 0;
31
+ float c34 = 0;
32
+
33
+ float qr_to_c_L2sqr = 0;
34
+ float qr_norm_L2sqr = 0;
35
+ };
36
+
37
+ static size_t get_code_size(const size_t d) {
38
+ return (d + 7) / 8 + sizeof(FactorsData);
39
+ }
40
+
41
+ RaBitQuantizer::RaBitQuantizer(size_t d, MetricType metric)
42
+ : Quantizer(d, get_code_size(d)), metric_type{metric} {}
43
+
44
+ void RaBitQuantizer::train(size_t n, const float* x) {
45
+ // does nothing
46
+ }
47
+
48
+ void RaBitQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
49
+ const {
50
+ compute_codes_core(x, codes, n, centroid);
51
+ }
52
+
53
+ void RaBitQuantizer::compute_codes_core(
54
+ const float* x,
55
+ uint8_t* codes,
56
+ size_t n,
57
+ const float* centroid_in) const {
58
+ FAISS_ASSERT(codes != nullptr);
59
+ FAISS_ASSERT(x != nullptr);
60
+ FAISS_ASSERT(
61
+ (metric_type == MetricType::METRIC_L2 ||
62
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
63
+
64
+ if (n == 0) {
65
+ return;
66
+ }
67
+
68
+ // compute some helper constants
69
+ const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
70
+
71
+ // compute codes
72
+ #pragma omp parallel for if (n > 1000)
73
+ for (int64_t i = 0; i < n; i++) {
74
+ // ||or - c||^2
75
+ float norm_L2sqr = 0;
76
+ // ||or||^2, which is equal to ||P(or)||^2 and ||P^(-1)(or)||^2
77
+ float or_L2sqr = 0;
78
+ // dot product
79
+ float dp_oO = 0;
80
+
81
+ // the code
82
+ uint8_t* code = codes + i * code_size;
83
+ FactorsData* fac = reinterpret_cast<FactorsData*>(code + (d + 7) / 8);
84
+
85
+ // cleanup it
86
+ if (code != nullptr) {
87
+ memset(code, 0, code_size);
88
+ }
89
+
90
+ for (size_t j = 0; j < d; j++) {
91
+ const float or_minus_c = x[i * d + j] -
92
+ ((centroid_in == nullptr) ? 0 : centroid_in[j]);
93
+ norm_L2sqr += or_minus_c * or_minus_c;
94
+ or_L2sqr += x[i * d + j] * x[i * d + j];
95
+
96
+ const bool xb = (or_minus_c > 0);
97
+
98
+ dp_oO += xb ? or_minus_c : (-or_minus_c);
99
+
100
+ // store the output data
101
+ if (code != nullptr) {
102
+ if (xb) {
103
+ // enable a particular bit
104
+ code[j / 8] |= (1 << (j % 8));
105
+ }
106
+ }
107
+ }
108
+
109
+ // compute factors
110
+
111
+ // compute the inverse norm
112
+ const float inv_norm_L2 =
113
+ (std::abs(norm_L2sqr) < std::numeric_limits<float>::epsilon())
114
+ ? 1.0f
115
+ : (1.0f / std::sqrt(norm_L2sqr));
116
+ dp_oO *= inv_norm_L2;
117
+ dp_oO *= inv_d_sqrt;
118
+
119
+ const float inv_dp_oO =
120
+ (std::abs(dp_oO) < std::numeric_limits<float>::epsilon())
121
+ ? 1.0f
122
+ : (1.0f / dp_oO);
123
+
124
+ fac->or_minus_c_l2sqr = norm_L2sqr;
125
+ if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
126
+ fac->or_minus_c_l2sqr -= or_L2sqr;
127
+ }
128
+
129
+ fac->dp_multiplier = inv_dp_oO * std::sqrt(norm_L2sqr);
130
+ }
131
+ }
132
+
133
+ void RaBitQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
134
+ decode_core(codes, x, n, centroid);
135
+ }
136
+
137
+ void RaBitQuantizer::decode_core(
138
+ const uint8_t* codes,
139
+ float* x,
140
+ size_t n,
141
+ const float* centroid_in) const {
142
+ FAISS_ASSERT(codes != nullptr);
143
+ FAISS_ASSERT(x != nullptr);
144
+
145
+ const float inv_d_sqrt = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
146
+
147
+ #pragma omp parallel for if (n > 1000)
148
+ for (int64_t i = 0; i < n; i++) {
149
+ const uint8_t* code = codes + i * code_size;
150
+
151
+ // split the code into parts
152
+ const uint8_t* binary_data = code;
153
+ const FactorsData* fac =
154
+ reinterpret_cast<const FactorsData*>(code + (d + 7) / 8);
155
+
156
+ //
157
+ for (size_t j = 0; j < d; j++) {
158
+ // extract i-th bit
159
+ const uint8_t masker = (1 << (j % 8));
160
+ const float bit = ((binary_data[j / 8] & masker) == masker) ? 1 : 0;
161
+
162
+ // compute the output code
163
+ x[i * d + j] = (bit - 0.5f) * fac->dp_multiplier * 2 * inv_d_sqrt +
164
+ ((centroid_in == nullptr) ? 0 : centroid_in[j]);
165
+ }
166
+ }
167
+ }
168
+
169
+ struct RaBitDistanceComputer : FlatCodesDistanceComputer {
170
+ // dimensionality
171
+ size_t d = 0;
172
+ // a centroid to use
173
+ const float* centroid = nullptr;
174
+
175
+ // the metric
176
+ MetricType metric_type = MetricType::METRIC_L2;
177
+
178
+ RaBitDistanceComputer();
179
+
180
+ float symmetric_dis(idx_t i, idx_t j) override;
181
+ };
182
+
183
+ RaBitDistanceComputer::RaBitDistanceComputer() = default;
184
+
185
+ float RaBitDistanceComputer::symmetric_dis(idx_t i, idx_t j) {
186
+ FAISS_THROW_MSG("Not implemented");
187
+ }
188
+
189
+ struct RaBitDistanceComputerNotQ : RaBitDistanceComputer {
190
+ // the rotated query (qr - c)
191
+ std::vector<float> rotated_q;
192
+ // some additional numbers for the query
193
+ QueryFactorsData query_fac;
194
+
195
+ RaBitDistanceComputerNotQ();
196
+
197
+ float distance_to_code(const uint8_t* code) override;
198
+
199
+ void set_query(const float* x) override;
200
+ };
201
+
202
+ RaBitDistanceComputerNotQ::RaBitDistanceComputerNotQ() = default;
203
+
204
+ float RaBitDistanceComputerNotQ::distance_to_code(const uint8_t* code) {
205
+ FAISS_ASSERT(code != nullptr);
206
+ FAISS_ASSERT(
207
+ (metric_type == MetricType::METRIC_L2 ||
208
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
209
+
210
+ // split the code into parts
211
+ const uint8_t* binary_data = code;
212
+ const FactorsData* fac =
213
+ reinterpret_cast<const FactorsData*>(code + (d + 7) / 8);
214
+
215
+ // this is the baseline code
216
+ //
217
+ // compute <q,o> using floats
218
+ float dot_qo = 0;
219
+ // It was a willful decision (after the discussion) to not to pre-cache
220
+ // the sum of all bits, just in order to reduce the overhead per vector.
221
+ uint64_t sum_q = 0;
222
+ for (size_t i = 0; i < d; i++) {
223
+ // extract i-th bit
224
+ const uint8_t masker = (1 << (i % 8));
225
+ const bool b_bit = ((binary_data[i / 8] & masker) == masker);
226
+
227
+ // accumulate dp
228
+ dot_qo += (b_bit) ? rotated_q[i] : 0;
229
+ // accumulate sum-of-bits
230
+ sum_q += (b_bit) ? 1 : 0;
231
+ }
232
+
233
+ float final_dot = 0;
234
+ // dot-product itself
235
+ final_dot += query_fac.c1 * dot_qo;
236
+ // normalizer coefficients
237
+ final_dot += query_fac.c2 * sum_q;
238
+ // normalizer coefficients
239
+ final_dot -= query_fac.c34;
240
+
241
+ // this is ||or - c||^2 - (IP ? ||or||^2 : 0)
242
+ const float or_c_l2sqr = fac->or_minus_c_l2sqr;
243
+
244
+ // pre_dist = ||or - c||^2 + ||qr - c||^2 -
245
+ // 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
246
+ const float pre_dist = or_c_l2sqr + query_fac.qr_to_c_L2sqr -
247
+ 2 * fac->dp_multiplier * final_dot;
248
+
249
+ if (metric_type == MetricType::METRIC_L2) {
250
+ // ||or - q||^ 2
251
+ return pre_dist;
252
+ } else {
253
+ // metric == MetricType::METRIC_INNER_PRODUCT
254
+
255
+ // this is ||q||^2
256
+ const float query_norm_sqr = query_fac.qr_norm_L2sqr;
257
+
258
+ // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
259
+ return -0.5f * (pre_dist - query_norm_sqr);
260
+ }
261
+ }
262
+
263
+ void RaBitDistanceComputerNotQ::set_query(const float* x) {
264
+ FAISS_ASSERT(x != nullptr);
265
+ FAISS_ASSERT(
266
+ (metric_type == MetricType::METRIC_L2 ||
267
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
268
+
269
+ // compute the distance from the query to the centroid
270
+ if (centroid != nullptr) {
271
+ query_fac.qr_to_c_L2sqr = fvec_L2sqr(x, centroid, d);
272
+ } else {
273
+ query_fac.qr_to_c_L2sqr = fvec_norm_L2sqr(x, d);
274
+ }
275
+
276
+ // subtract c, obtain P^(-1)(qr - c)
277
+ rotated_q.resize(d);
278
+ for (size_t i = 0; i < d; i++) {
279
+ rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
280
+ }
281
+
282
+ // compute some numbers
283
+ const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
284
+
285
+ // do not quantize the query
286
+ float sum_q = 0;
287
+ for (size_t i = 0; i < d; i++) {
288
+ sum_q += rotated_q[i];
289
+ }
290
+
291
+ query_fac.c1 = 2 * inv_d;
292
+ query_fac.c2 = 0;
293
+ query_fac.c34 = sum_q * inv_d;
294
+
295
+ if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
296
+ // precompute if needed
297
+ query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d);
298
+ }
299
+ }
300
+
301
+ //
302
+ struct RaBitDistanceComputerQ : RaBitDistanceComputer {
303
+ // the rotated and quantized query (qr - c)
304
+ std::vector<uint8_t> rotated_qq;
305
+ // we're using the proposed relayout-ed scheme from 3.3 that allows
306
+ // using popcounts for computing the distance.
307
+ std::vector<uint8_t> rearranged_rotated_qq;
308
+ // some additional numbers for the query
309
+ QueryFactorsData query_fac;
310
+
311
+ // the number of bits for SQ quantization of the query (qb > 0)
312
+ uint8_t qb = 8;
313
+ // the smallest value divisible by 8 that is not smaller than dim
314
+ size_t popcount_aligned_dim = 0;
315
+
316
+ RaBitDistanceComputerQ();
317
+
318
+ float distance_to_code(const uint8_t* code) override;
319
+
320
+ void set_query(const float* x) override;
321
+ };
322
+
323
+ RaBitDistanceComputerQ::RaBitDistanceComputerQ() = default;
324
+
325
+ float RaBitDistanceComputerQ::distance_to_code(const uint8_t* code) {
326
+ FAISS_ASSERT(code != nullptr);
327
+ FAISS_ASSERT(
328
+ (metric_type == MetricType::METRIC_L2 ||
329
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
330
+
331
+ // split the code into parts
332
+ const uint8_t* binary_data = code;
333
+ const FactorsData* fac =
334
+ reinterpret_cast<const FactorsData*>(code + (d + 7) / 8);
335
+
336
+ // // this is the baseline code
337
+ // //
338
+ // // compute <q,o> using integers
339
+ // size_t dot_qo = 0;
340
+ // for (size_t i = 0; i < d; i++) {
341
+ // // extract i-th bit
342
+ // const uint8_t masker = (1 << (i % 8));
343
+ // const uint8_t bit = ((binary_data[i / 8] & masker) == masker) ? 1 :
344
+ // 0;
345
+ //
346
+ // // accumulate dp
347
+ // dot_qo += bit * rotated_qq[i];
348
+ // }
349
+
350
+ // this is the scheme for popcount
351
+ const size_t di_8b = (d + 7) / 8;
352
+ const size_t di_64b = (di_8b / 8) * 8;
353
+
354
+ uint64_t dot_qo = 0;
355
+ for (size_t j = 0; j < qb; j++) {
356
+ const uint8_t* query_j = rearranged_rotated_qq.data() + j * di_8b;
357
+
358
+ // process 64-bit popcounts
359
+ uint64_t count_dot = 0;
360
+ for (size_t i = 0; i < di_64b; i += 8) {
361
+ const auto qv = *(const uint64_t*)(query_j + i);
362
+ const auto yv = *(const uint64_t*)(binary_data + i);
363
+ count_dot += __builtin_popcountll(qv & yv);
364
+ }
365
+
366
+ // process leftovers
367
+ for (size_t i = di_64b; i < di_8b; i++) {
368
+ const auto qv = *(query_j + i);
369
+ const auto yv = *(binary_data + i);
370
+ count_dot += __builtin_popcount(qv & yv);
371
+ }
372
+
373
+ dot_qo += (count_dot << j);
374
+ }
375
+
376
+ // It was a willful decision (after the discussion) to not to pre-cache
377
+ // the sum of all bits, just in order to reduce the overhead per vector.
378
+ uint64_t sum_q = 0;
379
+ {
380
+ // process 64-bit popcounts
381
+ for (size_t i = 0; i < di_64b; i += 8) {
382
+ const auto yv = *(const uint64_t*)(binary_data + i);
383
+ sum_q += __builtin_popcountll(yv);
384
+ }
385
+
386
+ // process leftovers
387
+ for (size_t i = di_64b; i < di_8b; i++) {
388
+ const auto yv = *(binary_data + i);
389
+ sum_q += __builtin_popcount(yv);
390
+ }
391
+ }
392
+
393
+ float final_dot = 0;
394
+ // dot-product itself
395
+ final_dot += query_fac.c1 * dot_qo;
396
+ // normalizer coefficients
397
+ final_dot += query_fac.c2 * sum_q;
398
+ // normalizer coefficients
399
+ final_dot -= query_fac.c34;
400
+
401
+ // this is ||or - c||^2 - (IP ? ||or||^2 : 0)
402
+ const float or_c_l2sqr = fac->or_minus_c_l2sqr;
403
+
404
+ // pre_dist = ||or - c||^2 + ||qr - c||^2 -
405
+ // 2 * ||or - c|| * ||qr - c|| * <q,o> - (IP ? ||or||^2 : 0)
406
+ const float pre_dist = or_c_l2sqr + query_fac.qr_to_c_L2sqr -
407
+ 2 * fac->dp_multiplier * final_dot;
408
+
409
+ if (metric_type == MetricType::METRIC_L2) {
410
+ // ||or - q||^ 2
411
+ return pre_dist;
412
+ } else {
413
+ // metric == MetricType::METRIC_INNER_PRODUCT
414
+
415
+ // this is ||q||^2
416
+ const float query_norm_sqr = query_fac.qr_norm_L2sqr;
417
+
418
+ // 2 * (or, q) = (||or - q||^2 - ||q||^2 - ||or||^2)
419
+ return -0.5f * (pre_dist - query_norm_sqr);
420
+ }
421
+ }
422
+
423
+ void RaBitDistanceComputerQ::set_query(const float* x) {
424
+ FAISS_ASSERT(x != nullptr);
425
+ FAISS_ASSERT(
426
+ (metric_type == MetricType::METRIC_L2 ||
427
+ metric_type == MetricType::METRIC_INNER_PRODUCT));
428
+
429
+ // compute the distance from the query to the centroid
430
+ if (centroid != nullptr) {
431
+ query_fac.qr_to_c_L2sqr = fvec_L2sqr(x, centroid, d);
432
+ } else {
433
+ query_fac.qr_to_c_L2sqr = fvec_norm_L2sqr(x, d);
434
+ }
435
+
436
+ // allocate space
437
+ rotated_qq.resize(d);
438
+
439
+ // rotate the query
440
+ std::vector<float> rotated_q(d);
441
+ for (size_t i = 0; i < d; i++) {
442
+ rotated_q[i] = x[i] - ((centroid == nullptr) ? 0 : centroid[i]);
443
+ }
444
+
445
+ // compute some numbers
446
+ const float inv_d = (d == 0) ? 1.0f : (1.0f / std::sqrt((float)d));
447
+
448
+ // quantize the query. compute min and max
449
+ float v_min = std::numeric_limits<float>::max();
450
+ float v_max = std::numeric_limits<float>::lowest();
451
+ for (size_t i = 0; i < d; i++) {
452
+ const float v_q = rotated_q[i];
453
+ v_min = std::min(v_min, v_q);
454
+ v_max = std::max(v_max, v_q);
455
+ }
456
+
457
+ const float pow_2_qb = 1 << qb;
458
+
459
+ const float delta = (v_max - v_min) / (pow_2_qb - 1);
460
+ const float inv_delta = 1.0f / delta;
461
+
462
+ size_t sum_qq = 0;
463
+ for (int32_t i = 0; i < d; i++) {
464
+ const float v_q = rotated_q[i];
465
+
466
+ // a default non-randomized SQ
467
+ const int v_qq = std::round((v_q - v_min) * inv_delta);
468
+
469
+ rotated_qq[i] = std::min(255, std::max(0, v_qq));
470
+ sum_qq += v_qq;
471
+ }
472
+
473
+ // rearrange the query vector
474
+ popcount_aligned_dim = ((d + 7) / 8) * 8;
475
+ size_t offset = (d + 7) / 8;
476
+
477
+ rearranged_rotated_qq.resize(offset * qb);
478
+ std::fill(rearranged_rotated_qq.begin(), rearranged_rotated_qq.end(), 0);
479
+
480
+ for (size_t idim = 0; idim < d; idim++) {
481
+ for (size_t iv = 0; iv < qb; iv++) {
482
+ const bool bit = ((rotated_qq[idim] & (1 << iv)) != 0);
483
+ rearranged_rotated_qq[iv * offset + idim / 8] |=
484
+ bit ? (1 << (idim % 8)) : 0;
485
+ }
486
+ }
487
+
488
+ query_fac.c1 = 2 * delta * inv_d;
489
+ query_fac.c2 = 2 * v_min * inv_d;
490
+ query_fac.c34 = inv_d * (delta * sum_qq + d * v_min);
491
+
492
+ if (metric_type == MetricType::METRIC_INNER_PRODUCT) {
493
+ // precompute if needed
494
+ query_fac.qr_norm_L2sqr = fvec_norm_L2sqr(x, d);
495
+ }
496
+ }
497
+
498
+ FlatCodesDistanceComputer* RaBitQuantizer::get_distance_computer(
499
+ uint8_t qb,
500
+ const float* centroid_in) const {
501
+ if (qb == 0) {
502
+ auto dc = std::make_unique<RaBitDistanceComputerNotQ>();
503
+ dc->metric_type = metric_type;
504
+ dc->d = d;
505
+ dc->centroid = centroid_in;
506
+
507
+ return dc.release();
508
+ } else {
509
+ auto dc = std::make_unique<RaBitDistanceComputerQ>();
510
+ dc->metric_type = metric_type;
511
+ dc->d = d;
512
+ dc->centroid = centroid_in;
513
+ dc->qb = qb;
514
+
515
+ return dc.release();
516
+ }
517
+ }
518
+
519
+ } // namespace faiss
@@ -0,0 +1,78 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <cstddef>
11
+ #include <cstdint>
12
+
13
+ #include <faiss/MetricType.h>
14
+ #include <faiss/impl/DistanceComputer.h>
15
+ #include <faiss/impl/Quantizer.h>
16
+
17
+ namespace faiss {
18
+
19
+ // the reference implementation of the https://arxiv.org/pdf/2405.12497
20
+ // Jianyang Gao, Cheng Long, "RaBitQ: Quantizing High-Dimensional Vectors
21
+ // with a Theoretical Error Bound for Approximate Nearest Neighbor Search".
22
+ //
23
+ // It is assumed that the Random Matrix Rotation is performed externally.
24
+ struct RaBitQuantizer : Quantizer {
25
+ // all RaBitQ operations are provided against a centroid, which needs
26
+ // to be provided Externally (!). Nullptr value implies that the centroid
27
+ // consists of zero values.
28
+ // This is the default value that can be customized using XYZ_core() calls.
29
+ // Such a customization is needed for IVF calls.
30
+ //
31
+ // This particular pointer will NOT be serialized.
32
+ float* centroid = nullptr;
33
+
34
+ // RaBitQ codes computations are independent from a metric. But it is needed
35
+ // to store some additional fp32 constants together with a quantized code.
36
+ // A decision was made to make this quantizer as space efficient as
37
+ // possible. Thus, a quantizer has to introduce a metric.
38
+ MetricType metric_type = MetricType::METRIC_L2;
39
+
40
+ RaBitQuantizer(size_t d = 0, MetricType metric = MetricType::METRIC_L2);
41
+
42
+ void train(size_t n, const float* x) override;
43
+
44
+ // every vector is expected to take (d + 7) / 8 + sizeof(FactorsData) bytes,
45
+ void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
46
+
47
+ void compute_codes_core(
48
+ const float* x,
49
+ uint8_t* codes,
50
+ size_t n,
51
+ const float* centroid_in) const;
52
+
53
+ // The decode output is Heavily geared towards maintaining the IP, not L2.
54
+ // This means that the reconstructed codes maybe less accurate than one may
55
+ // expect, if one computes an L2 distance between a reconstructed code and
56
+ // the corresponding original vector.
57
+ // But value of the dot product between a query and the original vector
58
+ // might be very close to the value of the dot product between a query and
59
+ // the reconstructed code.
60
+ // Basically, it seems to be related to the distributions of values, not
61
+ // values.
62
+ void decode(const uint8_t* codes, float* x, size_t n) const override;
63
+
64
+ void decode_core(
65
+ const uint8_t* codes,
66
+ float* x,
67
+ size_t n,
68
+ const float* centroid_in) const;
69
+
70
+ // returns the distance computer.
71
+ // specify qb = 0 to get an DC that does not quantize a query
72
+ // specify qb > 0 to have SQ qb-bits query
73
+ FlatCodesDistanceComputer* get_distance_computer(
74
+ uint8_t qb,
75
+ const float* centroid_in = nullptr) const;
76
+ };
77
+
78
+ } // namespace faiss
@@ -534,7 +534,7 @@ struct RangeSearchBlockResultHandler : BlockResultHandler<C, use_sel> {
534
534
  try {
535
535
  // finalize the partial result
536
536
  pres.finalize();
537
- } catch (const faiss::FaissException& e) {
537
+ } catch ([[maybe_unused]] const faiss::FaissException& e) {
538
538
  // Do nothing if allocation fails in finalizing partial results.
539
539
  #ifndef NDEBUG
540
540
  std::cerr << e.what() << std::endl;
@@ -598,7 +598,7 @@ struct RangeSearchBlockResultHandler : BlockResultHandler<C, use_sel> {
598
598
  if (partial_results.size() > 0) {
599
599
  RangeSearchPartialResult::merge(partial_results);
600
600
  }
601
- } catch (const faiss::FaissException& e) {
601
+ } catch ([[maybe_unused]] const faiss::FaissException& e) {
602
602
  // Do nothing if allocation fails in merge.
603
603
  #ifndef NDEBUG
604
604
  std::cerr << e.what() << std::endl;
@@ -14,6 +14,7 @@
14
14
  #include <tuple>
15
15
  #include <type_traits>
16
16
 
17
+ #include <faiss/impl/ProductQuantizer.h>
17
18
  #include <faiss/impl/code_distance/code_distance-generic.h>
18
19
 
19
20
  namespace faiss {
@@ -48,7 +49,7 @@ static inline void distance_codes_kernel(
48
49
  partialSum = svadd_f32_m(pg, partialSum, collected);
49
50
  }
50
51
 
51
- static float distance_single_code_sve_for_small_m(
52
+ static inline float distance_single_code_sve_for_small_m(
52
53
  // the product quantizer
53
54
  const size_t M,
54
55
  // precomputed distances, layout (M, ksub)
@@ -196,7 +197,7 @@ distance_four_codes_sve(
196
197
  result3);
197
198
  }
198
199
 
199
- static void distance_four_codes_sve_for_small_m(
200
+ static inline void distance_four_codes_sve_for_small_m(
200
201
  // the product quantizer
201
202
  const size_t M,
202
203
  // precomputed distances, layout (M, ksub)
@@ -217,8 +218,6 @@ static void distance_four_codes_sve_for_small_m(
217
218
 
218
219
  const auto offsets_0 = svindex_u32(0, static_cast<uint32_t>(ksub));
219
220
 
220
- const auto quad_lanes = svcntw();
221
-
222
221
  // loop
223
222
  const auto pg = svwhilelt_b32_u64(0, M);
224
223