torch-rb 0.12.0 → 0.12.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 24617a6191c4d1e42ff0d5885c0762db75d034ee52817d70004ddc52025890dd
4
- data.tar.gz: 7036f3c7fac8a1aac22914ec2811b8771464baeb847052e417b069f3897b3977
3
+ metadata.gz: ce853372191c85509a65417abeaa05c484976f24681246babf9bd00f8db16df1
4
+ data.tar.gz: 3327004180566a194c7de8288c260b8fe9487d80c25493d82e041cf4fc0062e2
5
5
  SHA512:
6
- metadata.gz: 044166f526bea5fea0c4314abc55dfaf2410cdd5d7bce764ccb78329754fe284675dc0ecb55cf5a4546587221cf2361c8e215f1e9e067668ea5833c445bea961
7
- data.tar.gz: e086d0dcc731fb3e810f8fb91d04d95bf3c5efd9cd35f053238be60479d8a47f2126ddad7ab2705d18a26175c6c4126f12be3f19962c2a69f24ad3ee551252b1
6
+ metadata.gz: d8b6b0f7bd8b79963931b6b28b6a6cee59be18b8f185e2acadc22487a27b5793edabf0fb8b80857c5dfc0eb036a27b67a55750b8dd3c8eb1d199f06323e2b919
7
+ data.tar.gz: ffcafd2e9e99d6654f9689dd021c874cf53847833e26207853b0051056ae7cfe0b5ff202036e9a38264198facfa2c66f83b6ee51815131b63aa350415139b614
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.12.1 (2023-01-29)
2
+
3
+ - Added `Generator` class
4
+
1
5
  ## 0.12.0 (2022-11-05)
2
6
 
3
7
  - Updated LibTorch to 1.13.0
data/ext/torch/ext.cpp CHANGED
@@ -12,6 +12,7 @@ void init_torch(Rice::Module& m);
12
12
  void init_backends(Rice::Module& m);
13
13
  void init_cuda(Rice::Module& m);
14
14
  void init_device(Rice::Module& m);
15
+ void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator);
15
16
  void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
16
17
  void init_random(Rice::Module& m);
17
18
 
@@ -23,6 +24,7 @@ void Init_ext()
23
24
  // need to define certain classes up front to keep Rice happy
24
25
  auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
25
26
  .define_constructor(Rice::Constructor<torch::IValue>());
27
+ auto rb_cGenerator = Rice::define_class_under<torch::Generator>(m, "Generator");
26
28
  auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
27
29
  auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
28
30
  .define_constructor(Rice::Constructor<torch::TensorOptions>());
@@ -38,6 +40,7 @@ void Init_ext()
38
40
  init_backends(m);
39
41
  init_cuda(m);
40
42
  init_device(m);
43
+ init_generator(m, rb_cGenerator);
41
44
  init_ivalue(m, rb_cIValue);
42
45
  init_random(m);
43
46
  }
@@ -0,0 +1,50 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator) {
8
+ // https://github.com/pytorch/pytorch/blob/master/torch/csrc/Generator.cpp
9
+ rb_cGenerator
10
+ .add_handler<torch::Error>(handle_error)
11
+ .define_singleton_function(
12
+ "new",
13
+ []() {
14
+ // TODO support more devices
15
+ return torch::make_generator<torch::CPUGeneratorImpl>();
16
+ })
17
+ .define_method(
18
+ "device",
19
+ [](torch::Generator& self) {
20
+ return self.device();
21
+ })
22
+ .define_method(
23
+ "initial_seed",
24
+ [](torch::Generator& self) {
25
+ return self.current_seed();
26
+ })
27
+ .define_method(
28
+ "manual_seed",
29
+ [](torch::Generator& self, uint64_t seed) {
30
+ self.set_current_seed(seed);
31
+ return self;
32
+ })
33
+ .define_method(
34
+ "seed",
35
+ [](torch::Generator& self) {
36
+ return self.seed();
37
+ })
38
+ .define_method(
39
+ "state",
40
+ [](torch::Generator& self) {
41
+ return self.get_state();
42
+ })
43
+ .define_method(
44
+ "state=",
45
+ [](torch::Generator& self, const torch::Tensor& state) {
46
+ self.set_state(state);
47
+ });
48
+
49
+ THPGeneratorClass = rb_cGenerator.value();
50
+ }
@@ -2,6 +2,7 @@
2
2
 
3
3
  #include "ruby_arg_parser.h"
4
4
 
5
+ VALUE THPGeneratorClass = Qnil;
5
6
  VALUE THPVariableClass = Qnil;
6
7
 
7
8
  static std::unordered_map<std::string, ParameterType> type_map = {
@@ -244,7 +245,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
244
245
  return size > 0 && FIXNUM_P(obj);
245
246
  }
246
247
  case ParameterType::FLOAT_LIST: return (RB_TYPE_P(obj, T_ARRAY));
247
- case ParameterType::GENERATOR: return false; // return THPGenerator_Check(obj);
248
+ case ParameterType::GENERATOR: return THPGenerator_Check(obj);
248
249
  case ParameterType::BOOL: return obj == Qtrue || obj == Qfalse;
249
250
  case ParameterType::STORAGE: return false; // return isStorage(obj);
250
251
  // case ParameterType::PYOBJECT: return true;
@@ -223,7 +223,7 @@ inline c10::OptionalArray<c10::SymInt> RubyArgs::symintlistOptional(int i) {
223
223
 
224
224
  inline c10::optional<at::Generator> RubyArgs::generator(int i) {
225
225
  if (NIL_P(args[i])) return c10::nullopt;
226
- throw std::runtime_error("generator not supported yet");
226
+ return Rice::detail::From_Ruby<torch::Generator>().convert(args[i]);
227
227
  }
228
228
 
229
229
  inline at::Storage RubyArgs::storage(int i) {
data/ext/torch/utils.h CHANGED
@@ -17,6 +17,7 @@ inline void handle_error(torch::Error const & ex) {
17
17
 
18
18
  // keep THP prefix for now to make it easier to compare code
19
19
 
20
+ extern VALUE THPGeneratorClass;
20
21
  extern VALUE THPVariableClass;
21
22
 
22
23
  inline VALUE THPUtils_internSymbol(const std::string& str) {
@@ -44,6 +45,10 @@ inline bool THPUtils_checkScalar(VALUE obj) {
44
45
  return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
45
46
  }
46
47
 
48
+ inline bool THPGenerator_Check(VALUE obj) {
49
+ return rb_obj_is_kind_of(obj, THPGeneratorClass);
50
+ }
51
+
47
52
  inline bool THPVariable_Check(VALUE obj) {
48
53
  return rb_obj_is_kind_of(obj, THPVariableClass);
49
54
  }
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.12.0"
2
+ VERSION = "0.12.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.12.0
4
+ version: 0.12.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-11-05 00:00:00.000000000 Z
11
+ date: 2023-01-30 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -44,6 +44,7 @@ files:
44
44
  - ext/torch/extconf.rb
45
45
  - ext/torch/fft.cpp
46
46
  - ext/torch/fft_functions.h
47
+ - ext/torch/generator.cpp
47
48
  - ext/torch/ivalue.cpp
48
49
  - ext/torch/linalg.cpp
49
50
  - ext/torch/linalg_functions.h
@@ -229,7 +230,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
229
230
  - !ruby/object:Gem::Version
230
231
  version: '0'
231
232
  requirements: []
232
- rubygems_version: 3.3.7
233
+ rubygems_version: 3.4.1
233
234
  signing_key:
234
235
  specification_version: 4
235
236
  summary: Deep learning for Ruby, powered by LibTorch