faiss 0.1.1 → 0.1.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +18 -18
- data/README.md +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/Clustering.cpp +318 -53
- data/vendor/faiss/Clustering.h +39 -11
- data/vendor/faiss/DirectMap.cpp +267 -0
- data/vendor/faiss/DirectMap.h +120 -0
- data/vendor/faiss/IVFlib.cpp +24 -4
- data/vendor/faiss/IVFlib.h +4 -0
- data/vendor/faiss/Index.h +5 -24
- data/vendor/faiss/Index2Layer.cpp +0 -1
- data/vendor/faiss/IndexBinary.h +7 -3
- data/vendor/faiss/IndexBinaryFlat.cpp +5 -0
- data/vendor/faiss/IndexBinaryFlat.h +3 -0
- data/vendor/faiss/IndexBinaryHash.cpp +492 -0
- data/vendor/faiss/IndexBinaryHash.h +116 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +160 -107
- data/vendor/faiss/IndexBinaryIVF.h +14 -4
- data/vendor/faiss/IndexFlat.h +2 -1
- data/vendor/faiss/IndexHNSW.cpp +68 -16
- data/vendor/faiss/IndexHNSW.h +3 -3
- data/vendor/faiss/IndexIVF.cpp +72 -76
- data/vendor/faiss/IndexIVF.h +24 -5
- data/vendor/faiss/IndexIVFFlat.cpp +19 -54
- data/vendor/faiss/IndexIVFFlat.h +1 -11
- data/vendor/faiss/IndexIVFPQ.cpp +49 -26
- data/vendor/faiss/IndexIVFPQ.h +9 -10
- data/vendor/faiss/IndexIVFPQR.cpp +2 -2
- data/vendor/faiss/IndexIVFSpectralHash.cpp +2 -2
- data/vendor/faiss/IndexLSH.h +4 -1
- data/vendor/faiss/IndexPreTransform.cpp +0 -1
- data/vendor/faiss/IndexScalarQuantizer.cpp +8 -1
- data/vendor/faiss/InvertedLists.cpp +0 -2
- data/vendor/faiss/MetaIndexes.cpp +0 -1
- data/vendor/faiss/MetricType.h +36 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +13 -7
- data/vendor/faiss/c_api/Clustering_c.h +11 -5
- data/vendor/faiss/c_api/IndexIVF_c.cpp +7 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +7 -0
- data/vendor/faiss/c_api/IndexPreTransform_c.cpp +21 -0
- data/vendor/faiss/c_api/IndexPreTransform_c.h +32 -0
- data/vendor/faiss/demos/demo_weighted_kmeans.cpp +185 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +4 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +1 -1
- data/vendor/faiss/gpu/GpuDistance.h +93 -0
- data/vendor/faiss/gpu/GpuIndex.h +7 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +0 -10
- data/vendor/faiss/gpu/GpuIndexIVF.h +1 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +8 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +49 -27
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +110 -2
- data/vendor/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +17 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +14 -3
- data/vendor/faiss/impl/HNSW.cpp +0 -1
- data/vendor/faiss/impl/PolysemousTraining.h +5 -5
- data/vendor/faiss/impl/ProductQuantizer-inl.h +138 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +1 -113
- data/vendor/faiss/impl/ProductQuantizer.h +42 -47
- data/vendor/faiss/impl/index_read.cpp +103 -7
- data/vendor/faiss/impl/index_write.cpp +101 -5
- data/vendor/faiss/impl/io.cpp +111 -1
- data/vendor/faiss/impl/io.h +38 -0
- data/vendor/faiss/index_factory.cpp +0 -1
- data/vendor/faiss/tests/test_merge.cpp +0 -1
- data/vendor/faiss/tests/test_pq_encoding.cpp +6 -6
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +1 -0
- data/vendor/faiss/utils/distances.cpp +4 -5
- data/vendor/faiss/utils/distances_simd.cpp +0 -1
- data/vendor/faiss/utils/hamming.cpp +85 -3
- data/vendor/faiss/utils/hamming.h +20 -0
- data/vendor/faiss/utils/utils.cpp +0 -96
- data/vendor/faiss/utils/utils.h +0 -15
- metadata +11 -3
- data/lib/faiss/ext.bundle +0 -0
@@ -144,8 +144,8 @@ void IndexIVFPQR::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
144
144
|
|
145
145
|
if (sl == -1) continue;
|
146
146
|
|
147
|
-
int list_no = sl
|
148
|
-
int ofs = sl
|
147
|
+
int list_no = lo_listno(sl);
|
148
|
+
int ofs = lo_offset(sl);
|
149
149
|
|
150
150
|
assert (list_no >= 0 && list_no < nlist);
|
151
151
|
assert (ofs >= 0 && ofs < invlists->list_size (list_no));
|
@@ -270,7 +270,7 @@ struct IVFScanner: InvertedListScanner {
|
|
270
270
|
|
271
271
|
if (dis < simi [0]) {
|
272
272
|
maxheap_pop (k, simi, idxi);
|
273
|
-
int64_t id = store_pairs ? (list_no
|
273
|
+
int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
|
274
274
|
maxheap_push (k, simi, idxi, dis, id);
|
275
275
|
nup++;
|
276
276
|
}
|
@@ -288,7 +288,7 @@ struct IVFScanner: InvertedListScanner {
|
|
288
288
|
for (size_t j = 0; j < list_size; j++) {
|
289
289
|
float dis = hc.hamming (codes);
|
290
290
|
if (dis < radius) {
|
291
|
-
int64_t id = store_pairs ? (list_no
|
291
|
+
int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
|
292
292
|
res.add (dis, id);
|
293
293
|
}
|
294
294
|
codes += code_size;
|
data/vendor/faiss/IndexLSH.h
CHANGED
@@ -69,7 +69,10 @@ struct IndexLSH:Index {
|
|
69
69
|
|
70
70
|
IndexLSH ();
|
71
71
|
|
72
|
-
/* standalone codec interface
|
72
|
+
/* standalone codec interface.
|
73
|
+
*
|
74
|
+
* The vectors are decoded to +/- 1 (not 0, 1) */
|
75
|
+
|
73
76
|
size_t sa_code_size () const override;
|
74
77
|
|
75
78
|
void sa_encode (idx_t n, const float *x,
|
@@ -253,6 +253,8 @@ void IndexIVFScalarQuantizer::add_with_ids
|
|
253
253
|
size_t nadd = 0;
|
254
254
|
std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer ());
|
255
255
|
|
256
|
+
DirectMapAdd dm_add (direct_map, n, xids);
|
257
|
+
|
256
258
|
#pragma omp parallel reduction(+: nadd)
|
257
259
|
{
|
258
260
|
std::vector<float> residual (d);
|
@@ -275,13 +277,18 @@ void IndexIVFScalarQuantizer::add_with_ids
|
|
275
277
|
memset (one_code.data(), 0, code_size);
|
276
278
|
squant->encode_vector (xi, one_code.data());
|
277
279
|
|
278
|
-
invlists->add_entry (list_no, id, one_code.data());
|
280
|
+
size_t ofs = invlists->add_entry (list_no, id, one_code.data());
|
279
281
|
|
282
|
+
dm_add.add (i, list_no, ofs);
|
280
283
|
nadd++;
|
281
284
|
|
285
|
+
} else if (rank == 0 && list_no == -1) {
|
286
|
+
dm_add.add (i, -1, 0);
|
282
287
|
}
|
283
288
|
}
|
284
289
|
}
|
290
|
+
|
291
|
+
|
285
292
|
ntotal += n;
|
286
293
|
}
|
287
294
|
|
@@ -0,0 +1,36 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#ifndef FAISS_METRIC_TYPE_H
|
11
|
+
#define FAISS_METRIC_TYPE_H
|
12
|
+
|
13
|
+
namespace faiss {
|
14
|
+
|
15
|
+
/// The metric space for vector comparison for Faiss indices and algorithms.
|
16
|
+
///
|
17
|
+
/// Most algorithms support both inner product and L2, with the flat
|
18
|
+
/// (brute-force) indices supporting additional metric types for vector
|
19
|
+
/// comparison.
|
20
|
+
enum MetricType {
|
21
|
+
METRIC_INNER_PRODUCT = 0, ///< maximum inner product search
|
22
|
+
METRIC_L2 = 1, ///< squared L2 search
|
23
|
+
METRIC_L1, ///< L1 (aka cityblock)
|
24
|
+
METRIC_Linf, ///< infinity distance
|
25
|
+
METRIC_Lp, ///< L_p distance, p is given by a faiss::Index
|
26
|
+
/// metric_arg
|
27
|
+
|
28
|
+
/// some additional metrics defined in scipy.spatial.distance
|
29
|
+
METRIC_Canberra = 20,
|
30
|
+
METRIC_BrayCurtis,
|
31
|
+
METRIC_JensenShannon,
|
32
|
+
};
|
33
|
+
|
34
|
+
}
|
35
|
+
|
36
|
+
#endif
|
@@ -19,6 +19,7 @@ extern "C" {
|
|
19
19
|
using faiss::Clustering;
|
20
20
|
using faiss::ClusteringParameters;
|
21
21
|
using faiss::Index;
|
22
|
+
using faiss::ClusteringIterationStats;
|
22
23
|
|
23
24
|
DEFINE_GETTER(Clustering, int, niter)
|
24
25
|
DEFINE_GETTER(Clustering, int, nredo)
|
@@ -38,6 +39,12 @@ DEFINE_GETTER(Clustering, size_t, d)
|
|
38
39
|
/// getter for k
|
39
40
|
DEFINE_GETTER(Clustering, size_t, k)
|
40
41
|
|
42
|
+
DEFINE_GETTER(ClusteringIterationStats, float, obj)
|
43
|
+
DEFINE_GETTER(ClusteringIterationStats, double, time)
|
44
|
+
DEFINE_GETTER(ClusteringIterationStats, double, time_search)
|
45
|
+
DEFINE_GETTER(ClusteringIterationStats, double, imbalance_factor)
|
46
|
+
DEFINE_GETTER(ClusteringIterationStats, int, nsplit)
|
47
|
+
|
41
48
|
void faiss_ClusteringParameters_init(FaissClusteringParameters* params) {
|
42
49
|
ClusteringParameters d;
|
43
50
|
params->frozen_centroids = d.frozen_centroids;
|
@@ -78,13 +85,12 @@ void faiss_Clustering_centroids(
|
|
78
85
|
}
|
79
86
|
}
|
80
87
|
|
81
|
-
/// getter for
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
*obj = v.data();
|
88
|
+
/// getter for iteration stats
|
89
|
+
void faiss_Clustering_iteration_stats(
|
90
|
+
FaissClustering* clustering, FaissClusteringIterationStats** iteration_stats, size_t* size) {
|
91
|
+
std::vector<ClusteringIterationStats>& v = reinterpret_cast<Clustering*>(clustering)->iteration_stats;
|
92
|
+
if (iteration_stats) {
|
93
|
+
*iteration_stats = reinterpret_cast<FaissClusteringIterationStats*>(v.data());
|
88
94
|
}
|
89
95
|
if (size) {
|
90
96
|
*size = v.size();
|
@@ -47,7 +47,7 @@ void faiss_ClusteringParameters_init(FaissClusteringParameters* params);
|
|
47
47
|
* points to the centroids. Therefore, at each iteration the centroids
|
48
48
|
* are added to the index.
|
49
49
|
*
|
50
|
-
* On output, the
|
50
|
+
* On output, the centroids table is set to the latest version
|
51
51
|
* of the centroids and they are also added to the index. If the
|
52
52
|
* centroids table it is not empty on input, it is also used for
|
53
53
|
* initialization.
|
@@ -75,14 +75,20 @@ FAISS_DECLARE_GETTER(Clustering, size_t, d)
|
|
75
75
|
/// getter for k
|
76
76
|
FAISS_DECLARE_GETTER(Clustering, size_t, k)
|
77
77
|
|
78
|
+
FAISS_DECLARE_CLASS(ClusteringIterationStats)
|
79
|
+
FAISS_DECLARE_GETTER(ClusteringIterationStats, float, obj)
|
80
|
+
FAISS_DECLARE_GETTER(ClusteringIterationStats, double, time)
|
81
|
+
FAISS_DECLARE_GETTER(ClusteringIterationStats, double, time_search)
|
82
|
+
FAISS_DECLARE_GETTER(ClusteringIterationStats, double, imbalance_factor)
|
83
|
+
FAISS_DECLARE_GETTER(ClusteringIterationStats, int, nsplit)
|
84
|
+
|
78
85
|
/// getter for centroids (size = k * d)
|
79
86
|
void faiss_Clustering_centroids(
|
80
87
|
FaissClustering* clustering, float** centroids, size_t* size);
|
81
88
|
|
82
|
-
/// getter for
|
83
|
-
|
84
|
-
|
85
|
-
FaissClustering* clustering, float** obj, size_t* size);
|
89
|
+
/// getter for iteration stats
|
90
|
+
void faiss_Clustering_iteration_stats(
|
91
|
+
FaissClustering* clustering, FaissClusteringIterationStats** iteration_stats, size_t* size);
|
86
92
|
|
87
93
|
/// the only mandatory parameters are k and d
|
88
94
|
int faiss_Clustering_new(FaissClustering** p_clustering, int d, int k);
|
@@ -87,6 +87,13 @@ void faiss_IndexIVF_print_stats (const FaissIndexIVF* index) {
|
|
87
87
|
reinterpret_cast<const IndexIVF*>(index)->invlists->print_stats();
|
88
88
|
}
|
89
89
|
|
90
|
+
/// get inverted lists ids
|
91
|
+
void faiss_IndexIVF_invlists_get_ids (const FaissIndexIVF* index, size_t list_no, idx_t* invlist) {
|
92
|
+
const idx_t* list = reinterpret_cast<const IndexIVF*>(index)->invlists->get_ids(list_no);
|
93
|
+
size_t list_size = reinterpret_cast<const IndexIVF*>(index)->get_list_size(list_no);
|
94
|
+
memcpy(invlist, list, list_size*sizeof(idx_t));
|
95
|
+
}
|
96
|
+
|
90
97
|
void faiss_IndexIVFStats_reset(FaissIndexIVFStats* stats) {
|
91
98
|
reinterpret_cast<IndexIVFStats*>(stats)->reset();
|
92
99
|
}
|
@@ -114,6 +114,13 @@ double faiss_IndexIVF_imbalance_factor (const FaissIndexIVF* index);
|
|
114
114
|
/// display some stats about the inverted lists of the index
|
115
115
|
void faiss_IndexIVF_print_stats (const FaissIndexIVF* index);
|
116
116
|
|
117
|
+
/// Get the IDs in an inverted list. IDs are written to `invlist`, which must be large enough
|
118
|
+
//// to accommodate the full list.
|
119
|
+
///
|
120
|
+
/// @param list_no the list ID
|
121
|
+
/// @param invlist output pointer to a slice of memory, at least as long as the list's size
|
122
|
+
/// @see faiss_IndexIVF_get_list_size(size_t)
|
123
|
+
void faiss_IndexIVF_invlists_get_ids (const FaissIndexIVF* index, size_t list_no, idx_t* invlist);
|
117
124
|
|
118
125
|
typedef struct FaissIndexIVFStats {
|
119
126
|
size_t nq; // nb of queries run
|
@@ -0,0 +1,21 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// Copyright 2004-present Facebook. All Rights Reserved.
|
9
|
+
// -*- c++ -*-
|
10
|
+
|
11
|
+
#include "IndexPreTransform_c.h"
|
12
|
+
#include "IndexPreTransform.h"
|
13
|
+
#include "macros_impl.h"
|
14
|
+
|
15
|
+
using faiss::Index;
|
16
|
+
using faiss::IndexPreTransform;
|
17
|
+
|
18
|
+
DEFINE_DESTRUCTOR(IndexPreTransform)
|
19
|
+
DEFINE_INDEX_DOWNCAST(IndexPreTransform)
|
20
|
+
|
21
|
+
DEFINE_GETTER_PERMISSIVE(IndexPreTransform, FaissIndex*, index)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// Copyright 2004-present Facebook. All Rights Reserved.
|
9
|
+
// -*- c -*-
|
10
|
+
|
11
|
+
#ifndef FAISS_INDEX_PRETRANSFORM_C_H
|
12
|
+
#define FAISS_INDEX_PRETRANSFORM_C_H
|
13
|
+
|
14
|
+
#include "faiss_c.h"
|
15
|
+
#include "Index_c.h"
|
16
|
+
|
17
|
+
#ifdef __cplusplus
|
18
|
+
extern "C" {
|
19
|
+
#endif
|
20
|
+
|
21
|
+
FAISS_DECLARE_CLASS(IndexPreTransform)
|
22
|
+
FAISS_DECLARE_DESTRUCTOR(IndexPreTransform)
|
23
|
+
FAISS_DECLARE_INDEX_DOWNCAST(IndexPreTransform)
|
24
|
+
|
25
|
+
FAISS_DECLARE_GETTER(IndexPreTransform, FaissIndex*, index)
|
26
|
+
|
27
|
+
#ifdef __cplusplus
|
28
|
+
}
|
29
|
+
#endif
|
30
|
+
|
31
|
+
|
32
|
+
#endif
|
@@ -0,0 +1,185 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#include <cstdio>
|
9
|
+
#include <cstdlib>
|
10
|
+
|
11
|
+
#include <faiss/Clustering.h>
|
12
|
+
#include <faiss/utils/random.h>
|
13
|
+
#include <faiss/utils/distances.h>
|
14
|
+
#include <faiss/IndexFlat.h>
|
15
|
+
#include <faiss/IndexHNSW.h>
|
16
|
+
|
17
|
+
|
18
|
+
namespace {
|
19
|
+
|
20
|
+
|
21
|
+
enum WeightedKMeansType {
|
22
|
+
WKMT_FlatL2,
|
23
|
+
WKMT_FlatIP,
|
24
|
+
WKMT_FlatIP_spherical,
|
25
|
+
WKMT_HNSW,
|
26
|
+
};
|
27
|
+
|
28
|
+
|
29
|
+
float weighted_kmeans_clustering (size_t d, size_t n, size_t k,
|
30
|
+
const float *input,
|
31
|
+
const float *weights,
|
32
|
+
float *centroids,
|
33
|
+
WeightedKMeansType index_num)
|
34
|
+
{
|
35
|
+
using namespace faiss;
|
36
|
+
Clustering clus (d, k);
|
37
|
+
clus.verbose = true;
|
38
|
+
|
39
|
+
std::unique_ptr<Index> index;
|
40
|
+
|
41
|
+
switch (index_num) {
|
42
|
+
case WKMT_FlatL2:
|
43
|
+
index.reset(new IndexFlatL2 (d));
|
44
|
+
break;
|
45
|
+
case WKMT_FlatIP:
|
46
|
+
index.reset(new IndexFlatIP (d));
|
47
|
+
break;
|
48
|
+
case WKMT_FlatIP_spherical:
|
49
|
+
index.reset(new IndexFlatIP (d));
|
50
|
+
clus.spherical = true;
|
51
|
+
break;
|
52
|
+
case WKMT_HNSW:
|
53
|
+
IndexHNSWFlat *ihnsw = new IndexHNSWFlat (d, 32);
|
54
|
+
ihnsw->hnsw.efSearch = 128;
|
55
|
+
index.reset(ihnsw);
|
56
|
+
break;
|
57
|
+
}
|
58
|
+
|
59
|
+
clus.train(n, input, *index.get(), weights);
|
60
|
+
// on output the index contains the centroids.
|
61
|
+
memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
|
62
|
+
return clus.iteration_stats.back().obj;
|
63
|
+
}
|
64
|
+
|
65
|
+
|
66
|
+
int d = 32;
|
67
|
+
float sigma = 0.1;
|
68
|
+
|
69
|
+
#define BIGTEST
|
70
|
+
|
71
|
+
#ifdef BIGTEST
|
72
|
+
// the production setup = setting of https://fb.quip.com/CWgnAAYbwtgs
|
73
|
+
int nc = 200000;
|
74
|
+
int n_big = 4;
|
75
|
+
int n_small = 2;
|
76
|
+
#else
|
77
|
+
int nc = 5;
|
78
|
+
int n_big = 100;
|
79
|
+
int n_small = 10;
|
80
|
+
#endif
|
81
|
+
|
82
|
+
int n; // number of training points
|
83
|
+
|
84
|
+
void generate_trainset (std::vector<float> & ccent,
|
85
|
+
std::vector<float> & x,
|
86
|
+
std::vector<float> & weights)
|
87
|
+
{
|
88
|
+
// same sampling as test_build_blocks.py test_weighted
|
89
|
+
|
90
|
+
ccent.resize (d * 2 * nc);
|
91
|
+
faiss::float_randn (ccent.data(), d * 2 * nc, 123);
|
92
|
+
faiss::fvec_renorm_L2 (d, 2 * nc, ccent.data());
|
93
|
+
n = nc * n_big + nc * n_small;
|
94
|
+
x.resize(d * n);
|
95
|
+
weights.resize(n);
|
96
|
+
faiss::float_randn (x.data(), x.size(), 1234);
|
97
|
+
|
98
|
+
float *xi = x.data();
|
99
|
+
float *w = weights.data();
|
100
|
+
for (int ci = 0; ci < nc * 2; ci++) { // loop over centroids
|
101
|
+
int np = ci < nc ? n_big : n_small; // nb of points around this centroid
|
102
|
+
for (int i = 0; i < np; i++) {
|
103
|
+
for (int j = 0; j < d; j++) {
|
104
|
+
xi[j] = xi[j] * sigma + ccent[ci * d + j];
|
105
|
+
}
|
106
|
+
*w++ = ci < nc ? 0.1 : 10;
|
107
|
+
xi += d;
|
108
|
+
}
|
109
|
+
}
|
110
|
+
}
|
111
|
+
|
112
|
+
}
|
113
|
+
|
114
|
+
|
115
|
+
int main(int argc, char **argv) {
|
116
|
+
std::vector<float> ccent;
|
117
|
+
std::vector<float> x;
|
118
|
+
std::vector<float> weights;
|
119
|
+
|
120
|
+
printf("generate training set\n");
|
121
|
+
generate_trainset(ccent, x, weights);
|
122
|
+
|
123
|
+
std::vector<float> centroids;
|
124
|
+
centroids.resize(nc * d);
|
125
|
+
|
126
|
+
int the_index_num = -1;
|
127
|
+
int the_with_weights = -1;
|
128
|
+
|
129
|
+
if (argc == 3) {
|
130
|
+
the_index_num = atoi(argv[1]);
|
131
|
+
the_with_weights = atoi(argv[2]);
|
132
|
+
}
|
133
|
+
|
134
|
+
|
135
|
+
for (int index_num = WKMT_FlatL2;
|
136
|
+
index_num <= WKMT_HNSW;
|
137
|
+
index_num++) {
|
138
|
+
|
139
|
+
if (the_index_num >= 0 && index_num != the_index_num) {
|
140
|
+
continue;
|
141
|
+
}
|
142
|
+
|
143
|
+
for (int with_weights = 0; with_weights <= 1; with_weights++) {
|
144
|
+
if (the_with_weights >= 0 && with_weights != the_with_weights) {
|
145
|
+
continue;
|
146
|
+
}
|
147
|
+
|
148
|
+
printf("=================== index_num=%d Run %s weights\n",
|
149
|
+
index_num, with_weights ? "with" : "without");
|
150
|
+
|
151
|
+
weighted_kmeans_clustering (
|
152
|
+
d, n, nc, x.data(),
|
153
|
+
with_weights ? weights.data() : nullptr,
|
154
|
+
centroids.data(), (WeightedKMeansType)index_num
|
155
|
+
);
|
156
|
+
|
157
|
+
{ // compute distance of points to centroids
|
158
|
+
faiss::IndexFlatL2 cent_index(d);
|
159
|
+
cent_index.add(nc, centroids.data());
|
160
|
+
std::vector<float> dis (n);
|
161
|
+
std::vector<faiss::Index::idx_t> idx (n);
|
162
|
+
|
163
|
+
cent_index.search (nc * 2, ccent.data(), 1,
|
164
|
+
dis.data(), idx.data());
|
165
|
+
|
166
|
+
float dis1 = 0, dis2 = 0;
|
167
|
+
for (int i = 0; i < nc ; i++) {
|
168
|
+
dis1 += dis[i];
|
169
|
+
}
|
170
|
+
printf("average distance of points from big clusters: %g\n",
|
171
|
+
dis1 / nc);
|
172
|
+
|
173
|
+
for (int i = 0; i < nc ; i++) {
|
174
|
+
dis2 += dis[i + nc];
|
175
|
+
}
|
176
|
+
|
177
|
+
printf("average distance of points from small clusters: %g\n",
|
178
|
+
dis2 / nc);
|
179
|
+
|
180
|
+
}
|
181
|
+
|
182
|
+
}
|
183
|
+
}
|
184
|
+
return 0;
|
185
|
+
}
|