faiss 0.2.0 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -0,0 +1,407 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // quiet the noise
9
+ // clang-format off
10
+
11
+ #include <faiss/IndexAdditiveQuantizer.h>
12
+
13
+ #include <algorithm>
14
+ #include <cmath>
15
+ #include <cstring>
16
+
17
+ #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/ResidualQuantizer.h>
19
+ #include <faiss/impl/ResultHandler.h>
20
+ #include <faiss/utils/distances.h>
21
+ #include <faiss/utils/extra_distances.h>
22
+ #include <faiss/utils/utils.h>
23
+
24
+
25
+ namespace faiss {
26
+
27
+ /**************************************************************************************
28
+ * IndexAdditiveQuantizer
29
+ **************************************************************************************/
30
+
31
+ IndexAdditiveQuantizer::IndexAdditiveQuantizer(
32
+ idx_t d,
33
+ AdditiveQuantizer* aq,
34
+ MetricType metric):
35
+ IndexFlatCodes(aq->code_size, d, metric), aq(aq)
36
+ {
37
+ FAISS_THROW_IF_NOT(metric == METRIC_INNER_PRODUCT || metric == METRIC_L2);
38
+ }
39
+
40
+
41
+ namespace {
42
+
43
+ template <class VectorDistance, class ResultHandler>
44
+ void search_with_decompress(
45
+ const IndexAdditiveQuantizer& ir,
46
+ const float* xq,
47
+ VectorDistance& vd,
48
+ ResultHandler& res) {
49
+ const uint8_t* codes = ir.codes.data();
50
+ size_t ntotal = ir.ntotal;
51
+ size_t code_size = ir.code_size;
52
+ const AdditiveQuantizer *aq = ir.aq;
53
+
54
+ using SingleResultHandler = typename ResultHandler::SingleResultHandler;
55
+
56
+ #pragma omp parallel for if(res.nq > 100)
57
+ for (int64_t q = 0; q < res.nq; q++) {
58
+ SingleResultHandler resi(res);
59
+ resi.begin(q);
60
+ std::vector<float> tmp(ir.d);
61
+ const float* x = xq + ir.d * q;
62
+ for (size_t i = 0; i < ntotal; i++) {
63
+ aq->decode(codes + i * code_size, tmp.data(), 1);
64
+ float dis = vd(x, tmp.data());
65
+ resi.add_result(dis, i);
66
+ }
67
+ resi.end();
68
+ }
69
+ }
70
+
71
+ template<bool is_IP, AdditiveQuantizer::Search_type_t st, class ResultHandler>
72
+ void search_with_LUT(
73
+ const IndexAdditiveQuantizer& ir,
74
+ const float* xq,
75
+ ResultHandler& res)
76
+ {
77
+ const AdditiveQuantizer & aq = *ir.aq;
78
+ const uint8_t* codes = ir.codes.data();
79
+ size_t ntotal = ir.ntotal;
80
+ size_t code_size = aq.code_size;
81
+ size_t nq = res.nq;
82
+ size_t d = ir.d;
83
+
84
+ using SingleResultHandler = typename ResultHandler::SingleResultHandler;
85
+ std::unique_ptr<float []> LUT(new float[nq * aq.total_codebook_size]);
86
+
87
+ aq.compute_LUT(nq, xq, LUT.get());
88
+
89
+ #pragma omp parallel for if(nq > 100)
90
+ for (int64_t q = 0; q < nq; q++) {
91
+ SingleResultHandler resi(res);
92
+ resi.begin(q);
93
+ std::vector<float> tmp(aq.d);
94
+ const float *LUT_q = LUT.get() + aq.total_codebook_size * q;
95
+ float bias = 0;
96
+ if (!is_IP) { // the LUT function returns ||y||^2 - 2 * <x, y>, need to add ||x||^2
97
+ bias = fvec_norm_L2sqr(xq + q * d, d);
98
+ }
99
+ for (size_t i = 0; i < ntotal; i++) {
100
+ float dis = aq.compute_1_distance_LUT<is_IP, st>(
101
+ codes + i * code_size,
102
+ LUT_q
103
+ );
104
+ resi.add_result(dis + bias, i);
105
+ }
106
+ resi.end();
107
+ }
108
+
109
+ }
110
+
111
+
112
+ } // anonymous namespace
113
+
114
+ void IndexAdditiveQuantizer::search(
115
+ idx_t n,
116
+ const float* x,
117
+ idx_t k,
118
+ float* distances,
119
+ idx_t* labels) const {
120
+ if (aq->search_type == AdditiveQuantizer::ST_decompress) {
121
+ if (metric_type == METRIC_L2) {
122
+ using VD = VectorDistance<METRIC_L2>;
123
+ VD vd = {size_t(d), metric_arg};
124
+ HeapResultHandler<VD::C> rh(n, distances, labels, k);
125
+ search_with_decompress(*this, x, vd, rh);
126
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
127
+ using VD = VectorDistance<METRIC_INNER_PRODUCT>;
128
+ VD vd = {size_t(d), metric_arg};
129
+ HeapResultHandler<VD::C> rh(n, distances, labels, k);
130
+ search_with_decompress(*this, x, vd, rh);
131
+ }
132
+ } else {
133
+ if (metric_type == METRIC_INNER_PRODUCT) {
134
+ HeapResultHandler<CMin<float, idx_t> > rh(n, distances, labels, k);
135
+ search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm> (*this, x, rh);
136
+ } else {
137
+ HeapResultHandler<CMax<float, idx_t> > rh(n, distances, labels, k);
138
+
139
+ if (aq->search_type == AdditiveQuantizer::ST_norm_float) {
140
+ search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
141
+ } else if (aq->search_type == AdditiveQuantizer::ST_LUT_nonorm) {
142
+ search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
143
+ } else if (aq->search_type == AdditiveQuantizer::ST_norm_qint8) {
144
+ search_with_LUT<false, AdditiveQuantizer::ST_norm_qint8> (*this, x, rh);
145
+ } else if (aq->search_type == AdditiveQuantizer::ST_norm_qint4) {
146
+ search_with_LUT<false, AdditiveQuantizer::ST_norm_qint4> (*this, x, rh);
147
+ } else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint8) {
148
+ search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this, x, rh);
149
+ } else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint4) {
150
+ search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint4> (*this, x, rh);
151
+ } else {
152
+ FAISS_THROW_FMT("search type %d not supported", aq->search_type);
153
+ }
154
+ }
155
+
156
+ }
157
+ }
158
+
159
+ void IndexAdditiveQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
160
+ return aq->compute_codes(x, bytes, n);
161
+ }
162
+
163
+ void IndexAdditiveQuantizer::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
164
+ return aq->decode(bytes, x, n);
165
+ }
166
+
167
+
168
+
169
+
170
+ /**************************************************************************************
171
+ * IndexResidualQuantizer
172
+ **************************************************************************************/
173
+
174
+ IndexResidualQuantizer::IndexResidualQuantizer(
175
+ int d, ///< dimensionality of the input vectors
176
+ size_t M, ///< number of subquantizers
177
+ size_t nbits, ///< number of bit per subvector index
178
+ MetricType metric,
179
+ Search_type_t search_type)
180
+ : IndexResidualQuantizer(d, std::vector<size_t>(M, nbits), metric, search_type) {
181
+ }
182
+
183
+ IndexResidualQuantizer::IndexResidualQuantizer(
184
+ int d,
185
+ const std::vector<size_t>& nbits,
186
+ MetricType metric,
187
+ Search_type_t search_type)
188
+ : IndexAdditiveQuantizer(d, &rq, metric), rq(d, nbits, search_type) {
189
+ code_size = rq.code_size;
190
+ is_trained = false;
191
+ }
192
+
193
+ IndexResidualQuantizer::IndexResidualQuantizer() : IndexResidualQuantizer(0, 0, 0) {}
194
+
195
+ void IndexResidualQuantizer::train(idx_t n, const float* x) {
196
+ rq.train(n, x);
197
+ is_trained = true;
198
+ }
199
+
200
+
201
+ /**************************************************************************************
202
+ * IndexLocalSearchQuantizer
203
+ **************************************************************************************/
204
+
205
+ IndexLocalSearchQuantizer::IndexLocalSearchQuantizer(
206
+ int d,
207
+ size_t M, ///< number of subquantizers
208
+ size_t nbits, ///< number of bit per subvector index
209
+ MetricType metric,
210
+ Search_type_t search_type)
211
+ : IndexAdditiveQuantizer(d, &lsq, metric), lsq(d, M, nbits, search_type) {
212
+ code_size = lsq.code_size;
213
+ is_trained = false;
214
+ }
215
+
216
+ IndexLocalSearchQuantizer::IndexLocalSearchQuantizer() : IndexLocalSearchQuantizer(0, 0, 0) {}
217
+
218
+ void IndexLocalSearchQuantizer::train(idx_t n, const float* x) {
219
+ lsq.train(n, x);
220
+ is_trained = true;
221
+ }
222
+
223
+ /**************************************************************************************
224
+ * AdditiveCoarseQuantizer
225
+ **************************************************************************************/
226
+
227
+ AdditiveCoarseQuantizer::AdditiveCoarseQuantizer(
228
+ idx_t d,
229
+ AdditiveQuantizer* aq,
230
+ MetricType metric):
231
+ Index(d, metric), aq(aq)
232
+ {}
233
+
234
+ void AdditiveCoarseQuantizer::add(idx_t, const float*) {
235
+ FAISS_THROW_MSG("not applicable");
236
+ }
237
+
238
+ void AdditiveCoarseQuantizer::reconstruct(idx_t key, float* recons) const {
239
+ aq->decode_64bit(key, recons);
240
+ }
241
+
242
+ void AdditiveCoarseQuantizer::reset() {
243
+ FAISS_THROW_MSG("not applicable");
244
+ }
245
+
246
+
247
+ void AdditiveCoarseQuantizer::train(idx_t n, const float* x) {
248
+ if (verbose) {
249
+ printf("AdditiveCoarseQuantizer::train: training on %zd vectors\n", size_t(n));
250
+ }
251
+ aq->train(n, x);
252
+ is_trained = true;
253
+ ntotal = (idx_t)1 << aq->tot_bits;
254
+
255
+ if (metric_type == METRIC_L2) {
256
+ if (verbose) {
257
+ printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n", size_t(ntotal));
258
+ }
259
+ // this is not necessary for the residualcoarsequantizer when
260
+ // using beam search. We'll see if the memory overhead is too high
261
+ centroid_norms.resize(ntotal);
262
+ aq->compute_centroid_norms(centroid_norms.data());
263
+ }
264
+ }
265
+
266
+ void AdditiveCoarseQuantizer::search(
267
+ idx_t n,
268
+ const float* x,
269
+ idx_t k,
270
+ float* distances,
271
+ idx_t* labels) const {
272
+ if (metric_type == METRIC_INNER_PRODUCT) {
273
+ aq->knn_centroids_inner_product(n, x, k, distances, labels);
274
+ } else if (metric_type == METRIC_L2) {
275
+ FAISS_THROW_IF_NOT(centroid_norms.size() == ntotal);
276
+ aq->knn_centroids_L2(
277
+ n, x, k, distances, labels, centroid_norms.data());
278
+ }
279
+ }
280
+
281
+ /**************************************************************************************
282
+ * ResidualCoarseQuantizer
283
+ **************************************************************************************/
284
+
285
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer(
286
+ int d, ///< dimensionality of the input vectors
287
+ const std::vector<size_t>& nbits,
288
+ MetricType metric)
289
+ : AdditiveCoarseQuantizer(d, &rq, metric), rq(d, nbits), beam_factor(4.0) {
290
+ FAISS_THROW_IF_NOT(rq.tot_bits <= 63);
291
+ is_trained = false;
292
+ }
293
+
294
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer(
295
+ int d,
296
+ size_t M, ///< number of subquantizers
297
+ size_t nbits, ///< number of bit per subvector index
298
+ MetricType metric)
299
+ : ResidualCoarseQuantizer(d, std::vector<size_t>(M, nbits), metric) {}
300
+
301
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer(): ResidualCoarseQuantizer(0, 0, 0) {}
302
+
303
+
304
+
305
+ void ResidualCoarseQuantizer::set_beam_factor(float new_beam_factor) {
306
+ beam_factor = new_beam_factor;
307
+ if (new_beam_factor > 0) {
308
+ FAISS_THROW_IF_NOT(new_beam_factor >= 1.0);
309
+ return;
310
+ } else if (metric_type == METRIC_L2 && ntotal != centroid_norms.size()) {
311
+ if (verbose) {
312
+ printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n", size_t(ntotal));
313
+ }
314
+ centroid_norms.resize(ntotal);
315
+ aq->compute_centroid_norms(centroid_norms.data());
316
+ }
317
+ }
318
+
319
+ void ResidualCoarseQuantizer::search(
320
+ idx_t n,
321
+ const float* x,
322
+ idx_t k,
323
+ float* distances,
324
+ idx_t* labels) const {
325
+ if (beam_factor < 0) {
326
+ AdditiveCoarseQuantizer::search(n, x, k, distances, labels);
327
+ return;
328
+ }
329
+
330
+ int beam_size = int(k * beam_factor);
331
+ if (beam_size > ntotal) {
332
+ beam_size = ntotal;
333
+ }
334
+ size_t memory_per_point = rq.memory_per_point(beam_size);
335
+
336
+ /*
337
+
338
+ printf("mem per point %ld n=%d max_mem_distance=%ld mem_kb=%zd\n",
339
+ memory_per_point, int(n), rq.max_mem_distances, get_mem_usage_kb());
340
+ */
341
+ if (n > 1 && memory_per_point * n > rq.max_mem_distances) {
342
+ // then split queries to reduce temp memory
343
+ idx_t bs = rq.max_mem_distances / memory_per_point;
344
+ if (bs == 0) {
345
+ bs = 1; // otherwise we can't do much
346
+ }
347
+ if (verbose) {
348
+ printf("ResidualCoarseQuantizer::search: run %d searches in batches of size %d\n",
349
+ int(n),
350
+ int(bs));
351
+ }
352
+ for (idx_t i0 = 0; i0 < n; i0 += bs) {
353
+ idx_t i1 = std::min(n, i0 + bs);
354
+ search(i1 - i0, x + i0 * d, k, distances + i0 * k, labels + i0 * k);
355
+ InterruptCallback::check();
356
+ }
357
+ return;
358
+ }
359
+
360
+ std::vector<int32_t> codes(beam_size * rq.M * n);
361
+ std::vector<float> beam_distances(n * beam_size);
362
+
363
+ rq.refine_beam(
364
+ n, 1, x, beam_size, codes.data(), nullptr, beam_distances.data());
365
+
366
+ #pragma omp parallel for if (n > 4000)
367
+ for (idx_t i = 0; i < n; i++) {
368
+ memcpy(distances + i * k,
369
+ beam_distances.data() + beam_size * i,
370
+ k * sizeof(distances[0]));
371
+
372
+ const int32_t* codes_i = codes.data() + beam_size * i * rq.M;
373
+ for (idx_t j = 0; j < k; j++) {
374
+ idx_t l = 0;
375
+ int shift = 0;
376
+ for (int m = 0; m < rq.M; m++) {
377
+ l |= (*codes_i++) << shift;
378
+ shift += rq.nbits[m];
379
+ }
380
+ labels[i * k + j] = l;
381
+ }
382
+ }
383
+ }
384
+
385
+ /**************************************************************************************
386
+ * LocalSearchCoarseQuantizer
387
+ **************************************************************************************/
388
+
389
+ LocalSearchCoarseQuantizer::LocalSearchCoarseQuantizer(
390
+ int d, ///< dimensionality of the input vectors
391
+ size_t M, ///< number of subquantizers
392
+ size_t nbits, ///< number of bit per subvector index
393
+ MetricType metric)
394
+ : AdditiveCoarseQuantizer(d, &lsq, metric), lsq(d, M, nbits) {
395
+ FAISS_THROW_IF_NOT(lsq.tot_bits <= 63);
396
+ is_trained = false;
397
+ }
398
+
399
+
400
+ LocalSearchCoarseQuantizer::LocalSearchCoarseQuantizer() {
401
+ aq = &lsq;
402
+ }
403
+
404
+
405
+
406
+
407
+ } // namespace faiss
@@ -0,0 +1,195 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #ifndef FAISS_INDEX_ADDITIVE_QUANTIZER_H
9
+ #define FAISS_INDEX_ADDITIVE_QUANTIZER_H
10
+
11
+ #include <faiss/impl/AdditiveQuantizer.h>
12
+
13
+ #include <cstdint>
14
+ #include <vector>
15
+
16
+ #include <faiss/IndexFlatCodes.h>
17
+ #include <faiss/impl/LocalSearchQuantizer.h>
18
+ #include <faiss/impl/ResidualQuantizer.h>
19
+ #include <faiss/impl/platform_macros.h>
20
+
21
+ namespace faiss {
22
+
23
+ /// Abstract class for additive quantizers. The search functions are in common.
24
+ struct IndexAdditiveQuantizer : IndexFlatCodes {
25
+ // the quantizer, this points to the relevant field in the inheriting
26
+ // classes
27
+ AdditiveQuantizer* aq;
28
+ using Search_type_t = AdditiveQuantizer::Search_type_t;
29
+
30
+ explicit IndexAdditiveQuantizer(
31
+ idx_t d = 0,
32
+ AdditiveQuantizer* aq = nullptr,
33
+ MetricType metric = METRIC_L2);
34
+
35
+ void search(
36
+ idx_t n,
37
+ const float* x,
38
+ idx_t k,
39
+ float* distances,
40
+ idx_t* labels) const override;
41
+
42
+ /* The standalone codec interface */
43
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
44
+
45
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
46
+ };
47
+
48
+ /** Index based on a residual quantizer. Stored vectors are
49
+ * approximated by residual quantization codes.
50
+ * Can also be used as a codec
51
+ */
52
+ struct IndexResidualQuantizer : IndexAdditiveQuantizer {
53
+ /// The residual quantizer used to encode the vectors
54
+ ResidualQuantizer rq;
55
+
56
+ /** Constructor.
57
+ *
58
+ * @param d dimensionality of the input vectors
59
+ * @param M number of subquantizers
60
+ * @param nbits number of bit per subvector index
61
+ */
62
+ IndexResidualQuantizer(
63
+ int d, ///< dimensionality of the input vectors
64
+ size_t M, ///< number of subquantizers
65
+ size_t nbits, ///< number of bit per subvector index
66
+ MetricType metric = METRIC_L2,
67
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
68
+
69
+ IndexResidualQuantizer(
70
+ int d,
71
+ const std::vector<size_t>& nbits,
72
+ MetricType metric = METRIC_L2,
73
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
74
+
75
+ IndexResidualQuantizer();
76
+
77
+ void train(idx_t n, const float* x) override;
78
+ };
79
+
80
+ struct IndexLocalSearchQuantizer : IndexAdditiveQuantizer {
81
+ LocalSearchQuantizer lsq;
82
+
83
+ /** Constructor.
84
+ *
85
+ * @param d dimensionality of the input vectors
86
+ * @param M number of subquantizers
87
+ * @param nbits number of bit per subvector index
88
+ */
89
+ IndexLocalSearchQuantizer(
90
+ int d, ///< dimensionality of the input vectors
91
+ size_t M, ///< number of subquantizers
92
+ size_t nbits, ///< number of bit per subvector index
93
+ MetricType metric = METRIC_L2,
94
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
95
+
96
+ IndexLocalSearchQuantizer();
97
+
98
+ void train(idx_t n, const float* x) override;
99
+ };
100
+
101
+ /** A "virtual" index where the elements are the residual quantizer centroids.
102
+ *
103
+ * Intended for use as a coarse quantizer in an IndexIVF.
104
+ */
105
+ struct AdditiveCoarseQuantizer : Index {
106
+ AdditiveQuantizer* aq;
107
+
108
+ explicit AdditiveCoarseQuantizer(
109
+ idx_t d = 0,
110
+ AdditiveQuantizer* aq = nullptr,
111
+ MetricType metric = METRIC_L2);
112
+
113
+ /// norms of centroids, useful for knn-search
114
+ std::vector<float> centroid_norms;
115
+
116
+ /// N/A
117
+ void add(idx_t n, const float* x) override;
118
+
119
+ void search(
120
+ idx_t n,
121
+ const float* x,
122
+ idx_t k,
123
+ float* distances,
124
+ idx_t* labels) const override;
125
+
126
+ void reconstruct(idx_t key, float* recons) const override;
127
+ void train(idx_t n, const float* x) override;
128
+
129
+ /// N/A
130
+ void reset() override;
131
+ };
132
+
133
+ /** The ResidualCoarseQuantizer is a bit specialized compared to the
134
+ * default AdditiveCoarseQuantizer because it can use a beam search
135
+ * at search time (slow but may be useful for very large vocabularies) */
136
+ struct ResidualCoarseQuantizer : AdditiveCoarseQuantizer {
137
+ /// The residual quantizer used to encode the vectors
138
+ ResidualQuantizer rq;
139
+
140
+ /// factor between the beam size and the search k
141
+ /// if negative, use exact search-to-centroid
142
+ float beam_factor;
143
+
144
+ /// computes centroid norms if required
145
+ void set_beam_factor(float new_beam_factor);
146
+
147
+ /** Constructor.
148
+ *
149
+ * @param d dimensionality of the input vectors
150
+ * @param M number of subquantizers
151
+ * @param nbits number of bit per subvector index
152
+ */
153
+ ResidualCoarseQuantizer(
154
+ int d, ///< dimensionality of the input vectors
155
+ size_t M, ///< number of subquantizers
156
+ size_t nbits, ///< number of bit per subvector index
157
+ MetricType metric = METRIC_L2);
158
+
159
+ ResidualCoarseQuantizer(
160
+ int d,
161
+ const std::vector<size_t>& nbits,
162
+ MetricType metric = METRIC_L2);
163
+
164
+ void search(
165
+ idx_t n,
166
+ const float* x,
167
+ idx_t k,
168
+ float* distances,
169
+ idx_t* labels) const override;
170
+
171
+ ResidualCoarseQuantizer();
172
+ };
173
+
174
+ struct LocalSearchCoarseQuantizer : AdditiveCoarseQuantizer {
175
+ /// The residual quantizer used to encode the vectors
176
+ LocalSearchQuantizer lsq;
177
+
178
+ /** Constructor.
179
+ *
180
+ * @param d dimensionality of the input vectors
181
+ * @param M number of subquantizers
182
+ * @param nbits number of bit per subvector index
183
+ */
184
+ LocalSearchCoarseQuantizer(
185
+ int d, ///< dimensionality of the input vectors
186
+ size_t M, ///< number of subquantizers
187
+ size_t nbits, ///< number of bit per subvector index
188
+ MetricType metric = METRIC_L2);
189
+
190
+ LocalSearchCoarseQuantizer();
191
+ };
192
+
193
+ } // namespace faiss
194
+
195
+ #endif