rumale 0.18.7 → 0.20.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.rubocop.yml +66 -1
- data/CHANGELOG.md +46 -0
- data/Gemfile +2 -0
- data/README.md +5 -36
- data/lib/rumale.rb +5 -10
- data/lib/rumale/clustering/hdbscan.rb +1 -1
- data/lib/rumale/clustering/k_means.rb +1 -1
- data/lib/rumale/clustering/k_medoids.rb +1 -1
- data/lib/rumale/clustering/mini_batch_k_means.rb +139 -0
- data/lib/rumale/dataset.rb +3 -3
- data/lib/rumale/decomposition/pca.rb +23 -5
- data/lib/rumale/feature_extraction/feature_hasher.rb +14 -1
- data/lib/rumale/feature_extraction/tfidf_transformer.rb +113 -0
- data/lib/rumale/kernel_approximation/nystroem.rb +1 -1
- data/lib/rumale/kernel_machine/kernel_svc.rb +1 -1
- data/lib/rumale/linear_model/base_sgd.rb +1 -1
- data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +13 -1
- data/lib/rumale/model_selection/cross_validation.rb +3 -2
- data/lib/rumale/model_selection/k_fold.rb +1 -1
- data/lib/rumale/model_selection/shuffle_split.rb +1 -1
- data/lib/rumale/multiclass/one_vs_rest_classifier.rb +2 -2
- data/lib/rumale/nearest_neighbors/vp_tree.rb +1 -1
- data/lib/rumale/neural_network/adam.rb +1 -1
- data/lib/rumale/neural_network/base_mlp.rb +1 -1
- data/lib/rumale/preprocessing/binarizer.rb +60 -0
- data/lib/rumale/preprocessing/l1_normalizer.rb +62 -0
- data/lib/rumale/preprocessing/l2_normalizer.rb +2 -1
- data/lib/rumale/preprocessing/max_normalizer.rb +62 -0
- data/lib/rumale/version.rb +1 -1
- data/rumale.gemspec +1 -3
- metadata +11 -44
- data/lib/rumale/linear_model/base_linear_model.rb +0 -101
- data/lib/rumale/optimizer/ada_grad.rb +0 -39
- data/lib/rumale/optimizer/adam.rb +0 -53
- data/lib/rumale/optimizer/nadam.rb +0 -62
- data/lib/rumale/optimizer/rmsprop.rb +0 -47
- data/lib/rumale/optimizer/sgd.rb +0 -43
- data/lib/rumale/optimizer/yellow_fin.rb +0 -101
- data/lib/rumale/polynomial_model/base_factorization_machine.rb +0 -121
- data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +0 -215
- data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +0 -129
data/lib/rumale/dataset.rb
CHANGED
@@ -81,7 +81,7 @@ module Rumale
|
|
81
81
|
y = Numo::Int32.hstack([Numo::Int32.zeros(n_samples_out), Numo::Int32.ones(n_samples_in)])
|
82
82
|
# shuffle data indices.
|
83
83
|
if shuffle
|
84
|
-
rand_ids =
|
84
|
+
rand_ids = Array(0...n_samples).shuffle(random: rng.dup)
|
85
85
|
x = x[rand_ids, true].dup
|
86
86
|
y = y[rand_ids].dup
|
87
87
|
end
|
@@ -118,7 +118,7 @@ module Rumale
|
|
118
118
|
y = Numo::Int32.hstack([Numo::Int32.zeros(n_samples_out), Numo::Int32.ones(n_samples_in)])
|
119
119
|
# shuffle data indices.
|
120
120
|
if shuffle
|
121
|
-
rand_ids =
|
121
|
+
rand_ids = Array(0...n_samples).shuffle(random: rng.dup)
|
122
122
|
x = x[rand_ids, true].dup
|
123
123
|
y = y[rand_ids].dup
|
124
124
|
end
|
@@ -173,7 +173,7 @@ module Rumale
|
|
173
173
|
end
|
174
174
|
# shuffle data.
|
175
175
|
if shuffle
|
176
|
-
rand_ids =
|
176
|
+
rand_ids = Array(0...n_samples).shuffle(random: rng.dup)
|
177
177
|
x = x[rand_ids, true].dup
|
178
178
|
y = y[rand_ids].dup
|
179
179
|
end
|
@@ -9,7 +9,7 @@ module Rumale
|
|
9
9
|
# PCA is a class that implements Principal Component Analysis.
|
10
10
|
#
|
11
11
|
# @example
|
12
|
-
# decomposer = Rumale::Decomposition::PCA.new(n_components: 2)
|
12
|
+
# decomposer = Rumale::Decomposition::PCA.new(n_components: 2, solver: 'fpt')
|
13
13
|
# representaion = decomposer.fit_transform(samples)
|
14
14
|
#
|
15
15
|
# # If Numo::Linalg is installed, you can specify 'evd' for the solver option.
|
@@ -17,6 +17,11 @@ module Rumale
|
|
17
17
|
# decomposer = Rumale::Decomposition::PCA.new(n_components: 2, solver: 'evd')
|
18
18
|
# representaion = decomposer.fit_transform(samples)
|
19
19
|
#
|
20
|
+
# # If Numo::Linalg is loaded and the solver option is not given,
|
21
|
+
# # the solver option is choosen 'evd' automatically.
|
22
|
+
# decomposer = Rumale::Decomposition::PCA.new(n_components: 2)
|
23
|
+
# representaion = decomposer.fit_transform(samples)
|
24
|
+
#
|
20
25
|
# *Reference*
|
21
26
|
# - Sharma, A., and Paliwal, K K., "Fast principal component analysis using fixed-point algorithm," Pattern Recognition Letters, 28, pp. 1151--1155, 2007.
|
22
27
|
class PCA
|
@@ -38,18 +43,24 @@ module Rumale
|
|
38
43
|
# Create a new transformer with PCA.
|
39
44
|
#
|
40
45
|
# @param n_components [Integer] The number of principal components.
|
41
|
-
# @param solver [String] The algorithm for the optimization ('fpt' or 'evd').
|
42
|
-
# '
|
46
|
+
# @param solver [String] The algorithm for the optimization ('auto', 'fpt' or 'evd').
|
47
|
+
# 'auto' chooses the 'evd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'fpt' solver.
|
48
|
+
# 'fpt' uses the fixed-point algorithm.
|
49
|
+
# 'evd' performs eigen value decomposition of the covariance matrix of samples.
|
43
50
|
# @param max_iter [Integer] The maximum number of iterations. If solver = 'evd', this parameter is ignored.
|
44
51
|
# @param tol [Float] The tolerance of termination criterion. If solver = 'evd', this parameter is ignored.
|
45
52
|
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
46
|
-
def initialize(n_components: 2, solver: '
|
53
|
+
def initialize(n_components: 2, solver: 'auto', max_iter: 100, tol: 1.0e-4, random_seed: nil)
|
47
54
|
check_params_numeric(n_components: n_components, max_iter: max_iter, tol: tol)
|
48
55
|
check_params_string(solver: solver)
|
49
56
|
check_params_numeric_or_nil(random_seed: random_seed)
|
50
57
|
check_params_positive(n_components: n_components, max_iter: max_iter, tol: tol)
|
51
58
|
@params = {}
|
52
|
-
@params[:solver] = solver
|
59
|
+
@params[:solver] = if solver == 'auto'
|
60
|
+
load_linalg? ? 'evd' : 'fpt'
|
61
|
+
else
|
62
|
+
solver != 'evd' ? 'fpt' : 'evd'
|
63
|
+
end
|
53
64
|
@params[:n_components] = n_components
|
54
65
|
@params[:max_iter] = max_iter
|
55
66
|
@params[:tol] = tol
|
@@ -128,6 +139,13 @@ module Rumale
|
|
128
139
|
|
129
140
|
private
|
130
141
|
|
142
|
+
def load_linalg?
|
143
|
+
return false if defined?(Numo::Linalg).nil?
|
144
|
+
return false if Numo::Linalg::VERSION < '0.1.4'
|
145
|
+
|
146
|
+
true
|
147
|
+
end
|
148
|
+
|
131
149
|
def orthogonalize(pcvec)
|
132
150
|
unless @components.nil?
|
133
151
|
delta = @components.dot(pcvec) * @components.transpose
|
@@ -1,6 +1,5 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
-
require 'mmh3'
|
4
3
|
require 'rumale/base/base_estimator'
|
5
4
|
require 'rumale/base/transformer'
|
6
5
|
|
@@ -11,11 +10,15 @@ module Rumale
|
|
11
10
|
# This encoder employs signed 32-bit Murmurhash3 as the hash function.
|
12
11
|
#
|
13
12
|
# @example
|
13
|
+
# require 'mmh3'
|
14
|
+
# require 'rumale'
|
15
|
+
#
|
14
16
|
# encoder = Rumale::FeatureExtraction::FeatureHasher.new(n_features: 10)
|
15
17
|
# x = encoder.transform([
|
16
18
|
# { dog: 1, cat: 2, elephant: 4 },
|
17
19
|
# { dog: 2, run: 5 }
|
18
20
|
# ])
|
21
|
+
#
|
19
22
|
# # > pp x
|
20
23
|
# # Numo::DFloat#shape=[2,10]
|
21
24
|
# # [[0, 0, -4, -1, 0, 0, 0, 0, 0, 2],
|
@@ -62,6 +65,8 @@ module Rumale
|
|
62
65
|
# @param x [Array<Hash>] (shape: [n_samples]) The array of hash consisting of feature names and values.
|
63
66
|
# @return [Numo::DFloat] (shape: [n_samples, n_features]) The encoded sample array.
|
64
67
|
def transform(x)
|
68
|
+
raise 'FeatureHasher#transform requires Mmh3 but that is not loaded.' unless enable_mmh3?
|
69
|
+
|
65
70
|
x = [x] unless x.is_a?(Array)
|
66
71
|
n_samples = x.size
|
67
72
|
|
@@ -85,6 +90,14 @@ module Rumale
|
|
85
90
|
|
86
91
|
private
|
87
92
|
|
93
|
+
def enable_mmh3?
|
94
|
+
if defined?(Mmh3).nil?
|
95
|
+
warn('FeatureHasher#transform requires Mmh3 but that is not loaded. You should intall and load mmh3 gem in advance.')
|
96
|
+
return false
|
97
|
+
end
|
98
|
+
true
|
99
|
+
end
|
100
|
+
|
88
101
|
def n_features
|
89
102
|
@params[:n_features]
|
90
103
|
end
|
@@ -0,0 +1,113 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/base_estimator'
|
4
|
+
require 'rumale/base/transformer'
|
5
|
+
require 'rumale/preprocessing/l1_normalizer'
|
6
|
+
require 'rumale/preprocessing/l2_normalizer'
|
7
|
+
|
8
|
+
module Rumale
|
9
|
+
module FeatureExtraction
|
10
|
+
# Transform sample matrix with term frequecy (tf) to a normalized tf-idf (inverse document frequency) reprensentation.
|
11
|
+
#
|
12
|
+
# @example
|
13
|
+
# encoder = Rumale::FeatureExtraction::HashVectorizer.new
|
14
|
+
# x = encoder.fit_transform([
|
15
|
+
# { foo: 1, bar: 2 },
|
16
|
+
# { foo: 3, baz: 1 }
|
17
|
+
# ])
|
18
|
+
#
|
19
|
+
# # > pp x
|
20
|
+
# # Numo::DFloat#shape=[2,3]
|
21
|
+
# # [[2, 0, 1],
|
22
|
+
# # [0, 1, 3]]
|
23
|
+
#
|
24
|
+
# transformer = Rumale::FeatureExtraction::TfidfTransformer.new
|
25
|
+
# x_tfidf = transformer.fit_transform(x)
|
26
|
+
#
|
27
|
+
# # > pp x_tfidf
|
28
|
+
# # Numo::DFloat#shape=[2,3]
|
29
|
+
# # [[0.959056, 0, 0.283217],
|
30
|
+
# # [0, 0.491506, 0.870874]]
|
31
|
+
#
|
32
|
+
# *Reference*
|
33
|
+
# - Manning, C D., Raghavan, P., and Schutze, H., "Introduction to Information Retrieval," Cambridge University Press., 2008.
|
34
|
+
class TfidfTransformer
|
35
|
+
include Base::BaseEstimator
|
36
|
+
include Base::Transformer
|
37
|
+
|
38
|
+
# Return the vector consists of inverse document frequency.
|
39
|
+
# @return [Numo::DFloat] (shape: [n_features])
|
40
|
+
attr_reader :idf
|
41
|
+
|
42
|
+
# Create a new transfomer for converting tf vectors to tf-idf vectors.
|
43
|
+
#
|
44
|
+
# @param norm [String] The normalization method to be used ('l1', 'l2' and 'none').
|
45
|
+
# @param use_idf [Boolean] The flag indicating whether to use inverse document frequency weighting.
|
46
|
+
# @param smooth_idf [Boolean] The flag indicating whether to apply idf smoothing by log((n_samples + 1) / (df + 1)) + 1.
|
47
|
+
# @param sublinear_tf [Boolean] The flag indicating whether to perform subliner tf scaling by 1 + log(tf).
|
48
|
+
def initialize(norm: 'l2', use_idf: true, smooth_idf: false, sublinear_tf: false)
|
49
|
+
check_params_string(norm: norm)
|
50
|
+
check_params_boolean(use_idf: use_idf, smooth_idf: smooth_idf, sublinear_tf: sublinear_tf)
|
51
|
+
@params = {}
|
52
|
+
@params[:norm] = norm
|
53
|
+
@params[:use_idf] = use_idf
|
54
|
+
@params[:smooth_idf] = smooth_idf
|
55
|
+
@params[:sublinear_tf] = sublinear_tf
|
56
|
+
@idf = nil
|
57
|
+
end
|
58
|
+
|
59
|
+
# Calculate the inverse document frequency for weighting.
|
60
|
+
#
|
61
|
+
# @overload fit(x) -> TfidfTransformer
|
62
|
+
#
|
63
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate the idf values.
|
64
|
+
# @return [TfidfTransformer]
|
65
|
+
def fit(x, _y = nil)
|
66
|
+
return self unless @params[:use_idf]
|
67
|
+
|
68
|
+
x = check_convert_sample_array(x)
|
69
|
+
|
70
|
+
n_samples = x.shape[0]
|
71
|
+
df = x.class.cast(x.gt(0.0).count(0))
|
72
|
+
|
73
|
+
if @params[:smooth_idf]
|
74
|
+
df += 1
|
75
|
+
n_samples += 1
|
76
|
+
end
|
77
|
+
|
78
|
+
@idf = Numo::NMath.log(n_samples / df) + 1
|
79
|
+
|
80
|
+
self
|
81
|
+
end
|
82
|
+
|
83
|
+
# Calculate the idf values, and then transfrom samples to the tf-idf representation.
|
84
|
+
#
|
85
|
+
# @overload fit_transform(x) -> Numo::DFloat
|
86
|
+
#
|
87
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to calculate idf and be transformed to tf-idf representation.
|
88
|
+
# @return [Numo::DFloat] The transformed samples.
|
89
|
+
def fit_transform(x, _y = nil)
|
90
|
+
fit(x).transform(x)
|
91
|
+
end
|
92
|
+
|
93
|
+
# Perform transforming the given samples to the tf-idf representation.
|
94
|
+
#
|
95
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be transformed.
|
96
|
+
# @return [Numo::DFloat] The transformed samples.
|
97
|
+
def transform(x)
|
98
|
+
x = check_convert_sample_array(x)
|
99
|
+
z = x.dup
|
100
|
+
|
101
|
+
z[z.ne(0)] = Numo::NMath.log(z[z.ne(0)]) + 1 if @params[:sublinear_tf]
|
102
|
+
z *= @idf if @params[:use_idf]
|
103
|
+
case @params[:norm]
|
104
|
+
when 'l2'
|
105
|
+
z = Rumale::Preprocessing::L2Normalizer.new.fit_transform(z)
|
106
|
+
when 'l1'
|
107
|
+
z = Rumale::Preprocessing::L1Normalizer.new.fit_transform(z)
|
108
|
+
end
|
109
|
+
z
|
110
|
+
end
|
111
|
+
end
|
112
|
+
end
|
113
|
+
end
|
@@ -69,7 +69,7 @@ module Rumale
|
|
69
69
|
n_components = [1, [@params[:n_components], n_samples].min].max
|
70
70
|
|
71
71
|
# random sampling.
|
72
|
-
@component_indices = Numo::Int32.cast(
|
72
|
+
@component_indices = Numo::Int32.cast(Array(0...n_samples).shuffle(random: sub_rng)[0...n_components])
|
73
73
|
@components = x[@component_indices, true]
|
74
74
|
|
75
75
|
# calculate normalizing factor.
|
@@ -172,7 +172,7 @@ module Rumale
|
|
172
172
|
# Start optimization.
|
173
173
|
@params[:max_iter].times do |t|
|
174
174
|
# random sampling
|
175
|
-
rand_ids =
|
175
|
+
rand_ids = Array(0...n_training_samples).shuffle(random: sub_rng) if rand_ids.empty?
|
176
176
|
target_id = rand_ids.shift
|
177
177
|
# update the weight vector
|
178
178
|
func = (weight_vec * bin_y).dot(x[target_id, true].transpose).to_f
|
@@ -209,7 +209,7 @@ module Rumale
|
|
209
209
|
l1_penalty = LinearModel::Penalty::L1Penalty.new(reg_param: l1_reg_param) if apply_l1_penalty?
|
210
210
|
# Optimization.
|
211
211
|
@params[:max_iter].times do |t|
|
212
|
-
sample_ids =
|
212
|
+
sample_ids = Array(0...n_samples)
|
213
213
|
sample_ids.shuffle!(random: sub_rng)
|
214
214
|
until (subset_ids = sample_ids.shift(@params[:batch_size])).empty?
|
215
215
|
# sampling
|
@@ -2,13 +2,15 @@
|
|
2
2
|
|
3
3
|
require 'rumale/base/base_estimator'
|
4
4
|
require 'rumale/base/transformer'
|
5
|
-
require 'mopti/scaled_conjugate_gradient'
|
6
5
|
|
7
6
|
module Rumale
|
8
7
|
module MetricLearning
|
9
8
|
# NeighbourhoodComponentAnalysis is a class that implements Neighbourhood Component Analysis.
|
10
9
|
#
|
11
10
|
# @example
|
11
|
+
# require 'mopti'
|
12
|
+
# require 'rumale'
|
13
|
+
#
|
12
14
|
# transformer = Rumale::MetricLearning::NeighbourhoodComponentAnalysis.new
|
13
15
|
# transformer.fit(training_samples, traininig_labels)
|
14
16
|
# low_samples = transformer.transform(testing_samples)
|
@@ -63,6 +65,8 @@ module Rumale
|
|
63
65
|
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
64
66
|
# @return [NeighbourhoodComponentAnalysis] The learned classifier itself.
|
65
67
|
def fit(x, y)
|
68
|
+
raise 'NeighbourhoodComponentAnalysis#fit requires Mopti but that is not loaded.' unless enable_mopti?
|
69
|
+
|
66
70
|
x = check_convert_sample_array(x)
|
67
71
|
y = check_convert_label_array(y)
|
68
72
|
check_sample_label_size(x, y)
|
@@ -98,6 +102,14 @@ module Rumale
|
|
98
102
|
|
99
103
|
private
|
100
104
|
|
105
|
+
def enable_mopti?
|
106
|
+
if defined?(Mopti).nil?
|
107
|
+
warn('NeighbourhoodComponentAnalysis#fit requires Mopti but that is not loaded. You should intall and load mopti gem in advance.')
|
108
|
+
return false
|
109
|
+
end
|
110
|
+
true
|
111
|
+
end
|
112
|
+
|
101
113
|
def init_components(x, n_features, n_components)
|
102
114
|
if @params[:init] == 'pca'
|
103
115
|
pca = Rumale::Decomposition::PCA.new(n_components: n_components, solver: 'evd')
|
@@ -69,10 +69,11 @@ module Rumale
|
|
69
69
|
# the return_train_score is false.
|
70
70
|
def perform(x, y)
|
71
71
|
x = check_convert_sample_array(x)
|
72
|
-
|
72
|
+
case @estimator
|
73
|
+
when Rumale::Base::Classifier
|
73
74
|
y = check_convert_label_array(y)
|
74
75
|
check_sample_label_size(x, y)
|
75
|
-
|
76
|
+
when Rumale::Base::Regressor
|
76
77
|
y = check_convert_tvalue_array(y)
|
77
78
|
check_sample_tvalue_size(x, y)
|
78
79
|
else
|
@@ -62,7 +62,7 @@ module Rumale
|
|
62
62
|
end
|
63
63
|
sub_rng = @rng.dup
|
64
64
|
# Splits dataset ids to each fold.
|
65
|
-
dataset_ids =
|
65
|
+
dataset_ids = Array(0...n_samples)
|
66
66
|
dataset_ids.shuffle!(random: sub_rng) if @shuffle
|
67
67
|
fold_sets = Array.new(@n_splits) do |n|
|
68
68
|
n_fold_samples = n_samples / @n_splits
|
@@ -74,7 +74,7 @@ module Rumale
|
|
74
74
|
end
|
75
75
|
sub_rng = @rng.dup
|
76
76
|
# Returns array consisting of the training and testing ids for each fold.
|
77
|
-
dataset_ids =
|
77
|
+
dataset_ids = Array(0...n_samples)
|
78
78
|
Array.new(@n_splits) do
|
79
79
|
test_ids = dataset_ids.sample(n_test_samples, random: sub_rng)
|
80
80
|
train_ids = if @train_size.nil?
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
-
require 'rumale/base/base_estimator
|
4
|
-
require 'rumale/base/classifier
|
3
|
+
require 'rumale/base/base_estimator'
|
4
|
+
require 'rumale/base/classifier'
|
5
5
|
|
6
6
|
module Rumale
|
7
7
|
# This module consists of the classes that implement multi-class classification strategy.
|
@@ -30,7 +30,7 @@ module Rumale
|
|
30
30
|
@params = {}
|
31
31
|
@params[:min_samples_leaf] = min_samples_leaf
|
32
32
|
@data = x
|
33
|
-
@tree = build_tree(Numo::Int32.cast(
|
33
|
+
@tree = build_tree(Numo::Int32.cast(Array(0...@data.shape[0])))
|
34
34
|
end
|
35
35
|
|
36
36
|
# Search k-nearest neighbors of given query point.
|
@@ -32,7 +32,7 @@ module Rumale
|
|
32
32
|
end
|
33
33
|
|
34
34
|
# @!visibility private
|
35
|
-
# Calculate the updated weight with
|
35
|
+
# Calculate the updated weight with Adam adaptive learning rate.
|
36
36
|
#
|
37
37
|
# @param weight [Numo::DFloat] (shape: [n_features]) The weight to be updated.
|
38
38
|
# @param gradient [Numo::DFloat] (shape: [n_features]) The gradient for updating the weight.
|
@@ -222,7 +222,7 @@ module Rumale
|
|
222
222
|
n_samples = x.shape[0]
|
223
223
|
|
224
224
|
@params[:max_iter].times do |t|
|
225
|
-
sample_ids =
|
225
|
+
sample_ids = Array(0...n_samples)
|
226
226
|
sample_ids.shuffle!(random: srng)
|
227
227
|
until (subset_ids = sample_ids.shift(@params[:batch_size])).empty?
|
228
228
|
# random sampling
|
@@ -0,0 +1,60 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'rumale/base/base_estimator'
|
4
|
+
require 'rumale/base/transformer'
|
5
|
+
|
6
|
+
module Rumale
|
7
|
+
module Preprocessing
|
8
|
+
# Binarize samples according to a threshold
|
9
|
+
#
|
10
|
+
# @example
|
11
|
+
# binarizer = Rumale::Preprocessing::Binarizer.new
|
12
|
+
# x = Numo::DFloat[[-1.2, 3.2], [2.4, -0.5], [4.5, 0.8]]
|
13
|
+
# b = binarizer.transform(x)
|
14
|
+
# p b
|
15
|
+
#
|
16
|
+
# # Numo::DFloat#shape=[3, 2]
|
17
|
+
# # [[0, 1],
|
18
|
+
# # [1, 0],
|
19
|
+
# # [1, 1]]
|
20
|
+
class Binarizer
|
21
|
+
include Base::BaseEstimator
|
22
|
+
include Base::Transformer
|
23
|
+
|
24
|
+
# Create a new transformer for binarization.
|
25
|
+
# @param threshold [Float] The threshold value for binarization.
|
26
|
+
def initialize(threshold: 0.0)
|
27
|
+
check_params_numeric(threshold: threshold)
|
28
|
+
@params = { threshold: threshold }
|
29
|
+
end
|
30
|
+
|
31
|
+
# This method does nothing and returns the object itself.
|
32
|
+
# For compatibility with other transformer, this method exists.
|
33
|
+
#
|
34
|
+
# @overload fit() -> Binarizer
|
35
|
+
#
|
36
|
+
# @return [Binarizer]
|
37
|
+
def fit(_x = nil, _y = nil)
|
38
|
+
self
|
39
|
+
end
|
40
|
+
|
41
|
+
# Binarize each sample.
|
42
|
+
#
|
43
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be binarized.
|
44
|
+
# @return [Numo::DFloat] The binarized samples.
|
45
|
+
def transform(x)
|
46
|
+
x = check_convert_sample_array(x)
|
47
|
+
x.class.cast(x.gt(@params[:threshold]))
|
48
|
+
end
|
49
|
+
|
50
|
+
# The output of this method is the same as that of the transform method.
|
51
|
+
# For compatibility with other transformer, this method exists.
|
52
|
+
#
|
53
|
+
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be binarized.
|
54
|
+
# @return [Numo::DFloat] The binarized samples.
|
55
|
+
def fit_transform(x, _y = nil)
|
56
|
+
fit(x).transform(x)
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
60
|
+
end
|