torch-rb 0.12.0 → 0.12.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
 - data/CHANGELOG.md +4 -0
 - data/ext/torch/ext.cpp +3 -0
 - data/ext/torch/generator.cpp +50 -0
 - data/ext/torch/ruby_arg_parser.cpp +2 -1
 - data/ext/torch/ruby_arg_parser.h +1 -1
 - data/ext/torch/utils.h +5 -0
 - data/lib/torch/version.rb +1 -1
 - metadata +4 -3
 
    
        checksums.yaml
    CHANGED
    
    | 
         @@ -1,7 +1,7 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            ---
         
     | 
| 
       2 
2 
     | 
    
         
             
            SHA256:
         
     | 
| 
       3 
     | 
    
         
            -
              metadata.gz:  
     | 
| 
       4 
     | 
    
         
            -
              data.tar.gz:  
     | 
| 
      
 3 
     | 
    
         
            +
              metadata.gz: ce853372191c85509a65417abeaa05c484976f24681246babf9bd00f8db16df1
         
     | 
| 
      
 4 
     | 
    
         
            +
              data.tar.gz: 3327004180566a194c7de8288c260b8fe9487d80c25493d82e041cf4fc0062e2
         
     | 
| 
       5 
5 
     | 
    
         
             
            SHA512:
         
     | 
| 
       6 
     | 
    
         
            -
              metadata.gz:  
     | 
| 
       7 
     | 
    
         
            -
              data.tar.gz:  
     | 
| 
      
 6 
     | 
    
         
            +
              metadata.gz: d8b6b0f7bd8b79963931b6b28b6a6cee59be18b8f185e2acadc22487a27b5793edabf0fb8b80857c5dfc0eb036a27b67a55750b8dd3c8eb1d199f06323e2b919
         
     | 
| 
      
 7 
     | 
    
         
            +
              data.tar.gz: ffcafd2e9e99d6654f9689dd021c874cf53847833e26207853b0051056ae7cfe0b5ff202036e9a38264198facfa2c66f83b6ee51815131b63aa350415139b614
         
     | 
    
        data/CHANGELOG.md
    CHANGED
    
    
    
        data/ext/torch/ext.cpp
    CHANGED
    
    | 
         @@ -12,6 +12,7 @@ void init_torch(Rice::Module& m); 
     | 
|
| 
       12 
12 
     | 
    
         
             
            void init_backends(Rice::Module& m);
         
     | 
| 
       13 
13 
     | 
    
         
             
            void init_cuda(Rice::Module& m);
         
     | 
| 
       14 
14 
     | 
    
         
             
            void init_device(Rice::Module& m);
         
     | 
| 
      
 15 
     | 
    
         
            +
            void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator);
         
     | 
| 
       15 
16 
     | 
    
         
             
            void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
         
     | 
| 
       16 
17 
     | 
    
         
             
            void init_random(Rice::Module& m);
         
     | 
| 
       17 
18 
     | 
    
         | 
| 
         @@ -23,6 +24,7 @@ void Init_ext() 
     | 
|
| 
       23 
24 
     | 
    
         
             
              // need to define certain classes up front to keep Rice happy
         
     | 
| 
       24 
25 
     | 
    
         
             
              auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
         
     | 
| 
       25 
26 
     | 
    
         
             
                .define_constructor(Rice::Constructor<torch::IValue>());
         
     | 
| 
      
 27 
     | 
    
         
            +
              auto rb_cGenerator = Rice::define_class_under<torch::Generator>(m, "Generator");
         
     | 
| 
       26 
28 
     | 
    
         
             
              auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
         
     | 
| 
       27 
29 
     | 
    
         
             
              auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
         
     | 
| 
       28 
30 
     | 
    
         
             
                .define_constructor(Rice::Constructor<torch::TensorOptions>());
         
     | 
| 
         @@ -38,6 +40,7 @@ void Init_ext() 
     | 
|
| 
       38 
40 
     | 
    
         
             
              init_backends(m);
         
     | 
| 
       39 
41 
     | 
    
         
             
              init_cuda(m);
         
     | 
| 
       40 
42 
     | 
    
         
             
              init_device(m);
         
     | 
| 
      
 43 
     | 
    
         
            +
              init_generator(m, rb_cGenerator);
         
     | 
| 
       41 
44 
     | 
    
         
             
              init_ivalue(m, rb_cIValue);
         
     | 
| 
       42 
45 
     | 
    
         
             
              init_random(m);
         
     | 
| 
       43 
46 
     | 
    
         
             
            }
         
     | 
| 
         @@ -0,0 +1,50 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            #include <torch/torch.h>
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            #include <rice/rice.hpp>
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            #include "utils.h"
         
     | 
| 
      
 6 
     | 
    
         
            +
             
     | 
| 
      
 7 
     | 
    
         
            +
            void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator) {
         
     | 
| 
      
 8 
     | 
    
         
            +
              // https://github.com/pytorch/pytorch/blob/master/torch/csrc/Generator.cpp
         
     | 
| 
      
 9 
     | 
    
         
            +
              rb_cGenerator
         
     | 
| 
      
 10 
     | 
    
         
            +
                .add_handler<torch::Error>(handle_error)
         
     | 
| 
      
 11 
     | 
    
         
            +
                .define_singleton_function(
         
     | 
| 
      
 12 
     | 
    
         
            +
                  "new",
         
     | 
| 
      
 13 
     | 
    
         
            +
                  []() {
         
     | 
| 
      
 14 
     | 
    
         
            +
                    // TODO support more devices
         
     | 
| 
      
 15 
     | 
    
         
            +
                    return torch::make_generator<torch::CPUGeneratorImpl>();
         
     | 
| 
      
 16 
     | 
    
         
            +
                  })
         
     | 
| 
      
 17 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 18 
     | 
    
         
            +
                  "device",
         
     | 
| 
      
 19 
     | 
    
         
            +
                  [](torch::Generator& self) {
         
     | 
| 
      
 20 
     | 
    
         
            +
                    return self.device();
         
     | 
| 
      
 21 
     | 
    
         
            +
                  })
         
     | 
| 
      
 22 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 23 
     | 
    
         
            +
                  "initial_seed",
         
     | 
| 
      
 24 
     | 
    
         
            +
                  [](torch::Generator& self) {
         
     | 
| 
      
 25 
     | 
    
         
            +
                    return self.current_seed();
         
     | 
| 
      
 26 
     | 
    
         
            +
                  })
         
     | 
| 
      
 27 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 28 
     | 
    
         
            +
                  "manual_seed",
         
     | 
| 
      
 29 
     | 
    
         
            +
                  [](torch::Generator& self, uint64_t seed) {
         
     | 
| 
      
 30 
     | 
    
         
            +
                    self.set_current_seed(seed);
         
     | 
| 
      
 31 
     | 
    
         
            +
                    return self;
         
     | 
| 
      
 32 
     | 
    
         
            +
                  })
         
     | 
| 
      
 33 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 34 
     | 
    
         
            +
                  "seed",
         
     | 
| 
      
 35 
     | 
    
         
            +
                  [](torch::Generator& self) {
         
     | 
| 
      
 36 
     | 
    
         
            +
                    return self.seed();
         
     | 
| 
      
 37 
     | 
    
         
            +
                  })
         
     | 
| 
      
 38 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 39 
     | 
    
         
            +
                  "state",
         
     | 
| 
      
 40 
     | 
    
         
            +
                  [](torch::Generator& self) {
         
     | 
| 
      
 41 
     | 
    
         
            +
                    return self.get_state();
         
     | 
| 
      
 42 
     | 
    
         
            +
                  })
         
     | 
| 
      
 43 
     | 
    
         
            +
                .define_method(
         
     | 
| 
      
 44 
     | 
    
         
            +
                  "state=",
         
     | 
| 
      
 45 
     | 
    
         
            +
                  [](torch::Generator& self, const torch::Tensor& state) {
         
     | 
| 
      
 46 
     | 
    
         
            +
                    self.set_state(state);
         
     | 
| 
      
 47 
     | 
    
         
            +
                  });
         
     | 
| 
      
 48 
     | 
    
         
            +
             
     | 
| 
      
 49 
     | 
    
         
            +
              THPGeneratorClass = rb_cGenerator.value();
         
     | 
| 
      
 50 
     | 
    
         
            +
            }
         
     | 
| 
         @@ -2,6 +2,7 @@ 
     | 
|
| 
       2 
2 
     | 
    
         | 
| 
       3 
3 
     | 
    
         
             
            #include "ruby_arg_parser.h"
         
     | 
| 
       4 
4 
     | 
    
         | 
| 
      
 5 
     | 
    
         
            +
            VALUE THPGeneratorClass = Qnil;
         
     | 
| 
       5 
6 
     | 
    
         
             
            VALUE THPVariableClass = Qnil;
         
     | 
| 
       6 
7 
     | 
    
         | 
| 
       7 
8 
     | 
    
         
             
            static std::unordered_map<std::string, ParameterType> type_map = {
         
     | 
| 
         @@ -244,7 +245,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool 
     | 
|
| 
       244 
245 
     | 
    
         
             
                  return size > 0 && FIXNUM_P(obj);
         
     | 
| 
       245 
246 
     | 
    
         
             
                }
         
     | 
| 
       246 
247 
     | 
    
         
             
                case ParameterType::FLOAT_LIST: return (RB_TYPE_P(obj, T_ARRAY));
         
     | 
| 
       247 
     | 
    
         
            -
                case ParameterType::GENERATOR: return  
     | 
| 
      
 248 
     | 
    
         
            +
                case ParameterType::GENERATOR: return THPGenerator_Check(obj);
         
     | 
| 
       248 
249 
     | 
    
         
             
                case ParameterType::BOOL: return obj == Qtrue || obj == Qfalse;
         
     | 
| 
       249 
250 
     | 
    
         
             
                case ParameterType::STORAGE: return false; // return isStorage(obj);
         
     | 
| 
       250 
251 
     | 
    
         
             
                // case ParameterType::PYOBJECT: return true;
         
     | 
    
        data/ext/torch/ruby_arg_parser.h
    CHANGED
    
    | 
         @@ -223,7 +223,7 @@ inline c10::OptionalArray<c10::SymInt> RubyArgs::symintlistOptional(int i) { 
     | 
|
| 
       223 
223 
     | 
    
         | 
| 
       224 
224 
     | 
    
         
             
            inline c10::optional<at::Generator> RubyArgs::generator(int i) {
         
     | 
| 
       225 
225 
     | 
    
         
             
              if (NIL_P(args[i])) return c10::nullopt;
         
     | 
| 
       226 
     | 
    
         
            -
               
     | 
| 
      
 226 
     | 
    
         
            +
              return Rice::detail::From_Ruby<torch::Generator>().convert(args[i]);
         
     | 
| 
       227 
227 
     | 
    
         
             
            }
         
     | 
| 
       228 
228 
     | 
    
         | 
| 
       229 
229 
     | 
    
         
             
            inline at::Storage RubyArgs::storage(int i) {
         
     | 
    
        data/ext/torch/utils.h
    CHANGED
    
    | 
         @@ -17,6 +17,7 @@ inline void handle_error(torch::Error const & ex) { 
     | 
|
| 
       17 
17 
     | 
    
         | 
| 
       18 
18 
     | 
    
         
             
            // keep THP prefix for now to make it easier to compare code
         
     | 
| 
       19 
19 
     | 
    
         | 
| 
      
 20 
     | 
    
         
            +
            extern VALUE THPGeneratorClass;
         
     | 
| 
       20 
21 
     | 
    
         
             
            extern VALUE THPVariableClass;
         
     | 
| 
       21 
22 
     | 
    
         | 
| 
       22 
23 
     | 
    
         
             
            inline VALUE THPUtils_internSymbol(const std::string& str) {
         
     | 
| 
         @@ -44,6 +45,10 @@ inline bool THPUtils_checkScalar(VALUE obj) { 
     | 
|
| 
       44 
45 
     | 
    
         
             
              return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
         
     | 
| 
       45 
46 
     | 
    
         
             
            }
         
     | 
| 
       46 
47 
     | 
    
         | 
| 
      
 48 
     | 
    
         
            +
            inline bool THPGenerator_Check(VALUE obj) {
         
     | 
| 
      
 49 
     | 
    
         
            +
              return rb_obj_is_kind_of(obj, THPGeneratorClass);
         
     | 
| 
      
 50 
     | 
    
         
            +
            }
         
     | 
| 
      
 51 
     | 
    
         
            +
             
     | 
| 
       47 
52 
     | 
    
         
             
            inline bool THPVariable_Check(VALUE obj) {
         
     | 
| 
       48 
53 
     | 
    
         
             
              return rb_obj_is_kind_of(obj, THPVariableClass);
         
     | 
| 
       49 
54 
     | 
    
         
             
            }
         
     | 
    
        data/lib/torch/version.rb
    CHANGED
    
    
    
        metadata
    CHANGED
    
    | 
         @@ -1,14 +1,14 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            --- !ruby/object:Gem::Specification
         
     | 
| 
       2 
2 
     | 
    
         
             
            name: torch-rb
         
     | 
| 
       3 
3 
     | 
    
         
             
            version: !ruby/object:Gem::Version
         
     | 
| 
       4 
     | 
    
         
            -
              version: 0.12. 
     | 
| 
      
 4 
     | 
    
         
            +
              version: 0.12.1
         
     | 
| 
       5 
5 
     | 
    
         
             
            platform: ruby
         
     | 
| 
       6 
6 
     | 
    
         
             
            authors:
         
     | 
| 
       7 
7 
     | 
    
         
             
            - Andrew Kane
         
     | 
| 
       8 
8 
     | 
    
         
             
            autorequire:
         
     | 
| 
       9 
9 
     | 
    
         
             
            bindir: bin
         
     | 
| 
       10 
10 
     | 
    
         
             
            cert_chain: []
         
     | 
| 
       11 
     | 
    
         
            -
            date:  
     | 
| 
      
 11 
     | 
    
         
            +
            date: 2023-01-30 00:00:00.000000000 Z
         
     | 
| 
       12 
12 
     | 
    
         
             
            dependencies:
         
     | 
| 
       13 
13 
     | 
    
         
             
            - !ruby/object:Gem::Dependency
         
     | 
| 
       14 
14 
     | 
    
         
             
              name: rice
         
     | 
| 
         @@ -44,6 +44,7 @@ files: 
     | 
|
| 
       44 
44 
     | 
    
         
             
            - ext/torch/extconf.rb
         
     | 
| 
       45 
45 
     | 
    
         
             
            - ext/torch/fft.cpp
         
     | 
| 
       46 
46 
     | 
    
         
             
            - ext/torch/fft_functions.h
         
     | 
| 
      
 47 
     | 
    
         
            +
            - ext/torch/generator.cpp
         
     | 
| 
       47 
48 
     | 
    
         
             
            - ext/torch/ivalue.cpp
         
     | 
| 
       48 
49 
     | 
    
         
             
            - ext/torch/linalg.cpp
         
     | 
| 
       49 
50 
     | 
    
         
             
            - ext/torch/linalg_functions.h
         
     | 
| 
         @@ -229,7 +230,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement 
     | 
|
| 
       229 
230 
     | 
    
         
             
                - !ruby/object:Gem::Version
         
     | 
| 
       230 
231 
     | 
    
         
             
                  version: '0'
         
     | 
| 
       231 
232 
     | 
    
         
             
            requirements: []
         
     | 
| 
       232 
     | 
    
         
            -
            rubygems_version: 3. 
     | 
| 
      
 233 
     | 
    
         
            +
            rubygems_version: 3.4.1
         
     | 
| 
       233 
234 
     | 
    
         
             
            signing_key:
         
     | 
| 
       234 
235 
     | 
    
         
             
            specification_version: 4
         
     | 
| 
       235 
236 
     | 
    
         
             
            summary: Deep learning for Ruby, powered by LibTorch
         
     |