svmkit 0.6.1 → 0.6.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 1d52bf496a762b096a5f6dbeec278a1cae8079b53d6c91cc13c07dca7a799fde
4
- data.tar.gz: e5ca2fed307b82e88dfe816691a4715d62a3187c1cad71421a48bea65037b19c
3
+ metadata.gz: 9655d7d990f37468c79de9713e55f74c6134b00c1cda9471832097c678cb6ded
4
+ data.tar.gz: fb294e6256d16272e80c2ade2f5b223dd792b14c7fd39caf1de8a9500a3ab55e
5
5
  SHA512:
6
- metadata.gz: 620370c119300f3f419550609444eba4aa34561a954e8ec26cf6a0d3522cd32cabf1f6875092de5ab0dd202ebf7b772c1d6d6421cd05d90cfeeeeadea3cd0565
7
- data.tar.gz: a0d8b5a7b91c4a8e2ffb4312a8082096a2e4fbd411e37bd36752ac19a66f6b9accf22be15894d6a88421b7296bde90ead33b55604085d139e5faec64b97f0f55
6
+ metadata.gz: f7386071fe57df51bd8223d4945dc67069769464892bd64b972b5c1ba26cd206b7b67d50e600f34d79a3bff9f19803c0fdae06dd92fdf8f6ef87f1d5e982cf2d
7
+ data.tar.gz: 473a1233e0109672b80b8bf17933366276b0a81fab3c75699fb1a07f92923d66cc3f064e0a2df36e493bf1554aa4c34d567d60607d1802e10ab949eafe1187d3
data/.rubocop.yml CHANGED
@@ -42,3 +42,6 @@ Style/FormatStringToken:
42
42
 
43
43
  Style/NumericLiterals:
44
44
  Enabled: false
45
+
46
+ Layout/EmptyLineAfterGuardClause:
47
+ Enabled: false
data/HISTORY.md CHANGED
@@ -1,7 +1,10 @@
1
+ # 0.6.2
2
+ - Refactor decision tree classes for improving performance.
3
+
1
4
  # 0.6.1
2
5
  - Add abstract class for linear estimator with stochastic gradient descent.
3
6
  - Refactor linear estimators to use linear esitmator abstract class.
4
- - Refactor decistion tree classes to avoid unneeded type conversion.
7
+ - Refactor decision tree classes to avoid unneeded type conversion.
5
8
 
6
9
  # 0.6.0
7
10
  - Add class for Principal Component Analysis.
@@ -79,7 +79,7 @@ module SVMKit
79
79
 
80
80
  def split_weight(weight)
81
81
  if @params[:fit_bias]
82
- [weight[0...-1], weight[-1]]
82
+ [weight[0...-1].dup, weight[-1]]
83
83
  else
84
84
  [weight, 0.0]
85
85
  end
@@ -253,7 +253,7 @@ module SVMKit
253
253
  end
254
254
 
255
255
  def split_weight_vec_bias(weight_vec)
256
- weights = weight_vec[0...-1]
256
+ weights = weight_vec[0...-1].dup
257
257
  bias = weight_vec[-1]
258
258
  [weights, bias]
259
259
  end
@@ -185,7 +185,7 @@ module SVMKit
185
185
  end
186
186
 
187
187
  def split_weight_vec_bias(weight_vec)
188
- weights = weight_vec[0...-1]
188
+ weights = weight_vec[0...-1].dup
189
189
  bias = weight_vec[-1]
190
190
  [weights, bias]
191
191
  end
@@ -91,8 +91,9 @@ module SVMKit
91
91
  n_samples, n_features = x.shape
92
92
  @params[:max_features] = n_features if @params[:max_features].nil?
93
93
  @params[:max_features] = [@params[:max_features], n_features].min
94
- @classes = Numo::Int32.asarray(y.to_a.uniq.sort)
95
- build_tree(x, y)
94
+ uniq_y = y.to_a.uniq.sort
95
+ @classes = Numo::Int32.asarray(uniq_y)
96
+ build_tree(x, y.map { |v| uniq_y.index(v) })
96
97
  eval_importance(n_samples, n_features)
97
98
  self
98
99
  end
@@ -174,36 +175,35 @@ module SVMKit
174
175
  def build_tree(x, y)
175
176
  @n_leaves = 0
176
177
  @leaf_labels = []
177
- @tree = grow_node(0, x, y)
178
+ @tree = grow_node(0, x, y, impurity(y))
178
179
  @leaf_labels = Numo::Int32[*@leaf_labels]
179
180
  nil
180
181
  end
181
182
 
182
- def grow_node(depth, x, y)
183
- if @params[:max_leaf_nodes].is_a?(Integer)
183
+ def grow_node(depth, x, y, whole_impurity)
184
+ unless @params[:max_leaf_nodes].nil?
184
185
  return nil if @n_leaves >= @params[:max_leaf_nodes]
185
186
  end
186
187
 
187
188
  n_samples, n_features = x.shape
188
- if @params[:min_samples_leaf].is_a?(Integer)
189
- return nil if n_samples <= @params[:min_samples_leaf]
190
- end
189
+ return nil if n_samples <= @params[:min_samples_leaf]
191
190
 
192
- node = Node.new(depth: depth, impurity: impurity(y), n_samples: n_samples)
191
+ node = Node.new(depth: depth, impurity: whole_impurity, n_samples: n_samples)
193
192
 
194
193
  return put_leaf(node, y) if y.to_a.uniq.size == 1
195
194
 
196
- if @params[:max_depth].is_a?(Integer)
195
+ unless @params[:max_depth].nil?
197
196
  return put_leaf(node, y) if depth == @params[:max_depth]
198
197
  end
199
198
 
200
- feature_id, threshold, left_ids, right_ids, max_gain =
201
- rand_ids(n_features).map { |f_id| [f_id, *best_split(x[true, f_id], y)] }.max_by(&:last)
202
- return put_leaf(node, y) if max_gain.nil?
203
- return put_leaf(node, y) if max_gain.zero?
199
+ feature_id, threshold, left_ids, right_ids, left_impurity, right_impurity, gain =
200
+ rand_ids(n_features).map { |f_id| [f_id, *best_split(x[true, f_id], y, whole_impurity)] }.max_by(&:last)
201
+
202
+ return put_leaf(node, y) if gain.nil? || gain.zero?
203
+
204
+ node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids], left_impurity)
205
+ node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids], right_impurity)
204
206
 
205
- node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids])
206
- node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids])
207
207
  return put_leaf(node, y) if node.left.nil? && node.right.nil?
208
208
 
209
209
  node.feature_id = feature_id
@@ -213,7 +213,7 @@ module SVMKit
213
213
  end
214
214
 
215
215
  def put_leaf(node, y)
216
- node.probs = Numo::DFloat.cast(@classes.map { |c| y.eq(c).count_true }) / node.n_samples
216
+ node.probs = y.bincount(minlength: @classes.size) / node.n_samples.to_f
217
217
  node.leaf = true
218
218
  node.leaf_id = @n_leaves
219
219
  @n_leaves += 1
@@ -225,27 +225,23 @@ module SVMKit
225
225
  [*0...n].sample(@params[:max_features], random: @rng)
226
226
  end
227
227
 
228
- def best_split(features, labels)
228
+ def best_split(features, labels, whole_impurity)
229
+ n_samples = labels.size
229
230
  features.to_a.uniq.sort.each_cons(2).map do |l, r|
230
231
  threshold = 0.5 * (l + r)
231
- left_ids, right_ids = splited_ids(features, threshold)
232
- [threshold, left_ids, right_ids, gain(labels, labels[left_ids], labels[right_ids])]
232
+ left_ids = features.le(threshold).where
233
+ right_ids = features.gt(threshold).where
234
+ left_impurity = impurity(labels[left_ids])
235
+ right_impurity = impurity(labels[right_ids])
236
+ gain = whole_impurity -
237
+ left_impurity * left_ids.size.fdiv(n_samples) -
238
+ right_impurity * right_ids.size.fdiv(n_samples)
239
+ [threshold, left_ids, right_ids, left_impurity, right_impurity, gain]
233
240
  end.max_by(&:last)
234
241
  end
235
242
 
236
- def splited_ids(features, threshold)
237
- [features.le(threshold).where, features.gt(threshold).where]
238
- end
239
-
240
- def gain(labels, labels_left, labels_right)
241
- prob_left = labels_left.size.fdiv(labels.size)
242
- prob_right = labels_right.size.fdiv(labels.size)
243
- impurity(labels) - prob_left * impurity(labels_left) - prob_right * impurity(labels_right)
244
- end
245
-
246
243
  def impurity(labels)
247
- cls = labels.to_a.uniq.sort
248
- cls.size == 1 ? 0.0 : send(@criterion, Numo::DFloat[*(cls.map { |c| labels.eq(c).count_true.fdiv(labels.size) })])
244
+ send(@criterion, labels.bincount / labels.size.to_f)
249
245
  end
250
246
 
251
247
  def gini(posterior_probs)
@@ -253,7 +249,7 @@ module SVMKit
253
249
  end
254
250
 
255
251
  def entropy(posterior_probs)
256
- -(posterior_probs * Numo::NMath.log(posterior_probs)).sum
252
+ -(posterior_probs * Numo::NMath.log(posterior_probs + 1)).sum
257
253
  end
258
254
 
259
255
  def eval_importance(n_samples, n_features)
@@ -269,7 +265,8 @@ module SVMKit
269
265
  return nil if node.leaf
270
266
  return nil if node.left.nil? || node.right.nil?
271
267
  gain = node.n_samples * node.impurity -
272
- node.left.n_samples * node.left.impurity - node.right.n_samples * node.right.impurity
268
+ node.left.n_samples * node.left.impurity -
269
+ node.right.n_samples * node.right.impurity
273
270
  @feature_importances[node.feature_id] += gain
274
271
  eval_importance_at_node(node.left)
275
272
  eval_importance_at_node(node.right)
@@ -151,12 +151,12 @@ module SVMKit
151
151
  def build_tree(x, y)
152
152
  @n_leaves = 0
153
153
  @leaf_values = []
154
- @tree = grow_node(0, x, y)
154
+ @tree = grow_node(0, x, y, impurity(y))
155
155
  @leaf_values = Numo::DFloat.cast(@leaf_values)
156
156
  nil
157
157
  end
158
158
 
159
- def grow_node(depth, x, y)
159
+ def grow_node(depth, x, y, whole_impurity)
160
160
  unless @params[:max_leaf_nodes].nil?
161
161
  return nil if @n_leaves >= @params[:max_leaf_nodes]
162
162
  end
@@ -164,7 +164,7 @@ module SVMKit
164
164
  n_samples, n_features = x.shape
165
165
  return nil if n_samples <= @params[:min_samples_leaf]
166
166
 
167
- node = Node.new(depth: depth, impurity: impurity(y), n_samples: n_samples)
167
+ node = Node.new(depth: depth, impurity: whole_impurity, n_samples: n_samples)
168
168
 
169
169
  return put_leaf(node, y) if (y - y.mean(0)).sum.abs.zero?
170
170
 
@@ -172,12 +172,14 @@ module SVMKit
172
172
  return put_leaf(node, y) if depth == @params[:max_depth]
173
173
  end
174
174
 
175
- feature_id, threshold, left_ids, right_ids, max_gain =
176
- rand_ids(n_features).map { |f_id| [f_id, *best_split(x[true, f_id], y)] }.max_by(&:last)
177
- return put_leaf(node, y) if max_gain.nil? || max_gain.zero?
175
+ feature_id, threshold, left_ids, right_ids, left_impurity, right_impurity, gain =
176
+ rand_ids(n_features).map { |f_id| [f_id, *best_split(x[true, f_id], y, whole_impurity)] }.max_by(&:last)
177
+
178
+ return put_leaf(node, y) if gain.nil? || gain.zero?
179
+
180
+ node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids, true], left_impurity)
181
+ node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids, true], right_impurity)
178
182
 
179
- node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids, true])
180
- node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids, true])
181
183
  return put_leaf(node, y) if node.left.nil? && node.right.nil?
182
184
 
183
185
  node.feature_id = feature_id
@@ -199,24 +201,21 @@ module SVMKit
199
201
  [*0...n].sample(@params[:max_features], random: @rng)
200
202
  end
201
203
 
202
- def best_split(features, values)
204
+ def best_split(features, values, whole_impurity)
205
+ n_samples = values.shape[0]
203
206
  features.to_a.uniq.sort.each_cons(2).map do |l, r|
204
207
  threshold = 0.5 * (l + r)
205
- left_ids, right_ids = splited_ids(features, threshold)
206
- [threshold, left_ids, right_ids, gain(values, values[left_ids], values[right_ids])]
208
+ left_ids = features.le(threshold).where
209
+ right_ids = features.gt(threshold).where
210
+ left_impurity = impurity(values[left_ids, true])
211
+ right_impurity = impurity(values[right_ids, true])
212
+ gain = whole_impurity -
213
+ left_impurity * left_ids.size.fdiv(n_samples) -
214
+ right_impurity * right_ids.size.fdiv(n_samples)
215
+ [threshold, left_ids, right_ids, left_impurity, right_impurity, gain]
207
216
  end.max_by(&:last)
208
217
  end
209
218
 
210
- def splited_ids(features, threshold)
211
- [features.le(threshold).where, features.gt(threshold).where]
212
- end
213
-
214
- def gain(values, values_left, values_right)
215
- prob_left = values_left.shape[0].fdiv(values.shape[0])
216
- prob_right = values_right.shape[0].fdiv(values.shape[0])
217
- impurity(values) - prob_left * impurity(values_left) - prob_right * impurity(values_right)
218
- end
219
-
220
219
  def impurity(values)
221
220
  send(@criterion, values)
222
221
  end
@@ -3,5 +3,5 @@
3
3
  # SVMKit is a machine learning library in Ruby.
4
4
  module SVMKit
5
5
  # @!visibility private
6
- VERSION = '0.6.1'.freeze
6
+ VERSION = '0.6.2'.freeze
7
7
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: svmkit
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.6.1
4
+ version: 0.6.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2018-09-10 00:00:00.000000000 Z
11
+ date: 2018-09-17 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray