svmkit 0.6.2 → 0.6.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9655d7d990f37468c79de9713e55f74c6134b00c1cda9471832097c678cb6ded
4
- data.tar.gz: fb294e6256d16272e80c2ade2f5b223dd792b14c7fd39caf1de8a9500a3ab55e
3
+ metadata.gz: 8f6c3a9b6df704497579a44ed1649baeff28364d8992cc781babb03ee1fe836a
4
+ data.tar.gz: 99502858d1cecbe65efe5ee342cc6719278f227b1305c75d5305d3362d00e295
5
5
  SHA512:
6
- metadata.gz: f7386071fe57df51bd8223d4945dc67069769464892bd64b972b5c1ba26cd206b7b67d50e600f34d79a3bff9f19803c0fdae06dd92fdf8f6ef87f1d5e982cf2d
7
- data.tar.gz: 473a1233e0109672b80b8bf17933366276b0a81fab3c75699fb1a07f92923d66cc3f064e0a2df36e493bf1554aa4c34d567d60607d1802e10ab949eafe1187d3
6
+ metadata.gz: 86bf9a6ce8de82c7b51c9f7ff3451c9834ca73a84abb48394769e65bf80ebf8b00c565f27b69882719ad81c5b33a27af0156cc1b3cc2949f530eeacf11560718
7
+ data.tar.gz: 14a4e5a9becd0b41f43ba1a4cdee545f12f8d18f339feae789593c527025a9c6c0bb50d7471c5ffefffbe1ba96dc46e409aaa25e781ef52830adac476531ce58
data/HISTORY.md CHANGED
@@ -1,3 +1,6 @@
1
+ # 0.6.3
2
+ - Fix bug on setting random seed and max_features parameter of Random Forest estimators.
3
+
1
4
  # 0.6.2
2
5
  - Refactor decision tree classes for improving performance.
3
6
 
@@ -85,22 +85,22 @@ module SVMKit
85
85
  SVMKit::Validation.check_sample_label_size(x, y)
86
86
  # Initialize some variables.
87
87
  n_samples, n_features = x.shape
88
- @params[:max_features] = n_features unless @params[:max_features].is_a?(Integer)
89
- @params[:max_features] = [[1, @params[:max_features]].max, Math.sqrt(n_features).to_i].min
88
+ @params[:max_features] = Math.sqrt(n_features).to_i unless @params[:max_features].is_a?(Integer)
89
+ @params[:max_features] = [[1, @params[:max_features]].max, n_features].min
90
90
  @classes = Numo::Int32.asarray(y.to_a.uniq.sort)
91
+ @feature_importances = Numo::DFloat.zeros(n_features)
91
92
  # Construct forest.
92
- @estimators = Array.new(@params[:n_estimators]) do |_n|
93
+ @estimators = Array.new(@params[:n_estimators]) do
93
94
  tree = Tree::DecisionTreeClassifier.new(
94
95
  criterion: @params[:criterion], max_depth: @params[:max_depth],
95
96
  max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
96
- max_features: @params[:max_features], random_seed: @params[:random_seed]
97
+ max_features: @params[:max_features], random_seed: @rng.rand(int_max)
97
98
  )
98
99
  bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) }
99
100
  tree.fit(x[bootstrap_ids, true], y[bootstrap_ids])
101
+ @feature_importances += tree.feature_importances
102
+ tree
100
103
  end
101
- # Calculate feature importances.
102
- @feature_importances = Numo::DFloat.zeros(n_features)
103
- @estimators.each { |tree| @feature_importances += tree.feature_importances }
104
104
  @feature_importances /= @feature_importances.sum
105
105
  self
106
106
  end
@@ -157,8 +157,11 @@ module SVMKit
157
157
  # Dump marshal data.
158
158
  # @return [Hash] The marshal data about RandomForestClassifier.
159
159
  def marshal_dump
160
- { params: @params, estimators: @estimators, classes: @classes,
161
- feature_importances: @feature_importances, rng: @rng }
160
+ { params: @params,
161
+ estimators: @estimators,
162
+ classes: @classes,
163
+ feature_importances: @feature_importances,
164
+ rng: @rng }
162
165
  end
163
166
 
164
167
  # Load marshal data.
@@ -171,6 +174,12 @@ module SVMKit
171
174
  @rng = obj[:rng]
172
175
  nil
173
176
  end
177
+
178
+ private
179
+
180
+ def int_max
181
+ @int_max ||= 2**([42].pack('i').size * 16 - 2) - 1
182
+ end
174
183
  end
175
184
  end
176
185
  end
@@ -80,21 +80,22 @@ module SVMKit
80
80
  check_sample_tvalue_size(x, y)
81
81
  # Initialize some variables.
82
82
  n_samples, n_features = x.shape
83
- @params[:max_features] ||= n_features
84
- @params[:max_features] = [[1, @params[:max_features]].max, Math.sqrt(n_features).to_i].min
83
+ @params[:max_features] = Math.sqrt(n_features).to_i unless @params[:max_features].is_a?(Integer)
84
+ @params[:max_features] = [[1, @params[:max_features]].max, n_features].min
85
+ @feature_importances = Numo::DFloat.zeros(n_features)
85
86
  single_target = y.shape[1].nil?
86
87
  # Construct forest.
87
- @estimators = Array.new(@params[:n_estimators]) do |_n|
88
+ @estimators = Array.new(@params[:n_estimators]) do
88
89
  tree = Tree::DecisionTreeRegressor.new(
89
90
  criterion: @params[:criterion], max_depth: @params[:max_depth],
90
91
  max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
91
- max_features: @params[:max_features], random_seed: @params[:random_seed]
92
+ max_features: @params[:max_features], random_seed: @rng.rand(int_max)
92
93
  )
93
94
  bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) }
94
95
  tree.fit(x[bootstrap_ids, true], single_target ? y[bootstrap_ids] : y[bootstrap_ids, true])
96
+ @feature_importances += tree.feature_importances
97
+ tree
95
98
  end
96
- # Calculate feature importances.
97
- @feature_importances = @estimators.map(&:feature_importances).reduce(&:+)
98
99
  @feature_importances /= @feature_importances.sum
99
100
  self
100
101
  end
@@ -135,6 +136,12 @@ module SVMKit
135
136
  @rng = obj[:rng]
136
137
  nil
137
138
  end
139
+
140
+ private
141
+
142
+ def int_max
143
+ @int_max ||= 2**([42].pack('i').size * 16 - 2) - 1
144
+ end
138
145
  end
139
146
  end
140
147
  end
@@ -3,5 +3,5 @@
3
3
  # SVMKit is a machine learning library in Ruby.
4
4
  module SVMKit
5
5
  # @!visibility private
6
- VERSION = '0.6.2'.freeze
6
+ VERSION = '0.6.3'.freeze
7
7
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: svmkit
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.6.2
4
+ version: 0.6.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-09-17 00:00:00.000000000 Z
11
+ date: 2018-11-25 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray