svmkit 0.7.2 → 0.7.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/HISTORY.md +5 -1
- data/lib/svmkit.rb +1 -0
- data/lib/svmkit/ensemble/ada_boost_classifier.rb +1 -1
- data/lib/svmkit/ensemble/ada_boost_regressor.rb +1 -1
- data/lib/svmkit/ensemble/random_forest_classifier.rb +1 -1
- data/lib/svmkit/ensemble/random_forest_regressor.rb +1 -1
- data/lib/svmkit/model_selection/grid_search_cv.rb +247 -0
- data/lib/svmkit/pipeline/pipeline.rb +11 -1
- data/lib/svmkit/utils.rb +1 -1
- data/lib/svmkit/version.rb +1 -1
- metadata +3 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA1:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: ca1916101dd6c77c5be1a157c2bfa8dafe9c543e
|
4
|
+
data.tar.gz: c4751b21fd3d0667bb7d378f8b524fc2f70069d9
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: db878c8b28e88649fed654b292358c11ec91369cd52ec03e01d06d053fbeb90ebff87248be628c8ab081fd820c1460bb3783448242531ef5f30b4b06337af87c
|
7
|
+
data.tar.gz: bfbfc580897a4a3161afa865cd14ac15a0a322cf80ae728f7b70c8d45cb41ff4566b1f1e748e6c1eb9412a4c401b737a569bf6b9af12a113a4ca4a6d75f8b9b8
|
data/HISTORY.md
CHANGED
@@ -1,6 +1,10 @@
|
|
1
|
+
# 0.7.3
|
2
|
+
- Add class for grid search performing hyperparameter optimization.
|
3
|
+
- Add argument validations to Pipeline.
|
4
|
+
|
1
5
|
# 0.7.2
|
2
6
|
- Add class for Pipeline that constructs chain of transformers and estimators.
|
3
|
-
- Fix some typos on document.
|
7
|
+
- Fix some typos on document ([#1](https://github.com/yoshoku/SVMKit/pull/1)).
|
4
8
|
|
5
9
|
# 0.7.1
|
6
10
|
- Fix to use CSV class in parsing libsvm format file.
|
data/lib/svmkit.rb
CHANGED
@@ -55,6 +55,7 @@ require 'svmkit/preprocessing/one_hot_encoder'
|
|
55
55
|
require 'svmkit/model_selection/k_fold'
|
56
56
|
require 'svmkit/model_selection/stratified_k_fold'
|
57
57
|
require 'svmkit/model_selection/cross_validation'
|
58
|
+
require 'svmkit/model_selection/grid_search_cv'
|
58
59
|
require 'svmkit/evaluation_measure/accuracy'
|
59
60
|
require 'svmkit/evaluation_measure/precision'
|
60
61
|
require 'svmkit/evaluation_measure/recall'
|
@@ -109,7 +109,7 @@ module SVMKit
|
|
109
109
|
tree = Tree::DecisionTreeClassifier.new(
|
110
110
|
criterion: @params[:criterion], max_depth: @params[:max_depth],
|
111
111
|
max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
|
112
|
-
max_features: @params[:max_features], random_seed: @rng.rand(SVMKit::Values
|
112
|
+
max_features: @params[:max_features], random_seed: @rng.rand(SVMKit::Values.int_max)
|
113
113
|
)
|
114
114
|
tree.fit(x[ids, true], y[ids])
|
115
115
|
# Calculate estimator error.
|
@@ -111,7 +111,7 @@ module SVMKit
|
|
111
111
|
tree = Tree::DecisionTreeRegressor.new(
|
112
112
|
criterion: @params[:criterion], max_depth: @params[:max_depth],
|
113
113
|
max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
|
114
|
-
max_features: @params[:max_features], random_seed: @rng.rand(SVMKit::Values
|
114
|
+
max_features: @params[:max_features], random_seed: @rng.rand(SVMKit::Values.int_max)
|
115
115
|
)
|
116
116
|
tree.fit(x[ids, true], y[ids])
|
117
117
|
p = tree.predict(x)
|
@@ -97,7 +97,7 @@ module SVMKit
|
|
97
97
|
tree = Tree::DecisionTreeClassifier.new(
|
98
98
|
criterion: @params[:criterion], max_depth: @params[:max_depth],
|
99
99
|
max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
|
100
|
-
max_features: @params[:max_features], random_seed: @rng.rand(SVMKit::Values
|
100
|
+
max_features: @params[:max_features], random_seed: @rng.rand(SVMKit::Values.int_max)
|
101
101
|
)
|
102
102
|
bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) }
|
103
103
|
tree.fit(x[bootstrap_ids, true], y[bootstrap_ids])
|
@@ -91,7 +91,7 @@ module SVMKit
|
|
91
91
|
tree = Tree::DecisionTreeRegressor.new(
|
92
92
|
criterion: @params[:criterion], max_depth: @params[:max_depth],
|
93
93
|
max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
|
94
|
-
max_features: @params[:max_features], random_seed: @rng.rand(SVMKit::Values
|
94
|
+
max_features: @params[:max_features], random_seed: @rng.rand(SVMKit::Values.int_max)
|
95
95
|
)
|
96
96
|
bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) }
|
97
97
|
tree.fit(x[bootstrap_ids, true], single_target ? y[bootstrap_ids] : y[bootstrap_ids, true])
|
@@ -0,0 +1,247 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'svmkit/validation'
|
4
|
+
require 'svmkit/base/base_estimator'
|
5
|
+
require 'svmkit/base/evaluator'
|
6
|
+
require 'svmkit/base/splitter'
|
7
|
+
require 'svmkit/pipeline/pipeline'
|
8
|
+
|
9
|
+
module SVMKit
|
10
|
+
module ModelSelection
|
11
|
+
# GridSearchCV is a class that performs hyperparameter optimization with grid search method.
|
12
|
+
#
|
13
|
+
# @example
|
14
|
+
# rfc = SVMKit::Ensemble::RandomForestClassifier.new(random_seed: 1)
|
15
|
+
# pg = { n_estimators: [5, 10], max_depth: [3, 5], max_leaf_nodes: [15, 31] }
|
16
|
+
# kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5)
|
17
|
+
# gs = SVMKit::ModelSelection::GridSearchCV.new(estimator: rfc, param_grid: pg, splitter: kf)
|
18
|
+
# gs.fit(samples, labels)
|
19
|
+
# p gs.cv_results
|
20
|
+
# p gs.best_params
|
21
|
+
#
|
22
|
+
# @example
|
23
|
+
# rbf = SVMKit::KernelApproximation::RBF.new(random_seed: 1)
|
24
|
+
# svc = SVMKit::LinearModel::SVC.new(random_seed: 1)
|
25
|
+
# pipe = SVMKit::Pipeline::Pipeline.new(steps: { rbf: rbf, svc: svc })
|
26
|
+
# pg = { rbf__gamma: [32.0, 1.0], rbf__n_components: [4, 128], svc__reg_param: [16.0, 0.1] }
|
27
|
+
# kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5)
|
28
|
+
# gs = SVMKit::ModelSelection::GridSearchCV.new(estimator: pipe, param_grid: pg, splitter: kf)
|
29
|
+
# gs.fit(samples, labels)
|
30
|
+
# p gs.cv_results
|
31
|
+
# p gs.best_params
|
32
|
+
#
|
33
|
+
class GridSearchCV
|
34
|
+
include Base::BaseEstimator
|
35
|
+
include Validation
|
36
|
+
|
37
|
+
# Return the result of cross validation for each parameter.
|
38
|
+
# @return [Hash]
|
39
|
+
attr_reader :cv_results
|
40
|
+
|
41
|
+
# Return the score of the estimator learned with the best parameter.
|
42
|
+
# @return [Float]
|
43
|
+
attr_reader :best_score
|
44
|
+
|
45
|
+
# Return the best parameter set.
|
46
|
+
# @return [Hash]
|
47
|
+
attr_reader :best_params
|
48
|
+
|
49
|
+
# Return the index of the best parameter.
|
50
|
+
# @return [Integer]
|
51
|
+
attr_reader :best_index
|
52
|
+
|
53
|
+
# Return the estimator learned with the best parameter.
|
54
|
+
# @return [Estimator]
|
55
|
+
attr_reader :best_estimator
|
56
|
+
|
57
|
+
# Create a new grid search method.
|
58
|
+
#
|
59
|
+
# @param estimator [Classifier/Regresor] The estimator to be searched for optimal parameters with grid search method.
|
60
|
+
# @param param_grid [Array<Hash>] The parameter sets is represented with array of hash that
|
61
|
+
# consists of parameter names as keys and array of parameter values as values.
|
62
|
+
# @param splitter [Splitter] The splitter that divides dataset to training and testing dataset on cross validation.
|
63
|
+
# @param evaluator [Evaluator] The evaluator that calculates score of estimator results on cross validation.
|
64
|
+
# If nil is given, the score method of estimator is used to evaluation.
|
65
|
+
# @param greater_is_better [Boolean] The flag that indicates whether the estimator is better as
|
66
|
+
# evaluation score is larger.
|
67
|
+
def initialize(estimator: nil, param_grid: nil, splitter: nil, evaluator: nil, greater_is_better: true)
|
68
|
+
check_params_type(SVMKit::Base::BaseEstimator, estimator: estimator)
|
69
|
+
check_params_type(SVMKit::Base::Splitter, splitter: splitter)
|
70
|
+
check_params_type_or_nil(SVMKit::Base::Evaluator, evaluator: evaluator)
|
71
|
+
check_params_boolean(greater_is_better: greater_is_better)
|
72
|
+
@params = {}
|
73
|
+
@params[:param_grid] = valid_param_grid(param_grid)
|
74
|
+
@params[:estimator] = Marshal.load(Marshal.dump(estimator))
|
75
|
+
@params[:splitter] = Marshal.load(Marshal.dump(splitter))
|
76
|
+
@params[:evaluator] = Marshal.load(Marshal.dump(evaluator))
|
77
|
+
@params[:greater_is_better] = greater_is_better
|
78
|
+
@cv_results = nil
|
79
|
+
@best_score = nil
|
80
|
+
@best_params = nil
|
81
|
+
@best_index = nil
|
82
|
+
@best_estimator = nil
|
83
|
+
end
|
84
|
+
|
85
|
+
# Fit the model with given training data and all sets of parameters.
|
86
|
+
#
|
87
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
88
|
+
# @param y [Numo::NArray] (shape: [n_samples, n_outputs]) The target values or labels to be used for fitting the model.
|
89
|
+
# @return [GridSearchCV] The learned estimator with grid search.
|
90
|
+
def fit(x, y)
|
91
|
+
check_sample_array(x)
|
92
|
+
|
93
|
+
init_attrs
|
94
|
+
|
95
|
+
param_combinations.each do |prm_set|
|
96
|
+
prm_set.each do |prms|
|
97
|
+
report = perform_cross_validation(x, y, prms)
|
98
|
+
store_cv_result(prms, report)
|
99
|
+
end
|
100
|
+
end
|
101
|
+
|
102
|
+
find_best_params
|
103
|
+
|
104
|
+
@best_estimator = configurated_estimator(@best_params)
|
105
|
+
@best_estimator.fit(x, y)
|
106
|
+
self
|
107
|
+
end
|
108
|
+
|
109
|
+
# Call the decision_function method of learned estimator with the best parameter.
|
110
|
+
#
|
111
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
|
112
|
+
# @return [Numo::DFloat] (shape: [n_samples]) Confidence score per sample.
|
113
|
+
def decision_function(x)
|
114
|
+
check_sample_array(x)
|
115
|
+
@best_estimator.decision_function(x)
|
116
|
+
end
|
117
|
+
|
118
|
+
# Call the predict method of learned estimator with the best parameter.
|
119
|
+
#
|
120
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to obtain prediction result.
|
121
|
+
# @return [Numo::NArray] Predicted results.
|
122
|
+
def predict(x)
|
123
|
+
check_sample_array(x)
|
124
|
+
@best_estimator.predict(x)
|
125
|
+
end
|
126
|
+
|
127
|
+
# Call the predict_log_proba method of learned estimator with the best parameter.
|
128
|
+
#
|
129
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the log-probailities.
|
130
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.
|
131
|
+
def predict_log_proba(x)
|
132
|
+
check_sample_array(x)
|
133
|
+
@best_estimator.predict_log_proba(x)
|
134
|
+
end
|
135
|
+
|
136
|
+
# Call the predict_proba method of learned estimator with the best parameter.
|
137
|
+
#
|
138
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
|
139
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
|
140
|
+
def predict_proba(x)
|
141
|
+
check_sample_array(x)
|
142
|
+
@best_estimator.predict_proba(x)
|
143
|
+
end
|
144
|
+
|
145
|
+
# Call the score method of learned estimator with the best parameter.
|
146
|
+
#
|
147
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
|
148
|
+
# @param y [Numo::NArray] (shape: [n_samples, n_outputs]) True target values or labels for testing data.
|
149
|
+
# @return [Float] The score of estimator.
|
150
|
+
def score(x, y)
|
151
|
+
check_sample_array(x)
|
152
|
+
@best_estimator.score(x, y)
|
153
|
+
end
|
154
|
+
|
155
|
+
# Dump marshal data.
|
156
|
+
# @return [Hash] The marshal data about GridSearchCV.
|
157
|
+
def marshal_dump
|
158
|
+
{ params: @params,
|
159
|
+
cv_results: @cv_results,
|
160
|
+
best_score: @best_score,
|
161
|
+
best_params: @best_params,
|
162
|
+
best_index: @best_index,
|
163
|
+
best_estimator: @best_estimator }
|
164
|
+
end
|
165
|
+
|
166
|
+
# Load marshal data.
|
167
|
+
# @return [nil]
|
168
|
+
def marshal_load(obj)
|
169
|
+
@params = obj[:params]
|
170
|
+
@cv_results = obj[:cv_results]
|
171
|
+
@best_score = obj[:best_score]
|
172
|
+
@best_params = obj[:best_params]
|
173
|
+
@best_index = obj[:best_index]
|
174
|
+
@best_estimator = obj[:best_estimator]
|
175
|
+
nil
|
176
|
+
end
|
177
|
+
|
178
|
+
private
|
179
|
+
|
180
|
+
def valid_param_grid(grid)
|
181
|
+
raise TypeError, 'Expect class of param_grid to be Hash or Array' unless grid.is_a?(Hash) || grid.is_a?(Array)
|
182
|
+
grid = [grid] if grid.is_a?(Hash)
|
183
|
+
grid.each do |h|
|
184
|
+
raise TypeError, 'Expect class of elements in param_grid to be Hash' unless h.is_a?(Hash)
|
185
|
+
raise TypeError, 'Expect class of parameter values in param_grid to be Array' unless h.values.all? { |v| v.is_a?(Array) }
|
186
|
+
end
|
187
|
+
grid
|
188
|
+
end
|
189
|
+
|
190
|
+
def param_combinations
|
191
|
+
@param_combinations ||= @params[:param_grid].map do |prm|
|
192
|
+
x = Hash[prm.sort].map { |k, v| [k].product(v) }
|
193
|
+
x[0].product(*x[1...x.size]).map { |v| Hash[v] }
|
194
|
+
end
|
195
|
+
end
|
196
|
+
|
197
|
+
def perform_cross_validation(x, y, prms)
|
198
|
+
est = configurated_estimator(prms)
|
199
|
+
cv = CrossValidation.new(estimator: est, splitter: @params[:splitter],
|
200
|
+
evaluator: @params[:evaluator], return_train_score: true)
|
201
|
+
cv.perform(x, y)
|
202
|
+
end
|
203
|
+
|
204
|
+
def configurated_estimator(prms)
|
205
|
+
estimator = Marshal.load(Marshal.dump(@params[:estimator]))
|
206
|
+
if @params[:estimator].is_a?(SVMKit::Pipeline::Pipeline)
|
207
|
+
prms.each do |k, v|
|
208
|
+
est_name, prm_name = k.to_s.split('__')
|
209
|
+
estimator.steps[est_name.to_sym].params[prm_name.to_sym] = v
|
210
|
+
end
|
211
|
+
else
|
212
|
+
prms.each { |k, v| estimator.params[k] = v }
|
213
|
+
end
|
214
|
+
estimator
|
215
|
+
end
|
216
|
+
|
217
|
+
def init_attrs
|
218
|
+
@cv_results = %i[mean_test_score std_test_score
|
219
|
+
mean_train_score std_train_score
|
220
|
+
mean_fit_time std_fit_time params].map { |v| [v, []] }.to_h
|
221
|
+
@best_score = nil
|
222
|
+
@best_params = nil
|
223
|
+
@best_index = nil
|
224
|
+
@best_estimator = nil
|
225
|
+
end
|
226
|
+
|
227
|
+
def store_cv_result(prms, report)
|
228
|
+
test_scores = Numo::DFloat[*report[:test_score]]
|
229
|
+
train_scores = Numo::DFloat[*report[:train_score]]
|
230
|
+
fit_times = Numo::DFloat[*report[:fit_time]]
|
231
|
+
@cv_results[:mean_test_score].push(test_scores.mean)
|
232
|
+
@cv_results[:std_test_score].push(test_scores.stddev)
|
233
|
+
@cv_results[:mean_train_score].push(train_scores.mean)
|
234
|
+
@cv_results[:std_train_score].push(train_scores.stddev)
|
235
|
+
@cv_results[:mean_fit_time].push(fit_times.mean)
|
236
|
+
@cv_results[:std_fit_time].push(fit_times.stddev)
|
237
|
+
@cv_results[:params].push(prms)
|
238
|
+
end
|
239
|
+
|
240
|
+
def find_best_params
|
241
|
+
@best_score = @params[:greater_is_better] ? @cv_results[:mean_test_score].max : @cv_results[:mean_test_score].min
|
242
|
+
@best_index = @cv_results[:mean_test_score].index(@best_score)
|
243
|
+
@best_params = @cv_results[:params][@best_index]
|
244
|
+
end
|
245
|
+
end
|
246
|
+
end
|
247
|
+
end
|
@@ -40,6 +40,7 @@ module SVMKit
|
|
40
40
|
# @param y [Numo::NArray] (shape: [n_samples, n_outputs]) The target values or labels to be used for fitting the model.
|
41
41
|
# @return [Pipeline] The learned pipeline itself.
|
42
42
|
def fit(x, y)
|
43
|
+
check_sample_array(x)
|
43
44
|
trans_x = apply_transforms(x, y, fit: true)
|
44
45
|
last_estimator.fit(trans_x, y) unless last_estimator.nil?
|
45
46
|
self
|
@@ -51,6 +52,7 @@ module SVMKit
|
|
51
52
|
# @param y [Numo::NArray] (shape: [n_samples, n_outputs], default: nil) The target values or labels to be used for fitting the model.
|
52
53
|
# @return [Numo::NArray] The predicted results by last estimator.
|
53
54
|
def fit_predict(x, y = nil)
|
55
|
+
check_sample_array(x)
|
54
56
|
trans_x = apply_transforms(x, y, fit: true)
|
55
57
|
last_estimator.fit_predict(trans_x)
|
56
58
|
end
|
@@ -61,6 +63,7 @@ module SVMKit
|
|
61
63
|
# @param y [Numo::NArray] (shape: [n_samples, n_outputs], default: nil) The target values or labels to be used for fitting the model.
|
62
64
|
# @return [Numo::NArray] The predicted results by last estimator.
|
63
65
|
def fit_transform(x, y = nil)
|
66
|
+
check_sample_array(x)
|
64
67
|
trans_x = apply_transforms(x, y, fit: true)
|
65
68
|
last_estimator.fit_transform(trans_x, y)
|
66
69
|
end
|
@@ -70,6 +73,7 @@ module SVMKit
|
|
70
73
|
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
|
71
74
|
# @return [Numo::DFloat] (shape: [n_samples]) Confidence score per sample.
|
72
75
|
def decision_function(x)
|
76
|
+
check_sample_array(x)
|
73
77
|
trans_x = apply_transforms(x)
|
74
78
|
last_estimator.decision_function(trans_x)
|
75
79
|
end
|
@@ -79,6 +83,7 @@ module SVMKit
|
|
79
83
|
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to obtain prediction result.
|
80
84
|
# @return [Numo::NArray] The predicted results by last estimator.
|
81
85
|
def predict(x)
|
86
|
+
check_sample_array(x)
|
82
87
|
trans_x = apply_transforms(x)
|
83
88
|
last_estimator.predict(trans_x)
|
84
89
|
end
|
@@ -88,6 +93,7 @@ module SVMKit
|
|
88
93
|
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the log-probailities.
|
89
94
|
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.
|
90
95
|
def predict_log_proba(x)
|
96
|
+
check_sample_array(x)
|
91
97
|
trans_x = apply_transforms(x)
|
92
98
|
last_estimator.predict_log_proba(trans_x)
|
93
99
|
end
|
@@ -97,6 +103,7 @@ module SVMKit
|
|
97
103
|
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
|
98
104
|
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
|
99
105
|
def predict_proba(x)
|
106
|
+
check_sample_array(x)
|
100
107
|
trans_x = apply_transforms(x)
|
101
108
|
last_estimator.predict_proba(trans_x)
|
102
109
|
end
|
@@ -106,6 +113,7 @@ module SVMKit
|
|
106
113
|
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be transformed.
|
107
114
|
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed samples.
|
108
115
|
def transform(x)
|
116
|
+
check_sample_array(x)
|
109
117
|
trans_x = apply_transforms(x)
|
110
118
|
last_estimator.nil? ? trans_x : last_estimator.transform(trans_x)
|
111
119
|
end
|
@@ -115,8 +123,9 @@ module SVMKit
|
|
115
123
|
# @param z [Numo::DFloat] (shape: [n_samples, n_components]) The transformed samples to be restored into original space.
|
116
124
|
# @return [Numo::DFloat] (shape: [n_samples, n_featuress]) The restored samples.
|
117
125
|
def inverse_transform(z)
|
126
|
+
check_sample_array(z)
|
118
127
|
itrans_z = z
|
119
|
-
@steps.keys.
|
128
|
+
@steps.keys.reverse_each do |name|
|
120
129
|
transformer = @steps[name]
|
121
130
|
next if transformer.nil?
|
122
131
|
itrans_z = transformer.inverse_transform(itrans_z)
|
@@ -130,6 +139,7 @@ module SVMKit
|
|
130
139
|
# @param y [Numo::NArray] (shape: [n_samples, n_outputs]) True target values or labels for testing data.
|
131
140
|
# @return [Float] The score of last estimator
|
132
141
|
def score(x, y)
|
142
|
+
check_sample_array(x)
|
133
143
|
trans_x = apply_transforms(x)
|
134
144
|
last_estimator.score(trans_x, y)
|
135
145
|
end
|
data/lib/svmkit/utils.rb
CHANGED
data/lib/svmkit/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: svmkit
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.7.
|
4
|
+
version: 0.7.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-
|
11
|
+
date: 2019-02-05 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -146,6 +146,7 @@ files:
|
|
146
146
|
- lib/svmkit/linear_model/svc.rb
|
147
147
|
- lib/svmkit/linear_model/svr.rb
|
148
148
|
- lib/svmkit/model_selection/cross_validation.rb
|
149
|
+
- lib/svmkit/model_selection/grid_search_cv.rb
|
149
150
|
- lib/svmkit/model_selection/k_fold.rb
|
150
151
|
- lib/svmkit/model_selection/stratified_k_fold.rb
|
151
152
|
- lib/svmkit/multiclass/one_vs_rest_classifier.rb
|