faiss 0.1.7 → 0.2.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +18 -0
  3. data/README.md +7 -7
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +8 -2
  6. data/ext/faiss/index.cpp +102 -69
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +0 -5
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +26 -12
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -8,6 +8,7 @@
8
8
  // -*- c++ -*-
9
9
 
10
10
  #include <faiss/Clustering.h>
11
+ #include <faiss/VectorTransform.h>
11
12
  #include <faiss/impl/AuxIndexStructures.h>
12
13
 
13
14
  #include <cinttypes>
@@ -17,100 +18,100 @@
17
18
 
18
19
  #include <omp.h>
19
20
 
20
- #include <faiss/utils/utils.h>
21
- #include <faiss/utils/random.h>
22
- #include <faiss/utils/distances.h>
23
- #include <faiss/impl/FaissAssert.h>
24
21
  #include <faiss/IndexFlat.h>
22
+ #include <faiss/impl/FaissAssert.h>
23
+ #include <faiss/utils/distances.h>
24
+ #include <faiss/utils/random.h>
25
+ #include <faiss/utils/utils.h>
25
26
 
26
27
  namespace faiss {
27
28
 
28
- ClusteringParameters::ClusteringParameters ():
29
- niter(25),
30
- nredo(1),
31
- verbose(false),
32
- spherical(false),
33
- int_centroids(false),
34
- update_index(false),
35
- frozen_centroids(false),
36
- min_points_per_centroid(39),
37
- max_points_per_centroid(256),
38
- seed(1234),
39
- decode_block_size(32768)
40
- {}
29
+ ClusteringParameters::ClusteringParameters()
30
+ : niter(25),
31
+ nredo(1),
32
+ verbose(false),
33
+ spherical(false),
34
+ int_centroids(false),
35
+ update_index(false),
36
+ frozen_centroids(false),
37
+ min_points_per_centroid(39),
38
+ max_points_per_centroid(256),
39
+ seed(1234),
40
+ decode_block_size(32768) {}
41
41
  // 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
42
42
 
43
+ Clustering::Clustering(int d, int k) : d(d), k(k) {}
43
44
 
44
- Clustering::Clustering (int d, int k):
45
- d(d), k(k) {}
46
-
47
- Clustering::Clustering (int d, int k, const ClusteringParameters &cp):
48
- ClusteringParameters (cp), d(d), k(k) {}
45
+ Clustering::Clustering(int d, int k, const ClusteringParameters& cp)
46
+ : ClusteringParameters(cp), d(d), k(k) {}
49
47
 
50
-
51
-
52
- static double imbalance_factor (int n, int k, int64_t *assign) {
48
+ static double imbalance_factor(int n, int k, int64_t* assign) {
53
49
  std::vector<int> hist(k, 0);
54
50
  for (int i = 0; i < n; i++)
55
51
  hist[assign[i]]++;
56
52
 
57
53
  double tot = 0, uf = 0;
58
54
 
59
- for (int i = 0 ; i < k ; i++) {
55
+ for (int i = 0; i < k; i++) {
60
56
  tot += hist[i];
61
- uf += hist[i] * (double) hist[i];
57
+ uf += hist[i] * (double)hist[i];
62
58
  }
63
59
  uf = uf * k / (tot * tot);
64
60
 
65
61
  return uf;
66
62
  }
67
63
 
68
- void Clustering::post_process_centroids ()
69
- {
70
-
64
+ void Clustering::post_process_centroids() {
71
65
  if (spherical) {
72
- fvec_renorm_L2 (d, k, centroids.data());
66
+ fvec_renorm_L2(d, k, centroids.data());
73
67
  }
74
68
 
75
69
  if (int_centroids) {
76
70
  for (size_t i = 0; i < centroids.size(); i++)
77
- centroids[i] = roundf (centroids[i]);
71
+ centroids[i] = roundf(centroids[i]);
78
72
  }
79
73
  }
80
74
 
81
-
82
- void Clustering::train (idx_t nx, const float *x_in, Index & index,
83
- const float *weights) {
84
- train_encoded (nx, reinterpret_cast<const uint8_t *>(x_in), nullptr,
85
- index, weights);
75
+ void Clustering::train(
76
+ idx_t nx,
77
+ const float* x_in,
78
+ Index& index,
79
+ const float* weights) {
80
+ train_encoded(
81
+ nx,
82
+ reinterpret_cast<const uint8_t*>(x_in),
83
+ nullptr,
84
+ index,
85
+ weights);
86
86
  }
87
87
 
88
-
89
88
  namespace {
90
89
 
91
90
  using idx_t = Clustering::idx_t;
92
91
 
93
92
  idx_t subsample_training_set(
94
- const Clustering &clus, idx_t nx, const uint8_t *x,
95
- size_t line_size, const float * weights,
96
- uint8_t **x_out,
97
- float **weights_out
98
- )
99
- {
93
+ const Clustering& clus,
94
+ idx_t nx,
95
+ const uint8_t* x,
96
+ size_t line_size,
97
+ const float* weights,
98
+ uint8_t** x_out,
99
+ float** weights_out) {
100
100
  if (clus.verbose) {
101
101
  printf("Sampling a subset of %zd / %" PRId64 " for training\n",
102
- clus.k * clus.max_points_per_centroid, nx);
102
+ clus.k * clus.max_points_per_centroid,
103
+ nx);
103
104
  }
104
- std::vector<int> perm (nx);
105
- rand_perm (perm.data (), nx, clus.seed);
105
+ std::vector<int> perm(nx);
106
+ rand_perm(perm.data(), nx, clus.seed);
106
107
  nx = clus.k * clus.max_points_per_centroid;
107
- uint8_t * x_new = new uint8_t [nx * line_size];
108
+ uint8_t* x_new = new uint8_t[nx * line_size];
108
109
  *x_out = x_new;
109
110
  for (idx_t i = 0; i < nx; i++) {
110
- memcpy (x_new + i * line_size, x + perm[i] * line_size, line_size);
111
+ memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size);
111
112
  }
112
113
  if (weights) {
113
- float *weights_new = new float[nx];
114
+ float* weights_new = new float[nx];
114
115
  for (idx_t i = 0; i < nx; i++) {
115
116
  weights_new[i] = weights[perm[i]];
116
117
  }
@@ -134,20 +135,23 @@ idx_t subsample_training_set(
134
135
  *
135
136
  */
136
137
 
137
- void compute_centroids (size_t d, size_t k, size_t n,
138
- size_t k_frozen,
139
- const uint8_t * x, const Index *codec,
140
- const int64_t * assign,
141
- const float * weights,
142
- float * hassign,
143
- float * centroids)
144
- {
138
+ void compute_centroids(
139
+ size_t d,
140
+ size_t k,
141
+ size_t n,
142
+ size_t k_frozen,
143
+ const uint8_t* x,
144
+ const Index* codec,
145
+ const int64_t* assign,
146
+ const float* weights,
147
+ float* hassign,
148
+ float* centroids) {
145
149
  k -= k_frozen;
146
150
  centroids += k_frozen * d;
147
151
 
148
- memset (centroids, 0, sizeof(*centroids) * d * k);
152
+ memset(centroids, 0, sizeof(*centroids) * d * k);
149
153
 
150
- size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float);
154
+ size_t line_size = codec ? codec->sa_code_size() : d * sizeof(float);
151
155
 
152
156
  #pragma omp parallel
153
157
  {
@@ -157,20 +161,20 @@ void compute_centroids (size_t d, size_t k, size_t n,
157
161
  // this thread is taking care of centroids c0:c1
158
162
  size_t c0 = (k * rank) / nt;
159
163
  size_t c1 = (k * (rank + 1)) / nt;
160
- std::vector<float> decode_buffer (d);
164
+ std::vector<float> decode_buffer(d);
161
165
 
162
166
  for (size_t i = 0; i < n; i++) {
163
167
  int64_t ci = assign[i];
164
- assert (ci >= 0 && ci < k + k_frozen);
168
+ assert(ci >= 0 && ci < k + k_frozen);
165
169
  ci -= k_frozen;
166
- if (ci >= c0 && ci < c1) {
167
- float * c = centroids + ci * d;
168
- const float * xi;
170
+ if (ci >= c0 && ci < c1) {
171
+ float* c = centroids + ci * d;
172
+ const float* xi;
169
173
  if (!codec) {
170
174
  xi = reinterpret_cast<const float*>(x + i * line_size);
171
175
  } else {
172
- float *xif = decode_buffer.data();
173
- codec->sa_decode (1, x + i * line_size, xif);
176
+ float* xif = decode_buffer.data();
177
+ codec->sa_decode(1, x + i * line_size, xif);
174
178
  xi = xif;
175
179
  }
176
180
  if (weights) {
@@ -187,7 +191,6 @@ void compute_centroids (size_t d, size_t k, size_t n,
187
191
  }
188
192
  }
189
193
  }
190
-
191
194
  }
192
195
 
193
196
  #pragma omp parallel for
@@ -196,12 +199,11 @@ void compute_centroids (size_t d, size_t k, size_t n,
196
199
  continue;
197
200
  }
198
201
  float norm = 1 / hassign[ci];
199
- float * c = centroids + ci * d;
202
+ float* c = centroids + ci * d;
200
203
  for (size_t j = 0; j < d; j++) {
201
204
  c[j] *= norm;
202
205
  }
203
206
  }
204
-
205
207
  }
206
208
 
207
209
  // a bit above machine epsilon for float16
@@ -214,29 +216,33 @@ void compute_centroids (size_t d, size_t k, size_t n,
214
216
  *
215
217
  * @return nb of spliting operations (larger is worse)
216
218
  */
217
- int split_clusters (size_t d, size_t k, size_t n,
218
- size_t k_frozen,
219
- float * hassign,
220
- float * centroids)
221
- {
219
+ int split_clusters(
220
+ size_t d,
221
+ size_t k,
222
+ size_t n,
223
+ size_t k_frozen,
224
+ float* hassign,
225
+ float* centroids) {
222
226
  k -= k_frozen;
223
227
  centroids += k_frozen * d;
224
228
 
225
229
  /* Take care of void clusters */
226
230
  size_t nsplit = 0;
227
- RandomGenerator rng (1234);
231
+ RandomGenerator rng(1234);
228
232
  for (size_t ci = 0; ci < k; ci++) {
229
233
  if (hassign[ci] == 0) { /* need to redefine a centroid */
230
234
  size_t cj;
231
235
  for (cj = 0; 1; cj = (cj + 1) % k) {
232
236
  /* probability to pick this cluster for split */
233
- float p = (hassign[cj] - 1.0) / (float) (n - k);
234
- float r = rng.rand_float ();
237
+ float p = (hassign[cj] - 1.0) / (float)(n - k);
238
+ float r = rng.rand_float();
235
239
  if (r < p) {
236
240
  break; /* found our cluster to be split */
237
241
  }
238
242
  }
239
- memcpy (centroids+ci*d, centroids+cj*d, sizeof(*centroids) * d);
243
+ memcpy(centroids + ci * d,
244
+ centroids + cj * d,
245
+ sizeof(*centroids) * d);
240
246
 
241
247
  /* small symmetric pertubation */
242
248
  for (size_t j = 0; j < d; j++) {
@@ -257,30 +263,35 @@ int split_clusters (size_t d, size_t k, size_t n,
257
263
  }
258
264
 
259
265
  return nsplit;
260
-
261
266
  }
262
267
 
263
-
264
-
265
- };
266
-
267
-
268
- void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
269
- const Index * codec, Index & index,
270
- const float *weights) {
271
-
272
-
273
- FAISS_THROW_IF_NOT_FMT (nx >= k,
274
- "Number of training points (%" PRId64 ") should be at least "
275
- "as large as number of clusters (%zd)", nx, k);
276
-
277
- FAISS_THROW_IF_NOT_FMT ((!codec || codec->d == d),
278
- "Codec dimension %d not the same as data dimension %d",
279
- int(codec->d), int(d));
280
-
281
- FAISS_THROW_IF_NOT_FMT (index.d == d,
268
+ }; // namespace
269
+
270
+ void Clustering::train_encoded(
271
+ idx_t nx,
272
+ const uint8_t* x_in,
273
+ const Index* codec,
274
+ Index& index,
275
+ const float* weights) {
276
+ FAISS_THROW_IF_NOT_FMT(
277
+ nx >= k,
278
+ "Number of training points (%" PRId64
279
+ ") should be at least "
280
+ "as large as number of clusters (%zd)",
281
+ nx,
282
+ k);
283
+
284
+ FAISS_THROW_IF_NOT_FMT(
285
+ (!codec || codec->d == d),
286
+ "Codec dimension %d not the same as data dimension %d",
287
+ int(codec->d),
288
+ int(d));
289
+
290
+ FAISS_THROW_IF_NOT_FMT(
291
+ index.d == d,
282
292
  "Index dimension %d not the same as data dimension %d",
283
- int(index.d), int(d));
293
+ int(index.d),
294
+ int(d));
284
295
 
285
296
  double t0 = getmillisecs();
286
297
 
@@ -288,67 +299,78 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
288
299
  // Check for NaNs in input data. Normally it is the user's
289
300
  // responsibility, but it may spare us some hard-to-debug
290
301
  // reports.
291
- const float *x = reinterpret_cast<const float *>(x_in);
302
+ const float* x = reinterpret_cast<const float*>(x_in);
292
303
  for (size_t i = 0; i < nx * d; i++) {
293
- FAISS_THROW_IF_NOT_MSG (std::isfinite (x[i]),
294
- "input contains NaN's or Inf's");
304
+ FAISS_THROW_IF_NOT_MSG(
305
+ std::isfinite(x[i]), "input contains NaN's or Inf's");
295
306
  }
296
307
  }
297
308
 
298
- const uint8_t *x = x_in;
299
- std::unique_ptr<uint8_t []> del1;
300
- std::unique_ptr<float []> del3;
309
+ const uint8_t* x = x_in;
310
+ std::unique_ptr<uint8_t[]> del1;
311
+ std::unique_ptr<float[]> del3;
301
312
  size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
302
313
 
303
314
  if (nx > k * max_points_per_centroid) {
304
- uint8_t *x_new;
305
- float *weights_new;
306
- nx = subsample_training_set (*this, nx, x, line_size, weights,
307
- &x_new, &weights_new);
308
- del1.reset (x_new); x = x_new;
309
- del3.reset (weights_new); weights = weights_new;
315
+ uint8_t* x_new;
316
+ float* weights_new;
317
+ nx = subsample_training_set(
318
+ *this, nx, x, line_size, weights, &x_new, &weights_new);
319
+ del1.reset(x_new);
320
+ x = x_new;
321
+ del3.reset(weights_new);
322
+ weights = weights_new;
310
323
  } else if (nx < k * min_points_per_centroid) {
311
- fprintf (stderr,
312
- "WARNING clustering %" PRId64 " points to %zd centroids: "
313
- "please provide at least %" PRId64 " training points\n",
314
- nx, k, idx_t(k) * min_points_per_centroid);
324
+ fprintf(stderr,
325
+ "WARNING clustering %" PRId64
326
+ " points to %zd centroids: "
327
+ "please provide at least %" PRId64 " training points\n",
328
+ nx,
329
+ k,
330
+ idx_t(k) * min_points_per_centroid);
315
331
  }
316
332
 
317
333
  if (nx == k) {
318
334
  // this is a corner case, just copy training set to clusters
319
335
  if (verbose) {
320
- printf("Number of training points (%" PRId64 ") same as number of "
321
- "clusters, just copying\n", nx);
336
+ printf("Number of training points (%" PRId64
337
+ ") same as number of "
338
+ "clusters, just copying\n",
339
+ nx);
322
340
  }
323
- centroids.resize (d * k);
341
+ centroids.resize(d * k);
324
342
  if (!codec) {
325
- memcpy (centroids.data(), x_in, sizeof (float) * d * k);
343
+ memcpy(centroids.data(), x_in, sizeof(float) * d * k);
326
344
  } else {
327
- codec->sa_decode (nx, x_in, centroids.data());
345
+ codec->sa_decode(nx, x_in, centroids.data());
328
346
  }
329
347
 
330
348
  // one fake iteration...
331
- ClusteringIterationStats stats = { 0.0, 0.0, 0.0, 1.0, 0 };
332
- iteration_stats.push_back (stats);
349
+ ClusteringIterationStats stats = {0.0, 0.0, 0.0, 1.0, 0};
350
+ iteration_stats.push_back(stats);
333
351
 
334
352
  index.reset();
335
353
  index.add(k, centroids.data());
336
354
  return;
337
355
  }
338
356
 
339
-
340
357
  if (verbose) {
341
- printf("Clustering %" PRId64 " points in %zdD to %zd clusters, "
358
+ printf("Clustering %" PRId64
359
+ " points in %zdD to %zd clusters, "
342
360
  "redo %d times, %d iterations\n",
343
- nx, d, k, nredo, niter);
361
+ nx,
362
+ d,
363
+ k,
364
+ nredo,
365
+ niter);
344
366
  if (codec) {
345
367
  printf("Input data encoded in %zd bytes per vector\n",
346
- codec->sa_code_size ());
368
+ codec->sa_code_size());
347
369
  }
348
370
  }
349
371
 
350
- std::unique_ptr<idx_t []> assign(new idx_t[nx]);
351
- std::unique_ptr<float []> dis(new float[nx]);
372
+ std::unique_ptr<idx_t[]> assign(new idx_t[nx]);
373
+ std::unique_ptr<float[]> dis(new float[nx]);
352
374
 
353
375
  // remember best iteration for redo
354
376
  bool lower_is_better = index.metric_type != METRIC_INNER_PRODUCT;
@@ -358,52 +380,49 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
358
380
 
359
381
  // support input centroids
360
382
 
361
- FAISS_THROW_IF_NOT_MSG (
362
- centroids.size() % d == 0,
363
- "size of provided input centroids not a multiple of dimension"
364
- );
383
+ FAISS_THROW_IF_NOT_MSG(
384
+ centroids.size() % d == 0,
385
+ "size of provided input centroids not a multiple of dimension");
365
386
 
366
387
  size_t n_input_centroids = centroids.size() / d;
367
388
 
368
389
  if (verbose && n_input_centroids > 0) {
369
- printf (" Using %zd centroids provided as input (%sfrozen)\n",
370
- n_input_centroids, frozen_centroids ? "" : "not ");
390
+ printf(" Using %zd centroids provided as input (%sfrozen)\n",
391
+ n_input_centroids,
392
+ frozen_centroids ? "" : "not ");
371
393
  }
372
394
 
373
395
  double t_search_tot = 0;
374
396
  if (verbose) {
375
- printf(" Preprocessing in %.2f s\n",
376
- (getmillisecs() - t0) / 1000.);
397
+ printf(" Preprocessing in %.2f s\n", (getmillisecs() - t0) / 1000.);
377
398
  }
378
399
  t0 = getmillisecs();
379
400
 
380
401
  // temporary buffer to decode vectors during the optimization
381
- std::vector<float> decode_buffer
382
- (codec ? d * decode_block_size : 0);
402
+ std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
383
403
 
384
404
  for (int redo = 0; redo < nredo; redo++) {
385
-
386
405
  if (verbose && nredo > 1) {
387
406
  printf("Outer iteration %d / %d\n", redo, nredo);
388
407
  }
389
408
 
390
409
  // initialize (remaining) centroids with random points from the dataset
391
- centroids.resize (d * k);
392
- std::vector<int> perm (nx);
410
+ centroids.resize(d * k);
411
+ std::vector<int> perm(nx);
393
412
 
394
- rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
413
+ rand_perm(perm.data(), nx, seed + 1 + redo * 15486557L);
395
414
 
396
415
  if (!codec) {
397
- for (int i = n_input_centroids; i < k ; i++) {
398
- memcpy (&centroids[i * d], x + perm[i] * line_size, line_size);
416
+ for (int i = n_input_centroids; i < k; i++) {
417
+ memcpy(&centroids[i * d], x + perm[i] * line_size, line_size);
399
418
  }
400
419
  } else {
401
- for (int i = n_input_centroids; i < k ; i++) {
402
- codec->sa_decode (1, x + perm[i] * line_size, &centroids[i * d]);
420
+ for (int i = n_input_centroids; i < k; i++) {
421
+ codec->sa_decode(1, x + perm[i] * line_size, &centroids[i * d]);
403
422
  }
404
423
  }
405
424
 
406
- post_process_centroids ();
425
+ post_process_centroids();
407
426
 
408
427
  // prepare the index
409
428
 
@@ -412,10 +431,10 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
412
431
  }
413
432
 
414
433
  if (!index.is_trained) {
415
- index.train (k, centroids.data());
434
+ index.train(k, centroids.data());
416
435
  }
417
436
 
418
- index.add (k, centroids.data());
437
+ index.add(k, centroids.data());
419
438
 
420
439
  // k-means iterations
421
440
 
@@ -424,18 +443,28 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
424
443
  double t0s = getmillisecs();
425
444
 
426
445
  if (!codec) {
427
- index.search (nx, reinterpret_cast<const float *>(x), 1,
428
- dis.get(), assign.get());
446
+ index.search(
447
+ nx,
448
+ reinterpret_cast<const float*>(x),
449
+ 1,
450
+ dis.get(),
451
+ assign.get());
429
452
  } else {
430
453
  // search by blocks of decode_block_size vectors
431
- size_t code_size = codec->sa_code_size ();
454
+ size_t code_size = codec->sa_code_size();
432
455
  for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
433
456
  size_t i1 = i0 + decode_block_size;
434
- if (i1 > nx) { i1 = nx; }
435
- codec->sa_decode (i1 - i0, x + code_size * i0,
436
- decode_buffer.data ());
437
- index.search (i1 - i0, decode_buffer.data (), 1,
438
- dis.get() + i0, assign.get() + i0);
457
+ if (i1 > nx) {
458
+ i1 = nx;
459
+ }
460
+ codec->sa_decode(
461
+ i1 - i0, x + code_size * i0, decode_buffer.data());
462
+ index.search(
463
+ i1 - i0,
464
+ decode_buffer.data(),
465
+ 1,
466
+ dis.get() + i0,
467
+ assign.get() + i0);
439
468
  }
440
469
  }
441
470
 
@@ -449,61 +478,71 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
449
478
  }
450
479
 
451
480
  // update the centroids
452
- std::vector<float> hassign (k);
481
+ std::vector<float> hassign(k);
453
482
 
454
483
  size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
455
- compute_centroids (
456
- d, k, nx, k_frozen,
457
- x, codec, assign.get(), weights,
458
- hassign.data(), centroids.data()
459
- );
460
-
461
- int nsplit = split_clusters (
462
- d, k, nx, k_frozen,
463
- hassign.data(), centroids.data()
464
- );
484
+ compute_centroids(
485
+ d,
486
+ k,
487
+ nx,
488
+ k_frozen,
489
+ x,
490
+ codec,
491
+ assign.get(),
492
+ weights,
493
+ hassign.data(),
494
+ centroids.data());
495
+
496
+ int nsplit = split_clusters(
497
+ d, k, nx, k_frozen, hassign.data(), centroids.data());
465
498
 
466
499
  // collect statistics
467
- ClusteringIterationStats stats =
468
- { obj, (getmillisecs() - t0) / 1000.0,
469
- t_search_tot / 1000,
470
- imbalance_factor (nx, k, assign.get()),
471
- nsplit };
500
+ ClusteringIterationStats stats = {
501
+ obj,
502
+ (getmillisecs() - t0) / 1000.0,
503
+ t_search_tot / 1000,
504
+ imbalance_factor(nx, k, assign.get()),
505
+ nsplit};
472
506
  iteration_stats.push_back(stats);
473
507
 
474
508
  if (verbose) {
475
- printf (" Iteration %d (%.2f s, search %.2f s): "
476
- "objective=%g imbalance=%.3f nsplit=%d \r",
477
- i, stats.time, stats.time_search, stats.obj,
478
- stats.imbalance_factor, nsplit);
479
- fflush (stdout);
509
+ printf(" Iteration %d (%.2f s, search %.2f s): "
510
+ "objective=%g imbalance=%.3f nsplit=%d \r",
511
+ i,
512
+ stats.time,
513
+ stats.time_search,
514
+ stats.obj,
515
+ stats.imbalance_factor,
516
+ nsplit);
517
+ fflush(stdout);
480
518
  }
481
519
 
482
- post_process_centroids ();
520
+ post_process_centroids();
483
521
 
484
522
  // add centroids to index for the next iteration (or for output)
485
523
 
486
- index.reset ();
524
+ index.reset();
487
525
  if (update_index) {
488
- index.train (k, centroids.data());
526
+ index.train(k, centroids.data());
489
527
  }
490
528
 
491
- index.add (k, centroids.data());
492
- InterruptCallback::check ();
529
+ index.add(k, centroids.data());
530
+ InterruptCallback::check();
493
531
  }
494
532
 
495
- if (verbose) printf("\n");
533
+ if (verbose)
534
+ printf("\n");
496
535
  if (nredo > 1) {
497
536
  if ((lower_is_better && obj < best_obj) ||
498
537
  (!lower_is_better && obj > best_obj)) {
499
538
  if (verbose) {
500
- printf ("Objective improved: keep new clusters\n");
539
+ printf("Objective improved: keep new clusters\n");
501
540
  }
502
541
  best_centroids = centroids;
503
542
  best_iteration_stats = iteration_stats;
504
543
  best_obj = obj;
505
544
  }
506
- index.reset ();
545
+ index.reset();
507
546
  }
508
547
  }
509
548
  if (nredo > 1) {
@@ -512,20 +551,120 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
512
551
  index.reset();
513
552
  index.add(k, best_centroids.data());
514
553
  }
515
-
516
554
  }
517
555
 
518
- float kmeans_clustering (size_t d, size_t n, size_t k,
519
- const float *x,
520
- float *centroids)
521
- {
522
- Clustering clus (d, k);
556
+ float kmeans_clustering(
557
+ size_t d,
558
+ size_t n,
559
+ size_t k,
560
+ const float* x,
561
+ float* centroids) {
562
+ Clustering clus(d, k);
523
563
  clus.verbose = d * n * k > (1L << 30);
524
564
  // display logs if > 1Gflop per iteration
525
- IndexFlatL2 index (d);
526
- clus.train (n, x, index);
565
+ IndexFlatL2 index(d);
566
+ clus.train(n, x, index);
527
567
  memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
528
568
  return clus.iteration_stats.back().obj;
529
569
  }
530
570
 
571
+ /******************************************************************************
572
+ * ProgressiveDimClustering implementation
573
+ ******************************************************************************/
574
+
575
+ ProgressiveDimClusteringParameters::ProgressiveDimClusteringParameters() {
576
+ progressive_dim_steps = 10;
577
+ apply_pca = true; // seems a good idea to do this by default
578
+ niter = 10; // reduce nb of iterations per step
579
+ }
580
+
581
+ Index* ProgressiveDimIndexFactory::operator()(int dim) {
582
+ return new IndexFlatL2(dim);
583
+ }
584
+
585
+ ProgressiveDimClustering::ProgressiveDimClustering(int d, int k) : d(d), k(k) {}
586
+
587
+ ProgressiveDimClustering::ProgressiveDimClustering(
588
+ int d,
589
+ int k,
590
+ const ProgressiveDimClusteringParameters& cp)
591
+ : ProgressiveDimClusteringParameters(cp), d(d), k(k) {}
592
+
593
+ namespace {
594
+
595
+ using idx_t = Index::idx_t;
596
+
597
+ void copy_columns(idx_t n, idx_t d1, const float* src, idx_t d2, float* dest) {
598
+ idx_t d = std::min(d1, d2);
599
+ for (idx_t i = 0; i < n; i++) {
600
+ memcpy(dest, src, sizeof(float) * d);
601
+ src += d1;
602
+ dest += d2;
603
+ }
604
+ }
605
+
606
+ }; // namespace
607
+
608
+ void ProgressiveDimClustering::train(
609
+ idx_t n,
610
+ const float* x,
611
+ ProgressiveDimIndexFactory& factory) {
612
+ int d_prev = 0;
613
+
614
+ PCAMatrix pca(d, d);
615
+
616
+ std::vector<float> xbuf;
617
+ if (apply_pca) {
618
+ if (verbose) {
619
+ printf("Training PCA transform\n");
620
+ }
621
+ pca.train(n, x);
622
+ if (verbose) {
623
+ printf("Apply PCA\n");
624
+ }
625
+ xbuf.resize(n * d);
626
+ pca.apply_noalloc(n, x, xbuf.data());
627
+ x = xbuf.data();
628
+ }
629
+
630
+ for (int iter = 0; iter < progressive_dim_steps; iter++) {
631
+ int di = int(pow(d, (1. + iter) / progressive_dim_steps));
632
+ if (verbose) {
633
+ printf("Progressive dim step %d: cluster in dimension %d\n",
634
+ iter,
635
+ di);
636
+ }
637
+ std::unique_ptr<Index> clustering_index(factory(di));
638
+
639
+ Clustering clus(di, k, *this);
640
+ if (d_prev > 0) {
641
+ // copy warm-start centroids (padded with 0s)
642
+ clus.centroids.resize(k * di);
643
+ copy_columns(
644
+ k, d_prev, centroids.data(), di, clus.centroids.data());
645
+ }
646
+ std::vector<float> xsub(n * di);
647
+ copy_columns(n, d, x, di, xsub.data());
648
+
649
+ clus.train(n, xsub.data(), *clustering_index.get());
650
+
651
+ centroids = clus.centroids;
652
+ iteration_stats.insert(
653
+ iteration_stats.end(),
654
+ clus.iteration_stats.begin(),
655
+ clus.iteration_stats.end());
656
+
657
+ d_prev = di;
658
+ }
659
+
660
+ if (apply_pca) {
661
+ if (verbose) {
662
+ printf("Revert PCA transform on centroids\n");
663
+ }
664
+ std::vector<float> cent_transformed(d * k);
665
+ pca.reverse_transform(k, centroids.data(), cent_transformed.data());
666
+ cent_transformed.swap(centroids);
667
+ }
668
+ }
669
+
531
670
  } // namespace faiss