torch-rb 0.23.1 → 0.24.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +3 -1
- data/codegen/generate_functions.rb +1 -6
- data/codegen/native_functions.yaml +229 -58
- data/ext/torch/ivalue.cpp +4 -4
- data/ext/torch/ruby_arg_parser.cpp +14 -14
- data/ext/torch/ruby_arg_parser.h +11 -11
- data/ext/torch/templates.h +1 -1
- data/ext/torch/tensor.cpp +6 -12
- data/ext/torch/torch.cpp +6 -6
- data/ext/torch/utils.h +5 -5
- data/ext/torch/wrap_outputs.h +29 -22
- data/lib/torch/hub.rb +8 -28
- data/lib/torch/nn/module.rb +1 -1
- data/lib/torch/nn/rnn_base.rb +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +4 -4
- metadata +3 -3
data/ext/torch/ivalue.cpp
CHANGED
|
@@ -53,7 +53,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
|
|
|
53
53
|
[](torch::IValue& self) {
|
|
54
54
|
auto list = self.toListRef();
|
|
55
55
|
Rice::Array obj;
|
|
56
|
-
for (auto& elem : list) {
|
|
56
|
+
for (const auto& elem : list) {
|
|
57
57
|
auto v = torch::IValue{elem};
|
|
58
58
|
obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)), false);
|
|
59
59
|
}
|
|
@@ -74,7 +74,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
|
|
|
74
74
|
[](torch::IValue& self) {
|
|
75
75
|
auto dict = self.toGenericDict();
|
|
76
76
|
Rice::Hash obj;
|
|
77
|
-
for (auto& pair : dict) {
|
|
77
|
+
for (const auto& pair : dict) {
|
|
78
78
|
auto k = torch::IValue{pair.key()};
|
|
79
79
|
auto v = torch::IValue{pair.value()};
|
|
80
80
|
obj[Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(k))] = Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v));
|
|
@@ -91,7 +91,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
|
|
|
91
91
|
"from_list",
|
|
92
92
|
[](Rice::Array obj) {
|
|
93
93
|
c10::impl::GenericList list(c10::AnyType::get());
|
|
94
|
-
for (auto entry : obj) {
|
|
94
|
+
for (const auto& entry : obj) {
|
|
95
95
|
list.push_back(Rice::detail::From_Ruby<torch::IValue>().convert(entry.value()));
|
|
96
96
|
}
|
|
97
97
|
return torch::IValue(list);
|
|
@@ -125,7 +125,7 @@ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
|
|
|
125
125
|
auto value_type = c10::AnyType::get();
|
|
126
126
|
c10::impl::GenericDict elems(key_type, value_type);
|
|
127
127
|
elems.reserve(obj.size());
|
|
128
|
-
for (auto entry : obj) {
|
|
128
|
+
for (const auto& entry : obj) {
|
|
129
129
|
elems.insert(Rice::detail::From_Ruby<torch::IValue>().convert(entry.first), Rice::detail::From_Ruby<torch::IValue>().convert((Rice::Object) entry.second));
|
|
130
130
|
}
|
|
131
131
|
return torch::IValue(std::move(elems));
|
|
@@ -117,11 +117,11 @@ bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
|
|
|
117
117
|
}
|
|
118
118
|
auto size = RARRAY_LEN(obj);
|
|
119
119
|
for (int idx = 0; idx < size; idx++) {
|
|
120
|
-
VALUE iobj = rb_ary_entry
|
|
120
|
+
VALUE iobj = Rice::detail::protect(rb_ary_entry, obj, idx);
|
|
121
121
|
if (!THPVariable_Check(iobj)) {
|
|
122
122
|
if (throw_error) {
|
|
123
123
|
throw Rice::Exception(rb_eArgError, "expected Tensor as element %d in argument %d, but got %s",
|
|
124
|
-
static_cast<int>(idx), argnum, rb_obj_classname
|
|
124
|
+
static_cast<int>(idx), argnum, Rice::detail::protect(rb_obj_classname, obj));
|
|
125
125
|
}
|
|
126
126
|
return false;
|
|
127
127
|
}
|
|
@@ -136,7 +136,7 @@ static bool is_int_list(VALUE obj, int broadcast_size) {
|
|
|
136
136
|
return true;
|
|
137
137
|
}
|
|
138
138
|
|
|
139
|
-
auto item = rb_ary_entry
|
|
139
|
+
auto item = Rice::detail::protect(rb_ary_entry, obj, 0);
|
|
140
140
|
bool int_first = false;
|
|
141
141
|
if (THPUtils_checkIndex(item)) {
|
|
142
142
|
// we still have to check that the rest of items are NOT symint nodes
|
|
@@ -178,7 +178,7 @@ static bool is_int_or_symint_list(VALUE obj, int broadcast_size) {
|
|
|
178
178
|
if (RARRAY_LEN(obj) == 0) {
|
|
179
179
|
return true;
|
|
180
180
|
}
|
|
181
|
-
auto item = rb_ary_entry
|
|
181
|
+
auto item = Rice::detail::protect(rb_ary_entry, obj, 0);
|
|
182
182
|
|
|
183
183
|
if (is_int_or_symint(item)) {
|
|
184
184
|
return true;
|
|
@@ -631,23 +631,23 @@ static ssize_t find_param(FunctionSignature& signature, VALUE name) {
|
|
|
631
631
|
static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num_pos_args) {
|
|
632
632
|
VALUE key;
|
|
633
633
|
|
|
634
|
-
VALUE keys = rb_funcall
|
|
634
|
+
VALUE keys = Rice::detail::protect(rb_funcall, kwargs, rb_intern("keys"), 0);
|
|
635
635
|
if (RARRAY_LEN(keys) > 0) {
|
|
636
|
-
key = rb_ary_entry
|
|
636
|
+
key = Rice::detail::protect(rb_ary_entry, keys, 0);
|
|
637
637
|
|
|
638
638
|
if (!THPUtils_checkSymbol(key)) {
|
|
639
|
-
throw Rice::Exception(rb_eArgError, "keywords must be symbols, not %s", rb_obj_classname
|
|
639
|
+
throw Rice::Exception(rb_eArgError, "keywords must be symbols, not %s", Rice::detail::protect(rb_obj_classname, key));
|
|
640
640
|
}
|
|
641
641
|
|
|
642
642
|
auto param_idx = find_param(signature, key);
|
|
643
643
|
if (param_idx < 0) {
|
|
644
644
|
throw Rice::Exception(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
|
|
645
|
-
signature.name.c_str(), rb_id2name(rb_to_id
|
|
645
|
+
signature.name.c_str(), Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, key)));
|
|
646
646
|
}
|
|
647
647
|
|
|
648
648
|
if (param_idx < num_pos_args) {
|
|
649
649
|
throw Rice::Exception(rb_eArgError, "%s() got multiple values for argument '%s'",
|
|
650
|
-
signature.name.c_str(), rb_id2name(rb_to_id
|
|
650
|
+
signature.name.c_str(), Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, key)));
|
|
651
651
|
}
|
|
652
652
|
}
|
|
653
653
|
|
|
@@ -703,14 +703,14 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
|
|
|
703
703
|
}
|
|
704
704
|
return false;
|
|
705
705
|
}
|
|
706
|
-
obj = rb_ary_entry
|
|
706
|
+
obj = Rice::detail::protect(rb_ary_entry, args, arg_pos);
|
|
707
707
|
} else if (!NIL_P(kwargs)) {
|
|
708
|
-
obj = rb_hash_lookup2
|
|
708
|
+
obj = Rice::detail::protect(rb_hash_lookup2, kwargs, param.ruby_name, missing);
|
|
709
709
|
// for (VALUE numpy_name: param.numpy_python_names) {
|
|
710
710
|
// if (obj) {
|
|
711
711
|
// break;
|
|
712
712
|
// }
|
|
713
|
-
// obj = rb_hash_aref
|
|
713
|
+
// obj = Rice::detail::protect(rb_hash_aref, kwargs, numpy_name);
|
|
714
714
|
// }
|
|
715
715
|
is_kwd = true;
|
|
716
716
|
}
|
|
@@ -742,12 +742,12 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
|
|
|
742
742
|
// foo(): argument 'other' must be str, not int
|
|
743
743
|
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, not %s",
|
|
744
744
|
name.c_str(), param.name.c_str(), param.type_name().c_str(),
|
|
745
|
-
rb_obj_classname
|
|
745
|
+
Rice::detail::protect(rb_obj_classname, obj));
|
|
746
746
|
} else {
|
|
747
747
|
// foo(): argument 'other' (position 2) must be str, not int
|
|
748
748
|
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' (position %ld) must be %s, not %s",
|
|
749
749
|
name.c_str(), param.name.c_str(), static_cast<long>(arg_pos + 1),
|
|
750
|
-
param.type_name().c_str(), rb_obj_classname
|
|
750
|
+
param.type_name().c_str(), Rice::detail::protect(rb_obj_classname, obj));
|
|
751
751
|
}
|
|
752
752
|
} else {
|
|
753
753
|
return false;
|
data/ext/torch/ruby_arg_parser.h
CHANGED
|
@@ -168,7 +168,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
|
|
168
168
|
throw Rice::Exception(rb_eArgError, "expected array of %d elements but got %d", N, static_cast<int>(size));
|
|
169
169
|
}
|
|
170
170
|
for (int idx = 0; idx < size; idx++) {
|
|
171
|
-
VALUE obj = rb_ary_entry
|
|
171
|
+
VALUE obj = Rice::detail::protect(rb_ary_entry, arg, idx);
|
|
172
172
|
res[idx] = Rice::detail::From_Ruby<Tensor>().convert(obj);
|
|
173
173
|
}
|
|
174
174
|
return res;
|
|
@@ -202,13 +202,13 @@ inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int6
|
|
|
202
202
|
size = RARRAY_LEN(arg);
|
|
203
203
|
std::vector<int64_t> res(size);
|
|
204
204
|
for (idx = 0; idx < size; idx++) {
|
|
205
|
-
VALUE obj = rb_ary_entry
|
|
205
|
+
VALUE obj = Rice::detail::protect(rb_ary_entry, arg, idx);
|
|
206
206
|
if (FIXNUM_P(obj)) {
|
|
207
207
|
res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
|
|
208
208
|
} else {
|
|
209
209
|
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
|
210
210
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
|
211
|
-
signature.params[i].type_name().c_str(), rb_obj_classname
|
|
211
|
+
signature.params[i].type_name().c_str(), Rice::detail::protect(rb_obj_classname, obj), idx + 1);
|
|
212
212
|
}
|
|
213
213
|
}
|
|
214
214
|
return res;
|
|
@@ -270,7 +270,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
|
270
270
|
|
|
271
271
|
auto it = dtype_map.find(args[i]);
|
|
272
272
|
if (it == dtype_map.end()) {
|
|
273
|
-
throw Rice::Exception(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id
|
|
273
|
+
throw Rice::Exception(rb_eArgError, "invalid dtype: %s", Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, args[i])));
|
|
274
274
|
}
|
|
275
275
|
return it->second;
|
|
276
276
|
}
|
|
@@ -317,13 +317,13 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
|
|
317
317
|
auto size = RARRAY_LEN(arg);
|
|
318
318
|
std::vector<double> res(size);
|
|
319
319
|
for (idx = 0; idx < size; idx++) {
|
|
320
|
-
VALUE obj = rb_ary_entry
|
|
320
|
+
VALUE obj = Rice::detail::protect(rb_ary_entry, arg, idx);
|
|
321
321
|
if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
|
|
322
322
|
res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
|
|
323
323
|
} else {
|
|
324
324
|
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
|
325
325
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
|
326
|
-
signature.params[i].type_name().c_str(), rb_obj_classname
|
|
326
|
+
signature.params[i].type_name().c_str(), Rice::detail::protect(rb_obj_classname, obj), idx + 1);
|
|
327
327
|
}
|
|
328
328
|
}
|
|
329
329
|
return res;
|
|
@@ -338,7 +338,7 @@ inline at::Layout RubyArgs::layout(int i) {
|
|
|
338
338
|
|
|
339
339
|
auto it = layout_map.find(args[i]);
|
|
340
340
|
if (it == layout_map.end()) {
|
|
341
|
-
throw Rice::Exception(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id
|
|
341
|
+
throw Rice::Exception(rb_eArgError, "invalid layout: %s", Rice::detail::protect(rb_id2name, Rice::detail::protect(rb_to_id, args[i])));
|
|
342
342
|
}
|
|
343
343
|
return it->second;
|
|
344
344
|
}
|
|
@@ -358,7 +358,7 @@ inline at::Device RubyArgs::device(int i) {
|
|
|
358
358
|
return at::Device("cpu");
|
|
359
359
|
}
|
|
360
360
|
if (RB_TYPE_P(args[i], T_STRING)) {
|
|
361
|
-
const std::string
|
|
361
|
+
const std::string& device_str = THPUtils_unpackString(args[i]);
|
|
362
362
|
return at::Device(device_str);
|
|
363
363
|
}
|
|
364
364
|
return Rice::detail::From_Ruby<at::Device>().convert(args[i]);
|
|
@@ -461,13 +461,13 @@ struct RubyArgParser {
|
|
|
461
461
|
|
|
462
462
|
// Check deprecated signatures last
|
|
463
463
|
std::stable_partition(signatures_.begin(), signatures_.end(),
|
|
464
|
-
[](const FunctionSignature
|
|
464
|
+
[](const FunctionSignature& sig) {
|
|
465
465
|
return !sig.deprecated;
|
|
466
466
|
});
|
|
467
467
|
}
|
|
468
468
|
|
|
469
469
|
template<int N>
|
|
470
|
-
inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N
|
|
470
|
+
inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N>& dst) {
|
|
471
471
|
if (N < max_args) {
|
|
472
472
|
throw Rice::Exception(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", static_cast<int>(max_args), N);
|
|
473
473
|
}
|
|
@@ -476,7 +476,7 @@ struct RubyArgParser {
|
|
|
476
476
|
|
|
477
477
|
inline RubyArgs raw_parse(VALUE self, int argc, VALUE* argv, VALUE parsed_args[]) {
|
|
478
478
|
VALUE args, kwargs;
|
|
479
|
-
rb_scan_args
|
|
479
|
+
Rice::detail::protect(rb_scan_args, argc, argv, "*:", &args, &kwargs);
|
|
480
480
|
|
|
481
481
|
if (signatures_.size() == 1) {
|
|
482
482
|
auto& signature = signatures_[0];
|
data/ext/torch/templates.h
CHANGED
data/ext/torch/tensor.cpp
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
#include <optional>
|
|
2
2
|
#include <string>
|
|
3
|
+
#include <string_view>
|
|
3
4
|
#include <vector>
|
|
4
5
|
|
|
5
6
|
#include <torch/torch.h>
|
|
@@ -68,7 +69,7 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
|
68
69
|
} else if (obj.value() == Qtrue || obj.value() == Qfalse) {
|
|
69
70
|
indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
|
|
70
71
|
} else {
|
|
71
|
-
throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname
|
|
72
|
+
throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", Rice::detail::protect(rb_obj_classname, obj.value()));
|
|
72
73
|
}
|
|
73
74
|
}
|
|
74
75
|
return indices;
|
|
@@ -108,8 +109,6 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
108
109
|
.define_method("mps?", [](Tensor& self) { return self.is_mps(); })
|
|
109
110
|
.define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
|
|
110
111
|
.define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
|
|
111
|
-
.define_method("dim", [](Tensor& self) { return self.dim(); })
|
|
112
|
-
.define_method("numel", [](Tensor& self) { return self.numel(); })
|
|
113
112
|
.define_method("element_size", [](Tensor& self) { return self.element_size(); })
|
|
114
113
|
.define_method("requires_grad", [](Tensor& self) { return self.requires_grad(); })
|
|
115
114
|
.define_method(
|
|
@@ -127,7 +126,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
127
126
|
"shape",
|
|
128
127
|
[](Tensor& self) {
|
|
129
128
|
Array a;
|
|
130
|
-
for (auto
|
|
129
|
+
for (const auto& size : self.sizes()) {
|
|
131
130
|
a.push(size, false);
|
|
132
131
|
}
|
|
133
132
|
return a;
|
|
@@ -136,7 +135,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
136
135
|
"_strides",
|
|
137
136
|
[](Tensor& self) {
|
|
138
137
|
Array a;
|
|
139
|
-
for (auto
|
|
138
|
+
for (const auto& stride : self.strides()) {
|
|
140
139
|
a.push(stride, false);
|
|
141
140
|
}
|
|
142
141
|
return a;
|
|
@@ -153,11 +152,6 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
153
152
|
auto vec = index_vector(indices);
|
|
154
153
|
return self.index_put_(vec, value);
|
|
155
154
|
})
|
|
156
|
-
.define_method(
|
|
157
|
-
"contiguous?",
|
|
158
|
-
[](Tensor& self) {
|
|
159
|
-
return self.is_contiguous();
|
|
160
|
-
})
|
|
161
155
|
.define_method(
|
|
162
156
|
"_requires_grad!",
|
|
163
157
|
[](Tensor& self, bool requires_grad) {
|
|
@@ -233,8 +227,8 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
233
227
|
tensor = tensor.contiguous();
|
|
234
228
|
}
|
|
235
229
|
|
|
236
|
-
auto data_ptr =
|
|
237
|
-
return std::
|
|
230
|
+
auto data_ptr = static_cast<const char *>(tensor.data_ptr());
|
|
231
|
+
return Rice::String(std::string_view(data_ptr, tensor.numel() * tensor.element_size()));
|
|
238
232
|
})
|
|
239
233
|
// for TorchVision
|
|
240
234
|
.define_method(
|
data/ext/torch/torch.cpp
CHANGED
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
#include "utils.h"
|
|
13
13
|
|
|
14
14
|
template<typename T>
|
|
15
|
-
torch::Tensor make_tensor(Rice::Array a, const std::vector<int64_t
|
|
15
|
+
torch::Tensor make_tensor(Rice::Array a, const std::vector<int64_t>& size, const torch::TensorOptions& options) {
|
|
16
16
|
std::vector<T> vec;
|
|
17
17
|
vec.reserve(a.size());
|
|
18
18
|
for (long i = 0; i < a.size(); i++) {
|
|
@@ -61,13 +61,13 @@ void init_torch(Rice::Module& m) {
|
|
|
61
61
|
// begin operations
|
|
62
62
|
.define_singleton_function(
|
|
63
63
|
"_save",
|
|
64
|
-
[](const torch::IValue
|
|
64
|
+
[](const torch::IValue& value) {
|
|
65
65
|
auto v = torch::pickle_save(value);
|
|
66
|
-
return Rice::Object(rb_str_new
|
|
66
|
+
return Rice::Object(Rice::detail::protect(rb_str_new, v.data(), v.size()));
|
|
67
67
|
})
|
|
68
68
|
.define_singleton_function(
|
|
69
69
|
"_load",
|
|
70
|
-
[](const std::string
|
|
70
|
+
[](const std::string& filename) {
|
|
71
71
|
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
|
72
72
|
std::ifstream input(filename, std::ios::binary);
|
|
73
73
|
std::vector<char> bytes(
|
|
@@ -78,13 +78,13 @@ void init_torch(Rice::Module& m) {
|
|
|
78
78
|
})
|
|
79
79
|
.define_singleton_function(
|
|
80
80
|
"_from_blob",
|
|
81
|
-
[](Rice::String s, const std::vector<int64_t
|
|
81
|
+
[](Rice::String s, const std::vector<int64_t>& size, const torch::TensorOptions& options) {
|
|
82
82
|
void *data = const_cast<char *>(s.c_str());
|
|
83
83
|
return torch::from_blob(data, size, options);
|
|
84
84
|
})
|
|
85
85
|
.define_singleton_function(
|
|
86
86
|
"_tensor",
|
|
87
|
-
[](Rice::Array a, const std::vector<int64_t
|
|
87
|
+
[](Rice::Array a, const std::vector<int64_t>& size, const torch::TensorOptions& options) {
|
|
88
88
|
auto dtype = options.dtype();
|
|
89
89
|
if (dtype == torch::kByte) {
|
|
90
90
|
return make_tensor<uint8_t>(a, size, options);
|
data/ext/torch/utils.h
CHANGED
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
#include <rice/stl.hpp>
|
|
9
9
|
|
|
10
10
|
static_assert(
|
|
11
|
-
TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR ==
|
|
11
|
+
TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 12,
|
|
12
12
|
"Incompatible LibTorch version"
|
|
13
13
|
);
|
|
14
14
|
|
|
@@ -50,17 +50,17 @@ inline bool THPUtils_checkScalar(VALUE obj) {
|
|
|
50
50
|
}
|
|
51
51
|
|
|
52
52
|
inline bool THPDevice_Check(VALUE obj) {
|
|
53
|
-
return rb_obj_is_kind_of
|
|
53
|
+
return Rice::detail::protect(rb_obj_is_kind_of, obj, THPDeviceClass);
|
|
54
54
|
}
|
|
55
55
|
|
|
56
56
|
inline bool THPGenerator_Check(VALUE obj) {
|
|
57
|
-
return rb_obj_is_kind_of
|
|
57
|
+
return Rice::detail::protect(rb_obj_is_kind_of, obj, THPGeneratorClass);
|
|
58
58
|
}
|
|
59
59
|
|
|
60
60
|
inline bool THPVariable_Check(VALUE obj) {
|
|
61
|
-
return rb_obj_is_kind_of
|
|
61
|
+
return Rice::detail::protect(rb_obj_is_kind_of, obj, THPVariableClass);
|
|
62
62
|
}
|
|
63
63
|
|
|
64
64
|
inline bool THPVariable_CheckExact(VALUE obj) {
|
|
65
|
-
return rb_obj_is_instance_of
|
|
65
|
+
return Rice::detail::protect(rb_obj_is_instance_of, obj, THPVariableClass);
|
|
66
66
|
}
|
data/ext/torch/wrap_outputs.h
CHANGED
|
@@ -15,32 +15,34 @@ inline VALUE wrap(double x) {
|
|
|
15
15
|
return Rice::detail::To_Ruby<double>().convert(x);
|
|
16
16
|
}
|
|
17
17
|
|
|
18
|
-
inline VALUE wrap(torch::Tensor x) {
|
|
18
|
+
inline VALUE wrap(const torch::Tensor& x) {
|
|
19
19
|
return Rice::detail::To_Ruby<torch::Tensor>().convert(x);
|
|
20
20
|
}
|
|
21
21
|
|
|
22
|
-
inline VALUE wrap(torch::Scalar x) {
|
|
22
|
+
inline VALUE wrap(const torch::Scalar& x) {
|
|
23
23
|
return Rice::detail::To_Ruby<torch::Scalar>().convert(x);
|
|
24
24
|
}
|
|
25
25
|
|
|
26
|
-
inline VALUE wrap(torch::ScalarType x) {
|
|
26
|
+
inline VALUE wrap(const torch::ScalarType& x) {
|
|
27
27
|
return Rice::detail::To_Ruby<torch::ScalarType>().convert(x);
|
|
28
28
|
}
|
|
29
29
|
|
|
30
|
-
inline VALUE wrap(torch::QScheme x) {
|
|
30
|
+
inline VALUE wrap(const torch::QScheme& x) {
|
|
31
31
|
return Rice::detail::To_Ruby<torch::QScheme>().convert(x);
|
|
32
32
|
}
|
|
33
33
|
|
|
34
|
-
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor
|
|
35
|
-
return
|
|
34
|
+
inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor>& x) {
|
|
35
|
+
return Rice::detail::protect(
|
|
36
|
+
rb_ary_new3,
|
|
36
37
|
2,
|
|
37
38
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
|
38
39
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x))
|
|
39
40
|
);
|
|
40
41
|
}
|
|
41
42
|
|
|
42
|
-
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor
|
|
43
|
-
return
|
|
43
|
+
inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>& x) {
|
|
44
|
+
return Rice::detail::protect(
|
|
45
|
+
rb_ary_new3,
|
|
44
46
|
3,
|
|
45
47
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
|
46
48
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
|
@@ -48,8 +50,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
|
|
48
50
|
);
|
|
49
51
|
}
|
|
50
52
|
|
|
51
|
-
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor
|
|
52
|
-
return
|
|
53
|
+
inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>& x) {
|
|
54
|
+
return Rice::detail::protect(
|
|
55
|
+
rb_ary_new3,
|
|
53
56
|
4,
|
|
54
57
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
|
55
58
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
|
@@ -58,8 +61,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch:
|
|
|
58
61
|
);
|
|
59
62
|
}
|
|
60
63
|
|
|
61
|
-
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor
|
|
62
|
-
return
|
|
64
|
+
inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>& x) {
|
|
65
|
+
return Rice::detail::protect(
|
|
66
|
+
rb_ary_new3,
|
|
63
67
|
5,
|
|
64
68
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
|
65
69
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
|
@@ -69,8 +73,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch:
|
|
|
69
73
|
);
|
|
70
74
|
}
|
|
71
75
|
|
|
72
|
-
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t
|
|
73
|
-
return
|
|
76
|
+
inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t>& x) {
|
|
77
|
+
return Rice::detail::protect(
|
|
78
|
+
rb_ary_new3,
|
|
74
79
|
4,
|
|
75
80
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
|
76
81
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
|
@@ -79,8 +84,9 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_
|
|
|
79
84
|
);
|
|
80
85
|
}
|
|
81
86
|
|
|
82
|
-
inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t
|
|
83
|
-
return
|
|
87
|
+
inline VALUE wrap(const std::tuple<torch::Tensor, torch::Tensor, double, int64_t>& x) {
|
|
88
|
+
return Rice::detail::protect(
|
|
89
|
+
rb_ary_new3,
|
|
84
90
|
4,
|
|
85
91
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<0>(x)),
|
|
86
92
|
Rice::detail::To_Ruby<torch::Tensor>().convert(std::get<1>(x)),
|
|
@@ -89,16 +95,17 @@ inline VALUE wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
|
|
89
95
|
);
|
|
90
96
|
}
|
|
91
97
|
|
|
92
|
-
inline VALUE wrap(torch::TensorList x) {
|
|
93
|
-
auto a = rb_ary_new2
|
|
94
|
-
for (auto t : x) {
|
|
95
|
-
rb_ary_push
|
|
98
|
+
inline VALUE wrap(const torch::TensorList& x) {
|
|
99
|
+
auto a = Rice::detail::protect(rb_ary_new2, x.size());
|
|
100
|
+
for (const auto& t : x) {
|
|
101
|
+
Rice::detail::protect(rb_ary_push, a, Rice::detail::To_Ruby<torch::Tensor>().convert(t));
|
|
96
102
|
}
|
|
97
103
|
return a;
|
|
98
104
|
}
|
|
99
105
|
|
|
100
|
-
inline VALUE wrap(std::tuple<double, double
|
|
101
|
-
return
|
|
106
|
+
inline VALUE wrap(const std::tuple<double, double>& x) {
|
|
107
|
+
return Rice::detail::protect(
|
|
108
|
+
rb_ary_new3,
|
|
102
109
|
2,
|
|
103
110
|
Rice::detail::To_Ruby<double>().convert(std::get<0>(x)),
|
|
104
111
|
Rice::detail::To_Ruby<double>().convert(std::get<1>(x))
|
data/lib/torch/hub.rb
CHANGED
|
@@ -6,37 +6,17 @@ module Torch
|
|
|
6
6
|
end
|
|
7
7
|
|
|
8
8
|
def download_url_to_file(url, dst)
|
|
9
|
-
uri
|
|
10
|
-
tmp = nil
|
|
11
|
-
location = nil
|
|
9
|
+
require "open-uri"
|
|
12
10
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
request = Net::HTTP::Get.new(uri)
|
|
16
|
-
|
|
17
|
-
http.request(request) do |response|
|
|
18
|
-
case response
|
|
19
|
-
when Net::HTTPRedirection
|
|
20
|
-
location = response["location"]
|
|
21
|
-
when Net::HTTPSuccess
|
|
22
|
-
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
|
23
|
-
File.open(tmp, "wb") do |f|
|
|
24
|
-
response.read_body do |chunk|
|
|
25
|
-
f.write(chunk)
|
|
26
|
-
end
|
|
27
|
-
end
|
|
28
|
-
else
|
|
29
|
-
raise Error, "Bad response"
|
|
30
|
-
end
|
|
31
|
-
end
|
|
32
|
-
end
|
|
11
|
+
uri = URI.parse(url)
|
|
12
|
+
raise "Invalid URL" unless uri.is_a?(URI::HTTP) # includes https
|
|
33
13
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
nil
|
|
14
|
+
puts "Downloading #{url}..."
|
|
15
|
+
uri.open(max_redirects: 10) do |download|
|
|
16
|
+
# TODO move file when possible
|
|
17
|
+
IO.copy_stream(download, dst.to_str)
|
|
39
18
|
end
|
|
19
|
+
nil
|
|
40
20
|
end
|
|
41
21
|
|
|
42
22
|
def load_state_dict_from_url(url, model_dir: nil)
|
data/lib/torch/nn/module.rb
CHANGED
data/lib/torch/nn/rnn_base.rb
CHANGED
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
|
@@ -279,7 +279,7 @@ module Torch
|
|
|
279
279
|
def self._make_tensor_class(dtype, cuda = false)
|
|
280
280
|
cls = Class.new
|
|
281
281
|
device = cuda ? "cuda" : "cpu"
|
|
282
|
-
cls.define_singleton_method(
|
|
282
|
+
cls.define_singleton_method(:new) do |*args|
|
|
283
283
|
if args.size == 1 && args.first.is_a?(Tensor)
|
|
284
284
|
args.first.send(dtype).to(device)
|
|
285
285
|
elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
|
|
@@ -347,7 +347,7 @@ module Torch
|
|
|
347
347
|
# from_blob does not own the data, so we need to keep
|
|
348
348
|
# a reference to it for duration of tensor
|
|
349
349
|
# can remove when passing pointer directly
|
|
350
|
-
tensor.instance_variable_set(
|
|
350
|
+
tensor.instance_variable_set(:@_numo_data, data)
|
|
351
351
|
tensor
|
|
352
352
|
end
|
|
353
353
|
|
|
@@ -426,11 +426,11 @@ module Torch
|
|
|
426
426
|
end
|
|
427
427
|
|
|
428
428
|
if options[:dtype].nil?
|
|
429
|
-
if data.all?
|
|
429
|
+
if data.all?(Integer)
|
|
430
430
|
options[:dtype] = :int64
|
|
431
431
|
elsif data.all? { |v| v == true || v == false }
|
|
432
432
|
options[:dtype] = :bool
|
|
433
|
-
elsif data.any?
|
|
433
|
+
elsif data.any?(Complex)
|
|
434
434
|
options[:dtype] = :complex64
|
|
435
435
|
end
|
|
436
436
|
end
|
metadata
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: torch-rb
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.
|
|
4
|
+
version: 0.24.0
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- Andrew Kane
|
|
@@ -234,14 +234,14 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
|
234
234
|
requirements:
|
|
235
235
|
- - ">="
|
|
236
236
|
- !ruby/object:Gem::Version
|
|
237
|
-
version: '3.
|
|
237
|
+
version: '3.3'
|
|
238
238
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
|
239
239
|
requirements:
|
|
240
240
|
- - ">="
|
|
241
241
|
- !ruby/object:Gem::Version
|
|
242
242
|
version: '0'
|
|
243
243
|
requirements: []
|
|
244
|
-
rubygems_version: 4.0.
|
|
244
|
+
rubygems_version: 4.0.6
|
|
245
245
|
specification_version: 4
|
|
246
246
|
summary: Deep learning for Ruby, powered by LibTorch
|
|
247
247
|
test_files: []
|