svmkit 0.2.2 → 0.2.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.travis.yml +4 -3
- data/HISTORY.md +6 -1
- data/README.md +21 -0
- data/lib/svmkit.rb +1 -0
- data/lib/svmkit/base/splitter.rb +1 -1
- data/lib/svmkit/kernel_machine/kernel_svc.rb +2 -2
- data/lib/svmkit/linear_model/logistic_regression.rb +1 -1
- data/lib/svmkit/linear_model/svc.rb +1 -1
- data/lib/svmkit/model_selection/cross_validation.rb +82 -0
- data/lib/svmkit/model_selection/k_fold.rb +1 -1
- data/lib/svmkit/model_selection/stratified_k_fold.rb +6 -2
- data/lib/svmkit/version.rb +1 -1
- metadata +3 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA1:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 6271a50754a13199f7c3c12c6f1b9e2a0a2075d5
|
4
|
+
data.tar.gz: ecdc84a2987f22d49ad1b8435397862771f96f37
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9974eb62cd19ebca32ca92cafbcbf2e34a978d41c0aa9bb0765c001d56963189aa43f71806fa5a2492a4b3e50bcd939a018aa784198f5aef08e2159682682dd2
|
7
|
+
data.tar.gz: eee0f089449a71f79576165aa083c41fe1d9812fcc251c31c2a5c7cf06e6dfd4293dddd6685fa6285c8d89a3744bb3a72bd9db6f36b7cbc586b7a71800c1fb78
|
data/.travis.yml
CHANGED
data/HISTORY.md
CHANGED
@@ -1,5 +1,10 @@
|
|
1
|
+
# 0.2.3
|
2
|
+
- Added class for cross validation.
|
3
|
+
- Added specs for base modules.
|
4
|
+
- Fixed validation of the number of splits when a negative label is given.
|
5
|
+
|
1
6
|
# 0.2.2
|
2
|
-
- Added classes for K-fold cross validation.
|
7
|
+
- Added data splitter classes for K-fold cross validation.
|
3
8
|
|
4
9
|
# 0.2.1
|
5
10
|
- Added class for K-nearest neighbors classifier.
|
data/README.md
CHANGED
@@ -66,6 +66,27 @@ transformed = transformer.transform(normalized)
|
|
66
66
|
puts(sprintf("Accuracy: %.1f%%", 100.0 * classifier.score(transformed, labels)))
|
67
67
|
```
|
68
68
|
|
69
|
+
5-fold cross-validation:
|
70
|
+
|
71
|
+
```ruby
|
72
|
+
require 'svmkit'
|
73
|
+
|
74
|
+
samples, labels = SVMKit::Dataset.load_libsvm_file('pendigits')
|
75
|
+
|
76
|
+
kernel_svc =
|
77
|
+
SVMKit::KernelMachine::KernelSVC.new(reg_param: 1.0, max_iter: 1000, random_seed: 1)
|
78
|
+
ovr_kernel_svc = SVMKit::Multiclass::OneVsRestClassifier.new(estimator: kernel_svc)
|
79
|
+
|
80
|
+
kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5, shuffle: true, random_seed: 1)
|
81
|
+
cv = SVMKit::ModelSelection::CrossValidation.new(estimator: ovr_kernel_svc, splitter: kf)
|
82
|
+
|
83
|
+
kernel_mat = SVMKit::PairwiseMetric::rbf_kernel(samples, nil, 0.005)
|
84
|
+
report = cv.perform(kernel_mat, labels)
|
85
|
+
|
86
|
+
mean_accuracy = report[:test_score].inject(:+) / kf.n_splits
|
87
|
+
puts(sprintf("Mean Accuracy: %.1f%%", 100.0 * mean_accuracy))
|
88
|
+
```
|
89
|
+
|
69
90
|
## Development
|
70
91
|
|
71
92
|
After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake spec` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
|
data/lib/svmkit.rb
CHANGED
data/lib/svmkit/base/splitter.rb
CHANGED
@@ -9,7 +9,7 @@ module SVMKit
|
|
9
9
|
|
10
10
|
# An abstract method for splitting dataset.
|
11
11
|
def split
|
12
|
-
raise
|
12
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
13
13
|
end
|
14
14
|
end
|
15
15
|
end
|
@@ -68,7 +68,7 @@ module SVMKit
|
|
68
68
|
weight_vec[target_id] += 1.0 if func < 1.0
|
69
69
|
end
|
70
70
|
# Store the learned model.
|
71
|
-
@weight_vec = weight_vec * Numo::DFloat
|
71
|
+
@weight_vec = weight_vec * Numo::DFloat[*bin_y]
|
72
72
|
self
|
73
73
|
end
|
74
74
|
|
@@ -78,7 +78,7 @@ module SVMKit
|
|
78
78
|
# The kernel matrix between testing samples and training samples to compute the scores.
|
79
79
|
# @return [Numo::DFloat] (shape: [n_testing_samples]) Confidence score per sample.
|
80
80
|
def decision_function(x)
|
81
|
-
|
81
|
+
x.dot(@weight_vec)
|
82
82
|
end
|
83
83
|
|
84
84
|
# Predict class labels for samples.
|
@@ -74,7 +74,7 @@ module SVMKit
|
|
74
74
|
end
|
75
75
|
# Initialize some variables.
|
76
76
|
n_samples, n_features = samples.shape
|
77
|
-
rand_ids = [*0
|
77
|
+
rand_ids = [*0...n_samples].shuffle(random: @rng)
|
78
78
|
weight_vec = Numo::DFloat.zeros(n_features)
|
79
79
|
# Start optimization.
|
80
80
|
@params[:max_iter].times do |t|
|
@@ -70,7 +70,7 @@ module SVMKit
|
|
70
70
|
end
|
71
71
|
# Initialize some variables.
|
72
72
|
n_samples, n_features = samples.shape
|
73
|
-
rand_ids = [*0
|
73
|
+
rand_ids = [*0...n_samples].shuffle(random: @rng)
|
74
74
|
weight_vec = Numo::DFloat.zeros(n_features)
|
75
75
|
# Start optimization.
|
76
76
|
@params[:max_iter].times do |t|
|
@@ -0,0 +1,82 @@
|
|
1
|
+
require 'svmkit/base/splitter'
|
2
|
+
|
3
|
+
module SVMKit
|
4
|
+
# This module consists of the classes for model validation techniques.
|
5
|
+
module ModelSelection
|
6
|
+
# CrossValidation is a class that evaluates a given classifier with cross-validation method.
|
7
|
+
#
|
8
|
+
# @example
|
9
|
+
# svc = SVMKit::LinearModel::SVC.new
|
10
|
+
# kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5)
|
11
|
+
# cv = SVMKit::ModelSelection::CrossValidation.new(estimator: svc, splitter: kf)
|
12
|
+
# report = cv.perform(samples, lables)
|
13
|
+
# mean_test_score = report[:test_score].inject(:+) / kf.n_splits
|
14
|
+
#
|
15
|
+
class CrossValidation
|
16
|
+
# Return the classifier of which performance is evaluated.
|
17
|
+
# @return [Classifier]
|
18
|
+
attr_reader :estimator
|
19
|
+
|
20
|
+
# Return the splitter that divides dataset.
|
21
|
+
# @return [Splitter]
|
22
|
+
attr_reader :splitter
|
23
|
+
|
24
|
+
# Return the flag indicating whether to caculate the score of training dataset.
|
25
|
+
# @return [Boolean]
|
26
|
+
attr_reader :return_train_score
|
27
|
+
|
28
|
+
# Create a new evaluator with cross-validation method.
|
29
|
+
#
|
30
|
+
# @param estimator [Classifier] The classifier of which performance is evaluated.
|
31
|
+
# @param splitter [Splitter] The splitter that divides dataset to training and testing dataset.
|
32
|
+
# @param return_train_score [Boolean] The flag indicating whether to calculate the score of training dataset.
|
33
|
+
def initialize(estimator: nil, splitter: nil, return_train_score: false)
|
34
|
+
@estimator = estimator
|
35
|
+
@splitter = splitter
|
36
|
+
@return_train_score = return_train_score
|
37
|
+
end
|
38
|
+
|
39
|
+
# Perform the evalution of given classifier with cross-validation method.
|
40
|
+
#
|
41
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
42
|
+
# The dataset to be used to evaluate the classifier.
|
43
|
+
# @param y [Numo::Int32] (shape: [n_samples])
|
44
|
+
# The labels to be used to evaluate the classifier.
|
45
|
+
# @return [Hash] The report summarizing the results of cross-validation.
|
46
|
+
# * :fit_time (Array<Float>) The calculation times of fitting the estimator for each split.
|
47
|
+
# * :test_score (Array<Float>) The scores of testing dataset for each split.
|
48
|
+
# * :train_score (Array<Float>) The scores of training dataset for each split. This option is nil if
|
49
|
+
# the return_train_score is false.
|
50
|
+
def perform(x, y)
|
51
|
+
# Initialize the report of cross validation.
|
52
|
+
report = {test_score: [], train_score: nil, fit_time: []}
|
53
|
+
report[:train_score] = [] if @return_train_score
|
54
|
+
# Evaluate the estimator on each split.
|
55
|
+
@splitter.split(x, y).each do |train_ids, test_ids|
|
56
|
+
# Split dataset into training and testing dataset.
|
57
|
+
feature_ids = !kernel_machine? || train_ids
|
58
|
+
train_x = x[train_ids, feature_ids]
|
59
|
+
train_y = y[train_ids]
|
60
|
+
test_x = x[test_ids, feature_ids]
|
61
|
+
test_y = y[test_ids]
|
62
|
+
# Fit the estimator.
|
63
|
+
start_time = Time.now.to_i
|
64
|
+
@estimator.fit(train_x, train_y)
|
65
|
+
# Calculate scores and prepare the report.
|
66
|
+
report[:fit_time].push(Time.now.to_i - start_time)
|
67
|
+
report[:test_score].push(@estimator.score(test_x, test_y))
|
68
|
+
report[:train_score].push(@estimator.score(train_x, train_y)) if @return_train_score
|
69
|
+
end
|
70
|
+
report
|
71
|
+
end
|
72
|
+
|
73
|
+
private
|
74
|
+
|
75
|
+
def kernel_machine?
|
76
|
+
class_name = @estimator.class.to_s
|
77
|
+
class_name = @estimator.params[:estimator].class.to_s if class_name.include?('Multiclass')
|
78
|
+
class_name.include?('KernelMachine')
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
@@ -16,7 +16,7 @@ module SVMKit
|
|
16
16
|
class StratifiedKFold
|
17
17
|
include Base::Splitter
|
18
18
|
|
19
|
-
# Return the
|
19
|
+
# Return the flag indicating whether to shuffle the dataset.
|
20
20
|
# @return [Boolean]
|
21
21
|
attr_reader :shuffle
|
22
22
|
|
@@ -47,7 +47,7 @@ module SVMKit
|
|
47
47
|
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
48
48
|
def split(x, y) # rubocop:disable Lint/UnusedMethodArgument
|
49
49
|
# Check the number of samples in each class.
|
50
|
-
unless
|
50
|
+
unless valid_n_splits?(y)
|
51
51
|
raise ArgumentError,
|
52
52
|
'The value of n_splits must be not less than 2 and not more than the number of samples in each class.'
|
53
53
|
end
|
@@ -59,6 +59,10 @@ module SVMKit
|
|
59
59
|
|
60
60
|
private
|
61
61
|
|
62
|
+
def valid_n_splits?(y)
|
63
|
+
y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(2, n_samples) }
|
64
|
+
end
|
65
|
+
|
62
66
|
def fold_sets(y, label)
|
63
67
|
sample_ids = y.eq(label).where.to_a
|
64
68
|
sample_ids.shuffle!(random: @rng) if @shuffle
|
data/lib/svmkit/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: svmkit
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2018-01-
|
11
|
+
date: 2018-01-15 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -112,6 +112,7 @@ files:
|
|
112
112
|
- lib/svmkit/kernel_machine/kernel_svc.rb
|
113
113
|
- lib/svmkit/linear_model/logistic_regression.rb
|
114
114
|
- lib/svmkit/linear_model/svc.rb
|
115
|
+
- lib/svmkit/model_selection/cross_validation.rb
|
115
116
|
- lib/svmkit/model_selection/k_fold.rb
|
116
117
|
- lib/svmkit/model_selection/stratified_k_fold.rb
|
117
118
|
- lib/svmkit/multiclass/one_vs_rest_classifier.rb
|