torch-rb 0.3.4 → 0.4.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +2 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +549 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +76 -87
- data/ext/torch/extconf.rb +5 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +87 -97
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +44 -7
- data/lib/torch.rb +51 -77
- data/lib/torch/nn/functional.rb +142 -18
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +36 -115
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +19 -14
- data/lib/torch/native/dispatcher.rb +0 -48
- data/lib/torch/native/function.rb +0 -115
- data/lib/torch/native/generator.rb +0 -163
- data/lib/torch/native/parser.rb +0 -140
data/ext/torch/utils.h
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include <rice/Symbol.hpp>
|
4
|
+
|
5
|
+
// keep THP prefix for now to make it easier to compare code
|
6
|
+
|
7
|
+
extern VALUE THPVariableClass;
|
8
|
+
|
9
|
+
inline VALUE THPUtils_internSymbol(const std::string& str) {
|
10
|
+
return Symbol(str);
|
11
|
+
}
|
12
|
+
|
13
|
+
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
14
|
+
Check_Type(obj, T_SYMBOL);
|
15
|
+
obj = rb_funcall(obj, rb_intern("to_s"), 0);
|
16
|
+
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
17
|
+
}
|
18
|
+
|
19
|
+
inline std::string THPUtils_unpackString(VALUE obj) {
|
20
|
+
Check_Type(obj, T_STRING);
|
21
|
+
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
22
|
+
}
|
23
|
+
|
24
|
+
inline bool THPUtils_checkSymbol(VALUE obj) {
|
25
|
+
return SYMBOL_P(obj);
|
26
|
+
}
|
27
|
+
|
28
|
+
inline bool THPUtils_checkIndex(VALUE obj) {
|
29
|
+
return FIXNUM_P(obj);
|
30
|
+
}
|
31
|
+
|
32
|
+
inline bool THPUtils_checkScalar(VALUE obj) {
|
33
|
+
return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
|
34
|
+
}
|
35
|
+
|
36
|
+
inline bool THPVariable_Check(VALUE obj) {
|
37
|
+
return rb_obj_is_kind_of(obj, THPVariableClass);
|
38
|
+
}
|
39
|
+
|
40
|
+
inline bool THPVariable_CheckExact(VALUE obj) {
|
41
|
+
return rb_obj_is_instance_of(obj, THPVariableClass);
|
42
|
+
}
|
@@ -1,15 +1,44 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
1
3
|
#include <torch/torch.h>
|
2
4
|
#include <rice/Object.hpp>
|
3
|
-
#include "templates.hpp"
|
4
5
|
|
5
|
-
Object wrap(
|
6
|
+
inline Object wrap(bool x) {
|
7
|
+
return to_ruby<bool>(x);
|
8
|
+
}
|
9
|
+
|
10
|
+
inline Object wrap(int64_t x) {
|
11
|
+
return to_ruby<int64_t>(x);
|
12
|
+
}
|
13
|
+
|
14
|
+
inline Object wrap(double x) {
|
15
|
+
return to_ruby<double>(x);
|
16
|
+
}
|
17
|
+
|
18
|
+
inline Object wrap(torch::Tensor x) {
|
19
|
+
return to_ruby<torch::Tensor>(x);
|
20
|
+
}
|
21
|
+
|
22
|
+
inline Object wrap(torch::Scalar x) {
|
23
|
+
return to_ruby<torch::Scalar>(x);
|
24
|
+
}
|
25
|
+
|
26
|
+
inline Object wrap(torch::ScalarType x) {
|
27
|
+
return to_ruby<torch::ScalarType>(x);
|
28
|
+
}
|
29
|
+
|
30
|
+
inline Object wrap(torch::QScheme x) {
|
31
|
+
return to_ruby<torch::QScheme>(x);
|
32
|
+
}
|
33
|
+
|
34
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
6
35
|
Array a;
|
7
36
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
8
37
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
9
38
|
return Object(a);
|
10
39
|
}
|
11
40
|
|
12
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
41
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
13
42
|
Array a;
|
14
43
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
15
44
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -17,7 +46,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
|
17
46
|
return Object(a);
|
18
47
|
}
|
19
48
|
|
20
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
49
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
21
50
|
Array a;
|
22
51
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
23
52
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -26,7 +55,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
|
|
26
55
|
return Object(a);
|
27
56
|
}
|
28
57
|
|
29
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
58
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
30
59
|
Array a;
|
31
60
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
32
61
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -36,7 +65,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
|
|
36
65
|
return Object(a);
|
37
66
|
}
|
38
67
|
|
39
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
68
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
40
69
|
Array a;
|
41
70
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
42
71
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -45,7 +74,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x)
|
|
45
74
|
return Object(a);
|
46
75
|
}
|
47
76
|
|
48
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
77
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
49
78
|
Array a;
|
50
79
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
51
80
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -53,3 +82,11 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
|
53
82
|
a.push(to_ruby<int64_t>(std::get<3>(x)));
|
54
83
|
return Object(a);
|
55
84
|
}
|
85
|
+
|
86
|
+
inline Object wrap(torch::TensorList x) {
|
87
|
+
Array a;
|
88
|
+
for (auto& t : x) {
|
89
|
+
a.push(to_ruby<torch::Tensor>(t));
|
90
|
+
}
|
91
|
+
return Object(a);
|
92
|
+
}
|
data/lib/torch.rb
CHANGED
@@ -7,11 +7,6 @@ require "net/http"
|
|
7
7
|
require "set"
|
8
8
|
require "tmpdir"
|
9
9
|
|
10
|
-
# native functions
|
11
|
-
require "torch/native/generator"
|
12
|
-
require "torch/native/parser"
|
13
|
-
require "torch/native/dispatcher"
|
14
|
-
|
15
10
|
# modules
|
16
11
|
require "torch/inspector"
|
17
12
|
require "torch/tensor"
|
@@ -174,6 +169,9 @@ require "torch/nn/smooth_l1_loss"
|
|
174
169
|
require "torch/nn/soft_margin_loss"
|
175
170
|
require "torch/nn/triplet_margin_loss"
|
176
171
|
|
172
|
+
# nn vision
|
173
|
+
require "torch/nn/upsample"
|
174
|
+
|
177
175
|
# nn other
|
178
176
|
require "torch/nn/functional"
|
179
177
|
require "torch/nn/init"
|
@@ -196,6 +194,32 @@ module Torch
|
|
196
194
|
end
|
197
195
|
end
|
198
196
|
|
197
|
+
# legacy
|
198
|
+
# but may make it easier to port tutorials
|
199
|
+
module Autograd
|
200
|
+
class Variable
|
201
|
+
def self.new(x)
|
202
|
+
raise ArgumentError, "Variable data has to be a tensor, but got #{x.class.name}" unless x.is_a?(Tensor)
|
203
|
+
warn "[torch] The Variable API is deprecated. Use tensors with requires_grad: true instead."
|
204
|
+
x
|
205
|
+
end
|
206
|
+
end
|
207
|
+
end
|
208
|
+
|
209
|
+
# TODO move to C++
|
210
|
+
class ByteStorage
|
211
|
+
# private
|
212
|
+
attr_reader :bytes
|
213
|
+
|
214
|
+
def initialize(bytes)
|
215
|
+
@bytes = bytes
|
216
|
+
end
|
217
|
+
|
218
|
+
def self.from_buffer(bytes)
|
219
|
+
new(bytes)
|
220
|
+
end
|
221
|
+
end
|
222
|
+
|
199
223
|
# keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
|
200
224
|
# values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
|
201
225
|
DTYPE_TO_ENUM = {
|
@@ -224,40 +248,43 @@ module Torch
|
|
224
248
|
}
|
225
249
|
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
|
226
250
|
|
251
|
+
TENSOR_TYPE_CLASSES = []
|
252
|
+
|
227
253
|
def self._make_tensor_class(dtype, cuda = false)
|
228
254
|
cls = Class.new
|
229
255
|
device = cuda ? "cuda" : "cpu"
|
230
256
|
cls.define_singleton_method("new") do |*args|
|
231
257
|
if args.size == 1 && args.first.is_a?(Tensor)
|
232
258
|
args.first.send(dtype).to(device)
|
259
|
+
elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
|
260
|
+
bytes = args.first.bytes
|
261
|
+
Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
233
262
|
elsif args.size == 1 && args.first.is_a?(Array)
|
234
263
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
235
264
|
else
|
236
265
|
Torch.empty(*args, dtype: dtype, device: device)
|
237
266
|
end
|
238
267
|
end
|
268
|
+
TENSOR_TYPE_CLASSES << cls
|
239
269
|
cls
|
240
270
|
end
|
241
271
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
CUDA::IntTensor = _make_tensor_class(:int32, true)
|
259
|
-
CUDA::LongTensor = _make_tensor_class(:int64, true)
|
260
|
-
CUDA::BoolTensor = _make_tensor_class(:bool, true)
|
272
|
+
DTYPE_TO_CLASS = {
|
273
|
+
float32: "FloatTensor",
|
274
|
+
float64: "DoubleTensor",
|
275
|
+
float16: "HalfTensor",
|
276
|
+
uint8: "ByteTensor",
|
277
|
+
int8: "CharTensor",
|
278
|
+
int16: "ShortTensor",
|
279
|
+
int32: "IntTensor",
|
280
|
+
int64: "LongTensor",
|
281
|
+
bool: "BoolTensor"
|
282
|
+
}
|
283
|
+
|
284
|
+
DTYPE_TO_CLASS.each do |dtype, class_name|
|
285
|
+
const_set(class_name, _make_tensor_class(dtype))
|
286
|
+
CUDA.const_set(class_name, _make_tensor_class(dtype, true))
|
287
|
+
end
|
261
288
|
|
262
289
|
class << self
|
263
290
|
# Torch.float, Torch.long, etc
|
@@ -342,59 +369,6 @@ module Torch
|
|
342
369
|
|
343
370
|
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
344
371
|
|
345
|
-
def arange(start, finish = nil, step = 1, **options)
|
346
|
-
# ruby doesn't support start = 0, finish, step = 1, ...
|
347
|
-
if finish.nil?
|
348
|
-
finish = start
|
349
|
-
start = 0
|
350
|
-
end
|
351
|
-
_arange(start, finish, step, tensor_options(**options))
|
352
|
-
end
|
353
|
-
|
354
|
-
def empty(*size, **options)
|
355
|
-
_empty(tensor_size(size), tensor_options(**options))
|
356
|
-
end
|
357
|
-
|
358
|
-
def eye(n, m = nil, **options)
|
359
|
-
_eye(n, m || n, tensor_options(**options))
|
360
|
-
end
|
361
|
-
|
362
|
-
def full(size, fill_value, **options)
|
363
|
-
_full(size, fill_value, tensor_options(**options))
|
364
|
-
end
|
365
|
-
|
366
|
-
def linspace(start, finish, steps = 100, **options)
|
367
|
-
_linspace(start, finish, steps, tensor_options(**options))
|
368
|
-
end
|
369
|
-
|
370
|
-
def logspace(start, finish, steps = 100, base = 10.0, **options)
|
371
|
-
_logspace(start, finish, steps, base, tensor_options(**options))
|
372
|
-
end
|
373
|
-
|
374
|
-
def ones(*size, **options)
|
375
|
-
_ones(tensor_size(size), tensor_options(**options))
|
376
|
-
end
|
377
|
-
|
378
|
-
def rand(*size, **options)
|
379
|
-
_rand(tensor_size(size), tensor_options(**options))
|
380
|
-
end
|
381
|
-
|
382
|
-
def randint(low = 0, high, size, **options)
|
383
|
-
_randint(low, high, size, tensor_options(**options))
|
384
|
-
end
|
385
|
-
|
386
|
-
def randn(*size, **options)
|
387
|
-
_randn(tensor_size(size), tensor_options(**options))
|
388
|
-
end
|
389
|
-
|
390
|
-
def randperm(n, **options)
|
391
|
-
_randperm(n, tensor_options(**options))
|
392
|
-
end
|
393
|
-
|
394
|
-
def zeros(*size, **options)
|
395
|
-
_zeros(tensor_size(size), tensor_options(**options))
|
396
|
-
end
|
397
|
-
|
398
372
|
def tensor(data, **options)
|
399
373
|
if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
|
400
374
|
numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -178,8 +178,12 @@ module Torch
|
|
178
178
|
Torch.hardshrink(input, lambd)
|
179
179
|
end
|
180
180
|
|
181
|
-
def leaky_relu(input, negative_slope = 0.01)
|
182
|
-
|
181
|
+
def leaky_relu(input, negative_slope = 0.01, inplace: false)
|
182
|
+
if inplace
|
183
|
+
NN.leaky_relu!(input, negative_slope)
|
184
|
+
else
|
185
|
+
NN.leaky_relu(input, negative_slope)
|
186
|
+
end
|
183
187
|
end
|
184
188
|
|
185
189
|
def log_sigmoid(input)
|
@@ -390,15 +394,15 @@ module Torch
|
|
390
394
|
# loss functions
|
391
395
|
|
392
396
|
def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
|
393
|
-
NN.binary_cross_entropy(input, target, weight, reduction)
|
397
|
+
NN.binary_cross_entropy(input, target, weight, to_reduction(reduction))
|
394
398
|
end
|
395
399
|
|
396
400
|
def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
|
397
|
-
Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
|
401
|
+
Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, to_reduction(reduction))
|
398
402
|
end
|
399
403
|
|
400
404
|
def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
|
401
|
-
Torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
|
405
|
+
Torch.cosine_embedding_loss(input1, input2, target, margin, to_reduction(reduction))
|
402
406
|
end
|
403
407
|
|
404
408
|
def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
@@ -407,34 +411,34 @@ module Torch
|
|
407
411
|
|
408
412
|
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
|
409
413
|
# call to_a on input_lengths and target_lengths for C++
|
410
|
-
Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
|
414
|
+
Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, to_reduction(reduction), zero_infinity)
|
411
415
|
end
|
412
416
|
|
413
417
|
def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
|
414
|
-
Torch.hinge_embedding_loss(input, target, margin, reduction)
|
418
|
+
Torch.hinge_embedding_loss(input, target, margin, to_reduction(reduction))
|
415
419
|
end
|
416
420
|
|
417
421
|
def kl_div(input, target, reduction: "mean")
|
418
|
-
Torch.kl_div(input, target, reduction)
|
422
|
+
Torch.kl_div(input, target, to_reduction(reduction))
|
419
423
|
end
|
420
424
|
|
421
425
|
def l1_loss(input, target, reduction: "mean")
|
422
|
-
NN.l1_loss(input, target, reduction)
|
426
|
+
NN.l1_loss(input, target, to_reduction(reduction))
|
423
427
|
end
|
424
428
|
|
425
429
|
def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
|
426
|
-
Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
|
430
|
+
Torch.margin_ranking_loss(input1, input2, target, margin, to_reduction(reduction))
|
427
431
|
end
|
428
432
|
|
429
433
|
def mse_loss(input, target, reduction: "mean")
|
430
434
|
if target.size != input.size
|
431
435
|
warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
|
432
436
|
end
|
433
|
-
NN.mse_loss(input, target, reduction)
|
437
|
+
NN.mse_loss(input, target, to_reduction(reduction))
|
434
438
|
end
|
435
439
|
|
436
440
|
def multilabel_margin_loss(input, target, reduction: "mean")
|
437
|
-
NN.multilabel_margin_loss(input, target, reduction)
|
441
|
+
NN.multilabel_margin_loss(input, target, to_reduction(reduction))
|
438
442
|
end
|
439
443
|
|
440
444
|
def multilabel_soft_margin_loss(input, target, weight: nil)
|
@@ -442,31 +446,116 @@ module Torch
|
|
442
446
|
end
|
443
447
|
|
444
448
|
def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
|
445
|
-
NN.multi_margin_loss(input, target, p, margin, weight, reduction)
|
449
|
+
NN.multi_margin_loss(input, target, p, margin, weight, to_reduction(reduction))
|
446
450
|
end
|
447
451
|
|
448
452
|
def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
449
|
-
NN.nll_loss(input, target, weight, reduction, ignore_index)
|
453
|
+
NN.nll_loss(input, target, weight, to_reduction(reduction), ignore_index)
|
450
454
|
end
|
451
455
|
|
452
456
|
def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
453
|
-
Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
|
457
|
+
Torch.poisson_nll_loss(input, target, log_input, full, eps, to_reduction(reduction))
|
454
458
|
end
|
455
459
|
|
456
460
|
def soft_margin_loss(input, target, reduction: "mean")
|
457
|
-
NN.soft_margin_loss(input, target, reduction)
|
461
|
+
NN.soft_margin_loss(input, target, to_reduction(reduction))
|
458
462
|
end
|
459
463
|
|
460
464
|
def smooth_l1_loss(input, target, reduction: "mean")
|
461
|
-
NN.smooth_l1_loss(input, target, reduction)
|
465
|
+
NN.smooth_l1_loss(input, target, to_reduction(reduction))
|
462
466
|
end
|
463
467
|
|
464
468
|
def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
|
465
|
-
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
469
|
+
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
|
470
|
+
end
|
471
|
+
|
472
|
+
# vision
|
473
|
+
|
474
|
+
def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
|
475
|
+
if ["nearest", "area"].include?(mode)
|
476
|
+
unless align_corners.nil?
|
477
|
+
raise ArgumentError, "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
|
478
|
+
end
|
479
|
+
else
|
480
|
+
if align_corners.nil?
|
481
|
+
align_corners = false
|
482
|
+
end
|
483
|
+
end
|
484
|
+
|
485
|
+
scale_factor_len = input.dim - 2
|
486
|
+
scale_factor_list = [nil] * scale_factor_len
|
487
|
+
# default value of recompute_scale_factor is False
|
488
|
+
if !scale_factor.nil? && (recompute_scale_factor == false || recompute_scale_factor.nil?)
|
489
|
+
if scale_factor.is_a?(Array)
|
490
|
+
_scale_factor_repeated = scale_factor
|
491
|
+
else
|
492
|
+
_scale_factor_repeated = [scale_factor] * scale_factor_len
|
493
|
+
end
|
494
|
+
scale_factor_list = _scale_factor_repeated
|
495
|
+
end
|
496
|
+
|
497
|
+
# Give this variable a short name because it has to be repeated multiple times below.
|
498
|
+
sfl = scale_factor_list
|
499
|
+
|
500
|
+
closed_over_args = [input, size, scale_factor, recompute_scale_factor]
|
501
|
+
output_size = _interp_output_size(closed_over_args)
|
502
|
+
if input.dim == 3 && mode == "nearest"
|
503
|
+
NN.upsample_nearest1d(input, output_size, sfl[0])
|
504
|
+
elsif input.dim == 4 && mode == "nearest"
|
505
|
+
NN.upsample_nearest2d(input, output_size, sfl[0], sfl[1])
|
506
|
+
elsif input.dim == 5 && mode == "nearest"
|
507
|
+
NN.upsample_nearest3d(input, output_size, sfl[0], sfl[1], sfl[2])
|
508
|
+
elsif input.dim == 3 && mode == "area"
|
509
|
+
adaptive_avg_pool1d(input, output_size)
|
510
|
+
elsif input.dim == 4 && mode == "area"
|
511
|
+
adaptive_avg_pool2d(input, output_size)
|
512
|
+
elsif input.dim == 5 && mode == "area"
|
513
|
+
adaptive_avg_pool3d(input, output_size)
|
514
|
+
elsif input.dim == 3 && mode == "linear"
|
515
|
+
# assert align_corners is not None
|
516
|
+
NN.upsample_linear1d(input, output_size, align_corners, sfl[0])
|
517
|
+
elsif input.dim == 3 && mode == "bilinear"
|
518
|
+
raise ArgumentError, "Got 3D input, but bilinear mode needs 4D input"
|
519
|
+
elsif input.dim == 3 && mode == "trilinear"
|
520
|
+
raise ArgumentError, "Got 3D input, but trilinear mode needs 5D input"
|
521
|
+
elsif input.dim == 4 && mode == "linear"
|
522
|
+
raise ArgumentError, "Got 4D input, but linear mode needs 3D input"
|
523
|
+
elsif input.dim == 4 && mode == "bilinear"
|
524
|
+
# assert align_corners is not None
|
525
|
+
NN.upsample_bilinear2d(input, output_size, align_corners, sfl[0], sfl[1])
|
526
|
+
elsif input.dim == 4 && mode == "trilinear"
|
527
|
+
raise ArgumentError, "Got 4D input, but trilinear mode needs 5D input"
|
528
|
+
elsif input.dim == 5 && mode == "linear"
|
529
|
+
raise ArgumentError, "Got 5D input, but linear mode needs 3D input"
|
530
|
+
elsif input.dim == 5 && mode == "bilinear"
|
531
|
+
raise ArgumentError, "Got 5D input, but bilinear mode needs 4D input"
|
532
|
+
elsif input.dim == 5 && mode == "trilinear"
|
533
|
+
# assert align_corners is not None
|
534
|
+
NN.upsample_trilinear3d(input, output_size, align_corners, sfl[0], sfl[1], sfl[2])
|
535
|
+
elsif input.dim == 4 && mode == "bicubic"
|
536
|
+
# assert align_corners is not None
|
537
|
+
NN.upsample_bicubic2d(input, output_size, align_corners, sfl[0], sfl[1])
|
538
|
+
else
|
539
|
+
raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})"
|
540
|
+
end
|
466
541
|
end
|
467
542
|
|
468
543
|
private
|
469
544
|
|
545
|
+
# see _reduction.py
|
546
|
+
def to_reduction(v)
|
547
|
+
case v.to_s
|
548
|
+
when "none"
|
549
|
+
0
|
550
|
+
when "mean"
|
551
|
+
1
|
552
|
+
when "sum"
|
553
|
+
2
|
554
|
+
else
|
555
|
+
raise ArgumentError, "#{v} is not a valid value for reduction"
|
556
|
+
end
|
557
|
+
end
|
558
|
+
|
470
559
|
def softmax_dim(ndim)
|
471
560
|
ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
|
472
561
|
end
|
@@ -480,6 +569,41 @@ module Torch
|
|
480
569
|
out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
|
481
570
|
end
|
482
571
|
end
|
572
|
+
|
573
|
+
def _interp_output_size(closed_over_args)
|
574
|
+
input, size, scale_factor, recompute_scale_factor = closed_over_args
|
575
|
+
dim = input.dim - 2
|
576
|
+
if size.nil? && scale_factor.nil?
|
577
|
+
raise ArgumentError, "either size or scale_factor should be defined"
|
578
|
+
end
|
579
|
+
if !size.nil? && !scale_factor.nil?
|
580
|
+
raise ArgumentError, "only one of size or scale_factor should be defined"
|
581
|
+
end
|
582
|
+
if !scale_factor.nil?
|
583
|
+
if scale_factor.is_a?(Array)
|
584
|
+
if scale_factor.length != dim
|
585
|
+
raise ArgumentError, "scale_factor shape must match input shape. Input is #{dim}D, scale_factor size is #{scale_factor.length}"
|
586
|
+
end
|
587
|
+
end
|
588
|
+
end
|
589
|
+
|
590
|
+
if !size.nil?
|
591
|
+
if size.is_a?(Array)
|
592
|
+
return size
|
593
|
+
else
|
594
|
+
return [size] * dim
|
595
|
+
end
|
596
|
+
end
|
597
|
+
|
598
|
+
raise "Failed assertion" if scale_factor.nil?
|
599
|
+
if scale_factor.is_a?(Array)
|
600
|
+
scale_factors = scale_factor
|
601
|
+
else
|
602
|
+
scale_factors = [scale_factor] * dim
|
603
|
+
end
|
604
|
+
|
605
|
+
dim.times.map { |i| (input.size(i + 2) * scale_factors[i]).floor }
|
606
|
+
end
|
483
607
|
end
|
484
608
|
end
|
485
609
|
|