faiss 0.1.1 → 0.1.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (77) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +18 -18
  4. data/README.md +1 -1
  5. data/lib/faiss/version.rb +1 -1
  6. data/vendor/faiss/Clustering.cpp +318 -53
  7. data/vendor/faiss/Clustering.h +39 -11
  8. data/vendor/faiss/DirectMap.cpp +267 -0
  9. data/vendor/faiss/DirectMap.h +120 -0
  10. data/vendor/faiss/IVFlib.cpp +24 -4
  11. data/vendor/faiss/IVFlib.h +4 -0
  12. data/vendor/faiss/Index.h +5 -24
  13. data/vendor/faiss/Index2Layer.cpp +0 -1
  14. data/vendor/faiss/IndexBinary.h +7 -3
  15. data/vendor/faiss/IndexBinaryFlat.cpp +5 -0
  16. data/vendor/faiss/IndexBinaryFlat.h +3 -0
  17. data/vendor/faiss/IndexBinaryHash.cpp +492 -0
  18. data/vendor/faiss/IndexBinaryHash.h +116 -0
  19. data/vendor/faiss/IndexBinaryIVF.cpp +160 -107
  20. data/vendor/faiss/IndexBinaryIVF.h +14 -4
  21. data/vendor/faiss/IndexFlat.h +2 -1
  22. data/vendor/faiss/IndexHNSW.cpp +68 -16
  23. data/vendor/faiss/IndexHNSW.h +3 -3
  24. data/vendor/faiss/IndexIVF.cpp +72 -76
  25. data/vendor/faiss/IndexIVF.h +24 -5
  26. data/vendor/faiss/IndexIVFFlat.cpp +19 -54
  27. data/vendor/faiss/IndexIVFFlat.h +1 -11
  28. data/vendor/faiss/IndexIVFPQ.cpp +49 -26
  29. data/vendor/faiss/IndexIVFPQ.h +9 -10
  30. data/vendor/faiss/IndexIVFPQR.cpp +2 -2
  31. data/vendor/faiss/IndexIVFSpectralHash.cpp +2 -2
  32. data/vendor/faiss/IndexLSH.h +4 -1
  33. data/vendor/faiss/IndexPreTransform.cpp +0 -1
  34. data/vendor/faiss/IndexScalarQuantizer.cpp +8 -1
  35. data/vendor/faiss/InvertedLists.cpp +0 -2
  36. data/vendor/faiss/MetaIndexes.cpp +0 -1
  37. data/vendor/faiss/MetricType.h +36 -0
  38. data/vendor/faiss/c_api/Clustering_c.cpp +13 -7
  39. data/vendor/faiss/c_api/Clustering_c.h +11 -5
  40. data/vendor/faiss/c_api/IndexIVF_c.cpp +7 -0
  41. data/vendor/faiss/c_api/IndexIVF_c.h +7 -0
  42. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +21 -0
  43. data/vendor/faiss/c_api/IndexPreTransform_c.h +32 -0
  44. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +185 -0
  45. data/vendor/faiss/gpu/GpuCloner.cpp +4 -0
  46. data/vendor/faiss/gpu/GpuClonerOptions.cpp +1 -1
  47. data/vendor/faiss/gpu/GpuDistance.h +93 -0
  48. data/vendor/faiss/gpu/GpuIndex.h +7 -0
  49. data/vendor/faiss/gpu/GpuIndexFlat.h +0 -10
  50. data/vendor/faiss/gpu/GpuIndexIVF.h +1 -0
  51. data/vendor/faiss/gpu/StandardGpuResources.cpp +8 -0
  52. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +49 -27
  53. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +110 -2
  54. data/vendor/faiss/gpu/utils/DeviceUtils.h +6 -0
  55. data/vendor/faiss/impl/AuxIndexStructures.cpp +17 -0
  56. data/vendor/faiss/impl/AuxIndexStructures.h +14 -3
  57. data/vendor/faiss/impl/HNSW.cpp +0 -1
  58. data/vendor/faiss/impl/PolysemousTraining.h +5 -5
  59. data/vendor/faiss/impl/ProductQuantizer-inl.h +138 -0
  60. data/vendor/faiss/impl/ProductQuantizer.cpp +1 -113
  61. data/vendor/faiss/impl/ProductQuantizer.h +42 -47
  62. data/vendor/faiss/impl/index_read.cpp +103 -7
  63. data/vendor/faiss/impl/index_write.cpp +101 -5
  64. data/vendor/faiss/impl/io.cpp +111 -1
  65. data/vendor/faiss/impl/io.h +38 -0
  66. data/vendor/faiss/index_factory.cpp +0 -1
  67. data/vendor/faiss/tests/test_merge.cpp +0 -1
  68. data/vendor/faiss/tests/test_pq_encoding.cpp +6 -6
  69. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +1 -0
  70. data/vendor/faiss/utils/distances.cpp +4 -5
  71. data/vendor/faiss/utils/distances_simd.cpp +0 -1
  72. data/vendor/faiss/utils/hamming.cpp +85 -3
  73. data/vendor/faiss/utils/hamming.h +20 -0
  74. data/vendor/faiss/utils/utils.cpp +0 -96
  75. data/vendor/faiss/utils/utils.h +0 -15
  76. metadata +11 -3
  77. data/lib/faiss/ext.bundle +0 -0
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: a0369e5dda330b1490e48a88863baa01df9cadfa570078892cec439f82efaad1
4
- data.tar.gz: bb7d89fa17f782e8163b114b520b8c2c082cf37661b4b6fc4593460dc5958484
3
+ metadata.gz: 61a0b7a7d20933b60a9e0e213016b77f10eae5bb86ecf6825f5f5661f31f5d7d
4
+ data.tar.gz: 967378ee774a35e3a639b1d902648ebf4177ffb7ed43e7228e8432f284f397c1
5
5
  SHA512:
6
- metadata.gz: 0a9f1515d142d11c688f1a9cdbcf9af0c36fa3fc98b240f236554b1067cf2daad1cefa377d18d236674b8fc1b94d64a3acc070c2528c47b68f4d231f29b7648d
7
- data.tar.gz: ae02808dbda4831c7165c987b77c72f9d436bba94e3e28d372f69ceee18fb4971c6cc99a2bf7ce9bd9a9a6e4befcd372515a728bf379294ecc870f2c58f85eb2
6
+ metadata.gz: 32747a4d4a3d40f15e9802280d894b2270d4b78ac0a10859442d0fd3c7ae27a55032a92e072756cb4046964c9d53afcc9586ac954e2b3cf63d057d0a3652e5a8
7
+ data.tar.gz: ca5005286253b7dea1546160ffb00c4b91a8b926512fbcfb7db594171435249e88408c982dc53fed4d90f87e68979bd9c2a2c1975f94495ca77c9ac878b22c1c
@@ -1,3 +1,7 @@
1
+ ## 0.1.2 (2020-08-17)
2
+
3
+ - Updated Faiss to 1.6.3
4
+
1
5
  ## 0.1.1 (2020-03-09)
2
6
 
3
7
  - Vendored library
@@ -1,22 +1,22 @@
1
- Copyright (c) 2020 Andrew Kane
2
-
3
1
  MIT License
4
2
 
5
- Permission is hereby granted, free of charge, to any person obtaining
6
- a copy of this software and associated documentation files (the
7
- "Software"), to deal in the Software without restriction, including
8
- without limitation the rights to use, copy, modify, merge, publish,
9
- distribute, sublicense, and/or sell copies of the Software, and to
10
- permit persons to whom the Software is furnished to do so, subject to
11
- the following conditions:
3
+ Copyright (c) Facebook, Inc. and its affiliates.
4
+ Copyright (c) 2020 Andrew Kane
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
12
 
13
- The above copyright notice and this permission notice shall be
14
- included in all copies or substantial portions of the Software.
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
15
 
16
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
data/README.md CHANGED
@@ -24,7 +24,7 @@ Add this line to your application’s Gemfile:
24
24
  gem 'faiss'
25
25
  ```
26
26
 
27
- Faiss is not available for Windows yet
27
+ Faiss is not available for Windows
28
28
 
29
29
  ## Getting Started
30
30
 
@@ -1,3 +1,3 @@
1
1
  module Faiss
2
- VERSION = "0.1.1"
2
+ VERSION = "0.1.2"
3
3
  end
@@ -10,11 +10,12 @@
10
10
  #include <faiss/Clustering.h>
11
11
  #include <faiss/impl/AuxIndexStructures.h>
12
12
 
13
-
14
13
  #include <cmath>
15
14
  #include <cstdio>
16
15
  #include <cstring>
17
16
 
17
+ #include <omp.h>
18
+
18
19
  #include <faiss/utils/utils.h>
19
20
  #include <faiss/utils/random.h>
20
21
  #include <faiss/utils/distances.h>
@@ -33,7 +34,8 @@ ClusteringParameters::ClusteringParameters ():
33
34
  frozen_centroids(false),
34
35
  min_points_per_centroid(39),
35
36
  max_points_per_centroid(256),
36
- seed(1234)
37
+ seed(1234),
38
+ decode_block_size(32768)
37
39
  {}
38
40
  // 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
39
41
 
@@ -76,35 +78,233 @@ void Clustering::post_process_centroids ()
76
78
  }
77
79
 
78
80
 
79
- void Clustering::train (idx_t nx, const float *x_in, Index & index) {
81
+ void Clustering::train (idx_t nx, const float *x_in, Index & index,
82
+ const float *weights) {
83
+ train_encoded (nx, reinterpret_cast<const uint8_t *>(x_in), nullptr,
84
+ index, weights);
85
+ }
86
+
87
+
88
+ namespace {
89
+
90
+ using idx_t = Clustering::idx_t;
91
+
92
+ idx_t subsample_training_set(
93
+ const Clustering &clus, idx_t nx, const uint8_t *x,
94
+ size_t line_size, const float * weights,
95
+ uint8_t **x_out,
96
+ float **weights_out
97
+ )
98
+ {
99
+ if (clus.verbose) {
100
+ printf("Sampling a subset of %ld / %ld for training\n",
101
+ clus.k * clus.max_points_per_centroid, nx);
102
+ }
103
+ std::vector<int> perm (nx);
104
+ rand_perm (perm.data (), nx, clus.seed);
105
+ nx = clus.k * clus.max_points_per_centroid;
106
+ uint8_t * x_new = new uint8_t [nx * line_size];
107
+ *x_out = x_new;
108
+ for (idx_t i = 0; i < nx; i++) {
109
+ memcpy (x_new + i * line_size, x + perm[i] * line_size, line_size);
110
+ }
111
+ if (weights) {
112
+ float *weights_new = new float[nx];
113
+ for (idx_t i = 0; i < nx; i++) {
114
+ weights_new[i] = weights[perm[i]];
115
+ }
116
+ *weights_out = weights_new;
117
+ } else {
118
+ *weights_out = nullptr;
119
+ }
120
+ return nx;
121
+ }
122
+
123
+ /** compute centroids as (weighted) sum of training points
124
+ *
125
+ * @param x training vectors, size n * code_size (from codec)
126
+ * @param codec how to decode the vectors (if NULL then cast to float*)
127
+ * @param weights per-training vector weight, size n (or NULL)
128
+ * @param assign nearest centroid for each training vector, size n
129
+ * @param k_frozen do not update the k_frozen first centroids
130
+ * @param centroids centroid vectors (output only), size k * d
131
+ * @param hassign histogram of assignments per centroid (size k),
132
+ * should be 0 on input
133
+ *
134
+ */
135
+
136
+ void compute_centroids (size_t d, size_t k, size_t n,
137
+ size_t k_frozen,
138
+ const uint8_t * x, const Index *codec,
139
+ const int64_t * assign,
140
+ const float * weights,
141
+ float * hassign,
142
+ float * centroids)
143
+ {
144
+ k -= k_frozen;
145
+ centroids += k_frozen * d;
146
+
147
+ memset (centroids, 0, sizeof(*centroids) * d * k);
148
+
149
+ size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float);
150
+
151
+ #pragma omp parallel
152
+ {
153
+ int nt = omp_get_num_threads();
154
+ int rank = omp_get_thread_num();
155
+
156
+ // this thread is taking care of centroids c0:c1
157
+ size_t c0 = (k * rank) / nt;
158
+ size_t c1 = (k * (rank + 1)) / nt;
159
+ std::vector<float> decode_buffer (d);
160
+
161
+ for (size_t i = 0; i < n; i++) {
162
+ int64_t ci = assign[i];
163
+ assert (ci >= 0 && ci < k + k_frozen);
164
+ ci -= k_frozen;
165
+ if (ci >= c0 && ci < c1) {
166
+ float * c = centroids + ci * d;
167
+ const float * xi;
168
+ if (!codec) {
169
+ xi = reinterpret_cast<const float*>(x + i * line_size);
170
+ } else {
171
+ float *xif = decode_buffer.data();
172
+ codec->sa_decode (1, x + i * line_size, xif);
173
+ xi = xif;
174
+ }
175
+ if (weights) {
176
+ float w = weights[i];
177
+ hassign[ci] += w;
178
+ for (size_t j = 0; j < d; j++) {
179
+ c[j] += xi[j] * w;
180
+ }
181
+ } else {
182
+ hassign[ci] += 1.0;
183
+ for (size_t j = 0; j < d; j++) {
184
+ c[j] += xi[j];
185
+ }
186
+ }
187
+ }
188
+ }
189
+
190
+ }
191
+
192
+ #pragma omp parallel for
193
+ for (size_t ci = 0; ci < k; ci++) {
194
+ if (hassign[ci] == 0) {
195
+ continue;
196
+ }
197
+ float norm = 1 / hassign[ci];
198
+ float * c = centroids + ci * d;
199
+ for (size_t j = 0; j < d; j++) {
200
+ c[j] *= norm;
201
+ }
202
+ }
203
+
204
+ }
205
+
206
+ // a bit above machine epsilon for float16
207
+ #define EPS (1 / 1024.)
208
+
209
+ /** Handle empty clusters by splitting larger ones.
210
+ *
211
+ * It works by slightly changing the centroids to make 2 clusters from
212
+ * a single one. Takes the same arguements as compute_centroids.
213
+ *
214
+ * @return nb of spliting operations (larger is worse)
215
+ */
216
+ int split_clusters (size_t d, size_t k, size_t n,
217
+ size_t k_frozen,
218
+ float * hassign,
219
+ float * centroids)
220
+ {
221
+ k -= k_frozen;
222
+ centroids += k_frozen * d;
223
+
224
+ /* Take care of void clusters */
225
+ size_t nsplit = 0;
226
+ RandomGenerator rng (1234);
227
+ for (size_t ci = 0; ci < k; ci++) {
228
+ if (hassign[ci] == 0) { /* need to redefine a centroid */
229
+ size_t cj;
230
+ for (cj = 0; 1; cj = (cj + 1) % k) {
231
+ /* probability to pick this cluster for split */
232
+ float p = (hassign[cj] - 1.0) / (float) (n - k);
233
+ float r = rng.rand_float ();
234
+ if (r < p) {
235
+ break; /* found our cluster to be split */
236
+ }
237
+ }
238
+ memcpy (centroids+ci*d, centroids+cj*d, sizeof(*centroids) * d);
239
+
240
+ /* small symmetric pertubation */
241
+ for (size_t j = 0; j < d; j++) {
242
+ if (j % 2 == 0) {
243
+ centroids[ci * d + j] *= 1 + EPS;
244
+ centroids[cj * d + j] *= 1 - EPS;
245
+ } else {
246
+ centroids[ci * d + j] *= 1 - EPS;
247
+ centroids[cj * d + j] *= 1 + EPS;
248
+ }
249
+ }
250
+
251
+ /* assume even split of the cluster */
252
+ hassign[ci] = hassign[cj] / 2;
253
+ hassign[cj] -= hassign[ci];
254
+ nsplit++;
255
+ }
256
+ }
257
+
258
+ return nsplit;
259
+
260
+ }
261
+
262
+
263
+
264
+ };
265
+
266
+
267
+ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
268
+ const Index * codec, Index & index,
269
+ const float *weights) {
270
+
80
271
  FAISS_THROW_IF_NOT_FMT (nx >= k,
81
272
  "Number of training points (%ld) should be at least "
82
273
  "as large as number of clusters (%ld)", nx, k);
83
274
 
275
+ FAISS_THROW_IF_NOT_FMT ((!codec || codec->d == d),
276
+ "Codec dimension %d not the same as data dimension %d",
277
+ int(codec->d), int(d));
278
+
279
+ FAISS_THROW_IF_NOT_FMT (index.d == d,
280
+ "Index dimension %d not the same as data dimension %d",
281
+ int(index.d), int(d));
282
+
84
283
  double t0 = getmillisecs();
85
284
 
86
- // yes it is the user's responsibility, but it may spare us some
87
- // hard-to-debug reports.
88
- for (size_t i = 0; i < nx * d; i++) {
89
- FAISS_THROW_IF_NOT_MSG (finite (x_in[i]),
90
- "input contains NaN's or Inf's");
285
+ if (!codec) {
286
+ // Check for NaNs in input data. Normally it is the user's
287
+ // responsibility, but it may spare us some hard-to-debug
288
+ // reports.
289
+ const float *x = reinterpret_cast<const float *>(x_in);
290
+ for (size_t i = 0; i < nx * d; i++) {
291
+ FAISS_THROW_IF_NOT_MSG (finite (x[i]),
292
+ "input contains NaN's or Inf's");
293
+ }
91
294
  }
92
295
 
93
- const float *x = x_in;
94
- ScopeDeleter<float> del1;
296
+ const uint8_t *x = x_in;
297
+ std::unique_ptr<uint8_t []> del1;
298
+ std::unique_ptr<float []> del3;
299
+ size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
95
300
 
96
301
  if (nx > k * max_points_per_centroid) {
97
- if (verbose)
98
- printf("Sampling a subset of %ld / %ld for training\n",
99
- k * max_points_per_centroid, nx);
100
- std::vector<int> perm (nx);
101
- rand_perm (perm.data (), nx, seed);
102
- nx = k * max_points_per_centroid;
103
- float * x_new = new float [nx * d];
104
- for (idx_t i = 0; i < nx; i++)
105
- memcpy (x_new + i * d, x + perm[i] * d, sizeof(x_new[0]) * d);
106
- x = x_new;
107
- del1.set (x);
302
+ uint8_t *x_new;
303
+ float *weights_new;
304
+ nx = subsample_training_set (*this, nx, x, line_size, weights,
305
+ &x_new, &weights_new);
306
+ del1.reset (x_new); x = x_new;
307
+ del3.reset (weights_new); weights = weights_new;
108
308
  } else if (nx < k * min_points_per_centroid) {
109
309
  fprintf (stderr,
110
310
  "WARNING clustering %ld points to %ld centroids: "
@@ -112,41 +312,53 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
112
312
  nx, k, idx_t(k) * min_points_per_centroid);
113
313
  }
114
314
 
115
-
116
315
  if (nx == k) {
316
+ // this is a corner case, just copy training set to clusters
117
317
  if (verbose) {
118
318
  printf("Number of training points (%ld) same as number of "
119
319
  "clusters, just copying\n", nx);
120
320
  }
121
- // this is a corner case, just copy training set to clusters
122
321
  centroids.resize (d * k);
123
- memcpy (centroids.data(), x_in, sizeof (*x_in) * d * k);
322
+ if (!codec) {
323
+ memcpy (centroids.data(), x_in, sizeof (float) * d * k);
324
+ } else {
325
+ codec->sa_decode (nx, x_in, centroids.data());
326
+ }
327
+
328
+ // one fake iteration...
329
+ ClusteringIterationStats stats = { 0.0, 0.0, 0.0, 1.0, 0 };
330
+ iteration_stats.push_back (stats);
331
+
124
332
  index.reset();
125
- index.add(k, x_in);
333
+ index.add(k, centroids.data());
126
334
  return;
127
335
  }
128
336
 
129
337
 
130
- if (verbose)
338
+ if (verbose) {
131
339
  printf("Clustering %d points in %ldD to %ld clusters, "
132
340
  "redo %d times, %d iterations\n",
133
341
  int(nx), d, k, nredo, niter);
342
+ if (codec) {
343
+ printf("Input data encoded in %ld bytes per vector\n",
344
+ codec->sa_code_size ());
345
+ }
346
+ }
134
347
 
135
- idx_t * assign = new idx_t[nx];
136
- ScopeDeleter<idx_t> del (assign);
137
- float * dis = new float[nx];
138
- ScopeDeleter<float> del2(dis);
348
+ std::unique_ptr<idx_t []> assign(new idx_t[nx]);
349
+ std::unique_ptr<float []> dis(new float[nx]);
139
350
 
140
- // for redo
351
+ // remember best iteration for redo
141
352
  float best_err = HUGE_VALF;
142
- std::vector<float> best_obj;
353
+ std::vector<ClusteringIterationStats> best_obj;
143
354
  std::vector<float> best_centroids;
144
355
 
145
356
  // support input centroids
146
357
 
147
358
  FAISS_THROW_IF_NOT_MSG (
148
359
  centroids.size() % d == 0,
149
- "size of provided input centroids not a multiple of dimension");
360
+ "size of provided input centroids not a multiple of dimension"
361
+ );
150
362
 
151
363
  size_t n_input_centroids = centroids.size() / d;
152
364
 
@@ -162,23 +374,36 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
162
374
  }
163
375
  t0 = getmillisecs();
164
376
 
377
+ // temporary buffer to decode vectors during the optimization
378
+ std::vector<float> decode_buffer
379
+ (codec ? d * decode_block_size : 0);
380
+
165
381
  for (int redo = 0; redo < nredo; redo++) {
166
382
 
167
383
  if (verbose && nredo > 1) {
168
384
  printf("Outer iteration %d / %d\n", redo, nredo);
169
385
  }
170
386
 
171
- // initialize remaining centroids with random points from the dataset
387
+ // initialize (remaining) centroids with random points from the dataset
172
388
  centroids.resize (d * k);
173
389
  std::vector<int> perm (nx);
174
390
 
175
391
  rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
176
- for (int i = n_input_centroids; i < k ; i++)
177
- memcpy (&centroids[i * d], x + perm[i] * d,
178
- d * sizeof (float));
392
+
393
+ if (!codec) {
394
+ for (int i = n_input_centroids; i < k ; i++) {
395
+ memcpy (&centroids[i * d], x + perm[i] * line_size, line_size);
396
+ }
397
+ } else {
398
+ for (int i = n_input_centroids; i < k ; i++) {
399
+ codec->sa_decode (1, x + perm[i] * line_size, &centroids[i * d]);
400
+ }
401
+ }
179
402
 
180
403
  post_process_centroids ();
181
404
 
405
+ // prepare the index
406
+
182
407
  if (index.ntotal != 0) {
183
408
  index.reset();
184
409
  }
@@ -188,49 +413,89 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
188
413
  }
189
414
 
190
415
  index.add (k, centroids.data());
416
+
417
+ // k-means iterations
418
+
191
419
  float err = 0;
192
420
  for (int i = 0; i < niter; i++) {
193
421
  double t0s = getmillisecs();
194
- index.search (nx, x, 1, dis, assign);
422
+
423
+ if (!codec) {
424
+ index.search (nx, reinterpret_cast<const float *>(x), 1,
425
+ dis.get(), assign.get());
426
+ } else {
427
+ // search by blocks of decode_block_size vectors
428
+ size_t code_size = codec->sa_code_size ();
429
+ for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
430
+ size_t i1 = i0 + decode_block_size;
431
+ if (i1 > nx) { i1 = nx; }
432
+ codec->sa_decode (i1 - i0, x + code_size * i0,
433
+ decode_buffer.data ());
434
+ index.search (i1 - i0, decode_buffer.data (), 1,
435
+ dis.get() + i0, assign.get() + i0);
436
+ }
437
+ }
438
+
195
439
  InterruptCallback::check();
196
440
  t_search_tot += getmillisecs() - t0s;
197
441
 
442
+ // accumulate error
198
443
  err = 0;
199
- for (int j = 0; j < nx; j++)
444
+ for (int j = 0; j < nx; j++) {
200
445
  err += dis[j];
201
- obj.push_back (err);
446
+ }
447
+
448
+ // update the centroids
449
+ std::vector<float> hassign (k);
202
450
 
203
- int nsplit = km_update_centroids (
204
- x, centroids.data(),
205
- assign, d, k, nx, frozen_centroids ? n_input_centroids : 0);
451
+ size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
452
+ compute_centroids (
453
+ d, k, nx, k_frozen,
454
+ x, codec, assign.get(), weights,
455
+ hassign.data(), centroids.data()
456
+ );
457
+
458
+ int nsplit = split_clusters (
459
+ d, k, nx, k_frozen,
460
+ hassign.data(), centroids.data()
461
+ );
462
+
463
+ // collect statistics
464
+ ClusteringIterationStats stats =
465
+ { err, (getmillisecs() - t0) / 1000.0,
466
+ t_search_tot / 1000, imbalance_factor (nx, k, assign.get()),
467
+ nsplit };
468
+ iteration_stats.push_back(stats);
206
469
 
207
470
  if (verbose) {
208
471
  printf (" Iteration %d (%.2f s, search %.2f s): "
209
472
  "objective=%g imbalance=%.3f nsplit=%d \r",
210
- i, (getmillisecs() - t0) / 1000.0,
211
- t_search_tot / 1000,
212
- err, imbalance_factor (nx, k, assign),
213
- nsplit);
473
+ i, stats.time, stats.time_search, stats.obj,
474
+ stats.imbalance_factor, nsplit);
214
475
  fflush (stdout);
215
476
  }
216
477
 
217
478
  post_process_centroids ();
218
479
 
480
+ // add centroids to index for the next iteration (or for output)
481
+
219
482
  index.reset ();
220
- if (update_index)
483
+ if (update_index) {
221
484
  index.train (k, centroids.data());
485
+ }
222
486
 
223
- assert (index.ntotal == 0);
224
487
  index.add (k, centroids.data());
225
488
  InterruptCallback::check ();
226
489
  }
490
+
227
491
  if (verbose) printf("\n");
228
492
  if (nredo > 1) {
229
493
  if (err < best_err) {
230
- if (verbose)
494
+ if (verbose) {
231
495
  printf ("Objective improved: keep new clusters\n");
496
+ }
232
497
  best_centroids = centroids;
233
- best_obj = obj;
498
+ best_obj = iteration_stats;
234
499
  best_err = err;
235
500
  }
236
501
  index.reset ();
@@ -238,7 +503,7 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
238
503
  }
239
504
  if (nredo > 1) {
240
505
  centroids = best_centroids;
241
- obj = best_obj;
506
+ iteration_stats = best_obj;
242
507
  index.reset();
243
508
  index.add(k, best_centroids.data());
244
509
  }
@@ -255,7 +520,7 @@ float kmeans_clustering (size_t d, size_t n, size_t k,
255
520
  IndexFlatL2 index (d);
256
521
  clus.train (n, x, index);
257
522
  memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
258
- return clus.obj.back();
523
+ return clus.iteration_stats.back().obj;
259
524
  }
260
525
 
261
526
  } // namespace faiss