faiss 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,212 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #ifndef FAISS_AUTO_TUNE_H
11
+ #define FAISS_AUTO_TUNE_H
12
+
13
+ #include <vector>
14
+ #include <unordered_map>
15
+ #include <stdint.h>
16
+
17
+ #include <faiss/Index.h>
18
+ #include <faiss/IndexBinary.h>
19
+
20
+ namespace faiss {
21
+
22
+
23
+ /**
24
+ * Evaluation criterion. Returns a performance measure in [0,1],
25
+ * higher is better.
26
+ */
27
+ struct AutoTuneCriterion {
28
+ typedef Index::idx_t idx_t;
29
+ idx_t nq; ///< nb of queries this criterion is evaluated on
30
+ idx_t nnn; ///< nb of NNs that the query should request
31
+ idx_t gt_nnn; ///< nb of GT NNs required to evaluate crterion
32
+
33
+ std::vector<float> gt_D; ///< Ground-truth distances (size nq * gt_nnn)
34
+ std::vector<idx_t> gt_I; ///< Ground-truth indexes (size nq * gt_nnn)
35
+
36
+ AutoTuneCriterion (idx_t nq, idx_t nnn);
37
+
38
+ /** Intitializes the gt_D and gt_I vectors. Must be called before evaluating
39
+ *
40
+ * @param gt_D_in size nq * gt_nnn
41
+ * @param gt_I_in size nq * gt_nnn
42
+ */
43
+ void set_groundtruth (int gt_nnn, const float *gt_D_in,
44
+ const idx_t *gt_I_in);
45
+
46
+ /** Evaluate the criterion.
47
+ *
48
+ * @param D size nq * nnn
49
+ * @param I size nq * nnn
50
+ * @return the criterion, between 0 and 1. Larger is better.
51
+ */
52
+ virtual double evaluate (const float *D, const idx_t *I) const = 0;
53
+
54
+ virtual ~AutoTuneCriterion () {}
55
+
56
+ };
57
+
58
+ struct OneRecallAtRCriterion: AutoTuneCriterion {
59
+
60
+ idx_t R;
61
+
62
+ OneRecallAtRCriterion (idx_t nq, idx_t R);
63
+
64
+ double evaluate(const float* D, const idx_t* I) const override;
65
+
66
+ ~OneRecallAtRCriterion() override {}
67
+ };
68
+
69
+
70
+ struct IntersectionCriterion: AutoTuneCriterion {
71
+
72
+ idx_t R;
73
+
74
+ IntersectionCriterion (idx_t nq, idx_t R);
75
+
76
+ double evaluate(const float* D, const idx_t* I) const override;
77
+
78
+ ~IntersectionCriterion() override {}
79
+ };
80
+
81
+ /**
82
+ * Maintains a list of experimental results. Each operating point is a
83
+ * (perf, t, key) triplet, where higher perf and lower t is
84
+ * better. The key field is an arbitrary identifier for the operating point
85
+ */
86
+
87
+ struct OperatingPoint {
88
+ double perf; ///< performance measure (output of a Criterion)
89
+ double t; ///< corresponding execution time (ms)
90
+ std::string key; ///< key that identifies this op pt
91
+ int64_t cno; ///< integer identifer
92
+ };
93
+
94
+ struct OperatingPoints {
95
+ /// all operating points
96
+ std::vector<OperatingPoint> all_pts;
97
+
98
+ /// optimal operating points, sorted by perf
99
+ std::vector<OperatingPoint> optimal_pts;
100
+
101
+ // begins with a single operating point: t=0, perf=0
102
+ OperatingPoints ();
103
+
104
+ /// add operating points from other to this, with a prefix to the keys
105
+ int merge_with (const OperatingPoints &other,
106
+ const std::string & prefix = "");
107
+
108
+ void clear ();
109
+
110
+ /// add a performance measure. Return whether it is an optimal point
111
+ bool add (double perf, double t, const std::string & key, size_t cno = 0);
112
+
113
+ /// get time required to obtain a given performance measure
114
+ double t_for_perf (double perf) const;
115
+
116
+ /// easy-to-read output
117
+ void display (bool only_optimal = true) const;
118
+
119
+ /// output to a format easy to digest by gnuplot
120
+ void all_to_gnuplot (const char *fname) const;
121
+ void optimal_to_gnuplot (const char *fname) const;
122
+
123
+ };
124
+
125
+ /// possible values of a parameter, sorted from least to most expensive/accurate
126
+ struct ParameterRange {
127
+ std::string name;
128
+ std::vector<double> values;
129
+ };
130
+
131
+ /** Uses a-priori knowledge on the Faiss indexes to extract tunable parameters.
132
+ */
133
+ struct ParameterSpace {
134
+ /// all tunable parameters
135
+ std::vector<ParameterRange> parameter_ranges;
136
+
137
+ // exploration parameters
138
+
139
+ /// verbosity during exploration
140
+ int verbose;
141
+
142
+ /// nb of experiments during optimization (0 = try all combinations)
143
+ int n_experiments;
144
+
145
+ /// maximum number of queries to submit at a time.
146
+ size_t batchsize;
147
+
148
+ /// use multithreading over batches (useful to benchmark
149
+ /// independent single-searches)
150
+ bool thread_over_batches;
151
+
152
+ /// run tests several times until they reach at least this
153
+ /// duration (to avoid jittering in MT mode)
154
+ double min_test_duration;
155
+
156
+ ParameterSpace ();
157
+
158
+ /// nb of combinations, = product of values sizes
159
+ size_t n_combinations () const;
160
+
161
+ /// returns whether combinations c1 >= c2 in the tuple sense
162
+ bool combination_ge (size_t c1, size_t c2) const;
163
+
164
+ /// get string representation of the combination
165
+ std::string combination_name (size_t cno) const;
166
+
167
+ /// print a description on stdout
168
+ void display () const;
169
+
170
+ /// add a new parameter (or return it if it exists)
171
+ ParameterRange &add_range(const char * name);
172
+
173
+ /// initialize with reasonable parameters for the index
174
+ virtual void initialize (const Index * index);
175
+
176
+ /// set a combination of parameters on an index
177
+ void set_index_parameters (Index *index, size_t cno) const;
178
+
179
+ /// set a combination of parameters described by a string
180
+ void set_index_parameters (Index *index, const char *param_string) const;
181
+
182
+ /// set one of the parameters
183
+ virtual void set_index_parameter (
184
+ Index * index, const std::string & name, double val) const;
185
+
186
+ /** find an upper bound on the performance and a lower bound on t
187
+ * for configuration cno given another operating point op */
188
+ void update_bounds (size_t cno, const OperatingPoint & op,
189
+ double *upper_bound_perf,
190
+ double *lower_bound_t) const;
191
+
192
+ /** explore operating points
193
+ * @param index index to run on
194
+ * @param xq query vectors (size nq * index.d)
195
+ * @param crit selection criterion
196
+ * @param ops resulting operating points
197
+ */
198
+ void explore (Index *index,
199
+ size_t nq, const float *xq,
200
+ const AutoTuneCriterion & crit,
201
+ OperatingPoints * ops) const;
202
+
203
+ virtual ~ParameterSpace () {}
204
+ };
205
+
206
+
207
+
208
+ } // namespace faiss
209
+
210
+
211
+
212
+ #endif
@@ -0,0 +1,261 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/Clustering.h>
11
+ #include <faiss/impl/AuxIndexStructures.h>
12
+
13
+
14
+ #include <cmath>
15
+ #include <cstdio>
16
+ #include <cstring>
17
+
18
+ #include <faiss/utils/utils.h>
19
+ #include <faiss/utils/random.h>
20
+ #include <faiss/utils/distances.h>
21
+ #include <faiss/impl/FaissAssert.h>
22
+ #include <faiss/IndexFlat.h>
23
+
24
+ namespace faiss {
25
+
26
+ ClusteringParameters::ClusteringParameters ():
27
+ niter(25),
28
+ nredo(1),
29
+ verbose(false),
30
+ spherical(false),
31
+ int_centroids(false),
32
+ update_index(false),
33
+ frozen_centroids(false),
34
+ min_points_per_centroid(39),
35
+ max_points_per_centroid(256),
36
+ seed(1234)
37
+ {}
38
+ // 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
39
+
40
+
41
+ Clustering::Clustering (int d, int k):
42
+ d(d), k(k) {}
43
+
44
+ Clustering::Clustering (int d, int k, const ClusteringParameters &cp):
45
+ ClusteringParameters (cp), d(d), k(k) {}
46
+
47
+
48
+
49
+ static double imbalance_factor (int n, int k, int64_t *assign) {
50
+ std::vector<int> hist(k, 0);
51
+ for (int i = 0; i < n; i++)
52
+ hist[assign[i]]++;
53
+
54
+ double tot = 0, uf = 0;
55
+
56
+ for (int i = 0 ; i < k ; i++) {
57
+ tot += hist[i];
58
+ uf += hist[i] * (double) hist[i];
59
+ }
60
+ uf = uf * k / (tot * tot);
61
+
62
+ return uf;
63
+ }
64
+
65
+ void Clustering::post_process_centroids ()
66
+ {
67
+
68
+ if (spherical) {
69
+ fvec_renorm_L2 (d, k, centroids.data());
70
+ }
71
+
72
+ if (int_centroids) {
73
+ for (size_t i = 0; i < centroids.size(); i++)
74
+ centroids[i] = roundf (centroids[i]);
75
+ }
76
+ }
77
+
78
+
79
+ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
80
+ FAISS_THROW_IF_NOT_FMT (nx >= k,
81
+ "Number of training points (%ld) should be at least "
82
+ "as large as number of clusters (%ld)", nx, k);
83
+
84
+ double t0 = getmillisecs();
85
+
86
+ // yes it is the user's responsibility, but it may spare us some
87
+ // hard-to-debug reports.
88
+ for (size_t i = 0; i < nx * d; i++) {
89
+ FAISS_THROW_IF_NOT_MSG (finite (x_in[i]),
90
+ "input contains NaN's or Inf's");
91
+ }
92
+
93
+ const float *x = x_in;
94
+ ScopeDeleter<float> del1;
95
+
96
+ if (nx > k * max_points_per_centroid) {
97
+ if (verbose)
98
+ printf("Sampling a subset of %ld / %ld for training\n",
99
+ k * max_points_per_centroid, nx);
100
+ std::vector<int> perm (nx);
101
+ rand_perm (perm.data (), nx, seed);
102
+ nx = k * max_points_per_centroid;
103
+ float * x_new = new float [nx * d];
104
+ for (idx_t i = 0; i < nx; i++)
105
+ memcpy (x_new + i * d, x + perm[i] * d, sizeof(x_new[0]) * d);
106
+ x = x_new;
107
+ del1.set (x);
108
+ } else if (nx < k * min_points_per_centroid) {
109
+ fprintf (stderr,
110
+ "WARNING clustering %ld points to %ld centroids: "
111
+ "please provide at least %ld training points\n",
112
+ nx, k, idx_t(k) * min_points_per_centroid);
113
+ }
114
+
115
+
116
+ if (nx == k) {
117
+ if (verbose) {
118
+ printf("Number of training points (%ld) same as number of "
119
+ "clusters, just copying\n", nx);
120
+ }
121
+ // this is a corner case, just copy training set to clusters
122
+ centroids.resize (d * k);
123
+ memcpy (centroids.data(), x_in, sizeof (*x_in) * d * k);
124
+ index.reset();
125
+ index.add(k, x_in);
126
+ return;
127
+ }
128
+
129
+
130
+ if (verbose)
131
+ printf("Clustering %d points in %ldD to %ld clusters, "
132
+ "redo %d times, %d iterations\n",
133
+ int(nx), d, k, nredo, niter);
134
+
135
+ idx_t * assign = new idx_t[nx];
136
+ ScopeDeleter<idx_t> del (assign);
137
+ float * dis = new float[nx];
138
+ ScopeDeleter<float> del2(dis);
139
+
140
+ // for redo
141
+ float best_err = HUGE_VALF;
142
+ std::vector<float> best_obj;
143
+ std::vector<float> best_centroids;
144
+
145
+ // support input centroids
146
+
147
+ FAISS_THROW_IF_NOT_MSG (
148
+ centroids.size() % d == 0,
149
+ "size of provided input centroids not a multiple of dimension");
150
+
151
+ size_t n_input_centroids = centroids.size() / d;
152
+
153
+ if (verbose && n_input_centroids > 0) {
154
+ printf (" Using %zd centroids provided as input (%sfrozen)\n",
155
+ n_input_centroids, frozen_centroids ? "" : "not ");
156
+ }
157
+
158
+ double t_search_tot = 0;
159
+ if (verbose) {
160
+ printf(" Preprocessing in %.2f s\n",
161
+ (getmillisecs() - t0) / 1000.);
162
+ }
163
+ t0 = getmillisecs();
164
+
165
+ for (int redo = 0; redo < nredo; redo++) {
166
+
167
+ if (verbose && nredo > 1) {
168
+ printf("Outer iteration %d / %d\n", redo, nredo);
169
+ }
170
+
171
+ // initialize remaining centroids with random points from the dataset
172
+ centroids.resize (d * k);
173
+ std::vector<int> perm (nx);
174
+
175
+ rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
176
+ for (int i = n_input_centroids; i < k ; i++)
177
+ memcpy (&centroids[i * d], x + perm[i] * d,
178
+ d * sizeof (float));
179
+
180
+ post_process_centroids ();
181
+
182
+ if (index.ntotal != 0) {
183
+ index.reset();
184
+ }
185
+
186
+ if (!index.is_trained) {
187
+ index.train (k, centroids.data());
188
+ }
189
+
190
+ index.add (k, centroids.data());
191
+ float err = 0;
192
+ for (int i = 0; i < niter; i++) {
193
+ double t0s = getmillisecs();
194
+ index.search (nx, x, 1, dis, assign);
195
+ InterruptCallback::check();
196
+ t_search_tot += getmillisecs() - t0s;
197
+
198
+ err = 0;
199
+ for (int j = 0; j < nx; j++)
200
+ err += dis[j];
201
+ obj.push_back (err);
202
+
203
+ int nsplit = km_update_centroids (
204
+ x, centroids.data(),
205
+ assign, d, k, nx, frozen_centroids ? n_input_centroids : 0);
206
+
207
+ if (verbose) {
208
+ printf (" Iteration %d (%.2f s, search %.2f s): "
209
+ "objective=%g imbalance=%.3f nsplit=%d \r",
210
+ i, (getmillisecs() - t0) / 1000.0,
211
+ t_search_tot / 1000,
212
+ err, imbalance_factor (nx, k, assign),
213
+ nsplit);
214
+ fflush (stdout);
215
+ }
216
+
217
+ post_process_centroids ();
218
+
219
+ index.reset ();
220
+ if (update_index)
221
+ index.train (k, centroids.data());
222
+
223
+ assert (index.ntotal == 0);
224
+ index.add (k, centroids.data());
225
+ InterruptCallback::check ();
226
+ }
227
+ if (verbose) printf("\n");
228
+ if (nredo > 1) {
229
+ if (err < best_err) {
230
+ if (verbose)
231
+ printf ("Objective improved: keep new clusters\n");
232
+ best_centroids = centroids;
233
+ best_obj = obj;
234
+ best_err = err;
235
+ }
236
+ index.reset ();
237
+ }
238
+ }
239
+ if (nredo > 1) {
240
+ centroids = best_centroids;
241
+ obj = best_obj;
242
+ index.reset();
243
+ index.add(k, best_centroids.data());
244
+ }
245
+
246
+ }
247
+
248
+ float kmeans_clustering (size_t d, size_t n, size_t k,
249
+ const float *x,
250
+ float *centroids)
251
+ {
252
+ Clustering clus (d, k);
253
+ clus.verbose = d * n * k > (1L << 30);
254
+ // display logs if > 1Gflop per iteration
255
+ IndexFlatL2 index (d);
256
+ clus.train (n, x, index);
257
+ memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
258
+ return clus.obj.back();
259
+ }
260
+
261
+ } // namespace faiss