faiss 0.2.3 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  6. data/vendor/faiss/faiss/Clustering.h +14 -0
  7. data/vendor/faiss/faiss/Index.h +1 -1
  8. data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
  9. data/vendor/faiss/faiss/Index2Layer.h +2 -16
  10. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  11. data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
  12. data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
  13. data/vendor/faiss/faiss/IndexFlat.h +9 -15
  14. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  15. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  16. data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
  17. data/vendor/faiss/faiss/IndexIVF.h +25 -7
  18. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  20. data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
  21. data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
  22. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  23. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
  24. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
  25. data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
  26. data/vendor/faiss/faiss/IndexLSH.h +2 -15
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
  28. data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
  29. data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
  30. data/vendor/faiss/faiss/IndexPQ.h +2 -17
  31. data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
  32. data/vendor/faiss/faiss/IndexRefine.h +10 -0
  33. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
  35. data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
  36. data/vendor/faiss/faiss/VectorTransform.h +3 -0
  37. data/vendor/faiss/faiss/clone_index.cpp +3 -2
  38. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
  39. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  40. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
  41. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
  42. data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
  43. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
  44. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
  45. data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
  46. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  47. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
  48. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
  50. data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
  51. data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
  52. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  53. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  54. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  55. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +585 -414
  57. data/vendor/faiss/faiss/index_factory.h +3 -0
  58. data/vendor/faiss/faiss/utils/distances.cpp +4 -2
  59. data/vendor/faiss/faiss/utils/distances.h +36 -3
  60. data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
  61. data/vendor/faiss/faiss/utils/utils.h +1 -1
  62. metadata +12 -5
  63. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -357,6 +357,7 @@ PCAMatrix::PCAMatrix(
357
357
  is_trained = false;
358
358
  max_points_per_d = 1000;
359
359
  balanced_bins = 0;
360
+ epsilon = 0;
360
361
  }
361
362
 
362
363
  namespace {
@@ -620,7 +621,7 @@ void PCAMatrix::prepare_Ab() {
620
621
  if (eigen_power != 0) {
621
622
  float* ai = A.data();
622
623
  for (int i = 0; i < d_out; i++) {
623
- float factor = pow(eigenvalues[i], eigen_power);
624
+ float factor = pow(eigenvalues[i] + epsilon, eigen_power);
624
625
  for (int j = 0; j < d_in; j++)
625
626
  *ai++ *= factor;
626
627
  }
@@ -129,6 +129,9 @@ struct PCAMatrix : LinearTransform {
129
129
  */
130
130
  float eigen_power;
131
131
 
132
+ /// value added to eigenvalues to avoid division by 0 when whitening
133
+ float epsilon;
134
+
132
135
  /// random rotation after PCA
133
136
  bool random_rotation;
134
137
 
@@ -15,6 +15,7 @@
15
15
  #include <faiss/impl/FaissAssert.h>
16
16
 
17
17
  #include <faiss/Index2Layer.h>
18
+ #include <faiss/IndexAdditiveQuantizer.h>
18
19
  #include <faiss/IndexFlat.h>
19
20
  #include <faiss/IndexHNSW.h>
20
21
  #include <faiss/IndexIVF.h>
@@ -27,7 +28,6 @@
27
28
  #include <faiss/IndexNSG.h>
28
29
  #include <faiss/IndexPQ.h>
29
30
  #include <faiss/IndexPreTransform.h>
30
- #include <faiss/IndexResidual.h>
31
31
  #include <faiss/IndexScalarQuantizer.h>
32
32
  #include <faiss/MetaIndexes.h>
33
33
  #include <faiss/VectorTransform.h>
@@ -80,9 +80,10 @@ Index* Cloner::clone_Index(const Index* index) {
80
80
  TRYCLONE(IndexFlatIP, index)
81
81
  TRYCLONE(IndexFlat, index)
82
82
  TRYCLONE(IndexLattice, index)
83
- TRYCLONE(IndexResidual, index)
83
+ TRYCLONE(IndexResidualQuantizer, index)
84
84
  TRYCLONE(IndexScalarQuantizer, index)
85
85
  TRYCLONE(MultiIndexQuantizer, index)
86
+ TRYCLONE(ResidualCoarseQuantizer, index)
86
87
  if (const IndexIVF* ivf = dynamic_cast<const IndexIVF*>(index)) {
87
88
  IndexIVF* res = clone_IndexIVF(ivf);
88
89
  if (ivf->invlists == nullptr) {
@@ -40,7 +40,7 @@ void ToCPUCloner::merge_index(Index* dst, Index* src, bool successive_ids) {
40
40
  auto ifl2 = dynamic_cast<const IndexFlat*>(src);
41
41
  FAISS_ASSERT(ifl2);
42
42
  FAISS_ASSERT(successive_ids);
43
- ifl->add(ifl2->ntotal, ifl2->xb.data());
43
+ ifl->add(ifl2->ntotal, ifl2->get_xb());
44
44
  } else if (auto ifl = dynamic_cast<IndexIVFFlat*>(dst)) {
45
45
  auto ifl2 = dynamic_cast<IndexIVFFlat*>(src);
46
46
  FAISS_ASSERT(ifl2);
@@ -329,7 +329,7 @@ Index* ToGpuClonerMultiple::clone_Index_to_shards(const Index* index) {
329
329
  if (index->ntotal > 0) {
330
330
  long i0 = index->ntotal * i / n;
331
331
  long i1 = index->ntotal * (i + 1) / n;
332
- shards[i]->add(i1 - i0, index_flat->xb.data() + i0 * index->d);
332
+ shards[i]->add(i1 - i0, index_flat->get_xb() + i0 * index->d);
333
333
  }
334
334
  }
335
335
  }
@@ -0,0 +1,60 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/impl/LocalSearchQuantizer.h>
11
+
12
+ #include <memory>
13
+
14
+ namespace faiss {
15
+ namespace gpu {
16
+
17
+ class GpuResourcesProvider;
18
+ struct IcmEncoderShards;
19
+
20
+ /** Perform LSQ encoding on GPU.
21
+ *
22
+ * Split input vectors to different devices and call IcmEncoderImpl::encode
23
+ * to encode them
24
+ */
25
+ class GpuIcmEncoder : public lsq::IcmEncoder {
26
+ public:
27
+ GpuIcmEncoder(
28
+ const LocalSearchQuantizer* lsq,
29
+ const std::vector<GpuResourcesProvider*>& provs,
30
+ const std::vector<int>& devices);
31
+
32
+ ~GpuIcmEncoder();
33
+
34
+ GpuIcmEncoder(const GpuIcmEncoder&) = delete;
35
+ GpuIcmEncoder& operator=(const GpuIcmEncoder&) = delete;
36
+
37
+ void set_binary_term() override;
38
+
39
+ void encode(
40
+ int32_t* codes,
41
+ const float* x,
42
+ std::mt19937& gen,
43
+ size_t n,
44
+ size_t ils_iters) const override;
45
+
46
+ private:
47
+ std::unique_ptr<IcmEncoderShards> shards;
48
+ };
49
+
50
+ struct GpuIcmEncoderFactory : public lsq::IcmEncoderFactory {
51
+ explicit GpuIcmEncoderFactory(int ngpus = 1);
52
+
53
+ lsq::IcmEncoder* get(const LocalSearchQuantizer* lsq) override;
54
+
55
+ std::vector<GpuResourcesProvider*> provs;
56
+ std::vector<int> devices;
57
+ };
58
+
59
+ } // namespace gpu
60
+ } // namespace faiss
@@ -8,7 +8,6 @@
8
8
  // -*- c++ -*-
9
9
 
10
10
  #include <faiss/impl/AdditiveQuantizer.h>
11
- #include <faiss/impl/FaissAssert.h>
12
11
 
13
12
  #include <cstddef>
14
13
  #include <cstdio>
@@ -18,9 +17,10 @@
18
17
 
19
18
  #include <algorithm>
20
19
 
20
+ #include <faiss/impl/FaissAssert.h>
21
21
  #include <faiss/utils/Heap.h>
22
22
  #include <faiss/utils/distances.h>
23
- #include <faiss/utils/hamming.h> // BitstringWriter
23
+ #include <faiss/utils/hamming.h>
24
24
  #include <faiss/utils/utils.h>
25
25
 
26
26
  extern "C" {
@@ -42,51 +42,125 @@ int sgemm_(
42
42
  FINTEGER* ldc);
43
43
  }
44
44
 
45
- namespace {
46
-
47
- // c and a and b can overlap
48
- void fvec_add(size_t d, const float* a, const float* b, float* c) {
49
- for (size_t i = 0; i < d; i++) {
50
- c[i] = a[i] + b[i];
51
- }
52
- }
45
+ namespace faiss {
53
46
 
54
- void fvec_add(size_t d, const float* a, float b, float* c) {
55
- for (size_t i = 0; i < d; i++) {
56
- c[i] = a[i] + b;
57
- }
47
+ AdditiveQuantizer::AdditiveQuantizer(
48
+ size_t d,
49
+ const std::vector<size_t>& nbits,
50
+ Search_type_t search_type)
51
+ : d(d),
52
+ M(nbits.size()),
53
+ nbits(nbits),
54
+ verbose(false),
55
+ is_trained(false),
56
+ search_type(search_type) {
57
+ norm_max = norm_min = NAN;
58
+ code_size = 0;
59
+ tot_bits = 0;
60
+ total_codebook_size = 0;
61
+ only_8bit = false;
62
+ set_derived_values();
58
63
  }
59
64
 
60
- } // namespace
61
-
62
- namespace faiss {
65
+ AdditiveQuantizer::AdditiveQuantizer()
66
+ : AdditiveQuantizer(0, std::vector<size_t>()) {}
63
67
 
64
68
  void AdditiveQuantizer::set_derived_values() {
65
69
  tot_bits = 0;
66
- is_byte_aligned = true;
70
+ only_8bit = true;
67
71
  codebook_offsets.resize(M + 1, 0);
68
72
  for (int i = 0; i < M; i++) {
69
73
  int nbit = nbits[i];
70
74
  size_t k = 1 << nbit;
71
75
  codebook_offsets[i + 1] = codebook_offsets[i] + k;
72
76
  tot_bits += nbit;
73
- if (nbit % 8 != 0) {
74
- is_byte_aligned = false;
77
+ if (nbit != 0) {
78
+ only_8bit = false;
75
79
  }
76
80
  }
77
81
  total_codebook_size = codebook_offsets[M];
82
+ switch (search_type) {
83
+ case ST_decompress:
84
+ case ST_LUT_nonorm:
85
+ case ST_norm_from_LUT:
86
+ break; // nothing to add
87
+ case ST_norm_float:
88
+ tot_bits += 32;
89
+ break;
90
+ case ST_norm_qint8:
91
+ case ST_norm_cqint8:
92
+ tot_bits += 8;
93
+ break;
94
+ case ST_norm_qint4:
95
+ case ST_norm_cqint4:
96
+ tot_bits += 4;
97
+ break;
98
+ }
99
+
78
100
  // convert bits to bytes
79
101
  code_size = (tot_bits + 7) / 8;
80
102
  }
81
103
 
104
+ namespace {
105
+
106
+ // TODO
107
+ // https://stackoverflow.com/questions/31631224/hacks-for-clamping-integer-to-0-255-and-doubles-to-0-0-1-0
108
+
109
+ uint8_t encode_qint8(float x, float amin, float amax) {
110
+ float x1 = (x - amin) / (amax - amin) * 256;
111
+ int32_t xi = int32_t(floor(x1));
112
+
113
+ return xi < 0 ? 0 : xi > 255 ? 255 : xi;
114
+ }
115
+
116
+ uint8_t encode_qint4(float x, float amin, float amax) {
117
+ float x1 = (x - amin) / (amax - amin) * 16;
118
+ int32_t xi = int32_t(floor(x1));
119
+
120
+ return xi < 0 ? 0 : xi > 15 ? 15 : xi;
121
+ }
122
+
123
+ float decode_qint8(uint8_t i, float amin, float amax) {
124
+ return (i + 0.5) / 256 * (amax - amin) + amin;
125
+ }
126
+
127
+ float decode_qint4(uint8_t i, float amin, float amax) {
128
+ return (i + 0.5) / 16 * (amax - amin) + amin;
129
+ }
130
+
131
+ } // anonymous namespace
132
+
133
+ uint32_t AdditiveQuantizer::encode_qcint(float x) const {
134
+ idx_t id;
135
+ qnorm.assign(idx_t(1), &x, &id, idx_t(1));
136
+ return uint32_t(id);
137
+ }
138
+
139
+ float AdditiveQuantizer::decode_qcint(uint32_t c) const {
140
+ return qnorm.get_xb()[c];
141
+ }
142
+
82
143
  void AdditiveQuantizer::pack_codes(
83
144
  size_t n,
84
145
  const int32_t* codes,
85
146
  uint8_t* packed_codes,
86
- int64_t ld_codes) const {
147
+ int64_t ld_codes,
148
+ const float* norms) const {
87
149
  if (ld_codes == -1) {
88
150
  ld_codes = M;
89
151
  }
152
+ std::vector<float> norm_buf;
153
+ if (search_type == ST_norm_float || search_type == ST_norm_qint4 ||
154
+ search_type == ST_norm_qint8 || search_type == ST_norm_cqint8 ||
155
+ search_type == ST_norm_cqint4) {
156
+ if (!norms) {
157
+ norm_buf.resize(n);
158
+ std::vector<float> x_recons(n * d);
159
+ decode_unpacked(codes, x_recons.data(), n, ld_codes);
160
+ fvec_norms_L2sqr(norm_buf.data(), x_recons.data(), d, n);
161
+ norms = norm_buf.data();
162
+ }
163
+ }
90
164
  #pragma omp parallel for if (n > 1000)
91
165
  for (int64_t i = 0; i < n; i++) {
92
166
  const int32_t* codes1 = codes + i * ld_codes;
@@ -94,6 +168,35 @@ void AdditiveQuantizer::pack_codes(
94
168
  for (int m = 0; m < M; m++) {
95
169
  bsw.write(codes1[m], nbits[m]);
96
170
  }
171
+ switch (search_type) {
172
+ case ST_decompress:
173
+ case ST_LUT_nonorm:
174
+ case ST_norm_from_LUT:
175
+ break;
176
+ case ST_norm_float:
177
+ bsw.write(*(uint32_t*)&norms[i], 32);
178
+ break;
179
+ case ST_norm_qint8: {
180
+ uint8_t b = encode_qint8(norms[i], norm_min, norm_max);
181
+ bsw.write(b, 8);
182
+ break;
183
+ }
184
+ case ST_norm_qint4: {
185
+ uint8_t b = encode_qint4(norms[i], norm_min, norm_max);
186
+ bsw.write(b, 4);
187
+ break;
188
+ }
189
+ case ST_norm_cqint8: {
190
+ uint32_t b = encode_qcint(norms[i]);
191
+ bsw.write(b, 8);
192
+ break;
193
+ }
194
+ case ST_norm_cqint4: {
195
+ uint32_t b = encode_qcint(norms[i]);
196
+ bsw.write(b, 4);
197
+ break;
198
+ }
199
+ }
97
200
  }
98
201
  }
99
202
 
@@ -118,10 +221,39 @@ void AdditiveQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
118
221
  }
119
222
  }
120
223
 
224
+ void AdditiveQuantizer::decode_unpacked(
225
+ const int32_t* code,
226
+ float* x,
227
+ size_t n,
228
+ int64_t ld_codes) const {
229
+ FAISS_THROW_IF_NOT_MSG(
230
+ is_trained, "The additive quantizer is not trained yet.");
231
+
232
+ if (ld_codes == -1) {
233
+ ld_codes = M;
234
+ }
235
+
236
+ // standard additive quantizer decoding
237
+ #pragma omp parallel for if (n > 1000)
238
+ for (int64_t i = 0; i < n; i++) {
239
+ const int32_t* codesi = code + i * ld_codes;
240
+ float* xi = x + i * d;
241
+ for (int m = 0; m < M; m++) {
242
+ int idx = codesi[m];
243
+ const float* c = codebooks.data() + d * (codebook_offsets[m] + idx);
244
+ if (m == 0) {
245
+ memcpy(xi, c, sizeof(*x) * d);
246
+ } else {
247
+ fvec_add(d, xi, c, xi);
248
+ }
249
+ }
250
+ }
251
+ }
252
+
121
253
  AdditiveQuantizer::~AdditiveQuantizer() {}
122
254
 
123
255
  /****************************************************************************
124
- * Support for fast distance computations and search with additive quantizer
256
+ * Support for fast distance computations in centroids
125
257
  ****************************************************************************/
126
258
 
127
259
  void AdditiveQuantizer::compute_centroid_norms(float* norms) const {
@@ -201,7 +333,7 @@ void compute_inner_prod_with_LUT(
201
333
 
202
334
  } // anonymous namespace
203
335
 
204
- void AdditiveQuantizer::knn_exact_inner_product(
336
+ void AdditiveQuantizer::knn_centroids_inner_product(
205
337
  idx_t n,
206
338
  const float* xq,
207
339
  idx_t k,
@@ -227,7 +359,7 @@ void AdditiveQuantizer::knn_exact_inner_product(
227
359
  }
228
360
  }
229
361
 
230
- void AdditiveQuantizer::knn_exact_L2(
362
+ void AdditiveQuantizer::knn_centroids_L2(
231
363
  idx_t n,
232
364
  const float* xq,
233
365
  idx_t k,
@@ -267,4 +399,105 @@ void AdditiveQuantizer::knn_exact_L2(
267
399
  }
268
400
  }
269
401
 
402
+ /****************************************************************************
403
+ * Support for fast distance computations in codes
404
+ ****************************************************************************/
405
+
406
+ namespace {
407
+
408
+ float accumulate_IPs(
409
+ const AdditiveQuantizer& aq,
410
+ BitstringReader& bs,
411
+ const uint8_t* codes,
412
+ const float* LUT) {
413
+ float accu = 0;
414
+ for (int m = 0; m < aq.M; m++) {
415
+ size_t nbit = aq.nbits[m];
416
+ int idx = bs.read(nbit);
417
+ accu += LUT[idx];
418
+ LUT += (uint64_t)1 << nbit;
419
+ }
420
+ return accu;
421
+ }
422
+
423
+ } // anonymous namespace
424
+
425
+ template <>
426
+ float AdditiveQuantizer::
427
+ compute_1_distance_LUT<true, AdditiveQuantizer::ST_LUT_nonorm>(
428
+ const uint8_t* codes,
429
+ const float* LUT) const {
430
+ BitstringReader bs(codes, code_size);
431
+ return accumulate_IPs(*this, bs, codes, LUT);
432
+ }
433
+
434
+ template <>
435
+ float AdditiveQuantizer::
436
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_LUT_nonorm>(
437
+ const uint8_t* codes,
438
+ const float* LUT) const {
439
+ BitstringReader bs(codes, code_size);
440
+ return -accumulate_IPs(*this, bs, codes, LUT);
441
+ }
442
+
443
+ template <>
444
+ float AdditiveQuantizer::
445
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_float>(
446
+ const uint8_t* codes,
447
+ const float* LUT) const {
448
+ BitstringReader bs(codes, code_size);
449
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
450
+ uint32_t norm_i = bs.read(32);
451
+ float norm2 = *(float*)&norm_i;
452
+ return norm2 - 2 * accu;
453
+ }
454
+
455
+ template <>
456
+ float AdditiveQuantizer::
457
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint8>(
458
+ const uint8_t* codes,
459
+ const float* LUT) const {
460
+ BitstringReader bs(codes, code_size);
461
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
462
+ uint32_t norm_i = bs.read(8);
463
+ float norm2 = decode_qcint(norm_i);
464
+ return norm2 - 2 * accu;
465
+ }
466
+
467
+ template <>
468
+ float AdditiveQuantizer::
469
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint4>(
470
+ const uint8_t* codes,
471
+ const float* LUT) const {
472
+ BitstringReader bs(codes, code_size);
473
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
474
+ uint32_t norm_i = bs.read(4);
475
+ float norm2 = decode_qcint(norm_i);
476
+ return norm2 - 2 * accu;
477
+ }
478
+
479
+ template <>
480
+ float AdditiveQuantizer::
481
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint8>(
482
+ const uint8_t* codes,
483
+ const float* LUT) const {
484
+ BitstringReader bs(codes, code_size);
485
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
486
+ uint32_t norm_i = bs.read(8);
487
+ float norm2 = decode_qint8(norm_i, norm_min, norm_max);
488
+ return norm2 - 2 * accu;
489
+ }
490
+
491
+ template <>
492
+ float AdditiveQuantizer::
493
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint4>(
494
+ const uint8_t* codes,
495
+ const float* LUT) const {
496
+ BitstringReader bs(codes, code_size);
497
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
498
+ uint32_t norm_i = bs.read(4);
499
+ float norm2 = decode_qint4(norm_i, norm_min, norm_max);
500
+ return norm2 - 2 * accu;
501
+ }
502
+
270
503
  } // namespace faiss
@@ -11,6 +11,7 @@
11
11
  #include <vector>
12
12
 
13
13
  #include <faiss/Index.h>
14
+ #include <faiss/IndexFlat.h>
14
15
 
15
16
  namespace faiss {
16
17
 
@@ -27,15 +28,44 @@ struct AdditiveQuantizer {
27
28
  std::vector<float> codebooks; ///< codebooks
28
29
 
29
30
  // derived values
30
- std::vector<size_t> codebook_offsets;
31
+ std::vector<uint64_t> codebook_offsets;
31
32
  size_t code_size; ///< code size in bytes
32
33
  size_t tot_bits; ///< total number of bits
33
34
  size_t total_codebook_size; ///< size of the codebook in vectors
34
- bool is_byte_aligned;
35
+ bool only_8bit; ///< are all nbits = 8 (use faster decoder)
35
36
 
36
37
  bool verbose; ///< verbose during training?
37
38
  bool is_trained; ///< is trained or not
38
39
 
40
+ IndexFlat1D qnorm; ///< store and search norms
41
+
42
+ uint32_t encode_qcint(
43
+ float x) const; ///< encode norm by non-uniform scalar quantization
44
+
45
+ float decode_qcint(uint32_t c)
46
+ const; ///< decode norm by non-uniform scalar quantization
47
+
48
+ /// Encodes how search is performed and how vectors are encoded
49
+ enum Search_type_t {
50
+ ST_decompress, ///< decompress database vector
51
+ ST_LUT_nonorm, ///< use a LUT, don't include norms (OK for IP or
52
+ ///< normalized vectors)
53
+ ST_norm_from_LUT, ///< compute the norms from the look-up tables (cost
54
+ ///< is in O(M^2))
55
+ ST_norm_float, ///< use a LUT, and store float32 norm with the vectors
56
+ ST_norm_qint8, ///< use a LUT, and store 8bit-quantized norm
57
+ ST_norm_qint4,
58
+ ST_norm_cqint8, ///< use a LUT, and store non-uniform quantized norm
59
+ ST_norm_cqint4,
60
+ };
61
+
62
+ AdditiveQuantizer(
63
+ size_t d,
64
+ const std::vector<size_t>& nbits,
65
+ Search_type_t search_type = ST_decompress);
66
+
67
+ AdditiveQuantizer();
68
+
39
69
  ///< compute derived values when d, M and nbits have been set
40
70
  void set_derived_values();
41
71
 
@@ -52,15 +82,18 @@ struct AdditiveQuantizer {
52
82
 
53
83
  /** pack a series of code to bit-compact format
54
84
  *
55
- * @param codes codes to be packed, size n * code_size
85
+ * @param codes codes to be packed, size n * code_size
56
86
  * @param packed_codes output bit-compact codes
57
- * @param ld_codes leading dimension of codes
87
+ * @param ld_codes leading dimension of codes
88
+ * @param norms norms of the vectors (size n). Will be computed if
89
+ * needed but not provided
58
90
  */
59
91
  void pack_codes(
60
92
  size_t n,
61
93
  const int32_t* codes,
62
94
  uint8_t* packed_codes,
63
- int64_t ld_codes = -1) const;
95
+ int64_t ld_codes = -1,
96
+ const float* norms = nullptr) const;
64
97
 
65
98
  /** Decode a set of vectors
66
99
  *
@@ -69,9 +102,36 @@ struct AdditiveQuantizer {
69
102
  */
70
103
  void decode(const uint8_t* codes, float* x, size_t n) const;
71
104
 
105
+ /** Decode a set of vectors in non-packed format
106
+ *
107
+ * @param codes codes to decode, size n * ld_codes
108
+ * @param x output vectors, size n * d
109
+ */
110
+ void decode_unpacked(
111
+ const int32_t* codes,
112
+ float* x,
113
+ size_t n,
114
+ int64_t ld_codes = -1) const;
115
+
116
+ /****************************************************************************
117
+ * Search functions in an external set of codes.
118
+ ****************************************************************************/
119
+
120
+ /// Also determines what's in the codes
121
+ Search_type_t search_type;
122
+
123
+ /// min/max for quantization of norms
124
+ float norm_min, norm_max;
125
+
126
+ template <bool is_IP, Search_type_t effective_search_type>
127
+ float compute_1_distance_LUT(const uint8_t* codes, const float* LUT) const;
128
+
129
+ /*
130
+ float compute_1_L2sqr(const uint8_t* codes, const float* LUT);
131
+ */
72
132
  /****************************************************************************
73
- * Support for exhaustive distance computations with the centroids.
74
- * Hence, the number of elements that can be enumerated is not too large.
133
+ * Support for exhaustive distance computations with all the centroids.
134
+ * Hence, the number of these centroids should not be too large.
75
135
  ****************************************************************************/
76
136
  using idx_t = Index::idx_t;
77
137
 
@@ -87,7 +147,7 @@ struct AdditiveQuantizer {
87
147
  void compute_LUT(size_t n, const float* xq, float* LUT) const;
88
148
 
89
149
  /// exact IP search
90
- void knn_exact_inner_product(
150
+ void knn_centroids_inner_product(
91
151
  idx_t n,
92
152
  const float* xq,
93
153
  idx_t k,
@@ -101,7 +161,7 @@ struct AdditiveQuantizer {
101
161
  void compute_centroid_norms(float* norms) const;
102
162
 
103
163
  /** Exact L2 search, with precomputed norms */
104
- void knn_exact_L2(
164
+ void knn_centroids_L2(
105
165
  idx_t n,
106
166
  const float* xq,
107
167
  idx_t k,
@@ -434,17 +434,22 @@ void HNSW::add_links_starting_from(
434
434
 
435
435
  ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
436
436
 
437
+ std::vector<storage_idx_t> neighbors;
438
+ neighbors.reserve(link_targets.size());
437
439
  while (!link_targets.empty()) {
438
- int other_id = link_targets.top().id;
440
+ storage_idx_t other_id = link_targets.top().id;
441
+ add_link(*this, ptdis, pt_id, other_id, level);
442
+ neighbors.push_back(other_id);
443
+ link_targets.pop();
444
+ }
439
445
 
446
+ omp_unset_lock(&locks[pt_id]);
447
+ for (storage_idx_t other_id : neighbors) {
440
448
  omp_set_lock(&locks[other_id]);
441
449
  add_link(*this, ptdis, other_id, pt_id, level);
442
450
  omp_unset_lock(&locks[other_id]);
443
-
444
- add_link(*this, ptdis, pt_id, other_id, level);
445
-
446
- link_targets.pop();
447
451
  }
452
+ omp_set_lock(&locks[pt_id]);
448
453
  }
449
454
 
450
455
  /**************************************************************