svmkit 0.6.2 → 0.6.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9655d7d990f37468c79de9713e55f74c6134b00c1cda9471832097c678cb6ded
4
- data.tar.gz: fb294e6256d16272e80c2ade2f5b223dd792b14c7fd39caf1de8a9500a3ab55e
3
+ metadata.gz: 8f6c3a9b6df704497579a44ed1649baeff28364d8992cc781babb03ee1fe836a
4
+ data.tar.gz: 99502858d1cecbe65efe5ee342cc6719278f227b1305c75d5305d3362d00e295
5
5
  SHA512:
6
- metadata.gz: f7386071fe57df51bd8223d4945dc67069769464892bd64b972b5c1ba26cd206b7b67d50e600f34d79a3bff9f19803c0fdae06dd92fdf8f6ef87f1d5e982cf2d
7
- data.tar.gz: 473a1233e0109672b80b8bf17933366276b0a81fab3c75699fb1a07f92923d66cc3f064e0a2df36e493bf1554aa4c34d567d60607d1802e10ab949eafe1187d3
6
+ metadata.gz: 86bf9a6ce8de82c7b51c9f7ff3451c9834ca73a84abb48394769e65bf80ebf8b00c565f27b69882719ad81c5b33a27af0156cc1b3cc2949f530eeacf11560718
7
+ data.tar.gz: 14a4e5a9becd0b41f43ba1a4cdee545f12f8d18f339feae789593c527025a9c6c0bb50d7471c5ffefffbe1ba96dc46e409aaa25e781ef52830adac476531ce58
data/HISTORY.md CHANGED
@@ -1,3 +1,6 @@
1
+ # 0.6.3
2
+ - Fix bug on setting random seed and max_features parameter of Random Forest estimators.
3
+
1
4
  # 0.6.2
2
5
  - Refactor decision tree classes for improving performance.
3
6
 
@@ -85,22 +85,22 @@ module SVMKit
85
85
  SVMKit::Validation.check_sample_label_size(x, y)
86
86
  # Initialize some variables.
87
87
  n_samples, n_features = x.shape
88
- @params[:max_features] = n_features unless @params[:max_features].is_a?(Integer)
89
- @params[:max_features] = [[1, @params[:max_features]].max, Math.sqrt(n_features).to_i].min
88
+ @params[:max_features] = Math.sqrt(n_features).to_i unless @params[:max_features].is_a?(Integer)
89
+ @params[:max_features] = [[1, @params[:max_features]].max, n_features].min
90
90
  @classes = Numo::Int32.asarray(y.to_a.uniq.sort)
91
+ @feature_importances = Numo::DFloat.zeros(n_features)
91
92
  # Construct forest.
92
- @estimators = Array.new(@params[:n_estimators]) do |_n|
93
+ @estimators = Array.new(@params[:n_estimators]) do
93
94
  tree = Tree::DecisionTreeClassifier.new(
94
95
  criterion: @params[:criterion], max_depth: @params[:max_depth],
95
96
  max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
96
- max_features: @params[:max_features], random_seed: @params[:random_seed]
97
+ max_features: @params[:max_features], random_seed: @rng.rand(int_max)
97
98
  )
98
99
  bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) }
99
100
  tree.fit(x[bootstrap_ids, true], y[bootstrap_ids])
101
+ @feature_importances += tree.feature_importances
102
+ tree
100
103
  end
101
- # Calculate feature importances.
102
- @feature_importances = Numo::DFloat.zeros(n_features)
103
- @estimators.each { |tree| @feature_importances += tree.feature_importances }
104
104
  @feature_importances /= @feature_importances.sum
105
105
  self
106
106
  end
@@ -157,8 +157,11 @@ module SVMKit
157
157
  # Dump marshal data.
158
158
  # @return [Hash] The marshal data about RandomForestClassifier.
159
159
  def marshal_dump
160
- { params: @params, estimators: @estimators, classes: @classes,
161
- feature_importances: @feature_importances, rng: @rng }
160
+ { params: @params,
161
+ estimators: @estimators,
162
+ classes: @classes,
163
+ feature_importances: @feature_importances,
164
+ rng: @rng }
162
165
  end
163
166
 
164
167
  # Load marshal data.
@@ -171,6 +174,12 @@ module SVMKit
171
174
  @rng = obj[:rng]
172
175
  nil
173
176
  end
177
+
178
+ private
179
+
180
+ def int_max
181
+ @int_max ||= 2**([42].pack('i').size * 16 - 2) - 1
182
+ end
174
183
  end
175
184
  end
176
185
  end
@@ -80,21 +80,22 @@ module SVMKit
80
80
  check_sample_tvalue_size(x, y)
81
81
  # Initialize some variables.
82
82
  n_samples, n_features = x.shape
83
- @params[:max_features] ||= n_features
84
- @params[:max_features] = [[1, @params[:max_features]].max, Math.sqrt(n_features).to_i].min
83
+ @params[:max_features] = Math.sqrt(n_features).to_i unless @params[:max_features].is_a?(Integer)
84
+ @params[:max_features] = [[1, @params[:max_features]].max, n_features].min
85
+ @feature_importances = Numo::DFloat.zeros(n_features)
85
86
  single_target = y.shape[1].nil?
86
87
  # Construct forest.
87
- @estimators = Array.new(@params[:n_estimators]) do |_n|
88
+ @estimators = Array.new(@params[:n_estimators]) do
88
89
  tree = Tree::DecisionTreeRegressor.new(
89
90
  criterion: @params[:criterion], max_depth: @params[:max_depth],
90
91
  max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
91
- max_features: @params[:max_features], random_seed: @params[:random_seed]
92
+ max_features: @params[:max_features], random_seed: @rng.rand(int_max)
92
93
  )
93
94
  bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) }
94
95
  tree.fit(x[bootstrap_ids, true], single_target ? y[bootstrap_ids] : y[bootstrap_ids, true])
96
+ @feature_importances += tree.feature_importances
97
+ tree
95
98
  end
96
- # Calculate feature importances.
97
- @feature_importances = @estimators.map(&:feature_importances).reduce(&:+)
98
99
  @feature_importances /= @feature_importances.sum
99
100
  self
100
101
  end
@@ -135,6 +136,12 @@ module SVMKit
135
136
  @rng = obj[:rng]
136
137
  nil
137
138
  end
139
+
140
+ private
141
+
142
+ def int_max
143
+ @int_max ||= 2**([42].pack('i').size * 16 - 2) - 1
144
+ end
138
145
  end
139
146
  end
140
147
  end
@@ -3,5 +3,5 @@
3
3
  # SVMKit is a machine learning library in Ruby.
4
4
  module SVMKit
5
5
  # @!visibility private
6
- VERSION = '0.6.2'.freeze
6
+ VERSION = '0.6.3'.freeze
7
7
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: svmkit
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.6.2
4
+ version: 0.6.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-09-17 00:00:00.000000000 Z
11
+ date: 2018-11-25 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray