clusterkit 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/Cargo.lock +3236 -0
- data/README.md +227 -7
- data/docs/KNOWN_ISSUES.md +5 -5
- data/docs/RUST_ERROR_HANDLING.md +6 -6
- data/docs/assets/clusterkit-wide.png +0 -0
- data/docs/assets/clusterkit.png +0 -0
- data/docs/assets/visualization.png +0 -0
- data/ext/clusterkit/Cargo.toml +5 -4
- data/ext/clusterkit/extconf.rb +9 -1
- data/ext/clusterkit/src/clustering/hdbscan_wrapper.rs +27 -62
- data/ext/clusterkit/src/clustering.rs +68 -114
- data/ext/clusterkit/src/embedder.rs +48 -131
- data/ext/clusterkit/src/hnsw.rs +579 -0
- data/ext/clusterkit/src/lib.rs +7 -5
- data/ext/clusterkit/src/svd.rs +35 -58
- data/ext/clusterkit/src/utils.rs +159 -9
- data/lib/clusterkit/clustering/hdbscan.rb +4 -17
- data/lib/clusterkit/clustering.rb +4 -23
- data/lib/clusterkit/data_validator.rb +132 -0
- data/lib/clusterkit/dimensionality/pca.rb +12 -12
- data/lib/clusterkit/dimensionality/svd.rb +47 -16
- data/lib/clusterkit/dimensionality/umap.rb +7 -40
- data/lib/clusterkit/hnsw.rb +251 -0
- data/lib/clusterkit/version.rb +1 -1
- data/lib/clusterkit.rb +2 -1
- metadata +40 -20
- data/clusterkit.gemspec +0 -45
data/ext/clusterkit/src/svd.rs
CHANGED
|
@@ -1,112 +1,89 @@
|
|
|
1
|
-
use magnus::{function, prelude::*, Error, Value, RArray,
|
|
1
|
+
use magnus::{function, prelude::*, Error, Value, RArray, Ruby};
|
|
2
2
|
use annembed::tools::svdapprox::{SvdApprox, RangeApproxMode, RangeRank, MatRepr};
|
|
3
|
-
use
|
|
3
|
+
use crate::utils::ruby_array_to_ndarray2;
|
|
4
4
|
|
|
5
5
|
pub fn init(parent: &magnus::RModule) -> Result<(), Error> {
|
|
6
6
|
let svd_module = parent.define_module("SVD")?;
|
|
7
|
-
|
|
7
|
+
|
|
8
8
|
svd_module.define_singleton_method(
|
|
9
9
|
"randomized_svd_rust",
|
|
10
10
|
function!(randomized_svd, 3),
|
|
11
11
|
)?;
|
|
12
|
-
|
|
12
|
+
|
|
13
13
|
Ok(())
|
|
14
14
|
}
|
|
15
15
|
|
|
16
16
|
fn randomized_svd(matrix: Value, k: usize, n_iter: usize) -> Result<RArray, Error> {
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
let
|
|
22
|
-
|
|
23
|
-
let n_cols = first_row.len();
|
|
24
|
-
|
|
25
|
-
if n_rows == 0 || n_cols == 0 {
|
|
26
|
-
return Err(Error::new(
|
|
27
|
-
magnus::exception::arg_error(),
|
|
28
|
-
"Matrix cannot be empty",
|
|
29
|
-
));
|
|
30
|
-
}
|
|
31
|
-
|
|
17
|
+
let ruby = Ruby::get().unwrap();
|
|
18
|
+
|
|
19
|
+
// Convert Ruby array to ndarray using shared helper
|
|
20
|
+
let matrix_data = ruby_array_to_ndarray2(matrix)?;
|
|
21
|
+
let (n_rows, n_cols) = matrix_data.dim();
|
|
22
|
+
|
|
32
23
|
if k > n_rows.min(n_cols) {
|
|
33
24
|
return Err(Error::new(
|
|
34
|
-
|
|
25
|
+
ruby.exception_arg_error(),
|
|
35
26
|
format!("k ({}) cannot be larger than min(rows, cols) = {}", k, n_rows.min(n_cols)),
|
|
36
27
|
));
|
|
37
28
|
}
|
|
38
|
-
|
|
39
|
-
// Convert to ndarray Array2
|
|
40
|
-
let mut matrix_data = Array2::<f64>::zeros((n_rows, n_cols));
|
|
41
|
-
for i in 0..n_rows {
|
|
42
|
-
let row: RArray = rarray.entry(i as isize)?;
|
|
43
|
-
for j in 0..n_cols {
|
|
44
|
-
let val: f64 = row.entry(j as isize)?;
|
|
45
|
-
matrix_data[[i, j]] = val;
|
|
46
|
-
}
|
|
47
|
-
}
|
|
48
|
-
|
|
29
|
+
|
|
49
30
|
// Create MatRepr for the full matrix
|
|
50
31
|
let mat_repr = MatRepr::from_array2(matrix_data.clone());
|
|
51
|
-
|
|
32
|
+
|
|
52
33
|
// Create SvdApprox instance
|
|
53
34
|
let mut svd_approx = SvdApprox::new(&mat_repr);
|
|
54
|
-
|
|
35
|
+
|
|
55
36
|
// Set up parameters for randomized SVD
|
|
56
|
-
// Use RANK mode to specify the desired rank
|
|
57
37
|
let params = RangeApproxMode::RANK(RangeRank::new(k, n_iter));
|
|
58
|
-
|
|
38
|
+
|
|
59
39
|
// Perform SVD
|
|
60
40
|
let svd_result = svd_approx.direct_svd(params)
|
|
61
|
-
.map_err(|e| Error::new(
|
|
62
|
-
|
|
63
|
-
// Extract U, S, V from the result
|
|
41
|
+
.map_err(|e| Error::new(ruby.exception_runtime_error(), e))?;
|
|
42
|
+
|
|
43
|
+
// Extract U, S, V from the result
|
|
64
44
|
let u_matrix = svd_result.u.ok_or_else(|| {
|
|
65
|
-
Error::new(
|
|
45
|
+
Error::new(ruby.exception_runtime_error(), "No U matrix in SVD result")
|
|
66
46
|
})?;
|
|
67
|
-
|
|
47
|
+
|
|
68
48
|
let s_values = svd_result.s.ok_or_else(|| {
|
|
69
|
-
Error::new(
|
|
49
|
+
Error::new(ruby.exception_runtime_error(), "No S values in SVD result")
|
|
70
50
|
})?;
|
|
71
|
-
|
|
51
|
+
|
|
72
52
|
let vt_matrix = svd_result.vt.ok_or_else(|| {
|
|
73
|
-
Error::new(
|
|
53
|
+
Error::new(ruby.exception_runtime_error(), "No V^T matrix in SVD result")
|
|
74
54
|
})?;
|
|
75
|
-
|
|
55
|
+
|
|
76
56
|
// Convert results to Ruby arrays
|
|
77
|
-
|
|
78
|
-
let u_ruby = RArray::new();
|
|
57
|
+
let u_ruby = ruby.ary_new();
|
|
79
58
|
let u_shape = u_matrix.shape();
|
|
80
59
|
for i in 0..u_shape[0] {
|
|
81
|
-
let row =
|
|
60
|
+
let row = ruby.ary_new();
|
|
82
61
|
for j in 0..u_shape[1] {
|
|
83
62
|
row.push(u_matrix[[i, j]])?;
|
|
84
63
|
}
|
|
85
64
|
u_ruby.push(row)?;
|
|
86
65
|
}
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
let s_ruby = RArray::new();
|
|
66
|
+
|
|
67
|
+
let s_ruby = ruby.ary_new();
|
|
90
68
|
for val in s_values.iter() {
|
|
91
69
|
s_ruby.push(*val)?;
|
|
92
70
|
}
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
let v_ruby = RArray::new();
|
|
71
|
+
|
|
72
|
+
let v_ruby = ruby.ary_new();
|
|
96
73
|
let vt_shape = vt_matrix.shape();
|
|
97
74
|
for i in 0..vt_shape[0] {
|
|
98
|
-
let row =
|
|
75
|
+
let row = ruby.ary_new();
|
|
99
76
|
for j in 0..vt_shape[1] {
|
|
100
77
|
row.push(vt_matrix[[i, j]])?;
|
|
101
78
|
}
|
|
102
79
|
v_ruby.push(row)?;
|
|
103
80
|
}
|
|
104
|
-
|
|
81
|
+
|
|
105
82
|
// Return [U, S, V^T] as a Ruby array
|
|
106
|
-
let result =
|
|
83
|
+
let result = ruby.ary_new();
|
|
107
84
|
result.push(u_ruby)?;
|
|
108
85
|
result.push(s_ruby)?;
|
|
109
86
|
result.push(v_ruby)?;
|
|
110
|
-
|
|
87
|
+
|
|
111
88
|
Ok(result)
|
|
112
|
-
}
|
|
89
|
+
}
|
data/ext/clusterkit/src/utils.rs
CHANGED
|
@@ -1,33 +1,183 @@
|
|
|
1
|
-
use magnus::{function, prelude::*, Error, Value};
|
|
1
|
+
use magnus::{function, prelude::*, Error, Value, RArray, TryConvert, Float, Integer, Ruby};
|
|
2
|
+
use ndarray::Array2;
|
|
2
3
|
|
|
3
4
|
pub fn init(parent: &magnus::RModule) -> Result<(), Error> {
|
|
4
5
|
let utils_module = parent.define_module("Utils")?;
|
|
5
|
-
|
|
6
|
+
|
|
6
7
|
utils_module.define_singleton_method(
|
|
7
8
|
"estimate_intrinsic_dimension_rust",
|
|
8
9
|
function!(estimate_intrinsic_dimension, 2),
|
|
9
10
|
)?;
|
|
10
|
-
|
|
11
|
+
|
|
11
12
|
utils_module.define_singleton_method(
|
|
12
13
|
"estimate_hubness_rust",
|
|
13
14
|
function!(estimate_hubness, 1),
|
|
14
15
|
)?;
|
|
15
|
-
|
|
16
|
+
|
|
16
17
|
Ok(())
|
|
17
18
|
}
|
|
18
19
|
|
|
19
20
|
fn estimate_intrinsic_dimension(_data: Value, _k_neighbors: usize) -> Result<f64, Error> {
|
|
20
|
-
|
|
21
|
+
let ruby = Ruby::get().unwrap();
|
|
21
22
|
Err(Error::new(
|
|
22
|
-
|
|
23
|
+
ruby.exception_not_imp_error(),
|
|
23
24
|
"Dimension estimation not implemented yet",
|
|
24
25
|
))
|
|
25
26
|
}
|
|
26
27
|
|
|
27
28
|
fn estimate_hubness(_data: Value) -> Result<Value, Error> {
|
|
28
|
-
|
|
29
|
+
let ruby = Ruby::get().unwrap();
|
|
29
30
|
Err(Error::new(
|
|
30
|
-
|
|
31
|
+
ruby.exception_not_imp_error(),
|
|
31
32
|
"Hubness estimation not implemented yet",
|
|
32
33
|
))
|
|
33
|
-
}
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
/// Convert Ruby 2D array to ndarray Array2<f64>
|
|
37
|
+
/// Handles validation and provides consistent error messages
|
|
38
|
+
pub fn ruby_array_to_ndarray2(data: Value) -> Result<Array2<f64>, Error> {
|
|
39
|
+
let ruby = Ruby::get().unwrap();
|
|
40
|
+
let rarray: RArray = TryConvert::try_convert(data)?;
|
|
41
|
+
let n_samples = rarray.len();
|
|
42
|
+
|
|
43
|
+
if n_samples == 0 {
|
|
44
|
+
return Err(Error::new(
|
|
45
|
+
ruby.exception_arg_error(),
|
|
46
|
+
"Data cannot be empty",
|
|
47
|
+
));
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
// Get dimensions from first row
|
|
51
|
+
let first_row: RArray = rarray.entry::<RArray>(0)?;
|
|
52
|
+
let n_features = first_row.len();
|
|
53
|
+
|
|
54
|
+
if n_features == 0 {
|
|
55
|
+
return Err(Error::new(
|
|
56
|
+
ruby.exception_arg_error(),
|
|
57
|
+
"Data rows cannot be empty",
|
|
58
|
+
));
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
// Create ndarray and populate
|
|
62
|
+
let mut data_array = Array2::<f64>::zeros((n_samples, n_features));
|
|
63
|
+
for i in 0..n_samples {
|
|
64
|
+
let row: RArray = rarray.entry(i as isize)?;
|
|
65
|
+
|
|
66
|
+
// Validate row length consistency
|
|
67
|
+
if row.len() != n_features {
|
|
68
|
+
return Err(Error::new(
|
|
69
|
+
ruby.exception_arg_error(),
|
|
70
|
+
format!("Row {} has {} elements, expected {}", i, row.len(), n_features),
|
|
71
|
+
));
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
for j in 0..n_features {
|
|
75
|
+
let val: f64 = row.entry(j as isize)?;
|
|
76
|
+
data_array[[i, j]] = val;
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
Ok(data_array)
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
/// Convert Ruby 2D array to Vec<Vec<f64>>
|
|
84
|
+
/// Handles validation and provides consistent error messages
|
|
85
|
+
pub fn ruby_array_to_vec_vec_f64(data: Value) -> Result<Vec<Vec<f64>>, Error> {
|
|
86
|
+
let ruby = Ruby::get().unwrap();
|
|
87
|
+
let rarray: RArray = TryConvert::try_convert(data)?;
|
|
88
|
+
let n_samples = rarray.len();
|
|
89
|
+
|
|
90
|
+
if n_samples == 0 {
|
|
91
|
+
return Err(Error::new(
|
|
92
|
+
ruby.exception_arg_error(),
|
|
93
|
+
"Data cannot be empty",
|
|
94
|
+
));
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
let mut data_vec: Vec<Vec<f64>> = Vec::with_capacity(n_samples);
|
|
98
|
+
let mut expected_features: Option<usize> = None;
|
|
99
|
+
|
|
100
|
+
for i in 0..n_samples {
|
|
101
|
+
let row: RArray = rarray.entry(i as isize)?;
|
|
102
|
+
let n_features = row.len();
|
|
103
|
+
|
|
104
|
+
// Check row length consistency
|
|
105
|
+
match expected_features {
|
|
106
|
+
Some(expected) => {
|
|
107
|
+
if n_features != expected {
|
|
108
|
+
return Err(Error::new(
|
|
109
|
+
ruby.exception_arg_error(),
|
|
110
|
+
format!("Row {} has {} elements, expected {}", i, n_features, expected),
|
|
111
|
+
));
|
|
112
|
+
}
|
|
113
|
+
}
|
|
114
|
+
None => expected_features = Some(n_features),
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
let mut row_vec: Vec<f64> = Vec::with_capacity(n_features);
|
|
118
|
+
for j in 0..n_features {
|
|
119
|
+
let val: f64 = row.entry(j as isize)?;
|
|
120
|
+
row_vec.push(val);
|
|
121
|
+
}
|
|
122
|
+
data_vec.push(row_vec);
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
Ok(data_vec)
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
/// Convert Ruby 2D array to Vec<Vec<f32>>
|
|
129
|
+
/// For algorithms that require f32 precision (like UMAP)
|
|
130
|
+
pub fn ruby_array_to_vec_vec_f32(data: Value) -> Result<Vec<Vec<f32>>, Error> {
|
|
131
|
+
let ruby = Ruby::get().unwrap();
|
|
132
|
+
let rarray: RArray = TryConvert::try_convert(data)?;
|
|
133
|
+
let array_len = rarray.len();
|
|
134
|
+
|
|
135
|
+
if array_len == 0 {
|
|
136
|
+
return Err(Error::new(
|
|
137
|
+
ruby.exception_arg_error(),
|
|
138
|
+
"Input data cannot be empty",
|
|
139
|
+
));
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
let mut rust_data: Vec<Vec<f32>> = Vec::with_capacity(array_len);
|
|
143
|
+
|
|
144
|
+
for i in 0..array_len {
|
|
145
|
+
let row = rarray.entry::<Value>(i as isize)?;
|
|
146
|
+
let row_array = RArray::try_convert(row).map_err(|_| {
|
|
147
|
+
Error::new(
|
|
148
|
+
ruby.exception_type_error(),
|
|
149
|
+
"Expected array of arrays (2D array)",
|
|
150
|
+
)
|
|
151
|
+
})?;
|
|
152
|
+
|
|
153
|
+
let mut rust_row: Vec<f32> = Vec::new();
|
|
154
|
+
let row_len = row_array.len();
|
|
155
|
+
|
|
156
|
+
for j in 0..row_len {
|
|
157
|
+
let val = row_array.entry::<Value>(j as isize)?;
|
|
158
|
+
let float_val = if let Ok(f) = Float::try_convert(val) {
|
|
159
|
+
f.to_f64() as f32
|
|
160
|
+
} else if let Ok(i) = Integer::try_convert(val) {
|
|
161
|
+
i.to_i64()? as f32
|
|
162
|
+
} else {
|
|
163
|
+
return Err(Error::new(
|
|
164
|
+
ruby.exception_type_error(),
|
|
165
|
+
"All values must be numeric",
|
|
166
|
+
));
|
|
167
|
+
};
|
|
168
|
+
rust_row.push(float_val);
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
// Validate row length consistency
|
|
172
|
+
if !rust_data.is_empty() && rust_row.len() != rust_data[0].len() {
|
|
173
|
+
return Err(Error::new(
|
|
174
|
+
ruby.exception_arg_error(),
|
|
175
|
+
"All rows must have the same length",
|
|
176
|
+
));
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
rust_data.push(rust_row);
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
Ok(rust_data)
|
|
183
|
+
}
|
|
@@ -1,5 +1,7 @@
|
|
|
1
1
|
# frozen_string_literal: true
|
|
2
2
|
|
|
3
|
+
require_relative '../data_validator'
|
|
4
|
+
|
|
3
5
|
module ClusterKit
|
|
4
6
|
module Clustering
|
|
5
7
|
# HDBSCAN clustering algorithm - matching KMeans API pattern
|
|
@@ -128,23 +130,8 @@ module ClusterKit
|
|
|
128
130
|
private
|
|
129
131
|
|
|
130
132
|
def validate_data(data)
|
|
131
|
-
#
|
|
132
|
-
|
|
133
|
-
raise ArgumentError, "Data cannot be empty" if data.empty?
|
|
134
|
-
raise ArgumentError, "Data must be 2D array" unless data.first.is_a?(Array)
|
|
135
|
-
|
|
136
|
-
row_length = data.first.length
|
|
137
|
-
unless data.all? { |row| row.is_a?(Array) && row.length == row_length }
|
|
138
|
-
raise ArgumentError, "All rows must have the same length"
|
|
139
|
-
end
|
|
140
|
-
|
|
141
|
-
data.each_with_index do |row, i|
|
|
142
|
-
row.each_with_index do |val, j|
|
|
143
|
-
unless val.is_a?(Numeric)
|
|
144
|
-
raise ArgumentError, "Element at position [#{i}, #{j}] is not numeric"
|
|
145
|
-
end
|
|
146
|
-
end
|
|
147
|
-
end
|
|
133
|
+
# Use same validation as KMeans for consistency
|
|
134
|
+
DataValidator.validate_clustering(data, check_finite: false)
|
|
148
135
|
end
|
|
149
136
|
end
|
|
150
137
|
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
require_relative 'clusterkit'
|
|
4
4
|
require_relative 'clustering/hdbscan'
|
|
5
|
+
require_relative 'data_validator'
|
|
5
6
|
|
|
6
7
|
module ClusterKit
|
|
7
8
|
# Module for clustering algorithms
|
|
@@ -28,11 +29,8 @@ module ClusterKit
|
|
|
28
29
|
def fit(data)
|
|
29
30
|
validate_data(data)
|
|
30
31
|
|
|
31
|
-
#
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
# Call Rust implementation
|
|
35
|
-
@labels, @centroids, @inertia = Clustering.kmeans_rust(data, @k, @max_iter)
|
|
32
|
+
# Call Rust implementation with optional seed
|
|
33
|
+
@labels, @centroids, @inertia = Clustering.kmeans_rust(data, @k, @max_iter, @random_seed)
|
|
36
34
|
@fitted = true
|
|
37
35
|
|
|
38
36
|
self
|
|
@@ -132,24 +130,7 @@ module ClusterKit
|
|
|
132
130
|
private
|
|
133
131
|
|
|
134
132
|
def validate_data(data)
|
|
135
|
-
|
|
136
|
-
raise ArgumentError, "Data cannot be empty" if data.empty?
|
|
137
|
-
raise ArgumentError, "Data must be 2D array" unless data.first.is_a?(Array)
|
|
138
|
-
|
|
139
|
-
# Check all rows have same length
|
|
140
|
-
row_length = data.first.length
|
|
141
|
-
unless data.all? { |row| row.is_a?(Array) && row.length == row_length }
|
|
142
|
-
raise ArgumentError, "All rows must have the same length"
|
|
143
|
-
end
|
|
144
|
-
|
|
145
|
-
# Check all values are numeric
|
|
146
|
-
data.each_with_index do |row, i|
|
|
147
|
-
row.each_with_index do |val, j|
|
|
148
|
-
unless val.is_a?(Numeric)
|
|
149
|
-
raise ArgumentError, "Element at position [#{i}, #{j}] is not numeric"
|
|
150
|
-
end
|
|
151
|
-
end
|
|
152
|
-
end
|
|
133
|
+
DataValidator.validate_clustering(data, check_finite: false)
|
|
153
134
|
end
|
|
154
135
|
end
|
|
155
136
|
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
3
|
+
module ClusterKit
|
|
4
|
+
# Shared data validation methods for all algorithms
|
|
5
|
+
module DataValidator
|
|
6
|
+
class << self
|
|
7
|
+
# Validate basic data structure and types
|
|
8
|
+
# @param data [Array] Data to validate
|
|
9
|
+
# @raise [ArgumentError] If data structure is invalid
|
|
10
|
+
def validate_basic_structure(data)
|
|
11
|
+
raise ArgumentError, "Input must be an array" unless data.is_a?(Array)
|
|
12
|
+
raise ArgumentError, "Input cannot be empty" if data.empty?
|
|
13
|
+
|
|
14
|
+
first_row = data.first
|
|
15
|
+
raise ArgumentError, "Input must be a 2D array (array of arrays)" unless first_row.is_a?(Array)
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
# Validate row consistency (all rows have same length)
|
|
19
|
+
# @param data [Array] 2D array to validate
|
|
20
|
+
# @raise [ArgumentError] If rows have different lengths
|
|
21
|
+
def validate_row_consistency(data)
|
|
22
|
+
row_length = data.first.length
|
|
23
|
+
|
|
24
|
+
data.each_with_index do |row, i|
|
|
25
|
+
unless row.is_a?(Array)
|
|
26
|
+
raise ArgumentError, "Row #{i} is not an array"
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
if row.length != row_length
|
|
30
|
+
raise ArgumentError, "All rows must have the same length (row #{i} has #{row.length} elements, expected #{row_length})"
|
|
31
|
+
end
|
|
32
|
+
end
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
# Validate that all elements are numeric
|
|
36
|
+
# @param data [Array] 2D array to validate
|
|
37
|
+
# @raise [ArgumentError] If any element is not numeric
|
|
38
|
+
def validate_numeric_types(data)
|
|
39
|
+
data.each_with_index do |row, i|
|
|
40
|
+
row.each_with_index do |val, j|
|
|
41
|
+
unless val.is_a?(Numeric)
|
|
42
|
+
raise ArgumentError, "Element at position [#{i}, #{j}] is not numeric"
|
|
43
|
+
end
|
|
44
|
+
end
|
|
45
|
+
end
|
|
46
|
+
end
|
|
47
|
+
|
|
48
|
+
# Validate finite values (no NaN or Infinite)
|
|
49
|
+
# @param data [Array] 2D array to validate
|
|
50
|
+
# @raise [ArgumentError] If any float is NaN or Infinite
|
|
51
|
+
def validate_finite_values(data)
|
|
52
|
+
data.each_with_index do |row, i|
|
|
53
|
+
row.each_with_index do |val, j|
|
|
54
|
+
# Only check for NaN/Infinite on floats
|
|
55
|
+
if val.is_a?(Float) && (val.nan? || val.infinite?)
|
|
56
|
+
raise ArgumentError, "Element at position [#{i}, #{j}] is NaN or Infinite"
|
|
57
|
+
end
|
|
58
|
+
end
|
|
59
|
+
end
|
|
60
|
+
end
|
|
61
|
+
|
|
62
|
+
# Standard validation for most algorithms
|
|
63
|
+
# @param data [Array] 2D array to validate
|
|
64
|
+
# @param check_finite [Boolean] Whether to check for NaN/Infinite values
|
|
65
|
+
# @raise [ArgumentError] If data is invalid
|
|
66
|
+
def validate_standard(data, check_finite: true)
|
|
67
|
+
validate_basic_structure(data)
|
|
68
|
+
validate_row_consistency(data)
|
|
69
|
+
validate_numeric_types(data)
|
|
70
|
+
validate_finite_values(data) if check_finite
|
|
71
|
+
end
|
|
72
|
+
|
|
73
|
+
# Validation for clustering algorithms (KMeans, HDBSCAN) with specific error messages
|
|
74
|
+
# @param data [Array] 2D array to validate
|
|
75
|
+
# @param check_finite [Boolean] Whether to check for NaN/Infinite values
|
|
76
|
+
# @raise [ArgumentError] If data is invalid
|
|
77
|
+
def validate_clustering(data, check_finite: false)
|
|
78
|
+
raise ArgumentError, "Data must be an array" unless data.is_a?(Array)
|
|
79
|
+
raise ArgumentError, "Data cannot be empty" if data.empty?
|
|
80
|
+
raise ArgumentError, "Data must be 2D array" unless data.first.is_a?(Array)
|
|
81
|
+
|
|
82
|
+
validate_row_consistency(data)
|
|
83
|
+
validate_numeric_types(data)
|
|
84
|
+
validate_finite_values(data) if check_finite
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
# Validation for PCA with specific error messages (same as clustering but without finite checks)
|
|
88
|
+
# @param data [Array] 2D array to validate
|
|
89
|
+
# @raise [ArgumentError] If data is invalid
|
|
90
|
+
def validate_pca(data)
|
|
91
|
+
raise ArgumentError, "Data must be an array" unless data.is_a?(Array)
|
|
92
|
+
raise ArgumentError, "Data cannot be empty" if data.empty?
|
|
93
|
+
raise ArgumentError, "Data must be 2D array" unless data.first.is_a?(Array)
|
|
94
|
+
|
|
95
|
+
validate_row_consistency(data)
|
|
96
|
+
validate_numeric_types(data)
|
|
97
|
+
end
|
|
98
|
+
|
|
99
|
+
# Get data statistics for warnings/error context
|
|
100
|
+
# @param data [Array] 2D array
|
|
101
|
+
# @return [Hash] Statistics about the data
|
|
102
|
+
def data_statistics(data)
|
|
103
|
+
return { n_samples: 0, n_features: 0, data_range: 0.0 } if data.empty?
|
|
104
|
+
|
|
105
|
+
n_samples = data.size
|
|
106
|
+
n_features = data.first&.size || 0
|
|
107
|
+
|
|
108
|
+
# Calculate data range for warnings
|
|
109
|
+
min_val = Float::INFINITY
|
|
110
|
+
max_val = -Float::INFINITY
|
|
111
|
+
|
|
112
|
+
data.each do |row|
|
|
113
|
+
row.each do |val|
|
|
114
|
+
val_f = val.to_f
|
|
115
|
+
min_val = val_f if val_f < min_val
|
|
116
|
+
max_val = val_f if val_f > max_val
|
|
117
|
+
end
|
|
118
|
+
end
|
|
119
|
+
|
|
120
|
+
data_range = max_val - min_val
|
|
121
|
+
|
|
122
|
+
{
|
|
123
|
+
n_samples: n_samples,
|
|
124
|
+
n_features: n_features,
|
|
125
|
+
data_range: data_range,
|
|
126
|
+
min_value: min_val,
|
|
127
|
+
max_value: max_val
|
|
128
|
+
}
|
|
129
|
+
end
|
|
130
|
+
end
|
|
131
|
+
end
|
|
132
|
+
end
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
require_relative '../clusterkit'
|
|
4
4
|
require_relative 'svd'
|
|
5
|
+
require_relative '../data_validator'
|
|
5
6
|
|
|
6
7
|
module ClusterKit
|
|
7
8
|
module Dimensionality
|
|
@@ -30,7 +31,7 @@ module ClusterKit
|
|
|
30
31
|
|
|
31
32
|
# Perform SVD on centered data
|
|
32
33
|
# U contains the transformed data, S contains singular values, VT contains components
|
|
33
|
-
u, s, vt =
|
|
34
|
+
u, s, vt = perform_svd(centered_data)
|
|
34
35
|
|
|
35
36
|
# Store the principal components (eigenvectors)
|
|
36
37
|
@components = vt # Shape: (n_components, n_features)
|
|
@@ -76,7 +77,7 @@ module ClusterKit
|
|
|
76
77
|
centered_data = center_data(data, @mean)
|
|
77
78
|
|
|
78
79
|
# Perform SVD on centered data
|
|
79
|
-
u, s, vt =
|
|
80
|
+
u, s, vt = perform_svd(centered_data)
|
|
80
81
|
|
|
81
82
|
# Store the principal components (eigenvectors)
|
|
82
83
|
@components = vt
|
|
@@ -166,17 +167,10 @@ module ClusterKit
|
|
|
166
167
|
private
|
|
167
168
|
|
|
168
169
|
def validate_data(data)
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
raise ArgumentError, "Data must be 2D array" unless data.first.is_a?(Array)
|
|
172
|
-
|
|
173
|
-
# Check all rows have same length
|
|
174
|
-
row_length = data.first.length
|
|
175
|
-
unless data.all? { |row| row.is_a?(Array) && row.length == row_length }
|
|
176
|
-
raise ArgumentError, "All rows must have the same length"
|
|
177
|
-
end
|
|
170
|
+
# Use shared validation for common checks
|
|
171
|
+
DataValidator.validate_pca(data)
|
|
178
172
|
|
|
179
|
-
#
|
|
173
|
+
# PCA-specific validations
|
|
180
174
|
if data.size < @n_components
|
|
181
175
|
raise ArgumentError, "n_components (#{@n_components}) cannot be larger than n_samples (#{data.size})"
|
|
182
176
|
end
|
|
@@ -237,6 +231,12 @@ module ClusterKit
|
|
|
237
231
|
|
|
238
232
|
transformed
|
|
239
233
|
end
|
|
234
|
+
|
|
235
|
+
# Shared SVD computation for both fit and fit_transform
|
|
236
|
+
# Ensures both methods use identical SVD invocation and parameters
|
|
237
|
+
def perform_svd(centered_data)
|
|
238
|
+
SVD.randomized_svd(centered_data, @n_components, n_iter: 5)
|
|
239
|
+
end
|
|
240
240
|
end
|
|
241
241
|
|
|
242
242
|
# Module-level convenience method
|