clusterkit 0.2.4 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +11 -11
- data/ext/clusterkit/Cargo.toml +1 -1
- data/ext/clusterkit/src/clustering/hdbscan_wrapper.rs +23 -36
- data/ext/clusterkit/src/clustering.rs +47 -53
- data/ext/clusterkit/src/embedder.rs +44 -52
- data/ext/clusterkit/src/hnsw.rs +181 -215
- data/ext/clusterkit/src/lib.rs +5 -5
- data/ext/clusterkit/src/svd.rs +31 -33
- data/ext/clusterkit/src/utils.rs +24 -21
- data/lib/clusterkit/version.rb +1 -1
- data/lib/clusterkit.rb +1 -1
- metadata +17 -4
- data/clusterkit.gemspec +0 -45
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: 19c31f1a55e35724884b0005bd5e168c572c75b4cb3331a9634f7b00e5722f25
|
|
4
|
+
data.tar.gz: a7620b934689b61e4a88b555885f20eadc92f316cc8f09910fce59fc68ca7445
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: a0104f2e34d261e5d8da181f783f2c30e1127b8895ff2a1be64e4e7ec824e9fce9da5e08aebb0dd2e08d0ddca4b9e9b7c0d14d40335b9ccf25ffd014ba048812
|
|
7
|
+
data.tar.gz: d8c236c53f3e351d210c9de923bb4a77b82b33fe4c1caad4889a2251cfc92addc6980c8c621e7aac6408603d05e1ef529285ccc20d043726d0e7c6473a2fc926
|
data/Cargo.lock
CHANGED
|
@@ -1323,7 +1323,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
|
1323
1323
|
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
|
|
1324
1324
|
dependencies = [
|
|
1325
1325
|
"cfg-if",
|
|
1326
|
-
"windows-targets 0.
|
|
1326
|
+
"windows-targets 0.48.5",
|
|
1327
1327
|
]
|
|
1328
1328
|
|
|
1329
1329
|
[[package]]
|
|
@@ -1382,9 +1382,9 @@ dependencies = [
|
|
|
1382
1382
|
|
|
1383
1383
|
[[package]]
|
|
1384
1384
|
name = "magnus"
|
|
1385
|
-
version = "0.
|
|
1385
|
+
version = "0.8.2"
|
|
1386
1386
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
1387
|
-
checksum = "
|
|
1387
|
+
checksum = "3b36a5b126bbe97eb0d02d07acfeb327036c6319fd816139a49824a83b7f9012"
|
|
1388
1388
|
dependencies = [
|
|
1389
1389
|
"magnus-macros",
|
|
1390
1390
|
"rb-sys",
|
|
@@ -1394,9 +1394,9 @@ dependencies = [
|
|
|
1394
1394
|
|
|
1395
1395
|
[[package]]
|
|
1396
1396
|
name = "magnus-macros"
|
|
1397
|
-
version = "0.
|
|
1397
|
+
version = "0.8.0"
|
|
1398
1398
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
1399
|
-
checksum = "
|
|
1399
|
+
checksum = "47607461fd8e1513cb4f2076c197d8092d921a1ea75bd08af97398f593751892"
|
|
1400
1400
|
dependencies = [
|
|
1401
1401
|
"proc-macro2",
|
|
1402
1402
|
"quote",
|
|
@@ -1970,18 +1970,18 @@ dependencies = [
|
|
|
1970
1970
|
|
|
1971
1971
|
[[package]]
|
|
1972
1972
|
name = "rb-sys"
|
|
1973
|
-
version = "0.9.
|
|
1973
|
+
version = "0.9.124"
|
|
1974
1974
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
1975
|
-
checksum = "
|
|
1975
|
+
checksum = "c85c4188462601e2aa1469def389c17228566f82ea72f137ed096f21591bc489"
|
|
1976
1976
|
dependencies = [
|
|
1977
1977
|
"rb-sys-build",
|
|
1978
1978
|
]
|
|
1979
1979
|
|
|
1980
1980
|
[[package]]
|
|
1981
1981
|
name = "rb-sys-build"
|
|
1982
|
-
version = "0.9.
|
|
1982
|
+
version = "0.9.124"
|
|
1983
1983
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
1984
|
-
checksum = "
|
|
1984
|
+
checksum = "568068db4102230882e6d4ae8de6632e224ca75fe5970f6e026a04e91ed635d3"
|
|
1985
1985
|
dependencies = [
|
|
1986
1986
|
"bindgen",
|
|
1987
1987
|
"lazy_static",
|
|
@@ -1994,9 +1994,9 @@ dependencies = [
|
|
|
1994
1994
|
|
|
1995
1995
|
[[package]]
|
|
1996
1996
|
name = "rb-sys-env"
|
|
1997
|
-
version = "0.
|
|
1997
|
+
version = "0.2.3"
|
|
1998
1998
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
1999
|
-
checksum = "
|
|
1999
|
+
checksum = "cca7ad6a7e21e72151d56fe2495a259b5670e204c3adac41ee7ef676ea08117a"
|
|
2000
2000
|
|
|
2001
2001
|
[[package]]
|
|
2002
2002
|
name = "redox_syscall"
|
data/ext/clusterkit/Cargo.toml
CHANGED
|
@@ -7,7 +7,7 @@ edition = "2021"
|
|
|
7
7
|
crate-type = ["cdylib"]
|
|
8
8
|
|
|
9
9
|
[dependencies]
|
|
10
|
-
magnus = { version = "0.
|
|
10
|
+
magnus = { version = "0.8", features = ["embed"] }
|
|
11
11
|
annembed = { git = "https://github.com/scientist-labs/annembed", tag = "clusterkit-0.1.1" }
|
|
12
12
|
hnsw_rs = { git = "https://github.com/scientist-labs/hnswlib-rs", tag = "clusterkit-0.1.0" }
|
|
13
13
|
hdbscan = "0.11"
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
use magnus::{function, prelude::*, Error, Value,
|
|
1
|
+
use magnus::{function, prelude::*, Error, Value, RHash, Ruby};
|
|
2
2
|
use hdbscan::{Hdbscan, HdbscanHyperParams};
|
|
3
3
|
use crate::utils::ruby_array_to_vec_vec_f64;
|
|
4
4
|
|
|
@@ -10,75 +10,62 @@ pub fn hdbscan_fit(
|
|
|
10
10
|
min_cluster_size: usize,
|
|
11
11
|
metric: String,
|
|
12
12
|
) -> Result<RHash, Error> {
|
|
13
|
+
let ruby = Ruby::get().unwrap();
|
|
14
|
+
|
|
13
15
|
// Convert Ruby array to Vec<Vec<f64>> using shared helper
|
|
14
16
|
let data_vec = ruby_array_to_vec_vec_f64(data)?;
|
|
15
17
|
let n_samples = data_vec.len();
|
|
16
|
-
|
|
17
|
-
// Note: hdbscan crate doesn't support custom metrics directly
|
|
18
|
-
// We'll use the default Euclidean distance for now
|
|
18
|
+
|
|
19
19
|
if metric != "euclidean" && metric != "l2" {
|
|
20
20
|
eprintln!("Warning: Current hdbscan version only supports Euclidean distance. Using Euclidean.");
|
|
21
21
|
}
|
|
22
|
-
|
|
22
|
+
|
|
23
23
|
// Adjust parameters to avoid index out of bounds errors
|
|
24
|
-
// The hdbscan crate has issues when min_samples >= n_samples
|
|
25
24
|
let adjusted_min_samples = min_samples.min(n_samples.saturating_sub(1)).max(1);
|
|
26
25
|
let adjusted_min_cluster_size = min_cluster_size.min(n_samples).max(2);
|
|
27
|
-
|
|
26
|
+
|
|
28
27
|
// Create hyperparameters
|
|
29
28
|
let hyper_params = HdbscanHyperParams::builder()
|
|
30
29
|
.min_cluster_size(adjusted_min_cluster_size)
|
|
31
30
|
.min_samples(adjusted_min_samples)
|
|
32
31
|
.build();
|
|
33
|
-
|
|
32
|
+
|
|
34
33
|
// Create HDBSCAN instance and run clustering
|
|
35
34
|
let clusterer = Hdbscan::new(&data_vec, hyper_params);
|
|
36
|
-
|
|
37
|
-
// Run the clustering algorithm - cluster() returns Result<Vec<i32>, HdbscanError>
|
|
35
|
+
|
|
38
36
|
let labels = clusterer.cluster().map_err(|e| {
|
|
39
37
|
Error::new(
|
|
40
|
-
|
|
38
|
+
ruby.exception_runtime_error(),
|
|
41
39
|
format!("HDBSCAN clustering failed: {:?}", e)
|
|
42
40
|
)
|
|
43
41
|
})?;
|
|
44
|
-
|
|
42
|
+
|
|
45
43
|
// Convert results to Ruby types
|
|
46
|
-
let
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
// Convert labels (i32 to Ruby Integer, -1 for noise)
|
|
50
|
-
let labels_array = RArray::new();
|
|
44
|
+
let result = ruby.hash_new();
|
|
45
|
+
|
|
46
|
+
let labels_array = ruby.ary_new();
|
|
51
47
|
for &label in labels.iter() {
|
|
52
|
-
labels_array.push(
|
|
53
|
-
ruby.eval(&format!("{}", label)).unwrap()
|
|
54
|
-
).unwrap())?;
|
|
48
|
+
labels_array.push(ruby.integer_from_i64(label as i64))?;
|
|
55
49
|
}
|
|
56
50
|
result.aset("labels", labels_array)?;
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
// since the basic hdbscan crate doesn't provide these
|
|
60
|
-
// In the future, we could calculate these ourselves or use a more advanced implementation
|
|
61
|
-
|
|
62
|
-
// Create probabilities array (all 1.0 for clustered points, 0.0 for noise)
|
|
63
|
-
let probs_array = RArray::new();
|
|
51
|
+
|
|
52
|
+
let probs_array = ruby.ary_new();
|
|
64
53
|
for &label in labels.iter() {
|
|
65
54
|
let prob = if label == -1 { 0.0 } else { 1.0 };
|
|
66
55
|
probs_array.push(prob)?;
|
|
67
56
|
}
|
|
68
57
|
result.aset("probabilities", probs_array)?;
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
let outlier_array = RArray::new();
|
|
58
|
+
|
|
59
|
+
let outlier_array = ruby.ary_new();
|
|
72
60
|
for &label in labels.iter() {
|
|
73
61
|
let score = if label == -1 { 1.0 } else { 0.0 };
|
|
74
62
|
outlier_array.push(score)?;
|
|
75
63
|
}
|
|
76
64
|
result.aset("outlier_scores", outlier_array)?;
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
let persistence_hash = RHash::new();
|
|
65
|
+
|
|
66
|
+
let persistence_hash = ruby.hash_new();
|
|
80
67
|
result.aset("cluster_persistence", persistence_hash)?;
|
|
81
|
-
|
|
68
|
+
|
|
82
69
|
Ok(result)
|
|
83
70
|
}
|
|
84
71
|
|
|
@@ -88,6 +75,6 @@ pub fn init(clustering_module: &magnus::RModule) -> Result<(), Error> {
|
|
|
88
75
|
"hdbscan_rust",
|
|
89
76
|
function!(hdbscan_fit, 4),
|
|
90
77
|
)?;
|
|
91
|
-
|
|
78
|
+
|
|
92
79
|
Ok(())
|
|
93
|
-
}
|
|
80
|
+
}
|
|
@@ -1,50 +1,52 @@
|
|
|
1
|
-
use magnus::{function, prelude::*, Error, Value, RArray,
|
|
1
|
+
use magnus::{function, prelude::*, Error, Value, RArray, Ruby};
|
|
2
2
|
use ndarray::{Array1, Array2, ArrayView1, Axis};
|
|
3
3
|
use rand::prelude::*;
|
|
4
4
|
use rand::rngs::StdRng;
|
|
5
5
|
use rand::SeedableRng;
|
|
6
|
-
use crate::utils::
|
|
6
|
+
use crate::utils::ruby_array_to_ndarray2;
|
|
7
7
|
|
|
8
8
|
mod hdbscan_wrapper;
|
|
9
9
|
|
|
10
10
|
pub fn init(parent: &magnus::RModule) -> Result<(), Error> {
|
|
11
11
|
let clustering_module = parent.define_module("Clustering")?;
|
|
12
|
-
|
|
12
|
+
|
|
13
13
|
clustering_module.define_singleton_method(
|
|
14
14
|
"kmeans_rust",
|
|
15
15
|
function!(kmeans, 4),
|
|
16
16
|
)?;
|
|
17
|
-
|
|
17
|
+
|
|
18
18
|
clustering_module.define_singleton_method(
|
|
19
19
|
"kmeans_predict_rust",
|
|
20
20
|
function!(kmeans_predict, 2),
|
|
21
21
|
)?;
|
|
22
|
-
|
|
22
|
+
|
|
23
23
|
// Initialize HDBSCAN functions
|
|
24
24
|
hdbscan_wrapper::init(&clustering_module)?;
|
|
25
|
-
|
|
25
|
+
|
|
26
26
|
Ok(())
|
|
27
27
|
}
|
|
28
28
|
|
|
29
29
|
/// Perform K-means clustering
|
|
30
30
|
/// Returns (labels, centroids, inertia)
|
|
31
31
|
fn kmeans(data: Value, k: usize, max_iter: usize, random_seed: Option<i64>) -> Result<(RArray, RArray, f64), Error> {
|
|
32
|
+
let ruby = Ruby::get().unwrap();
|
|
33
|
+
|
|
32
34
|
// Convert Ruby array to ndarray using shared helper
|
|
33
35
|
let data_array = ruby_array_to_ndarray2(data)?;
|
|
34
36
|
let (n_samples, n_features) = data_array.dim();
|
|
35
|
-
|
|
37
|
+
|
|
36
38
|
if k > n_samples {
|
|
37
39
|
return Err(Error::new(
|
|
38
|
-
|
|
40
|
+
ruby.exception_arg_error(),
|
|
39
41
|
format!("k ({}) cannot be larger than number of samples ({})", k, n_samples),
|
|
40
42
|
));
|
|
41
43
|
}
|
|
42
|
-
|
|
44
|
+
|
|
43
45
|
// Initialize centroids using K-means++
|
|
44
46
|
let mut centroids = kmeans_plusplus(&data_array, k, random_seed)?;
|
|
45
47
|
let mut labels = vec![0usize; n_samples];
|
|
46
48
|
let mut prev_labels = vec![0usize; n_samples];
|
|
47
|
-
|
|
49
|
+
|
|
48
50
|
// K-means iterations
|
|
49
51
|
for iteration in 0..max_iter {
|
|
50
52
|
// Assign points to nearest centroid
|
|
@@ -53,7 +55,7 @@ fn kmeans(data: Value, k: usize, max_iter: usize, random_seed: Option<i64>) -> R
|
|
|
53
55
|
let point = data_array.row(i);
|
|
54
56
|
let mut min_dist = f64::INFINITY;
|
|
55
57
|
let mut best_cluster = 0;
|
|
56
|
-
|
|
58
|
+
|
|
57
59
|
for (j, centroid) in centroids.axis_iter(Axis(0)).enumerate() {
|
|
58
60
|
let dist = euclidean_distance(&point, ¢roid);
|
|
59
61
|
if dist < min_dist {
|
|
@@ -61,38 +63,38 @@ fn kmeans(data: Value, k: usize, max_iter: usize, random_seed: Option<i64>) -> R
|
|
|
61
63
|
best_cluster = j;
|
|
62
64
|
}
|
|
63
65
|
}
|
|
64
|
-
|
|
66
|
+
|
|
65
67
|
if labels[i] != best_cluster {
|
|
66
68
|
changed = true;
|
|
67
69
|
}
|
|
68
70
|
labels[i] = best_cluster;
|
|
69
71
|
}
|
|
70
|
-
|
|
72
|
+
|
|
71
73
|
// Check for convergence
|
|
72
74
|
if !changed && iteration > 0 {
|
|
73
75
|
break;
|
|
74
76
|
}
|
|
75
|
-
|
|
77
|
+
|
|
76
78
|
// Update centroids
|
|
77
79
|
for j in 0..k {
|
|
78
80
|
let mut sum = Array1::<f64>::zeros(n_features);
|
|
79
81
|
let mut count = 0;
|
|
80
|
-
|
|
82
|
+
|
|
81
83
|
for i in 0..n_samples {
|
|
82
84
|
if labels[i] == j {
|
|
83
85
|
sum += &data_array.row(i);
|
|
84
86
|
count += 1;
|
|
85
87
|
}
|
|
86
88
|
}
|
|
87
|
-
|
|
89
|
+
|
|
88
90
|
if count > 0 {
|
|
89
91
|
centroids.row_mut(j).assign(&(sum / count as f64));
|
|
90
92
|
}
|
|
91
93
|
}
|
|
92
|
-
|
|
94
|
+
|
|
93
95
|
prev_labels.clone_from(&labels);
|
|
94
96
|
}
|
|
95
|
-
|
|
97
|
+
|
|
96
98
|
// Calculate inertia (sum of squared distances to nearest centroid)
|
|
97
99
|
let mut inertia = 0.0;
|
|
98
100
|
for i in 0..n_samples {
|
|
@@ -100,44 +102,43 @@ fn kmeans(data: Value, k: usize, max_iter: usize, random_seed: Option<i64>) -> R
|
|
|
100
102
|
let centroid = centroids.row(labels[i]);
|
|
101
103
|
inertia += euclidean_distance(&point, ¢roid).powi(2);
|
|
102
104
|
}
|
|
103
|
-
|
|
105
|
+
|
|
104
106
|
// Convert results to Ruby arrays
|
|
105
|
-
let
|
|
106
|
-
let labels_array = RArray::new();
|
|
107
|
+
let labels_array = ruby.ary_new();
|
|
107
108
|
for label in labels {
|
|
108
|
-
labels_array.push(
|
|
109
|
+
labels_array.push(ruby.integer_from_i64(label as i64))?;
|
|
109
110
|
}
|
|
110
|
-
|
|
111
|
-
let centroids_array =
|
|
111
|
+
|
|
112
|
+
let centroids_array = ruby.ary_new();
|
|
112
113
|
for i in 0..k {
|
|
113
|
-
let row_array =
|
|
114
|
+
let row_array = ruby.ary_new();
|
|
114
115
|
for j in 0..n_features {
|
|
115
116
|
row_array.push(centroids[[i, j]])?;
|
|
116
117
|
}
|
|
117
118
|
centroids_array.push(row_array)?;
|
|
118
119
|
}
|
|
119
|
-
|
|
120
|
+
|
|
120
121
|
Ok((labels_array, centroids_array, inertia))
|
|
121
122
|
}
|
|
122
123
|
|
|
123
124
|
/// Predict cluster labels for new data given centroids
|
|
124
125
|
fn kmeans_predict(data: Value, centroids: Value) -> Result<RArray, Error> {
|
|
126
|
+
let ruby = Ruby::get().unwrap();
|
|
127
|
+
|
|
125
128
|
// Convert inputs using shared helpers
|
|
126
129
|
let data_matrix = ruby_array_to_ndarray2(data)?;
|
|
127
130
|
let centroids_matrix = ruby_array_to_ndarray2(centroids)?;
|
|
128
|
-
|
|
131
|
+
|
|
129
132
|
let (n_samples, _) = data_matrix.dim();
|
|
130
|
-
|
|
131
|
-
|
|
133
|
+
|
|
132
134
|
// Predict labels
|
|
133
|
-
let
|
|
134
|
-
|
|
135
|
-
|
|
135
|
+
let labels_array = ruby.ary_new();
|
|
136
|
+
|
|
136
137
|
for i in 0..n_samples {
|
|
137
138
|
let point = data_matrix.row(i);
|
|
138
139
|
let mut min_dist = f64::INFINITY;
|
|
139
140
|
let mut best_cluster = 0;
|
|
140
|
-
|
|
141
|
+
|
|
141
142
|
for (j, centroid) in centroids_matrix.axis_iter(Axis(0)).enumerate() {
|
|
142
143
|
let dist = euclidean_distance(&point, ¢roid);
|
|
143
144
|
if dist < min_dist {
|
|
@@ -145,10 +146,10 @@ fn kmeans_predict(data: Value, centroids: Value) -> Result<RArray, Error> {
|
|
|
145
146
|
best_cluster = j;
|
|
146
147
|
}
|
|
147
148
|
}
|
|
148
|
-
|
|
149
|
-
labels_array.push(
|
|
149
|
+
|
|
150
|
+
labels_array.push(ruby.integer_from_i64(best_cluster as i64))?;
|
|
150
151
|
}
|
|
151
|
-
|
|
152
|
+
|
|
152
153
|
Ok(labels_array)
|
|
153
154
|
}
|
|
154
155
|
|
|
@@ -156,28 +157,26 @@ fn kmeans_predict(data: Value, centroids: Value) -> Result<RArray, Error> {
|
|
|
156
157
|
fn kmeans_plusplus(data: &Array2<f64>, k: usize, random_seed: Option<i64>) -> Result<Array2<f64>, Error> {
|
|
157
158
|
let n_samples = data.nrows();
|
|
158
159
|
let n_features = data.ncols();
|
|
159
|
-
|
|
160
|
+
|
|
160
161
|
// Use seeded RNG if seed is provided, otherwise use thread_rng
|
|
161
162
|
let mut rng: Box<dyn RngCore> = match random_seed {
|
|
162
163
|
Some(seed) => {
|
|
163
|
-
// Convert i64 to u64 for seeding (negative numbers wrap around)
|
|
164
164
|
let seed_u64 = seed as u64;
|
|
165
165
|
Box::new(StdRng::seed_from_u64(seed_u64))
|
|
166
166
|
},
|
|
167
167
|
None => Box::new(thread_rng()),
|
|
168
168
|
};
|
|
169
|
-
|
|
169
|
+
|
|
170
170
|
let mut centroids = Array2::<f64>::zeros((k, n_features));
|
|
171
|
-
|
|
171
|
+
|
|
172
172
|
// Choose first centroid randomly
|
|
173
173
|
let first_idx = rng.gen_range(0..n_samples);
|
|
174
174
|
centroids.row_mut(0).assign(&data.row(first_idx));
|
|
175
|
-
|
|
175
|
+
|
|
176
176
|
// Choose remaining centroids
|
|
177
177
|
for i in 1..k {
|
|
178
178
|
let mut distances = vec![f64::INFINITY; n_samples];
|
|
179
|
-
|
|
180
|
-
// Calculate distance to nearest centroid for each point
|
|
179
|
+
|
|
181
180
|
for j in 0..n_samples {
|
|
182
181
|
for c in 0..i {
|
|
183
182
|
let dist = euclidean_distance(&data.row(j), ¢roids.row(c));
|
|
@@ -186,25 +185,20 @@ fn kmeans_plusplus(data: &Array2<f64>, k: usize, random_seed: Option<i64>) -> Re
|
|
|
186
185
|
}
|
|
187
186
|
}
|
|
188
187
|
}
|
|
189
|
-
|
|
190
|
-
// Convert distances to probabilities
|
|
188
|
+
|
|
191
189
|
let total: f64 = distances.iter().map(|d| d * d).sum();
|
|
192
190
|
if total == 0.0 {
|
|
193
|
-
// All points are identical or we've selected duplicates
|
|
194
|
-
// Just use sequential points as centroids
|
|
195
191
|
if i < n_samples {
|
|
196
192
|
centroids.row_mut(i).assign(&data.row(i));
|
|
197
193
|
} else {
|
|
198
|
-
// Reuse first point if we run out
|
|
199
194
|
centroids.row_mut(i).assign(&data.row(0));
|
|
200
195
|
}
|
|
201
196
|
continue;
|
|
202
197
|
}
|
|
203
|
-
|
|
204
|
-
// Choose next centroid with probability proportional to squared distance
|
|
198
|
+
|
|
205
199
|
let mut cumsum = 0.0;
|
|
206
200
|
let rand_val: f64 = rng.gen::<f64>() * total;
|
|
207
|
-
|
|
201
|
+
|
|
208
202
|
for j in 0..n_samples {
|
|
209
203
|
cumsum += distances[j] * distances[j];
|
|
210
204
|
if cumsum >= rand_val {
|
|
@@ -213,7 +207,7 @@ fn kmeans_plusplus(data: &Array2<f64>, k: usize, random_seed: Option<i64>) -> Re
|
|
|
213
207
|
}
|
|
214
208
|
}
|
|
215
209
|
}
|
|
216
|
-
|
|
210
|
+
|
|
217
211
|
Ok(centroids)
|
|
218
212
|
}
|
|
219
213
|
|
|
@@ -224,4 +218,4 @@ fn euclidean_distance(a: &ArrayView1<f64>, b: &ArrayView1<f64>) -> f64 {
|
|
|
224
218
|
.map(|(x, y)| (x - y).powi(2))
|
|
225
219
|
.sum::<f64>()
|
|
226
220
|
.sqrt()
|
|
227
|
-
}
|
|
221
|
+
}
|