faiss 0.1.5 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +24 -0
  3. data/README.md +12 -0
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +6 -2
  6. data/ext/faiss/index.cpp +114 -43
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +24 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -0,0 +1,291 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexResidual.h>
9
+
10
+ #include <algorithm>
11
+ #include <cmath>
12
+ #include <cstring>
13
+
14
+ #include <faiss/impl/FaissAssert.h>
15
+ #include <faiss/impl/ResultHandler.h>
16
+ #include <faiss/utils/distances.h>
17
+ #include <faiss/utils/extra_distances.h>
18
+ #include <faiss/utils/utils.h>
19
+
20
+ namespace faiss {
21
+
22
+ /**************************************************************************************
23
+ * IndexResidual
24
+ **************************************************************************************/
25
+
26
+ IndexResidual::IndexResidual(
27
+ int d, ///< dimensionality of the input vectors
28
+ size_t M, ///< number of subquantizers
29
+ size_t nbits, ///< number of bit per subvector index
30
+ MetricType metric,
31
+ Search_type_t search_type_in)
32
+ : Index(d, metric), rq(d, M, nbits), search_type(ST_decompress) {
33
+ is_trained = false;
34
+ norm_max = norm_min = NAN;
35
+ set_search_type(search_type_in);
36
+ }
37
+
38
+ IndexResidual::IndexResidual(
39
+ int d,
40
+ const std::vector<size_t>& nbits,
41
+ MetricType metric,
42
+ Search_type_t search_type_in)
43
+ : Index(d, metric), rq(d, nbits), search_type(ST_decompress) {
44
+ is_trained = false;
45
+ norm_max = norm_min = NAN;
46
+ set_search_type(search_type_in);
47
+ }
48
+
49
+ IndexResidual::IndexResidual() : IndexResidual(0, 0, 0) {}
50
+
51
+ void IndexResidual::set_search_type(Search_type_t new_search_type) {
52
+ int norm_bits = new_search_type == ST_norm_float ? 32
53
+ : new_search_type == ST_norm_qint8 ? 8
54
+ : 0;
55
+
56
+ FAISS_THROW_IF_NOT(ntotal == 0);
57
+
58
+ search_type = new_search_type;
59
+ code_size = (rq.tot_bits + norm_bits + 7) / 8;
60
+ }
61
+
62
+ void IndexResidual::train(idx_t n, const float* x) {
63
+ rq.train(n, x);
64
+
65
+ std::vector<float> norms(n);
66
+ fvec_norms_L2sqr(norms.data(), x, d, n);
67
+
68
+ norm_min = HUGE_VALF;
69
+ norm_max = -HUGE_VALF;
70
+ for (idx_t i = 0; i < n; i++) {
71
+ if (norms[i] < norm_min) {
72
+ norm_min = norms[i];
73
+ }
74
+ if (norms[i] > norm_min) {
75
+ norm_max = norms[i];
76
+ }
77
+ }
78
+
79
+ is_trained = true;
80
+ }
81
+
82
+ void IndexResidual::add(idx_t n, const float* x) {
83
+ FAISS_THROW_IF_NOT(is_trained);
84
+ codes.resize((n + ntotal) * rq.code_size);
85
+ if (search_type == ST_decompress || search_type == ST_LUT_nonorm) {
86
+ rq.compute_codes(x, &codes[ntotal * rq.code_size], n);
87
+ } else {
88
+ // should compute codes + compute and quantize norms
89
+ FAISS_THROW_MSG("not implemented");
90
+ }
91
+ ntotal += n;
92
+ }
93
+
94
+ namespace {
95
+
96
+ template <class VectorDistance, class ResultHandler>
97
+ void search_with_decompress(
98
+ const IndexResidual& ir,
99
+ const float* xq,
100
+ VectorDistance& vd,
101
+ ResultHandler& res) {
102
+ const uint8_t* codes = ir.codes.data();
103
+ size_t ntotal = ir.ntotal;
104
+ size_t code_size = ir.code_size;
105
+
106
+ using SingleResultHandler = typename ResultHandler::SingleResultHandler;
107
+
108
+ #pragma omp parallel for
109
+ for (int64_t q = 0; q < res.nq; q++) {
110
+ SingleResultHandler resi(res);
111
+ resi.begin(q);
112
+ std::vector<float> tmp(ir.d);
113
+ const float* x = xq + ir.d * q;
114
+ for (size_t i = 0; i < ntotal; i++) {
115
+ ir.rq.decode(codes + i * code_size, tmp.data(), 1);
116
+ float dis = vd(x, tmp.data());
117
+ resi.add_result(dis, i);
118
+ }
119
+ resi.end();
120
+ }
121
+ }
122
+
123
+ } // anonymous namespace
124
+
125
+ void IndexResidual::search(
126
+ idx_t n,
127
+ const float* x,
128
+ idx_t k,
129
+ float* distances,
130
+ idx_t* labels) const {
131
+ if (search_type == ST_decompress) {
132
+ if (metric_type == METRIC_L2) {
133
+ using VD = VectorDistance<METRIC_L2>;
134
+ VD vd = {size_t(d), metric_arg};
135
+ HeapResultHandler<VD::C> rh(n, distances, labels, k);
136
+ search_with_decompress(*this, x, vd, rh);
137
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
138
+ using VD = VectorDistance<METRIC_INNER_PRODUCT>;
139
+ VD vd = {size_t(d), metric_arg};
140
+ HeapResultHandler<VD::C> rh(n, distances, labels, k);
141
+ search_with_decompress(*this, x, vd, rh);
142
+ }
143
+ } else {
144
+ FAISS_THROW_MSG("not implemented");
145
+ }
146
+ }
147
+
148
+ void IndexResidual::reset() {
149
+ codes.clear();
150
+ ntotal = 0;
151
+ }
152
+
153
+ size_t IndexResidual::sa_code_size() const {
154
+ return code_size;
155
+ }
156
+
157
+ void IndexResidual::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
158
+ return rq.compute_codes(x, bytes, n);
159
+ }
160
+
161
+ void IndexResidual::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
162
+ return rq.decode(bytes, x, n);
163
+ }
164
+
165
+ /**************************************************************************************
166
+ * ResidualCoarseQuantizer
167
+ **************************************************************************************/
168
+
169
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer(
170
+ int d, ///< dimensionality of the input vectors
171
+ size_t M, ///< number of subquantizers
172
+ size_t nbits, ///< number of bit per subvector index
173
+ MetricType metric)
174
+ : Index(d, metric), rq(d, M, nbits), beam_factor(4.0) {
175
+ FAISS_THROW_IF_NOT(rq.tot_bits <= 63);
176
+ is_trained = false;
177
+ }
178
+
179
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer(
180
+ int d,
181
+ const std::vector<size_t>& nbits,
182
+ MetricType metric)
183
+ : Index(d, metric), rq(d, nbits), beam_factor(4.0) {
184
+ FAISS_THROW_IF_NOT(rq.tot_bits <= 63);
185
+ is_trained = false;
186
+ }
187
+
188
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer() {}
189
+
190
+ void ResidualCoarseQuantizer::train(idx_t n, const float* x) {
191
+ rq.train(n, x);
192
+ is_trained = true;
193
+ ntotal = (idx_t)1 << rq.tot_bits;
194
+ }
195
+
196
+ void ResidualCoarseQuantizer::add(idx_t, const float*) {
197
+ FAISS_THROW_MSG("not applicable");
198
+ }
199
+
200
+ void ResidualCoarseQuantizer::set_beam_factor(float new_beam_factor) {
201
+ centroid_norms.resize(0);
202
+ beam_factor = new_beam_factor;
203
+ if (new_beam_factor > 0) {
204
+ FAISS_THROW_IF_NOT(new_beam_factor >= 1.0);
205
+ return;
206
+ }
207
+
208
+ if (metric_type == METRIC_L2) {
209
+ centroid_norms.resize((size_t)1 << rq.tot_bits);
210
+ rq.compute_centroid_norms(centroid_norms.data());
211
+ }
212
+ }
213
+
214
+ void ResidualCoarseQuantizer::search(
215
+ idx_t n,
216
+ const float* x,
217
+ idx_t k,
218
+ float* distances,
219
+ idx_t* labels) const {
220
+ if (beam_factor < 0) {
221
+ if (metric_type == METRIC_INNER_PRODUCT) {
222
+ rq.knn_exact_inner_product(n, x, k, distances, labels);
223
+ } else if (metric_type == METRIC_L2) {
224
+ FAISS_THROW_IF_NOT(centroid_norms.size() == ntotal);
225
+ rq.knn_exact_L2(n, x, k, distances, labels, centroid_norms.data());
226
+ }
227
+ return;
228
+ }
229
+
230
+ int beam_size = int(k * beam_factor);
231
+
232
+ size_t memory_per_point = rq.memory_per_point(beam_size);
233
+
234
+ /*
235
+
236
+ printf("mem per point %ld n=%d max_mem_distance=%ld mem_kb=%zd\n",
237
+ memory_per_point, int(n), rq.max_mem_distances, get_mem_usage_kb());
238
+ */
239
+ if (n > 1 && memory_per_point * n > rq.max_mem_distances) {
240
+ // then split queries to reduce temp memory
241
+ idx_t bs = rq.max_mem_distances / memory_per_point;
242
+ if (bs == 0) {
243
+ bs = 1; // otherwise we can't do much
244
+ }
245
+ if (verbose) {
246
+ printf("ResidualCoarseQuantizer::search: run %d searches in batches of size %d\n",
247
+ int(n),
248
+ int(bs));
249
+ }
250
+ for (idx_t i0 = 0; i0 < n; i0 += bs) {
251
+ idx_t i1 = std::min(n, i0 + bs);
252
+ search(i1 - i0, x + i0 * d, k, distances + i0 * k, labels + i0 * k);
253
+ InterruptCallback::check();
254
+ }
255
+ return;
256
+ }
257
+
258
+ std::vector<int32_t> codes(beam_size * rq.M * n);
259
+ std::vector<float> beam_distances(n * beam_size);
260
+
261
+ rq.refine_beam(
262
+ n, 1, x, beam_size, codes.data(), nullptr, beam_distances.data());
263
+
264
+ #pragma omp parallel for if (n > 4000)
265
+ for (idx_t i = 0; i < n; i++) {
266
+ memcpy(distances + i * k,
267
+ beam_distances.data() + beam_size * i,
268
+ k * sizeof(distances[0]));
269
+
270
+ const int32_t* codes_i = codes.data() + beam_size * i * rq.M;
271
+ for (idx_t j = 0; j < k; j++) {
272
+ idx_t l = 0;
273
+ int shift = 0;
274
+ for (int m = 0; m < rq.M; m++) {
275
+ l |= (*codes_i++) << shift;
276
+ shift += rq.nbits[m];
277
+ }
278
+ labels[i * k + j] = l;
279
+ }
280
+ }
281
+ }
282
+
283
+ void ResidualCoarseQuantizer::reconstruct(idx_t key, float* recons) const {
284
+ rq.decode_64bit(key, recons);
285
+ }
286
+
287
+ void ResidualCoarseQuantizer::reset() {
288
+ FAISS_THROW_MSG("not applicable");
289
+ }
290
+
291
+ } // namespace faiss
@@ -0,0 +1,152 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #ifndef FAISS_INDEX_RESIDUAL_H
9
+ #define FAISS_INDEX_RESIDUAL_H
10
+
11
+ #include <stdint.h>
12
+
13
+ #include <vector>
14
+
15
+ #include <faiss/Index.h>
16
+ #include <faiss/impl/ResidualQuantizer.h>
17
+ #include <faiss/impl/platform_macros.h>
18
+
19
+ namespace faiss {
20
+
21
+ /** Index based on a residual quantizer. Stored vectors are
22
+ * approximated by residual quantization codes.
23
+ * Can also be used as a codec
24
+ */
25
+ struct IndexResidual : Index {
26
+ /// The residual quantizer used to encode the vectors
27
+ ResidualQuantizer rq;
28
+
29
+ enum Search_type_t {
30
+ ST_decompress, ///< decompress database vector
31
+ ST_LUT_nonorm, ///< use a LUT, don't include norms (OK for IP or
32
+ ///< normalized vectors)
33
+ ST_norm_float, ///< use a LUT, and store float32 norm with the vectors
34
+ ST_norm_qint8, ///< use a LUT, and store 8bit-quantized norm
35
+ };
36
+ Search_type_t search_type;
37
+
38
+ /// min/max for quantization of norms
39
+ float norm_min, norm_max;
40
+
41
+ /// size of residual quantizer codes + norms
42
+ size_t code_size;
43
+
44
+ /// Codes. Size ntotal * rq.code_size
45
+ std::vector<uint8_t> codes;
46
+
47
+ /** Constructor.
48
+ *
49
+ * @param d dimensionality of the input vectors
50
+ * @param M number of subquantizers
51
+ * @param nbits number of bit per subvector index
52
+ */
53
+ IndexResidual(
54
+ int d, ///< dimensionality of the input vectors
55
+ size_t M, ///< number of subquantizers
56
+ size_t nbits, ///< number of bit per subvector index
57
+ MetricType metric = METRIC_L2,
58
+ Search_type_t search_type = ST_decompress);
59
+
60
+ IndexResidual(
61
+ int d,
62
+ const std::vector<size_t>& nbits,
63
+ MetricType metric = METRIC_L2,
64
+ Search_type_t search_type = ST_decompress);
65
+
66
+ IndexResidual();
67
+
68
+ /// set search type and update parameters
69
+ void set_search_type(Search_type_t search_type);
70
+
71
+ void train(idx_t n, const float* x) override;
72
+
73
+ void add(idx_t n, const float* x) override;
74
+
75
+ /// not implemented
76
+ void search(
77
+ idx_t n,
78
+ const float* x,
79
+ idx_t k,
80
+ float* distances,
81
+ idx_t* labels) const override;
82
+
83
+ void reset() override;
84
+
85
+ /* The standalone codec interface */
86
+ size_t sa_code_size() const override;
87
+
88
+ void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
89
+
90
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
91
+
92
+ // DistanceComputer* get_distance_computer() const override;
93
+ };
94
+
95
+ /** A "virtual" index where the elements are the residual quantizer centroids.
96
+ *
97
+ * Intended for use as a coarse quantizer in an IndexIVF.
98
+ */
99
+ struct ResidualCoarseQuantizer : Index {
100
+ /// The residual quantizer used to encode the vectors
101
+ ResidualQuantizer rq;
102
+
103
+ /// factor between the beam size and the search k
104
+ /// if negative, use exact search-to-centroid
105
+ float beam_factor;
106
+
107
+ /// norms of centroids, useful for knn-search
108
+ std::vector<float> centroid_norms;
109
+
110
+ /// computes centroid norms if required
111
+ void set_beam_factor(float new_beam_factor);
112
+
113
+ /** Constructor.
114
+ *
115
+ * @param d dimensionality of the input vectors
116
+ * @param M number of subquantizers
117
+ * @param nbits number of bit per subvector index
118
+ */
119
+ ResidualCoarseQuantizer(
120
+ int d, ///< dimensionality of the input vectors
121
+ size_t M, ///< number of subquantizers
122
+ size_t nbits, ///< number of bit per subvector index
123
+ MetricType metric = METRIC_L2);
124
+
125
+ ResidualCoarseQuantizer(
126
+ int d,
127
+ const std::vector<size_t>& nbits,
128
+ MetricType metric = METRIC_L2);
129
+
130
+ ResidualCoarseQuantizer();
131
+
132
+ void train(idx_t n, const float* x) override;
133
+
134
+ /// N/A
135
+ void add(idx_t n, const float* x) override;
136
+
137
+ void search(
138
+ idx_t n,
139
+ const float* x,
140
+ idx_t k,
141
+ float* distances,
142
+ idx_t* labels) const override;
143
+
144
+ void reconstruct(idx_t key, float* recons) const override;
145
+
146
+ /// N/A
147
+ void reset() override;
148
+ };
149
+
150
+ } // namespace faiss
151
+
152
+ #endif
@@ -9,231 +9,207 @@
9
9
 
10
10
  #include <faiss/IndexScalarQuantizer.h>
11
11
 
12
- #include <cstdio>
13
12
  #include <algorithm>
13
+ #include <cstdio>
14
14
 
15
15
  #include <omp.h>
16
16
 
17
- #include <faiss/utils/utils.h>
18
- #include <faiss/impl/FaissAssert.h>
19
17
  #include <faiss/impl/AuxIndexStructures.h>
18
+ #include <faiss/impl/FaissAssert.h>
20
19
  #include <faiss/impl/ScalarQuantizer.h>
20
+ #include <faiss/utils/utils.h>
21
21
 
22
22
  namespace faiss {
23
23
 
24
-
25
-
26
24
  /*******************************************************************
27
25
  * IndexScalarQuantizer implementation
28
26
  ********************************************************************/
29
27
 
30
- IndexScalarQuantizer::IndexScalarQuantizer
31
- (int d, ScalarQuantizer::QuantizerType qtype,
32
- MetricType metric):
33
- Index(d, metric),
34
- sq (d, qtype)
35
- {
36
- is_trained =
37
- qtype == ScalarQuantizer::QT_fp16 ||
38
- qtype == ScalarQuantizer::QT_8bit_direct;
28
+ IndexScalarQuantizer::IndexScalarQuantizer(
29
+ int d,
30
+ ScalarQuantizer::QuantizerType qtype,
31
+ MetricType metric)
32
+ : Index(d, metric), sq(d, qtype) {
33
+ is_trained = qtype == ScalarQuantizer::QT_fp16 ||
34
+ qtype == ScalarQuantizer::QT_8bit_direct;
39
35
  code_size = sq.code_size;
40
36
  }
41
37
 
38
+ IndexScalarQuantizer::IndexScalarQuantizer()
39
+ : IndexScalarQuantizer(0, ScalarQuantizer::QT_8bit) {}
42
40
 
43
- IndexScalarQuantizer::IndexScalarQuantizer ():
44
- IndexScalarQuantizer(0, ScalarQuantizer::QT_8bit)
45
- {}
46
-
47
- void IndexScalarQuantizer::train(idx_t n, const float* x)
48
- {
41
+ void IndexScalarQuantizer::train(idx_t n, const float* x) {
49
42
  sq.train(n, x);
50
43
  is_trained = true;
51
44
  }
52
45
 
53
- void IndexScalarQuantizer::add(idx_t n, const float* x)
54
- {
55
- FAISS_THROW_IF_NOT (is_trained);
56
- codes.resize ((n + ntotal) * code_size);
57
- sq.compute_codes (x, &codes[ntotal * code_size], n);
46
+ void IndexScalarQuantizer::add(idx_t n, const float* x) {
47
+ FAISS_THROW_IF_NOT(is_trained);
48
+ codes.resize((n + ntotal) * code_size);
49
+ sq.compute_codes(x, &codes[ntotal * code_size], n);
58
50
  ntotal += n;
59
51
  }
60
52
 
61
-
62
53
  void IndexScalarQuantizer::search(
63
54
  idx_t n,
64
55
  const float* x,
65
56
  idx_t k,
66
57
  float* distances,
67
- idx_t* labels) const
68
- {
69
- FAISS_THROW_IF_NOT (is_trained);
70
- FAISS_THROW_IF_NOT (metric_type == METRIC_L2 ||
71
- metric_type == METRIC_INNER_PRODUCT);
58
+ idx_t* labels) const {
59
+ FAISS_THROW_IF_NOT(k > 0);
60
+
61
+ FAISS_THROW_IF_NOT(is_trained);
62
+ FAISS_THROW_IF_NOT(
63
+ metric_type == METRIC_L2 || metric_type == METRIC_INNER_PRODUCT);
72
64
 
73
65
  #pragma omp parallel
74
66
  {
75
- InvertedListScanner* scanner = sq.select_InvertedListScanner
76
- (metric_type, nullptr, true);
67
+ InvertedListScanner* scanner =
68
+ sq.select_InvertedListScanner(metric_type, nullptr, true);
77
69
  ScopeDeleter1<InvertedListScanner> del(scanner);
78
70
 
79
71
  #pragma omp for
80
72
  for (idx_t i = 0; i < n; i++) {
81
- float * D = distances + k * i;
82
- idx_t * I = labels + k * i;
73
+ float* D = distances + k * i;
74
+ idx_t* I = labels + k * i;
83
75
  // re-order heap
84
76
  if (metric_type == METRIC_L2) {
85
- maxheap_heapify (k, D, I);
77
+ maxheap_heapify(k, D, I);
86
78
  } else {
87
- minheap_heapify (k, D, I);
79
+ minheap_heapify(k, D, I);
88
80
  }
89
- scanner->set_query (x + i * d);
90
- scanner->scan_codes (ntotal, codes.data(),
91
- nullptr, D, I, k);
81
+ scanner->set_query(x + i * d);
82
+ scanner->scan_codes(ntotal, codes.data(), nullptr, D, I, k);
92
83
 
93
84
  // re-order heap
94
85
  if (metric_type == METRIC_L2) {
95
- maxheap_reorder (k, D, I);
86
+ maxheap_reorder(k, D, I);
96
87
  } else {
97
- minheap_reorder (k, D, I);
88
+ minheap_reorder(k, D, I);
98
89
  }
99
90
  }
100
91
  }
101
-
102
92
  }
103
93
 
104
-
105
- DistanceComputer *IndexScalarQuantizer::get_distance_computer () const
106
- {
107
- ScalarQuantizer::SQDistanceComputer *dc =
108
- sq.get_distance_computer (metric_type);
94
+ DistanceComputer* IndexScalarQuantizer::get_distance_computer() const {
95
+ ScalarQuantizer::SQDistanceComputer* dc =
96
+ sq.get_distance_computer(metric_type);
109
97
  dc->code_size = sq.code_size;
110
98
  dc->codes = codes.data();
111
99
  return dc;
112
100
  }
113
101
 
114
-
115
- void IndexScalarQuantizer::reset()
116
- {
102
+ void IndexScalarQuantizer::reset() {
117
103
  codes.clear();
118
104
  ntotal = 0;
119
105
  }
120
106
 
121
- void IndexScalarQuantizer::reconstruct_n(
122
- idx_t i0, idx_t ni, float* recons) const
123
- {
124
- std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer ());
107
+ void IndexScalarQuantizer::reconstruct_n(idx_t i0, idx_t ni, float* recons)
108
+ const {
109
+ std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
125
110
  for (size_t i = 0; i < ni; i++) {
126
111
  squant->decode_vector(&codes[(i + i0) * code_size], recons + i * d);
127
112
  }
128
113
  }
129
114
 
130
- void IndexScalarQuantizer::reconstruct(idx_t key, float* recons) const
131
- {
115
+ void IndexScalarQuantizer::reconstruct(idx_t key, float* recons) const {
132
116
  reconstruct_n(key, 1, recons);
133
117
  }
134
118
 
135
119
  /* Codec interface */
136
- size_t IndexScalarQuantizer::sa_code_size () const
137
- {
120
+ size_t IndexScalarQuantizer::sa_code_size() const {
138
121
  return sq.code_size;
139
122
  }
140
123
 
141
- void IndexScalarQuantizer::sa_encode (idx_t n, const float *x,
142
- uint8_t *bytes) const
143
- {
144
- FAISS_THROW_IF_NOT (is_trained);
145
- sq.compute_codes (x, bytes, n);
124
+ void IndexScalarQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes)
125
+ const {
126
+ FAISS_THROW_IF_NOT(is_trained);
127
+ sq.compute_codes(x, bytes, n);
146
128
  }
147
129
 
148
- void IndexScalarQuantizer::sa_decode (idx_t n, const uint8_t *bytes,
149
- float *x) const
150
- {
151
- FAISS_THROW_IF_NOT (is_trained);
130
+ void IndexScalarQuantizer::sa_decode(idx_t n, const uint8_t* bytes, float* x)
131
+ const {
132
+ FAISS_THROW_IF_NOT(is_trained);
152
133
  sq.decode(bytes, x, n);
153
134
  }
154
135
 
155
-
156
-
157
136
  /*******************************************************************
158
137
  * IndexIVFScalarQuantizer implementation
159
138
  ********************************************************************/
160
139
 
161
- IndexIVFScalarQuantizer::IndexIVFScalarQuantizer (
162
- Index *quantizer, size_t d, size_t nlist,
163
- ScalarQuantizer::QuantizerType qtype,
164
- MetricType metric, bool encode_residual)
165
- : IndexIVF(quantizer, d, nlist, 0, metric),
166
- sq(d, qtype),
167
- by_residual(encode_residual)
168
- {
140
+ IndexIVFScalarQuantizer::IndexIVFScalarQuantizer(
141
+ Index* quantizer,
142
+ size_t d,
143
+ size_t nlist,
144
+ ScalarQuantizer::QuantizerType qtype,
145
+ MetricType metric,
146
+ bool encode_residual)
147
+ : IndexIVF(quantizer, d, nlist, 0, metric),
148
+ sq(d, qtype),
149
+ by_residual(encode_residual) {
169
150
  code_size = sq.code_size;
170
151
  // was not known at construction time
171
152
  invlists->code_size = code_size;
172
153
  is_trained = false;
173
154
  }
174
155
 
175
- IndexIVFScalarQuantizer::IndexIVFScalarQuantizer ():
176
- IndexIVF(),
177
- by_residual(true)
178
- {
179
- }
156
+ IndexIVFScalarQuantizer::IndexIVFScalarQuantizer()
157
+ : IndexIVF(), by_residual(true) {}
180
158
 
181
- void IndexIVFScalarQuantizer::train_residual (idx_t n, const float *x)
182
- {
159
+ void IndexIVFScalarQuantizer::train_residual(idx_t n, const float* x) {
183
160
  sq.train_residual(n, x, quantizer, by_residual, verbose);
184
161
  }
185
162
 
186
- void IndexIVFScalarQuantizer::encode_vectors(idx_t n, const float* x,
187
- const idx_t *list_nos,
188
- uint8_t * codes,
189
- bool include_listnos) const
190
- {
191
- std::unique_ptr<ScalarQuantizer::Quantizer> squant (sq.select_quantizer ());
192
- size_t coarse_size = include_listnos ? coarse_code_size () : 0;
163
+ void IndexIVFScalarQuantizer::encode_vectors(
164
+ idx_t n,
165
+ const float* x,
166
+ const idx_t* list_nos,
167
+ uint8_t* codes,
168
+ bool include_listnos) const {
169
+ std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
170
+ size_t coarse_size = include_listnos ? coarse_code_size() : 0;
193
171
  memset(codes, 0, (code_size + coarse_size) * n);
194
172
 
195
- #pragma omp parallel if(n > 1000)
173
+ #pragma omp parallel if (n > 1000)
196
174
  {
197
- std::vector<float> residual (d);
175
+ std::vector<float> residual(d);
198
176
 
199
177
  #pragma omp for
200
178
  for (idx_t i = 0; i < n; i++) {
201
- int64_t list_no = list_nos [i];
179
+ int64_t list_no = list_nos[i];
202
180
  if (list_no >= 0) {
203
- const float *xi = x + i * d;
204
- uint8_t *code = codes + i * (code_size + coarse_size);
181
+ const float* xi = x + i * d;
182
+ uint8_t* code = codes + i * (code_size + coarse_size);
205
183
  if (by_residual) {
206
- quantizer->compute_residual (
207
- xi, residual.data(), list_no);
208
- xi = residual.data ();
184
+ quantizer->compute_residual(xi, residual.data(), list_no);
185
+ xi = residual.data();
209
186
  }
210
187
  if (coarse_size) {
211
- encode_listno (list_no, code);
188
+ encode_listno(list_no, code);
212
189
  }
213
- squant->encode_vector (xi, code + coarse_size);
190
+ squant->encode_vector(xi, code + coarse_size);
214
191
  }
215
192
  }
216
193
  }
217
194
  }
218
195
 
219
- void IndexIVFScalarQuantizer::sa_decode (idx_t n, const uint8_t *codes,
220
- float *x) const
221
- {
222
- std::unique_ptr<ScalarQuantizer::Quantizer> squant (sq.select_quantizer ());
223
- size_t coarse_size = coarse_code_size ();
196
+ void IndexIVFScalarQuantizer::sa_decode(idx_t n, const uint8_t* codes, float* x)
197
+ const {
198
+ std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
199
+ size_t coarse_size = coarse_code_size();
224
200
 
225
- #pragma omp parallel if(n > 1000)
201
+ #pragma omp parallel if (n > 1000)
226
202
  {
227
- std::vector<float> residual (d);
203
+ std::vector<float> residual(d);
228
204
 
229
205
  #pragma omp for
230
206
  for (idx_t i = 0; i < n; i++) {
231
- const uint8_t *code = codes + i * (code_size + coarse_size);
232
- int64_t list_no = decode_listno (code);
233
- float *xi = x + i * d;
234
- squant->decode_vector (code + coarse_size, xi);
207
+ const uint8_t* code = codes + i * (code_size + coarse_size);
208
+ int64_t list_no = decode_listno(code);
209
+ float* xi = x + i * d;
210
+ squant->decode_vector(code + coarse_size, xi);
235
211
  if (by_residual) {
236
- quantizer->reconstruct (list_no, residual.data());
212
+ quantizer->reconstruct(list_no, residual.data());
237
213
  for (size_t j = 0; j < d; j++) {
238
214
  xi[j] += residual[j];
239
215
  }
@@ -242,83 +218,72 @@ void IndexIVFScalarQuantizer::sa_decode (idx_t n, const uint8_t *codes,
242
218
  }
243
219
  }
244
220
 
221
+ void IndexIVFScalarQuantizer::add_core(
222
+ idx_t n,
223
+ const float* x,
224
+ const idx_t* xids,
225
+ const idx_t* coarse_idx) {
226
+ FAISS_THROW_IF_NOT(is_trained);
245
227
 
246
-
247
- void IndexIVFScalarQuantizer::add_with_ids
248
- (idx_t n, const float * x, const idx_t *xids)
249
- {
250
- FAISS_THROW_IF_NOT (is_trained);
251
- std::unique_ptr<int64_t []> idx (new int64_t [n]);
252
- quantizer->assign (n, x, idx.get());
253
228
  size_t nadd = 0;
254
- std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer ());
229
+ std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
255
230
 
256
- DirectMapAdd dm_add (direct_map, n, xids);
231
+ DirectMapAdd dm_add(direct_map, n, xids);
257
232
 
258
- #pragma omp parallel reduction(+: nadd)
233
+ #pragma omp parallel reduction(+ : nadd)
259
234
  {
260
- std::vector<float> residual (d);
261
- std::vector<uint8_t> one_code (code_size);
235
+ std::vector<float> residual(d);
236
+ std::vector<uint8_t> one_code(code_size);
262
237
  int nt = omp_get_num_threads();
263
238
  int rank = omp_get_thread_num();
264
239
 
265
240
  // each thread takes care of a subset of lists
266
241
  for (size_t i = 0; i < n; i++) {
267
- int64_t list_no = idx [i];
242
+ int64_t list_no = coarse_idx[i];
268
243
  if (list_no >= 0 && list_no % nt == rank) {
269
244
  int64_t id = xids ? xids[i] : ntotal + i;
270
245
 
271
- const float * xi = x + i * d;
246
+ const float* xi = x + i * d;
272
247
  if (by_residual) {
273
- quantizer->compute_residual (xi, residual.data(), list_no);
248
+ quantizer->compute_residual(xi, residual.data(), list_no);
274
249
  xi = residual.data();
275
250
  }
276
251
 
277
- memset (one_code.data(), 0, code_size);
278
- squant->encode_vector (xi, one_code.data());
252
+ memset(one_code.data(), 0, code_size);
253
+ squant->encode_vector(xi, one_code.data());
279
254
 
280
- size_t ofs = invlists->add_entry (list_no, id, one_code.data());
255
+ size_t ofs = invlists->add_entry(list_no, id, one_code.data());
281
256
 
282
- dm_add.add (i, list_no, ofs);
257
+ dm_add.add(i, list_no, ofs);
283
258
  nadd++;
284
259
 
285
260
  } else if (rank == 0 && list_no == -1) {
286
- dm_add.add (i, -1, 0);
261
+ dm_add.add(i, -1, 0);
287
262
  }
288
263
  }
289
264
  }
290
265
 
291
-
292
266
  ntotal += n;
293
267
  }
294
268
 
295
-
296
-
297
-
298
-
299
- InvertedListScanner* IndexIVFScalarQuantizer::get_InvertedListScanner
300
- (bool store_pairs) const
301
- {
302
- return sq.select_InvertedListScanner (metric_type, quantizer, store_pairs,
303
- by_residual);
269
+ InvertedListScanner* IndexIVFScalarQuantizer::get_InvertedListScanner(
270
+ bool store_pairs) const {
271
+ return sq.select_InvertedListScanner(
272
+ metric_type, quantizer, store_pairs, by_residual);
304
273
  }
305
274
 
306
-
307
- void IndexIVFScalarQuantizer::reconstruct_from_offset (int64_t list_no,
308
- int64_t offset,
309
- float* recons) const
310
- {
275
+ void IndexIVFScalarQuantizer::reconstruct_from_offset(
276
+ int64_t list_no,
277
+ int64_t offset,
278
+ float* recons) const {
311
279
  std::vector<float> centroid(d);
312
- quantizer->reconstruct (list_no, centroid.data());
280
+ quantizer->reconstruct(list_no, centroid.data());
313
281
 
314
- const uint8_t* code = invlists->get_single_code (list_no, offset);
315
- sq.decode (code, recons, 1);
282
+ const uint8_t* code = invlists->get_single_code(list_no, offset);
283
+ sq.decode(code, recons, 1);
316
284
  for (int i = 0; i < d; ++i) {
317
285
  recons[i] += centroid[i];
318
286
  }
319
287
  }
320
288
 
321
-
322
-
323
-
324
289
  } // namespace faiss