torch-rb 0.3.7 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +23 -0
- data/README.md +2 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +557 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +2363 -714
- data/ext/torch/ext.cpp +56 -87
- data/ext/torch/extconf.rb +2 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +397 -0
- data/ext/torch/{templates.hpp → templates.h} +31 -51
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +23 -15
- data/lib/torch.rb +5 -69
- data/lib/torch/nn/functional.rb +30 -16
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/optim/adadelta.rb +4 -4
- data/lib/torch/optim/adagrad.rb +3 -3
- data/lib/torch/optim/adam.rb +4 -4
- data/lib/torch/optim/adamax.rb +3 -3
- data/lib/torch/optim/adamw.rb +3 -3
- data/lib/torch/optim/asgd.rb +2 -2
- data/lib/torch/optim/rmsprop.rb +7 -7
- data/lib/torch/optim/rprop.rb +1 -1
- data/lib/torch/optim/sgd.rb +5 -5
- data/lib/torch/tensor.rb +28 -103
- data/lib/torch/version.rb +1 -1
- metadata +18 -14
- data/lib/torch/native/dispatcher.rb +0 -70
- data/lib/torch/native/function.rb +0 -200
- data/lib/torch/native/generator.rb +0 -178
- data/lib/torch/native/parser.rb +0 -117
data/lib/torch/optim/adam.rb
CHANGED
@@ -53,12 +53,12 @@ module Torch
|
|
53
53
|
bias_correction2 = 1 - beta2 ** state[:step]
|
54
54
|
|
55
55
|
if group[:weight_decay] != 0
|
56
|
-
grad.add!(group[:weight_decay]
|
56
|
+
grad.add!(p.data, alpha: group[:weight_decay])
|
57
57
|
end
|
58
58
|
|
59
59
|
# Decay the first and second moment running average coefficient
|
60
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
61
|
-
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2
|
60
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
61
|
+
exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
|
62
62
|
if amsgrad
|
63
63
|
# Maintains the maximum of all 2nd moment running avg. till now
|
64
64
|
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
@@ -70,7 +70,7 @@ module Torch
|
|
70
70
|
|
71
71
|
step_size = group[:lr] / bias_correction1
|
72
72
|
|
73
|
-
p.data.addcdiv!(
|
73
|
+
p.data.addcdiv!(exp_avg, denom, value: -step_size)
|
74
74
|
end
|
75
75
|
end
|
76
76
|
|
data/lib/torch/optim/adamax.rb
CHANGED
@@ -42,11 +42,11 @@ module Torch
|
|
42
42
|
state[:step] += 1
|
43
43
|
|
44
44
|
if group[:weight_decay] != 0
|
45
|
-
grad = grad.add(group[:weight_decay]
|
45
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
46
46
|
end
|
47
47
|
|
48
48
|
# Update biased first moment estimate.
|
49
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
49
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
50
50
|
# Update the exponentially weighted infinity norm.
|
51
51
|
norm_buf = Torch.cat([
|
52
52
|
exp_inf.mul!(beta2).unsqueeze(0),
|
@@ -57,7 +57,7 @@ module Torch
|
|
57
57
|
bias_correction = 1 - beta1 ** state[:step]
|
58
58
|
clr = group[:lr] / bias_correction
|
59
59
|
|
60
|
-
p.data.addcdiv!(
|
60
|
+
p.data.addcdiv!(exp_avg, exp_inf, value: -clr)
|
61
61
|
end
|
62
62
|
end
|
63
63
|
|
data/lib/torch/optim/adamw.rb
CHANGED
@@ -58,8 +58,8 @@ module Torch
|
|
58
58
|
bias_correction2 = 1 - beta2 ** state[:step]
|
59
59
|
|
60
60
|
# Decay the first and second moment running average coefficient
|
61
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
62
|
-
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2
|
61
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
62
|
+
exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
|
63
63
|
if amsgrad
|
64
64
|
# Maintains the maximum of all 2nd moment running avg. till now
|
65
65
|
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
@@ -71,7 +71,7 @@ module Torch
|
|
71
71
|
|
72
72
|
step_size = group[:lr] / bias_correction1
|
73
73
|
|
74
|
-
p.data.addcdiv!(
|
74
|
+
p.data.addcdiv!(exp_avg, denom, value: -step_size)
|
75
75
|
end
|
76
76
|
end
|
77
77
|
|
data/lib/torch/optim/asgd.rb
CHANGED
@@ -36,14 +36,14 @@ module Torch
|
|
36
36
|
state[:step] += 1
|
37
37
|
|
38
38
|
if group[:weight_decay] != 0
|
39
|
-
grad = grad.add(group[:weight_decay]
|
39
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
40
40
|
end
|
41
41
|
|
42
42
|
# decay term
|
43
43
|
p.data.mul!(1 - group[:lambd] * state[:eta])
|
44
44
|
|
45
45
|
# update parameter
|
46
|
-
p.data.add!(-state[:eta]
|
46
|
+
p.data.add!(grad, alpha: -state[:eta])
|
47
47
|
|
48
48
|
# averaging
|
49
49
|
if state[:mu] != 1
|
data/lib/torch/optim/rmsprop.rb
CHANGED
@@ -46,25 +46,25 @@ module Torch
|
|
46
46
|
state[:step] += 1
|
47
47
|
|
48
48
|
if group[:weight_decay] != 0
|
49
|
-
grad = grad.add(group[:weight_decay]
|
49
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
50
50
|
end
|
51
51
|
|
52
|
-
square_avg.mul!(alpha).addcmul!(1 - alpha
|
52
|
+
square_avg.mul!(alpha).addcmul!(grad, grad, value: 1 - alpha)
|
53
53
|
|
54
54
|
if group[:centered]
|
55
55
|
grad_avg = state[:grad_avg]
|
56
|
-
grad_avg.mul!(alpha).add!(1 - alpha
|
57
|
-
avg = square_avg.addcmul(
|
56
|
+
grad_avg.mul!(alpha).add!(grad, alpha: 1 - alpha)
|
57
|
+
avg = square_avg.addcmul(grad_avg, grad_avg, value: -1).sqrt!.add!(group[:eps])
|
58
58
|
else
|
59
59
|
avg = square_avg.sqrt.add!(group[:eps])
|
60
60
|
end
|
61
61
|
|
62
62
|
if group[:momentum] > 0
|
63
63
|
buf = state[:momentum_buffer]
|
64
|
-
buf.mul!(group[:momentum]).addcdiv!(grad, avg)
|
65
|
-
p.data.add!(-group[:lr]
|
64
|
+
buf.mul!(group[:momentum]).addcdiv!(grad, avg, value: 1)
|
65
|
+
p.data.add!(buf, alpha: -group[:lr])
|
66
66
|
else
|
67
|
-
p.data.addcdiv!(-group[:lr]
|
67
|
+
p.data.addcdiv!(grad, avg, value: -group[:lr])
|
68
68
|
end
|
69
69
|
end
|
70
70
|
end
|
data/lib/torch/optim/rprop.rb
CHANGED
data/lib/torch/optim/sgd.rb
CHANGED
@@ -32,24 +32,24 @@ module Torch
|
|
32
32
|
next unless p.grad
|
33
33
|
d_p = p.grad.data
|
34
34
|
if weight_decay != 0
|
35
|
-
d_p.add!(
|
35
|
+
d_p.add!(p.data, alpha: weight_decay)
|
36
36
|
end
|
37
37
|
if momentum != 0
|
38
38
|
param_state = @state[p]
|
39
|
-
if !param_state.key(:momentum_buffer)
|
39
|
+
if !param_state.key?(:momentum_buffer)
|
40
40
|
buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
|
41
41
|
else
|
42
42
|
buf = param_state[:momentum_buffer]
|
43
|
-
buf.mul!(momentum).add!(1 - dampening
|
43
|
+
buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
|
44
44
|
end
|
45
45
|
if nesterov
|
46
|
-
d_p = d_p.add(
|
46
|
+
d_p = d_p.add(buf, alpha: momentum)
|
47
47
|
else
|
48
48
|
d_p = buf
|
49
49
|
end
|
50
50
|
end
|
51
51
|
|
52
|
-
p.data.add!(-group[:lr]
|
52
|
+
p.data.add!(d_p, alpha: -group[:lr])
|
53
53
|
end
|
54
54
|
end
|
55
55
|
|
data/lib/torch/tensor.rb
CHANGED
@@ -8,6 +8,18 @@ module Torch
|
|
8
8
|
alias_method :ndim, :dim
|
9
9
|
alias_method :ndimension, :dim
|
10
10
|
|
11
|
+
# use alias_method for performance
|
12
|
+
alias_method :+, :add
|
13
|
+
alias_method :-, :sub
|
14
|
+
alias_method :*, :mul
|
15
|
+
alias_method :/, :div
|
16
|
+
alias_method :%, :remainder
|
17
|
+
alias_method :**, :pow
|
18
|
+
alias_method :-@, :neg
|
19
|
+
alias_method :&, :logical_and
|
20
|
+
alias_method :|, :logical_or
|
21
|
+
alias_method :^, :logical_xor
|
22
|
+
|
11
23
|
def self.new(*args)
|
12
24
|
FloatTensor.new(*args)
|
13
25
|
end
|
@@ -73,12 +85,20 @@ module Torch
|
|
73
85
|
|
74
86
|
def size(dim = nil)
|
75
87
|
if dim
|
76
|
-
|
88
|
+
_size(dim)
|
77
89
|
else
|
78
90
|
shape
|
79
91
|
end
|
80
92
|
end
|
81
93
|
|
94
|
+
def stride(dim = nil)
|
95
|
+
if dim
|
96
|
+
_stride(dim)
|
97
|
+
else
|
98
|
+
_strides
|
99
|
+
end
|
100
|
+
end
|
101
|
+
|
82
102
|
# mirror Python len()
|
83
103
|
def length
|
84
104
|
size(0)
|
@@ -130,57 +150,6 @@ module Torch
|
|
130
150
|
end
|
131
151
|
end
|
132
152
|
|
133
|
-
def reshape(*size)
|
134
|
-
# Python doesn't check if size == 1, just ignores later arguments
|
135
|
-
size = size.first if size.size == 1 && size.first.is_a?(Array)
|
136
|
-
_reshape(size)
|
137
|
-
end
|
138
|
-
|
139
|
-
def view(*size)
|
140
|
-
size = size.first if size.size == 1 && size.first.is_a?(Array)
|
141
|
-
_view(size)
|
142
|
-
end
|
143
|
-
|
144
|
-
def +(other)
|
145
|
-
add(other)
|
146
|
-
end
|
147
|
-
|
148
|
-
def -(other)
|
149
|
-
sub(other)
|
150
|
-
end
|
151
|
-
|
152
|
-
def *(other)
|
153
|
-
mul(other)
|
154
|
-
end
|
155
|
-
|
156
|
-
def /(other)
|
157
|
-
div(other)
|
158
|
-
end
|
159
|
-
|
160
|
-
def %(other)
|
161
|
-
remainder(other)
|
162
|
-
end
|
163
|
-
|
164
|
-
def **(other)
|
165
|
-
pow(other)
|
166
|
-
end
|
167
|
-
|
168
|
-
def -@
|
169
|
-
neg
|
170
|
-
end
|
171
|
-
|
172
|
-
def &(other)
|
173
|
-
logical_and(other)
|
174
|
-
end
|
175
|
-
|
176
|
-
def |(other)
|
177
|
-
logical_or(other)
|
178
|
-
end
|
179
|
-
|
180
|
-
def ^(other)
|
181
|
-
logical_xor(other)
|
182
|
-
end
|
183
|
-
|
184
153
|
# TODO better compare?
|
185
154
|
def <=>(other)
|
186
155
|
item <=> other
|
@@ -189,7 +158,7 @@ module Torch
|
|
189
158
|
# based on python_variable_indexing.cpp and
|
190
159
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
191
160
|
def [](*indexes)
|
192
|
-
_index(
|
161
|
+
_index(indexes)
|
193
162
|
end
|
194
163
|
|
195
164
|
# based on python_variable_indexing.cpp and
|
@@ -197,62 +166,18 @@ module Torch
|
|
197
166
|
def []=(*indexes, value)
|
198
167
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
199
168
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
200
|
-
_index_put_custom(
|
201
|
-
end
|
202
|
-
|
203
|
-
# native functions that need manually defined
|
204
|
-
|
205
|
-
# value and other are swapped for some methods
|
206
|
-
def add!(value = 1, other)
|
207
|
-
if other.is_a?(Numeric)
|
208
|
-
_add__scalar(other, value)
|
209
|
-
else
|
210
|
-
_add__tensor(other, value)
|
211
|
-
end
|
169
|
+
_index_put_custom(indexes, value)
|
212
170
|
end
|
213
171
|
|
214
172
|
# parser can't handle overlap, so need to handle manually
|
215
173
|
def random!(*args)
|
216
|
-
|
217
|
-
|
218
|
-
_random__to(*args)
|
219
|
-
when 2
|
220
|
-
_random__from(*args)
|
221
|
-
else
|
222
|
-
_random_(*args)
|
223
|
-
end
|
174
|
+
return _random!(0, *args) if args.size == 1
|
175
|
+
_random!(*args)
|
224
176
|
end
|
225
177
|
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
end
|
230
|
-
|
231
|
-
private
|
232
|
-
|
233
|
-
def tensor_indexes(indexes)
|
234
|
-
indexes.map do |index|
|
235
|
-
case index
|
236
|
-
when Integer
|
237
|
-
TensorIndex.integer(index)
|
238
|
-
when Range
|
239
|
-
finish = index.end || -1
|
240
|
-
if finish == -1 && !index.exclude_end?
|
241
|
-
finish = nil
|
242
|
-
else
|
243
|
-
finish += 1 unless index.exclude_end?
|
244
|
-
end
|
245
|
-
TensorIndex.slice(index.begin, finish)
|
246
|
-
when Tensor
|
247
|
-
TensorIndex.tensor(index)
|
248
|
-
when nil
|
249
|
-
TensorIndex.none
|
250
|
-
when true, false
|
251
|
-
TensorIndex.boolean(index)
|
252
|
-
else
|
253
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
254
|
-
end
|
255
|
-
end
|
178
|
+
# center option
|
179
|
+
def stft(*args)
|
180
|
+
Torch.stft(*args)
|
256
181
|
end
|
257
182
|
end
|
258
183
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.5.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
|
-
autorequire:
|
8
|
+
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-10-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -108,7 +108,7 @@ dependencies:
|
|
108
108
|
- - ">="
|
109
109
|
- !ruby/object:Gem::Version
|
110
110
|
version: 0.1.1
|
111
|
-
description:
|
111
|
+
description:
|
112
112
|
email: andrew@chartkick.com
|
113
113
|
executables: []
|
114
114
|
extensions:
|
@@ -118,19 +118,23 @@ files:
|
|
118
118
|
- CHANGELOG.md
|
119
119
|
- LICENSE.txt
|
120
120
|
- README.md
|
121
|
+
- codegen/function.rb
|
122
|
+
- codegen/generate_functions.rb
|
123
|
+
- codegen/native_functions.yaml
|
121
124
|
- ext/torch/ext.cpp
|
122
125
|
- ext/torch/extconf.rb
|
123
|
-
- ext/torch/
|
124
|
-
- ext/torch/
|
126
|
+
- ext/torch/nn_functions.h
|
127
|
+
- ext/torch/ruby_arg_parser.cpp
|
128
|
+
- ext/torch/ruby_arg_parser.h
|
129
|
+
- ext/torch/templates.h
|
130
|
+
- ext/torch/tensor_functions.h
|
131
|
+
- ext/torch/torch_functions.h
|
132
|
+
- ext/torch/utils.h
|
133
|
+
- ext/torch/wrap_outputs.h
|
125
134
|
- lib/torch-rb.rb
|
126
135
|
- lib/torch.rb
|
127
136
|
- lib/torch/hub.rb
|
128
137
|
- lib/torch/inspector.rb
|
129
|
-
- lib/torch/native/dispatcher.rb
|
130
|
-
- lib/torch/native/function.rb
|
131
|
-
- lib/torch/native/generator.rb
|
132
|
-
- lib/torch/native/native_functions.yaml
|
133
|
-
- lib/torch/native/parser.rb
|
134
138
|
- lib/torch/nn/adaptive_avg_pool1d.rb
|
135
139
|
- lib/torch/nn/adaptive_avg_pool2d.rb
|
136
140
|
- lib/torch/nn/adaptive_avg_pool3d.rb
|
@@ -270,7 +274,7 @@ homepage: https://github.com/ankane/torch.rb
|
|
270
274
|
licenses:
|
271
275
|
- BSD-3-Clause
|
272
276
|
metadata: {}
|
273
|
-
post_install_message:
|
277
|
+
post_install_message:
|
274
278
|
rdoc_options: []
|
275
279
|
require_paths:
|
276
280
|
- lib
|
@@ -285,8 +289,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
285
289
|
- !ruby/object:Gem::Version
|
286
290
|
version: '0'
|
287
291
|
requirements: []
|
288
|
-
rubygems_version: 3.1.
|
289
|
-
signing_key:
|
292
|
+
rubygems_version: 3.1.4
|
293
|
+
signing_key:
|
290
294
|
specification_version: 4
|
291
295
|
summary: Deep learning for Ruby, powered by LibTorch
|
292
296
|
test_files: []
|
@@ -1,70 +0,0 @@
|
|
1
|
-
# We use a generic interface for methods (*args, **options)
|
2
|
-
# and this class to determine the C++ method to call
|
3
|
-
#
|
4
|
-
# This is needed since LibTorch uses function overloading,
|
5
|
-
# which isn't available in Ruby or Python
|
6
|
-
#
|
7
|
-
# PyTorch uses this approach, but the parser/dispatcher is written in C++
|
8
|
-
#
|
9
|
-
# We could generate Ruby methods directly, but an advantage of this approach is
|
10
|
-
# arguments and keyword arguments can be used interchangably like in Python,
|
11
|
-
# making it easier to port code
|
12
|
-
|
13
|
-
module Torch
|
14
|
-
module Native
|
15
|
-
module Dispatcher
|
16
|
-
class << self
|
17
|
-
def bind
|
18
|
-
functions = Generator.grouped_functions
|
19
|
-
bind_functions(::Torch, :define_singleton_method, functions[:torch])
|
20
|
-
bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
|
21
|
-
bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
|
22
|
-
end
|
23
|
-
|
24
|
-
def bind_functions(context, def_method, functions)
|
25
|
-
instance_method = def_method == :define_method
|
26
|
-
functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
|
27
|
-
if instance_method
|
28
|
-
funcs.map! { |f| Function.new(f.function) }
|
29
|
-
funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
|
30
|
-
end
|
31
|
-
|
32
|
-
defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
|
33
|
-
next if defined && name != "clone"
|
34
|
-
|
35
|
-
# skip parser when possible for performance
|
36
|
-
if funcs.size == 1 && funcs.first.args.size == 0
|
37
|
-
# functions with no arguments
|
38
|
-
if instance_method
|
39
|
-
context.send(:alias_method, name, funcs.first.cpp_name)
|
40
|
-
else
|
41
|
-
context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
|
42
|
-
end
|
43
|
-
elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
|
44
|
-
# functions that take a tensor or scalar
|
45
|
-
scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
|
46
|
-
context.send(def_method, name) do |other|
|
47
|
-
case other
|
48
|
-
when Tensor
|
49
|
-
send(tensor_name, other)
|
50
|
-
else
|
51
|
-
send(scalar_name, other)
|
52
|
-
end
|
53
|
-
end
|
54
|
-
else
|
55
|
-
parser = Parser.new(funcs)
|
56
|
-
|
57
|
-
context.send(def_method, name) do |*args, **options|
|
58
|
-
result = parser.parse(args, options)
|
59
|
-
raise ArgumentError, result[:error] if result[:error]
|
60
|
-
send(result[:name], *result[:args])
|
61
|
-
end
|
62
|
-
end
|
63
|
-
end
|
64
|
-
end
|
65
|
-
end
|
66
|
-
end
|
67
|
-
end
|
68
|
-
end
|
69
|
-
|
70
|
-
Torch::Native::Dispatcher.bind
|