torch-rb 0.1.3 → 0.1.8
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +30 -0
- data/README.md +5 -2
- data/ext/torch/ext.cpp +130 -555
- data/ext/torch/extconf.rb +9 -0
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +244 -0
- data/lib/torch.rb +209 -171
- data/lib/torch/inspector.rb +23 -19
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +110 -0
- data/lib/torch/native/generator.rb +168 -0
- data/lib/torch/native/native_functions.yaml +6491 -0
- data/lib/torch/native/parser.rb +134 -0
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +19 -0
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +10 -20
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/convnd.rb +3 -3
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropoutnd.rb +2 -2
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +379 -32
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +186 -35
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +2 -2
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +198 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +51 -44
- data/lib/torch/version.rb +1 -1
- metadata +98 -6
- data/lib/torch/ext.bundle +0 -0
data/lib/torch/nn/sequential.rb
CHANGED
@@ -2,28 +2,19 @@ module Torch
|
|
2
2
|
module NN
|
3
3
|
class Sequential < Module
|
4
4
|
def initialize(*args)
|
5
|
-
|
5
|
+
super()
|
6
6
|
# TODO support hash arg (named modules)
|
7
7
|
args.each_with_index do |mod, idx|
|
8
8
|
add_module(idx.to_s, mod)
|
9
9
|
end
|
10
10
|
end
|
11
11
|
|
12
|
-
def add_module(name, mod)
|
13
|
-
# TODO add checks
|
14
|
-
@modules[name] = mod
|
15
|
-
end
|
16
|
-
|
17
12
|
def forward(input)
|
18
13
|
@modules.values.each do |mod|
|
19
14
|
input = mod.call(input)
|
20
15
|
end
|
21
16
|
input
|
22
17
|
end
|
23
|
-
|
24
|
-
def parameters
|
25
|
-
@modules.flat_map { |_, mod| mod.parameters }
|
26
|
-
end
|
27
18
|
end
|
28
19
|
end
|
29
20
|
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Softmax < Module
|
4
|
+
def initialize(dim: nil)
|
5
|
+
super()
|
6
|
+
@dim = dim
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input)
|
10
|
+
F.softmax(input, dim: @dim)
|
11
|
+
end
|
12
|
+
|
13
|
+
def extra_inspect
|
14
|
+
format("dim: %s", @dim)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Softplus < Module
|
4
|
+
def initialize(beta: 1, threshold: 20)
|
5
|
+
super()
|
6
|
+
@beta = beta
|
7
|
+
@threshold = threshold
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input)
|
11
|
+
F.softplus(input, beta: @beta, threshold: @threshold)
|
12
|
+
end
|
13
|
+
|
14
|
+
def extra_inspect
|
15
|
+
format("beta: %s, threshold: %s", @beta, @threshold)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class TripletMarginLoss < Loss
|
4
|
+
def initialize(margin: 1.0, p: 2.0, eps: 1e-6, swap: false, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
@p = p
|
8
|
+
@eps = eps
|
9
|
+
@swap = swap
|
10
|
+
end
|
11
|
+
|
12
|
+
def forward(anchor, positive, negative)
|
13
|
+
F.triplet_margin_loss(anchor, positive, negative, margin: @margin, p: @p,
|
14
|
+
eps: @eps, swap: @swap, reduction: @reduction)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Unfold < Module
|
4
|
+
def initialize(kernel_size, dilation: 1, padding: 0, stride: 1)
|
5
|
+
super()
|
6
|
+
@kernel_size = kernel_size
|
7
|
+
@dilation = dilation
|
8
|
+
@padding = padding
|
9
|
+
@stride = stride
|
10
|
+
end
|
11
|
+
|
12
|
+
def forward(input)
|
13
|
+
F.unfold(input, @kernel_size, dilation: @dilation, padding: @padding, stride: @stride)
|
14
|
+
end
|
15
|
+
|
16
|
+
# TODO add extra_inspect
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
@@ -0,0 +1,25 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
module Utils
|
4
|
+
def _single(value)
|
5
|
+
_ntuple(1, value)
|
6
|
+
end
|
7
|
+
|
8
|
+
def _pair(value)
|
9
|
+
_ntuple(2, value)
|
10
|
+
end
|
11
|
+
|
12
|
+
def _triple(value)
|
13
|
+
_ntuple(3, value)
|
14
|
+
end
|
15
|
+
|
16
|
+
def _quadrupal(value)
|
17
|
+
_ntuple(4, value)
|
18
|
+
end
|
19
|
+
|
20
|
+
def _ntuple(n, value)
|
21
|
+
value.is_a?(Array) ? value : [value] * n
|
22
|
+
end
|
23
|
+
end
|
24
|
+
end
|
25
|
+
end
|
data/lib/torch/random.rb
ADDED
data/lib/torch/tensor.rb
CHANGED
@@ -5,12 +5,8 @@ module Torch
|
|
5
5
|
|
6
6
|
alias_method :requires_grad?, :requires_grad
|
7
7
|
|
8
|
-
def self.new(*
|
9
|
-
|
10
|
-
size.first
|
11
|
-
else
|
12
|
-
Torch.empty(*size)
|
13
|
-
end
|
8
|
+
def self.new(*args)
|
9
|
+
FloatTensor.new(*args)
|
14
10
|
end
|
15
11
|
|
16
12
|
def dtype
|
@@ -28,7 +24,7 @@ module Torch
|
|
28
24
|
end
|
29
25
|
|
30
26
|
def to_a
|
31
|
-
reshape_arr(
|
27
|
+
reshape_arr(_flat_data, shape)
|
32
28
|
end
|
33
29
|
|
34
30
|
# TODO support dtype
|
@@ -39,7 +35,7 @@ module Torch
|
|
39
35
|
|
40
36
|
def size(dim = nil)
|
41
37
|
if dim
|
42
|
-
|
38
|
+
_size_int(dim)
|
43
39
|
else
|
44
40
|
shape
|
45
41
|
end
|
@@ -49,15 +45,16 @@ module Torch
|
|
49
45
|
dim.times.map { |i| size(i) }
|
50
46
|
end
|
51
47
|
|
52
|
-
|
53
|
-
|
48
|
+
# mirror Python len()
|
49
|
+
def length
|
50
|
+
size(0)
|
54
51
|
end
|
55
52
|
|
56
53
|
def item
|
57
54
|
if numel != 1
|
58
55
|
raise Error, "only one element tensors can be converted to Ruby scalars"
|
59
56
|
end
|
60
|
-
|
57
|
+
_flat_data.first
|
61
58
|
end
|
62
59
|
|
63
60
|
# unsure if this is correct
|
@@ -66,19 +63,14 @@ module Torch
|
|
66
63
|
end
|
67
64
|
|
68
65
|
def backward(gradient = nil)
|
69
|
-
|
70
|
-
_backward_gradient(gradient)
|
71
|
-
else
|
72
|
-
_backward
|
73
|
-
end
|
66
|
+
_backward(gradient)
|
74
67
|
end
|
75
68
|
|
76
69
|
# TODO read directly from memory
|
77
70
|
def numo
|
78
|
-
raise Error, "Numo not found" unless defined?(Numo::NArray)
|
79
71
|
cls = Torch._dtype_to_numo[dtype]
|
80
72
|
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
81
|
-
cls.cast(
|
73
|
+
cls.cast(_flat_data).reshape(*shape)
|
82
74
|
end
|
83
75
|
|
84
76
|
def new_ones(*size, **options)
|
@@ -95,31 +87,23 @@ module Torch
|
|
95
87
|
_type(enum)
|
96
88
|
end
|
97
89
|
|
98
|
-
def
|
99
|
-
if
|
100
|
-
|
101
|
-
|
102
|
-
# need to use alpha for sparse tensors instead of multiplying
|
103
|
-
_add_alpha!(other, value)
|
104
|
-
end
|
90
|
+
def reshape(*size)
|
91
|
+
# Python doesn't check if size == 1, just ignores later arguments
|
92
|
+
size = size.first if size.size == 1 && size.first.is_a?(Array)
|
93
|
+
_reshape(size)
|
105
94
|
end
|
106
95
|
|
107
|
-
def
|
108
|
-
if
|
109
|
-
|
110
|
-
else
|
111
|
-
_mul!(other)
|
112
|
-
end
|
96
|
+
def view(*size)
|
97
|
+
size = size.first if size.size == 1 && size.first.is_a?(Array)
|
98
|
+
_view(size)
|
113
99
|
end
|
114
100
|
|
115
|
-
#
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
Torch.send(op, self, *args, &block)
|
122
|
-
end
|
101
|
+
# value and other are swapped for some methods
|
102
|
+
def add!(value = 1, other)
|
103
|
+
if other.is_a?(Numeric)
|
104
|
+
_add__scalar(other, value)
|
105
|
+
else
|
106
|
+
_add__tensor(other, value)
|
123
107
|
end
|
124
108
|
end
|
125
109
|
|
@@ -161,14 +145,20 @@ module Torch
|
|
161
145
|
dim = 0
|
162
146
|
indexes.each do |index|
|
163
147
|
if index.is_a?(Numeric)
|
164
|
-
result = result.
|
148
|
+
result = result._select_int(dim, index)
|
165
149
|
elsif index.is_a?(Range)
|
166
150
|
finish = index.end
|
167
151
|
finish += 1 unless index.exclude_end?
|
168
|
-
result = result.
|
152
|
+
result = result._slice_tensor(dim, index.begin, finish, 1)
|
153
|
+
dim += 1
|
154
|
+
elsif index.nil?
|
155
|
+
result = result.unsqueeze(dim)
|
169
156
|
dim += 1
|
157
|
+
elsif index == true
|
158
|
+
result = result.unsqueeze(dim)
|
159
|
+
# TODO handle false
|
170
160
|
else
|
171
|
-
raise Error, "Unsupported index type"
|
161
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
172
162
|
end
|
173
163
|
end
|
174
164
|
result
|
@@ -176,11 +166,28 @@ module Torch
|
|
176
166
|
|
177
167
|
# TODO
|
178
168
|
# based on python_variable_indexing.cpp
|
179
|
-
|
180
|
-
|
169
|
+
def []=(index, value)
|
170
|
+
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
171
|
+
|
172
|
+
value = Torch.tensor(value) unless value.is_a?(Tensor)
|
173
|
+
|
174
|
+
if index.is_a?(Numeric)
|
175
|
+
copy_to(_select_int(0, index), value)
|
176
|
+
elsif index.is_a?(Range)
|
177
|
+
finish = index.end
|
178
|
+
finish += 1 unless index.exclude_end?
|
179
|
+
copy_to(_slice_tensor(0, index.begin, finish, 1), value)
|
180
|
+
else
|
181
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
182
|
+
end
|
183
|
+
end
|
181
184
|
|
182
185
|
private
|
183
186
|
|
187
|
+
def copy_to(dst, src)
|
188
|
+
dst.copy!(src)
|
189
|
+
end
|
190
|
+
|
184
191
|
def reshape_arr(arr, dims)
|
185
192
|
if dims.empty?
|
186
193
|
arr
|