torch-rb 0.3.3 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +2 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +546 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +76 -87
- data/ext/torch/extconf.rb +5 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +87 -97
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +44 -7
- data/lib/torch.rb +56 -77
- data/lib/torch/nn/functional.rb +142 -18
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +36 -115
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +28 -9
- data/lib/torch/native/dispatcher.rb +0 -48
- data/lib/torch/native/function.rb +0 -115
- data/lib/torch/native/generator.rb +0 -163
- data/lib/torch/native/parser.rb +0 -140
data/lib/torch/nn/init.rb
CHANGED
@@ -14,25 +14,11 @@ module Torch
|
|
14
14
|
_normal!(tensor, mean, std)
|
15
15
|
end
|
16
16
|
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
_ones!(tensor)
|
23
|
-
end
|
24
|
-
|
25
|
-
def zeros!(tensor)
|
26
|
-
_zeros!(tensor)
|
27
|
-
end
|
28
|
-
|
29
|
-
def eye!(tensor)
|
30
|
-
_eye!(tensor)
|
31
|
-
end
|
32
|
-
|
33
|
-
def dirac!(tensor)
|
34
|
-
_dirac!(tensor)
|
35
|
-
end
|
17
|
+
alias_method :constant!, :_constant!
|
18
|
+
alias_method :ones!, :_ones!
|
19
|
+
alias_method :zeros!, :_zeros!
|
20
|
+
alias_method :eye!, :_eye!
|
21
|
+
alias_method :dirac!, :_dirac!
|
36
22
|
|
37
23
|
def xavier_uniform!(tensor, gain: 1.0)
|
38
24
|
_xavier_uniform!(tensor, gain)
|
data/lib/torch/nn/leaky_relu.rb
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class LeakyReLU < Module
|
4
|
-
def initialize(negative_slope: 1e-2
|
4
|
+
def initialize(negative_slope: 1e-2, inplace: false)
|
5
5
|
super()
|
6
6
|
@negative_slope = negative_slope
|
7
|
-
|
7
|
+
@inplace = inplace
|
8
8
|
end
|
9
9
|
|
10
10
|
def forward(input)
|
11
|
-
F.leaky_relu(input, @negative_slope
|
11
|
+
F.leaky_relu(input, @negative_slope, inplace: @inplace)
|
12
12
|
end
|
13
13
|
|
14
14
|
def extra_inspect
|
data/lib/torch/nn/module.rb
CHANGED
@@ -55,7 +55,15 @@ module Torch
|
|
55
55
|
end
|
56
56
|
end
|
57
57
|
end
|
58
|
-
|
58
|
+
|
59
|
+
@buffers.each_key do |k|
|
60
|
+
buf = @buffers[k]
|
61
|
+
unless buf.nil?
|
62
|
+
@buffers[k] = fn.call(buf)
|
63
|
+
instance_variable_set("@#{k}", @buffers[k])
|
64
|
+
end
|
65
|
+
end
|
66
|
+
|
59
67
|
self
|
60
68
|
end
|
61
69
|
|
@@ -0,0 +1,31 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Upsample < Module
|
4
|
+
def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
|
5
|
+
super()
|
6
|
+
@size = size
|
7
|
+
if scale_factor.is_a?(Array)
|
8
|
+
@scale_factor = scale_factor.map(&:to_f)
|
9
|
+
else
|
10
|
+
@scale_factor = scale_factor ? scale_factor.to_f : nil
|
11
|
+
end
|
12
|
+
@mode = mode
|
13
|
+
@align_corners = align_corners
|
14
|
+
end
|
15
|
+
|
16
|
+
def forward(input)
|
17
|
+
F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
|
18
|
+
end
|
19
|
+
|
20
|
+
def extra_inspect
|
21
|
+
if !@scale_factor.nil?
|
22
|
+
info = "scale_factor: #{@scale_factor.inspect}"
|
23
|
+
else
|
24
|
+
info = "size: #{@size.inspect}"
|
25
|
+
end
|
26
|
+
info += ", mode: #{@mode.inspect}"
|
27
|
+
info
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
data/lib/torch/optim/adadelta.rb
CHANGED
@@ -45,7 +45,7 @@ module Torch
|
|
45
45
|
square_avg.mul!(rho).addcmul!(1 - rho, grad, grad)
|
46
46
|
std = square_avg.add(eps).sqrt!
|
47
47
|
delta = acc_delta.add(eps).sqrt!.div!(std).mul!(grad)
|
48
|
-
p.data.add!(-group[:lr]
|
48
|
+
p.data.add!(delta, alpha: -group[:lr])
|
49
49
|
acc_delta.mul!(rho).addcmul!(1 - rho, delta, delta)
|
50
50
|
end
|
51
51
|
end
|
data/lib/torch/optim/adam.rb
CHANGED
@@ -53,11 +53,11 @@ module Torch
|
|
53
53
|
bias_correction2 = 1 - beta2 ** state[:step]
|
54
54
|
|
55
55
|
if group[:weight_decay] != 0
|
56
|
-
grad.add!(group[:weight_decay]
|
56
|
+
grad.add!(p.data, alpha: group[:weight_decay])
|
57
57
|
end
|
58
58
|
|
59
59
|
# Decay the first and second moment running average coefficient
|
60
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
60
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
61
61
|
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
|
62
62
|
if amsgrad
|
63
63
|
# Maintains the maximum of all 2nd moment running avg. till now
|
data/lib/torch/optim/adamax.rb
CHANGED
@@ -46,7 +46,7 @@ module Torch
|
|
46
46
|
end
|
47
47
|
|
48
48
|
# Update biased first moment estimate.
|
49
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
49
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
50
50
|
# Update the exponentially weighted infinity norm.
|
51
51
|
norm_buf = Torch.cat([
|
52
52
|
exp_inf.mul!(beta2).unsqueeze(0),
|
data/lib/torch/optim/adamw.rb
CHANGED
@@ -58,7 +58,7 @@ module Torch
|
|
58
58
|
bias_correction2 = 1 - beta2 ** state[:step]
|
59
59
|
|
60
60
|
# Decay the first and second moment running average coefficient
|
61
|
-
exp_avg.mul!(beta1).add!(1 - beta1
|
61
|
+
exp_avg.mul!(beta1).add!(grad, alpha: 1 - beta1)
|
62
62
|
exp_avg_sq.mul!(beta2).addcmul!(1 - beta2, grad, grad)
|
63
63
|
if amsgrad
|
64
64
|
# Maintains the maximum of all 2nd moment running avg. till now
|
data/lib/torch/optim/asgd.rb
CHANGED
data/lib/torch/optim/sgd.rb
CHANGED
@@ -32,7 +32,7 @@ module Torch
|
|
32
32
|
next unless p.grad
|
33
33
|
d_p = p.grad.data
|
34
34
|
if weight_decay != 0
|
35
|
-
d_p.add!(
|
35
|
+
d_p.add!(p.data, alpha: weight_decay)
|
36
36
|
end
|
37
37
|
if momentum != 0
|
38
38
|
param_state = @state[p]
|
@@ -40,7 +40,7 @@ module Torch
|
|
40
40
|
buf = param_state[:momentum_buffer] = Torch.clone(d_p).detach
|
41
41
|
else
|
42
42
|
buf = param_state[:momentum_buffer]
|
43
|
-
buf.mul!(momentum).add!(1 - dampening
|
43
|
+
buf.mul!(momentum).add!(d_p, alpha: 1 - dampening)
|
44
44
|
end
|
45
45
|
if nesterov
|
46
46
|
d_p = d_p.add(momentum, buf)
|
@@ -49,7 +49,7 @@ module Torch
|
|
49
49
|
end
|
50
50
|
end
|
51
51
|
|
52
|
-
p.data.add!(-group[:lr]
|
52
|
+
p.data.add!(d_p, alpha: -group[:lr])
|
53
53
|
end
|
54
54
|
end
|
55
55
|
|
data/lib/torch/tensor.rb
CHANGED
@@ -8,6 +8,18 @@ module Torch
|
|
8
8
|
alias_method :ndim, :dim
|
9
9
|
alias_method :ndimension, :dim
|
10
10
|
|
11
|
+
# use alias_method for performance
|
12
|
+
alias_method :+, :add
|
13
|
+
alias_method :-, :sub
|
14
|
+
alias_method :*, :mul
|
15
|
+
alias_method :/, :div
|
16
|
+
alias_method :%, :remainder
|
17
|
+
alias_method :**, :pow
|
18
|
+
alias_method :-@, :neg
|
19
|
+
alias_method :&, :logical_and
|
20
|
+
alias_method :|, :logical_or
|
21
|
+
alias_method :^, :logical_xor
|
22
|
+
|
11
23
|
def self.new(*args)
|
12
24
|
FloatTensor.new(*args)
|
13
25
|
end
|
@@ -48,6 +60,11 @@ module Torch
|
|
48
60
|
end
|
49
61
|
|
50
62
|
def to(device = nil, dtype: nil, non_blocking: false, copy: false)
|
63
|
+
if device.is_a?(Symbol) && !dtype
|
64
|
+
dtype = device
|
65
|
+
device = nil
|
66
|
+
end
|
67
|
+
|
51
68
|
device ||= self.device
|
52
69
|
device = Device.new(device) if device.is_a?(String)
|
53
70
|
|
@@ -68,14 +85,18 @@ module Torch
|
|
68
85
|
|
69
86
|
def size(dim = nil)
|
70
87
|
if dim
|
71
|
-
|
88
|
+
_size(dim)
|
72
89
|
else
|
73
90
|
shape
|
74
91
|
end
|
75
92
|
end
|
76
93
|
|
77
|
-
def
|
78
|
-
dim
|
94
|
+
def stride(dim = nil)
|
95
|
+
if dim
|
96
|
+
_stride(dim)
|
97
|
+
else
|
98
|
+
_strides
|
99
|
+
end
|
79
100
|
end
|
80
101
|
|
81
102
|
# mirror Python len()
|
@@ -103,11 +124,6 @@ module Torch
|
|
103
124
|
Torch.empty(0, dtype: dtype)
|
104
125
|
end
|
105
126
|
|
106
|
-
def backward(gradient = nil, retain_graph: nil, create_graph: false)
|
107
|
-
retain_graph = create_graph if retain_graph.nil?
|
108
|
-
_backward(gradient, retain_graph, create_graph)
|
109
|
-
end
|
110
|
-
|
111
127
|
# TODO read directly from memory
|
112
128
|
def numo
|
113
129
|
cls = Torch._dtype_to_numo[dtype]
|
@@ -124,60 +140,14 @@ module Torch
|
|
124
140
|
end
|
125
141
|
|
126
142
|
def type(dtype)
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
_reshape(size)
|
136
|
-
end
|
137
|
-
|
138
|
-
def view(*size)
|
139
|
-
size = size.first if size.size == 1 && size.first.is_a?(Array)
|
140
|
-
_view(size)
|
141
|
-
end
|
142
|
-
|
143
|
-
def +(other)
|
144
|
-
add(other)
|
145
|
-
end
|
146
|
-
|
147
|
-
def -(other)
|
148
|
-
sub(other)
|
149
|
-
end
|
150
|
-
|
151
|
-
def *(other)
|
152
|
-
mul(other)
|
153
|
-
end
|
154
|
-
|
155
|
-
def /(other)
|
156
|
-
div(other)
|
157
|
-
end
|
158
|
-
|
159
|
-
def %(other)
|
160
|
-
remainder(other)
|
161
|
-
end
|
162
|
-
|
163
|
-
def **(other)
|
164
|
-
pow(other)
|
165
|
-
end
|
166
|
-
|
167
|
-
def -@
|
168
|
-
neg
|
169
|
-
end
|
170
|
-
|
171
|
-
def &(other)
|
172
|
-
logical_and(other)
|
173
|
-
end
|
174
|
-
|
175
|
-
def |(other)
|
176
|
-
logical_or(other)
|
177
|
-
end
|
178
|
-
|
179
|
-
def ^(other)
|
180
|
-
logical_xor(other)
|
143
|
+
if dtype.is_a?(Class)
|
144
|
+
raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
|
145
|
+
dtype.new(self)
|
146
|
+
else
|
147
|
+
enum = DTYPE_TO_ENUM[dtype]
|
148
|
+
raise Error, "Invalid type: #{dtype}" unless enum
|
149
|
+
_type(enum)
|
150
|
+
end
|
181
151
|
end
|
182
152
|
|
183
153
|
# TODO better compare?
|
@@ -188,7 +158,7 @@ module Torch
|
|
188
158
|
# based on python_variable_indexing.cpp and
|
189
159
|
# https://pytorch.org/cppdocs/notes/tensor_indexing.html
|
190
160
|
def [](*indexes)
|
191
|
-
_index(
|
161
|
+
_index(indexes)
|
192
162
|
end
|
193
163
|
|
194
164
|
# based on python_variable_indexing.cpp and
|
@@ -196,62 +166,13 @@ module Torch
|
|
196
166
|
def []=(*indexes, value)
|
197
167
|
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
198
168
|
value = Torch.tensor(value, dtype: dtype) unless value.is_a?(Tensor)
|
199
|
-
_index_put_custom(
|
200
|
-
end
|
201
|
-
|
202
|
-
# native functions that need manually defined
|
203
|
-
|
204
|
-
# value and other are swapped for some methods
|
205
|
-
def add!(value = 1, other)
|
206
|
-
if other.is_a?(Numeric)
|
207
|
-
_add__scalar(other, value)
|
208
|
-
else
|
209
|
-
_add__tensor(other, value)
|
210
|
-
end
|
169
|
+
_index_put_custom(indexes, value)
|
211
170
|
end
|
212
171
|
|
213
172
|
# parser can't handle overlap, so need to handle manually
|
214
173
|
def random!(*args)
|
215
|
-
|
216
|
-
|
217
|
-
_random__to(*args)
|
218
|
-
when 2
|
219
|
-
_random__from(*args)
|
220
|
-
else
|
221
|
-
_random_(*args)
|
222
|
-
end
|
223
|
-
end
|
224
|
-
|
225
|
-
def clamp!(min, max)
|
226
|
-
_clamp_min_(min)
|
227
|
-
_clamp_max_(max)
|
228
|
-
end
|
229
|
-
|
230
|
-
private
|
231
|
-
|
232
|
-
def tensor_indexes(indexes)
|
233
|
-
indexes.map do |index|
|
234
|
-
case index
|
235
|
-
when Integer
|
236
|
-
TensorIndex.integer(index)
|
237
|
-
when Range
|
238
|
-
finish = index.end
|
239
|
-
if finish == -1 && !index.exclude_end?
|
240
|
-
finish = nil
|
241
|
-
else
|
242
|
-
finish += 1 unless index.exclude_end?
|
243
|
-
end
|
244
|
-
TensorIndex.slice(index.begin, finish)
|
245
|
-
when Tensor
|
246
|
-
TensorIndex.tensor(index)
|
247
|
-
when nil
|
248
|
-
TensorIndex.none
|
249
|
-
when true, false
|
250
|
-
TensorIndex.boolean(index)
|
251
|
-
else
|
252
|
-
raise Error, "Unsupported index type: #{index.class.name}"
|
253
|
-
end
|
254
|
-
end
|
174
|
+
return _random!(0, *args) if args.size == 1
|
175
|
+
_random!(*args)
|
255
176
|
end
|
256
177
|
end
|
257
178
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.4.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-09-27 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -108,6 +108,20 @@ dependencies:
|
|
108
108
|
- - ">="
|
109
109
|
- !ruby/object:Gem::Version
|
110
110
|
version: 0.1.1
|
111
|
+
- !ruby/object:Gem::Dependency
|
112
|
+
name: magro
|
113
|
+
requirement: !ruby/object:Gem::Requirement
|
114
|
+
requirements:
|
115
|
+
- - ">="
|
116
|
+
- !ruby/object:Gem::Version
|
117
|
+
version: '0'
|
118
|
+
type: :development
|
119
|
+
prerelease: false
|
120
|
+
version_requirements: !ruby/object:Gem::Requirement
|
121
|
+
requirements:
|
122
|
+
- - ">="
|
123
|
+
- !ruby/object:Gem::Version
|
124
|
+
version: '0'
|
111
125
|
description:
|
112
126
|
email: andrew@chartkick.com
|
113
127
|
executables: []
|
@@ -118,19 +132,23 @@ files:
|
|
118
132
|
- CHANGELOG.md
|
119
133
|
- LICENSE.txt
|
120
134
|
- README.md
|
135
|
+
- codegen/function.rb
|
136
|
+
- codegen/generate_functions.rb
|
137
|
+
- codegen/native_functions.yaml
|
121
138
|
- ext/torch/ext.cpp
|
122
139
|
- ext/torch/extconf.rb
|
123
|
-
- ext/torch/
|
124
|
-
- ext/torch/
|
140
|
+
- ext/torch/nn_functions.h
|
141
|
+
- ext/torch/ruby_arg_parser.cpp
|
142
|
+
- ext/torch/ruby_arg_parser.h
|
143
|
+
- ext/torch/templates.h
|
144
|
+
- ext/torch/tensor_functions.h
|
145
|
+
- ext/torch/torch_functions.h
|
146
|
+
- ext/torch/utils.h
|
147
|
+
- ext/torch/wrap_outputs.h
|
125
148
|
- lib/torch-rb.rb
|
126
149
|
- lib/torch.rb
|
127
150
|
- lib/torch/hub.rb
|
128
151
|
- lib/torch/inspector.rb
|
129
|
-
- lib/torch/native/dispatcher.rb
|
130
|
-
- lib/torch/native/function.rb
|
131
|
-
- lib/torch/native/generator.rb
|
132
|
-
- lib/torch/native/native_functions.yaml
|
133
|
-
- lib/torch/native/parser.rb
|
134
152
|
- lib/torch/nn/adaptive_avg_pool1d.rb
|
135
153
|
- lib/torch/nn/adaptive_avg_pool2d.rb
|
136
154
|
- lib/torch/nn/adaptive_avg_pool3d.rb
|
@@ -238,6 +256,7 @@ files:
|
|
238
256
|
- lib/torch/nn/tanhshrink.rb
|
239
257
|
- lib/torch/nn/triplet_margin_loss.rb
|
240
258
|
- lib/torch/nn/unfold.rb
|
259
|
+
- lib/torch/nn/upsample.rb
|
241
260
|
- lib/torch/nn/utils.rb
|
242
261
|
- lib/torch/nn/weighted_loss.rb
|
243
262
|
- lib/torch/nn/zero_pad2d.rb
|