faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -0,0 +1,316 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // quiet the noise
9
+ // XXclang-format off
10
+
11
+ #include <faiss/IndexIVFAdditiveQuantizer.h>
12
+
13
+ #include <algorithm>
14
+ #include <cmath>
15
+ #include <cstring>
16
+
17
+ #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/ResidualQuantizer.h>
19
+ #include <faiss/impl/ResultHandler.h>
20
+ #include <faiss/utils/distances.h>
21
+ #include <faiss/utils/extra_distances.h>
22
+ #include <faiss/utils/utils.h>
23
+
24
+ namespace faiss {
25
+
26
+ /**************************************************************************************
27
+ * IndexIVFAdditiveQuantizer
28
+ **************************************************************************************/
29
+
30
+ IndexIVFAdditiveQuantizer::IndexIVFAdditiveQuantizer(
31
+ AdditiveQuantizer* aq,
32
+ Index* quantizer,
33
+ size_t d,
34
+ size_t nlist,
35
+ MetricType metric)
36
+ : IndexIVF(quantizer, d, nlist, 0, metric), aq(aq) {
37
+ by_residual = true;
38
+ }
39
+
40
+ IndexIVFAdditiveQuantizer::IndexIVFAdditiveQuantizer(AdditiveQuantizer* aq)
41
+ : IndexIVF(), aq(aq) {}
42
+
43
+ void IndexIVFAdditiveQuantizer::train_residual(idx_t n, const float* x) {
44
+ const float* x_in = x;
45
+
46
+ size_t max_train_points = 1024 * ((size_t)1 << aq->nbits[0]);
47
+
48
+ x = fvecs_maybe_subsample(
49
+ d, (size_t*)&n, max_train_points, x, verbose, 1234);
50
+ ScopeDeleter1<float> del_x(x_in == x ? nullptr : x);
51
+
52
+ if (by_residual) {
53
+ std::vector<Index::idx_t> idx(n);
54
+ quantizer->assign(n, x, idx.data());
55
+
56
+ std::vector<float> residuals(n * d);
57
+ quantizer->compute_residual_n(n, x, residuals.data(), idx.data());
58
+
59
+ aq->train(n, residuals.data());
60
+ } else {
61
+ aq->train(n, x);
62
+ }
63
+ }
64
+
65
+ void IndexIVFAdditiveQuantizer::encode_vectors(
66
+ idx_t n,
67
+ const float* x,
68
+ const idx_t* list_nos,
69
+ uint8_t* codes,
70
+ bool include_listnos) const {
71
+ FAISS_THROW_IF_NOT(is_trained);
72
+
73
+ // first encode then possibly add listnos
74
+
75
+ if (by_residual) {
76
+ // subtract centroids
77
+ std::vector<float> residuals(n * d);
78
+
79
+ #pragma omp parallel if (n > 10000)
80
+ for (idx_t i = 0; i < n; i++) {
81
+ quantizer->compute_residual(
82
+ x + i * d,
83
+ residuals.data() + i * d,
84
+ list_nos[i] >= 0 ? list_nos[i] : 0);
85
+ }
86
+ aq->compute_codes(residuals.data(), codes, n);
87
+ } else {
88
+ aq->compute_codes(x, codes, n);
89
+ }
90
+
91
+ if (include_listnos) {
92
+ // write back from the end, where there is enough space
93
+ size_t coarse_size = coarse_code_size();
94
+ for (idx_t i = n - 1; i >= 0; i--) {
95
+ uint8_t* code = codes + i * (code_size + coarse_size);
96
+ memmove(code + coarse_size, codes + i * code_size, code_size);
97
+ encode_listno(list_nos[i], code);
98
+ }
99
+ }
100
+ }
101
+
102
+ IndexIVFAdditiveQuantizer::~IndexIVFAdditiveQuantizer() {}
103
+
104
+ /*********************************************
105
+ * AQInvertedListScanner
106
+ *********************************************/
107
+
108
+ namespace {
109
+
110
+ using Search_type_t = AdditiveQuantizer::Search_type_t;
111
+
112
+ struct AQInvertedListScanner : InvertedListScanner {
113
+ const IndexIVFAdditiveQuantizer& ia;
114
+ const AdditiveQuantizer& aq;
115
+ std::vector<float> tmp;
116
+
117
+ AQInvertedListScanner(const IndexIVFAdditiveQuantizer& ia, bool store_pairs)
118
+ : ia(ia), aq(*ia.aq) {
119
+ this->store_pairs = store_pairs;
120
+ this->code_size = ia.code_size;
121
+ keep_max = ia.metric_type == METRIC_INNER_PRODUCT;
122
+ tmp.resize(ia.d);
123
+ }
124
+
125
+ const float* q0;
126
+
127
+ /// from now on we handle this query.
128
+ void set_query(const float* query_vector) override {
129
+ q0 = query_vector;
130
+ }
131
+
132
+ const float* q;
133
+ /// following codes come from this inverted list
134
+ void set_list(idx_t list_no, float coarse_dis) override {
135
+ if (ia.metric_type == METRIC_L2 && ia.by_residual) {
136
+ ia.quantizer->compute_residual(q0, tmp.data(), list_no);
137
+ q = tmp.data();
138
+ } else {
139
+ q = q0;
140
+ }
141
+ }
142
+
143
+ ~AQInvertedListScanner() {}
144
+ };
145
+
146
+ template <bool is_IP>
147
+ struct AQInvertedListScannerDecompress : AQInvertedListScanner {
148
+ AQInvertedListScannerDecompress(
149
+ const IndexIVFAdditiveQuantizer& ia,
150
+ bool store_pairs)
151
+ : AQInvertedListScanner(ia, store_pairs) {}
152
+
153
+ float coarse_dis = 0;
154
+
155
+ /// following codes come from this inverted list
156
+ void set_list(idx_t list_no, float coarse_dis) override {
157
+ AQInvertedListScanner::set_list(list_no, coarse_dis);
158
+ if (ia.by_residual) {
159
+ this->coarse_dis = coarse_dis;
160
+ }
161
+ }
162
+
163
+ /// compute a single query-to-code distance
164
+ float distance_to_code(const uint8_t* code) const final {
165
+ std::vector<float> b(aq.d);
166
+ aq.decode(code, b.data(), 1);
167
+ FAISS_ASSERT(q);
168
+ FAISS_ASSERT(b.data());
169
+
170
+ return is_IP ? coarse_dis + fvec_inner_product(q, b.data(), aq.d)
171
+ : fvec_L2sqr(q, b.data(), aq.d);
172
+ }
173
+
174
+ ~AQInvertedListScannerDecompress() override {}
175
+ };
176
+
177
+ template <bool is_IP, Search_type_t search_type>
178
+ struct AQInvertedListScannerLUT : AQInvertedListScanner {
179
+ std::vector<float> LUT, tmp;
180
+ float distance_bias;
181
+
182
+ AQInvertedListScannerLUT(
183
+ const IndexIVFAdditiveQuantizer& ia,
184
+ bool store_pairs)
185
+ : AQInvertedListScanner(ia, store_pairs) {
186
+ LUT.resize(aq.total_codebook_size);
187
+ tmp.resize(ia.d);
188
+ distance_bias = 0;
189
+ }
190
+
191
+ /// from now on we handle this query.
192
+ void set_query(const float* query_vector) override {
193
+ AQInvertedListScanner::set_query(query_vector);
194
+ if (!is_IP && !ia.by_residual) {
195
+ distance_bias = fvec_norm_L2sqr(query_vector, ia.d);
196
+ }
197
+ }
198
+
199
+ /// following codes come from this inverted list
200
+ void set_list(idx_t list_no, float coarse_dis) override {
201
+ AQInvertedListScanner::set_list(list_no, coarse_dis);
202
+ // TODO find a way to provide the nprobes together to do a matmul
203
+ // + precompute tables
204
+ aq.compute_LUT(1, q, LUT.data());
205
+
206
+ if (ia.by_residual) {
207
+ distance_bias = coarse_dis;
208
+ }
209
+ }
210
+
211
+ /// compute a single query-to-code distance
212
+ float distance_to_code(const uint8_t* code) const final {
213
+ return distance_bias +
214
+ aq.compute_1_distance_LUT<is_IP, search_type>(code, LUT.data());
215
+ }
216
+
217
+ ~AQInvertedListScannerLUT() override {}
218
+ };
219
+
220
+ } // anonymous namespace
221
+
222
+ InvertedListScanner* IndexIVFAdditiveQuantizer::get_InvertedListScanner(
223
+ bool store_pairs) const {
224
+ if (metric_type == METRIC_INNER_PRODUCT) {
225
+ if (aq->search_type == AdditiveQuantizer::ST_decompress) {
226
+ return new AQInvertedListScannerDecompress<true>(
227
+ *this, store_pairs);
228
+ } else {
229
+ return new AQInvertedListScannerLUT<
230
+ true,
231
+ AdditiveQuantizer::ST_LUT_nonorm>(*this, store_pairs);
232
+ }
233
+ } else {
234
+ switch (aq->search_type) {
235
+ case AdditiveQuantizer::ST_decompress:
236
+ return new AQInvertedListScannerDecompress<false>(
237
+ *this, store_pairs);
238
+ #define A(st) \
239
+ case AdditiveQuantizer::st: \
240
+ return new AQInvertedListScannerLUT<false, AdditiveQuantizer::st>( \
241
+ *this, store_pairs);
242
+ A(ST_LUT_nonorm)
243
+ // A(ST_norm_from_LUT)
244
+ A(ST_norm_float)
245
+ A(ST_norm_qint8)
246
+ A(ST_norm_qint4)
247
+ A(ST_norm_cqint8)
248
+ A(ST_norm_cqint4)
249
+ #undef A
250
+ default:
251
+ FAISS_THROW_FMT(
252
+ "search type %d not supported", aq->search_type);
253
+ }
254
+ }
255
+ }
256
+
257
+ /**************************************************************************************
258
+ * IndexIVFResidualQuantizer
259
+ **************************************************************************************/
260
+
261
+ IndexIVFResidualQuantizer::IndexIVFResidualQuantizer(
262
+ Index* quantizer,
263
+ size_t d,
264
+ size_t nlist,
265
+ const std::vector<size_t>& nbits,
266
+ MetricType metric,
267
+ Search_type_t search_type)
268
+ : IndexIVFAdditiveQuantizer(&rq, quantizer, d, nlist, metric),
269
+ rq(d, nbits, search_type) {
270
+ code_size = invlists->code_size = rq.code_size;
271
+ }
272
+
273
+ IndexIVFResidualQuantizer::IndexIVFResidualQuantizer()
274
+ : IndexIVFAdditiveQuantizer(&rq) {}
275
+
276
+ IndexIVFResidualQuantizer::IndexIVFResidualQuantizer(
277
+ Index* quantizer,
278
+ size_t d,
279
+ size_t nlist,
280
+ size_t M, /* number of subquantizers */
281
+ size_t nbits, /* number of bit per subvector index */
282
+ MetricType metric,
283
+ Search_type_t search_type)
284
+ : IndexIVFResidualQuantizer(
285
+ quantizer,
286
+ d,
287
+ nlist,
288
+ std::vector<size_t>(M, nbits),
289
+ metric,
290
+ search_type) {}
291
+
292
+ IndexIVFResidualQuantizer::~IndexIVFResidualQuantizer() {}
293
+
294
+ /**************************************************************************************
295
+ * IndexIVFLocalSearchQuantizer
296
+ **************************************************************************************/
297
+
298
+ IndexIVFLocalSearchQuantizer::IndexIVFLocalSearchQuantizer(
299
+ Index* quantizer,
300
+ size_t d,
301
+ size_t nlist,
302
+ size_t M, /* number of subquantizers */
303
+ size_t nbits, /* number of bit per subvector index */
304
+ MetricType metric,
305
+ Search_type_t search_type)
306
+ : IndexIVFAdditiveQuantizer(&lsq, quantizer, d, nlist, metric),
307
+ lsq(d, M, nbits, search_type) {
308
+ code_size = invlists->code_size = lsq.code_size;
309
+ }
310
+
311
+ IndexIVFLocalSearchQuantizer::IndexIVFLocalSearchQuantizer()
312
+ : IndexIVFAdditiveQuantizer(&lsq) {}
313
+
314
+ IndexIVFLocalSearchQuantizer::~IndexIVFLocalSearchQuantizer() {}
315
+
316
+ } // namespace faiss
@@ -0,0 +1,121 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #ifndef FAISS_INDEX_IVF_ADDITIVE_QUANTIZER_H
9
+ #define FAISS_INDEX_IVF_ADDITIVE_QUANTIZER_H
10
+
11
+ #include <faiss/impl/AdditiveQuantizer.h>
12
+
13
+ #include <cstdint>
14
+ #include <vector>
15
+
16
+ #include <faiss/IndexIVF.h>
17
+ #include <faiss/impl/LocalSearchQuantizer.h>
18
+ #include <faiss/impl/ResidualQuantizer.h>
19
+ #include <faiss/impl/platform_macros.h>
20
+
21
+ namespace faiss {
22
+
23
+ /// Abstract class for IVF additive quantizers.
24
+ /// The search functions are in common.
25
+ struct IndexIVFAdditiveQuantizer : IndexIVF {
26
+ // the quantizer
27
+ AdditiveQuantizer* aq;
28
+ bool by_residual = true;
29
+ int use_precomputed_table = 0; // for future use
30
+
31
+ using Search_type_t = AdditiveQuantizer::Search_type_t;
32
+
33
+ IndexIVFAdditiveQuantizer(
34
+ AdditiveQuantizer* aq,
35
+ Index* quantizer,
36
+ size_t d,
37
+ size_t nlist,
38
+ MetricType metric = METRIC_L2);
39
+
40
+ explicit IndexIVFAdditiveQuantizer(AdditiveQuantizer* aq);
41
+
42
+ void train_residual(idx_t n, const float* x) override;
43
+
44
+ void encode_vectors(
45
+ idx_t n,
46
+ const float* x,
47
+ const idx_t* list_nos,
48
+ uint8_t* codes,
49
+ bool include_listnos = false) const override;
50
+
51
+ InvertedListScanner* get_InvertedListScanner(
52
+ bool store_pairs) const override;
53
+
54
+ ~IndexIVFAdditiveQuantizer() override;
55
+ };
56
+
57
+ /** IndexIVF based on a residual quantizer. Stored vectors are
58
+ * approximated by residual quantization codes.
59
+ */
60
+ struct IndexIVFResidualQuantizer : IndexIVFAdditiveQuantizer {
61
+ /// The residual quantizer used to encode the vectors
62
+ ResidualQuantizer rq;
63
+
64
+ /** Constructor.
65
+ *
66
+ * @param d dimensionality of the input vectors
67
+ * @param M number of subquantizers
68
+ * @param nbits number of bit per subvector index
69
+ */
70
+ IndexIVFResidualQuantizer(
71
+ Index* quantizer,
72
+ size_t d,
73
+ size_t nlist,
74
+ const std::vector<size_t>& nbits,
75
+ MetricType metric = METRIC_L2,
76
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
77
+
78
+ IndexIVFResidualQuantizer(
79
+ Index* quantizer,
80
+ size_t d,
81
+ size_t nlist,
82
+ size_t M, /* number of subquantizers */
83
+ size_t nbits, /* number of bit per subvector index */
84
+ MetricType metric = METRIC_L2,
85
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
86
+
87
+ IndexIVFResidualQuantizer();
88
+
89
+ virtual ~IndexIVFResidualQuantizer();
90
+ };
91
+
92
+ /** IndexIVF based on a residual quantizer. Stored vectors are
93
+ * approximated by residual quantization codes.
94
+ */
95
+ struct IndexIVFLocalSearchQuantizer : IndexIVFAdditiveQuantizer {
96
+ /// The LSQ quantizer used to encode the vectors
97
+ LocalSearchQuantizer lsq;
98
+
99
+ /** Constructor.
100
+ *
101
+ * @param d dimensionality of the input vectors
102
+ * @param M number of subquantizers
103
+ * @param nbits number of bit per subvector index
104
+ */
105
+ IndexIVFLocalSearchQuantizer(
106
+ Index* quantizer,
107
+ size_t d,
108
+ size_t nlist,
109
+ size_t M, /* number of subquantizers */
110
+ size_t nbits, /* number of bit per subvector index */
111
+ MetricType metric = METRIC_L2,
112
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
113
+
114
+ IndexIVFLocalSearchQuantizer();
115
+
116
+ virtual ~IndexIVFLocalSearchQuantizer();
117
+ };
118
+
119
+ } // namespace faiss
120
+
121
+ #endif