rumale 0.23.3 → 0.24.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/LICENSE.txt +5 -1
- data/README.md +3 -288
- data/lib/rumale/version.rb +1 -1
- data/lib/rumale.rb +20 -131
- metadata +252 -150
- data/CHANGELOG.md +0 -643
- data/CODE_OF_CONDUCT.md +0 -74
- data/ext/rumale/extconf.rb +0 -37
- data/ext/rumale/rumaleext.c +0 -545
- data/ext/rumale/rumaleext.h +0 -12
- data/lib/rumale/base/base_estimator.rb +0 -49
- data/lib/rumale/base/classifier.rb +0 -36
- data/lib/rumale/base/cluster_analyzer.rb +0 -31
- data/lib/rumale/base/evaluator.rb +0 -17
- data/lib/rumale/base/regressor.rb +0 -36
- data/lib/rumale/base/splitter.rb +0 -21
- data/lib/rumale/base/transformer.rb +0 -22
- data/lib/rumale/clustering/dbscan.rb +0 -123
- data/lib/rumale/clustering/gaussian_mixture.rb +0 -218
- data/lib/rumale/clustering/hdbscan.rb +0 -291
- data/lib/rumale/clustering/k_means.rb +0 -122
- data/lib/rumale/clustering/k_medoids.rb +0 -141
- data/lib/rumale/clustering/mini_batch_k_means.rb +0 -139
- data/lib/rumale/clustering/power_iteration.rb +0 -127
- data/lib/rumale/clustering/single_linkage.rb +0 -203
- data/lib/rumale/clustering/snn.rb +0 -76
- data/lib/rumale/clustering/spectral_clustering.rb +0 -115
- data/lib/rumale/dataset.rb +0 -246
- data/lib/rumale/decomposition/factor_analysis.rb +0 -150
- data/lib/rumale/decomposition/fast_ica.rb +0 -188
- data/lib/rumale/decomposition/nmf.rb +0 -124
- data/lib/rumale/decomposition/pca.rb +0 -159
- data/lib/rumale/ensemble/ada_boost_classifier.rb +0 -179
- data/lib/rumale/ensemble/ada_boost_regressor.rb +0 -160
- data/lib/rumale/ensemble/extra_trees_classifier.rb +0 -139
- data/lib/rumale/ensemble/extra_trees_regressor.rb +0 -125
- data/lib/rumale/ensemble/gradient_boosting_classifier.rb +0 -306
- data/lib/rumale/ensemble/gradient_boosting_regressor.rb +0 -237
- data/lib/rumale/ensemble/random_forest_classifier.rb +0 -189
- data/lib/rumale/ensemble/random_forest_regressor.rb +0 -153
- data/lib/rumale/ensemble/stacking_classifier.rb +0 -215
- data/lib/rumale/ensemble/stacking_regressor.rb +0 -163
- data/lib/rumale/ensemble/voting_classifier.rb +0 -126
- data/lib/rumale/ensemble/voting_regressor.rb +0 -82
- data/lib/rumale/evaluation_measure/accuracy.rb +0 -29
- data/lib/rumale/evaluation_measure/adjusted_rand_score.rb +0 -74
- data/lib/rumale/evaluation_measure/calinski_harabasz_score.rb +0 -56
- data/lib/rumale/evaluation_measure/davies_bouldin_score.rb +0 -53
- data/lib/rumale/evaluation_measure/explained_variance_score.rb +0 -39
- data/lib/rumale/evaluation_measure/f_score.rb +0 -50
- data/lib/rumale/evaluation_measure/function.rb +0 -147
- data/lib/rumale/evaluation_measure/log_loss.rb +0 -45
- data/lib/rumale/evaluation_measure/mean_absolute_error.rb +0 -29
- data/lib/rumale/evaluation_measure/mean_squared_error.rb +0 -29
- data/lib/rumale/evaluation_measure/mean_squared_log_error.rb +0 -29
- data/lib/rumale/evaluation_measure/median_absolute_error.rb +0 -30
- data/lib/rumale/evaluation_measure/mutual_information.rb +0 -49
- data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +0 -53
- data/lib/rumale/evaluation_measure/precision.rb +0 -50
- data/lib/rumale/evaluation_measure/precision_recall.rb +0 -96
- data/lib/rumale/evaluation_measure/purity.rb +0 -40
- data/lib/rumale/evaluation_measure/r2_score.rb +0 -43
- data/lib/rumale/evaluation_measure/recall.rb +0 -50
- data/lib/rumale/evaluation_measure/roc_auc.rb +0 -130
- data/lib/rumale/evaluation_measure/silhouette_score.rb +0 -82
- data/lib/rumale/feature_extraction/feature_hasher.rb +0 -110
- data/lib/rumale/feature_extraction/hash_vectorizer.rb +0 -155
- data/lib/rumale/feature_extraction/tfidf_transformer.rb +0 -113
- data/lib/rumale/kernel_approximation/nystroem.rb +0 -126
- data/lib/rumale/kernel_approximation/rbf.rb +0 -102
- data/lib/rumale/kernel_machine/kernel_fda.rb +0 -120
- data/lib/rumale/kernel_machine/kernel_pca.rb +0 -97
- data/lib/rumale/kernel_machine/kernel_ridge.rb +0 -82
- data/lib/rumale/kernel_machine/kernel_ridge_classifier.rb +0 -92
- data/lib/rumale/kernel_machine/kernel_svc.rb +0 -193
- data/lib/rumale/linear_model/base_sgd.rb +0 -285
- data/lib/rumale/linear_model/elastic_net.rb +0 -119
- data/lib/rumale/linear_model/lasso.rb +0 -115
- data/lib/rumale/linear_model/linear_regression.rb +0 -201
- data/lib/rumale/linear_model/logistic_regression.rb +0 -275
- data/lib/rumale/linear_model/nnls.rb +0 -137
- data/lib/rumale/linear_model/ridge.rb +0 -209
- data/lib/rumale/linear_model/svc.rb +0 -213
- data/lib/rumale/linear_model/svr.rb +0 -132
- data/lib/rumale/manifold/mds.rb +0 -155
- data/lib/rumale/manifold/tsne.rb +0 -222
- data/lib/rumale/metric_learning/fisher_discriminant_analysis.rb +0 -113
- data/lib/rumale/metric_learning/mlkr.rb +0 -161
- data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +0 -167
- data/lib/rumale/model_selection/cross_validation.rb +0 -125
- data/lib/rumale/model_selection/function.rb +0 -42
- data/lib/rumale/model_selection/grid_search_cv.rb +0 -225
- data/lib/rumale/model_selection/group_k_fold.rb +0 -93
- data/lib/rumale/model_selection/group_shuffle_split.rb +0 -115
- data/lib/rumale/model_selection/k_fold.rb +0 -81
- data/lib/rumale/model_selection/shuffle_split.rb +0 -90
- data/lib/rumale/model_selection/stratified_k_fold.rb +0 -99
- data/lib/rumale/model_selection/stratified_shuffle_split.rb +0 -118
- data/lib/rumale/model_selection/time_series_split.rb +0 -91
- data/lib/rumale/multiclass/one_vs_rest_classifier.rb +0 -83
- data/lib/rumale/naive_bayes/base_naive_bayes.rb +0 -47
- data/lib/rumale/naive_bayes/bernoulli_nb.rb +0 -82
- data/lib/rumale/naive_bayes/complement_nb.rb +0 -85
- data/lib/rumale/naive_bayes/gaussian_nb.rb +0 -69
- data/lib/rumale/naive_bayes/multinomial_nb.rb +0 -74
- data/lib/rumale/naive_bayes/negation_nb.rb +0 -71
- data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +0 -133
- data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +0 -108
- data/lib/rumale/nearest_neighbors/vp_tree.rb +0 -132
- data/lib/rumale/neural_network/adam.rb +0 -56
- data/lib/rumale/neural_network/base_mlp.rb +0 -248
- data/lib/rumale/neural_network/mlp_classifier.rb +0 -120
- data/lib/rumale/neural_network/mlp_regressor.rb +0 -90
- data/lib/rumale/pairwise_metric.rb +0 -152
- data/lib/rumale/pipeline/feature_union.rb +0 -69
- data/lib/rumale/pipeline/pipeline.rb +0 -175
- data/lib/rumale/preprocessing/bin_discretizer.rb +0 -93
- data/lib/rumale/preprocessing/binarizer.rb +0 -60
- data/lib/rumale/preprocessing/kernel_calculator.rb +0 -92
- data/lib/rumale/preprocessing/l1_normalizer.rb +0 -62
- data/lib/rumale/preprocessing/l2_normalizer.rb +0 -63
- data/lib/rumale/preprocessing/label_binarizer.rb +0 -89
- data/lib/rumale/preprocessing/label_encoder.rb +0 -79
- data/lib/rumale/preprocessing/max_abs_scaler.rb +0 -61
- data/lib/rumale/preprocessing/max_normalizer.rb +0 -62
- data/lib/rumale/preprocessing/min_max_scaler.rb +0 -76
- data/lib/rumale/preprocessing/one_hot_encoder.rb +0 -100
- data/lib/rumale/preprocessing/ordinal_encoder.rb +0 -109
- data/lib/rumale/preprocessing/polynomial_features.rb +0 -109
- data/lib/rumale/preprocessing/standard_scaler.rb +0 -71
- data/lib/rumale/probabilistic_output.rb +0 -114
- data/lib/rumale/tree/base_decision_tree.rb +0 -150
- data/lib/rumale/tree/decision_tree_classifier.rb +0 -150
- data/lib/rumale/tree/decision_tree_regressor.rb +0 -116
- data/lib/rumale/tree/extra_tree_classifier.rb +0 -107
- data/lib/rumale/tree/extra_tree_regressor.rb +0 -94
- data/lib/rumale/tree/gradient_tree_regressor.rb +0 -202
- data/lib/rumale/tree/node.rb +0 -39
- data/lib/rumale/utils.rb +0 -42
- data/lib/rumale/validation.rb +0 -128
- data/lib/rumale/values.rb +0 -13
|
@@ -1,115 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/splitter'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
module ModelSelection
|
|
7
|
-
# GroupShuffleSplit is a class that generates the set of data indices
|
|
8
|
-
# for random permutation cross-validation by randomly selecting group labels.
|
|
9
|
-
#
|
|
10
|
-
# @example
|
|
11
|
-
# cv = Rumale::ModelSelection::GroupShuffleSplit.new(n_splits: 2, test_size: 0.2, random_seed: 1)
|
|
12
|
-
# x = Numo::DFloat.new(8, 2).rand
|
|
13
|
-
# groups = Numo::Int32[1, 1, 1, 2, 2, 3, 3, 3]
|
|
14
|
-
# cv.split(x, nil, groups).each do |train_ids, test_ids|
|
|
15
|
-
# puts '---'
|
|
16
|
-
# pp train_ids
|
|
17
|
-
# pp test_ids
|
|
18
|
-
# end
|
|
19
|
-
#
|
|
20
|
-
# # ---
|
|
21
|
-
# # [0, 1, 2, 5, 6, 7]
|
|
22
|
-
# # [3, 4]
|
|
23
|
-
# # ---
|
|
24
|
-
# # [3, 4, 5, 6, 7]
|
|
25
|
-
# # [0, 1, 2]
|
|
26
|
-
#
|
|
27
|
-
class GroupShuffleSplit
|
|
28
|
-
include Base::Splitter
|
|
29
|
-
|
|
30
|
-
# Return the number of folds.
|
|
31
|
-
# @return [Integer]
|
|
32
|
-
attr_reader :n_splits
|
|
33
|
-
|
|
34
|
-
# Return the random generator for shuffling the dataset.
|
|
35
|
-
# @return [Random]
|
|
36
|
-
attr_reader :rng
|
|
37
|
-
|
|
38
|
-
# Create a new data splitter for random permutation cross validation with given group labels.
|
|
39
|
-
#
|
|
40
|
-
# @param n_splits [Integer] The number of folds.
|
|
41
|
-
# @param test_size [Float] The ratio of number of groups for test data.
|
|
42
|
-
# @param train_size [Float/Nil] The ratio of number of groups for train data.
|
|
43
|
-
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
|
44
|
-
def initialize(n_splits: 5, test_size: 0.2, train_size: nil, random_seed: nil)
|
|
45
|
-
check_params_numeric(n_splits: n_splits, test_size: test_size)
|
|
46
|
-
check_params_numeric_or_nil(train_size: train_size, random_seed: random_seed)
|
|
47
|
-
check_params_positive(n_splits: n_splits)
|
|
48
|
-
check_params_positive(test_size: test_size)
|
|
49
|
-
check_params_positive(train_size: train_size) unless train_size.nil?
|
|
50
|
-
@n_splits = n_splits
|
|
51
|
-
@test_size = test_size
|
|
52
|
-
@train_size = train_size
|
|
53
|
-
@random_seed = random_seed
|
|
54
|
-
@random_seed ||= srand
|
|
55
|
-
@rng = Random.new(@random_seed)
|
|
56
|
-
end
|
|
57
|
-
|
|
58
|
-
# Generate train and test data indices by randomly selecting group labels.
|
|
59
|
-
#
|
|
60
|
-
# @overload split(x, y, groups) -> Array
|
|
61
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
|
62
|
-
# The dataset to be used to generate data indices for random permutation cross validation.
|
|
63
|
-
# @param y [Numo::Int32] (shape: [n_samples])
|
|
64
|
-
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
|
|
65
|
-
# @param groups [Numo::Int32] (shape: [n_samples])
|
|
66
|
-
# The group labels to be used to generate data indices for random permutation cross validation.
|
|
67
|
-
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
|
68
|
-
def split(x, _y, groups)
|
|
69
|
-
x = check_convert_sample_array(x)
|
|
70
|
-
groups = check_convert_label_array(groups)
|
|
71
|
-
check_sample_label_size(x, groups)
|
|
72
|
-
|
|
73
|
-
classes = groups.to_a.uniq.sort
|
|
74
|
-
n_groups = classes.size
|
|
75
|
-
n_test_groups = (@test_size * n_groups).ceil.to_i
|
|
76
|
-
n_train_groups = @train_size.nil? ? n_groups - n_test_groups : (@train_size * n_groups).floor.to_i
|
|
77
|
-
|
|
78
|
-
unless n_test_groups.between?(1, n_groups)
|
|
79
|
-
raise RangeError,
|
|
80
|
-
'The number of groups in test split must be not less than 1 and not more than the number of groups.'
|
|
81
|
-
end
|
|
82
|
-
unless n_train_groups.between?(1, n_groups)
|
|
83
|
-
raise RangeError,
|
|
84
|
-
'The number of groups in train split must be not less than 1 and not more than the number of groups.'
|
|
85
|
-
end
|
|
86
|
-
if (n_test_groups + n_train_groups) > n_groups
|
|
87
|
-
raise RangeError,
|
|
88
|
-
'The total number of groups in test split and train split must be not more than the number of groups.'
|
|
89
|
-
end
|
|
90
|
-
|
|
91
|
-
sub_rng = @rng.dup
|
|
92
|
-
|
|
93
|
-
Array.new(@n_splits) do
|
|
94
|
-
test_group_ids = classes.sample(n_test_groups, random: sub_rng)
|
|
95
|
-
train_group_ids = if @train_size.nil?
|
|
96
|
-
classes - test_group_ids
|
|
97
|
-
else
|
|
98
|
-
(classes - test_group_ids).sample(n_train_groups, random: sub_rng)
|
|
99
|
-
end
|
|
100
|
-
test_ids = in1d(groups, test_group_ids).where.to_a
|
|
101
|
-
train_ids = in1d(groups, train_group_ids).where.to_a
|
|
102
|
-
[train_ids, test_ids]
|
|
103
|
-
end
|
|
104
|
-
end
|
|
105
|
-
|
|
106
|
-
private
|
|
107
|
-
|
|
108
|
-
def in1d(a, b)
|
|
109
|
-
res = Numo::Bit.zeros(a.shape[0])
|
|
110
|
-
b.each { |v| res |= a.eq(v) }
|
|
111
|
-
res
|
|
112
|
-
end
|
|
113
|
-
end
|
|
114
|
-
end
|
|
115
|
-
end
|
|
@@ -1,81 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/splitter'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
# This module consists of the classes for model validation techniques.
|
|
7
|
-
module ModelSelection
|
|
8
|
-
# KFold is a class that generates the set of data indices for K-fold cross-validation.
|
|
9
|
-
#
|
|
10
|
-
# @example
|
|
11
|
-
# kf = Rumale::ModelSelection::KFold.new(n_splits: 3, shuffle: true, random_seed: 1)
|
|
12
|
-
# kf.split(samples, labels).each do |train_ids, test_ids|
|
|
13
|
-
# train_samples = samples[train_ids, true]
|
|
14
|
-
# test_samples = samples[test_ids, true]
|
|
15
|
-
# ...
|
|
16
|
-
# end
|
|
17
|
-
#
|
|
18
|
-
class KFold
|
|
19
|
-
include Base::Splitter
|
|
20
|
-
|
|
21
|
-
# Return the number of folds.
|
|
22
|
-
# @return [Integer]
|
|
23
|
-
attr_reader :n_splits
|
|
24
|
-
|
|
25
|
-
# Return the flag indicating whether to shuffle the dataset.
|
|
26
|
-
# @return [Boolean]
|
|
27
|
-
attr_reader :shuffle
|
|
28
|
-
|
|
29
|
-
# Return the random generator for shuffling the dataset.
|
|
30
|
-
# @return [Random]
|
|
31
|
-
attr_reader :rng
|
|
32
|
-
|
|
33
|
-
# Create a new data splitter for K-fold cross validation.
|
|
34
|
-
#
|
|
35
|
-
# @param n_splits [Integer] The number of folds.
|
|
36
|
-
# @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
|
|
37
|
-
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
|
38
|
-
def initialize(n_splits: 3, shuffle: false, random_seed: nil)
|
|
39
|
-
check_params_numeric(n_splits: n_splits)
|
|
40
|
-
check_params_boolean(shuffle: shuffle)
|
|
41
|
-
check_params_numeric_or_nil(random_seed: random_seed)
|
|
42
|
-
check_params_positive(n_splits: n_splits)
|
|
43
|
-
@n_splits = n_splits
|
|
44
|
-
@shuffle = shuffle
|
|
45
|
-
@random_seed = random_seed
|
|
46
|
-
@random_seed ||= srand
|
|
47
|
-
@rng = Random.new(@random_seed)
|
|
48
|
-
end
|
|
49
|
-
|
|
50
|
-
# Generate data indices for K-fold cross validation.
|
|
51
|
-
#
|
|
52
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
|
53
|
-
# The dataset to be used to generate data indices for K-fold cross validation.
|
|
54
|
-
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
|
55
|
-
def split(x, _y = nil)
|
|
56
|
-
x = check_convert_sample_array(x)
|
|
57
|
-
# Initialize and check some variables.
|
|
58
|
-
n_samples, = x.shape
|
|
59
|
-
unless @n_splits.between?(2, n_samples)
|
|
60
|
-
raise ArgumentError,
|
|
61
|
-
'The value of n_splits must be not less than 2 and not more than the number of samples.'
|
|
62
|
-
end
|
|
63
|
-
sub_rng = @rng.dup
|
|
64
|
-
# Splits dataset ids to each fold.
|
|
65
|
-
dataset_ids = Array(0...n_samples)
|
|
66
|
-
dataset_ids.shuffle!(random: sub_rng) if @shuffle
|
|
67
|
-
fold_sets = Array.new(@n_splits) do |n|
|
|
68
|
-
n_fold_samples = n_samples / @n_splits
|
|
69
|
-
n_fold_samples += 1 if n < n_samples % @n_splits
|
|
70
|
-
dataset_ids.shift(n_fold_samples)
|
|
71
|
-
end
|
|
72
|
-
# Returns array consisting of the training and testing ids for each fold.
|
|
73
|
-
Array.new(@n_splits) do |n|
|
|
74
|
-
train_ids = fold_sets.select.with_index { |_, id| id != n }.flatten
|
|
75
|
-
test_ids = fold_sets[n]
|
|
76
|
-
[train_ids, test_ids]
|
|
77
|
-
end
|
|
78
|
-
end
|
|
79
|
-
end
|
|
80
|
-
end
|
|
81
|
-
end
|
|
@@ -1,90 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/splitter'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
module ModelSelection
|
|
7
|
-
# ShuffleSplit is a class that generates the set of data indices for random permutation cross-validation.
|
|
8
|
-
#
|
|
9
|
-
# @example
|
|
10
|
-
# ss = Rumale::ModelSelection::ShuffleSplit.new(n_splits: 3, test_size: 0.2, random_seed: 1)
|
|
11
|
-
# ss.split(samples, labels).each do |train_ids, test_ids|
|
|
12
|
-
# train_samples = samples[train_ids, true]
|
|
13
|
-
# test_samples = samples[test_ids, true]
|
|
14
|
-
# ...
|
|
15
|
-
# end
|
|
16
|
-
#
|
|
17
|
-
class ShuffleSplit
|
|
18
|
-
include Base::Splitter
|
|
19
|
-
|
|
20
|
-
# Return the number of folds.
|
|
21
|
-
# @return [Integer]
|
|
22
|
-
attr_reader :n_splits
|
|
23
|
-
|
|
24
|
-
# Return the random generator for shuffling the dataset.
|
|
25
|
-
# @return [Random]
|
|
26
|
-
attr_reader :rng
|
|
27
|
-
|
|
28
|
-
# Create a new data splitter for random permutation cross validation.
|
|
29
|
-
#
|
|
30
|
-
# @param n_splits [Integer] The number of folds.
|
|
31
|
-
# @param test_size [Float] The ratio of number of samples for test data.
|
|
32
|
-
# @param train_size [Float] The ratio of number of samples for train data.
|
|
33
|
-
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
|
34
|
-
def initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil)
|
|
35
|
-
check_params_numeric(n_splits: n_splits, test_size: test_size)
|
|
36
|
-
check_params_numeric_or_nil(train_size: train_size, random_seed: random_seed)
|
|
37
|
-
check_params_positive(n_splits: n_splits)
|
|
38
|
-
check_params_positive(test_size: test_size)
|
|
39
|
-
check_params_positive(train_size: train_size) unless train_size.nil?
|
|
40
|
-
@n_splits = n_splits
|
|
41
|
-
@test_size = test_size
|
|
42
|
-
@train_size = train_size
|
|
43
|
-
@random_seed = random_seed
|
|
44
|
-
@random_seed ||= srand
|
|
45
|
-
@rng = Random.new(@random_seed)
|
|
46
|
-
end
|
|
47
|
-
|
|
48
|
-
# Generate data indices for random permutation cross validation.
|
|
49
|
-
#
|
|
50
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
|
51
|
-
# The dataset to be used to generate data indices for random permutation cross validation.
|
|
52
|
-
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
|
53
|
-
def split(x, _y = nil)
|
|
54
|
-
x = check_convert_sample_array(x)
|
|
55
|
-
# Initialize and check some variables.
|
|
56
|
-
n_samples = x.shape[0]
|
|
57
|
-
n_test_samples = (@test_size * n_samples).ceil.to_i
|
|
58
|
-
n_train_samples = @train_size.nil? ? n_samples - n_test_samples : (@train_size * n_samples).floor.to_i
|
|
59
|
-
unless @n_splits.between?(1, n_samples)
|
|
60
|
-
raise ArgumentError,
|
|
61
|
-
'The value of n_splits must be not less than 1 and not more than the number of samples.'
|
|
62
|
-
end
|
|
63
|
-
unless n_test_samples.between?(1, n_samples)
|
|
64
|
-
raise RangeError,
|
|
65
|
-
'The number of samples in test split must be not less than 1 and not more than the number of samples.'
|
|
66
|
-
end
|
|
67
|
-
unless n_train_samples.between?(1, n_samples)
|
|
68
|
-
raise RangeError,
|
|
69
|
-
'The number of samples in train split must be not less than 1 and not more than the number of samples.'
|
|
70
|
-
end
|
|
71
|
-
if (n_test_samples + n_train_samples) > n_samples
|
|
72
|
-
raise RangeError,
|
|
73
|
-
'The total number of samples in test split and train split must be not more than the number of samples.'
|
|
74
|
-
end
|
|
75
|
-
sub_rng = @rng.dup
|
|
76
|
-
# Returns array consisting of the training and testing ids for each fold.
|
|
77
|
-
dataset_ids = Array(0...n_samples)
|
|
78
|
-
Array.new(@n_splits) do
|
|
79
|
-
test_ids = dataset_ids.sample(n_test_samples, random: sub_rng)
|
|
80
|
-
train_ids = if @train_size.nil?
|
|
81
|
-
dataset_ids - test_ids
|
|
82
|
-
else
|
|
83
|
-
(dataset_ids - test_ids).sample(n_train_samples, random: sub_rng)
|
|
84
|
-
end
|
|
85
|
-
[train_ids, test_ids]
|
|
86
|
-
end
|
|
87
|
-
end
|
|
88
|
-
end
|
|
89
|
-
end
|
|
90
|
-
end
|
|
@@ -1,99 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/splitter'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
module ModelSelection
|
|
7
|
-
# StratifiedKFold is a class that generates the set of data indices for K-fold cross-validation.
|
|
8
|
-
# The proportion of the number of samples in each class will be almost equal for each fold.
|
|
9
|
-
#
|
|
10
|
-
# @example
|
|
11
|
-
# kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 3, shuffle: true, random_seed: 1)
|
|
12
|
-
# kf.split(samples, labels).each do |train_ids, test_ids|
|
|
13
|
-
# train_samples = samples[train_ids, true]
|
|
14
|
-
# test_samples = samples[test_ids, true]
|
|
15
|
-
# ...
|
|
16
|
-
# end
|
|
17
|
-
#
|
|
18
|
-
class StratifiedKFold
|
|
19
|
-
include Base::Splitter
|
|
20
|
-
|
|
21
|
-
# Return the number of folds.
|
|
22
|
-
# @return [Integer]
|
|
23
|
-
attr_reader :n_splits
|
|
24
|
-
|
|
25
|
-
# Return the flag indicating whether to shuffle the dataset.
|
|
26
|
-
# @return [Boolean]
|
|
27
|
-
attr_reader :shuffle
|
|
28
|
-
|
|
29
|
-
# Return the random generator for shuffling the dataset.
|
|
30
|
-
# @return [Random]
|
|
31
|
-
attr_reader :rng
|
|
32
|
-
|
|
33
|
-
# Create a new data splitter for stratified K-fold cross validation.
|
|
34
|
-
#
|
|
35
|
-
# @param n_splits [Integer] The number of folds.
|
|
36
|
-
# @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
|
|
37
|
-
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
|
38
|
-
def initialize(n_splits: 3, shuffle: false, random_seed: nil)
|
|
39
|
-
check_params_numeric(n_splits: n_splits)
|
|
40
|
-
check_params_boolean(shuffle: shuffle)
|
|
41
|
-
check_params_numeric_or_nil(random_seed: random_seed)
|
|
42
|
-
check_params_positive(n_splits: n_splits)
|
|
43
|
-
@n_splits = n_splits
|
|
44
|
-
@shuffle = shuffle
|
|
45
|
-
@random_seed = random_seed
|
|
46
|
-
@random_seed ||= srand
|
|
47
|
-
@rng = Random.new(@random_seed)
|
|
48
|
-
end
|
|
49
|
-
|
|
50
|
-
# Generate data indices for stratified K-fold cross validation.
|
|
51
|
-
#
|
|
52
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
|
53
|
-
# The dataset to be used to generate data indices for stratified K-fold cross validation.
|
|
54
|
-
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
|
|
55
|
-
# @param y [Numo::Int32] (shape: [n_samples])
|
|
56
|
-
# The labels to be used to generate data indices for stratified K-fold cross validation.
|
|
57
|
-
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
|
58
|
-
def split(x, y)
|
|
59
|
-
x = check_convert_sample_array(x)
|
|
60
|
-
y = check_convert_label_array(y)
|
|
61
|
-
check_sample_label_size(x, y)
|
|
62
|
-
# Check the number of samples in each class.
|
|
63
|
-
unless valid_n_splits?(y)
|
|
64
|
-
raise ArgumentError,
|
|
65
|
-
'The value of n_splits must be not less than 2 and not more than the number of samples in each class.'
|
|
66
|
-
end
|
|
67
|
-
# Splits dataset ids of each class to each fold.
|
|
68
|
-
sub_rng = @rng.dup
|
|
69
|
-
fold_sets_each_class = y.to_a.uniq.map { |label| fold_sets(y, label, sub_rng) }
|
|
70
|
-
# Returns array consisting of the training and testing ids for each fold.
|
|
71
|
-
Array.new(@n_splits) { |fold_id| train_test_sets(fold_sets_each_class, fold_id) }
|
|
72
|
-
end
|
|
73
|
-
|
|
74
|
-
private
|
|
75
|
-
|
|
76
|
-
def valid_n_splits?(y)
|
|
77
|
-
y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(2, n_samples) }
|
|
78
|
-
end
|
|
79
|
-
|
|
80
|
-
def fold_sets(y, label, sub_rng)
|
|
81
|
-
sample_ids = y.eq(label).where.to_a
|
|
82
|
-
sample_ids.shuffle!(random: sub_rng) if @shuffle
|
|
83
|
-
n_samples = sample_ids.size
|
|
84
|
-
Array.new(@n_splits) do |n|
|
|
85
|
-
n_fold_samples = n_samples / @n_splits
|
|
86
|
-
n_fold_samples += 1 if n < n_samples % @n_splits
|
|
87
|
-
sample_ids.shift(n_fold_samples)
|
|
88
|
-
end
|
|
89
|
-
end
|
|
90
|
-
|
|
91
|
-
def train_test_sets(fold_sets_each_class, fold_id)
|
|
92
|
-
train_test_sets_each_class = fold_sets_each_class.map do |folds|
|
|
93
|
-
folds.partition.with_index { |_, id| id != fold_id }.map(&:flatten)
|
|
94
|
-
end
|
|
95
|
-
train_test_sets_each_class.transpose.map(&:flatten)
|
|
96
|
-
end
|
|
97
|
-
end
|
|
98
|
-
end
|
|
99
|
-
end
|
|
@@ -1,118 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/splitter'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
module ModelSelection
|
|
7
|
-
# StratifiedShuffleSplit is a class that generates the set of data indices for random permutation cross-validation.
|
|
8
|
-
# The proportion of the number of samples in each class will be almost equal for each fold.
|
|
9
|
-
#
|
|
10
|
-
# @example
|
|
11
|
-
# ss = Rumale::ModelSelection::StratifiedShuffleSplit.new(n_splits: 3, test_size: 0.2, random_seed: 1)
|
|
12
|
-
# ss.split(samples, labels).each do |train_ids, test_ids|
|
|
13
|
-
# train_samples = samples[train_ids, true]
|
|
14
|
-
# test_samples = samples[test_ids, true]
|
|
15
|
-
# ...
|
|
16
|
-
# end
|
|
17
|
-
#
|
|
18
|
-
class StratifiedShuffleSplit
|
|
19
|
-
include Base::Splitter
|
|
20
|
-
|
|
21
|
-
# Return the number of folds.
|
|
22
|
-
# @return [Integer]
|
|
23
|
-
attr_reader :n_splits
|
|
24
|
-
|
|
25
|
-
# Return the random generator for shuffling the dataset.
|
|
26
|
-
# @return [Random]
|
|
27
|
-
attr_reader :rng
|
|
28
|
-
|
|
29
|
-
# Create a new data splitter for random permutation cross validation.
|
|
30
|
-
#
|
|
31
|
-
# @param n_splits [Integer] The number of folds.
|
|
32
|
-
# @param test_size [Float] The ratio of number of samples for test data.
|
|
33
|
-
# @param train_size [Float] The ratio of number of samples for train data.
|
|
34
|
-
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
|
35
|
-
def initialize(n_splits: 3, test_size: 0.1, train_size: nil, random_seed: nil)
|
|
36
|
-
check_params_numeric(n_splits: n_splits, test_size: test_size)
|
|
37
|
-
check_params_numeric_or_nil(train_size: train_size, random_seed: random_seed)
|
|
38
|
-
check_params_positive(n_splits: n_splits)
|
|
39
|
-
check_params_positive(test_size: test_size)
|
|
40
|
-
check_params_positive(train_size: train_size) unless train_size.nil?
|
|
41
|
-
@n_splits = n_splits
|
|
42
|
-
@test_size = test_size
|
|
43
|
-
@train_size = train_size
|
|
44
|
-
@random_seed = random_seed
|
|
45
|
-
@random_seed ||= srand
|
|
46
|
-
@rng = Random.new(@random_seed)
|
|
47
|
-
end
|
|
48
|
-
|
|
49
|
-
# Generate data indices for stratified random permutation cross validation.
|
|
50
|
-
#
|
|
51
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
|
52
|
-
# The dataset to be used to generate data indices for stratified random permutation cross validation.
|
|
53
|
-
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
|
|
54
|
-
# @param y [Numo::Int32] (shape: [n_samples])
|
|
55
|
-
# The labels to be used to generate data indices for stratified random permutation cross validation.
|
|
56
|
-
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
|
57
|
-
def split(x, y)
|
|
58
|
-
x = check_convert_sample_array(x)
|
|
59
|
-
y = check_convert_label_array(y)
|
|
60
|
-
check_sample_label_size(x, y)
|
|
61
|
-
# Initialize and check some variables.
|
|
62
|
-
train_sz = @train_size.nil? ? 1.0 - @test_size : @train_size
|
|
63
|
-
sub_rng = @rng.dup
|
|
64
|
-
# Check the number of samples in each class.
|
|
65
|
-
unless valid_n_splits?(y)
|
|
66
|
-
raise ArgumentError,
|
|
67
|
-
'The value of n_splits must be not less than 1 and not more than the number of samples in each class.'
|
|
68
|
-
end
|
|
69
|
-
unless enough_data_size_each_class?(y, @test_size, 'test')
|
|
70
|
-
raise RangeError,
|
|
71
|
-
'The number of samples in test split must be not less than 1 and not more than the number of samples in each class.'
|
|
72
|
-
end
|
|
73
|
-
unless enough_data_size_each_class?(y, train_sz, 'train')
|
|
74
|
-
raise RangeError,
|
|
75
|
-
'The number of samples in train split must be not less than 1 and not more than the number of samples in each class.'
|
|
76
|
-
end
|
|
77
|
-
unless enough_data_size_each_class?(y, train_sz + @test_size, 'train')
|
|
78
|
-
raise RangeError,
|
|
79
|
-
'The total number of samples in test split and train split must be not more than the number of samples in each class.'
|
|
80
|
-
end
|
|
81
|
-
# Returns array consisting of the training and testing ids for each fold.
|
|
82
|
-
sample_ids_each_class = y.to_a.uniq.map { |label| y.eq(label).where.to_a }
|
|
83
|
-
Array.new(@n_splits) do
|
|
84
|
-
train_ids = []
|
|
85
|
-
test_ids = []
|
|
86
|
-
sample_ids_each_class.each do |sample_ids|
|
|
87
|
-
n_samples = sample_ids.size
|
|
88
|
-
n_test_samples = (@test_size * n_samples).ceil.to_i
|
|
89
|
-
test_ids += sample_ids.sample(n_test_samples, random: sub_rng)
|
|
90
|
-
train_ids += if @train_size.nil?
|
|
91
|
-
sample_ids - test_ids
|
|
92
|
-
else
|
|
93
|
-
n_train_samples = (train_sz * n_samples).floor.to_i
|
|
94
|
-
(sample_ids - test_ids).sample(n_train_samples, random: sub_rng)
|
|
95
|
-
end
|
|
96
|
-
end
|
|
97
|
-
[train_ids, test_ids]
|
|
98
|
-
end
|
|
99
|
-
end
|
|
100
|
-
|
|
101
|
-
private
|
|
102
|
-
|
|
103
|
-
def valid_n_splits?(y)
|
|
104
|
-
y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(1, n_samples) }
|
|
105
|
-
end
|
|
106
|
-
|
|
107
|
-
def enough_data_size_each_class?(y, data_size, data_type)
|
|
108
|
-
y.to_a.uniq.map { |label| y.eq(label).where.size }.all? do |n_samples|
|
|
109
|
-
if data_type == 'test'
|
|
110
|
-
(data_size * n_samples).ceil.to_i.between?(1, n_samples)
|
|
111
|
-
else
|
|
112
|
-
(data_size * n_samples).floor.to_i.between?(1, n_samples)
|
|
113
|
-
end
|
|
114
|
-
end
|
|
115
|
-
end
|
|
116
|
-
end
|
|
117
|
-
end
|
|
118
|
-
end
|
|
@@ -1,91 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/splitter'
|
|
4
|
-
|
|
5
|
-
module Rumale
|
|
6
|
-
module ModelSelection
|
|
7
|
-
# TimeSeriesSplit is a class that generates the set of data indices for time series cross-validation.
|
|
8
|
-
# It is assumed that the dataset given are already ordered by time information.
|
|
9
|
-
#
|
|
10
|
-
# @example
|
|
11
|
-
# cv = Rumale::ModelSelection::TimeSeriesSplit.new(n_splits: 5)
|
|
12
|
-
# x = Numo::DFloat.new(6, 2).rand
|
|
13
|
-
# cv.split(x, nil).each do |train_ids, test_ids|
|
|
14
|
-
# puts '---'
|
|
15
|
-
# pp train_ids
|
|
16
|
-
# pp test_ids
|
|
17
|
-
# end
|
|
18
|
-
#
|
|
19
|
-
# # ---
|
|
20
|
-
# # [0]
|
|
21
|
-
# # [1]
|
|
22
|
-
# # ---
|
|
23
|
-
# # [0, 1]
|
|
24
|
-
# # [2]
|
|
25
|
-
# # ---
|
|
26
|
-
# # [0, 1, 2]
|
|
27
|
-
# # [3]
|
|
28
|
-
# # ---
|
|
29
|
-
# # [0, 1, 2, 3]
|
|
30
|
-
# # [4]
|
|
31
|
-
# # ---
|
|
32
|
-
# # [0, 1, 2, 3, 4]
|
|
33
|
-
# # [5]
|
|
34
|
-
#
|
|
35
|
-
class TimeSeriesSplit
|
|
36
|
-
include Base::Splitter
|
|
37
|
-
|
|
38
|
-
# Return the number of splits.
|
|
39
|
-
# @return [Integer]
|
|
40
|
-
attr_reader :n_splits
|
|
41
|
-
|
|
42
|
-
# Return the maximum number of training samples in a split.
|
|
43
|
-
# @return [Integer/Nil]
|
|
44
|
-
attr_reader :max_train_size
|
|
45
|
-
|
|
46
|
-
# Create a new data splitter for time series cross-validation.
|
|
47
|
-
#
|
|
48
|
-
# @param n_splits [Integer] The number of splits.
|
|
49
|
-
# @param max_train_size [Integer/Nil] The maximum number of training samples in a split.
|
|
50
|
-
def initialize(n_splits: 5, max_train_size: nil)
|
|
51
|
-
check_params_numeric(n_splits: n_splits)
|
|
52
|
-
check_params_numeric_or_nil(max_train_size: max_train_size)
|
|
53
|
-
@n_splits = n_splits
|
|
54
|
-
@max_train_size = max_train_size
|
|
55
|
-
end
|
|
56
|
-
|
|
57
|
-
# Generate data indices for time series cross-validation.
|
|
58
|
-
#
|
|
59
|
-
# @overload split(x, y) -> Array
|
|
60
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
|
|
61
|
-
# The dataset to be used to generate data indices for time series cross-validation.
|
|
62
|
-
# It is expected that the data will be ordered by time information.
|
|
63
|
-
# @param y [Numo::Int32] (shape: [n_samples])
|
|
64
|
-
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
|
|
65
|
-
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
|
66
|
-
def split(x, _y)
|
|
67
|
-
x = check_convert_sample_array(x)
|
|
68
|
-
|
|
69
|
-
n_samples = x.shape[0]
|
|
70
|
-
unless (@n_splits + 1).between?(2, n_samples)
|
|
71
|
-
raise ArgumentError,
|
|
72
|
-
'The number of folds (n_splits + 1) must be not less than 2 and not more than the number of samples.'
|
|
73
|
-
end
|
|
74
|
-
|
|
75
|
-
test_size = n_samples / (@n_splits + 1)
|
|
76
|
-
offset = test_size + n_samples % (@n_splits + 1)
|
|
77
|
-
|
|
78
|
-
Array.new(@n_splits) do |n|
|
|
79
|
-
start = offset * (n + 1)
|
|
80
|
-
train_ids = if !@max_train_size.nil? && @max_train_size < test_size
|
|
81
|
-
Array((start - @max_train_size)...start)
|
|
82
|
-
else
|
|
83
|
-
Array(0...start)
|
|
84
|
-
end
|
|
85
|
-
test_ids = Array(start...(start + test_size))
|
|
86
|
-
[train_ids, test_ids]
|
|
87
|
-
end
|
|
88
|
-
end
|
|
89
|
-
end
|
|
90
|
-
end
|
|
91
|
-
end
|
|
@@ -1,83 +0,0 @@
|
|
|
1
|
-
# frozen_string_literal: true
|
|
2
|
-
|
|
3
|
-
require 'rumale/base/base_estimator'
|
|
4
|
-
require 'rumale/base/classifier'
|
|
5
|
-
|
|
6
|
-
module Rumale
|
|
7
|
-
# This module consists of the classes that implement multi-class classification strategy.
|
|
8
|
-
module Multiclass
|
|
9
|
-
# @note
|
|
10
|
-
# All classifier in Rumale support multi-class classifiction since version 0.2.7.
|
|
11
|
-
# There is no need to explicitly use this class for multiclass classifiction.
|
|
12
|
-
#
|
|
13
|
-
# OneVsRestClassifier is a class that implements One-vs-Rest (OvR) strategy for multi-class classification.
|
|
14
|
-
#
|
|
15
|
-
# @example
|
|
16
|
-
# base_estimator = Rumale::LinearModel::LogisticRegression.new
|
|
17
|
-
# estimator = Rumale::Multiclass::OneVsRestClassifier.new(estimator: base_estimator)
|
|
18
|
-
# estimator.fit(training_samples, training_labels)
|
|
19
|
-
# results = estimator.predict(testing_samples)
|
|
20
|
-
class OneVsRestClassifier
|
|
21
|
-
include Base::BaseEstimator
|
|
22
|
-
include Base::Classifier
|
|
23
|
-
|
|
24
|
-
# Return the set of estimators.
|
|
25
|
-
# @return [Array<Classifier>]
|
|
26
|
-
attr_reader :estimators
|
|
27
|
-
|
|
28
|
-
# Return the class labels.
|
|
29
|
-
# @return [Numo::Int32] (shape: [n_classes])
|
|
30
|
-
attr_reader :classes
|
|
31
|
-
|
|
32
|
-
# Create a new multi-class classifier with the one-vs-rest startegy.
|
|
33
|
-
#
|
|
34
|
-
# @param estimator [Classifier] The (binary) classifier for construction a multi-class classifier.
|
|
35
|
-
def initialize(estimator: nil)
|
|
36
|
-
check_params_type(Rumale::Base::BaseEstimator, estimator: estimator)
|
|
37
|
-
@params = {}
|
|
38
|
-
@params[:estimator] = estimator
|
|
39
|
-
@estimators = nil
|
|
40
|
-
@classes = nil
|
|
41
|
-
end
|
|
42
|
-
|
|
43
|
-
# Fit the model with given training data.
|
|
44
|
-
#
|
|
45
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
|
46
|
-
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
|
47
|
-
# @return [OneVsRestClassifier] The learned classifier itself.
|
|
48
|
-
def fit(x, y)
|
|
49
|
-
x = check_convert_sample_array(x)
|
|
50
|
-
y = check_convert_label_array(y)
|
|
51
|
-
check_sample_label_size(x, y)
|
|
52
|
-
y_arr = y.to_a
|
|
53
|
-
@classes = Numo::Int32.asarray(y_arr.uniq.sort)
|
|
54
|
-
@estimators = @classes.to_a.map do |label|
|
|
55
|
-
bin_y = Numo::Int32.asarray(y_arr.map { |l| l == label ? 1 : -1 })
|
|
56
|
-
@params[:estimator].dup.fit(x, bin_y)
|
|
57
|
-
end
|
|
58
|
-
self
|
|
59
|
-
end
|
|
60
|
-
|
|
61
|
-
# Calculate confidence scores for samples.
|
|
62
|
-
#
|
|
63
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
|
|
64
|
-
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence scores per sample for each class.
|
|
65
|
-
def decision_function(x)
|
|
66
|
-
x = check_convert_sample_array(x)
|
|
67
|
-
n_classes = @classes.size
|
|
68
|
-
Numo::DFloat.asarray(Array.new(n_classes) { |m| @estimators[m].decision_function(x).to_a }).transpose
|
|
69
|
-
end
|
|
70
|
-
|
|
71
|
-
# Predict class labels for samples.
|
|
72
|
-
#
|
|
73
|
-
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
|
|
74
|
-
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
|
|
75
|
-
def predict(x)
|
|
76
|
-
x = check_convert_sample_array(x)
|
|
77
|
-
n_samples, = x.shape
|
|
78
|
-
decision_values = decision_function(x)
|
|
79
|
-
Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
|
|
80
|
-
end
|
|
81
|
-
end
|
|
82
|
-
end
|
|
83
|
-
end
|