torch-rb 0.12.0 → 0.12.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 24617a6191c4d1e42ff0d5885c0762db75d034ee52817d70004ddc52025890dd
4
- data.tar.gz: 7036f3c7fac8a1aac22914ec2811b8771464baeb847052e417b069f3897b3977
3
+ metadata.gz: ce853372191c85509a65417abeaa05c484976f24681246babf9bd00f8db16df1
4
+ data.tar.gz: 3327004180566a194c7de8288c260b8fe9487d80c25493d82e041cf4fc0062e2
5
5
  SHA512:
6
- metadata.gz: 044166f526bea5fea0c4314abc55dfaf2410cdd5d7bce764ccb78329754fe284675dc0ecb55cf5a4546587221cf2361c8e215f1e9e067668ea5833c445bea961
7
- data.tar.gz: e086d0dcc731fb3e810f8fb91d04d95bf3c5efd9cd35f053238be60479d8a47f2126ddad7ab2705d18a26175c6c4126f12be3f19962c2a69f24ad3ee551252b1
6
+ metadata.gz: d8b6b0f7bd8b79963931b6b28b6a6cee59be18b8f185e2acadc22487a27b5793edabf0fb8b80857c5dfc0eb036a27b67a55750b8dd3c8eb1d199f06323e2b919
7
+ data.tar.gz: ffcafd2e9e99d6654f9689dd021c874cf53847833e26207853b0051056ae7cfe0b5ff202036e9a38264198facfa2c66f83b6ee51815131b63aa350415139b614
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.12.1 (2023-01-29)
2
+
3
+ - Added `Generator` class
4
+
1
5
  ## 0.12.0 (2022-11-05)
2
6
 
3
7
  - Updated LibTorch to 1.13.0
data/ext/torch/ext.cpp CHANGED
@@ -12,6 +12,7 @@ void init_torch(Rice::Module& m);
12
12
  void init_backends(Rice::Module& m);
13
13
  void init_cuda(Rice::Module& m);
14
14
  void init_device(Rice::Module& m);
15
+ void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator);
15
16
  void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
16
17
  void init_random(Rice::Module& m);
17
18
 
@@ -23,6 +24,7 @@ void Init_ext()
23
24
  // need to define certain classes up front to keep Rice happy
24
25
  auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
25
26
  .define_constructor(Rice::Constructor<torch::IValue>());
27
+ auto rb_cGenerator = Rice::define_class_under<torch::Generator>(m, "Generator");
26
28
  auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
27
29
  auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
28
30
  .define_constructor(Rice::Constructor<torch::TensorOptions>());
@@ -38,6 +40,7 @@ void Init_ext()
38
40
  init_backends(m);
39
41
  init_cuda(m);
40
42
  init_device(m);
43
+ init_generator(m, rb_cGenerator);
41
44
  init_ivalue(m, rb_cIValue);
42
45
  init_random(m);
43
46
  }
@@ -0,0 +1,50 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_generator(Rice::Module& m, Rice::Class& rb_cGenerator) {
8
+ // https://github.com/pytorch/pytorch/blob/master/torch/csrc/Generator.cpp
9
+ rb_cGenerator
10
+ .add_handler<torch::Error>(handle_error)
11
+ .define_singleton_function(
12
+ "new",
13
+ []() {
14
+ // TODO support more devices
15
+ return torch::make_generator<torch::CPUGeneratorImpl>();
16
+ })
17
+ .define_method(
18
+ "device",
19
+ [](torch::Generator& self) {
20
+ return self.device();
21
+ })
22
+ .define_method(
23
+ "initial_seed",
24
+ [](torch::Generator& self) {
25
+ return self.current_seed();
26
+ })
27
+ .define_method(
28
+ "manual_seed",
29
+ [](torch::Generator& self, uint64_t seed) {
30
+ self.set_current_seed(seed);
31
+ return self;
32
+ })
33
+ .define_method(
34
+ "seed",
35
+ [](torch::Generator& self) {
36
+ return self.seed();
37
+ })
38
+ .define_method(
39
+ "state",
40
+ [](torch::Generator& self) {
41
+ return self.get_state();
42
+ })
43
+ .define_method(
44
+ "state=",
45
+ [](torch::Generator& self, const torch::Tensor& state) {
46
+ self.set_state(state);
47
+ });
48
+
49
+ THPGeneratorClass = rb_cGenerator.value();
50
+ }
@@ -2,6 +2,7 @@
2
2
 
3
3
  #include "ruby_arg_parser.h"
4
4
 
5
+ VALUE THPGeneratorClass = Qnil;
5
6
  VALUE THPVariableClass = Qnil;
6
7
 
7
8
  static std::unordered_map<std::string, ParameterType> type_map = {
@@ -244,7 +245,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
244
245
  return size > 0 && FIXNUM_P(obj);
245
246
  }
246
247
  case ParameterType::FLOAT_LIST: return (RB_TYPE_P(obj, T_ARRAY));
247
- case ParameterType::GENERATOR: return false; // return THPGenerator_Check(obj);
248
+ case ParameterType::GENERATOR: return THPGenerator_Check(obj);
248
249
  case ParameterType::BOOL: return obj == Qtrue || obj == Qfalse;
249
250
  case ParameterType::STORAGE: return false; // return isStorage(obj);
250
251
  // case ParameterType::PYOBJECT: return true;
@@ -223,7 +223,7 @@ inline c10::OptionalArray<c10::SymInt> RubyArgs::symintlistOptional(int i) {
223
223
 
224
224
  inline c10::optional<at::Generator> RubyArgs::generator(int i) {
225
225
  if (NIL_P(args[i])) return c10::nullopt;
226
- throw std::runtime_error("generator not supported yet");
226
+ return Rice::detail::From_Ruby<torch::Generator>().convert(args[i]);
227
227
  }
228
228
 
229
229
  inline at::Storage RubyArgs::storage(int i) {
data/ext/torch/utils.h CHANGED
@@ -17,6 +17,7 @@ inline void handle_error(torch::Error const & ex) {
17
17
 
18
18
  // keep THP prefix for now to make it easier to compare code
19
19
 
20
+ extern VALUE THPGeneratorClass;
20
21
  extern VALUE THPVariableClass;
21
22
 
22
23
  inline VALUE THPUtils_internSymbol(const std::string& str) {
@@ -44,6 +45,10 @@ inline bool THPUtils_checkScalar(VALUE obj) {
44
45
  return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
45
46
  }
46
47
 
48
+ inline bool THPGenerator_Check(VALUE obj) {
49
+ return rb_obj_is_kind_of(obj, THPGeneratorClass);
50
+ }
51
+
47
52
  inline bool THPVariable_Check(VALUE obj) {
48
53
  return rb_obj_is_kind_of(obj, THPVariableClass);
49
54
  }
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.12.0"
2
+ VERSION = "0.12.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.12.0
4
+ version: 0.12.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-11-05 00:00:00.000000000 Z
11
+ date: 2023-01-30 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -44,6 +44,7 @@ files:
44
44
  - ext/torch/extconf.rb
45
45
  - ext/torch/fft.cpp
46
46
  - ext/torch/fft_functions.h
47
+ - ext/torch/generator.cpp
47
48
  - ext/torch/ivalue.cpp
48
49
  - ext/torch/linalg.cpp
49
50
  - ext/torch/linalg_functions.h
@@ -229,7 +230,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
229
230
  - !ruby/object:Gem::Version
230
231
  version: '0'
231
232
  requirements: []
232
- rubygems_version: 3.3.7
233
+ rubygems_version: 3.4.1
233
234
  signing_key:
234
235
  specification_version: 4
235
236
  summary: Deep learning for Ruby, powered by LibTorch