torch-rb 0.8.2 → 0.9.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +22 -0
- data/README.md +5 -4
- data/codegen/generate_functions.rb +11 -4
- data/codegen/native_functions.yaml +1103 -373
- data/ext/torch/backends.cpp +2 -2
- data/ext/torch/nn.cpp +4 -1
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +19 -5
- data/ext/torch/templates.h +30 -33
- data/ext/torch/tensor.cpp +33 -33
- data/ext/torch/torch.cpp +38 -28
- data/ext/torch/utils.h +0 -6
- data/lib/torch/inspector.rb +1 -1
- data/lib/torch/nn/functional.rb +1 -1
- data/lib/torch/nn/functional_attention.rb +1 -1
- data/lib/torch/nn/module.rb +28 -0
- data/lib/torch/nn/parameter.rb +9 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +1 -1
- data/lib/torch/nn/utils.rb +1 -5
- data/lib/torch/tensor.rb +22 -8
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +14 -48
- metadata +3 -3
data/ext/torch/backends.cpp
CHANGED
@@ -7,11 +7,11 @@
|
|
7
7
|
void init_backends(Rice::Module& m) {
|
8
8
|
auto rb_mBackends = Rice::define_module_under(m, "Backends");
|
9
9
|
|
10
|
-
|
10
|
+
Rice::define_module_under(rb_mBackends, "OpenMP")
|
11
11
|
.add_handler<torch::Error>(handle_error)
|
12
12
|
.define_singleton_function("available?", &torch::hasOpenMP);
|
13
13
|
|
14
|
-
|
14
|
+
Rice::define_module_under(rb_mBackends, "MKL")
|
15
15
|
.add_handler<torch::Error>(handle_error)
|
16
16
|
.define_singleton_function("available?", &torch::hasMKL);
|
17
17
|
}
|
data/ext/torch/nn.cpp
CHANGED
@@ -98,8 +98,11 @@ void init_nn(Rice::Module& m) {
|
|
98
98
|
auto grad = self.grad();
|
99
99
|
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
100
100
|
})
|
101
|
+
// can't use grad=
|
102
|
+
// assignment methods fail with Ruby 3.0
|
103
|
+
// TODO add checks like Tensor
|
101
104
|
.define_method(
|
102
|
-
"
|
105
|
+
"_set_grad",
|
103
106
|
[](Parameter& self, torch::Tensor& grad) {
|
104
107
|
self.mutable_grad() = grad;
|
105
108
|
})
|
@@ -472,12 +472,12 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
|
|
472
472
|
auto param_idx = find_param(signature, key);
|
473
473
|
if (param_idx < 0) {
|
474
474
|
rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
|
475
|
-
signature.name.c_str(),
|
475
|
+
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
476
476
|
}
|
477
477
|
|
478
478
|
if (param_idx < num_pos_args) {
|
479
479
|
rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
|
480
|
-
signature.name.c_str(),
|
480
|
+
signature.name.c_str(), rb_id2name(rb_to_id(key)));
|
481
481
|
}
|
482
482
|
}
|
483
483
|
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -75,7 +75,7 @@ struct RubyArgs {
|
|
75
75
|
int idx;
|
76
76
|
|
77
77
|
inline at::Tensor tensor(int i);
|
78
|
-
inline
|
78
|
+
inline c10::optional<at::Tensor> optionalTensor(int i);
|
79
79
|
inline at::Scalar scalar(int i);
|
80
80
|
// inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
81
81
|
inline std::vector<at::Scalar> scalarlist(int i);
|
@@ -109,6 +109,9 @@ struct RubyArgs {
|
|
109
109
|
// inline at::QScheme toQScheme(int i);
|
110
110
|
inline std::string string(int i);
|
111
111
|
inline c10::optional<std::string> stringOptional(int i);
|
112
|
+
inline c10::string_view stringView(int i);
|
113
|
+
// inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
|
114
|
+
inline c10::optional<c10::string_view> stringViewOptional(int i);
|
112
115
|
// inline PyObject* pyobject(int i);
|
113
116
|
inline int64_t toInt64(int i);
|
114
117
|
// inline int64_t toInt64WithDefault(int i, int64_t default_int);
|
@@ -125,8 +128,8 @@ inline at::Tensor RubyArgs::tensor(int i) {
|
|
125
128
|
return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
|
126
129
|
}
|
127
130
|
|
128
|
-
inline
|
129
|
-
if (NIL_P(args[i])) return
|
131
|
+
inline c10::optional<at::Tensor> RubyArgs::optionalTensor(int i) {
|
132
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
130
133
|
return tensor(i);
|
131
134
|
}
|
132
135
|
|
@@ -232,7 +235,7 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
232
235
|
|
233
236
|
auto it = dtype_map.find(args[i]);
|
234
237
|
if (it == dtype_map.end()) {
|
235
|
-
rb_raise(rb_eArgError, "invalid dtype: %s",
|
238
|
+
rb_raise(rb_eArgError, "invalid dtype: %s", rb_id2name(rb_to_id(args[i])));
|
236
239
|
}
|
237
240
|
return it->second;
|
238
241
|
}
|
@@ -290,7 +293,7 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
|
290
293
|
|
291
294
|
auto it = layout_map.find(args[i]);
|
292
295
|
if (it == layout_map.end()) {
|
293
|
-
rb_raise(rb_eArgError, "invalid layout: %s",
|
296
|
+
rb_raise(rb_eArgError, "invalid layout: %s", rb_id2name(rb_to_id(args[i])));
|
294
297
|
}
|
295
298
|
return it->second;
|
296
299
|
}
|
@@ -322,6 +325,17 @@ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
|
322
325
|
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
323
326
|
}
|
324
327
|
|
328
|
+
// string_view does not own data
|
329
|
+
inline c10::string_view RubyArgs::stringView(int i) {
|
330
|
+
return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
|
331
|
+
}
|
332
|
+
|
333
|
+
// string_view does not own data
|
334
|
+
inline c10::optional<c10::string_view> RubyArgs::stringViewOptional(int i) {
|
335
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
336
|
+
return c10::string_view(RSTRING_PTR(args[i]), RSTRING_LEN(args[i]));
|
337
|
+
}
|
338
|
+
|
325
339
|
inline int64_t RubyArgs::toInt64(int i) {
|
326
340
|
if (NIL_P(args[i])) return signature.params[i].default_int;
|
327
341
|
return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
|
data/ext/torch/templates.h
CHANGED
@@ -41,23 +41,39 @@ using torch::nn::init::NonlinearityType;
|
|
41
41
|
#define RETURN_NIL \
|
42
42
|
return Qnil;
|
43
43
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
}
|
44
|
+
namespace Rice::detail
|
45
|
+
{
|
46
|
+
template<typename T>
|
47
|
+
struct Type<c10::complex<T>>
|
48
|
+
{
|
49
|
+
static bool verify()
|
50
|
+
{
|
51
|
+
return true;
|
53
52
|
}
|
54
|
-
|
55
|
-
|
53
|
+
};
|
54
|
+
|
55
|
+
template<typename T>
|
56
|
+
class To_Ruby<c10::complex<T>>
|
57
|
+
{
|
58
|
+
public:
|
59
|
+
VALUE convert(c10::complex<T> const& x)
|
60
|
+
{
|
61
|
+
return rb_dbl_complex_new(x.real(), x.imag());
|
56
62
|
}
|
57
|
-
|
58
|
-
|
63
|
+
};
|
64
|
+
|
65
|
+
template<typename T>
|
66
|
+
class From_Ruby<c10::complex<T>>
|
67
|
+
{
|
68
|
+
public:
|
69
|
+
c10::complex<T> convert(VALUE x)
|
70
|
+
{
|
71
|
+
VALUE real = rb_funcall(x, rb_intern("real"), 0);
|
72
|
+
VALUE imag = rb_funcall(x, rb_intern("imag"), 0);
|
73
|
+
return c10::complex<T>(From_Ruby<T>().convert(real), From_Ruby<T>().convert(imag));
|
59
74
|
}
|
60
|
-
};
|
75
|
+
};
|
76
|
+
}
|
61
77
|
|
62
78
|
namespace Rice::detail
|
63
79
|
{
|
@@ -131,25 +147,6 @@ namespace Rice::detail
|
|
131
147
|
}
|
132
148
|
};
|
133
149
|
|
134
|
-
template<>
|
135
|
-
struct Type<OptionalTensor>
|
136
|
-
{
|
137
|
-
static bool verify()
|
138
|
-
{
|
139
|
-
return true;
|
140
|
-
}
|
141
|
-
};
|
142
|
-
|
143
|
-
template<>
|
144
|
-
class From_Ruby<OptionalTensor>
|
145
|
-
{
|
146
|
-
public:
|
147
|
-
OptionalTensor convert(VALUE x)
|
148
|
-
{
|
149
|
-
return OptionalTensor(x);
|
150
|
-
}
|
151
|
-
};
|
152
|
-
|
153
150
|
template<>
|
154
151
|
struct Type<Scalar>
|
155
152
|
{
|
data/ext/torch/tensor.cpp
CHANGED
@@ -10,28 +10,6 @@
|
|
10
10
|
using namespace Rice;
|
11
11
|
using torch::indexing::TensorIndex;
|
12
12
|
|
13
|
-
namespace Rice::detail
|
14
|
-
{
|
15
|
-
template<typename T>
|
16
|
-
struct Type<c10::complex<T>>
|
17
|
-
{
|
18
|
-
static bool verify()
|
19
|
-
{
|
20
|
-
return true;
|
21
|
-
}
|
22
|
-
};
|
23
|
-
|
24
|
-
template<typename T>
|
25
|
-
class To_Ruby<c10::complex<T>>
|
26
|
-
{
|
27
|
-
public:
|
28
|
-
VALUE convert(c10::complex<T> const& x)
|
29
|
-
{
|
30
|
-
return rb_dbl_complex_new(x.real(), x.imag());
|
31
|
-
}
|
32
|
-
};
|
33
|
-
}
|
34
|
-
|
35
13
|
template<typename T>
|
36
14
|
Array flat_data(Tensor& tensor) {
|
37
15
|
Tensor view = tensor.reshape({tensor.numel()});
|
@@ -107,7 +85,7 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
|
107
85
|
ParsedArgs<4> parsed_args;
|
108
86
|
auto _r = parser.parse(self_, argc, argv, parsed_args);
|
109
87
|
// _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
|
110
|
-
auto dispatch__backward = [](const Tensor & self, TensorList inputs, const
|
88
|
+
auto dispatch__backward = [](const Tensor & self, TensorList inputs, const c10::optional<at::Tensor> & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
|
111
89
|
// in future, release GVL
|
112
90
|
self._backward(inputs, gradient, retain_graph, create_graph);
|
113
91
|
};
|
@@ -125,13 +103,13 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
125
103
|
rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
|
126
104
|
|
127
105
|
rb_cTensor
|
128
|
-
.define_method("cuda?", &
|
129
|
-
.define_method("sparse?", &
|
130
|
-
.define_method("quantized?", &
|
131
|
-
.define_method("dim", &
|
132
|
-
.define_method("numel", &
|
133
|
-
.define_method("element_size", &
|
134
|
-
.define_method("requires_grad", &
|
106
|
+
.define_method("cuda?", [](Tensor& self) { return self.is_cuda(); })
|
107
|
+
.define_method("sparse?", [](Tensor& self) { return self.is_sparse(); })
|
108
|
+
.define_method("quantized?", [](Tensor& self) { return self.is_quantized(); })
|
109
|
+
.define_method("dim", [](Tensor& self) { return self.dim(); })
|
110
|
+
.define_method("numel", [](Tensor& self) { return self.numel(); })
|
111
|
+
.define_method("element_size", [](Tensor& self) { return self.element_size(); })
|
112
|
+
.define_method("requires_grad", [](Tensor& self) { return self.requires_grad(); })
|
135
113
|
.define_method(
|
136
114
|
"_size",
|
137
115
|
[](Tensor& self, int64_t dim) {
|
@@ -189,9 +167,31 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
189
167
|
auto grad = self.grad();
|
190
168
|
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
191
169
|
})
|
170
|
+
// can't use grad=
|
171
|
+
// assignment methods fail with Ruby 3.0
|
192
172
|
.define_method(
|
193
|
-
"
|
194
|
-
[](Tensor& self,
|
173
|
+
"_set_grad",
|
174
|
+
[](Tensor& self, Rice::Object value) {
|
175
|
+
if (value.is_nil()) {
|
176
|
+
self.mutable_grad().reset();
|
177
|
+
return;
|
178
|
+
}
|
179
|
+
|
180
|
+
const auto& grad = Rice::detail::From_Ruby<torch::Tensor>().convert(value.value());
|
181
|
+
|
182
|
+
// TODO support sparse grad
|
183
|
+
if (!grad.options().type_equal(self.options())) {
|
184
|
+
rb_raise(rb_eArgError, "assigned grad has data of a different type");
|
185
|
+
}
|
186
|
+
|
187
|
+
if (self.is_cuda() && grad.get_device() != self.get_device()) {
|
188
|
+
rb_raise(rb_eArgError, "assigned grad has data located on a different device");
|
189
|
+
}
|
190
|
+
|
191
|
+
if (!self.sizes().equals(grad.sizes())) {
|
192
|
+
rb_raise(rb_eArgError, "assigned grad has data of a different size");
|
193
|
+
}
|
194
|
+
|
195
195
|
self.mutable_grad() = grad;
|
196
196
|
})
|
197
197
|
.define_method(
|
@@ -281,7 +281,7 @@ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions
|
|
281
281
|
})
|
282
282
|
.define_method(
|
283
283
|
"_to",
|
284
|
-
[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
284
|
+
[](Tensor& self, torch::Device& device, int dtype, bool non_blocking, bool copy) {
|
285
285
|
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
286
286
|
});
|
287
287
|
|
data/ext/torch/torch.cpp
CHANGED
@@ -6,6 +6,23 @@
|
|
6
6
|
#include "templates.h"
|
7
7
|
#include "utils.h"
|
8
8
|
|
9
|
+
template<typename T>
|
10
|
+
torch::Tensor make_tensor(Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
11
|
+
std::vector<T> vec;
|
12
|
+
for (long i = 0; i < a.size(); i++) {
|
13
|
+
vec.push_back(Rice::detail::From_Ruby<T>().convert(a[i].value()));
|
14
|
+
}
|
15
|
+
|
16
|
+
// hack for requires_grad error
|
17
|
+
auto requires_grad = options.requires_grad();
|
18
|
+
torch::Tensor t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
19
|
+
if (requires_grad) {
|
20
|
+
t.set_requires_grad(true);
|
21
|
+
}
|
22
|
+
|
23
|
+
return t.reshape(size);
|
24
|
+
}
|
25
|
+
|
9
26
|
void init_torch(Rice::Module& m) {
|
10
27
|
m.add_handler<torch::Error>(handle_error);
|
11
28
|
add_torch_functions(m);
|
@@ -61,35 +78,28 @@ void init_torch(Rice::Module& m) {
|
|
61
78
|
"_tensor",
|
62
79
|
[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
63
80
|
auto dtype = options.dtype();
|
64
|
-
torch::
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
81
|
+
if (dtype == torch::kByte) {
|
82
|
+
return make_tensor<uint8_t>(a, size, options);
|
83
|
+
} else if (dtype == torch::kChar) {
|
84
|
+
return make_tensor<int8_t>(a, size, options);
|
85
|
+
} else if (dtype == torch::kShort) {
|
86
|
+
return make_tensor<int16_t>(a, size, options);
|
87
|
+
} else if (dtype == torch::kInt) {
|
88
|
+
return make_tensor<int32_t>(a, size, options);
|
89
|
+
} else if (dtype == torch::kLong) {
|
90
|
+
return make_tensor<int64_t>(a, size, options);
|
91
|
+
} else if (dtype == torch::kFloat) {
|
92
|
+
return make_tensor<float>(a, size, options);
|
93
|
+
} else if (dtype == torch::kDouble) {
|
94
|
+
return make_tensor<double>(a, size, options);
|
95
|
+
} else if (dtype == torch::kBool) {
|
96
|
+
return make_tensor<uint8_t>(a, size, options);
|
97
|
+
} else if (dtype == torch::kComplexFloat) {
|
98
|
+
return make_tensor<c10::complex<float>>(a, size, options);
|
99
|
+
} else if (dtype == torch::kComplexDouble) {
|
100
|
+
return make_tensor<c10::complex<double>>(a, size, options);
|
80
101
|
} else {
|
81
|
-
std::
|
82
|
-
for (long i = 0; i < a.size(); i++) {
|
83
|
-
vec.push_back(Rice::detail::From_Ruby<float>().convert(a[i].value()));
|
84
|
-
}
|
85
|
-
// hack for requires_grad error
|
86
|
-
if (options.requires_grad()) {
|
87
|
-
t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
88
|
-
t.set_requires_grad(true);
|
89
|
-
} else {
|
90
|
-
t = torch::tensor(vec, options);
|
91
|
-
}
|
102
|
+
throw std::runtime_error("Unsupported type");
|
92
103
|
}
|
93
|
-
return t.reshape(size);
|
94
104
|
});
|
95
105
|
}
|
data/ext/torch/utils.h
CHANGED
@@ -16,12 +16,6 @@ inline VALUE THPUtils_internSymbol(const std::string& str) {
|
|
16
16
|
return Rice::Symbol(str);
|
17
17
|
}
|
18
18
|
|
19
|
-
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
20
|
-
Check_Type(obj, T_SYMBOL);
|
21
|
-
obj = rb_funcall(obj, rb_intern("to_s"), 0);
|
22
|
-
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
23
|
-
}
|
24
|
-
|
25
19
|
inline std::string THPUtils_unpackString(VALUE obj) {
|
26
20
|
Check_Type(obj, T_STRING);
|
27
21
|
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
data/lib/torch/inspector.rb
CHANGED
@@ -247,7 +247,7 @@ module Torch
|
|
247
247
|
# length includes spaces and comma between elements
|
248
248
|
element_length = formatter.width + 2
|
249
249
|
elements_per_line = [1, ((PRINT_OPTS[:linewidth] - indent) / element_length.to_f).floor.to_i].max
|
250
|
-
|
250
|
+
_char_per_line = element_length * elements_per_line
|
251
251
|
|
252
252
|
if summarize && slf.size(0) > 2 * PRINT_OPTS[:edgeitems]
|
253
253
|
data = (
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -571,7 +571,7 @@ module Torch
|
|
571
571
|
end
|
572
572
|
|
573
573
|
def _interp_output_size(closed_over_args)
|
574
|
-
input, size, scale_factor,
|
574
|
+
input, size, scale_factor, _recompute_scale_factor = closed_over_args
|
575
575
|
dim = input.dim - 2
|
576
576
|
if size.nil? && scale_factor.nil?
|
577
577
|
raise ArgumentError, "either size or scale_factor should be defined"
|
data/lib/torch/nn/module.rb
CHANGED
@@ -280,6 +280,11 @@ module Torch
|
|
280
280
|
end
|
281
281
|
end
|
282
282
|
|
283
|
+
def deep_dup
|
284
|
+
memo = {}
|
285
|
+
dup_value(self, memo)
|
286
|
+
end
|
287
|
+
|
283
288
|
def method_missing(method, *args, &block)
|
284
289
|
name = method.to_s
|
285
290
|
if named_parameters.key?(name)
|
@@ -388,6 +393,29 @@ module Torch
|
|
388
393
|
destination[prefix + k] = v
|
389
394
|
end
|
390
395
|
end
|
396
|
+
|
397
|
+
# keep memo hash like Python deepcopy
|
398
|
+
# https://docs.python.org/3/library/copy.html
|
399
|
+
def dup_value(v, memo)
|
400
|
+
memo[v.object_id] ||= begin
|
401
|
+
case v
|
402
|
+
when Method, UnboundMethod
|
403
|
+
v
|
404
|
+
when Hash
|
405
|
+
v.to_h { |k, v2| [dup_value(k, memo), dup_value(v2, memo)] }
|
406
|
+
when Array
|
407
|
+
v.map { |v2| dup_value(v2, memo) }
|
408
|
+
when Torch::NN::Module
|
409
|
+
copy = v.dup
|
410
|
+
v.instance_variables.each do |var|
|
411
|
+
copy.instance_variable_set(var, dup_value(v.instance_variable_get(var), memo))
|
412
|
+
end
|
413
|
+
copy
|
414
|
+
else
|
415
|
+
v.dup
|
416
|
+
end
|
417
|
+
end
|
418
|
+
end
|
391
419
|
end
|
392
420
|
end
|
393
421
|
end
|
data/lib/torch/nn/parameter.rb
CHANGED
@@ -1,6 +1,9 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Parameter < Tensor
|
4
|
+
# fix for issue w/ assignment methods
|
5
|
+
alias_method :grad=, :_set_grad
|
6
|
+
|
4
7
|
def self.new(data = nil, requires_grad: true)
|
5
8
|
data = Tensor.new unless data
|
6
9
|
_make_subclass(data, requires_grad)
|
@@ -9,6 +12,12 @@ module Torch
|
|
9
12
|
def inspect
|
10
13
|
"Parameter containing:\n#{super}"
|
11
14
|
end
|
15
|
+
|
16
|
+
def dup
|
17
|
+
Torch.no_grad do
|
18
|
+
Parameter.new(clone, requires_grad: requires_grad)
|
19
|
+
end
|
20
|
+
end
|
12
21
|
end
|
13
22
|
end
|
14
23
|
end
|
data/lib/torch/nn/utils.rb
CHANGED
@@ -22,11 +22,7 @@ module Torch
|
|
22
22
|
end
|
23
23
|
|
24
24
|
def _clones(mod, n)
|
25
|
-
|
26
|
-
layers = n.times.map do |i|
|
27
|
-
mod.clone.tap { |l| l.load_state_dict(state) }
|
28
|
-
end
|
29
|
-
ModuleList.new(layers)
|
25
|
+
ModuleList.new(n.times.map { mod.deep_dup })
|
30
26
|
end
|
31
27
|
|
32
28
|
def _activation_fn(activation)
|
data/lib/torch/tensor.rb
CHANGED
@@ -8,6 +8,9 @@ module Torch
|
|
8
8
|
alias_method :ndim, :dim
|
9
9
|
alias_method :ndimension, :dim
|
10
10
|
|
11
|
+
# fix for issue w/ assignment methods
|
12
|
+
alias_method :grad=, :_set_grad
|
13
|
+
|
11
14
|
# use alias_method for performance
|
12
15
|
alias_method :+, :add
|
13
16
|
alias_method :-, :sub
|
@@ -106,6 +109,7 @@ module Torch
|
|
106
109
|
size(0)
|
107
110
|
end
|
108
111
|
|
112
|
+
remove_method :item
|
109
113
|
def item
|
110
114
|
if numel != 1
|
111
115
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
@@ -133,18 +137,10 @@ module Torch
|
|
133
137
|
cls.from_string(_data_str).reshape(*shape)
|
134
138
|
end
|
135
139
|
|
136
|
-
def new_ones(*size, **options)
|
137
|
-
Torch.ones_like(Torch.empty(*size), **options)
|
138
|
-
end
|
139
|
-
|
140
140
|
def requires_grad=(requires_grad)
|
141
141
|
_requires_grad!(requires_grad)
|
142
142
|
end
|
143
143
|
|
144
|
-
def requires_grad!(requires_grad = true)
|
145
|
-
_requires_grad!(requires_grad)
|
146
|
-
end
|
147
|
-
|
148
144
|
def type(dtype)
|
149
145
|
if dtype.is_a?(Class)
|
150
146
|
raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
|
@@ -185,5 +181,23 @@ module Torch
|
|
185
181
|
def stft(*args)
|
186
182
|
Torch.stft(*args)
|
187
183
|
end
|
184
|
+
|
185
|
+
def dup
|
186
|
+
Torch.no_grad do
|
187
|
+
clone
|
188
|
+
end
|
189
|
+
end
|
190
|
+
|
191
|
+
# not a method in native_functions.yaml
|
192
|
+
# attribute in Python rather than method
|
193
|
+
def imag
|
194
|
+
Torch.imag(self)
|
195
|
+
end
|
196
|
+
|
197
|
+
# not a method in native_functions.yaml
|
198
|
+
# attribute in Python rather than method
|
199
|
+
def real
|
200
|
+
Torch.real(self)
|
201
|
+
end
|
188
202
|
end
|
189
203
|
end
|
@@ -29,7 +29,7 @@ module Torch
|
|
29
29
|
|
30
30
|
# try to keep the random number generator in sync with Python
|
31
31
|
# this makes it easy to compare results
|
32
|
-
|
32
|
+
_base_seed = Torch.empty([], dtype: :int64).random!.item
|
33
33
|
|
34
34
|
indexes =
|
35
35
|
if @shuffle
|
data/lib/torch/version.rb
CHANGED