torch-rb 0.3.7 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
 - data/CHANGELOG.md +5 -0
 - data/README.md +1 -1
 - data/codegen/function.rb +134 -0
 - data/codegen/generate_functions.rb +546 -0
 - data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
 - data/ext/torch/ext.cpp +54 -75
 - data/ext/torch/extconf.rb +2 -2
 - data/ext/torch/nn_functions.h +6 -0
 - data/ext/torch/ruby_arg_parser.cpp +593 -0
 - data/ext/torch/ruby_arg_parser.h +373 -0
 - data/ext/torch/{templates.hpp → templates.h} +30 -51
 - data/ext/torch/tensor_functions.h +6 -0
 - data/ext/torch/torch_functions.h +6 -0
 - data/ext/torch/utils.h +42 -0
 - data/ext/torch/{templates.cpp → wrap_outputs.h} +16 -15
 - data/lib/torch.rb +0 -62
 - data/lib/torch/nn/functional.rb +30 -16
 - data/lib/torch/nn/init.rb +5 -19
 - data/lib/torch/optim/adadelta.rb +1 -1
 - data/lib/torch/optim/adam.rb +2 -2
 - data/lib/torch/optim/adamax.rb +1 -1
 - data/lib/torch/optim/adamw.rb +1 -1
 - data/lib/torch/optim/asgd.rb +1 -1
 - data/lib/torch/optim/sgd.rb +3 -3
 - data/lib/torch/tensor.rb +25 -105
 - data/lib/torch/version.rb +1 -1
 - metadata +27 -9
 - data/lib/torch/native/dispatcher.rb +0 -70
 - data/lib/torch/native/function.rb +0 -200
 - data/lib/torch/native/generator.rb +0 -178
 - data/lib/torch/native/parser.rb +0 -117
 
    
        data/ext/torch/utils.h
    ADDED
    
    | 
         @@ -0,0 +1,42 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            #pragma once
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            #include <rice/Symbol.hpp>
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            // keep THP prefix for now to make it easier to compare code
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            extern VALUE THPVariableClass;
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            inline VALUE THPUtils_internSymbol(const std::string& str) {
         
     | 
| 
      
 10 
     | 
    
         
            +
              return Symbol(str);
         
     | 
| 
      
 11 
     | 
    
         
            +
            }
         
     | 
| 
      
 12 
     | 
    
         
            +
             
     | 
| 
      
 13 
     | 
    
         
            +
            inline std::string THPUtils_unpackSymbol(VALUE obj) {
         
     | 
| 
      
 14 
     | 
    
         
            +
              Check_Type(obj, T_SYMBOL);
         
     | 
| 
      
 15 
     | 
    
         
            +
              obj = rb_funcall(obj, rb_intern("to_s"), 0);
         
     | 
| 
      
 16 
     | 
    
         
            +
              return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
         
     | 
| 
      
 17 
     | 
    
         
            +
            }
         
     | 
| 
      
 18 
     | 
    
         
            +
             
     | 
| 
      
 19 
     | 
    
         
            +
            inline std::string THPUtils_unpackString(VALUE obj) {
         
     | 
| 
      
 20 
     | 
    
         
            +
              Check_Type(obj, T_STRING);
         
     | 
| 
      
 21 
     | 
    
         
            +
              return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
         
     | 
| 
      
 22 
     | 
    
         
            +
            }
         
     | 
| 
      
 23 
     | 
    
         
            +
             
     | 
| 
      
 24 
     | 
    
         
            +
            inline bool THPUtils_checkSymbol(VALUE obj) {
         
     | 
| 
      
 25 
     | 
    
         
            +
              return SYMBOL_P(obj);
         
     | 
| 
      
 26 
     | 
    
         
            +
            }
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
      
 28 
     | 
    
         
            +
            inline bool THPUtils_checkIndex(VALUE obj) {
         
     | 
| 
      
 29 
     | 
    
         
            +
              return FIXNUM_P(obj);
         
     | 
| 
      
 30 
     | 
    
         
            +
            }
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
            inline bool THPUtils_checkScalar(VALUE obj) {
         
     | 
| 
      
 33 
     | 
    
         
            +
              return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
         
     | 
| 
      
 34 
     | 
    
         
            +
            }
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
            inline bool THPVariable_Check(VALUE obj) {
         
     | 
| 
      
 37 
     | 
    
         
            +
              return rb_obj_is_kind_of(obj, THPVariableClass);
         
     | 
| 
      
 38 
     | 
    
         
            +
            }
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
            inline bool THPVariable_CheckExact(VALUE obj) {
         
     | 
| 
      
 41 
     | 
    
         
            +
              return rb_obj_is_instance_of(obj, THPVariableClass);
         
     | 
| 
      
 42 
     | 
    
         
            +
            }
         
     | 
| 
         @@ -1,43 +1,44 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            #pragma once
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
       1 
3 
     | 
    
         
             
            #include <torch/torch.h>
         
     | 
| 
       2 
4 
     | 
    
         
             
            #include <rice/Object.hpp>
         
     | 
| 
       3 
     | 
    
         
            -
            #include "templates.hpp"
         
     | 
| 
       4 
5 
     | 
    
         | 
| 
       5 
     | 
    
         
            -
            Object wrap(bool x) {
         
     | 
| 
      
 6 
     | 
    
         
            +
            inline Object wrap(bool x) {
         
     | 
| 
       6 
7 
     | 
    
         
             
              return to_ruby<bool>(x);
         
     | 
| 
       7 
8 
     | 
    
         
             
            }
         
     | 
| 
       8 
9 
     | 
    
         | 
| 
       9 
     | 
    
         
            -
            Object wrap(int64_t x) {
         
     | 
| 
      
 10 
     | 
    
         
            +
            inline Object wrap(int64_t x) {
         
     | 
| 
       10 
11 
     | 
    
         
             
              return to_ruby<int64_t>(x);
         
     | 
| 
       11 
12 
     | 
    
         
             
            }
         
     | 
| 
       12 
13 
     | 
    
         | 
| 
       13 
     | 
    
         
            -
            Object wrap(double x) {
         
     | 
| 
      
 14 
     | 
    
         
            +
            inline Object wrap(double x) {
         
     | 
| 
       14 
15 
     | 
    
         
             
              return to_ruby<double>(x);
         
     | 
| 
       15 
16 
     | 
    
         
             
            }
         
     | 
| 
       16 
17 
     | 
    
         | 
| 
       17 
     | 
    
         
            -
            Object wrap(torch::Tensor x) {
         
     | 
| 
      
 18 
     | 
    
         
            +
            inline Object wrap(torch::Tensor x) {
         
     | 
| 
       18 
19 
     | 
    
         
             
              return to_ruby<torch::Tensor>(x);
         
     | 
| 
       19 
20 
     | 
    
         
             
            }
         
     | 
| 
       20 
21 
     | 
    
         | 
| 
       21 
     | 
    
         
            -
            Object wrap(torch::Scalar x) {
         
     | 
| 
      
 22 
     | 
    
         
            +
            inline Object wrap(torch::Scalar x) {
         
     | 
| 
       22 
23 
     | 
    
         
             
              return to_ruby<torch::Scalar>(x);
         
     | 
| 
       23 
24 
     | 
    
         
             
            }
         
     | 
| 
       24 
25 
     | 
    
         | 
| 
       25 
     | 
    
         
            -
            Object wrap(torch::ScalarType x) {
         
     | 
| 
      
 26 
     | 
    
         
            +
            inline Object wrap(torch::ScalarType x) {
         
     | 
| 
       26 
27 
     | 
    
         
             
              return to_ruby<torch::ScalarType>(x);
         
     | 
| 
       27 
28 
     | 
    
         
             
            }
         
     | 
| 
       28 
29 
     | 
    
         | 
| 
       29 
     | 
    
         
            -
            Object wrap(torch::QScheme x) {
         
     | 
| 
      
 30 
     | 
    
         
            +
            inline Object wrap(torch::QScheme x) {
         
     | 
| 
       30 
31 
     | 
    
         
             
              return to_ruby<torch::QScheme>(x);
         
     | 
| 
       31 
32 
     | 
    
         
             
            }
         
     | 
| 
       32 
33 
     | 
    
         | 
| 
       33 
     | 
    
         
            -
            Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
         
     | 
| 
      
 34 
     | 
    
         
            +
            inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
         
     | 
| 
       34 
35 
     | 
    
         
             
              Array a;
         
     | 
| 
       35 
36 
     | 
    
         
             
              a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
         
     | 
| 
       36 
37 
     | 
    
         
             
              a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
         
     | 
| 
       37 
38 
     | 
    
         
             
              return Object(a);
         
     | 
| 
       38 
39 
     | 
    
         
             
            }
         
     | 
| 
       39 
40 
     | 
    
         | 
| 
       40 
     | 
    
         
            -
            Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
         
     | 
| 
      
 41 
     | 
    
         
            +
            inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
         
     | 
| 
       41 
42 
     | 
    
         
             
              Array a;
         
     | 
| 
       42 
43 
     | 
    
         
             
              a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
         
     | 
| 
       43 
44 
     | 
    
         
             
              a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
         
     | 
| 
         @@ -45,7 +46,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) { 
     | 
|
| 
       45 
46 
     | 
    
         
             
              return Object(a);
         
     | 
| 
       46 
47 
     | 
    
         
             
            }
         
     | 
| 
       47 
48 
     | 
    
         | 
| 
       48 
     | 
    
         
            -
            Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
         
     | 
| 
      
 49 
     | 
    
         
            +
            inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
         
     | 
| 
       49 
50 
     | 
    
         
             
              Array a;
         
     | 
| 
       50 
51 
     | 
    
         
             
              a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
         
     | 
| 
       51 
52 
     | 
    
         
             
              a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
         
     | 
| 
         @@ -54,7 +55,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso 
     | 
|
| 
       54 
55 
     | 
    
         
             
              return Object(a);
         
     | 
| 
       55 
56 
     | 
    
         
             
            }
         
     | 
| 
       56 
57 
     | 
    
         | 
| 
       57 
     | 
    
         
            -
            Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
         
     | 
| 
      
 58 
     | 
    
         
            +
            inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
         
     | 
| 
       58 
59 
     | 
    
         
             
              Array a;
         
     | 
| 
       59 
60 
     | 
    
         
             
              a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
         
     | 
| 
       60 
61 
     | 
    
         
             
              a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
         
     | 
| 
         @@ -64,7 +65,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso 
     | 
|
| 
       64 
65 
     | 
    
         
             
              return Object(a);
         
     | 
| 
       65 
66 
     | 
    
         
             
            }
         
     | 
| 
       66 
67 
     | 
    
         | 
| 
       67 
     | 
    
         
            -
            Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
         
     | 
| 
      
 68 
     | 
    
         
            +
            inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
         
     | 
| 
       68 
69 
     | 
    
         
             
              Array a;
         
     | 
| 
       69 
70 
     | 
    
         
             
              a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
         
     | 
| 
       70 
71 
     | 
    
         
             
              a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
         
     | 
| 
         @@ -73,7 +74,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) 
     | 
|
| 
       73 
74 
     | 
    
         
             
              return Object(a);
         
     | 
| 
       74 
75 
     | 
    
         
             
            }
         
     | 
| 
       75 
76 
     | 
    
         | 
| 
       76 
     | 
    
         
            -
            Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
         
     | 
| 
      
 77 
     | 
    
         
            +
            inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
         
     | 
| 
       77 
78 
     | 
    
         
             
              Array a;
         
     | 
| 
       78 
79 
     | 
    
         
             
              a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
         
     | 
| 
       79 
80 
     | 
    
         
             
              a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
         
     | 
| 
         @@ -82,7 +83,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) { 
     | 
|
| 
       82 
83 
     | 
    
         
             
              return Object(a);
         
     | 
| 
       83 
84 
     | 
    
         
             
            }
         
     | 
| 
       84 
85 
     | 
    
         | 
| 
       85 
     | 
    
         
            -
            Object wrap( 
     | 
| 
      
 86 
     | 
    
         
            +
            inline Object wrap(torch::TensorList x) {
         
     | 
| 
       86 
87 
     | 
    
         
             
              Array a;
         
     | 
| 
       87 
88 
     | 
    
         
             
              for (auto& t : x) {
         
     | 
| 
       88 
89 
     | 
    
         
             
                a.push(to_ruby<torch::Tensor>(t));
         
     | 
    
        data/lib/torch.rb
    CHANGED
    
    | 
         @@ -7,11 +7,6 @@ require "net/http" 
     | 
|
| 
       7 
7 
     | 
    
         
             
            require "set"
         
     | 
| 
       8 
8 
     | 
    
         
             
            require "tmpdir"
         
     | 
| 
       9 
9 
     | 
    
         | 
| 
       10 
     | 
    
         
            -
            # native functions
         
     | 
| 
       11 
     | 
    
         
            -
            require "torch/native/generator"
         
     | 
| 
       12 
     | 
    
         
            -
            require "torch/native/parser"
         
     | 
| 
       13 
     | 
    
         
            -
            require "torch/native/dispatcher"
         
     | 
| 
       14 
     | 
    
         
            -
             
     | 
| 
       15 
10 
     | 
    
         
             
            # modules
         
     | 
| 
       16 
11 
     | 
    
         
             
            require "torch/inspector"
         
     | 
| 
       17 
12 
     | 
    
         
             
            require "torch/tensor"
         
     | 
| 
         @@ -374,63 +369,6 @@ module Torch 
     | 
|
| 
       374 
369 
     | 
    
         | 
| 
       375 
370 
     | 
    
         
             
                # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
         
     | 
| 
       376 
371 
     | 
    
         | 
| 
       377 
     | 
    
         
            -
                def arange(start, finish = nil, step = 1, **options)
         
     | 
| 
       378 
     | 
    
         
            -
                  # ruby doesn't support start = 0, finish, step = 1, ...
         
     | 
| 
       379 
     | 
    
         
            -
                  if finish.nil?
         
     | 
| 
       380 
     | 
    
         
            -
                    finish = start
         
     | 
| 
       381 
     | 
    
         
            -
                    start = 0
         
     | 
| 
       382 
     | 
    
         
            -
                  end
         
     | 
| 
       383 
     | 
    
         
            -
                  _arange(start, finish, step, tensor_options(**options))
         
     | 
| 
       384 
     | 
    
         
            -
                end
         
     | 
| 
       385 
     | 
    
         
            -
             
     | 
| 
       386 
     | 
    
         
            -
                def empty(*size, **options)
         
     | 
| 
       387 
     | 
    
         
            -
                  _empty(tensor_size(size), tensor_options(**options))
         
     | 
| 
       388 
     | 
    
         
            -
                end
         
     | 
| 
       389 
     | 
    
         
            -
             
     | 
| 
       390 
     | 
    
         
            -
                def eye(n, m = nil, **options)
         
     | 
| 
       391 
     | 
    
         
            -
                  _eye(n, m || n, tensor_options(**options))
         
     | 
| 
       392 
     | 
    
         
            -
                end
         
     | 
| 
       393 
     | 
    
         
            -
             
     | 
| 
       394 
     | 
    
         
            -
                def full(size, fill_value, **options)
         
     | 
| 
       395 
     | 
    
         
            -
                  _full(size, fill_value, tensor_options(**options))
         
     | 
| 
       396 
     | 
    
         
            -
                end
         
     | 
| 
       397 
     | 
    
         
            -
             
     | 
| 
       398 
     | 
    
         
            -
                def linspace(start, finish, steps = 100, **options)
         
     | 
| 
       399 
     | 
    
         
            -
                  _linspace(start, finish, steps, tensor_options(**options))
         
     | 
| 
       400 
     | 
    
         
            -
                end
         
     | 
| 
       401 
     | 
    
         
            -
             
     | 
| 
       402 
     | 
    
         
            -
                def logspace(start, finish, steps = 100, base = 10.0, **options)
         
     | 
| 
       403 
     | 
    
         
            -
                  _logspace(start, finish, steps, base, tensor_options(**options))
         
     | 
| 
       404 
     | 
    
         
            -
                end
         
     | 
| 
       405 
     | 
    
         
            -
             
     | 
| 
       406 
     | 
    
         
            -
                def ones(*size, **options)
         
     | 
| 
       407 
     | 
    
         
            -
                  _ones(tensor_size(size), tensor_options(**options))
         
     | 
| 
       408 
     | 
    
         
            -
                end
         
     | 
| 
       409 
     | 
    
         
            -
             
     | 
| 
       410 
     | 
    
         
            -
                def rand(*size, **options)
         
     | 
| 
       411 
     | 
    
         
            -
                  _rand(tensor_size(size), tensor_options(**options))
         
     | 
| 
       412 
     | 
    
         
            -
                end
         
     | 
| 
       413 
     | 
    
         
            -
             
     | 
| 
       414 
     | 
    
         
            -
                def randint(low = 0, high, size, **options)
         
     | 
| 
       415 
     | 
    
         
            -
                  _randint(low, high, size, tensor_options(**options))
         
     | 
| 
       416 
     | 
    
         
            -
                end
         
     | 
| 
       417 
     | 
    
         
            -
             
     | 
| 
       418 
     | 
    
         
            -
                 def randn(*size, **options)
         
     | 
| 
       419 
     | 
    
         
            -
                  _randn(tensor_size(size), tensor_options(**options))
         
     | 
| 
       420 
     | 
    
         
            -
                end
         
     | 
| 
       421 
     | 
    
         
            -
             
     | 
| 
       422 
     | 
    
         
            -
                def randperm(n, **options)
         
     | 
| 
       423 
     | 
    
         
            -
                  # dtype hack in Python
         
     | 
| 
       424 
     | 
    
         
            -
                  # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
         
     | 
| 
       425 
     | 
    
         
            -
                  options[:dtype] ||= :int64
         
     | 
| 
       426 
     | 
    
         
            -
             
     | 
| 
       427 
     | 
    
         
            -
                  _randperm(n, tensor_options(**options))
         
     | 
| 
       428 
     | 
    
         
            -
                end
         
     | 
| 
       429 
     | 
    
         
            -
             
     | 
| 
       430 
     | 
    
         
            -
                def zeros(*size, **options)
         
     | 
| 
       431 
     | 
    
         
            -
                  _zeros(tensor_size(size), tensor_options(**options))
         
     | 
| 
       432 
     | 
    
         
            -
                end
         
     | 
| 
       433 
     | 
    
         
            -
             
     | 
| 
       434 
372 
     | 
    
         
             
                def tensor(data, **options)
         
     | 
| 
       435 
373 
     | 
    
         
             
                  if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
         
     | 
| 
       436 
374 
     | 
    
         
             
                    numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
         
     | 
    
        data/lib/torch/nn/functional.rb
    CHANGED
    
    | 
         @@ -394,15 +394,15 @@ module Torch 
     | 
|
| 
       394 
394 
     | 
    
         
             
                    # loss functions
         
     | 
| 
       395 
395 
     | 
    
         | 
| 
       396 
396 
     | 
    
         
             
                    def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
         
     | 
| 
       397 
     | 
    
         
            -
                      NN.binary_cross_entropy(input, target, weight, reduction)
         
     | 
| 
      
 397 
     | 
    
         
            +
                      NN.binary_cross_entropy(input, target, weight, to_reduction(reduction))
         
     | 
| 
       398 
398 
     | 
    
         
             
                    end
         
     | 
| 
       399 
399 
     | 
    
         | 
| 
       400 
400 
     | 
    
         
             
                    def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
         
     | 
| 
       401 
     | 
    
         
            -
                      Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
         
     | 
| 
      
 401 
     | 
    
         
            +
                      Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, to_reduction(reduction))
         
     | 
| 
       402 
402 
     | 
    
         
             
                    end
         
     | 
| 
       403 
403 
     | 
    
         | 
| 
       404 
404 
     | 
    
         
             
                    def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
         
     | 
| 
       405 
     | 
    
         
            -
                      Torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
         
     | 
| 
      
 405 
     | 
    
         
            +
                      Torch.cosine_embedding_loss(input1, input2, target, margin, to_reduction(reduction))
         
     | 
| 
       406 
406 
     | 
    
         
             
                    end
         
     | 
| 
       407 
407 
     | 
    
         | 
| 
       408 
408 
     | 
    
         
             
                    def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
         
     | 
| 
         @@ -411,34 +411,34 @@ module Torch 
     | 
|
| 
       411 
411 
     | 
    
         | 
| 
       412 
412 
     | 
    
         
             
                    def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
         
     | 
| 
       413 
413 
     | 
    
         
             
                      # call to_a on input_lengths and target_lengths for C++
         
     | 
| 
       414 
     | 
    
         
            -
                      Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
         
     | 
| 
      
 414 
     | 
    
         
            +
                      Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, to_reduction(reduction), zero_infinity)
         
     | 
| 
       415 
415 
     | 
    
         
             
                    end
         
     | 
| 
       416 
416 
     | 
    
         | 
| 
       417 
417 
     | 
    
         
             
                    def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
         
     | 
| 
       418 
     | 
    
         
            -
                      Torch.hinge_embedding_loss(input, target, margin, reduction)
         
     | 
| 
      
 418 
     | 
    
         
            +
                      Torch.hinge_embedding_loss(input, target, margin, to_reduction(reduction))
         
     | 
| 
       419 
419 
     | 
    
         
             
                    end
         
     | 
| 
       420 
420 
     | 
    
         | 
| 
       421 
421 
     | 
    
         
             
                    def kl_div(input, target, reduction: "mean")
         
     | 
| 
       422 
     | 
    
         
            -
                      Torch.kl_div(input, target, reduction)
         
     | 
| 
      
 422 
     | 
    
         
            +
                      Torch.kl_div(input, target, to_reduction(reduction))
         
     | 
| 
       423 
423 
     | 
    
         
             
                    end
         
     | 
| 
       424 
424 
     | 
    
         | 
| 
       425 
425 
     | 
    
         
             
                    def l1_loss(input, target, reduction: "mean")
         
     | 
| 
       426 
     | 
    
         
            -
                      NN.l1_loss(input, target, reduction)
         
     | 
| 
      
 426 
     | 
    
         
            +
                      NN.l1_loss(input, target, to_reduction(reduction))
         
     | 
| 
       427 
427 
     | 
    
         
             
                    end
         
     | 
| 
       428 
428 
     | 
    
         | 
| 
       429 
429 
     | 
    
         
             
                    def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
         
     | 
| 
       430 
     | 
    
         
            -
                      Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
         
     | 
| 
      
 430 
     | 
    
         
            +
                      Torch.margin_ranking_loss(input1, input2, target, margin, to_reduction(reduction))
         
     | 
| 
       431 
431 
     | 
    
         
             
                    end
         
     | 
| 
       432 
432 
     | 
    
         | 
| 
       433 
433 
     | 
    
         
             
                    def mse_loss(input, target, reduction: "mean")
         
     | 
| 
       434 
434 
     | 
    
         
             
                      if target.size != input.size
         
     | 
| 
       435 
435 
     | 
    
         
             
                        warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
         
     | 
| 
       436 
436 
     | 
    
         
             
                      end
         
     | 
| 
       437 
     | 
    
         
            -
                      NN.mse_loss(input, target, reduction)
         
     | 
| 
      
 437 
     | 
    
         
            +
                      NN.mse_loss(input, target, to_reduction(reduction))
         
     | 
| 
       438 
438 
     | 
    
         
             
                    end
         
     | 
| 
       439 
439 
     | 
    
         | 
| 
       440 
440 
     | 
    
         
             
                    def multilabel_margin_loss(input, target, reduction: "mean")
         
     | 
| 
       441 
     | 
    
         
            -
                      NN.multilabel_margin_loss(input, target, reduction)
         
     | 
| 
      
 441 
     | 
    
         
            +
                      NN.multilabel_margin_loss(input, target, to_reduction(reduction))
         
     | 
| 
       442 
442 
     | 
    
         
             
                    end
         
     | 
| 
       443 
443 
     | 
    
         | 
| 
       444 
444 
     | 
    
         
             
                    def multilabel_soft_margin_loss(input, target, weight: nil)
         
     | 
| 
         @@ -446,27 +446,27 @@ module Torch 
     | 
|
| 
       446 
446 
     | 
    
         
             
                    end
         
     | 
| 
       447 
447 
     | 
    
         | 
| 
       448 
448 
     | 
    
         
             
                    def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
         
     | 
| 
       449 
     | 
    
         
            -
                      NN.multi_margin_loss(input, target, p, margin, weight, reduction)
         
     | 
| 
      
 449 
     | 
    
         
            +
                      NN.multi_margin_loss(input, target, p, margin, weight, to_reduction(reduction))
         
     | 
| 
       450 
450 
     | 
    
         
             
                    end
         
     | 
| 
       451 
451 
     | 
    
         | 
| 
       452 
452 
     | 
    
         
             
                    def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
         
     | 
| 
       453 
     | 
    
         
            -
                      NN.nll_loss(input, target, weight, reduction, ignore_index)
         
     | 
| 
      
 453 
     | 
    
         
            +
                      NN.nll_loss(input, target, weight, to_reduction(reduction), ignore_index)
         
     | 
| 
       454 
454 
     | 
    
         
             
                    end
         
     | 
| 
       455 
455 
     | 
    
         | 
| 
       456 
456 
     | 
    
         
             
                    def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
         
     | 
| 
       457 
     | 
    
         
            -
                      Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
         
     | 
| 
      
 457 
     | 
    
         
            +
                      Torch.poisson_nll_loss(input, target, log_input, full, eps, to_reduction(reduction))
         
     | 
| 
       458 
458 
     | 
    
         
             
                    end
         
     | 
| 
       459 
459 
     | 
    
         | 
| 
       460 
460 
     | 
    
         
             
                    def soft_margin_loss(input, target, reduction: "mean")
         
     | 
| 
       461 
     | 
    
         
            -
                      NN.soft_margin_loss(input, target, reduction)
         
     | 
| 
      
 461 
     | 
    
         
            +
                      NN.soft_margin_loss(input, target, to_reduction(reduction))
         
     | 
| 
       462 
462 
     | 
    
         
             
                    end
         
     | 
| 
       463 
463 
     | 
    
         | 
| 
       464 
464 
     | 
    
         
             
                    def smooth_l1_loss(input, target, reduction: "mean")
         
     | 
| 
       465 
     | 
    
         
            -
                      NN.smooth_l1_loss(input, target, reduction)
         
     | 
| 
      
 465 
     | 
    
         
            +
                      NN.smooth_l1_loss(input, target, to_reduction(reduction))
         
     | 
| 
       466 
466 
     | 
    
         
             
                    end
         
     | 
| 
       467 
467 
     | 
    
         | 
| 
       468 
468 
     | 
    
         
             
                    def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
         
     | 
| 
       469 
     | 
    
         
            -
                      Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
         
     | 
| 
      
 469 
     | 
    
         
            +
                      Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
         
     | 
| 
       470 
470 
     | 
    
         
             
                    end
         
     | 
| 
       471 
471 
     | 
    
         | 
| 
       472 
472 
     | 
    
         
             
                    # vision
         
     | 
| 
         @@ -542,6 +542,20 @@ module Torch 
     | 
|
| 
       542 
542 
     | 
    
         | 
| 
       543 
543 
     | 
    
         
             
                    private
         
     | 
| 
       544 
544 
     | 
    
         | 
| 
      
 545 
     | 
    
         
            +
                    # see _reduction.py
         
     | 
| 
      
 546 
     | 
    
         
            +
                    def to_reduction(v)
         
     | 
| 
      
 547 
     | 
    
         
            +
                      case v.to_s
         
     | 
| 
      
 548 
     | 
    
         
            +
                      when "none"
         
     | 
| 
      
 549 
     | 
    
         
            +
                        0
         
     | 
| 
      
 550 
     | 
    
         
            +
                      when "mean"
         
     | 
| 
      
 551 
     | 
    
         
            +
                        1
         
     | 
| 
      
 552 
     | 
    
         
            +
                      when "sum"
         
     | 
| 
      
 553 
     | 
    
         
            +
                        2
         
     | 
| 
      
 554 
     | 
    
         
            +
                      else
         
     | 
| 
      
 555 
     | 
    
         
            +
                        raise ArgumentError, "#{v} is not a valid value for reduction"
         
     | 
| 
      
 556 
     | 
    
         
            +
                      end
         
     | 
| 
      
 557 
     | 
    
         
            +
                    end
         
     | 
| 
      
 558 
     | 
    
         
            +
             
     | 
| 
       545 
559 
     | 
    
         
             
                    def softmax_dim(ndim)
         
     | 
| 
       546 
560 
     | 
    
         
             
                      ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
         
     | 
| 
       547 
561 
     | 
    
         
             
                    end
         
     | 
    
        data/lib/torch/nn/init.rb
    CHANGED
    
    | 
         @@ -14,25 +14,11 @@ module Torch 
     | 
|
| 
       14 
14 
     | 
    
         
             
                      _normal!(tensor, mean, std)
         
     | 
| 
       15 
15 
     | 
    
         
             
                    end
         
     | 
| 
       16 
16 
     | 
    
         | 
| 
       17 
     | 
    
         
            -
                     
     | 
| 
       18 
     | 
    
         
            -
             
     | 
| 
       19 
     | 
    
         
            -
                     
     | 
| 
       20 
     | 
    
         
            -
             
     | 
| 
       21 
     | 
    
         
            -
                     
     | 
| 
       22 
     | 
    
         
            -
                      _ones!(tensor)
         
     | 
| 
       23 
     | 
    
         
            -
                    end
         
     | 
| 
       24 
     | 
    
         
            -
             
     | 
| 
       25 
     | 
    
         
            -
                    def zeros!(tensor)
         
     | 
| 
       26 
     | 
    
         
            -
                      _zeros!(tensor)
         
     | 
| 
       27 
     | 
    
         
            -
                    end
         
     | 
| 
       28 
     | 
    
         
            -
             
     | 
| 
       29 
     | 
    
         
            -
                    def eye!(tensor)
         
     | 
| 
       30 
     | 
    
         
            -
                      _eye!(tensor)
         
     | 
| 
       31 
     | 
    
         
            -
                    end
         
     | 
| 
       32 
     | 
    
         
            -
             
     | 
| 
       33 
     | 
    
         
            -
                    def dirac!(tensor)
         
     | 
| 
       34 
     | 
    
         
            -
                      _dirac!(tensor)
         
     | 
| 
       35 
     | 
    
         
            -
                    end
         
     | 
| 
      
 17 
     | 
    
         
            +
                    alias_method :constant!, :_constant!
         
     | 
| 
      
 18 
     | 
    
         
            +
                    alias_method :ones!, :_ones!
         
     | 
| 
      
 19 
     | 
    
         
            +
                    alias_method :zeros!, :_zeros!
         
     | 
| 
      
 20 
     | 
    
         
            +
                    alias_method :eye!, :_eye!
         
     | 
| 
      
 21 
     | 
    
         
            +
                    alias_method :dirac!, :_dirac!
         
     | 
| 
       36 
22 
     | 
    
         | 
| 
       37 
23 
     | 
    
         
             
                    def xavier_uniform!(tensor, gain: 1.0)
         
     | 
| 
       38 
24 
     | 
    
         
             
                      _xavier_uniform!(tensor, gain)
         
     | 
    
        data/lib/torch/optim/adadelta.rb
    CHANGED
    
    | 
         @@ -45,7 +45,7 @@ module Torch 
     | 
|
| 
       45 
45 
     | 
    
         
             
                        square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
         
     | 
| 
       46 
46 
     | 
    
         
             
                        std = square_avg.add(eps).sqrt!
         
     | 
| 
       47 
47 
     | 
    
         
             
                        delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
         
     | 
| 
       48 
     | 
    
         
            -
                        p.data.add!(-group[:lr] 
     | 
| 
      
 48 
     | 
    
         
            +
                        p.data.add!(delta, alpha: -group[:lr])
         
     | 
| 
       49 
49 
     | 
    
         
             
                        acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
         
     | 
| 
       50 
50 
     | 
    
         
             
                      end
         
     | 
| 
       51 
51 
     | 
    
         
             
                    end
         
     | 
    
        data/lib/torch/optim/adam.rb
    CHANGED
    
    | 
         @@ -53,11 +53,11 @@ module Torch 
     | 
|
| 
       53 
53 
     | 
    
         
             
                        bias_correction2 = 1 - beta2 ** state[:step]
         
     | 
| 
       54 
54 
     | 
    
         | 
| 
       55 
55 
     | 
    
         
             
                        if group[:weight_decay] != 0
         
     | 
| 
       56 
     | 
    
         
            -
                          grad.add!(group[:weight_decay] 
     | 
| 
      
 56 
     | 
    
         
            +
                          grad.add!(p.data, alpha: group[:weight_decay])
         
     | 
| 
       57 
57 
     | 
    
         
             
                        end
         
     | 
| 
       58 
58 
     | 
    
         | 
| 
       59 
59 
     | 
    
         
             
                        # Decay the first and second moment running average coefficient
         
     | 
| 
       60 
     | 
    
         
            -
                        exp_avg.mul!(beta1).add!(1 - beta1 
     | 
| 
      
 60 
     | 
    
         
            +
                        exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
         
     | 
| 
       61 
61 
     | 
    
         
             
                        exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
         
     | 
| 
       62 
62 
     | 
    
         
             
                        if amsgrad
         
     | 
| 
       63 
63 
     | 
    
         
             
                          # Maintains the maximum of all 2nd moment running avg. till now
         
     | 
    
        data/lib/torch/optim/adamax.rb
    CHANGED
    
    | 
         @@ -46,7 +46,7 @@ module Torch 
     | 
|
| 
       46 
46 
     | 
    
         
             
                        end
         
     | 
| 
       47 
47 
     | 
    
         | 
| 
       48 
48 
     | 
    
         
             
                        # Update biased first moment estimate.
         
     | 
| 
       49 
     | 
    
         
            -
                        exp_avg.mul!(beta1).add!(1 - beta1 
     | 
| 
      
 49 
     | 
    
         
            +
                        exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
         
     | 
| 
       50 
50 
     | 
    
         
             
                        # Update the exponentially weighted infinity norm.
         
     | 
| 
       51 
51 
     | 
    
         
             
                        norm_buf = Torch.cat([
         
     | 
| 
       52 
52 
     | 
    
         
             
                            exp_inf.mul!(beta2).unsqueeze(0),
         
     | 
    
        data/lib/torch/optim/adamw.rb
    CHANGED
    
    | 
         @@ -58,7 +58,7 @@ module Torch 
     | 
|
| 
       58 
58 
     | 
    
         
             
                        bias_correction2 = 1 - beta2 ** state[:step]
         
     | 
| 
       59 
59 
     | 
    
         | 
| 
       60 
60 
     | 
    
         
             
                        # Decay the first and second moment running average coefficient
         
     | 
| 
       61 
     | 
    
         
            -
                        exp_avg.mul!(beta1).add!(1 - beta1 
     | 
| 
      
 61 
     | 
    
         
            +
                        exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
         
     | 
| 
       62 
62 
     | 
    
         
             
                        exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
         
     | 
| 
       63 
63 
     | 
    
         
             
                        if amsgrad
         
     | 
| 
       64 
64 
     | 
    
         
             
                          # Maintains the maximum of all 2nd moment running avg. till now
         
     | 
    
        data/lib/torch/optim/asgd.rb
    CHANGED
    
    
    
        data/lib/torch/optim/sgd.rb
    CHANGED
    
    | 
         @@ -32,7 +32,7 @@ module Torch 
     | 
|
| 
       32 
32 
     | 
    
         
             
                        next unless p.grad
         
     | 
| 
       33 
33 
     | 
    
         
             
                        d_p = p.grad.data
         
     | 
| 
       34 
34 
     | 
    
         
             
                        if weight_decay != 0
         
     | 
| 
       35 
     | 
    
         
            -
                          d_p.add!( 
     | 
| 
      
 35 
     | 
    
         
            +
                          d_p.add!(p.data, alpha: weight_decay)
         
     | 
| 
       36 
36 
     | 
    
         
             
                        end
         
     | 
| 
       37 
37 
     | 
    
         
             
                        if momentum != 0
         
     | 
| 
       38 
38 
     | 
    
         
             
                          param_state = @state[p]
         
     | 
| 
         @@ -40,7 +40,7 @@ module Torch 
     | 
|
| 
       40 
40 
     | 
    
         
             
                            buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
         
     | 
| 
       41 
41 
     | 
    
         
             
                          else
         
     | 
| 
       42 
42 
     | 
    
         
             
                            buf = param_state[:momentum_buffer]
         
     | 
| 
       43 
     | 
    
         
            -
                            buf.mul!(momentum).add!(1 - dampening 
     | 
| 
      
 43 
     | 
    
         
            +
                            buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
         
     | 
| 
       44 
44 
     | 
    
         
             
                          end
         
     | 
| 
       45 
45 
     | 
    
         
             
                          if nesterov
         
     | 
| 
       46 
46 
     | 
    
         
             
                            d_p = d_p.add(momentum, buf)
         
     | 
| 
         @@ -49,7 +49,7 @@ module Torch 
     | 
|
| 
       49 
49 
     | 
    
         
             
                          end
         
     | 
| 
       50 
50 
     | 
    
         
             
                        end
         
     | 
| 
       51 
51 
     | 
    
         | 
| 
       52 
     | 
    
         
            -
                        p.data.add!(-group[:lr] 
     | 
| 
      
 52 
     | 
    
         
            +
                        p.data.add!(d_p, alpha: -group[:lr])
         
     | 
| 
       53 
53 
     | 
    
         
             
                      end
         
     | 
| 
       54 
54 
     | 
    
         
             
                    end
         
     | 
| 
       55 
55 
     | 
    
         |