rumale 0.22.5 → 0.23.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: '058078489d3ff66d67432e1418ae786292c263e05e75b6703fb5a7e65e88bd46'
4
- data.tar.gz: bd7ed9b223e0cd0074ffdd3e521b01c195f82909013c93f4736ab338d5920c96
3
+ metadata.gz: 0f66b61c7ae0d76fdccc2b95c8f1ae3ea181c1622024a0dd93a15dea17e9a632
4
+ data.tar.gz: 1bf6934f79110b4bc59528bc6d656250f6e3fcb7834c9bf262e53f42fae78b69
5
5
  SHA512:
6
- metadata.gz: 79ce4715a503b1b5a618526832adad5912daac72af3e8f1892ff2df14b7695e546419d6f85e5ea735abd2ea06da649a763f96c82b95eee75341934fa65fce93e
7
- data.tar.gz: 5948583ec6c5ca10b320e09c447f9cc1e244dd3bfbf95800338dba2eb7e1b46a342d44020fa24c26a7161786feb63c82b0677654e4aa087c311337a606880a22
6
+ metadata.gz: da5c9c6463d3fbefc2d48628b053379cfc7e71195780dcbceddefadd0211f15de509911a96e722d464640e9da1107be8298354f0df464ba63ca7e385632adcc8
7
+ data.tar.gz: a54bce60f8d9c0f65a4ea2899d900b7fd7069181d5e76461dc48820314c603c3069323d96ceb8851d57910802b2c8f361a2c1c5f7345159736622cd23145fc87
data/CHANGELOG.md CHANGED
@@ -1,3 +1,11 @@
1
+ # 0.23.0
2
+ ## Breaking change
3
+ - Change automalically selected solver from sgd to lbfgs in
4
+ [LinearRegression](https://yoshoku.github.io/rumale/doc/Rumale/LinearModel/LinearRegression.html) and
5
+ [Ridge](https://yoshoku.github.io/rumale/doc/Rumale/LinearModel/Ridge.html).
6
+ - When given 'auto' to solver parameter, these estimator select the 'svd' solver if Numo::Linalg is loaded.
7
+ Otherwise, they select the 'lbfgs' solver.
8
+
1
9
  # 0.22.5
2
10
  - Add transformer class for calculating kernel matrix.
3
11
  - [KernelCalculator](https://yoshoku.github.io/rumale/doc/Rumale/Preprocessing/KernelCalculator.html)
@@ -1,5 +1,7 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require 'lbfgsb'
4
+
3
5
  require 'rumale/linear_model/base_sgd'
4
6
  require 'rumale/base/regressor'
5
7
 
@@ -58,7 +60,7 @@ module Rumale
58
60
  # @param tol [Float] The tolerance of loss for terminating optimization.
59
61
  # If solver is 'svd', this parameter is ignored.
60
62
  # @param solver [String] The algorithm to calculate weights. ('auto', 'sgd', 'svd' or 'lbfgs').
61
- # 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'sgd' solver.
63
+ # 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'lbfgs' solver.
62
64
  # 'sgd' uses the stochastic gradient descent optimization.
63
65
  # 'svd' performs singular value decomposition of samples.
64
66
  # 'lbfgs' uses the L-BFGS method for optimization.
@@ -82,9 +84,9 @@ module Rumale
82
84
  super()
83
85
  @params.merge!(method(:initialize).parameters.map { |_t, arg| [arg, binding.local_variable_get(arg)] }.to_h)
84
86
  @params[:solver] = if solver == 'auto'
85
- enable_linalg?(warning: false) ? 'svd' : 'sgd'
87
+ enable_linalg?(warning: false) ? 'svd' : 'lbfgs'
86
88
  else
87
- solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : 'sgd'
89
+ solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : 'lbfgs'
88
90
  end
89
91
  @params[:decay] ||= @params[:learning_rate]
90
92
  @params[:random_seed] ||= srand
@@ -61,7 +61,7 @@ module Rumale
61
61
  # @param tol [Float] The tolerance of loss for terminating optimization.
62
62
  # If solver is 'svd', this parameter is ignored.
63
63
  # @param solver [String] The algorithm to calculate weights. ('auto', 'sgd', 'svd', or 'lbfgs').
64
- # 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'sgd' solver.
64
+ # 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'lbfgs' solver.
65
65
  # 'sgd' uses the stochastic gradient descent optimization.
66
66
  # 'svd' performs singular value decomposition of samples.
67
67
  # 'lbfgs' uses the L-BFGS method for optimization.
@@ -87,9 +87,9 @@ module Rumale
87
87
  super()
88
88
  @params.merge!(method(:initialize).parameters.map { |_t, arg| [arg, binding.local_variable_get(arg)] }.to_h)
89
89
  @params[:solver] = if solver == 'auto'
90
- enable_linalg?(warning: false) ? 'svd' : 'sgd'
90
+ enable_linalg?(warning: false) ? 'svd' : 'lbfgs'
91
91
  else
92
- solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : 'sgd'
92
+ solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : 'lbfgs'
93
93
  end
94
94
  @params[:decay] ||= @params[:reg_param] * @params[:learning_rate]
95
95
  @params[:random_seed] ||= srand
@@ -160,15 +160,15 @@ module Rumale
160
160
  grid = [grid] if grid.is_a?(Hash)
161
161
  grid.each do |h|
162
162
  raise TypeError, 'Expect class of elements in param_grid to be Hash' unless h.is_a?(Hash)
163
- raise TypeError, 'Expect class of parameter values in param_grid to be Array' unless h.values.all? { |v| v.is_a?(Array) }
163
+ raise TypeError, 'Expect class of parameter values in param_grid to be Array' unless h.values.all?(Array)
164
164
  end
165
165
  grid
166
166
  end
167
167
 
168
168
  def param_combinations
169
169
  @param_combinations ||= @params[:param_grid].map do |prm|
170
- x = Hash[prm.sort].map { |k, v| [k].product(v) }
171
- x[0].product(*x[1...x.size]).map { |v| Hash[v] }
170
+ x = prm.sort.to_h.map { |k, v| [k].product(v) }
171
+ x[0].product(*x[1...x.size]).map(&:to_h)
172
172
  end
173
173
  end
174
174
 
@@ -3,5 +3,5 @@
3
3
  # Rumale is a machine learning library in Ruby.
4
4
  module Rumale
5
5
  # The version of Rumale you are using.
6
- VERSION = '0.22.5'
6
+ VERSION = '0.23.0'
7
7
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.22.5
4
+ version: 0.23.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2021-03-14 00:00:00.000000000 Z
11
+ date: 2021-04-04 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray