torch-rb 0.3.4 → 0.4.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_tensor_functions(Module m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Module m);
@@ -0,0 +1,42 @@
1
+ #pragma once
2
+
3
+ #include <rice/Symbol.hpp>
4
+
5
+ // keep THP prefix for now to make it easier to compare code
6
+
7
+ extern VALUE THPVariableClass;
8
+
9
+ inline VALUE THPUtils_internSymbol(const std::string& str) {
10
+ return Symbol(str);
11
+ }
12
+
13
+ inline std::string THPUtils_unpackSymbol(VALUE obj) {
14
+ Check_Type(obj, T_SYMBOL);
15
+ obj = rb_funcall(obj, rb_intern("to_s"), 0);
16
+ return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
17
+ }
18
+
19
+ inline std::string THPUtils_unpackString(VALUE obj) {
20
+ Check_Type(obj, T_STRING);
21
+ return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
22
+ }
23
+
24
+ inline bool THPUtils_checkSymbol(VALUE obj) {
25
+ return SYMBOL_P(obj);
26
+ }
27
+
28
+ inline bool THPUtils_checkIndex(VALUE obj) {
29
+ return FIXNUM_P(obj);
30
+ }
31
+
32
+ inline bool THPUtils_checkScalar(VALUE obj) {
33
+ return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
34
+ }
35
+
36
+ inline bool THPVariable_Check(VALUE obj) {
37
+ return rb_obj_is_kind_of(obj, THPVariableClass);
38
+ }
39
+
40
+ inline bool THPVariable_CheckExact(VALUE obj) {
41
+ return rb_obj_is_instance_of(obj, THPVariableClass);
42
+ }
@@ -1,15 +1,44 @@
1
+ #pragma once
2
+
1
3
  #include <torch/torch.h>
2
4
  #include <rice/Object.hpp>
3
- #include "templates.hpp"
4
5
 
5
- Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
6
+ inline Object wrap(bool x) {
7
+ return to_ruby<bool>(x);
8
+ }
9
+
10
+ inline Object wrap(int64_t x) {
11
+ return to_ruby<int64_t>(x);
12
+ }
13
+
14
+ inline Object wrap(double x) {
15
+ return to_ruby<double>(x);
16
+ }
17
+
18
+ inline Object wrap(torch::Tensor x) {
19
+ return to_ruby<torch::Tensor>(x);
20
+ }
21
+
22
+ inline Object wrap(torch::Scalar x) {
23
+ return to_ruby<torch::Scalar>(x);
24
+ }
25
+
26
+ inline Object wrap(torch::ScalarType x) {
27
+ return to_ruby<torch::ScalarType>(x);
28
+ }
29
+
30
+ inline Object wrap(torch::QScheme x) {
31
+ return to_ruby<torch::QScheme>(x);
32
+ }
33
+
34
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
6
35
  Array a;
7
36
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
8
37
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
9
38
  return Object(a);
10
39
  }
11
40
 
12
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
41
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
13
42
  Array a;
14
43
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
15
44
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -17,7 +46,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
17
46
  return Object(a);
18
47
  }
19
48
 
20
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
49
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
21
50
  Array a;
22
51
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
23
52
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -26,7 +55,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
26
55
  return Object(a);
27
56
  }
28
57
 
29
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
58
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
30
59
  Array a;
31
60
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
32
61
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -36,7 +65,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
36
65
  return Object(a);
37
66
  }
38
67
 
39
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
68
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
40
69
  Array a;
41
70
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
42
71
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -45,7 +74,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x)
45
74
  return Object(a);
46
75
  }
47
76
 
48
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
77
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
49
78
  Array a;
50
79
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
51
80
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -53,3 +82,11 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
53
82
  a.push(to_ruby<int64_t>(std::get<3>(x)));
54
83
  return Object(a);
55
84
  }
85
+
86
+ inline Object wrap(torch::TensorList x) {
87
+ Array a;
88
+ for (auto& t : x) {
89
+ a.push(to_ruby<torch::Tensor>(t));
90
+ }
91
+ return Object(a);
92
+ }
@@ -7,11 +7,6 @@ require "net/http"
7
7
  require "set"
8
8
  require "tmpdir"
9
9
 
10
- # native functions
11
- require "torch/native/generator"
12
- require "torch/native/parser"
13
- require "torch/native/dispatcher"
14
-
15
10
  # modules
16
11
  require "torch/inspector"
17
12
  require "torch/tensor"
@@ -174,6 +169,9 @@ require "torch/nn/smooth_l1_loss"
174
169
  require "torch/nn/soft_margin_loss"
175
170
  require "torch/nn/triplet_margin_loss"
176
171
 
172
+ # nn vision
173
+ require "torch/nn/upsample"
174
+
177
175
  # nn other
178
176
  require "torch/nn/functional"
179
177
  require "torch/nn/init"
@@ -196,6 +194,32 @@ module Torch
196
194
  end
197
195
  end
198
196
 
197
+ # legacy
198
+ # but may make it easier to port tutorials
199
+ module Autograd
200
+ class Variable
201
+ def self.new(x)
202
+ raise ArgumentError, "Variable data has to be a tensor, but got #{x.class.name}" unless x.is_a?(Tensor)
203
+ warn "[torch] The Variable API is deprecated. Use tensors with requires_grad: true instead."
204
+ x
205
+ end
206
+ end
207
+ end
208
+
209
+ # TODO move to C++
210
+ class ByteStorage
211
+ # private
212
+ attr_reader :bytes
213
+
214
+ def initialize(bytes)
215
+ @bytes = bytes
216
+ end
217
+
218
+ def self.from_buffer(bytes)
219
+ new(bytes)
220
+ end
221
+ end
222
+
199
223
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
200
224
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
201
225
  DTYPE_TO_ENUM = {
@@ -224,40 +248,43 @@ module Torch
224
248
  }
225
249
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
226
250
 
251
+ TENSOR_TYPE_CLASSES = []
252
+
227
253
  def self._make_tensor_class(dtype, cuda = false)
228
254
  cls = Class.new
229
255
  device = cuda ? "cuda" : "cpu"
230
256
  cls.define_singleton_method("new") do |*args|
231
257
  if args.size == 1 && args.first.is_a?(Tensor)
232
258
  args.first.send(dtype).to(device)
259
+ elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
260
+ bytes = args.first.bytes
261
+ Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
233
262
  elsif args.size == 1 && args.first.is_a?(Array)
234
263
  Torch.tensor(args.first, dtype: dtype, device: device)
235
264
  else
236
265
  Torch.empty(*args, dtype: dtype, device: device)
237
266
  end
238
267
  end
268
+ TENSOR_TYPE_CLASSES << cls
239
269
  cls
240
270
  end
241
271
 
242
- FloatTensor = _make_tensor_class(:float32)
243
- DoubleTensor = _make_tensor_class(:float64)
244
- HalfTensor = _make_tensor_class(:float16)
245
- ByteTensor = _make_tensor_class(:uint8)
246
- CharTensor = _make_tensor_class(:int8)
247
- ShortTensor = _make_tensor_class(:int16)
248
- IntTensor = _make_tensor_class(:int32)
249
- LongTensor = _make_tensor_class(:int64)
250
- BoolTensor = _make_tensor_class(:bool)
251
-
252
- CUDA::FloatTensor = _make_tensor_class(:float32, true)
253
- CUDA::DoubleTensor = _make_tensor_class(:float64, true)
254
- CUDA::HalfTensor = _make_tensor_class(:float16, true)
255
- CUDA::ByteTensor = _make_tensor_class(:uint8, true)
256
- CUDA::CharTensor = _make_tensor_class(:int8, true)
257
- CUDA::ShortTensor = _make_tensor_class(:int16, true)
258
- CUDA::IntTensor = _make_tensor_class(:int32, true)
259
- CUDA::LongTensor = _make_tensor_class(:int64, true)
260
- CUDA::BoolTensor = _make_tensor_class(:bool, true)
272
+ DTYPE_TO_CLASS = {
273
+ float32: "FloatTensor",
274
+ float64: "DoubleTensor",
275
+ float16: "HalfTensor",
276
+ uint8: "ByteTensor",
277
+ int8: "CharTensor",
278
+ int16: "ShortTensor",
279
+ int32: "IntTensor",
280
+ int64: "LongTensor",
281
+ bool: "BoolTensor"
282
+ }
283
+
284
+ DTYPE_TO_CLASS.each do |dtype, class_name|
285
+ const_set(class_name, _make_tensor_class(dtype))
286
+ CUDA.const_set(class_name, _make_tensor_class(dtype, true))
287
+ end
261
288
 
262
289
  class << self
263
290
  # Torch.float, Torch.long, etc
@@ -342,59 +369,6 @@ module Torch
342
369
 
343
370
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
344
371
 
345
- def arange(start, finish = nil, step = 1, **options)
346
- # ruby doesn't support start = 0, finish, step = 1, ...
347
- if finish.nil?
348
- finish = start
349
- start = 0
350
- end
351
- _arange(start, finish, step, tensor_options(**options))
352
- end
353
-
354
- def empty(*size, **options)
355
- _empty(tensor_size(size), tensor_options(**options))
356
- end
357
-
358
- def eye(n, m = nil, **options)
359
- _eye(n, m || n, tensor_options(**options))
360
- end
361
-
362
- def full(size, fill_value, **options)
363
- _full(size, fill_value, tensor_options(**options))
364
- end
365
-
366
- def linspace(start, finish, steps = 100, **options)
367
- _linspace(start, finish, steps, tensor_options(**options))
368
- end
369
-
370
- def logspace(start, finish, steps = 100, base = 10.0, **options)
371
- _logspace(start, finish, steps, base, tensor_options(**options))
372
- end
373
-
374
- def ones(*size, **options)
375
- _ones(tensor_size(size), tensor_options(**options))
376
- end
377
-
378
- def rand(*size, **options)
379
- _rand(tensor_size(size), tensor_options(**options))
380
- end
381
-
382
- def randint(low = 0, high, size, **options)
383
- _randint(low, high, size, tensor_options(**options))
384
- end
385
-
386
- def randn(*size, **options)
387
- _randn(tensor_size(size), tensor_options(**options))
388
- end
389
-
390
- def randperm(n, **options)
391
- _randperm(n, tensor_options(**options))
392
- end
393
-
394
- def zeros(*size, **options)
395
- _zeros(tensor_size(size), tensor_options(**options))
396
- end
397
-
398
372
  def tensor(data, **options)
399
373
  if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
400
374
  numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
@@ -178,8 +178,12 @@ module Torch
178
178
  Torch.hardshrink(input, lambd)
179
179
  end
180
180
 
181
- def leaky_relu(input, negative_slope = 0.01)
182
- NN.leaky_relu(input, negative_slope)
181
+ def leaky_relu(input, negative_slope = 0.01, inplace: false)
182
+ if inplace
183
+ NN.leaky_relu!(input, negative_slope)
184
+ else
185
+ NN.leaky_relu(input, negative_slope)
186
+ end
183
187
  end
184
188
 
185
189
  def log_sigmoid(input)
@@ -390,15 +394,15 @@ module Torch
390
394
  # loss functions
391
395
 
392
396
  def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
393
- NN.binary_cross_entropy(input, target, weight, reduction)
397
+ NN.binary_cross_entropy(input, target, weight, to_reduction(reduction))
394
398
  end
395
399
 
396
400
  def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
397
- Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
401
+ Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, to_reduction(reduction))
398
402
  end
399
403
 
400
404
  def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
401
- Torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
405
+ Torch.cosine_embedding_loss(input1, input2, target, margin, to_reduction(reduction))
402
406
  end
403
407
 
404
408
  def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
@@ -407,34 +411,34 @@ module Torch
407
411
 
408
412
  def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
409
413
  # call to_a on input_lengths and target_lengths for C++
410
- Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
414
+ Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, to_reduction(reduction), zero_infinity)
411
415
  end
412
416
 
413
417
  def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
414
- Torch.hinge_embedding_loss(input, target, margin, reduction)
418
+ Torch.hinge_embedding_loss(input, target, margin, to_reduction(reduction))
415
419
  end
416
420
 
417
421
  def kl_div(input, target, reduction: "mean")
418
- Torch.kl_div(input, target, reduction)
422
+ Torch.kl_div(input, target, to_reduction(reduction))
419
423
  end
420
424
 
421
425
  def l1_loss(input, target, reduction: "mean")
422
- NN.l1_loss(input, target, reduction)
426
+ NN.l1_loss(input, target, to_reduction(reduction))
423
427
  end
424
428
 
425
429
  def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
426
- Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
430
+ Torch.margin_ranking_loss(input1, input2, target, margin, to_reduction(reduction))
427
431
  end
428
432
 
429
433
  def mse_loss(input, target, reduction: "mean")
430
434
  if target.size != input.size
431
435
  warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
432
436
  end
433
- NN.mse_loss(input, target, reduction)
437
+ NN.mse_loss(input, target, to_reduction(reduction))
434
438
  end
435
439
 
436
440
  def multilabel_margin_loss(input, target, reduction: "mean")
437
- NN.multilabel_margin_loss(input, target, reduction)
441
+ NN.multilabel_margin_loss(input, target, to_reduction(reduction))
438
442
  end
439
443
 
440
444
  def multilabel_soft_margin_loss(input, target, weight: nil)
@@ -442,31 +446,116 @@ module Torch
442
446
  end
443
447
 
444
448
  def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
445
- NN.multi_margin_loss(input, target, p, margin, weight, reduction)
449
+ NN.multi_margin_loss(input, target, p, margin, weight, to_reduction(reduction))
446
450
  end
447
451
 
448
452
  def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
449
- NN.nll_loss(input, target, weight, reduction, ignore_index)
453
+ NN.nll_loss(input, target, weight, to_reduction(reduction), ignore_index)
450
454
  end
451
455
 
452
456
  def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
453
- Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
457
+ Torch.poisson_nll_loss(input, target, log_input, full, eps, to_reduction(reduction))
454
458
  end
455
459
 
456
460
  def soft_margin_loss(input, target, reduction: "mean")
457
- NN.soft_margin_loss(input, target, reduction)
461
+ NN.soft_margin_loss(input, target, to_reduction(reduction))
458
462
  end
459
463
 
460
464
  def smooth_l1_loss(input, target, reduction: "mean")
461
- NN.smooth_l1_loss(input, target, reduction)
465
+ NN.smooth_l1_loss(input, target, to_reduction(reduction))
462
466
  end
463
467
 
464
468
  def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
465
- Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
469
+ Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
470
+ end
471
+
472
+ # vision
473
+
474
+ def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
475
+ if ["nearest", "area"].include?(mode)
476
+ unless align_corners.nil?
477
+ raise ArgumentError, "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
478
+ end
479
+ else
480
+ if align_corners.nil?
481
+ align_corners = false
482
+ end
483
+ end
484
+
485
+ scale_factor_len = input.dim - 2
486
+ scale_factor_list = [nil] * scale_factor_len
487
+ # default value of recompute_scale_factor is False
488
+ if !scale_factor.nil? && (recompute_scale_factor == false || recompute_scale_factor.nil?)
489
+ if scale_factor.is_a?(Array)
490
+ _scale_factor_repeated = scale_factor
491
+ else
492
+ _scale_factor_repeated = [scale_factor] * scale_factor_len
493
+ end
494
+ scale_factor_list = _scale_factor_repeated
495
+ end
496
+
497
+ # Give this variable a short name because it has to be repeated multiple times below.
498
+ sfl = scale_factor_list
499
+
500
+ closed_over_args = [input, size, scale_factor, recompute_scale_factor]
501
+ output_size = _interp_output_size(closed_over_args)
502
+ if input.dim == 3 && mode == "nearest"
503
+ NN.upsample_nearest1d(input, output_size, sfl[0])
504
+ elsif input.dim == 4 && mode == "nearest"
505
+ NN.upsample_nearest2d(input, output_size, sfl[0], sfl[1])
506
+ elsif input.dim == 5 && mode == "nearest"
507
+ NN.upsample_nearest3d(input, output_size, sfl[0], sfl[1], sfl[2])
508
+ elsif input.dim == 3 && mode == "area"
509
+ adaptive_avg_pool1d(input, output_size)
510
+ elsif input.dim == 4 && mode == "area"
511
+ adaptive_avg_pool2d(input, output_size)
512
+ elsif input.dim == 5 && mode == "area"
513
+ adaptive_avg_pool3d(input, output_size)
514
+ elsif input.dim == 3 && mode == "linear"
515
+ # assert align_corners is not None
516
+ NN.upsample_linear1d(input, output_size, align_corners, sfl[0])
517
+ elsif input.dim == 3 && mode == "bilinear"
518
+ raise ArgumentError, "Got 3D input, but bilinear mode needs 4D input"
519
+ elsif input.dim == 3 && mode == "trilinear"
520
+ raise ArgumentError, "Got 3D input, but trilinear mode needs 5D input"
521
+ elsif input.dim == 4 && mode == "linear"
522
+ raise ArgumentError, "Got 4D input, but linear mode needs 3D input"
523
+ elsif input.dim == 4 && mode == "bilinear"
524
+ # assert align_corners is not None
525
+ NN.upsample_bilinear2d(input, output_size, align_corners, sfl[0], sfl[1])
526
+ elsif input.dim == 4 && mode == "trilinear"
527
+ raise ArgumentError, "Got 4D input, but trilinear mode needs 5D input"
528
+ elsif input.dim == 5 && mode == "linear"
529
+ raise ArgumentError, "Got 5D input, but linear mode needs 3D input"
530
+ elsif input.dim == 5 && mode == "bilinear"
531
+ raise ArgumentError, "Got 5D input, but bilinear mode needs 4D input"
532
+ elsif input.dim == 5 && mode == "trilinear"
533
+ # assert align_corners is not None
534
+ NN.upsample_trilinear3d(input, output_size, align_corners, sfl[0], sfl[1], sfl[2])
535
+ elsif input.dim == 4 && mode == "bicubic"
536
+ # assert align_corners is not None
537
+ NN.upsample_bicubic2d(input, output_size, align_corners, sfl[0], sfl[1])
538
+ else
539
+ raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})"
540
+ end
466
541
  end
467
542
 
468
543
  private
469
544
 
545
+ # see _reduction.py
546
+ def to_reduction(v)
547
+ case v.to_s
548
+ when "none"
549
+ 0
550
+ when "mean"
551
+ 1
552
+ when "sum"
553
+ 2
554
+ else
555
+ raise ArgumentError, "#{v} is not a valid value for reduction"
556
+ end
557
+ end
558
+
470
559
  def softmax_dim(ndim)
471
560
  ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
472
561
  end
@@ -480,6 +569,41 @@ module Torch
480
569
  out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
481
570
  end
482
571
  end
572
+
573
+ def _interp_output_size(closed_over_args)
574
+ input, size, scale_factor, recompute_scale_factor = closed_over_args
575
+ dim = input.dim - 2
576
+ if size.nil? && scale_factor.nil?
577
+ raise ArgumentError, "either size or scale_factor should be defined"
578
+ end
579
+ if !size.nil? && !scale_factor.nil?
580
+ raise ArgumentError, "only one of size or scale_factor should be defined"
581
+ end
582
+ if !scale_factor.nil?
583
+ if scale_factor.is_a?(Array)
584
+ if scale_factor.length != dim
585
+ raise ArgumentError, "scale_factor shape must match input shape. Input is #{dim}D, scale_factor size is #{scale_factor.length}"
586
+ end
587
+ end
588
+ end
589
+
590
+ if !size.nil?
591
+ if size.is_a?(Array)
592
+ return size
593
+ else
594
+ return [size] * dim
595
+ end
596
+ end
597
+
598
+ raise "Failed assertion" if scale_factor.nil?
599
+ if scale_factor.is_a?(Array)
600
+ scale_factors = scale_factor
601
+ else
602
+ scale_factors = [scale_factor] * dim
603
+ end
604
+
605
+ dim.times.map { |i| (input.size(i + 2) * scale_factors[i]).floor }
606
+ end
483
607
  end
484
608
  end
485
609