rumale-tree 0.26.0 → 0.28.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: eae93f4d3753b16336ffed4b0a467cf09ddb33b0b89005a83ed5088ea2b362ef
4
- data.tar.gz: 73d6c88f3653b5d2c70ef884fb7f5ef5a68e91d79abd19570476c13f02c434c6
3
+ metadata.gz: a437c3b879dd51d2e823f1851ca8b350d0441947fb6cefe727de83db78b4e6d9
4
+ data.tar.gz: 519665f2baea649ec31c9d5529aa05c67a880daa8ab60149433a53483f630663
5
5
  SHA512:
6
- metadata.gz: 6a5fa902ab481cc7fe7bdbc38e2e4da2f57c4b712d47fef6daabfd757d263949e19d12c5324a7ff0b0372e656a435908d3049380aef4f241c7ff39fc8914bf44
7
- data.tar.gz: 7a5e3b465549928e01143949e4001d7825882a687591adb91f82f7b753f7527c4fa2e5787ad5af263777855f3e8a3e5101386fdc361da3940f7f6dce0933c5bb
6
+ metadata.gz: aa68d3e2fbe99ef4c8d15a10f64098abfbbcdce9047053967e578edad688fa08a03a1ac4e4d23dbd5e67a4a23370057aa9462107bb961eb4d3f12782ce0c0f18
7
+ data.tar.gz: 6a6e6f167e31d56051768a881a4db0da7ef02f36977cb3969e8160bad385757153b1f6be32dd364ed3f67ed4672dd0c1bfd14cc80e0b033e00dd282d757955a9
@@ -0,0 +1,39 @@
1
+ /**
2
+ * Copyright (c) 2022-2023 Atsushi Tatsuma
3
+ * All rights reserved.
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * * Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * * Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * * Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ */
30
+
31
+ #include "ext.hpp"
32
+
33
+ extern "C" void Init_ext(void) {
34
+ VALUE rb_mRumale = rb_define_module("Rumale");
35
+ VALUE rb_mTree = rb_define_module_under(rb_mRumale, "Tree");
36
+ ExtDecisionTreeClassifier::define_module(rb_mTree);
37
+ ExtDecisionTreeRegressor::define_module(rb_mTree);
38
+ ExtGradientTreeRegressor::define_module(rb_mTree);
39
+ }
@@ -0,0 +1,550 @@
1
+ /**
2
+ * Copyright (c) 2022-2023 Atsushi Tatsuma
3
+ * All rights reserved.
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * * Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * * Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * * Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ */
30
+
31
+ #ifndef RUMALE_TREE_EXT_HPP
32
+ #define RUMALE_TREE_EXT_HPP 1
33
+
34
+ #include <cmath>
35
+ #include <limits>
36
+ #include <string>
37
+ #include <vector>
38
+
39
+ #include <ruby.h>
40
+
41
+ #include <numo/narray.h>
42
+ #include <numo/template.h>
43
+
44
+ /**
45
+ * @!visibility private
46
+ * Document-module: Rumale::Tree::ExtDecisionTreeClassifier
47
+ * The mixin module consisting of extension method for DecisionTreeClassifier class.
48
+ * This module is used internally.
49
+ */
50
+ class ExtDecisionTreeClassifier {
51
+ public:
52
+ static void define_module(VALUE& outer) {
53
+ VALUE rb_mExtDTreeCls = rb_define_module_under(outer, "ExtDecisionTreeClassifier");
54
+ rb_define_private_method(rb_mExtDTreeCls, "find_split_params", find_split_params_, 6);
55
+ rb_define_private_method(rb_mExtDTreeCls, "node_impurity", node_impurity_, 3);
56
+ rb_define_private_method(rb_mExtDTreeCls, "stop_growing?", check_same_label_, 1);
57
+ }
58
+
59
+ private:
60
+ static double calc_impurity_(const std::string& criterion, const std::vector<size_t>& histogram, const size_t& n_elements, const size_t& n_classes) {
61
+ double impurity = 0.0;
62
+ if (criterion == "entropy") {
63
+ double entropy = 0.0;
64
+ for (size_t i = 0; i < n_classes; i++) {
65
+ const double el = static_cast<double>(histogram[i]) / static_cast<double>(n_elements);
66
+ entropy += el * std::log(el + 1.0);
67
+ }
68
+ impurity = -entropy;
69
+ } else {
70
+ double gini = 0.0;
71
+ for (size_t i = 0; i < n_classes; i++) {
72
+ const double el = static_cast<double>(histogram[i]) / static_cast<double>(n_elements);
73
+ gini += el * el;
74
+ }
75
+ impurity = 1.0 - gini;
76
+ }
77
+ return impurity;
78
+ }
79
+
80
+ /**
81
+ * @!visibility private
82
+ * Find for split point with maximum information gain.
83
+ *
84
+ * @overload find_split_params(criterion, impurity, order, features, labels, n_classes) -> Array<Float>
85
+ *
86
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
87
+ * @param impurity [Float] The impurity of whole dataset.
88
+ * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
89
+ * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
90
+ * @param labels [Numo::Int32] (shape: [n_elements]) The labels.
91
+ * @param n_classes [Integer] The number of classes.
92
+ * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
93
+ */
94
+
95
+ struct FindSplitParamsOpts_ {
96
+ std::string criterion;
97
+ size_t n_classes;
98
+ double impurity;
99
+ };
100
+
101
+ static void iter_find_split_params_(na_loop_t const* lp) {
102
+ // Obtain iteration variables.
103
+ const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
104
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
105
+ const double* f = (double*)NDL_PTR(lp, 1);
106
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 2);
107
+ const std::string criterion = ((FindSplitParamsOpts_*)lp->opt_ptr)->criterion;
108
+ const size_t n_classes = ((FindSplitParamsOpts_*)lp->opt_ptr)->n_classes;
109
+ const double w_impurity = ((FindSplitParamsOpts_*)lp->opt_ptr)->impurity;
110
+
111
+ // Initialize output optimal parameters.
112
+ double* params = (double*)NDL_PTR(lp, 3);
113
+ params[0] = 0.0; // left impurity
114
+ params[1] = w_impurity; // right impurity
115
+ params[2] = f[o[0]]; // threshold
116
+ params[3] = 0.0; // gain
117
+
118
+ // Initialize child node variables.
119
+ std::vector<size_t> r_histogram(n_classes, 0);
120
+ for (size_t i = 0; i < n_elements; i++) r_histogram[y[o[i]]] += 1;
121
+
122
+ // Find optimal parameters.
123
+ size_t curr_pos = 0;
124
+ size_t next_pos = 0;
125
+ size_t n_l_elements = 0;
126
+ size_t n_r_elements = n_elements;
127
+ double curr_el = f[o[0]];
128
+ const double last_el = f[o[n_elements - 1]];
129
+ std::vector<size_t> l_histogram(n_classes, 0);
130
+ while (curr_pos < n_elements && curr_el != last_el) {
131
+ double next_el = f[o[next_pos]];
132
+ while (next_pos < n_elements && next_el == curr_el) {
133
+ l_histogram[y[o[next_pos]]] += 1;
134
+ n_l_elements++;
135
+ r_histogram[y[o[next_pos]]] -= 1;
136
+ n_r_elements--;
137
+ next_pos++;
138
+ next_el = f[o[next_pos]];
139
+ }
140
+ // Calculate gain of new split.
141
+ const double l_impurity = calc_impurity_(criterion, l_histogram, n_l_elements, n_classes);
142
+ const double r_impurity = calc_impurity_(criterion, r_histogram, n_r_elements, n_classes);
143
+ const double gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / static_cast<double>(n_elements);
144
+ // Update optimal parameters.
145
+ if (gain > params[3]) {
146
+ params[0] = l_impurity;
147
+ params[1] = r_impurity;
148
+ params[2] = 0.5 * (curr_el + next_el);
149
+ params[3] = gain;
150
+ }
151
+ if (next_pos == n_elements) break;
152
+ curr_pos = next_pos;
153
+ curr_el = f[o[curr_pos]];
154
+ }
155
+ }
156
+
157
+ static VALUE find_split_params_(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE labels, VALUE n_classes) {
158
+ ndfunc_arg_in_t ain[3] = { { numo_cInt32, 1 }, { numo_cDFloat, 1 }, { numo_cInt32, 1 } };
159
+ size_t out_shape[1] = { 4 };
160
+ ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 1, out_shape } };
161
+ ndfunc_t ndf = { (na_iter_func_t)iter_find_split_params_, NO_LOOP, 3, 1, ain, aout };
162
+ FindSplitParamsOpts_ opts = { std::string(StringValueCStr(criterion)), NUM2SIZET(n_classes), NUM2DBL(impurity) };
163
+ VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, labels);
164
+ RB_GC_GUARD(criterion);
165
+ return params;
166
+ }
167
+
168
+ /**
169
+ * @!visibility private
170
+ * Calculate impurity based on criterion.
171
+ *
172
+ * @overload node_impurity(criterion, y, n_classes) -> Float
173
+ *
174
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
175
+ * @param y [Numo::Int32] (shape: [n_samples]) The labels.
176
+ * @param n_classes [Integer] The number of classes.
177
+ * @return [Float] impurity
178
+ */
179
+
180
+ struct NodeImpurityOpts_ {
181
+ std::string criterion;
182
+ size_t n_classes;
183
+ };
184
+
185
+ static void iter_node_impurity_(na_loop_t const* lp) {
186
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
187
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
188
+ const std::string criterion = ((NodeImpurityOpts_*)lp->opt_ptr)->criterion;
189
+ const size_t n_classes = ((NodeImpurityOpts_*)lp->opt_ptr)->n_classes;
190
+ double* ret = (double*)NDL_PTR(lp, 1);
191
+ std::vector<size_t> histogram(n_classes, 0);
192
+ for (size_t i = 0; i < n_elements; i++) histogram[y[i]] += 1;
193
+ *ret = calc_impurity_(criterion, histogram, n_elements, n_classes);
194
+ }
195
+
196
+ static VALUE node_impurity_(VALUE self, VALUE criterion, VALUE y, VALUE n_classes) {
197
+ ndfunc_arg_in_t ain[1] = { { numo_cInt32, 1 } };
198
+ ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 0 } };
199
+ ndfunc_t ndf = { (na_iter_func_t)iter_node_impurity_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
200
+ NodeImpurityOpts_ opts = { std::string(StringValueCStr(criterion)), NUM2SIZET(n_classes) };
201
+ VALUE ret = na_ndloop3(&ndf, &opts, 1, y);
202
+ RB_GC_GUARD(criterion);
203
+ return ret;
204
+ }
205
+
206
+ /**
207
+ * @!visibility private
208
+ * Check all elements have the same value.
209
+ *
210
+ * @overload check_same_label(y) -> Boolean
211
+ *
212
+ * @param y [Numo::Int32] (shape: [n_samples]) The labels.
213
+ * @return [Boolean]
214
+ */
215
+
216
+ static void iter_check_same_label_(na_loop_t const* lp) {
217
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
218
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
219
+ VALUE* ret = (VALUE*)NDL_PTR(lp, 1);
220
+ *ret = Qtrue;
221
+ if (n_elements > 0) {
222
+ int32_t label = y[0];
223
+ for (size_t i = 0; i < n_elements; i++) {
224
+ if (y[i] != label) {
225
+ *ret = Qfalse;
226
+ break;
227
+ }
228
+ }
229
+ }
230
+ }
231
+
232
+ static VALUE check_same_label_(VALUE self, VALUE y) {
233
+ ndfunc_arg_in_t ain[1] = { { numo_cInt32, 1 } };
234
+ ndfunc_arg_out_t aout[1] = { { numo_cRObject, 0 } };
235
+ ndfunc_t ndf = { (na_iter_func_t)iter_check_same_label_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
236
+ return na_ndloop(&ndf, 1, y);
237
+ }
238
+ };
239
+
240
+ /**
241
+ * @!visibility private
242
+ * Document-module: Rumale::Tree::ExtDecisionTreeRegressor
243
+ * The mixin module consisting of extension method for DecisionTreeRegressor class.
244
+ * This module is used internally.
245
+ */
246
+ class ExtDecisionTreeRegressor {
247
+ public:
248
+ static void define_module(VALUE& outer) {
249
+ VALUE rb_mExtDTreeReg = rb_define_module_under(outer, "ExtDecisionTreeRegressor");
250
+ rb_define_private_method(rb_mExtDTreeReg, "find_split_params", find_split_params_, 5);
251
+ rb_define_private_method(rb_mExtDTreeReg, "node_impurity", node_impurity_, 2);
252
+ rb_define_private_method(rb_mExtDTreeReg, "stop_growing?", check_same_value_, 1);
253
+ }
254
+
255
+ private:
256
+ static double calc_impurity_(const std::string& criterion, const int32_t* order, const double* vecs, const double* mean_vec,
257
+ const size_t& n_elements, const size_t& n_outputs, const size_t& order_offset) {
258
+ const bool is_mae = criterion == "mae";
259
+ double sum_err = 0.0;
260
+ for (size_t i = 0; i < n_elements; i++) {
261
+ double err = 0.0;
262
+ for (size_t j = 0; j < n_outputs; j++) {
263
+ const double el = vecs[order[order_offset + i] * n_outputs + j] - mean_vec[j];
264
+ err += is_mae ? std::fabs(el) : el * el;
265
+ }
266
+ err /= static_cast<double>(n_outputs);
267
+ sum_err += err;
268
+ }
269
+ const double impurity = sum_err / static_cast<double>(n_elements);
270
+ return impurity;
271
+ }
272
+
273
+ /**
274
+ * @!visibility private
275
+ * Find for split point with maximum information gain.
276
+ *
277
+ * @overload find_split_params(criterion, impurity, order, features, targets) -> Array<Float>
278
+ *
279
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
280
+ * @param impurity [Float] The impurity of whole dataset.
281
+ * @param order [Numo::Int32] (shape: [n_samples]) The element indices sorted according to feature values in ascending order.
282
+ * @param features [Numo::DFloat] (shape: [n_samples]) The feature values.
283
+ * @param targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
284
+ * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
285
+ */
286
+
287
+ struct FindSplitParamsOpts_ {
288
+ std::string criterion;
289
+ double impurity;
290
+ };
291
+
292
+ static void iter_find_split_params_(na_loop_t const* lp) {
293
+ // Obtain iteration variables.
294
+ const int32_t* order = (int32_t*)NDL_PTR(lp, 0);
295
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
296
+ const double* f = (double*)NDL_PTR(lp, 1);
297
+ const double* y = (double*)NDL_PTR(lp, 2);
298
+ const size_t n_outputs = NDL_SHAPE(lp, 2)[1];
299
+ const std::string criterion = ((FindSplitParamsOpts_*)lp->opt_ptr)->criterion;
300
+ const double w_impurity = ((FindSplitParamsOpts_*)lp->opt_ptr)->impurity;
301
+
302
+ // Initialize optimal parameters.
303
+ double* params = (double*)NDL_PTR(lp, 3);
304
+ params[0] = 0.0; // left impurity
305
+ params[1] = w_impurity; // right impurity
306
+ params[2] = f[order[0]]; // threshold
307
+ params[3] = 0.0; // gain
308
+
309
+ // Initialize child node variables.
310
+ std::vector<double> l_sum_y(n_outputs, 0);
311
+ std::vector<double> r_sum_y(n_outputs, 0);
312
+ for (size_t i = 0; i < n_elements; i++) {
313
+ for (size_t j = 0; j < n_outputs; j++) {
314
+ r_sum_y[j] += y[order[i] * n_outputs + j];
315
+ }
316
+ }
317
+
318
+ // Find optimal parameters.
319
+ size_t curr_pos = 0;
320
+ size_t next_pos = 0;
321
+ size_t n_l_elements = 0;
322
+ size_t n_r_elements = n_elements;
323
+ std::vector<double> l_mean_y(n_outputs, 0);
324
+ std::vector<double> r_mean_y(n_outputs, 0);
325
+ double curr_el = f[order[0]];
326
+ const double last_el = f[order[n_elements - 1]];
327
+ while (curr_pos < n_elements && curr_el != last_el) {
328
+ double next_el = f[order[next_pos]];
329
+ while (next_pos < n_elements && next_el == curr_el) {
330
+ for (size_t j = 0; j < n_outputs; j++) {
331
+ l_sum_y[j] += y[order[next_pos] * n_outputs + j];
332
+ r_sum_y[j] -= y[order[next_pos] * n_outputs + j];
333
+ }
334
+ n_l_elements++;
335
+ n_r_elements--;
336
+ next_pos++;
337
+ next_el = f[order[next_pos]];
338
+ }
339
+ // Calculate gain of new split.
340
+ for (size_t j = 0; j < n_outputs; j++) {
341
+ l_mean_y[j] = l_sum_y[j] / static_cast<double>(n_l_elements);
342
+ r_mean_y[j] = r_sum_y[j] / static_cast<double>(n_r_elements);
343
+ }
344
+ const double l_impurity = calc_impurity_(criterion, order, y, l_mean_y.data(), n_l_elements, n_outputs, 0);
345
+ const double r_impurity = calc_impurity_(criterion, order, y, r_mean_y.data(), n_r_elements, n_outputs, next_pos);
346
+ const double gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / static_cast<double>(n_elements);
347
+ // Update optimal parameters.
348
+ if (gain > params[3]) {
349
+ params[0] = l_impurity;
350
+ params[1] = r_impurity;
351
+ params[2] = 0.5 * (curr_el + next_el);
352
+ params[3] = gain;
353
+ }
354
+ if (next_pos == n_elements) break;
355
+ curr_pos = next_pos;
356
+ curr_el = f[order[curr_pos]];
357
+ }
358
+ }
359
+
360
+ static VALUE find_split_params_(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE targets) {
361
+ ndfunc_arg_in_t ain[3] = { { numo_cInt32, 1 }, { numo_cDFloat, 1 }, { numo_cDFloat, 2 } };
362
+ size_t out_shape[1] = { 4 };
363
+ ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 1, out_shape } };
364
+ ndfunc_t ndf = { (na_iter_func_t)iter_find_split_params_, NO_LOOP, 3, 1, ain, aout };
365
+ FindSplitParamsOpts_ opts = { std::string(StringValueCStr(criterion)), NUM2DBL(impurity) };
366
+ VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, targets);
367
+ RB_GC_GUARD(criterion);
368
+ return params;
369
+ }
370
+
371
+ /**
372
+ * @!visibility private
373
+ * Calculate impurity based on criterion.
374
+ *
375
+ * @overload node_impurity(criterion, y) -> Float
376
+ *
377
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
378
+ * @param y [Array<Float>] (shape: [n_samples, n_outputs]) The taget values.
379
+ * @return [Float] impurity
380
+ */
381
+
382
+ struct NodeImpurityOpts_ {
383
+ std::string criterion;
384
+ };
385
+
386
+ static void iter_node_impurity_(na_loop_t const* lp) {
387
+ const double* y = (double*)NDL_PTR(lp, 0);
388
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
389
+ const size_t n_outputs = NDL_SHAPE(lp, 0)[1];
390
+ const std::string criterion = ((NodeImpurityOpts_*)lp->opt_ptr)->criterion;
391
+
392
+ std::vector<int32_t> order(n_elements);
393
+ std::vector<double> mean_y(n_outputs, 0);
394
+ for (size_t i = 0; i < n_elements; i++) {
395
+ order[i] = static_cast<int32_t>(i);
396
+ for (size_t j = 0; j < n_outputs; j++) {
397
+ mean_y[j] += y[i * n_outputs + j];
398
+ }
399
+ }
400
+ for (size_t j = 0; j < n_outputs; j++) {
401
+ mean_y[j] /= static_cast<double>(n_elements);
402
+ }
403
+
404
+ double* ret = (double*)NDL_PTR(lp, 1);
405
+ *ret = calc_impurity_(criterion, order.data(), y, mean_y.data(), n_elements, n_outputs, 0);
406
+ }
407
+
408
+ static VALUE node_impurity_(VALUE self, VALUE criterion, VALUE y) {
409
+ ndfunc_arg_in_t ain[1] = { { numo_cDFloat, 2 } };
410
+ ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 0 } };
411
+ ndfunc_t ndf = { (na_iter_func_t)iter_node_impurity_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
412
+ NodeImpurityOpts_ opts = { std::string(StringValueCStr(criterion)) };
413
+ VALUE ret = na_ndloop3(&ndf, &opts, 1, y);
414
+ RB_GC_GUARD(criterion);
415
+ return ret;
416
+ }
417
+
418
+ /**
419
+ * @!visibility private
420
+ * Check all elements have the same value/vector.
421
+ *
422
+ * @overload check_same_value(y) -> Boolean
423
+ *
424
+ * @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
425
+ * @return [Boolean]
426
+ */
427
+
428
+ static void iter_check_same_value_(na_loop_t const* lp) {
429
+ const double* y = (double*)NDL_PTR(lp, 0);
430
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
431
+ const size_t n_outputs = NDL_SHAPE(lp, 0)[1];
432
+ VALUE* ret = (VALUE*)NDL_PTR(lp, 1);
433
+ const double eps = std::numeric_limits<double>::epsilon();
434
+ *ret = Qtrue;
435
+ if (n_elements > 0) {
436
+ for (size_t i = 1; i < n_elements; i++) {
437
+ for (size_t j = 0; j < n_outputs; j++) {
438
+ if (std::abs(y[i * n_outputs + j] - y[j]) > eps) {
439
+ *ret = Qfalse;
440
+ break;
441
+ }
442
+ }
443
+ if (*ret == Qfalse) break;
444
+ }
445
+ }
446
+ }
447
+
448
+ static VALUE check_same_value_(VALUE self, VALUE y) {
449
+ ndfunc_arg_in_t ain[1] = { { numo_cDFloat, 2 } };
450
+ ndfunc_arg_out_t aout[1] = { { numo_cRObject, 0 } };
451
+ ndfunc_t ndf = { (na_iter_func_t)iter_check_same_value_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
452
+ return na_ndloop(&ndf, 1, y);
453
+ }
454
+ };
455
+
456
+ /**
457
+ * @!visibility private
458
+ * Document-module: Rumale::Tree::ExtGradientTreeRegressor
459
+ * The mixin module consisting of extension method for GradientTreeRegressor class.
460
+ * This module is used internally.
461
+ */
462
+ class ExtGradientTreeRegressor {
463
+ public:
464
+ static void define_module(VALUE& outer) {
465
+ VALUE rb_mExtGTreeReg = rb_define_module_under(outer, "ExtGradientTreeRegressor");
466
+ rb_define_private_method(rb_mExtGTreeReg, "find_split_params", find_split_params_, 7);
467
+ }
468
+
469
+ private:
470
+ /**
471
+ * @!visibility private
472
+ * Find for split point with maximum information gain.
473
+ *
474
+ * @overload find_split_params(order, features, gradients, hessians, sum_gradient, sum_hessian, reg_lambda) -> Array<Float>
475
+ * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
476
+ * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
477
+ * @param gradients [Numo::DFloat] (shape: [n_elements]) The gradient values.
478
+ * @param hessians [Numo::DFloat] (shape: [n_elements]) The hessian values.
479
+ * @param sum_gradient [Float] The sum of gradient values.
480
+ * @param sum_hessian [Float] The sum of hessian values.
481
+ * @param reg_lambda [Float] The L2 regularization term on weight.
482
+ * @return [Array<Float>] The array consists of optimal parameters including threshold and gain.
483
+ */
484
+
485
+ struct FindSplitParamsOpts_ {
486
+ double sum_gradient;
487
+ double sum_hessian;
488
+ double reg_lambda;
489
+ };
490
+
491
+ static void iter_find_split_params_(na_loop_t const* lp) {
492
+ // Obtain iteration variables.
493
+ const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
494
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
495
+ const double* f = (double*)NDL_PTR(lp, 1);
496
+ const double* g = (double*)NDL_PTR(lp, 2);
497
+ const double* h = (double*)NDL_PTR(lp, 3);
498
+ const double s_grad = ((FindSplitParamsOpts_*)lp->opt_ptr)->sum_gradient;
499
+ const double s_hess = ((FindSplitParamsOpts_*)lp->opt_ptr)->sum_hessian;
500
+ const double reg_lambda = ((FindSplitParamsOpts_*)lp->opt_ptr)->reg_lambda;
501
+
502
+ // Find optimal parameters.
503
+ size_t curr_pos = 0;
504
+ size_t next_pos = 0;
505
+ double curr_el = f[o[0]];
506
+ const double last_el = f[o[n_elements - 1]];
507
+ double l_grad = 0.0;
508
+ double l_hess = 0.0;
509
+ double threshold = curr_el;
510
+ double gain_max = 0.0;
511
+ while (curr_pos < n_elements && curr_el != last_el) {
512
+ double next_el = f[o[next_pos]];
513
+ while (next_pos < n_elements && next_el == curr_el) {
514
+ l_grad += g[o[next_pos]];
515
+ l_hess += h[o[next_pos]];
516
+ next_pos++;
517
+ next_el = f[o[next_pos]];
518
+ }
519
+ // Calculate gain of new split.
520
+ const double r_grad = s_grad - l_grad;
521
+ const double r_hess = s_hess - l_hess;
522
+ const double gain = (l_grad * l_grad) / (l_hess + reg_lambda) + (r_grad * r_grad) / (r_hess + reg_lambda) - (s_grad * s_grad) / (s_hess + reg_lambda);
523
+ // Update optimal parameters.
524
+ if (gain > gain_max) {
525
+ threshold = 0.5 * (curr_el + next_el);
526
+ gain_max = gain;
527
+ }
528
+ if (next_pos == n_elements) break;
529
+ curr_pos = next_pos;
530
+ curr_el = f[o[curr_pos]];
531
+ }
532
+
533
+ double* params = (double*)NDL_PTR(lp, 4);
534
+ params[0] = threshold;
535
+ params[1] = gain_max;
536
+ }
537
+
538
+ static VALUE find_split_params_(VALUE self, VALUE order, VALUE features, VALUE gradients, VALUE hessians,
539
+ VALUE sum_gradient, VALUE sum_hessian, VALUE reg_lambda) {
540
+ ndfunc_arg_in_t ain[4] = { { numo_cInt32, 1 }, { numo_cDFloat, 1 }, { numo_cDFloat, 1 }, { numo_cDFloat, 1 } };
541
+ size_t out_shape[1] = { 2 };
542
+ ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 1, out_shape } };
543
+ ndfunc_t ndf = { (na_iter_func_t)iter_find_split_params_, NO_LOOP, 4, 1, ain, aout };
544
+ FindSplitParamsOpts_ opts = { NUM2DBL(sum_gradient), NUM2DBL(sum_hessian), NUM2DBL(reg_lambda) };
545
+ VALUE params = na_ndloop3(&ndf, &opts, 4, order, features, gradients, hessians);
546
+ return params;
547
+ }
548
+ };
549
+
550
+ #endif /* RUMALE_TREE_EXT_HPP */
@@ -3,6 +3,8 @@
3
3
  require 'mkmf'
4
4
  require 'numo/narray'
5
5
 
6
+ abort 'libstdc++ is not found.' unless have_library('stdc++')
7
+
6
8
  $LOAD_PATH.each do |lp|
7
9
  if File.exist?(File.join(lp, 'numo/numo/narray.h'))
8
10
  $INCFLAGS = "-I#{lp}/numo #{$INCFLAGS}"
@@ -63,11 +63,7 @@ module Rumale
63
63
  end
64
64
 
65
65
  def build_tree(x, y)
66
- y = y.expand_dims(1).dup if y.shape[1].nil?
67
- @feature_ids = Array.new(x.shape[1]) { |v| v }
68
- @tree = grow_node(0, x, y, impurity(y))
69
- @feature_ids = nil
70
- nil
66
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
71
67
  end
72
68
 
73
69
  def grow_node(depth, x, y, impurity) # rubocop:disable Metrics/AbcSize, Metrics/PerceivedComplexity
@@ -68,6 +68,7 @@ module Rumale
68
68
  @params[:max_features] = [@params[:max_features], n_features].min
69
69
  @n_leaves = 0
70
70
  @leaf_values = []
71
+ @feature_ids = Array.new(x.shape[1]) { |v| v }
71
72
  @sub_rng = @rng.dup
72
73
  build_tree(x, y)
73
74
  eval_importance(n_samples, n_features)
@@ -88,8 +89,10 @@ module Rumale
88
89
 
89
90
  private
90
91
 
91
- def stop_growing?(y)
92
- y.to_a.uniq.size == 1
92
+ def build_tree(x, y)
93
+ y = y.expand_dims(1).dup if y.shape[1].nil?
94
+ @tree = grow_node(0, x, y, impurity(y))
95
+ nil
93
96
  end
94
97
 
95
98
  def put_leaf(node, y)
@@ -106,7 +109,7 @@ module Rumale
106
109
  end
107
110
 
108
111
  def impurity(y)
109
- node_impurity(@params[:criterion], y.to_a)
112
+ node_impurity(@params[:criterion], y)
110
113
  end
111
114
  end
112
115
  end
@@ -5,6 +5,6 @@ module Rumale
5
5
  # This module consists of the classes that implement tree models.
6
6
  module Tree
7
7
  # @!visibility private
8
- VERSION = '0.26.0'
8
+ VERSION = '0.28.0'
9
9
  end
10
10
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-tree
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.26.0
4
+ version: 0.28.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-02-19 00:00:00.000000000 Z
11
+ date: 2023-11-12 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -30,14 +30,14 @@ dependencies:
30
30
  requirements:
31
31
  - - "~>"
32
32
  - !ruby/object:Gem::Version
33
- version: 0.26.0
33
+ version: 0.28.0
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - "~>"
39
39
  - !ruby/object:Gem::Version
40
- version: 0.26.0
40
+ version: 0.28.0
41
41
  description: Rumale::Tree provides classifier and regression based on decision tree
42
42
  algorithms with Rumale interface.
43
43
  email:
@@ -49,8 +49,8 @@ extra_rdoc_files: []
49
49
  files:
50
50
  - LICENSE.txt
51
51
  - README.md
52
- - ext/rumale/tree/ext.c
53
- - ext/rumale/tree/ext.h
52
+ - ext/rumale/tree/ext.cpp
53
+ - ext/rumale/tree/ext.hpp
54
54
  - ext/rumale/tree/extconf.rb
55
55
  - lib/rumale/tree.rb
56
56
  - lib/rumale/tree/base_decision_tree.rb
@@ -85,7 +85,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
85
85
  - !ruby/object:Gem::Version
86
86
  version: '0'
87
87
  requirements: []
88
- rubygems_version: 3.3.26
88
+ rubygems_version: 3.4.20
89
89
  signing_key:
90
90
  specification_version: 4
91
91
  summary: Rumale::Tree provides classifier and regression based on decision tree algorithms
@@ -1,575 +0,0 @@
1
- #include "ext.h"
2
-
3
- double* alloc_dbl_array(const long n_dimensions) {
4
- double* arr = ALLOC_N(double, n_dimensions);
5
- memset(arr, 0, n_dimensions * sizeof(double));
6
- return arr;
7
- }
8
-
9
- double calc_gini_coef(double* histogram, const long n_elements, const long n_classes) {
10
- double gini = 0.0;
11
-
12
- for (long i = 0; i < n_classes; i++) {
13
- double el = histogram[i] / n_elements;
14
- gini += el * el;
15
- }
16
-
17
- return 1.0 - gini;
18
- }
19
-
20
- double calc_entropy(double* histogram, const long n_elements, const long n_classes) {
21
- double entropy = 0.0;
22
-
23
- for (long i = 0; i < n_classes; i++) {
24
- double el = histogram[i] / n_elements;
25
- entropy += el * log(el + 1.0);
26
- }
27
-
28
- return -entropy;
29
- }
30
-
31
- VALUE
32
- calc_mean_vec(double* sum_vec, const long n_dimensions, const long n_elements) {
33
- VALUE mean_vec = rb_ary_new2(n_dimensions);
34
-
35
- for (long i = 0; i < n_dimensions; i++) {
36
- rb_ary_store(mean_vec, i, DBL2NUM(sum_vec[i] / n_elements));
37
- }
38
-
39
- return mean_vec;
40
- }
41
-
42
- double calc_vec_mae(VALUE vec_a, VALUE vec_b) {
43
- const long n_dimensions = RARRAY_LEN(vec_a);
44
- double sum = 0.0;
45
-
46
- for (long i = 0; i < n_dimensions; i++) {
47
- double diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
48
- sum += fabs(diff);
49
- }
50
-
51
- return sum / n_dimensions;
52
- }
53
-
54
- double calc_vec_mse(VALUE vec_a, VALUE vec_b) {
55
- const long n_dimensions = RARRAY_LEN(vec_a);
56
- double sum = 0.0;
57
-
58
- for (long i = 0; i < n_dimensions; i++) {
59
- double diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
60
- sum += diff * diff;
61
- }
62
-
63
- return sum / n_dimensions;
64
- }
65
-
66
- double calc_mae(VALUE target_vecs, VALUE mean_vec) {
67
- const long n_elements = RARRAY_LEN(target_vecs);
68
- double sum = 0.0;
69
-
70
- for (long i = 0; i < n_elements; i++) {
71
- sum += calc_vec_mae(rb_ary_entry(target_vecs, i), mean_vec);
72
- }
73
-
74
- return sum / n_elements;
75
- }
76
-
77
- double calc_mse(VALUE target_vecs, VALUE mean_vec) {
78
- const long n_elements = RARRAY_LEN(target_vecs);
79
- double sum = 0.0;
80
-
81
- for (long i = 0; i < n_elements; i++) {
82
- sum += calc_vec_mse(rb_ary_entry(target_vecs, i), mean_vec);
83
- }
84
-
85
- return sum / n_elements;
86
- }
87
-
88
- double calc_impurity_cls(const char* criterion, double* histogram, const long n_elements, const long n_classes) {
89
- if (strcmp(criterion, "entropy") == 0) {
90
- return calc_entropy(histogram, n_elements, n_classes);
91
- }
92
- return calc_gini_coef(histogram, n_elements, n_classes);
93
- }
94
-
95
- double calc_impurity_reg(const char* criterion, VALUE target_vecs, double* sum_vec) {
96
- const long n_elements = RARRAY_LEN(target_vecs);
97
- const long n_dimensions = RARRAY_LEN(rb_ary_entry(target_vecs, 0));
98
- VALUE mean_vec = calc_mean_vec(sum_vec, n_dimensions, n_elements);
99
-
100
- if (strcmp(criterion, "mae") == 0) {
101
- return calc_mae(target_vecs, mean_vec);
102
- }
103
- return calc_mse(target_vecs, mean_vec);
104
- }
105
-
106
- void add_sum_vec(double* sum_vec, VALUE target) {
107
- const long n_dimensions = RARRAY_LEN(target);
108
-
109
- for (long i = 0; i < n_dimensions; i++) {
110
- sum_vec[i] += NUM2DBL(rb_ary_entry(target, i));
111
- }
112
- }
113
-
114
- void sub_sum_vec(double* sum_vec, VALUE target) {
115
- const long n_dimensions = RARRAY_LEN(target);
116
-
117
- for (long i = 0; i < n_dimensions; i++) {
118
- sum_vec[i] -= NUM2DBL(rb_ary_entry(target, i));
119
- }
120
- }
121
-
122
- /**
123
- * @!visibility private
124
- */
125
- typedef struct {
126
- char* criterion;
127
- long n_classes;
128
- double impurity;
129
- } split_opts_cls;
130
-
131
- /**
132
- * @!visibility private
133
- */
134
- static void iter_find_split_params_cls(na_loop_t const* lp) {
135
- const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
136
- const double* f = (double*)NDL_PTR(lp, 1);
137
- const int32_t* y = (int32_t*)NDL_PTR(lp, 2);
138
- const long n_elements = NDL_SHAPE(lp, 0)[0];
139
- const char* criterion = ((split_opts_cls*)lp->opt_ptr)->criterion;
140
- const long n_classes = ((split_opts_cls*)lp->opt_ptr)->n_classes;
141
- const double w_impurity = ((split_opts_cls*)lp->opt_ptr)->impurity;
142
- double* params = (double*)NDL_PTR(lp, 3);
143
- long curr_pos = 0;
144
- long next_pos = 0;
145
- long n_l_elements = 0;
146
- long n_r_elements = n_elements;
147
- double curr_el = f[o[0]];
148
- double last_el = f[o[n_elements - 1]];
149
- double next_el;
150
- double l_impurity;
151
- double r_impurity;
152
- double gain;
153
- double* l_histogram = alloc_dbl_array(n_classes);
154
- double* r_histogram = alloc_dbl_array(n_classes);
155
-
156
- /* Initialize optimal parameters. */
157
- params[0] = 0.0; /* left impurity */
158
- params[1] = w_impurity; /* right impurity */
159
- params[2] = curr_el; /* threshold */
160
- params[3] = 0.0; /* gain */
161
-
162
- /* Initialize child node variables. */
163
- for (long i = 0; i < n_elements; i++) {
164
- r_histogram[y[o[i]]] += 1.0;
165
- }
166
-
167
- /* Find optimal parameters. */
168
- while (curr_pos < n_elements && curr_el != last_el) {
169
- next_el = f[o[next_pos]];
170
- while (next_pos < n_elements && next_el == curr_el) {
171
- l_histogram[y[o[next_pos]]] += 1;
172
- n_l_elements++;
173
- r_histogram[y[o[next_pos]]] -= 1;
174
- n_r_elements--;
175
- next_pos++;
176
- next_el = f[o[next_pos]];
177
- }
178
- /* Calculate gain of new split. */
179
- l_impurity = calc_impurity_cls(criterion, l_histogram, n_l_elements, n_classes);
180
- r_impurity = calc_impurity_cls(criterion, r_histogram, n_r_elements, n_classes);
181
- gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
182
- /* Update optimal parameters. */
183
- if (gain > params[3]) {
184
- params[0] = l_impurity;
185
- params[1] = r_impurity;
186
- params[2] = 0.5 * (curr_el + next_el);
187
- params[3] = gain;
188
- }
189
- if (next_pos == n_elements) break;
190
- curr_pos = next_pos;
191
- curr_el = f[o[curr_pos]];
192
- }
193
-
194
- xfree(l_histogram);
195
- xfree(r_histogram);
196
- }
197
-
198
- /**
199
- * @!visibility private
200
- * Find for split point with maximum information gain.
201
- *
202
- * @overload find_split_params(criterion, impurity, order, features, labels, n_classes) -> Array<Float>
203
- *
204
- * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
205
- * @param impurity [Float] The impurity of whole dataset.
206
- * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
207
- * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
208
- * @param labels [Numo::Int32] (shape: [n_elements]) The labels.
209
- * @param n_classes [Integer] The number of classes.
210
- * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
211
- */
212
- static VALUE find_split_params_cls(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE labels,
213
- VALUE n_classes) {
214
- ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cInt32, 1}};
215
- size_t out_shape[1] = {4};
216
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
217
- ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_cls, NO_LOOP, 3, 1, ain, aout};
218
- split_opts_cls opts = {StringValuePtr(criterion), NUM2LONG(n_classes), NUM2DBL(impurity)};
219
- VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, labels);
220
- VALUE results = rb_ary_new2(4);
221
- double* params_ptr = (double*)na_get_pointer_for_read(params);
222
- rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
223
- rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
224
- rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
225
- rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
226
- RB_GC_GUARD(params);
227
- RB_GC_GUARD(criterion);
228
- return results;
229
- }
230
-
231
- /**
232
- * @!visibility private
233
- */
234
- typedef struct {
235
- char* criterion;
236
- double impurity;
237
- } split_opts_reg;
238
-
239
- /**
240
- * @!visibility private
241
- */
242
- static void iter_find_split_params_reg(na_loop_t const* lp) {
243
- const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
244
- const double* f = (double*)NDL_PTR(lp, 1);
245
- const double* y = (double*)NDL_PTR(lp, 2);
246
- const long n_elements = NDL_SHAPE(lp, 0)[0];
247
- const long n_outputs = NDL_SHAPE(lp, 2)[1];
248
- const char* criterion = ((split_opts_reg*)lp->opt_ptr)->criterion;
249
- const double w_impurity = ((split_opts_reg*)lp->opt_ptr)->impurity;
250
- double* params = (double*)NDL_PTR(lp, 3);
251
- long curr_pos = 0;
252
- long next_pos = 0;
253
- long n_l_elements = 0;
254
- long n_r_elements = n_elements;
255
- double curr_el = f[o[0]];
256
- double last_el = f[o[n_elements - 1]];
257
- double next_el;
258
- double l_impurity;
259
- double r_impurity;
260
- double gain;
261
- double* l_sum_vec = alloc_dbl_array(n_outputs);
262
- double* r_sum_vec = alloc_dbl_array(n_outputs);
263
- double target_var;
264
- VALUE l_target_vecs = rb_ary_new();
265
- VALUE r_target_vecs = rb_ary_new();
266
- VALUE target;
267
-
268
- /* Initialize optimal parameters. */
269
- params[0] = 0.0; /* left impurity */
270
- params[1] = w_impurity; /* right impurity */
271
- params[2] = curr_el; /* threshold */
272
- params[3] = 0.0; /* gain */
273
-
274
- /* Initialize child node variables. */
275
- for (long i = 0; i < n_elements; i++) {
276
- target = rb_ary_new2(n_outputs);
277
- for (long j = 0; j < n_outputs; j++) {
278
- target_var = y[o[i] * n_outputs + j];
279
- rb_ary_store(target, j, DBL2NUM(target_var));
280
- r_sum_vec[j] += target_var;
281
- }
282
- rb_ary_push(r_target_vecs, target);
283
- }
284
-
285
- /* Find optimal parameters. */
286
- while (curr_pos < n_elements && curr_el != last_el) {
287
- next_el = f[o[next_pos]];
288
- while (next_pos < n_elements && next_el == curr_el) {
289
- target = rb_ary_shift(r_target_vecs);
290
- n_r_elements--;
291
- sub_sum_vec(r_sum_vec, target);
292
- rb_ary_push(l_target_vecs, target);
293
- n_l_elements++;
294
- add_sum_vec(l_sum_vec, target);
295
- next_pos++;
296
- next_el = f[o[next_pos]];
297
- }
298
- /* Calculate gain of new split. */
299
- l_impurity = calc_impurity_reg(criterion, l_target_vecs, l_sum_vec);
300
- r_impurity = calc_impurity_reg(criterion, r_target_vecs, r_sum_vec);
301
- gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
302
- /* Update optimal parameters. */
303
- if (gain > params[3]) {
304
- params[0] = l_impurity;
305
- params[1] = r_impurity;
306
- params[2] = 0.5 * (curr_el + next_el);
307
- params[3] = gain;
308
- }
309
- if (next_pos == n_elements) break;
310
- curr_pos = next_pos;
311
- curr_el = f[o[curr_pos]];
312
- }
313
-
314
- xfree(l_sum_vec);
315
- xfree(r_sum_vec);
316
- }
317
-
318
- /**
319
- * @!visibility private
320
- * Find for split point with maximum information gain.
321
- *
322
- * @overload find_split_params(criterion, impurity, order, features, targets) -> Array<Float>
323
- *
324
- * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
325
- * @param impurity [Float] The impurity of whole dataset.
326
- * @param order [Numo::Int32] (shape: [n_samples]) The element indices sorted according to feature values in ascending order.
327
- * @param features [Numo::DFloat] (shape: [n_samples]) The feature values.
328
- * @param targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
329
- * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
330
- */
331
- static VALUE find_split_params_reg(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE targets) {
332
- ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 2}};
333
- size_t out_shape[1] = {4};
334
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
335
- ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_reg, NO_LOOP, 3, 1, ain, aout};
336
- split_opts_reg opts = {StringValuePtr(criterion), NUM2DBL(impurity)};
337
- VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, targets);
338
- VALUE results = rb_ary_new2(4);
339
- double* params_ptr = (double*)na_get_pointer_for_read(params);
340
- rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
341
- rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
342
- rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
343
- rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
344
- RB_GC_GUARD(params);
345
- RB_GC_GUARD(criterion);
346
- return results;
347
- }
348
-
349
- /**
350
- * @!visibility private
351
- */
352
- static void iter_find_split_params_grad_reg(na_loop_t const* lp) {
353
- const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
354
- const double* f = (double*)NDL_PTR(lp, 1);
355
- const double* g = (double*)NDL_PTR(lp, 2);
356
- const double* h = (double*)NDL_PTR(lp, 3);
357
- const double s_grad = ((double*)lp->opt_ptr)[0];
358
- const double s_hess = ((double*)lp->opt_ptr)[1];
359
- const double reg_lambda = ((double*)lp->opt_ptr)[2];
360
- const long n_elements = NDL_SHAPE(lp, 0)[0];
361
- double* params = (double*)NDL_PTR(lp, 4);
362
- long curr_pos = 0;
363
- long next_pos = 0;
364
- double curr_el = f[o[0]];
365
- double last_el = f[o[n_elements - 1]];
366
- double next_el;
367
- double l_grad = 0.0;
368
- double l_hess = 0.0;
369
- double r_grad;
370
- double r_hess;
371
- double threshold = curr_el;
372
- double gain_max = 0.0;
373
- double gain;
374
-
375
- /* Find optimal parameters. */
376
- while (curr_pos < n_elements && curr_el != last_el) {
377
- next_el = f[o[next_pos]];
378
- while (next_pos < n_elements && next_el == curr_el) {
379
- l_grad += g[o[next_pos]];
380
- l_hess += h[o[next_pos]];
381
- next_pos++;
382
- next_el = f[o[next_pos]];
383
- }
384
- /* Calculate gain of new split. */
385
- r_grad = s_grad - l_grad;
386
- r_hess = s_hess - l_hess;
387
- gain = (l_grad * l_grad) / (l_hess + reg_lambda) + (r_grad * r_grad) / (r_hess + reg_lambda) -
388
- (s_grad * s_grad) / (s_hess + reg_lambda);
389
- /* Update optimal parameters. */
390
- if (gain > gain_max) {
391
- threshold = 0.5 * (curr_el + next_el);
392
- gain_max = gain;
393
- }
394
- if (next_pos == n_elements) {
395
- break;
396
- }
397
- curr_pos = next_pos;
398
- curr_el = f[o[curr_pos]];
399
- }
400
-
401
- params[0] = threshold;
402
- params[1] = gain_max;
403
- }
404
-
405
- /**
406
- * @!visibility private
407
- * Find for split point with maximum information gain.
408
- *
409
- * @overload find_split_params(order, features, gradients, hessians, sum_gradient, sum_hessian, reg_lambda) -> Array<Float>
410
- * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
411
- * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
412
- * @param gradients [Numo::DFloat] (shape: [n_elements]) The gradient values.
413
- * @param hessians [Numo::DFloat] (shape: [n_elements]) The hessian values.
414
- * @param sum_gradient [Float] The sum of gradient values.
415
- * @param sum_hessian [Float] The sum of hessian values.
416
- * @param reg_lambda [Float] The L2 regularization term on weight.
417
- * @return [Array<Float>] The array consists of optimal parameters including threshold and gain.
418
- */
419
- static VALUE find_split_params_grad_reg(VALUE self, VALUE order, VALUE features, VALUE gradients, VALUE hessians,
420
- VALUE sum_gradient, VALUE sum_hessian, VALUE reg_lambda) {
421
- ndfunc_arg_in_t ain[4] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}};
422
- size_t out_shape[1] = {2};
423
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
424
- ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_grad_reg, NO_LOOP, 4, 1, ain, aout};
425
- double opts[3] = {NUM2DBL(sum_gradient), NUM2DBL(sum_hessian), NUM2DBL(reg_lambda)};
426
- VALUE params = na_ndloop3(&ndf, opts, 4, order, features, gradients, hessians);
427
- VALUE results = rb_ary_new2(2);
428
- double* params_ptr = (double*)na_get_pointer_for_read(params);
429
- rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
430
- rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
431
- RB_GC_GUARD(params);
432
- return results;
433
- }
434
-
435
- /**
436
- * @!visibility private
437
- */
438
- typedef struct {
439
- char* criterion;
440
- long n_classes;
441
- } node_impurity_cls_opts;
442
-
443
- /**
444
- * @!visibility private
445
- */
446
- static void iter_node_impurity_cls(na_loop_t const* lp) {
447
- const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
448
- const char* criterion = ((node_impurity_cls_opts*)lp->opt_ptr)->criterion;
449
- const long n_classes = ((node_impurity_cls_opts*)lp->opt_ptr)->n_classes;
450
- const long n_elements = NDL_SHAPE(lp, 0)[0];
451
- double* ret = (double*)NDL_PTR(lp, 1);
452
- double* histogram = alloc_dbl_array(n_classes);
453
- for (long i = 0; i < n_elements; i++) histogram[y[i]] += 1;
454
- *ret = calc_impurity_cls(criterion, histogram, n_elements, n_classes);
455
- xfree(histogram);
456
- }
457
-
458
- /**
459
- * @!visibility private
460
- * Calculate impurity based on criterion.
461
- *
462
- * @overload node_impurity(criterion, y, n_classes) -> Float
463
- *
464
- * @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
465
- * @param y [Numo::Int32] (shape: [n_samples]) The labels.
466
- * @param n_classes [Integer] The number of classes.
467
- * @return [Float] impurity
468
- */
469
- static VALUE node_impurity_cls(VALUE self, VALUE criterion, VALUE y, VALUE n_classes) {
470
- ndfunc_arg_in_t ain[1] = {{numo_cInt32, 1}};
471
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 0}};
472
- ndfunc_t ndf = {(na_iter_func_t)iter_node_impurity_cls, NDF_EXTRACT, 1, 1, ain, aout};
473
- node_impurity_cls_opts opts = {StringValuePtr(criterion), NUM2LONG(n_classes)};
474
- VALUE ret = na_ndloop3(&ndf, &opts, 1, y);
475
- RB_GC_GUARD(criterion);
476
- return ret;
477
- }
478
-
479
- /**
480
- * @!visibility private
481
- */
482
- static void iter_check_same_label(na_loop_t const* lp) {
483
- const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
484
- const long n_elements = NDL_SHAPE(lp, 0)[0];
485
- int32_t* ret = (int32_t*)NDL_PTR(lp, 1);
486
- *ret = 1;
487
- if (n_elements > 0) {
488
- int32_t label = y[0];
489
- for (long i = 0; i < n_elements; i++) {
490
- if (y[i] != label) {
491
- *ret = 0;
492
- break;
493
- }
494
- }
495
- }
496
- }
497
-
498
- /**
499
- * @!visibility private
500
- * Check all elements have the save value.
501
- *
502
- * @overload check_same_label(y) -> Boolean
503
- *
504
- * @param y [Numo::Int32] (shape: [n_samples]) The labels.
505
- * @return [Boolean]
506
- */
507
- static VALUE check_same_label(VALUE self, VALUE y) {
508
- ndfunc_arg_in_t ain[1] = {{numo_cInt32, 1}};
509
- ndfunc_arg_out_t aout[1] = {{numo_cInt32, 0}};
510
- ndfunc_t ndf = {(na_iter_func_t)iter_check_same_label, NDF_EXTRACT, 1, 1, ain, aout};
511
- VALUE ret = na_ndloop(&ndf, 1, y);
512
- return (NUM2INT(ret) == 1 ? Qtrue : Qfalse);
513
- }
514
-
515
- /**
516
- * @!visibility private
517
- * Calculate impurity based on criterion.
518
- *
519
- * @overload node_impurity(criterion, y) -> Float
520
- *
521
- * @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
522
- * @param y [Array<Float>] (shape: [n_samples, n_outputs]) The taget values.
523
- * @return [Float] impurity
524
- */
525
- static VALUE node_impurity_reg(VALUE self, VALUE criterion, VALUE y) {
526
- const long n_elements = RARRAY_LEN(y);
527
- const long n_outputs = RARRAY_LEN(rb_ary_entry(y, 0));
528
- double* sum_vec = alloc_dbl_array(n_outputs);
529
- VALUE target_vecs = rb_ary_new();
530
-
531
- for (long i = 0; i < n_elements; i++) {
532
- VALUE target = rb_ary_entry(y, i);
533
- add_sum_vec(sum_vec, target);
534
- rb_ary_push(target_vecs, target);
535
- }
536
-
537
- VALUE ret = DBL2NUM(calc_impurity_reg(StringValuePtr(criterion), target_vecs, sum_vec));
538
- xfree(sum_vec);
539
- RB_GC_GUARD(criterion);
540
- return ret;
541
- }
542
-
543
- void Init_ext(void) {
544
- VALUE rb_mRumale = rb_define_module("Rumale");
545
- VALUE rb_mTree = rb_define_module_under(rb_mRumale, "Tree");
546
-
547
- /**
548
- * Document-module: Rumale::Tree::ExtDecisionTreeClassifier
549
- * @!visibility private
550
- * The mixin module consisting of extension method for DecisionTreeClassifier class.
551
- * This module is used internally.
552
- */
553
- VALUE rb_mExtDTreeCls = rb_define_module_under(rb_mTree, "ExtDecisionTreeClassifier");
554
- /**
555
- * Document-module: Rumale::Tree::ExtDecisionTreeRegressor
556
- * @!visibility private
557
- * The mixin module consisting of extension method for DecisionTreeRegressor class.
558
- * This module is used internally.
559
- */
560
- VALUE rb_mExtDTreeReg = rb_define_module_under(rb_mTree, "ExtDecisionTreeRegressor");
561
- /**
562
- * Document-module: Rumale::Tree::ExtGradientTreeRegressor
563
- * @!visibility private
564
- * The mixin module consisting of extension method for GradientTreeRegressor class.
565
- * This module is used internally.
566
- */
567
- VALUE rb_mExtGTreeReg = rb_define_module_under(rb_mTree, "ExtGradientTreeRegressor");
568
-
569
- rb_define_private_method(rb_mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
570
- rb_define_private_method(rb_mExtDTreeReg, "find_split_params", find_split_params_reg, 5);
571
- rb_define_private_method(rb_mExtGTreeReg, "find_split_params", find_split_params_grad_reg, 7);
572
- rb_define_private_method(rb_mExtDTreeCls, "node_impurity", node_impurity_cls, 3);
573
- rb_define_private_method(rb_mExtDTreeCls, "stop_growing?", check_same_label, 1);
574
- rb_define_private_method(rb_mExtDTreeReg, "node_impurity", node_impurity_reg, 2);
575
- }
@@ -1,12 +0,0 @@
1
- #ifndef RUMALE_TREE_EXT_H
2
- #define RUMALE_TREE_EXT_H 1
3
-
4
- #include <math.h>
5
- #include <string.h>
6
-
7
- #include <ruby.h>
8
-
9
- #include <numo/narray.h>
10
- #include <numo/template.h>
11
-
12
- #endif /* RUMALE_TREE_EXT_H */