torch-rb 0.21.0 → 0.22.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
 - data/CHANGELOG.md +9 -0
 - data/README.md +2 -3
 - data/codegen/generate_functions.rb +5 -1
 - data/codegen/native_functions.yaml +239 -152
 - data/ext/torch/ext.cpp +4 -0
 - data/ext/torch/ivalue.cpp +1 -1
 - data/ext/torch/templates.h +36 -1
 - data/ext/torch/tensor.cpp +3 -3
 - data/ext/torch/utils.h +4 -2
 - data/lib/torch/version.rb +1 -1
 - data/lib/torch.rb +0 -1
 - metadata +3 -3
 
    
        data/ext/torch/ext.cpp
    CHANGED
    
    | 
         @@ -16,10 +16,14 @@ void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator); 
     | 
|
| 
       16 
16 
     | 
    
         
             
            void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
         
     | 
| 
       17 
17 
     | 
    
         
             
            void init_random(Rice::Module& m);
         
     | 
| 
       18 
18 
     | 
    
         | 
| 
      
 19 
     | 
    
         
            +
            VALUE rb_eTorchError = Qnil;
         
     | 
| 
      
 20 
     | 
    
         
            +
             
     | 
| 
       19 
21 
     | 
    
         
             
            extern "C"
         
     | 
| 
       20 
22 
     | 
    
         
             
            void Init_ext() {
         
     | 
| 
       21 
23 
     | 
    
         
             
              auto m = Rice::define_module("Torch");
         
     | 
| 
       22 
24 
     | 
    
         | 
| 
      
 25 
     | 
    
         
            +
              rb_eTorchError = Rice::define_class_under(m, "Error", rb_eStandardError);
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
       23 
27 
     | 
    
         
             
              // need to define certain classes up front to keep Rice happy
         
     | 
| 
       24 
28 
     | 
    
         
             
              auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
         
     | 
| 
       25 
29 
     | 
    
         
             
                .define_constructor(Rice::Constructor<torch::IValue>());
         
     | 
    
        data/ext/torch/ivalue.cpp
    CHANGED
    
    | 
         @@ -55,7 +55,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) { 
     | 
|
| 
       55 
55 
     | 
    
         
             
                    Rice::Array obj;
         
     | 
| 
       56 
56 
     | 
    
         
             
                    for (auto& elem : list) {
         
     | 
| 
       57 
57 
     | 
    
         
             
                      auto v = torch::IValue{elem};
         
     | 
| 
       58 
     | 
    
         
            -
                      obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)));
         
     | 
| 
      
 58 
     | 
    
         
            +
                      obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)), false);
         
     | 
| 
       59 
59 
     | 
    
         
             
                    }
         
     | 
| 
       60 
60 
     | 
    
         
             
                    return obj;
         
     | 
| 
       61 
61 
     | 
    
         
             
                  })
         
     | 
    
        data/ext/torch/templates.h
    CHANGED
    
    | 
         @@ -31,7 +31,7 @@ using torch::nn::init::NonlinearityType; 
     | 
|
| 
       31 
31 
     | 
    
         | 
| 
       32 
32 
     | 
    
         
             
            #define END_HANDLE_TH_ERRORS                                         \
         
     | 
| 
       33 
33 
     | 
    
         
             
              } catch (const torch::Error& ex) {                                 \
         
     | 
| 
       34 
     | 
    
         
            -
                rb_raise( 
     | 
| 
      
 34 
     | 
    
         
            +
                rb_raise(rb_eTorchError, "%s", ex.what_without_backtrace());   \
         
     | 
| 
       35 
35 
     | 
    
         
             
              } catch (const Rice::Exception& ex) {                              \
         
     | 
| 
       36 
36 
     | 
    
         
             
                rb_raise(ex.class_of(), "%s", ex.what());                        \
         
     | 
| 
       37 
37 
     | 
    
         
             
              } catch (const std::exception& ex) {                               \
         
     | 
| 
         @@ -50,14 +50,25 @@ namespace Rice::detail { 
     | 
|
| 
       50 
50 
     | 
    
         
             
              template<typename T>
         
     | 
| 
       51 
51 
     | 
    
         
             
              class To_Ruby<c10::complex<T>> {
         
     | 
| 
       52 
52 
     | 
    
         
             
              public:
         
     | 
| 
      
 53 
     | 
    
         
            +
                To_Ruby() = default;
         
     | 
| 
      
 54 
     | 
    
         
            +
             
     | 
| 
      
 55 
     | 
    
         
            +
                explicit To_Ruby(Arg* arg) : arg_(arg) { }
         
     | 
| 
      
 56 
     | 
    
         
            +
             
     | 
| 
       53 
57 
     | 
    
         
             
                VALUE convert(c10::complex<T> const& x) {
         
     | 
| 
       54 
58 
     | 
    
         
             
                  return rb_dbl_complex_new(x.real(), x.imag());
         
     | 
| 
       55 
59 
     | 
    
         
             
                }
         
     | 
| 
      
 60 
     | 
    
         
            +
             
     | 
| 
      
 61 
     | 
    
         
            +
              private:
         
     | 
| 
      
 62 
     | 
    
         
            +
                Arg* arg_ = nullptr;
         
     | 
| 
       56 
63 
     | 
    
         
             
              };
         
     | 
| 
       57 
64 
     | 
    
         | 
| 
       58 
65 
     | 
    
         
             
              template<typename T>
         
     | 
| 
       59 
66 
     | 
    
         
             
              class From_Ruby<c10::complex<T>> {
         
     | 
| 
       60 
67 
     | 
    
         
             
              public:
         
     | 
| 
      
 68 
     | 
    
         
            +
                From_Ruby() = default;
         
     | 
| 
      
 69 
     | 
    
         
            +
             
     | 
| 
      
 70 
     | 
    
         
            +
                explicit From_Ruby(Arg* arg) : arg_(arg) { }
         
     | 
| 
      
 71 
     | 
    
         
            +
             
     | 
| 
       61 
72 
     | 
    
         
             
                Convertible is_convertible(VALUE value) { return Convertible::Cast; }
         
     | 
| 
       62 
73 
     | 
    
         | 
| 
       63 
74 
     | 
    
         
             
                c10::complex<T> convert(VALUE x) {
         
     | 
| 
         @@ -65,6 +76,9 @@ namespace Rice::detail { 
     | 
|
| 
       65 
76 
     | 
    
         
             
                  VALUE imag = rb_funcall(x, rb_intern("imag"), 0);
         
     | 
| 
       66 
77 
     | 
    
         
             
                  return c10::complex<T>(From_Ruby<T>().convert(real), From_Ruby<T>().convert(imag));
         
     | 
| 
       67 
78 
     | 
    
         
             
                }
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
              private:
         
     | 
| 
      
 81 
     | 
    
         
            +
                Arg* arg_ = nullptr;
         
     | 
| 
       68 
82 
     | 
    
         
             
              };
         
     | 
| 
       69 
83 
     | 
    
         | 
| 
       70 
84 
     | 
    
         
             
              template<>
         
     | 
| 
         @@ -75,6 +89,10 @@ namespace Rice::detail { 
     | 
|
| 
       75 
89 
     | 
    
         
             
              template<>
         
     | 
| 
       76 
90 
     | 
    
         
             
              class From_Ruby<FanModeType> {
         
     | 
| 
       77 
91 
     | 
    
         
             
              public:
         
     | 
| 
      
 92 
     | 
    
         
            +
                From_Ruby() = default;
         
     | 
| 
      
 93 
     | 
    
         
            +
             
     | 
| 
      
 94 
     | 
    
         
            +
                explicit From_Ruby(Arg* arg) : arg_(arg) { }
         
     | 
| 
      
 95 
     | 
    
         
            +
             
     | 
| 
       78 
96 
     | 
    
         
             
                Convertible is_convertible(VALUE value) { return Convertible::Cast; }
         
     | 
| 
       79 
97 
     | 
    
         | 
| 
       80 
98 
     | 
    
         
             
                FanModeType convert(VALUE x) {
         
     | 
| 
         @@ -87,6 +105,9 @@ namespace Rice::detail { 
     | 
|
| 
       87 
105 
     | 
    
         
             
                    throw std::runtime_error("Unsupported nonlinearity type: " + s);
         
     | 
| 
       88 
106 
     | 
    
         
             
                  }
         
     | 
| 
       89 
107 
     | 
    
         
             
                }
         
     | 
| 
      
 108 
     | 
    
         
            +
             
     | 
| 
      
 109 
     | 
    
         
            +
              private:
         
     | 
| 
      
 110 
     | 
    
         
            +
                Arg* arg_ = nullptr;
         
     | 
| 
       90 
111 
     | 
    
         
             
              };
         
     | 
| 
       91 
112 
     | 
    
         | 
| 
       92 
113 
     | 
    
         
             
              template<>
         
     | 
| 
         @@ -97,6 +118,10 @@ namespace Rice::detail { 
     | 
|
| 
       97 
118 
     | 
    
         
             
              template<>
         
     | 
| 
       98 
119 
     | 
    
         
             
              class From_Ruby<NonlinearityType> {
         
     | 
| 
       99 
120 
     | 
    
         
             
              public:
         
     | 
| 
      
 121 
     | 
    
         
            +
                From_Ruby() = default;
         
     | 
| 
      
 122 
     | 
    
         
            +
             
     | 
| 
      
 123 
     | 
    
         
            +
                explicit From_Ruby(Arg* arg) : arg_(arg) { }
         
     | 
| 
      
 124 
     | 
    
         
            +
             
     | 
| 
       100 
125 
     | 
    
         
             
                Convertible is_convertible(VALUE value) { return Convertible::Cast; }
         
     | 
| 
       101 
126 
     | 
    
         | 
| 
       102 
127 
     | 
    
         
             
                NonlinearityType convert(VALUE x) {
         
     | 
| 
         @@ -127,6 +152,9 @@ namespace Rice::detail { 
     | 
|
| 
       127 
152 
     | 
    
         
             
                    throw std::runtime_error("Unsupported nonlinearity type: " + s);
         
     | 
| 
       128 
153 
     | 
    
         
             
                  }
         
     | 
| 
       129 
154 
     | 
    
         
             
                }
         
     | 
| 
      
 155 
     | 
    
         
            +
             
     | 
| 
      
 156 
     | 
    
         
            +
              private:
         
     | 
| 
      
 157 
     | 
    
         
            +
                Arg* arg_ = nullptr;
         
     | 
| 
       130 
158 
     | 
    
         
             
              };
         
     | 
| 
       131 
159 
     | 
    
         | 
| 
       132 
160 
     | 
    
         
             
              template<>
         
     | 
| 
         @@ -137,6 +165,10 @@ namespace Rice::detail { 
     | 
|
| 
       137 
165 
     | 
    
         
             
              template<>
         
     | 
| 
       138 
166 
     | 
    
         
             
              class From_Ruby<Scalar> {
         
     | 
| 
       139 
167 
     | 
    
         
             
              public:
         
     | 
| 
      
 168 
     | 
    
         
            +
                From_Ruby() = default;
         
     | 
| 
      
 169 
     | 
    
         
            +
             
     | 
| 
      
 170 
     | 
    
         
            +
                explicit From_Ruby(Arg* arg) : arg_(arg) { }
         
     | 
| 
      
 171 
     | 
    
         
            +
             
     | 
| 
       140 
172 
     | 
    
         
             
                Convertible is_convertible(VALUE value) { return Convertible::Cast; }
         
     | 
| 
       141 
173 
     | 
    
         | 
| 
       142 
174 
     | 
    
         
             
                Scalar convert(VALUE x) {
         
     | 
| 
         @@ -146,5 +178,8 @@ namespace Rice::detail { 
     | 
|
| 
       146 
178 
     | 
    
         
             
                    return torch::Scalar(From_Ruby<double>().convert(x));
         
     | 
| 
       147 
179 
     | 
    
         
             
                  }
         
     | 
| 
       148 
180 
     | 
    
         
             
                }
         
     | 
| 
      
 181 
     | 
    
         
            +
             
     | 
| 
      
 182 
     | 
    
         
            +
              private:
         
     | 
| 
      
 183 
     | 
    
         
            +
                Arg* arg_ = nullptr;
         
     | 
| 
       149 
184 
     | 
    
         
             
              };
         
     | 
| 
       150 
185 
     | 
    
         
             
            } // namespace Rice::detail
         
     | 
    
        data/ext/torch/tensor.cpp
    CHANGED
    
    | 
         @@ -20,7 +20,7 @@ Array flat_data(Tensor& tensor) { 
     | 
|
| 
       20 
20 
     | 
    
         | 
| 
       21 
21 
     | 
    
         
             
              Array a;
         
     | 
| 
       22 
22 
     | 
    
         
             
              for (int i = 0; i < tensor.numel(); i++) {
         
     | 
| 
       23 
     | 
    
         
            -
                a.push(view[i].item().to<T>());
         
     | 
| 
      
 23 
     | 
    
         
            +
                a.push(view[i].item().to<T>(), false);
         
     | 
| 
       24 
24 
     | 
    
         
             
              }
         
     | 
| 
       25 
25 
     | 
    
         
             
              return a;
         
     | 
| 
       26 
26 
     | 
    
         
             
            }
         
     | 
| 
         @@ -129,7 +129,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions 
     | 
|
| 
       129 
129 
     | 
    
         
             
                  [](Tensor& self) {
         
     | 
| 
       130 
130 
     | 
    
         
             
                    Array a;
         
     | 
| 
       131 
131 
     | 
    
         
             
                    for (auto &size : self.sizes()) {
         
     | 
| 
       132 
     | 
    
         
            -
                      a.push(size);
         
     | 
| 
      
 132 
     | 
    
         
            +
                      a.push(size, false);
         
     | 
| 
       133 
133 
     | 
    
         
             
                    }
         
     | 
| 
       134 
134 
     | 
    
         
             
                    return a;
         
     | 
| 
       135 
135 
     | 
    
         
             
                  })
         
     | 
| 
         @@ -138,7 +138,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions 
     | 
|
| 
       138 
138 
     | 
    
         
             
                  [](Tensor& self) {
         
     | 
| 
       139 
139 
     | 
    
         
             
                    Array a;
         
     | 
| 
       140 
140 
     | 
    
         
             
                    for (auto &stride : self.strides()) {
         
     | 
| 
       141 
     | 
    
         
            -
                      a.push(stride);
         
     | 
| 
      
 141 
     | 
    
         
            +
                      a.push(stride, false);
         
     | 
| 
       142 
142 
     | 
    
         
             
                    }
         
     | 
| 
       143 
143 
     | 
    
         
             
                    return a;
         
     | 
| 
       144 
144 
     | 
    
         
             
                  })
         
     | 
    
        data/ext/torch/utils.h
    CHANGED
    
    | 
         @@ -8,12 +8,14 @@ 
     | 
|
| 
       8 
8 
     | 
    
         
             
            #include <rice/stl.hpp>
         
     | 
| 
       9 
9 
     | 
    
         | 
| 
       10 
10 
     | 
    
         
             
            static_assert(
         
     | 
| 
       11 
     | 
    
         
            -
              TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR ==  
     | 
| 
      
 11 
     | 
    
         
            +
              TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 9,
         
     | 
| 
       12 
12 
     | 
    
         
             
              "Incompatible LibTorch version"
         
     | 
| 
       13 
13 
     | 
    
         
             
            );
         
     | 
| 
       14 
14 
     | 
    
         | 
| 
      
 15 
     | 
    
         
            +
            extern VALUE rb_eTorchError;
         
     | 
| 
      
 16 
     | 
    
         
            +
             
     | 
| 
       15 
17 
     | 
    
         
             
            inline void handle_global_error(const torch::Error& ex) {
         
     | 
| 
       16 
     | 
    
         
            -
              throw Rice::Exception( 
     | 
| 
      
 18 
     | 
    
         
            +
              throw Rice::Exception(rb_eTorchError, ex.what_without_backtrace());
         
     | 
| 
       17 
19 
     | 
    
         
             
            }
         
     | 
| 
       18 
20 
     | 
    
         | 
| 
       19 
21 
     | 
    
         
             
            // keep THP prefix for now to make it easier to compare code
         
     | 
    
        data/lib/torch/version.rb
    CHANGED
    
    
    
        data/lib/torch.rb
    CHANGED
    
    | 
         @@ -210,7 +210,6 @@ require_relative "torch/utils/data/tensor_dataset" 
     | 
|
| 
       210 
210 
     | 
    
         
             
            require_relative "torch/hub"
         
     | 
| 
       211 
211 
     | 
    
         | 
| 
       212 
212 
     | 
    
         
             
            module Torch
         
     | 
| 
       213 
     | 
    
         
            -
              class Error < StandardError; end
         
     | 
| 
       214 
213 
     | 
    
         
             
              class NotImplementedYet < StandardError
         
     | 
| 
       215 
214 
     | 
    
         
             
                def message
         
     | 
| 
       216 
215 
     | 
    
         
             
                  "This feature has not been implemented yet. Consider submitting a PR."
         
     | 
    
        metadata
    CHANGED
    
    | 
         @@ -1,7 +1,7 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            --- !ruby/object:Gem::Specification
         
     | 
| 
       2 
2 
     | 
    
         
             
            name: torch-rb
         
     | 
| 
       3 
3 
     | 
    
         
             
            version: !ruby/object:Gem::Version
         
     | 
| 
       4 
     | 
    
         
            -
              version: 0. 
     | 
| 
      
 4 
     | 
    
         
            +
              version: 0.22.1
         
     | 
| 
       5 
5 
     | 
    
         
             
            platform: ruby
         
     | 
| 
       6 
6 
     | 
    
         
             
            authors:
         
     | 
| 
       7 
7 
     | 
    
         
             
            - Andrew Kane
         
     | 
| 
         @@ -15,14 +15,14 @@ dependencies: 
     | 
|
| 
       15 
15 
     | 
    
         
             
                requirements:
         
     | 
| 
       16 
16 
     | 
    
         
             
                - - ">="
         
     | 
| 
       17 
17 
     | 
    
         
             
                  - !ruby/object:Gem::Version
         
     | 
| 
       18 
     | 
    
         
            -
                    version: '4. 
     | 
| 
      
 18 
     | 
    
         
            +
                    version: '4.7'
         
     | 
| 
       19 
19 
     | 
    
         
             
              type: :runtime
         
     | 
| 
       20 
20 
     | 
    
         
             
              prerelease: false
         
     | 
| 
       21 
21 
     | 
    
         
             
              version_requirements: !ruby/object:Gem::Requirement
         
     | 
| 
       22 
22 
     | 
    
         
             
                requirements:
         
     | 
| 
       23 
23 
     | 
    
         
             
                - - ">="
         
     | 
| 
       24 
24 
     | 
    
         
             
                  - !ruby/object:Gem::Version
         
     | 
| 
       25 
     | 
    
         
            -
                    version: '4. 
     | 
| 
      
 25 
     | 
    
         
            +
                    version: '4.7'
         
     | 
| 
       26 
26 
     | 
    
         
             
            email: andrew@ankane.org
         
     | 
| 
       27 
27 
     | 
    
         
             
            executables: []
         
     | 
| 
       28 
28 
     | 
    
         
             
            extensions:
         
     |