rumale-svm 0.6.0 → 0.7.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: c8cb843c4d4245f90f57ee22722e8e899e8e5077e29773629001dccf9e243010
4
- data.tar.gz: b26af5c7c940562086ee2f7859315779db0e2be372c47da4e518b6fa77099d35
3
+ metadata.gz: ec0ef314d44123fa222655004b099652d244c7fb9bea73569a4fe701974e43b1
4
+ data.tar.gz: ccaca04e1b3fe7691467a26a678f8978bc2f77e5a1b2e094b1af08923779588b
5
5
  SHA512:
6
- metadata.gz: 0ca52d3c96333ee5ee60ec43b36c9d7a80687d9a3241b29d0872cbc1569eb45e22966f52cbde34de103c43c46d88d64a8685703db58c58c9828db2d4daf8b9c7
7
- data.tar.gz: ff2a02ec2891dcf3626af1033dd3b181b697e8a43155bb6923e3cda31ef2116581b77e1201e934d152d6667a7b51ccf943a498cccf9fb8617a438899a99b5cdf
6
+ metadata.gz: 44b1ab01d131133ef5d9fb32ae4dfaff5ca1187e45c5a1b4f4ec3b198f515020baf85b5bb4a1bb110ae6e85d74f521af86661b4741a63828bd49288a138c9513
7
+ data.tar.gz: af83ab927ed5fbd630880b51fa2a0bf25c9aaad066a603d1ce6bf178c873ed669461eb46cdb23003b888799f11a5114fdf5405288d8129783c9495ec74799fd5
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ # 0.7.0
2
+ - Support for probabilistic outputs with Rumale::SVM::OneClassSVM.
3
+ - Update numo-libsvm depedency to v2.1 or higher.
4
+
1
5
  # 0.6.0
2
6
  - Update numo-libsvm and numo-liblinear depedency to v2.0 or higher.
3
7
 
@@ -24,15 +24,16 @@ module Rumale
24
24
  # @param gamma [Float] The gamma parameter in rbf/poly/sigmoid kernel function.
25
25
  # @param coef0 [Float] The coefficient in poly/sigmoid kernel function.
26
26
  # @param shrinking [Boolean] The flag indicating whether to use the shrinking heuristics.
27
+ # @param probability [Boolean] The flag indicating whether to train the parameter for probability estimation.
27
28
  # @param cache_size [Float] The cache memory size in MB.
28
29
  # @param tol [Float] The tolerance of termination criterion.
29
30
  # @param verbose [Boolean] The flag indicating whether to output learning process message
30
31
  # @param random_seed [Integer/Nil] The seed value using to initialize the random generator.
31
32
  def initialize(nu: 1.0, kernel: 'rbf', degree: 3, gamma: 1.0, coef0: 0.0,
32
- shrinking: true, cache_size: 200.0, tol: 1e-3, verbose: false, random_seed: nil)
33
+ shrinking: true, probability: true, cache_size: 200.0, tol: 1e-3, verbose: false, random_seed: nil)
33
34
  check_params_numeric(nu: nu, degree: degree, gamma: gamma, coef0: coef0, cache_size: cache_size, tol: tol)
34
35
  check_params_string(kernel: kernel)
35
- check_params_boolean(shrinking: shrinking, verbose: verbose)
36
+ check_params_boolean(shrinking: shrinking, probability: probability, verbose: verbose)
36
37
  check_params_numeric_or_nil(random_seed: random_seed)
37
38
  @params = {}
38
39
  @params[:nu] = nu.to_f
@@ -41,6 +42,7 @@ module Rumale
41
42
  @params[:gamma] = gamma.to_f
42
43
  @params[:coef0] = coef0.to_f
43
44
  @params[:shrinking] = shrinking
45
+ @params[:probability] = probability
44
46
  @params[:cache_size] = cache_size.to_f
45
47
  @params[:tol] = tol.to_f
46
48
  @params[:verbose] = verbose
@@ -82,6 +84,19 @@ module Rumale
82
84
  Numo::Int32.cast(Numo::Libsvm.predict(x, libsvm_params, @model))
83
85
  end
84
86
 
87
+ # Predict class probability for samples.
88
+ # This method works correctly only if the probability parameter is true.
89
+ #
90
+ # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
91
+ # If the kernel is 'precomputed', the shape of x must be [n_samples, n_training_samples].
92
+ # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
93
+ def predict_proba(x)
94
+ raise "#{self.class.name}\##{__method__} expects to be called after training the model with the fit method." unless trained?
95
+ raise "#{self.class.name}\##{__method__} expects to be called after training the probablity parameters." unless trained_probs?
96
+ x = check_convert_sample_array(x)
97
+ Numo::Libsvm.predict_proba(x, libsvm_params, @model)
98
+ end
99
+
85
100
  # Dump marshal data.
86
101
  # @return [Hash] The marshal data about SVC.
87
102
  def marshal_dump
@@ -150,6 +165,10 @@ module Rumale
150
165
  def trained?
151
166
  !@model.nil?
152
167
  end
168
+
169
+ def trained_probs?
170
+ @model[:prob_density_marks].is_a?(Numo::NArray)
171
+ end
153
172
  end
154
173
  end
155
174
  end
@@ -5,6 +5,6 @@ module Rumale
5
5
  # This module consists of Rumale interfaces for suppor vector machine algorithms with LIBSVM and LIBLINEAR.
6
6
  module SVM
7
7
  # The version of Rumale::SVM you are using.
8
- VERSION = '0.6.0'
8
+ VERSION = '0.7.0'
9
9
  end
10
10
  end
@@ -10,6 +10,7 @@ module Rumale
10
10
  def fit: (Numo::DFloat x, ?untyped? _y) -> OneClassSVM
11
11
  def decision_function: (Numo::DFloat x) -> Numo::DFloat
12
12
  def predict: (Numo::DFloat x) -> Numo::Int32
13
+ def predict_proba: (Numo::DFloat x) -> Numo::DFloat
13
14
  def marshal_dump: () -> { params: Hash[Symbol, untyped], model: untyped }
14
15
  def marshal_load: (Hash[Symbol, untyped] obj) -> void
15
16
  def support: () -> Numo::Int32
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: rumale-svm
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.6.0
4
+ version: 0.7.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - yoshoku
8
8
  autorequire:
9
9
  bindir: exe
10
10
  cert_chain: []
11
- date: 2022-05-15 00:00:00.000000000 Z
11
+ date: 2022-10-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-liblinear
@@ -30,14 +30,14 @@ dependencies:
30
30
  requirements:
31
31
  - - "~>"
32
32
  - !ruby/object:Gem::Version
33
- version: '2.0'
33
+ version: '2.1'
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - "~>"
39
39
  - !ruby/object:Gem::Version
40
- version: '2.0'
40
+ version: '2.1'
41
41
  - !ruby/object:Gem::Dependency
42
42
  name: rumale
43
43
  requirement: !ruby/object:Gem::Requirement
@@ -116,7 +116,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
116
116
  - !ruby/object:Gem::Version
117
117
  version: '0'
118
118
  requirements: []
119
- rubygems_version: 3.2.33
119
+ rubygems_version: 3.3.7
120
120
  signing_key:
121
121
  specification_version: 4
122
122
  summary: Rumale::SVM provides support vector machine algorithms of LIBSVM and LIBLINEAR