torch-rb 0.3.6 → 0.5.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +27 -0
- data/README.md +3 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +557 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +2363 -714
- data/ext/torch/ext.cpp +78 -89
- data/ext/torch/extconf.rb +5 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +397 -0
- data/ext/torch/{templates.hpp → templates.h} +46 -77
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +44 -8
- data/lib/torch.rb +35 -62
- data/lib/torch/nn/functional.rb +136 -16
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/nn/module.rb +4 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/optim/adadelta.rb +4 -4
- data/lib/torch/optim/adagrad.rb +3 -3
- data/lib/torch/optim/adam.rb +4 -4
- data/lib/torch/optim/adamax.rb +3 -3
- data/lib/torch/optim/adamw.rb +3 -3
- data/lib/torch/optim/asgd.rb +2 -2
- data/lib/torch/optim/rmsprop.rb +7 -7
- data/lib/torch/optim/rprop.rb +1 -1
- data/lib/torch/optim/sgd.rb +5 -5
- data/lib/torch/tensor.rb +36 -110
- data/lib/torch/version.rb +1 -1
- metadata +19 -14
- data/lib/torch/native/dispatcher.rb +0 -48
- data/lib/torch/native/function.rb +0 -119
- data/lib/torch/native/generator.rb +0 -168
- data/lib/torch/native/parser.rb +0 -148
data/lib/torch/nn/init.rb
CHANGED
@@ -14,25 +14,11 @@ module Torch
|
|
14
14
|
_normal!(tensor, mean, std)
|
15
15
|
end
|
16
16
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
_ones!(tensor)
|
23
|
-
end
|
24
|
-
|
25
|
-
def zeros!(tensor)
|
26
|
-
_zeros!(tensor)
|
27
|
-
end
|
28
|
-
|
29
|
-
def eye!(tensor)
|
30
|
-
_eye!(tensor)
|
31
|
-
end
|
32
|
-
|
33
|
-
def dirac!(tensor)
|
34
|
-
_dirac!(tensor)
|
35
|
-
end
|
17
|
+
alias_method :constant!, :_constant!
|
18
|
+
alias_method :ones!, :_ones!
|
19
|
+
alias_method :zeros!, :_zeros!
|
20
|
+
alias_method :eye!, :_eye!
|
21
|
+
alias_method :dirac!, :_dirac!
|
36
22
|
|
37
23
|
def xavier_uniform!(tensor, gain: 1.0)
|
38
24
|
_xavier_uniform!(tensor, gain)
|
data/lib/torch/nn/module.rb
CHANGED
@@ -0,0 +1,31 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Upsample < Module
|
4
|
+
def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
|
5
|
+
super()
|
6
|
+
@size = size
|
7
|
+
if scale_factor.is_a?(Array)
|
8
|
+
@scale_factor = scale_factor.map(&:to_f)
|
9
|
+
else
|
10
|
+
@scale_factor = scale_factor ? scale_factor.to_f : nil
|
11
|
+
end
|
12
|
+
@mode = mode
|
13
|
+
@align_corners = align_corners
|
14
|
+
end
|
15
|
+
|
16
|
+
def forward(input)
|
17
|
+
F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
|
18
|
+
end
|
19
|
+
|
20
|
+
def extra_inspect
|
21
|
+
if !@scale_factor.nil?
|
22
|
+
info = "scale_factor: #{@scale_factor.inspect}"
|
23
|
+
else
|
24
|
+
info = "size: #{@size.inspect}"
|
25
|
+
end
|
26
|
+
info += ", mode: #{@mode.inspect}"
|
27
|
+
info
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
data/lib/torch/optim/adadelta.rb
CHANGED
@@ -39,14 +39,14 @@ module Torch
|
|
39
39
|
state[:step] += 1
|
40
40
|
|
41
41
|
if group[:weight_decay] != 0
|
42
|
-
grad = grad.add(group[:weight_decay]
|
42
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
43
43
|
end
|
44
44
|
|
45
|
-
square_avg.mul!(rho).addcmul!(1 - rho
|
45
|
+
square_avg.mul!(rho).addcmul!(grad, grad, value: 1 - rho)
|
46
46
|
std = square_avg.add(eps).sqrt!
|
47
47
|
delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
|
48
|
-
p.data.add!(-group[:lr]
|
49
|
-
acc_delta.mul!(rho).addcmul!(1 - rho
|
48
|
+
p.data.add!(delta, alpha: -group[:lr])
|
49
|
+
acc_delta.mul!(rho).addcmul!(delta, delta, value: 1 - rho)
|
50
50
|
end
|
51
51
|
end
|
52
52
|
|
data/lib/torch/optim/adagrad.rb
CHANGED
@@ -49,7 +49,7 @@ module Torch
|
|
49
49
|
if p.grad.data.sparse?
|
50
50
|
raise Error, "weight_decay option is not compatible with sparse gradients"
|
51
51
|
end
|
52
|
-
grad = grad.add(group[:weight_decay]
|
52
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
53
53
|
end
|
54
54
|
|
55
55
|
clr = group[:lr] / (1 + (state[:step] - 1) * group[:lr_decay])
|
@@ -57,9 +57,9 @@ module Torch
|
|
57
57
|
if grad.sparse?
|
58
58
|
raise NotImplementedYet
|
59
59
|
else
|
60
|
-
state[:sum].addcmul!(
|
60
|
+
state[:sum].addcmul!(grad, grad, value: 1)
|
61
61
|
std = state[:sum].sqrt.add!(group[:eps])
|
62
|
-
p.data.addcdiv!(
|
62
|
+
p.data.addcdiv!(grad, std, value: -clr)
|
63
63
|
end
|
64
64
|
end
|
65
65
|
end
|
data/lib/torch/optim/adam.rb
CHANGED
@@ -53,12 +53,12 @@ module Torch
|
|
53
53
|
bias_correction2 = 1 - beta2 ** state[:step]
|
54
54
|
|
55
55
|
if group[:weight_decay] != 0
|
56
|
-
grad.add!(group[:weight_decay]
|
56
|
+
grad.add!(p.data, alpha: group[:weight_decay])
|
57
57
|
end
|
58
58
|
|
59
59
|
# Decay the first and second moment running average coefficient
|
60
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
61
|
-
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2
|
60
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
61
|
+
exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
|
62
62
|
if amsgrad
|
63
63
|
# Maintains the maximum of all 2nd moment running avg. till now
|
64
64
|
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
@@ -70,7 +70,7 @@ module Torch
|
|
70
70
|
|
71
71
|
step_size = group[:lr] / bias_correction1
|
72
72
|
|
73
|
-
p.data.addcdiv!(
|
73
|
+
p.data.addcdiv!(exp_avg, denom, value: -step_size)
|
74
74
|
end
|
75
75
|
end
|
76
76
|
|
data/lib/torch/optim/adamax.rb
CHANGED
@@ -42,11 +42,11 @@ module Torch
|
|
42
42
|
state[:step] += 1
|
43
43
|
|
44
44
|
if group[:weight_decay] != 0
|
45
|
-
grad = grad.add(group[:weight_decay]
|
45
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
46
46
|
end
|
47
47
|
|
48
48
|
# Update biased first moment estimate.
|
49
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
49
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
50
50
|
# Update the exponentially weighted infinity norm.
|
51
51
|
norm_buf = Torch.cat([
|
52
52
|
exp_inf.mul!(beta2).unsqueeze(0),
|
@@ -57,7 +57,7 @@ module Torch
|
|
57
57
|
bias_correction = 1 - beta1 ** state[:step]
|
58
58
|
clr = group[:lr] / bias_correction
|
59
59
|
|
60
|
-
p.data.addcdiv!(
|
60
|
+
p.data.addcdiv!(exp_avg, exp_inf, value: -clr)
|
61
61
|
end
|
62
62
|
end
|
63
63
|
|
data/lib/torch/optim/adamw.rb
CHANGED
@@ -58,8 +58,8 @@ module Torch
|
|
58
58
|
bias_correction2 = 1 - beta2 ** state[:step]
|
59
59
|
|
60
60
|
# Decay the first and second moment running average coefficient
|
61
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
62
|
-
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2
|
61
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
62
|
+
exp_avg_sq.mul!(beta2).addcmul!(grad, grad, value: 1 - beta2)
|
63
63
|
if amsgrad
|
64
64
|
# Maintains the maximum of all 2nd moment running avg. till now
|
65
65
|
Torch.max(max_exp_avg_sq, exp_avg_sq, out: max_exp_avg_sq)
|
@@ -71,7 +71,7 @@ module Torch
|
|
71
71
|
|
72
72
|
step_size = group[:lr] / bias_correction1
|
73
73
|
|
74
|
-
p.data.addcdiv!(
|
74
|
+
p.data.addcdiv!(exp_avg, denom, value: -step_size)
|
75
75
|
end
|
76
76
|
end
|
77
77
|
|
data/lib/torch/optim/asgd.rb
CHANGED
@@ -36,14 +36,14 @@ module Torch
|
|
36
36
|
state[:step] += 1
|
37
37
|
|
38
38
|
if group[:weight_decay] != 0
|
39
|
-
grad = grad.add(group[:weight_decay]
|
39
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
40
40
|
end
|
41
41
|
|
42
42
|
# decay term
|
43
43
|
p.data.mul!(1 - group[:lambd] * state[:eta])
|
44
44
|
|
45
45
|
# update parameter
|
46
|
-
p.data.add!(-state[:eta]
|
46
|
+
p.data.add!(grad, alpha: -state[:eta])
|
47
47
|
|
48
48
|
# averaging
|
49
49
|
if state[:mu] != 1
|
data/lib/torch/optim/rmsprop.rb
CHANGED
@@ -46,25 +46,25 @@ module Torch
|
|
46
46
|
state[:step] += 1
|
47
47
|
|
48
48
|
if group[:weight_decay] != 0
|
49
|
-
grad = grad.add(group[:weight_decay]
|
49
|
+
grad = grad.add(p.data, alpha: group[:weight_decay])
|
50
50
|
end
|
51
51
|
|
52
|
-
square_avg.mul!(alpha).addcmul!(1 - alpha
|
52
|
+
square_avg.mul!(alpha).addcmul!(grad, grad, value: 1 - alpha)
|
53
53
|
|
54
54
|
if group[:centered]
|
55
55
|
grad_avg = state[:grad_avg]
|
56
|
-
grad_avg.mul!(alpha).add!(1 - alpha
|
57
|
-
avg = square_avg.addcmul(
|
56
|
+
grad_avg.mul!(alpha).add!(grad, alpha: 1 - alpha)
|
57
|
+
avg = square_avg.addcmul(grad_avg, grad_avg, value: -1).sqrt!.add!(group[:eps])
|
58
58
|
else
|
59
59
|
avg = square_avg.sqrt.add!(group[:eps])
|
60
60
|
end
|
61
61
|
|
62
62
|
if group[:momentum] > 0
|
63
63
|
buf = state[:momentum_buffer]
|
64
|
-
buf.mul!(group[:momentum]).addcdiv!(grad, avg)
|
65
|
-
p.data.add!(-group[:lr]
|
64
|
+
buf.mul!(group[:momentum]).addcdiv!(grad, avg, value: 1)
|
65
|
+
p.data.add!(buf, alpha: -group[:lr])
|
66
66
|
else
|
67
|
-
p.data.addcdiv!(-group[:lr]
|
67
|
+
p.data.addcdiv!(grad, avg, value: -group[:lr])
|
68
68
|
end
|
69
69
|
end
|
70
70
|
end
|
data/lib/torch/optim/rprop.rb
CHANGED
data/lib/torch/optim/sgd.rb
CHANGED
@@ -32,24 +32,24 @@ module Torch
|
|
32
32
|
next unless p.grad
|
33
33
|
d_p = p.grad.data
|
34
34
|
if weight_decay != 0
|
35
|
-
d_p.add!(
|
35
|
+
d_p.add!(p.data, alpha: weight_decay)
|
36
36
|
end
|
37
37
|
if momentum != 0
|
38
38
|
param_state = @state[p]
|
39
|
-
if !param_state.key(:momentum_buffer)
|
39
|
+
if !param_state.key?(:momentum_buffer)
|
40
40
|
buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
|
41
41
|
else
|
42
42
|
buf = param_state[:momentum_buffer]
|
43
|
-
buf.mul!(momentum).add!(1 - dampening
|
43
|
+
buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
|
44
44
|
end
|
45
45
|
if nesterov
|
46
|
-
d_p = d_p.add(
|
46
|
+
d_p = d_p.add(buf, alpha: momentum)
|
47
47
|
else
|
48
48
|
d_p = buf
|
49
49
|
end
|
50
50
|
end
|
51
51
|
|
52
|
-
p.data.add!(-group[:lr]
|
52
|
+
p.data.add!(d_p, alpha: -group[:lr])
|
53
53
|
end
|
54
54
|
end
|
55
55
|
|
data/lib/torch/tensor.rb
CHANGED
@@ -8,6 +8,18 @@ module Torch
|
|
8
8
|
alias_method :ndim, :dim
|
9
9
|
alias_method :ndimension, :dim
|
10
10
|
|
11
|
+
# use alias_method for performance
|
12
|
+
alias_method :+, :add
|
13
|
+
alias_method :-, :sub
|
14
|
+
alias_method :*, :mul
|
15
|
+
alias_method :/, :div
|
16
|
+
alias_method :%, :remainder
|
17
|
+
alias_method :**, :pow
|
18
|
+
alias_method :-@, :neg
|
19
|
+
alias_method :&, :logical_and
|
20
|
+
alias_method :|, :logical_or
|
21
|
+
alias_method :^, :logical_xor
|
22
|
+
|
11
23
|
def self.new(*args)
|
12
24
|
FloatTensor.new(*args)
|
13
25
|
end
|
@@ -48,6 +60,11 @@ module Torch
|
|
48
60
|
end
|
49
61
|
|
50
62
|
def to(device = nil, dtype: nil, non_blocking: false, copy: false)
|
63
|
+
if device.is_a?(Symbol) && !dtype
|
64
|
+
dtype = device
|
65
|
+
device = nil
|
66
|
+
end
|
67
|
+
|
51
68
|
device ||= self.device
|
52
69
|
device = Device.new(device) if device.is_a?(String)
|
53
70
|
|
@@ -68,14 +85,18 @@ module Torch
|
|
68
85
|
|
69
86
|
def size(dim = nil)
|
70
87
|
if dim
|
71
|
-
|
88
|
+
_size(dim)
|
72
89
|
else
|
73
90
|
shape
|
74
91
|
end
|
75
92
|
end
|
76
93
|
|
77
|
-
def
|
78
|
-
dim
|
94
|
+
def stride(dim = nil)
|
95
|
+
if dim
|
96
|
+
_stride(dim)
|
97
|
+
else
|
98
|
+
_strides
|
99
|
+
end
|
79
100
|
end
|
80
101
|
|
81
102
|
# mirror Python len()
|
@@ -119,60 +140,14 @@ module Torch
|
|
119
140
|
end
|
120
141
|
|
121
142
|
def type(dtype)
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
_reshape(size)
|
131
|
-
end
|
132
|
-
|
133
|
-
def view(*size)
|
134
|
-
size = size.first if size.size == 1 && size.first.is_a?(Array)
|
135
|
-
_view(size)
|
136
|
-
end
|
137
|
-
|
138
|
-
def +(other)
|
139
|
-
add(other)
|
140
|
-
end
|
141
|
-
|
142
|
-
def -(other)
|
143
|
-
sub(other)
|
144
|
-
end
|
145
|
-
|
146
|
-
def *(other)
|
147
|
-
mul(other)
|
148
|
-
end
|
149
|
-
|
150
|
-
def /(other)
|
151
|
-
div(other)
|
152
|
-
end
|
153
|
-
|
154
|
-
def %(other)
|
155
|
-
remainder(other)
|
156
|
-
end
|
157
|
-
|
158
|
-
def **(other)
|
159
|
-
pow(other)
|
160
|
-
end
|
161
|
-
|
162
|
-
def -@
|
163
|
-
neg
|
164
|
-
end
|
165
|
-
|
166
|
-
def &(other)
|
167
|
-
logical_and(other)
|
168
|
-
end
|
169
|
-
|
170
|
-
def |(other)
|
171
|
-
logical_or(other)
|
172
|
-
end
|
173
|
-
|
174
|
-
def ^(other)
|
175
|
-
logical_xor(other)
|
143
|
+
if dtype.is_a?(Class)
|
144
|
+
raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
|
145
|
+
dtype.new(self)
|
146
|
+
else
|
147
|
+
enum = DTYPE_TO_ENUM[dtype]
|
148
|
+
raise Error, "Invalid type: #{dtype}" unless enum
|
149
|
+
_type(enum)
|
150
|
+
end
|
176
151
|
end
|
177
152
|
|
178
153
|
# TODO better compare?
|
@@ -183,7 +158,7 @@ module Torch
|
|
183
158
|
# based on python_variable_indexing.cpp and
|
184
159
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
185
160
|
def [](*indexes)
|
186
|
-
_index(
|
161
|
+
_index(indexes)
|
187
162
|
end
|
188
163
|
|
189
164
|
# based on python_variable_indexing.cpp and
|
@@ -191,62 +166,13 @@ module Torch
|
|
191
166
|
def []=(*indexes, value)
|
192
167
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
193
168
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
194
|
-
_index_put_custom(
|
195
|
-
end
|
196
|
-
|
197
|
-
# native functions that need manually defined
|
198
|
-
|
199
|
-
# value and other are swapped for some methods
|
200
|
-
def add!(value = 1, other)
|
201
|
-
if other.is_a?(Numeric)
|
202
|
-
_add__scalar(other, value)
|
203
|
-
else
|
204
|
-
_add__tensor(other, value)
|
205
|
-
end
|
169
|
+
_index_put_custom(indexes, value)
|
206
170
|
end
|
207
171
|
|
208
172
|
# parser can't handle overlap, so need to handle manually
|
209
173
|
def random!(*args)
|
210
|
-
|
211
|
-
|
212
|
-
_random__to(*args)
|
213
|
-
when 2
|
214
|
-
_random__from(*args)
|
215
|
-
else
|
216
|
-
_random_(*args)
|
217
|
-
end
|
218
|
-
end
|
219
|
-
|
220
|
-
def clamp!(min, max)
|
221
|
-
_clamp_min_(min)
|
222
|
-
_clamp_max_(max)
|
223
|
-
end
|
224
|
-
|
225
|
-
private
|
226
|
-
|
227
|
-
def tensor_indexes(indexes)
|
228
|
-
indexes.map do |index|
|
229
|
-
case index
|
230
|
-
when Integer
|
231
|
-
TensorIndex.integer(index)
|
232
|
-
when Range
|
233
|
-
finish = index.end || -1
|
234
|
-
if finish == -1 && !index.exclude_end?
|
235
|
-
finish = nil
|
236
|
-
else
|
237
|
-
finish += 1 unless index.exclude_end?
|
238
|
-
end
|
239
|
-
TensorIndex.slice(index.begin, finish)
|
240
|
-
when Tensor
|
241
|
-
TensorIndex.tensor(index)
|
242
|
-
when nil
|
243
|
-
TensorIndex.none
|
244
|
-
when true, false
|
245
|
-
TensorIndex.boolean(index)
|
246
|
-
else
|
247
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
248
|
-
end
|
249
|
-
end
|
174
|
+
return _random!(0, *args) if args.size == 1
|
175
|
+
_random!(*args)
|
250
176
|
end
|
251
177
|
end
|
252
178
|
end
|