faiss 0.2.3 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (63) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  6. data/vendor/faiss/faiss/Clustering.h +14 -0
  7. data/vendor/faiss/faiss/Index.h +1 -1
  8. data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
  9. data/vendor/faiss/faiss/Index2Layer.h +2 -16
  10. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  11. data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
  12. data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
  13. data/vendor/faiss/faiss/IndexFlat.h +9 -15
  14. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  15. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  16. data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
  17. data/vendor/faiss/faiss/IndexIVF.h +25 -7
  18. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  20. data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
  21. data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
  22. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
  23. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
  24. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
  25. data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
  26. data/vendor/faiss/faiss/IndexLSH.h +2 -15
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
  28. data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
  29. data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
  30. data/vendor/faiss/faiss/IndexPQ.h +2 -17
  31. data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
  32. data/vendor/faiss/faiss/IndexRefine.h +10 -0
  33. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
  35. data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
  36. data/vendor/faiss/faiss/VectorTransform.h +3 -0
  37. data/vendor/faiss/faiss/clone_index.cpp +3 -2
  38. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
  39. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  40. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
  41. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
  42. data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
  43. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
  44. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
  45. data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
  46. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  47. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
  48. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
  49. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
  50. data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
  51. data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
  52. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  53. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  54. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  55. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  56. data/vendor/faiss/faiss/index_factory.cpp +585 -414
  57. data/vendor/faiss/faiss/index_factory.h +3 -0
  58. data/vendor/faiss/faiss/utils/distances.cpp +4 -2
  59. data/vendor/faiss/faiss/utils/distances.h +36 -3
  60. data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
  61. data/vendor/faiss/faiss/utils/utils.h +1 -1
  62. metadata +12 -5
  63. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -357,6 +357,7 @@ PCAMatrix::PCAMatrix(
357
357
  is_trained = false;
358
358
  max_points_per_d = 1000;
359
359
  balanced_bins = 0;
360
+ epsilon = 0;
360
361
  }
361
362
 
362
363
  namespace {
@@ -620,7 +621,7 @@ void PCAMatrix::prepare_Ab() {
620
621
  if (eigen_power != 0) {
621
622
  float* ai = A.data();
622
623
  for (int i = 0; i < d_out; i++) {
623
- float factor = pow(eigenvalues[i], eigen_power);
624
+ float factor = pow(eigenvalues[i] + epsilon, eigen_power);
624
625
  for (int j = 0; j < d_in; j++)
625
626
  *ai++ *= factor;
626
627
  }
@@ -129,6 +129,9 @@ struct PCAMatrix : LinearTransform {
129
129
  */
130
130
  float eigen_power;
131
131
 
132
+ /// value added to eigenvalues to avoid division by 0 when whitening
133
+ float epsilon;
134
+
132
135
  /// random rotation after PCA
133
136
  bool random_rotation;
134
137
 
@@ -15,6 +15,7 @@
15
15
  #include <faiss/impl/FaissAssert.h>
16
16
 
17
17
  #include <faiss/Index2Layer.h>
18
+ #include <faiss/IndexAdditiveQuantizer.h>
18
19
  #include <faiss/IndexFlat.h>
19
20
  #include <faiss/IndexHNSW.h>
20
21
  #include <faiss/IndexIVF.h>
@@ -27,7 +28,6 @@
27
28
  #include <faiss/IndexNSG.h>
28
29
  #include <faiss/IndexPQ.h>
29
30
  #include <faiss/IndexPreTransform.h>
30
- #include <faiss/IndexResidual.h>
31
31
  #include <faiss/IndexScalarQuantizer.h>
32
32
  #include <faiss/MetaIndexes.h>
33
33
  #include <faiss/VectorTransform.h>
@@ -80,9 +80,10 @@ Index* Cloner::clone_Index(const Index* index) {
80
80
  TRYCLONE(IndexFlatIP, index)
81
81
  TRYCLONE(IndexFlat, index)
82
82
  TRYCLONE(IndexLattice, index)
83
- TRYCLONE(IndexResidual, index)
83
+ TRYCLONE(IndexResidualQuantizer, index)
84
84
  TRYCLONE(IndexScalarQuantizer, index)
85
85
  TRYCLONE(MultiIndexQuantizer, index)
86
+ TRYCLONE(ResidualCoarseQuantizer, index)
86
87
  if (const IndexIVF* ivf = dynamic_cast<const IndexIVF*>(index)) {
87
88
  IndexIVF* res = clone_IndexIVF(ivf);
88
89
  if (ivf->invlists == nullptr) {
@@ -40,7 +40,7 @@ void ToCPUCloner::merge_index(Index* dst, Index* src, bool successive_ids) {
40
40
  auto ifl2 = dynamic_cast<const IndexFlat*>(src);
41
41
  FAISS_ASSERT(ifl2);
42
42
  FAISS_ASSERT(successive_ids);
43
- ifl->add(ifl2->ntotal, ifl2->xb.data());
43
+ ifl->add(ifl2->ntotal, ifl2->get_xb());
44
44
  } else if (auto ifl = dynamic_cast<IndexIVFFlat*>(dst)) {
45
45
  auto ifl2 = dynamic_cast<IndexIVFFlat*>(src);
46
46
  FAISS_ASSERT(ifl2);
@@ -329,7 +329,7 @@ Index* ToGpuClonerMultiple::clone_Index_to_shards(const Index* index) {
329
329
  if (index->ntotal > 0) {
330
330
  long i0 = index->ntotal * i / n;
331
331
  long i1 = index->ntotal * (i + 1) / n;
332
- shards[i]->add(i1 - i0, index_flat->xb.data() + i0 * index->d);
332
+ shards[i]->add(i1 - i0, index_flat->get_xb() + i0 * index->d);
333
333
  }
334
334
  }
335
335
  }
@@ -0,0 +1,60 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/impl/LocalSearchQuantizer.h>
11
+
12
+ #include <memory>
13
+
14
+ namespace faiss {
15
+ namespace gpu {
16
+
17
+ class GpuResourcesProvider;
18
+ struct IcmEncoderShards;
19
+
20
+ /** Perform LSQ encoding on GPU.
21
+ *
22
+ * Split input vectors to different devices and call IcmEncoderImpl::encode
23
+ * to encode them
24
+ */
25
+ class GpuIcmEncoder : public lsq::IcmEncoder {
26
+ public:
27
+ GpuIcmEncoder(
28
+ const LocalSearchQuantizer* lsq,
29
+ const std::vector<GpuResourcesProvider*>& provs,
30
+ const std::vector<int>& devices);
31
+
32
+ ~GpuIcmEncoder();
33
+
34
+ GpuIcmEncoder(const GpuIcmEncoder&) = delete;
35
+ GpuIcmEncoder& operator=(const GpuIcmEncoder&) = delete;
36
+
37
+ void set_binary_term() override;
38
+
39
+ void encode(
40
+ int32_t* codes,
41
+ const float* x,
42
+ std::mt19937& gen,
43
+ size_t n,
44
+ size_t ils_iters) const override;
45
+
46
+ private:
47
+ std::unique_ptr<IcmEncoderShards> shards;
48
+ };
49
+
50
+ struct GpuIcmEncoderFactory : public lsq::IcmEncoderFactory {
51
+ explicit GpuIcmEncoderFactory(int ngpus = 1);
52
+
53
+ lsq::IcmEncoder* get(const LocalSearchQuantizer* lsq) override;
54
+
55
+ std::vector<GpuResourcesProvider*> provs;
56
+ std::vector<int> devices;
57
+ };
58
+
59
+ } // namespace gpu
60
+ } // namespace faiss
@@ -8,7 +8,6 @@
8
8
  // -*- c++ -*-
9
9
 
10
10
  #include <faiss/impl/AdditiveQuantizer.h>
11
- #include <faiss/impl/FaissAssert.h>
12
11
 
13
12
  #include <cstddef>
14
13
  #include <cstdio>
@@ -18,9 +17,10 @@
18
17
 
19
18
  #include <algorithm>
20
19
 
20
+ #include <faiss/impl/FaissAssert.h>
21
21
  #include <faiss/utils/Heap.h>
22
22
  #include <faiss/utils/distances.h>
23
- #include <faiss/utils/hamming.h> // BitstringWriter
23
+ #include <faiss/utils/hamming.h>
24
24
  #include <faiss/utils/utils.h>
25
25
 
26
26
  extern "C" {
@@ -42,51 +42,125 @@ int sgemm_(
42
42
  FINTEGER* ldc);
43
43
  }
44
44
 
45
- namespace {
46
-
47
- // c and a and b can overlap
48
- void fvec_add(size_t d, const float* a, const float* b, float* c) {
49
- for (size_t i = 0; i < d; i++) {
50
- c[i] = a[i] + b[i];
51
- }
52
- }
45
+ namespace faiss {
53
46
 
54
- void fvec_add(size_t d, const float* a, float b, float* c) {
55
- for (size_t i = 0; i < d; i++) {
56
- c[i] = a[i] + b;
57
- }
47
+ AdditiveQuantizer::AdditiveQuantizer(
48
+ size_t d,
49
+ const std::vector<size_t>& nbits,
50
+ Search_type_t search_type)
51
+ : d(d),
52
+ M(nbits.size()),
53
+ nbits(nbits),
54
+ verbose(false),
55
+ is_trained(false),
56
+ search_type(search_type) {
57
+ norm_max = norm_min = NAN;
58
+ code_size = 0;
59
+ tot_bits = 0;
60
+ total_codebook_size = 0;
61
+ only_8bit = false;
62
+ set_derived_values();
58
63
  }
59
64
 
60
- } // namespace
61
-
62
- namespace faiss {
65
+ AdditiveQuantizer::AdditiveQuantizer()
66
+ : AdditiveQuantizer(0, std::vector<size_t>()) {}
63
67
 
64
68
  void AdditiveQuantizer::set_derived_values() {
65
69
  tot_bits = 0;
66
- is_byte_aligned = true;
70
+ only_8bit = true;
67
71
  codebook_offsets.resize(M + 1, 0);
68
72
  for (int i = 0; i < M; i++) {
69
73
  int nbit = nbits[i];
70
74
  size_t k = 1 << nbit;
71
75
  codebook_offsets[i + 1] = codebook_offsets[i] + k;
72
76
  tot_bits += nbit;
73
- if (nbit % 8 != 0) {
74
- is_byte_aligned = false;
77
+ if (nbit != 0) {
78
+ only_8bit = false;
75
79
  }
76
80
  }
77
81
  total_codebook_size = codebook_offsets[M];
82
+ switch (search_type) {
83
+ case ST_decompress:
84
+ case ST_LUT_nonorm:
85
+ case ST_norm_from_LUT:
86
+ break; // nothing to add
87
+ case ST_norm_float:
88
+ tot_bits += 32;
89
+ break;
90
+ case ST_norm_qint8:
91
+ case ST_norm_cqint8:
92
+ tot_bits += 8;
93
+ break;
94
+ case ST_norm_qint4:
95
+ case ST_norm_cqint4:
96
+ tot_bits += 4;
97
+ break;
98
+ }
99
+
78
100
  // convert bits to bytes
79
101
  code_size = (tot_bits + 7) / 8;
80
102
  }
81
103
 
104
+ namespace {
105
+
106
+ // TODO
107
+ // https://stackoverflow.com/questions/31631224/hacks-for-clamping-integer-to-0-255-and-doubles-to-0-0-1-0
108
+
109
+ uint8_t encode_qint8(float x, float amin, float amax) {
110
+ float x1 = (x - amin) / (amax - amin) * 256;
111
+ int32_t xi = int32_t(floor(x1));
112
+
113
+ return xi < 0 ? 0 : xi > 255 ? 255 : xi;
114
+ }
115
+
116
+ uint8_t encode_qint4(float x, float amin, float amax) {
117
+ float x1 = (x - amin) / (amax - amin) * 16;
118
+ int32_t xi = int32_t(floor(x1));
119
+
120
+ return xi < 0 ? 0 : xi > 15 ? 15 : xi;
121
+ }
122
+
123
+ float decode_qint8(uint8_t i, float amin, float amax) {
124
+ return (i + 0.5) / 256 * (amax - amin) + amin;
125
+ }
126
+
127
+ float decode_qint4(uint8_t i, float amin, float amax) {
128
+ return (i + 0.5) / 16 * (amax - amin) + amin;
129
+ }
130
+
131
+ } // anonymous namespace
132
+
133
+ uint32_t AdditiveQuantizer::encode_qcint(float x) const {
134
+ idx_t id;
135
+ qnorm.assign(idx_t(1), &x, &id, idx_t(1));
136
+ return uint32_t(id);
137
+ }
138
+
139
+ float AdditiveQuantizer::decode_qcint(uint32_t c) const {
140
+ return qnorm.get_xb()[c];
141
+ }
142
+
82
143
  void AdditiveQuantizer::pack_codes(
83
144
  size_t n,
84
145
  const int32_t* codes,
85
146
  uint8_t* packed_codes,
86
- int64_t ld_codes) const {
147
+ int64_t ld_codes,
148
+ const float* norms) const {
87
149
  if (ld_codes == -1) {
88
150
  ld_codes = M;
89
151
  }
152
+ std::vector<float> norm_buf;
153
+ if (search_type == ST_norm_float || search_type == ST_norm_qint4 ||
154
+ search_type == ST_norm_qint8 || search_type == ST_norm_cqint8 ||
155
+ search_type == ST_norm_cqint4) {
156
+ if (!norms) {
157
+ norm_buf.resize(n);
158
+ std::vector<float> x_recons(n * d);
159
+ decode_unpacked(codes, x_recons.data(), n, ld_codes);
160
+ fvec_norms_L2sqr(norm_buf.data(), x_recons.data(), d, n);
161
+ norms = norm_buf.data();
162
+ }
163
+ }
90
164
  #pragma omp parallel for if (n > 1000)
91
165
  for (int64_t i = 0; i < n; i++) {
92
166
  const int32_t* codes1 = codes + i * ld_codes;
@@ -94,6 +168,35 @@ void AdditiveQuantizer::pack_codes(
94
168
  for (int m = 0; m < M; m++) {
95
169
  bsw.write(codes1[m], nbits[m]);
96
170
  }
171
+ switch (search_type) {
172
+ case ST_decompress:
173
+ case ST_LUT_nonorm:
174
+ case ST_norm_from_LUT:
175
+ break;
176
+ case ST_norm_float:
177
+ bsw.write(*(uint32_t*)&norms[i], 32);
178
+ break;
179
+ case ST_norm_qint8: {
180
+ uint8_t b = encode_qint8(norms[i], norm_min, norm_max);
181
+ bsw.write(b, 8);
182
+ break;
183
+ }
184
+ case ST_norm_qint4: {
185
+ uint8_t b = encode_qint4(norms[i], norm_min, norm_max);
186
+ bsw.write(b, 4);
187
+ break;
188
+ }
189
+ case ST_norm_cqint8: {
190
+ uint32_t b = encode_qcint(norms[i]);
191
+ bsw.write(b, 8);
192
+ break;
193
+ }
194
+ case ST_norm_cqint4: {
195
+ uint32_t b = encode_qcint(norms[i]);
196
+ bsw.write(b, 4);
197
+ break;
198
+ }
199
+ }
97
200
  }
98
201
  }
99
202
 
@@ -118,10 +221,39 @@ void AdditiveQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
118
221
  }
119
222
  }
120
223
 
224
+ void AdditiveQuantizer::decode_unpacked(
225
+ const int32_t* code,
226
+ float* x,
227
+ size_t n,
228
+ int64_t ld_codes) const {
229
+ FAISS_THROW_IF_NOT_MSG(
230
+ is_trained, "The additive quantizer is not trained yet.");
231
+
232
+ if (ld_codes == -1) {
233
+ ld_codes = M;
234
+ }
235
+
236
+ // standard additive quantizer decoding
237
+ #pragma omp parallel for if (n > 1000)
238
+ for (int64_t i = 0; i < n; i++) {
239
+ const int32_t* codesi = code + i * ld_codes;
240
+ float* xi = x + i * d;
241
+ for (int m = 0; m < M; m++) {
242
+ int idx = codesi[m];
243
+ const float* c = codebooks.data() + d * (codebook_offsets[m] + idx);
244
+ if (m == 0) {
245
+ memcpy(xi, c, sizeof(*x) * d);
246
+ } else {
247
+ fvec_add(d, xi, c, xi);
248
+ }
249
+ }
250
+ }
251
+ }
252
+
121
253
  AdditiveQuantizer::~AdditiveQuantizer() {}
122
254
 
123
255
  /****************************************************************************
124
- * Support for fast distance computations and search with additive quantizer
256
+ * Support for fast distance computations in centroids
125
257
  ****************************************************************************/
126
258
 
127
259
  void AdditiveQuantizer::compute_centroid_norms(float* norms) const {
@@ -201,7 +333,7 @@ void compute_inner_prod_with_LUT(
201
333
 
202
334
  } // anonymous namespace
203
335
 
204
- void AdditiveQuantizer::knn_exact_inner_product(
336
+ void AdditiveQuantizer::knn_centroids_inner_product(
205
337
  idx_t n,
206
338
  const float* xq,
207
339
  idx_t k,
@@ -227,7 +359,7 @@ void AdditiveQuantizer::knn_exact_inner_product(
227
359
  }
228
360
  }
229
361
 
230
- void AdditiveQuantizer::knn_exact_L2(
362
+ void AdditiveQuantizer::knn_centroids_L2(
231
363
  idx_t n,
232
364
  const float* xq,
233
365
  idx_t k,
@@ -267,4 +399,105 @@ void AdditiveQuantizer::knn_exact_L2(
267
399
  }
268
400
  }
269
401
 
402
+ /****************************************************************************
403
+ * Support for fast distance computations in codes
404
+ ****************************************************************************/
405
+
406
+ namespace {
407
+
408
+ float accumulate_IPs(
409
+ const AdditiveQuantizer& aq,
410
+ BitstringReader& bs,
411
+ const uint8_t* codes,
412
+ const float* LUT) {
413
+ float accu = 0;
414
+ for (int m = 0; m < aq.M; m++) {
415
+ size_t nbit = aq.nbits[m];
416
+ int idx = bs.read(nbit);
417
+ accu += LUT[idx];
418
+ LUT += (uint64_t)1 << nbit;
419
+ }
420
+ return accu;
421
+ }
422
+
423
+ } // anonymous namespace
424
+
425
+ template <>
426
+ float AdditiveQuantizer::
427
+ compute_1_distance_LUT<true, AdditiveQuantizer::ST_LUT_nonorm>(
428
+ const uint8_t* codes,
429
+ const float* LUT) const {
430
+ BitstringReader bs(codes, code_size);
431
+ return accumulate_IPs(*this, bs, codes, LUT);
432
+ }
433
+
434
+ template <>
435
+ float AdditiveQuantizer::
436
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_LUT_nonorm>(
437
+ const uint8_t* codes,
438
+ const float* LUT) const {
439
+ BitstringReader bs(codes, code_size);
440
+ return -accumulate_IPs(*this, bs, codes, LUT);
441
+ }
442
+
443
+ template <>
444
+ float AdditiveQuantizer::
445
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_float>(
446
+ const uint8_t* codes,
447
+ const float* LUT) const {
448
+ BitstringReader bs(codes, code_size);
449
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
450
+ uint32_t norm_i = bs.read(32);
451
+ float norm2 = *(float*)&norm_i;
452
+ return norm2 - 2 * accu;
453
+ }
454
+
455
+ template <>
456
+ float AdditiveQuantizer::
457
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint8>(
458
+ const uint8_t* codes,
459
+ const float* LUT) const {
460
+ BitstringReader bs(codes, code_size);
461
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
462
+ uint32_t norm_i = bs.read(8);
463
+ float norm2 = decode_qcint(norm_i);
464
+ return norm2 - 2 * accu;
465
+ }
466
+
467
+ template <>
468
+ float AdditiveQuantizer::
469
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint4>(
470
+ const uint8_t* codes,
471
+ const float* LUT) const {
472
+ BitstringReader bs(codes, code_size);
473
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
474
+ uint32_t norm_i = bs.read(4);
475
+ float norm2 = decode_qcint(norm_i);
476
+ return norm2 - 2 * accu;
477
+ }
478
+
479
+ template <>
480
+ float AdditiveQuantizer::
481
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint8>(
482
+ const uint8_t* codes,
483
+ const float* LUT) const {
484
+ BitstringReader bs(codes, code_size);
485
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
486
+ uint32_t norm_i = bs.read(8);
487
+ float norm2 = decode_qint8(norm_i, norm_min, norm_max);
488
+ return norm2 - 2 * accu;
489
+ }
490
+
491
+ template <>
492
+ float AdditiveQuantizer::
493
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint4>(
494
+ const uint8_t* codes,
495
+ const float* LUT) const {
496
+ BitstringReader bs(codes, code_size);
497
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
498
+ uint32_t norm_i = bs.read(4);
499
+ float norm2 = decode_qint4(norm_i, norm_min, norm_max);
500
+ return norm2 - 2 * accu;
501
+ }
502
+
270
503
  } // namespace faiss
@@ -11,6 +11,7 @@
11
11
  #include <vector>
12
12
 
13
13
  #include <faiss/Index.h>
14
+ #include <faiss/IndexFlat.h>
14
15
 
15
16
  namespace faiss {
16
17
 
@@ -27,15 +28,44 @@ struct AdditiveQuantizer {
27
28
  std::vector<float> codebooks; ///< codebooks
28
29
 
29
30
  // derived values
30
- std::vector<size_t> codebook_offsets;
31
+ std::vector<uint64_t> codebook_offsets;
31
32
  size_t code_size; ///< code size in bytes
32
33
  size_t tot_bits; ///< total number of bits
33
34
  size_t total_codebook_size; ///< size of the codebook in vectors
34
- bool is_byte_aligned;
35
+ bool only_8bit; ///< are all nbits = 8 (use faster decoder)
35
36
 
36
37
  bool verbose; ///< verbose during training?
37
38
  bool is_trained; ///< is trained or not
38
39
 
40
+ IndexFlat1D qnorm; ///< store and search norms
41
+
42
+ uint32_t encode_qcint(
43
+ float x) const; ///< encode norm by non-uniform scalar quantization
44
+
45
+ float decode_qcint(uint32_t c)
46
+ const; ///< decode norm by non-uniform scalar quantization
47
+
48
+ /// Encodes how search is performed and how vectors are encoded
49
+ enum Search_type_t {
50
+ ST_decompress, ///< decompress database vector
51
+ ST_LUT_nonorm, ///< use a LUT, don't include norms (OK for IP or
52
+ ///< normalized vectors)
53
+ ST_norm_from_LUT, ///< compute the norms from the look-up tables (cost
54
+ ///< is in O(M^2))
55
+ ST_norm_float, ///< use a LUT, and store float32 norm with the vectors
56
+ ST_norm_qint8, ///< use a LUT, and store 8bit-quantized norm
57
+ ST_norm_qint4,
58
+ ST_norm_cqint8, ///< use a LUT, and store non-uniform quantized norm
59
+ ST_norm_cqint4,
60
+ };
61
+
62
+ AdditiveQuantizer(
63
+ size_t d,
64
+ const std::vector<size_t>& nbits,
65
+ Search_type_t search_type = ST_decompress);
66
+
67
+ AdditiveQuantizer();
68
+
39
69
  ///< compute derived values when d, M and nbits have been set
40
70
  void set_derived_values();
41
71
 
@@ -52,15 +82,18 @@ struct AdditiveQuantizer {
52
82
 
53
83
  /** pack a series of code to bit-compact format
54
84
  *
55
- * @param codes codes to be packed, size n * code_size
85
+ * @param codes codes to be packed, size n * code_size
56
86
  * @param packed_codes output bit-compact codes
57
- * @param ld_codes leading dimension of codes
87
+ * @param ld_codes leading dimension of codes
88
+ * @param norms norms of the vectors (size n). Will be computed if
89
+ * needed but not provided
58
90
  */
59
91
  void pack_codes(
60
92
  size_t n,
61
93
  const int32_t* codes,
62
94
  uint8_t* packed_codes,
63
- int64_t ld_codes = -1) const;
95
+ int64_t ld_codes = -1,
96
+ const float* norms = nullptr) const;
64
97
 
65
98
  /** Decode a set of vectors
66
99
  *
@@ -69,9 +102,36 @@ struct AdditiveQuantizer {
69
102
  */
70
103
  void decode(const uint8_t* codes, float* x, size_t n) const;
71
104
 
105
+ /** Decode a set of vectors in non-packed format
106
+ *
107
+ * @param codes codes to decode, size n * ld_codes
108
+ * @param x output vectors, size n * d
109
+ */
110
+ void decode_unpacked(
111
+ const int32_t* codes,
112
+ float* x,
113
+ size_t n,
114
+ int64_t ld_codes = -1) const;
115
+
116
+ /****************************************************************************
117
+ * Search functions in an external set of codes.
118
+ ****************************************************************************/
119
+
120
+ /// Also determines what's in the codes
121
+ Search_type_t search_type;
122
+
123
+ /// min/max for quantization of norms
124
+ float norm_min, norm_max;
125
+
126
+ template <bool is_IP, Search_type_t effective_search_type>
127
+ float compute_1_distance_LUT(const uint8_t* codes, const float* LUT) const;
128
+
129
+ /*
130
+ float compute_1_L2sqr(const uint8_t* codes, const float* LUT);
131
+ */
72
132
  /****************************************************************************
73
- * Support for exhaustive distance computations with the centroids.
74
- * Hence, the number of elements that can be enumerated is not too large.
133
+ * Support for exhaustive distance computations with all the centroids.
134
+ * Hence, the number of these centroids should not be too large.
75
135
  ****************************************************************************/
76
136
  using idx_t = Index::idx_t;
77
137
 
@@ -87,7 +147,7 @@ struct AdditiveQuantizer {
87
147
  void compute_LUT(size_t n, const float* xq, float* LUT) const;
88
148
 
89
149
  /// exact IP search
90
- void knn_exact_inner_product(
150
+ void knn_centroids_inner_product(
91
151
  idx_t n,
92
152
  const float* xq,
93
153
  idx_t k,
@@ -101,7 +161,7 @@ struct AdditiveQuantizer {
101
161
  void compute_centroid_norms(float* norms) const;
102
162
 
103
163
  /** Exact L2 search, with precomputed norms */
104
- void knn_exact_L2(
164
+ void knn_centroids_L2(
105
165
  idx_t n,
106
166
  const float* xq,
107
167
  idx_t k,
@@ -434,17 +434,22 @@ void HNSW::add_links_starting_from(
434
434
 
435
435
  ::faiss::shrink_neighbor_list(ptdis, link_targets, M);
436
436
 
437
+ std::vector<storage_idx_t> neighbors;
438
+ neighbors.reserve(link_targets.size());
437
439
  while (!link_targets.empty()) {
438
- int other_id = link_targets.top().id;
440
+ storage_idx_t other_id = link_targets.top().id;
441
+ add_link(*this, ptdis, pt_id, other_id, level);
442
+ neighbors.push_back(other_id);
443
+ link_targets.pop();
444
+ }
439
445
 
446
+ omp_unset_lock(&locks[pt_id]);
447
+ for (storage_idx_t other_id : neighbors) {
440
448
  omp_set_lock(&locks[other_id]);
441
449
  add_link(*this, ptdis, other_id, pt_id, level);
442
450
  omp_unset_lock(&locks[other_id]);
443
-
444
- add_link(*this, ptdis, pt_id, other_id, level);
445
-
446
- link_targets.pop();
447
451
  }
452
+ omp_set_lock(&locks[pt_id]);
448
453
  }
449
454
 
450
455
  /**************************************************************