torch-rb 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +1 -0
- data/ext/torch/ext.cpp +375 -124
- data/lib/torch.rb +101 -20
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +23 -19
- data/lib/torch/nn/avg_pool2d.rb +14 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/conv2d.rb +2 -2
- data/lib/torch/nn/convnd.rb +3 -3
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropoutnd.rb +2 -2
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/functional.rb +101 -13
- data/lib/torch/nn/identity.rb +13 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/module.rb +120 -31
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +0 -4
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +28 -10
- data/lib/torch/version.rb +1 -1
- metadata +29 -2
data/lib/torch/nn/mse_loss.rb
CHANGED
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class NLLLoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, ignore_index: -100, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
@ignore_index = ignore_index
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input, target)
|
10
|
+
F.nll_loss(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,16 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class PairwiseDistance < Module
|
4
|
+
def initialize(p: 2.0, eps: 1e-6, keepdim: false)
|
5
|
+
super()
|
6
|
+
@norm = p
|
7
|
+
@eps = eps
|
8
|
+
@keepdim = keepdim
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(x1, x2)
|
12
|
+
F.pairwise_distance(x1, x2, p: @norm, eps: @eps, keepdim: @keepdim)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
data/lib/torch/nn/parameter.rb
CHANGED
@@ -0,0 +1,16 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class PoissonNLLLoss < Loss
|
4
|
+
def initialize(log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@log_input = log_input
|
7
|
+
@full = full
|
8
|
+
@eps = eps
|
9
|
+
end
|
10
|
+
|
11
|
+
def forward(log_input, target)
|
12
|
+
F.poisson_nll_loss(log_input, target, log_input: @log_input, full: @full, eps: @eps, reduction: @reduction)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class PReLU < Module
|
4
|
+
def initialize(num_parameters: 1, init: 0.25)
|
5
|
+
@num_parameters = num_parameters
|
6
|
+
super()
|
7
|
+
@weight = Parameter.new(Tensor.new(num_parameters).fill!(init))
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input)
|
11
|
+
F.prelu(input, @weight)
|
12
|
+
end
|
13
|
+
|
14
|
+
def extra_inspect
|
15
|
+
format("num_parameters: %s", @num_parameters)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
data/lib/torch/nn/relu.rb
CHANGED
@@ -1,12 +1,17 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class ReLU < Module
|
4
|
-
def initialize
|
5
|
-
|
4
|
+
def initialize(inplace: false)
|
5
|
+
super()
|
6
|
+
@inplace = inplace
|
6
7
|
end
|
7
8
|
|
8
9
|
def forward(input)
|
9
|
-
F.relu(input
|
10
|
+
F.relu(input, inplace: @inplace)
|
11
|
+
end
|
12
|
+
|
13
|
+
def extra_inspect
|
14
|
+
@inplace ? "inplace: true" : ""
|
10
15
|
end
|
11
16
|
end
|
12
17
|
end
|
data/lib/torch/nn/sequential.rb
CHANGED
@@ -2,28 +2,19 @@ module Torch
|
|
2
2
|
module NN
|
3
3
|
class Sequential < Module
|
4
4
|
def initialize(*args)
|
5
|
-
|
5
|
+
super()
|
6
6
|
# TODO support hash arg (named modules)
|
7
7
|
args.each_with_index do |mod, idx|
|
8
8
|
add_module(idx.to_s, mod)
|
9
9
|
end
|
10
10
|
end
|
11
11
|
|
12
|
-
def add_module(name, mod)
|
13
|
-
# TODO add checks
|
14
|
-
@modules[name] = mod
|
15
|
-
end
|
16
|
-
|
17
12
|
def forward(input)
|
18
13
|
@modules.values.each do |mod|
|
19
14
|
input = mod.call(input)
|
20
15
|
end
|
21
16
|
input
|
22
17
|
end
|
23
|
-
|
24
|
-
def parameters
|
25
|
-
@modules.flat_map { |_, mod| mod.parameters }
|
26
|
-
end
|
27
18
|
end
|
28
19
|
end
|
29
20
|
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Softmax < Module
|
4
|
+
def initialize(dim: nil)
|
5
|
+
super()
|
6
|
+
@dim = dim
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input)
|
10
|
+
F.softmax(input, dim: @dim)
|
11
|
+
end
|
12
|
+
|
13
|
+
def extra_inspect
|
14
|
+
format("dim: %s", @dim)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Softplus < Module
|
4
|
+
def initialize(beta: 1, threshold: 20)
|
5
|
+
super()
|
6
|
+
@beta = beta
|
7
|
+
@threshold = threshold
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input)
|
11
|
+
F.softplus(input, beta: @beta, threshold: @threshold)
|
12
|
+
end
|
13
|
+
|
14
|
+
def extra_inspect
|
15
|
+
format("beta: %s, threshold: %s", @beta, @threshold)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
data/lib/torch/random.rb
ADDED
data/lib/torch/tensor.rb
CHANGED
@@ -66,16 +66,11 @@ module Torch
|
|
66
66
|
end
|
67
67
|
|
68
68
|
def backward(gradient = nil)
|
69
|
-
|
70
|
-
_backward_gradient(gradient)
|
71
|
-
else
|
72
|
-
_backward
|
73
|
-
end
|
69
|
+
_backward(gradient)
|
74
70
|
end
|
75
71
|
|
76
72
|
# TODO read directly from memory
|
77
73
|
def numo
|
78
|
-
raise Error, "Numo not found" unless defined?(Numo::NArray)
|
79
74
|
cls = Torch._dtype_to_numo[dtype]
|
80
75
|
raise Error, "Cannot convert #{dtype} to Numo" unless cls
|
81
76
|
cls.cast(_data).reshape(*shape)
|
@@ -113,7 +108,7 @@ module Torch
|
|
113
108
|
end
|
114
109
|
|
115
110
|
# operations
|
116
|
-
%w(abs add argmax div dot eq exp gt log lt matmul max mean min mul neg norm num numel pow remainder reshape sign sqrt sub sum unsqueeze).each do |op|
|
111
|
+
%w(abs add argmax div dot eq exp gt log log_softmax lt matmul max mean min mul neg norm num numel pow relu remainder reshape sign softmax sqrt sub sum unsqueeze topk).each do |op|
|
117
112
|
define_method(op) do |*args, **options, &block|
|
118
113
|
if options.any?
|
119
114
|
Torch.send(op, self, *args, **options, &block)
|
@@ -167,8 +162,14 @@ module Torch
|
|
167
162
|
finish += 1 unless index.exclude_end?
|
168
163
|
result = result._slice(dim, index.begin, finish, 1)
|
169
164
|
dim += 1
|
165
|
+
elsif index.nil?
|
166
|
+
result = result.unsqueeze(dim)
|
167
|
+
dim += 1
|
168
|
+
elsif index == true
|
169
|
+
result = result.unsqueeze(dim)
|
170
|
+
# TODO handle false
|
170
171
|
else
|
171
|
-
raise Error, "Unsupported index type"
|
172
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
172
173
|
end
|
173
174
|
end
|
174
175
|
result
|
@@ -176,11 +177,28 @@ module Torch
|
|
176
177
|
|
177
178
|
# TODO
|
178
179
|
# based on python_variable_indexing.cpp
|
179
|
-
|
180
|
-
|
180
|
+
def []=(index, value)
|
181
|
+
raise ArgumentError, "Tensor does not support deleting items" if value.nil?
|
182
|
+
|
183
|
+
value = Torch.tensor(value) unless value.is_a?(Tensor)
|
184
|
+
|
185
|
+
if index.is_a?(Numeric)
|
186
|
+
copy_to(_select(0, index), value)
|
187
|
+
elsif index.is_a?(Range)
|
188
|
+
finish = index.end
|
189
|
+
finish += 1 unless index.exclude_end?
|
190
|
+
copy_to(_slice(0, index.begin, finish, 1), value)
|
191
|
+
else
|
192
|
+
raise Error, "Unsupported index type: #{index.class.name}"
|
193
|
+
end
|
194
|
+
end
|
181
195
|
|
182
196
|
private
|
183
197
|
|
198
|
+
def copy_to(dst, src)
|
199
|
+
dst.copy!(src)
|
200
|
+
end
|
201
|
+
|
184
202
|
def reshape_arr(arr, dims)
|
185
203
|
if dims.empty?
|
186
204
|
arr
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.4
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-
|
11
|
+
date: 2019-12-01 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -111,22 +111,48 @@ files:
|
|
111
111
|
- lib/torch/ext.bundle
|
112
112
|
- lib/torch/inspector.rb
|
113
113
|
- lib/torch/nn/alpha_dropout.rb
|
114
|
+
- lib/torch/nn/avg_pool2d.rb
|
115
|
+
- lib/torch/nn/avg_poolnd.rb
|
116
|
+
- lib/torch/nn/bce_loss.rb
|
117
|
+
- lib/torch/nn/bilinear.rb
|
114
118
|
- lib/torch/nn/conv2d.rb
|
115
119
|
- lib/torch/nn/convnd.rb
|
120
|
+
- lib/torch/nn/cosine_similarity.rb
|
121
|
+
- lib/torch/nn/cross_entropy_loss.rb
|
122
|
+
- lib/torch/nn/ctc_loss.rb
|
116
123
|
- lib/torch/nn/dropout.rb
|
117
124
|
- lib/torch/nn/dropout2d.rb
|
118
125
|
- lib/torch/nn/dropout3d.rb
|
119
126
|
- lib/torch/nn/dropoutnd.rb
|
120
127
|
- lib/torch/nn/embedding.rb
|
128
|
+
- lib/torch/nn/embedding_bag.rb
|
121
129
|
- lib/torch/nn/feature_alpha_dropout.rb
|
122
130
|
- lib/torch/nn/functional.rb
|
131
|
+
- lib/torch/nn/identity.rb
|
123
132
|
- lib/torch/nn/init.rb
|
133
|
+
- lib/torch/nn/kl_div_loss.rb
|
134
|
+
- lib/torch/nn/l1_loss.rb
|
135
|
+
- lib/torch/nn/leaky_relu.rb
|
124
136
|
- lib/torch/nn/linear.rb
|
137
|
+
- lib/torch/nn/log_softmax.rb
|
138
|
+
- lib/torch/nn/loss.rb
|
139
|
+
- lib/torch/nn/max_pool2d.rb
|
140
|
+
- lib/torch/nn/max_poolnd.rb
|
125
141
|
- lib/torch/nn/module.rb
|
126
142
|
- lib/torch/nn/mse_loss.rb
|
143
|
+
- lib/torch/nn/nll_loss.rb
|
144
|
+
- lib/torch/nn/pairwise_distance.rb
|
127
145
|
- lib/torch/nn/parameter.rb
|
146
|
+
- lib/torch/nn/poisson_nll_loss.rb
|
147
|
+
- lib/torch/nn/prelu.rb
|
128
148
|
- lib/torch/nn/relu.rb
|
129
149
|
- lib/torch/nn/sequential.rb
|
150
|
+
- lib/torch/nn/sigmoid.rb
|
151
|
+
- lib/torch/nn/softmax.rb
|
152
|
+
- lib/torch/nn/softmax2d.rb
|
153
|
+
- lib/torch/nn/softmin.rb
|
154
|
+
- lib/torch/nn/softplus.rb
|
155
|
+
- lib/torch/nn/weighted_loss.rb
|
130
156
|
- lib/torch/optim/adadelta.rb
|
131
157
|
- lib/torch/optim/adagrad.rb
|
132
158
|
- lib/torch/optim/adam.rb
|
@@ -139,6 +165,7 @@ files:
|
|
139
165
|
- lib/torch/optim/rmsprop.rb
|
140
166
|
- lib/torch/optim/rprop.rb
|
141
167
|
- lib/torch/optim/sgd.rb
|
168
|
+
- lib/torch/random.rb
|
142
169
|
- lib/torch/tensor.rb
|
143
170
|
- lib/torch/utils/data/data_loader.rb
|
144
171
|
- lib/torch/utils/data/tensor_dataset.rb
|