rumale-tree 0.24.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 5f7d402d2787e61a76bc2d22d19da29e1b8c1bd7ecfcee07621152a3c6e262ec
4
+ data.tar.gz: e5578f0c6fa00b775b9bc652a2721913bd5a767fbc01d9c70e585715f9693935
5
+ SHA512:
6
+ metadata.gz: 959421ae6500dfd8ddde1f19e19dcca84b3b53d6473be79c3b789176554d43c68c90869c49edabf8bfab726e88e596ffadfbd180ff309d234d1176dc99bc2168
7
+ data.tar.gz: a2558f3ce63f88deef4cf82c562a87914c06b83381ffbc9911b0778e01a13efb9a2147d1a6dad7a97e9e31299c93906d9c9959075dc2a96e81d58ed3f7c6e00d
data/LICENSE.txt ADDED
@@ -0,0 +1,27 @@
1
+ Copyright (c) 2022 Atsushi Tatsuma
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ * Neither the name of the copyright holder nor the names of its
15
+ contributors may be used to endorse or promote products derived from
16
+ this software without specific prior written permission.
17
+
18
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
data/README.md ADDED
@@ -0,0 +1,33 @@
1
+ # Rumale::Tree
2
+
3
+ [![Gem Version](https://badge.fury.io/rb/rumale-tree.svg)](https://badge.fury.io/rb/rumale-tree)
4
+ [![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/rumale/blob/main/rumale-tree/LICENSE.txt)
5
+ [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale/doc/Rumale/Tree.html)
6
+
7
+ Rumale is a machine learning library in Ruby.
8
+ Rumale::Tree provides classifier and regression based on decision tree algorithms
9
+ with Rumale interface.
10
+
11
+ ## Installation
12
+
13
+ Add this line to your application's Gemfile:
14
+
15
+ ```ruby
16
+ gem 'rumale-tree'
17
+ ```
18
+
19
+ And then execute:
20
+
21
+ $ bundle install
22
+
23
+ Or install it yourself as:
24
+
25
+ $ gem install rumale-tree
26
+
27
+ ## Documentation
28
+
29
+ - [Rumale API Documentation - Tree](https://yoshoku.github.io/rumale/doc/Rumale/Tree.html)
30
+
31
+ ## License
32
+
33
+ The gem is available as open source under the terms of the [BSD-3-Clause License](https://opensource.org/licenses/BSD-3-Clause).
@@ -0,0 +1,575 @@
1
+ #include "ext.h"
2
+
3
+ double* alloc_dbl_array(const long n_dimensions) {
4
+ double* arr = ALLOC_N(double, n_dimensions);
5
+ memset(arr, 0, n_dimensions * sizeof(double));
6
+ return arr;
7
+ }
8
+
9
+ double calc_gini_coef(double* histogram, const long n_elements, const long n_classes) {
10
+ double gini = 0.0;
11
+
12
+ for (long i = 0; i < n_classes; i++) {
13
+ double el = histogram[i] / n_elements;
14
+ gini += el * el;
15
+ }
16
+
17
+ return 1.0 - gini;
18
+ }
19
+
20
+ double calc_entropy(double* histogram, const long n_elements, const long n_classes) {
21
+ double entropy = 0.0;
22
+
23
+ for (long i = 0; i < n_classes; i++) {
24
+ double el = histogram[i] / n_elements;
25
+ entropy += el * log(el + 1.0);
26
+ }
27
+
28
+ return -entropy;
29
+ }
30
+
31
+ VALUE
32
+ calc_mean_vec(double* sum_vec, const long n_dimensions, const long n_elements) {
33
+ VALUE mean_vec = rb_ary_new2(n_dimensions);
34
+
35
+ for (long i = 0; i < n_dimensions; i++) {
36
+ rb_ary_store(mean_vec, i, DBL2NUM(sum_vec[i] / n_elements));
37
+ }
38
+
39
+ return mean_vec;
40
+ }
41
+
42
+ double calc_vec_mae(VALUE vec_a, VALUE vec_b) {
43
+ const long n_dimensions = RARRAY_LEN(vec_a);
44
+ double sum = 0.0;
45
+
46
+ for (long i = 0; i < n_dimensions; i++) {
47
+ double diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
48
+ sum += fabs(diff);
49
+ }
50
+
51
+ return sum / n_dimensions;
52
+ }
53
+
54
+ double calc_vec_mse(VALUE vec_a, VALUE vec_b) {
55
+ const long n_dimensions = RARRAY_LEN(vec_a);
56
+ double sum = 0.0;
57
+
58
+ for (long i = 0; i < n_dimensions; i++) {
59
+ double diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
60
+ sum += diff * diff;
61
+ }
62
+
63
+ return sum / n_dimensions;
64
+ }
65
+
66
+ double calc_mae(VALUE target_vecs, VALUE mean_vec) {
67
+ const long n_elements = RARRAY_LEN(target_vecs);
68
+ double sum = 0.0;
69
+
70
+ for (long i = 0; i < n_elements; i++) {
71
+ sum += calc_vec_mae(rb_ary_entry(target_vecs, i), mean_vec);
72
+ }
73
+
74
+ return sum / n_elements;
75
+ }
76
+
77
+ double calc_mse(VALUE target_vecs, VALUE mean_vec) {
78
+ const long n_elements = RARRAY_LEN(target_vecs);
79
+ double sum = 0.0;
80
+
81
+ for (long i = 0; i < n_elements; i++) {
82
+ sum += calc_vec_mse(rb_ary_entry(target_vecs, i), mean_vec);
83
+ }
84
+
85
+ return sum / n_elements;
86
+ }
87
+
88
+ double calc_impurity_cls(const char* criterion, double* histogram, const long n_elements, const long n_classes) {
89
+ if (strcmp(criterion, "entropy") == 0) {
90
+ return calc_entropy(histogram, n_elements, n_classes);
91
+ }
92
+ return calc_gini_coef(histogram, n_elements, n_classes);
93
+ }
94
+
95
+ double calc_impurity_reg(const char* criterion, VALUE target_vecs, double* sum_vec) {
96
+ const long n_elements = RARRAY_LEN(target_vecs);
97
+ const long n_dimensions = RARRAY_LEN(rb_ary_entry(target_vecs, 0));
98
+ VALUE mean_vec = calc_mean_vec(sum_vec, n_dimensions, n_elements);
99
+
100
+ if (strcmp(criterion, "mae") == 0) {
101
+ return calc_mae(target_vecs, mean_vec);
102
+ }
103
+ return calc_mse(target_vecs, mean_vec);
104
+ }
105
+
106
+ void add_sum_vec(double* sum_vec, VALUE target) {
107
+ const long n_dimensions = RARRAY_LEN(target);
108
+
109
+ for (long i = 0; i < n_dimensions; i++) {
110
+ sum_vec[i] += NUM2DBL(rb_ary_entry(target, i));
111
+ }
112
+ }
113
+
114
+ void sub_sum_vec(double* sum_vec, VALUE target) {
115
+ const long n_dimensions = RARRAY_LEN(target);
116
+
117
+ for (long i = 0; i < n_dimensions; i++) {
118
+ sum_vec[i] -= NUM2DBL(rb_ary_entry(target, i));
119
+ }
120
+ }
121
+
122
+ /**
123
+ * @!visibility private
124
+ */
125
+ typedef struct {
126
+ char* criterion;
127
+ long n_classes;
128
+ double impurity;
129
+ } split_opts_cls;
130
+
131
+ /**
132
+ * @!visibility private
133
+ */
134
+ static void iter_find_split_params_cls(na_loop_t const* lp) {
135
+ const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
136
+ const double* f = (double*)NDL_PTR(lp, 1);
137
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 2);
138
+ const long n_elements = NDL_SHAPE(lp, 0)[0];
139
+ const char* criterion = ((split_opts_cls*)lp->opt_ptr)->criterion;
140
+ const long n_classes = ((split_opts_cls*)lp->opt_ptr)->n_classes;
141
+ const double w_impurity = ((split_opts_cls*)lp->opt_ptr)->impurity;
142
+ double* params = (double*)NDL_PTR(lp, 3);
143
+ long curr_pos = 0;
144
+ long next_pos = 0;
145
+ long n_l_elements = 0;
146
+ long n_r_elements = n_elements;
147
+ double curr_el = f[o[0]];
148
+ double last_el = f[o[n_elements - 1]];
149
+ double next_el;
150
+ double l_impurity;
151
+ double r_impurity;
152
+ double gain;
153
+ double* l_histogram = alloc_dbl_array(n_classes);
154
+ double* r_histogram = alloc_dbl_array(n_classes);
155
+
156
+ /* Initialize optimal parameters. */
157
+ params[0] = 0.0; /* left impurity */
158
+ params[1] = w_impurity; /* right impurity */
159
+ params[2] = curr_el; /* threshold */
160
+ params[3] = 0.0; /* gain */
161
+
162
+ /* Initialize child node variables. */
163
+ for (long i = 0; i < n_elements; i++) {
164
+ r_histogram[y[o[i]]] += 1.0;
165
+ }
166
+
167
+ /* Find optimal parameters. */
168
+ while (curr_pos < n_elements && curr_el != last_el) {
169
+ next_el = f[o[next_pos]];
170
+ while (next_pos < n_elements && next_el == curr_el) {
171
+ l_histogram[y[o[next_pos]]] += 1;
172
+ n_l_elements++;
173
+ r_histogram[y[o[next_pos]]] -= 1;
174
+ n_r_elements--;
175
+ next_pos++;
176
+ next_el = f[o[next_pos]];
177
+ }
178
+ /* Calculate gain of new split. */
179
+ l_impurity = calc_impurity_cls(criterion, l_histogram, n_l_elements, n_classes);
180
+ r_impurity = calc_impurity_cls(criterion, r_histogram, n_r_elements, n_classes);
181
+ gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
182
+ /* Update optimal parameters. */
183
+ if (gain > params[3]) {
184
+ params[0] = l_impurity;
185
+ params[1] = r_impurity;
186
+ params[2] = 0.5 * (curr_el + next_el);
187
+ params[3] = gain;
188
+ }
189
+ if (next_pos == n_elements) break;
190
+ curr_pos = next_pos;
191
+ curr_el = f[o[curr_pos]];
192
+ }
193
+
194
+ xfree(l_histogram);
195
+ xfree(r_histogram);
196
+ }
197
+
198
+ /**
199
+ * @!visibility private
200
+ * Find for split point with maximum information gain.
201
+ *
202
+ * @overload find_split_params(criterion, impurity, order, features, labels, n_classes) -> Array<Float>
203
+ *
204
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
205
+ * @param impurity [Float] The impurity of whole dataset.
206
+ * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
207
+ * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
208
+ * @param labels [Numo::Int32] (shape: [n_elements]) The labels.
209
+ * @param n_classes [Integer] The number of classes.
210
+ * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
211
+ */
212
+ static VALUE find_split_params_cls(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE labels,
213
+ VALUE n_classes) {
214
+ ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cInt32, 1}};
215
+ size_t out_shape[1] = {4};
216
+ ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
217
+ ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_cls, NO_LOOP, 3, 1, ain, aout};
218
+ split_opts_cls opts = {StringValuePtr(criterion), NUM2LONG(n_classes), NUM2DBL(impurity)};
219
+ VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, labels);
220
+ VALUE results = rb_ary_new2(4);
221
+ double* params_ptr = (double*)na_get_pointer_for_read(params);
222
+ rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
223
+ rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
224
+ rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
225
+ rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
226
+ RB_GC_GUARD(params);
227
+ RB_GC_GUARD(criterion);
228
+ return results;
229
+ }
230
+
231
+ /**
232
+ * @!visibility private
233
+ */
234
+ typedef struct {
235
+ char* criterion;
236
+ double impurity;
237
+ } split_opts_reg;
238
+
239
+ /**
240
+ * @!visibility private
241
+ */
242
+ static void iter_find_split_params_reg(na_loop_t const* lp) {
243
+ const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
244
+ const double* f = (double*)NDL_PTR(lp, 1);
245
+ const double* y = (double*)NDL_PTR(lp, 2);
246
+ const long n_elements = NDL_SHAPE(lp, 0)[0];
247
+ const long n_outputs = NDL_SHAPE(lp, 2)[1];
248
+ const char* criterion = ((split_opts_reg*)lp->opt_ptr)->criterion;
249
+ const double w_impurity = ((split_opts_reg*)lp->opt_ptr)->impurity;
250
+ double* params = (double*)NDL_PTR(lp, 3);
251
+ long curr_pos = 0;
252
+ long next_pos = 0;
253
+ long n_l_elements = 0;
254
+ long n_r_elements = n_elements;
255
+ double curr_el = f[o[0]];
256
+ double last_el = f[o[n_elements - 1]];
257
+ double next_el;
258
+ double l_impurity;
259
+ double r_impurity;
260
+ double gain;
261
+ double* l_sum_vec = alloc_dbl_array(n_outputs);
262
+ double* r_sum_vec = alloc_dbl_array(n_outputs);
263
+ double target_var;
264
+ VALUE l_target_vecs = rb_ary_new();
265
+ VALUE r_target_vecs = rb_ary_new();
266
+ VALUE target;
267
+
268
+ /* Initialize optimal parameters. */
269
+ params[0] = 0.0; /* left impurity */
270
+ params[1] = w_impurity; /* right impurity */
271
+ params[2] = curr_el; /* threshold */
272
+ params[3] = 0.0; /* gain */
273
+
274
+ /* Initialize child node variables. */
275
+ for (long i = 0; i < n_elements; i++) {
276
+ target = rb_ary_new2(n_outputs);
277
+ for (long j = 0; j < n_outputs; j++) {
278
+ target_var = y[o[i] * n_outputs + j];
279
+ rb_ary_store(target, j, DBL2NUM(target_var));
280
+ r_sum_vec[j] += target_var;
281
+ }
282
+ rb_ary_push(r_target_vecs, target);
283
+ }
284
+
285
+ /* Find optimal parameters. */
286
+ while (curr_pos < n_elements && curr_el != last_el) {
287
+ next_el = f[o[next_pos]];
288
+ while (next_pos < n_elements && next_el == curr_el) {
289
+ target = rb_ary_shift(r_target_vecs);
290
+ n_r_elements--;
291
+ sub_sum_vec(r_sum_vec, target);
292
+ rb_ary_push(l_target_vecs, target);
293
+ n_l_elements++;
294
+ add_sum_vec(l_sum_vec, target);
295
+ next_pos++;
296
+ next_el = f[o[next_pos]];
297
+ }
298
+ /* Calculate gain of new split. */
299
+ l_impurity = calc_impurity_reg(criterion, l_target_vecs, l_sum_vec);
300
+ r_impurity = calc_impurity_reg(criterion, r_target_vecs, r_sum_vec);
301
+ gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
302
+ /* Update optimal parameters. */
303
+ if (gain > params[3]) {
304
+ params[0] = l_impurity;
305
+ params[1] = r_impurity;
306
+ params[2] = 0.5 * (curr_el + next_el);
307
+ params[3] = gain;
308
+ }
309
+ if (next_pos == n_elements) break;
310
+ curr_pos = next_pos;
311
+ curr_el = f[o[curr_pos]];
312
+ }
313
+
314
+ xfree(l_sum_vec);
315
+ xfree(r_sum_vec);
316
+ }
317
+
318
+ /**
319
+ * @!visibility private
320
+ * Find for split point with maximum information gain.
321
+ *
322
+ * @overload find_split_params(criterion, impurity, order, features, targets) -> Array<Float>
323
+ *
324
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
325
+ * @param impurity [Float] The impurity of whole dataset.
326
+ * @param order [Numo::Int32] (shape: [n_samples]) The element indices sorted according to feature values in ascending order.
327
+ * @param features [Numo::DFloat] (shape: [n_samples]) The feature values.
328
+ * @param targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
329
+ * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
330
+ */
331
+ static VALUE find_split_params_reg(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE targets) {
332
+ ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 2}};
333
+ size_t out_shape[1] = {4};
334
+ ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
335
+ ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_reg, NO_LOOP, 3, 1, ain, aout};
336
+ split_opts_reg opts = {StringValuePtr(criterion), NUM2DBL(impurity)};
337
+ VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, targets);
338
+ VALUE results = rb_ary_new2(4);
339
+ double* params_ptr = (double*)na_get_pointer_for_read(params);
340
+ rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
341
+ rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
342
+ rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
343
+ rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
344
+ RB_GC_GUARD(params);
345
+ RB_GC_GUARD(criterion);
346
+ return results;
347
+ }
348
+
349
+ /**
350
+ * @!visibility private
351
+ */
352
+ static void iter_find_split_params_grad_reg(na_loop_t const* lp) {
353
+ const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
354
+ const double* f = (double*)NDL_PTR(lp, 1);
355
+ const double* g = (double*)NDL_PTR(lp, 2);
356
+ const double* h = (double*)NDL_PTR(lp, 3);
357
+ const double s_grad = ((double*)lp->opt_ptr)[0];
358
+ const double s_hess = ((double*)lp->opt_ptr)[1];
359
+ const double reg_lambda = ((double*)lp->opt_ptr)[2];
360
+ const long n_elements = NDL_SHAPE(lp, 0)[0];
361
+ double* params = (double*)NDL_PTR(lp, 4);
362
+ long curr_pos = 0;
363
+ long next_pos = 0;
364
+ double curr_el = f[o[0]];
365
+ double last_el = f[o[n_elements - 1]];
366
+ double next_el;
367
+ double l_grad = 0.0;
368
+ double l_hess = 0.0;
369
+ double r_grad;
370
+ double r_hess;
371
+ double threshold = curr_el;
372
+ double gain_max = 0.0;
373
+ double gain;
374
+
375
+ /* Find optimal parameters. */
376
+ while (curr_pos < n_elements && curr_el != last_el) {
377
+ next_el = f[o[next_pos]];
378
+ while (next_pos < n_elements && next_el == curr_el) {
379
+ l_grad += g[o[next_pos]];
380
+ l_hess += h[o[next_pos]];
381
+ next_pos++;
382
+ next_el = f[o[next_pos]];
383
+ }
384
+ /* Calculate gain of new split. */
385
+ r_grad = s_grad - l_grad;
386
+ r_hess = s_hess - l_hess;
387
+ gain = (l_grad * l_grad) / (l_hess + reg_lambda) + (r_grad * r_grad) / (r_hess + reg_lambda) -
388
+ (s_grad * s_grad) / (s_hess + reg_lambda);
389
+ /* Update optimal parameters. */
390
+ if (gain > gain_max) {
391
+ threshold = 0.5 * (curr_el + next_el);
392
+ gain_max = gain;
393
+ }
394
+ if (next_pos == n_elements) {
395
+ break;
396
+ }
397
+ curr_pos = next_pos;
398
+ curr_el = f[o[curr_pos]];
399
+ }
400
+
401
+ params[0] = threshold;
402
+ params[1] = gain_max;
403
+ }
404
+
405
+ /**
406
+ * @!visibility private
407
+ * Find for split point with maximum information gain.
408
+ *
409
+ * @overload find_split_params(order, features, gradients, hessians, sum_gradient, sum_hessian, reg_lambda) -> Array<Float>
410
+ * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
411
+ * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
412
+ * @param gradients [Numo::DFloat] (shape: [n_elements]) The gradient values.
413
+ * @param hessians [Numo::DFloat] (shape: [n_elements]) The hessian values.
414
+ * @param sum_gradient [Float] The sum of gradient values.
415
+ * @param sum_hessian [Float] The sum of hessian values.
416
+ * @param reg_lambda [Float] The L2 regularization term on weight.
417
+ * @return [Array<Float>] The array consists of optimal parameters including threshold and gain.
418
+ */
419
+ static VALUE find_split_params_grad_reg(VALUE self, VALUE order, VALUE features, VALUE gradients, VALUE hessians,
420
+ VALUE sum_gradient, VALUE sum_hessian, VALUE reg_lambda) {
421
+ ndfunc_arg_in_t ain[4] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}};
422
+ size_t out_shape[1] = {2};
423
+ ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
424
+ ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_grad_reg, NO_LOOP, 4, 1, ain, aout};
425
+ double opts[3] = {NUM2DBL(sum_gradient), NUM2DBL(sum_hessian), NUM2DBL(reg_lambda)};
426
+ VALUE params = na_ndloop3(&ndf, opts, 4, order, features, gradients, hessians);
427
+ VALUE results = rb_ary_new2(2);
428
+ double* params_ptr = (double*)na_get_pointer_for_read(params);
429
+ rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
430
+ rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
431
+ RB_GC_GUARD(params);
432
+ return results;
433
+ }
434
+
435
+ /**
436
+ * @!visibility private
437
+ */
438
+ typedef struct {
439
+ char* criterion;
440
+ long n_classes;
441
+ } node_impurity_cls_opts;
442
+
443
+ /**
444
+ * @!visibility private
445
+ */
446
+ static void iter_node_impurity_cls(na_loop_t const* lp) {
447
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
448
+ const char* criterion = ((node_impurity_cls_opts*)lp->opt_ptr)->criterion;
449
+ const long n_classes = ((node_impurity_cls_opts*)lp->opt_ptr)->n_classes;
450
+ const long n_elements = NDL_SHAPE(lp, 0)[0];
451
+ double* ret = (double*)NDL_PTR(lp, 1);
452
+ double* histogram = alloc_dbl_array(n_classes);
453
+ for (long i = 0; i < n_elements; i++) histogram[y[i]] += 1;
454
+ *ret = calc_impurity_cls(criterion, histogram, n_elements, n_classes);
455
+ xfree(histogram);
456
+ }
457
+
458
+ /**
459
+ * @!visibility private
460
+ * Calculate impurity based on criterion.
461
+ *
462
+ * @overload node_impurity(criterion, y, n_classes) -> Float
463
+ *
464
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
465
+ * @param y [Numo::Int32] (shape: [n_samples]) The labels.
466
+ * @param n_classes [Integer] The number of classes.
467
+ * @return [Float] impurity
468
+ */
469
+ static VALUE node_impurity_cls(VALUE self, VALUE criterion, VALUE y, VALUE n_classes) {
470
+ ndfunc_arg_in_t ain[1] = {{numo_cInt32, 1}};
471
+ ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 0}};
472
+ ndfunc_t ndf = {(na_iter_func_t)iter_node_impurity_cls, NDF_EXTRACT, 1, 1, ain, aout};
473
+ node_impurity_cls_opts opts = {StringValuePtr(criterion), NUM2LONG(n_classes)};
474
+ VALUE ret = na_ndloop3(&ndf, &opts, 1, y);
475
+ RB_GC_GUARD(criterion);
476
+ return ret;
477
+ }
478
+
479
+ /**
480
+ * @!visibility private
481
+ */
482
+ static void iter_check_same_label(na_loop_t const* lp) {
483
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
484
+ const long n_elements = NDL_SHAPE(lp, 0)[0];
485
+ int32_t* ret = (int32_t*)NDL_PTR(lp, 1);
486
+ *ret = 1;
487
+ if (n_elements > 0) {
488
+ int32_t label = y[0];
489
+ for (long i = 0; i < n_elements; i++) {
490
+ if (y[i] != label) {
491
+ *ret = 0;
492
+ break;
493
+ }
494
+ }
495
+ }
496
+ }
497
+
498
+ /**
499
+ * @!visibility private
500
+ * Check all elements have the save value.
501
+ *
502
+ * @overload check_same_label(y) -> Boolean
503
+ *
504
+ * @param y [Numo::Int32] (shape: [n_samples]) The labels.
505
+ * @return [Boolean]
506
+ */
507
+ static VALUE check_same_label(VALUE self, VALUE y) {
508
+ ndfunc_arg_in_t ain[1] = {{numo_cInt32, 1}};
509
+ ndfunc_arg_out_t aout[1] = {{numo_cInt32, 0}};
510
+ ndfunc_t ndf = {(na_iter_func_t)iter_check_same_label, NDF_EXTRACT, 1, 1, ain, aout};
511
+ VALUE ret = na_ndloop(&ndf, 1, y);
512
+ return (NUM2INT(ret) == 1 ? Qtrue : Qfalse);
513
+ }
514
+
515
+ /**
516
+ * @!visibility private
517
+ * Calculate impurity based on criterion.
518
+ *
519
+ * @overload node_impurity(criterion, y) -> Float
520
+ *
521
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
522
+ * @param y [Array<Float>] (shape: [n_samples, n_outputs]) The taget values.
523
+ * @return [Float] impurity
524
+ */
525
+ static VALUE node_impurity_reg(VALUE self, VALUE criterion, VALUE y) {
526
+ const long n_elements = RARRAY_LEN(y);
527
+ const long n_outputs = RARRAY_LEN(rb_ary_entry(y, 0));
528
+ double* sum_vec = alloc_dbl_array(n_outputs);
529
+ VALUE target_vecs = rb_ary_new();
530
+
531
+ for (long i = 0; i < n_elements; i++) {
532
+ VALUE target = rb_ary_entry(y, i);
533
+ add_sum_vec(sum_vec, target);
534
+ rb_ary_push(target_vecs, target);
535
+ }
536
+
537
+ VALUE ret = DBL2NUM(calc_impurity_reg(StringValuePtr(criterion), target_vecs, sum_vec));
538
+ xfree(sum_vec);
539
+ RB_GC_GUARD(criterion);
540
+ return ret;
541
+ }
542
+
543
+ void Init_ext(void) {
544
+ VALUE rb_mRumale = rb_define_module("Rumale");
545
+ VALUE rb_mTree = rb_define_module_under(rb_mRumale, "Tree");
546
+
547
+ /**
548
+ * Document-module: Rumale::Tree::ExtDecisionTreeClassifier
549
+ * @!visibility private
550
+ * The mixin module consisting of extension method for DecisionTreeClassifier class.
551
+ * This module is used internally.
552
+ */
553
+ VALUE rb_mExtDTreeCls = rb_define_module_under(rb_mTree, "ExtDecisionTreeClassifier");
554
+ /**
555
+ * Document-module: Rumale::Tree::ExtDecisionTreeRegressor
556
+ * @!visibility private
557
+ * The mixin module consisting of extension method for DecisionTreeRegressor class.
558
+ * This module is used internally.
559
+ */
560
+ VALUE rb_mExtDTreeReg = rb_define_module_under(rb_mTree, "ExtDecisionTreeRegressor");
561
+ /**
562
+ * Document-module: Rumale::Tree::ExtGradientTreeRegressor
563
+ * @!visibility private
564
+ * The mixin module consisting of extension method for GradientTreeRegressor class.
565
+ * This module is used internally.
566
+ */
567
+ VALUE rb_mExtGTreeReg = rb_define_module_under(rb_mTree, "ExtGradientTreeRegressor");
568
+
569
+ rb_define_private_method(rb_mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
570
+ rb_define_private_method(rb_mExtDTreeReg, "find_split_params", find_split_params_reg, 5);
571
+ rb_define_private_method(rb_mExtGTreeReg, "find_split_params", find_split_params_grad_reg, 7);
572
+ rb_define_private_method(rb_mExtDTreeCls, "node_impurity", node_impurity_cls, 3);
573
+ rb_define_private_method(rb_mExtDTreeCls, "stop_growing?", check_same_label, 1);
574
+ rb_define_private_method(rb_mExtDTreeReg, "node_impurity", node_impurity_reg, 2);
575
+ }
@@ -0,0 +1,12 @@
1
+ #ifndef RUMALE_TREE_EXT_H
2
+ #define RUMALE_TREE_EXT_H 1
3
+
4
+ #include <math.h>
5
+ #include <string.h>
6
+
7
+ #include <ruby.h>
8
+
9
+ #include <numo/narray.h>
10
+ #include <numo/template.h>
11
+
12
+ #endif /* RUMALE_TREE_EXT_H */
@@ -0,0 +1,32 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'mkmf'
4
+ require 'numo/narray'
5
+
6
+ $LOAD_PATH.each do |lp|
7
+ if File.exist?(File.join(lp, 'numo/numo/narray.h'))
8
+ $INCFLAGS = "-I#{lp}/numo #{$INCFLAGS}"
9
+ break
10
+ end
11
+ end
12
+
13
+ abort 'numo/narray.h not found.' unless have_header('numo/narray.h')
14
+
15
+ if RUBY_PLATFORM =~ /mswin|cygwin|mingw/
16
+ $LOAD_PATH.each do |lp|
17
+ if File.exist?(File.join(lp, 'numo/libnarray.a'))
18
+ $LDFLAGS = "-L#{lp}/numo #{$LDFLAGS}"
19
+ break
20
+ end
21
+ end
22
+
23
+ abort 'libnarray.a not found.' unless have_library('narray', 'nary_new')
24
+ end
25
+
26
+ if RUBY_PLATFORM.match?(/darwin/) && Gem::Version.new('3.1.0') <= Gem::Version.new(RUBY_VERSION)
27
+ if try_link('int main(void){return 0;}', '-Wl,-undefined,dynamic_lookup')
28
+ $LDFLAGS << ' -Wl,-undefined,dynamic_lookup'
29
+ end
30
+ end
31
+
32
+ create_makefile('rumale/tree/ext')