faiss 0.1.7 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -7
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +8 -2
  6. data/ext/faiss/index.cpp +102 -69
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +0 -5
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +26 -12
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -8,6 +8,7 @@
8
8
  // -*- c++ -*-
9
9
 
10
10
  #include <faiss/Clustering.h>
11
+ #include <faiss/VectorTransform.h>
11
12
  #include <faiss/impl/AuxIndexStructures.h>
12
13
 
13
14
  #include <cinttypes>
@@ -17,100 +18,100 @@
17
18
 
18
19
  #include <omp.h>
19
20
 
20
- #include <faiss/utils/utils.h>
21
- #include <faiss/utils/random.h>
22
- #include <faiss/utils/distances.h>
23
- #include <faiss/impl/FaissAssert.h>
24
21
  #include <faiss/IndexFlat.h>
22
+ #include <faiss/impl/FaissAssert.h>
23
+ #include <faiss/utils/distances.h>
24
+ #include <faiss/utils/random.h>
25
+ #include <faiss/utils/utils.h>
25
26
 
26
27
  namespace faiss {
27
28
 
28
- ClusteringParameters::ClusteringParameters ():
29
- niter(25),
30
- nredo(1),
31
- verbose(false),
32
- spherical(false),
33
- int_centroids(false),
34
- update_index(false),
35
- frozen_centroids(false),
36
- min_points_per_centroid(39),
37
- max_points_per_centroid(256),
38
- seed(1234),
39
- decode_block_size(32768)
40
- {}
29
+ ClusteringParameters::ClusteringParameters()
30
+ : niter(25),
31
+ nredo(1),
32
+ verbose(false),
33
+ spherical(false),
34
+ int_centroids(false),
35
+ update_index(false),
36
+ frozen_centroids(false),
37
+ min_points_per_centroid(39),
38
+ max_points_per_centroid(256),
39
+ seed(1234),
40
+ decode_block_size(32768) {}
41
41
  // 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
42
42
 
43
+ Clustering::Clustering(int d, int k) : d(d), k(k) {}
43
44
 
44
- Clustering::Clustering (int d, int k):
45
- d(d), k(k) {}
46
-
47
- Clustering::Clustering (int d, int k, const ClusteringParameters &cp):
48
- ClusteringParameters (cp), d(d), k(k) {}
45
+ Clustering::Clustering(int d, int k, const ClusteringParameters& cp)
46
+ : ClusteringParameters(cp), d(d), k(k) {}
49
47
 
50
-
51
-
52
- static double imbalance_factor (int n, int k, int64_t *assign) {
48
+ static double imbalance_factor(int n, int k, int64_t* assign) {
53
49
  std::vector<int> hist(k, 0);
54
50
  for (int i = 0; i < n; i++)
55
51
  hist[assign[i]]++;
56
52
 
57
53
  double tot = 0, uf = 0;
58
54
 
59
- for (int i = 0 ; i < k ; i++) {
55
+ for (int i = 0; i < k; i++) {
60
56
  tot += hist[i];
61
- uf += hist[i] * (double) hist[i];
57
+ uf += hist[i] * (double)hist[i];
62
58
  }
63
59
  uf = uf * k / (tot * tot);
64
60
 
65
61
  return uf;
66
62
  }
67
63
 
68
- void Clustering::post_process_centroids ()
69
- {
70
-
64
+ void Clustering::post_process_centroids() {
71
65
  if (spherical) {
72
- fvec_renorm_L2 (d, k, centroids.data());
66
+ fvec_renorm_L2(d, k, centroids.data());
73
67
  }
74
68
 
75
69
  if (int_centroids) {
76
70
  for (size_t i = 0; i < centroids.size(); i++)
77
- centroids[i] = roundf (centroids[i]);
71
+ centroids[i] = roundf(centroids[i]);
78
72
  }
79
73
  }
80
74
 
81
-
82
- void Clustering::train (idx_t nx, const float *x_in, Index & index,
83
- const float *weights) {
84
- train_encoded (nx, reinterpret_cast<const uint8_t *>(x_in), nullptr,
85
- index, weights);
75
+ void Clustering::train(
76
+ idx_t nx,
77
+ const float* x_in,
78
+ Index& index,
79
+ const float* weights) {
80
+ train_encoded(
81
+ nx,
82
+ reinterpret_cast<const uint8_t*>(x_in),
83
+ nullptr,
84
+ index,
85
+ weights);
86
86
  }
87
87
 
88
-
89
88
  namespace {
90
89
 
91
90
  using idx_t = Clustering::idx_t;
92
91
 
93
92
  idx_t subsample_training_set(
94
- const Clustering &clus, idx_t nx, const uint8_t *x,
95
- size_t line_size, const float * weights,
96
- uint8_t **x_out,
97
- float **weights_out
98
- )
99
- {
93
+ const Clustering& clus,
94
+ idx_t nx,
95
+ const uint8_t* x,
96
+ size_t line_size,
97
+ const float* weights,
98
+ uint8_t** x_out,
99
+ float** weights_out) {
100
100
  if (clus.verbose) {
101
101
  printf("Sampling a subset of %zd / %" PRId64 " for training\n",
102
- clus.k * clus.max_points_per_centroid, nx);
102
+ clus.k * clus.max_points_per_centroid,
103
+ nx);
103
104
  }
104
- std::vector<int> perm (nx);
105
- rand_perm (perm.data (), nx, clus.seed);
105
+ std::vector<int> perm(nx);
106
+ rand_perm(perm.data(), nx, clus.seed);
106
107
  nx = clus.k * clus.max_points_per_centroid;
107
- uint8_t * x_new = new uint8_t [nx * line_size];
108
+ uint8_t* x_new = new uint8_t[nx * line_size];
108
109
  *x_out = x_new;
109
110
  for (idx_t i = 0; i < nx; i++) {
110
- memcpy (x_new + i * line_size, x + perm[i] * line_size, line_size);
111
+ memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size);
111
112
  }
112
113
  if (weights) {
113
- float *weights_new = new float[nx];
114
+ float* weights_new = new float[nx];
114
115
  for (idx_t i = 0; i < nx; i++) {
115
116
  weights_new[i] = weights[perm[i]];
116
117
  }
@@ -134,20 +135,23 @@ idx_t subsample_training_set(
134
135
  *
135
136
  */
136
137
 
137
- void compute_centroids (size_t d, size_t k, size_t n,
138
- size_t k_frozen,
139
- const uint8_t * x, const Index *codec,
140
- const int64_t * assign,
141
- const float * weights,
142
- float * hassign,
143
- float * centroids)
144
- {
138
+ void compute_centroids(
139
+ size_t d,
140
+ size_t k,
141
+ size_t n,
142
+ size_t k_frozen,
143
+ const uint8_t* x,
144
+ const Index* codec,
145
+ const int64_t* assign,
146
+ const float* weights,
147
+ float* hassign,
148
+ float* centroids) {
145
149
  k -= k_frozen;
146
150
  centroids += k_frozen * d;
147
151
 
148
- memset (centroids, 0, sizeof(*centroids) * d * k);
152
+ memset(centroids, 0, sizeof(*centroids) * d * k);
149
153
 
150
- size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float);
154
+ size_t line_size = codec ? codec->sa_code_size() : d * sizeof(float);
151
155
 
152
156
  #pragma omp parallel
153
157
  {
@@ -157,20 +161,20 @@ void compute_centroids (size_t d, size_t k, size_t n,
157
161
  // this thread is taking care of centroids c0:c1
158
162
  size_t c0 = (k * rank) / nt;
159
163
  size_t c1 = (k * (rank + 1)) / nt;
160
- std::vector<float> decode_buffer (d);
164
+ std::vector<float> decode_buffer(d);
161
165
 
162
166
  for (size_t i = 0; i < n; i++) {
163
167
  int64_t ci = assign[i];
164
- assert (ci >= 0 && ci < k + k_frozen);
168
+ assert(ci >= 0 && ci < k + k_frozen);
165
169
  ci -= k_frozen;
166
- if (ci >= c0 && ci < c1) {
167
- float * c = centroids + ci * d;
168
- const float * xi;
170
+ if (ci >= c0 && ci < c1) {
171
+ float* c = centroids + ci * d;
172
+ const float* xi;
169
173
  if (!codec) {
170
174
  xi = reinterpret_cast<const float*>(x + i * line_size);
171
175
  } else {
172
- float *xif = decode_buffer.data();
173
- codec->sa_decode (1, x + i * line_size, xif);
176
+ float* xif = decode_buffer.data();
177
+ codec->sa_decode(1, x + i * line_size, xif);
174
178
  xi = xif;
175
179
  }
176
180
  if (weights) {
@@ -187,7 +191,6 @@ void compute_centroids (size_t d, size_t k, size_t n,
187
191
  }
188
192
  }
189
193
  }
190
-
191
194
  }
192
195
 
193
196
  #pragma omp parallel for
@@ -196,12 +199,11 @@ void compute_centroids (size_t d, size_t k, size_t n,
196
199
  continue;
197
200
  }
198
201
  float norm = 1 / hassign[ci];
199
- float * c = centroids + ci * d;
202
+ float* c = centroids + ci * d;
200
203
  for (size_t j = 0; j < d; j++) {
201
204
  c[j] *= norm;
202
205
  }
203
206
  }
204
-
205
207
  }
206
208
 
207
209
  // a bit above machine epsilon for float16
@@ -214,29 +216,33 @@ void compute_centroids (size_t d, size_t k, size_t n,
214
216
  *
215
217
  * @return nb of spliting operations (larger is worse)
216
218
  */
217
- int split_clusters (size_t d, size_t k, size_t n,
218
- size_t k_frozen,
219
- float * hassign,
220
- float * centroids)
221
- {
219
+ int split_clusters(
220
+ size_t d,
221
+ size_t k,
222
+ size_t n,
223
+ size_t k_frozen,
224
+ float* hassign,
225
+ float* centroids) {
222
226
  k -= k_frozen;
223
227
  centroids += k_frozen * d;
224
228
 
225
229
  /* Take care of void clusters */
226
230
  size_t nsplit = 0;
227
- RandomGenerator rng (1234);
231
+ RandomGenerator rng(1234);
228
232
  for (size_t ci = 0; ci < k; ci++) {
229
233
  if (hassign[ci] == 0) { /* need to redefine a centroid */
230
234
  size_t cj;
231
235
  for (cj = 0; 1; cj = (cj + 1) % k) {
232
236
  /* probability to pick this cluster for split */
233
- float p = (hassign[cj] - 1.0) / (float) (n - k);
234
- float r = rng.rand_float ();
237
+ float p = (hassign[cj] - 1.0) / (float)(n - k);
238
+ float r = rng.rand_float();
235
239
  if (r < p) {
236
240
  break; /* found our cluster to be split */
237
241
  }
238
242
  }
239
- memcpy (centroids+ci*d, centroids+cj*d, sizeof(*centroids) * d);
243
+ memcpy(centroids + ci * d,
244
+ centroids + cj * d,
245
+ sizeof(*centroids) * d);
240
246
 
241
247
  /* small symmetric pertubation */
242
248
  for (size_t j = 0; j < d; j++) {
@@ -257,30 +263,35 @@ int split_clusters (size_t d, size_t k, size_t n,
257
263
  }
258
264
 
259
265
  return nsplit;
260
-
261
266
  }
262
267
 
263
-
264
-
265
- };
266
-
267
-
268
- void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
269
- const Index * codec, Index & index,
270
- const float *weights) {
271
-
272
-
273
- FAISS_THROW_IF_NOT_FMT (nx >= k,
274
- "Number of training points (%" PRId64 ") should be at least "
275
- "as large as number of clusters (%zd)", nx, k);
276
-
277
- FAISS_THROW_IF_NOT_FMT ((!codec || codec->d == d),
278
- "Codec dimension %d not the same as data dimension %d",
279
- int(codec->d), int(d));
280
-
281
- FAISS_THROW_IF_NOT_FMT (index.d == d,
268
+ }; // namespace
269
+
270
+ void Clustering::train_encoded(
271
+ idx_t nx,
272
+ const uint8_t* x_in,
273
+ const Index* codec,
274
+ Index& index,
275
+ const float* weights) {
276
+ FAISS_THROW_IF_NOT_FMT(
277
+ nx >= k,
278
+ "Number of training points (%" PRId64
279
+ ") should be at least "
280
+ "as large as number of clusters (%zd)",
281
+ nx,
282
+ k);
283
+
284
+ FAISS_THROW_IF_NOT_FMT(
285
+ (!codec || codec->d == d),
286
+ "Codec dimension %d not the same as data dimension %d",
287
+ int(codec->d),
288
+ int(d));
289
+
290
+ FAISS_THROW_IF_NOT_FMT(
291
+ index.d == d,
282
292
  "Index dimension %d not the same as data dimension %d",
283
- int(index.d), int(d));
293
+ int(index.d),
294
+ int(d));
284
295
 
285
296
  double t0 = getmillisecs();
286
297
 
@@ -288,67 +299,78 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
288
299
  // Check for NaNs in input data. Normally it is the user's
289
300
  // responsibility, but it may spare us some hard-to-debug
290
301
  // reports.
291
- const float *x = reinterpret_cast<const float *>(x_in);
302
+ const float* x = reinterpret_cast<const float*>(x_in);
292
303
  for (size_t i = 0; i < nx * d; i++) {
293
- FAISS_THROW_IF_NOT_MSG (std::isfinite (x[i]),
294
- "input contains NaN's or Inf's");
304
+ FAISS_THROW_IF_NOT_MSG(
305
+ std::isfinite(x[i]), "input contains NaN's or Inf's");
295
306
  }
296
307
  }
297
308
 
298
- const uint8_t *x = x_in;
299
- std::unique_ptr<uint8_t []> del1;
300
- std::unique_ptr<float []> del3;
309
+ const uint8_t* x = x_in;
310
+ std::unique_ptr<uint8_t[]> del1;
311
+ std::unique_ptr<float[]> del3;
301
312
  size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
302
313
 
303
314
  if (nx > k * max_points_per_centroid) {
304
- uint8_t *x_new;
305
- float *weights_new;
306
- nx = subsample_training_set (*this, nx, x, line_size, weights,
307
- &x_new, &weights_new);
308
- del1.reset (x_new); x = x_new;
309
- del3.reset (weights_new); weights = weights_new;
315
+ uint8_t* x_new;
316
+ float* weights_new;
317
+ nx = subsample_training_set(
318
+ *this, nx, x, line_size, weights, &x_new, &weights_new);
319
+ del1.reset(x_new);
320
+ x = x_new;
321
+ del3.reset(weights_new);
322
+ weights = weights_new;
310
323
  } else if (nx < k * min_points_per_centroid) {
311
- fprintf (stderr,
312
- "WARNING clustering %" PRId64 " points to %zd centroids: "
313
- "please provide at least %" PRId64 " training points\n",
314
- nx, k, idx_t(k) * min_points_per_centroid);
324
+ fprintf(stderr,
325
+ "WARNING clustering %" PRId64
326
+ " points to %zd centroids: "
327
+ "please provide at least %" PRId64 " training points\n",
328
+ nx,
329
+ k,
330
+ idx_t(k) * min_points_per_centroid);
315
331
  }
316
332
 
317
333
  if (nx == k) {
318
334
  // this is a corner case, just copy training set to clusters
319
335
  if (verbose) {
320
- printf("Number of training points (%" PRId64 ") same as number of "
321
- "clusters, just copying\n", nx);
336
+ printf("Number of training points (%" PRId64
337
+ ") same as number of "
338
+ "clusters, just copying\n",
339
+ nx);
322
340
  }
323
- centroids.resize (d * k);
341
+ centroids.resize(d * k);
324
342
  if (!codec) {
325
- memcpy (centroids.data(), x_in, sizeof (float) * d * k);
343
+ memcpy(centroids.data(), x_in, sizeof(float) * d * k);
326
344
  } else {
327
- codec->sa_decode (nx, x_in, centroids.data());
345
+ codec->sa_decode(nx, x_in, centroids.data());
328
346
  }
329
347
 
330
348
  // one fake iteration...
331
- ClusteringIterationStats stats = { 0.0, 0.0, 0.0, 1.0, 0 };
332
- iteration_stats.push_back (stats);
349
+ ClusteringIterationStats stats = {0.0, 0.0, 0.0, 1.0, 0};
350
+ iteration_stats.push_back(stats);
333
351
 
334
352
  index.reset();
335
353
  index.add(k, centroids.data());
336
354
  return;
337
355
  }
338
356
 
339
-
340
357
  if (verbose) {
341
- printf("Clustering %" PRId64 " points in %zdD to %zd clusters, "
358
+ printf("Clustering %" PRId64
359
+ " points in %zdD to %zd clusters, "
342
360
  "redo %d times, %d iterations\n",
343
- nx, d, k, nredo, niter);
361
+ nx,
362
+ d,
363
+ k,
364
+ nredo,
365
+ niter);
344
366
  if (codec) {
345
367
  printf("Input data encoded in %zd bytes per vector\n",
346
- codec->sa_code_size ());
368
+ codec->sa_code_size());
347
369
  }
348
370
  }
349
371
 
350
- std::unique_ptr<idx_t []> assign(new idx_t[nx]);
351
- std::unique_ptr<float []> dis(new float[nx]);
372
+ std::unique_ptr<idx_t[]> assign(new idx_t[nx]);
373
+ std::unique_ptr<float[]> dis(new float[nx]);
352
374
 
353
375
  // remember best iteration for redo
354
376
  bool lower_is_better = index.metric_type != METRIC_INNER_PRODUCT;
@@ -358,52 +380,49 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
358
380
 
359
381
  // support input centroids
360
382
 
361
- FAISS_THROW_IF_NOT_MSG (
362
- centroids.size() % d == 0,
363
- "size of provided input centroids not a multiple of dimension"
364
- );
383
+ FAISS_THROW_IF_NOT_MSG(
384
+ centroids.size() % d == 0,
385
+ "size of provided input centroids not a multiple of dimension");
365
386
 
366
387
  size_t n_input_centroids = centroids.size() / d;
367
388
 
368
389
  if (verbose && n_input_centroids > 0) {
369
- printf (" Using %zd centroids provided as input (%sfrozen)\n",
370
- n_input_centroids, frozen_centroids ? "" : "not ");
390
+ printf(" Using %zd centroids provided as input (%sfrozen)\n",
391
+ n_input_centroids,
392
+ frozen_centroids ? "" : "not ");
371
393
  }
372
394
 
373
395
  double t_search_tot = 0;
374
396
  if (verbose) {
375
- printf(" Preprocessing in %.2f s\n",
376
- (getmillisecs() - t0) / 1000.);
397
+ printf(" Preprocessing in %.2f s\n", (getmillisecs() - t0) / 1000.);
377
398
  }
378
399
  t0 = getmillisecs();
379
400
 
380
401
  // temporary buffer to decode vectors during the optimization
381
- std::vector<float> decode_buffer
382
- (codec ? d * decode_block_size : 0);
402
+ std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
383
403
 
384
404
  for (int redo = 0; redo < nredo; redo++) {
385
-
386
405
  if (verbose && nredo > 1) {
387
406
  printf("Outer iteration %d / %d\n", redo, nredo);
388
407
  }
389
408
 
390
409
  // initialize (remaining) centroids with random points from the dataset
391
- centroids.resize (d * k);
392
- std::vector<int> perm (nx);
410
+ centroids.resize(d * k);
411
+ std::vector<int> perm(nx);
393
412
 
394
- rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
413
+ rand_perm(perm.data(), nx, seed + 1 + redo * 15486557L);
395
414
 
396
415
  if (!codec) {
397
- for (int i = n_input_centroids; i < k ; i++) {
398
- memcpy (&centroids[i * d], x + perm[i] * line_size, line_size);
416
+ for (int i = n_input_centroids; i < k; i++) {
417
+ memcpy(&centroids[i * d], x + perm[i] * line_size, line_size);
399
418
  }
400
419
  } else {
401
- for (int i = n_input_centroids; i < k ; i++) {
402
- codec->sa_decode (1, x + perm[i] * line_size, &centroids[i * d]);
420
+ for (int i = n_input_centroids; i < k; i++) {
421
+ codec->sa_decode(1, x + perm[i] * line_size, &centroids[i * d]);
403
422
  }
404
423
  }
405
424
 
406
- post_process_centroids ();
425
+ post_process_centroids();
407
426
 
408
427
  // prepare the index
409
428
 
@@ -412,10 +431,10 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
412
431
  }
413
432
 
414
433
  if (!index.is_trained) {
415
- index.train (k, centroids.data());
434
+ index.train(k, centroids.data());
416
435
  }
417
436
 
418
- index.add (k, centroids.data());
437
+ index.add(k, centroids.data());
419
438
 
420
439
  // k-means iterations
421
440
 
@@ -424,18 +443,28 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
424
443
  double t0s = getmillisecs();
425
444
 
426
445
  if (!codec) {
427
- index.search (nx, reinterpret_cast<const float *>(x), 1,
428
- dis.get(), assign.get());
446
+ index.search(
447
+ nx,
448
+ reinterpret_cast<const float*>(x),
449
+ 1,
450
+ dis.get(),
451
+ assign.get());
429
452
  } else {
430
453
  // search by blocks of decode_block_size vectors
431
- size_t code_size = codec->sa_code_size ();
454
+ size_t code_size = codec->sa_code_size();
432
455
  for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
433
456
  size_t i1 = i0 + decode_block_size;
434
- if (i1 > nx) { i1 = nx; }
435
- codec->sa_decode (i1 - i0, x + code_size * i0,
436
- decode_buffer.data ());
437
- index.search (i1 - i0, decode_buffer.data (), 1,
438
- dis.get() + i0, assign.get() + i0);
457
+ if (i1 > nx) {
458
+ i1 = nx;
459
+ }
460
+ codec->sa_decode(
461
+ i1 - i0, x + code_size * i0, decode_buffer.data());
462
+ index.search(
463
+ i1 - i0,
464
+ decode_buffer.data(),
465
+ 1,
466
+ dis.get() + i0,
467
+ assign.get() + i0);
439
468
  }
440
469
  }
441
470
 
@@ -449,61 +478,71 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
449
478
  }
450
479
 
451
480
  // update the centroids
452
- std::vector<float> hassign (k);
481
+ std::vector<float> hassign(k);
453
482
 
454
483
  size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
455
- compute_centroids (
456
- d, k, nx, k_frozen,
457
- x, codec, assign.get(), weights,
458
- hassign.data(), centroids.data()
459
- );
460
-
461
- int nsplit = split_clusters (
462
- d, k, nx, k_frozen,
463
- hassign.data(), centroids.data()
464
- );
484
+ compute_centroids(
485
+ d,
486
+ k,
487
+ nx,
488
+ k_frozen,
489
+ x,
490
+ codec,
491
+ assign.get(),
492
+ weights,
493
+ hassign.data(),
494
+ centroids.data());
495
+
496
+ int nsplit = split_clusters(
497
+ d, k, nx, k_frozen, hassign.data(), centroids.data());
465
498
 
466
499
  // collect statistics
467
- ClusteringIterationStats stats =
468
- { obj, (getmillisecs() - t0) / 1000.0,
469
- t_search_tot / 1000,
470
- imbalance_factor (nx, k, assign.get()),
471
- nsplit };
500
+ ClusteringIterationStats stats = {
501
+ obj,
502
+ (getmillisecs() - t0) / 1000.0,
503
+ t_search_tot / 1000,
504
+ imbalance_factor(nx, k, assign.get()),
505
+ nsplit};
472
506
  iteration_stats.push_back(stats);
473
507
 
474
508
  if (verbose) {
475
- printf (" Iteration %d (%.2f s, search %.2f s): "
476
- "objective=%g imbalance=%.3f nsplit=%d \r",
477
- i, stats.time, stats.time_search, stats.obj,
478
- stats.imbalance_factor, nsplit);
479
- fflush (stdout);
509
+ printf(" Iteration %d (%.2f s, search %.2f s): "
510
+ "objective=%g imbalance=%.3f nsplit=%d \r",
511
+ i,
512
+ stats.time,
513
+ stats.time_search,
514
+ stats.obj,
515
+ stats.imbalance_factor,
516
+ nsplit);
517
+ fflush(stdout);
480
518
  }
481
519
 
482
- post_process_centroids ();
520
+ post_process_centroids();
483
521
 
484
522
  // add centroids to index for the next iteration (or for output)
485
523
 
486
- index.reset ();
524
+ index.reset();
487
525
  if (update_index) {
488
- index.train (k, centroids.data());
526
+ index.train(k, centroids.data());
489
527
  }
490
528
 
491
- index.add (k, centroids.data());
492
- InterruptCallback::check ();
529
+ index.add(k, centroids.data());
530
+ InterruptCallback::check();
493
531
  }
494
532
 
495
- if (verbose) printf("\n");
533
+ if (verbose)
534
+ printf("\n");
496
535
  if (nredo > 1) {
497
536
  if ((lower_is_better && obj < best_obj) ||
498
537
  (!lower_is_better && obj > best_obj)) {
499
538
  if (verbose) {
500
- printf ("Objective improved: keep new clusters\n");
539
+ printf("Objective improved: keep new clusters\n");
501
540
  }
502
541
  best_centroids = centroids;
503
542
  best_iteration_stats = iteration_stats;
504
543
  best_obj = obj;
505
544
  }
506
- index.reset ();
545
+ index.reset();
507
546
  }
508
547
  }
509
548
  if (nredo > 1) {
@@ -512,20 +551,120 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
512
551
  index.reset();
513
552
  index.add(k, best_centroids.data());
514
553
  }
515
-
516
554
  }
517
555
 
518
- float kmeans_clustering (size_t d, size_t n, size_t k,
519
- const float *x,
520
- float *centroids)
521
- {
522
- Clustering clus (d, k);
556
+ float kmeans_clustering(
557
+ size_t d,
558
+ size_t n,
559
+ size_t k,
560
+ const float* x,
561
+ float* centroids) {
562
+ Clustering clus(d, k);
523
563
  clus.verbose = d * n * k > (1L << 30);
524
564
  // display logs if > 1Gflop per iteration
525
- IndexFlatL2 index (d);
526
- clus.train (n, x, index);
565
+ IndexFlatL2 index(d);
566
+ clus.train(n, x, index);
527
567
  memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
528
568
  return clus.iteration_stats.back().obj;
529
569
  }
530
570
 
571
+ /******************************************************************************
572
+ * ProgressiveDimClustering implementation
573
+ ******************************************************************************/
574
+
575
+ ProgressiveDimClusteringParameters::ProgressiveDimClusteringParameters() {
576
+ progressive_dim_steps = 10;
577
+ apply_pca = true; // seems a good idea to do this by default
578
+ niter = 10; // reduce nb of iterations per step
579
+ }
580
+
581
+ Index* ProgressiveDimIndexFactory::operator()(int dim) {
582
+ return new IndexFlatL2(dim);
583
+ }
584
+
585
+ ProgressiveDimClustering::ProgressiveDimClustering(int d, int k) : d(d), k(k) {}
586
+
587
+ ProgressiveDimClustering::ProgressiveDimClustering(
588
+ int d,
589
+ int k,
590
+ const ProgressiveDimClusteringParameters& cp)
591
+ : ProgressiveDimClusteringParameters(cp), d(d), k(k) {}
592
+
593
+ namespace {
594
+
595
+ using idx_t = Index::idx_t;
596
+
597
+ void copy_columns(idx_t n, idx_t d1, const float* src, idx_t d2, float* dest) {
598
+ idx_t d = std::min(d1, d2);
599
+ for (idx_t i = 0; i < n; i++) {
600
+ memcpy(dest, src, sizeof(float) * d);
601
+ src += d1;
602
+ dest += d2;
603
+ }
604
+ }
605
+
606
+ }; // namespace
607
+
608
+ void ProgressiveDimClustering::train(
609
+ idx_t n,
610
+ const float* x,
611
+ ProgressiveDimIndexFactory& factory) {
612
+ int d_prev = 0;
613
+
614
+ PCAMatrix pca(d, d);
615
+
616
+ std::vector<float> xbuf;
617
+ if (apply_pca) {
618
+ if (verbose) {
619
+ printf("Training PCA transform\n");
620
+ }
621
+ pca.train(n, x);
622
+ if (verbose) {
623
+ printf("Apply PCA\n");
624
+ }
625
+ xbuf.resize(n * d);
626
+ pca.apply_noalloc(n, x, xbuf.data());
627
+ x = xbuf.data();
628
+ }
629
+
630
+ for (int iter = 0; iter < progressive_dim_steps; iter++) {
631
+ int di = int(pow(d, (1. + iter) / progressive_dim_steps));
632
+ if (verbose) {
633
+ printf("Progressive dim step %d: cluster in dimension %d\n",
634
+ iter,
635
+ di);
636
+ }
637
+ std::unique_ptr<Index> clustering_index(factory(di));
638
+
639
+ Clustering clus(di, k, *this);
640
+ if (d_prev > 0) {
641
+ // copy warm-start centroids (padded with 0s)
642
+ clus.centroids.resize(k * di);
643
+ copy_columns(
644
+ k, d_prev, centroids.data(), di, clus.centroids.data());
645
+ }
646
+ std::vector<float> xsub(n * di);
647
+ copy_columns(n, d, x, di, xsub.data());
648
+
649
+ clus.train(n, xsub.data(), *clustering_index.get());
650
+
651
+ centroids = clus.centroids;
652
+ iteration_stats.insert(
653
+ iteration_stats.end(),
654
+ clus.iteration_stats.begin(),
655
+ clus.iteration_stats.end());
656
+
657
+ d_prev = di;
658
+ }
659
+
660
+ if (apply_pca) {
661
+ if (verbose) {
662
+ printf("Revert PCA transform on centroids\n");
663
+ }
664
+ std::vector<float> cent_transformed(d * k);
665
+ pca.reverse_transform(k, centroids.data(), cent_transformed.data());
666
+ cent_transformed.swap(centroids);
667
+ }
668
+ }
669
+
531
670
  } // namespace faiss