rumale 0.8.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (85) hide show
  1. checksums.yaml +7 -0
  2. data/.coveralls.yml +1 -0
  3. data/.gitignore +20 -0
  4. data/.rspec +3 -0
  5. data/.rubocop.yml +47 -0
  6. data/.rubocop_todo.yml +58 -0
  7. data/.travis.yml +13 -0
  8. data/CHANGELOG.md +2 -0
  9. data/CODE_OF_CONDUCT.md +74 -0
  10. data/Gemfile +4 -0
  11. data/LICENSE.txt +23 -0
  12. data/README.md +175 -0
  13. data/Rakefile +6 -0
  14. data/bin/console +14 -0
  15. data/bin/setup +8 -0
  16. data/lib/rumale.rb +70 -0
  17. data/lib/rumale/base/base_estimator.rb +13 -0
  18. data/lib/rumale/base/classifier.rb +36 -0
  19. data/lib/rumale/base/cluster_analyzer.rb +31 -0
  20. data/lib/rumale/base/evaluator.rb +17 -0
  21. data/lib/rumale/base/regressor.rb +36 -0
  22. data/lib/rumale/base/splitter.rb +21 -0
  23. data/lib/rumale/base/transformer.rb +22 -0
  24. data/lib/rumale/clustering/dbscan.rb +125 -0
  25. data/lib/rumale/clustering/k_means.rb +138 -0
  26. data/lib/rumale/dataset.rb +110 -0
  27. data/lib/rumale/decomposition/nmf.rb +141 -0
  28. data/lib/rumale/decomposition/pca.rb +148 -0
  29. data/lib/rumale/ensemble/ada_boost_classifier.rb +196 -0
  30. data/lib/rumale/ensemble/ada_boost_regressor.rb +178 -0
  31. data/lib/rumale/ensemble/random_forest_classifier.rb +180 -0
  32. data/lib/rumale/ensemble/random_forest_regressor.rb +141 -0
  33. data/lib/rumale/evaluation_measure/accuracy.rb +29 -0
  34. data/lib/rumale/evaluation_measure/f_score.rb +50 -0
  35. data/lib/rumale/evaluation_measure/log_loss.rb +45 -0
  36. data/lib/rumale/evaluation_measure/mean_absolute_error.rb +29 -0
  37. data/lib/rumale/evaluation_measure/mean_squared_error.rb +29 -0
  38. data/lib/rumale/evaluation_measure/normalized_mutual_information.rb +62 -0
  39. data/lib/rumale/evaluation_measure/precision.rb +50 -0
  40. data/lib/rumale/evaluation_measure/precision_recall.rb +91 -0
  41. data/lib/rumale/evaluation_measure/purity.rb +40 -0
  42. data/lib/rumale/evaluation_measure/r2_score.rb +43 -0
  43. data/lib/rumale/evaluation_measure/recall.rb +50 -0
  44. data/lib/rumale/kernel_approximation/rbf.rb +121 -0
  45. data/lib/rumale/kernel_machine/kernel_svc.rb +193 -0
  46. data/lib/rumale/linear_model/base_linear_model.rb +89 -0
  47. data/lib/rumale/linear_model/lasso.rb +136 -0
  48. data/lib/rumale/linear_model/linear_regression.rb +110 -0
  49. data/lib/rumale/linear_model/logistic_regression.rb +159 -0
  50. data/lib/rumale/linear_model/ridge.rb +110 -0
  51. data/lib/rumale/linear_model/svc.rb +183 -0
  52. data/lib/rumale/linear_model/svr.rb +122 -0
  53. data/lib/rumale/model_selection/cross_validation.rb +123 -0
  54. data/lib/rumale/model_selection/grid_search_cv.rb +247 -0
  55. data/lib/rumale/model_selection/k_fold.rb +76 -0
  56. data/lib/rumale/model_selection/stratified_k_fold.rb +94 -0
  57. data/lib/rumale/multiclass/one_vs_rest_classifier.rb +100 -0
  58. data/lib/rumale/naive_bayes/naive_bayes.rb +315 -0
  59. data/lib/rumale/nearest_neighbors/k_neighbors_classifier.rb +111 -0
  60. data/lib/rumale/nearest_neighbors/k_neighbors_regressor.rb +93 -0
  61. data/lib/rumale/optimizer/nadam.rb +90 -0
  62. data/lib/rumale/optimizer/rmsprop.rb +69 -0
  63. data/lib/rumale/optimizer/sgd.rb +65 -0
  64. data/lib/rumale/optimizer/yellow_fin.rb +144 -0
  65. data/lib/rumale/pairwise_metric.rb +91 -0
  66. data/lib/rumale/pipeline/pipeline.rb +197 -0
  67. data/lib/rumale/polynomial_model/base_factorization_machine.rb +99 -0
  68. data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +197 -0
  69. data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +131 -0
  70. data/lib/rumale/preprocessing/l2_normalizer.rb +62 -0
  71. data/lib/rumale/preprocessing/label_encoder.rb +94 -0
  72. data/lib/rumale/preprocessing/min_max_scaler.rb +92 -0
  73. data/lib/rumale/preprocessing/one_hot_encoder.rb +98 -0
  74. data/lib/rumale/preprocessing/standard_scaler.rb +86 -0
  75. data/lib/rumale/probabilistic_output.rb +112 -0
  76. data/lib/rumale/tree/base_decision_tree.rb +153 -0
  77. data/lib/rumale/tree/decision_tree_classifier.rb +163 -0
  78. data/lib/rumale/tree/decision_tree_regressor.rb +135 -0
  79. data/lib/rumale/tree/node.rb +70 -0
  80. data/lib/rumale/utils.rb +37 -0
  81. data/lib/rumale/validation.rb +79 -0
  82. data/lib/rumale/values.rb +13 -0
  83. data/lib/rumale/version.rb +6 -0
  84. data/rumale.gemspec +41 -0
  85. metadata +204 -0
@@ -0,0 +1,76 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ # This module consists of the classes for model validation techniques.
7
+ module ModelSelection
8
+ # KFold is a class that generates the set of data indices for K-fold cross-validation.
9
+ #
10
+ # @example
11
+ # kf = Rumale::ModelSelection::KFold.new(n_splits: 3, shuffle: true, random_seed: 1)
12
+ # kf.split(samples, labels).each do |train_ids, test_ids|
13
+ # train_samples = samples[train_ids, true]
14
+ # test_samples = samples[test_ids, true]
15
+ # ...
16
+ # end
17
+ #
18
+ class KFold
19
+ include Base::Splitter
20
+
21
+ # Return the flag indicating whether to shuffle the dataset.
22
+ # @return [Boolean]
23
+ attr_reader :shuffle
24
+
25
+ # Return the random generator for shuffling the dataset.
26
+ # @return [Random]
27
+ attr_reader :rng
28
+
29
+ # Create a new data splitter for K-fold cross validation.
30
+ #
31
+ # @param n_splits [Integer] The number of folds.
32
+ # @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
33
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
34
+ def initialize(n_splits: 3, shuffle: false, random_seed: nil)
35
+ check_params_integer(n_splits: n_splits)
36
+ check_params_boolean(shuffle: shuffle)
37
+ check_params_type_or_nil(Integer, random_seed: random_seed)
38
+ check_params_positive(n_splits: n_splits)
39
+ @n_splits = n_splits
40
+ @shuffle = shuffle
41
+ @random_seed = random_seed
42
+ @random_seed ||= srand
43
+ @rng = Random.new(@random_seed)
44
+ end
45
+
46
+ # Generate data indices for K-fold cross validation.
47
+ #
48
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
49
+ # The dataset to be used to generate data indices for K-fold cross validation.
50
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
51
+ def split(x, _y = nil)
52
+ check_sample_array(x)
53
+ # Initialize and check some variables.
54
+ n_samples, = x.shape
55
+ unless @n_splits.between?(2, n_samples)
56
+ raise ArgumentError,
57
+ 'The value of n_splits must be not less than 2 and not more than the number of samples.'
58
+ end
59
+ # Splits dataset ids to each fold.
60
+ dataset_ids = [*0...n_samples]
61
+ dataset_ids.shuffle!(random: @rng) if @shuffle
62
+ fold_sets = Array.new(@n_splits) do |n|
63
+ n_fold_samples = n_samples / @n_splits
64
+ n_fold_samples += 1 if n < n_samples % @n_splits
65
+ dataset_ids.shift(n_fold_samples)
66
+ end
67
+ # Returns array consisting of the training and testing ids for each fold.
68
+ Array.new(@n_splits) do |n|
69
+ train_ids = fold_sets.select.with_index { |_, id| id != n }.flatten
70
+ test_ids = fold_sets[n]
71
+ [train_ids, test_ids]
72
+ end
73
+ end
74
+ end
75
+ end
76
+ end
@@ -0,0 +1,94 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/splitter'
4
+
5
+ module Rumale
6
+ module ModelSelection
7
+ # StratifiedKFold is a class that generates the set of data indices for K-fold cross-validation.
8
+ # The proportion of the number of samples in each class will be almost equal for each fold.
9
+ #
10
+ # @example
11
+ # kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 3, shuffle: true, random_seed: 1)
12
+ # kf.split(samples, labels).each do |train_ids, test_ids|
13
+ # train_samples = samples[train_ids, true]
14
+ # test_samples = samples[test_ids, true]
15
+ # ...
16
+ # end
17
+ #
18
+ class StratifiedKFold
19
+ include Base::Splitter
20
+
21
+ # Return the flag indicating whether to shuffle the dataset.
22
+ # @return [Boolean]
23
+ attr_reader :shuffle
24
+
25
+ # Return the random generator for shuffling the dataset.
26
+ # @return [Random]
27
+ attr_reader :rng
28
+
29
+ # Create a new data splitter for K-fold cross validation.
30
+ #
31
+ # @param n_splits [Integer] The number of folds.
32
+ # @param shuffle [Boolean] The flag indicating whether to shuffle the dataset.
33
+ # @param random_seed [Integer] The seed value using to initialize the random generator.
34
+ def initialize(n_splits: 3, shuffle: false, random_seed: nil)
35
+ check_params_integer(n_splits: n_splits)
36
+ check_params_boolean(shuffle: shuffle)
37
+ check_params_type_or_nil(Integer, random_seed: random_seed)
38
+ check_params_positive(n_splits: n_splits)
39
+ @n_splits = n_splits
40
+ @shuffle = shuffle
41
+ @random_seed = random_seed
42
+ @random_seed ||= srand
43
+ @rng = Random.new(@random_seed)
44
+ end
45
+
46
+ # Generate data indices for stratified K-fold cross validation.
47
+ #
48
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features])
49
+ # The dataset to be used to generate data indices for stratified K-fold cross validation.
50
+ # This argument exists to unify the interface between the K-fold methods, it is not used in the method.
51
+ # @param y [Numo::Int32] (shape: [n_samples])
52
+ # The labels to be used to generate data indices for stratified K-fold cross validation.
53
+ # @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
54
+ def split(x, y)
55
+ check_sample_array(x)
56
+ check_label_array(y)
57
+ check_sample_label_size(x, y)
58
+ # Check the number of samples in each class.
59
+ unless valid_n_splits?(y)
60
+ raise ArgumentError,
61
+ 'The value of n_splits must be not less than 2 and not more than the number of samples in each class.'
62
+ end
63
+ # Splits dataset ids of each class to each fold.
64
+ fold_sets_each_class = y.to_a.uniq.map { |label| fold_sets(y, label) }
65
+ # Returns array consisting of the training and testing ids for each fold.
66
+ Array.new(@n_splits) { |fold_id| train_test_sets(fold_sets_each_class, fold_id) }
67
+ end
68
+
69
+ private
70
+
71
+ def valid_n_splits?(y)
72
+ y.to_a.uniq.map { |label| y.eq(label).where.size }.all? { |n_samples| @n_splits.between?(2, n_samples) }
73
+ end
74
+
75
+ def fold_sets(y, label)
76
+ sample_ids = y.eq(label).where.to_a
77
+ sample_ids.shuffle!(random: @rng) if @shuffle
78
+ n_samples = sample_ids.size
79
+ Array.new(@n_splits) do |n|
80
+ n_fold_samples = n_samples / @n_splits
81
+ n_fold_samples += 1 if n < n_samples % @n_splits
82
+ sample_ids.shift(n_fold_samples)
83
+ end
84
+ end
85
+
86
+ def train_test_sets(fold_sets_each_class, fold_id)
87
+ train_test_sets_each_class = fold_sets_each_class.map do |folds|
88
+ folds.partition.with_index { |_, id| id != fold_id }.map(&:flatten)
89
+ end
90
+ train_test_sets_each_class.transpose.map(&:flatten)
91
+ end
92
+ end
93
+ end
94
+ end
@@ -0,0 +1,100 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator.rb'
4
+ require 'rumale/base/classifier.rb'
5
+
6
+ module Rumale
7
+ # This module consists of the classes that implement multi-class classification strategy.
8
+ module Multiclass
9
+ # @note
10
+ # All classifier in Rumale support multi-class classifiction since version 0.2.7.
11
+ # There is no need to explicitly use this class for multiclass classifiction.
12
+ #
13
+ # OneVsRestClassifier is a class that implements One-vs-Rest (OvR) strategy for multi-class classification.
14
+ #
15
+ # @example
16
+ # base_estimator = Rumale::LinearModel::LogisticRegression.new
17
+ # estimator = Rumale::Multiclass::OneVsRestClassifier.new(estimator: base_estimator)
18
+ # estimator.fit(training_samples, training_labels)
19
+ # results = estimator.predict(testing_samples)
20
+ class OneVsRestClassifier
21
+ include Base::BaseEstimator
22
+ include Base::Classifier
23
+
24
+ # Return the set of estimators.
25
+ # @return [Array<Classifier>]
26
+ attr_reader :estimators
27
+
28
+ # Return the class labels.
29
+ # @return [Numo::Int32] (shape: [n_classes])
30
+ attr_reader :classes
31
+
32
+ # Create a new multi-class classifier with the one-vs-rest startegy.
33
+ #
34
+ # @param estimator [Classifier] The (binary) classifier for construction a multi-class classifier.
35
+ def initialize(estimator: nil)
36
+ check_params_type(Rumale::Base::BaseEstimator, estimator: estimator)
37
+ @params = {}
38
+ @params[:estimator] = estimator
39
+ @estimators = nil
40
+ @classes = nil
41
+ end
42
+
43
+ # Fit the model with given training data.
44
+ #
45
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
46
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
47
+ # @return [OneVsRestClassifier] The learned classifier itself.
48
+ def fit(x, y)
49
+ check_sample_array(x)
50
+ check_label_array(y)
51
+ check_sample_label_size(x, y)
52
+ y_arr = y.to_a
53
+ @classes = Numo::Int32.asarray(y_arr.uniq.sort)
54
+ @estimators = @classes.to_a.map do |label|
55
+ bin_y = Numo::Int32.asarray(y_arr.map { |l| l == label ? 1 : -1 })
56
+ @params[:estimator].dup.fit(x, bin_y)
57
+ end
58
+ self
59
+ end
60
+
61
+ # Calculate confidence scores for samples.
62
+ #
63
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
64
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence scores per sample for each class.
65
+ def decision_function(x)
66
+ check_sample_array(x)
67
+ n_classes = @classes.size
68
+ Numo::DFloat.asarray(Array.new(n_classes) { |m| @estimators[m].decision_function(x).to_a }).transpose
69
+ end
70
+
71
+ # Predict class labels for samples.
72
+ #
73
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
74
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
75
+ def predict(x)
76
+ check_sample_array(x)
77
+ n_samples, = x.shape
78
+ decision_values = decision_function(x)
79
+ Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
80
+ end
81
+
82
+ # Dump marshal data.
83
+ # @return [Hash] The marshal data about OneVsRestClassifier.
84
+ def marshal_dump
85
+ { params: @params,
86
+ classes: @classes,
87
+ estimators: @estimators.map { |e| Marshal.dump(e) } }
88
+ end
89
+
90
+ # Load marshal data.
91
+ # @return [nil]
92
+ def marshal_load(obj)
93
+ @params = obj[:params]
94
+ @classes = obj[:classes]
95
+ @estimators = obj[:estimators].map { |e| Marshal.load(e) }
96
+ nil
97
+ end
98
+ end
99
+ end
100
+ end
@@ -0,0 +1,315 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/classifier'
5
+
6
+ module Rumale
7
+ # This module consists of the classes that implement naive bayes models.
8
+ module NaiveBayes
9
+ # BaseNaiveBayes is a class that has methods for common processes of naive bayes classifier.
10
+ class BaseNaiveBayes
11
+ include Base::BaseEstimator
12
+ include Base::Classifier
13
+
14
+ # Predict class labels for samples.
15
+ #
16
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
17
+ # @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample.
18
+ def predict(x)
19
+ check_sample_array(x)
20
+ n_samples = x.shape.first
21
+ decision_values = decision_function(x)
22
+ Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
23
+ end
24
+
25
+ # Predict log-probability for samples.
26
+ #
27
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the log-probailities.
28
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.
29
+ def predict_log_proba(x)
30
+ check_sample_array(x)
31
+ n_samples, = x.shape
32
+ log_likelihoods = decision_function(x)
33
+ log_likelihoods - Numo::NMath.log(Numo::NMath.exp(log_likelihoods).sum(1)).reshape(n_samples, 1)
34
+ end
35
+
36
+ # Predict probability for samples.
37
+ #
38
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
39
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
40
+ def predict_proba(x)
41
+ check_sample_array(x)
42
+ Numo::NMath.exp(predict_log_proba(x)).abs
43
+ end
44
+ end
45
+
46
+ # GaussianNB is a class that implements Gaussian Naive Bayes classifier.
47
+ #
48
+ # @example
49
+ # estimator = Rumale::NaiveBayes::GaussianNB.new
50
+ # estimator.fit(training_samples, training_labels)
51
+ # results = estimator.predict(testing_samples)
52
+ class GaussianNB < BaseNaiveBayes
53
+ # Return the class labels.
54
+ # @return [Numo::Int32] (size: n_classes)
55
+ attr_reader :classes
56
+
57
+ # Return the prior probabilities of the classes.
58
+ # @return [Numo::DFloat] (shape: [n_classes])
59
+ attr_reader :class_priors
60
+
61
+ # Return the mean vectors of the classes.
62
+ # @return [Numo::DFloat] (shape: [n_classes, n_features])
63
+ attr_reader :means
64
+
65
+ # Return the variance vectors of the classes.
66
+ # @return [Numo::DFloat] (shape: [n_classes, n_features])
67
+ attr_reader :variances
68
+
69
+ # Create a new classifier with Gaussian Naive Bayes.
70
+ def initialize
71
+ @params = {}
72
+ end
73
+
74
+ # Fit the model with given training data.
75
+ #
76
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
77
+ # @param y [Numo::Int32] (shape: [n_samples]) The categorical variables (e.g. labels)
78
+ # to be used for fitting the model.
79
+ # @return [GaussianNB] The learned classifier itself.
80
+ def fit(x, y)
81
+ check_sample_array(x)
82
+ check_label_array(y)
83
+ check_sample_label_size(x, y)
84
+ n_samples, = x.shape
85
+ @classes = Numo::Int32[*y.to_a.uniq.sort]
86
+ @class_priors = Numo::DFloat[*@classes.to_a.map { |l| y.eq(l).count / n_samples.to_f }]
87
+ @means = Numo::DFloat[*@classes.to_a.map { |l| x[y.eq(l).where, true].mean(0) }]
88
+ @variances = Numo::DFloat[*@classes.to_a.map { |l| x[y.eq(l).where, true].var(0) }]
89
+ self
90
+ end
91
+
92
+ # Calculate confidence scores for samples.
93
+ #
94
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
95
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence scores per sample for each class.
96
+ def decision_function(x)
97
+ check_sample_array(x)
98
+ n_classes = @classes.size
99
+ log_likelihoods = Array.new(n_classes) do |l|
100
+ Math.log(@class_priors[l]) - 0.5 * (
101
+ Numo::NMath.log(2.0 * Math::PI * @variances[l, true]) +
102
+ ((x - @means[l, true])**2 / @variances[l, true])).sum(1)
103
+ end
104
+ Numo::DFloat[*log_likelihoods].transpose
105
+ end
106
+
107
+ # Dump marshal data.
108
+ #
109
+ # @return [Hash] The marshal data about GaussianNB.
110
+ def marshal_dump
111
+ { params: @params,
112
+ classes: @classes,
113
+ class_priors: @class_priors,
114
+ means: @means,
115
+ variances: @variances }
116
+ end
117
+
118
+ # Load marshal data.
119
+ #
120
+ # @return [nil]
121
+ def marshal_load(obj)
122
+ @params = obj[:params]
123
+ @classes = obj[:classes]
124
+ @class_priors = obj[:class_priors]
125
+ @means = obj[:means]
126
+ @variances = obj[:variances]
127
+ nil
128
+ end
129
+ end
130
+
131
+ # MultinomialNB is a class that implements Multinomial Naive Bayes classifier.
132
+ #
133
+ # @example
134
+ # estimator = Rumale::NaiveBayes::MultinomialNB.new(smoothing_param: 1.0)
135
+ # estimator.fit(training_samples, training_labels)
136
+ # results = estimator.predict(testing_samples)
137
+ #
138
+ # *Reference*
139
+ # - C D. Manning, P. Raghavan, and H. Schutze, "Introduction to Information Retrieval," Cambridge University Press., 2008.
140
+ class MultinomialNB < BaseNaiveBayes
141
+ # Return the class labels.
142
+ # @return [Numo::Int32] (size: n_classes)
143
+ attr_reader :classes
144
+
145
+ # Return the prior probabilities of the classes.
146
+ # @return [Numo::DFloat] (shape: [n_classes])
147
+ attr_reader :class_priors
148
+
149
+ # Return the conditional probabilities for features of each class.
150
+ # @return [Numo::DFloat] (shape: [n_classes, n_features])
151
+ attr_reader :feature_probs
152
+
153
+ # Create a new classifier with Multinomial Naive Bayes.
154
+ #
155
+ # @param smoothing_param [Float] The Laplace smoothing parameter.
156
+ def initialize(smoothing_param: 1.0)
157
+ check_params_float(smoothing_param: smoothing_param)
158
+ check_params_positive(smoothing_param: smoothing_param)
159
+ @params = {}
160
+ @params[:smoothing_param] = smoothing_param
161
+ end
162
+
163
+ # Fit the model with given training data.
164
+ #
165
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
166
+ # @param y [Numo::Int32] (shape: [n_samples]) The categorical variables (e.g. labels)
167
+ # to be used for fitting the model.
168
+ # @return [MultinomialNB] The learned classifier itself.
169
+ def fit(x, y)
170
+ check_sample_array(x)
171
+ check_label_array(y)
172
+ check_sample_label_size(x, y)
173
+ n_samples, = x.shape
174
+ @classes = Numo::Int32[*y.to_a.uniq.sort]
175
+ @class_priors = Numo::DFloat[*@classes.to_a.map { |l| y.eq(l).count / n_samples.to_f }]
176
+ count_features = Numo::DFloat[*@classes.to_a.map { |l| x[y.eq(l).where, true].sum(0) }]
177
+ count_features += @params[:smoothing_param]
178
+ n_classes = @classes.size
179
+ @feature_probs = count_features / count_features.sum(1).reshape(n_classes, 1)
180
+ self
181
+ end
182
+
183
+ # Calculate confidence scores for samples.
184
+ #
185
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
186
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence scores per sample for each class.
187
+ def decision_function(x)
188
+ check_sample_array(x)
189
+ n_classes = @classes.size
190
+ bin_x = x.gt(0)
191
+ log_likelihoods = Array.new(n_classes) do |l|
192
+ Math.log(@class_priors[l]) + (Numo::DFloat[*bin_x] * Numo::NMath.log(@feature_probs[l, true])).sum(1)
193
+ end
194
+ Numo::DFloat[*log_likelihoods].transpose
195
+ end
196
+
197
+ # Dump marshal data.
198
+ #
199
+ # @return [Hash] The marshal data about MultinomialNB.
200
+ def marshal_dump
201
+ { params: @params,
202
+ classes: @classes,
203
+ class_priors: @class_priors,
204
+ feature_probs: @feature_probs }
205
+ end
206
+
207
+ # Load marshal data.
208
+ #
209
+ # @return [nil]
210
+ def marshal_load(obj)
211
+ @params = obj[:params]
212
+ @classes = obj[:classes]
213
+ @class_priors = obj[:class_priors]
214
+ @feature_probs = obj[:feature_probs]
215
+ nil
216
+ end
217
+ end
218
+
219
+ # BernoulliNB is a class that implements Bernoulli Naive Bayes classifier.
220
+ #
221
+ # @example
222
+ # estimator = Rumale::NaiveBayes::BernoulliNB.new(smoothing_param: 1.0, bin_threshold: 0.0)
223
+ # estimator.fit(training_samples, training_labels)
224
+ # results = estimator.predict(testing_samples)
225
+ #
226
+ # *Reference*
227
+ # - C D. Manning, P. Raghavan, and H. Schutze, "Introduction to Information Retrieval," Cambridge University Press., 2008.
228
+ class BernoulliNB < BaseNaiveBayes
229
+ # Return the class labels.
230
+ # @return [Numo::Int32] (size: n_classes)
231
+ attr_reader :classes
232
+
233
+ # Return the prior probabilities of the classes.
234
+ # @return [Numo::DFloat] (shape: [n_classes])
235
+ attr_reader :class_priors
236
+
237
+ # Return the conditional probabilities for features of each class.
238
+ # @return [Numo::DFloat] (shape: [n_classes, n_features])
239
+ attr_reader :feature_probs
240
+
241
+ # Create a new classifier with Bernoulli Naive Bayes.
242
+ #
243
+ # @param smoothing_param [Float] The Laplace smoothing parameter.
244
+ # @param bin_threshold [Float] The threshold for binarizing of features.
245
+ def initialize(smoothing_param: 1.0, bin_threshold: 0.0)
246
+ check_params_float(smoothing_param: smoothing_param, bin_threshold: bin_threshold)
247
+ check_params_positive(smoothing_param: smoothing_param)
248
+ @params = {}
249
+ @params[:smoothing_param] = smoothing_param
250
+ @params[:bin_threshold] = bin_threshold
251
+ end
252
+
253
+ # Fit the model with given training data.
254
+ #
255
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
256
+ # @param y [Numo::Int32] (shape: [n_samples]) The categorical variables (e.g. labels)
257
+ # to be used for fitting the model.
258
+ # @return [BernoulliNB] The learned classifier itself.
259
+ def fit(x, y)
260
+ check_sample_array(x)
261
+ check_label_array(y)
262
+ check_sample_label_size(x, y)
263
+ n_samples, = x.shape
264
+ bin_x = Numo::DFloat[*x.gt(@params[:bin_threshold])]
265
+ @classes = Numo::Int32[*y.to_a.uniq.sort]
266
+ n_samples_each_class = Numo::DFloat[*@classes.to_a.map { |l| y.eq(l).count.to_f }]
267
+ @class_priors = n_samples_each_class / n_samples
268
+ count_features = Numo::DFloat[*@classes.to_a.map { |l| bin_x[y.eq(l).where, true].sum(0) }]
269
+ count_features += @params[:smoothing_param]
270
+ n_samples_each_class += 2.0 * @params[:smoothing_param]
271
+ n_classes = @classes.size
272
+ @feature_probs = count_features / n_samples_each_class.reshape(n_classes, 1)
273
+ self
274
+ end
275
+
276
+ # Calculate confidence scores for samples.
277
+ #
278
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
279
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Confidence scores per sample for each class.
280
+ def decision_function(x)
281
+ check_sample_array(x)
282
+ n_classes = @classes.size
283
+ bin_x = Numo::DFloat[*x.gt(@params[:bin_threshold])]
284
+ not_bin_x = Numo::DFloat[*x.le(@params[:bin_threshold])]
285
+ log_likelihoods = Array.new(n_classes) do |l|
286
+ Math.log(@class_priors[l]) + (
287
+ (Numo::DFloat[*bin_x] * Numo::NMath.log(@feature_probs[l, true])).sum(1)
288
+ (Numo::DFloat[*not_bin_x] * Numo::NMath.log(1.0 - @feature_probs[l, true])).sum(1))
289
+ end
290
+ Numo::DFloat[*log_likelihoods].transpose
291
+ end
292
+
293
+ # Dump marshal data.
294
+ #
295
+ # @return [Hash] The marshal data about BernoulliNB.
296
+ def marshal_dump
297
+ { params: @params,
298
+ classes: @classes,
299
+ class_priors: @class_priors,
300
+ feature_probs: @feature_probs }
301
+ end
302
+
303
+ # Load marshal data.
304
+ #
305
+ # @return [nil]
306
+ def marshal_load(obj)
307
+ @params = obj[:params]
308
+ @classes = obj[:classes]
309
+ @class_priors = obj[:class_priors]
310
+ @feature_probs = obj[:feature_probs]
311
+ nil
312
+ end
313
+ end
314
+ end
315
+ end