svmkit 0.7.3 → 0.8.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +0 -9
- data/.rspec +1 -0
- data/.travis.yml +4 -12
- data/LICENSE.txt +1 -1
- data/README.md +11 -13
- data/lib/svmkit.rb +3 -66
- data/svmkit.gemspec +12 -7
- metadata +16 -81
- data/.coveralls.yml +0 -1
- data/.rubocop.yml +0 -47
- data/.rubocop_todo.yml +0 -58
- data/HISTORY.md +0 -168
- data/lib/svmkit/base/base_estimator.rb +0 -13
- data/lib/svmkit/base/classifier.rb +0 -34
- data/lib/svmkit/base/cluster_analyzer.rb +0 -29
- data/lib/svmkit/base/evaluator.rb +0 -13
- data/lib/svmkit/base/regressor.rb +0 -34
- data/lib/svmkit/base/splitter.rb +0 -17
- data/lib/svmkit/base/transformer.rb +0 -18
- data/lib/svmkit/clustering/dbscan.rb +0 -127
- data/lib/svmkit/clustering/k_means.rb +0 -140
- data/lib/svmkit/dataset.rb +0 -109
- data/lib/svmkit/decomposition/nmf.rb +0 -147
- data/lib/svmkit/decomposition/pca.rb +0 -150
- data/lib/svmkit/ensemble/ada_boost_classifier.rb +0 -198
- data/lib/svmkit/ensemble/ada_boost_regressor.rb +0 -180
- data/lib/svmkit/ensemble/random_forest_classifier.rb +0 -182
- data/lib/svmkit/ensemble/random_forest_regressor.rb +0 -143
- data/lib/svmkit/evaluation_measure/accuracy.rb +0 -30
- data/lib/svmkit/evaluation_measure/f_score.rb +0 -51
- data/lib/svmkit/evaluation_measure/log_loss.rb +0 -46
- data/lib/svmkit/evaluation_measure/mean_absolute_error.rb +0 -30
- data/lib/svmkit/evaluation_measure/mean_squared_error.rb +0 -30
- data/lib/svmkit/evaluation_measure/normalized_mutual_information.rb +0 -63
- data/lib/svmkit/evaluation_measure/precision.rb +0 -51
- data/lib/svmkit/evaluation_measure/precision_recall.rb +0 -91
- data/lib/svmkit/evaluation_measure/purity.rb +0 -41
- data/lib/svmkit/evaluation_measure/r2_score.rb +0 -44
- data/lib/svmkit/evaluation_measure/recall.rb +0 -51
- data/lib/svmkit/kernel_approximation/rbf.rb +0 -136
- data/lib/svmkit/kernel_machine/kernel_svc.rb +0 -194
- data/lib/svmkit/linear_model/lasso.rb +0 -138
- data/lib/svmkit/linear_model/linear_regression.rb +0 -112
- data/lib/svmkit/linear_model/logistic_regression.rb +0 -161
- data/lib/svmkit/linear_model/ridge.rb +0 -112
- data/lib/svmkit/linear_model/sgd_linear_estimator.rb +0 -89
- data/lib/svmkit/linear_model/svc.rb +0 -184
- data/lib/svmkit/linear_model/svr.rb +0 -123
- data/lib/svmkit/model_selection/cross_validation.rb +0 -121
- data/lib/svmkit/model_selection/grid_search_cv.rb +0 -247
- data/lib/svmkit/model_selection/k_fold.rb +0 -77
- data/lib/svmkit/model_selection/stratified_k_fold.rb +0 -95
- data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +0 -101
- data/lib/svmkit/naive_bayes/naive_bayes.rb +0 -316
- data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +0 -112
- data/lib/svmkit/nearest_neighbors/k_neighbors_regressor.rb +0 -94
- data/lib/svmkit/optimizer/nadam.rb +0 -90
- data/lib/svmkit/optimizer/rmsprop.rb +0 -69
- data/lib/svmkit/optimizer/sgd.rb +0 -65
- data/lib/svmkit/optimizer/yellow_fin.rb +0 -144
- data/lib/svmkit/pairwise_metric.rb +0 -91
- data/lib/svmkit/pipeline/pipeline.rb +0 -197
- data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +0 -262
- data/lib/svmkit/polynomial_model/factorization_machine_regressor.rb +0 -194
- data/lib/svmkit/preprocessing/l2_normalizer.rb +0 -63
- data/lib/svmkit/preprocessing/label_encoder.rb +0 -95
- data/lib/svmkit/preprocessing/min_max_scaler.rb +0 -93
- data/lib/svmkit/preprocessing/one_hot_encoder.rb +0 -99
- data/lib/svmkit/preprocessing/standard_scaler.rb +0 -87
- data/lib/svmkit/probabilistic_output.rb +0 -112
- data/lib/svmkit/tree/decision_tree_classifier.rb +0 -276
- data/lib/svmkit/tree/decision_tree_regressor.rb +0 -251
- data/lib/svmkit/tree/node.rb +0 -70
- data/lib/svmkit/utils.rb +0 -22
- data/lib/svmkit/validation.rb +0 -79
- data/lib/svmkit/values.rb +0 -13
- data/lib/svmkit/version.rb +0 -7
data/.rubocop_todo.yml
DELETED
@@ -1,58 +0,0 @@
|
|
1
|
-
# This configuration was generated by
|
2
|
-
# `rubocop --auto-gen-config`
|
3
|
-
# on 2018-06-10 12:21:53 +0900 using RuboCop version 0.57.1.
|
4
|
-
# The point is for the user to remove these configuration records
|
5
|
-
# one by one as the offenses are removed from the code base.
|
6
|
-
# Note that changes in the inspected code, or installation of new
|
7
|
-
# versions of RuboCop, may require this file to be generated again.
|
8
|
-
|
9
|
-
# Offense count: 2
|
10
|
-
# Cop supports --auto-correct.
|
11
|
-
Layout/ClosingHeredocIndentation:
|
12
|
-
Exclude:
|
13
|
-
- 'svmkit.gemspec'
|
14
|
-
|
15
|
-
# Offense count: 2
|
16
|
-
# Cop supports --auto-correct.
|
17
|
-
# Configuration parameters: EnforcedStyle.
|
18
|
-
# SupportedStyles: auto_detection, squiggly, active_support, powerpack, unindent
|
19
|
-
Layout/IndentHeredoc:
|
20
|
-
Exclude:
|
21
|
-
- 'svmkit.gemspec'
|
22
|
-
|
23
|
-
# Offense count: 1
|
24
|
-
# Cop supports --auto-correct.
|
25
|
-
Layout/LeadingBlankLines:
|
26
|
-
Exclude:
|
27
|
-
- 'svmkit.gemspec'
|
28
|
-
|
29
|
-
# Offense count: 1
|
30
|
-
# Configuration parameters: CountComments, ExcludedMethods.
|
31
|
-
Metrics/BlockLength:
|
32
|
-
Max: 29
|
33
|
-
|
34
|
-
# Offense count: 3
|
35
|
-
Metrics/CyclomaticComplexity:
|
36
|
-
Max: 12
|
37
|
-
|
38
|
-
# Offense count: 3
|
39
|
-
Metrics/PerceivedComplexity:
|
40
|
-
Max: 13
|
41
|
-
|
42
|
-
# Offense count: 1
|
43
|
-
# Cop supports --auto-correct.
|
44
|
-
# Configuration parameters: EnforcedStyle, UseHashRocketsWithSymbolValues, PreferHashRocketsForNonAlnumEndingSymbols.
|
45
|
-
# SupportedStyles: ruby19, hash_rockets, no_mixed_keys, ruby19_no_mixed_keys
|
46
|
-
Style/HashSyntax:
|
47
|
-
Exclude:
|
48
|
-
- 'Rakefile'
|
49
|
-
|
50
|
-
# Offense count: 6
|
51
|
-
# Cop supports --auto-correct.
|
52
|
-
# Configuration parameters: EnforcedStyle, ConsistentQuotesInMultiline.
|
53
|
-
# SupportedStyles: single_quotes, double_quotes
|
54
|
-
Style/StringLiterals:
|
55
|
-
Exclude:
|
56
|
-
- 'Gemfile'
|
57
|
-
- 'Rakefile'
|
58
|
-
- 'bin/console'
|
data/HISTORY.md
DELETED
@@ -1,168 +0,0 @@
|
|
1
|
-
# 0.7.3
|
2
|
-
- Add class for grid search performing hyperparameter optimization.
|
3
|
-
- Add argument validations to Pipeline.
|
4
|
-
|
5
|
-
# 0.7.2
|
6
|
-
- Add class for Pipeline that constructs chain of transformers and estimators.
|
7
|
-
- Fix some typos on document ([#1](https://github.com/yoshoku/SVMKit/pull/1)).
|
8
|
-
|
9
|
-
# 0.7.1
|
10
|
-
- Fix to use CSV class in parsing libsvm format file.
|
11
|
-
- Refactor ensemble estimators.
|
12
|
-
|
13
|
-
# 0.7.0
|
14
|
-
- Add class for AdaBoost classifier.
|
15
|
-
- Add class for AdaBoost regressor.
|
16
|
-
|
17
|
-
# 0.6.3
|
18
|
-
- Fix bug on setting random seed and max_features parameter of Random Forest estimators.
|
19
|
-
|
20
|
-
# 0.6.2
|
21
|
-
- Refactor decision tree classes for improving performance.
|
22
|
-
|
23
|
-
# 0.6.1
|
24
|
-
- Add abstract class for linear estimator with stochastic gradient descent.
|
25
|
-
- Refactor linear estimators to use linear esitmator abstract class.
|
26
|
-
- Refactor decision tree classes to avoid unneeded type conversion.
|
27
|
-
|
28
|
-
# 0.6.0
|
29
|
-
- Add class for Principal Component Analysis.
|
30
|
-
- Add class for Non-negative Matrix Factorization.
|
31
|
-
|
32
|
-
# 0.5.2
|
33
|
-
- Add class for DBSCAN clustering.
|
34
|
-
|
35
|
-
# 0.5.1
|
36
|
-
- Fix bug on class probability calculation of DecisionTreeClassifier.
|
37
|
-
|
38
|
-
# 0.5.0
|
39
|
-
- Add class for K-Means clustering.
|
40
|
-
- Add class for evaluating purity.
|
41
|
-
- Add class for evaluating normalized mutual information.
|
42
|
-
|
43
|
-
# 0.4.1
|
44
|
-
- Add class for linear regressor.
|
45
|
-
- Add class for SGD optimizer.
|
46
|
-
- Add class for RMSProp optimizer.
|
47
|
-
- Add class for YellowFin optimizer.
|
48
|
-
- Fix to be able to select optimizer on estimators of LineaModel and PolynomialModel.
|
49
|
-
|
50
|
-
# 0.4.0
|
51
|
-
## Breaking changes
|
52
|
-
|
53
|
-
SVMKit introduces optimizer algorithm that calculates learning rates adaptively
|
54
|
-
on each iteration of stochastic gradient descent (SGD).
|
55
|
-
While Pegasos SGD runs fast, it sometimes fails to optimize complicated models
|
56
|
-
like Factorization Machine.
|
57
|
-
To solve this problem, in version 0.3.3, SVMKit introduced optimization with RMSProp on
|
58
|
-
FactorizationMachineRegressor, Ridge and Lasso.
|
59
|
-
This attempt realized stable optimization of those estimators.
|
60
|
-
Following the success of the attempt, author decided to use modern optimizer algorithms
|
61
|
-
with all SGD optimizations in SVMKit.
|
62
|
-
Through some preliminary experiments, author implemented Nadam as the default optimizer.
|
63
|
-
SVMKit plans to add other optimizer algorithms sequentially, so that users can select them.
|
64
|
-
|
65
|
-
- Fix to use Nadam for optimization on SVC, SVR, LogisticRegression, Ridge, Lasso, and Factorization Machine estimators.
|
66
|
-
- Combine reg_param_weight and reg_param_bias parameters on Factorization Machine estimators into the unified parameter named reg_param_linear.
|
67
|
-
- Remove init_std paramter on Factorization Machine estimators.
|
68
|
-
- Remove learning_rate, decay, and momentum parameters on Ridge, Lasso, and FactorizationMachineRegressor.
|
69
|
-
- Remove normalize parameter on SVC, SVR, and LogisticRegression.
|
70
|
-
|
71
|
-
# 0.3.3
|
72
|
-
- Add class for Ridge regressor.
|
73
|
-
- Add class for Lasso regressor.
|
74
|
-
- Fix bug on gradient calculation of FactorizationMachineRegressor.
|
75
|
-
- Fix some documents.
|
76
|
-
|
77
|
-
# 0.3.2
|
78
|
-
- Add class for Factorization Machine regressor.
|
79
|
-
- Add class for Decision Tree regressor.
|
80
|
-
- Add class for Random Forest regressor.
|
81
|
-
- Fix to support loading and dumping libsvm file with multi-target variables.
|
82
|
-
- Fix to require DecisionTreeClassifier on RandomForestClassifier.
|
83
|
-
- Fix some mistakes on document.
|
84
|
-
|
85
|
-
# 0.3.1
|
86
|
-
- Fix bug on decision function calculation of FactorizationMachineClassifier.
|
87
|
-
- Fix bug on weight updating process of KernelSVC.
|
88
|
-
|
89
|
-
# 0.3.0
|
90
|
-
- Add class for Support Vector Regression.
|
91
|
-
- Add class for K-Nearest Neighbor Regression.
|
92
|
-
- Add class for evaluating coefficient of determination.
|
93
|
-
- Add class for evaluating mean squared error.
|
94
|
-
- Add class for evaluating mean absolute error.
|
95
|
-
- Fix to use min method instead of sort and first methods.
|
96
|
-
- Fix cross validation class to be able to use for regression problem.
|
97
|
-
- Fix some typos on document.
|
98
|
-
- Rename spec filename for Factorization Machine classifier.
|
99
|
-
|
100
|
-
# 0.2.9
|
101
|
-
- Add predict_proba method to SVC and KernelSVC.
|
102
|
-
- Add class for evaluating logarithmic loss.
|
103
|
-
- Add classes for Label- and One-Hot- encoding.
|
104
|
-
- Add some validator.
|
105
|
-
- Fix bug on training data score calculation of cross validation.
|
106
|
-
- Fix fit method of SVC for performance.
|
107
|
-
- Fix criterion calculation on Decision Tree for performance.
|
108
|
-
- Fix data structure of Decision Tree for performance.
|
109
|
-
|
110
|
-
# 0.2.8
|
111
|
-
- Fix bug on gradient calculation of Logistic Regression.
|
112
|
-
- Fix to change accessor of params of estimators to read only.
|
113
|
-
- Add parameter validation.
|
114
|
-
|
115
|
-
# 0.2.7
|
116
|
-
- Fix to support multiclass classifiction into LinearSVC, LogisticRegression, KernelSVC, and FactorizationMachineClassifier
|
117
|
-
|
118
|
-
# 0.2.6
|
119
|
-
- Add class for Decision Tree classifier.
|
120
|
-
- Add class for Random Forest classifier.
|
121
|
-
- Fix to use frozen string literal.
|
122
|
-
- Refactor marshal dump method on some classes.
|
123
|
-
- Introduce Coveralls to confirm test coverage.
|
124
|
-
|
125
|
-
# 0.2.5
|
126
|
-
- Add classes for Naive Bayes classifier.
|
127
|
-
- Fix decision function method on Logistic Regression class.
|
128
|
-
- Fix method visibility on RBF kernel approximation class.
|
129
|
-
|
130
|
-
# 0.2.4
|
131
|
-
- Add class for Factorization Machine classifier.
|
132
|
-
- Add classes for evaluation measures.
|
133
|
-
- Fix the method for prediction of class probability in Logistic Regression.
|
134
|
-
|
135
|
-
# 0.2.3
|
136
|
-
- Add class for cross validation.
|
137
|
-
- Add specs for base modules.
|
138
|
-
- Fix validation of the number of splits when a negative label is given.
|
139
|
-
|
140
|
-
# 0.2.2
|
141
|
-
- Add data splitter classes for K-fold cross validation.
|
142
|
-
|
143
|
-
# 0.2.1
|
144
|
-
- Add class for K-nearest neighbors classifier.
|
145
|
-
|
146
|
-
# 0.2.0
|
147
|
-
- Migrated the linear algebra library to Numo::NArray.
|
148
|
-
- Add module for loading and saving libsvm format file.
|
149
|
-
|
150
|
-
# 0.1.3
|
151
|
-
- Add class for Kernel Support Vector Machine with Pegasos algorithm.
|
152
|
-
- Add module for calculating pairwise kernel fuctions and euclidean distances.
|
153
|
-
|
154
|
-
# 0.1.2
|
155
|
-
- Add the function learning a model with bias term to the PegasosSVC and LogisticRegression classes.
|
156
|
-
- Rewrite the document with yard notation.
|
157
|
-
|
158
|
-
# 0.1.1
|
159
|
-
- Add class for Logistic Regression with SGD optimization.
|
160
|
-
- Fix some mistakes on the document.
|
161
|
-
|
162
|
-
# 0.1.0
|
163
|
-
- Add basic classes.
|
164
|
-
- Add an utility module.
|
165
|
-
- Add class for RBF kernel approximation.
|
166
|
-
- Add class for Support Vector Machine with Pegasos alogrithm.
|
167
|
-
- Add class that performs mutlclass classification with one-vs.-rest strategy.
|
168
|
-
- Add classes for preprocessing such as min-max scaling, standardization, and L2 normalization.
|
@@ -1,13 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module SVMKit
|
4
|
-
# This module consists of basic mix-in classes.
|
5
|
-
module Base
|
6
|
-
# Base module for all estimators in SVMKit.
|
7
|
-
module BaseEstimator
|
8
|
-
# Return parameters about an estimator.
|
9
|
-
# @return [Hash]
|
10
|
-
attr_reader :params
|
11
|
-
end
|
12
|
-
end
|
13
|
-
end
|
@@ -1,34 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/evaluation_measure/accuracy'
|
5
|
-
|
6
|
-
module SVMKit
|
7
|
-
module Base
|
8
|
-
# Module for all classifiers in SVMKit.
|
9
|
-
module Classifier
|
10
|
-
# An abstract method for fitting a model.
|
11
|
-
def fit
|
12
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
13
|
-
end
|
14
|
-
|
15
|
-
# An abstract method for predicting labels.
|
16
|
-
def predict
|
17
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
18
|
-
end
|
19
|
-
|
20
|
-
# Calculate the mean accuracy of the given testing data.
|
21
|
-
#
|
22
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
23
|
-
# @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
|
24
|
-
# @return [Float] Mean accuracy
|
25
|
-
def score(x, y)
|
26
|
-
SVMKit::Validation.check_sample_array(x)
|
27
|
-
SVMKit::Validation.check_label_array(y)
|
28
|
-
SVMKit::Validation.check_sample_label_size(x, y)
|
29
|
-
evaluator = SVMKit::EvaluationMeasure::Accuracy.new
|
30
|
-
evaluator.score(y, predict(x))
|
31
|
-
end
|
32
|
-
end
|
33
|
-
end
|
34
|
-
end
|
@@ -1,29 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/evaluation_measure/purity'
|
5
|
-
|
6
|
-
module SVMKit
|
7
|
-
module Base
|
8
|
-
# Module for all clustering algorithms in SVMKit.
|
9
|
-
module ClusterAnalyzer
|
10
|
-
# An abstract method for analyzing clusters and predicting cluster indices.
|
11
|
-
def fit_predict
|
12
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
13
|
-
end
|
14
|
-
|
15
|
-
# Calculate purity of clustering result.
|
16
|
-
#
|
17
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
18
|
-
# @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
|
19
|
-
# @return [Float] Purity
|
20
|
-
def score(x, y)
|
21
|
-
SVMKit::Validation.check_sample_array(x)
|
22
|
-
SVMKit::Validation.check_label_array(y)
|
23
|
-
SVMKit::Validation.check_sample_label_size(x, y)
|
24
|
-
evaluator = SVMKit::EvaluationMeasure::Purity.new
|
25
|
-
evaluator.score(y, fit_predict(x))
|
26
|
-
end
|
27
|
-
end
|
28
|
-
end
|
29
|
-
end
|
@@ -1,13 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module SVMKit
|
4
|
-
module Base
|
5
|
-
# Module for all evaluation measures in SVMKit.
|
6
|
-
module Evaluator
|
7
|
-
# An abstract method for evaluation of model.
|
8
|
-
def score
|
9
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
10
|
-
end
|
11
|
-
end
|
12
|
-
end
|
13
|
-
end
|
@@ -1,34 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/evaluation_measure/r2_score'
|
5
|
-
|
6
|
-
module SVMKit
|
7
|
-
module Base
|
8
|
-
# Module for all regressors in SVMKit.
|
9
|
-
module Regressor
|
10
|
-
# An abstract method for fitting a model.
|
11
|
-
def fit
|
12
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
13
|
-
end
|
14
|
-
|
15
|
-
# An abstract method for predicting labels.
|
16
|
-
def predict
|
17
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
18
|
-
end
|
19
|
-
|
20
|
-
# Calculate the coefficient of determination for the given testing data.
|
21
|
-
#
|
22
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
23
|
-
# @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) Target values for testing data.
|
24
|
-
# @return [Float] Coefficient of determination
|
25
|
-
def score(x, y)
|
26
|
-
SVMKit::Validation.check_sample_array(x)
|
27
|
-
SVMKit::Validation.check_tvalue_array(y)
|
28
|
-
SVMKit::Validation.check_sample_tvalue_size(x, y)
|
29
|
-
evaluator = SVMKit::EvaluationMeasure::R2Score.new
|
30
|
-
evaluator.score(y, predict(x))
|
31
|
-
end
|
32
|
-
end
|
33
|
-
end
|
34
|
-
end
|
data/lib/svmkit/base/splitter.rb
DELETED
@@ -1,17 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module SVMKit
|
4
|
-
module Base
|
5
|
-
# Module for all validation methods in SVMKit.
|
6
|
-
module Splitter
|
7
|
-
# Return the number of splits.
|
8
|
-
# @return [Integer]
|
9
|
-
attr_reader :n_splits
|
10
|
-
|
11
|
-
# An abstract method for splitting dataset.
|
12
|
-
def split
|
13
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
14
|
-
end
|
15
|
-
end
|
16
|
-
end
|
17
|
-
end
|
@@ -1,18 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module SVMKit
|
4
|
-
module Base
|
5
|
-
# Module for all transfomers in SVMKit.
|
6
|
-
module Transformer
|
7
|
-
# An abstract method for fitting a model.
|
8
|
-
def fit
|
9
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
10
|
-
end
|
11
|
-
|
12
|
-
# An abstract method for fitting a model and transforming given data.
|
13
|
-
def fit_transform
|
14
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
15
|
-
end
|
16
|
-
end
|
17
|
-
end
|
18
|
-
end
|
@@ -1,127 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/base/base_estimator'
|
5
|
-
require 'svmkit/base/cluster_analyzer'
|
6
|
-
require 'svmkit/pairwise_metric'
|
7
|
-
|
8
|
-
module SVMKit
|
9
|
-
module Clustering
|
10
|
-
# DBSCAN is a class that implements DBSCAN cluster analysis.
|
11
|
-
# The current implementation uses the Euclidean distance for analyzing the clusters.
|
12
|
-
#
|
13
|
-
# @example
|
14
|
-
# analyzer = SVMKit::Clustering::DBSCAN.new(eps: 0.5, min_samples: 5)
|
15
|
-
# cluster_labels = analyzer.fit_predict(samples)
|
16
|
-
#
|
17
|
-
# *Reference*
|
18
|
-
# - M. Ester, H-P. Kriegel, J. Sander, and X. Xu, "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
|
19
|
-
class DBSCAN
|
20
|
-
include Base::BaseEstimator
|
21
|
-
include Base::ClusterAnalyzer
|
22
|
-
include Validation
|
23
|
-
|
24
|
-
# Return the core sample indices.
|
25
|
-
# @return [Numo::Int32] (shape: [n_core_samples])
|
26
|
-
attr_reader :core_sample_ids
|
27
|
-
|
28
|
-
# Return the cluster labels. The negative cluster label indicates that the point is noise.
|
29
|
-
# @return [Numo::Int32] (shape: [n_samples])
|
30
|
-
attr_reader :labels
|
31
|
-
|
32
|
-
# Create a new cluster analyzer with DBSCAN method.
|
33
|
-
#
|
34
|
-
# @param eps [Float] The radius of neighborhood.
|
35
|
-
# @param min_samples [Integer] The number of neighbor samples to be used for the criterion whether a point is a core point.
|
36
|
-
def initialize(eps: 0.5, min_samples: 5)
|
37
|
-
check_params_float(eps: eps)
|
38
|
-
check_params_integer(min_samples: min_samples)
|
39
|
-
@params = {}
|
40
|
-
@params[:eps] = eps
|
41
|
-
@params[:min_samples] = min_samples
|
42
|
-
@core_sample_ids = nil
|
43
|
-
@labels = nil
|
44
|
-
end
|
45
|
-
|
46
|
-
# Analysis clusters with given training data.
|
47
|
-
#
|
48
|
-
# @overload fit(x) -> DBSCAN
|
49
|
-
#
|
50
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
51
|
-
# @return [DBSCAN] The learned cluster analyzer itself.
|
52
|
-
def fit(x, _y = nil)
|
53
|
-
check_sample_array(x)
|
54
|
-
partial_fit(x)
|
55
|
-
self
|
56
|
-
end
|
57
|
-
|
58
|
-
# Analysis clusters and assign samples to clusters.
|
59
|
-
#
|
60
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
61
|
-
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
62
|
-
def fit_predict(x)
|
63
|
-
check_sample_array(x)
|
64
|
-
partial_fit(x)
|
65
|
-
labels
|
66
|
-
end
|
67
|
-
|
68
|
-
# Dump marshal data.
|
69
|
-
# @return [Hash] The marshal data.
|
70
|
-
def marshal_dump
|
71
|
-
{ params: @params,
|
72
|
-
core_sample_ids: @core_sample_ids,
|
73
|
-
labels: @labels }
|
74
|
-
end
|
75
|
-
|
76
|
-
# Load marshal data.
|
77
|
-
# @return [nil]
|
78
|
-
def marshal_load(obj)
|
79
|
-
@params = obj[:params]
|
80
|
-
@core_sample_ids = obj[:core_sample_ids]
|
81
|
-
@labels = obj[:labels]
|
82
|
-
nil
|
83
|
-
end
|
84
|
-
|
85
|
-
private
|
86
|
-
|
87
|
-
def partial_fit(x)
|
88
|
-
cluster_id = 0
|
89
|
-
n_samples = x.shape[0]
|
90
|
-
@core_sample_ids = []
|
91
|
-
@labels = Numo::Int32.zeros(n_samples) - 2
|
92
|
-
n_samples.times do |q|
|
93
|
-
next if @labels[q] >= -1
|
94
|
-
cluster_id += 1 if expand_cluster(x, q, cluster_id)
|
95
|
-
end
|
96
|
-
@core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
|
97
|
-
nil
|
98
|
-
end
|
99
|
-
|
100
|
-
def expand_cluster(x, query_id, cluster_id)
|
101
|
-
target_ids = region_query(x[query_id, true], x)
|
102
|
-
if target_ids.size < @params[:min_samples]
|
103
|
-
@labels[query_id] = -1
|
104
|
-
false
|
105
|
-
else
|
106
|
-
@labels[target_ids] = cluster_id
|
107
|
-
@core_sample_ids.push(target_ids.dup)
|
108
|
-
target_ids.delete(query_id)
|
109
|
-
while (m = target_ids.shift)
|
110
|
-
neighbor_ids = region_query(x[m, true], x)
|
111
|
-
next if neighbor_ids.size < @params[:min_samples]
|
112
|
-
neighbor_ids.each do |n|
|
113
|
-
target_ids.push(n) if @labels[n] < -1
|
114
|
-
@labels[n] = cluster_id if @labels[n] <= -1
|
115
|
-
end
|
116
|
-
end
|
117
|
-
true
|
118
|
-
end
|
119
|
-
end
|
120
|
-
|
121
|
-
def region_query(query, targets)
|
122
|
-
distance_arr = PairwiseMetric.euclidean_distance(query.expand_dims(0), targets)[0, true]
|
123
|
-
distance_arr.lt(@params[:eps]).where.to_a
|
124
|
-
end
|
125
|
-
end
|
126
|
-
end
|
127
|
-
end
|