faiss 0.1.7 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -7
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +8 -2
  6. data/ext/faiss/index.cpp +102 -69
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +0 -5
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +26 -12
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -0,0 +1,291 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexResidual.h>
9
+
10
+ #include <algorithm>
11
+ #include <cmath>
12
+ #include <cstring>
13
+
14
+ #include <faiss/impl/FaissAssert.h>
15
+ #include <faiss/impl/ResultHandler.h>
16
+ #include <faiss/utils/distances.h>
17
+ #include <faiss/utils/extra_distances.h>
18
+ #include <faiss/utils/utils.h>
19
+
20
+ namespace faiss {
21
+
22
+ /**************************************************************************************
23
+ * IndexResidual
24
+ **************************************************************************************/
25
+
26
+ IndexResidual::IndexResidual(
27
+ int d, ///< dimensionality of the input vectors
28
+ size_t M, ///< number of subquantizers
29
+ size_t nbits, ///< number of bit per subvector index
30
+ MetricType metric,
31
+ Search_type_t search_type_in)
32
+ : Index(d, metric), rq(d, M, nbits), search_type(ST_decompress) {
33
+ is_trained = false;
34
+ norm_max = norm_min = NAN;
35
+ set_search_type(search_type_in);
36
+ }
37
+
38
+ IndexResidual::IndexResidual(
39
+ int d,
40
+ const std::vector<size_t>& nbits,
41
+ MetricType metric,
42
+ Search_type_t search_type_in)
43
+ : Index(d, metric), rq(d, nbits), search_type(ST_decompress) {
44
+ is_trained = false;
45
+ norm_max = norm_min = NAN;
46
+ set_search_type(search_type_in);
47
+ }
48
+
49
+ IndexResidual::IndexResidual() : IndexResidual(0, 0, 0) {}
50
+
51
+ void IndexResidual::set_search_type(Search_type_t new_search_type) {
52
+ int norm_bits = new_search_type == ST_norm_float ? 32
53
+ : new_search_type == ST_norm_qint8 ? 8
54
+ : 0;
55
+
56
+ FAISS_THROW_IF_NOT(ntotal == 0);
57
+
58
+ search_type = new_search_type;
59
+ code_size = (rq.tot_bits + norm_bits + 7) / 8;
60
+ }
61
+
62
+ void IndexResidual::train(idx_t n, const float* x) {
63
+ rq.train(n, x);
64
+
65
+ std::vector<float> norms(n);
66
+ fvec_norms_L2sqr(norms.data(), x, d, n);
67
+
68
+ norm_min = HUGE_VALF;
69
+ norm_max = -HUGE_VALF;
70
+ for (idx_t i = 0; i < n; i++) {
71
+ if (norms[i] < norm_min) {
72
+ norm_min = norms[i];
73
+ }
74
+ if (norms[i] > norm_min) {
75
+ norm_max = norms[i];
76
+ }
77
+ }
78
+
79
+ is_trained = true;
80
+ }
81
+
82
+ void IndexResidual::add(idx_t n, const float* x) {
83
+ FAISS_THROW_IF_NOT(is_trained);
84
+ codes.resize((n + ntotal) * rq.code_size);
85
+ if (search_type == ST_decompress || search_type == ST_LUT_nonorm) {
86
+ rq.compute_codes(x, &codes[ntotal * rq.code_size], n);
87
+ } else {
88
+ // should compute codes + compute and quantize norms
89
+ FAISS_THROW_MSG("not implemented");
90
+ }
91
+ ntotal += n;
92
+ }
93
+
94
+ namespace {
95
+
96
+ template <class VectorDistance, class ResultHandler>
97
+ void search_with_decompress(
98
+ const IndexResidual& ir,
99
+ const float* xq,
100
+ VectorDistance& vd,
101
+ ResultHandler& res) {
102
+ const uint8_t* codes = ir.codes.data();
103
+ size_t ntotal = ir.ntotal;
104
+ size_t code_size = ir.code_size;
105
+
106
+ using SingleResultHandler = typename ResultHandler::SingleResultHandler;
107
+
108
+ #pragma omp parallel for
109
+ for (int64_t q = 0; q < res.nq; q++) {
110
+ SingleResultHandler resi(res);
111
+ resi.begin(q);
112
+ std::vector<float> tmp(ir.d);
113
+ const float* x = xq + ir.d * q;
114
+ for (size_t i = 0; i < ntotal; i++) {
115
+ ir.rq.decode(codes + i * code_size, tmp.data(), 1);
116
+ float dis = vd(x, tmp.data());
117
+ resi.add_result(dis, i);
118
+ }
119
+ resi.end();
120
+ }
121
+ }
122
+
123
+ } // anonymous namespace
124
+
125
+ void IndexResidual::search(
126
+ idx_t n,
127
+ const float* x,
128
+ idx_t k,
129
+ float* distances,
130
+ idx_t* labels) const {
131
+ if (search_type == ST_decompress) {
132
+ if (metric_type == METRIC_L2) {
133
+ using VD = VectorDistance<METRIC_L2>;
134
+ VD vd = {size_t(d), metric_arg};
135
+ HeapResultHandler<VD::C> rh(n, distances, labels, k);
136
+ search_with_decompress(*this, x, vd, rh);
137
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
138
+ using VD = VectorDistance<METRIC_INNER_PRODUCT>;
139
+ VD vd = {size_t(d), metric_arg};
140
+ HeapResultHandler<VD::C> rh(n, distances, labels, k);
141
+ search_with_decompress(*this, x, vd, rh);
142
+ }
143
+ } else {
144
+ FAISS_THROW_MSG("not implemented");
145
+ }
146
+ }
147
+
148
+ void IndexResidual::reset() {
149
+ codes.clear();
150
+ ntotal = 0;
151
+ }
152
+
153
+ size_t IndexResidual::sa_code_size() const {
154
+ return code_size;
155
+ }
156
+
157
+ void IndexResidual::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
158
+ return rq.compute_codes(x, bytes, n);
159
+ }
160
+
161
+ void IndexResidual::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
162
+ return rq.decode(bytes, x, n);
163
+ }
164
+
165
+ /**************************************************************************************
166
+ * ResidualCoarseQuantizer
167
+ **************************************************************************************/
168
+
169
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer(
170
+ int d, ///< dimensionality of the input vectors
171
+ size_t M, ///< number of subquantizers
172
+ size_t nbits, ///< number of bit per subvector index
173
+ MetricType metric)
174
+ : Index(d, metric), rq(d, M, nbits), beam_factor(4.0) {
175
+ FAISS_THROW_IF_NOT(rq.tot_bits <= 63);
176
+ is_trained = false;
177
+ }
178
+
179
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer(
180
+ int d,
181
+ const std::vector<size_t>& nbits,
182
+ MetricType metric)
183
+ : Index(d, metric), rq(d, nbits), beam_factor(4.0) {
184
+ FAISS_THROW_IF_NOT(rq.tot_bits <= 63);
185
+ is_trained = false;
186
+ }
187
+
188
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer() {}
189
+
190
+ void ResidualCoarseQuantizer::train(idx_t n, const float* x) {
191
+ rq.train(n, x);
192
+ is_trained = true;
193
+ ntotal = (idx_t)1 << rq.tot_bits;
194
+ }
195
+
196
+ void ResidualCoarseQuantizer::add(idx_t, const float*) {
197
+ FAISS_THROW_MSG("not applicable");
198
+ }
199
+
200
+ void ResidualCoarseQuantizer::set_beam_factor(float new_beam_factor) {
201
+ centroid_norms.resize(0);
202
+ beam_factor = new_beam_factor;
203
+ if (new_beam_factor > 0) {
204
+ FAISS_THROW_IF_NOT(new_beam_factor >= 1.0);
205
+ return;
206
+ }
207
+
208
+ if (metric_type == METRIC_L2) {
209
+ centroid_norms.resize((size_t)1 << rq.tot_bits);
210
+ rq.compute_centroid_norms(centroid_norms.data());
211
+ }
212
+ }
213
+
214
+ void ResidualCoarseQuantizer::search(
215
+ idx_t n,
216
+ const float* x,
217
+ idx_t k,
218
+ float* distances,
219
+ idx_t* labels) const {
220
+ if (beam_factor < 0) {
221
+ if (metric_type == METRIC_INNER_PRODUCT) {
222
+ rq.knn_exact_inner_product(n, x, k, distances, labels);
223
+ } else if (metric_type == METRIC_L2) {
224
+ FAISS_THROW_IF_NOT(centroid_norms.size() == ntotal);
225
+ rq.knn_exact_L2(n, x, k, distances, labels, centroid_norms.data());
226
+ }
227
+ return;
228
+ }
229
+
230
+ int beam_size = int(k * beam_factor);
231
+
232
+ size_t memory_per_point = rq.memory_per_point(beam_size);
233
+
234
+ /*
235
+
236
+ printf("mem per point %ld n=%d max_mem_distance=%ld mem_kb=%zd\n",
237
+ memory_per_point, int(n), rq.max_mem_distances, get_mem_usage_kb());
238
+ */
239
+ if (n > 1 && memory_per_point * n > rq.max_mem_distances) {
240
+ // then split queries to reduce temp memory
241
+ idx_t bs = rq.max_mem_distances / memory_per_point;
242
+ if (bs == 0) {
243
+ bs = 1; // otherwise we can't do much
244
+ }
245
+ if (verbose) {
246
+ printf("ResidualCoarseQuantizer::search: run %d searches in batches of size %d\n",
247
+ int(n),
248
+ int(bs));
249
+ }
250
+ for (idx_t i0 = 0; i0 < n; i0 += bs) {
251
+ idx_t i1 = std::min(n, i0 + bs);
252
+ search(i1 - i0, x + i0 * d, k, distances + i0 * k, labels + i0 * k);
253
+ InterruptCallback::check();
254
+ }
255
+ return;
256
+ }
257
+
258
+ std::vector<int32_t> codes(beam_size * rq.M * n);
259
+ std::vector<float> beam_distances(n * beam_size);
260
+
261
+ rq.refine_beam(
262
+ n, 1, x, beam_size, codes.data(), nullptr, beam_distances.data());
263
+
264
+ #pragma omp parallel for if (n > 4000)
265
+ for (idx_t i = 0; i < n; i++) {
266
+ memcpy(distances + i * k,
267
+ beam_distances.data() + beam_size * i,
268
+ k * sizeof(distances[0]));
269
+
270
+ const int32_t* codes_i = codes.data() + beam_size * i * rq.M;
271
+ for (idx_t j = 0; j < k; j++) {
272
+ idx_t l = 0;
273
+ int shift = 0;
274
+ for (int m = 0; m < rq.M; m++) {
275
+ l |= (*codes_i++) << shift;
276
+ shift += rq.nbits[m];
277
+ }
278
+ labels[i * k + j] = l;
279
+ }
280
+ }
281
+ }
282
+
283
+ void ResidualCoarseQuantizer::reconstruct(idx_t key, float* recons) const {
284
+ rq.decode_64bit(key, recons);
285
+ }
286
+
287
+ void ResidualCoarseQuantizer::reset() {
288
+ FAISS_THROW_MSG("not applicable");
289
+ }
290
+
291
+ } // namespace faiss
@@ -0,0 +1,152 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #ifndef FAISS_INDEX_RESIDUAL_H
9
+ #define FAISS_INDEX_RESIDUAL_H
10
+
11
+ #include <stdint.h>
12
+
13
+ #include <vector>
14
+
15
+ #include <faiss/Index.h>
16
+ #include <faiss/impl/ResidualQuantizer.h>
17
+ #include <faiss/impl/platform_macros.h>
18
+
19
+ namespace faiss {
20
+
21
+ /** Index based on a residual quantizer. Stored vectors are
22
+ * approximated by residual quantization codes.
23
+ * Can also be used as a codec
24
+ */
25
+ struct IndexResidual : Index {
26
+ /// The residual quantizer used to encode the vectors
27
+ ResidualQuantizer rq;
28
+
29
+ enum Search_type_t {
30
+ ST_decompress, ///< decompress database vector
31
+ ST_LUT_nonorm, ///< use a LUT, don't include norms (OK for IP or
32
+ ///< normalized vectors)
33
+ ST_norm_float, ///< use a LUT, and store float32 norm with the vectors
34
+ ST_norm_qint8, ///< use a LUT, and store 8bit-quantized norm
35
+ };
36
+ Search_type_t search_type;
37
+
38
+ /// min/max for quantization of norms
39
+ float norm_min, norm_max;
40
+
41
+ /// size of residual quantizer codes + norms
42
+ size_t code_size;
43
+
44
+ /// Codes. Size ntotal * rq.code_size
45
+ std::vector<uint8_t> codes;
46
+
47
+ /** Constructor.
48
+ *
49
+ * @param d dimensionality of the input vectors
50
+ * @param M number of subquantizers
51
+ * @param nbits number of bit per subvector index
52
+ */
53
+ IndexResidual(
54
+ int d, ///< dimensionality of the input vectors
55
+ size_t M, ///< number of subquantizers
56
+ size_t nbits, ///< number of bit per subvector index
57
+ MetricType metric = METRIC_L2,
58
+ Search_type_t search_type = ST_decompress);
59
+
60
+ IndexResidual(
61
+ int d,
62
+ const std::vector<size_t>& nbits,
63
+ MetricType metric = METRIC_L2,
64
+ Search_type_t search_type = ST_decompress);
65
+
66
+ IndexResidual();
67
+
68
+ /// set search type and update parameters
69
+ void set_search_type(Search_type_t search_type);
70
+
71
+ void train(idx_t n, const float* x) override;
72
+
73
+ void add(idx_t n, const float* x) override;
74
+
75
+ /// not implemented
76
+ void search(
77
+ idx_t n,
78
+ const float* x,
79
+ idx_t k,
80
+ float* distances,
81
+ idx_t* labels) const override;
82
+
83
+ void reset() override;
84
+
85
+ /* The standalone codec interface */
86
+ size_t sa_code_size() const override;
87
+
88
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
89
+
90
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
91
+
92
+ // DistanceComputer* get_distance_computer() const override;
93
+ };
94
+
95
+ /** A "virtual" index where the elements are the residual quantizer centroids.
96
+ *
97
+ * Intended for use as a coarse quantizer in an IndexIVF.
98
+ */
99
+ struct ResidualCoarseQuantizer : Index {
100
+ /// The residual quantizer used to encode the vectors
101
+ ResidualQuantizer rq;
102
+
103
+ /// factor between the beam size and the search k
104
+ /// if negative, use exact search-to-centroid
105
+ float beam_factor;
106
+
107
+ /// norms of centroids, useful for knn-search
108
+ std::vector<float> centroid_norms;
109
+
110
+ /// computes centroid norms if required
111
+ void set_beam_factor(float new_beam_factor);
112
+
113
+ /** Constructor.
114
+ *
115
+ * @param d dimensionality of the input vectors
116
+ * @param M number of subquantizers
117
+ * @param nbits number of bit per subvector index
118
+ */
119
+ ResidualCoarseQuantizer(
120
+ int d, ///< dimensionality of the input vectors
121
+ size_t M, ///< number of subquantizers
122
+ size_t nbits, ///< number of bit per subvector index
123
+ MetricType metric = METRIC_L2);
124
+
125
+ ResidualCoarseQuantizer(
126
+ int d,
127
+ const std::vector<size_t>& nbits,
128
+ MetricType metric = METRIC_L2);
129
+
130
+ ResidualCoarseQuantizer();
131
+
132
+ void train(idx_t n, const float* x) override;
133
+
134
+ /// N/A
135
+ void add(idx_t n, const float* x) override;
136
+
137
+ void search(
138
+ idx_t n,
139
+ const float* x,
140
+ idx_t k,
141
+ float* distances,
142
+ idx_t* labels) const override;
143
+
144
+ void reconstruct(idx_t key, float* recons) const override;
145
+
146
+ /// N/A
147
+ void reset() override;
148
+ };
149
+
150
+ } // namespace faiss
151
+
152
+ #endif