rumale 0.19.3 → 0.21.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.rubocop.yml +15 -95
- data/CHANGELOG.md +27 -0
- data/Gemfile +3 -0
- data/README.md +4 -0
- data/lib/rumale.rb +3 -10
- data/lib/rumale/clustering/hdbscan.rb +2 -2
- data/lib/rumale/clustering/snn.rb +1 -1
- data/lib/rumale/dataset.rb +1 -1
- data/lib/rumale/decomposition/nmf.rb +2 -2
- data/lib/rumale/ensemble/random_forest_classifier.rb +1 -1
- data/lib/rumale/ensemble/random_forest_regressor.rb +1 -1
- data/lib/rumale/evaluation_measure/roc_auc.rb +3 -0
- data/lib/rumale/feature_extraction/feature_hasher.rb +1 -1
- data/lib/rumale/feature_extraction/hash_vectorizer.rb +1 -1
- data/lib/rumale/linear_model/base_sgd.rb +1 -1
- data/lib/rumale/linear_model/elastic_net.rb +2 -2
- data/lib/rumale/linear_model/lasso.rb +2 -2
- data/lib/rumale/linear_model/linear_regression.rb +2 -2
- data/lib/rumale/linear_model/logistic_regression.rb +2 -2
- data/lib/rumale/linear_model/ridge.rb +2 -2
- data/lib/rumale/linear_model/svc.rb +2 -2
- data/lib/rumale/linear_model/svr.rb +2 -2
- data/lib/rumale/manifold/tsne.rb +1 -1
- data/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +1 -1
- data/lib/rumale/model_selection/group_k_fold.rb +93 -0
- data/lib/rumale/model_selection/group_shuffle_split.rb +115 -0
- data/lib/rumale/model_selection/shuffle_split.rb +4 -4
- data/lib/rumale/model_selection/stratified_k_fold.rb +1 -1
- data/lib/rumale/model_selection/stratified_shuffle_split.rb +13 -9
- data/lib/rumale/model_selection/time_series_split.rb +91 -0
- data/lib/rumale/pipeline/pipeline.rb +1 -1
- data/lib/rumale/probabilistic_output.rb +1 -1
- data/lib/rumale/tree/base_decision_tree.rb +2 -9
- data/lib/rumale/tree/gradient_tree_regressor.rb +3 -10
- data/lib/rumale/version.rb +1 -1
- metadata +5 -12
- data/lib/rumale/linear_model/base_linear_model.rb +0 -102
- data/lib/rumale/optimizer/ada_grad.rb +0 -42
- data/lib/rumale/optimizer/adam.rb +0 -56
- data/lib/rumale/optimizer/nadam.rb +0 -67
- data/lib/rumale/optimizer/rmsprop.rb +0 -50
- data/lib/rumale/optimizer/sgd.rb +0 -46
- data/lib/rumale/optimizer/yellow_fin.rb +0 -104
- data/lib/rumale/polynomial_model/base_factorization_machine.rb +0 -125
- data/lib/rumale/polynomial_model/factorization_machine_classifier.rb +0 -220
- data/lib/rumale/polynomial_model/factorization_machine_regressor.rb +0 -134
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: fd5ca16629a5258be9e577771dc8c6b42dbfdf3a60a4c43d4ee170cc17b72bea
|
|
4
|
+
data.tar.gz: 70b26bcf0b39bb8e716b9bbb4aba100c496152b1c4e71879105046440c7d8758
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: 645a6bda6e3601534c69f5ecfbd840c1d6c1ed7a5a3b8bd57995621a03d970cd02e9749a8a70be5af2678c029a26a5e6e1c32376a4514a64e96d6a9b4b12aa3e
|
|
7
|
+
data.tar.gz: 5904c64da9cc30cf0c288dfbeb3051bca3333e588e153cf19619c169d713e93edcb95e6902134e58c823d621ebad9e6a56310123c110b6d810033f5f96a40fbb
|
data/.rubocop.yml
CHANGED
|
@@ -3,6 +3,7 @@ require:
|
|
|
3
3
|
- rubocop-rspec
|
|
4
4
|
|
|
5
5
|
AllCops:
|
|
6
|
+
NewCops: enable
|
|
6
7
|
TargetRubyVersion: 2.5
|
|
7
8
|
DisplayCopNames: true
|
|
8
9
|
DisplayStyleGuide: true
|
|
@@ -15,34 +16,16 @@ AllCops:
|
|
|
15
16
|
Style/Documentation:
|
|
16
17
|
Enabled: false
|
|
17
18
|
|
|
18
|
-
Style/HashEachMethods:
|
|
19
|
-
Enabled: true
|
|
20
|
-
|
|
21
|
-
Style/HashTransformKeys:
|
|
22
|
-
Enabled: true
|
|
23
|
-
|
|
24
|
-
Style/HashTransformValues:
|
|
25
|
-
Enabled: true
|
|
26
|
-
|
|
27
|
-
Lint/DeprecatedOpenSSLConstant:
|
|
28
|
-
Enabled: true
|
|
29
|
-
|
|
30
|
-
Lint/DuplicateElsifCondition:
|
|
31
|
-
Enabled: true
|
|
32
|
-
|
|
33
|
-
Lint/MixedRegexpCaptureTypes:
|
|
34
|
-
Enabled: true
|
|
35
|
-
|
|
36
|
-
Lint/RaiseException:
|
|
37
|
-
Enabled: true
|
|
38
|
-
|
|
39
|
-
Lint/StructNewOverride:
|
|
40
|
-
Enabled: true
|
|
41
|
-
|
|
42
19
|
Layout/LineLength:
|
|
43
20
|
Max: 145
|
|
44
21
|
IgnoredPatterns: ['(\A|\s)#']
|
|
45
22
|
|
|
23
|
+
Lint/ConstantDefinitionInBlock:
|
|
24
|
+
Enabled: false
|
|
25
|
+
|
|
26
|
+
Lint/MissingSuper:
|
|
27
|
+
Enabled: false
|
|
28
|
+
|
|
46
29
|
Metrics/ModuleLength:
|
|
47
30
|
Max: 200
|
|
48
31
|
|
|
@@ -78,87 +61,21 @@ Naming/MethodParameterName:
|
|
|
78
61
|
Naming/ConstantName:
|
|
79
62
|
Enabled: false
|
|
80
63
|
|
|
81
|
-
Style/AccessorGrouping:
|
|
82
|
-
Enabled: true
|
|
83
|
-
|
|
84
|
-
Style/ArrayCoercion:
|
|
85
|
-
Enabled: true
|
|
86
|
-
|
|
87
|
-
Style/BisectedAttrAccessor:
|
|
88
|
-
Enabled: true
|
|
89
|
-
|
|
90
|
-
Style/CaseLikeIf:
|
|
91
|
-
Enabled: true
|
|
92
|
-
|
|
93
|
-
Style/ExponentialNotation:
|
|
94
|
-
Enabled: true
|
|
95
|
-
|
|
96
64
|
Style/FormatStringToken:
|
|
97
65
|
Enabled: false
|
|
98
66
|
|
|
99
|
-
Style/HashAsLastArrayItem:
|
|
100
|
-
Enabled: true
|
|
101
|
-
|
|
102
|
-
Style/HashLikeCase:
|
|
103
|
-
Enabled: true
|
|
104
|
-
|
|
105
67
|
Style/NumericLiterals:
|
|
106
68
|
Enabled: false
|
|
107
69
|
|
|
108
|
-
Style/
|
|
109
|
-
Enabled:
|
|
110
|
-
|
|
111
|
-
Style/RedundantFetchBlock:
|
|
112
|
-
Enabled: true
|
|
113
|
-
|
|
114
|
-
Style/RedundantFileExtensionInRequire:
|
|
115
|
-
Enabled: true
|
|
116
|
-
|
|
117
|
-
Style/RedundantRegexpCharacterClass:
|
|
118
|
-
Enabled: true
|
|
119
|
-
|
|
120
|
-
Style/RedundantRegexpEscape:
|
|
121
|
-
Enabled: true
|
|
122
|
-
|
|
123
|
-
Style/SlicingWithRange:
|
|
124
|
-
Enabled: true
|
|
125
|
-
|
|
126
|
-
Layout/EmptyLineAfterGuardClause:
|
|
127
|
-
Enabled: true
|
|
128
|
-
|
|
129
|
-
Layout/EmptyLinesAroundAttributeAccessor:
|
|
130
|
-
Enabled: true
|
|
131
|
-
|
|
132
|
-
Layout/SpaceAroundMethodCallOperator:
|
|
133
|
-
Enabled: true
|
|
134
|
-
|
|
135
|
-
Performance/AncestorsInclude:
|
|
136
|
-
Enabled: true
|
|
137
|
-
|
|
138
|
-
Performance/BigDecimalWithNumericArgument:
|
|
139
|
-
Enabled: true
|
|
140
|
-
|
|
141
|
-
Performance/RedundantSortBlock:
|
|
142
|
-
Enabled: true
|
|
143
|
-
|
|
144
|
-
Performance/RedundantStringChars:
|
|
145
|
-
Enabled: true
|
|
146
|
-
|
|
147
|
-
Performance/ReverseFirst:
|
|
148
|
-
Enabled: true
|
|
149
|
-
|
|
150
|
-
Performance/SortReverse:
|
|
151
|
-
Enabled: true
|
|
152
|
-
|
|
153
|
-
Performance/Squeeze:
|
|
154
|
-
Enabled: true
|
|
155
|
-
|
|
156
|
-
Performance/StringInclude:
|
|
157
|
-
Enabled: true
|
|
70
|
+
Style/StringConcatenation:
|
|
71
|
+
Enabled: false
|
|
158
72
|
|
|
159
73
|
RSpec/MultipleExpectations:
|
|
160
74
|
Enabled: false
|
|
161
75
|
|
|
76
|
+
RSpec/MultipleMemoizedHelpers:
|
|
77
|
+
Max: 25
|
|
78
|
+
|
|
162
79
|
RSpec/NestedGroups:
|
|
163
80
|
Max: 4
|
|
164
81
|
|
|
@@ -170,3 +87,6 @@ RSpec/InstanceVariable:
|
|
|
170
87
|
|
|
171
88
|
RSpec/LeakyConstantDeclaration:
|
|
172
89
|
Enabled: false
|
|
90
|
+
|
|
91
|
+
Performance/Sum:
|
|
92
|
+
Enabled: false
|
data/CHANGELOG.md
CHANGED
|
@@ -1,3 +1,29 @@
|
|
|
1
|
+
# 0.21.0
|
|
2
|
+
## Breaking change
|
|
3
|
+
- Change the default value of max_iter argument on LinearModel estimators to 1000.
|
|
4
|
+
|
|
5
|
+
# 0.20.3
|
|
6
|
+
- Fix to use automatic solver of PCA in NeighbourhoodComponentAnalysis.
|
|
7
|
+
- Refactor some codes with Rubocop.
|
|
8
|
+
- Update README.
|
|
9
|
+
|
|
10
|
+
# 0.20.2
|
|
11
|
+
- Add cross-validator class for time-series data.
|
|
12
|
+
- [TimeSeriesSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/TimeSeriesSplit.html)
|
|
13
|
+
|
|
14
|
+
# 0.20.1
|
|
15
|
+
- Add cross-validator classes that split data according group labels.
|
|
16
|
+
- [GroupKFold](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/GroupKFold.html)
|
|
17
|
+
- [GroupShuffleSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/GroupShuffleSplit.html)
|
|
18
|
+
- Fix fraction treating of the number of samples on shuffle split cross-validator classes.
|
|
19
|
+
- [ShuffleSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/ShuffleSplit.html)
|
|
20
|
+
- [StratifiedShuffleSplit](https://yoshoku.github.io/rumale/doc/Rumale/ModelSelection/StratifiedShuffleSplit.html)
|
|
21
|
+
- Refactor some codes with Rubocop.
|
|
22
|
+
|
|
23
|
+
# 0.20.0
|
|
24
|
+
## Breaking changes
|
|
25
|
+
- Delete deprecated estimators such as PolynomialModel, Optimizer, and BaseLinearModel.
|
|
26
|
+
|
|
1
27
|
# 0.19.3
|
|
2
28
|
- Add preprocessing class for [Binarizer](https://yoshoku.github.io/rumale/doc/Rumale/Preprocessing/Binarizer.html)
|
|
3
29
|
- Add preprocessing class for [MaxNormalizer](https://yoshoku.github.io/rumale/doc/Rumale/Preprocessing/MaxNormalizer.html)
|
|
@@ -13,6 +39,7 @@
|
|
|
13
39
|
- Fix some typos.
|
|
14
40
|
|
|
15
41
|
# 0.19.0
|
|
42
|
+
## Breaking changes
|
|
16
43
|
- Change mmh3 and mopti gem to non-runtime dependent library.
|
|
17
44
|
- The mmh3 gem is used in [FeatureHasher](https://yoshoku.github.io/rumale/doc/Rumale/FeatureExtraction/FeatureHasher.html).
|
|
18
45
|
You only need to require mmh3 gem when using FeatureHasher.
|
data/Gemfile
CHANGED
data/README.md
CHANGED
|
@@ -228,6 +228,10 @@ When -1 is given to n_jobs parameter, all processors are used.
|
|
|
228
228
|
estimator = Rumale::Ensemble::RandomForestClassifier.new(n_jobs: -1, random_seed: 1)
|
|
229
229
|
```
|
|
230
230
|
|
|
231
|
+
## Novelties
|
|
232
|
+
|
|
233
|
+
* [Rumale SHOP](https://suzuri.jp/yoshoku)
|
|
234
|
+
|
|
231
235
|
## Contributing
|
|
232
236
|
|
|
233
237
|
Bug reports and pull requests are welcome on GitHub at https://github.com/yoshoku/rumale.
|
data/lib/rumale.rb
CHANGED
|
@@ -18,17 +18,10 @@ require 'rumale/base/cluster_analyzer'
|
|
|
18
18
|
require 'rumale/base/transformer'
|
|
19
19
|
require 'rumale/base/splitter'
|
|
20
20
|
require 'rumale/base/evaluator'
|
|
21
|
-
require 'rumale/optimizer/sgd'
|
|
22
|
-
require 'rumale/optimizer/ada_grad'
|
|
23
|
-
require 'rumale/optimizer/rmsprop'
|
|
24
|
-
require 'rumale/optimizer/adam'
|
|
25
|
-
require 'rumale/optimizer/nadam'
|
|
26
|
-
require 'rumale/optimizer/yellow_fin'
|
|
27
21
|
require 'rumale/pipeline/pipeline'
|
|
28
22
|
require 'rumale/pipeline/feature_union'
|
|
29
23
|
require 'rumale/kernel_approximation/rbf'
|
|
30
24
|
require 'rumale/kernel_approximation/nystroem'
|
|
31
|
-
require 'rumale/linear_model/base_linear_model'
|
|
32
25
|
require 'rumale/linear_model/base_sgd'
|
|
33
26
|
require 'rumale/linear_model/svc'
|
|
34
27
|
require 'rumale/linear_model/svr'
|
|
@@ -41,9 +34,6 @@ require 'rumale/kernel_machine/kernel_svc'
|
|
|
41
34
|
require 'rumale/kernel_machine/kernel_pca'
|
|
42
35
|
require 'rumale/kernel_machine/kernel_fda'
|
|
43
36
|
require 'rumale/kernel_machine/kernel_ridge'
|
|
44
|
-
require 'rumale/polynomial_model/base_factorization_machine'
|
|
45
|
-
require 'rumale/polynomial_model/factorization_machine_classifier'
|
|
46
|
-
require 'rumale/polynomial_model/factorization_machine_regressor'
|
|
47
37
|
require 'rumale/multiclass/one_vs_rest_classifier'
|
|
48
38
|
require 'rumale/nearest_neighbors/vp_tree'
|
|
49
39
|
require 'rumale/nearest_neighbors/k_neighbors_classifier'
|
|
@@ -108,9 +98,12 @@ require 'rumale/preprocessing/ordinal_encoder'
|
|
|
108
98
|
require 'rumale/preprocessing/binarizer'
|
|
109
99
|
require 'rumale/preprocessing/polynomial_features'
|
|
110
100
|
require 'rumale/model_selection/k_fold'
|
|
101
|
+
require 'rumale/model_selection/group_k_fold'
|
|
111
102
|
require 'rumale/model_selection/stratified_k_fold'
|
|
112
103
|
require 'rumale/model_selection/shuffle_split'
|
|
104
|
+
require 'rumale/model_selection/group_shuffle_split'
|
|
113
105
|
require 'rumale/model_selection/stratified_shuffle_split'
|
|
106
|
+
require 'rumale/model_selection/time_series_split'
|
|
114
107
|
require 'rumale/model_selection/cross_validation'
|
|
115
108
|
require 'rumale/model_selection/grid_search_cv'
|
|
116
109
|
require 'rumale/model_selection/function'
|
|
@@ -136,7 +136,7 @@ module Rumale
|
|
|
136
136
|
res
|
|
137
137
|
end
|
|
138
138
|
|
|
139
|
-
# rubocop:disable Metrics/AbcSize, Metrics/MethodLength, Metrics/PerceivedComplexity
|
|
139
|
+
# rubocop:disable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
|
|
140
140
|
def condense_tree(hierarchy, min_cluster_size)
|
|
141
141
|
n_edges = hierarchy.size
|
|
142
142
|
root = 2 * n_edges
|
|
@@ -265,7 +265,7 @@ module Rumale
|
|
|
265
265
|
end
|
|
266
266
|
res
|
|
267
267
|
end
|
|
268
|
-
# rubocop:enable Metrics/AbcSize, Metrics/MethodLength, Metrics/PerceivedComplexity
|
|
268
|
+
# rubocop:enable Metrics/AbcSize, Metrics/CyclomaticComplexity, Metrics/MethodLength, Metrics/PerceivedComplexity
|
|
269
269
|
end
|
|
270
270
|
end
|
|
271
271
|
end
|
|
@@ -51,7 +51,7 @@ module Rumale
|
|
|
51
51
|
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to be used for cluster analysis.
|
|
52
52
|
# If the metric is 'precomputed', x must be a square distance matrix (shape: [n_samples, n_samples]).
|
|
53
53
|
# @return [Numo::Int32] (shape: [n_samples]) Predicted cluster label per sample.
|
|
54
|
-
def fit_predict(x)
|
|
54
|
+
def fit_predict(x) # rubocop:disable Lint/UselessMethodDefinition
|
|
55
55
|
super
|
|
56
56
|
end
|
|
57
57
|
|
data/lib/rumale/dataset.rb
CHANGED
|
@@ -225,7 +225,7 @@ module Rumale
|
|
|
225
225
|
line = dump_label(label, label_type.to_s)
|
|
226
226
|
ftvec.to_a.each_with_index do |val, n|
|
|
227
227
|
idx = n + (zero_based == false ? 1 : 0)
|
|
228
|
-
line += format(" %d:#{value_type}", idx, val) if val != 0
|
|
228
|
+
line += format(" %d:#{value_type}", idx, val) if val != 0
|
|
229
229
|
end
|
|
230
230
|
line
|
|
231
231
|
end
|
|
@@ -77,7 +77,7 @@ module Rumale
|
|
|
77
77
|
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
|
|
78
78
|
def transform(x)
|
|
79
79
|
x = check_convert_sample_array(x)
|
|
80
|
-
partial_fit(x, false)
|
|
80
|
+
partial_fit(x, update_comps: false)
|
|
81
81
|
end
|
|
82
82
|
|
|
83
83
|
# Inverse transform the given transformed data with the learned model.
|
|
@@ -91,7 +91,7 @@ module Rumale
|
|
|
91
91
|
|
|
92
92
|
private
|
|
93
93
|
|
|
94
|
-
def partial_fit(x, update_comps
|
|
94
|
+
def partial_fit(x, update_comps: true)
|
|
95
95
|
# initialize some variables.
|
|
96
96
|
n_samples, n_features = x.shape
|
|
97
97
|
scale = Math.sqrt(x.mean / @params[:n_components])
|
|
@@ -85,7 +85,7 @@ module Rumale
|
|
|
85
85
|
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
|
86
86
|
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
|
|
87
87
|
# @return [RandomForestClassifier] The learned classifier itself.
|
|
88
|
-
def fit(x, y)
|
|
88
|
+
def fit(x, y) # rubocop:disable Metrics/AbcSize
|
|
89
89
|
x = check_convert_sample_array(x)
|
|
90
90
|
y = check_convert_label_array(y)
|
|
91
91
|
check_sample_label_size(x, y)
|
|
@@ -79,7 +79,7 @@ module Rumale
|
|
|
79
79
|
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
|
|
80
80
|
# @param y [Numo::DFloat] (shape: [n_samples, n_outputs]) The target values to be used for fitting the model.
|
|
81
81
|
# @return [RandomForestRegressor] The learned regressor itself.
|
|
82
|
-
def fit(x, y)
|
|
82
|
+
def fit(x, y) # rubocop:disable Metrics/AbcSize
|
|
83
83
|
x = check_convert_sample_array(x)
|
|
84
84
|
y = check_convert_tvalue_array(y)
|
|
85
85
|
check_sample_tvalue_size(x, y)
|
|
@@ -75,9 +75,12 @@ module Rumale
|
|
|
75
75
|
false_pos, true_pos, thresholds = binary_roc_curve(y_true, y_score, pos_label)
|
|
76
76
|
|
|
77
77
|
if true_pos.size.zero? || false_pos[0] != 0 || true_pos[0] != 0
|
|
78
|
+
# NOTE: Numo::NArray#insert is not a destructive method.
|
|
79
|
+
# rubocop:disable Style/RedundantSelfAssignment
|
|
78
80
|
true_pos = true_pos.insert(0, 0)
|
|
79
81
|
false_pos = false_pos.insert(0, 0)
|
|
80
82
|
thresholds = thresholds.insert(0, thresholds[0] + 1)
|
|
83
|
+
# rubocop:enable Style/RedundantSelfAssignment
|
|
81
84
|
end
|
|
82
85
|
|
|
83
86
|
tpr = true_pos / true_pos[-1].to_f
|
|
@@ -67,7 +67,7 @@ module Rumale
|
|
|
67
67
|
def transform(x)
|
|
68
68
|
raise 'FeatureHasher#transform requires Mmh3 but that is not loaded.' unless enable_mmh3?
|
|
69
69
|
|
|
70
|
-
x = [x] unless x.is_a?(Array)
|
|
70
|
+
x = [x] unless x.is_a?(Array) # rubocop:disable Style/ArrayCoercion
|
|
71
71
|
n_samples = x.size
|
|
72
72
|
|
|
73
73
|
z = Numo::DFloat.zeros(n_samples, n_features)
|
|
@@ -99,7 +99,7 @@ module Rumale
|
|
|
99
99
|
# @param x [Array<Hash>] (shape: [n_samples]) The array of hash consisting of feature names and values.
|
|
100
100
|
# @return [Numo::DFloat] (shape: [n_samples, n_features]) The encoded sample array.
|
|
101
101
|
def transform(x)
|
|
102
|
-
x = [x] unless x.is_a?(Array)
|
|
102
|
+
x = [x] unless x.is_a?(Array) # rubocop:disable Style/ArrayCoercion
|
|
103
103
|
n_samples = x.size
|
|
104
104
|
n_features = @vocabulary.size
|
|
105
105
|
z = Numo::DFloat.zeros(n_samples, n_features)
|
|
@@ -10,7 +10,7 @@ module Rumale
|
|
|
10
10
|
#
|
|
11
11
|
# @example
|
|
12
12
|
# estimator =
|
|
13
|
-
# Rumale::LinearModel::ElasticNet.new(reg_param: 0.1, l1_ratio: 0.5, max_iter:
|
|
13
|
+
# Rumale::LinearModel::ElasticNet.new(reg_param: 0.1, l1_ratio: 0.5, max_iter: 1000, batch_size: 50, random_seed: 1)
|
|
14
14
|
# estimator.fit(training_samples, traininig_values)
|
|
15
15
|
# results = estimator.predict(testing_samples)
|
|
16
16
|
#
|
|
@@ -59,7 +59,7 @@ module Rumale
|
|
|
59
59
|
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
|
60
60
|
def initialize(learning_rate: 0.01, decay: nil, momentum: 0.9,
|
|
61
61
|
reg_param: 1.0, l1_ratio: 0.5, fit_bias: true, bias_scale: 1.0,
|
|
62
|
-
max_iter:
|
|
62
|
+
max_iter: 1000, batch_size: 50, tol: 1e-4,
|
|
63
63
|
n_jobs: nil, verbose: false, random_seed: nil)
|
|
64
64
|
check_params_numeric(learning_rate: learning_rate, momentum: momentum,
|
|
65
65
|
reg_param: reg_param, l1_ratio: l1_ratio, bias_scale: bias_scale,
|
|
@@ -10,7 +10,7 @@ module Rumale
|
|
|
10
10
|
#
|
|
11
11
|
# @example
|
|
12
12
|
# estimator =
|
|
13
|
-
# Rumale::LinearModel::Lasso.new(reg_param: 0.1, max_iter:
|
|
13
|
+
# Rumale::LinearModel::Lasso.new(reg_param: 0.1, max_iter: 1000, batch_size: 20, random_seed: 1)
|
|
14
14
|
# estimator.fit(training_samples, traininig_values)
|
|
15
15
|
# results = estimator.predict(testing_samples)
|
|
16
16
|
#
|
|
@@ -55,7 +55,7 @@ module Rumale
|
|
|
55
55
|
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
|
56
56
|
def initialize(learning_rate: 0.01, decay: nil, momentum: 0.9,
|
|
57
57
|
reg_param: 1.0, fit_bias: true, bias_scale: 1.0,
|
|
58
|
-
max_iter:
|
|
58
|
+
max_iter: 1000, batch_size: 50, tol: 1e-4,
|
|
59
59
|
n_jobs: nil, verbose: false, random_seed: nil)
|
|
60
60
|
check_params_numeric(learning_rate: learning_rate, momentum: momentum,
|
|
61
61
|
reg_param: reg_param, bias_scale: bias_scale,
|
|
@@ -10,7 +10,7 @@ module Rumale
|
|
|
10
10
|
#
|
|
11
11
|
# @example
|
|
12
12
|
# estimator =
|
|
13
|
-
# Rumale::LinearModel::LinearRegression.new(max_iter:
|
|
13
|
+
# Rumale::LinearModel::LinearRegression.new(max_iter: 1000, batch_size: 20, random_seed: 1)
|
|
14
14
|
# estimator.fit(training_samples, traininig_values)
|
|
15
15
|
# results = estimator.predict(testing_samples)
|
|
16
16
|
#
|
|
@@ -68,7 +68,7 @@ module Rumale
|
|
|
68
68
|
# If solver = 'svd', this parameter is ignored.
|
|
69
69
|
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
|
70
70
|
def initialize(learning_rate: 0.01, decay: nil, momentum: 0.9,
|
|
71
|
-
fit_bias: true, bias_scale: 1.0, max_iter:
|
|
71
|
+
fit_bias: true, bias_scale: 1.0, max_iter: 1000, batch_size: 50, tol: 1e-4,
|
|
72
72
|
solver: 'auto',
|
|
73
73
|
n_jobs: nil, verbose: false, random_seed: nil)
|
|
74
74
|
check_params_numeric(learning_rate: learning_rate, momentum: momentum,
|
|
@@ -15,7 +15,7 @@ module Rumale
|
|
|
15
15
|
#
|
|
16
16
|
# @example
|
|
17
17
|
# estimator =
|
|
18
|
-
# Rumale::LinearModel::LogisticRegression.new(reg_param: 1.0, max_iter:
|
|
18
|
+
# Rumale::LinearModel::LogisticRegression.new(reg_param: 1.0, max_iter: 1000, batch_size: 50, random_seed: 1)
|
|
19
19
|
# estimator.fit(training_samples, traininig_labels)
|
|
20
20
|
# results = estimator.predict(testing_samples)
|
|
21
21
|
#
|
|
@@ -72,7 +72,7 @@ module Rumale
|
|
|
72
72
|
def initialize(learning_rate: 0.01, decay: nil, momentum: 0.9,
|
|
73
73
|
penalty: 'l2', reg_param: 1.0, l1_ratio: 0.5,
|
|
74
74
|
fit_bias: true, bias_scale: 1.0,
|
|
75
|
-
max_iter:
|
|
75
|
+
max_iter: 1000, batch_size: 50, tol: 1e-4,
|
|
76
76
|
n_jobs: nil, verbose: false, random_seed: nil)
|
|
77
77
|
check_params_numeric(learning_rate: learning_rate, momentum: momentum,
|
|
78
78
|
reg_param: reg_param, l1_ratio: l1_ratio, bias_scale: bias_scale,
|