torch-rb 0.1.2 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +18 -6
  5. data/ext/torch/ext.cpp +148 -369
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +242 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +240 -131
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +27 -22
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -38
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +411 -22
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +201 -20
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +2 -2
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +56 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +48 -16
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +71 -30
  139. data/lib/torch/utils/data/data_loader.rb +10 -4
  140. data/lib/torch/utils/data/tensor_dataset.rb +3 -0
  141. data/lib/torch/version.rb +1 -1
  142. metadata +123 -6
@@ -0,0 +1,68 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adamax.py
2
+ module Torch
3
+ module Optim
4
+ class Adamax < Optimizer
5
+ def initialize(params, lr: 2e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 0)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
8
+ raise ArgumentError, "Invalid beta parameter at index 0: #{betas[0]}" if betas[0] < 0 || betas[0] >= 1
9
+ raise ArgumentError, "Invalid beta parameter at index 1: #{betas[1]}" if betas[1] < 0 || betas[1] >= 1
10
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
11
+
12
+ defaults = {lr: lr, betas: betas, eps: eps, weight_decay: weight_decay}
13
+ super(params, defaults)
14
+ end
15
+
16
+ def step(closure = nil)
17
+ loss = nil
18
+ if closure
19
+ loss = closure.call
20
+ end
21
+
22
+ @param_groups.each do |group|
23
+ group[:params].each do |p|
24
+ next unless p.grad
25
+ grad = p.grad.data
26
+ if grad.sparse?
27
+ raise Error, "Adamax does not support sparse gradients, please consider SparseAdam instead"
28
+ end
29
+ state = @state[p]
30
+
31
+ # State initialization
32
+ if state.size == 0
33
+ state[:step] = 0
34
+ state[:exp_avg] = Torch.zeros_like(p.data)
35
+ state[:exp_inf] = Torch.zeros_like(p.data)
36
+ end
37
+
38
+ exp_avg, exp_inf = state[:exp_avg], state[:exp_inf]
39
+ beta1, beta2 = group[:betas]
40
+ eps = group[:eps]
41
+
42
+ state[:step] += 1
43
+
44
+ if group[:weight_decay] != 0
45
+ grad = grad.add(group[:weight_decay], p.data)
46
+ end
47
+
48
+ # Update biased first moment estimate.
49
+ exp_avg.mul!(beta1).add!(1 - beta1, grad)
50
+ # Update the exponentially weighted infinity norm.
51
+ norm_buf = Torch.cat([
52
+ exp_inf.mul!(beta2).unsqueeze(0),
53
+ grad.abs.add!(eps).unsqueeze!(0)
54
+ ], 0)
55
+ Torch.max(norm_buf, 0, keepdim: false, out: [exp_inf, exp_inf.new.long])
56
+
57
+ bias_correction = 1 - beta1 ** state[:step]
58
+ clr = group[:lr] / bias_correction
59
+
60
+ p.data.addcdiv!(-clr, exp_avg, exp_inf)
61
+ end
62
+ end
63
+
64
+ loss
65
+ end
66
+ end
67
+ end
68
+ end
@@ -0,0 +1,82 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adamw.py
2
+ module Torch
3
+ module Optim
4
+ class AdamW < Optimizer
5
+ def initialize(params, lr: 1e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 1e-2, amsgrad: false)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
8
+ raise ArgumentError, "Invalid beta parameter at index 0: #{betas[0]}" if betas[0] < 0 || betas[0] >= 1
9
+ raise ArgumentError, "Invalid beta parameter at index 1: #{betas[1]}" if betas[1] < 0 || betas[1] >= 1
10
+
11
+ defaults = {lr: lr, betas: betas, eps: eps, weight_decay: weight_decay, amsgrad: amsgrad}
12
+ super(params, defaults)
13
+ end
14
+
15
+ def step(closure = nil)
16
+ loss = nil
17
+ if closure
18
+ loss = closure.call
19
+ end
20
+
21
+ @param_groups.each do |group|
22
+ group[:params].each do |p|
23
+ next unless p.grad
24
+
25
+ # Perform stepweight decay
26
+ p.data.mul!(1 - group[:lr] * group[:weight_decay])
27
+
28
+ # Perform optimization step
29
+ grad = p.grad.data
30
+ if grad.sparse?
31
+ raise Error, "AdamW does not support sparse gradients, please consider SparseAdam instead"
32
+ end
33
+ amsgrad = group[:amsgrad]
34
+
35
+ state = @state[p]
36
+
37
+ # State initialization
38
+ if state.size == 0
39
+ state[:step] = 0
40
+ # Exponential moving average of gradient values
41
+ state[:exp_avg] = Torch.zeros_like(p.data)
42
+ # Exponential moving average of squared gradient values
43
+ state[:exp_avg_sq] = Torch.zeros_like(p.data)
44
+ if amsgrad
45
+ # Maintains max of all exp. moving avg. of sq. grad. values
46
+ state[:max_exp_avg_sq] = Torch.zeros_like(p.data)
47
+ end
48
+ end
49
+
50
+ exp_avg, exp_avg_sq = state[:exp_avg], state[:exp_avg_sq]
51
+ if amsgrad
52
+ max_exp_avg_sq = state[:max_exp_avg_sq]
53
+ end
54
+ beta1, beta2 = group[:betas]
55
+
56
+ state[:step] += 1
57
+ bias_correction1 = 1 - beta1 ** state[:step]
58
+ bias_correction2 = 1 - beta2 ** state[:step]
59
+
60
+ # Decay the first and second moment running average coefficient
61
+ exp_avg.mul!(beta1).add!(1 - beta1, grad)
62
+ exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
63
+ if amsgrad
64
+ # Maintains the maximum of all 2nd moment running avg. till now
65
+ Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
66
+ # Use the max. for normalizing running avg. of gradient
67
+ denom = (max_exp_avg_sq.sqrt / Math.sqrt(bias_correction2)).add!(group[:eps])
68
+ else
69
+ denom = (exp_avg_sq.sqrt / Math.sqrt(bias_correction2)).add!(group[:eps])
70
+ end
71
+
72
+ step_size = group[:lr] / bias_correction1
73
+
74
+ p.data.addcdiv!(-step_size, exp_avg, denom)
75
+ end
76
+ end
77
+
78
+ loss
79
+ end
80
+ end
81
+ end
82
+ end
@@ -0,0 +1,65 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/asgd.py
2
+ module Torch
3
+ module Optim
4
+ class ASGD < Optimizer
5
+ def initialize(params, lr: 1e-2, lambd: 1e-4, alpha: 0.75, t0: 1e6, weight_decay: 0)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
8
+
9
+ defaults = {lr: lr, lambd: lambd, alpha: alpha, t0: t0, weight_decay: weight_decay}
10
+ super(params, defaults)
11
+ end
12
+
13
+ def step(closure = nil)
14
+ loss = nil
15
+ if closure
16
+ loss = closure.call
17
+ end
18
+
19
+ @param_groups.each do |group|
20
+ group[:params].each do |p|
21
+ next unless p.grad
22
+ grad = p.grad.data
23
+ if grad.sparse?
24
+ raise Error, "ASGD does not support sparse gradients"
25
+ end
26
+ state = @state[p]
27
+
28
+ # State initialization
29
+ if state.size == 0
30
+ state[:step] = 0
31
+ state[:eta] = group[:lr]
32
+ state[:mu] = 1
33
+ state[:ax] = Torch.zeros_like(p.data)
34
+ end
35
+
36
+ state[:step] += 1
37
+
38
+ if group[:weight_decay] != 0
39
+ grad = grad.add(group[:weight_decay], p.data)
40
+ end
41
+
42
+ # decay term
43
+ p.data.mul!(1 - group[:lambd] * state[:eta])
44
+
45
+ # update parameter
46
+ p.data.add!(-state[:eta], grad)
47
+
48
+ # averaging
49
+ if state[:mu] != 1
50
+ state[:ax].add!(p.data.sub(state[:ax]).mul(state[:mu]))
51
+ else
52
+ state[:ax].copy!(p.data)
53
+ end
54
+
55
+ # update eta and mu
56
+ state[:eta] = (group[:lr] / ((1 + group[:lambd] * group[:lr] * state[:step]) ** group[:alpha]))
57
+ state[:mu] = 1 / [1, state[:step] - group[:t0]].max
58
+ end
59
+ end
60
+
61
+ loss
62
+ end
63
+ end
64
+ end
65
+ end
@@ -0,0 +1,33 @@
1
+ module Torch
2
+ module Optim
3
+ module LRScheduler
4
+ class LRScheduler
5
+ def initialize(optimizer, last_epoch)
6
+ @optimizer = optimizer
7
+ if last_epoch == -1
8
+ optimizer.param_groups.each do |group|
9
+ group[:initial_lr] ||= group[:lr]
10
+ end
11
+ last_epoch = 0
12
+ else
13
+ raise NotImplementedYet
14
+ end
15
+ @base_lrs = optimizer.param_groups.map { |group| group[:initial_lr] }
16
+ @last_epoch = last_epoch
17
+
18
+ @step_count = 0
19
+ step(last_epoch)
20
+ end
21
+
22
+ def step(epoch = nil)
23
+ @step_count += 1
24
+ epoch ||= @last_epoch + 1
25
+ @last_epoch = epoch
26
+ @optimizer.param_groups.zip(get_lr).each do |param_group, lr|
27
+ param_group[:lr] = lr
28
+ end
29
+ end
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,17 @@
1
+ module Torch
2
+ module Optim
3
+ module LRScheduler
4
+ class StepLR < LRScheduler
5
+ def initialize(optimizer, step_size:, gamma: 0.1, last_epoch: -1)
6
+ @step_size = step_size
7
+ @gamma = gamma
8
+ super(optimizer, last_epoch)
9
+ end
10
+
11
+ def get_lr
12
+ @base_lrs.map { |base_lr| base_lr * @gamma ** (@last_epoch / @step_size).floor }
13
+ end
14
+ end
15
+ end
16
+ end
17
+ end
@@ -1,6 +1,62 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py
1
2
  module Torch
2
3
  module Optim
3
4
  class Optimizer
5
+ attr_reader :param_groups
6
+
7
+ def initialize(params, defaults)
8
+ @defaults = defaults
9
+ @state = Hash.new { |hash, key| hash[key] = {} }
10
+ @param_groups = []
11
+
12
+ param_groups = params
13
+ if param_groups.empty?
14
+ raise ArgumentError, "optimizer got an empty parameter list"
15
+ end
16
+ if !param_groups[0].is_a?(Hash)
17
+ param_groups = [{params: param_groups}]
18
+ end
19
+
20
+ param_groups.each do |param_group|
21
+ add_param_group(param_group)
22
+ end
23
+ end
24
+
25
+ def add_param_group(param_group)
26
+ # TODO more advanced logic
27
+ @param_groups << @defaults.merge(param_group)
28
+ end
29
+
30
+ def load_state_dict(state_dict)
31
+ raise NotImplementedYet
32
+ end
33
+
34
+ def state_dict
35
+ pack_group = lambda do |group|
36
+ packed = group.select { |k, _| k != :params }.to_h
37
+ packed[:params] = group[:params].map { |p| p.object_id }
38
+ packed
39
+ end
40
+
41
+ param_groups = @param_groups.map { |g| pack_group.call(g) }
42
+ packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
43
+
44
+ {
45
+ state: packed_state,
46
+ param_groups: param_groups
47
+ }
48
+ end
49
+
50
+ def zero_grad
51
+ @param_groups.each do |group|
52
+ group[:params].each do |p|
53
+ if p.grad
54
+ p.grad.detach!
55
+ p.grad.zero!
56
+ end
57
+ end
58
+ end
59
+ end
4
60
  end
5
61
  end
6
62
  end
@@ -0,0 +1,76 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py
2
+ module Torch
3
+ module Optim
4
+ class RMSprop < Optimizer
5
+ def initialize(params, lr: 1e-2, alpha: 0.99, eps: 1e-8, weight_decay: 0, momentum: 0, centered: false)
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
8
+ raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0
9
+ raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
10
+ raise ArgumentError, "Invalid momentum alpha: #{alpha}" if alpha < 0
11
+
12
+ defaults = {lr: lr, momentum: momentum, alpha: alpha, eps: eps, centered: centered, weight_decay: weight_decay}
13
+ super(params, defaults)
14
+ end
15
+
16
+ def step(closure = nil)
17
+ loss = nil
18
+ if closure
19
+ loss = closure.call
20
+ end
21
+
22
+ @param_groups.each do |group|
23
+ group[:params].each do |p|
24
+ next unless p.grad
25
+ grad = p.grad.data
26
+ if grad.sparse?
27
+ raise Error, "RMSprop does not support sparse gradients"
28
+ end
29
+ state = @state[p]
30
+
31
+ # State initialization
32
+ if state.size == 0
33
+ state[:step] = 0
34
+ state[:square_avg] = Torch.zeros_like(p.data)
35
+ if group[:momentum] > 0
36
+ state[:momentum_buffer] = Torch.zeros_like(p.data)
37
+ end
38
+ if group[:centered]
39
+ state[:grad_avg] = Torch.zeros_like(p.data)
40
+ end
41
+ end
42
+
43
+ square_avg = state[:square_avg]
44
+ alpha = group[:alpha]
45
+
46
+ state[:step] += 1
47
+
48
+ if group[:weight_decay] != 0
49
+ grad = grad.add(group[:weight_decay], p.data)
50
+ end
51
+
52
+ square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
53
+
54
+ if group[:centered]
55
+ grad_avg = state[:grad_avg]
56
+ grad_avg.mul!(alpha).add!(1 - alpha, grad)
57
+ avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
58
+ else
59
+ avg = square_avg.sqrt.add!(group[:eps])
60
+ end
61
+
62
+ if group[:momentum] > 0
63
+ buf = state[:momentum_buffer]
64
+ buf.mul!(group[:momentum]).addcdiv!(grad, avg)
65
+ p.data.add!(-group[:lr], buf)
66
+ else
67
+ p.data.addcdiv!(-group[:lr], grad, avg)
68
+ end
69
+ end
70
+ end
71
+
72
+ loss
73
+ end
74
+ end
75
+ end
76
+ end
@@ -0,0 +1,68 @@
1
+ # ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rprop.py
2
+ module Torch
3
+ module Optim
4
+ class Rprop < Optimizer
5
+ def initialize(params, lr: 1e-2, etas: [0.5, 1.2], step_sizes: [1e-6, 50])
6
+ raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
7
+ raise ArgumentError, "Invalid eta values: #{etas[0]}, #{etas[1]}" if etas[0] < 0 || etas[0] >= 1 || etas[1] < 1
8
+
9
+ defaults = {lr: lr, etas: etas, step_sizes: step_sizes}
10
+ super(params, defaults)
11
+ end
12
+
13
+ def step(closure = nil)
14
+ # TODO implement []=
15
+ raise NotImplementedYet
16
+
17
+ loss = nil
18
+ if closure
19
+ loss = closure.call
20
+ end
21
+
22
+ @param_groups.each do |group|
23
+ group[:params].each do |p|
24
+ next unless p.grad
25
+ grad = p.grad.data
26
+ if grad.sparse?
27
+ raise Error, "Rprop does not support sparse gradients"
28
+ end
29
+ state = @state[p]
30
+
31
+ # State initialization
32
+ if state.size == 0
33
+ state[:step] = 0
34
+ state[:prev] = Torch.zeros_like(p.data)
35
+ state[:step_size] = grad.new.resize_as!(grad).fill!(group[:lr])
36
+ end
37
+
38
+ etaminus, etaplus = group[:etas]
39
+ step_size_min, step_size_max = group[:step_sizes]
40
+ step_size = state[:step_size]
41
+
42
+ state[:step] += 1
43
+
44
+ sign = grad.mul(state[:prev]).sign
45
+ sign[sign.gt(0)] = etaplus
46
+ sign[sign.lt(0)] = etaminus
47
+ sign[sign.eq(0)] = 1
48
+
49
+ # update stepsizes with step size updates
50
+ step_size.mul!(sign).clamp!(step_size_min, step_size_max)
51
+
52
+ # for dir<0, dfdx=0
53
+ # for dir>=0 dfdx=dfdx
54
+ grad = grad.clone
55
+ grad[sign.eq(etaminus)] = 0
56
+
57
+ # update parameters
58
+ p.data.addcmul!(-1, grad.sign, step_size)
59
+
60
+ state[:prev].copy!(grad)
61
+ end
62
+ end
63
+
64
+ loss
65
+ end
66
+ end
67
+ end
68
+ end