rumale-svm 0.5.1 → 0.7.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: f0000e7a25f4aaf36f645a7860d13bed3ca06ea7d99bc15469cd7c9ea0a57288
4
- data.tar.gz: 0db6f53cabe1a84493d20bf9e0822057dca4ec92042bdefb6fda15f794aba9aa
3
+ metadata.gz: ec0ef314d44123fa222655004b099652d244c7fb9bea73569a4fe701974e43b1
4
+ data.tar.gz: ccaca04e1b3fe7691467a26a678f8978bc2f77e5a1b2e094b1af08923779588b
5
5
  SHA512:
6
- metadata.gz: af0433aacb57daa0b9bf81dfe961f71ef71751215749201c18c1c81aa592d55050e4ed6d65605809ff6c36153cbbeb45e200181675dec23e13f884b70185ccab
7
- data.tar.gz: c97c4a38ccdfabd869e6f5196eae73f32371d7bd12ff8e9612ede2cbfe89bb3bc563505429590ef65d928ba84d76cc06cf9c5f9eee0e79b161cbfd470da80cc0
6
+ metadata.gz: 44b1ab01d131133ef5d9fb32ae4dfaff5ca1187e45c5a1b4f4ec3b198f515020baf85b5bb4a1bb110ae6e85d74f521af86661b4741a63828bd49288a138c9513
7
+ data.tar.gz: af83ab927ed5fbd630880b51fa2a0bf25c9aaad066a603d1ce6bf178c873ed669461eb46cdb23003b888799f11a5114fdf5405288d8129783c9495ec74799fd5
data/CHANGELOG.md CHANGED
@@ -1,3 +1,10 @@
1
+ # 0.7.0
2
+ - Support for probabilistic outputs with Rumale::SVM::OneClassSVM.
3
+ - Update numo-libsvm depedency to v2.1 or higher.
4
+
5
+ # 0.6.0
6
+ - Update numo-libsvm and numo-liblinear depedency to v2.0 or higher.
7
+
1
8
  # 0.5.1
2
9
  - Refator specs and config files.
3
10
 
@@ -24,15 +24,16 @@ module Rumale
24
24
  # @param gamma [Float] The gamma parameter in rbf/poly/sigmoid kernel function.
25
25
  # @param coef0 [Float] The coefficient in poly/sigmoid kernel function.
26
26
  # @param shrinking [Boolean] The flag indicating whether to use the shrinking heuristics.
27
+ # @param probability [Boolean] The flag indicating whether to train the parameter for probability estimation.
27
28
  # @param cache_size [Float] The cache memory size in MB.
28
29
  # @param tol [Float] The tolerance of termination criterion.
29
30
  # @param verbose [Boolean] The flag indicating whether to output learning process message
30
31
  # @param random_seed [Integer/Nil] The seed value using to initialize the random generator.
31
32
  def initialize(nu: 1.0, kernel: 'rbf', degree: 3, gamma: 1.0, coef0: 0.0,
32
- shrinking: true, cache_size: 200.0, tol: 1e-3, verbose: false, random_seed: nil)
33
+ shrinking: true, probability: true, cache_size: 200.0, tol: 1e-3, verbose: false, random_seed: nil)
33
34
  check_params_numeric(nu: nu, degree: degree, gamma: gamma, coef0: coef0, cache_size: cache_size, tol: tol)
34
35
  check_params_string(kernel: kernel)
35
- check_params_boolean(shrinking: shrinking, verbose: verbose)
36
+ check_params_boolean(shrinking: shrinking, probability: probability, verbose: verbose)
36
37
  check_params_numeric_or_nil(random_seed: random_seed)
37
38
  @params = {}
38
39
  @params[:nu] = nu.to_f
@@ -41,6 +42,7 @@ module Rumale
41
42
  @params[:gamma] = gamma.to_f
42
43
  @params[:coef0] = coef0.to_f
43
44
  @params[:shrinking] = shrinking
45
+ @params[:probability] = probability
44
46
  @params[:cache_size] = cache_size.to_f
45
47
  @params[:tol] = tol.to_f
46
48
  @params[:verbose] = verbose
@@ -82,6 +84,19 @@ module Rumale
82
84
  Numo::Int32.cast(Numo::Libsvm.predict(x, libsvm_params, @model))
83
85
  end
84
86
 
87
+ # Predict class probability for samples.
88
+ # This method works correctly only if the probability parameter is true.
89
+ #
90
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
91
+ # If the kernel is 'precomputed', the shape of x must be [n_samples, n_training_samples].
92
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
93
+ def predict_proba(x)
94
+ raise "#{self.class.name}\##{__method__} expects to be called after training the model with the fit method." unless trained?
95
+ raise "#{self.class.name}\##{__method__} expects to be called after training the probablity parameters." unless trained_probs?
96
+ x = check_convert_sample_array(x)
97
+ Numo::Libsvm.predict_proba(x, libsvm_params, @model)
98
+ end
99
+
85
100
  # Dump marshal data.
86
101
  # @return [Hash] The marshal data about SVC.
87
102
  def marshal_dump
@@ -150,6 +165,10 @@ module Rumale
150
165
  def trained?
151
166
  !@model.nil?
152
167
  end
168
+
169
+ def trained_probs?
170
+ @model[:prob_density_marks].is_a?(Numo::NArray)
171
+ end
153
172
  end
154
173
  end
155
174
  end
@@ -4,7 +4,7 @@
4
4
  module Rumale
5
5
  # This module consists of Rumale interfaces for suppor vector machine algorithms with LIBSVM and LIBLINEAR.
6
6
  module SVM
7
- # The version of Rumale-SVM you are using.
8
- VERSION = '0.5.1'
7
+ # The version of Rumale::SVM you are using.
8
+ VERSION = '0.7.0'
9
9
  end
10
10
  end
@@ -10,6 +10,7 @@ module Rumale
10
10
  def fit: (Numo::DFloat x, ?untyped? _y) -> OneClassSVM
11
11
  def decision_function: (Numo::DFloat x) -> Numo::DFloat
12
12
  def predict: (Numo::DFloat x) -> Numo::Int32
13
+ def predict_proba: (Numo::DFloat x) -> Numo::DFloat
13
14
  def marshal_dump: () -> { params: Hash[Symbol, untyped], model: untyped }
14
15
  def marshal_load: (Hash[Symbol, untyped] obj) -> void
15
16
  def support: () -> Numo::Int32
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-svm
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.5.1
4
+ version: 0.7.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2022-05-11 00:00:00.000000000 Z
11
+ date: 2022-10-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-liblinear
@@ -16,28 +16,28 @@ dependencies:
16
16
  requirements:
17
17
  - - "~>"
18
18
  - !ruby/object:Gem::Version
19
- version: '1.1'
19
+ version: '2.0'
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - "~>"
25
25
  - !ruby/object:Gem::Version
26
- version: '1.1'
26
+ version: '2.0'
27
27
  - !ruby/object:Gem::Dependency
28
28
  name: numo-libsvm
29
29
  requirement: !ruby/object:Gem::Requirement
30
30
  requirements:
31
31
  - - "~>"
32
32
  - !ruby/object:Gem::Version
33
- version: '1.0'
33
+ version: '2.1'
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - "~>"
39
39
  - !ruby/object:Gem::Version
40
- version: '1.0'
40
+ version: '2.1'
41
41
  - !ruby/object:Gem::Dependency
42
42
  name: rumale
43
43
  requirement: !ruby/object:Gem::Requirement
@@ -58,7 +58,7 @@ dependencies:
58
58
  - - "<"
59
59
  - !ruby/object:Gem::Version
60
60
  version: '0.24'
61
- description: 'Rumale-SVM provides support vector machine algorithms of LIBSVM and
61
+ description: 'Rumale::SVM provides support vector machine algorithms of LIBSVM and
62
62
  LIBLINEAR with Rumale interface.
63
63
 
64
64
  '
@@ -116,9 +116,9 @@ required_rubygems_version: !ruby/object:Gem::Requirement
116
116
  - !ruby/object:Gem::Version
117
117
  version: '0'
118
118
  requirements: []
119
- rubygems_version: 3.2.33
119
+ rubygems_version: 3.3.7
120
120
  signing_key:
121
121
  specification_version: 4
122
- summary: Rumale-SVM provides support vector machine algorithms of LIBSVM and LIBLINEAR
122
+ summary: Rumale::SVM provides support vector machine algorithms of LIBSVM and LIBLINEAR
123
123
  with Rumale interface.
124
124
  test_files: []