rumale 0.8.4 → 0.9.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.gitignore +1 -0
- data/CHANGELOG.md +5 -0
- data/Rakefile +9 -1
- data/ext/rumale/extconf.rb +3 -0
- data/ext/rumale/rumale.c +418 -0
- data/ext/rumale/rumale.h +9 -0
- data/lib/rumale.rb +2 -0
- data/lib/rumale/tree/base_decision_tree.rb +10 -18
- data/lib/rumale/tree/decision_tree_classifier.rb +15 -9
- data/lib/rumale/tree/decision_tree_regressor.rb +13 -8
- data/lib/rumale/version.rb +1 -1
- data/rumale.gemspec +2 -0
- metadata +21 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA1:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: a2dfbc60c9d47e741fc91497f8c58ade390e6c8f
|
4
|
+
data.tar.gz: d4cbc26e0d81fbe0de5e83d785cc836e9a5b2099
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 7f2b4b8ba5d7511215a2e850add19f0942cbff4157a8373eba1950c0eac9fcd0e44925d3a88b2a709c0308ef4c03cca44c501b710f4a22dc4dd573e6866d94dc
|
7
|
+
data.tar.gz: 4630710eef59af88274e9a411a6ad12de7e4a616280f8fc94d185e24c7bc667bf8c1f662425c64cf05f6ec9accd914ac32e1039688d09629b920329ad85354c8
|
data/.gitignore
CHANGED
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,8 @@
|
|
1
|
+
# 0.9.0
|
2
|
+
## Breaking changes
|
3
|
+
- Decide to introduce Ruby extensions for improving performance.
|
4
|
+
- Fix to find split point on decision tree estimators using extension modules.
|
5
|
+
|
1
6
|
# 0.8.4
|
2
7
|
- Remove unused parameter on Nadam.
|
3
8
|
- Fix condition to stop growing tree about decision tree.
|
data/Rakefile
CHANGED
@@ -3,4 +3,12 @@ require 'rspec/core/rake_task'
|
|
3
3
|
|
4
4
|
RSpec::Core::RakeTask.new(:spec)
|
5
5
|
|
6
|
-
|
6
|
+
require 'rake/extensiontask'
|
7
|
+
|
8
|
+
task :build => :compile
|
9
|
+
|
10
|
+
Rake::ExtensionTask.new('rumale') do |ext|
|
11
|
+
ext.lib_dir = 'lib/rumale'
|
12
|
+
end
|
13
|
+
|
14
|
+
task :default => [:clobber, :compile, :spec]
|
data/ext/rumale/rumale.c
ADDED
@@ -0,0 +1,418 @@
|
|
1
|
+
#include "rumale.h"
|
2
|
+
|
3
|
+
VALUE
|
4
|
+
create_zero_vector(const long n_dimensions)
|
5
|
+
{
|
6
|
+
long i;
|
7
|
+
VALUE vec = rb_ary_new2(n_dimensions);
|
8
|
+
|
9
|
+
for (i = 0; i < n_dimensions; i++) {
|
10
|
+
rb_ary_store(vec, i, DBL2NUM(0));
|
11
|
+
}
|
12
|
+
|
13
|
+
return vec;
|
14
|
+
}
|
15
|
+
|
16
|
+
double
|
17
|
+
calc_gini_coef(VALUE histogram, const long n_elements)
|
18
|
+
{
|
19
|
+
long i;
|
20
|
+
double el;
|
21
|
+
double gini = 0.0;
|
22
|
+
const long n_classes = RARRAY_LEN(histogram);
|
23
|
+
|
24
|
+
for (i = 0; i < n_classes; i++) {
|
25
|
+
el = NUM2DBL(rb_ary_entry(histogram, i)) / n_elements;
|
26
|
+
gini += el * el;
|
27
|
+
}
|
28
|
+
|
29
|
+
return 1.0 - gini;
|
30
|
+
}
|
31
|
+
|
32
|
+
double
|
33
|
+
calc_entropy(VALUE histogram, const long n_elements)
|
34
|
+
{
|
35
|
+
long i;
|
36
|
+
double el;
|
37
|
+
double entropy = 0.0;
|
38
|
+
const long n_classes = RARRAY_LEN(histogram);
|
39
|
+
|
40
|
+
for (i = 0; i < n_classes; i++) {
|
41
|
+
el = NUM2DBL(rb_ary_entry(histogram, i)) / n_elements;
|
42
|
+
entropy += el * log(el + 1.0);
|
43
|
+
}
|
44
|
+
|
45
|
+
return -entropy;
|
46
|
+
}
|
47
|
+
|
48
|
+
VALUE
|
49
|
+
calc_mean_vec(VALUE sum_vec, const long n_elements)
|
50
|
+
{
|
51
|
+
long i;
|
52
|
+
const long n_dimensions = RARRAY_LEN(sum_vec);
|
53
|
+
VALUE mean_vec = rb_ary_new2(n_dimensions);
|
54
|
+
|
55
|
+
for (i = 0; i < n_dimensions; i++) {
|
56
|
+
rb_ary_store(mean_vec, i, DBL2NUM(NUM2DBL(rb_ary_entry(sum_vec, i)) / n_elements));
|
57
|
+
}
|
58
|
+
|
59
|
+
return mean_vec;
|
60
|
+
}
|
61
|
+
|
62
|
+
double
|
63
|
+
calc_vec_mae(VALUE vec_a, VALUE vec_b)
|
64
|
+
{
|
65
|
+
long i;
|
66
|
+
const long n_dimensions = RARRAY_LEN(vec_a);
|
67
|
+
double sum = 0.0;
|
68
|
+
double diff;
|
69
|
+
|
70
|
+
for (i = 0; i < n_dimensions; i++) {
|
71
|
+
diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
|
72
|
+
sum += fabs(diff);
|
73
|
+
}
|
74
|
+
|
75
|
+
return sum / n_dimensions;
|
76
|
+
}
|
77
|
+
|
78
|
+
double
|
79
|
+
calc_vec_mse(VALUE vec_a, VALUE vec_b)
|
80
|
+
{
|
81
|
+
long i;
|
82
|
+
const long n_dimensions = RARRAY_LEN(vec_a);
|
83
|
+
double sum = 0.0;
|
84
|
+
double diff;
|
85
|
+
|
86
|
+
for (i = 0; i < n_dimensions; i++) {
|
87
|
+
diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
|
88
|
+
sum += diff * diff;
|
89
|
+
}
|
90
|
+
|
91
|
+
return sum / n_dimensions;
|
92
|
+
}
|
93
|
+
|
94
|
+
double
|
95
|
+
calc_mae(VALUE target_vecs, VALUE sum_vec)
|
96
|
+
{
|
97
|
+
long i;
|
98
|
+
const long n_elements = RARRAY_LEN(target_vecs);
|
99
|
+
double sum = 0.0;
|
100
|
+
VALUE mean_vec = calc_mean_vec(sum_vec, n_elements);
|
101
|
+
|
102
|
+
for (i = 0; i < n_elements; i++) {
|
103
|
+
sum += calc_vec_mae(rb_ary_entry(target_vecs, i), mean_vec);
|
104
|
+
}
|
105
|
+
|
106
|
+
return sum / n_elements;
|
107
|
+
}
|
108
|
+
|
109
|
+
double
|
110
|
+
calc_mse(VALUE target_vecs, VALUE sum_vec)
|
111
|
+
{
|
112
|
+
long i;
|
113
|
+
const long n_elements = RARRAY_LEN(target_vecs);
|
114
|
+
double sum = 0.0;
|
115
|
+
VALUE mean_vec = calc_mean_vec(sum_vec, n_elements);
|
116
|
+
|
117
|
+
for (i = 0; i < n_elements; i++) {
|
118
|
+
sum += calc_vec_mse(rb_ary_entry(target_vecs, i), mean_vec);
|
119
|
+
}
|
120
|
+
|
121
|
+
return sum / n_elements;
|
122
|
+
}
|
123
|
+
|
124
|
+
double
|
125
|
+
calc_impurity_cls(VALUE criterion, VALUE histogram, const long n_elements)
|
126
|
+
{
|
127
|
+
if (strcmp(StringValuePtr(criterion), "entropy") == 0) {
|
128
|
+
return calc_entropy(histogram, n_elements);
|
129
|
+
}
|
130
|
+
return calc_gini_coef(histogram, n_elements);
|
131
|
+
}
|
132
|
+
|
133
|
+
double
|
134
|
+
calc_impurity_reg(VALUE criterion, VALUE target_vecs, VALUE sum_vec)
|
135
|
+
{
|
136
|
+
if (strcmp(StringValuePtr(criterion), "mae") == 0) {
|
137
|
+
return calc_mae(target_vecs, sum_vec);
|
138
|
+
}
|
139
|
+
return calc_mse(target_vecs, sum_vec);
|
140
|
+
}
|
141
|
+
|
142
|
+
void
|
143
|
+
increment_histogram(VALUE histogram, const long bin_id)
|
144
|
+
{
|
145
|
+
const double updated = NUM2DBL(rb_ary_entry(histogram, bin_id)) + 1;
|
146
|
+
rb_ary_store(histogram, bin_id, DBL2NUM(updated));
|
147
|
+
}
|
148
|
+
|
149
|
+
void
|
150
|
+
decrement_histogram(VALUE histogram, const long bin_id)
|
151
|
+
{
|
152
|
+
const double updated = NUM2DBL(rb_ary_entry(histogram, bin_id)) - 1;
|
153
|
+
rb_ary_store(histogram, bin_id, DBL2NUM(updated));
|
154
|
+
}
|
155
|
+
|
156
|
+
void
|
157
|
+
add_sum_vec(VALUE sum_vec, VALUE target)
|
158
|
+
{
|
159
|
+
long i;
|
160
|
+
const long n_dimensions = RARRAY_LEN(sum_vec);
|
161
|
+
double el;
|
162
|
+
|
163
|
+
for (i = 0; i < n_dimensions; i++) {
|
164
|
+
el = NUM2DBL(rb_ary_entry(sum_vec, i)) + NUM2DBL(rb_ary_entry(target, i));
|
165
|
+
rb_ary_store(sum_vec, i, DBL2NUM(el));
|
166
|
+
}
|
167
|
+
}
|
168
|
+
|
169
|
+
void
|
170
|
+
sub_sum_vec(VALUE sum_vec, VALUE target)
|
171
|
+
{
|
172
|
+
long i;
|
173
|
+
const long n_dimensions = RARRAY_LEN(sum_vec);
|
174
|
+
double el;
|
175
|
+
|
176
|
+
for (i = 0; i < n_dimensions; i++) {
|
177
|
+
el = NUM2DBL(rb_ary_entry(sum_vec, i)) - NUM2DBL(rb_ary_entry(target, i));
|
178
|
+
rb_ary_store(sum_vec, i, DBL2NUM(el));
|
179
|
+
}
|
180
|
+
}
|
181
|
+
|
182
|
+
/**
|
183
|
+
* @!visibility private
|
184
|
+
* Find for split point with maximum information gain.
|
185
|
+
*
|
186
|
+
* @overload find_split_params(criterion, impurity, sorted_features, sorted_labels, uniqed_features, n_classes) -> Array<Float>
|
187
|
+
*
|
188
|
+
* @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
|
189
|
+
* @param impurity [Float] The impurity of whole dataset.
|
190
|
+
* @param sorted_features [Numo::DFloat] (shape: [n_samples]) The feature values sorted in ascending order.
|
191
|
+
* @param sorted_labels [Numo::Int32] (shape: [n_labels]) The labels sorted according to feature values.
|
192
|
+
* @param uniqed_features [Numo::DFloat] (shape: [n_uniqed_features]) The unique feature values.
|
193
|
+
* @param n_classes [Integer] The number of classes.
|
194
|
+
* @return [Float] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
|
195
|
+
*/
|
196
|
+
static VALUE
|
197
|
+
find_split_params_cls(VALUE self, VALUE criterion, VALUE whole_impurity, VALUE sorted_f, VALUE sorted_y, VALUE uniqed_f, VALUE n_classes_)
|
198
|
+
{
|
199
|
+
long i;
|
200
|
+
long curr_pos;
|
201
|
+
long next_pos;
|
202
|
+
long n_l_elements;
|
203
|
+
long n_r_elements;
|
204
|
+
const long n_classes = NUM2LONG(n_classes_);
|
205
|
+
const long n_elements = RARRAY_LEN(sorted_f);
|
206
|
+
const long n_uniq_elements = RARRAY_LEN(uniqed_f);
|
207
|
+
const double w_impurity = NUM2DBL(whole_impurity);
|
208
|
+
double l_impurity;
|
209
|
+
double r_impurity;
|
210
|
+
double gain;
|
211
|
+
double curr_el;
|
212
|
+
double next_el;
|
213
|
+
VALUE l_histogram = create_zero_vector(n_classes);
|
214
|
+
VALUE r_histogram = create_zero_vector(n_classes);
|
215
|
+
VALUE opt_params = rb_ary_new2(4);
|
216
|
+
|
217
|
+
/* Initialize optimal parameters. */
|
218
|
+
rb_ary_store(opt_params, 0, DBL2NUM(0)); /* left impurity */
|
219
|
+
rb_ary_store(opt_params, 1, DBL2NUM(w_impurity)); /* right impurity */
|
220
|
+
rb_ary_store(opt_params, 2, rb_ary_entry(uniqed_f, 0)); /* threshold */
|
221
|
+
rb_ary_store(opt_params, 3, DBL2NUM(0)); /* gain */
|
222
|
+
|
223
|
+
/* Initialize child node variables. */
|
224
|
+
n_l_elements = 0;
|
225
|
+
n_r_elements = n_elements;
|
226
|
+
for (i = 0; i < n_elements; i++) {
|
227
|
+
increment_histogram(r_histogram, NUM2LONG(rb_ary_entry(sorted_y, i)));
|
228
|
+
}
|
229
|
+
|
230
|
+
/* Find optimal parameters. */
|
231
|
+
for (curr_pos = 0, next_pos = 0; curr_pos < n_uniq_elements - 1; curr_pos++) {
|
232
|
+
/* Find new split point. */
|
233
|
+
curr_el = NUM2DBL(rb_ary_entry(uniqed_f, curr_pos));
|
234
|
+
next_el = NUM2DBL(rb_ary_entry(sorted_f, next_pos));
|
235
|
+
while (next_pos < n_elements && next_el <= curr_el) {
|
236
|
+
increment_histogram(l_histogram, NUM2LONG(rb_ary_entry(sorted_y, next_pos)));
|
237
|
+
n_l_elements++;
|
238
|
+
decrement_histogram(r_histogram, NUM2LONG(rb_ary_entry(sorted_y, next_pos)));
|
239
|
+
n_r_elements--;
|
240
|
+
next_el = NUM2DBL(rb_ary_entry(sorted_f, ++next_pos));
|
241
|
+
}
|
242
|
+
/* Calculate gain of new split. */
|
243
|
+
l_impurity = calc_impurity_cls(criterion, l_histogram, n_l_elements);
|
244
|
+
r_impurity = calc_impurity_cls(criterion, r_histogram, n_r_elements);
|
245
|
+
gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
|
246
|
+
/* Update optimal parameters. */
|
247
|
+
if (gain > NUM2DBL(rb_ary_entry(opt_params, 3))) {
|
248
|
+
rb_ary_store(opt_params, 0, DBL2NUM(l_impurity));
|
249
|
+
rb_ary_store(opt_params, 1, DBL2NUM(r_impurity));
|
250
|
+
rb_ary_store(opt_params, 2, DBL2NUM(0.5 * (curr_el + next_el)));
|
251
|
+
rb_ary_store(opt_params, 3, DBL2NUM(gain));
|
252
|
+
}
|
253
|
+
}
|
254
|
+
|
255
|
+
return opt_params;
|
256
|
+
}
|
257
|
+
|
258
|
+
/**
|
259
|
+
* @!visibility private
|
260
|
+
* Find for split point with maximum information gain.
|
261
|
+
*
|
262
|
+
* @overload find_split_params(criterion, impurity, sorted_features, sorted_targets, uniqed_features) -> Array<Float>
|
263
|
+
*
|
264
|
+
* @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
|
265
|
+
* @param impurity [Float] The impurity of whole dataset.
|
266
|
+
* @param sorted_features [Numo::DFloat] (shape: [n_samples]) The feature values sorted in ascending order.
|
267
|
+
* @param sorted_targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values sorted according to feature values.
|
268
|
+
* @param uniqed_features [Numo::DFloat] (shape: [n_uniqed_features]) The unique feature values.
|
269
|
+
* @return [Float] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
|
270
|
+
*/
|
271
|
+
static VALUE
|
272
|
+
find_split_params_reg(VALUE self, VALUE criterion, VALUE whole_impurity, VALUE sorted_f, VALUE sorted_y, VALUE uniqed_f)
|
273
|
+
{
|
274
|
+
long i;
|
275
|
+
long curr_pos;
|
276
|
+
long next_pos;
|
277
|
+
long n_l_elements;
|
278
|
+
long n_r_elements;
|
279
|
+
const long n_elements = RARRAY_LEN(sorted_f);
|
280
|
+
const long n_uniq_elements = RARRAY_LEN(uniqed_f);
|
281
|
+
const long n_dimensions = RARRAY_LEN(rb_ary_entry(sorted_y, 0));
|
282
|
+
const double w_impurity = NUM2DBL(whole_impurity);
|
283
|
+
double l_impurity;
|
284
|
+
double r_impurity;
|
285
|
+
double gain;
|
286
|
+
double curr_el;
|
287
|
+
double next_el;
|
288
|
+
VALUE l_sum_vec = create_zero_vector(n_dimensions);
|
289
|
+
VALUE r_sum_vec = create_zero_vector(n_dimensions);
|
290
|
+
VALUE l_target_vecs = rb_ary_new();
|
291
|
+
VALUE r_target_vecs = rb_ary_new();
|
292
|
+
VALUE target;
|
293
|
+
VALUE opt_params = rb_ary_new2(4);
|
294
|
+
|
295
|
+
/* Initialize optimal parameters. */
|
296
|
+
rb_ary_store(opt_params, 0, DBL2NUM(0)); /* left impurity */
|
297
|
+
rb_ary_store(opt_params, 1, DBL2NUM(w_impurity)); /* right impurity */
|
298
|
+
rb_ary_store(opt_params, 2, rb_ary_entry(uniqed_f, 0)); /* threshold */
|
299
|
+
rb_ary_store(opt_params, 3, DBL2NUM(0)); /* gain */
|
300
|
+
|
301
|
+
/* Initialize child node variables. */
|
302
|
+
n_l_elements = 0;
|
303
|
+
n_r_elements = n_elements;
|
304
|
+
for (i = 0; i < n_elements; i++) {
|
305
|
+
target = rb_ary_entry(sorted_y, i);
|
306
|
+
add_sum_vec(r_sum_vec, target);
|
307
|
+
rb_ary_push(r_target_vecs, target);
|
308
|
+
}
|
309
|
+
|
310
|
+
/* Find optimal parameters. */
|
311
|
+
for (curr_pos = 0, next_pos = 0; curr_pos < n_uniq_elements - 1; curr_pos++) {
|
312
|
+
/* Find new split point. */
|
313
|
+
curr_el = NUM2DBL(rb_ary_entry(uniqed_f, curr_pos));
|
314
|
+
next_el = NUM2DBL(rb_ary_entry(sorted_f, next_pos));
|
315
|
+
while (next_pos < n_elements && next_el <= curr_el) {
|
316
|
+
target = rb_ary_entry(sorted_y, next_pos);
|
317
|
+
add_sum_vec(l_sum_vec, target);
|
318
|
+
rb_ary_push(l_target_vecs, target);
|
319
|
+
n_l_elements++;
|
320
|
+
sub_sum_vec(r_sum_vec, target);
|
321
|
+
rb_ary_shift(r_target_vecs);
|
322
|
+
n_r_elements--;
|
323
|
+
next_el = NUM2DBL(rb_ary_entry(sorted_f, ++next_pos));
|
324
|
+
}
|
325
|
+
/* Calculate gain of new split. */
|
326
|
+
l_impurity = calc_impurity_reg(criterion, l_target_vecs, l_sum_vec);
|
327
|
+
r_impurity = calc_impurity_reg(criterion, r_target_vecs, r_sum_vec);
|
328
|
+
gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
|
329
|
+
/* Update optimal parameters. */
|
330
|
+
if (gain > NUM2DBL(rb_ary_entry(opt_params, 3))) {
|
331
|
+
rb_ary_store(opt_params, 0, DBL2NUM(l_impurity));
|
332
|
+
rb_ary_store(opt_params, 1, DBL2NUM(r_impurity));
|
333
|
+
rb_ary_store(opt_params, 2, DBL2NUM(0.5 * (curr_el + next_el)));
|
334
|
+
rb_ary_store(opt_params, 3, DBL2NUM(gain));
|
335
|
+
}
|
336
|
+
}
|
337
|
+
|
338
|
+
return opt_params;
|
339
|
+
}
|
340
|
+
|
341
|
+
/**
|
342
|
+
* @!visibility private
|
343
|
+
* Calculate impurity based on criterion.
|
344
|
+
*
|
345
|
+
* @overload node_impurity(criterion, y, n_classes) -> Float
|
346
|
+
*
|
347
|
+
* @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
|
348
|
+
* @param y [Numo::Int32] (shape: [n_samples]) The labels.
|
349
|
+
* @param n_classes [Integer] The number of classes.
|
350
|
+
* @return [Float] impurity
|
351
|
+
*/
|
352
|
+
static VALUE
|
353
|
+
node_impurity_cls(VALUE self, VALUE criterion, VALUE y, VALUE n_classes)
|
354
|
+
{
|
355
|
+
long i;
|
356
|
+
const long n_elements = RARRAY_LEN(y);
|
357
|
+
VALUE histogram = create_zero_vector(NUM2LONG(n_classes));
|
358
|
+
|
359
|
+
for (i = 0; i < n_elements; i++) {
|
360
|
+
increment_histogram(histogram, NUM2LONG(rb_ary_entry(y, i)));
|
361
|
+
}
|
362
|
+
|
363
|
+
return DBL2NUM(calc_impurity_cls(criterion, histogram, n_elements));
|
364
|
+
}
|
365
|
+
|
366
|
+
/**
|
367
|
+
* @!visibility private
|
368
|
+
* Calculate impurity based on criterion.
|
369
|
+
*
|
370
|
+
* @overload node_impurity(criterion, y) -> Float
|
371
|
+
*
|
372
|
+
* @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
|
373
|
+
* @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The taget values.
|
374
|
+
* @return [Float] impurity
|
375
|
+
*/
|
376
|
+
static VALUE
|
377
|
+
node_impurity_reg(VALUE self, VALUE criterion, VALUE y)
|
378
|
+
{
|
379
|
+
long i;
|
380
|
+
const long n_elements = RARRAY_LEN(y);
|
381
|
+
const long n_dimensions = RARRAY_LEN(rb_ary_entry(y, 0));
|
382
|
+
VALUE sum_vec = create_zero_vector(n_dimensions);
|
383
|
+
VALUE target_vecs = rb_ary_new();
|
384
|
+
VALUE target;
|
385
|
+
|
386
|
+
for (i = 0; i < n_elements; i++) {
|
387
|
+
target = rb_ary_entry(y, i);
|
388
|
+
add_sum_vec(sum_vec, target);
|
389
|
+
rb_ary_push(target_vecs, target);
|
390
|
+
}
|
391
|
+
|
392
|
+
return DBL2NUM(calc_impurity_reg(criterion, target_vecs, sum_vec));
|
393
|
+
}
|
394
|
+
|
395
|
+
void Init_rumale(void)
|
396
|
+
{
|
397
|
+
VALUE mRumale = rb_define_module("Rumale");
|
398
|
+
VALUE mTree = rb_define_module_under(mRumale, "Tree");
|
399
|
+
/**
|
400
|
+
* Document-module: Rumale::Tree::ExtDecisionTreeClassifier
|
401
|
+
* @!visibility private
|
402
|
+
* The mixin module consisting of extension method for DecisionTreeClassifier class.
|
403
|
+
* This module is used internally.
|
404
|
+
*/
|
405
|
+
VALUE mExtDTreeCls = rb_define_module_under(mTree, "ExtDecisionTreeClassifier");
|
406
|
+
/**
|
407
|
+
* Document-module: Rumale::Tree::ExtDecisionTreeRegressor
|
408
|
+
* @!visibility private
|
409
|
+
* The mixin module consisting of extension method for DecisionTreeRegressor class.
|
410
|
+
* This module is used internally.
|
411
|
+
*/
|
412
|
+
VALUE mExtDTreeReg = rb_define_module_under(mTree, "ExtDecisionTreeRegressor");
|
413
|
+
|
414
|
+
rb_define_method(mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
|
415
|
+
rb_define_method(mExtDTreeReg, "find_split_params", find_split_params_reg, 5);
|
416
|
+
rb_define_method(mExtDTreeCls, "node_impurity", node_impurity_cls, 3);
|
417
|
+
rb_define_method(mExtDTreeReg, "node_impurity", node_impurity_reg, 2);
|
418
|
+
}
|
data/ext/rumale/rumale.h
ADDED
data/lib/rumale.rb
CHANGED
@@ -86,13 +86,16 @@ module Rumale
|
|
86
86
|
return put_leaf(node, y) if stop_growing?(y)
|
87
87
|
|
88
88
|
# calculate optimal parameters.
|
89
|
-
feature_id,
|
90
|
-
|
89
|
+
feature_id, left_ids, right_ids, left_imp, right_imp, threshold, gain = rand_ids(n_features).map do |fid|
|
90
|
+
ft = x[true, fid]
|
91
|
+
limp, rimp, th, ga = best_split(ft, y, whole_impurity)
|
92
|
+
[fid, ft.le(th).where, ft.gt(th).where, limp, rimp, th, ga]
|
93
|
+
end.max_by(&:last)
|
91
94
|
|
92
95
|
return put_leaf(node, y) if gain.nil? || gain.zero?
|
93
96
|
|
94
|
-
node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids, true],
|
95
|
-
node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids, true],
|
97
|
+
node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids, true], left_imp)
|
98
|
+
node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids, true], right_imp)
|
96
99
|
|
97
100
|
return put_leaf(node, y) if node.left.nil? && node.right.nil?
|
98
101
|
|
@@ -114,22 +117,11 @@ module Rumale
|
|
114
117
|
[*0...n].sample(@params[:max_features], random: @rng)
|
115
118
|
end
|
116
119
|
|
117
|
-
def best_split(
|
118
|
-
|
119
|
-
features.to_a.uniq.sort.each_cons(2).map do |l, r|
|
120
|
-
threshold = 0.5 * (l + r)
|
121
|
-
left_ids = features.le(threshold).where
|
122
|
-
right_ids = features.gt(threshold).where
|
123
|
-
left_impurity = impurity(targets[left_ids, true])
|
124
|
-
right_impurity = impurity(targets[right_ids, true])
|
125
|
-
gain = whole_impurity -
|
126
|
-
left_impurity * left_ids.size.fdiv(n_samples) -
|
127
|
-
right_impurity * right_ids.size.fdiv(n_samples)
|
128
|
-
[threshold, left_ids, right_ids, left_impurity, right_impurity, gain]
|
129
|
-
end.max_by(&:last)
|
120
|
+
def best_split(_features, _y, _impurity)
|
121
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
130
122
|
end
|
131
123
|
|
132
|
-
def impurity(
|
124
|
+
def impurity(_y)
|
133
125
|
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
134
126
|
end
|
135
127
|
|
@@ -1,5 +1,6 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
+
require 'rumale/rumale'
|
3
4
|
require 'rumale/tree/base_decision_tree'
|
4
5
|
require 'rumale/base/classifier'
|
5
6
|
|
@@ -16,6 +17,7 @@ module Rumale
|
|
16
17
|
#
|
17
18
|
class DecisionTreeClassifier < BaseDecisionTree
|
18
19
|
include Base::Classifier
|
20
|
+
include ExtDecisionTreeClassifier
|
19
21
|
|
20
22
|
# Return the class labels.
|
21
23
|
# @return [Numo::Int32] (size: n_classes)
|
@@ -39,7 +41,7 @@ module Rumale
|
|
39
41
|
|
40
42
|
# Create a new classifier with decision tree algorithm.
|
41
43
|
#
|
42
|
-
# @param criterion [String] The function to
|
44
|
+
# @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
|
43
45
|
# @param max_depth [Integer] The maximum depth of the tree.
|
44
46
|
# If nil is given, decision tree grows without concern for depth.
|
45
47
|
# @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
|
@@ -89,7 +91,7 @@ module Rumale
|
|
89
91
|
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
|
90
92
|
def predict(x)
|
91
93
|
check_sample_array(x)
|
92
|
-
@leaf_labels[apply(x)]
|
94
|
+
@leaf_labels[apply(x)].dup
|
93
95
|
end
|
94
96
|
|
95
97
|
# Predict probability for samples.
|
@@ -138,7 +140,7 @@ module Rumale
|
|
138
140
|
end
|
139
141
|
|
140
142
|
def stop_growing?(y)
|
141
|
-
y.
|
143
|
+
y[true, 0].to_a.uniq.size == 1
|
142
144
|
end
|
143
145
|
|
144
146
|
def put_leaf(node, y)
|
@@ -150,13 +152,17 @@ module Rumale
|
|
150
152
|
node
|
151
153
|
end
|
152
154
|
|
155
|
+
def best_split(features, y, whole_impurity)
|
156
|
+
order = features.sort_index
|
157
|
+
sorted_f = features[order].to_a
|
158
|
+
sorted_y = y[order, true].to_a.flatten
|
159
|
+
n_classes = @classes.size
|
160
|
+
find_split_params(@params[:criterion], whole_impurity, sorted_f, sorted_y, sorted_f.uniq, n_classes)
|
161
|
+
end
|
162
|
+
|
153
163
|
def impurity(y)
|
154
|
-
|
155
|
-
|
156
|
-
-(posterior_probs * Numo::NMath.log(posterior_probs + 1)).sum
|
157
|
-
else
|
158
|
-
1.0 - (posterior_probs * posterior_probs).sum
|
159
|
-
end
|
164
|
+
n_classes = @classes.size
|
165
|
+
node_impurity(@params[:criterion], y[true, 0].to_a, n_classes)
|
160
166
|
end
|
161
167
|
end
|
162
168
|
end
|
@@ -1,5 +1,6 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
+
require 'rumale/rumale'
|
3
4
|
require 'rumale/tree/base_decision_tree'
|
4
5
|
require 'rumale/base/regressor'
|
5
6
|
|
@@ -16,6 +17,7 @@ module Rumale
|
|
16
17
|
#
|
17
18
|
class DecisionTreeRegressor < BaseDecisionTree
|
18
19
|
include Base::Regressor
|
20
|
+
include ExtDecisionTreeRegressor
|
19
21
|
|
20
22
|
# Return the importance for each feature.
|
21
23
|
# @return [Numo::DFloat] (size: n_features)
|
@@ -35,7 +37,7 @@ module Rumale
|
|
35
37
|
|
36
38
|
# Create a new regressor with decision tree algorithm.
|
37
39
|
#
|
38
|
-
# @param criterion [String] The function to
|
40
|
+
# @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
|
39
41
|
# @param max_depth [Integer] The maximum depth of the tree.
|
40
42
|
# If nil is given, decision tree grows without concern for depth.
|
41
43
|
# @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
|
@@ -84,7 +86,7 @@ module Rumale
|
|
84
86
|
# @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted values per sample.
|
85
87
|
def predict(x)
|
86
88
|
check_sample_array(x)
|
87
|
-
@leaf_values.shape[1].nil? ? @leaf_values[apply(x)] : @leaf_values[apply(x), true]
|
89
|
+
@leaf_values.shape[1].nil? ? @leaf_values[apply(x)].dup : @leaf_values[apply(x), true].dup
|
88
90
|
end
|
89
91
|
|
90
92
|
# Dump marshal data.
|
@@ -123,12 +125,15 @@ module Rumale
|
|
123
125
|
node
|
124
126
|
end
|
125
127
|
|
126
|
-
def
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
128
|
+
def best_split(features, y, whole_impurity)
|
129
|
+
order = features.sort_index
|
130
|
+
sorted_f = features[order].to_a
|
131
|
+
sorted_y = y[order, true].to_a
|
132
|
+
find_split_params(@params[:criterion], whole_impurity, sorted_f, sorted_y, sorted_f.uniq)
|
133
|
+
end
|
134
|
+
|
135
|
+
def impurity(y)
|
136
|
+
node_impurity(@params[:criterion], y.to_a)
|
132
137
|
end
|
133
138
|
end
|
134
139
|
end
|
data/lib/rumale/version.rb
CHANGED
data/rumale.gemspec
CHANGED
@@ -29,6 +29,7 @@ MSG
|
|
29
29
|
spec.bindir = 'exe'
|
30
30
|
spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
|
31
31
|
spec.require_paths = ['lib']
|
32
|
+
spec.extensions = ['ext/rumale/extconf.rb']
|
32
33
|
|
33
34
|
spec.required_ruby_version = '>= 2.3'
|
34
35
|
|
@@ -37,5 +38,6 @@ MSG
|
|
37
38
|
spec.add_development_dependency 'bundler', '>= 1.16'
|
38
39
|
spec.add_development_dependency 'coveralls', '~> 0.8'
|
39
40
|
spec.add_development_dependency 'rake', '~> 12.0'
|
41
|
+
spec.add_development_dependency 'rake-compiler'
|
40
42
|
spec.add_development_dependency 'rspec', '~> 3.0'
|
41
43
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.9.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-04-
|
11
|
+
date: 2019-04-22 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -66,6 +66,20 @@ dependencies:
|
|
66
66
|
- - "~>"
|
67
67
|
- !ruby/object:Gem::Version
|
68
68
|
version: '12.0'
|
69
|
+
- !ruby/object:Gem::Dependency
|
70
|
+
name: rake-compiler
|
71
|
+
requirement: !ruby/object:Gem::Requirement
|
72
|
+
requirements:
|
73
|
+
- - ">="
|
74
|
+
- !ruby/object:Gem::Version
|
75
|
+
version: '0'
|
76
|
+
type: :development
|
77
|
+
prerelease: false
|
78
|
+
version_requirements: !ruby/object:Gem::Requirement
|
79
|
+
requirements:
|
80
|
+
- - ">="
|
81
|
+
- !ruby/object:Gem::Version
|
82
|
+
version: '0'
|
69
83
|
- !ruby/object:Gem::Dependency
|
70
84
|
name: rspec
|
71
85
|
requirement: !ruby/object:Gem::Requirement
|
@@ -90,7 +104,8 @@ description: |
|
|
90
104
|
email:
|
91
105
|
- yoshoku@outlook.com
|
92
106
|
executables: []
|
93
|
-
extensions:
|
107
|
+
extensions:
|
108
|
+
- ext/rumale/extconf.rb
|
94
109
|
extra_rdoc_files: []
|
95
110
|
files:
|
96
111
|
- ".coveralls.yml"
|
@@ -107,6 +122,9 @@ files:
|
|
107
122
|
- Rakefile
|
108
123
|
- bin/console
|
109
124
|
- bin/setup
|
125
|
+
- ext/rumale/extconf.rb
|
126
|
+
- ext/rumale/rumale.c
|
127
|
+
- ext/rumale/rumale.h
|
110
128
|
- lib/rumale.rb
|
111
129
|
- lib/rumale/base/base_estimator.rb
|
112
130
|
- lib/rumale/base/classifier.rb
|