torch-rb 0.23.0 → 0.24.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/torch/ivalue.cpp CHANGED
@@ -53,7 +53,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
53
53
  [](torch::IValue& self) {
54
54
  auto list = self.toListRef();
55
55
  Rice::Array obj;
56
- for (auto& elem : list) {
56
+ for (const auto& elem : list) {
57
57
  auto v = torch::IValue{elem};
58
58
  obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)), false);
59
59
  }
@@ -74,7 +74,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
74
74
  [](torch::IValue& self) {
75
75
  auto dict = self.toGenericDict();
76
76
  Rice::Hash obj;
77
- for (auto& pair : dict) {
77
+ for (const auto& pair : dict) {
78
78
  auto k = torch::IValue{pair.key()};
79
79
  auto v = torch::IValue{pair.value()};
80
80
  obj[Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(k))] = Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v));
@@ -91,7 +91,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
91
91
  "from_list",
92
92
  [](Rice::Array obj) {
93
93
  c10::impl::GenericList list(c10::AnyType::get());
94
- for (auto entry : obj) {
94
+ for (const auto& entry : obj) {
95
95
  list.push_back(Rice::detail::From_Ruby<torch::IValue>().convert(entry.value()));
96
96
  }
97
97
  return torch::IValue(list);
@@ -125,7 +125,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
125
125
  auto value_type = c10::AnyType::get();
126
126
  c10::impl::GenericDict elems(key_type, value_type);
127
127
  elems.reserve(obj.size());
128
- for (auto entry : obj) {
128
+ for (const auto& entry : obj) {
129
129
  elems.insert(Rice::detail::From_Ruby<torch::IValue>().convert(entry.first), Rice::detail::From_Ruby<torch::IValue>().convert((Rice::Object) entry.second));
130
130
  }
131
131
  return torch::IValue(std::move(elems));
data/ext/torch/nn.cpp CHANGED
@@ -1,3 +1,4 @@
1
+ #include <optional>
1
2
  #include <utility>
2
3
 
3
4
  #include <torch/torch.h>
@@ -95,7 +96,7 @@ void init_nn(Rice::Module& m) {
95
96
  "grad",
96
97
  [](Parameter& self) {
97
98
  auto grad = self.grad();
98
- return grad.defined() ? Rice::Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Rice::Nil;
99
+ return grad.defined() ? std::optional<torch::Tensor>{grad} : std::nullopt;
99
100
  })
100
101
  // can't use grad=
101
102
  // assignment methods fail with Ruby 3.0
@@ -117,11 +117,11 @@ bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
117
117
  }
118
118
  auto size = RARRAY_LEN(obj);
119
119
  for (int idx = 0; idx < size; idx++) {
120
- VALUE iobj = rb_ary_entry(obj, idx);
120
+ VALUE iobj = Rice::detail::protect(rb_ary_entry, obj, idx);
121
121
  if (!THPVariable_Check(iobj)) {
122
122
  if (throw_error) {
123
- rb_raise(rb_eArgError, "expected Tensor as element %d in argument %d, but got %s",
124
- static_cast<int>(idx), argnum, rb_obj_classname(obj));
123
+ throw Rice::Exception(rb_eArgError, "expected Tensor as element %d in argument %d, but got %s",
124
+ static_cast<int>(idx), argnum, Rice::detail::protect(rb_obj_classname, obj));
125
125
  }
126
126
  return false;
127
127
  }
@@ -136,7 +136,7 @@ static bool is_int_list(VALUE obj, int broadcast_size) {
136
136
  return true;
137
137
  }
138
138
 
139
- auto item = rb_ary_entry(obj, 0);
139
+ auto item = Rice::detail::protect(rb_ary_entry, obj, 0);
140
140
  bool int_first = false;
141
141
  if (THPUtils_checkIndex(item)) {
142
142
  // we still have to check that the rest of items are NOT symint nodes
@@ -178,7 +178,7 @@ static bool is_int_or_symint_list(VALUE obj, int broadcast_size) {
178
178
  if (RARRAY_LEN(obj) == 0) {
179
179
  return true;
180
180
  }
181
- auto item = rb_ary_entry(obj, 0);
181
+ auto item = Rice::detail::protect(rb_ary_entry, obj, 0);
182
182
 
183
183
  if (is_int_or_symint(item)) {
184
184
  return true;
@@ -583,10 +583,10 @@ static void extra_args(const FunctionSignature& signature, ssize_t nargs) {
583
583
  const long min_args = signature.min_args;
584
584
  const long nargs_ = nargs;
585
585
  if (min_args != max_pos_args) {
586
- rb_raise(rb_eArgError, "%s() takes from %ld to %ld positional arguments but %ld were given",
586
+ throw Rice::Exception(rb_eArgError, "%s() takes from %ld to %ld positional arguments but %ld were given",
587
587
  signature.name.c_str(), min_args, max_pos_args, nargs_);
588
588
  }
589
- rb_raise(rb_eArgError, "%s() takes %ld positional argument%s but %ld %s given",
589
+ throw Rice::Exception(rb_eArgError, "%s() takes %ld positional argument%s but %ld %s given",
590
590
  signature.name.c_str(),
591
591
  max_pos_args, max_pos_args == 1 ? "" : "s",
592
592
  nargs_, nargs == 1 ? "was" : "were");
@@ -608,7 +608,7 @@ static void missing_args(const FunctionSignature& signature, int idx) {
608
608
  }
609
609
  }
610
610
 
611
- rb_raise(rb_eArgError, "%s() missing %d required positional argument%s: %s",
611
+ throw Rice::Exception(rb_eArgError, "%s() missing %d required positional argument%s: %s",
612
612
  signature.name.c_str(),
613
613
  num_missing,
614
614
  num_missing == 1 ? "s" : "",
@@ -631,28 +631,28 @@ static ssize_t find_param(FunctionSignature& signature, VALUE name) {
631
631
  static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num_pos_args) {
632
632
  VALUE key;
633
633
 
634
- VALUE keys = rb_funcall(kwargs, rb_intern("keys"), 0);
634
+ VALUE keys = Rice::detail::protect(rb_funcall, kwargs, rb_intern("keys"), 0);
635
635
  if (RARRAY_LEN(keys) > 0) {
636
- key = rb_ary_entry(keys, 0);
636
+ key = Rice::detail::protect(rb_ary_entry, keys, 0);
637
637
 
638
638
  if (!THPUtils_checkSymbol(key)) {
639
- rb_raise(rb_eArgError, "keywords must be symbols, not %s", rb_obj_classname(key));
639
+ throw Rice::Exception(rb_eArgError, "keywords must be symbols, not %s", Rice::detail::protect(rb_obj_classname, key));
640
640
  }
641
641
 
642
642
  auto param_idx = find_param(signature, key);
643
643
  if (param_idx < 0) {
644
- rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
645
- signature.name.c_str(), rb_id2name(rb_to_id(key)));
644
+ throw Rice::Exception(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
645
+ signature.name.c_str(), Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, key)));
646
646
  }
647
647
 
648
648
  if (param_idx < num_pos_args) {
649
- rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
650
- signature.name.c_str(), rb_id2name(rb_to_id(key)));
649
+ throw Rice::Exception(rb_eArgError, "%s() got multiple values for argument '%s'",
650
+ signature.name.c_str(), Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, key)));
651
651
  }
652
652
  }
653
653
 
654
654
  // this should never be hit
655
- rb_raise(rb_eArgError, "invalid keyword arguments");
655
+ throw Rice::Exception(rb_eArgError, "invalid keyword arguments");
656
656
  }
657
657
 
658
658
  VALUE missing = Qundef;
@@ -703,14 +703,14 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
703
703
  }
704
704
  return false;
705
705
  }
706
- obj = rb_ary_entry(args, arg_pos);
706
+ obj = Rice::detail::protect(rb_ary_entry, args, arg_pos);
707
707
  } else if (!NIL_P(kwargs)) {
708
- obj = rb_hash_lookup2(kwargs, param.ruby_name, missing);
708
+ obj = Rice::detail::protect(rb_hash_lookup2, kwargs, param.ruby_name, missing);
709
709
  // for (VALUE numpy_name: param.numpy_python_names) {
710
710
  // if (obj) {
711
711
  // break;
712
712
  // }
713
- // obj = rb_hash_aref(kwargs, numpy_name);
713
+ // obj = Rice::detail::protect(rb_hash_aref, kwargs, numpy_name);
714
714
  // }
715
715
  is_kwd = true;
716
716
  }
@@ -740,14 +740,14 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
740
740
  } else if (raise_exception) {
741
741
  if (is_kwd) {
742
742
  // foo(): argument 'other' must be str, not int
743
- rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, not %s",
743
+ throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, not %s",
744
744
  name.c_str(), param.name.c_str(), param.type_name().c_str(),
745
- rb_obj_classname(obj));
745
+ Rice::detail::protect(rb_obj_classname, obj));
746
746
  } else {
747
747
  // foo(): argument 'other' (position 2) must be str, not int
748
- rb_raise(rb_eArgError, "%s(): argument '%s' (position %ld) must be %s, not %s",
748
+ throw Rice::Exception(rb_eArgError, "%s(): argument '%s' (position %ld) must be %s, not %s",
749
749
  name.c_str(), param.name.c_str(), static_cast<long>(arg_pos + 1),
750
- param.type_name().c_str(), rb_obj_classname(obj));
750
+ param.type_name().c_str(), Rice::detail::protect(rb_obj_classname, obj));
751
751
  }
752
752
  } else {
753
753
  return false;
@@ -165,10 +165,10 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
165
165
  Check_Type(arg, T_ARRAY);
166
166
  auto size = RARRAY_LEN(arg);
167
167
  if (size != N) {
168
- rb_raise(rb_eArgError, "expected array of %d elements but got %d", N, static_cast<int>(size));
168
+ throw Rice::Exception(rb_eArgError, "expected array of %d elements but got %d", N, static_cast<int>(size));
169
169
  }
170
170
  for (int idx = 0; idx < size; idx++) {
171
- VALUE obj = rb_ary_entry(arg, idx);
171
+ VALUE obj = Rice::detail::protect(rb_ary_entry, arg, idx);
172
172
  res[idx] = Rice::detail::From_Ruby<Tensor>().convert(obj);
173
173
  }
174
174
  return res;
@@ -202,13 +202,13 @@ inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int6
202
202
  size = RARRAY_LEN(arg);
203
203
  std::vector<int64_t> res(size);
204
204
  for (idx = 0; idx < size; idx++) {
205
- VALUE obj = rb_ary_entry(arg, idx);
205
+ VALUE obj = Rice::detail::protect(rb_ary_entry, arg, idx);
206
206
  if (FIXNUM_P(obj)) {
207
207
  res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
208
208
  } else {
209
- rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
209
+ throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
210
210
  signature.name.c_str(), signature.params[i].name.c_str(),
211
- signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
211
+ signature.params[i].type_name().c_str(), Rice::detail::protect(rb_obj_classname, obj), idx + 1);
212
212
  }
213
213
  }
214
214
  return res;
@@ -270,7 +270,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
270
270
 
271
271
  auto it = dtype_map.find(args[i]);
272
272
  if (it == dtype_map.end()) {
273
- rb_raise(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
273
+ throw Rice::Exception(rb_eArgError, "invalid dtype: %s", Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, args[i])));
274
274
  }
275
275
  return it->second;
276
276
  }
@@ -317,13 +317,13 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
317
317
  auto size = RARRAY_LEN(arg);
318
318
  std::vector<double> res(size);
319
319
  for (idx = 0; idx < size; idx++) {
320
- VALUE obj = rb_ary_entry(arg, idx);
320
+ VALUE obj = Rice::detail::protect(rb_ary_entry, arg, idx);
321
321
  if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
322
322
  res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
323
323
  } else {
324
- rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
324
+ throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
325
325
  signature.name.c_str(), signature.params[i].name.c_str(),
326
- signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
326
+ signature.params[i].type_name().c_str(), Rice::detail::protect(rb_obj_classname, obj), idx + 1);
327
327
  }
328
328
  }
329
329
  return res;
@@ -338,7 +338,7 @@ inline at::Layout RubyArgs::layout(int i) {
338
338
 
339
339
  auto it = layout_map.find(args[i]);
340
340
  if (it == layout_map.end()) {
341
- rb_raise(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
341
+ throw Rice::Exception(rb_eArgError, "invalid layout: %s", Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, args[i])));
342
342
  }
343
343
  return it->second;
344
344
  }
@@ -358,7 +358,7 @@ inline at::Device RubyArgs::device(int i) {
358
358
  return at::Device("cpu");
359
359
  }
360
360
  if (RB_TYPE_P(args[i], T_STRING)) {
361
- const std::string &device_str = THPUtils_unpackString(args[i]);
361
+ const std::string& device_str = THPUtils_unpackString(args[i]);
362
362
  return at::Device(device_str);
363
363
  }
364
364
  return Rice::detail::From_Ruby<at::Device>().convert(args[i]);
@@ -461,22 +461,22 @@ struct RubyArgParser {
461
461
 
462
462
  // Check deprecated signatures last
463
463
  std::stable_partition(signatures_.begin(), signatures_.end(),
464
- [](const FunctionSignature & sig) {
464
+ [](const FunctionSignature& sig) {
465
465
  return !sig.deprecated;
466
466
  });
467
467
  }
468
468
 
469
469
  template<int N>
470
- inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N> &dst) {
470
+ inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N>& dst) {
471
471
  if (N < max_args) {
472
- rb_raise(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", static_cast<int>(max_args), N);
472
+ throw Rice::Exception(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", static_cast<int>(max_args), N);
473
473
  }
474
474
  return raw_parse(self, argc, argv, dst.args);
475
475
  }
476
476
 
477
477
  inline RubyArgs raw_parse(VALUE self, int argc, VALUE* argv, VALUE parsed_args[]) {
478
478
  VALUE args, kwargs;
479
- rb_scan_args(argc, argv, "*:", &args, &kwargs);
479
+ Rice::detail::protect(rb_scan_args, argc, argv, "*:", &args, &kwargs);
480
480
 
481
481
  if (signatures_.size() == 1) {
482
482
  auto& signature = signatures_[0];
@@ -493,7 +493,7 @@ struct RubyArgParser {
493
493
  print_error(self, args, kwargs, parsed_args);
494
494
 
495
495
  // TODO better message
496
- rb_raise(rb_eArgError, "No matching signatures");
496
+ throw Rice::Exception(rb_eArgError, "No matching signatures");
497
497
  }
498
498
 
499
499
  void print_error(VALUE self, VALUE args, VALUE kwargs, VALUE parsed_args[]) {
@@ -54,8 +54,8 @@ namespace Rice::detail {
54
54
 
55
55
  explicit To_Ruby(Arg* arg) : arg_(arg) { }
56
56
 
57
- VALUE convert(c10::complex<T> const& x) {
58
- return rb_dbl_complex_new(x.real(), x.imag());
57
+ VALUE convert(const c10::complex<T>& x) {
58
+ return protect(rb_dbl_complex_new, x.real(), x.imag());
59
59
  }
60
60
 
61
61
  private:
@@ -72,8 +72,8 @@ namespace Rice::detail {
72
72
  double is_convertible(VALUE value) { return Convertible::Exact; }
73
73
 
74
74
  c10::complex<T> convert(VALUE x) {
75
- VALUE real = rb_funcall(x, rb_intern("real"), 0);
76
- VALUE imag = rb_funcall(x, rb_intern("imag"), 0);
75
+ VALUE real = protect(rb_funcall, x, rb_intern("real"), 0);
76
+ VALUE imag = protect(rb_funcall, x, rb_intern("imag"), 0);
77
77
  return c10::complex<T>(From_Ruby<T>().convert(real), From_Ruby<T>().convert(imag));
78
78
  }
79
79
 
data/ext/torch/tensor.cpp CHANGED
@@ -1,4 +1,6 @@
1
+ #include <optional>
1
2
  #include <string>
3
+ #include <string_view>
2
4
  #include <vector>
3
5
 
4
6
  #include <torch/torch.h>
@@ -28,15 +30,15 @@ Array flat_data(Tensor& tensor) {
28
30
  Rice::Class rb_cTensor;
29
31
 
30
32
  std::vector<TensorIndex> index_vector(Array a) {
31
- Object obj;
32
-
33
33
  std::vector<TensorIndex> indices;
34
34
  indices.reserve(a.size());
35
35
 
36
36
  for (long i = 0; i < a.size(); i++) {
37
- obj = a[i];
37
+ Object obj(a[i]);
38
38
 
39
- if (obj.is_instance_of(rb_cInteger)) {
39
+ if (obj.is_nil()) {
40
+ indices.push_back(torch::indexing::None);
41
+ } else if (obj.is_instance_of(rb_cInteger)) {
40
42
  indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
41
43
  } else if (obj.is_instance_of(rb_cRange)) {
42
44
  torch::optional<c10::SymInt> start_index = torch::nullopt;
@@ -64,12 +66,10 @@ std::vector<TensorIndex> index_vector(Array a) {
64
66
  indices.push_back(torch::indexing::Slice(start_index, stop_index));
65
67
  } else if (obj.is_instance_of(rb_cTensor)) {
66
68
  indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
67
- } else if (obj.is_nil()) {
68
- indices.push_back(torch::indexing::None);
69
- } else if (obj == Rice::True || obj == Rice::False) {
69
+ } else if (obj.value() == Qtrue || obj.value() == Qfalse) {
70
70
  indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
71
71
  } else {
72
- throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
72
+ throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", Rice::detail::protect(rb_obj_classname, obj.value()));
73
73
  }
74
74
  }
75
75
  return indices;
@@ -102,15 +102,13 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
102
102
  add_tensor_functions(rb_cTensor);
103
103
  THPVariableClass = rb_cTensor.value();
104
104
 
105
- rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
105
+ rb_define_method(rb_cTensor, "backward", tensor__backward, -1);
106
106
 
107
107
  rb_cTensor
108
108
  .define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
109
109
  .define_method("mps?", [](Tensor& self) { return self.is_mps(); })
110
110
  .define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
111
111
  .define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
112
- .define_method("dim", [](Tensor& self) { return self.dim(); })
113
- .define_method("numel", [](Tensor& self) { return self.numel(); })
114
112
  .define_method("element_size", [](Tensor& self) { return self.element_size(); })
115
113
  .define_method("requires_grad", [](Tensor& self) { return self.requires_grad(); })
116
114
  .define_method(
@@ -128,7 +126,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
128
126
  "shape",
129
127
  [](Tensor& self) {
130
128
  Array a;
131
- for (auto &size : self.sizes()) {
129
+ for (const auto& size : self.sizes()) {
132
130
  a.push(size, false);
133
131
  }
134
132
  return a;
@@ -137,7 +135,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
137
135
  "_strides",
138
136
  [](Tensor& self) {
139
137
  Array a;
140
- for (auto &stride : self.strides()) {
138
+ for (const auto& stride : self.strides()) {
141
139
  a.push(stride, false);
142
140
  }
143
141
  return a;
@@ -154,11 +152,6 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
154
152
  auto vec = index_vector(indices);
155
153
  return self.index_put_(vec, value);
156
154
  })
157
- .define_method(
158
- "contiguous?",
159
- [](Tensor& self) {
160
- return self.is_contiguous();
161
- })
162
155
  .define_method(
163
156
  "_requires_grad!",
164
157
  [](Tensor& self, bool requires_grad) {
@@ -168,7 +161,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
168
161
  "grad",
169
162
  [](Tensor& self) {
170
163
  auto grad = self.grad();
171
- return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Rice::Nil;
164
+ return grad.defined() ? std::optional<torch::Tensor>{grad} : std::nullopt;
172
165
  })
173
166
  // can't use grad=
174
167
  // assignment methods fail with Ruby 3.0
@@ -184,15 +177,15 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
184
177
 
185
178
  // TODO support sparse grad
186
179
  if (!grad.options().type_equal(self.options())) {
187
- rb_raise(rb_eArgError, "assigned grad has data of a different type");
180
+ throw Rice::Exception(rb_eArgError, "assigned grad has data of a different type");
188
181
  }
189
182
 
190
183
  if (self.is_cuda() && grad.get_device() != self.get_device()) {
191
- rb_raise(rb_eArgError, "assigned grad has data located on a different device");
184
+ throw Rice::Exception(rb_eArgError, "assigned grad has data located on a different device");
192
185
  }
193
186
 
194
187
  if (!self.sizes().equals(grad.sizes())) {
195
- rb_raise(rb_eArgError, "assigned grad has data of a different size");
188
+ throw Rice::Exception(rb_eArgError, "assigned grad has data of a different size");
196
189
  }
197
190
 
198
191
  self.mutable_grad() = grad;
@@ -234,8 +227,8 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
234
227
  tensor = tensor.contiguous();
235
228
  }
236
229
 
237
- auto data_ptr = (const char *) tensor.data_ptr();
238
- return std::string(data_ptr, tensor.numel() * tensor.element_size());
230
+ auto data_ptr = static_cast<const char *>(tensor.data_ptr());
231
+ return Rice::String(std::string_view(data_ptr, tensor.numel() * tensor.element_size()));
239
232
  })
240
233
  // for TorchVision
241
234
  .define_method(
data/ext/torch/torch.cpp CHANGED
@@ -12,7 +12,7 @@
12
12
  #include "utils.h"
13
13
 
14
14
  template<typename T>
15
- torch::Tensor make_tensor(Rice::Array a, const std::vector<int64_t> &size, const torch::TensorOptions &options) {
15
+ torch::Tensor make_tensor(Rice::Array a, const std::vector<int64_t>& size, const torch::TensorOptions& options) {
16
16
  std::vector<T> vec;
17
17
  vec.reserve(a.size());
18
18
  for (long i = 0; i < a.size(); i++) {
@@ -61,13 +61,13 @@ void init_torch(Rice::Module& m) {
61
61
  // begin operations
62
62
  .define_singleton_function(
63
63
  "_save",
64
- [](const torch::IValue &value) {
64
+ [](const torch::IValue& value) {
65
65
  auto v = torch::pickle_save(value);
66
- return Rice::Object(rb_str_new(v.data(), v.size()));
66
+ return Rice::Object(Rice::detail::protect(rb_str_new, v.data(), v.size()));
67
67
  })
68
68
  .define_singleton_function(
69
69
  "_load",
70
- [](const std::string &filename) {
70
+ [](const std::string& filename) {
71
71
  // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
72
72
  std::ifstream input(filename, std::ios::binary);
73
73
  std::vector<char> bytes(
@@ -78,13 +78,13 @@ void init_torch(Rice::Module& m) {
78
78
  })
79
79
  .define_singleton_function(
80
80
  "_from_blob",
81
- [](Rice::String s, const std::vector<int64_t> &size, const torch::TensorOptions &options) {
81
+ [](Rice::String s, const std::vector<int64_t>& size, const torch::TensorOptions& options) {
82
82
  void *data = const_cast<char *>(s.c_str());
83
83
  return torch::from_blob(data, size, options);
84
84
  })
85
85
  .define_singleton_function(
86
86
  "_tensor",
87
- [](Rice::Array a, const std::vector<int64_t> &size, const torch::TensorOptions &options) {
87
+ [](Rice::Array a, const std::vector<int64_t>& size, const torch::TensorOptions& options) {
88
88
  auto dtype = options.dtype();
89
89
  if (dtype == torch::kByte) {
90
90
  return make_tensor<uint8_t>(a, size, options);
data/ext/torch/utils.h CHANGED
@@ -8,7 +8,7 @@
8
8
  #include <rice/stl.hpp>
9
9
 
10
10
  static_assert(
11
- TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 10,
11
+ TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 12,
12
12
  "Incompatible LibTorch version"
13
13
  );
14
14
 
@@ -50,17 +50,17 @@ inline bool THPUtils_checkScalar(VALUE obj) {
50
50
  }
51
51
 
52
52
  inline bool THPDevice_Check(VALUE obj) {
53
- return rb_obj_is_kind_of(obj, THPDeviceClass);
53
+ return Rice::detail::protect(rb_obj_is_kind_of, obj, THPDeviceClass);
54
54
  }
55
55
 
56
56
  inline bool THPGenerator_Check(VALUE obj) {
57
- return rb_obj_is_kind_of(obj, THPGeneratorClass);
57
+ return Rice::detail::protect(rb_obj_is_kind_of, obj, THPGeneratorClass);
58
58
  }
59
59
 
60
60
  inline bool THPVariable_Check(VALUE obj) {
61
- return rb_obj_is_kind_of(obj, THPVariableClass);
61
+ return Rice::detail::protect(rb_obj_is_kind_of, obj, THPVariableClass);
62
62
  }
63
63
 
64
64
  inline bool THPVariable_CheckExact(VALUE obj) {
65
- return rb_obj_is_instance_of(obj, THPVariableClass);
65
+ return Rice::detail::protect(rb_obj_is_instance_of, obj, THPVariableClass);
66
66
  }
@@ -15,32 +15,34 @@ inline VALUE wrap(double x) {
15
15
  return Rice::detail::To_Ruby<double>().convert(x);
16
16
  }
17
17
 
18
- inline VALUE wrap(torch::Tensor x) {
18
+ inline VALUE wrap(const torch::Tensor& x) {
19
19
  return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
20
20
  }
21
21
 
22
- inline VALUE wrap(torch::Scalar x) {
22
+ inline VALUE wrap(const torch::Scalar& x) {
23
23
  return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
24
24
  }
25
25
 
26
- inline VALUE wrap(torch::ScalarType x) {
26
+ inline VALUE wrap(const torch::ScalarType& x) {
27
27
  return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
28
28
  }
29
29
 
30
- inline VALUE wrap(torch::QScheme x) {
30
+ inline VALUE wrap(const torch::QScheme& x) {
31
31
  return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
32
32
  }
33
33
 
34
- inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
35
- return rb_ary_new3(
34
+ inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor>& x) {
35
+ return Rice::detail::protect(
36
+ rb_ary_new3,
36
37
  2,
37
38
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
38
39
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
39
40
  );
40
41
  }
41
42
 
42
- inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
43
- return rb_ary_new3(
43
+ inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>& x) {
44
+ return Rice::detail::protect(
45
+ rb_ary_new3,
44
46
  3,
45
47
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
46
48
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
@@ -48,8 +50,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
48
50
  );
49
51
  }
50
52
 
51
- inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
52
- return rb_ary_new3(
53
+ inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>& x) {
54
+ return Rice::detail::protect(
55
+ rb_ary_new3,
53
56
  4,
54
57
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
55
58
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
@@ -58,8 +61,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch:
58
61
  );
59
62
  }
60
63
 
61
- inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
62
- return rb_ary_new3(
64
+ inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>& x) {
65
+ return Rice::detail::protect(
66
+ rb_ary_new3,
63
67
  5,
64
68
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
65
69
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
@@ -69,8 +73,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch:
69
73
  );
70
74
  }
71
75
 
72
- inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
73
- return rb_ary_new3(
76
+ inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t>& x) {
77
+ return Rice::detail::protect(
78
+ rb_ary_new3,
74
79
  4,
75
80
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
76
81
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
@@ -79,8 +84,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_
79
84
  );
80
85
  }
81
86
 
82
- inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
83
- return rb_ary_new3(
87
+ inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, double, int64_t>& x) {
88
+ return Rice::detail::protect(
89
+ rb_ary_new3,
84
90
  4,
85
91
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
86
92
  Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
@@ -89,16 +95,17 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
89
95
  );
90
96
  }
91
97
 
92
- inline VALUE wrap(torch::TensorList x) {
93
- auto a = rb_ary_new2(x.size());
94
- for (auto t : x) {
95
- rb_ary_push(a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
98
+ inline VALUE wrap(const torch::TensorList& x) {
99
+ auto a = Rice::detail::protect(rb_ary_new2, x.size());
100
+ for (const auto& t : x) {
101
+ Rice::detail::protect(rb_ary_push, a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
96
102
  }
97
103
  return a;
98
104
  }
99
105
 
100
- inline VALUE wrap(std::tuple<double, double> x) {
101
- return rb_ary_new3(
106
+ inline VALUE wrap(const std::tuple<double, double>& x) {
107
+ return Rice::detail::protect(
108
+ rb_ary_new3,
102
109
  2,
103
110
  Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
104
111
  Rice::detail::To_Ruby<double>().convert(std::get<1>(x))