torch-rb 0.1.3 → 0.1.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +1 -0
- data/ext/torch/ext.cpp +375 -124
- data/lib/torch.rb +101 -20
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +23 -19
- data/lib/torch/nn/avg_pool2d.rb +14 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/conv2d.rb +2 -2
- data/lib/torch/nn/convnd.rb +3 -3
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropoutnd.rb +2 -2
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/functional.rb +101 -13
- data/lib/torch/nn/identity.rb +13 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/module.rb +120 -31
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +0 -4
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +28 -10
- data/lib/torch/version.rb +1 -1
- metadata +29 -2
data/lib/torch/nn/mse_loss.rb
CHANGED
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class NLLLoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, ignore_index: -100, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
@ignore_index = ignore_index
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input, target)
|
10
|
+
F.nll_loss(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class PairwiseDistance < Module
|
4
|
+
def initialize(p: 2.0, eps: 1e-6, keepdim: false)
|
5
|
+
super()
|
6
|
+
@norm = p
|
7
|
+
@eps = eps
|
8
|
+
@keepdim = keepdim
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(x1, x2)
|
12
|
+
F.pairwise_distance(x1, x2, p: @norm, eps: @eps, keepdim: @keepdim)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
data/lib/torch/nn/parameter.rb
CHANGED
@@ -0,0 +1,16 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class PoissonNLLLoss < Loss
|
4
|
+
def initialize(log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@log_input = log_input
|
7
|
+
@full = full
|
8
|
+
@eps = eps
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(log_input, target)
|
12
|
+
F.poisson_nll_loss(log_input, target, log_input: @log_input, full: @full, eps: @eps, reduction: @reduction)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class PReLU < Module
|
4
|
+
def initialize(num_parameters: 1, init: 0.25)
|
5
|
+
@num_parameters = num_parameters
|
6
|
+
super()
|
7
|
+
@weight = Parameter.new(Tensor.new(num_parameters).fill!(init))
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input)
|
11
|
+
F.prelu(input, @weight)
|
12
|
+
end
|
13
|
+
|
14
|
+
def extra_inspect
|
15
|
+
format("num_parameters: %s", @num_parameters)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
data/lib/torch/nn/relu.rb
CHANGED
@@ -1,12 +1,17 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class ReLU < Module
|
4
|
-
def initialize
|
5
|
-
|
4
|
+
def initialize(inplace: false)
|
5
|
+
super()
|
6
|
+
@inplace = inplace
|
6
7
|
end
|
7
8
|
|
8
9
|
def forward(input)
|
9
|
-
F.relu(input
|
10
|
+
F.relu(input, inplace: @inplace)
|
11
|
+
end
|
12
|
+
|
13
|
+
def extra_inspect
|
14
|
+
@inplace ? "inplace: true" : ""
|
10
15
|
end
|
11
16
|
end
|
12
17
|
end
|
data/lib/torch/nn/sequential.rb
CHANGED
@@ -2,28 +2,19 @@ module Torch
|
|
2
2
|
module NN
|
3
3
|
class Sequential < Module
|
4
4
|
def initialize(*args)
|
5
|
-
|
5
|
+
super()
|
6
6
|
# TODO support hash arg (named modules)
|
7
7
|
args.each_with_index do |mod, idx|
|
8
8
|
add_module(idx.to_s, mod)
|
9
9
|
end
|
10
10
|
end
|
11
11
|
|
12
|
-
def add_module(name, mod)
|
13
|
-
# TODO add checks
|
14
|
-
@modules[name] = mod
|
15
|
-
end
|
16
|
-
|
17
12
|
def forward(input)
|
18
13
|
@modules.values.each do |mod|
|
19
14
|
input = mod.call(input)
|
20
15
|
end
|
21
16
|
input
|
22
17
|
end
|
23
|
-
|
24
|
-
def parameters
|
25
|
-
@modules.flat_map { |_, mod| mod.parameters }
|
26
|
-
end
|
27
18
|
end
|
28
19
|
end
|
29
20
|
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Softmax < Module
|
4
|
+
def initialize(dim: nil)
|
5
|
+
super()
|
6
|
+
@dim = dim
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input)
|
10
|
+
F.softmax(input, dim: @dim)
|
11
|
+
end
|
12
|
+
|
13
|
+
def extra_inspect
|
14
|
+
format("dim: %s", @dim)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Softplus < Module
|
4
|
+
def initialize(beta: 1, threshold: 20)
|
5
|
+
super()
|
6
|
+
@beta = beta
|
7
|
+
@threshold = threshold
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input)
|
11
|
+
F.softplus(input, beta: @beta, threshold: @threshold)
|
12
|
+
end
|
13
|
+
|
14
|
+
def extra_inspect
|
15
|
+
format("beta: %s, threshold: %s", @beta, @threshold)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
data/lib/torch/random.rb
ADDED
data/lib/torch/tensor.rb
CHANGED
@@ -66,16 +66,11 @@ module Torch
|
|
66
66
|
end
|
67
67
|
|
68
68
|
def backward(gradient = nil)
|
69
|
-
|
70
|
-
_backward_gradient(gradient)
|
71
|
-
else
|
72
|
-
_backward
|
73
|
-
end
|
69
|
+
_backward(gradient)
|
74
70
|
end
|
75
71
|
|
76
72
|
# TODO read directly from memory
|
77
73
|
def numo
|
78
|
-
raise Error, "Numo not found" unless defined?(Numo::NArray)
|
79
74
|
cls = Torch._dtype_to_numo[dtype]
|
80
75
|
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
81
76
|
cls.cast(_data).reshape(*shape)
|
@@ -113,7 +108,7 @@ module Torch
|
|
113
108
|
end
|
114
109
|
|
115
110
|
# operations
|
116
|
-
%w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op|
|
111
|
+
%w(abs add argmax div dot eq exp gt log log_softmax lt matmul max mean min mul neg norm num numel pow relu remainder reshape sign softmax sqrt sub sum unsqueeze topk).each do |op|
|
117
112
|
define_method(op) do |*args, **options, &block|
|
118
113
|
if options.any?
|
119
114
|
Torch.send(op, self, *args, **options, &block)
|
@@ -167,8 +162,14 @@ module Torch
|
|
167
162
|
finish += 1 unless index.exclude_end?
|
168
163
|
result = result._slice(dim, index.begin, finish, 1)
|
169
164
|
dim += 1
|
165
|
+
elsif index.nil?
|
166
|
+
result = result.unsqueeze(dim)
|
167
|
+
dim += 1
|
168
|
+
elsif index == true
|
169
|
+
result = result.unsqueeze(dim)
|
170
|
+
# TODO handle false
|
170
171
|
else
|
171
|
-
raise Error, "Unsupported index type"
|
172
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
172
173
|
end
|
173
174
|
end
|
174
175
|
result
|
@@ -176,11 +177,28 @@ module Torch
|
|
176
177
|
|
177
178
|
# TODO
|
178
179
|
# based on python_variable_indexing.cpp
|
179
|
-
|
180
|
-
|
180
|
+
def []=(index, value)
|
181
|
+
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
182
|
+
|
183
|
+
value = Torch.tensor(value) unless value.is_a?(Tensor)
|
184
|
+
|
185
|
+
if index.is_a?(Numeric)
|
186
|
+
copy_to(_select(0, index), value)
|
187
|
+
elsif index.is_a?(Range)
|
188
|
+
finish = index.end
|
189
|
+
finish += 1 unless index.exclude_end?
|
190
|
+
copy_to(_slice(0, index.begin, finish, 1), value)
|
191
|
+
else
|
192
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
193
|
+
end
|
194
|
+
end
|
181
195
|
|
182
196
|
private
|
183
197
|
|
198
|
+
def copy_to(dst, src)
|
199
|
+
dst.copy!(src)
|
200
|
+
end
|
201
|
+
|
184
202
|
def reshape_arr(arr, dims)
|
185
203
|
if dims.empty?
|
186
204
|
arr
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.4
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-
|
11
|
+
date: 2019-12-01 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -111,22 +111,48 @@ files:
|
|
111
111
|
- lib/torch/ext.bundle
|
112
112
|
- lib/torch/inspector.rb
|
113
113
|
- lib/torch/nn/alpha_dropout.rb
|
114
|
+
- lib/torch/nn/avg_pool2d.rb
|
115
|
+
- lib/torch/nn/avg_poolnd.rb
|
116
|
+
- lib/torch/nn/bce_loss.rb
|
117
|
+
- lib/torch/nn/bilinear.rb
|
114
118
|
- lib/torch/nn/conv2d.rb
|
115
119
|
- lib/torch/nn/convnd.rb
|
120
|
+
- lib/torch/nn/cosine_similarity.rb
|
121
|
+
- lib/torch/nn/cross_entropy_loss.rb
|
122
|
+
- lib/torch/nn/ctc_loss.rb
|
116
123
|
- lib/torch/nn/dropout.rb
|
117
124
|
- lib/torch/nn/dropout2d.rb
|
118
125
|
- lib/torch/nn/dropout3d.rb
|
119
126
|
- lib/torch/nn/dropoutnd.rb
|
120
127
|
- lib/torch/nn/embedding.rb
|
128
|
+
- lib/torch/nn/embedding_bag.rb
|
121
129
|
- lib/torch/nn/feature_alpha_dropout.rb
|
122
130
|
- lib/torch/nn/functional.rb
|
131
|
+
- lib/torch/nn/identity.rb
|
123
132
|
- lib/torch/nn/init.rb
|
133
|
+
- lib/torch/nn/kl_div_loss.rb
|
134
|
+
- lib/torch/nn/l1_loss.rb
|
135
|
+
- lib/torch/nn/leaky_relu.rb
|
124
136
|
- lib/torch/nn/linear.rb
|
137
|
+
- lib/torch/nn/log_softmax.rb
|
138
|
+
- lib/torch/nn/loss.rb
|
139
|
+
- lib/torch/nn/max_pool2d.rb
|
140
|
+
- lib/torch/nn/max_poolnd.rb
|
125
141
|
- lib/torch/nn/module.rb
|
126
142
|
- lib/torch/nn/mse_loss.rb
|
143
|
+
- lib/torch/nn/nll_loss.rb
|
144
|
+
- lib/torch/nn/pairwise_distance.rb
|
127
145
|
- lib/torch/nn/parameter.rb
|
146
|
+
- lib/torch/nn/poisson_nll_loss.rb
|
147
|
+
- lib/torch/nn/prelu.rb
|
128
148
|
- lib/torch/nn/relu.rb
|
129
149
|
- lib/torch/nn/sequential.rb
|
150
|
+
- lib/torch/nn/sigmoid.rb
|
151
|
+
- lib/torch/nn/softmax.rb
|
152
|
+
- lib/torch/nn/softmax2d.rb
|
153
|
+
- lib/torch/nn/softmin.rb
|
154
|
+
- lib/torch/nn/softplus.rb
|
155
|
+
- lib/torch/nn/weighted_loss.rb
|
130
156
|
- lib/torch/optim/adadelta.rb
|
131
157
|
- lib/torch/optim/adagrad.rb
|
132
158
|
- lib/torch/optim/adam.rb
|
@@ -139,6 +165,7 @@ files:
|
|
139
165
|
- lib/torch/optim/rmsprop.rb
|
140
166
|
- lib/torch/optim/rprop.rb
|
141
167
|
- lib/torch/optim/sgd.rb
|
168
|
+
- lib/torch/random.rb
|
142
169
|
- lib/torch/tensor.rb
|
143
170
|
- lib/torch/utils/data/data_loader.rb
|
144
171
|
- lib/torch/utils/data/tensor_dataset.rb
|