torch-rb 0.3.7 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_tensor_functions(Module m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Module m);
@@ -0,0 +1,42 @@
1
+ #pragma once
2
+
3
+ #include <rice/Symbol.hpp>
4
+
5
+ // keep THP prefix for now to make it easier to compare code
6
+
7
+ extern VALUE THPVariableClass;
8
+
9
+ inline VALUE THPUtils_internSymbol(const std::string& str) {
10
+ return Symbol(str);
11
+ }
12
+
13
+ inline std::string THPUtils_unpackSymbol(VALUE obj) {
14
+ Check_Type(obj, T_SYMBOL);
15
+ obj = rb_funcall(obj, rb_intern("to_s"), 0);
16
+ return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
17
+ }
18
+
19
+ inline std::string THPUtils_unpackString(VALUE obj) {
20
+ Check_Type(obj, T_STRING);
21
+ return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
22
+ }
23
+
24
+ inline bool THPUtils_checkSymbol(VALUE obj) {
25
+ return SYMBOL_P(obj);
26
+ }
27
+
28
+ inline bool THPUtils_checkIndex(VALUE obj) {
29
+ return FIXNUM_P(obj);
30
+ }
31
+
32
+ inline bool THPUtils_checkScalar(VALUE obj) {
33
+ return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
34
+ }
35
+
36
+ inline bool THPVariable_Check(VALUE obj) {
37
+ return rb_obj_is_kind_of(obj, THPVariableClass);
38
+ }
39
+
40
+ inline bool THPVariable_CheckExact(VALUE obj) {
41
+ return rb_obj_is_instance_of(obj, THPVariableClass);
42
+ }
@@ -1,43 +1,44 @@
1
+ #pragma once
2
+
1
3
  #include <torch/torch.h>
2
4
  #include <rice/Object.hpp>
3
- #include "templates.hpp"
4
5
 
5
- Object wrap(bool x) {
6
+ inline Object wrap(bool x) {
6
7
  return to_ruby<bool>(x);
7
8
  }
8
9
 
9
- Object wrap(int64_t x) {
10
+ inline Object wrap(int64_t x) {
10
11
  return to_ruby<int64_t>(x);
11
12
  }
12
13
 
13
- Object wrap(double x) {
14
+ inline Object wrap(double x) {
14
15
  return to_ruby<double>(x);
15
16
  }
16
17
 
17
- Object wrap(torch::Tensor x) {
18
+ inline Object wrap(torch::Tensor x) {
18
19
  return to_ruby<torch::Tensor>(x);
19
20
  }
20
21
 
21
- Object wrap(torch::Scalar x) {
22
+ inline Object wrap(torch::Scalar x) {
22
23
  return to_ruby<torch::Scalar>(x);
23
24
  }
24
25
 
25
- Object wrap(torch::ScalarType x) {
26
+ inline Object wrap(torch::ScalarType x) {
26
27
  return to_ruby<torch::ScalarType>(x);
27
28
  }
28
29
 
29
- Object wrap(torch::QScheme x) {
30
+ inline Object wrap(torch::QScheme x) {
30
31
  return to_ruby<torch::QScheme>(x);
31
32
  }
32
33
 
33
- Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
34
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
34
35
  Array a;
35
36
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
36
37
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
37
38
  return Object(a);
38
39
  }
39
40
 
40
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
41
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
41
42
  Array a;
42
43
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
43
44
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -45,7 +46,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
45
46
  return Object(a);
46
47
  }
47
48
 
48
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
49
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
49
50
  Array a;
50
51
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
51
52
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -54,7 +55,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
54
55
  return Object(a);
55
56
  }
56
57
 
57
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
58
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
58
59
  Array a;
59
60
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
60
61
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -64,7 +65,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
64
65
  return Object(a);
65
66
  }
66
67
 
67
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
68
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
68
69
  Array a;
69
70
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
70
71
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -73,7 +74,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x)
73
74
  return Object(a);
74
75
  }
75
76
 
76
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
77
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
77
78
  Array a;
78
79
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
79
80
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -82,7 +83,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
82
83
  return Object(a);
83
84
  }
84
85
 
85
- Object wrap(std::vector<torch::Tensor> x) {
86
+ inline Object wrap(torch::TensorList x) {
86
87
  Array a;
87
88
  for (auto& t : x) {
88
89
  a.push(to_ruby<torch::Tensor>(t));
@@ -7,11 +7,6 @@ require "net/http"
7
7
  require "set"
8
8
  require "tmpdir"
9
9
 
10
- # native functions
11
- require "torch/native/generator"
12
- require "torch/native/parser"
13
- require "torch/native/dispatcher"
14
-
15
10
  # modules
16
11
  require "torch/inspector"
17
12
  require "torch/tensor"
@@ -374,63 +369,6 @@ module Torch
374
369
 
375
370
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
376
371
 
377
- def arange(start, finish = nil, step = 1, **options)
378
- # ruby doesn't support start = 0, finish, step = 1, ...
379
- if finish.nil?
380
- finish = start
381
- start = 0
382
- end
383
- _arange(start, finish, step, tensor_options(**options))
384
- end
385
-
386
- def empty(*size, **options)
387
- _empty(tensor_size(size), tensor_options(**options))
388
- end
389
-
390
- def eye(n, m = nil, **options)
391
- _eye(n, m || n, tensor_options(**options))
392
- end
393
-
394
- def full(size, fill_value, **options)
395
- _full(size, fill_value, tensor_options(**options))
396
- end
397
-
398
- def linspace(start, finish, steps = 100, **options)
399
- _linspace(start, finish, steps, tensor_options(**options))
400
- end
401
-
402
- def logspace(start, finish, steps = 100, base = 10.0, **options)
403
- _logspace(start, finish, steps, base, tensor_options(**options))
404
- end
405
-
406
- def ones(*size, **options)
407
- _ones(tensor_size(size), tensor_options(**options))
408
- end
409
-
410
- def rand(*size, **options)
411
- _rand(tensor_size(size), tensor_options(**options))
412
- end
413
-
414
- def randint(low = 0, high, size, **options)
415
- _randint(low, high, size, tensor_options(**options))
416
- end
417
-
418
- def randn(*size, **options)
419
- _randn(tensor_size(size), tensor_options(**options))
420
- end
421
-
422
- def randperm(n, **options)
423
- # dtype hack in Python
424
- # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
425
- options[:dtype] ||= :int64
426
-
427
- _randperm(n, tensor_options(**options))
428
- end
429
-
430
- def zeros(*size, **options)
431
- _zeros(tensor_size(size), tensor_options(**options))
432
- end
433
-
434
372
  def tensor(data, **options)
435
373
  if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
436
374
  numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
@@ -394,15 +394,15 @@ module Torch
394
394
  # loss functions
395
395
 
396
396
  def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
397
- NN.binary_cross_entropy(input, target, weight, reduction)
397
+ NN.binary_cross_entropy(input, target, weight, to_reduction(reduction))
398
398
  end
399
399
 
400
400
  def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
401
- Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
401
+ Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, to_reduction(reduction))
402
402
  end
403
403
 
404
404
  def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
405
- Torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
405
+ Torch.cosine_embedding_loss(input1, input2, target, margin, to_reduction(reduction))
406
406
  end
407
407
 
408
408
  def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
@@ -411,34 +411,34 @@ module Torch
411
411
 
412
412
  def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
413
413
  # call to_a on input_lengths and target_lengths for C++
414
- Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
414
+ Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, to_reduction(reduction), zero_infinity)
415
415
  end
416
416
 
417
417
  def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
418
- Torch.hinge_embedding_loss(input, target, margin, reduction)
418
+ Torch.hinge_embedding_loss(input, target, margin, to_reduction(reduction))
419
419
  end
420
420
 
421
421
  def kl_div(input, target, reduction: "mean")
422
- Torch.kl_div(input, target, reduction)
422
+ Torch.kl_div(input, target, to_reduction(reduction))
423
423
  end
424
424
 
425
425
  def l1_loss(input, target, reduction: "mean")
426
- NN.l1_loss(input, target, reduction)
426
+ NN.l1_loss(input, target, to_reduction(reduction))
427
427
  end
428
428
 
429
429
  def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
430
- Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
430
+ Torch.margin_ranking_loss(input1, input2, target, margin, to_reduction(reduction))
431
431
  end
432
432
 
433
433
  def mse_loss(input, target, reduction: "mean")
434
434
  if target.size != input.size
435
435
  warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
436
436
  end
437
- NN.mse_loss(input, target, reduction)
437
+ NN.mse_loss(input, target, to_reduction(reduction))
438
438
  end
439
439
 
440
440
  def multilabel_margin_loss(input, target, reduction: "mean")
441
- NN.multilabel_margin_loss(input, target, reduction)
441
+ NN.multilabel_margin_loss(input, target, to_reduction(reduction))
442
442
  end
443
443
 
444
444
  def multilabel_soft_margin_loss(input, target, weight: nil)
@@ -446,27 +446,27 @@ module Torch
446
446
  end
447
447
 
448
448
  def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
449
- NN.multi_margin_loss(input, target, p, margin, weight, reduction)
449
+ NN.multi_margin_loss(input, target, p, margin, weight, to_reduction(reduction))
450
450
  end
451
451
 
452
452
  def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
453
- NN.nll_loss(input, target, weight, reduction, ignore_index)
453
+ NN.nll_loss(input, target, weight, to_reduction(reduction), ignore_index)
454
454
  end
455
455
 
456
456
  def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
457
- Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
457
+ Torch.poisson_nll_loss(input, target, log_input, full, eps, to_reduction(reduction))
458
458
  end
459
459
 
460
460
  def soft_margin_loss(input, target, reduction: "mean")
461
- NN.soft_margin_loss(input, target, reduction)
461
+ NN.soft_margin_loss(input, target, to_reduction(reduction))
462
462
  end
463
463
 
464
464
  def smooth_l1_loss(input, target, reduction: "mean")
465
- NN.smooth_l1_loss(input, target, reduction)
465
+ NN.smooth_l1_loss(input, target, to_reduction(reduction))
466
466
  end
467
467
 
468
468
  def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
469
- Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
469
+ Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
470
470
  end
471
471
 
472
472
  # vision
@@ -542,6 +542,20 @@ module Torch
542
542
 
543
543
  private
544
544
 
545
+ # see _reduction.py
546
+ def to_reduction(v)
547
+ case v.to_s
548
+ when "none"
549
+ 0
550
+ when "mean"
551
+ 1
552
+ when "sum"
553
+ 2
554
+ else
555
+ raise ArgumentError, "#{v} is not a valid value for reduction"
556
+ end
557
+ end
558
+
545
559
  def softmax_dim(ndim)
546
560
  ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
547
561
  end
@@ -14,25 +14,11 @@ module Torch
14
14
  _normal!(tensor, mean, std)
15
15
  end
16
16
 
17
- def constant!(tensor, val)
18
- _constant!(tensor, val)
19
- end
20
-
21
- def ones!(tensor)
22
- _ones!(tensor)
23
- end
24
-
25
- def zeros!(tensor)
26
- _zeros!(tensor)
27
- end
28
-
29
- def eye!(tensor)
30
- _eye!(tensor)
31
- end
32
-
33
- def dirac!(tensor)
34
- _dirac!(tensor)
35
- end
17
+ alias_method :constant!, :_constant!
18
+ alias_method :ones!, :_ones!
19
+ alias_method :zeros!, :_zeros!
20
+ alias_method :eye!, :_eye!
21
+ alias_method :dirac!, :_dirac!
36
22
 
37
23
  def xavier_uniform!(tensor, gain: 1.0)
38
24
  _xavier_uniform!(tensor, gain)
@@ -45,7 +45,7 @@ module Torch
45
45
  square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
46
46
  std = square_avg.add(eps).sqrt!
47
47
  delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
48
- p.data.add!(-group[:lr], delta)
48
+ p.data.add!(delta, alpha: -group[:lr])
49
49
  acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
50
50
  end
51
51
  end
@@ -53,11 +53,11 @@ module Torch
53
53
  bias_correction2 = 1 - beta2 ** state[:step]
54
54
 
55
55
  if group[:weight_decay] != 0
56
- grad.add!(group[:weight_decay], p.data)
56
+ grad.add!(p.data, alpha: group[:weight_decay])
57
57
  end
58
58
 
59
59
  # Decay the first and second moment running average coefficient
60
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
60
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
61
61
  exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
62
62
  if amsgrad
63
63
  # Maintains the maximum of all 2nd moment running avg. till now
@@ -46,7 +46,7 @@ module Torch
46
46
  end
47
47
 
48
48
  # Update biased first moment estimate.
49
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
49
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
50
50
  # Update the exponentially weighted infinity norm.
51
51
  norm_buf = Torch.cat([
52
52
  exp_inf.mul!(beta2).unsqueeze(0),
@@ -58,7 +58,7 @@ module Torch
58
58
  bias_correction2 = 1 - beta2 ** state[:step]
59
59
 
60
60
  # Decay the first and second moment running average coefficient
61
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
61
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
62
62
  exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
63
63
  if amsgrad
64
64
  # Maintains the maximum of all 2nd moment running avg. till now
@@ -43,7 +43,7 @@ module Torch
43
43
  p.data.mul!(1 - group[:lambd] * state[:eta])
44
44
 
45
45
  # update parameter
46
- p.data.add!(-state[:eta], grad)
46
+ p.data.add!(grad, alpha: -state[:eta])
47
47
 
48
48
  # averaging
49
49
  if state[:mu] != 1
@@ -32,7 +32,7 @@ module Torch
32
32
  next unless p.grad
33
33
  d_p = p.grad.data
34
34
  if weight_decay != 0
35
- d_p.add!(weight_decay, p.data)
35
+ d_p.add!(p.data, alpha: weight_decay)
36
36
  end
37
37
  if momentum != 0
38
38
  param_state = @state[p]
@@ -40,7 +40,7 @@ module Torch
40
40
  buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
41
  else
42
42
  buf = param_state[:momentum_buffer]
43
- buf.mul!(momentum).add!(1 - dampening, d_p)
43
+ buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
44
44
  end
45
45
  if nesterov
46
46
  d_p = d_p.add(momentum, buf)
@@ -49,7 +49,7 @@ module Torch
49
49
  end
50
50
  end
51
51
 
52
- p.data.add!(-group[:lr], d_p)
52
+ p.data.add!(d_p, alpha: -group[:lr])
53
53
  end
54
54
  end
55
55