torch-rb 0.1.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +28 -0
- data/LICENSE.txt +46 -0
- data/README.md +426 -0
- data/ext/torch/ext.cpp +839 -0
- data/ext/torch/extconf.rb +25 -0
- data/lib/torch-rb.rb +1 -0
- data/lib/torch.rb +422 -0
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +85 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/conv2d.rb +37 -0
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +100 -0
- data/lib/torch/nn/init.rb +30 -0
- data/lib/torch/nn/linear.rb +36 -0
- data/lib/torch/nn/module.rb +85 -0
- data/lib/torch/nn/mse_loss.rb +13 -0
- data/lib/torch/nn/parameter.rb +14 -0
- data/lib/torch/nn/relu.rb +13 -0
- data/lib/torch/nn/sequential.rb +29 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +62 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +60 -0
- data/lib/torch/tensor.rb +196 -0
- data/lib/torch/utils/data/data_loader.rb +27 -0
- data/lib/torch/utils/data/tensor_dataset.rb +22 -0
- data/lib/torch/version.rb +3 -0
- metadata +169 -0
@@ -0,0 +1,85 @@
|
|
1
|
+
module Torch
|
2
|
+
module Inspector
|
3
|
+
# TODO make more performance, especially when summarizing
|
4
|
+
# how? only read data that will be displayed
|
5
|
+
def inspect
|
6
|
+
data =
|
7
|
+
if numel == 0
|
8
|
+
"[]"
|
9
|
+
elsif dim == 0
|
10
|
+
item
|
11
|
+
else
|
12
|
+
summarize = numel > 1000
|
13
|
+
|
14
|
+
values = to_a.flatten
|
15
|
+
abs = values.select { |v| v != 0 }.map(&:abs)
|
16
|
+
max = abs.max || 1
|
17
|
+
min = abs.min || 1
|
18
|
+
|
19
|
+
total = 0
|
20
|
+
if values.any? { |v| v < 0 }
|
21
|
+
total += 1
|
22
|
+
end
|
23
|
+
|
24
|
+
if floating_point?
|
25
|
+
sci = max / min.to_f > 1000 || max > 1e8 || min < 1e-4
|
26
|
+
|
27
|
+
all_int = values.all? { |v| v.finite? && v == v.to_i }
|
28
|
+
decimal = all_int ? 1 : 4
|
29
|
+
|
30
|
+
total += sci ? 10 : decimal + 1 + max.to_i.to_s.size
|
31
|
+
|
32
|
+
if sci
|
33
|
+
fmt = "%#{total}.4e"
|
34
|
+
else
|
35
|
+
fmt = "%#{total}.#{decimal}f"
|
36
|
+
end
|
37
|
+
else
|
38
|
+
total += max.to_s.size
|
39
|
+
fmt = "%#{total}d"
|
40
|
+
end
|
41
|
+
|
42
|
+
inspect_level(to_a, fmt, dim - 1, 0, summarize)
|
43
|
+
end
|
44
|
+
|
45
|
+
attributes = []
|
46
|
+
if requires_grad
|
47
|
+
attributes << "requires_grad: true"
|
48
|
+
end
|
49
|
+
if ![:float32, :int64, :bool].include?(dtype)
|
50
|
+
attributes << "dtype: #{dtype.inspect}"
|
51
|
+
end
|
52
|
+
|
53
|
+
"tensor(#{data}#{attributes.map { |a| ", #{a}" }.join("")})"
|
54
|
+
end
|
55
|
+
|
56
|
+
private
|
57
|
+
|
58
|
+
# TODO DRY code
|
59
|
+
def inspect_level(arr, fmt, total, level, summarize)
|
60
|
+
if level == total
|
61
|
+
cols =
|
62
|
+
if summarize && arr.size > 7
|
63
|
+
arr[0..2].map { |v| fmt % v } +
|
64
|
+
["..."] +
|
65
|
+
arr[-3..-1].map { |v| fmt % v }
|
66
|
+
else
|
67
|
+
arr.map { |v| fmt % v }
|
68
|
+
end
|
69
|
+
|
70
|
+
"[#{cols.join(", ")}]"
|
71
|
+
else
|
72
|
+
rows =
|
73
|
+
if summarize && arr.size > 7
|
74
|
+
arr[0..2].map { |row| inspect_level(row, fmt, total, level + 1, summarize) } +
|
75
|
+
["..."] +
|
76
|
+
arr[-3..-1].map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
|
77
|
+
else
|
78
|
+
arr.map { |row| inspect_level(row, fmt, total, level + 1, summarize) }
|
79
|
+
end
|
80
|
+
|
81
|
+
"[#{rows.join(",#{"\n" * (total - level)}#{" " * (level + 8)}")}]"
|
82
|
+
end
|
83
|
+
end
|
84
|
+
end
|
85
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Conv2d < ConvNd
|
4
|
+
attr_reader :bias, :weight
|
5
|
+
|
6
|
+
def initialize(in_channels, out_channels, kernel_size, stride: 1, padding: 0, dilation: 1, groups: 1, bias: true, padding_mode: "zeros")
|
7
|
+
kernel_size = pair(kernel_size)
|
8
|
+
stride = pair(stride)
|
9
|
+
padding = pair(padding)
|
10
|
+
dilation = pair(dilation)
|
11
|
+
super(in_channels, out_channels, kernel_size, stride, padding, dilation, false, pair(0), groups, bias, padding_mode)
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(input)
|
15
|
+
if @padding_mode == "circular"
|
16
|
+
raise NotImplementedError
|
17
|
+
end
|
18
|
+
F.conv2d(input, @weight, @bias, stride: @stride, padding: @padding, dilation: @dilation, groups: @groups)
|
19
|
+
end
|
20
|
+
|
21
|
+
# TODO add more parameters
|
22
|
+
def inspect
|
23
|
+
"Conv2d(#{@in_channels}, #{@out_channels}, kernel_size: #{@kernel_size.inspect}, stride: #{@stride.inspect})"
|
24
|
+
end
|
25
|
+
|
26
|
+
private
|
27
|
+
|
28
|
+
def pair(value)
|
29
|
+
if value.is_a?(Array)
|
30
|
+
value
|
31
|
+
else
|
32
|
+
[value] * 2
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,41 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class ConvNd < Module
|
4
|
+
def initialize(in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode)
|
5
|
+
super()
|
6
|
+
raise ArgumentError, "in_channels must be divisible by groups" if in_channels % groups != 0
|
7
|
+
raise ArgumentError, "out_channels must be divisible by groups" if out_channels % groups != 0
|
8
|
+
@in_channels = in_channels
|
9
|
+
@out_channels = out_channels
|
10
|
+
@kernel_size = kernel_size
|
11
|
+
@stride = stride
|
12
|
+
@padding = padding
|
13
|
+
@dilation = dilation
|
14
|
+
@transposed = transposed
|
15
|
+
@output_padding = output_padding
|
16
|
+
@groups = groups
|
17
|
+
@padding_mode = padding_mode
|
18
|
+
if transposed
|
19
|
+
@weight = Parameter.new(Tensor.new(in_channels, out_channels / groups, *kernel_size))
|
20
|
+
else
|
21
|
+
@weight = Parameter.new(Tensor.new(out_channels, in_channels / groups, *kernel_size))
|
22
|
+
end
|
23
|
+
if bias
|
24
|
+
@bias = Parameter.new(Tensor.new(out_channels))
|
25
|
+
else
|
26
|
+
raise NotImplementedError
|
27
|
+
end
|
28
|
+
reset_parameters
|
29
|
+
end
|
30
|
+
|
31
|
+
def reset_parameters
|
32
|
+
Init.kaiming_uniform!(@weight, Math.sqrt(5))
|
33
|
+
if @bias
|
34
|
+
fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
|
35
|
+
bound = 1 / Math.sqrt(fan_in)
|
36
|
+
Init.uniform!(@bias, -bound, bound)
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class DropoutNd < Module
|
4
|
+
def initialize(p: 0.5, inplace: false)
|
5
|
+
super()
|
6
|
+
@p = p
|
7
|
+
@inplace = inplace
|
8
|
+
end
|
9
|
+
|
10
|
+
def inspect
|
11
|
+
"#{self.class.name.split("::").last}(p: #{@p.inspect}, inplace: #{@inplace.inspect})"
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
@@ -0,0 +1,52 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
|
2
|
+
module Torch
|
3
|
+
module NN
|
4
|
+
class Embedding < Module
|
5
|
+
def initialize(num_embeddings, embedding_dim, padding_idx: nil, max_norm: nil,
|
6
|
+
norm_type: 2.0, scale_grad_by_freq: false, sparse: false, _weight: nil)
|
7
|
+
|
8
|
+
super()
|
9
|
+
@num_embeddings = num_embeddings
|
10
|
+
@embedding_dim = embedding_dim
|
11
|
+
|
12
|
+
if padding_idx
|
13
|
+
if padding_idx > 0
|
14
|
+
raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx < @num_embeddings
|
15
|
+
elsif padding_idx < 0
|
16
|
+
raise ArgumentError, "Padding_idx must be within num_embeddings" unless padding_idx >= -@num_embeddings
|
17
|
+
padding_idx = @num_embeddings + padding_idx
|
18
|
+
end
|
19
|
+
end
|
20
|
+
@padding_idx = padding_idx
|
21
|
+
@max_norm = max_norm
|
22
|
+
@norm_type = norm_type
|
23
|
+
@scale_grad_by_freq = scale_grad_by_freq
|
24
|
+
if _weight.nil?
|
25
|
+
@weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
|
26
|
+
reset_parameters
|
27
|
+
else
|
28
|
+
raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
|
29
|
+
@weight = Parameter.new(_weight)
|
30
|
+
end
|
31
|
+
@sparse = sparse
|
32
|
+
end
|
33
|
+
|
34
|
+
def reset_parameters
|
35
|
+
Init.normal!(@weight)
|
36
|
+
if @padding_idx
|
37
|
+
Torch.no_grad do
|
38
|
+
@weight[@padding_idx].fill!(0)
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
def forward(input)
|
44
|
+
F.embedding(input, @weight, padding_idx: @padding_idx, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, sparse: @sparse)
|
45
|
+
end
|
46
|
+
|
47
|
+
def inspect
|
48
|
+
"Embedding(#{@num_embeddings}, #{@embedding_dim})"
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
52
|
+
end
|
@@ -0,0 +1,100 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Functional
|
4
|
+
class << self
|
5
|
+
def relu(input)
|
6
|
+
Torch.relu(input)
|
7
|
+
end
|
8
|
+
|
9
|
+
def conv2d(input, weight, bias, stride: 1, padding: 0, dilation: 1, groups: 1)
|
10
|
+
# TODO pair stride and padding when needed
|
11
|
+
Torch.conv2d(input, weight, bias, stride, padding, dilation, groups)
|
12
|
+
end
|
13
|
+
|
14
|
+
def max_pool2d(input, kernel_size)
|
15
|
+
kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
|
16
|
+
Torch.max_pool2d(input, kernel_size)
|
17
|
+
end
|
18
|
+
|
19
|
+
def avg_pool2d(input, kernel_size)
|
20
|
+
kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
|
21
|
+
Torch.avg_pool2d(input, kernel_size)
|
22
|
+
end
|
23
|
+
|
24
|
+
def linear(input, weight, bias)
|
25
|
+
Torch.linear(input, weight, bias)
|
26
|
+
end
|
27
|
+
|
28
|
+
def mse_loss(input, target, reduction: "mean")
|
29
|
+
Torch.mse_loss(input, target, reduction)
|
30
|
+
end
|
31
|
+
|
32
|
+
def cross_entropy(input, target)
|
33
|
+
nll_loss(log_softmax(input, 1), target)
|
34
|
+
end
|
35
|
+
|
36
|
+
def nll_loss(input, target, reduction: "mean")
|
37
|
+
# TODO fix for non-1d
|
38
|
+
Torch.nll_loss(input, target, reduction)
|
39
|
+
end
|
40
|
+
|
41
|
+
def log_softmax(input, dim)
|
42
|
+
input.log_softmax(dim)
|
43
|
+
end
|
44
|
+
|
45
|
+
def dropout(input, p: 0.5, training: true, inplace: false)
|
46
|
+
if inplace
|
47
|
+
Torch._dropout!(input, p, training)
|
48
|
+
else
|
49
|
+
Torch._dropout(input, p, training)
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
def dropout2d(input, p: 0.5, training: true, inplace: false)
|
54
|
+
raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
|
55
|
+
|
56
|
+
if inplace
|
57
|
+
Torch._feature_dropout!(input, p, training)
|
58
|
+
else
|
59
|
+
Torch._feature_dropout(input, p, training)
|
60
|
+
end
|
61
|
+
end
|
62
|
+
|
63
|
+
def dropout3d(input, p: 0.5, training: true, inplace: false)
|
64
|
+
if inplace
|
65
|
+
Torch._feature_dropout!(input, p, training)
|
66
|
+
else
|
67
|
+
Torch._feature_dropout(input, p, training)
|
68
|
+
end
|
69
|
+
end
|
70
|
+
|
71
|
+
def alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
72
|
+
if inplace
|
73
|
+
Torch._alpha_dropout!(input, p, training)
|
74
|
+
else
|
75
|
+
Torch._alpha_dropout(input, p, training)
|
76
|
+
end
|
77
|
+
end
|
78
|
+
|
79
|
+
def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
80
|
+
if inplace
|
81
|
+
Torch._feature_alpha_dropout!(input, p, training)
|
82
|
+
else
|
83
|
+
Torch._feature_alpha_dropout(input, p, training)
|
84
|
+
end
|
85
|
+
end
|
86
|
+
|
87
|
+
def embedding(input, weight, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false)
|
88
|
+
# TODO handle max_norm and norm_type
|
89
|
+
raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
|
90
|
+
|
91
|
+
padding_idx ||= -1
|
92
|
+
Torch._embedding(input, weight, padding_idx, scale_grad_by_freq, sparse)
|
93
|
+
end
|
94
|
+
end
|
95
|
+
end
|
96
|
+
|
97
|
+
# shortcut
|
98
|
+
F = Functional
|
99
|
+
end
|
100
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
module Init
|
4
|
+
class << self
|
5
|
+
def calculate_fan_in_and_fan_out(tensor)
|
6
|
+
dimensions = tensor.dim
|
7
|
+
if dimensions < 2
|
8
|
+
raise Error, "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
|
9
|
+
end
|
10
|
+
|
11
|
+
if dimensions == 2
|
12
|
+
fan_in = tensor.size(1)
|
13
|
+
fan_out = tensor.size(0)
|
14
|
+
else
|
15
|
+
num_input_fmaps = tensor.size(1)
|
16
|
+
num_output_fmaps = tensor.size(0)
|
17
|
+
receptive_field_size = 1
|
18
|
+
if tensor.dim > 2
|
19
|
+
receptive_field_size = tensor[0][0].numel
|
20
|
+
end
|
21
|
+
fan_in = num_input_fmaps * receptive_field_size
|
22
|
+
fan_out = num_output_fmaps * receptive_field_size
|
23
|
+
end
|
24
|
+
|
25
|
+
[fan_in, fan_out]
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,36 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Linear < Module
|
4
|
+
attr_reader :bias, :weight
|
5
|
+
|
6
|
+
def initialize(in_features, out_features, bias: true)
|
7
|
+
@in_features = in_features
|
8
|
+
@out_features = out_features
|
9
|
+
|
10
|
+
@weight = Parameter.new(Tensor.new(out_features, in_features))
|
11
|
+
if bias
|
12
|
+
@bias = Parameter.new(Tensor.new(out_features))
|
13
|
+
end
|
14
|
+
|
15
|
+
reset_parameters
|
16
|
+
end
|
17
|
+
|
18
|
+
def call(input)
|
19
|
+
F.linear(input, @weight, @bias)
|
20
|
+
end
|
21
|
+
|
22
|
+
def reset_parameters
|
23
|
+
Init.kaiming_uniform!(@weight, Math.sqrt(5))
|
24
|
+
if @bias
|
25
|
+
fan_in, _ = Init.calculate_fan_in_and_fan_out(@weight)
|
26
|
+
bound = 1 / Math.sqrt(fan_in)
|
27
|
+
Init.uniform!(@bias, -bound, bound)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
def inspect
|
32
|
+
"Linear(in_features: #{@in_features.inspect}, out_features: #{@out_features.inspect}, bias: #{(!@bias.nil?).inspect})"
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|