torch-rb 0.4.1 → 0.4.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adagrad.rb +1 -1
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/rmsprop.rb +5 -5
- data/lib/torch/optim/sgd.rb +2 -2
- data/lib/torch/version.rb +1 -1
- metadata +3 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: cab9683b82c74698dc3f4c60bf56719ee4d17f6c994ddf913502e62e73795b35
|
4
|
+
data.tar.gz: bc4b74a0786f0dbdb571fadcce37812c2d5a38b8bd7d75f5644d9159a61e385d
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: f237f8e517647b08069e712af9662418cbcbfb3fb7f67965a6a8d55ef9096e3d141f27b6f5740f77873ea87281f008e2eece145f4c7639ef2b85c18c60b6cd8d
|
7
|
+
data.tar.gz: 5ce048f61657af95cd7084d9b115c9ba6b82b97449603eaef9217786b8cf5a0459d1116c8ef3bccc7cc21bbdd9b600ba3a6918bfae70004634ae46bc797532c9
|
data/CHANGELOG.md
CHANGED
data/lib/torch/optim/adadelta.rb
CHANGED
data/lib/torch/optim/adagrad.rb
CHANGED
@@ -49,7 +49,7 @@ module Torch
|
|
49
49
|
if p.grad.data.sparse?
|
50
50
|
raise Error, "weight_decay option is not compatible with sparse gradients"
|
51
51
|
end
|
52
|
-
grad = grad.add(group[:weight_decay]
|
52
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
53
53
|
end
|
54
54
|
|
55
55
|
clr = group[:lr] / (1 + (state[:step] - 1) * group[:lr_decay])
|
data/lib/torch/optim/adamax.rb
CHANGED
data/lib/torch/optim/asgd.rb
CHANGED
data/lib/torch/optim/rmsprop.rb
CHANGED
@@ -46,23 +46,23 @@ module Torch
|
|
46
46
|
state[:step] += 1
|
47
47
|
|
48
48
|
if group[:weight_decay] != 0
|
49
|
-
grad = grad.add(group[:weight_decay]
|
49
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
50
50
|
end
|
51
51
|
|
52
52
|
square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
|
53
53
|
|
54
54
|
if group[:centered]
|
55
55
|
grad_avg = state[:grad_avg]
|
56
|
-
grad_avg.mul!(alpha).add!(1 - alpha
|
57
|
-
avg = square_avg.addcmul(
|
56
|
+
grad_avg.mul!(alpha).add!(grad, alpha: 1 - alpha)
|
57
|
+
avg = square_avg.addcmul(grad_avg, grad_avg, value: -1).sqrt!.add!(group[:eps])
|
58
58
|
else
|
59
59
|
avg = square_avg.sqrt.add!(group[:eps])
|
60
60
|
end
|
61
61
|
|
62
62
|
if group[:momentum] > 0
|
63
63
|
buf = state[:momentum_buffer]
|
64
|
-
buf.mul!(group[:momentum]).addcdiv!(grad, avg)
|
65
|
-
p.data.add!(-group[:lr]
|
64
|
+
buf.mul!(group[:momentum]).addcdiv!(1, grad, avg)
|
65
|
+
p.data.add!(buf, alpha: -group[:lr])
|
66
66
|
else
|
67
67
|
p.data.addcdiv!(-group[:lr], grad, avg)
|
68
68
|
end
|
data/lib/torch/optim/sgd.rb
CHANGED
@@ -36,14 +36,14 @@ module Torch
|
|
36
36
|
end
|
37
37
|
if momentum != 0
|
38
38
|
param_state = @state[p]
|
39
|
-
if !param_state.key(:momentum_buffer)
|
39
|
+
if !param_state.key?(:momentum_buffer)
|
40
40
|
buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
|
41
41
|
else
|
42
42
|
buf = param_state[:momentum_buffer]
|
43
43
|
buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
|
44
44
|
end
|
45
45
|
if nesterov
|
46
|
-
d_p = d_p.add(
|
46
|
+
d_p = d_p.add(buf, alpha: momentum)
|
47
47
|
else
|
48
48
|
d_p = buf
|
49
49
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.4.
|
4
|
+
version: 0.4.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-10-
|
11
|
+
date: 2020-10-28 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -289,7 +289,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
289
289
|
- !ruby/object:Gem::Version
|
290
290
|
version: '0'
|
291
291
|
requirements: []
|
292
|
-
rubygems_version: 3.
|
292
|
+
rubygems_version: 3.1.4
|
293
293
|
signing_key:
|
294
294
|
specification_version: 4
|
295
295
|
summary: Deep learning for Ruby, powered by LibTorch
|