torch-rb 0.1.0 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +40 -0
- data/LICENSE.txt +46 -22
- data/README.md +85 -19
- data/ext/torch/ext.cpp +274 -256
- data/ext/torch/extconf.rb +9 -0
- data/ext/torch/nn_functions.cpp +595 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.hpp +250 -0
- data/ext/torch/tensor_functions.cpp +1860 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2875 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +199 -84
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +52 -25
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +78 -0
- data/lib/torch/native/generator.rb +149 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +97 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/avg_pool2d.rb +14 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/conv2d.rb +14 -29
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +194 -11
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/module.rb +184 -19
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +154 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +62 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +60 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +92 -21
- data/lib/torch/utils/data/data_loader.rb +15 -0
- data/lib/torch/utils/data/tensor_dataset.rb +8 -1
- data/lib/torch/version.rb +1 -1
- metadata +74 -3
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MarginRankingLoss < Loss
|
4
|
+
def initialize(margin: 1.0, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input1, input2, target)
|
10
|
+
F.margin_ranking_loss(input1, input2, target, margin: @margin, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MaxPoolNd < Module
|
4
|
+
def initialize(kernel_size) #, stride: nil, padding: 0, dilation: 1, return_indices: false, ceil_mode: false)
|
5
|
+
super()
|
6
|
+
@kernel_size = kernel_size
|
7
|
+
# @stride = stride || kernel_size
|
8
|
+
# @padding = padding
|
9
|
+
# @dilation = dilation
|
10
|
+
# @return_indices = return_indices
|
11
|
+
# @ceil_mode = ceil_mode
|
12
|
+
end
|
13
|
+
|
14
|
+
def extra_inspect
|
15
|
+
format("kernel_size: %s", @kernel_size)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
data/lib/torch/nn/module.rb
CHANGED
@@ -1,55 +1,220 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Module
|
4
|
-
def
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
4
|
+
def initialize
|
5
|
+
@training = true
|
6
|
+
@parameters = {}
|
7
|
+
@buffers = {}
|
8
|
+
@modules = {}
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward
|
12
|
+
raise NotImplementedError
|
13
|
+
end
|
14
|
+
|
15
|
+
def register_buffer(name, tensor)
|
16
|
+
# TODO add checks
|
17
|
+
@buffers[name] = tensor
|
18
|
+
end
|
19
|
+
|
20
|
+
def register_parameter(name, param)
|
21
|
+
# TODO add checks
|
22
|
+
@parameters[name] = param
|
23
|
+
end
|
24
|
+
|
25
|
+
def add_module(name, mod)
|
26
|
+
# TODO add checks
|
27
|
+
@modules[name] = mod
|
28
|
+
end
|
29
|
+
|
30
|
+
def _apply(fn)
|
31
|
+
children.each do |mod|
|
32
|
+
mod._apply(fn)
|
33
|
+
end
|
34
|
+
# TODO apply to more objects
|
35
|
+
self
|
36
|
+
end
|
37
|
+
|
38
|
+
def apply(fn)
|
39
|
+
children.each do |mod|
|
40
|
+
mod.apply(fn)
|
41
|
+
end
|
42
|
+
fn.call(self)
|
43
|
+
self
|
44
|
+
end
|
45
|
+
|
46
|
+
def cuda(device: nil)
|
47
|
+
_apply ->(t) { t.cuda(device) }
|
48
|
+
end
|
49
|
+
|
50
|
+
def cpu
|
51
|
+
_apply ->(t) { t.cpu }
|
52
|
+
end
|
53
|
+
|
54
|
+
def type(dst_type)
|
55
|
+
_apply ->(t) { t.type(dst_type) }
|
56
|
+
end
|
57
|
+
|
58
|
+
def float
|
59
|
+
_apply ->(t) { t.floating_point? ? t.float : t }
|
60
|
+
end
|
61
|
+
|
62
|
+
def double
|
63
|
+
_apply ->(t) { t.floating_point? ? t.double : t }
|
64
|
+
end
|
65
|
+
|
66
|
+
def half
|
67
|
+
_apply ->(t) { t.floating_point? ? t.half : t }
|
68
|
+
end
|
69
|
+
|
70
|
+
# modifies in-place
|
71
|
+
def to(device)
|
72
|
+
convert = lambda do |t|
|
73
|
+
t.to(device)
|
9
74
|
end
|
10
|
-
|
75
|
+
|
76
|
+
_apply(convert)
|
11
77
|
end
|
12
78
|
|
13
79
|
def call(*input)
|
14
80
|
forward(*input)
|
15
81
|
end
|
16
82
|
|
83
|
+
def state_dict
|
84
|
+
raise NotImplementedYet
|
85
|
+
end
|
86
|
+
|
17
87
|
def parameters
|
18
|
-
|
88
|
+
named_parameters.values
|
89
|
+
end
|
90
|
+
|
91
|
+
def named_parameters(prefix: "", recurse: true)
|
92
|
+
params = {}
|
93
|
+
if recurse
|
94
|
+
named_children.each do |name, mod|
|
95
|
+
params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
|
96
|
+
end
|
97
|
+
end
|
19
98
|
instance_variables.each do |name|
|
20
99
|
param = instance_variable_get(name)
|
21
|
-
params
|
100
|
+
params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
|
101
|
+
end
|
102
|
+
@parameters.each do |name, param|
|
103
|
+
params[[prefix, name].join] = param
|
22
104
|
end
|
23
|
-
params
|
105
|
+
params
|
106
|
+
end
|
107
|
+
|
108
|
+
def buffers
|
109
|
+
named_buffers.values
|
110
|
+
end
|
111
|
+
|
112
|
+
def named_buffers
|
113
|
+
@buffers || {}
|
114
|
+
end
|
115
|
+
|
116
|
+
def children
|
117
|
+
named_children.values
|
118
|
+
end
|
119
|
+
|
120
|
+
def named_children
|
121
|
+
modules = {}
|
122
|
+
instance_variables.each do |name|
|
123
|
+
mod = instance_variable_get(name)
|
124
|
+
modules[name[1..-1]] = mod if mod.is_a?(Module)
|
125
|
+
end
|
126
|
+
@modules.each do |name, mod|
|
127
|
+
modules[name] = mod
|
128
|
+
end
|
129
|
+
modules
|
130
|
+
end
|
131
|
+
|
132
|
+
def modules
|
133
|
+
named_modules.values
|
134
|
+
end
|
135
|
+
|
136
|
+
def named_modules
|
137
|
+
{"" => self}.merge(named_children)
|
138
|
+
end
|
139
|
+
|
140
|
+
def train(mode = true)
|
141
|
+
@training = mode
|
142
|
+
children.each do |mod|
|
143
|
+
mod.train(mode)
|
144
|
+
end
|
145
|
+
self
|
146
|
+
end
|
147
|
+
|
148
|
+
def eval
|
149
|
+
train(false)
|
150
|
+
end
|
151
|
+
|
152
|
+
def requires_grad!(requires_grad: true)
|
153
|
+
parameters.each do |p|
|
154
|
+
p.requires_grad!(requires_grad)
|
155
|
+
end
|
156
|
+
self
|
24
157
|
end
|
25
158
|
|
26
159
|
def zero_grad
|
27
160
|
parameters.each do |param|
|
28
161
|
if param.grad
|
29
|
-
raise Error, "Not supported yet"
|
30
162
|
param.grad.detach!
|
31
163
|
param.grad.zero!
|
32
164
|
end
|
33
165
|
end
|
34
166
|
end
|
35
167
|
|
168
|
+
def share_memory
|
169
|
+
_apply ->(t) { t.share_memory! }
|
170
|
+
end
|
171
|
+
|
172
|
+
def inspect
|
173
|
+
name = self.class.name.split("::").last
|
174
|
+
if children.empty?
|
175
|
+
"#{name}(#{extra_inspect})"
|
176
|
+
else
|
177
|
+
str = String.new
|
178
|
+
str << "#{name}(\n"
|
179
|
+
children.each do |name, mod|
|
180
|
+
str << " (#{name}): #{mod.inspect}\n"
|
181
|
+
end
|
182
|
+
str << ")"
|
183
|
+
end
|
184
|
+
end
|
185
|
+
|
36
186
|
def method_missing(method, *args, &block)
|
37
|
-
|
187
|
+
name = method.to_s
|
188
|
+
if named_parameters.key?(name)
|
189
|
+
named_parameters[name]
|
190
|
+
elsif named_buffers.key?(name)
|
191
|
+
named_buffers[name]
|
192
|
+
elsif named_modules.key?(name)
|
193
|
+
named_modules[name]
|
194
|
+
else
|
195
|
+
super
|
196
|
+
end
|
38
197
|
end
|
39
198
|
|
40
199
|
def respond_to?(method, include_private = false)
|
41
|
-
|
200
|
+
name = method.to_s
|
201
|
+
named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
|
42
202
|
end
|
43
203
|
|
44
204
|
private
|
45
205
|
|
46
|
-
def
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
206
|
+
def extra_inspect
|
207
|
+
nil
|
208
|
+
end
|
209
|
+
|
210
|
+
def format(str, *vars, **options)
|
211
|
+
vars =
|
212
|
+
if vars.any?
|
213
|
+
vars.map(&:inspect)
|
214
|
+
else
|
215
|
+
options.map { |k, v| [k, v.inspect] }.to_h
|
216
|
+
end
|
217
|
+
str % vars
|
53
218
|
end
|
54
219
|
end
|
55
220
|
end
|
data/lib/torch/nn/mse_loss.rb
CHANGED
@@ -0,0 +1,13 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MultiLabelSoftMarginLoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
end
|
7
|
+
|
8
|
+
def forward(input, target)
|
9
|
+
F.multilabel_soft_margin_loss(input, target, weight: @weight, reduction: @reduction)
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,17 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class MultiMarginLoss < WeightedLoss
|
4
|
+
def initialize(p: 1, margin: 1.0, weight: nil, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
raise ArgumentError, "only p == 1 and p == 2 supported" if p != 1 && p != 2
|
7
|
+
raise ArgumentError, "weight must be nil or have one dimension" unless weight.nil? || weight.dim == 1
|
8
|
+
@p = p
|
9
|
+
@margin = margin
|
10
|
+
end
|
11
|
+
|
12
|
+
def forward(input, target)
|
13
|
+
F.multi_margin_loss(input, target, p: @p, margin: @margin, weight: @weight, reduction: @reduction)
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class NLLLoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, ignore_index: -100, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
@ignore_index = ignore_index
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input, target)
|
10
|
+
F.nll_loss(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class PairwiseDistance < Module
|
4
|
+
def initialize(p: 2.0, eps: 1e-6, keepdim: false)
|
5
|
+
super()
|
6
|
+
@norm = p
|
7
|
+
@eps = eps
|
8
|
+
@keepdim = keepdim
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(x1, x2)
|
12
|
+
F.pairwise_distance(x1, x2, p: @norm, eps: @eps, keepdim: @keepdim)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
data/lib/torch/nn/parameter.rb
CHANGED
@@ -0,0 +1,16 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class PoissonNLLLoss < Loss
|
4
|
+
def initialize(log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@log_input = log_input
|
7
|
+
@full = full
|
8
|
+
@eps = eps
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(log_input, target)
|
12
|
+
F.poisson_nll_loss(log_input, target, log_input: @log_input, full: @full, eps: @eps, reduction: @reduction)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class PReLU < Module
|
4
|
+
def initialize(num_parameters: 1, init: 0.25)
|
5
|
+
@num_parameters = num_parameters
|
6
|
+
super()
|
7
|
+
@weight = Parameter.new(Tensor.new(num_parameters).fill!(init))
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input)
|
11
|
+
F.prelu(input, @weight)
|
12
|
+
end
|
13
|
+
|
14
|
+
def extra_inspect
|
15
|
+
format("num_parameters: %s", @num_parameters)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|