torch-rb 0.4.1 → 0.4.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: a40eda909da15ec34573f3fa447519c3cb5553092d6eb35189d02c88154f1669
4
- data.tar.gz: cb795ca4c53189534f306c874f90db531c3b14fd17a14d83a592d2e73dd255ff
3
+ metadata.gz: cab9683b82c74698dc3f4c60bf56719ee4d17f6c994ddf913502e62e73795b35
4
+ data.tar.gz: bc4b74a0786f0dbdb571fadcce37812c2d5a38b8bd7d75f5644d9159a61e385d
5
5
  SHA512:
6
- metadata.gz: 750e2103e8ab7b029f7f3acd439088ba41890d1aabab5761d1ea768f556ae7550d2cef91545e83eebb9b09ef9b76d9dac939f9fa80a7c1478dd7d4668a2a6c0e
7
- data.tar.gz: d22f3569dc06cc653bfc0d2786eb952ffad858dd2b235f736ebbabbc2f26b73f4c5b92c3776f2004b8076818730de1f1012b7a9ef615a84909c806abf7e96f52
6
+ metadata.gz: f237f8e517647b08069e712af9662418cbcbfb3fb7f67965a6a8d55ef9096e3d141f27b6f5740f77873ea87281f008e2eece145f4c7639ef2b85c18c60b6cd8d
7
+ data.tar.gz: 5ce048f61657af95cd7084d9b115c9ba6b82b97449603eaef9217786b8cf5a0459d1116c8ef3bccc7cc21bbdd9b600ba3a6918bfae70004634ae46bc797532c9
@@ -1,3 +1,7 @@
1
+ ## 0.4.2 (2020-10-27)
2
+
3
+ - Fixed errors with optimizer options
4
+
1
5
  ## 0.4.1 (2020-10-12)
2
6
 
3
7
  - Fixed installation error with Ruby < 2.7
@@ -39,7 +39,7 @@ module Torch
39
39
  state[:step] += 1
40
40
 
41
41
  if group[:weight_decay] != 0
42
- grad = grad.add(group[:weight_decay], p.data)
42
+ grad = grad.add(p.data, alpha: group[:weight_decay])
43
43
  end
44
44
 
45
45
  square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
@@ -49,7 +49,7 @@ module Torch
49
49
  if p.grad.data.sparse?
50
50
  raise Error, "weight_decay option is not compatible with sparse gradients"
51
51
  end
52
- grad = grad.add(group[:weight_decay], p.data)
52
+ grad = grad.add(p.data, alpha: group[:weight_decay])
53
53
  end
54
54
 
55
55
  clr = group[:lr] / (1 + (state[:step] - 1) * group[:lr_decay])
@@ -42,7 +42,7 @@ module Torch
42
42
  state[:step] += 1
43
43
 
44
44
  if group[:weight_decay] != 0
45
- grad = grad.add(group[:weight_decay], p.data)
45
+ grad = grad.add(p.data, alpha: group[:weight_decay])
46
46
  end
47
47
 
48
48
  # Update biased first moment estimate.
@@ -36,7 +36,7 @@ module Torch
36
36
  state[:step] += 1
37
37
 
38
38
  if group[:weight_decay] != 0
39
- grad = grad.add(group[:weight_decay], p.data)
39
+ grad = grad.add(p.data, alpha: group[:weight_decay])
40
40
  end
41
41
 
42
42
  # decay term
@@ -46,23 +46,23 @@ module Torch
46
46
  state[:step] += 1
47
47
 
48
48
  if group[:weight_decay] != 0
49
- grad = grad.add(group[:weight_decay], p.data)
49
+ grad = grad.add(p.data, alpha: group[:weight_decay])
50
50
  end
51
51
 
52
52
  square_avg.mul!(alpha).addcmul!(1 - alpha, grad, grad)
53
53
 
54
54
  if group[:centered]
55
55
  grad_avg = state[:grad_avg]
56
- grad_avg.mul!(alpha).add!(1 - alpha, grad)
57
- avg = square_avg.addcmul(-1, grad_avg, grad_avg).sqrt!.add!(group[:eps])
56
+ grad_avg.mul!(alpha).add!(grad, alpha: 1 - alpha)
57
+ avg = square_avg.addcmul(grad_avg, grad_avg, value: -1).sqrt!.add!(group[:eps])
58
58
  else
59
59
  avg = square_avg.sqrt.add!(group[:eps])
60
60
  end
61
61
 
62
62
  if group[:momentum] > 0
63
63
  buf = state[:momentum_buffer]
64
- buf.mul!(group[:momentum]).addcdiv!(grad, avg)
65
- p.data.add!(-group[:lr], buf)
64
+ buf.mul!(group[:momentum]).addcdiv!(1, grad, avg)
65
+ p.data.add!(buf, alpha: -group[:lr])
66
66
  else
67
67
  p.data.addcdiv!(-group[:lr], grad, avg)
68
68
  end
@@ -36,14 +36,14 @@ module Torch
36
36
  end
37
37
  if momentum != 0
38
38
  param_state = @state[p]
39
- if !param_state.key(:momentum_buffer)
39
+ if !param_state.key?(:momentum_buffer)
40
40
  buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
41
41
  else
42
42
  buf = param_state[:momentum_buffer]
43
43
  buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
44
44
  end
45
45
  if nesterov
46
- d_p = d_p.add(momentum, buf)
46
+ d_p = d_p.add(buf, alpha: momentum)
47
47
  else
48
48
  d_p = buf
49
49
  end
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.4.1"
2
+ VERSION = "0.4.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.1
4
+ version: 0.4.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-10-13 00:00:00.000000000 Z
11
+ date: 2020-10-28 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -289,7 +289,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
289
289
  - !ruby/object:Gem::Version
290
290
  version: '0'
291
291
  requirements: []
292
- rubygems_version: 3.0.3
292
+ rubygems_version: 3.1.4
293
293
  signing_key:
294
294
  specification_version: 4
295
295
  summary: Deep learning for Ruby, powered by LibTorch