faiss 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -5,9 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // quiet the noise
9
- // clang-format off
10
-
11
8
  #include <faiss/IndexAdditiveQuantizer.h>
12
9
 
13
10
  #include <algorithm>
@@ -21,7 +18,6 @@
21
18
  #include <faiss/utils/extra_distances.h>
22
19
  #include <faiss/utils/utils.h>
23
20
 
24
-
25
21
  namespace faiss {
26
22
 
27
23
  /**************************************************************************************
@@ -29,15 +25,13 @@ namespace faiss {
29
25
  **************************************************************************************/
30
26
 
31
27
  IndexAdditiveQuantizer::IndexAdditiveQuantizer(
32
- idx_t d,
33
- AdditiveQuantizer* aq,
34
- MetricType metric):
35
- IndexFlatCodes(aq->code_size, d, metric), aq(aq)
36
- {
28
+ idx_t d,
29
+ AdditiveQuantizer* aq,
30
+ MetricType metric)
31
+ : IndexFlatCodes(aq->code_size, d, metric), aq(aq) {
37
32
  FAISS_THROW_IF_NOT(metric == METRIC_INNER_PRODUCT || metric == METRIC_L2);
38
33
  }
39
34
 
40
-
41
35
  namespace {
42
36
 
43
37
  /************************************************************
@@ -45,21 +39,22 @@ namespace {
45
39
  ************************************************************/
46
40
 
47
41
  template <class VectorDistance>
48
- struct AQDistanceComputerDecompress: FlatCodesDistanceComputer {
42
+ struct AQDistanceComputerDecompress : FlatCodesDistanceComputer {
49
43
  std::vector<float> tmp;
50
- const AdditiveQuantizer & aq;
44
+ const AdditiveQuantizer& aq;
51
45
  VectorDistance vd;
52
46
  size_t d;
53
47
 
54
- AQDistanceComputerDecompress(const IndexAdditiveQuantizer &iaq, VectorDistance vd):
55
- FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
56
- tmp(iaq.d * 2),
57
- aq(*iaq.aq),
58
- vd(vd),
59
- d(iaq.d)
60
- {}
48
+ AQDistanceComputerDecompress(
49
+ const IndexAdditiveQuantizer& iaq,
50
+ VectorDistance vd)
51
+ : FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
52
+ tmp(iaq.d * 2),
53
+ aq(*iaq.aq),
54
+ vd(vd),
55
+ d(iaq.d) {}
61
56
 
62
- const float *q;
57
+ const float* q;
63
58
  void set_query(const float* x) final {
64
59
  q = x;
65
60
  }
@@ -70,27 +65,25 @@ struct AQDistanceComputerDecompress: FlatCodesDistanceComputer {
70
65
  return vd(tmp.data(), tmp.data() + d);
71
66
  }
72
67
 
73
- float distance_to_code(const uint8_t *code) final {
68
+ float distance_to_code(const uint8_t* code) final {
74
69
  aq.decode(code, tmp.data(), 1);
75
70
  return vd(q, tmp.data());
76
71
  }
77
72
 
78
- virtual ~AQDistanceComputerDecompress() {}
73
+ virtual ~AQDistanceComputerDecompress() = default;
79
74
  };
80
75
 
81
-
82
- template<bool is_IP, AdditiveQuantizer::Search_type_t st>
83
- struct AQDistanceComputerLUT: FlatCodesDistanceComputer {
76
+ template <bool is_IP, AdditiveQuantizer::Search_type_t st>
77
+ struct AQDistanceComputerLUT : FlatCodesDistanceComputer {
84
78
  std::vector<float> LUT;
85
- const AdditiveQuantizer & aq;
79
+ const AdditiveQuantizer& aq;
86
80
  size_t d;
87
81
 
88
- explicit AQDistanceComputerLUT(const IndexAdditiveQuantizer &iaq):
89
- FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
90
- LUT(iaq.aq->total_codebook_size + iaq.d * 2),
91
- aq(*iaq.aq),
92
- d(iaq.d)
93
- {}
82
+ explicit AQDistanceComputerLUT(const IndexAdditiveQuantizer& iaq)
83
+ : FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
84
+ LUT(iaq.aq->total_codebook_size + iaq.d * 2),
85
+ aq(*iaq.aq),
86
+ d(iaq.d) {}
94
87
 
95
88
  float bias;
96
89
  void set_query(const float* x) final {
@@ -104,40 +97,38 @@ struct AQDistanceComputerLUT: FlatCodesDistanceComputer {
104
97
  }
105
98
 
106
99
  float symmetric_dis(idx_t i, idx_t j) final {
107
- float *tmp = LUT.data();
100
+ float* tmp = LUT.data();
108
101
  aq.decode(codes + i * d, tmp, 1);
109
102
  aq.decode(codes + j * d, tmp + d, 1);
110
103
  return fvec_L2sqr(tmp, tmp + d, d);
111
104
  }
112
105
 
113
- float distance_to_code(const uint8_t *code) final {
106
+ float distance_to_code(const uint8_t* code) final {
114
107
  return bias + aq.compute_1_distance_LUT<is_IP, st>(code, LUT.data());
115
108
  }
116
109
 
117
- virtual ~AQDistanceComputerLUT() {}
110
+ virtual ~AQDistanceComputerLUT() = default;
118
111
  };
119
112
 
120
-
121
-
122
113
  /************************************************************
123
114
  * scanning implementation for search
124
115
  ************************************************************/
125
116
 
126
-
127
- template <class VectorDistance, class ResultHandler>
117
+ template <class VectorDistance, class BlockResultHandler>
128
118
  void search_with_decompress(
129
119
  const IndexAdditiveQuantizer& ir,
130
120
  const float* xq,
131
121
  VectorDistance& vd,
132
- ResultHandler& res) {
122
+ BlockResultHandler& res) {
133
123
  const uint8_t* codes = ir.codes.data();
134
124
  size_t ntotal = ir.ntotal;
135
125
  size_t code_size = ir.code_size;
136
- const AdditiveQuantizer *aq = ir.aq;
126
+ const AdditiveQuantizer* aq = ir.aq;
137
127
 
138
- using SingleResultHandler = typename ResultHandler::SingleResultHandler;
128
+ using SingleResultHandler =
129
+ typename BlockResultHandler::SingleResultHandler;
139
130
 
140
- #pragma omp parallel for if(res.nq > 100)
131
+ #pragma omp parallel for if (res.nq > 100)
141
132
  for (int64_t q = 0; q < res.nq; q++) {
142
133
  SingleResultHandler resi(res);
143
134
  resi.begin(q);
@@ -152,52 +143,51 @@ void search_with_decompress(
152
143
  }
153
144
  }
154
145
 
155
- template<bool is_IP, AdditiveQuantizer::Search_type_t st, class ResultHandler>
146
+ template <
147
+ bool is_IP,
148
+ AdditiveQuantizer::Search_type_t st,
149
+ class BlockResultHandler>
156
150
  void search_with_LUT(
157
151
  const IndexAdditiveQuantizer& ir,
158
152
  const float* xq,
159
- ResultHandler& res)
160
- {
161
- const AdditiveQuantizer & aq = *ir.aq;
153
+ BlockResultHandler& res) {
154
+ const AdditiveQuantizer& aq = *ir.aq;
162
155
  const uint8_t* codes = ir.codes.data();
163
156
  size_t ntotal = ir.ntotal;
164
157
  size_t code_size = aq.code_size;
165
158
  size_t nq = res.nq;
166
159
  size_t d = ir.d;
167
160
 
168
- using SingleResultHandler = typename ResultHandler::SingleResultHandler;
169
- std::unique_ptr<float []> LUT(new float[nq * aq.total_codebook_size]);
161
+ using SingleResultHandler =
162
+ typename BlockResultHandler::SingleResultHandler;
163
+ std::unique_ptr<float[]> LUT(new float[nq * aq.total_codebook_size]);
170
164
 
171
165
  aq.compute_LUT(nq, xq, LUT.get());
172
166
 
173
- #pragma omp parallel for if(nq > 100)
167
+ #pragma omp parallel for if (nq > 100)
174
168
  for (int64_t q = 0; q < nq; q++) {
175
169
  SingleResultHandler resi(res);
176
170
  resi.begin(q);
177
171
  std::vector<float> tmp(aq.d);
178
- const float *LUT_q = LUT.get() + aq.total_codebook_size * q;
172
+ const float* LUT_q = LUT.get() + aq.total_codebook_size * q;
179
173
  float bias = 0;
180
- if (!is_IP) { // the LUT function returns ||y||^2 - 2 * <x, y>, need to add ||x||^2
174
+ if (!is_IP) { // the LUT function returns ||y||^2 - 2 * <x, y>, need to
175
+ // add ||x||^2
181
176
  bias = fvec_norm_L2sqr(xq + q * d, d);
182
177
  }
183
178
  for (size_t i = 0; i < ntotal; i++) {
184
179
  float dis = aq.compute_1_distance_LUT<is_IP, st>(
185
- codes + i * code_size,
186
- LUT_q
187
- );
180
+ codes + i * code_size, LUT_q);
188
181
  resi.add_result(dis + bias, i);
189
182
  }
190
183
  resi.end();
191
184
  }
192
-
193
185
  }
194
186
 
195
-
196
187
  } // anonymous namespace
197
188
 
198
-
199
- FlatCodesDistanceComputer * IndexAdditiveQuantizer::get_FlatCodesDistanceComputer() const {
200
-
189
+ FlatCodesDistanceComputer* IndexAdditiveQuantizer::
190
+ get_FlatCodesDistanceComputer() const {
201
191
  if (aq->search_type == AdditiveQuantizer::ST_decompress) {
202
192
  if (metric_type == METRIC_L2) {
203
193
  using VD = VectorDistance<METRIC_L2>;
@@ -212,34 +202,36 @@ FlatCodesDistanceComputer * IndexAdditiveQuantizer::get_FlatCodesDistanceCompute
212
202
  }
213
203
  } else {
214
204
  if (metric_type == METRIC_INNER_PRODUCT) {
215
- return new AQDistanceComputerLUT<true, AdditiveQuantizer::ST_LUT_nonorm>(*this);
205
+ return new AQDistanceComputerLUT<
206
+ true,
207
+ AdditiveQuantizer::ST_LUT_nonorm>(*this);
216
208
  } else {
217
- switch(aq->search_type) {
218
- #define DISPATCH(st) \
219
- case AdditiveQuantizer::st: \
220
- return new AQDistanceComputerLUT<false, AdditiveQuantizer::st> (*this);\
221
- break;
222
- DISPATCH(ST_norm_float)
223
- DISPATCH(ST_LUT_nonorm)
224
- DISPATCH(ST_norm_qint8)
225
- DISPATCH(ST_norm_qint4)
226
- DISPATCH(ST_norm_cqint4)
227
- case AdditiveQuantizer::ST_norm_cqint8:
228
- case AdditiveQuantizer::ST_norm_lsq2x4:
229
- case AdditiveQuantizer::ST_norm_rq2x4:
230
- return new AQDistanceComputerLUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this);\
231
- break;
209
+ switch (aq->search_type) {
210
+ #define DISPATCH(st) \
211
+ case AdditiveQuantizer::st: \
212
+ return new AQDistanceComputerLUT<false, AdditiveQuantizer::st>(*this); \
213
+ break;
214
+ DISPATCH(ST_norm_float)
215
+ DISPATCH(ST_LUT_nonorm)
216
+ DISPATCH(ST_norm_qint8)
217
+ DISPATCH(ST_norm_qint4)
218
+ DISPATCH(ST_norm_cqint4)
219
+ case AdditiveQuantizer::ST_norm_cqint8:
220
+ case AdditiveQuantizer::ST_norm_lsq2x4:
221
+ case AdditiveQuantizer::ST_norm_rq2x4:
222
+ return new AQDistanceComputerLUT<
223
+ false,
224
+ AdditiveQuantizer::ST_norm_cqint8>(*this);
225
+ break;
232
226
  #undef DISPATCH
233
- default:
234
- FAISS_THROW_FMT("search type %d not supported", aq->search_type);
227
+ default:
228
+ FAISS_THROW_FMT(
229
+ "search type %d not supported", aq->search_type);
235
230
  }
236
231
  }
237
232
  }
238
233
  }
239
234
 
240
-
241
-
242
-
243
235
  void IndexAdditiveQuantizer::search(
244
236
  idx_t n,
245
237
  const float* x,
@@ -247,62 +239,65 @@ void IndexAdditiveQuantizer::search(
247
239
  float* distances,
248
240
  idx_t* labels,
249
241
  const SearchParameters* params) const {
250
-
251
- FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
242
+ FAISS_THROW_IF_NOT_MSG(
243
+ !params, "search params not supported for this index");
252
244
 
253
245
  if (aq->search_type == AdditiveQuantizer::ST_decompress) {
254
246
  if (metric_type == METRIC_L2) {
255
247
  using VD = VectorDistance<METRIC_L2>;
256
248
  VD vd = {size_t(d), metric_arg};
257
- HeapResultHandler<VD::C> rh(n, distances, labels, k);
249
+ HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
258
250
  search_with_decompress(*this, x, vd, rh);
259
251
  } else if (metric_type == METRIC_INNER_PRODUCT) {
260
252
  using VD = VectorDistance<METRIC_INNER_PRODUCT>;
261
253
  VD vd = {size_t(d), metric_arg};
262
- HeapResultHandler<VD::C> rh(n, distances, labels, k);
254
+ HeapBlockResultHandler<VD::C> rh(n, distances, labels, k);
263
255
  search_with_decompress(*this, x, vd, rh);
264
256
  }
265
257
  } else {
266
258
  if (metric_type == METRIC_INNER_PRODUCT) {
267
- HeapResultHandler<CMin<float, idx_t> > rh(n, distances, labels, k);
268
- search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm> (*this, x, rh);
259
+ HeapBlockResultHandler<CMin<float, idx_t>> rh(
260
+ n, distances, labels, k);
261
+ search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm>(
262
+ *this, x, rh);
269
263
  } else {
270
- HeapResultHandler<CMax<float, idx_t> > rh(n, distances, labels, k);
271
- switch(aq->search_type) {
272
- #define DISPATCH(st) \
273
- case AdditiveQuantizer::st: \
274
- search_with_LUT<false, AdditiveQuantizer::st> (*this, x, rh);\
275
- break;
276
- DISPATCH(ST_norm_float)
277
- DISPATCH(ST_LUT_nonorm)
278
- DISPATCH(ST_norm_qint8)
279
- DISPATCH(ST_norm_qint4)
280
- DISPATCH(ST_norm_cqint4)
281
- case AdditiveQuantizer::ST_norm_cqint8:
282
- case AdditiveQuantizer::ST_norm_lsq2x4:
283
- case AdditiveQuantizer::ST_norm_rq2x4:
284
- search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this, x, rh);
285
- break;
264
+ HeapBlockResultHandler<CMax<float, idx_t>> rh(
265
+ n, distances, labels, k);
266
+ switch (aq->search_type) {
267
+ #define DISPATCH(st) \
268
+ case AdditiveQuantizer::st: \
269
+ search_with_LUT<false, AdditiveQuantizer::st>(*this, x, rh); \
270
+ break;
271
+ DISPATCH(ST_norm_float)
272
+ DISPATCH(ST_LUT_nonorm)
273
+ DISPATCH(ST_norm_qint8)
274
+ DISPATCH(ST_norm_qint4)
275
+ DISPATCH(ST_norm_cqint4)
276
+ case AdditiveQuantizer::ST_norm_cqint8:
277
+ case AdditiveQuantizer::ST_norm_lsq2x4:
278
+ case AdditiveQuantizer::ST_norm_rq2x4:
279
+ search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint8>(
280
+ *this, x, rh);
281
+ break;
286
282
  #undef DISPATCH
287
- default:
288
- FAISS_THROW_FMT("search type %d not supported", aq->search_type);
283
+ default:
284
+ FAISS_THROW_FMT(
285
+ "search type %d not supported", aq->search_type);
289
286
  }
290
287
  }
291
-
292
288
  }
293
289
  }
294
290
 
295
- void IndexAdditiveQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
291
+ void IndexAdditiveQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes)
292
+ const {
296
293
  return aq->compute_codes(x, bytes, n);
297
294
  }
298
295
 
299
- void IndexAdditiveQuantizer::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
296
+ void IndexAdditiveQuantizer::sa_decode(idx_t n, const uint8_t* bytes, float* x)
297
+ const {
300
298
  return aq->decode(bytes, x, n);
301
299
  }
302
300
 
303
-
304
-
305
-
306
301
  /**************************************************************************************
307
302
  * IndexResidualQuantizer
308
303
  **************************************************************************************/
@@ -313,8 +308,11 @@ IndexResidualQuantizer::IndexResidualQuantizer(
313
308
  size_t nbits, ///< number of bit per subvector index
314
309
  MetricType metric,
315
310
  Search_type_t search_type)
316
- : IndexResidualQuantizer(d, std::vector<size_t>(M, nbits), metric, search_type) {
317
- }
311
+ : IndexResidualQuantizer(
312
+ d,
313
+ std::vector<size_t>(M, nbits),
314
+ metric,
315
+ search_type) {}
318
316
 
319
317
  IndexResidualQuantizer::IndexResidualQuantizer(
320
318
  int d,
@@ -326,14 +324,14 @@ IndexResidualQuantizer::IndexResidualQuantizer(
326
324
  is_trained = false;
327
325
  }
328
326
 
329
- IndexResidualQuantizer::IndexResidualQuantizer() : IndexResidualQuantizer(0, 0, 0) {}
327
+ IndexResidualQuantizer::IndexResidualQuantizer()
328
+ : IndexResidualQuantizer(0, 0, 0) {}
330
329
 
331
330
  void IndexResidualQuantizer::train(idx_t n, const float* x) {
332
331
  rq.train(n, x);
333
332
  is_trained = true;
334
333
  }
335
334
 
336
-
337
335
  /**************************************************************************************
338
336
  * IndexLocalSearchQuantizer
339
337
  **************************************************************************************/
@@ -344,31 +342,33 @@ IndexLocalSearchQuantizer::IndexLocalSearchQuantizer(
344
342
  size_t nbits, ///< number of bit per subvector index
345
343
  MetricType metric,
346
344
  Search_type_t search_type)
347
- : IndexAdditiveQuantizer(d, &lsq, metric), lsq(d, M, nbits, search_type) {
345
+ : IndexAdditiveQuantizer(d, &lsq, metric),
346
+ lsq(d, M, nbits, search_type) {
348
347
  code_size = lsq.code_size;
349
348
  is_trained = false;
350
349
  }
351
350
 
352
- IndexLocalSearchQuantizer::IndexLocalSearchQuantizer() : IndexLocalSearchQuantizer(0, 0, 0) {}
351
+ IndexLocalSearchQuantizer::IndexLocalSearchQuantizer()
352
+ : IndexLocalSearchQuantizer(0, 0, 0) {}
353
353
 
354
354
  void IndexLocalSearchQuantizer::train(idx_t n, const float* x) {
355
355
  lsq.train(n, x);
356
356
  is_trained = true;
357
357
  }
358
358
 
359
-
360
359
  /**************************************************************************************
361
360
  * IndexProductResidualQuantizer
362
361
  **************************************************************************************/
363
362
 
364
363
  IndexProductResidualQuantizer::IndexProductResidualQuantizer(
365
- int d, ///< dimensionality of the input vectors
364
+ int d, ///< dimensionality of the input vectors
366
365
  size_t nsplits, ///< number of residual quantizers
367
- size_t Msub, ///< number of subquantizers per RQ
368
- size_t nbits, ///< number of bit per subvector index
366
+ size_t Msub, ///< number of subquantizers per RQ
367
+ size_t nbits, ///< number of bit per subvector index
369
368
  MetricType metric,
370
369
  Search_type_t search_type)
371
- : IndexAdditiveQuantizer(d, &prq, metric), prq(d, nsplits, Msub, nbits, search_type) {
370
+ : IndexAdditiveQuantizer(d, &prq, metric),
371
+ prq(d, nsplits, Msub, nbits, search_type) {
372
372
  code_size = prq.code_size;
373
373
  is_trained = false;
374
374
  }
@@ -381,19 +381,19 @@ void IndexProductResidualQuantizer::train(idx_t n, const float* x) {
381
381
  is_trained = true;
382
382
  }
383
383
 
384
-
385
384
  /**************************************************************************************
386
385
  * IndexProductLocalSearchQuantizer
387
386
  **************************************************************************************/
388
387
 
389
388
  IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer(
390
- int d, ///< dimensionality of the input vectors
389
+ int d, ///< dimensionality of the input vectors
391
390
  size_t nsplits, ///< number of local search quantizers
392
- size_t Msub, ///< number of subquantizers per LSQ
393
- size_t nbits, ///< number of bit per subvector index
391
+ size_t Msub, ///< number of subquantizers per LSQ
392
+ size_t nbits, ///< number of bit per subvector index
394
393
  MetricType metric,
395
394
  Search_type_t search_type)
396
- : IndexAdditiveQuantizer(d, &plsq, metric), plsq(d, nsplits, Msub, nbits, search_type) {
395
+ : IndexAdditiveQuantizer(d, &plsq, metric),
396
+ plsq(d, nsplits, Msub, nbits, search_type) {
397
397
  code_size = plsq.code_size;
398
398
  is_trained = false;
399
399
  }
@@ -406,17 +406,15 @@ void IndexProductLocalSearchQuantizer::train(idx_t n, const float* x) {
406
406
  is_trained = true;
407
407
  }
408
408
 
409
-
410
409
  /**************************************************************************************
411
410
  * AdditiveCoarseQuantizer
412
411
  **************************************************************************************/
413
412
 
414
413
  AdditiveCoarseQuantizer::AdditiveCoarseQuantizer(
415
- idx_t d,
416
- AdditiveQuantizer* aq,
417
- MetricType metric):
418
- Index(d, metric), aq(aq)
419
- {}
414
+ idx_t d,
415
+ AdditiveQuantizer* aq,
416
+ MetricType metric)
417
+ : Index(d, metric), aq(aq) {}
420
418
 
421
419
  void AdditiveCoarseQuantizer::add(idx_t, const float*) {
422
420
  FAISS_THROW_MSG("not applicable");
@@ -430,17 +428,16 @@ void AdditiveCoarseQuantizer::reset() {
430
428
  FAISS_THROW_MSG("not applicable");
431
429
  }
432
430
 
433
-
434
431
  void AdditiveCoarseQuantizer::train(idx_t n, const float* x) {
435
432
  if (verbose) {
436
- printf("AdditiveCoarseQuantizer::train: training on %zd vectors\n", size_t(n));
433
+ printf("AdditiveCoarseQuantizer::train: training on %zd vectors\n",
434
+ size_t(n));
437
435
  }
438
436
  size_t norms_size = sizeof(float) << aq->tot_bits;
439
437
 
440
- FAISS_THROW_IF_NOT_MSG (
441
- norms_size <= aq->max_mem_distances,
442
- "the RCQ norms matrix will become too large, please reduce the number of quantization steps"
443
- );
438
+ FAISS_THROW_IF_NOT_MSG(
439
+ norms_size <= aq->max_mem_distances,
440
+ "the RCQ norms matrix will become too large, please reduce the number of quantization steps");
444
441
 
445
442
  aq->train(n, x);
446
443
  is_trained = true;
@@ -448,7 +445,8 @@ void AdditiveCoarseQuantizer::train(idx_t n, const float* x) {
448
445
 
449
446
  if (metric_type == METRIC_L2) {
450
447
  if (verbose) {
451
- printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n", size_t(ntotal));
448
+ printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n",
449
+ size_t(ntotal));
452
450
  }
453
451
  // this is not necessary for the residualcoarsequantizer when
454
452
  // using beam search. We'll see if the memory overhead is too high
@@ -463,16 +461,15 @@ void AdditiveCoarseQuantizer::search(
463
461
  idx_t k,
464
462
  float* distances,
465
463
  idx_t* labels,
466
- const SearchParameters * params) const {
467
-
468
- FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
464
+ const SearchParameters* params) const {
465
+ FAISS_THROW_IF_NOT_MSG(
466
+ !params, "search params not supported for this index");
469
467
 
470
468
  if (metric_type == METRIC_INNER_PRODUCT) {
471
469
  aq->knn_centroids_inner_product(n, x, k, distances, labels);
472
470
  } else if (metric_type == METRIC_L2) {
473
471
  FAISS_THROW_IF_NOT(centroid_norms.size() == ntotal);
474
- aq->knn_centroids_L2(
475
- n, x, k, distances, labels, centroid_norms.data());
472
+ aq->knn_centroids_L2(n, x, k, distances, labels, centroid_norms.data());
476
473
  }
477
474
  }
478
475
 
@@ -481,7 +478,7 @@ void AdditiveCoarseQuantizer::search(
481
478
  **************************************************************************************/
482
479
 
483
480
  ResidualCoarseQuantizer::ResidualCoarseQuantizer(
484
- int d, ///< dimensionality of the input vectors
481
+ int d, ///< dimensionality of the input vectors
485
482
  const std::vector<size_t>& nbits,
486
483
  MetricType metric)
487
484
  : AdditiveCoarseQuantizer(d, &rq, metric), rq(d, nbits) {
@@ -496,21 +493,30 @@ ResidualCoarseQuantizer::ResidualCoarseQuantizer(
496
493
  MetricType metric)
497
494
  : ResidualCoarseQuantizer(d, std::vector<size_t>(M, nbits), metric) {}
498
495
 
499
- ResidualCoarseQuantizer::ResidualCoarseQuantizer(): ResidualCoarseQuantizer(0, 0, 0) {}
500
-
501
-
496
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer()
497
+ : ResidualCoarseQuantizer(0, 0, 0) {}
502
498
 
503
499
  void ResidualCoarseQuantizer::set_beam_factor(float new_beam_factor) {
504
500
  beam_factor = new_beam_factor;
505
501
  if (new_beam_factor > 0) {
506
502
  FAISS_THROW_IF_NOT(new_beam_factor >= 1.0);
503
+ if (rq.codebook_cross_products.size() == 0) {
504
+ rq.compute_codebook_tables();
505
+ }
507
506
  return;
508
- } else if (metric_type == METRIC_L2 && ntotal != centroid_norms.size()) {
509
- if (verbose) {
510
- printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n", size_t(ntotal));
507
+ } else {
508
+ // new_beam_factor = -1: exhaustive computation.
509
+ // Does not use the cross_products
510
+ rq.codebook_cross_products.resize(0);
511
+ // but the centroid norms are necessary!
512
+ if (metric_type == METRIC_L2 && ntotal != centroid_norms.size()) {
513
+ if (verbose) {
514
+ printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n",
515
+ size_t(ntotal));
516
+ }
517
+ centroid_norms.resize(ntotal);
518
+ aq->compute_centroid_norms(centroid_norms.data());
511
519
  }
512
- centroid_norms.resize(ntotal);
513
- aq->compute_centroid_norms(centroid_norms.data());
514
520
  }
515
521
  }
516
522
 
@@ -520,13 +526,15 @@ void ResidualCoarseQuantizer::search(
520
526
  idx_t k,
521
527
  float* distances,
522
528
  idx_t* labels,
523
- const SearchParameters * params_in
524
- ) const {
525
-
529
+ const SearchParameters* params_in) const {
526
530
  float beam_factor = this->beam_factor;
527
531
  if (params_in) {
528
- auto params = dynamic_cast<const SearchParametersResidualCoarseQuantizer*>(params_in);
529
- FAISS_THROW_IF_NOT_MSG(params, "need SearchParametersResidualCoarseQuantizer parameters");
532
+ auto params =
533
+ dynamic_cast<const SearchParametersResidualCoarseQuantizer*>(
534
+ params_in);
535
+ FAISS_THROW_IF_NOT_MSG(
536
+ params,
537
+ "need SearchParametersResidualCoarseQuantizer parameters");
530
538
  beam_factor = params->beam_factor;
531
539
  }
532
540
 
@@ -559,7 +567,12 @@ void ResidualCoarseQuantizer::search(
559
567
  }
560
568
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
561
569
  idx_t i1 = std::min(n, i0 + bs);
562
- search(i1 - i0, x + i0 * d, k, distances + i0 * k, labels + i0 * k);
570
+ search(i1 - i0,
571
+ x + i0 * d,
572
+ k,
573
+ distances + i0 * k,
574
+ labels + i0 * k,
575
+ params_in);
563
576
  InterruptCallback::check();
564
577
  }
565
578
  return;
@@ -571,6 +584,7 @@ void ResidualCoarseQuantizer::search(
571
584
  rq.refine_beam(
572
585
  n, 1, x, beam_size, codes.data(), nullptr, beam_distances.data());
573
586
 
587
+ // pack int32 table
574
588
  #pragma omp parallel for if (n > 4000)
575
589
  for (idx_t i = 0; i < n; i++) {
576
590
  memcpy(distances + i * k,
@@ -590,7 +604,8 @@ void ResidualCoarseQuantizer::search(
590
604
  }
591
605
  }
592
606
 
593
- void ResidualCoarseQuantizer::initialize_from(const ResidualCoarseQuantizer &other) {
607
+ void ResidualCoarseQuantizer::initialize_from(
608
+ const ResidualCoarseQuantizer& other) {
594
609
  FAISS_THROW_IF_NOT(rq.M <= other.rq.M);
595
610
  rq.initialize_from(other.rq);
596
611
  set_beam_factor(other.beam_factor);
@@ -598,7 +613,6 @@ void ResidualCoarseQuantizer::initialize_from(const ResidualCoarseQuantizer &oth
598
613
  ntotal = (idx_t)1 << aq->tot_bits;
599
614
  }
600
615
 
601
-
602
616
  /**************************************************************************************
603
617
  * LocalSearchCoarseQuantizer
604
618
  **************************************************************************************/
@@ -613,12 +627,8 @@ LocalSearchCoarseQuantizer::LocalSearchCoarseQuantizer(
613
627
  is_trained = false;
614
628
  }
615
629
 
616
-
617
630
  LocalSearchCoarseQuantizer::LocalSearchCoarseQuantizer() {
618
631
  aq = &lsq;
619
632
  }
620
633
 
621
-
622
-
623
-
624
634
  } // namespace faiss