torch-rb 0.1.2 → 0.1.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/LICENSE.txt +46 -22
- data/README.md +14 -5
- data/ext/torch/ext.cpp +248 -31
- data/lib/torch.rb +80 -9
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +4 -3
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/conv2d.rb +12 -24
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +54 -12
- data/lib/torch/nn/linear.rb +2 -2
- data/lib/torch/nn/module.rb +30 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +56 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +48 -16
- data/lib/torch/tensor.rb +38 -4
- data/lib/torch/utils/data/data_loader.rb +10 -4
- data/lib/torch/utils/data/tensor_dataset.rb +3 -0
- data/lib/torch/version.rb +1 -1
- metadata +21 -3
data/lib/torch/nn/linear.rb
CHANGED
@@ -20,11 +20,11 @@ module Torch
|
|
20
20
|
end
|
21
21
|
|
22
22
|
def reset_parameters
|
23
|
-
Init.
|
23
|
+
Init.kaiming_uniform!(@weight, Math.sqrt(5))
|
24
24
|
if @bias
|
25
25
|
fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
|
26
26
|
bound = 1 / Math.sqrt(fan_in)
|
27
|
-
Init.
|
27
|
+
Init.uniform!(@bias, -bound, bound)
|
28
28
|
end
|
29
29
|
end
|
30
30
|
|
data/lib/torch/nn/module.rb
CHANGED
@@ -1,6 +1,10 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Module
|
4
|
+
def initialize
|
5
|
+
@training = true
|
6
|
+
end
|
7
|
+
|
4
8
|
def inspect
|
5
9
|
str = String.new
|
6
10
|
str << "#{self.class.name}(\n"
|
@@ -10,10 +14,36 @@ module Torch
|
|
10
14
|
str << ")"
|
11
15
|
end
|
12
16
|
|
17
|
+
def train(mode = true)
|
18
|
+
@training = mode
|
19
|
+
|
20
|
+
modules.each do |_, mod|
|
21
|
+
mod.train(mode)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
def eval
|
26
|
+
train(false)
|
27
|
+
end
|
28
|
+
|
13
29
|
def call(*input)
|
14
30
|
forward(*input)
|
15
31
|
end
|
16
32
|
|
33
|
+
# modifies in-place
|
34
|
+
def to(device)
|
35
|
+
instance_variables.each do |name|
|
36
|
+
param = instance_variable_get(name)
|
37
|
+
if param.is_a?(Parameter)
|
38
|
+
instance_variable_set(name, Parameter.new(param.to(device)))
|
39
|
+
end
|
40
|
+
end
|
41
|
+
modules.each do |_, mod|
|
42
|
+
mod.to(device)
|
43
|
+
end
|
44
|
+
self
|
45
|
+
end
|
46
|
+
|
17
47
|
def parameters
|
18
48
|
params = []
|
19
49
|
instance_variables.each do |name|
|
@@ -0,0 +1,57 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adadelta.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Adadelta < Optimizer
|
5
|
+
def initialize(params, lr: 1.0, rho: 0.9, eps: 1e-6, weight_decay: 0)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid rho value: #{rho}" if rho < 0 || rho > 1
|
8
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
9
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
10
|
+
|
11
|
+
defaults = {lr: lr, rho: rho, eps: eps, weight_decay: weight_decay}
|
12
|
+
super(params, defaults)
|
13
|
+
end
|
14
|
+
|
15
|
+
def step(closure = nil)
|
16
|
+
loss = nil
|
17
|
+
if closure
|
18
|
+
loss = closure.call
|
19
|
+
end
|
20
|
+
|
21
|
+
@param_groups.each do |group|
|
22
|
+
group[:params].each do |p|
|
23
|
+
next unless p.grad
|
24
|
+
grad = p.grad.data
|
25
|
+
if grad.sparse?
|
26
|
+
raise Error, "Adadelta does not support sparse gradients"
|
27
|
+
end
|
28
|
+
state = @state[p]
|
29
|
+
|
30
|
+
if state.size == 0
|
31
|
+
state[:step] = 0
|
32
|
+
state[:square_avg] = Torch.zeros_like(p.data)
|
33
|
+
state[:acc_delta] = Torch.zeros_like(p.data)
|
34
|
+
end
|
35
|
+
|
36
|
+
square_avg, acc_delta = state[:square_avg], state[:acc_delta]
|
37
|
+
rho, eps = group[:rho], group[:eps]
|
38
|
+
|
39
|
+
state[:step] += 1
|
40
|
+
|
41
|
+
if group[:weight_decay] != 0
|
42
|
+
grad = grad.add(group[:weight_decay], p.data)
|
43
|
+
end
|
44
|
+
|
45
|
+
square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
|
46
|
+
std = square_avg.add(eps).sqrt!
|
47
|
+
delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
|
48
|
+
p.data.add!(-group[:lr], delta)
|
49
|
+
acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
loss
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adagrad.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Adagrad < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, lr_decay: 0, weight_decay: 0, initial_accumulator_value: 0, eps: 1e-10)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid lr_decay value: #{lr_decay}" if lr_decay < 0
|
8
|
+
raise ArgumentError, "Invalid initial_accumulator_value value: #{initial_accumulator_value}" if initial_accumulator_value < 0
|
9
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
10
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
11
|
+
|
12
|
+
defaults = {lr: lr, lr_decay: lr_decay, eps: eps, weight_decay: weight_decay, initial_accumulator_value: initial_accumulator_value}
|
13
|
+
super(params, defaults)
|
14
|
+
|
15
|
+
@param_groups.each do |group|
|
16
|
+
group[:params].each do |p|
|
17
|
+
state = @state[p]
|
18
|
+
state[:step] = 0
|
19
|
+
state[:sum] = Torch.full_like(p.data, initial_accumulator_value)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
def share_memory
|
25
|
+
@param_groups.each do |group|
|
26
|
+
group[:params].each do |p|
|
27
|
+
state = @state[p]
|
28
|
+
state[:sum].share_memory!
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
def step(closure = nil)
|
34
|
+
loss = nil
|
35
|
+
if closure
|
36
|
+
loss = closure.call
|
37
|
+
end
|
38
|
+
|
39
|
+
@param_groups.each do |group|
|
40
|
+
group[:params].each do |p|
|
41
|
+
next unless p.grad
|
42
|
+
|
43
|
+
grad = p.grad.data
|
44
|
+
state = @state[p]
|
45
|
+
|
46
|
+
state[:step] += 1
|
47
|
+
|
48
|
+
if group[:weight_decay] != 0
|
49
|
+
if p.grad.data.sparse?
|
50
|
+
raise Error, "weight_decay option is not compatible with sparse gradients"
|
51
|
+
end
|
52
|
+
grad = grad.add(group[:weight_decay], p.data)
|
53
|
+
end
|
54
|
+
|
55
|
+
clr = group[:lr] / (1 + (state[:step] - 1) * group[:lr_decay])
|
56
|
+
|
57
|
+
if grad.sparse?
|
58
|
+
raise NotImplementedYet
|
59
|
+
else
|
60
|
+
state[:sum].addcmul!(1, grad, grad)
|
61
|
+
std = state[:sum].sqrt.add!(group[:eps])
|
62
|
+
p.data.addcdiv!(-clr, grad, std)
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|
66
|
+
|
67
|
+
loss
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
@@ -0,0 +1,81 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Adam < Optimizer
|
5
|
+
def initialize(params, lr: 1e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 0, amsgrad: false)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
8
|
+
raise ArgumentError, "Invalid beta parameter at index 0: #{betas[0]}" if betas[0] < 0 || betas[0] >= 1
|
9
|
+
raise ArgumentError, "Invalid beta parameter at index 1: #{betas[1]}" if betas[1] < 0 || betas[1] >= 1
|
10
|
+
|
11
|
+
defaults = {lr: lr, betas: betas, eps: eps, weight_decay: weight_decay, amsgrad: amsgrad}
|
12
|
+
super(params, defaults)
|
13
|
+
end
|
14
|
+
|
15
|
+
def step(closure = nil)
|
16
|
+
loss = nil
|
17
|
+
if closure
|
18
|
+
loss = closure.call
|
19
|
+
end
|
20
|
+
|
21
|
+
@param_groups.each do |group|
|
22
|
+
group[:params].each do |p|
|
23
|
+
next unless p.grad
|
24
|
+
grad = p.grad.data
|
25
|
+
if grad.sparse?
|
26
|
+
raise Error, "Adam does not support sparse gradients, please consider SparseAdam instead"
|
27
|
+
end
|
28
|
+
amsgrad = group[:amsgrad]
|
29
|
+
|
30
|
+
state = @state[p]
|
31
|
+
|
32
|
+
# State initialization
|
33
|
+
if state.size == 0
|
34
|
+
state[:step] = 0
|
35
|
+
# Exponential moving average of gradient values
|
36
|
+
state[:exp_avg] = Torch.zeros_like(p.data)
|
37
|
+
# Exponential moving average of squared gradient values
|
38
|
+
state[:exp_avg_sq] = Torch.zeros_like(p.data)
|
39
|
+
if amsgrad
|
40
|
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
41
|
+
state[:max_exp_avg_sq] = Torch.zeros_like(p.data)
|
42
|
+
end
|
43
|
+
end
|
44
|
+
|
45
|
+
exp_avg, exp_avg_sq = state[:exp_avg], state[:exp_avg_sq]
|
46
|
+
if amsgrad
|
47
|
+
max_exp_avg_sq = state[:max_exp_avg_sq]
|
48
|
+
end
|
49
|
+
beta1, beta2 = group[:betas]
|
50
|
+
|
51
|
+
state[:step] += 1
|
52
|
+
bias_correction1 = 1 - beta1 ** state[:step]
|
53
|
+
bias_correction2 = 1 - beta2 ** state[:step]
|
54
|
+
|
55
|
+
if group[:weight_decay] != 0
|
56
|
+
grad.add!(group[:weight_decay], p.data)
|
57
|
+
end
|
58
|
+
|
59
|
+
# Decay the first and second moment running average coefficient
|
60
|
+
exp_avg.mul!(beta1).add!(1 - beta1, grad)
|
61
|
+
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
|
62
|
+
if amsgrad
|
63
|
+
# Maintains the maximum of all 2nd moment running avg. till now
|
64
|
+
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
65
|
+
# Use the max. for normalizing running avg. of gradient
|
66
|
+
denom = (max_exp_avg_sq.sqrt / Math.sqrt(bias_correction2)).add!(group[:eps])
|
67
|
+
else
|
68
|
+
denom = (exp_avg_sq.sqrt / Math.sqrt(bias_correction2)).add!(group[:eps])
|
69
|
+
end
|
70
|
+
|
71
|
+
step_size = group[:lr] / bias_correction1
|
72
|
+
|
73
|
+
p.data.addcdiv!(-step_size, exp_avg, denom)
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
loss
|
78
|
+
end
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adamax.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Adamax < Optimizer
|
5
|
+
def initialize(params, lr: 2e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 0)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
8
|
+
raise ArgumentError, "Invalid beta parameter at index 0: #{betas[0]}" if betas[0] < 0 || betas[0] >= 1
|
9
|
+
raise ArgumentError, "Invalid beta parameter at index 1: #{betas[1]}" if betas[1] < 0 || betas[1] >= 1
|
10
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
11
|
+
|
12
|
+
defaults = {lr: lr, betas: betas, eps: eps, weight_decay: weight_decay}
|
13
|
+
super(params, defaults)
|
14
|
+
end
|
15
|
+
|
16
|
+
def step(closure = nil)
|
17
|
+
loss = nil
|
18
|
+
if closure
|
19
|
+
loss = closure.call
|
20
|
+
end
|
21
|
+
|
22
|
+
@param_groups.each do |group|
|
23
|
+
group[:params].each do |p|
|
24
|
+
next unless p.grad
|
25
|
+
grad = p.grad.data
|
26
|
+
if grad.sparse?
|
27
|
+
raise Error, "Adamax does not support sparse gradients, please consider SparseAdam instead"
|
28
|
+
end
|
29
|
+
state = @state[p]
|
30
|
+
|
31
|
+
# State initialization
|
32
|
+
if state.size == 0
|
33
|
+
state[:step] = 0
|
34
|
+
state[:exp_avg] = Torch.zeros_like(p.data)
|
35
|
+
state[:exp_inf] = Torch.zeros_like(p.data)
|
36
|
+
end
|
37
|
+
|
38
|
+
exp_avg, exp_inf = state[:exp_avg], state[:exp_inf]
|
39
|
+
beta1, beta2 = group[:betas]
|
40
|
+
eps = group[:eps]
|
41
|
+
|
42
|
+
state[:step] += 1
|
43
|
+
|
44
|
+
if group[:weight_decay] != 0
|
45
|
+
grad = grad.add(group[:weight_decay], p.data)
|
46
|
+
end
|
47
|
+
|
48
|
+
# Update biased first moment estimate.
|
49
|
+
exp_avg.mul!(beta1).add!(1 - beta1, grad)
|
50
|
+
# Update the exponentially weighted infinity norm.
|
51
|
+
norm_buf = Torch.cat([
|
52
|
+
exp_inf.mul!(beta2).unsqueeze(0),
|
53
|
+
grad.abs.add!(eps).unsqueeze!(0)
|
54
|
+
], 0)
|
55
|
+
Torch.max(norm_buf, 0, keepdim: false, out: [exp_inf, exp_inf.new.long])
|
56
|
+
|
57
|
+
bias_correction = 1 - beta1 ** state[:step]
|
58
|
+
clr = group[:lr] / bias_correction
|
59
|
+
|
60
|
+
p.data.addcdiv!(-clr, exp_avg, exp_inf)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
loss
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adamw.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class AdamW < Optimizer
|
5
|
+
def initialize(params, lr: 1e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 1e-2, amsgrad: false)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
8
|
+
raise ArgumentError, "Invalid beta parameter at index 0: #{betas[0]}" if betas[0] < 0 || betas[0] >= 1
|
9
|
+
raise ArgumentError, "Invalid beta parameter at index 1: #{betas[1]}" if betas[1] < 0 || betas[1] >= 1
|
10
|
+
|
11
|
+
defaults = {lr: lr, betas: betas, eps: eps, weight_decay: weight_decay, amsgrad: amsgrad}
|
12
|
+
super(params, defaults)
|
13
|
+
end
|
14
|
+
|
15
|
+
def step(closure = nil)
|
16
|
+
loss = nil
|
17
|
+
if closure
|
18
|
+
loss = closure.call
|
19
|
+
end
|
20
|
+
|
21
|
+
@param_groups.each do |group|
|
22
|
+
group[:params].each do |p|
|
23
|
+
next unless p.grad
|
24
|
+
|
25
|
+
# Perform stepweight decay
|
26
|
+
p.data.mul!(1 - group[:lr] * group[:weight_decay])
|
27
|
+
|
28
|
+
# Perform optimization step
|
29
|
+
grad = p.grad.data
|
30
|
+
if grad.sparse?
|
31
|
+
raise Error, "AdamW does not support sparse gradients, please consider SparseAdam instead"
|
32
|
+
end
|
33
|
+
amsgrad = group[:amsgrad]
|
34
|
+
|
35
|
+
state = @state[p]
|
36
|
+
|
37
|
+
# State initialization
|
38
|
+
if state.size == 0
|
39
|
+
state[:step] = 0
|
40
|
+
# Exponential moving average of gradient values
|
41
|
+
state[:exp_avg] = Torch.zeros_like(p.data)
|
42
|
+
# Exponential moving average of squared gradient values
|
43
|
+
state[:exp_avg_sq] = Torch.zeros_like(p.data)
|
44
|
+
if amsgrad
|
45
|
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
46
|
+
state[:max_exp_avg_sq] = Torch.zeros_like(p.data)
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
exp_avg, exp_avg_sq = state[:exp_avg], state[:exp_avg_sq]
|
51
|
+
if amsgrad
|
52
|
+
max_exp_avg_sq = state[:max_exp_avg_sq]
|
53
|
+
end
|
54
|
+
beta1, beta2 = group[:betas]
|
55
|
+
|
56
|
+
state[:step] += 1
|
57
|
+
bias_correction1 = 1 - beta1 ** state[:step]
|
58
|
+
bias_correction2 = 1 - beta2 ** state[:step]
|
59
|
+
|
60
|
+
# Decay the first and second moment running average coefficient
|
61
|
+
exp_avg.mul!(beta1).add!(1 - beta1, grad)
|
62
|
+
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
|
63
|
+
if amsgrad
|
64
|
+
# Maintains the maximum of all 2nd moment running avg. till now
|
65
|
+
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
66
|
+
# Use the max. for normalizing running avg. of gradient
|
67
|
+
denom = (max_exp_avg_sq.sqrt / Math.sqrt(bias_correction2)).add!(group[:eps])
|
68
|
+
else
|
69
|
+
denom = (exp_avg_sq.sqrt / Math.sqrt(bias_correction2)).add!(group[:eps])
|
70
|
+
end
|
71
|
+
|
72
|
+
step_size = group[:lr] / bias_correction1
|
73
|
+
|
74
|
+
p.data.addcdiv!(-step_size, exp_avg, denom)
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
loss
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/asgd.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class ASGD < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, lambd: 1e-4, alpha: 0.75, t0: 1e6, weight_decay: 0)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
8
|
+
|
9
|
+
defaults = {lr: lr, lambd: lambd, alpha: alpha, t0: t0, weight_decay: weight_decay}
|
10
|
+
super(params, defaults)
|
11
|
+
end
|
12
|
+
|
13
|
+
def step(closure = nil)
|
14
|
+
loss = nil
|
15
|
+
if closure
|
16
|
+
loss = closure.call
|
17
|
+
end
|
18
|
+
|
19
|
+
@param_groups.each do |group|
|
20
|
+
group[:params].each do |p|
|
21
|
+
next unless p.grad
|
22
|
+
grad = p.grad.data
|
23
|
+
if grad.sparse?
|
24
|
+
raise Error, "ASGD does not support sparse gradients"
|
25
|
+
end
|
26
|
+
state = @state[p]
|
27
|
+
|
28
|
+
# State initialization
|
29
|
+
if state.size == 0
|
30
|
+
state[:step] = 0
|
31
|
+
state[:eta] = group[:lr]
|
32
|
+
state[:mu] = 1
|
33
|
+
state[:ax] = Torch.zeros_like(p.data)
|
34
|
+
end
|
35
|
+
|
36
|
+
state[:step] += 1
|
37
|
+
|
38
|
+
if group[:weight_decay] != 0
|
39
|
+
grad = grad.add(group[:weight_decay], p.data)
|
40
|
+
end
|
41
|
+
|
42
|
+
# decay term
|
43
|
+
p.data.mul!(1 - group[:lambd] * state[:eta])
|
44
|
+
|
45
|
+
# update parameter
|
46
|
+
p.data.add!(-state[:eta], grad)
|
47
|
+
|
48
|
+
# averaging
|
49
|
+
if state[:mu] != 1
|
50
|
+
state[:ax].add!(p.data.sub(state[:ax]).mul(state[:mu]))
|
51
|
+
else
|
52
|
+
state[:ax].copy!(p.data)
|
53
|
+
end
|
54
|
+
|
55
|
+
# update eta and mu
|
56
|
+
state[:eta] = (group[:lr] / ((1 + group[:lambd] * group[:lr] * state[:step]) ** group[:alpha]))
|
57
|
+
state[:mu] = 1 / [1, state[:step] - group[:t0]].max
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
loss
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|