torch-rb 0.3.4 → 0.4.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +2 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +549 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +76 -87
- data/ext/torch/extconf.rb +5 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +87 -97
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +44 -7
- data/lib/torch.rb +51 -77
- data/lib/torch/nn/functional.rb +142 -18
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +36 -115
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +19 -14
- data/lib/torch/native/dispatcher.rb +0 -48
- data/lib/torch/native/function.rb +0 -115
- data/lib/torch/native/generator.rb +0 -163
- data/lib/torch/native/parser.rb +0 -140
data/lib/torch/nn/init.rb
CHANGED
@@ -14,25 +14,11 @@ module Torch
|
|
14
14
|
_normal!(tensor, mean, std)
|
15
15
|
end
|
16
16
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
_ones!(tensor)
|
23
|
-
end
|
24
|
-
|
25
|
-
def zeros!(tensor)
|
26
|
-
_zeros!(tensor)
|
27
|
-
end
|
28
|
-
|
29
|
-
def eye!(tensor)
|
30
|
-
_eye!(tensor)
|
31
|
-
end
|
32
|
-
|
33
|
-
def dirac!(tensor)
|
34
|
-
_dirac!(tensor)
|
35
|
-
end
|
17
|
+
alias_method :constant!, :_constant!
|
18
|
+
alias_method :ones!, :_ones!
|
19
|
+
alias_method :zeros!, :_zeros!
|
20
|
+
alias_method :eye!, :_eye!
|
21
|
+
alias_method :dirac!, :_dirac!
|
36
22
|
|
37
23
|
def xavier_uniform!(tensor, gain: 1.0)
|
38
24
|
_xavier_uniform!(tensor, gain)
|
data/lib/torch/nn/leaky_relu.rb
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class LeakyReLU < Module
|
4
|
-
def initialize(negative_slope: 1e-2
|
4
|
+
def initialize(negative_slope: 1e-2, inplace: false)
|
5
5
|
super()
|
6
6
|
@negative_slope = negative_slope
|
7
|
-
|
7
|
+
@inplace = inplace
|
8
8
|
end
|
9
9
|
|
10
10
|
def forward(input)
|
11
|
-
F.leaky_relu(input, @negative_slope
|
11
|
+
F.leaky_relu(input, @negative_slope, inplace: @inplace)
|
12
12
|
end
|
13
13
|
|
14
14
|
def extra_inspect
|
data/lib/torch/nn/module.rb
CHANGED
@@ -55,7 +55,15 @@ module Torch
|
|
55
55
|
end
|
56
56
|
end
|
57
57
|
end
|
58
|
-
|
58
|
+
|
59
|
+
@buffers.each_key do |k|
|
60
|
+
buf = @buffers[k]
|
61
|
+
unless buf.nil?
|
62
|
+
@buffers[k] = fn.call(buf)
|
63
|
+
instance_variable_set("@#{k}", @buffers[k])
|
64
|
+
end
|
65
|
+
end
|
66
|
+
|
59
67
|
self
|
60
68
|
end
|
61
69
|
|
@@ -0,0 +1,31 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Upsample < Module
|
4
|
+
def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
|
5
|
+
super()
|
6
|
+
@size = size
|
7
|
+
if scale_factor.is_a?(Array)
|
8
|
+
@scale_factor = scale_factor.map(&:to_f)
|
9
|
+
else
|
10
|
+
@scale_factor = scale_factor ? scale_factor.to_f : nil
|
11
|
+
end
|
12
|
+
@mode = mode
|
13
|
+
@align_corners = align_corners
|
14
|
+
end
|
15
|
+
|
16
|
+
def forward(input)
|
17
|
+
F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
|
18
|
+
end
|
19
|
+
|
20
|
+
def extra_inspect
|
21
|
+
if !@scale_factor.nil?
|
22
|
+
info = "scale_factor: #{@scale_factor.inspect}"
|
23
|
+
else
|
24
|
+
info = "size: #{@size.inspect}"
|
25
|
+
end
|
26
|
+
info += ", mode: #{@mode.inspect}"
|
27
|
+
info
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
data/lib/torch/optim/adadelta.rb
CHANGED
@@ -45,7 +45,7 @@ module Torch
|
|
45
45
|
square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
|
46
46
|
std = square_avg.add(eps).sqrt!
|
47
47
|
delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
|
48
|
-
p.data.add!(-group[:lr]
|
48
|
+
p.data.add!(delta, alpha: -group[:lr])
|
49
49
|
acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
|
50
50
|
end
|
51
51
|
end
|
data/lib/torch/optim/adam.rb
CHANGED
@@ -53,11 +53,11 @@ module Torch
|
|
53
53
|
bias_correction2 = 1 - beta2 ** state[:step]
|
54
54
|
|
55
55
|
if group[:weight_decay] != 0
|
56
|
-
grad.add!(group[:weight_decay]
|
56
|
+
grad.add!(p.data, alpha: group[:weight_decay])
|
57
57
|
end
|
58
58
|
|
59
59
|
# Decay the first and second moment running average coefficient
|
60
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
60
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
61
61
|
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
|
62
62
|
if amsgrad
|
63
63
|
# Maintains the maximum of all 2nd moment running avg. till now
|
data/lib/torch/optim/adamax.rb
CHANGED
@@ -46,7 +46,7 @@ module Torch
|
|
46
46
|
end
|
47
47
|
|
48
48
|
# Update biased first moment estimate.
|
49
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
49
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
50
50
|
# Update the exponentially weighted infinity norm.
|
51
51
|
norm_buf = Torch.cat([
|
52
52
|
exp_inf.mul!(beta2).unsqueeze(0),
|
data/lib/torch/optim/adamw.rb
CHANGED
@@ -58,7 +58,7 @@ module Torch
|
|
58
58
|
bias_correction2 = 1 - beta2 ** state[:step]
|
59
59
|
|
60
60
|
# Decay the first and second moment running average coefficient
|
61
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
61
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
62
62
|
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
|
63
63
|
if amsgrad
|
64
64
|
# Maintains the maximum of all 2nd moment running avg. till now
|
data/lib/torch/optim/asgd.rb
CHANGED
data/lib/torch/optim/sgd.rb
CHANGED
@@ -32,7 +32,7 @@ module Torch
|
|
32
32
|
next unless p.grad
|
33
33
|
d_p = p.grad.data
|
34
34
|
if weight_decay != 0
|
35
|
-
d_p.add!(
|
35
|
+
d_p.add!(p.data, alpha: weight_decay)
|
36
36
|
end
|
37
37
|
if momentum != 0
|
38
38
|
param_state = @state[p]
|
@@ -40,7 +40,7 @@ module Torch
|
|
40
40
|
buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
|
41
41
|
else
|
42
42
|
buf = param_state[:momentum_buffer]
|
43
|
-
buf.mul!(momentum).add!(1 - dampening
|
43
|
+
buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
|
44
44
|
end
|
45
45
|
if nesterov
|
46
46
|
d_p = d_p.add(momentum, buf)
|
@@ -49,7 +49,7 @@ module Torch
|
|
49
49
|
end
|
50
50
|
end
|
51
51
|
|
52
|
-
p.data.add!(-group[:lr]
|
52
|
+
p.data.add!(d_p, alpha: -group[:lr])
|
53
53
|
end
|
54
54
|
end
|
55
55
|
|
data/lib/torch/tensor.rb
CHANGED
@@ -8,6 +8,18 @@ module Torch
|
|
8
8
|
alias_method :ndim, :dim
|
9
9
|
alias_method :ndimension, :dim
|
10
10
|
|
11
|
+
# use alias_method for performance
|
12
|
+
alias_method :+, :add
|
13
|
+
alias_method :-, :sub
|
14
|
+
alias_method :*, :mul
|
15
|
+
alias_method :/, :div
|
16
|
+
alias_method :%, :remainder
|
17
|
+
alias_method :**, :pow
|
18
|
+
alias_method :-@, :neg
|
19
|
+
alias_method :&, :logical_and
|
20
|
+
alias_method :|, :logical_or
|
21
|
+
alias_method :^, :logical_xor
|
22
|
+
|
11
23
|
def self.new(*args)
|
12
24
|
FloatTensor.new(*args)
|
13
25
|
end
|
@@ -48,6 +60,11 @@ module Torch
|
|
48
60
|
end
|
49
61
|
|
50
62
|
def to(device = nil, dtype: nil, non_blocking: false, copy: false)
|
63
|
+
if device.is_a?(Symbol) && !dtype
|
64
|
+
dtype = device
|
65
|
+
device = nil
|
66
|
+
end
|
67
|
+
|
51
68
|
device ||= self.device
|
52
69
|
device = Device.new(device) if device.is_a?(String)
|
53
70
|
|
@@ -68,14 +85,18 @@ module Torch
|
|
68
85
|
|
69
86
|
def size(dim = nil)
|
70
87
|
if dim
|
71
|
-
|
88
|
+
_size(dim)
|
72
89
|
else
|
73
90
|
shape
|
74
91
|
end
|
75
92
|
end
|
76
93
|
|
77
|
-
def
|
78
|
-
dim
|
94
|
+
def stride(dim = nil)
|
95
|
+
if dim
|
96
|
+
_stride(dim)
|
97
|
+
else
|
98
|
+
_strides
|
99
|
+
end
|
79
100
|
end
|
80
101
|
|
81
102
|
# mirror Python len()
|
@@ -103,11 +124,6 @@ module Torch
|
|
103
124
|
Torch.empty(0, dtype: dtype)
|
104
125
|
end
|
105
126
|
|
106
|
-
def backward(gradient = nil, retain_graph: nil, create_graph: false)
|
107
|
-
retain_graph = create_graph if retain_graph.nil?
|
108
|
-
_backward(gradient, retain_graph, create_graph)
|
109
|
-
end
|
110
|
-
|
111
127
|
# TODO read directly from memory
|
112
128
|
def numo
|
113
129
|
cls = Torch._dtype_to_numo[dtype]
|
@@ -124,60 +140,14 @@ module Torch
|
|
124
140
|
end
|
125
141
|
|
126
142
|
def type(dtype)
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
_reshape(size)
|
136
|
-
end
|
137
|
-
|
138
|
-
def view(*size)
|
139
|
-
size = size.first if size.size == 1 && size.first.is_a?(Array)
|
140
|
-
_view(size)
|
141
|
-
end
|
142
|
-
|
143
|
-
def +(other)
|
144
|
-
add(other)
|
145
|
-
end
|
146
|
-
|
147
|
-
def -(other)
|
148
|
-
sub(other)
|
149
|
-
end
|
150
|
-
|
151
|
-
def *(other)
|
152
|
-
mul(other)
|
153
|
-
end
|
154
|
-
|
155
|
-
def /(other)
|
156
|
-
div(other)
|
157
|
-
end
|
158
|
-
|
159
|
-
def %(other)
|
160
|
-
remainder(other)
|
161
|
-
end
|
162
|
-
|
163
|
-
def **(other)
|
164
|
-
pow(other)
|
165
|
-
end
|
166
|
-
|
167
|
-
def -@
|
168
|
-
neg
|
169
|
-
end
|
170
|
-
|
171
|
-
def &(other)
|
172
|
-
logical_and(other)
|
173
|
-
end
|
174
|
-
|
175
|
-
def |(other)
|
176
|
-
logical_or(other)
|
177
|
-
end
|
178
|
-
|
179
|
-
def ^(other)
|
180
|
-
logical_xor(other)
|
143
|
+
if dtype.is_a?(Class)
|
144
|
+
raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
|
145
|
+
dtype.new(self)
|
146
|
+
else
|
147
|
+
enum = DTYPE_TO_ENUM[dtype]
|
148
|
+
raise Error, "Invalid type: #{dtype}" unless enum
|
149
|
+
_type(enum)
|
150
|
+
end
|
181
151
|
end
|
182
152
|
|
183
153
|
# TODO better compare?
|
@@ -188,7 +158,7 @@ module Torch
|
|
188
158
|
# based on python_variable_indexing.cpp and
|
189
159
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
190
160
|
def [](*indexes)
|
191
|
-
_index(
|
161
|
+
_index(indexes)
|
192
162
|
end
|
193
163
|
|
194
164
|
# based on python_variable_indexing.cpp and
|
@@ -196,62 +166,13 @@ module Torch
|
|
196
166
|
def []=(*indexes, value)
|
197
167
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
198
168
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
199
|
-
_index_put_custom(
|
200
|
-
end
|
201
|
-
|
202
|
-
# native functions that need manually defined
|
203
|
-
|
204
|
-
# value and other are swapped for some methods
|
205
|
-
def add!(value = 1, other)
|
206
|
-
if other.is_a?(Numeric)
|
207
|
-
_add__scalar(other, value)
|
208
|
-
else
|
209
|
-
_add__tensor(other, value)
|
210
|
-
end
|
169
|
+
_index_put_custom(indexes, value)
|
211
170
|
end
|
212
171
|
|
213
172
|
# parser can't handle overlap, so need to handle manually
|
214
173
|
def random!(*args)
|
215
|
-
|
216
|
-
|
217
|
-
_random__to(*args)
|
218
|
-
when 2
|
219
|
-
_random__from(*args)
|
220
|
-
else
|
221
|
-
_random_(*args)
|
222
|
-
end
|
223
|
-
end
|
224
|
-
|
225
|
-
def clamp!(min, max)
|
226
|
-
_clamp_min_(min)
|
227
|
-
_clamp_max_(max)
|
228
|
-
end
|
229
|
-
|
230
|
-
private
|
231
|
-
|
232
|
-
def tensor_indexes(indexes)
|
233
|
-
indexes.map do |index|
|
234
|
-
case index
|
235
|
-
when Integer
|
236
|
-
TensorIndex.integer(index)
|
237
|
-
when Range
|
238
|
-
finish = index.end
|
239
|
-
if finish == -1 && !index.exclude_end?
|
240
|
-
finish = nil
|
241
|
-
else
|
242
|
-
finish += 1 unless index.exclude_end?
|
243
|
-
end
|
244
|
-
TensorIndex.slice(index.begin, finish)
|
245
|
-
when Tensor
|
246
|
-
TensorIndex.tensor(index)
|
247
|
-
when nil
|
248
|
-
TensorIndex.none
|
249
|
-
when true, false
|
250
|
-
TensorIndex.boolean(index)
|
251
|
-
else
|
252
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
253
|
-
end
|
254
|
-
end
|
174
|
+
return _random!(0, *args) if args.size == 1
|
175
|
+
_random!(*args)
|
255
176
|
end
|
256
177
|
end
|
257
178
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.4.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
|
-
autorequire:
|
8
|
+
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-10-13 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -108,7 +108,7 @@ dependencies:
|
|
108
108
|
- - ">="
|
109
109
|
- !ruby/object:Gem::Version
|
110
110
|
version: 0.1.1
|
111
|
-
description:
|
111
|
+
description:
|
112
112
|
email: andrew@chartkick.com
|
113
113
|
executables: []
|
114
114
|
extensions:
|
@@ -118,19 +118,23 @@ files:
|
|
118
118
|
- CHANGELOG.md
|
119
119
|
- LICENSE.txt
|
120
120
|
- README.md
|
121
|
+
- codegen/function.rb
|
122
|
+
- codegen/generate_functions.rb
|
123
|
+
- codegen/native_functions.yaml
|
121
124
|
- ext/torch/ext.cpp
|
122
125
|
- ext/torch/extconf.rb
|
123
|
-
- ext/torch/
|
124
|
-
- ext/torch/
|
126
|
+
- ext/torch/nn_functions.h
|
127
|
+
- ext/torch/ruby_arg_parser.cpp
|
128
|
+
- ext/torch/ruby_arg_parser.h
|
129
|
+
- ext/torch/templates.h
|
130
|
+
- ext/torch/tensor_functions.h
|
131
|
+
- ext/torch/torch_functions.h
|
132
|
+
- ext/torch/utils.h
|
133
|
+
- ext/torch/wrap_outputs.h
|
125
134
|
- lib/torch-rb.rb
|
126
135
|
- lib/torch.rb
|
127
136
|
- lib/torch/hub.rb
|
128
137
|
- lib/torch/inspector.rb
|
129
|
-
- lib/torch/native/dispatcher.rb
|
130
|
-
- lib/torch/native/function.rb
|
131
|
-
- lib/torch/native/generator.rb
|
132
|
-
- lib/torch/native/native_functions.yaml
|
133
|
-
- lib/torch/native/parser.rb
|
134
138
|
- lib/torch/nn/adaptive_avg_pool1d.rb
|
135
139
|
- lib/torch/nn/adaptive_avg_pool2d.rb
|
136
140
|
- lib/torch/nn/adaptive_avg_pool3d.rb
|
@@ -238,6 +242,7 @@ files:
|
|
238
242
|
- lib/torch/nn/tanhshrink.rb
|
239
243
|
- lib/torch/nn/triplet_margin_loss.rb
|
240
244
|
- lib/torch/nn/unfold.rb
|
245
|
+
- lib/torch/nn/upsample.rb
|
241
246
|
- lib/torch/nn/utils.rb
|
242
247
|
- lib/torch/nn/weighted_loss.rb
|
243
248
|
- lib/torch/nn/zero_pad2d.rb
|
@@ -269,7 +274,7 @@ homepage: https://github.com/ankane/torch.rb
|
|
269
274
|
licenses:
|
270
275
|
- BSD-3-Clause
|
271
276
|
metadata: {}
|
272
|
-
post_install_message:
|
277
|
+
post_install_message:
|
273
278
|
rdoc_options: []
|
274
279
|
require_paths:
|
275
280
|
- lib
|
@@ -284,8 +289,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
284
289
|
- !ruby/object:Gem::Version
|
285
290
|
version: '0'
|
286
291
|
requirements: []
|
287
|
-
rubygems_version: 3.
|
288
|
-
signing_key:
|
292
|
+
rubygems_version: 3.0.3
|
293
|
+
signing_key:
|
289
294
|
specification_version: 4
|
290
295
|
summary: Deep learning for Ruby, powered by LibTorch
|
291
296
|
test_files: []
|