torch-rb 0.4.1 → 0.5.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +24 -0
- data/README.md +9 -1
- data/codegen/generate_functions.rb +13 -8
- data/codegen/native_functions.yaml +2363 -714
- data/ext/torch/ext.cpp +15 -16
- data/ext/torch/ruby_arg_parser.h +26 -2
- data/ext/torch/templates.h +1 -0
- data/ext/torch/wrap_outputs.h +7 -0
- data/lib/torch.rb +5 -7
- data/lib/torch/nn/module.rb +101 -21
- data/lib/torch/optim/adadelta.rb +3 -3
- data/lib/torch/optim/adagrad.rb +3 -3
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +2 -2
- data/lib/torch/optim/adamw.rb +2 -2
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/rmsprop.rb +7 -7
- data/lib/torch/optim/rprop.rb +1 -1
- data/lib/torch/optim/sgd.rb +2 -2
- data/lib/torch/tensor.rb +5 -0
- data/lib/torch/version.rb +1 -1
- metadata +3 -3
data/lib/torch/optim/rprop.rb
CHANGED
data/lib/torch/optim/sgd.rb
CHANGED
@@ -36,14 +36,14 @@ module Torch
|
|
36
36
|
end
|
37
37
|
if momentum != 0
|
38
38
|
param_state = @state[p]
|
39
|
-
if !param_state.key(:momentum_buffer)
|
39
|
+
if !param_state.key?(:momentum_buffer)
|
40
40
|
buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
|
41
41
|
else
|
42
42
|
buf = param_state[:momentum_buffer]
|
43
43
|
buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
|
44
44
|
end
|
45
45
|
if nesterov
|
46
|
-
d_p = d_p.add(
|
46
|
+
d_p = d_p.add(buf, alpha: momentum)
|
47
47
|
else
|
48
48
|
d_p = buf
|
49
49
|
end
|
data/lib/torch/tensor.rb
CHANGED
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.5.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date:
|
11
|
+
date: 2021-01-14 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -289,7 +289,7 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
289
289
|
- !ruby/object:Gem::Version
|
290
290
|
version: '0'
|
291
291
|
requirements: []
|
292
|
-
rubygems_version: 3.
|
292
|
+
rubygems_version: 3.2.3
|
293
293
|
signing_key:
|
294
294
|
specification_version: 4
|
295
295
|
summary: Deep learning for Ruby, powered by LibTorch
|