torch-rb 0.3.5 → 0.3.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
 - data/CHANGELOG.md +6 -0
 - data/ext/torch/ext.cpp +0 -10
 - data/ext/torch/templates.cpp +8 -0
 - data/ext/torch/templates.hpp +46 -24
 - data/lib/torch.rb +16 -19
 - data/lib/torch/native/function.rb +4 -0
 - data/lib/torch/native/generator.rb +10 -5
 - data/lib/torch/native/parser.rb +8 -0
 - data/lib/torch/nn/functional.rb +6 -2
 - data/lib/torch/nn/leaky_relu.rb +3 -3
 - data/lib/torch/nn/module.rb +6 -1
 - data/lib/torch/tensor.rb +1 -6
 - data/lib/torch/utils/data/data_loader.rb +2 -0
 - data/lib/torch/utils/data/tensor_dataset.rb +2 -0
 - data/lib/torch/version.rb +1 -1
 - metadata +2 -2
 
    
        checksums.yaml
    CHANGED
    
    | 
         @@ -1,7 +1,7 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            ---
         
     | 
| 
       2 
2 
     | 
    
         
             
            SHA256:
         
     | 
| 
       3 
     | 
    
         
            -
              metadata.gz:  
     | 
| 
       4 
     | 
    
         
            -
              data.tar.gz:  
     | 
| 
      
 3 
     | 
    
         
            +
              metadata.gz: f7b85027dfbb5a3d8de3741d00f4256fd13a6b5496a123564472a48d8a084c1b
         
     | 
| 
      
 4 
     | 
    
         
            +
              data.tar.gz: 5c684e45ec115ce3b9cc5a3e223ee73cac3dce5fbacae1bd4d4faa7cf49adc5f
         
     | 
| 
       5 
5 
     | 
    
         
             
            SHA512:
         
     | 
| 
       6 
     | 
    
         
            -
              metadata.gz:  
     | 
| 
       7 
     | 
    
         
            -
              data.tar.gz:  
     | 
| 
      
 6 
     | 
    
         
            +
              metadata.gz: acb727d9836709e5db4df21aeb6eec401e10f3c1910f95877493a9b1920cef4a0bd4914b906dab9b2ec18071fe95bf50de91a8a00a0914f3876ecb851e1c19c7
         
     | 
| 
      
 7 
     | 
    
         
            +
              data.tar.gz: 7e69bde091825d7dcda81cfcfebd220c8072442322071a73eafd2849d9a899229bbcc2ce2b80f78e4b44e468e7a263ec83fd9a3fb3c1bf3073573596f40ec143
         
     | 
    
        data/CHANGELOG.md
    CHANGED
    
    | 
         @@ -1,3 +1,9 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            ## 0.3.6 (2020-09-17)
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            - Added `inplace` option for leaky ReLU
         
     | 
| 
      
 4 
     | 
    
         
            +
            - Fixed error with methods that return a tensor list (`chunk`, `split`, and `unbind`)
         
     | 
| 
      
 5 
     | 
    
         
            +
            - Fixed error with buffers on GPU
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
       1 
7 
     | 
    
         
             
            ## 0.3.5 (2020-09-04)
         
     | 
| 
       2 
8 
     | 
    
         | 
| 
       3 
9 
     | 
    
         
             
            - Fixed error with data loader (due to `dtype` of `randperm`)
         
     | 
    
        data/ext/torch/ext.cpp
    CHANGED
    
    | 
         @@ -301,11 +301,6 @@ void Init_ext() 
     | 
|
| 
       301 
301 
     | 
    
         
             
                    // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
         
     | 
| 
       302 
302 
     | 
    
         
             
                    return torch::pickle_load(v);
         
     | 
| 
       303 
303 
     | 
    
         
             
                  })
         
     | 
| 
       304 
     | 
    
         
            -
                .define_singleton_method(
         
     | 
| 
       305 
     | 
    
         
            -
                  "_binary_cross_entropy_with_logits",
         
     | 
| 
       306 
     | 
    
         
            -
                  *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
         
     | 
| 
       307 
     | 
    
         
            -
                    return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
         
     | 
| 
       308 
     | 
    
         
            -
                  })
         
     | 
| 
       309 
304 
     | 
    
         
             
                .define_singleton_method(
         
     | 
| 
       310 
305 
     | 
    
         
             
                  "_from_blob",
         
     | 
| 
       311 
306 
     | 
    
         
             
                  *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
         
     | 
| 
         @@ -379,11 +374,6 @@ void Init_ext() 
     | 
|
| 
       379 
374 
     | 
    
         
             
                  *[](Tensor& self, bool requires_grad) {
         
     | 
| 
       380 
375 
     | 
    
         
             
                    return self.set_requires_grad(requires_grad);
         
     | 
| 
       381 
376 
     | 
    
         
             
                  })
         
     | 
| 
       382 
     | 
    
         
            -
                .define_method(
         
     | 
| 
       383 
     | 
    
         
            -
                  "_backward",
         
     | 
| 
       384 
     | 
    
         
            -
                  *[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
         
     | 
| 
       385 
     | 
    
         
            -
                    return self.backward(gradient, create_graph, retain_graph);
         
     | 
| 
       386 
     | 
    
         
            -
                  })
         
     | 
| 
       387 
377 
     | 
    
         
             
                .define_method(
         
     | 
| 
       388 
378 
     | 
    
         
             
                  "grad",
         
     | 
| 
       389 
379 
     | 
    
         
             
                  *[](Tensor& self) {
         
     | 
    
        data/ext/torch/templates.cpp
    CHANGED
    
    | 
         @@ -53,3 +53,11 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) { 
     | 
|
| 
       53 
53 
     | 
    
         
             
              a.push(to_ruby<int64_t>(std::get<3>(x)));
         
     | 
| 
       54 
54 
     | 
    
         
             
              return Object(a);
         
     | 
| 
       55 
55 
     | 
    
         
             
            }
         
     | 
| 
      
 56 
     | 
    
         
            +
             
     | 
| 
      
 57 
     | 
    
         
            +
            Object wrap(std::vector<torch::Tensor> x) {
         
     | 
| 
      
 58 
     | 
    
         
            +
              Array a;
         
     | 
| 
      
 59 
     | 
    
         
            +
              for (auto& t : x) {
         
     | 
| 
      
 60 
     | 
    
         
            +
                a.push(to_ruby<torch::Tensor>(t));
         
     | 
| 
      
 61 
     | 
    
         
            +
              }
         
     | 
| 
      
 62 
     | 
    
         
            +
              return Object(a);
         
     | 
| 
      
 63 
     | 
    
         
            +
            }
         
     | 
    
        data/ext/torch/templates.hpp
    CHANGED
    
    | 
         @@ -10,6 +10,7 @@ 
     | 
|
| 
       10 
10 
     | 
    
         
             
            using namespace Rice;
         
     | 
| 
       11 
11 
     | 
    
         | 
| 
       12 
12 
     | 
    
         
             
            using torch::Device;
         
     | 
| 
      
 13 
     | 
    
         
            +
            using torch::Scalar;
         
     | 
| 
       13 
14 
     | 
    
         
             
            using torch::ScalarType;
         
     | 
| 
       14 
15 
     | 
    
         
             
            using torch::Tensor;
         
     | 
| 
       15 
16 
     | 
    
         | 
| 
         @@ -36,30 +37,6 @@ IntArrayRef from_ruby<IntArrayRef>(Object x) 
     | 
|
| 
       36 
37 
     | 
    
         
             
              return IntArrayRef(x);
         
     | 
| 
       37 
38 
     | 
    
         
             
            }
         
     | 
| 
       38 
39 
     | 
    
         | 
| 
       39 
     | 
    
         
            -
            // for now
         
     | 
| 
       40 
     | 
    
         
            -
            class Scalar {
         
     | 
| 
       41 
     | 
    
         
            -
              torch::Scalar value;
         
     | 
| 
       42 
     | 
    
         
            -
              public:
         
     | 
| 
       43 
     | 
    
         
            -
                Scalar(Object o) {
         
     | 
| 
       44 
     | 
    
         
            -
                  // TODO cast based on Ruby type
         
     | 
| 
       45 
     | 
    
         
            -
                  if (o.rb_type() == T_FIXNUM) {
         
     | 
| 
       46 
     | 
    
         
            -
                    value = torch::Scalar(from_ruby<int64_t>(o));
         
     | 
| 
       47 
     | 
    
         
            -
                  } else {
         
     | 
| 
       48 
     | 
    
         
            -
                    value = torch::Scalar(from_ruby<float>(o));
         
     | 
| 
       49 
     | 
    
         
            -
                  }
         
     | 
| 
       50 
     | 
    
         
            -
                }
         
     | 
| 
       51 
     | 
    
         
            -
                operator torch::Scalar() {
         
     | 
| 
       52 
     | 
    
         
            -
                  return value;
         
     | 
| 
       53 
     | 
    
         
            -
                }
         
     | 
| 
       54 
     | 
    
         
            -
            };
         
     | 
| 
       55 
     | 
    
         
            -
             
     | 
| 
       56 
     | 
    
         
            -
            template<>
         
     | 
| 
       57 
     | 
    
         
            -
            inline
         
     | 
| 
       58 
     | 
    
         
            -
            Scalar from_ruby<Scalar>(Object x)
         
     | 
| 
       59 
     | 
    
         
            -
            {
         
     | 
| 
       60 
     | 
    
         
            -
              return Scalar(x);
         
     | 
| 
       61 
     | 
    
         
            -
            }
         
     | 
| 
       62 
     | 
    
         
            -
             
     | 
| 
       63 
40 
     | 
    
         
             
            class TensorList {
         
     | 
| 
       64 
41 
     | 
    
         
             
              std::vector<torch::Tensor> vec;
         
     | 
| 
       65 
42 
     | 
    
         
             
              public:
         
     | 
| 
         @@ -192,6 +169,17 @@ class OptionalTensor { 
     | 
|
| 
       192 
169 
     | 
    
         
             
                }
         
     | 
| 
       193 
170 
     | 
    
         
             
            };
         
     | 
| 
       194 
171 
     | 
    
         | 
| 
      
 172 
     | 
    
         
            +
            template<>
         
     | 
| 
      
 173 
     | 
    
         
            +
            inline
         
     | 
| 
      
 174 
     | 
    
         
            +
            Scalar from_ruby<Scalar>(Object x)
         
     | 
| 
      
 175 
     | 
    
         
            +
            {
         
     | 
| 
      
 176 
     | 
    
         
            +
              if (x.rb_type() == T_FIXNUM) {
         
     | 
| 
      
 177 
     | 
    
         
            +
                return torch::Scalar(from_ruby<int64_t>(x));
         
     | 
| 
      
 178 
     | 
    
         
            +
              } else {
         
     | 
| 
      
 179 
     | 
    
         
            +
                return torch::Scalar(from_ruby<double>(x));
         
     | 
| 
      
 180 
     | 
    
         
            +
              }
         
     | 
| 
      
 181 
     | 
    
         
            +
            }
         
     | 
| 
      
 182 
     | 
    
         
            +
             
     | 
| 
       195 
183 
     | 
    
         
             
            template<>
         
     | 
| 
       196 
184 
     | 
    
         
             
            inline
         
     | 
| 
       197 
185 
     | 
    
         
             
            OptionalTensor from_ruby<OptionalTensor>(Object x)
         
     | 
| 
         @@ -221,9 +209,43 @@ torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x) 
     | 
|
| 
       221 
209 
     | 
    
         
             
              }
         
     | 
| 
       222 
210 
     | 
    
         
             
            }
         
     | 
| 
       223 
211 
     | 
    
         | 
| 
      
 212 
     | 
    
         
            +
            template<>
         
     | 
| 
      
 213 
     | 
    
         
            +
            inline
         
     | 
| 
      
 214 
     | 
    
         
            +
            torch::optional<double> from_ruby<torch::optional<double>>(Object x)
         
     | 
| 
      
 215 
     | 
    
         
            +
            {
         
     | 
| 
      
 216 
     | 
    
         
            +
              if (x.is_nil()) {
         
     | 
| 
      
 217 
     | 
    
         
            +
                return torch::nullopt;
         
     | 
| 
      
 218 
     | 
    
         
            +
              } else {
         
     | 
| 
      
 219 
     | 
    
         
            +
                return torch::optional<double>{from_ruby<double>(x)};
         
     | 
| 
      
 220 
     | 
    
         
            +
              }
         
     | 
| 
      
 221 
     | 
    
         
            +
            }
         
     | 
| 
      
 222 
     | 
    
         
            +
             
     | 
| 
      
 223 
     | 
    
         
            +
            template<>
         
     | 
| 
      
 224 
     | 
    
         
            +
            inline
         
     | 
| 
      
 225 
     | 
    
         
            +
            torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
         
     | 
| 
      
 226 
     | 
    
         
            +
            {
         
     | 
| 
      
 227 
     | 
    
         
            +
              if (x.is_nil()) {
         
     | 
| 
      
 228 
     | 
    
         
            +
                return torch::nullopt;
         
     | 
| 
      
 229 
     | 
    
         
            +
              } else {
         
     | 
| 
      
 230 
     | 
    
         
            +
                return torch::optional<bool>{from_ruby<bool>(x)};
         
     | 
| 
      
 231 
     | 
    
         
            +
              }
         
     | 
| 
      
 232 
     | 
    
         
            +
            }
         
     | 
| 
      
 233 
     | 
    
         
            +
             
     | 
| 
      
 234 
     | 
    
         
            +
            template<>
         
     | 
| 
      
 235 
     | 
    
         
            +
            inline
         
     | 
| 
      
 236 
     | 
    
         
            +
            torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
         
     | 
| 
      
 237 
     | 
    
         
            +
            {
         
     | 
| 
      
 238 
     | 
    
         
            +
              if (x.is_nil()) {
         
     | 
| 
      
 239 
     | 
    
         
            +
                return torch::nullopt;
         
     | 
| 
      
 240 
     | 
    
         
            +
              } else {
         
     | 
| 
      
 241 
     | 
    
         
            +
                return torch::optional<Scalar>{from_ruby<Scalar>(x)};
         
     | 
| 
      
 242 
     | 
    
         
            +
              }
         
     | 
| 
      
 243 
     | 
    
         
            +
            }
         
     | 
| 
      
 244 
     | 
    
         
            +
             
     | 
| 
       224 
245 
     | 
    
         
             
            Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
         
     | 
| 
       225 
246 
     | 
    
         
             
            Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
         
     | 
| 
       226 
247 
     | 
    
         
             
            Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
         
     | 
| 
       227 
248 
     | 
    
         
             
            Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
         
     | 
| 
       228 
249 
     | 
    
         
             
            Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
         
     | 
| 
       229 
250 
     | 
    
         
             
            Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
         
     | 
| 
      
 251 
     | 
    
         
            +
            Object wrap(std::vector<torch::Tensor> x);
         
     | 
    
        data/lib/torch.rb
    CHANGED
    
    | 
         @@ -239,25 +239,22 @@ module Torch 
     | 
|
| 
       239 
239 
     | 
    
         
             
                cls
         
     | 
| 
       240 
240 
     | 
    
         
             
              end
         
     | 
| 
       241 
241 
     | 
    
         | 
| 
       242 
     | 
    
         
            -
               
     | 
| 
       243 
     | 
    
         
            -
             
     | 
| 
       244 
     | 
    
         
            -
             
     | 
| 
       245 
     | 
    
         
            -
             
     | 
| 
       246 
     | 
    
         
            -
             
     | 
| 
       247 
     | 
    
         
            -
             
     | 
| 
       248 
     | 
    
         
            -
             
     | 
| 
       249 
     | 
    
         
            -
             
     | 
| 
       250 
     | 
    
         
            -
             
     | 
| 
       251 
     | 
    
         
            -
             
     | 
| 
       252 
     | 
    
         
            -
               
     | 
| 
       253 
     | 
    
         
            -
             
     | 
| 
       254 
     | 
    
         
            -
               
     | 
| 
       255 
     | 
    
         
            -
             
     | 
| 
       256 
     | 
    
         
            -
             
     | 
| 
       257 
     | 
    
         
            -
               
     | 
| 
       258 
     | 
    
         
            -
              CUDA::IntTensor = _make_tensor_class(:int32, true)
         
     | 
| 
       259 
     | 
    
         
            -
              CUDA::LongTensor = _make_tensor_class(:int64, true)
         
     | 
| 
       260 
     | 
    
         
            -
              CUDA::BoolTensor = _make_tensor_class(:bool, true)
         
     | 
| 
      
 242 
     | 
    
         
            +
              DTYPE_TO_CLASS = {
         
     | 
| 
      
 243 
     | 
    
         
            +
                float32: "FloatTensor",
         
     | 
| 
      
 244 
     | 
    
         
            +
                float64: "DoubleTensor",
         
     | 
| 
      
 245 
     | 
    
         
            +
                float16: "HalfTensor",
         
     | 
| 
      
 246 
     | 
    
         
            +
                uint8: "ByteTensor",
         
     | 
| 
      
 247 
     | 
    
         
            +
                int8: "CharTensor",
         
     | 
| 
      
 248 
     | 
    
         
            +
                int16: "ShortTensor",
         
     | 
| 
      
 249 
     | 
    
         
            +
                int32: "IntTensor",
         
     | 
| 
      
 250 
     | 
    
         
            +
                int64: "LongTensor",
         
     | 
| 
      
 251 
     | 
    
         
            +
                bool: "BoolTensor"
         
     | 
| 
      
 252 
     | 
    
         
            +
              }
         
     | 
| 
      
 253 
     | 
    
         
            +
             
     | 
| 
      
 254 
     | 
    
         
            +
              DTYPE_TO_CLASS.each do |dtype, class_name|
         
     | 
| 
      
 255 
     | 
    
         
            +
                const_set(class_name, _make_tensor_class(dtype))
         
     | 
| 
      
 256 
     | 
    
         
            +
                CUDA.const_set(class_name, _make_tensor_class(dtype, true))
         
     | 
| 
      
 257 
     | 
    
         
            +
              end
         
     | 
| 
       261 
258 
     | 
    
         | 
| 
       262 
259 
     | 
    
         
             
              class << self
         
     | 
| 
       263 
260 
     | 
    
         
             
                # Torch.float, Torch.long, etc
         
     | 
| 
         @@ -18,12 +18,12 @@ module Torch 
     | 
|
| 
       18 
18 
     | 
    
         
             
                      functions = functions()
         
     | 
| 
       19 
19 
     | 
    
         | 
| 
       20 
20 
     | 
    
         
             
                      # skip functions
         
     | 
| 
       21 
     | 
    
         
            -
                      skip_args = [" 
     | 
| 
      
 21 
     | 
    
         
            +
                      skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
         
     | 
| 
       22 
22 
     | 
    
         | 
| 
       23 
23 
     | 
    
         
             
                      # remove functions
         
     | 
| 
       24 
24 
     | 
    
         
             
                      functions.reject! do |f|
         
     | 
| 
       25 
25 
     | 
    
         
             
                        f.ruby_name.start_with?("_") ||
         
     | 
| 
       26 
     | 
    
         
            -
                        f.ruby_name. 
     | 
| 
      
 26 
     | 
    
         
            +
                        f.ruby_name.include?("_backward") ||
         
     | 
| 
       27 
27 
     | 
    
         
             
                        f.args.any? { |a| a[:type].include?("Dimname") }
         
     | 
| 
       28 
28 
     | 
    
         
             
                      end
         
     | 
| 
       29 
29 
     | 
    
         | 
| 
         @@ -31,7 +31,6 @@ module Torch 
     | 
|
| 
       31 
31 
     | 
    
         
             
                      todo_functions, functions =
         
     | 
| 
       32 
32 
     | 
    
         
             
                        functions.partition do |f|
         
     | 
| 
       33 
33 
     | 
    
         
             
                          f.args.any? do |a|
         
     | 
| 
       34 
     | 
    
         
            -
                            a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
         
     | 
| 
       35 
34 
     | 
    
         
             
                            skip_args.any? { |sa| a[:type].include?(sa) } ||
         
     | 
| 
       36 
35 
     | 
    
         
             
                            # call to 'range' is ambiguous
         
     | 
| 
       37 
36 
     | 
    
         
             
                            f.cpp_name == "_range" ||
         
     | 
| 
         @@ -104,6 +103,12 @@ void add_%{type}_functions(Module m) { 
     | 
|
| 
       104 
103 
     | 
    
         
             
                              "int64_t"
         
     | 
| 
       105 
104 
     | 
    
         
             
                            when "int?"
         
     | 
| 
       106 
105 
     | 
    
         
             
                              "torch::optional<int64_t>"
         
     | 
| 
      
 106 
     | 
    
         
            +
                            when "float?"
         
     | 
| 
      
 107 
     | 
    
         
            +
                              "torch::optional<double>"
         
     | 
| 
      
 108 
     | 
    
         
            +
                            when "bool?"
         
     | 
| 
      
 109 
     | 
    
         
            +
                              "torch::optional<bool>"
         
     | 
| 
      
 110 
     | 
    
         
            +
                            when "Scalar?"
         
     | 
| 
      
 111 
     | 
    
         
            +
                              "torch::optional<torch::Scalar>"
         
     | 
| 
       107 
112 
     | 
    
         
             
                            when "float"
         
     | 
| 
       108 
113 
     | 
    
         
             
                              "double"
         
     | 
| 
       109 
114 
     | 
    
         
             
                            when /\Aint\[/
         
     | 
| 
         @@ -130,8 +135,8 @@ void add_%{type}_functions(Module m) { 
     | 
|
| 
       130 
135 
     | 
    
         
             
                        prefix = def_method == :define_method ? "self." : "torch::"
         
     | 
| 
       131 
136 
     | 
    
         | 
| 
       132 
137 
     | 
    
         
             
                        body = "#{prefix}#{dispatch}(#{args.join(", ")})"
         
     | 
| 
       133 
     | 
    
         
            -
             
     | 
| 
       134 
     | 
    
         
            -
                        if func.ret_size > 1
         
     | 
| 
      
 138 
     | 
    
         
            +
             
     | 
| 
      
 139 
     | 
    
         
            +
                        if func.ret_size > 1 || func.ret_array?
         
     | 
| 
       135 
140 
     | 
    
         
             
                          body = "wrap(#{body})"
         
     | 
| 
       136 
141 
     | 
    
         
             
                        end
         
     | 
| 
       137 
142 
     | 
    
         | 
    
        data/lib/torch/native/parser.rb
    CHANGED
    
    | 
         @@ -85,6 +85,10 @@ module Torch 
     | 
|
| 
       85 
85 
     | 
    
         
             
                            end
         
     | 
| 
       86 
86 
     | 
    
         
             
                          when "int?"
         
     | 
| 
       87 
87 
     | 
    
         
             
                            v.is_a?(Integer) || v.nil?
         
     | 
| 
      
 88 
     | 
    
         
            +
                          when "float?"
         
     | 
| 
      
 89 
     | 
    
         
            +
                            v.is_a?(Numeric) || v.nil?
         
     | 
| 
      
 90 
     | 
    
         
            +
                          when "bool?"
         
     | 
| 
      
 91 
     | 
    
         
            +
                            v == true || v == false || v.nil?
         
     | 
| 
       88 
92 
     | 
    
         
             
                          when "float"
         
     | 
| 
       89 
93 
     | 
    
         
             
                            v.is_a?(Numeric)
         
     | 
| 
       90 
94 
     | 
    
         
             
                          when /int\[.*\]/
         
     | 
| 
         @@ -97,6 +101,10 @@ module Torch 
     | 
|
| 
       97 
101 
     | 
    
         
             
                            v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
         
     | 
| 
       98 
102 
     | 
    
         
             
                          when "Scalar"
         
     | 
| 
       99 
103 
     | 
    
         
             
                            v.is_a?(Numeric)
         
     | 
| 
      
 104 
     | 
    
         
            +
                          when "Scalar?"
         
     | 
| 
      
 105 
     | 
    
         
            +
                            v.is_a?(Numeric) || v.nil?
         
     | 
| 
      
 106 
     | 
    
         
            +
                          when "ScalarType"
         
     | 
| 
      
 107 
     | 
    
         
            +
                            false # not supported yet
         
     | 
| 
       100 
108 
     | 
    
         
             
                          when "ScalarType?"
         
     | 
| 
       101 
109 
     | 
    
         
             
                            v.nil?
         
     | 
| 
       102 
110 
     | 
    
         
             
                          when "bool"
         
     | 
    
        data/lib/torch/nn/functional.rb
    CHANGED
    
    | 
         @@ -178,8 +178,12 @@ module Torch 
     | 
|
| 
       178 
178 
     | 
    
         
             
                      Torch.hardshrink(input, lambd)
         
     | 
| 
       179 
179 
     | 
    
         
             
                    end
         
     | 
| 
       180 
180 
     | 
    
         | 
| 
       181 
     | 
    
         
            -
                    def leaky_relu(input, negative_slope = 0.01)
         
     | 
| 
       182 
     | 
    
         
            -
                       
     | 
| 
      
 181 
     | 
    
         
            +
                    def leaky_relu(input, negative_slope = 0.01, inplace: false)
         
     | 
| 
      
 182 
     | 
    
         
            +
                      if inplace
         
     | 
| 
      
 183 
     | 
    
         
            +
                        NN.leaky_relu!(input, negative_slope)
         
     | 
| 
      
 184 
     | 
    
         
            +
                      else
         
     | 
| 
      
 185 
     | 
    
         
            +
                        NN.leaky_relu(input, negative_slope)
         
     | 
| 
      
 186 
     | 
    
         
            +
                      end
         
     | 
| 
       183 
187 
     | 
    
         
             
                    end
         
     | 
| 
       184 
188 
     | 
    
         | 
| 
       185 
189 
     | 
    
         
             
                    def log_sigmoid(input)
         
     | 
    
        data/lib/torch/nn/leaky_relu.rb
    CHANGED
    
    | 
         @@ -1,14 +1,14 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            module Torch
         
     | 
| 
       2 
2 
     | 
    
         
             
              module NN
         
     | 
| 
       3 
3 
     | 
    
         
             
                class LeakyReLU < Module
         
     | 
| 
       4 
     | 
    
         
            -
                  def initialize(negative_slope: 1e-2 
     | 
| 
      
 4 
     | 
    
         
            +
                  def initialize(negative_slope: 1e-2, inplace: false)
         
     | 
| 
       5 
5 
     | 
    
         
             
                    super()
         
     | 
| 
       6 
6 
     | 
    
         
             
                    @negative_slope = negative_slope
         
     | 
| 
       7 
     | 
    
         
            -
                     
     | 
| 
      
 7 
     | 
    
         
            +
                    @inplace = inplace
         
     | 
| 
       8 
8 
     | 
    
         
             
                  end
         
     | 
| 
       9 
9 
     | 
    
         | 
| 
       10 
10 
     | 
    
         
             
                  def forward(input)
         
     | 
| 
       11 
     | 
    
         
            -
                    F.leaky_relu(input, @negative_slope 
     | 
| 
      
 11 
     | 
    
         
            +
                    F.leaky_relu(input, @negative_slope, inplace: @inplace)
         
     | 
| 
       12 
12 
     | 
    
         
             
                  end
         
     | 
| 
       13 
13 
     | 
    
         | 
| 
       14 
14 
     | 
    
         
             
                  def extra_inspect
         
     | 
    
        data/lib/torch/nn/module.rb
    CHANGED
    
    
    
        data/lib/torch/tensor.rb
    CHANGED
    
    | 
         @@ -103,11 +103,6 @@ module Torch 
     | 
|
| 
       103 
103 
     | 
    
         
             
                  Torch.empty(0, dtype: dtype)
         
     | 
| 
       104 
104 
     | 
    
         
             
                end
         
     | 
| 
       105 
105 
     | 
    
         | 
| 
       106 
     | 
    
         
            -
                def backward(gradient = nil, retain_graph: nil, create_graph: false)
         
     | 
| 
       107 
     | 
    
         
            -
                  retain_graph = create_graph if retain_graph.nil?
         
     | 
| 
       108 
     | 
    
         
            -
                  _backward(gradient, retain_graph, create_graph)
         
     | 
| 
       109 
     | 
    
         
            -
                end
         
     | 
| 
       110 
     | 
    
         
            -
             
     | 
| 
       111 
106 
     | 
    
         
             
                # TODO read directly from memory
         
     | 
| 
       112 
107 
     | 
    
         
             
                def numo
         
     | 
| 
       113 
108 
     | 
    
         
             
                  cls = Torch._dtype_to_numo[dtype]
         
     | 
| 
         @@ -235,7 +230,7 @@ module Torch 
     | 
|
| 
       235 
230 
     | 
    
         
             
                    when Integer
         
     | 
| 
       236 
231 
     | 
    
         
             
                      TensorIndex.integer(index)
         
     | 
| 
       237 
232 
     | 
    
         
             
                    when Range
         
     | 
| 
       238 
     | 
    
         
            -
                      finish = index.end
         
     | 
| 
      
 233 
     | 
    
         
            +
                      finish = index.end || -1
         
     | 
| 
       239 
234 
     | 
    
         
             
                      if finish == -1 && !index.exclude_end?
         
     | 
| 
       240 
235 
     | 
    
         
             
                        finish = nil
         
     | 
| 
       241 
236 
     | 
    
         
             
                      else
         
     | 
    
        data/lib/torch/version.rb
    CHANGED
    
    
    
        metadata
    CHANGED
    
    | 
         @@ -1,14 +1,14 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            --- !ruby/object:Gem::Specification
         
     | 
| 
       2 
2 
     | 
    
         
             
            name: torch-rb
         
     | 
| 
       3 
3 
     | 
    
         
             
            version: !ruby/object:Gem::Version
         
     | 
| 
       4 
     | 
    
         
            -
              version: 0.3. 
     | 
| 
      
 4 
     | 
    
         
            +
              version: 0.3.6
         
     | 
| 
       5 
5 
     | 
    
         
             
            platform: ruby
         
     | 
| 
       6 
6 
     | 
    
         
             
            authors:
         
     | 
| 
       7 
7 
     | 
    
         
             
            - Andrew Kane
         
     | 
| 
       8 
8 
     | 
    
         
             
            autorequire: 
         
     | 
| 
       9 
9 
     | 
    
         
             
            bindir: bin
         
     | 
| 
       10 
10 
     | 
    
         
             
            cert_chain: []
         
     | 
| 
       11 
     | 
    
         
            -
            date: 2020-09- 
     | 
| 
      
 11 
     | 
    
         
            +
            date: 2020-09-18 00:00:00.000000000 Z
         
     | 
| 
       12 
12 
     | 
    
         
             
            dependencies:
         
     | 
| 
       13 
13 
     | 
    
         
             
            - !ruby/object:Gem::Dependency
         
     | 
| 
       14 
14 
     | 
    
         
             
              name: rice
         
     |