torch-rb 0.3.6 → 0.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
 - data/CHANGELOG.md +9 -0
 - data/README.md +1 -0
 - data/ext/torch/ext.cpp +29 -9
 - data/ext/torch/extconf.rb +3 -0
 - data/ext/torch/templates.cpp +28 -0
 - data/ext/torch/templates.hpp +23 -34
 - data/lib/torch.rb +35 -0
 - data/lib/torch/native/dispatcher.rb +30 -8
 - data/lib/torch/native/function.rb +87 -6
 - data/lib/torch/native/generator.rb +28 -18
 - data/lib/torch/native/parser.rb +55 -86
 - data/lib/torch/nn/functional.rb +106 -0
 - data/lib/torch/nn/module.rb +4 -1
 - data/lib/torch/nn/upsample.rb +31 -0
 - data/lib/torch/tensor.rb +13 -7
 - data/lib/torch/version.rb +1 -1
 - metadata +3 -2
 
    
        checksums.yaml
    CHANGED
    
    | 
         @@ -1,7 +1,7 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            ---
         
     | 
| 
       2 
2 
     | 
    
         
             
            SHA256:
         
     | 
| 
       3 
     | 
    
         
            -
              metadata.gz:  
     | 
| 
       4 
     | 
    
         
            -
              data.tar.gz:  
     | 
| 
      
 3 
     | 
    
         
            +
              metadata.gz: 8a1852ee3d1ecc7a29c23259b8c328a95030a270b7c11f37f22049177898652e
         
     | 
| 
      
 4 
     | 
    
         
            +
              data.tar.gz: 56823f1815d3c0c4d5d5c01ef76d781b792b3e4e7c68c0332a149b883a54c7c8
         
     | 
| 
       5 
5 
     | 
    
         
             
            SHA512:
         
     | 
| 
       6 
     | 
    
         
            -
              metadata.gz:  
     | 
| 
       7 
     | 
    
         
            -
              data.tar.gz:  
     | 
| 
      
 6 
     | 
    
         
            +
              metadata.gz: bed15510cfeaa555d71f1e1f46ed8944893bd349a07c4316dcd63429fe76e13facd8794399ef97fc400d05796579f2e84822b62c98c71dc996e211ad04113ae2
         
     | 
| 
      
 7 
     | 
    
         
            +
              data.tar.gz: aa05e3645e363eda27274323cdb7fb316342074d1d5afe8f7ee6bfd9819da7883b43d084beb6b29011c631c04fdddc8e6789db41c7c84c53ba9ed152d3338b09
         
     | 
    
        data/CHANGELOG.md
    CHANGED
    
    | 
         @@ -1,3 +1,12 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            ## 0.3.7 (2020-09-22)
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            - Improved performance
         
     | 
| 
      
 4 
     | 
    
         
            +
            - Added `Upsample`
         
     | 
| 
      
 5 
     | 
    
         
            +
            - Added support for passing tensor class to `type` method
         
     | 
| 
      
 6 
     | 
    
         
            +
            - Fixed error with buffers on GPU
         
     | 
| 
      
 7 
     | 
    
         
            +
            - Fixed error with `new_full`
         
     | 
| 
      
 8 
     | 
    
         
            +
            - Fixed issue with `numo` method and non-contiguous tensors
         
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
       1 
10 
     | 
    
         
             
            ## 0.3.6 (2020-09-17)
         
     | 
| 
       2 
11 
     | 
    
         | 
| 
       3 
12 
     | 
    
         
             
            - Added `inplace` option for leaky ReLU
         
     | 
    
        data/README.md
    CHANGED
    
    | 
         @@ -402,6 +402,7 @@ Here are a few full examples: 
     | 
|
| 
       402 
402 
     | 
    
         
             
            - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
         
     | 
| 
       403 
403 
     | 
    
         
             
            - [Collaborative filtering with MovieLens](examples/movielens)
         
     | 
| 
       404 
404 
     | 
    
         
             
            - [Sequence models and word embeddings](examples/nlp)
         
     | 
| 
      
 405 
     | 
    
         
            +
            - [Generative adversarial networks](examples/gan)
         
     | 
| 
       405 
406 
     | 
    
         | 
| 
       406 
407 
     | 
    
         
             
            ## LibTorch Installation
         
     | 
| 
       407 
408 
     | 
    
         | 
    
        data/ext/torch/ext.cpp
    CHANGED
    
    | 
         @@ -232,7 +232,7 @@ void Init_ext() 
     | 
|
| 
       232 
232 
     | 
    
         
             
                  })
         
     | 
| 
       233 
233 
     | 
    
         
             
                .define_singleton_method(
         
     | 
| 
       234 
234 
     | 
    
         
             
                  "_empty",
         
     | 
| 
       235 
     | 
    
         
            -
                  *[]( 
     | 
| 
      
 235 
     | 
    
         
            +
                  *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
         
     | 
| 
       236 
236 
     | 
    
         
             
                    return torch::empty(size, options);
         
     | 
| 
       237 
237 
     | 
    
         
             
                  })
         
     | 
| 
       238 
238 
     | 
    
         
             
                .define_singleton_method(
         
     | 
| 
         @@ -242,7 +242,7 @@ void Init_ext() 
     | 
|
| 
       242 
242 
     | 
    
         
             
                  })
         
     | 
| 
       243 
243 
     | 
    
         
             
                .define_singleton_method(
         
     | 
| 
       244 
244 
     | 
    
         
             
                  "_full",
         
     | 
| 
       245 
     | 
    
         
            -
                  *[]( 
     | 
| 
      
 245 
     | 
    
         
            +
                  *[](std::vector<int64_t> size, Scalar fill_value, const torch::TensorOptions& options) {
         
     | 
| 
       246 
246 
     | 
    
         
             
                    return torch::full(size, fill_value, options);
         
     | 
| 
       247 
247 
     | 
    
         
             
                  })
         
     | 
| 
       248 
248 
     | 
    
         
             
                .define_singleton_method(
         
     | 
| 
         @@ -257,22 +257,22 @@ void Init_ext() 
     | 
|
| 
       257 
257 
     | 
    
         
             
                  })
         
     | 
| 
       258 
258 
     | 
    
         
             
                .define_singleton_method(
         
     | 
| 
       259 
259 
     | 
    
         
             
                  "_ones",
         
     | 
| 
       260 
     | 
    
         
            -
                  *[]( 
     | 
| 
      
 260 
     | 
    
         
            +
                  *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
         
     | 
| 
       261 
261 
     | 
    
         
             
                    return torch::ones(size, options);
         
     | 
| 
       262 
262 
     | 
    
         
             
                  })
         
     | 
| 
       263 
263 
     | 
    
         
             
                .define_singleton_method(
         
     | 
| 
       264 
264 
     | 
    
         
             
                  "_rand",
         
     | 
| 
       265 
     | 
    
         
            -
                  *[]( 
     | 
| 
      
 265 
     | 
    
         
            +
                  *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
         
     | 
| 
       266 
266 
     | 
    
         
             
                    return torch::rand(size, options);
         
     | 
| 
       267 
267 
     | 
    
         
             
                  })
         
     | 
| 
       268 
268 
     | 
    
         
             
                .define_singleton_method(
         
     | 
| 
       269 
269 
     | 
    
         
             
                  "_randint",
         
     | 
| 
       270 
     | 
    
         
            -
                  *[](int64_t low, int64_t high,  
     | 
| 
      
 270 
     | 
    
         
            +
                  *[](int64_t low, int64_t high, std::vector<int64_t> size, const torch::TensorOptions &options) {
         
     | 
| 
       271 
271 
     | 
    
         
             
                    return torch::randint(low, high, size, options);
         
     | 
| 
       272 
272 
     | 
    
         
             
                  })
         
     | 
| 
       273 
273 
     | 
    
         
             
                .define_singleton_method(
         
     | 
| 
       274 
274 
     | 
    
         
             
                  "_randn",
         
     | 
| 
       275 
     | 
    
         
            -
                  *[]( 
     | 
| 
      
 275 
     | 
    
         
            +
                  *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
         
     | 
| 
       276 
276 
     | 
    
         
             
                    return torch::randn(size, options);
         
     | 
| 
       277 
277 
     | 
    
         
             
                  })
         
     | 
| 
       278 
278 
     | 
    
         
             
                .define_singleton_method(
         
     | 
| 
         @@ -282,7 +282,7 @@ void Init_ext() 
     | 
|
| 
       282 
282 
     | 
    
         
             
                  })
         
     | 
| 
       283 
283 
     | 
    
         
             
                .define_singleton_method(
         
     | 
| 
       284 
284 
     | 
    
         
             
                  "_zeros",
         
     | 
| 
       285 
     | 
    
         
            -
                  *[]( 
     | 
| 
      
 285 
     | 
    
         
            +
                  *[](std::vector<int64_t> size, const torch::TensorOptions &options) {
         
     | 
| 
       286 
286 
     | 
    
         
             
                    return torch::zeros(size, options);
         
     | 
| 
       287 
287 
     | 
    
         
             
                  })
         
     | 
| 
       288 
288 
     | 
    
         
             
                // begin operations
         
     | 
| 
         @@ -303,13 +303,13 @@ void Init_ext() 
     | 
|
| 
       303 
303 
     | 
    
         
             
                  })
         
     | 
| 
       304 
304 
     | 
    
         
             
                .define_singleton_method(
         
     | 
| 
       305 
305 
     | 
    
         
             
                  "_from_blob",
         
     | 
| 
       306 
     | 
    
         
            -
                  *[](String s,  
     | 
| 
      
 306 
     | 
    
         
            +
                  *[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
         
     | 
| 
       307 
307 
     | 
    
         
             
                    void *data = const_cast<char *>(s.c_str());
         
     | 
| 
       308 
308 
     | 
    
         
             
                    return torch::from_blob(data, size, options);
         
     | 
| 
       309 
309 
     | 
    
         
             
                  })
         
     | 
| 
       310 
310 
     | 
    
         
             
                .define_singleton_method(
         
     | 
| 
       311 
311 
     | 
    
         
             
                  "_tensor",
         
     | 
| 
       312 
     | 
    
         
            -
                  *[](Array a,  
     | 
| 
      
 312 
     | 
    
         
            +
                  *[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
         
     | 
| 
       313 
313 
     | 
    
         
             
                    auto dtype = options.dtype();
         
     | 
| 
       314 
314 
     | 
    
         
             
                    torch::Tensor t;
         
     | 
| 
       315 
315 
     | 
    
         
             
                    if (dtype == torch::kBool) {
         
     | 
| 
         @@ -342,6 +342,16 @@ void Init_ext() 
     | 
|
| 
       342 
342 
     | 
    
         
             
                .define_method("numel", &torch::Tensor::numel)
         
     | 
| 
       343 
343 
     | 
    
         
             
                .define_method("element_size", &torch::Tensor::element_size)
         
     | 
| 
       344 
344 
     | 
    
         
             
                .define_method("requires_grad", &torch::Tensor::requires_grad)
         
     | 
| 
      
 345 
     | 
    
         
            +
                // in C++ for performance
         
     | 
| 
      
 346 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 347 
     | 
    
         
            +
                  "shape",
         
     | 
| 
      
 348 
     | 
    
         
            +
                  *[](Tensor& self) {
         
     | 
| 
      
 349 
     | 
    
         
            +
                    Array a;
         
     | 
| 
      
 350 
     | 
    
         
            +
                    for (auto &size : self.sizes()) {
         
     | 
| 
      
 351 
     | 
    
         
            +
                      a.push(size);
         
     | 
| 
      
 352 
     | 
    
         
            +
                    }
         
     | 
| 
      
 353 
     | 
    
         
            +
                    return a;
         
     | 
| 
      
 354 
     | 
    
         
            +
                  })
         
     | 
| 
       345 
355 
     | 
    
         
             
                .define_method(
         
     | 
| 
       346 
356 
     | 
    
         
             
                  "_index",
         
     | 
| 
       347 
357 
     | 
    
         
             
                  *[](Tensor& self, Array indices) {
         
     | 
| 
         @@ -420,9 +430,19 @@ void Init_ext() 
     | 
|
| 
       420 
430 
     | 
    
         
             
                      tensor = tensor.to(device);
         
     | 
| 
       421 
431 
     | 
    
         
             
                    }
         
     | 
| 
       422 
432 
     | 
    
         | 
| 
      
 433 
     | 
    
         
            +
                    if (!tensor.is_contiguous()) {
         
     | 
| 
      
 434 
     | 
    
         
            +
                      tensor = tensor.contiguous();
         
     | 
| 
      
 435 
     | 
    
         
            +
                    }
         
     | 
| 
      
 436 
     | 
    
         
            +
             
     | 
| 
       423 
437 
     | 
    
         
             
                    auto data_ptr = (const char *) tensor.data_ptr();
         
     | 
| 
       424 
438 
     | 
    
         
             
                    return std::string(data_ptr, tensor.numel() * tensor.element_size());
         
     | 
| 
       425 
439 
     | 
    
         
             
                  })
         
     | 
| 
      
 440 
     | 
    
         
            +
                // for TorchVision
         
     | 
| 
      
 441 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 442 
     | 
    
         
            +
                  "_data_ptr",
         
     | 
| 
      
 443 
     | 
    
         
            +
                  *[](Tensor& self) {
         
     | 
| 
      
 444 
     | 
    
         
            +
                    return reinterpret_cast<uintptr_t>(self.data_ptr());
         
     | 
| 
      
 445 
     | 
    
         
            +
                  })
         
     | 
| 
       426 
446 
     | 
    
         
             
                // TODO figure out a better way to do this
         
     | 
| 
       427 
447 
     | 
    
         
             
                .define_method(
         
     | 
| 
       428 
448 
     | 
    
         
             
                  "_flat_data",
         
     | 
    
        data/ext/torch/extconf.rb
    CHANGED
    
    
    
        data/ext/torch/templates.cpp
    CHANGED
    
    | 
         @@ -2,6 +2,34 @@ 
     | 
|
| 
       2 
2 
     | 
    
         
             
            #include <rice/Object.hpp>
         
     | 
| 
       3 
3 
     | 
    
         
             
            #include "templates.hpp"
         
     | 
| 
       4 
4 
     | 
    
         | 
| 
      
 5 
     | 
    
         
            +
            Object wrap(bool x) {
         
     | 
| 
      
 6 
     | 
    
         
            +
              return to_ruby<bool>(x);
         
     | 
| 
      
 7 
     | 
    
         
            +
            }
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            Object wrap(int64_t x) {
         
     | 
| 
      
 10 
     | 
    
         
            +
              return to_ruby<int64_t>(x);
         
     | 
| 
      
 11 
     | 
    
         
            +
            }
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            Object wrap(double x) {
         
     | 
| 
      
 14 
     | 
    
         
            +
              return to_ruby<double>(x);
         
     | 
| 
      
 15 
     | 
    
         
            +
            }
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
      
 17 
     | 
    
         
            +
            Object wrap(torch::Tensor x) {
         
     | 
| 
      
 18 
     | 
    
         
            +
              return to_ruby<torch::Tensor>(x);
         
     | 
| 
      
 19 
     | 
    
         
            +
            }
         
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
      
 21 
     | 
    
         
            +
            Object wrap(torch::Scalar x) {
         
     | 
| 
      
 22 
     | 
    
         
            +
              return to_ruby<torch::Scalar>(x);
         
     | 
| 
      
 23 
     | 
    
         
            +
            }
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
            Object wrap(torch::ScalarType x) {
         
     | 
| 
      
 26 
     | 
    
         
            +
              return to_ruby<torch::ScalarType>(x);
         
     | 
| 
      
 27 
     | 
    
         
            +
            }
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
            Object wrap(torch::QScheme x) {
         
     | 
| 
      
 30 
     | 
    
         
            +
              return to_ruby<torch::QScheme>(x);
         
     | 
| 
      
 31 
     | 
    
         
            +
            }
         
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
       5 
33 
     | 
    
         
             
            Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
         
     | 
| 
       6 
34 
     | 
    
         
             
              Array a;
         
     | 
| 
       7 
35 
     | 
    
         
             
              a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
         
     | 
    
        data/ext/torch/templates.hpp
    CHANGED
    
    | 
         @@ -13,49 +13,31 @@ using torch::Device; 
     | 
|
| 
       13 
13 
     | 
    
         
             
            using torch::Scalar;
         
     | 
| 
       14 
14 
     | 
    
         
             
            using torch::ScalarType;
         
     | 
| 
       15 
15 
     | 
    
         
             
            using torch::Tensor;
         
     | 
| 
       16 
     | 
    
         
            -
             
     | 
| 
       17 
     | 
    
         
            -
             
     | 
| 
       18 
     | 
    
         
            -
            // it doesn't own underlying data
         
     | 
| 
       19 
     | 
    
         
            -
            class IntArrayRef {
         
     | 
| 
       20 
     | 
    
         
            -
              std::vector<int64_t> vec;
         
     | 
| 
       21 
     | 
    
         
            -
              public:
         
     | 
| 
       22 
     | 
    
         
            -
                IntArrayRef(Object o) {
         
     | 
| 
       23 
     | 
    
         
            -
                  Array a = Array(o);
         
     | 
| 
       24 
     | 
    
         
            -
                  for (size_t i = 0; i < a.size(); i++) {
         
     | 
| 
       25 
     | 
    
         
            -
                    vec.push_back(from_ruby<int64_t>(a[i]));
         
     | 
| 
       26 
     | 
    
         
            -
                  }
         
     | 
| 
       27 
     | 
    
         
            -
                }
         
     | 
| 
       28 
     | 
    
         
            -
                operator torch::IntArrayRef() {
         
     | 
| 
       29 
     | 
    
         
            -
                  return torch::IntArrayRef(vec);
         
     | 
| 
       30 
     | 
    
         
            -
                }
         
     | 
| 
       31 
     | 
    
         
            -
            };
         
     | 
| 
      
 16 
     | 
    
         
            +
            using torch::IntArrayRef;
         
     | 
| 
      
 17 
     | 
    
         
            +
            using torch::TensorList;
         
     | 
| 
       32 
18 
     | 
    
         | 
| 
       33 
19 
     | 
    
         
             
            template<>
         
     | 
| 
       34 
20 
     | 
    
         
             
            inline
         
     | 
| 
       35 
     | 
    
         
            -
             
     | 
| 
      
 21 
     | 
    
         
            +
            std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
         
     | 
| 
       36 
22 
     | 
    
         
             
            {
         
     | 
| 
       37 
     | 
    
         
            -
               
     | 
| 
      
 23 
     | 
    
         
            +
              Array a = Array(x);
         
     | 
| 
      
 24 
     | 
    
         
            +
              std::vector<int64_t> vec(a.size());
         
     | 
| 
      
 25 
     | 
    
         
            +
              for (size_t i = 0; i < a.size(); i++) {
         
     | 
| 
      
 26 
     | 
    
         
            +
                vec[i] = from_ruby<int64_t>(a[i]);
         
     | 
| 
      
 27 
     | 
    
         
            +
              }
         
     | 
| 
      
 28 
     | 
    
         
            +
              return vec;
         
     | 
| 
       38 
29 
     | 
    
         
             
            }
         
     | 
| 
       39 
30 
     | 
    
         | 
| 
       40 
     | 
    
         
            -
            class TensorList {
         
     | 
| 
       41 
     | 
    
         
            -
              std::vector<torch::Tensor> vec;
         
     | 
| 
       42 
     | 
    
         
            -
              public:
         
     | 
| 
       43 
     | 
    
         
            -
                TensorList(Object o) {
         
     | 
| 
       44 
     | 
    
         
            -
                  Array a = Array(o);
         
     | 
| 
       45 
     | 
    
         
            -
                  for (size_t i = 0; i < a.size(); i++) {
         
     | 
| 
       46 
     | 
    
         
            -
                    vec.push_back(from_ruby<torch::Tensor>(a[i]));
         
     | 
| 
       47 
     | 
    
         
            -
                  }
         
     | 
| 
       48 
     | 
    
         
            -
                }
         
     | 
| 
       49 
     | 
    
         
            -
                operator torch::TensorList() {
         
     | 
| 
       50 
     | 
    
         
            -
                  return torch::TensorList(vec);
         
     | 
| 
       51 
     | 
    
         
            -
                }
         
     | 
| 
       52 
     | 
    
         
            -
            };
         
     | 
| 
       53 
     | 
    
         
            -
             
     | 
| 
       54 
31 
     | 
    
         
             
            template<>
         
     | 
| 
       55 
32 
     | 
    
         
             
            inline
         
     | 
| 
       56 
     | 
    
         
            -
             
     | 
| 
      
 33 
     | 
    
         
            +
            std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
         
     | 
| 
       57 
34 
     | 
    
         
             
            {
         
     | 
| 
       58 
     | 
    
         
            -
               
     | 
| 
      
 35 
     | 
    
         
            +
              Array a = Array(x);
         
     | 
| 
      
 36 
     | 
    
         
            +
              std::vector<Tensor> vec(a.size());
         
     | 
| 
      
 37 
     | 
    
         
            +
              for (size_t i = 0; i < a.size(); i++) {
         
     | 
| 
      
 38 
     | 
    
         
            +
                vec[i] = from_ruby<Tensor>(a[i]);
         
     | 
| 
      
 39 
     | 
    
         
            +
              }
         
     | 
| 
      
 40 
     | 
    
         
            +
              return vec;
         
     | 
| 
       59 
41 
     | 
    
         
             
            }
         
     | 
| 
       60 
42 
     | 
    
         | 
| 
       61 
43 
     | 
    
         
             
            class FanModeType {
         
     | 
| 
         @@ -242,6 +224,13 @@ torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x) 
     | 
|
| 
       242 
224 
     | 
    
         
             
              }
         
     | 
| 
       243 
225 
     | 
    
         
             
            }
         
     | 
| 
       244 
226 
     | 
    
         | 
| 
      
 227 
     | 
    
         
            +
            Object wrap(bool x);
         
     | 
| 
      
 228 
     | 
    
         
            +
            Object wrap(int64_t x);
         
     | 
| 
      
 229 
     | 
    
         
            +
            Object wrap(double x);
         
     | 
| 
      
 230 
     | 
    
         
            +
            Object wrap(torch::Tensor x);
         
     | 
| 
      
 231 
     | 
    
         
            +
            Object wrap(torch::Scalar x);
         
     | 
| 
      
 232 
     | 
    
         
            +
            Object wrap(torch::ScalarType x);
         
     | 
| 
      
 233 
     | 
    
         
            +
            Object wrap(torch::QScheme x);
         
     | 
| 
       245 
234 
     | 
    
         
             
            Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
         
     | 
| 
       246 
235 
     | 
    
         
             
            Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
         
     | 
| 
       247 
236 
     | 
    
         
             
            Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
         
     | 
    
        data/lib/torch.rb
    CHANGED
    
    | 
         @@ -174,6 +174,9 @@ require "torch/nn/smooth_l1_loss" 
     | 
|
| 
       174 
174 
     | 
    
         
             
            require "torch/nn/soft_margin_loss"
         
     | 
| 
       175 
175 
     | 
    
         
             
            require "torch/nn/triplet_margin_loss"
         
     | 
| 
       176 
176 
     | 
    
         | 
| 
      
 177 
     | 
    
         
            +
            # nn vision
         
     | 
| 
      
 178 
     | 
    
         
            +
            require "torch/nn/upsample"
         
     | 
| 
      
 179 
     | 
    
         
            +
             
     | 
| 
       177 
180 
     | 
    
         
             
            # nn other
         
     | 
| 
       178 
181 
     | 
    
         
             
            require "torch/nn/functional"
         
     | 
| 
       179 
182 
     | 
    
         
             
            require "torch/nn/init"
         
     | 
| 
         @@ -196,6 +199,32 @@ module Torch 
     | 
|
| 
       196 
199 
     | 
    
         
             
                end
         
     | 
| 
       197 
200 
     | 
    
         
             
              end
         
     | 
| 
       198 
201 
     | 
    
         | 
| 
      
 202 
     | 
    
         
            +
              # legacy
         
     | 
| 
      
 203 
     | 
    
         
            +
              # but may make it easier to port tutorials
         
     | 
| 
      
 204 
     | 
    
         
            +
              module Autograd
         
     | 
| 
      
 205 
     | 
    
         
            +
                class Variable
         
     | 
| 
      
 206 
     | 
    
         
            +
                  def self.new(x)
         
     | 
| 
      
 207 
     | 
    
         
            +
                    raise ArgumentError, "Variable data has to be a tensor, but got #{x.class.name}" unless x.is_a?(Tensor)
         
     | 
| 
      
 208 
     | 
    
         
            +
                    warn "[torch] The Variable API is deprecated. Use tensors with requires_grad: true instead."
         
     | 
| 
      
 209 
     | 
    
         
            +
                    x
         
     | 
| 
      
 210 
     | 
    
         
            +
                  end
         
     | 
| 
      
 211 
     | 
    
         
            +
                end
         
     | 
| 
      
 212 
     | 
    
         
            +
              end
         
     | 
| 
      
 213 
     | 
    
         
            +
             
     | 
| 
      
 214 
     | 
    
         
            +
              # TODO move to C++
         
     | 
| 
      
 215 
     | 
    
         
            +
              class ByteStorage
         
     | 
| 
      
 216 
     | 
    
         
            +
                # private
         
     | 
| 
      
 217 
     | 
    
         
            +
                attr_reader :bytes
         
     | 
| 
      
 218 
     | 
    
         
            +
             
     | 
| 
      
 219 
     | 
    
         
            +
                def initialize(bytes)
         
     | 
| 
      
 220 
     | 
    
         
            +
                  @bytes = bytes
         
     | 
| 
      
 221 
     | 
    
         
            +
                end
         
     | 
| 
      
 222 
     | 
    
         
            +
             
     | 
| 
      
 223 
     | 
    
         
            +
                def self.from_buffer(bytes)
         
     | 
| 
      
 224 
     | 
    
         
            +
                  new(bytes)
         
     | 
| 
      
 225 
     | 
    
         
            +
                end
         
     | 
| 
      
 226 
     | 
    
         
            +
              end
         
     | 
| 
      
 227 
     | 
    
         
            +
             
     | 
| 
       199 
228 
     | 
    
         
             
              # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
         
     | 
| 
       200 
229 
     | 
    
         
             
              # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
         
     | 
| 
       201 
230 
     | 
    
         
             
              DTYPE_TO_ENUM = {
         
     | 
| 
         @@ -224,18 +253,24 @@ module Torch 
     | 
|
| 
       224 
253 
     | 
    
         
             
              }
         
     | 
| 
       225 
254 
     | 
    
         
             
              ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
         
     | 
| 
       226 
255 
     | 
    
         | 
| 
      
 256 
     | 
    
         
            +
              TENSOR_TYPE_CLASSES = []
         
     | 
| 
      
 257 
     | 
    
         
            +
             
     | 
| 
       227 
258 
     | 
    
         
             
              def self._make_tensor_class(dtype, cuda = false)
         
     | 
| 
       228 
259 
     | 
    
         
             
                cls = Class.new
         
     | 
| 
       229 
260 
     | 
    
         
             
                device = cuda ? "cuda" : "cpu"
         
     | 
| 
       230 
261 
     | 
    
         
             
                cls.define_singleton_method("new") do |*args|
         
     | 
| 
       231 
262 
     | 
    
         
             
                  if args.size == 1 && args.first.is_a?(Tensor)
         
     | 
| 
       232 
263 
     | 
    
         
             
                    args.first.send(dtype).to(device)
         
     | 
| 
      
 264 
     | 
    
         
            +
                  elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
         
     | 
| 
      
 265 
     | 
    
         
            +
                    bytes = args.first.bytes
         
     | 
| 
      
 266 
     | 
    
         
            +
                    Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
         
     | 
| 
       233 
267 
     | 
    
         
             
                  elsif args.size == 1 && args.first.is_a?(Array)
         
     | 
| 
       234 
268 
     | 
    
         
             
                    Torch.tensor(args.first, dtype: dtype, device: device)
         
     | 
| 
       235 
269 
     | 
    
         
             
                  else
         
     | 
| 
       236 
270 
     | 
    
         
             
                    Torch.empty(*args, dtype: dtype, device: device)
         
     | 
| 
       237 
271 
     | 
    
         
             
                  end
         
     | 
| 
       238 
272 
     | 
    
         
             
                end
         
     | 
| 
      
 273 
     | 
    
         
            +
                TENSOR_TYPE_CLASSES << cls
         
     | 
| 
       239 
274 
     | 
    
         
             
                cls
         
     | 
| 
       240 
275 
     | 
    
         
             
              end
         
     | 
| 
       241 
276 
     | 
    
         | 
| 
         @@ -22,21 +22,43 @@ module Torch 
     | 
|
| 
       22 
22 
     | 
    
         
             
                    end
         
     | 
| 
       23 
23 
     | 
    
         | 
| 
       24 
24 
     | 
    
         
             
                    def bind_functions(context, def_method, functions)
         
     | 
| 
      
 25 
     | 
    
         
            +
                      instance_method = def_method == :define_method
         
     | 
| 
       25 
26 
     | 
    
         
             
                      functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
         
     | 
| 
       26 
     | 
    
         
            -
                        if  
     | 
| 
      
 27 
     | 
    
         
            +
                        if instance_method
         
     | 
| 
       27 
28 
     | 
    
         
             
                          funcs.map! { |f| Function.new(f.function) }
         
     | 
| 
       28 
     | 
    
         
            -
                          funcs.each { |f| f.args.reject! { |a| a[:name] ==  
     | 
| 
      
 29 
     | 
    
         
            +
                          funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
         
     | 
| 
       29 
30 
     | 
    
         
             
                        end
         
     | 
| 
       30 
31 
     | 
    
         | 
| 
       31 
     | 
    
         
            -
                        defined =  
     | 
| 
      
 32 
     | 
    
         
            +
                        defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
         
     | 
| 
       32 
33 
     | 
    
         
             
                        next if defined && name != "clone"
         
     | 
| 
       33 
34 
     | 
    
         | 
| 
       34 
     | 
    
         
            -
                        parser  
     | 
| 
      
 35 
     | 
    
         
            +
                        # skip parser when possible for performance
         
     | 
| 
      
 36 
     | 
    
         
            +
                        if funcs.size == 1 && funcs.first.args.size == 0
         
     | 
| 
      
 37 
     | 
    
         
            +
                          # functions with no arguments
         
     | 
| 
      
 38 
     | 
    
         
            +
                          if instance_method
         
     | 
| 
      
 39 
     | 
    
         
            +
                            context.send(:alias_method, name, funcs.first.cpp_name)
         
     | 
| 
      
 40 
     | 
    
         
            +
                          else
         
     | 
| 
      
 41 
     | 
    
         
            +
                            context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
         
     | 
| 
      
 42 
     | 
    
         
            +
                          end
         
     | 
| 
      
 43 
     | 
    
         
            +
                        elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
         
     | 
| 
      
 44 
     | 
    
         
            +
                          # functions that take a tensor or scalar
         
     | 
| 
      
 45 
     | 
    
         
            +
                          scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
         
     | 
| 
      
 46 
     | 
    
         
            +
                          context.send(def_method, name) do |other|
         
     | 
| 
      
 47 
     | 
    
         
            +
                            case other
         
     | 
| 
      
 48 
     | 
    
         
            +
                            when Tensor
         
     | 
| 
      
 49 
     | 
    
         
            +
                              send(tensor_name, other)
         
     | 
| 
      
 50 
     | 
    
         
            +
                            else
         
     | 
| 
      
 51 
     | 
    
         
            +
                              send(scalar_name, other)
         
     | 
| 
      
 52 
     | 
    
         
            +
                            end
         
     | 
| 
      
 53 
     | 
    
         
            +
                          end
         
     | 
| 
      
 54 
     | 
    
         
            +
                        else
         
     | 
| 
      
 55 
     | 
    
         
            +
                          parser = Parser.new(funcs)
         
     | 
| 
       35 
56 
     | 
    
         | 
| 
       36 
     | 
    
         
            -
             
     | 
| 
       37 
     | 
    
         
            -
             
     | 
| 
       38 
     | 
    
         
            -
             
     | 
| 
       39 
     | 
    
         
            -
             
     | 
| 
      
 57 
     | 
    
         
            +
                          context.send(def_method, name) do |*args, **options|
         
     | 
| 
      
 58 
     | 
    
         
            +
                            result = parser.parse(args, options)
         
     | 
| 
      
 59 
     | 
    
         
            +
                            raise ArgumentError, result[:error] if result[:error]
         
     | 
| 
      
 60 
     | 
    
         
            +
                            send(result[:name], *result[:args])
         
     | 
| 
      
 61 
     | 
    
         
            +
                          end
         
     | 
| 
       40 
62 
     | 
    
         
             
                        end
         
     | 
| 
       41 
63 
     | 
    
         
             
                      end
         
     | 
| 
       42 
64 
     | 
    
         
             
                    end
         
     | 
| 
         @@ -6,9 +6,10 @@ module Torch 
     | 
|
| 
       6 
6 
     | 
    
         
             
                  def initialize(function)
         
     | 
| 
       7 
7 
     | 
    
         
             
                    @function = function
         
     | 
| 
       8 
8 
     | 
    
         | 
| 
       9 
     | 
    
         
            -
                     
     | 
| 
       10 
     | 
    
         
            -
                    @ 
     | 
| 
       11 
     | 
    
         
            -
                    @function["func"]. 
     | 
| 
      
 9 
     | 
    
         
            +
                    # note: don't modify function in-place
         
     | 
| 
      
 10 
     | 
    
         
            +
                    @tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
         
     | 
| 
      
 11 
     | 
    
         
            +
                    @tensor_options = @function["func"].include?(@tensor_options_str)
         
     | 
| 
      
 12 
     | 
    
         
            +
                    @out = out_size > 0 && base_name[-1] != "_"
         
     | 
| 
       12 
13 
     | 
    
         
             
                  end
         
     | 
| 
       13 
14 
     | 
    
         | 
| 
       14 
15 
     | 
    
         
             
                  def func
         
     | 
| 
         @@ -31,7 +32,7 @@ module Torch 
     | 
|
| 
       31 
32 
     | 
    
         
             
                    @args ||= begin
         
     | 
| 
       32 
33 
     | 
    
         
             
                      args = []
         
     | 
| 
       33 
34 
     | 
    
         
             
                      pos = true
         
     | 
| 
       34 
     | 
    
         
            -
                      args_str = func.split("(", 2).last.split(") ->").first
         
     | 
| 
      
 35 
     | 
    
         
            +
                      args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
         
     | 
| 
       35 
36 
     | 
    
         
             
                      args_str.split(", ").each do |a|
         
     | 
| 
       36 
37 
     | 
    
         
             
                        if a == "*"
         
     | 
| 
       37 
38 
     | 
    
         
             
                          pos = false
         
     | 
| 
         @@ -72,12 +73,88 @@ module Torch 
     | 
|
| 
       72 
73 
     | 
    
         
             
                        next if t == "Generator?"
         
     | 
| 
       73 
74 
     | 
    
         
             
                        next if t == "MemoryFormat"
         
     | 
| 
       74 
75 
     | 
    
         
             
                        next if t == "MemoryFormat?"
         
     | 
| 
       75 
     | 
    
         
            -
                        args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
         
     | 
| 
      
 76 
     | 
    
         
            +
                        args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
         
     | 
| 
       76 
77 
     | 
    
         
             
                      end
         
     | 
| 
       77 
78 
     | 
    
         
             
                      args
         
     | 
| 
       78 
79 
     | 
    
         
             
                    end
         
     | 
| 
       79 
80 
     | 
    
         
             
                  end
         
     | 
| 
       80 
81 
     | 
    
         | 
| 
      
 82 
     | 
    
         
            +
                  def arg_checkers
         
     | 
| 
      
 83 
     | 
    
         
            +
                    @arg_checkers ||= begin
         
     | 
| 
      
 84 
     | 
    
         
            +
                      checkers = {}
         
     | 
| 
      
 85 
     | 
    
         
            +
                      arg_types.each do |k, t|
         
     | 
| 
      
 86 
     | 
    
         
            +
                        checker =
         
     | 
| 
      
 87 
     | 
    
         
            +
                          case t
         
     | 
| 
      
 88 
     | 
    
         
            +
                          when "Tensor"
         
     | 
| 
      
 89 
     | 
    
         
            +
                            ->(v) { v.is_a?(Tensor) }
         
     | 
| 
      
 90 
     | 
    
         
            +
                          when "Tensor?"
         
     | 
| 
      
 91 
     | 
    
         
            +
                            ->(v) { v.nil? || v.is_a?(Tensor) }
         
     | 
| 
      
 92 
     | 
    
         
            +
                          when "Tensor[]", "Tensor?[]"
         
     | 
| 
      
 93 
     | 
    
         
            +
                            ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
         
     | 
| 
      
 94 
     | 
    
         
            +
                          when "int"
         
     | 
| 
      
 95 
     | 
    
         
            +
                            if k == :reduction
         
     | 
| 
      
 96 
     | 
    
         
            +
                              ->(v) { v.is_a?(String) }
         
     | 
| 
      
 97 
     | 
    
         
            +
                            else
         
     | 
| 
      
 98 
     | 
    
         
            +
                              ->(v) { v.is_a?(Integer) }
         
     | 
| 
      
 99 
     | 
    
         
            +
                            end
         
     | 
| 
      
 100 
     | 
    
         
            +
                          when "int?"
         
     | 
| 
      
 101 
     | 
    
         
            +
                            ->(v) { v.is_a?(Integer) || v.nil? }
         
     | 
| 
      
 102 
     | 
    
         
            +
                          when "float?"
         
     | 
| 
      
 103 
     | 
    
         
            +
                            ->(v) { v.is_a?(Numeric) || v.nil? }
         
     | 
| 
      
 104 
     | 
    
         
            +
                          when "bool?"
         
     | 
| 
      
 105 
     | 
    
         
            +
                            ->(v) { v == true || v == false || v.nil? }
         
     | 
| 
      
 106 
     | 
    
         
            +
                          when "float"
         
     | 
| 
      
 107 
     | 
    
         
            +
                            ->(v) { v.is_a?(Numeric) }
         
     | 
| 
      
 108 
     | 
    
         
            +
                          when /int\[.*\]/
         
     | 
| 
      
 109 
     | 
    
         
            +
                            ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
         
     | 
| 
      
 110 
     | 
    
         
            +
                          when "Scalar"
         
     | 
| 
      
 111 
     | 
    
         
            +
                            ->(v) { v.is_a?(Numeric) }
         
     | 
| 
      
 112 
     | 
    
         
            +
                          when "Scalar?"
         
     | 
| 
      
 113 
     | 
    
         
            +
                            ->(v) { v.is_a?(Numeric) || v.nil? }
         
     | 
| 
      
 114 
     | 
    
         
            +
                          when "ScalarType"
         
     | 
| 
      
 115 
     | 
    
         
            +
                            ->(v) { false } # not supported yet
         
     | 
| 
      
 116 
     | 
    
         
            +
                          when "ScalarType?"
         
     | 
| 
      
 117 
     | 
    
         
            +
                            ->(v) { v.nil? }
         
     | 
| 
      
 118 
     | 
    
         
            +
                          when "bool"
         
     | 
| 
      
 119 
     | 
    
         
            +
                            ->(v) { v == true || v == false }
         
     | 
| 
      
 120 
     | 
    
         
            +
                          when "str"
         
     | 
| 
      
 121 
     | 
    
         
            +
                            ->(v) { v.is_a?(String) }
         
     | 
| 
      
 122 
     | 
    
         
            +
                          else
         
     | 
| 
      
 123 
     | 
    
         
            +
                            raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
         
     | 
| 
      
 124 
     | 
    
         
            +
                          end
         
     | 
| 
      
 125 
     | 
    
         
            +
                        checkers[k] = checker
         
     | 
| 
      
 126 
     | 
    
         
            +
                      end
         
     | 
| 
      
 127 
     | 
    
         
            +
                      checkers
         
     | 
| 
      
 128 
     | 
    
         
            +
                    end
         
     | 
| 
      
 129 
     | 
    
         
            +
                  end
         
     | 
| 
      
 130 
     | 
    
         
            +
             
     | 
| 
      
 131 
     | 
    
         
            +
                  def int_array_lengths
         
     | 
| 
      
 132 
     | 
    
         
            +
                    @int_array_lengths ||= begin
         
     | 
| 
      
 133 
     | 
    
         
            +
                      ret = {}
         
     | 
| 
      
 134 
     | 
    
         
            +
                      arg_types.each do |k, t|
         
     | 
| 
      
 135 
     | 
    
         
            +
                        if t.match?(/\Aint\[.+\]\z/)
         
     | 
| 
      
 136 
     | 
    
         
            +
                          size = t[4..-2]
         
     | 
| 
      
 137 
     | 
    
         
            +
                          raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
         
     | 
| 
      
 138 
     | 
    
         
            +
                          ret[k] = size.to_i
         
     | 
| 
      
 139 
     | 
    
         
            +
                        end
         
     | 
| 
      
 140 
     | 
    
         
            +
                      end
         
     | 
| 
      
 141 
     | 
    
         
            +
                      ret
         
     | 
| 
      
 142 
     | 
    
         
            +
                    end
         
     | 
| 
      
 143 
     | 
    
         
            +
                  end
         
     | 
| 
      
 144 
     | 
    
         
            +
             
     | 
| 
      
 145 
     | 
    
         
            +
                  def arg_names
         
     | 
| 
      
 146 
     | 
    
         
            +
                    @arg_names ||= args.map { |a| a[:name] }
         
     | 
| 
      
 147 
     | 
    
         
            +
                  end
         
     | 
| 
      
 148 
     | 
    
         
            +
             
     | 
| 
      
 149 
     | 
    
         
            +
                  def arg_types
         
     | 
| 
      
 150 
     | 
    
         
            +
                    @arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
         
     | 
| 
      
 151 
     | 
    
         
            +
                  end
         
     | 
| 
      
 152 
     | 
    
         
            +
             
     | 
| 
      
 153 
     | 
    
         
            +
                  def arg_defaults
         
     | 
| 
      
 154 
     | 
    
         
            +
                    # TODO find out why can't use select here
         
     | 
| 
      
 155 
     | 
    
         
            +
                    @arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
         
     | 
| 
      
 156 
     | 
    
         
            +
                  end
         
     | 
| 
      
 157 
     | 
    
         
            +
             
     | 
| 
       81 
158 
     | 
    
         
             
                  def out_size
         
     | 
| 
       82 
159 
     | 
    
         
             
                    @out_size ||= func.split("->").last.count("!")
         
     | 
| 
       83 
160 
     | 
    
         
             
                  end
         
     | 
| 
         @@ -90,8 +167,12 @@ module Torch 
     | 
|
| 
       90 
167 
     | 
    
         
             
                    @ret_array ||= func.split("->").last.include?('[]')
         
     | 
| 
       91 
168 
     | 
    
         
             
                  end
         
     | 
| 
       92 
169 
     | 
    
         | 
| 
      
 170 
     | 
    
         
            +
                  def ret_void?
         
     | 
| 
      
 171 
     | 
    
         
            +
                    func.split("->").last.strip == "()"
         
     | 
| 
      
 172 
     | 
    
         
            +
                  end
         
     | 
| 
      
 173 
     | 
    
         
            +
             
     | 
| 
       93 
174 
     | 
    
         
             
                  def out?
         
     | 
| 
       94 
     | 
    
         
            -
                     
     | 
| 
      
 175 
     | 
    
         
            +
                    @out
         
     | 
| 
       95 
176 
     | 
    
         
             
                  end
         
     | 
| 
       96 
177 
     | 
    
         | 
| 
       97 
178 
     | 
    
         
             
                  def ruby_name
         
     | 
| 
         @@ -72,16 +72,18 @@ void add_%{type}_functions(Module m); 
     | 
|
| 
       72 
72 
     | 
    
         
             
            #include <rice/Module.hpp>
         
     | 
| 
       73 
73 
     | 
    
         
             
            #include "templates.hpp"
         
     | 
| 
       74 
74 
     | 
    
         | 
| 
      
 75 
     | 
    
         
            +
            %{functions}
         
     | 
| 
      
 76 
     | 
    
         
            +
             
     | 
| 
       75 
77 
     | 
    
         
             
            void add_%{type}_functions(Module m) {
         
     | 
| 
       76 
     | 
    
         
            -
               
     | 
| 
       77 
     | 
    
         
            -
              %{functions};
         
     | 
| 
      
 78 
     | 
    
         
            +
              %{add_functions}
         
     | 
| 
       78 
79 
     | 
    
         
             
            }
         
     | 
| 
       79 
80 
     | 
    
         
             
                    TEMPLATE
         
     | 
| 
       80 
81 
     | 
    
         | 
| 
       81 
82 
     | 
    
         
             
                      cpp_defs = []
         
     | 
| 
      
 83 
     | 
    
         
            +
                      add_defs = []
         
     | 
| 
       82 
84 
     | 
    
         
             
                      functions.sort_by(&:cpp_name).each do |func|
         
     | 
| 
       83 
85 
     | 
    
         
             
                        fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
         
     | 
| 
       84 
     | 
    
         
            -
                        fargs << {name:  
     | 
| 
      
 86 
     | 
    
         
            +
                        fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
         
     | 
| 
       85 
87 
     | 
    
         | 
| 
       86 
88 
     | 
    
         
             
                        cpp_args = []
         
     | 
| 
       87 
89 
     | 
    
         
             
                        fargs.each do |a|
         
     | 
| 
         @@ -94,11 +96,9 @@ void add_%{type}_functions(Module m) { 
     | 
|
| 
       94 
96 
     | 
    
         
             
                              "OptionalTensor"
         
     | 
| 
       95 
97 
     | 
    
         
             
                            when "ScalarType?"
         
     | 
| 
       96 
98 
     | 
    
         
             
                              "torch::optional<ScalarType>"
         
     | 
| 
       97 
     | 
    
         
            -
                            when "Tensor[]"
         
     | 
| 
       98 
     | 
    
         
            -
                              "TensorList"
         
     | 
| 
       99 
     | 
    
         
            -
                            when "Tensor?[]"
         
     | 
| 
      
 99 
     | 
    
         
            +
                            when "Tensor[]", "Tensor?[]"
         
     | 
| 
       100 
100 
     | 
    
         
             
                              # TODO make optional
         
     | 
| 
       101 
     | 
    
         
            -
                              " 
     | 
| 
      
 101 
     | 
    
         
            +
                              "std::vector<Tensor>"
         
     | 
| 
       102 
102 
     | 
    
         
             
                            when "int"
         
     | 
| 
       103 
103 
     | 
    
         
             
                              "int64_t"
         
     | 
| 
       104 
104 
     | 
    
         
             
                            when "int?"
         
     | 
| 
         @@ -112,43 +112,53 @@ void add_%{type}_functions(Module m) { 
     | 
|
| 
       112 
112 
     | 
    
         
             
                            when "float"
         
     | 
| 
       113 
113 
     | 
    
         
             
                              "double"
         
     | 
| 
       114 
114 
     | 
    
         
             
                            when /\Aint\[/
         
     | 
| 
       115 
     | 
    
         
            -
                              " 
     | 
| 
      
 115 
     | 
    
         
            +
                              "std::vector<int64_t>"
         
     | 
| 
       116 
116 
     | 
    
         
             
                            when /Tensor\(\S!?\)/
         
     | 
| 
       117 
117 
     | 
    
         
             
                              "Tensor &"
         
     | 
| 
       118 
118 
     | 
    
         
             
                            when "str"
         
     | 
| 
       119 
119 
     | 
    
         
             
                              "std::string"
         
     | 
| 
       120 
120 
     | 
    
         
             
                            when "TensorOptions"
         
     | 
| 
       121 
121 
     | 
    
         
             
                              "const torch::TensorOptions &"
         
     | 
| 
       122 
     | 
    
         
            -
                             
     | 
| 
      
 122 
     | 
    
         
            +
                            when "Layout?"
         
     | 
| 
      
 123 
     | 
    
         
            +
                              "torch::optional<Layout>"
         
     | 
| 
      
 124 
     | 
    
         
            +
                            when "Device?"
         
     | 
| 
      
 125 
     | 
    
         
            +
                              "torch::optional<Device>"
         
     | 
| 
      
 126 
     | 
    
         
            +
                            when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
         
     | 
| 
       123 
127 
     | 
    
         
             
                              a[:type]
         
     | 
| 
      
 128 
     | 
    
         
            +
                            else
         
     | 
| 
      
 129 
     | 
    
         
            +
                              raise "Unknown type: #{a[:type]}"
         
     | 
| 
       124 
130 
     | 
    
         
             
                            end
         
     | 
| 
       125 
131 
     | 
    
         | 
| 
       126 
     | 
    
         
            -
                          t = "MyReduction" if a[:name] ==  
     | 
| 
      
 132 
     | 
    
         
            +
                          t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
         
     | 
| 
       127 
133 
     | 
    
         
             
                          cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
         
     | 
| 
       128 
134 
     | 
    
         
             
                        end
         
     | 
| 
       129 
135 
     | 
    
         | 
| 
       130 
136 
     | 
    
         
             
                        dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
         
     | 
| 
       131 
137 
     | 
    
         
             
                        args = fargs.map { |a| a[:name] }
         
     | 
| 
       132 
138 
     | 
    
         
             
                        args.unshift(*args.pop(func.out_size)) if func.out?
         
     | 
| 
       133 
     | 
    
         
            -
                        args.delete( 
     | 
| 
      
 139 
     | 
    
         
            +
                        args.delete(:self) if def_method == :define_method
         
     | 
| 
       134 
140 
     | 
    
         | 
| 
       135 
141 
     | 
    
         
             
                        prefix = def_method == :define_method ? "self." : "torch::"
         
     | 
| 
       136 
142 
     | 
    
         | 
| 
       137 
143 
     | 
    
         
             
                        body = "#{prefix}#{dispatch}(#{args.join(", ")})"
         
     | 
| 
       138 
144 
     | 
    
         | 
| 
       139 
     | 
    
         
            -
                        if func. 
     | 
| 
      
 145 
     | 
    
         
            +
                        if func.cpp_name == "_fill_diagonal_"
         
     | 
| 
      
 146 
     | 
    
         
            +
                          body = "to_ruby<torch::Tensor>(#{body})"
         
     | 
| 
      
 147 
     | 
    
         
            +
                        elsif !func.ret_void?
         
     | 
| 
       140 
148 
     | 
    
         
             
                          body = "wrap(#{body})"
         
     | 
| 
       141 
149 
     | 
    
         
             
                        end
         
     | 
| 
       142 
150 
     | 
    
         | 
| 
       143 
     | 
    
         
            -
                        cpp_defs << " 
     | 
| 
       144 
     | 
    
         
            -
             
     | 
| 
       145 
     | 
    
         
            -
             
     | 
| 
       146 
     | 
    
         
            -
             
     | 
| 
       147 
     | 
    
         
            -
             
     | 
| 
      
 151 
     | 
    
         
            +
                        cpp_defs << "// #{func.func}
         
     | 
| 
      
 152 
     | 
    
         
            +
            static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
         
     | 
| 
      
 153 
     | 
    
         
            +
            {
         
     | 
| 
      
 154 
     | 
    
         
            +
              return #{body};
         
     | 
| 
      
 155 
     | 
    
         
            +
            }"
         
     | 
| 
      
 156 
     | 
    
         
            +
             
     | 
| 
      
 157 
     | 
    
         
            +
                        add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
         
     | 
| 
       148 
158 
     | 
    
         
             
                      end
         
     | 
| 
       149 
159 
     | 
    
         | 
| 
       150 
160 
     | 
    
         
             
                      hpp_contents = hpp_template % {type: type}
         
     | 
| 
       151 
     | 
    
         
            -
                      cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n  ")}
         
     | 
| 
      
 161 
     | 
    
         
            +
                      cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n  ")}
         
     | 
| 
       152 
162 
     | 
    
         | 
| 
       153 
163 
     | 
    
         
             
                      path = File.expand_path("../../../ext/torch", __dir__)
         
     | 
| 
       154 
164 
     | 
    
         
             
                      File.write("#{path}/#{type}_functions.hpp", hpp_contents)
         
     | 
    
        data/lib/torch/native/parser.rb
    CHANGED
    
    | 
         @@ -6,14 +6,24 @@ module Torch 
     | 
|
| 
       6 
6 
     | 
    
         
             
                    @name = @functions.first.ruby_name
         
     | 
| 
       7 
7 
     | 
    
         
             
                    @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
         
     | 
| 
       8 
8 
     | 
    
         
             
                    @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
         
     | 
| 
      
 9 
     | 
    
         
            +
                    @int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
         
     | 
| 
       9 
10 
     | 
    
         
             
                  end
         
     | 
| 
       10 
11 
     | 
    
         | 
| 
      
 12 
     | 
    
         
            +
                  # TODO improve performance
         
     | 
| 
      
 13 
     | 
    
         
            +
                  # possibly move to C++ (see python_arg_parser.cpp)
         
     | 
| 
       11 
14 
     | 
    
         
             
                  def parse(args, options)
         
     | 
| 
       12 
15 
     | 
    
         
             
                    candidates = @functions.dup
         
     | 
| 
       13 
16 
     | 
    
         | 
| 
       14 
     | 
    
         
            -
                    #  
     | 
| 
       15 
     | 
    
         
            -
                     
     | 
| 
       16 
     | 
    
         
            -
                       
     | 
| 
      
 17 
     | 
    
         
            +
                    # TODO check candidates individually to see if they match
         
     | 
| 
      
 18 
     | 
    
         
            +
                    if @int_array_first
         
     | 
| 
      
 19 
     | 
    
         
            +
                      int_args = []
         
     | 
| 
      
 20 
     | 
    
         
            +
                      while args.first.is_a?(Integer)
         
     | 
| 
      
 21 
     | 
    
         
            +
                        int_args << args.shift
         
     | 
| 
      
 22 
     | 
    
         
            +
                      end
         
     | 
| 
      
 23 
     | 
    
         
            +
                      if int_args.any?
         
     | 
| 
      
 24 
     | 
    
         
            +
                        raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
         
     | 
| 
      
 25 
     | 
    
         
            +
                        args.unshift(int_args)
         
     | 
| 
      
 26 
     | 
    
         
            +
                      end
         
     | 
| 
       17 
27 
     | 
    
         
             
                    end
         
     | 
| 
       18 
28 
     | 
    
         | 
| 
       19 
29 
     | 
    
         
             
                    # TODO account for args passed as options here
         
     | 
| 
         @@ -25,99 +35,60 @@ module Torch 
     | 
|
| 
       25 
35 
     | 
    
         | 
| 
       26 
36 
     | 
    
         
             
                    candidates.reject! { |f| args.size > f.args.size }
         
     | 
| 
       27 
37 
     | 
    
         | 
| 
       28 
     | 
    
         
            -
                    # exclude functions missing required options
         
     | 
| 
       29 
     | 
    
         
            -
                    candidates.reject! do |func|
         
     | 
| 
       30 
     | 
    
         
            -
                      # TODO make more generic
         
     | 
| 
       31 
     | 
    
         
            -
                      func.out? && !options[:out]
         
     | 
| 
       32 
     | 
    
         
            -
                    end
         
     | 
| 
       33 
     | 
    
         
            -
             
     | 
| 
       34 
38 
     | 
    
         
             
                    # handle out with multiple
         
     | 
| 
       35 
39 
     | 
    
         
             
                    # there should only be one match, so safe to modify all
         
     | 
| 
       36 
     | 
    
         
            -
                     
     | 
| 
       37 
     | 
    
         
            -
             
     | 
| 
       38 
     | 
    
         
            -
             
     | 
| 
       39 
     | 
    
         
            -
             
     | 
| 
       40 
     | 
    
         
            -
             
     | 
| 
       41 
     | 
    
         
            -
             
     | 
| 
       42 
     | 
    
         
            -
             
     | 
| 
       43 
     | 
    
         
            -
                    end
         
     | 
| 
       44 
     | 
    
         
            -
             
     | 
| 
       45 
     | 
    
         
            -
                    # exclude functions where options don't match
         
     | 
| 
       46 
     | 
    
         
            -
                    options.each do |k, v|
         
     | 
| 
       47 
     | 
    
         
            -
                      candidates.select! do |func|
         
     | 
| 
       48 
     | 
    
         
            -
                        func.args.any? { |a| a[:name] == k.to_s }
         
     | 
| 
      
 40 
     | 
    
         
            +
                    if options[:out]
         
     | 
| 
      
 41 
     | 
    
         
            +
                      if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
         
     | 
| 
      
 42 
     | 
    
         
            +
                        out_args = out_func.args.last(2).map { |a| a[:name] }
         
     | 
| 
      
 43 
     | 
    
         
            +
                        out_args.zip(options.delete(:out)).each do |k, v|
         
     | 
| 
      
 44 
     | 
    
         
            +
                          options[k] = v
         
     | 
| 
      
 45 
     | 
    
         
            +
                        end
         
     | 
| 
      
 46 
     | 
    
         
            +
                        candidates = [out_func]
         
     | 
| 
       49 
47 
     | 
    
         
             
                      end
         
     | 
| 
       50 
     | 
    
         
            -
             
     | 
| 
       51 
     | 
    
         
            -
                       
     | 
| 
      
 48 
     | 
    
         
            +
                    else
         
     | 
| 
      
 49 
     | 
    
         
            +
                      # exclude functions missing required options
         
     | 
| 
      
 50 
     | 
    
         
            +
                      candidates.reject!(&:out?)
         
     | 
| 
       52 
51 
     | 
    
         
             
                    end
         
     | 
| 
       53 
52 
     | 
    
         | 
| 
       54 
     | 
    
         
            -
                    final_values =  
     | 
| 
      
 53 
     | 
    
         
            +
                    final_values = nil
         
     | 
| 
       55 
54 
     | 
    
         | 
| 
       56 
55 
     | 
    
         
             
                    # check args
         
     | 
| 
       57 
     | 
    
         
            -
                    candidates. 
     | 
| 
      
 56 
     | 
    
         
            +
                    while (func = candidates.shift)
         
     | 
| 
       58 
57 
     | 
    
         
             
                      good = true
         
     | 
| 
       59 
58 
     | 
    
         | 
| 
       60 
     | 
    
         
            -
                       
     | 
| 
       61 
     | 
    
         
            -
                       
     | 
| 
       62 
     | 
    
         
            -
                       
     | 
| 
       63 
     | 
    
         
            -
             
     | 
| 
      
 59 
     | 
    
         
            +
                      # set values
         
     | 
| 
      
 60 
     | 
    
         
            +
                      # TODO use array instead of hash?
         
     | 
| 
      
 61 
     | 
    
         
            +
                      values = {}
         
     | 
| 
      
 62 
     | 
    
         
            +
                      args.each_with_index do |a, i|
         
     | 
| 
      
 63 
     | 
    
         
            +
                        values[func.arg_names[i]] = a
         
     | 
| 
      
 64 
     | 
    
         
            +
                      end
         
     | 
| 
      
 65 
     | 
    
         
            +
                      options.each do |k, v|
         
     | 
| 
      
 66 
     | 
    
         
            +
                        values[k] = v
         
     | 
| 
      
 67 
     | 
    
         
            +
                      end
         
     | 
| 
      
 68 
     | 
    
         
            +
                      func.arg_defaults.each do |k, v|
         
     | 
| 
      
 69 
     | 
    
         
            +
                        values[k] = v unless values.key?(k)
         
     | 
| 
      
 70 
     | 
    
         
            +
                      end
         
     | 
| 
      
 71 
     | 
    
         
            +
                      func.int_array_lengths.each do |k, len|
         
     | 
| 
      
 72 
     | 
    
         
            +
                        values[k] = [values[k]] * len if values[k].is_a?(Integer)
         
     | 
| 
       64 
73 
     | 
    
         
             
                      end
         
     | 
| 
       65 
74 
     | 
    
         | 
| 
       66 
     | 
    
         
            -
                       
     | 
| 
      
 75 
     | 
    
         
            +
                      arg_checkers = func.arg_checkers
         
     | 
| 
       67 
76 
     | 
    
         | 
| 
       68 
77 
     | 
    
         
             
                      values.each_key do |k|
         
     | 
| 
       69 
     | 
    
         
            -
                         
     | 
| 
       70 
     | 
    
         
            -
             
     | 
| 
       71 
     | 
    
         
            -
             
     | 
| 
       72 
     | 
    
         
            -
             
     | 
| 
       73 
     | 
    
         
            -
             
     | 
| 
       74 
     | 
    
         
            -
                          when "Tensor"
         
     | 
| 
       75 
     | 
    
         
            -
                            v.is_a?(Tensor)
         
     | 
| 
       76 
     | 
    
         
            -
                          when "Tensor?"
         
     | 
| 
       77 
     | 
    
         
            -
                            v.nil? || v.is_a?(Tensor)
         
     | 
| 
       78 
     | 
    
         
            -
                          when "Tensor[]", "Tensor?[]"
         
     | 
| 
       79 
     | 
    
         
            -
                            v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
         
     | 
| 
       80 
     | 
    
         
            -
                          when "int"
         
     | 
| 
       81 
     | 
    
         
            -
                            if k == "reduction"
         
     | 
| 
       82 
     | 
    
         
            -
                              v.is_a?(String)
         
     | 
| 
       83 
     | 
    
         
            -
                            else
         
     | 
| 
       84 
     | 
    
         
            -
                              v.is_a?(Integer)
         
     | 
| 
       85 
     | 
    
         
            -
                            end
         
     | 
| 
       86 
     | 
    
         
            -
                          when "int?"
         
     | 
| 
       87 
     | 
    
         
            -
                            v.is_a?(Integer) || v.nil?
         
     | 
| 
       88 
     | 
    
         
            -
                          when "float?"
         
     | 
| 
       89 
     | 
    
         
            -
                            v.is_a?(Numeric) || v.nil?
         
     | 
| 
       90 
     | 
    
         
            -
                          when "bool?"
         
     | 
| 
       91 
     | 
    
         
            -
                            v == true || v == false || v.nil?
         
     | 
| 
       92 
     | 
    
         
            -
                          when "float"
         
     | 
| 
       93 
     | 
    
         
            -
                            v.is_a?(Numeric)
         
     | 
| 
       94 
     | 
    
         
            -
                          when /int\[.*\]/
         
     | 
| 
       95 
     | 
    
         
            -
                            if v.is_a?(Integer)
         
     | 
| 
       96 
     | 
    
         
            -
                              size = t[4..-2]
         
     | 
| 
       97 
     | 
    
         
            -
                              raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
         
     | 
| 
       98 
     | 
    
         
            -
                              v = [v] * size.to_i
         
     | 
| 
       99 
     | 
    
         
            -
                              values[k] = v
         
     | 
| 
       100 
     | 
    
         
            -
                            end
         
     | 
| 
       101 
     | 
    
         
            -
                            v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
         
     | 
| 
       102 
     | 
    
         
            -
                          when "Scalar"
         
     | 
| 
       103 
     | 
    
         
            -
                            v.is_a?(Numeric)
         
     | 
| 
       104 
     | 
    
         
            -
                          when "Scalar?"
         
     | 
| 
       105 
     | 
    
         
            -
                            v.is_a?(Numeric) || v.nil?
         
     | 
| 
       106 
     | 
    
         
            -
                          when "ScalarType"
         
     | 
| 
       107 
     | 
    
         
            -
                            false # not supported yet
         
     | 
| 
       108 
     | 
    
         
            -
                          when "ScalarType?"
         
     | 
| 
       109 
     | 
    
         
            -
                            v.nil?
         
     | 
| 
       110 
     | 
    
         
            -
                          when "bool"
         
     | 
| 
       111 
     | 
    
         
            -
                            v == true || v == false
         
     | 
| 
       112 
     | 
    
         
            -
                          when "str"
         
     | 
| 
       113 
     | 
    
         
            -
                            v.is_a?(String)
         
     | 
| 
       114 
     | 
    
         
            -
                          else
         
     | 
| 
       115 
     | 
    
         
            -
                            raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
         
     | 
| 
      
 78 
     | 
    
         
            +
                        unless arg_checkers.key?(k)
         
     | 
| 
      
 79 
     | 
    
         
            +
                          good = false
         
     | 
| 
      
 80 
     | 
    
         
            +
                          if candidates.empty?
         
     | 
| 
      
 81 
     | 
    
         
            +
                            # TODO show all bad keywords at once like Ruby?
         
     | 
| 
      
 82 
     | 
    
         
            +
                            return {error: "unknown keyword: #{k}"}
         
     | 
| 
       116 
83 
     | 
    
         
             
                          end
         
     | 
| 
      
 84 
     | 
    
         
            +
                          break
         
     | 
| 
      
 85 
     | 
    
         
            +
                        end
         
     | 
| 
       117 
86 
     | 
    
         | 
| 
       118 
     | 
    
         
            -
                         
     | 
| 
       119 
     | 
    
         
            -
                           
     | 
| 
       120 
     | 
    
         
            -
             
     | 
| 
      
 87 
     | 
    
         
            +
                        unless arg_checkers[k].call(values[k])
         
     | 
| 
      
 88 
     | 
    
         
            +
                          good = false
         
     | 
| 
      
 89 
     | 
    
         
            +
                          if candidates.empty?
         
     | 
| 
      
 90 
     | 
    
         
            +
                            t = func.arg_types[k]
         
     | 
| 
      
 91 
     | 
    
         
            +
                            k = :input if k == :self
         
     | 
| 
       121 
92 
     | 
    
         
             
                            return {error: "#{@name}(): argument '#{k}' must be #{t}"}
         
     | 
| 
       122 
93 
     | 
    
         
             
                          end
         
     | 
| 
       123 
94 
     | 
    
         
             
                          break
         
     | 
| 
         @@ -126,17 +97,15 @@ module Torch 
     | 
|
| 
       126 
97 
     | 
    
         | 
| 
       127 
98 
     | 
    
         
             
                      if good
         
     | 
| 
       128 
99 
     | 
    
         
             
                        final_values = values
         
     | 
| 
      
 100 
     | 
    
         
            +
                        break
         
     | 
| 
       129 
101 
     | 
    
         
             
                      end
         
     | 
| 
       130 
     | 
    
         
            -
             
     | 
| 
       131 
     | 
    
         
            -
                      good
         
     | 
| 
       132 
102 
     | 
    
         
             
                    end
         
     | 
| 
       133 
103 
     | 
    
         | 
| 
       134 
     | 
    
         
            -
                     
     | 
| 
      
 104 
     | 
    
         
            +
                    unless final_values
         
     | 
| 
       135 
105 
     | 
    
         
             
                      raise Error, "This should never happen. Please report a bug with #{@name}."
         
     | 
| 
       136 
106 
     | 
    
         
             
                    end
         
     | 
| 
       137 
107 
     | 
    
         | 
| 
       138 
     | 
    
         
            -
                     
     | 
| 
       139 
     | 
    
         
            -
                    args = func.args.map { |a| final_values[a[:name]] }
         
     | 
| 
      
 108 
     | 
    
         
            +
                    args = func.arg_names.map { |k| final_values[k] }
         
     | 
| 
       140 
109 
     | 
    
         
             
                    args << TensorOptions.new.dtype(6) if func.tensor_options
         
     | 
| 
       141 
110 
     | 
    
         
             
                    {
         
     | 
| 
       142 
111 
     | 
    
         
             
                      name: func.cpp_name,
         
     | 
    
        data/lib/torch/nn/functional.rb
    CHANGED
    
    | 
         @@ -469,6 +469,77 @@ module Torch 
     | 
|
| 
       469 
469 
     | 
    
         
             
                      Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
         
     | 
| 
       470 
470 
     | 
    
         
             
                    end
         
     | 
| 
       471 
471 
     | 
    
         | 
| 
      
 472 
     | 
    
         
            +
                    # vision
         
     | 
| 
      
 473 
     | 
    
         
            +
             
     | 
| 
      
 474 
     | 
    
         
            +
                    def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
         
     | 
| 
      
 475 
     | 
    
         
            +
                      if ["nearest", "area"].include?(mode)
         
     | 
| 
      
 476 
     | 
    
         
            +
                        unless align_corners.nil?
         
     | 
| 
      
 477 
     | 
    
         
            +
                          raise ArgumentError, "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
         
     | 
| 
      
 478 
     | 
    
         
            +
                        end
         
     | 
| 
      
 479 
     | 
    
         
            +
                      else
         
     | 
| 
      
 480 
     | 
    
         
            +
                        if align_corners.nil?
         
     | 
| 
      
 481 
     | 
    
         
            +
                          align_corners = false
         
     | 
| 
      
 482 
     | 
    
         
            +
                        end
         
     | 
| 
      
 483 
     | 
    
         
            +
                      end
         
     | 
| 
      
 484 
     | 
    
         
            +
             
     | 
| 
      
 485 
     | 
    
         
            +
                      scale_factor_len = input.dim - 2
         
     | 
| 
      
 486 
     | 
    
         
            +
                      scale_factor_list = [nil] * scale_factor_len
         
     | 
| 
      
 487 
     | 
    
         
            +
                      # default value of recompute_scale_factor is False
         
     | 
| 
      
 488 
     | 
    
         
            +
                      if !scale_factor.nil? && (recompute_scale_factor == false || recompute_scale_factor.nil?)
         
     | 
| 
      
 489 
     | 
    
         
            +
                        if scale_factor.is_a?(Array)
         
     | 
| 
      
 490 
     | 
    
         
            +
                          _scale_factor_repeated = scale_factor
         
     | 
| 
      
 491 
     | 
    
         
            +
                        else
         
     | 
| 
      
 492 
     | 
    
         
            +
                          _scale_factor_repeated = [scale_factor] * scale_factor_len
         
     | 
| 
      
 493 
     | 
    
         
            +
                        end
         
     | 
| 
      
 494 
     | 
    
         
            +
                        scale_factor_list = _scale_factor_repeated
         
     | 
| 
      
 495 
     | 
    
         
            +
                      end
         
     | 
| 
      
 496 
     | 
    
         
            +
             
     | 
| 
      
 497 
     | 
    
         
            +
                      # Give this variable a short name because it has to be repeated multiple times below.
         
     | 
| 
      
 498 
     | 
    
         
            +
                      sfl = scale_factor_list
         
     | 
| 
      
 499 
     | 
    
         
            +
             
     | 
| 
      
 500 
     | 
    
         
            +
                      closed_over_args = [input, size, scale_factor, recompute_scale_factor]
         
     | 
| 
      
 501 
     | 
    
         
            +
                      output_size = _interp_output_size(closed_over_args)
         
     | 
| 
      
 502 
     | 
    
         
            +
                      if input.dim == 3 && mode == "nearest"
         
     | 
| 
      
 503 
     | 
    
         
            +
                        NN.upsample_nearest1d(input, output_size, sfl[0])
         
     | 
| 
      
 504 
     | 
    
         
            +
                      elsif input.dim == 4 && mode == "nearest"
         
     | 
| 
      
 505 
     | 
    
         
            +
                        NN.upsample_nearest2d(input, output_size, sfl[0], sfl[1])
         
     | 
| 
      
 506 
     | 
    
         
            +
                      elsif input.dim == 5 && mode == "nearest"
         
     | 
| 
      
 507 
     | 
    
         
            +
                        NN.upsample_nearest3d(input, output_size, sfl[0], sfl[1], sfl[2])
         
     | 
| 
      
 508 
     | 
    
         
            +
                      elsif input.dim == 3 && mode == "area"
         
     | 
| 
      
 509 
     | 
    
         
            +
                        adaptive_avg_pool1d(input, output_size)
         
     | 
| 
      
 510 
     | 
    
         
            +
                      elsif input.dim == 4 && mode == "area"
         
     | 
| 
      
 511 
     | 
    
         
            +
                        adaptive_avg_pool2d(input, output_size)
         
     | 
| 
      
 512 
     | 
    
         
            +
                      elsif input.dim == 5 && mode == "area"
         
     | 
| 
      
 513 
     | 
    
         
            +
                        adaptive_avg_pool3d(input, output_size)
         
     | 
| 
      
 514 
     | 
    
         
            +
                      elsif input.dim == 3 && mode == "linear"
         
     | 
| 
      
 515 
     | 
    
         
            +
                        # assert align_corners is not None
         
     | 
| 
      
 516 
     | 
    
         
            +
                        NN.upsample_linear1d(input, output_size, align_corners, sfl[0])
         
     | 
| 
      
 517 
     | 
    
         
            +
                      elsif input.dim == 3 && mode == "bilinear"
         
     | 
| 
      
 518 
     | 
    
         
            +
                        raise ArgumentError, "Got 3D input, but bilinear mode needs 4D input"
         
     | 
| 
      
 519 
     | 
    
         
            +
                      elsif input.dim == 3 && mode == "trilinear"
         
     | 
| 
      
 520 
     | 
    
         
            +
                        raise ArgumentError, "Got 3D input, but trilinear mode needs 5D input"
         
     | 
| 
      
 521 
     | 
    
         
            +
                      elsif input.dim == 4 && mode == "linear"
         
     | 
| 
      
 522 
     | 
    
         
            +
                        raise ArgumentError, "Got 4D input, but linear mode needs 3D input"
         
     | 
| 
      
 523 
     | 
    
         
            +
                      elsif input.dim == 4 && mode == "bilinear"
         
     | 
| 
      
 524 
     | 
    
         
            +
                        # assert align_corners is not None
         
     | 
| 
      
 525 
     | 
    
         
            +
                        NN.upsample_bilinear2d(input, output_size, align_corners, sfl[0], sfl[1])
         
     | 
| 
      
 526 
     | 
    
         
            +
                      elsif input.dim == 4 && mode == "trilinear"
         
     | 
| 
      
 527 
     | 
    
         
            +
                        raise ArgumentError, "Got 4D input, but trilinear mode needs 5D input"
         
     | 
| 
      
 528 
     | 
    
         
            +
                      elsif input.dim == 5 && mode == "linear"
         
     | 
| 
      
 529 
     | 
    
         
            +
                        raise ArgumentError, "Got 5D input, but linear mode needs 3D input"
         
     | 
| 
      
 530 
     | 
    
         
            +
                      elsif input.dim == 5 && mode == "bilinear"
         
     | 
| 
      
 531 
     | 
    
         
            +
                        raise ArgumentError, "Got 5D input, but bilinear mode needs 4D input"
         
     | 
| 
      
 532 
     | 
    
         
            +
                      elsif input.dim == 5 && mode == "trilinear"
         
     | 
| 
      
 533 
     | 
    
         
            +
                        # assert align_corners is not None
         
     | 
| 
      
 534 
     | 
    
         
            +
                        NN.upsample_trilinear3d(input, output_size, align_corners, sfl[0], sfl[1], sfl[2])
         
     | 
| 
      
 535 
     | 
    
         
            +
                      elsif input.dim == 4 && mode == "bicubic"
         
     | 
| 
      
 536 
     | 
    
         
            +
                        # assert align_corners is not None
         
     | 
| 
      
 537 
     | 
    
         
            +
                        NN.upsample_bicubic2d(input, output_size, align_corners, sfl[0], sfl[1])
         
     | 
| 
      
 538 
     | 
    
         
            +
                      else
         
     | 
| 
      
 539 
     | 
    
         
            +
                        raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})"
         
     | 
| 
      
 540 
     | 
    
         
            +
                      end
         
     | 
| 
      
 541 
     | 
    
         
            +
                    end
         
     | 
| 
      
 542 
     | 
    
         
            +
             
     | 
| 
       472 
543 
     | 
    
         
             
                    private
         
     | 
| 
       473 
544 
     | 
    
         | 
| 
       474 
545 
     | 
    
         
             
                    def softmax_dim(ndim)
         
     | 
| 
         @@ -484,6 +555,41 @@ module Torch 
     | 
|
| 
       484 
555 
     | 
    
         
             
                        out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
         
     | 
| 
       485 
556 
     | 
    
         
             
                      end
         
     | 
| 
       486 
557 
     | 
    
         
             
                    end
         
     | 
| 
      
 558 
     | 
    
         
            +
             
     | 
| 
      
 559 
     | 
    
         
            +
                    def _interp_output_size(closed_over_args)
         
     | 
| 
      
 560 
     | 
    
         
            +
                      input, size, scale_factor, recompute_scale_factor = closed_over_args
         
     | 
| 
      
 561 
     | 
    
         
            +
                      dim = input.dim - 2
         
     | 
| 
      
 562 
     | 
    
         
            +
                      if size.nil? && scale_factor.nil?
         
     | 
| 
      
 563 
     | 
    
         
            +
                        raise ArgumentError, "either size or scale_factor should be defined"
         
     | 
| 
      
 564 
     | 
    
         
            +
                      end
         
     | 
| 
      
 565 
     | 
    
         
            +
                      if !size.nil? && !scale_factor.nil?
         
     | 
| 
      
 566 
     | 
    
         
            +
                        raise ArgumentError, "only one of size or scale_factor should be defined"
         
     | 
| 
      
 567 
     | 
    
         
            +
                      end
         
     | 
| 
      
 568 
     | 
    
         
            +
                      if !scale_factor.nil?
         
     | 
| 
      
 569 
     | 
    
         
            +
                        if scale_factor.is_a?(Array)
         
     | 
| 
      
 570 
     | 
    
         
            +
                          if scale_factor.length != dim
         
     | 
| 
      
 571 
     | 
    
         
            +
                            raise ArgumentError, "scale_factor shape must match input shape. Input is #{dim}D, scale_factor size is #{scale_factor.length}"
         
     | 
| 
      
 572 
     | 
    
         
            +
                          end
         
     | 
| 
      
 573 
     | 
    
         
            +
                        end
         
     | 
| 
      
 574 
     | 
    
         
            +
                      end
         
     | 
| 
      
 575 
     | 
    
         
            +
             
     | 
| 
      
 576 
     | 
    
         
            +
                      if !size.nil?
         
     | 
| 
      
 577 
     | 
    
         
            +
                        if size.is_a?(Array)
         
     | 
| 
      
 578 
     | 
    
         
            +
                          return size
         
     | 
| 
      
 579 
     | 
    
         
            +
                        else
         
     | 
| 
      
 580 
     | 
    
         
            +
                          return [size] * dim
         
     | 
| 
      
 581 
     | 
    
         
            +
                        end
         
     | 
| 
      
 582 
     | 
    
         
            +
                      end
         
     | 
| 
      
 583 
     | 
    
         
            +
             
     | 
| 
      
 584 
     | 
    
         
            +
                      raise "Failed assertion" if scale_factor.nil?
         
     | 
| 
      
 585 
     | 
    
         
            +
                      if scale_factor.is_a?(Array)
         
     | 
| 
      
 586 
     | 
    
         
            +
                        scale_factors = scale_factor
         
     | 
| 
      
 587 
     | 
    
         
            +
                      else
         
     | 
| 
      
 588 
     | 
    
         
            +
                        scale_factors = [scale_factor] * dim
         
     | 
| 
      
 589 
     | 
    
         
            +
                      end
         
     | 
| 
      
 590 
     | 
    
         
            +
             
     | 
| 
      
 591 
     | 
    
         
            +
                      dim.times.map { |i| (input.size(i + 2) * scale_factors[i]).floor }
         
     | 
| 
      
 592 
     | 
    
         
            +
                    end
         
     | 
| 
       487 
593 
     | 
    
         
             
                  end
         
     | 
| 
       488 
594 
     | 
    
         
             
                end
         
     | 
| 
       489 
595 
     | 
    
         | 
    
        data/lib/torch/nn/module.rb
    CHANGED
    
    
| 
         @@ -0,0 +1,31 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            module Torch
         
     | 
| 
      
 2 
     | 
    
         
            +
              module NN
         
     | 
| 
      
 3 
     | 
    
         
            +
                class Upsample < Module
         
     | 
| 
      
 4 
     | 
    
         
            +
                  def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
         
     | 
| 
      
 5 
     | 
    
         
            +
                    super()
         
     | 
| 
      
 6 
     | 
    
         
            +
                    @size = size
         
     | 
| 
      
 7 
     | 
    
         
            +
                    if scale_factor.is_a?(Array)
         
     | 
| 
      
 8 
     | 
    
         
            +
                      @scale_factor = scale_factor.map(&:to_f)
         
     | 
| 
      
 9 
     | 
    
         
            +
                    else
         
     | 
| 
      
 10 
     | 
    
         
            +
                      @scale_factor = scale_factor ? scale_factor.to_f : nil
         
     | 
| 
      
 11 
     | 
    
         
            +
                    end
         
     | 
| 
      
 12 
     | 
    
         
            +
                    @mode = mode
         
     | 
| 
      
 13 
     | 
    
         
            +
                    @align_corners = align_corners
         
     | 
| 
      
 14 
     | 
    
         
            +
                  end
         
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
                  def forward(input)
         
     | 
| 
      
 17 
     | 
    
         
            +
                    F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
         
     | 
| 
      
 18 
     | 
    
         
            +
                  end
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
                  def extra_inspect
         
     | 
| 
      
 21 
     | 
    
         
            +
                    if !@scale_factor.nil?
         
     | 
| 
      
 22 
     | 
    
         
            +
                      info = "scale_factor: #{@scale_factor.inspect}"
         
     | 
| 
      
 23 
     | 
    
         
            +
                    else
         
     | 
| 
      
 24 
     | 
    
         
            +
                      info = "size: #{@size.inspect}"
         
     | 
| 
      
 25 
     | 
    
         
            +
                    end
         
     | 
| 
      
 26 
     | 
    
         
            +
                    info += ", mode: #{@mode.inspect}"
         
     | 
| 
      
 27 
     | 
    
         
            +
                    info
         
     | 
| 
      
 28 
     | 
    
         
            +
                  end
         
     | 
| 
      
 29 
     | 
    
         
            +
                end
         
     | 
| 
      
 30 
     | 
    
         
            +
              end
         
     | 
| 
      
 31 
     | 
    
         
            +
            end
         
     | 
    
        data/lib/torch/tensor.rb
    CHANGED
    
    | 
         @@ -48,6 +48,11 @@ module Torch 
     | 
|
| 
       48 
48 
     | 
    
         
             
                end
         
     | 
| 
       49 
49 
     | 
    
         | 
| 
       50 
50 
     | 
    
         
             
                def to(device = nil, dtype: nil, non_blocking: false, copy: false)
         
     | 
| 
      
 51 
     | 
    
         
            +
                  if device.is_a?(Symbol) && !dtype
         
     | 
| 
      
 52 
     | 
    
         
            +
                    dtype = device
         
     | 
| 
      
 53 
     | 
    
         
            +
                    device = nil
         
     | 
| 
      
 54 
     | 
    
         
            +
                  end
         
     | 
| 
      
 55 
     | 
    
         
            +
             
     | 
| 
       51 
56 
     | 
    
         
             
                  device ||= self.device
         
     | 
| 
       52 
57 
     | 
    
         
             
                  device = Device.new(device) if device.is_a?(String)
         
     | 
| 
       53 
58 
     | 
    
         | 
| 
         @@ -74,10 +79,6 @@ module Torch 
     | 
|
| 
       74 
79 
     | 
    
         
             
                  end
         
     | 
| 
       75 
80 
     | 
    
         
             
                end
         
     | 
| 
       76 
81 
     | 
    
         | 
| 
       77 
     | 
    
         
            -
                def shape
         
     | 
| 
       78 
     | 
    
         
            -
                  dim.times.map { |i| size(i) }
         
     | 
| 
       79 
     | 
    
         
            -
                end
         
     | 
| 
       80 
     | 
    
         
            -
             
     | 
| 
       81 
82 
     | 
    
         
             
                # mirror Python len()
         
     | 
| 
       82 
83 
     | 
    
         
             
                def length
         
     | 
| 
       83 
84 
     | 
    
         
             
                  size(0)
         
     | 
| 
         @@ -119,9 +120,14 @@ module Torch 
     | 
|
| 
       119 
120 
     | 
    
         
             
                end
         
     | 
| 
       120 
121 
     | 
    
         | 
| 
       121 
122 
     | 
    
         
             
                def type(dtype)
         
     | 
| 
       122 
     | 
    
         
            -
                   
     | 
| 
       123 
     | 
    
         
            -
             
     | 
| 
       124 
     | 
    
         
            -
             
     | 
| 
      
 123 
     | 
    
         
            +
                  if dtype.is_a?(Class)
         
     | 
| 
      
 124 
     | 
    
         
            +
                    raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
         
     | 
| 
      
 125 
     | 
    
         
            +
                    dtype.new(self)
         
     | 
| 
      
 126 
     | 
    
         
            +
                  else
         
     | 
| 
      
 127 
     | 
    
         
            +
                    enum = DTYPE_TO_ENUM[dtype]
         
     | 
| 
      
 128 
     | 
    
         
            +
                    raise Error, "Invalid type: #{dtype}" unless enum
         
     | 
| 
      
 129 
     | 
    
         
            +
                    _type(enum)
         
     | 
| 
      
 130 
     | 
    
         
            +
                  end
         
     | 
| 
       125 
131 
     | 
    
         
             
                end
         
     | 
| 
       126 
132 
     | 
    
         | 
| 
       127 
133 
     | 
    
         
             
                def reshape(*size)
         
     | 
    
        data/lib/torch/version.rb
    CHANGED
    
    
    
        metadata
    CHANGED
    
    | 
         @@ -1,14 +1,14 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            --- !ruby/object:Gem::Specification
         
     | 
| 
       2 
2 
     | 
    
         
             
            name: torch-rb
         
     | 
| 
       3 
3 
     | 
    
         
             
            version: !ruby/object:Gem::Version
         
     | 
| 
       4 
     | 
    
         
            -
              version: 0.3. 
     | 
| 
      
 4 
     | 
    
         
            +
              version: 0.3.7
         
     | 
| 
       5 
5 
     | 
    
         
             
            platform: ruby
         
     | 
| 
       6 
6 
     | 
    
         
             
            authors:
         
     | 
| 
       7 
7 
     | 
    
         
             
            - Andrew Kane
         
     | 
| 
       8 
8 
     | 
    
         
             
            autorequire: 
         
     | 
| 
       9 
9 
     | 
    
         
             
            bindir: bin
         
     | 
| 
       10 
10 
     | 
    
         
             
            cert_chain: []
         
     | 
| 
       11 
     | 
    
         
            -
            date: 2020-09- 
     | 
| 
      
 11 
     | 
    
         
            +
            date: 2020-09-23 00:00:00.000000000 Z
         
     | 
| 
       12 
12 
     | 
    
         
             
            dependencies:
         
     | 
| 
       13 
13 
     | 
    
         
             
            - !ruby/object:Gem::Dependency
         
     | 
| 
       14 
14 
     | 
    
         
             
              name: rice
         
     | 
| 
         @@ -238,6 +238,7 @@ files: 
     | 
|
| 
       238 
238 
     | 
    
         
             
            - lib/torch/nn/tanhshrink.rb
         
     | 
| 
       239 
239 
     | 
    
         
             
            - lib/torch/nn/triplet_margin_loss.rb
         
     | 
| 
       240 
240 
     | 
    
         
             
            - lib/torch/nn/unfold.rb
         
     | 
| 
      
 241 
     | 
    
         
            +
            - lib/torch/nn/upsample.rb
         
     | 
| 
       241 
242 
     | 
    
         
             
            - lib/torch/nn/utils.rb
         
     | 
| 
       242 
243 
     | 
    
         
             
            - lib/torch/nn/weighted_loss.rb
         
     | 
| 
       243 
244 
     | 
    
         
             
            - lib/torch/nn/zero_pad2d.rb
         
     |