torch-rb 0.2.5 → 0.2.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
 - data/CHANGELOG.md +6 -0
 - data/README.md +1 -1
 - data/ext/torch/ext.cpp +5 -0
 - data/lib/torch/native/function.rb +1 -0
 - data/lib/torch/native/generator.rb +5 -2
 - data/lib/torch/native/parser.rb +1 -1
 - data/lib/torch/nn/module.rb +1 -1
 - data/lib/torch/optim/rprop.rb +0 -3
 - data/lib/torch/tensor.rb +14 -4
 - data/lib/torch/version.rb +1 -1
 - metadata +2 -2
 
    
        checksums.yaml
    CHANGED
    
    | 
         @@ -1,7 +1,7 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            ---
         
     | 
| 
       2 
2 
     | 
    
         
             
            SHA256:
         
     | 
| 
       3 
     | 
    
         
            -
              metadata.gz:  
     | 
| 
       4 
     | 
    
         
            -
              data.tar.gz:  
     | 
| 
      
 3 
     | 
    
         
            +
              metadata.gz: da5c88539a2890933e44859af7c1acfe835405c89e82bc6a6fb2df37fdee141a
         
     | 
| 
      
 4 
     | 
    
         
            +
              data.tar.gz: 747ab48ba1b0ba16077ed31cf505f622a4120d18be9c0942ded39810095aa68e
         
     | 
| 
       5 
5 
     | 
    
         
             
            SHA512:
         
     | 
| 
       6 
     | 
    
         
            -
              metadata.gz:  
     | 
| 
       7 
     | 
    
         
            -
              data.tar.gz:  
     | 
| 
      
 6 
     | 
    
         
            +
              metadata.gz: d2bccf16e7af54d53affbc12030fd89f417fd060afa7057e97765bca00c24b7089f74b9e8aa4bab9180045e466649676f481425dc24ab29533b74138bb03e786
         
     | 
| 
      
 7 
     | 
    
         
            +
              data.tar.gz: 8eb49a743fedb220df4edc39d7d0c492c1827ce63f2c5d9a8d73495da63da5f6055674ca0bd60bb0b81be68db6e5c9300357de2c417e82e7af4fafd1ab6c7ca2
         
     | 
    
        data/CHANGELOG.md
    CHANGED
    
    
    
        data/README.md
    CHANGED
    
    
    
        data/ext/torch/ext.cpp
    CHANGED
    
    | 
         @@ -329,6 +329,11 @@ void Init_ext() 
     | 
|
| 
       329 
329 
     | 
    
         
             
                .define_method("numel", &torch::Tensor::numel)
         
     | 
| 
       330 
330 
     | 
    
         
             
                .define_method("element_size", &torch::Tensor::element_size)
         
     | 
| 
       331 
331 
     | 
    
         
             
                .define_method("requires_grad", &torch::Tensor::requires_grad)
         
     | 
| 
      
 332 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 333 
     | 
    
         
            +
                  "contiguous?",
         
     | 
| 
      
 334 
     | 
    
         
            +
                  *[](Tensor& self) {
         
     | 
| 
      
 335 
     | 
    
         
            +
                    return self.is_contiguous();
         
     | 
| 
      
 336 
     | 
    
         
            +
                  })
         
     | 
| 
       332 
337 
     | 
    
         
             
                .define_method(
         
     | 
| 
       333 
338 
     | 
    
         
             
                  "addcmul!",
         
     | 
| 
       334 
339 
     | 
    
         
             
                  *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
         
     | 
| 
         @@ -18,7 +18,7 @@ module Torch 
     | 
|
| 
       18 
18 
     | 
    
         
             
                      functions = functions()
         
     | 
| 
       19 
19 
     | 
    
         | 
| 
       20 
20 
     | 
    
         
             
                      # skip functions
         
     | 
| 
       21 
     | 
    
         
            -
                      skip_args = ["bool[3]", "Dimname", " 
     | 
| 
      
 21 
     | 
    
         
            +
                      skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
         
     | 
| 
       22 
22 
     | 
    
         | 
| 
       23 
23 
     | 
    
         
             
                      # remove functions
         
     | 
| 
       24 
24 
     | 
    
         
             
                      functions.reject! do |f|
         
     | 
| 
         @@ -31,7 +31,7 @@ module Torch 
     | 
|
| 
       31 
31 
     | 
    
         
             
                      todo_functions, functions =
         
     | 
| 
       32 
32 
     | 
    
         
             
                        functions.partition do |f|
         
     | 
| 
       33 
33 
     | 
    
         
             
                          f.args.any? do |a|
         
     | 
| 
       34 
     | 
    
         
            -
                            a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
         
     | 
| 
      
 34 
     | 
    
         
            +
                            a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
         
     | 
| 
       35 
35 
     | 
    
         
             
                            skip_args.any? { |sa| a[:type].include?(sa) } ||
         
     | 
| 
       36 
36 
     | 
    
         
             
                            # native_functions.yaml is missing size argument for normal
         
     | 
| 
       37 
37 
     | 
    
         
             
                            # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
         
     | 
| 
         @@ -112,6 +112,9 @@ void add_%{type}_functions(Module m) { 
     | 
|
| 
       112 
112 
     | 
    
         
             
                              "OptionalScalarType"
         
     | 
| 
       113 
113 
     | 
    
         
             
                            when "Tensor[]"
         
     | 
| 
       114 
114 
     | 
    
         
             
                              "TensorList"
         
     | 
| 
      
 115 
     | 
    
         
            +
                            when "Tensor?[]"
         
     | 
| 
      
 116 
     | 
    
         
            +
                              # TODO make optional
         
     | 
| 
      
 117 
     | 
    
         
            +
                              "TensorList"
         
     | 
| 
       115 
118 
     | 
    
         
             
                            when "int"
         
     | 
| 
       116 
119 
     | 
    
         
             
                              "int64_t"
         
     | 
| 
       117 
120 
     | 
    
         
             
                            when "float"
         
     | 
    
        data/lib/torch/native/parser.rb
    CHANGED
    
    
    
        data/lib/torch/nn/module.rb
    CHANGED
    
    | 
         @@ -145,7 +145,7 @@ module Torch 
     | 
|
| 
       145 
145 
     | 
    
         
             
                    params = {}
         
     | 
| 
       146 
146 
     | 
    
         
             
                    if recurse
         
     | 
| 
       147 
147 
     | 
    
         
             
                      named_children.each do |name, mod|
         
     | 
| 
       148 
     | 
    
         
            -
                        params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
         
     | 
| 
      
 148 
     | 
    
         
            +
                        params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
         
     | 
| 
       149 
149 
     | 
    
         
             
                      end
         
     | 
| 
       150 
150 
     | 
    
         
             
                    end
         
     | 
| 
       151 
151 
     | 
    
         
             
                    instance_variables.each do |name|
         
     | 
    
        data/lib/torch/optim/rprop.rb
    CHANGED
    
    
    
        data/lib/torch/tensor.rb
    CHANGED
    
    | 
         @@ -158,7 +158,8 @@ module Torch 
     | 
|
| 
       158 
158 
     | 
    
         
             
                  item <=> other
         
     | 
| 
       159 
159 
     | 
    
         
             
                end
         
     | 
| 
       160 
160 
     | 
    
         | 
| 
       161 
     | 
    
         
            -
                # based on python_variable_indexing.cpp
         
     | 
| 
      
 161 
     | 
    
         
            +
                # based on python_variable_indexing.cpp and
         
     | 
| 
      
 162 
     | 
    
         
            +
                # https://pytorch.org/cppdocs/notes/tensor_indexing.html
         
     | 
| 
       162 
163 
     | 
    
         
             
                def [](*indexes)
         
     | 
| 
       163 
164 
     | 
    
         
             
                  result = self
         
     | 
| 
       164 
165 
     | 
    
         
             
                  dim = 0
         
     | 
| 
         @@ -170,6 +171,8 @@ module Torch 
     | 
|
| 
       170 
171 
     | 
    
         
             
                      finish += 1 unless index.exclude_end?
         
     | 
| 
       171 
172 
     | 
    
         
             
                      result = result._slice_tensor(dim, index.begin, finish, 1)
         
     | 
| 
       172 
173 
     | 
    
         
             
                      dim += 1
         
     | 
| 
      
 174 
     | 
    
         
            +
                    elsif index.is_a?(Tensor)
         
     | 
| 
      
 175 
     | 
    
         
            +
                      result = result.index([index])
         
     | 
| 
       173 
176 
     | 
    
         
             
                    elsif index.nil?
         
     | 
| 
       174 
177 
     | 
    
         
             
                      result = result.unsqueeze(dim)
         
     | 
| 
       175 
178 
     | 
    
         
             
                      dim += 1
         
     | 
| 
         @@ -183,12 +186,12 @@ module Torch 
     | 
|
| 
       183 
186 
     | 
    
         
             
                  result
         
     | 
| 
       184 
187 
     | 
    
         
             
                end
         
     | 
| 
       185 
188 
     | 
    
         | 
| 
       186 
     | 
    
         
            -
                #  
     | 
| 
       187 
     | 
    
         
            -
                #  
     | 
| 
      
 189 
     | 
    
         
            +
                # based on python_variable_indexing.cpp and
         
     | 
| 
      
 190 
     | 
    
         
            +
                # https://pytorch.org/cppdocs/notes/tensor_indexing.html
         
     | 
| 
       188 
191 
     | 
    
         
             
                def []=(index, value)
         
     | 
| 
       189 
192 
     | 
    
         
             
                  raise ArgumentError, "Tensor does not support deleting items" if value.nil?
         
     | 
| 
       190 
193 
     | 
    
         | 
| 
       191 
     | 
    
         
            -
                  value = Torch.tensor(value) unless value.is_a?(Tensor)
         
     | 
| 
      
 194 
     | 
    
         
            +
                  value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
         
     | 
| 
       192 
195 
     | 
    
         | 
| 
       193 
196 
     | 
    
         
             
                  if index.is_a?(Numeric)
         
     | 
| 
       194 
197 
     | 
    
         
             
                    copy_to(_select_int(0, index), value)
         
     | 
| 
         @@ -196,6 +199,8 @@ module Torch 
     | 
|
| 
       196 
199 
     | 
    
         
             
                    finish = index.end
         
     | 
| 
       197 
200 
     | 
    
         
             
                    finish += 1 unless index.exclude_end?
         
     | 
| 
       198 
201 
     | 
    
         
             
                    copy_to(_slice_tensor(0, index.begin, finish, 1), value)
         
     | 
| 
      
 202 
     | 
    
         
            +
                  elsif index.is_a?(Tensor)
         
     | 
| 
      
 203 
     | 
    
         
            +
                    index_put!([index], value)
         
     | 
| 
       199 
204 
     | 
    
         
             
                  else
         
     | 
| 
       200 
205 
     | 
    
         
             
                    raise Error, "Unsupported index type: #{index.class.name}"
         
     | 
| 
       201 
206 
     | 
    
         
             
                  end
         
     | 
| 
         @@ -224,6 +229,11 @@ module Torch 
     | 
|
| 
       224 
229 
     | 
    
         
             
                  end
         
     | 
| 
       225 
230 
     | 
    
         
             
                end
         
     | 
| 
       226 
231 
     | 
    
         | 
| 
      
 232 
     | 
    
         
            +
                def clamp!(min, max)
         
     | 
| 
      
 233 
     | 
    
         
            +
                  _clamp_min_(min)
         
     | 
| 
      
 234 
     | 
    
         
            +
                  _clamp_max_(max)
         
     | 
| 
      
 235 
     | 
    
         
            +
                end
         
     | 
| 
      
 236 
     | 
    
         
            +
             
     | 
| 
       227 
237 
     | 
    
         
             
                private
         
     | 
| 
       228 
238 
     | 
    
         | 
| 
       229 
239 
     | 
    
         
             
                def copy_to(dst, src)
         
     | 
    
        data/lib/torch/version.rb
    CHANGED
    
    
    
        metadata
    CHANGED
    
    | 
         @@ -1,14 +1,14 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            --- !ruby/object:Gem::Specification
         
     | 
| 
       2 
2 
     | 
    
         
             
            name: torch-rb
         
     | 
| 
       3 
3 
     | 
    
         
             
            version: !ruby/object:Gem::Version
         
     | 
| 
       4 
     | 
    
         
            -
              version: 0.2. 
     | 
| 
      
 4 
     | 
    
         
            +
              version: 0.2.6
         
     | 
| 
       5 
5 
     | 
    
         
             
            platform: ruby
         
     | 
| 
       6 
6 
     | 
    
         
             
            authors:
         
     | 
| 
       7 
7 
     | 
    
         
             
            - Andrew Kane
         
     | 
| 
       8 
8 
     | 
    
         
             
            autorequire: 
         
     | 
| 
       9 
9 
     | 
    
         
             
            bindir: bin
         
     | 
| 
       10 
10 
     | 
    
         
             
            cert_chain: []
         
     | 
| 
       11 
     | 
    
         
            -
            date: 2020-06- 
     | 
| 
      
 11 
     | 
    
         
            +
            date: 2020-06-29 00:00:00.000000000 Z
         
     | 
| 
       12 
12 
     | 
    
         
             
            dependencies:
         
     | 
| 
       13 
13 
     | 
    
         
             
            - !ruby/object:Gem::Dependency
         
     | 
| 
       14 
14 
     | 
    
         
             
              name: rice
         
     |