rumale 0.8.4 → 0.9.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA1:
3
- metadata.gz: 8e895ca9462569e1ec2b3b9c1cb985aebd7d4b19
4
- data.tar.gz: 9374969bd955cada6ba54ec6edf1d7bc174a5102
3
+ metadata.gz: a2dfbc60c9d47e741fc91497f8c58ade390e6c8f
4
+ data.tar.gz: d4cbc26e0d81fbe0de5e83d785cc836e9a5b2099
5
5
  SHA512:
6
- metadata.gz: 7e9eadf3404e74ee887007a1fb1df4e2b933f10394c1a61afdd8400492afcc4fbb40479c453983d2349db5d5c7cd52ab280a9ef28aa26d91b19cb41253bdb233
7
- data.tar.gz: 674bf164a3f1be2971fa2f882999ed34a22773d15674d6c2c24da66af888be855ffd3daa1d66a08ee7ec4bc99c87e864ffe61ebb7d6356dfae78a0a20e97c3a9
6
+ metadata.gz: 7f2b4b8ba5d7511215a2e850add19f0942cbff4157a8373eba1950c0eac9fcd0e44925d3a88b2a709c0308ef4c03cca44c501b710f4a22dc4dd573e6866d94dc
7
+ data.tar.gz: 4630710eef59af88274e9a411a6ad12de7e4a616280f8fc94d185e24c7bc667bf8c1f662425c64cf05f6ec9accd914ac32e1039688d09629b920329ad85354c8
data/.gitignore CHANGED
@@ -12,6 +12,7 @@
12
12
  .rspec_status
13
13
 
14
14
  *.swp
15
+ *.bundle
15
16
  .DS_Store
16
17
  .ruby-version
17
18
  /spec/dump_dbl.t
@@ -1,3 +1,8 @@
1
+ # 0.9.0
2
+ ## Breaking changes
3
+ - Decide to introduce Ruby extensions for improving performance.
4
+ - Fix to find split point on decision tree estimators using extension modules.
5
+
1
6
  # 0.8.4
2
7
  - Remove unused parameter on Nadam.
3
8
  - Fix condition to stop growing tree about decision tree.
data/Rakefile CHANGED
@@ -3,4 +3,12 @@ require 'rspec/core/rake_task'
3
3
 
4
4
  RSpec::Core::RakeTask.new(:spec)
5
5
 
6
- task :default => :spec
6
+ require 'rake/extensiontask'
7
+
8
+ task :build => :compile
9
+
10
+ Rake::ExtensionTask.new('rumale') do |ext|
11
+ ext.lib_dir = 'lib/rumale'
12
+ end
13
+
14
+ task :default => [:clobber, :compile, :spec]
@@ -0,0 +1,3 @@
1
+ require 'mkmf'
2
+
3
+ create_makefile('rumale/rumale')
@@ -0,0 +1,418 @@
1
+ #include "rumale.h"
2
+
3
+ VALUE
4
+ create_zero_vector(const long n_dimensions)
5
+ {
6
+ long i;
7
+ VALUE vec = rb_ary_new2(n_dimensions);
8
+
9
+ for (i = 0; i < n_dimensions; i++) {
10
+ rb_ary_store(vec, i, DBL2NUM(0));
11
+ }
12
+
13
+ return vec;
14
+ }
15
+
16
+ double
17
+ calc_gini_coef(VALUE histogram, const long n_elements)
18
+ {
19
+ long i;
20
+ double el;
21
+ double gini = 0.0;
22
+ const long n_classes = RARRAY_LEN(histogram);
23
+
24
+ for (i = 0; i < n_classes; i++) {
25
+ el = NUM2DBL(rb_ary_entry(histogram, i)) / n_elements;
26
+ gini += el * el;
27
+ }
28
+
29
+ return 1.0 - gini;
30
+ }
31
+
32
+ double
33
+ calc_entropy(VALUE histogram, const long n_elements)
34
+ {
35
+ long i;
36
+ double el;
37
+ double entropy = 0.0;
38
+ const long n_classes = RARRAY_LEN(histogram);
39
+
40
+ for (i = 0; i < n_classes; i++) {
41
+ el = NUM2DBL(rb_ary_entry(histogram, i)) / n_elements;
42
+ entropy += el * log(el + 1.0);
43
+ }
44
+
45
+ return -entropy;
46
+ }
47
+
48
+ VALUE
49
+ calc_mean_vec(VALUE sum_vec, const long n_elements)
50
+ {
51
+ long i;
52
+ const long n_dimensions = RARRAY_LEN(sum_vec);
53
+ VALUE mean_vec = rb_ary_new2(n_dimensions);
54
+
55
+ for (i = 0; i < n_dimensions; i++) {
56
+ rb_ary_store(mean_vec, i, DBL2NUM(NUM2DBL(rb_ary_entry(sum_vec, i)) / n_elements));
57
+ }
58
+
59
+ return mean_vec;
60
+ }
61
+
62
+ double
63
+ calc_vec_mae(VALUE vec_a, VALUE vec_b)
64
+ {
65
+ long i;
66
+ const long n_dimensions = RARRAY_LEN(vec_a);
67
+ double sum = 0.0;
68
+ double diff;
69
+
70
+ for (i = 0; i < n_dimensions; i++) {
71
+ diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
72
+ sum += fabs(diff);
73
+ }
74
+
75
+ return sum / n_dimensions;
76
+ }
77
+
78
+ double
79
+ calc_vec_mse(VALUE vec_a, VALUE vec_b)
80
+ {
81
+ long i;
82
+ const long n_dimensions = RARRAY_LEN(vec_a);
83
+ double sum = 0.0;
84
+ double diff;
85
+
86
+ for (i = 0; i < n_dimensions; i++) {
87
+ diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
88
+ sum += diff * diff;
89
+ }
90
+
91
+ return sum / n_dimensions;
92
+ }
93
+
94
+ double
95
+ calc_mae(VALUE target_vecs, VALUE sum_vec)
96
+ {
97
+ long i;
98
+ const long n_elements = RARRAY_LEN(target_vecs);
99
+ double sum = 0.0;
100
+ VALUE mean_vec = calc_mean_vec(sum_vec, n_elements);
101
+
102
+ for (i = 0; i < n_elements; i++) {
103
+ sum += calc_vec_mae(rb_ary_entry(target_vecs, i), mean_vec);
104
+ }
105
+
106
+ return sum / n_elements;
107
+ }
108
+
109
+ double
110
+ calc_mse(VALUE target_vecs, VALUE sum_vec)
111
+ {
112
+ long i;
113
+ const long n_elements = RARRAY_LEN(target_vecs);
114
+ double sum = 0.0;
115
+ VALUE mean_vec = calc_mean_vec(sum_vec, n_elements);
116
+
117
+ for (i = 0; i < n_elements; i++) {
118
+ sum += calc_vec_mse(rb_ary_entry(target_vecs, i), mean_vec);
119
+ }
120
+
121
+ return sum / n_elements;
122
+ }
123
+
124
+ double
125
+ calc_impurity_cls(VALUE criterion, VALUE histogram, const long n_elements)
126
+ {
127
+ if (strcmp(StringValuePtr(criterion), "entropy") == 0) {
128
+ return calc_entropy(histogram, n_elements);
129
+ }
130
+ return calc_gini_coef(histogram, n_elements);
131
+ }
132
+
133
+ double
134
+ calc_impurity_reg(VALUE criterion, VALUE target_vecs, VALUE sum_vec)
135
+ {
136
+ if (strcmp(StringValuePtr(criterion), "mae") == 0) {
137
+ return calc_mae(target_vecs, sum_vec);
138
+ }
139
+ return calc_mse(target_vecs, sum_vec);
140
+ }
141
+
142
+ void
143
+ increment_histogram(VALUE histogram, const long bin_id)
144
+ {
145
+ const double updated = NUM2DBL(rb_ary_entry(histogram, bin_id)) + 1;
146
+ rb_ary_store(histogram, bin_id, DBL2NUM(updated));
147
+ }
148
+
149
+ void
150
+ decrement_histogram(VALUE histogram, const long bin_id)
151
+ {
152
+ const double updated = NUM2DBL(rb_ary_entry(histogram, bin_id)) - 1;
153
+ rb_ary_store(histogram, bin_id, DBL2NUM(updated));
154
+ }
155
+
156
+ void
157
+ add_sum_vec(VALUE sum_vec, VALUE target)
158
+ {
159
+ long i;
160
+ const long n_dimensions = RARRAY_LEN(sum_vec);
161
+ double el;
162
+
163
+ for (i = 0; i < n_dimensions; i++) {
164
+ el = NUM2DBL(rb_ary_entry(sum_vec, i)) + NUM2DBL(rb_ary_entry(target, i));
165
+ rb_ary_store(sum_vec, i, DBL2NUM(el));
166
+ }
167
+ }
168
+
169
+ void
170
+ sub_sum_vec(VALUE sum_vec, VALUE target)
171
+ {
172
+ long i;
173
+ const long n_dimensions = RARRAY_LEN(sum_vec);
174
+ double el;
175
+
176
+ for (i = 0; i < n_dimensions; i++) {
177
+ el = NUM2DBL(rb_ary_entry(sum_vec, i)) - NUM2DBL(rb_ary_entry(target, i));
178
+ rb_ary_store(sum_vec, i, DBL2NUM(el));
179
+ }
180
+ }
181
+
182
+ /**
183
+ * @!visibility private
184
+ * Find for split point with maximum information gain.
185
+ *
186
+ * @overload find_split_params(criterion, impurity, sorted_features, sorted_labels, uniqed_features, n_classes) -> Array<Float>
187
+ *
188
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
189
+ * @param impurity [Float] The impurity of whole dataset.
190
+ * @param sorted_features [Numo::DFloat] (shape: [n_samples]) The feature values sorted in ascending order.
191
+ * @param sorted_labels [Numo::Int32] (shape: [n_labels]) The labels sorted according to feature values.
192
+ * @param uniqed_features [Numo::DFloat] (shape: [n_uniqed_features]) The unique feature values.
193
+ * @param n_classes [Integer] The number of classes.
194
+ * @return [Float] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
195
+ */
196
+ static VALUE
197
+ find_split_params_cls(VALUE self, VALUE criterion, VALUE whole_impurity, VALUE sorted_f, VALUE sorted_y, VALUE uniqed_f, VALUE n_classes_)
198
+ {
199
+ long i;
200
+ long curr_pos;
201
+ long next_pos;
202
+ long n_l_elements;
203
+ long n_r_elements;
204
+ const long n_classes = NUM2LONG(n_classes_);
205
+ const long n_elements = RARRAY_LEN(sorted_f);
206
+ const long n_uniq_elements = RARRAY_LEN(uniqed_f);
207
+ const double w_impurity = NUM2DBL(whole_impurity);
208
+ double l_impurity;
209
+ double r_impurity;
210
+ double gain;
211
+ double curr_el;
212
+ double next_el;
213
+ VALUE l_histogram = create_zero_vector(n_classes);
214
+ VALUE r_histogram = create_zero_vector(n_classes);
215
+ VALUE opt_params = rb_ary_new2(4);
216
+
217
+ /* Initialize optimal parameters. */
218
+ rb_ary_store(opt_params, 0, DBL2NUM(0)); /* left impurity */
219
+ rb_ary_store(opt_params, 1, DBL2NUM(w_impurity)); /* right impurity */
220
+ rb_ary_store(opt_params, 2, rb_ary_entry(uniqed_f, 0)); /* threshold */
221
+ rb_ary_store(opt_params, 3, DBL2NUM(0)); /* gain */
222
+
223
+ /* Initialize child node variables. */
224
+ n_l_elements = 0;
225
+ n_r_elements = n_elements;
226
+ for (i = 0; i < n_elements; i++) {
227
+ increment_histogram(r_histogram, NUM2LONG(rb_ary_entry(sorted_y, i)));
228
+ }
229
+
230
+ /* Find optimal parameters. */
231
+ for (curr_pos = 0, next_pos = 0; curr_pos < n_uniq_elements - 1; curr_pos++) {
232
+ /* Find new split point. */
233
+ curr_el = NUM2DBL(rb_ary_entry(uniqed_f, curr_pos));
234
+ next_el = NUM2DBL(rb_ary_entry(sorted_f, next_pos));
235
+ while (next_pos < n_elements && next_el <= curr_el) {
236
+ increment_histogram(l_histogram, NUM2LONG(rb_ary_entry(sorted_y, next_pos)));
237
+ n_l_elements++;
238
+ decrement_histogram(r_histogram, NUM2LONG(rb_ary_entry(sorted_y, next_pos)));
239
+ n_r_elements--;
240
+ next_el = NUM2DBL(rb_ary_entry(sorted_f, ++next_pos));
241
+ }
242
+ /* Calculate gain of new split. */
243
+ l_impurity = calc_impurity_cls(criterion, l_histogram, n_l_elements);
244
+ r_impurity = calc_impurity_cls(criterion, r_histogram, n_r_elements);
245
+ gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
246
+ /* Update optimal parameters. */
247
+ if (gain > NUM2DBL(rb_ary_entry(opt_params, 3))) {
248
+ rb_ary_store(opt_params, 0, DBL2NUM(l_impurity));
249
+ rb_ary_store(opt_params, 1, DBL2NUM(r_impurity));
250
+ rb_ary_store(opt_params, 2, DBL2NUM(0.5 * (curr_el + next_el)));
251
+ rb_ary_store(opt_params, 3, DBL2NUM(gain));
252
+ }
253
+ }
254
+
255
+ return opt_params;
256
+ }
257
+
258
+ /**
259
+ * @!visibility private
260
+ * Find for split point with maximum information gain.
261
+ *
262
+ * @overload find_split_params(criterion, impurity, sorted_features, sorted_targets, uniqed_features) -> Array<Float>
263
+ *
264
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
265
+ * @param impurity [Float] The impurity of whole dataset.
266
+ * @param sorted_features [Numo::DFloat] (shape: [n_samples]) The feature values sorted in ascending order.
267
+ * @param sorted_targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values sorted according to feature values.
268
+ * @param uniqed_features [Numo::DFloat] (shape: [n_uniqed_features]) The unique feature values.
269
+ * @return [Float] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
270
+ */
271
+ static VALUE
272
+ find_split_params_reg(VALUE self, VALUE criterion, VALUE whole_impurity, VALUE sorted_f, VALUE sorted_y, VALUE uniqed_f)
273
+ {
274
+ long i;
275
+ long curr_pos;
276
+ long next_pos;
277
+ long n_l_elements;
278
+ long n_r_elements;
279
+ const long n_elements = RARRAY_LEN(sorted_f);
280
+ const long n_uniq_elements = RARRAY_LEN(uniqed_f);
281
+ const long n_dimensions = RARRAY_LEN(rb_ary_entry(sorted_y, 0));
282
+ const double w_impurity = NUM2DBL(whole_impurity);
283
+ double l_impurity;
284
+ double r_impurity;
285
+ double gain;
286
+ double curr_el;
287
+ double next_el;
288
+ VALUE l_sum_vec = create_zero_vector(n_dimensions);
289
+ VALUE r_sum_vec = create_zero_vector(n_dimensions);
290
+ VALUE l_target_vecs = rb_ary_new();
291
+ VALUE r_target_vecs = rb_ary_new();
292
+ VALUE target;
293
+ VALUE opt_params = rb_ary_new2(4);
294
+
295
+ /* Initialize optimal parameters. */
296
+ rb_ary_store(opt_params, 0, DBL2NUM(0)); /* left impurity */
297
+ rb_ary_store(opt_params, 1, DBL2NUM(w_impurity)); /* right impurity */
298
+ rb_ary_store(opt_params, 2, rb_ary_entry(uniqed_f, 0)); /* threshold */
299
+ rb_ary_store(opt_params, 3, DBL2NUM(0)); /* gain */
300
+
301
+ /* Initialize child node variables. */
302
+ n_l_elements = 0;
303
+ n_r_elements = n_elements;
304
+ for (i = 0; i < n_elements; i++) {
305
+ target = rb_ary_entry(sorted_y, i);
306
+ add_sum_vec(r_sum_vec, target);
307
+ rb_ary_push(r_target_vecs, target);
308
+ }
309
+
310
+ /* Find optimal parameters. */
311
+ for (curr_pos = 0, next_pos = 0; curr_pos < n_uniq_elements - 1; curr_pos++) {
312
+ /* Find new split point. */
313
+ curr_el = NUM2DBL(rb_ary_entry(uniqed_f, curr_pos));
314
+ next_el = NUM2DBL(rb_ary_entry(sorted_f, next_pos));
315
+ while (next_pos < n_elements && next_el <= curr_el) {
316
+ target = rb_ary_entry(sorted_y, next_pos);
317
+ add_sum_vec(l_sum_vec, target);
318
+ rb_ary_push(l_target_vecs, target);
319
+ n_l_elements++;
320
+ sub_sum_vec(r_sum_vec, target);
321
+ rb_ary_shift(r_target_vecs);
322
+ n_r_elements--;
323
+ next_el = NUM2DBL(rb_ary_entry(sorted_f, ++next_pos));
324
+ }
325
+ /* Calculate gain of new split. */
326
+ l_impurity = calc_impurity_reg(criterion, l_target_vecs, l_sum_vec);
327
+ r_impurity = calc_impurity_reg(criterion, r_target_vecs, r_sum_vec);
328
+ gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
329
+ /* Update optimal parameters. */
330
+ if (gain > NUM2DBL(rb_ary_entry(opt_params, 3))) {
331
+ rb_ary_store(opt_params, 0, DBL2NUM(l_impurity));
332
+ rb_ary_store(opt_params, 1, DBL2NUM(r_impurity));
333
+ rb_ary_store(opt_params, 2, DBL2NUM(0.5 * (curr_el + next_el)));
334
+ rb_ary_store(opt_params, 3, DBL2NUM(gain));
335
+ }
336
+ }
337
+
338
+ return opt_params;
339
+ }
340
+
341
+ /**
342
+ * @!visibility private
343
+ * Calculate impurity based on criterion.
344
+ *
345
+ * @overload node_impurity(criterion, y, n_classes) -> Float
346
+ *
347
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
348
+ * @param y [Numo::Int32] (shape: [n_samples]) The labels.
349
+ * @param n_classes [Integer] The number of classes.
350
+ * @return [Float] impurity
351
+ */
352
+ static VALUE
353
+ node_impurity_cls(VALUE self, VALUE criterion, VALUE y, VALUE n_classes)
354
+ {
355
+ long i;
356
+ const long n_elements = RARRAY_LEN(y);
357
+ VALUE histogram = create_zero_vector(NUM2LONG(n_classes));
358
+
359
+ for (i = 0; i < n_elements; i++) {
360
+ increment_histogram(histogram, NUM2LONG(rb_ary_entry(y, i)));
361
+ }
362
+
363
+ return DBL2NUM(calc_impurity_cls(criterion, histogram, n_elements));
364
+ }
365
+
366
+ /**
367
+ * @!visibility private
368
+ * Calculate impurity based on criterion.
369
+ *
370
+ * @overload node_impurity(criterion, y) -> Float
371
+ *
372
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
373
+ * @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The taget values.
374
+ * @return [Float] impurity
375
+ */
376
+ static VALUE
377
+ node_impurity_reg(VALUE self, VALUE criterion, VALUE y)
378
+ {
379
+ long i;
380
+ const long n_elements = RARRAY_LEN(y);
381
+ const long n_dimensions = RARRAY_LEN(rb_ary_entry(y, 0));
382
+ VALUE sum_vec = create_zero_vector(n_dimensions);
383
+ VALUE target_vecs = rb_ary_new();
384
+ VALUE target;
385
+
386
+ for (i = 0; i < n_elements; i++) {
387
+ target = rb_ary_entry(y, i);
388
+ add_sum_vec(sum_vec, target);
389
+ rb_ary_push(target_vecs, target);
390
+ }
391
+
392
+ return DBL2NUM(calc_impurity_reg(criterion, target_vecs, sum_vec));
393
+ }
394
+
395
+ void Init_rumale(void)
396
+ {
397
+ VALUE mRumale = rb_define_module("Rumale");
398
+ VALUE mTree = rb_define_module_under(mRumale, "Tree");
399
+ /**
400
+ * Document-module: Rumale::Tree::ExtDecisionTreeClassifier
401
+ * @!visibility private
402
+ * The mixin module consisting of extension method for DecisionTreeClassifier class.
403
+ * This module is used internally.
404
+ */
405
+ VALUE mExtDTreeCls = rb_define_module_under(mTree, "ExtDecisionTreeClassifier");
406
+ /**
407
+ * Document-module: Rumale::Tree::ExtDecisionTreeRegressor
408
+ * @!visibility private
409
+ * The mixin module consisting of extension method for DecisionTreeRegressor class.
410
+ * This module is used internally.
411
+ */
412
+ VALUE mExtDTreeReg = rb_define_module_under(mTree, "ExtDecisionTreeRegressor");
413
+
414
+ rb_define_method(mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
415
+ rb_define_method(mExtDTreeReg, "find_split_params", find_split_params_reg, 5);
416
+ rb_define_method(mExtDTreeCls, "node_impurity", node_impurity_cls, 3);
417
+ rb_define_method(mExtDTreeReg, "node_impurity", node_impurity_reg, 2);
418
+ }
@@ -0,0 +1,9 @@
1
+ #ifndef RUMALE_H
2
+ #define RUMALE_H 1
3
+
4
+ #include <math.h>
5
+ #include <string.h>
6
+
7
+ #include "ruby.h"
8
+
9
+ #endif /* RUMALE_H */
@@ -2,6 +2,8 @@
2
2
 
3
3
  require 'numo/narray'
4
4
 
5
+ require 'rumale/rumale'
6
+
5
7
  require 'rumale/version'
6
8
  require 'rumale/validation'
7
9
  require 'rumale/values'
@@ -86,13 +86,16 @@ module Rumale
86
86
  return put_leaf(node, y) if stop_growing?(y)
87
87
 
88
88
  # calculate optimal parameters.
89
- feature_id, threshold, left_ids, right_ids, left_impurity, right_impurity, gain =
90
- rand_ids(n_features).map { |f_id| [f_id, *best_split(x[true, f_id], y, whole_impurity)] }.max_by(&:last)
89
+ feature_id, left_ids, right_ids, left_imp, right_imp, threshold, gain = rand_ids(n_features).map do |fid|
90
+ ft = x[true, fid]
91
+ limp, rimp, th, ga = best_split(ft, y, whole_impurity)
92
+ [fid, ft.le(th).where, ft.gt(th).where, limp, rimp, th, ga]
93
+ end.max_by(&:last)
91
94
 
92
95
  return put_leaf(node, y) if gain.nil? || gain.zero?
93
96
 
94
- node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids, true], left_impurity)
95
- node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids, true], right_impurity)
97
+ node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids, true], left_imp)
98
+ node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids, true], right_imp)
96
99
 
97
100
  return put_leaf(node, y) if node.left.nil? && node.right.nil?
98
101
 
@@ -114,22 +117,11 @@ module Rumale
114
117
  [*0...n].sample(@params[:max_features], random: @rng)
115
118
  end
116
119
 
117
- def best_split(features, targets, whole_impurity)
118
- n_samples = targets.shape[0]
119
- features.to_a.uniq.sort.each_cons(2).map do |l, r|
120
- threshold = 0.5 * (l + r)
121
- left_ids = features.le(threshold).where
122
- right_ids = features.gt(threshold).where
123
- left_impurity = impurity(targets[left_ids, true])
124
- right_impurity = impurity(targets[right_ids, true])
125
- gain = whole_impurity -
126
- left_impurity * left_ids.size.fdiv(n_samples) -
127
- right_impurity * right_ids.size.fdiv(n_samples)
128
- [threshold, left_ids, right_ids, left_impurity, right_impurity, gain]
129
- end.max_by(&:last)
120
+ def best_split(_features, _y, _impurity)
121
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
130
122
  end
131
123
 
132
- def impurity(_targets)
124
+ def impurity(_y)
133
125
  raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
134
126
  end
135
127
 
@@ -1,5 +1,6 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require 'rumale/rumale'
3
4
  require 'rumale/tree/base_decision_tree'
4
5
  require 'rumale/base/classifier'
5
6
 
@@ -16,6 +17,7 @@ module Rumale
16
17
  #
17
18
  class DecisionTreeClassifier < BaseDecisionTree
18
19
  include Base::Classifier
20
+ include ExtDecisionTreeClassifier
19
21
 
20
22
  # Return the class labels.
21
23
  # @return [Numo::Int32] (size: n_classes)
@@ -39,7 +41,7 @@ module Rumale
39
41
 
40
42
  # Create a new classifier with decision tree algorithm.
41
43
  #
42
- # @param criterion [String] The function to evalue spliting point. Supported criteria are 'gini' and 'entropy'.
44
+ # @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
43
45
  # @param max_depth [Integer] The maximum depth of the tree.
44
46
  # If nil is given, decision tree grows without concern for depth.
45
47
  # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
@@ -89,7 +91,7 @@ module Rumale
89
91
  # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
90
92
  def predict(x)
91
93
  check_sample_array(x)
92
- @leaf_labels[apply(x)]
94
+ @leaf_labels[apply(x)].dup
93
95
  end
94
96
 
95
97
  # Predict probability for samples.
@@ -138,7 +140,7 @@ module Rumale
138
140
  end
139
141
 
140
142
  def stop_growing?(y)
141
- y.flatten.to_a.uniq.size == 1
143
+ y[true, 0].to_a.uniq.size == 1
142
144
  end
143
145
 
144
146
  def put_leaf(node, y)
@@ -150,13 +152,17 @@ module Rumale
150
152
  node
151
153
  end
152
154
 
155
+ def best_split(features, y, whole_impurity)
156
+ order = features.sort_index
157
+ sorted_f = features[order].to_a
158
+ sorted_y = y[order, true].to_a.flatten
159
+ n_classes = @classes.size
160
+ find_split_params(@params[:criterion], whole_impurity, sorted_f, sorted_y, sorted_f.uniq, n_classes)
161
+ end
162
+
153
163
  def impurity(y)
154
- posterior_probs = y.flatten.bincount / y.size.to_f
155
- if @params[:criterion] == 'entropy'
156
- -(posterior_probs * Numo::NMath.log(posterior_probs + 1)).sum
157
- else
158
- 1.0 - (posterior_probs * posterior_probs).sum
159
- end
164
+ n_classes = @classes.size
165
+ node_impurity(@params[:criterion], y[true, 0].to_a, n_classes)
160
166
  end
161
167
  end
162
168
  end
@@ -1,5 +1,6 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require 'rumale/rumale'
3
4
  require 'rumale/tree/base_decision_tree'
4
5
  require 'rumale/base/regressor'
5
6
 
@@ -16,6 +17,7 @@ module Rumale
16
17
  #
17
18
  class DecisionTreeRegressor < BaseDecisionTree
18
19
  include Base::Regressor
20
+ include ExtDecisionTreeRegressor
19
21
 
20
22
  # Return the importance for each feature.
21
23
  # @return [Numo::DFloat] (size: n_features)
@@ -35,7 +37,7 @@ module Rumale
35
37
 
36
38
  # Create a new regressor with decision tree algorithm.
37
39
  #
38
- # @param criterion [String] The function to evalue spliting point. Supported criteria are 'mae' and 'mse'.
40
+ # @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
39
41
  # @param max_depth [Integer] The maximum depth of the tree.
40
42
  # If nil is given, decision tree grows without concern for depth.
41
43
  # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
@@ -84,7 +86,7 @@ module Rumale
84
86
  # @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted values per sample.
85
87
  def predict(x)
86
88
  check_sample_array(x)
87
- @leaf_values.shape[1].nil? ? @leaf_values[apply(x)] : @leaf_values[apply(x), true]
89
+ @leaf_values.shape[1].nil? ? @leaf_values[apply(x)].dup : @leaf_values[apply(x), true].dup
88
90
  end
89
91
 
90
92
  # Dump marshal data.
@@ -123,12 +125,15 @@ module Rumale
123
125
  node
124
126
  end
125
127
 
126
- def impurity(values)
127
- if @params[:criterion] == 'mae'
128
- (values - values.mean(0)).abs.mean
129
- else
130
- ((values - values.mean(0))**2).mean
131
- end
128
+ def best_split(features, y, whole_impurity)
129
+ order = features.sort_index
130
+ sorted_f = features[order].to_a
131
+ sorted_y = y[order, true].to_a
132
+ find_split_params(@params[:criterion], whole_impurity, sorted_f, sorted_y, sorted_f.uniq)
133
+ end
134
+
135
+ def impurity(y)
136
+ node_impurity(@params[:criterion], y.to_a)
132
137
  end
133
138
  end
134
139
  end
@@ -3,5 +3,5 @@
3
3
  # Rumale is a machine learning library in Ruby.
4
4
  module Rumale
5
5
  # The version of Rumale you are using.
6
- VERSION = '0.8.4'
6
+ VERSION = '0.9.0'
7
7
  end
@@ -29,6 +29,7 @@ MSG
29
29
  spec.bindir = 'exe'
30
30
  spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
31
31
  spec.require_paths = ['lib']
32
+ spec.extensions = ['ext/rumale/extconf.rb']
32
33
 
33
34
  spec.required_ruby_version = '>= 2.3'
34
35
 
@@ -37,5 +38,6 @@ MSG
37
38
  spec.add_development_dependency 'bundler', '>= 1.16'
38
39
  spec.add_development_dependency 'coveralls', '~> 0.8'
39
40
  spec.add_development_dependency 'rake', '~> 12.0'
41
+ spec.add_development_dependency 'rake-compiler'
40
42
  spec.add_development_dependency 'rspec', '~> 3.0'
41
43
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.4
4
+ version: 0.9.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-04-20 00:00:00.000000000 Z
11
+ date: 2019-04-22 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -66,6 +66,20 @@ dependencies:
66
66
  - - "~>"
67
67
  - !ruby/object:Gem::Version
68
68
  version: '12.0'
69
+ - !ruby/object:Gem::Dependency
70
+ name: rake-compiler
71
+ requirement: !ruby/object:Gem::Requirement
72
+ requirements:
73
+ - - ">="
74
+ - !ruby/object:Gem::Version
75
+ version: '0'
76
+ type: :development
77
+ prerelease: false
78
+ version_requirements: !ruby/object:Gem::Requirement
79
+ requirements:
80
+ - - ">="
81
+ - !ruby/object:Gem::Version
82
+ version: '0'
69
83
  - !ruby/object:Gem::Dependency
70
84
  name: rspec
71
85
  requirement: !ruby/object:Gem::Requirement
@@ -90,7 +104,8 @@ description: |
90
104
  email:
91
105
  - yoshoku@outlook.com
92
106
  executables: []
93
- extensions: []
107
+ extensions:
108
+ - ext/rumale/extconf.rb
94
109
  extra_rdoc_files: []
95
110
  files:
96
111
  - ".coveralls.yml"
@@ -107,6 +122,9 @@ files:
107
122
  - Rakefile
108
123
  - bin/console
109
124
  - bin/setup
125
+ - ext/rumale/extconf.rb
126
+ - ext/rumale/rumale.c
127
+ - ext/rumale/rumale.h
110
128
  - lib/rumale.rb
111
129
  - lib/rumale/base/base_estimator.rb
112
130
  - lib/rumale/base/classifier.rb