torch-rb 0.23.0 → 0.23.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +1 -1
- data/ext/torch/nn.cpp +2 -1
- data/ext/torch/ruby_arg_parser.cpp +10 -10
- data/ext/torch/ruby_arg_parser.h +7 -7
- data/ext/torch/templates.h +3 -3
- data/ext/torch/tensor.cpp +12 -13
- data/lib/torch/version.rb +1 -1
- metadata +1 -1
checksums.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
---
|
|
2
2
|
SHA256:
|
|
3
|
-
metadata.gz:
|
|
4
|
-
data.tar.gz:
|
|
3
|
+
metadata.gz: af0e575eccaa4ab574b53df94fb6b4f4c50cc26cf76d96d6564d97d71c08dbb6
|
|
4
|
+
data.tar.gz: f54788198cac0ce2a970f92cf965aa426872e3a4b16f36683da2aa5fb1b82580
|
|
5
5
|
SHA512:
|
|
6
|
-
metadata.gz:
|
|
7
|
-
data.tar.gz:
|
|
6
|
+
metadata.gz: c14d63c032f08b4d1f2a147a8a12f8337101206f9f78e0ad35e21f01bb960c6f91031bb66d364001ee9dbe1a2eb9b7c91792e81ed4f9bd89a2d07ff6b53fefc6
|
|
7
|
+
data.tar.gz: 50aeeca2ce811c594c436aa1b55cb72808ea001fc1c78aaf989040491ad942020930b378179bb245bcb7d4c3a32bb9919d9d114f121926b5968cf64461bc210b
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
|
@@ -22,7 +22,7 @@ As well as:
|
|
|
22
22
|
First, [download LibTorch](https://pytorch.org/get-started/locally/). For Mac arm64, use:
|
|
23
23
|
|
|
24
24
|
```sh
|
|
25
|
-
curl -L https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.
|
|
25
|
+
curl -L https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.10.0.zip > libtorch.zip
|
|
26
26
|
unzip -q libtorch.zip
|
|
27
27
|
```
|
|
28
28
|
|
data/ext/torch/nn.cpp
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
#include <optional>
|
|
1
2
|
#include <utility>
|
|
2
3
|
|
|
3
4
|
#include <torch/torch.h>
|
|
@@ -95,7 +96,7 @@ void init_nn(Rice::Module& m) {
|
|
|
95
96
|
"grad",
|
|
96
97
|
[](Parameter& self) {
|
|
97
98
|
auto grad = self.grad();
|
|
98
|
-
return grad.defined() ?
|
|
99
|
+
return grad.defined() ? std::optional<torch::Tensor>{grad} : std::nullopt;
|
|
99
100
|
})
|
|
100
101
|
// can't use grad=
|
|
101
102
|
// assignment methods fail with Ruby 3.0
|
|
@@ -120,7 +120,7 @@ bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
|
|
|
120
120
|
VALUE iobj = rb_ary_entry(obj, idx);
|
|
121
121
|
if (!THPVariable_Check(iobj)) {
|
|
122
122
|
if (throw_error) {
|
|
123
|
-
|
|
123
|
+
throw Rice::Exception(rb_eArgError, "expected Tensor as element %d in argument %d, but got %s",
|
|
124
124
|
static_cast<int>(idx), argnum, rb_obj_classname(obj));
|
|
125
125
|
}
|
|
126
126
|
return false;
|
|
@@ -583,10 +583,10 @@ static void extra_args(const FunctionSignature& signature, ssize_t nargs) {
|
|
|
583
583
|
const long min_args = signature.min_args;
|
|
584
584
|
const long nargs_ = nargs;
|
|
585
585
|
if (min_args != max_pos_args) {
|
|
586
|
-
|
|
586
|
+
throw Rice::Exception(rb_eArgError, "%s() takes from %ld to %ld positional arguments but %ld were given",
|
|
587
587
|
signature.name.c_str(), min_args, max_pos_args, nargs_);
|
|
588
588
|
}
|
|
589
|
-
|
|
589
|
+
throw Rice::Exception(rb_eArgError, "%s() takes %ld positional argument%s but %ld %s given",
|
|
590
590
|
signature.name.c_str(),
|
|
591
591
|
max_pos_args, max_pos_args == 1 ? "" : "s",
|
|
592
592
|
nargs_, nargs == 1 ? "was" : "were");
|
|
@@ -608,7 +608,7 @@ static void missing_args(const FunctionSignature& signature, int idx) {
|
|
|
608
608
|
}
|
|
609
609
|
}
|
|
610
610
|
|
|
611
|
-
|
|
611
|
+
throw Rice::Exception(rb_eArgError, "%s() missing %d required positional argument%s: %s",
|
|
612
612
|
signature.name.c_str(),
|
|
613
613
|
num_missing,
|
|
614
614
|
num_missing == 1 ? "s" : "",
|
|
@@ -636,23 +636,23 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
|
|
|
636
636
|
key = rb_ary_entry(keys, 0);
|
|
637
637
|
|
|
638
638
|
if (!THPUtils_checkSymbol(key)) {
|
|
639
|
-
|
|
639
|
+
throw Rice::Exception(rb_eArgError, "keywords must be symbols, not %s", rb_obj_classname(key));
|
|
640
640
|
}
|
|
641
641
|
|
|
642
642
|
auto param_idx = find_param(signature, key);
|
|
643
643
|
if (param_idx < 0) {
|
|
644
|
-
|
|
644
|
+
throw Rice::Exception(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
|
|
645
645
|
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
|
646
646
|
}
|
|
647
647
|
|
|
648
648
|
if (param_idx < num_pos_args) {
|
|
649
|
-
|
|
649
|
+
throw Rice::Exception(rb_eArgError, "%s() got multiple values for argument '%s'",
|
|
650
650
|
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
|
651
651
|
}
|
|
652
652
|
}
|
|
653
653
|
|
|
654
654
|
// this should never be hit
|
|
655
|
-
|
|
655
|
+
throw Rice::Exception(rb_eArgError, "invalid keyword arguments");
|
|
656
656
|
}
|
|
657
657
|
|
|
658
658
|
VALUE missing = Qundef;
|
|
@@ -740,12 +740,12 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
|
|
|
740
740
|
} else if (raise_exception) {
|
|
741
741
|
if (is_kwd) {
|
|
742
742
|
// foo(): argument 'other' must be str, not int
|
|
743
|
-
|
|
743
|
+
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, not %s",
|
|
744
744
|
name.c_str(), param.name.c_str(), param.type_name().c_str(),
|
|
745
745
|
rb_obj_classname(obj));
|
|
746
746
|
} else {
|
|
747
747
|
// foo(): argument 'other' (position 2) must be str, not int
|
|
748
|
-
|
|
748
|
+
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' (position %ld) must be %s, not %s",
|
|
749
749
|
name.c_str(), param.name.c_str(), static_cast<long>(arg_pos + 1),
|
|
750
750
|
param.type_name().c_str(), rb_obj_classname(obj));
|
|
751
751
|
}
|
data/ext/torch/ruby_arg_parser.h
CHANGED
|
@@ -165,7 +165,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
|
|
165
165
|
Check_Type(arg, T_ARRAY);
|
|
166
166
|
auto size = RARRAY_LEN(arg);
|
|
167
167
|
if (size != N) {
|
|
168
|
-
|
|
168
|
+
throw Rice::Exception(rb_eArgError, "expected array of %d elements but got %d", N, static_cast<int>(size));
|
|
169
169
|
}
|
|
170
170
|
for (int idx = 0; idx < size; idx++) {
|
|
171
171
|
VALUE obj = rb_ary_entry(arg, idx);
|
|
@@ -206,7 +206,7 @@ inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int6
|
|
|
206
206
|
if (FIXNUM_P(obj)) {
|
|
207
207
|
res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
|
|
208
208
|
} else {
|
|
209
|
-
|
|
209
|
+
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
|
210
210
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
|
211
211
|
signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
|
|
212
212
|
}
|
|
@@ -270,7 +270,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
|
270
270
|
|
|
271
271
|
auto it = dtype_map.find(args[i]);
|
|
272
272
|
if (it == dtype_map.end()) {
|
|
273
|
-
|
|
273
|
+
throw Rice::Exception(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
|
|
274
274
|
}
|
|
275
275
|
return it->second;
|
|
276
276
|
}
|
|
@@ -321,7 +321,7 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
|
|
321
321
|
if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
|
|
322
322
|
res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
|
|
323
323
|
} else {
|
|
324
|
-
|
|
324
|
+
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
|
325
325
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
|
326
326
|
signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
|
|
327
327
|
}
|
|
@@ -338,7 +338,7 @@ inline at::Layout RubyArgs::layout(int i) {
|
|
|
338
338
|
|
|
339
339
|
auto it = layout_map.find(args[i]);
|
|
340
340
|
if (it == layout_map.end()) {
|
|
341
|
-
|
|
341
|
+
throw Rice::Exception(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
|
|
342
342
|
}
|
|
343
343
|
return it->second;
|
|
344
344
|
}
|
|
@@ -469,7 +469,7 @@ struct RubyArgParser {
|
|
|
469
469
|
template<int N>
|
|
470
470
|
inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N> &dst) {
|
|
471
471
|
if (N < max_args) {
|
|
472
|
-
|
|
472
|
+
throw Rice::Exception(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", static_cast<int>(max_args), N);
|
|
473
473
|
}
|
|
474
474
|
return raw_parse(self, argc, argv, dst.args);
|
|
475
475
|
}
|
|
@@ -493,7 +493,7 @@ struct RubyArgParser {
|
|
|
493
493
|
print_error(self, args, kwargs, parsed_args);
|
|
494
494
|
|
|
495
495
|
// TODO better message
|
|
496
|
-
|
|
496
|
+
throw Rice::Exception(rb_eArgError, "No matching signatures");
|
|
497
497
|
}
|
|
498
498
|
|
|
499
499
|
void print_error(VALUE self, VALUE args, VALUE kwargs, VALUE parsed_args[]) {
|
data/ext/torch/templates.h
CHANGED
|
@@ -55,7 +55,7 @@ namespace Rice::detail {
|
|
|
55
55
|
explicit To_Ruby(Arg* arg) : arg_(arg) { }
|
|
56
56
|
|
|
57
57
|
VALUE convert(c10::complex<T> const& x) {
|
|
58
|
-
return rb_dbl_complex_new
|
|
58
|
+
return protect(rb_dbl_complex_new, x.real(), x.imag());
|
|
59
59
|
}
|
|
60
60
|
|
|
61
61
|
private:
|
|
@@ -72,8 +72,8 @@ namespace Rice::detail {
|
|
|
72
72
|
double is_convertible(VALUE value) { return Convertible::Exact; }
|
|
73
73
|
|
|
74
74
|
c10::complex<T> convert(VALUE x) {
|
|
75
|
-
VALUE real = rb_funcall
|
|
76
|
-
VALUE imag = rb_funcall
|
|
75
|
+
VALUE real = protect(rb_funcall, x, rb_intern("real"), 0);
|
|
76
|
+
VALUE imag = protect(rb_funcall, x, rb_intern("imag"), 0);
|
|
77
77
|
return c10::complex<T>(From_Ruby<T>().convert(real), From_Ruby<T>().convert(imag));
|
|
78
78
|
}
|
|
79
79
|
|
data/ext/torch/tensor.cpp
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
#include <optional>
|
|
1
2
|
#include <string>
|
|
2
3
|
#include <vector>
|
|
3
4
|
|
|
@@ -28,15 +29,15 @@ Array flat_data(Tensor& tensor) {
|
|
|
28
29
|
Rice::Class rb_cTensor;
|
|
29
30
|
|
|
30
31
|
std::vector<TensorIndex> index_vector(Array a) {
|
|
31
|
-
Object obj;
|
|
32
|
-
|
|
33
32
|
std::vector<TensorIndex> indices;
|
|
34
33
|
indices.reserve(a.size());
|
|
35
34
|
|
|
36
35
|
for (long i = 0; i < a.size(); i++) {
|
|
37
|
-
obj
|
|
36
|
+
Object obj(a[i]);
|
|
38
37
|
|
|
39
|
-
if (obj.
|
|
38
|
+
if (obj.is_nil()) {
|
|
39
|
+
indices.push_back(torch::indexing::None);
|
|
40
|
+
} else if (obj.is_instance_of(rb_cInteger)) {
|
|
40
41
|
indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
|
|
41
42
|
} else if (obj.is_instance_of(rb_cRange)) {
|
|
42
43
|
torch::optional<c10::SymInt> start_index = torch::nullopt;
|
|
@@ -64,12 +65,10 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
|
64
65
|
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
|
65
66
|
} else if (obj.is_instance_of(rb_cTensor)) {
|
|
66
67
|
indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
|
|
67
|
-
} else if (obj.
|
|
68
|
-
indices.push_back(torch::indexing::None);
|
|
69
|
-
} else if (obj == Rice::True || obj == Rice::False) {
|
|
68
|
+
} else if (obj.value() == Qtrue || obj.value() == Qfalse) {
|
|
70
69
|
indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
|
|
71
70
|
} else {
|
|
72
|
-
throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
|
71
|
+
throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj.value()));
|
|
73
72
|
}
|
|
74
73
|
}
|
|
75
74
|
return indices;
|
|
@@ -102,7 +101,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
102
101
|
add_tensor_functions(rb_cTensor);
|
|
103
102
|
THPVariableClass = rb_cTensor.value();
|
|
104
103
|
|
|
105
|
-
rb_define_method(rb_cTensor, "backward",
|
|
104
|
+
rb_define_method(rb_cTensor, "backward", tensor__backward, -1);
|
|
106
105
|
|
|
107
106
|
rb_cTensor
|
|
108
107
|
.define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
|
|
@@ -168,7 +167,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
168
167
|
"grad",
|
|
169
168
|
[](Tensor& self) {
|
|
170
169
|
auto grad = self.grad();
|
|
171
|
-
return grad.defined() ?
|
|
170
|
+
return grad.defined() ? std::optional<torch::Tensor>{grad} : std::nullopt;
|
|
172
171
|
})
|
|
173
172
|
// can't use grad=
|
|
174
173
|
// assignment methods fail with Ruby 3.0
|
|
@@ -184,15 +183,15 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
184
183
|
|
|
185
184
|
// TODO support sparse grad
|
|
186
185
|
if (!grad.options().type_equal(self.options())) {
|
|
187
|
-
|
|
186
|
+
throw Rice::Exception(rb_eArgError, "assigned grad has data of a different type");
|
|
188
187
|
}
|
|
189
188
|
|
|
190
189
|
if (self.is_cuda() && grad.get_device() != self.get_device()) {
|
|
191
|
-
|
|
190
|
+
throw Rice::Exception(rb_eArgError, "assigned grad has data located on a different device");
|
|
192
191
|
}
|
|
193
192
|
|
|
194
193
|
if (!self.sizes().equals(grad.sizes())) {
|
|
195
|
-
|
|
194
|
+
throw Rice::Exception(rb_eArgError, "assigned grad has data of a different size");
|
|
196
195
|
}
|
|
197
196
|
|
|
198
197
|
self.mutable_grad() = grad;
|
data/lib/torch/version.rb
CHANGED