torch-rb 0.23.0 → 0.24.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/README.md +3 -1
- data/codegen/generate_functions.rb +1 -6
- data/codegen/native_functions.yaml +229 -58
- data/ext/torch/ivalue.cpp +4 -4
- data/ext/torch/nn.cpp +2 -1
- data/ext/torch/ruby_arg_parser.cpp +23 -23
- data/ext/torch/ruby_arg_parser.h +16 -16
- data/ext/torch/templates.h +4 -4
- data/ext/torch/tensor.cpp +17 -24
- data/ext/torch/torch.cpp +6 -6
- data/ext/torch/utils.h +5 -5
- data/ext/torch/wrap_outputs.h +29 -22
- data/lib/torch/hub.rb +8 -28
- data/lib/torch/nn/module.rb +1 -1
- data/lib/torch/nn/rnn_base.rb +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +4 -4
- metadata +3 -3
data/ext/torch/ivalue.cpp
CHANGED
|
@@ -53,7 +53,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
|
|
|
53
53
|
[](torch::IValue& self) {
|
|
54
54
|
auto list = self.toListRef();
|
|
55
55
|
Rice::Array obj;
|
|
56
|
-
for (auto& elem : list) {
|
|
56
|
+
for (const auto& elem : list) {
|
|
57
57
|
auto v = torch::IValue{elem};
|
|
58
58
|
obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)), false);
|
|
59
59
|
}
|
|
@@ -74,7 +74,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
|
|
|
74
74
|
[](torch::IValue& self) {
|
|
75
75
|
auto dict = self.toGenericDict();
|
|
76
76
|
Rice::Hash obj;
|
|
77
|
-
for (auto& pair : dict) {
|
|
77
|
+
for (const auto& pair : dict) {
|
|
78
78
|
auto k = torch::IValue{pair.key()};
|
|
79
79
|
auto v = torch::IValue{pair.value()};
|
|
80
80
|
obj[Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(k))] = Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v));
|
|
@@ -91,7 +91,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
|
|
|
91
91
|
"from_list",
|
|
92
92
|
[](Rice::Array obj) {
|
|
93
93
|
c10::impl::GenericList list(c10::AnyType::get());
|
|
94
|
-
for (auto entry : obj) {
|
|
94
|
+
for (const auto& entry : obj) {
|
|
95
95
|
list.push_back(Rice::detail::From_Ruby<torch::IValue>().convert(entry.value()));
|
|
96
96
|
}
|
|
97
97
|
return torch::IValue(list);
|
|
@@ -125,7 +125,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
|
|
|
125
125
|
auto value_type = c10::AnyType::get();
|
|
126
126
|
c10::impl::GenericDict elems(key_type, value_type);
|
|
127
127
|
elems.reserve(obj.size());
|
|
128
|
-
for (auto entry : obj) {
|
|
128
|
+
for (const auto& entry : obj) {
|
|
129
129
|
elems.insert(Rice::detail::From_Ruby<torch::IValue>().convert(entry.first), Rice::detail::From_Ruby<torch::IValue>().convert((Rice::Object) entry.second));
|
|
130
130
|
}
|
|
131
131
|
return torch::IValue(std::move(elems));
|
data/ext/torch/nn.cpp
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
#include <optional>
|
|
1
2
|
#include <utility>
|
|
2
3
|
|
|
3
4
|
#include <torch/torch.h>
|
|
@@ -95,7 +96,7 @@ void init_nn(Rice::Module& m) {
|
|
|
95
96
|
"grad",
|
|
96
97
|
[](Parameter& self) {
|
|
97
98
|
auto grad = self.grad();
|
|
98
|
-
return grad.defined() ?
|
|
99
|
+
return grad.defined() ? std::optional<torch::Tensor>{grad} : std::nullopt;
|
|
99
100
|
})
|
|
100
101
|
// can't use grad=
|
|
101
102
|
// assignment methods fail with Ruby 3.0
|
|
@@ -117,11 +117,11 @@ bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
|
|
|
117
117
|
}
|
|
118
118
|
auto size = RARRAY_LEN(obj);
|
|
119
119
|
for (int idx = 0; idx < size; idx++) {
|
|
120
|
-
VALUE iobj = rb_ary_entry
|
|
120
|
+
VALUE iobj = Rice::detail::protect(rb_ary_entry, obj, idx);
|
|
121
121
|
if (!THPVariable_Check(iobj)) {
|
|
122
122
|
if (throw_error) {
|
|
123
|
-
|
|
124
|
-
static_cast<int>(idx), argnum, rb_obj_classname
|
|
123
|
+
throw Rice::Exception(rb_eArgError, "expected Tensor as element %d in argument %d, but got %s",
|
|
124
|
+
static_cast<int>(idx), argnum, Rice::detail::protect(rb_obj_classname, obj));
|
|
125
125
|
}
|
|
126
126
|
return false;
|
|
127
127
|
}
|
|
@@ -136,7 +136,7 @@ static bool is_int_list(VALUE obj, int broadcast_size) {
|
|
|
136
136
|
return true;
|
|
137
137
|
}
|
|
138
138
|
|
|
139
|
-
auto item = rb_ary_entry
|
|
139
|
+
auto item = Rice::detail::protect(rb_ary_entry, obj, 0);
|
|
140
140
|
bool int_first = false;
|
|
141
141
|
if (THPUtils_checkIndex(item)) {
|
|
142
142
|
// we still have to check that the rest of items are NOT symint nodes
|
|
@@ -178,7 +178,7 @@ static bool is_int_or_symint_list(VALUE obj, int broadcast_size) {
|
|
|
178
178
|
if (RARRAY_LEN(obj) == 0) {
|
|
179
179
|
return true;
|
|
180
180
|
}
|
|
181
|
-
auto item = rb_ary_entry
|
|
181
|
+
auto item = Rice::detail::protect(rb_ary_entry, obj, 0);
|
|
182
182
|
|
|
183
183
|
if (is_int_or_symint(item)) {
|
|
184
184
|
return true;
|
|
@@ -583,10 +583,10 @@ static void extra_args(const FunctionSignature& signature, ssize_t nargs) {
|
|
|
583
583
|
const long min_args = signature.min_args;
|
|
584
584
|
const long nargs_ = nargs;
|
|
585
585
|
if (min_args != max_pos_args) {
|
|
586
|
-
|
|
586
|
+
throw Rice::Exception(rb_eArgError, "%s() takes from %ld to %ld positional arguments but %ld were given",
|
|
587
587
|
signature.name.c_str(), min_args, max_pos_args, nargs_);
|
|
588
588
|
}
|
|
589
|
-
|
|
589
|
+
throw Rice::Exception(rb_eArgError, "%s() takes %ld positional argument%s but %ld %s given",
|
|
590
590
|
signature.name.c_str(),
|
|
591
591
|
max_pos_args, max_pos_args == 1 ? "" : "s",
|
|
592
592
|
nargs_, nargs == 1 ? "was" : "were");
|
|
@@ -608,7 +608,7 @@ static void missing_args(const FunctionSignature& signature, int idx) {
|
|
|
608
608
|
}
|
|
609
609
|
}
|
|
610
610
|
|
|
611
|
-
|
|
611
|
+
throw Rice::Exception(rb_eArgError, "%s() missing %d required positional argument%s: %s",
|
|
612
612
|
signature.name.c_str(),
|
|
613
613
|
num_missing,
|
|
614
614
|
num_missing == 1 ? "s" : "",
|
|
@@ -631,28 +631,28 @@ static ssize_t find_param(FunctionSignature& signature, VALUE name) {
|
|
|
631
631
|
static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num_pos_args) {
|
|
632
632
|
VALUE key;
|
|
633
633
|
|
|
634
|
-
VALUE keys = rb_funcall
|
|
634
|
+
VALUE keys = Rice::detail::protect(rb_funcall, kwargs, rb_intern("keys"), 0);
|
|
635
635
|
if (RARRAY_LEN(keys) > 0) {
|
|
636
|
-
key = rb_ary_entry
|
|
636
|
+
key = Rice::detail::protect(rb_ary_entry, keys, 0);
|
|
637
637
|
|
|
638
638
|
if (!THPUtils_checkSymbol(key)) {
|
|
639
|
-
|
|
639
|
+
throw Rice::Exception(rb_eArgError, "keywords must be symbols, not %s", Rice::detail::protect(rb_obj_classname, key));
|
|
640
640
|
}
|
|
641
641
|
|
|
642
642
|
auto param_idx = find_param(signature, key);
|
|
643
643
|
if (param_idx < 0) {
|
|
644
|
-
|
|
645
|
-
signature.name.c_str(), rb_id2name(rb_to_id
|
|
644
|
+
throw Rice::Exception(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
|
|
645
|
+
signature.name.c_str(), Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, key)));
|
|
646
646
|
}
|
|
647
647
|
|
|
648
648
|
if (param_idx < num_pos_args) {
|
|
649
|
-
|
|
650
|
-
signature.name.c_str(), rb_id2name(rb_to_id
|
|
649
|
+
throw Rice::Exception(rb_eArgError, "%s() got multiple values for argument '%s'",
|
|
650
|
+
signature.name.c_str(), Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, key)));
|
|
651
651
|
}
|
|
652
652
|
}
|
|
653
653
|
|
|
654
654
|
// this should never be hit
|
|
655
|
-
|
|
655
|
+
throw Rice::Exception(rb_eArgError, "invalid keyword arguments");
|
|
656
656
|
}
|
|
657
657
|
|
|
658
658
|
VALUE missing = Qundef;
|
|
@@ -703,14 +703,14 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
|
|
|
703
703
|
}
|
|
704
704
|
return false;
|
|
705
705
|
}
|
|
706
|
-
obj = rb_ary_entry
|
|
706
|
+
obj = Rice::detail::protect(rb_ary_entry, args, arg_pos);
|
|
707
707
|
} else if (!NIL_P(kwargs)) {
|
|
708
|
-
obj = rb_hash_lookup2
|
|
708
|
+
obj = Rice::detail::protect(rb_hash_lookup2, kwargs, param.ruby_name, missing);
|
|
709
709
|
// for (VALUE numpy_name: param.numpy_python_names) {
|
|
710
710
|
// if (obj) {
|
|
711
711
|
// break;
|
|
712
712
|
// }
|
|
713
|
-
// obj = rb_hash_aref
|
|
713
|
+
// obj = Rice::detail::protect(rb_hash_aref, kwargs, numpy_name);
|
|
714
714
|
// }
|
|
715
715
|
is_kwd = true;
|
|
716
716
|
}
|
|
@@ -740,14 +740,14 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
|
|
|
740
740
|
} else if (raise_exception) {
|
|
741
741
|
if (is_kwd) {
|
|
742
742
|
// foo(): argument 'other' must be str, not int
|
|
743
|
-
|
|
743
|
+
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, not %s",
|
|
744
744
|
name.c_str(), param.name.c_str(), param.type_name().c_str(),
|
|
745
|
-
rb_obj_classname
|
|
745
|
+
Rice::detail::protect(rb_obj_classname, obj));
|
|
746
746
|
} else {
|
|
747
747
|
// foo(): argument 'other' (position 2) must be str, not int
|
|
748
|
-
|
|
748
|
+
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' (position %ld) must be %s, not %s",
|
|
749
749
|
name.c_str(), param.name.c_str(), static_cast<long>(arg_pos + 1),
|
|
750
|
-
param.type_name().c_str(), rb_obj_classname
|
|
750
|
+
param.type_name().c_str(), Rice::detail::protect(rb_obj_classname, obj));
|
|
751
751
|
}
|
|
752
752
|
} else {
|
|
753
753
|
return false;
|
data/ext/torch/ruby_arg_parser.h
CHANGED
|
@@ -165,10 +165,10 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
|
|
165
165
|
Check_Type(arg, T_ARRAY);
|
|
166
166
|
auto size = RARRAY_LEN(arg);
|
|
167
167
|
if (size != N) {
|
|
168
|
-
|
|
168
|
+
throw Rice::Exception(rb_eArgError, "expected array of %d elements but got %d", N, static_cast<int>(size));
|
|
169
169
|
}
|
|
170
170
|
for (int idx = 0; idx < size; idx++) {
|
|
171
|
-
VALUE obj = rb_ary_entry
|
|
171
|
+
VALUE obj = Rice::detail::protect(rb_ary_entry, arg, idx);
|
|
172
172
|
res[idx] = Rice::detail::From_Ruby<Tensor>().convert(obj);
|
|
173
173
|
}
|
|
174
174
|
return res;
|
|
@@ -202,13 +202,13 @@ inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int6
|
|
|
202
202
|
size = RARRAY_LEN(arg);
|
|
203
203
|
std::vector<int64_t> res(size);
|
|
204
204
|
for (idx = 0; idx < size; idx++) {
|
|
205
|
-
VALUE obj = rb_ary_entry
|
|
205
|
+
VALUE obj = Rice::detail::protect(rb_ary_entry, arg, idx);
|
|
206
206
|
if (FIXNUM_P(obj)) {
|
|
207
207
|
res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
|
|
208
208
|
} else {
|
|
209
|
-
|
|
209
|
+
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
|
210
210
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
|
211
|
-
signature.params[i].type_name().c_str(), rb_obj_classname
|
|
211
|
+
signature.params[i].type_name().c_str(), Rice::detail::protect(rb_obj_classname, obj), idx + 1);
|
|
212
212
|
}
|
|
213
213
|
}
|
|
214
214
|
return res;
|
|
@@ -270,7 +270,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
|
270
270
|
|
|
271
271
|
auto it = dtype_map.find(args[i]);
|
|
272
272
|
if (it == dtype_map.end()) {
|
|
273
|
-
|
|
273
|
+
throw Rice::Exception(rb_eArgError, "invalid dtype: %s", Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, args[i])));
|
|
274
274
|
}
|
|
275
275
|
return it->second;
|
|
276
276
|
}
|
|
@@ -317,13 +317,13 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
|
|
317
317
|
auto size = RARRAY_LEN(arg);
|
|
318
318
|
std::vector<double> res(size);
|
|
319
319
|
for (idx = 0; idx < size; idx++) {
|
|
320
|
-
VALUE obj = rb_ary_entry
|
|
320
|
+
VALUE obj = Rice::detail::protect(rb_ary_entry, arg, idx);
|
|
321
321
|
if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
|
|
322
322
|
res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
|
|
323
323
|
} else {
|
|
324
|
-
|
|
324
|
+
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
|
325
325
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
|
326
|
-
signature.params[i].type_name().c_str(), rb_obj_classname
|
|
326
|
+
signature.params[i].type_name().c_str(), Rice::detail::protect(rb_obj_classname, obj), idx + 1);
|
|
327
327
|
}
|
|
328
328
|
}
|
|
329
329
|
return res;
|
|
@@ -338,7 +338,7 @@ inline at::Layout RubyArgs::layout(int i) {
|
|
|
338
338
|
|
|
339
339
|
auto it = layout_map.find(args[i]);
|
|
340
340
|
if (it == layout_map.end()) {
|
|
341
|
-
|
|
341
|
+
throw Rice::Exception(rb_eArgError, "invalid layout: %s", Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, args[i])));
|
|
342
342
|
}
|
|
343
343
|
return it->second;
|
|
344
344
|
}
|
|
@@ -358,7 +358,7 @@ inline at::Device RubyArgs::device(int i) {
|
|
|
358
358
|
return at::Device("cpu");
|
|
359
359
|
}
|
|
360
360
|
if (RB_TYPE_P(args[i], T_STRING)) {
|
|
361
|
-
const std::string
|
|
361
|
+
const std::string& device_str = THPUtils_unpackString(args[i]);
|
|
362
362
|
return at::Device(device_str);
|
|
363
363
|
}
|
|
364
364
|
return Rice::detail::From_Ruby<at::Device>().convert(args[i]);
|
|
@@ -461,22 +461,22 @@ struct RubyArgParser {
|
|
|
461
461
|
|
|
462
462
|
// Check deprecated signatures last
|
|
463
463
|
std::stable_partition(signatures_.begin(), signatures_.end(),
|
|
464
|
-
[](const FunctionSignature
|
|
464
|
+
[](const FunctionSignature& sig) {
|
|
465
465
|
return !sig.deprecated;
|
|
466
466
|
});
|
|
467
467
|
}
|
|
468
468
|
|
|
469
469
|
template<int N>
|
|
470
|
-
inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N
|
|
470
|
+
inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N>& dst) {
|
|
471
471
|
if (N < max_args) {
|
|
472
|
-
|
|
472
|
+
throw Rice::Exception(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", static_cast<int>(max_args), N);
|
|
473
473
|
}
|
|
474
474
|
return raw_parse(self, argc, argv, dst.args);
|
|
475
475
|
}
|
|
476
476
|
|
|
477
477
|
inline RubyArgs raw_parse(VALUE self, int argc, VALUE* argv, VALUE parsed_args[]) {
|
|
478
478
|
VALUE args, kwargs;
|
|
479
|
-
rb_scan_args
|
|
479
|
+
Rice::detail::protect(rb_scan_args, argc, argv, "*:", &args, &kwargs);
|
|
480
480
|
|
|
481
481
|
if (signatures_.size() == 1) {
|
|
482
482
|
auto& signature = signatures_[0];
|
|
@@ -493,7 +493,7 @@ struct RubyArgParser {
|
|
|
493
493
|
print_error(self, args, kwargs, parsed_args);
|
|
494
494
|
|
|
495
495
|
// TODO better message
|
|
496
|
-
|
|
496
|
+
throw Rice::Exception(rb_eArgError, "No matching signatures");
|
|
497
497
|
}
|
|
498
498
|
|
|
499
499
|
void print_error(VALUE self, VALUE args, VALUE kwargs, VALUE parsed_args[]) {
|
data/ext/torch/templates.h
CHANGED
|
@@ -54,8 +54,8 @@ namespace Rice::detail {
|
|
|
54
54
|
|
|
55
55
|
explicit To_Ruby(Arg* arg) : arg_(arg) { }
|
|
56
56
|
|
|
57
|
-
VALUE convert(c10::complex<T
|
|
58
|
-
return rb_dbl_complex_new
|
|
57
|
+
VALUE convert(const c10::complex<T>& x) {
|
|
58
|
+
return protect(rb_dbl_complex_new, x.real(), x.imag());
|
|
59
59
|
}
|
|
60
60
|
|
|
61
61
|
private:
|
|
@@ -72,8 +72,8 @@ namespace Rice::detail {
|
|
|
72
72
|
double is_convertible(VALUE value) { return Convertible::Exact; }
|
|
73
73
|
|
|
74
74
|
c10::complex<T> convert(VALUE x) {
|
|
75
|
-
VALUE real = rb_funcall
|
|
76
|
-
VALUE imag = rb_funcall
|
|
75
|
+
VALUE real = protect(rb_funcall, x, rb_intern("real"), 0);
|
|
76
|
+
VALUE imag = protect(rb_funcall, x, rb_intern("imag"), 0);
|
|
77
77
|
return c10::complex<T>(From_Ruby<T>().convert(real), From_Ruby<T>().convert(imag));
|
|
78
78
|
}
|
|
79
79
|
|
data/ext/torch/tensor.cpp
CHANGED
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
#include <optional>
|
|
1
2
|
#include <string>
|
|
3
|
+
#include <string_view>
|
|
2
4
|
#include <vector>
|
|
3
5
|
|
|
4
6
|
#include <torch/torch.h>
|
|
@@ -28,15 +30,15 @@ Array flat_data(Tensor& tensor) {
|
|
|
28
30
|
Rice::Class rb_cTensor;
|
|
29
31
|
|
|
30
32
|
std::vector<TensorIndex> index_vector(Array a) {
|
|
31
|
-
Object obj;
|
|
32
|
-
|
|
33
33
|
std::vector<TensorIndex> indices;
|
|
34
34
|
indices.reserve(a.size());
|
|
35
35
|
|
|
36
36
|
for (long i = 0; i < a.size(); i++) {
|
|
37
|
-
obj
|
|
37
|
+
Object obj(a[i]);
|
|
38
38
|
|
|
39
|
-
if (obj.
|
|
39
|
+
if (obj.is_nil()) {
|
|
40
|
+
indices.push_back(torch::indexing::None);
|
|
41
|
+
} else if (obj.is_instance_of(rb_cInteger)) {
|
|
40
42
|
indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
|
|
41
43
|
} else if (obj.is_instance_of(rb_cRange)) {
|
|
42
44
|
torch::optional<c10::SymInt> start_index = torch::nullopt;
|
|
@@ -64,12 +66,10 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
|
64
66
|
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
|
65
67
|
} else if (obj.is_instance_of(rb_cTensor)) {
|
|
66
68
|
indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
|
|
67
|
-
} else if (obj.
|
|
68
|
-
indices.push_back(torch::indexing::None);
|
|
69
|
-
} else if (obj == Rice::True || obj == Rice::False) {
|
|
69
|
+
} else if (obj.value() == Qtrue || obj.value() == Qfalse) {
|
|
70
70
|
indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
|
|
71
71
|
} else {
|
|
72
|
-
throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(
|
|
72
|
+
throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", Rice::detail::protect(rb_obj_classname, obj.value()));
|
|
73
73
|
}
|
|
74
74
|
}
|
|
75
75
|
return indices;
|
|
@@ -102,15 +102,13 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
102
102
|
add_tensor_functions(rb_cTensor);
|
|
103
103
|
THPVariableClass = rb_cTensor.value();
|
|
104
104
|
|
|
105
|
-
rb_define_method(rb_cTensor, "backward",
|
|
105
|
+
rb_define_method(rb_cTensor, "backward", tensor__backward, -1);
|
|
106
106
|
|
|
107
107
|
rb_cTensor
|
|
108
108
|
.define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
|
|
109
109
|
.define_method("mps?", [](Tensor& self) { return self.is_mps(); })
|
|
110
110
|
.define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
|
|
111
111
|
.define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
|
|
112
|
-
.define_method("dim", [](Tensor& self) { return self.dim(); })
|
|
113
|
-
.define_method("numel", [](Tensor& self) { return self.numel(); })
|
|
114
112
|
.define_method("element_size", [](Tensor& self) { return self.element_size(); })
|
|
115
113
|
.define_method("requires_grad", [](Tensor& self) { return self.requires_grad(); })
|
|
116
114
|
.define_method(
|
|
@@ -128,7 +126,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
128
126
|
"shape",
|
|
129
127
|
[](Tensor& self) {
|
|
130
128
|
Array a;
|
|
131
|
-
for (auto
|
|
129
|
+
for (const auto& size : self.sizes()) {
|
|
132
130
|
a.push(size, false);
|
|
133
131
|
}
|
|
134
132
|
return a;
|
|
@@ -137,7 +135,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
137
135
|
"_strides",
|
|
138
136
|
[](Tensor& self) {
|
|
139
137
|
Array a;
|
|
140
|
-
for (auto
|
|
138
|
+
for (const auto& stride : self.strides()) {
|
|
141
139
|
a.push(stride, false);
|
|
142
140
|
}
|
|
143
141
|
return a;
|
|
@@ -154,11 +152,6 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
154
152
|
auto vec = index_vector(indices);
|
|
155
153
|
return self.index_put_(vec, value);
|
|
156
154
|
})
|
|
157
|
-
.define_method(
|
|
158
|
-
"contiguous?",
|
|
159
|
-
[](Tensor& self) {
|
|
160
|
-
return self.is_contiguous();
|
|
161
|
-
})
|
|
162
155
|
.define_method(
|
|
163
156
|
"_requires_grad!",
|
|
164
157
|
[](Tensor& self, bool requires_grad) {
|
|
@@ -168,7 +161,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
168
161
|
"grad",
|
|
169
162
|
[](Tensor& self) {
|
|
170
163
|
auto grad = self.grad();
|
|
171
|
-
return grad.defined() ?
|
|
164
|
+
return grad.defined() ? std::optional<torch::Tensor>{grad} : std::nullopt;
|
|
172
165
|
})
|
|
173
166
|
// can't use grad=
|
|
174
167
|
// assignment methods fail with Ruby 3.0
|
|
@@ -184,15 +177,15 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
184
177
|
|
|
185
178
|
// TODO support sparse grad
|
|
186
179
|
if (!grad.options().type_equal(self.options())) {
|
|
187
|
-
|
|
180
|
+
throw Rice::Exception(rb_eArgError, "assigned grad has data of a different type");
|
|
188
181
|
}
|
|
189
182
|
|
|
190
183
|
if (self.is_cuda() && grad.get_device() != self.get_device()) {
|
|
191
|
-
|
|
184
|
+
throw Rice::Exception(rb_eArgError, "assigned grad has data located on a different device");
|
|
192
185
|
}
|
|
193
186
|
|
|
194
187
|
if (!self.sizes().equals(grad.sizes())) {
|
|
195
|
-
|
|
188
|
+
throw Rice::Exception(rb_eArgError, "assigned grad has data of a different size");
|
|
196
189
|
}
|
|
197
190
|
|
|
198
191
|
self.mutable_grad() = grad;
|
|
@@ -234,8 +227,8 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
234
227
|
tensor = tensor.contiguous();
|
|
235
228
|
}
|
|
236
229
|
|
|
237
|
-
auto data_ptr =
|
|
238
|
-
return std::
|
|
230
|
+
auto data_ptr = static_cast<const char *>(tensor.data_ptr());
|
|
231
|
+
return Rice::String(std::string_view(data_ptr, tensor.numel() * tensor.element_size()));
|
|
239
232
|
})
|
|
240
233
|
// for TorchVision
|
|
241
234
|
.define_method(
|
data/ext/torch/torch.cpp
CHANGED
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
#include "utils.h"
|
|
13
13
|
|
|
14
14
|
template<typename T>
|
|
15
|
-
torch::Tensor make_tensor(Rice::Array a, const std::vector<int64_t
|
|
15
|
+
torch::Tensor make_tensor(Rice::Array a, const std::vector<int64_t>& size, const torch::TensorOptions& options) {
|
|
16
16
|
std::vector<T> vec;
|
|
17
17
|
vec.reserve(a.size());
|
|
18
18
|
for (long i = 0; i < a.size(); i++) {
|
|
@@ -61,13 +61,13 @@ void init_torch(Rice::Module& m) {
|
|
|
61
61
|
// begin operations
|
|
62
62
|
.define_singleton_function(
|
|
63
63
|
"_save",
|
|
64
|
-
[](const torch::IValue
|
|
64
|
+
[](const torch::IValue& value) {
|
|
65
65
|
auto v = torch::pickle_save(value);
|
|
66
|
-
return Rice::Object(rb_str_new
|
|
66
|
+
return Rice::Object(Rice::detail::protect(rb_str_new, v.data(), v.size()));
|
|
67
67
|
})
|
|
68
68
|
.define_singleton_function(
|
|
69
69
|
"_load",
|
|
70
|
-
[](const std::string
|
|
70
|
+
[](const std::string& filename) {
|
|
71
71
|
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
|
72
72
|
std::ifstream input(filename, std::ios::binary);
|
|
73
73
|
std::vector<char> bytes(
|
|
@@ -78,13 +78,13 @@ void init_torch(Rice::Module& m) {
|
|
|
78
78
|
})
|
|
79
79
|
.define_singleton_function(
|
|
80
80
|
"_from_blob",
|
|
81
|
-
[](Rice::String s, const std::vector<int64_t
|
|
81
|
+
[](Rice::String s, const std::vector<int64_t>& size, const torch::TensorOptions& options) {
|
|
82
82
|
void *data = const_cast<char *>(s.c_str());
|
|
83
83
|
return torch::from_blob(data, size, options);
|
|
84
84
|
})
|
|
85
85
|
.define_singleton_function(
|
|
86
86
|
"_tensor",
|
|
87
|
-
[](Rice::Array a, const std::vector<int64_t
|
|
87
|
+
[](Rice::Array a, const std::vector<int64_t>& size, const torch::TensorOptions& options) {
|
|
88
88
|
auto dtype = options.dtype();
|
|
89
89
|
if (dtype == torch::kByte) {
|
|
90
90
|
return make_tensor<uint8_t>(a, size, options);
|
data/ext/torch/utils.h
CHANGED
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
#include <rice/stl.hpp>
|
|
9
9
|
|
|
10
10
|
static_assert(
|
|
11
|
-
TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR ==
|
|
11
|
+
TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 12,
|
|
12
12
|
"Incompatible LibTorch version"
|
|
13
13
|
);
|
|
14
14
|
|
|
@@ -50,17 +50,17 @@ inline bool THPUtils_checkScalar(VALUE obj) {
|
|
|
50
50
|
}
|
|
51
51
|
|
|
52
52
|
inline bool THPDevice_Check(VALUE obj) {
|
|
53
|
-
return rb_obj_is_kind_of
|
|
53
|
+
return Rice::detail::protect(rb_obj_is_kind_of, obj, THPDeviceClass);
|
|
54
54
|
}
|
|
55
55
|
|
|
56
56
|
inline bool THPGenerator_Check(VALUE obj) {
|
|
57
|
-
return rb_obj_is_kind_of
|
|
57
|
+
return Rice::detail::protect(rb_obj_is_kind_of, obj, THPGeneratorClass);
|
|
58
58
|
}
|
|
59
59
|
|
|
60
60
|
inline bool THPVariable_Check(VALUE obj) {
|
|
61
|
-
return rb_obj_is_kind_of
|
|
61
|
+
return Rice::detail::protect(rb_obj_is_kind_of, obj, THPVariableClass);
|
|
62
62
|
}
|
|
63
63
|
|
|
64
64
|
inline bool THPVariable_CheckExact(VALUE obj) {
|
|
65
|
-
return rb_obj_is_instance_of
|
|
65
|
+
return Rice::detail::protect(rb_obj_is_instance_of, obj, THPVariableClass);
|
|
66
66
|
}
|
data/ext/torch/wrap_outputs.h
CHANGED
|
@@ -15,32 +15,34 @@ inline VALUE wrap(double x) {
|
|
|
15
15
|
return Rice::detail::To_Ruby<double>().convert(x);
|
|
16
16
|
}
|
|
17
17
|
|
|
18
|
-
inline VALUE wrap(torch::Tensor x) {
|
|
18
|
+
inline VALUE wrap(const torch::Tensor& x) {
|
|
19
19
|
return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
|
|
20
20
|
}
|
|
21
21
|
|
|
22
|
-
inline VALUE wrap(torch::Scalar x) {
|
|
22
|
+
inline VALUE wrap(const torch::Scalar& x) {
|
|
23
23
|
return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
|
|
24
24
|
}
|
|
25
25
|
|
|
26
|
-
inline VALUE wrap(torch::ScalarType x) {
|
|
26
|
+
inline VALUE wrap(const torch::ScalarType& x) {
|
|
27
27
|
return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
|
|
28
28
|
}
|
|
29
29
|
|
|
30
|
-
inline VALUE wrap(torch::QScheme x) {
|
|
30
|
+
inline VALUE wrap(const torch::QScheme& x) {
|
|
31
31
|
return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
|
|
32
32
|
}
|
|
33
33
|
|
|
34
|
-
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor
|
|
35
|
-
return
|
|
34
|
+
inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor>& x) {
|
|
35
|
+
return Rice::detail::protect(
|
|
36
|
+
rb_ary_new3,
|
|
36
37
|
2,
|
|
37
38
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
|
38
39
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
|
|
39
40
|
);
|
|
40
41
|
}
|
|
41
42
|
|
|
42
|
-
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor
|
|
43
|
-
return
|
|
43
|
+
inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>& x) {
|
|
44
|
+
return Rice::detail::protect(
|
|
45
|
+
rb_ary_new3,
|
|
44
46
|
3,
|
|
45
47
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
|
46
48
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
|
@@ -48,8 +50,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
|
|
48
50
|
);
|
|
49
51
|
}
|
|
50
52
|
|
|
51
|
-
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor
|
|
52
|
-
return
|
|
53
|
+
inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>& x) {
|
|
54
|
+
return Rice::detail::protect(
|
|
55
|
+
rb_ary_new3,
|
|
53
56
|
4,
|
|
54
57
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
|
55
58
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
|
@@ -58,8 +61,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch:
|
|
|
58
61
|
);
|
|
59
62
|
}
|
|
60
63
|
|
|
61
|
-
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor
|
|
62
|
-
return
|
|
64
|
+
inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>& x) {
|
|
65
|
+
return Rice::detail::protect(
|
|
66
|
+
rb_ary_new3,
|
|
63
67
|
5,
|
|
64
68
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
|
65
69
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
|
@@ -69,8 +73,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch:
|
|
|
69
73
|
);
|
|
70
74
|
}
|
|
71
75
|
|
|
72
|
-
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t
|
|
73
|
-
return
|
|
76
|
+
inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t>& x) {
|
|
77
|
+
return Rice::detail::protect(
|
|
78
|
+
rb_ary_new3,
|
|
74
79
|
4,
|
|
75
80
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
|
76
81
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
|
@@ -79,8 +84,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_
|
|
|
79
84
|
);
|
|
80
85
|
}
|
|
81
86
|
|
|
82
|
-
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t
|
|
83
|
-
return
|
|
87
|
+
inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, double, int64_t>& x) {
|
|
88
|
+
return Rice::detail::protect(
|
|
89
|
+
rb_ary_new3,
|
|
84
90
|
4,
|
|
85
91
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
|
86
92
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
|
@@ -89,16 +95,17 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
|
|
89
95
|
);
|
|
90
96
|
}
|
|
91
97
|
|
|
92
|
-
inline VALUE wrap(torch::TensorList x) {
|
|
93
|
-
auto a = rb_ary_new2
|
|
94
|
-
for (auto t : x) {
|
|
95
|
-
rb_ary_push
|
|
98
|
+
inline VALUE wrap(const torch::TensorList& x) {
|
|
99
|
+
auto a = Rice::detail::protect(rb_ary_new2, x.size());
|
|
100
|
+
for (const auto& t : x) {
|
|
101
|
+
Rice::detail::protect(rb_ary_push, a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
|
|
96
102
|
}
|
|
97
103
|
return a;
|
|
98
104
|
}
|
|
99
105
|
|
|
100
|
-
inline VALUE wrap(std::tuple<double, double
|
|
101
|
-
return
|
|
106
|
+
inline VALUE wrap(const std::tuple<double, double>& x) {
|
|
107
|
+
return Rice::detail::protect(
|
|
108
|
+
rb_ary_new3,
|
|
102
109
|
2,
|
|
103
110
|
Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
|
|
104
111
|
Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
|