torch-rb 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +1 -0
- data/ext/torch/ext.cpp +375 -124
- data/lib/torch.rb +101 -20
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +23 -19
- data/lib/torch/nn/avg_pool2d.rb +14 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/conv2d.rb +2 -2
- data/lib/torch/nn/convnd.rb +3 -3
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropoutnd.rb +2 -2
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/functional.rb +101 -13
- data/lib/torch/nn/identity.rb +13 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/module.rb +120 -31
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +0 -4
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +28 -10
- data/lib/torch/version.rb +1 -1
- metadata +29 -2
data/lib/torch.rb
CHANGED
@@ -22,31 +22,90 @@ require "torch/optim/sgd"
|
|
22
22
|
require "torch/optim/lr_scheduler/lr_scheduler"
|
23
23
|
require "torch/optim/lr_scheduler/step_lr"
|
24
24
|
|
25
|
-
# nn
|
25
|
+
# nn parameters
|
26
|
+
require "torch/nn/parameter"
|
27
|
+
|
28
|
+
# nn containers
|
26
29
|
require "torch/nn/module"
|
30
|
+
require "torch/nn/sequential"
|
31
|
+
|
32
|
+
# nn convolution layers
|
27
33
|
require "torch/nn/convnd"
|
28
|
-
require "torch/nn/
|
34
|
+
require "torch/nn/conv2d"
|
35
|
+
|
36
|
+
# nn pooling layers
|
37
|
+
require "torch/nn/max_poolnd"
|
38
|
+
require "torch/nn/max_pool2d"
|
39
|
+
require "torch/nn/avg_poolnd"
|
40
|
+
require "torch/nn/avg_pool2d"
|
41
|
+
|
42
|
+
# nn linear layers
|
43
|
+
require "torch/nn/bilinear"
|
44
|
+
require "torch/nn/identity"
|
45
|
+
require "torch/nn/linear"
|
29
46
|
|
30
|
-
# nn
|
47
|
+
# nn dropout layers
|
48
|
+
require "torch/nn/dropoutnd"
|
31
49
|
require "torch/nn/alpha_dropout"
|
32
|
-
require "torch/nn/conv2d"
|
33
50
|
require "torch/nn/dropout"
|
34
51
|
require "torch/nn/dropout2d"
|
35
52
|
require "torch/nn/dropout3d"
|
36
|
-
require "torch/nn/embedding"
|
37
53
|
require "torch/nn/feature_alpha_dropout"
|
54
|
+
|
55
|
+
# nn activations
|
56
|
+
require "torch/nn/leaky_relu"
|
57
|
+
require "torch/nn/prelu"
|
58
|
+
require "torch/nn/relu"
|
59
|
+
require "torch/nn/sigmoid"
|
60
|
+
require "torch/nn/softplus"
|
61
|
+
|
62
|
+
# nn activations other
|
63
|
+
require "torch/nn/log_softmax"
|
64
|
+
require "torch/nn/softmax"
|
65
|
+
require "torch/nn/softmax2d"
|
66
|
+
require "torch/nn/softmin"
|
67
|
+
|
68
|
+
# nn sparse layers
|
69
|
+
require "torch/nn/embedding"
|
70
|
+
require "torch/nn/embedding_bag"
|
71
|
+
|
72
|
+
# nn distance functions
|
73
|
+
require "torch/nn/cosine_similarity"
|
74
|
+
require "torch/nn/pairwise_distance"
|
75
|
+
|
76
|
+
# nn loss functions
|
77
|
+
require "torch/nn/loss"
|
78
|
+
require "torch/nn/weighted_loss"
|
79
|
+
require "torch/nn/bce_loss"
|
80
|
+
# require "torch/nn/bce_with_logits_loss"
|
81
|
+
# require "torch/nn/cosine_embedding_loss"
|
82
|
+
require "torch/nn/cross_entropy_loss"
|
83
|
+
require "torch/nn/ctc_loss"
|
84
|
+
# require "torch/nn/hinge_embedding_loss"
|
85
|
+
require "torch/nn/kl_div_loss"
|
86
|
+
require "torch/nn/l1_loss"
|
87
|
+
# require "torch/nn/margin_ranking_loss"
|
88
|
+
require "torch/nn/mse_loss"
|
89
|
+
# require "torch/nn/multi_label_margin_loss"
|
90
|
+
# require "torch/nn/multi_label_soft_margin_loss"
|
91
|
+
# require "torch/nn/multi_margin_loss"
|
92
|
+
require "torch/nn/nll_loss"
|
93
|
+
require "torch/nn/poisson_nll_loss"
|
94
|
+
# require "torch/nn/smooth_l1_loss"
|
95
|
+
# require "torch/nn/soft_margin_loss"
|
96
|
+
# require "torch/nn/triplet_margin_loss"
|
97
|
+
|
98
|
+
# nn other
|
38
99
|
require "torch/nn/functional"
|
39
100
|
require "torch/nn/init"
|
40
|
-
require "torch/nn/linear"
|
41
|
-
require "torch/nn/mse_loss"
|
42
|
-
require "torch/nn/parameter"
|
43
|
-
require "torch/nn/relu"
|
44
|
-
require "torch/nn/sequential"
|
45
101
|
|
46
102
|
# utils
|
47
103
|
require "torch/utils/data/data_loader"
|
48
104
|
require "torch/utils/data/tensor_dataset"
|
49
105
|
|
106
|
+
# random
|
107
|
+
require "torch/random"
|
108
|
+
|
50
109
|
module Torch
|
51
110
|
class Error < StandardError; end
|
52
111
|
class NotImplementedYet < StandardError
|
@@ -57,7 +116,6 @@ module Torch
|
|
57
116
|
|
58
117
|
# keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
|
59
118
|
# values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
|
60
|
-
# complex and quantized types not supported by PyTorch yet
|
61
119
|
DTYPE_TO_ENUM = {
|
62
120
|
uint8: 0,
|
63
121
|
int8: 1,
|
@@ -73,14 +131,14 @@ module Torch
|
|
73
131
|
float32: 6,
|
74
132
|
double: 7,
|
75
133
|
float64: 7,
|
76
|
-
|
77
|
-
|
78
|
-
|
134
|
+
complex_half: 8,
|
135
|
+
complex_float: 9,
|
136
|
+
complex_double: 10,
|
79
137
|
bool: 11,
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
138
|
+
qint8: 12,
|
139
|
+
quint8: 13,
|
140
|
+
qint32: 14,
|
141
|
+
bfloat16: 15
|
84
142
|
}
|
85
143
|
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
|
86
144
|
|
@@ -120,6 +178,8 @@ module Torch
|
|
120
178
|
# use method for cases when Numo not available
|
121
179
|
# or available after Torch loaded
|
122
180
|
def _dtype_to_numo
|
181
|
+
raise Error, "Numo not found" unless defined?(Numo::NArray)
|
182
|
+
|
123
183
|
{
|
124
184
|
uint8: Numo::UInt8,
|
125
185
|
int8: Numo::Int8,
|
@@ -200,8 +260,12 @@ module Torch
|
|
200
260
|
data = [data].compact
|
201
261
|
end
|
202
262
|
|
203
|
-
if options[:dtype].nil?
|
204
|
-
|
263
|
+
if options[:dtype].nil?
|
264
|
+
if data.all? { |v| v.is_a?(Integer) }
|
265
|
+
options[:dtype] = :int64
|
266
|
+
elsif data.all? { |v| v == true || v == false }
|
267
|
+
options[:dtype] = :bool
|
268
|
+
end
|
205
269
|
end
|
206
270
|
|
207
271
|
_tensor(data, size, tensor_options(**options))
|
@@ -302,6 +366,10 @@ module Torch
|
|
302
366
|
_pow(input, exponent)
|
303
367
|
end
|
304
368
|
|
369
|
+
def topk(input, k)
|
370
|
+
_topk(input, k)
|
371
|
+
end
|
372
|
+
|
305
373
|
def min(input)
|
306
374
|
_min(input)
|
307
375
|
end
|
@@ -327,6 +395,10 @@ module Torch
|
|
327
395
|
_sign(input)
|
328
396
|
end
|
329
397
|
|
398
|
+
def sigmoid(input)
|
399
|
+
_sigmoid(input)
|
400
|
+
end
|
401
|
+
|
330
402
|
def gt(input, other)
|
331
403
|
_gt(input, other)
|
332
404
|
end
|
@@ -363,6 +435,15 @@ module Torch
|
|
363
435
|
_sqrt(input)
|
364
436
|
end
|
365
437
|
|
438
|
+
# TODO make dim keyword argument
|
439
|
+
def log_softmax(input, dim)
|
440
|
+
_log_softmax(input, dim)
|
441
|
+
end
|
442
|
+
|
443
|
+
def softmax(input, dim: nil)
|
444
|
+
_softmax(input, dim)
|
445
|
+
end
|
446
|
+
|
366
447
|
def abs(input)
|
367
448
|
_abs(input)
|
368
449
|
end
|
data/lib/torch/ext.bundle
CHANGED
Binary file
|
data/lib/torch/inspector.rb
CHANGED
@@ -11,32 +11,36 @@ module Torch
|
|
11
11
|
else
|
12
12
|
summarize = numel > 1000
|
13
13
|
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
14
|
+
if dtype == :bool
|
15
|
+
fmt = "%s"
|
16
|
+
else
|
17
|
+
values = to_a.flatten
|
18
|
+
abs = values.select { |v| v != 0 }.map(&:abs)
|
19
|
+
max = abs.max || 1
|
20
|
+
min = abs.min || 1
|
18
21
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
22
|
+
total = 0
|
23
|
+
if values.any? { |v| v < 0 }
|
24
|
+
total += 1
|
25
|
+
end
|
23
26
|
|
24
|
-
|
25
|
-
|
27
|
+
if floating_point?
|
28
|
+
sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
|
26
29
|
|
27
|
-
|
28
|
-
|
30
|
+
all_int = values.all? { |v| v.finite? && v == v.to_i }
|
31
|
+
decimal = all_int ? 1 : 4
|
29
32
|
|
30
|
-
|
33
|
+
total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
|
31
34
|
|
32
|
-
|
33
|
-
|
35
|
+
if sci
|
36
|
+
fmt = "%#{total}.4e"
|
37
|
+
else
|
38
|
+
fmt = "%#{total}.#{decimal}f"
|
39
|
+
end
|
34
40
|
else
|
35
|
-
|
41
|
+
total += max.to_s.size
|
42
|
+
fmt = "%#{total}d"
|
36
43
|
end
|
37
|
-
else
|
38
|
-
total += max.to_s.size
|
39
|
-
fmt = "%#{total}d"
|
40
44
|
end
|
41
45
|
|
42
46
|
inspect_level(to_a, fmt, dim - 1, 0, summarize)
|
@@ -0,0 +1,13 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class BCELoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
end
|
7
|
+
|
8
|
+
def forward(input, target)
|
9
|
+
F.binary_cross_entropy(input, target, weight: @weight, reduction: @reduction)
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,38 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Bilinear < Module
|
4
|
+
def initialize(in1_features, in2_features, out_features, bias: true)
|
5
|
+
super()
|
6
|
+
|
7
|
+
@in1_features = in1_features
|
8
|
+
@in2_features = in2_features
|
9
|
+
@out_features = out_features
|
10
|
+
@weight = Parameter.new(Tensor.new(out_features, in1_features, in2_features))
|
11
|
+
|
12
|
+
if bias
|
13
|
+
@bias = Parameter.new(Tensor.new(out_features))
|
14
|
+
else
|
15
|
+
raise NotImplementedYet
|
16
|
+
end
|
17
|
+
|
18
|
+
reset_parameters
|
19
|
+
end
|
20
|
+
|
21
|
+
def reset_parameters
|
22
|
+
bound = 1 / Math.sqrt(@weight.size(1))
|
23
|
+
Init.uniform!(@weight, a: -bound, b: bound)
|
24
|
+
if @bias
|
25
|
+
Init.uniform!(@bias, a: -bound, b: bound)
|
26
|
+
end
|
27
|
+
end
|
28
|
+
|
29
|
+
def forward(input1, input2)
|
30
|
+
F.bilinear(input1, input2, @weight, @bias)
|
31
|
+
end
|
32
|
+
|
33
|
+
def extra_inspect
|
34
|
+
format("in1_features: %s, in2_features: %s, out_features: %s, bias: %s", @in1_features, @in2_features, @out_features, !@bias.nil?)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
data/lib/torch/nn/conv2d.rb
CHANGED
@@ -19,8 +19,8 @@ module Torch
|
|
19
19
|
end
|
20
20
|
|
21
21
|
# TODO add more parameters
|
22
|
-
def
|
23
|
-
"
|
22
|
+
def extra_inspect
|
23
|
+
format("%s, %s, kernel_size: %s, stride: %s", @in_channels, @out_channels, @kernel_size, @stride)
|
24
24
|
end
|
25
25
|
|
26
26
|
private
|
data/lib/torch/nn/convnd.rb
CHANGED
@@ -29,11 +29,11 @@ module Torch
|
|
29
29
|
end
|
30
30
|
|
31
31
|
def reset_parameters
|
32
|
-
Init.kaiming_uniform!(@weight, Math.sqrt(5))
|
32
|
+
Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
|
33
33
|
if @bias
|
34
|
-
fan_in, _ = Init.
|
34
|
+
fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
|
35
35
|
bound = 1 / Math.sqrt(fan_in)
|
36
|
-
Init.uniform!(@bias, -bound, bound)
|
36
|
+
Init.uniform!(@bias, a: -bound, b: bound)
|
37
37
|
end
|
38
38
|
end
|
39
39
|
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CrossEntropyLoss < WeightedLoss
|
4
|
+
def initialize(weight: nil, ignore_index: -100, reduction: "mean")
|
5
|
+
super(weight, reduction)
|
6
|
+
@ignore_index = ignore_index
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input, target)
|
10
|
+
F.cross_entropy(input, target, weight: @weight, ignore_index: @ignore_index, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class CTCLoss < Loss
|
4
|
+
def initialize(blank: 0, reduction: "mean", zero_infinity: false)
|
5
|
+
super(reduction)
|
6
|
+
@blank = blank
|
7
|
+
@zero_infinity = zero_infinity
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(log_probs, targets, input_lengths, target_lengths)
|
11
|
+
F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: @blank, reduction: @reduction, zero_infinity: @zero_infinity)
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
data/lib/torch/nn/dropoutnd.rb
CHANGED
@@ -0,0 +1,34 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
|
2
|
+
module Torch
|
3
|
+
module NN
|
4
|
+
class EmbeddingBag < Module
|
5
|
+
def initialize(num_embeddings, embedding_dim, max_norm: nil, norm_type: 2.0,
|
6
|
+
scale_grad_by_freq: false, mode: "mean", sparse: false, _weight: nil)
|
7
|
+
|
8
|
+
super()
|
9
|
+
@num_embeddings = num_embeddings
|
10
|
+
@embedding_dim = embedding_dim
|
11
|
+
@max_norm = max_norm
|
12
|
+
@norm_type = norm_type
|
13
|
+
@scale_grad_by_freq = scale_grad_by_freq
|
14
|
+
if _weight.nil?
|
15
|
+
@weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
|
16
|
+
reset_parameters
|
17
|
+
else
|
18
|
+
raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
|
19
|
+
@weight = Parameter.new(_weight)
|
20
|
+
end
|
21
|
+
@mode = mode
|
22
|
+
@sparse = sparse
|
23
|
+
end
|
24
|
+
|
25
|
+
def reset_parameters
|
26
|
+
Init.normal!(@weight)
|
27
|
+
end
|
28
|
+
|
29
|
+
def forward(input, offsets: nil, per_sample_weights: nil)
|
30
|
+
F.embedding_bag(input, @weight, offsets: offsets, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, mode: @mode, sparse: @sparse, per_sample_weights: per_sample_weights)
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|