faiss 0.1.1 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +18 -18
- data/README.md +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/Clustering.cpp +318 -53
- data/vendor/faiss/Clustering.h +39 -11
- data/vendor/faiss/DirectMap.cpp +267 -0
- data/vendor/faiss/DirectMap.h +120 -0
- data/vendor/faiss/IVFlib.cpp +24 -4
- data/vendor/faiss/IVFlib.h +4 -0
- data/vendor/faiss/Index.h +5 -24
- data/vendor/faiss/Index2Layer.cpp +0 -1
- data/vendor/faiss/IndexBinary.h +7 -3
- data/vendor/faiss/IndexBinaryFlat.cpp +5 -0
- data/vendor/faiss/IndexBinaryFlat.h +3 -0
- data/vendor/faiss/IndexBinaryHash.cpp +492 -0
- data/vendor/faiss/IndexBinaryHash.h +116 -0
- data/vendor/faiss/IndexBinaryIVF.cpp +160 -107
- data/vendor/faiss/IndexBinaryIVF.h +14 -4
- data/vendor/faiss/IndexFlat.h +2 -1
- data/vendor/faiss/IndexHNSW.cpp +68 -16
- data/vendor/faiss/IndexHNSW.h +3 -3
- data/vendor/faiss/IndexIVF.cpp +72 -76
- data/vendor/faiss/IndexIVF.h +24 -5
- data/vendor/faiss/IndexIVFFlat.cpp +19 -54
- data/vendor/faiss/IndexIVFFlat.h +1 -11
- data/vendor/faiss/IndexIVFPQ.cpp +49 -26
- data/vendor/faiss/IndexIVFPQ.h +9 -10
- data/vendor/faiss/IndexIVFPQR.cpp +2 -2
- data/vendor/faiss/IndexIVFSpectralHash.cpp +2 -2
- data/vendor/faiss/IndexLSH.h +4 -1
- data/vendor/faiss/IndexPreTransform.cpp +0 -1
- data/vendor/faiss/IndexScalarQuantizer.cpp +8 -1
- data/vendor/faiss/InvertedLists.cpp +0 -2
- data/vendor/faiss/MetaIndexes.cpp +0 -1
- data/vendor/faiss/MetricType.h +36 -0
- data/vendor/faiss/c_api/Clustering_c.cpp +13 -7
- data/vendor/faiss/c_api/Clustering_c.h +11 -5
- data/vendor/faiss/c_api/IndexIVF_c.cpp +7 -0
- data/vendor/faiss/c_api/IndexIVF_c.h +7 -0
- data/vendor/faiss/c_api/IndexPreTransform_c.cpp +21 -0
- data/vendor/faiss/c_api/IndexPreTransform_c.h +32 -0
- data/vendor/faiss/demos/demo_weighted_kmeans.cpp +185 -0
- data/vendor/faiss/gpu/GpuCloner.cpp +4 -0
- data/vendor/faiss/gpu/GpuClonerOptions.cpp +1 -1
- data/vendor/faiss/gpu/GpuDistance.h +93 -0
- data/vendor/faiss/gpu/GpuIndex.h +7 -0
- data/vendor/faiss/gpu/GpuIndexFlat.h +0 -10
- data/vendor/faiss/gpu/GpuIndexIVF.h +1 -0
- data/vendor/faiss/gpu/StandardGpuResources.cpp +8 -0
- data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +49 -27
- data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +110 -2
- data/vendor/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/impl/AuxIndexStructures.cpp +17 -0
- data/vendor/faiss/impl/AuxIndexStructures.h +14 -3
- data/vendor/faiss/impl/HNSW.cpp +0 -1
- data/vendor/faiss/impl/PolysemousTraining.h +5 -5
- data/vendor/faiss/impl/ProductQuantizer-inl.h +138 -0
- data/vendor/faiss/impl/ProductQuantizer.cpp +1 -113
- data/vendor/faiss/impl/ProductQuantizer.h +42 -47
- data/vendor/faiss/impl/index_read.cpp +103 -7
- data/vendor/faiss/impl/index_write.cpp +101 -5
- data/vendor/faiss/impl/io.cpp +111 -1
- data/vendor/faiss/impl/io.h +38 -0
- data/vendor/faiss/index_factory.cpp +0 -1
- data/vendor/faiss/tests/test_merge.cpp +0 -1
- data/vendor/faiss/tests/test_pq_encoding.cpp +6 -6
- data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +1 -0
- data/vendor/faiss/utils/distances.cpp +4 -5
- data/vendor/faiss/utils/distances_simd.cpp +0 -1
- data/vendor/faiss/utils/hamming.cpp +85 -3
- data/vendor/faiss/utils/hamming.h +20 -0
- data/vendor/faiss/utils/utils.cpp +0 -96
- data/vendor/faiss/utils/utils.h +0 -15
- metadata +11 -3
- data/lib/faiss/ext.bundle +0 -0
@@ -144,8 +144,8 @@ void IndexIVFPQR::search_preassigned (idx_t n, const float *x, idx_t k,
|
|
144
144
|
|
145
145
|
if (sl == -1) continue;
|
146
146
|
|
147
|
-
int list_no = sl
|
148
|
-
int ofs = sl
|
147
|
+
int list_no = lo_listno(sl);
|
148
|
+
int ofs = lo_offset(sl);
|
149
149
|
|
150
150
|
assert (list_no >= 0 && list_no < nlist);
|
151
151
|
assert (ofs >= 0 && ofs < invlists->list_size (list_no));
|
@@ -270,7 +270,7 @@ struct IVFScanner: InvertedListScanner {
|
|
270
270
|
|
271
271
|
if (dis < simi [0]) {
|
272
272
|
maxheap_pop (k, simi, idxi);
|
273
|
-
int64_t id = store_pairs ? (list_no
|
273
|
+
int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
|
274
274
|
maxheap_push (k, simi, idxi, dis, id);
|
275
275
|
nup++;
|
276
276
|
}
|
@@ -288,7 +288,7 @@ struct IVFScanner: InvertedListScanner {
|
|
288
288
|
for (size_t j = 0; j < list_size; j++) {
|
289
289
|
float dis = hc.hamming (codes);
|
290
290
|
if (dis < radius) {
|
291
|
-
int64_t id = store_pairs ? (list_no
|
291
|
+
int64_t id = store_pairs ? lo_build (list_no, j) : ids[j];
|
292
292
|
res.add (dis, id);
|
293
293
|
}
|
294
294
|
codes += code_size;
|
data/vendor/faiss/IndexLSH.h
CHANGED
@@ -69,7 +69,10 @@ struct IndexLSH:Index {
|
|
69
69
|
|
70
70
|
IndexLSH ();
|
71
71
|
|
72
|
-
/* standalone codec interface
|
72
|
+
/* standalone codec interface.
|
73
|
+
*
|
74
|
+
* The vectors are decoded to +/- 1 (not 0, 1) */
|
75
|
+
|
73
76
|
size_t sa_code_size () const override;
|
74
77
|
|
75
78
|
void sa_encode (idx_t n, const float *x,
|
@@ -253,6 +253,8 @@ void IndexIVFScalarQuantizer::add_with_ids
|
|
253
253
|
size_t nadd = 0;
|
254
254
|
std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer ());
|
255
255
|
|
256
|
+
DirectMapAdd dm_add (direct_map, n, xids);
|
257
|
+
|
256
258
|
#pragma omp parallel reduction(+: nadd)
|
257
259
|
{
|
258
260
|
std::vector<float> residual (d);
|
@@ -275,13 +277,18 @@ void IndexIVFScalarQuantizer::add_with_ids
|
|
275
277
|
memset (one_code.data(), 0, code_size);
|
276
278
|
squant->encode_vector (xi, one_code.data());
|
277
279
|
|
278
|
-
invlists->add_entry (list_no, id, one_code.data());
|
280
|
+
size_t ofs = invlists->add_entry (list_no, id, one_code.data());
|
279
281
|
|
282
|
+
dm_add.add (i, list_no, ofs);
|
280
283
|
nadd++;
|
281
284
|
|
285
|
+
} else if (rank == 0 && list_no == -1) {
|
286
|
+
dm_add.add (i, -1, 0);
|
282
287
|
}
|
283
288
|
}
|
284
289
|
}
|
290
|
+
|
291
|
+
|
285
292
|
ntotal += n;
|
286
293
|
}
|
287
294
|
|
@@ -0,0 +1,36 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// -*- c++ -*-
|
9
|
+
|
10
|
+
#ifndef FAISS_METRIC_TYPE_H
|
11
|
+
#define FAISS_METRIC_TYPE_H
|
12
|
+
|
13
|
+
namespace faiss {
|
14
|
+
|
15
|
+
/// The metric space for vector comparison for Faiss indices and algorithms.
|
16
|
+
///
|
17
|
+
/// Most algorithms support both inner product and L2, with the flat
|
18
|
+
/// (brute-force) indices supporting additional metric types for vector
|
19
|
+
/// comparison.
|
20
|
+
enum MetricType {
|
21
|
+
METRIC_INNER_PRODUCT = 0, ///< maximum inner product search
|
22
|
+
METRIC_L2 = 1, ///< squared L2 search
|
23
|
+
METRIC_L1, ///< L1 (aka cityblock)
|
24
|
+
METRIC_Linf, ///< infinity distance
|
25
|
+
METRIC_Lp, ///< L_p distance, p is given by a faiss::Index
|
26
|
+
/// metric_arg
|
27
|
+
|
28
|
+
/// some additional metrics defined in scipy.spatial.distance
|
29
|
+
METRIC_Canberra = 20,
|
30
|
+
METRIC_BrayCurtis,
|
31
|
+
METRIC_JensenShannon,
|
32
|
+
};
|
33
|
+
|
34
|
+
}
|
35
|
+
|
36
|
+
#endif
|
@@ -19,6 +19,7 @@ extern "C" {
|
|
19
19
|
using faiss::Clustering;
|
20
20
|
using faiss::ClusteringParameters;
|
21
21
|
using faiss::Index;
|
22
|
+
using faiss::ClusteringIterationStats;
|
22
23
|
|
23
24
|
DEFINE_GETTER(Clustering, int, niter)
|
24
25
|
DEFINE_GETTER(Clustering, int, nredo)
|
@@ -38,6 +39,12 @@ DEFINE_GETTER(Clustering, size_t, d)
|
|
38
39
|
/// getter for k
|
39
40
|
DEFINE_GETTER(Clustering, size_t, k)
|
40
41
|
|
42
|
+
DEFINE_GETTER(ClusteringIterationStats, float, obj)
|
43
|
+
DEFINE_GETTER(ClusteringIterationStats, double, time)
|
44
|
+
DEFINE_GETTER(ClusteringIterationStats, double, time_search)
|
45
|
+
DEFINE_GETTER(ClusteringIterationStats, double, imbalance_factor)
|
46
|
+
DEFINE_GETTER(ClusteringIterationStats, int, nsplit)
|
47
|
+
|
41
48
|
void faiss_ClusteringParameters_init(FaissClusteringParameters* params) {
|
42
49
|
ClusteringParameters d;
|
43
50
|
params->frozen_centroids = d.frozen_centroids;
|
@@ -78,13 +85,12 @@ void faiss_Clustering_centroids(
|
|
78
85
|
}
|
79
86
|
}
|
80
87
|
|
81
|
-
/// getter for
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
*obj = v.data();
|
88
|
+
/// getter for iteration stats
|
89
|
+
void faiss_Clustering_iteration_stats(
|
90
|
+
FaissClustering* clustering, FaissClusteringIterationStats** iteration_stats, size_t* size) {
|
91
|
+
std::vector<ClusteringIterationStats>& v = reinterpret_cast<Clustering*>(clustering)->iteration_stats;
|
92
|
+
if (iteration_stats) {
|
93
|
+
*iteration_stats = reinterpret_cast<FaissClusteringIterationStats*>(v.data());
|
88
94
|
}
|
89
95
|
if (size) {
|
90
96
|
*size = v.size();
|
@@ -47,7 +47,7 @@ void faiss_ClusteringParameters_init(FaissClusteringParameters* params);
|
|
47
47
|
* points to the centroids. Therefore, at each iteration the centroids
|
48
48
|
* are added to the index.
|
49
49
|
*
|
50
|
-
* On output, the
|
50
|
+
* On output, the centroids table is set to the latest version
|
51
51
|
* of the centroids and they are also added to the index. If the
|
52
52
|
* centroids table it is not empty on input, it is also used for
|
53
53
|
* initialization.
|
@@ -75,14 +75,20 @@ FAISS_DECLARE_GETTER(Clustering, size_t, d)
|
|
75
75
|
/// getter for k
|
76
76
|
FAISS_DECLARE_GETTER(Clustering, size_t, k)
|
77
77
|
|
78
|
+
FAISS_DECLARE_CLASS(ClusteringIterationStats)
|
79
|
+
FAISS_DECLARE_GETTER(ClusteringIterationStats, float, obj)
|
80
|
+
FAISS_DECLARE_GETTER(ClusteringIterationStats, double, time)
|
81
|
+
FAISS_DECLARE_GETTER(ClusteringIterationStats, double, time_search)
|
82
|
+
FAISS_DECLARE_GETTER(ClusteringIterationStats, double, imbalance_factor)
|
83
|
+
FAISS_DECLARE_GETTER(ClusteringIterationStats, int, nsplit)
|
84
|
+
|
78
85
|
/// getter for centroids (size = k * d)
|
79
86
|
void faiss_Clustering_centroids(
|
80
87
|
FaissClustering* clustering, float** centroids, size_t* size);
|
81
88
|
|
82
|
-
/// getter for
|
83
|
-
|
84
|
-
|
85
|
-
FaissClustering* clustering, float** obj, size_t* size);
|
89
|
+
/// getter for iteration stats
|
90
|
+
void faiss_Clustering_iteration_stats(
|
91
|
+
FaissClustering* clustering, FaissClusteringIterationStats** iteration_stats, size_t* size);
|
86
92
|
|
87
93
|
/// the only mandatory parameters are k and d
|
88
94
|
int faiss_Clustering_new(FaissClustering** p_clustering, int d, int k);
|
@@ -87,6 +87,13 @@ void faiss_IndexIVF_print_stats (const FaissIndexIVF* index) {
|
|
87
87
|
reinterpret_cast<const IndexIVF*>(index)->invlists->print_stats();
|
88
88
|
}
|
89
89
|
|
90
|
+
/// get inverted lists ids
|
91
|
+
void faiss_IndexIVF_invlists_get_ids (const FaissIndexIVF* index, size_t list_no, idx_t* invlist) {
|
92
|
+
const idx_t* list = reinterpret_cast<const IndexIVF*>(index)->invlists->get_ids(list_no);
|
93
|
+
size_t list_size = reinterpret_cast<const IndexIVF*>(index)->get_list_size(list_no);
|
94
|
+
memcpy(invlist, list, list_size*sizeof(idx_t));
|
95
|
+
}
|
96
|
+
|
90
97
|
void faiss_IndexIVFStats_reset(FaissIndexIVFStats* stats) {
|
91
98
|
reinterpret_cast<IndexIVFStats*>(stats)->reset();
|
92
99
|
}
|
@@ -114,6 +114,13 @@ double faiss_IndexIVF_imbalance_factor (const FaissIndexIVF* index);
|
|
114
114
|
/// display some stats about the inverted lists of the index
|
115
115
|
void faiss_IndexIVF_print_stats (const FaissIndexIVF* index);
|
116
116
|
|
117
|
+
/// Get the IDs in an inverted list. IDs are written to `invlist`, which must be large enough
|
118
|
+
//// to accommodate the full list.
|
119
|
+
///
|
120
|
+
/// @param list_no the list ID
|
121
|
+
/// @param invlist output pointer to a slice of memory, at least as long as the list's size
|
122
|
+
/// @see faiss_IndexIVF_get_list_size(size_t)
|
123
|
+
void faiss_IndexIVF_invlists_get_ids (const FaissIndexIVF* index, size_t list_no, idx_t* invlist);
|
117
124
|
|
118
125
|
typedef struct FaissIndexIVFStats {
|
119
126
|
size_t nq; // nb of queries run
|
@@ -0,0 +1,21 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// Copyright 2004-present Facebook. All Rights Reserved.
|
9
|
+
// -*- c++ -*-
|
10
|
+
|
11
|
+
#include "IndexPreTransform_c.h"
|
12
|
+
#include "IndexPreTransform.h"
|
13
|
+
#include "macros_impl.h"
|
14
|
+
|
15
|
+
using faiss::Index;
|
16
|
+
using faiss::IndexPreTransform;
|
17
|
+
|
18
|
+
DEFINE_DESTRUCTOR(IndexPreTransform)
|
19
|
+
DEFINE_INDEX_DOWNCAST(IndexPreTransform)
|
20
|
+
|
21
|
+
DEFINE_GETTER_PERMISSIVE(IndexPreTransform, FaissIndex*, index)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
// Copyright 2004-present Facebook. All Rights Reserved.
|
9
|
+
// -*- c -*-
|
10
|
+
|
11
|
+
#ifndef FAISS_INDEX_PRETRANSFORM_C_H
|
12
|
+
#define FAISS_INDEX_PRETRANSFORM_C_H
|
13
|
+
|
14
|
+
#include "faiss_c.h"
|
15
|
+
#include "Index_c.h"
|
16
|
+
|
17
|
+
#ifdef __cplusplus
|
18
|
+
extern "C" {
|
19
|
+
#endif
|
20
|
+
|
21
|
+
FAISS_DECLARE_CLASS(IndexPreTransform)
|
22
|
+
FAISS_DECLARE_DESTRUCTOR(IndexPreTransform)
|
23
|
+
FAISS_DECLARE_INDEX_DOWNCAST(IndexPreTransform)
|
24
|
+
|
25
|
+
FAISS_DECLARE_GETTER(IndexPreTransform, FaissIndex*, index)
|
26
|
+
|
27
|
+
#ifdef __cplusplus
|
28
|
+
}
|
29
|
+
#endif
|
30
|
+
|
31
|
+
|
32
|
+
#endif
|
@@ -0,0 +1,185 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#include <cstdio>
|
9
|
+
#include <cstdlib>
|
10
|
+
|
11
|
+
#include <faiss/Clustering.h>
|
12
|
+
#include <faiss/utils/random.h>
|
13
|
+
#include <faiss/utils/distances.h>
|
14
|
+
#include <faiss/IndexFlat.h>
|
15
|
+
#include <faiss/IndexHNSW.h>
|
16
|
+
|
17
|
+
|
18
|
+
namespace {
|
19
|
+
|
20
|
+
|
21
|
+
enum WeightedKMeansType {
|
22
|
+
WKMT_FlatL2,
|
23
|
+
WKMT_FlatIP,
|
24
|
+
WKMT_FlatIP_spherical,
|
25
|
+
WKMT_HNSW,
|
26
|
+
};
|
27
|
+
|
28
|
+
|
29
|
+
float weighted_kmeans_clustering (size_t d, size_t n, size_t k,
|
30
|
+
const float *input,
|
31
|
+
const float *weights,
|
32
|
+
float *centroids,
|
33
|
+
WeightedKMeansType index_num)
|
34
|
+
{
|
35
|
+
using namespace faiss;
|
36
|
+
Clustering clus (d, k);
|
37
|
+
clus.verbose = true;
|
38
|
+
|
39
|
+
std::unique_ptr<Index> index;
|
40
|
+
|
41
|
+
switch (index_num) {
|
42
|
+
case WKMT_FlatL2:
|
43
|
+
index.reset(new IndexFlatL2 (d));
|
44
|
+
break;
|
45
|
+
case WKMT_FlatIP:
|
46
|
+
index.reset(new IndexFlatIP (d));
|
47
|
+
break;
|
48
|
+
case WKMT_FlatIP_spherical:
|
49
|
+
index.reset(new IndexFlatIP (d));
|
50
|
+
clus.spherical = true;
|
51
|
+
break;
|
52
|
+
case WKMT_HNSW:
|
53
|
+
IndexHNSWFlat *ihnsw = new IndexHNSWFlat (d, 32);
|
54
|
+
ihnsw->hnsw.efSearch = 128;
|
55
|
+
index.reset(ihnsw);
|
56
|
+
break;
|
57
|
+
}
|
58
|
+
|
59
|
+
clus.train(n, input, *index.get(), weights);
|
60
|
+
// on output the index contains the centroids.
|
61
|
+
memcpy(centroids, clus.centroids.data(), sizeof(*centroids) * d * k);
|
62
|
+
return clus.iteration_stats.back().obj;
|
63
|
+
}
|
64
|
+
|
65
|
+
|
66
|
+
int d = 32;
|
67
|
+
float sigma = 0.1;
|
68
|
+
|
69
|
+
#define BIGTEST
|
70
|
+
|
71
|
+
#ifdef BIGTEST
|
72
|
+
// the production setup = setting of https://fb.quip.com/CWgnAAYbwtgs
|
73
|
+
int nc = 200000;
|
74
|
+
int n_big = 4;
|
75
|
+
int n_small = 2;
|
76
|
+
#else
|
77
|
+
int nc = 5;
|
78
|
+
int n_big = 100;
|
79
|
+
int n_small = 10;
|
80
|
+
#endif
|
81
|
+
|
82
|
+
int n; // number of training points
|
83
|
+
|
84
|
+
void generate_trainset (std::vector<float> & ccent,
|
85
|
+
std::vector<float> & x,
|
86
|
+
std::vector<float> & weights)
|
87
|
+
{
|
88
|
+
// same sampling as test_build_blocks.py test_weighted
|
89
|
+
|
90
|
+
ccent.resize (d * 2 * nc);
|
91
|
+
faiss::float_randn (ccent.data(), d * 2 * nc, 123);
|
92
|
+
faiss::fvec_renorm_L2 (d, 2 * nc, ccent.data());
|
93
|
+
n = nc * n_big + nc * n_small;
|
94
|
+
x.resize(d * n);
|
95
|
+
weights.resize(n);
|
96
|
+
faiss::float_randn (x.data(), x.size(), 1234);
|
97
|
+
|
98
|
+
float *xi = x.data();
|
99
|
+
float *w = weights.data();
|
100
|
+
for (int ci = 0; ci < nc * 2; ci++) { // loop over centroids
|
101
|
+
int np = ci < nc ? n_big : n_small; // nb of points around this centroid
|
102
|
+
for (int i = 0; i < np; i++) {
|
103
|
+
for (int j = 0; j < d; j++) {
|
104
|
+
xi[j] = xi[j] * sigma + ccent[ci * d + j];
|
105
|
+
}
|
106
|
+
*w++ = ci < nc ? 0.1 : 10;
|
107
|
+
xi += d;
|
108
|
+
}
|
109
|
+
}
|
110
|
+
}
|
111
|
+
|
112
|
+
}
|
113
|
+
|
114
|
+
|
115
|
+
int main(int argc, char **argv) {
|
116
|
+
std::vector<float> ccent;
|
117
|
+
std::vector<float> x;
|
118
|
+
std::vector<float> weights;
|
119
|
+
|
120
|
+
printf("generate training set\n");
|
121
|
+
generate_trainset(ccent, x, weights);
|
122
|
+
|
123
|
+
std::vector<float> centroids;
|
124
|
+
centroids.resize(nc * d);
|
125
|
+
|
126
|
+
int the_index_num = -1;
|
127
|
+
int the_with_weights = -1;
|
128
|
+
|
129
|
+
if (argc == 3) {
|
130
|
+
the_index_num = atoi(argv[1]);
|
131
|
+
the_with_weights = atoi(argv[2]);
|
132
|
+
}
|
133
|
+
|
134
|
+
|
135
|
+
for (int index_num = WKMT_FlatL2;
|
136
|
+
index_num <= WKMT_HNSW;
|
137
|
+
index_num++) {
|
138
|
+
|
139
|
+
if (the_index_num >= 0 && index_num != the_index_num) {
|
140
|
+
continue;
|
141
|
+
}
|
142
|
+
|
143
|
+
for (int with_weights = 0; with_weights <= 1; with_weights++) {
|
144
|
+
if (the_with_weights >= 0 && with_weights != the_with_weights) {
|
145
|
+
continue;
|
146
|
+
}
|
147
|
+
|
148
|
+
printf("=================== index_num=%d Run %s weights\n",
|
149
|
+
index_num, with_weights ? "with" : "without");
|
150
|
+
|
151
|
+
weighted_kmeans_clustering (
|
152
|
+
d, n, nc, x.data(),
|
153
|
+
with_weights ? weights.data() : nullptr,
|
154
|
+
centroids.data(), (WeightedKMeansType)index_num
|
155
|
+
);
|
156
|
+
|
157
|
+
{ // compute distance of points to centroids
|
158
|
+
faiss::IndexFlatL2 cent_index(d);
|
159
|
+
cent_index.add(nc, centroids.data());
|
160
|
+
std::vector<float> dis (n);
|
161
|
+
std::vector<faiss::Index::idx_t> idx (n);
|
162
|
+
|
163
|
+
cent_index.search (nc * 2, ccent.data(), 1,
|
164
|
+
dis.data(), idx.data());
|
165
|
+
|
166
|
+
float dis1 = 0, dis2 = 0;
|
167
|
+
for (int i = 0; i < nc ; i++) {
|
168
|
+
dis1 += dis[i];
|
169
|
+
}
|
170
|
+
printf("average distance of points from big clusters: %g\n",
|
171
|
+
dis1 / nc);
|
172
|
+
|
173
|
+
for (int i = 0; i < nc ; i++) {
|
174
|
+
dis2 += dis[i + nc];
|
175
|
+
}
|
176
|
+
|
177
|
+
printf("average distance of points from small clusters: %g\n",
|
178
|
+
dis2 / nc);
|
179
|
+
|
180
|
+
}
|
181
|
+
|
182
|
+
}
|
183
|
+
}
|
184
|
+
return 0;
|
185
|
+
}
|