torch-rb 0.1.0 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +40 -0
- data/LICENSE.txt +46 -22
- data/README.md +85 -19
- data/ext/torch/ext.cpp +274 -256
- data/ext/torch/extconf.rb +9 -0
- data/ext/torch/nn_functions.cpp +595 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.hpp +250 -0
- data/ext/torch/tensor_functions.cpp +1860 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2875 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +199 -84
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +52 -25
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +78 -0
- data/lib/torch/native/generator.rb +149 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +97 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/avg_pool2d.rb +14 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/conv2d.rb +14 -29
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +194 -11
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/module.rb +184 -19
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +154 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +62 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +60 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +92 -21
- data/lib/torch/utils/data/data_loader.rb +15 -0
- data/lib/torch/utils/data/tensor_dataset.rb +8 -1
- data/lib/torch/version.rb +1 -1
- metadata +74 -3
@@ -0,0 +1,34 @@
|
|
1
|
+
# ported from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/sparse.py
|
2
|
+
module Torch
|
3
|
+
module NN
|
4
|
+
class EmbeddingBag < Module
|
5
|
+
def initialize(num_embeddings, embedding_dim, max_norm: nil, norm_type: 2.0,
|
6
|
+
scale_grad_by_freq: false, mode: "mean", sparse: false, _weight: nil)
|
7
|
+
|
8
|
+
super()
|
9
|
+
@num_embeddings = num_embeddings
|
10
|
+
@embedding_dim = embedding_dim
|
11
|
+
@max_norm = max_norm
|
12
|
+
@norm_type = norm_type
|
13
|
+
@scale_grad_by_freq = scale_grad_by_freq
|
14
|
+
if _weight.nil?
|
15
|
+
@weight = Parameter.new(Tensor.new(num_embeddings, embedding_dim))
|
16
|
+
reset_parameters
|
17
|
+
else
|
18
|
+
raise ArgumentError, "Shape of weight does not match num_embeddings and embedding_dim" unless _weight.shape == [num_embeddings, embedding_dim]
|
19
|
+
@weight = Parameter.new(_weight)
|
20
|
+
end
|
21
|
+
@mode = mode
|
22
|
+
@sparse = sparse
|
23
|
+
end
|
24
|
+
|
25
|
+
def reset_parameters
|
26
|
+
Init.normal!(@weight)
|
27
|
+
end
|
28
|
+
|
29
|
+
def forward(input, offsets: nil, per_sample_weights: nil)
|
30
|
+
F.embedding_bag(input, @weight, offsets: offsets, max_norm: @max_norm, norm_type: @norm_type, scale_grad_by_freq: @scale_grad_by_freq, mode: @mode, sparse: @sparse, per_sample_weights: per_sample_weights)
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -2,12 +2,25 @@ module Torch
|
|
2
2
|
module NN
|
3
3
|
class Functional
|
4
4
|
class << self
|
5
|
-
def relu(input)
|
6
|
-
|
5
|
+
def relu(input, inplace: false)
|
6
|
+
if inplace
|
7
|
+
input.relu!
|
8
|
+
else
|
9
|
+
input.relu
|
10
|
+
end
|
7
11
|
end
|
8
12
|
|
9
|
-
def conv2d(input, weight, bias)
|
10
|
-
|
13
|
+
def conv2d(input, weight, bias, stride: 1, padding: 0, dilation: 1, groups: 1)
|
14
|
+
# TODO pair stride and padding when needed
|
15
|
+
Torch.conv2d(input, weight, bias, stride, padding, dilation, groups)
|
16
|
+
end
|
17
|
+
|
18
|
+
def prelu(input, weight)
|
19
|
+
Torch.prelu(input, weight)
|
20
|
+
end
|
21
|
+
|
22
|
+
def leaky_relu(input, negative_slope = 0.01)
|
23
|
+
Torch.leaky_relu(input, negative_slope)
|
11
24
|
end
|
12
25
|
|
13
26
|
def max_pool2d(input, kernel_size)
|
@@ -15,26 +28,196 @@ module Torch
|
|
15
28
|
Torch.max_pool2d(input, kernel_size)
|
16
29
|
end
|
17
30
|
|
31
|
+
def avg_pool2d(input, kernel_size)
|
32
|
+
kernel_size = [kernel_size, kernel_size] if kernel_size.is_a?(Integer)
|
33
|
+
Torch.avg_pool2d(input, kernel_size)
|
34
|
+
end
|
35
|
+
|
36
|
+
# linear layers
|
37
|
+
|
38
|
+
def bilinear(input1, input2, weight, bias)
|
39
|
+
Torch.bilinear(input1, input2, weight, bias)
|
40
|
+
end
|
41
|
+
|
18
42
|
def linear(input, weight, bias)
|
19
43
|
Torch.linear(input, weight, bias)
|
20
44
|
end
|
21
45
|
|
46
|
+
# sparse layers
|
47
|
+
|
48
|
+
def embedding(input, weight, padding_idx: nil, max_norm: nil, norm_type: 2.0, scale_grad_by_freq: false, sparse: false)
|
49
|
+
# TODO handle max_norm and norm_type
|
50
|
+
raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
|
51
|
+
|
52
|
+
padding_idx ||= -1
|
53
|
+
# weight and indices are swapped from Python interface
|
54
|
+
Torch._embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
|
55
|
+
end
|
56
|
+
|
57
|
+
def embedding_bag(input, weight, offsets: nil, max_norm: nil, norm_type: 2, scale_grad_by_freq: false, mode: "mean", sparse: false, per_sample_weights: nil)
|
58
|
+
# need to handle nils
|
59
|
+
raise NotImplementedYet
|
60
|
+
|
61
|
+
# TODO handle max_norm and norm_type
|
62
|
+
raise NotImplementedYet unless max_norm.nil? && norm_type == 2.0
|
63
|
+
|
64
|
+
Torch._embedding_bag(input, weight, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights)
|
65
|
+
end
|
66
|
+
|
67
|
+
# distance functions
|
68
|
+
|
69
|
+
def cosine_similarity(x1, x2, dim: 1, eps: 1e-8)
|
70
|
+
Torch._cosine_similarity(x1, x2, dim, eps)
|
71
|
+
end
|
72
|
+
|
73
|
+
def pairwise_distance(x1, x2, p: 2.0, eps: 1e-6, keepdim: false)
|
74
|
+
Torch._pairwise_distance(x1, x2, p, eps, keepdim)
|
75
|
+
end
|
76
|
+
|
77
|
+
# loss functions
|
78
|
+
|
79
|
+
def binary_cross_entropy(input, target, weight: nil, reduction: "mean")
|
80
|
+
NN._binary_cross_entropy(input, target, weight, reduction)
|
81
|
+
end
|
82
|
+
|
83
|
+
def binary_cross_entropy_with_logits(input, target, weight: nil, reduction: "mean", pos_weight: nil)
|
84
|
+
Torch._binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction)
|
85
|
+
end
|
86
|
+
|
87
|
+
def cosine_embedding_loss(input1, input2, target, margin: 0, reduction: "mean")
|
88
|
+
raise NotImplementedYet
|
89
|
+
end
|
90
|
+
|
91
|
+
def cross_entropy(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
92
|
+
nll_loss(log_softmax(input, 1), target, weight: weight, ignore_index: ignore_index, reduction: reduction)
|
93
|
+
end
|
94
|
+
|
95
|
+
def ctc_loss(log_probs, targets, input_lengths, target_lengths, blank: 0, reduction: "mean", zero_infinity: false)
|
96
|
+
# call to_a on input_lengths and target_lengths for C++
|
97
|
+
Torch._ctc_loss_intlist(log_probs, targets, input_lengths.to_a, target_lengths.to_a, blank, reduction, zero_infinity)
|
98
|
+
end
|
99
|
+
|
100
|
+
def hinge_embedding_loss(input, target, margin: 1.0, reduction: "mean")
|
101
|
+
Torch._hinge_embedding_loss(input, target, margin, reduction)
|
102
|
+
end
|
103
|
+
|
104
|
+
def kl_div(input, target, reduction: "mean")
|
105
|
+
Torch._kl_div(input, target, reduction)
|
106
|
+
end
|
107
|
+
|
108
|
+
def l1_loss(input, target, reduction: "mean")
|
109
|
+
NN._l1_loss(input, target, reduction)
|
110
|
+
end
|
111
|
+
|
112
|
+
def margin_ranking_loss(input1, input2, target, margin: 0, reduction: "mean")
|
113
|
+
raise NotImplementedYet
|
114
|
+
end
|
115
|
+
|
22
116
|
def mse_loss(input, target, reduction: "mean")
|
23
|
-
|
117
|
+
NN._mse_loss(input, target, reduction)
|
118
|
+
end
|
119
|
+
|
120
|
+
def multilabel_margin_loss(input, target, reduction: "mean")
|
121
|
+
NN._multilabel_margin_loss(input, target, reduction)
|
122
|
+
end
|
123
|
+
|
124
|
+
def multilabel_soft_margin_loss(input, target, weight: nil)
|
125
|
+
raise NotImplementedYet
|
126
|
+
end
|
127
|
+
|
128
|
+
def multi_margin_loss(input, target, p: 1, margin: 1.0, weight: nil, reduction: "mean")
|
129
|
+
NN._multi_margin_loss(input, target, p, margin, weight, reduction)
|
24
130
|
end
|
25
131
|
|
26
|
-
def
|
27
|
-
|
132
|
+
def nll_loss(input, target, weight: nil, ignore_index: -100, reduction: "mean")
|
133
|
+
NN._nll_loss(input, target, weight, reduction, ignore_index)
|
28
134
|
end
|
29
135
|
|
30
|
-
def
|
31
|
-
|
32
|
-
|
136
|
+
def poisson_nll_loss(input, target, log_input: true, full: false, eps: 1e-8, reduction: "mean")
|
137
|
+
Torch._poisson_nll_loss(input, target, log_input, full, eps, reduction)
|
138
|
+
end
|
139
|
+
|
140
|
+
def soft_margin_loss(input, target, reduction: "mean")
|
141
|
+
NN._soft_margin_loss(input, target, reduction)
|
142
|
+
end
|
143
|
+
|
144
|
+
def smooth_l1_loss(input, target, reduction: "mean")
|
145
|
+
NN._smooth_l1_loss(input, target, reduction)
|
146
|
+
end
|
147
|
+
|
148
|
+
def triplet_margin_loss(anchor, positive, negative, margin: 1.0, p: 2, eps: 1e-06, swap: false, reduction: "mean")
|
149
|
+
Torch._triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
150
|
+
end
|
151
|
+
|
152
|
+
# end loss
|
153
|
+
|
154
|
+
def softmax(input, dim: nil)
|
155
|
+
dim ||= softmax_dim(input.dim)
|
156
|
+
input.softmax(dim: dim)
|
33
157
|
end
|
34
158
|
|
35
|
-
def
|
159
|
+
def softmin(input, dim: nil)
|
160
|
+
dim ||= softmax_dim(input.dim)
|
161
|
+
(-input).softmax(dim: dim)
|
162
|
+
end
|
163
|
+
|
164
|
+
def softplus(input, beta: 1, threshold: 20)
|
165
|
+
NN._softplus(input, beta, threshold)
|
166
|
+
end
|
167
|
+
|
168
|
+
# TODO make dim keyword argument and update examples
|
169
|
+
def log_softmax(input, dim = nil)
|
170
|
+
dim ||= softmax_dim(input.dim)
|
36
171
|
input.log_softmax(dim)
|
37
172
|
end
|
173
|
+
|
174
|
+
def dropout(input, p: 0.5, training: true, inplace: false)
|
175
|
+
if inplace
|
176
|
+
Torch._dropout_(input, p, training)
|
177
|
+
else
|
178
|
+
Torch._dropout(input, p, training)
|
179
|
+
end
|
180
|
+
end
|
181
|
+
|
182
|
+
def dropout2d(input, p: 0.5, training: true, inplace: false)
|
183
|
+
raise ArgumentError, "dropout probability has to be between 0 and 1, but got #{p}" if p < 0 || p > 1
|
184
|
+
|
185
|
+
if inplace
|
186
|
+
Torch._feature_dropout_(input, p, training)
|
187
|
+
else
|
188
|
+
Torch._feature_dropout(input, p, training)
|
189
|
+
end
|
190
|
+
end
|
191
|
+
|
192
|
+
def dropout3d(input, p: 0.5, training: true, inplace: false)
|
193
|
+
if inplace
|
194
|
+
Torch._feature_dropout_(input, p, training)
|
195
|
+
else
|
196
|
+
Torch._feature_dropout(input, p, training)
|
197
|
+
end
|
198
|
+
end
|
199
|
+
|
200
|
+
def alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
201
|
+
if inplace
|
202
|
+
Torch._alpha_dropout_(input, p, training)
|
203
|
+
else
|
204
|
+
Torch._alpha_dropout(input, p, training)
|
205
|
+
end
|
206
|
+
end
|
207
|
+
|
208
|
+
def feature_alpha_dropout(input, p: 0.5, training: true, inplace: false)
|
209
|
+
if inplace
|
210
|
+
Torch._feature_alpha_dropout_(input, p, training)
|
211
|
+
else
|
212
|
+
Torch._feature_alpha_dropout(input, p, training)
|
213
|
+
end
|
214
|
+
end
|
215
|
+
|
216
|
+
private
|
217
|
+
|
218
|
+
def softmax_dim(ndim)
|
219
|
+
ndim == 0 || ndim == 1 || ndim == 3 ? 0 : 1
|
220
|
+
end
|
38
221
|
end
|
39
222
|
end
|
40
223
|
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class HingeEmbeddingLoss < Loss
|
4
|
+
def initialize(margin: 1.0, reduction: "mean")
|
5
|
+
super(reduction)
|
6
|
+
@margin = margin
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(input, target)
|
10
|
+
F.hinge_embedding_loss(input, target, margin: @margin, reduction: @reduction)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
data/lib/torch/nn/init.rb
CHANGED
@@ -2,7 +2,64 @@ module Torch
|
|
2
2
|
module NN
|
3
3
|
module Init
|
4
4
|
class << self
|
5
|
-
def
|
5
|
+
def calculate_gain(nonlinearity, param: 0.01)
|
6
|
+
_calculate_gain(nonlinearity, param)
|
7
|
+
end
|
8
|
+
|
9
|
+
def uniform!(tensor, a: 0.0, b: 1.0)
|
10
|
+
_uniform!(tensor, a, b)
|
11
|
+
end
|
12
|
+
|
13
|
+
def normal!(tensor, mean: 0.0, std: 1.0)
|
14
|
+
_normal!(tensor, mean, std)
|
15
|
+
end
|
16
|
+
|
17
|
+
def constant!(tensor, val)
|
18
|
+
_constant!(tensor, val)
|
19
|
+
end
|
20
|
+
|
21
|
+
def ones!(tensor)
|
22
|
+
_ones!(tensor)
|
23
|
+
end
|
24
|
+
|
25
|
+
def zeros!(tensor)
|
26
|
+
_zeros!(tensor)
|
27
|
+
end
|
28
|
+
|
29
|
+
def eye!(tensor)
|
30
|
+
_eye!(tensor)
|
31
|
+
end
|
32
|
+
|
33
|
+
def dirac!(tensor)
|
34
|
+
_dirac!(tensor)
|
35
|
+
end
|
36
|
+
|
37
|
+
def xavier_uniform!(tensor, gain: 1.0)
|
38
|
+
_xavier_uniform!(tensor, gain)
|
39
|
+
end
|
40
|
+
|
41
|
+
def xavier_normal!(tensor, gain: 1.0)
|
42
|
+
_xavier_normal!(tensor, gain)
|
43
|
+
end
|
44
|
+
|
45
|
+
def kaiming_uniform!(tensor, a: 0, mode: "fan_in", nonlinearity: "leaky_relu")
|
46
|
+
_kaiming_uniform!(tensor, a, mode, nonlinearity)
|
47
|
+
end
|
48
|
+
|
49
|
+
def kaiming_normal!(tensor, a: 0, mode: "fan_in", nonlinearity: "leaky_relu")
|
50
|
+
_kaiming_normal!(tensor, a, mode, nonlinearity)
|
51
|
+
end
|
52
|
+
|
53
|
+
def orthogonal!(tensor, gain: 1)
|
54
|
+
_orthogonal!(tensor, gain)
|
55
|
+
end
|
56
|
+
|
57
|
+
def sparse!(tensor, sparsity, std: 0.01)
|
58
|
+
_sparse!(tensor, sparsity, std)
|
59
|
+
end
|
60
|
+
|
61
|
+
# TODO move to C++ when released
|
62
|
+
def _calculate_fan_in_and_fan_out(tensor)
|
6
63
|
dimensions = tensor.dim
|
7
64
|
if dimensions < 2
|
8
65
|
raise Error, "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
|
@@ -0,0 +1,20 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class LeakyReLU < Module
|
4
|
+
def initialize(negative_slope: 1e-2) #, inplace: false)
|
5
|
+
super()
|
6
|
+
@negative_slope = negative_slope
|
7
|
+
# @inplace = inplace
|
8
|
+
end
|
9
|
+
|
10
|
+
def forward(input)
|
11
|
+
F.leaky_relu(input, @negative_slope) #, inplace: @inplace)
|
12
|
+
end
|
13
|
+
|
14
|
+
def extra_inspect
|
15
|
+
inplace_str = @inplace ? ", inplace: true" : ""
|
16
|
+
format("negative_slope: %s%s", @negative_slope, inplace_str)
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|
20
|
+
end
|
data/lib/torch/nn/linear.rb
CHANGED
@@ -1,35 +1,36 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class Linear < Module
|
4
|
-
attr_reader :bias, :weight
|
5
|
-
|
6
4
|
def initialize(in_features, out_features, bias: true)
|
5
|
+
super()
|
7
6
|
@in_features = in_features
|
8
7
|
@out_features = out_features
|
9
8
|
|
10
9
|
@weight = Parameter.new(Tensor.new(out_features, in_features))
|
11
10
|
if bias
|
12
11
|
@bias = Parameter.new(Tensor.new(out_features))
|
12
|
+
else
|
13
|
+
register_parameter("bias", nil)
|
13
14
|
end
|
14
15
|
|
15
16
|
reset_parameters
|
16
17
|
end
|
17
18
|
|
18
|
-
def call(input)
|
19
|
-
F.linear(input, @weight, @bias)
|
20
|
-
end
|
21
|
-
|
22
19
|
def reset_parameters
|
23
|
-
Init.
|
20
|
+
Init.kaiming_uniform!(@weight, a: Math.sqrt(5))
|
24
21
|
if @bias
|
25
|
-
fan_in, _ = Init.
|
22
|
+
fan_in, _ = Init._calculate_fan_in_and_fan_out(@weight)
|
26
23
|
bound = 1 / Math.sqrt(fan_in)
|
27
|
-
Init.
|
24
|
+
Init.uniform!(@bias, a: -bound, b: bound)
|
28
25
|
end
|
29
26
|
end
|
30
27
|
|
31
|
-
def
|
32
|
-
|
28
|
+
def forward(input)
|
29
|
+
F.linear(input, @weight, @bias)
|
30
|
+
end
|
31
|
+
|
32
|
+
def extra_inspect
|
33
|
+
format("in_features: %s, out_features: %s, bias: %s", @in_features, @out_features, !@bias.nil?)
|
33
34
|
end
|
34
35
|
end
|
35
36
|
end
|