torch-rb 0.22.2 → 0.23.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/torch/device.cpp CHANGED
@@ -8,7 +8,8 @@
8
8
  #include "utils.h"
9
9
 
10
10
  void init_device(Rice::Module& m) {
11
- Rice::define_class_under<torch::Device>(m, "Device")
11
+ auto rb_cDevice = Rice::define_class_under<torch::Device>(m, "Device");
12
+ rb_cDevice
12
13
  .define_constructor(Rice::Constructor<torch::Device, const std::string&>())
13
14
  .define_method(
14
15
  "_index",
@@ -28,8 +29,10 @@ void init_device(Rice::Module& m) {
28
29
  return s.str();
29
30
  })
30
31
  .define_method(
31
- "_str",
32
+ "to_s",
32
33
  [](torch::Device& self) {
33
34
  return self.str();
34
35
  });
36
+
37
+ THPDeviceClass = rb_cDevice.value();
35
38
  }
data/ext/torch/nn.cpp CHANGED
@@ -1,3 +1,4 @@
1
+ #include <optional>
1
2
  #include <utility>
2
3
 
3
4
  #include <torch/torch.h>
@@ -95,7 +96,7 @@ void init_nn(Rice::Module& m) {
95
96
  "grad",
96
97
  [](Parameter& self) {
97
98
  auto grad = self.grad();
98
- return grad.defined() ? Rice::Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Rice::Nil;
99
+ return grad.defined() ? std::optional<torch::Tensor>{grad} : std::nullopt;
99
100
  })
100
101
  // can't use grad=
101
102
  // assignment methods fail with Ruby 3.0
@@ -7,6 +7,7 @@
7
7
 
8
8
  #include "ruby_arg_parser.h"
9
9
 
10
+ VALUE THPDeviceClass = Qnil;
10
11
  VALUE THPGeneratorClass = Qnil;
11
12
  VALUE THPVariableClass = Qnil;
12
13
 
@@ -119,7 +120,7 @@ bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
119
120
  VALUE iobj = rb_ary_entry(obj, idx);
120
121
  if (!THPVariable_Check(iobj)) {
121
122
  if (throw_error) {
122
- rb_raise(rb_eArgError, "expected Tensor as element %d in argument %d, but got %s",
123
+ throw Rice::Exception(rb_eArgError, "expected Tensor as element %d in argument %d, but got %s",
123
124
  static_cast<int>(idx), argnum, rb_obj_classname(obj));
124
125
  }
125
126
  return false;
@@ -257,7 +258,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool {
257
258
  case ParameterType::LAYOUT: return SYMBOL_P(obj);
258
259
  case ParameterType::MEMORY_FORMAT: return false; // return THPMemoryFormat_Check(obj);
259
260
  case ParameterType::QSCHEME: return false; // return THPQScheme_Check(obj);
260
- case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING); // TODO check device
261
+ case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING) || THPDevice_Check(obj);
261
262
  case ParameterType::STRING: return RB_TYPE_P(obj, T_STRING);
262
263
  case ParameterType::SYM_INT: return is_int_or_symint(obj);
263
264
  case ParameterType::SYM_INT_LIST: return is_int_or_symint_list(obj, size);
@@ -582,10 +583,10 @@ static void extra_args(const FunctionSignature& signature, ssize_t nargs) {
582
583
  const long min_args = signature.min_args;
583
584
  const long nargs_ = nargs;
584
585
  if (min_args != max_pos_args) {
585
- rb_raise(rb_eArgError, "%s() takes from %ld to %ld positional arguments but %ld were given",
586
+ throw Rice::Exception(rb_eArgError, "%s() takes from %ld to %ld positional arguments but %ld were given",
586
587
  signature.name.c_str(), min_args, max_pos_args, nargs_);
587
588
  }
588
- rb_raise(rb_eArgError, "%s() takes %ld positional argument%s but %ld %s given",
589
+ throw Rice::Exception(rb_eArgError, "%s() takes %ld positional argument%s but %ld %s given",
589
590
  signature.name.c_str(),
590
591
  max_pos_args, max_pos_args == 1 ? "" : "s",
591
592
  nargs_, nargs == 1 ? "was" : "were");
@@ -607,7 +608,7 @@ static void missing_args(const FunctionSignature& signature, int idx) {
607
608
  }
608
609
  }
609
610
 
610
- rb_raise(rb_eArgError, "%s() missing %d required positional argument%s: %s",
611
+ throw Rice::Exception(rb_eArgError, "%s() missing %d required positional argument%s: %s",
611
612
  signature.name.c_str(),
612
613
  num_missing,
613
614
  num_missing == 1 ? "s" : "",
@@ -635,23 +636,23 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
635
636
  key = rb_ary_entry(keys, 0);
636
637
 
637
638
  if (!THPUtils_checkSymbol(key)) {
638
- rb_raise(rb_eArgError, "keywords must be symbols, not %s", rb_obj_classname(key));
639
+ throw Rice::Exception(rb_eArgError, "keywords must be symbols, not %s", rb_obj_classname(key));
639
640
  }
640
641
 
641
642
  auto param_idx = find_param(signature, key);
642
643
  if (param_idx < 0) {
643
- rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
644
+ throw Rice::Exception(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
644
645
  signature.name.c_str(), rb_id2name(rb_to_id(key)));
645
646
  }
646
647
 
647
648
  if (param_idx < num_pos_args) {
648
- rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
649
+ throw Rice::Exception(rb_eArgError, "%s() got multiple values for argument '%s'",
649
650
  signature.name.c_str(), rb_id2name(rb_to_id(key)));
650
651
  }
651
652
  }
652
653
 
653
654
  // this should never be hit
654
- rb_raise(rb_eArgError, "invalid keyword arguments");
655
+ throw Rice::Exception(rb_eArgError, "invalid keyword arguments");
655
656
  }
656
657
 
657
658
  VALUE missing = Qundef;
@@ -739,12 +740,12 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
739
740
  } else if (raise_exception) {
740
741
  if (is_kwd) {
741
742
  // foo(): argument 'other' must be str, not int
742
- rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, not %s",
743
+ throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, not %s",
743
744
  name.c_str(), param.name.c_str(), param.type_name().c_str(),
744
745
  rb_obj_classname(obj));
745
746
  } else {
746
747
  // foo(): argument 'other' (position 2) must be str, not int
747
- rb_raise(rb_eArgError, "%s(): argument '%s' (position %ld) must be %s, not %s",
748
+ throw Rice::Exception(rb_eArgError, "%s(): argument '%s' (position %ld) must be %s, not %s",
748
749
  name.c_str(), param.name.c_str(), static_cast<long>(arg_pos + 1),
749
750
  param.type_name().c_str(), rb_obj_classname(obj));
750
751
  }
@@ -165,7 +165,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
165
165
  Check_Type(arg, T_ARRAY);
166
166
  auto size = RARRAY_LEN(arg);
167
167
  if (size != N) {
168
- rb_raise(rb_eArgError, "expected array of %d elements but got %d", N, static_cast<int>(size));
168
+ throw Rice::Exception(rb_eArgError, "expected array of %d elements but got %d", N, static_cast<int>(size));
169
169
  }
170
170
  for (int idx = 0; idx < size; idx++) {
171
171
  VALUE obj = rb_ary_entry(arg, idx);
@@ -206,7 +206,7 @@ inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int6
206
206
  if (FIXNUM_P(obj)) {
207
207
  res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
208
208
  } else {
209
- rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
209
+ throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
210
210
  signature.name.c_str(), signature.params[i].name.c_str(),
211
211
  signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
212
212
  }
@@ -270,7 +270,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
270
270
 
271
271
  auto it = dtype_map.find(args[i]);
272
272
  if (it == dtype_map.end()) {
273
- rb_raise(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
273
+ throw Rice::Exception(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
274
274
  }
275
275
  return it->second;
276
276
  }
@@ -321,7 +321,7 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
321
321
  if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
322
322
  res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
323
323
  } else {
324
- rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
324
+ throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
325
325
  signature.name.c_str(), signature.params[i].name.c_str(),
326
326
  signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
327
327
  }
@@ -338,7 +338,7 @@ inline at::Layout RubyArgs::layout(int i) {
338
338
 
339
339
  auto it = layout_map.find(args[i]);
340
340
  if (it == layout_map.end()) {
341
- rb_raise(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
341
+ throw Rice::Exception(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
342
342
  }
343
343
  return it->second;
344
344
  }
@@ -357,8 +357,11 @@ inline at::Device RubyArgs::device(int i) {
357
357
  if (NIL_P(args[i])) {
358
358
  return at::Device("cpu");
359
359
  }
360
- const std::string &device_str = THPUtils_unpackString(args[i]);
361
- return at::Device(device_str);
360
+ if (RB_TYPE_P(args[i], T_STRING)) {
361
+ const std::string &device_str = THPUtils_unpackString(args[i]);
362
+ return at::Device(device_str);
363
+ }
364
+ return Rice::detail::From_Ruby<at::Device>().convert(args[i]);
362
365
  }
363
366
 
364
367
  inline at::Device RubyArgs::deviceWithDefault(int i, const at::Device& default_device) {
@@ -466,7 +469,7 @@ struct RubyArgParser {
466
469
  template<int N>
467
470
  inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N> &dst) {
468
471
  if (N < max_args) {
469
- rb_raise(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", static_cast<int>(max_args), N);
472
+ throw Rice::Exception(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", static_cast<int>(max_args), N);
470
473
  }
471
474
  return raw_parse(self, argc, argv, dst.args);
472
475
  }
@@ -490,7 +493,7 @@ struct RubyArgParser {
490
493
  print_error(self, args, kwargs, parsed_args);
491
494
 
492
495
  // TODO better message
493
- rb_raise(rb_eArgError, "No matching signatures");
496
+ throw Rice::Exception(rb_eArgError, "No matching signatures");
494
497
  }
495
498
 
496
499
  void print_error(VALUE self, VALUE args, VALUE kwargs, VALUE parsed_args[]) {
@@ -55,7 +55,7 @@ namespace Rice::detail {
55
55
  explicit To_Ruby(Arg* arg) : arg_(arg) { }
56
56
 
57
57
  VALUE convert(c10::complex<T> const& x) {
58
- return rb_dbl_complex_new(x.real(), x.imag());
58
+ return protect(rb_dbl_complex_new, x.real(), x.imag());
59
59
  }
60
60
 
61
61
  private:
@@ -69,11 +69,11 @@ namespace Rice::detail {
69
69
 
70
70
  explicit From_Ruby(Arg* arg) : arg_(arg) { }
71
71
 
72
- Convertible is_convertible(VALUE value) { return Convertible::Cast; }
72
+ double is_convertible(VALUE value) { return Convertible::Exact; }
73
73
 
74
74
  c10::complex<T> convert(VALUE x) {
75
- VALUE real = rb_funcall(x, rb_intern("real"), 0);
76
- VALUE imag = rb_funcall(x, rb_intern("imag"), 0);
75
+ VALUE real = protect(rb_funcall, x, rb_intern("real"), 0);
76
+ VALUE imag = protect(rb_funcall, x, rb_intern("imag"), 0);
77
77
  return c10::complex<T>(From_Ruby<T>().convert(real), From_Ruby<T>().convert(imag));
78
78
  }
79
79
 
@@ -93,7 +93,7 @@ namespace Rice::detail {
93
93
 
94
94
  explicit From_Ruby(Arg* arg) : arg_(arg) { }
95
95
 
96
- Convertible is_convertible(VALUE value) { return Convertible::Cast; }
96
+ double is_convertible(VALUE value) { return Convertible::Exact; }
97
97
 
98
98
  FanModeType convert(VALUE x) {
99
99
  auto s = String(x).str();
@@ -122,7 +122,7 @@ namespace Rice::detail {
122
122
 
123
123
  explicit From_Ruby(Arg* arg) : arg_(arg) { }
124
124
 
125
- Convertible is_convertible(VALUE value) { return Convertible::Cast; }
125
+ double is_convertible(VALUE value) { return Convertible::Exact; }
126
126
 
127
127
  NonlinearityType convert(VALUE x) {
128
128
  auto s = String(x).str();
@@ -169,7 +169,7 @@ namespace Rice::detail {
169
169
 
170
170
  explicit From_Ruby(Arg* arg) : arg_(arg) { }
171
171
 
172
- Convertible is_convertible(VALUE value) { return Convertible::Cast; }
172
+ double is_convertible(VALUE value) { return Convertible::Exact; }
173
173
 
174
174
  Scalar convert(VALUE x) {
175
175
  if (FIXNUM_P(x)) {
data/ext/torch/tensor.cpp CHANGED
@@ -1,3 +1,4 @@
1
+ #include <optional>
1
2
  #include <string>
2
3
  #include <vector>
3
4
 
@@ -28,15 +29,15 @@ Array flat_data(Tensor& tensor) {
28
29
  Rice::Class rb_cTensor;
29
30
 
30
31
  std::vector<TensorIndex> index_vector(Array a) {
31
- Object obj;
32
-
33
32
  std::vector<TensorIndex> indices;
34
33
  indices.reserve(a.size());
35
34
 
36
35
  for (long i = 0; i < a.size(); i++) {
37
- obj = a[i];
36
+ Object obj(a[i]);
38
37
 
39
- if (obj.is_instance_of(rb_cInteger)) {
38
+ if (obj.is_nil()) {
39
+ indices.push_back(torch::indexing::None);
40
+ } else if (obj.is_instance_of(rb_cInteger)) {
40
41
  indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
41
42
  } else if (obj.is_instance_of(rb_cRange)) {
42
43
  torch::optional<c10::SymInt> start_index = torch::nullopt;
@@ -64,12 +65,10 @@ std::vector<TensorIndex> index_vector(Array a) {
64
65
  indices.push_back(torch::indexing::Slice(start_index, stop_index));
65
66
  } else if (obj.is_instance_of(rb_cTensor)) {
66
67
  indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
67
- } else if (obj.is_nil()) {
68
- indices.push_back(torch::indexing::None);
69
- } else if (obj == Rice::True || obj == Rice::False) {
68
+ } else if (obj.value() == Qtrue || obj.value() == Qfalse) {
70
69
  indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
71
70
  } else {
72
- throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
71
+ throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj.value()));
73
72
  }
74
73
  }
75
74
  return indices;
@@ -102,7 +101,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
102
101
  add_tensor_functions(rb_cTensor);
103
102
  THPVariableClass = rb_cTensor.value();
104
103
 
105
- rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
104
+ rb_define_method(rb_cTensor, "backward", tensor__backward, -1);
106
105
 
107
106
  rb_cTensor
108
107
  .define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
@@ -168,7 +167,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
168
167
  "grad",
169
168
  [](Tensor& self) {
170
169
  auto grad = self.grad();
171
- return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Rice::Nil;
170
+ return grad.defined() ? std::optional<torch::Tensor>{grad} : std::nullopt;
172
171
  })
173
172
  // can't use grad=
174
173
  // assignment methods fail with Ruby 3.0
@@ -184,15 +183,15 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
184
183
 
185
184
  // TODO support sparse grad
186
185
  if (!grad.options().type_equal(self.options())) {
187
- rb_raise(rb_eArgError, "assigned grad has data of a different type");
186
+ throw Rice::Exception(rb_eArgError, "assigned grad has data of a different type");
188
187
  }
189
188
 
190
189
  if (self.is_cuda() && grad.get_device() != self.get_device()) {
191
- rb_raise(rb_eArgError, "assigned grad has data located on a different device");
190
+ throw Rice::Exception(rb_eArgError, "assigned grad has data located on a different device");
192
191
  }
193
192
 
194
193
  if (!self.sizes().equals(grad.sizes())) {
195
- rb_raise(rb_eArgError, "assigned grad has data of a different size");
194
+ throw Rice::Exception(rb_eArgError, "assigned grad has data of a different size");
196
195
  }
197
196
 
198
197
  self.mutable_grad() = grad;
@@ -215,7 +214,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
215
214
  return s.str();
216
215
  })
217
216
  .define_method(
218
- "_device",
217
+ "device",
219
218
  [](Tensor& self) {
220
219
  return self.device();
221
220
  })
data/ext/torch/utils.h CHANGED
@@ -8,7 +8,7 @@
8
8
  #include <rice/stl.hpp>
9
9
 
10
10
  static_assert(
11
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 9,
11
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 10,
12
12
  "Incompatible LibTorch version"
13
13
  );
14
14
 
@@ -20,6 +20,7 @@ inline void handle_global_error(const torch::Error& ex) {
20
20
 
21
21
  // keep THP prefix for now to make it easier to compare code
22
22
 
23
+ extern VALUE THPDeviceClass;
23
24
  extern VALUE THPGeneratorClass;
24
25
  extern VALUE THPVariableClass;
25
26
 
@@ -48,6 +49,10 @@ inline bool THPUtils_checkScalar(VALUE obj) {
48
49
  return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
49
50
  }
50
51
 
52
+ inline bool THPDevice_Check(VALUE obj) {
53
+ return rb_obj_is_kind_of(obj, THPDeviceClass);
54
+ }
55
+
51
56
  inline bool THPGenerator_Check(VALUE obj) {
52
57
  return rb_obj_is_kind_of(obj, THPGeneratorClass);
53
58
  }
data/lib/torch/device.rb CHANGED
@@ -8,7 +8,6 @@ module Torch
8
8
  extra = ", index: #{index.inspect}" if index?
9
9
  "device(type: #{type.inspect}#{extra})"
10
10
  end
11
- alias_method :to_s, :inspect
12
11
 
13
12
  def ==(other)
14
13
  eql?(other)
data/lib/torch/tensor.rb CHANGED
@@ -115,7 +115,7 @@ module Torch
115
115
  if numel != 1
116
116
  raise Error, "only one element tensors can be converted to Ruby scalars"
117
117
  end
118
- to_a.first
118
+ to_a.flatten.first
119
119
  end
120
120
 
121
121
  def to_i
@@ -210,10 +210,5 @@ module Torch
210
210
  raise TypeError, "#{self.class} can't be coerced into #{other.class}"
211
211
  end
212
212
  end
213
-
214
- # TODO return Device instead of String in 0.19.0
215
- def device
216
- _device._str
217
- end
218
213
  end
219
214
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.22.2"
2
+ VERSION = "0.23.1"
3
3
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.22.2
4
+ version: 0.23.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
@@ -15,14 +15,14 @@ dependencies:
15
15
  requirements:
16
16
  - - ">="
17
17
  - !ruby/object:Gem::Version
18
- version: '4.7'
18
+ version: '4.8'
19
19
  type: :runtime
20
20
  prerelease: false
21
21
  version_requirements: !ruby/object:Gem::Requirement
22
22
  requirements:
23
23
  - - ">="
24
24
  - !ruby/object:Gem::Version
25
- version: '4.7'
25
+ version: '4.8'
26
26
  email: andrew@ankane.org
27
27
  executables: []
28
28
  extensions:
@@ -241,7 +241,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
241
241
  - !ruby/object:Gem::Version
242
242
  version: '0'
243
243
  requirements: []
244
- rubygems_version: 3.6.9
244
+ rubygems_version: 4.0.3
245
245
  specification_version: 4
246
246
  summary: Deep learning for Ruby, powered by LibTorch
247
247
  test_files: []