rumale-pipeline 0.24.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: cd3ec61c4a23cd7022d5be945e4bd22dadbdb883e6800a7d5e2026b05e6c5cf0
4
+ data.tar.gz: a6be2df1a40c169dd46a5b587df6314ab438c507c5bd721d802a422dfcd66bea
5
+ SHA512:
6
+ metadata.gz: 8ebc9359645c3e8300e6eb797c9ecbb2da7f99b77522e44d8619562ad3ffcdde681e1ed9256fb7fe5882d36a6c67ccf6c6f6386ee5f0f68a801cc16208305845
7
+ data.tar.gz: 14d7861bb9e31aba0924f021f320dc1c0d44e7880f9bdd03e465f3b3bc9d26132bce259317a96f3644f1c9bf753b6896e9fcd6727882c8ad3d01078eea62f842
data/LICENSE.txt ADDED
@@ -0,0 +1,27 @@
1
+ Copyright (c) 2022 Atsushi Tatsuma
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+
7
+ * Redistributions of source code must retain the above copyright notice, this
8
+ list of conditions and the following disclaimer.
9
+
10
+ * Redistributions in binary form must reproduce the above copyright notice,
11
+ this list of conditions and the following disclaimer in the documentation
12
+ and/or other materials provided with the distribution.
13
+
14
+ * Neither the name of the copyright holder nor the names of its
15
+ contributors may be used to endorse or promote products derived from
16
+ this software without specific prior written permission.
17
+
18
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
data/README.md ADDED
@@ -0,0 +1,33 @@
1
+ # Rumale::Pipeline
2
+
3
+ [![Gem Version](https://badge.fury.io/rb/rumale-pipeline.svg)](https://badge.fury.io/rb/rumale-pipeline)
4
+ [![BSD 3-Clause License](https://img.shields.io/badge/License-BSD%203--Clause-orange.svg)](https://github.com/yoshoku/rumale/blob/main/rumale-pipeline/LICENSE.txt)
5
+ [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale/doc/Rumale/Pipeline.html)
6
+
7
+ Rumale is a machine learning library in Ruby.
8
+ Rumale::Pipeline provides classes for chaining transformers and estimators
9
+ with Rumale interface.
10
+
11
+ ## Installation
12
+
13
+ Add this line to your application's Gemfile:
14
+
15
+ ```ruby
16
+ gem 'rumale-pipeline'
17
+ ```
18
+
19
+ And then execute:
20
+
21
+ $ bundle install
22
+
23
+ Or install it yourself as:
24
+
25
+ $ gem install rumale-pipeline
26
+
27
+ ## Documentation
28
+
29
+ - [Rumale API Documentation - Pipeline](https://yoshoku.github.io/rumale/doc/Rumale/Pipeline.html)
30
+
31
+ ## License
32
+
33
+ The gem is available as open source under the terms of the [BSD-3-Clause License](https://opensource.org/licenses/BSD-3-Clause).
@@ -0,0 +1,69 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+
5
+ module Rumale
6
+ module Pipeline
7
+ # FeatureUnion is a class that implements the function concatenating the multi-transformer results.
8
+ #
9
+ # @example
10
+ # require 'rumale/kernel_approximation/rbf'
11
+ # require 'rumale/decomposition/pca'
12
+ # require 'rumale/pipeline/feature_union'
13
+ #
14
+ # fu = Rumale::Pipeline::FeatureUnion.new(
15
+ # transformers: {
16
+ # 'rbf': Rumale::KernelApproximation::RBF.new(gamma: 1.0, n_components: 96, random_seed: 1),
17
+ # 'pca': Rumale::Decomposition::PCA.new(n_components: 32)
18
+ # }
19
+ # )
20
+ # fu.fit(training_samples, traininig_labels)
21
+ # results = fu.predict(testing_samples)
22
+ #
23
+ # # > p results.shape[1]
24
+ # # > 128
25
+ #
26
+ class FeatureUnion < ::Rumale::Base::Estimator
27
+ # Return the transformers
28
+ # @return [Hash]
29
+ attr_reader :transformers
30
+
31
+ # Create a new feature union.
32
+ #
33
+ # @param transformers [Hash] List of transformers. The order of transforms follows the insertion order of hash keys.
34
+ def initialize(transformers:)
35
+ super()
36
+ @params = {}
37
+ @transformers = transformers
38
+ end
39
+
40
+ # Fit the model with given training data.
41
+ #
42
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the transformers.
43
+ # @param y [Numo::NArray/Nil] (shape: [n_samples, n_outputs]) The target values or labels to be used for fitting the transformers.
44
+ # @return [FeatureUnion] The learned feature union itself.
45
+ def fit(x, y = nil)
46
+ @transformers.each { |_k, t| t.fit(x, y) }
47
+ self
48
+ end
49
+
50
+ # Fit the model with training data, and then transform them with the learned model.
51
+ #
52
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the transformers.
53
+ # @param y [Numo::NArray/Nil] (shape: [n_samples, n_outputs]) The target values or labels to be used for fitting the transformers.
54
+ # @return [Numo::DFloat] (shape: [n_samples, sum_n_components]) The transformed and concatenated data.
55
+ def fit_transform(x, y = nil)
56
+ fit(x, y).transform(x)
57
+ end
58
+
59
+ # Transform the given data with the learned model.
60
+ #
61
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned transformers.
62
+ # @return [Numo::DFloat] (shape: [n_samples, sum_n_components]) The transformed and concatenated data.
63
+ def transform(x)
64
+ z = @transformers.values.map { |t| t.transform(x) }
65
+ Numo::NArray.hstack(z)
66
+ end
67
+ end
68
+ end
69
+ end
@@ -0,0 +1,175 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/estimator'
4
+
5
+ module Rumale
6
+ # Module implements utilities of pipeline that cosists of a chain of transfomers and estimators.
7
+ module Pipeline
8
+ # Pipeline is a class that implements the function to perform the transformers and estimators sequencially.
9
+ #
10
+ # @example
11
+ # require 'rumale/kernel_approximation/rbf'
12
+ # require 'rumale/linear_model/svc'
13
+ # require 'rumale/pipeline/pipeline'
14
+ #
15
+ # rbf = Rumale::KernelApproximation::RBF.new(gamma: 1.0, n_components: 128, random_seed: 1)
16
+ # svc = Rumale::LinearModel::SVC.new(reg_param: 1.0, fit_bias: true, max_iter: 5000, random_seed: 1)
17
+ # pipeline = Rumale::Pipeline::Pipeline.new(steps: { trs: rbf, est: svc })
18
+ # pipeline.fit(training_samples, traininig_labels)
19
+ # results = pipeline.predict(testing_samples)
20
+ #
21
+ class Pipeline < ::Rumale::Base::Estimator
22
+ # Return the steps.
23
+ # @return [Hash]
24
+ attr_reader :steps
25
+
26
+ # Create a new pipeline.
27
+ #
28
+ # @param steps [Hash] List of transformers and estimators. The order of transforms follows the insertion order of hash keys.
29
+ # The last entry is considered an estimator.
30
+ def initialize(steps:)
31
+ super()
32
+ validate_steps(steps)
33
+ @params = {}
34
+ @steps = steps
35
+ end
36
+
37
+ # Fit the model with given training data.
38
+ #
39
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be transformed and used for fitting the model.
40
+ # @param y [Numo::NArray] (shape: [n_samples, n_outputs]) The target values or labels to be used for fitting the model.
41
+ # @return [Pipeline] The learned pipeline itself.
42
+ def fit(x, y)
43
+ trans_x = apply_transforms(x, y, fit: true)
44
+ last_estimator&.fit(trans_x, y)
45
+ self
46
+ end
47
+
48
+ # Call the fit_predict method of last estimator after applying all transforms.
49
+ #
50
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be transformed and used for fitting the model.
51
+ # @param y [Numo::NArray] (shape: [n_samples, n_outputs], default: nil) The target values or labels to be used for fitting the model.
52
+ # @return [Numo::NArray] The predicted results by last estimator.
53
+ def fit_predict(x, y = nil)
54
+ trans_x = apply_transforms(x, y, fit: true)
55
+ last_estimator.fit_predict(trans_x)
56
+ end
57
+
58
+ # Call the fit_transform method of last estimator after applying all transforms.
59
+ #
60
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be transformed and used for fitting the model.
61
+ # @param y [Numo::NArray] (shape: [n_samples, n_outputs], default: nil) The target values or labels to be used for fitting the model.
62
+ # @return [Numo::NArray] The predicted results by last estimator.
63
+ def fit_transform(x, y = nil)
64
+ trans_x = apply_transforms(x, y, fit: true)
65
+ last_estimator.fit_transform(trans_x, y)
66
+ end
67
+
68
+ # Call the decision_function method of last estimator after applying all transforms.
69
+ #
70
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
71
+ # @return [Numo::DFloat] (shape: [n_samples]) Confidence score per sample.
72
+ def decision_function(x)
73
+ trans_x = apply_transforms(x)
74
+ last_estimator.decision_function(trans_x)
75
+ end
76
+
77
+ # Call the predict method of last estimator after applying all transforms.
78
+ #
79
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to obtain prediction result.
80
+ # @return [Numo::NArray] The predicted results by last estimator.
81
+ def predict(x)
82
+ trans_x = apply_transforms(x)
83
+ last_estimator.predict(trans_x)
84
+ end
85
+
86
+ # Call the predict_log_proba method of last estimator after applying all transforms.
87
+ #
88
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the log-probailities.
89
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted log-probability of each class per sample.
90
+ def predict_log_proba(x)
91
+ trans_x = apply_transforms(x)
92
+ last_estimator.predict_log_proba(trans_x)
93
+ end
94
+
95
+ # Call the predict_proba method of last estimator after applying all transforms.
96
+ #
97
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
98
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
99
+ def predict_proba(x)
100
+ trans_x = apply_transforms(x)
101
+ last_estimator.predict_proba(trans_x)
102
+ end
103
+
104
+ # Call the transform method of last estimator after applying all transforms.
105
+ #
106
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be transformed.
107
+ # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed samples.
108
+ def transform(x)
109
+ trans_x = apply_transforms(x)
110
+ last_estimator.nil? ? trans_x : last_estimator.transform(trans_x)
111
+ end
112
+
113
+ # Call the inverse_transform method in reverse order.
114
+ #
115
+ # @param z [Numo::DFloat] (shape: [n_samples, n_components]) The transformed samples to be restored into original space.
116
+ # @return [Numo::DFloat] (shape: [n_samples, n_featuress]) The restored samples.
117
+ def inverse_transform(z)
118
+ itrans_z = z
119
+ @steps.keys.reverse_each do |name|
120
+ transformer = @steps[name]
121
+ next if transformer.nil?
122
+
123
+ itrans_z = transformer.inverse_transform(itrans_z)
124
+ end
125
+ itrans_z
126
+ end
127
+
128
+ # Call the score method of last estimator after applying all transforms.
129
+ #
130
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) Testing data.
131
+ # @param y [Numo::NArray] (shape: [n_samples, n_outputs]) True target values or labels for testing data.
132
+ # @return [Float] The score of last estimator
133
+ def score(x, y)
134
+ trans_x = apply_transforms(x)
135
+ last_estimator.score(trans_x, y)
136
+ end
137
+
138
+ private
139
+
140
+ def validate_steps(steps)
141
+ steps.keys[0...-1].each do |name|
142
+ transformer = steps[name]
143
+ next if transformer.nil? || (transformer.class.method_defined?(:fit) && transformer.class.method_defined?(:transform))
144
+
145
+ raise TypeError,
146
+ 'Class of intermediate step in pipeline should be implemented fit and transform methods: ' \
147
+ "#{name} => #{transformer.class}"
148
+ end
149
+
150
+ estimator = steps[steps.keys.last]
151
+ unless estimator.nil? || estimator.class.method_defined?(:fit) # rubocop:disable Style/GuardClause
152
+ raise TypeError,
153
+ 'Class of last step in pipeline should be implemented fit method: ' \
154
+ "#{steps.keys.last} => #{estimator.class}"
155
+ end
156
+ end
157
+
158
+ def apply_transforms(x, y = nil, fit: false)
159
+ trans_x = x
160
+ @steps.keys[0...-1].each do |name|
161
+ transformer = @steps[name]
162
+ next if transformer.nil?
163
+
164
+ transformer.fit(trans_x, y) if fit
165
+ trans_x = transformer.transform(trans_x)
166
+ end
167
+ trans_x
168
+ end
169
+
170
+ def last_estimator
171
+ @steps[@steps.keys.last]
172
+ end
173
+ end
174
+ end
175
+ end
@@ -0,0 +1,10 @@
1
+ # frozen_string_literal: true
2
+
3
+ # Rumale is a machine learning library in Ruby.
4
+ module Rumale
5
+ # Module implements utilities of pipeline that cosists of a chain of transfomers and estimators.
6
+ module Pipeline
7
+ # @!visibility private
8
+ VERSION = '0.24.0'
9
+ end
10
+ end
@@ -0,0 +1,7 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'numo/narray'
4
+
5
+ require_relative 'pipeline/feature_union'
6
+ require_relative 'pipeline/pipeline'
7
+ require_relative 'pipeline/version'
metadata ADDED
@@ -0,0 +1,84 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: rumale-pipeline
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.24.0
5
+ platform: ruby
6
+ authors:
7
+ - yoshoku
8
+ autorequire:
9
+ bindir: exe
10
+ cert_chain: []
11
+ date: 2022-12-31 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: numo-narray
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: 0.9.1
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: 0.9.1
27
+ - !ruby/object:Gem::Dependency
28
+ name: rumale-core
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - "~>"
32
+ - !ruby/object:Gem::Version
33
+ version: 0.24.0
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - "~>"
39
+ - !ruby/object:Gem::Version
40
+ version: 0.24.0
41
+ description: Rumale::Pipeline provides classes for chaining transformers and estimators
42
+ with Rumale interface.
43
+ email:
44
+ - yoshoku@outlook.com
45
+ executables: []
46
+ extensions: []
47
+ extra_rdoc_files: []
48
+ files:
49
+ - LICENSE.txt
50
+ - README.md
51
+ - lib/rumale/pipeline.rb
52
+ - lib/rumale/pipeline/feature_union.rb
53
+ - lib/rumale/pipeline/pipeline.rb
54
+ - lib/rumale/pipeline/version.rb
55
+ homepage: https://github.com/yoshoku/rumale
56
+ licenses:
57
+ - BSD-3-Clause
58
+ metadata:
59
+ homepage_uri: https://github.com/yoshoku/rumale
60
+ source_code_uri: https://github.com/yoshoku/rumale/tree/main/rumale-pipeline
61
+ changelog_uri: https://github.com/yoshoku/rumale/blob/main/CHANGELOG.md
62
+ documentation_uri: https://yoshoku.github.io/rumale/doc/
63
+ rubygems_mfa_required: 'true'
64
+ post_install_message:
65
+ rdoc_options: []
66
+ require_paths:
67
+ - lib
68
+ required_ruby_version: !ruby/object:Gem::Requirement
69
+ requirements:
70
+ - - ">="
71
+ - !ruby/object:Gem::Version
72
+ version: '0'
73
+ required_rubygems_version: !ruby/object:Gem::Requirement
74
+ requirements:
75
+ - - ">="
76
+ - !ruby/object:Gem::Version
77
+ version: '0'
78
+ requirements: []
79
+ rubygems_version: 3.3.26
80
+ signing_key:
81
+ specification_version: 4
82
+ summary: Rumale::Pipeline provides classes for chaining transformers and estimators
83
+ with Rumale interface.
84
+ test_files: []