rumale 0.23.3 → 0.24.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE.txt +5 -1
  3. data/README.md +3 -288
  4. data/lib/rumale/version.rb +1 -1
  5. data/lib/rumale.rb +20 -131
  6. metadata +252 -150
  7. data/CHANGELOG.md +0 -643
  8. data/CODE_OF_CONDUCT.md +0 -74
  9. data/ext/rumale/extconf.rb +0 -37
  10. data/ext/rumale/rumaleext.c +0 -545
  11. data/ext/rumale/rumaleext.h +0 -12
  12. data/lib/rumale/base/base_estimator.rb +0 -49
  13. data/lib/rumale/base/classifier.rb +0 -36
  14. data/lib/rumale/base/cluster_analyzer.rb +0 -31
  15. data/lib/rumale/base/evaluator.rb +0 -17
  16. data/lib/rumale/base/regressor.rb +0 -36
  17. data/lib/rumale/base/splitter.rb +0 -21
  18. data/lib/rumale/base/transformer.rb +0 -22
  19. data/lib/rumale/clustering/dbscan.rb +0 -123
  20. data/lib/rumale/clustering/gaussian_mixture.rb +0 -218
  21. data/lib/rumale/clustering/hdbscan.rb +0 -291
  22. data/lib/rumale/clustering/k_means.rb +0 -122
  23. data/lib/rumale/clustering/k_medoids.rb +0 -141
  24. data/lib/rumale/clustering/mini_batch_k_means.rb +0 -139
  25. data/lib/rumale/clustering/power_iteration.rb +0 -127
  26. data/lib/rumale/clustering/single_linkage.rb +0 -203
  27. data/lib/rumale/clustering/snn.rb +0 -76
  28. data/lib/rumale/clustering/spectral_clustering.rb +0 -115
  29. data/lib/rumale/dataset.rb +0 -246
  30. data/lib/rumale/decomposition/factor_analysis.rb +0 -150
  31. data/lib/rumale/decomposition/fast_ica.rb +0 -188
  32. data/lib/rumale/decomposition/nmf.rb +0 -124
  33. data/lib/rumale/decomposition/pca.rb +0 -159
  34. data/lib/rumale/ensemble/ada_boost_classifier.rb +0 -179
  35. data/lib/rumale/ensemble/ada_boost_regressor.rb +0 -160
  36. data/lib/rumale/ensemble/extra_trees_classifier.rb +0 -139
  37. data/lib/rumale/ensemble/extra_trees_regressor.rb +0 -125
  38. data/lib/rumale/ensemble/gradient_boosting_classifier.rb +0 -306
  39. data/lib/rumale/ensemble/gradient_boosting_regressor.rb +0 -237
  40. data/lib/rumale/ensemble/random_forest_classifier.rb +0 -189
  41. data/lib/rumale/ensemble/random_forest_regressor.rb +0 -153
  42. data/lib/rumale/ensemble/stacking_classifier.rb +0 -215
  43. data/lib/rumale/ensemble/stacking_regressor.rb +0 -163
  44. data/lib/rumale/ensemble/voting_classifier.rb +0 -126
  45. data/lib/rumale/ensemble/voting_regressor.rb +0 -82
  46. data/lib/rumale/evaluation_measure/accuracy.rb +0 -29
  47. data/lib/rumale/evaluation_measure/adjusted_rand_score.rb +0 -74
  48. data/lib/rumale/evaluation_measure/calinski_harabasz_score.rb +0 -56
  49. data/lib/rumale/evaluation_measure/davies_bouldin_score.rb +0 -53
  50. data/lib/rumale/evaluation_measure/explained_variance_score.rb +0 -39
  51. data/lib/rumale/evaluation_measure/f_score.rb +0 -50
  52. data/lib/rumale/evaluation_measure/function.rb +0 -147
  53. data/lib/rumale/evaluation_measure/log_loss.rb +0 -45
  54. data/lib/rumale/evaluation_measure/mean_absolute_error.rb +0 -29
  55. data/lib/rumale/evaluation_measure/mean_squared_error.rb +0 -29
  56. data/lib/rumale/evaluation_measure/mean_squared_log_error.rb +0 -29
  57. data/lib/rumale/evaluation_measure/median_absolute_error.rb +0 -30
  58. data/lib/rumale/evaluation_measure/mutual_information.rb +0 -49
  59. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +0 -53
  60. data/lib/rumale/evaluation_measure/precision.rb +0 -50
  61. data/lib/rumale/evaluation_measure/precision_recall.rb +0 -96
  62. data/lib/rumale/evaluation_measure/purity.rb +0 -40
  63. data/lib/rumale/evaluation_measure/r2_score.rb +0 -43
  64. data/lib/rumale/evaluation_measure/recall.rb +0 -50
  65. data/lib/rumale/evaluation_measure/roc_auc.rb +0 -130
  66. data/lib/rumale/evaluation_measure/silhouette_score.rb +0 -82
  67. data/lib/rumale/feature_extraction/feature_hasher.rb +0 -110
  68. data/lib/rumale/feature_extraction/hash_vectorizer.rb +0 -155
  69. data/lib/rumale/feature_extraction/tfidf_transformer.rb +0 -113
  70. data/lib/rumale/kernel_approximation/nystroem.rb +0 -126
  71. data/lib/rumale/kernel_approximation/rbf.rb +0 -102
  72. data/lib/rumale/kernel_machine/kernel_fda.rb +0 -120
  73. data/lib/rumale/kernel_machine/kernel_pca.rb +0 -97
  74. data/lib/rumale/kernel_machine/kernel_ridge.rb +0 -82
  75. data/lib/rumale/kernel_machine/kernel_ridge_classifier.rb +0 -92
  76. data/lib/rumale/kernel_machine/kernel_svc.rb +0 -193
  77. data/lib/rumale/linear_model/base_sgd.rb +0 -285
  78. data/lib/rumale/linear_model/elastic_net.rb +0 -119
  79. data/lib/rumale/linear_model/lasso.rb +0 -115
  80. data/lib/rumale/linear_model/linear_regression.rb +0 -201
  81. data/lib/rumale/linear_model/logistic_regression.rb +0 -275
  82. data/lib/rumale/linear_model/nnls.rb +0 -137
  83. data/lib/rumale/linear_model/ridge.rb +0 -209
  84. data/lib/rumale/linear_model/svc.rb +0 -213
  85. data/lib/rumale/linear_model/svr.rb +0 -132
  86. data/lib/rumale/manifold/mds.rb +0 -155
  87. data/lib/rumale/manifold/tsne.rb +0 -222
  88. data/lib/rumale/metric_learning/fisher_discriminant_analysis.rb +0 -113
  89. data/lib/rumale/metric_learning/mlkr.rb +0 -161
  90. data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +0 -167
  91. data/lib/rumale/model_selection/cross_validation.rb +0 -125
  92. data/lib/rumale/model_selection/function.rb +0 -42
  93. data/lib/rumale/model_selection/grid_search_cv.rb +0 -225
  94. data/lib/rumale/model_selection/group_k_fold.rb +0 -93
  95. data/lib/rumale/model_selection/group_shuffle_split.rb +0 -115
  96. data/lib/rumale/model_selection/k_fold.rb +0 -81
  97. data/lib/rumale/model_selection/shuffle_split.rb +0 -90
  98. data/lib/rumale/model_selection/stratified_k_fold.rb +0 -99
  99. data/lib/rumale/model_selection/stratified_shuffle_split.rb +0 -118
  100. data/lib/rumale/model_selection/time_series_split.rb +0 -91
  101. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +0 -83
  102. data/lib/rumale/naive_bayes/base_naive_bayes.rb +0 -47
  103. data/lib/rumale/naive_bayes/bernoulli_nb.rb +0 -82
  104. data/lib/rumale/naive_bayes/complement_nb.rb +0 -85
  105. data/lib/rumale/naive_bayes/gaussian_nb.rb +0 -69
  106. data/lib/rumale/naive_bayes/multinomial_nb.rb +0 -74
  107. data/lib/rumale/naive_bayes/negation_nb.rb +0 -71
  108. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +0 -133
  109. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +0 -108
  110. data/lib/rumale/nearest_neighbors/vp_tree.rb +0 -132
  111. data/lib/rumale/neural_network/adam.rb +0 -56
  112. data/lib/rumale/neural_network/base_mlp.rb +0 -248
  113. data/lib/rumale/neural_network/mlp_classifier.rb +0 -120
  114. data/lib/rumale/neural_network/mlp_regressor.rb +0 -90
  115. data/lib/rumale/pairwise_metric.rb +0 -152
  116. data/lib/rumale/pipeline/feature_union.rb +0 -69
  117. data/lib/rumale/pipeline/pipeline.rb +0 -175
  118. data/lib/rumale/preprocessing/bin_discretizer.rb +0 -93
  119. data/lib/rumale/preprocessing/binarizer.rb +0 -60
  120. data/lib/rumale/preprocessing/kernel_calculator.rb +0 -92
  121. data/lib/rumale/preprocessing/l1_normalizer.rb +0 -62
  122. data/lib/rumale/preprocessing/l2_normalizer.rb +0 -63
  123. data/lib/rumale/preprocessing/label_binarizer.rb +0 -89
  124. data/lib/rumale/preprocessing/label_encoder.rb +0 -79
  125. data/lib/rumale/preprocessing/max_abs_scaler.rb +0 -61
  126. data/lib/rumale/preprocessing/max_normalizer.rb +0 -62
  127. data/lib/rumale/preprocessing/min_max_scaler.rb +0 -76
  128. data/lib/rumale/preprocessing/one_hot_encoder.rb +0 -100
  129. data/lib/rumale/preprocessing/ordinal_encoder.rb +0 -109
  130. data/lib/rumale/preprocessing/polynomial_features.rb +0 -109
  131. data/lib/rumale/preprocessing/standard_scaler.rb +0 -71
  132. data/lib/rumale/probabilistic_output.rb +0 -114
  133. data/lib/rumale/tree/base_decision_tree.rb +0 -150
  134. data/lib/rumale/tree/decision_tree_classifier.rb +0 -150
  135. data/lib/rumale/tree/decision_tree_regressor.rb +0 -116
  136. data/lib/rumale/tree/extra_tree_classifier.rb +0 -107
  137. data/lib/rumale/tree/extra_tree_regressor.rb +0 -94
  138. data/lib/rumale/tree/gradient_tree_regressor.rb +0 -202
  139. data/lib/rumale/tree/node.rb +0 -39
  140. data/lib/rumale/utils.rb +0 -42
  141. data/lib/rumale/validation.rb +0 -128
  142. data/lib/rumale/values.rb +0 -13
data/CODE_OF_CONDUCT.md DELETED
@@ -1,74 +0,0 @@
1
- # Contributor Covenant Code of Conduct
2
-
3
- ## Our Pledge
4
-
5
- In the interest of fostering an open and welcoming environment, we as
6
- contributors and maintainers pledge to making participation in our project and
7
- our community a harassment-free experience for everyone, regardless of age, body
8
- size, disability, ethnicity, gender identity and expression, level of experience,
9
- nationality, personal appearance, race, religion, or sexual identity and
10
- orientation.
11
-
12
- ## Our Standards
13
-
14
- Examples of behavior that contributes to creating a positive environment
15
- include:
16
-
17
- * Using welcoming and inclusive language
18
- * Being respectful of differing viewpoints and experiences
19
- * Gracefully accepting constructive criticism
20
- * Focusing on what is best for the community
21
- * Showing empathy towards other community members
22
-
23
- Examples of unacceptable behavior by participants include:
24
-
25
- * The use of sexualized language or imagery and unwelcome sexual attention or
26
- advances
27
- * Trolling, insulting/derogatory comments, and personal or political attacks
28
- * Public or private harassment
29
- * Publishing others' private information, such as a physical or electronic
30
- address, without explicit permission
31
- * Other conduct which could reasonably be considered inappropriate in a
32
- professional setting
33
-
34
- ## Our Responsibilities
35
-
36
- Project maintainers are responsible for clarifying the standards of acceptable
37
- behavior and are expected to take appropriate and fair corrective action in
38
- response to any instances of unacceptable behavior.
39
-
40
- Project maintainers have the right and responsibility to remove, edit, or
41
- reject comments, commits, code, wiki edits, issues, and other contributions
42
- that are not aligned to this Code of Conduct, or to ban temporarily or
43
- permanently any contributor for other behaviors that they deem inappropriate,
44
- threatening, offensive, or harmful.
45
-
46
- ## Scope
47
-
48
- This Code of Conduct applies both within project spaces and in public spaces
49
- when an individual is representing the project or its community. Examples of
50
- representing a project or community include using an official project e-mail
51
- address, posting via an official social media account, or acting as an appointed
52
- representative at an online or offline event. Representation of a project may be
53
- further defined and clarified by project maintainers.
54
-
55
- ## Enforcement
56
-
57
- Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
- reported by contacting the project team at yoshoku@outlook.com. All
59
- complaints will be reviewed and investigated and will result in a response that
60
- is deemed necessary and appropriate to the circumstances. The project team is
61
- obligated to maintain confidentiality with regard to the reporter of an incident.
62
- Further details of specific enforcement policies may be posted separately.
63
-
64
- Project maintainers who do not follow or enforce the Code of Conduct in good
65
- faith may face temporary or permanent repercussions as determined by other
66
- members of the project's leadership.
67
-
68
- ## Attribution
69
-
70
- This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
- available at [http://contributor-covenant.org/version/1/4][version]
72
-
73
- [homepage]: http://contributor-covenant.org
74
- [version]: http://contributor-covenant.org/version/1/4/
@@ -1,37 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'mkmf'
4
- require 'numo/narray'
5
-
6
- $LOAD_PATH.each do |lp|
7
- if File.exist?(File.join(lp, 'numo/numo/narray.h'))
8
- $INCFLAGS = "-I#{lp}/numo #{$INCFLAGS}"
9
- break
10
- end
11
- end
12
-
13
- unless have_header('numo/narray.h')
14
- puts 'numo/narray.h not found.'
15
- exit(1)
16
- end
17
-
18
- if RUBY_PLATFORM =~ /mswin|cygwin|mingw/
19
- $LOAD_PATH.each do |lp|
20
- if File.exist?(File.join(lp, 'numo/libnarray.a'))
21
- $LDFLAGS = "-L#{lp}/numo #{$LDFLAGS}"
22
- break
23
- end
24
- end
25
- unless have_library('narray', 'nary_new')
26
- puts 'libnarray.a not found.'
27
- exit(1)
28
- end
29
- end
30
-
31
- if RUBY_PLATFORM.match?(/darwin/) && Gem::Version.new('3.1.0') <= Gem::Version.new(RUBY_VERSION)
32
- if try_link('int main(void){return 0;}', '-Wl,-undefined,dynamic_lookup')
33
- $LDFLAGS << ' -Wl,-undefined,dynamic_lookup'
34
- end
35
- end
36
-
37
- create_makefile('rumale/rumaleext')
@@ -1,545 +0,0 @@
1
- #include "rumaleext.h"
2
-
3
- double* alloc_dbl_array(const long n_dimensions) {
4
- double* arr = ALLOC_N(double, n_dimensions);
5
- memset(arr, 0, n_dimensions * sizeof(double));
6
- return arr;
7
- }
8
-
9
- double calc_gini_coef(double* histogram, const long n_elements, const long n_classes) {
10
- long i;
11
- double el;
12
- double gini = 0.0;
13
-
14
- for (i = 0; i < n_classes; i++) {
15
- el = histogram[i] / n_elements;
16
- gini += el * el;
17
- }
18
-
19
- return 1.0 - gini;
20
- }
21
-
22
- double calc_entropy(double* histogram, const long n_elements, const long n_classes) {
23
- long i;
24
- double el;
25
- double entropy = 0.0;
26
-
27
- for (i = 0; i < n_classes; i++) {
28
- el = histogram[i] / n_elements;
29
- entropy += el * log(el + 1.0);
30
- }
31
-
32
- return -entropy;
33
- }
34
-
35
- VALUE
36
- calc_mean_vec(double* sum_vec, const long n_dimensions, const long n_elements) {
37
- long i;
38
- VALUE mean_vec = rb_ary_new2(n_dimensions);
39
-
40
- for (i = 0; i < n_dimensions; i++) {
41
- rb_ary_store(mean_vec, i, DBL2NUM(sum_vec[i] / n_elements));
42
- }
43
-
44
- return mean_vec;
45
- }
46
-
47
- double calc_vec_mae(VALUE vec_a, VALUE vec_b) {
48
- long i;
49
- const long n_dimensions = RARRAY_LEN(vec_a);
50
- double sum = 0.0;
51
- double diff;
52
-
53
- for (i = 0; i < n_dimensions; i++) {
54
- diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
55
- sum += fabs(diff);
56
- }
57
-
58
- return sum / n_dimensions;
59
- }
60
-
61
- double calc_vec_mse(VALUE vec_a, VALUE vec_b) {
62
- long i;
63
- const long n_dimensions = RARRAY_LEN(vec_a);
64
- double sum = 0.0;
65
- double diff;
66
-
67
- for (i = 0; i < n_dimensions; i++) {
68
- diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
69
- sum += diff * diff;
70
- }
71
-
72
- return sum / n_dimensions;
73
- }
74
-
75
- double calc_mae(VALUE target_vecs, VALUE mean_vec) {
76
- long i;
77
- const long n_elements = RARRAY_LEN(target_vecs);
78
- double sum = 0.0;
79
-
80
- for (i = 0; i < n_elements; i++) {
81
- sum += calc_vec_mae(rb_ary_entry(target_vecs, i), mean_vec);
82
- }
83
-
84
- return sum / n_elements;
85
- }
86
-
87
- double calc_mse(VALUE target_vecs, VALUE mean_vec) {
88
- long i;
89
- const long n_elements = RARRAY_LEN(target_vecs);
90
- double sum = 0.0;
91
-
92
- for (i = 0; i < n_elements; i++) {
93
- sum += calc_vec_mse(rb_ary_entry(target_vecs, i), mean_vec);
94
- }
95
-
96
- return sum / n_elements;
97
- }
98
-
99
- double calc_impurity_cls(const char* criterion, double* histogram, const long n_elements, const long n_classes) {
100
- if (strcmp(criterion, "entropy") == 0) {
101
- return calc_entropy(histogram, n_elements, n_classes);
102
- }
103
- return calc_gini_coef(histogram, n_elements, n_classes);
104
- }
105
-
106
- double calc_impurity_reg(const char* criterion, VALUE target_vecs, double* sum_vec) {
107
- const long n_elements = RARRAY_LEN(target_vecs);
108
- const long n_dimensions = RARRAY_LEN(rb_ary_entry(target_vecs, 0));
109
- VALUE mean_vec = calc_mean_vec(sum_vec, n_dimensions, n_elements);
110
-
111
- if (strcmp(criterion, "mae") == 0) {
112
- return calc_mae(target_vecs, mean_vec);
113
- }
114
- return calc_mse(target_vecs, mean_vec);
115
- }
116
-
117
- void add_sum_vec(double* sum_vec, VALUE target) {
118
- long i;
119
- const long n_dimensions = RARRAY_LEN(target);
120
-
121
- for (i = 0; i < n_dimensions; i++) {
122
- sum_vec[i] += NUM2DBL(rb_ary_entry(target, i));
123
- }
124
- }
125
-
126
- void sub_sum_vec(double* sum_vec, VALUE target) {
127
- long i;
128
- const long n_dimensions = RARRAY_LEN(target);
129
-
130
- for (i = 0; i < n_dimensions; i++) {
131
- sum_vec[i] -= NUM2DBL(rb_ary_entry(target, i));
132
- }
133
- }
134
-
135
- /**
136
- * @!visibility private
137
- */
138
- typedef struct {
139
- char* criterion;
140
- long n_classes;
141
- double impurity;
142
- } split_opts_cls;
143
- /**
144
- * @!visibility private
145
- */
146
- static void iter_find_split_params_cls(na_loop_t const* lp) {
147
- const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
148
- const double* f = (double*)NDL_PTR(lp, 1);
149
- const int32_t* y = (int32_t*)NDL_PTR(lp, 2);
150
- const long n_elements = NDL_SHAPE(lp, 0)[0];
151
- const char* criterion = ((split_opts_cls*)lp->opt_ptr)->criterion;
152
- const long n_classes = ((split_opts_cls*)lp->opt_ptr)->n_classes;
153
- const double w_impurity = ((split_opts_cls*)lp->opt_ptr)->impurity;
154
- double* params = (double*)NDL_PTR(lp, 3);
155
- long i;
156
- long curr_pos = 0;
157
- long next_pos = 0;
158
- long n_l_elements = 0;
159
- long n_r_elements = n_elements;
160
- double curr_el = f[o[0]];
161
- double last_el = f[o[n_elements - 1]];
162
- double next_el;
163
- double l_impurity;
164
- double r_impurity;
165
- double gain;
166
- double* l_histogram = alloc_dbl_array(n_classes);
167
- double* r_histogram = alloc_dbl_array(n_classes);
168
-
169
- /* Initialize optimal parameters. */
170
- params[0] = 0.0; /* left impurity */
171
- params[1] = w_impurity; /* right impurity */
172
- params[2] = curr_el; /* threshold */
173
- params[3] = 0.0; /* gain */
174
-
175
- /* Initialize child node variables. */
176
- for (i = 0; i < n_elements; i++) {
177
- r_histogram[y[o[i]]] += 1.0;
178
- }
179
-
180
- /* Find optimal parameters. */
181
- while (curr_pos < n_elements && curr_el != last_el) {
182
- next_el = f[o[next_pos]];
183
- while (next_pos < n_elements && next_el == curr_el) {
184
- l_histogram[y[o[next_pos]]] += 1;
185
- n_l_elements++;
186
- r_histogram[y[o[next_pos]]] -= 1;
187
- n_r_elements--;
188
- next_pos++;
189
- next_el = f[o[next_pos]];
190
- }
191
- /* Calculate gain of new split. */
192
- l_impurity = calc_impurity_cls(criterion, l_histogram, n_l_elements, n_classes);
193
- r_impurity = calc_impurity_cls(criterion, r_histogram, n_r_elements, n_classes);
194
- gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
195
- /* Update optimal parameters. */
196
- if (gain > params[3]) {
197
- params[0] = l_impurity;
198
- params[1] = r_impurity;
199
- params[2] = 0.5 * (curr_el + next_el);
200
- params[3] = gain;
201
- }
202
- if (next_pos == n_elements) break;
203
- curr_pos = next_pos;
204
- curr_el = f[o[curr_pos]];
205
- }
206
-
207
- xfree(l_histogram);
208
- xfree(r_histogram);
209
- }
210
- /**
211
- * @!visibility private
212
- * Find for split point with maximum information gain.
213
- *
214
- * @overload find_split_params(criterion, impurity, order, features, labels, n_classes) -> Array<Float>
215
- *
216
- * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
217
- * @param impurity [Float] The impurity of whole dataset.
218
- * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
219
- * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
220
- * @param labels [Numo::Int32] (shape: [n_elements]) The labels.
221
- * @param n_classes [Integer] The number of classes.
222
- * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
223
- */
224
- static VALUE find_split_params_cls(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE labels,
225
- VALUE n_classes) {
226
- ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cInt32, 1}};
227
- size_t out_shape[1] = {4};
228
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
229
- ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_cls, NO_LOOP, 3, 1, ain, aout};
230
- split_opts_cls opts = {StringValuePtr(criterion), NUM2LONG(n_classes), NUM2DBL(impurity)};
231
- VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, labels);
232
- VALUE results = rb_ary_new2(4);
233
- double* params_ptr = (double*)na_get_pointer_for_read(params);
234
- rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
235
- rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
236
- rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
237
- rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
238
- RB_GC_GUARD(params);
239
- RB_GC_GUARD(criterion);
240
- return results;
241
- }
242
-
243
- /**
244
- * @!visibility private
245
- */
246
- typedef struct {
247
- char* criterion;
248
- double impurity;
249
- } split_opts_reg;
250
- /**
251
- * @!visibility private
252
- */
253
- static void iter_find_split_params_reg(na_loop_t const* lp) {
254
- const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
255
- const double* f = (double*)NDL_PTR(lp, 1);
256
- const double* y = (double*)NDL_PTR(lp, 2);
257
- const long n_elements = NDL_SHAPE(lp, 0)[0];
258
- const long n_outputs = NDL_SHAPE(lp, 2)[1];
259
- const char* criterion = ((split_opts_reg*)lp->opt_ptr)->criterion;
260
- const double w_impurity = ((split_opts_reg*)lp->opt_ptr)->impurity;
261
- double* params = (double*)NDL_PTR(lp, 3);
262
- long i, j;
263
- long curr_pos = 0;
264
- long next_pos = 0;
265
- long n_l_elements = 0;
266
- long n_r_elements = n_elements;
267
- double curr_el = f[o[0]];
268
- double last_el = f[o[n_elements - 1]];
269
- double next_el;
270
- double l_impurity;
271
- double r_impurity;
272
- double gain;
273
- double* l_sum_vec = alloc_dbl_array(n_outputs);
274
- double* r_sum_vec = alloc_dbl_array(n_outputs);
275
- double target_var;
276
- VALUE l_target_vecs = rb_ary_new();
277
- VALUE r_target_vecs = rb_ary_new();
278
- VALUE target;
279
-
280
- /* Initialize optimal parameters. */
281
- params[0] = 0.0; /* left impurity */
282
- params[1] = w_impurity; /* right impurity */
283
- params[2] = curr_el; /* threshold */
284
- params[3] = 0.0; /* gain */
285
-
286
- /* Initialize child node variables. */
287
- for (i = 0; i < n_elements; i++) {
288
- target = rb_ary_new2(n_outputs);
289
- for (j = 0; j < n_outputs; j++) {
290
- target_var = y[o[i] * n_outputs + j];
291
- rb_ary_store(target, j, DBL2NUM(target_var));
292
- r_sum_vec[j] += target_var;
293
- }
294
- rb_ary_push(r_target_vecs, target);
295
- }
296
-
297
- /* Find optimal parameters. */
298
- while (curr_pos < n_elements && curr_el != last_el) {
299
- next_el = f[o[next_pos]];
300
- while (next_pos < n_elements && next_el == curr_el) {
301
- target = rb_ary_shift(r_target_vecs);
302
- n_r_elements--;
303
- sub_sum_vec(r_sum_vec, target);
304
- rb_ary_push(l_target_vecs, target);
305
- n_l_elements++;
306
- add_sum_vec(l_sum_vec, target);
307
- next_pos++;
308
- next_el = f[o[next_pos]];
309
- }
310
- /* Calculate gain of new split. */
311
- l_impurity = calc_impurity_reg(criterion, l_target_vecs, l_sum_vec);
312
- r_impurity = calc_impurity_reg(criterion, r_target_vecs, r_sum_vec);
313
- gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
314
- /* Update optimal parameters. */
315
- if (gain > params[3]) {
316
- params[0] = l_impurity;
317
- params[1] = r_impurity;
318
- params[2] = 0.5 * (curr_el + next_el);
319
- params[3] = gain;
320
- }
321
- if (next_pos == n_elements) break;
322
- curr_pos = next_pos;
323
- curr_el = f[o[curr_pos]];
324
- }
325
-
326
- xfree(l_sum_vec);
327
- xfree(r_sum_vec);
328
- }
329
- /**
330
- * @!visibility private
331
- * Find for split point with maximum information gain.
332
- *
333
- * @overload find_split_params(criterion, impurity, order, features, targets) -> Array<Float>
334
- *
335
- * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
336
- * @param impurity [Float] The impurity of whole dataset.
337
- * @param order [Numo::Int32] (shape: [n_samples]) The element indices sorted according to feature values in ascending order.
338
- * @param features [Numo::DFloat] (shape: [n_samples]) The feature values.
339
- * @param targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
340
- * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
341
- */
342
- static VALUE find_split_params_reg(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE targets) {
343
- ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 2}};
344
- size_t out_shape[1] = {4};
345
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
346
- ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_reg, NO_LOOP, 3, 1, ain, aout};
347
- split_opts_reg opts = {StringValuePtr(criterion), NUM2DBL(impurity)};
348
- VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, targets);
349
- VALUE results = rb_ary_new2(4);
350
- double* params_ptr = (double*)na_get_pointer_for_read(params);
351
- rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
352
- rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
353
- rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
354
- rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
355
- RB_GC_GUARD(params);
356
- RB_GC_GUARD(criterion);
357
- return results;
358
- }
359
-
360
- /**
361
- * @!visibility private
362
- */
363
- static void iter_find_split_params_grad_reg(na_loop_t const* lp) {
364
- const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
365
- const double* f = (double*)NDL_PTR(lp, 1);
366
- const double* g = (double*)NDL_PTR(lp, 2);
367
- const double* h = (double*)NDL_PTR(lp, 3);
368
- const double s_grad = ((double*)lp->opt_ptr)[0];
369
- const double s_hess = ((double*)lp->opt_ptr)[1];
370
- const double reg_lambda = ((double*)lp->opt_ptr)[2];
371
- const long n_elements = NDL_SHAPE(lp, 0)[0];
372
- double* params = (double*)NDL_PTR(lp, 4);
373
- long curr_pos = 0;
374
- long next_pos = 0;
375
- double curr_el = f[o[0]];
376
- double last_el = f[o[n_elements - 1]];
377
- double next_el;
378
- double l_grad = 0.0;
379
- double l_hess = 0.0;
380
- double r_grad;
381
- double r_hess;
382
- double threshold = curr_el;
383
- double gain_max = 0.0;
384
- double gain;
385
-
386
- /* Find optimal parameters. */
387
- while (curr_pos < n_elements && curr_el != last_el) {
388
- next_el = f[o[next_pos]];
389
- while (next_pos < n_elements && next_el == curr_el) {
390
- l_grad += g[o[next_pos]];
391
- l_hess += h[o[next_pos]];
392
- next_pos++;
393
- next_el = f[o[next_pos]];
394
- }
395
- /* Calculate gain of new split. */
396
- r_grad = s_grad - l_grad;
397
- r_hess = s_hess - l_hess;
398
- gain = (l_grad * l_grad) / (l_hess + reg_lambda) + (r_grad * r_grad) / (r_hess + reg_lambda) -
399
- (s_grad * s_grad) / (s_hess + reg_lambda);
400
- /* Update optimal parameters. */
401
- if (gain > gain_max) {
402
- threshold = 0.5 * (curr_el + next_el);
403
- gain_max = gain;
404
- }
405
- if (next_pos == n_elements) {
406
- break;
407
- }
408
- curr_pos = next_pos;
409
- curr_el = f[o[curr_pos]];
410
- }
411
-
412
- params[0] = threshold;
413
- params[1] = gain_max;
414
- }
415
-
416
- /**
417
- * @!visibility private
418
- * Find for split point with maximum information gain.
419
- *
420
- * @overload find_split_params(order, features, gradients, hessians, sum_gradient, sum_hessian, reg_lambda) -> Array<Float>
421
- * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
422
- * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
423
- * @param gradients [Numo::DFloat] (shape: [n_elements]) The gradient values.
424
- * @param hessians [Numo::DFloat] (shape: [n_elements]) The hessian values.
425
- * @param sum_gradient [Float] The sum of gradient values.
426
- * @param sum_hessian [Float] The sum of hessian values.
427
- * @param reg_lambda [Float] The L2 regularization term on weight.
428
- * @return [Array<Float>] The array consists of optimal parameters including threshold and gain.
429
- */
430
- static VALUE find_split_params_grad_reg(VALUE self, VALUE order, VALUE features, VALUE gradients, VALUE hessians,
431
- VALUE sum_gradient, VALUE sum_hessian, VALUE reg_lambda) {
432
- ndfunc_arg_in_t ain[4] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}};
433
- size_t out_shape[1] = {2};
434
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
435
- ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_grad_reg, NO_LOOP, 4, 1, ain, aout};
436
- double opts[3] = {NUM2DBL(sum_gradient), NUM2DBL(sum_hessian), NUM2DBL(reg_lambda)};
437
- VALUE params = na_ndloop3(&ndf, opts, 4, order, features, gradients, hessians);
438
- VALUE results = rb_ary_new2(2);
439
- double* params_ptr = (double*)na_get_pointer_for_read(params);
440
- rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
441
- rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
442
- RB_GC_GUARD(params);
443
- return results;
444
- }
445
-
446
- /**
447
- * @!visibility private
448
- * Calculate impurity based on criterion.
449
- *
450
- * @overload node_impurity(criterion, y, n_classes) -> Float
451
- *
452
- * @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
453
- * @param y_nary [Numo::Int32] (shape: [n_samples]) The labels.
454
- * @param n_elements_ [Integer] The number of elements.
455
- * @param n_classes_ [Integer] The number of classes.
456
- * @return [Float] impurity
457
- */
458
- static VALUE node_impurity_cls(VALUE self, VALUE criterion, VALUE y_nary, VALUE n_elements_, VALUE n_classes_) {
459
- long i;
460
- const long n_classes = NUM2LONG(n_classes_);
461
- const long n_elements = NUM2LONG(n_elements_);
462
- const int32_t* y = (int32_t*)na_get_pointer_for_read(y_nary);
463
- double* histogram = alloc_dbl_array(n_classes);
464
- VALUE ret;
465
-
466
- for (i = 0; i < n_elements; i++) {
467
- histogram[y[i]] += 1;
468
- }
469
-
470
- ret = DBL2NUM(calc_impurity_cls(StringValuePtr(criterion), histogram, n_elements, n_classes));
471
-
472
- xfree(histogram);
473
-
474
- RB_GC_GUARD(y_nary);
475
- RB_GC_GUARD(criterion);
476
-
477
- return ret;
478
- }
479
-
480
- /**
481
- * @!visibility private
482
- * Calculate impurity based on criterion.
483
- *
484
- * @overload node_impurity(criterion, y) -> Float
485
- *
486
- * @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
487
- * @param y [Array<Float>] (shape: [n_samples, n_outputs]) The taget values.
488
- * @return [Float] impurity
489
- */
490
- static VALUE node_impurity_reg(VALUE self, VALUE criterion, VALUE y) {
491
- long i;
492
- const long n_elements = RARRAY_LEN(y);
493
- const long n_outputs = RARRAY_LEN(rb_ary_entry(y, 0));
494
- double* sum_vec = alloc_dbl_array(n_outputs);
495
- VALUE target_vecs = rb_ary_new();
496
- VALUE target;
497
- VALUE ret;
498
-
499
- for (i = 0; i < n_elements; i++) {
500
- target = rb_ary_entry(y, i);
501
- add_sum_vec(sum_vec, target);
502
- rb_ary_push(target_vecs, target);
503
- }
504
-
505
- ret = DBL2NUM(calc_impurity_reg(StringValuePtr(criterion), target_vecs, sum_vec));
506
-
507
- xfree(sum_vec);
508
-
509
- RB_GC_GUARD(criterion);
510
-
511
- return ret;
512
- }
513
-
514
- void Init_rumaleext(void) {
515
- VALUE mRumale = rb_define_module("Rumale");
516
- VALUE mTree = rb_define_module_under(mRumale, "Tree");
517
-
518
- /**
519
- * Document-module: Rumale::Tree::ExtDecisionTreeClassifier
520
- * @!visibility private
521
- * The mixin module consisting of extension method for DecisionTreeClassifier class.
522
- * This module is used internally.
523
- */
524
- VALUE mExtDTreeCls = rb_define_module_under(mTree, "ExtDecisionTreeClassifier");
525
- /**
526
- * Document-module: Rumale::Tree::ExtDecisionTreeRegressor
527
- * @!visibility private
528
- * The mixin module consisting of extension method for DecisionTreeRegressor class.
529
- * This module is used internally.
530
- */
531
- VALUE mExtDTreeReg = rb_define_module_under(mTree, "ExtDecisionTreeRegressor");
532
- /**
533
- * Document-module: Rumale::Tree::ExtGradientTreeRegressor
534
- * @!visibility private
535
- * The mixin module consisting of extension method for GradientTreeRegressor class.
536
- * This module is used internally.
537
- */
538
- VALUE mExtGTreeReg = rb_define_module_under(mTree, "ExtGradientTreeRegressor");
539
-
540
- rb_define_private_method(mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
541
- rb_define_private_method(mExtDTreeReg, "find_split_params", find_split_params_reg, 5);
542
- rb_define_private_method(mExtGTreeReg, "find_split_params", find_split_params_grad_reg, 7);
543
- rb_define_private_method(mExtDTreeCls, "node_impurity", node_impurity_cls, 4);
544
- rb_define_private_method(mExtDTreeReg, "node_impurity", node_impurity_reg, 2);
545
- }
@@ -1,12 +0,0 @@
1
- #ifndef RUMALEEXT_H
2
- #define RUMALEEXT_H 1
3
-
4
- #include <math.h>
5
- #include <string.h>
6
-
7
- #include <ruby.h>
8
-
9
- #include <numo/narray.h>
10
- #include <numo/template.h>
11
-
12
- #endif /* RUMALEEXT_H */