faiss 0.2.3 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -70,10 +70,11 @@ bool getTensorCoreSupport(int device);
70
70
  /// Equivalent to getTensorCoreSupport(getCurrentDevice())
71
71
  bool getTensorCoreSupportCurrentDevice();
72
72
 
73
- /// Returns the maximum k-selection value supported based on the CUDA SDK that
74
- /// we were compiled with. .cu files can use DeviceDefs.cuh, but this is for
75
- /// non-CUDA files
76
- int getMaxKSelection();
73
+ /// Returns the amount of currently available memory on the given device
74
+ size_t getFreeMemory(int device);
75
+
76
+ /// Equivalent to getFreeMemory(getCurrentDevice())
77
+ size_t getFreeMemoryCurrentDevice();
77
78
 
78
79
  /// RAII object to set the current device, and restore the previous
79
80
  /// device upon destruction
@@ -8,7 +8,6 @@
8
8
  // -*- c++ -*-
9
9
 
10
10
  #include <faiss/impl/AdditiveQuantizer.h>
11
- #include <faiss/impl/FaissAssert.h>
12
11
 
13
12
  #include <cstddef>
14
13
  #include <cstdio>
@@ -18,9 +17,13 @@
18
17
 
19
18
  #include <algorithm>
20
19
 
20
+ #include <faiss/Clustering.h>
21
+ #include <faiss/impl/FaissAssert.h>
22
+ #include <faiss/impl/LocalSearchQuantizer.h>
23
+ #include <faiss/impl/ResidualQuantizer.h>
21
24
  #include <faiss/utils/Heap.h>
22
25
  #include <faiss/utils/distances.h>
23
- #include <faiss/utils/hamming.h> // BitstringWriter
26
+ #include <faiss/utils/hamming.h>
24
27
  #include <faiss/utils/utils.h>
25
28
 
26
29
  extern "C" {
@@ -42,51 +45,211 @@ int sgemm_(
42
45
  FINTEGER* ldc);
43
46
  }
44
47
 
45
- namespace {
46
-
47
- // c and a and b can overlap
48
- void fvec_add(size_t d, const float* a, const float* b, float* c) {
49
- for (size_t i = 0; i < d; i++) {
50
- c[i] = a[i] + b[i];
51
- }
52
- }
48
+ namespace faiss {
53
49
 
54
- void fvec_add(size_t d, const float* a, float b, float* c) {
55
- for (size_t i = 0; i < d; i++) {
56
- c[i] = a[i] + b;
57
- }
50
+ AdditiveQuantizer::AdditiveQuantizer(
51
+ size_t d,
52
+ const std::vector<size_t>& nbits,
53
+ Search_type_t search_type)
54
+ : Quantizer(d),
55
+ M(nbits.size()),
56
+ nbits(nbits),
57
+ verbose(false),
58
+ is_trained(false),
59
+ max_mem_distances(5 * (size_t(1) << 30)), // 5 GiB
60
+ search_type(search_type) {
61
+ norm_max = norm_min = NAN;
62
+ tot_bits = 0;
63
+ total_codebook_size = 0;
64
+ only_8bit = false;
65
+ set_derived_values();
58
66
  }
59
67
 
60
- } // namespace
61
-
62
- namespace faiss {
68
+ AdditiveQuantizer::AdditiveQuantizer()
69
+ : AdditiveQuantizer(0, std::vector<size_t>()) {}
63
70
 
64
71
  void AdditiveQuantizer::set_derived_values() {
65
72
  tot_bits = 0;
66
- is_byte_aligned = true;
73
+ only_8bit = true;
67
74
  codebook_offsets.resize(M + 1, 0);
68
75
  for (int i = 0; i < M; i++) {
69
76
  int nbit = nbits[i];
70
77
  size_t k = 1 << nbit;
71
78
  codebook_offsets[i + 1] = codebook_offsets[i] + k;
72
79
  tot_bits += nbit;
73
- if (nbit % 8 != 0) {
74
- is_byte_aligned = false;
80
+ if (nbit != 0) {
81
+ only_8bit = false;
75
82
  }
76
83
  }
77
84
  total_codebook_size = codebook_offsets[M];
85
+ switch (search_type) {
86
+ case ST_norm_float:
87
+ norm_bits = 32;
88
+ break;
89
+ case ST_norm_qint8:
90
+ case ST_norm_cqint8:
91
+ case ST_norm_lsq2x4:
92
+ case ST_norm_rq2x4:
93
+ norm_bits = 8;
94
+ break;
95
+ case ST_norm_qint4:
96
+ case ST_norm_cqint4:
97
+ norm_bits = 4;
98
+ break;
99
+ case ST_decompress:
100
+ case ST_LUT_nonorm:
101
+ case ST_norm_from_LUT:
102
+ default:
103
+ norm_bits = 0;
104
+ break;
105
+ }
106
+ tot_bits += norm_bits;
107
+
78
108
  // convert bits to bytes
79
109
  code_size = (tot_bits + 7) / 8;
80
110
  }
81
111
 
112
+ void AdditiveQuantizer::train_norm(size_t n, const float* norms) {
113
+ norm_min = HUGE_VALF;
114
+ norm_max = -HUGE_VALF;
115
+ for (idx_t i = 0; i < n; i++) {
116
+ if (norms[i] < norm_min) {
117
+ norm_min = norms[i];
118
+ }
119
+ if (norms[i] > norm_max) {
120
+ norm_max = norms[i];
121
+ }
122
+ }
123
+
124
+ if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
125
+ size_t k = (1 << 8);
126
+ if (search_type == ST_norm_cqint4) {
127
+ k = (1 << 4);
128
+ }
129
+ Clustering1D clus(k);
130
+ clus.train_exact(n, norms);
131
+ qnorm.add(clus.k, clus.centroids.data());
132
+ } else if (search_type == ST_norm_lsq2x4 || search_type == ST_norm_rq2x4) {
133
+ std::unique_ptr<AdditiveQuantizer> aq;
134
+ if (search_type == ST_norm_lsq2x4) {
135
+ aq.reset(new LocalSearchQuantizer(1, 2, 4));
136
+ } else {
137
+ aq.reset(new ResidualQuantizer(1, 2, 4));
138
+ }
139
+
140
+ aq->train(n, norms);
141
+ // flatten aq codebooks
142
+ std::vector<float> flat_codebooks(1 << 8);
143
+ FAISS_THROW_IF_NOT(aq->codebooks.size() == 32);
144
+
145
+ // save norm tables for 4-bit fastscan search
146
+ norm_tabs = aq->codebooks;
147
+
148
+ // assume big endian
149
+ const float* c = norm_tabs.data();
150
+ for (size_t i = 0; i < 16; i++) {
151
+ for (size_t j = 0; j < 16; j++) {
152
+ flat_codebooks[i * 16 + j] = c[j] + c[16 + i];
153
+ }
154
+ }
155
+
156
+ qnorm.reset();
157
+ qnorm.add(1 << 8, flat_codebooks.data());
158
+ FAISS_THROW_IF_NOT(qnorm.ntotal == (1 << 8));
159
+ }
160
+ }
161
+
162
+ namespace {
163
+
164
+ // TODO
165
+ // https://stackoverflow.com/questions/31631224/hacks-for-clamping-integer-to-0-255-and-doubles-to-0-0-1-0
166
+
167
+ uint8_t encode_qint8(float x, float amin, float amax) {
168
+ float x1 = (x - amin) / (amax - amin) * 256;
169
+ int32_t xi = int32_t(floor(x1));
170
+
171
+ return xi < 0 ? 0 : xi > 255 ? 255 : xi;
172
+ }
173
+
174
+ uint8_t encode_qint4(float x, float amin, float amax) {
175
+ float x1 = (x - amin) / (amax - amin) * 16;
176
+ int32_t xi = int32_t(floor(x1));
177
+
178
+ return xi < 0 ? 0 : xi > 15 ? 15 : xi;
179
+ }
180
+
181
+ float decode_qint8(uint8_t i, float amin, float amax) {
182
+ return (i + 0.5) / 256 * (amax - amin) + amin;
183
+ }
184
+
185
+ float decode_qint4(uint8_t i, float amin, float amax) {
186
+ return (i + 0.5) / 16 * (amax - amin) + amin;
187
+ }
188
+
189
+ } // anonymous namespace
190
+
191
+ uint32_t AdditiveQuantizer::encode_qcint(float x) const {
192
+ idx_t id;
193
+ qnorm.assign(1, &x, &id, 1);
194
+ return uint32_t(id);
195
+ }
196
+
197
+ float AdditiveQuantizer::decode_qcint(uint32_t c) const {
198
+ return qnorm.get_xb()[c];
199
+ }
200
+
201
+ uint64_t AdditiveQuantizer::encode_norm(float norm) const {
202
+ switch (search_type) {
203
+ case ST_norm_float:
204
+ uint32_t inorm;
205
+ memcpy(&inorm, &norm, 4);
206
+ return inorm;
207
+ case ST_norm_qint8:
208
+ return encode_qint8(norm, norm_min, norm_max);
209
+ case ST_norm_qint4:
210
+ return encode_qint4(norm, norm_min, norm_max);
211
+ case ST_norm_lsq2x4:
212
+ case ST_norm_rq2x4:
213
+ case ST_norm_cqint8:
214
+ return encode_qcint(norm);
215
+ case ST_norm_cqint4:
216
+ return encode_qcint(norm);
217
+ case ST_decompress:
218
+ case ST_LUT_nonorm:
219
+ case ST_norm_from_LUT:
220
+ default:
221
+ return 0;
222
+ }
223
+ }
224
+
82
225
  void AdditiveQuantizer::pack_codes(
83
226
  size_t n,
84
227
  const int32_t* codes,
85
228
  uint8_t* packed_codes,
86
- int64_t ld_codes) const {
229
+ int64_t ld_codes,
230
+ const float* norms,
231
+ const float* centroids) const {
87
232
  if (ld_codes == -1) {
88
233
  ld_codes = M;
89
234
  }
235
+ std::vector<float> norm_buf;
236
+ if (search_type == ST_norm_float || search_type == ST_norm_qint4 ||
237
+ search_type == ST_norm_qint8 || search_type == ST_norm_cqint8 ||
238
+ search_type == ST_norm_cqint4 || search_type == ST_norm_lsq2x4 ||
239
+ search_type == ST_norm_rq2x4) {
240
+ if (centroids != nullptr || !norms) {
241
+ norm_buf.resize(n);
242
+ std::vector<float> x_recons(n * d);
243
+ decode_unpacked(codes, x_recons.data(), n, ld_codes);
244
+
245
+ if (centroids != nullptr) {
246
+ // x = x + c
247
+ fvec_add(n * d, x_recons.data(), centroids, x_recons.data());
248
+ }
249
+ fvec_norms_L2sqr(norm_buf.data(), x_recons.data(), d, n);
250
+ norms = norm_buf.data();
251
+ }
252
+ }
90
253
  #pragma omp parallel for if (n > 1000)
91
254
  for (int64_t i = 0; i < n; i++) {
92
255
  const int32_t* codes1 = codes + i * ld_codes;
@@ -94,6 +257,9 @@ void AdditiveQuantizer::pack_codes(
94
257
  for (int m = 0; m < M; m++) {
95
258
  bsw.write(codes1[m], nbits[m]);
96
259
  }
260
+ if (norm_bits != 0) {
261
+ bsw.write(encode_norm(norms[i]), norm_bits);
262
+ }
97
263
  }
98
264
  }
99
265
 
@@ -118,10 +284,39 @@ void AdditiveQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
118
284
  }
119
285
  }
120
286
 
287
+ void AdditiveQuantizer::decode_unpacked(
288
+ const int32_t* code,
289
+ float* x,
290
+ size_t n,
291
+ int64_t ld_codes) const {
292
+ FAISS_THROW_IF_NOT_MSG(
293
+ is_trained, "The additive quantizer is not trained yet.");
294
+
295
+ if (ld_codes == -1) {
296
+ ld_codes = M;
297
+ }
298
+
299
+ // standard additive quantizer decoding
300
+ #pragma omp parallel for if (n > 1000)
301
+ for (int64_t i = 0; i < n; i++) {
302
+ const int32_t* codesi = code + i * ld_codes;
303
+ float* xi = x + i * d;
304
+ for (int m = 0; m < M; m++) {
305
+ int idx = codesi[m];
306
+ const float* c = codebooks.data() + d * (codebook_offsets[m] + idx);
307
+ if (m == 0) {
308
+ memcpy(xi, c, sizeof(*x) * d);
309
+ } else {
310
+ fvec_add(d, xi, c, xi);
311
+ }
312
+ }
313
+ }
314
+ }
315
+
121
316
  AdditiveQuantizer::~AdditiveQuantizer() {}
122
317
 
123
318
  /****************************************************************************
124
- * Support for fast distance computations and search with additive quantizer
319
+ * Support for fast distance computations in centroids
125
320
  ****************************************************************************/
126
321
 
127
322
  void AdditiveQuantizer::compute_centroid_norms(float* norms) const {
@@ -151,28 +346,33 @@ void AdditiveQuantizer::decode_64bit(idx_t bits, float* xi) const {
151
346
  }
152
347
  }
153
348
 
154
- void AdditiveQuantizer::compute_LUT(size_t n, const float* xq, float* LUT)
155
- const {
349
+ void AdditiveQuantizer::compute_LUT(
350
+ size_t n,
351
+ const float* xq,
352
+ float* LUT,
353
+ float alpha,
354
+ long ld_lut) const {
156
355
  // in all cases, it is large matrix multiplication
157
356
 
158
357
  FINTEGER ncenti = total_codebook_size;
159
358
  FINTEGER di = d;
160
359
  FINTEGER nqi = n;
161
- float one = 1, zero = 0;
360
+ FINTEGER ldc = ld_lut > 0 ? ld_lut : ncenti;
361
+ float zero = 0;
162
362
 
163
363
  sgemm_("Transposed",
164
364
  "Not transposed",
165
365
  &ncenti,
166
366
  &nqi,
167
367
  &di,
168
- &one,
368
+ &alpha,
169
369
  codebooks.data(),
170
370
  &di,
171
371
  xq,
172
372
  &di,
173
373
  &zero,
174
374
  LUT,
175
- &ncenti);
375
+ &ldc);
176
376
  }
177
377
 
178
378
  namespace {
@@ -201,7 +401,7 @@ void compute_inner_prod_with_LUT(
201
401
 
202
402
  } // anonymous namespace
203
403
 
204
- void AdditiveQuantizer::knn_exact_inner_product(
404
+ void AdditiveQuantizer::knn_centroids_inner_product(
205
405
  idx_t n,
206
406
  const float* xq,
207
407
  idx_t k,
@@ -227,7 +427,7 @@ void AdditiveQuantizer::knn_exact_inner_product(
227
427
  }
228
428
  }
229
429
 
230
- void AdditiveQuantizer::knn_exact_L2(
430
+ void AdditiveQuantizer::knn_centroids_L2(
231
431
  idx_t n,
232
432
  const float* xq,
233
433
  idx_t k,
@@ -267,4 +467,106 @@ void AdditiveQuantizer::knn_exact_L2(
267
467
  }
268
468
  }
269
469
 
470
+ /****************************************************************************
471
+ * Support for fast distance computations in codes
472
+ ****************************************************************************/
473
+
474
+ namespace {
475
+
476
+ float accumulate_IPs(
477
+ const AdditiveQuantizer& aq,
478
+ BitstringReader& bs,
479
+ const uint8_t* codes,
480
+ const float* LUT) {
481
+ float accu = 0;
482
+ for (int m = 0; m < aq.M; m++) {
483
+ size_t nbit = aq.nbits[m];
484
+ int idx = bs.read(nbit);
485
+ accu += LUT[idx];
486
+ LUT += (uint64_t)1 << nbit;
487
+ }
488
+ return accu;
489
+ }
490
+
491
+ } // anonymous namespace
492
+
493
+ template <>
494
+ float AdditiveQuantizer::
495
+ compute_1_distance_LUT<true, AdditiveQuantizer::ST_LUT_nonorm>(
496
+ const uint8_t* codes,
497
+ const float* LUT) const {
498
+ BitstringReader bs(codes, code_size);
499
+ return accumulate_IPs(*this, bs, codes, LUT);
500
+ }
501
+
502
+ template <>
503
+ float AdditiveQuantizer::
504
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_LUT_nonorm>(
505
+ const uint8_t* codes,
506
+ const float* LUT) const {
507
+ BitstringReader bs(codes, code_size);
508
+ return -accumulate_IPs(*this, bs, codes, LUT);
509
+ }
510
+
511
+ template <>
512
+ float AdditiveQuantizer::
513
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_float>(
514
+ const uint8_t* codes,
515
+ const float* LUT) const {
516
+ BitstringReader bs(codes, code_size);
517
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
518
+ uint32_t norm_i = bs.read(32);
519
+ float norm2;
520
+ memcpy(&norm2, &norm_i, 4);
521
+ return norm2 - 2 * accu;
522
+ }
523
+
524
+ template <>
525
+ float AdditiveQuantizer::
526
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint8>(
527
+ const uint8_t* codes,
528
+ const float* LUT) const {
529
+ BitstringReader bs(codes, code_size);
530
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
531
+ uint32_t norm_i = bs.read(8);
532
+ float norm2 = decode_qcint(norm_i);
533
+ return norm2 - 2 * accu;
534
+ }
535
+
536
+ template <>
537
+ float AdditiveQuantizer::
538
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint4>(
539
+ const uint8_t* codes,
540
+ const float* LUT) const {
541
+ BitstringReader bs(codes, code_size);
542
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
543
+ uint32_t norm_i = bs.read(4);
544
+ float norm2 = decode_qcint(norm_i);
545
+ return norm2 - 2 * accu;
546
+ }
547
+
548
+ template <>
549
+ float AdditiveQuantizer::
550
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint8>(
551
+ const uint8_t* codes,
552
+ const float* LUT) const {
553
+ BitstringReader bs(codes, code_size);
554
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
555
+ uint32_t norm_i = bs.read(8);
556
+ float norm2 = decode_qint8(norm_i, norm_min, norm_max);
557
+ return norm2 - 2 * accu;
558
+ }
559
+
560
+ template <>
561
+ float AdditiveQuantizer::
562
+ compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint4>(
563
+ const uint8_t* codes,
564
+ const float* LUT) const {
565
+ BitstringReader bs(codes, code_size);
566
+ float accu = accumulate_IPs(*this, bs, codes, LUT);
567
+ uint32_t norm_i = bs.read(4);
568
+ float norm2 = decode_qint4(norm_i, norm_min, norm_max);
569
+ return norm2 - 2 * accu;
570
+ }
571
+
270
572
  } // namespace faiss
@@ -11,6 +11,8 @@
11
11
  #include <vector>
12
12
 
13
13
  #include <faiss/Index.h>
14
+ #include <faiss/IndexFlat.h>
15
+ #include <faiss/impl/Quantizer.h>
14
16
 
15
17
  namespace faiss {
16
18
 
@@ -20,58 +22,140 @@ namespace faiss {
20
22
  * concatenation of M sub-vectors, additive quantizers sum M sub-vectors
21
23
  * to get the decoded vector.
22
24
  */
23
- struct AdditiveQuantizer {
24
- size_t d; ///< size of the input vectors
25
+ struct AdditiveQuantizer : Quantizer {
25
26
  size_t M; ///< number of codebooks
26
27
  std::vector<size_t> nbits; ///< bits for each step
27
28
  std::vector<float> codebooks; ///< codebooks
28
29
 
29
30
  // derived values
30
- std::vector<size_t> codebook_offsets;
31
- size_t code_size; ///< code size in bytes
32
- size_t tot_bits; ///< total number of bits
31
+ std::vector<uint64_t> codebook_offsets;
32
+ size_t tot_bits; ///< total number of bits (indexes + norms)
33
+ size_t norm_bits; ///< bits allocated for the norms
33
34
  size_t total_codebook_size; ///< size of the codebook in vectors
34
- bool is_byte_aligned;
35
+ bool only_8bit; ///< are all nbits = 8 (use faster decoder)
35
36
 
36
37
  bool verbose; ///< verbose during training?
37
38
  bool is_trained; ///< is trained or not
38
39
 
40
+ IndexFlat1D qnorm; ///< store and search norms
41
+ std::vector<float> norm_tabs; ///< store norms of codebook entries for 4-bit
42
+ ///< fastscan search
43
+
44
+ /// norms and distance matrixes with beam search can get large, so use this
45
+ /// to control for the amount of memory that can be allocated
46
+ size_t max_mem_distances;
47
+
48
+ /// encode a norm into norm_bits bits
49
+ uint64_t encode_norm(float norm) const;
50
+
51
+ uint32_t encode_qcint(
52
+ float x) const; ///< encode norm by non-uniform scalar quantization
53
+
54
+ float decode_qcint(uint32_t c)
55
+ const; ///< decode norm by non-uniform scalar quantization
56
+
57
+ /// Encodes how search is performed and how vectors are encoded
58
+ enum Search_type_t {
59
+ ST_decompress, ///< decompress database vector
60
+ ST_LUT_nonorm, ///< use a LUT, don't include norms (OK for IP or
61
+ ///< normalized vectors)
62
+ ST_norm_from_LUT, ///< compute the norms from the look-up tables (cost
63
+ ///< is in O(M^2))
64
+ ST_norm_float, ///< use a LUT, and store float32 norm with the vectors
65
+ ST_norm_qint8, ///< use a LUT, and store 8bit-quantized norm
66
+ ST_norm_qint4,
67
+ ST_norm_cqint8, ///< use a LUT, and store non-uniform quantized norm
68
+ ST_norm_cqint4,
69
+
70
+ ST_norm_lsq2x4, ///< use a 2x4 bits lsq as norm quantizer (for fast
71
+ ///< scan)
72
+ ST_norm_rq2x4, ///< use a 2x4 bits rq as norm quantizer (for fast scan)
73
+ };
74
+
75
+ AdditiveQuantizer(
76
+ size_t d,
77
+ const std::vector<size_t>& nbits,
78
+ Search_type_t search_type = ST_decompress);
79
+
80
+ AdditiveQuantizer();
81
+
39
82
  ///< compute derived values when d, M and nbits have been set
40
83
  void set_derived_values();
41
84
 
42
- ///< Train the additive quantizer
43
- virtual void train(size_t n, const float* x) = 0;
85
+ ///< Train the norm quantizer
86
+ void train_norm(size_t n, const float* norms);
87
+
88
+ void compute_codes(const float* x, uint8_t* codes, size_t n)
89
+ const override {
90
+ compute_codes_add_centroids(x, codes, n);
91
+ }
44
92
 
45
93
  /** Encode a set of vectors
46
94
  *
47
95
  * @param x vectors to encode, size n * d
48
96
  * @param codes output codes, size n * code_size
97
+ * @param centroids centroids to be added to x, size n * d
49
98
  */
50
- virtual void compute_codes(const float* x, uint8_t* codes, size_t n)
51
- const = 0;
99
+ virtual void compute_codes_add_centroids(
100
+ const float* x,
101
+ uint8_t* codes,
102
+ size_t n,
103
+ const float* centroids = nullptr) const = 0;
52
104
 
53
105
  /** pack a series of code to bit-compact format
54
106
  *
55
- * @param codes codes to be packed, size n * code_size
107
+ * @param codes codes to be packed, size n * code_size
56
108
  * @param packed_codes output bit-compact codes
57
- * @param ld_codes leading dimension of codes
109
+ * @param ld_codes leading dimension of codes
110
+ * @param norms norms of the vectors (size n). Will be computed if
111
+ * needed but not provided
112
+ * @param centroids centroids to be added to x, size n * d
58
113
  */
59
114
  void pack_codes(
60
115
  size_t n,
61
116
  const int32_t* codes,
62
117
  uint8_t* packed_codes,
63
- int64_t ld_codes = -1) const;
118
+ int64_t ld_codes = -1,
119
+ const float* norms = nullptr,
120
+ const float* centroids = nullptr) const;
64
121
 
65
122
  /** Decode a set of vectors
66
123
  *
67
124
  * @param codes codes to decode, size n * code_size
68
125
  * @param x output vectors, size n * d
69
126
  */
70
- void decode(const uint8_t* codes, float* x, size_t n) const;
127
+ void decode(const uint8_t* codes, float* x, size_t n) const override;
128
+
129
+ /** Decode a set of vectors in non-packed format
130
+ *
131
+ * @param codes codes to decode, size n * ld_codes
132
+ * @param x output vectors, size n * d
133
+ */
134
+ virtual void decode_unpacked(
135
+ const int32_t* codes,
136
+ float* x,
137
+ size_t n,
138
+ int64_t ld_codes = -1) const;
71
139
 
72
140
  /****************************************************************************
73
- * Support for exhaustive distance computations with the centroids.
74
- * Hence, the number of elements that can be enumerated is not too large.
141
+ * Search functions in an external set of codes.
142
+ ****************************************************************************/
143
+
144
+ /// Also determines what's in the codes
145
+ Search_type_t search_type;
146
+
147
+ /// min/max for quantization of norms
148
+ float norm_min, norm_max;
149
+
150
+ template <bool is_IP, Search_type_t effective_search_type>
151
+ float compute_1_distance_LUT(const uint8_t* codes, const float* LUT) const;
152
+
153
+ /*
154
+ float compute_1_L2sqr(const uint8_t* codes, const float* LUT);
155
+ */
156
+ /****************************************************************************
157
+ * Support for exhaustive distance computations with all the centroids.
158
+ * Hence, the number of these centroids should not be too large.
75
159
  ****************************************************************************/
76
160
  using idx_t = Index::idx_t;
77
161
 
@@ -83,11 +167,18 @@ struct AdditiveQuantizer {
83
167
  *
84
168
  * @param xq query vector, size (n, d)
85
169
  * @param LUT look-up table, size (n, total_codebook_size)
170
+ * @param alpha compute alpha * inner-product
171
+ * @param ld_lut leading dimension of LUT
86
172
  */
87
- void compute_LUT(size_t n, const float* xq, float* LUT) const;
173
+ virtual void compute_LUT(
174
+ size_t n,
175
+ const float* xq,
176
+ float* LUT,
177
+ float alpha = 1.0f,
178
+ long ld_lut = -1) const;
88
179
 
89
180
  /// exact IP search
90
- void knn_exact_inner_product(
181
+ void knn_centroids_inner_product(
91
182
  idx_t n,
92
183
  const float* xq,
93
184
  idx_t k,
@@ -101,7 +192,7 @@ struct AdditiveQuantizer {
101
192
  void compute_centroid_norms(float* norms) const;
102
193
 
103
194
  /** Exact L2 search, with precomputed norms */
104
- void knn_exact_L2(
195
+ void knn_centroids_L2(
105
196
  idx_t n,
106
197
  const float* xq,
107
198
  idx_t k,