svmkit 0.2.2 → 0.2.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA1:
3
- metadata.gz: c7c326db290b1847234f890914fe5d670a4b1d36
4
- data.tar.gz: ee5a624c92bf6b35edcccf4df469f9f552c2174d
3
+ metadata.gz: 6271a50754a13199f7c3c12c6f1b9e2a0a2075d5
4
+ data.tar.gz: ecdc84a2987f22d49ad1b8435397862771f96f37
5
5
  SHA512:
6
- metadata.gz: 9256fc3d36e6247fae44ac1a14672eaf9b3ba414176b48c592be5aa8631d232dbaddba9f3884198e0ba751616a3017ad461da2fb7ec40ef26b9ab2b2417aadf5
7
- data.tar.gz: e198dcbe0c7e782162a31e7131b961f619e76bf509b43b596299412ba5bb5ea16ea02e21bb2a79909d70bb230c81484083df94ec65db6e770e4e5adc712da174
6
+ metadata.gz: 9974eb62cd19ebca32ca92cafbcbf2e34a978d41c0aa9bb0765c001d56963189aa43f71806fa5a2492a4b3e50bcd939a018aa784198f5aef08e2159682682dd2
7
+ data.tar.gz: eee0f089449a71f79576165aa083c41fe1d9812fcc251c31c2a5c7cf06e6dfd4293dddd6685fa6285c8d89a3744bb3a72bd9db6f36b7cbc586b7a71800c1fb78
data/.travis.yml CHANGED
@@ -3,8 +3,9 @@ os: linux
3
3
  dist: trusty
4
4
  language: ruby
5
5
  rvm:
6
- - 2.2.9
7
- - 2.3.6
8
- - 2.4.3
6
+ - 2.2
7
+ - 2.3
8
+ - 2.4
9
+ - 2.5
9
10
  before_install:
10
11
  - gem install --no-document bundler -v '~> 1.16'
data/HISTORY.md CHANGED
@@ -1,5 +1,10 @@
1
+ # 0.2.3
2
+ - Added class for cross validation.
3
+ - Added specs for base modules.
4
+ - Fixed validation of the number of splits when a negative label is given.
5
+
1
6
  # 0.2.2
2
- - Added classes for K-fold cross validation.
7
+ - Added data splitter classes for K-fold cross validation.
3
8
 
4
9
  # 0.2.1
5
10
  - Added class for K-nearest neighbors classifier.
data/README.md CHANGED
@@ -66,6 +66,27 @@ transformed = transformer.transform(normalized)
66
66
  puts(sprintf("Accuracy: %.1f%%", 100.0 * classifier.score(transformed, labels)))
67
67
  ```
68
68
 
69
+ 5-fold cross-validation:
70
+
71
+ ```ruby
72
+ require 'svmkit'
73
+
74
+ samples, labels = SVMKit::Dataset.load_libsvm_file('pendigits')
75
+
76
+ kernel_svc =
77
+ SVMKit::KernelMachine::KernelSVC.new(reg_param: 1.0, max_iter: 1000, random_seed: 1)
78
+ ovr_kernel_svc = SVMKit::Multiclass::OneVsRestClassifier.new(estimator: kernel_svc)
79
+
80
+ kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5, shuffle: true, random_seed: 1)
81
+ cv = SVMKit::ModelSelection::CrossValidation.new(estimator: ovr_kernel_svc, splitter: kf)
82
+
83
+ kernel_mat = SVMKit::PairwiseMetric::rbf_kernel(samples, nil, 0.005)
84
+ report = cv.perform(kernel_mat, labels)
85
+
86
+ mean_accuracy = report[:test_score].inject(:+) / kf.n_splits
87
+ puts(sprintf("Mean Accuracy: %.1f%%", 100.0 * mean_accuracy))
88
+ ```
89
+
69
90
  ## Development
70
91
 
71
92
  After checking out the repo, run `bin/setup` to install dependencies. Then, run `rake spec` to run the tests. You can also run `bin/console` for an interactive prompt that will allow you to experiment.
data/lib/svmkit.rb CHANGED
@@ -19,3 +19,4 @@ require 'svmkit/preprocessing/min_max_scaler'
19
19
  require 'svmkit/preprocessing/standard_scaler'
20
20
  require 'svmkit/model_selection/k_fold'
21
21
  require 'svmkit/model_selection/stratified_k_fold'
22
+ require 'svmkit/model_selection/cross_validation'
@@ -9,7 +9,7 @@ module SVMKit
9
9
 
10
10
  # An abstract method for splitting dataset.
11
11
  def split
12
- raise NoImplementedError, "#{__method__} has to be implemented in #{self.class}."
12
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
13
13
  end
14
14
  end
15
15
  end
@@ -68,7 +68,7 @@ module SVMKit
68
68
  weight_vec[target_id] += 1.0 if func < 1.0
69
69
  end
70
70
  # Store the learned model.
71
- @weight_vec = weight_vec * Numo::DFloat.asarray(bin_y)
71
+ @weight_vec = weight_vec * Numo::DFloat[*bin_y]
72
72
  self
73
73
  end
74
74
 
@@ -78,7 +78,7 @@ module SVMKit
78
78
  # The kernel matrix between testing samples and training samples to compute the scores.
79
79
  # @return [Numo::DFloat] (shape: [n_testing_samples]) Confidence score per sample.
80
80
  def decision_function(x)
81
- @weight_vec.dot(x.transpose)
81
+ x.dot(@weight_vec)
82
82
  end
83
83
 
84
84
  # Predict class labels for samples.
@@ -74,7 +74,7 @@ module SVMKit
74
74
  end
75
75
  # Initialize some variables.
76
76
  n_samples, n_features = samples.shape
77
- rand_ids = [*0..n_samples - 1].shuffle(random: @rng)
77
+ rand_ids = [*0...n_samples].shuffle(random: @rng)
78
78
  weight_vec = Numo::DFloat.zeros(n_features)
79
79
  # Start optimization.
80
80
  @params[:max_iter].times do |t|
@@ -70,7 +70,7 @@ module SVMKit
70
70
  end
71
71
  # Initialize some variables.
72
72
  n_samples, n_features = samples.shape
73
- rand_ids = [*0..n_samples - 1].shuffle(random: @rng)
73
+ rand_ids = [*0...n_samples].shuffle(random: @rng)
74
74
  weight_vec = Numo::DFloat.zeros(n_features)
75
75
  # Start optimization.
76
76
  @params[:max_iter].times do |t|
@@ -0,0 +1,82 @@
1
+ require 'svmkit/base/splitter'
2
+
3
+ module SVMKit
4
+ # This module consists of the classes for model validation techniques.
5
+ module ModelSelection
6
+ # CrossValidation is a class that evaluates a given classifier with cross-validation method.
7
+ #
8
+ # @example
9
+ # svc = SVMKit::LinearModel::SVC.new
10
+ # kf = SVMKit::ModelSelection::StratifiedKFold.new(n_splits: 5)
11
+ # cv = SVMKit::ModelSelection::CrossValidation.new(estimator: svc, splitter: kf)
12
+ # report = cv.perform(samples, lables)
13
+ # mean_test_score = report[:test_score].inject(:+) / kf.n_splits
14
+ #
15
+ class CrossValidation
16
+ # Return the classifier of which performance is evaluated.
17
+ # @return [Classifier]
18
+ attr_reader :estimator
19
+
20
+ # Return the splitter that divides dataset.
21
+ # @return [Splitter]
22
+ attr_reader :splitter
23
+
24
+ # Return the flag indicating whether to caculate the score of training dataset.
25
+ # @return [Boolean]
26
+ attr_reader :return_train_score
27
+
28
+ # Create a new evaluator with cross-validation method.
29
+ #
30
+ # @param estimator [Classifier] The classifier of which performance is evaluated.
31
+ # @param splitter [Splitter] The splitter that divides dataset to training and testing dataset.
32
+ # @param return_train_score [Boolean] The flag indicating whether to calculate the score of training dataset.
33
+ def initialize(estimator: nil, splitter: nil, return_train_score: false)
34
+ @estimator = estimator
35
+ @splitter = splitter
36
+ @return_train_score = return_train_score
37
+ end
38
+
39
+ # Perform the evalution of given classifier with cross-validation method.
40
+ #
41
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
42
+ # The dataset to be used to evaluate the classifier.
43
+ # @param y [Numo::Int32] (shape: [n_samples])
44
+ # The labels to be used to evaluate the classifier.
45
+ # @return [Hash] The report summarizing the results of cross-validation.
46
+ # * :fit_time (Array<Float>) The calculation times of fitting the estimator for each split.
47
+ # * :test_score (Array<Float>) The scores of testing dataset for each split.
48
+ # * :train_score (Array<Float>) The scores of training dataset for each split. This option is nil if
49
+ # the return_train_score is false.
50
+ def perform(x, y)
51
+ # Initialize the report of cross validation.
52
+ report = {test_score: [], train_score: nil, fit_time: []}
53
+ report[:train_score] = [] if @return_train_score
54
+ # Evaluate the estimator on each split.
55
+ @splitter.split(x, y).each do |train_ids, test_ids|
56
+ # Split dataset into training and testing dataset.
57
+ feature_ids = !kernel_machine? || train_ids
58
+ train_x = x[train_ids, feature_ids]
59
+ train_y = y[train_ids]
60
+ test_x = x[test_ids, feature_ids]
61
+ test_y = y[test_ids]
62
+ # Fit the estimator.
63
+ start_time = Time.now.to_i
64
+ @estimator.fit(train_x, train_y)
65
+ # Calculate scores and prepare the report.
66
+ report[:fit_time].push(Time.now.to_i - start_time)
67
+ report[:test_score].push(@estimator.score(test_x, test_y))
68
+ report[:train_score].push(@estimator.score(train_x, train_y)) if @return_train_score
69
+ end
70
+ report
71
+ end
72
+
73
+ private
74
+
75
+ def kernel_machine?
76
+ class_name = @estimator.class.to_s
77
+ class_name = @estimator.params[:estimator].class.to_s if class_name.include?('Multiclass')
78
+ class_name.include?('KernelMachine')
79
+ end
80
+ end
81
+ end
82
+ end
@@ -16,7 +16,7 @@ module SVMKit
16
16
  class KFold
17
17
  include Base::Splitter
18
18
 
19
- # Return the proportion of the test set to the dataset.
19
+ # Return the flag indicating whether to shuffle the dataset.
20
20
  # @return [Boolean]
21
21
  attr_reader :shuffle
22
22
 
@@ -16,7 +16,7 @@ module SVMKit
16
16
  class StratifiedKFold
17
17
  include Base::Splitter
18
18
 
19
- # Return the proportion of the test set to the dataset.
19
+ # Return the flag indicating whether to shuffle the dataset.
20
20
  # @return [Boolean]
21
21
  attr_reader :shuffle
22
22
 
@@ -47,7 +47,7 @@ module SVMKit
47
47
  # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
48
48
  def split(x, y) # rubocop:disable Lint/UnusedMethodArgument
49
49
  # Check the number of samples in each class.
50
- unless y.bincount.to_a.all? { |n_samples| @n_splits.between?(2, n_samples) }
50
+ unless valid_n_splits?(y)
51
51
  raise ArgumentError,
52
52
  'The value of n_splits must be not less than 2 and not more than the number of samples in each class.'
53
53
  end
@@ -59,6 +59,10 @@ module SVMKit
59
59
 
60
60
  private
61
61
 
62
+ def valid_n_splits?(y)
63
+ y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(2, n_samples) }
64
+ end
65
+
62
66
  def fold_sets(y, label)
63
67
  sample_ids = y.eq(label).where.to_a
64
68
  sample_ids.shuffle!(random: @rng) if @shuffle
@@ -1,5 +1,5 @@
1
1
  # SVMKit is an experimental library of machine learning in Ruby.
2
2
  module SVMKit
3
3
  # @!visibility private
4
- VERSION = '0.2.2'.freeze
4
+ VERSION = '0.2.3'.freeze
5
5
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: svmkit
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.2
4
+ version: 0.2.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-01-13 00:00:00.000000000 Z
11
+ date: 2018-01-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -112,6 +112,7 @@ files:
112
112
  - lib/svmkit/kernel_machine/kernel_svc.rb
113
113
  - lib/svmkit/linear_model/logistic_regression.rb
114
114
  - lib/svmkit/linear_model/svc.rb
115
+ - lib/svmkit/model_selection/cross_validation.rb
115
116
  - lib/svmkit/model_selection/k_fold.rb
116
117
  - lib/svmkit/model_selection/stratified_k_fold.rb
117
118
  - lib/svmkit/multiclass/one_vs_rest_classifier.rb