rumale 0.20.0 → 0.22.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (40) hide show
  1. checksums.yaml +4 -4
  2. data/.github/workflows/build.yml +23 -0
  3. data/.rubocop.yml +15 -95
  4. data/CHANGELOG.md +28 -0
  5. data/Gemfile +4 -2
  6. data/README.md +5 -2
  7. data/lib/rumale.rb +3 -0
  8. data/lib/rumale/clustering/hdbscan.rb +2 -2
  9. data/lib/rumale/clustering/snn.rb +1 -1
  10. data/lib/rumale/dataset.rb +1 -1
  11. data/lib/rumale/decomposition/nmf.rb +2 -2
  12. data/lib/rumale/ensemble/random_forest_classifier.rb +1 -1
  13. data/lib/rumale/ensemble/random_forest_regressor.rb +1 -1
  14. data/lib/rumale/evaluation_measure/roc_auc.rb +3 -0
  15. data/lib/rumale/feature_extraction/feature_hasher.rb +1 -1
  16. data/lib/rumale/feature_extraction/hash_vectorizer.rb +1 -1
  17. data/lib/rumale/linear_model/base_sgd.rb +1 -1
  18. data/lib/rumale/linear_model/elastic_net.rb +2 -2
  19. data/lib/rumale/linear_model/lasso.rb +2 -2
  20. data/lib/rumale/linear_model/linear_regression.rb +2 -2
  21. data/lib/rumale/linear_model/logistic_regression.rb +123 -35
  22. data/lib/rumale/linear_model/ridge.rb +2 -2
  23. data/lib/rumale/linear_model/svc.rb +2 -2
  24. data/lib/rumale/linear_model/svr.rb +2 -2
  25. data/lib/rumale/manifold/tsne.rb +1 -1
  26. data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +13 -45
  27. data/lib/rumale/model_selection/group_k_fold.rb +93 -0
  28. data/lib/rumale/model_selection/group_shuffle_split.rb +115 -0
  29. data/lib/rumale/model_selection/shuffle_split.rb +4 -4
  30. data/lib/rumale/model_selection/stratified_k_fold.rb +1 -1
  31. data/lib/rumale/model_selection/stratified_shuffle_split.rb +13 -9
  32. data/lib/rumale/model_selection/time_series_split.rb +91 -0
  33. data/lib/rumale/pipeline/pipeline.rb +1 -1
  34. data/lib/rumale/probabilistic_output.rb +1 -1
  35. data/lib/rumale/tree/base_decision_tree.rb +2 -9
  36. data/lib/rumale/tree/gradient_tree_regressor.rb +3 -10
  37. data/lib/rumale/version.rb +1 -1
  38. data/rumale.gemspec +1 -0
  39. metadata +21 -4
  40. data/.coveralls.yml +0 -1
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 358515f8785eb3de2e6571a957ca76cece6b774bb022c1a0951c92d44ab422b4
4
- data.tar.gz: '0289b7eb382cd3300845412af0fd43626f4f827bb719083c879b574e3ab37eb0'
3
+ metadata.gz: 4e2f68b3182ada73537901e7bc74bddd100aff75264f9147c88d8240fb624e29
4
+ data.tar.gz: e2639a55fc84d1399b925f65b3a56b38f2ae3150dd15ab8556120af28d408cae
5
5
  SHA512:
6
- metadata.gz: f03fc0f27f99ed4acea3fb7d7bf34017c1dbf923b20dabc9a78d6d44f0b151bc9dc78ba24d122f81607a43fd1852e398a603b75b87656a2f79109f87c0db0d98
7
- data.tar.gz: 69f6b8892f6bfb4c43706513245c3fba687dcb6a347c1c5185a70d5e45a024b2848a019bfae48726e1f49212878e8d6d67c811ec5f4a990fdbb3a2841efdfe9b
6
+ metadata.gz: 91ffcbade578bbb9c6a5d87a54ebd89a2b5990eb70835e7a5549afe78541dbfeafe3af50833725bee751fa89c059484970e5add7ebf8adee3e25bc000fbe3778
7
+ data.tar.gz: 2ee2b1448a486581ef98561f65bc3446b2e161c89a3a12bd6cd78867350e26151bc0b350bd431902d21f6979493ab2d01a6ee81b55c1099f631aa84c84a704e6
@@ -0,0 +1,23 @@
1
+ name: build
2
+
3
+ on: [push]
4
+
5
+ jobs:
6
+ build:
7
+ runs-on: ubuntu-latest
8
+ strategy:
9
+ matrix:
10
+ ruby: [ '2.5', '2.6', '2.7' ]
11
+ steps:
12
+ - uses: actions/checkout@v2
13
+ - name: Install BLAS and LAPACK
14
+ run: sudo apt-get install -y libopenblas-dev liblapacke-dev
15
+ - name: Set up Ruby ${{ matrix.ruby }}
16
+ uses: actions/setup-ruby@v1
17
+ with:
18
+ ruby-version: ${{ matrix.ruby }}
19
+ - name: Build and test with Rake
20
+ run: |
21
+ gem install bundler
22
+ bundle install --jobs 4 --retry 3
23
+ bundle exec rake
@@ -3,6 +3,7 @@ require:
3
3
  - rubocop-rspec
4
4
 
5
5
  AllCops:
6
+ NewCops: enable
6
7
  TargetRubyVersion: 2.5
7
8
  DisplayCopNames: true
8
9
  DisplayStyleGuide: true
@@ -15,34 +16,16 @@ AllCops:
15
16
  Style/Documentation:
16
17
  Enabled: false
17
18
 
18
- Style/HashEachMethods:
19
- Enabled: true
20
-
21
- Style/HashTransformKeys:
22
- Enabled: true
23
-
24
- Style/HashTransformValues:
25
- Enabled: true
26
-
27
- Lint/DeprecatedOpenSSLConstant:
28
- Enabled: true
29
-
30
- Lint/DuplicateElsifCondition:
31
- Enabled: true
32
-
33
- Lint/MixedRegexpCaptureTypes:
34
- Enabled: true
35
-
36
- Lint/RaiseException:
37
- Enabled: true
38
-
39
- Lint/StructNewOverride:
40
- Enabled: true
41
-
42
19
  Layout/LineLength:
43
20
  Max: 145
44
21
  IgnoredPatterns: ['(\A|\s)#']
45
22
 
23
+ Lint/ConstantDefinitionInBlock:
24
+ Enabled: false
25
+
26
+ Lint/MissingSuper:
27
+ Enabled: false
28
+
46
29
  Metrics/ModuleLength:
47
30
  Max: 200
48
31
 
@@ -78,87 +61,21 @@ Naming/MethodParameterName:
78
61
  Naming/ConstantName:
79
62
  Enabled: false
80
63
 
81
- Style/AccessorGrouping:
82
- Enabled: true
83
-
84
- Style/ArrayCoercion:
85
- Enabled: true
86
-
87
- Style/BisectedAttrAccessor:
88
- Enabled: true
89
-
90
- Style/CaseLikeIf:
91
- Enabled: true
92
-
93
- Style/ExponentialNotation:
94
- Enabled: true
95
-
96
64
  Style/FormatStringToken:
97
65
  Enabled: false
98
66
 
99
- Style/HashAsLastArrayItem:
100
- Enabled: true
101
-
102
- Style/HashLikeCase:
103
- Enabled: true
104
-
105
67
  Style/NumericLiterals:
106
68
  Enabled: false
107
69
 
108
- Style/RedundantAssignment:
109
- Enabled: true
110
-
111
- Style/RedundantFetchBlock:
112
- Enabled: true
113
-
114
- Style/RedundantFileExtensionInRequire:
115
- Enabled: true
116
-
117
- Style/RedundantRegexpCharacterClass:
118
- Enabled: true
119
-
120
- Style/RedundantRegexpEscape:
121
- Enabled: true
122
-
123
- Style/SlicingWithRange:
124
- Enabled: true
125
-
126
- Layout/EmptyLineAfterGuardClause:
127
- Enabled: true
128
-
129
- Layout/EmptyLinesAroundAttributeAccessor:
130
- Enabled: true
131
-
132
- Layout/SpaceAroundMethodCallOperator:
133
- Enabled: true
134
-
135
- Performance/AncestorsInclude:
136
- Enabled: true
137
-
138
- Performance/BigDecimalWithNumericArgument:
139
- Enabled: true
140
-
141
- Performance/RedundantSortBlock:
142
- Enabled: true
143
-
144
- Performance/RedundantStringChars:
145
- Enabled: true
146
-
147
- Performance/ReverseFirst:
148
- Enabled: true
149
-
150
- Performance/SortReverse:
151
- Enabled: true
152
-
153
- Performance/Squeeze:
154
- Enabled: true
155
-
156
- Performance/StringInclude:
157
- Enabled: true
70
+ Style/StringConcatenation:
71
+ Enabled: false
158
72
 
159
73
  RSpec/MultipleExpectations:
160
74
  Enabled: false
161
75
 
76
+ RSpec/MultipleMemoizedHelpers:
77
+ Max: 25
78
+
162
79
  RSpec/NestedGroups:
163
80
  Max: 4
164
81
 
@@ -170,3 +87,6 @@ RSpec/InstanceVariable:
170
87
 
171
88
  RSpec/LeakyConstantDeclaration:
172
89
  Enabled: false
90
+
91
+ Performance/Sum:
92
+ Enabled: false
@@ -1,3 +1,31 @@
1
+ # 0.22.0
2
+ ## Breaking change
3
+ - Add lbfgsb.rb gem to runtime dependencies. Rumale uses lbfgsb gem for optimization.
4
+ This eliminates the need to require the mopti gem when using [NeighbourhoodComponentAnalysis](https://yoshoku.github.io/rumale/doc/Rumale/MetricLearning/NeighbourhoodComponentAnalysis.html).
5
+ - Add lbfgs solver to [LogisticRegression](https://yoshoku.github.io/rumale/doc/Rumale/LinearModel/LogisticRegression.html) and make it the default solver.
6
+
7
+ # 0.21.0
8
+ ## Breaking change
9
+ - Change the default value of max_iter argument on LinearModel estimators to 1000.
10
+
11
+ # 0.20.3
12
+ - Fix to use automatic solver of PCA in NeighbourhoodComponentAnalysis.
13
+ - Refactor some codes with Rubocop.
14
+ - Update README.
15
+
16
+ # 0.20.2
17
+ - Add cross-validator class for time-series data.
18
+ - [TimeSeriesSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/TimeSeriesSplit.html)
19
+
20
+ # 0.20.1
21
+ - Add cross-validator classes that split data according group labels.
22
+ - [GroupKFold](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/GroupKFold.html)
23
+ - [GroupShuffleSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/GroupShuffleSplit.html)
24
+ - Fix fraction treating of the number of samples on shuffle split cross-validator classes.
25
+ - [ShuffleSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/ShuffleSplit.html)
26
+ - [StratifiedShuffleSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/StratifiedShuffleSplit.html)
27
+ - Refactor some codes with Rubocop.
28
+
1
29
  # 0.20.0
2
30
  ## Breaking changes
3
31
  - Delete deprecated estimators such as PolynomialModel, Optimizer, and BaseLinearModel.
data/Gemfile CHANGED
@@ -3,11 +3,13 @@ source 'https://rubygems.org'
3
3
  # Specify your gem's dependencies in rumale.gemspec
4
4
  gemspec
5
5
 
6
- gem 'coveralls', '~> 0.8'
7
6
  gem 'mmh3', '>= 1.0'
8
- gem 'mopti', '>= 0.1.0'
9
7
  gem 'numo-linalg', '>= 0.1.4'
10
8
  gem 'parallel', '>= 1.17.0'
11
9
  gem 'rake', '~> 12.0'
12
10
  gem 'rake-compiler', '~> 1.0'
13
11
  gem 'rspec', '~> 3.0'
12
+ gem 'rubocop', '~> 0.91'
13
+ gem 'rubocop-performance', '~> 1.8'
14
+ gem 'rubocop-rspec', '~> 1.43'
15
+ gem 'simplecov', '~> 0.19'
data/README.md CHANGED
@@ -2,8 +2,7 @@
2
2
 
3
3
  ![Rumale](https://dl.dropboxusercontent.com/s/joxruk2720ur66o/rumale_header_400.png)
4
4
 
5
- [![Build Status](https://travis-ci.org/yoshoku/rumale.svg?branch=master)](https://travis-ci.org/yoshoku/rumale)
6
- [![Coverage Status](https://coveralls.io/repos/github/yoshoku/rumale/badge.svg?branch=master)](https://coveralls.io/github/yoshoku/rumale?branch=master)
5
+ [![Build Status](https://github.com/yoshoku/rumale/workflows/build/badge.svg)](https://github.com/yoshoku/rumale/actions?query=workflow%3Abuild)
7
6
  [![Gem Version](https://badge.fury.io/rb/rumale.svg)](https://badge.fury.io/rb/rumale)
8
7
  [![BSD 2-Clause License](https://img.shields.io/badge/License-BSD%202--Clause-orange.svg)](https://github.com/yoshoku/rumale/blob/master/LICENSE.txt)
9
8
  [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://yoshoku.github.io/rumale/doc/)
@@ -228,6 +227,10 @@ When -1 is given to n_jobs parameter, all processors are used.
228
227
  estimator = Rumale::Ensemble::RandomForestClassifier.new(n_jobs: -1, random_seed: 1)
229
228
  ```
230
229
 
230
+ ## Novelties
231
+
232
+ * [Rumale SHOP](https://suzuri.jp/yoshoku)
233
+
231
234
  ## Contributing
232
235
 
233
236
  Bug reports and pull requests are welcome on GitHub at https://github.com/yoshoku/rumale.
@@ -98,9 +98,12 @@ require 'rumale/preprocessing/ordinal_encoder'
98
98
  require 'rumale/preprocessing/binarizer'
99
99
  require 'rumale/preprocessing/polynomial_features'
100
100
  require 'rumale/model_selection/k_fold'
101
+ require 'rumale/model_selection/group_k_fold'
101
102
  require 'rumale/model_selection/stratified_k_fold'
102
103
  require 'rumale/model_selection/shuffle_split'
104
+ require 'rumale/model_selection/group_shuffle_split'
103
105
  require 'rumale/model_selection/stratified_shuffle_split'
106
+ require 'rumale/model_selection/time_series_split'
104
107
  require 'rumale/model_selection/cross_validation'
105
108
  require 'rumale/model_selection/grid_search_cv'
106
109
  require 'rumale/model_selection/function'
@@ -136,7 +136,7 @@ module Rumale
136
136
  res
137
137
  end
138
138
 
139
- # rubocop:disable Metrics/AbcSize, Metrics/MethodLength, Metrics/PerceivedComplexity
139
+ # rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
140
140
  def condense_tree(hierarchy, min_cluster_size)
141
141
  n_edges = hierarchy.size
142
142
  root = 2 * n_edges
@@ -265,7 +265,7 @@ module Rumale
265
265
  end
266
266
  res
267
267
  end
268
- # rubocop:enable Metrics/AbcSize, Metrics/MethodLength, Metrics/PerceivedComplexity
268
+ # rubocop:enable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
269
269
  end
270
270
  end
271
271
  end
@@ -51,7 +51,7 @@ module Rumale
51
51
  # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for cluster analysis.
52
52
  # If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
53
53
  # @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
54
- def fit_predict(x)
54
+ def fit_predict(x) # rubocop:disable Lint/UselessMethodDefinition
55
55
  super
56
56
  end
57
57
 
@@ -225,7 +225,7 @@ module Rumale
225
225
  line = dump_label(label, label_type.to_s)
226
226
  ftvec.to_a.each_with_index do |val, n|
227
227
  idx = n + (zero_based == false ? 1 : 0)
228
- line += format(" %d:#{value_type}", idx, val) if val != 0.0
228
+ line += format(" %d:#{value_type}", idx, val) if val != 0
229
229
  end
230
230
  line
231
231
  end
@@ -77,7 +77,7 @@ module Rumale
77
77
  # @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
78
78
  def transform(x)
79
79
  x = check_convert_sample_array(x)
80
- partial_fit(x, false)
80
+ partial_fit(x, update_comps: false)
81
81
  end
82
82
 
83
83
  # Inverse transform the given transformed data with the learned model.
@@ -91,7 +91,7 @@ module Rumale
91
91
 
92
92
  private
93
93
 
94
- def partial_fit(x, update_comps = true)
94
+ def partial_fit(x, update_comps: true)
95
95
  # initialize some variables.
96
96
  n_samples, n_features = x.shape
97
97
  scale = Math.sqrt(x.mean / @params[:n_components])
@@ -85,7 +85,7 @@ module Rumale
85
85
  # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
86
86
  # @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
87
87
  # @return [RandomForestClassifier] The learned classifier itself.
88
- def fit(x, y)
88
+ def fit(x, y) # rubocop:disable Metrics/AbcSize
89
89
  x = check_convert_sample_array(x)
90
90
  y = check_convert_label_array(y)
91
91
  check_sample_label_size(x, y)
@@ -79,7 +79,7 @@ module Rumale
79
79
  # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
80
80
  # @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values to be used for fitting the model.
81
81
  # @return [RandomForestRegressor] The learned regressor itself.
82
- def fit(x, y)
82
+ def fit(x, y) # rubocop:disable Metrics/AbcSize
83
83
  x = check_convert_sample_array(x)
84
84
  y = check_convert_tvalue_array(y)
85
85
  check_sample_tvalue_size(x, y)
@@ -75,9 +75,12 @@ module Rumale
75
75
  false_pos, true_pos, thresholds = binary_roc_curve(y_true, y_score, pos_label)
76
76
 
77
77
  if true_pos.size.zero? || false_pos[0] != 0 || true_pos[0] != 0
78
+ # NOTE: Numo::NArray#insert is not a destructive method.
79
+ # rubocop:disable Style/RedundantSelfAssignment
78
80
  true_pos = true_pos.insert(0, 0)
79
81
  false_pos = false_pos.insert(0, 0)
80
82
  thresholds = thresholds.insert(0, thresholds[0] + 1)
83
+ # rubocop:enable Style/RedundantSelfAssignment
81
84
  end
82
85
 
83
86
  tpr = true_pos / true_pos[-1].to_f
@@ -67,7 +67,7 @@ module Rumale
67
67
  def transform(x)
68
68
  raise 'FeatureHasher#transform requires Mmh3 but that is not loaded.' unless enable_mmh3?
69
69
 
70
- x = [x] unless x.is_a?(Array)
70
+ x = [x] unless x.is_a?(Array) # rubocop:disable Style/ArrayCoercion
71
71
  n_samples = x.size
72
72
 
73
73
  z = Numo::DFloat.zeros(n_samples, n_features)
@@ -99,7 +99,7 @@ module Rumale
99
99
  # @param x [Array<Hash>] (shape: [n_samples]) The array of hash consisting of feature names and values.
100
100
  # @return [Numo::DFloat] (shape: [n_samples, n_features]) The encoded sample array.
101
101
  def transform(x)
102
- x = [x] unless x.is_a?(Array)
102
+ x = [x] unless x.is_a?(Array) # rubocop:disable Style/ArrayCoercion
103
103
  n_samples = x.size
104
104
  n_features = @vocabulary.size
105
105
  z = Numo::DFloat.zeros(n_samples, n_features)
@@ -171,7 +171,7 @@ module Rumale
171
171
  @params[:fit_bias] = true
172
172
  @params[:reg_param] = 0.0
173
173
  @params[:l1_ratio] = 0.0
174
- @params[:max_iter] = 200
174
+ @params[:max_iter] = 1000
175
175
  @params[:batch_size] = 50
176
176
  @params[:tol] = 0.0001
177
177
  @params[:verbose] = false
@@ -10,7 +10,7 @@ module Rumale
10
10
  #
11
11
  # @example
12
12
  # estimator =
13
- # Rumale::LinearModel::ElasticNet.new(reg_param: 0.1, l1_ratio: 0.5, max_iter: 200, batch_size: 50, random_seed: 1)
13
+ # Rumale::LinearModel::ElasticNet.new(reg_param: 0.1, l1_ratio: 0.5, max_iter: 1000, batch_size: 50, random_seed: 1)
14
14
  # estimator.fit(training_samples, traininig_values)
15
15
  # results = estimator.predict(testing_samples)
16
16
  #
@@ -59,7 +59,7 @@ module Rumale
59
59
  # @param random_seed [Integer] The seed value using to initialize the random generator.
60
60
  def initialize(learning_rate: 0.01, decay: nil, momentum: 0.9,
61
61
  reg_param: 1.0, l1_ratio: 0.5, fit_bias: true, bias_scale: 1.0,
62
- max_iter: 200, batch_size: 50, tol: 1e-4,
62
+ max_iter: 1000, batch_size: 50, tol: 1e-4,
63
63
  n_jobs: nil, verbose: false, random_seed: nil)
64
64
  check_params_numeric(learning_rate: learning_rate, momentum: momentum,
65
65
  reg_param: reg_param, l1_ratio: l1_ratio, bias_scale: bias_scale,
@@ -10,7 +10,7 @@ module Rumale
10
10
  #
11
11
  # @example
12
12
  # estimator =
13
- # Rumale::LinearModel::Lasso.new(reg_param: 0.1, max_iter: 500, batch_size: 20, random_seed: 1)
13
+ # Rumale::LinearModel::Lasso.new(reg_param: 0.1, max_iter: 1000, batch_size: 20, random_seed: 1)
14
14
  # estimator.fit(training_samples, traininig_values)
15
15
  # results = estimator.predict(testing_samples)
16
16
  #
@@ -55,7 +55,7 @@ module Rumale
55
55
  # @param random_seed [Integer] The seed value using to initialize the random generator.
56
56
  def initialize(learning_rate: 0.01, decay: nil, momentum: 0.9,
57
57
  reg_param: 1.0, fit_bias: true, bias_scale: 1.0,
58
- max_iter: 200, batch_size: 50, tol: 1e-4,
58
+ max_iter: 1000, batch_size: 50, tol: 1e-4,
59
59
  n_jobs: nil, verbose: false, random_seed: nil)
60
60
  check_params_numeric(learning_rate: learning_rate, momentum: momentum,
61
61
  reg_param: reg_param, bias_scale: bias_scale,
@@ -10,7 +10,7 @@ module Rumale
10
10
  #
11
11
  # @example
12
12
  # estimator =
13
- # Rumale::LinearModel::LinearRegression.new(max_iter: 500, batch_size: 20, random_seed: 1)
13
+ # Rumale::LinearModel::LinearRegression.new(max_iter: 1000, batch_size: 20, random_seed: 1)
14
14
  # estimator.fit(training_samples, traininig_values)
15
15
  # results = estimator.predict(testing_samples)
16
16
  #
@@ -68,7 +68,7 @@ module Rumale
68
68
  # If solver = 'svd', this parameter is ignored.
69
69
  # @param random_seed [Integer] The seed value using to initialize the random generator.
70
70
  def initialize(learning_rate: 0.01, decay: nil, momentum: 0.9,
71
- fit_bias: true, bias_scale: 1.0, max_iter: 200, batch_size: 50, tol: 1e-4,
71
+ fit_bias: true, bias_scale: 1.0, max_iter: 1000, batch_size: 50, tol: 1e-4,
72
72
  solver: 'auto',
73
73
  n_jobs: nil, verbose: false, random_seed: nil)
74
74
  check_params_numeric(learning_rate: learning_rate, momentum: momentum,
@@ -1,21 +1,24 @@
1
1
  # frozen_string_literal: true
2
2
 
3
- require 'rumale/linear_model/base_sgd'
3
+ require 'lbfgsb'
4
4
  require 'rumale/base/classifier'
5
+ require 'rumale/linear_model/base_sgd'
6
+ require 'rumale/preprocessing/label_binarizer'
5
7
 
6
8
  module Rumale
7
9
  module LinearModel
8
- # LogisticRegression is a class that implements Logistic Regression
9
- # with stochastic gradient descent optimization.
10
- # For multiclass classification problem, it uses one-vs-the-rest strategy.
10
+ # LogisticRegression is a class that implements Logistic Regression.
11
+ # In multiclass classification problem, it uses one-vs-the-rest strategy for the sgd solver
12
+ # and multinomial logistic regression for the lbfgs solver.
11
13
  #
12
- # Rumale::SVM provides Logistic Regression based on LIBLINEAR.
13
- # If you prefer execution speed, you should use Rumale::SVM::LogisticRegression.
14
- # https://github.com/yoshoku/rumale-svm
14
+ # @note
15
+ # Rumale::SVM provides Logistic Regression based on LIBLINEAR.
16
+ # If you prefer execution speed, you should use Rumale::SVM::LogisticRegression.
17
+ # https://github.com/yoshoku/rumale-svm
15
18
  #
16
19
  # @example
17
20
  # estimator =
18
- # Rumale::LinearModel::LogisticRegression.new(reg_param: 1.0, max_iter: 200, batch_size: 50, random_seed: 1)
21
+ # Rumale::LinearModel::LogisticRegression.new(reg_param: 1.0, random_seed: 1)
19
22
  # estimator.fit(training_samples, traininig_labels)
20
23
  # results = estimator.predict(testing_samples)
21
24
  #
@@ -42,19 +45,24 @@ module Rumale
42
45
  # @return [Random]
43
46
  attr_reader :rng
44
47
 
45
- # Create a new classifier with Logisitc Regression by the SGD optimization.
48
+ # Create a new classifier with Logisitc Regression.
46
49
  #
47
50
  # @param learning_rate [Float] The initial value of learning rate.
48
51
  # The learning rate decreases as the iteration proceeds according to the equation: learning_rate / (1 + decay * t).
52
+ # If solver = 'lbfgs', this parameter is ignored.
49
53
  # @param decay [Float] The smoothing parameter for decreasing learning rate as the iteration proceeds.
50
54
  # If nil is given, the decay sets to 'reg_param * learning_rate'.
55
+ # If solver = 'lbfgs', this parameter is ignored.
51
56
  # @param momentum [Float] The momentum factor.
57
+ # If solver = 'lbfgs', this parameter is ignored.
52
58
  # @param penalty [String] The regularization type to be used ('l1', 'l2', and 'elasticnet').
59
+ # If solver = 'lbfgs', only 'l2' can be selected for this parameter.
53
60
  # @param l1_ratio [Float] The elastic-net type regularization mixing parameter.
54
61
  # If penalty set to 'l2' or 'l1', this parameter is ignored.
55
62
  # If l1_ratio = 1, the regularization is similar to Lasso.
56
63
  # If l1_ratio = 0, the regularization is similar to Ridge.
57
64
  # If 0 < l1_ratio < 1, the regularization is a combination of L1 and L2.
65
+ # If solver = 'lbfgs', this parameter is ignored.
58
66
  # @param reg_param [Float] The regularization parameter.
59
67
  # @param fit_bias [Boolean] The flag indicating whether to fit the bias term.
60
68
  # @param bias_scale [Float] The scale of the bias term.
@@ -62,28 +70,38 @@ module Rumale
62
70
  # @param max_iter [Integer] The maximum number of epochs that indicates
63
71
  # how many times the whole data is given to the training process.
64
72
  # @param batch_size [Integer] The size of the mini batches.
73
+ # If solver = 'lbfgs', this parameter is ignored.
65
74
  # @param tol [Float] The tolerance of loss for terminating optimization.
75
+ # If solver = 'lbfgs', this value is given as tol / Lbfgsb::DBL_EPSILON to the factr argument of Lbfgsb.minimize method.
76
+ # @param solver [String] The algorithm for optimization. ('lbfgs' or 'sgd').
77
+ # 'lbfgs' uses the L-BFGS with lbfgs.rb gem.
78
+ # 'sgd' uses the stochastic gradient descent optimization.
66
79
  # @param n_jobs [Integer] The number of jobs for running the fit and predict methods in parallel.
67
80
  # If nil is given, the methods do not execute in parallel.
68
81
  # If zero or less is given, it becomes equal to the number of processors.
69
- # This parameter is ignored if the Parallel gem is not loaded.
82
+ # This parameter is ignored if the Parallel gem is not loaded or the solver is 'lbfgs'.
70
83
  # @param verbose [Boolean] The flag indicating whether to output loss during iteration.
84
+ # If solver = 'lbfgs' and true is given, 'iterate.dat' file is generated by lbfgsb.rb.
71
85
  # @param random_seed [Integer] The seed value using to initialize the random generator.
72
86
  def initialize(learning_rate: 0.01, decay: nil, momentum: 0.9,
73
87
  penalty: 'l2', reg_param: 1.0, l1_ratio: 0.5,
74
88
  fit_bias: true, bias_scale: 1.0,
75
- max_iter: 200, batch_size: 50, tol: 1e-4,
89
+ max_iter: 1000, batch_size: 50, tol: 1e-4,
90
+ solver: 'lbfgs',
76
91
  n_jobs: nil, verbose: false, random_seed: nil)
77
92
  check_params_numeric(learning_rate: learning_rate, momentum: momentum,
78
93
  reg_param: reg_param, l1_ratio: l1_ratio, bias_scale: bias_scale,
79
94
  max_iter: max_iter, batch_size: batch_size, tol: tol)
80
95
  check_params_boolean(fit_bias: fit_bias, verbose: verbose)
81
- check_params_string(penalty: penalty)
96
+ check_params_string(solver: solver, penalty: penalty)
82
97
  check_params_numeric_or_nil(decay: decay, n_jobs: n_jobs, random_seed: random_seed)
83
98
  check_params_positive(learning_rate: learning_rate, reg_param: reg_param,
84
99
  bias_scale: bias_scale, max_iter: max_iter, batch_size: batch_size)
100
+ raise ArgumentError, "The 'lbfgs' solver supports only 'l2' penalties." if solver == 'lbfgs' && penalty != 'l2'
101
+
85
102
  super()
86
103
  @params.merge!(method(:initialize).parameters.map { |_t, arg| [arg, binding.local_variable_get(arg)] }.to_h)
104
+ @params[:solver] = solver == 'sgd' ? 'sgd' : 'lbfgs'
87
105
  @params[:decay] ||= @params[:reg_param] * @params[:learning_rate]
88
106
  @params[:random_seed] ||= srand
89
107
  @rng = Random.new(@params[:random_seed])
@@ -105,30 +123,10 @@ module Rumale
105
123
  check_sample_label_size(x, y)
106
124
 
107
125
  @classes = Numo::Int32[*y.to_a.uniq.sort]
108
-
109
- if multiclass_problem?
110
- n_classes = @classes.size
111
- n_features = x.shape[1]
112
- @weight_vec = Numo::DFloat.zeros(n_classes, n_features)
113
- @bias_term = Numo::DFloat.zeros(n_classes)
114
- if enable_parallel?
115
- # :nocov:
116
- models = parallel_map(n_classes) do |n|
117
- bin_y = Numo::Int32.cast(y.eq(@classes[n])) * 2 - 1
118
- partial_fit(x, bin_y)
119
- end
120
- # :nocov:
121
- n_classes.times { |n| @weight_vec[n, true], @bias_term[n] = models[n] }
122
- else
123
- n_classes.times do |n|
124
- bin_y = Numo::Int32.cast(y.eq(@classes[n])) * 2 - 1
125
- @weight_vec[n, true], @bias_term[n] = partial_fit(x, bin_y)
126
- end
127
- end
126
+ if @params[:solver] == 'sgd'
127
+ fit_sgd(x, y)
128
128
  else
129
- negative_label = @classes[0]
130
- bin_y = Numo::Int32.cast(y.ne(negative_label)) * 2 - 1
131
- @weight_vec, @bias_term = partial_fit(x, bin_y)
129
+ fit_lbfgs(x, y)
132
130
  end
133
131
 
134
132
  self
@@ -182,6 +180,96 @@ module Rumale
182
180
  def multiclass_problem?
183
181
  @classes.size > 2
184
182
  end
183
+
184
+ def fit_lbfgs(base_x, base_y)
185
+ if multiclass_problem?
186
+ fnc = proc do |w, x, y, a|
187
+ n_features = x.shape[1]
188
+ n_classes = y.shape[1]
189
+ z = x.dot(w.reshape(n_classes, n_features).transpose)
190
+ # logsumexp and softmax
191
+ z_max = z.max(-1).expand_dims(-1).dup
192
+ z_max[~z_max.isfinite] = 0.0
193
+ lgsexp = Numo::NMath.log(Numo::NMath.exp(z - z_max).sum(-1)).expand_dims(-1) + z_max
194
+ t = z - lgsexp
195
+ sftmax = Numo::NMath.exp(t)
196
+ # loss and gradient
197
+ loss = -(y * t).sum + 0.5 * a * w.dot(w)
198
+ grad = (sftmax - y).transpose.dot(x).flatten.dup + a * w
199
+ [loss, grad]
200
+ end
201
+
202
+ base_x = expand_feature(base_x) if fit_bias?
203
+ encoder = Rumale::Preprocessing::LabelBinarizer.new
204
+ onehot_y = encoder.fit_transform(base_y)
205
+ n_classes = @classes.size
206
+ n_features = base_x.shape[1]
207
+ w_init = Numo::DFloat.zeros(n_classes * n_features)
208
+
209
+ verbose = @params[:verbose] ? 1 : -1
210
+ res = Lbfgsb.minimize(
211
+ fnc: fnc, jcb: true, x_init: w_init, args: [base_x, onehot_y, @params[:reg_param]],
212
+ maxiter: @params[:max_iter], factr: @params[:tol] / Lbfgsb::DBL_EPSILON, verbose: verbose
213
+ )
214
+
215
+ if fit_bias?
216
+ weight = res[:x].reshape(n_classes, n_features)
217
+ @weight_vec = weight[true, 0...-1].dup
218
+ @bias_term = weight[true, -1].dup
219
+ else
220
+ @weight_vec = res[:x].reshape(n_classes, n_features)
221
+ @bias_term = Numo::DFloat.zeros(n_classes)
222
+ end
223
+ else
224
+ fnc = proc do |w, x, y, a|
225
+ z = 1 + Numo::NMath.exp(-y * x.dot(w))
226
+ loss = Numo::NMath.log(z).sum + 0.5 * a * w.dot(w)
227
+ grad = (y / z - y).dot(x) + a * w
228
+ [loss, grad]
229
+ end
230
+
231
+ base_x = expand_feature(base_x) if fit_bias?
232
+ negative_label = @classes[0]
233
+ bin_y = Numo::Int32.cast(base_y.ne(negative_label)) * 2 - 1
234
+ n_features = base_x.shape[1]
235
+ w_init = Numo::DFloat.zeros(n_features)
236
+
237
+ verbose = @params[:verbose] ? 1 : -1
238
+ res = Lbfgsb.minimize(
239
+ fnc: fnc, jcb: true, x_init: w_init, args: [base_x, bin_y, @params[:reg_param]],
240
+ maxiter: @params[:max_iter], factr: @params[:tol] / Lbfgsb::DBL_EPSILON, verbose: verbose
241
+ )
242
+
243
+ @weight_vec, @bias_term = split_weight(res[:x])
244
+ end
245
+ end
246
+
247
+ def fit_sgd(x, y)
248
+ if multiclass_problem?
249
+ n_classes = @classes.size
250
+ n_features = x.shape[1]
251
+ @weight_vec = Numo::DFloat.zeros(n_classes, n_features)
252
+ @bias_term = Numo::DFloat.zeros(n_classes)
253
+ if enable_parallel?
254
+ # :nocov:
255
+ models = parallel_map(n_classes) do |n|
256
+ bin_y = Numo::Int32.cast(y.eq(@classes[n])) * 2 - 1
257
+ partial_fit(x, bin_y)
258
+ end
259
+ # :nocov:
260
+ n_classes.times { |n| @weight_vec[n, true], @bias_term[n] = models[n] }
261
+ else
262
+ n_classes.times do |n|
263
+ bin_y = Numo::Int32.cast(y.eq(@classes[n])) * 2 - 1
264
+ @weight_vec[n, true], @bias_term[n] = partial_fit(x, bin_y)
265
+ end
266
+ end
267
+ else
268
+ negative_label = @classes[0]
269
+ bin_y = Numo::Int32.cast(y.ne(negative_label)) * 2 - 1
270
+ @weight_vec, @bias_term = partial_fit(x, bin_y)
271
+ end
272
+ end
185
273
  end
186
274
  end
187
275
  end