rumale 0.18.7 → 0.20.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.rubocop.yml +66 -1
- data/CHANGELOG.md +46 -0
- data/Gemfile +2 -0
- data/README.md +5 -36
- data/lib/rumale.rb +5 -10
- data/lib/rumale/clustering/hdbscan.rb +1 -1
- data/lib/rumale/clustering/k_means.rb +1 -1
- data/lib/rumale/clustering/k_medoids.rb +1 -1
- data/lib/rumale/clustering/mini_batch_k_means.rb +139 -0
- data/lib/rumale/dataset.rb +3 -3
- data/lib/rumale/decomposition/pca.rb +23 -5
- data/lib/rumale/feature_extraction/feature_hasher.rb +14 -1
- data/lib/rumale/feature_extraction/tfidf_transformer.rb +113 -0
- data/lib/rumale/kernel_approximation/nystroem.rb +1 -1
- data/lib/rumale/kernel_machine/kernel_svc.rb +1 -1
- data/lib/rumale/linear_model/base_sgd.rb +1 -1
- data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +13 -1
- data/lib/rumale/model_selection/cross_validation.rb +3 -2
- data/lib/rumale/model_selection/k_fold.rb +1 -1
- data/lib/rumale/model_selection/shuffle_split.rb +1 -1
- data/lib/rumale/multiclass/one_vs_rest_classifier.rb +2 -2
- data/lib/rumale/nearest_neighbors/vp_tree.rb +1 -1
- data/lib/rumale/neural_network/adam.rb +1 -1
- data/lib/rumale/neural_network/base_mlp.rb +1 -1
- data/lib/rumale/preprocessing/binarizer.rb +60 -0
- data/lib/rumale/preprocessing/l1_normalizer.rb +62 -0
- data/lib/rumale/preprocessing/l2_normalizer.rb +2 -1
- data/lib/rumale/preprocessing/max_normalizer.rb +62 -0
- data/lib/rumale/version.rb +1 -1
- data/rumale.gemspec +1 -3
- metadata +11 -44
- data/lib/rumale/linear_model/base_linear_model.rb +0 -101
- data/lib/rumale/optimizer/ada_grad.rb +0 -39
- data/lib/rumale/optimizer/adam.rb +0 -53
- data/lib/rumale/optimizer/nadam.rb +0 -62
- data/lib/rumale/optimizer/rmsprop.rb +0 -47
- data/lib/rumale/optimizer/sgd.rb +0 -43
- data/lib/rumale/optimizer/yellow_fin.rb +0 -101
- data/lib/rumale/polynomial_model/base_factorization_machine.rb +0 -121
- data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +0 -215
- data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +0 -129
data/lib/rumale/dataset.rb
CHANGED
@@ -81,7 +81,7 @@ module Rumale
|
|
81
81
|
y = Numo::Int32.hstack([Numo::Int32.zeros(n_samples_out), Numo::Int32.ones(n_samples_in)])
|
82
82
|
# shuffle data indices.
|
83
83
|
if shuffle
|
84
|
-
rand_ids =
|
84
|
+
rand_ids = Array(0...n_samples).shuffle(random: rng.dup)
|
85
85
|
x = x[rand_ids, true].dup
|
86
86
|
y = y[rand_ids].dup
|
87
87
|
end
|
@@ -118,7 +118,7 @@ module Rumale
|
|
118
118
|
y = Numo::Int32.hstack([Numo::Int32.zeros(n_samples_out), Numo::Int32.ones(n_samples_in)])
|
119
119
|
# shuffle data indices.
|
120
120
|
if shuffle
|
121
|
-
rand_ids =
|
121
|
+
rand_ids = Array(0...n_samples).shuffle(random: rng.dup)
|
122
122
|
x = x[rand_ids, true].dup
|
123
123
|
y = y[rand_ids].dup
|
124
124
|
end
|
@@ -173,7 +173,7 @@ module Rumale
|
|
173
173
|
end
|
174
174
|
# shuffle data.
|
175
175
|
if shuffle
|
176
|
-
rand_ids =
|
176
|
+
rand_ids = Array(0...n_samples).shuffle(random: rng.dup)
|
177
177
|
x = x[rand_ids, true].dup
|
178
178
|
y = y[rand_ids].dup
|
179
179
|
end
|
@@ -9,7 +9,7 @@ module Rumale
|
|
9
9
|
# PCA is a class that implements Principal Component Analysis.
|
10
10
|
#
|
11
11
|
# @example
|
12
|
-
# decomposer = Rumale::Decomposition::PCA.new(n_components: 2)
|
12
|
+
# decomposer = Rumale::Decomposition::PCA.new(n_components: 2, solver: 'fpt')
|
13
13
|
# representaion = decomposer.fit_transform(samples)
|
14
14
|
#
|
15
15
|
# # If Numo::Linalg is installed, you can specify 'evd' for the solver option.
|
@@ -17,6 +17,11 @@ module Rumale
|
|
17
17
|
# decomposer = Rumale::Decomposition::PCA.new(n_components: 2, solver: 'evd')
|
18
18
|
# representaion = decomposer.fit_transform(samples)
|
19
19
|
#
|
20
|
+
# # If Numo::Linalg is loaded and the solver option is not given,
|
21
|
+
# # the solver option is choosen 'evd' automatically.
|
22
|
+
# decomposer = Rumale::Decomposition::PCA.new(n_components: 2)
|
23
|
+
# representaion = decomposer.fit_transform(samples)
|
24
|
+
#
|
20
25
|
# *Reference*
|
21
26
|
# - Sharma, A., and Paliwal, K K., "Fast principal component analysis using fixed-point algorithm," Pattern Recognition Letters, 28, pp. 1151--1155, 2007.
|
22
27
|
class PCA
|
@@ -38,18 +43,24 @@ module Rumale
|
|
38
43
|
# Create a new transformer with PCA.
|
39
44
|
#
|
40
45
|
# @param n_components [Integer] The number of principal components.
|
41
|
-
# @param solver [String] The algorithm for the optimization ('fpt' or 'evd').
|
42
|
-
# '
|
46
|
+
# @param solver [String] The algorithm for the optimization ('auto', 'fpt' or 'evd').
|
47
|
+
# 'auto' chooses the 'evd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'fpt' solver.
|
48
|
+
# 'fpt' uses the fixed-point algorithm.
|
49
|
+
# 'evd' performs eigen value decomposition of the covariance matrix of samples.
|
43
50
|
# @param max_iter [Integer] The maximum number of iterations. If solver = 'evd', this parameter is ignored.
|
44
51
|
# @param tol [Float] The tolerance of termination criterion. If solver = 'evd', this parameter is ignored.
|
45
52
|
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
46
|
-
def initialize(n_components: 2, solver: '
|
53
|
+
def initialize(n_components: 2, solver: 'auto', max_iter: 100, tol: 1.0e-4, random_seed: nil)
|
47
54
|
check_params_numeric(n_components: n_components, max_iter: max_iter, tol: tol)
|
48
55
|
check_params_string(solver: solver)
|
49
56
|
check_params_numeric_or_nil(random_seed: random_seed)
|
50
57
|
check_params_positive(n_components: n_components, max_iter: max_iter, tol: tol)
|
51
58
|
@params = {}
|
52
|
-
@params[:solver] = solver
|
59
|
+
@params[:solver] = if solver == 'auto'
|
60
|
+
load_linalg? ? 'evd' : 'fpt'
|
61
|
+
else
|
62
|
+
solver != 'evd' ? 'fpt' : 'evd'
|
63
|
+
end
|
53
64
|
@params[:n_components] = n_components
|
54
65
|
@params[:max_iter] = max_iter
|
55
66
|
@params[:tol] = tol
|
@@ -128,6 +139,13 @@ module Rumale
|
|
128
139
|
|
129
140
|
private
|
130
141
|
|
142
|
+
def load_linalg?
|
143
|
+
return false if defined?(Numo::Linalg).nil?
|
144
|
+
return false if Numo::Linalg::VERSION < '0.1.4'
|
145
|
+
|
146
|
+
true
|
147
|
+
end
|
148
|
+
|
131
149
|
def orthogonalize(pcvec)
|
132
150
|
unless @components.nil?
|
133
151
|
delta = @components.dot(pcvec) * @components.transpose
|
@@ -1,6 +1,5 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
-
require 'mmh3'
|
4
3
|
require 'rumale/base/base_estimator'
|
5
4
|
require 'rumale/base/transformer'
|
6
5
|
|
@@ -11,11 +10,15 @@ module Rumale
|
|
11
10
|
# This encoder employs signed 32-bit Murmurhash3 as the hash function.
|
12
11
|
#
|
13
12
|
# @example
|
13
|
+
# require 'mmh3'
|
14
|
+
# require 'rumale'
|
15
|
+
#
|
14
16
|
# encoder = Rumale::FeatureExtraction::FeatureHasher.new(n_features: 10)
|
15
17
|
# x = encoder.transform([
|
16
18
|
# { dog: 1, cat: 2, elephant: 4 },
|
17
19
|
# { dog: 2, run: 5 }
|
18
20
|
# ])
|
21
|
+
#
|
19
22
|
# # > pp x
|
20
23
|
# # Numo::DFloat#shape=[2,10]
|
21
24
|
# # [[0, 0, -4, -1, 0, 0, 0, 0, 0, 2],
|
@@ -62,6 +65,8 @@ module Rumale
|
|
62
65
|
# @param x [Array<Hash>] (shape: [n_samples]) The array of hash consisting of feature names and values.
|
63
66
|
# @return [Numo::DFloat] (shape: [n_samples, n_features]) The encoded sample array.
|
64
67
|
def transform(x)
|
68
|
+
raise 'FeatureHasher#transform requires Mmh3 but that is not loaded.' unless enable_mmh3?
|
69
|
+
|
65
70
|
x = [x] unless x.is_a?(Array)
|
66
71
|
n_samples = x.size
|
67
72
|
|
@@ -85,6 +90,14 @@ module Rumale
|
|
85
90
|
|
86
91
|
private
|
87
92
|
|
93
|
+
def enable_mmh3?
|
94
|
+
if defined?(Mmh3).nil?
|
95
|
+
warn('FeatureHasher#transform requires Mmh3 but that is not loaded. You should intall and load mmh3 gem in advance.')
|
96
|
+
return false
|
97
|
+
end
|
98
|
+
true
|
99
|
+
end
|
100
|
+
|
88
101
|
def n_features
|
89
102
|
@params[:n_features]
|
90
103
|
end
|
@@ -0,0 +1,113 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/base_estimator'
|
4
|
+
require 'rumale/base/transformer'
|
5
|
+
require 'rumale/preprocessing/l1_normalizer'
|
6
|
+
require 'rumale/preprocessing/l2_normalizer'
|
7
|
+
|
8
|
+
module Rumale
|
9
|
+
module FeatureExtraction
|
10
|
+
# Transform sample matrix with term frequecy (tf) to a normalized tf-idf (inverse document frequency) reprensentation.
|
11
|
+
#
|
12
|
+
# @example
|
13
|
+
# encoder = Rumale::FeatureExtraction::HashVectorizer.new
|
14
|
+
# x = encoder.fit_transform([
|
15
|
+
# { foo: 1, bar: 2 },
|
16
|
+
# { foo: 3, baz: 1 }
|
17
|
+
# ])
|
18
|
+
#
|
19
|
+
# # > pp x
|
20
|
+
# # Numo::DFloat#shape=[2,3]
|
21
|
+
# # [[2, 0, 1],
|
22
|
+
# # [0, 1, 3]]
|
23
|
+
#
|
24
|
+
# transformer = Rumale::FeatureExtraction::TfidfTransformer.new
|
25
|
+
# x_tfidf = transformer.fit_transform(x)
|
26
|
+
#
|
27
|
+
# # > pp x_tfidf
|
28
|
+
# # Numo::DFloat#shape=[2,3]
|
29
|
+
# # [[0.959056, 0, 0.283217],
|
30
|
+
# # [0, 0.491506, 0.870874]]
|
31
|
+
#
|
32
|
+
# *Reference*
|
33
|
+
# - Manning, C D., Raghavan, P., and Schutze, H., "Introduction to Information Retrieval," Cambridge University Press., 2008.
|
34
|
+
class TfidfTransformer
|
35
|
+
include Base::BaseEstimator
|
36
|
+
include Base::Transformer
|
37
|
+
|
38
|
+
# Return the vector consists of inverse document frequency.
|
39
|
+
# @return [Numo::DFloat] (shape: [n_features])
|
40
|
+
attr_reader :idf
|
41
|
+
|
42
|
+
# Create a new transfomer for converting tf vectors to tf-idf vectors.
|
43
|
+
#
|
44
|
+
# @param norm [String] The normalization method to be used ('l1', 'l2' and 'none').
|
45
|
+
# @param use_idf [Boolean] The flag indicating whether to use inverse document frequency weighting.
|
46
|
+
# @param smooth_idf [Boolean] The flag indicating whether to apply idf smoothing by log((n_samples + 1) / (df + 1)) + 1.
|
47
|
+
# @param sublinear_tf [Boolean] The flag indicating whether to perform subliner tf scaling by 1 + log(tf).
|
48
|
+
def initialize(norm: 'l2', use_idf: true, smooth_idf: false, sublinear_tf: false)
|
49
|
+
check_params_string(norm: norm)
|
50
|
+
check_params_boolean(use_idf: use_idf, smooth_idf: smooth_idf, sublinear_tf: sublinear_tf)
|
51
|
+
@params = {}
|
52
|
+
@params[:norm] = norm
|
53
|
+
@params[:use_idf] = use_idf
|
54
|
+
@params[:smooth_idf] = smooth_idf
|
55
|
+
@params[:sublinear_tf] = sublinear_tf
|
56
|
+
@idf = nil
|
57
|
+
end
|
58
|
+
|
59
|
+
# Calculate the inverse document frequency for weighting.
|
60
|
+
#
|
61
|
+
# @overload fit(x) -> TfidfTransformer
|
62
|
+
#
|
63
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate the idf values.
|
64
|
+
# @return [TfidfTransformer]
|
65
|
+
def fit(x, _y = nil)
|
66
|
+
return self unless @params[:use_idf]
|
67
|
+
|
68
|
+
x = check_convert_sample_array(x)
|
69
|
+
|
70
|
+
n_samples = x.shape[0]
|
71
|
+
df = x.class.cast(x.gt(0.0).count(0))
|
72
|
+
|
73
|
+
if @params[:smooth_idf]
|
74
|
+
df += 1
|
75
|
+
n_samples += 1
|
76
|
+
end
|
77
|
+
|
78
|
+
@idf = Numo::NMath.log(n_samples / df) + 1
|
79
|
+
|
80
|
+
self
|
81
|
+
end
|
82
|
+
|
83
|
+
# Calculate the idf values, and then transfrom samples to the tf-idf representation.
|
84
|
+
#
|
85
|
+
# @overload fit_transform(x) -> Numo::DFloat
|
86
|
+
#
|
87
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate idf and be transformed to tf-idf representation.
|
88
|
+
# @return [Numo::DFloat] The transformed samples.
|
89
|
+
def fit_transform(x, _y = nil)
|
90
|
+
fit(x).transform(x)
|
91
|
+
end
|
92
|
+
|
93
|
+
# Perform transforming the given samples to the tf-idf representation.
|
94
|
+
#
|
95
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be transformed.
|
96
|
+
# @return [Numo::DFloat] The transformed samples.
|
97
|
+
def transform(x)
|
98
|
+
x = check_convert_sample_array(x)
|
99
|
+
z = x.dup
|
100
|
+
|
101
|
+
z[z.ne(0)] = Numo::NMath.log(z[z.ne(0)]) + 1 if @params[:sublinear_tf]
|
102
|
+
z *= @idf if @params[:use_idf]
|
103
|
+
case @params[:norm]
|
104
|
+
when 'l2'
|
105
|
+
z = Rumale::Preprocessing::L2Normalizer.new.fit_transform(z)
|
106
|
+
when 'l1'
|
107
|
+
z = Rumale::Preprocessing::L1Normalizer.new.fit_transform(z)
|
108
|
+
end
|
109
|
+
z
|
110
|
+
end
|
111
|
+
end
|
112
|
+
end
|
113
|
+
end
|
@@ -69,7 +69,7 @@ module Rumale
|
|
69
69
|
n_components = [1, [@params[:n_components], n_samples].min].max
|
70
70
|
|
71
71
|
# random sampling.
|
72
|
-
@component_indices = Numo::Int32.cast(
|
72
|
+
@component_indices = Numo::Int32.cast(Array(0...n_samples).shuffle(random: sub_rng)[0...n_components])
|
73
73
|
@components = x[@component_indices, true]
|
74
74
|
|
75
75
|
# calculate normalizing factor.
|
@@ -172,7 +172,7 @@ module Rumale
|
|
172
172
|
# Start optimization.
|
173
173
|
@params[:max_iter].times do |t|
|
174
174
|
# random sampling
|
175
|
-
rand_ids =
|
175
|
+
rand_ids = Array(0...n_training_samples).shuffle(random: sub_rng) if rand_ids.empty?
|
176
176
|
target_id = rand_ids.shift
|
177
177
|
# update the weight vector
|
178
178
|
func = (weight_vec * bin_y).dot(x[target_id, true].transpose).to_f
|
@@ -209,7 +209,7 @@ module Rumale
|
|
209
209
|
l1_penalty = LinearModel::Penalty::L1Penalty.new(reg_param: l1_reg_param) if apply_l1_penalty?
|
210
210
|
# Optimization.
|
211
211
|
@params[:max_iter].times do |t|
|
212
|
-
sample_ids =
|
212
|
+
sample_ids = Array(0...n_samples)
|
213
213
|
sample_ids.shuffle!(random: sub_rng)
|
214
214
|
until (subset_ids = sample_ids.shift(@params[:batch_size])).empty?
|
215
215
|
# sampling
|
@@ -2,13 +2,15 @@
|
|
2
2
|
|
3
3
|
require 'rumale/base/base_estimator'
|
4
4
|
require 'rumale/base/transformer'
|
5
|
-
require 'mopti/scaled_conjugate_gradient'
|
6
5
|
|
7
6
|
module Rumale
|
8
7
|
module MetricLearning
|
9
8
|
# NeighbourhoodComponentAnalysis is a class that implements Neighbourhood Component Analysis.
|
10
9
|
#
|
11
10
|
# @example
|
11
|
+
# require 'mopti'
|
12
|
+
# require 'rumale'
|
13
|
+
#
|
12
14
|
# transformer = Rumale::MetricLearning::NeighbourhoodComponentAnalysis.new
|
13
15
|
# transformer.fit(training_samples, traininig_labels)
|
14
16
|
# low_samples = transformer.transform(testing_samples)
|
@@ -63,6 +65,8 @@ module Rumale
|
|
63
65
|
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
64
66
|
# @return [NeighbourhoodComponentAnalysis] The learned classifier itself.
|
65
67
|
def fit(x, y)
|
68
|
+
raise 'NeighbourhoodComponentAnalysis#fit requires Mopti but that is not loaded.' unless enable_mopti?
|
69
|
+
|
66
70
|
x = check_convert_sample_array(x)
|
67
71
|
y = check_convert_label_array(y)
|
68
72
|
check_sample_label_size(x, y)
|
@@ -98,6 +102,14 @@ module Rumale
|
|
98
102
|
|
99
103
|
private
|
100
104
|
|
105
|
+
def enable_mopti?
|
106
|
+
if defined?(Mopti).nil?
|
107
|
+
warn('NeighbourhoodComponentAnalysis#fit requires Mopti but that is not loaded. You should intall and load mopti gem in advance.')
|
108
|
+
return false
|
109
|
+
end
|
110
|
+
true
|
111
|
+
end
|
112
|
+
|
101
113
|
def init_components(x, n_features, n_components)
|
102
114
|
if @params[:init] == 'pca'
|
103
115
|
pca = Rumale::Decomposition::PCA.new(n_components: n_components, solver: 'evd')
|
@@ -69,10 +69,11 @@ module Rumale
|
|
69
69
|
# the return_train_score is false.
|
70
70
|
def perform(x, y)
|
71
71
|
x = check_convert_sample_array(x)
|
72
|
-
|
72
|
+
case @estimator
|
73
|
+
when Rumale::Base::Classifier
|
73
74
|
y = check_convert_label_array(y)
|
74
75
|
check_sample_label_size(x, y)
|
75
|
-
|
76
|
+
when Rumale::Base::Regressor
|
76
77
|
y = check_convert_tvalue_array(y)
|
77
78
|
check_sample_tvalue_size(x, y)
|
78
79
|
else
|
@@ -62,7 +62,7 @@ module Rumale
|
|
62
62
|
end
|
63
63
|
sub_rng = @rng.dup
|
64
64
|
# Splits dataset ids to each fold.
|
65
|
-
dataset_ids =
|
65
|
+
dataset_ids = Array(0...n_samples)
|
66
66
|
dataset_ids.shuffle!(random: sub_rng) if @shuffle
|
67
67
|
fold_sets = Array.new(@n_splits) do |n|
|
68
68
|
n_fold_samples = n_samples / @n_splits
|
@@ -74,7 +74,7 @@ module Rumale
|
|
74
74
|
end
|
75
75
|
sub_rng = @rng.dup
|
76
76
|
# Returns array consisting of the training and testing ids for each fold.
|
77
|
-
dataset_ids =
|
77
|
+
dataset_ids = Array(0...n_samples)
|
78
78
|
Array.new(@n_splits) do
|
79
79
|
test_ids = dataset_ids.sample(n_test_samples, random: sub_rng)
|
80
80
|
train_ids = if @train_size.nil?
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
-
require 'rumale/base/base_estimator
|
4
|
-
require 'rumale/base/classifier
|
3
|
+
require 'rumale/base/base_estimator'
|
4
|
+
require 'rumale/base/classifier'
|
5
5
|
|
6
6
|
module Rumale
|
7
7
|
# This module consists of the classes that implement multi-class classification strategy.
|
@@ -30,7 +30,7 @@ module Rumale
|
|
30
30
|
@params = {}
|
31
31
|
@params[:min_samples_leaf] = min_samples_leaf
|
32
32
|
@data = x
|
33
|
-
@tree = build_tree(Numo::Int32.cast(
|
33
|
+
@tree = build_tree(Numo::Int32.cast(Array(0...@data.shape[0])))
|
34
34
|
end
|
35
35
|
|
36
36
|
# Search k-nearest neighbors of given query point.
|
@@ -32,7 +32,7 @@ module Rumale
|
|
32
32
|
end
|
33
33
|
|
34
34
|
# @!visibility private
|
35
|
-
# Calculate the updated weight with
|
35
|
+
# Calculate the updated weight with Adam adaptive learning rate.
|
36
36
|
#
|
37
37
|
# @param weight [Numo::DFloat] (shape: [n_features]) The weight to be updated.
|
38
38
|
# @param gradient [Numo::DFloat] (shape: [n_features]) The gradient for updating the weight.
|
@@ -222,7 +222,7 @@ module Rumale
|
|
222
222
|
n_samples = x.shape[0]
|
223
223
|
|
224
224
|
@params[:max_iter].times do |t|
|
225
|
-
sample_ids =
|
225
|
+
sample_ids = Array(0...n_samples)
|
226
226
|
sample_ids.shuffle!(random: srng)
|
227
227
|
until (subset_ids = sample_ids.shift(@params[:batch_size])).empty?
|
228
228
|
# random sampling
|
@@ -0,0 +1,60 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/base_estimator'
|
4
|
+
require 'rumale/base/transformer'
|
5
|
+
|
6
|
+
module Rumale
|
7
|
+
module Preprocessing
|
8
|
+
# Binarize samples according to a threshold
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# binarizer = Rumale::Preprocessing::Binarizer.new
|
12
|
+
# x = Numo::DFloat[[-1.2, 3.2], [2.4, -0.5], [4.5, 0.8]]
|
13
|
+
# b = binarizer.transform(x)
|
14
|
+
# p b
|
15
|
+
#
|
16
|
+
# # Numo::DFloat#shape=[3, 2]
|
17
|
+
# # [[0, 1],
|
18
|
+
# # [1, 0],
|
19
|
+
# # [1, 1]]
|
20
|
+
class Binarizer
|
21
|
+
include Base::BaseEstimator
|
22
|
+
include Base::Transformer
|
23
|
+
|
24
|
+
# Create a new transformer for binarization.
|
25
|
+
# @param threshold [Float] The threshold value for binarization.
|
26
|
+
def initialize(threshold: 0.0)
|
27
|
+
check_params_numeric(threshold: threshold)
|
28
|
+
@params = { threshold: threshold }
|
29
|
+
end
|
30
|
+
|
31
|
+
# This method does nothing and returns the object itself.
|
32
|
+
# For compatibility with other transformer, this method exists.
|
33
|
+
#
|
34
|
+
# @overload fit() -> Binarizer
|
35
|
+
#
|
36
|
+
# @return [Binarizer]
|
37
|
+
def fit(_x = nil, _y = nil)
|
38
|
+
self
|
39
|
+
end
|
40
|
+
|
41
|
+
# Binarize each sample.
|
42
|
+
#
|
43
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be binarized.
|
44
|
+
# @return [Numo::DFloat] The binarized samples.
|
45
|
+
def transform(x)
|
46
|
+
x = check_convert_sample_array(x)
|
47
|
+
x.class.cast(x.gt(@params[:threshold]))
|
48
|
+
end
|
49
|
+
|
50
|
+
# The output of this method is the same as that of the transform method.
|
51
|
+
# For compatibility with other transformer, this method exists.
|
52
|
+
#
|
53
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be binarized.
|
54
|
+
# @return [Numo::DFloat] The binarized samples.
|
55
|
+
def fit_transform(x, _y = nil)
|
56
|
+
fit(x).transform(x)
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|