svmkit 0.2.5 → 0.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. checksums.yaml +4 -4
  2. data/.coveralls.yml +1 -0
  3. data/HISTORY.md +7 -0
  4. data/README.md +2 -1
  5. data/lib/svmkit.rb +3 -0
  6. data/lib/svmkit/base/base_estimator.rb +1 -0
  7. data/lib/svmkit/base/classifier.rb +9 -3
  8. data/lib/svmkit/base/evaluator.rb +1 -0
  9. data/lib/svmkit/base/splitter.rb +1 -0
  10. data/lib/svmkit/base/transformer.rb +1 -0
  11. data/lib/svmkit/dataset.rb +2 -0
  12. data/lib/svmkit/ensemble/random_forest_classifier.rb +161 -0
  13. data/lib/svmkit/evaluation_measure/accuracy.rb +2 -0
  14. data/lib/svmkit/evaluation_measure/f_score.rb +2 -0
  15. data/lib/svmkit/evaluation_measure/precision.rb +2 -0
  16. data/lib/svmkit/evaluation_measure/precision_recall.rb +2 -0
  17. data/lib/svmkit/evaluation_measure/recall.rb +2 -0
  18. data/lib/svmkit/kernel_approximation/rbf.rb +2 -0
  19. data/lib/svmkit/kernel_machine/kernel_svc.rb +3 -3
  20. data/lib/svmkit/linear_model/logistic_regression.rb +2 -11
  21. data/lib/svmkit/linear_model/svc.rb +2 -11
  22. data/lib/svmkit/model_selection/cross_validation.rb +2 -0
  23. data/lib/svmkit/model_selection/k_fold.rb +2 -0
  24. data/lib/svmkit/model_selection/stratified_k_fold.rb +2 -0
  25. data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +2 -11
  26. data/lib/svmkit/naive_bayes/naive_bayes.rb +5 -13
  27. data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +3 -12
  28. data/lib/svmkit/pairwise_metric.rb +2 -0
  29. data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +3 -12
  30. data/lib/svmkit/preprocessing/l2_normalizer.rb +2 -0
  31. data/lib/svmkit/preprocessing/min_max_scaler.rb +2 -0
  32. data/lib/svmkit/preprocessing/standard_scaler.rb +2 -0
  33. data/lib/svmkit/tree/decision_tree_classifier.rb +260 -0
  34. data/lib/svmkit/version.rb +4 -2
  35. data/svmkit.gemspec +2 -2
  36. metadata +9 -6
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA1:
3
- metadata.gz: 42b8bbee820defc7646b422fa160ade5dd0ffddd
4
- data.tar.gz: 5ba44c2c18a02231646456ab1ff73fee409c50c0
3
+ metadata.gz: b36f6b299c47d1107d587aafeb7bb66531f1208c
4
+ data.tar.gz: 1bd382f3339b8fb08454493a45a2338020791b6c
5
5
  SHA512:
6
- metadata.gz: 23bdab14e55581e61b9050c88167dd74ebd5c086dfa2c37f57aaebe2110e8d7bac70df659ae34615c788d36cd03fac36663267ed07cf950564da7b1e496e9b59
7
- data.tar.gz: c36b3565bd731e1613ee63f83ad83f2021e8c2c255267df00a0872ec8f1d9f2e9ab73edd10d5f2a5bef6ca1af8712aef4220434bc680f1f5228bf9727c94b6e4
6
+ metadata.gz: 0676f2e9b3ef4ac9786f10ca976721e73d2cd918a9c939900281e36267ab14a3413c3a719d9504415f527e5d6163d5640ea2af186e023c3980327cb7c476afba
7
+ data.tar.gz: e8dcf72f7d1641903a4625bb23399deabf9a19931a5c00bd2c5077b525a2d0361b194b8c43dee1b7a365b25eafdabb460e0a4ae21f17afd5b401379671379463
@@ -0,0 +1 @@
1
+ service_name: travis-ci
data/HISTORY.md CHANGED
@@ -1,3 +1,10 @@
1
+ # 0.2.6
2
+ - Added class for Decision Tree classifier.
3
+ - Added class for Random Forest classifier.
4
+ - Fixed to use frozen string literal.
5
+ - Refactored marshal dump method on some classes.
6
+ - Introduced Coveralls to confirm test coverage.
7
+
1
8
  # 0.2.5
2
9
  - Added classes for Naive Bayes classifier.
3
10
  - Fixed decision function method on Logistic Regression class.
data/README.md CHANGED
@@ -1,13 +1,14 @@
1
1
  # SVMKit
2
2
 
3
3
  [![Build Status](https://travis-ci.org/yoshoku/SVMKit.svg?branch=master)](https://travis-ci.org/yoshoku/SVMKit)
4
+ [![Coverage Status](https://coveralls.io/repos/github/yoshoku/SVMKit/badge.svg?branch=master)](https://coveralls.io/github/yoshoku/SVMKit?branch=master)
4
5
  [![Gem Version](https://badge.fury.io/rb/svmkit.svg)](https://badge.fury.io/rb/svmkit)
5
6
  [![BSD 2-Clause License](https://img.shields.io/badge/License-BSD%202--Clause-orange.svg)](https://github.com/yoshoku/SVMKit/blob/master/LICENSE.txt)
6
7
 
7
8
  SVMKit is a machine learninig library in Ruby.
8
9
  SVMKit provides machine learning algorithms with interfaces similar to Scikit-Learn in Python.
9
10
  SVMKit currently supports Linear / Kernel Support Vector Machine,
10
- Logistic Regression, Factorization Machine, Naive Bayes,
11
+ Logistic Regression, Factorization Machine, Naive Bayes, Decision Tree, Random Forest,
11
12
  K-nearest neighbor classifier, and cross-validation.
12
13
 
13
14
  ## Installation
@@ -1,3 +1,4 @@
1
+ # frozen_string_literal: true
1
2
 
2
3
  require 'numo/narray'
3
4
 
@@ -17,6 +18,8 @@ require 'svmkit/polynomial_model/factorization_machine_classifier'
17
18
  require 'svmkit/multiclass/one_vs_rest_classifier'
18
19
  require 'svmkit/nearest_neighbors/k_neighbors_classifier'
19
20
  require 'svmkit/naive_bayes/naive_bayes'
21
+ require 'svmkit/tree/decision_tree_classifier'
22
+ require 'svmkit/ensemble/random_forest_classifier'
20
23
  require 'svmkit/preprocessing/l2_normalizer'
21
24
  require 'svmkit/preprocessing/min_max_scaler'
22
25
  require 'svmkit/preprocessing/standard_scaler'
@@ -1,3 +1,4 @@
1
+ # frozen_string_literal: true
1
2
 
2
3
  module SVMKit
3
4
  # This module consists of basic mix-in classes.
@@ -1,3 +1,4 @@
1
+ # frozen_string_literal: true
1
2
 
2
3
  module SVMKit
3
4
  module Base
@@ -13,9 +14,14 @@ module SVMKit
13
14
  raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
14
15
  end
15
16
 
16
- # An abstract method for calculating classification accuracy.
17
- def score
18
- raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
17
+ # Claculate the mean accuracy of the given testing data.
18
+ #
19
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
20
+ # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
21
+ # @return [Float] Mean accuracy
22
+ def score(x, y)
23
+ evaluator = SVMKit::EvaluationMeasure::Accuracy.new
24
+ evaluator.score(y, predict(x))
19
25
  end
20
26
  end
21
27
  end
@@ -1,3 +1,4 @@
1
+ # frozen_string_literal: true
1
2
 
2
3
  module SVMKit
3
4
  module Base
@@ -1,3 +1,4 @@
1
+ # frozen_string_literal: true
1
2
 
2
3
  module SVMKit
3
4
  module Base
@@ -1,3 +1,4 @@
1
+ # frozen_string_literal: true
1
2
 
2
3
  module SVMKit
3
4
  module Base
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  module SVMKit
2
4
  # Module for loading and saving a dataset file.
3
5
  module Dataset
@@ -0,0 +1,161 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'svmkit/base/base_estimator'
4
+ require 'svmkit/base/classifier'
5
+
6
+ module SVMKit
7
+ # This module consists of the classes that implement ensemble-based methods.
8
+ module Ensemble
9
+ # RandomForestClassifier is a class that implements random forest for classification.
10
+ #
11
+ # @example
12
+ # estimator =
13
+ # SVMKit::Ensemble::RandomForestClassifier.new(
14
+ # n_estimators: 10, criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
15
+ # estimator.fit(training_samples, traininig_labels)
16
+ # results = estimator.predict(testing_samples)
17
+ #
18
+ class RandomForestClassifier
19
+ include Base::BaseEstimator
20
+ include Base::Classifier
21
+
22
+ # Return the set of estimators.
23
+ # @return [Array<DecisionTreeClassifier>]
24
+ attr_reader :estimators
25
+
26
+ # Return the class labels.
27
+ # @return [Numo::Int32] (size: n_classes)
28
+ attr_reader :classes
29
+
30
+ # Return the importance for each feature.
31
+ # @return [Numo::DFloat] (size: n_features)
32
+ attr_reader :feature_importances
33
+
34
+ # Return the random generator for performing random sampling in the Pegasos algorithm.
35
+ # @return [Random]
36
+ attr_reader :rng
37
+
38
+ # Create a new classifier with random forest.
39
+ #
40
+ # @param n_estimators [Integer] The numeber of decision trees for contructing random forest.
41
+ # @param criterion [String] The function to evalue spliting point. Supported criteria are 'gini' and 'entropy'.
42
+ # @param max_depth [Integer] The maximum depth of the tree.
43
+ # If nil is given, decision tree grows without concern for depth.
44
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
45
+ # If nil is given, number of leaves is not limited.
46
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
47
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
48
+ # If nil is given, split process considers all features.
49
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
50
+ # It is used to randomly determine the order of features when deciding spliting point.
51
+ def initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1,
52
+ max_features: nil, random_seed: nil)
53
+ @params = {}
54
+ @params[:n_estimators] = n_estimators
55
+ @params[:criterion] = criterion
56
+ @params[:max_depth] = max_depth
57
+ @params[:max_leaf_nodes] = max_leaf_nodes
58
+ @params[:min_samples_leaf] = min_samples_leaf
59
+ @params[:max_features] = max_features
60
+ @params[:random_seed] = random_seed
61
+ @params[:random_seed] ||= srand
62
+ @rng = Random.new(@params[:random_seed])
63
+ @estimators = nil
64
+ @classes = nil
65
+ @feature_importances = nil
66
+ end
67
+
68
+ # Fit the model with given training data.
69
+ #
70
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
71
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
72
+ # @return [RandomForestClassifier] The learned classifier itself.
73
+ def fit(x, y)
74
+ # Initialize some variables.
75
+ n_samples, n_features = x.shape
76
+ @params[:max_features] = n_features unless @params[:max_features].is_a?(Integer)
77
+ @params[:max_features] = [[1, @params[:max_features]].max, Math.sqrt(n_features).to_i].min
78
+ @classes = Numo::Int32.asarray(y.to_a.uniq.sort)
79
+ # Construct forest.
80
+ @estimators = Array.new(@params[:n_estimators]) do |_n|
81
+ tree = Tree::DecisionTreeClassifier.new(
82
+ criterion: @params[:criterion], max_depth: @params[:max_depth],
83
+ max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
84
+ max_features: @params[:max_features], random_seed: @params[:random_seed]
85
+ )
86
+ bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) }
87
+ tree.fit(x[bootstrap_ids, true], y[bootstrap_ids])
88
+ end
89
+ # Calculate feature importances.
90
+ @feature_importances = Numo::DFloat.zeros(n_features)
91
+ @estimators.each { |tree| @feature_importances += tree.feature_importances }
92
+ @feature_importances /= @feature_importances.sum
93
+ self
94
+ end
95
+
96
+ # Predict class labels for samples.
97
+ #
98
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
99
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
100
+ def predict(x)
101
+ n_samples, = x.shape
102
+ n_classes = @classes.size
103
+ classes_arr = @classes.to_a
104
+ ballot_box = Numo::DFloat.zeros(n_samples, n_classes)
105
+ @estimators.each do |tree|
106
+ predicted = tree.predict(x)
107
+ n_samples.times do |n|
108
+ class_id = classes_arr.index(predicted[n])
109
+ ballot_box[n, class_id] += 1.0 unless class_id.nil?
110
+ end
111
+ end
112
+ Numo::Int32[*Array.new(n_samples) { |n| @classes[ballot_box[n, true].max_index] }]
113
+ end
114
+
115
+ # Predict probability for samples.
116
+ #
117
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
118
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
119
+ def predict_proba(x)
120
+ n_samples, = x.shape
121
+ n_classes = @classes.size
122
+ classes_arr = @classes.to_a
123
+ ballot_box = Numo::DFloat.zeros(n_samples, n_classes)
124
+ @estimators.each do |tree|
125
+ probs = tree.predict_proba(x)
126
+ tree.classes.size.times do |n|
127
+ class_id = classes_arr.index(tree.classes[n])
128
+ ballot_box[true, class_id] += probs[true, n] unless class_id.nil?
129
+ end
130
+ end
131
+ (ballot_box.transpose / ballot_box.sum(axis: 1)).transpose
132
+ end
133
+
134
+ # Return the index of the leaf that each sample reached.
135
+ #
136
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
137
+ # @return [Numo::Int32] (shape: [n_samples, n_estimators]) Leaf index for sample.
138
+ def apply(x)
139
+ Numo::Int32[*Array.new(@params[:n_estimators]) { |n| @estimators[n].apply(x) }].transpose
140
+ end
141
+
142
+ # Dump marshal data.
143
+ # @return [Hash] The marshal data about RandomForestClassifier
144
+ def marshal_dump
145
+ { params: @params, estimators: @estimators, classes: @classes,
146
+ feature_importances: @feature_importances, rng: @rng }
147
+ end
148
+
149
+ # Load marshal data.
150
+ # @return [nil]
151
+ def marshal_load(obj)
152
+ @params = obj[:params]
153
+ @estimators = obj[:estimators]
154
+ @classes = obj[:classes]
155
+ @feature_importances = obj[:feature_importances]
156
+ @rng = obj[:rng]
157
+ nil
158
+ end
159
+ end
160
+ end
161
+ end
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/evaluator'
2
4
 
3
5
  module SVMKit
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/evaluator'
2
4
  require 'svmkit/evaluation_measure/precision_recall'
3
5
 
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/evaluator'
2
4
  require 'svmkit/evaluation_measure/precision_recall'
3
5
 
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/evaluator'
2
4
 
3
5
  module SVMKit
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/evaluator'
2
4
  require 'svmkit/evaluation_measure/precision_recall'
3
5
 
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/base_estimator'
2
4
  require 'svmkit/base/transformer'
3
5
 
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/base_estimator'
2
4
  require 'svmkit/base/classifier'
3
5
 
@@ -97,9 +99,7 @@ module SVMKit
97
99
  # @param y [Numo::Int32] (shape: [n_testing_samples]) True labels for testing data.
98
100
  # @return [Float] Mean accuracy
99
101
  def score(x, y)
100
- p = predict(x)
101
- n_hits = (y.to_a.map.with_index { |l, n| l == p[n] ? 1 : 0 }).inject(:+)
102
- n_hits / y.size.to_f
102
+ super
103
103
  end
104
104
 
105
105
  # Dump marshal data.
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/base_estimator'
2
4
  require 'svmkit/base/classifier'
3
5
 
@@ -135,17 +137,6 @@ module SVMKit
135
137
  proba
136
138
  end
137
139
 
138
- # Claculate the mean accuracy of the given testing data.
139
- #
140
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
141
- # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
142
- # @return [Float] Mean accuracy
143
- def score(x, y)
144
- p = predict(x)
145
- n_hits = (y.to_a.map.with_index { |l, n| l == p[n] ? 1 : 0 }).inject(:+)
146
- n_hits / y.size.to_f
147
- end
148
-
149
140
  # Dump marshal data.
150
141
  # @return [Hash] The marshal data about LogisticRegression.
151
142
  def marshal_dump
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/base_estimator'
2
4
  require 'svmkit/base/classifier'
3
5
 
@@ -118,17 +120,6 @@ module SVMKit
118
120
  Numo::Int32.cast(decision_function(x).map { |v| v >= 0 ? 1 : -1 })
119
121
  end
120
122
 
121
- # Claculate the mean accuracy of the given testing data.
122
- #
123
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
124
- # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
125
- # @return [Float] Mean accuracy
126
- def score(x, y)
127
- p = predict(x)
128
- n_hits = (y.to_a.map.with_index { |l, n| l == p[n] ? 1 : 0 }).inject(:+)
129
- n_hits / y.size.to_f
130
- end
131
-
132
123
  # Dump marshal data.
133
124
  # @return [Hash] The marshal data about SVC.
134
125
  def marshal_dump
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/splitter'
2
4
 
3
5
  module SVMKit
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/splitter'
2
4
 
3
5
  module SVMKit
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/splitter'
2
4
 
3
5
  module SVMKit
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/base_estimator.rb'
2
4
  require 'svmkit/base/classifier.rb'
3
5
 
@@ -68,17 +70,6 @@ module SVMKit
68
70
  Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
69
71
  end
70
72
 
71
- # Claculate the mean accuracy of the given testing data.
72
- #
73
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
74
- # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
75
- # @return [Float] Mean accuracy
76
- def score(x, y)
77
- p = predict(x)
78
- n_hits = (y.to_a.map.with_index { |l, n| l == p[n] ? 1 : 0 }).inject(:+)
79
- n_hits / y.size.to_f
80
- end
81
-
82
73
  # Dump marshal data.
83
74
  # @return [Hash] The marshal data about OneVsRestClassifier.
84
75
  def marshal_dump
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/base_estimator'
2
4
  require 'svmkit/base/classifier'
3
5
 
@@ -36,16 +38,6 @@ module SVMKit
36
38
  def predict_proba(x)
37
39
  Numo::NMath.exp(predict_log_proba(x)).abs
38
40
  end
39
-
40
- # Claculate the mean accuracy of the given testing data.
41
- #
42
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
43
- # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
44
- # @return [Float] Mean accuracy
45
- def score(x, y)
46
- evaluator = SVMKit::EvaluationMeasure::Accuracy.new
47
- evaluator.score(y, predict(x))
48
- end
49
41
  end
50
42
 
51
43
  # GaussianNB is a class that implements Gaussian Naive Bayes classifier.
@@ -109,7 +101,7 @@ module SVMKit
109
101
  #
110
102
  # @return [Hash] The marshal data about GaussianNB.
111
103
  def marshal_dump
112
- { params: params,
104
+ { params: @params,
113
105
  classes: @classes,
114
106
  class_priors: @class_priors,
115
107
  means: @means,
@@ -193,7 +185,7 @@ module SVMKit
193
185
  #
194
186
  # @return [Hash] The marshal data about MultinomialNB.
195
187
  def marshal_dump
196
- { params: params,
188
+ { params: @params,
197
189
  classes: @classes,
198
190
  class_priors: @class_priors,
199
191
  feature_probs: @feature_probs }
@@ -283,7 +275,7 @@ module SVMKit
283
275
  #
284
276
  # @return [Hash] The marshal data about BernoulliNB.
285
277
  def marshal_dump
286
- { params: params,
278
+ { params: @params,
287
279
  classes: @classes,
288
280
  class_priors: @class_priors,
289
281
  feature_probs: @feature_probs }
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/base_estimator'
2
4
  require 'svmkit/base/classifier'
3
5
 
@@ -79,21 +81,10 @@ module SVMKit
79
81
  Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
80
82
  end
81
83
 
82
- # Claculate the mean accuracy of the given testing data.
83
- #
84
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
85
- # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
86
- # @return [Float] Mean accuracy
87
- def score(x, y)
88
- p = predict(x)
89
- n_hits = (y.to_a.map.with_index { |l, n| l == p[n] ? 1 : 0 }).inject(:+)
90
- n_hits / y.size.to_f
91
- end
92
-
93
84
  # Dump marshal data.
94
85
  # @return [Hash] The marshal data about KNeighborsClassifier.
95
86
  def marshal_dump
96
- { params: params,
87
+ { params: @params,
97
88
  prototypes: @prototypes,
98
89
  labels: @labels,
99
90
  classes: @classes }
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  module SVMKit
2
4
  # Module for calculating pairwise distances, similarities, and kernels.
3
5
  module PairwiseMetric
@@ -1,8 +1,10 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/base_estimator'
2
4
  require 'svmkit/base/classifier'
3
5
 
4
6
  module SVMKit
5
- # This module consists of the classes that implemnt polynomial models.
7
+ # This module consists of the classes that implement polynomial models.
6
8
  module PolynomialModel
7
9
  # FactorizationMachineClassifier is a class that
8
10
  # implements Fatorization Machine for binary classification
@@ -136,17 +138,6 @@ module SVMKit
136
138
  proba
137
139
  end
138
140
 
139
- # Claculate the mean accuracy of the given testing data.
140
- #
141
- # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
142
- # @param y [Numo::Int32] (shape: [n_samples]) True labels for testing data.
143
- # @return [Float] Mean accuracy
144
- def score(x, y)
145
- p = predict(x)
146
- n_hits = (y.to_a.map.with_index { |l, n| l == p[n] ? 1 : 0 }).inject(:+)
147
- n_hits / y.size.to_f
148
- end
149
-
150
141
  # Dump marshal data.
151
142
  # @return [Hash] The marshal data about FactorizationMachineClassifier
152
143
  def marshal_dump
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/base_estimator'
2
4
  require 'svmkit/base/transformer'
3
5
 
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/base_estimator'
2
4
  require 'svmkit/base/transformer'
3
5
 
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  require 'svmkit/base/base_estimator'
2
4
  require 'svmkit/base/transformer'
3
5
 
@@ -0,0 +1,260 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'svmkit/base/base_estimator'
4
+ require 'svmkit/base/classifier'
5
+ require 'ostruct'
6
+
7
+ module SVMKit
8
+ # This module consists of the classes that implement tree models.
9
+ module Tree
10
+ # DecisionTreeClassifier is a class that implements decision tree for classification.
11
+ #
12
+ # @example
13
+ # estimator =
14
+ # SVMKit::Tree::DecisionTreeClassifier.new(
15
+ # criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
16
+ # estimator.fit(training_samples, traininig_labels)
17
+ # results = estimator.predict(testing_samples)
18
+ #
19
+ class DecisionTreeClassifier
20
+ include Base::BaseEstimator
21
+ include Base::Classifier
22
+
23
+ # Return the class labels.
24
+ # @return [Numo::Int32] (size: n_classes)
25
+ attr_reader :classes
26
+
27
+ # Return the importance for each feature.
28
+ # @return [Numo::DFloat] (size: n_features)
29
+ attr_reader :feature_importances
30
+
31
+ # Return the learned tree.
32
+ # @return [OpenStruct]
33
+ attr_reader :tree
34
+
35
+ # Return the random generator for performing random sampling in the Pegasos algorithm.
36
+ # @return [Random]
37
+ attr_reader :rng
38
+
39
+ # Return the labels assigned each leaf.
40
+ # @return [Numo::Int32] (size: n_leafs)
41
+ attr_reader :leaf_labels
42
+
43
+ # Create a new classifier with decision tree algorithm.
44
+ #
45
+ # @param criterion [String] The function to evalue spliting point. Supported criteria are 'gini' and 'entropy'.
46
+ # @param max_depth [Integer] The maximum depth of the tree.
47
+ # If nil is given, decision tree grows without concern for depth.
48
+ # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
49
+ # If nil is given, number of leaves is not limited.
50
+ # @param min_samples_leaf [Integer] The minimum number of samples at a leaf node.
51
+ # @param max_features [Integer] The number of features to consider when searching optimal split point.
52
+ # If nil is given, split process considers all features.
53
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
54
+ # It is used to randomly determine the order of features when deciding spliting point.
55
+ def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil,
56
+ random_seed: nil)
57
+ @params = {}
58
+ @params[:criterion] = criterion
59
+ @params[:max_depth] = max_depth
60
+ @params[:max_leaf_nodes] = max_leaf_nodes
61
+ @params[:min_samples_leaf] = min_samples_leaf
62
+ @params[:max_features] = max_features
63
+ @params[:random_seed] = random_seed
64
+ @params[:random_seed] ||= srand
65
+ @rng = Random.new(@params[:random_seed])
66
+ @tree = nil
67
+ @classes = nil
68
+ @feature_importances = nil
69
+ @n_leaves = nil
70
+ @leaf_labels = nil
71
+ end
72
+
73
+ # Fit the model with given training data.
74
+ #
75
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
76
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
77
+ # @return [DecisionTreeClassifier] The learned classifier itself.
78
+ def fit(x, y)
79
+ n_samples, n_features = x.shape
80
+ @params[:max_features] = n_features unless @params[:max_features].is_a?(Integer)
81
+ @params[:max_features] = [[1, @params[:max_features]].max, n_features].min
82
+ @classes = Numo::Int32.asarray(y.to_a.uniq.sort)
83
+ build_tree(x, y)
84
+ eval_importance(n_samples, n_features)
85
+ self
86
+ end
87
+
88
+ # Predict class labels for samples.
89
+ #
90
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
91
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
92
+ def predict(x)
93
+ @leaf_labels[apply(x)]
94
+ end
95
+
96
+ # Predict probability for samples.
97
+ #
98
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
99
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
100
+ def predict_proba(x)
101
+ probs = Numo::DFloat[*(Array.new(x.shape[0]) { |n| predict_at_node(@tree, x[n, true]) })]
102
+ probs[true, @classes]
103
+ end
104
+
105
+ # Return the index of the leaf that each sample reached.
106
+ #
107
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
108
+ # @return [Numo::Int32] (shape: [n_samples]) Leaf index for sample.
109
+ def apply(x)
110
+ Numo::Int32[*(Array.new(x.shape[0]) { |n| apply_at_node(@tree, x[n, true]) })]
111
+ end
112
+
113
+ # Dump marshal data.
114
+ # @return [Hash] The marshal data about DecisionTreeClassifier
115
+ def marshal_dump
116
+ { params: @params, classes: @classes, tree: @tree,
117
+ feature_importances: @feature_importances, leaf_labels: @leaf_labels,
118
+ rng: @rng }
119
+ end
120
+
121
+ # Load marshal data.
122
+ # @return [nil]
123
+ def marshal_load(obj)
124
+ @params = obj[:params]
125
+ @classes = obj[:classes]
126
+ @tree = obj[:tree]
127
+ @feature_importances = obj[:feature_importances]
128
+ @leaf_labels = obj[:leaf_labels]
129
+ @rng = obj[:rng]
130
+ nil
131
+ end
132
+
133
+ private
134
+
135
+ def predict_at_node(node, sample)
136
+ return node.probs if node.leaf
137
+ branch_at_node('predict', node, sample)
138
+ end
139
+
140
+ def apply_at_node(node, sample)
141
+ return node.leaf_id if node.leaf
142
+ branch_at_node('apply', node, sample)
143
+ end
144
+
145
+ def branch_at_node(action, node, sample)
146
+ return send("#{action}_at_node", node.left, sample) if node.right.nil?
147
+ return send("#{action}_at_node", node.right, sample) if node.left.nil?
148
+ if sample[node.feature_id] <= node.threshold
149
+ send("#{action}_at_node", node.left, sample)
150
+ else
151
+ send("#{action}_at_node", node.right, sample)
152
+ end
153
+ end
154
+
155
+ def build_tree(x, y)
156
+ @n_leaves = 0
157
+ @leaf_labels = []
158
+ @tree = grow_node(0, x, y)
159
+ @leaf_labels = Numo::Int32[*@leaf_labels]
160
+ nil
161
+ end
162
+
163
+ def grow_node(depth, x, y)
164
+ if @params[:max_leaf_nodes].is_a?(Integer)
165
+ return nil if @n_leaves >= @params[:max_leaf_nodes]
166
+ end
167
+
168
+ n_samples, n_features = x.shape
169
+ if @params[:min_samples_leaf].is_a?(Integer)
170
+ return nil if n_samples <= @params[:min_samples_leaf]
171
+ end
172
+
173
+ node = OpenStruct.new(depth: depth, impurity: impurity(y), n_samples: n_samples)
174
+
175
+ return put_leaf(node, y) if y.to_a.uniq.size == 1
176
+
177
+ if @params[:max_depth].is_a?(Integer)
178
+ return put_leaf(node, y) if depth == @params[:max_depth]
179
+ end
180
+
181
+ feature_id, threshold, left_ids, right_ids, max_gain =
182
+ rand_ids(n_features).map { |f_id| [f_id, *best_split(x[true, f_id], y)] }.max_by(&:last)
183
+ return put_leaf(node, y) if max_gain.nil?
184
+ return put_leaf(node, y) if max_gain.zero?
185
+
186
+ node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids])
187
+ node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids])
188
+ return put_leaf(node, y) if node.left.nil? && node.right.nil?
189
+
190
+ node.feature_id = feature_id
191
+ node.threshold = threshold
192
+ node.leaf = false
193
+ node
194
+ end
195
+
196
+ def put_leaf(node, y)
197
+ node.probs = y.bincount(minlength: @classes.max + 1) / node.n_samples.to_f
198
+ node.leaf = true
199
+ node.leaf_id = @n_leaves
200
+ @n_leaves += 1
201
+ @leaf_labels.push(node.probs.max_index)
202
+ node
203
+ end
204
+
205
+ def rand_ids(n)
206
+ [*0...n].sample(@params[:max_features], random: @rng)
207
+ end
208
+
209
+ def best_split(features, labels)
210
+ features.to_a.uniq.sort.each_cons(2).map do |l, r|
211
+ threshold = 0.5 * (l + r)
212
+ left_ids, right_ids = splited_ids(features, threshold)
213
+ [threshold, left_ids, right_ids, gain(labels, labels[left_ids], labels[right_ids])]
214
+ end.max_by(&:last)
215
+ end
216
+
217
+ def splited_ids(features, threshold)
218
+ [features.le(threshold).where.to_a, features.gt(threshold).where.to_a]
219
+ end
220
+
221
+ def gain(labels, labels_left, labels_right)
222
+ prob_left = labels_left.size / labels.size.to_f
223
+ prob_right = labels_right.size / labels.size.to_f
224
+ impurity(labels) - prob_left * impurity(labels_left) - prob_right * impurity(labels_right)
225
+ end
226
+
227
+ def impurity(labels)
228
+ posterior_probs = labels.to_a.uniq.sort.map { |c| labels.eq(c).count / labels.size.to_f }
229
+ @params[:criterion] == 'entropy' ? entropy(posterior_probs) : gini(posterior_probs)
230
+ end
231
+
232
+ def gini(posterior_probs)
233
+ 1.0 - posterior_probs.map { |p| p**2 }.inject(:+)
234
+ end
235
+
236
+ def entropy(posterior_probs)
237
+ -posterior_probs.map { |p| p * Math.log(p) }.inject(:+)
238
+ end
239
+
240
+ def eval_importance(n_samples, n_features)
241
+ @feature_importances = Numo::DFloat.zeros(n_features)
242
+ eval_importance_at_node(@tree)
243
+ @feature_importances /= n_samples
244
+ normalizer = @feature_importances.sum
245
+ @feature_importances /= normalizer if normalizer > 0.0
246
+ nil
247
+ end
248
+
249
+ def eval_importance_at_node(node)
250
+ return nil if node.leaf
251
+ return nil if node.left.nil? || node.right.nil?
252
+ gain = node.n_samples * node.impurity -
253
+ node.left.n_samples * node.left.impurity - node.right.n_samples * node.right.impurity
254
+ @feature_importances[node.feature_id] += gain
255
+ eval_importance_at_node(node.left)
256
+ eval_importance_at_node(node.right)
257
+ end
258
+ end
259
+ end
260
+ end
@@ -1,5 +1,7 @@
1
- # SVMKit is an experimental library of machine learning in Ruby.
1
+ # frozen_string_literal: true
2
+
3
+ # SVMKit is a machine learning library in Ruby.
2
4
  module SVMKit
3
5
  # @!visibility private
4
- VERSION = '0.2.5'.freeze
6
+ VERSION = '0.2.6'
5
7
  end
@@ -18,7 +18,7 @@ MSG
18
18
  SVMKit is a machine learninig library in Ruby.
19
19
  SVMKit provides machine learning algorithms with interfaces similar to Scikit-Learn in Python.
20
20
  SVMKit currently supports Linear / Kernel Support Vector Machine,
21
- Logistic Regression, Factorization Machine, Naive Bayes,
21
+ Logistic Regression, Factorization Machine, Naive Bayes, Decision Tree, Random Forest,
22
22
  K-nearest neighbor classifier, and cross-validation.
23
23
  MSG
24
24
  spec.homepage = 'https://github.com/yoshoku/svmkit'
@@ -38,7 +38,7 @@ MSG
38
38
  spec.add_development_dependency 'bundler', '~> 1.16'
39
39
  spec.add_development_dependency 'rake', '~> 12.0'
40
40
  spec.add_development_dependency 'rspec', '~> 3.0'
41
- spec.add_development_dependency 'simplecov', '~> 0.15'
41
+ spec.add_development_dependency 'coveralls', '~> 0.8'
42
42
 
43
43
  spec.post_install_message = <<MSG
44
44
  *************************************************************************
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: svmkit
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.5
4
+ version: 0.2.6
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-02-24 00:00:00.000000000 Z
11
+ date: 2018-03-11 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -67,24 +67,24 @@ dependencies:
67
67
  - !ruby/object:Gem::Version
68
68
  version: '3.0'
69
69
  - !ruby/object:Gem::Dependency
70
- name: simplecov
70
+ name: coveralls
71
71
  requirement: !ruby/object:Gem::Requirement
72
72
  requirements:
73
73
  - - "~>"
74
74
  - !ruby/object:Gem::Version
75
- version: '0.15'
75
+ version: '0.8'
76
76
  type: :development
77
77
  prerelease: false
78
78
  version_requirements: !ruby/object:Gem::Requirement
79
79
  requirements:
80
80
  - - "~>"
81
81
  - !ruby/object:Gem::Version
82
- version: '0.15'
82
+ version: '0.8'
83
83
  description: |
84
84
  SVMKit is a machine learninig library in Ruby.
85
85
  SVMKit provides machine learning algorithms with interfaces similar to Scikit-Learn in Python.
86
86
  SVMKit currently supports Linear / Kernel Support Vector Machine,
87
- Logistic Regression, Factorization Machine, Naive Bayes,
87
+ Logistic Regression, Factorization Machine, Naive Bayes, Decision Tree, Random Forest,
88
88
  K-nearest neighbor classifier, and cross-validation.
89
89
  email:
90
90
  - yoshoku@outlook.com
@@ -92,6 +92,7 @@ executables: []
92
92
  extensions: []
93
93
  extra_rdoc_files: []
94
94
  files:
95
+ - ".coveralls.yml"
95
96
  - ".gitignore"
96
97
  - ".rspec"
97
98
  - ".rubocop.yml"
@@ -112,6 +113,7 @@ files:
112
113
  - lib/svmkit/base/splitter.rb
113
114
  - lib/svmkit/base/transformer.rb
114
115
  - lib/svmkit/dataset.rb
116
+ - lib/svmkit/ensemble/random_forest_classifier.rb
115
117
  - lib/svmkit/evaluation_measure/accuracy.rb
116
118
  - lib/svmkit/evaluation_measure/f_score.rb
117
119
  - lib/svmkit/evaluation_measure/precision.rb
@@ -132,6 +134,7 @@ files:
132
134
  - lib/svmkit/preprocessing/l2_normalizer.rb
133
135
  - lib/svmkit/preprocessing/min_max_scaler.rb
134
136
  - lib/svmkit/preprocessing/standard_scaler.rb
137
+ - lib/svmkit/tree/decision_tree_classifier.rb
135
138
  - lib/svmkit/version.rb
136
139
  - svmkit.gemspec
137
140
  homepage: https://github.com/yoshoku/svmkit