torch-rb 0.1.0 → 0.1.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +40 -0
- data/LICENSE.txt +46 -22
- data/README.md +85 -19
- data/ext/torch/ext.cpp +274 -256
- data/ext/torch/extconf.rb +9 -0
- data/ext/torch/nn_functions.cpp +595 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.hpp +250 -0
- data/ext/torch/tensor_functions.cpp +1860 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2875 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +199 -84
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +52 -25
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +78 -0
- data/lib/torch/native/generator.rb +149 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +97 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/avg_pool2d.rb +14 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/conv2d.rb +14 -29
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +194 -11
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/module.rb +184 -19
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +154 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +62 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +60 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +92 -21
- data/lib/torch/utils/data/data_loader.rb +15 -0
- data/lib/torch/utils/data/tensor_dataset.rb +8 -1
- data/lib/torch/version.rb +1 -1
- metadata +74 -3
@@ -0,0 +1,34 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
|
2
|
+
module Torch
|
3
|
+
module NN
|
4
|
+
class EmbeddingBag < Module
|
5
|
+
def initialize(num_embeddings, embedding_dim, max_norm: nil, norm_type: 2.0,
|
6
|
+
scale_grad_by_freq: false, mode: "mean", sparse: false, _weight: nil)
|
7
|
+
|
8
|
+
super()
|
9
|
+
@num_embeddings = num_embeddings
|
10
|
+
@embedding_dim = embedding_dim
|
11
|
+
@max_norm = max_norm
|
12
|
+
@norm_type = norm_type
|
13
|
+
@scale_grad_by_freq = scale_grad_by_freq
|
14
|
+
if _weight.nil?
|
15
|
+
@weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
|
16
|
+
reset_parameters
|
17
|
+
else
|
18
|
+
raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
|
19
|
+
@weight = Parameter.new(_weight)
|
20
|
+
end
|
21
|
+
@mode = mode
|
22
|
+
@sparse = sparse
|
23
|
+
end
|
24
|
+
|
25
|
+
def reset_parameters
|
26
|
+
Init.normal!(@weight)
|
27
|
+
end
|
28
|
+
|
29
|
+
def forward(input, offsets: nil, per_sample_weights: nil)
|
30
|
+
F.embedding_bag(input, @weight, offsets: offsets, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, mode: @mode, sparse: @sparse, per_sample_weights: per_sample_weights)
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -2,12 +2,25 @@ module Torch
|
|
2
2
|
module NN
|
3
3
|
class Functional
|
4
4
|
class << self
|
5
|
-
def relu(input)
|
6
|
-
|
5
|
+
def relu(input, inplace: false)
|
6
|
+
if inplace
|
7
|
+
input.relu!
|
8
|
+
else
|
9
|
+
input.relu
|
10
|
+
end
|
7
11
|
end
|
8
12
|
|
9
|
-
def conv2d(input, weight, bias)
|
10
|
-
|
13
|
+
def conv2d(input, weight, bias, stride: 1, padding: 0, dilation: 1, groups: 1)
|
14
|
+
# TODO pair stride and padding when needed
|
15
|
+
Torch.conv2d(input, weight, bias, stride, padding, dilation, groups)
|
16
|
+
end
|
17
|
+
|
18
|
+
def prelu(input, weight)
|
19
|
+
Torch.prelu(input, weight)
|
20
|
+
end
|
21
|
+
|
22
|
+
def leaky_relu(input, negative_slope = 0.01)
|
23
|
+
Torch.leaky_relu(input, negative_slope)
|
11
24
|
end
|
12
25
|
|
13
26
|
def max_pool2d(input, kernel_size)
|
@@ -15,26 +28,196 @@ module Torch
|
|
15
28
|
Torch.max_pool2d(input, kernel_size)
|
16
29
|
end
|
17
30
|
|
31
|
+
def avg_pool2d(input, kernel_size)
|
32
|
+
kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
|
33
|
+
Torch.avg_pool2d(input, kernel_size)
|
34
|
+
end
|
35
|
+
|
36
|
+
# linear layers
|
37
|
+
|
38
|
+
def bilinear(input1, input2, weight, bias)
|
39
|
+
Torch.bilinear(input1, input2, weight, bias)
|
40
|
+
end
|
41
|
+
|
18
42
|
def linear(input, weight, bias)
|
19
43
|
Torch.linear(input, weight, bias)
|
20
44
|
end
|
21
45
|
|
46
|
+
# sparse layers
|
47
|
+
|
48
|
+
def embedding(input, weight, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false)
|
49
|
+
# TODO handle max_norm and norm_type
|
50
|
+
raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
|
51
|
+
|
52
|
+
padding_idx ||= -1
|
53
|
+
# weight and indices are swapped from Python interface
|
54
|
+
Torch._embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
|
55
|
+
end
|
56
|
+
|
57
|
+
def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil)
|
58
|
+
# need to handle nils
|
59
|
+
raise NotImplementedYet
|
60
|
+
|
61
|
+
# TODO handle max_norm and norm_type
|
62
|
+
raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
|
63
|
+
|
64
|
+
Torch._embedding_bag(input, weight, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights)
|
65
|
+
end
|
66
|
+
|
67
|
+
# distance functions
|
68
|
+
|
69
|
+
def cosine_similarity(x1, x2, dim: 1, eps: 1e-8)
|
70
|
+
Torch._cosine_similarity(x1, x2, dim, eps)
|
71
|
+
end
|
72
|
+
|
73
|
+
def pairwise_distance(x1, x2, p: 2.0, eps: 1e-6, keepdim: false)
|
74
|
+
Torch._pairwise_distance(x1, x2, p, eps, keepdim)
|
75
|
+
end
|
76
|
+
|
77
|
+
# loss functions
|
78
|
+
|
79
|
+
def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
|
80
|
+
NN._binary_cross_entropy(input, target, weight, reduction)
|
81
|
+
end
|
82
|
+
|
83
|
+
def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
|
84
|
+
Torch._binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
|
85
|
+
end
|
86
|
+
|
87
|
+
def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
|
88
|
+
raise NotImplementedYet
|
89
|
+
end
|
90
|
+
|
91
|
+
def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
92
|
+
nll_loss(log_softmax(input, 1), target, weight: weight, ignore_index: ignore_index, reduction: reduction)
|
93
|
+
end
|
94
|
+
|
95
|
+
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
|
96
|
+
# call to_a on input_lengths and target_lengths for C++
|
97
|
+
Torch._ctc_loss_intlist(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
|
98
|
+
end
|
99
|
+
|
100
|
+
def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
|
101
|
+
Torch._hinge_embedding_loss(input, target, margin, reduction)
|
102
|
+
end
|
103
|
+
|
104
|
+
def kl_div(input, target, reduction: "mean")
|
105
|
+
Torch._kl_div(input, target, reduction)
|
106
|
+
end
|
107
|
+
|
108
|
+
def l1_loss(input, target, reduction: "mean")
|
109
|
+
NN._l1_loss(input, target, reduction)
|
110
|
+
end
|
111
|
+
|
112
|
+
def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
|
113
|
+
raise NotImplementedYet
|
114
|
+
end
|
115
|
+
|
22
116
|
def mse_loss(input, target, reduction: "mean")
|
23
|
-
|
117
|
+
NN._mse_loss(input, target, reduction)
|
118
|
+
end
|
119
|
+
|
120
|
+
def multilabel_margin_loss(input, target, reduction: "mean")
|
121
|
+
NN._multilabel_margin_loss(input, target, reduction)
|
122
|
+
end
|
123
|
+
|
124
|
+
def multilabel_soft_margin_loss(input, target, weight: nil)
|
125
|
+
raise NotImplementedYet
|
126
|
+
end
|
127
|
+
|
128
|
+
def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
|
129
|
+
NN._multi_margin_loss(input, target, p, margin, weight, reduction)
|
24
130
|
end
|
25
131
|
|
26
|
-
def
|
27
|
-
|
132
|
+
def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
133
|
+
NN._nll_loss(input, target, weight, reduction, ignore_index)
|
28
134
|
end
|
29
135
|
|
30
|
-
def
|
31
|
-
|
32
|
-
|
136
|
+
def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
137
|
+
Torch._poisson_nll_loss(input, target, log_input, full, eps, reduction)
|
138
|
+
end
|
139
|
+
|
140
|
+
def soft_margin_loss(input, target, reduction: "mean")
|
141
|
+
NN._soft_margin_loss(input, target, reduction)
|
142
|
+
end
|
143
|
+
|
144
|
+
def smooth_l1_loss(input, target, reduction: "mean")
|
145
|
+
NN._smooth_l1_loss(input, target, reduction)
|
146
|
+
end
|
147
|
+
|
148
|
+
def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
|
149
|
+
Torch._triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
150
|
+
end
|
151
|
+
|
152
|
+
# end loss
|
153
|
+
|
154
|
+
def softmax(input, dim: nil)
|
155
|
+
dim ||= softmax_dim(input.dim)
|
156
|
+
input.softmax(dim: dim)
|
33
157
|
end
|
34
158
|
|
35
|
-
def
|
159
|
+
def softmin(input, dim: nil)
|
160
|
+
dim ||= softmax_dim(input.dim)
|
161
|
+
(-input).softmax(dim: dim)
|
162
|
+
end
|
163
|
+
|
164
|
+
def softplus(input, beta: 1, threshold: 20)
|
165
|
+
NN._softplus(input, beta, threshold)
|
166
|
+
end
|
167
|
+
|
168
|
+
# TODO make dim keyword argument and update examples
|
169
|
+
def log_softmax(input, dim = nil)
|
170
|
+
dim ||= softmax_dim(input.dim)
|
36
171
|
input.log_softmax(dim)
|
37
172
|
end
|
173
|
+
|
174
|
+
def dropout(input, p: 0.5, training: true, inplace: false)
|
175
|
+
if inplace
|
176
|
+
Torch._dropout_(input, p, training)
|
177
|
+
else
|
178
|
+
Torch._dropout(input, p, training)
|
179
|
+
end
|
180
|
+
end
|
181
|
+
|
182
|
+
def dropout2d(input, p: 0.5, training: true, inplace: false)
|
183
|
+
raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
|
184
|
+
|
185
|
+
if inplace
|
186
|
+
Torch._feature_dropout_(input, p, training)
|
187
|
+
else
|
188
|
+
Torch._feature_dropout(input, p, training)
|
189
|
+
end
|
190
|
+
end
|
191
|
+
|
192
|
+
def dropout3d(input, p: 0.5, training: true, inplace: false)
|
193
|
+
if inplace
|
194
|
+
Torch._feature_dropout_(input, p, training)
|
195
|
+
else
|
196
|
+
Torch._feature_dropout(input, p, training)
|
197
|
+
end
|
198
|
+
end
|
199
|
+
|
200
|
+
def alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
201
|
+
if inplace
|
202
|
+
Torch._alpha_dropout_(input, p, training)
|
203
|
+
else
|
204
|
+
Torch._alpha_dropout(input, p, training)
|
205
|
+
end
|
206
|
+
end
|
207
|
+
|
208
|
+
def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
209
|
+
if inplace
|
210
|
+
Torch._feature_alpha_dropout_(input, p, training)
|
211
|
+
else
|
212
|
+
Torch._feature_alpha_dropout(input, p, training)
|
213
|
+
end
|
214
|
+
end
|
215
|
+
|
216
|
+
private
|
217
|
+
|
218
|
+
def softmax_dim(ndim)
|
219
|
+
ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
|
220
|
+
end
|
38
221
|
end
|
39
222
|
end
|
40
223
|
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class HingeEmbeddingLoss < Loss
|
4
|
+
def initialize(margin: 1.0, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input, target)
|
10
|
+
F.hinge_embedding_loss(input, target, margin: @margin, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
data/lib/torch/nn/init.rb
CHANGED
@@ -2,7 +2,64 @@ module Torch
|
|
2
2
|
module NN
|
3
3
|
module Init
|
4
4
|
class << self
|
5
|
-
def
|
5
|
+
def calculate_gain(nonlinearity, param: 0.01)
|
6
|
+
_calculate_gain(nonlinearity, param)
|
7
|
+
end
|
8
|
+
|
9
|
+
def uniform!(tensor, a: 0.0, b: 1.0)
|
10
|
+
_uniform!(tensor, a, b)
|
11
|
+
end
|
12
|
+
|
13
|
+
def normal!(tensor, mean: 0.0, std: 1.0)
|
14
|
+
_normal!(tensor, mean, std)
|
15
|
+
end
|
16
|
+
|
17
|
+
def constant!(tensor, val)
|
18
|
+
_constant!(tensor, val)
|
19
|
+
end
|
20
|
+
|
21
|
+
def ones!(tensor)
|
22
|
+
_ones!(tensor)
|
23
|
+
end
|
24
|
+
|
25
|
+
def zeros!(tensor)
|
26
|
+
_zeros!(tensor)
|
27
|
+
end
|
28
|
+
|
29
|
+
def eye!(tensor)
|
30
|
+
_eye!(tensor)
|
31
|
+
end
|
32
|
+
|
33
|
+
def dirac!(tensor)
|
34
|
+
_dirac!(tensor)
|
35
|
+
end
|
36
|
+
|
37
|
+
def xavier_uniform!(tensor, gain: 1.0)
|
38
|
+
_xavier_uniform!(tensor, gain)
|
39
|
+
end
|
40
|
+
|
41
|
+
def xavier_normal!(tensor, gain: 1.0)
|
42
|
+
_xavier_normal!(tensor, gain)
|
43
|
+
end
|
44
|
+
|
45
|
+
def kaiming_uniform!(tensor, a: 0, mode: "fan_in", nonlinearity: "leaky_relu")
|
46
|
+
_kaiming_uniform!(tensor, a, mode, nonlinearity)
|
47
|
+
end
|
48
|
+
|
49
|
+
def kaiming_normal!(tensor, a: 0, mode: "fan_in", nonlinearity: "leaky_relu")
|
50
|
+
_kaiming_normal!(tensor, a, mode, nonlinearity)
|
51
|
+
end
|
52
|
+
|
53
|
+
def orthogonal!(tensor, gain: 1)
|
54
|
+
_orthogonal!(tensor, gain)
|
55
|
+
end
|
56
|
+
|
57
|
+
def sparse!(tensor, sparsity, std: 0.01)
|
58
|
+
_sparse!(tensor, sparsity, std)
|
59
|
+
end
|
60
|
+
|
61
|
+
# TODO move to C++ when released
|
62
|
+
def _calculate_fan_in_and_fan_out(tensor)
|
6
63
|
dimensions = tensor.dim
|
7
64
|
if dimensions < 2
|
8
65
|
raise Error, "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
|
@@ -0,0 +1,20 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class LeakyReLU < Module
|
4
|
+
def initialize(negative_slope: 1e-2) #, inplace: false)
|
5
|
+
super()
|
6
|
+
@negative_slope = negative_slope
|
7
|
+
# @inplace = inplace
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input)
|
11
|
+
F.leaky_relu(input, @negative_slope) #, inplace: @inplace)
|
12
|
+
end
|
13
|
+
|
14
|
+
def extra_inspect
|
15
|
+
inplace_str = @inplace ? ", inplace: true" : ""
|
16
|
+
format("negative_slope: %s%s", @negative_slope, inplace_str)
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
data/lib/torch/nn/linear.rb
CHANGED
@@ -1,35 +1,36 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Linear < Module
|
4
|
-
attr_reader :bias, :weight
|
5
|
-
|
6
4
|
def initialize(in_features, out_features, bias: true)
|
5
|
+
super()
|
7
6
|
@in_features = in_features
|
8
7
|
@out_features = out_features
|
9
8
|
|
10
9
|
@weight = Parameter.new(Tensor.new(out_features, in_features))
|
11
10
|
if bias
|
12
11
|
@bias = Parameter.new(Tensor.new(out_features))
|
12
|
+
else
|
13
|
+
register_parameter("bias", nil)
|
13
14
|
end
|
14
15
|
|
15
16
|
reset_parameters
|
16
17
|
end
|
17
18
|
|
18
|
-
def call(input)
|
19
|
-
F.linear(input, @weight, @bias)
|
20
|
-
end
|
21
|
-
|
22
19
|
def reset_parameters
|
23
|
-
Init.
|
20
|
+
Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
|
24
21
|
if @bias
|
25
|
-
fan_in, _ = Init.
|
22
|
+
fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
|
26
23
|
bound = 1 / Math.sqrt(fan_in)
|
27
|
-
Init.
|
24
|
+
Init.uniform!(@bias, a: -bound, b: bound)
|
28
25
|
end
|
29
26
|
end
|
30
27
|
|
31
|
-
def
|
32
|
-
|
28
|
+
def forward(input)
|
29
|
+
F.linear(input, @weight, @bias)
|
30
|
+
end
|
31
|
+
|
32
|
+
def extra_inspect
|
33
|
+
format("in_features: %s, out_features: %s, bias: %s", @in_features, @out_features, !@bias.nil?)
|
33
34
|
end
|
34
35
|
end
|
35
36
|
end
|