torch-rb 0.23.0 → 0.23.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: f09e8876f931e0d2d2ed1620cfac6019b2d3ac9a42e708bc2d2bc5d9e0f84e51
4
- data.tar.gz: 55712f92837b23fb952c78c30413203be47d09f736f3641a7a46cd34a6c085df
3
+ metadata.gz: af0e575eccaa4ab574b53df94fb6b4f4c50cc26cf76d96d6564d97d71c08dbb6
4
+ data.tar.gz: f54788198cac0ce2a970f92cf965aa426872e3a4b16f36683da2aa5fb1b82580
5
5
  SHA512:
6
- metadata.gz: d2c140c59a26644f953dc7996b16cf83cee57d3c573e18275169fde76d85528c480ce905e57e97f0dc3d211a1c5987406785eb469ee75a64990482029ad594e2
7
- data.tar.gz: 4460ebd50c579464a6272f463421ae1b6783e5123b79e9278477439ab3a4b0a582cad1951288610c19d323bc25ce770c72968e0cbedc509376c0de0778843be5
6
+ metadata.gz: c14d63c032f08b4d1f2a147a8a12f8337101206f9f78e0ad35e21f01bb960c6f91031bb66d364001ee9dbe1a2eb9b7c91792e81ed4f9bd89a2d07ff6b53fefc6
7
+ data.tar.gz: 50aeeca2ce811c594c436aa1b55cb72808ea001fc1c78aaf989040491ad942020930b378179bb245bcb7d4c3a32bb9919d9d114f121926b5968cf64461bc210b
data/CHANGELOG.md CHANGED
@@ -1,3 +1,8 @@
1
+ ## 0.23.1 (2026-02-19)
2
+
3
+ - Fixed memory leaks with exceptions
4
+ - Fixed error with Rice 4.11
5
+
1
6
  ## 0.23.0 (2026-01-21)
2
7
 
3
8
  - Updated LibTorch to 2.10.0
data/README.md CHANGED
@@ -22,7 +22,7 @@ As well as:
22
22
  First, [download LibTorch](https://pytorch.org/get-started/locally/). For Mac arm64, use:
23
23
 
24
24
  ```sh
25
- curl -L https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.9.1.zip > libtorch.zip
25
+ curl -L https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.10.0.zip > libtorch.zip
26
26
  unzip -q libtorch.zip
27
27
  ```
28
28
 
data/ext/torch/nn.cpp CHANGED
@@ -1,3 +1,4 @@
1
+ #include <optional>
1
2
  #include <utility>
2
3
 
3
4
  #include <torch/torch.h>
@@ -95,7 +96,7 @@ void init_nn(Rice::Module& m) {
95
96
  "grad",
96
97
  [](Parameter& self) {
97
98
  auto grad = self.grad();
98
- return grad.defined() ? Rice::Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Rice::Nil;
99
+ return grad.defined() ? std::optional<torch::Tensor>{grad} : std::nullopt;
99
100
  })
100
101
  // can't use grad=
101
102
  // assignment methods fail with Ruby 3.0
@@ -120,7 +120,7 @@ bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
120
120
  VALUE iobj = rb_ary_entry(obj, idx);
121
121
  if (!THPVariable_Check(iobj)) {
122
122
  if (throw_error) {
123
- rb_raise(rb_eArgError, "expected Tensor as element %d in argument %d, but got %s",
123
+ throw Rice::Exception(rb_eArgError, "expected Tensor as element %d in argument %d, but got %s",
124
124
  static_cast<int>(idx), argnum, rb_obj_classname(obj));
125
125
  }
126
126
  return false;
@@ -583,10 +583,10 @@ static void extra_args(const FunctionSignature& signature, ssize_t nargs) {
583
583
  const long min_args = signature.min_args;
584
584
  const long nargs_ = nargs;
585
585
  if (min_args != max_pos_args) {
586
- rb_raise(rb_eArgError, "%s() takes from %ld to %ld positional arguments but %ld were given",
586
+ throw Rice::Exception(rb_eArgError, "%s() takes from %ld to %ld positional arguments but %ld were given",
587
587
  signature.name.c_str(), min_args, max_pos_args, nargs_);
588
588
  }
589
- rb_raise(rb_eArgError, "%s() takes %ld positional argument%s but %ld %s given",
589
+ throw Rice::Exception(rb_eArgError, "%s() takes %ld positional argument%s but %ld %s given",
590
590
  signature.name.c_str(),
591
591
  max_pos_args, max_pos_args == 1 ? "" : "s",
592
592
  nargs_, nargs == 1 ? "was" : "were");
@@ -608,7 +608,7 @@ static void missing_args(const FunctionSignature& signature, int idx) {
608
608
  }
609
609
  }
610
610
 
611
- rb_raise(rb_eArgError, "%s() missing %d required positional argument%s: %s",
611
+ throw Rice::Exception(rb_eArgError, "%s() missing %d required positional argument%s: %s",
612
612
  signature.name.c_str(),
613
613
  num_missing,
614
614
  num_missing == 1 ? "s" : "",
@@ -636,23 +636,23 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
636
636
  key = rb_ary_entry(keys, 0);
637
637
 
638
638
  if (!THPUtils_checkSymbol(key)) {
639
- rb_raise(rb_eArgError, "keywords must be symbols, not %s", rb_obj_classname(key));
639
+ throw Rice::Exception(rb_eArgError, "keywords must be symbols, not %s", rb_obj_classname(key));
640
640
  }
641
641
 
642
642
  auto param_idx = find_param(signature, key);
643
643
  if (param_idx < 0) {
644
- rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
644
+ throw Rice::Exception(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
645
645
  signature.name.c_str(), rb_id2name(rb_to_id(key)));
646
646
  }
647
647
 
648
648
  if (param_idx < num_pos_args) {
649
- rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
649
+ throw Rice::Exception(rb_eArgError, "%s() got multiple values for argument '%s'",
650
650
  signature.name.c_str(), rb_id2name(rb_to_id(key)));
651
651
  }
652
652
  }
653
653
 
654
654
  // this should never be hit
655
- rb_raise(rb_eArgError, "invalid keyword arguments");
655
+ throw Rice::Exception(rb_eArgError, "invalid keyword arguments");
656
656
  }
657
657
 
658
658
  VALUE missing = Qundef;
@@ -740,12 +740,12 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
740
740
  } else if (raise_exception) {
741
741
  if (is_kwd) {
742
742
  // foo(): argument 'other' must be str, not int
743
- rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, not %s",
743
+ throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, not %s",
744
744
  name.c_str(), param.name.c_str(), param.type_name().c_str(),
745
745
  rb_obj_classname(obj));
746
746
  } else {
747
747
  // foo(): argument 'other' (position 2) must be str, not int
748
- rb_raise(rb_eArgError, "%s(): argument '%s' (position %ld) must be %s, not %s",
748
+ throw Rice::Exception(rb_eArgError, "%s(): argument '%s' (position %ld) must be %s, not %s",
749
749
  name.c_str(), param.name.c_str(), static_cast<long>(arg_pos + 1),
750
750
  param.type_name().c_str(), rb_obj_classname(obj));
751
751
  }
@@ -165,7 +165,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
165
165
  Check_Type(arg, T_ARRAY);
166
166
  auto size = RARRAY_LEN(arg);
167
167
  if (size != N) {
168
- rb_raise(rb_eArgError, "expected array of %d elements but got %d", N, static_cast<int>(size));
168
+ throw Rice::Exception(rb_eArgError, "expected array of %d elements but got %d", N, static_cast<int>(size));
169
169
  }
170
170
  for (int idx = 0; idx < size; idx++) {
171
171
  VALUE obj = rb_ary_entry(arg, idx);
@@ -206,7 +206,7 @@ inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int6
206
206
  if (FIXNUM_P(obj)) {
207
207
  res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
208
208
  } else {
209
- rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
209
+ throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
210
210
  signature.name.c_str(), signature.params[i].name.c_str(),
211
211
  signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
212
212
  }
@@ -270,7 +270,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
270
270
 
271
271
  auto it = dtype_map.find(args[i]);
272
272
  if (it == dtype_map.end()) {
273
- rb_raise(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
273
+ throw Rice::Exception(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
274
274
  }
275
275
  return it->second;
276
276
  }
@@ -321,7 +321,7 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
321
321
  if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
322
322
  res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
323
323
  } else {
324
- rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
324
+ throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
325
325
  signature.name.c_str(), signature.params[i].name.c_str(),
326
326
  signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
327
327
  }
@@ -338,7 +338,7 @@ inline at::Layout RubyArgs::layout(int i) {
338
338
 
339
339
  auto it = layout_map.find(args[i]);
340
340
  if (it == layout_map.end()) {
341
- rb_raise(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
341
+ throw Rice::Exception(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
342
342
  }
343
343
  return it->second;
344
344
  }
@@ -469,7 +469,7 @@ struct RubyArgParser {
469
469
  template<int N>
470
470
  inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N> &dst) {
471
471
  if (N < max_args) {
472
- rb_raise(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", static_cast<int>(max_args), N);
472
+ throw Rice::Exception(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", static_cast<int>(max_args), N);
473
473
  }
474
474
  return raw_parse(self, argc, argv, dst.args);
475
475
  }
@@ -493,7 +493,7 @@ struct RubyArgParser {
493
493
  print_error(self, args, kwargs, parsed_args);
494
494
 
495
495
  // TODO better message
496
- rb_raise(rb_eArgError, "No matching signatures");
496
+ throw Rice::Exception(rb_eArgError, "No matching signatures");
497
497
  }
498
498
 
499
499
  void print_error(VALUE self, VALUE args, VALUE kwargs, VALUE parsed_args[]) {
@@ -55,7 +55,7 @@ namespace Rice::detail {
55
55
  explicit To_Ruby(Arg* arg) : arg_(arg) { }
56
56
 
57
57
  VALUE convert(c10::complex<T> const& x) {
58
- return rb_dbl_complex_new(x.real(), x.imag());
58
+ return protect(rb_dbl_complex_new, x.real(), x.imag());
59
59
  }
60
60
 
61
61
  private:
@@ -72,8 +72,8 @@ namespace Rice::detail {
72
72
  double is_convertible(VALUE value) { return Convertible::Exact; }
73
73
 
74
74
  c10::complex<T> convert(VALUE x) {
75
- VALUE real = rb_funcall(x, rb_intern("real"), 0);
76
- VALUE imag = rb_funcall(x, rb_intern("imag"), 0);
75
+ VALUE real = protect(rb_funcall, x, rb_intern("real"), 0);
76
+ VALUE imag = protect(rb_funcall, x, rb_intern("imag"), 0);
77
77
  return c10::complex<T>(From_Ruby<T>().convert(real), From_Ruby<T>().convert(imag));
78
78
  }
79
79
 
data/ext/torch/tensor.cpp CHANGED
@@ -1,3 +1,4 @@
1
+ #include <optional>
1
2
  #include <string>
2
3
  #include <vector>
3
4
 
@@ -28,15 +29,15 @@ Array flat_data(Tensor& tensor) {
28
29
  Rice::Class rb_cTensor;
29
30
 
30
31
  std::vector<TensorIndex> index_vector(Array a) {
31
- Object obj;
32
-
33
32
  std::vector<TensorIndex> indices;
34
33
  indices.reserve(a.size());
35
34
 
36
35
  for (long i = 0; i < a.size(); i++) {
37
- obj = a[i];
36
+ Object obj(a[i]);
38
37
 
39
- if (obj.is_instance_of(rb_cInteger)) {
38
+ if (obj.is_nil()) {
39
+ indices.push_back(torch::indexing::None);
40
+ } else if (obj.is_instance_of(rb_cInteger)) {
40
41
  indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
41
42
  } else if (obj.is_instance_of(rb_cRange)) {
42
43
  torch::optional<c10::SymInt> start_index = torch::nullopt;
@@ -64,12 +65,10 @@ std::vector<TensorIndex> index_vector(Array a) {
64
65
  indices.push_back(torch::indexing::Slice(start_index, stop_index));
65
66
  } else if (obj.is_instance_of(rb_cTensor)) {
66
67
  indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
67
- } else if (obj.is_nil()) {
68
- indices.push_back(torch::indexing::None);
69
- } else if (obj == Rice::True || obj == Rice::False) {
68
+ } else if (obj.value() == Qtrue || obj.value() == Qfalse) {
70
69
  indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
71
70
  } else {
72
- throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
71
+ throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj.value()));
73
72
  }
74
73
  }
75
74
  return indices;
@@ -102,7 +101,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
102
101
  add_tensor_functions(rb_cTensor);
103
102
  THPVariableClass = rb_cTensor.value();
104
103
 
105
- rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
104
+ rb_define_method(rb_cTensor, "backward", tensor__backward, -1);
106
105
 
107
106
  rb_cTensor
108
107
  .define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
@@ -168,7 +167,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
168
167
  "grad",
169
168
  [](Tensor& self) {
170
169
  auto grad = self.grad();
171
- return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Rice::Nil;
170
+ return grad.defined() ? std::optional<torch::Tensor>{grad} : std::nullopt;
172
171
  })
173
172
  // can't use grad=
174
173
  // assignment methods fail with Ruby 3.0
@@ -184,15 +183,15 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
184
183
 
185
184
  // TODO support sparse grad
186
185
  if (!grad.options().type_equal(self.options())) {
187
- rb_raise(rb_eArgError, "assigned grad has data of a different type");
186
+ throw Rice::Exception(rb_eArgError, "assigned grad has data of a different type");
188
187
  }
189
188
 
190
189
  if (self.is_cuda() && grad.get_device() != self.get_device()) {
191
- rb_raise(rb_eArgError, "assigned grad has data located on a different device");
190
+ throw Rice::Exception(rb_eArgError, "assigned grad has data located on a different device");
192
191
  }
193
192
 
194
193
  if (!self.sizes().equals(grad.sizes())) {
195
- rb_raise(rb_eArgError, "assigned grad has data of a different size");
194
+ throw Rice::Exception(rb_eArgError, "assigned grad has data of a different size");
196
195
  }
197
196
 
198
197
  self.mutable_grad() = grad;
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.23.0"
2
+ VERSION = "0.23.1"
3
3
  end
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.23.0
4
+ version: 0.23.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane