faiss 0.2.0 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -0,0 +1,758 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/ResidualQuantizer.h>
11
+
12
+ #include <algorithm>
13
+ #include <cstddef>
14
+ #include <cstdio>
15
+ #include <cstring>
16
+ #include <memory>
17
+
18
+ #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/impl/ResidualQuantizer.h>
20
+ #include <faiss/utils/utils.h>
21
+
22
+ #include <faiss/Clustering.h>
23
+ #include <faiss/IndexFlat.h>
24
+ #include <faiss/VectorTransform.h>
25
+ #include <faiss/impl/AuxIndexStructures.h>
26
+ #include <faiss/impl/FaissAssert.h>
27
+ #include <faiss/utils/Heap.h>
28
+ #include <faiss/utils/distances.h>
29
+ #include <faiss/utils/hamming.h>
30
+ #include <faiss/utils/simdlib.h>
31
+ #include <faiss/utils/utils.h>
32
+
33
+ extern "C" {
34
+
35
+ // general matrix multiplication
36
+ int sgemm_(
37
+ const char* transa,
38
+ const char* transb,
39
+ FINTEGER* m,
40
+ FINTEGER* n,
41
+ FINTEGER* k,
42
+ const float* alpha,
43
+ const float* a,
44
+ FINTEGER* lda,
45
+ const float* b,
46
+ FINTEGER* ldb,
47
+ float* beta,
48
+ float* c,
49
+ FINTEGER* ldc);
50
+ }
51
+
52
+ namespace faiss {
53
+
54
+ ResidualQuantizer::ResidualQuantizer()
55
+ : train_type(Train_progressive_dim),
56
+ max_beam_size(5),
57
+ use_beam_LUT(0),
58
+ max_mem_distances(5 * (size_t(1) << 30)), // 5 GiB
59
+ assign_index_factory(nullptr) {
60
+ d = 0;
61
+ M = 0;
62
+ verbose = false;
63
+ }
64
+
65
+ ResidualQuantizer::ResidualQuantizer(
66
+ size_t d,
67
+ const std::vector<size_t>& nbits,
68
+ Search_type_t search_type)
69
+ : ResidualQuantizer() {
70
+ this->search_type = search_type;
71
+ this->d = d;
72
+ M = nbits.size();
73
+ this->nbits = nbits;
74
+ set_derived_values();
75
+ }
76
+
77
+ ResidualQuantizer::ResidualQuantizer(
78
+ size_t d,
79
+ size_t M,
80
+ size_t nbits,
81
+ Search_type_t search_type)
82
+ : ResidualQuantizer(d, std::vector<size_t>(M, nbits), search_type) {}
83
+
84
+ void beam_search_encode_step(
85
+ size_t d,
86
+ size_t K,
87
+ const float* cent, /// size (K, d)
88
+ size_t n,
89
+ size_t beam_size,
90
+ const float* residuals, /// size (n, beam_size, d)
91
+ size_t m,
92
+ const int32_t* codes, /// size (n, beam_size, m)
93
+ size_t new_beam_size,
94
+ int32_t* new_codes, /// size (n, new_beam_size, m + 1)
95
+ float* new_residuals, /// size (n, new_beam_size, d)
96
+ float* new_distances, /// size (n, new_beam_size)
97
+ Index* assign_index) {
98
+ // we have to fill in the whole output matrix
99
+ FAISS_THROW_IF_NOT(new_beam_size <= beam_size * K);
100
+
101
+ using idx_t = Index::idx_t;
102
+
103
+ std::vector<float> cent_distances;
104
+ std::vector<idx_t> cent_ids;
105
+
106
+ if (assign_index) {
107
+ // search beam_size distances per query
108
+ FAISS_THROW_IF_NOT(assign_index->d == d);
109
+ cent_distances.resize(n * beam_size * new_beam_size);
110
+ cent_ids.resize(n * beam_size * new_beam_size);
111
+ if (assign_index->ntotal != 0) {
112
+ // then we assume the codebooks are already added to the index
113
+ FAISS_THROW_IF_NOT(assign_index->ntotal == K);
114
+ } else {
115
+ assign_index->add(K, cent);
116
+ }
117
+
118
+ // printf("beam_search_encode_step -- mem usage %zd\n",
119
+ // get_mem_usage_kb());
120
+ assign_index->search(
121
+ n * beam_size,
122
+ residuals,
123
+ new_beam_size,
124
+ cent_distances.data(),
125
+ cent_ids.data());
126
+ } else {
127
+ // do one big distance computation
128
+ cent_distances.resize(n * beam_size * K);
129
+ pairwise_L2sqr(
130
+ d, n * beam_size, residuals, K, cent, cent_distances.data());
131
+ }
132
+ InterruptCallback::check();
133
+
134
+ #pragma omp parallel for if (n > 100)
135
+ for (int64_t i = 0; i < n; i++) {
136
+ const int32_t* codes_i = codes + i * m * beam_size;
137
+ int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
138
+ const float* residuals_i = residuals + i * d * beam_size;
139
+ float* new_residuals_i = new_residuals + i * d * new_beam_size;
140
+
141
+ float* new_distances_i = new_distances + i * new_beam_size;
142
+ using C = CMax<float, int>;
143
+
144
+ if (assign_index) {
145
+ const float* cent_distances_i =
146
+ cent_distances.data() + i * beam_size * new_beam_size;
147
+ const idx_t* cent_ids_i =
148
+ cent_ids.data() + i * beam_size * new_beam_size;
149
+
150
+ // here we could be a tad more efficient by merging sorted arrays
151
+ for (int i = 0; i < new_beam_size; i++) {
152
+ new_distances_i[i] = C::neutral();
153
+ }
154
+ std::vector<int> perm(new_beam_size, -1);
155
+ heap_addn<C>(
156
+ new_beam_size,
157
+ new_distances_i,
158
+ perm.data(),
159
+ cent_distances_i,
160
+ nullptr,
161
+ beam_size * new_beam_size);
162
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
163
+
164
+ for (int j = 0; j < new_beam_size; j++) {
165
+ int js = perm[j] / new_beam_size;
166
+ int ls = cent_ids_i[perm[j]];
167
+ if (m > 0) {
168
+ memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
169
+ }
170
+ new_codes_i[m] = ls;
171
+ new_codes_i += m + 1;
172
+ fvec_sub(
173
+ d,
174
+ residuals_i + js * d,
175
+ cent + ls * d,
176
+ new_residuals_i);
177
+ new_residuals_i += d;
178
+ }
179
+
180
+ } else {
181
+ const float* cent_distances_i =
182
+ cent_distances.data() + i * beam_size * K;
183
+ // then we have to select the best results
184
+ for (int i = 0; i < new_beam_size; i++) {
185
+ new_distances_i[i] = C::neutral();
186
+ }
187
+ std::vector<int> perm(new_beam_size, -1);
188
+ heap_addn<C>(
189
+ new_beam_size,
190
+ new_distances_i,
191
+ perm.data(),
192
+ cent_distances_i,
193
+ nullptr,
194
+ beam_size * K);
195
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
196
+
197
+ for (int j = 0; j < new_beam_size; j++) {
198
+ int js = perm[j] / K;
199
+ int ls = perm[j] % K;
200
+ if (m > 0) {
201
+ memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
202
+ }
203
+ new_codes_i[m] = ls;
204
+ new_codes_i += m + 1;
205
+ fvec_sub(
206
+ d,
207
+ residuals_i + js * d,
208
+ cent + ls * d,
209
+ new_residuals_i);
210
+ new_residuals_i += d;
211
+ }
212
+ }
213
+ }
214
+ }
215
+
216
+ void ResidualQuantizer::train(size_t n, const float* x) {
217
+ codebooks.resize(d * codebook_offsets.back());
218
+
219
+ if (verbose) {
220
+ printf("Training ResidualQuantizer, with %zd steps on %zd %zdD vectors\n",
221
+ M,
222
+ n,
223
+ size_t(d));
224
+ }
225
+
226
+ int cur_beam_size = 1;
227
+ std::vector<float> residuals(x, x + n * d);
228
+ std::vector<int32_t> codes;
229
+ std::vector<float> distances;
230
+ double t0 = getmillisecs();
231
+ double clustering_time = 0;
232
+
233
+ for (int m = 0; m < M; m++) {
234
+ int K = 1 << nbits[m];
235
+
236
+ // on which residuals to train
237
+ std::vector<float>& train_residuals = residuals;
238
+ std::vector<float> residuals1;
239
+ if (train_type & Train_top_beam) {
240
+ residuals1.resize(n * d);
241
+ for (size_t j = 0; j < n; j++) {
242
+ memcpy(residuals1.data() + j * d,
243
+ residuals.data() + j * d * cur_beam_size,
244
+ sizeof(residuals[0]) * d);
245
+ }
246
+ train_residuals = residuals1;
247
+ }
248
+ train_type_t tt = train_type_t(train_type & 1023);
249
+
250
+ std::vector<float> codebooks;
251
+ float obj = 0;
252
+
253
+ std::unique_ptr<Index> assign_index;
254
+ if (assign_index_factory) {
255
+ assign_index.reset((*assign_index_factory)(d));
256
+ } else {
257
+ assign_index.reset(new IndexFlatL2(d));
258
+ }
259
+
260
+ double t1 = getmillisecs();
261
+
262
+ if (tt == Train_default) {
263
+ Clustering clus(d, K, cp);
264
+ clus.train(
265
+ train_residuals.size() / d,
266
+ train_residuals.data(),
267
+ *assign_index.get());
268
+ codebooks.swap(clus.centroids);
269
+ assign_index->reset();
270
+ obj = clus.iteration_stats.back().obj;
271
+ } else if (tt == Train_progressive_dim) {
272
+ ProgressiveDimClustering clus(d, K, cp);
273
+ ProgressiveDimIndexFactory default_fac;
274
+ clus.train(
275
+ train_residuals.size() / d,
276
+ train_residuals.data(),
277
+ assign_index_factory ? *assign_index_factory : default_fac);
278
+ codebooks.swap(clus.centroids);
279
+ obj = clus.iteration_stats.back().obj;
280
+ } else {
281
+ FAISS_THROW_MSG("train type not supported");
282
+ }
283
+ clustering_time += (getmillisecs() - t1) / 1000;
284
+
285
+ memcpy(this->codebooks.data() + codebook_offsets[m] * d,
286
+ codebooks.data(),
287
+ codebooks.size() * sizeof(codebooks[0]));
288
+
289
+ // quantize using the new codebooks
290
+
291
+ int new_beam_size = std::min(cur_beam_size * K, max_beam_size);
292
+ std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
293
+ std::vector<float> new_residuals(n * new_beam_size * d);
294
+ std::vector<float> new_distances(n * new_beam_size);
295
+
296
+ size_t bs;
297
+ { // determine batch size
298
+ size_t mem = memory_per_point();
299
+ if (n > 1 && mem * n > max_mem_distances) {
300
+ // then split queries to reduce temp memory
301
+ bs = std::max(max_mem_distances / mem, size_t(1));
302
+ } else {
303
+ bs = n;
304
+ }
305
+ }
306
+
307
+ for (size_t i0 = 0; i0 < n; i0 += bs) {
308
+ size_t i1 = std::min(i0 + bs, n);
309
+
310
+ /* printf("i0: %ld i1: %ld K %d ntotal assign index %ld\n",
311
+ i0, i1, K, assign_index->ntotal); */
312
+
313
+ beam_search_encode_step(
314
+ d,
315
+ K,
316
+ codebooks.data(),
317
+ i1 - i0,
318
+ cur_beam_size,
319
+ residuals.data() + i0 * cur_beam_size * d,
320
+ m,
321
+ codes.data() + i0 * cur_beam_size * m,
322
+ new_beam_size,
323
+ new_codes.data() + i0 * new_beam_size * (m + 1),
324
+ new_residuals.data() + i0 * new_beam_size * d,
325
+ new_distances.data() + i0 * new_beam_size,
326
+ assign_index.get());
327
+ }
328
+ codes.swap(new_codes);
329
+ residuals.swap(new_residuals);
330
+ distances.swap(new_distances);
331
+
332
+ float sum_distances = 0;
333
+ for (int j = 0; j < distances.size(); j++) {
334
+ sum_distances += distances[j];
335
+ }
336
+
337
+ if (verbose) {
338
+ printf("[%.3f s, %.3f s clustering] train stage %d, %d bits, kmeans objective %g, "
339
+ "total distance %g, beam_size %d->%d (batch size %zd)\n",
340
+ (getmillisecs() - t0) / 1000,
341
+ clustering_time,
342
+ m,
343
+ int(nbits[m]),
344
+ obj,
345
+ sum_distances,
346
+ cur_beam_size,
347
+ new_beam_size,
348
+ bs);
349
+ }
350
+ cur_beam_size = new_beam_size;
351
+ }
352
+
353
+ // find min and max norms
354
+ std::vector<float> norms(n);
355
+
356
+ for (size_t i = 0; i < n; i++) {
357
+ norms[i] = fvec_L2sqr(
358
+ x + i * d, residuals.data() + i * cur_beam_size * d, d);
359
+ }
360
+
361
+ // fvec_norms_L2sqr(norms.data(), x, d, n);
362
+
363
+ norm_min = HUGE_VALF;
364
+ norm_max = -HUGE_VALF;
365
+ for (idx_t i = 0; i < n; i++) {
366
+ if (norms[i] < norm_min) {
367
+ norm_min = norms[i];
368
+ }
369
+ if (norms[i] > norm_max) {
370
+ norm_max = norms[i];
371
+ }
372
+ }
373
+
374
+ if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
375
+ size_t k = (1 << 8);
376
+ if (search_type == ST_norm_cqint4) {
377
+ k = (1 << 4);
378
+ }
379
+ Clustering1D clus(k);
380
+ clus.train_exact(n, norms.data());
381
+ qnorm.add(clus.k, clus.centroids.data());
382
+ }
383
+
384
+ is_trained = true;
385
+
386
+ if (!(train_type & Skip_codebook_tables)) {
387
+ compute_codebook_tables();
388
+ }
389
+ }
390
+
391
+ size_t ResidualQuantizer::memory_per_point(int beam_size) const {
392
+ if (beam_size < 0) {
393
+ beam_size = max_beam_size;
394
+ }
395
+ size_t mem;
396
+ mem = beam_size * d * 2 * sizeof(float); // size for 2 beams at a time
397
+ mem += beam_size * beam_size *
398
+ (sizeof(float) +
399
+ sizeof(Index::idx_t)); // size for 1 beam search result
400
+ return mem;
401
+ }
402
+
403
+ void ResidualQuantizer::compute_codes(
404
+ const float* x,
405
+ uint8_t* codes_out,
406
+ size_t n) const {
407
+ FAISS_THROW_IF_NOT_MSG(is_trained, "RQ is not trained yet.");
408
+
409
+ size_t mem = memory_per_point();
410
+ if (n > 1 && mem * n > max_mem_distances) {
411
+ // then split queries to reduce temp memory
412
+ size_t bs = max_mem_distances / mem;
413
+ if (bs == 0) {
414
+ bs = 1; // otherwise we can't do much
415
+ }
416
+ for (size_t i0 = 0; i0 < n; i0 += bs) {
417
+ size_t i1 = std::min(n, i0 + bs);
418
+ compute_codes(x + i0 * d, codes_out + i0 * code_size, i1 - i0);
419
+ }
420
+ return;
421
+ }
422
+
423
+ std::vector<int32_t> codes(max_beam_size * M * n);
424
+ std::vector<float> norms;
425
+ std::vector<float> distances(max_beam_size * n);
426
+
427
+ if (use_beam_LUT == 0) {
428
+ std::vector<float> residuals(max_beam_size * n * d);
429
+
430
+ refine_beam(
431
+ n,
432
+ 1,
433
+ x,
434
+ max_beam_size,
435
+ codes.data(),
436
+ residuals.data(),
437
+ distances.data());
438
+
439
+ if (search_type == ST_norm_float || search_type == ST_norm_qint8 ||
440
+ search_type == ST_norm_qint4) {
441
+ norms.resize(n);
442
+ // recover the norms of reconstruction as
443
+ // || original_vector - residual ||^2
444
+ for (size_t i = 0; i < n; i++) {
445
+ norms[i] = fvec_L2sqr(
446
+ x + i * d, residuals.data() + i * max_beam_size * d, d);
447
+ }
448
+ }
449
+ } else if (use_beam_LUT == 1) {
450
+ FAISS_THROW_IF_NOT_MSG(
451
+ codebook_cross_products.size() ==
452
+ total_codebook_size * total_codebook_size,
453
+ "call compute_codebook_tables first");
454
+
455
+ std::vector<float> query_norms(n);
456
+ fvec_norms_L2sqr(query_norms.data(), x, d, n);
457
+
458
+ std::vector<float> query_cp(n * total_codebook_size);
459
+ {
460
+ FINTEGER ti = total_codebook_size, di = d, ni = n;
461
+ float zero = 0, one = 1;
462
+ sgemm_("Transposed",
463
+ "Not transposed",
464
+ &ti,
465
+ &ni,
466
+ &di,
467
+ &one,
468
+ codebooks.data(),
469
+ &di,
470
+ x,
471
+ &di,
472
+ &zero,
473
+ query_cp.data(),
474
+ &ti);
475
+ }
476
+
477
+ refine_beam_LUT(
478
+ n,
479
+ query_norms.data(),
480
+ query_cp.data(),
481
+ max_beam_size,
482
+ codes.data(),
483
+ distances.data());
484
+ }
485
+ // pack only the first code of the beam (hence the ld_codes=M *
486
+ // max_beam_size)
487
+ pack_codes(
488
+ n,
489
+ codes.data(),
490
+ codes_out,
491
+ M * max_beam_size,
492
+ norms.size() > 0 ? norms.data() : nullptr);
493
+ }
494
+
495
+ void ResidualQuantizer::refine_beam(
496
+ size_t n,
497
+ size_t beam_size,
498
+ const float* x,
499
+ int out_beam_size,
500
+ int32_t* out_codes,
501
+ float* out_residuals,
502
+ float* out_distances) const {
503
+ int cur_beam_size = beam_size;
504
+
505
+ std::vector<float> residuals(x, x + n * d * beam_size);
506
+ std::vector<int32_t> codes;
507
+ std::vector<float> distances;
508
+ double t0 = getmillisecs();
509
+
510
+ std::unique_ptr<Index> assign_index;
511
+ if (assign_index_factory) {
512
+ assign_index.reset((*assign_index_factory)(d));
513
+ } else {
514
+ assign_index.reset(new IndexFlatL2(d));
515
+ }
516
+
517
+ for (int m = 0; m < M; m++) {
518
+ int K = 1 << nbits[m];
519
+
520
+ const float* codebooks_m =
521
+ this->codebooks.data() + codebook_offsets[m] * d;
522
+
523
+ int new_beam_size = std::min(cur_beam_size * K, out_beam_size);
524
+
525
+ std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
526
+ std::vector<float> new_residuals(n * new_beam_size * d);
527
+ distances.resize(n * new_beam_size);
528
+
529
+ beam_search_encode_step(
530
+ d,
531
+ K,
532
+ codebooks_m,
533
+ n,
534
+ cur_beam_size,
535
+ residuals.data(),
536
+ m,
537
+ codes.data(),
538
+ new_beam_size,
539
+ new_codes.data(),
540
+ new_residuals.data(),
541
+ distances.data(),
542
+ assign_index.get());
543
+
544
+ assign_index->reset();
545
+
546
+ codes.swap(new_codes);
547
+ residuals.swap(new_residuals);
548
+
549
+ cur_beam_size = new_beam_size;
550
+
551
+ if (verbose) {
552
+ float sum_distances = 0;
553
+ for (int j = 0; j < distances.size(); j++) {
554
+ sum_distances += distances[j];
555
+ }
556
+ printf("[%.3f s] encode stage %d, %d bits, "
557
+ "total error %g, beam_size %d\n",
558
+ (getmillisecs() - t0) / 1000,
559
+ m,
560
+ int(nbits[m]),
561
+ sum_distances,
562
+ cur_beam_size);
563
+ }
564
+ }
565
+
566
+ if (out_codes) {
567
+ memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
568
+ }
569
+ if (out_residuals) {
570
+ memcpy(out_residuals,
571
+ residuals.data(),
572
+ residuals.size() * sizeof(residuals[0]));
573
+ }
574
+ if (out_distances) {
575
+ memcpy(out_distances,
576
+ distances.data(),
577
+ distances.size() * sizeof(distances[0]));
578
+ }
579
+ }
580
+
581
+ /*******************************************************************
582
+ * Functions using the dot products between codebook entries
583
+ *******************************************************************/
584
+
585
+ void ResidualQuantizer::compute_codebook_tables() {
586
+ codebook_cross_products.resize(total_codebook_size * total_codebook_size);
587
+ cent_norms.resize(total_codebook_size);
588
+ // stricly speaking we could use ssyrk
589
+ {
590
+ FINTEGER ni = total_codebook_size;
591
+ FINTEGER di = d;
592
+ float zero = 0, one = 1;
593
+ sgemm_("Transposed",
594
+ "Not transposed",
595
+ &ni,
596
+ &ni,
597
+ &di,
598
+ &one,
599
+ codebooks.data(),
600
+ &di,
601
+ codebooks.data(),
602
+ &di,
603
+ &zero,
604
+ codebook_cross_products.data(),
605
+ &ni);
606
+ }
607
+ for (size_t i = 0; i < total_codebook_size; i++) {
608
+ cent_norms[i] = codebook_cross_products[i + i * total_codebook_size];
609
+ }
610
+ }
611
+
612
+ void beam_search_encode_step_tab(
613
+ size_t K,
614
+ size_t n,
615
+ size_t beam_size, // input sizes
616
+ const float* codebook_cross_norms, // size K * ldc
617
+ size_t ldc, // >= K
618
+ const uint64_t* codebook_offsets, // m
619
+ const float* query_cp, // size n * ldqc
620
+ size_t ldqc, // >= K
621
+ const float* cent_norms_i, // size K
622
+ size_t m,
623
+ const int32_t* codes, // n * beam_size * m
624
+ const float* distances, // n * beam_size
625
+ size_t new_beam_size,
626
+ int32_t* new_codes, // n * new_beam_size * (m + 1)
627
+ float* new_distances) // n * new_beam_size
628
+ {
629
+ FAISS_THROW_IF_NOT(ldc >= K);
630
+
631
+ #pragma omp parallel for if (n > 100)
632
+ for (int64_t i = 0; i < n; i++) {
633
+ std::vector<float> cent_distances(beam_size * K);
634
+ std::vector<float> cd_common(K);
635
+
636
+ const int32_t* codes_i = codes + i * m * beam_size;
637
+ const float* query_cp_i = query_cp + i * ldqc;
638
+ const float* distances_i = distances + i * beam_size;
639
+
640
+ for (size_t k = 0; k < K; k++) {
641
+ cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
642
+ }
643
+
644
+ for (size_t b = 0; b < beam_size; b++) {
645
+ std::vector<float> dp(K);
646
+
647
+ for (size_t m1 = 0; m1 < m; m1++) {
648
+ size_t c = codes_i[b * m + m1];
649
+ const float* cb =
650
+ &codebook_cross_norms[(codebook_offsets[m1] + c) * ldc];
651
+ fvec_add(K, cb, dp.data(), dp.data());
652
+ }
653
+
654
+ for (size_t k = 0; k < K; k++) {
655
+ cent_distances[b * K + k] =
656
+ distances_i[b] + cd_common[k] + 2 * dp[k];
657
+ }
658
+ }
659
+
660
+ using C = CMax<float, int>;
661
+ int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
662
+ float* new_distances_i = new_distances + i * new_beam_size;
663
+
664
+ const float* cent_distances_i = cent_distances.data();
665
+
666
+ // then we have to select the best results
667
+ for (int i = 0; i < new_beam_size; i++) {
668
+ new_distances_i[i] = C::neutral();
669
+ }
670
+ std::vector<int> perm(new_beam_size, -1);
671
+ heap_addn<C>(
672
+ new_beam_size,
673
+ new_distances_i,
674
+ perm.data(),
675
+ cent_distances_i,
676
+ nullptr,
677
+ beam_size * K);
678
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
679
+
680
+ for (int j = 0; j < new_beam_size; j++) {
681
+ int js = perm[j] / K;
682
+ int ls = perm[j] % K;
683
+ if (m > 0) {
684
+ memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
685
+ }
686
+ new_codes_i[m] = ls;
687
+ new_codes_i += m + 1;
688
+ }
689
+ }
690
+ }
691
+
692
+ void ResidualQuantizer::refine_beam_LUT(
693
+ size_t n,
694
+ const float* query_norms, // size n
695
+ const float* query_cp, //
696
+ int out_beam_size,
697
+ int32_t* out_codes,
698
+ float* out_distances) const {
699
+ int beam_size = 1;
700
+
701
+ std::vector<int32_t> codes;
702
+ std::vector<float> distances(query_norms, query_norms + n);
703
+ double t0 = getmillisecs();
704
+
705
+ for (int m = 0; m < M; m++) {
706
+ int K = 1 << nbits[m];
707
+
708
+ int new_beam_size = std::min(beam_size * K, out_beam_size);
709
+ std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
710
+ std::vector<float> new_distances(n * new_beam_size);
711
+
712
+ beam_search_encode_step_tab(
713
+ K,
714
+ n,
715
+ beam_size,
716
+ codebook_cross_products.data() + codebook_offsets[m],
717
+ total_codebook_size,
718
+ codebook_offsets.data(),
719
+ query_cp + codebook_offsets[m],
720
+ total_codebook_size,
721
+ cent_norms.data() + codebook_offsets[m],
722
+ m,
723
+ codes.data(),
724
+ distances.data(),
725
+ new_beam_size,
726
+ new_codes.data(),
727
+ new_distances.data());
728
+
729
+ codes.swap(new_codes);
730
+ distances.swap(new_distances);
731
+ beam_size = new_beam_size;
732
+
733
+ if (verbose) {
734
+ float sum_distances = 0;
735
+ for (int j = 0; j < distances.size(); j++) {
736
+ sum_distances += distances[j];
737
+ }
738
+ printf("[%.3f s] encode stage %d, %d bits, "
739
+ "total error %g, beam_size %d\n",
740
+ (getmillisecs() - t0) / 1000,
741
+ m,
742
+ int(nbits[m]),
743
+ sum_distances,
744
+ beam_size);
745
+ }
746
+ }
747
+
748
+ if (out_codes) {
749
+ memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
750
+ }
751
+ if (out_distances) {
752
+ memcpy(out_distances,
753
+ distances.data(),
754
+ distances.size() * sizeof(distances[0]));
755
+ }
756
+ }
757
+
758
+ } // namespace faiss