svmkit 0.2.8 → 0.2.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +5 -5
- data/.gitignore +4 -0
- data/.rubocop.yml +10 -1
- data/.rubocop_todo.yml +51 -10
- data/Gemfile +1 -1
- data/HISTORY.md +43 -33
- data/lib/svmkit.rb +4 -0
- data/lib/svmkit/base/classifier.rb +1 -0
- data/lib/svmkit/ensemble/random_forest_classifier.rb +5 -2
- data/lib/svmkit/evaluation_measure/log_loss.rb +44 -0
- data/lib/svmkit/kernel_approximation/rbf.rb +1 -1
- data/lib/svmkit/kernel_machine/kernel_svc.rb +40 -2
- data/lib/svmkit/linear_model/logistic_regression.rb +3 -1
- data/lib/svmkit/linear_model/svc.rb +46 -7
- data/lib/svmkit/model_selection/cross_validation.rb +9 -1
- data/lib/svmkit/model_selection/k_fold.rb +1 -1
- data/lib/svmkit/model_selection/stratified_k_fold.rb +3 -2
- data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +1 -0
- data/lib/svmkit/naive_bayes/naive_bayes.rb +5 -0
- data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +2 -0
- data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +4 -1
- data/lib/svmkit/preprocessing/label_encoder.rb +94 -0
- data/lib/svmkit/preprocessing/one_hot_encoder.rb +98 -0
- data/lib/svmkit/probabilistic_output.rb +112 -0
- data/lib/svmkit/tree/decision_tree_classifier.rb +80 -10
- data/lib/svmkit/validation.rb +12 -0
- data/lib/svmkit/version.rb +1 -1
- data/svmkit.gemspec +4 -6
- metadata +18 -14
@@ -45,26 +45,30 @@ module SVMKit
|
|
45
45
|
# @param bias_scale [Float] The scale of the bias term.
|
46
46
|
# @param max_iter [Integer] The maximum number of iterations.
|
47
47
|
# @param batch_size [Integer] The size of the mini batches.
|
48
|
+
# @param probability [Boolean] The flag indicating whether to perform probability estimation.
|
48
49
|
# @param normalize [Boolean] The flag indicating whether to normalize the weight vector.
|
49
50
|
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
50
51
|
def initialize(reg_param: 1.0, fit_bias: false, bias_scale: 1.0,
|
51
|
-
max_iter: 100, batch_size: 50, normalize: true, random_seed: nil)
|
52
|
+
max_iter: 100, batch_size: 50, probability: false, normalize: true, random_seed: nil)
|
52
53
|
SVMKit::Validation.check_params_float(reg_param: reg_param, bias_scale: bias_scale)
|
53
54
|
SVMKit::Validation.check_params_integer(max_iter: max_iter, batch_size: batch_size)
|
54
|
-
SVMKit::Validation.check_params_boolean(fit_bias: fit_bias, normalize: normalize)
|
55
|
+
SVMKit::Validation.check_params_boolean(fit_bias: fit_bias, probability: probability, normalize: normalize)
|
55
56
|
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
56
|
-
|
57
|
+
SVMKit::Validation.check_params_positive(reg_param: reg_param, bias_scale: bias_scale, max_iter: max_iter,
|
58
|
+
batch_size: batch_size)
|
57
59
|
@params = {}
|
58
60
|
@params[:reg_param] = reg_param
|
59
61
|
@params[:fit_bias] = fit_bias
|
60
62
|
@params[:bias_scale] = bias_scale
|
61
63
|
@params[:max_iter] = max_iter
|
62
64
|
@params[:batch_size] = batch_size
|
65
|
+
@params[:probability] = probability
|
63
66
|
@params[:normalize] = normalize
|
64
67
|
@params[:random_seed] = random_seed
|
65
68
|
@params[:random_seed] ||= srand
|
66
69
|
@weight_vec = nil
|
67
70
|
@bias_term = nil
|
71
|
+
@prob_param = nil
|
68
72
|
@classes = nil
|
69
73
|
@rng = Random.new(@params[:random_seed])
|
70
74
|
end
|
@@ -77,6 +81,7 @@ module SVMKit
|
|
77
81
|
def fit(x, y)
|
78
82
|
SVMKit::Validation.check_sample_array(x)
|
79
83
|
SVMKit::Validation.check_label_array(y)
|
84
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
80
85
|
|
81
86
|
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
82
87
|
n_classes = @classes.size
|
@@ -85,16 +90,27 @@ module SVMKit
|
|
85
90
|
if n_classes > 2
|
86
91
|
@weight_vec = Numo::DFloat.zeros(n_classes, n_features)
|
87
92
|
@bias_term = Numo::DFloat.zeros(n_classes)
|
93
|
+
@prob_param = Numo::DFloat.zeros(n_classes, 2)
|
88
94
|
n_classes.times do |n|
|
89
95
|
bin_y = Numo::Int32.cast(y.eq(@classes[n])) * 2 - 1
|
90
96
|
weight, bias = binary_fit(x, bin_y)
|
91
97
|
@weight_vec[n, true] = weight
|
92
98
|
@bias_term[n] = bias
|
99
|
+
@prob_param[n, true] = if @params[:probability]
|
100
|
+
SVMKit::ProbabilisticOutput.fit_sigmoid(x.dot(weight.transpose) + bias, bin_y)
|
101
|
+
else
|
102
|
+
Numo::DFloat[1, 0]
|
103
|
+
end
|
93
104
|
end
|
94
105
|
else
|
95
106
|
negative_label = y.to_a.uniq.sort.first
|
96
107
|
bin_y = Numo::Int32.cast(y.ne(negative_label)) * 2 - 1
|
97
108
|
@weight_vec, @bias_term = binary_fit(x, bin_y)
|
109
|
+
@prob_param = if @params[:probability]
|
110
|
+
SVMKit::ProbabilisticOutput.fit_sigmoid(x.dot(@weight_vec.transpose) + @bias_term, bin_y)
|
111
|
+
else
|
112
|
+
Numo::DFloat[1, 0]
|
113
|
+
end
|
98
114
|
end
|
99
115
|
|
100
116
|
self
|
@@ -124,12 +140,32 @@ module SVMKit
|
|
124
140
|
Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
|
125
141
|
end
|
126
142
|
|
143
|
+
# Predict probability for samples.
|
144
|
+
#
|
145
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
|
146
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
|
147
|
+
def predict_proba(x)
|
148
|
+
SVMKit::Validation.check_sample_array(x)
|
149
|
+
|
150
|
+
if @classes.size > 2
|
151
|
+
probs = 1.0 / (Numo::NMath.exp(@prob_param[true, 0] * decision_function(x) + @prob_param[true, 1]) + 1.0)
|
152
|
+
return (probs.transpose / probs.sum(axis: 1)).transpose
|
153
|
+
end
|
154
|
+
|
155
|
+
n_samples, = x.shape
|
156
|
+
probs = Numo::DFloat.zeros(n_samples, 2)
|
157
|
+
probs[true, 1] = 1.0 / (Numo::NMath.exp(@prob_param[0] * decision_function(x) + @prob_param[1]) + 1.0)
|
158
|
+
probs[true, 0] = 1.0 - probs[true, 1]
|
159
|
+
probs
|
160
|
+
end
|
161
|
+
|
127
162
|
# Dump marshal data.
|
128
163
|
# @return [Hash] The marshal data about SVC.
|
129
164
|
def marshal_dump
|
130
165
|
{ params: @params,
|
131
166
|
weight_vec: @weight_vec,
|
132
167
|
bias_term: @bias_term,
|
168
|
+
prob_param: @prob_param,
|
133
169
|
classes: @classes,
|
134
170
|
rng: @rng }
|
135
171
|
end
|
@@ -140,6 +176,7 @@ module SVMKit
|
|
140
176
|
@params = obj[:params]
|
141
177
|
@weight_vec = obj[:weight_vec]
|
142
178
|
@bias_term = obj[:bias_term]
|
179
|
+
@prob_param = obj[:prob_param]
|
143
180
|
@classes = obj[:classes]
|
144
181
|
@rng = obj[:rng]
|
145
182
|
nil
|
@@ -159,11 +196,13 @@ module SVMKit
|
|
159
196
|
# random sampling
|
160
197
|
subset_ids = rand_ids.shift(@params[:batch_size])
|
161
198
|
rand_ids.concat(subset_ids)
|
162
|
-
|
163
|
-
|
164
|
-
|
199
|
+
sub_samples = samples[subset_ids, true]
|
200
|
+
sub_bin_y = bin_y[subset_ids]
|
201
|
+
target_ids = (sub_samples.dot(weight_vec.transpose) * sub_bin_y).lt(1.0).where
|
202
|
+
n_targets = target_ids.size
|
203
|
+
next if n_targets.zero?
|
165
204
|
# update the weight vector.
|
166
|
-
mean_vec =
|
205
|
+
mean_vec = sub_samples[target_ids, true].transpose.dot(sub_bin_y[target_ids]) / n_targets
|
167
206
|
weight_vec -= learning_rate(t) * (@params[:reg_param] * weight_vec - mean_vec)
|
168
207
|
# scale the weight vector.
|
169
208
|
normalize_weight_vec(weight_vec) if @params[:normalize]
|
@@ -62,6 +62,7 @@ module SVMKit
|
|
62
62
|
def perform(x, y)
|
63
63
|
SVMKit::Validation.check_sample_array(x)
|
64
64
|
SVMKit::Validation.check_label_array(y)
|
65
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
65
66
|
# Initialize the report of cross validation.
|
66
67
|
report = { test_score: [], train_score: nil, fit_time: [] }
|
67
68
|
report[:train_score] = [] if @return_train_score
|
@@ -81,9 +82,12 @@ module SVMKit
|
|
81
82
|
if @evaluator.nil?
|
82
83
|
report[:test_score].push(@estimator.score(test_x, test_y))
|
83
84
|
report[:train_score].push(@estimator.score(train_x, train_y)) if @return_train_score
|
85
|
+
elsif log_loss?
|
86
|
+
report[:test_score].push(@evaluator.score(test_y, @estimator.predict_proba(test_x)))
|
87
|
+
report[:train_score].push(@evaluator.score(train_y, @estimator.predict_proba(train_x))) if @return_train_score
|
84
88
|
else
|
85
89
|
report[:test_score].push(@evaluator.score(test_y, @estimator.predict(test_x)))
|
86
|
-
report[:train_score].push(@
|
90
|
+
report[:train_score].push(@evaluator.score(train_y, @estimator.predict(train_x))) if @return_train_score
|
87
91
|
end
|
88
92
|
end
|
89
93
|
report
|
@@ -96,6 +100,10 @@ module SVMKit
|
|
96
100
|
class_name = @estimator.params[:estimator].class.to_s if class_name.include?('Multiclass')
|
97
101
|
class_name.include?('KernelMachine')
|
98
102
|
end
|
103
|
+
|
104
|
+
def log_loss?
|
105
|
+
@evaluator.is_a?(SVMKit::EvaluationMeasure::LogLoss)
|
106
|
+
end
|
99
107
|
end
|
100
108
|
end
|
101
109
|
end
|
@@ -35,7 +35,7 @@ module SVMKit
|
|
35
35
|
SVMKit::Validation.check_params_integer(n_splits: n_splits)
|
36
36
|
SVMKit::Validation.check_params_boolean(shuffle: shuffle)
|
37
37
|
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
38
|
-
|
38
|
+
SVMKit::Validation.check_params_positive(n_splits: n_splits)
|
39
39
|
@n_splits = n_splits
|
40
40
|
@shuffle = shuffle
|
41
41
|
@random_seed = random_seed
|
@@ -35,7 +35,7 @@ module SVMKit
|
|
35
35
|
SVMKit::Validation.check_params_integer(n_splits: n_splits)
|
36
36
|
SVMKit::Validation.check_params_boolean(shuffle: shuffle)
|
37
37
|
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
38
|
-
|
38
|
+
SVMKit::Validation.check_params_positive(n_splits: n_splits)
|
39
39
|
@n_splits = n_splits
|
40
40
|
@shuffle = shuffle
|
41
41
|
@random_seed = random_seed
|
@@ -51,9 +51,10 @@ module SVMKit
|
|
51
51
|
# @param y [Numo::Int32] (shape: [n_samples])
|
52
52
|
# The labels to be used to generate data indices for stratified K-fold cross validation.
|
53
53
|
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
54
|
-
def split(x, y)
|
54
|
+
def split(x, y)
|
55
55
|
SVMKit::Validation.check_sample_array(x)
|
56
56
|
SVMKit::Validation.check_label_array(y)
|
57
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
57
58
|
# Check the number of samples in each class.
|
58
59
|
unless valid_n_splits?(y)
|
59
60
|
raise ArgumentError,
|
@@ -48,6 +48,7 @@ module SVMKit
|
|
48
48
|
def fit(x, y)
|
49
49
|
SVMKit::Validation.check_sample_array(x)
|
50
50
|
SVMKit::Validation.check_label_array(y)
|
51
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
51
52
|
y_arr = y.to_a
|
52
53
|
@classes = Numo::Int32.asarray(y_arr.uniq.sort)
|
53
54
|
@estimators = @classes.to_a.map do |label|
|
@@ -80,6 +80,7 @@ module SVMKit
|
|
80
80
|
def fit(x, y)
|
81
81
|
SVMKit::Validation.check_sample_array(x)
|
82
82
|
SVMKit::Validation.check_label_array(y)
|
83
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
83
84
|
n_samples, = x.shape
|
84
85
|
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
85
86
|
@class_priors = Numo::DFloat[*@classes.to_a.map { |l| y.eq(l).count / n_samples.to_f }]
|
@@ -154,6 +155,7 @@ module SVMKit
|
|
154
155
|
# @param smoothing_param [Float] The Laplace smoothing parameter.
|
155
156
|
def initialize(smoothing_param: 1.0)
|
156
157
|
SVMKit::Validation.check_params_float(smoothing_param: smoothing_param)
|
158
|
+
SVMKit::Validation.check_params_positive(smoothing_param: smoothing_param)
|
157
159
|
@params = {}
|
158
160
|
@params[:smoothing_param] = smoothing_param
|
159
161
|
end
|
@@ -167,6 +169,7 @@ module SVMKit
|
|
167
169
|
def fit(x, y)
|
168
170
|
SVMKit::Validation.check_sample_array(x)
|
169
171
|
SVMKit::Validation.check_label_array(y)
|
172
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
170
173
|
n_samples, = x.shape
|
171
174
|
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
172
175
|
@class_priors = Numo::DFloat[*@classes.to_a.map { |l| y.eq(l).count / n_samples.to_f }]
|
@@ -241,6 +244,7 @@ module SVMKit
|
|
241
244
|
# @param bin_threshold [Float] The threshold for binarizing of features.
|
242
245
|
def initialize(smoothing_param: 1.0, bin_threshold: 0.0)
|
243
246
|
SVMKit::Validation.check_params_float(smoothing_param: smoothing_param, bin_threshold: bin_threshold)
|
247
|
+
SVMKit::Validation.check_params_positive(smoothing_param: smoothing_param)
|
244
248
|
@params = {}
|
245
249
|
@params[:smoothing_param] = smoothing_param
|
246
250
|
@params[:bin_threshold] = bin_threshold
|
@@ -255,6 +259,7 @@ module SVMKit
|
|
255
259
|
def fit(x, y)
|
256
260
|
SVMKit::Validation.check_sample_array(x)
|
257
261
|
SVMKit::Validation.check_label_array(y)
|
262
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
258
263
|
n_samples, = x.shape
|
259
264
|
bin_x = Numo::DFloat[*x.gt(@params[:bin_threshold])]
|
260
265
|
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
@@ -36,6 +36,7 @@ module SVMKit
|
|
36
36
|
# @param n_neighbors [Integer] The number of neighbors.
|
37
37
|
def initialize(n_neighbors: 5)
|
38
38
|
SVMKit::Validation.check_params_integer(n_neighbors: n_neighbors)
|
39
|
+
SVMKit::Validation.check_params_positive(n_neighbors: n_neighbors)
|
39
40
|
@params = {}
|
40
41
|
@params[:n_neighbors] = n_neighbors
|
41
42
|
@prototypes = nil
|
@@ -51,6 +52,7 @@ module SVMKit
|
|
51
52
|
def fit(x, y)
|
52
53
|
SVMKit::Validation.check_sample_array(x)
|
53
54
|
SVMKit::Validation.check_label_array(y)
|
55
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
54
56
|
@prototypes = Numo::DFloat.asarray(x.to_a)
|
55
57
|
@labels = Numo::Int32.asarray(y.to_a)
|
56
58
|
@classes = Numo::Int32.asarray(y.to_a.uniq.sort)
|
@@ -63,7 +63,9 @@ module SVMKit
|
|
63
63
|
SVMKit::Validation.check_params_integer(n_factors: n_factors, max_iter: max_iter, batch_size: batch_size)
|
64
64
|
SVMKit::Validation.check_params_string(loss: loss)
|
65
65
|
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
66
|
-
|
66
|
+
SVMKit::Validation.check_params_positive(n_factors: n_factors, reg_param_bias: reg_param_bias,
|
67
|
+
reg_param_weight: reg_param_weight, reg_param_factor: reg_param_factor,
|
68
|
+
max_iter: max_iter, batch_size: batch_size)
|
67
69
|
@params = {}
|
68
70
|
@params[:n_factors] = n_factors
|
69
71
|
@params[:loss] = loss
|
@@ -90,6 +92,7 @@ module SVMKit
|
|
90
92
|
def fit(x, y)
|
91
93
|
SVMKit::Validation.check_sample_array(x)
|
92
94
|
SVMKit::Validation.check_label_array(y)
|
95
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
93
96
|
|
94
97
|
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
95
98
|
n_classes = @classes.size
|
@@ -0,0 +1,94 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'svmkit/base/base_estimator'
|
4
|
+
require 'svmkit/base/transformer'
|
5
|
+
|
6
|
+
module SVMKit
|
7
|
+
module Preprocessing
|
8
|
+
# Encode labels to values between 0 and n_classes - 1.
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# encoder = SVMKit::Preprocessing::LabelEncoder.new
|
12
|
+
# labels = Numo::Int32[1, 8, 8, 15, 0]
|
13
|
+
# encoded_labels = encoder.fit_transform(labels)
|
14
|
+
# # > pp encoded_labels
|
15
|
+
# # Numo::Int32#shape=[5]
|
16
|
+
# # [1, 2, 2, 3, 0]
|
17
|
+
# decoded_labels = encoder.inverse_transform(encoded_labels)
|
18
|
+
# # > pp decoded_labels
|
19
|
+
# # [1, 8, 8, 15, 0]
|
20
|
+
class LabelEncoder
|
21
|
+
include Base::BaseEstimator
|
22
|
+
include Base::Transformer
|
23
|
+
|
24
|
+
# Return the class labels.
|
25
|
+
# @return [Array] (size: [n_classes])
|
26
|
+
attr_reader :classes
|
27
|
+
|
28
|
+
# Create a new encoder for encoding labels to values between 0 and n_classes - 1.
|
29
|
+
def initialize
|
30
|
+
@params = {}
|
31
|
+
@classes = nil
|
32
|
+
end
|
33
|
+
|
34
|
+
# Fit label-encoder to labels.
|
35
|
+
#
|
36
|
+
# @overload fit(x) -> LabelEncoder
|
37
|
+
#
|
38
|
+
# @param x [Array] (shape: [n_samples]) The labels to fit label-encoder.
|
39
|
+
# @return [LabelEncoder]
|
40
|
+
def fit(x, _y = nil)
|
41
|
+
x = x.to_a if x.is_a?(Numo::NArray)
|
42
|
+
SVMKit::Validation.check_params_type(Array, x: x)
|
43
|
+
@classes = x.sort.uniq
|
44
|
+
self
|
45
|
+
end
|
46
|
+
|
47
|
+
# Fit label-encoder to labels, then return encoded labels.
|
48
|
+
#
|
49
|
+
# @overload fit_transform(x) -> Numo::DFloat
|
50
|
+
#
|
51
|
+
# @param x [Array] (shape: [n_samples]) The labels to fit label-encoder.
|
52
|
+
# @return [Numo::Int32] The encoded labels.
|
53
|
+
def fit_transform(x, _y = nil)
|
54
|
+
x = x.to_a if x.is_a?(Numo::NArray)
|
55
|
+
SVMKit::Validation.check_params_type(Array, x: x)
|
56
|
+
fit(x).transform(x)
|
57
|
+
end
|
58
|
+
|
59
|
+
# Encode labels.
|
60
|
+
#
|
61
|
+
# @param x [Array] (shape: [n_samples]) The labels to be encoded.
|
62
|
+
# @return [Numo::Int32] The encoded labels.
|
63
|
+
def transform(x)
|
64
|
+
x = x.to_a if x.is_a?(Numo::NArray)
|
65
|
+
SVMKit::Validation.check_params_type(Array, x: x)
|
66
|
+
Numo::Int32[*(x.map { |v| @classes.index(v) })]
|
67
|
+
end
|
68
|
+
|
69
|
+
# Decode encoded labels.
|
70
|
+
#
|
71
|
+
# @param x [Numo::Int32] (shape: [n_samples]) The labels to be decoded.
|
72
|
+
# @return [Array] The decoded labels.
|
73
|
+
def inverse_transform(x)
|
74
|
+
SVMKit::Validation.check_label_array(x)
|
75
|
+
x.to_a.map { |n| @classes[n] }
|
76
|
+
end
|
77
|
+
|
78
|
+
# Dump marshal data.
|
79
|
+
# @return [Hash] The marshal data about LabelEncoder
|
80
|
+
def marshal_dump
|
81
|
+
{ params: @params,
|
82
|
+
classes: @classes }
|
83
|
+
end
|
84
|
+
|
85
|
+
# Load marshal data.
|
86
|
+
# @return [nil]
|
87
|
+
def marshal_load(obj)
|
88
|
+
@params = obj[:params]
|
89
|
+
@classes = obj[:classes]
|
90
|
+
nil
|
91
|
+
end
|
92
|
+
end
|
93
|
+
end
|
94
|
+
end
|
@@ -0,0 +1,98 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'svmkit/base/base_estimator'
|
4
|
+
require 'svmkit/base/transformer'
|
5
|
+
|
6
|
+
module SVMKit
|
7
|
+
module Preprocessing
|
8
|
+
# Encode categorical integer features to one-hot-vectors.
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# encoder = SVMKit::Preprocessing::OneHotEncoder.new
|
12
|
+
# labels = Numo::Int32[0, 0, 2, 3, 2, 1]
|
13
|
+
# one_hot_vectors = encoder.fit_transform(labels)
|
14
|
+
# # > pp one_hot_vectors
|
15
|
+
# # Numo::DFloat#shape[6, 4]
|
16
|
+
# # [[1, 0, 0, 0],
|
17
|
+
# # [1, 0, 0, 0],
|
18
|
+
# # [0, 0, 1, 0],
|
19
|
+
# # [0, 0, 0, 1],
|
20
|
+
# # [0, 0, 1, 0],
|
21
|
+
# # [0, 1, 0, 0]]
|
22
|
+
class OneHotEncoder
|
23
|
+
include Base::BaseEstimator
|
24
|
+
include Base::Transformer
|
25
|
+
|
26
|
+
# Return the maximum values for each feature.
|
27
|
+
# @return [Numo::Int32] (shape: [n_features])
|
28
|
+
attr_reader :n_values
|
29
|
+
|
30
|
+
# Return the indices to feature ranges.
|
31
|
+
# @return [Numo::Int32] (shape: [n_features + 1])
|
32
|
+
attr_reader :feature_indices
|
33
|
+
|
34
|
+
# Create a new encoder for encoding categorical integer features to one-hot-vectors
|
35
|
+
def initialize
|
36
|
+
@params = {}
|
37
|
+
@n_values = nil
|
38
|
+
@feature_indices = nil
|
39
|
+
end
|
40
|
+
|
41
|
+
# Fit one-hot-encoder to samples.
|
42
|
+
#
|
43
|
+
# @overload fit(x) -> OneHotEncoder
|
44
|
+
#
|
45
|
+
# @param x [Numo::Int32] (shape: [n_samples, n_features]) The samples to fit one-hot-encoder.
|
46
|
+
# @return [OneHotEncoder]
|
47
|
+
def fit(x, _y = nil)
|
48
|
+
SVMKit::Validation.check_params_type(Numo::Int32, x: x)
|
49
|
+
@n_values = x.max(0) + 1
|
50
|
+
@feature_indices = Numo::Int32.hstack([[0], @n_values]).cumsum
|
51
|
+
self
|
52
|
+
end
|
53
|
+
|
54
|
+
# Fit one-hot-encoder to samples, then encode samples into one-hot-vectors
|
55
|
+
#
|
56
|
+
# @overload fit_transform(x) -> Numo::DFloat
|
57
|
+
#
|
58
|
+
# @param x [Numo::Int32] (shape: [n_samples, n_features]) The samples to encode into one-hot-vectors.
|
59
|
+
# @return [Numo::DFloat] The one-hot-vectors.
|
60
|
+
def fit_transform(x, _y = nil)
|
61
|
+
SVMKit::Validation.check_params_type(Numo::Int32, x: x)
|
62
|
+
fit(x).transform(x)
|
63
|
+
end
|
64
|
+
|
65
|
+
# Encode samples into one-hot-vectors.
|
66
|
+
#
|
67
|
+
# @param x [Numo::Int32] (shape: [n_samples, n_features]) The samples to encode into one-hot-vectors.
|
68
|
+
# @return [Numo::DFloat] The one-hot-vectors.
|
69
|
+
def transform(x)
|
70
|
+
SVMKit::Validation.check_params_type(Numo::Int32, x: x)
|
71
|
+
n_samples, n_features = x.shape
|
72
|
+
n_features = 1 if n_features.nil?
|
73
|
+
column_indices = (x + @feature_indices[0...-1]).flatten.to_a
|
74
|
+
row_indices = Numo::Int32.new(n_samples).seq.repeat(n_features).to_a
|
75
|
+
codes = Numo::DFloat.zeros(n_samples, @feature_indices[-1])
|
76
|
+
row_indices.zip(column_indices).each { |r, c| codes[r, c] = 1.0 }
|
77
|
+
codes
|
78
|
+
end
|
79
|
+
|
80
|
+
# Dump marshal data.
|
81
|
+
# @return [Hash] The marshal data about OneHotEncoder.
|
82
|
+
def marshal_dump
|
83
|
+
{ params: @params,
|
84
|
+
n_values: @n_values,
|
85
|
+
feature_indices: @feature_indices }
|
86
|
+
end
|
87
|
+
|
88
|
+
# Load marshal data.
|
89
|
+
# @return [nil]
|
90
|
+
def marshal_load(obj)
|
91
|
+
@params = obj[:params]
|
92
|
+
@n_values = obj[:n_values]
|
93
|
+
@feature_indices = obj[:feature_indices]
|
94
|
+
nil
|
95
|
+
end
|
96
|
+
end
|
97
|
+
end
|
98
|
+
end
|