torch-rb 0.3.7 → 0.5.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_tensor_functions(Module m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Module m);
@@ -0,0 +1,42 @@
1
+ #pragma once
2
+
3
+ #include <rice/Symbol.hpp>
4
+
5
+ // keep THP prefix for now to make it easier to compare code
6
+
7
+ extern VALUE THPVariableClass;
8
+
9
+ inline VALUE THPUtils_internSymbol(const std::string& str) {
10
+ return Symbol(str);
11
+ }
12
+
13
+ inline std::string THPUtils_unpackSymbol(VALUE obj) {
14
+ Check_Type(obj, T_SYMBOL);
15
+ obj = rb_funcall(obj, rb_intern("to_s"), 0);
16
+ return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
17
+ }
18
+
19
+ inline std::string THPUtils_unpackString(VALUE obj) {
20
+ Check_Type(obj, T_STRING);
21
+ return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
22
+ }
23
+
24
+ inline bool THPUtils_checkSymbol(VALUE obj) {
25
+ return SYMBOL_P(obj);
26
+ }
27
+
28
+ inline bool THPUtils_checkIndex(VALUE obj) {
29
+ return FIXNUM_P(obj);
30
+ }
31
+
32
+ inline bool THPUtils_checkScalar(VALUE obj) {
33
+ return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
34
+ }
35
+
36
+ inline bool THPVariable_Check(VALUE obj) {
37
+ return rb_obj_is_kind_of(obj, THPVariableClass);
38
+ }
39
+
40
+ inline bool THPVariable_CheckExact(VALUE obj) {
41
+ return rb_obj_is_instance_of(obj, THPVariableClass);
42
+ }
@@ -1,43 +1,44 @@
1
+ #pragma once
2
+
1
3
  #include <torch/torch.h>
2
4
  #include <rice/Object.hpp>
3
- #include "templates.hpp"
4
5
 
5
- Object wrap(bool x) {
6
+ inline Object wrap(bool x) {
6
7
  return to_ruby<bool>(x);
7
8
  }
8
9
 
9
- Object wrap(int64_t x) {
10
+ inline Object wrap(int64_t x) {
10
11
  return to_ruby<int64_t>(x);
11
12
  }
12
13
 
13
- Object wrap(double x) {
14
+ inline Object wrap(double x) {
14
15
  return to_ruby<double>(x);
15
16
  }
16
17
 
17
- Object wrap(torch::Tensor x) {
18
+ inline Object wrap(torch::Tensor x) {
18
19
  return to_ruby<torch::Tensor>(x);
19
20
  }
20
21
 
21
- Object wrap(torch::Scalar x) {
22
+ inline Object wrap(torch::Scalar x) {
22
23
  return to_ruby<torch::Scalar>(x);
23
24
  }
24
25
 
25
- Object wrap(torch::ScalarType x) {
26
+ inline Object wrap(torch::ScalarType x) {
26
27
  return to_ruby<torch::ScalarType>(x);
27
28
  }
28
29
 
29
- Object wrap(torch::QScheme x) {
30
+ inline Object wrap(torch::QScheme x) {
30
31
  return to_ruby<torch::QScheme>(x);
31
32
  }
32
33
 
33
- Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
34
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
34
35
  Array a;
35
36
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
36
37
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
37
38
  return Object(a);
38
39
  }
39
40
 
40
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
41
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
41
42
  Array a;
42
43
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
43
44
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -45,7 +46,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
45
46
  return Object(a);
46
47
  }
47
48
 
48
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
49
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
49
50
  Array a;
50
51
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
51
52
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -54,7 +55,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
54
55
  return Object(a);
55
56
  }
56
57
 
57
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
58
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
58
59
  Array a;
59
60
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
60
61
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -64,7 +65,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
64
65
  return Object(a);
65
66
  }
66
67
 
67
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
68
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
68
69
  Array a;
69
70
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
70
71
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -73,7 +74,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x)
73
74
  return Object(a);
74
75
  }
75
76
 
76
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
77
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
77
78
  Array a;
78
79
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
79
80
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -82,10 +83,17 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
82
83
  return Object(a);
83
84
  }
84
85
 
85
- Object wrap(std::vector<torch::Tensor> x) {
86
+ inline Object wrap(torch::TensorList x) {
86
87
  Array a;
87
88
  for (auto& t : x) {
88
89
  a.push(to_ruby<torch::Tensor>(t));
89
90
  }
90
91
  return Object(a);
91
92
  }
93
+
94
+ inline Object wrap(std::tuple<double, double> x) {
95
+ Array a;
96
+ a.push(to_ruby<double>(std::get<0>(x)));
97
+ a.push(to_ruby<double>(std::get<1>(x)));
98
+ return Object(a);
99
+ }
@@ -7,11 +7,6 @@ require "net/http"
7
7
  require "set"
8
8
  require "tmpdir"
9
9
 
10
- # native functions
11
- require "torch/native/generator"
12
- require "torch/native/parser"
13
- require "torch/native/dispatcher"
14
-
15
10
  # modules
16
11
  require "torch/inspector"
17
12
  require "torch/tensor"
@@ -266,6 +261,8 @@ module Torch
266
261
  Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
267
262
  elsif args.size == 1 && args.first.is_a?(Array)
268
263
  Torch.tensor(args.first, dtype: dtype, device: device)
264
+ elsif args.size == 0
265
+ Torch.empty(0, dtype: dtype, device: device)
269
266
  else
270
267
  Torch.empty(*args, dtype: dtype, device: device)
271
268
  end
@@ -374,63 +371,6 @@ module Torch
374
371
 
375
372
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
376
373
 
377
- def arange(start, finish = nil, step = 1, **options)
378
- # ruby doesn't support start = 0, finish, step = 1, ...
379
- if finish.nil?
380
- finish = start
381
- start = 0
382
- end
383
- _arange(start, finish, step, tensor_options(**options))
384
- end
385
-
386
- def empty(*size, **options)
387
- _empty(tensor_size(size), tensor_options(**options))
388
- end
389
-
390
- def eye(n, m = nil, **options)
391
- _eye(n, m || n, tensor_options(**options))
392
- end
393
-
394
- def full(size, fill_value, **options)
395
- _full(size, fill_value, tensor_options(**options))
396
- end
397
-
398
- def linspace(start, finish, steps = 100, **options)
399
- _linspace(start, finish, steps, tensor_options(**options))
400
- end
401
-
402
- def logspace(start, finish, steps = 100, base = 10.0, **options)
403
- _logspace(start, finish, steps, base, tensor_options(**options))
404
- end
405
-
406
- def ones(*size, **options)
407
- _ones(tensor_size(size), tensor_options(**options))
408
- end
409
-
410
- def rand(*size, **options)
411
- _rand(tensor_size(size), tensor_options(**options))
412
- end
413
-
414
- def randint(low = 0, high, size, **options)
415
- _randint(low, high, size, tensor_options(**options))
416
- end
417
-
418
- def randn(*size, **options)
419
- _randn(tensor_size(size), tensor_options(**options))
420
- end
421
-
422
- def randperm(n, **options)
423
- # dtype hack in Python
424
- # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
425
- options[:dtype] ||= :int64
426
-
427
- _randperm(n, tensor_options(**options))
428
- end
429
-
430
- def zeros(*size, **options)
431
- _zeros(tensor_size(size), tensor_options(**options))
432
- end
433
-
434
374
  def tensor(data, **options)
435
375
  if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
436
376
  numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
@@ -496,7 +436,8 @@ module Torch
496
436
  zeros(input.size, **like_options(input, options))
497
437
  end
498
438
 
499
- def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
439
+ # center option
440
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
500
441
  if center
501
442
  signal_dim = input.dim
502
443
  extended_shape = [1] * (3 - signal_dim) + input.size
@@ -504,12 +445,7 @@ module Torch
504
445
  input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
505
446
  input = input.view(input.shape[-signal_dim..-1])
506
447
  end
507
- _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
508
- end
509
-
510
- def clamp(tensor, min, max)
511
- tensor = _clamp_min(tensor, min)
512
- _clamp_max(tensor, max)
448
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
513
449
  end
514
450
 
515
451
  private
@@ -394,15 +394,15 @@ module Torch
394
394
  # loss functions
395
395
 
396
396
  def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
397
- NN.binary_cross_entropy(input, target, weight, reduction)
397
+ NN.binary_cross_entropy(input, target, weight, to_reduction(reduction))
398
398
  end
399
399
 
400
400
  def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
401
- Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
401
+ Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, to_reduction(reduction))
402
402
  end
403
403
 
404
404
  def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
405
- Torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
405
+ Torch.cosine_embedding_loss(input1, input2, target, margin, to_reduction(reduction))
406
406
  end
407
407
 
408
408
  def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
@@ -411,34 +411,34 @@ module Torch
411
411
 
412
412
  def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
413
413
  # call to_a on input_lengths and target_lengths for C++
414
- Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
414
+ Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, to_reduction(reduction), zero_infinity)
415
415
  end
416
416
 
417
417
  def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
418
- Torch.hinge_embedding_loss(input, target, margin, reduction)
418
+ Torch.hinge_embedding_loss(input, target, margin, to_reduction(reduction))
419
419
  end
420
420
 
421
421
  def kl_div(input, target, reduction: "mean")
422
- Torch.kl_div(input, target, reduction)
422
+ Torch.kl_div(input, target, to_reduction(reduction))
423
423
  end
424
424
 
425
425
  def l1_loss(input, target, reduction: "mean")
426
- NN.l1_loss(input, target, reduction)
426
+ NN.l1_loss(input, target, to_reduction(reduction))
427
427
  end
428
428
 
429
429
  def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
430
- Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
430
+ Torch.margin_ranking_loss(input1, input2, target, margin, to_reduction(reduction))
431
431
  end
432
432
 
433
433
  def mse_loss(input, target, reduction: "mean")
434
434
  if target.size != input.size
435
435
  warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
436
436
  end
437
- NN.mse_loss(input, target, reduction)
437
+ NN.mse_loss(input, target, to_reduction(reduction))
438
438
  end
439
439
 
440
440
  def multilabel_margin_loss(input, target, reduction: "mean")
441
- NN.multilabel_margin_loss(input, target, reduction)
441
+ NN.multilabel_margin_loss(input, target, to_reduction(reduction))
442
442
  end
443
443
 
444
444
  def multilabel_soft_margin_loss(input, target, weight: nil)
@@ -446,27 +446,27 @@ module Torch
446
446
  end
447
447
 
448
448
  def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
449
- NN.multi_margin_loss(input, target, p, margin, weight, reduction)
449
+ NN.multi_margin_loss(input, target, p, margin, weight, to_reduction(reduction))
450
450
  end
451
451
 
452
452
  def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
453
- NN.nll_loss(input, target, weight, reduction, ignore_index)
453
+ NN.nll_loss(input, target, weight, to_reduction(reduction), ignore_index)
454
454
  end
455
455
 
456
456
  def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
457
- Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
457
+ Torch.poisson_nll_loss(input, target, log_input, full, eps, to_reduction(reduction))
458
458
  end
459
459
 
460
460
  def soft_margin_loss(input, target, reduction: "mean")
461
- NN.soft_margin_loss(input, target, reduction)
461
+ NN.soft_margin_loss(input, target, to_reduction(reduction))
462
462
  end
463
463
 
464
464
  def smooth_l1_loss(input, target, reduction: "mean")
465
- NN.smooth_l1_loss(input, target, reduction)
465
+ NN.smooth_l1_loss(input, target, to_reduction(reduction))
466
466
  end
467
467
 
468
468
  def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
469
- Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
469
+ Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
470
470
  end
471
471
 
472
472
  # vision
@@ -542,6 +542,20 @@ module Torch
542
542
 
543
543
  private
544
544
 
545
+ # see _reduction.py
546
+ def to_reduction(v)
547
+ case v.to_s
548
+ when "none"
549
+ 0
550
+ when "mean"
551
+ 1
552
+ when "sum"
553
+ 2
554
+ else
555
+ raise ArgumentError, "#{v} is not a valid value for reduction"
556
+ end
557
+ end
558
+
545
559
  def softmax_dim(ndim)
546
560
  ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
547
561
  end
@@ -14,25 +14,11 @@ module Torch
14
14
  _normal!(tensor, mean, std)
15
15
  end
16
16
 
17
- def constant!(tensor, val)
18
- _constant!(tensor, val)
19
- end
20
-
21
- def ones!(tensor)
22
- _ones!(tensor)
23
- end
24
-
25
- def zeros!(tensor)
26
- _zeros!(tensor)
27
- end
28
-
29
- def eye!(tensor)
30
- _eye!(tensor)
31
- end
32
-
33
- def dirac!(tensor)
34
- _dirac!(tensor)
35
- end
17
+ alias_method :constant!, :_constant!
18
+ alias_method :ones!, :_ones!
19
+ alias_method :zeros!, :_zeros!
20
+ alias_method :eye!, :_eye!
21
+ alias_method :dirac!, :_dirac!
36
22
 
37
23
  def xavier_uniform!(tensor, gain: 1.0)
38
24
  _xavier_uniform!(tensor, gain)
@@ -39,14 +39,14 @@ module Torch
39
39
  state[:step] += 1
40
40
 
41
41
  if group[:weight_decay] != 0
42
- grad = grad.add(group[:weight_decay], p.data)
42
+ grad = grad.add(p.data, alpha: group[:weight_decay])
43
43
  end
44
44
 
45
- square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
45
+ square_avg.mul!(rho).addcmul!(grad, grad, value: 1 - rho)
46
46
  std = square_avg.add(eps).sqrt!
47
47
  delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
48
- p.data.add!(-group[:lr], delta)
49
- acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
48
+ p.data.add!(delta, alpha: -group[:lr])
49
+ acc_delta.mul!(rho).addcmul!(delta, delta, value: 1 - rho)
50
50
  end
51
51
  end
52
52
 
@@ -49,7 +49,7 @@ module Torch
49
49
  if p.grad.data.sparse?
50
50
  raise Error, "weight_decay option is not compatible with sparse gradients"
51
51
  end
52
- grad = grad.add(group[:weight_decay], p.data)
52
+ grad = grad.add(p.data, alpha: group[:weight_decay])
53
53
  end
54
54
 
55
55
  clr = group[:lr] / (1 + (state[:step] - 1) * group[:lr_decay])
@@ -57,9 +57,9 @@ module Torch
57
57
  if grad.sparse?
58
58
  raise NotImplementedYet
59
59
  else
60
- state[:sum].addcmul!(1, grad, grad)
60
+ state[:sum].addcmul!(grad, grad, value: 1)
61
61
  std = state[:sum].sqrt.add!(group[:eps])
62
- p.data.addcdiv!(-clr, grad, std)
62
+ p.data.addcdiv!(grad, std, value: -clr)
63
63
  end
64
64
  end
65
65
  end