faiss 0.2.0 → 0.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -8,6 +8,7 @@
8
8
  // -*- c++ -*-
9
9
 
10
10
  #include <faiss/Clustering.h>
11
+ #include <faiss/VectorTransform.h>
11
12
  #include <faiss/impl/AuxIndexStructures.h>
12
13
 
13
14
  #include <cinttypes>
@@ -17,100 +18,101 @@
17
18
 
18
19
  #include <omp.h>
19
20
 
20
- #include <faiss/utils/utils.h>
21
- #include <faiss/utils/random.h>
22
- #include <faiss/utils/distances.h>
23
- #include <faiss/impl/FaissAssert.h>
24
21
  #include <faiss/IndexFlat.h>
22
+ #include <faiss/impl/FaissAssert.h>
23
+ #include <faiss/impl/kmeans1d.h>
24
+ #include <faiss/utils/distances.h>
25
+ #include <faiss/utils/random.h>
26
+ #include <faiss/utils/utils.h>
25
27
 
26
28
  namespace faiss {
27
29
 
28
- ClusteringParameters::ClusteringParameters ():
29
- niter(25),
30
- nredo(1),
31
- verbose(false),
32
- spherical(false),
33
- int_centroids(false),
34
- update_index(false),
35
- frozen_centroids(false),
36
- min_points_per_centroid(39),
37
- max_points_per_centroid(256),
38
- seed(1234),
39
- decode_block_size(32768)
40
- {}
30
+ ClusteringParameters::ClusteringParameters()
31
+ : niter(25),
32
+ nredo(1),
33
+ verbose(false),
34
+ spherical(false),
35
+ int_centroids(false),
36
+ update_index(false),
37
+ frozen_centroids(false),
38
+ min_points_per_centroid(39),
39
+ max_points_per_centroid(256),
40
+ seed(1234),
41
+ decode_block_size(32768) {}
41
42
  // 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
42
43
 
44
+ Clustering::Clustering(int d, int k) : d(d), k(k) {}
43
45
 
44
- Clustering::Clustering (int d, int k):
45
- d(d), k(k) {}
46
-
47
- Clustering::Clustering (int d, int k, const ClusteringParameters &cp):
48
- ClusteringParameters (cp), d(d), k(k) {}
46
+ Clustering::Clustering(int d, int k, const ClusteringParameters& cp)
47
+ : ClusteringParameters(cp), d(d), k(k) {}
49
48
 
50
-
51
-
52
- static double imbalance_factor (int n, int k, int64_t *assign) {
49
+ static double imbalance_factor(int n, int k, int64_t* assign) {
53
50
  std::vector<int> hist(k, 0);
54
51
  for (int i = 0; i < n; i++)
55
52
  hist[assign[i]]++;
56
53
 
57
54
  double tot = 0, uf = 0;
58
55
 
59
- for (int i = 0 ; i < k ; i++) {
56
+ for (int i = 0; i < k; i++) {
60
57
  tot += hist[i];
61
- uf += hist[i] * (double) hist[i];
58
+ uf += hist[i] * (double)hist[i];
62
59
  }
63
60
  uf = uf * k / (tot * tot);
64
61
 
65
62
  return uf;
66
63
  }
67
64
 
68
- void Clustering::post_process_centroids ()
69
- {
70
-
65
+ void Clustering::post_process_centroids() {
71
66
  if (spherical) {
72
- fvec_renorm_L2 (d, k, centroids.data());
67
+ fvec_renorm_L2(d, k, centroids.data());
73
68
  }
74
69
 
75
70
  if (int_centroids) {
76
71
  for (size_t i = 0; i < centroids.size(); i++)
77
- centroids[i] = roundf (centroids[i]);
72
+ centroids[i] = roundf(centroids[i]);
78
73
  }
79
74
  }
80
75
 
81
-
82
- void Clustering::train (idx_t nx, const float *x_in, Index & index,
83
- const float *weights) {
84
- train_encoded (nx, reinterpret_cast<const uint8_t *>(x_in), nullptr,
85
- index, weights);
76
+ void Clustering::train(
77
+ idx_t nx,
78
+ const float* x_in,
79
+ Index& index,
80
+ const float* weights) {
81
+ train_encoded(
82
+ nx,
83
+ reinterpret_cast<const uint8_t*>(x_in),
84
+ nullptr,
85
+ index,
86
+ weights);
86
87
  }
87
88
 
88
-
89
89
  namespace {
90
90
 
91
91
  using idx_t = Clustering::idx_t;
92
92
 
93
93
  idx_t subsample_training_set(
94
- const Clustering &clus, idx_t nx, const uint8_t *x,
95
- size_t line_size, const float * weights,
96
- uint8_t **x_out,
97
- float **weights_out
98
- )
99
- {
94
+ const Clustering& clus,
95
+ idx_t nx,
96
+ const uint8_t* x,
97
+ size_t line_size,
98
+ const float* weights,
99
+ uint8_t** x_out,
100
+ float** weights_out) {
100
101
  if (clus.verbose) {
101
102
  printf("Sampling a subset of %zd / %" PRId64 " for training\n",
102
- clus.k * clus.max_points_per_centroid, nx);
103
+ clus.k * clus.max_points_per_centroid,
104
+ nx);
103
105
  }
104
- std::vector<int> perm (nx);
105
- rand_perm (perm.data (), nx, clus.seed);
106
+ std::vector<int> perm(nx);
107
+ rand_perm(perm.data(), nx, clus.seed);
106
108
  nx = clus.k * clus.max_points_per_centroid;
107
- uint8_t * x_new = new uint8_t [nx * line_size];
109
+ uint8_t* x_new = new uint8_t[nx * line_size];
108
110
  *x_out = x_new;
109
111
  for (idx_t i = 0; i < nx; i++) {
110
- memcpy (x_new + i * line_size, x + perm[i] * line_size, line_size);
112
+ memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size);
111
113
  }
112
114
  if (weights) {
113
- float *weights_new = new float[nx];
115
+ float* weights_new = new float[nx];
114
116
  for (idx_t i = 0; i < nx; i++) {
115
117
  weights_new[i] = weights[perm[i]];
116
118
  }
@@ -134,20 +136,23 @@ idx_t subsample_training_set(
134
136
  *
135
137
  */
136
138
 
137
- void compute_centroids (size_t d, size_t k, size_t n,
138
- size_t k_frozen,
139
- const uint8_t * x, const Index *codec,
140
- const int64_t * assign,
141
- const float * weights,
142
- float * hassign,
143
- float * centroids)
144
- {
139
+ void compute_centroids(
140
+ size_t d,
141
+ size_t k,
142
+ size_t n,
143
+ size_t k_frozen,
144
+ const uint8_t* x,
145
+ const Index* codec,
146
+ const int64_t* assign,
147
+ const float* weights,
148
+ float* hassign,
149
+ float* centroids) {
145
150
  k -= k_frozen;
146
151
  centroids += k_frozen * d;
147
152
 
148
- memset (centroids, 0, sizeof(*centroids) * d * k);
153
+ memset(centroids, 0, sizeof(*centroids) * d * k);
149
154
 
150
- size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float);
155
+ size_t line_size = codec ? codec->sa_code_size() : d * sizeof(float);
151
156
 
152
157
  #pragma omp parallel
153
158
  {
@@ -157,20 +162,20 @@ void compute_centroids (size_t d, size_t k, size_t n,
157
162
  // this thread is taking care of centroids c0:c1
158
163
  size_t c0 = (k * rank) / nt;
159
164
  size_t c1 = (k * (rank + 1)) / nt;
160
- std::vector<float> decode_buffer (d);
165
+ std::vector<float> decode_buffer(d);
161
166
 
162
167
  for (size_t i = 0; i < n; i++) {
163
168
  int64_t ci = assign[i];
164
- assert (ci >= 0 && ci < k + k_frozen);
169
+ assert(ci >= 0 && ci < k + k_frozen);
165
170
  ci -= k_frozen;
166
- if (ci >= c0 && ci < c1) {
167
- float * c = centroids + ci * d;
168
- const float * xi;
171
+ if (ci >= c0 && ci < c1) {
172
+ float* c = centroids + ci * d;
173
+ const float* xi;
169
174
  if (!codec) {
170
175
  xi = reinterpret_cast<const float*>(x + i * line_size);
171
176
  } else {
172
- float *xif = decode_buffer.data();
173
- codec->sa_decode (1, x + i * line_size, xif);
177
+ float* xif = decode_buffer.data();
178
+ codec->sa_decode(1, x + i * line_size, xif);
174
179
  xi = xif;
175
180
  }
176
181
  if (weights) {
@@ -187,7 +192,6 @@ void compute_centroids (size_t d, size_t k, size_t n,
187
192
  }
188
193
  }
189
194
  }
190
-
191
195
  }
192
196
 
193
197
  #pragma omp parallel for
@@ -196,12 +200,11 @@ void compute_centroids (size_t d, size_t k, size_t n,
196
200
  continue;
197
201
  }
198
202
  float norm = 1 / hassign[ci];
199
- float * c = centroids + ci * d;
203
+ float* c = centroids + ci * d;
200
204
  for (size_t j = 0; j < d; j++) {
201
205
  c[j] *= norm;
202
206
  }
203
207
  }
204
-
205
208
  }
206
209
 
207
210
  // a bit above machine epsilon for float16
@@ -214,29 +217,33 @@ void compute_centroids (size_t d, size_t k, size_t n,
214
217
  *
215
218
  * @return nb of spliting operations (larger is worse)
216
219
  */
217
- int split_clusters (size_t d, size_t k, size_t n,
218
- size_t k_frozen,
219
- float * hassign,
220
- float * centroids)
221
- {
220
+ int split_clusters(
221
+ size_t d,
222
+ size_t k,
223
+ size_t n,
224
+ size_t k_frozen,
225
+ float* hassign,
226
+ float* centroids) {
222
227
  k -= k_frozen;
223
228
  centroids += k_frozen * d;
224
229
 
225
230
  /* Take care of void clusters */
226
231
  size_t nsplit = 0;
227
- RandomGenerator rng (1234);
232
+ RandomGenerator rng(1234);
228
233
  for (size_t ci = 0; ci < k; ci++) {
229
234
  if (hassign[ci] == 0) { /* need to redefine a centroid */
230
235
  size_t cj;
231
236
  for (cj = 0; 1; cj = (cj + 1) % k) {
232
237
  /* probability to pick this cluster for split */
233
- float p = (hassign[cj] - 1.0) / (float) (n - k);
234
- float r = rng.rand_float ();
238
+ float p = (hassign[cj] - 1.0) / (float)(n - k);
239
+ float r = rng.rand_float();
235
240
  if (r < p) {
236
241
  break; /* found our cluster to be split */
237
242
  }
238
243
  }
239
- memcpy (centroids+ci*d, centroids+cj*d, sizeof(*centroids) * d);
244
+ memcpy(centroids + ci * d,
245
+ centroids + cj * d,
246
+ sizeof(*centroids) * d);
240
247
 
241
248
  /* small symmetric pertubation */
242
249
  for (size_t j = 0; j < d; j++) {
@@ -257,30 +264,35 @@ int split_clusters (size_t d, size_t k, size_t n,
257
264
  }
258
265
 
259
266
  return nsplit;
260
-
261
267
  }
262
268
 
263
-
264
-
265
- };
266
-
267
-
268
- void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
269
- const Index * codec, Index & index,
270
- const float *weights) {
271
-
272
-
273
- FAISS_THROW_IF_NOT_FMT (nx >= k,
274
- "Number of training points (%" PRId64 ") should be at least "
275
- "as large as number of clusters (%zd)", nx, k);
276
-
277
- FAISS_THROW_IF_NOT_FMT ((!codec || codec->d == d),
278
- "Codec dimension %d not the same as data dimension %d",
279
- int(codec->d), int(d));
280
-
281
- FAISS_THROW_IF_NOT_FMT (index.d == d,
269
+ }; // namespace
270
+
271
+ void Clustering::train_encoded(
272
+ idx_t nx,
273
+ const uint8_t* x_in,
274
+ const Index* codec,
275
+ Index& index,
276
+ const float* weights) {
277
+ FAISS_THROW_IF_NOT_FMT(
278
+ nx >= k,
279
+ "Number of training points (%" PRId64
280
+ ") should be at least "
281
+ "as large as number of clusters (%zd)",
282
+ nx,
283
+ k);
284
+
285
+ FAISS_THROW_IF_NOT_FMT(
286
+ (!codec || codec->d == d),
287
+ "Codec dimension %d not the same as data dimension %d",
288
+ int(codec->d),
289
+ int(d));
290
+
291
+ FAISS_THROW_IF_NOT_FMT(
292
+ index.d == d,
282
293
  "Index dimension %d not the same as data dimension %d",
283
- int(index.d), int(d));
294
+ int(index.d),
295
+ int(d));
284
296
 
285
297
  double t0 = getmillisecs();
286
298
 
@@ -288,67 +300,78 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
288
300
  // Check for NaNs in input data. Normally it is the user's
289
301
  // responsibility, but it may spare us some hard-to-debug
290
302
  // reports.
291
- const float *x = reinterpret_cast<const float *>(x_in);
303
+ const float* x = reinterpret_cast<const float*>(x_in);
292
304
  for (size_t i = 0; i < nx * d; i++) {
293
- FAISS_THROW_IF_NOT_MSG (std::isfinite (x[i]),
294
- "input contains NaN's or Inf's");
305
+ FAISS_THROW_IF_NOT_MSG(
306
+ std::isfinite(x[i]), "input contains NaN's or Inf's");
295
307
  }
296
308
  }
297
309
 
298
- const uint8_t *x = x_in;
299
- std::unique_ptr<uint8_t []> del1;
300
- std::unique_ptr<float []> del3;
310
+ const uint8_t* x = x_in;
311
+ std::unique_ptr<uint8_t[]> del1;
312
+ std::unique_ptr<float[]> del3;
301
313
  size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
302
314
 
303
315
  if (nx > k * max_points_per_centroid) {
304
- uint8_t *x_new;
305
- float *weights_new;
306
- nx = subsample_training_set (*this, nx, x, line_size, weights,
307
- &x_new, &weights_new);
308
- del1.reset (x_new); x = x_new;
309
- del3.reset (weights_new); weights = weights_new;
316
+ uint8_t* x_new;
317
+ float* weights_new;
318
+ nx = subsample_training_set(
319
+ *this, nx, x, line_size, weights, &x_new, &weights_new);
320
+ del1.reset(x_new);
321
+ x = x_new;
322
+ del3.reset(weights_new);
323
+ weights = weights_new;
310
324
  } else if (nx < k * min_points_per_centroid) {
311
- fprintf (stderr,
312
- "WARNING clustering %" PRId64 " points to %zd centroids: "
313
- "please provide at least %" PRId64 " training points\n",
314
- nx, k, idx_t(k) * min_points_per_centroid);
325
+ fprintf(stderr,
326
+ "WARNING clustering %" PRId64
327
+ " points to %zd centroids: "
328
+ "please provide at least %" PRId64 " training points\n",
329
+ nx,
330
+ k,
331
+ idx_t(k) * min_points_per_centroid);
315
332
  }
316
333
 
317
334
  if (nx == k) {
318
335
  // this is a corner case, just copy training set to clusters
319
336
  if (verbose) {
320
- printf("Number of training points (%" PRId64 ") same as number of "
321
- "clusters, just copying\n", nx);
337
+ printf("Number of training points (%" PRId64
338
+ ") same as number of "
339
+ "clusters, just copying\n",
340
+ nx);
322
341
  }
323
- centroids.resize (d * k);
342
+ centroids.resize(d * k);
324
343
  if (!codec) {
325
- memcpy (centroids.data(), x_in, sizeof (float) * d * k);
344
+ memcpy(centroids.data(), x_in, sizeof(float) * d * k);
326
345
  } else {
327
- codec->sa_decode (nx, x_in, centroids.data());
346
+ codec->sa_decode(nx, x_in, centroids.data());
328
347
  }
329
348
 
330
349
  // one fake iteration...
331
- ClusteringIterationStats stats = { 0.0, 0.0, 0.0, 1.0, 0 };
332
- iteration_stats.push_back (stats);
350
+ ClusteringIterationStats stats = {0.0, 0.0, 0.0, 1.0, 0};
351
+ iteration_stats.push_back(stats);
333
352
 
334
353
  index.reset();
335
354
  index.add(k, centroids.data());
336
355
  return;
337
356
  }
338
357
 
339
-
340
358
  if (verbose) {
341
- printf("Clustering %" PRId64 " points in %zdD to %zd clusters, "
359
+ printf("Clustering %" PRId64
360
+ " points in %zdD to %zd clusters, "
342
361
  "redo %d times, %d iterations\n",
343
- nx, d, k, nredo, niter);
362
+ nx,
363
+ d,
364
+ k,
365
+ nredo,
366
+ niter);
344
367
  if (codec) {
345
368
  printf("Input data encoded in %zd bytes per vector\n",
346
- codec->sa_code_size ());
369
+ codec->sa_code_size());
347
370
  }
348
371
  }
349
372
 
350
- std::unique_ptr<idx_t []> assign(new idx_t[nx]);
351
- std::unique_ptr<float []> dis(new float[nx]);
373
+ std::unique_ptr<idx_t[]> assign(new idx_t[nx]);
374
+ std::unique_ptr<float[]> dis(new float[nx]);
352
375
 
353
376
  // remember best iteration for redo
354
377
  bool lower_is_better = index.metric_type != METRIC_INNER_PRODUCT;
@@ -358,52 +381,49 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
358
381
 
359
382
  // support input centroids
360
383
 
361
- FAISS_THROW_IF_NOT_MSG (
362
- centroids.size() % d == 0,
363
- "size of provided input centroids not a multiple of dimension"
364
- );
384
+ FAISS_THROW_IF_NOT_MSG(
385
+ centroids.size() % d == 0,
386
+ "size of provided input centroids not a multiple of dimension");
365
387
 
366
388
  size_t n_input_centroids = centroids.size() / d;
367
389
 
368
390
  if (verbose && n_input_centroids > 0) {
369
- printf (" Using %zd centroids provided as input (%sfrozen)\n",
370
- n_input_centroids, frozen_centroids ? "" : "not ");
391
+ printf(" Using %zd centroids provided as input (%sfrozen)\n",
392
+ n_input_centroids,
393
+ frozen_centroids ? "" : "not ");
371
394
  }
372
395
 
373
396
  double t_search_tot = 0;
374
397
  if (verbose) {
375
- printf(" Preprocessing in %.2f s\n",
376
- (getmillisecs() - t0) / 1000.);
398
+ printf(" Preprocessing in %.2f s\n", (getmillisecs() - t0) / 1000.);
377
399
  }
378
400
  t0 = getmillisecs();
379
401
 
380
402
  // temporary buffer to decode vectors during the optimization
381
- std::vector<float> decode_buffer
382
- (codec ? d * decode_block_size : 0);
403
+ std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
383
404
 
384
405
  for (int redo = 0; redo < nredo; redo++) {
385
-
386
406
  if (verbose && nredo > 1) {
387
407
  printf("Outer iteration %d / %d\n", redo, nredo);
388
408
  }
389
409
 
390
410
  // initialize (remaining) centroids with random points from the dataset
391
- centroids.resize (d * k);
392
- std::vector<int> perm (nx);
411
+ centroids.resize(d * k);
412
+ std::vector<int> perm(nx);
393
413
 
394
- rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
414
+ rand_perm(perm.data(), nx, seed + 1 + redo * 15486557L);
395
415
 
396
416
  if (!codec) {
397
- for (int i = n_input_centroids; i < k ; i++) {
398
- memcpy (&centroids[i * d], x + perm[i] * line_size, line_size);
417
+ for (int i = n_input_centroids; i < k; i++) {
418
+ memcpy(&centroids[i * d], x + perm[i] * line_size, line_size);
399
419
  }
400
420
  } else {
401
- for (int i = n_input_centroids; i < k ; i++) {
402
- codec->sa_decode (1, x + perm[i] * line_size, &centroids[i * d]);
421
+ for (int i = n_input_centroids; i < k; i++) {
422
+ codec->sa_decode(1, x + perm[i] * line_size, &centroids[i * d]);
403
423
  }
404
424
  }
405
425
 
406
- post_process_centroids ();
426
+ post_process_centroids();
407
427
 
408
428
  // prepare the index
409
429
 
@@ -412,10 +432,10 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
412
432
  }
413
433
 
414
434
  if (!index.is_trained) {
415
- index.train (k, centroids.data());
435
+ index.train(k, centroids.data());
416
436
  }
417
437
 
418
- index.add (k, centroids.data());
438
+ index.add(k, centroids.data());
419
439
 
420
440
  // k-means iterations
421
441
 
@@ -424,18 +444,28 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
424
444
  double t0s = getmillisecs();
425
445
 
426
446
  if (!codec) {
427
- index.search (nx, reinterpret_cast<const float *>(x), 1,
428
- dis.get(), assign.get());
447
+ index.search(
448
+ nx,
449
+ reinterpret_cast<const float*>(x),
450
+ 1,
451
+ dis.get(),
452
+ assign.get());
429
453
  } else {
430
454
  // search by blocks of decode_block_size vectors
431
- size_t code_size = codec->sa_code_size ();
455
+ size_t code_size = codec->sa_code_size();
432
456
  for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
433
457
  size_t i1 = i0 + decode_block_size;
434
- if (i1 > nx) { i1 = nx; }
435
- codec->sa_decode (i1 - i0, x + code_size * i0,
436
- decode_buffer.data ());
437
- index.search (i1 - i0, decode_buffer.data (), 1,
438
- dis.get() + i0, assign.get() + i0);
458
+ if (i1 > nx) {
459
+ i1 = nx;
460
+ }
461
+ codec->sa_decode(
462
+ i1 - i0, x + code_size * i0, decode_buffer.data());
463
+ index.search(
464
+ i1 - i0,
465
+ decode_buffer.data(),
466
+ 1,
467
+ dis.get() + i0,
468
+ assign.get() + i0);
439
469
  }
440
470
  }
441
471
 
@@ -449,61 +479,71 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
449
479
  }
450
480
 
451
481
  // update the centroids
452
- std::vector<float> hassign (k);
482
+ std::vector<float> hassign(k);
453
483
 
454
484
  size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
455
- compute_centroids (
456
- d, k, nx, k_frozen,
457
- x, codec, assign.get(), weights,
458
- hassign.data(), centroids.data()
459
- );
460
-
461
- int nsplit = split_clusters (
462
- d, k, nx, k_frozen,
463
- hassign.data(), centroids.data()
464
- );
485
+ compute_centroids(
486
+ d,
487
+ k,
488
+ nx,
489
+ k_frozen,
490
+ x,
491
+ codec,
492
+ assign.get(),
493
+ weights,
494
+ hassign.data(),
495
+ centroids.data());
496
+
497
+ int nsplit = split_clusters(
498
+ d, k, nx, k_frozen, hassign.data(), centroids.data());
465
499
 
466
500
  // collect statistics
467
- ClusteringIterationStats stats =
468
- { obj, (getmillisecs() - t0) / 1000.0,
469
- t_search_tot / 1000,
470
- imbalance_factor (nx, k, assign.get()),
471
- nsplit };
501
+ ClusteringIterationStats stats = {
502
+ obj,
503
+ (getmillisecs() - t0) / 1000.0,
504
+ t_search_tot / 1000,
505
+ imbalance_factor(nx, k, assign.get()),
506
+ nsplit};
472
507
  iteration_stats.push_back(stats);
473
508
 
474
509
  if (verbose) {
475
- printf (" Iteration %d (%.2f s, search %.2f s): "
476
- "objective=%g imbalance=%.3f nsplit=%d \r",
477
- i, stats.time, stats.time_search, stats.obj,
478
- stats.imbalance_factor, nsplit);
479
- fflush (stdout);
510
+ printf(" Iteration %d (%.2f s, search %.2f s): "
511
+ "objective=%g imbalance=%.3f nsplit=%d \r",
512
+ i,
513
+ stats.time,
514
+ stats.time_search,
515
+ stats.obj,
516
+ stats.imbalance_factor,
517
+ nsplit);
518
+ fflush(stdout);
480
519
  }
481
520
 
482
- post_process_centroids ();
521
+ post_process_centroids();
483
522
 
484
523
  // add centroids to index for the next iteration (or for output)
485
524
 
486
- index.reset ();
525
+ index.reset();
487
526
  if (update_index) {
488
- index.train (k, centroids.data());
527
+ index.train(k, centroids.data());
489
528
  }
490
529
 
491
- index.add (k, centroids.data());
492
- InterruptCallback::check ();
530
+ index.add(k, centroids.data());
531
+ InterruptCallback::check();
493
532
  }
494
533
 
495
- if (verbose) printf("\n");
534
+ if (verbose)
535
+ printf("\n");
496
536
  if (nredo > 1) {
497
537
  if ((lower_is_better && obj < best_obj) ||
498
538
  (!lower_is_better && obj > best_obj)) {
499
539
  if (verbose) {
500
- printf ("Objective improved: keep new clusters\n");
540
+ printf("Objective improved: keep new clusters\n");
501
541
  }
502
542
  best_centroids = centroids;
503
543
  best_iteration_stats = iteration_stats;
504
544
  best_obj = obj;
505
545
  }
506
- index.reset ();
546
+ index.reset();
507
547
  }
508
548
  }
509
549
  if (nredo > 1) {
@@ -512,20 +552,151 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
512
552
  index.reset();
513
553
  index.add(k, best_centroids.data());
514
554
  }
555
+ }
556
+
557
+ Clustering1D::Clustering1D(int k) : Clustering(1, k) {}
558
+
559
+ Clustering1D::Clustering1D(int k, const ClusteringParameters& cp)
560
+ : Clustering(1, k, cp) {}
561
+
562
+ void Clustering1D::train_exact(idx_t n, const float* x) {
563
+ const float* xt = x;
564
+
565
+ std::unique_ptr<uint8_t[]> del;
566
+ if (n > k * max_points_per_centroid) {
567
+ uint8_t* x_new;
568
+ float* weights_new;
569
+ n = subsample_training_set(
570
+ *this,
571
+ n,
572
+ (uint8_t*)x,
573
+ sizeof(float) * d,
574
+ nullptr,
575
+ &x_new,
576
+ &weights_new);
577
+ del.reset(x_new);
578
+ xt = (float*)x_new;
579
+ }
515
580
 
581
+ centroids.resize(k);
582
+ double uf = kmeans1d(xt, n, k, centroids.data());
583
+
584
+ ClusteringIterationStats stats = {0.0, 0.0, 0.0, uf, 0};
585
+ iteration_stats.push_back(stats);
516
586
  }
517
587
 
518
- float kmeans_clustering (size_t d, size_t n, size_t k,
519
- const float *x,
520
- float *centroids)
521
- {
522
- Clustering clus (d, k);
588
+ float kmeans_clustering(
589
+ size_t d,
590
+ size_t n,
591
+ size_t k,
592
+ const float* x,
593
+ float* centroids) {
594
+ Clustering clus(d, k);
523
595
  clus.verbose = d * n * k > (1L << 30);
524
596
  // display logs if > 1Gflop per iteration
525
- IndexFlatL2 index (d);
526
- clus.train (n, x, index);
597
+ IndexFlatL2 index(d);
598
+ clus.train(n, x, index);
527
599
  memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
528
600
  return clus.iteration_stats.back().obj;
529
601
  }
530
602
 
603
+ /******************************************************************************
604
+ * ProgressiveDimClustering implementation
605
+ ******************************************************************************/
606
+
607
+ ProgressiveDimClusteringParameters::ProgressiveDimClusteringParameters() {
608
+ progressive_dim_steps = 10;
609
+ apply_pca = true; // seems a good idea to do this by default
610
+ niter = 10; // reduce nb of iterations per step
611
+ }
612
+
613
+ Index* ProgressiveDimIndexFactory::operator()(int dim) {
614
+ return new IndexFlatL2(dim);
615
+ }
616
+
617
+ ProgressiveDimClustering::ProgressiveDimClustering(int d, int k) : d(d), k(k) {}
618
+
619
+ ProgressiveDimClustering::ProgressiveDimClustering(
620
+ int d,
621
+ int k,
622
+ const ProgressiveDimClusteringParameters& cp)
623
+ : ProgressiveDimClusteringParameters(cp), d(d), k(k) {}
624
+
625
+ namespace {
626
+
627
+ using idx_t = Index::idx_t;
628
+
629
+ void copy_columns(idx_t n, idx_t d1, const float* src, idx_t d2, float* dest) {
630
+ idx_t d = std::min(d1, d2);
631
+ for (idx_t i = 0; i < n; i++) {
632
+ memcpy(dest, src, sizeof(float) * d);
633
+ src += d1;
634
+ dest += d2;
635
+ }
636
+ }
637
+
638
+ }; // namespace
639
+
640
+ void ProgressiveDimClustering::train(
641
+ idx_t n,
642
+ const float* x,
643
+ ProgressiveDimIndexFactory& factory) {
644
+ int d_prev = 0;
645
+
646
+ PCAMatrix pca(d, d);
647
+
648
+ std::vector<float> xbuf;
649
+ if (apply_pca) {
650
+ if (verbose) {
651
+ printf("Training PCA transform\n");
652
+ }
653
+ pca.train(n, x);
654
+ if (verbose) {
655
+ printf("Apply PCA\n");
656
+ }
657
+ xbuf.resize(n * d);
658
+ pca.apply_noalloc(n, x, xbuf.data());
659
+ x = xbuf.data();
660
+ }
661
+
662
+ for (int iter = 0; iter < progressive_dim_steps; iter++) {
663
+ int di = int(pow(d, (1. + iter) / progressive_dim_steps));
664
+ if (verbose) {
665
+ printf("Progressive dim step %d: cluster in dimension %d\n",
666
+ iter,
667
+ di);
668
+ }
669
+ std::unique_ptr<Index> clustering_index(factory(di));
670
+
671
+ Clustering clus(di, k, *this);
672
+ if (d_prev > 0) {
673
+ // copy warm-start centroids (padded with 0s)
674
+ clus.centroids.resize(k * di);
675
+ copy_columns(
676
+ k, d_prev, centroids.data(), di, clus.centroids.data());
677
+ }
678
+ std::vector<float> xsub(n * di);
679
+ copy_columns(n, d, x, di, xsub.data());
680
+
681
+ clus.train(n, xsub.data(), *clustering_index.get());
682
+
683
+ centroids = clus.centroids;
684
+ iteration_stats.insert(
685
+ iteration_stats.end(),
686
+ clus.iteration_stats.begin(),
687
+ clus.iteration_stats.end());
688
+
689
+ d_prev = di;
690
+ }
691
+
692
+ if (apply_pca) {
693
+ if (verbose) {
694
+ printf("Revert PCA transform on centroids\n");
695
+ }
696
+ std::vector<float> cent_transformed(d * k);
697
+ pca.reverse_transform(k, centroids.data(), cent_transformed.data());
698
+ cent_transformed.swap(centroids);
699
+ }
700
+ }
701
+
531
702
  } // namespace faiss