faiss 0.1.5 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +24 -0
  3. data/README.md +12 -0
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +6 -2
  6. data/ext/faiss/index.cpp +114 -43
  7. data/ext/faiss/index_binary.cpp +24 -30
  8. data/ext/faiss/kmeans.cpp +20 -16
  9. data/ext/faiss/numo.hpp +867 -0
  10. data/ext/faiss/pca_matrix.cpp +13 -14
  11. data/ext/faiss/product_quantizer.cpp +23 -24
  12. data/ext/faiss/utils.cpp +10 -37
  13. data/ext/faiss/utils.h +2 -13
  14. data/lib/faiss.rb +0 -5
  15. data/lib/faiss/version.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  17. data/vendor/faiss/faiss/AutoTune.h +55 -56
  18. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  19. data/vendor/faiss/faiss/Clustering.h +88 -35
  20. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  21. data/vendor/faiss/faiss/IVFlib.h +48 -51
  22. data/vendor/faiss/faiss/Index.cpp +85 -103
  23. data/vendor/faiss/faiss/Index.h +54 -48
  24. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  25. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  26. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  27. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  28. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  29. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  30. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  31. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  32. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  33. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  34. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  35. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  36. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  37. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  38. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  39. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  40. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  41. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  42. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  43. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  54. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  55. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  69. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  73. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  76. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  77. data/vendor/faiss/faiss/IndexShards.h +85 -73
  78. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  79. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  81. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  82. data/vendor/faiss/faiss/MetricType.h +7 -7
  83. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  84. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  85. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  86. data/vendor/faiss/faiss/clone_index.h +4 -9
  87. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  88. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  89. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  90. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  91. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  96. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  102. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  103. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  104. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  106. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  108. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  110. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  112. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  113. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  114. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  115. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  116. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  121. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  122. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  124. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  125. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  126. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  128. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  129. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  130. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  131. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  133. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  135. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  136. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  137. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  138. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  139. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  140. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  142. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  144. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  145. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  146. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  148. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  149. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  151. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  153. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  154. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  156. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  157. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  158. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  159. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  160. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  161. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  162. data/vendor/faiss/faiss/impl/io.h +31 -41
  163. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  164. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  165. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  166. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  167. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  171. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  172. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  173. data/vendor/faiss/faiss/index_factory.h +6 -7
  174. data/vendor/faiss/faiss/index_io.h +23 -26
  175. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  177. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  178. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  179. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  180. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  181. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  183. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  185. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  186. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  187. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  188. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  189. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  190. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  191. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  192. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  193. data/vendor/faiss/faiss/utils/distances.h +133 -118
  194. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  195. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  196. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  197. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  198. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  199. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  200. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  201. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  202. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  203. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  204. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  205. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  206. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  207. data/vendor/faiss/faiss/utils/random.h +13 -16
  208. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  209. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  210. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  211. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  212. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  213. data/vendor/faiss/faiss/utils/utils.h +53 -48
  214. metadata +24 -10
  215. data/lib/faiss/index.rb +0 -20
  216. data/lib/faiss/index_binary.rb +0 -20
  217. data/lib/faiss/kmeans.rb +0 -15
  218. data/lib/faiss/pca_matrix.rb +0 -15
  219. data/lib/faiss/product_quantizer.rb +0 -22
@@ -10,38 +10,39 @@
10
10
  #ifndef FAISS_AUTO_TUNE_H
11
11
  #define FAISS_AUTO_TUNE_H
12
12
 
13
- #include <vector>
14
- #include <unordered_map>
15
13
  #include <stdint.h>
14
+ #include <unordered_map>
15
+ #include <vector>
16
16
 
17
17
  #include <faiss/Index.h>
18
18
  #include <faiss/IndexBinary.h>
19
19
 
20
20
  namespace faiss {
21
21
 
22
-
23
22
  /**
24
23
  * Evaluation criterion. Returns a performance measure in [0,1],
25
24
  * higher is better.
26
25
  */
27
26
  struct AutoTuneCriterion {
28
27
  typedef Index::idx_t idx_t;
29
- idx_t nq; ///< nb of queries this criterion is evaluated on
30
- idx_t nnn; ///< nb of NNs that the query should request
31
- idx_t gt_nnn; ///< nb of GT NNs required to evaluate crterion
28
+ idx_t nq; ///< nb of queries this criterion is evaluated on
29
+ idx_t nnn; ///< nb of NNs that the query should request
30
+ idx_t gt_nnn; ///< nb of GT NNs required to evaluate criterion
32
31
 
33
- std::vector<float> gt_D; ///< Ground-truth distances (size nq * gt_nnn)
34
- std::vector<idx_t> gt_I; ///< Ground-truth indexes (size nq * gt_nnn)
32
+ std::vector<float> gt_D; ///< Ground-truth distances (size nq * gt_nnn)
33
+ std::vector<idx_t> gt_I; ///< Ground-truth indexes (size nq * gt_nnn)
35
34
 
36
- AutoTuneCriterion (idx_t nq, idx_t nnn);
35
+ AutoTuneCriterion(idx_t nq, idx_t nnn);
37
36
 
38
37
  /** Intitializes the gt_D and gt_I vectors. Must be called before evaluating
39
38
  *
40
39
  * @param gt_D_in size nq * gt_nnn
41
40
  * @param gt_I_in size nq * gt_nnn
42
41
  */
43
- void set_groundtruth (int gt_nnn, const float *gt_D_in,
44
- const idx_t *gt_I_in);
42
+ void set_groundtruth(
43
+ int gt_nnn,
44
+ const float* gt_D_in,
45
+ const idx_t* gt_I_in);
45
46
 
46
47
  /** Evaluate the criterion.
47
48
  *
@@ -49,29 +50,25 @@ struct AutoTuneCriterion {
49
50
  * @param I size nq * nnn
50
51
  * @return the criterion, between 0 and 1. Larger is better.
51
52
  */
52
- virtual double evaluate (const float *D, const idx_t *I) const = 0;
53
-
54
- virtual ~AutoTuneCriterion () {}
53
+ virtual double evaluate(const float* D, const idx_t* I) const = 0;
55
54
 
55
+ virtual ~AutoTuneCriterion() {}
56
56
  };
57
57
 
58
- struct OneRecallAtRCriterion: AutoTuneCriterion {
59
-
58
+ struct OneRecallAtRCriterion : AutoTuneCriterion {
60
59
  idx_t R;
61
60
 
62
- OneRecallAtRCriterion (idx_t nq, idx_t R);
61
+ OneRecallAtRCriterion(idx_t nq, idx_t R);
63
62
 
64
63
  double evaluate(const float* D, const idx_t* I) const override;
65
64
 
66
65
  ~OneRecallAtRCriterion() override {}
67
66
  };
68
67
 
69
-
70
- struct IntersectionCriterion: AutoTuneCriterion {
71
-
68
+ struct IntersectionCriterion : AutoTuneCriterion {
72
69
  idx_t R;
73
70
 
74
- IntersectionCriterion (idx_t nq, idx_t R);
71
+ IntersectionCriterion(idx_t nq, idx_t R);
75
72
 
76
73
  double evaluate(const float* D, const idx_t* I) const override;
77
74
 
@@ -91,7 +88,7 @@ struct OperatingPoint {
91
88
  double perf; ///< performance measure (output of a Criterion)
92
89
  double t; ///< corresponding execution time (ms)
93
90
  std::string key; ///< key that identifies this op pt
94
- int64_t cno; ///< integer identifer
91
+ int64_t cno; ///< integer identifer
95
92
  };
96
93
 
97
94
  struct OperatingPoints {
@@ -102,27 +99,27 @@ struct OperatingPoints {
102
99
  std::vector<OperatingPoint> optimal_pts;
103
100
 
104
101
  // begins with a single operating point: t=0, perf=0
105
- OperatingPoints ();
102
+ OperatingPoints();
106
103
 
107
104
  /// add operating points from other to this, with a prefix to the keys
108
- int merge_with (const OperatingPoints &other,
109
- const std::string & prefix = "");
105
+ int merge_with(
106
+ const OperatingPoints& other,
107
+ const std::string& prefix = "");
110
108
 
111
- void clear ();
109
+ void clear();
112
110
 
113
111
  /// add a performance measure. Return whether it is an optimal point
114
- bool add (double perf, double t, const std::string & key, size_t cno = 0);
112
+ bool add(double perf, double t, const std::string& key, size_t cno = 0);
115
113
 
116
114
  /// get time required to obtain a given performance measure
117
- double t_for_perf (double perf) const;
115
+ double t_for_perf(double perf) const;
118
116
 
119
117
  /// easy-to-read output
120
- void display (bool only_optimal = true) const;
118
+ void display(bool only_optimal = true) const;
121
119
 
122
120
  /// output to a format easy to digest by gnuplot
123
- void all_to_gnuplot (const char *fname) const;
124
- void optimal_to_gnuplot (const char *fname) const;
125
-
121
+ void all_to_gnuplot(const char* fname) const;
122
+ void optimal_to_gnuplot(const char* fname) const;
126
123
  };
127
124
 
128
125
  /// possible values of a parameter, sorted from least to most expensive/accurate
@@ -156,41 +153,45 @@ struct ParameterSpace {
156
153
  /// duration (to avoid jittering in MT mode)
157
154
  double min_test_duration;
158
155
 
159
- ParameterSpace ();
156
+ ParameterSpace();
160
157
 
161
158
  /// nb of combinations, = product of values sizes
162
- size_t n_combinations () const;
159
+ size_t n_combinations() const;
163
160
 
164
161
  /// returns whether combinations c1 >= c2 in the tuple sense
165
- bool combination_ge (size_t c1, size_t c2) const;
162
+ bool combination_ge(size_t c1, size_t c2) const;
166
163
 
167
164
  /// get string representation of the combination
168
- std::string combination_name (size_t cno) const;
165
+ std::string combination_name(size_t cno) const;
169
166
 
170
167
  /// print a description on stdout
171
- void display () const;
168
+ void display() const;
172
169
 
173
170
  /// add a new parameter (or return it if it exists)
174
- ParameterRange &add_range(const std::string & name);
171
+ ParameterRange& add_range(const std::string& name);
175
172
 
176
173
  /// initialize with reasonable parameters for the index
177
- virtual void initialize (const Index * index);
174
+ virtual void initialize(const Index* index);
178
175
 
179
176
  /// set a combination of parameters on an index
180
- void set_index_parameters (Index *index, size_t cno) const;
177
+ void set_index_parameters(Index* index, size_t cno) const;
181
178
 
182
179
  /// set a combination of parameters described by a string
183
- void set_index_parameters (Index *index, const char *param_string) const;
180
+ void set_index_parameters(Index* index, const char* param_string) const;
184
181
 
185
182
  /// set one of the parameters, returns whether setting was successful
186
- virtual void set_index_parameter (
187
- Index * index, const std::string & name, double val) const;
183
+ virtual void set_index_parameter(
184
+ Index* index,
185
+ const std::string& name,
186
+ double val) const;
188
187
 
189
188
  /** find an upper bound on the performance and a lower bound on t
190
189
  * for configuration cno given another operating point op */
191
- void update_bounds (size_t cno, const OperatingPoint & op,
192
- double *upper_bound_perf,
193
- double *lower_bound_t) const;
190
+ void update_bounds(
191
+ size_t cno,
192
+ const OperatingPoint& op,
193
+ double* upper_bound_perf,
194
+ double* lower_bound_t) const;
194
195
 
195
196
  /** explore operating points
196
197
  * @param index index to run on
@@ -198,18 +199,16 @@ struct ParameterSpace {
198
199
  * @param crit selection criterion
199
200
  * @param ops resulting operating points
200
201
  */
201
- void explore (Index *index,
202
- size_t nq, const float *xq,
203
- const AutoTuneCriterion & crit,
204
- OperatingPoints * ops) const;
205
-
206
- virtual ~ParameterSpace () {}
202
+ void explore(
203
+ Index* index,
204
+ size_t nq,
205
+ const float* xq,
206
+ const AutoTuneCriterion& crit,
207
+ OperatingPoints* ops) const;
208
+
209
+ virtual ~ParameterSpace() {}
207
210
  };
208
211
 
209
-
210
-
211
212
  } // namespace faiss
212
213
 
213
-
214
-
215
214
  #endif
@@ -8,6 +8,7 @@
8
8
  // -*- c++ -*-
9
9
 
10
10
  #include <faiss/Clustering.h>
11
+ #include <faiss/VectorTransform.h>
11
12
  #include <faiss/impl/AuxIndexStructures.h>
12
13
 
13
14
  #include <cinttypes>
@@ -17,100 +18,100 @@
17
18
 
18
19
  #include <omp.h>
19
20
 
20
- #include <faiss/utils/utils.h>
21
- #include <faiss/utils/random.h>
22
- #include <faiss/utils/distances.h>
23
- #include <faiss/impl/FaissAssert.h>
24
21
  #include <faiss/IndexFlat.h>
22
+ #include <faiss/impl/FaissAssert.h>
23
+ #include <faiss/utils/distances.h>
24
+ #include <faiss/utils/random.h>
25
+ #include <faiss/utils/utils.h>
25
26
 
26
27
  namespace faiss {
27
28
 
28
- ClusteringParameters::ClusteringParameters ():
29
- niter(25),
30
- nredo(1),
31
- verbose(false),
32
- spherical(false),
33
- int_centroids(false),
34
- update_index(false),
35
- frozen_centroids(false),
36
- min_points_per_centroid(39),
37
- max_points_per_centroid(256),
38
- seed(1234),
39
- decode_block_size(32768)
40
- {}
29
+ ClusteringParameters::ClusteringParameters()
30
+ : niter(25),
31
+ nredo(1),
32
+ verbose(false),
33
+ spherical(false),
34
+ int_centroids(false),
35
+ update_index(false),
36
+ frozen_centroids(false),
37
+ min_points_per_centroid(39),
38
+ max_points_per_centroid(256),
39
+ seed(1234),
40
+ decode_block_size(32768) {}
41
41
  // 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
42
42
 
43
+ Clustering::Clustering(int d, int k) : d(d), k(k) {}
43
44
 
44
- Clustering::Clustering (int d, int k):
45
- d(d), k(k) {}
46
-
47
- Clustering::Clustering (int d, int k, const ClusteringParameters &cp):
48
- ClusteringParameters (cp), d(d), k(k) {}
45
+ Clustering::Clustering(int d, int k, const ClusteringParameters& cp)
46
+ : ClusteringParameters(cp), d(d), k(k) {}
49
47
 
50
-
51
-
52
- static double imbalance_factor (int n, int k, int64_t *assign) {
48
+ static double imbalance_factor(int n, int k, int64_t* assign) {
53
49
  std::vector<int> hist(k, 0);
54
50
  for (int i = 0; i < n; i++)
55
51
  hist[assign[i]]++;
56
52
 
57
53
  double tot = 0, uf = 0;
58
54
 
59
- for (int i = 0 ; i < k ; i++) {
55
+ for (int i = 0; i < k; i++) {
60
56
  tot += hist[i];
61
- uf += hist[i] * (double) hist[i];
57
+ uf += hist[i] * (double)hist[i];
62
58
  }
63
59
  uf = uf * k / (tot * tot);
64
60
 
65
61
  return uf;
66
62
  }
67
63
 
68
- void Clustering::post_process_centroids ()
69
- {
70
-
64
+ void Clustering::post_process_centroids() {
71
65
  if (spherical) {
72
- fvec_renorm_L2 (d, k, centroids.data());
66
+ fvec_renorm_L2(d, k, centroids.data());
73
67
  }
74
68
 
75
69
  if (int_centroids) {
76
70
  for (size_t i = 0; i < centroids.size(); i++)
77
- centroids[i] = roundf (centroids[i]);
71
+ centroids[i] = roundf(centroids[i]);
78
72
  }
79
73
  }
80
74
 
81
-
82
- void Clustering::train (idx_t nx, const float *x_in, Index & index,
83
- const float *weights) {
84
- train_encoded (nx, reinterpret_cast<const uint8_t *>(x_in), nullptr,
85
- index, weights);
75
+ void Clustering::train(
76
+ idx_t nx,
77
+ const float* x_in,
78
+ Index& index,
79
+ const float* weights) {
80
+ train_encoded(
81
+ nx,
82
+ reinterpret_cast<const uint8_t*>(x_in),
83
+ nullptr,
84
+ index,
85
+ weights);
86
86
  }
87
87
 
88
-
89
88
  namespace {
90
89
 
91
90
  using idx_t = Clustering::idx_t;
92
91
 
93
92
  idx_t subsample_training_set(
94
- const Clustering &clus, idx_t nx, const uint8_t *x,
95
- size_t line_size, const float * weights,
96
- uint8_t **x_out,
97
- float **weights_out
98
- )
99
- {
93
+ const Clustering& clus,
94
+ idx_t nx,
95
+ const uint8_t* x,
96
+ size_t line_size,
97
+ const float* weights,
98
+ uint8_t** x_out,
99
+ float** weights_out) {
100
100
  if (clus.verbose) {
101
101
  printf("Sampling a subset of %zd / %" PRId64 " for training\n",
102
- clus.k * clus.max_points_per_centroid, nx);
102
+ clus.k * clus.max_points_per_centroid,
103
+ nx);
103
104
  }
104
- std::vector<int> perm (nx);
105
- rand_perm (perm.data (), nx, clus.seed);
105
+ std::vector<int> perm(nx);
106
+ rand_perm(perm.data(), nx, clus.seed);
106
107
  nx = clus.k * clus.max_points_per_centroid;
107
- uint8_t * x_new = new uint8_t [nx * line_size];
108
+ uint8_t* x_new = new uint8_t[nx * line_size];
108
109
  *x_out = x_new;
109
110
  for (idx_t i = 0; i < nx; i++) {
110
- memcpy (x_new + i * line_size, x + perm[i] * line_size, line_size);
111
+ memcpy(x_new + i * line_size, x + perm[i] * line_size, line_size);
111
112
  }
112
113
  if (weights) {
113
- float *weights_new = new float[nx];
114
+ float* weights_new = new float[nx];
114
115
  for (idx_t i = 0; i < nx; i++) {
115
116
  weights_new[i] = weights[perm[i]];
116
117
  }
@@ -134,20 +135,23 @@ idx_t subsample_training_set(
134
135
  *
135
136
  */
136
137
 
137
- void compute_centroids (size_t d, size_t k, size_t n,
138
- size_t k_frozen,
139
- const uint8_t * x, const Index *codec,
140
- const int64_t * assign,
141
- const float * weights,
142
- float * hassign,
143
- float * centroids)
144
- {
138
+ void compute_centroids(
139
+ size_t d,
140
+ size_t k,
141
+ size_t n,
142
+ size_t k_frozen,
143
+ const uint8_t* x,
144
+ const Index* codec,
145
+ const int64_t* assign,
146
+ const float* weights,
147
+ float* hassign,
148
+ float* centroids) {
145
149
  k -= k_frozen;
146
150
  centroids += k_frozen * d;
147
151
 
148
- memset (centroids, 0, sizeof(*centroids) * d * k);
152
+ memset(centroids, 0, sizeof(*centroids) * d * k);
149
153
 
150
- size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float);
154
+ size_t line_size = codec ? codec->sa_code_size() : d * sizeof(float);
151
155
 
152
156
  #pragma omp parallel
153
157
  {
@@ -157,20 +161,20 @@ void compute_centroids (size_t d, size_t k, size_t n,
157
161
  // this thread is taking care of centroids c0:c1
158
162
  size_t c0 = (k * rank) / nt;
159
163
  size_t c1 = (k * (rank + 1)) / nt;
160
- std::vector<float> decode_buffer (d);
164
+ std::vector<float> decode_buffer(d);
161
165
 
162
166
  for (size_t i = 0; i < n; i++) {
163
167
  int64_t ci = assign[i];
164
- assert (ci >= 0 && ci < k + k_frozen);
168
+ assert(ci >= 0 && ci < k + k_frozen);
165
169
  ci -= k_frozen;
166
- if (ci >= c0 && ci < c1) {
167
- float * c = centroids + ci * d;
168
- const float * xi;
170
+ if (ci >= c0 && ci < c1) {
171
+ float* c = centroids + ci * d;
172
+ const float* xi;
169
173
  if (!codec) {
170
174
  xi = reinterpret_cast<const float*>(x + i * line_size);
171
175
  } else {
172
- float *xif = decode_buffer.data();
173
- codec->sa_decode (1, x + i * line_size, xif);
176
+ float* xif = decode_buffer.data();
177
+ codec->sa_decode(1, x + i * line_size, xif);
174
178
  xi = xif;
175
179
  }
176
180
  if (weights) {
@@ -187,7 +191,6 @@ void compute_centroids (size_t d, size_t k, size_t n,
187
191
  }
188
192
  }
189
193
  }
190
-
191
194
  }
192
195
 
193
196
  #pragma omp parallel for
@@ -196,12 +199,11 @@ void compute_centroids (size_t d, size_t k, size_t n,
196
199
  continue;
197
200
  }
198
201
  float norm = 1 / hassign[ci];
199
- float * c = centroids + ci * d;
202
+ float* c = centroids + ci * d;
200
203
  for (size_t j = 0; j < d; j++) {
201
204
  c[j] *= norm;
202
205
  }
203
206
  }
204
-
205
207
  }
206
208
 
207
209
  // a bit above machine epsilon for float16
@@ -214,29 +216,33 @@ void compute_centroids (size_t d, size_t k, size_t n,
214
216
  *
215
217
  * @return nb of spliting operations (larger is worse)
216
218
  */
217
- int split_clusters (size_t d, size_t k, size_t n,
218
- size_t k_frozen,
219
- float * hassign,
220
- float * centroids)
221
- {
219
+ int split_clusters(
220
+ size_t d,
221
+ size_t k,
222
+ size_t n,
223
+ size_t k_frozen,
224
+ float* hassign,
225
+ float* centroids) {
222
226
  k -= k_frozen;
223
227
  centroids += k_frozen * d;
224
228
 
225
229
  /* Take care of void clusters */
226
230
  size_t nsplit = 0;
227
- RandomGenerator rng (1234);
231
+ RandomGenerator rng(1234);
228
232
  for (size_t ci = 0; ci < k; ci++) {
229
233
  if (hassign[ci] == 0) { /* need to redefine a centroid */
230
234
  size_t cj;
231
235
  for (cj = 0; 1; cj = (cj + 1) % k) {
232
236
  /* probability to pick this cluster for split */
233
- float p = (hassign[cj] - 1.0) / (float) (n - k);
234
- float r = rng.rand_float ();
237
+ float p = (hassign[cj] - 1.0) / (float)(n - k);
238
+ float r = rng.rand_float();
235
239
  if (r < p) {
236
240
  break; /* found our cluster to be split */
237
241
  }
238
242
  }
239
- memcpy (centroids+ci*d, centroids+cj*d, sizeof(*centroids) * d);
243
+ memcpy(centroids + ci * d,
244
+ centroids + cj * d,
245
+ sizeof(*centroids) * d);
240
246
 
241
247
  /* small symmetric pertubation */
242
248
  for (size_t j = 0; j < d; j++) {
@@ -257,30 +263,35 @@ int split_clusters (size_t d, size_t k, size_t n,
257
263
  }
258
264
 
259
265
  return nsplit;
260
-
261
266
  }
262
267
 
263
-
264
-
265
- };
266
-
267
-
268
- void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
269
- const Index * codec, Index & index,
270
- const float *weights) {
271
-
272
-
273
- FAISS_THROW_IF_NOT_FMT (nx >= k,
274
- "Number of training points (%" PRId64 ") should be at least "
275
- "as large as number of clusters (%zd)", nx, k);
276
-
277
- FAISS_THROW_IF_NOT_FMT ((!codec || codec->d == d),
278
- "Codec dimension %d not the same as data dimension %d",
279
- int(codec->d), int(d));
280
-
281
- FAISS_THROW_IF_NOT_FMT (index.d == d,
268
+ }; // namespace
269
+
270
+ void Clustering::train_encoded(
271
+ idx_t nx,
272
+ const uint8_t* x_in,
273
+ const Index* codec,
274
+ Index& index,
275
+ const float* weights) {
276
+ FAISS_THROW_IF_NOT_FMT(
277
+ nx >= k,
278
+ "Number of training points (%" PRId64
279
+ ") should be at least "
280
+ "as large as number of clusters (%zd)",
281
+ nx,
282
+ k);
283
+
284
+ FAISS_THROW_IF_NOT_FMT(
285
+ (!codec || codec->d == d),
286
+ "Codec dimension %d not the same as data dimension %d",
287
+ int(codec->d),
288
+ int(d));
289
+
290
+ FAISS_THROW_IF_NOT_FMT(
291
+ index.d == d,
282
292
  "Index dimension %d not the same as data dimension %d",
283
- int(index.d), int(d));
293
+ int(index.d),
294
+ int(d));
284
295
 
285
296
  double t0 = getmillisecs();
286
297
 
@@ -288,67 +299,78 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
288
299
  // Check for NaNs in input data. Normally it is the user's
289
300
  // responsibility, but it may spare us some hard-to-debug
290
301
  // reports.
291
- const float *x = reinterpret_cast<const float *>(x_in);
302
+ const float* x = reinterpret_cast<const float*>(x_in);
292
303
  for (size_t i = 0; i < nx * d; i++) {
293
- FAISS_THROW_IF_NOT_MSG (std::isfinite (x[i]),
294
- "input contains NaN's or Inf's");
304
+ FAISS_THROW_IF_NOT_MSG(
305
+ std::isfinite(x[i]), "input contains NaN's or Inf's");
295
306
  }
296
307
  }
297
308
 
298
- const uint8_t *x = x_in;
299
- std::unique_ptr<uint8_t []> del1;
300
- std::unique_ptr<float []> del3;
309
+ const uint8_t* x = x_in;
310
+ std::unique_ptr<uint8_t[]> del1;
311
+ std::unique_ptr<float[]> del3;
301
312
  size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
302
313
 
303
314
  if (nx > k * max_points_per_centroid) {
304
- uint8_t *x_new;
305
- float *weights_new;
306
- nx = subsample_training_set (*this, nx, x, line_size, weights,
307
- &x_new, &weights_new);
308
- del1.reset (x_new); x = x_new;
309
- del3.reset (weights_new); weights = weights_new;
315
+ uint8_t* x_new;
316
+ float* weights_new;
317
+ nx = subsample_training_set(
318
+ *this, nx, x, line_size, weights, &x_new, &weights_new);
319
+ del1.reset(x_new);
320
+ x = x_new;
321
+ del3.reset(weights_new);
322
+ weights = weights_new;
310
323
  } else if (nx < k * min_points_per_centroid) {
311
- fprintf (stderr,
312
- "WARNING clustering %" PRId64 " points to %zd centroids: "
313
- "please provide at least %" PRId64 " training points\n",
314
- nx, k, idx_t(k) * min_points_per_centroid);
324
+ fprintf(stderr,
325
+ "WARNING clustering %" PRId64
326
+ " points to %zd centroids: "
327
+ "please provide at least %" PRId64 " training points\n",
328
+ nx,
329
+ k,
330
+ idx_t(k) * min_points_per_centroid);
315
331
  }
316
332
 
317
333
  if (nx == k) {
318
334
  // this is a corner case, just copy training set to clusters
319
335
  if (verbose) {
320
- printf("Number of training points (%" PRId64 ") same as number of "
321
- "clusters, just copying\n", nx);
336
+ printf("Number of training points (%" PRId64
337
+ ") same as number of "
338
+ "clusters, just copying\n",
339
+ nx);
322
340
  }
323
- centroids.resize (d * k);
341
+ centroids.resize(d * k);
324
342
  if (!codec) {
325
- memcpy (centroids.data(), x_in, sizeof (float) * d * k);
343
+ memcpy(centroids.data(), x_in, sizeof(float) * d * k);
326
344
  } else {
327
- codec->sa_decode (nx, x_in, centroids.data());
345
+ codec->sa_decode(nx, x_in, centroids.data());
328
346
  }
329
347
 
330
348
  // one fake iteration...
331
- ClusteringIterationStats stats = { 0.0, 0.0, 0.0, 1.0, 0 };
332
- iteration_stats.push_back (stats);
349
+ ClusteringIterationStats stats = {0.0, 0.0, 0.0, 1.0, 0};
350
+ iteration_stats.push_back(stats);
333
351
 
334
352
  index.reset();
335
353
  index.add(k, centroids.data());
336
354
  return;
337
355
  }
338
356
 
339
-
340
357
  if (verbose) {
341
- printf("Clustering %" PRId64 " points in %zdD to %zd clusters, "
358
+ printf("Clustering %" PRId64
359
+ " points in %zdD to %zd clusters, "
342
360
  "redo %d times, %d iterations\n",
343
- nx, d, k, nredo, niter);
361
+ nx,
362
+ d,
363
+ k,
364
+ nredo,
365
+ niter);
344
366
  if (codec) {
345
367
  printf("Input data encoded in %zd bytes per vector\n",
346
- codec->sa_code_size ());
368
+ codec->sa_code_size());
347
369
  }
348
370
  }
349
371
 
350
- std::unique_ptr<idx_t []> assign(new idx_t[nx]);
351
- std::unique_ptr<float []> dis(new float[nx]);
372
+ std::unique_ptr<idx_t[]> assign(new idx_t[nx]);
373
+ std::unique_ptr<float[]> dis(new float[nx]);
352
374
 
353
375
  // remember best iteration for redo
354
376
  bool lower_is_better = index.metric_type != METRIC_INNER_PRODUCT;
@@ -358,52 +380,49 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
358
380
 
359
381
  // support input centroids
360
382
 
361
- FAISS_THROW_IF_NOT_MSG (
362
- centroids.size() % d == 0,
363
- "size of provided input centroids not a multiple of dimension"
364
- );
383
+ FAISS_THROW_IF_NOT_MSG(
384
+ centroids.size() % d == 0,
385
+ "size of provided input centroids not a multiple of dimension");
365
386
 
366
387
  size_t n_input_centroids = centroids.size() / d;
367
388
 
368
389
  if (verbose && n_input_centroids > 0) {
369
- printf (" Using %zd centroids provided as input (%sfrozen)\n",
370
- n_input_centroids, frozen_centroids ? "" : "not ");
390
+ printf(" Using %zd centroids provided as input (%sfrozen)\n",
391
+ n_input_centroids,
392
+ frozen_centroids ? "" : "not ");
371
393
  }
372
394
 
373
395
  double t_search_tot = 0;
374
396
  if (verbose) {
375
- printf(" Preprocessing in %.2f s\n",
376
- (getmillisecs() - t0) / 1000.);
397
+ printf(" Preprocessing in %.2f s\n", (getmillisecs() - t0) / 1000.);
377
398
  }
378
399
  t0 = getmillisecs();
379
400
 
380
401
  // temporary buffer to decode vectors during the optimization
381
- std::vector<float> decode_buffer
382
- (codec ? d * decode_block_size : 0);
402
+ std::vector<float> decode_buffer(codec ? d * decode_block_size : 0);
383
403
 
384
404
  for (int redo = 0; redo < nredo; redo++) {
385
-
386
405
  if (verbose && nredo > 1) {
387
406
  printf("Outer iteration %d / %d\n", redo, nredo);
388
407
  }
389
408
 
390
409
  // initialize (remaining) centroids with random points from the dataset
391
- centroids.resize (d * k);
392
- std::vector<int> perm (nx);
410
+ centroids.resize(d * k);
411
+ std::vector<int> perm(nx);
393
412
 
394
- rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
413
+ rand_perm(perm.data(), nx, seed + 1 + redo * 15486557L);
395
414
 
396
415
  if (!codec) {
397
- for (int i = n_input_centroids; i < k ; i++) {
398
- memcpy (&centroids[i * d], x + perm[i] * line_size, line_size);
416
+ for (int i = n_input_centroids; i < k; i++) {
417
+ memcpy(&centroids[i * d], x + perm[i] * line_size, line_size);
399
418
  }
400
419
  } else {
401
- for (int i = n_input_centroids; i < k ; i++) {
402
- codec->sa_decode (1, x + perm[i] * line_size, &centroids[i * d]);
420
+ for (int i = n_input_centroids; i < k; i++) {
421
+ codec->sa_decode(1, x + perm[i] * line_size, &centroids[i * d]);
403
422
  }
404
423
  }
405
424
 
406
- post_process_centroids ();
425
+ post_process_centroids();
407
426
 
408
427
  // prepare the index
409
428
 
@@ -412,10 +431,10 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
412
431
  }
413
432
 
414
433
  if (!index.is_trained) {
415
- index.train (k, centroids.data());
434
+ index.train(k, centroids.data());
416
435
  }
417
436
 
418
- index.add (k, centroids.data());
437
+ index.add(k, centroids.data());
419
438
 
420
439
  // k-means iterations
421
440
 
@@ -424,18 +443,28 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
424
443
  double t0s = getmillisecs();
425
444
 
426
445
  if (!codec) {
427
- index.search (nx, reinterpret_cast<const float *>(x), 1,
428
- dis.get(), assign.get());
446
+ index.search(
447
+ nx,
448
+ reinterpret_cast<const float*>(x),
449
+ 1,
450
+ dis.get(),
451
+ assign.get());
429
452
  } else {
430
453
  // search by blocks of decode_block_size vectors
431
- size_t code_size = codec->sa_code_size ();
454
+ size_t code_size = codec->sa_code_size();
432
455
  for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
433
456
  size_t i1 = i0 + decode_block_size;
434
- if (i1 > nx) { i1 = nx; }
435
- codec->sa_decode (i1 - i0, x + code_size * i0,
436
- decode_buffer.data ());
437
- index.search (i1 - i0, decode_buffer.data (), 1,
438
- dis.get() + i0, assign.get() + i0);
457
+ if (i1 > nx) {
458
+ i1 = nx;
459
+ }
460
+ codec->sa_decode(
461
+ i1 - i0, x + code_size * i0, decode_buffer.data());
462
+ index.search(
463
+ i1 - i0,
464
+ decode_buffer.data(),
465
+ 1,
466
+ dis.get() + i0,
467
+ assign.get() + i0);
439
468
  }
440
469
  }
441
470
 
@@ -449,61 +478,71 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
449
478
  }
450
479
 
451
480
  // update the centroids
452
- std::vector<float> hassign (k);
481
+ std::vector<float> hassign(k);
453
482
 
454
483
  size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
455
- compute_centroids (
456
- d, k, nx, k_frozen,
457
- x, codec, assign.get(), weights,
458
- hassign.data(), centroids.data()
459
- );
460
-
461
- int nsplit = split_clusters (
462
- d, k, nx, k_frozen,
463
- hassign.data(), centroids.data()
464
- );
484
+ compute_centroids(
485
+ d,
486
+ k,
487
+ nx,
488
+ k_frozen,
489
+ x,
490
+ codec,
491
+ assign.get(),
492
+ weights,
493
+ hassign.data(),
494
+ centroids.data());
495
+
496
+ int nsplit = split_clusters(
497
+ d, k, nx, k_frozen, hassign.data(), centroids.data());
465
498
 
466
499
  // collect statistics
467
- ClusteringIterationStats stats =
468
- { obj, (getmillisecs() - t0) / 1000.0,
469
- t_search_tot / 1000,
470
- imbalance_factor (nx, k, assign.get()),
471
- nsplit };
500
+ ClusteringIterationStats stats = {
501
+ obj,
502
+ (getmillisecs() - t0) / 1000.0,
503
+ t_search_tot / 1000,
504
+ imbalance_factor(nx, k, assign.get()),
505
+ nsplit};
472
506
  iteration_stats.push_back(stats);
473
507
 
474
508
  if (verbose) {
475
- printf (" Iteration %d (%.2f s, search %.2f s): "
476
- "objective=%g imbalance=%.3f nsplit=%d \r",
477
- i, stats.time, stats.time_search, stats.obj,
478
- stats.imbalance_factor, nsplit);
479
- fflush (stdout);
509
+ printf(" Iteration %d (%.2f s, search %.2f s): "
510
+ "objective=%g imbalance=%.3f nsplit=%d \r",
511
+ i,
512
+ stats.time,
513
+ stats.time_search,
514
+ stats.obj,
515
+ stats.imbalance_factor,
516
+ nsplit);
517
+ fflush(stdout);
480
518
  }
481
519
 
482
- post_process_centroids ();
520
+ post_process_centroids();
483
521
 
484
522
  // add centroids to index for the next iteration (or for output)
485
523
 
486
- index.reset ();
524
+ index.reset();
487
525
  if (update_index) {
488
- index.train (k, centroids.data());
526
+ index.train(k, centroids.data());
489
527
  }
490
528
 
491
- index.add (k, centroids.data());
492
- InterruptCallback::check ();
529
+ index.add(k, centroids.data());
530
+ InterruptCallback::check();
493
531
  }
494
532
 
495
- if (verbose) printf("\n");
533
+ if (verbose)
534
+ printf("\n");
496
535
  if (nredo > 1) {
497
536
  if ((lower_is_better && obj < best_obj) ||
498
537
  (!lower_is_better && obj > best_obj)) {
499
538
  if (verbose) {
500
- printf ("Objective improved: keep new clusters\n");
539
+ printf("Objective improved: keep new clusters\n");
501
540
  }
502
541
  best_centroids = centroids;
503
542
  best_iteration_stats = iteration_stats;
504
543
  best_obj = obj;
505
544
  }
506
- index.reset ();
545
+ index.reset();
507
546
  }
508
547
  }
509
548
  if (nredo > 1) {
@@ -512,20 +551,120 @@ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
512
551
  index.reset();
513
552
  index.add(k, best_centroids.data());
514
553
  }
515
-
516
554
  }
517
555
 
518
- float kmeans_clustering (size_t d, size_t n, size_t k,
519
- const float *x,
520
- float *centroids)
521
- {
522
- Clustering clus (d, k);
556
+ float kmeans_clustering(
557
+ size_t d,
558
+ size_t n,
559
+ size_t k,
560
+ const float* x,
561
+ float* centroids) {
562
+ Clustering clus(d, k);
523
563
  clus.verbose = d * n * k > (1L << 30);
524
564
  // display logs if > 1Gflop per iteration
525
- IndexFlatL2 index (d);
526
- clus.train (n, x, index);
565
+ IndexFlatL2 index(d);
566
+ clus.train(n, x, index);
527
567
  memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
528
568
  return clus.iteration_stats.back().obj;
529
569
  }
530
570
 
571
+ /******************************************************************************
572
+ * ProgressiveDimClustering implementation
573
+ ******************************************************************************/
574
+
575
+ ProgressiveDimClusteringParameters::ProgressiveDimClusteringParameters() {
576
+ progressive_dim_steps = 10;
577
+ apply_pca = true; // seems a good idea to do this by default
578
+ niter = 10; // reduce nb of iterations per step
579
+ }
580
+
581
+ Index* ProgressiveDimIndexFactory::operator()(int dim) {
582
+ return new IndexFlatL2(dim);
583
+ }
584
+
585
+ ProgressiveDimClustering::ProgressiveDimClustering(int d, int k) : d(d), k(k) {}
586
+
587
+ ProgressiveDimClustering::ProgressiveDimClustering(
588
+ int d,
589
+ int k,
590
+ const ProgressiveDimClusteringParameters& cp)
591
+ : ProgressiveDimClusteringParameters(cp), d(d), k(k) {}
592
+
593
+ namespace {
594
+
595
+ using idx_t = Index::idx_t;
596
+
597
+ void copy_columns(idx_t n, idx_t d1, const float* src, idx_t d2, float* dest) {
598
+ idx_t d = std::min(d1, d2);
599
+ for (idx_t i = 0; i < n; i++) {
600
+ memcpy(dest, src, sizeof(float) * d);
601
+ src += d1;
602
+ dest += d2;
603
+ }
604
+ }
605
+
606
+ }; // namespace
607
+
608
+ void ProgressiveDimClustering::train(
609
+ idx_t n,
610
+ const float* x,
611
+ ProgressiveDimIndexFactory& factory) {
612
+ int d_prev = 0;
613
+
614
+ PCAMatrix pca(d, d);
615
+
616
+ std::vector<float> xbuf;
617
+ if (apply_pca) {
618
+ if (verbose) {
619
+ printf("Training PCA transform\n");
620
+ }
621
+ pca.train(n, x);
622
+ if (verbose) {
623
+ printf("Apply PCA\n");
624
+ }
625
+ xbuf.resize(n * d);
626
+ pca.apply_noalloc(n, x, xbuf.data());
627
+ x = xbuf.data();
628
+ }
629
+
630
+ for (int iter = 0; iter < progressive_dim_steps; iter++) {
631
+ int di = int(pow(d, (1. + iter) / progressive_dim_steps));
632
+ if (verbose) {
633
+ printf("Progressive dim step %d: cluster in dimension %d\n",
634
+ iter,
635
+ di);
636
+ }
637
+ std::unique_ptr<Index> clustering_index(factory(di));
638
+
639
+ Clustering clus(di, k, *this);
640
+ if (d_prev > 0) {
641
+ // copy warm-start centroids (padded with 0s)
642
+ clus.centroids.resize(k * di);
643
+ copy_columns(
644
+ k, d_prev, centroids.data(), di, clus.centroids.data());
645
+ }
646
+ std::vector<float> xsub(n * di);
647
+ copy_columns(n, d, x, di, xsub.data());
648
+
649
+ clus.train(n, xsub.data(), *clustering_index.get());
650
+
651
+ centroids = clus.centroids;
652
+ iteration_stats.insert(
653
+ iteration_stats.end(),
654
+ clus.iteration_stats.begin(),
655
+ clus.iteration_stats.end());
656
+
657
+ d_prev = di;
658
+ }
659
+
660
+ if (apply_pca) {
661
+ if (verbose) {
662
+ printf("Revert PCA transform on centroids\n");
663
+ }
664
+ std::vector<float> cent_transformed(d * k);
665
+ pca.reverse_transform(k, centroids.data(), cent_transformed.data());
666
+ cent_transformed.swap(centroids);
667
+ }
668
+ }
669
+
531
670
  } // namespace faiss