svmkit 0.2.2 → 0.2.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.travis.yml +4 -3
- data/HISTORY.md +6 -1
- data/README.md +21 -0
- data/lib/svmkit.rb +1 -0
- data/lib/svmkit/base/splitter.rb +1 -1
- data/lib/svmkit/kernel_machine/kernel_svc.rb +2 -2
- data/lib/svmkit/linear_model/logistic_regression.rb +1 -1
- data/lib/svmkit/linear_model/svc.rb +1 -1
- data/lib/svmkit/model_selection/cross_validation.rb +82 -0
- data/lib/svmkit/model_selection/k_fold.rb +1 -1
- data/lib/svmkit/model_selection/stratified_k_fold.rb +6 -2
- data/lib/svmkit/version.rb +1 -1
- metadata +3 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA1:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 6271a50754a13199f7c3c12c6f1b9e2a0a2075d5
|
4
|
+
data.tar.gz: ecdc84a2987f22d49ad1b8435397862771f96f37
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9974eb62cd19ebca32ca92cafbcbf2e34a978d41c0aa9bb0765c001d56963189aa43f71806fa5a2492a4b3e50bcd939a018aa784198f5aef08e2159682682dd2
|
7
|
+
data.tar.gz: eee0f089449a71f79576165aa083c41fe1d9812fcc251c31c2a5c7cf06e6dfd4293dddd6685fa6285c8d89a3744bb3a72bd9db6f36b7cbc586b7a71800c1fb78
|
data/.travis.yml
CHANGED
data/HISTORY.md
CHANGED
@@ -1,5 +1,10 @@
|
|
1
|
+
# 0.2.3
|
2
|
+
- Added class for cross validation.
|
3
|
+
- Added specs for base modules.
|
4
|
+
- Fixed validation of the number of splits when a negative label is given.
|
5
|
+
|
1
6
|
# 0.2.2
|
2
|
-
- Added classes for K-fold cross validation.
|
7
|
+
- Added data splitter classes for K-fold cross validation.
|
3
8
|
|
4
9
|
# 0.2.1
|
5
10
|
- Added class for K-nearest neighbors classifier.
|
data/README.md
CHANGED
@@ -66,6 +66,27 @@ transformed = transformer.transform(normalized)
|
|
66
66
|
puts(sprintf("Accuracy: %.1f%%", 100.0 * classifier.score(transformed, labels)))
|
67
67
|
```
|
68
68
|
|
69
|
+
5-fold cross-validation:
|
70
|
+
|
71
|
+
```ruby
|
72
|
+
require 'svmkit'
|
73
|
+
|
74
|
+
samples, labels = SVMKit::Dataset.load_libsvm_file('pendigits')
|
75
|
+
|
76
|
+
kernel_svc =
|
77
|
+
SVMKit::KernelMachine::KernelSVC.new(reg_param: 1.0, max_iter: 1000, random_seed: 1)
|
78
|
+
ovr_kernel_svc = SVMKit::Multiclass::OneVsRestClassifier.new(estimator: kernel_svc)
|
79
|
+
|
80
|
+
kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5, shuffle: true, random_seed: 1)
|
81
|
+
cv = SVMKit::ModelSelection::CrossValidation.new(estimator: ovr_kernel_svc, splitter: kf)
|
82
|
+
|
83
|
+
kernel_mat = SVMKit::PairwiseMetric::rbf_kernel(samples, nil, 0.005)
|
84
|
+
report = cv.perform(kernel_mat, labels)
|
85
|
+
|
86
|
+
mean_accuracy = report[:test_score].inject(:+) / kf.n_splits
|
87
|
+
puts(sprintf("Mean Accuracy: %.1f%%", 100.0 * mean_accuracy))
|
88
|
+
```
|
89
|
+
|
69
90
|
## Development
|
70
91
|
|
71
92
|
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake spec` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
data/lib/svmkit.rb
CHANGED
data/lib/svmkit/base/splitter.rb
CHANGED
@@ -9,7 +9,7 @@ module SVMKit
|
|
9
9
|
|
10
10
|
# An abstract method for splitting dataset.
|
11
11
|
def split
|
12
|
-
raise
|
12
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
13
13
|
end
|
14
14
|
end
|
15
15
|
end
|
@@ -68,7 +68,7 @@ module SVMKit
|
|
68
68
|
weight_vec[target_id] += 1.0 if func < 1.0
|
69
69
|
end
|
70
70
|
# Store the learned model.
|
71
|
-
@weight_vec = weight_vec * Numo::DFloat
|
71
|
+
@weight_vec = weight_vec * Numo::DFloat[*bin_y]
|
72
72
|
self
|
73
73
|
end
|
74
74
|
|
@@ -78,7 +78,7 @@ module SVMKit
|
|
78
78
|
# The kernel matrix between testing samples and training samples to compute the scores.
|
79
79
|
# @return [Numo::DFloat] (shape: [n_testing_samples]) Confidence score per sample.
|
80
80
|
def decision_function(x)
|
81
|
-
|
81
|
+
x.dot(@weight_vec)
|
82
82
|
end
|
83
83
|
|
84
84
|
# Predict class labels for samples.
|
@@ -74,7 +74,7 @@ module SVMKit
|
|
74
74
|
end
|
75
75
|
# Initialize some variables.
|
76
76
|
n_samples, n_features = samples.shape
|
77
|
-
rand_ids = [*0
|
77
|
+
rand_ids = [*0...n_samples].shuffle(random: @rng)
|
78
78
|
weight_vec = Numo::DFloat.zeros(n_features)
|
79
79
|
# Start optimization.
|
80
80
|
@params[:max_iter].times do |t|
|
@@ -70,7 +70,7 @@ module SVMKit
|
|
70
70
|
end
|
71
71
|
# Initialize some variables.
|
72
72
|
n_samples, n_features = samples.shape
|
73
|
-
rand_ids = [*0
|
73
|
+
rand_ids = [*0...n_samples].shuffle(random: @rng)
|
74
74
|
weight_vec = Numo::DFloat.zeros(n_features)
|
75
75
|
# Start optimization.
|
76
76
|
@params[:max_iter].times do |t|
|
@@ -0,0 +1,82 @@
|
|
1
|
+
require 'svmkit/base/splitter'
|
2
|
+
|
3
|
+
module SVMKit
|
4
|
+
# This module consists of the classes for model validation techniques.
|
5
|
+
module ModelSelection
|
6
|
+
# CrossValidation is a class that evaluates a given classifier with cross-validation method.
|
7
|
+
#
|
8
|
+
# @example
|
9
|
+
# svc = SVMKit::LinearModel::SVC.new
|
10
|
+
# kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5)
|
11
|
+
# cv = SVMKit::ModelSelection::CrossValidation.new(estimator: svc, splitter: kf)
|
12
|
+
# report = cv.perform(samples, lables)
|
13
|
+
# mean_test_score = report[:test_score].inject(:+) / kf.n_splits
|
14
|
+
#
|
15
|
+
class CrossValidation
|
16
|
+
# Return the classifier of which performance is evaluated.
|
17
|
+
# @return [Classifier]
|
18
|
+
attr_reader :estimator
|
19
|
+
|
20
|
+
# Return the splitter that divides dataset.
|
21
|
+
# @return [Splitter]
|
22
|
+
attr_reader :splitter
|
23
|
+
|
24
|
+
# Return the flag indicating whether to caculate the score of training dataset.
|
25
|
+
# @return [Boolean]
|
26
|
+
attr_reader :return_train_score
|
27
|
+
|
28
|
+
# Create a new evaluator with cross-validation method.
|
29
|
+
#
|
30
|
+
# @param estimator [Classifier] The classifier of which performance is evaluated.
|
31
|
+
# @param splitter [Splitter] The splitter that divides dataset to training and testing dataset.
|
32
|
+
# @param return_train_score [Boolean] The flag indicating whether to calculate the score of training dataset.
|
33
|
+
def initialize(estimator: nil, splitter: nil, return_train_score: false)
|
34
|
+
@estimator = estimator
|
35
|
+
@splitter = splitter
|
36
|
+
@return_train_score = return_train_score
|
37
|
+
end
|
38
|
+
|
39
|
+
# Perform the evalution of given classifier with cross-validation method.
|
40
|
+
#
|
41
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
42
|
+
# The dataset to be used to evaluate the classifier.
|
43
|
+
# @param y [Numo::Int32] (shape: [n_samples])
|
44
|
+
# The labels to be used to evaluate the classifier.
|
45
|
+
# @return [Hash] The report summarizing the results of cross-validation.
|
46
|
+
# * :fit_time (Array<Float>) The calculation times of fitting the estimator for each split.
|
47
|
+
# * :test_score (Array<Float>) The scores of testing dataset for each split.
|
48
|
+
# * :train_score (Array<Float>) The scores of training dataset for each split. This option is nil if
|
49
|
+
# the return_train_score is false.
|
50
|
+
def perform(x, y)
|
51
|
+
# Initialize the report of cross validation.
|
52
|
+
report = {test_score: [], train_score: nil, fit_time: []}
|
53
|
+
report[:train_score] = [] if @return_train_score
|
54
|
+
# Evaluate the estimator on each split.
|
55
|
+
@splitter.split(x, y).each do |train_ids, test_ids|
|
56
|
+
# Split dataset into training and testing dataset.
|
57
|
+
feature_ids = !kernel_machine? || train_ids
|
58
|
+
train_x = x[train_ids, feature_ids]
|
59
|
+
train_y = y[train_ids]
|
60
|
+
test_x = x[test_ids, feature_ids]
|
61
|
+
test_y = y[test_ids]
|
62
|
+
# Fit the estimator.
|
63
|
+
start_time = Time.now.to_i
|
64
|
+
@estimator.fit(train_x, train_y)
|
65
|
+
# Calculate scores and prepare the report.
|
66
|
+
report[:fit_time].push(Time.now.to_i - start_time)
|
67
|
+
report[:test_score].push(@estimator.score(test_x, test_y))
|
68
|
+
report[:train_score].push(@estimator.score(train_x, train_y)) if @return_train_score
|
69
|
+
end
|
70
|
+
report
|
71
|
+
end
|
72
|
+
|
73
|
+
private
|
74
|
+
|
75
|
+
def kernel_machine?
|
76
|
+
class_name = @estimator.class.to_s
|
77
|
+
class_name = @estimator.params[:estimator].class.to_s if class_name.include?('Multiclass')
|
78
|
+
class_name.include?('KernelMachine')
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
@@ -16,7 +16,7 @@ module SVMKit
|
|
16
16
|
class StratifiedKFold
|
17
17
|
include Base::Splitter
|
18
18
|
|
19
|
-
# Return the
|
19
|
+
# Return the flag indicating whether to shuffle the dataset.
|
20
20
|
# @return [Boolean]
|
21
21
|
attr_reader :shuffle
|
22
22
|
|
@@ -47,7 +47,7 @@ module SVMKit
|
|
47
47
|
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
48
48
|
def split(x, y) # rubocop:disable Lint/UnusedMethodArgument
|
49
49
|
# Check the number of samples in each class.
|
50
|
-
unless
|
50
|
+
unless valid_n_splits?(y)
|
51
51
|
raise ArgumentError,
|
52
52
|
'The value of n_splits must be not less than 2 and not more than the number of samples in each class.'
|
53
53
|
end
|
@@ -59,6 +59,10 @@ module SVMKit
|
|
59
59
|
|
60
60
|
private
|
61
61
|
|
62
|
+
def valid_n_splits?(y)
|
63
|
+
y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(2, n_samples) }
|
64
|
+
end
|
65
|
+
|
62
66
|
def fold_sets(y, label)
|
63
67
|
sample_ids = y.eq(label).where.to_a
|
64
68
|
sample_ids.shuffle!(random: @rng) if @shuffle
|
data/lib/svmkit/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: svmkit
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-01-
|
11
|
+
date: 2018-01-15 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -112,6 +112,7 @@ files:
|
|
112
112
|
- lib/svmkit/kernel_machine/kernel_svc.rb
|
113
113
|
- lib/svmkit/linear_model/logistic_regression.rb
|
114
114
|
- lib/svmkit/linear_model/svc.rb
|
115
|
+
- lib/svmkit/model_selection/cross_validation.rb
|
115
116
|
- lib/svmkit/model_selection/k_fold.rb
|
116
117
|
- lib/svmkit/model_selection/stratified_k_fold.rb
|
117
118
|
- lib/svmkit/multiclass/one_vs_rest_classifier.rb
|