torch-rb 0.3.4 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_tensor_functions(Module m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Module m);
@@ -0,0 +1,42 @@
1
+ #pragma once
2
+
3
+ #include <rice/Symbol.hpp>
4
+
5
+ // keep THP prefix for now to make it easier to compare code
6
+
7
+ extern VALUE THPVariableClass;
8
+
9
+ inline VALUE THPUtils_internSymbol(const std::string& str) {
10
+ return Symbol(str);
11
+ }
12
+
13
+ inline std::string THPUtils_unpackSymbol(VALUE obj) {
14
+ Check_Type(obj, T_SYMBOL);
15
+ obj = rb_funcall(obj, rb_intern("to_s"), 0);
16
+ return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
17
+ }
18
+
19
+ inline std::string THPUtils_unpackString(VALUE obj) {
20
+ Check_Type(obj, T_STRING);
21
+ return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
22
+ }
23
+
24
+ inline bool THPUtils_checkSymbol(VALUE obj) {
25
+ return SYMBOL_P(obj);
26
+ }
27
+
28
+ inline bool THPUtils_checkIndex(VALUE obj) {
29
+ return FIXNUM_P(obj);
30
+ }
31
+
32
+ inline bool THPUtils_checkScalar(VALUE obj) {
33
+ return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
34
+ }
35
+
36
+ inline bool THPVariable_Check(VALUE obj) {
37
+ return rb_obj_is_kind_of(obj, THPVariableClass);
38
+ }
39
+
40
+ inline bool THPVariable_CheckExact(VALUE obj) {
41
+ return rb_obj_is_instance_of(obj, THPVariableClass);
42
+ }
@@ -1,15 +1,44 @@
1
+ #pragma once
2
+
1
3
  #include <torch/torch.h>
2
4
  #include <rice/Object.hpp>
3
- #include "templates.hpp"
4
5
 
5
- Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
6
+ inline Object wrap(bool x) {
7
+ return to_ruby<bool>(x);
8
+ }
9
+
10
+ inline Object wrap(int64_t x) {
11
+ return to_ruby<int64_t>(x);
12
+ }
13
+
14
+ inline Object wrap(double x) {
15
+ return to_ruby<double>(x);
16
+ }
17
+
18
+ inline Object wrap(torch::Tensor x) {
19
+ return to_ruby<torch::Tensor>(x);
20
+ }
21
+
22
+ inline Object wrap(torch::Scalar x) {
23
+ return to_ruby<torch::Scalar>(x);
24
+ }
25
+
26
+ inline Object wrap(torch::ScalarType x) {
27
+ return to_ruby<torch::ScalarType>(x);
28
+ }
29
+
30
+ inline Object wrap(torch::QScheme x) {
31
+ return to_ruby<torch::QScheme>(x);
32
+ }
33
+
34
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
6
35
  Array a;
7
36
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
8
37
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
9
38
  return Object(a);
10
39
  }
11
40
 
12
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
41
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
13
42
  Array a;
14
43
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
15
44
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -17,7 +46,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
17
46
  return Object(a);
18
47
  }
19
48
 
20
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
49
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
21
50
  Array a;
22
51
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
23
52
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -26,7 +55,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
26
55
  return Object(a);
27
56
  }
28
57
 
29
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
58
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
30
59
  Array a;
31
60
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
32
61
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -36,7 +65,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
36
65
  return Object(a);
37
66
  }
38
67
 
39
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
68
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
40
69
  Array a;
41
70
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
42
71
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -45,7 +74,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x)
45
74
  return Object(a);
46
75
  }
47
76
 
48
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
77
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
49
78
  Array a;
50
79
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
51
80
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -53,3 +82,11 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
53
82
  a.push(to_ruby<int64_t>(std::get<3>(x)));
54
83
  return Object(a);
55
84
  }
85
+
86
+ inline Object wrap(torch::TensorList x) {
87
+ Array a;
88
+ for (auto& t : x) {
89
+ a.push(to_ruby<torch::Tensor>(t));
90
+ }
91
+ return Object(a);
92
+ }
@@ -7,11 +7,6 @@ require "net/http"
7
7
  require "set"
8
8
  require "tmpdir"
9
9
 
10
- # native functions
11
- require "torch/native/generator"
12
- require "torch/native/parser"
13
- require "torch/native/dispatcher"
14
-
15
10
  # modules
16
11
  require "torch/inspector"
17
12
  require "torch/tensor"
@@ -174,6 +169,9 @@ require "torch/nn/smooth_l1_loss"
174
169
  require "torch/nn/soft_margin_loss"
175
170
  require "torch/nn/triplet_margin_loss"
176
171
 
172
+ # nn vision
173
+ require "torch/nn/upsample"
174
+
177
175
  # nn other
178
176
  require "torch/nn/functional"
179
177
  require "torch/nn/init"
@@ -196,6 +194,32 @@ module Torch
196
194
  end
197
195
  end
198
196
 
197
+ # legacy
198
+ # but may make it easier to port tutorials
199
+ module Autograd
200
+ class Variable
201
+ def self.new(x)
202
+ raise ArgumentError, "Variable data has to be a tensor, but got #{x.class.name}" unless x.is_a?(Tensor)
203
+ warn "[torch] The Variable API is deprecated. Use tensors with requires_grad: true instead."
204
+ x
205
+ end
206
+ end
207
+ end
208
+
209
+ # TODO move to C++
210
+ class ByteStorage
211
+ # private
212
+ attr_reader :bytes
213
+
214
+ def initialize(bytes)
215
+ @bytes = bytes
216
+ end
217
+
218
+ def self.from_buffer(bytes)
219
+ new(bytes)
220
+ end
221
+ end
222
+
199
223
  # keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
200
224
  # values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
201
225
  DTYPE_TO_ENUM = {
@@ -224,40 +248,43 @@ module Torch
224
248
  }
225
249
  ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
226
250
 
251
+ TENSOR_TYPE_CLASSES = []
252
+
227
253
  def self._make_tensor_class(dtype, cuda = false)
228
254
  cls = Class.new
229
255
  device = cuda ? "cuda" : "cpu"
230
256
  cls.define_singleton_method("new") do |*args|
231
257
  if args.size == 1 && args.first.is_a?(Tensor)
232
258
  args.first.send(dtype).to(device)
259
+ elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
260
+ bytes = args.first.bytes
261
+ Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
233
262
  elsif args.size == 1 && args.first.is_a?(Array)
234
263
  Torch.tensor(args.first, dtype: dtype, device: device)
235
264
  else
236
265
  Torch.empty(*args, dtype: dtype, device: device)
237
266
  end
238
267
  end
268
+ TENSOR_TYPE_CLASSES << cls
239
269
  cls
240
270
  end
241
271
 
242
- FloatTensor = _make_tensor_class(:float32)
243
- DoubleTensor = _make_tensor_class(:float64)
244
- HalfTensor = _make_tensor_class(:float16)
245
- ByteTensor = _make_tensor_class(:uint8)
246
- CharTensor = _make_tensor_class(:int8)
247
- ShortTensor = _make_tensor_class(:int16)
248
- IntTensor = _make_tensor_class(:int32)
249
- LongTensor = _make_tensor_class(:int64)
250
- BoolTensor = _make_tensor_class(:bool)
251
-
252
- CUDA::FloatTensor = _make_tensor_class(:float32, true)
253
- CUDA::DoubleTensor = _make_tensor_class(:float64, true)
254
- CUDA::HalfTensor = _make_tensor_class(:float16, true)
255
- CUDA::ByteTensor = _make_tensor_class(:uint8, true)
256
- CUDA::CharTensor = _make_tensor_class(:int8, true)
257
- CUDA::ShortTensor = _make_tensor_class(:int16, true)
258
- CUDA::IntTensor = _make_tensor_class(:int32, true)
259
- CUDA::LongTensor = _make_tensor_class(:int64, true)
260
- CUDA::BoolTensor = _make_tensor_class(:bool, true)
272
+ DTYPE_TO_CLASS = {
273
+ float32: "FloatTensor",
274
+ float64: "DoubleTensor",
275
+ float16: "HalfTensor",
276
+ uint8: "ByteTensor",
277
+ int8: "CharTensor",
278
+ int16: "ShortTensor",
279
+ int32: "IntTensor",
280
+ int64: "LongTensor",
281
+ bool: "BoolTensor"
282
+ }
283
+
284
+ DTYPE_TO_CLASS.each do |dtype, class_name|
285
+ const_set(class_name, _make_tensor_class(dtype))
286
+ CUDA.const_set(class_name, _make_tensor_class(dtype, true))
287
+ end
261
288
 
262
289
  class << self
263
290
  # Torch.float, Torch.long, etc
@@ -342,59 +369,6 @@ module Torch
342
369
 
343
370
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
344
371
 
345
- def arange(start, finish = nil, step = 1, **options)
346
- # ruby doesn't support start = 0, finish, step = 1, ...
347
- if finish.nil?
348
- finish = start
349
- start = 0
350
- end
351
- _arange(start, finish, step, tensor_options(**options))
352
- end
353
-
354
- def empty(*size, **options)
355
- _empty(tensor_size(size), tensor_options(**options))
356
- end
357
-
358
- def eye(n, m = nil, **options)
359
- _eye(n, m || n, tensor_options(**options))
360
- end
361
-
362
- def full(size, fill_value, **options)
363
- _full(size, fill_value, tensor_options(**options))
364
- end
365
-
366
- def linspace(start, finish, steps = 100, **options)
367
- _linspace(start, finish, steps, tensor_options(**options))
368
- end
369
-
370
- def logspace(start, finish, steps = 100, base = 10.0, **options)
371
- _logspace(start, finish, steps, base, tensor_options(**options))
372
- end
373
-
374
- def ones(*size, **options)
375
- _ones(tensor_size(size), tensor_options(**options))
376
- end
377
-
378
- def rand(*size, **options)
379
- _rand(tensor_size(size), tensor_options(**options))
380
- end
381
-
382
- def randint(low = 0, high, size, **options)
383
- _randint(low, high, size, tensor_options(**options))
384
- end
385
-
386
- def randn(*size, **options)
387
- _randn(tensor_size(size), tensor_options(**options))
388
- end
389
-
390
- def randperm(n, **options)
391
- _randperm(n, tensor_options(**options))
392
- end
393
-
394
- def zeros(*size, **options)
395
- _zeros(tensor_size(size), tensor_options(**options))
396
- end
397
-
398
372
  def tensor(data, **options)
399
373
  if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
400
374
  numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
@@ -178,8 +178,12 @@ module Torch
178
178
  Torch.hardshrink(input, lambd)
179
179
  end
180
180
 
181
- def leaky_relu(input, negative_slope = 0.01)
182
- NN.leaky_relu(input, negative_slope)
181
+ def leaky_relu(input, negative_slope = 0.01, inplace: false)
182
+ if inplace
183
+ NN.leaky_relu!(input, negative_slope)
184
+ else
185
+ NN.leaky_relu(input, negative_slope)
186
+ end
183
187
  end
184
188
 
185
189
  def log_sigmoid(input)
@@ -390,15 +394,15 @@ module Torch
390
394
  # loss functions
391
395
 
392
396
  def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
393
- NN.binary_cross_entropy(input, target, weight, reduction)
397
+ NN.binary_cross_entropy(input, target, weight, to_reduction(reduction))
394
398
  end
395
399
 
396
400
  def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
397
- Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
401
+ Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, to_reduction(reduction))
398
402
  end
399
403
 
400
404
  def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
401
- Torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
405
+ Torch.cosine_embedding_loss(input1, input2, target, margin, to_reduction(reduction))
402
406
  end
403
407
 
404
408
  def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
@@ -407,34 +411,34 @@ module Torch
407
411
 
408
412
  def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
409
413
  # call to_a on input_lengths and target_lengths for C++
410
- Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
414
+ Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, to_reduction(reduction), zero_infinity)
411
415
  end
412
416
 
413
417
  def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
414
- Torch.hinge_embedding_loss(input, target, margin, reduction)
418
+ Torch.hinge_embedding_loss(input, target, margin, to_reduction(reduction))
415
419
  end
416
420
 
417
421
  def kl_div(input, target, reduction: "mean")
418
- Torch.kl_div(input, target, reduction)
422
+ Torch.kl_div(input, target, to_reduction(reduction))
419
423
  end
420
424
 
421
425
  def l1_loss(input, target, reduction: "mean")
422
- NN.l1_loss(input, target, reduction)
426
+ NN.l1_loss(input, target, to_reduction(reduction))
423
427
  end
424
428
 
425
429
  def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
426
- Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
430
+ Torch.margin_ranking_loss(input1, input2, target, margin, to_reduction(reduction))
427
431
  end
428
432
 
429
433
  def mse_loss(input, target, reduction: "mean")
430
434
  if target.size != input.size
431
435
  warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
432
436
  end
433
- NN.mse_loss(input, target, reduction)
437
+ NN.mse_loss(input, target, to_reduction(reduction))
434
438
  end
435
439
 
436
440
  def multilabel_margin_loss(input, target, reduction: "mean")
437
- NN.multilabel_margin_loss(input, target, reduction)
441
+ NN.multilabel_margin_loss(input, target, to_reduction(reduction))
438
442
  end
439
443
 
440
444
  def multilabel_soft_margin_loss(input, target, weight: nil)
@@ -442,31 +446,116 @@ module Torch
442
446
  end
443
447
 
444
448
  def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
445
- NN.multi_margin_loss(input, target, p, margin, weight, reduction)
449
+ NN.multi_margin_loss(input, target, p, margin, weight, to_reduction(reduction))
446
450
  end
447
451
 
448
452
  def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
449
- NN.nll_loss(input, target, weight, reduction, ignore_index)
453
+ NN.nll_loss(input, target, weight, to_reduction(reduction), ignore_index)
450
454
  end
451
455
 
452
456
  def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
453
- Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
457
+ Torch.poisson_nll_loss(input, target, log_input, full, eps, to_reduction(reduction))
454
458
  end
455
459
 
456
460
  def soft_margin_loss(input, target, reduction: "mean")
457
- NN.soft_margin_loss(input, target, reduction)
461
+ NN.soft_margin_loss(input, target, to_reduction(reduction))
458
462
  end
459
463
 
460
464
  def smooth_l1_loss(input, target, reduction: "mean")
461
- NN.smooth_l1_loss(input, target, reduction)
465
+ NN.smooth_l1_loss(input, target, to_reduction(reduction))
462
466
  end
463
467
 
464
468
  def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
465
- Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
469
+ Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
470
+ end
471
+
472
+ # vision
473
+
474
+ def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
475
+ if ["nearest", "area"].include?(mode)
476
+ unless align_corners.nil?
477
+ raise ArgumentError, "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
478
+ end
479
+ else
480
+ if align_corners.nil?
481
+ align_corners = false
482
+ end
483
+ end
484
+
485
+ scale_factor_len = input.dim - 2
486
+ scale_factor_list = [nil] * scale_factor_len
487
+ # default value of recompute_scale_factor is False
488
+ if !scale_factor.nil? && (recompute_scale_factor == false || recompute_scale_factor.nil?)
489
+ if scale_factor.is_a?(Array)
490
+ _scale_factor_repeated = scale_factor
491
+ else
492
+ _scale_factor_repeated = [scale_factor] * scale_factor_len
493
+ end
494
+ scale_factor_list = _scale_factor_repeated
495
+ end
496
+
497
+ # Give this variable a short name because it has to be repeated multiple times below.
498
+ sfl = scale_factor_list
499
+
500
+ closed_over_args = [input, size, scale_factor, recompute_scale_factor]
501
+ output_size = _interp_output_size(closed_over_args)
502
+ if input.dim == 3 && mode == "nearest"
503
+ NN.upsample_nearest1d(input, output_size, sfl[0])
504
+ elsif input.dim == 4 && mode == "nearest"
505
+ NN.upsample_nearest2d(input, output_size, sfl[0], sfl[1])
506
+ elsif input.dim == 5 && mode == "nearest"
507
+ NN.upsample_nearest3d(input, output_size, sfl[0], sfl[1], sfl[2])
508
+ elsif input.dim == 3 && mode == "area"
509
+ adaptive_avg_pool1d(input, output_size)
510
+ elsif input.dim == 4 && mode == "area"
511
+ adaptive_avg_pool2d(input, output_size)
512
+ elsif input.dim == 5 && mode == "area"
513
+ adaptive_avg_pool3d(input, output_size)
514
+ elsif input.dim == 3 && mode == "linear"
515
+ # assert align_corners is not None
516
+ NN.upsample_linear1d(input, output_size, align_corners, sfl[0])
517
+ elsif input.dim == 3 && mode == "bilinear"
518
+ raise ArgumentError, "Got 3D input, but bilinear mode needs 4D input"
519
+ elsif input.dim == 3 && mode == "trilinear"
520
+ raise ArgumentError, "Got 3D input, but trilinear mode needs 5D input"
521
+ elsif input.dim == 4 && mode == "linear"
522
+ raise ArgumentError, "Got 4D input, but linear mode needs 3D input"
523
+ elsif input.dim == 4 && mode == "bilinear"
524
+ # assert align_corners is not None
525
+ NN.upsample_bilinear2d(input, output_size, align_corners, sfl[0], sfl[1])
526
+ elsif input.dim == 4 && mode == "trilinear"
527
+ raise ArgumentError, "Got 4D input, but trilinear mode needs 5D input"
528
+ elsif input.dim == 5 && mode == "linear"
529
+ raise ArgumentError, "Got 5D input, but linear mode needs 3D input"
530
+ elsif input.dim == 5 && mode == "bilinear"
531
+ raise ArgumentError, "Got 5D input, but bilinear mode needs 4D input"
532
+ elsif input.dim == 5 && mode == "trilinear"
533
+ # assert align_corners is not None
534
+ NN.upsample_trilinear3d(input, output_size, align_corners, sfl[0], sfl[1], sfl[2])
535
+ elsif input.dim == 4 && mode == "bicubic"
536
+ # assert align_corners is not None
537
+ NN.upsample_bicubic2d(input, output_size, align_corners, sfl[0], sfl[1])
538
+ else
539
+ raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})"
540
+ end
466
541
  end
467
542
 
468
543
  private
469
544
 
545
+ # see _reduction.py
546
+ def to_reduction(v)
547
+ case v.to_s
548
+ when "none"
549
+ 0
550
+ when "mean"
551
+ 1
552
+ when "sum"
553
+ 2
554
+ else
555
+ raise ArgumentError, "#{v} is not a valid value for reduction"
556
+ end
557
+ end
558
+
470
559
  def softmax_dim(ndim)
471
560
  ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
472
561
  end
@@ -480,6 +569,41 @@ module Torch
480
569
  out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
481
570
  end
482
571
  end
572
+
573
+ def _interp_output_size(closed_over_args)
574
+ input, size, scale_factor, recompute_scale_factor = closed_over_args
575
+ dim = input.dim - 2
576
+ if size.nil? && scale_factor.nil?
577
+ raise ArgumentError, "either size or scale_factor should be defined"
578
+ end
579
+ if !size.nil? && !scale_factor.nil?
580
+ raise ArgumentError, "only one of size or scale_factor should be defined"
581
+ end
582
+ if !scale_factor.nil?
583
+ if scale_factor.is_a?(Array)
584
+ if scale_factor.length != dim
585
+ raise ArgumentError, "scale_factor shape must match input shape. Input is #{dim}D, scale_factor size is #{scale_factor.length}"
586
+ end
587
+ end
588
+ end
589
+
590
+ if !size.nil?
591
+ if size.is_a?(Array)
592
+ return size
593
+ else
594
+ return [size] * dim
595
+ end
596
+ end
597
+
598
+ raise "Failed assertion" if scale_factor.nil?
599
+ if scale_factor.is_a?(Array)
600
+ scale_factors = scale_factor
601
+ else
602
+ scale_factors = [scale_factor] * dim
603
+ end
604
+
605
+ dim.times.map { |i| (input.size(i + 2) * scale_factors[i]).floor }
606
+ end
483
607
  end
484
608
  end
485
609