faiss 0.1.1 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/LICENSE.txt +18 -18
  4. data/README.md +1 -1
  5. data/lib/faiss/version.rb +1 -1
  6. data/vendor/faiss/Clustering.cpp +318 -53
  7. data/vendor/faiss/Clustering.h +39 -11
  8. data/vendor/faiss/DirectMap.cpp +267 -0
  9. data/vendor/faiss/DirectMap.h +120 -0
  10. data/vendor/faiss/IVFlib.cpp +24 -4
  11. data/vendor/faiss/IVFlib.h +4 -0
  12. data/vendor/faiss/Index.h +5 -24
  13. data/vendor/faiss/Index2Layer.cpp +0 -1
  14. data/vendor/faiss/IndexBinary.h +7 -3
  15. data/vendor/faiss/IndexBinaryFlat.cpp +5 -0
  16. data/vendor/faiss/IndexBinaryFlat.h +3 -0
  17. data/vendor/faiss/IndexBinaryHash.cpp +492 -0
  18. data/vendor/faiss/IndexBinaryHash.h +116 -0
  19. data/vendor/faiss/IndexBinaryIVF.cpp +160 -107
  20. data/vendor/faiss/IndexBinaryIVF.h +14 -4
  21. data/vendor/faiss/IndexFlat.h +2 -1
  22. data/vendor/faiss/IndexHNSW.cpp +68 -16
  23. data/vendor/faiss/IndexHNSW.h +3 -3
  24. data/vendor/faiss/IndexIVF.cpp +72 -76
  25. data/vendor/faiss/IndexIVF.h +24 -5
  26. data/vendor/faiss/IndexIVFFlat.cpp +19 -54
  27. data/vendor/faiss/IndexIVFFlat.h +1 -11
  28. data/vendor/faiss/IndexIVFPQ.cpp +49 -26
  29. data/vendor/faiss/IndexIVFPQ.h +9 -10
  30. data/vendor/faiss/IndexIVFPQR.cpp +2 -2
  31. data/vendor/faiss/IndexIVFSpectralHash.cpp +2 -2
  32. data/vendor/faiss/IndexLSH.h +4 -1
  33. data/vendor/faiss/IndexPreTransform.cpp +0 -1
  34. data/vendor/faiss/IndexScalarQuantizer.cpp +8 -1
  35. data/vendor/faiss/InvertedLists.cpp +0 -2
  36. data/vendor/faiss/MetaIndexes.cpp +0 -1
  37. data/vendor/faiss/MetricType.h +36 -0
  38. data/vendor/faiss/c_api/Clustering_c.cpp +13 -7
  39. data/vendor/faiss/c_api/Clustering_c.h +11 -5
  40. data/vendor/faiss/c_api/IndexIVF_c.cpp +7 -0
  41. data/vendor/faiss/c_api/IndexIVF_c.h +7 -0
  42. data/vendor/faiss/c_api/IndexPreTransform_c.cpp +21 -0
  43. data/vendor/faiss/c_api/IndexPreTransform_c.h +32 -0
  44. data/vendor/faiss/demos/demo_weighted_kmeans.cpp +185 -0
  45. data/vendor/faiss/gpu/GpuCloner.cpp +4 -0
  46. data/vendor/faiss/gpu/GpuClonerOptions.cpp +1 -1
  47. data/vendor/faiss/gpu/GpuDistance.h +93 -0
  48. data/vendor/faiss/gpu/GpuIndex.h +7 -0
  49. data/vendor/faiss/gpu/GpuIndexFlat.h +0 -10
  50. data/vendor/faiss/gpu/GpuIndexIVF.h +1 -0
  51. data/vendor/faiss/gpu/StandardGpuResources.cpp +8 -0
  52. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +49 -27
  53. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +110 -2
  54. data/vendor/faiss/gpu/utils/DeviceUtils.h +6 -0
  55. data/vendor/faiss/impl/AuxIndexStructures.cpp +17 -0
  56. data/vendor/faiss/impl/AuxIndexStructures.h +14 -3
  57. data/vendor/faiss/impl/HNSW.cpp +0 -1
  58. data/vendor/faiss/impl/PolysemousTraining.h +5 -5
  59. data/vendor/faiss/impl/ProductQuantizer-inl.h +138 -0
  60. data/vendor/faiss/impl/ProductQuantizer.cpp +1 -113
  61. data/vendor/faiss/impl/ProductQuantizer.h +42 -47
  62. data/vendor/faiss/impl/index_read.cpp +103 -7
  63. data/vendor/faiss/impl/index_write.cpp +101 -5
  64. data/vendor/faiss/impl/io.cpp +111 -1
  65. data/vendor/faiss/impl/io.h +38 -0
  66. data/vendor/faiss/index_factory.cpp +0 -1
  67. data/vendor/faiss/tests/test_merge.cpp +0 -1
  68. data/vendor/faiss/tests/test_pq_encoding.cpp +6 -6
  69. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +1 -0
  70. data/vendor/faiss/utils/distances.cpp +4 -5
  71. data/vendor/faiss/utils/distances_simd.cpp +0 -1
  72. data/vendor/faiss/utils/hamming.cpp +85 -3
  73. data/vendor/faiss/utils/hamming.h +20 -0
  74. data/vendor/faiss/utils/utils.cpp +0 -96
  75. data/vendor/faiss/utils/utils.h +0 -15
  76. metadata +11 -3
  77. data/lib/faiss/ext.bundle +0 -0
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: a0369e5dda330b1490e48a88863baa01df9cadfa570078892cec439f82efaad1
4
- data.tar.gz: bb7d89fa17f782e8163b114b520b8c2c082cf37661b4b6fc4593460dc5958484
3
+ metadata.gz: 61a0b7a7d20933b60a9e0e213016b77f10eae5bb86ecf6825f5f5661f31f5d7d
4
+ data.tar.gz: 967378ee774a35e3a639b1d902648ebf4177ffb7ed43e7228e8432f284f397c1
5
5
  SHA512:
6
- metadata.gz: 0a9f1515d142d11c688f1a9cdbcf9af0c36fa3fc98b240f236554b1067cf2daad1cefa377d18d236674b8fc1b94d64a3acc070c2528c47b68f4d231f29b7648d
7
- data.tar.gz: ae02808dbda4831c7165c987b77c72f9d436bba94e3e28d372f69ceee18fb4971c6cc99a2bf7ce9bd9a9a6e4befcd372515a728bf379294ecc870f2c58f85eb2
6
+ metadata.gz: 32747a4d4a3d40f15e9802280d894b2270d4b78ac0a10859442d0fd3c7ae27a55032a92e072756cb4046964c9d53afcc9586ac954e2b3cf63d057d0a3652e5a8
7
+ data.tar.gz: ca5005286253b7dea1546160ffb00c4b91a8b926512fbcfb7db594171435249e88408c982dc53fed4d90f87e68979bd9c2a2c1975f94495ca77c9ac878b22c1c
@@ -1,3 +1,7 @@
1
+ ## 0.1.2 (2020-08-17)
2
+
3
+ - Updated Faiss to 1.6.3
4
+
1
5
  ## 0.1.1 (2020-03-09)
2
6
 
3
7
  - Vendored library
@@ -1,22 +1,22 @@
1
- Copyright (c) 2020 Andrew Kane
2
-
3
1
  MIT License
4
2
 
5
- Permission is hereby granted, free of charge, to any person obtaining
6
- a copy of this software and associated documentation files (the
7
- "Software"), to deal in the Software without restriction, including
8
- without limitation the rights to use, copy, modify, merge, publish,
9
- distribute, sublicense, and/or sell copies of the Software, and to
10
- permit persons to whom the Software is furnished to do so, subject to
11
- the following conditions:
3
+ Copyright (c) Facebook, Inc. and its affiliates.
4
+ Copyright (c) 2020 Andrew Kane
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
12
 
13
- The above copyright notice and this permission notice shall be
14
- included in all copies or substantial portions of the Software.
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
15
 
16
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
data/README.md CHANGED
@@ -24,7 +24,7 @@ Add this line to your application’s Gemfile:
24
24
  gem 'faiss'
25
25
  ```
26
26
 
27
- Faiss is not available for Windows yet
27
+ Faiss is not available for Windows
28
28
 
29
29
  ## Getting Started
30
30
 
@@ -1,3 +1,3 @@
1
1
  module Faiss
2
- VERSION = "0.1.1"
2
+ VERSION = "0.1.2"
3
3
  end
@@ -10,11 +10,12 @@
10
10
  #include <faiss/Clustering.h>
11
11
  #include <faiss/impl/AuxIndexStructures.h>
12
12
 
13
-
14
13
  #include <cmath>
15
14
  #include <cstdio>
16
15
  #include <cstring>
17
16
 
17
+ #include <omp.h>
18
+
18
19
  #include <faiss/utils/utils.h>
19
20
  #include <faiss/utils/random.h>
20
21
  #include <faiss/utils/distances.h>
@@ -33,7 +34,8 @@ ClusteringParameters::ClusteringParameters ():
33
34
  frozen_centroids(false),
34
35
  min_points_per_centroid(39),
35
36
  max_points_per_centroid(256),
36
- seed(1234)
37
+ seed(1234),
38
+ decode_block_size(32768)
37
39
  {}
38
40
  // 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
39
41
 
@@ -76,35 +78,233 @@ void Clustering::post_process_centroids ()
76
78
  }
77
79
 
78
80
 
79
- void Clustering::train (idx_t nx, const float *x_in, Index & index) {
81
+ void Clustering::train (idx_t nx, const float *x_in, Index & index,
82
+ const float *weights) {
83
+ train_encoded (nx, reinterpret_cast<const uint8_t *>(x_in), nullptr,
84
+ index, weights);
85
+ }
86
+
87
+
88
+ namespace {
89
+
90
+ using idx_t = Clustering::idx_t;
91
+
92
+ idx_t subsample_training_set(
93
+ const Clustering &clus, idx_t nx, const uint8_t *x,
94
+ size_t line_size, const float * weights,
95
+ uint8_t **x_out,
96
+ float **weights_out
97
+ )
98
+ {
99
+ if (clus.verbose) {
100
+ printf("Sampling a subset of %ld / %ld for training\n",
101
+ clus.k * clus.max_points_per_centroid, nx);
102
+ }
103
+ std::vector<int> perm (nx);
104
+ rand_perm (perm.data (), nx, clus.seed);
105
+ nx = clus.k * clus.max_points_per_centroid;
106
+ uint8_t * x_new = new uint8_t [nx * line_size];
107
+ *x_out = x_new;
108
+ for (idx_t i = 0; i < nx; i++) {
109
+ memcpy (x_new + i * line_size, x + perm[i] * line_size, line_size);
110
+ }
111
+ if (weights) {
112
+ float *weights_new = new float[nx];
113
+ for (idx_t i = 0; i < nx; i++) {
114
+ weights_new[i] = weights[perm[i]];
115
+ }
116
+ *weights_out = weights_new;
117
+ } else {
118
+ *weights_out = nullptr;
119
+ }
120
+ return nx;
121
+ }
122
+
123
+ /** compute centroids as (weighted) sum of training points
124
+ *
125
+ * @param x training vectors, size n * code_size (from codec)
126
+ * @param codec how to decode the vectors (if NULL then cast to float*)
127
+ * @param weights per-training vector weight, size n (or NULL)
128
+ * @param assign nearest centroid for each training vector, size n
129
+ * @param k_frozen do not update the k_frozen first centroids
130
+ * @param centroids centroid vectors (output only), size k * d
131
+ * @param hassign histogram of assignments per centroid (size k),
132
+ * should be 0 on input
133
+ *
134
+ */
135
+
136
+ void compute_centroids (size_t d, size_t k, size_t n,
137
+ size_t k_frozen,
138
+ const uint8_t * x, const Index *codec,
139
+ const int64_t * assign,
140
+ const float * weights,
141
+ float * hassign,
142
+ float * centroids)
143
+ {
144
+ k -= k_frozen;
145
+ centroids += k_frozen * d;
146
+
147
+ memset (centroids, 0, sizeof(*centroids) * d * k);
148
+
149
+ size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float);
150
+
151
+ #pragma omp parallel
152
+ {
153
+ int nt = omp_get_num_threads();
154
+ int rank = omp_get_thread_num();
155
+
156
+ // this thread is taking care of centroids c0:c1
157
+ size_t c0 = (k * rank) / nt;
158
+ size_t c1 = (k * (rank + 1)) / nt;
159
+ std::vector<float> decode_buffer (d);
160
+
161
+ for (size_t i = 0; i < n; i++) {
162
+ int64_t ci = assign[i];
163
+ assert (ci >= 0 && ci < k + k_frozen);
164
+ ci -= k_frozen;
165
+ if (ci >= c0 && ci < c1) {
166
+ float * c = centroids + ci * d;
167
+ const float * xi;
168
+ if (!codec) {
169
+ xi = reinterpret_cast<const float*>(x + i * line_size);
170
+ } else {
171
+ float *xif = decode_buffer.data();
172
+ codec->sa_decode (1, x + i * line_size, xif);
173
+ xi = xif;
174
+ }
175
+ if (weights) {
176
+ float w = weights[i];
177
+ hassign[ci] += w;
178
+ for (size_t j = 0; j < d; j++) {
179
+ c[j] += xi[j] * w;
180
+ }
181
+ } else {
182
+ hassign[ci] += 1.0;
183
+ for (size_t j = 0; j < d; j++) {
184
+ c[j] += xi[j];
185
+ }
186
+ }
187
+ }
188
+ }
189
+
190
+ }
191
+
192
+ #pragma omp parallel for
193
+ for (size_t ci = 0; ci < k; ci++) {
194
+ if (hassign[ci] == 0) {
195
+ continue;
196
+ }
197
+ float norm = 1 / hassign[ci];
198
+ float * c = centroids + ci * d;
199
+ for (size_t j = 0; j < d; j++) {
200
+ c[j] *= norm;
201
+ }
202
+ }
203
+
204
+ }
205
+
206
+ // a bit above machine epsilon for float16
207
+ #define EPS (1 / 1024.)
208
+
209
+ /** Handle empty clusters by splitting larger ones.
210
+ *
211
+ * It works by slightly changing the centroids to make 2 clusters from
212
+ * a single one. Takes the same arguements as compute_centroids.
213
+ *
214
+ * @return nb of spliting operations (larger is worse)
215
+ */
216
+ int split_clusters (size_t d, size_t k, size_t n,
217
+ size_t k_frozen,
218
+ float * hassign,
219
+ float * centroids)
220
+ {
221
+ k -= k_frozen;
222
+ centroids += k_frozen * d;
223
+
224
+ /* Take care of void clusters */
225
+ size_t nsplit = 0;
226
+ RandomGenerator rng (1234);
227
+ for (size_t ci = 0; ci < k; ci++) {
228
+ if (hassign[ci] == 0) { /* need to redefine a centroid */
229
+ size_t cj;
230
+ for (cj = 0; 1; cj = (cj + 1) % k) {
231
+ /* probability to pick this cluster for split */
232
+ float p = (hassign[cj] - 1.0) / (float) (n - k);
233
+ float r = rng.rand_float ();
234
+ if (r < p) {
235
+ break; /* found our cluster to be split */
236
+ }
237
+ }
238
+ memcpy (centroids+ci*d, centroids+cj*d, sizeof(*centroids) * d);
239
+
240
+ /* small symmetric pertubation */
241
+ for (size_t j = 0; j < d; j++) {
242
+ if (j % 2 == 0) {
243
+ centroids[ci * d + j] *= 1 + EPS;
244
+ centroids[cj * d + j] *= 1 - EPS;
245
+ } else {
246
+ centroids[ci * d + j] *= 1 - EPS;
247
+ centroids[cj * d + j] *= 1 + EPS;
248
+ }
249
+ }
250
+
251
+ /* assume even split of the cluster */
252
+ hassign[ci] = hassign[cj] / 2;
253
+ hassign[cj] -= hassign[ci];
254
+ nsplit++;
255
+ }
256
+ }
257
+
258
+ return nsplit;
259
+
260
+ }
261
+
262
+
263
+
264
+ };
265
+
266
+
267
+ void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
268
+ const Index * codec, Index & index,
269
+ const float *weights) {
270
+
80
271
  FAISS_THROW_IF_NOT_FMT (nx >= k,
81
272
  "Number of training points (%ld) should be at least "
82
273
  "as large as number of clusters (%ld)", nx, k);
83
274
 
275
+ FAISS_THROW_IF_NOT_FMT ((!codec || codec->d == d),
276
+ "Codec dimension %d not the same as data dimension %d",
277
+ int(codec->d), int(d));
278
+
279
+ FAISS_THROW_IF_NOT_FMT (index.d == d,
280
+ "Index dimension %d not the same as data dimension %d",
281
+ int(index.d), int(d));
282
+
84
283
  double t0 = getmillisecs();
85
284
 
86
- // yes it is the user's responsibility, but it may spare us some
87
- // hard-to-debug reports.
88
- for (size_t i = 0; i < nx * d; i++) {
89
- FAISS_THROW_IF_NOT_MSG (finite (x_in[i]),
90
- "input contains NaN's or Inf's");
285
+ if (!codec) {
286
+ // Check for NaNs in input data. Normally it is the user's
287
+ // responsibility, but it may spare us some hard-to-debug
288
+ // reports.
289
+ const float *x = reinterpret_cast<const float *>(x_in);
290
+ for (size_t i = 0; i < nx * d; i++) {
291
+ FAISS_THROW_IF_NOT_MSG (finite (x[i]),
292
+ "input contains NaN's or Inf's");
293
+ }
91
294
  }
92
295
 
93
- const float *x = x_in;
94
- ScopeDeleter<float> del1;
296
+ const uint8_t *x = x_in;
297
+ std::unique_ptr<uint8_t []> del1;
298
+ std::unique_ptr<float []> del3;
299
+ size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
95
300
 
96
301
  if (nx > k * max_points_per_centroid) {
97
- if (verbose)
98
- printf("Sampling a subset of %ld / %ld for training\n",
99
- k * max_points_per_centroid, nx);
100
- std::vector<int> perm (nx);
101
- rand_perm (perm.data (), nx, seed);
102
- nx = k * max_points_per_centroid;
103
- float * x_new = new float [nx * d];
104
- for (idx_t i = 0; i < nx; i++)
105
- memcpy (x_new + i * d, x + perm[i] * d, sizeof(x_new[0]) * d);
106
- x = x_new;
107
- del1.set (x);
302
+ uint8_t *x_new;
303
+ float *weights_new;
304
+ nx = subsample_training_set (*this, nx, x, line_size, weights,
305
+ &x_new, &weights_new);
306
+ del1.reset (x_new); x = x_new;
307
+ del3.reset (weights_new); weights = weights_new;
108
308
  } else if (nx < k * min_points_per_centroid) {
109
309
  fprintf (stderr,
110
310
  "WARNING clustering %ld points to %ld centroids: "
@@ -112,41 +312,53 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
112
312
  nx, k, idx_t(k) * min_points_per_centroid);
113
313
  }
114
314
 
115
-
116
315
  if (nx == k) {
316
+ // this is a corner case, just copy training set to clusters
117
317
  if (verbose) {
118
318
  printf("Number of training points (%ld) same as number of "
119
319
  "clusters, just copying\n", nx);
120
320
  }
121
- // this is a corner case, just copy training set to clusters
122
321
  centroids.resize (d * k);
123
- memcpy (centroids.data(), x_in, sizeof (*x_in) * d * k);
322
+ if (!codec) {
323
+ memcpy (centroids.data(), x_in, sizeof (float) * d * k);
324
+ } else {
325
+ codec->sa_decode (nx, x_in, centroids.data());
326
+ }
327
+
328
+ // one fake iteration...
329
+ ClusteringIterationStats stats = { 0.0, 0.0, 0.0, 1.0, 0 };
330
+ iteration_stats.push_back (stats);
331
+
124
332
  index.reset();
125
- index.add(k, x_in);
333
+ index.add(k, centroids.data());
126
334
  return;
127
335
  }
128
336
 
129
337
 
130
- if (verbose)
338
+ if (verbose) {
131
339
  printf("Clustering %d points in %ldD to %ld clusters, "
132
340
  "redo %d times, %d iterations\n",
133
341
  int(nx), d, k, nredo, niter);
342
+ if (codec) {
343
+ printf("Input data encoded in %ld bytes per vector\n",
344
+ codec->sa_code_size ());
345
+ }
346
+ }
134
347
 
135
- idx_t * assign = new idx_t[nx];
136
- ScopeDeleter<idx_t> del (assign);
137
- float * dis = new float[nx];
138
- ScopeDeleter<float> del2(dis);
348
+ std::unique_ptr<idx_t []> assign(new idx_t[nx]);
349
+ std::unique_ptr<float []> dis(new float[nx]);
139
350
 
140
- // for redo
351
+ // remember best iteration for redo
141
352
  float best_err = HUGE_VALF;
142
- std::vector<float> best_obj;
353
+ std::vector<ClusteringIterationStats> best_obj;
143
354
  std::vector<float> best_centroids;
144
355
 
145
356
  // support input centroids
146
357
 
147
358
  FAISS_THROW_IF_NOT_MSG (
148
359
  centroids.size() % d == 0,
149
- "size of provided input centroids not a multiple of dimension");
360
+ "size of provided input centroids not a multiple of dimension"
361
+ );
150
362
 
151
363
  size_t n_input_centroids = centroids.size() / d;
152
364
 
@@ -162,23 +374,36 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
162
374
  }
163
375
  t0 = getmillisecs();
164
376
 
377
+ // temporary buffer to decode vectors during the optimization
378
+ std::vector<float> decode_buffer
379
+ (codec ? d * decode_block_size : 0);
380
+
165
381
  for (int redo = 0; redo < nredo; redo++) {
166
382
 
167
383
  if (verbose && nredo > 1) {
168
384
  printf("Outer iteration %d / %d\n", redo, nredo);
169
385
  }
170
386
 
171
- // initialize remaining centroids with random points from the dataset
387
+ // initialize (remaining) centroids with random points from the dataset
172
388
  centroids.resize (d * k);
173
389
  std::vector<int> perm (nx);
174
390
 
175
391
  rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
176
- for (int i = n_input_centroids; i < k ; i++)
177
- memcpy (&centroids[i * d], x + perm[i] * d,
178
- d * sizeof (float));
392
+
393
+ if (!codec) {
394
+ for (int i = n_input_centroids; i < k ; i++) {
395
+ memcpy (&centroids[i * d], x + perm[i] * line_size, line_size);
396
+ }
397
+ } else {
398
+ for (int i = n_input_centroids; i < k ; i++) {
399
+ codec->sa_decode (1, x + perm[i] * line_size, &centroids[i * d]);
400
+ }
401
+ }
179
402
 
180
403
  post_process_centroids ();
181
404
 
405
+ // prepare the index
406
+
182
407
  if (index.ntotal != 0) {
183
408
  index.reset();
184
409
  }
@@ -188,49 +413,89 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
188
413
  }
189
414
 
190
415
  index.add (k, centroids.data());
416
+
417
+ // k-means iterations
418
+
191
419
  float err = 0;
192
420
  for (int i = 0; i < niter; i++) {
193
421
  double t0s = getmillisecs();
194
- index.search (nx, x, 1, dis, assign);
422
+
423
+ if (!codec) {
424
+ index.search (nx, reinterpret_cast<const float *>(x), 1,
425
+ dis.get(), assign.get());
426
+ } else {
427
+ // search by blocks of decode_block_size vectors
428
+ size_t code_size = codec->sa_code_size ();
429
+ for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
430
+ size_t i1 = i0 + decode_block_size;
431
+ if (i1 > nx) { i1 = nx; }
432
+ codec->sa_decode (i1 - i0, x + code_size * i0,
433
+ decode_buffer.data ());
434
+ index.search (i1 - i0, decode_buffer.data (), 1,
435
+ dis.get() + i0, assign.get() + i0);
436
+ }
437
+ }
438
+
195
439
  InterruptCallback::check();
196
440
  t_search_tot += getmillisecs() - t0s;
197
441
 
442
+ // accumulate error
198
443
  err = 0;
199
- for (int j = 0; j < nx; j++)
444
+ for (int j = 0; j < nx; j++) {
200
445
  err += dis[j];
201
- obj.push_back (err);
446
+ }
447
+
448
+ // update the centroids
449
+ std::vector<float> hassign (k);
202
450
 
203
- int nsplit = km_update_centroids (
204
- x, centroids.data(),
205
- assign, d, k, nx, frozen_centroids ? n_input_centroids : 0);
451
+ size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
452
+ compute_centroids (
453
+ d, k, nx, k_frozen,
454
+ x, codec, assign.get(), weights,
455
+ hassign.data(), centroids.data()
456
+ );
457
+
458
+ int nsplit = split_clusters (
459
+ d, k, nx, k_frozen,
460
+ hassign.data(), centroids.data()
461
+ );
462
+
463
+ // collect statistics
464
+ ClusteringIterationStats stats =
465
+ { err, (getmillisecs() - t0) / 1000.0,
466
+ t_search_tot / 1000, imbalance_factor (nx, k, assign.get()),
467
+ nsplit };
468
+ iteration_stats.push_back(stats);
206
469
 
207
470
  if (verbose) {
208
471
  printf (" Iteration %d (%.2f s, search %.2f s): "
209
472
  "objective=%g imbalance=%.3f nsplit=%d \r",
210
- i, (getmillisecs() - t0) / 1000.0,
211
- t_search_tot / 1000,
212
- err, imbalance_factor (nx, k, assign),
213
- nsplit);
473
+ i, stats.time, stats.time_search, stats.obj,
474
+ stats.imbalance_factor, nsplit);
214
475
  fflush (stdout);
215
476
  }
216
477
 
217
478
  post_process_centroids ();
218
479
 
480
+ // add centroids to index for the next iteration (or for output)
481
+
219
482
  index.reset ();
220
- if (update_index)
483
+ if (update_index) {
221
484
  index.train (k, centroids.data());
485
+ }
222
486
 
223
- assert (index.ntotal == 0);
224
487
  index.add (k, centroids.data());
225
488
  InterruptCallback::check ();
226
489
  }
490
+
227
491
  if (verbose) printf("\n");
228
492
  if (nredo > 1) {
229
493
  if (err < best_err) {
230
- if (verbose)
494
+ if (verbose) {
231
495
  printf ("Objective improved: keep new clusters\n");
496
+ }
232
497
  best_centroids = centroids;
233
- best_obj = obj;
498
+ best_obj = iteration_stats;
234
499
  best_err = err;
235
500
  }
236
501
  index.reset ();
@@ -238,7 +503,7 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
238
503
  }
239
504
  if (nredo > 1) {
240
505
  centroids = best_centroids;
241
- obj = best_obj;
506
+ iteration_stats = best_obj;
242
507
  index.reset();
243
508
  index.add(k, best_centroids.data());
244
509
  }
@@ -255,7 +520,7 @@ float kmeans_clustering (size_t d, size_t n, size_t k,
255
520
  IndexFlatL2 index (d);
256
521
  clus.train (n, x, index);
257
522
  memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
258
- return clus.obj.back();
523
+ return clus.iteration_stats.back().obj;
259
524
  }
260
525
 
261
526
  } // namespace faiss