torch-rb 0.4.2 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
 - data/CHANGELOG.md +26 -0
 - data/README.md +13 -3
 - data/codegen/generate_functions.rb +20 -13
 - data/codegen/native_functions.yaml +4129 -1521
 - data/ext/torch/cuda.cpp +14 -0
 - data/ext/torch/device.cpp +21 -0
 - data/ext/torch/ext.cpp +17 -623
 - data/ext/torch/extconf.rb +0 -1
 - data/ext/torch/ivalue.cpp +134 -0
 - data/ext/torch/nn.cpp +114 -0
 - data/ext/torch/nn_functions.h +1 -1
 - data/ext/torch/random.cpp +22 -0
 - data/ext/torch/ruby_arg_parser.cpp +1 -1
 - data/ext/torch/ruby_arg_parser.h +47 -7
 - data/ext/torch/templates.h +3 -2
 - data/ext/torch/tensor.cpp +307 -0
 - data/ext/torch/tensor_functions.h +1 -1
 - data/ext/torch/torch.cpp +86 -0
 - data/ext/torch/torch_functions.h +1 -1
 - data/ext/torch/utils.h +8 -1
 - data/ext/torch/wrap_outputs.h +7 -0
 - data/lib/torch.rb +14 -17
 - data/lib/torch/nn/linear.rb +2 -0
 - data/lib/torch/nn/module.rb +107 -21
 - data/lib/torch/nn/parameter.rb +1 -1
 - data/lib/torch/optim/adadelta.rb +2 -2
 - data/lib/torch/optim/adagrad.rb +2 -2
 - data/lib/torch/optim/adam.rb +2 -2
 - data/lib/torch/optim/adamax.rb +1 -1
 - data/lib/torch/optim/adamw.rb +2 -2
 - data/lib/torch/optim/rmsprop.rb +3 -3
 - data/lib/torch/optim/rprop.rb +1 -1
 - data/lib/torch/tensor.rb +9 -0
 - data/lib/torch/utils/data/data_loader.rb +1 -1
 - data/lib/torch/version.rb +1 -1
 - metadata +12 -89
 
    
        data/ext/torch/extconf.rb
    CHANGED
    
    | 
         @@ -11,7 +11,6 @@ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i 
     | 
|
| 
       11 
11 
     | 
    
         | 
| 
       12 
12 
     | 
    
         
             
            # check omp first
         
     | 
| 
       13 
13 
     | 
    
         
             
            if have_library("omp") || have_library("gomp")
         
     | 
| 
       14 
     | 
    
         
            -
              $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
         
     | 
| 
       15 
14 
     | 
    
         
             
              $CXXFLAGS += " -Xclang" if apple_clang
         
     | 
| 
       16 
15 
     | 
    
         
             
              $CXXFLAGS += " -fopenmp"
         
     | 
| 
       17 
16 
     | 
    
         
             
            end
         
     | 
| 
         @@ -0,0 +1,134 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            #include <torch/torch.h>
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            #include <rice/Array.hpp>
         
     | 
| 
      
 4 
     | 
    
         
            +
            #include <rice/Constructor.hpp>
         
     | 
| 
      
 5 
     | 
    
         
            +
            #include <rice/Hash.hpp>
         
     | 
| 
      
 6 
     | 
    
         
            +
            #include <rice/Module.hpp>
         
     | 
| 
      
 7 
     | 
    
         
            +
            #include <rice/String.hpp>
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            #include "utils.h"
         
     | 
| 
      
 10 
     | 
    
         
            +
             
     | 
| 
      
 11 
     | 
    
         
            +
            void init_ivalue(Rice::Module& m) {
         
     | 
| 
      
 12 
     | 
    
         
            +
              // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
         
     | 
| 
      
 13 
     | 
    
         
            +
              Rice::define_class_under<torch::IValue>(m, "IValue")
         
     | 
| 
      
 14 
     | 
    
         
            +
                .add_handler<torch::Error>(handle_error)
         
     | 
| 
      
 15 
     | 
    
         
            +
                .define_constructor(Rice::Constructor<torch::IValue>())
         
     | 
| 
      
 16 
     | 
    
         
            +
                .define_method("bool?", &torch::IValue::isBool)
         
     | 
| 
      
 17 
     | 
    
         
            +
                .define_method("bool_list?", &torch::IValue::isBoolList)
         
     | 
| 
      
 18 
     | 
    
         
            +
                .define_method("capsule?", &torch::IValue::isCapsule)
         
     | 
| 
      
 19 
     | 
    
         
            +
                .define_method("custom_class?", &torch::IValue::isCustomClass)
         
     | 
| 
      
 20 
     | 
    
         
            +
                .define_method("device?", &torch::IValue::isDevice)
         
     | 
| 
      
 21 
     | 
    
         
            +
                .define_method("double?", &torch::IValue::isDouble)
         
     | 
| 
      
 22 
     | 
    
         
            +
                .define_method("double_list?", &torch::IValue::isDoubleList)
         
     | 
| 
      
 23 
     | 
    
         
            +
                .define_method("future?", &torch::IValue::isFuture)
         
     | 
| 
      
 24 
     | 
    
         
            +
                // .define_method("generator?", &torch::IValue::isGenerator)
         
     | 
| 
      
 25 
     | 
    
         
            +
                .define_method("generic_dict?", &torch::IValue::isGenericDict)
         
     | 
| 
      
 26 
     | 
    
         
            +
                .define_method("list?", &torch::IValue::isList)
         
     | 
| 
      
 27 
     | 
    
         
            +
                .define_method("int?", &torch::IValue::isInt)
         
     | 
| 
      
 28 
     | 
    
         
            +
                .define_method("int_list?", &torch::IValue::isIntList)
         
     | 
| 
      
 29 
     | 
    
         
            +
                .define_method("module?", &torch::IValue::isModule)
         
     | 
| 
      
 30 
     | 
    
         
            +
                .define_method("none?", &torch::IValue::isNone)
         
     | 
| 
      
 31 
     | 
    
         
            +
                .define_method("object?", &torch::IValue::isObject)
         
     | 
| 
      
 32 
     | 
    
         
            +
                .define_method("ptr_type?", &torch::IValue::isPtrType)
         
     | 
| 
      
 33 
     | 
    
         
            +
                .define_method("py_object?", &torch::IValue::isPyObject)
         
     | 
| 
      
 34 
     | 
    
         
            +
                .define_method("r_ref?", &torch::IValue::isRRef)
         
     | 
| 
      
 35 
     | 
    
         
            +
                .define_method("scalar?", &torch::IValue::isScalar)
         
     | 
| 
      
 36 
     | 
    
         
            +
                .define_method("string?", &torch::IValue::isString)
         
     | 
| 
      
 37 
     | 
    
         
            +
                .define_method("tensor?", &torch::IValue::isTensor)
         
     | 
| 
      
 38 
     | 
    
         
            +
                .define_method("tensor_list?", &torch::IValue::isTensorList)
         
     | 
| 
      
 39 
     | 
    
         
            +
                .define_method("tuple?", &torch::IValue::isTuple)
         
     | 
| 
      
 40 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 41 
     | 
    
         
            +
                  "to_bool",
         
     | 
| 
      
 42 
     | 
    
         
            +
                  *[](torch::IValue& self) {
         
     | 
| 
      
 43 
     | 
    
         
            +
                    return self.toBool();
         
     | 
| 
      
 44 
     | 
    
         
            +
                  })
         
     | 
| 
      
 45 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 46 
     | 
    
         
            +
                  "to_double",
         
     | 
| 
      
 47 
     | 
    
         
            +
                  *[](torch::IValue& self) {
         
     | 
| 
      
 48 
     | 
    
         
            +
                    return self.toDouble();
         
     | 
| 
      
 49 
     | 
    
         
            +
                  })
         
     | 
| 
      
 50 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 51 
     | 
    
         
            +
                  "to_int",
         
     | 
| 
      
 52 
     | 
    
         
            +
                  *[](torch::IValue& self) {
         
     | 
| 
      
 53 
     | 
    
         
            +
                    return self.toInt();
         
     | 
| 
      
 54 
     | 
    
         
            +
                  })
         
     | 
| 
      
 55 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 56 
     | 
    
         
            +
                  "to_list",
         
     | 
| 
      
 57 
     | 
    
         
            +
                  *[](torch::IValue& self) {
         
     | 
| 
      
 58 
     | 
    
         
            +
                    auto list = self.toListRef();
         
     | 
| 
      
 59 
     | 
    
         
            +
                    Rice::Array obj;
         
     | 
| 
      
 60 
     | 
    
         
            +
                    for (auto& elem : list) {
         
     | 
| 
      
 61 
     | 
    
         
            +
                      obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
         
     | 
| 
      
 62 
     | 
    
         
            +
                    }
         
     | 
| 
      
 63 
     | 
    
         
            +
                    return obj;
         
     | 
| 
      
 64 
     | 
    
         
            +
                  })
         
     | 
| 
      
 65 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 66 
     | 
    
         
            +
                  "to_string_ref",
         
     | 
| 
      
 67 
     | 
    
         
            +
                  *[](torch::IValue& self) {
         
     | 
| 
      
 68 
     | 
    
         
            +
                    return self.toStringRef();
         
     | 
| 
      
 69 
     | 
    
         
            +
                  })
         
     | 
| 
      
 70 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 71 
     | 
    
         
            +
                  "to_tensor",
         
     | 
| 
      
 72 
     | 
    
         
            +
                  *[](torch::IValue& self) {
         
     | 
| 
      
 73 
     | 
    
         
            +
                    return self.toTensor();
         
     | 
| 
      
 74 
     | 
    
         
            +
                  })
         
     | 
| 
      
 75 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 76 
     | 
    
         
            +
                  "to_generic_dict",
         
     | 
| 
      
 77 
     | 
    
         
            +
                  *[](torch::IValue& self) {
         
     | 
| 
      
 78 
     | 
    
         
            +
                    auto dict = self.toGenericDict();
         
     | 
| 
      
 79 
     | 
    
         
            +
                    Rice::Hash obj;
         
     | 
| 
      
 80 
     | 
    
         
            +
                    for (auto& pair : dict) {
         
     | 
| 
      
 81 
     | 
    
         
            +
                      obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
         
     | 
| 
      
 82 
     | 
    
         
            +
                    }
         
     | 
| 
      
 83 
     | 
    
         
            +
                    return obj;
         
     | 
| 
      
 84 
     | 
    
         
            +
                  })
         
     | 
| 
      
 85 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 86 
     | 
    
         
            +
                  "from_tensor",
         
     | 
| 
      
 87 
     | 
    
         
            +
                  *[](torch::Tensor& v) {
         
     | 
| 
      
 88 
     | 
    
         
            +
                    return torch::IValue(v);
         
     | 
| 
      
 89 
     | 
    
         
            +
                  })
         
     | 
| 
      
 90 
     | 
    
         
            +
                // TODO create specialized list types?
         
     | 
| 
      
 91 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 92 
     | 
    
         
            +
                  "from_list",
         
     | 
| 
      
 93 
     | 
    
         
            +
                  *[](Rice::Array obj) {
         
     | 
| 
      
 94 
     | 
    
         
            +
                    c10::impl::GenericList list(c10::AnyType::get());
         
     | 
| 
      
 95 
     | 
    
         
            +
                    for (auto entry : obj) {
         
     | 
| 
      
 96 
     | 
    
         
            +
                      list.push_back(from_ruby<torch::IValue>(entry));
         
     | 
| 
      
 97 
     | 
    
         
            +
                    }
         
     | 
| 
      
 98 
     | 
    
         
            +
                    return torch::IValue(list);
         
     | 
| 
      
 99 
     | 
    
         
            +
                  })
         
     | 
| 
      
 100 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 101 
     | 
    
         
            +
                  "from_string",
         
     | 
| 
      
 102 
     | 
    
         
            +
                  *[](Rice::String v) {
         
     | 
| 
      
 103 
     | 
    
         
            +
                    return torch::IValue(v.str());
         
     | 
| 
      
 104 
     | 
    
         
            +
                  })
         
     | 
| 
      
 105 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 106 
     | 
    
         
            +
                  "from_int",
         
     | 
| 
      
 107 
     | 
    
         
            +
                  *[](int64_t v) {
         
     | 
| 
      
 108 
     | 
    
         
            +
                    return torch::IValue(v);
         
     | 
| 
      
 109 
     | 
    
         
            +
                  })
         
     | 
| 
      
 110 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 111 
     | 
    
         
            +
                  "from_double",
         
     | 
| 
      
 112 
     | 
    
         
            +
                  *[](double v) {
         
     | 
| 
      
 113 
     | 
    
         
            +
                    return torch::IValue(v);
         
     | 
| 
      
 114 
     | 
    
         
            +
                  })
         
     | 
| 
      
 115 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 116 
     | 
    
         
            +
                  "from_bool",
         
     | 
| 
      
 117 
     | 
    
         
            +
                  *[](bool v) {
         
     | 
| 
      
 118 
     | 
    
         
            +
                    return torch::IValue(v);
         
     | 
| 
      
 119 
     | 
    
         
            +
                  })
         
     | 
| 
      
 120 
     | 
    
         
            +
                // see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
         
     | 
| 
      
 121 
     | 
    
         
            +
                // createGenericDict and toIValue
         
     | 
| 
      
 122 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 123 
     | 
    
         
            +
                  "from_dict",
         
     | 
| 
      
 124 
     | 
    
         
            +
                  *[](Rice::Hash obj) {
         
     | 
| 
      
 125 
     | 
    
         
            +
                    auto key_type = c10::AnyType::get();
         
     | 
| 
      
 126 
     | 
    
         
            +
                    auto value_type = c10::AnyType::get();
         
     | 
| 
      
 127 
     | 
    
         
            +
                    c10::impl::GenericDict elems(key_type, value_type);
         
     | 
| 
      
 128 
     | 
    
         
            +
                    elems.reserve(obj.size());
         
     | 
| 
      
 129 
     | 
    
         
            +
                    for (auto entry : obj) {
         
     | 
| 
      
 130 
     | 
    
         
            +
                      elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Rice::Object) entry.second));
         
     | 
| 
      
 131 
     | 
    
         
            +
                    }
         
     | 
| 
      
 132 
     | 
    
         
            +
                    return torch::IValue(std::move(elems));
         
     | 
| 
      
 133 
     | 
    
         
            +
                  });
         
     | 
| 
      
 134 
     | 
    
         
            +
            }
         
     | 
    
        data/ext/torch/nn.cpp
    ADDED
    
    | 
         @@ -0,0 +1,114 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            #include <torch/torch.h>
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            #include <rice/Module.hpp>
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            #include "nn_functions.h"
         
     | 
| 
      
 6 
     | 
    
         
            +
            #include "templates.h"
         
     | 
| 
      
 7 
     | 
    
         
            +
            #include "utils.h"
         
     | 
| 
      
 8 
     | 
    
         
            +
             
     | 
| 
      
 9 
     | 
    
         
            +
            // need to make a distinction between parameters and tensors
         
     | 
| 
      
 10 
     | 
    
         
            +
            class Parameter: public torch::autograd::Variable {
         
     | 
| 
      
 11 
     | 
    
         
            +
              public:
         
     | 
| 
      
 12 
     | 
    
         
            +
                Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
         
     | 
| 
      
 13 
     | 
    
         
            +
            };
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
            void init_nn(Rice::Module& m) {
         
     | 
| 
      
 16 
     | 
    
         
            +
              auto rb_mNN = Rice::define_module_under(m, "NN");
         
     | 
| 
      
 17 
     | 
    
         
            +
              rb_mNN.add_handler<torch::Error>(handle_error);
         
     | 
| 
      
 18 
     | 
    
         
            +
              add_nn_functions(rb_mNN);
         
     | 
| 
      
 19 
     | 
    
         
            +
             
     | 
| 
      
 20 
     | 
    
         
            +
              Rice::define_module_under(rb_mNN, "Init")
         
     | 
| 
      
 21 
     | 
    
         
            +
                .add_handler<torch::Error>(handle_error)
         
     | 
| 
      
 22 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 23 
     | 
    
         
            +
                  "_calculate_gain",
         
     | 
| 
      
 24 
     | 
    
         
            +
                  *[](NonlinearityType nonlinearity, double param) {
         
     | 
| 
      
 25 
     | 
    
         
            +
                    return torch::nn::init::calculate_gain(nonlinearity, param);
         
     | 
| 
      
 26 
     | 
    
         
            +
                  })
         
     | 
| 
      
 27 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 28 
     | 
    
         
            +
                  "_uniform!",
         
     | 
| 
      
 29 
     | 
    
         
            +
                  *[](Tensor tensor, double low, double high) {
         
     | 
| 
      
 30 
     | 
    
         
            +
                    return torch::nn::init::uniform_(tensor, low, high);
         
     | 
| 
      
 31 
     | 
    
         
            +
                  })
         
     | 
| 
      
 32 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 33 
     | 
    
         
            +
                  "_normal!",
         
     | 
| 
      
 34 
     | 
    
         
            +
                  *[](Tensor tensor, double mean, double std) {
         
     | 
| 
      
 35 
     | 
    
         
            +
                    return torch::nn::init::normal_(tensor, mean, std);
         
     | 
| 
      
 36 
     | 
    
         
            +
                  })
         
     | 
| 
      
 37 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 38 
     | 
    
         
            +
                  "_constant!",
         
     | 
| 
      
 39 
     | 
    
         
            +
                  *[](Tensor tensor, Scalar value) {
         
     | 
| 
      
 40 
     | 
    
         
            +
                    return torch::nn::init::constant_(tensor, value);
         
     | 
| 
      
 41 
     | 
    
         
            +
                  })
         
     | 
| 
      
 42 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 43 
     | 
    
         
            +
                  "_ones!",
         
     | 
| 
      
 44 
     | 
    
         
            +
                  *[](Tensor tensor) {
         
     | 
| 
      
 45 
     | 
    
         
            +
                    return torch::nn::init::ones_(tensor);
         
     | 
| 
      
 46 
     | 
    
         
            +
                  })
         
     | 
| 
      
 47 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 48 
     | 
    
         
            +
                  "_zeros!",
         
     | 
| 
      
 49 
     | 
    
         
            +
                  *[](Tensor tensor) {
         
     | 
| 
      
 50 
     | 
    
         
            +
                    return torch::nn::init::zeros_(tensor);
         
     | 
| 
      
 51 
     | 
    
         
            +
                  })
         
     | 
| 
      
 52 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 53 
     | 
    
         
            +
                  "_eye!",
         
     | 
| 
      
 54 
     | 
    
         
            +
                  *[](Tensor tensor) {
         
     | 
| 
      
 55 
     | 
    
         
            +
                    return torch::nn::init::eye_(tensor);
         
     | 
| 
      
 56 
     | 
    
         
            +
                  })
         
     | 
| 
      
 57 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 58 
     | 
    
         
            +
                  "_dirac!",
         
     | 
| 
      
 59 
     | 
    
         
            +
                  *[](Tensor tensor) {
         
     | 
| 
      
 60 
     | 
    
         
            +
                    return torch::nn::init::dirac_(tensor);
         
     | 
| 
      
 61 
     | 
    
         
            +
                  })
         
     | 
| 
      
 62 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 63 
     | 
    
         
            +
                  "_xavier_uniform!",
         
     | 
| 
      
 64 
     | 
    
         
            +
                  *[](Tensor tensor, double gain) {
         
     | 
| 
      
 65 
     | 
    
         
            +
                    return torch::nn::init::xavier_uniform_(tensor, gain);
         
     | 
| 
      
 66 
     | 
    
         
            +
                  })
         
     | 
| 
      
 67 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 68 
     | 
    
         
            +
                  "_xavier_normal!",
         
     | 
| 
      
 69 
     | 
    
         
            +
                  *[](Tensor tensor, double gain) {
         
     | 
| 
      
 70 
     | 
    
         
            +
                    return torch::nn::init::xavier_normal_(tensor, gain);
         
     | 
| 
      
 71 
     | 
    
         
            +
                  })
         
     | 
| 
      
 72 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 73 
     | 
    
         
            +
                  "_kaiming_uniform!",
         
     | 
| 
      
 74 
     | 
    
         
            +
                  *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
         
     | 
| 
      
 75 
     | 
    
         
            +
                    return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
         
     | 
| 
      
 76 
     | 
    
         
            +
                  })
         
     | 
| 
      
 77 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 78 
     | 
    
         
            +
                  "_kaiming_normal!",
         
     | 
| 
      
 79 
     | 
    
         
            +
                  *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
         
     | 
| 
      
 80 
     | 
    
         
            +
                    return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
         
     | 
| 
      
 81 
     | 
    
         
            +
                  })
         
     | 
| 
      
 82 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 83 
     | 
    
         
            +
                  "_orthogonal!",
         
     | 
| 
      
 84 
     | 
    
         
            +
                  *[](Tensor tensor, double gain) {
         
     | 
| 
      
 85 
     | 
    
         
            +
                    return torch::nn::init::orthogonal_(tensor, gain);
         
     | 
| 
      
 86 
     | 
    
         
            +
                  })
         
     | 
| 
      
 87 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 88 
     | 
    
         
            +
                  "_sparse!",
         
     | 
| 
      
 89 
     | 
    
         
            +
                  *[](Tensor tensor, double sparsity, double std) {
         
     | 
| 
      
 90 
     | 
    
         
            +
                    return torch::nn::init::sparse_(tensor, sparsity, std);
         
     | 
| 
      
 91 
     | 
    
         
            +
                  });
         
     | 
| 
      
 92 
     | 
    
         
            +
             
     | 
| 
      
 93 
     | 
    
         
            +
              Rice::define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
         
     | 
| 
      
 94 
     | 
    
         
            +
                .add_handler<torch::Error>(handle_error)
         
     | 
| 
      
 95 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 96 
     | 
    
         
            +
                  "grad",
         
     | 
| 
      
 97 
     | 
    
         
            +
                  *[](Parameter& self) {
         
     | 
| 
      
 98 
     | 
    
         
            +
                    auto grad = self.grad();
         
     | 
| 
      
 99 
     | 
    
         
            +
                    return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
         
     | 
| 
      
 100 
     | 
    
         
            +
                  })
         
     | 
| 
      
 101 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 102 
     | 
    
         
            +
                  "grad=",
         
     | 
| 
      
 103 
     | 
    
         
            +
                  *[](Parameter& self, torch::Tensor& grad) {
         
     | 
| 
      
 104 
     | 
    
         
            +
                    self.mutable_grad() = grad;
         
     | 
| 
      
 105 
     | 
    
         
            +
                  })
         
     | 
| 
      
 106 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 107 
     | 
    
         
            +
                  "_make_subclass",
         
     | 
| 
      
 108 
     | 
    
         
            +
                  *[](Tensor& rd, bool requires_grad) {
         
     | 
| 
      
 109 
     | 
    
         
            +
                    auto data = rd.detach();
         
     | 
| 
      
 110 
     | 
    
         
            +
                    data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
         
     | 
| 
      
 111 
     | 
    
         
            +
                    auto var = data.set_requires_grad(requires_grad);
         
     | 
| 
      
 112 
     | 
    
         
            +
                    return Parameter(std::move(var));
         
     | 
| 
      
 113 
     | 
    
         
            +
                  });
         
     | 
| 
      
 114 
     | 
    
         
            +
            }
         
     | 
    
        data/ext/torch/nn_functions.h
    CHANGED
    
    
| 
         @@ -0,0 +1,22 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            #include <torch/torch.h>
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            #include <rice/Module.hpp>
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            #include "utils.h"
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            void init_random(Rice::Module& m) {
         
     | 
| 
      
 8 
     | 
    
         
            +
              Rice::define_module_under(m, "Random")
         
     | 
| 
      
 9 
     | 
    
         
            +
                .add_handler<torch::Error>(handle_error)
         
     | 
| 
      
 10 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 11 
     | 
    
         
            +
                  "initial_seed",
         
     | 
| 
      
 12 
     | 
    
         
            +
                  *[]() {
         
     | 
| 
      
 13 
     | 
    
         
            +
                    return at::detail::getDefaultCPUGenerator().current_seed();
         
     | 
| 
      
 14 
     | 
    
         
            +
                  })
         
     | 
| 
      
 15 
     | 
    
         
            +
                .define_singleton_method(
         
     | 
| 
      
 16 
     | 
    
         
            +
                  "seed",
         
     | 
| 
      
 17 
     | 
    
         
            +
                  *[]() {
         
     | 
| 
      
 18 
     | 
    
         
            +
                    // TODO set for CUDA when available
         
     | 
| 
      
 19 
     | 
    
         
            +
                    auto generator = at::detail::getDefaultCPUGenerator();
         
     | 
| 
      
 20 
     | 
    
         
            +
                    return generator.seed();
         
     | 
| 
      
 21 
     | 
    
         
            +
                  });
         
     | 
| 
      
 22 
     | 
    
         
            +
            }
         
     | 
| 
         @@ -487,7 +487,7 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num 
     | 
|
| 
       487 
487 
     | 
    
         | 
| 
       488 
488 
     | 
    
         
             
            VALUE missing = Qundef;
         
     | 
| 
       489 
489 
     | 
    
         | 
| 
       490 
     | 
    
         
            -
            bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs,  
     | 
| 
      
 490 
     | 
    
         
            +
            bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],  // NOLINT
         
     | 
| 
       491 
491 
     | 
    
         
             
                                          bool raise_exception) {
         
     | 
| 
       492 
492 
     | 
    
         
             
              auto nargs = NIL_P(args) ? 0 : RARRAY_LEN(args);
         
     | 
| 
       493 
493 
     | 
    
         
             
              ssize_t remaining_kwargs = NIL_P(kwargs) ? 0 :  RHASH_SIZE(kwargs);
         
     | 
    
        data/ext/torch/ruby_arg_parser.h
    CHANGED
    
    | 
         @@ -2,6 +2,8 @@ 
     | 
|
| 
       2 
2 
     | 
    
         | 
| 
       3 
3 
     | 
    
         
             
            #pragma once
         
     | 
| 
       4 
4 
     | 
    
         | 
| 
      
 5 
     | 
    
         
            +
            #include <sstream>
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
       5 
7 
     | 
    
         
             
            #include <torch/torch.h>
         
     | 
| 
       6 
8 
     | 
    
         
             
            #include <rice/Exception.hpp>
         
     | 
| 
       7 
9 
     | 
    
         | 
| 
         @@ -46,7 +48,7 @@ struct FunctionParameter { 
     | 
|
| 
       46 
48 
     | 
    
         
             
            struct FunctionSignature {
         
     | 
| 
       47 
49 
     | 
    
         
             
              explicit FunctionSignature(const std::string& fmt, int index);
         
     | 
| 
       48 
50 
     | 
    
         | 
| 
       49 
     | 
    
         
            -
              bool parse(VALUE self, VALUE args, VALUE kwargs,  
     | 
| 
      
 51 
     | 
    
         
            +
              bool parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[], bool raise_exception);
         
     | 
| 
       50 
52 
     | 
    
         | 
| 
       51 
53 
     | 
    
         
             
              std::string toString() const;
         
     | 
| 
       52 
54 
     | 
    
         | 
| 
         @@ -63,13 +65,13 @@ struct FunctionSignature { 
     | 
|
| 
       63 
65 
     | 
    
         
             
            };
         
     | 
| 
       64 
66 
     | 
    
         | 
| 
       65 
67 
     | 
    
         
             
            struct RubyArgs {
         
     | 
| 
       66 
     | 
    
         
            -
              RubyArgs(const FunctionSignature& signature,  
     | 
| 
      
 68 
     | 
    
         
            +
              RubyArgs(const FunctionSignature& signature, VALUE* args)
         
     | 
| 
       67 
69 
     | 
    
         
             
                : signature(signature)
         
     | 
| 
       68 
70 
     | 
    
         
             
                , args(args)
         
     | 
| 
       69 
71 
     | 
    
         
             
                , idx(signature.index) {}
         
     | 
| 
       70 
72 
     | 
    
         | 
| 
       71 
73 
     | 
    
         
             
              const FunctionSignature& signature;
         
     | 
| 
       72 
     | 
    
         
            -
               
     | 
| 
      
 74 
     | 
    
         
            +
              VALUE* args;
         
     | 
| 
       73 
75 
     | 
    
         
             
              int idx;
         
     | 
| 
       74 
76 
     | 
    
         | 
| 
       75 
77 
     | 
    
         
             
              inline at::Tensor tensor(int i);
         
     | 
| 
         @@ -91,7 +93,7 @@ struct RubyArgs { 
     | 
|
| 
       91 
93 
     | 
    
         
             
              inline c10::optional<int64_t> toInt64Optional(int i);
         
     | 
| 
       92 
94 
     | 
    
         
             
              inline c10::optional<bool> toBoolOptional(int i);
         
     | 
| 
       93 
95 
     | 
    
         
             
              inline c10::optional<double> toDoubleOptional(int i);
         
     | 
| 
       94 
     | 
    
         
            -
               
     | 
| 
      
 96 
     | 
    
         
            +
              inline c10::OptionalArray<double> doublelistOptional(int i);
         
     | 
| 
       95 
97 
     | 
    
         
             
              // inline at::Layout layout(int i);
         
     | 
| 
       96 
98 
     | 
    
         
             
              // inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
         
     | 
| 
       97 
99 
     | 
    
         
             
              inline c10::optional<at::Layout> layoutOptional(int i);
         
     | 
| 
         @@ -105,7 +107,7 @@ struct RubyArgs { 
     | 
|
| 
       105 
107 
     | 
    
         
             
              inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
         
     | 
| 
       106 
108 
     | 
    
         
             
              // inline at::QScheme toQScheme(int i);
         
     | 
| 
       107 
109 
     | 
    
         
             
              inline std::string string(int i);
         
     | 
| 
       108 
     | 
    
         
            -
               
     | 
| 
      
 110 
     | 
    
         
            +
              inline c10::optional<std::string> stringOptional(int i);
         
     | 
| 
       109 
111 
     | 
    
         
             
              // inline PyObject* pyobject(int i);
         
     | 
| 
       110 
112 
     | 
    
         
             
              inline int64_t toInt64(int i);
         
     | 
| 
       111 
113 
     | 
    
         
             
              // inline int64_t toInt64WithDefault(int i, int64_t default_int);
         
     | 
| 
         @@ -249,6 +251,25 @@ inline c10::optional<double> RubyArgs::toDoubleOptional(int i) { 
     | 
|
| 
       249 
251 
     | 
    
         
             
              return toDouble(i);
         
     | 
| 
       250 
252 
     | 
    
         
             
            }
         
     | 
| 
       251 
253 
     | 
    
         | 
| 
      
 254 
     | 
    
         
            +
            inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
         
     | 
| 
      
 255 
     | 
    
         
            +
              if (NIL_P(args[i])) return {};
         
     | 
| 
      
 256 
     | 
    
         
            +
             
     | 
| 
      
 257 
     | 
    
         
            +
              VALUE arg = args[i];
         
     | 
| 
      
 258 
     | 
    
         
            +
              auto size = RARRAY_LEN(arg);
         
     | 
| 
      
 259 
     | 
    
         
            +
              std::vector<double> res(size);
         
     | 
| 
      
 260 
     | 
    
         
            +
              for (idx = 0; idx < size; idx++) {
         
     | 
| 
      
 261 
     | 
    
         
            +
                VALUE obj = rb_ary_entry(arg, idx);
         
     | 
| 
      
 262 
     | 
    
         
            +
                if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
         
     | 
| 
      
 263 
     | 
    
         
            +
                  res[idx] = from_ruby<double>(obj);
         
     | 
| 
      
 264 
     | 
    
         
            +
                } else {
         
     | 
| 
      
 265 
     | 
    
         
            +
                  rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
         
     | 
| 
      
 266 
     | 
    
         
            +
                      signature.name.c_str(), signature.params[i].name.c_str(),
         
     | 
| 
      
 267 
     | 
    
         
            +
                      signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
         
     | 
| 
      
 268 
     | 
    
         
            +
                }
         
     | 
| 
      
 269 
     | 
    
         
            +
              }
         
     | 
| 
      
 270 
     | 
    
         
            +
              return res;
         
     | 
| 
      
 271 
     | 
    
         
            +
            }
         
     | 
| 
      
 272 
     | 
    
         
            +
             
     | 
| 
       252 
273 
     | 
    
         
             
            inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
         
     | 
| 
       253 
274 
     | 
    
         
             
              if (NIL_P(args[i])) return c10::nullopt;
         
     | 
| 
       254 
275 
     | 
    
         | 
| 
         @@ -285,6 +306,11 @@ inline std::string RubyArgs::string(int i) { 
     | 
|
| 
       285 
306 
     | 
    
         
             
              return from_ruby<std::string>(args[i]);
         
     | 
| 
       286 
307 
     | 
    
         
             
            }
         
     | 
| 
       287 
308 
     | 
    
         | 
| 
      
 309 
     | 
    
         
            +
            inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
         
     | 
| 
      
 310 
     | 
    
         
            +
              if (!args[i]) return c10::nullopt;
         
     | 
| 
      
 311 
     | 
    
         
            +
              return from_ruby<std::string>(args[i]);
         
     | 
| 
      
 312 
     | 
    
         
            +
            }
         
     | 
| 
      
 313 
     | 
    
         
            +
             
     | 
| 
       288 
314 
     | 
    
         
             
            inline int64_t RubyArgs::toInt64(int i) {
         
     | 
| 
       289 
315 
     | 
    
         
             
              if (NIL_P(args[i])) return signature.params[i].default_int;
         
     | 
| 
       290 
316 
     | 
    
         
             
              return from_ruby<int64_t>(args[i]);
         
     | 
| 
         @@ -304,6 +330,12 @@ inline bool RubyArgs::isNone(int i) { 
     | 
|
| 
       304 
330 
     | 
    
         
             
              return NIL_P(args[i]);
         
     | 
| 
       305 
331 
     | 
    
         
             
            }
         
     | 
| 
       306 
332 
     | 
    
         | 
| 
      
 333 
     | 
    
         
            +
            template<int N>
         
     | 
| 
      
 334 
     | 
    
         
            +
            struct ParsedArgs {
         
     | 
| 
      
 335 
     | 
    
         
            +
              ParsedArgs() : args() { }
         
     | 
| 
      
 336 
     | 
    
         
            +
              VALUE args[N];
         
     | 
| 
      
 337 
     | 
    
         
            +
            };
         
     | 
| 
      
 338 
     | 
    
         
            +
             
     | 
| 
       307 
339 
     | 
    
         
             
            struct RubyArgParser {
         
     | 
| 
       308 
340 
     | 
    
         
             
              std::vector<FunctionSignature> signatures_;
         
     | 
| 
       309 
341 
     | 
    
         
             
              std::string function_name;
         
     | 
| 
         @@ -332,7 +364,15 @@ struct RubyArgParser { 
     | 
|
| 
       332 
364 
     | 
    
         
             
                    });
         
     | 
| 
       333 
365 
     | 
    
         
             
                }
         
     | 
| 
       334 
366 
     | 
    
         | 
| 
       335 
     | 
    
         
            -
                 
     | 
| 
      
 367 
     | 
    
         
            +
                template<int N>
         
     | 
| 
      
 368 
     | 
    
         
            +
                inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N> &dst) {
         
     | 
| 
      
 369 
     | 
    
         
            +
                  if (N < max_args) {
         
     | 
| 
      
 370 
     | 
    
         
            +
                    rb_raise(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", (int)max_args, N);
         
     | 
| 
      
 371 
     | 
    
         
            +
                  }
         
     | 
| 
      
 372 
     | 
    
         
            +
                  return raw_parse(self, argc, argv, dst.args);
         
     | 
| 
      
 373 
     | 
    
         
            +
                }
         
     | 
| 
      
 374 
     | 
    
         
            +
             
     | 
| 
      
 375 
     | 
    
         
            +
                inline RubyArgs raw_parse(VALUE self, int argc, VALUE* argv, VALUE parsed_args[]) {
         
     | 
| 
       336 
376 
     | 
    
         
             
                  VALUE args, kwargs;
         
     | 
| 
       337 
377 
     | 
    
         
             
                  rb_scan_args(argc, argv, "*:", &args, &kwargs);
         
     | 
| 
       338 
378 
     | 
    
         | 
| 
         @@ -354,7 +394,7 @@ struct RubyArgParser { 
     | 
|
| 
       354 
394 
     | 
    
         
             
                  rb_raise(rb_eArgError, "No matching signatures");
         
     | 
| 
       355 
395 
     | 
    
         
             
                }
         
     | 
| 
       356 
396 
     | 
    
         | 
| 
       357 
     | 
    
         
            -
                void print_error(VALUE self, VALUE args, VALUE kwargs,  
     | 
| 
      
 397 
     | 
    
         
            +
                void print_error(VALUE self, VALUE args, VALUE kwargs, VALUE parsed_args[]) {
         
     | 
| 
       358 
398 
     | 
    
         
             
                  ssize_t num_args = (NIL_P(args) ? 0 : RARRAY_LEN(args)) + (NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs));
         
     | 
| 
       359 
399 
     | 
    
         
             
                  std::vector<int> plausible_idxs;
         
     | 
| 
       360 
400 
     | 
    
         
             
                  ssize_t i = 0;
         
     | 
    
        data/ext/torch/templates.h
    CHANGED
    
    | 
         @@ -19,6 +19,7 @@ using torch::TensorOptions; 
     | 
|
| 
       19 
19 
     | 
    
         
             
            using torch::Layout;
         
     | 
| 
       20 
20 
     | 
    
         
             
            using torch::MemoryFormat;
         
     | 
| 
       21 
21 
     | 
    
         
             
            using torch::IntArrayRef;
         
     | 
| 
      
 22 
     | 
    
         
            +
            using torch::ArrayRef;
         
     | 
| 
       22 
23 
     | 
    
         
             
            using torch::TensorList;
         
     | 
| 
       23 
24 
     | 
    
         
             
            using torch::Storage;
         
     | 
| 
       24 
25 
     | 
    
         | 
| 
         @@ -43,7 +44,7 @@ std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x) 
     | 
|
| 
       43 
44 
     | 
    
         
             
            {
         
     | 
| 
       44 
45 
     | 
    
         
             
              Array a = Array(x);
         
     | 
| 
       45 
46 
     | 
    
         
             
              std::vector<int64_t> vec(a.size());
         
     | 
| 
       46 
     | 
    
         
            -
              for ( 
     | 
| 
      
 47 
     | 
    
         
            +
              for (long i = 0; i < a.size(); i++) {
         
     | 
| 
       47 
48 
     | 
    
         
             
                vec[i] = from_ruby<int64_t>(a[i]);
         
     | 
| 
       48 
49 
     | 
    
         
             
              }
         
     | 
| 
       49 
50 
     | 
    
         
             
              return vec;
         
     | 
| 
         @@ -55,7 +56,7 @@ std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x) 
     | 
|
| 
       55 
56 
     | 
    
         
             
            {
         
     | 
| 
       56 
57 
     | 
    
         
             
              Array a = Array(x);
         
     | 
| 
       57 
58 
     | 
    
         
             
              std::vector<Tensor> vec(a.size());
         
     | 
| 
       58 
     | 
    
         
            -
              for ( 
     | 
| 
      
 59 
     | 
    
         
            +
              for (long i = 0; i < a.size(); i++) {
         
     | 
| 
       59 
60 
     | 
    
         
             
                vec[i] = from_ruby<Tensor>(a[i]);
         
     | 
| 
       60 
61 
     | 
    
         
             
              }
         
     | 
| 
       61 
62 
     | 
    
         
             
              return vec;
         
     |