torch-rb 0.22.2 → 0.23.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +12 -0
- data/README.md +3 -2
- data/codegen/native_functions.yaml +259 -103
- data/ext/torch/device.cpp +5 -2
- data/ext/torch/nn.cpp +2 -1
- data/ext/torch/ruby_arg_parser.cpp +12 -11
- data/ext/torch/ruby_arg_parser.h +12 -9
- data/ext/torch/templates.h +7 -7
- data/ext/torch/tensor.cpp +13 -14
- data/ext/torch/utils.h +6 -1
- data/lib/torch/device.rb +0 -1
- data/lib/torch/tensor.rb +1 -6
- data/lib/torch/version.rb +1 -1
- metadata +4 -4
data/ext/torch/device.cpp
CHANGED
|
@@ -8,7 +8,8 @@
|
|
|
8
8
|
#include "utils.h"
|
|
9
9
|
|
|
10
10
|
void init_device(Rice::Module& m) {
|
|
11
|
-
Rice::define_class_under<torch::Device>(m, "Device")
|
|
11
|
+
auto rb_cDevice = Rice::define_class_under<torch::Device>(m, "Device");
|
|
12
|
+
rb_cDevice
|
|
12
13
|
.define_constructor(Rice::Constructor<torch::Device, const std::string&>())
|
|
13
14
|
.define_method(
|
|
14
15
|
"_index",
|
|
@@ -28,8 +29,10 @@ void init_device(Rice::Module& m) {
|
|
|
28
29
|
return s.str();
|
|
29
30
|
})
|
|
30
31
|
.define_method(
|
|
31
|
-
"
|
|
32
|
+
"to_s",
|
|
32
33
|
[](torch::Device& self) {
|
|
33
34
|
return self.str();
|
|
34
35
|
});
|
|
36
|
+
|
|
37
|
+
THPDeviceClass = rb_cDevice.value();
|
|
35
38
|
}
|
data/ext/torch/nn.cpp
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
#include <optional>
|
|
1
2
|
#include <utility>
|
|
2
3
|
|
|
3
4
|
#include <torch/torch.h>
|
|
@@ -95,7 +96,7 @@ void init_nn(Rice::Module& m) {
|
|
|
95
96
|
"grad",
|
|
96
97
|
[](Parameter& self) {
|
|
97
98
|
auto grad = self.grad();
|
|
98
|
-
return grad.defined() ?
|
|
99
|
+
return grad.defined() ? std::optional<torch::Tensor>{grad} : std::nullopt;
|
|
99
100
|
})
|
|
100
101
|
// can't use grad=
|
|
101
102
|
// assignment methods fail with Ruby 3.0
|
|
@@ -7,6 +7,7 @@
|
|
|
7
7
|
|
|
8
8
|
#include "ruby_arg_parser.h"
|
|
9
9
|
|
|
10
|
+
VALUE THPDeviceClass = Qnil;
|
|
10
11
|
VALUE THPGeneratorClass = Qnil;
|
|
11
12
|
VALUE THPVariableClass = Qnil;
|
|
12
13
|
|
|
@@ -119,7 +120,7 @@ bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
|
|
|
119
120
|
VALUE iobj = rb_ary_entry(obj, idx);
|
|
120
121
|
if (!THPVariable_Check(iobj)) {
|
|
121
122
|
if (throw_error) {
|
|
122
|
-
|
|
123
|
+
throw Rice::Exception(rb_eArgError, "expected Tensor as element %d in argument %d, but got %s",
|
|
123
124
|
static_cast<int>(idx), argnum, rb_obj_classname(obj));
|
|
124
125
|
}
|
|
125
126
|
return false;
|
|
@@ -257,7 +258,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool {
|
|
|
257
258
|
case ParameterType::LAYOUT: return SYMBOL_P(obj);
|
|
258
259
|
case ParameterType::MEMORY_FORMAT: return false; // return THPMemoryFormat_Check(obj);
|
|
259
260
|
case ParameterType::QSCHEME: return false; // return THPQScheme_Check(obj);
|
|
260
|
-
case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING)
|
|
261
|
+
case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING) || THPDevice_Check(obj);
|
|
261
262
|
case ParameterType::STRING: return RB_TYPE_P(obj, T_STRING);
|
|
262
263
|
case ParameterType::SYM_INT: return is_int_or_symint(obj);
|
|
263
264
|
case ParameterType::SYM_INT_LIST: return is_int_or_symint_list(obj, size);
|
|
@@ -582,10 +583,10 @@ static void extra_args(const FunctionSignature& signature, ssize_t nargs) {
|
|
|
582
583
|
const long min_args = signature.min_args;
|
|
583
584
|
const long nargs_ = nargs;
|
|
584
585
|
if (min_args != max_pos_args) {
|
|
585
|
-
|
|
586
|
+
throw Rice::Exception(rb_eArgError, "%s() takes from %ld to %ld positional arguments but %ld were given",
|
|
586
587
|
signature.name.c_str(), min_args, max_pos_args, nargs_);
|
|
587
588
|
}
|
|
588
|
-
|
|
589
|
+
throw Rice::Exception(rb_eArgError, "%s() takes %ld positional argument%s but %ld %s given",
|
|
589
590
|
signature.name.c_str(),
|
|
590
591
|
max_pos_args, max_pos_args == 1 ? "" : "s",
|
|
591
592
|
nargs_, nargs == 1 ? "was" : "were");
|
|
@@ -607,7 +608,7 @@ static void missing_args(const FunctionSignature& signature, int idx) {
|
|
|
607
608
|
}
|
|
608
609
|
}
|
|
609
610
|
|
|
610
|
-
|
|
611
|
+
throw Rice::Exception(rb_eArgError, "%s() missing %d required positional argument%s: %s",
|
|
611
612
|
signature.name.c_str(),
|
|
612
613
|
num_missing,
|
|
613
614
|
num_missing == 1 ? "s" : "",
|
|
@@ -635,23 +636,23 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
|
|
|
635
636
|
key = rb_ary_entry(keys, 0);
|
|
636
637
|
|
|
637
638
|
if (!THPUtils_checkSymbol(key)) {
|
|
638
|
-
|
|
639
|
+
throw Rice::Exception(rb_eArgError, "keywords must be symbols, not %s", rb_obj_classname(key));
|
|
639
640
|
}
|
|
640
641
|
|
|
641
642
|
auto param_idx = find_param(signature, key);
|
|
642
643
|
if (param_idx < 0) {
|
|
643
|
-
|
|
644
|
+
throw Rice::Exception(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
|
|
644
645
|
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
|
645
646
|
}
|
|
646
647
|
|
|
647
648
|
if (param_idx < num_pos_args) {
|
|
648
|
-
|
|
649
|
+
throw Rice::Exception(rb_eArgError, "%s() got multiple values for argument '%s'",
|
|
649
650
|
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
|
650
651
|
}
|
|
651
652
|
}
|
|
652
653
|
|
|
653
654
|
// this should never be hit
|
|
654
|
-
|
|
655
|
+
throw Rice::Exception(rb_eArgError, "invalid keyword arguments");
|
|
655
656
|
}
|
|
656
657
|
|
|
657
658
|
VALUE missing = Qundef;
|
|
@@ -739,12 +740,12 @@ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[],
|
|
|
739
740
|
} else if (raise_exception) {
|
|
740
741
|
if (is_kwd) {
|
|
741
742
|
// foo(): argument 'other' must be str, not int
|
|
742
|
-
|
|
743
|
+
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, not %s",
|
|
743
744
|
name.c_str(), param.name.c_str(), param.type_name().c_str(),
|
|
744
745
|
rb_obj_classname(obj));
|
|
745
746
|
} else {
|
|
746
747
|
// foo(): argument 'other' (position 2) must be str, not int
|
|
747
|
-
|
|
748
|
+
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' (position %ld) must be %s, not %s",
|
|
748
749
|
name.c_str(), param.name.c_str(), static_cast<long>(arg_pos + 1),
|
|
749
750
|
param.type_name().c_str(), rb_obj_classname(obj));
|
|
750
751
|
}
|
data/ext/torch/ruby_arg_parser.h
CHANGED
|
@@ -165,7 +165,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
|
|
165
165
|
Check_Type(arg, T_ARRAY);
|
|
166
166
|
auto size = RARRAY_LEN(arg);
|
|
167
167
|
if (size != N) {
|
|
168
|
-
|
|
168
|
+
throw Rice::Exception(rb_eArgError, "expected array of %d elements but got %d", N, static_cast<int>(size));
|
|
169
169
|
}
|
|
170
170
|
for (int idx = 0; idx < size; idx++) {
|
|
171
171
|
VALUE obj = rb_ary_entry(arg, idx);
|
|
@@ -206,7 +206,7 @@ inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int6
|
|
|
206
206
|
if (FIXNUM_P(obj)) {
|
|
207
207
|
res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
|
|
208
208
|
} else {
|
|
209
|
-
|
|
209
|
+
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
|
210
210
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
|
211
211
|
signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
|
|
212
212
|
}
|
|
@@ -270,7 +270,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
|
270
270
|
|
|
271
271
|
auto it = dtype_map.find(args[i]);
|
|
272
272
|
if (it == dtype_map.end()) {
|
|
273
|
-
|
|
273
|
+
throw Rice::Exception(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
|
|
274
274
|
}
|
|
275
275
|
return it->second;
|
|
276
276
|
}
|
|
@@ -321,7 +321,7 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
|
|
321
321
|
if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
|
|
322
322
|
res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
|
|
323
323
|
} else {
|
|
324
|
-
|
|
324
|
+
throw Rice::Exception(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
|
325
325
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
|
326
326
|
signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
|
|
327
327
|
}
|
|
@@ -338,7 +338,7 @@ inline at::Layout RubyArgs::layout(int i) {
|
|
|
338
338
|
|
|
339
339
|
auto it = layout_map.find(args[i]);
|
|
340
340
|
if (it == layout_map.end()) {
|
|
341
|
-
|
|
341
|
+
throw Rice::Exception(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
|
|
342
342
|
}
|
|
343
343
|
return it->second;
|
|
344
344
|
}
|
|
@@ -357,8 +357,11 @@ inline at::Device RubyArgs::device(int i) {
|
|
|
357
357
|
if (NIL_P(args[i])) {
|
|
358
358
|
return at::Device("cpu");
|
|
359
359
|
}
|
|
360
|
-
|
|
361
|
-
|
|
360
|
+
if (RB_TYPE_P(args[i], T_STRING)) {
|
|
361
|
+
const std::string &device_str = THPUtils_unpackString(args[i]);
|
|
362
|
+
return at::Device(device_str);
|
|
363
|
+
}
|
|
364
|
+
return Rice::detail::From_Ruby<at::Device>().convert(args[i]);
|
|
362
365
|
}
|
|
363
366
|
|
|
364
367
|
inline at::Device RubyArgs::deviceWithDefault(int i, const at::Device& default_device) {
|
|
@@ -466,7 +469,7 @@ struct RubyArgParser {
|
|
|
466
469
|
template<int N>
|
|
467
470
|
inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N> &dst) {
|
|
468
471
|
if (N < max_args) {
|
|
469
|
-
|
|
472
|
+
throw Rice::Exception(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", static_cast<int>(max_args), N);
|
|
470
473
|
}
|
|
471
474
|
return raw_parse(self, argc, argv, dst.args);
|
|
472
475
|
}
|
|
@@ -490,7 +493,7 @@ struct RubyArgParser {
|
|
|
490
493
|
print_error(self, args, kwargs, parsed_args);
|
|
491
494
|
|
|
492
495
|
// TODO better message
|
|
493
|
-
|
|
496
|
+
throw Rice::Exception(rb_eArgError, "No matching signatures");
|
|
494
497
|
}
|
|
495
498
|
|
|
496
499
|
void print_error(VALUE self, VALUE args, VALUE kwargs, VALUE parsed_args[]) {
|
data/ext/torch/templates.h
CHANGED
|
@@ -55,7 +55,7 @@ namespace Rice::detail {
|
|
|
55
55
|
explicit To_Ruby(Arg* arg) : arg_(arg) { }
|
|
56
56
|
|
|
57
57
|
VALUE convert(c10::complex<T> const& x) {
|
|
58
|
-
return rb_dbl_complex_new
|
|
58
|
+
return protect(rb_dbl_complex_new, x.real(), x.imag());
|
|
59
59
|
}
|
|
60
60
|
|
|
61
61
|
private:
|
|
@@ -69,11 +69,11 @@ namespace Rice::detail {
|
|
|
69
69
|
|
|
70
70
|
explicit From_Ruby(Arg* arg) : arg_(arg) { }
|
|
71
71
|
|
|
72
|
-
|
|
72
|
+
double is_convertible(VALUE value) { return Convertible::Exact; }
|
|
73
73
|
|
|
74
74
|
c10::complex<T> convert(VALUE x) {
|
|
75
|
-
VALUE real = rb_funcall
|
|
76
|
-
VALUE imag = rb_funcall
|
|
75
|
+
VALUE real = protect(rb_funcall, x, rb_intern("real"), 0);
|
|
76
|
+
VALUE imag = protect(rb_funcall, x, rb_intern("imag"), 0);
|
|
77
77
|
return c10::complex<T>(From_Ruby<T>().convert(real), From_Ruby<T>().convert(imag));
|
|
78
78
|
}
|
|
79
79
|
|
|
@@ -93,7 +93,7 @@ namespace Rice::detail {
|
|
|
93
93
|
|
|
94
94
|
explicit From_Ruby(Arg* arg) : arg_(arg) { }
|
|
95
95
|
|
|
96
|
-
|
|
96
|
+
double is_convertible(VALUE value) { return Convertible::Exact; }
|
|
97
97
|
|
|
98
98
|
FanModeType convert(VALUE x) {
|
|
99
99
|
auto s = String(x).str();
|
|
@@ -122,7 +122,7 @@ namespace Rice::detail {
|
|
|
122
122
|
|
|
123
123
|
explicit From_Ruby(Arg* arg) : arg_(arg) { }
|
|
124
124
|
|
|
125
|
-
|
|
125
|
+
double is_convertible(VALUE value) { return Convertible::Exact; }
|
|
126
126
|
|
|
127
127
|
NonlinearityType convert(VALUE x) {
|
|
128
128
|
auto s = String(x).str();
|
|
@@ -169,7 +169,7 @@ namespace Rice::detail {
|
|
|
169
169
|
|
|
170
170
|
explicit From_Ruby(Arg* arg) : arg_(arg) { }
|
|
171
171
|
|
|
172
|
-
|
|
172
|
+
double is_convertible(VALUE value) { return Convertible::Exact; }
|
|
173
173
|
|
|
174
174
|
Scalar convert(VALUE x) {
|
|
175
175
|
if (FIXNUM_P(x)) {
|
data/ext/torch/tensor.cpp
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
#include <optional>
|
|
1
2
|
#include <string>
|
|
2
3
|
#include <vector>
|
|
3
4
|
|
|
@@ -28,15 +29,15 @@ Array flat_data(Tensor& tensor) {
|
|
|
28
29
|
Rice::Class rb_cTensor;
|
|
29
30
|
|
|
30
31
|
std::vector<TensorIndex> index_vector(Array a) {
|
|
31
|
-
Object obj;
|
|
32
|
-
|
|
33
32
|
std::vector<TensorIndex> indices;
|
|
34
33
|
indices.reserve(a.size());
|
|
35
34
|
|
|
36
35
|
for (long i = 0; i < a.size(); i++) {
|
|
37
|
-
obj
|
|
36
|
+
Object obj(a[i]);
|
|
38
37
|
|
|
39
|
-
if (obj.
|
|
38
|
+
if (obj.is_nil()) {
|
|
39
|
+
indices.push_back(torch::indexing::None);
|
|
40
|
+
} else if (obj.is_instance_of(rb_cInteger)) {
|
|
40
41
|
indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
|
|
41
42
|
} else if (obj.is_instance_of(rb_cRange)) {
|
|
42
43
|
torch::optional<c10::SymInt> start_index = torch::nullopt;
|
|
@@ -64,12 +65,10 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
|
64
65
|
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
|
65
66
|
} else if (obj.is_instance_of(rb_cTensor)) {
|
|
66
67
|
indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
|
|
67
|
-
} else if (obj.
|
|
68
|
-
indices.push_back(torch::indexing::None);
|
|
69
|
-
} else if (obj == Rice::True || obj == Rice::False) {
|
|
68
|
+
} else if (obj.value() == Qtrue || obj.value() == Qfalse) {
|
|
70
69
|
indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
|
|
71
70
|
} else {
|
|
72
|
-
throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
|
71
|
+
throw Rice::Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj.value()));
|
|
73
72
|
}
|
|
74
73
|
}
|
|
75
74
|
return indices;
|
|
@@ -102,7 +101,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
102
101
|
add_tensor_functions(rb_cTensor);
|
|
103
102
|
THPVariableClass = rb_cTensor.value();
|
|
104
103
|
|
|
105
|
-
rb_define_method(rb_cTensor, "backward",
|
|
104
|
+
rb_define_method(rb_cTensor, "backward", tensor__backward, -1);
|
|
106
105
|
|
|
107
106
|
rb_cTensor
|
|
108
107
|
.define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
|
|
@@ -168,7 +167,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
168
167
|
"grad",
|
|
169
168
|
[](Tensor& self) {
|
|
170
169
|
auto grad = self.grad();
|
|
171
|
-
return grad.defined() ?
|
|
170
|
+
return grad.defined() ? std::optional<torch::Tensor>{grad} : std::nullopt;
|
|
172
171
|
})
|
|
173
172
|
// can't use grad=
|
|
174
173
|
// assignment methods fail with Ruby 3.0
|
|
@@ -184,15 +183,15 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
184
183
|
|
|
185
184
|
// TODO support sparse grad
|
|
186
185
|
if (!grad.options().type_equal(self.options())) {
|
|
187
|
-
|
|
186
|
+
throw Rice::Exception(rb_eArgError, "assigned grad has data of a different type");
|
|
188
187
|
}
|
|
189
188
|
|
|
190
189
|
if (self.is_cuda() && grad.get_device() != self.get_device()) {
|
|
191
|
-
|
|
190
|
+
throw Rice::Exception(rb_eArgError, "assigned grad has data located on a different device");
|
|
192
191
|
}
|
|
193
192
|
|
|
194
193
|
if (!self.sizes().equals(grad.sizes())) {
|
|
195
|
-
|
|
194
|
+
throw Rice::Exception(rb_eArgError, "assigned grad has data of a different size");
|
|
196
195
|
}
|
|
197
196
|
|
|
198
197
|
self.mutable_grad() = grad;
|
|
@@ -215,7 +214,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
|
215
214
|
return s.str();
|
|
216
215
|
})
|
|
217
216
|
.define_method(
|
|
218
|
-
"
|
|
217
|
+
"device",
|
|
219
218
|
[](Tensor& self) {
|
|
220
219
|
return self.device();
|
|
221
220
|
})
|
data/ext/torch/utils.h
CHANGED
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
#include <rice/stl.hpp>
|
|
9
9
|
|
|
10
10
|
static_assert(
|
|
11
|
-
TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR ==
|
|
11
|
+
TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 10,
|
|
12
12
|
"Incompatible LibTorch version"
|
|
13
13
|
);
|
|
14
14
|
|
|
@@ -20,6 +20,7 @@ inline void handle_global_error(const torch::Error& ex) {
|
|
|
20
20
|
|
|
21
21
|
// keep THP prefix for now to make it easier to compare code
|
|
22
22
|
|
|
23
|
+
extern VALUE THPDeviceClass;
|
|
23
24
|
extern VALUE THPGeneratorClass;
|
|
24
25
|
extern VALUE THPVariableClass;
|
|
25
26
|
|
|
@@ -48,6 +49,10 @@ inline bool THPUtils_checkScalar(VALUE obj) {
|
|
|
48
49
|
return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
|
|
49
50
|
}
|
|
50
51
|
|
|
52
|
+
inline bool THPDevice_Check(VALUE obj) {
|
|
53
|
+
return rb_obj_is_kind_of(obj, THPDeviceClass);
|
|
54
|
+
}
|
|
55
|
+
|
|
51
56
|
inline bool THPGenerator_Check(VALUE obj) {
|
|
52
57
|
return rb_obj_is_kind_of(obj, THPGeneratorClass);
|
|
53
58
|
}
|
data/lib/torch/device.rb
CHANGED
data/lib/torch/tensor.rb
CHANGED
|
@@ -115,7 +115,7 @@ module Torch
|
|
|
115
115
|
if numel != 1
|
|
116
116
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
|
117
117
|
end
|
|
118
|
-
to_a.first
|
|
118
|
+
to_a.flatten.first
|
|
119
119
|
end
|
|
120
120
|
|
|
121
121
|
def to_i
|
|
@@ -210,10 +210,5 @@ module Torch
|
|
|
210
210
|
raise TypeError, "#{self.class} can't be coerced into #{other.class}"
|
|
211
211
|
end
|
|
212
212
|
end
|
|
213
|
-
|
|
214
|
-
# TODO return Device instead of String in 0.19.0
|
|
215
|
-
def device
|
|
216
|
-
_device._str
|
|
217
|
-
end
|
|
218
213
|
end
|
|
219
214
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
|
2
2
|
name: torch-rb
|
|
3
3
|
version: !ruby/object:Gem::Version
|
|
4
|
-
version: 0.
|
|
4
|
+
version: 0.23.1
|
|
5
5
|
platform: ruby
|
|
6
6
|
authors:
|
|
7
7
|
- Andrew Kane
|
|
@@ -15,14 +15,14 @@ dependencies:
|
|
|
15
15
|
requirements:
|
|
16
16
|
- - ">="
|
|
17
17
|
- !ruby/object:Gem::Version
|
|
18
|
-
version: '4.
|
|
18
|
+
version: '4.8'
|
|
19
19
|
type: :runtime
|
|
20
20
|
prerelease: false
|
|
21
21
|
version_requirements: !ruby/object:Gem::Requirement
|
|
22
22
|
requirements:
|
|
23
23
|
- - ">="
|
|
24
24
|
- !ruby/object:Gem::Version
|
|
25
|
-
version: '4.
|
|
25
|
+
version: '4.8'
|
|
26
26
|
email: andrew@ankane.org
|
|
27
27
|
executables: []
|
|
28
28
|
extensions:
|
|
@@ -241,7 +241,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
|
241
241
|
- !ruby/object:Gem::Version
|
|
242
242
|
version: '0'
|
|
243
243
|
requirements: []
|
|
244
|
-
rubygems_version:
|
|
244
|
+
rubygems_version: 4.0.3
|
|
245
245
|
specification_version: 4
|
|
246
246
|
summary: Deep learning for Ruby, powered by LibTorch
|
|
247
247
|
test_files: []
|