torch-rb 0.3.4 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +2 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +549 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +76 -87
- data/ext/torch/extconf.rb +5 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +87 -97
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +44 -7
- data/lib/torch.rb +51 -77
- data/lib/torch/nn/functional.rb +142 -18
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +36 -115
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +19 -14
- data/lib/torch/native/dispatcher.rb +0 -48
- data/lib/torch/native/function.rb +0 -115
- data/lib/torch/native/generator.rb +0 -163
- data/lib/torch/native/parser.rb +0 -140
data/ext/torch/utils.h
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
3
|
+
#include <rice/Symbol.hpp>
|
4
|
+
|
5
|
+
// keep THP prefix for now to make it easier to compare code
|
6
|
+
|
7
|
+
extern VALUE THPVariableClass;
|
8
|
+
|
9
|
+
inline VALUE THPUtils_internSymbol(const std::string& str) {
|
10
|
+
return Symbol(str);
|
11
|
+
}
|
12
|
+
|
13
|
+
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
14
|
+
Check_Type(obj, T_SYMBOL);
|
15
|
+
obj = rb_funcall(obj, rb_intern("to_s"), 0);
|
16
|
+
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
17
|
+
}
|
18
|
+
|
19
|
+
inline std::string THPUtils_unpackString(VALUE obj) {
|
20
|
+
Check_Type(obj, T_STRING);
|
21
|
+
return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
|
22
|
+
}
|
23
|
+
|
24
|
+
inline bool THPUtils_checkSymbol(VALUE obj) {
|
25
|
+
return SYMBOL_P(obj);
|
26
|
+
}
|
27
|
+
|
28
|
+
inline bool THPUtils_checkIndex(VALUE obj) {
|
29
|
+
return FIXNUM_P(obj);
|
30
|
+
}
|
31
|
+
|
32
|
+
inline bool THPUtils_checkScalar(VALUE obj) {
|
33
|
+
return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
|
34
|
+
}
|
35
|
+
|
36
|
+
inline bool THPVariable_Check(VALUE obj) {
|
37
|
+
return rb_obj_is_kind_of(obj, THPVariableClass);
|
38
|
+
}
|
39
|
+
|
40
|
+
inline bool THPVariable_CheckExact(VALUE obj) {
|
41
|
+
return rb_obj_is_instance_of(obj, THPVariableClass);
|
42
|
+
}
|
@@ -1,15 +1,44 @@
|
|
1
|
+
#pragma once
|
2
|
+
|
1
3
|
#include <torch/torch.h>
|
2
4
|
#include <rice/Object.hpp>
|
3
|
-
#include "templates.hpp"
|
4
5
|
|
5
|
-
Object wrap(
|
6
|
+
inline Object wrap(bool x) {
|
7
|
+
return to_ruby<bool>(x);
|
8
|
+
}
|
9
|
+
|
10
|
+
inline Object wrap(int64_t x) {
|
11
|
+
return to_ruby<int64_t>(x);
|
12
|
+
}
|
13
|
+
|
14
|
+
inline Object wrap(double x) {
|
15
|
+
return to_ruby<double>(x);
|
16
|
+
}
|
17
|
+
|
18
|
+
inline Object wrap(torch::Tensor x) {
|
19
|
+
return to_ruby<torch::Tensor>(x);
|
20
|
+
}
|
21
|
+
|
22
|
+
inline Object wrap(torch::Scalar x) {
|
23
|
+
return to_ruby<torch::Scalar>(x);
|
24
|
+
}
|
25
|
+
|
26
|
+
inline Object wrap(torch::ScalarType x) {
|
27
|
+
return to_ruby<torch::ScalarType>(x);
|
28
|
+
}
|
29
|
+
|
30
|
+
inline Object wrap(torch::QScheme x) {
|
31
|
+
return to_ruby<torch::QScheme>(x);
|
32
|
+
}
|
33
|
+
|
34
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
6
35
|
Array a;
|
7
36
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
8
37
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
9
38
|
return Object(a);
|
10
39
|
}
|
11
40
|
|
12
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
41
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
13
42
|
Array a;
|
14
43
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
15
44
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -17,7 +46,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
|
17
46
|
return Object(a);
|
18
47
|
}
|
19
48
|
|
20
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
49
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
21
50
|
Array a;
|
22
51
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
23
52
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -26,7 +55,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
|
|
26
55
|
return Object(a);
|
27
56
|
}
|
28
57
|
|
29
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
58
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
|
30
59
|
Array a;
|
31
60
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
32
61
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -36,7 +65,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
|
|
36
65
|
return Object(a);
|
37
66
|
}
|
38
67
|
|
39
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
68
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
|
40
69
|
Array a;
|
41
70
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
42
71
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -45,7 +74,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x)
|
|
45
74
|
return Object(a);
|
46
75
|
}
|
47
76
|
|
48
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
77
|
+
inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
49
78
|
Array a;
|
50
79
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
51
80
|
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
@@ -53,3 +82,11 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
|
53
82
|
a.push(to_ruby<int64_t>(std::get<3>(x)));
|
54
83
|
return Object(a);
|
55
84
|
}
|
85
|
+
|
86
|
+
inline Object wrap(torch::TensorList x) {
|
87
|
+
Array a;
|
88
|
+
for (auto& t : x) {
|
89
|
+
a.push(to_ruby<torch::Tensor>(t));
|
90
|
+
}
|
91
|
+
return Object(a);
|
92
|
+
}
|
data/lib/torch.rb
CHANGED
@@ -7,11 +7,6 @@ require "net/http"
|
|
7
7
|
require "set"
|
8
8
|
require "tmpdir"
|
9
9
|
|
10
|
-
# native functions
|
11
|
-
require "torch/native/generator"
|
12
|
-
require "torch/native/parser"
|
13
|
-
require "torch/native/dispatcher"
|
14
|
-
|
15
10
|
# modules
|
16
11
|
require "torch/inspector"
|
17
12
|
require "torch/tensor"
|
@@ -174,6 +169,9 @@ require "torch/nn/smooth_l1_loss"
|
|
174
169
|
require "torch/nn/soft_margin_loss"
|
175
170
|
require "torch/nn/triplet_margin_loss"
|
176
171
|
|
172
|
+
# nn vision
|
173
|
+
require "torch/nn/upsample"
|
174
|
+
|
177
175
|
# nn other
|
178
176
|
require "torch/nn/functional"
|
179
177
|
require "torch/nn/init"
|
@@ -196,6 +194,32 @@ module Torch
|
|
196
194
|
end
|
197
195
|
end
|
198
196
|
|
197
|
+
# legacy
|
198
|
+
# but may make it easier to port tutorials
|
199
|
+
module Autograd
|
200
|
+
class Variable
|
201
|
+
def self.new(x)
|
202
|
+
raise ArgumentError, "Variable data has to be a tensor, but got #{x.class.name}" unless x.is_a?(Tensor)
|
203
|
+
warn "[torch] The Variable API is deprecated. Use tensors with requires_grad: true instead."
|
204
|
+
x
|
205
|
+
end
|
206
|
+
end
|
207
|
+
end
|
208
|
+
|
209
|
+
# TODO move to C++
|
210
|
+
class ByteStorage
|
211
|
+
# private
|
212
|
+
attr_reader :bytes
|
213
|
+
|
214
|
+
def initialize(bytes)
|
215
|
+
@bytes = bytes
|
216
|
+
end
|
217
|
+
|
218
|
+
def self.from_buffer(bytes)
|
219
|
+
new(bytes)
|
220
|
+
end
|
221
|
+
end
|
222
|
+
|
199
223
|
# keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
|
200
224
|
# values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
|
201
225
|
DTYPE_TO_ENUM = {
|
@@ -224,40 +248,43 @@ module Torch
|
|
224
248
|
}
|
225
249
|
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
|
226
250
|
|
251
|
+
TENSOR_TYPE_CLASSES = []
|
252
|
+
|
227
253
|
def self._make_tensor_class(dtype, cuda = false)
|
228
254
|
cls = Class.new
|
229
255
|
device = cuda ? "cuda" : "cpu"
|
230
256
|
cls.define_singleton_method("new") do |*args|
|
231
257
|
if args.size == 1 && args.first.is_a?(Tensor)
|
232
258
|
args.first.send(dtype).to(device)
|
259
|
+
elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
|
260
|
+
bytes = args.first.bytes
|
261
|
+
Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
233
262
|
elsif args.size == 1 && args.first.is_a?(Array)
|
234
263
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
235
264
|
else
|
236
265
|
Torch.empty(*args, dtype: dtype, device: device)
|
237
266
|
end
|
238
267
|
end
|
268
|
+
TENSOR_TYPE_CLASSES << cls
|
239
269
|
cls
|
240
270
|
end
|
241
271
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
CUDA::IntTensor = _make_tensor_class(:int32, true)
|
259
|
-
CUDA::LongTensor = _make_tensor_class(:int64, true)
|
260
|
-
CUDA::BoolTensor = _make_tensor_class(:bool, true)
|
272
|
+
DTYPE_TO_CLASS = {
|
273
|
+
float32: "FloatTensor",
|
274
|
+
float64: "DoubleTensor",
|
275
|
+
float16: "HalfTensor",
|
276
|
+
uint8: "ByteTensor",
|
277
|
+
int8: "CharTensor",
|
278
|
+
int16: "ShortTensor",
|
279
|
+
int32: "IntTensor",
|
280
|
+
int64: "LongTensor",
|
281
|
+
bool: "BoolTensor"
|
282
|
+
}
|
283
|
+
|
284
|
+
DTYPE_TO_CLASS.each do |dtype, class_name|
|
285
|
+
const_set(class_name, _make_tensor_class(dtype))
|
286
|
+
CUDA.const_set(class_name, _make_tensor_class(dtype, true))
|
287
|
+
end
|
261
288
|
|
262
289
|
class << self
|
263
290
|
# Torch.float, Torch.long, etc
|
@@ -342,59 +369,6 @@ module Torch
|
|
342
369
|
|
343
370
|
# --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
|
344
371
|
|
345
|
-
def arange(start, finish = nil, step = 1, **options)
|
346
|
-
# ruby doesn't support start = 0, finish, step = 1, ...
|
347
|
-
if finish.nil?
|
348
|
-
finish = start
|
349
|
-
start = 0
|
350
|
-
end
|
351
|
-
_arange(start, finish, step, tensor_options(**options))
|
352
|
-
end
|
353
|
-
|
354
|
-
def empty(*size, **options)
|
355
|
-
_empty(tensor_size(size), tensor_options(**options))
|
356
|
-
end
|
357
|
-
|
358
|
-
def eye(n, m = nil, **options)
|
359
|
-
_eye(n, m || n, tensor_options(**options))
|
360
|
-
end
|
361
|
-
|
362
|
-
def full(size, fill_value, **options)
|
363
|
-
_full(size, fill_value, tensor_options(**options))
|
364
|
-
end
|
365
|
-
|
366
|
-
def linspace(start, finish, steps = 100, **options)
|
367
|
-
_linspace(start, finish, steps, tensor_options(**options))
|
368
|
-
end
|
369
|
-
|
370
|
-
def logspace(start, finish, steps = 100, base = 10.0, **options)
|
371
|
-
_logspace(start, finish, steps, base, tensor_options(**options))
|
372
|
-
end
|
373
|
-
|
374
|
-
def ones(*size, **options)
|
375
|
-
_ones(tensor_size(size), tensor_options(**options))
|
376
|
-
end
|
377
|
-
|
378
|
-
def rand(*size, **options)
|
379
|
-
_rand(tensor_size(size), tensor_options(**options))
|
380
|
-
end
|
381
|
-
|
382
|
-
def randint(low = 0, high, size, **options)
|
383
|
-
_randint(low, high, size, tensor_options(**options))
|
384
|
-
end
|
385
|
-
|
386
|
-
def randn(*size, **options)
|
387
|
-
_randn(tensor_size(size), tensor_options(**options))
|
388
|
-
end
|
389
|
-
|
390
|
-
def randperm(n, **options)
|
391
|
-
_randperm(n, tensor_options(**options))
|
392
|
-
end
|
393
|
-
|
394
|
-
def zeros(*size, **options)
|
395
|
-
_zeros(tensor_size(size), tensor_options(**options))
|
396
|
-
end
|
397
|
-
|
398
372
|
def tensor(data, **options)
|
399
373
|
if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
|
400
374
|
numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -178,8 +178,12 @@ module Torch
|
|
178
178
|
Torch.hardshrink(input, lambd)
|
179
179
|
end
|
180
180
|
|
181
|
-
def leaky_relu(input, negative_slope = 0.01)
|
182
|
-
|
181
|
+
def leaky_relu(input, negative_slope = 0.01, inplace: false)
|
182
|
+
if inplace
|
183
|
+
NN.leaky_relu!(input, negative_slope)
|
184
|
+
else
|
185
|
+
NN.leaky_relu(input, negative_slope)
|
186
|
+
end
|
183
187
|
end
|
184
188
|
|
185
189
|
def log_sigmoid(input)
|
@@ -390,15 +394,15 @@ module Torch
|
|
390
394
|
# loss functions
|
391
395
|
|
392
396
|
def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
|
393
|
-
NN.binary_cross_entropy(input, target, weight, reduction)
|
397
|
+
NN.binary_cross_entropy(input, target, weight, to_reduction(reduction))
|
394
398
|
end
|
395
399
|
|
396
400
|
def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
|
397
|
-
Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
|
401
|
+
Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, to_reduction(reduction))
|
398
402
|
end
|
399
403
|
|
400
404
|
def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
|
401
|
-
Torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
|
405
|
+
Torch.cosine_embedding_loss(input1, input2, target, margin, to_reduction(reduction))
|
402
406
|
end
|
403
407
|
|
404
408
|
def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
@@ -407,34 +411,34 @@ module Torch
|
|
407
411
|
|
408
412
|
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
|
409
413
|
# call to_a on input_lengths and target_lengths for C++
|
410
|
-
Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
|
414
|
+
Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, to_reduction(reduction), zero_infinity)
|
411
415
|
end
|
412
416
|
|
413
417
|
def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
|
414
|
-
Torch.hinge_embedding_loss(input, target, margin, reduction)
|
418
|
+
Torch.hinge_embedding_loss(input, target, margin, to_reduction(reduction))
|
415
419
|
end
|
416
420
|
|
417
421
|
def kl_div(input, target, reduction: "mean")
|
418
|
-
Torch.kl_div(input, target, reduction)
|
422
|
+
Torch.kl_div(input, target, to_reduction(reduction))
|
419
423
|
end
|
420
424
|
|
421
425
|
def l1_loss(input, target, reduction: "mean")
|
422
|
-
NN.l1_loss(input, target, reduction)
|
426
|
+
NN.l1_loss(input, target, to_reduction(reduction))
|
423
427
|
end
|
424
428
|
|
425
429
|
def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
|
426
|
-
Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
|
430
|
+
Torch.margin_ranking_loss(input1, input2, target, margin, to_reduction(reduction))
|
427
431
|
end
|
428
432
|
|
429
433
|
def mse_loss(input, target, reduction: "mean")
|
430
434
|
if target.size != input.size
|
431
435
|
warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
|
432
436
|
end
|
433
|
-
NN.mse_loss(input, target, reduction)
|
437
|
+
NN.mse_loss(input, target, to_reduction(reduction))
|
434
438
|
end
|
435
439
|
|
436
440
|
def multilabel_margin_loss(input, target, reduction: "mean")
|
437
|
-
NN.multilabel_margin_loss(input, target, reduction)
|
441
|
+
NN.multilabel_margin_loss(input, target, to_reduction(reduction))
|
438
442
|
end
|
439
443
|
|
440
444
|
def multilabel_soft_margin_loss(input, target, weight: nil)
|
@@ -442,31 +446,116 @@ module Torch
|
|
442
446
|
end
|
443
447
|
|
444
448
|
def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
|
445
|
-
NN.multi_margin_loss(input, target, p, margin, weight, reduction)
|
449
|
+
NN.multi_margin_loss(input, target, p, margin, weight, to_reduction(reduction))
|
446
450
|
end
|
447
451
|
|
448
452
|
def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
449
|
-
NN.nll_loss(input, target, weight, reduction, ignore_index)
|
453
|
+
NN.nll_loss(input, target, weight, to_reduction(reduction), ignore_index)
|
450
454
|
end
|
451
455
|
|
452
456
|
def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
453
|
-
Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
|
457
|
+
Torch.poisson_nll_loss(input, target, log_input, full, eps, to_reduction(reduction))
|
454
458
|
end
|
455
459
|
|
456
460
|
def soft_margin_loss(input, target, reduction: "mean")
|
457
|
-
NN.soft_margin_loss(input, target, reduction)
|
461
|
+
NN.soft_margin_loss(input, target, to_reduction(reduction))
|
458
462
|
end
|
459
463
|
|
460
464
|
def smooth_l1_loss(input, target, reduction: "mean")
|
461
|
-
NN.smooth_l1_loss(input, target, reduction)
|
465
|
+
NN.smooth_l1_loss(input, target, to_reduction(reduction))
|
462
466
|
end
|
463
467
|
|
464
468
|
def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
|
465
|
-
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
469
|
+
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
|
470
|
+
end
|
471
|
+
|
472
|
+
# vision
|
473
|
+
|
474
|
+
def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
|
475
|
+
if ["nearest", "area"].include?(mode)
|
476
|
+
unless align_corners.nil?
|
477
|
+
raise ArgumentError, "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
|
478
|
+
end
|
479
|
+
else
|
480
|
+
if align_corners.nil?
|
481
|
+
align_corners = false
|
482
|
+
end
|
483
|
+
end
|
484
|
+
|
485
|
+
scale_factor_len = input.dim - 2
|
486
|
+
scale_factor_list = [nil] * scale_factor_len
|
487
|
+
# default value of recompute_scale_factor is False
|
488
|
+
if !scale_factor.nil? && (recompute_scale_factor == false || recompute_scale_factor.nil?)
|
489
|
+
if scale_factor.is_a?(Array)
|
490
|
+
_scale_factor_repeated = scale_factor
|
491
|
+
else
|
492
|
+
_scale_factor_repeated = [scale_factor] * scale_factor_len
|
493
|
+
end
|
494
|
+
scale_factor_list = _scale_factor_repeated
|
495
|
+
end
|
496
|
+
|
497
|
+
# Give this variable a short name because it has to be repeated multiple times below.
|
498
|
+
sfl = scale_factor_list
|
499
|
+
|
500
|
+
closed_over_args = [input, size, scale_factor, recompute_scale_factor]
|
501
|
+
output_size = _interp_output_size(closed_over_args)
|
502
|
+
if input.dim == 3 && mode == "nearest"
|
503
|
+
NN.upsample_nearest1d(input, output_size, sfl[0])
|
504
|
+
elsif input.dim == 4 && mode == "nearest"
|
505
|
+
NN.upsample_nearest2d(input, output_size, sfl[0], sfl[1])
|
506
|
+
elsif input.dim == 5 && mode == "nearest"
|
507
|
+
NN.upsample_nearest3d(input, output_size, sfl[0], sfl[1], sfl[2])
|
508
|
+
elsif input.dim == 3 && mode == "area"
|
509
|
+
adaptive_avg_pool1d(input, output_size)
|
510
|
+
elsif input.dim == 4 && mode == "area"
|
511
|
+
adaptive_avg_pool2d(input, output_size)
|
512
|
+
elsif input.dim == 5 && mode == "area"
|
513
|
+
adaptive_avg_pool3d(input, output_size)
|
514
|
+
elsif input.dim == 3 && mode == "linear"
|
515
|
+
# assert align_corners is not None
|
516
|
+
NN.upsample_linear1d(input, output_size, align_corners, sfl[0])
|
517
|
+
elsif input.dim == 3 && mode == "bilinear"
|
518
|
+
raise ArgumentError, "Got 3D input, but bilinear mode needs 4D input"
|
519
|
+
elsif input.dim == 3 && mode == "trilinear"
|
520
|
+
raise ArgumentError, "Got 3D input, but trilinear mode needs 5D input"
|
521
|
+
elsif input.dim == 4 && mode == "linear"
|
522
|
+
raise ArgumentError, "Got 4D input, but linear mode needs 3D input"
|
523
|
+
elsif input.dim == 4 && mode == "bilinear"
|
524
|
+
# assert align_corners is not None
|
525
|
+
NN.upsample_bilinear2d(input, output_size, align_corners, sfl[0], sfl[1])
|
526
|
+
elsif input.dim == 4 && mode == "trilinear"
|
527
|
+
raise ArgumentError, "Got 4D input, but trilinear mode needs 5D input"
|
528
|
+
elsif input.dim == 5 && mode == "linear"
|
529
|
+
raise ArgumentError, "Got 5D input, but linear mode needs 3D input"
|
530
|
+
elsif input.dim == 5 && mode == "bilinear"
|
531
|
+
raise ArgumentError, "Got 5D input, but bilinear mode needs 4D input"
|
532
|
+
elsif input.dim == 5 && mode == "trilinear"
|
533
|
+
# assert align_corners is not None
|
534
|
+
NN.upsample_trilinear3d(input, output_size, align_corners, sfl[0], sfl[1], sfl[2])
|
535
|
+
elsif input.dim == 4 && mode == "bicubic"
|
536
|
+
# assert align_corners is not None
|
537
|
+
NN.upsample_bicubic2d(input, output_size, align_corners, sfl[0], sfl[1])
|
538
|
+
else
|
539
|
+
raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})"
|
540
|
+
end
|
466
541
|
end
|
467
542
|
|
468
543
|
private
|
469
544
|
|
545
|
+
# see _reduction.py
|
546
|
+
def to_reduction(v)
|
547
|
+
case v.to_s
|
548
|
+
when "none"
|
549
|
+
0
|
550
|
+
when "mean"
|
551
|
+
1
|
552
|
+
when "sum"
|
553
|
+
2
|
554
|
+
else
|
555
|
+
raise ArgumentError, "#{v} is not a valid value for reduction"
|
556
|
+
end
|
557
|
+
end
|
558
|
+
|
470
559
|
def softmax_dim(ndim)
|
471
560
|
ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
|
472
561
|
end
|
@@ -480,6 +569,41 @@ module Torch
|
|
480
569
|
out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
|
481
570
|
end
|
482
571
|
end
|
572
|
+
|
573
|
+
def _interp_output_size(closed_over_args)
|
574
|
+
input, size, scale_factor, recompute_scale_factor = closed_over_args
|
575
|
+
dim = input.dim - 2
|
576
|
+
if size.nil? && scale_factor.nil?
|
577
|
+
raise ArgumentError, "either size or scale_factor should be defined"
|
578
|
+
end
|
579
|
+
if !size.nil? && !scale_factor.nil?
|
580
|
+
raise ArgumentError, "only one of size or scale_factor should be defined"
|
581
|
+
end
|
582
|
+
if !scale_factor.nil?
|
583
|
+
if scale_factor.is_a?(Array)
|
584
|
+
if scale_factor.length != dim
|
585
|
+
raise ArgumentError, "scale_factor shape must match input shape. Input is #{dim}D, scale_factor size is #{scale_factor.length}"
|
586
|
+
end
|
587
|
+
end
|
588
|
+
end
|
589
|
+
|
590
|
+
if !size.nil?
|
591
|
+
if size.is_a?(Array)
|
592
|
+
return size
|
593
|
+
else
|
594
|
+
return [size] * dim
|
595
|
+
end
|
596
|
+
end
|
597
|
+
|
598
|
+
raise "Failed assertion" if scale_factor.nil?
|
599
|
+
if scale_factor.is_a?(Array)
|
600
|
+
scale_factors = scale_factor
|
601
|
+
else
|
602
|
+
scale_factors = [scale_factor] * dim
|
603
|
+
end
|
604
|
+
|
605
|
+
dim.times.map { |i| (input.size(i + 2) * scale_factors[i]).floor }
|
606
|
+
end
|
483
607
|
end
|
484
608
|
end
|
485
609
|
|