torch-rb 0.1.2 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/LICENSE.txt +46 -22
- data/README.md +14 -5
- data/ext/torch/ext.cpp +248 -31
- data/lib/torch.rb +80 -9
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +4 -3
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/conv2d.rb +12 -24
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +54 -12
- data/lib/torch/nn/linear.rb +2 -2
- data/lib/torch/nn/module.rb +30 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +56 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +48 -16
- data/lib/torch/tensor.rb +38 -4
- data/lib/torch/utils/data/data_loader.rb +10 -4
- data/lib/torch/utils/data/tensor_dataset.rb +3 -0
- data/lib/torch/version.rb +1 -1
- metadata +21 -3
data/lib/torch/nn/linear.rb
CHANGED
@@ -20,11 +20,11 @@ module Torch
|
|
20
20
|
end
|
21
21
|
|
22
22
|
def reset_parameters
|
23
|
-
Init.
|
23
|
+
Init.kaiming_uniform!(@weight, Math.sqrt(5))
|
24
24
|
if @bias
|
25
25
|
fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
|
26
26
|
bound = 1 / Math.sqrt(fan_in)
|
27
|
-
Init.
|
27
|
+
Init.uniform!(@bias, -bound, bound)
|
28
28
|
end
|
29
29
|
end
|
30
30
|
|
data/lib/torch/nn/module.rb
CHANGED
@@ -1,6 +1,10 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Module
|
4
|
+
def initialize
|
5
|
+
@training = true
|
6
|
+
end
|
7
|
+
|
4
8
|
def inspect
|
5
9
|
str = String.new
|
6
10
|
str << "#{self.class.name}(\n"
|
@@ -10,10 +14,36 @@ module Torch
|
|
10
14
|
str << ")"
|
11
15
|
end
|
12
16
|
|
17
|
+
def train(mode = true)
|
18
|
+
@training = mode
|
19
|
+
|
20
|
+
modules.each do |_, mod|
|
21
|
+
mod.train(mode)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
def eval
|
26
|
+
train(false)
|
27
|
+
end
|
28
|
+
|
13
29
|
def call(*input)
|
14
30
|
forward(*input)
|
15
31
|
end
|
16
32
|
|
33
|
+
# modifies in-place
|
34
|
+
def to(device)
|
35
|
+
instance_variables.each do |name|
|
36
|
+
param = instance_variable_get(name)
|
37
|
+
if param.is_a?(Parameter)
|
38
|
+
instance_variable_set(name, Parameter.new(param.to(device)))
|
39
|
+
end
|
40
|
+
end
|
41
|
+
modules.each do |_, mod|
|
42
|
+
mod.to(device)
|
43
|
+
end
|
44
|
+
self
|
45
|
+
end
|
46
|
+
|
17
47
|
def parameters
|
18
48
|
params = []
|
19
49
|
instance_variables.each do |name|
|
@@ -0,0 +1,57 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adadelta.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Adadelta < Optimizer
|
5
|
+
def initialize(params, lr: 1.0, rho: 0.9, eps: 1e-6, weight_decay: 0)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid rho value: #{rho}" if rho < 0 || rho > 1
|
8
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
9
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
10
|
+
|
11
|
+
defaults = {lr: lr, rho: rho, eps: eps, weight_decay: weight_decay}
|
12
|
+
super(params, defaults)
|
13
|
+
end
|
14
|
+
|
15
|
+
def step(closure = nil)
|
16
|
+
loss = nil
|
17
|
+
if closure
|
18
|
+
loss = closure.call
|
19
|
+
end
|
20
|
+
|
21
|
+
@param_groups.each do |group|
|
22
|
+
group[:params].each do |p|
|
23
|
+
next unless p.grad
|
24
|
+
grad = p.grad.data
|
25
|
+
if grad.sparse?
|
26
|
+
raise Error, "Adadelta does not support sparse gradients"
|
27
|
+
end
|
28
|
+
state = @state[p]
|
29
|
+
|
30
|
+
if state.size == 0
|
31
|
+
state[:step] = 0
|
32
|
+
state[:square_avg] = Torch.zeros_like(p.data)
|
33
|
+
state[:acc_delta] = Torch.zeros_like(p.data)
|
34
|
+
end
|
35
|
+
|
36
|
+
square_avg, acc_delta = state[:square_avg], state[:acc_delta]
|
37
|
+
rho, eps = group[:rho], group[:eps]
|
38
|
+
|
39
|
+
state[:step] += 1
|
40
|
+
|
41
|
+
if group[:weight_decay] != 0
|
42
|
+
grad = grad.add(group[:weight_decay], p.data)
|
43
|
+
end
|
44
|
+
|
45
|
+
square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
|
46
|
+
std = square_avg.add(eps).sqrt!
|
47
|
+
delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
|
48
|
+
p.data.add!(-group[:lr], delta)
|
49
|
+
acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
loss
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adagrad.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Adagrad < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, lr_decay: 0, weight_decay: 0, initial_accumulator_value: 0, eps: 1e-10)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid lr_decay value: #{lr_decay}" if lr_decay < 0
|
8
|
+
raise ArgumentError, "Invalid initial_accumulator_value value: #{initial_accumulator_value}" if initial_accumulator_value < 0
|
9
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
10
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
11
|
+
|
12
|
+
defaults = {lr: lr, lr_decay: lr_decay, eps: eps, weight_decay: weight_decay, initial_accumulator_value: initial_accumulator_value}
|
13
|
+
super(params, defaults)
|
14
|
+
|
15
|
+
@param_groups.each do |group|
|
16
|
+
group[:params].each do |p|
|
17
|
+
state = @state[p]
|
18
|
+
state[:step] = 0
|
19
|
+
state[:sum] = Torch.full_like(p.data, initial_accumulator_value)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
23
|
+
|
24
|
+
def share_memory
|
25
|
+
@param_groups.each do |group|
|
26
|
+
group[:params].each do |p|
|
27
|
+
state = @state[p]
|
28
|
+
state[:sum].share_memory!
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
def step(closure = nil)
|
34
|
+
loss = nil
|
35
|
+
if closure
|
36
|
+
loss = closure.call
|
37
|
+
end
|
38
|
+
|
39
|
+
@param_groups.each do |group|
|
40
|
+
group[:params].each do |p|
|
41
|
+
next unless p.grad
|
42
|
+
|
43
|
+
grad = p.grad.data
|
44
|
+
state = @state[p]
|
45
|
+
|
46
|
+
state[:step] += 1
|
47
|
+
|
48
|
+
if group[:weight_decay] != 0
|
49
|
+
if p.grad.data.sparse?
|
50
|
+
raise Error, "weight_decay option is not compatible with sparse gradients"
|
51
|
+
end
|
52
|
+
grad = grad.add(group[:weight_decay], p.data)
|
53
|
+
end
|
54
|
+
|
55
|
+
clr = group[:lr] / (1 + (state[:step] - 1) * group[:lr_decay])
|
56
|
+
|
57
|
+
if grad.sparse?
|
58
|
+
raise NotImplementedYet
|
59
|
+
else
|
60
|
+
state[:sum].addcmul!(1, grad, grad)
|
61
|
+
std = state[:sum].sqrt.add!(group[:eps])
|
62
|
+
p.data.addcdiv!(-clr, grad, std)
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|
66
|
+
|
67
|
+
loss
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
71
|
+
end
|
@@ -0,0 +1,81 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Adam < Optimizer
|
5
|
+
def initialize(params, lr: 1e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 0, amsgrad: false)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
8
|
+
raise ArgumentError, "Invalid beta parameter at index 0: #{betas[0]}" if betas[0] < 0 || betas[0] >= 1
|
9
|
+
raise ArgumentError, "Invalid beta parameter at index 1: #{betas[1]}" if betas[1] < 0 || betas[1] >= 1
|
10
|
+
|
11
|
+
defaults = {lr: lr, betas: betas, eps: eps, weight_decay: weight_decay, amsgrad: amsgrad}
|
12
|
+
super(params, defaults)
|
13
|
+
end
|
14
|
+
|
15
|
+
def step(closure = nil)
|
16
|
+
loss = nil
|
17
|
+
if closure
|
18
|
+
loss = closure.call
|
19
|
+
end
|
20
|
+
|
21
|
+
@param_groups.each do |group|
|
22
|
+
group[:params].each do |p|
|
23
|
+
next unless p.grad
|
24
|
+
grad = p.grad.data
|
25
|
+
if grad.sparse?
|
26
|
+
raise Error, "Adam does not support sparse gradients, please consider SparseAdam instead"
|
27
|
+
end
|
28
|
+
amsgrad = group[:amsgrad]
|
29
|
+
|
30
|
+
state = @state[p]
|
31
|
+
|
32
|
+
# State initialization
|
33
|
+
if state.size == 0
|
34
|
+
state[:step] = 0
|
35
|
+
# Exponential moving average of gradient values
|
36
|
+
state[:exp_avg] = Torch.zeros_like(p.data)
|
37
|
+
# Exponential moving average of squared gradient values
|
38
|
+
state[:exp_avg_sq] = Torch.zeros_like(p.data)
|
39
|
+
if amsgrad
|
40
|
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
41
|
+
state[:max_exp_avg_sq] = Torch.zeros_like(p.data)
|
42
|
+
end
|
43
|
+
end
|
44
|
+
|
45
|
+
exp_avg, exp_avg_sq = state[:exp_avg], state[:exp_avg_sq]
|
46
|
+
if amsgrad
|
47
|
+
max_exp_avg_sq = state[:max_exp_avg_sq]
|
48
|
+
end
|
49
|
+
beta1, beta2 = group[:betas]
|
50
|
+
|
51
|
+
state[:step] += 1
|
52
|
+
bias_correction1 = 1 - beta1 ** state[:step]
|
53
|
+
bias_correction2 = 1 - beta2 ** state[:step]
|
54
|
+
|
55
|
+
if group[:weight_decay] != 0
|
56
|
+
grad.add!(group[:weight_decay], p.data)
|
57
|
+
end
|
58
|
+
|
59
|
+
# Decay the first and second moment running average coefficient
|
60
|
+
exp_avg.mul!(beta1).add!(1 - beta1, grad)
|
61
|
+
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
|
62
|
+
if amsgrad
|
63
|
+
# Maintains the maximum of all 2nd moment running avg. till now
|
64
|
+
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
65
|
+
# Use the max. for normalizing running avg. of gradient
|
66
|
+
denom = (max_exp_avg_sq.sqrt / Math.sqrt(bias_correction2)).add!(group[:eps])
|
67
|
+
else
|
68
|
+
denom = (exp_avg_sq.sqrt / Math.sqrt(bias_correction2)).add!(group[:eps])
|
69
|
+
end
|
70
|
+
|
71
|
+
step_size = group[:lr] / bias_correction1
|
72
|
+
|
73
|
+
p.data.addcdiv!(-step_size, exp_avg, denom)
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
loss
|
78
|
+
end
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adamax.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class Adamax < Optimizer
|
5
|
+
def initialize(params, lr: 2e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 0)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
8
|
+
raise ArgumentError, "Invalid beta parameter at index 0: #{betas[0]}" if betas[0] < 0 || betas[0] >= 1
|
9
|
+
raise ArgumentError, "Invalid beta parameter at index 1: #{betas[1]}" if betas[1] < 0 || betas[1] >= 1
|
10
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
11
|
+
|
12
|
+
defaults = {lr: lr, betas: betas, eps: eps, weight_decay: weight_decay}
|
13
|
+
super(params, defaults)
|
14
|
+
end
|
15
|
+
|
16
|
+
def step(closure = nil)
|
17
|
+
loss = nil
|
18
|
+
if closure
|
19
|
+
loss = closure.call
|
20
|
+
end
|
21
|
+
|
22
|
+
@param_groups.each do |group|
|
23
|
+
group[:params].each do |p|
|
24
|
+
next unless p.grad
|
25
|
+
grad = p.grad.data
|
26
|
+
if grad.sparse?
|
27
|
+
raise Error, "Adamax does not support sparse gradients, please consider SparseAdam instead"
|
28
|
+
end
|
29
|
+
state = @state[p]
|
30
|
+
|
31
|
+
# State initialization
|
32
|
+
if state.size == 0
|
33
|
+
state[:step] = 0
|
34
|
+
state[:exp_avg] = Torch.zeros_like(p.data)
|
35
|
+
state[:exp_inf] = Torch.zeros_like(p.data)
|
36
|
+
end
|
37
|
+
|
38
|
+
exp_avg, exp_inf = state[:exp_avg], state[:exp_inf]
|
39
|
+
beta1, beta2 = group[:betas]
|
40
|
+
eps = group[:eps]
|
41
|
+
|
42
|
+
state[:step] += 1
|
43
|
+
|
44
|
+
if group[:weight_decay] != 0
|
45
|
+
grad = grad.add(group[:weight_decay], p.data)
|
46
|
+
end
|
47
|
+
|
48
|
+
# Update biased first moment estimate.
|
49
|
+
exp_avg.mul!(beta1).add!(1 - beta1, grad)
|
50
|
+
# Update the exponentially weighted infinity norm.
|
51
|
+
norm_buf = Torch.cat([
|
52
|
+
exp_inf.mul!(beta2).unsqueeze(0),
|
53
|
+
grad.abs.add!(eps).unsqueeze!(0)
|
54
|
+
], 0)
|
55
|
+
Torch.max(norm_buf, 0, keepdim: false, out: [exp_inf, exp_inf.new.long])
|
56
|
+
|
57
|
+
bias_correction = 1 - beta1 ** state[:step]
|
58
|
+
clr = group[:lr] / bias_correction
|
59
|
+
|
60
|
+
p.data.addcdiv!(-clr, exp_avg, exp_inf)
|
61
|
+
end
|
62
|
+
end
|
63
|
+
|
64
|
+
loss
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/adamw.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class AdamW < Optimizer
|
5
|
+
def initialize(params, lr: 1e-3, betas: [0.9, 0.999], eps: 1e-8, weight_decay: 1e-2, amsgrad: false)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid epsilon value: #{eps}" if eps < 0
|
8
|
+
raise ArgumentError, "Invalid beta parameter at index 0: #{betas[0]}" if betas[0] < 0 || betas[0] >= 1
|
9
|
+
raise ArgumentError, "Invalid beta parameter at index 1: #{betas[1]}" if betas[1] < 0 || betas[1] >= 1
|
10
|
+
|
11
|
+
defaults = {lr: lr, betas: betas, eps: eps, weight_decay: weight_decay, amsgrad: amsgrad}
|
12
|
+
super(params, defaults)
|
13
|
+
end
|
14
|
+
|
15
|
+
def step(closure = nil)
|
16
|
+
loss = nil
|
17
|
+
if closure
|
18
|
+
loss = closure.call
|
19
|
+
end
|
20
|
+
|
21
|
+
@param_groups.each do |group|
|
22
|
+
group[:params].each do |p|
|
23
|
+
next unless p.grad
|
24
|
+
|
25
|
+
# Perform stepweight decay
|
26
|
+
p.data.mul!(1 - group[:lr] * group[:weight_decay])
|
27
|
+
|
28
|
+
# Perform optimization step
|
29
|
+
grad = p.grad.data
|
30
|
+
if grad.sparse?
|
31
|
+
raise Error, "AdamW does not support sparse gradients, please consider SparseAdam instead"
|
32
|
+
end
|
33
|
+
amsgrad = group[:amsgrad]
|
34
|
+
|
35
|
+
state = @state[p]
|
36
|
+
|
37
|
+
# State initialization
|
38
|
+
if state.size == 0
|
39
|
+
state[:step] = 0
|
40
|
+
# Exponential moving average of gradient values
|
41
|
+
state[:exp_avg] = Torch.zeros_like(p.data)
|
42
|
+
# Exponential moving average of squared gradient values
|
43
|
+
state[:exp_avg_sq] = Torch.zeros_like(p.data)
|
44
|
+
if amsgrad
|
45
|
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
46
|
+
state[:max_exp_avg_sq] = Torch.zeros_like(p.data)
|
47
|
+
end
|
48
|
+
end
|
49
|
+
|
50
|
+
exp_avg, exp_avg_sq = state[:exp_avg], state[:exp_avg_sq]
|
51
|
+
if amsgrad
|
52
|
+
max_exp_avg_sq = state[:max_exp_avg_sq]
|
53
|
+
end
|
54
|
+
beta1, beta2 = group[:betas]
|
55
|
+
|
56
|
+
state[:step] += 1
|
57
|
+
bias_correction1 = 1 - beta1 ** state[:step]
|
58
|
+
bias_correction2 = 1 - beta2 ** state[:step]
|
59
|
+
|
60
|
+
# Decay the first and second moment running average coefficient
|
61
|
+
exp_avg.mul!(beta1).add!(1 - beta1, grad)
|
62
|
+
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
|
63
|
+
if amsgrad
|
64
|
+
# Maintains the maximum of all 2nd moment running avg. till now
|
65
|
+
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
66
|
+
# Use the max. for normalizing running avg. of gradient
|
67
|
+
denom = (max_exp_avg_sq.sqrt / Math.sqrt(bias_correction2)).add!(group[:eps])
|
68
|
+
else
|
69
|
+
denom = (exp_avg_sq.sqrt / Math.sqrt(bias_correction2)).add!(group[:eps])
|
70
|
+
end
|
71
|
+
|
72
|
+
step_size = group[:lr] / bias_correction1
|
73
|
+
|
74
|
+
p.data.addcdiv!(-step_size, exp_avg, denom)
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
loss
|
79
|
+
end
|
80
|
+
end
|
81
|
+
end
|
82
|
+
end
|
@@ -0,0 +1,65 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/optim/asgd.py
|
2
|
+
module Torch
|
3
|
+
module Optim
|
4
|
+
class ASGD < Optimizer
|
5
|
+
def initialize(params, lr: 1e-2, lambd: 1e-4, alpha: 0.75, t0: 1e6, weight_decay: 0)
|
6
|
+
raise ArgumentError, "Invalid learning rate: #{lr}" if lr < 0
|
7
|
+
raise ArgumentError, "Invalid weight_decay value: #{weight_decay}" if weight_decay < 0
|
8
|
+
|
9
|
+
defaults = {lr: lr, lambd: lambd, alpha: alpha, t0: t0, weight_decay: weight_decay}
|
10
|
+
super(params, defaults)
|
11
|
+
end
|
12
|
+
|
13
|
+
def step(closure = nil)
|
14
|
+
loss = nil
|
15
|
+
if closure
|
16
|
+
loss = closure.call
|
17
|
+
end
|
18
|
+
|
19
|
+
@param_groups.each do |group|
|
20
|
+
group[:params].each do |p|
|
21
|
+
next unless p.grad
|
22
|
+
grad = p.grad.data
|
23
|
+
if grad.sparse?
|
24
|
+
raise Error, "ASGD does not support sparse gradients"
|
25
|
+
end
|
26
|
+
state = @state[p]
|
27
|
+
|
28
|
+
# State initialization
|
29
|
+
if state.size == 0
|
30
|
+
state[:step] = 0
|
31
|
+
state[:eta] = group[:lr]
|
32
|
+
state[:mu] = 1
|
33
|
+
state[:ax] = Torch.zeros_like(p.data)
|
34
|
+
end
|
35
|
+
|
36
|
+
state[:step] += 1
|
37
|
+
|
38
|
+
if group[:weight_decay] != 0
|
39
|
+
grad = grad.add(group[:weight_decay], p.data)
|
40
|
+
end
|
41
|
+
|
42
|
+
# decay term
|
43
|
+
p.data.mul!(1 - group[:lambd] * state[:eta])
|
44
|
+
|
45
|
+
# update parameter
|
46
|
+
p.data.add!(-state[:eta], grad)
|
47
|
+
|
48
|
+
# averaging
|
49
|
+
if state[:mu] != 1
|
50
|
+
state[:ax].add!(p.data.sub(state[:ax]).mul(state[:mu]))
|
51
|
+
else
|
52
|
+
state[:ax].copy!(p.data)
|
53
|
+
end
|
54
|
+
|
55
|
+
# update eta and mu
|
56
|
+
state[:eta] = (group[:lr] / ((1 + group[:lambd] * group[:lr] * state[:step]) ** group[:alpha]))
|
57
|
+
state[:mu] = 1 / [1, state[:step] - group[:t0]].max
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
loss
|
62
|
+
end
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|