torch-rb 0.3.7 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_tensor_functions(Module m);
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_torch_functions(Module m);
@@ -0,0 +1,42 @@
1
+ #pragma once
2
+
3
+ #include <rice/Symbol.hpp>
4
+
5
+ // keep THP prefix for now to make it easier to compare code
6
+
7
+ extern VALUE THPVariableClass;
8
+
9
+ inline VALUE THPUtils_internSymbol(const std::string& str) {
10
+ return Symbol(str);
11
+ }
12
+
13
+ inline std::string THPUtils_unpackSymbol(VALUE obj) {
14
+ Check_Type(obj, T_SYMBOL);
15
+ obj = rb_funcall(obj, rb_intern("to_s"), 0);
16
+ return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
17
+ }
18
+
19
+ inline std::string THPUtils_unpackString(VALUE obj) {
20
+ Check_Type(obj, T_STRING);
21
+ return std::string(RSTRING_PTR(obj), RSTRING_LEN(obj));
22
+ }
23
+
24
+ inline bool THPUtils_checkSymbol(VALUE obj) {
25
+ return SYMBOL_P(obj);
26
+ }
27
+
28
+ inline bool THPUtils_checkIndex(VALUE obj) {
29
+ return FIXNUM_P(obj);
30
+ }
31
+
32
+ inline bool THPUtils_checkScalar(VALUE obj) {
33
+ return FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj) || RB_TYPE_P(obj, T_COMPLEX);
34
+ }
35
+
36
+ inline bool THPVariable_Check(VALUE obj) {
37
+ return rb_obj_is_kind_of(obj, THPVariableClass);
38
+ }
39
+
40
+ inline bool THPVariable_CheckExact(VALUE obj) {
41
+ return rb_obj_is_instance_of(obj, THPVariableClass);
42
+ }
@@ -1,43 +1,44 @@
1
+ #pragma once
2
+
1
3
  #include <torch/torch.h>
2
4
  #include <rice/Object.hpp>
3
- #include "templates.hpp"
4
5
 
5
- Object wrap(bool x) {
6
+ inline Object wrap(bool x) {
6
7
  return to_ruby<bool>(x);
7
8
  }
8
9
 
9
- Object wrap(int64_t x) {
10
+ inline Object wrap(int64_t x) {
10
11
  return to_ruby<int64_t>(x);
11
12
  }
12
13
 
13
- Object wrap(double x) {
14
+ inline Object wrap(double x) {
14
15
  return to_ruby<double>(x);
15
16
  }
16
17
 
17
- Object wrap(torch::Tensor x) {
18
+ inline Object wrap(torch::Tensor x) {
18
19
  return to_ruby<torch::Tensor>(x);
19
20
  }
20
21
 
21
- Object wrap(torch::Scalar x) {
22
+ inline Object wrap(torch::Scalar x) {
22
23
  return to_ruby<torch::Scalar>(x);
23
24
  }
24
25
 
25
- Object wrap(torch::ScalarType x) {
26
+ inline Object wrap(torch::ScalarType x) {
26
27
  return to_ruby<torch::ScalarType>(x);
27
28
  }
28
29
 
29
- Object wrap(torch::QScheme x) {
30
+ inline Object wrap(torch::QScheme x) {
30
31
  return to_ruby<torch::QScheme>(x);
31
32
  }
32
33
 
33
- Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
34
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
34
35
  Array a;
35
36
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
36
37
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
37
38
  return Object(a);
38
39
  }
39
40
 
40
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
41
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
41
42
  Array a;
42
43
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
43
44
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -45,7 +46,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x) {
45
46
  return Object(a);
46
47
  }
47
48
 
48
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
49
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
49
50
  Array a;
50
51
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
51
52
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -54,7 +55,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
54
55
  return Object(a);
55
56
  }
56
57
 
57
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
58
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x) {
58
59
  Array a;
59
60
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
60
61
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -64,7 +65,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tenso
64
65
  return Object(a);
65
66
  }
66
67
 
67
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
68
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x) {
68
69
  Array a;
69
70
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
70
71
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -73,7 +74,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x)
73
74
  return Object(a);
74
75
  }
75
76
 
76
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
77
+ inline Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
77
78
  Array a;
78
79
  a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
79
80
  a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
@@ -82,7 +83,7 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
82
83
  return Object(a);
83
84
  }
84
85
 
85
- Object wrap(std::vector<torch::Tensor> x) {
86
+ inline Object wrap(torch::TensorList x) {
86
87
  Array a;
87
88
  for (auto& t : x) {
88
89
  a.push(to_ruby<torch::Tensor>(t));
@@ -7,11 +7,6 @@ require "net/http"
7
7
  require "set"
8
8
  require "tmpdir"
9
9
 
10
- # native functions
11
- require "torch/native/generator"
12
- require "torch/native/parser"
13
- require "torch/native/dispatcher"
14
-
15
10
  # modules
16
11
  require "torch/inspector"
17
12
  require "torch/tensor"
@@ -374,63 +369,6 @@ module Torch
374
369
 
375
370
  # --- begin tensor creation: https://pytorch.org/cppdocs/notes/tensor_creation.html ---
376
371
 
377
- def arange(start, finish = nil, step = 1, **options)
378
- # ruby doesn't support start = 0, finish, step = 1, ...
379
- if finish.nil?
380
- finish = start
381
- start = 0
382
- end
383
- _arange(start, finish, step, tensor_options(**options))
384
- end
385
-
386
- def empty(*size, **options)
387
- _empty(tensor_size(size), tensor_options(**options))
388
- end
389
-
390
- def eye(n, m = nil, **options)
391
- _eye(n, m || n, tensor_options(**options))
392
- end
393
-
394
- def full(size, fill_value, **options)
395
- _full(size, fill_value, tensor_options(**options))
396
- end
397
-
398
- def linspace(start, finish, steps = 100, **options)
399
- _linspace(start, finish, steps, tensor_options(**options))
400
- end
401
-
402
- def logspace(start, finish, steps = 100, base = 10.0, **options)
403
- _logspace(start, finish, steps, base, tensor_options(**options))
404
- end
405
-
406
- def ones(*size, **options)
407
- _ones(tensor_size(size), tensor_options(**options))
408
- end
409
-
410
- def rand(*size, **options)
411
- _rand(tensor_size(size), tensor_options(**options))
412
- end
413
-
414
- def randint(low = 0, high, size, **options)
415
- _randint(low, high, size, tensor_options(**options))
416
- end
417
-
418
- def randn(*size, **options)
419
- _randn(tensor_size(size), tensor_options(**options))
420
- end
421
-
422
- def randperm(n, **options)
423
- # dtype hack in Python
424
- # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
425
- options[:dtype] ||= :int64
426
-
427
- _randperm(n, tensor_options(**options))
428
- end
429
-
430
- def zeros(*size, **options)
431
- _zeros(tensor_size(size), tensor_options(**options))
432
- end
433
-
434
372
  def tensor(data, **options)
435
373
  if options[:dtype].nil? && defined?(Numo::NArray) && data.is_a?(Numo::NArray)
436
374
  numo_to_dtype = _dtype_to_numo.map(&:reverse).to_h
@@ -394,15 +394,15 @@ module Torch
394
394
  # loss functions
395
395
 
396
396
  def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
397
- NN.binary_cross_entropy(input, target, weight, reduction)
397
+ NN.binary_cross_entropy(input, target, weight, to_reduction(reduction))
398
398
  end
399
399
 
400
400
  def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
401
- Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
401
+ Torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, to_reduction(reduction))
402
402
  end
403
403
 
404
404
  def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
405
- Torch.cosine_embedding_loss(input1, input2, target, margin, reduction)
405
+ Torch.cosine_embedding_loss(input1, input2, target, margin, to_reduction(reduction))
406
406
  end
407
407
 
408
408
  def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
@@ -411,34 +411,34 @@ module Torch
411
411
 
412
412
  def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
413
413
  # call to_a on input_lengths and target_lengths for C++
414
- Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
414
+ Torch.ctc_loss(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, to_reduction(reduction), zero_infinity)
415
415
  end
416
416
 
417
417
  def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
418
- Torch.hinge_embedding_loss(input, target, margin, reduction)
418
+ Torch.hinge_embedding_loss(input, target, margin, to_reduction(reduction))
419
419
  end
420
420
 
421
421
  def kl_div(input, target, reduction: "mean")
422
- Torch.kl_div(input, target, reduction)
422
+ Torch.kl_div(input, target, to_reduction(reduction))
423
423
  end
424
424
 
425
425
  def l1_loss(input, target, reduction: "mean")
426
- NN.l1_loss(input, target, reduction)
426
+ NN.l1_loss(input, target, to_reduction(reduction))
427
427
  end
428
428
 
429
429
  def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
430
- Torch.margin_ranking_loss(input1, input2, target, margin, reduction)
430
+ Torch.margin_ranking_loss(input1, input2, target, margin, to_reduction(reduction))
431
431
  end
432
432
 
433
433
  def mse_loss(input, target, reduction: "mean")
434
434
  if target.size != input.size
435
435
  warn "Using a target size (#{target.size}) that is different to the input size (#{input.size}). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size."
436
436
  end
437
- NN.mse_loss(input, target, reduction)
437
+ NN.mse_loss(input, target, to_reduction(reduction))
438
438
  end
439
439
 
440
440
  def multilabel_margin_loss(input, target, reduction: "mean")
441
- NN.multilabel_margin_loss(input, target, reduction)
441
+ NN.multilabel_margin_loss(input, target, to_reduction(reduction))
442
442
  end
443
443
 
444
444
  def multilabel_soft_margin_loss(input, target, weight: nil)
@@ -446,27 +446,27 @@ module Torch
446
446
  end
447
447
 
448
448
  def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
449
- NN.multi_margin_loss(input, target, p, margin, weight, reduction)
449
+ NN.multi_margin_loss(input, target, p, margin, weight, to_reduction(reduction))
450
450
  end
451
451
 
452
452
  def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
453
- NN.nll_loss(input, target, weight, reduction, ignore_index)
453
+ NN.nll_loss(input, target, weight, to_reduction(reduction), ignore_index)
454
454
  end
455
455
 
456
456
  def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
457
- Torch.poisson_nll_loss(input, target, log_input, full, eps, reduction)
457
+ Torch.poisson_nll_loss(input, target, log_input, full, eps, to_reduction(reduction))
458
458
  end
459
459
 
460
460
  def soft_margin_loss(input, target, reduction: "mean")
461
- NN.soft_margin_loss(input, target, reduction)
461
+ NN.soft_margin_loss(input, target, to_reduction(reduction))
462
462
  end
463
463
 
464
464
  def smooth_l1_loss(input, target, reduction: "mean")
465
- NN.smooth_l1_loss(input, target, reduction)
465
+ NN.smooth_l1_loss(input, target, to_reduction(reduction))
466
466
  end
467
467
 
468
468
  def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
469
- Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
469
+ Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, to_reduction(reduction))
470
470
  end
471
471
 
472
472
  # vision
@@ -542,6 +542,20 @@ module Torch
542
542
 
543
543
  private
544
544
 
545
+ # see _reduction.py
546
+ def to_reduction(v)
547
+ case v.to_s
548
+ when "none"
549
+ 0
550
+ when "mean"
551
+ 1
552
+ when "sum"
553
+ 2
554
+ else
555
+ raise ArgumentError, "#{v} is not a valid value for reduction"
556
+ end
557
+ end
558
+
545
559
  def softmax_dim(ndim)
546
560
  ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
547
561
  end
@@ -14,25 +14,11 @@ module Torch
14
14
  _normal!(tensor, mean, std)
15
15
  end
16
16
 
17
- def constant!(tensor, val)
18
- _constant!(tensor, val)
19
- end
20
-
21
- def ones!(tensor)
22
- _ones!(tensor)
23
- end
24
-
25
- def zeros!(tensor)
26
- _zeros!(tensor)
27
- end
28
-
29
- def eye!(tensor)
30
- _eye!(tensor)
31
- end
32
-
33
- def dirac!(tensor)
34
- _dirac!(tensor)
35
- end
17
+ alias_method :constant!, :_constant!
18
+ alias_method :ones!, :_ones!
19
+ alias_method :zeros!, :_zeros!
20
+ alias_method :eye!, :_eye!
21
+ alias_method :dirac!, :_dirac!
36
22
 
37
23
  def xavier_uniform!(tensor, gain: 1.0)
38
24
  _xavier_uniform!(tensor, gain)
@@ -45,7 +45,7 @@ module Torch
45
45
  square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
46
46
  std = square_avg.add(eps).sqrt!
47
47
  delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
48
- p.data.add!(-group[:lr], delta)
48
+ p.data.add!(delta, alpha: -group[:lr])
49
49
  acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
50
50
  end
51
51
  end
@@ -53,11 +53,11 @@ module Torch
53
53
  bias_correction2 = 1 - beta2 ** state[:step]
54
54
 
55
55
  if group[:weight_decay] != 0
56
- grad.add!(group[:weight_decay], p.data)
56
+ grad.add!(p.data, alpha: group[:weight_decay])
57
57
  end
58
58
 
59
59
  # Decay the first and second moment running average coefficient
60
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
60
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
61
61
  exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
62
62
  if amsgrad
63
63
  # Maintains the maximum of all 2nd moment running avg. till now
@@ -46,7 +46,7 @@ module Torch
46
46
  end
47
47
 
48
48
  # Update biased first moment estimate.
49
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
49
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
50
50
  # Update the exponentially weighted infinity norm.
51
51
  norm_buf = Torch.cat([
52
52
  exp_inf.mul!(beta2).unsqueeze(0),
@@ -58,7 +58,7 @@ module Torch
58
58
  bias_correction2 = 1 - beta2 ** state[:step]
59
59
 
60
60
  # Decay the first and second moment running average coefficient
61
- exp_avg.mul!(beta1).add!(1 - beta1, grad)
61
+ exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
62
62
  exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
63
63
  if amsgrad
64
64
  # Maintains the maximum of all 2nd moment running avg. till now
@@ -43,7 +43,7 @@ module Torch
43
43
  p.data.mul!(1 - group[:lambd] * state[:eta])
44
44
 
45
45
  # update parameter
46
- p.data.add!(-state[:eta], grad)
46
+ p.data.add!(grad, alpha: -state[:eta])
47
47
 
48
48
  # averaging
49
49
  if state[:mu] != 1
@@ -32,7 +32,7 @@ module Torch
32
32
  next unless p.grad
33
33
  d_p = p.grad.data
34
34
  if weight_decay != 0
35
- d_p.add!(weight_decay, p.data)
35
+ d_p.add!(p.data, alpha: weight_decay)
36
36
  end
37
37
  if momentum != 0
38
38
  param_state = @state[p]
@@ -40,7 +40,7 @@ module Torch
40
40
  buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
41
  else
42
42
  buf = param_state[:momentum_buffer]
43
- buf.mul!(momentum).add!(1 - dampening, d_p)
43
+ buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
44
44
  end
45
45
  if nesterov
46
46
  d_p = d_p.add(momentum, buf)
@@ -49,7 +49,7 @@ module Torch
49
49
  end
50
50
  end
51
51
 
52
- p.data.add!(-group[:lr], d_p)
52
+ p.data.add!(d_p, alpha: -group[:lr])
53
53
  end
54
54
  end
55
55