faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -8,6 +8,7 @@
8
8
  // -*- c++ -*-
9
9
 
10
10
  #include <faiss/Clustering.h>
11
+ #include <faiss/VectorTransform.h>
11
12
  #include <faiss/impl/AuxIndexStructures.h>
12
13
 
13
14
  #include <cinttypes>
@@ -17,100 +18,101 @@
17
18
 
18
19
  #include <omp.h>
19
20
 
20
- #include <faiss/utils/utils.h>
21
- #include <faiss/utils/random.h>
22
- #include <faiss/utils/distances.h>
23
- #include <faiss/impl/FaissAssert.h>
24
21
  #include <faiss/IndexFlat.h>
22
+ #include <faiss/impl/FaissAssert.h>
23
+ #include <faiss/impl/kmeans1d.h>
24
+ #include <faiss/utils/distances.h>
25
+ #include <faiss/utils/random.h>
26
+ #include <faiss/utils/utils.h>
25
27
 
26
28
  namespace faiss {
27
29
 
28
- ClusteringParameters::ClusteringParameters ():
29
- niter(25),
30
- nredo(1),
31
- verbose(false),
32
- spherical(false),
33
- int_centroids(false),
34
- update_index(false),
35
- frozen_centroids(false),
36
- min_points_per_centroid(39),
37
- max_points_per_centroid(256),
38
- seed(1234),
39
- decode_block_size(32768)
40
- {}
30
+ ClusteringParameters::ClusteringParameters()
31
+ : niter(25),
32
+ nredo(1),
33
+ verbose(false),
34
+ spherical(false),
35
+ int_centroids(false),
36
+ update_index(false),
37
+ frozen_centroids(false),
38
+ min_points_per_centroid(39),
39
+ max_points_per_centroid(256),
40
+ seed(1234),
41
+ decode_block_size(32768) {}
41
42
  // 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
42
43
 
44
+ Clustering::Clustering(int d, int k) : d(d), k(k) {}
43
45
 
44
- Clustering::Clustering (int d, int k):
45
- d(d), k(k) {}
46
-
47
- Clustering::Clustering (int d, int k, const ClusteringParameters &cp):
48
- ClusteringParameters (cp), d(d), k(k) {}
46
+ Clustering::Clustering(int d, int k, const ClusteringParameters& cp)
47
+ : ClusteringParameters(cp), d(d), k(k) {}
49
48
 
50
-
51
-
52
- static double imbalance_factor (int n, int k, int64_t *assign) {
49
+ static double imbalance_factor(int n, int k, int64_t* assign) {
53
50
  std::vector<int> hist(k, 0);
54
51
  for (int i = 0; i < n; i++)
55
52
  hist[assign[i]]++;
56
53
 
57
54
  double tot = 0, uf = 0;
58
55
 
59
- for (int i = 0 ; i < k ; i++) {
56
+ for (int i = 0; i < k; i++) {
60
57
  tot += hist[i];
61
- uf += hist[i] * (double) hist[i];
58
+ uf += hist[i] * (double)hist[i];
62
59
  }
63
60
  uf = uf * k / (tot * tot);
64
61
 
65
62
  return uf;
66
63
  }
67
64
 
68
- void Clustering::post_process_centroids ()
69
- {
70
-
65
+ void Clustering::post_process_centroids() {
71
66
  if (spherical) {
72
- fvec_renorm_L2 (d, k, centroids.data());
67
+ fvec_renorm_L2(d, k, centroids.data());
73
68
  }
74
69
 
75
70
  if (int_centroids) {
76
71
  for (size_t i = 0; i < centroids.size(); i++)
77
- centroids[i] = roundf (centroids[i]);
72
+ centroids[i] = roundf(centroids[i]);
78
73
  }
79
74
  }
80
75
 
81
-
82
- void Clustering::train (idx_t nx, const float *x_in, Index & index,
83
- const float *weights) {
84
- train_encoded (nx, reinterpret_cast<const uint8_t *>(x_in), nullptr,
85
- index, weights);
76
+ void Clustering::train(
77
+ idx_t nx,
78
+ const float* x_in,
79
+ Index& index,
80
+ const float* weights) {
81
+ train_encoded(
82
+ nx,
83
+ reinterpret_cast<const uint8_t*>(x_in),
84
+ nullptr,
85
+ index,
86
+ weights);
86
87
  }
87
88
 
88
-
89
89
  namespace {
90
90
 
91
91
  using idx_t = Clustering::idx_t;
92
92
 
93
93
  idx_t subsample_training_set(
94
- const Clustering &clus, idx_t nx, const uint8_t *x,
95
- size_t line_size, const float * weights,
96
- uint8_t **x_out,
97
- float **weights_out
98
- )
99
- {
94
+ const Clustering& clus,
95
+ idx_t nx,
96
+ const uint8_t* x,
97
+ size_t line_size,
98
+ const float* weights,
99
+ uint8_t** x_out,
100
+ float** weights_out) {
100
101
  if (clus.verbose) {
101
102
  printf("Sampling a subset of %zd / %" PRId64 " for training\n",
102
- clus.k * clus.max_points_per_centroid, nx);
103
+ clus.k * clus.max_points_per_centroid,
104
+ nx);
103
105
  }
104
- std::vector<int> perm (nx);
105
- rand_perm (perm.data (), nx, clus.seed);
106
+ std::vector<int> perm(nx);
107
+ rand_perm(perm.data(), nx, clus.seed);
106
108
  nx = clus.k * clus.max_points_per_centroid;
107
- uint8_t * x_new = new uint8_t [nx * line_size];
109
+ uint8_t* x_new = new uint8_t[nx * line_size];
108
110
  *x_out = x_new;
109
111
  for (idx_t i = 0; i < nx; i++) {
110
- memcpy (x_new + i * line_size, x + perm[i] * line_size, line_size);
112
+ memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size);
111
113
  }
112
114
  if (weights) {
113
- float *weights_new = new float[nx];
115
+ float* weights_new = new float[nx];
114
116
  for (idx_t i = 0; i < nx; i++) {
115
117
  weights_new[i] = weights[perm[i]];
116
118
  }
@@ -134,20 +136,23 @@ idx_t subsample_training_set(
134
136
  *
135
137
  */
136
138
 
137
- void compute_centroids (size_t d, size_t k, size_t n,
138
- size_t k_frozen,
139
- const uint8_t * x, const Index *codec,
140
- const int64_t * assign,
141
- const float * weights,
142
- float * hassign,
143
- float * centroids)
144
- {
139
+ void compute_centroids(
140
+ size_t d,
141
+ size_t k,
142
+ size_t n,
143
+ size_t k_frozen,
144
+ const uint8_t* x,
145
+ const Index* codec,
146
+ const int64_t* assign,
147
+ const float* weights,
148
+ float* hassign,
149
+ float* centroids) {
145
150
  k -= k_frozen;
146
151
  centroids += k_frozen * d;
147
152
 
148
- memset (centroids, 0, sizeof(*centroids) * d * k);
153
+ memset(centroids, 0, sizeof(*centroids) * d * k);
149
154
 
150
- size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float);
155
+ size_t line_size = codec ? codec->sa_code_size() : d * sizeof(float);
151
156
 
152
157
  #pragma omp parallel
153
158
  {
@@ -157,20 +162,20 @@ void compute_centroids (size_t d, size_t k, size_t n,
157
162
  // this thread is taking care of centroids c0:c1
158
163
  size_t c0 = (k * rank) / nt;
159
164
  size_t c1 = (k * (rank + 1)) / nt;
160
- std::vector<float> decode_buffer (d);
165
+ std::vector<float> decode_buffer(d);
161
166
 
162
167
  for (size_t i = 0; i < n; i++) {
163
168
  int64_t ci = assign[i];
164
- assert (ci >= 0 && ci < k + k_frozen);
169
+ assert(ci >= 0 && ci < k + k_frozen);
165
170
  ci -= k_frozen;
166
- if (ci >= c0 && ci < c1) {
167
- float * c = centroids + ci * d;
168
- const float * xi;
171
+ if (ci >= c0 && ci < c1) {
172
+ float* c = centroids + ci * d;
173
+ const float* xi;
169
174
  if (!codec) {
170
175
  xi = reinterpret_cast<const float*>(x + i * line_size);
171
176
  } else {
172
- float *xif = decode_buffer.data();
173
- codec->sa_decode (1, x + i * line_size, xif);
177
+ float* xif = decode_buffer.data();
178
+ codec->sa_decode(1, x + i * line_size, xif);
174
179
  xi = xif;
175
180
  }
176
181
  if (weights) {
@@ -187,7 +192,6 @@ void compute_centroids (size_t d, size_t k, size_t n,
187
192
  }
188
193
  }
189
194
  }
190
-
191
195
  }
192
196
 
193
197
  #pragma omp parallel for
@@ -196,12 +200,11 @@ void compute_centroids (size_t d, size_t k, size_t n,
196
200
  continue;
197
201
  }
198
202
  float norm = 1 / hassign[ci];
199
- float * c = centroids + ci * d;
203
+ float* c = centroids + ci * d;
200
204
  for (size_t j = 0; j < d; j++) {
201
205
  c[j] *= norm;
202
206
  }
203
207
  }
204
-
205
208
  }
206
209
 
207
210
  // a bit above machine epsilon for float16
@@ -214,29 +217,33 @@ void compute_centroids (size_t d, size_t k, size_t n,
214
217
  *
215
218
  * @return nb of spliting operations (larger is worse)
216
219
  */
217
- int split_clusters (size_t d, size_t k, size_t n,
218
- size_t k_frozen,
219
- float * hassign,
220
- float * centroids)
221
- {
220
+ int split_clusters(
221
+ size_t d,
222
+ size_t k,
223
+ size_t n,
224
+ size_t k_frozen,
225
+ float* hassign,
226
+ float* centroids) {
222
227
  k -= k_frozen;
223
228
  centroids += k_frozen * d;
224
229
 
225
230
  /* Take care of void clusters */
226
231
  size_t nsplit = 0;
227
- RandomGenerator rng (1234);
232
+ RandomGenerator rng(1234);
228
233
  for (size_t ci = 0; ci < k; ci++) {
229
234
  if (hassign[ci] == 0) { /* need to redefine a centroid */
230
235
  size_t cj;
231
236
  for (cj = 0; 1; cj = (cj + 1) % k) {
232
237
  /* probability to pick this cluster for split */
233
- float p = (hassign[cj] - 1.0) / (float) (n - k);
234
- float r = rng.rand_float ();
238
+ float p = (hassign[cj] - 1.0) / (float)(n - k);
239
+ float r = rng.rand_float();
235
240
  if (r < p) {
236
241
  break; /* found our cluster to be split */
237
242
  }
238
243
  }
239
- memcpy (centroids+ci*d, centroids+cj*d, sizeof(*centroids) * d);
244
+ memcpy(centroids + ci * d,
245
+ centroids + cj * d,
246
+ sizeof(*centroids) * d);
240
247
 
241
248
  /* small symmetric pertubation */
242
249
  for (size_t j = 0; j < d; j++) {
@@ -257,30 +264,35 @@ int split_clusters (size_t d, size_t k, size_t n,
257
264
  }
258
265
 
259
266
  return nsplit;
260
-
261
267
  }
262
268
 
263
-
264
-
265
- };
266
-
267
-
268
- void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
269
- const Index * codec, Index & index,
270
- const float *weights) {
271
-
272
-
273
- FAISS_THROW_IF_NOT_FMT (nx >= k,
274
- "Number of training points (%" PRId64 ") should be at least "
275
- "as large as number of clusters (%zd)", nx, k);
276
-
277
- FAISS_THROW_IF_NOT_FMT ((!codec || codec->d == d),
278
- "Codec dimension %d not the same as data dimension %d",
279
- int(codec->d), int(d));
280
-
281
- FAISS_THROW_IF_NOT_FMT (index.d == d,
269
+ }; // namespace
270
+
271
+ void Clustering::train_encoded(
272
+ idx_t nx,
273
+ const uint8_t* x_in,
274
+ const Index* codec,
275
+ Index& index,
276
+ const float* weights) {
277
+ FAISS_THROW_IF_NOT_FMT(
278
+ nx >= k,
279
+ "Number of training points (%" PRId64
280
+ ") should be at least "
281
+ "as large as number of clusters (%zd)",
282
+ nx,
283
+ k);
284
+
285
+ FAISS_THROW_IF_NOT_FMT(
286
+ (!codec || codec->d == d),
287
+ "Codec dimension %d not the same as data dimension %d",
288
+ int(codec->d),
289
+ int(d));
290
+
291
+ FAISS_THROW_IF_NOT_FMT(
292
+ index.d == d,
282
293
  "Index dimension %d not the same as data dimension %d",
283
- int(index.d), int(d));
294
+ int(index.d),
295
+ int(d));
284
296
 
285
297
  double t0 = getmillisecs();
286
298
 
@@ -288,67 +300,78 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
288
300
  // Check for NaNs in input data. Normally it is the user's
289
301
  // responsibility, but it may spare us some hard-to-debug
290
302
  // reports.
291
- const float *x = reinterpret_cast<const float *>(x_in);
303
+ const float* x = reinterpret_cast<const float*>(x_in);
292
304
  for (size_t i = 0; i < nx * d; i++) {
293
- FAISS_THROW_IF_NOT_MSG (std::isfinite (x[i]),
294
- "input contains NaN's or Inf's");
305
+ FAISS_THROW_IF_NOT_MSG(
306
+ std::isfinite(x[i]), "input contains NaN's or Inf's");
295
307
  }
296
308
  }
297
309
 
298
- const uint8_t *x = x_in;
299
- std::unique_ptr<uint8_t []> del1;
300
- std::unique_ptr<float []> del3;
310
+ const uint8_t* x = x_in;
311
+ std::unique_ptr<uint8_t[]> del1;
312
+ std::unique_ptr<float[]> del3;
301
313
  size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
302
314
 
303
315
  if (nx > k * max_points_per_centroid) {
304
- uint8_t *x_new;
305
- float *weights_new;
306
- nx = subsample_training_set (*this, nx, x, line_size, weights,
307
- &x_new, &weights_new);
308
- del1.reset (x_new); x = x_new;
309
- del3.reset (weights_new); weights = weights_new;
316
+ uint8_t* x_new;
317
+ float* weights_new;
318
+ nx = subsample_training_set(
319
+ *this, nx, x, line_size, weights, &x_new, &weights_new);
320
+ del1.reset(x_new);
321
+ x = x_new;
322
+ del3.reset(weights_new);
323
+ weights = weights_new;
310
324
  } else if (nx < k * min_points_per_centroid) {
311
- fprintf (stderr,
312
- "WARNING clustering %" PRId64 " points to %zd centroids: "
313
- "please provide at least %" PRId64 " training points\n",
314
- nx, k, idx_t(k) * min_points_per_centroid);
325
+ fprintf(stderr,
326
+ "WARNING clustering %" PRId64
327
+ " points to %zd centroids: "
328
+ "please provide at least %" PRId64 " training points\n",
329
+ nx,
330
+ k,
331
+ idx_t(k) * min_points_per_centroid);
315
332
  }
316
333
 
317
334
  if (nx == k) {
318
335
  // this is a corner case, just copy training set to clusters
319
336
  if (verbose) {
320
- printf("Number of training points (%" PRId64 ") same as number of "
321
- "clusters, just copying\n", nx);
337
+ printf("Number of training points (%" PRId64
338
+ ") same as number of "
339
+ "clusters, just copying\n",
340
+ nx);
322
341
  }
323
- centroids.resize (d * k);
342
+ centroids.resize(d * k);
324
343
  if (!codec) {
325
- memcpy (centroids.data(), x_in, sizeof (float) * d * k);
344
+ memcpy(centroids.data(), x_in, sizeof(float) * d * k);
326
345
  } else {
327
- codec->sa_decode (nx, x_in, centroids.data());
346
+ codec->sa_decode(nx, x_in, centroids.data());
328
347
  }
329
348
 
330
349
  // one fake iteration...
331
- ClusteringIterationStats stats = { 0.0, 0.0, 0.0, 1.0, 0 };
332
- iteration_stats.push_back (stats);
350
+ ClusteringIterationStats stats = {0.0, 0.0, 0.0, 1.0, 0};
351
+ iteration_stats.push_back(stats);
333
352
 
334
353
  index.reset();
335
354
  index.add(k, centroids.data());
336
355
  return;
337
356
  }
338
357
 
339
-
340
358
  if (verbose) {
341
- printf("Clustering %" PRId64 " points in %zdD to %zd clusters, "
359
+ printf("Clustering %" PRId64
360
+ " points in %zdD to %zd clusters, "
342
361
  "redo %d times, %d iterations\n",
343
- nx, d, k, nredo, niter);
362
+ nx,
363
+ d,
364
+ k,
365
+ nredo,
366
+ niter);
344
367
  if (codec) {
345
368
  printf("Input data encoded in %zd bytes per vector\n",
346
- codec->sa_code_size ());
369
+ codec->sa_code_size());
347
370
  }
348
371
  }
349
372
 
350
- std::unique_ptr<idx_t []> assign(new idx_t[nx]);
351
- std::unique_ptr<float []> dis(new float[nx]);
373
+ std::unique_ptr<idx_t[]> assign(new idx_t[nx]);
374
+ std::unique_ptr<float[]> dis(new float[nx]);
352
375
 
353
376
  // remember best iteration for redo
354
377
  bool lower_is_better = index.metric_type != METRIC_INNER_PRODUCT;
@@ -358,52 +381,49 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
358
381
 
359
382
  // support input centroids
360
383
 
361
- FAISS_THROW_IF_NOT_MSG (
362
- centroids.size() % d == 0,
363
- "size of provided input centroids not a multiple of dimension"
364
- );
384
+ FAISS_THROW_IF_NOT_MSG(
385
+ centroids.size() % d == 0,
386
+ "size of provided input centroids not a multiple of dimension");
365
387
 
366
388
  size_t n_input_centroids = centroids.size() / d;
367
389
 
368
390
  if (verbose && n_input_centroids > 0) {
369
- printf (" Using %zd centroids provided as input (%sfrozen)\n",
370
- n_input_centroids, frozen_centroids ? "" : "not ");
391
+ printf(" Using %zd centroids provided as input (%sfrozen)\n",
392
+ n_input_centroids,
393
+ frozen_centroids ? "" : "not ");
371
394
  }
372
395
 
373
396
  double t_search_tot = 0;
374
397
  if (verbose) {
375
- printf(" Preprocessing in %.2f s\n",
376
- (getmillisecs() - t0) / 1000.);
398
+ printf(" Preprocessing in %.2f s\n", (getmillisecs() - t0) / 1000.);
377
399
  }
378
400
  t0 = getmillisecs();
379
401
 
380
402
  // temporary buffer to decode vectors during the optimization
381
- std::vector<float> decode_buffer
382
- (codec ? d * decode_block_size : 0);
403
+ std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
383
404
 
384
405
  for (int redo = 0; redo < nredo; redo++) {
385
-
386
406
  if (verbose && nredo > 1) {
387
407
  printf("Outer iteration %d / %d\n", redo, nredo);
388
408
  }
389
409
 
390
410
  // initialize (remaining) centroids with random points from the dataset
391
- centroids.resize (d * k);
392
- std::vector<int> perm (nx);
411
+ centroids.resize(d * k);
412
+ std::vector<int> perm(nx);
393
413
 
394
- rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
414
+ rand_perm(perm.data(), nx, seed + 1 + redo * 15486557L);
395
415
 
396
416
  if (!codec) {
397
- for (int i = n_input_centroids; i < k ; i++) {
398
- memcpy (&centroids[i * d], x + perm[i] * line_size, line_size);
417
+ for (int i = n_input_centroids; i < k; i++) {
418
+ memcpy(&centroids[i * d], x + perm[i] * line_size, line_size);
399
419
  }
400
420
  } else {
401
- for (int i = n_input_centroids; i < k ; i++) {
402
- codec->sa_decode (1, x + perm[i] * line_size, &centroids[i * d]);
421
+ for (int i = n_input_centroids; i < k; i++) {
422
+ codec->sa_decode(1, x + perm[i] * line_size, &centroids[i * d]);
403
423
  }
404
424
  }
405
425
 
406
- post_process_centroids ();
426
+ post_process_centroids();
407
427
 
408
428
  // prepare the index
409
429
 
@@ -412,10 +432,10 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
412
432
  }
413
433
 
414
434
  if (!index.is_trained) {
415
- index.train (k, centroids.data());
435
+ index.train(k, centroids.data());
416
436
  }
417
437
 
418
- index.add (k, centroids.data());
438
+ index.add(k, centroids.data());
419
439
 
420
440
  // k-means iterations
421
441
 
@@ -424,18 +444,28 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
424
444
  double t0s = getmillisecs();
425
445
 
426
446
  if (!codec) {
427
- index.search (nx, reinterpret_cast<const float *>(x), 1,
428
- dis.get(), assign.get());
447
+ index.search(
448
+ nx,
449
+ reinterpret_cast<const float*>(x),
450
+ 1,
451
+ dis.get(),
452
+ assign.get());
429
453
  } else {
430
454
  // search by blocks of decode_block_size vectors
431
- size_t code_size = codec->sa_code_size ();
455
+ size_t code_size = codec->sa_code_size();
432
456
  for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
433
457
  size_t i1 = i0 + decode_block_size;
434
- if (i1 > nx) { i1 = nx; }
435
- codec->sa_decode (i1 - i0, x + code_size * i0,
436
- decode_buffer.data ());
437
- index.search (i1 - i0, decode_buffer.data (), 1,
438
- dis.get() + i0, assign.get() + i0);
458
+ if (i1 > nx) {
459
+ i1 = nx;
460
+ }
461
+ codec->sa_decode(
462
+ i1 - i0, x + code_size * i0, decode_buffer.data());
463
+ index.search(
464
+ i1 - i0,
465
+ decode_buffer.data(),
466
+ 1,
467
+ dis.get() + i0,
468
+ assign.get() + i0);
439
469
  }
440
470
  }
441
471
 
@@ -449,61 +479,71 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
449
479
  }
450
480
 
451
481
  // update the centroids
452
- std::vector<float> hassign (k);
482
+ std::vector<float> hassign(k);
453
483
 
454
484
  size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
455
- compute_centroids (
456
- d, k, nx, k_frozen,
457
- x, codec, assign.get(), weights,
458
- hassign.data(), centroids.data()
459
- );
460
-
461
- int nsplit = split_clusters (
462
- d, k, nx, k_frozen,
463
- hassign.data(), centroids.data()
464
- );
485
+ compute_centroids(
486
+ d,
487
+ k,
488
+ nx,
489
+ k_frozen,
490
+ x,
491
+ codec,
492
+ assign.get(),
493
+ weights,
494
+ hassign.data(),
495
+ centroids.data());
496
+
497
+ int nsplit = split_clusters(
498
+ d, k, nx, k_frozen, hassign.data(), centroids.data());
465
499
 
466
500
  // collect statistics
467
- ClusteringIterationStats stats =
468
- { obj, (getmillisecs() - t0) / 1000.0,
469
- t_search_tot / 1000,
470
- imbalance_factor (nx, k, assign.get()),
471
- nsplit };
501
+ ClusteringIterationStats stats = {
502
+ obj,
503
+ (getmillisecs() - t0) / 1000.0,
504
+ t_search_tot / 1000,
505
+ imbalance_factor(nx, k, assign.get()),
506
+ nsplit};
472
507
  iteration_stats.push_back(stats);
473
508
 
474
509
  if (verbose) {
475
- printf (" Iteration %d (%.2f s, search %.2f s): "
476
- "objective=%g imbalance=%.3f nsplit=%d \r",
477
- i, stats.time, stats.time_search, stats.obj,
478
- stats.imbalance_factor, nsplit);
479
- fflush (stdout);
510
+ printf(" Iteration %d (%.2f s, search %.2f s): "
511
+ "objective=%g imbalance=%.3f nsplit=%d \r",
512
+ i,
513
+ stats.time,
514
+ stats.time_search,
515
+ stats.obj,
516
+ stats.imbalance_factor,
517
+ nsplit);
518
+ fflush(stdout);
480
519
  }
481
520
 
482
- post_process_centroids ();
521
+ post_process_centroids();
483
522
 
484
523
  // add centroids to index for the next iteration (or for output)
485
524
 
486
- index.reset ();
525
+ index.reset();
487
526
  if (update_index) {
488
- index.train (k, centroids.data());
527
+ index.train(k, centroids.data());
489
528
  }
490
529
 
491
- index.add (k, centroids.data());
492
- InterruptCallback::check ();
530
+ index.add(k, centroids.data());
531
+ InterruptCallback::check();
493
532
  }
494
533
 
495
- if (verbose) printf("\n");
534
+ if (verbose)
535
+ printf("\n");
496
536
  if (nredo > 1) {
497
537
  if ((lower_is_better && obj < best_obj) ||
498
538
  (!lower_is_better && obj > best_obj)) {
499
539
  if (verbose) {
500
- printf ("Objective improved: keep new clusters\n");
540
+ printf("Objective improved: keep new clusters\n");
501
541
  }
502
542
  best_centroids = centroids;
503
543
  best_iteration_stats = iteration_stats;
504
544
  best_obj = obj;
505
545
  }
506
- index.reset ();
546
+ index.reset();
507
547
  }
508
548
  }
509
549
  if (nredo > 1) {
@@ -512,20 +552,151 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
512
552
  index.reset();
513
553
  index.add(k, best_centroids.data());
514
554
  }
555
+ }
556
+
557
+ Clustering1D::Clustering1D(int k) : Clustering(1, k) {}
558
+
559
+ Clustering1D::Clustering1D(int k, const ClusteringParameters& cp)
560
+ : Clustering(1, k, cp) {}
561
+
562
+ void Clustering1D::train_exact(idx_t n, const float* x) {
563
+ const float* xt = x;
564
+
565
+ std::unique_ptr<uint8_t[]> del;
566
+ if (n > k * max_points_per_centroid) {
567
+ uint8_t* x_new;
568
+ float* weights_new;
569
+ n = subsample_training_set(
570
+ *this,
571
+ n,
572
+ (uint8_t*)x,
573
+ sizeof(float) * d,
574
+ nullptr,
575
+ &x_new,
576
+ &weights_new);
577
+ del.reset(x_new);
578
+ xt = (float*)x_new;
579
+ }
515
580
 
581
+ centroids.resize(k);
582
+ double uf = kmeans1d(xt, n, k, centroids.data());
583
+
584
+ ClusteringIterationStats stats = {0.0, 0.0, 0.0, uf, 0};
585
+ iteration_stats.push_back(stats);
516
586
  }
517
587
 
518
- float kmeans_clustering (size_t d, size_t n, size_t k,
519
- const float *x,
520
- float *centroids)
521
- {
522
- Clustering clus (d, k);
588
+ float kmeans_clustering(
589
+ size_t d,
590
+ size_t n,
591
+ size_t k,
592
+ const float* x,
593
+ float* centroids) {
594
+ Clustering clus(d, k);
523
595
  clus.verbose = d * n * k > (1L << 30);
524
596
  // display logs if > 1Gflop per iteration
525
- IndexFlatL2 index (d);
526
- clus.train (n, x, index);
597
+ IndexFlatL2 index(d);
598
+ clus.train(n, x, index);
527
599
  memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
528
600
  return clus.iteration_stats.back().obj;
529
601
  }
530
602
 
603
+ /******************************************************************************
604
+ * ProgressiveDimClustering implementation
605
+ ******************************************************************************/
606
+
607
+ ProgressiveDimClusteringParameters::ProgressiveDimClusteringParameters() {
608
+ progressive_dim_steps = 10;
609
+ apply_pca = true; // seems a good idea to do this by default
610
+ niter = 10; // reduce nb of iterations per step
611
+ }
612
+
613
+ Index* ProgressiveDimIndexFactory::operator()(int dim) {
614
+ return new IndexFlatL2(dim);
615
+ }
616
+
617
+ ProgressiveDimClustering::ProgressiveDimClustering(int d, int k) : d(d), k(k) {}
618
+
619
+ ProgressiveDimClustering::ProgressiveDimClustering(
620
+ int d,
621
+ int k,
622
+ const ProgressiveDimClusteringParameters& cp)
623
+ : ProgressiveDimClusteringParameters(cp), d(d), k(k) {}
624
+
625
+ namespace {
626
+
627
+ using idx_t = Index::idx_t;
628
+
629
+ void copy_columns(idx_t n, idx_t d1, const float* src, idx_t d2, float* dest) {
630
+ idx_t d = std::min(d1, d2);
631
+ for (idx_t i = 0; i < n; i++) {
632
+ memcpy(dest, src, sizeof(float) * d);
633
+ src += d1;
634
+ dest += d2;
635
+ }
636
+ }
637
+
638
+ }; // namespace
639
+
640
+ void ProgressiveDimClustering::train(
641
+ idx_t n,
642
+ const float* x,
643
+ ProgressiveDimIndexFactory& factory) {
644
+ int d_prev = 0;
645
+
646
+ PCAMatrix pca(d, d);
647
+
648
+ std::vector<float> xbuf;
649
+ if (apply_pca) {
650
+ if (verbose) {
651
+ printf("Training PCA transform\n");
652
+ }
653
+ pca.train(n, x);
654
+ if (verbose) {
655
+ printf("Apply PCA\n");
656
+ }
657
+ xbuf.resize(n * d);
658
+ pca.apply_noalloc(n, x, xbuf.data());
659
+ x = xbuf.data();
660
+ }
661
+
662
+ for (int iter = 0; iter < progressive_dim_steps; iter++) {
663
+ int di = int(pow(d, (1. + iter) / progressive_dim_steps));
664
+ if (verbose) {
665
+ printf("Progressive dim step %d: cluster in dimension %d\n",
666
+ iter,
667
+ di);
668
+ }
669
+ std::unique_ptr<Index> clustering_index(factory(di));
670
+
671
+ Clustering clus(di, k, *this);
672
+ if (d_prev > 0) {
673
+ // copy warm-start centroids (padded with 0s)
674
+ clus.centroids.resize(k * di);
675
+ copy_columns(
676
+ k, d_prev, centroids.data(), di, clus.centroids.data());
677
+ }
678
+ std::vector<float> xsub(n * di);
679
+ copy_columns(n, d, x, di, xsub.data());
680
+
681
+ clus.train(n, xsub.data(), *clustering_index.get());
682
+
683
+ centroids = clus.centroids;
684
+ iteration_stats.insert(
685
+ iteration_stats.end(),
686
+ clus.iteration_stats.begin(),
687
+ clus.iteration_stats.end());
688
+
689
+ d_prev = di;
690
+ }
691
+
692
+ if (apply_pca) {
693
+ if (verbose) {
694
+ printf("Revert PCA transform on centroids\n");
695
+ }
696
+ std::vector<float> cent_transformed(d * k);
697
+ pca.reverse_transform(k, centroids.data(), cent_transformed.data());
698
+ cent_transformed.swap(centroids);
699
+ }
700
+ }
701
+
531
702
  } // namespace faiss