torch-rb 0.1.5 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +1 -1
- data/ext/torch/ext.cpp +0 -170
- data/ext/torch/nn_functions.cpp +44 -24
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +48 -0
- data/ext/torch/tensor_functions.cpp +76 -16
- data/ext/torch/torch_functions.cpp +165 -65
- data/lib/torch.rb +51 -42
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/native/dispatcher.rb +1 -1
- data/lib/torch/native/function.rb +36 -5
- data/lib/torch/native/generator.rb +26 -7
- data/lib/torch/native/parser.rb +51 -14
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +7 -2
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +1 -1
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +9 -17
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +320 -100
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +1 -1
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +6 -6
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +7 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn_base.rb +48 -4
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/tensor.rb +14 -25
- data/lib/torch/version.rb +1 -1
- metadata +50 -2
@@ -0,0 +1,16 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MaxUnpool3d < MaxUnpoolNd
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0)
|
5
|
+
super()
|
6
|
+
@kernel_size = _triple(kernel_size)
|
7
|
+
@stride = _triple(stride || kernel_size)
|
8
|
+
@padding = _triple(padding)
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(input, indices, output_size: nil)
|
12
|
+
F.max_unpool3d(input, indices, @kernel_size, @stride, @padding, output_size)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
data/lib/torch/nn/module.rb
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Module
|
4
|
+
include Utils
|
5
|
+
|
4
6
|
def initialize
|
5
7
|
@training = true
|
6
8
|
@parameters = {}
|
@@ -15,6 +17,7 @@ module Torch
|
|
15
17
|
def register_buffer(name, tensor)
|
16
18
|
# TODO add checks
|
17
19
|
@buffers[name] = tensor
|
20
|
+
instance_variable_set("@#{name}", tensor)
|
18
21
|
end
|
19
22
|
|
20
23
|
def register_parameter(name, param)
|
@@ -216,6 +219,10 @@ module Torch
|
|
216
219
|
end
|
217
220
|
str % vars
|
218
221
|
end
|
222
|
+
|
223
|
+
def dict
|
224
|
+
instance_variables.map { |k| [k[1..-1].to_sym, instance_variable_get(k)] }.to_h
|
225
|
+
end
|
219
226
|
end
|
220
227
|
end
|
221
228
|
end
|
data/lib/torch/nn/rnn_base.rb
CHANGED
@@ -90,12 +90,13 @@ module Torch
|
|
90
90
|
end
|
91
91
|
|
92
92
|
def permute_hidden(hx, permutation)
|
93
|
+
if permutation.nil?
|
94
|
+
return hx
|
95
|
+
end
|
93
96
|
raise NotImplementedYet
|
94
97
|
end
|
95
98
|
|
96
99
|
def forward(input, hx: nil)
|
97
|
-
raise NotImplementedYet
|
98
|
-
|
99
100
|
is_packed = false # TODO isinstance(input, PackedSequence)
|
100
101
|
if is_packed
|
101
102
|
input, batch_sizes, sorted_indices, unsorted_indices = input
|
@@ -120,8 +121,8 @@ module Torch
|
|
120
121
|
|
121
122
|
check_forward_args(input, hx, batch_sizes)
|
122
123
|
_rnn_impls = {
|
123
|
-
"RNN_TANH" => Torch.method(:
|
124
|
-
"RNN_RELU" => Torch.method(:
|
124
|
+
"RNN_TANH" => Torch.method(:rnn_tanh),
|
125
|
+
"RNN_RELU" => Torch.method(:rnn_relu)
|
125
126
|
}
|
126
127
|
_impl = _rnn_impls[@mode]
|
127
128
|
if batch_sizes.nil?
|
@@ -149,6 +150,49 @@ module Torch
|
|
149
150
|
end
|
150
151
|
format(s, input_size: @input_size, hidden_size: @hidden_size, num_layers: @num_layers)
|
151
152
|
end
|
153
|
+
|
154
|
+
private
|
155
|
+
|
156
|
+
def _flat_weights
|
157
|
+
@all_weights.flatten.map { |v| instance_variable_get("@#{v}") }.compact
|
158
|
+
end
|
159
|
+
|
160
|
+
def _get_flat_weights
|
161
|
+
_flat_weights
|
162
|
+
end
|
163
|
+
|
164
|
+
def check_input(input, batch_sizes)
|
165
|
+
expected_input_dim = !batch_sizes.nil? ? 2 : 3
|
166
|
+
if input.dim != expected_input_dim
|
167
|
+
raise ArgumentError, "input must have #{expected_input_dim} dimensions, got #{input.dim}"
|
168
|
+
end
|
169
|
+
if @input_size != input.size(-1)
|
170
|
+
raise ArgumentError, "input.size(-1) must be equal to input_size. Expected #{@input_size}, got #{input.size(-1)}"
|
171
|
+
end
|
172
|
+
end
|
173
|
+
|
174
|
+
def get_expected_hidden_size(input, batch_sizes)
|
175
|
+
if !batch_sizes.nil?
|
176
|
+
mini_batch = batch_sizes[0]
|
177
|
+
mini_batch = mini_batch.to_i
|
178
|
+
else
|
179
|
+
mini_batch = @batch_first ? input.size(0) : input.size(1)
|
180
|
+
end
|
181
|
+
num_directions = @bidirectional ? 2 : 1
|
182
|
+
[@num_layers * num_directions, mini_batch, @hidden_size]
|
183
|
+
end
|
184
|
+
|
185
|
+
def check_hidden_size(hx, expected_hidden_size)
|
186
|
+
if hx.size != expected_hidden_size
|
187
|
+
raise ArgumentError, "Expected hidden size #{expected_hidden_size.inspect}, got #{hx.size.inspect}"
|
188
|
+
end
|
189
|
+
end
|
190
|
+
|
191
|
+
def check_forward_args(input, hidden, batch_sizes)
|
192
|
+
check_input(input, batch_sizes)
|
193
|
+
expected_hidden_size = get_expected_hidden_size(input, batch_sizes)
|
194
|
+
check_hidden_size(hidden, expected_hidden_size)
|
195
|
+
end
|
152
196
|
end
|
153
197
|
end
|
154
198
|
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Unfold < Module
|
4
|
+
def initialize(kernel_size, dilation: 1, padding: 0, stride: 1)
|
5
|
+
super()
|
6
|
+
@kernel_size = kernel_size
|
7
|
+
@dilation = dilation
|
8
|
+
@padding = padding
|
9
|
+
@stride = stride
|
10
|
+
end
|
11
|
+
|
12
|
+
def forward(input)
|
13
|
+
F.unfold(input, @kernel_size, dilation: @dilation, padding: @padding, stride: @stride)
|
14
|
+
end
|
15
|
+
|
16
|
+
# TODO add extra_inspect
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
module Utils
|
4
|
+
def _single(value)
|
5
|
+
_ntuple(1, value)
|
6
|
+
end
|
7
|
+
|
8
|
+
def _pair(value)
|
9
|
+
_ntuple(2, value)
|
10
|
+
end
|
11
|
+
|
12
|
+
def _triple(value)
|
13
|
+
_ntuple(3, value)
|
14
|
+
end
|
15
|
+
|
16
|
+
def _quadrupal(value)
|
17
|
+
_ntuple(4, value)
|
18
|
+
end
|
19
|
+
|
20
|
+
def _ntuple(n, value)
|
21
|
+
value.is_a?(Array) ? value : [value] * n
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -45,8 +45,9 @@ module Torch
|
|
45
45
|
dim.times.map { |i| size(i) }
|
46
46
|
end
|
47
47
|
|
48
|
-
|
49
|
-
|
48
|
+
# mirror Python len()
|
49
|
+
def length
|
50
|
+
size(0)
|
50
51
|
end
|
51
52
|
|
52
53
|
def item
|
@@ -86,38 +87,26 @@ module Torch
|
|
86
87
|
_type(enum)
|
87
88
|
end
|
88
89
|
|
89
|
-
|
90
|
+
def reshape(*size)
|
91
|
+
# Python doesn't check if size == 1, just ignores later arguments
|
92
|
+
size = size.first if size.size == 1 && size.first.is_a?(Array)
|
93
|
+
_reshape(size)
|
94
|
+
end
|
90
95
|
|
96
|
+
def view(*size)
|
97
|
+
size = size.first if size.size == 1 && size.first.is_a?(Array)
|
98
|
+
_view(size)
|
99
|
+
end
|
100
|
+
|
101
|
+
# value and other are swapped for some methods
|
91
102
|
def add!(value = 1, other)
|
92
103
|
if other.is_a?(Numeric)
|
93
104
|
_add__scalar(other, value)
|
94
105
|
else
|
95
|
-
# need to use alpha for sparse tensors instead of multiplying
|
96
106
|
_add__tensor(other, value)
|
97
107
|
end
|
98
108
|
end
|
99
109
|
|
100
|
-
def mul!(other)
|
101
|
-
if other.is_a?(Numeric)
|
102
|
-
_mul__scalar(other)
|
103
|
-
else
|
104
|
-
_mul__tensor(other)
|
105
|
-
end
|
106
|
-
end
|
107
|
-
|
108
|
-
# operations
|
109
|
-
%w(log_softmax mean softmax sum topk).each do |op|
|
110
|
-
define_method(op) do |*args, **options, &block|
|
111
|
-
if options.any?
|
112
|
-
Torch.send(op, self, *args, **options, &block)
|
113
|
-
else
|
114
|
-
Torch.send(op, self, *args, &block)
|
115
|
-
end
|
116
|
-
end
|
117
|
-
end
|
118
|
-
|
119
|
-
# end temp operations
|
120
|
-
|
121
110
|
def +(other)
|
122
111
|
add(other)
|
123
112
|
end
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.6
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-12-
|
11
|
+
date: 2019-12-10 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -108,6 +108,7 @@ files:
|
|
108
108
|
- ext/torch/extconf.rb
|
109
109
|
- ext/torch/nn_functions.cpp
|
110
110
|
- ext/torch/nn_functions.hpp
|
111
|
+
- ext/torch/templates.cpp
|
111
112
|
- ext/torch/templates.hpp
|
112
113
|
- ext/torch/tensor_functions.cpp
|
113
114
|
- ext/torch/tensor_functions.hpp
|
@@ -123,12 +124,24 @@ files:
|
|
123
124
|
- lib/torch/native/native_functions.yaml
|
124
125
|
- lib/torch/native/parser.rb
|
125
126
|
- lib/torch/nn/alpha_dropout.rb
|
127
|
+
- lib/torch/nn/avg_pool1d.rb
|
126
128
|
- lib/torch/nn/avg_pool2d.rb
|
129
|
+
- lib/torch/nn/avg_pool3d.rb
|
127
130
|
- lib/torch/nn/avg_poolnd.rb
|
131
|
+
- lib/torch/nn/batch_norm.rb
|
132
|
+
- lib/torch/nn/batch_norm1d.rb
|
133
|
+
- lib/torch/nn/batch_norm2d.rb
|
134
|
+
- lib/torch/nn/batch_norm3d.rb
|
128
135
|
- lib/torch/nn/bce_loss.rb
|
129
136
|
- lib/torch/nn/bce_with_logits_loss.rb
|
130
137
|
- lib/torch/nn/bilinear.rb
|
138
|
+
- lib/torch/nn/constant_pad1d.rb
|
139
|
+
- lib/torch/nn/constant_pad2d.rb
|
140
|
+
- lib/torch/nn/constant_pad3d.rb
|
141
|
+
- lib/torch/nn/constant_padnd.rb
|
142
|
+
- lib/torch/nn/conv1d.rb
|
131
143
|
- lib/torch/nn/conv2d.rb
|
144
|
+
- lib/torch/nn/conv3d.rb
|
132
145
|
- lib/torch/nn/convnd.rb
|
133
146
|
- lib/torch/nn/cosine_embedding_loss.rb
|
134
147
|
- lib/torch/nn/cosine_similarity.rb
|
@@ -141,19 +154,40 @@ files:
|
|
141
154
|
- lib/torch/nn/embedding.rb
|
142
155
|
- lib/torch/nn/embedding_bag.rb
|
143
156
|
- lib/torch/nn/feature_alpha_dropout.rb
|
157
|
+
- lib/torch/nn/fold.rb
|
144
158
|
- lib/torch/nn/functional.rb
|
159
|
+
- lib/torch/nn/group_norm.rb
|
160
|
+
- lib/torch/nn/gru.rb
|
161
|
+
- lib/torch/nn/hardshrink.rb
|
145
162
|
- lib/torch/nn/hinge_embedding_loss.rb
|
146
163
|
- lib/torch/nn/identity.rb
|
147
164
|
- lib/torch/nn/init.rb
|
165
|
+
- lib/torch/nn/instance_norm.rb
|
166
|
+
- lib/torch/nn/instance_norm1d.rb
|
167
|
+
- lib/torch/nn/instance_norm2d.rb
|
168
|
+
- lib/torch/nn/instance_norm3d.rb
|
148
169
|
- lib/torch/nn/kl_div_loss.rb
|
149
170
|
- lib/torch/nn/l1_loss.rb
|
171
|
+
- lib/torch/nn/layer_norm.rb
|
150
172
|
- lib/torch/nn/leaky_relu.rb
|
151
173
|
- lib/torch/nn/linear.rb
|
174
|
+
- lib/torch/nn/local_response_norm.rb
|
175
|
+
- lib/torch/nn/log_sigmoid.rb
|
152
176
|
- lib/torch/nn/log_softmax.rb
|
153
177
|
- lib/torch/nn/loss.rb
|
178
|
+
- lib/torch/nn/lp_pool1d.rb
|
179
|
+
- lib/torch/nn/lp_pool2d.rb
|
180
|
+
- lib/torch/nn/lp_poolnd.rb
|
181
|
+
- lib/torch/nn/lstm.rb
|
154
182
|
- lib/torch/nn/margin_ranking_loss.rb
|
183
|
+
- lib/torch/nn/max_pool1d.rb
|
155
184
|
- lib/torch/nn/max_pool2d.rb
|
185
|
+
- lib/torch/nn/max_pool3d.rb
|
156
186
|
- lib/torch/nn/max_poolnd.rb
|
187
|
+
- lib/torch/nn/max_unpool1d.rb
|
188
|
+
- lib/torch/nn/max_unpool2d.rb
|
189
|
+
- lib/torch/nn/max_unpool3d.rb
|
190
|
+
- lib/torch/nn/max_unpoolnd.rb
|
157
191
|
- lib/torch/nn/module.rb
|
158
192
|
- lib/torch/nn/mse_loss.rb
|
159
193
|
- lib/torch/nn/multi_label_margin_loss.rb
|
@@ -164,7 +198,14 @@ files:
|
|
164
198
|
- lib/torch/nn/parameter.rb
|
165
199
|
- lib/torch/nn/poisson_nll_loss.rb
|
166
200
|
- lib/torch/nn/prelu.rb
|
201
|
+
- lib/torch/nn/reflection_pad1d.rb
|
202
|
+
- lib/torch/nn/reflection_pad2d.rb
|
203
|
+
- lib/torch/nn/reflection_padnd.rb
|
167
204
|
- lib/torch/nn/relu.rb
|
205
|
+
- lib/torch/nn/replication_pad1d.rb
|
206
|
+
- lib/torch/nn/replication_pad2d.rb
|
207
|
+
- lib/torch/nn/replication_pad3d.rb
|
208
|
+
- lib/torch/nn/replication_padnd.rb
|
168
209
|
- lib/torch/nn/rnn.rb
|
169
210
|
- lib/torch/nn/rnn_base.rb
|
170
211
|
- lib/torch/nn/sequential.rb
|
@@ -175,8 +216,15 @@ files:
|
|
175
216
|
- lib/torch/nn/softmax2d.rb
|
176
217
|
- lib/torch/nn/softmin.rb
|
177
218
|
- lib/torch/nn/softplus.rb
|
219
|
+
- lib/torch/nn/softshrink.rb
|
220
|
+
- lib/torch/nn/softsign.rb
|
221
|
+
- lib/torch/nn/tanh.rb
|
222
|
+
- lib/torch/nn/tanhshrink.rb
|
178
223
|
- lib/torch/nn/triplet_margin_loss.rb
|
224
|
+
- lib/torch/nn/unfold.rb
|
225
|
+
- lib/torch/nn/utils.rb
|
179
226
|
- lib/torch/nn/weighted_loss.rb
|
227
|
+
- lib/torch/nn/zero_pad2d.rb
|
180
228
|
- lib/torch/optim/adadelta.rb
|
181
229
|
- lib/torch/optim/adagrad.rb
|
182
230
|
- lib/torch/optim/adam.rb
|