rumale-tree 0.26.0 → 0.28.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/rumale/tree/ext.cpp +39 -0
- data/ext/rumale/tree/ext.hpp +550 -0
- data/ext/rumale/tree/extconf.rb +2 -0
- data/lib/rumale/tree/base_decision_tree.rb +1 -5
- data/lib/rumale/tree/decision_tree_regressor.rb +6 -3
- data/lib/rumale/tree/version.rb +1 -1
- metadata +7 -7
- data/ext/rumale/tree/ext.c +0 -575
- data/ext/rumale/tree/ext.h +0 -12
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: a437c3b879dd51d2e823f1851ca8b350d0441947fb6cefe727de83db78b4e6d9
|
4
|
+
data.tar.gz: 519665f2baea649ec31c9d5529aa05c67a880daa8ab60149433a53483f630663
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: aa68d3e2fbe99ef4c8d15a10f64098abfbbcdce9047053967e578edad688fa08a03a1ac4e4d23dbd5e67a4a23370057aa9462107bb961eb4d3f12782ce0c0f18
|
7
|
+
data.tar.gz: 6a6e6f167e31d56051768a881a4db0da7ef02f36977cb3969e8160bad385757153b1f6be32dd364ed3f67ed4672dd0c1bfd14cc80e0b033e00dd282d757955a9
|
@@ -0,0 +1,39 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2022-2023 Atsushi Tatsuma
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* Redistribution and use in source and binary forms, with or without
|
6
|
+
* modification, are permitted provided that the following conditions are met:
|
7
|
+
*
|
8
|
+
* * Redistributions of source code must retain the above copyright notice, this
|
9
|
+
* list of conditions and the following disclaimer.
|
10
|
+
*
|
11
|
+
* * Redistributions in binary form must reproduce the above copyright notice,
|
12
|
+
* this list of conditions and the following disclaimer in the documentation
|
13
|
+
* and/or other materials provided with the distribution.
|
14
|
+
*
|
15
|
+
* * Neither the name of the copyright holder nor the names of its
|
16
|
+
* contributors may be used to endorse or promote products derived from
|
17
|
+
* this software without specific prior written permission.
|
18
|
+
*
|
19
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29
|
+
*/
|
30
|
+
|
31
|
+
#include "ext.hpp"
|
32
|
+
|
33
|
+
extern "C" void Init_ext(void) {
|
34
|
+
VALUE rb_mRumale = rb_define_module("Rumale");
|
35
|
+
VALUE rb_mTree = rb_define_module_under(rb_mRumale, "Tree");
|
36
|
+
ExtDecisionTreeClassifier::define_module(rb_mTree);
|
37
|
+
ExtDecisionTreeRegressor::define_module(rb_mTree);
|
38
|
+
ExtGradientTreeRegressor::define_module(rb_mTree);
|
39
|
+
}
|
@@ -0,0 +1,550 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) 2022-2023 Atsushi Tatsuma
|
3
|
+
* All rights reserved.
|
4
|
+
*
|
5
|
+
* Redistribution and use in source and binary forms, with or without
|
6
|
+
* modification, are permitted provided that the following conditions are met:
|
7
|
+
*
|
8
|
+
* * Redistributions of source code must retain the above copyright notice, this
|
9
|
+
* list of conditions and the following disclaimer.
|
10
|
+
*
|
11
|
+
* * Redistributions in binary form must reproduce the above copyright notice,
|
12
|
+
* this list of conditions and the following disclaimer in the documentation
|
13
|
+
* and/or other materials provided with the distribution.
|
14
|
+
*
|
15
|
+
* * Neither the name of the copyright holder nor the names of its
|
16
|
+
* contributors may be used to endorse or promote products derived from
|
17
|
+
* this software without specific prior written permission.
|
18
|
+
*
|
19
|
+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20
|
+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21
|
+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22
|
+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23
|
+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24
|
+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25
|
+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26
|
+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27
|
+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28
|
+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
29
|
+
*/
|
30
|
+
|
31
|
+
#ifndef RUMALE_TREE_EXT_HPP
|
32
|
+
#define RUMALE_TREE_EXT_HPP 1
|
33
|
+
|
34
|
+
#include <cmath>
|
35
|
+
#include <limits>
|
36
|
+
#include <string>
|
37
|
+
#include <vector>
|
38
|
+
|
39
|
+
#include <ruby.h>
|
40
|
+
|
41
|
+
#include <numo/narray.h>
|
42
|
+
#include <numo/template.h>
|
43
|
+
|
44
|
+
/**
|
45
|
+
* @!visibility private
|
46
|
+
* Document-module: Rumale::Tree::ExtDecisionTreeClassifier
|
47
|
+
* The mixin module consisting of extension method for DecisionTreeClassifier class.
|
48
|
+
* This module is used internally.
|
49
|
+
*/
|
50
|
+
class ExtDecisionTreeClassifier {
|
51
|
+
public:
|
52
|
+
static void define_module(VALUE& outer) {
|
53
|
+
VALUE rb_mExtDTreeCls = rb_define_module_under(outer, "ExtDecisionTreeClassifier");
|
54
|
+
rb_define_private_method(rb_mExtDTreeCls, "find_split_params", find_split_params_, 6);
|
55
|
+
rb_define_private_method(rb_mExtDTreeCls, "node_impurity", node_impurity_, 3);
|
56
|
+
rb_define_private_method(rb_mExtDTreeCls, "stop_growing?", check_same_label_, 1);
|
57
|
+
}
|
58
|
+
|
59
|
+
private:
|
60
|
+
static double calc_impurity_(const std::string& criterion, const std::vector<size_t>& histogram, const size_t& n_elements, const size_t& n_classes) {
|
61
|
+
double impurity = 0.0;
|
62
|
+
if (criterion == "entropy") {
|
63
|
+
double entropy = 0.0;
|
64
|
+
for (size_t i = 0; i < n_classes; i++) {
|
65
|
+
const double el = static_cast<double>(histogram[i]) / static_cast<double>(n_elements);
|
66
|
+
entropy += el * std::log(el + 1.0);
|
67
|
+
}
|
68
|
+
impurity = -entropy;
|
69
|
+
} else {
|
70
|
+
double gini = 0.0;
|
71
|
+
for (size_t i = 0; i < n_classes; i++) {
|
72
|
+
const double el = static_cast<double>(histogram[i]) / static_cast<double>(n_elements);
|
73
|
+
gini += el * el;
|
74
|
+
}
|
75
|
+
impurity = 1.0 - gini;
|
76
|
+
}
|
77
|
+
return impurity;
|
78
|
+
}
|
79
|
+
|
80
|
+
/**
|
81
|
+
* @!visibility private
|
82
|
+
* Find for split point with maximum information gain.
|
83
|
+
*
|
84
|
+
* @overload find_split_params(criterion, impurity, order, features, labels, n_classes) -> Array<Float>
|
85
|
+
*
|
86
|
+
* @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
|
87
|
+
* @param impurity [Float] The impurity of whole dataset.
|
88
|
+
* @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
|
89
|
+
* @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
|
90
|
+
* @param labels [Numo::Int32] (shape: [n_elements]) The labels.
|
91
|
+
* @param n_classes [Integer] The number of classes.
|
92
|
+
* @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
|
93
|
+
*/
|
94
|
+
|
95
|
+
struct FindSplitParamsOpts_ {
|
96
|
+
std::string criterion;
|
97
|
+
size_t n_classes;
|
98
|
+
double impurity;
|
99
|
+
};
|
100
|
+
|
101
|
+
static void iter_find_split_params_(na_loop_t const* lp) {
|
102
|
+
// Obtain iteration variables.
|
103
|
+
const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
|
104
|
+
const size_t n_elements = NDL_SHAPE(lp, 0)[0];
|
105
|
+
const double* f = (double*)NDL_PTR(lp, 1);
|
106
|
+
const int32_t* y = (int32_t*)NDL_PTR(lp, 2);
|
107
|
+
const std::string criterion = ((FindSplitParamsOpts_*)lp->opt_ptr)->criterion;
|
108
|
+
const size_t n_classes = ((FindSplitParamsOpts_*)lp->opt_ptr)->n_classes;
|
109
|
+
const double w_impurity = ((FindSplitParamsOpts_*)lp->opt_ptr)->impurity;
|
110
|
+
|
111
|
+
// Initialize output optimal parameters.
|
112
|
+
double* params = (double*)NDL_PTR(lp, 3);
|
113
|
+
params[0] = 0.0; // left impurity
|
114
|
+
params[1] = w_impurity; // right impurity
|
115
|
+
params[2] = f[o[0]]; // threshold
|
116
|
+
params[3] = 0.0; // gain
|
117
|
+
|
118
|
+
// Initialize child node variables.
|
119
|
+
std::vector<size_t> r_histogram(n_classes, 0);
|
120
|
+
for (size_t i = 0; i < n_elements; i++) r_histogram[y[o[i]]] += 1;
|
121
|
+
|
122
|
+
// Find optimal parameters.
|
123
|
+
size_t curr_pos = 0;
|
124
|
+
size_t next_pos = 0;
|
125
|
+
size_t n_l_elements = 0;
|
126
|
+
size_t n_r_elements = n_elements;
|
127
|
+
double curr_el = f[o[0]];
|
128
|
+
const double last_el = f[o[n_elements - 1]];
|
129
|
+
std::vector<size_t> l_histogram(n_classes, 0);
|
130
|
+
while (curr_pos < n_elements && curr_el != last_el) {
|
131
|
+
double next_el = f[o[next_pos]];
|
132
|
+
while (next_pos < n_elements && next_el == curr_el) {
|
133
|
+
l_histogram[y[o[next_pos]]] += 1;
|
134
|
+
n_l_elements++;
|
135
|
+
r_histogram[y[o[next_pos]]] -= 1;
|
136
|
+
n_r_elements--;
|
137
|
+
next_pos++;
|
138
|
+
next_el = f[o[next_pos]];
|
139
|
+
}
|
140
|
+
// Calculate gain of new split.
|
141
|
+
const double l_impurity = calc_impurity_(criterion, l_histogram, n_l_elements, n_classes);
|
142
|
+
const double r_impurity = calc_impurity_(criterion, r_histogram, n_r_elements, n_classes);
|
143
|
+
const double gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / static_cast<double>(n_elements);
|
144
|
+
// Update optimal parameters.
|
145
|
+
if (gain > params[3]) {
|
146
|
+
params[0] = l_impurity;
|
147
|
+
params[1] = r_impurity;
|
148
|
+
params[2] = 0.5 * (curr_el + next_el);
|
149
|
+
params[3] = gain;
|
150
|
+
}
|
151
|
+
if (next_pos == n_elements) break;
|
152
|
+
curr_pos = next_pos;
|
153
|
+
curr_el = f[o[curr_pos]];
|
154
|
+
}
|
155
|
+
}
|
156
|
+
|
157
|
+
static VALUE find_split_params_(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE labels, VALUE n_classes) {
|
158
|
+
ndfunc_arg_in_t ain[3] = { { numo_cInt32, 1 }, { numo_cDFloat, 1 }, { numo_cInt32, 1 } };
|
159
|
+
size_t out_shape[1] = { 4 };
|
160
|
+
ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 1, out_shape } };
|
161
|
+
ndfunc_t ndf = { (na_iter_func_t)iter_find_split_params_, NO_LOOP, 3, 1, ain, aout };
|
162
|
+
FindSplitParamsOpts_ opts = { std::string(StringValueCStr(criterion)), NUM2SIZET(n_classes), NUM2DBL(impurity) };
|
163
|
+
VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, labels);
|
164
|
+
RB_GC_GUARD(criterion);
|
165
|
+
return params;
|
166
|
+
}
|
167
|
+
|
168
|
+
/**
|
169
|
+
* @!visibility private
|
170
|
+
* Calculate impurity based on criterion.
|
171
|
+
*
|
172
|
+
* @overload node_impurity(criterion, y, n_classes) -> Float
|
173
|
+
*
|
174
|
+
* @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
|
175
|
+
* @param y [Numo::Int32] (shape: [n_samples]) The labels.
|
176
|
+
* @param n_classes [Integer] The number of classes.
|
177
|
+
* @return [Float] impurity
|
178
|
+
*/
|
179
|
+
|
180
|
+
struct NodeImpurityOpts_ {
|
181
|
+
std::string criterion;
|
182
|
+
size_t n_classes;
|
183
|
+
};
|
184
|
+
|
185
|
+
static void iter_node_impurity_(na_loop_t const* lp) {
|
186
|
+
const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
|
187
|
+
const size_t n_elements = NDL_SHAPE(lp, 0)[0];
|
188
|
+
const std::string criterion = ((NodeImpurityOpts_*)lp->opt_ptr)->criterion;
|
189
|
+
const size_t n_classes = ((NodeImpurityOpts_*)lp->opt_ptr)->n_classes;
|
190
|
+
double* ret = (double*)NDL_PTR(lp, 1);
|
191
|
+
std::vector<size_t> histogram(n_classes, 0);
|
192
|
+
for (size_t i = 0; i < n_elements; i++) histogram[y[i]] += 1;
|
193
|
+
*ret = calc_impurity_(criterion, histogram, n_elements, n_classes);
|
194
|
+
}
|
195
|
+
|
196
|
+
static VALUE node_impurity_(VALUE self, VALUE criterion, VALUE y, VALUE n_classes) {
|
197
|
+
ndfunc_arg_in_t ain[1] = { { numo_cInt32, 1 } };
|
198
|
+
ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 0 } };
|
199
|
+
ndfunc_t ndf = { (na_iter_func_t)iter_node_impurity_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
|
200
|
+
NodeImpurityOpts_ opts = { std::string(StringValueCStr(criterion)), NUM2SIZET(n_classes) };
|
201
|
+
VALUE ret = na_ndloop3(&ndf, &opts, 1, y);
|
202
|
+
RB_GC_GUARD(criterion);
|
203
|
+
return ret;
|
204
|
+
}
|
205
|
+
|
206
|
+
/**
|
207
|
+
* @!visibility private
|
208
|
+
* Check all elements have the same value.
|
209
|
+
*
|
210
|
+
* @overload check_same_label(y) -> Boolean
|
211
|
+
*
|
212
|
+
* @param y [Numo::Int32] (shape: [n_samples]) The labels.
|
213
|
+
* @return [Boolean]
|
214
|
+
*/
|
215
|
+
|
216
|
+
static void iter_check_same_label_(na_loop_t const* lp) {
|
217
|
+
const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
|
218
|
+
const size_t n_elements = NDL_SHAPE(lp, 0)[0];
|
219
|
+
VALUE* ret = (VALUE*)NDL_PTR(lp, 1);
|
220
|
+
*ret = Qtrue;
|
221
|
+
if (n_elements > 0) {
|
222
|
+
int32_t label = y[0];
|
223
|
+
for (size_t i = 0; i < n_elements; i++) {
|
224
|
+
if (y[i] != label) {
|
225
|
+
*ret = Qfalse;
|
226
|
+
break;
|
227
|
+
}
|
228
|
+
}
|
229
|
+
}
|
230
|
+
}
|
231
|
+
|
232
|
+
static VALUE check_same_label_(VALUE self, VALUE y) {
|
233
|
+
ndfunc_arg_in_t ain[1] = { { numo_cInt32, 1 } };
|
234
|
+
ndfunc_arg_out_t aout[1] = { { numo_cRObject, 0 } };
|
235
|
+
ndfunc_t ndf = { (na_iter_func_t)iter_check_same_label_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
|
236
|
+
return na_ndloop(&ndf, 1, y);
|
237
|
+
}
|
238
|
+
};
|
239
|
+
|
240
|
+
/**
|
241
|
+
* @!visibility private
|
242
|
+
* Document-module: Rumale::Tree::ExtDecisionTreeRegressor
|
243
|
+
* The mixin module consisting of extension method for DecisionTreeRegressor class.
|
244
|
+
* This module is used internally.
|
245
|
+
*/
|
246
|
+
class ExtDecisionTreeRegressor {
|
247
|
+
public:
|
248
|
+
static void define_module(VALUE& outer) {
|
249
|
+
VALUE rb_mExtDTreeReg = rb_define_module_under(outer, "ExtDecisionTreeRegressor");
|
250
|
+
rb_define_private_method(rb_mExtDTreeReg, "find_split_params", find_split_params_, 5);
|
251
|
+
rb_define_private_method(rb_mExtDTreeReg, "node_impurity", node_impurity_, 2);
|
252
|
+
rb_define_private_method(rb_mExtDTreeReg, "stop_growing?", check_same_value_, 1);
|
253
|
+
}
|
254
|
+
|
255
|
+
private:
|
256
|
+
static double calc_impurity_(const std::string& criterion, const int32_t* order, const double* vecs, const double* mean_vec,
|
257
|
+
const size_t& n_elements, const size_t& n_outputs, const size_t& order_offset) {
|
258
|
+
const bool is_mae = criterion == "mae";
|
259
|
+
double sum_err = 0.0;
|
260
|
+
for (size_t i = 0; i < n_elements; i++) {
|
261
|
+
double err = 0.0;
|
262
|
+
for (size_t j = 0; j < n_outputs; j++) {
|
263
|
+
const double el = vecs[order[order_offset + i] * n_outputs + j] - mean_vec[j];
|
264
|
+
err += is_mae ? std::fabs(el) : el * el;
|
265
|
+
}
|
266
|
+
err /= static_cast<double>(n_outputs);
|
267
|
+
sum_err += err;
|
268
|
+
}
|
269
|
+
const double impurity = sum_err / static_cast<double>(n_elements);
|
270
|
+
return impurity;
|
271
|
+
}
|
272
|
+
|
273
|
+
/**
|
274
|
+
* @!visibility private
|
275
|
+
* Find for split point with maximum information gain.
|
276
|
+
*
|
277
|
+
* @overload find_split_params(criterion, impurity, order, features, targets) -> Array<Float>
|
278
|
+
*
|
279
|
+
* @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
|
280
|
+
* @param impurity [Float] The impurity of whole dataset.
|
281
|
+
* @param order [Numo::Int32] (shape: [n_samples]) The element indices sorted according to feature values in ascending order.
|
282
|
+
* @param features [Numo::DFloat] (shape: [n_samples]) The feature values.
|
283
|
+
* @param targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
|
284
|
+
* @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
|
285
|
+
*/
|
286
|
+
|
287
|
+
struct FindSplitParamsOpts_ {
|
288
|
+
std::string criterion;
|
289
|
+
double impurity;
|
290
|
+
};
|
291
|
+
|
292
|
+
static void iter_find_split_params_(na_loop_t const* lp) {
|
293
|
+
// Obtain iteration variables.
|
294
|
+
const int32_t* order = (int32_t*)NDL_PTR(lp, 0);
|
295
|
+
const size_t n_elements = NDL_SHAPE(lp, 0)[0];
|
296
|
+
const double* f = (double*)NDL_PTR(lp, 1);
|
297
|
+
const double* y = (double*)NDL_PTR(lp, 2);
|
298
|
+
const size_t n_outputs = NDL_SHAPE(lp, 2)[1];
|
299
|
+
const std::string criterion = ((FindSplitParamsOpts_*)lp->opt_ptr)->criterion;
|
300
|
+
const double w_impurity = ((FindSplitParamsOpts_*)lp->opt_ptr)->impurity;
|
301
|
+
|
302
|
+
// Initialize optimal parameters.
|
303
|
+
double* params = (double*)NDL_PTR(lp, 3);
|
304
|
+
params[0] = 0.0; // left impurity
|
305
|
+
params[1] = w_impurity; // right impurity
|
306
|
+
params[2] = f[order[0]]; // threshold
|
307
|
+
params[3] = 0.0; // gain
|
308
|
+
|
309
|
+
// Initialize child node variables.
|
310
|
+
std::vector<double> l_sum_y(n_outputs, 0);
|
311
|
+
std::vector<double> r_sum_y(n_outputs, 0);
|
312
|
+
for (size_t i = 0; i < n_elements; i++) {
|
313
|
+
for (size_t j = 0; j < n_outputs; j++) {
|
314
|
+
r_sum_y[j] += y[order[i] * n_outputs + j];
|
315
|
+
}
|
316
|
+
}
|
317
|
+
|
318
|
+
// Find optimal parameters.
|
319
|
+
size_t curr_pos = 0;
|
320
|
+
size_t next_pos = 0;
|
321
|
+
size_t n_l_elements = 0;
|
322
|
+
size_t n_r_elements = n_elements;
|
323
|
+
std::vector<double> l_mean_y(n_outputs, 0);
|
324
|
+
std::vector<double> r_mean_y(n_outputs, 0);
|
325
|
+
double curr_el = f[order[0]];
|
326
|
+
const double last_el = f[order[n_elements - 1]];
|
327
|
+
while (curr_pos < n_elements && curr_el != last_el) {
|
328
|
+
double next_el = f[order[next_pos]];
|
329
|
+
while (next_pos < n_elements && next_el == curr_el) {
|
330
|
+
for (size_t j = 0; j < n_outputs; j++) {
|
331
|
+
l_sum_y[j] += y[order[next_pos] * n_outputs + j];
|
332
|
+
r_sum_y[j] -= y[order[next_pos] * n_outputs + j];
|
333
|
+
}
|
334
|
+
n_l_elements++;
|
335
|
+
n_r_elements--;
|
336
|
+
next_pos++;
|
337
|
+
next_el = f[order[next_pos]];
|
338
|
+
}
|
339
|
+
// Calculate gain of new split.
|
340
|
+
for (size_t j = 0; j < n_outputs; j++) {
|
341
|
+
l_mean_y[j] = l_sum_y[j] / static_cast<double>(n_l_elements);
|
342
|
+
r_mean_y[j] = r_sum_y[j] / static_cast<double>(n_r_elements);
|
343
|
+
}
|
344
|
+
const double l_impurity = calc_impurity_(criterion, order, y, l_mean_y.data(), n_l_elements, n_outputs, 0);
|
345
|
+
const double r_impurity = calc_impurity_(criterion, order, y, r_mean_y.data(), n_r_elements, n_outputs, next_pos);
|
346
|
+
const double gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / static_cast<double>(n_elements);
|
347
|
+
// Update optimal parameters.
|
348
|
+
if (gain > params[3]) {
|
349
|
+
params[0] = l_impurity;
|
350
|
+
params[1] = r_impurity;
|
351
|
+
params[2] = 0.5 * (curr_el + next_el);
|
352
|
+
params[3] = gain;
|
353
|
+
}
|
354
|
+
if (next_pos == n_elements) break;
|
355
|
+
curr_pos = next_pos;
|
356
|
+
curr_el = f[order[curr_pos]];
|
357
|
+
}
|
358
|
+
}
|
359
|
+
|
360
|
+
static VALUE find_split_params_(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE targets) {
|
361
|
+
ndfunc_arg_in_t ain[3] = { { numo_cInt32, 1 }, { numo_cDFloat, 1 }, { numo_cDFloat, 2 } };
|
362
|
+
size_t out_shape[1] = { 4 };
|
363
|
+
ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 1, out_shape } };
|
364
|
+
ndfunc_t ndf = { (na_iter_func_t)iter_find_split_params_, NO_LOOP, 3, 1, ain, aout };
|
365
|
+
FindSplitParamsOpts_ opts = { std::string(StringValueCStr(criterion)), NUM2DBL(impurity) };
|
366
|
+
VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, targets);
|
367
|
+
RB_GC_GUARD(criterion);
|
368
|
+
return params;
|
369
|
+
}
|
370
|
+
|
371
|
+
/**
|
372
|
+
* @!visibility private
|
373
|
+
* Calculate impurity based on criterion.
|
374
|
+
*
|
375
|
+
* @overload node_impurity(criterion, y) -> Float
|
376
|
+
*
|
377
|
+
* @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
|
378
|
+
* @param y [Array<Float>] (shape: [n_samples, n_outputs]) The taget values.
|
379
|
+
* @return [Float] impurity
|
380
|
+
*/
|
381
|
+
|
382
|
+
struct NodeImpurityOpts_ {
|
383
|
+
std::string criterion;
|
384
|
+
};
|
385
|
+
|
386
|
+
static void iter_node_impurity_(na_loop_t const* lp) {
|
387
|
+
const double* y = (double*)NDL_PTR(lp, 0);
|
388
|
+
const size_t n_elements = NDL_SHAPE(lp, 0)[0];
|
389
|
+
const size_t n_outputs = NDL_SHAPE(lp, 0)[1];
|
390
|
+
const std::string criterion = ((NodeImpurityOpts_*)lp->opt_ptr)->criterion;
|
391
|
+
|
392
|
+
std::vector<int32_t> order(n_elements);
|
393
|
+
std::vector<double> mean_y(n_outputs, 0);
|
394
|
+
for (size_t i = 0; i < n_elements; i++) {
|
395
|
+
order[i] = static_cast<int32_t>(i);
|
396
|
+
for (size_t j = 0; j < n_outputs; j++) {
|
397
|
+
mean_y[j] += y[i * n_outputs + j];
|
398
|
+
}
|
399
|
+
}
|
400
|
+
for (size_t j = 0; j < n_outputs; j++) {
|
401
|
+
mean_y[j] /= static_cast<double>(n_elements);
|
402
|
+
}
|
403
|
+
|
404
|
+
double* ret = (double*)NDL_PTR(lp, 1);
|
405
|
+
*ret = calc_impurity_(criterion, order.data(), y, mean_y.data(), n_elements, n_outputs, 0);
|
406
|
+
}
|
407
|
+
|
408
|
+
static VALUE node_impurity_(VALUE self, VALUE criterion, VALUE y) {
|
409
|
+
ndfunc_arg_in_t ain[1] = { { numo_cDFloat, 2 } };
|
410
|
+
ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 0 } };
|
411
|
+
ndfunc_t ndf = { (na_iter_func_t)iter_node_impurity_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
|
412
|
+
NodeImpurityOpts_ opts = { std::string(StringValueCStr(criterion)) };
|
413
|
+
VALUE ret = na_ndloop3(&ndf, &opts, 1, y);
|
414
|
+
RB_GC_GUARD(criterion);
|
415
|
+
return ret;
|
416
|
+
}
|
417
|
+
|
418
|
+
/**
|
419
|
+
* @!visibility private
|
420
|
+
* Check all elements have the same value/vector.
|
421
|
+
*
|
422
|
+
* @overload check_same_value(y) -> Boolean
|
423
|
+
*
|
424
|
+
* @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
|
425
|
+
* @return [Boolean]
|
426
|
+
*/
|
427
|
+
|
428
|
+
static void iter_check_same_value_(na_loop_t const* lp) {
|
429
|
+
const double* y = (double*)NDL_PTR(lp, 0);
|
430
|
+
const size_t n_elements = NDL_SHAPE(lp, 0)[0];
|
431
|
+
const size_t n_outputs = NDL_SHAPE(lp, 0)[1];
|
432
|
+
VALUE* ret = (VALUE*)NDL_PTR(lp, 1);
|
433
|
+
const double eps = std::numeric_limits<double>::epsilon();
|
434
|
+
*ret = Qtrue;
|
435
|
+
if (n_elements > 0) {
|
436
|
+
for (size_t i = 1; i < n_elements; i++) {
|
437
|
+
for (size_t j = 0; j < n_outputs; j++) {
|
438
|
+
if (std::abs(y[i * n_outputs + j] - y[j]) > eps) {
|
439
|
+
*ret = Qfalse;
|
440
|
+
break;
|
441
|
+
}
|
442
|
+
}
|
443
|
+
if (*ret == Qfalse) break;
|
444
|
+
}
|
445
|
+
}
|
446
|
+
}
|
447
|
+
|
448
|
+
static VALUE check_same_value_(VALUE self, VALUE y) {
|
449
|
+
ndfunc_arg_in_t ain[1] = { { numo_cDFloat, 2 } };
|
450
|
+
ndfunc_arg_out_t aout[1] = { { numo_cRObject, 0 } };
|
451
|
+
ndfunc_t ndf = { (na_iter_func_t)iter_check_same_value_, NO_LOOP | NDF_EXTRACT, 1, 1, ain, aout };
|
452
|
+
return na_ndloop(&ndf, 1, y);
|
453
|
+
}
|
454
|
+
};
|
455
|
+
|
456
|
+
/**
|
457
|
+
* @!visibility private
|
458
|
+
* Document-module: Rumale::Tree::ExtGradientTreeRegressor
|
459
|
+
* The mixin module consisting of extension method for GradientTreeRegressor class.
|
460
|
+
* This module is used internally.
|
461
|
+
*/
|
462
|
+
class ExtGradientTreeRegressor {
|
463
|
+
public:
|
464
|
+
static void define_module(VALUE& outer) {
|
465
|
+
VALUE rb_mExtGTreeReg = rb_define_module_under(outer, "ExtGradientTreeRegressor");
|
466
|
+
rb_define_private_method(rb_mExtGTreeReg, "find_split_params", find_split_params_, 7);
|
467
|
+
}
|
468
|
+
|
469
|
+
private:
|
470
|
+
/**
|
471
|
+
* @!visibility private
|
472
|
+
* Find for split point with maximum information gain.
|
473
|
+
*
|
474
|
+
* @overload find_split_params(order, features, gradients, hessians, sum_gradient, sum_hessian, reg_lambda) -> Array<Float>
|
475
|
+
* @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
|
476
|
+
* @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
|
477
|
+
* @param gradients [Numo::DFloat] (shape: [n_elements]) The gradient values.
|
478
|
+
* @param hessians [Numo::DFloat] (shape: [n_elements]) The hessian values.
|
479
|
+
* @param sum_gradient [Float] The sum of gradient values.
|
480
|
+
* @param sum_hessian [Float] The sum of hessian values.
|
481
|
+
* @param reg_lambda [Float] The L2 regularization term on weight.
|
482
|
+
* @return [Array<Float>] The array consists of optimal parameters including threshold and gain.
|
483
|
+
*/
|
484
|
+
|
485
|
+
struct FindSplitParamsOpts_ {
|
486
|
+
double sum_gradient;
|
487
|
+
double sum_hessian;
|
488
|
+
double reg_lambda;
|
489
|
+
};
|
490
|
+
|
491
|
+
static void iter_find_split_params_(na_loop_t const* lp) {
|
492
|
+
// Obtain iteration variables.
|
493
|
+
const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
|
494
|
+
const size_t n_elements = NDL_SHAPE(lp, 0)[0];
|
495
|
+
const double* f = (double*)NDL_PTR(lp, 1);
|
496
|
+
const double* g = (double*)NDL_PTR(lp, 2);
|
497
|
+
const double* h = (double*)NDL_PTR(lp, 3);
|
498
|
+
const double s_grad = ((FindSplitParamsOpts_*)lp->opt_ptr)->sum_gradient;
|
499
|
+
const double s_hess = ((FindSplitParamsOpts_*)lp->opt_ptr)->sum_hessian;
|
500
|
+
const double reg_lambda = ((FindSplitParamsOpts_*)lp->opt_ptr)->reg_lambda;
|
501
|
+
|
502
|
+
// Find optimal parameters.
|
503
|
+
size_t curr_pos = 0;
|
504
|
+
size_t next_pos = 0;
|
505
|
+
double curr_el = f[o[0]];
|
506
|
+
const double last_el = f[o[n_elements - 1]];
|
507
|
+
double l_grad = 0.0;
|
508
|
+
double l_hess = 0.0;
|
509
|
+
double threshold = curr_el;
|
510
|
+
double gain_max = 0.0;
|
511
|
+
while (curr_pos < n_elements && curr_el != last_el) {
|
512
|
+
double next_el = f[o[next_pos]];
|
513
|
+
while (next_pos < n_elements && next_el == curr_el) {
|
514
|
+
l_grad += g[o[next_pos]];
|
515
|
+
l_hess += h[o[next_pos]];
|
516
|
+
next_pos++;
|
517
|
+
next_el = f[o[next_pos]];
|
518
|
+
}
|
519
|
+
// Calculate gain of new split.
|
520
|
+
const double r_grad = s_grad - l_grad;
|
521
|
+
const double r_hess = s_hess - l_hess;
|
522
|
+
const double gain = (l_grad * l_grad) / (l_hess + reg_lambda) + (r_grad * r_grad) / (r_hess + reg_lambda) - (s_grad * s_grad) / (s_hess + reg_lambda);
|
523
|
+
// Update optimal parameters.
|
524
|
+
if (gain > gain_max) {
|
525
|
+
threshold = 0.5 * (curr_el + next_el);
|
526
|
+
gain_max = gain;
|
527
|
+
}
|
528
|
+
if (next_pos == n_elements) break;
|
529
|
+
curr_pos = next_pos;
|
530
|
+
curr_el = f[o[curr_pos]];
|
531
|
+
}
|
532
|
+
|
533
|
+
double* params = (double*)NDL_PTR(lp, 4);
|
534
|
+
params[0] = threshold;
|
535
|
+
params[1] = gain_max;
|
536
|
+
}
|
537
|
+
|
538
|
+
static VALUE find_split_params_(VALUE self, VALUE order, VALUE features, VALUE gradients, VALUE hessians,
|
539
|
+
VALUE sum_gradient, VALUE sum_hessian, VALUE reg_lambda) {
|
540
|
+
ndfunc_arg_in_t ain[4] = { { numo_cInt32, 1 }, { numo_cDFloat, 1 }, { numo_cDFloat, 1 }, { numo_cDFloat, 1 } };
|
541
|
+
size_t out_shape[1] = { 2 };
|
542
|
+
ndfunc_arg_out_t aout[1] = { { numo_cDFloat, 1, out_shape } };
|
543
|
+
ndfunc_t ndf = { (na_iter_func_t)iter_find_split_params_, NO_LOOP, 4, 1, ain, aout };
|
544
|
+
FindSplitParamsOpts_ opts = { NUM2DBL(sum_gradient), NUM2DBL(sum_hessian), NUM2DBL(reg_lambda) };
|
545
|
+
VALUE params = na_ndloop3(&ndf, &opts, 4, order, features, gradients, hessians);
|
546
|
+
return params;
|
547
|
+
}
|
548
|
+
};
|
549
|
+
|
550
|
+
#endif /* RUMALE_TREE_EXT_HPP */
|
data/ext/rumale/tree/extconf.rb
CHANGED
@@ -63,11 +63,7 @@ module Rumale
|
|
63
63
|
end
|
64
64
|
|
65
65
|
def build_tree(x, y)
|
66
|
-
|
67
|
-
@feature_ids = Array.new(x.shape[1]) { |v| v }
|
68
|
-
@tree = grow_node(0, x, y, impurity(y))
|
69
|
-
@feature_ids = nil
|
70
|
-
nil
|
66
|
+
raise NotImplementedError, "#{__method__} has to be implemented in #{self.class}."
|
71
67
|
end
|
72
68
|
|
73
69
|
def grow_node(depth, x, y, impurity) # rubocop:disable Metrics/AbcSize, Metrics/PerceivedComplexity
|
@@ -68,6 +68,7 @@ module Rumale
|
|
68
68
|
@params[:max_features] = [@params[:max_features], n_features].min
|
69
69
|
@n_leaves = 0
|
70
70
|
@leaf_values = []
|
71
|
+
@feature_ids = Array.new(x.shape[1]) { |v| v }
|
71
72
|
@sub_rng = @rng.dup
|
72
73
|
build_tree(x, y)
|
73
74
|
eval_importance(n_samples, n_features)
|
@@ -88,8 +89,10 @@ module Rumale
|
|
88
89
|
|
89
90
|
private
|
90
91
|
|
91
|
-
def
|
92
|
-
y.
|
92
|
+
def build_tree(x, y)
|
93
|
+
y = y.expand_dims(1).dup if y.shape[1].nil?
|
94
|
+
@tree = grow_node(0, x, y, impurity(y))
|
95
|
+
nil
|
93
96
|
end
|
94
97
|
|
95
98
|
def put_leaf(node, y)
|
@@ -106,7 +109,7 @@ module Rumale
|
|
106
109
|
end
|
107
110
|
|
108
111
|
def impurity(y)
|
109
|
-
node_impurity(@params[:criterion], y
|
112
|
+
node_impurity(@params[:criterion], y)
|
110
113
|
end
|
111
114
|
end
|
112
115
|
end
|
data/lib/rumale/tree/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale-tree
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.28.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2023-
|
11
|
+
date: 2023-11-12 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -30,14 +30,14 @@ dependencies:
|
|
30
30
|
requirements:
|
31
31
|
- - "~>"
|
32
32
|
- !ruby/object:Gem::Version
|
33
|
-
version: 0.
|
33
|
+
version: 0.28.0
|
34
34
|
type: :runtime
|
35
35
|
prerelease: false
|
36
36
|
version_requirements: !ruby/object:Gem::Requirement
|
37
37
|
requirements:
|
38
38
|
- - "~>"
|
39
39
|
- !ruby/object:Gem::Version
|
40
|
-
version: 0.
|
40
|
+
version: 0.28.0
|
41
41
|
description: Rumale::Tree provides classifier and regression based on decision tree
|
42
42
|
algorithms with Rumale interface.
|
43
43
|
email:
|
@@ -49,8 +49,8 @@ extra_rdoc_files: []
|
|
49
49
|
files:
|
50
50
|
- LICENSE.txt
|
51
51
|
- README.md
|
52
|
-
- ext/rumale/tree/ext.
|
53
|
-
- ext/rumale/tree/ext.
|
52
|
+
- ext/rumale/tree/ext.cpp
|
53
|
+
- ext/rumale/tree/ext.hpp
|
54
54
|
- ext/rumale/tree/extconf.rb
|
55
55
|
- lib/rumale/tree.rb
|
56
56
|
- lib/rumale/tree/base_decision_tree.rb
|
@@ -85,7 +85,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
85
85
|
- !ruby/object:Gem::Version
|
86
86
|
version: '0'
|
87
87
|
requirements: []
|
88
|
-
rubygems_version: 3.
|
88
|
+
rubygems_version: 3.4.20
|
89
89
|
signing_key:
|
90
90
|
specification_version: 4
|
91
91
|
summary: Rumale::Tree provides classifier and regression based on decision tree algorithms
|
data/ext/rumale/tree/ext.c
DELETED
@@ -1,575 +0,0 @@
|
|
1
|
-
#include "ext.h"
|
2
|
-
|
3
|
-
double* alloc_dbl_array(const long n_dimensions) {
|
4
|
-
double* arr = ALLOC_N(double, n_dimensions);
|
5
|
-
memset(arr, 0, n_dimensions * sizeof(double));
|
6
|
-
return arr;
|
7
|
-
}
|
8
|
-
|
9
|
-
double calc_gini_coef(double* histogram, const long n_elements, const long n_classes) {
|
10
|
-
double gini = 0.0;
|
11
|
-
|
12
|
-
for (long i = 0; i < n_classes; i++) {
|
13
|
-
double el = histogram[i] / n_elements;
|
14
|
-
gini += el * el;
|
15
|
-
}
|
16
|
-
|
17
|
-
return 1.0 - gini;
|
18
|
-
}
|
19
|
-
|
20
|
-
double calc_entropy(double* histogram, const long n_elements, const long n_classes) {
|
21
|
-
double entropy = 0.0;
|
22
|
-
|
23
|
-
for (long i = 0; i < n_classes; i++) {
|
24
|
-
double el = histogram[i] / n_elements;
|
25
|
-
entropy += el * log(el + 1.0);
|
26
|
-
}
|
27
|
-
|
28
|
-
return -entropy;
|
29
|
-
}
|
30
|
-
|
31
|
-
VALUE
|
32
|
-
calc_mean_vec(double* sum_vec, const long n_dimensions, const long n_elements) {
|
33
|
-
VALUE mean_vec = rb_ary_new2(n_dimensions);
|
34
|
-
|
35
|
-
for (long i = 0; i < n_dimensions; i++) {
|
36
|
-
rb_ary_store(mean_vec, i, DBL2NUM(sum_vec[i] / n_elements));
|
37
|
-
}
|
38
|
-
|
39
|
-
return mean_vec;
|
40
|
-
}
|
41
|
-
|
42
|
-
double calc_vec_mae(VALUE vec_a, VALUE vec_b) {
|
43
|
-
const long n_dimensions = RARRAY_LEN(vec_a);
|
44
|
-
double sum = 0.0;
|
45
|
-
|
46
|
-
for (long i = 0; i < n_dimensions; i++) {
|
47
|
-
double diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
|
48
|
-
sum += fabs(diff);
|
49
|
-
}
|
50
|
-
|
51
|
-
return sum / n_dimensions;
|
52
|
-
}
|
53
|
-
|
54
|
-
double calc_vec_mse(VALUE vec_a, VALUE vec_b) {
|
55
|
-
const long n_dimensions = RARRAY_LEN(vec_a);
|
56
|
-
double sum = 0.0;
|
57
|
-
|
58
|
-
for (long i = 0; i < n_dimensions; i++) {
|
59
|
-
double diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
|
60
|
-
sum += diff * diff;
|
61
|
-
}
|
62
|
-
|
63
|
-
return sum / n_dimensions;
|
64
|
-
}
|
65
|
-
|
66
|
-
double calc_mae(VALUE target_vecs, VALUE mean_vec) {
|
67
|
-
const long n_elements = RARRAY_LEN(target_vecs);
|
68
|
-
double sum = 0.0;
|
69
|
-
|
70
|
-
for (long i = 0; i < n_elements; i++) {
|
71
|
-
sum += calc_vec_mae(rb_ary_entry(target_vecs, i), mean_vec);
|
72
|
-
}
|
73
|
-
|
74
|
-
return sum / n_elements;
|
75
|
-
}
|
76
|
-
|
77
|
-
double calc_mse(VALUE target_vecs, VALUE mean_vec) {
|
78
|
-
const long n_elements = RARRAY_LEN(target_vecs);
|
79
|
-
double sum = 0.0;
|
80
|
-
|
81
|
-
for (long i = 0; i < n_elements; i++) {
|
82
|
-
sum += calc_vec_mse(rb_ary_entry(target_vecs, i), mean_vec);
|
83
|
-
}
|
84
|
-
|
85
|
-
return sum / n_elements;
|
86
|
-
}
|
87
|
-
|
88
|
-
double calc_impurity_cls(const char* criterion, double* histogram, const long n_elements, const long n_classes) {
|
89
|
-
if (strcmp(criterion, "entropy") == 0) {
|
90
|
-
return calc_entropy(histogram, n_elements, n_classes);
|
91
|
-
}
|
92
|
-
return calc_gini_coef(histogram, n_elements, n_classes);
|
93
|
-
}
|
94
|
-
|
95
|
-
double calc_impurity_reg(const char* criterion, VALUE target_vecs, double* sum_vec) {
|
96
|
-
const long n_elements = RARRAY_LEN(target_vecs);
|
97
|
-
const long n_dimensions = RARRAY_LEN(rb_ary_entry(target_vecs, 0));
|
98
|
-
VALUE mean_vec = calc_mean_vec(sum_vec, n_dimensions, n_elements);
|
99
|
-
|
100
|
-
if (strcmp(criterion, "mae") == 0) {
|
101
|
-
return calc_mae(target_vecs, mean_vec);
|
102
|
-
}
|
103
|
-
return calc_mse(target_vecs, mean_vec);
|
104
|
-
}
|
105
|
-
|
106
|
-
void add_sum_vec(double* sum_vec, VALUE target) {
|
107
|
-
const long n_dimensions = RARRAY_LEN(target);
|
108
|
-
|
109
|
-
for (long i = 0; i < n_dimensions; i++) {
|
110
|
-
sum_vec[i] += NUM2DBL(rb_ary_entry(target, i));
|
111
|
-
}
|
112
|
-
}
|
113
|
-
|
114
|
-
void sub_sum_vec(double* sum_vec, VALUE target) {
|
115
|
-
const long n_dimensions = RARRAY_LEN(target);
|
116
|
-
|
117
|
-
for (long i = 0; i < n_dimensions; i++) {
|
118
|
-
sum_vec[i] -= NUM2DBL(rb_ary_entry(target, i));
|
119
|
-
}
|
120
|
-
}
|
121
|
-
|
122
|
-
/**
|
123
|
-
* @!visibility private
|
124
|
-
*/
|
125
|
-
typedef struct {
|
126
|
-
char* criterion;
|
127
|
-
long n_classes;
|
128
|
-
double impurity;
|
129
|
-
} split_opts_cls;
|
130
|
-
|
131
|
-
/**
|
132
|
-
* @!visibility private
|
133
|
-
*/
|
134
|
-
static void iter_find_split_params_cls(na_loop_t const* lp) {
|
135
|
-
const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
|
136
|
-
const double* f = (double*)NDL_PTR(lp, 1);
|
137
|
-
const int32_t* y = (int32_t*)NDL_PTR(lp, 2);
|
138
|
-
const long n_elements = NDL_SHAPE(lp, 0)[0];
|
139
|
-
const char* criterion = ((split_opts_cls*)lp->opt_ptr)->criterion;
|
140
|
-
const long n_classes = ((split_opts_cls*)lp->opt_ptr)->n_classes;
|
141
|
-
const double w_impurity = ((split_opts_cls*)lp->opt_ptr)->impurity;
|
142
|
-
double* params = (double*)NDL_PTR(lp, 3);
|
143
|
-
long curr_pos = 0;
|
144
|
-
long next_pos = 0;
|
145
|
-
long n_l_elements = 0;
|
146
|
-
long n_r_elements = n_elements;
|
147
|
-
double curr_el = f[o[0]];
|
148
|
-
double last_el = f[o[n_elements - 1]];
|
149
|
-
double next_el;
|
150
|
-
double l_impurity;
|
151
|
-
double r_impurity;
|
152
|
-
double gain;
|
153
|
-
double* l_histogram = alloc_dbl_array(n_classes);
|
154
|
-
double* r_histogram = alloc_dbl_array(n_classes);
|
155
|
-
|
156
|
-
/* Initialize optimal parameters. */
|
157
|
-
params[0] = 0.0; /* left impurity */
|
158
|
-
params[1] = w_impurity; /* right impurity */
|
159
|
-
params[2] = curr_el; /* threshold */
|
160
|
-
params[3] = 0.0; /* gain */
|
161
|
-
|
162
|
-
/* Initialize child node variables. */
|
163
|
-
for (long i = 0; i < n_elements; i++) {
|
164
|
-
r_histogram[y[o[i]]] += 1.0;
|
165
|
-
}
|
166
|
-
|
167
|
-
/* Find optimal parameters. */
|
168
|
-
while (curr_pos < n_elements && curr_el != last_el) {
|
169
|
-
next_el = f[o[next_pos]];
|
170
|
-
while (next_pos < n_elements && next_el == curr_el) {
|
171
|
-
l_histogram[y[o[next_pos]]] += 1;
|
172
|
-
n_l_elements++;
|
173
|
-
r_histogram[y[o[next_pos]]] -= 1;
|
174
|
-
n_r_elements--;
|
175
|
-
next_pos++;
|
176
|
-
next_el = f[o[next_pos]];
|
177
|
-
}
|
178
|
-
/* Calculate gain of new split. */
|
179
|
-
l_impurity = calc_impurity_cls(criterion, l_histogram, n_l_elements, n_classes);
|
180
|
-
r_impurity = calc_impurity_cls(criterion, r_histogram, n_r_elements, n_classes);
|
181
|
-
gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
|
182
|
-
/* Update optimal parameters. */
|
183
|
-
if (gain > params[3]) {
|
184
|
-
params[0] = l_impurity;
|
185
|
-
params[1] = r_impurity;
|
186
|
-
params[2] = 0.5 * (curr_el + next_el);
|
187
|
-
params[3] = gain;
|
188
|
-
}
|
189
|
-
if (next_pos == n_elements) break;
|
190
|
-
curr_pos = next_pos;
|
191
|
-
curr_el = f[o[curr_pos]];
|
192
|
-
}
|
193
|
-
|
194
|
-
xfree(l_histogram);
|
195
|
-
xfree(r_histogram);
|
196
|
-
}
|
197
|
-
|
198
|
-
/**
|
199
|
-
* @!visibility private
|
200
|
-
* Find for split point with maximum information gain.
|
201
|
-
*
|
202
|
-
* @overload find_split_params(criterion, impurity, order, features, labels, n_classes) -> Array<Float>
|
203
|
-
*
|
204
|
-
* @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
|
205
|
-
* @param impurity [Float] The impurity of whole dataset.
|
206
|
-
* @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
|
207
|
-
* @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
|
208
|
-
* @param labels [Numo::Int32] (shape: [n_elements]) The labels.
|
209
|
-
* @param n_classes [Integer] The number of classes.
|
210
|
-
* @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
|
211
|
-
*/
|
212
|
-
static VALUE find_split_params_cls(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE labels,
|
213
|
-
VALUE n_classes) {
|
214
|
-
ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cInt32, 1}};
|
215
|
-
size_t out_shape[1] = {4};
|
216
|
-
ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
|
217
|
-
ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_cls, NO_LOOP, 3, 1, ain, aout};
|
218
|
-
split_opts_cls opts = {StringValuePtr(criterion), NUM2LONG(n_classes), NUM2DBL(impurity)};
|
219
|
-
VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, labels);
|
220
|
-
VALUE results = rb_ary_new2(4);
|
221
|
-
double* params_ptr = (double*)na_get_pointer_for_read(params);
|
222
|
-
rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
|
223
|
-
rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
|
224
|
-
rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
|
225
|
-
rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
|
226
|
-
RB_GC_GUARD(params);
|
227
|
-
RB_GC_GUARD(criterion);
|
228
|
-
return results;
|
229
|
-
}
|
230
|
-
|
231
|
-
/**
|
232
|
-
* @!visibility private
|
233
|
-
*/
|
234
|
-
typedef struct {
|
235
|
-
char* criterion;
|
236
|
-
double impurity;
|
237
|
-
} split_opts_reg;
|
238
|
-
|
239
|
-
/**
|
240
|
-
* @!visibility private
|
241
|
-
*/
|
242
|
-
static void iter_find_split_params_reg(na_loop_t const* lp) {
|
243
|
-
const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
|
244
|
-
const double* f = (double*)NDL_PTR(lp, 1);
|
245
|
-
const double* y = (double*)NDL_PTR(lp, 2);
|
246
|
-
const long n_elements = NDL_SHAPE(lp, 0)[0];
|
247
|
-
const long n_outputs = NDL_SHAPE(lp, 2)[1];
|
248
|
-
const char* criterion = ((split_opts_reg*)lp->opt_ptr)->criterion;
|
249
|
-
const double w_impurity = ((split_opts_reg*)lp->opt_ptr)->impurity;
|
250
|
-
double* params = (double*)NDL_PTR(lp, 3);
|
251
|
-
long curr_pos = 0;
|
252
|
-
long next_pos = 0;
|
253
|
-
long n_l_elements = 0;
|
254
|
-
long n_r_elements = n_elements;
|
255
|
-
double curr_el = f[o[0]];
|
256
|
-
double last_el = f[o[n_elements - 1]];
|
257
|
-
double next_el;
|
258
|
-
double l_impurity;
|
259
|
-
double r_impurity;
|
260
|
-
double gain;
|
261
|
-
double* l_sum_vec = alloc_dbl_array(n_outputs);
|
262
|
-
double* r_sum_vec = alloc_dbl_array(n_outputs);
|
263
|
-
double target_var;
|
264
|
-
VALUE l_target_vecs = rb_ary_new();
|
265
|
-
VALUE r_target_vecs = rb_ary_new();
|
266
|
-
VALUE target;
|
267
|
-
|
268
|
-
/* Initialize optimal parameters. */
|
269
|
-
params[0] = 0.0; /* left impurity */
|
270
|
-
params[1] = w_impurity; /* right impurity */
|
271
|
-
params[2] = curr_el; /* threshold */
|
272
|
-
params[3] = 0.0; /* gain */
|
273
|
-
|
274
|
-
/* Initialize child node variables. */
|
275
|
-
for (long i = 0; i < n_elements; i++) {
|
276
|
-
target = rb_ary_new2(n_outputs);
|
277
|
-
for (long j = 0; j < n_outputs; j++) {
|
278
|
-
target_var = y[o[i] * n_outputs + j];
|
279
|
-
rb_ary_store(target, j, DBL2NUM(target_var));
|
280
|
-
r_sum_vec[j] += target_var;
|
281
|
-
}
|
282
|
-
rb_ary_push(r_target_vecs, target);
|
283
|
-
}
|
284
|
-
|
285
|
-
/* Find optimal parameters. */
|
286
|
-
while (curr_pos < n_elements && curr_el != last_el) {
|
287
|
-
next_el = f[o[next_pos]];
|
288
|
-
while (next_pos < n_elements && next_el == curr_el) {
|
289
|
-
target = rb_ary_shift(r_target_vecs);
|
290
|
-
n_r_elements--;
|
291
|
-
sub_sum_vec(r_sum_vec, target);
|
292
|
-
rb_ary_push(l_target_vecs, target);
|
293
|
-
n_l_elements++;
|
294
|
-
add_sum_vec(l_sum_vec, target);
|
295
|
-
next_pos++;
|
296
|
-
next_el = f[o[next_pos]];
|
297
|
-
}
|
298
|
-
/* Calculate gain of new split. */
|
299
|
-
l_impurity = calc_impurity_reg(criterion, l_target_vecs, l_sum_vec);
|
300
|
-
r_impurity = calc_impurity_reg(criterion, r_target_vecs, r_sum_vec);
|
301
|
-
gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
|
302
|
-
/* Update optimal parameters. */
|
303
|
-
if (gain > params[3]) {
|
304
|
-
params[0] = l_impurity;
|
305
|
-
params[1] = r_impurity;
|
306
|
-
params[2] = 0.5 * (curr_el + next_el);
|
307
|
-
params[3] = gain;
|
308
|
-
}
|
309
|
-
if (next_pos == n_elements) break;
|
310
|
-
curr_pos = next_pos;
|
311
|
-
curr_el = f[o[curr_pos]];
|
312
|
-
}
|
313
|
-
|
314
|
-
xfree(l_sum_vec);
|
315
|
-
xfree(r_sum_vec);
|
316
|
-
}
|
317
|
-
|
318
|
-
/**
|
319
|
-
* @!visibility private
|
320
|
-
* Find for split point with maximum information gain.
|
321
|
-
*
|
322
|
-
* @overload find_split_params(criterion, impurity, order, features, targets) -> Array<Float>
|
323
|
-
*
|
324
|
-
* @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
|
325
|
-
* @param impurity [Float] The impurity of whole dataset.
|
326
|
-
* @param order [Numo::Int32] (shape: [n_samples]) The element indices sorted according to feature values in ascending order.
|
327
|
-
* @param features [Numo::DFloat] (shape: [n_samples]) The feature values.
|
328
|
-
* @param targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
|
329
|
-
* @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
|
330
|
-
*/
|
331
|
-
static VALUE find_split_params_reg(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE targets) {
|
332
|
-
ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 2}};
|
333
|
-
size_t out_shape[1] = {4};
|
334
|
-
ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
|
335
|
-
ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_reg, NO_LOOP, 3, 1, ain, aout};
|
336
|
-
split_opts_reg opts = {StringValuePtr(criterion), NUM2DBL(impurity)};
|
337
|
-
VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, targets);
|
338
|
-
VALUE results = rb_ary_new2(4);
|
339
|
-
double* params_ptr = (double*)na_get_pointer_for_read(params);
|
340
|
-
rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
|
341
|
-
rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
|
342
|
-
rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
|
343
|
-
rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
|
344
|
-
RB_GC_GUARD(params);
|
345
|
-
RB_GC_GUARD(criterion);
|
346
|
-
return results;
|
347
|
-
}
|
348
|
-
|
349
|
-
/**
|
350
|
-
* @!visibility private
|
351
|
-
*/
|
352
|
-
static void iter_find_split_params_grad_reg(na_loop_t const* lp) {
|
353
|
-
const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
|
354
|
-
const double* f = (double*)NDL_PTR(lp, 1);
|
355
|
-
const double* g = (double*)NDL_PTR(lp, 2);
|
356
|
-
const double* h = (double*)NDL_PTR(lp, 3);
|
357
|
-
const double s_grad = ((double*)lp->opt_ptr)[0];
|
358
|
-
const double s_hess = ((double*)lp->opt_ptr)[1];
|
359
|
-
const double reg_lambda = ((double*)lp->opt_ptr)[2];
|
360
|
-
const long n_elements = NDL_SHAPE(lp, 0)[0];
|
361
|
-
double* params = (double*)NDL_PTR(lp, 4);
|
362
|
-
long curr_pos = 0;
|
363
|
-
long next_pos = 0;
|
364
|
-
double curr_el = f[o[0]];
|
365
|
-
double last_el = f[o[n_elements - 1]];
|
366
|
-
double next_el;
|
367
|
-
double l_grad = 0.0;
|
368
|
-
double l_hess = 0.0;
|
369
|
-
double r_grad;
|
370
|
-
double r_hess;
|
371
|
-
double threshold = curr_el;
|
372
|
-
double gain_max = 0.0;
|
373
|
-
double gain;
|
374
|
-
|
375
|
-
/* Find optimal parameters. */
|
376
|
-
while (curr_pos < n_elements && curr_el != last_el) {
|
377
|
-
next_el = f[o[next_pos]];
|
378
|
-
while (next_pos < n_elements && next_el == curr_el) {
|
379
|
-
l_grad += g[o[next_pos]];
|
380
|
-
l_hess += h[o[next_pos]];
|
381
|
-
next_pos++;
|
382
|
-
next_el = f[o[next_pos]];
|
383
|
-
}
|
384
|
-
/* Calculate gain of new split. */
|
385
|
-
r_grad = s_grad - l_grad;
|
386
|
-
r_hess = s_hess - l_hess;
|
387
|
-
gain = (l_grad * l_grad) / (l_hess + reg_lambda) + (r_grad * r_grad) / (r_hess + reg_lambda) -
|
388
|
-
(s_grad * s_grad) / (s_hess + reg_lambda);
|
389
|
-
/* Update optimal parameters. */
|
390
|
-
if (gain > gain_max) {
|
391
|
-
threshold = 0.5 * (curr_el + next_el);
|
392
|
-
gain_max = gain;
|
393
|
-
}
|
394
|
-
if (next_pos == n_elements) {
|
395
|
-
break;
|
396
|
-
}
|
397
|
-
curr_pos = next_pos;
|
398
|
-
curr_el = f[o[curr_pos]];
|
399
|
-
}
|
400
|
-
|
401
|
-
params[0] = threshold;
|
402
|
-
params[1] = gain_max;
|
403
|
-
}
|
404
|
-
|
405
|
-
/**
|
406
|
-
* @!visibility private
|
407
|
-
* Find for split point with maximum information gain.
|
408
|
-
*
|
409
|
-
* @overload find_split_params(order, features, gradients, hessians, sum_gradient, sum_hessian, reg_lambda) -> Array<Float>
|
410
|
-
* @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
|
411
|
-
* @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
|
412
|
-
* @param gradients [Numo::DFloat] (shape: [n_elements]) The gradient values.
|
413
|
-
* @param hessians [Numo::DFloat] (shape: [n_elements]) The hessian values.
|
414
|
-
* @param sum_gradient [Float] The sum of gradient values.
|
415
|
-
* @param sum_hessian [Float] The sum of hessian values.
|
416
|
-
* @param reg_lambda [Float] The L2 regularization term on weight.
|
417
|
-
* @return [Array<Float>] The array consists of optimal parameters including threshold and gain.
|
418
|
-
*/
|
419
|
-
static VALUE find_split_params_grad_reg(VALUE self, VALUE order, VALUE features, VALUE gradients, VALUE hessians,
|
420
|
-
VALUE sum_gradient, VALUE sum_hessian, VALUE reg_lambda) {
|
421
|
-
ndfunc_arg_in_t ain[4] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}};
|
422
|
-
size_t out_shape[1] = {2};
|
423
|
-
ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
|
424
|
-
ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_grad_reg, NO_LOOP, 4, 1, ain, aout};
|
425
|
-
double opts[3] = {NUM2DBL(sum_gradient), NUM2DBL(sum_hessian), NUM2DBL(reg_lambda)};
|
426
|
-
VALUE params = na_ndloop3(&ndf, opts, 4, order, features, gradients, hessians);
|
427
|
-
VALUE results = rb_ary_new2(2);
|
428
|
-
double* params_ptr = (double*)na_get_pointer_for_read(params);
|
429
|
-
rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
|
430
|
-
rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
|
431
|
-
RB_GC_GUARD(params);
|
432
|
-
return results;
|
433
|
-
}
|
434
|
-
|
435
|
-
/**
|
436
|
-
* @!visibility private
|
437
|
-
*/
|
438
|
-
typedef struct {
|
439
|
-
char* criterion;
|
440
|
-
long n_classes;
|
441
|
-
} node_impurity_cls_opts;
|
442
|
-
|
443
|
-
/**
|
444
|
-
* @!visibility private
|
445
|
-
*/
|
446
|
-
static void iter_node_impurity_cls(na_loop_t const* lp) {
|
447
|
-
const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
|
448
|
-
const char* criterion = ((node_impurity_cls_opts*)lp->opt_ptr)->criterion;
|
449
|
-
const long n_classes = ((node_impurity_cls_opts*)lp->opt_ptr)->n_classes;
|
450
|
-
const long n_elements = NDL_SHAPE(lp, 0)[0];
|
451
|
-
double* ret = (double*)NDL_PTR(lp, 1);
|
452
|
-
double* histogram = alloc_dbl_array(n_classes);
|
453
|
-
for (long i = 0; i < n_elements; i++) histogram[y[i]] += 1;
|
454
|
-
*ret = calc_impurity_cls(criterion, histogram, n_elements, n_classes);
|
455
|
-
xfree(histogram);
|
456
|
-
}
|
457
|
-
|
458
|
-
/**
|
459
|
-
* @!visibility private
|
460
|
-
* Calculate impurity based on criterion.
|
461
|
-
*
|
462
|
-
* @overload node_impurity(criterion, y, n_classes) -> Float
|
463
|
-
*
|
464
|
-
* @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
|
465
|
-
* @param y [Numo::Int32] (shape: [n_samples]) The labels.
|
466
|
-
* @param n_classes [Integer] The number of classes.
|
467
|
-
* @return [Float] impurity
|
468
|
-
*/
|
469
|
-
static VALUE node_impurity_cls(VALUE self, VALUE criterion, VALUE y, VALUE n_classes) {
|
470
|
-
ndfunc_arg_in_t ain[1] = {{numo_cInt32, 1}};
|
471
|
-
ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 0}};
|
472
|
-
ndfunc_t ndf = {(na_iter_func_t)iter_node_impurity_cls, NDF_EXTRACT, 1, 1, ain, aout};
|
473
|
-
node_impurity_cls_opts opts = {StringValuePtr(criterion), NUM2LONG(n_classes)};
|
474
|
-
VALUE ret = na_ndloop3(&ndf, &opts, 1, y);
|
475
|
-
RB_GC_GUARD(criterion);
|
476
|
-
return ret;
|
477
|
-
}
|
478
|
-
|
479
|
-
/**
|
480
|
-
* @!visibility private
|
481
|
-
*/
|
482
|
-
static void iter_check_same_label(na_loop_t const* lp) {
|
483
|
-
const int32_t* y = (int32_t*)NDL_PTR(lp, 0);
|
484
|
-
const long n_elements = NDL_SHAPE(lp, 0)[0];
|
485
|
-
int32_t* ret = (int32_t*)NDL_PTR(lp, 1);
|
486
|
-
*ret = 1;
|
487
|
-
if (n_elements > 0) {
|
488
|
-
int32_t label = y[0];
|
489
|
-
for (long i = 0; i < n_elements; i++) {
|
490
|
-
if (y[i] != label) {
|
491
|
-
*ret = 0;
|
492
|
-
break;
|
493
|
-
}
|
494
|
-
}
|
495
|
-
}
|
496
|
-
}
|
497
|
-
|
498
|
-
/**
|
499
|
-
* @!visibility private
|
500
|
-
* Check all elements have the save value.
|
501
|
-
*
|
502
|
-
* @overload check_same_label(y) -> Boolean
|
503
|
-
*
|
504
|
-
* @param y [Numo::Int32] (shape: [n_samples]) The labels.
|
505
|
-
* @return [Boolean]
|
506
|
-
*/
|
507
|
-
static VALUE check_same_label(VALUE self, VALUE y) {
|
508
|
-
ndfunc_arg_in_t ain[1] = {{numo_cInt32, 1}};
|
509
|
-
ndfunc_arg_out_t aout[1] = {{numo_cInt32, 0}};
|
510
|
-
ndfunc_t ndf = {(na_iter_func_t)iter_check_same_label, NDF_EXTRACT, 1, 1, ain, aout};
|
511
|
-
VALUE ret = na_ndloop(&ndf, 1, y);
|
512
|
-
return (NUM2INT(ret) == 1 ? Qtrue : Qfalse);
|
513
|
-
}
|
514
|
-
|
515
|
-
/**
|
516
|
-
* @!visibility private
|
517
|
-
* Calculate impurity based on criterion.
|
518
|
-
*
|
519
|
-
* @overload node_impurity(criterion, y) -> Float
|
520
|
-
*
|
521
|
-
* @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
|
522
|
-
* @param y [Array<Float>] (shape: [n_samples, n_outputs]) The taget values.
|
523
|
-
* @return [Float] impurity
|
524
|
-
*/
|
525
|
-
static VALUE node_impurity_reg(VALUE self, VALUE criterion, VALUE y) {
|
526
|
-
const long n_elements = RARRAY_LEN(y);
|
527
|
-
const long n_outputs = RARRAY_LEN(rb_ary_entry(y, 0));
|
528
|
-
double* sum_vec = alloc_dbl_array(n_outputs);
|
529
|
-
VALUE target_vecs = rb_ary_new();
|
530
|
-
|
531
|
-
for (long i = 0; i < n_elements; i++) {
|
532
|
-
VALUE target = rb_ary_entry(y, i);
|
533
|
-
add_sum_vec(sum_vec, target);
|
534
|
-
rb_ary_push(target_vecs, target);
|
535
|
-
}
|
536
|
-
|
537
|
-
VALUE ret = DBL2NUM(calc_impurity_reg(StringValuePtr(criterion), target_vecs, sum_vec));
|
538
|
-
xfree(sum_vec);
|
539
|
-
RB_GC_GUARD(criterion);
|
540
|
-
return ret;
|
541
|
-
}
|
542
|
-
|
543
|
-
void Init_ext(void) {
|
544
|
-
VALUE rb_mRumale = rb_define_module("Rumale");
|
545
|
-
VALUE rb_mTree = rb_define_module_under(rb_mRumale, "Tree");
|
546
|
-
|
547
|
-
/**
|
548
|
-
* Document-module: Rumale::Tree::ExtDecisionTreeClassifier
|
549
|
-
* @!visibility private
|
550
|
-
* The mixin module consisting of extension method for DecisionTreeClassifier class.
|
551
|
-
* This module is used internally.
|
552
|
-
*/
|
553
|
-
VALUE rb_mExtDTreeCls = rb_define_module_under(rb_mTree, "ExtDecisionTreeClassifier");
|
554
|
-
/**
|
555
|
-
* Document-module: Rumale::Tree::ExtDecisionTreeRegressor
|
556
|
-
* @!visibility private
|
557
|
-
* The mixin module consisting of extension method for DecisionTreeRegressor class.
|
558
|
-
* This module is used internally.
|
559
|
-
*/
|
560
|
-
VALUE rb_mExtDTreeReg = rb_define_module_under(rb_mTree, "ExtDecisionTreeRegressor");
|
561
|
-
/**
|
562
|
-
* Document-module: Rumale::Tree::ExtGradientTreeRegressor
|
563
|
-
* @!visibility private
|
564
|
-
* The mixin module consisting of extension method for GradientTreeRegressor class.
|
565
|
-
* This module is used internally.
|
566
|
-
*/
|
567
|
-
VALUE rb_mExtGTreeReg = rb_define_module_under(rb_mTree, "ExtGradientTreeRegressor");
|
568
|
-
|
569
|
-
rb_define_private_method(rb_mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
|
570
|
-
rb_define_private_method(rb_mExtDTreeReg, "find_split_params", find_split_params_reg, 5);
|
571
|
-
rb_define_private_method(rb_mExtGTreeReg, "find_split_params", find_split_params_grad_reg, 7);
|
572
|
-
rb_define_private_method(rb_mExtDTreeCls, "node_impurity", node_impurity_cls, 3);
|
573
|
-
rb_define_private_method(rb_mExtDTreeCls, "stop_growing?", check_same_label, 1);
|
574
|
-
rb_define_private_method(rb_mExtDTreeReg, "node_impurity", node_impurity_reg, 2);
|
575
|
-
}
|