faiss 0.1.1 → 0.1.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +18 -18
- data/README.md +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/Clustering.cpp +318 -53
- data/vendor/faiss/Clustering.h +39 -11
- data/vendor/faiss/DirectMap.cpp +267 -0
- data/vendor/faiss/DirectMap.h +120 -0
- data/vendor/faiss/IVFlib.cpp +24 -4
- data/vendor/faiss/IVFlib.h +4 -0
- data/vendor/faiss/Index.h +5 -24
- data/vendor/faiss/Index2Layer.cpp +0 -1
- data/vendor/faiss/IndexBinary.h +7 -3
- data/vendor/faiss/IndexBinaryFlat.cpp +5 -0
- data/vendor/faiss/IndexBinaryFlat.h +3 -0
- data/vendor/faiss/IndexBinaryHash.cpp +492 -0
- data/vendor/faiss/IndexBinaryHash.h +116 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +160 -107
- data/vendor/faiss/IndexBinaryIVF.h +14 -4
- data/vendor/faiss/IndexFlat.h +2 -1
- data/vendor/faiss/IndexHNSW.cpp +68 -16
- data/vendor/faiss/IndexHNSW.h +3 -3
- data/vendor/faiss/IndexIVF.cpp +72 -76
- data/vendor/faiss/IndexIVF.h +24 -5
- data/vendor/faiss/IndexIVFFlat.cpp +19 -54
- data/vendor/faiss/IndexIVFFlat.h +1 -11
- data/vendor/faiss/IndexIVFPQ.cpp +49 -26
- data/vendor/faiss/IndexIVFPQ.h +9 -10
- data/vendor/faiss/IndexIVFPQR.cpp +2 -2
- data/vendor/faiss/IndexIVFSpectralHash.cpp +2 -2
- data/vendor/faiss/IndexLSH.h +4 -1
- data/vendor/faiss/IndexPreTransform.cpp +0 -1
- data/vendor/faiss/IndexScalarQuantizer.cpp +8 -1
- data/vendor/faiss/InvertedLists.cpp +0 -2
- data/vendor/faiss/MetaIndexes.cpp +0 -1
- data/vendor/faiss/MetricType.h +36 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +13 -7
- data/vendor/faiss/c_api/Clustering_c.h +11 -5
- data/vendor/faiss/c_api/IndexIVF_c.cpp +7 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +7 -0
- data/vendor/faiss/c_api/IndexPreTransform_c.cpp +21 -0
- data/vendor/faiss/c_api/IndexPreTransform_c.h +32 -0
- data/vendor/faiss/demos/demo_weighted_kmeans.cpp +185 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +4 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +1 -1
- data/vendor/faiss/gpu/GpuDistance.h +93 -0
- data/vendor/faiss/gpu/GpuIndex.h +7 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +0 -10
- data/vendor/faiss/gpu/GpuIndexIVF.h +1 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +8 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +49 -27
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +110 -2
- data/vendor/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +17 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +14 -3
- data/vendor/faiss/impl/HNSW.cpp +0 -1
- data/vendor/faiss/impl/PolysemousTraining.h +5 -5
- data/vendor/faiss/impl/ProductQuantizer-inl.h +138 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +1 -113
- data/vendor/faiss/impl/ProductQuantizer.h +42 -47
- data/vendor/faiss/impl/index_read.cpp +103 -7
- data/vendor/faiss/impl/index_write.cpp +101 -5
- data/vendor/faiss/impl/io.cpp +111 -1
- data/vendor/faiss/impl/io.h +38 -0
- data/vendor/faiss/index_factory.cpp +0 -1
- data/vendor/faiss/tests/test_merge.cpp +0 -1
- data/vendor/faiss/tests/test_pq_encoding.cpp +6 -6
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +1 -0
- data/vendor/faiss/utils/distances.cpp +4 -5
- data/vendor/faiss/utils/distances_simd.cpp +0 -1
- data/vendor/faiss/utils/hamming.cpp +85 -3
- data/vendor/faiss/utils/hamming.h +20 -0
- data/vendor/faiss/utils/utils.cpp +0 -96
- data/vendor/faiss/utils/utils.h +0 -15
- metadata +11 -3
- data/lib/faiss/ext.bundle +0 -0
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 61a0b7a7d20933b60a9e0e213016b77f10eae5bb86ecf6825f5f5661f31f5d7d
|
4
|
+
data.tar.gz: 967378ee774a35e3a639b1d902648ebf4177ffb7ed43e7228e8432f284f397c1
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 32747a4d4a3d40f15e9802280d894b2270d4b78ac0a10859442d0fd3c7ae27a55032a92e072756cb4046964c9d53afcc9586ac954e2b3cf63d057d0a3652e5a8
|
7
|
+
data.tar.gz: ca5005286253b7dea1546160ffb00c4b91a8b926512fbcfb7db594171435249e88408c982dc53fed4d90f87e68979bd9c2a2c1975f94495ca77c9ac878b22c1c
|
data/CHANGELOG.md
CHANGED
data/LICENSE.txt
CHANGED
@@ -1,22 +1,22 @@
|
|
1
|
-
Copyright (c) 2020 Andrew Kane
|
2
|
-
|
3
1
|
MIT License
|
4
2
|
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
3
|
+
Copyright (c) Facebook, Inc. and its affiliates.
|
4
|
+
Copyright (c) 2020 Andrew Kane
|
5
|
+
|
6
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7
|
+
of this software and associated documentation files (the "Software"), to deal
|
8
|
+
in the Software without restriction, including without limitation the rights
|
9
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10
|
+
copies of the Software, and to permit persons to whom the Software is
|
11
|
+
furnished to do so, subject to the following conditions:
|
12
12
|
|
13
|
-
The above copyright notice and this permission notice shall be
|
14
|
-
|
13
|
+
The above copyright notice and this permission notice shall be included in all
|
14
|
+
copies or substantial portions of the Software.
|
15
15
|
|
16
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
OF
|
22
|
-
|
16
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22
|
+
SOFTWARE.
|
data/README.md
CHANGED
data/lib/faiss/version.rb
CHANGED
data/vendor/faiss/Clustering.cpp
CHANGED
@@ -10,11 +10,12 @@
|
|
10
10
|
#include <faiss/Clustering.h>
|
11
11
|
#include <faiss/impl/AuxIndexStructures.h>
|
12
12
|
|
13
|
-
|
14
13
|
#include <cmath>
|
15
14
|
#include <cstdio>
|
16
15
|
#include <cstring>
|
17
16
|
|
17
|
+
#include <omp.h>
|
18
|
+
|
18
19
|
#include <faiss/utils/utils.h>
|
19
20
|
#include <faiss/utils/random.h>
|
20
21
|
#include <faiss/utils/distances.h>
|
@@ -33,7 +34,8 @@ ClusteringParameters::ClusteringParameters ():
|
|
33
34
|
frozen_centroids(false),
|
34
35
|
min_points_per_centroid(39),
|
35
36
|
max_points_per_centroid(256),
|
36
|
-
seed(1234)
|
37
|
+
seed(1234),
|
38
|
+
decode_block_size(32768)
|
37
39
|
{}
|
38
40
|
// 39 corresponds to 10000 / 256 -> to avoid warnings on PQ tests with randu10k
|
39
41
|
|
@@ -76,35 +78,233 @@ void Clustering::post_process_centroids ()
|
|
76
78
|
}
|
77
79
|
|
78
80
|
|
79
|
-
void Clustering::train (idx_t nx, const float *x_in, Index & index
|
81
|
+
void Clustering::train (idx_t nx, const float *x_in, Index & index,
|
82
|
+
const float *weights) {
|
83
|
+
train_encoded (nx, reinterpret_cast<const uint8_t *>(x_in), nullptr,
|
84
|
+
index, weights);
|
85
|
+
}
|
86
|
+
|
87
|
+
|
88
|
+
namespace {
|
89
|
+
|
90
|
+
using idx_t = Clustering::idx_t;
|
91
|
+
|
92
|
+
idx_t subsample_training_set(
|
93
|
+
const Clustering &clus, idx_t nx, const uint8_t *x,
|
94
|
+
size_t line_size, const float * weights,
|
95
|
+
uint8_t **x_out,
|
96
|
+
float **weights_out
|
97
|
+
)
|
98
|
+
{
|
99
|
+
if (clus.verbose) {
|
100
|
+
printf("Sampling a subset of %ld / %ld for training\n",
|
101
|
+
clus.k * clus.max_points_per_centroid, nx);
|
102
|
+
}
|
103
|
+
std::vector<int> perm (nx);
|
104
|
+
rand_perm (perm.data (), nx, clus.seed);
|
105
|
+
nx = clus.k * clus.max_points_per_centroid;
|
106
|
+
uint8_t * x_new = new uint8_t [nx * line_size];
|
107
|
+
*x_out = x_new;
|
108
|
+
for (idx_t i = 0; i < nx; i++) {
|
109
|
+
memcpy (x_new + i * line_size, x + perm[i] * line_size, line_size);
|
110
|
+
}
|
111
|
+
if (weights) {
|
112
|
+
float *weights_new = new float[nx];
|
113
|
+
for (idx_t i = 0; i < nx; i++) {
|
114
|
+
weights_new[i] = weights[perm[i]];
|
115
|
+
}
|
116
|
+
*weights_out = weights_new;
|
117
|
+
} else {
|
118
|
+
*weights_out = nullptr;
|
119
|
+
}
|
120
|
+
return nx;
|
121
|
+
}
|
122
|
+
|
123
|
+
/** compute centroids as (weighted) sum of training points
|
124
|
+
*
|
125
|
+
* @param x training vectors, size n * code_size (from codec)
|
126
|
+
* @param codec how to decode the vectors (if NULL then cast to float*)
|
127
|
+
* @param weights per-training vector weight, size n (or NULL)
|
128
|
+
* @param assign nearest centroid for each training vector, size n
|
129
|
+
* @param k_frozen do not update the k_frozen first centroids
|
130
|
+
* @param centroids centroid vectors (output only), size k * d
|
131
|
+
* @param hassign histogram of assignments per centroid (size k),
|
132
|
+
* should be 0 on input
|
133
|
+
*
|
134
|
+
*/
|
135
|
+
|
136
|
+
void compute_centroids (size_t d, size_t k, size_t n,
|
137
|
+
size_t k_frozen,
|
138
|
+
const uint8_t * x, const Index *codec,
|
139
|
+
const int64_t * assign,
|
140
|
+
const float * weights,
|
141
|
+
float * hassign,
|
142
|
+
float * centroids)
|
143
|
+
{
|
144
|
+
k -= k_frozen;
|
145
|
+
centroids += k_frozen * d;
|
146
|
+
|
147
|
+
memset (centroids, 0, sizeof(*centroids) * d * k);
|
148
|
+
|
149
|
+
size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float);
|
150
|
+
|
151
|
+
#pragma omp parallel
|
152
|
+
{
|
153
|
+
int nt = omp_get_num_threads();
|
154
|
+
int rank = omp_get_thread_num();
|
155
|
+
|
156
|
+
// this thread is taking care of centroids c0:c1
|
157
|
+
size_t c0 = (k * rank) / nt;
|
158
|
+
size_t c1 = (k * (rank + 1)) / nt;
|
159
|
+
std::vector<float> decode_buffer (d);
|
160
|
+
|
161
|
+
for (size_t i = 0; i < n; i++) {
|
162
|
+
int64_t ci = assign[i];
|
163
|
+
assert (ci >= 0 && ci < k + k_frozen);
|
164
|
+
ci -= k_frozen;
|
165
|
+
if (ci >= c0 && ci < c1) {
|
166
|
+
float * c = centroids + ci * d;
|
167
|
+
const float * xi;
|
168
|
+
if (!codec) {
|
169
|
+
xi = reinterpret_cast<const float*>(x + i * line_size);
|
170
|
+
} else {
|
171
|
+
float *xif = decode_buffer.data();
|
172
|
+
codec->sa_decode (1, x + i * line_size, xif);
|
173
|
+
xi = xif;
|
174
|
+
}
|
175
|
+
if (weights) {
|
176
|
+
float w = weights[i];
|
177
|
+
hassign[ci] += w;
|
178
|
+
for (size_t j = 0; j < d; j++) {
|
179
|
+
c[j] += xi[j] * w;
|
180
|
+
}
|
181
|
+
} else {
|
182
|
+
hassign[ci] += 1.0;
|
183
|
+
for (size_t j = 0; j < d; j++) {
|
184
|
+
c[j] += xi[j];
|
185
|
+
}
|
186
|
+
}
|
187
|
+
}
|
188
|
+
}
|
189
|
+
|
190
|
+
}
|
191
|
+
|
192
|
+
#pragma omp parallel for
|
193
|
+
for (size_t ci = 0; ci < k; ci++) {
|
194
|
+
if (hassign[ci] == 0) {
|
195
|
+
continue;
|
196
|
+
}
|
197
|
+
float norm = 1 / hassign[ci];
|
198
|
+
float * c = centroids + ci * d;
|
199
|
+
for (size_t j = 0; j < d; j++) {
|
200
|
+
c[j] *= norm;
|
201
|
+
}
|
202
|
+
}
|
203
|
+
|
204
|
+
}
|
205
|
+
|
206
|
+
// a bit above machine epsilon for float16
|
207
|
+
#define EPS (1 / 1024.)
|
208
|
+
|
209
|
+
/** Handle empty clusters by splitting larger ones.
|
210
|
+
*
|
211
|
+
* It works by slightly changing the centroids to make 2 clusters from
|
212
|
+
* a single one. Takes the same arguements as compute_centroids.
|
213
|
+
*
|
214
|
+
* @return nb of spliting operations (larger is worse)
|
215
|
+
*/
|
216
|
+
int split_clusters (size_t d, size_t k, size_t n,
|
217
|
+
size_t k_frozen,
|
218
|
+
float * hassign,
|
219
|
+
float * centroids)
|
220
|
+
{
|
221
|
+
k -= k_frozen;
|
222
|
+
centroids += k_frozen * d;
|
223
|
+
|
224
|
+
/* Take care of void clusters */
|
225
|
+
size_t nsplit = 0;
|
226
|
+
RandomGenerator rng (1234);
|
227
|
+
for (size_t ci = 0; ci < k; ci++) {
|
228
|
+
if (hassign[ci] == 0) { /* need to redefine a centroid */
|
229
|
+
size_t cj;
|
230
|
+
for (cj = 0; 1; cj = (cj + 1) % k) {
|
231
|
+
/* probability to pick this cluster for split */
|
232
|
+
float p = (hassign[cj] - 1.0) / (float) (n - k);
|
233
|
+
float r = rng.rand_float ();
|
234
|
+
if (r < p) {
|
235
|
+
break; /* found our cluster to be split */
|
236
|
+
}
|
237
|
+
}
|
238
|
+
memcpy (centroids+ci*d, centroids+cj*d, sizeof(*centroids) * d);
|
239
|
+
|
240
|
+
/* small symmetric pertubation */
|
241
|
+
for (size_t j = 0; j < d; j++) {
|
242
|
+
if (j % 2 == 0) {
|
243
|
+
centroids[ci * d + j] *= 1 + EPS;
|
244
|
+
centroids[cj * d + j] *= 1 - EPS;
|
245
|
+
} else {
|
246
|
+
centroids[ci * d + j] *= 1 - EPS;
|
247
|
+
centroids[cj * d + j] *= 1 + EPS;
|
248
|
+
}
|
249
|
+
}
|
250
|
+
|
251
|
+
/* assume even split of the cluster */
|
252
|
+
hassign[ci] = hassign[cj] / 2;
|
253
|
+
hassign[cj] -= hassign[ci];
|
254
|
+
nsplit++;
|
255
|
+
}
|
256
|
+
}
|
257
|
+
|
258
|
+
return nsplit;
|
259
|
+
|
260
|
+
}
|
261
|
+
|
262
|
+
|
263
|
+
|
264
|
+
};
|
265
|
+
|
266
|
+
|
267
|
+
void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
|
268
|
+
const Index * codec, Index & index,
|
269
|
+
const float *weights) {
|
270
|
+
|
80
271
|
FAISS_THROW_IF_NOT_FMT (nx >= k,
|
81
272
|
"Number of training points (%ld) should be at least "
|
82
273
|
"as large as number of clusters (%ld)", nx, k);
|
83
274
|
|
275
|
+
FAISS_THROW_IF_NOT_FMT ((!codec || codec->d == d),
|
276
|
+
"Codec dimension %d not the same as data dimension %d",
|
277
|
+
int(codec->d), int(d));
|
278
|
+
|
279
|
+
FAISS_THROW_IF_NOT_FMT (index.d == d,
|
280
|
+
"Index dimension %d not the same as data dimension %d",
|
281
|
+
int(index.d), int(d));
|
282
|
+
|
84
283
|
double t0 = getmillisecs();
|
85
284
|
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
285
|
+
if (!codec) {
|
286
|
+
// Check for NaNs in input data. Normally it is the user's
|
287
|
+
// responsibility, but it may spare us some hard-to-debug
|
288
|
+
// reports.
|
289
|
+
const float *x = reinterpret_cast<const float *>(x_in);
|
290
|
+
for (size_t i = 0; i < nx * d; i++) {
|
291
|
+
FAISS_THROW_IF_NOT_MSG (finite (x[i]),
|
292
|
+
"input contains NaN's or Inf's");
|
293
|
+
}
|
91
294
|
}
|
92
295
|
|
93
|
-
const
|
94
|
-
|
296
|
+
const uint8_t *x = x_in;
|
297
|
+
std::unique_ptr<uint8_t []> del1;
|
298
|
+
std::unique_ptr<float []> del3;
|
299
|
+
size_t line_size = codec ? codec->sa_code_size() : sizeof(float) * d;
|
95
300
|
|
96
301
|
if (nx > k * max_points_per_centroid) {
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
float * x_new = new float [nx * d];
|
104
|
-
for (idx_t i = 0; i < nx; i++)
|
105
|
-
memcpy (x_new + i * d, x + perm[i] * d, sizeof(x_new[0]) * d);
|
106
|
-
x = x_new;
|
107
|
-
del1.set (x);
|
302
|
+
uint8_t *x_new;
|
303
|
+
float *weights_new;
|
304
|
+
nx = subsample_training_set (*this, nx, x, line_size, weights,
|
305
|
+
&x_new, &weights_new);
|
306
|
+
del1.reset (x_new); x = x_new;
|
307
|
+
del3.reset (weights_new); weights = weights_new;
|
108
308
|
} else if (nx < k * min_points_per_centroid) {
|
109
309
|
fprintf (stderr,
|
110
310
|
"WARNING clustering %ld points to %ld centroids: "
|
@@ -112,41 +312,53 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
|
|
112
312
|
nx, k, idx_t(k) * min_points_per_centroid);
|
113
313
|
}
|
114
314
|
|
115
|
-
|
116
315
|
if (nx == k) {
|
316
|
+
// this is a corner case, just copy training set to clusters
|
117
317
|
if (verbose) {
|
118
318
|
printf("Number of training points (%ld) same as number of "
|
119
319
|
"clusters, just copying\n", nx);
|
120
320
|
}
|
121
|
-
// this is a corner case, just copy training set to clusters
|
122
321
|
centroids.resize (d * k);
|
123
|
-
|
322
|
+
if (!codec) {
|
323
|
+
memcpy (centroids.data(), x_in, sizeof (float) * d * k);
|
324
|
+
} else {
|
325
|
+
codec->sa_decode (nx, x_in, centroids.data());
|
326
|
+
}
|
327
|
+
|
328
|
+
// one fake iteration...
|
329
|
+
ClusteringIterationStats stats = { 0.0, 0.0, 0.0, 1.0, 0 };
|
330
|
+
iteration_stats.push_back (stats);
|
331
|
+
|
124
332
|
index.reset();
|
125
|
-
index.add(k,
|
333
|
+
index.add(k, centroids.data());
|
126
334
|
return;
|
127
335
|
}
|
128
336
|
|
129
337
|
|
130
|
-
if (verbose)
|
338
|
+
if (verbose) {
|
131
339
|
printf("Clustering %d points in %ldD to %ld clusters, "
|
132
340
|
"redo %d times, %d iterations\n",
|
133
341
|
int(nx), d, k, nredo, niter);
|
342
|
+
if (codec) {
|
343
|
+
printf("Input data encoded in %ld bytes per vector\n",
|
344
|
+
codec->sa_code_size ());
|
345
|
+
}
|
346
|
+
}
|
134
347
|
|
135
|
-
idx_t
|
136
|
-
|
137
|
-
float * dis = new float[nx];
|
138
|
-
ScopeDeleter<float> del2(dis);
|
348
|
+
std::unique_ptr<idx_t []> assign(new idx_t[nx]);
|
349
|
+
std::unique_ptr<float []> dis(new float[nx]);
|
139
350
|
|
140
|
-
// for redo
|
351
|
+
// remember best iteration for redo
|
141
352
|
float best_err = HUGE_VALF;
|
142
|
-
std::vector<
|
353
|
+
std::vector<ClusteringIterationStats> best_obj;
|
143
354
|
std::vector<float> best_centroids;
|
144
355
|
|
145
356
|
// support input centroids
|
146
357
|
|
147
358
|
FAISS_THROW_IF_NOT_MSG (
|
148
359
|
centroids.size() % d == 0,
|
149
|
-
"size of provided input centroids not a multiple of dimension"
|
360
|
+
"size of provided input centroids not a multiple of dimension"
|
361
|
+
);
|
150
362
|
|
151
363
|
size_t n_input_centroids = centroids.size() / d;
|
152
364
|
|
@@ -162,23 +374,36 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
|
|
162
374
|
}
|
163
375
|
t0 = getmillisecs();
|
164
376
|
|
377
|
+
// temporary buffer to decode vectors during the optimization
|
378
|
+
std::vector<float> decode_buffer
|
379
|
+
(codec ? d * decode_block_size : 0);
|
380
|
+
|
165
381
|
for (int redo = 0; redo < nredo; redo++) {
|
166
382
|
|
167
383
|
if (verbose && nredo > 1) {
|
168
384
|
printf("Outer iteration %d / %d\n", redo, nredo);
|
169
385
|
}
|
170
386
|
|
171
|
-
// initialize remaining centroids with random points from the dataset
|
387
|
+
// initialize (remaining) centroids with random points from the dataset
|
172
388
|
centroids.resize (d * k);
|
173
389
|
std::vector<int> perm (nx);
|
174
390
|
|
175
391
|
rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
|
176
|
-
|
177
|
-
|
178
|
-
|
392
|
+
|
393
|
+
if (!codec) {
|
394
|
+
for (int i = n_input_centroids; i < k ; i++) {
|
395
|
+
memcpy (¢roids[i * d], x + perm[i] * line_size, line_size);
|
396
|
+
}
|
397
|
+
} else {
|
398
|
+
for (int i = n_input_centroids; i < k ; i++) {
|
399
|
+
codec->sa_decode (1, x + perm[i] * line_size, ¢roids[i * d]);
|
400
|
+
}
|
401
|
+
}
|
179
402
|
|
180
403
|
post_process_centroids ();
|
181
404
|
|
405
|
+
// prepare the index
|
406
|
+
|
182
407
|
if (index.ntotal != 0) {
|
183
408
|
index.reset();
|
184
409
|
}
|
@@ -188,49 +413,89 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
|
|
188
413
|
}
|
189
414
|
|
190
415
|
index.add (k, centroids.data());
|
416
|
+
|
417
|
+
// k-means iterations
|
418
|
+
|
191
419
|
float err = 0;
|
192
420
|
for (int i = 0; i < niter; i++) {
|
193
421
|
double t0s = getmillisecs();
|
194
|
-
|
422
|
+
|
423
|
+
if (!codec) {
|
424
|
+
index.search (nx, reinterpret_cast<const float *>(x), 1,
|
425
|
+
dis.get(), assign.get());
|
426
|
+
} else {
|
427
|
+
// search by blocks of decode_block_size vectors
|
428
|
+
size_t code_size = codec->sa_code_size ();
|
429
|
+
for (size_t i0 = 0; i0 < nx; i0 += decode_block_size) {
|
430
|
+
size_t i1 = i0 + decode_block_size;
|
431
|
+
if (i1 > nx) { i1 = nx; }
|
432
|
+
codec->sa_decode (i1 - i0, x + code_size * i0,
|
433
|
+
decode_buffer.data ());
|
434
|
+
index.search (i1 - i0, decode_buffer.data (), 1,
|
435
|
+
dis.get() + i0, assign.get() + i0);
|
436
|
+
}
|
437
|
+
}
|
438
|
+
|
195
439
|
InterruptCallback::check();
|
196
440
|
t_search_tot += getmillisecs() - t0s;
|
197
441
|
|
442
|
+
// accumulate error
|
198
443
|
err = 0;
|
199
|
-
for (int j = 0; j < nx; j++)
|
444
|
+
for (int j = 0; j < nx; j++) {
|
200
445
|
err += dis[j];
|
201
|
-
|
446
|
+
}
|
447
|
+
|
448
|
+
// update the centroids
|
449
|
+
std::vector<float> hassign (k);
|
202
450
|
|
203
|
-
|
204
|
-
|
205
|
-
|
451
|
+
size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
|
452
|
+
compute_centroids (
|
453
|
+
d, k, nx, k_frozen,
|
454
|
+
x, codec, assign.get(), weights,
|
455
|
+
hassign.data(), centroids.data()
|
456
|
+
);
|
457
|
+
|
458
|
+
int nsplit = split_clusters (
|
459
|
+
d, k, nx, k_frozen,
|
460
|
+
hassign.data(), centroids.data()
|
461
|
+
);
|
462
|
+
|
463
|
+
// collect statistics
|
464
|
+
ClusteringIterationStats stats =
|
465
|
+
{ err, (getmillisecs() - t0) / 1000.0,
|
466
|
+
t_search_tot / 1000, imbalance_factor (nx, k, assign.get()),
|
467
|
+
nsplit };
|
468
|
+
iteration_stats.push_back(stats);
|
206
469
|
|
207
470
|
if (verbose) {
|
208
471
|
printf (" Iteration %d (%.2f s, search %.2f s): "
|
209
472
|
"objective=%g imbalance=%.3f nsplit=%d \r",
|
210
|
-
i,
|
211
|
-
|
212
|
-
err, imbalance_factor (nx, k, assign),
|
213
|
-
nsplit);
|
473
|
+
i, stats.time, stats.time_search, stats.obj,
|
474
|
+
stats.imbalance_factor, nsplit);
|
214
475
|
fflush (stdout);
|
215
476
|
}
|
216
477
|
|
217
478
|
post_process_centroids ();
|
218
479
|
|
480
|
+
// add centroids to index for the next iteration (or for output)
|
481
|
+
|
219
482
|
index.reset ();
|
220
|
-
if (update_index)
|
483
|
+
if (update_index) {
|
221
484
|
index.train (k, centroids.data());
|
485
|
+
}
|
222
486
|
|
223
|
-
assert (index.ntotal == 0);
|
224
487
|
index.add (k, centroids.data());
|
225
488
|
InterruptCallback::check ();
|
226
489
|
}
|
490
|
+
|
227
491
|
if (verbose) printf("\n");
|
228
492
|
if (nredo > 1) {
|
229
493
|
if (err < best_err) {
|
230
|
-
if (verbose)
|
494
|
+
if (verbose) {
|
231
495
|
printf ("Objective improved: keep new clusters\n");
|
496
|
+
}
|
232
497
|
best_centroids = centroids;
|
233
|
-
best_obj =
|
498
|
+
best_obj = iteration_stats;
|
234
499
|
best_err = err;
|
235
500
|
}
|
236
501
|
index.reset ();
|
@@ -238,7 +503,7 @@ void Clustering::train (idx_t nx, const float *x_in, Index & index) {
|
|
238
503
|
}
|
239
504
|
if (nredo > 1) {
|
240
505
|
centroids = best_centroids;
|
241
|
-
|
506
|
+
iteration_stats = best_obj;
|
242
507
|
index.reset();
|
243
508
|
index.add(k, best_centroids.data());
|
244
509
|
}
|
@@ -255,7 +520,7 @@ float kmeans_clustering (size_t d, size_t n, size_t k,
|
|
255
520
|
IndexFlatL2 index (d);
|
256
521
|
clus.train (n, x, index);
|
257
522
|
memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
|
258
|
-
return clus.
|
523
|
+
return clus.iteration_stats.back().obj;
|
259
524
|
}
|
260
525
|
|
261
526
|
} // namespace faiss
|