faiss 0.1.1 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +18 -18
- data/README.md +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/Clustering.cpp +318 -53
- data/vendor/faiss/Clustering.h +39 -11
- data/vendor/faiss/DirectMap.cpp +267 -0
- data/vendor/faiss/DirectMap.h +120 -0
- data/vendor/faiss/IVFlib.cpp +24 -4
- data/vendor/faiss/IVFlib.h +4 -0
- data/vendor/faiss/Index.h +5 -24
- data/vendor/faiss/Index2Layer.cpp +0 -1
- data/vendor/faiss/IndexBinary.h +7 -3
- data/vendor/faiss/IndexBinaryFlat.cpp +5 -0
- data/vendor/faiss/IndexBinaryFlat.h +3 -0
- data/vendor/faiss/IndexBinaryHash.cpp +492 -0
- data/vendor/faiss/IndexBinaryHash.h +116 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +160 -107
- data/vendor/faiss/IndexBinaryIVF.h +14 -4
- data/vendor/faiss/IndexFlat.h +2 -1
- data/vendor/faiss/IndexHNSW.cpp +68 -16
- data/vendor/faiss/IndexHNSW.h +3 -3
- data/vendor/faiss/IndexIVF.cpp +72 -76
- data/vendor/faiss/IndexIVF.h +24 -5
- data/vendor/faiss/IndexIVFFlat.cpp +19 -54
- data/vendor/faiss/IndexIVFFlat.h +1 -11
- data/vendor/faiss/IndexIVFPQ.cpp +49 -26
- data/vendor/faiss/IndexIVFPQ.h +9 -10
- data/vendor/faiss/IndexIVFPQR.cpp +2 -2
- data/vendor/faiss/IndexIVFSpectralHash.cpp +2 -2
- data/vendor/faiss/IndexLSH.h +4 -1
- data/vendor/faiss/IndexPreTransform.cpp +0 -1
- data/vendor/faiss/IndexScalarQuantizer.cpp +8 -1
- data/vendor/faiss/InvertedLists.cpp +0 -2
- data/vendor/faiss/MetaIndexes.cpp +0 -1
- data/vendor/faiss/MetricType.h +36 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +13 -7
- data/vendor/faiss/c_api/Clustering_c.h +11 -5
- data/vendor/faiss/c_api/IndexIVF_c.cpp +7 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +7 -0
- data/vendor/faiss/c_api/IndexPreTransform_c.cpp +21 -0
- data/vendor/faiss/c_api/IndexPreTransform_c.h +32 -0
- data/vendor/faiss/demos/demo_weighted_kmeans.cpp +185 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +4 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +1 -1
- data/vendor/faiss/gpu/GpuDistance.h +93 -0
- data/vendor/faiss/gpu/GpuIndex.h +7 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +0 -10
- data/vendor/faiss/gpu/GpuIndexIVF.h +1 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +8 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +49 -27
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +110 -2
- data/vendor/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +17 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +14 -3
- data/vendor/faiss/impl/HNSW.cpp +0 -1
- data/vendor/faiss/impl/PolysemousTraining.h +5 -5
- data/vendor/faiss/impl/ProductQuantizer-inl.h +138 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +1 -113
- data/vendor/faiss/impl/ProductQuantizer.h +42 -47
- data/vendor/faiss/impl/index_read.cpp +103 -7
- data/vendor/faiss/impl/index_write.cpp +101 -5
- data/vendor/faiss/impl/io.cpp +111 -1
- data/vendor/faiss/impl/io.h +38 -0
- data/vendor/faiss/index_factory.cpp +0 -1
- data/vendor/faiss/tests/test_merge.cpp +0 -1
- data/vendor/faiss/tests/test_pq_encoding.cpp +6 -6
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +1 -0
- data/vendor/faiss/utils/distances.cpp +4 -5
- data/vendor/faiss/utils/distances_simd.cpp +0 -1
- data/vendor/faiss/utils/hamming.cpp +85 -3
- data/vendor/faiss/utils/hamming.h +20 -0
- data/vendor/faiss/utils/utils.cpp +0 -96
- data/vendor/faiss/utils/utils.h +0 -15
- metadata +11 -3
- data/lib/faiss/ext.bundle +0 -0
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 61a0b7a7d20933b60a9e0e213016b77f10eae5bb86ecf6825f5f5661f31f5d7d
|
4
|
+
data.tar.gz: 967378ee774a35e3a639b1d902648ebf4177ffb7ed43e7228e8432f284f397c1
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 32747a4d4a3d40f15e9802280d894b2270d4b78ac0a10859442d0fd3c7ae27a55032a92e072756cb4046964c9d53afcc9586ac954e2b3cf63d057d0a3652e5a8
|
7
|
+
data.tar.gz: ca5005286253b7dea1546160ffb00c4b91a8b926512fbcfb7db594171435249e88408c982dc53fed4d90f87e68979bd9c2a2c1975f94495ca77c9ac878b22c1c
|
data/CHANGELOG.md
CHANGED
data/LICENSE.txt
CHANGED
@@ -1,22 +1,22 @@
|
|
1
|
-
Copyright (c) 2020 Andrew Kane
|
2
|
-
|
3
1
|
MIT License
|
4
2
|
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
3
|
+
Copyright (c) Facebook, Inc. and its affiliates.
|
4
|
+
Copyright (c) 2020 Andrew Kane
|
5
|
+
|
6
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7
|
+
of this software and associated documentation files (the "Software"), to deal
|
8
|
+
in the Software without restriction, including without limitation the rights
|
9
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10
|
+
copies of the Software, and to permit persons to whom the Software is
|
11
|
+
furnished to do so, subject to the following conditions:
|
12
12
|
|
13
|
-
The above copyright notice and this permission notice shall be
|
14
|
-
|
13
|
+
The above copyright notice and this permission notice shall be included in all
|
14
|
+
copies or substantial portions of the Software.
|
15
15
|
|
16
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
OF
|
22
|
-
|
16
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22
|
+
SOFTWARE.
|
data/README.md
CHANGED
data/lib/faiss/version.rb
CHANGED
data/vendor/faiss/Clustering.cpp
CHANGED
@@ -10,11 +10,12 @@
|
|
10
10
|
#include <faiss/Clustering.h>
|
11
11
|
#include <faiss/impl/AuxIndexStructures.h>
|
12
12
|
|
13
|
-
|
14
13
|
#include <cmath>
|
15
14
|
#include <cstdio>
|
16
15
|
#include <cstring>
|
17
16
|
|
17
|
+
#include <omp.h>
|
18
|
+
|
18
19
|
#include <faiss/utils/utils.h>
|
19
20
|
#include <faiss/utils/random.h>
|
20
21
|
#include <faiss/utils/distances.h>
|
@@ -33,7 +34,8 @@ ClusteringParameters::ClusteringParameters ():
|
|
33
34
|
frozen_centroids(false),
|
34
35
|
min_points_per_centroid(39),
|
35
36
|
max_points_per_centroid(256),
|
36
|
-
seed(1234)
|
37
|
+
seed(1234),
|
38
|
+
decode_block_size(32768)
|
37
39
|
{}
|
38
40
|
// 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
|
39
41
|
|
@@ -76,35 +78,233 @@ void Clustering::post_process_centroids ()
|
|
76
78
|
}
|
77
79
|
|
78
80
|
|
79
|
-
void Clustering::train (idx_t nx, const float *x_in, Index & index
|
81
|
+
void Clustering::train (idx_t nx, const float *x_in, Index & index,
|
82
|
+
const float *weights) {
|
83
|
+
train_encoded (nx, reinterpret_cast<const uint8_t *>(x_in), nullptr,
|
84
|
+
index, weights);
|
85
|
+
}
|
86
|
+
|
87
|
+
|
88
|
+
namespace {
|
89
|
+
|
90
|
+
using idx_t = Clustering::idx_t;
|
91
|
+
|
92
|
+
idx_t subsample_training_set(
|
93
|
+
const Clustering &clus, idx_t nx, const uint8_t *x,
|
94
|
+
size_t line_size, const float * weights,
|
95
|
+
uint8_t **x_out,
|
96
|
+
float **weights_out
|
97
|
+
)
|
98
|
+
{
|
99
|
+
if (clus.verbose) {
|
100
|
+
printf("Sampling a subset of %ld / %ld for training\n",
|
101
|
+
clus.k * clus.max_points_per_centroid, nx);
|
102
|
+
}
|
103
|
+
std::vector<int> perm (nx);
|
104
|
+
rand_perm (perm.data (), nx, clus.seed);
|
105
|
+
nx = clus.k * clus.max_points_per_centroid;
|
106
|
+
uint8_t * x_new = new uint8_t [nx * line_size];
|
107
|
+
*x_out = x_new;
|
108
|
+
for (idx_t i = 0; i < nx; i++) {
|
109
|
+
memcpy (x_new + i * line_size, x + perm[i] * line_size, line_size);
|
110
|
+
}
|
111
|
+
if (weights) {
|
112
|
+
float *weights_new = new float[nx];
|
113
|
+
for (idx_t i = 0; i < nx; i++) {
|
114
|
+
weights_new[i] = weights[perm[i]];
|
115
|
+
}
|
116
|
+
*weights_out = weights_new;
|
117
|
+
} else {
|
118
|
+
*weights_out = nullptr;
|
119
|
+
}
|
120
|
+
return nx;
|
121
|
+
}
|
122
|
+
|
123
|
+
/** compute centroids as (weighted) sum of training points
|
124
|
+
*
|
125
|
+
* @param x training vectors, size n * code_size (from codec)
|
126
|
+
* @param codec how to decode the vectors (if NULL then cast to float*)
|
127
|
+
* @param weights per-training vector weight, size n (or NULL)
|
128
|
+
* @param assign nearest centroid for each training vector, size n
|
129
|
+
* @param k_frozen do not update the k_frozen first centroids
|
130
|
+
* @param centroids centroid vectors (output only), size k * d
|
131
|
+
* @param hassign histogram of assignments per centroid (size k),
|
132
|
+
* should be 0 on input
|
133
|
+
*
|
134
|
+
*/
|
135
|
+
|
136
|
+
void compute_centroids (size_t d, size_t k, size_t n,
|
137
|
+
size_t k_frozen,
|
138
|
+
const uint8_t * x, const Index *codec,
|
139
|
+
const int64_t * assign,
|
140
|
+
const float * weights,
|
141
|
+
float * hassign,
|
142
|
+
float * centroids)
|
143
|
+
{
|
144
|
+
k -= k_frozen;
|
145
|
+
centroids += k_frozen * d;
|
146
|
+
|
147
|
+
memset (centroids, 0, sizeof(*centroids) * d * k);
|
148
|
+
|
149
|
+
size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float);
|
150
|
+
|
151
|
+
#pragma omp parallel
|
152
|
+
{
|
153
|
+
int nt = omp_get_num_threads();
|
154
|
+
int rank = omp_get_thread_num();
|
155
|
+
|
156
|
+
// this thread is taking care of centroids c0:c1
|
157
|
+
size_t c0 = (k * rank) / nt;
|
158
|
+
size_t c1 = (k * (rank + 1)) / nt;
|
159
|
+
std::vector<float> decode_buffer (d);
|
160
|
+
|
161
|
+
for (size_t i = 0; i < n; i++) {
|
162
|
+
int64_t ci = assign[i];
|
163
|
+
assert (ci >= 0 && ci < k + k_frozen);
|
164
|
+
ci -= k_frozen;
|
165
|
+
if (ci >= c0 && ci < c1) {
|
166
|
+
float * c = centroids + ci * d;
|
167
|
+
const float * xi;
|
168
|
+
if (!codec) {
|
169
|
+
xi = reinterpret_cast<const float*>(x + i * line_size);
|
170
|
+
} else {
|
171
|
+
float *xif = decode_buffer.data();
|
172
|
+
codec->sa_decode (1, x + i * line_size, xif);
|
173
|
+
xi = xif;
|
174
|
+
}
|
175
|
+
if (weights) {
|
176
|
+
float w = weights[i];
|
177
|
+
hassign[ci] += w;
|
178
|
+
for (size_t j = 0; j < d; j++) {
|
179
|
+
c[j] += xi[j] * w;
|
180
|
+
}
|
181
|
+
} else {
|
182
|
+
hassign[ci] += 1.0;
|
183
|
+
for (size_t j = 0; j < d; j++) {
|
184
|
+
c[j] += xi[j];
|
185
|
+
}
|
186
|
+
}
|
187
|
+
}
|
188
|
+
}
|
189
|
+
|
190
|
+
}
|
191
|
+
|
192
|
+
#pragma omp parallel for
|
193
|
+
for (size_t ci = 0; ci < k; ci++) {
|
194
|
+
if (hassign[ci] == 0) {
|
195
|
+
continue;
|
196
|
+
}
|
197
|
+
float norm = 1 / hassign[ci];
|
198
|
+
float * c = centroids + ci * d;
|
199
|
+
for (size_t j = 0; j < d; j++) {
|
200
|
+
c[j] *= norm;
|
201
|
+
}
|
202
|
+
}
|
203
|
+
|
204
|
+
}
|
205
|
+
|
206
|
+
// a bit above machine epsilon for float16
|
207
|
+
#define EPS (1 / 1024.)
|
208
|
+
|
209
|
+
/** Handle empty clusters by splitting larger ones.
|
210
|
+
*
|
211
|
+
* It works by slightly changing the centroids to make 2 clusters from
|
212
|
+
* a single one. Takes the same arguements as compute_centroids.
|
213
|
+
*
|
214
|
+
* @return nb of spliting operations (larger is worse)
|
215
|
+
*/
|
216
|
+
int split_clusters (size_t d, size_t k, size_t n,
|
217
|
+
size_t k_frozen,
|
218
|
+
float * hassign,
|
219
|
+
float * centroids)
|
220
|
+
{
|
221
|
+
k -= k_frozen;
|
222
|
+
centroids += k_frozen * d;
|
223
|
+
|
224
|
+
/* Take care of void clusters */
|
225
|
+
size_t nsplit = 0;
|
226
|
+
RandomGenerator rng (1234);
|
227
|
+
for (size_t ci = 0; ci < k; ci++) {
|
228
|
+
if (hassign[ci] == 0) { /* need to redefine a centroid */
|
229
|
+
size_t cj;
|
230
|
+
for (cj = 0; 1; cj = (cj + 1) % k) {
|
231
|
+
/* probability to pick this cluster for split */
|
232
|
+
float p = (hassign[cj] - 1.0) / (float) (n - k);
|
233
|
+
float r = rng.rand_float ();
|
234
|
+
if (r < p) {
|
235
|
+
break; /* found our cluster to be split */
|
236
|
+
}
|
237
|
+
}
|
238
|
+
memcpy (centroids+ci*d, centroids+cj*d, sizeof(*centroids) * d);
|
239
|
+
|
240
|
+
/* small symmetric pertubation */
|
241
|
+
for (size_t j = 0; j < d; j++) {
|
242
|
+
if (j % 2 == 0) {
|
243
|
+
centroids[ci * d + j] *= 1 + EPS;
|
244
|
+
centroids[cj * d + j] *= 1 - EPS;
|
245
|
+
} else {
|
246
|
+
centroids[ci * d + j] *= 1 - EPS;
|
247
|
+
centroids[cj * d + j] *= 1 + EPS;
|
248
|
+
}
|
249
|
+
}
|
250
|
+
|
251
|
+
/* assume even split of the cluster */
|
252
|
+
hassign[ci] = hassign[cj] / 2;
|
253
|
+
hassign[cj] -= hassign[ci];
|
254
|
+
nsplit++;
|
255
|
+
}
|
256
|
+
}
|
257
|
+
|
258
|
+
return nsplit;
|
259
|
+
|
260
|
+
}
|
261
|
+
|
262
|
+
|
263
|
+
|
264
|
+
};
|
265
|
+
|
266
|
+
|
267
|
+
void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
268
|
+
const Index * codec, Index & index,
|
269
|
+
const float *weights) {
|
270
|
+
|
80
271
|
FAISS_THROW_IF_NOT_FMT (nx >= k,
|
81
272
|
"Number of training points (%ld) should be at least "
|
82
273
|
"as large as number of clusters (%ld)", nx, k);
|
83
274
|
|
275
|
+
FAISS_THROW_IF_NOT_FMT ((!codec || codec->d == d),
|
276
|
+
"Codec dimension %d not the same as data dimension %d",
|
277
|
+
int(codec->d), int(d));
|
278
|
+
|
279
|
+
FAISS_THROW_IF_NOT_FMT (index.d == d,
|
280
|
+
"Index dimension %d not the same as data dimension %d",
|
281
|
+
int(index.d), int(d));
|
282
|
+
|
84
283
|
double t0 = getmillisecs();
|
85
284
|
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
285
|
+
if (!codec) {
|
286
|
+
// Check for NaNs in input data. Normally it is the user's
|
287
|
+
// responsibility, but it may spare us some hard-to-debug
|
288
|
+
// reports.
|
289
|
+
const float *x = reinterpret_cast<const float *>(x_in);
|
290
|
+
for (size_t i = 0; i < nx * d; i++) {
|
291
|
+
FAISS_THROW_IF_NOT_MSG (finite (x[i]),
|
292
|
+
"input contains NaN's or Inf's");
|
293
|
+
}
|
91
294
|
}
|
92
295
|
|
93
|
-
const
|
94
|
-
|
296
|
+
const uint8_t *x = x_in;
|
297
|
+
std::unique_ptr<uint8_t []> del1;
|
298
|
+
std::unique_ptr<float []> del3;
|
299
|
+
size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
|
95
300
|
|
96
301
|
if (nx > k * max_points_per_centroid) {
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
float * x_new = new float [nx * d];
|
104
|
-
for (idx_t i = 0; i < nx; i++)
|
105
|
-
memcpy (x_new + i * d, x + perm[i] * d, sizeof(x_new[0]) * d);
|
106
|
-
x = x_new;
|
107
|
-
del1.set (x);
|
302
|
+
uint8_t *x_new;
|
303
|
+
float *weights_new;
|
304
|
+
nx = subsample_training_set (*this, nx, x, line_size, weights,
|
305
|
+
&x_new, &weights_new);
|
306
|
+
del1.reset (x_new); x = x_new;
|
307
|
+
del3.reset (weights_new); weights = weights_new;
|
108
308
|
} else if (nx < k * min_points_per_centroid) {
|
109
309
|
fprintf (stderr,
|
110
310
|
"WARNING clustering %ld points to %ld centroids: "
|
@@ -112,41 +312,53 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
|
|
112
312
|
nx, k, idx_t(k) * min_points_per_centroid);
|
113
313
|
}
|
114
314
|
|
115
|
-
|
116
315
|
if (nx == k) {
|
316
|
+
// this is a corner case, just copy training set to clusters
|
117
317
|
if (verbose) {
|
118
318
|
printf("Number of training points (%ld) same as number of "
|
119
319
|
"clusters, just copying\n", nx);
|
120
320
|
}
|
121
|
-
// this is a corner case, just copy training set to clusters
|
122
321
|
centroids.resize (d * k);
|
123
|
-
|
322
|
+
if (!codec) {
|
323
|
+
memcpy (centroids.data(), x_in, sizeof (float) * d * k);
|
324
|
+
} else {
|
325
|
+
codec->sa_decode (nx, x_in, centroids.data());
|
326
|
+
}
|
327
|
+
|
328
|
+
// one fake iteration...
|
329
|
+
ClusteringIterationStats stats = { 0.0, 0.0, 0.0, 1.0, 0 };
|
330
|
+
iteration_stats.push_back (stats);
|
331
|
+
|
124
332
|
index.reset();
|
125
|
-
index.add(k,
|
333
|
+
index.add(k, centroids.data());
|
126
334
|
return;
|
127
335
|
}
|
128
336
|
|
129
337
|
|
130
|
-
if (verbose)
|
338
|
+
if (verbose) {
|
131
339
|
printf("Clustering %d points in %ldD to %ld clusters, "
|
132
340
|
"redo %d times, %d iterations\n",
|
133
341
|
int(nx), d, k, nredo, niter);
|
342
|
+
if (codec) {
|
343
|
+
printf("Input data encoded in %ld bytes per vector\n",
|
344
|
+
codec->sa_code_size ());
|
345
|
+
}
|
346
|
+
}
|
134
347
|
|
135
|
-
idx_t
|
136
|
-
|
137
|
-
float * dis = new float[nx];
|
138
|
-
ScopeDeleter<float> del2(dis);
|
348
|
+
std::unique_ptr<idx_t []> assign(new idx_t[nx]);
|
349
|
+
std::unique_ptr<float []> dis(new float[nx]);
|
139
350
|
|
140
|
-
// for redo
|
351
|
+
// remember best iteration for redo
|
141
352
|
float best_err = HUGE_VALF;
|
142
|
-
std::vector<
|
353
|
+
std::vector<ClusteringIterationStats> best_obj;
|
143
354
|
std::vector<float> best_centroids;
|
144
355
|
|
145
356
|
// support input centroids
|
146
357
|
|
147
358
|
FAISS_THROW_IF_NOT_MSG (
|
148
359
|
centroids.size() % d == 0,
|
149
|
-
"size of provided input centroids not a multiple of dimension"
|
360
|
+
"size of provided input centroids not a multiple of dimension"
|
361
|
+
);
|
150
362
|
|
151
363
|
size_t n_input_centroids = centroids.size() / d;
|
152
364
|
|
@@ -162,23 +374,36 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
|
|
162
374
|
}
|
163
375
|
t0 = getmillisecs();
|
164
376
|
|
377
|
+
// temporary buffer to decode vectors during the optimization
|
378
|
+
std::vector<float> decode_buffer
|
379
|
+
(codec ? d * decode_block_size : 0);
|
380
|
+
|
165
381
|
for (int redo = 0; redo < nredo; redo++) {
|
166
382
|
|
167
383
|
if (verbose && nredo > 1) {
|
168
384
|
printf("Outer iteration %d / %d\n", redo, nredo);
|
169
385
|
}
|
170
386
|
|
171
|
-
// initialize remaining centroids with random points from the dataset
|
387
|
+
// initialize (remaining) centroids with random points from the dataset
|
172
388
|
centroids.resize (d * k);
|
173
389
|
std::vector<int> perm (nx);
|
174
390
|
|
175
391
|
rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
|
176
|
-
|
177
|
-
|
178
|
-
|
392
|
+
|
393
|
+
if (!codec) {
|
394
|
+
for (int i = n_input_centroids; i < k ; i++) {
|
395
|
+
memcpy (¢roids[i * d], x + perm[i] * line_size, line_size);
|
396
|
+
}
|
397
|
+
} else {
|
398
|
+
for (int i = n_input_centroids; i < k ; i++) {
|
399
|
+
codec->sa_decode (1, x + perm[i] * line_size, ¢roids[i * d]);
|
400
|
+
}
|
401
|
+
}
|
179
402
|
|
180
403
|
post_process_centroids ();
|
181
404
|
|
405
|
+
// prepare the index
|
406
|
+
|
182
407
|
if (index.ntotal != 0) {
|
183
408
|
index.reset();
|
184
409
|
}
|
@@ -188,49 +413,89 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
|
|
188
413
|
}
|
189
414
|
|
190
415
|
index.add (k, centroids.data());
|
416
|
+
|
417
|
+
// k-means iterations
|
418
|
+
|
191
419
|
float err = 0;
|
192
420
|
for (int i = 0; i < niter; i++) {
|
193
421
|
double t0s = getmillisecs();
|
194
|
-
|
422
|
+
|
423
|
+
if (!codec) {
|
424
|
+
index.search (nx, reinterpret_cast<const float *>(x), 1,
|
425
|
+
dis.get(), assign.get());
|
426
|
+
} else {
|
427
|
+
// search by blocks of decode_block_size vectors
|
428
|
+
size_t code_size = codec->sa_code_size ();
|
429
|
+
for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
|
430
|
+
size_t i1 = i0 + decode_block_size;
|
431
|
+
if (i1 > nx) { i1 = nx; }
|
432
|
+
codec->sa_decode (i1 - i0, x + code_size * i0,
|
433
|
+
decode_buffer.data ());
|
434
|
+
index.search (i1 - i0, decode_buffer.data (), 1,
|
435
|
+
dis.get() + i0, assign.get() + i0);
|
436
|
+
}
|
437
|
+
}
|
438
|
+
|
195
439
|
InterruptCallback::check();
|
196
440
|
t_search_tot += getmillisecs() - t0s;
|
197
441
|
|
442
|
+
// accumulate error
|
198
443
|
err = 0;
|
199
|
-
for (int j = 0; j < nx; j++)
|
444
|
+
for (int j = 0; j < nx; j++) {
|
200
445
|
err += dis[j];
|
201
|
-
|
446
|
+
}
|
447
|
+
|
448
|
+
// update the centroids
|
449
|
+
std::vector<float> hassign (k);
|
202
450
|
|
203
|
-
|
204
|
-
|
205
|
-
|
451
|
+
size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
|
452
|
+
compute_centroids (
|
453
|
+
d, k, nx, k_frozen,
|
454
|
+
x, codec, assign.get(), weights,
|
455
|
+
hassign.data(), centroids.data()
|
456
|
+
);
|
457
|
+
|
458
|
+
int nsplit = split_clusters (
|
459
|
+
d, k, nx, k_frozen,
|
460
|
+
hassign.data(), centroids.data()
|
461
|
+
);
|
462
|
+
|
463
|
+
// collect statistics
|
464
|
+
ClusteringIterationStats stats =
|
465
|
+
{ err, (getmillisecs() - t0) / 1000.0,
|
466
|
+
t_search_tot / 1000, imbalance_factor (nx, k, assign.get()),
|
467
|
+
nsplit };
|
468
|
+
iteration_stats.push_back(stats);
|
206
469
|
|
207
470
|
if (verbose) {
|
208
471
|
printf (" Iteration %d (%.2f s, search %.2f s): "
|
209
472
|
"objective=%g imbalance=%.3f nsplit=%d \r",
|
210
|
-
i,
|
211
|
-
|
212
|
-
err, imbalance_factor (nx, k, assign),
|
213
|
-
nsplit);
|
473
|
+
i, stats.time, stats.time_search, stats.obj,
|
474
|
+
stats.imbalance_factor, nsplit);
|
214
475
|
fflush (stdout);
|
215
476
|
}
|
216
477
|
|
217
478
|
post_process_centroids ();
|
218
479
|
|
480
|
+
// add centroids to index for the next iteration (or for output)
|
481
|
+
|
219
482
|
index.reset ();
|
220
|
-
if (update_index)
|
483
|
+
if (update_index) {
|
221
484
|
index.train (k, centroids.data());
|
485
|
+
}
|
222
486
|
|
223
|
-
assert (index.ntotal == 0);
|
224
487
|
index.add (k, centroids.data());
|
225
488
|
InterruptCallback::check ();
|
226
489
|
}
|
490
|
+
|
227
491
|
if (verbose) printf("\n");
|
228
492
|
if (nredo > 1) {
|
229
493
|
if (err < best_err) {
|
230
|
-
if (verbose)
|
494
|
+
if (verbose) {
|
231
495
|
printf ("Objective improved: keep new clusters\n");
|
496
|
+
}
|
232
497
|
best_centroids = centroids;
|
233
|
-
best_obj =
|
498
|
+
best_obj = iteration_stats;
|
234
499
|
best_err = err;
|
235
500
|
}
|
236
501
|
index.reset ();
|
@@ -238,7 +503,7 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
|
|
238
503
|
}
|
239
504
|
if (nredo > 1) {
|
240
505
|
centroids = best_centroids;
|
241
|
-
|
506
|
+
iteration_stats = best_obj;
|
242
507
|
index.reset();
|
243
508
|
index.add(k, best_centroids.data());
|
244
509
|
}
|
@@ -255,7 +520,7 @@ float kmeans_clustering (size_t d, size_t n, size_t k,
|
|
255
520
|
IndexFlatL2 index (d);
|
256
521
|
clus.train (n, x, index);
|
257
522
|
memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
|
258
|
-
return clus.
|
523
|
+
return clus.iteration_stats.back().obj;
|
259
524
|
}
|
260
525
|
|
261
526
|
} // namespace faiss
|