rumale-tree 0.24.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 5f7d402d2787e61a76bc2d22d19da29e1b8c1bd7ecfcee07621152a3c6e262ec
4
+ data.tar.gz: e5578f0c6fa00b775b9bc652a2721913bd5a767fbc01d9c70e585715f9693935
5
+ SHA512:
6
+ metadata.gz: 959421ae6500dfd8ddde1f19e19dcca84b3b53d6473be79c3b789176554d43c68c90869c49edabf8bfab726e88e596ffadfbd180ff309d234d1176dc99bc2168
7
+ data.tar.gz: a2558f3ce63f88deef4cf82c562a87914c06b83381ffbc9911b0778e01a13efb9a2147d1a6dad7a97e9e31299c93906d9c9959075dc2a96e81d58ed3f7c6e00d
data/LICENSE.txt ADDED
@@ -0,0 +1,27 @@
1
+ Copyright (c) 2022 Atsushi Tatsuma
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ * Neither the name of the copyright holder nor the names of its
15
+ contributors may be used to endorse or promote products derived from
16
+ this software without specific prior written permission.
17
+
18
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
data/README.md ADDED
@@ -0,0 +1,33 @@
1
+ # Rumale::Tree
2
+
3
+ [![Gem Version](https://badge.fury.io/rb/rumale-tree.svg)](https://badge.fury.io/rb/rumale-tree)
4
+ [![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/rumale/blob/main/rumale-tree/LICENSE.txt)
5
+ [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale/doc/Rumale/Tree.html)
6
+
7
+ Rumale is a machine learning library in Ruby.
8
+ Rumale::Tree provides classifier and regression based on decision tree algorithms
9
+ with Rumale interface.
10
+
11
+ ## Installation
12
+
13
+ Add this line to your application's Gemfile:
14
+
15
+ ```ruby
16
+ gem 'rumale-tree'
17
+ ```
18
+
19
+ And then execute:
20
+
21
+ $ bundle install
22
+
23
+ Or install it yourself as:
24
+
25
+ $ gem install rumale-tree
26
+
27
+ ## Documentation
28
+
29
+ - [Rumale API Documentation - Tree](https://yoshoku.github.io/rumale/doc/Rumale/Tree.html)
30
+
31
+ ## License
32
+
33
+ The gem is available as open source under the terms of the [BSD-3-Clause License](https://opensource.org/licenses/BSD-3-Clause).
@@ -0,0 +1,575 @@
1
+ #include "ext.h"
2
+
3
+ double* alloc_dbl_array(const long n_dimensions) {
4
+ double* arr = ALLOC_N(double, n_dimensions);
5
+ memset(arr, 0, n_dimensions * sizeof(double));
6
+ return arr;
7
+ }
8
+
9
+ double calc_gini_coef(double* histogram, const long n_elements, const long n_classes) {
10
+ double gini = 0.0;
11
+
12
+ for (long i = 0; i < n_classes; i++) {
13
+ double el = histogram[i] / n_elements;
14
+ gini += el * el;
15
+ }
16
+
17
+ return 1.0 - gini;
18
+ }
19
+
20
+ double calc_entropy(double* histogram, const long n_elements, const long n_classes) {
21
+ double entropy = 0.0;
22
+
23
+ for (long i = 0; i < n_classes; i++) {
24
+ double el = histogram[i] / n_elements;
25
+ entropy += el * log(el + 1.0);
26
+ }
27
+
28
+ return -entropy;
29
+ }
30
+
31
+ VALUE
32
+ calc_mean_vec(double* sum_vec, const long n_dimensions, const long n_elements) {
33
+ VALUE mean_vec = rb_ary_new2(n_dimensions);
34
+
35
+ for (long i = 0; i < n_dimensions; i++) {
36
+ rb_ary_store(mean_vec, i, DBL2NUM(sum_vec[i] / n_elements));
37
+ }
38
+
39
+ return mean_vec;
40
+ }
41
+
42
+ double calc_vec_mae(VALUE vec_a, VALUE vec_b) {
43
+ const long n_dimensions = RARRAY_LEN(vec_a);
44
+ double sum = 0.0;
45
+
46
+ for (long i = 0; i < n_dimensions; i++) {
47
+ double diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
48
+ sum += fabs(diff);
49
+ }
50
+
51
+ return sum / n_dimensions;
52
+ }
53
+
54
+ double calc_vec_mse(VALUE vec_a, VALUE vec_b) {
55
+ const long n_dimensions = RARRAY_LEN(vec_a);
56
+ double sum = 0.0;
57
+
58
+ for (long i = 0; i < n_dimensions; i++) {
59
+ double diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
60
+ sum += diff * diff;
61
+ }
62
+
63
+ return sum / n_dimensions;
64
+ }
65
+
66
+ double calc_mae(VALUE target_vecs, VALUE mean_vec) {
67
+ const long n_elements = RARRAY_LEN(target_vecs);
68
+ double sum = 0.0;
69
+
70
+ for (long i = 0; i < n_elements; i++) {
71
+ sum += calc_vec_mae(rb_ary_entry(target_vecs, i), mean_vec);
72
+ }
73
+
74
+ return sum / n_elements;
75
+ }
76
+
77
+ double calc_mse(VALUE target_vecs, VALUE mean_vec) {
78
+ const long n_elements = RARRAY_LEN(target_vecs);
79
+ double sum = 0.0;
80
+
81
+ for (long i = 0; i < n_elements; i++) {
82
+ sum += calc_vec_mse(rb_ary_entry(target_vecs, i), mean_vec);
83
+ }
84
+
85
+ return sum / n_elements;
86
+ }
87
+
88
+ double calc_impurity_cls(const char* criterion, double* histogram, const long n_elements, const long n_classes) {
89
+ if (strcmp(criterion, "entropy") == 0) {
90
+ return calc_entropy(histogram, n_elements, n_classes);
91
+ }
92
+ return calc_gini_coef(histogram, n_elements, n_classes);
93
+ }
94
+
95
+ double calc_impurity_reg(const char* criterion, VALUE target_vecs, double* sum_vec) {
96
+ const long n_elements = RARRAY_LEN(target_vecs);
97
+ const long n_dimensions = RARRAY_LEN(rb_ary_entry(target_vecs, 0));
98
+ VALUE mean_vec = calc_mean_vec(sum_vec, n_dimensions, n_elements);
99
+
100
+ if (strcmp(criterion, "mae") == 0) {
101
+ return calc_mae(target_vecs, mean_vec);
102
+ }
103
+ return calc_mse(target_vecs, mean_vec);
104
+ }
105
+
106
+ void add_sum_vec(double* sum_vec, VALUE target) {
107
+ const long n_dimensions = RARRAY_LEN(target);
108
+
109
+ for (long i = 0; i < n_dimensions; i++) {
110
+ sum_vec[i] += NUM2DBL(rb_ary_entry(target, i));
111
+ }
112
+ }
113
+
114
+ void sub_sum_vec(double* sum_vec, VALUE target) {
115
+ const long n_dimensions = RARRAY_LEN(target);
116
+
117
+ for (long i = 0; i < n_dimensions; i++) {
118
+ sum_vec[i] -= NUM2DBL(rb_ary_entry(target, i));
119
+ }
120
+ }
121
+
122
+ /**
123
+ * @!visibility private
124
+ */
125
+ typedef struct {
126
+ char* criterion;
127
+ long n_classes;
128
+ double impurity;
129
+ } split_opts_cls;
130
+
131
+ /**
132
+ * @!visibility private
133
+ */
134
+ static void iter_find_split_params_cls(na_loop_t const* lp) {
135
+ const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
136
+ const double* f = (double*)NDL_PTR(lp, 1);
137
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 2);
138
+ const long n_elements = NDL_SHAPE(lp, 0)[0];
139
+ const char* criterion = ((split_opts_cls*)lp->opt_ptr)->criterion;
140
+ const long n_classes = ((split_opts_cls*)lp->opt_ptr)->n_classes;
141
+ const double w_impurity = ((split_opts_cls*)lp->opt_ptr)->impurity;
142
+ double* params = (double*)NDL_PTR(lp, 3);
143
+ long curr_pos = 0;
144
+ long next_pos = 0;
145
+ long n_l_elements = 0;
146
+ long n_r_elements = n_elements;
147
+ double curr_el = f[o[0]];
148
+ double last_el = f[o[n_elements - 1]];
149
+ double next_el;
150
+ double l_impurity;
151
+ double r_impurity;
152
+ double gain;
153
+ double* l_histogram = alloc_dbl_array(n_classes);
154
+ double* r_histogram = alloc_dbl_array(n_classes);
155
+
156
+ /* Initialize optimal parameters. */
157
+ params[0] = 0.0; /* left impurity */
158
+ params[1] = w_impurity; /* right impurity */
159
+ params[2] = curr_el; /* threshold */
160
+ params[3] = 0.0; /* gain */
161
+
162
+ /* Initialize child node variables. */
163
+ for (long i = 0; i < n_elements; i++) {
164
+ r_histogram[y[o[i]]] += 1.0;
165
+ }
166
+
167
+ /* Find optimal parameters. */
168
+ while (curr_pos < n_elements && curr_el != last_el) {
169
+ next_el = f[o[next_pos]];
170
+ while (next_pos < n_elements && next_el == curr_el) {
171
+ l_histogram[y[o[next_pos]]] += 1;
172
+ n_l_elements++;
173
+ r_histogram[y[o[next_pos]]] -= 1;
174
+ n_r_elements--;
175
+ next_pos++;
176
+ next_el = f[o[next_pos]];
177
+ }
178
+ /* Calculate gain of new split. */
179
+ l_impurity = calc_impurity_cls(criterion, l_histogram, n_l_elements, n_classes);
180
+ r_impurity = calc_impurity_cls(criterion, r_histogram, n_r_elements, n_classes);
181
+ gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
182
+ /* Update optimal parameters. */
183
+ if (gain > params[3]) {
184
+ params[0] = l_impurity;
185
+ params[1] = r_impurity;
186
+ params[2] = 0.5 * (curr_el + next_el);
187
+ params[3] = gain;
188
+ }
189
+ if (next_pos == n_elements) break;
190
+ curr_pos = next_pos;
191
+ curr_el = f[o[curr_pos]];
192
+ }
193
+
194
+ xfree(l_histogram);
195
+ xfree(r_histogram);
196
+ }
197
+
198
+ /**
199
+ * @!visibility private
200
+ * Find for split point with maximum information gain.
201
+ *
202
+ * @overload find_split_params(criterion, impurity, order, features, labels, n_classes) -> Array<Float>
203
+ *
204
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
205
+ * @param impurity [Float] The impurity of whole dataset.
206
+ * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
207
+ * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
208
+ * @param labels [Numo::Int32] (shape: [n_elements]) The labels.
209
+ * @param n_classes [Integer] The number of classes.
210
+ * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
211
+ */
212
+ static VALUE find_split_params_cls(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE labels,
213
+ VALUE n_classes) {
214
+ ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cInt32, 1}};
215
+ size_t out_shape[1] = {4};
216
+ ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
217
+ ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_cls, NO_LOOP, 3, 1, ain, aout};
218
+ split_opts_cls opts = {StringValuePtr(criterion), NUM2LONG(n_classes), NUM2DBL(impurity)};
219
+ VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, labels);
220
+ VALUE results = rb_ary_new2(4);
221
+ double* params_ptr = (double*)na_get_pointer_for_read(params);
222
+ rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
223
+ rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
224
+ rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
225
+ rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
226
+ RB_GC_GUARD(params);
227
+ RB_GC_GUARD(criterion);
228
+ return results;
229
+ }
230
+
231
+ /**
232
+ * @!visibility private
233
+ */
234
+ typedef struct {
235
+ char* criterion;
236
+ double impurity;
237
+ } split_opts_reg;
238
+
239
+ /**
240
+ * @!visibility private
241
+ */
242
+ static void iter_find_split_params_reg(na_loop_t const* lp) {
243
+ const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
244
+ const double* f = (double*)NDL_PTR(lp, 1);
245
+ const double* y = (double*)NDL_PTR(lp, 2);
246
+ const long n_elements = NDL_SHAPE(lp, 0)[0];
247
+ const long n_outputs = NDL_SHAPE(lp, 2)[1];
248
+ const char* criterion = ((split_opts_reg*)lp->opt_ptr)->criterion;
249
+ const double w_impurity = ((split_opts_reg*)lp->opt_ptr)->impurity;
250
+ double* params = (double*)NDL_PTR(lp, 3);
251
+ long curr_pos = 0;
252
+ long next_pos = 0;
253
+ long n_l_elements = 0;
254
+ long n_r_elements = n_elements;
255
+ double curr_el = f[o[0]];
256
+ double last_el = f[o[n_elements - 1]];
257
+ double next_el;
258
+ double l_impurity;
259
+ double r_impurity;
260
+ double gain;
261
+ double* l_sum_vec = alloc_dbl_array(n_outputs);
262
+ double* r_sum_vec = alloc_dbl_array(n_outputs);
263
+ double target_var;
264
+ VALUE l_target_vecs = rb_ary_new();
265
+ VALUE r_target_vecs = rb_ary_new();
266
+ VALUE target;
267
+
268
+ /* Initialize optimal parameters. */
269
+ params[0] = 0.0; /* left impurity */
270
+ params[1] = w_impurity; /* right impurity */
271
+ params[2] = curr_el; /* threshold */
272
+ params[3] = 0.0; /* gain */
273
+
274
+ /* Initialize child node variables. */
275
+ for (long i = 0; i < n_elements; i++) {
276
+ target = rb_ary_new2(n_outputs);
277
+ for (long j = 0; j < n_outputs; j++) {
278
+ target_var = y[o[i] * n_outputs + j];
279
+ rb_ary_store(target, j, DBL2NUM(target_var));
280
+ r_sum_vec[j] += target_var;
281
+ }
282
+ rb_ary_push(r_target_vecs, target);
283
+ }
284
+
285
+ /* Find optimal parameters. */
286
+ while (curr_pos < n_elements && curr_el != last_el) {
287
+ next_el = f[o[next_pos]];
288
+ while (next_pos < n_elements && next_el == curr_el) {
289
+ target = rb_ary_shift(r_target_vecs);
290
+ n_r_elements--;
291
+ sub_sum_vec(r_sum_vec, target);
292
+ rb_ary_push(l_target_vecs, target);
293
+ n_l_elements++;
294
+ add_sum_vec(l_sum_vec, target);
295
+ next_pos++;
296
+ next_el = f[o[next_pos]];
297
+ }
298
+ /* Calculate gain of new split. */
299
+ l_impurity = calc_impurity_reg(criterion, l_target_vecs, l_sum_vec);
300
+ r_impurity = calc_impurity_reg(criterion, r_target_vecs, r_sum_vec);
301
+ gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
302
+ /* Update optimal parameters. */
303
+ if (gain > params[3]) {
304
+ params[0] = l_impurity;
305
+ params[1] = r_impurity;
306
+ params[2] = 0.5 * (curr_el + next_el);
307
+ params[3] = gain;
308
+ }
309
+ if (next_pos == n_elements) break;
310
+ curr_pos = next_pos;
311
+ curr_el = f[o[curr_pos]];
312
+ }
313
+
314
+ xfree(l_sum_vec);
315
+ xfree(r_sum_vec);
316
+ }
317
+
318
+ /**
319
+ * @!visibility private
320
+ * Find for split point with maximum information gain.
321
+ *
322
+ * @overload find_split_params(criterion, impurity, order, features, targets) -> Array<Float>
323
+ *
324
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
325
+ * @param impurity [Float] The impurity of whole dataset.
326
+ * @param order [Numo::Int32] (shape: [n_samples]) The element indices sorted according to feature values in ascending order.
327
+ * @param features [Numo::DFloat] (shape: [n_samples]) The feature values.
328
+ * @param targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
329
+ * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
330
+ */
331
+ static VALUE find_split_params_reg(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE targets) {
332
+ ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 2}};
333
+ size_t out_shape[1] = {4};
334
+ ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
335
+ ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_reg, NO_LOOP, 3, 1, ain, aout};
336
+ split_opts_reg opts = {StringValuePtr(criterion), NUM2DBL(impurity)};
337
+ VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, targets);
338
+ VALUE results = rb_ary_new2(4);
339
+ double* params_ptr = (double*)na_get_pointer_for_read(params);
340
+ rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
341
+ rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
342
+ rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
343
+ rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
344
+ RB_GC_GUARD(params);
345
+ RB_GC_GUARD(criterion);
346
+ return results;
347
+ }
348
+
349
+ /**
350
+ * @!visibility private
351
+ */
352
+ static void iter_find_split_params_grad_reg(na_loop_t const* lp) {
353
+ const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
354
+ const double* f = (double*)NDL_PTR(lp, 1);
355
+ const double* g = (double*)NDL_PTR(lp, 2);
356
+ const double* h = (double*)NDL_PTR(lp, 3);
357
+ const double s_grad = ((double*)lp->opt_ptr)[0];
358
+ const double s_hess = ((double*)lp->opt_ptr)[1];
359
+ const double reg_lambda = ((double*)lp->opt_ptr)[2];
360
+ const long n_elements = NDL_SHAPE(lp, 0)[0];
361
+ double* params = (double*)NDL_PTR(lp, 4);
362
+ long curr_pos = 0;
363
+ long next_pos = 0;
364
+ double curr_el = f[o[0]];
365
+ double last_el = f[o[n_elements - 1]];
366
+ double next_el;
367
+ double l_grad = 0.0;
368
+ double l_hess = 0.0;
369
+ double r_grad;
370
+ double r_hess;
371
+ double threshold = curr_el;
372
+ double gain_max = 0.0;
373
+ double gain;
374
+
375
+ /* Find optimal parameters. */
376
+ while (curr_pos < n_elements && curr_el != last_el) {
377
+ next_el = f[o[next_pos]];
378
+ while (next_pos < n_elements && next_el == curr_el) {
379
+ l_grad += g[o[next_pos]];
380
+ l_hess += h[o[next_pos]];
381
+ next_pos++;
382
+ next_el = f[o[next_pos]];
383
+ }
384
+ /* Calculate gain of new split. */
385
+ r_grad = s_grad - l_grad;
386
+ r_hess = s_hess - l_hess;
387
+ gain = (l_grad * l_grad) / (l_hess + reg_lambda) + (r_grad * r_grad) / (r_hess + reg_lambda) -
388
+ (s_grad * s_grad) / (s_hess + reg_lambda);
389
+ /* Update optimal parameters. */
390
+ if (gain > gain_max) {
391
+ threshold = 0.5 * (curr_el + next_el);
392
+ gain_max = gain;
393
+ }
394
+ if (next_pos == n_elements) {
395
+ break;
396
+ }
397
+ curr_pos = next_pos;
398
+ curr_el = f[o[curr_pos]];
399
+ }
400
+
401
+ params[0] = threshold;
402
+ params[1] = gain_max;
403
+ }
404
+
405
+ /**
406
+ * @!visibility private
407
+ * Find for split point with maximum information gain.
408
+ *
409
+ * @overload find_split_params(order, features, gradients, hessians, sum_gradient, sum_hessian, reg_lambda) -> Array<Float>
410
+ * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
411
+ * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
412
+ * @param gradients [Numo::DFloat] (shape: [n_elements]) The gradient values.
413
+ * @param hessians [Numo::DFloat] (shape: [n_elements]) The hessian values.
414
+ * @param sum_gradient [Float] The sum of gradient values.
415
+ * @param sum_hessian [Float] The sum of hessian values.
416
+ * @param reg_lambda [Float] The L2 regularization term on weight.
417
+ * @return [Array<Float>] The array consists of optimal parameters including threshold and gain.
418
+ */
419
+ static VALUE find_split_params_grad_reg(VALUE self, VALUE order, VALUE features, VALUE gradients, VALUE hessians,
420
+ VALUE sum_gradient, VALUE sum_hessian, VALUE reg_lambda) {
421
+ ndfunc_arg_in_t ain[4] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}};
422
+ size_t out_shape[1] = {2};
423
+ ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
424
+ ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_grad_reg, NO_LOOP, 4, 1, ain, aout};
425
+ double opts[3] = {NUM2DBL(sum_gradient), NUM2DBL(sum_hessian), NUM2DBL(reg_lambda)};
426
+ VALUE params = na_ndloop3(&ndf, opts, 4, order, features, gradients, hessians);
427
+ VALUE results = rb_ary_new2(2);
428
+ double* params_ptr = (double*)na_get_pointer_for_read(params);
429
+ rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
430
+ rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
431
+ RB_GC_GUARD(params);
432
+ return results;
433
+ }
434
+
435
+ /**
436
+ * @!visibility private
437
+ */
438
+ typedef struct {
439
+ char* criterion;
440
+ long n_classes;
441
+ } node_impurity_cls_opts;
442
+
443
+ /**
444
+ * @!visibility private
445
+ */
446
+ static void iter_node_impurity_cls(na_loop_t const* lp) {
447
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
448
+ const char* criterion = ((node_impurity_cls_opts*)lp->opt_ptr)->criterion;
449
+ const long n_classes = ((node_impurity_cls_opts*)lp->opt_ptr)->n_classes;
450
+ const long n_elements = NDL_SHAPE(lp, 0)[0];
451
+ double* ret = (double*)NDL_PTR(lp, 1);
452
+ double* histogram = alloc_dbl_array(n_classes);
453
+ for (long i = 0; i < n_elements; i++) histogram[y[i]] += 1;
454
+ *ret = calc_impurity_cls(criterion, histogram, n_elements, n_classes);
455
+ xfree(histogram);
456
+ }
457
+
458
+ /**
459
+ * @!visibility private
460
+ * Calculate impurity based on criterion.
461
+ *
462
+ * @overload node_impurity(criterion, y, n_classes) -> Float
463
+ *
464
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
465
+ * @param y [Numo::Int32] (shape: [n_samples]) The labels.
466
+ * @param n_classes [Integer] The number of classes.
467
+ * @return [Float] impurity
468
+ */
469
+ static VALUE node_impurity_cls(VALUE self, VALUE criterion, VALUE y, VALUE n_classes) {
470
+ ndfunc_arg_in_t ain[1] = {{numo_cInt32, 1}};
471
+ ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 0}};
472
+ ndfunc_t ndf = {(na_iter_func_t)iter_node_impurity_cls, NDF_EXTRACT, 1, 1, ain, aout};
473
+ node_impurity_cls_opts opts = {StringValuePtr(criterion), NUM2LONG(n_classes)};
474
+ VALUE ret = na_ndloop3(&ndf, &opts, 1, y);
475
+ RB_GC_GUARD(criterion);
476
+ return ret;
477
+ }
478
+
479
+ /**
480
+ * @!visibility private
481
+ */
482
+ static void iter_check_same_label(na_loop_t const* lp) {
483
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
484
+ const long n_elements = NDL_SHAPE(lp, 0)[0];
485
+ int32_t* ret = (int32_t*)NDL_PTR(lp, 1);
486
+ *ret = 1;
487
+ if (n_elements > 0) {
488
+ int32_t label = y[0];
489
+ for (long i = 0; i < n_elements; i++) {
490
+ if (y[i] != label) {
491
+ *ret = 0;
492
+ break;
493
+ }
494
+ }
495
+ }
496
+ }
497
+
498
+ /**
499
+ * @!visibility private
500
+ * Check all elements have the save value.
501
+ *
502
+ * @overload check_same_label(y) -> Boolean
503
+ *
504
+ * @param y [Numo::Int32] (shape: [n_samples]) The labels.
505
+ * @return [Boolean]
506
+ */
507
+ static VALUE check_same_label(VALUE self, VALUE y) {
508
+ ndfunc_arg_in_t ain[1] = {{numo_cInt32, 1}};
509
+ ndfunc_arg_out_t aout[1] = {{numo_cInt32, 0}};
510
+ ndfunc_t ndf = {(na_iter_func_t)iter_check_same_label, NDF_EXTRACT, 1, 1, ain, aout};
511
+ VALUE ret = na_ndloop(&ndf, 1, y);
512
+ return (NUM2INT(ret) == 1 ? Qtrue : Qfalse);
513
+ }
514
+
515
+ /**
516
+ * @!visibility private
517
+ * Calculate impurity based on criterion.
518
+ *
519
+ * @overload node_impurity(criterion, y) -> Float
520
+ *
521
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
522
+ * @param y [Array<Float>] (shape: [n_samples, n_outputs]) The taget values.
523
+ * @return [Float] impurity
524
+ */
525
+ static VALUE node_impurity_reg(VALUE self, VALUE criterion, VALUE y) {
526
+ const long n_elements = RARRAY_LEN(y);
527
+ const long n_outputs = RARRAY_LEN(rb_ary_entry(y, 0));
528
+ double* sum_vec = alloc_dbl_array(n_outputs);
529
+ VALUE target_vecs = rb_ary_new();
530
+
531
+ for (long i = 0; i < n_elements; i++) {
532
+ VALUE target = rb_ary_entry(y, i);
533
+ add_sum_vec(sum_vec, target);
534
+ rb_ary_push(target_vecs, target);
535
+ }
536
+
537
+ VALUE ret = DBL2NUM(calc_impurity_reg(StringValuePtr(criterion), target_vecs, sum_vec));
538
+ xfree(sum_vec);
539
+ RB_GC_GUARD(criterion);
540
+ return ret;
541
+ }
542
+
543
+ void Init_ext(void) {
544
+ VALUE rb_mRumale = rb_define_module("Rumale");
545
+ VALUE rb_mTree = rb_define_module_under(rb_mRumale, "Tree");
546
+
547
+ /**
548
+ * Document-module: Rumale::Tree::ExtDecisionTreeClassifier
549
+ * @!visibility private
550
+ * The mixin module consisting of extension method for DecisionTreeClassifier class.
551
+ * This module is used internally.
552
+ */
553
+ VALUE rb_mExtDTreeCls = rb_define_module_under(rb_mTree, "ExtDecisionTreeClassifier");
554
+ /**
555
+ * Document-module: Rumale::Tree::ExtDecisionTreeRegressor
556
+ * @!visibility private
557
+ * The mixin module consisting of extension method for DecisionTreeRegressor class.
558
+ * This module is used internally.
559
+ */
560
+ VALUE rb_mExtDTreeReg = rb_define_module_under(rb_mTree, "ExtDecisionTreeRegressor");
561
+ /**
562
+ * Document-module: Rumale::Tree::ExtGradientTreeRegressor
563
+ * @!visibility private
564
+ * The mixin module consisting of extension method for GradientTreeRegressor class.
565
+ * This module is used internally.
566
+ */
567
+ VALUE rb_mExtGTreeReg = rb_define_module_under(rb_mTree, "ExtGradientTreeRegressor");
568
+
569
+ rb_define_private_method(rb_mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
570
+ rb_define_private_method(rb_mExtDTreeReg, "find_split_params", find_split_params_reg, 5);
571
+ rb_define_private_method(rb_mExtGTreeReg, "find_split_params", find_split_params_grad_reg, 7);
572
+ rb_define_private_method(rb_mExtDTreeCls, "node_impurity", node_impurity_cls, 3);
573
+ rb_define_private_method(rb_mExtDTreeCls, "stop_growing?", check_same_label, 1);
574
+ rb_define_private_method(rb_mExtDTreeReg, "node_impurity", node_impurity_reg, 2);
575
+ }
@@ -0,0 +1,12 @@
1
+ #ifndef RUMALE_TREE_EXT_H
2
+ #define RUMALE_TREE_EXT_H 1
3
+
4
+ #include <math.h>
5
+ #include <string.h>
6
+
7
+ #include <ruby.h>
8
+
9
+ #include <numo/narray.h>
10
+ #include <numo/template.h>
11
+
12
+ #endif /* RUMALE_TREE_EXT_H */
@@ -0,0 +1,32 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'mkmf'
4
+ require 'numo/narray'
5
+
6
+ $LOAD_PATH.each do |lp|
7
+ if File.exist?(File.join(lp, 'numo/numo/narray.h'))
8
+ $INCFLAGS = "-I#{lp}/numo #{$INCFLAGS}"
9
+ break
10
+ end
11
+ end
12
+
13
+ abort 'numo/narray.h not found.' unless have_header('numo/narray.h')
14
+
15
+ if RUBY_PLATFORM =~ /mswin|cygwin|mingw/
16
+ $LOAD_PATH.each do |lp|
17
+ if File.exist?(File.join(lp, 'numo/libnarray.a'))
18
+ $LDFLAGS = "-L#{lp}/numo #{$LDFLAGS}"
19
+ break
20
+ end
21
+ end
22
+
23
+ abort 'libnarray.a not found.' unless have_library('narray', 'nary_new')
24
+ end
25
+
26
+ if RUBY_PLATFORM.match?(/darwin/) && Gem::Version.new('3.1.0') <= Gem::Version.new(RUBY_VERSION)
27
+ if try_link('int main(void){return 0;}', '-Wl,-undefined,dynamic_lookup')
28
+ $LDFLAGS << ' -Wl,-undefined,dynamic_lookup'
29
+ end
30
+ end
31
+
32
+ create_makefile('rumale/tree/ext')