torch-rb 0.8.2 → 0.9.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +22 -0
- data/README.md +5 -4
- data/codegen/generate_functions.rb +11 -4
- data/codegen/native_functions.yaml +1103 -373
- data/ext/torch/backends.cpp +2 -2
- data/ext/torch/nn.cpp +4 -1
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +19 -5
- data/ext/torch/templates.h +30 -33
- data/ext/torch/tensor.cpp +33 -33
- data/ext/torch/torch.cpp +38 -28
- data/ext/torch/utils.h +0 -6
- data/lib/torch/inspector.rb +1 -1
- data/lib/torch/nn/functional.rb +1 -1
- data/lib/torch/nn/functional_attention.rb +1 -1
- data/lib/torch/nn/module.rb +28 -0
- data/lib/torch/nn/parameter.rb +9 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +1 -1
- data/lib/torch/nn/utils.rb +1 -5
- data/lib/torch/tensor.rb +22 -8
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +14 -48
- metadata +3 -3
data/ext/torch/backends.cpp
CHANGED
@@ -7,11 +7,11 @@
|
|
7
7
|
void init_backends(Rice::Module& m) {
|
8
8
|
auto rb_mBackends = Rice::define_module_under(m, "Backends");
|
9
9
|
|
10
|
-
|
10
|
+
Rice::define_module_under(rb_mBackends, "OpenMP")
|
11
11
|
.add_handler<torch::Error>(handle_error)
|
12
12
|
.define_singleton_function("available?", &torch::hasOpenMP);
|
13
13
|
|
14
|
-
|
14
|
+
Rice::define_module_under(rb_mBackends, "MKL")
|
15
15
|
.add_handler<torch::Error>(handle_error)
|
16
16
|
.define_singleton_function("available?", &torch::hasMKL);
|
17
17
|
}
|
data/ext/torch/nn.cpp
CHANGED
@@ -98,8 +98,11 @@ void init_nn(Rice::Module& m) {
|
|
98
98
|
auto grad = self.grad();
|
99
99
|
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
100
100
|
})
|
101
|
+
// can't use grad=
|
102
|
+
// assignment methods fail with Ruby 3.0
|
103
|
+
// TODO add checks like Tensor
|
101
104
|
.define_method(
|
102
|
-
"
|
105
|
+
"_set_grad",
|
103
106
|
[](Parameter& self, torch::Tensor& grad) {
|
104
107
|
self.mutable_grad() = grad;
|
105
108
|
})
|
@@ -472,12 +472,12 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
|
|
472
472
|
auto param_idx = find_param(signature, key);
|
473
473
|
if (param_idx < 0) {
|
474
474
|
rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
|
475
|
-
signature.name.c_str(),
|
475
|
+
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
476
476
|
}
|
477
477
|
|
478
478
|
if (param_idx < num_pos_args) {
|
479
479
|
rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
|
480
|
-
signature.name.c_str(),
|
480
|
+
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
481
481
|
}
|
482
482
|
}
|
483
483
|
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -75,7 +75,7 @@ struct RubyArgs {
|
|
75
75
|
int idx;
|
76
76
|
|
77
77
|
inline at::Tensor tensor(int i);
|
78
|
-
inline
|
78
|
+
inline c10::optional<at::Tensor> optionalTensor(int i);
|
79
79
|
inline at::Scalar scalar(int i);
|
80
80
|
// inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
81
81
|
inline std::vector<at::Scalar> scalarlist(int i);
|
@@ -109,6 +109,9 @@ struct RubyArgs {
|
|
109
109
|
// inline at::QScheme toQScheme(int i);
|
110
110
|
inline std::string string(int i);
|
111
111
|
inline c10::optional<std::string> stringOptional(int i);
|
112
|
+
inline c10::string_view stringView(int i);
|
113
|
+
// inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
|
114
|
+
inline c10::optional<c10::string_view> stringViewOptional(int i);
|
112
115
|
// inline PyObject* pyobject(int i);
|
113
116
|
inline int64_t toInt64(int i);
|
114
117
|
// inline int64_t toInt64WithDefault(int i, int64_t default_int);
|
@@ -125,8 +128,8 @@ inline at::Tensor RubyArgs::tensor(int i) {
|
|
125
128
|
return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
|
126
129
|
}
|
127
130
|
|
128
|
-
inline
|
129
|
-
if (NIL_P(args[i])) return
|
131
|
+
inline c10::optional<at::Tensor> RubyArgs::optionalTensor(int i) {
|
132
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
130
133
|
return tensor(i);
|
131
134
|
}
|
132
135
|
|
@@ -232,7 +235,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
232
235
|
|
233
236
|
auto it = dtype_map.find(args[i]);
|
234
237
|
if (it == dtype_map.end()) {
|
235
|
-
rb_raise(rb_eArgError, "invalid dtype: %s",
|
238
|
+
rb_raise(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
|
236
239
|
}
|
237
240
|
return it->second;
|
238
241
|
}
|
@@ -290,7 +293,7 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
|
290
293
|
|
291
294
|
auto it = layout_map.find(args[i]);
|
292
295
|
if (it == layout_map.end()) {
|
293
|
-
rb_raise(rb_eArgError, "invalid layout: %s",
|
296
|
+
rb_raise(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
|
294
297
|
}
|
295
298
|
return it->second;
|
296
299
|
}
|
@@ -322,6 +325,17 @@ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
|
322
325
|
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
323
326
|
}
|
324
327
|
|
328
|
+
// string_view does not own data
|
329
|
+
inline c10::string_view RubyArgs::stringView(int i) {
|
330
|
+
return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
|
331
|
+
}
|
332
|
+
|
333
|
+
// string_view does not own data
|
334
|
+
inline c10::optional<c10::string_view> RubyArgs::stringViewOptional(int i) {
|
335
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
336
|
+
return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
|
337
|
+
}
|
338
|
+
|
325
339
|
inline int64_t RubyArgs::toInt64(int i) {
|
326
340
|
if (NIL_P(args[i])) return signature.params[i].default_int;
|
327
341
|
return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
|
data/ext/torch/templates.h
CHANGED
@@ -41,23 +41,39 @@ using torch::nn::init::NonlinearityType;
|
|
41
41
|
#define RETURN_NIL \
|
42
42
|
return Qnil;
|
43
43
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
}
|
44
|
+
namespace Rice::detail
|
45
|
+
{
|
46
|
+
template<typename T>
|
47
|
+
struct Type<c10::complex<T>>
|
48
|
+
{
|
49
|
+
static bool verify()
|
50
|
+
{
|
51
|
+
return true;
|
53
52
|
}
|
54
|
-
|
55
|
-
|
53
|
+
};
|
54
|
+
|
55
|
+
template<typename T>
|
56
|
+
class To_Ruby<c10::complex<T>>
|
57
|
+
{
|
58
|
+
public:
|
59
|
+
VALUE convert(c10::complex<T> const& x)
|
60
|
+
{
|
61
|
+
return rb_dbl_complex_new(x.real(), x.imag());
|
56
62
|
}
|
57
|
-
|
58
|
-
|
63
|
+
};
|
64
|
+
|
65
|
+
template<typename T>
|
66
|
+
class From_Ruby<c10::complex<T>>
|
67
|
+
{
|
68
|
+
public:
|
69
|
+
c10::complex<T> convert(VALUE x)
|
70
|
+
{
|
71
|
+
VALUE real = rb_funcall(x, rb_intern("real"), 0);
|
72
|
+
VALUE imag = rb_funcall(x, rb_intern("imag"), 0);
|
73
|
+
return c10::complex<T>(From_Ruby<T>().convert(real), From_Ruby<T>().convert(imag));
|
59
74
|
}
|
60
|
-
};
|
75
|
+
};
|
76
|
+
}
|
61
77
|
|
62
78
|
namespace Rice::detail
|
63
79
|
{
|
@@ -131,25 +147,6 @@ namespace Rice::detail
|
|
131
147
|
}
|
132
148
|
};
|
133
149
|
|
134
|
-
template<>
|
135
|
-
struct Type<OptionalTensor>
|
136
|
-
{
|
137
|
-
static bool verify()
|
138
|
-
{
|
139
|
-
return true;
|
140
|
-
}
|
141
|
-
};
|
142
|
-
|
143
|
-
template<>
|
144
|
-
class From_Ruby<OptionalTensor>
|
145
|
-
{
|
146
|
-
public:
|
147
|
-
OptionalTensor convert(VALUE x)
|
148
|
-
{
|
149
|
-
return OptionalTensor(x);
|
150
|
-
}
|
151
|
-
};
|
152
|
-
|
153
150
|
template<>
|
154
151
|
struct Type<Scalar>
|
155
152
|
{
|
data/ext/torch/tensor.cpp
CHANGED
@@ -10,28 +10,6 @@
|
|
10
10
|
using namespace Rice;
|
11
11
|
using torch::indexing::TensorIndex;
|
12
12
|
|
13
|
-
namespace Rice::detail
|
14
|
-
{
|
15
|
-
template<typename T>
|
16
|
-
struct Type<c10::complex<T>>
|
17
|
-
{
|
18
|
-
static bool verify()
|
19
|
-
{
|
20
|
-
return true;
|
21
|
-
}
|
22
|
-
};
|
23
|
-
|
24
|
-
template<typename T>
|
25
|
-
class To_Ruby<c10::complex<T>>
|
26
|
-
{
|
27
|
-
public:
|
28
|
-
VALUE convert(c10::complex<T> const& x)
|
29
|
-
{
|
30
|
-
return rb_dbl_complex_new(x.real(), x.imag());
|
31
|
-
}
|
32
|
-
};
|
33
|
-
}
|
34
|
-
|
35
13
|
template<typename T>
|
36
14
|
Array flat_data(Tensor& tensor) {
|
37
15
|
Tensor view = tensor.reshape({tensor.numel()});
|
@@ -107,7 +85,7 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
|
107
85
|
ParsedArgs<4> parsed_args;
|
108
86
|
auto _r = parser.parse(self_, argc, argv, parsed_args);
|
109
87
|
// _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
|
110
|
-
auto dispatch__backward = [](const Tensor & self, TensorList inputs, const
|
88
|
+
auto dispatch__backward = [](const Tensor & self, TensorList inputs, const c10::optional<at::Tensor> & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
|
111
89
|
// in future, release GVL
|
112
90
|
self._backward(inputs, gradient, retain_graph, create_graph);
|
113
91
|
};
|
@@ -125,13 +103,13 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
125
103
|
rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
|
126
104
|
|
127
105
|
rb_cTensor
|
128
|
-
.define_method("cuda?", &
|
129
|
-
.define_method("sparse?", &
|
130
|
-
.define_method("quantized?", &
|
131
|
-
.define_method("dim", &
|
132
|
-
.define_method("numel", &
|
133
|
-
.define_method("element_size", &
|
134
|
-
.define_method("requires_grad", &
|
106
|
+
.define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
|
107
|
+
.define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
|
108
|
+
.define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
|
109
|
+
.define_method("dim", [](Tensor& self) { return self.dim(); })
|
110
|
+
.define_method("numel", [](Tensor& self) { return self.numel(); })
|
111
|
+
.define_method("element_size", [](Tensor& self) { return self.element_size(); })
|
112
|
+
.define_method("requires_grad", [](Tensor& self) { return self.requires_grad(); })
|
135
113
|
.define_method(
|
136
114
|
"_size",
|
137
115
|
[](Tensor& self, int64_t dim) {
|
@@ -189,9 +167,31 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
189
167
|
auto grad = self.grad();
|
190
168
|
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
191
169
|
})
|
170
|
+
// can't use grad=
|
171
|
+
// assignment methods fail with Ruby 3.0
|
192
172
|
.define_method(
|
193
|
-
"
|
194
|
-
[](Tensor& self,
|
173
|
+
"_set_grad",
|
174
|
+
[](Tensor& self, Rice::Object value) {
|
175
|
+
if (value.is_nil()) {
|
176
|
+
self.mutable_grad().reset();
|
177
|
+
return;
|
178
|
+
}
|
179
|
+
|
180
|
+
const auto& grad = Rice::detail::From_Ruby<torch::Tensor>().convert(value.value());
|
181
|
+
|
182
|
+
// TODO support sparse grad
|
183
|
+
if (!grad.options().type_equal(self.options())) {
|
184
|
+
rb_raise(rb_eArgError, "assigned grad has data of a different type");
|
185
|
+
}
|
186
|
+
|
187
|
+
if (self.is_cuda() && grad.get_device() != self.get_device()) {
|
188
|
+
rb_raise(rb_eArgError, "assigned grad has data located on a different device");
|
189
|
+
}
|
190
|
+
|
191
|
+
if (!self.sizes().equals(grad.sizes())) {
|
192
|
+
rb_raise(rb_eArgError, "assigned grad has data of a different size");
|
193
|
+
}
|
194
|
+
|
195
195
|
self.mutable_grad() = grad;
|
196
196
|
})
|
197
197
|
.define_method(
|
@@ -281,7 +281,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
281
281
|
})
|
282
282
|
.define_method(
|
283
283
|
"_to",
|
284
|
-
[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
284
|
+
[](Tensor& self, torch::Device& device, int dtype, bool non_blocking, bool copy) {
|
285
285
|
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
286
286
|
});
|
287
287
|
|
data/ext/torch/torch.cpp
CHANGED
@@ -6,6 +6,23 @@
|
|
6
6
|
#include "templates.h"
|
7
7
|
#include "utils.h"
|
8
8
|
|
9
|
+
template<typename T>
|
10
|
+
torch::Tensor make_tensor(Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
11
|
+
std::vector<T> vec;
|
12
|
+
for (long i = 0; i < a.size(); i++) {
|
13
|
+
vec.push_back(Rice::detail::From_Ruby<T>().convert(a[i].value()));
|
14
|
+
}
|
15
|
+
|
16
|
+
// hack for requires_grad error
|
17
|
+
auto requires_grad = options.requires_grad();
|
18
|
+
torch::Tensor t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
19
|
+
if (requires_grad) {
|
20
|
+
t.set_requires_grad(true);
|
21
|
+
}
|
22
|
+
|
23
|
+
return t.reshape(size);
|
24
|
+
}
|
25
|
+
|
9
26
|
void init_torch(Rice::Module& m) {
|
10
27
|
m.add_handler<torch::Error>(handle_error);
|
11
28
|
add_torch_functions(m);
|
@@ -61,35 +78,28 @@ void init_torch(Rice::Module& m) {
|
|
61
78
|
"_tensor",
|
62
79
|
[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
63
80
|
auto dtype = options.dtype();
|
64
|
-
torch::
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
81
|
+
if (dtype == torch::kByte) {
|
82
|
+
return make_tensor<uint8_t>(a, size, options);
|
83
|
+
} else if (dtype == torch::kChar) {
|
84
|
+
return make_tensor<int8_t>(a, size, options);
|
85
|
+
} else if (dtype == torch::kShort) {
|
86
|
+
return make_tensor<int16_t>(a, size, options);
|
87
|
+
} else if (dtype == torch::kInt) {
|
88
|
+
return make_tensor<int32_t>(a, size, options);
|
89
|
+
} else if (dtype == torch::kLong) {
|
90
|
+
return make_tensor<int64_t>(a, size, options);
|
91
|
+
} else if (dtype == torch::kFloat) {
|
92
|
+
return make_tensor<float>(a, size, options);
|
93
|
+
} else if (dtype == torch::kDouble) {
|
94
|
+
return make_tensor<double>(a, size, options);
|
95
|
+
} else if (dtype == torch::kBool) {
|
96
|
+
return make_tensor<uint8_t>(a, size, options);
|
97
|
+
} else if (dtype == torch::kComplexFloat) {
|
98
|
+
return make_tensor<c10::complex<float>>(a, size, options);
|
99
|
+
} else if (dtype == torch::kComplexDouble) {
|
100
|
+
return make_tensor<c10::complex<double>>(a, size, options);
|
80
101
|
} else {
|
81
|
-
std::
|
82
|
-
for (long i = 0; i < a.size(); i++) {
|
83
|
-
vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
|
84
|
-
}
|
85
|
-
// hack for requires_grad error
|
86
|
-
if (options.requires_grad()) {
|
87
|
-
t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
88
|
-
t.set_requires_grad(true);
|
89
|
-
} else {
|
90
|
-
t = torch::tensor(vec, options);
|
91
|
-
}
|
102
|
+
throw std::runtime_error("Unsupported type");
|
92
103
|
}
|
93
|
-
return t.reshape(size);
|
94
104
|
});
|
95
105
|
}
|
data/ext/torch/utils.h
CHANGED
@@ -16,12 +16,6 @@ inline VALUE THPUtils_internSymbol(const std::string& str) {
|
|
16
16
|
return Rice::Symbol(str);
|
17
17
|
}
|
18
18
|
|
19
|
-
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
20
|
-
Check_Type(obj, T_SYMBOL);
|
21
|
-
obj = rb_funcall(obj, rb_intern("to_s"), 0);
|
22
|
-
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
23
|
-
}
|
24
|
-
|
25
19
|
inline std::string THPUtils_unpackString(VALUE obj) {
|
26
20
|
Check_Type(obj, T_STRING);
|
27
21
|
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
data/lib/torch/inspector.rb
CHANGED
@@ -247,7 +247,7 @@ module Torch
|
|
247
247
|
# length includes spaces and comma between elements
|
248
248
|
element_length = formatter.width + 2
|
249
249
|
elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
|
250
|
-
|
250
|
+
_char_per_line = element_length * elements_per_line
|
251
251
|
|
252
252
|
if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
|
253
253
|
data = (
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -571,7 +571,7 @@ module Torch
|
|
571
571
|
end
|
572
572
|
|
573
573
|
def _interp_output_size(closed_over_args)
|
574
|
-
input, size, scale_factor,
|
574
|
+
input, size, scale_factor, _recompute_scale_factor = closed_over_args
|
575
575
|
dim = input.dim - 2
|
576
576
|
if size.nil? && scale_factor.nil?
|
577
577
|
raise ArgumentError, "either size or scale_factor should be defined"
|
data/lib/torch/nn/module.rb
CHANGED
@@ -280,6 +280,11 @@ module Torch
|
|
280
280
|
end
|
281
281
|
end
|
282
282
|
|
283
|
+
def deep_dup
|
284
|
+
memo = {}
|
285
|
+
dup_value(self, memo)
|
286
|
+
end
|
287
|
+
|
283
288
|
def method_missing(method, *args, &block)
|
284
289
|
name = method.to_s
|
285
290
|
if named_parameters.key?(name)
|
@@ -388,6 +393,29 @@ module Torch
|
|
388
393
|
destination[prefix + k] = v
|
389
394
|
end
|
390
395
|
end
|
396
|
+
|
397
|
+
# keep memo hash like Python deepcopy
|
398
|
+
# https://docs.python.org/3/library/copy.html
|
399
|
+
def dup_value(v, memo)
|
400
|
+
memo[v.object_id] ||= begin
|
401
|
+
case v
|
402
|
+
when Method, UnboundMethod
|
403
|
+
v
|
404
|
+
when Hash
|
405
|
+
v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] }
|
406
|
+
when Array
|
407
|
+
v.map { |v2| dup_value(v2, memo) }
|
408
|
+
when Torch::NN::Module
|
409
|
+
copy = v.dup
|
410
|
+
v.instance_variables.each do |var|
|
411
|
+
copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo))
|
412
|
+
end
|
413
|
+
copy
|
414
|
+
else
|
415
|
+
v.dup
|
416
|
+
end
|
417
|
+
end
|
418
|
+
end
|
391
419
|
end
|
392
420
|
end
|
393
421
|
end
|
data/lib/torch/nn/parameter.rb
CHANGED
@@ -1,6 +1,9 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Parameter < Tensor
|
4
|
+
# fix for issue w/ assignment methods
|
5
|
+
alias_method :grad=, :_set_grad
|
6
|
+
|
4
7
|
def self.new(data = nil, requires_grad: true)
|
5
8
|
data = Tensor.new unless data
|
6
9
|
_make_subclass(data, requires_grad)
|
@@ -9,6 +12,12 @@ module Torch
|
|
9
12
|
def inspect
|
10
13
|
"Parameter containing:\n#{super}"
|
11
14
|
end
|
15
|
+
|
16
|
+
def dup
|
17
|
+
Torch.no_grad do
|
18
|
+
Parameter.new(clone, requires_grad: requires_grad)
|
19
|
+
end
|
20
|
+
end
|
12
21
|
end
|
13
22
|
end
|
14
23
|
end
|
data/lib/torch/nn/utils.rb
CHANGED
@@ -22,11 +22,7 @@ module Torch
|
|
22
22
|
end
|
23
23
|
|
24
24
|
def _clones(mod, n)
|
25
|
-
|
26
|
-
layers = n.times.map do |i|
|
27
|
-
mod.clone.tap { |l| l.load_state_dict(state) }
|
28
|
-
end
|
29
|
-
ModuleList.new(layers)
|
25
|
+
ModuleList.new(n.times.map { mod.deep_dup })
|
30
26
|
end
|
31
27
|
|
32
28
|
def _activation_fn(activation)
|
data/lib/torch/tensor.rb
CHANGED
@@ -8,6 +8,9 @@ module Torch
|
|
8
8
|
alias_method :ndim, :dim
|
9
9
|
alias_method :ndimension, :dim
|
10
10
|
|
11
|
+
# fix for issue w/ assignment methods
|
12
|
+
alias_method :grad=, :_set_grad
|
13
|
+
|
11
14
|
# use alias_method for performance
|
12
15
|
alias_method :+, :add
|
13
16
|
alias_method :-, :sub
|
@@ -106,6 +109,7 @@ module Torch
|
|
106
109
|
size(0)
|
107
110
|
end
|
108
111
|
|
112
|
+
remove_method :item
|
109
113
|
def item
|
110
114
|
if numel != 1
|
111
115
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
@@ -133,18 +137,10 @@ module Torch
|
|
133
137
|
cls.from_string(_data_str).reshape(*shape)
|
134
138
|
end
|
135
139
|
|
136
|
-
def new_ones(*size, **options)
|
137
|
-
Torch.ones_like(Torch.empty(*size), **options)
|
138
|
-
end
|
139
|
-
|
140
140
|
def requires_grad=(requires_grad)
|
141
141
|
_requires_grad!(requires_grad)
|
142
142
|
end
|
143
143
|
|
144
|
-
def requires_grad!(requires_grad = true)
|
145
|
-
_requires_grad!(requires_grad)
|
146
|
-
end
|
147
|
-
|
148
144
|
def type(dtype)
|
149
145
|
if dtype.is_a?(Class)
|
150
146
|
raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
|
@@ -185,5 +181,23 @@ module Torch
|
|
185
181
|
def stft(*args)
|
186
182
|
Torch.stft(*args)
|
187
183
|
end
|
184
|
+
|
185
|
+
def dup
|
186
|
+
Torch.no_grad do
|
187
|
+
clone
|
188
|
+
end
|
189
|
+
end
|
190
|
+
|
191
|
+
# not a method in native_functions.yaml
|
192
|
+
# attribute in Python rather than method
|
193
|
+
def imag
|
194
|
+
Torch.imag(self)
|
195
|
+
end
|
196
|
+
|
197
|
+
# not a method in native_functions.yaml
|
198
|
+
# attribute in Python rather than method
|
199
|
+
def real
|
200
|
+
Torch.real(self)
|
201
|
+
end
|
188
202
|
end
|
189
203
|
end
|
@@ -29,7 +29,7 @@ module Torch
|
|
29
29
|
|
30
30
|
# try to keep the random number generator in sync with Python
|
31
31
|
# this makes it easy to compare results
|
32
|
-
|
32
|
+
_base_seed = Torch.empty([], dtype: :int64).random!.item
|
33
33
|
|
34
34
|
indexes =
|
35
35
|
if @shuffle
|
data/lib/torch/version.rb
CHANGED