torch-rb 0.1.2 → 0.1.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +35 -0
- data/LICENSE.txt +46 -22
- data/README.md +18 -6
- data/ext/torch/ext.cpp +148 -369
- data/ext/torch/extconf.rb +6 -0
- data/ext/torch/nn_functions.cpp +615 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +242 -0
- data/ext/torch/tensor_functions.cpp +1920 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2975 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +240 -131
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +27 -22
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +109 -0
- data/lib/torch/native/generator.rb +168 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +134 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +19 -0
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +16 -38
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +411 -22
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +201 -20
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +2 -2
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +198 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +56 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +48 -16
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +71 -30
- data/lib/torch/utils/data/data_loader.rb +10 -4
- data/lib/torch/utils/data/tensor_dataset.rb +3 -0
- data/lib/torch/version.rb +1 -1
- metadata +123 -6
@@ -0,0 +1,68 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adamax.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Adamax < Optimizer
|
5
|
+
def initialize(params, lr: 2e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 0)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
8
|
+
raise ArgumentError, "Invalid beta parameter at index 0: #{betas[0]}" if betas[0] < 0 || betas[0] >= 1
|
9
|
+
raise ArgumentError, "Invalid beta parameter at index 1: #{betas[1]}" if betas[1] < 0 || betas[1] >= 1
|
10
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
11
|
+
|
12
|
+
defaults = {lr: lr, betas: betas, eps: eps, weight_decay: weight_decay}
|
13
|
+
super(params, defaults)
|
14
|
+
end
|
15
|
+
|
16
|
+
def step(closure = nil)
|
17
|
+
loss = nil
|
18
|
+
if closure
|
19
|
+
loss = closure.call
|
20
|
+
end
|
21
|
+
|
22
|
+
@param_groups.each do |group|
|
23
|
+
group[:params].each do |p|
|
24
|
+
next unless p.grad
|
25
|
+
grad = p.grad.data
|
26
|
+
if grad.sparse?
|
27
|
+
raise Error, "Adamax does not support sparse gradients, please consider SparseAdam instead"
|
28
|
+
end
|
29
|
+
state = @state[p]
|
30
|
+
|
31
|
+
# State initialization
|
32
|
+
if state.size == 0
|
33
|
+
state[:step] = 0
|
34
|
+
state[:exp_avg] = Torch.zeros_like(p.data)
|
35
|
+
state[:exp_inf] = Torch.zeros_like(p.data)
|
36
|
+
end
|
37
|
+
|
38
|
+
exp_avg, exp_inf = state[:exp_avg], state[:exp_inf]
|
39
|
+
beta1, beta2 = group[:betas]
|
40
|
+
eps = group[:eps]
|
41
|
+
|
42
|
+
state[:step] += 1
|
43
|
+
|
44
|
+
if group[:weight_decay] != 0
|
45
|
+
grad = grad.add(group[:weight_decay], p.data)
|
46
|
+
end
|
47
|
+
|
48
|
+
# Update biased first moment estimate.
|
49
|
+
exp_avg.mul!(beta1).add!(1 - beta1, grad)
|
50
|
+
# Update the exponentially weighted infinity norm.
|
51
|
+
norm_buf = Torch.cat([
|
52
|
+
exp_inf.mul!(beta2).unsqueeze(0),
|
53
|
+
grad.abs.add!(eps).unsqueeze!(0)
|
54
|
+
], 0)
|
55
|
+
Torch.max(norm_buf, 0, keepdim: false, out: [exp_inf, exp_inf.new.long])
|
56
|
+
|
57
|
+
bias_correction = 1 - beta1 ** state[:step]
|
58
|
+
clr = group[:lr] / bias_correction
|
59
|
+
|
60
|
+
p.data.addcdiv!(-clr, exp_avg, exp_inf)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
loss
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adamw.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class AdamW < Optimizer
|
5
|
+
def initialize(params, lr: 1e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 1e-2, amsgrad: false)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
8
|
+
raise ArgumentError, "Invalid beta parameter at index 0: #{betas[0]}" if betas[0] < 0 || betas[0] >= 1
|
9
|
+
raise ArgumentError, "Invalid beta parameter at index 1: #{betas[1]}" if betas[1] < 0 || betas[1] >= 1
|
10
|
+
|
11
|
+
defaults = {lr: lr, betas: betas, eps: eps, weight_decay: weight_decay, amsgrad: amsgrad}
|
12
|
+
super(params, defaults)
|
13
|
+
end
|
14
|
+
|
15
|
+
def step(closure = nil)
|
16
|
+
loss = nil
|
17
|
+
if closure
|
18
|
+
loss = closure.call
|
19
|
+
end
|
20
|
+
|
21
|
+
@param_groups.each do |group|
|
22
|
+
group[:params].each do |p|
|
23
|
+
next unless p.grad
|
24
|
+
|
25
|
+
# Perform stepweight decay
|
26
|
+
p.data.mul!(1 - group[:lr] * group[:weight_decay])
|
27
|
+
|
28
|
+
# Perform optimization step
|
29
|
+
grad = p.grad.data
|
30
|
+
if grad.sparse?
|
31
|
+
raise Error, "AdamW does not support sparse gradients, please consider SparseAdam instead"
|
32
|
+
end
|
33
|
+
amsgrad = group[:amsgrad]
|
34
|
+
|
35
|
+
state = @state[p]
|
36
|
+
|
37
|
+
# State initialization
|
38
|
+
if state.size == 0
|
39
|
+
state[:step] = 0
|
40
|
+
# Exponential moving average of gradient values
|
41
|
+
state[:exp_avg] = Torch.zeros_like(p.data)
|
42
|
+
# Exponential moving average of squared gradient values
|
43
|
+
state[:exp_avg_sq] = Torch.zeros_like(p.data)
|
44
|
+
if amsgrad
|
45
|
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
46
|
+
state[:max_exp_avg_sq] = Torch.zeros_like(p.data)
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
exp_avg, exp_avg_sq = state[:exp_avg], state[:exp_avg_sq]
|
51
|
+
if amsgrad
|
52
|
+
max_exp_avg_sq = state[:max_exp_avg_sq]
|
53
|
+
end
|
54
|
+
beta1, beta2 = group[:betas]
|
55
|
+
|
56
|
+
state[:step] += 1
|
57
|
+
bias_correction1 = 1 - beta1 ** state[:step]
|
58
|
+
bias_correction2 = 1 - beta2 ** state[:step]
|
59
|
+
|
60
|
+
# Decay the first and second moment running average coefficient
|
61
|
+
exp_avg.mul!(beta1).add!(1 - beta1, grad)
|
62
|
+
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
|
63
|
+
if amsgrad
|
64
|
+
# Maintains the maximum of all 2nd moment running avg. till now
|
65
|
+
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
66
|
+
# Use the max. for normalizing running avg. of gradient
|
67
|
+
denom = (max_exp_avg_sq.sqrt / Math.sqrt(bias_correction2)).add!(group[:eps])
|
68
|
+
else
|
69
|
+
denom = (exp_avg_sq.sqrt / Math.sqrt(bias_correction2)).add!(group[:eps])
|
70
|
+
end
|
71
|
+
|
72
|
+
step_size = group[:lr] / bias_correction1
|
73
|
+
|
74
|
+
p.data.addcdiv!(-step_size, exp_avg, denom)
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
loss
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/asgd.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class ASGD < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, lambd: 1e-4, alpha: 0.75, t0: 1e6, weight_decay: 0)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
8
|
+
|
9
|
+
defaults = {lr: lr, lambd: lambd, alpha: alpha, t0: t0, weight_decay: weight_decay}
|
10
|
+
super(params, defaults)
|
11
|
+
end
|
12
|
+
|
13
|
+
def step(closure = nil)
|
14
|
+
loss = nil
|
15
|
+
if closure
|
16
|
+
loss = closure.call
|
17
|
+
end
|
18
|
+
|
19
|
+
@param_groups.each do |group|
|
20
|
+
group[:params].each do |p|
|
21
|
+
next unless p.grad
|
22
|
+
grad = p.grad.data
|
23
|
+
if grad.sparse?
|
24
|
+
raise Error, "ASGD does not support sparse gradients"
|
25
|
+
end
|
26
|
+
state = @state[p]
|
27
|
+
|
28
|
+
# State initialization
|
29
|
+
if state.size == 0
|
30
|
+
state[:step] = 0
|
31
|
+
state[:eta] = group[:lr]
|
32
|
+
state[:mu] = 1
|
33
|
+
state[:ax] = Torch.zeros_like(p.data)
|
34
|
+
end
|
35
|
+
|
36
|
+
state[:step] += 1
|
37
|
+
|
38
|
+
if group[:weight_decay] != 0
|
39
|
+
grad = grad.add(group[:weight_decay], p.data)
|
40
|
+
end
|
41
|
+
|
42
|
+
# decay term
|
43
|
+
p.data.mul!(1 - group[:lambd] * state[:eta])
|
44
|
+
|
45
|
+
# update parameter
|
46
|
+
p.data.add!(-state[:eta], grad)
|
47
|
+
|
48
|
+
# averaging
|
49
|
+
if state[:mu] != 1
|
50
|
+
state[:ax].add!(p.data.sub(state[:ax]).mul(state[:mu]))
|
51
|
+
else
|
52
|
+
state[:ax].copy!(p.data)
|
53
|
+
end
|
54
|
+
|
55
|
+
# update eta and mu
|
56
|
+
state[:eta] = (group[:lr] / ((1 + group[:lambd] * group[:lr] * state[:step]) ** group[:alpha]))
|
57
|
+
state[:mu] = 1 / [1, state[:step] - group[:t0]].max
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
loss
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|
@@ -0,0 +1,33 @@
|
|
1
|
+
module Torch
|
2
|
+
module Optim
|
3
|
+
module LRScheduler
|
4
|
+
class LRScheduler
|
5
|
+
def initialize(optimizer, last_epoch)
|
6
|
+
@optimizer = optimizer
|
7
|
+
if last_epoch == -1
|
8
|
+
optimizer.param_groups.each do |group|
|
9
|
+
group[:initial_lr] ||= group[:lr]
|
10
|
+
end
|
11
|
+
last_epoch = 0
|
12
|
+
else
|
13
|
+
raise NotImplementedYet
|
14
|
+
end
|
15
|
+
@base_lrs = optimizer.param_groups.map { |group| group[:initial_lr] }
|
16
|
+
@last_epoch = last_epoch
|
17
|
+
|
18
|
+
@step_count = 0
|
19
|
+
step(last_epoch)
|
20
|
+
end
|
21
|
+
|
22
|
+
def step(epoch = nil)
|
23
|
+
@step_count += 1
|
24
|
+
epoch ||= @last_epoch + 1
|
25
|
+
@last_epoch = epoch
|
26
|
+
@optimizer.param_groups.zip(get_lr).each do |param_group, lr|
|
27
|
+
param_group[:lr] = lr
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
module Torch
|
2
|
+
module Optim
|
3
|
+
module LRScheduler
|
4
|
+
class StepLR < LRScheduler
|
5
|
+
def initialize(optimizer, step_size:, gamma: 0.1, last_epoch: -1)
|
6
|
+
@step_size = step_size
|
7
|
+
@gamma = gamma
|
8
|
+
super(optimizer, last_epoch)
|
9
|
+
end
|
10
|
+
|
11
|
+
def get_lr
|
12
|
+
@base_lrs.map { |base_lr| base_lr * @gamma ** (@last_epoch / @step_size).floor }
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -1,6 +1,62 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py
|
1
2
|
module Torch
|
2
3
|
module Optim
|
3
4
|
class Optimizer
|
5
|
+
attr_reader :param_groups
|
6
|
+
|
7
|
+
def initialize(params, defaults)
|
8
|
+
@defaults = defaults
|
9
|
+
@state = Hash.new { |hash, key| hash[key] = {} }
|
10
|
+
@param_groups = []
|
11
|
+
|
12
|
+
param_groups = params
|
13
|
+
if param_groups.empty?
|
14
|
+
raise ArgumentError, "optimizer got an empty parameter list"
|
15
|
+
end
|
16
|
+
if !param_groups[0].is_a?(Hash)
|
17
|
+
param_groups = [{params: param_groups}]
|
18
|
+
end
|
19
|
+
|
20
|
+
param_groups.each do |param_group|
|
21
|
+
add_param_group(param_group)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
def add_param_group(param_group)
|
26
|
+
# TODO more advanced logic
|
27
|
+
@param_groups << @defaults.merge(param_group)
|
28
|
+
end
|
29
|
+
|
30
|
+
def load_state_dict(state_dict)
|
31
|
+
raise NotImplementedYet
|
32
|
+
end
|
33
|
+
|
34
|
+
def state_dict
|
35
|
+
pack_group = lambda do |group|
|
36
|
+
packed = group.select { |k, _| k != :params }.to_h
|
37
|
+
packed[:params] = group[:params].map { |p| p.object_id }
|
38
|
+
packed
|
39
|
+
end
|
40
|
+
|
41
|
+
param_groups = @param_groups.map { |g| pack_group.call(g) }
|
42
|
+
packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
|
43
|
+
|
44
|
+
{
|
45
|
+
state: packed_state,
|
46
|
+
param_groups: param_groups
|
47
|
+
}
|
48
|
+
end
|
49
|
+
|
50
|
+
def zero_grad
|
51
|
+
@param_groups.each do |group|
|
52
|
+
group[:params].each do |p|
|
53
|
+
if p.grad
|
54
|
+
p.grad.detach!
|
55
|
+
p.grad.zero!
|
56
|
+
end
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
4
60
|
end
|
5
61
|
end
|
6
62
|
end
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class RMSprop < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, alpha: 0.99, eps: 1e-8, weight_decay: 0, momentum: 0, centered: false)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
8
|
+
raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0
|
9
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
10
|
+
raise ArgumentError, "Invalid momentum alpha: #{alpha}" if alpha < 0
|
11
|
+
|
12
|
+
defaults = {lr: lr, momentum: momentum, alpha: alpha, eps: eps, centered: centered, weight_decay: weight_decay}
|
13
|
+
super(params, defaults)
|
14
|
+
end
|
15
|
+
|
16
|
+
def step(closure = nil)
|
17
|
+
loss = nil
|
18
|
+
if closure
|
19
|
+
loss = closure.call
|
20
|
+
end
|
21
|
+
|
22
|
+
@param_groups.each do |group|
|
23
|
+
group[:params].each do |p|
|
24
|
+
next unless p.grad
|
25
|
+
grad = p.grad.data
|
26
|
+
if grad.sparse?
|
27
|
+
raise Error, "RMSprop does not support sparse gradients"
|
28
|
+
end
|
29
|
+
state = @state[p]
|
30
|
+
|
31
|
+
# State initialization
|
32
|
+
if state.size == 0
|
33
|
+
state[:step] = 0
|
34
|
+
state[:square_avg] = Torch.zeros_like(p.data)
|
35
|
+
if group[:momentum] > 0
|
36
|
+
state[:momentum_buffer] = Torch.zeros_like(p.data)
|
37
|
+
end
|
38
|
+
if group[:centered]
|
39
|
+
state[:grad_avg] = Torch.zeros_like(p.data)
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
square_avg = state[:square_avg]
|
44
|
+
alpha = group[:alpha]
|
45
|
+
|
46
|
+
state[:step] += 1
|
47
|
+
|
48
|
+
if group[:weight_decay] != 0
|
49
|
+
grad = grad.add(group[:weight_decay], p.data)
|
50
|
+
end
|
51
|
+
|
52
|
+
square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
|
53
|
+
|
54
|
+
if group[:centered]
|
55
|
+
grad_avg = state[:grad_avg]
|
56
|
+
grad_avg.mul!(alpha).add!(1 - alpha, grad)
|
57
|
+
avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
|
58
|
+
else
|
59
|
+
avg = square_avg.sqrt.add!(group[:eps])
|
60
|
+
end
|
61
|
+
|
62
|
+
if group[:momentum] > 0
|
63
|
+
buf = state[:momentum_buffer]
|
64
|
+
buf.mul!(group[:momentum]).addcdiv!(grad, avg)
|
65
|
+
p.data.add!(-group[:lr], buf)
|
66
|
+
else
|
67
|
+
p.data.addcdiv!(-group[:lr], grad, avg)
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
|
72
|
+
loss
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rprop.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Rprop < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, etas: [0.5, 1.2], step_sizes: [1e-6, 50])
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid eta values: #{etas[0]}, #{etas[1]}" if etas[0] < 0 || etas[0] >= 1 || etas[1] < 1
|
8
|
+
|
9
|
+
defaults = {lr: lr, etas: etas, step_sizes: step_sizes}
|
10
|
+
super(params, defaults)
|
11
|
+
end
|
12
|
+
|
13
|
+
def step(closure = nil)
|
14
|
+
# TODO implement []=
|
15
|
+
raise NotImplementedYet
|
16
|
+
|
17
|
+
loss = nil
|
18
|
+
if closure
|
19
|
+
loss = closure.call
|
20
|
+
end
|
21
|
+
|
22
|
+
@param_groups.each do |group|
|
23
|
+
group[:params].each do |p|
|
24
|
+
next unless p.grad
|
25
|
+
grad = p.grad.data
|
26
|
+
if grad.sparse?
|
27
|
+
raise Error, "Rprop does not support sparse gradients"
|
28
|
+
end
|
29
|
+
state = @state[p]
|
30
|
+
|
31
|
+
# State initialization
|
32
|
+
if state.size == 0
|
33
|
+
state[:step] = 0
|
34
|
+
state[:prev] = Torch.zeros_like(p.data)
|
35
|
+
state[:step_size] = grad.new.resize_as!(grad).fill!(group[:lr])
|
36
|
+
end
|
37
|
+
|
38
|
+
etaminus, etaplus = group[:etas]
|
39
|
+
step_size_min, step_size_max = group[:step_sizes]
|
40
|
+
step_size = state[:step_size]
|
41
|
+
|
42
|
+
state[:step] += 1
|
43
|
+
|
44
|
+
sign = grad.mul(state[:prev]).sign
|
45
|
+
sign[sign.gt(0)] = etaplus
|
46
|
+
sign[sign.lt(0)] = etaminus
|
47
|
+
sign[sign.eq(0)] = 1
|
48
|
+
|
49
|
+
# update stepsizes with step size updates
|
50
|
+
step_size.mul!(sign).clamp!(step_size_min, step_size_max)
|
51
|
+
|
52
|
+
# for dir<0, dfdx=0
|
53
|
+
# for dir>=0 dfdx=dfdx
|
54
|
+
grad = grad.clone
|
55
|
+
grad[sign.eq(etaminus)] = 0
|
56
|
+
|
57
|
+
# update parameters
|
58
|
+
p.data.addcmul!(-1, grad.sign, step_size)
|
59
|
+
|
60
|
+
state[:prev].copy!(grad)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
loss
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|