torch-rb 0.4.1 → 0.5.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +24 -0
- data/README.md +9 -1
- data/codegen/generate_functions.rb +13 -8
- data/codegen/native_functions.yaml +2363 -714
- data/ext/torch/ext.cpp +15 -16
- data/ext/torch/ruby_arg_parser.h +26 -2
- data/ext/torch/templates.h +1 -0
- data/ext/torch/wrap_outputs.h +7 -0
- data/lib/torch.rb +5 -7
- data/lib/torch/nn/module.rb +101 -21
- data/lib/torch/optim/adadelta.rb +3 -3
- data/lib/torch/optim/adagrad.rb +3 -3
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +2 -2
- data/lib/torch/optim/adamw.rb +2 -2
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/rmsprop.rb +7 -7
- data/lib/torch/optim/rprop.rb +1 -1
- data/lib/torch/optim/sgd.rb +2 -2
- data/lib/torch/tensor.rb +5 -0
- data/lib/torch/version.rb +1 -1
- metadata +3 -3
data/ext/torch/ext.cpp
CHANGED
@@ -44,8 +44,13 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
44
44
|
if (obj.is_instance_of(rb_cInteger)) {
|
45
45
|
indices.push_back(from_ruby<int64_t>(obj));
|
46
46
|
} else if (obj.is_instance_of(rb_cRange)) {
|
47
|
-
torch::optional<int64_t> start_index =
|
48
|
-
torch::optional<int64_t> stop_index =
|
47
|
+
torch::optional<int64_t> start_index = torch::nullopt;
|
48
|
+
torch::optional<int64_t> stop_index = torch::nullopt;
|
49
|
+
|
50
|
+
Object begin = obj.call("begin");
|
51
|
+
if (!begin.is_nil()) {
|
52
|
+
start_index = from_ruby<int64_t>(begin);
|
53
|
+
}
|
49
54
|
|
50
55
|
Object end = obj.call("end");
|
51
56
|
if (!end.is_nil()) {
|
@@ -53,12 +58,14 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
53
58
|
}
|
54
59
|
|
55
60
|
Object exclude_end = obj.call("exclude_end?");
|
56
|
-
if (!exclude_end) {
|
61
|
+
if (stop_index.has_value() && !exclude_end) {
|
57
62
|
if (stop_index.value() == -1) {
|
58
63
|
stop_index = torch::nullopt;
|
59
64
|
} else {
|
60
65
|
stop_index = stop_index.value() + 1;
|
61
66
|
}
|
67
|
+
} else if (!stop_index.has_value() && exclude_end) {
|
68
|
+
stop_index = -1;
|
62
69
|
}
|
63
70
|
|
64
71
|
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
@@ -348,16 +355,6 @@ void Init_ext()
|
|
348
355
|
*[](Tensor& self) {
|
349
356
|
return self.is_contiguous();
|
350
357
|
})
|
351
|
-
.define_method(
|
352
|
-
"addcmul!",
|
353
|
-
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
354
|
-
return self.addcmul_(tensor1, tensor2, value);
|
355
|
-
})
|
356
|
-
.define_method(
|
357
|
-
"addcdiv!",
|
358
|
-
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
359
|
-
return self.addcdiv_(tensor1, tensor2, value);
|
360
|
-
})
|
361
358
|
.define_method(
|
362
359
|
"_requires_grad!",
|
363
360
|
*[](Tensor& self, bool requires_grad) {
|
@@ -372,7 +369,7 @@ void Init_ext()
|
|
372
369
|
.define_method(
|
373
370
|
"grad=",
|
374
371
|
*[](Tensor& self, torch::Tensor& grad) {
|
375
|
-
self.
|
372
|
+
self.mutable_grad() = grad;
|
376
373
|
})
|
377
374
|
.define_method(
|
378
375
|
"_dtype",
|
@@ -609,7 +606,7 @@ void Init_ext()
|
|
609
606
|
.define_method(
|
610
607
|
"grad=",
|
611
608
|
*[](Parameter& self, torch::Tensor& grad) {
|
612
|
-
self.
|
609
|
+
self.mutable_grad() = grad;
|
613
610
|
});
|
614
611
|
|
615
612
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
@@ -628,5 +625,7 @@ void Init_ext()
|
|
628
625
|
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
629
626
|
.add_handler<torch::Error>(handle_error)
|
630
627
|
.define_singleton_method("available?", &torch::cuda::is_available)
|
631
|
-
.define_singleton_method("device_count", &torch::cuda::device_count)
|
628
|
+
.define_singleton_method("device_count", &torch::cuda::device_count)
|
629
|
+
.define_singleton_method("manual_seed", &torch::cuda::manual_seed)
|
630
|
+
.define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all);
|
632
631
|
}
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -91,7 +91,7 @@ struct RubyArgs {
|
|
91
91
|
inline c10::optional<int64_t> toInt64Optional(int i);
|
92
92
|
inline c10::optional<bool> toBoolOptional(int i);
|
93
93
|
inline c10::optional<double> toDoubleOptional(int i);
|
94
|
-
|
94
|
+
inline c10::OptionalArray<double> doublelistOptional(int i);
|
95
95
|
// inline at::Layout layout(int i);
|
96
96
|
// inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
|
97
97
|
inline c10::optional<at::Layout> layoutOptional(int i);
|
@@ -105,7 +105,7 @@ struct RubyArgs {
|
|
105
105
|
inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
|
106
106
|
// inline at::QScheme toQScheme(int i);
|
107
107
|
inline std::string string(int i);
|
108
|
-
|
108
|
+
inline c10::optional<std::string> stringOptional(int i);
|
109
109
|
// inline PyObject* pyobject(int i);
|
110
110
|
inline int64_t toInt64(int i);
|
111
111
|
// inline int64_t toInt64WithDefault(int i, int64_t default_int);
|
@@ -249,6 +249,25 @@ inline c10::optional<double> RubyArgs::toDoubleOptional(int i) {
|
|
249
249
|
return toDouble(i);
|
250
250
|
}
|
251
251
|
|
252
|
+
inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
253
|
+
if (NIL_P(args[i])) return {};
|
254
|
+
|
255
|
+
VALUE arg = args[i];
|
256
|
+
auto size = RARRAY_LEN(arg);
|
257
|
+
std::vector<double> res(size);
|
258
|
+
for (idx = 0; idx < size; idx++) {
|
259
|
+
VALUE obj = rb_ary_entry(arg, idx);
|
260
|
+
if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
|
261
|
+
res[idx] = from_ruby<double>(obj);
|
262
|
+
} else {
|
263
|
+
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
264
|
+
signature.name.c_str(), signature.params[i].name.c_str(),
|
265
|
+
signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
|
266
|
+
}
|
267
|
+
}
|
268
|
+
return res;
|
269
|
+
}
|
270
|
+
|
252
271
|
inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
253
272
|
if (NIL_P(args[i])) return c10::nullopt;
|
254
273
|
|
@@ -285,6 +304,11 @@ inline std::string RubyArgs::string(int i) {
|
|
285
304
|
return from_ruby<std::string>(args[i]);
|
286
305
|
}
|
287
306
|
|
307
|
+
inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
308
|
+
if (!args[i]) return c10::nullopt;
|
309
|
+
return from_ruby<std::string>(args[i]);
|
310
|
+
}
|
311
|
+
|
288
312
|
inline int64_t RubyArgs::toInt64(int i) {
|
289
313
|
if (NIL_P(args[i])) return signature.params[i].default_int;
|
290
314
|
return from_ruby<int64_t>(args[i]);
|
data/ext/torch/templates.h
CHANGED
data/ext/torch/wrap_outputs.h
CHANGED
data/lib/torch.rb
CHANGED
@@ -261,6 +261,8 @@ module Torch
|
|
261
261
|
Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
262
262
|
elsif args.size == 1 && args.first.is_a?(Array)
|
263
263
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
264
|
+
elsif args.size == 0
|
265
|
+
Torch.empty(0, dtype: dtype, device: device)
|
264
266
|
else
|
265
267
|
Torch.empty(*args, dtype: dtype, device: device)
|
266
268
|
end
|
@@ -434,7 +436,8 @@ module Torch
|
|
434
436
|
zeros(input.size, **like_options(input, options))
|
435
437
|
end
|
436
438
|
|
437
|
-
|
439
|
+
# center option
|
440
|
+
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
|
438
441
|
if center
|
439
442
|
signal_dim = input.dim
|
440
443
|
extended_shape = [1] * (3 - signal_dim) + input.size
|
@@ -442,12 +445,7 @@ module Torch
|
|
442
445
|
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
443
446
|
input = input.view(input.shape[-signal_dim..-1])
|
444
447
|
end
|
445
|
-
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
446
|
-
end
|
447
|
-
|
448
|
-
def clamp(tensor, min, max)
|
449
|
-
tensor = _clamp_min(tensor, min)
|
450
|
-
_clamp_max(tensor, max)
|
448
|
+
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
|
451
449
|
end
|
452
450
|
|
453
451
|
private
|
data/lib/torch/nn/module.rb
CHANGED
@@ -113,35 +113,53 @@ module Torch
|
|
113
113
|
forward(*input, **kwargs)
|
114
114
|
end
|
115
115
|
|
116
|
-
def state_dict(destination: nil)
|
116
|
+
def state_dict(destination: nil, prefix: "")
|
117
117
|
destination ||= {}
|
118
|
-
|
119
|
-
|
118
|
+
save_to_state_dict(destination, prefix: prefix)
|
119
|
+
|
120
|
+
named_children.each do |name, mod|
|
121
|
+
next unless mod
|
122
|
+
mod.state_dict(destination: destination, prefix: prefix + name + ".")
|
120
123
|
end
|
121
124
|
destination
|
122
125
|
end
|
123
126
|
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
raise Error, "Unknown module: #{k1}"
|
127
|
+
def load_state_dict(state_dict, strict: true)
|
128
|
+
# TODO support strict: false
|
129
|
+
raise "strict: false not implemented yet" unless strict
|
130
|
+
|
131
|
+
missing_keys = []
|
132
|
+
unexpected_keys = []
|
133
|
+
error_msgs = []
|
134
|
+
|
135
|
+
# TODO handle metadata
|
136
|
+
|
137
|
+
_load = lambda do |mod, prefix = ""|
|
138
|
+
# TODO handle metadata
|
139
|
+
local_metadata = {}
|
140
|
+
mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
|
141
|
+
mod.named_children.each do |name, child|
|
142
|
+
_load.call(child, prefix + name + ".") unless child.nil?
|
141
143
|
end
|
142
144
|
end
|
143
145
|
|
144
|
-
|
146
|
+
_load.call(self)
|
147
|
+
|
148
|
+
if strict
|
149
|
+
if unexpected_keys.any?
|
150
|
+
error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
|
151
|
+
end
|
152
|
+
|
153
|
+
if missing_keys.any?
|
154
|
+
error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
|
155
|
+
end
|
156
|
+
end
|
157
|
+
|
158
|
+
if error_msgs.any?
|
159
|
+
# just show first error
|
160
|
+
raise Error, error_msgs[0]
|
161
|
+
end
|
162
|
+
|
145
163
|
nil
|
146
164
|
end
|
147
165
|
|
@@ -300,6 +318,68 @@ module Torch
|
|
300
318
|
def dict
|
301
319
|
instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
302
320
|
end
|
321
|
+
|
322
|
+
def load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
323
|
+
# TODO add hooks
|
324
|
+
|
325
|
+
# TODO handle non-persistent buffers
|
326
|
+
persistent_buffers = named_buffers
|
327
|
+
local_name_params = named_parameters(recurse: false).merge(persistent_buffers)
|
328
|
+
local_state = local_name_params.select { |_, v| !v.nil? }
|
329
|
+
|
330
|
+
local_state.each do |name, param|
|
331
|
+
key = prefix + name
|
332
|
+
if state_dict.key?(key)
|
333
|
+
input_param = state_dict[key]
|
334
|
+
|
335
|
+
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
336
|
+
if param.shape.length == 0 && input_param.shape.length == 1
|
337
|
+
input_param = input_param[0]
|
338
|
+
end
|
339
|
+
|
340
|
+
if input_param.shape != param.shape
|
341
|
+
# local shape should match the one in checkpoint
|
342
|
+
error_msgs << "size mismatch for #{key}: copying a param with shape #{input_param.shape} from checkpoint, " +
|
343
|
+
"the shape in current model is #{param.shape}."
|
344
|
+
next
|
345
|
+
end
|
346
|
+
|
347
|
+
begin
|
348
|
+
Torch.no_grad do
|
349
|
+
param.copy!(input_param)
|
350
|
+
end
|
351
|
+
rescue => e
|
352
|
+
error_msgs << "While copying the parameter named #{key.inspect}, " +
|
353
|
+
"whose dimensions in the model are #{param.size} and " +
|
354
|
+
"whose dimensions in the checkpoint are #{input_param.size}, " +
|
355
|
+
"an exception occurred: #{e.inspect}"
|
356
|
+
end
|
357
|
+
elsif strict
|
358
|
+
missing_keys << key
|
359
|
+
end
|
360
|
+
end
|
361
|
+
|
362
|
+
if strict
|
363
|
+
state_dict.each_key do |key|
|
364
|
+
if key.start_with?(prefix)
|
365
|
+
input_name = key[prefix.length..-1]
|
366
|
+
input_name = input_name.split(".", 2)[0]
|
367
|
+
if !named_children.key?(input_name) && !local_state.key?(input_name)
|
368
|
+
unexpected_keys << key
|
369
|
+
end
|
370
|
+
end
|
371
|
+
end
|
372
|
+
end
|
373
|
+
end
|
374
|
+
|
375
|
+
def save_to_state_dict(destination, prefix: "")
|
376
|
+
named_parameters(recurse: false).each do |k, v|
|
377
|
+
destination[prefix + k] = v
|
378
|
+
end
|
379
|
+
named_buffers.each do |k, v|
|
380
|
+
destination[prefix + k] = v
|
381
|
+
end
|
382
|
+
end
|
303
383
|
end
|
304
384
|
end
|
305
385
|
end
|
data/lib/torch/optim/adadelta.rb
CHANGED
@@ -39,14 +39,14 @@ module Torch
|
|
39
39
|
state[:step] += 1
|
40
40
|
|
41
41
|
if group[:weight_decay] != 0
|
42
|
-
grad = grad.add(group[:weight_decay]
|
42
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
43
43
|
end
|
44
44
|
|
45
|
-
square_avg.mul!(rho).addcmul!(1 - rho
|
45
|
+
square_avg.mul!(rho).addcmul!(grad, grad, value: 1 - rho)
|
46
46
|
std = square_avg.add(eps).sqrt!
|
47
47
|
delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
|
48
48
|
p.data.add!(delta, alpha: -group[:lr])
|
49
|
-
acc_delta.mul!(rho).addcmul!(1 - rho
|
49
|
+
acc_delta.mul!(rho).addcmul!(delta, delta, value: 1 - rho)
|
50
50
|
end
|
51
51
|
end
|
52
52
|
|
data/lib/torch/optim/adagrad.rb
CHANGED
@@ -49,7 +49,7 @@ module Torch
|
|
49
49
|
if p.grad.data.sparse?
|
50
50
|
raise Error, "weight_decay option is not compatible with sparse gradients"
|
51
51
|
end
|
52
|
-
grad = grad.add(group[:weight_decay]
|
52
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
53
53
|
end
|
54
54
|
|
55
55
|
clr = group[:lr] / (1 + (state[:step] - 1) * group[:lr_decay])
|
@@ -57,9 +57,9 @@ module Torch
|
|
57
57
|
if grad.sparse?
|
58
58
|
raise NotImplementedYet
|
59
59
|
else
|
60
|
-
state[:sum].addcmul!(
|
60
|
+
state[:sum].addcmul!(grad, grad, value: 1)
|
61
61
|
std = state[:sum].sqrt.add!(group[:eps])
|
62
|
-
p.data.addcdiv!(
|
62
|
+
p.data.addcdiv!(grad, std, value: -clr)
|
63
63
|
end
|
64
64
|
end
|
65
65
|
end
|
data/lib/torch/optim/adam.rb
CHANGED
@@ -58,7 +58,7 @@ module Torch
|
|
58
58
|
|
59
59
|
# Decay the first and second moment running average coefficient
|
60
60
|
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
61
|
-
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2
|
61
|
+
exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
|
62
62
|
if amsgrad
|
63
63
|
# Maintains the maximum of all 2nd moment running avg. till now
|
64
64
|
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
@@ -70,7 +70,7 @@ module Torch
|
|
70
70
|
|
71
71
|
step_size = group[:lr] / bias_correction1
|
72
72
|
|
73
|
-
p.data.addcdiv!(
|
73
|
+
p.data.addcdiv!(exp_avg, denom, value: -step_size)
|
74
74
|
end
|
75
75
|
end
|
76
76
|
|
data/lib/torch/optim/adamax.rb
CHANGED
@@ -42,7 +42,7 @@ module Torch
|
|
42
42
|
state[:step] += 1
|
43
43
|
|
44
44
|
if group[:weight_decay] != 0
|
45
|
-
grad = grad.add(group[:weight_decay]
|
45
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
46
46
|
end
|
47
47
|
|
48
48
|
# Update biased first moment estimate.
|
@@ -57,7 +57,7 @@ module Torch
|
|
57
57
|
bias_correction = 1 - beta1 ** state[:step]
|
58
58
|
clr = group[:lr] / bias_correction
|
59
59
|
|
60
|
-
p.data.addcdiv!(
|
60
|
+
p.data.addcdiv!(exp_avg, exp_inf, value: -clr)
|
61
61
|
end
|
62
62
|
end
|
63
63
|
|
data/lib/torch/optim/adamw.rb
CHANGED
@@ -59,7 +59,7 @@ module Torch
|
|
59
59
|
|
60
60
|
# Decay the first and second moment running average coefficient
|
61
61
|
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
62
|
-
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2
|
62
|
+
exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
|
63
63
|
if amsgrad
|
64
64
|
# Maintains the maximum of all 2nd moment running avg. till now
|
65
65
|
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
@@ -71,7 +71,7 @@ module Torch
|
|
71
71
|
|
72
72
|
step_size = group[:lr] / bias_correction1
|
73
73
|
|
74
|
-
p.data.addcdiv!(
|
74
|
+
p.data.addcdiv!(exp_avg, denom, value: -step_size)
|
75
75
|
end
|
76
76
|
end
|
77
77
|
|
data/lib/torch/optim/asgd.rb
CHANGED
data/lib/torch/optim/rmsprop.rb
CHANGED
@@ -46,25 +46,25 @@ module Torch
|
|
46
46
|
state[:step] += 1
|
47
47
|
|
48
48
|
if group[:weight_decay] != 0
|
49
|
-
grad = grad.add(group[:weight_decay]
|
49
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
50
50
|
end
|
51
51
|
|
52
|
-
square_avg.mul!(alpha).addcmul!(1 - alpha
|
52
|
+
square_avg.mul!(alpha).addcmul!(grad, grad, value: 1 - alpha)
|
53
53
|
|
54
54
|
if group[:centered]
|
55
55
|
grad_avg = state[:grad_avg]
|
56
|
-
grad_avg.mul!(alpha).add!(1 - alpha
|
57
|
-
avg = square_avg.addcmul(
|
56
|
+
grad_avg.mul!(alpha).add!(grad, alpha: 1 - alpha)
|
57
|
+
avg = square_avg.addcmul(grad_avg, grad_avg, value: -1).sqrt!.add!(group[:eps])
|
58
58
|
else
|
59
59
|
avg = square_avg.sqrt.add!(group[:eps])
|
60
60
|
end
|
61
61
|
|
62
62
|
if group[:momentum] > 0
|
63
63
|
buf = state[:momentum_buffer]
|
64
|
-
buf.mul!(group[:momentum]).addcdiv!(grad, avg)
|
65
|
-
p.data.add!(-group[:lr]
|
64
|
+
buf.mul!(group[:momentum]).addcdiv!(grad, avg, value: 1)
|
65
|
+
p.data.add!(buf, alpha: -group[:lr])
|
66
66
|
else
|
67
|
-
p.data.addcdiv!(-group[:lr]
|
67
|
+
p.data.addcdiv!(grad, avg, value: -group[:lr])
|
68
68
|
end
|
69
69
|
end
|
70
70
|
end
|