torch-rb 0.1.3 → 0.1.8
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +30 -0
- data/README.md +5 -2
- data/ext/torch/ext.cpp +130 -555
- data/ext/torch/extconf.rb +9 -0
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +244 -0
- data/lib/torch.rb +209 -171
- data/lib/torch/inspector.rb +23 -19
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +110 -0
- data/lib/torch/native/generator.rb +168 -0
- data/lib/torch/native/native_functions.yaml +6491 -0
- data/lib/torch/native/parser.rb +134 -0
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +19 -0
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +10 -20
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/convnd.rb +3 -3
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropoutnd.rb +2 -2
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +379 -32
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +186 -35
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +2 -2
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +198 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +51 -44
- data/lib/torch/version.rb +1 -1
- metadata +98 -6
- data/lib/torch/ext.bundle +0 -0
data/lib/torch/nn/sequential.rb
CHANGED
@@ -2,28 +2,19 @@ module Torch
|
|
2
2
|
module NN
|
3
3
|
class Sequential < Module
|
4
4
|
def initialize(*args)
|
5
|
-
|
5
|
+
super()
|
6
6
|
# TODO support hash arg (named modules)
|
7
7
|
args.each_with_index do |mod, idx|
|
8
8
|
add_module(idx.to_s, mod)
|
9
9
|
end
|
10
10
|
end
|
11
11
|
|
12
|
-
def add_module(name, mod)
|
13
|
-
# TODO add checks
|
14
|
-
@modules[name] = mod
|
15
|
-
end
|
16
|
-
|
17
12
|
def forward(input)
|
18
13
|
@modules.values.each do |mod|
|
19
14
|
input = mod.call(input)
|
20
15
|
end
|
21
16
|
input
|
22
17
|
end
|
23
|
-
|
24
|
-
def parameters
|
25
|
-
@modules.flat_map { |_, mod| mod.parameters }
|
26
|
-
end
|
27
18
|
end
|
28
19
|
end
|
29
20
|
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Softmax < Module
|
4
|
+
def initialize(dim: nil)
|
5
|
+
super()
|
6
|
+
@dim = dim
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input)
|
10
|
+
F.softmax(input, dim: @dim)
|
11
|
+
end
|
12
|
+
|
13
|
+
def extra_inspect
|
14
|
+
format("dim: %s", @dim)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Softplus < Module
|
4
|
+
def initialize(beta: 1, threshold: 20)
|
5
|
+
super()
|
6
|
+
@beta = beta
|
7
|
+
@threshold = threshold
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input)
|
11
|
+
F.softplus(input, beta: @beta, threshold: @threshold)
|
12
|
+
end
|
13
|
+
|
14
|
+
def extra_inspect
|
15
|
+
format("beta: %s, threshold: %s", @beta, @threshold)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class TripletMarginLoss < Loss
|
4
|
+
def initialize(margin: 1.0, p: 2.0, eps: 1e-6, swap: false, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
@p = p
|
8
|
+
@eps = eps
|
9
|
+
@swap = swap
|
10
|
+
end
|
11
|
+
|
12
|
+
def forward(anchor, positive, negative)
|
13
|
+
F.triplet_margin_loss(anchor, positive, negative, margin: @margin, p: @p,
|
14
|
+
eps: @eps, swap: @swap, reduction: @reduction)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Unfold < Module
|
4
|
+
def initialize(kernel_size, dilation: 1, padding: 0, stride: 1)
|
5
|
+
super()
|
6
|
+
@kernel_size = kernel_size
|
7
|
+
@dilation = dilation
|
8
|
+
@padding = padding
|
9
|
+
@stride = stride
|
10
|
+
end
|
11
|
+
|
12
|
+
def forward(input)
|
13
|
+
F.unfold(input, @kernel_size, dilation: @dilation, padding: @padding, stride: @stride)
|
14
|
+
end
|
15
|
+
|
16
|
+
# TODO add extra_inspect
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
module Utils
|
4
|
+
def _single(value)
|
5
|
+
_ntuple(1, value)
|
6
|
+
end
|
7
|
+
|
8
|
+
def _pair(value)
|
9
|
+
_ntuple(2, value)
|
10
|
+
end
|
11
|
+
|
12
|
+
def _triple(value)
|
13
|
+
_ntuple(3, value)
|
14
|
+
end
|
15
|
+
|
16
|
+
def _quadrupal(value)
|
17
|
+
_ntuple(4, value)
|
18
|
+
end
|
19
|
+
|
20
|
+
def _ntuple(n, value)
|
21
|
+
value.is_a?(Array) ? value : [value] * n
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
data/lib/torch/random.rb
ADDED
data/lib/torch/tensor.rb
CHANGED
@@ -5,12 +5,8 @@ module Torch
|
|
5
5
|
|
6
6
|
alias_method :requires_grad?, :requires_grad
|
7
7
|
|
8
|
-
def self.new(*
|
9
|
-
|
10
|
-
size.first
|
11
|
-
else
|
12
|
-
Torch.empty(*size)
|
13
|
-
end
|
8
|
+
def self.new(*args)
|
9
|
+
FloatTensor.new(*args)
|
14
10
|
end
|
15
11
|
|
16
12
|
def dtype
|
@@ -28,7 +24,7 @@ module Torch
|
|
28
24
|
end
|
29
25
|
|
30
26
|
def to_a
|
31
|
-
reshape_arr(
|
27
|
+
reshape_arr(_flat_data, shape)
|
32
28
|
end
|
33
29
|
|
34
30
|
# TODO support dtype
|
@@ -39,7 +35,7 @@ module Torch
|
|
39
35
|
|
40
36
|
def size(dim = nil)
|
41
37
|
if dim
|
42
|
-
|
38
|
+
_size_int(dim)
|
43
39
|
else
|
44
40
|
shape
|
45
41
|
end
|
@@ -49,15 +45,16 @@ module Torch
|
|
49
45
|
dim.times.map { |i| size(i) }
|
50
46
|
end
|
51
47
|
|
52
|
-
|
53
|
-
|
48
|
+
# mirror Python len()
|
49
|
+
def length
|
50
|
+
size(0)
|
54
51
|
end
|
55
52
|
|
56
53
|
def item
|
57
54
|
if numel != 1
|
58
55
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
59
56
|
end
|
60
|
-
|
57
|
+
_flat_data.first
|
61
58
|
end
|
62
59
|
|
63
60
|
# unsure if this is correct
|
@@ -66,19 +63,14 @@ module Torch
|
|
66
63
|
end
|
67
64
|
|
68
65
|
def backward(gradient = nil)
|
69
|
-
|
70
|
-
_backward_gradient(gradient)
|
71
|
-
else
|
72
|
-
_backward
|
73
|
-
end
|
66
|
+
_backward(gradient)
|
74
67
|
end
|
75
68
|
|
76
69
|
# TODO read directly from memory
|
77
70
|
def numo
|
78
|
-
raise Error, "Numo not found" unless defined?(Numo::NArray)
|
79
71
|
cls = Torch._dtype_to_numo[dtype]
|
80
72
|
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
81
|
-
cls.cast(
|
73
|
+
cls.cast(_flat_data).reshape(*shape)
|
82
74
|
end
|
83
75
|
|
84
76
|
def new_ones(*size, **options)
|
@@ -95,31 +87,23 @@ module Torch
|
|
95
87
|
_type(enum)
|
96
88
|
end
|
97
89
|
|
98
|
-
def
|
99
|
-
if
|
100
|
-
|
101
|
-
|
102
|
-
# need to use alpha for sparse tensors instead of multiplying
|
103
|
-
_add_alpha!(other, value)
|
104
|
-
end
|
90
|
+
def reshape(*size)
|
91
|
+
# Python doesn't check if size == 1, just ignores later arguments
|
92
|
+
size = size.first if size.size == 1 && size.first.is_a?(Array)
|
93
|
+
_reshape(size)
|
105
94
|
end
|
106
95
|
|
107
|
-
def
|
108
|
-
if
|
109
|
-
|
110
|
-
else
|
111
|
-
_mul!(other)
|
112
|
-
end
|
96
|
+
def view(*size)
|
97
|
+
size = size.first if size.size == 1 && size.first.is_a?(Array)
|
98
|
+
_view(size)
|
113
99
|
end
|
114
100
|
|
115
|
-
#
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
Torch.send(op, self, *args, &block)
|
122
|
-
end
|
101
|
+
# value and other are swapped for some methods
|
102
|
+
def add!(value = 1, other)
|
103
|
+
if other.is_a?(Numeric)
|
104
|
+
_add__scalar(other, value)
|
105
|
+
else
|
106
|
+
_add__tensor(other, value)
|
123
107
|
end
|
124
108
|
end
|
125
109
|
|
@@ -161,14 +145,20 @@ module Torch
|
|
161
145
|
dim = 0
|
162
146
|
indexes.each do |index|
|
163
147
|
if index.is_a?(Numeric)
|
164
|
-
result = result.
|
148
|
+
result = result._select_int(dim, index)
|
165
149
|
elsif index.is_a?(Range)
|
166
150
|
finish = index.end
|
167
151
|
finish += 1 unless index.exclude_end?
|
168
|
-
result = result.
|
152
|
+
result = result._slice_tensor(dim, index.begin, finish, 1)
|
153
|
+
dim += 1
|
154
|
+
elsif index.nil?
|
155
|
+
result = result.unsqueeze(dim)
|
169
156
|
dim += 1
|
157
|
+
elsif index == true
|
158
|
+
result = result.unsqueeze(dim)
|
159
|
+
# TODO handle false
|
170
160
|
else
|
171
|
-
raise Error, "Unsupported index type"
|
161
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
172
162
|
end
|
173
163
|
end
|
174
164
|
result
|
@@ -176,11 +166,28 @@ module Torch
|
|
176
166
|
|
177
167
|
# TODO
|
178
168
|
# based on python_variable_indexing.cpp
|
179
|
-
|
180
|
-
|
169
|
+
def []=(index, value)
|
170
|
+
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
171
|
+
|
172
|
+
value = Torch.tensor(value) unless value.is_a?(Tensor)
|
173
|
+
|
174
|
+
if index.is_a?(Numeric)
|
175
|
+
copy_to(_select_int(0, index), value)
|
176
|
+
elsif index.is_a?(Range)
|
177
|
+
finish = index.end
|
178
|
+
finish += 1 unless index.exclude_end?
|
179
|
+
copy_to(_slice_tensor(0, index.begin, finish, 1), value)
|
180
|
+
else
|
181
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
182
|
+
end
|
183
|
+
end
|
181
184
|
|
182
185
|
private
|
183
186
|
|
187
|
+
def copy_to(dst, src)
|
188
|
+
dst.copy!(src)
|
189
|
+
end
|
190
|
+
|
184
191
|
def reshape_arr(arr, dims)
|
185
192
|
if dims.empty?
|
186
193
|
arr
|