rumale 0.8.4 → 0.9.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA1:
3
- metadata.gz: 8e895ca9462569e1ec2b3b9c1cb985aebd7d4b19
4
- data.tar.gz: 9374969bd955cada6ba54ec6edf1d7bc174a5102
3
+ metadata.gz: a2dfbc60c9d47e741fc91497f8c58ade390e6c8f
4
+ data.tar.gz: d4cbc26e0d81fbe0de5e83d785cc836e9a5b2099
5
5
  SHA512:
6
- metadata.gz: 7e9eadf3404e74ee887007a1fb1df4e2b933f10394c1a61afdd8400492afcc4fbb40479c453983d2349db5d5c7cd52ab280a9ef28aa26d91b19cb41253bdb233
7
- data.tar.gz: 674bf164a3f1be2971fa2f882999ed34a22773d15674d6c2c24da66af888be855ffd3daa1d66a08ee7ec4bc99c87e864ffe61ebb7d6356dfae78a0a20e97c3a9
6
+ metadata.gz: 7f2b4b8ba5d7511215a2e850add19f0942cbff4157a8373eba1950c0eac9fcd0e44925d3a88b2a709c0308ef4c03cca44c501b710f4a22dc4dd573e6866d94dc
7
+ data.tar.gz: 4630710eef59af88274e9a411a6ad12de7e4a616280f8fc94d185e24c7bc667bf8c1f662425c64cf05f6ec9accd914ac32e1039688d09629b920329ad85354c8
data/.gitignore CHANGED
@@ -12,6 +12,7 @@
12
12
  .rspec_status
13
13
 
14
14
  *.swp
15
+ *.bundle
15
16
  .DS_Store
16
17
  .ruby-version
17
18
  /spec/dump_dbl.t
@@ -1,3 +1,8 @@
1
+ # 0.9.0
2
+ ## Breaking changes
3
+ - Decide to introduce Ruby extensions for improving performance.
4
+ - Fix to find split point on decision tree estimators using extension modules.
5
+
1
6
  # 0.8.4
2
7
  - Remove unused parameter on Nadam.
3
8
  - Fix condition to stop growing tree about decision tree.
data/Rakefile CHANGED
@@ -3,4 +3,12 @@ require 'rspec/core/rake_task'
3
3
 
4
4
  RSpec::Core::RakeTask.new(:spec)
5
5
 
6
- task :default => :spec
6
+ require 'rake/extensiontask'
7
+
8
+ task :build => :compile
9
+
10
+ Rake::ExtensionTask.new('rumale') do |ext|
11
+ ext.lib_dir = 'lib/rumale'
12
+ end
13
+
14
+ task :default => [:clobber, :compile, :spec]
@@ -0,0 +1,3 @@
1
+ require 'mkmf'
2
+
3
+ create_makefile('rumale/rumale')
@@ -0,0 +1,418 @@
1
+ #include "rumale.h"
2
+
3
+ VALUE
4
+ create_zero_vector(const long n_dimensions)
5
+ {
6
+ long i;
7
+ VALUE vec = rb_ary_new2(n_dimensions);
8
+
9
+ for (i = 0; i < n_dimensions; i++) {
10
+ rb_ary_store(vec, i, DBL2NUM(0));
11
+ }
12
+
13
+ return vec;
14
+ }
15
+
16
+ double
17
+ calc_gini_coef(VALUE histogram, const long n_elements)
18
+ {
19
+ long i;
20
+ double el;
21
+ double gini = 0.0;
22
+ const long n_classes = RARRAY_LEN(histogram);
23
+
24
+ for (i = 0; i < n_classes; i++) {
25
+ el = NUM2DBL(rb_ary_entry(histogram, i)) / n_elements;
26
+ gini += el * el;
27
+ }
28
+
29
+ return 1.0 - gini;
30
+ }
31
+
32
+ double
33
+ calc_entropy(VALUE histogram, const long n_elements)
34
+ {
35
+ long i;
36
+ double el;
37
+ double entropy = 0.0;
38
+ const long n_classes = RARRAY_LEN(histogram);
39
+
40
+ for (i = 0; i < n_classes; i++) {
41
+ el = NUM2DBL(rb_ary_entry(histogram, i)) / n_elements;
42
+ entropy += el * log(el + 1.0);
43
+ }
44
+
45
+ return -entropy;
46
+ }
47
+
48
+ VALUE
49
+ calc_mean_vec(VALUE sum_vec, const long n_elements)
50
+ {
51
+ long i;
52
+ const long n_dimensions = RARRAY_LEN(sum_vec);
53
+ VALUE mean_vec = rb_ary_new2(n_dimensions);
54
+
55
+ for (i = 0; i < n_dimensions; i++) {
56
+ rb_ary_store(mean_vec, i, DBL2NUM(NUM2DBL(rb_ary_entry(sum_vec, i)) / n_elements));
57
+ }
58
+
59
+ return mean_vec;
60
+ }
61
+
62
+ double
63
+ calc_vec_mae(VALUE vec_a, VALUE vec_b)
64
+ {
65
+ long i;
66
+ const long n_dimensions = RARRAY_LEN(vec_a);
67
+ double sum = 0.0;
68
+ double diff;
69
+
70
+ for (i = 0; i < n_dimensions; i++) {
71
+ diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
72
+ sum += fabs(diff);
73
+ }
74
+
75
+ return sum / n_dimensions;
76
+ }
77
+
78
+ double
79
+ calc_vec_mse(VALUE vec_a, VALUE vec_b)
80
+ {
81
+ long i;
82
+ const long n_dimensions = RARRAY_LEN(vec_a);
83
+ double sum = 0.0;
84
+ double diff;
85
+
86
+ for (i = 0; i < n_dimensions; i++) {
87
+ diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
88
+ sum += diff * diff;
89
+ }
90
+
91
+ return sum / n_dimensions;
92
+ }
93
+
94
+ double
95
+ calc_mae(VALUE target_vecs, VALUE sum_vec)
96
+ {
97
+ long i;
98
+ const long n_elements = RARRAY_LEN(target_vecs);
99
+ double sum = 0.0;
100
+ VALUE mean_vec = calc_mean_vec(sum_vec, n_elements);
101
+
102
+ for (i = 0; i < n_elements; i++) {
103
+ sum += calc_vec_mae(rb_ary_entry(target_vecs, i), mean_vec);
104
+ }
105
+
106
+ return sum / n_elements;
107
+ }
108
+
109
+ double
110
+ calc_mse(VALUE target_vecs, VALUE sum_vec)
111
+ {
112
+ long i;
113
+ const long n_elements = RARRAY_LEN(target_vecs);
114
+ double sum = 0.0;
115
+ VALUE mean_vec = calc_mean_vec(sum_vec, n_elements);
116
+
117
+ for (i = 0; i < n_elements; i++) {
118
+ sum += calc_vec_mse(rb_ary_entry(target_vecs, i), mean_vec);
119
+ }
120
+
121
+ return sum / n_elements;
122
+ }
123
+
124
+ double
125
+ calc_impurity_cls(VALUE criterion, VALUE histogram, const long n_elements)
126
+ {
127
+ if (strcmp(StringValuePtr(criterion), "entropy") == 0) {
128
+ return calc_entropy(histogram, n_elements);
129
+ }
130
+ return calc_gini_coef(histogram, n_elements);
131
+ }
132
+
133
+ double
134
+ calc_impurity_reg(VALUE criterion, VALUE target_vecs, VALUE sum_vec)
135
+ {
136
+ if (strcmp(StringValuePtr(criterion), "mae") == 0) {
137
+ return calc_mae(target_vecs, sum_vec);
138
+ }
139
+ return calc_mse(target_vecs, sum_vec);
140
+ }
141
+
142
+ void
143
+ increment_histogram(VALUE histogram, const long bin_id)
144
+ {
145
+ const double updated = NUM2DBL(rb_ary_entry(histogram, bin_id)) + 1;
146
+ rb_ary_store(histogram, bin_id, DBL2NUM(updated));
147
+ }
148
+
149
+ void
150
+ decrement_histogram(VALUE histogram, const long bin_id)
151
+ {
152
+ const double updated = NUM2DBL(rb_ary_entry(histogram, bin_id)) - 1;
153
+ rb_ary_store(histogram, bin_id, DBL2NUM(updated));
154
+ }
155
+
156
+ void
157
+ add_sum_vec(VALUE sum_vec, VALUE target)
158
+ {
159
+ long i;
160
+ const long n_dimensions = RARRAY_LEN(sum_vec);
161
+ double el;
162
+
163
+ for (i = 0; i < n_dimensions; i++) {
164
+ el = NUM2DBL(rb_ary_entry(sum_vec, i)) + NUM2DBL(rb_ary_entry(target, i));
165
+ rb_ary_store(sum_vec, i, DBL2NUM(el));
166
+ }
167
+ }
168
+
169
+ void
170
+ sub_sum_vec(VALUE sum_vec, VALUE target)
171
+ {
172
+ long i;
173
+ const long n_dimensions = RARRAY_LEN(sum_vec);
174
+ double el;
175
+
176
+ for (i = 0; i < n_dimensions; i++) {
177
+ el = NUM2DBL(rb_ary_entry(sum_vec, i)) - NUM2DBL(rb_ary_entry(target, i));
178
+ rb_ary_store(sum_vec, i, DBL2NUM(el));
179
+ }
180
+ }
181
+
182
+ /**
183
+ * @!visibility private
184
+ * Find for split point with maximum information gain.
185
+ *
186
+ * @overload find_split_params(criterion, impurity, sorted_features, sorted_labels, uniqed_features, n_classes) -> Array<Float>
187
+ *
188
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
189
+ * @param impurity [Float] The impurity of whole dataset.
190
+ * @param sorted_features [Numo::DFloat] (shape: [n_samples]) The feature values sorted in ascending order.
191
+ * @param sorted_labels [Numo::Int32] (shape: [n_labels]) The labels sorted according to feature values.
192
+ * @param uniqed_features [Numo::DFloat] (shape: [n_uniqed_features]) The unique feature values.
193
+ * @param n_classes [Integer] The number of classes.
194
+ * @return [Float] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
195
+ */
196
+ static VALUE
197
+ find_split_params_cls(VALUE self, VALUE criterion, VALUE whole_impurity, VALUE sorted_f, VALUE sorted_y, VALUE uniqed_f, VALUE n_classes_)
198
+ {
199
+ long i;
200
+ long curr_pos;
201
+ long next_pos;
202
+ long n_l_elements;
203
+ long n_r_elements;
204
+ const long n_classes = NUM2LONG(n_classes_);
205
+ const long n_elements = RARRAY_LEN(sorted_f);
206
+ const long n_uniq_elements = RARRAY_LEN(uniqed_f);
207
+ const double w_impurity = NUM2DBL(whole_impurity);
208
+ double l_impurity;
209
+ double r_impurity;
210
+ double gain;
211
+ double curr_el;
212
+ double next_el;
213
+ VALUE l_histogram = create_zero_vector(n_classes);
214
+ VALUE r_histogram = create_zero_vector(n_classes);
215
+ VALUE opt_params = rb_ary_new2(4);
216
+
217
+ /* Initialize optimal parameters. */
218
+ rb_ary_store(opt_params, 0, DBL2NUM(0)); /* left impurity */
219
+ rb_ary_store(opt_params, 1, DBL2NUM(w_impurity)); /* right impurity */
220
+ rb_ary_store(opt_params, 2, rb_ary_entry(uniqed_f, 0)); /* threshold */
221
+ rb_ary_store(opt_params, 3, DBL2NUM(0)); /* gain */
222
+
223
+ /* Initialize child node variables. */
224
+ n_l_elements = 0;
225
+ n_r_elements = n_elements;
226
+ for (i = 0; i < n_elements; i++) {
227
+ increment_histogram(r_histogram, NUM2LONG(rb_ary_entry(sorted_y, i)));
228
+ }
229
+
230
+ /* Find optimal parameters. */
231
+ for (curr_pos = 0, next_pos = 0; curr_pos < n_uniq_elements - 1; curr_pos++) {
232
+ /* Find new split point. */
233
+ curr_el = NUM2DBL(rb_ary_entry(uniqed_f, curr_pos));
234
+ next_el = NUM2DBL(rb_ary_entry(sorted_f, next_pos));
235
+ while (next_pos < n_elements && next_el <= curr_el) {
236
+ increment_histogram(l_histogram, NUM2LONG(rb_ary_entry(sorted_y, next_pos)));
237
+ n_l_elements++;
238
+ decrement_histogram(r_histogram, NUM2LONG(rb_ary_entry(sorted_y, next_pos)));
239
+ n_r_elements--;
240
+ next_el = NUM2DBL(rb_ary_entry(sorted_f, ++next_pos));
241
+ }
242
+ /* Calculate gain of new split. */
243
+ l_impurity = calc_impurity_cls(criterion, l_histogram, n_l_elements);
244
+ r_impurity = calc_impurity_cls(criterion, r_histogram, n_r_elements);
245
+ gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
246
+ /* Update optimal parameters. */
247
+ if (gain > NUM2DBL(rb_ary_entry(opt_params, 3))) {
248
+ rb_ary_store(opt_params, 0, DBL2NUM(l_impurity));
249
+ rb_ary_store(opt_params, 1, DBL2NUM(r_impurity));
250
+ rb_ary_store(opt_params, 2, DBL2NUM(0.5 * (curr_el + next_el)));
251
+ rb_ary_store(opt_params, 3, DBL2NUM(gain));
252
+ }
253
+ }
254
+
255
+ return opt_params;
256
+ }
257
+
258
+ /**
259
+ * @!visibility private
260
+ * Find for split point with maximum information gain.
261
+ *
262
+ * @overload find_split_params(criterion, impurity, sorted_features, sorted_targets, uniqed_features) -> Array<Float>
263
+ *
264
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
265
+ * @param impurity [Float] The impurity of whole dataset.
266
+ * @param sorted_features [Numo::DFloat] (shape: [n_samples]) The feature values sorted in ascending order.
267
+ * @param sorted_targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values sorted according to feature values.
268
+ * @param uniqed_features [Numo::DFloat] (shape: [n_uniqed_features]) The unique feature values.
269
+ * @return [Float] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
270
+ */
271
+ static VALUE
272
+ find_split_params_reg(VALUE self, VALUE criterion, VALUE whole_impurity, VALUE sorted_f, VALUE sorted_y, VALUE uniqed_f)
273
+ {
274
+ long i;
275
+ long curr_pos;
276
+ long next_pos;
277
+ long n_l_elements;
278
+ long n_r_elements;
279
+ const long n_elements = RARRAY_LEN(sorted_f);
280
+ const long n_uniq_elements = RARRAY_LEN(uniqed_f);
281
+ const long n_dimensions = RARRAY_LEN(rb_ary_entry(sorted_y, 0));
282
+ const double w_impurity = NUM2DBL(whole_impurity);
283
+ double l_impurity;
284
+ double r_impurity;
285
+ double gain;
286
+ double curr_el;
287
+ double next_el;
288
+ VALUE l_sum_vec = create_zero_vector(n_dimensions);
289
+ VALUE r_sum_vec = create_zero_vector(n_dimensions);
290
+ VALUE l_target_vecs = rb_ary_new();
291
+ VALUE r_target_vecs = rb_ary_new();
292
+ VALUE target;
293
+ VALUE opt_params = rb_ary_new2(4);
294
+
295
+ /* Initialize optimal parameters. */
296
+ rb_ary_store(opt_params, 0, DBL2NUM(0)); /* left impurity */
297
+ rb_ary_store(opt_params, 1, DBL2NUM(w_impurity)); /* right impurity */
298
+ rb_ary_store(opt_params, 2, rb_ary_entry(uniqed_f, 0)); /* threshold */
299
+ rb_ary_store(opt_params, 3, DBL2NUM(0)); /* gain */
300
+
301
+ /* Initialize child node variables. */
302
+ n_l_elements = 0;
303
+ n_r_elements = n_elements;
304
+ for (i = 0; i < n_elements; i++) {
305
+ target = rb_ary_entry(sorted_y, i);
306
+ add_sum_vec(r_sum_vec, target);
307
+ rb_ary_push(r_target_vecs, target);
308
+ }
309
+
310
+ /* Find optimal parameters. */
311
+ for (curr_pos = 0, next_pos = 0; curr_pos < n_uniq_elements - 1; curr_pos++) {
312
+ /* Find new split point. */
313
+ curr_el = NUM2DBL(rb_ary_entry(uniqed_f, curr_pos));
314
+ next_el = NUM2DBL(rb_ary_entry(sorted_f, next_pos));
315
+ while (next_pos < n_elements && next_el <= curr_el) {
316
+ target = rb_ary_entry(sorted_y, next_pos);
317
+ add_sum_vec(l_sum_vec, target);
318
+ rb_ary_push(l_target_vecs, target);
319
+ n_l_elements++;
320
+ sub_sum_vec(r_sum_vec, target);
321
+ rb_ary_shift(r_target_vecs);
322
+ n_r_elements--;
323
+ next_el = NUM2DBL(rb_ary_entry(sorted_f, ++next_pos));
324
+ }
325
+ /* Calculate gain of new split. */
326
+ l_impurity = calc_impurity_reg(criterion, l_target_vecs, l_sum_vec);
327
+ r_impurity = calc_impurity_reg(criterion, r_target_vecs, r_sum_vec);
328
+ gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
329
+ /* Update optimal parameters. */
330
+ if (gain > NUM2DBL(rb_ary_entry(opt_params, 3))) {
331
+ rb_ary_store(opt_params, 0, DBL2NUM(l_impurity));
332
+ rb_ary_store(opt_params, 1, DBL2NUM(r_impurity));
333
+ rb_ary_store(opt_params, 2, DBL2NUM(0.5 * (curr_el + next_el)));
334
+ rb_ary_store(opt_params, 3, DBL2NUM(gain));
335
+ }
336
+ }
337
+
338
+ return opt_params;
339
+ }
340
+
341
+ /**
342
+ * @!visibility private
343
+ * Calculate impurity based on criterion.
344
+ *
345
+ * @overload node_impurity(criterion, y, n_classes) -> Float
346
+ *
347
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
348
+ * @param y [Numo::Int32] (shape: [n_samples]) The labels.
349
+ * @param n_classes [Integer] The number of classes.
350
+ * @return [Float] impurity
351
+ */
352
+ static VALUE
353
+ node_impurity_cls(VALUE self, VALUE criterion, VALUE y, VALUE n_classes)
354
+ {
355
+ long i;
356
+ const long n_elements = RARRAY_LEN(y);
357
+ VALUE histogram = create_zero_vector(NUM2LONG(n_classes));
358
+
359
+ for (i = 0; i < n_elements; i++) {
360
+ increment_histogram(histogram, NUM2LONG(rb_ary_entry(y, i)));
361
+ }
362
+
363
+ return DBL2NUM(calc_impurity_cls(criterion, histogram, n_elements));
364
+ }
365
+
366
+ /**
367
+ * @!visibility private
368
+ * Calculate impurity based on criterion.
369
+ *
370
+ * @overload node_impurity(criterion, y) -> Float
371
+ *
372
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
373
+ * @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The taget values.
374
+ * @return [Float] impurity
375
+ */
376
+ static VALUE
377
+ node_impurity_reg(VALUE self, VALUE criterion, VALUE y)
378
+ {
379
+ long i;
380
+ const long n_elements = RARRAY_LEN(y);
381
+ const long n_dimensions = RARRAY_LEN(rb_ary_entry(y, 0));
382
+ VALUE sum_vec = create_zero_vector(n_dimensions);
383
+ VALUE target_vecs = rb_ary_new();
384
+ VALUE target;
385
+
386
+ for (i = 0; i < n_elements; i++) {
387
+ target = rb_ary_entry(y, i);
388
+ add_sum_vec(sum_vec, target);
389
+ rb_ary_push(target_vecs, target);
390
+ }
391
+
392
+ return DBL2NUM(calc_impurity_reg(criterion, target_vecs, sum_vec));
393
+ }
394
+
395
+ void Init_rumale(void)
396
+ {
397
+ VALUE mRumale = rb_define_module("Rumale");
398
+ VALUE mTree = rb_define_module_under(mRumale, "Tree");
399
+ /**
400
+ * Document-module: Rumale::Tree::ExtDecisionTreeClassifier
401
+ * @!visibility private
402
+ * The mixin module consisting of extension method for DecisionTreeClassifier class.
403
+ * This module is used internally.
404
+ */
405
+ VALUE mExtDTreeCls = rb_define_module_under(mTree, "ExtDecisionTreeClassifier");
406
+ /**
407
+ * Document-module: Rumale::Tree::ExtDecisionTreeRegressor
408
+ * @!visibility private
409
+ * The mixin module consisting of extension method for DecisionTreeRegressor class.
410
+ * This module is used internally.
411
+ */
412
+ VALUE mExtDTreeReg = rb_define_module_under(mTree, "ExtDecisionTreeRegressor");
413
+
414
+ rb_define_method(mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
415
+ rb_define_method(mExtDTreeReg, "find_split_params", find_split_params_reg, 5);
416
+ rb_define_method(mExtDTreeCls, "node_impurity", node_impurity_cls, 3);
417
+ rb_define_method(mExtDTreeReg, "node_impurity", node_impurity_reg, 2);
418
+ }
@@ -0,0 +1,9 @@
1
+ #ifndef RUMALE_H
2
+ #define RUMALE_H 1
3
+
4
+ #include <math.h>
5
+ #include <string.h>
6
+
7
+ #include "ruby.h"
8
+
9
+ #endif /* RUMALE_H */
@@ -2,6 +2,8 @@
2
2
 
3
3
  require 'numo/narray'
4
4
 
5
+ require 'rumale/rumale'
6
+
5
7
  require 'rumale/version'
6
8
  require 'rumale/validation'
7
9
  require 'rumale/values'
@@ -86,13 +86,16 @@ module Rumale
86
86
  return put_leaf(node, y) if stop_growing?(y)
87
87
 
88
88
  # calculate optimal parameters.
89
- feature_id, threshold, left_ids, right_ids, left_impurity, right_impurity, gain =
90
- rand_ids(n_features).map { |f_id| [f_id, *best_split(x[true, f_id], y, whole_impurity)] }.max_by(&:last)
89
+ feature_id, left_ids, right_ids, left_imp, right_imp, threshold, gain = rand_ids(n_features).map do |fid|
90
+ ft = x[true, fid]
91
+ limp, rimp, th, ga = best_split(ft, y, whole_impurity)
92
+ [fid, ft.le(th).where, ft.gt(th).where, limp, rimp, th, ga]
93
+ end.max_by(&:last)
91
94
 
92
95
  return put_leaf(node, y) if gain.nil? || gain.zero?
93
96
 
94
- node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids, true], left_impurity)
95
- node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids, true], right_impurity)
97
+ node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids, true], left_imp)
98
+ node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids, true], right_imp)
96
99
 
97
100
  return put_leaf(node, y) if node.left.nil? && node.right.nil?
98
101
 
@@ -114,22 +117,11 @@ module Rumale
114
117
  [*0...n].sample(@params[:max_features], random: @rng)
115
118
  end
116
119
 
117
- def best_split(features, targets, whole_impurity)
118
- n_samples = targets.shape[0]
119
- features.to_a.uniq.sort.each_cons(2).map do |l, r|
120
- threshold = 0.5 * (l + r)
121
- left_ids = features.le(threshold).where
122
- right_ids = features.gt(threshold).where
123
- left_impurity = impurity(targets[left_ids, true])
124
- right_impurity = impurity(targets[right_ids, true])
125
- gain = whole_impurity -
126
- left_impurity * left_ids.size.fdiv(n_samples) -
127
- right_impurity * right_ids.size.fdiv(n_samples)
128
- [threshold, left_ids, right_ids, left_impurity, right_impurity, gain]
129
- end.max_by(&:last)
120
+ def best_split(_features, _y, _impurity)
121
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
130
122
  end
131
123
 
132
- def impurity(_targets)
124
+ def impurity(_y)
133
125
  raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
134
126
  end
135
127
 
@@ -1,5 +1,6 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require 'rumale/rumale'
3
4
  require 'rumale/tree/base_decision_tree'
4
5
  require 'rumale/base/classifier'
5
6
 
@@ -16,6 +17,7 @@ module Rumale
16
17
  #
17
18
  class DecisionTreeClassifier < BaseDecisionTree
18
19
  include Base::Classifier
20
+ include ExtDecisionTreeClassifier
19
21
 
20
22
  # Return the class labels.
21
23
  # @return [Numo::Int32] (size: n_classes)
@@ -39,7 +41,7 @@ module Rumale
39
41
 
40
42
  # Create a new classifier with decision tree algorithm.
41
43
  #
42
- # @param criterion [String] The function to evalue spliting point. Supported criteria are 'gini' and 'entropy'.
44
+ # @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
43
45
  # @param max_depth [Integer] The maximum depth of the tree.
44
46
  # If nil is given, decision tree grows without concern for depth.
45
47
  # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
@@ -89,7 +91,7 @@ module Rumale
89
91
  # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
90
92
  def predict(x)
91
93
  check_sample_array(x)
92
- @leaf_labels[apply(x)]
94
+ @leaf_labels[apply(x)].dup
93
95
  end
94
96
 
95
97
  # Predict probability for samples.
@@ -138,7 +140,7 @@ module Rumale
138
140
  end
139
141
 
140
142
  def stop_growing?(y)
141
- y.flatten.to_a.uniq.size == 1
143
+ y[true, 0].to_a.uniq.size == 1
142
144
  end
143
145
 
144
146
  def put_leaf(node, y)
@@ -150,13 +152,17 @@ module Rumale
150
152
  node
151
153
  end
152
154
 
155
+ def best_split(features, y, whole_impurity)
156
+ order = features.sort_index
157
+ sorted_f = features[order].to_a
158
+ sorted_y = y[order, true].to_a.flatten
159
+ n_classes = @classes.size
160
+ find_split_params(@params[:criterion], whole_impurity, sorted_f, sorted_y, sorted_f.uniq, n_classes)
161
+ end
162
+
153
163
  def impurity(y)
154
- posterior_probs = y.flatten.bincount / y.size.to_f
155
- if @params[:criterion] == 'entropy'
156
- -(posterior_probs * Numo::NMath.log(posterior_probs + 1)).sum
157
- else
158
- 1.0 - (posterior_probs * posterior_probs).sum
159
- end
164
+ n_classes = @classes.size
165
+ node_impurity(@params[:criterion], y[true, 0].to_a, n_classes)
160
166
  end
161
167
  end
162
168
  end
@@ -1,5 +1,6 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require 'rumale/rumale'
3
4
  require 'rumale/tree/base_decision_tree'
4
5
  require 'rumale/base/regressor'
5
6
 
@@ -16,6 +17,7 @@ module Rumale
16
17
  #
17
18
  class DecisionTreeRegressor < BaseDecisionTree
18
19
  include Base::Regressor
20
+ include ExtDecisionTreeRegressor
19
21
 
20
22
  # Return the importance for each feature.
21
23
  # @return [Numo::DFloat] (size: n_features)
@@ -35,7 +37,7 @@ module Rumale
35
37
 
36
38
  # Create a new regressor with decision tree algorithm.
37
39
  #
38
- # @param criterion [String] The function to evalue spliting point. Supported criteria are 'mae' and 'mse'.
40
+ # @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
39
41
  # @param max_depth [Integer] The maximum depth of the tree.
40
42
  # If nil is given, decision tree grows without concern for depth.
41
43
  # @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
@@ -84,7 +86,7 @@ module Rumale
84
86
  # @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted values per sample.
85
87
  def predict(x)
86
88
  check_sample_array(x)
87
- @leaf_values.shape[1].nil? ? @leaf_values[apply(x)] : @leaf_values[apply(x), true]
89
+ @leaf_values.shape[1].nil? ? @leaf_values[apply(x)].dup : @leaf_values[apply(x), true].dup
88
90
  end
89
91
 
90
92
  # Dump marshal data.
@@ -123,12 +125,15 @@ module Rumale
123
125
  node
124
126
  end
125
127
 
126
- def impurity(values)
127
- if @params[:criterion] == 'mae'
128
- (values - values.mean(0)).abs.mean
129
- else
130
- ((values - values.mean(0))**2).mean
131
- end
128
+ def best_split(features, y, whole_impurity)
129
+ order = features.sort_index
130
+ sorted_f = features[order].to_a
131
+ sorted_y = y[order, true].to_a
132
+ find_split_params(@params[:criterion], whole_impurity, sorted_f, sorted_y, sorted_f.uniq)
133
+ end
134
+
135
+ def impurity(y)
136
+ node_impurity(@params[:criterion], y.to_a)
132
137
  end
133
138
  end
134
139
  end
@@ -3,5 +3,5 @@
3
3
  # Rumale is a machine learning library in Ruby.
4
4
  module Rumale
5
5
  # The version of Rumale you are using.
6
- VERSION = '0.8.4'
6
+ VERSION = '0.9.0'
7
7
  end
@@ -29,6 +29,7 @@ MSG
29
29
  spec.bindir = 'exe'
30
30
  spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
31
31
  spec.require_paths = ['lib']
32
+ spec.extensions = ['ext/rumale/extconf.rb']
32
33
 
33
34
  spec.required_ruby_version = '>= 2.3'
34
35
 
@@ -37,5 +38,6 @@ MSG
37
38
  spec.add_development_dependency 'bundler', '>= 1.16'
38
39
  spec.add_development_dependency 'coveralls', '~> 0.8'
39
40
  spec.add_development_dependency 'rake', '~> 12.0'
41
+ spec.add_development_dependency 'rake-compiler'
40
42
  spec.add_development_dependency 'rspec', '~> 3.0'
41
43
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.8.4
4
+ version: 0.9.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2019-04-20 00:00:00.000000000 Z
11
+ date: 2019-04-22 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -66,6 +66,20 @@ dependencies:
66
66
  - - "~>"
67
67
  - !ruby/object:Gem::Version
68
68
  version: '12.0'
69
+ - !ruby/object:Gem::Dependency
70
+ name: rake-compiler
71
+ requirement: !ruby/object:Gem::Requirement
72
+ requirements:
73
+ - - ">="
74
+ - !ruby/object:Gem::Version
75
+ version: '0'
76
+ type: :development
77
+ prerelease: false
78
+ version_requirements: !ruby/object:Gem::Requirement
79
+ requirements:
80
+ - - ">="
81
+ - !ruby/object:Gem::Version
82
+ version: '0'
69
83
  - !ruby/object:Gem::Dependency
70
84
  name: rspec
71
85
  requirement: !ruby/object:Gem::Requirement
@@ -90,7 +104,8 @@ description: |
90
104
  email:
91
105
  - yoshoku@outlook.com
92
106
  executables: []
93
- extensions: []
107
+ extensions:
108
+ - ext/rumale/extconf.rb
94
109
  extra_rdoc_files: []
95
110
  files:
96
111
  - ".coveralls.yml"
@@ -107,6 +122,9 @@ files:
107
122
  - Rakefile
108
123
  - bin/console
109
124
  - bin/setup
125
+ - ext/rumale/extconf.rb
126
+ - ext/rumale/rumale.c
127
+ - ext/rumale/rumale.h
110
128
  - lib/rumale.rb
111
129
  - lib/rumale/base/base_estimator.rb
112
130
  - lib/rumale/base/classifier.rb