svmkit 0.2.8 → 0.2.9
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +5 -5
- data/.gitignore +4 -0
- data/.rubocop.yml +10 -1
- data/.rubocop_todo.yml +51 -10
- data/Gemfile +1 -1
- data/HISTORY.md +43 -33
- data/lib/svmkit.rb +4 -0
- data/lib/svmkit/base/classifier.rb +1 -0
- data/lib/svmkit/ensemble/random_forest_classifier.rb +5 -2
- data/lib/svmkit/evaluation_measure/log_loss.rb +44 -0
- data/lib/svmkit/kernel_approximation/rbf.rb +1 -1
- data/lib/svmkit/kernel_machine/kernel_svc.rb +40 -2
- data/lib/svmkit/linear_model/logistic_regression.rb +3 -1
- data/lib/svmkit/linear_model/svc.rb +46 -7
- data/lib/svmkit/model_selection/cross_validation.rb +9 -1
- data/lib/svmkit/model_selection/k_fold.rb +1 -1
- data/lib/svmkit/model_selection/stratified_k_fold.rb +3 -2
- data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +1 -0
- data/lib/svmkit/naive_bayes/naive_bayes.rb +5 -0
- data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +2 -0
- data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +4 -1
- data/lib/svmkit/preprocessing/label_encoder.rb +94 -0
- data/lib/svmkit/preprocessing/one_hot_encoder.rb +98 -0
- data/lib/svmkit/probabilistic_output.rb +112 -0
- data/lib/svmkit/tree/decision_tree_classifier.rb +80 -10
- data/lib/svmkit/validation.rb +12 -0
- data/lib/svmkit/version.rb +1 -1
- data/svmkit.gemspec +4 -6
- metadata +18 -14
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
|
-
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 533508a3afd82d2bae3ddea3a5669f6d389688155d44649fd3eafaaff8207e0f
|
4
|
+
data.tar.gz: 43ff09b3bab72b68bc7a6b3740902be64508496337a4cde61057d33b91d0f349
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: e1c1bed8269d3c768d75bd8a5e731b5d2da689ef7a235a70c5ea87090aac79889c9fe0a004eca73c3015aae42d068f44b2b1e3a61a03b641607b2909441513b6
|
7
|
+
data.tar.gz: 80a18ca4ec7eb2740148829024f0625c835f24b771bb321168d0cc3233d8e152257b5515355d99a968dc25a670f9a69f3e30b42bf190757206a64bbcd2babcd6
|
data/.gitignore
CHANGED
data/.rubocop.yml
CHANGED
@@ -9,7 +9,7 @@ Documentation:
|
|
9
9
|
Enabled: false
|
10
10
|
|
11
11
|
Metrics/LineLength:
|
12
|
-
Max:
|
12
|
+
Max: 140
|
13
13
|
IgnoredPatterns: ['(\A|\s)#']
|
14
14
|
|
15
15
|
Metrics/ModuleLength:
|
@@ -21,6 +21,9 @@ Metrics/ClassLength:
|
|
21
21
|
Metrics/MethodLength:
|
22
22
|
Max: 40
|
23
23
|
|
24
|
+
Metrics/AbcSize:
|
25
|
+
Max: 60
|
26
|
+
|
24
27
|
Metrics/BlockLength:
|
25
28
|
Exclude:
|
26
29
|
- 'spec/**/*'
|
@@ -30,3 +33,9 @@ ParameterLists:
|
|
30
33
|
|
31
34
|
Security/MarshalLoad:
|
32
35
|
Enabled: false
|
36
|
+
|
37
|
+
Naming/UncommunicativeMethodParamName:
|
38
|
+
Enabled: false
|
39
|
+
|
40
|
+
Style/FormatStringToken:
|
41
|
+
Enabled: false
|
data/.rubocop_todo.yml
CHANGED
@@ -1,18 +1,59 @@
|
|
1
1
|
# This configuration was generated by
|
2
2
|
# `rubocop --auto-gen-config`
|
3
|
-
# on 2018-
|
3
|
+
# on 2018-04-14 20:44:19 +0900 using RuboCop version 0.54.0.
|
4
4
|
# The point is for the user to remove these configuration records
|
5
5
|
# one by one as the offenses are removed from the code base.
|
6
6
|
# Note that changes in the inspected code, or installation of new
|
7
7
|
# versions of RuboCop, may require this file to be generated again.
|
8
8
|
|
9
|
-
# Offense count:
|
10
|
-
|
11
|
-
|
9
|
+
# Offense count: 1
|
10
|
+
# Configuration parameters: Include.
|
11
|
+
# Include: **/*.gemspec
|
12
|
+
Gemspec/RequiredRubyVersion:
|
13
|
+
Exclude:
|
14
|
+
- 'svmkit.gemspec'
|
12
15
|
|
13
|
-
# Offense count:
|
14
|
-
#
|
15
|
-
#
|
16
|
-
|
17
|
-
|
18
|
-
|
16
|
+
# Offense count: 3
|
17
|
+
# Cop supports --auto-correct.
|
18
|
+
# Configuration parameters: EnforcedStyle.
|
19
|
+
# SupportedStyles: auto_detection, squiggly, active_support, powerpack, unindent
|
20
|
+
Layout/IndentHeredoc:
|
21
|
+
Exclude:
|
22
|
+
- 'svmkit.gemspec'
|
23
|
+
|
24
|
+
# Offense count: 1
|
25
|
+
# Configuration parameters: CountComments, ExcludedMethods.
|
26
|
+
Metrics/BlockLength:
|
27
|
+
Max: 30
|
28
|
+
|
29
|
+
# Offense count: 1
|
30
|
+
Metrics/CyclomaticComplexity:
|
31
|
+
Max: 12
|
32
|
+
|
33
|
+
# Offense count: 1
|
34
|
+
Metrics/PerceivedComplexity:
|
35
|
+
Max: 12
|
36
|
+
|
37
|
+
# Offense count: 1
|
38
|
+
# Cop supports --auto-correct.
|
39
|
+
Style/Encoding:
|
40
|
+
Exclude:
|
41
|
+
- 'svmkit.gemspec'
|
42
|
+
|
43
|
+
# Offense count: 1
|
44
|
+
# Cop supports --auto-correct.
|
45
|
+
# Configuration parameters: EnforcedStyle, UseHashRocketsWithSymbolValues, PreferHashRocketsForNonAlnumEndingSymbols.
|
46
|
+
# SupportedStyles: ruby19, hash_rockets, no_mixed_keys, ruby19_no_mixed_keys
|
47
|
+
Style/HashSyntax:
|
48
|
+
Exclude:
|
49
|
+
- 'Rakefile'
|
50
|
+
|
51
|
+
# Offense count: 6
|
52
|
+
# Cop supports --auto-correct.
|
53
|
+
# Configuration parameters: EnforcedStyle, ConsistentQuotesInMultiline.
|
54
|
+
# SupportedStyles: single_quotes, double_quotes
|
55
|
+
Style/StringLiterals:
|
56
|
+
Exclude:
|
57
|
+
- 'Gemfile'
|
58
|
+
- 'Rakefile'
|
59
|
+
- 'bin/console'
|
data/Gemfile
CHANGED
data/HISTORY.md
CHANGED
@@ -1,59 +1,69 @@
|
|
1
|
+
# 0.2.9
|
2
|
+
- Add predict_proba method to SVC and KernelSVC.
|
3
|
+
- Add class for evaluating logarithmic loss.
|
4
|
+
- Add classes for Label- and One-Hot- encoding.
|
5
|
+
- Add some validator.
|
6
|
+
- Fix bug on training data score calculation of cross validation.
|
7
|
+
- Fix fit method of SVC for performance.
|
8
|
+
- Fix criterion calculation on Decision Tree for performance.
|
9
|
+
- Fix data structure of Decision Tree for performance.
|
10
|
+
|
1
11
|
# 0.2.8
|
2
|
-
-
|
3
|
-
-
|
4
|
-
-
|
12
|
+
- Fix bug on gradient calculation of Logistic Regression.
|
13
|
+
- Fix to change accessor of params of estimators to read only.
|
14
|
+
- Add parameter validation.
|
5
15
|
|
6
16
|
# 0.2.7
|
7
|
-
-
|
17
|
+
- Fix to support multiclass classifiction into LinearSVC, LogisticRegression, KernelSVC, and FactorizationMachineClassifier
|
8
18
|
|
9
19
|
# 0.2.6
|
10
|
-
-
|
11
|
-
-
|
12
|
-
-
|
13
|
-
-
|
14
|
-
-
|
20
|
+
- Add class for Decision Tree classifier.
|
21
|
+
- Add class for Random Forest classifier.
|
22
|
+
- Fix to use frozen string literal.
|
23
|
+
- Refactor marshal dump method on some classes.
|
24
|
+
- Introduce Coveralls to confirm test coverage.
|
15
25
|
|
16
26
|
# 0.2.5
|
17
|
-
-
|
18
|
-
-
|
19
|
-
-
|
27
|
+
- Add classes for Naive Bayes classifier.
|
28
|
+
- Fix decision function method on Logistic Regression class.
|
29
|
+
- Fix method visibility on RBF kernel approximation class.
|
20
30
|
|
21
31
|
# 0.2.4
|
22
|
-
-
|
23
|
-
-
|
24
|
-
-
|
32
|
+
- Add class for Factorization Machine classifier.
|
33
|
+
- Add classes for evaluation measures.
|
34
|
+
- Fix the method for prediction of class probability in Logistic Regression.
|
25
35
|
|
26
36
|
# 0.2.3
|
27
|
-
-
|
28
|
-
-
|
29
|
-
-
|
37
|
+
- Add class for cross validation.
|
38
|
+
- Add specs for base modules.
|
39
|
+
- Fix validation of the number of splits when a negative label is given.
|
30
40
|
|
31
41
|
# 0.2.2
|
32
|
-
-
|
42
|
+
- Add data splitter classes for K-fold cross validation.
|
33
43
|
|
34
44
|
# 0.2.1
|
35
|
-
-
|
45
|
+
- Add class for K-nearest neighbors classifier.
|
36
46
|
|
37
47
|
# 0.2.0
|
38
48
|
- Migrated the linear algebra library to Numo::NArray.
|
39
|
-
-
|
49
|
+
- Add module for loading and saving libsvm format file.
|
40
50
|
|
41
51
|
# 0.1.3
|
42
|
-
-
|
43
|
-
-
|
52
|
+
- Add class for Kernel Support Vector Machine with Pegasos algorithm.
|
53
|
+
- Add module for calculating pairwise kernel fuctions and euclidean distances.
|
44
54
|
|
45
55
|
# 0.1.2
|
46
|
-
-
|
47
|
-
-
|
56
|
+
- Add the function learning a model with bias term to the PegasosSVC and LogisticRegression classes.
|
57
|
+
- Rewrite the document with yard notation.
|
48
58
|
|
49
59
|
# 0.1.1
|
50
|
-
-
|
51
|
-
-
|
60
|
+
- Add class for Logistic Regression with SGD optimization.
|
61
|
+
- Fix some mistakes on the document.
|
52
62
|
|
53
63
|
# 0.1.0
|
54
|
-
-
|
55
|
-
-
|
56
|
-
-
|
57
|
-
-
|
58
|
-
-
|
59
|
-
-
|
64
|
+
- Add basic classes.
|
65
|
+
- Add an utility module.
|
66
|
+
- Add class for RBF kernel approximation.
|
67
|
+
- Add class for Support Vector Machine with Pegasos alogrithm.
|
68
|
+
- Add class that performs mutlclass classification with one-vs.-rest strategy.
|
69
|
+
- Add classes for preprocessing such as min-max scaling, standardization, and L2 normalization.
|
data/lib/svmkit.rb
CHANGED
@@ -6,6 +6,7 @@ require 'svmkit/version'
|
|
6
6
|
require 'svmkit/validation'
|
7
7
|
require 'svmkit/pairwise_metric'
|
8
8
|
require 'svmkit/dataset'
|
9
|
+
require 'svmkit/probabilistic_output'
|
9
10
|
require 'svmkit/base/base_estimator'
|
10
11
|
require 'svmkit/base/classifier'
|
11
12
|
require 'svmkit/base/transformer'
|
@@ -24,6 +25,8 @@ require 'svmkit/ensemble/random_forest_classifier'
|
|
24
25
|
require 'svmkit/preprocessing/l2_normalizer'
|
25
26
|
require 'svmkit/preprocessing/min_max_scaler'
|
26
27
|
require 'svmkit/preprocessing/standard_scaler'
|
28
|
+
require 'svmkit/preprocessing/label_encoder'
|
29
|
+
require 'svmkit/preprocessing/one_hot_encoder'
|
27
30
|
require 'svmkit/model_selection/k_fold'
|
28
31
|
require 'svmkit/model_selection/stratified_k_fold'
|
29
32
|
require 'svmkit/model_selection/cross_validation'
|
@@ -31,3 +34,4 @@ require 'svmkit/evaluation_measure/accuracy'
|
|
31
34
|
require 'svmkit/evaluation_measure/precision'
|
32
35
|
require 'svmkit/evaluation_measure/recall'
|
33
36
|
require 'svmkit/evaluation_measure/f_score'
|
37
|
+
require 'svmkit/evaluation_measure/log_loss'
|
@@ -22,6 +22,7 @@ module SVMKit
|
|
22
22
|
def score(x, y)
|
23
23
|
SVMKit::Validation.check_sample_array(x)
|
24
24
|
SVMKit::Validation.check_label_array(y)
|
25
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
25
26
|
evaluator = SVMKit::EvaluationMeasure::Accuracy.new
|
26
27
|
evaluator.score(y, predict(x))
|
27
28
|
end
|
@@ -51,10 +51,12 @@ module SVMKit
|
|
51
51
|
def initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1,
|
52
52
|
max_features: nil, random_seed: nil)
|
53
53
|
SVMKit::Validation.check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
|
54
|
-
|
54
|
+
max_features: max_features, random_seed: random_seed)
|
55
55
|
SVMKit::Validation.check_params_integer(n_estimators: n_estimators, min_samples_leaf: min_samples_leaf)
|
56
56
|
SVMKit::Validation.check_params_string(criterion: criterion)
|
57
|
-
|
57
|
+
SVMKit::Validation.check_params_positive(n_estimators: n_estimators, max_depth: max_depth,
|
58
|
+
max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf,
|
59
|
+
max_features: max_features)
|
58
60
|
@params = {}
|
59
61
|
@params[:n_estimators] = n_estimators
|
60
62
|
@params[:criterion] = criterion
|
@@ -78,6 +80,7 @@ module SVMKit
|
|
78
80
|
def fit(x, y)
|
79
81
|
SVMKit::Validation.check_sample_array(x)
|
80
82
|
SVMKit::Validation.check_label_array(y)
|
83
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
81
84
|
# Initialize some variables.
|
82
85
|
n_samples, n_features = x.shape
|
83
86
|
@params[:max_features] = n_features unless @params[:max_features].is_a?(Integer)
|
@@ -0,0 +1,44 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'svmkit/base/evaluator'
|
4
|
+
|
5
|
+
module SVMKit
|
6
|
+
module EvaluationMeasure
|
7
|
+
# LogLoss is a class that calculates the logarithmic loss of predicted class probability.
|
8
|
+
#
|
9
|
+
# @example
|
10
|
+
# evaluator = SVMKit::EvaluationMeasure::LogLoss.new
|
11
|
+
# puts evaluator.score(ground_truth, predicted)
|
12
|
+
class LogLoss
|
13
|
+
include Base::Evaluator
|
14
|
+
|
15
|
+
# Claculate mean logarithmic loss.
|
16
|
+
# If both y_true and y_pred are array (both shapes are [n_samples]), this method calculates
|
17
|
+
# mean logarithmic loss for binary classification.
|
18
|
+
#
|
19
|
+
# @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
|
20
|
+
# @param y_pred [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted class probability.
|
21
|
+
# @param eps [Float] A small value close to zero to avoid outputting infinity in logarithmic calcuation.
|
22
|
+
# @return [Float] mean logarithmic loss
|
23
|
+
def score(y_true, y_pred, eps = 1e-15)
|
24
|
+
SVMKit::Validation.check_params_type(Numo::Int32, y_true: y_true)
|
25
|
+
SVMKit::Validation.check_params_type(Numo::DFloat, y_pred: y_pred)
|
26
|
+
|
27
|
+
n_samples, n_classes = y_pred.shape
|
28
|
+
clipped_p = y_pred.clip(eps, 1 - eps)
|
29
|
+
|
30
|
+
log_loss = if n_classes.nil?
|
31
|
+
negative_label = y_true.to_a.uniq.sort.first
|
32
|
+
bin_y_true = Numo::DFloat.cast(y_true.ne(negative_label))
|
33
|
+
-(bin_y_true * Numo::NMath.log(clipped_p) + (1 - bin_y_true) * Numo::NMath.log(1 - clipped_p))
|
34
|
+
else
|
35
|
+
encoder = SVMKit::Preprocessing::OneHotEncoder.new
|
36
|
+
encoded_y_true = encoder.fit_transform(y_true)
|
37
|
+
clipped_p /= clipped_p.sum(1).expand_dims(1)
|
38
|
+
-(encoded_y_true * Numo::NMath.log(clipped_p)).sum(1)
|
39
|
+
end
|
40
|
+
log_loss.sum / n_samples
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
@@ -40,7 +40,7 @@ module SVMKit
|
|
40
40
|
SVMKit::Validation.check_params_float(gamma: gamma)
|
41
41
|
SVMKit::Validation.check_params_integer(n_components: n_components)
|
42
42
|
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
43
|
-
|
43
|
+
SVMKit::Validation.check_params_positive(gamma: gamma, n_components: n_components)
|
44
44
|
@params = {}
|
45
45
|
@params[:gamma] = gamma
|
46
46
|
@params[:n_components] = n_components
|
@@ -40,18 +40,22 @@ module SVMKit
|
|
40
40
|
#
|
41
41
|
# @param reg_param [Float] The regularization parameter.
|
42
42
|
# @param max_iter [Integer] The maximum number of iterations.
|
43
|
+
# @param probability [Boolean] The flag indicating whether to perform probability estimation.
|
43
44
|
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
44
|
-
def initialize(reg_param: 1.0, max_iter: 1000, random_seed: nil)
|
45
|
+
def initialize(reg_param: 1.0, max_iter: 1000, probability: false, random_seed: nil)
|
45
46
|
SVMKit::Validation.check_params_float(reg_param: reg_param)
|
46
47
|
SVMKit::Validation.check_params_integer(max_iter: max_iter)
|
48
|
+
SVMKit::Validation.check_params_boolean(probability: probability)
|
47
49
|
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
48
|
-
|
50
|
+
SVMKit::Validation.check_params_positive(reg_param: reg_param, max_iter: max_iter)
|
49
51
|
@params = {}
|
50
52
|
@params[:reg_param] = reg_param
|
51
53
|
@params[:max_iter] = max_iter
|
54
|
+
@params[:probability] = probability
|
52
55
|
@params[:random_seed] = random_seed
|
53
56
|
@params[:random_seed] ||= srand
|
54
57
|
@weight_vec = nil
|
58
|
+
@prob_param = nil
|
55
59
|
@classes = nil
|
56
60
|
@rng = Random.new(@params[:random_seed])
|
57
61
|
end
|
@@ -65,6 +69,7 @@ module SVMKit
|
|
65
69
|
def fit(x, y)
|
66
70
|
SVMKit::Validation.check_sample_array(x)
|
67
71
|
SVMKit::Validation.check_label_array(y)
|
72
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
68
73
|
|
69
74
|
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
70
75
|
n_classes = @classes.size
|
@@ -72,14 +77,25 @@ module SVMKit
|
|
72
77
|
|
73
78
|
if n_classes > 2
|
74
79
|
@weight_vec = Numo::DFloat.zeros(n_classes, n_features)
|
80
|
+
@prob_param = Numo::DFloat.zeros(n_classes, 2)
|
75
81
|
n_classes.times do |n|
|
76
82
|
bin_y = Numo::Int32.cast(y.eq(@classes[n])) * 2 - 1
|
77
83
|
@weight_vec[n, true] = binary_fit(x, bin_y)
|
84
|
+
@prob_param[n, true] = if @params[:probability]
|
85
|
+
SVMKit::ProbabilisticOutput.fit_sigmoid(x.dot(@weight_vec[n, true].transpose), bin_y)
|
86
|
+
else
|
87
|
+
Numo::DFloat[1, 0]
|
88
|
+
end
|
78
89
|
end
|
79
90
|
else
|
80
91
|
negative_label = y.to_a.uniq.sort.first
|
81
92
|
bin_y = Numo::Int32.cast(y.ne(negative_label)) * 2 - 1
|
82
93
|
@weight_vec = binary_fit(x, bin_y)
|
94
|
+
@prob_param = if @params[:probability]
|
95
|
+
SVMKit::ProbabilisticOutput.fit_sigmoid(x.dot(@weight_vec.transpose), bin_y)
|
96
|
+
else
|
97
|
+
Numo::DFloat[1, 0]
|
98
|
+
end
|
83
99
|
end
|
84
100
|
|
85
101
|
self
|
@@ -111,11 +127,32 @@ module SVMKit
|
|
111
127
|
Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
|
112
128
|
end
|
113
129
|
|
130
|
+
# Predict probability for samples.
|
131
|
+
#
|
132
|
+
# @param x [Numo::DFloat] (shape: [n_testing_samples, n_training_samples])
|
133
|
+
# The kernel matrix between testing samples and training samples to predict the labels.
|
134
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
|
135
|
+
def predict_proba(x)
|
136
|
+
SVMKit::Validation.check_sample_array(x)
|
137
|
+
|
138
|
+
if @classes.size > 2
|
139
|
+
probs = 1.0 / (Numo::NMath.exp(@prob_param[true, 0] * decision_function(x) + @prob_param[true, 1]) + 1.0)
|
140
|
+
return (probs.transpose / probs.sum(axis: 1)).transpose
|
141
|
+
end
|
142
|
+
|
143
|
+
n_samples, = x.shape
|
144
|
+
probs = Numo::DFloat.zeros(n_samples, 2)
|
145
|
+
probs[true, 1] = 1.0 / (Numo::NMath.exp(@prob_param[0] * decision_function(x) + @prob_param[1]) + 1.0)
|
146
|
+
probs[true, 0] = 1.0 - probs[true, 1]
|
147
|
+
probs
|
148
|
+
end
|
149
|
+
|
114
150
|
# Dump marshal data.
|
115
151
|
# @return [Hash] The marshal data about KernelSVC.
|
116
152
|
def marshal_dump
|
117
153
|
{ params: @params,
|
118
154
|
weight_vec: @weight_vec,
|
155
|
+
prob_param: @prob_param,
|
119
156
|
classes: @classes,
|
120
157
|
rng: @rng }
|
121
158
|
end
|
@@ -125,6 +162,7 @@ module SVMKit
|
|
125
162
|
def marshal_load(obj)
|
126
163
|
@params = obj[:params]
|
127
164
|
@weight_vec = obj[:weight_vec]
|
165
|
+
@prob_param = obj[:prob_param]
|
128
166
|
@classes = obj[:classes]
|
129
167
|
@rng = obj[:rng]
|
130
168
|
nil
|
@@ -54,7 +54,8 @@ module SVMKit
|
|
54
54
|
SVMKit::Validation.check_params_integer(max_iter: max_iter, batch_size: batch_size)
|
55
55
|
SVMKit::Validation.check_params_boolean(fit_bias: fit_bias, normalize: normalize)
|
56
56
|
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
57
|
-
|
57
|
+
SVMKit::Validation.check_params_positive(reg_param: reg_param, bias_scale: bias_scale, max_iter: max_iter,
|
58
|
+
batch_size: batch_size)
|
58
59
|
@params = {}
|
59
60
|
@params[:reg_param] = reg_param
|
60
61
|
@params[:fit_bias] = fit_bias
|
@@ -78,6 +79,7 @@ module SVMKit
|
|
78
79
|
def fit(x, y)
|
79
80
|
SVMKit::Validation.check_sample_array(x)
|
80
81
|
SVMKit::Validation.check_label_array(y)
|
82
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
81
83
|
|
82
84
|
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
83
85
|
n_classes = @classes.size
|