rumale 0.23.1 → 0.23.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e3c7b4dd3b452f96f88f368a9b279fc67dd7e2fd0033f7a06247e052252de18f
4
- data.tar.gz: 88913193c9a6d33cd16cdd45b6a22bf94c072f6ebcb141571dcaef2a0f7aec71
3
+ metadata.gz: 4564c37af7744bc4fe14dec5c5fc1e236687c3a241d2e17ef2d89f1c57056af9
4
+ data.tar.gz: 6f70d79a10b890bbd127f60f1c7f26934fcd88f71458af8839ac049b7a07efc8
5
5
  SHA512:
6
- metadata.gz: e6f824f82415c8dfca7448505a2743bd94a89e9d0575e1c8edf6cdd37bd81af991ab9ed0c4970ed4572d2296c43d5c69331b669be9f4c5a60f9b900b7d220744
7
- data.tar.gz: bfebdfc2110f159c2aa0b3cd00b33455e1cfc38bc7bdce36be98e3a21b6138ffbc0299eae7f8b4a913629611d168095dcbc0d9e560ede89463963b2284d95689
6
+ metadata.gz: 5671a08ac8e9881f51896c4478ce5f4b54457c83d9b7194623febfd1859123cda5947c0d344aa551686c2c964359e9bdbd5ad13e9c921d2a3393a76717c00093
7
+ data.tar.gz: bb022827e8ca9d939addb9cfdd9b5fa5b643cd56150a84f41a224dde0c75992badbf792f77d06194943f015572beb5bafdd3c84e43efd98be5cc53beb9347ab0
data/CHANGELOG.md CHANGED
@@ -1,3 +1,10 @@
1
+ # 0.23.2
2
+ Rumale project will be rebooted on version 0.24.0.
3
+ This version is probably the last release of the series starting with version 0.8.0.
4
+
5
+ - Refactor some codes and configs.
6
+ - Deprecate VPTree class.
7
+
1
8
  # 0.23.1
2
9
  - Fix all estimators to return inference results in a contiguous narray.
3
10
  - Fix to use until statement instead of recursive call on apply methods of tree estimators.
data/LICENSE.txt CHANGED
@@ -1,4 +1,4 @@
1
- Copyright (c) 2017-2021 Atsushi Tatsuma
1
+ Copyright (c) 2017-2022 Atsushi Tatsuma
2
2
  All rights reserved.
3
3
 
4
4
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -1,9 +1,10 @@
1
1
  # Rumale
2
2
 
3
+ **This project is suspended for the author's health reasons. It will be resumed when the author recovers.**
4
+
3
5
  ![Rumale](https://dl.dropboxusercontent.com/s/joxruk2720ur66o/rumale_header_400.png)
4
6
 
5
7
  [![Build Status](https://github.com/yoshoku/rumale/actions/workflows/build.yml/badge.svg)](https://github.com/yoshoku/rumale/actions/workflows/build.yml)
6
- [![Coverage Status](https://coveralls.io/repos/github/yoshoku/rumale/badge.svg?branch=main)](https://coveralls.io/github/yoshoku/rumale?branch=main)
7
8
  [![Gem Version](https://badge.fury.io/rb/rumale.svg)](https://badge.fury.io/rb/rumale)
8
9
  [![BSD 2-Clause License](https://img.shields.io/badge/License-BSD%202--Clause-orange.svg)](https://github.com/yoshoku/rumale/blob/main/LICENSE.txt)
9
10
  [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale/doc/)
@@ -189,6 +190,12 @@ Ubuntu:
189
190
  $ sudo apt-get install libopenblas-dev liblapacke-dev
190
191
  ```
191
192
 
193
+ Fedora:
194
+
195
+ ```bash
196
+ $ sudo dnf install openblas-devel lapack-devel
197
+ ```
198
+
192
199
  Windows (MSYS2):
193
200
 
194
201
  ```bash
@@ -226,6 +233,12 @@ Ubuntu:
226
233
  $ sudo apt-get install gcc gfortran make
227
234
  ```
228
235
 
236
+ Fedora:
237
+
238
+ ```bash
239
+ $ sudo dnf install gcc gcc-gfortran make
240
+ ```
241
+
229
242
  Install Numo::OpenBLAS gem.
230
243
 
231
244
  ```bash
@@ -239,6 +252,25 @@ require 'numo/openblas'
239
252
  require 'rumale'
240
253
  ```
241
254
 
255
+ ### Numo::BLIS
256
+ [Numo::BLIS](https://github.com/yoshoku/numo-blis) downloads and builds BLIS during installation
257
+ and uses that as a background library for Numo::Linalg.
258
+ BLIS is one of the high-performance BLAS as with OpenBLAS,
259
+ and using that can be expected to speed up of processing in Rumale.
260
+
261
+ Install Numo::BLIS gem.
262
+
263
+ ```bash
264
+ $ gem install numo-blis
265
+ ```
266
+
267
+ Load Numo::BLIS gem instead of Numo::Linalg.
268
+
269
+ ```ruby
270
+ require 'numo/blis'
271
+ require 'rumale'
272
+ ```
273
+
242
274
  ### Parallel
243
275
  Several estimators in Rumale support parallel processing.
244
276
  Parallel processing in Rumale is realized by [Parallel](https://github.com/grosser/parallel) gem,
@@ -1,9 +1,545 @@
1
1
  #include "rumaleext.h"
2
2
 
3
- VALUE mRumale;
3
+ double* alloc_dbl_array(const long n_dimensions) {
4
+ double* arr = ALLOC_N(double, n_dimensions);
5
+ memset(arr, 0, n_dimensions * sizeof(double));
6
+ return arr;
7
+ }
8
+
9
+ double calc_gini_coef(double* histogram, const long n_elements, const long n_classes) {
10
+ long i;
11
+ double el;
12
+ double gini = 0.0;
13
+
14
+ for (i = 0; i < n_classes; i++) {
15
+ el = histogram[i] / n_elements;
16
+ gini += el * el;
17
+ }
18
+
19
+ return 1.0 - gini;
20
+ }
21
+
22
+ double calc_entropy(double* histogram, const long n_elements, const long n_classes) {
23
+ long i;
24
+ double el;
25
+ double entropy = 0.0;
26
+
27
+ for (i = 0; i < n_classes; i++) {
28
+ el = histogram[i] / n_elements;
29
+ entropy += el * log(el + 1.0);
30
+ }
31
+
32
+ return -entropy;
33
+ }
34
+
35
+ VALUE
36
+ calc_mean_vec(double* sum_vec, const long n_dimensions, const long n_elements) {
37
+ long i;
38
+ VALUE mean_vec = rb_ary_new2(n_dimensions);
39
+
40
+ for (i = 0; i < n_dimensions; i++) {
41
+ rb_ary_store(mean_vec, i, DBL2NUM(sum_vec[i] / n_elements));
42
+ }
43
+
44
+ return mean_vec;
45
+ }
46
+
47
+ double calc_vec_mae(VALUE vec_a, VALUE vec_b) {
48
+ long i;
49
+ const long n_dimensions = RARRAY_LEN(vec_a);
50
+ double sum = 0.0;
51
+ double diff;
52
+
53
+ for (i = 0; i < n_dimensions; i++) {
54
+ diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
55
+ sum += fabs(diff);
56
+ }
57
+
58
+ return sum / n_dimensions;
59
+ }
60
+
61
+ double calc_vec_mse(VALUE vec_a, VALUE vec_b) {
62
+ long i;
63
+ const long n_dimensions = RARRAY_LEN(vec_a);
64
+ double sum = 0.0;
65
+ double diff;
66
+
67
+ for (i = 0; i < n_dimensions; i++) {
68
+ diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
69
+ sum += diff * diff;
70
+ }
71
+
72
+ return sum / n_dimensions;
73
+ }
74
+
75
+ double calc_mae(VALUE target_vecs, VALUE mean_vec) {
76
+ long i;
77
+ const long n_elements = RARRAY_LEN(target_vecs);
78
+ double sum = 0.0;
79
+
80
+ for (i = 0; i < n_elements; i++) {
81
+ sum += calc_vec_mae(rb_ary_entry(target_vecs, i), mean_vec);
82
+ }
83
+
84
+ return sum / n_elements;
85
+ }
86
+
87
+ double calc_mse(VALUE target_vecs, VALUE mean_vec) {
88
+ long i;
89
+ const long n_elements = RARRAY_LEN(target_vecs);
90
+ double sum = 0.0;
91
+
92
+ for (i = 0; i < n_elements; i++) {
93
+ sum += calc_vec_mse(rb_ary_entry(target_vecs, i), mean_vec);
94
+ }
95
+
96
+ return sum / n_elements;
97
+ }
98
+
99
+ double calc_impurity_cls(const char* criterion, double* histogram, const long n_elements, const long n_classes) {
100
+ if (strcmp(criterion, "entropy") == 0) {
101
+ return calc_entropy(histogram, n_elements, n_classes);
102
+ }
103
+ return calc_gini_coef(histogram, n_elements, n_classes);
104
+ }
105
+
106
+ double calc_impurity_reg(const char* criterion, VALUE target_vecs, double* sum_vec) {
107
+ const long n_elements = RARRAY_LEN(target_vecs);
108
+ const long n_dimensions = RARRAY_LEN(rb_ary_entry(target_vecs, 0));
109
+ VALUE mean_vec = calc_mean_vec(sum_vec, n_dimensions, n_elements);
110
+
111
+ if (strcmp(criterion, "mae") == 0) {
112
+ return calc_mae(target_vecs, mean_vec);
113
+ }
114
+ return calc_mse(target_vecs, mean_vec);
115
+ }
116
+
117
+ void add_sum_vec(double* sum_vec, VALUE target) {
118
+ long i;
119
+ const long n_dimensions = RARRAY_LEN(target);
120
+
121
+ for (i = 0; i < n_dimensions; i++) {
122
+ sum_vec[i] += NUM2DBL(rb_ary_entry(target, i));
123
+ }
124
+ }
125
+
126
+ void sub_sum_vec(double* sum_vec, VALUE target) {
127
+ long i;
128
+ const long n_dimensions = RARRAY_LEN(target);
129
+
130
+ for (i = 0; i < n_dimensions; i++) {
131
+ sum_vec[i] -= NUM2DBL(rb_ary_entry(target, i));
132
+ }
133
+ }
134
+
135
+ /**
136
+ * @!visibility private
137
+ */
138
+ typedef struct {
139
+ char* criterion;
140
+ long n_classes;
141
+ double impurity;
142
+ } split_opts_cls;
143
+ /**
144
+ * @!visibility private
145
+ */
146
+ static void iter_find_split_params_cls(na_loop_t const* lp) {
147
+ const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
148
+ const double* f = (double*)NDL_PTR(lp, 1);
149
+ const int32_t* y = (int32_t*)NDL_PTR(lp, 2);
150
+ const long n_elements = NDL_SHAPE(lp, 0)[0];
151
+ const char* criterion = ((split_opts_cls*)lp->opt_ptr)->criterion;
152
+ const long n_classes = ((split_opts_cls*)lp->opt_ptr)->n_classes;
153
+ const double w_impurity = ((split_opts_cls*)lp->opt_ptr)->impurity;
154
+ double* params = (double*)NDL_PTR(lp, 3);
155
+ long i;
156
+ long curr_pos = 0;
157
+ long next_pos = 0;
158
+ long n_l_elements = 0;
159
+ long n_r_elements = n_elements;
160
+ double curr_el = f[o[0]];
161
+ double last_el = f[o[n_elements - 1]];
162
+ double next_el;
163
+ double l_impurity;
164
+ double r_impurity;
165
+ double gain;
166
+ double* l_histogram = alloc_dbl_array(n_classes);
167
+ double* r_histogram = alloc_dbl_array(n_classes);
168
+
169
+ /* Initialize optimal parameters. */
170
+ params[0] = 0.0; /* left impurity */
171
+ params[1] = w_impurity; /* right impurity */
172
+ params[2] = curr_el; /* threshold */
173
+ params[3] = 0.0; /* gain */
174
+
175
+ /* Initialize child node variables. */
176
+ for (i = 0; i < n_elements; i++) {
177
+ r_histogram[y[o[i]]] += 1.0;
178
+ }
179
+
180
+ /* Find optimal parameters. */
181
+ while (curr_pos < n_elements && curr_el != last_el) {
182
+ next_el = f[o[next_pos]];
183
+ while (next_pos < n_elements && next_el == curr_el) {
184
+ l_histogram[y[o[next_pos]]] += 1;
185
+ n_l_elements++;
186
+ r_histogram[y[o[next_pos]]] -= 1;
187
+ n_r_elements--;
188
+ next_pos++;
189
+ next_el = f[o[next_pos]];
190
+ }
191
+ /* Calculate gain of new split. */
192
+ l_impurity = calc_impurity_cls(criterion, l_histogram, n_l_elements, n_classes);
193
+ r_impurity = calc_impurity_cls(criterion, r_histogram, n_r_elements, n_classes);
194
+ gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
195
+ /* Update optimal parameters. */
196
+ if (gain > params[3]) {
197
+ params[0] = l_impurity;
198
+ params[1] = r_impurity;
199
+ params[2] = 0.5 * (curr_el + next_el);
200
+ params[3] = gain;
201
+ }
202
+ if (next_pos == n_elements) break;
203
+ curr_pos = next_pos;
204
+ curr_el = f[o[curr_pos]];
205
+ }
206
+
207
+ xfree(l_histogram);
208
+ xfree(r_histogram);
209
+ }
210
+ /**
211
+ * @!visibility private
212
+ * Find for split point with maximum information gain.
213
+ *
214
+ * @overload find_split_params(criterion, impurity, order, features, labels, n_classes) -> Array<Float>
215
+ *
216
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
217
+ * @param impurity [Float] The impurity of whole dataset.
218
+ * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
219
+ * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
220
+ * @param labels [Numo::Int32] (shape: [n_elements]) The labels.
221
+ * @param n_classes [Integer] The number of classes.
222
+ * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
223
+ */
224
+ static VALUE find_split_params_cls(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE labels,
225
+ VALUE n_classes) {
226
+ ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cInt32, 1}};
227
+ size_t out_shape[1] = {4};
228
+ ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
229
+ ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_cls, NO_LOOP, 3, 1, ain, aout};
230
+ split_opts_cls opts = {StringValuePtr(criterion), NUM2LONG(n_classes), NUM2DBL(impurity)};
231
+ VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, labels);
232
+ VALUE results = rb_ary_new2(4);
233
+ double* params_ptr = (double*)na_get_pointer_for_read(params);
234
+ rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
235
+ rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
236
+ rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
237
+ rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
238
+ RB_GC_GUARD(params);
239
+ RB_GC_GUARD(criterion);
240
+ return results;
241
+ }
242
+
243
+ /**
244
+ * @!visibility private
245
+ */
246
+ typedef struct {
247
+ char* criterion;
248
+ double impurity;
249
+ } split_opts_reg;
250
+ /**
251
+ * @!visibility private
252
+ */
253
+ static void iter_find_split_params_reg(na_loop_t const* lp) {
254
+ const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
255
+ const double* f = (double*)NDL_PTR(lp, 1);
256
+ const double* y = (double*)NDL_PTR(lp, 2);
257
+ const long n_elements = NDL_SHAPE(lp, 0)[0];
258
+ const long n_outputs = NDL_SHAPE(lp, 2)[1];
259
+ const char* criterion = ((split_opts_reg*)lp->opt_ptr)->criterion;
260
+ const double w_impurity = ((split_opts_reg*)lp->opt_ptr)->impurity;
261
+ double* params = (double*)NDL_PTR(lp, 3);
262
+ long i, j;
263
+ long curr_pos = 0;
264
+ long next_pos = 0;
265
+ long n_l_elements = 0;
266
+ long n_r_elements = n_elements;
267
+ double curr_el = f[o[0]];
268
+ double last_el = f[o[n_elements - 1]];
269
+ double next_el;
270
+ double l_impurity;
271
+ double r_impurity;
272
+ double gain;
273
+ double* l_sum_vec = alloc_dbl_array(n_outputs);
274
+ double* r_sum_vec = alloc_dbl_array(n_outputs);
275
+ double target_var;
276
+ VALUE l_target_vecs = rb_ary_new();
277
+ VALUE r_target_vecs = rb_ary_new();
278
+ VALUE target;
279
+
280
+ /* Initialize optimal parameters. */
281
+ params[0] = 0.0; /* left impurity */
282
+ params[1] = w_impurity; /* right impurity */
283
+ params[2] = curr_el; /* threshold */
284
+ params[3] = 0.0; /* gain */
285
+
286
+ /* Initialize child node variables. */
287
+ for (i = 0; i < n_elements; i++) {
288
+ target = rb_ary_new2(n_outputs);
289
+ for (j = 0; j < n_outputs; j++) {
290
+ target_var = y[o[i] * n_outputs + j];
291
+ rb_ary_store(target, j, DBL2NUM(target_var));
292
+ r_sum_vec[j] += target_var;
293
+ }
294
+ rb_ary_push(r_target_vecs, target);
295
+ }
296
+
297
+ /* Find optimal parameters. */
298
+ while (curr_pos < n_elements && curr_el != last_el) {
299
+ next_el = f[o[next_pos]];
300
+ while (next_pos < n_elements && next_el == curr_el) {
301
+ target = rb_ary_shift(r_target_vecs);
302
+ n_r_elements--;
303
+ sub_sum_vec(r_sum_vec, target);
304
+ rb_ary_push(l_target_vecs, target);
305
+ n_l_elements++;
306
+ add_sum_vec(l_sum_vec, target);
307
+ next_pos++;
308
+ next_el = f[o[next_pos]];
309
+ }
310
+ /* Calculate gain of new split. */
311
+ l_impurity = calc_impurity_reg(criterion, l_target_vecs, l_sum_vec);
312
+ r_impurity = calc_impurity_reg(criterion, r_target_vecs, r_sum_vec);
313
+ gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
314
+ /* Update optimal parameters. */
315
+ if (gain > params[3]) {
316
+ params[0] = l_impurity;
317
+ params[1] = r_impurity;
318
+ params[2] = 0.5 * (curr_el + next_el);
319
+ params[3] = gain;
320
+ }
321
+ if (next_pos == n_elements) break;
322
+ curr_pos = next_pos;
323
+ curr_el = f[o[curr_pos]];
324
+ }
325
+
326
+ xfree(l_sum_vec);
327
+ xfree(r_sum_vec);
328
+ }
329
+ /**
330
+ * @!visibility private
331
+ * Find for split point with maximum information gain.
332
+ *
333
+ * @overload find_split_params(criterion, impurity, order, features, targets) -> Array<Float>
334
+ *
335
+ * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
336
+ * @param impurity [Float] The impurity of whole dataset.
337
+ * @param order [Numo::Int32] (shape: [n_samples]) The element indices sorted according to feature values in ascending order.
338
+ * @param features [Numo::DFloat] (shape: [n_samples]) The feature values.
339
+ * @param targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
340
+ * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
341
+ */
342
+ static VALUE find_split_params_reg(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE targets) {
343
+ ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 2}};
344
+ size_t out_shape[1] = {4};
345
+ ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
346
+ ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_reg, NO_LOOP, 3, 1, ain, aout};
347
+ split_opts_reg opts = {StringValuePtr(criterion), NUM2DBL(impurity)};
348
+ VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, targets);
349
+ VALUE results = rb_ary_new2(4);
350
+ double* params_ptr = (double*)na_get_pointer_for_read(params);
351
+ rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
352
+ rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
353
+ rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
354
+ rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
355
+ RB_GC_GUARD(params);
356
+ RB_GC_GUARD(criterion);
357
+ return results;
358
+ }
359
+
360
+ /**
361
+ * @!visibility private
362
+ */
363
+ static void iter_find_split_params_grad_reg(na_loop_t const* lp) {
364
+ const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
365
+ const double* f = (double*)NDL_PTR(lp, 1);
366
+ const double* g = (double*)NDL_PTR(lp, 2);
367
+ const double* h = (double*)NDL_PTR(lp, 3);
368
+ const double s_grad = ((double*)lp->opt_ptr)[0];
369
+ const double s_hess = ((double*)lp->opt_ptr)[1];
370
+ const double reg_lambda = ((double*)lp->opt_ptr)[2];
371
+ const long n_elements = NDL_SHAPE(lp, 0)[0];
372
+ double* params = (double*)NDL_PTR(lp, 4);
373
+ long curr_pos = 0;
374
+ long next_pos = 0;
375
+ double curr_el = f[o[0]];
376
+ double last_el = f[o[n_elements - 1]];
377
+ double next_el;
378
+ double l_grad = 0.0;
379
+ double l_hess = 0.0;
380
+ double r_grad;
381
+ double r_hess;
382
+ double threshold = curr_el;
383
+ double gain_max = 0.0;
384
+ double gain;
385
+
386
+ /* Find optimal parameters. */
387
+ while (curr_pos < n_elements && curr_el != last_el) {
388
+ next_el = f[o[next_pos]];
389
+ while (next_pos < n_elements && next_el == curr_el) {
390
+ l_grad += g[o[next_pos]];
391
+ l_hess += h[o[next_pos]];
392
+ next_pos++;
393
+ next_el = f[o[next_pos]];
394
+ }
395
+ /* Calculate gain of new split. */
396
+ r_grad = s_grad - l_grad;
397
+ r_hess = s_hess - l_hess;
398
+ gain = (l_grad * l_grad) / (l_hess + reg_lambda) + (r_grad * r_grad) / (r_hess + reg_lambda) -
399
+ (s_grad * s_grad) / (s_hess + reg_lambda);
400
+ /* Update optimal parameters. */
401
+ if (gain > gain_max) {
402
+ threshold = 0.5 * (curr_el + next_el);
403
+ gain_max = gain;
404
+ }
405
+ if (next_pos == n_elements) {
406
+ break;
407
+ }
408
+ curr_pos = next_pos;
409
+ curr_el = f[o[curr_pos]];
410
+ }
411
+
412
+ params[0] = threshold;
413
+ params[1] = gain_max;
414
+ }
415
+
416
+ /**
417
+ * @!visibility private
418
+ * Find for split point with maximum information gain.
419
+ *
420
+ * @overload find_split_params(order, features, gradients, hessians, sum_gradient, sum_hessian, reg_lambda) -> Array<Float>
421
+ * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
422
+ * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
423
+ * @param gradients [Numo::DFloat] (shape: [n_elements]) The gradient values.
424
+ * @param hessians [Numo::DFloat] (shape: [n_elements]) The hessian values.
425
+ * @param sum_gradient [Float] The sum of gradient values.
426
+ * @param sum_hessian [Float] The sum of hessian values.
427
+ * @param reg_lambda [Float] The L2 regularization term on weight.
428
+ * @return [Array<Float>] The array consists of optimal parameters including threshold and gain.
429
+ */
430
+ static VALUE find_split_params_grad_reg(VALUE self, VALUE order, VALUE features, VALUE gradients, VALUE hessians,
431
+ VALUE sum_gradient, VALUE sum_hessian, VALUE reg_lambda) {
432
+ ndfunc_arg_in_t ain[4] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}};
433
+ size_t out_shape[1] = {2};
434
+ ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
435
+ ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_grad_reg, NO_LOOP, 4, 1, ain, aout};
436
+ double opts[3] = {NUM2DBL(sum_gradient), NUM2DBL(sum_hessian), NUM2DBL(reg_lambda)};
437
+ VALUE params = na_ndloop3(&ndf, opts, 4, order, features, gradients, hessians);
438
+ VALUE results = rb_ary_new2(2);
439
+ double* params_ptr = (double*)na_get_pointer_for_read(params);
440
+ rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
441
+ rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
442
+ RB_GC_GUARD(params);
443
+ return results;
444
+ }
445
+
446
+ /**
447
+ * @!visibility private
448
+ * Calculate impurity based on criterion.
449
+ *
450
+ * @overload node_impurity(criterion, y, n_classes) -> Float
451
+ *
452
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
453
+ * @param y_nary [Numo::Int32] (shape: [n_samples]) The labels.
454
+ * @param n_elements_ [Integer] The number of elements.
455
+ * @param n_classes_ [Integer] The number of classes.
456
+ * @return [Float] impurity
457
+ */
458
+ static VALUE node_impurity_cls(VALUE self, VALUE criterion, VALUE y_nary, VALUE n_elements_, VALUE n_classes_) {
459
+ long i;
460
+ const long n_classes = NUM2LONG(n_classes_);
461
+ const long n_elements = NUM2LONG(n_elements_);
462
+ const int32_t* y = (int32_t*)na_get_pointer_for_read(y_nary);
463
+ double* histogram = alloc_dbl_array(n_classes);
464
+ VALUE ret;
465
+
466
+ for (i = 0; i < n_elements; i++) {
467
+ histogram[y[i]] += 1;
468
+ }
469
+
470
+ ret = DBL2NUM(calc_impurity_cls(StringValuePtr(criterion), histogram, n_elements, n_classes));
471
+
472
+ xfree(histogram);
473
+
474
+ RB_GC_GUARD(y_nary);
475
+ RB_GC_GUARD(criterion);
476
+
477
+ return ret;
478
+ }
479
+
480
+ /**
481
+ * @!visibility private
482
+ * Calculate impurity based on criterion.
483
+ *
484
+ * @overload node_impurity(criterion, y) -> Float
485
+ *
486
+ * @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
487
+ * @param y [Array<Float>] (shape: [n_samples, n_outputs]) The taget values.
488
+ * @return [Float] impurity
489
+ */
490
+ static VALUE node_impurity_reg(VALUE self, VALUE criterion, VALUE y) {
491
+ long i;
492
+ const long n_elements = RARRAY_LEN(y);
493
+ const long n_outputs = RARRAY_LEN(rb_ary_entry(y, 0));
494
+ double* sum_vec = alloc_dbl_array(n_outputs);
495
+ VALUE target_vecs = rb_ary_new();
496
+ VALUE target;
497
+ VALUE ret;
498
+
499
+ for (i = 0; i < n_elements; i++) {
500
+ target = rb_ary_entry(y, i);
501
+ add_sum_vec(sum_vec, target);
502
+ rb_ary_push(target_vecs, target);
503
+ }
504
+
505
+ ret = DBL2NUM(calc_impurity_reg(StringValuePtr(criterion), target_vecs, sum_vec));
506
+
507
+ xfree(sum_vec);
508
+
509
+ RB_GC_GUARD(criterion);
510
+
511
+ return ret;
512
+ }
4
513
 
5
514
  void Init_rumaleext(void) {
6
- mRumale = rb_define_module("Rumale");
515
+ VALUE mRumale = rb_define_module("Rumale");
516
+ VALUE mTree = rb_define_module_under(mRumale, "Tree");
517
+
518
+ /**
519
+ * Document-module: Rumale::Tree::ExtDecisionTreeClassifier
520
+ * @!visibility private
521
+ * The mixin module consisting of extension method for DecisionTreeClassifier class.
522
+ * This module is used internally.
523
+ */
524
+ VALUE mExtDTreeCls = rb_define_module_under(mTree, "ExtDecisionTreeClassifier");
525
+ /**
526
+ * Document-module: Rumale::Tree::ExtDecisionTreeRegressor
527
+ * @!visibility private
528
+ * The mixin module consisting of extension method for DecisionTreeRegressor class.
529
+ * This module is used internally.
530
+ */
531
+ VALUE mExtDTreeReg = rb_define_module_under(mTree, "ExtDecisionTreeRegressor");
532
+ /**
533
+ * Document-module: Rumale::Tree::ExtGradientTreeRegressor
534
+ * @!visibility private
535
+ * The mixin module consisting of extension method for GradientTreeRegressor class.
536
+ * This module is used internally.
537
+ */
538
+ VALUE mExtGTreeReg = rb_define_module_under(mTree, "ExtGradientTreeRegressor");
7
539
 
8
- init_tree_module();
540
+ rb_define_private_method(mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
541
+ rb_define_private_method(mExtDTreeReg, "find_split_params", find_split_params_reg, 5);
542
+ rb_define_private_method(mExtGTreeReg, "find_split_params", find_split_params_grad_reg, 7);
543
+ rb_define_private_method(mExtDTreeCls, "node_impurity", node_impurity_cls, 4);
544
+ rb_define_private_method(mExtDTreeReg, "node_impurity", node_impurity_reg, 2);
9
545
  }
@@ -1,8 +1,12 @@
1
- #ifndef RUMALE_H
2
- #define RUMALE_H 1
1
+ #ifndef RUMALEEXT_H
2
+ #define RUMALEEXT_H 1
3
+
4
+ #include <math.h>
5
+ #include <string.h>
3
6
 
4
7
  #include <ruby.h>
5
8
 
6
- #include "tree.h"
9
+ #include <numo/narray.h>
10
+ #include <numo/template.h>
7
11
 
8
12
  #endif /* RUMALEEXT_H */