torch-rb 0.1.3 → 0.1.8
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +30 -0
- data/README.md +5 -2
- data/ext/torch/ext.cpp +130 -555
- data/ext/torch/extconf.rb +9 -0
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +244 -0
- data/lib/torch.rb +209 -171
- data/lib/torch/inspector.rb +23 -19
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +110 -0
- data/lib/torch/native/generator.rb +168 -0
- data/lib/torch/native/native_functions.yaml +6491 -0
- data/lib/torch/native/parser.rb +134 -0
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +19 -0
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +10 -20
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/convnd.rb +3 -3
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropoutnd.rb +2 -2
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +379 -32
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +186 -35
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +2 -2
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +198 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +51 -44
- data/lib/torch/version.rb +1 -1
- metadata +98 -6
- data/lib/torch/ext.bundle +0 -0
@@ -0,0 +1,22 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class LPPoolNd < Module
|
4
|
+
def initialize(norm_type, kernel_size, stride: nil, ceil_mode: false)
|
5
|
+
super()
|
6
|
+
@norm_type = norm_type
|
7
|
+
@kernel_size = kernel_size
|
8
|
+
@stride = stride
|
9
|
+
@ceil_mode = ceil_mode
|
10
|
+
end
|
11
|
+
|
12
|
+
def extra_inspect
|
13
|
+
format("norm_type: %{norm_type}, kernel_size: %{kernel_size}, stride: %{stride}, ceil_mode: %{ceil_mode}",
|
14
|
+
norm_type: @norm_type,
|
15
|
+
kernel_size: @kernel_size,
|
16
|
+
stride: @stride,
|
17
|
+
ceil_mode: @ceil_mode
|
18
|
+
)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
22
|
+
end
|
@@ -0,0 +1,66 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class LSTM < RNNBase
|
4
|
+
def initialize(*args, **options)
|
5
|
+
super("LSTM", *args, **options)
|
6
|
+
end
|
7
|
+
|
8
|
+
def check_forward_args(input, hidden, batch_sizes)
|
9
|
+
check_input(input, batch_sizes)
|
10
|
+
expected_hidden_size = get_expected_hidden_size(input, batch_sizes)
|
11
|
+
|
12
|
+
# TODO pass message
|
13
|
+
check_hidden_size(hidden[0], expected_hidden_size)
|
14
|
+
check_hidden_size(hidden[1], expected_hidden_size)
|
15
|
+
end
|
16
|
+
|
17
|
+
def permute_hidden(hx, permutation)
|
18
|
+
if permutation.nil?
|
19
|
+
return hx
|
20
|
+
end
|
21
|
+
raise NotImplementedYet
|
22
|
+
end
|
23
|
+
|
24
|
+
def forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
|
25
|
+
if hx.nil?
|
26
|
+
num_directions = @bidirectional ? 2 : 1
|
27
|
+
zeros = Torch.zeros(@num_layers * num_directions, max_batch_size, @hidden_size, dtype: input.dtype, device: input.device)
|
28
|
+
hx = [zeros, zeros]
|
29
|
+
else
|
30
|
+
# Each batch of the hidden state should match the input sequence that
|
31
|
+
# the user believes he/she is passing in.
|
32
|
+
hx = permute_hidden(hx, sorted_indices)
|
33
|
+
end
|
34
|
+
|
35
|
+
check_forward_args(input, hx, batch_sizes)
|
36
|
+
if batch_sizes.nil?
|
37
|
+
result = Torch.lstm(input, hx, _get_flat_weights, @bias, @num_layers,
|
38
|
+
@dropout, @training, @bidirectional, @batch_first)
|
39
|
+
else
|
40
|
+
result = Torch.lstm(input, batch_sizes, hx, _get_flat_weights, @bias,
|
41
|
+
@num_layers, @dropout, @training, @bidirectional)
|
42
|
+
end
|
43
|
+
output = result[0]
|
44
|
+
hidden = result[1..-1]
|
45
|
+
|
46
|
+
[output, hidden]
|
47
|
+
end
|
48
|
+
|
49
|
+
def forward_tensor(input, hx: nil)
|
50
|
+
batch_sizes = nil
|
51
|
+
max_batch_size = @batch_first ? input.size(0) : input.size(1)
|
52
|
+
sorted_indices = nil
|
53
|
+
unsorted_indices = nil
|
54
|
+
|
55
|
+
output, hidden = forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
|
56
|
+
|
57
|
+
[output, permute_hidden(hidden, unsorted_indices)]
|
58
|
+
end
|
59
|
+
|
60
|
+
def forward(input, hx: nil)
|
61
|
+
# TODO PackedSequence
|
62
|
+
forward_tensor(input, hx: hx)
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MarginRankingLoss < Loss
|
4
|
+
def initialize(margin: 1.0, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input1, input2, target)
|
10
|
+
F.margin_ranking_loss(input1, input2, target, margin: @margin, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MaxPoolNd < Module
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0, dilation: 1, return_indices: false, ceil_mode: false)
|
5
|
+
super()
|
6
|
+
@kernel_size = kernel_size
|
7
|
+
@stride = stride || kernel_size
|
8
|
+
@padding = padding
|
9
|
+
@dilation = dilation
|
10
|
+
@return_indices = return_indices
|
11
|
+
@ceil_mode = ceil_mode
|
12
|
+
end
|
13
|
+
|
14
|
+
def extra_inspect
|
15
|
+
format("kernel_size: %s", @kernel_size)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MaxUnpool1d < MaxUnpoolNd
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0)
|
5
|
+
super()
|
6
|
+
@kernel_size = _single(kernel_size)
|
7
|
+
@stride = _single(stride || kernel_size)
|
8
|
+
@padding = _single(padding)
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(input, indices, output_size: nil)
|
12
|
+
F.max_unpool1d(input, indices, @kernel_size, stride: @stride, padding: @padding, output_size: output_size)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MaxUnpool2d < MaxUnpoolNd
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0)
|
5
|
+
super()
|
6
|
+
@kernel_size = _pair(kernel_size)
|
7
|
+
@stride = _pair(stride || kernel_size)
|
8
|
+
@padding = _pair(padding)
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(input, indices, output_size: nil)
|
12
|
+
F.max_unpool2d(input, indices, @kernel_size, @stride, @padding, output_size)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MaxUnpool3d < MaxUnpoolNd
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0)
|
5
|
+
super()
|
6
|
+
@kernel_size = _triple(kernel_size)
|
7
|
+
@stride = _triple(stride || kernel_size)
|
8
|
+
@padding = _triple(padding)
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(input, indices, output_size: nil)
|
12
|
+
F.max_unpool3d(input, indices, @kernel_size, @stride, @padding, output_size)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
data/lib/torch/nn/module.rb
CHANGED
@@ -1,56 +1,170 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Module
|
4
|
+
include Utils
|
5
|
+
|
4
6
|
def initialize
|
5
7
|
@training = true
|
8
|
+
@parameters = {}
|
9
|
+
@buffers = {}
|
10
|
+
@modules = {}
|
6
11
|
end
|
7
12
|
|
8
|
-
def
|
9
|
-
|
10
|
-
str << "#{self.class.name}(\n"
|
11
|
-
modules.each do |name, mod|
|
12
|
-
str << " (#{name}): #{mod.inspect}\n"
|
13
|
-
end
|
14
|
-
str << ")"
|
13
|
+
def forward
|
14
|
+
raise NotImplementedError
|
15
15
|
end
|
16
16
|
|
17
|
-
def
|
18
|
-
|
17
|
+
def register_buffer(name, tensor)
|
18
|
+
# TODO add checks
|
19
|
+
@buffers[name] = tensor
|
20
|
+
instance_variable_set("@#{name}", tensor)
|
21
|
+
end
|
19
22
|
|
20
|
-
|
21
|
-
|
23
|
+
def register_parameter(name, param)
|
24
|
+
# TODO add checks
|
25
|
+
@parameters[name] = param
|
26
|
+
end
|
27
|
+
|
28
|
+
def add_module(name, mod)
|
29
|
+
# TODO add checks
|
30
|
+
@modules[name] = mod
|
31
|
+
end
|
32
|
+
|
33
|
+
def _apply(fn)
|
34
|
+
children.each do |mod|
|
35
|
+
mod._apply(fn)
|
22
36
|
end
|
37
|
+
# TODO apply to more objects
|
38
|
+
self
|
23
39
|
end
|
24
40
|
|
25
|
-
def
|
26
|
-
|
41
|
+
def apply(fn)
|
42
|
+
children.each do |mod|
|
43
|
+
mod.apply(fn)
|
44
|
+
end
|
45
|
+
fn.call(self)
|
46
|
+
self
|
47
|
+
end
|
48
|
+
|
49
|
+
def cuda(device: nil)
|
50
|
+
_apply ->(t) { t.cuda(device) }
|
51
|
+
end
|
52
|
+
|
53
|
+
def cpu
|
54
|
+
_apply ->(t) { t.cpu }
|
55
|
+
end
|
56
|
+
|
57
|
+
def type(dst_type)
|
58
|
+
_apply ->(t) { t.type(dst_type) }
|
27
59
|
end
|
28
60
|
|
29
|
-
def
|
30
|
-
|
61
|
+
def float
|
62
|
+
_apply ->(t) { t.floating_point? ? t.float : t }
|
63
|
+
end
|
64
|
+
|
65
|
+
def double
|
66
|
+
_apply ->(t) { t.floating_point? ? t.double : t }
|
67
|
+
end
|
68
|
+
|
69
|
+
def half
|
70
|
+
_apply ->(t) { t.floating_point? ? t.half : t }
|
31
71
|
end
|
32
72
|
|
33
73
|
# modifies in-place
|
34
74
|
def to(device)
|
35
|
-
|
36
|
-
|
37
|
-
if param.is_a?(Parameter)
|
38
|
-
instance_variable_set(name, Parameter.new(param.to(device)))
|
39
|
-
end
|
75
|
+
convert = lambda do |t|
|
76
|
+
t.to(device)
|
40
77
|
end
|
41
|
-
|
42
|
-
|
78
|
+
|
79
|
+
_apply(convert)
|
80
|
+
end
|
81
|
+
|
82
|
+
def call(*input, **kwargs)
|
83
|
+
forward(*input, **kwargs)
|
84
|
+
end
|
85
|
+
|
86
|
+
def state_dict(destination: nil)
|
87
|
+
destination ||= {}
|
88
|
+
named_parameters.each do |k, v|
|
89
|
+
destination[k] = v
|
43
90
|
end
|
44
|
-
|
91
|
+
destination
|
92
|
+
end
|
93
|
+
|
94
|
+
def load_state_dict(state_dict)
|
95
|
+
raise NotImplementedYet
|
45
96
|
end
|
46
97
|
|
47
98
|
def parameters
|
48
|
-
|
99
|
+
named_parameters.values
|
100
|
+
end
|
101
|
+
|
102
|
+
def named_parameters(prefix: "", recurse: true)
|
103
|
+
params = {}
|
104
|
+
if recurse
|
105
|
+
named_children.each do |name, mod|
|
106
|
+
params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
|
107
|
+
end
|
108
|
+
end
|
49
109
|
instance_variables.each do |name|
|
50
110
|
param = instance_variable_get(name)
|
51
|
-
params
|
111
|
+
params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
|
112
|
+
end
|
113
|
+
@parameters.each do |name, param|
|
114
|
+
params[[prefix, name].join] = param
|
52
115
|
end
|
53
|
-
params
|
116
|
+
params
|
117
|
+
end
|
118
|
+
|
119
|
+
def buffers
|
120
|
+
named_buffers.values
|
121
|
+
end
|
122
|
+
|
123
|
+
def named_buffers
|
124
|
+
@buffers || {}
|
125
|
+
end
|
126
|
+
|
127
|
+
def children
|
128
|
+
named_children.values
|
129
|
+
end
|
130
|
+
|
131
|
+
def named_children
|
132
|
+
modules = {}
|
133
|
+
instance_variables.each do |name|
|
134
|
+
mod = instance_variable_get(name)
|
135
|
+
modules[name[1..-1]] = mod if mod.is_a?(Module)
|
136
|
+
end
|
137
|
+
@modules.each do |name, mod|
|
138
|
+
modules[name] = mod
|
139
|
+
end
|
140
|
+
modules
|
141
|
+
end
|
142
|
+
|
143
|
+
def modules
|
144
|
+
named_modules.values
|
145
|
+
end
|
146
|
+
|
147
|
+
def named_modules
|
148
|
+
{"" => self}.merge(named_children)
|
149
|
+
end
|
150
|
+
|
151
|
+
def train(mode = true)
|
152
|
+
@training = mode
|
153
|
+
children.each do |mod|
|
154
|
+
mod.train(mode)
|
155
|
+
end
|
156
|
+
self
|
157
|
+
end
|
158
|
+
|
159
|
+
def eval
|
160
|
+
train(false)
|
161
|
+
end
|
162
|
+
|
163
|
+
def requires_grad!(requires_grad: true)
|
164
|
+
parameters.each do |p|
|
165
|
+
p.requires_grad!(requires_grad)
|
166
|
+
end
|
167
|
+
self
|
54
168
|
end
|
55
169
|
|
56
170
|
def zero_grad
|
@@ -62,23 +176,60 @@ module Torch
|
|
62
176
|
end
|
63
177
|
end
|
64
178
|
|
179
|
+
def share_memory
|
180
|
+
_apply ->(t) { t.share_memory! }
|
181
|
+
end
|
182
|
+
|
183
|
+
def inspect
|
184
|
+
name = self.class.name.split("::").last
|
185
|
+
if children.empty?
|
186
|
+
"#{name}(#{extra_inspect})"
|
187
|
+
else
|
188
|
+
str = String.new
|
189
|
+
str << "#{name}(\n"
|
190
|
+
children.each do |name, mod|
|
191
|
+
str << " (#{name}): #{mod.inspect}\n"
|
192
|
+
end
|
193
|
+
str << ")"
|
194
|
+
end
|
195
|
+
end
|
196
|
+
|
65
197
|
def method_missing(method, *args, &block)
|
66
|
-
|
198
|
+
name = method.to_s
|
199
|
+
if named_parameters.key?(name)
|
200
|
+
named_parameters[name]
|
201
|
+
elsif named_buffers.key?(name)
|
202
|
+
named_buffers[name]
|
203
|
+
elsif named_modules.key?(name)
|
204
|
+
named_modules[name]
|
205
|
+
else
|
206
|
+
super
|
207
|
+
end
|
67
208
|
end
|
68
209
|
|
69
210
|
def respond_to?(method, include_private = false)
|
70
|
-
|
211
|
+
name = method.to_s
|
212
|
+
named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
|
71
213
|
end
|
72
214
|
|
73
215
|
private
|
74
216
|
|
75
|
-
def
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
217
|
+
def extra_inspect
|
218
|
+
nil
|
219
|
+
end
|
220
|
+
|
221
|
+
def format(str, *vars, **options)
|
222
|
+
vars =
|
223
|
+
if vars.any?
|
224
|
+
vars.map(&:inspect)
|
225
|
+
else
|
226
|
+
options.map { |k, v| [k, v.inspect] }.to_h
|
227
|
+
end
|
228
|
+
str % vars
|
229
|
+
end
|
230
|
+
|
231
|
+
def dict
|
232
|
+
instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
82
233
|
end
|
83
234
|
end
|
84
235
|
end
|