svmkit 0.7.3 → 0.8.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +0 -9
- data/.rspec +1 -0
- data/.travis.yml +4 -12
- data/LICENSE.txt +1 -1
- data/README.md +11 -13
- data/lib/svmkit.rb +3 -66
- data/svmkit.gemspec +12 -7
- metadata +16 -81
- data/.coveralls.yml +0 -1
- data/.rubocop.yml +0 -47
- data/.rubocop_todo.yml +0 -58
- data/HISTORY.md +0 -168
- data/lib/svmkit/base/base_estimator.rb +0 -13
- data/lib/svmkit/base/classifier.rb +0 -34
- data/lib/svmkit/base/cluster_analyzer.rb +0 -29
- data/lib/svmkit/base/evaluator.rb +0 -13
- data/lib/svmkit/base/regressor.rb +0 -34
- data/lib/svmkit/base/splitter.rb +0 -17
- data/lib/svmkit/base/transformer.rb +0 -18
- data/lib/svmkit/clustering/dbscan.rb +0 -127
- data/lib/svmkit/clustering/k_means.rb +0 -140
- data/lib/svmkit/dataset.rb +0 -109
- data/lib/svmkit/decomposition/nmf.rb +0 -147
- data/lib/svmkit/decomposition/pca.rb +0 -150
- data/lib/svmkit/ensemble/ada_boost_classifier.rb +0 -198
- data/lib/svmkit/ensemble/ada_boost_regressor.rb +0 -180
- data/lib/svmkit/ensemble/random_forest_classifier.rb +0 -182
- data/lib/svmkit/ensemble/random_forest_regressor.rb +0 -143
- data/lib/svmkit/evaluation_measure/accuracy.rb +0 -30
- data/lib/svmkit/evaluation_measure/f_score.rb +0 -51
- data/lib/svmkit/evaluation_measure/log_loss.rb +0 -46
- data/lib/svmkit/evaluation_measure/mean_absolute_error.rb +0 -30
- data/lib/svmkit/evaluation_measure/mean_squared_error.rb +0 -30
- data/lib/svmkit/evaluation_measure/normalized_mutual_information.rb +0 -63
- data/lib/svmkit/evaluation_measure/precision.rb +0 -51
- data/lib/svmkit/evaluation_measure/precision_recall.rb +0 -91
- data/lib/svmkit/evaluation_measure/purity.rb +0 -41
- data/lib/svmkit/evaluation_measure/r2_score.rb +0 -44
- data/lib/svmkit/evaluation_measure/recall.rb +0 -51
- data/lib/svmkit/kernel_approximation/rbf.rb +0 -136
- data/lib/svmkit/kernel_machine/kernel_svc.rb +0 -194
- data/lib/svmkit/linear_model/lasso.rb +0 -138
- data/lib/svmkit/linear_model/linear_regression.rb +0 -112
- data/lib/svmkit/linear_model/logistic_regression.rb +0 -161
- data/lib/svmkit/linear_model/ridge.rb +0 -112
- data/lib/svmkit/linear_model/sgd_linear_estimator.rb +0 -89
- data/lib/svmkit/linear_model/svc.rb +0 -184
- data/lib/svmkit/linear_model/svr.rb +0 -123
- data/lib/svmkit/model_selection/cross_validation.rb +0 -121
- data/lib/svmkit/model_selection/grid_search_cv.rb +0 -247
- data/lib/svmkit/model_selection/k_fold.rb +0 -77
- data/lib/svmkit/model_selection/stratified_k_fold.rb +0 -95
- data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +0 -101
- data/lib/svmkit/naive_bayes/naive_bayes.rb +0 -316
- data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +0 -112
- data/lib/svmkit/nearest_neighbors/k_neighbors_regressor.rb +0 -94
- data/lib/svmkit/optimizer/nadam.rb +0 -90
- data/lib/svmkit/optimizer/rmsprop.rb +0 -69
- data/lib/svmkit/optimizer/sgd.rb +0 -65
- data/lib/svmkit/optimizer/yellow_fin.rb +0 -144
- data/lib/svmkit/pairwise_metric.rb +0 -91
- data/lib/svmkit/pipeline/pipeline.rb +0 -197
- data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +0 -262
- data/lib/svmkit/polynomial_model/factorization_machine_regressor.rb +0 -194
- data/lib/svmkit/preprocessing/l2_normalizer.rb +0 -63
- data/lib/svmkit/preprocessing/label_encoder.rb +0 -95
- data/lib/svmkit/preprocessing/min_max_scaler.rb +0 -93
- data/lib/svmkit/preprocessing/one_hot_encoder.rb +0 -99
- data/lib/svmkit/preprocessing/standard_scaler.rb +0 -87
- data/lib/svmkit/probabilistic_output.rb +0 -112
- data/lib/svmkit/tree/decision_tree_classifier.rb +0 -276
- data/lib/svmkit/tree/decision_tree_regressor.rb +0 -251
- data/lib/svmkit/tree/node.rb +0 -70
- data/lib/svmkit/utils.rb +0 -22
- data/lib/svmkit/validation.rb +0 -79
- data/lib/svmkit/values.rb +0 -13
- data/lib/svmkit/version.rb +0 -7
data/.rubocop_todo.yml
DELETED
@@ -1,58 +0,0 @@
|
|
1
|
-
# This configuration was generated by
|
2
|
-
# `rubocop --auto-gen-config`
|
3
|
-
# on 2018-06-10 12:21:53 +0900 using RuboCop version 0.57.1.
|
4
|
-
# The point is for the user to remove these configuration records
|
5
|
-
# one by one as the offenses are removed from the code base.
|
6
|
-
# Note that changes in the inspected code, or installation of new
|
7
|
-
# versions of RuboCop, may require this file to be generated again.
|
8
|
-
|
9
|
-
# Offense count: 2
|
10
|
-
# Cop supports --auto-correct.
|
11
|
-
Layout/ClosingHeredocIndentation:
|
12
|
-
Exclude:
|
13
|
-
- 'svmkit.gemspec'
|
14
|
-
|
15
|
-
# Offense count: 2
|
16
|
-
# Cop supports --auto-correct.
|
17
|
-
# Configuration parameters: EnforcedStyle.
|
18
|
-
# SupportedStyles: auto_detection, squiggly, active_support, powerpack, unindent
|
19
|
-
Layout/IndentHeredoc:
|
20
|
-
Exclude:
|
21
|
-
- 'svmkit.gemspec'
|
22
|
-
|
23
|
-
# Offense count: 1
|
24
|
-
# Cop supports --auto-correct.
|
25
|
-
Layout/LeadingBlankLines:
|
26
|
-
Exclude:
|
27
|
-
- 'svmkit.gemspec'
|
28
|
-
|
29
|
-
# Offense count: 1
|
30
|
-
# Configuration parameters: CountComments, ExcludedMethods.
|
31
|
-
Metrics/BlockLength:
|
32
|
-
Max: 29
|
33
|
-
|
34
|
-
# Offense count: 3
|
35
|
-
Metrics/CyclomaticComplexity:
|
36
|
-
Max: 12
|
37
|
-
|
38
|
-
# Offense count: 3
|
39
|
-
Metrics/PerceivedComplexity:
|
40
|
-
Max: 13
|
41
|
-
|
42
|
-
# Offense count: 1
|
43
|
-
# Cop supports --auto-correct.
|
44
|
-
# Configuration parameters: EnforcedStyle, UseHashRocketsWithSymbolValues, PreferHashRocketsForNonAlnumEndingSymbols.
|
45
|
-
# SupportedStyles: ruby19, hash_rockets, no_mixed_keys, ruby19_no_mixed_keys
|
46
|
-
Style/HashSyntax:
|
47
|
-
Exclude:
|
48
|
-
- 'Rakefile'
|
49
|
-
|
50
|
-
# Offense count: 6
|
51
|
-
# Cop supports --auto-correct.
|
52
|
-
# Configuration parameters: EnforcedStyle, ConsistentQuotesInMultiline.
|
53
|
-
# SupportedStyles: single_quotes, double_quotes
|
54
|
-
Style/StringLiterals:
|
55
|
-
Exclude:
|
56
|
-
- 'Gemfile'
|
57
|
-
- 'Rakefile'
|
58
|
-
- 'bin/console'
|
data/HISTORY.md
DELETED
@@ -1,168 +0,0 @@
|
|
1
|
-
# 0.7.3
|
2
|
-
- Add class for grid search performing hyperparameter optimization.
|
3
|
-
- Add argument validations to Pipeline.
|
4
|
-
|
5
|
-
# 0.7.2
|
6
|
-
- Add class for Pipeline that constructs chain of transformers and estimators.
|
7
|
-
- Fix some typos on document ([#1](https://github.com/yoshoku/SVMKit/pull/1)).
|
8
|
-
|
9
|
-
# 0.7.1
|
10
|
-
- Fix to use CSV class in parsing libsvm format file.
|
11
|
-
- Refactor ensemble estimators.
|
12
|
-
|
13
|
-
# 0.7.0
|
14
|
-
- Add class for AdaBoost classifier.
|
15
|
-
- Add class for AdaBoost regressor.
|
16
|
-
|
17
|
-
# 0.6.3
|
18
|
-
- Fix bug on setting random seed and max_features parameter of Random Forest estimators.
|
19
|
-
|
20
|
-
# 0.6.2
|
21
|
-
- Refactor decision tree classes for improving performance.
|
22
|
-
|
23
|
-
# 0.6.1
|
24
|
-
- Add abstract class for linear estimator with stochastic gradient descent.
|
25
|
-
- Refactor linear estimators to use linear esitmator abstract class.
|
26
|
-
- Refactor decision tree classes to avoid unneeded type conversion.
|
27
|
-
|
28
|
-
# 0.6.0
|
29
|
-
- Add class for Principal Component Analysis.
|
30
|
-
- Add class for Non-negative Matrix Factorization.
|
31
|
-
|
32
|
-
# 0.5.2
|
33
|
-
- Add class for DBSCAN clustering.
|
34
|
-
|
35
|
-
# 0.5.1
|
36
|
-
- Fix bug on class probability calculation of DecisionTreeClassifier.
|
37
|
-
|
38
|
-
# 0.5.0
|
39
|
-
- Add class for K-Means clustering.
|
40
|
-
- Add class for evaluating purity.
|
41
|
-
- Add class for evaluating normalized mutual information.
|
42
|
-
|
43
|
-
# 0.4.1
|
44
|
-
- Add class for linear regressor.
|
45
|
-
- Add class for SGD optimizer.
|
46
|
-
- Add class for RMSProp optimizer.
|
47
|
-
- Add class for YellowFin optimizer.
|
48
|
-
- Fix to be able to select optimizer on estimators of LineaModel and PolynomialModel.
|
49
|
-
|
50
|
-
# 0.4.0
|
51
|
-
## Breaking changes
|
52
|
-
|
53
|
-
SVMKit introduces optimizer algorithm that calculates learning rates adaptively
|
54
|
-
on each iteration of stochastic gradient descent (SGD).
|
55
|
-
While Pegasos SGD runs fast, it sometimes fails to optimize complicated models
|
56
|
-
like Factorization Machine.
|
57
|
-
To solve this problem, in version 0.3.3, SVMKit introduced optimization with RMSProp on
|
58
|
-
FactorizationMachineRegressor, Ridge and Lasso.
|
59
|
-
This attempt realized stable optimization of those estimators.
|
60
|
-
Following the success of the attempt, author decided to use modern optimizer algorithms
|
61
|
-
with all SGD optimizations in SVMKit.
|
62
|
-
Through some preliminary experiments, author implemented Nadam as the default optimizer.
|
63
|
-
SVMKit plans to add other optimizer algorithms sequentially, so that users can select them.
|
64
|
-
|
65
|
-
- Fix to use Nadam for optimization on SVC, SVR, LogisticRegression, Ridge, Lasso, and Factorization Machine estimators.
|
66
|
-
- Combine reg_param_weight and reg_param_bias parameters on Factorization Machine estimators into the unified parameter named reg_param_linear.
|
67
|
-
- Remove init_std paramter on Factorization Machine estimators.
|
68
|
-
- Remove learning_rate, decay, and momentum parameters on Ridge, Lasso, and FactorizationMachineRegressor.
|
69
|
-
- Remove normalize parameter on SVC, SVR, and LogisticRegression.
|
70
|
-
|
71
|
-
# 0.3.3
|
72
|
-
- Add class for Ridge regressor.
|
73
|
-
- Add class for Lasso regressor.
|
74
|
-
- Fix bug on gradient calculation of FactorizationMachineRegressor.
|
75
|
-
- Fix some documents.
|
76
|
-
|
77
|
-
# 0.3.2
|
78
|
-
- Add class for Factorization Machine regressor.
|
79
|
-
- Add class for Decision Tree regressor.
|
80
|
-
- Add class for Random Forest regressor.
|
81
|
-
- Fix to support loading and dumping libsvm file with multi-target variables.
|
82
|
-
- Fix to require DecisionTreeClassifier on RandomForestClassifier.
|
83
|
-
- Fix some mistakes on document.
|
84
|
-
|
85
|
-
# 0.3.1
|
86
|
-
- Fix bug on decision function calculation of FactorizationMachineClassifier.
|
87
|
-
- Fix bug on weight updating process of KernelSVC.
|
88
|
-
|
89
|
-
# 0.3.0
|
90
|
-
- Add class for Support Vector Regression.
|
91
|
-
- Add class for K-Nearest Neighbor Regression.
|
92
|
-
- Add class for evaluating coefficient of determination.
|
93
|
-
- Add class for evaluating mean squared error.
|
94
|
-
- Add class for evaluating mean absolute error.
|
95
|
-
- Fix to use min method instead of sort and first methods.
|
96
|
-
- Fix cross validation class to be able to use for regression problem.
|
97
|
-
- Fix some typos on document.
|
98
|
-
- Rename spec filename for Factorization Machine classifier.
|
99
|
-
|
100
|
-
# 0.2.9
|
101
|
-
- Add predict_proba method to SVC and KernelSVC.
|
102
|
-
- Add class for evaluating logarithmic loss.
|
103
|
-
- Add classes for Label- and One-Hot- encoding.
|
104
|
-
- Add some validator.
|
105
|
-
- Fix bug on training data score calculation of cross validation.
|
106
|
-
- Fix fit method of SVC for performance.
|
107
|
-
- Fix criterion calculation on Decision Tree for performance.
|
108
|
-
- Fix data structure of Decision Tree for performance.
|
109
|
-
|
110
|
-
# 0.2.8
|
111
|
-
- Fix bug on gradient calculation of Logistic Regression.
|
112
|
-
- Fix to change accessor of params of estimators to read only.
|
113
|
-
- Add parameter validation.
|
114
|
-
|
115
|
-
# 0.2.7
|
116
|
-
- Fix to support multiclass classifiction into LinearSVC, LogisticRegression, KernelSVC, and FactorizationMachineClassifier
|
117
|
-
|
118
|
-
# 0.2.6
|
119
|
-
- Add class for Decision Tree classifier.
|
120
|
-
- Add class for Random Forest classifier.
|
121
|
-
- Fix to use frozen string literal.
|
122
|
-
- Refactor marshal dump method on some classes.
|
123
|
-
- Introduce Coveralls to confirm test coverage.
|
124
|
-
|
125
|
-
# 0.2.5
|
126
|
-
- Add classes for Naive Bayes classifier.
|
127
|
-
- Fix decision function method on Logistic Regression class.
|
128
|
-
- Fix method visibility on RBF kernel approximation class.
|
129
|
-
|
130
|
-
# 0.2.4
|
131
|
-
- Add class for Factorization Machine classifier.
|
132
|
-
- Add classes for evaluation measures.
|
133
|
-
- Fix the method for prediction of class probability in Logistic Regression.
|
134
|
-
|
135
|
-
# 0.2.3
|
136
|
-
- Add class for cross validation.
|
137
|
-
- Add specs for base modules.
|
138
|
-
- Fix validation of the number of splits when a negative label is given.
|
139
|
-
|
140
|
-
# 0.2.2
|
141
|
-
- Add data splitter classes for K-fold cross validation.
|
142
|
-
|
143
|
-
# 0.2.1
|
144
|
-
- Add class for K-nearest neighbors classifier.
|
145
|
-
|
146
|
-
# 0.2.0
|
147
|
-
- Migrated the linear algebra library to Numo::NArray.
|
148
|
-
- Add module for loading and saving libsvm format file.
|
149
|
-
|
150
|
-
# 0.1.3
|
151
|
-
- Add class for Kernel Support Vector Machine with Pegasos algorithm.
|
152
|
-
- Add module for calculating pairwise kernel fuctions and euclidean distances.
|
153
|
-
|
154
|
-
# 0.1.2
|
155
|
-
- Add the function learning a model with bias term to the PegasosSVC and LogisticRegression classes.
|
156
|
-
- Rewrite the document with yard notation.
|
157
|
-
|
158
|
-
# 0.1.1
|
159
|
-
- Add class for Logistic Regression with SGD optimization.
|
160
|
-
- Fix some mistakes on the document.
|
161
|
-
|
162
|
-
# 0.1.0
|
163
|
-
- Add basic classes.
|
164
|
-
- Add an utility module.
|
165
|
-
- Add class for RBF kernel approximation.
|
166
|
-
- Add class for Support Vector Machine with Pegasos alogrithm.
|
167
|
-
- Add class that performs mutlclass classification with one-vs.-rest strategy.
|
168
|
-
- Add classes for preprocessing such as min-max scaling, standardization, and L2 normalization.
|
@@ -1,13 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module SVMKit
|
4
|
-
# This module consists of basic mix-in classes.
|
5
|
-
module Base
|
6
|
-
# Base module for all estimators in SVMKit.
|
7
|
-
module BaseEstimator
|
8
|
-
# Return parameters about an estimator.
|
9
|
-
# @return [Hash]
|
10
|
-
attr_reader :params
|
11
|
-
end
|
12
|
-
end
|
13
|
-
end
|
@@ -1,34 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/evaluation_measure/accuracy'
|
5
|
-
|
6
|
-
module SVMKit
|
7
|
-
module Base
|
8
|
-
# Module for all classifiers in SVMKit.
|
9
|
-
module Classifier
|
10
|
-
# An abstract method for fitting a model.
|
11
|
-
def fit
|
12
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
13
|
-
end
|
14
|
-
|
15
|
-
# An abstract method for predicting labels.
|
16
|
-
def predict
|
17
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
18
|
-
end
|
19
|
-
|
20
|
-
# Calculate the mean accuracy of the given testing data.
|
21
|
-
#
|
22
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
23
|
-
# @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
|
24
|
-
# @return [Float] Mean accuracy
|
25
|
-
def score(x, y)
|
26
|
-
SVMKit::Validation.check_sample_array(x)
|
27
|
-
SVMKit::Validation.check_label_array(y)
|
28
|
-
SVMKit::Validation.check_sample_label_size(x, y)
|
29
|
-
evaluator = SVMKit::EvaluationMeasure::Accuracy.new
|
30
|
-
evaluator.score(y, predict(x))
|
31
|
-
end
|
32
|
-
end
|
33
|
-
end
|
34
|
-
end
|
@@ -1,29 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/evaluation_measure/purity'
|
5
|
-
|
6
|
-
module SVMKit
|
7
|
-
module Base
|
8
|
-
# Module for all clustering algorithms in SVMKit.
|
9
|
-
module ClusterAnalyzer
|
10
|
-
# An abstract method for analyzing clusters and predicting cluster indices.
|
11
|
-
def fit_predict
|
12
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
13
|
-
end
|
14
|
-
|
15
|
-
# Calculate purity of clustering result.
|
16
|
-
#
|
17
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
18
|
-
# @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
|
19
|
-
# @return [Float] Purity
|
20
|
-
def score(x, y)
|
21
|
-
SVMKit::Validation.check_sample_array(x)
|
22
|
-
SVMKit::Validation.check_label_array(y)
|
23
|
-
SVMKit::Validation.check_sample_label_size(x, y)
|
24
|
-
evaluator = SVMKit::EvaluationMeasure::Purity.new
|
25
|
-
evaluator.score(y, fit_predict(x))
|
26
|
-
end
|
27
|
-
end
|
28
|
-
end
|
29
|
-
end
|
@@ -1,13 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module SVMKit
|
4
|
-
module Base
|
5
|
-
# Module for all evaluation measures in SVMKit.
|
6
|
-
module Evaluator
|
7
|
-
# An abstract method for evaluation of model.
|
8
|
-
def score
|
9
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
10
|
-
end
|
11
|
-
end
|
12
|
-
end
|
13
|
-
end
|
@@ -1,34 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/evaluation_measure/r2_score'
|
5
|
-
|
6
|
-
module SVMKit
|
7
|
-
module Base
|
8
|
-
# Module for all regressors in SVMKit.
|
9
|
-
module Regressor
|
10
|
-
# An abstract method for fitting a model.
|
11
|
-
def fit
|
12
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
13
|
-
end
|
14
|
-
|
15
|
-
# An abstract method for predicting labels.
|
16
|
-
def predict
|
17
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
18
|
-
end
|
19
|
-
|
20
|
-
# Calculate the coefficient of determination for the given testing data.
|
21
|
-
#
|
22
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
23
|
-
# @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) Target values for testing data.
|
24
|
-
# @return [Float] Coefficient of determination
|
25
|
-
def score(x, y)
|
26
|
-
SVMKit::Validation.check_sample_array(x)
|
27
|
-
SVMKit::Validation.check_tvalue_array(y)
|
28
|
-
SVMKit::Validation.check_sample_tvalue_size(x, y)
|
29
|
-
evaluator = SVMKit::EvaluationMeasure::R2Score.new
|
30
|
-
evaluator.score(y, predict(x))
|
31
|
-
end
|
32
|
-
end
|
33
|
-
end
|
34
|
-
end
|
data/lib/svmkit/base/splitter.rb
DELETED
@@ -1,17 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module SVMKit
|
4
|
-
module Base
|
5
|
-
# Module for all validation methods in SVMKit.
|
6
|
-
module Splitter
|
7
|
-
# Return the number of splits.
|
8
|
-
# @return [Integer]
|
9
|
-
attr_reader :n_splits
|
10
|
-
|
11
|
-
# An abstract method for splitting dataset.
|
12
|
-
def split
|
13
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
14
|
-
end
|
15
|
-
end
|
16
|
-
end
|
17
|
-
end
|
@@ -1,18 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
module SVMKit
|
4
|
-
module Base
|
5
|
-
# Module for all transfomers in SVMKit.
|
6
|
-
module Transformer
|
7
|
-
# An abstract method for fitting a model.
|
8
|
-
def fit
|
9
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
10
|
-
end
|
11
|
-
|
12
|
-
# An abstract method for fitting a model and transforming given data.
|
13
|
-
def fit_transform
|
14
|
-
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
15
|
-
end
|
16
|
-
end
|
17
|
-
end
|
18
|
-
end
|
@@ -1,127 +0,0 @@
|
|
1
|
-
# frozen_string_literal: true
|
2
|
-
|
3
|
-
require 'svmkit/validation'
|
4
|
-
require 'svmkit/base/base_estimator'
|
5
|
-
require 'svmkit/base/cluster_analyzer'
|
6
|
-
require 'svmkit/pairwise_metric'
|
7
|
-
|
8
|
-
module SVMKit
|
9
|
-
module Clustering
|
10
|
-
# DBSCAN is a class that implements DBSCAN cluster analysis.
|
11
|
-
# The current implementation uses the Euclidean distance for analyzing the clusters.
|
12
|
-
#
|
13
|
-
# @example
|
14
|
-
# analyzer = SVMKit::Clustering::DBSCAN.new(eps: 0.5, min_samples: 5)
|
15
|
-
# cluster_labels = analyzer.fit_predict(samples)
|
16
|
-
#
|
17
|
-
# *Reference*
|
18
|
-
# - M. Ester, H-P. Kriegel, J. Sander, and X. Xu, "A density-based algorithm for discovering clusters in large spatial databases with noise," Proc. KDD' 96, pp. 266--231, 1996.
|
19
|
-
class DBSCAN
|
20
|
-
include Base::BaseEstimator
|
21
|
-
include Base::ClusterAnalyzer
|
22
|
-
include Validation
|
23
|
-
|
24
|
-
# Return the core sample indices.
|
25
|
-
# @return [Numo::Int32] (shape: [n_core_samples])
|
26
|
-
attr_reader :core_sample_ids
|
27
|
-
|
28
|
-
# Return the cluster labels. The negative cluster label indicates that the point is noise.
|
29
|
-
# @return [Numo::Int32] (shape: [n_samples])
|
30
|
-
attr_reader :labels
|
31
|
-
|
32
|
-
# Create a new cluster analyzer with DBSCAN method.
|
33
|
-
#
|
34
|
-
# @param eps [Float] The radius of neighborhood.
|
35
|
-
# @param min_samples [Integer] The number of neighbor samples to be used for the criterion whether a point is a core point.
|
36
|
-
def initialize(eps: 0.5, min_samples: 5)
|
37
|
-
check_params_float(eps: eps)
|
38
|
-
check_params_integer(min_samples: min_samples)
|
39
|
-
@params = {}
|
40
|
-
@params[:eps] = eps
|
41
|
-
@params[:min_samples] = min_samples
|
42
|
-
@core_sample_ids = nil
|
43
|
-
@labels = nil
|
44
|
-
end
|
45
|
-
|
46
|
-
# Analysis clusters with given training data.
|
47
|
-
#
|
48
|
-
# @overload fit(x) -> DBSCAN
|
49
|
-
#
|
50
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
51
|
-
# @return [DBSCAN] The learned cluster analyzer itself.
|
52
|
-
def fit(x, _y = nil)
|
53
|
-
check_sample_array(x)
|
54
|
-
partial_fit(x)
|
55
|
-
self
|
56
|
-
end
|
57
|
-
|
58
|
-
# Analysis clusters and assign samples to clusters.
|
59
|
-
#
|
60
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for cluster analysis.
|
61
|
-
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
62
|
-
def fit_predict(x)
|
63
|
-
check_sample_array(x)
|
64
|
-
partial_fit(x)
|
65
|
-
labels
|
66
|
-
end
|
67
|
-
|
68
|
-
# Dump marshal data.
|
69
|
-
# @return [Hash] The marshal data.
|
70
|
-
def marshal_dump
|
71
|
-
{ params: @params,
|
72
|
-
core_sample_ids: @core_sample_ids,
|
73
|
-
labels: @labels }
|
74
|
-
end
|
75
|
-
|
76
|
-
# Load marshal data.
|
77
|
-
# @return [nil]
|
78
|
-
def marshal_load(obj)
|
79
|
-
@params = obj[:params]
|
80
|
-
@core_sample_ids = obj[:core_sample_ids]
|
81
|
-
@labels = obj[:labels]
|
82
|
-
nil
|
83
|
-
end
|
84
|
-
|
85
|
-
private
|
86
|
-
|
87
|
-
def partial_fit(x)
|
88
|
-
cluster_id = 0
|
89
|
-
n_samples = x.shape[0]
|
90
|
-
@core_sample_ids = []
|
91
|
-
@labels = Numo::Int32.zeros(n_samples) - 2
|
92
|
-
n_samples.times do |q|
|
93
|
-
next if @labels[q] >= -1
|
94
|
-
cluster_id += 1 if expand_cluster(x, q, cluster_id)
|
95
|
-
end
|
96
|
-
@core_sample_ids = Numo::Int32[*@core_sample_ids.flatten]
|
97
|
-
nil
|
98
|
-
end
|
99
|
-
|
100
|
-
def expand_cluster(x, query_id, cluster_id)
|
101
|
-
target_ids = region_query(x[query_id, true], x)
|
102
|
-
if target_ids.size < @params[:min_samples]
|
103
|
-
@labels[query_id] = -1
|
104
|
-
false
|
105
|
-
else
|
106
|
-
@labels[target_ids] = cluster_id
|
107
|
-
@core_sample_ids.push(target_ids.dup)
|
108
|
-
target_ids.delete(query_id)
|
109
|
-
while (m = target_ids.shift)
|
110
|
-
neighbor_ids = region_query(x[m, true], x)
|
111
|
-
next if neighbor_ids.size < @params[:min_samples]
|
112
|
-
neighbor_ids.each do |n|
|
113
|
-
target_ids.push(n) if @labels[n] < -1
|
114
|
-
@labels[n] = cluster_id if @labels[n] <= -1
|
115
|
-
end
|
116
|
-
end
|
117
|
-
true
|
118
|
-
end
|
119
|
-
end
|
120
|
-
|
121
|
-
def region_query(query, targets)
|
122
|
-
distance_arr = PairwiseMetric.euclidean_distance(query.expand_dims(0), targets)[0, true]
|
123
|
-
distance_arr.lt(@params[:eps]).where.to_a
|
124
|
-
end
|
125
|
-
end
|
126
|
-
end
|
127
|
-
end
|