torch-rb 0.3.6 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +27 -0
- data/README.md +3 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +557 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +2363 -714
- data/ext/torch/ext.cpp +78 -89
- data/ext/torch/extconf.rb +5 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +397 -0
- data/ext/torch/{templates.hpp → templates.h} +46 -77
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +44 -8
- data/lib/torch.rb +35 -62
- data/lib/torch/nn/functional.rb +136 -16
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/nn/module.rb +4 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/optim/adadelta.rb +4 -4
- data/lib/torch/optim/adagrad.rb +3 -3
- data/lib/torch/optim/adam.rb +4 -4
- data/lib/torch/optim/adamax.rb +3 -3
- data/lib/torch/optim/adamw.rb +3 -3
- data/lib/torch/optim/asgd.rb +2 -2
- data/lib/torch/optim/rmsprop.rb +7 -7
- data/lib/torch/optim/rprop.rb +1 -1
- data/lib/torch/optim/sgd.rb +5 -5
- data/lib/torch/tensor.rb +36 -110
- data/lib/torch/version.rb +1 -1
- metadata +19 -14
- data/lib/torch/native/dispatcher.rb +0 -48
- data/lib/torch/native/function.rb +0 -119
- data/lib/torch/native/generator.rb +0 -168
- data/lib/torch/native/parser.rb +0 -148
data/ext/torch/utils.h
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include <rice/Symbol.hpp>
|
4
|
+
|
5
|
+
// keep THP prefix for now to make it easier to compare code
|
6
|
+
|
7
|
+
extern VALUE THPVariableClass;
|
8
|
+
|
9
|
+
inline VALUE THPUtils_internSymbol(const std::string& str) {
|
10
|
+
return Symbol(str);
|
11
|
+
}
|
12
|
+
|
13
|
+
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
14
|
+
Check_Type(obj, T_SYMBOL);
|
15
|
+
obj = rb_funcall(obj, rb_intern("to_s"), 0);
|
16
|
+
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
17
|
+
}
|
18
|
+
|
19
|
+
inline std::string THPUtils_unpackString(VALUE obj) {
|
20
|
+
Check_Type(obj, T_STRING);
|
21
|
+
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
22
|
+
}
|
23
|
+
|
24
|
+
inline bool THPUtils_checkSymbol(VALUE obj) {
|
25
|
+
return SYMBOL_P(obj);
|
26
|
+
}
|
27
|
+
|
28
|
+
inline bool THPUtils_checkIndex(VALUE obj) {
|
29
|
+
return FIXNUM_P(obj);
|
30
|
+
}
|
31
|
+
|
32
|
+
inline bool THPUtils_checkScalar(VALUE obj) {
|
33
|
+
return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
|
34
|
+
}
|
35
|
+
|
36
|
+
inline bool THPVariable_Check(VALUE obj) {
|
37
|
+
return rb_obj_is_kind_of(obj, THPVariableClass);
|
38
|
+
}
|
39
|
+
|
40
|
+
inline bool THPVariable_CheckExact(VALUE obj) {
|
41
|
+
return rb_obj_is_instance_of(obj, THPVariableClass);
|
42
|
+
}
|
@@ -1,15 +1,44 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
1
3
|
#include <torch/torch.h>
|
2
4
|
#include <rice/Object.hpp>
|
3
|
-
#include "templates.hpp"
|
4
5
|
|
5
|
-
Object wrap(
|
6
|
+
inline Object wrap(bool x) {
|
7
|
+
return to_ruby<bool>(x);
|
8
|
+
}
|
9
|
+
|
10
|
+
inline Object wrap(int64_t x) {
|
11
|
+
return to_ruby<int64_t>(x);
|
12
|
+
}
|
13
|
+
|
14
|
+
inline Object wrap(double x) {
|
15
|
+
return to_ruby<double>(x);
|
16
|
+
}
|
17
|
+
|
18
|
+
inline Object wrap(torch::Tensor x) {
|
19
|
+
return to_ruby<torch::Tensor>(x);
|
20
|
+
}
|
21
|
+
|
22
|
+
inline Object wrap(torch::Scalar x) {
|
23
|
+
return to_ruby<torch::Scalar>(x);
|
24
|
+
}
|
25
|
+
|
26
|
+
inline Object wrap(torch::ScalarType x) {
|
27
|
+
return to_ruby<torch::ScalarType>(x);
|
28
|
+
}
|
29
|
+
|
30
|
+
inline Object wrap(torch::QScheme x) {
|
31
|
+
return to_ruby<torch::QScheme>(x);
|
32
|
+
}
|
33
|
+
|
34
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
6
35
|
Array a;
|
7
36
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
8
37
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
9
38
|
return Object(a);
|
10
39
|
}
|
11
40
|
|
12
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
41
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
13
42
|
Array a;
|
14
43
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
15
44
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -17,7 +46,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
|
17
46
|
return Object(a);
|
18
47
|
}
|
19
48
|
|
20
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
49
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
21
50
|
Array a;
|
22
51
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
23
52
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -26,7 +55,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
|
|
26
55
|
return Object(a);
|
27
56
|
}
|
28
57
|
|
29
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
58
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
30
59
|
Array a;
|
31
60
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
32
61
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -36,7 +65,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
|
|
36
65
|
return Object(a);
|
37
66
|
}
|
38
67
|
|
39
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
68
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
40
69
|
Array a;
|
41
70
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
42
71
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -45,7 +74,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x)
|
|
45
74
|
return Object(a);
|
46
75
|
}
|
47
76
|
|
48
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
77
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
49
78
|
Array a;
|
50
79
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
51
80
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -54,10 +83,17 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
|
54
83
|
return Object(a);
|
55
84
|
}
|
56
85
|
|
57
|
-
Object wrap(
|
86
|
+
inline Object wrap(torch::TensorList x) {
|
58
87
|
Array a;
|
59
88
|
for (auto& t : x) {
|
60
89
|
a.push(to_ruby<torch::Tensor>(t));
|
61
90
|
}
|
62
91
|
return Object(a);
|
63
92
|
}
|
93
|
+
|
94
|
+
inline Object wrap(std::tuple<double, double> x) {
|
95
|
+
Array a;
|
96
|
+
a.push(to_ruby<double>(std::get<0>(x)));
|
97
|
+
a.push(to_ruby<double>(std::get<1>(x)));
|
98
|
+
return Object(a);
|
99
|
+
}
|
data/lib/torch.rb
CHANGED
@@ -7,11 +7,6 @@ require "net/http"
|
|
7
7
|
require "set"
|
8
8
|
require "tmpdir"
|
9
9
|
|
10
|
-
# native functions
|
11
|
-
require "torch/native/generator"
|
12
|
-
require "torch/native/parser"
|
13
|
-
require "torch/native/dispatcher"
|
14
|
-
|
15
10
|
# modules
|
16
11
|
require "torch/inspector"
|
17
12
|
require "torch/tensor"
|
@@ -174,6 +169,9 @@ require "torch/nn/smooth_l1_loss"
|
|
174
169
|
require "torch/nn/soft_margin_loss"
|
175
170
|
require "torch/nn/triplet_margin_loss"
|
176
171
|
|
172
|
+
# nn vision
|
173
|
+
require "torch/nn/upsample"
|
174
|
+
|
177
175
|
# nn other
|
178
176
|
require "torch/nn/functional"
|
179
177
|
require "torch/nn/init"
|
@@ -196,6 +194,32 @@ module Torch
|
|
196
194
|
end
|
197
195
|
end
|
198
196
|
|
197
|
+
# legacy
|
198
|
+
# but may make it easier to port tutorials
|
199
|
+
module Autograd
|
200
|
+
class Variable
|
201
|
+
def self.new(x)
|
202
|
+
raise ArgumentError, "Variable data has to be a tensor, but got #{x.class.name}" unless x.is_a?(Tensor)
|
203
|
+
warn "[torch] The Variable API is deprecated. Use tensors with requires_grad: true instead."
|
204
|
+
x
|
205
|
+
end
|
206
|
+
end
|
207
|
+
end
|
208
|
+
|
209
|
+
# TODO move to C++
|
210
|
+
class ByteStorage
|
211
|
+
# private
|
212
|
+
attr_reader :bytes
|
213
|
+
|
214
|
+
def initialize(bytes)
|
215
|
+
@bytes = bytes
|
216
|
+
end
|
217
|
+
|
218
|
+
def self.from_buffer(bytes)
|
219
|
+
new(bytes)
|
220
|
+
end
|
221
|
+
end
|
222
|
+
|
199
223
|
# keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
|
200
224
|
# values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
|
201
225
|
DTYPE_TO_ENUM = {
|
@@ -224,18 +248,24 @@ module Torch
|
|
224
248
|
}
|
225
249
|
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
|
226
250
|
|
251
|
+
TENSOR_TYPE_CLASSES = []
|
252
|
+
|
227
253
|
def self._make_tensor_class(dtype, cuda = false)
|
228
254
|
cls = Class.new
|
229
255
|
device = cuda ? "cuda" : "cpu"
|
230
256
|
cls.define_singleton_method("new") do |*args|
|
231
257
|
if args.size == 1 && args.first.is_a?(Tensor)
|
232
258
|
args.first.send(dtype).to(device)
|
259
|
+
elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
|
260
|
+
bytes = args.first.bytes
|
261
|
+
Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
233
262
|
elsif args.size == 1 && args.first.is_a?(Array)
|
234
263
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
235
264
|
else
|
236
265
|
Torch.empty(*args, dtype: dtype, device: device)
|
237
266
|
end
|
238
267
|
end
|
268
|
+
TENSOR_TYPE_CLASSES << cls
|
239
269
|
cls
|
240
270
|
end
|
241
271
|
|
@@ -339,63 +369,6 @@ module Torch
|
|
339
369
|
|
340
370
|
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
341
371
|
|
342
|
-
def arange(start, finish = nil, step = 1, **options)
|
343
|
-
# ruby doesn't support start = 0, finish, step = 1, ...
|
344
|
-
if finish.nil?
|
345
|
-
finish = start
|
346
|
-
start = 0
|
347
|
-
end
|
348
|
-
_arange(start, finish, step, tensor_options(**options))
|
349
|
-
end
|
350
|
-
|
351
|
-
def empty(*size, **options)
|
352
|
-
_empty(tensor_size(size), tensor_options(**options))
|
353
|
-
end
|
354
|
-
|
355
|
-
def eye(n, m = nil, **options)
|
356
|
-
_eye(n, m || n, tensor_options(**options))
|
357
|
-
end
|
358
|
-
|
359
|
-
def full(size, fill_value, **options)
|
360
|
-
_full(size, fill_value, tensor_options(**options))
|
361
|
-
end
|
362
|
-
|
363
|
-
def linspace(start, finish, steps = 100, **options)
|
364
|
-
_linspace(start, finish, steps, tensor_options(**options))
|
365
|
-
end
|
366
|
-
|
367
|
-
def logspace(start, finish, steps = 100, base = 10.0, **options)
|
368
|
-
_logspace(start, finish, steps, base, tensor_options(**options))
|
369
|
-
end
|
370
|
-
|
371
|
-
def ones(*size, **options)
|
372
|
-
_ones(tensor_size(size), tensor_options(**options))
|
373
|
-
end
|
374
|
-
|
375
|
-
def rand(*size, **options)
|
376
|
-
_rand(tensor_size(size), tensor_options(**options))
|
377
|
-
end
|
378
|
-
|
379
|
-
def randint(low = 0, high, size, **options)
|
380
|
-
_randint(low, high, size, tensor_options(**options))
|
381
|
-
end
|
382
|
-
|
383
|
-
def randn(*size, **options)
|
384
|
-
_randn(tensor_size(size), tensor_options(**options))
|
385
|
-
end
|
386
|
-
|
387
|
-
def randperm(n, **options)
|
388
|
-
# dtype hack in Python
|
389
|
-
# https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
|
390
|
-
options[:dtype] ||= :int64
|
391
|
-
|
392
|
-
_randperm(n, tensor_options(**options))
|
393
|
-
end
|
394
|
-
|
395
|
-
def zeros(*size, **options)
|
396
|
-
_zeros(tensor_size(size), tensor_options(**options))
|
397
|
-
end
|
398
|
-
|
399
372
|
def tensor(data, **options)
|
400
373
|
if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
|
401
374
|
numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -394,15 +394,15 @@ module Torch
|
|
394
394
|
# loss functions
|
395
395
|
|
396
396
|
def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
|
397
|
-
NN.binary_cross_entropy(input, target, weight, reduction)
|
397
|
+
NN.binary_cross_entropy(input, target, weight, to_reduction(reduction))
|
398
398
|
end
|
399
399
|
|
400
400
|
def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
|
401
|
-
Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
|
401
|
+
Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, to_reduction(reduction))
|
402
402
|
end
|
403
403
|
|
404
404
|
def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
|
405
|
-
Torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
|
405
|
+
Torch.cosine_embedding_loss(input1, input2, target, margin, to_reduction(reduction))
|
406
406
|
end
|
407
407
|
|
408
408
|
def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
@@ -411,34 +411,34 @@ module Torch
|
|
411
411
|
|
412
412
|
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
|
413
413
|
# call to_a on input_lengths and target_lengths for C++
|
414
|
-
Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
|
414
|
+
Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, to_reduction(reduction), zero_infinity)
|
415
415
|
end
|
416
416
|
|
417
417
|
def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
|
418
|
-
Torch.hinge_embedding_loss(input, target, margin, reduction)
|
418
|
+
Torch.hinge_embedding_loss(input, target, margin, to_reduction(reduction))
|
419
419
|
end
|
420
420
|
|
421
421
|
def kl_div(input, target, reduction: "mean")
|
422
|
-
Torch.kl_div(input, target, reduction)
|
422
|
+
Torch.kl_div(input, target, to_reduction(reduction))
|
423
423
|
end
|
424
424
|
|
425
425
|
def l1_loss(input, target, reduction: "mean")
|
426
|
-
NN.l1_loss(input, target, reduction)
|
426
|
+
NN.l1_loss(input, target, to_reduction(reduction))
|
427
427
|
end
|
428
428
|
|
429
429
|
def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
|
430
|
-
Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
|
430
|
+
Torch.margin_ranking_loss(input1, input2, target, margin, to_reduction(reduction))
|
431
431
|
end
|
432
432
|
|
433
433
|
def mse_loss(input, target, reduction: "mean")
|
434
434
|
if target.size != input.size
|
435
435
|
warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
|
436
436
|
end
|
437
|
-
NN.mse_loss(input, target, reduction)
|
437
|
+
NN.mse_loss(input, target, to_reduction(reduction))
|
438
438
|
end
|
439
439
|
|
440
440
|
def multilabel_margin_loss(input, target, reduction: "mean")
|
441
|
-
NN.multilabel_margin_loss(input, target, reduction)
|
441
|
+
NN.multilabel_margin_loss(input, target, to_reduction(reduction))
|
442
442
|
end
|
443
443
|
|
444
444
|
def multilabel_soft_margin_loss(input, target, weight: nil)
|
@@ -446,31 +446,116 @@ module Torch
|
|
446
446
|
end
|
447
447
|
|
448
448
|
def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
|
449
|
-
NN.multi_margin_loss(input, target, p, margin, weight, reduction)
|
449
|
+
NN.multi_margin_loss(input, target, p, margin, weight, to_reduction(reduction))
|
450
450
|
end
|
451
451
|
|
452
452
|
def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
453
|
-
NN.nll_loss(input, target, weight, reduction, ignore_index)
|
453
|
+
NN.nll_loss(input, target, weight, to_reduction(reduction), ignore_index)
|
454
454
|
end
|
455
455
|
|
456
456
|
def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
457
|
-
Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
|
457
|
+
Torch.poisson_nll_loss(input, target, log_input, full, eps, to_reduction(reduction))
|
458
458
|
end
|
459
459
|
|
460
460
|
def soft_margin_loss(input, target, reduction: "mean")
|
461
|
-
NN.soft_margin_loss(input, target, reduction)
|
461
|
+
NN.soft_margin_loss(input, target, to_reduction(reduction))
|
462
462
|
end
|
463
463
|
|
464
464
|
def smooth_l1_loss(input, target, reduction: "mean")
|
465
|
-
NN.smooth_l1_loss(input, target, reduction)
|
465
|
+
NN.smooth_l1_loss(input, target, to_reduction(reduction))
|
466
466
|
end
|
467
467
|
|
468
468
|
def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
|
469
|
-
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
469
|
+
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
|
470
|
+
end
|
471
|
+
|
472
|
+
# vision
|
473
|
+
|
474
|
+
def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
|
475
|
+
if ["nearest", "area"].include?(mode)
|
476
|
+
unless align_corners.nil?
|
477
|
+
raise ArgumentError, "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
|
478
|
+
end
|
479
|
+
else
|
480
|
+
if align_corners.nil?
|
481
|
+
align_corners = false
|
482
|
+
end
|
483
|
+
end
|
484
|
+
|
485
|
+
scale_factor_len = input.dim - 2
|
486
|
+
scale_factor_list = [nil] * scale_factor_len
|
487
|
+
# default value of recompute_scale_factor is False
|
488
|
+
if !scale_factor.nil? && (recompute_scale_factor == false || recompute_scale_factor.nil?)
|
489
|
+
if scale_factor.is_a?(Array)
|
490
|
+
_scale_factor_repeated = scale_factor
|
491
|
+
else
|
492
|
+
_scale_factor_repeated = [scale_factor] * scale_factor_len
|
493
|
+
end
|
494
|
+
scale_factor_list = _scale_factor_repeated
|
495
|
+
end
|
496
|
+
|
497
|
+
# Give this variable a short name because it has to be repeated multiple times below.
|
498
|
+
sfl = scale_factor_list
|
499
|
+
|
500
|
+
closed_over_args = [input, size, scale_factor, recompute_scale_factor]
|
501
|
+
output_size = _interp_output_size(closed_over_args)
|
502
|
+
if input.dim == 3 && mode == "nearest"
|
503
|
+
NN.upsample_nearest1d(input, output_size, sfl[0])
|
504
|
+
elsif input.dim == 4 && mode == "nearest"
|
505
|
+
NN.upsample_nearest2d(input, output_size, sfl[0], sfl[1])
|
506
|
+
elsif input.dim == 5 && mode == "nearest"
|
507
|
+
NN.upsample_nearest3d(input, output_size, sfl[0], sfl[1], sfl[2])
|
508
|
+
elsif input.dim == 3 && mode == "area"
|
509
|
+
adaptive_avg_pool1d(input, output_size)
|
510
|
+
elsif input.dim == 4 && mode == "area"
|
511
|
+
adaptive_avg_pool2d(input, output_size)
|
512
|
+
elsif input.dim == 5 && mode == "area"
|
513
|
+
adaptive_avg_pool3d(input, output_size)
|
514
|
+
elsif input.dim == 3 && mode == "linear"
|
515
|
+
# assert align_corners is not None
|
516
|
+
NN.upsample_linear1d(input, output_size, align_corners, sfl[0])
|
517
|
+
elsif input.dim == 3 && mode == "bilinear"
|
518
|
+
raise ArgumentError, "Got 3D input, but bilinear mode needs 4D input"
|
519
|
+
elsif input.dim == 3 && mode == "trilinear"
|
520
|
+
raise ArgumentError, "Got 3D input, but trilinear mode needs 5D input"
|
521
|
+
elsif input.dim == 4 && mode == "linear"
|
522
|
+
raise ArgumentError, "Got 4D input, but linear mode needs 3D input"
|
523
|
+
elsif input.dim == 4 && mode == "bilinear"
|
524
|
+
# assert align_corners is not None
|
525
|
+
NN.upsample_bilinear2d(input, output_size, align_corners, sfl[0], sfl[1])
|
526
|
+
elsif input.dim == 4 && mode == "trilinear"
|
527
|
+
raise ArgumentError, "Got 4D input, but trilinear mode needs 5D input"
|
528
|
+
elsif input.dim == 5 && mode == "linear"
|
529
|
+
raise ArgumentError, "Got 5D input, but linear mode needs 3D input"
|
530
|
+
elsif input.dim == 5 && mode == "bilinear"
|
531
|
+
raise ArgumentError, "Got 5D input, but bilinear mode needs 4D input"
|
532
|
+
elsif input.dim == 5 && mode == "trilinear"
|
533
|
+
# assert align_corners is not None
|
534
|
+
NN.upsample_trilinear3d(input, output_size, align_corners, sfl[0], sfl[1], sfl[2])
|
535
|
+
elsif input.dim == 4 && mode == "bicubic"
|
536
|
+
# assert align_corners is not None
|
537
|
+
NN.upsample_bicubic2d(input, output_size, align_corners, sfl[0], sfl[1])
|
538
|
+
else
|
539
|
+
raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})"
|
540
|
+
end
|
470
541
|
end
|
471
542
|
|
472
543
|
private
|
473
544
|
|
545
|
+
# see _reduction.py
|
546
|
+
def to_reduction(v)
|
547
|
+
case v.to_s
|
548
|
+
when "none"
|
549
|
+
0
|
550
|
+
when "mean"
|
551
|
+
1
|
552
|
+
when "sum"
|
553
|
+
2
|
554
|
+
else
|
555
|
+
raise ArgumentError, "#{v} is not a valid value for reduction"
|
556
|
+
end
|
557
|
+
end
|
558
|
+
|
474
559
|
def softmax_dim(ndim)
|
475
560
|
ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
|
476
561
|
end
|
@@ -484,6 +569,41 @@ module Torch
|
|
484
569
|
out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
|
485
570
|
end
|
486
571
|
end
|
572
|
+
|
573
|
+
def _interp_output_size(closed_over_args)
|
574
|
+
input, size, scale_factor, recompute_scale_factor = closed_over_args
|
575
|
+
dim = input.dim - 2
|
576
|
+
if size.nil? && scale_factor.nil?
|
577
|
+
raise ArgumentError, "either size or scale_factor should be defined"
|
578
|
+
end
|
579
|
+
if !size.nil? && !scale_factor.nil?
|
580
|
+
raise ArgumentError, "only one of size or scale_factor should be defined"
|
581
|
+
end
|
582
|
+
if !scale_factor.nil?
|
583
|
+
if scale_factor.is_a?(Array)
|
584
|
+
if scale_factor.length != dim
|
585
|
+
raise ArgumentError, "scale_factor shape must match input shape. Input is #{dim}D, scale_factor size is #{scale_factor.length}"
|
586
|
+
end
|
587
|
+
end
|
588
|
+
end
|
589
|
+
|
590
|
+
if !size.nil?
|
591
|
+
if size.is_a?(Array)
|
592
|
+
return size
|
593
|
+
else
|
594
|
+
return [size] * dim
|
595
|
+
end
|
596
|
+
end
|
597
|
+
|
598
|
+
raise "Failed assertion" if scale_factor.nil?
|
599
|
+
if scale_factor.is_a?(Array)
|
600
|
+
scale_factors = scale_factor
|
601
|
+
else
|
602
|
+
scale_factors = [scale_factor] * dim
|
603
|
+
end
|
604
|
+
|
605
|
+
dim.times.map { |i| (input.size(i + 2) * scale_factors[i]).floor }
|
606
|
+
end
|
487
607
|
end
|
488
608
|
end
|
489
609
|
|