svmkit 0.2.8 → 0.2.9
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +5 -5
- data/.gitignore +4 -0
- data/.rubocop.yml +10 -1
- data/.rubocop_todo.yml +51 -10
- data/Gemfile +1 -1
- data/HISTORY.md +43 -33
- data/lib/svmkit.rb +4 -0
- data/lib/svmkit/base/classifier.rb +1 -0
- data/lib/svmkit/ensemble/random_forest_classifier.rb +5 -2
- data/lib/svmkit/evaluation_measure/log_loss.rb +44 -0
- data/lib/svmkit/kernel_approximation/rbf.rb +1 -1
- data/lib/svmkit/kernel_machine/kernel_svc.rb +40 -2
- data/lib/svmkit/linear_model/logistic_regression.rb +3 -1
- data/lib/svmkit/linear_model/svc.rb +46 -7
- data/lib/svmkit/model_selection/cross_validation.rb +9 -1
- data/lib/svmkit/model_selection/k_fold.rb +1 -1
- data/lib/svmkit/model_selection/stratified_k_fold.rb +3 -2
- data/lib/svmkit/multiclass/one_vs_rest_classifier.rb +1 -0
- data/lib/svmkit/naive_bayes/naive_bayes.rb +5 -0
- data/lib/svmkit/nearest_neighbors/k_neighbors_classifier.rb +2 -0
- data/lib/svmkit/polynomial_model/factorization_machine_classifier.rb +4 -1
- data/lib/svmkit/preprocessing/label_encoder.rb +94 -0
- data/lib/svmkit/preprocessing/one_hot_encoder.rb +98 -0
- data/lib/svmkit/probabilistic_output.rb +112 -0
- data/lib/svmkit/tree/decision_tree_classifier.rb +80 -10
- data/lib/svmkit/validation.rb +12 -0
- data/lib/svmkit/version.rb +1 -1
- data/svmkit.gemspec +4 -6
- metadata +18 -14
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
|
-
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: 533508a3afd82d2bae3ddea3a5669f6d389688155d44649fd3eafaaff8207e0f
|
4
|
+
data.tar.gz: 43ff09b3bab72b68bc7a6b3740902be64508496337a4cde61057d33b91d0f349
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: e1c1bed8269d3c768d75bd8a5e731b5d2da689ef7a235a70c5ea87090aac79889c9fe0a004eca73c3015aae42d068f44b2b1e3a61a03b641607b2909441513b6
|
7
|
+
data.tar.gz: 80a18ca4ec7eb2740148829024f0625c835f24b771bb321168d0cc3233d8e152257b5515355d99a968dc25a670f9a69f3e30b42bf190757206a64bbcd2babcd6
|
data/.gitignore
CHANGED
data/.rubocop.yml
CHANGED
@@ -9,7 +9,7 @@ Documentation:
|
|
9
9
|
Enabled: false
|
10
10
|
|
11
11
|
Metrics/LineLength:
|
12
|
-
Max:
|
12
|
+
Max: 140
|
13
13
|
IgnoredPatterns: ['(\A|\s)#']
|
14
14
|
|
15
15
|
Metrics/ModuleLength:
|
@@ -21,6 +21,9 @@ Metrics/ClassLength:
|
|
21
21
|
Metrics/MethodLength:
|
22
22
|
Max: 40
|
23
23
|
|
24
|
+
Metrics/AbcSize:
|
25
|
+
Max: 60
|
26
|
+
|
24
27
|
Metrics/BlockLength:
|
25
28
|
Exclude:
|
26
29
|
- 'spec/**/*'
|
@@ -30,3 +33,9 @@ ParameterLists:
|
|
30
33
|
|
31
34
|
Security/MarshalLoad:
|
32
35
|
Enabled: false
|
36
|
+
|
37
|
+
Naming/UncommunicativeMethodParamName:
|
38
|
+
Enabled: false
|
39
|
+
|
40
|
+
Style/FormatStringToken:
|
41
|
+
Enabled: false
|
data/.rubocop_todo.yml
CHANGED
@@ -1,18 +1,59 @@
|
|
1
1
|
# This configuration was generated by
|
2
2
|
# `rubocop --auto-gen-config`
|
3
|
-
# on 2018-
|
3
|
+
# on 2018-04-14 20:44:19 +0900 using RuboCop version 0.54.0.
|
4
4
|
# The point is for the user to remove these configuration records
|
5
5
|
# one by one as the offenses are removed from the code base.
|
6
6
|
# Note that changes in the inspected code, or installation of new
|
7
7
|
# versions of RuboCop, may require this file to be generated again.
|
8
8
|
|
9
|
-
# Offense count:
|
10
|
-
|
11
|
-
|
9
|
+
# Offense count: 1
|
10
|
+
# Configuration parameters: Include.
|
11
|
+
# Include: **/*.gemspec
|
12
|
+
Gemspec/RequiredRubyVersion:
|
13
|
+
Exclude:
|
14
|
+
- 'svmkit.gemspec'
|
12
15
|
|
13
|
-
# Offense count:
|
14
|
-
#
|
15
|
-
#
|
16
|
-
|
17
|
-
|
18
|
-
|
16
|
+
# Offense count: 3
|
17
|
+
# Cop supports --auto-correct.
|
18
|
+
# Configuration parameters: EnforcedStyle.
|
19
|
+
# SupportedStyles: auto_detection, squiggly, active_support, powerpack, unindent
|
20
|
+
Layout/IndentHeredoc:
|
21
|
+
Exclude:
|
22
|
+
- 'svmkit.gemspec'
|
23
|
+
|
24
|
+
# Offense count: 1
|
25
|
+
# Configuration parameters: CountComments, ExcludedMethods.
|
26
|
+
Metrics/BlockLength:
|
27
|
+
Max: 30
|
28
|
+
|
29
|
+
# Offense count: 1
|
30
|
+
Metrics/CyclomaticComplexity:
|
31
|
+
Max: 12
|
32
|
+
|
33
|
+
# Offense count: 1
|
34
|
+
Metrics/PerceivedComplexity:
|
35
|
+
Max: 12
|
36
|
+
|
37
|
+
# Offense count: 1
|
38
|
+
# Cop supports --auto-correct.
|
39
|
+
Style/Encoding:
|
40
|
+
Exclude:
|
41
|
+
- 'svmkit.gemspec'
|
42
|
+
|
43
|
+
# Offense count: 1
|
44
|
+
# Cop supports --auto-correct.
|
45
|
+
# Configuration parameters: EnforcedStyle, UseHashRocketsWithSymbolValues, PreferHashRocketsForNonAlnumEndingSymbols.
|
46
|
+
# SupportedStyles: ruby19, hash_rockets, no_mixed_keys, ruby19_no_mixed_keys
|
47
|
+
Style/HashSyntax:
|
48
|
+
Exclude:
|
49
|
+
- 'Rakefile'
|
50
|
+
|
51
|
+
# Offense count: 6
|
52
|
+
# Cop supports --auto-correct.
|
53
|
+
# Configuration parameters: EnforcedStyle, ConsistentQuotesInMultiline.
|
54
|
+
# SupportedStyles: single_quotes, double_quotes
|
55
|
+
Style/StringLiterals:
|
56
|
+
Exclude:
|
57
|
+
- 'Gemfile'
|
58
|
+
- 'Rakefile'
|
59
|
+
- 'bin/console'
|
data/Gemfile
CHANGED
data/HISTORY.md
CHANGED
@@ -1,59 +1,69 @@
|
|
1
|
+
# 0.2.9
|
2
|
+
- Add predict_proba method to SVC and KernelSVC.
|
3
|
+
- Add class for evaluating logarithmic loss.
|
4
|
+
- Add classes for Label- and One-Hot- encoding.
|
5
|
+
- Add some validator.
|
6
|
+
- Fix bug on training data score calculation of cross validation.
|
7
|
+
- Fix fit method of SVC for performance.
|
8
|
+
- Fix criterion calculation on Decision Tree for performance.
|
9
|
+
- Fix data structure of Decision Tree for performance.
|
10
|
+
|
1
11
|
# 0.2.8
|
2
|
-
-
|
3
|
-
-
|
4
|
-
-
|
12
|
+
- Fix bug on gradient calculation of Logistic Regression.
|
13
|
+
- Fix to change accessor of params of estimators to read only.
|
14
|
+
- Add parameter validation.
|
5
15
|
|
6
16
|
# 0.2.7
|
7
|
-
-
|
17
|
+
- Fix to support multiclass classifiction into LinearSVC, LogisticRegression, KernelSVC, and FactorizationMachineClassifier
|
8
18
|
|
9
19
|
# 0.2.6
|
10
|
-
-
|
11
|
-
-
|
12
|
-
-
|
13
|
-
-
|
14
|
-
-
|
20
|
+
- Add class for Decision Tree classifier.
|
21
|
+
- Add class for Random Forest classifier.
|
22
|
+
- Fix to use frozen string literal.
|
23
|
+
- Refactor marshal dump method on some classes.
|
24
|
+
- Introduce Coveralls to confirm test coverage.
|
15
25
|
|
16
26
|
# 0.2.5
|
17
|
-
-
|
18
|
-
-
|
19
|
-
-
|
27
|
+
- Add classes for Naive Bayes classifier.
|
28
|
+
- Fix decision function method on Logistic Regression class.
|
29
|
+
- Fix method visibility on RBF kernel approximation class.
|
20
30
|
|
21
31
|
# 0.2.4
|
22
|
-
-
|
23
|
-
-
|
24
|
-
-
|
32
|
+
- Add class for Factorization Machine classifier.
|
33
|
+
- Add classes for evaluation measures.
|
34
|
+
- Fix the method for prediction of class probability in Logistic Regression.
|
25
35
|
|
26
36
|
# 0.2.3
|
27
|
-
-
|
28
|
-
-
|
29
|
-
-
|
37
|
+
- Add class for cross validation.
|
38
|
+
- Add specs for base modules.
|
39
|
+
- Fix validation of the number of splits when a negative label is given.
|
30
40
|
|
31
41
|
# 0.2.2
|
32
|
-
-
|
42
|
+
- Add data splitter classes for K-fold cross validation.
|
33
43
|
|
34
44
|
# 0.2.1
|
35
|
-
-
|
45
|
+
- Add class for K-nearest neighbors classifier.
|
36
46
|
|
37
47
|
# 0.2.0
|
38
48
|
- Migrated the linear algebra library to Numo::NArray.
|
39
|
-
-
|
49
|
+
- Add module for loading and saving libsvm format file.
|
40
50
|
|
41
51
|
# 0.1.3
|
42
|
-
-
|
43
|
-
-
|
52
|
+
- Add class for Kernel Support Vector Machine with Pegasos algorithm.
|
53
|
+
- Add module for calculating pairwise kernel fuctions and euclidean distances.
|
44
54
|
|
45
55
|
# 0.1.2
|
46
|
-
-
|
47
|
-
-
|
56
|
+
- Add the function learning a model with bias term to the PegasosSVC and LogisticRegression classes.
|
57
|
+
- Rewrite the document with yard notation.
|
48
58
|
|
49
59
|
# 0.1.1
|
50
|
-
-
|
51
|
-
-
|
60
|
+
- Add class for Logistic Regression with SGD optimization.
|
61
|
+
- Fix some mistakes on the document.
|
52
62
|
|
53
63
|
# 0.1.0
|
54
|
-
-
|
55
|
-
-
|
56
|
-
-
|
57
|
-
-
|
58
|
-
-
|
59
|
-
-
|
64
|
+
- Add basic classes.
|
65
|
+
- Add an utility module.
|
66
|
+
- Add class for RBF kernel approximation.
|
67
|
+
- Add class for Support Vector Machine with Pegasos alogrithm.
|
68
|
+
- Add class that performs mutlclass classification with one-vs.-rest strategy.
|
69
|
+
- Add classes for preprocessing such as min-max scaling, standardization, and L2 normalization.
|
data/lib/svmkit.rb
CHANGED
@@ -6,6 +6,7 @@ require 'svmkit/version'
|
|
6
6
|
require 'svmkit/validation'
|
7
7
|
require 'svmkit/pairwise_metric'
|
8
8
|
require 'svmkit/dataset'
|
9
|
+
require 'svmkit/probabilistic_output'
|
9
10
|
require 'svmkit/base/base_estimator'
|
10
11
|
require 'svmkit/base/classifier'
|
11
12
|
require 'svmkit/base/transformer'
|
@@ -24,6 +25,8 @@ require 'svmkit/ensemble/random_forest_classifier'
|
|
24
25
|
require 'svmkit/preprocessing/l2_normalizer'
|
25
26
|
require 'svmkit/preprocessing/min_max_scaler'
|
26
27
|
require 'svmkit/preprocessing/standard_scaler'
|
28
|
+
require 'svmkit/preprocessing/label_encoder'
|
29
|
+
require 'svmkit/preprocessing/one_hot_encoder'
|
27
30
|
require 'svmkit/model_selection/k_fold'
|
28
31
|
require 'svmkit/model_selection/stratified_k_fold'
|
29
32
|
require 'svmkit/model_selection/cross_validation'
|
@@ -31,3 +34,4 @@ require 'svmkit/evaluation_measure/accuracy'
|
|
31
34
|
require 'svmkit/evaluation_measure/precision'
|
32
35
|
require 'svmkit/evaluation_measure/recall'
|
33
36
|
require 'svmkit/evaluation_measure/f_score'
|
37
|
+
require 'svmkit/evaluation_measure/log_loss'
|
@@ -22,6 +22,7 @@ module SVMKit
|
|
22
22
|
def score(x, y)
|
23
23
|
SVMKit::Validation.check_sample_array(x)
|
24
24
|
SVMKit::Validation.check_label_array(y)
|
25
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
25
26
|
evaluator = SVMKit::EvaluationMeasure::Accuracy.new
|
26
27
|
evaluator.score(y, predict(x))
|
27
28
|
end
|
@@ -51,10 +51,12 @@ module SVMKit
|
|
51
51
|
def initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1,
|
52
52
|
max_features: nil, random_seed: nil)
|
53
53
|
SVMKit::Validation.check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes,
|
54
|
-
|
54
|
+
max_features: max_features, random_seed: random_seed)
|
55
55
|
SVMKit::Validation.check_params_integer(n_estimators: n_estimators, min_samples_leaf: min_samples_leaf)
|
56
56
|
SVMKit::Validation.check_params_string(criterion: criterion)
|
57
|
-
|
57
|
+
SVMKit::Validation.check_params_positive(n_estimators: n_estimators, max_depth: max_depth,
|
58
|
+
max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf,
|
59
|
+
max_features: max_features)
|
58
60
|
@params = {}
|
59
61
|
@params[:n_estimators] = n_estimators
|
60
62
|
@params[:criterion] = criterion
|
@@ -78,6 +80,7 @@ module SVMKit
|
|
78
80
|
def fit(x, y)
|
79
81
|
SVMKit::Validation.check_sample_array(x)
|
80
82
|
SVMKit::Validation.check_label_array(y)
|
83
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
81
84
|
# Initialize some variables.
|
82
85
|
n_samples, n_features = x.shape
|
83
86
|
@params[:max_features] = n_features unless @params[:max_features].is_a?(Integer)
|
@@ -0,0 +1,44 @@
|
|
1
|
+
# frozen_string_literal: true
|
2
|
+
|
3
|
+
require 'svmkit/base/evaluator'
|
4
|
+
|
5
|
+
module SVMKit
|
6
|
+
module EvaluationMeasure
|
7
|
+
# LogLoss is a class that calculates the logarithmic loss of predicted class probability.
|
8
|
+
#
|
9
|
+
# @example
|
10
|
+
# evaluator = SVMKit::EvaluationMeasure::LogLoss.new
|
11
|
+
# puts evaluator.score(ground_truth, predicted)
|
12
|
+
class LogLoss
|
13
|
+
include Base::Evaluator
|
14
|
+
|
15
|
+
# Claculate mean logarithmic loss.
|
16
|
+
# If both y_true and y_pred are array (both shapes are [n_samples]), this method calculates
|
17
|
+
# mean logarithmic loss for binary classification.
|
18
|
+
#
|
19
|
+
# @param y_true [Numo::Int32] (shape: [n_samples]) Ground truth labels.
|
20
|
+
# @param y_pred [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted class probability.
|
21
|
+
# @param eps [Float] A small value close to zero to avoid outputting infinity in logarithmic calcuation.
|
22
|
+
# @return [Float] mean logarithmic loss
|
23
|
+
def score(y_true, y_pred, eps = 1e-15)
|
24
|
+
SVMKit::Validation.check_params_type(Numo::Int32, y_true: y_true)
|
25
|
+
SVMKit::Validation.check_params_type(Numo::DFloat, y_pred: y_pred)
|
26
|
+
|
27
|
+
n_samples, n_classes = y_pred.shape
|
28
|
+
clipped_p = y_pred.clip(eps, 1 - eps)
|
29
|
+
|
30
|
+
log_loss = if n_classes.nil?
|
31
|
+
negative_label = y_true.to_a.uniq.sort.first
|
32
|
+
bin_y_true = Numo::DFloat.cast(y_true.ne(negative_label))
|
33
|
+
-(bin_y_true * Numo::NMath.log(clipped_p) + (1 - bin_y_true) * Numo::NMath.log(1 - clipped_p))
|
34
|
+
else
|
35
|
+
encoder = SVMKit::Preprocessing::OneHotEncoder.new
|
36
|
+
encoded_y_true = encoder.fit_transform(y_true)
|
37
|
+
clipped_p /= clipped_p.sum(1).expand_dims(1)
|
38
|
+
-(encoded_y_true * Numo::NMath.log(clipped_p)).sum(1)
|
39
|
+
end
|
40
|
+
log_loss.sum / n_samples
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
44
|
+
end
|
@@ -40,7 +40,7 @@ module SVMKit
|
|
40
40
|
SVMKit::Validation.check_params_float(gamma: gamma)
|
41
41
|
SVMKit::Validation.check_params_integer(n_components: n_components)
|
42
42
|
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
43
|
-
|
43
|
+
SVMKit::Validation.check_params_positive(gamma: gamma, n_components: n_components)
|
44
44
|
@params = {}
|
45
45
|
@params[:gamma] = gamma
|
46
46
|
@params[:n_components] = n_components
|
@@ -40,18 +40,22 @@ module SVMKit
|
|
40
40
|
#
|
41
41
|
# @param reg_param [Float] The regularization parameter.
|
42
42
|
# @param max_iter [Integer] The maximum number of iterations.
|
43
|
+
# @param probability [Boolean] The flag indicating whether to perform probability estimation.
|
43
44
|
# @param random_seed [Integer] The seed value using to initialize the random generator.
|
44
|
-
def initialize(reg_param: 1.0, max_iter: 1000, random_seed: nil)
|
45
|
+
def initialize(reg_param: 1.0, max_iter: 1000, probability: false, random_seed: nil)
|
45
46
|
SVMKit::Validation.check_params_float(reg_param: reg_param)
|
46
47
|
SVMKit::Validation.check_params_integer(max_iter: max_iter)
|
48
|
+
SVMKit::Validation.check_params_boolean(probability: probability)
|
47
49
|
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
48
|
-
|
50
|
+
SVMKit::Validation.check_params_positive(reg_param: reg_param, max_iter: max_iter)
|
49
51
|
@params = {}
|
50
52
|
@params[:reg_param] = reg_param
|
51
53
|
@params[:max_iter] = max_iter
|
54
|
+
@params[:probability] = probability
|
52
55
|
@params[:random_seed] = random_seed
|
53
56
|
@params[:random_seed] ||= srand
|
54
57
|
@weight_vec = nil
|
58
|
+
@prob_param = nil
|
55
59
|
@classes = nil
|
56
60
|
@rng = Random.new(@params[:random_seed])
|
57
61
|
end
|
@@ -65,6 +69,7 @@ module SVMKit
|
|
65
69
|
def fit(x, y)
|
66
70
|
SVMKit::Validation.check_sample_array(x)
|
67
71
|
SVMKit::Validation.check_label_array(y)
|
72
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
68
73
|
|
69
74
|
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
70
75
|
n_classes = @classes.size
|
@@ -72,14 +77,25 @@ module SVMKit
|
|
72
77
|
|
73
78
|
if n_classes > 2
|
74
79
|
@weight_vec = Numo::DFloat.zeros(n_classes, n_features)
|
80
|
+
@prob_param = Numo::DFloat.zeros(n_classes, 2)
|
75
81
|
n_classes.times do |n|
|
76
82
|
bin_y = Numo::Int32.cast(y.eq(@classes[n])) * 2 - 1
|
77
83
|
@weight_vec[n, true] = binary_fit(x, bin_y)
|
84
|
+
@prob_param[n, true] = if @params[:probability]
|
85
|
+
SVMKit::ProbabilisticOutput.fit_sigmoid(x.dot(@weight_vec[n, true].transpose), bin_y)
|
86
|
+
else
|
87
|
+
Numo::DFloat[1, 0]
|
88
|
+
end
|
78
89
|
end
|
79
90
|
else
|
80
91
|
negative_label = y.to_a.uniq.sort.first
|
81
92
|
bin_y = Numo::Int32.cast(y.ne(negative_label)) * 2 - 1
|
82
93
|
@weight_vec = binary_fit(x, bin_y)
|
94
|
+
@prob_param = if @params[:probability]
|
95
|
+
SVMKit::ProbabilisticOutput.fit_sigmoid(x.dot(@weight_vec.transpose), bin_y)
|
96
|
+
else
|
97
|
+
Numo::DFloat[1, 0]
|
98
|
+
end
|
83
99
|
end
|
84
100
|
|
85
101
|
self
|
@@ -111,11 +127,32 @@ module SVMKit
|
|
111
127
|
Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
|
112
128
|
end
|
113
129
|
|
130
|
+
# Predict probability for samples.
|
131
|
+
#
|
132
|
+
# @param x [Numo::DFloat] (shape: [n_testing_samples, n_training_samples])
|
133
|
+
# The kernel matrix between testing samples and training samples to predict the labels.
|
134
|
+
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
|
135
|
+
def predict_proba(x)
|
136
|
+
SVMKit::Validation.check_sample_array(x)
|
137
|
+
|
138
|
+
if @classes.size > 2
|
139
|
+
probs = 1.0 / (Numo::NMath.exp(@prob_param[true, 0] * decision_function(x) + @prob_param[true, 1]) + 1.0)
|
140
|
+
return (probs.transpose / probs.sum(axis: 1)).transpose
|
141
|
+
end
|
142
|
+
|
143
|
+
n_samples, = x.shape
|
144
|
+
probs = Numo::DFloat.zeros(n_samples, 2)
|
145
|
+
probs[true, 1] = 1.0 / (Numo::NMath.exp(@prob_param[0] * decision_function(x) + @prob_param[1]) + 1.0)
|
146
|
+
probs[true, 0] = 1.0 - probs[true, 1]
|
147
|
+
probs
|
148
|
+
end
|
149
|
+
|
114
150
|
# Dump marshal data.
|
115
151
|
# @return [Hash] The marshal data about KernelSVC.
|
116
152
|
def marshal_dump
|
117
153
|
{ params: @params,
|
118
154
|
weight_vec: @weight_vec,
|
155
|
+
prob_param: @prob_param,
|
119
156
|
classes: @classes,
|
120
157
|
rng: @rng }
|
121
158
|
end
|
@@ -125,6 +162,7 @@ module SVMKit
|
|
125
162
|
def marshal_load(obj)
|
126
163
|
@params = obj[:params]
|
127
164
|
@weight_vec = obj[:weight_vec]
|
165
|
+
@prob_param = obj[:prob_param]
|
128
166
|
@classes = obj[:classes]
|
129
167
|
@rng = obj[:rng]
|
130
168
|
nil
|
@@ -54,7 +54,8 @@ module SVMKit
|
|
54
54
|
SVMKit::Validation.check_params_integer(max_iter: max_iter, batch_size: batch_size)
|
55
55
|
SVMKit::Validation.check_params_boolean(fit_bias: fit_bias, normalize: normalize)
|
56
56
|
SVMKit::Validation.check_params_type_or_nil(Integer, random_seed: random_seed)
|
57
|
-
|
57
|
+
SVMKit::Validation.check_params_positive(reg_param: reg_param, bias_scale: bias_scale, max_iter: max_iter,
|
58
|
+
batch_size: batch_size)
|
58
59
|
@params = {}
|
59
60
|
@params[:reg_param] = reg_param
|
60
61
|
@params[:fit_bias] = fit_bias
|
@@ -78,6 +79,7 @@ module SVMKit
|
|
78
79
|
def fit(x, y)
|
79
80
|
SVMKit::Validation.check_sample_array(x)
|
80
81
|
SVMKit::Validation.check_label_array(y)
|
82
|
+
SVMKit::Validation.check_sample_label_size(x, y)
|
81
83
|
|
82
84
|
@classes = Numo::Int32[*y.to_a.uniq.sort]
|
83
85
|
n_classes = @classes.size
|