faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -0,0 +1,407 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // quiet the noise
9
+ // clang-format off
10
+
11
+ #include <faiss/IndexAdditiveQuantizer.h>
12
+
13
+ #include <algorithm>
14
+ #include <cmath>
15
+ #include <cstring>
16
+
17
+ #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/ResidualQuantizer.h>
19
+ #include <faiss/impl/ResultHandler.h>
20
+ #include <faiss/utils/distances.h>
21
+ #include <faiss/utils/extra_distances.h>
22
+ #include <faiss/utils/utils.h>
23
+
24
+
25
+ namespace faiss {
26
+
27
+ /**************************************************************************************
28
+ * IndexAdditiveQuantizer
29
+ **************************************************************************************/
30
+
31
+ IndexAdditiveQuantizer::IndexAdditiveQuantizer(
32
+ idx_t d,
33
+ AdditiveQuantizer* aq,
34
+ MetricType metric):
35
+ IndexFlatCodes(aq->code_size, d, metric), aq(aq)
36
+ {
37
+ FAISS_THROW_IF_NOT(metric == METRIC_INNER_PRODUCT || metric == METRIC_L2);
38
+ }
39
+
40
+
41
+ namespace {
42
+
43
+ template <class VectorDistance, class ResultHandler>
44
+ void search_with_decompress(
45
+ const IndexAdditiveQuantizer& ir,
46
+ const float* xq,
47
+ VectorDistance& vd,
48
+ ResultHandler& res) {
49
+ const uint8_t* codes = ir.codes.data();
50
+ size_t ntotal = ir.ntotal;
51
+ size_t code_size = ir.code_size;
52
+ const AdditiveQuantizer *aq = ir.aq;
53
+
54
+ using SingleResultHandler = typename ResultHandler::SingleResultHandler;
55
+
56
+ #pragma omp parallel for if(res.nq > 100)
57
+ for (int64_t q = 0; q < res.nq; q++) {
58
+ SingleResultHandler resi(res);
59
+ resi.begin(q);
60
+ std::vector<float> tmp(ir.d);
61
+ const float* x = xq + ir.d * q;
62
+ for (size_t i = 0; i < ntotal; i++) {
63
+ aq->decode(codes + i * code_size, tmp.data(), 1);
64
+ float dis = vd(x, tmp.data());
65
+ resi.add_result(dis, i);
66
+ }
67
+ resi.end();
68
+ }
69
+ }
70
+
71
+ template<bool is_IP, AdditiveQuantizer::Search_type_t st, class ResultHandler>
72
+ void search_with_LUT(
73
+ const IndexAdditiveQuantizer& ir,
74
+ const float* xq,
75
+ ResultHandler& res)
76
+ {
77
+ const AdditiveQuantizer & aq = *ir.aq;
78
+ const uint8_t* codes = ir.codes.data();
79
+ size_t ntotal = ir.ntotal;
80
+ size_t code_size = aq.code_size;
81
+ size_t nq = res.nq;
82
+ size_t d = ir.d;
83
+
84
+ using SingleResultHandler = typename ResultHandler::SingleResultHandler;
85
+ std::unique_ptr<float []> LUT(new float[nq * aq.total_codebook_size]);
86
+
87
+ aq.compute_LUT(nq, xq, LUT.get());
88
+
89
+ #pragma omp parallel for if(nq > 100)
90
+ for (int64_t q = 0; q < nq; q++) {
91
+ SingleResultHandler resi(res);
92
+ resi.begin(q);
93
+ std::vector<float> tmp(aq.d);
94
+ const float *LUT_q = LUT.get() + aq.total_codebook_size * q;
95
+ float bias = 0;
96
+ if (!is_IP) { // the LUT function returns ||y||^2 - 2 * <x, y>, need to add ||x||^2
97
+ bias = fvec_norm_L2sqr(xq + q * d, d);
98
+ }
99
+ for (size_t i = 0; i < ntotal; i++) {
100
+ float dis = aq.compute_1_distance_LUT<is_IP, st>(
101
+ codes + i * code_size,
102
+ LUT_q
103
+ );
104
+ resi.add_result(dis + bias, i);
105
+ }
106
+ resi.end();
107
+ }
108
+
109
+ }
110
+
111
+
112
+ } // anonymous namespace
113
+
114
+ void IndexAdditiveQuantizer::search(
115
+ idx_t n,
116
+ const float* x,
117
+ idx_t k,
118
+ float* distances,
119
+ idx_t* labels) const {
120
+ if (aq->search_type == AdditiveQuantizer::ST_decompress) {
121
+ if (metric_type == METRIC_L2) {
122
+ using VD = VectorDistance<METRIC_L2>;
123
+ VD vd = {size_t(d), metric_arg};
124
+ HeapResultHandler<VD::C> rh(n, distances, labels, k);
125
+ search_with_decompress(*this, x, vd, rh);
126
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
127
+ using VD = VectorDistance<METRIC_INNER_PRODUCT>;
128
+ VD vd = {size_t(d), metric_arg};
129
+ HeapResultHandler<VD::C> rh(n, distances, labels, k);
130
+ search_with_decompress(*this, x, vd, rh);
131
+ }
132
+ } else {
133
+ if (metric_type == METRIC_INNER_PRODUCT) {
134
+ HeapResultHandler<CMin<float, idx_t> > rh(n, distances, labels, k);
135
+ search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm> (*this, x, rh);
136
+ } else {
137
+ HeapResultHandler<CMax<float, idx_t> > rh(n, distances, labels, k);
138
+
139
+ if (aq->search_type == AdditiveQuantizer::ST_norm_float) {
140
+ search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
141
+ } else if (aq->search_type == AdditiveQuantizer::ST_LUT_nonorm) {
142
+ search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
143
+ } else if (aq->search_type == AdditiveQuantizer::ST_norm_qint8) {
144
+ search_with_LUT<false, AdditiveQuantizer::ST_norm_qint8> (*this, x, rh);
145
+ } else if (aq->search_type == AdditiveQuantizer::ST_norm_qint4) {
146
+ search_with_LUT<false, AdditiveQuantizer::ST_norm_qint4> (*this, x, rh);
147
+ } else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint8) {
148
+ search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this, x, rh);
149
+ } else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint4) {
150
+ search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint4> (*this, x, rh);
151
+ } else {
152
+ FAISS_THROW_FMT("search type %d not supported", aq->search_type);
153
+ }
154
+ }
155
+
156
+ }
157
+ }
158
+
159
+ void IndexAdditiveQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
160
+ return aq->compute_codes(x, bytes, n);
161
+ }
162
+
163
+ void IndexAdditiveQuantizer::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
164
+ return aq->decode(bytes, x, n);
165
+ }
166
+
167
+
168
+
169
+
170
+ /**************************************************************************************
171
+ * IndexResidualQuantizer
172
+ **************************************************************************************/
173
+
174
+ IndexResidualQuantizer::IndexResidualQuantizer(
175
+ int d, ///< dimensionality of the input vectors
176
+ size_t M, ///< number of subquantizers
177
+ size_t nbits, ///< number of bit per subvector index
178
+ MetricType metric,
179
+ Search_type_t search_type)
180
+ : IndexResidualQuantizer(d, std::vector<size_t>(M, nbits), metric, search_type) {
181
+ }
182
+
183
+ IndexResidualQuantizer::IndexResidualQuantizer(
184
+ int d,
185
+ const std::vector<size_t>& nbits,
186
+ MetricType metric,
187
+ Search_type_t search_type)
188
+ : IndexAdditiveQuantizer(d, &rq, metric), rq(d, nbits, search_type) {
189
+ code_size = rq.code_size;
190
+ is_trained = false;
191
+ }
192
+
193
+ IndexResidualQuantizer::IndexResidualQuantizer() : IndexResidualQuantizer(0, 0, 0) {}
194
+
195
+ void IndexResidualQuantizer::train(idx_t n, const float* x) {
196
+ rq.train(n, x);
197
+ is_trained = true;
198
+ }
199
+
200
+
201
+ /**************************************************************************************
202
+ * IndexLocalSearchQuantizer
203
+ **************************************************************************************/
204
+
205
+ IndexLocalSearchQuantizer::IndexLocalSearchQuantizer(
206
+ int d,
207
+ size_t M, ///< number of subquantizers
208
+ size_t nbits, ///< number of bit per subvector index
209
+ MetricType metric,
210
+ Search_type_t search_type)
211
+ : IndexAdditiveQuantizer(d, &lsq, metric), lsq(d, M, nbits, search_type) {
212
+ code_size = lsq.code_size;
213
+ is_trained = false;
214
+ }
215
+
216
+ IndexLocalSearchQuantizer::IndexLocalSearchQuantizer() : IndexLocalSearchQuantizer(0, 0, 0) {}
217
+
218
+ void IndexLocalSearchQuantizer::train(idx_t n, const float* x) {
219
+ lsq.train(n, x);
220
+ is_trained = true;
221
+ }
222
+
223
+ /**************************************************************************************
224
+ * AdditiveCoarseQuantizer
225
+ **************************************************************************************/
226
+
227
+ AdditiveCoarseQuantizer::AdditiveCoarseQuantizer(
228
+ idx_t d,
229
+ AdditiveQuantizer* aq,
230
+ MetricType metric):
231
+ Index(d, metric), aq(aq)
232
+ {}
233
+
234
+ void AdditiveCoarseQuantizer::add(idx_t, const float*) {
235
+ FAISS_THROW_MSG("not applicable");
236
+ }
237
+
238
+ void AdditiveCoarseQuantizer::reconstruct(idx_t key, float* recons) const {
239
+ aq->decode_64bit(key, recons);
240
+ }
241
+
242
+ void AdditiveCoarseQuantizer::reset() {
243
+ FAISS_THROW_MSG("not applicable");
244
+ }
245
+
246
+
247
+ void AdditiveCoarseQuantizer::train(idx_t n, const float* x) {
248
+ if (verbose) {
249
+ printf("AdditiveCoarseQuantizer::train: training on %zd vectors\n", size_t(n));
250
+ }
251
+ aq->train(n, x);
252
+ is_trained = true;
253
+ ntotal = (idx_t)1 << aq->tot_bits;
254
+
255
+ if (metric_type == METRIC_L2) {
256
+ if (verbose) {
257
+ printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n", size_t(ntotal));
258
+ }
259
+ // this is not necessary for the residualcoarsequantizer when
260
+ // using beam search. We'll see if the memory overhead is too high
261
+ centroid_norms.resize(ntotal);
262
+ aq->compute_centroid_norms(centroid_norms.data());
263
+ }
264
+ }
265
+
266
+ void AdditiveCoarseQuantizer::search(
267
+ idx_t n,
268
+ const float* x,
269
+ idx_t k,
270
+ float* distances,
271
+ idx_t* labels) const {
272
+ if (metric_type == METRIC_INNER_PRODUCT) {
273
+ aq->knn_centroids_inner_product(n, x, k, distances, labels);
274
+ } else if (metric_type == METRIC_L2) {
275
+ FAISS_THROW_IF_NOT(centroid_norms.size() == ntotal);
276
+ aq->knn_centroids_L2(
277
+ n, x, k, distances, labels, centroid_norms.data());
278
+ }
279
+ }
280
+
281
+ /**************************************************************************************
282
+ * ResidualCoarseQuantizer
283
+ **************************************************************************************/
284
+
285
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer(
286
+ int d, ///< dimensionality of the input vectors
287
+ const std::vector<size_t>& nbits,
288
+ MetricType metric)
289
+ : AdditiveCoarseQuantizer(d, &rq, metric), rq(d, nbits), beam_factor(4.0) {
290
+ FAISS_THROW_IF_NOT(rq.tot_bits <= 63);
291
+ is_trained = false;
292
+ }
293
+
294
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer(
295
+ int d,
296
+ size_t M, ///< number of subquantizers
297
+ size_t nbits, ///< number of bit per subvector index
298
+ MetricType metric)
299
+ : ResidualCoarseQuantizer(d, std::vector<size_t>(M, nbits), metric) {}
300
+
301
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer(): ResidualCoarseQuantizer(0, 0, 0) {}
302
+
303
+
304
+
305
+ void ResidualCoarseQuantizer::set_beam_factor(float new_beam_factor) {
306
+ beam_factor = new_beam_factor;
307
+ if (new_beam_factor > 0) {
308
+ FAISS_THROW_IF_NOT(new_beam_factor >= 1.0);
309
+ return;
310
+ } else if (metric_type == METRIC_L2 && ntotal != centroid_norms.size()) {
311
+ if (verbose) {
312
+ printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n", size_t(ntotal));
313
+ }
314
+ centroid_norms.resize(ntotal);
315
+ aq->compute_centroid_norms(centroid_norms.data());
316
+ }
317
+ }
318
+
319
+ void ResidualCoarseQuantizer::search(
320
+ idx_t n,
321
+ const float* x,
322
+ idx_t k,
323
+ float* distances,
324
+ idx_t* labels) const {
325
+ if (beam_factor < 0) {
326
+ AdditiveCoarseQuantizer::search(n, x, k, distances, labels);
327
+ return;
328
+ }
329
+
330
+ int beam_size = int(k * beam_factor);
331
+ if (beam_size > ntotal) {
332
+ beam_size = ntotal;
333
+ }
334
+ size_t memory_per_point = rq.memory_per_point(beam_size);
335
+
336
+ /*
337
+
338
+ printf("mem per point %ld n=%d max_mem_distance=%ld mem_kb=%zd\n",
339
+ memory_per_point, int(n), rq.max_mem_distances, get_mem_usage_kb());
340
+ */
341
+ if (n > 1 && memory_per_point * n > rq.max_mem_distances) {
342
+ // then split queries to reduce temp memory
343
+ idx_t bs = rq.max_mem_distances / memory_per_point;
344
+ if (bs == 0) {
345
+ bs = 1; // otherwise we can't do much
346
+ }
347
+ if (verbose) {
348
+ printf("ResidualCoarseQuantizer::search: run %d searches in batches of size %d\n",
349
+ int(n),
350
+ int(bs));
351
+ }
352
+ for (idx_t i0 = 0; i0 < n; i0 += bs) {
353
+ idx_t i1 = std::min(n, i0 + bs);
354
+ search(i1 - i0, x + i0 * d, k, distances + i0 * k, labels + i0 * k);
355
+ InterruptCallback::check();
356
+ }
357
+ return;
358
+ }
359
+
360
+ std::vector<int32_t> codes(beam_size * rq.M * n);
361
+ std::vector<float> beam_distances(n * beam_size);
362
+
363
+ rq.refine_beam(
364
+ n, 1, x, beam_size, codes.data(), nullptr, beam_distances.data());
365
+
366
+ #pragma omp parallel for if (n > 4000)
367
+ for (idx_t i = 0; i < n; i++) {
368
+ memcpy(distances + i * k,
369
+ beam_distances.data() + beam_size * i,
370
+ k * sizeof(distances[0]));
371
+
372
+ const int32_t* codes_i = codes.data() + beam_size * i * rq.M;
373
+ for (idx_t j = 0; j < k; j++) {
374
+ idx_t l = 0;
375
+ int shift = 0;
376
+ for (int m = 0; m < rq.M; m++) {
377
+ l |= (*codes_i++) << shift;
378
+ shift += rq.nbits[m];
379
+ }
380
+ labels[i * k + j] = l;
381
+ }
382
+ }
383
+ }
384
+
385
+ /**************************************************************************************
386
+ * LocalSearchCoarseQuantizer
387
+ **************************************************************************************/
388
+
389
+ LocalSearchCoarseQuantizer::LocalSearchCoarseQuantizer(
390
+ int d, ///< dimensionality of the input vectors
391
+ size_t M, ///< number of subquantizers
392
+ size_t nbits, ///< number of bit per subvector index
393
+ MetricType metric)
394
+ : AdditiveCoarseQuantizer(d, &lsq, metric), lsq(d, M, nbits) {
395
+ FAISS_THROW_IF_NOT(lsq.tot_bits <= 63);
396
+ is_trained = false;
397
+ }
398
+
399
+
400
+ LocalSearchCoarseQuantizer::LocalSearchCoarseQuantizer() {
401
+ aq = &lsq;
402
+ }
403
+
404
+
405
+
406
+
407
+ } // namespace faiss
@@ -0,0 +1,195 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #ifndef FAISS_INDEX_ADDITIVE_QUANTIZER_H
9
+ #define FAISS_INDEX_ADDITIVE_QUANTIZER_H
10
+
11
+ #include <faiss/impl/AdditiveQuantizer.h>
12
+
13
+ #include <cstdint>
14
+ #include <vector>
15
+
16
+ #include <faiss/IndexFlatCodes.h>
17
+ #include <faiss/impl/LocalSearchQuantizer.h>
18
+ #include <faiss/impl/ResidualQuantizer.h>
19
+ #include <faiss/impl/platform_macros.h>
20
+
21
+ namespace faiss {
22
+
23
+ /// Abstract class for additive quantizers. The search functions are in common.
24
+ struct IndexAdditiveQuantizer : IndexFlatCodes {
25
+ // the quantizer, this points to the relevant field in the inheriting
26
+ // classes
27
+ AdditiveQuantizer* aq;
28
+ using Search_type_t = AdditiveQuantizer::Search_type_t;
29
+
30
+ explicit IndexAdditiveQuantizer(
31
+ idx_t d = 0,
32
+ AdditiveQuantizer* aq = nullptr,
33
+ MetricType metric = METRIC_L2);
34
+
35
+ void search(
36
+ idx_t n,
37
+ const float* x,
38
+ idx_t k,
39
+ float* distances,
40
+ idx_t* labels) const override;
41
+
42
+ /* The standalone codec interface */
43
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
44
+
45
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
46
+ };
47
+
48
+ /** Index based on a residual quantizer. Stored vectors are
49
+ * approximated by residual quantization codes.
50
+ * Can also be used as a codec
51
+ */
52
+ struct IndexResidualQuantizer : IndexAdditiveQuantizer {
53
+ /// The residual quantizer used to encode the vectors
54
+ ResidualQuantizer rq;
55
+
56
+ /** Constructor.
57
+ *
58
+ * @param d dimensionality of the input vectors
59
+ * @param M number of subquantizers
60
+ * @param nbits number of bit per subvector index
61
+ */
62
+ IndexResidualQuantizer(
63
+ int d, ///< dimensionality of the input vectors
64
+ size_t M, ///< number of subquantizers
65
+ size_t nbits, ///< number of bit per subvector index
66
+ MetricType metric = METRIC_L2,
67
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
68
+
69
+ IndexResidualQuantizer(
70
+ int d,
71
+ const std::vector<size_t>& nbits,
72
+ MetricType metric = METRIC_L2,
73
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
74
+
75
+ IndexResidualQuantizer();
76
+
77
+ void train(idx_t n, const float* x) override;
78
+ };
79
+
80
+ struct IndexLocalSearchQuantizer : IndexAdditiveQuantizer {
81
+ LocalSearchQuantizer lsq;
82
+
83
+ /** Constructor.
84
+ *
85
+ * @param d dimensionality of the input vectors
86
+ * @param M number of subquantizers
87
+ * @param nbits number of bit per subvector index
88
+ */
89
+ IndexLocalSearchQuantizer(
90
+ int d, ///< dimensionality of the input vectors
91
+ size_t M, ///< number of subquantizers
92
+ size_t nbits, ///< number of bit per subvector index
93
+ MetricType metric = METRIC_L2,
94
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
95
+
96
+ IndexLocalSearchQuantizer();
97
+
98
+ void train(idx_t n, const float* x) override;
99
+ };
100
+
101
+ /** A "virtual" index where the elements are the residual quantizer centroids.
102
+ *
103
+ * Intended for use as a coarse quantizer in an IndexIVF.
104
+ */
105
+ struct AdditiveCoarseQuantizer : Index {
106
+ AdditiveQuantizer* aq;
107
+
108
+ explicit AdditiveCoarseQuantizer(
109
+ idx_t d = 0,
110
+ AdditiveQuantizer* aq = nullptr,
111
+ MetricType metric = METRIC_L2);
112
+
113
+ /// norms of centroids, useful for knn-search
114
+ std::vector<float> centroid_norms;
115
+
116
+ /// N/A
117
+ void add(idx_t n, const float* x) override;
118
+
119
+ void search(
120
+ idx_t n,
121
+ const float* x,
122
+ idx_t k,
123
+ float* distances,
124
+ idx_t* labels) const override;
125
+
126
+ void reconstruct(idx_t key, float* recons) const override;
127
+ void train(idx_t n, const float* x) override;
128
+
129
+ /// N/A
130
+ void reset() override;
131
+ };
132
+
133
+ /** The ResidualCoarseQuantizer is a bit specialized compared to the
134
+ * default AdditiveCoarseQuantizer because it can use a beam search
135
+ * at search time (slow but may be useful for very large vocabularies) */
136
+ struct ResidualCoarseQuantizer : AdditiveCoarseQuantizer {
137
+ /// The residual quantizer used to encode the vectors
138
+ ResidualQuantizer rq;
139
+
140
+ /// factor between the beam size and the search k
141
+ /// if negative, use exact search-to-centroid
142
+ float beam_factor;
143
+
144
+ /// computes centroid norms if required
145
+ void set_beam_factor(float new_beam_factor);
146
+
147
+ /** Constructor.
148
+ *
149
+ * @param d dimensionality of the input vectors
150
+ * @param M number of subquantizers
151
+ * @param nbits number of bit per subvector index
152
+ */
153
+ ResidualCoarseQuantizer(
154
+ int d, ///< dimensionality of the input vectors
155
+ size_t M, ///< number of subquantizers
156
+ size_t nbits, ///< number of bit per subvector index
157
+ MetricType metric = METRIC_L2);
158
+
159
+ ResidualCoarseQuantizer(
160
+ int d,
161
+ const std::vector<size_t>& nbits,
162
+ MetricType metric = METRIC_L2);
163
+
164
+ void search(
165
+ idx_t n,
166
+ const float* x,
167
+ idx_t k,
168
+ float* distances,
169
+ idx_t* labels) const override;
170
+
171
+ ResidualCoarseQuantizer();
172
+ };
173
+
174
+ struct LocalSearchCoarseQuantizer : AdditiveCoarseQuantizer {
175
+ /// The residual quantizer used to encode the vectors
176
+ LocalSearchQuantizer lsq;
177
+
178
+ /** Constructor.
179
+ *
180
+ * @param d dimensionality of the input vectors
181
+ * @param M number of subquantizers
182
+ * @param nbits number of bit per subvector index
183
+ */
184
+ LocalSearchCoarseQuantizer(
185
+ int d, ///< dimensionality of the input vectors
186
+ size_t M, ///< number of subquantizers
187
+ size_t nbits, ///< number of bit per subvector index
188
+ MetricType metric = METRIC_L2);
189
+
190
+ LocalSearchCoarseQuantizer();
191
+ };
192
+
193
+ } // namespace faiss
194
+
195
+ #endif