torch-rb 0.4.1 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +24 -0
- data/README.md +9 -1
- data/codegen/generate_functions.rb +13 -8
- data/codegen/native_functions.yaml +2363 -714
- data/ext/torch/ext.cpp +15 -16
- data/ext/torch/ruby_arg_parser.h +26 -2
- data/ext/torch/templates.h +1 -0
- data/ext/torch/wrap_outputs.h +7 -0
- data/lib/torch.rb +5 -7
- data/lib/torch/nn/module.rb +101 -21
- data/lib/torch/optim/adadelta.rb +3 -3
- data/lib/torch/optim/adagrad.rb +3 -3
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +2 -2
- data/lib/torch/optim/adamw.rb +2 -2
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/rmsprop.rb +7 -7
- data/lib/torch/optim/rprop.rb +1 -1
- data/lib/torch/optim/sgd.rb +2 -2
- data/lib/torch/tensor.rb +5 -0
- data/lib/torch/version.rb +1 -1
- metadata +3 -3
data/ext/torch/ext.cpp
CHANGED
@@ -44,8 +44,13 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
44
44
|
if (obj.is_instance_of(rb_cInteger)) {
|
45
45
|
indices.push_back(from_ruby<int64_t>(obj));
|
46
46
|
} else if (obj.is_instance_of(rb_cRange)) {
|
47
|
-
torch::optional<int64_t> start_index =
|
48
|
-
torch::optional<int64_t> stop_index =
|
47
|
+
torch::optional<int64_t> start_index = torch::nullopt;
|
48
|
+
torch::optional<int64_t> stop_index = torch::nullopt;
|
49
|
+
|
50
|
+
Object begin = obj.call("begin");
|
51
|
+
if (!begin.is_nil()) {
|
52
|
+
start_index = from_ruby<int64_t>(begin);
|
53
|
+
}
|
49
54
|
|
50
55
|
Object end = obj.call("end");
|
51
56
|
if (!end.is_nil()) {
|
@@ -53,12 +58,14 @@ std::vector<TensorIndex> index_vector(Array a) {
|
|
53
58
|
}
|
54
59
|
|
55
60
|
Object exclude_end = obj.call("exclude_end?");
|
56
|
-
if (!exclude_end) {
|
61
|
+
if (stop_index.has_value() && !exclude_end) {
|
57
62
|
if (stop_index.value() == -1) {
|
58
63
|
stop_index = torch::nullopt;
|
59
64
|
} else {
|
60
65
|
stop_index = stop_index.value() + 1;
|
61
66
|
}
|
67
|
+
} else if (!stop_index.has_value() && exclude_end) {
|
68
|
+
stop_index = -1;
|
62
69
|
}
|
63
70
|
|
64
71
|
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
@@ -348,16 +355,6 @@ void Init_ext()
|
|
348
355
|
*[](Tensor& self) {
|
349
356
|
return self.is_contiguous();
|
350
357
|
})
|
351
|
-
.define_method(
|
352
|
-
"addcmul!",
|
353
|
-
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
354
|
-
return self.addcmul_(tensor1, tensor2, value);
|
355
|
-
})
|
356
|
-
.define_method(
|
357
|
-
"addcdiv!",
|
358
|
-
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
359
|
-
return self.addcdiv_(tensor1, tensor2, value);
|
360
|
-
})
|
361
358
|
.define_method(
|
362
359
|
"_requires_grad!",
|
363
360
|
*[](Tensor& self, bool requires_grad) {
|
@@ -372,7 +369,7 @@ void Init_ext()
|
|
372
369
|
.define_method(
|
373
370
|
"grad=",
|
374
371
|
*[](Tensor& self, torch::Tensor& grad) {
|
375
|
-
self.
|
372
|
+
self.mutable_grad() = grad;
|
376
373
|
})
|
377
374
|
.define_method(
|
378
375
|
"_dtype",
|
@@ -609,7 +606,7 @@ void Init_ext()
|
|
609
606
|
.define_method(
|
610
607
|
"grad=",
|
611
608
|
*[](Parameter& self, torch::Tensor& grad) {
|
612
|
-
self.
|
609
|
+
self.mutable_grad() = grad;
|
613
610
|
});
|
614
611
|
|
615
612
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
@@ -628,5 +625,7 @@ void Init_ext()
|
|
628
625
|
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
629
626
|
.add_handler<torch::Error>(handle_error)
|
630
627
|
.define_singleton_method("available?", &torch::cuda::is_available)
|
631
|
-
.define_singleton_method("device_count", &torch::cuda::device_count)
|
628
|
+
.define_singleton_method("device_count", &torch::cuda::device_count)
|
629
|
+
.define_singleton_method("manual_seed", &torch::cuda::manual_seed)
|
630
|
+
.define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all);
|
632
631
|
}
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -91,7 +91,7 @@ struct RubyArgs {
|
|
91
91
|
inline c10::optional<int64_t> toInt64Optional(int i);
|
92
92
|
inline c10::optional<bool> toBoolOptional(int i);
|
93
93
|
inline c10::optional<double> toDoubleOptional(int i);
|
94
|
-
|
94
|
+
inline c10::OptionalArray<double> doublelistOptional(int i);
|
95
95
|
// inline at::Layout layout(int i);
|
96
96
|
// inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
|
97
97
|
inline c10::optional<at::Layout> layoutOptional(int i);
|
@@ -105,7 +105,7 @@ struct RubyArgs {
|
|
105
105
|
inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
|
106
106
|
// inline at::QScheme toQScheme(int i);
|
107
107
|
inline std::string string(int i);
|
108
|
-
|
108
|
+
inline c10::optional<std::string> stringOptional(int i);
|
109
109
|
// inline PyObject* pyobject(int i);
|
110
110
|
inline int64_t toInt64(int i);
|
111
111
|
// inline int64_t toInt64WithDefault(int i, int64_t default_int);
|
@@ -249,6 +249,25 @@ inline c10::optional<double> RubyArgs::toDoubleOptional(int i) {
|
|
249
249
|
return toDouble(i);
|
250
250
|
}
|
251
251
|
|
252
|
+
inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
253
|
+
if (NIL_P(args[i])) return {};
|
254
|
+
|
255
|
+
VALUE arg = args[i];
|
256
|
+
auto size = RARRAY_LEN(arg);
|
257
|
+
std::vector<double> res(size);
|
258
|
+
for (idx = 0; idx < size; idx++) {
|
259
|
+
VALUE obj = rb_ary_entry(arg, idx);
|
260
|
+
if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
|
261
|
+
res[idx] = from_ruby<double>(obj);
|
262
|
+
} else {
|
263
|
+
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
264
|
+
signature.name.c_str(), signature.params[i].name.c_str(),
|
265
|
+
signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
|
266
|
+
}
|
267
|
+
}
|
268
|
+
return res;
|
269
|
+
}
|
270
|
+
|
252
271
|
inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
253
272
|
if (NIL_P(args[i])) return c10::nullopt;
|
254
273
|
|
@@ -285,6 +304,11 @@ inline std::string RubyArgs::string(int i) {
|
|
285
304
|
return from_ruby<std::string>(args[i]);
|
286
305
|
}
|
287
306
|
|
307
|
+
inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
308
|
+
if (!args[i]) return c10::nullopt;
|
309
|
+
return from_ruby<std::string>(args[i]);
|
310
|
+
}
|
311
|
+
|
288
312
|
inline int64_t RubyArgs::toInt64(int i) {
|
289
313
|
if (NIL_P(args[i])) return signature.params[i].default_int;
|
290
314
|
return from_ruby<int64_t>(args[i]);
|
data/ext/torch/templates.h
CHANGED
data/ext/torch/wrap_outputs.h
CHANGED
data/lib/torch.rb
CHANGED
@@ -261,6 +261,8 @@ module Torch
|
|
261
261
|
Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
262
262
|
elsif args.size == 1 && args.first.is_a?(Array)
|
263
263
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
264
|
+
elsif args.size == 0
|
265
|
+
Torch.empty(0, dtype: dtype, device: device)
|
264
266
|
else
|
265
267
|
Torch.empty(*args, dtype: dtype, device: device)
|
266
268
|
end
|
@@ -434,7 +436,8 @@ module Torch
|
|
434
436
|
zeros(input.size, **like_options(input, options))
|
435
437
|
end
|
436
438
|
|
437
|
-
|
439
|
+
# center option
|
440
|
+
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
|
438
441
|
if center
|
439
442
|
signal_dim = input.dim
|
440
443
|
extended_shape = [1] * (3 - signal_dim) + input.size
|
@@ -442,12 +445,7 @@ module Torch
|
|
442
445
|
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
443
446
|
input = input.view(input.shape[-signal_dim..-1])
|
444
447
|
end
|
445
|
-
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
446
|
-
end
|
447
|
-
|
448
|
-
def clamp(tensor, min, max)
|
449
|
-
tensor = _clamp_min(tensor, min)
|
450
|
-
_clamp_max(tensor, max)
|
448
|
+
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
|
451
449
|
end
|
452
450
|
|
453
451
|
private
|
data/lib/torch/nn/module.rb
CHANGED
@@ -113,35 +113,53 @@ module Torch
|
|
113
113
|
forward(*input, **kwargs)
|
114
114
|
end
|
115
115
|
|
116
|
-
def state_dict(destination: nil)
|
116
|
+
def state_dict(destination: nil, prefix: "")
|
117
117
|
destination ||= {}
|
118
|
-
|
119
|
-
|
118
|
+
save_to_state_dict(destination, prefix: prefix)
|
119
|
+
|
120
|
+
named_children.each do |name, mod|
|
121
|
+
next unless mod
|
122
|
+
mod.state_dict(destination: destination, prefix: prefix + name + ".")
|
120
123
|
end
|
121
124
|
destination
|
122
125
|
end
|
123
126
|
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
raise Error, "Unknown module: #{k1}"
|
127
|
+
def load_state_dict(state_dict, strict: true)
|
128
|
+
# TODO support strict: false
|
129
|
+
raise "strict: false not implemented yet" unless strict
|
130
|
+
|
131
|
+
missing_keys = []
|
132
|
+
unexpected_keys = []
|
133
|
+
error_msgs = []
|
134
|
+
|
135
|
+
# TODO handle metadata
|
136
|
+
|
137
|
+
_load = lambda do |mod, prefix = ""|
|
138
|
+
# TODO handle metadata
|
139
|
+
local_metadata = {}
|
140
|
+
mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
|
141
|
+
mod.named_children.each do |name, child|
|
142
|
+
_load.call(child, prefix + name + ".") unless child.nil?
|
141
143
|
end
|
142
144
|
end
|
143
145
|
|
144
|
-
|
146
|
+
_load.call(self)
|
147
|
+
|
148
|
+
if strict
|
149
|
+
if unexpected_keys.any?
|
150
|
+
error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
|
151
|
+
end
|
152
|
+
|
153
|
+
if missing_keys.any?
|
154
|
+
error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
|
155
|
+
end
|
156
|
+
end
|
157
|
+
|
158
|
+
if error_msgs.any?
|
159
|
+
# just show first error
|
160
|
+
raise Error, error_msgs[0]
|
161
|
+
end
|
162
|
+
|
145
163
|
nil
|
146
164
|
end
|
147
165
|
|
@@ -300,6 +318,68 @@ module Torch
|
|
300
318
|
def dict
|
301
319
|
instance_variables.reject { |k| instance_variable_get(k).is_a?(Tensor) }.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
302
320
|
end
|
321
|
+
|
322
|
+
def load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
323
|
+
# TODO add hooks
|
324
|
+
|
325
|
+
# TODO handle non-persistent buffers
|
326
|
+
persistent_buffers = named_buffers
|
327
|
+
local_name_params = named_parameters(recurse: false).merge(persistent_buffers)
|
328
|
+
local_state = local_name_params.select { |_, v| !v.nil? }
|
329
|
+
|
330
|
+
local_state.each do |name, param|
|
331
|
+
key = prefix + name
|
332
|
+
if state_dict.key?(key)
|
333
|
+
input_param = state_dict[key]
|
334
|
+
|
335
|
+
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
336
|
+
if param.shape.length == 0 && input_param.shape.length == 1
|
337
|
+
input_param = input_param[0]
|
338
|
+
end
|
339
|
+
|
340
|
+
if input_param.shape != param.shape
|
341
|
+
# local shape should match the one in checkpoint
|
342
|
+
error_msgs << "size mismatch for #{key}: copying a param with shape #{input_param.shape} from checkpoint, " +
|
343
|
+
"the shape in current model is #{param.shape}."
|
344
|
+
next
|
345
|
+
end
|
346
|
+
|
347
|
+
begin
|
348
|
+
Torch.no_grad do
|
349
|
+
param.copy!(input_param)
|
350
|
+
end
|
351
|
+
rescue => e
|
352
|
+
error_msgs << "While copying the parameter named #{key.inspect}, " +
|
353
|
+
"whose dimensions in the model are #{param.size} and " +
|
354
|
+
"whose dimensions in the checkpoint are #{input_param.size}, " +
|
355
|
+
"an exception occurred: #{e.inspect}"
|
356
|
+
end
|
357
|
+
elsif strict
|
358
|
+
missing_keys << key
|
359
|
+
end
|
360
|
+
end
|
361
|
+
|
362
|
+
if strict
|
363
|
+
state_dict.each_key do |key|
|
364
|
+
if key.start_with?(prefix)
|
365
|
+
input_name = key[prefix.length..-1]
|
366
|
+
input_name = input_name.split(".", 2)[0]
|
367
|
+
if !named_children.key?(input_name) && !local_state.key?(input_name)
|
368
|
+
unexpected_keys << key
|
369
|
+
end
|
370
|
+
end
|
371
|
+
end
|
372
|
+
end
|
373
|
+
end
|
374
|
+
|
375
|
+
def save_to_state_dict(destination, prefix: "")
|
376
|
+
named_parameters(recurse: false).each do |k, v|
|
377
|
+
destination[prefix + k] = v
|
378
|
+
end
|
379
|
+
named_buffers.each do |k, v|
|
380
|
+
destination[prefix + k] = v
|
381
|
+
end
|
382
|
+
end
|
303
383
|
end
|
304
384
|
end
|
305
385
|
end
|
data/lib/torch/optim/adadelta.rb
CHANGED
@@ -39,14 +39,14 @@ module Torch
|
|
39
39
|
state[:step] += 1
|
40
40
|
|
41
41
|
if group[:weight_decay] != 0
|
42
|
-
grad = grad.add(group[:weight_decay]
|
42
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
43
43
|
end
|
44
44
|
|
45
|
-
square_avg.mul!(rho).addcmul!(1 - rho
|
45
|
+
square_avg.mul!(rho).addcmul!(grad, grad, value: 1 - rho)
|
46
46
|
std = square_avg.add(eps).sqrt!
|
47
47
|
delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
|
48
48
|
p.data.add!(delta, alpha: -group[:lr])
|
49
|
-
acc_delta.mul!(rho).addcmul!(1 - rho
|
49
|
+
acc_delta.mul!(rho).addcmul!(delta, delta, value: 1 - rho)
|
50
50
|
end
|
51
51
|
end
|
52
52
|
|
data/lib/torch/optim/adagrad.rb
CHANGED
@@ -49,7 +49,7 @@ module Torch
|
|
49
49
|
if p.grad.data.sparse?
|
50
50
|
raise Error, "weight_decay option is not compatible with sparse gradients"
|
51
51
|
end
|
52
|
-
grad = grad.add(group[:weight_decay]
|
52
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
53
53
|
end
|
54
54
|
|
55
55
|
clr = group[:lr] / (1 + (state[:step] - 1) * group[:lr_decay])
|
@@ -57,9 +57,9 @@ module Torch
|
|
57
57
|
if grad.sparse?
|
58
58
|
raise NotImplementedYet
|
59
59
|
else
|
60
|
-
state[:sum].addcmul!(
|
60
|
+
state[:sum].addcmul!(grad, grad, value: 1)
|
61
61
|
std = state[:sum].sqrt.add!(group[:eps])
|
62
|
-
p.data.addcdiv!(
|
62
|
+
p.data.addcdiv!(grad, std, value: -clr)
|
63
63
|
end
|
64
64
|
end
|
65
65
|
end
|
data/lib/torch/optim/adam.rb
CHANGED
@@ -58,7 +58,7 @@ module Torch
|
|
58
58
|
|
59
59
|
# Decay the first and second moment running average coefficient
|
60
60
|
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
61
|
-
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2
|
61
|
+
exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
|
62
62
|
if amsgrad
|
63
63
|
# Maintains the maximum of all 2nd moment running avg. till now
|
64
64
|
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
@@ -70,7 +70,7 @@ module Torch
|
|
70
70
|
|
71
71
|
step_size = group[:lr] / bias_correction1
|
72
72
|
|
73
|
-
p.data.addcdiv!(
|
73
|
+
p.data.addcdiv!(exp_avg, denom, value: -step_size)
|
74
74
|
end
|
75
75
|
end
|
76
76
|
|
data/lib/torch/optim/adamax.rb
CHANGED
@@ -42,7 +42,7 @@ module Torch
|
|
42
42
|
state[:step] += 1
|
43
43
|
|
44
44
|
if group[:weight_decay] != 0
|
45
|
-
grad = grad.add(group[:weight_decay]
|
45
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
46
46
|
end
|
47
47
|
|
48
48
|
# Update biased first moment estimate.
|
@@ -57,7 +57,7 @@ module Torch
|
|
57
57
|
bias_correction = 1 - beta1 ** state[:step]
|
58
58
|
clr = group[:lr] / bias_correction
|
59
59
|
|
60
|
-
p.data.addcdiv!(
|
60
|
+
p.data.addcdiv!(exp_avg, exp_inf, value: -clr)
|
61
61
|
end
|
62
62
|
end
|
63
63
|
|
data/lib/torch/optim/adamw.rb
CHANGED
@@ -59,7 +59,7 @@ module Torch
|
|
59
59
|
|
60
60
|
# Decay the first and second moment running average coefficient
|
61
61
|
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
62
|
-
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2
|
62
|
+
exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
|
63
63
|
if amsgrad
|
64
64
|
# Maintains the maximum of all 2nd moment running avg. till now
|
65
65
|
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
@@ -71,7 +71,7 @@ module Torch
|
|
71
71
|
|
72
72
|
step_size = group[:lr] / bias_correction1
|
73
73
|
|
74
|
-
p.data.addcdiv!(
|
74
|
+
p.data.addcdiv!(exp_avg, denom, value: -step_size)
|
75
75
|
end
|
76
76
|
end
|
77
77
|
|
data/lib/torch/optim/asgd.rb
CHANGED
data/lib/torch/optim/rmsprop.rb
CHANGED
@@ -46,25 +46,25 @@ module Torch
|
|
46
46
|
state[:step] += 1
|
47
47
|
|
48
48
|
if group[:weight_decay] != 0
|
49
|
-
grad = grad.add(group[:weight_decay]
|
49
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
50
50
|
end
|
51
51
|
|
52
|
-
square_avg.mul!(alpha).addcmul!(1 - alpha
|
52
|
+
square_avg.mul!(alpha).addcmul!(grad, grad, value: 1 - alpha)
|
53
53
|
|
54
54
|
if group[:centered]
|
55
55
|
grad_avg = state[:grad_avg]
|
56
|
-
grad_avg.mul!(alpha).add!(1 - alpha
|
57
|
-
avg = square_avg.addcmul(
|
56
|
+
grad_avg.mul!(alpha).add!(grad, alpha: 1 - alpha)
|
57
|
+
avg = square_avg.addcmul(grad_avg, grad_avg, value: -1).sqrt!.add!(group[:eps])
|
58
58
|
else
|
59
59
|
avg = square_avg.sqrt.add!(group[:eps])
|
60
60
|
end
|
61
61
|
|
62
62
|
if group[:momentum] > 0
|
63
63
|
buf = state[:momentum_buffer]
|
64
|
-
buf.mul!(group[:momentum]).addcdiv!(grad, avg)
|
65
|
-
p.data.add!(-group[:lr]
|
64
|
+
buf.mul!(group[:momentum]).addcdiv!(grad, avg, value: 1)
|
65
|
+
p.data.add!(buf, alpha: -group[:lr])
|
66
66
|
else
|
67
|
-
p.data.addcdiv!(-group[:lr]
|
67
|
+
p.data.addcdiv!(grad, avg, value: -group[:lr])
|
68
68
|
end
|
69
69
|
end
|
70
70
|
end
|