faiss 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +5 -6
  6. data/ext/faiss/index_binary.cpp +76 -17
  7. data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
  8. data/ext/faiss/kmeans.cpp +12 -9
  9. data/ext/faiss/numo.hpp +11 -9
  10. data/ext/faiss/pca_matrix.cpp +10 -8
  11. data/ext/faiss/product_quantizer.cpp +14 -12
  12. data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
  13. data/ext/faiss/{utils.h → utils_rb.h} +6 -0
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  17. data/vendor/faiss/faiss/AutoTune.h +14 -1
  18. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  19. data/vendor/faiss/faiss/Clustering.h +12 -0
  20. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  21. data/vendor/faiss/faiss/Index.cpp +20 -8
  22. data/vendor/faiss/faiss/Index.h +25 -3
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  24. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  28. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  29. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  30. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  34. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  37. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  40. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  41. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  42. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  43. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  44. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  46. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  51. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  52. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  54. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  55. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  56. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  58. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  59. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  60. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  62. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  63. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  64. data/vendor/faiss/faiss/MetricType.h +16 -0
  65. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  66. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  67. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  68. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  70. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  71. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  72. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  73. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  74. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  75. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  76. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  77. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  78. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  79. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  80. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  81. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  82. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  83. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  84. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  85. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  86. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  87. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  88. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  89. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  90. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  91. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  92. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  93. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  94. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  95. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  96. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  97. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  98. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  99. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  100. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  101. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  102. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  103. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  104. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  105. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  106. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  109. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  110. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  111. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  112. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  113. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  114. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  115. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  116. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  126. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  127. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  128. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  129. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  130. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  131. data/vendor/faiss/faiss/index_io.h +29 -3
  132. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  133. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  134. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  135. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  136. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  137. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  138. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  139. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  140. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  141. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  142. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  143. data/vendor/faiss/faiss/utils/distances.h +98 -0
  144. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  145. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  146. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  147. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  148. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  149. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  150. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  151. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  152. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  157. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  158. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  159. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  160. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  161. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  162. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  163. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  164. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  165. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  166. metadata +47 -18
  167. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  168. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  169. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -0,0 +1,367 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/impl/ClusteringInitialization.h>
9
+
10
+ #include <algorithm>
11
+ #include <chrono>
12
+ #include <cstring>
13
+ #include <limits>
14
+ #include <random>
15
+ #include <unordered_set>
16
+ #include <vector>
17
+
18
+ #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/utils/distances_dispatch.h>
20
+ #include <faiss/utils/random.h>
21
+
22
+ namespace faiss {
23
+
24
+ namespace {
25
+
26
+ uint64_t get_seed(int64_t seed) {
27
+ if (seed >= 0) {
28
+ return static_cast<uint64_t>(seed);
29
+ }
30
+ return static_cast<uint64_t>(std::chrono::high_resolution_clock::now()
31
+ .time_since_epoch()
32
+ .count());
33
+ }
34
+
35
+ /// Compute distance from point idx to its nearest centroid.
36
+ /// Optionally checks both primary and secondary centroid sets.
37
+ float distance_to_nearest_centroid(
38
+ size_t d,
39
+ size_t n_centroids,
40
+ const float* x,
41
+ size_t idx,
42
+ const float* centroids,
43
+ size_t n_existing_centroids = 0,
44
+ const float* existing_centroids = nullptr) {
45
+ if (n_centroids == 0 && n_existing_centroids == 0) {
46
+ return std::numeric_limits<float>::infinity();
47
+ }
48
+
49
+ const float* point = x + idx * d;
50
+ float min_dist = std::numeric_limits<float>::max();
51
+
52
+ auto check_centroids = [&]<SIMDLevel SL>() {
53
+ // Check primary centroids
54
+ for (size_t c = 0; c < n_centroids; c++) {
55
+ float dist = fvec_L2sqr<SL>(point, centroids + c * d, d);
56
+ min_dist = std::min(min_dist, dist);
57
+ }
58
+
59
+ // Check existing centroids if provided
60
+ for (size_t c = 0; c < n_existing_centroids; c++) {
61
+ float dist = fvec_L2sqr<SL>(point, existing_centroids + c * d, d);
62
+ min_dist = std::min(min_dist, dist);
63
+ }
64
+ };
65
+ with_simd_level(check_centroids);
66
+ return min_dist;
67
+ }
68
+
69
+ /// Result of initializing distances for D² sampling
70
+ struct InitDistancesResult {
71
+ size_t first_new_centroid_idx;
72
+ double sum_d2;
73
+ size_t first_selected_idx; // Only valid when first_new_centroid_idx == 1
74
+ };
75
+
76
+ /// Initialize distance array for D² sampling.
77
+ /// If existing centroids are provided, computes distances to them.
78
+ /// Otherwise, selects first centroid randomly and computes distances to it.
79
+ /// Returns first_new_centroid_idx (0 if existing, 1 if random first),
80
+ /// sum of squared distances, and the first selected index (if applicable).
81
+ InitDistancesResult init_distances_for_d2_sampling(
82
+ size_t d,
83
+ size_t n,
84
+ const float* x,
85
+ float* centroids,
86
+ size_t n_existing_centroids,
87
+ const float* existing_centroids,
88
+ std::vector<double>& distances,
89
+ std::mt19937_64& rng) {
90
+ double sum_d2 = 0.0;
91
+ size_t first_selected_idx = 0;
92
+
93
+ if (n_existing_centroids > 0 && existing_centroids != nullptr) {
94
+ // Compute distances to nearest existing centroid
95
+ for (size_t i = 0; i < n; i++) {
96
+ distances[i] = distance_to_nearest_centroid(
97
+ d, n_existing_centroids, x, i, existing_centroids);
98
+ sum_d2 += distances[i];
99
+ }
100
+ return {0, sum_d2, 0};
101
+ } else {
102
+ // Select first centroid randomly
103
+ std::uniform_int_distribution<size_t> uniform_dist(0, n - 1);
104
+ first_selected_idx = uniform_dist(rng);
105
+ std::memcpy(centroids, x + first_selected_idx * d, d * sizeof(float));
106
+
107
+ // Compute distances to first centroid
108
+ with_simd_level([&]<SIMDLevel SL>() {
109
+ for (size_t i = 0; i < n; i++) {
110
+ distances[i] = fvec_L2sqr<SL>(x + i * d, centroids, d);
111
+ sum_d2 += distances[i];
112
+ }
113
+ });
114
+ return {1, sum_d2, first_selected_idx};
115
+ }
116
+ }
117
+
118
+ /// Sample an index from a distribution using precomputed cumulative sum.
119
+ /// Falls back to uniform sampling if total weight is zero.
120
+ size_t sample_from_cumsum(
121
+ const std::vector<double>& q_cumsum,
122
+ std::mt19937_64& rng) {
123
+ size_t n = q_cumsum.size();
124
+ if (n == 0) {
125
+ return 0;
126
+ }
127
+
128
+ double total = q_cumsum[n - 1];
129
+ if (total <= 0) {
130
+ // Fallback to uniform sampling if all weights are zero
131
+ std::uniform_int_distribution<size_t> uniform(0, n - 1);
132
+ return uniform(rng);
133
+ }
134
+
135
+ std::uniform_real_distribution<double> dist(0.0, total);
136
+ double r = dist(rng);
137
+
138
+ auto it = std::lower_bound(q_cumsum.begin(), q_cumsum.end(), r);
139
+ size_t idx = std::distance(q_cumsum.begin(), it);
140
+ return std::min(idx, n - 1);
141
+ }
142
+
143
+ } // namespace
144
+
145
+ ClusteringInitialization::ClusteringInitialization(size_t d, size_t k)
146
+ : d(d), k(k) {}
147
+
148
+ void ClusteringInitialization::init_centroids(
149
+ size_t n,
150
+ const float* x,
151
+ float* centroids,
152
+ size_t n_existing_centroids,
153
+ const float* existing_centroids) const {
154
+ FAISS_THROW_IF_NOT_FMT(
155
+ n >= k,
156
+ "Number of points (%zu) must be >= number of centroids (%zu)",
157
+ n,
158
+ k);
159
+ FAISS_THROW_IF_NOT(d > 0);
160
+ FAISS_THROW_IF_NOT(x != nullptr);
161
+ FAISS_THROW_IF_NOT(centroids != nullptr);
162
+ FAISS_THROW_IF_NOT(
163
+ n_existing_centroids == 0 || existing_centroids != nullptr);
164
+
165
+ switch (method) {
166
+ case ClusteringInitMethod::RANDOM:
167
+ init_random(n, x, centroids);
168
+ break;
169
+ case ClusteringInitMethod::KMEANS_PLUS_PLUS:
170
+ init_kmeans_plus_plus(
171
+ n, x, centroids, n_existing_centroids, existing_centroids);
172
+ break;
173
+ case ClusteringInitMethod::AFK_MC2:
174
+ init_afkmc2(
175
+ n, x, centroids, n_existing_centroids, existing_centroids);
176
+ break;
177
+ default:
178
+ FAISS_THROW_MSG("Unknown initialization method");
179
+ }
180
+ }
181
+
182
+ void ClusteringInitialization::init_random(
183
+ size_t n,
184
+ const float* x,
185
+ float* centroids) const {
186
+ // Use rand_perm for backward compatibility with Clustering.cpp
187
+ // This ensures the same random sequence as the original implementation
188
+ std::vector<int> perm(n);
189
+ rand_perm(perm.data(), n, seed);
190
+
191
+ // Copy selected points to centroids
192
+ for (size_t i = 0; i < k; i++) {
193
+ std::memcpy(centroids + i * d, x + perm[i] * d, d * sizeof(float));
194
+ }
195
+ }
196
+
197
+ void ClusteringInitialization::init_kmeans_plus_plus(
198
+ size_t n,
199
+ const float* x,
200
+ float* centroids,
201
+ size_t n_existing_centroids,
202
+ const float* existing_centroids) const {
203
+ std::mt19937_64 rng(get_seed(seed));
204
+
205
+ std::vector<double> min_distances(n);
206
+ auto result = init_distances_for_d2_sampling(
207
+ d,
208
+ n,
209
+ x,
210
+ centroids,
211
+ n_existing_centroids,
212
+ existing_centroids,
213
+ min_distances,
214
+ rng);
215
+
216
+ if (result.first_new_centroid_idx == 1 && k == 1) {
217
+ return;
218
+ }
219
+
220
+ // Reusable buffer for cumulative sum
221
+ std::vector<double> cumsum(n);
222
+
223
+ // Select remaining centroids using D² sampling
224
+ for (size_t c = result.first_new_centroid_idx; c < k; c++) {
225
+ // Compute cumulative sum
226
+ cumsum[0] = min_distances[0];
227
+ for (size_t i = 1; i < n; i++) {
228
+ cumsum[i] = cumsum[i - 1] + min_distances[i];
229
+ }
230
+
231
+ // Sample using precomputed cumsum
232
+ size_t next_idx = sample_from_cumsum(cumsum, rng);
233
+
234
+ float* new_centroid = centroids + c * d;
235
+ std::memcpy(new_centroid, x + next_idx * d, d * sizeof(float));
236
+
237
+ // Update min distances incrementally
238
+ with_simd_level([&]<SIMDLevel SL>() {
239
+ for (size_t i = 0; i < n; i++) {
240
+ double dist = fvec_L2sqr<SL>(x + i * d, new_centroid, d);
241
+ min_distances[i] = std::min(min_distances[i], dist);
242
+ }
243
+ });
244
+ }
245
+ }
246
+
247
+ void ClusteringInitialization::init_afkmc2(
248
+ size_t n,
249
+ const float* x,
250
+ float* centroids,
251
+ size_t n_existing_centroids,
252
+ const float* existing_centroids) const {
253
+ // AFK-MC² (Assumption-Free K-MC²) algorithm:
254
+ // Reference: Bachem et al., "Fast and Provably Good Seedings for
255
+ // k-Means"
256
+
257
+ std::mt19937_64 rng(get_seed(seed));
258
+ std::uniform_real_distribution<double> uniform_01(0.0, 1.0);
259
+
260
+ // Track selected centroids to prevent duplicates
261
+ std::unordered_set<size_t> selected_centroids;
262
+
263
+ // Compute proposal distribution q(x)
264
+ // If existing centroids: base q on distance to nearest existing
265
+ // centroid Otherwise: select first centroid randomly and base q on
266
+ // it
267
+ std::vector<double> dist_to_nearest(n);
268
+ auto result = init_distances_for_d2_sampling(
269
+ d,
270
+ n,
271
+ x,
272
+ centroids,
273
+ n_existing_centroids,
274
+ existing_centroids,
275
+ dist_to_nearest,
276
+ rng);
277
+
278
+ if (result.first_new_centroid_idx == 1) {
279
+ selected_centroids.insert(result.first_selected_idx);
280
+ if (k == 1) {
281
+ return;
282
+ }
283
+ }
284
+
285
+ // Compute q(x) and cumulative sum in a single pass
286
+ std::vector<double> q(n);
287
+ std::vector<double> q_cumsum(n);
288
+ double uniform_term = 0.5 / static_cast<double>(n);
289
+
290
+ for (size_t i = 0; i < n; i++) {
291
+ double d2_term = (result.sum_d2 > 0)
292
+ ? 0.5 * dist_to_nearest[i] / result.sum_d2
293
+ : 0.0;
294
+ q[i] = d2_term + uniform_term;
295
+ q_cumsum[i] = (i > 0 ? q_cumsum[i - 1] : 0.0) + q[i];
296
+ }
297
+
298
+ // Main loop: Select remaining centroids using MCMC
299
+ for (size_t c = result.first_new_centroid_idx; c < k; c++) {
300
+ // Sample initial candidate from proposal distribution q, skip
301
+ // duplicates
302
+ size_t current_idx;
303
+ do {
304
+ current_idx = sample_from_cumsum(q_cumsum, rng);
305
+ } while (selected_centroids.count(current_idx) > 0);
306
+
307
+ // Compute distance to nearest centroid (existing + newly
308
+ // selected)
309
+ double current_dist = distance_to_nearest_centroid(
310
+ d,
311
+ c,
312
+ x,
313
+ current_idx,
314
+ centroids,
315
+ n_existing_centroids,
316
+ existing_centroids);
317
+ double current_q = q[current_idx];
318
+
319
+ // Run Markov chain
320
+ for (size_t m = 0; m < afkmc2_chain_length; m++) {
321
+ // Sample proposal from q
322
+ size_t proposed_idx = sample_from_cumsum(q_cumsum, rng);
323
+
324
+ // Skip duplicates before expensive distance computation
325
+ if (selected_centroids.count(proposed_idx) > 0) {
326
+ continue;
327
+ }
328
+
329
+ // Compute distance to nearest centroid (existing + newly
330
+ // selected)
331
+ double proposed_dist = distance_to_nearest_centroid(
332
+ d,
333
+ c,
334
+ x,
335
+ proposed_idx,
336
+ centroids,
337
+ n_existing_centroids,
338
+ existing_centroids);
339
+ double proposed_q = q[proposed_idx];
340
+
341
+ // Metropolis-Hastings acceptance ratio:
342
+ // accept = min(1, d(y,C)² · q(x) / (d(x,C)² · q(y)))
343
+ double acceptance_prob = 0.0;
344
+ if (current_dist <= 0) {
345
+ // Current point is a centroid (distance = 0), never
346
+ // leave
347
+ acceptance_prob = 0.0;
348
+ } else if (proposed_q > 0) {
349
+ double numerator = proposed_dist * current_q;
350
+ double denominator = current_dist * proposed_q;
351
+ acceptance_prob = std::min(1.0, numerator / denominator);
352
+ }
353
+
354
+ if (uniform_01(rng) < acceptance_prob) {
355
+ current_idx = proposed_idx;
356
+ current_dist = proposed_dist;
357
+ current_q = proposed_q;
358
+ }
359
+ }
360
+
361
+ // Use final chain state as new centroid
362
+ selected_centroids.insert(current_idx);
363
+ std::memcpy(centroids + c * d, x + current_idx * d, d * sizeof(float));
364
+ }
365
+ }
366
+
367
+ } // namespace faiss
@@ -0,0 +1,107 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <cstddef>
11
+ #include <cstdint>
12
+
13
+ namespace faiss {
14
+
15
+ /// Initialization methods for k-means clustering centroids
16
+ enum class ClusteringInitMethod : uint8_t {
17
+ /// Random sampling: select k random points uniformly from the dataset.
18
+ /// Time complexity: O(k)
19
+ RANDOM,
20
+
21
+ /// k-means++: select centroids with probability proportional to D(x)²,
22
+ /// where D(x) is the distance to the nearest existing centroid.
23
+ /// Reference: Arthur, D., & Vassilvitskii, S. (2006). k-means++:
24
+ /// The advantages of careful seeding. Stanford.
25
+ /// Time complexity: O(nkd)
26
+ KMEANS_PLUS_PLUS,
27
+
28
+ /// AFK-MC²: Assumption-Free K-MC² using Markov Chain Monte Carlo.
29
+ /// Provides theoretical guarantees without assumptions on data
30
+ /// distribution.
31
+ /// Uses a non-uniform proposal distribution based on D²-sampling from
32
+ /// the first center, combined with uniform sampling for regularization.
33
+ /// Reference: Bachem, O., Lucic, M., Hassani, H., & Krause, A. (2016).
34
+ /// Fast and provably good seedings for k-means. Advances in neural
35
+ /// information processing systems, 29.
36
+ /// Time complexity: O(nd) preprocessing + O(mk²d) main loop
37
+ AFK_MC2
38
+ };
39
+
40
+ /// Centroid initialization for k-means clustering.
41
+ ///
42
+ /// This class provides different algorithms for selecting initial centroids
43
+ /// before running k-means iterations. Good initialization can significantly
44
+ /// improve clustering quality and convergence speed.
45
+ ///
46
+ /// Example usage:
47
+ /// @code
48
+ /// ClusteringInitialization init(128, 1000); // d=128, k=1000
49
+ /// init.method = ClusteringInitMethod::KMEANS_PLUS_PLUS;
50
+ /// init.seed = 42;
51
+ ///
52
+ /// std::vector<float> centroids(128 * 1000);
53
+ /// init.init_centroids(n, x, centroids.data());
54
+ /// @endcode
55
+ struct ClusteringInitialization {
56
+ size_t d; ///< vector dimension
57
+ size_t k; ///< number of centroids to initialize
58
+
59
+ /// Initialization method to use
60
+ ClusteringInitMethod method = ClusteringInitMethod::RANDOM;
61
+
62
+ /// Random seed.
63
+ int64_t seed = 1234;
64
+
65
+ /// Chain length for AFK-MC² (only used when method = AFK_MC2).
66
+ /// Longer chains give better approximation to k-means++ but are slower.
67
+ uint16_t afkmc2_chain_length = 50;
68
+
69
+ ClusteringInitialization(size_t d, size_t k);
70
+
71
+ /// Initialize k centroids from n input vectors.
72
+ ///
73
+ /// @param n number of input vectors
74
+ /// @param x input vectors, size (n, d), row-major
75
+ /// @param centroids output centroids, size (k, d), row-major
76
+ /// @param n_existing_centroids number of pre-existing centroids to
77
+ /// consider
78
+ /// when computing distances (for k-means++ and
79
+ /// AFK-MC²). These centroids are not modified.
80
+ /// @param existing_centroids pre-existing centroids, size
81
+ /// (n_existing_centroids, d), row-major.
82
+ /// New centroids will be selected to be far
83
+ /// from these existing ones.
84
+ void init_centroids(
85
+ size_t n,
86
+ const float* x,
87
+ float* centroids,
88
+ size_t n_existing_centroids = 0,
89
+ const float* existing_centroids = nullptr) const;
90
+
91
+ private:
92
+ void init_random(size_t n, const float* x, float* centroids) const;
93
+ void init_kmeans_plus_plus(
94
+ size_t n,
95
+ const float* x,
96
+ float* centroids,
97
+ size_t n_existing_centroids,
98
+ const float* existing_centroids) const;
99
+ void init_afkmc2(
100
+ size_t n,
101
+ const float* x,
102
+ float* centroids,
103
+ size_t n_existing_centroids,
104
+ const float* existing_centroids) const;
105
+ };
106
+
107
+ } // namespace faiss
@@ -64,4 +64,8 @@ void CodePackerFlat::unpack_1(
64
64
  unpack_all(block, flat_code);
65
65
  }
66
66
 
67
+ CodePacker* CodePackerFlat::clone() const {
68
+ return new CodePackerFlat(*this);
69
+ }
70
+
67
71
  } // namespace faiss
@@ -18,9 +18,13 @@ namespace faiss {
18
18
  * the "fast_scan" indexes on CPU and for some GPU kernels.
19
19
  */
20
20
  struct CodePacker {
21
- size_t code_size; // input code size in bytes
22
- size_t nvec; // number of vectors per block
23
- size_t block_size; // size of one block in bytes (>= code_size * nvec)
21
+ size_t code_size = 0; // input code size in bytes
22
+ size_t nvec = 0; // number of vectors per block
23
+ size_t block_size = 0; // size of one block in bytes (>= code_size * nvec)
24
+
25
+ CodePacker() = default;
26
+ CodePacker(const CodePacker&) = default;
27
+ CodePacker& operator=(const CodePacker&) = default;
24
28
 
25
29
  // pack a single code to a block
26
30
  virtual void pack_1(
@@ -52,6 +56,8 @@ struct CodePacker {
52
56
  // * code_size)
53
57
  ) const;
54
58
 
59
+ virtual CodePacker* clone() const = 0;
60
+
55
61
  virtual ~CodePacker() {}
56
62
  };
57
63
 
@@ -59,6 +65,8 @@ struct CodePacker {
59
65
  struct CodePackerFlat : CodePacker {
60
66
  explicit CodePackerFlat(size_t code_size);
61
67
 
68
+ CodePacker* clone() const final;
69
+
62
70
  void pack_1(const uint8_t* flat_code, size_t offset, uint8_t* block)
63
71
  const final;
64
72
  void unpack_1(const uint8_t* block, size_t offset, uint8_t* flat_code)
@@ -0,0 +1,83 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/impl/CodePackerRaBitQ.h>
9
+ #include <faiss/impl/pq4_fast_scan.h>
10
+
11
+ #include <cstring>
12
+
13
+ namespace faiss {
14
+
15
+ CodePackerRaBitQ::CodePackerRaBitQ(
16
+ size_t nsq,
17
+ size_t bbs,
18
+ size_t aux_per_vector) {
19
+ this->nsq = nsq;
20
+ this->aux_size_per_vec = aux_per_vector;
21
+ nvec = bbs;
22
+ const size_t pq4_bytes = (nsq * 4 + 7) / 8;
23
+ // code_size covers PQ4 codes + auxiliary data so that callers
24
+ // (BlockInvertedLists::remove_ids, add_entries, etc.) allocate
25
+ // buffers large enough and pack_1/unpack_1 transfer everything.
26
+ code_size = pq4_bytes + aux_per_vector;
27
+ // block_size = PQ4 packed codes + auxiliary data region
28
+ block_size = ((nsq + 1) / 2) * bbs + aux_per_vector * bbs;
29
+ }
30
+
31
+ void CodePackerRaBitQ::pack_1(
32
+ const uint8_t* flat_code,
33
+ size_t offset,
34
+ uint8_t* block) const {
35
+ const size_t bbs = nvec;
36
+ const size_t pq4_bytes = (nsq * 4 + 7) / 8;
37
+ if (offset >= nvec) {
38
+ block += (offset / nvec) * block_size;
39
+ offset = offset % nvec;
40
+ }
41
+ for (size_t i = 0; i < pq4_bytes; i++) {
42
+ uint8_t code = flat_code[i];
43
+ pq4_set_packed_element(block, code & 15, bbs, nsq, offset, 2 * i);
44
+ pq4_set_packed_element(block, code >> 4, bbs, nsq, offset, 2 * i + 1);
45
+ }
46
+ // Pack auxiliary data (factors, ex-codes) into the block aux region
47
+ if (aux_size_per_vec > 0) {
48
+ const size_t packed_block_size = ((nsq + 1) / 2) * bbs;
49
+ uint8_t* dst = block + packed_block_size + offset * aux_size_per_vec;
50
+ memcpy(dst, flat_code + pq4_bytes, aux_size_per_vec);
51
+ }
52
+ }
53
+
54
+ void CodePackerRaBitQ::unpack_1(
55
+ const uint8_t* block,
56
+ size_t offset,
57
+ uint8_t* flat_code) const {
58
+ const size_t bbs = nvec;
59
+ const size_t pq4_bytes = (nsq * 4 + 7) / 8;
60
+ if (offset >= nvec) {
61
+ block += (offset / nvec) * block_size;
62
+ offset = offset % nvec;
63
+ }
64
+ for (size_t i = 0; i < pq4_bytes; i++) {
65
+ uint8_t code0, code1;
66
+ code0 = pq4_get_packed_element(block, bbs, nsq, offset, 2 * i);
67
+ code1 = pq4_get_packed_element(block, bbs, nsq, offset, 2 * i + 1);
68
+ flat_code[i] = code0 | (code1 << 4);
69
+ }
70
+ // Unpack auxiliary data from the block aux region
71
+ if (aux_size_per_vec > 0) {
72
+ const size_t packed_block_size = ((nsq + 1) / 2) * bbs;
73
+ const uint8_t* src =
74
+ block + packed_block_size + offset * aux_size_per_vec;
75
+ memcpy(flat_code + pq4_bytes, src, aux_size_per_vec);
76
+ }
77
+ }
78
+
79
+ CodePacker* CodePackerRaBitQ::clone() const {
80
+ return new CodePackerRaBitQ(*this);
81
+ }
82
+
83
+ } // namespace faiss
@@ -0,0 +1,47 @@
1
+ /*
2
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <cstddef>
11
+ #include <cstdint>
12
+
13
+ #include <faiss/impl/CodePacker.h>
14
+
15
+ namespace faiss {
16
+
17
+ /** CodePacker for RaBitQ that allocates enlarged blocks.
18
+ *
19
+ * Each block contains the standard PQ4 packed codes region (bbs * nsq / 2
20
+ * bytes) followed by an auxiliary data region for per-vector factors.
21
+ * The pack_1/unpack_1 operations transfer BOTH the PQ4 codes and the
22
+ * auxiliary data, so callers such as BlockInvertedLists::remove_ids()
23
+ * and add_entries() automatically preserve auxiliary data.
24
+ *
25
+ * code_size = PQ4 flat bytes + aux_size_per_vec, which must match the
26
+ * buffer sizes allocated by callers (e.g. the index's code_size field).
27
+ */
28
+ struct CodePackerRaBitQ : CodePacker {
29
+ size_t nsq;
30
+ size_t aux_size_per_vec;
31
+
32
+ /** Construct a RaBitQ code packer.
33
+ * @param nsq number of sub-quantizers (M2)
34
+ * @param bbs block size (number of vectors per block)
35
+ * @param aux_per_vector bytes of auxiliary data per vector
36
+ */
37
+ CodePackerRaBitQ(size_t nsq, size_t bbs, size_t aux_per_vector);
38
+
39
+ CodePacker* clone() const final;
40
+
41
+ void pack_1(const uint8_t* flat_code, size_t offset, uint8_t* block)
42
+ const final;
43
+ void unpack_1(const uint8_t* block, size_t offset, uint8_t* flat_code)
44
+ const final;
45
+ };
46
+
47
+ } // namespace faiss