rumale 0.8.4 → 0.9.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +1 -0
- data/CHANGELOG.md +5 -0
- data/Rakefile +9 -1
- data/ext/rumale/extconf.rb +3 -0
- data/ext/rumale/rumale.c +418 -0
- data/ext/rumale/rumale.h +9 -0
- data/lib/rumale.rb +2 -0
- data/lib/rumale/tree/base_decision_tree.rb +10 -18
- data/lib/rumale/tree/decision_tree_classifier.rb +15 -9
- data/lib/rumale/tree/decision_tree_regressor.rb +13 -8
- data/lib/rumale/version.rb +1 -1
- data/rumale.gemspec +2 -0
- metadata +21 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA1:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: a2dfbc60c9d47e741fc91497f8c58ade390e6c8f
|
4
|
+
data.tar.gz: d4cbc26e0d81fbe0de5e83d785cc836e9a5b2099
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 7f2b4b8ba5d7511215a2e850add19f0942cbff4157a8373eba1950c0eac9fcd0e44925d3a88b2a709c0308ef4c03cca44c501b710f4a22dc4dd573e6866d94dc
|
7
|
+
data.tar.gz: 4630710eef59af88274e9a411a6ad12de7e4a616280f8fc94d185e24c7bc667bf8c1f662425c64cf05f6ec9accd914ac32e1039688d09629b920329ad85354c8
|
data/.gitignore
CHANGED
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,8 @@
|
|
1
|
+
# 0.9.0
|
2
|
+
## Breaking changes
|
3
|
+
- Decide to introduce Ruby extensions for improving performance.
|
4
|
+
- Fix to find split point on decision tree estimators using extension modules.
|
5
|
+
|
1
6
|
# 0.8.4
|
2
7
|
- Remove unused parameter on Nadam.
|
3
8
|
- Fix condition to stop growing tree about decision tree.
|
data/Rakefile
CHANGED
@@ -3,4 +3,12 @@ require 'rspec/core/rake_task'
|
|
3
3
|
|
4
4
|
RSpec::Core::RakeTask.new(:spec)
|
5
5
|
|
6
|
-
|
6
|
+
require 'rake/extensiontask'
|
7
|
+
|
8
|
+
task :build => :compile
|
9
|
+
|
10
|
+
Rake::ExtensionTask.new('rumale') do |ext|
|
11
|
+
ext.lib_dir = 'lib/rumale'
|
12
|
+
end
|
13
|
+
|
14
|
+
task :default => [:clobber, :compile, :spec]
|
data/ext/rumale/rumale.c
ADDED
@@ -0,0 +1,418 @@
|
|
1
|
+
#include "rumale.h"
|
2
|
+
|
3
|
+
VALUE
|
4
|
+
create_zero_vector(const long n_dimensions)
|
5
|
+
{
|
6
|
+
long i;
|
7
|
+
VALUE vec = rb_ary_new2(n_dimensions);
|
8
|
+
|
9
|
+
for (i = 0; i < n_dimensions; i++) {
|
10
|
+
rb_ary_store(vec, i, DBL2NUM(0));
|
11
|
+
}
|
12
|
+
|
13
|
+
return vec;
|
14
|
+
}
|
15
|
+
|
16
|
+
double
|
17
|
+
calc_gini_coef(VALUE histogram, const long n_elements)
|
18
|
+
{
|
19
|
+
long i;
|
20
|
+
double el;
|
21
|
+
double gini = 0.0;
|
22
|
+
const long n_classes = RARRAY_LEN(histogram);
|
23
|
+
|
24
|
+
for (i = 0; i < n_classes; i++) {
|
25
|
+
el = NUM2DBL(rb_ary_entry(histogram, i)) / n_elements;
|
26
|
+
gini += el * el;
|
27
|
+
}
|
28
|
+
|
29
|
+
return 1.0 - gini;
|
30
|
+
}
|
31
|
+
|
32
|
+
double
|
33
|
+
calc_entropy(VALUE histogram, const long n_elements)
|
34
|
+
{
|
35
|
+
long i;
|
36
|
+
double el;
|
37
|
+
double entropy = 0.0;
|
38
|
+
const long n_classes = RARRAY_LEN(histogram);
|
39
|
+
|
40
|
+
for (i = 0; i < n_classes; i++) {
|
41
|
+
el = NUM2DBL(rb_ary_entry(histogram, i)) / n_elements;
|
42
|
+
entropy += el * log(el + 1.0);
|
43
|
+
}
|
44
|
+
|
45
|
+
return -entropy;
|
46
|
+
}
|
47
|
+
|
48
|
+
VALUE
|
49
|
+
calc_mean_vec(VALUE sum_vec, const long n_elements)
|
50
|
+
{
|
51
|
+
long i;
|
52
|
+
const long n_dimensions = RARRAY_LEN(sum_vec);
|
53
|
+
VALUE mean_vec = rb_ary_new2(n_dimensions);
|
54
|
+
|
55
|
+
for (i = 0; i < n_dimensions; i++) {
|
56
|
+
rb_ary_store(mean_vec, i, DBL2NUM(NUM2DBL(rb_ary_entry(sum_vec, i)) / n_elements));
|
57
|
+
}
|
58
|
+
|
59
|
+
return mean_vec;
|
60
|
+
}
|
61
|
+
|
62
|
+
double
|
63
|
+
calc_vec_mae(VALUE vec_a, VALUE vec_b)
|
64
|
+
{
|
65
|
+
long i;
|
66
|
+
const long n_dimensions = RARRAY_LEN(vec_a);
|
67
|
+
double sum = 0.0;
|
68
|
+
double diff;
|
69
|
+
|
70
|
+
for (i = 0; i < n_dimensions; i++) {
|
71
|
+
diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
|
72
|
+
sum += fabs(diff);
|
73
|
+
}
|
74
|
+
|
75
|
+
return sum / n_dimensions;
|
76
|
+
}
|
77
|
+
|
78
|
+
double
|
79
|
+
calc_vec_mse(VALUE vec_a, VALUE vec_b)
|
80
|
+
{
|
81
|
+
long i;
|
82
|
+
const long n_dimensions = RARRAY_LEN(vec_a);
|
83
|
+
double sum = 0.0;
|
84
|
+
double diff;
|
85
|
+
|
86
|
+
for (i = 0; i < n_dimensions; i++) {
|
87
|
+
diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
|
88
|
+
sum += diff * diff;
|
89
|
+
}
|
90
|
+
|
91
|
+
return sum / n_dimensions;
|
92
|
+
}
|
93
|
+
|
94
|
+
double
|
95
|
+
calc_mae(VALUE target_vecs, VALUE sum_vec)
|
96
|
+
{
|
97
|
+
long i;
|
98
|
+
const long n_elements = RARRAY_LEN(target_vecs);
|
99
|
+
double sum = 0.0;
|
100
|
+
VALUE mean_vec = calc_mean_vec(sum_vec, n_elements);
|
101
|
+
|
102
|
+
for (i = 0; i < n_elements; i++) {
|
103
|
+
sum += calc_vec_mae(rb_ary_entry(target_vecs, i), mean_vec);
|
104
|
+
}
|
105
|
+
|
106
|
+
return sum / n_elements;
|
107
|
+
}
|
108
|
+
|
109
|
+
double
|
110
|
+
calc_mse(VALUE target_vecs, VALUE sum_vec)
|
111
|
+
{
|
112
|
+
long i;
|
113
|
+
const long n_elements = RARRAY_LEN(target_vecs);
|
114
|
+
double sum = 0.0;
|
115
|
+
VALUE mean_vec = calc_mean_vec(sum_vec, n_elements);
|
116
|
+
|
117
|
+
for (i = 0; i < n_elements; i++) {
|
118
|
+
sum += calc_vec_mse(rb_ary_entry(target_vecs, i), mean_vec);
|
119
|
+
}
|
120
|
+
|
121
|
+
return sum / n_elements;
|
122
|
+
}
|
123
|
+
|
124
|
+
double
|
125
|
+
calc_impurity_cls(VALUE criterion, VALUE histogram, const long n_elements)
|
126
|
+
{
|
127
|
+
if (strcmp(StringValuePtr(criterion), "entropy") == 0) {
|
128
|
+
return calc_entropy(histogram, n_elements);
|
129
|
+
}
|
130
|
+
return calc_gini_coef(histogram, n_elements);
|
131
|
+
}
|
132
|
+
|
133
|
+
double
|
134
|
+
calc_impurity_reg(VALUE criterion, VALUE target_vecs, VALUE sum_vec)
|
135
|
+
{
|
136
|
+
if (strcmp(StringValuePtr(criterion), "mae") == 0) {
|
137
|
+
return calc_mae(target_vecs, sum_vec);
|
138
|
+
}
|
139
|
+
return calc_mse(target_vecs, sum_vec);
|
140
|
+
}
|
141
|
+
|
142
|
+
void
|
143
|
+
increment_histogram(VALUE histogram, const long bin_id)
|
144
|
+
{
|
145
|
+
const double updated = NUM2DBL(rb_ary_entry(histogram, bin_id)) + 1;
|
146
|
+
rb_ary_store(histogram, bin_id, DBL2NUM(updated));
|
147
|
+
}
|
148
|
+
|
149
|
+
void
|
150
|
+
decrement_histogram(VALUE histogram, const long bin_id)
|
151
|
+
{
|
152
|
+
const double updated = NUM2DBL(rb_ary_entry(histogram, bin_id)) - 1;
|
153
|
+
rb_ary_store(histogram, bin_id, DBL2NUM(updated));
|
154
|
+
}
|
155
|
+
|
156
|
+
void
|
157
|
+
add_sum_vec(VALUE sum_vec, VALUE target)
|
158
|
+
{
|
159
|
+
long i;
|
160
|
+
const long n_dimensions = RARRAY_LEN(sum_vec);
|
161
|
+
double el;
|
162
|
+
|
163
|
+
for (i = 0; i < n_dimensions; i++) {
|
164
|
+
el = NUM2DBL(rb_ary_entry(sum_vec, i)) + NUM2DBL(rb_ary_entry(target, i));
|
165
|
+
rb_ary_store(sum_vec, i, DBL2NUM(el));
|
166
|
+
}
|
167
|
+
}
|
168
|
+
|
169
|
+
void
|
170
|
+
sub_sum_vec(VALUE sum_vec, VALUE target)
|
171
|
+
{
|
172
|
+
long i;
|
173
|
+
const long n_dimensions = RARRAY_LEN(sum_vec);
|
174
|
+
double el;
|
175
|
+
|
176
|
+
for (i = 0; i < n_dimensions; i++) {
|
177
|
+
el = NUM2DBL(rb_ary_entry(sum_vec, i)) - NUM2DBL(rb_ary_entry(target, i));
|
178
|
+
rb_ary_store(sum_vec, i, DBL2NUM(el));
|
179
|
+
}
|
180
|
+
}
|
181
|
+
|
182
|
+
/**
|
183
|
+
* @!visibility private
|
184
|
+
* Find for split point with maximum information gain.
|
185
|
+
*
|
186
|
+
* @overload find_split_params(criterion, impurity, sorted_features, sorted_labels, uniqed_features, n_classes) -> Array<Float>
|
187
|
+
*
|
188
|
+
* @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
|
189
|
+
* @param impurity [Float] The impurity of whole dataset.
|
190
|
+
* @param sorted_features [Numo::DFloat] (shape: [n_samples]) The feature values sorted in ascending order.
|
191
|
+
* @param sorted_labels [Numo::Int32] (shape: [n_labels]) The labels sorted according to feature values.
|
192
|
+
* @param uniqed_features [Numo::DFloat] (shape: [n_uniqed_features]) The unique feature values.
|
193
|
+
* @param n_classes [Integer] The number of classes.
|
194
|
+
* @return [Float] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
|
195
|
+
*/
|
196
|
+
static VALUE
|
197
|
+
find_split_params_cls(VALUE self, VALUE criterion, VALUE whole_impurity, VALUE sorted_f, VALUE sorted_y, VALUE uniqed_f, VALUE n_classes_)
|
198
|
+
{
|
199
|
+
long i;
|
200
|
+
long curr_pos;
|
201
|
+
long next_pos;
|
202
|
+
long n_l_elements;
|
203
|
+
long n_r_elements;
|
204
|
+
const long n_classes = NUM2LONG(n_classes_);
|
205
|
+
const long n_elements = RARRAY_LEN(sorted_f);
|
206
|
+
const long n_uniq_elements = RARRAY_LEN(uniqed_f);
|
207
|
+
const double w_impurity = NUM2DBL(whole_impurity);
|
208
|
+
double l_impurity;
|
209
|
+
double r_impurity;
|
210
|
+
double gain;
|
211
|
+
double curr_el;
|
212
|
+
double next_el;
|
213
|
+
VALUE l_histogram = create_zero_vector(n_classes);
|
214
|
+
VALUE r_histogram = create_zero_vector(n_classes);
|
215
|
+
VALUE opt_params = rb_ary_new2(4);
|
216
|
+
|
217
|
+
/* Initialize optimal parameters. */
|
218
|
+
rb_ary_store(opt_params, 0, DBL2NUM(0)); /* left impurity */
|
219
|
+
rb_ary_store(opt_params, 1, DBL2NUM(w_impurity)); /* right impurity */
|
220
|
+
rb_ary_store(opt_params, 2, rb_ary_entry(uniqed_f, 0)); /* threshold */
|
221
|
+
rb_ary_store(opt_params, 3, DBL2NUM(0)); /* gain */
|
222
|
+
|
223
|
+
/* Initialize child node variables. */
|
224
|
+
n_l_elements = 0;
|
225
|
+
n_r_elements = n_elements;
|
226
|
+
for (i = 0; i < n_elements; i++) {
|
227
|
+
increment_histogram(r_histogram, NUM2LONG(rb_ary_entry(sorted_y, i)));
|
228
|
+
}
|
229
|
+
|
230
|
+
/* Find optimal parameters. */
|
231
|
+
for (curr_pos = 0, next_pos = 0; curr_pos < n_uniq_elements - 1; curr_pos++) {
|
232
|
+
/* Find new split point. */
|
233
|
+
curr_el = NUM2DBL(rb_ary_entry(uniqed_f, curr_pos));
|
234
|
+
next_el = NUM2DBL(rb_ary_entry(sorted_f, next_pos));
|
235
|
+
while (next_pos < n_elements && next_el <= curr_el) {
|
236
|
+
increment_histogram(l_histogram, NUM2LONG(rb_ary_entry(sorted_y, next_pos)));
|
237
|
+
n_l_elements++;
|
238
|
+
decrement_histogram(r_histogram, NUM2LONG(rb_ary_entry(sorted_y, next_pos)));
|
239
|
+
n_r_elements--;
|
240
|
+
next_el = NUM2DBL(rb_ary_entry(sorted_f, ++next_pos));
|
241
|
+
}
|
242
|
+
/* Calculate gain of new split. */
|
243
|
+
l_impurity = calc_impurity_cls(criterion, l_histogram, n_l_elements);
|
244
|
+
r_impurity = calc_impurity_cls(criterion, r_histogram, n_r_elements);
|
245
|
+
gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
|
246
|
+
/* Update optimal parameters. */
|
247
|
+
if (gain > NUM2DBL(rb_ary_entry(opt_params, 3))) {
|
248
|
+
rb_ary_store(opt_params, 0, DBL2NUM(l_impurity));
|
249
|
+
rb_ary_store(opt_params, 1, DBL2NUM(r_impurity));
|
250
|
+
rb_ary_store(opt_params, 2, DBL2NUM(0.5 * (curr_el + next_el)));
|
251
|
+
rb_ary_store(opt_params, 3, DBL2NUM(gain));
|
252
|
+
}
|
253
|
+
}
|
254
|
+
|
255
|
+
return opt_params;
|
256
|
+
}
|
257
|
+
|
258
|
+
/**
|
259
|
+
* @!visibility private
|
260
|
+
* Find for split point with maximum information gain.
|
261
|
+
*
|
262
|
+
* @overload find_split_params(criterion, impurity, sorted_features, sorted_targets, uniqed_features) -> Array<Float>
|
263
|
+
*
|
264
|
+
* @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
|
265
|
+
* @param impurity [Float] The impurity of whole dataset.
|
266
|
+
* @param sorted_features [Numo::DFloat] (shape: [n_samples]) The feature values sorted in ascending order.
|
267
|
+
* @param sorted_targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values sorted according to feature values.
|
268
|
+
* @param uniqed_features [Numo::DFloat] (shape: [n_uniqed_features]) The unique feature values.
|
269
|
+
* @return [Float] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
|
270
|
+
*/
|
271
|
+
static VALUE
|
272
|
+
find_split_params_reg(VALUE self, VALUE criterion, VALUE whole_impurity, VALUE sorted_f, VALUE sorted_y, VALUE uniqed_f)
|
273
|
+
{
|
274
|
+
long i;
|
275
|
+
long curr_pos;
|
276
|
+
long next_pos;
|
277
|
+
long n_l_elements;
|
278
|
+
long n_r_elements;
|
279
|
+
const long n_elements = RARRAY_LEN(sorted_f);
|
280
|
+
const long n_uniq_elements = RARRAY_LEN(uniqed_f);
|
281
|
+
const long n_dimensions = RARRAY_LEN(rb_ary_entry(sorted_y, 0));
|
282
|
+
const double w_impurity = NUM2DBL(whole_impurity);
|
283
|
+
double l_impurity;
|
284
|
+
double r_impurity;
|
285
|
+
double gain;
|
286
|
+
double curr_el;
|
287
|
+
double next_el;
|
288
|
+
VALUE l_sum_vec = create_zero_vector(n_dimensions);
|
289
|
+
VALUE r_sum_vec = create_zero_vector(n_dimensions);
|
290
|
+
VALUE l_target_vecs = rb_ary_new();
|
291
|
+
VALUE r_target_vecs = rb_ary_new();
|
292
|
+
VALUE target;
|
293
|
+
VALUE opt_params = rb_ary_new2(4);
|
294
|
+
|
295
|
+
/* Initialize optimal parameters. */
|
296
|
+
rb_ary_store(opt_params, 0, DBL2NUM(0)); /* left impurity */
|
297
|
+
rb_ary_store(opt_params, 1, DBL2NUM(w_impurity)); /* right impurity */
|
298
|
+
rb_ary_store(opt_params, 2, rb_ary_entry(uniqed_f, 0)); /* threshold */
|
299
|
+
rb_ary_store(opt_params, 3, DBL2NUM(0)); /* gain */
|
300
|
+
|
301
|
+
/* Initialize child node variables. */
|
302
|
+
n_l_elements = 0;
|
303
|
+
n_r_elements = n_elements;
|
304
|
+
for (i = 0; i < n_elements; i++) {
|
305
|
+
target = rb_ary_entry(sorted_y, i);
|
306
|
+
add_sum_vec(r_sum_vec, target);
|
307
|
+
rb_ary_push(r_target_vecs, target);
|
308
|
+
}
|
309
|
+
|
310
|
+
/* Find optimal parameters. */
|
311
|
+
for (curr_pos = 0, next_pos = 0; curr_pos < n_uniq_elements - 1; curr_pos++) {
|
312
|
+
/* Find new split point. */
|
313
|
+
curr_el = NUM2DBL(rb_ary_entry(uniqed_f, curr_pos));
|
314
|
+
next_el = NUM2DBL(rb_ary_entry(sorted_f, next_pos));
|
315
|
+
while (next_pos < n_elements && next_el <= curr_el) {
|
316
|
+
target = rb_ary_entry(sorted_y, next_pos);
|
317
|
+
add_sum_vec(l_sum_vec, target);
|
318
|
+
rb_ary_push(l_target_vecs, target);
|
319
|
+
n_l_elements++;
|
320
|
+
sub_sum_vec(r_sum_vec, target);
|
321
|
+
rb_ary_shift(r_target_vecs);
|
322
|
+
n_r_elements--;
|
323
|
+
next_el = NUM2DBL(rb_ary_entry(sorted_f, ++next_pos));
|
324
|
+
}
|
325
|
+
/* Calculate gain of new split. */
|
326
|
+
l_impurity = calc_impurity_reg(criterion, l_target_vecs, l_sum_vec);
|
327
|
+
r_impurity = calc_impurity_reg(criterion, r_target_vecs, r_sum_vec);
|
328
|
+
gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
|
329
|
+
/* Update optimal parameters. */
|
330
|
+
if (gain > NUM2DBL(rb_ary_entry(opt_params, 3))) {
|
331
|
+
rb_ary_store(opt_params, 0, DBL2NUM(l_impurity));
|
332
|
+
rb_ary_store(opt_params, 1, DBL2NUM(r_impurity));
|
333
|
+
rb_ary_store(opt_params, 2, DBL2NUM(0.5 * (curr_el + next_el)));
|
334
|
+
rb_ary_store(opt_params, 3, DBL2NUM(gain));
|
335
|
+
}
|
336
|
+
}
|
337
|
+
|
338
|
+
return opt_params;
|
339
|
+
}
|
340
|
+
|
341
|
+
/**
|
342
|
+
* @!visibility private
|
343
|
+
* Calculate impurity based on criterion.
|
344
|
+
*
|
345
|
+
* @overload node_impurity(criterion, y, n_classes) -> Float
|
346
|
+
*
|
347
|
+
* @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
|
348
|
+
* @param y [Numo::Int32] (shape: [n_samples]) The labels.
|
349
|
+
* @param n_classes [Integer] The number of classes.
|
350
|
+
* @return [Float] impurity
|
351
|
+
*/
|
352
|
+
static VALUE
|
353
|
+
node_impurity_cls(VALUE self, VALUE criterion, VALUE y, VALUE n_classes)
|
354
|
+
{
|
355
|
+
long i;
|
356
|
+
const long n_elements = RARRAY_LEN(y);
|
357
|
+
VALUE histogram = create_zero_vector(NUM2LONG(n_classes));
|
358
|
+
|
359
|
+
for (i = 0; i < n_elements; i++) {
|
360
|
+
increment_histogram(histogram, NUM2LONG(rb_ary_entry(y, i)));
|
361
|
+
}
|
362
|
+
|
363
|
+
return DBL2NUM(calc_impurity_cls(criterion, histogram, n_elements));
|
364
|
+
}
|
365
|
+
|
366
|
+
/**
|
367
|
+
* @!visibility private
|
368
|
+
* Calculate impurity based on criterion.
|
369
|
+
*
|
370
|
+
* @overload node_impurity(criterion, y) -> Float
|
371
|
+
*
|
372
|
+
* @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
|
373
|
+
* @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The taget values.
|
374
|
+
* @return [Float] impurity
|
375
|
+
*/
|
376
|
+
static VALUE
|
377
|
+
node_impurity_reg(VALUE self, VALUE criterion, VALUE y)
|
378
|
+
{
|
379
|
+
long i;
|
380
|
+
const long n_elements = RARRAY_LEN(y);
|
381
|
+
const long n_dimensions = RARRAY_LEN(rb_ary_entry(y, 0));
|
382
|
+
VALUE sum_vec = create_zero_vector(n_dimensions);
|
383
|
+
VALUE target_vecs = rb_ary_new();
|
384
|
+
VALUE target;
|
385
|
+
|
386
|
+
for (i = 0; i < n_elements; i++) {
|
387
|
+
target = rb_ary_entry(y, i);
|
388
|
+
add_sum_vec(sum_vec, target);
|
389
|
+
rb_ary_push(target_vecs, target);
|
390
|
+
}
|
391
|
+
|
392
|
+
return DBL2NUM(calc_impurity_reg(criterion, target_vecs, sum_vec));
|
393
|
+
}
|
394
|
+
|
395
|
+
void Init_rumale(void)
|
396
|
+
{
|
397
|
+
VALUE mRumale = rb_define_module("Rumale");
|
398
|
+
VALUE mTree = rb_define_module_under(mRumale, "Tree");
|
399
|
+
/**
|
400
|
+
* Document-module: Rumale::Tree::ExtDecisionTreeClassifier
|
401
|
+
* @!visibility private
|
402
|
+
* The mixin module consisting of extension method for DecisionTreeClassifier class.
|
403
|
+
* This module is used internally.
|
404
|
+
*/
|
405
|
+
VALUE mExtDTreeCls = rb_define_module_under(mTree, "ExtDecisionTreeClassifier");
|
406
|
+
/**
|
407
|
+
* Document-module: Rumale::Tree::ExtDecisionTreeRegressor
|
408
|
+
* @!visibility private
|
409
|
+
* The mixin module consisting of extension method for DecisionTreeRegressor class.
|
410
|
+
* This module is used internally.
|
411
|
+
*/
|
412
|
+
VALUE mExtDTreeReg = rb_define_module_under(mTree, "ExtDecisionTreeRegressor");
|
413
|
+
|
414
|
+
rb_define_method(mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
|
415
|
+
rb_define_method(mExtDTreeReg, "find_split_params", find_split_params_reg, 5);
|
416
|
+
rb_define_method(mExtDTreeCls, "node_impurity", node_impurity_cls, 3);
|
417
|
+
rb_define_method(mExtDTreeReg, "node_impurity", node_impurity_reg, 2);
|
418
|
+
}
|
data/ext/rumale/rumale.h
ADDED
data/lib/rumale.rb
CHANGED
@@ -86,13 +86,16 @@ module Rumale
|
|
86
86
|
return put_leaf(node, y) if stop_growing?(y)
|
87
87
|
|
88
88
|
# calculate optimal parameters.
|
89
|
-
feature_id,
|
90
|
-
|
89
|
+
feature_id, left_ids, right_ids, left_imp, right_imp, threshold, gain = rand_ids(n_features).map do |fid|
|
90
|
+
ft = x[true, fid]
|
91
|
+
limp, rimp, th, ga = best_split(ft, y, whole_impurity)
|
92
|
+
[fid, ft.le(th).where, ft.gt(th).where, limp, rimp, th, ga]
|
93
|
+
end.max_by(&:last)
|
91
94
|
|
92
95
|
return put_leaf(node, y) if gain.nil? || gain.zero?
|
93
96
|
|
94
|
-
node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids, true],
|
95
|
-
node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids, true],
|
97
|
+
node.left = grow_node(depth + 1, x[left_ids, true], y[left_ids, true], left_imp)
|
98
|
+
node.right = grow_node(depth + 1, x[right_ids, true], y[right_ids, true], right_imp)
|
96
99
|
|
97
100
|
return put_leaf(node, y) if node.left.nil? && node.right.nil?
|
98
101
|
|
@@ -114,22 +117,11 @@ module Rumale
|
|
114
117
|
[*0...n].sample(@params[:max_features], random: @rng)
|
115
118
|
end
|
116
119
|
|
117
|
-
def best_split(
|
118
|
-
|
119
|
-
features.to_a.uniq.sort.each_cons(2).map do |l, r|
|
120
|
-
threshold = 0.5 * (l + r)
|
121
|
-
left_ids = features.le(threshold).where
|
122
|
-
right_ids = features.gt(threshold).where
|
123
|
-
left_impurity = impurity(targets[left_ids, true])
|
124
|
-
right_impurity = impurity(targets[right_ids, true])
|
125
|
-
gain = whole_impurity -
|
126
|
-
left_impurity * left_ids.size.fdiv(n_samples) -
|
127
|
-
right_impurity * right_ids.size.fdiv(n_samples)
|
128
|
-
[threshold, left_ids, right_ids, left_impurity, right_impurity, gain]
|
129
|
-
end.max_by(&:last)
|
120
|
+
def best_split(_features, _y, _impurity)
|
121
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
130
122
|
end
|
131
123
|
|
132
|
-
def impurity(
|
124
|
+
def impurity(_y)
|
133
125
|
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
134
126
|
end
|
135
127
|
|
@@ -1,5 +1,6 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
+
require 'rumale/rumale'
|
3
4
|
require 'rumale/tree/base_decision_tree'
|
4
5
|
require 'rumale/base/classifier'
|
5
6
|
|
@@ -16,6 +17,7 @@ module Rumale
|
|
16
17
|
#
|
17
18
|
class DecisionTreeClassifier < BaseDecisionTree
|
18
19
|
include Base::Classifier
|
20
|
+
include ExtDecisionTreeClassifier
|
19
21
|
|
20
22
|
# Return the class labels.
|
21
23
|
# @return [Numo::Int32] (size: n_classes)
|
@@ -39,7 +41,7 @@ module Rumale
|
|
39
41
|
|
40
42
|
# Create a new classifier with decision tree algorithm.
|
41
43
|
#
|
42
|
-
# @param criterion [String] The function to
|
44
|
+
# @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
|
43
45
|
# @param max_depth [Integer] The maximum depth of the tree.
|
44
46
|
# If nil is given, decision tree grows without concern for depth.
|
45
47
|
# @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
|
@@ -89,7 +91,7 @@ module Rumale
|
|
89
91
|
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
|
90
92
|
def predict(x)
|
91
93
|
check_sample_array(x)
|
92
|
-
@leaf_labels[apply(x)]
|
94
|
+
@leaf_labels[apply(x)].dup
|
93
95
|
end
|
94
96
|
|
95
97
|
# Predict probability for samples.
|
@@ -138,7 +140,7 @@ module Rumale
|
|
138
140
|
end
|
139
141
|
|
140
142
|
def stop_growing?(y)
|
141
|
-
y.
|
143
|
+
y[true, 0].to_a.uniq.size == 1
|
142
144
|
end
|
143
145
|
|
144
146
|
def put_leaf(node, y)
|
@@ -150,13 +152,17 @@ module Rumale
|
|
150
152
|
node
|
151
153
|
end
|
152
154
|
|
155
|
+
def best_split(features, y, whole_impurity)
|
156
|
+
order = features.sort_index
|
157
|
+
sorted_f = features[order].to_a
|
158
|
+
sorted_y = y[order, true].to_a.flatten
|
159
|
+
n_classes = @classes.size
|
160
|
+
find_split_params(@params[:criterion], whole_impurity, sorted_f, sorted_y, sorted_f.uniq, n_classes)
|
161
|
+
end
|
162
|
+
|
153
163
|
def impurity(y)
|
154
|
-
|
155
|
-
|
156
|
-
-(posterior_probs * Numo::NMath.log(posterior_probs + 1)).sum
|
157
|
-
else
|
158
|
-
1.0 - (posterior_probs * posterior_probs).sum
|
159
|
-
end
|
164
|
+
n_classes = @classes.size
|
165
|
+
node_impurity(@params[:criterion], y[true, 0].to_a, n_classes)
|
160
166
|
end
|
161
167
|
end
|
162
168
|
end
|
@@ -1,5 +1,6 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
+
require 'rumale/rumale'
|
3
4
|
require 'rumale/tree/base_decision_tree'
|
4
5
|
require 'rumale/base/regressor'
|
5
6
|
|
@@ -16,6 +17,7 @@ module Rumale
|
|
16
17
|
#
|
17
18
|
class DecisionTreeRegressor < BaseDecisionTree
|
18
19
|
include Base::Regressor
|
20
|
+
include ExtDecisionTreeRegressor
|
19
21
|
|
20
22
|
# Return the importance for each feature.
|
21
23
|
# @return [Numo::DFloat] (size: n_features)
|
@@ -35,7 +37,7 @@ module Rumale
|
|
35
37
|
|
36
38
|
# Create a new regressor with decision tree algorithm.
|
37
39
|
#
|
38
|
-
# @param criterion [String] The function to
|
40
|
+
# @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
|
39
41
|
# @param max_depth [Integer] The maximum depth of the tree.
|
40
42
|
# If nil is given, decision tree grows without concern for depth.
|
41
43
|
# @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree.
|
@@ -84,7 +86,7 @@ module Rumale
|
|
84
86
|
# @return [Numo::DFloat] (shape: [n_samples, n_outputs]) Predicted values per sample.
|
85
87
|
def predict(x)
|
86
88
|
check_sample_array(x)
|
87
|
-
@leaf_values.shape[1].nil? ? @leaf_values[apply(x)] : @leaf_values[apply(x), true]
|
89
|
+
@leaf_values.shape[1].nil? ? @leaf_values[apply(x)].dup : @leaf_values[apply(x), true].dup
|
88
90
|
end
|
89
91
|
|
90
92
|
# Dump marshal data.
|
@@ -123,12 +125,15 @@ module Rumale
|
|
123
125
|
node
|
124
126
|
end
|
125
127
|
|
126
|
-
def
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
128
|
+
def best_split(features, y, whole_impurity)
|
129
|
+
order = features.sort_index
|
130
|
+
sorted_f = features[order].to_a
|
131
|
+
sorted_y = y[order, true].to_a
|
132
|
+
find_split_params(@params[:criterion], whole_impurity, sorted_f, sorted_y, sorted_f.uniq)
|
133
|
+
end
|
134
|
+
|
135
|
+
def impurity(y)
|
136
|
+
node_impurity(@params[:criterion], y.to_a)
|
132
137
|
end
|
133
138
|
end
|
134
139
|
end
|
data/lib/rumale/version.rb
CHANGED
data/rumale.gemspec
CHANGED
@@ -29,6 +29,7 @@ MSG
|
|
29
29
|
spec.bindir = 'exe'
|
30
30
|
spec.executables = spec.files.grep(%r{^exe/}) { |f| File.basename(f) }
|
31
31
|
spec.require_paths = ['lib']
|
32
|
+
spec.extensions = ['ext/rumale/extconf.rb']
|
32
33
|
|
33
34
|
spec.required_ruby_version = '>= 2.3'
|
34
35
|
|
@@ -37,5 +38,6 @@ MSG
|
|
37
38
|
spec.add_development_dependency 'bundler', '>= 1.16'
|
38
39
|
spec.add_development_dependency 'coveralls', '~> 0.8'
|
39
40
|
spec.add_development_dependency 'rake', '~> 12.0'
|
41
|
+
spec.add_development_dependency 'rake-compiler'
|
40
42
|
spec.add_development_dependency 'rspec', '~> 3.0'
|
41
43
|
end
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.9.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-04-
|
11
|
+
date: 2019-04-22 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -66,6 +66,20 @@ dependencies:
|
|
66
66
|
- - "~>"
|
67
67
|
- !ruby/object:Gem::Version
|
68
68
|
version: '12.0'
|
69
|
+
- !ruby/object:Gem::Dependency
|
70
|
+
name: rake-compiler
|
71
|
+
requirement: !ruby/object:Gem::Requirement
|
72
|
+
requirements:
|
73
|
+
- - ">="
|
74
|
+
- !ruby/object:Gem::Version
|
75
|
+
version: '0'
|
76
|
+
type: :development
|
77
|
+
prerelease: false
|
78
|
+
version_requirements: !ruby/object:Gem::Requirement
|
79
|
+
requirements:
|
80
|
+
- - ">="
|
81
|
+
- !ruby/object:Gem::Version
|
82
|
+
version: '0'
|
69
83
|
- !ruby/object:Gem::Dependency
|
70
84
|
name: rspec
|
71
85
|
requirement: !ruby/object:Gem::Requirement
|
@@ -90,7 +104,8 @@ description: |
|
|
90
104
|
email:
|
91
105
|
- yoshoku@outlook.com
|
92
106
|
executables: []
|
93
|
-
extensions:
|
107
|
+
extensions:
|
108
|
+
- ext/rumale/extconf.rb
|
94
109
|
extra_rdoc_files: []
|
95
110
|
files:
|
96
111
|
- ".coveralls.yml"
|
@@ -107,6 +122,9 @@ files:
|
|
107
122
|
- Rakefile
|
108
123
|
- bin/console
|
109
124
|
- bin/setup
|
125
|
+
- ext/rumale/extconf.rb
|
126
|
+
- ext/rumale/rumale.c
|
127
|
+
- ext/rumale/rumale.h
|
110
128
|
- lib/rumale.rb
|
111
129
|
- lib/rumale/base/base_estimator.rb
|
112
130
|
- lib/rumale/base/classifier.rb
|