rumale-tree 0.26.0 → 0.28.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: eae93f4d3753b16336ffed4b0a467cf09ddb33b0b89005a83ed5088ea2b362ef
4
- data.tar.gz: 73d6c88f3653b5d2c70ef884fb7f5ef5a68e91d79abd19570476c13f02c434c6
3
+ metadata.gz: a437c3b879dd51d2e823f1851ca8b350d0441947fb6cefe727de83db78b4e6d9
4
+ data.tar.gz: 519665f2baea649ec31c9d5529aa05c67a880daa8ab60149433a53483f630663
5
5
  SHA512:
6
- metadata.gz: 6a5fa902ab481cc7fe7bdbc38e2e4da2f57c4b712d47fef6daabfd757d263949e19d12c5324a7ff0b0372e656a435908d3049380aef4f241c7ff39fc8914bf44
7
- data.tar.gz: 7a5e3b465549928e01143949e4001d7825882a687591adb91f82f7b753f7527c4fa2e5787ad5af263777855f3e8a3e5101386fdc361da3940f7f6dce0933c5bb
6
+ metadata.gz: aa68d3e2fbe99ef4c8d15a10f64098abfbbcdce9047053967e578edad688fa08a03a1ac4e4d23dbd5e67a4a23370057aa9462107bb961eb4d3f12782ce0c0f18
7
+ data.tar.gz: 6a6e6f167e31d56051768a881a4db0da7ef02f36977cb3969e8160bad385757153b1f6be32dd364ed3f67ed4672dd0c1bfd14cc80e0b033e00dd282d757955a9
@@ -0,0 +1,39 @@
1
+ /**
2
+ * Copyright (c) 2022-2023 Atsushi Tatsuma
3
+ * All rights reserved.
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * * Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * * Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * * Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ */
30
+
31
+ #include "ext.hpp"
32
+
33
+ extern "C" void Init_ext(void) {
34
+ VALUE rb_mRumale = rb_define_module("Rumale");
35
+ VALUE rb_mTree = rb_define_module_under(rb_mRumale, "Tree");
36
+ ExtDecisionTreeClassifier::define_module(rb_mTree);
37
+ ExtDecisionTreeRegressor::define_module(rb_mTree);
38
+ ExtGradientTreeRegressor::define_module(rb_mTree);
39
+ }
@@ -0,0 +1,550 @@
1
+ /**
2
+ * Copyright (c) 2022-2023 Atsushi Tatsuma
3
+ * All rights reserved.
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * * Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * * Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * * Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ */
30
+
31
+ #ifndef RUMALE_TREE_EXT_HPP
32
+ #define RUMALE_TREE_EXT_HPP 1
33
+
34
+ #include <cmath>
35
+ #include <limits>
36
+ #include <string>
37
+ #include <vector>
38
+
39
+ #include <ruby.h>
40
+
41
+ #include <numo/narray.h>
42
+ #include <numo/template.h>
43
+
44
+ /**
45
+ * @!visibility private
46
+ * Document-module: Rumale::Tree::ExtDecisionTreeClassifier
47
+ * The mixin module consisting of extension method for DecisionTreeClassifier class.
48
+ * This module is used internally.
49
+ */
50
+ class ExtDecisionTreeClassifier {
51
+ public:
52
+ static void define_module(VALUE& outer) {
53
+ VALUE rb_mExtDTreeCls = rb_define_module_under(outer, "ExtDecisionTreeClassifier");
54
+ rb_define_private_method(rb_mExtDTreeCls, "find_split_params", find_split_params_, 6);
55
+ rb_define_private_method(rb_mExtDTreeCls, "node_impurity", node_impurity_, 3);
56
+ rb_define_private_method(rb_mExtDTreeCls, "stop_growing?", check_same_label_, 1);
57
+ }
58
+
59
+ private:
60
+ static double calc_impurity_(const std::string& criterion, const std::vector<size_t>& histogram, const size_t& n_elements, const size_t& n_classes) {
61
+ double impurity = 0.0;
62
+ if (criterion == "entropy") {
63
+ double entropy = 0.0;
64
+ for (size_t i = 0; i < n_classes; i++) {
65
+ const double el = static_cast<double>(histogram[i]) / static_cast<double>(n_elements);
66
+ entropy += el * std::log(el + 1.0);
67
+ }
68
+ impurity = -entropy;
69
+ } else {
70
+ double gini = 0.0;
71
+ for (size_t i = 0; i < n_classes; i++) {
72
+ const double el = static_cast<double>(histogram[i]) / static_cast<double>(n_elements);
73
+ gini += el * el;
74
+ }
75
+ impurity = 1.0 - gini;
76
+ }
77
+ return impurity;
78
+ }
79
+
80
+ /**
81
+ * @!visibility private
82
+ * Find for split point with maximum information gain.
83
+ *
84
+ * @overload find_split_params(criterion, impurity, order, features, labels, n_classes) -> Array<Float>
85
+ *
86
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
87
+ * @param impurity [Float] The impurity of whole dataset.
88
+ * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
89
+ * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
90
+ * @param labels [Numo::Int32] (shape: [n_elements]) The labels.
91
+ * @param n_classes [Integer] The number of classes.
92
+ * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
93
+ */
94
+
95
+ struct FindSplitParamsOpts_ {
96
+ std::string criterion;
97
+ size_t n_classes;
98
+ double impurity;
99
+ };
100
+
101
+ static void iter_find_split_params_(na_loop_t const* lp) {
102
+ // Obtain iteration variables.
103
+ const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
104
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
105
+ const double* f = (double*)NDL_PTR(lp, 1);
106
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 2);
107
+ const std::string criterion = ((FindSplitParamsOpts_*)lp->opt_ptr)->criterion;
108
+ const size_t n_classes = ((FindSplitParamsOpts_*)lp->opt_ptr)->n_classes;
109
+ const double w_impurity = ((FindSplitParamsOpts_*)lp->opt_ptr)->impurity;
110
+
111
+ // Initialize output optimal parameters.
112
+ double* params = (double*)NDL_PTR(lp, 3);
113
+ params[0] = 0.0; // left impurity
114
+ params[1] = w_impurity; // right impurity
115
+ params[2] = f[o[0]]; // threshold
116
+ params[3] = 0.0; // gain
117
+
118
+ // Initialize child node variables.
119
+ std::vector<size_t> r_histogram(n_classes, 0);
120
+ for (size_t i = 0; i < n_elements; i++) r_histogram[y[o[i]]] += 1;
121
+
122
+ // Find optimal parameters.
123
+ size_t curr_pos = 0;
124
+ size_t next_pos = 0;
125
+ size_t n_l_elements = 0;
126
+ size_t n_r_elements = n_elements;
127
+ double curr_el = f[o[0]];
128
+ const double last_el = f[o[n_elements - 1]];
129
+ std::vector<size_t> l_histogram(n_classes, 0);
130
+ while (curr_pos < n_elements && curr_el != last_el) {
131
+ double next_el = f[o[next_pos]];
132
+ while (next_pos < n_elements && next_el == curr_el) {
133
+ l_histogram[y[o[next_pos]]] += 1;
134
+ n_l_elements++;
135
+ r_histogram[y[o[next_pos]]] -= 1;
136
+ n_r_elements--;
137
+ next_pos++;
138
+ next_el = f[o[next_pos]];
139
+ }
140
+ // Calculate gain of new split.
141
+ const double l_impurity = calc_impurity_(criterion, l_histogram, n_l_elements, n_classes);
142
+ const double r_impurity = calc_impurity_(criterion, r_histogram, n_r_elements, n_classes);
143
+ const double gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / static_cast<double>(n_elements);
144
+ // Update optimal parameters.
145
+ if (gain > params[3]) {
146
+ params[0] = l_impurity;
147
+ params[1] = r_impurity;
148
+ params[2] = 0.5 * (curr_el + next_el);
149
+ params[3] = gain;
150
+ }
151
+ if (next_pos == n_elements) break;
152
+ curr_pos = next_pos;
153
+ curr_el = f[o[curr_pos]];
154
+ }
155
+ }
156
+
157
+ static VALUE find_split_params_(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE labels, VALUE n_classes) {
158
+ ndfunc_arg_in_t ain[3] = { { numo_cInt32, 1 }, { numo_cDFloat, 1 }, { numo_cInt32, 1 } };
159
+ size_t out_shape[1] = { 4 };
160
+ ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 1, out_shape } };
161
+ ndfunc_t ndf = { (na_iter_func_t)iter_find_split_params_, NO_LOOP, 3, 1, ain, aout };
162
+ FindSplitParamsOpts_ opts = { std::string(StringValueCStr(criterion)), NUM2SIZET(n_classes), NUM2DBL(impurity) };
163
+ VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, labels);
164
+ RB_GC_GUARD(criterion);
165
+ return params;
166
+ }
167
+
168
+ /**
169
+ * @!visibility private
170
+ * Calculate impurity based on criterion.
171
+ *
172
+ * @overload node_impurity(criterion, y, n_classes) -> Float
173
+ *
174
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
175
+ * @param y [Numo::Int32] (shape: [n_samples]) The labels.
176
+ * @param n_classes [Integer] The number of classes.
177
+ * @return [Float] impurity
178
+ */
179
+
180
+ struct NodeImpurityOpts_ {
181
+ std::string criterion;
182
+ size_t n_classes;
183
+ };
184
+
185
+ static void iter_node_impurity_(na_loop_t const* lp) {
186
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
187
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
188
+ const std::string criterion = ((NodeImpurityOpts_*)lp->opt_ptr)->criterion;
189
+ const size_t n_classes = ((NodeImpurityOpts_*)lp->opt_ptr)->n_classes;
190
+ double* ret = (double*)NDL_PTR(lp, 1);
191
+ std::vector<size_t> histogram(n_classes, 0);
192
+ for (size_t i = 0; i < n_elements; i++) histogram[y[i]] += 1;
193
+ *ret = calc_impurity_(criterion, histogram, n_elements, n_classes);
194
+ }
195
+
196
+ static VALUE node_impurity_(VALUE self, VALUE criterion, VALUE y, VALUE n_classes) {
197
+ ndfunc_arg_in_t ain[1] = { { numo_cInt32, 1 } };
198
+ ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 0 } };
199
+ ndfunc_t ndf = { (na_iter_func_t)iter_node_impurity_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
200
+ NodeImpurityOpts_ opts = { std::string(StringValueCStr(criterion)), NUM2SIZET(n_classes) };
201
+ VALUE ret = na_ndloop3(&ndf, &opts, 1, y);
202
+ RB_GC_GUARD(criterion);
203
+ return ret;
204
+ }
205
+
206
+ /**
207
+ * @!visibility private
208
+ * Check all elements have the same value.
209
+ *
210
+ * @overload check_same_label(y) -> Boolean
211
+ *
212
+ * @param y [Numo::Int32] (shape: [n_samples]) The labels.
213
+ * @return [Boolean]
214
+ */
215
+
216
+ static void iter_check_same_label_(na_loop_t const* lp) {
217
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
218
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
219
+ VALUE* ret = (VALUE*)NDL_PTR(lp, 1);
220
+ *ret = Qtrue;
221
+ if (n_elements > 0) {
222
+ int32_t label = y[0];
223
+ for (size_t i = 0; i < n_elements; i++) {
224
+ if (y[i] != label) {
225
+ *ret = Qfalse;
226
+ break;
227
+ }
228
+ }
229
+ }
230
+ }
231
+
232
+ static VALUE check_same_label_(VALUE self, VALUE y) {
233
+ ndfunc_arg_in_t ain[1] = { { numo_cInt32, 1 } };
234
+ ndfunc_arg_out_t aout[1] = { { numo_cRObject, 0 } };
235
+ ndfunc_t ndf = { (na_iter_func_t)iter_check_same_label_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
236
+ return na_ndloop(&ndf, 1, y);
237
+ }
238
+ };
239
+
240
+ /**
241
+ * @!visibility private
242
+ * Document-module: Rumale::Tree::ExtDecisionTreeRegressor
243
+ * The mixin module consisting of extension method for DecisionTreeRegressor class.
244
+ * This module is used internally.
245
+ */
246
+ class ExtDecisionTreeRegressor {
247
+ public:
248
+ static void define_module(VALUE& outer) {
249
+ VALUE rb_mExtDTreeReg = rb_define_module_under(outer, "ExtDecisionTreeRegressor");
250
+ rb_define_private_method(rb_mExtDTreeReg, "find_split_params", find_split_params_, 5);
251
+ rb_define_private_method(rb_mExtDTreeReg, "node_impurity", node_impurity_, 2);
252
+ rb_define_private_method(rb_mExtDTreeReg, "stop_growing?", check_same_value_, 1);
253
+ }
254
+
255
+ private:
256
+ static double calc_impurity_(const std::string& criterion, const int32_t* order, const double* vecs, const double* mean_vec,
257
+ const size_t& n_elements, const size_t& n_outputs, const size_t& order_offset) {
258
+ const bool is_mae = criterion == "mae";
259
+ double sum_err = 0.0;
260
+ for (size_t i = 0; i < n_elements; i++) {
261
+ double err = 0.0;
262
+ for (size_t j = 0; j < n_outputs; j++) {
263
+ const double el = vecs[order[order_offset + i] * n_outputs + j] - mean_vec[j];
264
+ err += is_mae ? std::fabs(el) : el * el;
265
+ }
266
+ err /= static_cast<double>(n_outputs);
267
+ sum_err += err;
268
+ }
269
+ const double impurity = sum_err / static_cast<double>(n_elements);
270
+ return impurity;
271
+ }
272
+
273
+ /**
274
+ * @!visibility private
275
+ * Find for split point with maximum information gain.
276
+ *
277
+ * @overload find_split_params(criterion, impurity, order, features, targets) -> Array<Float>
278
+ *
279
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
280
+ * @param impurity [Float] The impurity of whole dataset.
281
+ * @param order [Numo::Int32] (shape: [n_samples]) The element indices sorted according to feature values in ascending order.
282
+ * @param features [Numo::DFloat] (shape: [n_samples]) The feature values.
283
+ * @param targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
284
+ * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
285
+ */
286
+
287
+ struct FindSplitParamsOpts_ {
288
+ std::string criterion;
289
+ double impurity;
290
+ };
291
+
292
+ static void iter_find_split_params_(na_loop_t const* lp) {
293
+ // Obtain iteration variables.
294
+ const int32_t* order = (int32_t*)NDL_PTR(lp, 0);
295
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
296
+ const double* f = (double*)NDL_PTR(lp, 1);
297
+ const double* y = (double*)NDL_PTR(lp, 2);
298
+ const size_t n_outputs = NDL_SHAPE(lp, 2)[1];
299
+ const std::string criterion = ((FindSplitParamsOpts_*)lp->opt_ptr)->criterion;
300
+ const double w_impurity = ((FindSplitParamsOpts_*)lp->opt_ptr)->impurity;
301
+
302
+ // Initialize optimal parameters.
303
+ double* params = (double*)NDL_PTR(lp, 3);
304
+ params[0] = 0.0; // left impurity
305
+ params[1] = w_impurity; // right impurity
306
+ params[2] = f[order[0]]; // threshold
307
+ params[3] = 0.0; // gain
308
+
309
+ // Initialize child node variables.
310
+ std::vector<double> l_sum_y(n_outputs, 0);
311
+ std::vector<double> r_sum_y(n_outputs, 0);
312
+ for (size_t i = 0; i < n_elements; i++) {
313
+ for (size_t j = 0; j < n_outputs; j++) {
314
+ r_sum_y[j] += y[order[i] * n_outputs + j];
315
+ }
316
+ }
317
+
318
+ // Find optimal parameters.
319
+ size_t curr_pos = 0;
320
+ size_t next_pos = 0;
321
+ size_t n_l_elements = 0;
322
+ size_t n_r_elements = n_elements;
323
+ std::vector<double> l_mean_y(n_outputs, 0);
324
+ std::vector<double> r_mean_y(n_outputs, 0);
325
+ double curr_el = f[order[0]];
326
+ const double last_el = f[order[n_elements - 1]];
327
+ while (curr_pos < n_elements && curr_el != last_el) {
328
+ double next_el = f[order[next_pos]];
329
+ while (next_pos < n_elements && next_el == curr_el) {
330
+ for (size_t j = 0; j < n_outputs; j++) {
331
+ l_sum_y[j] += y[order[next_pos] * n_outputs + j];
332
+ r_sum_y[j] -= y[order[next_pos] * n_outputs + j];
333
+ }
334
+ n_l_elements++;
335
+ n_r_elements--;
336
+ next_pos++;
337
+ next_el = f[order[next_pos]];
338
+ }
339
+ // Calculate gain of new split.
340
+ for (size_t j = 0; j < n_outputs; j++) {
341
+ l_mean_y[j] = l_sum_y[j] / static_cast<double>(n_l_elements);
342
+ r_mean_y[j] = r_sum_y[j] / static_cast<double>(n_r_elements);
343
+ }
344
+ const double l_impurity = calc_impurity_(criterion, order, y, l_mean_y.data(), n_l_elements, n_outputs, 0);
345
+ const double r_impurity = calc_impurity_(criterion, order, y, r_mean_y.data(), n_r_elements, n_outputs, next_pos);
346
+ const double gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / static_cast<double>(n_elements);
347
+ // Update optimal parameters.
348
+ if (gain > params[3]) {
349
+ params[0] = l_impurity;
350
+ params[1] = r_impurity;
351
+ params[2] = 0.5 * (curr_el + next_el);
352
+ params[3] = gain;
353
+ }
354
+ if (next_pos == n_elements) break;
355
+ curr_pos = next_pos;
356
+ curr_el = f[order[curr_pos]];
357
+ }
358
+ }
359
+
360
+ static VALUE find_split_params_(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE targets) {
361
+ ndfunc_arg_in_t ain[3] = { { numo_cInt32, 1 }, { numo_cDFloat, 1 }, { numo_cDFloat, 2 } };
362
+ size_t out_shape[1] = { 4 };
363
+ ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 1, out_shape } };
364
+ ndfunc_t ndf = { (na_iter_func_t)iter_find_split_params_, NO_LOOP, 3, 1, ain, aout };
365
+ FindSplitParamsOpts_ opts = { std::string(StringValueCStr(criterion)), NUM2DBL(impurity) };
366
+ VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, targets);
367
+ RB_GC_GUARD(criterion);
368
+ return params;
369
+ }
370
+
371
+ /**
372
+ * @!visibility private
373
+ * Calculate impurity based on criterion.
374
+ *
375
+ * @overload node_impurity(criterion, y) -> Float
376
+ *
377
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
378
+ * @param y [Array<Float>] (shape: [n_samples, n_outputs]) The taget values.
379
+ * @return [Float] impurity
380
+ */
381
+
382
+ struct NodeImpurityOpts_ {
383
+ std::string criterion;
384
+ };
385
+
386
+ static void iter_node_impurity_(na_loop_t const* lp) {
387
+ const double* y = (double*)NDL_PTR(lp, 0);
388
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
389
+ const size_t n_outputs = NDL_SHAPE(lp, 0)[1];
390
+ const std::string criterion = ((NodeImpurityOpts_*)lp->opt_ptr)->criterion;
391
+
392
+ std::vector<int32_t> order(n_elements);
393
+ std::vector<double> mean_y(n_outputs, 0);
394
+ for (size_t i = 0; i < n_elements; i++) {
395
+ order[i] = static_cast<int32_t>(i);
396
+ for (size_t j = 0; j < n_outputs; j++) {
397
+ mean_y[j] += y[i * n_outputs + j];
398
+ }
399
+ }
400
+ for (size_t j = 0; j < n_outputs; j++) {
401
+ mean_y[j] /= static_cast<double>(n_elements);
402
+ }
403
+
404
+ double* ret = (double*)NDL_PTR(lp, 1);
405
+ *ret = calc_impurity_(criterion, order.data(), y, mean_y.data(), n_elements, n_outputs, 0);
406
+ }
407
+
408
+ static VALUE node_impurity_(VALUE self, VALUE criterion, VALUE y) {
409
+ ndfunc_arg_in_t ain[1] = { { numo_cDFloat, 2 } };
410
+ ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 0 } };
411
+ ndfunc_t ndf = { (na_iter_func_t)iter_node_impurity_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
412
+ NodeImpurityOpts_ opts = { std::string(StringValueCStr(criterion)) };
413
+ VALUE ret = na_ndloop3(&ndf, &opts, 1, y);
414
+ RB_GC_GUARD(criterion);
415
+ return ret;
416
+ }
417
+
418
+ /**
419
+ * @!visibility private
420
+ * Check all elements have the same value/vector.
421
+ *
422
+ * @overload check_same_value(y) -> Boolean
423
+ *
424
+ * @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
425
+ * @return [Boolean]
426
+ */
427
+
428
+ static void iter_check_same_value_(na_loop_t const* lp) {
429
+ const double* y = (double*)NDL_PTR(lp, 0);
430
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
431
+ const size_t n_outputs = NDL_SHAPE(lp, 0)[1];
432
+ VALUE* ret = (VALUE*)NDL_PTR(lp, 1);
433
+ const double eps = std::numeric_limits<double>::epsilon();
434
+ *ret = Qtrue;
435
+ if (n_elements > 0) {
436
+ for (size_t i = 1; i < n_elements; i++) {
437
+ for (size_t j = 0; j < n_outputs; j++) {
438
+ if (std::abs(y[i * n_outputs + j] - y[j]) > eps) {
439
+ *ret = Qfalse;
440
+ break;
441
+ }
442
+ }
443
+ if (*ret == Qfalse) break;
444
+ }
445
+ }
446
+ }
447
+
448
+ static VALUE check_same_value_(VALUE self, VALUE y) {
449
+ ndfunc_arg_in_t ain[1] = { { numo_cDFloat, 2 } };
450
+ ndfunc_arg_out_t aout[1] = { { numo_cRObject, 0 } };
451
+ ndfunc_t ndf = { (na_iter_func_t)iter_check_same_value_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
452
+ return na_ndloop(&ndf, 1, y);
453
+ }
454
+ };
455
+
456
+ /**
457
+ * @!visibility private
458
+ * Document-module: Rumale::Tree::ExtGradientTreeRegressor
459
+ * The mixin module consisting of extension method for GradientTreeRegressor class.
460
+ * This module is used internally.
461
+ */
462
+ class ExtGradientTreeRegressor {
463
+ public:
464
+ static void define_module(VALUE& outer) {
465
+ VALUE rb_mExtGTreeReg = rb_define_module_under(outer, "ExtGradientTreeRegressor");
466
+ rb_define_private_method(rb_mExtGTreeReg, "find_split_params", find_split_params_, 7);
467
+ }
468
+
469
+ private:
470
+ /**
471
+ * @!visibility private
472
+ * Find for split point with maximum information gain.
473
+ *
474
+ * @overload find_split_params(order, features, gradients, hessians, sum_gradient, sum_hessian, reg_lambda) -> Array<Float>
475
+ * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
476
+ * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
477
+ * @param gradients [Numo::DFloat] (shape: [n_elements]) The gradient values.
478
+ * @param hessians [Numo::DFloat] (shape: [n_elements]) The hessian values.
479
+ * @param sum_gradient [Float] The sum of gradient values.
480
+ * @param sum_hessian [Float] The sum of hessian values.
481
+ * @param reg_lambda [Float] The L2 regularization term on weight.
482
+ * @return [Array<Float>] The array consists of optimal parameters including threshold and gain.
483
+ */
484
+
485
+ struct FindSplitParamsOpts_ {
486
+ double sum_gradient;
487
+ double sum_hessian;
488
+ double reg_lambda;
489
+ };
490
+
491
+ static void iter_find_split_params_(na_loop_t const* lp) {
492
+ // Obtain iteration variables.
493
+ const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
494
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
495
+ const double* f = (double*)NDL_PTR(lp, 1);
496
+ const double* g = (double*)NDL_PTR(lp, 2);
497
+ const double* h = (double*)NDL_PTR(lp, 3);
498
+ const double s_grad = ((FindSplitParamsOpts_*)lp->opt_ptr)->sum_gradient;
499
+ const double s_hess = ((FindSplitParamsOpts_*)lp->opt_ptr)->sum_hessian;
500
+ const double reg_lambda = ((FindSplitParamsOpts_*)lp->opt_ptr)->reg_lambda;
501
+
502
+ // Find optimal parameters.
503
+ size_t curr_pos = 0;
504
+ size_t next_pos = 0;
505
+ double curr_el = f[o[0]];
506
+ const double last_el = f[o[n_elements - 1]];
507
+ double l_grad = 0.0;
508
+ double l_hess = 0.0;
509
+ double threshold = curr_el;
510
+ double gain_max = 0.0;
511
+ while (curr_pos < n_elements && curr_el != last_el) {
512
+ double next_el = f[o[next_pos]];
513
+ while (next_pos < n_elements && next_el == curr_el) {
514
+ l_grad += g[o[next_pos]];
515
+ l_hess += h[o[next_pos]];
516
+ next_pos++;
517
+ next_el = f[o[next_pos]];
518
+ }
519
+ // Calculate gain of new split.
520
+ const double r_grad = s_grad - l_grad;
521
+ const double r_hess = s_hess - l_hess;
522
+ const double gain = (l_grad * l_grad) / (l_hess + reg_lambda) + (r_grad * r_grad) / (r_hess + reg_lambda) - (s_grad * s_grad) / (s_hess + reg_lambda);
523
+ // Update optimal parameters.
524
+ if (gain > gain_max) {
525
+ threshold = 0.5 * (curr_el + next_el);
526
+ gain_max = gain;
527
+ }
528
+ if (next_pos == n_elements) break;
529
+ curr_pos = next_pos;
530
+ curr_el = f[o[curr_pos]];
531
+ }
532
+
533
+ double* params = (double*)NDL_PTR(lp, 4);
534
+ params[0] = threshold;
535
+ params[1] = gain_max;
536
+ }
537
+
538
+ static VALUE find_split_params_(VALUE self, VALUE order, VALUE features, VALUE gradients, VALUE hessians,
539
+ VALUE sum_gradient, VALUE sum_hessian, VALUE reg_lambda) {
540
+ ndfunc_arg_in_t ain[4] = { { numo_cInt32, 1 }, { numo_cDFloat, 1 }, { numo_cDFloat, 1 }, { numo_cDFloat, 1 } };
541
+ size_t out_shape[1] = { 2 };
542
+ ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 1, out_shape } };
543
+ ndfunc_t ndf = { (na_iter_func_t)iter_find_split_params_, NO_LOOP, 4, 1, ain, aout };
544
+ FindSplitParamsOpts_ opts = { NUM2DBL(sum_gradient), NUM2DBL(sum_hessian), NUM2DBL(reg_lambda) };
545
+ VALUE params = na_ndloop3(&ndf, &opts, 4, order, features, gradients, hessians);
546
+ return params;
547
+ }
548
+ };
549
+
550
+ #endif /* RUMALE_TREE_EXT_HPP */
@@ -3,6 +3,8 @@
3
3
  require 'mkmf'
4
4
  require 'numo/narray'
5
5
 
6
+ abort 'libstdc++ is not found.' unless have_library('stdc++')
7
+
6
8
  $LOAD_PATH.each do |lp|
7
9
  if File.exist?(File.join(lp, 'numo/numo/narray.h'))
8
10
  $INCFLAGS = "-I#{lp}/numo #{$INCFLAGS}"
@@ -63,11 +63,7 @@ module Rumale
63
63
  end
64
64
 
65
65
  def build_tree(x, y)
66
- y = y.expand_dims(1).dup if y.shape[1].nil?
67
- @feature_ids = Array.new(x.shape[1]) { |v| v }
68
- @tree = grow_node(0, x, y, impurity(y))
69
- @feature_ids = nil
70
- nil
66
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
71
67
  end
72
68
 
73
69
  def grow_node(depth, x, y, impurity) # rubocop:disable Metrics/AbcSize, Metrics/PerceivedComplexity
@@ -68,6 +68,7 @@ module Rumale
68
68
  @params[:max_features] = [@params[:max_features], n_features].min
69
69
  @n_leaves = 0
70
70
  @leaf_values = []
71
+ @feature_ids = Array.new(x.shape[1]) { |v| v }
71
72
  @sub_rng = @rng.dup
72
73
  build_tree(x, y)
73
74
  eval_importance(n_samples, n_features)
@@ -88,8 +89,10 @@ module Rumale
88
89
 
89
90
  private
90
91
 
91
- def stop_growing?(y)
92
- y.to_a.uniq.size == 1
92
+ def build_tree(x, y)
93
+ y = y.expand_dims(1).dup if y.shape[1].nil?
94
+ @tree = grow_node(0, x, y, impurity(y))
95
+ nil
93
96
  end
94
97
 
95
98
  def put_leaf(node, y)
@@ -106,7 +109,7 @@ module Rumale
106
109
  end
107
110
 
108
111
  def impurity(y)
109
- node_impurity(@params[:criterion], y.to_a)
112
+ node_impurity(@params[:criterion], y)
110
113
  end
111
114
  end
112
115
  end
@@ -5,6 +5,6 @@ module Rumale
5
5
  # This module consists of the classes that implement tree models.
6
6
  module Tree
7
7
  # @!visibility private
8
- VERSION = '0.26.0'
8
+ VERSION = '0.28.0'
9
9
  end
10
10
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-tree
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.26.0
4
+ version: 0.28.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-02-19 00:00:00.000000000 Z
11
+ date: 2023-11-12 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -30,14 +30,14 @@ dependencies:
30
30
  requirements:
31
31
  - - "~>"
32
32
  - !ruby/object:Gem::Version
33
- version: 0.26.0
33
+ version: 0.28.0
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - "~>"
39
39
  - !ruby/object:Gem::Version
40
- version: 0.26.0
40
+ version: 0.28.0
41
41
  description: Rumale::Tree provides classifier and regression based on decision tree
42
42
  algorithms with Rumale interface.
43
43
  email:
@@ -49,8 +49,8 @@ extra_rdoc_files: []
49
49
  files:
50
50
  - LICENSE.txt
51
51
  - README.md
52
- - ext/rumale/tree/ext.c
53
- - ext/rumale/tree/ext.h
52
+ - ext/rumale/tree/ext.cpp
53
+ - ext/rumale/tree/ext.hpp
54
54
  - ext/rumale/tree/extconf.rb
55
55
  - lib/rumale/tree.rb
56
56
  - lib/rumale/tree/base_decision_tree.rb
@@ -85,7 +85,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
85
85
  - !ruby/object:Gem::Version
86
86
  version: '0'
87
87
  requirements: []
88
- rubygems_version: 3.3.26
88
+ rubygems_version: 3.4.20
89
89
  signing_key:
90
90
  specification_version: 4
91
91
  summary: Rumale::Tree provides classifier and regression based on decision tree algorithms
@@ -1,575 +0,0 @@
1
- #include "ext.h"
2
-
3
- double* alloc_dbl_array(const long n_dimensions) {
4
- double* arr = ALLOC_N(double, n_dimensions);
5
- memset(arr, 0, n_dimensions * sizeof(double));
6
- return arr;
7
- }
8
-
9
- double calc_gini_coef(double* histogram, const long n_elements, const long n_classes) {
10
- double gini = 0.0;
11
-
12
- for (long i = 0; i < n_classes; i++) {
13
- double el = histogram[i] / n_elements;
14
- gini += el * el;
15
- }
16
-
17
- return 1.0 - gini;
18
- }
19
-
20
- double calc_entropy(double* histogram, const long n_elements, const long n_classes) {
21
- double entropy = 0.0;
22
-
23
- for (long i = 0; i < n_classes; i++) {
24
- double el = histogram[i] / n_elements;
25
- entropy += el * log(el + 1.0);
26
- }
27
-
28
- return -entropy;
29
- }
30
-
31
- VALUE
32
- calc_mean_vec(double* sum_vec, const long n_dimensions, const long n_elements) {
33
- VALUE mean_vec = rb_ary_new2(n_dimensions);
34
-
35
- for (long i = 0; i < n_dimensions; i++) {
36
- rb_ary_store(mean_vec, i, DBL2NUM(sum_vec[i] / n_elements));
37
- }
38
-
39
- return mean_vec;
40
- }
41
-
42
- double calc_vec_mae(VALUE vec_a, VALUE vec_b) {
43
- const long n_dimensions = RARRAY_LEN(vec_a);
44
- double sum = 0.0;
45
-
46
- for (long i = 0; i < n_dimensions; i++) {
47
- double diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
48
- sum += fabs(diff);
49
- }
50
-
51
- return sum / n_dimensions;
52
- }
53
-
54
- double calc_vec_mse(VALUE vec_a, VALUE vec_b) {
55
- const long n_dimensions = RARRAY_LEN(vec_a);
56
- double sum = 0.0;
57
-
58
- for (long i = 0; i < n_dimensions; i++) {
59
- double diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
60
- sum += diff * diff;
61
- }
62
-
63
- return sum / n_dimensions;
64
- }
65
-
66
- double calc_mae(VALUE target_vecs, VALUE mean_vec) {
67
- const long n_elements = RARRAY_LEN(target_vecs);
68
- double sum = 0.0;
69
-
70
- for (long i = 0; i < n_elements; i++) {
71
- sum += calc_vec_mae(rb_ary_entry(target_vecs, i), mean_vec);
72
- }
73
-
74
- return sum / n_elements;
75
- }
76
-
77
- double calc_mse(VALUE target_vecs, VALUE mean_vec) {
78
- const long n_elements = RARRAY_LEN(target_vecs);
79
- double sum = 0.0;
80
-
81
- for (long i = 0; i < n_elements; i++) {
82
- sum += calc_vec_mse(rb_ary_entry(target_vecs, i), mean_vec);
83
- }
84
-
85
- return sum / n_elements;
86
- }
87
-
88
- double calc_impurity_cls(const char* criterion, double* histogram, const long n_elements, const long n_classes) {
89
- if (strcmp(criterion, "entropy") == 0) {
90
- return calc_entropy(histogram, n_elements, n_classes);
91
- }
92
- return calc_gini_coef(histogram, n_elements, n_classes);
93
- }
94
-
95
- double calc_impurity_reg(const char* criterion, VALUE target_vecs, double* sum_vec) {
96
- const long n_elements = RARRAY_LEN(target_vecs);
97
- const long n_dimensions = RARRAY_LEN(rb_ary_entry(target_vecs, 0));
98
- VALUE mean_vec = calc_mean_vec(sum_vec, n_dimensions, n_elements);
99
-
100
- if (strcmp(criterion, "mae") == 0) {
101
- return calc_mae(target_vecs, mean_vec);
102
- }
103
- return calc_mse(target_vecs, mean_vec);
104
- }
105
-
106
- void add_sum_vec(double* sum_vec, VALUE target) {
107
- const long n_dimensions = RARRAY_LEN(target);
108
-
109
- for (long i = 0; i < n_dimensions; i++) {
110
- sum_vec[i] += NUM2DBL(rb_ary_entry(target, i));
111
- }
112
- }
113
-
114
- void sub_sum_vec(double* sum_vec, VALUE target) {
115
- const long n_dimensions = RARRAY_LEN(target);
116
-
117
- for (long i = 0; i < n_dimensions; i++) {
118
- sum_vec[i] -= NUM2DBL(rb_ary_entry(target, i));
119
- }
120
- }
121
-
122
- /**
123
- * @!visibility private
124
- */
125
- typedef struct {
126
- char* criterion;
127
- long n_classes;
128
- double impurity;
129
- } split_opts_cls;
130
-
131
- /**
132
- * @!visibility private
133
- */
134
- static void iter_find_split_params_cls(na_loop_t const* lp) {
135
- const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
136
- const double* f = (double*)NDL_PTR(lp, 1);
137
- const int32_t* y = (int32_t*)NDL_PTR(lp, 2);
138
- const long n_elements = NDL_SHAPE(lp, 0)[0];
139
- const char* criterion = ((split_opts_cls*)lp->opt_ptr)->criterion;
140
- const long n_classes = ((split_opts_cls*)lp->opt_ptr)->n_classes;
141
- const double w_impurity = ((split_opts_cls*)lp->opt_ptr)->impurity;
142
- double* params = (double*)NDL_PTR(lp, 3);
143
- long curr_pos = 0;
144
- long next_pos = 0;
145
- long n_l_elements = 0;
146
- long n_r_elements = n_elements;
147
- double curr_el = f[o[0]];
148
- double last_el = f[o[n_elements - 1]];
149
- double next_el;
150
- double l_impurity;
151
- double r_impurity;
152
- double gain;
153
- double* l_histogram = alloc_dbl_array(n_classes);
154
- double* r_histogram = alloc_dbl_array(n_classes);
155
-
156
- /* Initialize optimal parameters. */
157
- params[0] = 0.0; /* left impurity */
158
- params[1] = w_impurity; /* right impurity */
159
- params[2] = curr_el; /* threshold */
160
- params[3] = 0.0; /* gain */
161
-
162
- /* Initialize child node variables. */
163
- for (long i = 0; i < n_elements; i++) {
164
- r_histogram[y[o[i]]] += 1.0;
165
- }
166
-
167
- /* Find optimal parameters. */
168
- while (curr_pos < n_elements && curr_el != last_el) {
169
- next_el = f[o[next_pos]];
170
- while (next_pos < n_elements && next_el == curr_el) {
171
- l_histogram[y[o[next_pos]]] += 1;
172
- n_l_elements++;
173
- r_histogram[y[o[next_pos]]] -= 1;
174
- n_r_elements--;
175
- next_pos++;
176
- next_el = f[o[next_pos]];
177
- }
178
- /* Calculate gain of new split. */
179
- l_impurity = calc_impurity_cls(criterion, l_histogram, n_l_elements, n_classes);
180
- r_impurity = calc_impurity_cls(criterion, r_histogram, n_r_elements, n_classes);
181
- gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
182
- /* Update optimal parameters. */
183
- if (gain > params[3]) {
184
- params[0] = l_impurity;
185
- params[1] = r_impurity;
186
- params[2] = 0.5 * (curr_el + next_el);
187
- params[3] = gain;
188
- }
189
- if (next_pos == n_elements) break;
190
- curr_pos = next_pos;
191
- curr_el = f[o[curr_pos]];
192
- }
193
-
194
- xfree(l_histogram);
195
- xfree(r_histogram);
196
- }
197
-
198
- /**
199
- * @!visibility private
200
- * Find for split point with maximum information gain.
201
- *
202
- * @overload find_split_params(criterion, impurity, order, features, labels, n_classes) -> Array<Float>
203
- *
204
- * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
205
- * @param impurity [Float] The impurity of whole dataset.
206
- * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
207
- * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
208
- * @param labels [Numo::Int32] (shape: [n_elements]) The labels.
209
- * @param n_classes [Integer] The number of classes.
210
- * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
211
- */
212
- static VALUE find_split_params_cls(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE labels,
213
- VALUE n_classes) {
214
- ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cInt32, 1}};
215
- size_t out_shape[1] = {4};
216
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
217
- ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_cls, NO_LOOP, 3, 1, ain, aout};
218
- split_opts_cls opts = {StringValuePtr(criterion), NUM2LONG(n_classes), NUM2DBL(impurity)};
219
- VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, labels);
220
- VALUE results = rb_ary_new2(4);
221
- double* params_ptr = (double*)na_get_pointer_for_read(params);
222
- rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
223
- rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
224
- rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
225
- rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
226
- RB_GC_GUARD(params);
227
- RB_GC_GUARD(criterion);
228
- return results;
229
- }
230
-
231
- /**
232
- * @!visibility private
233
- */
234
- typedef struct {
235
- char* criterion;
236
- double impurity;
237
- } split_opts_reg;
238
-
239
- /**
240
- * @!visibility private
241
- */
242
- static void iter_find_split_params_reg(na_loop_t const* lp) {
243
- const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
244
- const double* f = (double*)NDL_PTR(lp, 1);
245
- const double* y = (double*)NDL_PTR(lp, 2);
246
- const long n_elements = NDL_SHAPE(lp, 0)[0];
247
- const long n_outputs = NDL_SHAPE(lp, 2)[1];
248
- const char* criterion = ((split_opts_reg*)lp->opt_ptr)->criterion;
249
- const double w_impurity = ((split_opts_reg*)lp->opt_ptr)->impurity;
250
- double* params = (double*)NDL_PTR(lp, 3);
251
- long curr_pos = 0;
252
- long next_pos = 0;
253
- long n_l_elements = 0;
254
- long n_r_elements = n_elements;
255
- double curr_el = f[o[0]];
256
- double last_el = f[o[n_elements - 1]];
257
- double next_el;
258
- double l_impurity;
259
- double r_impurity;
260
- double gain;
261
- double* l_sum_vec = alloc_dbl_array(n_outputs);
262
- double* r_sum_vec = alloc_dbl_array(n_outputs);
263
- double target_var;
264
- VALUE l_target_vecs = rb_ary_new();
265
- VALUE r_target_vecs = rb_ary_new();
266
- VALUE target;
267
-
268
- /* Initialize optimal parameters. */
269
- params[0] = 0.0; /* left impurity */
270
- params[1] = w_impurity; /* right impurity */
271
- params[2] = curr_el; /* threshold */
272
- params[3] = 0.0; /* gain */
273
-
274
- /* Initialize child node variables. */
275
- for (long i = 0; i < n_elements; i++) {
276
- target = rb_ary_new2(n_outputs);
277
- for (long j = 0; j < n_outputs; j++) {
278
- target_var = y[o[i] * n_outputs + j];
279
- rb_ary_store(target, j, DBL2NUM(target_var));
280
- r_sum_vec[j] += target_var;
281
- }
282
- rb_ary_push(r_target_vecs, target);
283
- }
284
-
285
- /* Find optimal parameters. */
286
- while (curr_pos < n_elements && curr_el != last_el) {
287
- next_el = f[o[next_pos]];
288
- while (next_pos < n_elements && next_el == curr_el) {
289
- target = rb_ary_shift(r_target_vecs);
290
- n_r_elements--;
291
- sub_sum_vec(r_sum_vec, target);
292
- rb_ary_push(l_target_vecs, target);
293
- n_l_elements++;
294
- add_sum_vec(l_sum_vec, target);
295
- next_pos++;
296
- next_el = f[o[next_pos]];
297
- }
298
- /* Calculate gain of new split. */
299
- l_impurity = calc_impurity_reg(criterion, l_target_vecs, l_sum_vec);
300
- r_impurity = calc_impurity_reg(criterion, r_target_vecs, r_sum_vec);
301
- gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
302
- /* Update optimal parameters. */
303
- if (gain > params[3]) {
304
- params[0] = l_impurity;
305
- params[1] = r_impurity;
306
- params[2] = 0.5 * (curr_el + next_el);
307
- params[3] = gain;
308
- }
309
- if (next_pos == n_elements) break;
310
- curr_pos = next_pos;
311
- curr_el = f[o[curr_pos]];
312
- }
313
-
314
- xfree(l_sum_vec);
315
- xfree(r_sum_vec);
316
- }
317
-
318
- /**
319
- * @!visibility private
320
- * Find for split point with maximum information gain.
321
- *
322
- * @overload find_split_params(criterion, impurity, order, features, targets) -> Array<Float>
323
- *
324
- * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
325
- * @param impurity [Float] The impurity of whole dataset.
326
- * @param order [Numo::Int32] (shape: [n_samples]) The element indices sorted according to feature values in ascending order.
327
- * @param features [Numo::DFloat] (shape: [n_samples]) The feature values.
328
- * @param targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
329
- * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
330
- */
331
- static VALUE find_split_params_reg(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE targets) {
332
- ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 2}};
333
- size_t out_shape[1] = {4};
334
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
335
- ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_reg, NO_LOOP, 3, 1, ain, aout};
336
- split_opts_reg opts = {StringValuePtr(criterion), NUM2DBL(impurity)};
337
- VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, targets);
338
- VALUE results = rb_ary_new2(4);
339
- double* params_ptr = (double*)na_get_pointer_for_read(params);
340
- rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
341
- rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
342
- rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
343
- rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
344
- RB_GC_GUARD(params);
345
- RB_GC_GUARD(criterion);
346
- return results;
347
- }
348
-
349
- /**
350
- * @!visibility private
351
- */
352
- static void iter_find_split_params_grad_reg(na_loop_t const* lp) {
353
- const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
354
- const double* f = (double*)NDL_PTR(lp, 1);
355
- const double* g = (double*)NDL_PTR(lp, 2);
356
- const double* h = (double*)NDL_PTR(lp, 3);
357
- const double s_grad = ((double*)lp->opt_ptr)[0];
358
- const double s_hess = ((double*)lp->opt_ptr)[1];
359
- const double reg_lambda = ((double*)lp->opt_ptr)[2];
360
- const long n_elements = NDL_SHAPE(lp, 0)[0];
361
- double* params = (double*)NDL_PTR(lp, 4);
362
- long curr_pos = 0;
363
- long next_pos = 0;
364
- double curr_el = f[o[0]];
365
- double last_el = f[o[n_elements - 1]];
366
- double next_el;
367
- double l_grad = 0.0;
368
- double l_hess = 0.0;
369
- double r_grad;
370
- double r_hess;
371
- double threshold = curr_el;
372
- double gain_max = 0.0;
373
- double gain;
374
-
375
- /* Find optimal parameters. */
376
- while (curr_pos < n_elements && curr_el != last_el) {
377
- next_el = f[o[next_pos]];
378
- while (next_pos < n_elements && next_el == curr_el) {
379
- l_grad += g[o[next_pos]];
380
- l_hess += h[o[next_pos]];
381
- next_pos++;
382
- next_el = f[o[next_pos]];
383
- }
384
- /* Calculate gain of new split. */
385
- r_grad = s_grad - l_grad;
386
- r_hess = s_hess - l_hess;
387
- gain = (l_grad * l_grad) / (l_hess + reg_lambda) + (r_grad * r_grad) / (r_hess + reg_lambda) -
388
- (s_grad * s_grad) / (s_hess + reg_lambda);
389
- /* Update optimal parameters. */
390
- if (gain > gain_max) {
391
- threshold = 0.5 * (curr_el + next_el);
392
- gain_max = gain;
393
- }
394
- if (next_pos == n_elements) {
395
- break;
396
- }
397
- curr_pos = next_pos;
398
- curr_el = f[o[curr_pos]];
399
- }
400
-
401
- params[0] = threshold;
402
- params[1] = gain_max;
403
- }
404
-
405
- /**
406
- * @!visibility private
407
- * Find for split point with maximum information gain.
408
- *
409
- * @overload find_split_params(order, features, gradients, hessians, sum_gradient, sum_hessian, reg_lambda) -> Array<Float>
410
- * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
411
- * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
412
- * @param gradients [Numo::DFloat] (shape: [n_elements]) The gradient values.
413
- * @param hessians [Numo::DFloat] (shape: [n_elements]) The hessian values.
414
- * @param sum_gradient [Float] The sum of gradient values.
415
- * @param sum_hessian [Float] The sum of hessian values.
416
- * @param reg_lambda [Float] The L2 regularization term on weight.
417
- * @return [Array<Float>] The array consists of optimal parameters including threshold and gain.
418
- */
419
- static VALUE find_split_params_grad_reg(VALUE self, VALUE order, VALUE features, VALUE gradients, VALUE hessians,
420
- VALUE sum_gradient, VALUE sum_hessian, VALUE reg_lambda) {
421
- ndfunc_arg_in_t ain[4] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}};
422
- size_t out_shape[1] = {2};
423
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
424
- ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_grad_reg, NO_LOOP, 4, 1, ain, aout};
425
- double opts[3] = {NUM2DBL(sum_gradient), NUM2DBL(sum_hessian), NUM2DBL(reg_lambda)};
426
- VALUE params = na_ndloop3(&ndf, opts, 4, order, features, gradients, hessians);
427
- VALUE results = rb_ary_new2(2);
428
- double* params_ptr = (double*)na_get_pointer_for_read(params);
429
- rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
430
- rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
431
- RB_GC_GUARD(params);
432
- return results;
433
- }
434
-
435
- /**
436
- * @!visibility private
437
- */
438
- typedef struct {
439
- char* criterion;
440
- long n_classes;
441
- } node_impurity_cls_opts;
442
-
443
- /**
444
- * @!visibility private
445
- */
446
- static void iter_node_impurity_cls(na_loop_t const* lp) {
447
- const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
448
- const char* criterion = ((node_impurity_cls_opts*)lp->opt_ptr)->criterion;
449
- const long n_classes = ((node_impurity_cls_opts*)lp->opt_ptr)->n_classes;
450
- const long n_elements = NDL_SHAPE(lp, 0)[0];
451
- double* ret = (double*)NDL_PTR(lp, 1);
452
- double* histogram = alloc_dbl_array(n_classes);
453
- for (long i = 0; i < n_elements; i++) histogram[y[i]] += 1;
454
- *ret = calc_impurity_cls(criterion, histogram, n_elements, n_classes);
455
- xfree(histogram);
456
- }
457
-
458
- /**
459
- * @!visibility private
460
- * Calculate impurity based on criterion.
461
- *
462
- * @overload node_impurity(criterion, y, n_classes) -> Float
463
- *
464
- * @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
465
- * @param y [Numo::Int32] (shape: [n_samples]) The labels.
466
- * @param n_classes [Integer] The number of classes.
467
- * @return [Float] impurity
468
- */
469
- static VALUE node_impurity_cls(VALUE self, VALUE criterion, VALUE y, VALUE n_classes) {
470
- ndfunc_arg_in_t ain[1] = {{numo_cInt32, 1}};
471
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 0}};
472
- ndfunc_t ndf = {(na_iter_func_t)iter_node_impurity_cls, NDF_EXTRACT, 1, 1, ain, aout};
473
- node_impurity_cls_opts opts = {StringValuePtr(criterion), NUM2LONG(n_classes)};
474
- VALUE ret = na_ndloop3(&ndf, &opts, 1, y);
475
- RB_GC_GUARD(criterion);
476
- return ret;
477
- }
478
-
479
- /**
480
- * @!visibility private
481
- */
482
- static void iter_check_same_label(na_loop_t const* lp) {
483
- const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
484
- const long n_elements = NDL_SHAPE(lp, 0)[0];
485
- int32_t* ret = (int32_t*)NDL_PTR(lp, 1);
486
- *ret = 1;
487
- if (n_elements > 0) {
488
- int32_t label = y[0];
489
- for (long i = 0; i < n_elements; i++) {
490
- if (y[i] != label) {
491
- *ret = 0;
492
- break;
493
- }
494
- }
495
- }
496
- }
497
-
498
- /**
499
- * @!visibility private
500
- * Check all elements have the save value.
501
- *
502
- * @overload check_same_label(y) -> Boolean
503
- *
504
- * @param y [Numo::Int32] (shape: [n_samples]) The labels.
505
- * @return [Boolean]
506
- */
507
- static VALUE check_same_label(VALUE self, VALUE y) {
508
- ndfunc_arg_in_t ain[1] = {{numo_cInt32, 1}};
509
- ndfunc_arg_out_t aout[1] = {{numo_cInt32, 0}};
510
- ndfunc_t ndf = {(na_iter_func_t)iter_check_same_label, NDF_EXTRACT, 1, 1, ain, aout};
511
- VALUE ret = na_ndloop(&ndf, 1, y);
512
- return (NUM2INT(ret) == 1 ? Qtrue : Qfalse);
513
- }
514
-
515
- /**
516
- * @!visibility private
517
- * Calculate impurity based on criterion.
518
- *
519
- * @overload node_impurity(criterion, y) -> Float
520
- *
521
- * @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
522
- * @param y [Array<Float>] (shape: [n_samples, n_outputs]) The taget values.
523
- * @return [Float] impurity
524
- */
525
- static VALUE node_impurity_reg(VALUE self, VALUE criterion, VALUE y) {
526
- const long n_elements = RARRAY_LEN(y);
527
- const long n_outputs = RARRAY_LEN(rb_ary_entry(y, 0));
528
- double* sum_vec = alloc_dbl_array(n_outputs);
529
- VALUE target_vecs = rb_ary_new();
530
-
531
- for (long i = 0; i < n_elements; i++) {
532
- VALUE target = rb_ary_entry(y, i);
533
- add_sum_vec(sum_vec, target);
534
- rb_ary_push(target_vecs, target);
535
- }
536
-
537
- VALUE ret = DBL2NUM(calc_impurity_reg(StringValuePtr(criterion), target_vecs, sum_vec));
538
- xfree(sum_vec);
539
- RB_GC_GUARD(criterion);
540
- return ret;
541
- }
542
-
543
- void Init_ext(void) {
544
- VALUE rb_mRumale = rb_define_module("Rumale");
545
- VALUE rb_mTree = rb_define_module_under(rb_mRumale, "Tree");
546
-
547
- /**
548
- * Document-module: Rumale::Tree::ExtDecisionTreeClassifier
549
- * @!visibility private
550
- * The mixin module consisting of extension method for DecisionTreeClassifier class.
551
- * This module is used internally.
552
- */
553
- VALUE rb_mExtDTreeCls = rb_define_module_under(rb_mTree, "ExtDecisionTreeClassifier");
554
- /**
555
- * Document-module: Rumale::Tree::ExtDecisionTreeRegressor
556
- * @!visibility private
557
- * The mixin module consisting of extension method for DecisionTreeRegressor class.
558
- * This module is used internally.
559
- */
560
- VALUE rb_mExtDTreeReg = rb_define_module_under(rb_mTree, "ExtDecisionTreeRegressor");
561
- /**
562
- * Document-module: Rumale::Tree::ExtGradientTreeRegressor
563
- * @!visibility private
564
- * The mixin module consisting of extension method for GradientTreeRegressor class.
565
- * This module is used internally.
566
- */
567
- VALUE rb_mExtGTreeReg = rb_define_module_under(rb_mTree, "ExtGradientTreeRegressor");
568
-
569
- rb_define_private_method(rb_mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
570
- rb_define_private_method(rb_mExtDTreeReg, "find_split_params", find_split_params_reg, 5);
571
- rb_define_private_method(rb_mExtGTreeReg, "find_split_params", find_split_params_grad_reg, 7);
572
- rb_define_private_method(rb_mExtDTreeCls, "node_impurity", node_impurity_cls, 3);
573
- rb_define_private_method(rb_mExtDTreeCls, "stop_growing?", check_same_label, 1);
574
- rb_define_private_method(rb_mExtDTreeReg, "node_impurity", node_impurity_reg, 2);
575
- }
@@ -1,12 +0,0 @@
1
- #ifndef RUMALE_TREE_EXT_H
2
- #define RUMALE_TREE_EXT_H 1
3
-
4
- #include <math.h>
5
- #include <string.h>
6
-
7
- #include <ruby.h>
8
-
9
- #include <numo/narray.h>
10
- #include <numo/template.h>
11
-
12
- #endif /* RUMALE_TREE_EXT_H */