rumale 0.22.5 → 0.23.0
Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 0f66b61c7ae0d76fdccc2b95c8f1ae3ea181c1622024a0dd93a15dea17e9a632
|
4
|
+
data.tar.gz: 1bf6934f79110b4bc59528bc6d656250f6e3fcb7834c9bf262e53f42fae78b69
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: da5c9c6463d3fbefc2d48628b053379cfc7e71195780dcbceddefadd0211f15de509911a96e722d464640e9da1107be8298354f0df464ba63ca7e385632adcc8
|
7
|
+
data.tar.gz: a54bce60f8d9c0f65a4ea2899d900b7fd7069181d5e76461dc48820314c603c3069323d96ceb8851d57910802b2c8f361a2c1c5f7345159736622cd23145fc87
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,11 @@
|
|
1
|
+
# 0.23.0
|
2
|
+
## Breaking change
|
3
|
+
- Change automalically selected solver from sgd to lbfgs in
|
4
|
+
[LinearRegression](https://yoshoku.github.io/rumale/doc/Rumale/LinearModel/LinearRegression.html) and
|
5
|
+
[Ridge](https://yoshoku.github.io/rumale/doc/Rumale/LinearModel/Ridge.html).
|
6
|
+
- When given 'auto' to solver parameter, these estimator select the 'svd' solver if Numo::Linalg is loaded.
|
7
|
+
Otherwise, they select the 'lbfgs' solver.
|
8
|
+
|
1
9
|
# 0.22.5
|
2
10
|
- Add transformer class for calculating kernel matrix.
|
3
11
|
- [KernelCalculator](https://yoshoku.github.io/rumale/doc/Rumale/Preprocessing/KernelCalculator.html)
|
@@ -1,5 +1,7 @@
|
|
1
1
|
# frozen_string_literal: true
|
2
2
|
|
3
|
+
require 'lbfgsb'
|
4
|
+
|
3
5
|
require 'rumale/linear_model/base_sgd'
|
4
6
|
require 'rumale/base/regressor'
|
5
7
|
|
@@ -58,7 +60,7 @@ module Rumale
|
|
58
60
|
# @param tol [Float] The tolerance of loss for terminating optimization.
|
59
61
|
# If solver is 'svd', this parameter is ignored.
|
60
62
|
# @param solver [String] The algorithm to calculate weights. ('auto', 'sgd', 'svd' or 'lbfgs').
|
61
|
-
# 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the '
|
63
|
+
# 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'lbfgs' solver.
|
62
64
|
# 'sgd' uses the stochastic gradient descent optimization.
|
63
65
|
# 'svd' performs singular value decomposition of samples.
|
64
66
|
# 'lbfgs' uses the L-BFGS method for optimization.
|
@@ -82,9 +84,9 @@ module Rumale
|
|
82
84
|
super()
|
83
85
|
@params.merge!(method(:initialize).parameters.map { |_t, arg| [arg, binding.local_variable_get(arg)] }.to_h)
|
84
86
|
@params[:solver] = if solver == 'auto'
|
85
|
-
enable_linalg?(warning: false) ? 'svd' : '
|
87
|
+
enable_linalg?(warning: false) ? 'svd' : 'lbfgs'
|
86
88
|
else
|
87
|
-
solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : '
|
89
|
+
solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : 'lbfgs'
|
88
90
|
end
|
89
91
|
@params[:decay] ||= @params[:learning_rate]
|
90
92
|
@params[:random_seed] ||= srand
|
@@ -61,7 +61,7 @@ module Rumale
|
|
61
61
|
# @param tol [Float] The tolerance of loss for terminating optimization.
|
62
62
|
# If solver is 'svd', this parameter is ignored.
|
63
63
|
# @param solver [String] The algorithm to calculate weights. ('auto', 'sgd', 'svd', or 'lbfgs').
|
64
|
-
# 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the '
|
64
|
+
# 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'lbfgs' solver.
|
65
65
|
# 'sgd' uses the stochastic gradient descent optimization.
|
66
66
|
# 'svd' performs singular value decomposition of samples.
|
67
67
|
# 'lbfgs' uses the L-BFGS method for optimization.
|
@@ -87,9 +87,9 @@ module Rumale
|
|
87
87
|
super()
|
88
88
|
@params.merge!(method(:initialize).parameters.map { |_t, arg| [arg, binding.local_variable_get(arg)] }.to_h)
|
89
89
|
@params[:solver] = if solver == 'auto'
|
90
|
-
enable_linalg?(warning: false) ? 'svd' : '
|
90
|
+
enable_linalg?(warning: false) ? 'svd' : 'lbfgs'
|
91
91
|
else
|
92
|
-
solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : '
|
92
|
+
solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : 'lbfgs'
|
93
93
|
end
|
94
94
|
@params[:decay] ||= @params[:reg_param] * @params[:learning_rate]
|
95
95
|
@params[:random_seed] ||= srand
|
@@ -160,15 +160,15 @@ module Rumale
|
|
160
160
|
grid = [grid] if grid.is_a?(Hash)
|
161
161
|
grid.each do |h|
|
162
162
|
raise TypeError, 'Expect class of elements in param_grid to be Hash' unless h.is_a?(Hash)
|
163
|
-
raise TypeError, 'Expect class of parameter values in param_grid to be Array' unless h.values.all?
|
163
|
+
raise TypeError, 'Expect class of parameter values in param_grid to be Array' unless h.values.all?(Array)
|
164
164
|
end
|
165
165
|
grid
|
166
166
|
end
|
167
167
|
|
168
168
|
def param_combinations
|
169
169
|
@param_combinations ||= @params[:param_grid].map do |prm|
|
170
|
-
x =
|
171
|
-
x[0].product(*x[1...x.size]).map
|
170
|
+
x = prm.sort.to_h.map { |k, v| [k].product(v) }
|
171
|
+
x[0].product(*x[1...x.size]).map(&:to_h)
|
172
172
|
end
|
173
173
|
end
|
174
174
|
|
data/lib/rumale/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: rumale
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.23.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- yoshoku
|
8
8
|
autorequire:
|
9
9
|
bindir: exe
|
10
10
|
cert_chain: []
|
11
|
-
date: 2021-
|
11
|
+
date: 2021-04-04 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|