svmkit 0.2.8 → 0.2.9
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +5 -5
- data/.gitignore +4 -0
- data/.rubocop.yml +10 -1
- data/.rubocop_todo.yml +51 -10
- data/Gemfile +1 -1
- data/HISTORY.md +43 -33
- data/lib/svmkit.rb +4 -0
- data/lib/svmkit/base/classifier.rb +1 -0
- data/lib/svmkit/ensemble/random_forest_classifier.rb +5 -2
- data/lib/svmkit/evaluation_measure/log_loss.rb +44 -0
- data/lib/svmkit/kernel_approximation/rbf.rb +1 -1
- data/lib/svmkit/kernel_machine/kernel_svc.rb +40 -2
- data/lib/svmkit/linear_model/logistic_regression.rb +3 -1
- data/lib/svmkit/linear_model/svc.rb +46 -7
- data/lib/svmkit/model_selection/cross_validation.rb +9 -1
- data/lib/svmkit/model_selection/k_fold.rb +1 -1
- data/lib/svmkit/model_selection/stratified_k_fold.rb +3 -2
- data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +1 -0
- data/lib/svmkit/naive_bayes/naive_bayes.rb +5 -0
- data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +2 -0
- data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +4 -1
- data/lib/svmkit/preprocessing/label_encoder.rb +94 -0
- data/lib/svmkit/preprocessing/one_hot_encoder.rb +98 -0
- data/lib/svmkit/probabilistic_output.rb +112 -0
- data/lib/svmkit/tree/decision_tree_classifier.rb +80 -10
- data/lib/svmkit/validation.rb +12 -0
- data/lib/svmkit/version.rb +1 -1
- data/svmkit.gemspec +4 -6
- metadata +18 -14
@@ -45,26 +45,30 @@ module SVMKit
|
|
45
45
|
# @param bias_scale [Float] The scale of the bias term.
|
46
46
|
# @param max_iter [Integer] The maximum number of iterations.
|
47
47
|
# @param batch_size [Integer] The size of the mini batches.
|
48
|
+
# @param probability [Boolean] The flag indicating whether to perform probability estimation.
|
48
49
|
# @param normalize [Boolean] The flag indicating whether to normalize the weight vector.
|
49
50
|
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
50
51
|
def initialize(reg_param: 1.0, fit_bias: false, bias_scale: 1.0,
|
51
|
-
max_iter: 100, batch_size: 50, normalize: true, random_seed: nil)
|
52
|
+
max_iter: 100, batch_size: 50, probability: false, normalize: true, random_seed: nil)
|
52
53
|
SVMKit::Validation.check_params_float(reg_param: reg_param, bias_scale: bias_scale)
|
53
54
|
SVMKit::Validation.check_params_integer(max_iter: max_iter, batch_size: batch_size)
|
54
|
-
SVMKit::Validation.check_params_boolean(fit_bias: fit_bias, normalize: normalize)
|
55
|
+
SVMKit::Validation.check_params_boolean(fit_bias: fit_bias, probability: probability, normalize: normalize)
|
55
56
|
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
56
|
-
|
57
|
+
SVMKit::Validation.check_params_positive(reg_param: reg_param, bias_scale: bias_scale, max_iter: max_iter,
|
58
|
+
batch_size: batch_size)
|
57
59
|
@params = {}
|
58
60
|
@params[:reg_param] = reg_param
|
59
61
|
@params[:fit_bias] = fit_bias
|
60
62
|
@params[:bias_scale] = bias_scale
|
61
63
|
@params[:max_iter] = max_iter
|
62
64
|
@params[:batch_size] = batch_size
|
65
|
+
@params[:probability] = probability
|
63
66
|
@params[:normalize] = normalize
|
64
67
|
@params[:random_seed] = random_seed
|
65
68
|
@params[:random_seed] ||= srand
|
66
69
|
@weight_vec = nil
|
67
70
|
@bias_term = nil
|
71
|
+
@prob_param = nil
|
68
72
|
@classes = nil
|
69
73
|
@rng = Random.new(@params[:random_seed])
|
70
74
|
end
|
@@ -77,6 +81,7 @@ module SVMKit
|
|
77
81
|
def fit(x, y)
|
78
82
|
SVMKit::Validation.check_sample_array(x)
|
79
83
|
SVMKit::Validation.check_label_array(y)
|
84
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
80
85
|
|
81
86
|
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
82
87
|
n_classes = @classes.size
|
@@ -85,16 +90,27 @@ module SVMKit
|
|
85
90
|
if n_classes > 2
|
86
91
|
@weight_vec = Numo::DFloat.zeros(n_classes, n_features)
|
87
92
|
@bias_term = Numo::DFloat.zeros(n_classes)
|
93
|
+
@prob_param = Numo::DFloat.zeros(n_classes, 2)
|
88
94
|
n_classes.times do |n|
|
89
95
|
bin_y = Numo::Int32.cast(y.eq(@classes[n])) * 2 - 1
|
90
96
|
weight, bias = binary_fit(x, bin_y)
|
91
97
|
@weight_vec[n, true] = weight
|
92
98
|
@bias_term[n] = bias
|
99
|
+
@prob_param[n, true] = if @params[:probability]
|
100
|
+
SVMKit::ProbabilisticOutput.fit_sigmoid(x.dot(weight.transpose) + bias, bin_y)
|
101
|
+
else
|
102
|
+
Numo::DFloat[1, 0]
|
103
|
+
end
|
93
104
|
end
|
94
105
|
else
|
95
106
|
negative_label = y.to_a.uniq.sort.first
|
96
107
|
bin_y = Numo::Int32.cast(y.ne(negative_label)) * 2 - 1
|
97
108
|
@weight_vec, @bias_term = binary_fit(x, bin_y)
|
109
|
+
@prob_param = if @params[:probability]
|
110
|
+
SVMKit::ProbabilisticOutput.fit_sigmoid(x.dot(@weight_vec.transpose) + @bias_term, bin_y)
|
111
|
+
else
|
112
|
+
Numo::DFloat[1, 0]
|
113
|
+
end
|
98
114
|
end
|
99
115
|
|
100
116
|
self
|
@@ -124,12 +140,32 @@ module SVMKit
|
|
124
140
|
Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
|
125
141
|
end
|
126
142
|
|
143
|
+
# Predict probability for samples.
|
144
|
+
#
|
145
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
|
146
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
|
147
|
+
def predict_proba(x)
|
148
|
+
SVMKit::Validation.check_sample_array(x)
|
149
|
+
|
150
|
+
if @classes.size > 2
|
151
|
+
probs = 1.0 / (Numo::NMath.exp(@prob_param[true, 0] * decision_function(x) + @prob_param[true, 1]) + 1.0)
|
152
|
+
return (probs.transpose / probs.sum(axis: 1)).transpose
|
153
|
+
end
|
154
|
+
|
155
|
+
n_samples, = x.shape
|
156
|
+
probs = Numo::DFloat.zeros(n_samples, 2)
|
157
|
+
probs[true, 1] = 1.0 / (Numo::NMath.exp(@prob_param[0] * decision_function(x) + @prob_param[1]) + 1.0)
|
158
|
+
probs[true, 0] = 1.0 - probs[true, 1]
|
159
|
+
probs
|
160
|
+
end
|
161
|
+
|
127
162
|
# Dump marshal data.
|
128
163
|
# @return [Hash] The marshal data about SVC.
|
129
164
|
def marshal_dump
|
130
165
|
{ params: @params,
|
131
166
|
weight_vec: @weight_vec,
|
132
167
|
bias_term: @bias_term,
|
168
|
+
prob_param: @prob_param,
|
133
169
|
classes: @classes,
|
134
170
|
rng: @rng }
|
135
171
|
end
|
@@ -140,6 +176,7 @@ module SVMKit
|
|
140
176
|
@params = obj[:params]
|
141
177
|
@weight_vec = obj[:weight_vec]
|
142
178
|
@bias_term = obj[:bias_term]
|
179
|
+
@prob_param = obj[:prob_param]
|
143
180
|
@classes = obj[:classes]
|
144
181
|
@rng = obj[:rng]
|
145
182
|
nil
|
@@ -159,11 +196,13 @@ module SVMKit
|
|
159
196
|
# random sampling
|
160
197
|
subset_ids = rand_ids.shift(@params[:batch_size])
|
161
198
|
rand_ids.concat(subset_ids)
|
162
|
-
|
163
|
-
|
164
|
-
|
199
|
+
sub_samples = samples[subset_ids, true]
|
200
|
+
sub_bin_y = bin_y[subset_ids]
|
201
|
+
target_ids = (sub_samples.dot(weight_vec.transpose) * sub_bin_y).lt(1.0).where
|
202
|
+
n_targets = target_ids.size
|
203
|
+
next if n_targets.zero?
|
165
204
|
# update the weight vector.
|
166
|
-
mean_vec =
|
205
|
+
mean_vec = sub_samples[target_ids, true].transpose.dot(sub_bin_y[target_ids]) / n_targets
|
167
206
|
weight_vec -= learning_rate(t) * (@params[:reg_param] * weight_vec - mean_vec)
|
168
207
|
# scale the weight vector.
|
169
208
|
normalize_weight_vec(weight_vec) if @params[:normalize]
|
@@ -62,6 +62,7 @@ module SVMKit
|
|
62
62
|
def perform(x, y)
|
63
63
|
SVMKit::Validation.check_sample_array(x)
|
64
64
|
SVMKit::Validation.check_label_array(y)
|
65
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
65
66
|
# Initialize the report of cross validation.
|
66
67
|
report = { test_score: [], train_score: nil, fit_time: [] }
|
67
68
|
report[:train_score] = [] if @return_train_score
|
@@ -81,9 +82,12 @@ module SVMKit
|
|
81
82
|
if @evaluator.nil?
|
82
83
|
report[:test_score].push(@estimator.score(test_x, test_y))
|
83
84
|
report[:train_score].push(@estimator.score(train_x, train_y)) if @return_train_score
|
85
|
+
elsif log_loss?
|
86
|
+
report[:test_score].push(@evaluator.score(test_y, @estimator.predict_proba(test_x)))
|
87
|
+
report[:train_score].push(@evaluator.score(train_y, @estimator.predict_proba(train_x))) if @return_train_score
|
84
88
|
else
|
85
89
|
report[:test_score].push(@evaluator.score(test_y, @estimator.predict(test_x)))
|
86
|
-
report[:train_score].push(@
|
90
|
+
report[:train_score].push(@evaluator.score(train_y, @estimator.predict(train_x))) if @return_train_score
|
87
91
|
end
|
88
92
|
end
|
89
93
|
report
|
@@ -96,6 +100,10 @@ module SVMKit
|
|
96
100
|
class_name = @estimator.params[:estimator].class.to_s if class_name.include?('Multiclass')
|
97
101
|
class_name.include?('KernelMachine')
|
98
102
|
end
|
103
|
+
|
104
|
+
def log_loss?
|
105
|
+
@evaluator.is_a?(SVMKit::EvaluationMeasure::LogLoss)
|
106
|
+
end
|
99
107
|
end
|
100
108
|
end
|
101
109
|
end
|
@@ -35,7 +35,7 @@ module SVMKit
|
|
35
35
|
SVMKit::Validation.check_params_integer(n_splits: n_splits)
|
36
36
|
SVMKit::Validation.check_params_boolean(shuffle: shuffle)
|
37
37
|
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
38
|
-
|
38
|
+
SVMKit::Validation.check_params_positive(n_splits: n_splits)
|
39
39
|
@n_splits = n_splits
|
40
40
|
@shuffle = shuffle
|
41
41
|
@random_seed = random_seed
|
@@ -35,7 +35,7 @@ module SVMKit
|
|
35
35
|
SVMKit::Validation.check_params_integer(n_splits: n_splits)
|
36
36
|
SVMKit::Validation.check_params_boolean(shuffle: shuffle)
|
37
37
|
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
38
|
-
|
38
|
+
SVMKit::Validation.check_params_positive(n_splits: n_splits)
|
39
39
|
@n_splits = n_splits
|
40
40
|
@shuffle = shuffle
|
41
41
|
@random_seed = random_seed
|
@@ -51,9 +51,10 @@ module SVMKit
|
|
51
51
|
# @param y [Numo::Int32] (shape: [n_samples])
|
52
52
|
# The labels to be used to generate data indices for stratified K-fold cross validation.
|
53
53
|
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
|
54
|
-
def split(x, y)
|
54
|
+
def split(x, y)
|
55
55
|
SVMKit::Validation.check_sample_array(x)
|
56
56
|
SVMKit::Validation.check_label_array(y)
|
57
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
57
58
|
# Check the number of samples in each class.
|
58
59
|
unless valid_n_splits?(y)
|
59
60
|
raise ArgumentError,
|
@@ -48,6 +48,7 @@ module SVMKit
|
|
48
48
|
def fit(x, y)
|
49
49
|
SVMKit::Validation.check_sample_array(x)
|
50
50
|
SVMKit::Validation.check_label_array(y)
|
51
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
51
52
|
y_arr = y.to_a
|
52
53
|
@classes = Numo::Int32.asarray(y_arr.uniq.sort)
|
53
54
|
@estimators = @classes.to_a.map do |label|
|
@@ -80,6 +80,7 @@ module SVMKit
|
|
80
80
|
def fit(x, y)
|
81
81
|
SVMKit::Validation.check_sample_array(x)
|
82
82
|
SVMKit::Validation.check_label_array(y)
|
83
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
83
84
|
n_samples, = x.shape
|
84
85
|
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
85
86
|
@class_priors = Numo::DFloat[*@classes.to_a.map { |l| y.eq(l).count / n_samples.to_f }]
|
@@ -154,6 +155,7 @@ module SVMKit
|
|
154
155
|
# @param smoothing_param [Float] The Laplace smoothing parameter.
|
155
156
|
def initialize(smoothing_param: 1.0)
|
156
157
|
SVMKit::Validation.check_params_float(smoothing_param: smoothing_param)
|
158
|
+
SVMKit::Validation.check_params_positive(smoothing_param: smoothing_param)
|
157
159
|
@params = {}
|
158
160
|
@params[:smoothing_param] = smoothing_param
|
159
161
|
end
|
@@ -167,6 +169,7 @@ module SVMKit
|
|
167
169
|
def fit(x, y)
|
168
170
|
SVMKit::Validation.check_sample_array(x)
|
169
171
|
SVMKit::Validation.check_label_array(y)
|
172
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
170
173
|
n_samples, = x.shape
|
171
174
|
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
172
175
|
@class_priors = Numo::DFloat[*@classes.to_a.map { |l| y.eq(l).count / n_samples.to_f }]
|
@@ -241,6 +244,7 @@ module SVMKit
|
|
241
244
|
# @param bin_threshold [Float] The threshold for binarizing of features.
|
242
245
|
def initialize(smoothing_param: 1.0, bin_threshold: 0.0)
|
243
246
|
SVMKit::Validation.check_params_float(smoothing_param: smoothing_param, bin_threshold: bin_threshold)
|
247
|
+
SVMKit::Validation.check_params_positive(smoothing_param: smoothing_param)
|
244
248
|
@params = {}
|
245
249
|
@params[:smoothing_param] = smoothing_param
|
246
250
|
@params[:bin_threshold] = bin_threshold
|
@@ -255,6 +259,7 @@ module SVMKit
|
|
255
259
|
def fit(x, y)
|
256
260
|
SVMKit::Validation.check_sample_array(x)
|
257
261
|
SVMKit::Validation.check_label_array(y)
|
262
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
258
263
|
n_samples, = x.shape
|
259
264
|
bin_x = Numo::DFloat[*x.gt(@params[:bin_threshold])]
|
260
265
|
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
@@ -36,6 +36,7 @@ module SVMKit
|
|
36
36
|
# @param n_neighbors [Integer] The number of neighbors.
|
37
37
|
def initialize(n_neighbors: 5)
|
38
38
|
SVMKit::Validation.check_params_integer(n_neighbors: n_neighbors)
|
39
|
+
SVMKit::Validation.check_params_positive(n_neighbors: n_neighbors)
|
39
40
|
@params = {}
|
40
41
|
@params[:n_neighbors] = n_neighbors
|
41
42
|
@prototypes = nil
|
@@ -51,6 +52,7 @@ module SVMKit
|
|
51
52
|
def fit(x, y)
|
52
53
|
SVMKit::Validation.check_sample_array(x)
|
53
54
|
SVMKit::Validation.check_label_array(y)
|
55
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
54
56
|
@prototypes = Numo::DFloat.asarray(x.to_a)
|
55
57
|
@labels = Numo::Int32.asarray(y.to_a)
|
56
58
|
@classes = Numo::Int32.asarray(y.to_a.uniq.sort)
|
@@ -63,7 +63,9 @@ module SVMKit
|
|
63
63
|
SVMKit::Validation.check_params_integer(n_factors: n_factors, max_iter: max_iter, batch_size: batch_size)
|
64
64
|
SVMKit::Validation.check_params_string(loss: loss)
|
65
65
|
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
66
|
-
|
66
|
+
SVMKit::Validation.check_params_positive(n_factors: n_factors, reg_param_bias: reg_param_bias,
|
67
|
+
reg_param_weight: reg_param_weight, reg_param_factor: reg_param_factor,
|
68
|
+
max_iter: max_iter, batch_size: batch_size)
|
67
69
|
@params = {}
|
68
70
|
@params[:n_factors] = n_factors
|
69
71
|
@params[:loss] = loss
|
@@ -90,6 +92,7 @@ module SVMKit
|
|
90
92
|
def fit(x, y)
|
91
93
|
SVMKit::Validation.check_sample_array(x)
|
92
94
|
SVMKit::Validation.check_label_array(y)
|
95
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
93
96
|
|
94
97
|
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
95
98
|
n_classes = @classes.size
|
@@ -0,0 +1,94 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'svmkit/base/base_estimator'
|
4
|
+
require 'svmkit/base/transformer'
|
5
|
+
|
6
|
+
module SVMKit
|
7
|
+
module Preprocessing
|
8
|
+
# Encode labels to values between 0 and n_classes - 1.
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# encoder = SVMKit::Preprocessing::LabelEncoder.new
|
12
|
+
# labels = Numo::Int32[1, 8, 8, 15, 0]
|
13
|
+
# encoded_labels = encoder.fit_transform(labels)
|
14
|
+
# # > pp encoded_labels
|
15
|
+
# # Numo::Int32#shape=[5]
|
16
|
+
# # [1, 2, 2, 3, 0]
|
17
|
+
# decoded_labels = encoder.inverse_transform(encoded_labels)
|
18
|
+
# # > pp decoded_labels
|
19
|
+
# # [1, 8, 8, 15, 0]
|
20
|
+
class LabelEncoder
|
21
|
+
include Base::BaseEstimator
|
22
|
+
include Base::Transformer
|
23
|
+
|
24
|
+
# Return the class labels.
|
25
|
+
# @return [Array] (size: [n_classes])
|
26
|
+
attr_reader :classes
|
27
|
+
|
28
|
+
# Create a new encoder for encoding labels to values between 0 and n_classes - 1.
|
29
|
+
def initialize
|
30
|
+
@params = {}
|
31
|
+
@classes = nil
|
32
|
+
end
|
33
|
+
|
34
|
+
# Fit label-encoder to labels.
|
35
|
+
#
|
36
|
+
# @overload fit(x) -> LabelEncoder
|
37
|
+
#
|
38
|
+
# @param x [Array] (shape: [n_samples]) The labels to fit label-encoder.
|
39
|
+
# @return [LabelEncoder]
|
40
|
+
def fit(x, _y = nil)
|
41
|
+
x = x.to_a if x.is_a?(Numo::NArray)
|
42
|
+
SVMKit::Validation.check_params_type(Array, x: x)
|
43
|
+
@classes = x.sort.uniq
|
44
|
+
self
|
45
|
+
end
|
46
|
+
|
47
|
+
# Fit label-encoder to labels, then return encoded labels.
|
48
|
+
#
|
49
|
+
# @overload fit_transform(x) -> Numo::DFloat
|
50
|
+
#
|
51
|
+
# @param x [Array] (shape: [n_samples]) The labels to fit label-encoder.
|
52
|
+
# @return [Numo::Int32] The encoded labels.
|
53
|
+
def fit_transform(x, _y = nil)
|
54
|
+
x = x.to_a if x.is_a?(Numo::NArray)
|
55
|
+
SVMKit::Validation.check_params_type(Array, x: x)
|
56
|
+
fit(x).transform(x)
|
57
|
+
end
|
58
|
+
|
59
|
+
# Encode labels.
|
60
|
+
#
|
61
|
+
# @param x [Array] (shape: [n_samples]) The labels to be encoded.
|
62
|
+
# @return [Numo::Int32] The encoded labels.
|
63
|
+
def transform(x)
|
64
|
+
x = x.to_a if x.is_a?(Numo::NArray)
|
65
|
+
SVMKit::Validation.check_params_type(Array, x: x)
|
66
|
+
Numo::Int32[*(x.map { |v| @classes.index(v) })]
|
67
|
+
end
|
68
|
+
|
69
|
+
# Decode encoded labels.
|
70
|
+
#
|
71
|
+
# @param x [Numo::Int32] (shape: [n_samples]) The labels to be decoded.
|
72
|
+
# @return [Array] The decoded labels.
|
73
|
+
def inverse_transform(x)
|
74
|
+
SVMKit::Validation.check_label_array(x)
|
75
|
+
x.to_a.map { |n| @classes[n] }
|
76
|
+
end
|
77
|
+
|
78
|
+
# Dump marshal data.
|
79
|
+
# @return [Hash] The marshal data about LabelEncoder
|
80
|
+
def marshal_dump
|
81
|
+
{ params: @params,
|
82
|
+
classes: @classes }
|
83
|
+
end
|
84
|
+
|
85
|
+
# Load marshal data.
|
86
|
+
# @return [nil]
|
87
|
+
def marshal_load(obj)
|
88
|
+
@params = obj[:params]
|
89
|
+
@classes = obj[:classes]
|
90
|
+
nil
|
91
|
+
end
|
92
|
+
end
|
93
|
+
end
|
94
|
+
end
|
@@ -0,0 +1,98 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'svmkit/base/base_estimator'
|
4
|
+
require 'svmkit/base/transformer'
|
5
|
+
|
6
|
+
module SVMKit
|
7
|
+
module Preprocessing
|
8
|
+
# Encode categorical integer features to one-hot-vectors.
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# encoder = SVMKit::Preprocessing::OneHotEncoder.new
|
12
|
+
# labels = Numo::Int32[0, 0, 2, 3, 2, 1]
|
13
|
+
# one_hot_vectors = encoder.fit_transform(labels)
|
14
|
+
# # > pp one_hot_vectors
|
15
|
+
# # Numo::DFloat#shape[6, 4]
|
16
|
+
# # [[1, 0, 0, 0],
|
17
|
+
# # [1, 0, 0, 0],
|
18
|
+
# # [0, 0, 1, 0],
|
19
|
+
# # [0, 0, 0, 1],
|
20
|
+
# # [0, 0, 1, 0],
|
21
|
+
# # [0, 1, 0, 0]]
|
22
|
+
class OneHotEncoder
|
23
|
+
include Base::BaseEstimator
|
24
|
+
include Base::Transformer
|
25
|
+
|
26
|
+
# Return the maximum values for each feature.
|
27
|
+
# @return [Numo::Int32] (shape: [n_features])
|
28
|
+
attr_reader :n_values
|
29
|
+
|
30
|
+
# Return the indices to feature ranges.
|
31
|
+
# @return [Numo::Int32] (shape: [n_features + 1])
|
32
|
+
attr_reader :feature_indices
|
33
|
+
|
34
|
+
# Create a new encoder for encoding categorical integer features to one-hot-vectors
|
35
|
+
def initialize
|
36
|
+
@params = {}
|
37
|
+
@n_values = nil
|
38
|
+
@feature_indices = nil
|
39
|
+
end
|
40
|
+
|
41
|
+
# Fit one-hot-encoder to samples.
|
42
|
+
#
|
43
|
+
# @overload fit(x) -> OneHotEncoder
|
44
|
+
#
|
45
|
+
# @param x [Numo::Int32] (shape: [n_samples, n_features]) The samples to fit one-hot-encoder.
|
46
|
+
# @return [OneHotEncoder]
|
47
|
+
def fit(x, _y = nil)
|
48
|
+
SVMKit::Validation.check_params_type(Numo::Int32, x: x)
|
49
|
+
@n_values = x.max(0) + 1
|
50
|
+
@feature_indices = Numo::Int32.hstack([[0], @n_values]).cumsum
|
51
|
+
self
|
52
|
+
end
|
53
|
+
|
54
|
+
# Fit one-hot-encoder to samples, then encode samples into one-hot-vectors
|
55
|
+
#
|
56
|
+
# @overload fit_transform(x) -> Numo::DFloat
|
57
|
+
#
|
58
|
+
# @param x [Numo::Int32] (shape: [n_samples, n_features]) The samples to encode into one-hot-vectors.
|
59
|
+
# @return [Numo::DFloat] The one-hot-vectors.
|
60
|
+
def fit_transform(x, _y = nil)
|
61
|
+
SVMKit::Validation.check_params_type(Numo::Int32, x: x)
|
62
|
+
fit(x).transform(x)
|
63
|
+
end
|
64
|
+
|
65
|
+
# Encode samples into one-hot-vectors.
|
66
|
+
#
|
67
|
+
# @param x [Numo::Int32] (shape: [n_samples, n_features]) The samples to encode into one-hot-vectors.
|
68
|
+
# @return [Numo::DFloat] The one-hot-vectors.
|
69
|
+
def transform(x)
|
70
|
+
SVMKit::Validation.check_params_type(Numo::Int32, x: x)
|
71
|
+
n_samples, n_features = x.shape
|
72
|
+
n_features = 1 if n_features.nil?
|
73
|
+
column_indices = (x + @feature_indices[0...-1]).flatten.to_a
|
74
|
+
row_indices = Numo::Int32.new(n_samples).seq.repeat(n_features).to_a
|
75
|
+
codes = Numo::DFloat.zeros(n_samples, @feature_indices[-1])
|
76
|
+
row_indices.zip(column_indices).each { |r, c| codes[r, c] = 1.0 }
|
77
|
+
codes
|
78
|
+
end
|
79
|
+
|
80
|
+
# Dump marshal data.
|
81
|
+
# @return [Hash] The marshal data about OneHotEncoder.
|
82
|
+
def marshal_dump
|
83
|
+
{ params: @params,
|
84
|
+
n_values: @n_values,
|
85
|
+
feature_indices: @feature_indices }
|
86
|
+
end
|
87
|
+
|
88
|
+
# Load marshal data.
|
89
|
+
# @return [nil]
|
90
|
+
def marshal_load(obj)
|
91
|
+
@params = obj[:params]
|
92
|
+
@n_values = obj[:n_values]
|
93
|
+
@feature_indices = obj[:feature_indices]
|
94
|
+
nil
|
95
|
+
end
|
96
|
+
end
|
97
|
+
end
|
98
|
+
end
|