rumale 0.12.5 → 0.12.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +3 -0
- data/README.md +1 -1
- data/ext/rumale/extconf.rb +8 -0
- data/ext/rumale/rumale.c +137 -82
- data/ext/rumale/rumale.h +2 -0
- data/lib/rumale/tree/decision_tree_classifier.rb +3 -4
- data/lib/rumale/tree/gradient_tree_regressor.rb +2 -6
- data/lib/rumale/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA1:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: f5ad94467c9d031744ab3cc7e0c29439a23ce562
|
|
4
|
+
data.tar.gz: 6b99e9c5846acef596ee6efef32fc355df73f856
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: b962432af5544227a33dce685f002807855bb51559fbe874414e26f3b13df031886bf294769282b3971eee75d4ad6941a630b8a9b81a591975ad7977fc111e43
|
|
7
|
+
data.tar.gz: e2c5e90412c9dfb1cb2ab6cf9f46593a93902807cbffe5b3fff43074074d3cc77ce9a03d9be8dd7c2e447579fb231c0548f57796ee04b189758e8890cc8e2a8f
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
|
@@ -6,7 +6,7 @@
|
|
|
6
6
|
[](https://coveralls.io/github/yoshoku/rumale?branch=master)
|
|
7
7
|
[](https://badge.fury.io/rb/rumale)
|
|
8
8
|
[](https://github.com/yoshoku/rumale/blob/master/LICENSE.txt)
|
|
9
|
-
[](https://www.rubydoc.info/gems/rumale/0.12.
|
|
9
|
+
[](https://www.rubydoc.info/gems/rumale/0.12.6)
|
|
10
10
|
|
|
11
11
|
Rumale (**Ru**by **ma**chine **le**arning) is a machine learning library in Ruby.
|
|
12
12
|
Rumale provides machine learning algorithms with interfaces similar to Scikit-Learn in Python.
|
data/ext/rumale/extconf.rb
CHANGED
|
@@ -1,5 +1,13 @@
|
|
|
1
1
|
# frozen_string_literal: true
|
|
2
2
|
|
|
3
3
|
require 'mkmf'
|
|
4
|
+
require 'numo/narray'
|
|
5
|
+
|
|
6
|
+
$LOAD_PATH.each do |lp|
|
|
7
|
+
if File.exist?(File.join(lp, 'numo/numo/narray.h'))
|
|
8
|
+
$INCFLAGS = "-I#{lp}/numo #{$INCFLAGS}"
|
|
9
|
+
break
|
|
10
|
+
end
|
|
11
|
+
end
|
|
4
12
|
|
|
5
13
|
create_makefile('rumale/rumale')
|
data/ext/rumale/rumale.c
CHANGED
|
@@ -122,9 +122,9 @@ calc_mse(VALUE target_vecs, VALUE sum_vec)
|
|
|
122
122
|
}
|
|
123
123
|
|
|
124
124
|
double
|
|
125
|
-
calc_impurity_cls(
|
|
125
|
+
calc_impurity_cls(const char* criterion, VALUE histogram, const long n_elements)
|
|
126
126
|
{
|
|
127
|
-
if (strcmp(
|
|
127
|
+
if (strcmp(criterion, "entropy") == 0) {
|
|
128
128
|
return calc_entropy(histogram, n_elements);
|
|
129
129
|
}
|
|
130
130
|
return calc_gini_coef(histogram, n_elements);
|
|
@@ -181,76 +181,107 @@ sub_sum_vec(VALUE sum_vec, VALUE target)
|
|
|
181
181
|
|
|
182
182
|
/**
|
|
183
183
|
* @!visibility private
|
|
184
|
-
* Find for split point with maximum information gain.
|
|
185
|
-
*
|
|
186
|
-
* @overload find_split_params(criterion, impurity, sorted_features, sorted_labels, n_classes) -> Array<Float>
|
|
187
|
-
*
|
|
188
|
-
* @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
|
|
189
|
-
* @param impurity [Float] The impurity of whole dataset.
|
|
190
|
-
* @param sorted_features [Numo::DFloat] (shape: [n_samples]) The feature values sorted in ascending order.
|
|
191
|
-
* @param sorted_labels [Numo::Int32] (shape: [n_labels]) The labels sorted according to feature values.
|
|
192
|
-
* @param n_classes [Integer] The number of classes.
|
|
193
|
-
* @return [Float] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
|
|
194
184
|
*/
|
|
195
|
-
|
|
196
|
-
|
|
185
|
+
typedef struct {
|
|
186
|
+
char* criterion;
|
|
187
|
+
long n_classes;
|
|
188
|
+
double impurity;
|
|
189
|
+
} split_opts_cls;
|
|
190
|
+
/**
|
|
191
|
+
* @!visibility private
|
|
192
|
+
*/
|
|
193
|
+
static void
|
|
194
|
+
iter_find_split_params_cls(na_loop_t const* lp)
|
|
197
195
|
{
|
|
198
|
-
const
|
|
199
|
-
const
|
|
200
|
-
const
|
|
201
|
-
long
|
|
196
|
+
const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
|
|
197
|
+
const double* f = (double*)NDL_PTR(lp, 1);
|
|
198
|
+
const int32_t* y = (int32_t*)NDL_PTR(lp, 2);
|
|
199
|
+
const long n_elements = NDL_SHAPE(lp, 0)[0];
|
|
200
|
+
const char* criterion = ((split_opts_cls*)lp->opt_ptr)->criterion;
|
|
201
|
+
const long n_classes = ((split_opts_cls*)lp->opt_ptr)->n_classes;
|
|
202
|
+
const double w_impurity = ((split_opts_cls*)lp->opt_ptr)->impurity;
|
|
203
|
+
double* params = (double*)NDL_PTR(lp, 3);
|
|
204
|
+
long i;
|
|
202
205
|
long curr_pos = 0;
|
|
203
206
|
long next_pos = 0;
|
|
204
207
|
long n_l_elements = 0;
|
|
205
208
|
long n_r_elements = n_elements;
|
|
206
|
-
double
|
|
207
|
-
double
|
|
209
|
+
double curr_el = f[o[0]];
|
|
210
|
+
double last_el = f[o[n_elements - 1]];
|
|
208
211
|
double next_el;
|
|
209
212
|
double l_impurity;
|
|
210
213
|
double r_impurity;
|
|
211
214
|
double gain;
|
|
212
215
|
VALUE l_histogram = create_zero_vector(n_classes);
|
|
213
216
|
VALUE r_histogram = create_zero_vector(n_classes);
|
|
214
|
-
VALUE opt_params = rb_ary_new2(4);
|
|
215
217
|
|
|
216
218
|
/* Initialize optimal parameters. */
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
219
|
+
params[0] = 0.0; /* left impurity */
|
|
220
|
+
params[1] = w_impurity; /* right impurity */
|
|
221
|
+
params[2] = curr_el; /* threshold */
|
|
222
|
+
params[3] = 0.0; /* gain */
|
|
221
223
|
|
|
222
224
|
/* Initialize child node variables. */
|
|
223
|
-
for (
|
|
224
|
-
increment_histogram(r_histogram,
|
|
225
|
+
for (i = 0; i < n_elements; i++) {
|
|
226
|
+
increment_histogram(r_histogram, y[o[i]]);
|
|
225
227
|
}
|
|
226
228
|
|
|
227
229
|
/* Find optimal parameters. */
|
|
228
230
|
while (curr_pos < n_elements && curr_el != last_el) {
|
|
229
|
-
next_el =
|
|
231
|
+
next_el = f[o[next_pos]];
|
|
230
232
|
while (next_pos < n_elements && next_el == curr_el) {
|
|
231
|
-
increment_histogram(l_histogram,
|
|
233
|
+
increment_histogram(l_histogram, y[o[next_pos]]);
|
|
232
234
|
n_l_elements++;
|
|
233
|
-
decrement_histogram(r_histogram,
|
|
235
|
+
decrement_histogram(r_histogram, y[o[next_pos]]);
|
|
234
236
|
n_r_elements--;
|
|
235
|
-
|
|
237
|
+
next_pos++;
|
|
238
|
+
next_el = f[o[next_pos]];
|
|
236
239
|
}
|
|
237
240
|
/* Calculate gain of new split. */
|
|
238
241
|
l_impurity = calc_impurity_cls(criterion, l_histogram, n_l_elements);
|
|
239
242
|
r_impurity = calc_impurity_cls(criterion, r_histogram, n_r_elements);
|
|
240
243
|
gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
|
|
241
244
|
/* Update optimal parameters. */
|
|
242
|
-
if (gain >
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
245
|
+
if (gain > params[3]) {
|
|
246
|
+
params[0] = l_impurity;
|
|
247
|
+
params[1] = r_impurity;
|
|
248
|
+
params[2] = 0.5 * (curr_el + next_el);
|
|
249
|
+
params[3] = gain;
|
|
247
250
|
}
|
|
248
251
|
if (next_pos == n_elements) break;
|
|
249
252
|
curr_pos = next_pos;
|
|
250
|
-
curr_el =
|
|
253
|
+
curr_el = f[o[curr_pos]];
|
|
251
254
|
}
|
|
252
|
-
|
|
253
|
-
|
|
255
|
+
}
|
|
256
|
+
/**
|
|
257
|
+
* @!visibility private
|
|
258
|
+
* Find for split point with maximum information gain.
|
|
259
|
+
*
|
|
260
|
+
* @overload find_split_params(criterion, impurity, order, features, labels, n_classes) -> Array<Float>
|
|
261
|
+
*
|
|
262
|
+
* @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
|
|
263
|
+
* @param impurity [Float] The impurity of whole dataset.
|
|
264
|
+
* @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
|
|
265
|
+
* @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
|
|
266
|
+
* @param labels [Numo::Int32] (shape: [n_elements]) The labels.
|
|
267
|
+
* @param n_classes [Integer] The number of classes.
|
|
268
|
+
* @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
|
|
269
|
+
*/
|
|
270
|
+
static VALUE
|
|
271
|
+
find_split_params_cls(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE labels, VALUE n_classes)
|
|
272
|
+
{
|
|
273
|
+
ndfunc_arg_in_t ain[3] = { {numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cInt32, 1} };
|
|
274
|
+
size_t out_shape[1] = { 4 };
|
|
275
|
+
ndfunc_arg_out_t aout[1] = { {numo_cDFloat, 1, out_shape} };
|
|
276
|
+
ndfunc_t ndf = { (na_iter_func_t)iter_find_split_params_cls, NO_LOOP, 3, 1, ain, aout };
|
|
277
|
+
split_opts_cls opts = { StringValuePtr(criterion), NUM2LONG(n_classes), NUM2DBL(impurity) };
|
|
278
|
+
VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, labels);
|
|
279
|
+
VALUE results = rb_ary_new2(4);
|
|
280
|
+
rb_ary_store(results, 0, DBL2NUM(((double*)na_get_pointer_for_read(params))[0]));
|
|
281
|
+
rb_ary_store(results, 1, DBL2NUM(((double*)na_get_pointer_for_read(params))[1]));
|
|
282
|
+
rb_ary_store(results, 2, DBL2NUM(((double*)na_get_pointer_for_read(params))[2]));
|
|
283
|
+
rb_ary_store(results, 3, DBL2NUM(((double*)na_get_pointer_for_read(params))[3]));
|
|
284
|
+
return results;
|
|
254
285
|
}
|
|
255
286
|
|
|
256
287
|
/**
|
|
@@ -336,50 +367,40 @@ find_split_params_reg(VALUE self, VALUE criterion, VALUE whole_impurity, VALUE s
|
|
|
336
367
|
|
|
337
368
|
/**
|
|
338
369
|
* @!visibility private
|
|
339
|
-
* Find for split point with maximum information gain.
|
|
340
|
-
*
|
|
341
|
-
* @overload find_split_params(sorted_features, sorted_gradient, sorted_hessian, sum_gradient, sum_hessian) -> Array<Float>
|
|
342
|
-
*
|
|
343
|
-
* @param sorted_features [Array<Float>] (size: n_samples) The feature values sorted in ascending order.
|
|
344
|
-
* @param sorted_targets [Array<Float>] (size: n_samples) The target values sorted according to feature values.
|
|
345
|
-
* @param sorted_gradient [Array<Float>] (size: n_samples) The gradient values of loss function sorted according to feature values.
|
|
346
|
-
* @param sorted_hessian [Array<Float>] (size: n_samples) The hessian values of loss function sorted according to feature values.
|
|
347
|
-
* @param sum_gradient [Float] The sum of gradient values.
|
|
348
|
-
* @param sum_hessian [Float] The sum of hessian values.
|
|
349
|
-
* @param reg_lambda [Float] The L2 regularization term on weight.
|
|
350
|
-
* @return [Array<Float>] The array consists of optimal parameters including threshold and gain.
|
|
351
370
|
*/
|
|
352
|
-
static
|
|
353
|
-
|
|
354
|
-
(VALUE self, VALUE sorted_f, VALUE sorted_g, VALUE sorted_h, VALUE sum_g, VALUE sum_h, VALUE reg_l)
|
|
371
|
+
static void
|
|
372
|
+
iter_find_split_params_grad_reg(na_loop_t const* lp)
|
|
355
373
|
{
|
|
356
|
-
const
|
|
357
|
-
const double
|
|
358
|
-
const double
|
|
359
|
-
const double
|
|
374
|
+
const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
|
|
375
|
+
const double* f = (double*)NDL_PTR(lp, 1);
|
|
376
|
+
const double* g = (double*)NDL_PTR(lp, 2);
|
|
377
|
+
const double* h = (double*)NDL_PTR(lp, 3);
|
|
378
|
+
const double s_grad = ((double*)lp->opt_ptr)[0];
|
|
379
|
+
const double s_hess = ((double*)lp->opt_ptr)[1];
|
|
380
|
+
const double reg_lambda = ((double*)lp->opt_ptr)[2];
|
|
381
|
+
const long n_elements = NDL_SHAPE(lp, 0)[0];
|
|
382
|
+
double* params = (double*)NDL_PTR(lp, 4);
|
|
360
383
|
long curr_pos = 0;
|
|
361
384
|
long next_pos = 0;
|
|
362
|
-
double
|
|
363
|
-
double
|
|
385
|
+
double curr_el = f[o[0]];
|
|
386
|
+
double last_el = f[o[n_elements - 1]];
|
|
364
387
|
double next_el;
|
|
365
388
|
double l_grad = 0.0;
|
|
366
389
|
double l_hess = 0.0;
|
|
367
390
|
double r_grad;
|
|
368
391
|
double r_hess;
|
|
392
|
+
double threshold = curr_el;
|
|
393
|
+
double gain_max = 0.0;
|
|
369
394
|
double gain;
|
|
370
|
-
VALUE opt_params = rb_ary_new2(2);
|
|
371
|
-
|
|
372
|
-
/* Initialize optimal parameters. */
|
|
373
|
-
rb_ary_store(opt_params, 0, rb_ary_entry(sorted_f, 0)); /* threshold */
|
|
374
|
-
rb_ary_store(opt_params, 1, DBL2NUM(0)); /* gain */
|
|
375
395
|
|
|
376
396
|
/* Find optimal parameters. */
|
|
377
397
|
while (curr_pos < n_elements && curr_el != last_el) {
|
|
378
|
-
next_el =
|
|
398
|
+
next_el = f[o[next_pos]];
|
|
379
399
|
while (next_pos < n_elements && next_el == curr_el) {
|
|
380
|
-
l_grad +=
|
|
381
|
-
l_hess +=
|
|
382
|
-
|
|
400
|
+
l_grad += g[o[next_pos]];
|
|
401
|
+
l_hess += h[o[next_pos]];
|
|
402
|
+
next_pos++;
|
|
403
|
+
next_el = f[o[next_pos]];
|
|
383
404
|
}
|
|
384
405
|
/* Calculate gain of new split. */
|
|
385
406
|
r_grad = s_grad - l_grad;
|
|
@@ -388,16 +409,48 @@ find_split_params_grad_reg
|
|
|
388
409
|
(r_grad * r_grad) / (r_hess + reg_lambda) -
|
|
389
410
|
(s_grad * s_grad) / (s_hess + reg_lambda);
|
|
390
411
|
/* Update optimal parameters. */
|
|
391
|
-
if (gain >
|
|
392
|
-
|
|
393
|
-
|
|
412
|
+
if (gain > gain_max) {
|
|
413
|
+
threshold = 0.5 * (curr_el + next_el);
|
|
414
|
+
gain_max = gain;
|
|
394
415
|
}
|
|
395
416
|
if (next_pos == n_elements) break;
|
|
396
417
|
curr_pos = next_pos;
|
|
397
|
-
curr_el =
|
|
418
|
+
curr_el = f[o[curr_pos]];
|
|
398
419
|
}
|
|
399
420
|
|
|
400
|
-
|
|
421
|
+
params[0] = threshold;
|
|
422
|
+
params[1] = gain_max;
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
/**
|
|
426
|
+
* @!visibility private
|
|
427
|
+
* Find for split point with maximum information gain.
|
|
428
|
+
*
|
|
429
|
+
* @overload find_split_params(order, features, gradients, hessians, sum_gradient, sum_hessian, reg_lambda) -> Array<Float>
|
|
430
|
+
*
|
|
431
|
+
* @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
|
|
432
|
+
* @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
|
|
433
|
+
* @param gradients [Numo::DFloat] (shape: [n_elements]) The gradient values.
|
|
434
|
+
* @param hessians [Numo::DFloat] (shape: [n_elements]) The hessian values.
|
|
435
|
+
* @param sum_gradient [Float] The sum of gradient values.
|
|
436
|
+
* @param sum_hessian [Float] The sum of hessian values.
|
|
437
|
+
* @param reg_lambda [Float] The L2 regularization term on weight.
|
|
438
|
+
* @return [Array<Float>] The array consists of optimal parameters including threshold and gain.
|
|
439
|
+
*/
|
|
440
|
+
static VALUE
|
|
441
|
+
find_split_params_grad_reg
|
|
442
|
+
(VALUE self, VALUE order, VALUE features, VALUE gradients, VALUE hessians, VALUE sum_gradient, VALUE sum_hessian, VALUE reg_lambda)
|
|
443
|
+
{
|
|
444
|
+
ndfunc_arg_in_t ain[4] = { {numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1} };
|
|
445
|
+
size_t out_shape[1] = { 2 };
|
|
446
|
+
ndfunc_arg_out_t aout[1] = { {numo_cDFloat, 1, out_shape} };
|
|
447
|
+
ndfunc_t ndf = { (na_iter_func_t)iter_find_split_params_grad_reg, NO_LOOP, 4, 1, ain, aout };
|
|
448
|
+
double opts[3] = { NUM2DBL(sum_gradient), NUM2DBL(sum_hessian), NUM2DBL(reg_lambda) };
|
|
449
|
+
VALUE params = na_ndloop3(&ndf, opts, 4, order, features, gradients, hessians);
|
|
450
|
+
VALUE results = rb_ary_new2(2);
|
|
451
|
+
rb_ary_store(results, 0, DBL2NUM(((double*)na_get_pointer_for_read(params))[0]));
|
|
452
|
+
rb_ary_store(results, 1, DBL2NUM(((double*)na_get_pointer_for_read(params))[1]));
|
|
453
|
+
return results;
|
|
401
454
|
}
|
|
402
455
|
|
|
403
456
|
/**
|
|
@@ -407,22 +460,24 @@ find_split_params_grad_reg
|
|
|
407
460
|
* @overload node_impurity(criterion, y, n_classes) -> Float
|
|
408
461
|
*
|
|
409
462
|
* @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
|
|
410
|
-
* @param
|
|
463
|
+
* @param y_nary [Numo::Int32] (shape: [n_samples]) The labels.
|
|
464
|
+
* @param n_elements_ [Integer] The number of elements.
|
|
411
465
|
* @param n_classes [Integer] The number of classes.
|
|
412
466
|
* @return [Float] impurity
|
|
413
467
|
*/
|
|
414
468
|
static VALUE
|
|
415
|
-
node_impurity_cls(VALUE self, VALUE criterion, VALUE
|
|
469
|
+
node_impurity_cls(VALUE self, VALUE criterion, VALUE y_nary, VALUE n_elements_, VALUE n_classes)
|
|
416
470
|
{
|
|
417
471
|
long i;
|
|
418
|
-
const long n_elements =
|
|
472
|
+
const long n_elements = NUM2LONG(n_elements_);
|
|
473
|
+
const int32_t* y = (int32_t*)na_get_pointer_for_read(y_nary);
|
|
419
474
|
VALUE histogram = create_zero_vector(NUM2LONG(n_classes));
|
|
420
475
|
|
|
421
476
|
for (i = 0; i < n_elements; i++) {
|
|
422
|
-
increment_histogram(histogram,
|
|
477
|
+
increment_histogram(histogram, y[i]);
|
|
423
478
|
}
|
|
424
479
|
|
|
425
|
-
return DBL2NUM(calc_impurity_cls(criterion, histogram, n_elements));
|
|
480
|
+
return DBL2NUM(calc_impurity_cls(StringValuePtr(criterion), histogram, n_elements));
|
|
426
481
|
}
|
|
427
482
|
|
|
428
483
|
/**
|
|
@@ -480,9 +535,9 @@ void Init_rumale(void)
|
|
|
480
535
|
*/
|
|
481
536
|
VALUE mExtGTreeReg = rb_define_module_under(mTree, "ExtGradientTreeRegressor");
|
|
482
537
|
|
|
483
|
-
rb_define_private_method(mExtDTreeCls, "find_split_params", find_split_params_cls,
|
|
538
|
+
rb_define_private_method(mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
|
|
484
539
|
rb_define_private_method(mExtDTreeReg, "find_split_params", find_split_params_reg, 4);
|
|
485
|
-
rb_define_private_method(mExtGTreeReg, "find_split_params", find_split_params_grad_reg,
|
|
486
|
-
rb_define_private_method(mExtDTreeCls, "node_impurity", node_impurity_cls,
|
|
540
|
+
rb_define_private_method(mExtGTreeReg, "find_split_params", find_split_params_grad_reg, 7);
|
|
541
|
+
rb_define_private_method(mExtDTreeCls, "node_impurity", node_impurity_cls, 4);
|
|
487
542
|
rb_define_private_method(mExtDTreeReg, "node_impurity", node_impurity_reg, 2);
|
|
488
543
|
}
|
data/ext/rumale/rumale.h
CHANGED
|
@@ -155,15 +155,14 @@ module Rumale
|
|
|
155
155
|
|
|
156
156
|
def best_split(features, y, whole_impurity)
|
|
157
157
|
order = features.sort_index
|
|
158
|
-
sorted_f = features[order].to_a
|
|
159
|
-
sorted_y = y[order, 0].to_a
|
|
160
158
|
n_classes = @classes.size
|
|
161
|
-
find_split_params(@params[:criterion], whole_impurity,
|
|
159
|
+
find_split_params(@params[:criterion], whole_impurity, order, features, y[true, 0], n_classes)
|
|
162
160
|
end
|
|
163
161
|
|
|
164
162
|
def impurity(y)
|
|
163
|
+
n_elements = y.shape[0]
|
|
165
164
|
n_classes = @classes.size
|
|
166
|
-
node_impurity(@params[:criterion], y[true, 0].
|
|
165
|
+
node_impurity(@params[:criterion], y[true, 0].dup, n_elements, n_classes)
|
|
167
166
|
end
|
|
168
167
|
end
|
|
169
168
|
end
|
|
@@ -214,12 +214,8 @@ module Rumale
|
|
|
214
214
|
node
|
|
215
215
|
end
|
|
216
216
|
|
|
217
|
-
def best_split(
|
|
218
|
-
|
|
219
|
-
sorted_f = features[order].to_a
|
|
220
|
-
sorted_g = g[order].to_a
|
|
221
|
-
sorted_h = h[order].to_a
|
|
222
|
-
find_split_params(sorted_f, sorted_g, sorted_h, sum_g, sum_h, @params[:reg_lambda])
|
|
217
|
+
def best_split(f, g, h, sum_g, sum_h)
|
|
218
|
+
find_split_params(f.sort_index, f, g, h, sum_g, sum_h, @params[:reg_lambda])
|
|
223
219
|
end
|
|
224
220
|
|
|
225
221
|
def rand_ids
|
data/lib/rumale/version.rb
CHANGED
metadata
CHANGED
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: rumale
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.12.
|
|
4
|
+
version: 0.12.6
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- yoshoku
|
|
8
8
|
autorequire:
|
|
9
9
|
bindir: exe
|
|
10
10
|
cert_chain: []
|
|
11
|
-
date: 2019-07-
|
|
11
|
+
date: 2019-07-13 00:00:00.000000000 Z
|
|
12
12
|
dependencies:
|
|
13
13
|
- !ruby/object:Gem::Dependency
|
|
14
14
|
name: numo-narray
|