rumale-tree 0.27.0 → 0.28.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: a71957b21116e3224657fe98ca031bb49eac076f87f5a57fb16cd2ad5ecc2744
4
- data.tar.gz: a514b1afe06b53d8ad31215be9a26b97ac895f5c081de238a873aeaac48516cd
3
+ metadata.gz: a437c3b879dd51d2e823f1851ca8b350d0441947fb6cefe727de83db78b4e6d9
4
+ data.tar.gz: 519665f2baea649ec31c9d5529aa05c67a880daa8ab60149433a53483f630663
5
5
  SHA512:
6
- metadata.gz: 824e108abe1fdc1c7c15f4c9298fe6f9ee6189e37374049a4b5a7c9154021e89de8ea7bb87759fdb0d45867ea678ccc5f8c4911298dc103ee7829f8cdfd91062
7
- data.tar.gz: ab7a7fce4a106979a150605ebc383cbe66bc7733064e9158218629f3ef6d405908dd4d195c0ed423b09d9c55071a0b19373e9582edc58d4f8b323b02df317fa0
6
+ metadata.gz: aa68d3e2fbe99ef4c8d15a10f64098abfbbcdce9047053967e578edad688fa08a03a1ac4e4d23dbd5e67a4a23370057aa9462107bb961eb4d3f12782ce0c0f18
7
+ data.tar.gz: 6a6e6f167e31d56051768a881a4db0da7ef02f36977cb3969e8160bad385757153b1f6be32dd364ed3f67ed4672dd0c1bfd14cc80e0b033e00dd282d757955a9
@@ -0,0 +1,39 @@
1
+ /**
2
+ * Copyright (c) 2022-2023 Atsushi Tatsuma
3
+ * All rights reserved.
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * * Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * * Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * * Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ */
30
+
31
+ #include "ext.hpp"
32
+
33
+ extern "C" void Init_ext(void) {
34
+ VALUE rb_mRumale = rb_define_module("Rumale");
35
+ VALUE rb_mTree = rb_define_module_under(rb_mRumale, "Tree");
36
+ ExtDecisionTreeClassifier::define_module(rb_mTree);
37
+ ExtDecisionTreeRegressor::define_module(rb_mTree);
38
+ ExtGradientTreeRegressor::define_module(rb_mTree);
39
+ }
@@ -0,0 +1,550 @@
1
+ /**
2
+ * Copyright (c) 2022-2023 Atsushi Tatsuma
3
+ * All rights reserved.
4
+ *
5
+ * Redistribution and use in source and binary forms, with or without
6
+ * modification, are permitted provided that the following conditions are met:
7
+ *
8
+ * * Redistributions of source code must retain the above copyright notice, this
9
+ * list of conditions and the following disclaimer.
10
+ *
11
+ * * Redistributions in binary form must reproduce the above copyright notice,
12
+ * this list of conditions and the following disclaimer in the documentation
13
+ * and/or other materials provided with the distribution.
14
+ *
15
+ * * Neither the name of the copyright holder nor the names of its
16
+ * contributors may be used to endorse or promote products derived from
17
+ * this software without specific prior written permission.
18
+ *
19
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
+ */
30
+
31
+ #ifndef RUMALE_TREE_EXT_HPP
32
+ #define RUMALE_TREE_EXT_HPP 1
33
+
34
+ #include <cmath>
35
+ #include <limits>
36
+ #include <string>
37
+ #include <vector>
38
+
39
+ #include <ruby.h>
40
+
41
+ #include <numo/narray.h>
42
+ #include <numo/template.h>
43
+
44
+ /**
45
+ * @!visibility private
46
+ * Document-module: Rumale::Tree::ExtDecisionTreeClassifier
47
+ * The mixin module consisting of extension method for DecisionTreeClassifier class.
48
+ * This module is used internally.
49
+ */
50
+ class ExtDecisionTreeClassifier {
51
+ public:
52
+ static void define_module(VALUE& outer) {
53
+ VALUE rb_mExtDTreeCls = rb_define_module_under(outer, "ExtDecisionTreeClassifier");
54
+ rb_define_private_method(rb_mExtDTreeCls, "find_split_params", find_split_params_, 6);
55
+ rb_define_private_method(rb_mExtDTreeCls, "node_impurity", node_impurity_, 3);
56
+ rb_define_private_method(rb_mExtDTreeCls, "stop_growing?", check_same_label_, 1);
57
+ }
58
+
59
+ private:
60
+ static double calc_impurity_(const std::string& criterion, const std::vector<size_t>& histogram, const size_t& n_elements, const size_t& n_classes) {
61
+ double impurity = 0.0;
62
+ if (criterion == "entropy") {
63
+ double entropy = 0.0;
64
+ for (size_t i = 0; i < n_classes; i++) {
65
+ const double el = static_cast<double>(histogram[i]) / static_cast<double>(n_elements);
66
+ entropy += el * std::log(el + 1.0);
67
+ }
68
+ impurity = -entropy;
69
+ } else {
70
+ double gini = 0.0;
71
+ for (size_t i = 0; i < n_classes; i++) {
72
+ const double el = static_cast<double>(histogram[i]) / static_cast<double>(n_elements);
73
+ gini += el * el;
74
+ }
75
+ impurity = 1.0 - gini;
76
+ }
77
+ return impurity;
78
+ }
79
+
80
+ /**
81
+ * @!visibility private
82
+ * Find for split point with maximum information gain.
83
+ *
84
+ * @overload find_split_params(criterion, impurity, order, features, labels, n_classes) -> Array<Float>
85
+ *
86
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
87
+ * @param impurity [Float] The impurity of whole dataset.
88
+ * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
89
+ * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
90
+ * @param labels [Numo::Int32] (shape: [n_elements]) The labels.
91
+ * @param n_classes [Integer] The number of classes.
92
+ * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
93
+ */
94
+
95
+ struct FindSplitParamsOpts_ {
96
+ std::string criterion;
97
+ size_t n_classes;
98
+ double impurity;
99
+ };
100
+
101
+ static void iter_find_split_params_(na_loop_t const* lp) {
102
+ // Obtain iteration variables.
103
+ const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
104
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
105
+ const double* f = (double*)NDL_PTR(lp, 1);
106
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 2);
107
+ const std::string criterion = ((FindSplitParamsOpts_*)lp->opt_ptr)->criterion;
108
+ const size_t n_classes = ((FindSplitParamsOpts_*)lp->opt_ptr)->n_classes;
109
+ const double w_impurity = ((FindSplitParamsOpts_*)lp->opt_ptr)->impurity;
110
+
111
+ // Initialize output optimal parameters.
112
+ double* params = (double*)NDL_PTR(lp, 3);
113
+ params[0] = 0.0; // left impurity
114
+ params[1] = w_impurity; // right impurity
115
+ params[2] = f[o[0]]; // threshold
116
+ params[3] = 0.0; // gain
117
+
118
+ // Initialize child node variables.
119
+ std::vector<size_t> r_histogram(n_classes, 0);
120
+ for (size_t i = 0; i < n_elements; i++) r_histogram[y[o[i]]] += 1;
121
+
122
+ // Find optimal parameters.
123
+ size_t curr_pos = 0;
124
+ size_t next_pos = 0;
125
+ size_t n_l_elements = 0;
126
+ size_t n_r_elements = n_elements;
127
+ double curr_el = f[o[0]];
128
+ const double last_el = f[o[n_elements - 1]];
129
+ std::vector<size_t> l_histogram(n_classes, 0);
130
+ while (curr_pos < n_elements && curr_el != last_el) {
131
+ double next_el = f[o[next_pos]];
132
+ while (next_pos < n_elements && next_el == curr_el) {
133
+ l_histogram[y[o[next_pos]]] += 1;
134
+ n_l_elements++;
135
+ r_histogram[y[o[next_pos]]] -= 1;
136
+ n_r_elements--;
137
+ next_pos++;
138
+ next_el = f[o[next_pos]];
139
+ }
140
+ // Calculate gain of new split.
141
+ const double l_impurity = calc_impurity_(criterion, l_histogram, n_l_elements, n_classes);
142
+ const double r_impurity = calc_impurity_(criterion, r_histogram, n_r_elements, n_classes);
143
+ const double gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / static_cast<double>(n_elements);
144
+ // Update optimal parameters.
145
+ if (gain > params[3]) {
146
+ params[0] = l_impurity;
147
+ params[1] = r_impurity;
148
+ params[2] = 0.5 * (curr_el + next_el);
149
+ params[3] = gain;
150
+ }
151
+ if (next_pos == n_elements) break;
152
+ curr_pos = next_pos;
153
+ curr_el = f[o[curr_pos]];
154
+ }
155
+ }
156
+
157
+ static VALUE find_split_params_(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE labels, VALUE n_classes) {
158
+ ndfunc_arg_in_t ain[3] = { { numo_cInt32, 1 }, { numo_cDFloat, 1 }, { numo_cInt32, 1 } };
159
+ size_t out_shape[1] = { 4 };
160
+ ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 1, out_shape } };
161
+ ndfunc_t ndf = { (na_iter_func_t)iter_find_split_params_, NO_LOOP, 3, 1, ain, aout };
162
+ FindSplitParamsOpts_ opts = { std::string(StringValueCStr(criterion)), NUM2SIZET(n_classes), NUM2DBL(impurity) };
163
+ VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, labels);
164
+ RB_GC_GUARD(criterion);
165
+ return params;
166
+ }
167
+
168
+ /**
169
+ * @!visibility private
170
+ * Calculate impurity based on criterion.
171
+ *
172
+ * @overload node_impurity(criterion, y, n_classes) -> Float
173
+ *
174
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
175
+ * @param y [Numo::Int32] (shape: [n_samples]) The labels.
176
+ * @param n_classes [Integer] The number of classes.
177
+ * @return [Float] impurity
178
+ */
179
+
180
+ struct NodeImpurityOpts_ {
181
+ std::string criterion;
182
+ size_t n_classes;
183
+ };
184
+
185
+ static void iter_node_impurity_(na_loop_t const* lp) {
186
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
187
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
188
+ const std::string criterion = ((NodeImpurityOpts_*)lp->opt_ptr)->criterion;
189
+ const size_t n_classes = ((NodeImpurityOpts_*)lp->opt_ptr)->n_classes;
190
+ double* ret = (double*)NDL_PTR(lp, 1);
191
+ std::vector<size_t> histogram(n_classes, 0);
192
+ for (size_t i = 0; i < n_elements; i++) histogram[y[i]] += 1;
193
+ *ret = calc_impurity_(criterion, histogram, n_elements, n_classes);
194
+ }
195
+
196
+ static VALUE node_impurity_(VALUE self, VALUE criterion, VALUE y, VALUE n_classes) {
197
+ ndfunc_arg_in_t ain[1] = { { numo_cInt32, 1 } };
198
+ ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 0 } };
199
+ ndfunc_t ndf = { (na_iter_func_t)iter_node_impurity_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
200
+ NodeImpurityOpts_ opts = { std::string(StringValueCStr(criterion)), NUM2SIZET(n_classes) };
201
+ VALUE ret = na_ndloop3(&ndf, &opts, 1, y);
202
+ RB_GC_GUARD(criterion);
203
+ return ret;
204
+ }
205
+
206
+ /**
207
+ * @!visibility private
208
+ * Check all elements have the same value.
209
+ *
210
+ * @overload check_same_label(y) -> Boolean
211
+ *
212
+ * @param y [Numo::Int32] (shape: [n_samples]) The labels.
213
+ * @return [Boolean]
214
+ */
215
+
216
+ static void iter_check_same_label_(na_loop_t const* lp) {
217
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
218
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
219
+ VALUE* ret = (VALUE*)NDL_PTR(lp, 1);
220
+ *ret = Qtrue;
221
+ if (n_elements > 0) {
222
+ int32_t label = y[0];
223
+ for (size_t i = 0; i < n_elements; i++) {
224
+ if (y[i] != label) {
225
+ *ret = Qfalse;
226
+ break;
227
+ }
228
+ }
229
+ }
230
+ }
231
+
232
+ static VALUE check_same_label_(VALUE self, VALUE y) {
233
+ ndfunc_arg_in_t ain[1] = { { numo_cInt32, 1 } };
234
+ ndfunc_arg_out_t aout[1] = { { numo_cRObject, 0 } };
235
+ ndfunc_t ndf = { (na_iter_func_t)iter_check_same_label_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
236
+ return na_ndloop(&ndf, 1, y);
237
+ }
238
+ };
239
+
240
+ /**
241
+ * @!visibility private
242
+ * Document-module: Rumale::Tree::ExtDecisionTreeRegressor
243
+ * The mixin module consisting of extension method for DecisionTreeRegressor class.
244
+ * This module is used internally.
245
+ */
246
+ class ExtDecisionTreeRegressor {
247
+ public:
248
+ static void define_module(VALUE& outer) {
249
+ VALUE rb_mExtDTreeReg = rb_define_module_under(outer, "ExtDecisionTreeRegressor");
250
+ rb_define_private_method(rb_mExtDTreeReg, "find_split_params", find_split_params_, 5);
251
+ rb_define_private_method(rb_mExtDTreeReg, "node_impurity", node_impurity_, 2);
252
+ rb_define_private_method(rb_mExtDTreeReg, "stop_growing?", check_same_value_, 1);
253
+ }
254
+
255
+ private:
256
+ static double calc_impurity_(const std::string& criterion, const int32_t* order, const double* vecs, const double* mean_vec,
257
+ const size_t& n_elements, const size_t& n_outputs, const size_t& order_offset) {
258
+ const bool is_mae = criterion == "mae";
259
+ double sum_err = 0.0;
260
+ for (size_t i = 0; i < n_elements; i++) {
261
+ double err = 0.0;
262
+ for (size_t j = 0; j < n_outputs; j++) {
263
+ const double el = vecs[order[order_offset + i] * n_outputs + j] - mean_vec[j];
264
+ err += is_mae ? std::fabs(el) : el * el;
265
+ }
266
+ err /= static_cast<double>(n_outputs);
267
+ sum_err += err;
268
+ }
269
+ const double impurity = sum_err / static_cast<double>(n_elements);
270
+ return impurity;
271
+ }
272
+
273
+ /**
274
+ * @!visibility private
275
+ * Find for split point with maximum information gain.
276
+ *
277
+ * @overload find_split_params(criterion, impurity, order, features, targets) -> Array<Float>
278
+ *
279
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
280
+ * @param impurity [Float] The impurity of whole dataset.
281
+ * @param order [Numo::Int32] (shape: [n_samples]) The element indices sorted according to feature values in ascending order.
282
+ * @param features [Numo::DFloat] (shape: [n_samples]) The feature values.
283
+ * @param targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
284
+ * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
285
+ */
286
+
287
+ struct FindSplitParamsOpts_ {
288
+ std::string criterion;
289
+ double impurity;
290
+ };
291
+
292
+ static void iter_find_split_params_(na_loop_t const* lp) {
293
+ // Obtain iteration variables.
294
+ const int32_t* order = (int32_t*)NDL_PTR(lp, 0);
295
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
296
+ const double* f = (double*)NDL_PTR(lp, 1);
297
+ const double* y = (double*)NDL_PTR(lp, 2);
298
+ const size_t n_outputs = NDL_SHAPE(lp, 2)[1];
299
+ const std::string criterion = ((FindSplitParamsOpts_*)lp->opt_ptr)->criterion;
300
+ const double w_impurity = ((FindSplitParamsOpts_*)lp->opt_ptr)->impurity;
301
+
302
+ // Initialize optimal parameters.
303
+ double* params = (double*)NDL_PTR(lp, 3);
304
+ params[0] = 0.0; // left impurity
305
+ params[1] = w_impurity; // right impurity
306
+ params[2] = f[order[0]]; // threshold
307
+ params[3] = 0.0; // gain
308
+
309
+ // Initialize child node variables.
310
+ std::vector<double> l_sum_y(n_outputs, 0);
311
+ std::vector<double> r_sum_y(n_outputs, 0);
312
+ for (size_t i = 0; i < n_elements; i++) {
313
+ for (size_t j = 0; j < n_outputs; j++) {
314
+ r_sum_y[j] += y[order[i] * n_outputs + j];
315
+ }
316
+ }
317
+
318
+ // Find optimal parameters.
319
+ size_t curr_pos = 0;
320
+ size_t next_pos = 0;
321
+ size_t n_l_elements = 0;
322
+ size_t n_r_elements = n_elements;
323
+ std::vector<double> l_mean_y(n_outputs, 0);
324
+ std::vector<double> r_mean_y(n_outputs, 0);
325
+ double curr_el = f[order[0]];
326
+ const double last_el = f[order[n_elements - 1]];
327
+ while (curr_pos < n_elements && curr_el != last_el) {
328
+ double next_el = f[order[next_pos]];
329
+ while (next_pos < n_elements && next_el == curr_el) {
330
+ for (size_t j = 0; j < n_outputs; j++) {
331
+ l_sum_y[j] += y[order[next_pos] * n_outputs + j];
332
+ r_sum_y[j] -= y[order[next_pos] * n_outputs + j];
333
+ }
334
+ n_l_elements++;
335
+ n_r_elements--;
336
+ next_pos++;
337
+ next_el = f[order[next_pos]];
338
+ }
339
+ // Calculate gain of new split.
340
+ for (size_t j = 0; j < n_outputs; j++) {
341
+ l_mean_y[j] = l_sum_y[j] / static_cast<double>(n_l_elements);
342
+ r_mean_y[j] = r_sum_y[j] / static_cast<double>(n_r_elements);
343
+ }
344
+ const double l_impurity = calc_impurity_(criterion, order, y, l_mean_y.data(), n_l_elements, n_outputs, 0);
345
+ const double r_impurity = calc_impurity_(criterion, order, y, r_mean_y.data(), n_r_elements, n_outputs, next_pos);
346
+ const double gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / static_cast<double>(n_elements);
347
+ // Update optimal parameters.
348
+ if (gain > params[3]) {
349
+ params[0] = l_impurity;
350
+ params[1] = r_impurity;
351
+ params[2] = 0.5 * (curr_el + next_el);
352
+ params[3] = gain;
353
+ }
354
+ if (next_pos == n_elements) break;
355
+ curr_pos = next_pos;
356
+ curr_el = f[order[curr_pos]];
357
+ }
358
+ }
359
+
360
+ static VALUE find_split_params_(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE targets) {
361
+ ndfunc_arg_in_t ain[3] = { { numo_cInt32, 1 }, { numo_cDFloat, 1 }, { numo_cDFloat, 2 } };
362
+ size_t out_shape[1] = { 4 };
363
+ ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 1, out_shape } };
364
+ ndfunc_t ndf = { (na_iter_func_t)iter_find_split_params_, NO_LOOP, 3, 1, ain, aout };
365
+ FindSplitParamsOpts_ opts = { std::string(StringValueCStr(criterion)), NUM2DBL(impurity) };
366
+ VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, targets);
367
+ RB_GC_GUARD(criterion);
368
+ return params;
369
+ }
370
+
371
+ /**
372
+ * @!visibility private
373
+ * Calculate impurity based on criterion.
374
+ *
375
+ * @overload node_impurity(criterion, y) -> Float
376
+ *
377
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
378
+ * @param y [Array<Float>] (shape: [n_samples, n_outputs]) The taget values.
379
+ * @return [Float] impurity
380
+ */
381
+
382
+ struct NodeImpurityOpts_ {
383
+ std::string criterion;
384
+ };
385
+
386
+ static void iter_node_impurity_(na_loop_t const* lp) {
387
+ const double* y = (double*)NDL_PTR(lp, 0);
388
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
389
+ const size_t n_outputs = NDL_SHAPE(lp, 0)[1];
390
+ const std::string criterion = ((NodeImpurityOpts_*)lp->opt_ptr)->criterion;
391
+
392
+ std::vector<int32_t> order(n_elements);
393
+ std::vector<double> mean_y(n_outputs, 0);
394
+ for (size_t i = 0; i < n_elements; i++) {
395
+ order[i] = static_cast<int32_t>(i);
396
+ for (size_t j = 0; j < n_outputs; j++) {
397
+ mean_y[j] += y[i * n_outputs + j];
398
+ }
399
+ }
400
+ for (size_t j = 0; j < n_outputs; j++) {
401
+ mean_y[j] /= static_cast<double>(n_elements);
402
+ }
403
+
404
+ double* ret = (double*)NDL_PTR(lp, 1);
405
+ *ret = calc_impurity_(criterion, order.data(), y, mean_y.data(), n_elements, n_outputs, 0);
406
+ }
407
+
408
+ static VALUE node_impurity_(VALUE self, VALUE criterion, VALUE y) {
409
+ ndfunc_arg_in_t ain[1] = { { numo_cDFloat, 2 } };
410
+ ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 0 } };
411
+ ndfunc_t ndf = { (na_iter_func_t)iter_node_impurity_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
412
+ NodeImpurityOpts_ opts = { std::string(StringValueCStr(criterion)) };
413
+ VALUE ret = na_ndloop3(&ndf, &opts, 1, y);
414
+ RB_GC_GUARD(criterion);
415
+ return ret;
416
+ }
417
+
418
+ /**
419
+ * @!visibility private
420
+ * Check all elements have the same value/vector.
421
+ *
422
+ * @overload check_same_value(y) -> Boolean
423
+ *
424
+ * @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
425
+ * @return [Boolean]
426
+ */
427
+
428
+ static void iter_check_same_value_(na_loop_t const* lp) {
429
+ const double* y = (double*)NDL_PTR(lp, 0);
430
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
431
+ const size_t n_outputs = NDL_SHAPE(lp, 0)[1];
432
+ VALUE* ret = (VALUE*)NDL_PTR(lp, 1);
433
+ const double eps = std::numeric_limits<double>::epsilon();
434
+ *ret = Qtrue;
435
+ if (n_elements > 0) {
436
+ for (size_t i = 1; i < n_elements; i++) {
437
+ for (size_t j = 0; j < n_outputs; j++) {
438
+ if (std::abs(y[i * n_outputs + j] - y[j]) > eps) {
439
+ *ret = Qfalse;
440
+ break;
441
+ }
442
+ }
443
+ if (*ret == Qfalse) break;
444
+ }
445
+ }
446
+ }
447
+
448
+ static VALUE check_same_value_(VALUE self, VALUE y) {
449
+ ndfunc_arg_in_t ain[1] = { { numo_cDFloat, 2 } };
450
+ ndfunc_arg_out_t aout[1] = { { numo_cRObject, 0 } };
451
+ ndfunc_t ndf = { (na_iter_func_t)iter_check_same_value_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
452
+ return na_ndloop(&ndf, 1, y);
453
+ }
454
+ };
455
+
456
+ /**
457
+ * @!visibility private
458
+ * Document-module: Rumale::Tree::ExtGradientTreeRegressor
459
+ * The mixin module consisting of extension method for GradientTreeRegressor class.
460
+ * This module is used internally.
461
+ */
462
+ class ExtGradientTreeRegressor {
463
+ public:
464
+ static void define_module(VALUE& outer) {
465
+ VALUE rb_mExtGTreeReg = rb_define_module_under(outer, "ExtGradientTreeRegressor");
466
+ rb_define_private_method(rb_mExtGTreeReg, "find_split_params", find_split_params_, 7);
467
+ }
468
+
469
+ private:
470
+ /**
471
+ * @!visibility private
472
+ * Find for split point with maximum information gain.
473
+ *
474
+ * @overload find_split_params(order, features, gradients, hessians, sum_gradient, sum_hessian, reg_lambda) -> Array<Float>
475
+ * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
476
+ * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
477
+ * @param gradients [Numo::DFloat] (shape: [n_elements]) The gradient values.
478
+ * @param hessians [Numo::DFloat] (shape: [n_elements]) The hessian values.
479
+ * @param sum_gradient [Float] The sum of gradient values.
480
+ * @param sum_hessian [Float] The sum of hessian values.
481
+ * @param reg_lambda [Float] The L2 regularization term on weight.
482
+ * @return [Array<Float>] The array consists of optimal parameters including threshold and gain.
483
+ */
484
+
485
+ struct FindSplitParamsOpts_ {
486
+ double sum_gradient;
487
+ double sum_hessian;
488
+ double reg_lambda;
489
+ };
490
+
491
+ static void iter_find_split_params_(na_loop_t const* lp) {
492
+ // Obtain iteration variables.
493
+ const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
494
+ const size_t n_elements = NDL_SHAPE(lp, 0)[0];
495
+ const double* f = (double*)NDL_PTR(lp, 1);
496
+ const double* g = (double*)NDL_PTR(lp, 2);
497
+ const double* h = (double*)NDL_PTR(lp, 3);
498
+ const double s_grad = ((FindSplitParamsOpts_*)lp->opt_ptr)->sum_gradient;
499
+ const double s_hess = ((FindSplitParamsOpts_*)lp->opt_ptr)->sum_hessian;
500
+ const double reg_lambda = ((FindSplitParamsOpts_*)lp->opt_ptr)->reg_lambda;
501
+
502
+ // Find optimal parameters.
503
+ size_t curr_pos = 0;
504
+ size_t next_pos = 0;
505
+ double curr_el = f[o[0]];
506
+ const double last_el = f[o[n_elements - 1]];
507
+ double l_grad = 0.0;
508
+ double l_hess = 0.0;
509
+ double threshold = curr_el;
510
+ double gain_max = 0.0;
511
+ while (curr_pos < n_elements && curr_el != last_el) {
512
+ double next_el = f[o[next_pos]];
513
+ while (next_pos < n_elements && next_el == curr_el) {
514
+ l_grad += g[o[next_pos]];
515
+ l_hess += h[o[next_pos]];
516
+ next_pos++;
517
+ next_el = f[o[next_pos]];
518
+ }
519
+ // Calculate gain of new split.
520
+ const double r_grad = s_grad - l_grad;
521
+ const double r_hess = s_hess - l_hess;
522
+ const double gain = (l_grad * l_grad) / (l_hess + reg_lambda) + (r_grad * r_grad) / (r_hess + reg_lambda) - (s_grad * s_grad) / (s_hess + reg_lambda);
523
+ // Update optimal parameters.
524
+ if (gain > gain_max) {
525
+ threshold = 0.5 * (curr_el + next_el);
526
+ gain_max = gain;
527
+ }
528
+ if (next_pos == n_elements) break;
529
+ curr_pos = next_pos;
530
+ curr_el = f[o[curr_pos]];
531
+ }
532
+
533
+ double* params = (double*)NDL_PTR(lp, 4);
534
+ params[0] = threshold;
535
+ params[1] = gain_max;
536
+ }
537
+
538
+ static VALUE find_split_params_(VALUE self, VALUE order, VALUE features, VALUE gradients, VALUE hessians,
539
+ VALUE sum_gradient, VALUE sum_hessian, VALUE reg_lambda) {
540
+ ndfunc_arg_in_t ain[4] = { { numo_cInt32, 1 }, { numo_cDFloat, 1 }, { numo_cDFloat, 1 }, { numo_cDFloat, 1 } };
541
+ size_t out_shape[1] = { 2 };
542
+ ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 1, out_shape } };
543
+ ndfunc_t ndf = { (na_iter_func_t)iter_find_split_params_, NO_LOOP, 4, 1, ain, aout };
544
+ FindSplitParamsOpts_ opts = { NUM2DBL(sum_gradient), NUM2DBL(sum_hessian), NUM2DBL(reg_lambda) };
545
+ VALUE params = na_ndloop3(&ndf, &opts, 4, order, features, gradients, hessians);
546
+ return params;
547
+ }
548
+ };
549
+
550
+ #endif /* RUMALE_TREE_EXT_HPP */
@@ -3,6 +3,8 @@
3
3
  require 'mkmf'
4
4
  require 'numo/narray'
5
5
 
6
+ abort 'libstdc++ is not found.' unless have_library('stdc++')
7
+
6
8
  $LOAD_PATH.each do |lp|
7
9
  if File.exist?(File.join(lp, 'numo/numo/narray.h'))
8
10
  $INCFLAGS = "-I#{lp}/numo #{$INCFLAGS}"
@@ -63,11 +63,7 @@ module Rumale
63
63
  end
64
64
 
65
65
  def build_tree(x, y)
66
- y = y.expand_dims(1).dup if y.shape[1].nil?
67
- @feature_ids = Array.new(x.shape[1]) { |v| v }
68
- @tree = grow_node(0, x, y, impurity(y))
69
- @feature_ids = nil
70
- nil
66
+ raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
71
67
  end
72
68
 
73
69
  def grow_node(depth, x, y, impurity) # rubocop:disable Metrics/AbcSize, Metrics/PerceivedComplexity
@@ -68,6 +68,7 @@ module Rumale
68
68
  @params[:max_features] = [@params[:max_features], n_features].min
69
69
  @n_leaves = 0
70
70
  @leaf_values = []
71
+ @feature_ids = Array.new(x.shape[1]) { |v| v }
71
72
  @sub_rng = @rng.dup
72
73
  build_tree(x, y)
73
74
  eval_importance(n_samples, n_features)
@@ -88,8 +89,10 @@ module Rumale
88
89
 
89
90
  private
90
91
 
91
- def stop_growing?(y)
92
- y.to_a.uniq.size == 1
92
+ def build_tree(x, y)
93
+ y = y.expand_dims(1).dup if y.shape[1].nil?
94
+ @tree = grow_node(0, x, y, impurity(y))
95
+ nil
93
96
  end
94
97
 
95
98
  def put_leaf(node, y)
@@ -106,7 +109,7 @@ module Rumale
106
109
  end
107
110
 
108
111
  def impurity(y)
109
- node_impurity(@params[:criterion], y.to_a)
112
+ node_impurity(@params[:criterion], y)
110
113
  end
111
114
  end
112
115
  end
@@ -5,6 +5,6 @@ module Rumale
5
5
  # This module consists of the classes that implement tree models.
6
6
  module Tree
7
7
  # @!visibility private
8
- VERSION = '0.27.0'
8
+ VERSION = '0.28.0'
9
9
  end
10
10
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-tree
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.27.0
4
+ version: 0.28.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2023-08-26 00:00:00.000000000 Z
11
+ date: 2023-11-12 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -30,14 +30,14 @@ dependencies:
30
30
  requirements:
31
31
  - - "~>"
32
32
  - !ruby/object:Gem::Version
33
- version: 0.27.0
33
+ version: 0.28.0
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - "~>"
39
39
  - !ruby/object:Gem::Version
40
- version: 0.27.0
40
+ version: 0.28.0
41
41
  description: Rumale::Tree provides classifier and regression based on decision tree
42
42
  algorithms with Rumale interface.
43
43
  email:
@@ -49,8 +49,8 @@ extra_rdoc_files: []
49
49
  files:
50
50
  - LICENSE.txt
51
51
  - README.md
52
- - ext/rumale/tree/ext.c
53
- - ext/rumale/tree/ext.h
52
+ - ext/rumale/tree/ext.cpp
53
+ - ext/rumale/tree/ext.hpp
54
54
  - ext/rumale/tree/extconf.rb
55
55
  - lib/rumale/tree.rb
56
56
  - lib/rumale/tree/base_decision_tree.rb
@@ -85,7 +85,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
85
85
  - !ruby/object:Gem::Version
86
86
  version: '0'
87
87
  requirements: []
88
- rubygems_version: 3.3.26
88
+ rubygems_version: 3.4.20
89
89
  signing_key:
90
90
  specification_version: 4
91
91
  summary: Rumale::Tree provides classifier and regression based on decision tree algorithms
@@ -1,556 +0,0 @@
1
- #include "ext.h"
2
-
3
- double* alloc_dbl_array(const long n_dimensions) {
4
- double* arr = ALLOC_N(double, n_dimensions);
5
- memset(arr, 0, n_dimensions * sizeof(double));
6
- return arr;
7
- }
8
-
9
- double calc_gini_coef(double* histogram, const long n_elements, const long n_classes) {
10
- double gini = 0.0;
11
-
12
- for (long i = 0; i < n_classes; i++) {
13
- double el = histogram[i] / n_elements;
14
- gini += el * el;
15
- }
16
-
17
- return 1.0 - gini;
18
- }
19
-
20
- double calc_entropy(double* histogram, const long n_elements, const long n_classes) {
21
- double entropy = 0.0;
22
-
23
- for (long i = 0; i < n_classes; i++) {
24
- double el = histogram[i] / n_elements;
25
- entropy += el * log(el + 1.0);
26
- }
27
-
28
- return -entropy;
29
- }
30
-
31
- VALUE
32
- calc_mean_vec(double* sum_vec, const long n_dimensions, const long n_elements) {
33
- VALUE mean_vec = rb_ary_new2(n_dimensions);
34
-
35
- for (long i = 0; i < n_dimensions; i++) {
36
- rb_ary_store(mean_vec, i, DBL2NUM(sum_vec[i] / n_elements));
37
- }
38
-
39
- return mean_vec;
40
- }
41
-
42
- double calc_vec_mae(VALUE vec_a, VALUE vec_b) {
43
- const long n_dimensions = RARRAY_LEN(vec_a);
44
- double sum = 0.0;
45
-
46
- for (long i = 0; i < n_dimensions; i++) {
47
- double diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
48
- sum += fabs(diff);
49
- }
50
-
51
- return sum / n_dimensions;
52
- }
53
-
54
- double calc_vec_mse(VALUE vec_a, VALUE vec_b) {
55
- const long n_dimensions = RARRAY_LEN(vec_a);
56
- double sum = 0.0;
57
-
58
- for (long i = 0; i < n_dimensions; i++) {
59
- double diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
60
- sum += diff * diff;
61
- }
62
-
63
- return sum / n_dimensions;
64
- }
65
-
66
- double calc_mae(VALUE target_vecs, VALUE mean_vec) {
67
- const long n_elements = RARRAY_LEN(target_vecs);
68
- double sum = 0.0;
69
-
70
- for (long i = 0; i < n_elements; i++) {
71
- sum += calc_vec_mae(rb_ary_entry(target_vecs, i), mean_vec);
72
- }
73
-
74
- return sum / n_elements;
75
- }
76
-
77
- double calc_mse(VALUE target_vecs, VALUE mean_vec) {
78
- const long n_elements = RARRAY_LEN(target_vecs);
79
- double sum = 0.0;
80
-
81
- for (long i = 0; i < n_elements; i++) {
82
- sum += calc_vec_mse(rb_ary_entry(target_vecs, i), mean_vec);
83
- }
84
-
85
- return sum / n_elements;
86
- }
87
-
88
- double calc_impurity_cls(const char* criterion, double* histogram, const long n_elements, const long n_classes) {
89
- if (strcmp(criterion, "entropy") == 0) {
90
- return calc_entropy(histogram, n_elements, n_classes);
91
- }
92
- return calc_gini_coef(histogram, n_elements, n_classes);
93
- }
94
-
95
- double calc_impurity_reg(const char* criterion, VALUE target_vecs, double* sum_vec) {
96
- const long n_elements = RARRAY_LEN(target_vecs);
97
- const long n_dimensions = RARRAY_LEN(rb_ary_entry(target_vecs, 0));
98
- VALUE mean_vec = calc_mean_vec(sum_vec, n_dimensions, n_elements);
99
-
100
- if (strcmp(criterion, "mae") == 0) {
101
- return calc_mae(target_vecs, mean_vec);
102
- }
103
- return calc_mse(target_vecs, mean_vec);
104
- }
105
-
106
- void add_sum_vec(double* sum_vec, VALUE target) {
107
- const long n_dimensions = RARRAY_LEN(target);
108
-
109
- for (long i = 0; i < n_dimensions; i++) {
110
- sum_vec[i] += NUM2DBL(rb_ary_entry(target, i));
111
- }
112
- }
113
-
114
- void sub_sum_vec(double* sum_vec, VALUE target) {
115
- const long n_dimensions = RARRAY_LEN(target);
116
-
117
- for (long i = 0; i < n_dimensions; i++) {
118
- sum_vec[i] -= NUM2DBL(rb_ary_entry(target, i));
119
- }
120
- }
121
-
122
- /**
123
- * @!visibility private
124
- */
125
- typedef struct {
126
- char* criterion;
127
- long n_classes;
128
- double impurity;
129
- } split_opts_cls;
130
-
131
- /**
132
- * @!visibility private
133
- */
134
- static void iter_find_split_params_cls(na_loop_t const* lp) {
135
- const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
136
- const double* f = (double*)NDL_PTR(lp, 1);
137
- const int32_t* y = (int32_t*)NDL_PTR(lp, 2);
138
- const long n_elements = NDL_SHAPE(lp, 0)[0];
139
- const char* criterion = ((split_opts_cls*)lp->opt_ptr)->criterion;
140
- const long n_classes = ((split_opts_cls*)lp->opt_ptr)->n_classes;
141
- const double w_impurity = ((split_opts_cls*)lp->opt_ptr)->impurity;
142
- double* params = (double*)NDL_PTR(lp, 3);
143
- long curr_pos = 0;
144
- long next_pos = 0;
145
- long n_l_elements = 0;
146
- long n_r_elements = n_elements;
147
- double curr_el = f[o[0]];
148
- double last_el = f[o[n_elements - 1]];
149
- double next_el;
150
- double l_impurity;
151
- double r_impurity;
152
- double gain;
153
- double* l_histogram = alloc_dbl_array(n_classes);
154
- double* r_histogram = alloc_dbl_array(n_classes);
155
-
156
- /* Initialize optimal parameters. */
157
- params[0] = 0.0; /* left impurity */
158
- params[1] = w_impurity; /* right impurity */
159
- params[2] = curr_el; /* threshold */
160
- params[3] = 0.0; /* gain */
161
-
162
- /* Initialize child node variables. */
163
- for (long i = 0; i < n_elements; i++) {
164
- r_histogram[y[o[i]]] += 1.0;
165
- }
166
-
167
- /* Find optimal parameters. */
168
- while (curr_pos < n_elements && curr_el != last_el) {
169
- next_el = f[o[next_pos]];
170
- while (next_pos < n_elements && next_el == curr_el) {
171
- l_histogram[y[o[next_pos]]] += 1;
172
- n_l_elements++;
173
- r_histogram[y[o[next_pos]]] -= 1;
174
- n_r_elements--;
175
- next_pos++;
176
- next_el = f[o[next_pos]];
177
- }
178
- /* Calculate gain of new split. */
179
- l_impurity = calc_impurity_cls(criterion, l_histogram, n_l_elements, n_classes);
180
- r_impurity = calc_impurity_cls(criterion, r_histogram, n_r_elements, n_classes);
181
- gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
182
- /* Update optimal parameters. */
183
- if (gain > params[3]) {
184
- params[0] = l_impurity;
185
- params[1] = r_impurity;
186
- params[2] = 0.5 * (curr_el + next_el);
187
- params[3] = gain;
188
- }
189
- if (next_pos == n_elements) break;
190
- curr_pos = next_pos;
191
- curr_el = f[o[curr_pos]];
192
- }
193
-
194
- xfree(l_histogram);
195
- xfree(r_histogram);
196
- }
197
-
198
- /**
199
- * @!visibility private
200
- * Find for split point with maximum information gain.
201
- *
202
- * @overload find_split_params(criterion, impurity, order, features, labels, n_classes) -> Array<Float>
203
- *
204
- * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
205
- * @param impurity [Float] The impurity of whole dataset.
206
- * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
207
- * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
208
- * @param labels [Numo::Int32] (shape: [n_elements]) The labels.
209
- * @param n_classes [Integer] The number of classes.
210
- * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
211
- */
212
- static VALUE find_split_params_cls(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE labels,
213
- VALUE n_classes) {
214
- ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cInt32, 1}};
215
- size_t out_shape[1] = {4};
216
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
217
- ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_cls, NO_LOOP, 3, 1, ain, aout};
218
- split_opts_cls opts = {StringValuePtr(criterion), NUM2LONG(n_classes), NUM2DBL(impurity)};
219
- VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, labels);
220
- RB_GC_GUARD(criterion);
221
- return params;
222
- }
223
-
224
- /**
225
- * @!visibility private
226
- */
227
- typedef struct {
228
- char* criterion;
229
- double impurity;
230
- } split_opts_reg;
231
-
232
- /**
233
- * @!visibility private
234
- */
235
- static void iter_find_split_params_reg(na_loop_t const* lp) {
236
- const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
237
- const double* f = (double*)NDL_PTR(lp, 1);
238
- const double* y = (double*)NDL_PTR(lp, 2);
239
- const long n_elements = NDL_SHAPE(lp, 0)[0];
240
- const long n_outputs = NDL_SHAPE(lp, 2)[1];
241
- const char* criterion = ((split_opts_reg*)lp->opt_ptr)->criterion;
242
- const double w_impurity = ((split_opts_reg*)lp->opt_ptr)->impurity;
243
- double* params = (double*)NDL_PTR(lp, 3);
244
- long curr_pos = 0;
245
- long next_pos = 0;
246
- long n_l_elements = 0;
247
- long n_r_elements = n_elements;
248
- double curr_el = f[o[0]];
249
- double last_el = f[o[n_elements - 1]];
250
- double next_el;
251
- double l_impurity;
252
- double r_impurity;
253
- double gain;
254
- double* l_sum_vec = alloc_dbl_array(n_outputs);
255
- double* r_sum_vec = alloc_dbl_array(n_outputs);
256
- double target_var;
257
- VALUE l_target_vecs = rb_ary_new();
258
- VALUE r_target_vecs = rb_ary_new();
259
- VALUE target;
260
-
261
- /* Initialize optimal parameters. */
262
- params[0] = 0.0; /* left impurity */
263
- params[1] = w_impurity; /* right impurity */
264
- params[2] = curr_el; /* threshold */
265
- params[3] = 0.0; /* gain */
266
-
267
- /* Initialize child node variables. */
268
- for (long i = 0; i < n_elements; i++) {
269
- target = rb_ary_new2(n_outputs);
270
- for (long j = 0; j < n_outputs; j++) {
271
- target_var = y[o[i] * n_outputs + j];
272
- rb_ary_store(target, j, DBL2NUM(target_var));
273
- r_sum_vec[j] += target_var;
274
- }
275
- rb_ary_push(r_target_vecs, target);
276
- }
277
-
278
- /* Find optimal parameters. */
279
- while (curr_pos < n_elements && curr_el != last_el) {
280
- next_el = f[o[next_pos]];
281
- while (next_pos < n_elements && next_el == curr_el) {
282
- target = rb_ary_shift(r_target_vecs);
283
- n_r_elements--;
284
- sub_sum_vec(r_sum_vec, target);
285
- rb_ary_push(l_target_vecs, target);
286
- n_l_elements++;
287
- add_sum_vec(l_sum_vec, target);
288
- next_pos++;
289
- next_el = f[o[next_pos]];
290
- }
291
- /* Calculate gain of new split. */
292
- l_impurity = calc_impurity_reg(criterion, l_target_vecs, l_sum_vec);
293
- r_impurity = calc_impurity_reg(criterion, r_target_vecs, r_sum_vec);
294
- gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
295
- /* Update optimal parameters. */
296
- if (gain > params[3]) {
297
- params[0] = l_impurity;
298
- params[1] = r_impurity;
299
- params[2] = 0.5 * (curr_el + next_el);
300
- params[3] = gain;
301
- }
302
- if (next_pos == n_elements) break;
303
- curr_pos = next_pos;
304
- curr_el = f[o[curr_pos]];
305
- }
306
-
307
- xfree(l_sum_vec);
308
- xfree(r_sum_vec);
309
- }
310
-
311
- /**
312
- * @!visibility private
313
- * Find for split point with maximum information gain.
314
- *
315
- * @overload find_split_params(criterion, impurity, order, features, targets) -> Array<Float>
316
- *
317
- * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
318
- * @param impurity [Float] The impurity of whole dataset.
319
- * @param order [Numo::Int32] (shape: [n_samples]) The element indices sorted according to feature values in ascending order.
320
- * @param features [Numo::DFloat] (shape: [n_samples]) The feature values.
321
- * @param targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
322
- * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
323
- */
324
- static VALUE find_split_params_reg(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE targets) {
325
- ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 2}};
326
- size_t out_shape[1] = {4};
327
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
328
- ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_reg, NO_LOOP, 3, 1, ain, aout};
329
- split_opts_reg opts = {StringValuePtr(criterion), NUM2DBL(impurity)};
330
- VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, targets);
331
- RB_GC_GUARD(criterion);
332
- return params;
333
- }
334
-
335
- /**
336
- * @!visibility private
337
- */
338
- static void iter_find_split_params_grad_reg(na_loop_t const* lp) {
339
- const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
340
- const double* f = (double*)NDL_PTR(lp, 1);
341
- const double* g = (double*)NDL_PTR(lp, 2);
342
- const double* h = (double*)NDL_PTR(lp, 3);
343
- const double s_grad = ((double*)lp->opt_ptr)[0];
344
- const double s_hess = ((double*)lp->opt_ptr)[1];
345
- const double reg_lambda = ((double*)lp->opt_ptr)[2];
346
- const long n_elements = NDL_SHAPE(lp, 0)[0];
347
- double* params = (double*)NDL_PTR(lp, 4);
348
- long curr_pos = 0;
349
- long next_pos = 0;
350
- double curr_el = f[o[0]];
351
- double last_el = f[o[n_elements - 1]];
352
- double next_el;
353
- double l_grad = 0.0;
354
- double l_hess = 0.0;
355
- double r_grad;
356
- double r_hess;
357
- double threshold = curr_el;
358
- double gain_max = 0.0;
359
- double gain;
360
-
361
- /* Find optimal parameters. */
362
- while (curr_pos < n_elements && curr_el != last_el) {
363
- next_el = f[o[next_pos]];
364
- while (next_pos < n_elements && next_el == curr_el) {
365
- l_grad += g[o[next_pos]];
366
- l_hess += h[o[next_pos]];
367
- next_pos++;
368
- next_el = f[o[next_pos]];
369
- }
370
- /* Calculate gain of new split. */
371
- r_grad = s_grad - l_grad;
372
- r_hess = s_hess - l_hess;
373
- gain = (l_grad * l_grad) / (l_hess + reg_lambda) + (r_grad * r_grad) / (r_hess + reg_lambda) -
374
- (s_grad * s_grad) / (s_hess + reg_lambda);
375
- /* Update optimal parameters. */
376
- if (gain > gain_max) {
377
- threshold = 0.5 * (curr_el + next_el);
378
- gain_max = gain;
379
- }
380
- if (next_pos == n_elements) {
381
- break;
382
- }
383
- curr_pos = next_pos;
384
- curr_el = f[o[curr_pos]];
385
- }
386
-
387
- params[0] = threshold;
388
- params[1] = gain_max;
389
- }
390
-
391
- /**
392
- * @!visibility private
393
- * Find for split point with maximum information gain.
394
- *
395
- * @overload find_split_params(order, features, gradients, hessians, sum_gradient, sum_hessian, reg_lambda) -> Array<Float>
396
- * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
397
- * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
398
- * @param gradients [Numo::DFloat] (shape: [n_elements]) The gradient values.
399
- * @param hessians [Numo::DFloat] (shape: [n_elements]) The hessian values.
400
- * @param sum_gradient [Float] The sum of gradient values.
401
- * @param sum_hessian [Float] The sum of hessian values.
402
- * @param reg_lambda [Float] The L2 regularization term on weight.
403
- * @return [Array<Float>] The array consists of optimal parameters including threshold and gain.
404
- */
405
- static VALUE find_split_params_grad_reg(VALUE self, VALUE order, VALUE features, VALUE gradients, VALUE hessians,
406
- VALUE sum_gradient, VALUE sum_hessian, VALUE reg_lambda) {
407
- ndfunc_arg_in_t ain[4] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}};
408
- size_t out_shape[1] = {2};
409
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
410
- ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_grad_reg, NO_LOOP, 4, 1, ain, aout};
411
- double opts[3] = {NUM2DBL(sum_gradient), NUM2DBL(sum_hessian), NUM2DBL(reg_lambda)};
412
- VALUE params = na_ndloop3(&ndf, opts, 4, order, features, gradients, hessians);
413
- return params;
414
- }
415
-
416
- /**
417
- * @!visibility private
418
- */
419
- typedef struct {
420
- char* criterion;
421
- long n_classes;
422
- } node_impurity_cls_opts;
423
-
424
- /**
425
- * @!visibility private
426
- */
427
- static void iter_node_impurity_cls(na_loop_t const* lp) {
428
- const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
429
- const char* criterion = ((node_impurity_cls_opts*)lp->opt_ptr)->criterion;
430
- const long n_classes = ((node_impurity_cls_opts*)lp->opt_ptr)->n_classes;
431
- const long n_elements = NDL_SHAPE(lp, 0)[0];
432
- double* ret = (double*)NDL_PTR(lp, 1);
433
- double* histogram = alloc_dbl_array(n_classes);
434
- for (long i = 0; i < n_elements; i++) histogram[y[i]] += 1;
435
- *ret = calc_impurity_cls(criterion, histogram, n_elements, n_classes);
436
- xfree(histogram);
437
- }
438
-
439
- /**
440
- * @!visibility private
441
- * Calculate impurity based on criterion.
442
- *
443
- * @overload node_impurity(criterion, y, n_classes) -> Float
444
- *
445
- * @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
446
- * @param y [Numo::Int32] (shape: [n_samples]) The labels.
447
- * @param n_classes [Integer] The number of classes.
448
- * @return [Float] impurity
449
- */
450
- static VALUE node_impurity_cls(VALUE self, VALUE criterion, VALUE y, VALUE n_classes) {
451
- ndfunc_arg_in_t ain[1] = {{numo_cInt32, 1}};
452
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 0}};
453
- ndfunc_t ndf = {(na_iter_func_t)iter_node_impurity_cls, NDF_EXTRACT, 1, 1, ain, aout};
454
- node_impurity_cls_opts opts = {StringValuePtr(criterion), NUM2LONG(n_classes)};
455
- VALUE ret = na_ndloop3(&ndf, &opts, 1, y);
456
- RB_GC_GUARD(criterion);
457
- return ret;
458
- }
459
-
460
- /**
461
- * @!visibility private
462
- */
463
- static void iter_check_same_label(na_loop_t const* lp) {
464
- const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
465
- const long n_elements = NDL_SHAPE(lp, 0)[0];
466
- int32_t* ret = (int32_t*)NDL_PTR(lp, 1);
467
- *ret = 1;
468
- if (n_elements > 0) {
469
- int32_t label = y[0];
470
- for (long i = 0; i < n_elements; i++) {
471
- if (y[i] != label) {
472
- *ret = 0;
473
- break;
474
- }
475
- }
476
- }
477
- }
478
-
479
- /**
480
- * @!visibility private
481
- * Check all elements have the save value.
482
- *
483
- * @overload check_same_label(y) -> Boolean
484
- *
485
- * @param y [Numo::Int32] (shape: [n_samples]) The labels.
486
- * @return [Boolean]
487
- */
488
- static VALUE check_same_label(VALUE self, VALUE y) {
489
- ndfunc_arg_in_t ain[1] = {{numo_cInt32, 1}};
490
- ndfunc_arg_out_t aout[1] = {{numo_cInt32, 0}};
491
- ndfunc_t ndf = {(na_iter_func_t)iter_check_same_label, NDF_EXTRACT, 1, 1, ain, aout};
492
- VALUE ret = na_ndloop(&ndf, 1, y);
493
- return (NUM2INT(ret) == 1 ? Qtrue : Qfalse);
494
- }
495
-
496
- /**
497
- * @!visibility private
498
- * Calculate impurity based on criterion.
499
- *
500
- * @overload node_impurity(criterion, y) -> Float
501
- *
502
- * @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
503
- * @param y [Array<Float>] (shape: [n_samples, n_outputs]) The taget values.
504
- * @return [Float] impurity
505
- */
506
- static VALUE node_impurity_reg(VALUE self, VALUE criterion, VALUE y) {
507
- const long n_elements = RARRAY_LEN(y);
508
- const long n_outputs = RARRAY_LEN(rb_ary_entry(y, 0));
509
- double* sum_vec = alloc_dbl_array(n_outputs);
510
- VALUE target_vecs = rb_ary_new();
511
-
512
- for (long i = 0; i < n_elements; i++) {
513
- VALUE target = rb_ary_entry(y, i);
514
- add_sum_vec(sum_vec, target);
515
- rb_ary_push(target_vecs, target);
516
- }
517
-
518
- VALUE ret = DBL2NUM(calc_impurity_reg(StringValuePtr(criterion), target_vecs, sum_vec));
519
- xfree(sum_vec);
520
- RB_GC_GUARD(criterion);
521
- return ret;
522
- }
523
-
524
- void Init_ext(void) {
525
- VALUE rb_mRumale = rb_define_module("Rumale");
526
- VALUE rb_mTree = rb_define_module_under(rb_mRumale, "Tree");
527
-
528
- /**
529
- * Document-module: Rumale::Tree::ExtDecisionTreeClassifier
530
- * @!visibility private
531
- * The mixin module consisting of extension method for DecisionTreeClassifier class.
532
- * This module is used internally.
533
- */
534
- VALUE rb_mExtDTreeCls = rb_define_module_under(rb_mTree, "ExtDecisionTreeClassifier");
535
- /**
536
- * Document-module: Rumale::Tree::ExtDecisionTreeRegressor
537
- * @!visibility private
538
- * The mixin module consisting of extension method for DecisionTreeRegressor class.
539
- * This module is used internally.
540
- */
541
- VALUE rb_mExtDTreeReg = rb_define_module_under(rb_mTree, "ExtDecisionTreeRegressor");
542
- /**
543
- * Document-module: Rumale::Tree::ExtGradientTreeRegressor
544
- * @!visibility private
545
- * The mixin module consisting of extension method for GradientTreeRegressor class.
546
- * This module is used internally.
547
- */
548
- VALUE rb_mExtGTreeReg = rb_define_module_under(rb_mTree, "ExtGradientTreeRegressor");
549
-
550
- rb_define_private_method(rb_mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
551
- rb_define_private_method(rb_mExtDTreeReg, "find_split_params", find_split_params_reg, 5);
552
- rb_define_private_method(rb_mExtGTreeReg, "find_split_params", find_split_params_grad_reg, 7);
553
- rb_define_private_method(rb_mExtDTreeCls, "node_impurity", node_impurity_cls, 3);
554
- rb_define_private_method(rb_mExtDTreeCls, "stop_growing?", check_same_label, 1);
555
- rb_define_private_method(rb_mExtDTreeReg, "node_impurity", node_impurity_reg, 2);
556
- }
@@ -1,12 +0,0 @@
1
- #ifndef RUMALE_TREE_EXT_H
2
- #define RUMALE_TREE_EXT_H 1
3
-
4
- #include <math.h>
5
- #include <string.h>
6
-
7
- #include <ruby.h>
8
-
9
- #include <numo/narray.h>
10
- #include <numo/template.h>
11
-
12
- #endif /* RUMALE_TREE_EXT_H */