rumale-svm 0.6.0 → 0.7.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/rumale/svm/one_class_svm.rb +21 -2
- data/lib/rumale/svm/version.rb +1 -1
- data/sig/rumale/svm/one_class_svm.rbs +1 -0
- metadata +5 -5
    
        checksums.yaml
    CHANGED
    
    | @@ -1,7 +1,7 @@ | |
| 1 1 | 
             
            ---
         | 
| 2 2 | 
             
            SHA256:
         | 
| 3 | 
            -
              metadata.gz:  | 
| 4 | 
            -
              data.tar.gz:  | 
| 3 | 
            +
              metadata.gz: ec0ef314d44123fa222655004b099652d244c7fb9bea73569a4fe701974e43b1
         | 
| 4 | 
            +
              data.tar.gz: ccaca04e1b3fe7691467a26a678f8978bc2f77e5a1b2e094b1af08923779588b
         | 
| 5 5 | 
             
            SHA512:
         | 
| 6 | 
            -
              metadata.gz:  | 
| 7 | 
            -
              data.tar.gz:  | 
| 6 | 
            +
              metadata.gz: 44b1ab01d131133ef5d9fb32ae4dfaff5ca1187e45c5a1b4f4ec3b198f515020baf85b5bb4a1bb110ae6e85d74f521af86661b4741a63828bd49288a138c9513
         | 
| 7 | 
            +
              data.tar.gz: af83ab927ed5fbd630880b51fa2a0bf25c9aaad066a603d1ce6bf178c873ed669461eb46cdb23003b888799f11a5114fdf5405288d8129783c9495ec74799fd5
         | 
    
        data/CHANGELOG.md
    CHANGED
    
    
| @@ -24,15 +24,16 @@ module Rumale | |
| 24 24 | 
             
                  # @param gamma [Float] The gamma parameter in rbf/poly/sigmoid kernel function.
         | 
| 25 25 | 
             
                  # @param coef0 [Float] The coefficient in poly/sigmoid kernel function.
         | 
| 26 26 | 
             
                  # @param shrinking [Boolean] The flag indicating whether to use the shrinking heuristics.
         | 
| 27 | 
            +
                  # @param probability [Boolean] The flag indicating whether to train the parameter for probability estimation.
         | 
| 27 28 | 
             
                  # @param cache_size [Float] The cache memory size in MB.
         | 
| 28 29 | 
             
                  # @param tol [Float] The tolerance of termination criterion.
         | 
| 29 30 | 
             
                  # @param verbose [Boolean] The flag indicating whether to output learning process message
         | 
| 30 31 | 
             
                  # @param random_seed [Integer/Nil] The seed value using to initialize the random generator.
         | 
| 31 32 | 
             
                  def initialize(nu: 1.0, kernel: 'rbf', degree: 3, gamma: 1.0, coef0: 0.0,
         | 
| 32 | 
            -
                                 shrinking: true, cache_size: 200.0, tol: 1e-3, verbose: false, random_seed: nil)
         | 
| 33 | 
            +
                                 shrinking: true, probability: true, cache_size: 200.0, tol: 1e-3, verbose: false, random_seed: nil)
         | 
| 33 34 | 
             
                    check_params_numeric(nu: nu, degree: degree, gamma: gamma, coef0: coef0, cache_size: cache_size, tol: tol)
         | 
| 34 35 | 
             
                    check_params_string(kernel: kernel)
         | 
| 35 | 
            -
                    check_params_boolean(shrinking: shrinking, verbose: verbose)
         | 
| 36 | 
            +
                    check_params_boolean(shrinking: shrinking, probability: probability, verbose: verbose)
         | 
| 36 37 | 
             
                    check_params_numeric_or_nil(random_seed: random_seed)
         | 
| 37 38 | 
             
                    @params = {}
         | 
| 38 39 | 
             
                    @params[:nu] = nu.to_f
         | 
| @@ -41,6 +42,7 @@ module Rumale | |
| 41 42 | 
             
                    @params[:gamma] = gamma.to_f
         | 
| 42 43 | 
             
                    @params[:coef0] = coef0.to_f
         | 
| 43 44 | 
             
                    @params[:shrinking] = shrinking
         | 
| 45 | 
            +
                    @params[:probability] = probability
         | 
| 44 46 | 
             
                    @params[:cache_size] = cache_size.to_f
         | 
| 45 47 | 
             
                    @params[:tol] = tol.to_f
         | 
| 46 48 | 
             
                    @params[:verbose] = verbose
         | 
| @@ -82,6 +84,19 @@ module Rumale | |
| 82 84 | 
             
                    Numo::Int32.cast(Numo::Libsvm.predict(x, libsvm_params, @model))
         | 
| 83 85 | 
             
                  end
         | 
| 84 86 |  | 
| 87 | 
            +
                  # Predict class probability for samples.
         | 
| 88 | 
            +
                  # This method works correctly only if the probability parameter is true.
         | 
| 89 | 
            +
                  #
         | 
| 90 | 
            +
                  # @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
         | 
| 91 | 
            +
                  #   If the kernel is 'precomputed', the shape of x must be [n_samples, n_training_samples].
         | 
| 92 | 
            +
                  # @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
         | 
| 93 | 
            +
                  def predict_proba(x)
         | 
| 94 | 
            +
                    raise "#{self.class.name}\##{__method__} expects to be called after training the model with the fit method." unless trained?
         | 
| 95 | 
            +
                    raise "#{self.class.name}\##{__method__} expects to be called after training the probablity parameters." unless trained_probs?
         | 
| 96 | 
            +
                    x = check_convert_sample_array(x)
         | 
| 97 | 
            +
                    Numo::Libsvm.predict_proba(x, libsvm_params, @model)
         | 
| 98 | 
            +
                  end
         | 
| 99 | 
            +
             | 
| 85 100 | 
             
                  # Dump marshal data.
         | 
| 86 101 | 
             
                  # @return [Hash] The marshal data about SVC.
         | 
| 87 102 | 
             
                  def marshal_dump
         | 
| @@ -150,6 +165,10 @@ module Rumale | |
| 150 165 | 
             
                  def trained?
         | 
| 151 166 | 
             
                    !@model.nil?
         | 
| 152 167 | 
             
                  end
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                  def trained_probs?
         | 
| 170 | 
            +
                    @model[:prob_density_marks].is_a?(Numo::NArray)
         | 
| 171 | 
            +
                  end
         | 
| 153 172 | 
             
                end
         | 
| 154 173 | 
             
              end
         | 
| 155 174 | 
             
            end
         | 
    
        data/lib/rumale/svm/version.rb
    CHANGED
    
    
| @@ -10,6 +10,7 @@ module Rumale | |
| 10 10 | 
             
                  def fit: (Numo::DFloat x, ?untyped? _y) -> OneClassSVM
         | 
| 11 11 | 
             
                  def decision_function: (Numo::DFloat x) -> Numo::DFloat
         | 
| 12 12 | 
             
                  def predict: (Numo::DFloat x) -> Numo::Int32
         | 
| 13 | 
            +
                  def predict_proba: (Numo::DFloat x) -> Numo::DFloat
         | 
| 13 14 | 
             
                  def marshal_dump: () -> { params: Hash[Symbol, untyped], model: untyped }
         | 
| 14 15 | 
             
                  def marshal_load: (Hash[Symbol, untyped] obj) -> void
         | 
| 15 16 | 
             
                  def support: () -> Numo::Int32
         | 
    
        metadata
    CHANGED
    
    | @@ -1,14 +1,14 @@ | |
| 1 1 | 
             
            --- !ruby/object:Gem::Specification
         | 
| 2 2 | 
             
            name: rumale-svm
         | 
| 3 3 | 
             
            version: !ruby/object:Gem::Version
         | 
| 4 | 
            -
              version: 0. | 
| 4 | 
            +
              version: 0.7.0
         | 
| 5 5 | 
             
            platform: ruby
         | 
| 6 6 | 
             
            authors:
         | 
| 7 7 | 
             
            - yoshoku
         | 
| 8 8 | 
             
            autorequire:
         | 
| 9 9 | 
             
            bindir: exe
         | 
| 10 10 | 
             
            cert_chain: []
         | 
| 11 | 
            -
            date: 2022- | 
| 11 | 
            +
            date: 2022-10-15 00:00:00.000000000 Z
         | 
| 12 12 | 
             
            dependencies:
         | 
| 13 13 | 
             
            - !ruby/object:Gem::Dependency
         | 
| 14 14 | 
             
              name: numo-liblinear
         | 
| @@ -30,14 +30,14 @@ dependencies: | |
| 30 30 | 
             
                requirements:
         | 
| 31 31 | 
             
                - - "~>"
         | 
| 32 32 | 
             
                  - !ruby/object:Gem::Version
         | 
| 33 | 
            -
                    version: '2. | 
| 33 | 
            +
                    version: '2.1'
         | 
| 34 34 | 
             
              type: :runtime
         | 
| 35 35 | 
             
              prerelease: false
         | 
| 36 36 | 
             
              version_requirements: !ruby/object:Gem::Requirement
         | 
| 37 37 | 
             
                requirements:
         | 
| 38 38 | 
             
                - - "~>"
         | 
| 39 39 | 
             
                  - !ruby/object:Gem::Version
         | 
| 40 | 
            -
                    version: '2. | 
| 40 | 
            +
                    version: '2.1'
         | 
| 41 41 | 
             
            - !ruby/object:Gem::Dependency
         | 
| 42 42 | 
             
              name: rumale
         | 
| 43 43 | 
             
              requirement: !ruby/object:Gem::Requirement
         | 
| @@ -116,7 +116,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement | |
| 116 116 | 
             
                - !ruby/object:Gem::Version
         | 
| 117 117 | 
             
                  version: '0'
         | 
| 118 118 | 
             
            requirements: []
         | 
| 119 | 
            -
            rubygems_version: 3. | 
| 119 | 
            +
            rubygems_version: 3.3.7
         | 
| 120 120 | 
             
            signing_key:
         | 
| 121 121 | 
             
            specification_version: 4
         | 
| 122 122 | 
             
            summary: Rumale::SVM provides support vector machine algorithms of LIBSVM and LIBLINEAR
         |