torch-rb 0.3.7 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_tensor_functions(Module m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Module m);
@@ -0,0 +1,42 @@
1
+ #pragma once
2
+
3
+ #include <rice/Symbol.hpp>
4
+
5
+ // keep THP prefix for now to make it easier to compare code
6
+
7
+ extern VALUE THPVariableClass;
8
+
9
+ inline VALUE THPUtils_internSymbol(const std::string& str) {
10
+ return Symbol(str);
11
+ }
12
+
13
+ inline std::string THPUtils_unpackSymbol(VALUE obj) {
14
+ Check_Type(obj, T_SYMBOL);
15
+ obj = rb_funcall(obj, rb_intern("to_s"), 0);
16
+ return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
17
+ }
18
+
19
+ inline std::string THPUtils_unpackString(VALUE obj) {
20
+ Check_Type(obj, T_STRING);
21
+ return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
22
+ }
23
+
24
+ inline bool THPUtils_checkSymbol(VALUE obj) {
25
+ return SYMBOL_P(obj);
26
+ }
27
+
28
+ inline bool THPUtils_checkIndex(VALUE obj) {
29
+ return FIXNUM_P(obj);
30
+ }
31
+
32
+ inline bool THPUtils_checkScalar(VALUE obj) {
33
+ return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
34
+ }
35
+
36
+ inline bool THPVariable_Check(VALUE obj) {
37
+ return rb_obj_is_kind_of(obj, THPVariableClass);
38
+ }
39
+
40
+ inline bool THPVariable_CheckExact(VALUE obj) {
41
+ return rb_obj_is_instance_of(obj, THPVariableClass);
42
+ }
@@ -1,43 +1,44 @@
1
+ #pragma once
2
+
1
3
  #include <torch/torch.h>
2
4
  #include <rice/Object.hpp>
3
- #include "templates.hpp"
4
5
 
5
- Object wrap(bool x) {
6
+ inline Object wrap(bool x) {
6
7
  return to_ruby<bool>(x);
7
8
  }
8
9
 
9
- Object wrap(int64_t x) {
10
+ inline Object wrap(int64_t x) {
10
11
  return to_ruby<int64_t>(x);
11
12
  }
12
13
 
13
- Object wrap(double x) {
14
+ inline Object wrap(double x) {
14
15
  return to_ruby<double>(x);
15
16
  }
16
17
 
17
- Object wrap(torch::Tensor x) {
18
+ inline Object wrap(torch::Tensor x) {
18
19
  return to_ruby<torch::Tensor>(x);
19
20
  }
20
21
 
21
- Object wrap(torch::Scalar x) {
22
+ inline Object wrap(torch::Scalar x) {
22
23
  return to_ruby<torch::Scalar>(x);
23
24
  }
24
25
 
25
- Object wrap(torch::ScalarType x) {
26
+ inline Object wrap(torch::ScalarType x) {
26
27
  return to_ruby<torch::ScalarType>(x);
27
28
  }
28
29
 
29
- Object wrap(torch::QScheme x) {
30
+ inline Object wrap(torch::QScheme x) {
30
31
  return to_ruby<torch::QScheme>(x);
31
32
  }
32
33
 
33
- Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
34
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
34
35
  Array a;
35
36
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
36
37
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
37
38
  return Object(a);
38
39
  }
39
40
 
40
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
41
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
41
42
  Array a;
42
43
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
43
44
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -45,7 +46,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
45
46
  return Object(a);
46
47
  }
47
48
 
48
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
49
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
49
50
  Array a;
50
51
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
51
52
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -54,7 +55,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
54
55
  return Object(a);
55
56
  }
56
57
 
57
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
58
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
58
59
  Array a;
59
60
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
60
61
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -64,7 +65,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
64
65
  return Object(a);
65
66
  }
66
67
 
67
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
68
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
68
69
  Array a;
69
70
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
70
71
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -73,7 +74,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x)
73
74
  return Object(a);
74
75
  }
75
76
 
76
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
77
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
77
78
  Array a;
78
79
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
79
80
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -82,10 +83,17 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
82
83
  return Object(a);
83
84
  }
84
85
 
85
- Object wrap(std::vector<torch::Tensor> x) {
86
+ inline Object wrap(torch::TensorList x) {
86
87
  Array a;
87
88
  for (auto& t : x) {
88
89
  a.push(to_ruby<torch::Tensor>(t));
89
90
  }
90
91
  return Object(a);
91
92
  }
93
+
94
+ inline Object wrap(std::tuple<double, double> x) {
95
+ Array a;
96
+ a.push(to_ruby<double>(std::get<0>(x)));
97
+ a.push(to_ruby<double>(std::get<1>(x)));
98
+ return Object(a);
99
+ }
@@ -7,11 +7,6 @@ require "net/http"
7
7
  require "set"
8
8
  require "tmpdir"
9
9
 
10
- # native functions
11
- require "torch/native/generator"
12
- require "torch/native/parser"
13
- require "torch/native/dispatcher"
14
-
15
10
  # modules
16
11
  require "torch/inspector"
17
12
  require "torch/tensor"
@@ -266,6 +261,8 @@ module Torch
266
261
  Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
267
262
  elsif args.size == 1 && args.first.is_a?(Array)
268
263
  Torch.tensor(args.first, dtype: dtype, device: device)
264
+ elsif args.size == 0
265
+ Torch.empty(0, dtype: dtype, device: device)
269
266
  else
270
267
  Torch.empty(*args, dtype: dtype, device: device)
271
268
  end
@@ -374,63 +371,6 @@ module Torch
374
371
 
375
372
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
376
373
 
377
- def arange(start, finish = nil, step = 1, **options)
378
- # ruby doesn't support start = 0, finish, step = 1, ...
379
- if finish.nil?
380
- finish = start
381
- start = 0
382
- end
383
- _arange(start, finish, step, tensor_options(**options))
384
- end
385
-
386
- def empty(*size, **options)
387
- _empty(tensor_size(size), tensor_options(**options))
388
- end
389
-
390
- def eye(n, m = nil, **options)
391
- _eye(n, m || n, tensor_options(**options))
392
- end
393
-
394
- def full(size, fill_value, **options)
395
- _full(size, fill_value, tensor_options(**options))
396
- end
397
-
398
- def linspace(start, finish, steps = 100, **options)
399
- _linspace(start, finish, steps, tensor_options(**options))
400
- end
401
-
402
- def logspace(start, finish, steps = 100, base = 10.0, **options)
403
- _logspace(start, finish, steps, base, tensor_options(**options))
404
- end
405
-
406
- def ones(*size, **options)
407
- _ones(tensor_size(size), tensor_options(**options))
408
- end
409
-
410
- def rand(*size, **options)
411
- _rand(tensor_size(size), tensor_options(**options))
412
- end
413
-
414
- def randint(low = 0, high, size, **options)
415
- _randint(low, high, size, tensor_options(**options))
416
- end
417
-
418
- def randn(*size, **options)
419
- _randn(tensor_size(size), tensor_options(**options))
420
- end
421
-
422
- def randperm(n, **options)
423
- # dtype hack in Python
424
- # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
425
- options[:dtype] ||= :int64
426
-
427
- _randperm(n, tensor_options(**options))
428
- end
429
-
430
- def zeros(*size, **options)
431
- _zeros(tensor_size(size), tensor_options(**options))
432
- end
433
-
434
374
  def tensor(data, **options)
435
375
  if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
436
376
  numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
@@ -496,7 +436,8 @@ module Torch
496
436
  zeros(input.size, **like_options(input, options))
497
437
  end
498
438
 
499
- def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
439
+ # center option
440
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
500
441
  if center
501
442
  signal_dim = input.dim
502
443
  extended_shape = [1] * (3 - signal_dim) + input.size
@@ -504,12 +445,7 @@ module Torch
504
445
  input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
505
446
  input = input.view(input.shape[-signal_dim..-1])
506
447
  end
507
- _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
508
- end
509
-
510
- def clamp(tensor, min, max)
511
- tensor = _clamp_min(tensor, min)
512
- _clamp_max(tensor, max)
448
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
513
449
  end
514
450
 
515
451
  private
@@ -394,15 +394,15 @@ module Torch
394
394
  # loss functions
395
395
 
396
396
  def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
397
- NN.binary_cross_entropy(input, target, weight, reduction)
397
+ NN.binary_cross_entropy(input, target, weight, to_reduction(reduction))
398
398
  end
399
399
 
400
400
  def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
401
- Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
401
+ Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, to_reduction(reduction))
402
402
  end
403
403
 
404
404
  def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
405
- Torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
405
+ Torch.cosine_embedding_loss(input1, input2, target, margin, to_reduction(reduction))
406
406
  end
407
407
 
408
408
  def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
@@ -411,34 +411,34 @@ module Torch
411
411
 
412
412
  def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
413
413
  # call to_a on input_lengths and target_lengths for C++
414
- Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
414
+ Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, to_reduction(reduction), zero_infinity)
415
415
  end
416
416
 
417
417
  def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
418
- Torch.hinge_embedding_loss(input, target, margin, reduction)
418
+ Torch.hinge_embedding_loss(input, target, margin, to_reduction(reduction))
419
419
  end
420
420
 
421
421
  def kl_div(input, target, reduction: "mean")
422
- Torch.kl_div(input, target, reduction)
422
+ Torch.kl_div(input, target, to_reduction(reduction))
423
423
  end
424
424
 
425
425
  def l1_loss(input, target, reduction: "mean")
426
- NN.l1_loss(input, target, reduction)
426
+ NN.l1_loss(input, target, to_reduction(reduction))
427
427
  end
428
428
 
429
429
  def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
430
- Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
430
+ Torch.margin_ranking_loss(input1, input2, target, margin, to_reduction(reduction))
431
431
  end
432
432
 
433
433
  def mse_loss(input, target, reduction: "mean")
434
434
  if target.size != input.size
435
435
  warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
436
436
  end
437
- NN.mse_loss(input, target, reduction)
437
+ NN.mse_loss(input, target, to_reduction(reduction))
438
438
  end
439
439
 
440
440
  def multilabel_margin_loss(input, target, reduction: "mean")
441
- NN.multilabel_margin_loss(input, target, reduction)
441
+ NN.multilabel_margin_loss(input, target, to_reduction(reduction))
442
442
  end
443
443
 
444
444
  def multilabel_soft_margin_loss(input, target, weight: nil)
@@ -446,27 +446,27 @@ module Torch
446
446
  end
447
447
 
448
448
  def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
449
- NN.multi_margin_loss(input, target, p, margin, weight, reduction)
449
+ NN.multi_margin_loss(input, target, p, margin, weight, to_reduction(reduction))
450
450
  end
451
451
 
452
452
  def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
453
- NN.nll_loss(input, target, weight, reduction, ignore_index)
453
+ NN.nll_loss(input, target, weight, to_reduction(reduction), ignore_index)
454
454
  end
455
455
 
456
456
  def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
457
- Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
457
+ Torch.poisson_nll_loss(input, target, log_input, full, eps, to_reduction(reduction))
458
458
  end
459
459
 
460
460
  def soft_margin_loss(input, target, reduction: "mean")
461
- NN.soft_margin_loss(input, target, reduction)
461
+ NN.soft_margin_loss(input, target, to_reduction(reduction))
462
462
  end
463
463
 
464
464
  def smooth_l1_loss(input, target, reduction: "mean")
465
- NN.smooth_l1_loss(input, target, reduction)
465
+ NN.smooth_l1_loss(input, target, to_reduction(reduction))
466
466
  end
467
467
 
468
468
  def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
469
- Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
469
+ Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
470
470
  end
471
471
 
472
472
  # vision
@@ -542,6 +542,20 @@ module Torch
542
542
 
543
543
  private
544
544
 
545
+ # see _reduction.py
546
+ def to_reduction(v)
547
+ case v.to_s
548
+ when "none"
549
+ 0
550
+ when "mean"
551
+ 1
552
+ when "sum"
553
+ 2
554
+ else
555
+ raise ArgumentError, "#{v} is not a valid value for reduction"
556
+ end
557
+ end
558
+
545
559
  def softmax_dim(ndim)
546
560
  ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
547
561
  end
@@ -14,25 +14,11 @@ module Torch
14
14
  _normal!(tensor, mean, std)
15
15
  end
16
16
 
17
- def constant!(tensor, val)
18
- _constant!(tensor, val)
19
- end
20
-
21
- def ones!(tensor)
22
- _ones!(tensor)
23
- end
24
-
25
- def zeros!(tensor)
26
- _zeros!(tensor)
27
- end
28
-
29
- def eye!(tensor)
30
- _eye!(tensor)
31
- end
32
-
33
- def dirac!(tensor)
34
- _dirac!(tensor)
35
- end
17
+ alias_method :constant!, :_constant!
18
+ alias_method :ones!, :_ones!
19
+ alias_method :zeros!, :_zeros!
20
+ alias_method :eye!, :_eye!
21
+ alias_method :dirac!, :_dirac!
36
22
 
37
23
  def xavier_uniform!(tensor, gain: 1.0)
38
24
  _xavier_uniform!(tensor, gain)
@@ -39,14 +39,14 @@ module Torch
39
39
  state[:step] += 1
40
40
 
41
41
  if group[:weight_decay] != 0
42
- grad = grad.add(group[:weight_decay], p.data)
42
+ grad = grad.add(p.data, alpha: group[:weight_decay])
43
43
  end
44
44
 
45
- square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
45
+ square_avg.mul!(rho).addcmul!(grad, grad, value: 1 - rho)
46
46
  std = square_avg.add(eps).sqrt!
47
47
  delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
48
- p.data.add!(-group[:lr], delta)
49
- acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
48
+ p.data.add!(delta, alpha: -group[:lr])
49
+ acc_delta.mul!(rho).addcmul!(delta, delta, value: 1 - rho)
50
50
  end
51
51
  end
52
52
 
@@ -49,7 +49,7 @@ module Torch
49
49
  if p.grad.data.sparse?
50
50
  raise Error, "weight_decay option is not compatible with sparse gradients"
51
51
  end
52
- grad = grad.add(group[:weight_decay], p.data)
52
+ grad = grad.add(p.data, alpha: group[:weight_decay])
53
53
  end
54
54
 
55
55
  clr = group[:lr] / (1 + (state[:step] - 1) * group[:lr_decay])
@@ -57,9 +57,9 @@ module Torch
57
57
  if grad.sparse?
58
58
  raise NotImplementedYet
59
59
  else
60
- state[:sum].addcmul!(1, grad, grad)
60
+ state[:sum].addcmul!(grad, grad, value: 1)
61
61
  std = state[:sum].sqrt.add!(group[:eps])
62
- p.data.addcdiv!(-clr, grad, std)
62
+ p.data.addcdiv!(grad, std, value: -clr)
63
63
  end
64
64
  end
65
65
  end