ruby-dnn 1.2.3 → 1.3.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/examples/dcgan/dcgan.rb +1 -1
- data/examples/iris_example.rb +17 -41
- data/examples/iris_example_unused_model.rb +57 -0
- data/examples/vae.rb +1 -1
- data/lib/dnn/core/callbacks.rb +18 -8
- data/lib/dnn/core/iterator.rb +20 -4
- data/lib/dnn/core/layers/rnn_layers.rb +20 -24
- data/lib/dnn/core/models.rb +474 -149
- data/lib/dnn/core/savers.rb +4 -12
- data/lib/dnn/core/utils.rb +14 -0
- data/lib/dnn/datasets/iris.rb +5 -1
- data/lib/dnn/version.rb +1 -1
- data/lib/dnn.rb +32 -26
- metadata +3 -2
    
        data/lib/dnn/core/savers.rb
    CHANGED
    
    | @@ -14,15 +14,13 @@ module DNN | |
| 14 14 | 
             
                    load_bin(File.binread(file_name))
         | 
| 15 15 | 
             
                  end
         | 
| 16 16 |  | 
| 17 | 
            -
                  private
         | 
| 18 | 
            -
             | 
| 19 17 | 
             
                  def load_bin(bin)
         | 
| 20 18 | 
             
                    raise NotImplementedError, "Class '#{self.class.name}' has implement method 'load_bin'"
         | 
| 21 19 | 
             
                  end
         | 
| 22 20 | 
             
                end
         | 
| 23 21 |  | 
| 24 22 | 
             
                class MarshalLoader < Loader
         | 
| 25 | 
            -
                   | 
| 23 | 
            +
                  def load_bin(bin)
         | 
| 26 24 | 
             
                    data = Marshal.load(Zlib::Inflate.inflate(bin))
         | 
| 27 25 | 
             
                    unless @model.class.name == data[:class]
         | 
| 28 26 | 
             
                      raise DNNError, "Class name is mismatch. Target model is #{@model.class.name}. But loading model is #{data[:class]}."
         | 
| @@ -38,8 +36,6 @@ module DNN | |
| 38 36 | 
             
                end
         | 
| 39 37 |  | 
| 40 38 | 
             
                class JSONLoader < Loader
         | 
| 41 | 
            -
                  private
         | 
| 42 | 
            -
             | 
| 43 39 | 
             
                  def load_bin(bin)
         | 
| 44 40 | 
             
                    data = JSON.parse(bin, symbolize_names: true)
         | 
| 45 41 | 
             
                    unless @model.class.name == data[:class]
         | 
| @@ -48,7 +44,7 @@ module DNN | |
| 48 44 | 
             
                    set_all_params_base64_data(data[:params])
         | 
| 49 45 | 
             
                  end
         | 
| 50 46 |  | 
| 51 | 
            -
                  def set_all_params_base64_data(params_data)
         | 
| 47 | 
            +
                  private def set_all_params_base64_data(params_data)
         | 
| 52 48 | 
             
                    @model.trainable_layers.each.with_index do |layer, i|
         | 
| 53 49 | 
             
                      params_data[i].each do |(key, (shape, base64_data))|
         | 
| 54 50 | 
             
                        bin = Base64.decode64(base64_data)
         | 
| @@ -79,8 +75,6 @@ module DNN | |
| 79 75 | 
             
                    end
         | 
| 80 76 | 
             
                  end
         | 
| 81 77 |  | 
| 82 | 
            -
                  private
         | 
| 83 | 
            -
             | 
| 84 78 | 
             
                  def dump_bin
         | 
| 85 79 | 
             
                    raise NotImplementedError, "Class '#{self.class.name}' has implement method 'dump_bin'"
         | 
| 86 80 | 
             
                  end
         | 
| @@ -92,7 +86,7 @@ module DNN | |
| 92 86 | 
             
                    @include_model = include_model
         | 
| 93 87 | 
             
                  end
         | 
| 94 88 |  | 
| 95 | 
            -
                   | 
| 89 | 
            +
                  def dump_bin
         | 
| 96 90 | 
             
                    params_data = @model.get_all_params_data
         | 
| 97 91 | 
             
                    if @include_model
         | 
| 98 92 | 
             
                      @model.clean_layers
         | 
| @@ -110,14 +104,12 @@ module DNN | |
| 110 104 | 
             
                end
         | 
| 111 105 |  | 
| 112 106 | 
             
                class JSONSaver < Saver
         | 
| 113 | 
            -
                  private
         | 
| 114 | 
            -
             | 
| 115 107 | 
             
                  def dump_bin
         | 
| 116 108 | 
             
                    data = { version: VERSION, class: @model.class.name, params: get_all_params_base64_data }
         | 
| 117 109 | 
             
                    JSON.dump(data)
         | 
| 118 110 | 
             
                  end
         | 
| 119 111 |  | 
| 120 | 
            -
                  def get_all_params_base64_data
         | 
| 112 | 
            +
                  private def get_all_params_base64_data
         | 
| 121 113 | 
             
                    @model.trainable_layers.map do |layer|
         | 
| 122 114 | 
             
                      layer.get_params.to_h do |key, param|
         | 
| 123 115 | 
             
                        base64_data = Base64.encode64(param.data.to_binary)
         | 
    
        data/lib/dnn/core/utils.rb
    CHANGED
    
    | @@ -39,6 +39,20 @@ module DNN | |
| 39 39 | 
             
                  Losses::SoftmaxCrossEntropy.softmax(x)
         | 
| 40 40 | 
             
                end
         | 
| 41 41 |  | 
| 42 | 
            +
                # Check training or evaluate input data type.
         | 
| 43 | 
            +
                def self.check_input_data_type(data_name, data, expected_type)
         | 
| 44 | 
            +
                  if !data.is_a?(expected_type) && !data.is_a?(Array)
         | 
| 45 | 
            +
                    raise TypeError, "#{data_name}:#{data.class.name} is not an instance of #{expected_type.name} class or Array class."
         | 
| 46 | 
            +
                  end
         | 
| 47 | 
            +
                  if data.is_a?(Array)
         | 
| 48 | 
            +
                    data.each.with_index do |v, i|
         | 
| 49 | 
            +
                      unless v.is_a?(expected_type)
         | 
| 50 | 
            +
                        raise TypeError, "#{data_name}[#{i}]:#{v.class.name} is not an instance of #{expected_type.name} class."
         | 
| 51 | 
            +
                      end
         | 
| 52 | 
            +
                    end
         | 
| 53 | 
            +
                  end
         | 
| 54 | 
            +
                end
         | 
| 55 | 
            +
             | 
| 42 56 | 
             
                # Perform numerical differentiation.
         | 
| 43 57 | 
             
                def self.numerical_grad(x, func)
         | 
| 44 58 | 
             
                  (func.(x + 1e-7) - func.(x)) / 1e-7
         | 
    
        data/lib/dnn/datasets/iris.rb
    CHANGED
    
    | @@ -41,7 +41,11 @@ module DNN | |
| 41 41 | 
             
                    end
         | 
| 42 42 | 
             
                  end
         | 
| 43 43 | 
             
                  if shuffle
         | 
| 44 | 
            -
                     | 
| 44 | 
            +
                    if RUBY_VERSION.split(".")[0].to_i >= 3
         | 
| 45 | 
            +
                      orig_seed = Random.seed
         | 
| 46 | 
            +
                    else
         | 
| 47 | 
            +
                      orig_seed = Random::DEFAULT.seed
         | 
| 48 | 
            +
                    end
         | 
| 45 49 | 
             
                    srand(shuffle_seed)
         | 
| 46 50 | 
             
                    indexs = (0...csv_array.length).to_a.shuffle
         | 
| 47 51 | 
             
                    x[indexs, true] = x
         | 
    
        data/lib/dnn/version.rb
    CHANGED
    
    
    
        data/lib/dnn.rb
    CHANGED
    
    | @@ -1,4 +1,8 @@ | |
| 1 | 
            -
             | 
| 1 | 
            +
            if RUBY_PLATFORM == "wasm32-wasi"
         | 
| 2 | 
            +
              require "narray.so"
         | 
| 3 | 
            +
            else
         | 
| 4 | 
            +
              require "numo/narray"
         | 
| 5 | 
            +
            end
         | 
| 2 6 |  | 
| 3 7 | 
             
            module DNN
         | 
| 4 8 | 
             
              if ENV["RUBY_DNN_USE_CUMO"] == "ENABLE"
         | 
| @@ -27,28 +31,30 @@ module DNN | |
| 27 31 | 
             
              end
         | 
| 28 32 | 
             
            end
         | 
| 29 33 |  | 
| 30 | 
            -
             | 
| 31 | 
            -
            require_relative "dnn/ | 
| 32 | 
            -
            require_relative "dnn/core/ | 
| 33 | 
            -
            require_relative "dnn/core/ | 
| 34 | 
            -
            require_relative "dnn/core/ | 
| 35 | 
            -
            require_relative "dnn/core/ | 
| 36 | 
            -
            require_relative "dnn/core/ | 
| 37 | 
            -
            require_relative "dnn/core/ | 
| 38 | 
            -
            require_relative "dnn/core/ | 
| 39 | 
            -
            require_relative "dnn/core/ | 
| 40 | 
            -
            require_relative "dnn/core/layers/ | 
| 41 | 
            -
            require_relative "dnn/core/layers/ | 
| 42 | 
            -
            require_relative "dnn/core/layers/ | 
| 43 | 
            -
            require_relative "dnn/core/layers/ | 
| 44 | 
            -
            require_relative "dnn/core/layers/ | 
| 45 | 
            -
            require_relative "dnn/core/layers/ | 
| 46 | 
            -
            require_relative "dnn/core/layers/ | 
| 47 | 
            -
            require_relative "dnn/core/layers/ | 
| 48 | 
            -
            require_relative "dnn/core/ | 
| 49 | 
            -
            require_relative "dnn/core/ | 
| 50 | 
            -
            require_relative "dnn/core/ | 
| 51 | 
            -
            require_relative "dnn/core/ | 
| 52 | 
            -
            require_relative "dnn/core/ | 
| 53 | 
            -
            require_relative "dnn/core/ | 
| 54 | 
            -
            require_relative "dnn/core/ | 
| 34 | 
            +
            if RUBY_PLATFORM != "wasm32-wasi"
         | 
| 35 | 
            +
              require_relative "dnn/version"
         | 
| 36 | 
            +
              require_relative "dnn/core/monkey_patch"
         | 
| 37 | 
            +
              require_relative "dnn/core/error"
         | 
| 38 | 
            +
              require_relative "dnn/core/global"
         | 
| 39 | 
            +
              require_relative "dnn/core/tensor"
         | 
| 40 | 
            +
              require_relative "dnn/core/param"
         | 
| 41 | 
            +
              require_relative "dnn/core/link"
         | 
| 42 | 
            +
              require_relative "dnn/core/iterator"
         | 
| 43 | 
            +
              require_relative "dnn/core/models"
         | 
| 44 | 
            +
              require_relative "dnn/core/layers/basic_layers"
         | 
| 45 | 
            +
              require_relative "dnn/core/layers/normalizations"
         | 
| 46 | 
            +
              require_relative "dnn/core/layers/activations"
         | 
| 47 | 
            +
              require_relative "dnn/core/layers/merge_layers"
         | 
| 48 | 
            +
              require_relative "dnn/core/layers/split_layers"
         | 
| 49 | 
            +
              require_relative "dnn/core/layers/cnn_layers"
         | 
| 50 | 
            +
              require_relative "dnn/core/layers/embedding"
         | 
| 51 | 
            +
              require_relative "dnn/core/layers/rnn_layers"
         | 
| 52 | 
            +
              require_relative "dnn/core/layers/math_layers"
         | 
| 53 | 
            +
              require_relative "dnn/core/optimizers"
         | 
| 54 | 
            +
              require_relative "dnn/core/losses"
         | 
| 55 | 
            +
              require_relative "dnn/core/initializers"
         | 
| 56 | 
            +
              require_relative "dnn/core/regularizers"
         | 
| 57 | 
            +
              require_relative "dnn/core/callbacks"
         | 
| 58 | 
            +
              require_relative "dnn/core/savers"
         | 
| 59 | 
            +
              require_relative "dnn/core/utils"
         | 
| 60 | 
            +
            end
         | 
    
        metadata
    CHANGED
    
    | @@ -1,14 +1,14 @@ | |
| 1 1 | 
             
            --- !ruby/object:Gem::Specification
         | 
| 2 2 | 
             
            name: ruby-dnn
         | 
| 3 3 | 
             
            version: !ruby/object:Gem::Version
         | 
| 4 | 
            -
              version: 1. | 
| 4 | 
            +
              version: 1.3.0
         | 
| 5 5 | 
             
            platform: ruby
         | 
| 6 6 | 
             
            authors:
         | 
| 7 7 | 
             
            - unagiootoro
         | 
| 8 8 | 
             
            autorequire:
         | 
| 9 9 | 
             
            bindir: exe
         | 
| 10 10 | 
             
            cert_chain: []
         | 
| 11 | 
            -
            date: 2023-03- | 
| 11 | 
            +
            date: 2023-03-19 00:00:00.000000000 Z
         | 
| 12 12 | 
             
            dependencies:
         | 
| 13 13 | 
             
            - !ruby/object:Gem::Dependency
         | 
| 14 14 | 
             
              name: numo-narray
         | 
| @@ -139,6 +139,7 @@ files: | |
| 139 139 | 
             
            - examples/dcgan/imgen.rb
         | 
| 140 140 | 
             
            - examples/dcgan/train.rb
         | 
| 141 141 | 
             
            - examples/iris_example.rb
         | 
| 142 | 
            +
            - examples/iris_example_unused_model.rb
         | 
| 142 143 | 
             
            - examples/judge-number/README.md
         | 
| 143 144 | 
             
            - examples/judge-number/capture.PNG
         | 
| 144 145 | 
             
            - examples/judge-number/convnet8.rb
         |