torch-rb 0.1.2 → 0.1.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +35 -0
- data/LICENSE.txt +46 -22
- data/README.md +18 -6
- data/ext/torch/ext.cpp +148 -369
- data/ext/torch/extconf.rb +6 -0
- data/ext/torch/nn_functions.cpp +615 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +242 -0
- data/ext/torch/tensor_functions.cpp +1920 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2975 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +240 -131
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +27 -22
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +109 -0
- data/lib/torch/native/generator.rb +168 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +134 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +19 -0
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +16 -38
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +411 -22
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +201 -20
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +2 -2
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +198 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +56 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +48 -16
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +71 -30
- data/lib/torch/utils/data/data_loader.rb +10 -4
- data/lib/torch/utils/data/tensor_dataset.rb +3 -0
- data/lib/torch/version.rb +1 -1
- metadata +123 -6
@@ -0,0 +1,68 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adamax.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Adamax < Optimizer
|
5
|
+
def initialize(params, lr: 2e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 0)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
8
|
+
raise ArgumentError, "Invalid beta parameter at index 0: #{betas[0]}" if betas[0] < 0 || betas[0] >= 1
|
9
|
+
raise ArgumentError, "Invalid beta parameter at index 1: #{betas[1]}" if betas[1] < 0 || betas[1] >= 1
|
10
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
11
|
+
|
12
|
+
defaults = {lr: lr, betas: betas, eps: eps, weight_decay: weight_decay}
|
13
|
+
super(params, defaults)
|
14
|
+
end
|
15
|
+
|
16
|
+
def step(closure = nil)
|
17
|
+
loss = nil
|
18
|
+
if closure
|
19
|
+
loss = closure.call
|
20
|
+
end
|
21
|
+
|
22
|
+
@param_groups.each do |group|
|
23
|
+
group[:params].each do |p|
|
24
|
+
next unless p.grad
|
25
|
+
grad = p.grad.data
|
26
|
+
if grad.sparse?
|
27
|
+
raise Error, "Adamax does not support sparse gradients, please consider SparseAdam instead"
|
28
|
+
end
|
29
|
+
state = @state[p]
|
30
|
+
|
31
|
+
# State initialization
|
32
|
+
if state.size == 0
|
33
|
+
state[:step] = 0
|
34
|
+
state[:exp_avg] = Torch.zeros_like(p.data)
|
35
|
+
state[:exp_inf] = Torch.zeros_like(p.data)
|
36
|
+
end
|
37
|
+
|
38
|
+
exp_avg, exp_inf = state[:exp_avg], state[:exp_inf]
|
39
|
+
beta1, beta2 = group[:betas]
|
40
|
+
eps = group[:eps]
|
41
|
+
|
42
|
+
state[:step] += 1
|
43
|
+
|
44
|
+
if group[:weight_decay] != 0
|
45
|
+
grad = grad.add(group[:weight_decay], p.data)
|
46
|
+
end
|
47
|
+
|
48
|
+
# Update biased first moment estimate.
|
49
|
+
exp_avg.mul!(beta1).add!(1 - beta1, grad)
|
50
|
+
# Update the exponentially weighted infinity norm.
|
51
|
+
norm_buf = Torch.cat([
|
52
|
+
exp_inf.mul!(beta2).unsqueeze(0),
|
53
|
+
grad.abs.add!(eps).unsqueeze!(0)
|
54
|
+
], 0)
|
55
|
+
Torch.max(norm_buf, 0, keepdim: false, out: [exp_inf, exp_inf.new.long])
|
56
|
+
|
57
|
+
bias_correction = 1 - beta1 ** state[:step]
|
58
|
+
clr = group[:lr] / bias_correction
|
59
|
+
|
60
|
+
p.data.addcdiv!(-clr, exp_avg, exp_inf)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
loss
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adamw.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class AdamW < Optimizer
|
5
|
+
def initialize(params, lr: 1e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 1e-2, amsgrad: false)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
8
|
+
raise ArgumentError, "Invalid beta parameter at index 0: #{betas[0]}" if betas[0] < 0 || betas[0] >= 1
|
9
|
+
raise ArgumentError, "Invalid beta parameter at index 1: #{betas[1]}" if betas[1] < 0 || betas[1] >= 1
|
10
|
+
|
11
|
+
defaults = {lr: lr, betas: betas, eps: eps, weight_decay: weight_decay, amsgrad: amsgrad}
|
12
|
+
super(params, defaults)
|
13
|
+
end
|
14
|
+
|
15
|
+
def step(closure = nil)
|
16
|
+
loss = nil
|
17
|
+
if closure
|
18
|
+
loss = closure.call
|
19
|
+
end
|
20
|
+
|
21
|
+
@param_groups.each do |group|
|
22
|
+
group[:params].each do |p|
|
23
|
+
next unless p.grad
|
24
|
+
|
25
|
+
# Perform stepweight decay
|
26
|
+
p.data.mul!(1 - group[:lr] * group[:weight_decay])
|
27
|
+
|
28
|
+
# Perform optimization step
|
29
|
+
grad = p.grad.data
|
30
|
+
if grad.sparse?
|
31
|
+
raise Error, "AdamW does not support sparse gradients, please consider SparseAdam instead"
|
32
|
+
end
|
33
|
+
amsgrad = group[:amsgrad]
|
34
|
+
|
35
|
+
state = @state[p]
|
36
|
+
|
37
|
+
# State initialization
|
38
|
+
if state.size == 0
|
39
|
+
state[:step] = 0
|
40
|
+
# Exponential moving average of gradient values
|
41
|
+
state[:exp_avg] = Torch.zeros_like(p.data)
|
42
|
+
# Exponential moving average of squared gradient values
|
43
|
+
state[:exp_avg_sq] = Torch.zeros_like(p.data)
|
44
|
+
if amsgrad
|
45
|
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
46
|
+
state[:max_exp_avg_sq] = Torch.zeros_like(p.data)
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
exp_avg, exp_avg_sq = state[:exp_avg], state[:exp_avg_sq]
|
51
|
+
if amsgrad
|
52
|
+
max_exp_avg_sq = state[:max_exp_avg_sq]
|
53
|
+
end
|
54
|
+
beta1, beta2 = group[:betas]
|
55
|
+
|
56
|
+
state[:step] += 1
|
57
|
+
bias_correction1 = 1 - beta1 ** state[:step]
|
58
|
+
bias_correction2 = 1 - beta2 ** state[:step]
|
59
|
+
|
60
|
+
# Decay the first and second moment running average coefficient
|
61
|
+
exp_avg.mul!(beta1).add!(1 - beta1, grad)
|
62
|
+
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
|
63
|
+
if amsgrad
|
64
|
+
# Maintains the maximum of all 2nd moment running avg. till now
|
65
|
+
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
66
|
+
# Use the max. for normalizing running avg. of gradient
|
67
|
+
denom = (max_exp_avg_sq.sqrt / Math.sqrt(bias_correction2)).add!(group[:eps])
|
68
|
+
else
|
69
|
+
denom = (exp_avg_sq.sqrt / Math.sqrt(bias_correction2)).add!(group[:eps])
|
70
|
+
end
|
71
|
+
|
72
|
+
step_size = group[:lr] / bias_correction1
|
73
|
+
|
74
|
+
p.data.addcdiv!(-step_size, exp_avg, denom)
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
loss
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/asgd.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class ASGD < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, lambd: 1e-4, alpha: 0.75, t0: 1e6, weight_decay: 0)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
8
|
+
|
9
|
+
defaults = {lr: lr, lambd: lambd, alpha: alpha, t0: t0, weight_decay: weight_decay}
|
10
|
+
super(params, defaults)
|
11
|
+
end
|
12
|
+
|
13
|
+
def step(closure = nil)
|
14
|
+
loss = nil
|
15
|
+
if closure
|
16
|
+
loss = closure.call
|
17
|
+
end
|
18
|
+
|
19
|
+
@param_groups.each do |group|
|
20
|
+
group[:params].each do |p|
|
21
|
+
next unless p.grad
|
22
|
+
grad = p.grad.data
|
23
|
+
if grad.sparse?
|
24
|
+
raise Error, "ASGD does not support sparse gradients"
|
25
|
+
end
|
26
|
+
state = @state[p]
|
27
|
+
|
28
|
+
# State initialization
|
29
|
+
if state.size == 0
|
30
|
+
state[:step] = 0
|
31
|
+
state[:eta] = group[:lr]
|
32
|
+
state[:mu] = 1
|
33
|
+
state[:ax] = Torch.zeros_like(p.data)
|
34
|
+
end
|
35
|
+
|
36
|
+
state[:step] += 1
|
37
|
+
|
38
|
+
if group[:weight_decay] != 0
|
39
|
+
grad = grad.add(group[:weight_decay], p.data)
|
40
|
+
end
|
41
|
+
|
42
|
+
# decay term
|
43
|
+
p.data.mul!(1 - group[:lambd] * state[:eta])
|
44
|
+
|
45
|
+
# update parameter
|
46
|
+
p.data.add!(-state[:eta], grad)
|
47
|
+
|
48
|
+
# averaging
|
49
|
+
if state[:mu] != 1
|
50
|
+
state[:ax].add!(p.data.sub(state[:ax]).mul(state[:mu]))
|
51
|
+
else
|
52
|
+
state[:ax].copy!(p.data)
|
53
|
+
end
|
54
|
+
|
55
|
+
# update eta and mu
|
56
|
+
state[:eta] = (group[:lr] / ((1 + group[:lambd] * group[:lr] * state[:step]) ** group[:alpha]))
|
57
|
+
state[:mu] = 1 / [1, state[:step] - group[:t0]].max
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
loss
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|
@@ -0,0 +1,33 @@
|
|
1
|
+
module Torch
|
2
|
+
module Optim
|
3
|
+
module LRScheduler
|
4
|
+
class LRScheduler
|
5
|
+
def initialize(optimizer, last_epoch)
|
6
|
+
@optimizer = optimizer
|
7
|
+
if last_epoch == -1
|
8
|
+
optimizer.param_groups.each do |group|
|
9
|
+
group[:initial_lr] ||= group[:lr]
|
10
|
+
end
|
11
|
+
last_epoch = 0
|
12
|
+
else
|
13
|
+
raise NotImplementedYet
|
14
|
+
end
|
15
|
+
@base_lrs = optimizer.param_groups.map { |group| group[:initial_lr] }
|
16
|
+
@last_epoch = last_epoch
|
17
|
+
|
18
|
+
@step_count = 0
|
19
|
+
step(last_epoch)
|
20
|
+
end
|
21
|
+
|
22
|
+
def step(epoch = nil)
|
23
|
+
@step_count += 1
|
24
|
+
epoch ||= @last_epoch + 1
|
25
|
+
@last_epoch = epoch
|
26
|
+
@optimizer.param_groups.zip(get_lr).each do |param_group, lr|
|
27
|
+
param_group[:lr] = lr
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
module Torch
|
2
|
+
module Optim
|
3
|
+
module LRScheduler
|
4
|
+
class StepLR < LRScheduler
|
5
|
+
def initialize(optimizer, step_size:, gamma: 0.1, last_epoch: -1)
|
6
|
+
@step_size = step_size
|
7
|
+
@gamma = gamma
|
8
|
+
super(optimizer, last_epoch)
|
9
|
+
end
|
10
|
+
|
11
|
+
def get_lr
|
12
|
+
@base_lrs.map { |base_lr| base_lr * @gamma ** (@last_epoch / @step_size).floor }
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -1,6 +1,62 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/optimizer.py
|
1
2
|
module Torch
|
2
3
|
module Optim
|
3
4
|
class Optimizer
|
5
|
+
attr_reader :param_groups
|
6
|
+
|
7
|
+
def initialize(params, defaults)
|
8
|
+
@defaults = defaults
|
9
|
+
@state = Hash.new { |hash, key| hash[key] = {} }
|
10
|
+
@param_groups = []
|
11
|
+
|
12
|
+
param_groups = params
|
13
|
+
if param_groups.empty?
|
14
|
+
raise ArgumentError, "optimizer got an empty parameter list"
|
15
|
+
end
|
16
|
+
if !param_groups[0].is_a?(Hash)
|
17
|
+
param_groups = [{params: param_groups}]
|
18
|
+
end
|
19
|
+
|
20
|
+
param_groups.each do |param_group|
|
21
|
+
add_param_group(param_group)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
def add_param_group(param_group)
|
26
|
+
# TODO more advanced logic
|
27
|
+
@param_groups << @defaults.merge(param_group)
|
28
|
+
end
|
29
|
+
|
30
|
+
def load_state_dict(state_dict)
|
31
|
+
raise NotImplementedYet
|
32
|
+
end
|
33
|
+
|
34
|
+
def state_dict
|
35
|
+
pack_group = lambda do |group|
|
36
|
+
packed = group.select { |k, _| k != :params }.to_h
|
37
|
+
packed[:params] = group[:params].map { |p| p.object_id }
|
38
|
+
packed
|
39
|
+
end
|
40
|
+
|
41
|
+
param_groups = @param_groups.map { |g| pack_group.call(g) }
|
42
|
+
packed_state = @state.map { |k, v| [k.is_a?(Tensor) ? k.object_id : k, v] }.to_h
|
43
|
+
|
44
|
+
{
|
45
|
+
state: packed_state,
|
46
|
+
param_groups: param_groups
|
47
|
+
}
|
48
|
+
end
|
49
|
+
|
50
|
+
def zero_grad
|
51
|
+
@param_groups.each do |group|
|
52
|
+
group[:params].each do |p|
|
53
|
+
if p.grad
|
54
|
+
p.grad.detach!
|
55
|
+
p.grad.zero!
|
56
|
+
end
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|
4
60
|
end
|
5
61
|
end
|
6
62
|
end
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class RMSprop < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, alpha: 0.99, eps: 1e-8, weight_decay: 0, momentum: 0, centered: false)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
8
|
+
raise ArgumentError, "Invalid momentum value: #{momentum}" if momentum < 0
|
9
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
10
|
+
raise ArgumentError, "Invalid momentum alpha: #{alpha}" if alpha < 0
|
11
|
+
|
12
|
+
defaults = {lr: lr, momentum: momentum, alpha: alpha, eps: eps, centered: centered, weight_decay: weight_decay}
|
13
|
+
super(params, defaults)
|
14
|
+
end
|
15
|
+
|
16
|
+
def step(closure = nil)
|
17
|
+
loss = nil
|
18
|
+
if closure
|
19
|
+
loss = closure.call
|
20
|
+
end
|
21
|
+
|
22
|
+
@param_groups.each do |group|
|
23
|
+
group[:params].each do |p|
|
24
|
+
next unless p.grad
|
25
|
+
grad = p.grad.data
|
26
|
+
if grad.sparse?
|
27
|
+
raise Error, "RMSprop does not support sparse gradients"
|
28
|
+
end
|
29
|
+
state = @state[p]
|
30
|
+
|
31
|
+
# State initialization
|
32
|
+
if state.size == 0
|
33
|
+
state[:step] = 0
|
34
|
+
state[:square_avg] = Torch.zeros_like(p.data)
|
35
|
+
if group[:momentum] > 0
|
36
|
+
state[:momentum_buffer] = Torch.zeros_like(p.data)
|
37
|
+
end
|
38
|
+
if group[:centered]
|
39
|
+
state[:grad_avg] = Torch.zeros_like(p.data)
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
square_avg = state[:square_avg]
|
44
|
+
alpha = group[:alpha]
|
45
|
+
|
46
|
+
state[:step] += 1
|
47
|
+
|
48
|
+
if group[:weight_decay] != 0
|
49
|
+
grad = grad.add(group[:weight_decay], p.data)
|
50
|
+
end
|
51
|
+
|
52
|
+
square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
|
53
|
+
|
54
|
+
if group[:centered]
|
55
|
+
grad_avg = state[:grad_avg]
|
56
|
+
grad_avg.mul!(alpha).add!(1 - alpha, grad)
|
57
|
+
avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
|
58
|
+
else
|
59
|
+
avg = square_avg.sqrt.add!(group[:eps])
|
60
|
+
end
|
61
|
+
|
62
|
+
if group[:momentum] > 0
|
63
|
+
buf = state[:momentum_buffer]
|
64
|
+
buf.mul!(group[:momentum]).addcdiv!(grad, avg)
|
65
|
+
p.data.add!(-group[:lr], buf)
|
66
|
+
else
|
67
|
+
p.data.addcdiv!(-group[:lr], grad, avg)
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
|
72
|
+
loss
|
73
|
+
end
|
74
|
+
end
|
75
|
+
end
|
76
|
+
end
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/rprop.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Rprop < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, etas: [0.5, 1.2], step_sizes: [1e-6, 50])
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid eta values: #{etas[0]}, #{etas[1]}" if etas[0] < 0 || etas[0] >= 1 || etas[1] < 1
|
8
|
+
|
9
|
+
defaults = {lr: lr, etas: etas, step_sizes: step_sizes}
|
10
|
+
super(params, defaults)
|
11
|
+
end
|
12
|
+
|
13
|
+
def step(closure = nil)
|
14
|
+
# TODO implement []=
|
15
|
+
raise NotImplementedYet
|
16
|
+
|
17
|
+
loss = nil
|
18
|
+
if closure
|
19
|
+
loss = closure.call
|
20
|
+
end
|
21
|
+
|
22
|
+
@param_groups.each do |group|
|
23
|
+
group[:params].each do |p|
|
24
|
+
next unless p.grad
|
25
|
+
grad = p.grad.data
|
26
|
+
if grad.sparse?
|
27
|
+
raise Error, "Rprop does not support sparse gradients"
|
28
|
+
end
|
29
|
+
state = @state[p]
|
30
|
+
|
31
|
+
# State initialization
|
32
|
+
if state.size == 0
|
33
|
+
state[:step] = 0
|
34
|
+
state[:prev] = Torch.zeros_like(p.data)
|
35
|
+
state[:step_size] = grad.new.resize_as!(grad).fill!(group[:lr])
|
36
|
+
end
|
37
|
+
|
38
|
+
etaminus, etaplus = group[:etas]
|
39
|
+
step_size_min, step_size_max = group[:step_sizes]
|
40
|
+
step_size = state[:step_size]
|
41
|
+
|
42
|
+
state[:step] += 1
|
43
|
+
|
44
|
+
sign = grad.mul(state[:prev]).sign
|
45
|
+
sign[sign.gt(0)] = etaplus
|
46
|
+
sign[sign.lt(0)] = etaminus
|
47
|
+
sign[sign.eq(0)] = 1
|
48
|
+
|
49
|
+
# update stepsizes with step size updates
|
50
|
+
step_size.mul!(sign).clamp!(step_size_min, step_size_max)
|
51
|
+
|
52
|
+
# for dir<0, dfdx=0
|
53
|
+
# for dir>=0 dfdx=dfdx
|
54
|
+
grad = grad.clone
|
55
|
+
grad[sign.eq(etaminus)] = 0
|
56
|
+
|
57
|
+
# update parameters
|
58
|
+
p.data.addcmul!(-1, grad.sign, step_size)
|
59
|
+
|
60
|
+
state[:prev].copy!(grad)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
loss
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|