rumale-svm 0.5.1 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: f0000e7a25f4aaf36f645a7860d13bed3ca06ea7d99bc15469cd7c9ea0a57288
4
- data.tar.gz: 0db6f53cabe1a84493d20bf9e0822057dca4ec92042bdefb6fda15f794aba9aa
3
+ metadata.gz: ec0ef314d44123fa222655004b099652d244c7fb9bea73569a4fe701974e43b1
4
+ data.tar.gz: ccaca04e1b3fe7691467a26a678f8978bc2f77e5a1b2e094b1af08923779588b
5
5
  SHA512:
6
- metadata.gz: af0433aacb57daa0b9bf81dfe961f71ef71751215749201c18c1c81aa592d55050e4ed6d65605809ff6c36153cbbeb45e200181675dec23e13f884b70185ccab
7
- data.tar.gz: c97c4a38ccdfabd869e6f5196eae73f32371d7bd12ff8e9612ede2cbfe89bb3bc563505429590ef65d928ba84d76cc06cf9c5f9eee0e79b161cbfd470da80cc0
6
+ metadata.gz: 44b1ab01d131133ef5d9fb32ae4dfaff5ca1187e45c5a1b4f4ec3b198f515020baf85b5bb4a1bb110ae6e85d74f521af86661b4741a63828bd49288a138c9513
7
+ data.tar.gz: af83ab927ed5fbd630880b51fa2a0bf25c9aaad066a603d1ce6bf178c873ed669461eb46cdb23003b888799f11a5114fdf5405288d8129783c9495ec74799fd5
data/CHANGELOG.md CHANGED
@@ -1,3 +1,10 @@
1
+ # 0.7.0
2
+ - Support for probabilistic outputs with Rumale::SVM::OneClassSVM.
3
+ - Update numo-libsvm depedency to v2.1 or higher.
4
+
5
+ # 0.6.0
6
+ - Update numo-libsvm and numo-liblinear depedency to v2.0 or higher.
7
+
1
8
  # 0.5.1
2
9
  - Refator specs and config files.
3
10
 
@@ -24,15 +24,16 @@ module Rumale
24
24
  # @param gamma [Float] The gamma parameter in rbf/poly/sigmoid kernel function.
25
25
  # @param coef0 [Float] The coefficient in poly/sigmoid kernel function.
26
26
  # @param shrinking [Boolean] The flag indicating whether to use the shrinking heuristics.
27
+ # @param probability [Boolean] The flag indicating whether to train the parameter for probability estimation.
27
28
  # @param cache_size [Float] The cache memory size in MB.
28
29
  # @param tol [Float] The tolerance of termination criterion.
29
30
  # @param verbose [Boolean] The flag indicating whether to output learning process message
30
31
  # @param random_seed [Integer/Nil] The seed value using to initialize the random generator.
31
32
  def initialize(nu: 1.0, kernel: 'rbf', degree: 3, gamma: 1.0, coef0: 0.0,
32
- shrinking: true, cache_size: 200.0, tol: 1e-3, verbose: false, random_seed: nil)
33
+ shrinking: true, probability: true, cache_size: 200.0, tol: 1e-3, verbose: false, random_seed: nil)
33
34
  check_params_numeric(nu: nu, degree: degree, gamma: gamma, coef0: coef0, cache_size: cache_size, tol: tol)
34
35
  check_params_string(kernel: kernel)
35
- check_params_boolean(shrinking: shrinking, verbose: verbose)
36
+ check_params_boolean(shrinking: shrinking, probability: probability, verbose: verbose)
36
37
  check_params_numeric_or_nil(random_seed: random_seed)
37
38
  @params = {}
38
39
  @params[:nu] = nu.to_f
@@ -41,6 +42,7 @@ module Rumale
41
42
  @params[:gamma] = gamma.to_f
42
43
  @params[:coef0] = coef0.to_f
43
44
  @params[:shrinking] = shrinking
45
+ @params[:probability] = probability
44
46
  @params[:cache_size] = cache_size.to_f
45
47
  @params[:tol] = tol.to_f
46
48
  @params[:verbose] = verbose
@@ -82,6 +84,19 @@ module Rumale
82
84
  Numo::Int32.cast(Numo::Libsvm.predict(x, libsvm_params, @model))
83
85
  end
84
86
 
87
+ # Predict class probability for samples.
88
+ # This method works correctly only if the probability parameter is true.
89
+ #
90
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
91
+ # If the kernel is 'precomputed', the shape of x must be [n_samples, n_training_samples].
92
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
93
+ def predict_proba(x)
94
+ raise "#{self.class.name}\##{__method__} expects to be called after training the model with the fit method." unless trained?
95
+ raise "#{self.class.name}\##{__method__} expects to be called after training the probablity parameters." unless trained_probs?
96
+ x = check_convert_sample_array(x)
97
+ Numo::Libsvm.predict_proba(x, libsvm_params, @model)
98
+ end
99
+
85
100
  # Dump marshal data.
86
101
  # @return [Hash] The marshal data about SVC.
87
102
  def marshal_dump
@@ -150,6 +165,10 @@ module Rumale
150
165
  def trained?
151
166
  !@model.nil?
152
167
  end
168
+
169
+ def trained_probs?
170
+ @model[:prob_density_marks].is_a?(Numo::NArray)
171
+ end
153
172
  end
154
173
  end
155
174
  end
@@ -4,7 +4,7 @@
4
4
  module Rumale
5
5
  # This module consists of Rumale interfaces for suppor vector machine algorithms with LIBSVM and LIBLINEAR.
6
6
  module SVM
7
- # The version of Rumale-SVM you are using.
8
- VERSION = '0.5.1'
7
+ # The version of Rumale::SVM you are using.
8
+ VERSION = '0.7.0'
9
9
  end
10
10
  end
@@ -10,6 +10,7 @@ module Rumale
10
10
  def fit: (Numo::DFloat x, ?untyped? _y) -> OneClassSVM
11
11
  def decision_function: (Numo::DFloat x) -> Numo::DFloat
12
12
  def predict: (Numo::DFloat x) -> Numo::Int32
13
+ def predict_proba: (Numo::DFloat x) -> Numo::DFloat
13
14
  def marshal_dump: () -> { params: Hash[Symbol, untyped], model: untyped }
14
15
  def marshal_load: (Hash[Symbol, untyped] obj) -> void
15
16
  def support: () -> Numo::Int32
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-svm
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.5.1
4
+ version: 0.7.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2022-05-11 00:00:00.000000000 Z
11
+ date: 2022-10-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-liblinear
@@ -16,28 +16,28 @@ dependencies:
16
16
  requirements:
17
17
  - - "~>"
18
18
  - !ruby/object:Gem::Version
19
- version: '1.1'
19
+ version: '2.0'
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - "~>"
25
25
  - !ruby/object:Gem::Version
26
- version: '1.1'
26
+ version: '2.0'
27
27
  - !ruby/object:Gem::Dependency
28
28
  name: numo-libsvm
29
29
  requirement: !ruby/object:Gem::Requirement
30
30
  requirements:
31
31
  - - "~>"
32
32
  - !ruby/object:Gem::Version
33
- version: '1.0'
33
+ version: '2.1'
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - "~>"
39
39
  - !ruby/object:Gem::Version
40
- version: '1.0'
40
+ version: '2.1'
41
41
  - !ruby/object:Gem::Dependency
42
42
  name: rumale
43
43
  requirement: !ruby/object:Gem::Requirement
@@ -58,7 +58,7 @@ dependencies:
58
58
  - - "<"
59
59
  - !ruby/object:Gem::Version
60
60
  version: '0.24'
61
- description: 'Rumale-SVM provides support vector machine algorithms of LIBSVM and
61
+ description: 'Rumale::SVM provides support vector machine algorithms of LIBSVM and
62
62
  LIBLINEAR with Rumale interface.
63
63
 
64
64
  '
@@ -116,9 +116,9 @@ required_rubygems_version: !ruby/object:Gem::Requirement
116
116
  - !ruby/object:Gem::Version
117
117
  version: '0'
118
118
  requirements: []
119
- rubygems_version: 3.2.33
119
+ rubygems_version: 3.3.7
120
120
  signing_key:
121
121
  specification_version: 4
122
- summary: Rumale-SVM provides support vector machine algorithms of LIBSVM and LIBLINEAR
122
+ summary: Rumale::SVM provides support vector machine algorithms of LIBSVM and LIBLINEAR
123
123
  with Rumale interface.
124
124
  test_files: []