torch-rb 0.1.5 → 0.1.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +1 -1
- data/ext/torch/ext.cpp +0 -170
- data/ext/torch/nn_functions.cpp +44 -24
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +48 -0
- data/ext/torch/tensor_functions.cpp +76 -16
- data/ext/torch/torch_functions.cpp +165 -65
- data/lib/torch.rb +51 -42
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/native/dispatcher.rb +1 -1
- data/lib/torch/native/function.rb +36 -5
- data/lib/torch/native/generator.rb +26 -7
- data/lib/torch/native/parser.rb +51 -14
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +7 -2
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +1 -1
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +9 -17
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +320 -100
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +1 -1
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +6 -6
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +7 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn_base.rb +48 -4
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/tensor.rb +14 -25
- data/lib/torch/version.rb +1 -1
- metadata +50 -2
@@ -0,0 +1,16 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MaxUnpool3d < MaxUnpoolNd
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0)
|
5
|
+
super()
|
6
|
+
@kernel_size = _triple(kernel_size)
|
7
|
+
@stride = _triple(stride || kernel_size)
|
8
|
+
@padding = _triple(padding)
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(input, indices, output_size: nil)
|
12
|
+
F.max_unpool3d(input, indices, @kernel_size, @stride, @padding, output_size)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
data/lib/torch/nn/module.rb
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Module
|
4
|
+
include Utils
|
5
|
+
|
4
6
|
def initialize
|
5
7
|
@training = true
|
6
8
|
@parameters = {}
|
@@ -15,6 +17,7 @@ module Torch
|
|
15
17
|
def register_buffer(name, tensor)
|
16
18
|
# TODO add checks
|
17
19
|
@buffers[name] = tensor
|
20
|
+
instance_variable_set("@#{name}", tensor)
|
18
21
|
end
|
19
22
|
|
20
23
|
def register_parameter(name, param)
|
@@ -216,6 +219,10 @@ module Torch
|
|
216
219
|
end
|
217
220
|
str % vars
|
218
221
|
end
|
222
|
+
|
223
|
+
def dict
|
224
|
+
instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
225
|
+
end
|
219
226
|
end
|
220
227
|
end
|
221
228
|
end
|
data/lib/torch/nn/rnn_base.rb
CHANGED
@@ -90,12 +90,13 @@ module Torch
|
|
90
90
|
end
|
91
91
|
|
92
92
|
def permute_hidden(hx, permutation)
|
93
|
+
if permutation.nil?
|
94
|
+
return hx
|
95
|
+
end
|
93
96
|
raise NotImplementedYet
|
94
97
|
end
|
95
98
|
|
96
99
|
def forward(input, hx: nil)
|
97
|
-
raise NotImplementedYet
|
98
|
-
|
99
100
|
is_packed = false # TODO isinstance(input, PackedSequence)
|
100
101
|
if is_packed
|
101
102
|
input, batch_sizes, sorted_indices, unsorted_indices = input
|
@@ -120,8 +121,8 @@ module Torch
|
|
120
121
|
|
121
122
|
check_forward_args(input, hx, batch_sizes)
|
122
123
|
_rnn_impls = {
|
123
|
-
"RNN_TANH" => Torch.method(:
|
124
|
-
"RNN_RELU" => Torch.method(:
|
124
|
+
"RNN_TANH" => Torch.method(:rnn_tanh),
|
125
|
+
"RNN_RELU" => Torch.method(:rnn_relu)
|
125
126
|
}
|
126
127
|
_impl = _rnn_impls[@mode]
|
127
128
|
if batch_sizes.nil?
|
@@ -149,6 +150,49 @@ module Torch
|
|
149
150
|
end
|
150
151
|
format(s, input_size: @input_size, hidden_size: @hidden_size, num_layers: @num_layers)
|
151
152
|
end
|
153
|
+
|
154
|
+
private
|
155
|
+
|
156
|
+
def _flat_weights
|
157
|
+
@all_weights.flatten.map { |v| instance_variable_get("@#{v}") }.compact
|
158
|
+
end
|
159
|
+
|
160
|
+
def _get_flat_weights
|
161
|
+
_flat_weights
|
162
|
+
end
|
163
|
+
|
164
|
+
def check_input(input, batch_sizes)
|
165
|
+
expected_input_dim = !batch_sizes.nil? ? 2 : 3
|
166
|
+
if input.dim != expected_input_dim
|
167
|
+
raise ArgumentError, "input must have #{expected_input_dim} dimensions, got #{input.dim}"
|
168
|
+
end
|
169
|
+
if @input_size != input.size(-1)
|
170
|
+
raise ArgumentError, "input.size(-1) must be equal to input_size. Expected #{@input_size}, got #{input.size(-1)}"
|
171
|
+
end
|
172
|
+
end
|
173
|
+
|
174
|
+
def get_expected_hidden_size(input, batch_sizes)
|
175
|
+
if !batch_sizes.nil?
|
176
|
+
mini_batch = batch_sizes[0]
|
177
|
+
mini_batch = mini_batch.to_i
|
178
|
+
else
|
179
|
+
mini_batch = @batch_first ? input.size(0) : input.size(1)
|
180
|
+
end
|
181
|
+
num_directions = @bidirectional ? 2 : 1
|
182
|
+
[@num_layers * num_directions, mini_batch, @hidden_size]
|
183
|
+
end
|
184
|
+
|
185
|
+
def check_hidden_size(hx, expected_hidden_size)
|
186
|
+
if hx.size != expected_hidden_size
|
187
|
+
raise ArgumentError, "Expected hidden size #{expected_hidden_size.inspect}, got #{hx.size.inspect}"
|
188
|
+
end
|
189
|
+
end
|
190
|
+
|
191
|
+
def check_forward_args(input, hidden, batch_sizes)
|
192
|
+
check_input(input, batch_sizes)
|
193
|
+
expected_hidden_size = get_expected_hidden_size(input, batch_sizes)
|
194
|
+
check_hidden_size(hidden, expected_hidden_size)
|
195
|
+
end
|
152
196
|
end
|
153
197
|
end
|
154
198
|
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Unfold < Module
|
4
|
+
def initialize(kernel_size, dilation: 1, padding: 0, stride: 1)
|
5
|
+
super()
|
6
|
+
@kernel_size = kernel_size
|
7
|
+
@dilation = dilation
|
8
|
+
@padding = padding
|
9
|
+
@stride = stride
|
10
|
+
end
|
11
|
+
|
12
|
+
def forward(input)
|
13
|
+
F.unfold(input, @kernel_size, dilation: @dilation, padding: @padding, stride: @stride)
|
14
|
+
end
|
15
|
+
|
16
|
+
# TODO add extra_inspect
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
module Utils
|
4
|
+
def _single(value)
|
5
|
+
_ntuple(1, value)
|
6
|
+
end
|
7
|
+
|
8
|
+
def _pair(value)
|
9
|
+
_ntuple(2, value)
|
10
|
+
end
|
11
|
+
|
12
|
+
def _triple(value)
|
13
|
+
_ntuple(3, value)
|
14
|
+
end
|
15
|
+
|
16
|
+
def _quadrupal(value)
|
17
|
+
_ntuple(4, value)
|
18
|
+
end
|
19
|
+
|
20
|
+
def _ntuple(n, value)
|
21
|
+
value.is_a?(Array) ? value : [value] * n
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -45,8 +45,9 @@ module Torch
|
|
45
45
|
dim.times.map { |i| size(i) }
|
46
46
|
end
|
47
47
|
|
48
|
-
|
49
|
-
|
48
|
+
# mirror Python len()
|
49
|
+
def length
|
50
|
+
size(0)
|
50
51
|
end
|
51
52
|
|
52
53
|
def item
|
@@ -86,38 +87,26 @@ module Torch
|
|
86
87
|
_type(enum)
|
87
88
|
end
|
88
89
|
|
89
|
-
|
90
|
+
def reshape(*size)
|
91
|
+
# Python doesn't check if size == 1, just ignores later arguments
|
92
|
+
size = size.first if size.size == 1 && size.first.is_a?(Array)
|
93
|
+
_reshape(size)
|
94
|
+
end
|
90
95
|
|
96
|
+
def view(*size)
|
97
|
+
size = size.first if size.size == 1 && size.first.is_a?(Array)
|
98
|
+
_view(size)
|
99
|
+
end
|
100
|
+
|
101
|
+
# value and other are swapped for some methods
|
91
102
|
def add!(value = 1, other)
|
92
103
|
if other.is_a?(Numeric)
|
93
104
|
_add__scalar(other, value)
|
94
105
|
else
|
95
|
-
# need to use alpha for sparse tensors instead of multiplying
|
96
106
|
_add__tensor(other, value)
|
97
107
|
end
|
98
108
|
end
|
99
109
|
|
100
|
-
def mul!(other)
|
101
|
-
if other.is_a?(Numeric)
|
102
|
-
_mul__scalar(other)
|
103
|
-
else
|
104
|
-
_mul__tensor(other)
|
105
|
-
end
|
106
|
-
end
|
107
|
-
|
108
|
-
# operations
|
109
|
-
%w(log_softmax mean softmax sum topk).each do |op|
|
110
|
-
define_method(op) do |*args, **options, &block|
|
111
|
-
if options.any?
|
112
|
-
Torch.send(op, self, *args, **options, &block)
|
113
|
-
else
|
114
|
-
Torch.send(op, self, *args, &block)
|
115
|
-
end
|
116
|
-
end
|
117
|
-
end
|
118
|
-
|
119
|
-
# end temp operations
|
120
|
-
|
121
110
|
def +(other)
|
122
111
|
add(other)
|
123
112
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-12-
|
11
|
+
date: 2019-12-10 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -108,6 +108,7 @@ files:
|
|
108
108
|
- ext/torch/extconf.rb
|
109
109
|
- ext/torch/nn_functions.cpp
|
110
110
|
- ext/torch/nn_functions.hpp
|
111
|
+
- ext/torch/templates.cpp
|
111
112
|
- ext/torch/templates.hpp
|
112
113
|
- ext/torch/tensor_functions.cpp
|
113
114
|
- ext/torch/tensor_functions.hpp
|
@@ -123,12 +124,24 @@ files:
|
|
123
124
|
- lib/torch/native/native_functions.yaml
|
124
125
|
- lib/torch/native/parser.rb
|
125
126
|
- lib/torch/nn/alpha_dropout.rb
|
127
|
+
- lib/torch/nn/avg_pool1d.rb
|
126
128
|
- lib/torch/nn/avg_pool2d.rb
|
129
|
+
- lib/torch/nn/avg_pool3d.rb
|
127
130
|
- lib/torch/nn/avg_poolnd.rb
|
131
|
+
- lib/torch/nn/batch_norm.rb
|
132
|
+
- lib/torch/nn/batch_norm1d.rb
|
133
|
+
- lib/torch/nn/batch_norm2d.rb
|
134
|
+
- lib/torch/nn/batch_norm3d.rb
|
128
135
|
- lib/torch/nn/bce_loss.rb
|
129
136
|
- lib/torch/nn/bce_with_logits_loss.rb
|
130
137
|
- lib/torch/nn/bilinear.rb
|
138
|
+
- lib/torch/nn/constant_pad1d.rb
|
139
|
+
- lib/torch/nn/constant_pad2d.rb
|
140
|
+
- lib/torch/nn/constant_pad3d.rb
|
141
|
+
- lib/torch/nn/constant_padnd.rb
|
142
|
+
- lib/torch/nn/conv1d.rb
|
131
143
|
- lib/torch/nn/conv2d.rb
|
144
|
+
- lib/torch/nn/conv3d.rb
|
132
145
|
- lib/torch/nn/convnd.rb
|
133
146
|
- lib/torch/nn/cosine_embedding_loss.rb
|
134
147
|
- lib/torch/nn/cosine_similarity.rb
|
@@ -141,19 +154,40 @@ files:
|
|
141
154
|
- lib/torch/nn/embedding.rb
|
142
155
|
- lib/torch/nn/embedding_bag.rb
|
143
156
|
- lib/torch/nn/feature_alpha_dropout.rb
|
157
|
+
- lib/torch/nn/fold.rb
|
144
158
|
- lib/torch/nn/functional.rb
|
159
|
+
- lib/torch/nn/group_norm.rb
|
160
|
+
- lib/torch/nn/gru.rb
|
161
|
+
- lib/torch/nn/hardshrink.rb
|
145
162
|
- lib/torch/nn/hinge_embedding_loss.rb
|
146
163
|
- lib/torch/nn/identity.rb
|
147
164
|
- lib/torch/nn/init.rb
|
165
|
+
- lib/torch/nn/instance_norm.rb
|
166
|
+
- lib/torch/nn/instance_norm1d.rb
|
167
|
+
- lib/torch/nn/instance_norm2d.rb
|
168
|
+
- lib/torch/nn/instance_norm3d.rb
|
148
169
|
- lib/torch/nn/kl_div_loss.rb
|
149
170
|
- lib/torch/nn/l1_loss.rb
|
171
|
+
- lib/torch/nn/layer_norm.rb
|
150
172
|
- lib/torch/nn/leaky_relu.rb
|
151
173
|
- lib/torch/nn/linear.rb
|
174
|
+
- lib/torch/nn/local_response_norm.rb
|
175
|
+
- lib/torch/nn/log_sigmoid.rb
|
152
176
|
- lib/torch/nn/log_softmax.rb
|
153
177
|
- lib/torch/nn/loss.rb
|
178
|
+
- lib/torch/nn/lp_pool1d.rb
|
179
|
+
- lib/torch/nn/lp_pool2d.rb
|
180
|
+
- lib/torch/nn/lp_poolnd.rb
|
181
|
+
- lib/torch/nn/lstm.rb
|
154
182
|
- lib/torch/nn/margin_ranking_loss.rb
|
183
|
+
- lib/torch/nn/max_pool1d.rb
|
155
184
|
- lib/torch/nn/max_pool2d.rb
|
185
|
+
- lib/torch/nn/max_pool3d.rb
|
156
186
|
- lib/torch/nn/max_poolnd.rb
|
187
|
+
- lib/torch/nn/max_unpool1d.rb
|
188
|
+
- lib/torch/nn/max_unpool2d.rb
|
189
|
+
- lib/torch/nn/max_unpool3d.rb
|
190
|
+
- lib/torch/nn/max_unpoolnd.rb
|
157
191
|
- lib/torch/nn/module.rb
|
158
192
|
- lib/torch/nn/mse_loss.rb
|
159
193
|
- lib/torch/nn/multi_label_margin_loss.rb
|
@@ -164,7 +198,14 @@ files:
|
|
164
198
|
- lib/torch/nn/parameter.rb
|
165
199
|
- lib/torch/nn/poisson_nll_loss.rb
|
166
200
|
- lib/torch/nn/prelu.rb
|
201
|
+
- lib/torch/nn/reflection_pad1d.rb
|
202
|
+
- lib/torch/nn/reflection_pad2d.rb
|
203
|
+
- lib/torch/nn/reflection_padnd.rb
|
167
204
|
- lib/torch/nn/relu.rb
|
205
|
+
- lib/torch/nn/replication_pad1d.rb
|
206
|
+
- lib/torch/nn/replication_pad2d.rb
|
207
|
+
- lib/torch/nn/replication_pad3d.rb
|
208
|
+
- lib/torch/nn/replication_padnd.rb
|
168
209
|
- lib/torch/nn/rnn.rb
|
169
210
|
- lib/torch/nn/rnn_base.rb
|
170
211
|
- lib/torch/nn/sequential.rb
|
@@ -175,8 +216,15 @@ files:
|
|
175
216
|
- lib/torch/nn/softmax2d.rb
|
176
217
|
- lib/torch/nn/softmin.rb
|
177
218
|
- lib/torch/nn/softplus.rb
|
219
|
+
- lib/torch/nn/softshrink.rb
|
220
|
+
- lib/torch/nn/softsign.rb
|
221
|
+
- lib/torch/nn/tanh.rb
|
222
|
+
- lib/torch/nn/tanhshrink.rb
|
178
223
|
- lib/torch/nn/triplet_margin_loss.rb
|
224
|
+
- lib/torch/nn/unfold.rb
|
225
|
+
- lib/torch/nn/utils.rb
|
179
226
|
- lib/torch/nn/weighted_loss.rb
|
227
|
+
- lib/torch/nn/zero_pad2d.rb
|
180
228
|
- lib/torch/optim/adadelta.rb
|
181
229
|
- lib/torch/optim/adagrad.rb
|
182
230
|
- lib/torch/optim/adam.rb
|