rumale 0.23.3 → 0.24.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE.txt +5 -1
  3. data/README.md +3 -288
  4. data/lib/rumale/version.rb +1 -1
  5. data/lib/rumale.rb +20 -131
  6. metadata +252 -150
  7. data/CHANGELOG.md +0 -643
  8. data/CODE_OF_CONDUCT.md +0 -74
  9. data/ext/rumale/extconf.rb +0 -37
  10. data/ext/rumale/rumaleext.c +0 -545
  11. data/ext/rumale/rumaleext.h +0 -12
  12. data/lib/rumale/base/base_estimator.rb +0 -49
  13. data/lib/rumale/base/classifier.rb +0 -36
  14. data/lib/rumale/base/cluster_analyzer.rb +0 -31
  15. data/lib/rumale/base/evaluator.rb +0 -17
  16. data/lib/rumale/base/regressor.rb +0 -36
  17. data/lib/rumale/base/splitter.rb +0 -21
  18. data/lib/rumale/base/transformer.rb +0 -22
  19. data/lib/rumale/clustering/dbscan.rb +0 -123
  20. data/lib/rumale/clustering/gaussian_mixture.rb +0 -218
  21. data/lib/rumale/clustering/hdbscan.rb +0 -291
  22. data/lib/rumale/clustering/k_means.rb +0 -122
  23. data/lib/rumale/clustering/k_medoids.rb +0 -141
  24. data/lib/rumale/clustering/mini_batch_k_means.rb +0 -139
  25. data/lib/rumale/clustering/power_iteration.rb +0 -127
  26. data/lib/rumale/clustering/single_linkage.rb +0 -203
  27. data/lib/rumale/clustering/snn.rb +0 -76
  28. data/lib/rumale/clustering/spectral_clustering.rb +0 -115
  29. data/lib/rumale/dataset.rb +0 -246
  30. data/lib/rumale/decomposition/factor_analysis.rb +0 -150
  31. data/lib/rumale/decomposition/fast_ica.rb +0 -188
  32. data/lib/rumale/decomposition/nmf.rb +0 -124
  33. data/lib/rumale/decomposition/pca.rb +0 -159
  34. data/lib/rumale/ensemble/ada_boost_classifier.rb +0 -179
  35. data/lib/rumale/ensemble/ada_boost_regressor.rb +0 -160
  36. data/lib/rumale/ensemble/extra_trees_classifier.rb +0 -139
  37. data/lib/rumale/ensemble/extra_trees_regressor.rb +0 -125
  38. data/lib/rumale/ensemble/gradient_boosting_classifier.rb +0 -306
  39. data/lib/rumale/ensemble/gradient_boosting_regressor.rb +0 -237
  40. data/lib/rumale/ensemble/random_forest_classifier.rb +0 -189
  41. data/lib/rumale/ensemble/random_forest_regressor.rb +0 -153
  42. data/lib/rumale/ensemble/stacking_classifier.rb +0 -215
  43. data/lib/rumale/ensemble/stacking_regressor.rb +0 -163
  44. data/lib/rumale/ensemble/voting_classifier.rb +0 -126
  45. data/lib/rumale/ensemble/voting_regressor.rb +0 -82
  46. data/lib/rumale/evaluation_measure/accuracy.rb +0 -29
  47. data/lib/rumale/evaluation_measure/adjusted_rand_score.rb +0 -74
  48. data/lib/rumale/evaluation_measure/calinski_harabasz_score.rb +0 -56
  49. data/lib/rumale/evaluation_measure/davies_bouldin_score.rb +0 -53
  50. data/lib/rumale/evaluation_measure/explained_variance_score.rb +0 -39
  51. data/lib/rumale/evaluation_measure/f_score.rb +0 -50
  52. data/lib/rumale/evaluation_measure/function.rb +0 -147
  53. data/lib/rumale/evaluation_measure/log_loss.rb +0 -45
  54. data/lib/rumale/evaluation_measure/mean_absolute_error.rb +0 -29
  55. data/lib/rumale/evaluation_measure/mean_squared_error.rb +0 -29
  56. data/lib/rumale/evaluation_measure/mean_squared_log_error.rb +0 -29
  57. data/lib/rumale/evaluation_measure/median_absolute_error.rb +0 -30
  58. data/lib/rumale/evaluation_measure/mutual_information.rb +0 -49
  59. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +0 -53
  60. data/lib/rumale/evaluation_measure/precision.rb +0 -50
  61. data/lib/rumale/evaluation_measure/precision_recall.rb +0 -96
  62. data/lib/rumale/evaluation_measure/purity.rb +0 -40
  63. data/lib/rumale/evaluation_measure/r2_score.rb +0 -43
  64. data/lib/rumale/evaluation_measure/recall.rb +0 -50
  65. data/lib/rumale/evaluation_measure/roc_auc.rb +0 -130
  66. data/lib/rumale/evaluation_measure/silhouette_score.rb +0 -82
  67. data/lib/rumale/feature_extraction/feature_hasher.rb +0 -110
  68. data/lib/rumale/feature_extraction/hash_vectorizer.rb +0 -155
  69. data/lib/rumale/feature_extraction/tfidf_transformer.rb +0 -113
  70. data/lib/rumale/kernel_approximation/nystroem.rb +0 -126
  71. data/lib/rumale/kernel_approximation/rbf.rb +0 -102
  72. data/lib/rumale/kernel_machine/kernel_fda.rb +0 -120
  73. data/lib/rumale/kernel_machine/kernel_pca.rb +0 -97
  74. data/lib/rumale/kernel_machine/kernel_ridge.rb +0 -82
  75. data/lib/rumale/kernel_machine/kernel_ridge_classifier.rb +0 -92
  76. data/lib/rumale/kernel_machine/kernel_svc.rb +0 -193
  77. data/lib/rumale/linear_model/base_sgd.rb +0 -285
  78. data/lib/rumale/linear_model/elastic_net.rb +0 -119
  79. data/lib/rumale/linear_model/lasso.rb +0 -115
  80. data/lib/rumale/linear_model/linear_regression.rb +0 -201
  81. data/lib/rumale/linear_model/logistic_regression.rb +0 -275
  82. data/lib/rumale/linear_model/nnls.rb +0 -137
  83. data/lib/rumale/linear_model/ridge.rb +0 -209
  84. data/lib/rumale/linear_model/svc.rb +0 -213
  85. data/lib/rumale/linear_model/svr.rb +0 -132
  86. data/lib/rumale/manifold/mds.rb +0 -155
  87. data/lib/rumale/manifold/tsne.rb +0 -222
  88. data/lib/rumale/metric_learning/fisher_discriminant_analysis.rb +0 -113
  89. data/lib/rumale/metric_learning/mlkr.rb +0 -161
  90. data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +0 -167
  91. data/lib/rumale/model_selection/cross_validation.rb +0 -125
  92. data/lib/rumale/model_selection/function.rb +0 -42
  93. data/lib/rumale/model_selection/grid_search_cv.rb +0 -225
  94. data/lib/rumale/model_selection/group_k_fold.rb +0 -93
  95. data/lib/rumale/model_selection/group_shuffle_split.rb +0 -115
  96. data/lib/rumale/model_selection/k_fold.rb +0 -81
  97. data/lib/rumale/model_selection/shuffle_split.rb +0 -90
  98. data/lib/rumale/model_selection/stratified_k_fold.rb +0 -99
  99. data/lib/rumale/model_selection/stratified_shuffle_split.rb +0 -118
  100. data/lib/rumale/model_selection/time_series_split.rb +0 -91
  101. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +0 -83
  102. data/lib/rumale/naive_bayes/base_naive_bayes.rb +0 -47
  103. data/lib/rumale/naive_bayes/bernoulli_nb.rb +0 -82
  104. data/lib/rumale/naive_bayes/complement_nb.rb +0 -85
  105. data/lib/rumale/naive_bayes/gaussian_nb.rb +0 -69
  106. data/lib/rumale/naive_bayes/multinomial_nb.rb +0 -74
  107. data/lib/rumale/naive_bayes/negation_nb.rb +0 -71
  108. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +0 -133
  109. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +0 -108
  110. data/lib/rumale/nearest_neighbors/vp_tree.rb +0 -132
  111. data/lib/rumale/neural_network/adam.rb +0 -56
  112. data/lib/rumale/neural_network/base_mlp.rb +0 -248
  113. data/lib/rumale/neural_network/mlp_classifier.rb +0 -120
  114. data/lib/rumale/neural_network/mlp_regressor.rb +0 -90
  115. data/lib/rumale/pairwise_metric.rb +0 -152
  116. data/lib/rumale/pipeline/feature_union.rb +0 -69
  117. data/lib/rumale/pipeline/pipeline.rb +0 -175
  118. data/lib/rumale/preprocessing/bin_discretizer.rb +0 -93
  119. data/lib/rumale/preprocessing/binarizer.rb +0 -60
  120. data/lib/rumale/preprocessing/kernel_calculator.rb +0 -92
  121. data/lib/rumale/preprocessing/l1_normalizer.rb +0 -62
  122. data/lib/rumale/preprocessing/l2_normalizer.rb +0 -63
  123. data/lib/rumale/preprocessing/label_binarizer.rb +0 -89
  124. data/lib/rumale/preprocessing/label_encoder.rb +0 -79
  125. data/lib/rumale/preprocessing/max_abs_scaler.rb +0 -61
  126. data/lib/rumale/preprocessing/max_normalizer.rb +0 -62
  127. data/lib/rumale/preprocessing/min_max_scaler.rb +0 -76
  128. data/lib/rumale/preprocessing/one_hot_encoder.rb +0 -100
  129. data/lib/rumale/preprocessing/ordinal_encoder.rb +0 -109
  130. data/lib/rumale/preprocessing/polynomial_features.rb +0 -109
  131. data/lib/rumale/preprocessing/standard_scaler.rb +0 -71
  132. data/lib/rumale/probabilistic_output.rb +0 -114
  133. data/lib/rumale/tree/base_decision_tree.rb +0 -150
  134. data/lib/rumale/tree/decision_tree_classifier.rb +0 -150
  135. data/lib/rumale/tree/decision_tree_regressor.rb +0 -116
  136. data/lib/rumale/tree/extra_tree_classifier.rb +0 -107
  137. data/lib/rumale/tree/extra_tree_regressor.rb +0 -94
  138. data/lib/rumale/tree/gradient_tree_regressor.rb +0 -202
  139. data/lib/rumale/tree/node.rb +0 -39
  140. data/lib/rumale/utils.rb +0 -42
  141. data/lib/rumale/validation.rb +0 -128
  142. data/lib/rumale/values.rb +0 -13
data/CODE_OF_CONDUCT.md DELETED
@@ -1,74 +0,0 @@
1
- # Contributor Covenant Code of Conduct
2
-
3
- ## Our Pledge
4
-
5
- In the interest of fostering an open and welcoming environment, we as
6
- contributors and maintainers pledge to making participation in our project and
7
- our community a harassment-free experience for everyone, regardless of age, body
8
- size, disability, ethnicity, gender identity and expression, level of experience,
9
- nationality, personal appearance, race, religion, or sexual identity and
10
- orientation.
11
-
12
- ## Our Standards
13
-
14
- Examples of behavior that contributes to creating a positive environment
15
- include:
16
-
17
- * Using welcoming and inclusive language
18
- * Being respectful of differing viewpoints and experiences
19
- * Gracefully accepting constructive criticism
20
- * Focusing on what is best for the community
21
- * Showing empathy towards other community members
22
-
23
- Examples of unacceptable behavior by participants include:
24
-
25
- * The use of sexualized language or imagery and unwelcome sexual attention or
26
- advances
27
- * Trolling, insulting/derogatory comments, and personal or political attacks
28
- * Public or private harassment
29
- * Publishing others' private information, such as a physical or electronic
30
- address, without explicit permission
31
- * Other conduct which could reasonably be considered inappropriate in a
32
- professional setting
33
-
34
- ## Our Responsibilities
35
-
36
- Project maintainers are responsible for clarifying the standards of acceptable
37
- behavior and are expected to take appropriate and fair corrective action in
38
- response to any instances of unacceptable behavior.
39
-
40
- Project maintainers have the right and responsibility to remove, edit, or
41
- reject comments, commits, code, wiki edits, issues, and other contributions
42
- that are not aligned to this Code of Conduct, or to ban temporarily or
43
- permanently any contributor for other behaviors that they deem inappropriate,
44
- threatening, offensive, or harmful.
45
-
46
- ## Scope
47
-
48
- This Code of Conduct applies both within project spaces and in public spaces
49
- when an individual is representing the project or its community. Examples of
50
- representing a project or community include using an official project e-mail
51
- address, posting via an official social media account, or acting as an appointed
52
- representative at an online or offline event. Representation of a project may be
53
- further defined and clarified by project maintainers.
54
-
55
- ## Enforcement
56
-
57
- Instances of abusive, harassing, or otherwise unacceptable behavior may be
58
- reported by contacting the project team at yoshoku@outlook.com. All
59
- complaints will be reviewed and investigated and will result in a response that
60
- is deemed necessary and appropriate to the circumstances. The project team is
61
- obligated to maintain confidentiality with regard to the reporter of an incident.
62
- Further details of specific enforcement policies may be posted separately.
63
-
64
- Project maintainers who do not follow or enforce the Code of Conduct in good
65
- faith may face temporary or permanent repercussions as determined by other
66
- members of the project's leadership.
67
-
68
- ## Attribution
69
-
70
- This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71
- available at [http://contributor-covenant.org/version/1/4][version]
72
-
73
- [homepage]: http://contributor-covenant.org
74
- [version]: http://contributor-covenant.org/version/1/4/
@@ -1,37 +0,0 @@
1
- # frozen_string_literal: true
2
-
3
- require 'mkmf'
4
- require 'numo/narray'
5
-
6
- $LOAD_PATH.each do |lp|
7
- if File.exist?(File.join(lp, 'numo/numo/narray.h'))
8
- $INCFLAGS = "-I#{lp}/numo #{$INCFLAGS}"
9
- break
10
- end
11
- end
12
-
13
- unless have_header('numo/narray.h')
14
- puts 'numo/narray.h not found.'
15
- exit(1)
16
- end
17
-
18
- if RUBY_PLATFORM =~ /mswin|cygwin|mingw/
19
- $LOAD_PATH.each do |lp|
20
- if File.exist?(File.join(lp, 'numo/libnarray.a'))
21
- $LDFLAGS = "-L#{lp}/numo #{$LDFLAGS}"
22
- break
23
- end
24
- end
25
- unless have_library('narray', 'nary_new')
26
- puts 'libnarray.a not found.'
27
- exit(1)
28
- end
29
- end
30
-
31
- if RUBY_PLATFORM.match?(/darwin/) && Gem::Version.new('3.1.0') <= Gem::Version.new(RUBY_VERSION)
32
- if try_link('int main(void){return 0;}', '-Wl,-undefined,dynamic_lookup')
33
- $LDFLAGS << ' -Wl,-undefined,dynamic_lookup'
34
- end
35
- end
36
-
37
- create_makefile('rumale/rumaleext')
@@ -1,545 +0,0 @@
1
- #include "rumaleext.h"
2
-
3
- double* alloc_dbl_array(const long n_dimensions) {
4
- double* arr = ALLOC_N(double, n_dimensions);
5
- memset(arr, 0, n_dimensions * sizeof(double));
6
- return arr;
7
- }
8
-
9
- double calc_gini_coef(double* histogram, const long n_elements, const long n_classes) {
10
- long i;
11
- double el;
12
- double gini = 0.0;
13
-
14
- for (i = 0; i < n_classes; i++) {
15
- el = histogram[i] / n_elements;
16
- gini += el * el;
17
- }
18
-
19
- return 1.0 - gini;
20
- }
21
-
22
- double calc_entropy(double* histogram, const long n_elements, const long n_classes) {
23
- long i;
24
- double el;
25
- double entropy = 0.0;
26
-
27
- for (i = 0; i < n_classes; i++) {
28
- el = histogram[i] / n_elements;
29
- entropy += el * log(el + 1.0);
30
- }
31
-
32
- return -entropy;
33
- }
34
-
35
- VALUE
36
- calc_mean_vec(double* sum_vec, const long n_dimensions, const long n_elements) {
37
- long i;
38
- VALUE mean_vec = rb_ary_new2(n_dimensions);
39
-
40
- for (i = 0; i < n_dimensions; i++) {
41
- rb_ary_store(mean_vec, i, DBL2NUM(sum_vec[i] / n_elements));
42
- }
43
-
44
- return mean_vec;
45
- }
46
-
47
- double calc_vec_mae(VALUE vec_a, VALUE vec_b) {
48
- long i;
49
- const long n_dimensions = RARRAY_LEN(vec_a);
50
- double sum = 0.0;
51
- double diff;
52
-
53
- for (i = 0; i < n_dimensions; i++) {
54
- diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
55
- sum += fabs(diff);
56
- }
57
-
58
- return sum / n_dimensions;
59
- }
60
-
61
- double calc_vec_mse(VALUE vec_a, VALUE vec_b) {
62
- long i;
63
- const long n_dimensions = RARRAY_LEN(vec_a);
64
- double sum = 0.0;
65
- double diff;
66
-
67
- for (i = 0; i < n_dimensions; i++) {
68
- diff = NUM2DBL(rb_ary_entry(vec_a, i)) - NUM2DBL(rb_ary_entry(vec_b, i));
69
- sum += diff * diff;
70
- }
71
-
72
- return sum / n_dimensions;
73
- }
74
-
75
- double calc_mae(VALUE target_vecs, VALUE mean_vec) {
76
- long i;
77
- const long n_elements = RARRAY_LEN(target_vecs);
78
- double sum = 0.0;
79
-
80
- for (i = 0; i < n_elements; i++) {
81
- sum += calc_vec_mae(rb_ary_entry(target_vecs, i), mean_vec);
82
- }
83
-
84
- return sum / n_elements;
85
- }
86
-
87
- double calc_mse(VALUE target_vecs, VALUE mean_vec) {
88
- long i;
89
- const long n_elements = RARRAY_LEN(target_vecs);
90
- double sum = 0.0;
91
-
92
- for (i = 0; i < n_elements; i++) {
93
- sum += calc_vec_mse(rb_ary_entry(target_vecs, i), mean_vec);
94
- }
95
-
96
- return sum / n_elements;
97
- }
98
-
99
- double calc_impurity_cls(const char* criterion, double* histogram, const long n_elements, const long n_classes) {
100
- if (strcmp(criterion, "entropy") == 0) {
101
- return calc_entropy(histogram, n_elements, n_classes);
102
- }
103
- return calc_gini_coef(histogram, n_elements, n_classes);
104
- }
105
-
106
- double calc_impurity_reg(const char* criterion, VALUE target_vecs, double* sum_vec) {
107
- const long n_elements = RARRAY_LEN(target_vecs);
108
- const long n_dimensions = RARRAY_LEN(rb_ary_entry(target_vecs, 0));
109
- VALUE mean_vec = calc_mean_vec(sum_vec, n_dimensions, n_elements);
110
-
111
- if (strcmp(criterion, "mae") == 0) {
112
- return calc_mae(target_vecs, mean_vec);
113
- }
114
- return calc_mse(target_vecs, mean_vec);
115
- }
116
-
117
- void add_sum_vec(double* sum_vec, VALUE target) {
118
- long i;
119
- const long n_dimensions = RARRAY_LEN(target);
120
-
121
- for (i = 0; i < n_dimensions; i++) {
122
- sum_vec[i] += NUM2DBL(rb_ary_entry(target, i));
123
- }
124
- }
125
-
126
- void sub_sum_vec(double* sum_vec, VALUE target) {
127
- long i;
128
- const long n_dimensions = RARRAY_LEN(target);
129
-
130
- for (i = 0; i < n_dimensions; i++) {
131
- sum_vec[i] -= NUM2DBL(rb_ary_entry(target, i));
132
- }
133
- }
134
-
135
- /**
136
- * @!visibility private
137
- */
138
- typedef struct {
139
- char* criterion;
140
- long n_classes;
141
- double impurity;
142
- } split_opts_cls;
143
- /**
144
- * @!visibility private
145
- */
146
- static void iter_find_split_params_cls(na_loop_t const* lp) {
147
- const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
148
- const double* f = (double*)NDL_PTR(lp, 1);
149
- const int32_t* y = (int32_t*)NDL_PTR(lp, 2);
150
- const long n_elements = NDL_SHAPE(lp, 0)[0];
151
- const char* criterion = ((split_opts_cls*)lp->opt_ptr)->criterion;
152
- const long n_classes = ((split_opts_cls*)lp->opt_ptr)->n_classes;
153
- const double w_impurity = ((split_opts_cls*)lp->opt_ptr)->impurity;
154
- double* params = (double*)NDL_PTR(lp, 3);
155
- long i;
156
- long curr_pos = 0;
157
- long next_pos = 0;
158
- long n_l_elements = 0;
159
- long n_r_elements = n_elements;
160
- double curr_el = f[o[0]];
161
- double last_el = f[o[n_elements - 1]];
162
- double next_el;
163
- double l_impurity;
164
- double r_impurity;
165
- double gain;
166
- double* l_histogram = alloc_dbl_array(n_classes);
167
- double* r_histogram = alloc_dbl_array(n_classes);
168
-
169
- /* Initialize optimal parameters. */
170
- params[0] = 0.0; /* left impurity */
171
- params[1] = w_impurity; /* right impurity */
172
- params[2] = curr_el; /* threshold */
173
- params[3] = 0.0; /* gain */
174
-
175
- /* Initialize child node variables. */
176
- for (i = 0; i < n_elements; i++) {
177
- r_histogram[y[o[i]]] += 1.0;
178
- }
179
-
180
- /* Find optimal parameters. */
181
- while (curr_pos < n_elements && curr_el != last_el) {
182
- next_el = f[o[next_pos]];
183
- while (next_pos < n_elements && next_el == curr_el) {
184
- l_histogram[y[o[next_pos]]] += 1;
185
- n_l_elements++;
186
- r_histogram[y[o[next_pos]]] -= 1;
187
- n_r_elements--;
188
- next_pos++;
189
- next_el = f[o[next_pos]];
190
- }
191
- /* Calculate gain of new split. */
192
- l_impurity = calc_impurity_cls(criterion, l_histogram, n_l_elements, n_classes);
193
- r_impurity = calc_impurity_cls(criterion, r_histogram, n_r_elements, n_classes);
194
- gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
195
- /* Update optimal parameters. */
196
- if (gain > params[3]) {
197
- params[0] = l_impurity;
198
- params[1] = r_impurity;
199
- params[2] = 0.5 * (curr_el + next_el);
200
- params[3] = gain;
201
- }
202
- if (next_pos == n_elements) break;
203
- curr_pos = next_pos;
204
- curr_el = f[o[curr_pos]];
205
- }
206
-
207
- xfree(l_histogram);
208
- xfree(r_histogram);
209
- }
210
- /**
211
- * @!visibility private
212
- * Find for split point with maximum information gain.
213
- *
214
- * @overload find_split_params(criterion, impurity, order, features, labels, n_classes) -> Array<Float>
215
- *
216
- * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'.
217
- * @param impurity [Float] The impurity of whole dataset.
218
- * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
219
- * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
220
- * @param labels [Numo::Int32] (shape: [n_elements]) The labels.
221
- * @param n_classes [Integer] The number of classes.
222
- * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
223
- */
224
- static VALUE find_split_params_cls(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE labels,
225
- VALUE n_classes) {
226
- ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cInt32, 1}};
227
- size_t out_shape[1] = {4};
228
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
229
- ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_cls, NO_LOOP, 3, 1, ain, aout};
230
- split_opts_cls opts = {StringValuePtr(criterion), NUM2LONG(n_classes), NUM2DBL(impurity)};
231
- VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, labels);
232
- VALUE results = rb_ary_new2(4);
233
- double* params_ptr = (double*)na_get_pointer_for_read(params);
234
- rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
235
- rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
236
- rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
237
- rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
238
- RB_GC_GUARD(params);
239
- RB_GC_GUARD(criterion);
240
- return results;
241
- }
242
-
243
- /**
244
- * @!visibility private
245
- */
246
- typedef struct {
247
- char* criterion;
248
- double impurity;
249
- } split_opts_reg;
250
- /**
251
- * @!visibility private
252
- */
253
- static void iter_find_split_params_reg(na_loop_t const* lp) {
254
- const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
255
- const double* f = (double*)NDL_PTR(lp, 1);
256
- const double* y = (double*)NDL_PTR(lp, 2);
257
- const long n_elements = NDL_SHAPE(lp, 0)[0];
258
- const long n_outputs = NDL_SHAPE(lp, 2)[1];
259
- const char* criterion = ((split_opts_reg*)lp->opt_ptr)->criterion;
260
- const double w_impurity = ((split_opts_reg*)lp->opt_ptr)->impurity;
261
- double* params = (double*)NDL_PTR(lp, 3);
262
- long i, j;
263
- long curr_pos = 0;
264
- long next_pos = 0;
265
- long n_l_elements = 0;
266
- long n_r_elements = n_elements;
267
- double curr_el = f[o[0]];
268
- double last_el = f[o[n_elements - 1]];
269
- double next_el;
270
- double l_impurity;
271
- double r_impurity;
272
- double gain;
273
- double* l_sum_vec = alloc_dbl_array(n_outputs);
274
- double* r_sum_vec = alloc_dbl_array(n_outputs);
275
- double target_var;
276
- VALUE l_target_vecs = rb_ary_new();
277
- VALUE r_target_vecs = rb_ary_new();
278
- VALUE target;
279
-
280
- /* Initialize optimal parameters. */
281
- params[0] = 0.0; /* left impurity */
282
- params[1] = w_impurity; /* right impurity */
283
- params[2] = curr_el; /* threshold */
284
- params[3] = 0.0; /* gain */
285
-
286
- /* Initialize child node variables. */
287
- for (i = 0; i < n_elements; i++) {
288
- target = rb_ary_new2(n_outputs);
289
- for (j = 0; j < n_outputs; j++) {
290
- target_var = y[o[i] * n_outputs + j];
291
- rb_ary_store(target, j, DBL2NUM(target_var));
292
- r_sum_vec[j] += target_var;
293
- }
294
- rb_ary_push(r_target_vecs, target);
295
- }
296
-
297
- /* Find optimal parameters. */
298
- while (curr_pos < n_elements && curr_el != last_el) {
299
- next_el = f[o[next_pos]];
300
- while (next_pos < n_elements && next_el == curr_el) {
301
- target = rb_ary_shift(r_target_vecs);
302
- n_r_elements--;
303
- sub_sum_vec(r_sum_vec, target);
304
- rb_ary_push(l_target_vecs, target);
305
- n_l_elements++;
306
- add_sum_vec(l_sum_vec, target);
307
- next_pos++;
308
- next_el = f[o[next_pos]];
309
- }
310
- /* Calculate gain of new split. */
311
- l_impurity = calc_impurity_reg(criterion, l_target_vecs, l_sum_vec);
312
- r_impurity = calc_impurity_reg(criterion, r_target_vecs, r_sum_vec);
313
- gain = w_impurity - (n_l_elements * l_impurity + n_r_elements * r_impurity) / n_elements;
314
- /* Update optimal parameters. */
315
- if (gain > params[3]) {
316
- params[0] = l_impurity;
317
- params[1] = r_impurity;
318
- params[2] = 0.5 * (curr_el + next_el);
319
- params[3] = gain;
320
- }
321
- if (next_pos == n_elements) break;
322
- curr_pos = next_pos;
323
- curr_el = f[o[curr_pos]];
324
- }
325
-
326
- xfree(l_sum_vec);
327
- xfree(r_sum_vec);
328
- }
329
- /**
330
- * @!visibility private
331
- * Find for split point with maximum information gain.
332
- *
333
- * @overload find_split_params(criterion, impurity, order, features, targets) -> Array<Float>
334
- *
335
- * @param criterion [String] The function to evaluate spliting point. Supported criteria are 'mae' and 'mse'.
336
- * @param impurity [Float] The impurity of whole dataset.
337
- * @param order [Numo::Int32] (shape: [n_samples]) The element indices sorted according to feature values in ascending order.
338
- * @param features [Numo::DFloat] (shape: [n_samples]) The feature values.
339
- * @param targets [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values.
340
- * @return [Array<Float>] The array consists of optimal parameters including impurities of child nodes, threshold, and gain.
341
- */
342
- static VALUE find_split_params_reg(VALUE self, VALUE criterion, VALUE impurity, VALUE order, VALUE features, VALUE targets) {
343
- ndfunc_arg_in_t ain[3] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 2}};
344
- size_t out_shape[1] = {4};
345
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
346
- ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_reg, NO_LOOP, 3, 1, ain, aout};
347
- split_opts_reg opts = {StringValuePtr(criterion), NUM2DBL(impurity)};
348
- VALUE params = na_ndloop3(&ndf, &opts, 3, order, features, targets);
349
- VALUE results = rb_ary_new2(4);
350
- double* params_ptr = (double*)na_get_pointer_for_read(params);
351
- rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
352
- rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
353
- rb_ary_store(results, 2, DBL2NUM(params_ptr[2]));
354
- rb_ary_store(results, 3, DBL2NUM(params_ptr[3]));
355
- RB_GC_GUARD(params);
356
- RB_GC_GUARD(criterion);
357
- return results;
358
- }
359
-
360
- /**
361
- * @!visibility private
362
- */
363
- static void iter_find_split_params_grad_reg(na_loop_t const* lp) {
364
- const int32_t* o = (int32_t*)NDL_PTR(lp, 0);
365
- const double* f = (double*)NDL_PTR(lp, 1);
366
- const double* g = (double*)NDL_PTR(lp, 2);
367
- const double* h = (double*)NDL_PTR(lp, 3);
368
- const double s_grad = ((double*)lp->opt_ptr)[0];
369
- const double s_hess = ((double*)lp->opt_ptr)[1];
370
- const double reg_lambda = ((double*)lp->opt_ptr)[2];
371
- const long n_elements = NDL_SHAPE(lp, 0)[0];
372
- double* params = (double*)NDL_PTR(lp, 4);
373
- long curr_pos = 0;
374
- long next_pos = 0;
375
- double curr_el = f[o[0]];
376
- double last_el = f[o[n_elements - 1]];
377
- double next_el;
378
- double l_grad = 0.0;
379
- double l_hess = 0.0;
380
- double r_grad;
381
- double r_hess;
382
- double threshold = curr_el;
383
- double gain_max = 0.0;
384
- double gain;
385
-
386
- /* Find optimal parameters. */
387
- while (curr_pos < n_elements && curr_el != last_el) {
388
- next_el = f[o[next_pos]];
389
- while (next_pos < n_elements && next_el == curr_el) {
390
- l_grad += g[o[next_pos]];
391
- l_hess += h[o[next_pos]];
392
- next_pos++;
393
- next_el = f[o[next_pos]];
394
- }
395
- /* Calculate gain of new split. */
396
- r_grad = s_grad - l_grad;
397
- r_hess = s_hess - l_hess;
398
- gain = (l_grad * l_grad) / (l_hess + reg_lambda) + (r_grad * r_grad) / (r_hess + reg_lambda) -
399
- (s_grad * s_grad) / (s_hess + reg_lambda);
400
- /* Update optimal parameters. */
401
- if (gain > gain_max) {
402
- threshold = 0.5 * (curr_el + next_el);
403
- gain_max = gain;
404
- }
405
- if (next_pos == n_elements) {
406
- break;
407
- }
408
- curr_pos = next_pos;
409
- curr_el = f[o[curr_pos]];
410
- }
411
-
412
- params[0] = threshold;
413
- params[1] = gain_max;
414
- }
415
-
416
- /**
417
- * @!visibility private
418
- * Find for split point with maximum information gain.
419
- *
420
- * @overload find_split_params(order, features, gradients, hessians, sum_gradient, sum_hessian, reg_lambda) -> Array<Float>
421
- * @param order [Numo::Int32] (shape: [n_elements]) The element indices sorted according to feature values.
422
- * @param features [Numo::DFloat] (shape: [n_elements]) The feature values.
423
- * @param gradients [Numo::DFloat] (shape: [n_elements]) The gradient values.
424
- * @param hessians [Numo::DFloat] (shape: [n_elements]) The hessian values.
425
- * @param sum_gradient [Float] The sum of gradient values.
426
- * @param sum_hessian [Float] The sum of hessian values.
427
- * @param reg_lambda [Float] The L2 regularization term on weight.
428
- * @return [Array<Float>] The array consists of optimal parameters including threshold and gain.
429
- */
430
- static VALUE find_split_params_grad_reg(VALUE self, VALUE order, VALUE features, VALUE gradients, VALUE hessians,
431
- VALUE sum_gradient, VALUE sum_hessian, VALUE reg_lambda) {
432
- ndfunc_arg_in_t ain[4] = {{numo_cInt32, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}, {numo_cDFloat, 1}};
433
- size_t out_shape[1] = {2};
434
- ndfunc_arg_out_t aout[1] = {{numo_cDFloat, 1, out_shape}};
435
- ndfunc_t ndf = {(na_iter_func_t)iter_find_split_params_grad_reg, NO_LOOP, 4, 1, ain, aout};
436
- double opts[3] = {NUM2DBL(sum_gradient), NUM2DBL(sum_hessian), NUM2DBL(reg_lambda)};
437
- VALUE params = na_ndloop3(&ndf, opts, 4, order, features, gradients, hessians);
438
- VALUE results = rb_ary_new2(2);
439
- double* params_ptr = (double*)na_get_pointer_for_read(params);
440
- rb_ary_store(results, 0, DBL2NUM(params_ptr[0]));
441
- rb_ary_store(results, 1, DBL2NUM(params_ptr[1]));
442
- RB_GC_GUARD(params);
443
- return results;
444
- }
445
-
446
- /**
447
- * @!visibility private
448
- * Calculate impurity based on criterion.
449
- *
450
- * @overload node_impurity(criterion, y, n_classes) -> Float
451
- *
452
- * @param criterion [String] The function to calculate impurity. Supported criteria are 'gini' and 'entropy'.
453
- * @param y_nary [Numo::Int32] (shape: [n_samples]) The labels.
454
- * @param n_elements_ [Integer] The number of elements.
455
- * @param n_classes_ [Integer] The number of classes.
456
- * @return [Float] impurity
457
- */
458
- static VALUE node_impurity_cls(VALUE self, VALUE criterion, VALUE y_nary, VALUE n_elements_, VALUE n_classes_) {
459
- long i;
460
- const long n_classes = NUM2LONG(n_classes_);
461
- const long n_elements = NUM2LONG(n_elements_);
462
- const int32_t* y = (int32_t*)na_get_pointer_for_read(y_nary);
463
- double* histogram = alloc_dbl_array(n_classes);
464
- VALUE ret;
465
-
466
- for (i = 0; i < n_elements; i++) {
467
- histogram[y[i]] += 1;
468
- }
469
-
470
- ret = DBL2NUM(calc_impurity_cls(StringValuePtr(criterion), histogram, n_elements, n_classes));
471
-
472
- xfree(histogram);
473
-
474
- RB_GC_GUARD(y_nary);
475
- RB_GC_GUARD(criterion);
476
-
477
- return ret;
478
- }
479
-
480
- /**
481
- * @!visibility private
482
- * Calculate impurity based on criterion.
483
- *
484
- * @overload node_impurity(criterion, y) -> Float
485
- *
486
- * @param criterion [String] The function to calculate impurity. Supported criteria are 'mae' and 'mse'.
487
- * @param y [Array<Float>] (shape: [n_samples, n_outputs]) The taget values.
488
- * @return [Float] impurity
489
- */
490
- static VALUE node_impurity_reg(VALUE self, VALUE criterion, VALUE y) {
491
- long i;
492
- const long n_elements = RARRAY_LEN(y);
493
- const long n_outputs = RARRAY_LEN(rb_ary_entry(y, 0));
494
- double* sum_vec = alloc_dbl_array(n_outputs);
495
- VALUE target_vecs = rb_ary_new();
496
- VALUE target;
497
- VALUE ret;
498
-
499
- for (i = 0; i < n_elements; i++) {
500
- target = rb_ary_entry(y, i);
501
- add_sum_vec(sum_vec, target);
502
- rb_ary_push(target_vecs, target);
503
- }
504
-
505
- ret = DBL2NUM(calc_impurity_reg(StringValuePtr(criterion), target_vecs, sum_vec));
506
-
507
- xfree(sum_vec);
508
-
509
- RB_GC_GUARD(criterion);
510
-
511
- return ret;
512
- }
513
-
514
- void Init_rumaleext(void) {
515
- VALUE mRumale = rb_define_module("Rumale");
516
- VALUE mTree = rb_define_module_under(mRumale, "Tree");
517
-
518
- /**
519
- * Document-module: Rumale::Tree::ExtDecisionTreeClassifier
520
- * @!visibility private
521
- * The mixin module consisting of extension method for DecisionTreeClassifier class.
522
- * This module is used internally.
523
- */
524
- VALUE mExtDTreeCls = rb_define_module_under(mTree, "ExtDecisionTreeClassifier");
525
- /**
526
- * Document-module: Rumale::Tree::ExtDecisionTreeRegressor
527
- * @!visibility private
528
- * The mixin module consisting of extension method for DecisionTreeRegressor class.
529
- * This module is used internally.
530
- */
531
- VALUE mExtDTreeReg = rb_define_module_under(mTree, "ExtDecisionTreeRegressor");
532
- /**
533
- * Document-module: Rumale::Tree::ExtGradientTreeRegressor
534
- * @!visibility private
535
- * The mixin module consisting of extension method for GradientTreeRegressor class.
536
- * This module is used internally.
537
- */
538
- VALUE mExtGTreeReg = rb_define_module_under(mTree, "ExtGradientTreeRegressor");
539
-
540
- rb_define_private_method(mExtDTreeCls, "find_split_params", find_split_params_cls, 6);
541
- rb_define_private_method(mExtDTreeReg, "find_split_params", find_split_params_reg, 5);
542
- rb_define_private_method(mExtGTreeReg, "find_split_params", find_split_params_grad_reg, 7);
543
- rb_define_private_method(mExtDTreeCls, "node_impurity", node_impurity_cls, 4);
544
- rb_define_private_method(mExtDTreeReg, "node_impurity", node_impurity_reg, 2);
545
- }
@@ -1,12 +0,0 @@
1
- #ifndef RUMALEEXT_H
2
- #define RUMALEEXT_H 1
3
-
4
- #include <math.h>
5
- #include <string.h>
6
-
7
- #include <ruby.h>
8
-
9
- #include <numo/narray.h>
10
- #include <numo/template.h>
11
-
12
- #endif /* RUMALEEXT_H */