rumale 0.20.2 → 0.22.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 5d8c93acbf38fbd07e5df224010abbdd4269a6ce3bbf8112a0eba652a606785d
4
- data.tar.gz: e7cb00a802420854835c92f011425f3054bfcc1052bf7b3664da1f95834ef435
3
+ metadata.gz: 703a6895f4218ca45c5d5ae5e86559b077cf1be213d4939eb1e9ab94eac4621d
4
+ data.tar.gz: 5862466e565d1e6030c35494b5028ae980a47d373e90050c62266055fcecd374
5
5
  SHA512:
6
- metadata.gz: f95fdd89b84dad02e516ee0479b1cddfb101cb96de897b6e7fa3fba546272a243cff5cfe954cb51942ec1ab23cf3028b183db86b52fab00a35d15be7eee5bf92
7
- data.tar.gz: e5f6235e88dd47b9002a2154cabd2c1e64afb6cbb5b0745b411c7e5559351e925c9db8ec332724e301b83215662b3582e79a9e997f0338846514b234dabf1fc3
6
+ metadata.gz: 988d55c681a102e0c65b9133c6aeafc049e33755955f959d6e6046f5601dd192af881424355a2b373ed2e7a5a16b74236698aef5372e09584b10fe28d1b7bc21
7
+ data.tar.gz: adc58efa3b46d9fc1a87ddb2a4df32472507d61f21a3a0eb07026068cc5e41af166fb0a0f8ae23f1b23aec649b22835a50edbed79d35255e8cc231b82b31eb8c
@@ -0,0 +1,23 @@
1
+ name: build
2
+
3
+ on: [push, pull_request]
4
+
5
+ jobs:
6
+ build:
7
+ runs-on: ubuntu-latest
8
+ strategy:
9
+ matrix:
10
+ ruby: [ '2.5', '2.6', '2.7' ]
11
+ steps:
12
+ - uses: actions/checkout@v2
13
+ - name: Install BLAS and LAPACK
14
+ run: sudo apt-get install -y libopenblas-dev liblapacke-dev
15
+ - name: Set up Ruby ${{ matrix.ruby }}
16
+ uses: actions/setup-ruby@v1
17
+ with:
18
+ ruby-version: ${{ matrix.ruby }}
19
+ - name: Build and test with Rake
20
+ run: |
21
+ gem install bundler
22
+ bundle install --jobs 4 --retry 3
23
+ bundle exec rake
@@ -1,5 +1,6 @@
1
1
  require:
2
2
  - rubocop-performance
3
+ - rubocop-rake
3
4
  - rubocop-rspec
4
5
 
5
6
  AllCops:
@@ -20,6 +21,9 @@ Layout/LineLength:
20
21
  Max: 145
21
22
  IgnoredPatterns: ['(\A|\s)#']
22
23
 
24
+ Lint/ConstantDefinitionInBlock:
25
+ Enabled: false
26
+
23
27
  Lint/MissingSuper:
24
28
  Enabled: false
25
29
 
@@ -70,6 +74,9 @@ Style/StringConcatenation:
70
74
  RSpec/MultipleExpectations:
71
75
  Enabled: false
72
76
 
77
+ RSpec/MultipleMemoizedHelpers:
78
+ Max: 25
79
+
73
80
  RSpec/NestedGroups:
74
81
  Max: 4
75
82
 
@@ -81,3 +88,6 @@ RSpec/InstanceVariable:
81
88
 
82
89
  RSpec/LeakyConstantDeclaration:
83
90
  Enabled: false
91
+
92
+ Performance/Sum:
93
+ Enabled: false
@@ -1,3 +1,29 @@
1
+ # 0.22.2
2
+ - Add classifier and regressor classes for stacking method.
3
+ - [StackingClassifier](https://yoshoku.github.io/rumale/doc/Rumale/Ensemble/StackingClassifier.html)
4
+ - [StackingRegressor](https://yoshoku.github.io/rumale/doc/Rumale/Ensemble/StackingRegressor.html)
5
+ - Refactor some codes with Rubocop.
6
+
7
+ # 0.22.1
8
+ - Add transfomer class for [MLKR](https://yoshoku.github.io/rumale/doc/Rumale/MetricLearning/MLKR.html), that implements Metric Learning for Kernel Regression.
9
+ - Refactor NeighbourhoodComponentAnalysis.
10
+ - Update API documentation.
11
+
12
+ # 0.22.0
13
+ ## Breaking change
14
+ - Add lbfgsb.rb gem to runtime dependencies. Rumale uses lbfgsb gem for optimization.
15
+ This eliminates the need to require the mopti gem when using [NeighbourhoodComponentAnalysis](https://yoshoku.github.io/rumale/doc/Rumale/MetricLearning/NeighbourhoodComponentAnalysis.html).
16
+ - Add lbfgs solver to [LogisticRegression](https://yoshoku.github.io/rumale/doc/Rumale/LinearModel/LogisticRegression.html) and make it the default solver.
17
+
18
+ # 0.21.0
19
+ ## Breaking change
20
+ - Change the default value of max_iter argument on LinearModel estimators to 1000.
21
+
22
+ # 0.20.3
23
+ - Fix to use automatic solver of PCA in NeighbourhoodComponentAnalysis.
24
+ - Refactor some codes with Rubocop.
25
+ - Update README.
26
+
1
27
  # 0.20.2
2
28
  - Add cross-validator class for time-series data.
3
29
  - [TimeSeriesSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/TimeSeriesSplit.html)
data/Gemfile CHANGED
@@ -3,11 +3,14 @@ source 'https://rubygems.org'
3
3
  # Specify your gem's dependencies in rumale.gemspec
4
4
  gemspec
5
5
 
6
- gem 'coveralls', '~> 0.8'
7
6
  gem 'mmh3', '>= 1.0'
8
- gem 'mopti', '>= 0.1.0'
9
7
  gem 'numo-linalg', '>= 0.1.4'
10
8
  gem 'parallel', '>= 1.17.0'
11
9
  gem 'rake', '~> 12.0'
12
10
  gem 'rake-compiler', '~> 1.0'
13
11
  gem 'rspec', '~> 3.0'
12
+ gem 'rubocop', '~> 1.0'
13
+ gem 'rubocop-performance', '~> 1.8'
14
+ gem 'rubocop-rake', '~> 0.5'
15
+ gem 'rubocop-rspec', '~> 2.0'
16
+ gem 'simplecov', '~> 0.19'
data/README.md CHANGED
@@ -2,10 +2,9 @@
2
2
 
3
3
  ![Rumale](https://dl.dropboxusercontent.com/s/joxruk2720ur66o/rumale_header_400.png)
4
4
 
5
- [![Build Status](https://travis-ci.org/yoshoku/rumale.svg?branch=master)](https://travis-ci.org/yoshoku/rumale)
6
- [![Coverage Status](https://coveralls.io/repos/github/yoshoku/rumale/badge.svg?branch=master)](https://coveralls.io/github/yoshoku/rumale?branch=master)
5
+ [![Build Status](https://github.com/yoshoku/rumale/workflows/build/badge.svg)](https://github.com/yoshoku/rumale/actions?query=workflow%3Abuild)
7
6
  [![Gem Version](https://badge.fury.io/rb/rumale.svg)](https://badge.fury.io/rb/rumale)
8
- [![BSD 2-Clause License](https://img.shields.io/badge/License-BSD%202--Clause-orange.svg)](https://github.com/yoshoku/rumale/blob/master/LICENSE.txt)
7
+ [![BSD 2-Clause License](https://img.shields.io/badge/License-BSD%202--Clause-orange.svg)](https://github.com/yoshoku/rumale/blob/main/LICENSE.txt)
9
8
  [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale/doc/)
10
9
 
11
10
  Rumale (**Ru**by **ma**chine **le**arning) is a machine learning library in Ruby.
@@ -114,10 +113,10 @@ require 'rumale'
114
113
  samples, labels = Rumale::Dataset.load_libsvm_file('pendigits')
115
114
 
116
115
  # Define the estimator to be evaluated.
117
- lr = Rumale::LinearModel::LogisticRegression.new(learning_rate: 0.00001, reg_param: 0.0001, random_seed: 1)
116
+ lr = Rumale::LinearModel::LogisticRegression.new
118
117
 
119
118
  # Define the evaluation measure, splitting strategy, and cross validation.
120
- ev = Rumale::EvaluationMeasure::LogLoss.new
119
+ ev = Rumale::EvaluationMeasure::Accuracy.new
121
120
  kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5, shuffle: true, random_seed: 1)
122
121
  cv = Rumale::ModelSelection::CrossValidation.new(estimator: lr, splitter: kf, evaluator: ev)
123
122
 
@@ -125,15 +124,15 @@ cv = Rumale::ModelSelection::CrossValidation.new(estimator: lr, splitter: kf, ev
125
124
  report = cv.perform(samples, labels)
126
125
 
127
126
  # Output result.
128
- mean_logloss = report[:test_score].inject(:+) / kf.n_splits
129
- puts("5-CV mean log-loss: %.3f" % mean_logloss)
127
+ mean_accuracy = report[:test_score].sum / kf.n_splits
128
+ puts "5-CV mean accuracy: %.1f%%" % (100.0 * mean_accuracy)
130
129
  ```
131
130
 
132
131
  Execution of the above scripts result in the following.
133
132
 
134
133
  ```bash
135
134
  $ ruby cross_validation.rb
136
- 5-CV mean log-loss: 0.355
135
+ 5-CV mean accuracy: 95.4%
137
136
  ```
138
137
 
139
138
  ### Example 3. Pipeline
@@ -144,10 +143,10 @@ require 'rumale'
144
143
  # Load dataset.
145
144
  samples, labels = Rumale::Dataset.load_libsvm_file('pendigits')
146
145
 
147
- # Construct pipeline with kernel approximation and SVC.
148
- rbf = Rumale::KernelApproximation::RBF.new(gamma: 0.0001, n_components: 800, random_seed: 1)
149
- svc = Rumale::LinearModel::SVC.new(reg_param: 0.0001, random_seed: 1)
150
- pipeline = Rumale::Pipeline::Pipeline.new(steps: { trns: rbf, clsf: svc })
146
+ # Construct pipeline with kernel approximation and LogisticRegression.
147
+ rbf = Rumale::KernelApproximation::RBF.new(gamma: 1e-4, n_components: 800, random_seed: 1)
148
+ lr = Rumale::LinearModel::LogisticRegression.new(reg_param: 1e-3)
149
+ pipeline = Rumale::Pipeline::Pipeline.new(steps: { trns: rbf, clsf: lr })
151
150
 
152
151
  # Define the splitting strategy and cross validation.
153
152
  kf = Rumale::ModelSelection::StratifiedKFold.new(n_splits: 5, shuffle: true, random_seed: 1)
@@ -157,7 +156,7 @@ cv = Rumale::ModelSelection::CrossValidation.new(estimator: pipeline, splitter:
157
156
  report = cv.perform(samples, labels)
158
157
 
159
158
  # Output result.
160
- mean_accuracy = report[:test_score].inject(:+) / kf.n_splits
159
+ mean_accuracy = report[:test_score].sum / kf.n_splits
161
160
  puts("5-CV mean accuracy: %.1f %%" % (mean_accuracy * 100.0))
162
161
  ```
163
162
 
@@ -228,6 +227,10 @@ When -1 is given to n_jobs parameter, all processors are used.
228
227
  estimator = Rumale::Ensemble::RandomForestClassifier.new(n_jobs: -1, random_seed: 1)
229
228
  ```
230
229
 
230
+ ## Novelties
231
+
232
+ * [Rumale SHOP](https://suzuri.jp/yoshoku)
233
+
231
234
  ## Contributing
232
235
 
233
236
  Bug reports and pull requests are welcome on GitHub at https://github.com/yoshoku/rumale.
@@ -241,4 +244,4 @@ The gem is available as open source under the terms of the [BSD 2-clause License
241
244
  ## Code of Conduct
242
245
 
243
246
  Everyone interacting in the Rumale project’s codebases, issue trackers,
244
- chat rooms and mailing lists is expected to follow the [code of conduct](https://github.com/yoshoku/Rumale/blob/master/CODE_OF_CONDUCT.md).
247
+ chat rooms and mailing lists is expected to follow the [code of conduct](https://github.com/yoshoku/Rumale/blob/main/CODE_OF_CONDUCT.md).
@@ -59,6 +59,8 @@ require 'rumale/ensemble/random_forest_classifier'
59
59
  require 'rumale/ensemble/random_forest_regressor'
60
60
  require 'rumale/ensemble/extra_trees_classifier'
61
61
  require 'rumale/ensemble/extra_trees_regressor'
62
+ require 'rumale/ensemble/stacking_classifier'
63
+ require 'rumale/ensemble/stacking_regressor'
62
64
  require 'rumale/clustering/k_means'
63
65
  require 'rumale/clustering/mini_batch_k_means'
64
66
  require 'rumale/clustering/k_medoids'
@@ -77,6 +79,7 @@ require 'rumale/manifold/tsne'
77
79
  require 'rumale/manifold/mds'
78
80
  require 'rumale/metric_learning/fisher_discriminant_analysis'
79
81
  require 'rumale/metric_learning/neighbourhood_component_analysis'
82
+ require 'rumale/metric_learning/mlkr'
80
83
  require 'rumale/neural_network/adam'
81
84
  require 'rumale/neural_network/base_mlp'
82
85
  require 'rumale/neural_network/mlp_regressor'
@@ -51,7 +51,7 @@ module Rumale
51
51
  # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for cluster analysis.
52
52
  # If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
53
53
  # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
54
- def fit_predict(x)
54
+ def fit_predict(x) # rubocop:disable Lint/UselessMethodDefinition
55
55
  super
56
56
  end
57
57
 
@@ -59,7 +59,7 @@ module Rumale
59
59
  @params[:solver] = if solver == 'auto'
60
60
  load_linalg? ? 'evd' : 'fpt'
61
61
  else
62
- solver != 'evd' ? 'fpt' : 'evd'
62
+ solver != 'evd' ? 'fpt' : 'evd' # rubocop:disable Style/NegatedIfElseCondition
63
63
  end
64
64
  @params[:n_components] = n_components
65
65
  @params[:max_iter] = max_iter
@@ -0,0 +1,214 @@
1
+ # frozen_string_literal: true
2
+
3
+ require 'rumale/base/base_estimator'
4
+ require 'rumale/base/classifier'
5
+
6
+ module Rumale
7
+ module Ensemble
8
+ # StackingClassifier is a class that implements classifier with stacking method.
9
+ #
10
+ # @example
11
+ # estimators = {
12
+ # lgr: Rumale::LinearModel::LogisticRegression.new(reg_param: 1e-2, random_seed: 1),
13
+ # mlp: Rumele::NeuralNetwork::MLPClassifier.new(hidden_units: [256], random_seed: 1),
14
+ # rnd: Rumale::Ensemble::RandomForestClassifier.new(random_seed: 1)
15
+ # }
16
+ # meta_estimator = Rumale::LinearModel::LogisticRegression.new(random_seed: 1)
17
+ # classifier = Rumale::Ensemble::StackedClassifier.new(
18
+ # estimators: estimators, meta_estimator: meta_estimator, random_seed: 1
19
+ # )
20
+ # classifier.fit(training_samples, traininig_labels)
21
+ # results = classifier.predict(testing_samples)
22
+ #
23
+ # *Reference*
24
+ # - Zhou, Z-H., "Ensemble Mehotds - Foundations and Algorithms," CRC Press Taylor and Francis Group, Chapman and Hall/CRC, 2012.
25
+ class StackingClassifier
26
+ include Base::BaseEstimator
27
+ include Base::Classifier
28
+
29
+ # Return the base classifiers.
30
+ # @return [Hash<Symbol,Classifier>]
31
+ attr_reader :estimators
32
+
33
+ # Return the meta classifier.
34
+ # @return [Classifier]
35
+ attr_reader :meta_estimator
36
+
37
+ # Return the class labels.
38
+ # @return [Numo::Int32] (size: n_classes)
39
+ attr_reader :classes
40
+
41
+ # Return the method used by each base classifier.
42
+ # @return [Hash<Symbol,Symbol>]
43
+ attr_reader :stack_method
44
+
45
+ # Create a new classifier with stacking method.
46
+ #
47
+ # @param estimators [Hash<Symbol,Classifier>] The base classifiers for extracting meta features.
48
+ # @param meta_estimator [Classifier/Nil] The meta classifier that predicts class label.
49
+ # If nil is given, LogisticRegression is used.
50
+ # @param n_splits [Integer] The number of folds for cross validation with stratified k-fold on meta feature extraction in training phase.
51
+ # @param shuffle [Boolean] The flag indicating whether to shuffle the dataset on cross validation.
52
+ # @param stack_method [String] The method name of base classifier for using meta feature extraction.
53
+ # If 'auto' is given, it searches the callable method in the order 'predict_proba', 'decision_function', and 'predict'
54
+ # on each classifier.
55
+ # @param passthrough [Boolean] The flag indicating whether to concatenate the original features and meta features when training the meta classifier.
56
+ # @param random_seed [Integer/Nil] The seed value using to initialize the random generator on cross validation.
57
+ def initialize(estimators:, meta_estimator: nil, n_splits: 5, shuffle: true, stack_method: 'auto', passthrough: false, random_seed: nil)
58
+ check_params_type(Hash, estimators: estimators)
59
+ check_params_numeric(n_splits: n_splits)
60
+ check_params_string(stack_method: stack_method)
61
+ check_params_boolean(shuffle: shuffle, passthrough: passthrough)
62
+ check_params_numeric_or_nil(random_seed: random_seed)
63
+ @estimators = estimators
64
+ @meta_estimator = meta_estimator || Rumale::LinearModel::LogisticRegression.new
65
+ @classes = nil
66
+ @stack_method = nil
67
+ @output_size = nil
68
+ @params = {}
69
+ @params[:n_splits] = n_splits
70
+ @params[:shuffle] = shuffle
71
+ @params[:stack_method] = stack_method
72
+ @params[:passthrough] = passthrough
73
+ @params[:random_seed] = random_seed || srand
74
+ end
75
+
76
+ # Fit the model with given training data.
77
+ #
78
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
79
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
80
+ # @return [StackedClassifier] The learned classifier itself.
81
+ def fit(x, y)
82
+ x = check_convert_sample_array(x)
83
+ y = check_convert_label_array(y)
84
+ check_sample_label_size(x, y)
85
+
86
+ n_samples, n_features = x.shape
87
+
88
+ @encoder = Rumale::Preprocessing::LabelEncoder.new
89
+ y_encoded = @encoder.fit_transform(y)
90
+ @classes = Numo::NArray[*@encoder.classes]
91
+
92
+ # training base classifiers with all training data.
93
+ @estimators.each_key { |name| @estimators[name].fit(x, y_encoded) }
94
+
95
+ # detecting feature extraction method and its size of output for each base classifier.
96
+ @stack_method = detect_stack_method
97
+ @output_size = detect_output_size(n_features)
98
+
99
+ # extracting meta features with base classifiers.
100
+ n_components = @output_size.values.inject(:+)
101
+ z = Numo::DFloat.zeros(n_samples, n_components)
102
+
103
+ kf = Rumale::ModelSelection::StratifiedKFold.new(
104
+ n_splits: @params[:n_splits], shuffle: @params[:shuffle], random_seed: @params[:random_seed]
105
+ )
106
+
107
+ kf.split(x, y_encoded).each do |train_ids, valid_ids|
108
+ x_train = x[train_ids, true]
109
+ y_train = y_encoded[train_ids]
110
+ x_valid = x[valid_ids, true]
111
+ f_start = 0
112
+ @estimators.each_key do |name|
113
+ est_fold = Marshal.load(Marshal.dump(@estimators[name]))
114
+ f_last = f_start + @output_size[name]
115
+ f_position = @output_size[name] == 1 ? f_start : f_start...f_last
116
+ z[valid_ids, f_position] = est_fold.fit(x_train, y_train).public_send(@stack_method[name], x_valid)
117
+ f_start = f_last
118
+ end
119
+ end
120
+
121
+ # concatenating original features.
122
+ z = Numo::NArray.hstack([z, x]) if @params[:passthrough]
123
+
124
+ # training meta classifier.
125
+ @meta_estimator.fit(z, y_encoded)
126
+
127
+ self
128
+ end
129
+
130
+ # Calculate confidence scores for samples.
131
+ #
132
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to compute the scores.
133
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) The confidence score per sample.
134
+ def decision_function(x)
135
+ x = check_convert_sample_array(x)
136
+ z = transform(x)
137
+ @meta_estimator.decision_function(z)
138
+ end
139
+
140
+ # Predict class labels for samples.
141
+ #
142
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels.
143
+ # @return [Numo::Int32] (shape: [n_samples]) The predicted class label per sample.
144
+ def predict(x)
145
+ x = check_convert_sample_array(x)
146
+ z = transform(x)
147
+ Numo::Int32.cast(@encoder.inverse_transform(@meta_estimator.predict(z)))
148
+ end
149
+
150
+ # Predict probability for samples.
151
+ #
152
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
153
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) The predicted probability of each class per sample.
154
+ def predict_proba(x)
155
+ x = check_convert_sample_array(x)
156
+ z = transform(x)
157
+ @meta_estimator.predict_proba(z)
158
+ end
159
+
160
+ # Transform the given data with the learned model.
161
+ #
162
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be transformed with the learned model.
163
+ # @return [Numo::DFloat] (shape: [n_samples, n_components]) The meta features for samples.
164
+ def transform(x)
165
+ x = check_convert_sample_array(x)
166
+ n_samples = x.shape[0]
167
+ n_components = @output_size.values.inject(:+)
168
+ z = Numo::DFloat.zeros(n_samples, n_components)
169
+ f_start = 0
170
+ @estimators.each_key do |name|
171
+ f_last = f_start + @output_size[name]
172
+ f_position = @output_size[name] == 1 ? f_start : f_start...f_last
173
+ z[true, f_position] = @estimators[name].public_send(@stack_method[name], x)
174
+ f_start = f_last
175
+ end
176
+ z = Numo::NArray.hstack([z, x]) if @params[:passthrough]
177
+ z
178
+ end
179
+
180
+ # Fit the model with training data, and then transform them with the learned model.
181
+ #
182
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
183
+ # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
184
+ # @return [Numo::DFloat] (shape: [n_samples, n_components]) The meta features for training data.
185
+ def fit_transform(x, y)
186
+ x = check_convert_sample_array(x)
187
+ y = check_convert_label_array(y)
188
+ fit(x, y).transform(x)
189
+ end
190
+
191
+ private
192
+
193
+ STACK_METHODS = %i[predict_proba decision_function predict].freeze
194
+
195
+ private_constant :STACK_METHODS
196
+
197
+ def detect_stack_method
198
+ if @params[:stack_method] == 'auto'
199
+ @estimators.each_key.with_object({}) { |name, obj| obj[name] = STACK_METHODS.detect { |m| @estimators[name].respond_to?(m) } }
200
+ else
201
+ @estimators.each_key.with_object({}) { |name, obj| obj[name] = @params[:stack_method].to_sym }
202
+ end
203
+ end
204
+
205
+ def detect_output_size(n_features)
206
+ x_dummy = Numo::DFloat.new(2, n_features).rand
207
+ @estimators.each_key.with_object({}) do |name, obj|
208
+ output_dummy = @estimators[name].public_send(@stack_method[name], x_dummy)
209
+ obj[name] = output_dummy.ndim == 1 ? 1 : output_dummy.shape[1]
210
+ end
211
+ end
212
+ end
213
+ end
214
+ end