rumale 0.22.5 → 0.23.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: '058078489d3ff66d67432e1418ae786292c263e05e75b6703fb5a7e65e88bd46'
4
- data.tar.gz: bd7ed9b223e0cd0074ffdd3e521b01c195f82909013c93f4736ab338d5920c96
3
+ metadata.gz: 0f66b61c7ae0d76fdccc2b95c8f1ae3ea181c1622024a0dd93a15dea17e9a632
4
+ data.tar.gz: 1bf6934f79110b4bc59528bc6d656250f6e3fcb7834c9bf262e53f42fae78b69
5
5
  SHA512:
6
- metadata.gz: 79ce4715a503b1b5a618526832adad5912daac72af3e8f1892ff2df14b7695e546419d6f85e5ea735abd2ea06da649a763f96c82b95eee75341934fa65fce93e
7
- data.tar.gz: 5948583ec6c5ca10b320e09c447f9cc1e244dd3bfbf95800338dba2eb7e1b46a342d44020fa24c26a7161786feb63c82b0677654e4aa087c311337a606880a22
6
+ metadata.gz: da5c9c6463d3fbefc2d48628b053379cfc7e71195780dcbceddefadd0211f15de509911a96e722d464640e9da1107be8298354f0df464ba63ca7e385632adcc8
7
+ data.tar.gz: a54bce60f8d9c0f65a4ea2899d900b7fd7069181d5e76461dc48820314c603c3069323d96ceb8851d57910802b2c8f361a2c1c5f7345159736622cd23145fc87
data/CHANGELOG.md CHANGED
@@ -1,3 +1,11 @@
1
+ # 0.23.0
2
+ ## Breaking change
3
+ - Change automalically selected solver from sgd to lbfgs in
4
+ [LinearRegression](https://yoshoku.github.io/rumale/doc/Rumale/LinearModel/LinearRegression.html) and
5
+ [Ridge](https://yoshoku.github.io/rumale/doc/Rumale/LinearModel/Ridge.html).
6
+ - When given 'auto' to solver parameter, these estimator select the 'svd' solver if Numo::Linalg is loaded.
7
+ Otherwise, they select the 'lbfgs' solver.
8
+
1
9
  # 0.22.5
2
10
  - Add transformer class for calculating kernel matrix.
3
11
  - [KernelCalculator](https://yoshoku.github.io/rumale/doc/Rumale/Preprocessing/KernelCalculator.html)
@@ -1,5 +1,7 @@
1
1
  # frozen_string_literal: true
2
2
 
3
+ require 'lbfgsb'
4
+
3
5
  require 'rumale/linear_model/base_sgd'
4
6
  require 'rumale/base/regressor'
5
7
 
@@ -58,7 +60,7 @@ module Rumale
58
60
  # @param tol [Float] The tolerance of loss for terminating optimization.
59
61
  # If solver is 'svd', this parameter is ignored.
60
62
  # @param solver [String] The algorithm to calculate weights. ('auto', 'sgd', 'svd' or 'lbfgs').
61
- # 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'sgd' solver.
63
+ # 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'lbfgs' solver.
62
64
  # 'sgd' uses the stochastic gradient descent optimization.
63
65
  # 'svd' performs singular value decomposition of samples.
64
66
  # 'lbfgs' uses the L-BFGS method for optimization.
@@ -82,9 +84,9 @@ module Rumale
82
84
  super()
83
85
  @params.merge!(method(:initialize).parameters.map { |_t, arg| [arg, binding.local_variable_get(arg)] }.to_h)
84
86
  @params[:solver] = if solver == 'auto'
85
- enable_linalg?(warning: false) ? 'svd' : 'sgd'
87
+ enable_linalg?(warning: false) ? 'svd' : 'lbfgs'
86
88
  else
87
- solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : 'sgd'
89
+ solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : 'lbfgs'
88
90
  end
89
91
  @params[:decay] ||= @params[:learning_rate]
90
92
  @params[:random_seed] ||= srand
@@ -61,7 +61,7 @@ module Rumale
61
61
  # @param tol [Float] The tolerance of loss for terminating optimization.
62
62
  # If solver is 'svd', this parameter is ignored.
63
63
  # @param solver [String] The algorithm to calculate weights. ('auto', 'sgd', 'svd', or 'lbfgs').
64
- # 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'sgd' solver.
64
+ # 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'lbfgs' solver.
65
65
  # 'sgd' uses the stochastic gradient descent optimization.
66
66
  # 'svd' performs singular value decomposition of samples.
67
67
  # 'lbfgs' uses the L-BFGS method for optimization.
@@ -87,9 +87,9 @@ module Rumale
87
87
  super()
88
88
  @params.merge!(method(:initialize).parameters.map { |_t, arg| [arg, binding.local_variable_get(arg)] }.to_h)
89
89
  @params[:solver] = if solver == 'auto'
90
- enable_linalg?(warning: false) ? 'svd' : 'sgd'
90
+ enable_linalg?(warning: false) ? 'svd' : 'lbfgs'
91
91
  else
92
- solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : 'sgd'
92
+ solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : 'lbfgs'
93
93
  end
94
94
  @params[:decay] ||= @params[:reg_param] * @params[:learning_rate]
95
95
  @params[:random_seed] ||= srand
@@ -160,15 +160,15 @@ module Rumale
160
160
  grid = [grid] if grid.is_a?(Hash)
161
161
  grid.each do |h|
162
162
  raise TypeError, 'Expect class of elements in param_grid to be Hash' unless h.is_a?(Hash)
163
- raise TypeError, 'Expect class of parameter values in param_grid to be Array' unless h.values.all? { |v| v.is_a?(Array) }
163
+ raise TypeError, 'Expect class of parameter values in param_grid to be Array' unless h.values.all?(Array)
164
164
  end
165
165
  grid
166
166
  end
167
167
 
168
168
  def param_combinations
169
169
  @param_combinations ||= @params[:param_grid].map do |prm|
170
- x = Hash[prm.sort].map { |k, v| [k].product(v) }
171
- x[0].product(*x[1...x.size]).map { |v| Hash[v] }
170
+ x = prm.sort.to_h.map { |k, v| [k].product(v) }
171
+ x[0].product(*x[1...x.size]).map(&:to_h)
172
172
  end
173
173
  end
174
174
 
@@ -3,5 +3,5 @@
3
3
  # Rumale is a machine learning library in Ruby.
4
4
  module Rumale
5
5
  # The version of Rumale you are using.
6
- VERSION = '0.22.5'
6
+ VERSION = '0.23.0'
7
7
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.22.5
4
+ version: 0.23.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2021-03-14 00:00:00.000000000 Z
11
+ date: 2021-04-04 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray