SM2ST 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sm2st-0.0.1/LICENSE.txt +21 -0
- sm2st-0.0.1/PKG-INFO +17 -0
- sm2st-0.0.1/README.md +3 -0
- sm2st-0.0.1/SM2ST/SMLED.py +332 -0
- sm2st-0.0.1/SM2ST/Train_SMLED.py +363 -0
- sm2st-0.0.1/SM2ST/__init__.py +15 -0
- sm2st-0.0.1/SM2ST/dataset.py +85 -0
- sm2st-0.0.1/SM2ST/gatv2_conv.py +213 -0
- sm2st-0.0.1/SM2ST/rectification.py +204 -0
- sm2st-0.0.1/SM2ST/utils.py +447 -0
- sm2st-0.0.1/SM2ST.egg-info/PKG-INFO +17 -0
- sm2st-0.0.1/SM2ST.egg-info/SOURCES.txt +14 -0
- sm2st-0.0.1/SM2ST.egg-info/dependency_links.txt +1 -0
- sm2st-0.0.1/SM2ST.egg-info/top_level.txt +1 -0
- sm2st-0.0.1/setup.cfg +4 -0
- sm2st-0.0.1/setup.py +22 -0
sm2st-0.0.1/LICENSE.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 Lixian Lin
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
sm2st-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: SM2ST
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: SM2ST: Automatic registration of spatial metabolome and spatial transcriptome via adversarial autoencoders
|
|
5
|
+
Home-page: https://github.com/binbin-coder/SM2ST
|
|
6
|
+
Author: LLX
|
|
7
|
+
Author-email: llx_1910@163.com
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
10
|
+
Classifier: Operating System :: OS Independent
|
|
11
|
+
Requires-Python: >=3.8
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
License-File: LICENSE.txt
|
|
14
|
+
|
|
15
|
+
SM2ST package. You can see
|
|
16
|
+
(https://github.com/binbin-coder/SM2ST)
|
|
17
|
+
to use.
|
sm2st-0.0.1/README.md
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.backends.cudnn as cudnn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
import random
|
|
8
|
+
# from .gatv2_conv_or import GATv2Conv as GATConv
|
|
9
|
+
from torch.nn.utils import spectral_norm
|
|
10
|
+
|
|
11
|
+
class encoding_mask_noise(torch.nn.Module):
|
|
12
|
+
def __init__(self, hidden_dims):
|
|
13
|
+
super(encoding_mask_noise, self).__init__()
|
|
14
|
+
[in_dim, num_hidden, out_dim] = hidden_dims
|
|
15
|
+
self.enc_mask_token = nn.Parameter(torch.zeros(size=(1, in_dim)))
|
|
16
|
+
self.reset_parameters_for_token()
|
|
17
|
+
|
|
18
|
+
def reset_parameters_for_token(self):
|
|
19
|
+
nn.init.xavier_normal_(self.enc_mask_token.data, gain=1.414)#
|
|
20
|
+
|
|
21
|
+
def forward(self, x, mask_rate=0.5, replace_rate=0.05):
|
|
22
|
+
# num_nodes = g.num_nodes()
|
|
23
|
+
num_nodes = x.size()[0]
|
|
24
|
+
perm = torch.randperm(num_nodes, device=x.device)
|
|
25
|
+
num_mask_nodes = int(mask_rate * num_nodes)
|
|
26
|
+
mask_token_rate = 1-replace_rate
|
|
27
|
+
# random masking
|
|
28
|
+
num_mask_nodes = int(mask_rate * num_nodes)
|
|
29
|
+
mask_nodes = perm[: num_mask_nodes]
|
|
30
|
+
keep_nodes = perm[num_mask_nodes: ]
|
|
31
|
+
|
|
32
|
+
if replace_rate > 0.0:
|
|
33
|
+
num_noise_nodes = int(replace_rate * num_mask_nodes)
|
|
34
|
+
perm_mask = torch.randperm(num_mask_nodes, device=x.device)
|
|
35
|
+
token_nodes = mask_nodes[perm_mask[: -num_noise_nodes]]#int(mask_token_rate * num_mask_nodes)
|
|
36
|
+
noise_nodes = mask_nodes[perm_mask[-num_noise_nodes:]]
|
|
37
|
+
noise_to_be_chosen = torch.randperm(num_nodes, device=x.device)[:num_noise_nodes]
|
|
38
|
+
|
|
39
|
+
out_x = x.clone()
|
|
40
|
+
# out_x[token_nodes] = torch.zeros_like(out_x[token_nodes])
|
|
41
|
+
out_x[token_nodes] = 0.0
|
|
42
|
+
out_x[noise_nodes] = x[noise_to_be_chosen]
|
|
43
|
+
# out_x[noise_nodes] = torch.add(x[noise_to_be_chosen], out_x[noise_nodes])
|
|
44
|
+
else:
|
|
45
|
+
out_x = x.clone()
|
|
46
|
+
token_nodes = mask_nodes
|
|
47
|
+
out_x[mask_nodes] = 0.0
|
|
48
|
+
|
|
49
|
+
out_x[token_nodes] += self.enc_mask_token
|
|
50
|
+
# use_g = g.clone()
|
|
51
|
+
return out_x, mask_nodes, keep_nodes
|
|
52
|
+
|
|
53
|
+
class random_remask(torch.nn.Module):
|
|
54
|
+
def __init__(self, hidden_dims):
|
|
55
|
+
super(random_remask, self).__init__()
|
|
56
|
+
[in_dim, num_hidden, out_dim] = hidden_dims
|
|
57
|
+
self.dec_mask_token = nn.Parameter(torch.zeros(size=(1, out_dim)))
|
|
58
|
+
self.reset_parameters_for_token()
|
|
59
|
+
|
|
60
|
+
def reset_parameters_for_token(self):
|
|
61
|
+
nn.init.xavier_normal_(self.dec_mask_token.data, gain=1.414)
|
|
62
|
+
|
|
63
|
+
def forward(self,rep,remask_rate=0.5):
|
|
64
|
+
num_nodes = rep.size()[0]
|
|
65
|
+
# num_nodes = g.num_nodes()
|
|
66
|
+
perm = torch.randperm(num_nodes, device=rep.device)
|
|
67
|
+
num_remask_nodes = int(remask_rate * num_nodes)
|
|
68
|
+
remask_nodes = perm[: num_remask_nodes]
|
|
69
|
+
rekeep_nodes = perm[num_remask_nodes: ]
|
|
70
|
+
|
|
71
|
+
out_rep = rep.clone()
|
|
72
|
+
out_rep[remask_nodes] = 0.0
|
|
73
|
+
out_rep[remask_nodes] += self.dec_mask_token
|
|
74
|
+
return out_rep, remask_nodes, rekeep_nodes
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# class Encoder(nn.Module):
|
|
78
|
+
# def __init__(self, mz_number, X_dim):
|
|
79
|
+
# super(Encoder, self).__init__()
|
|
80
|
+
# # self.encoding_mask_noise = encoding_mask_noise(hidden_dims)
|
|
81
|
+
# # self.random_remask = random_remask(hidden_dims)
|
|
82
|
+
# self.fc1 = nn.Linear(mz_number, 1024)
|
|
83
|
+
# self.fc1_bn = nn.BatchNorm1d(1024)
|
|
84
|
+
# self.fc2 = nn.Linear(1024, 256)
|
|
85
|
+
# self.fc2_bn = nn.BatchNorm1d(256)
|
|
86
|
+
# self.fc3 = nn.Linear(256, 64)
|
|
87
|
+
# self.fc3_bn = nn.BatchNorm1d(64)
|
|
88
|
+
# self.fc4 = nn.Linear(64, 8)
|
|
89
|
+
# self.fc4_bn = nn.BatchNorm1d(8)
|
|
90
|
+
# self.fc5 = nn.Linear(8, X_dim)
|
|
91
|
+
# # Initialize parameters
|
|
92
|
+
# self.init_weights()
|
|
93
|
+
|
|
94
|
+
# def init_weights(self):
|
|
95
|
+
# gain = nn.init.calculate_gain('relu')
|
|
96
|
+
# # Initialize weights and biases for all linear layers
|
|
97
|
+
# for module in self.modules():
|
|
98
|
+
# if isinstance(module, nn.Linear):
|
|
99
|
+
# # Use the Xavier initialization method to specify the gain value
|
|
100
|
+
# nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
101
|
+
# if module.bias is not None:
|
|
102
|
+
# # Initialize the bias to 0
|
|
103
|
+
# nn.init.zeros_(module.bias)
|
|
104
|
+
|
|
105
|
+
# def forward(self, features, relu=False, mask = 0.0):
|
|
106
|
+
# if mask:
|
|
107
|
+
# mask_tensor = torch.bernoulli(torch.full_like(features, mask)).to(features.device) # Random mask with 50% probability
|
|
108
|
+
# features = features * mask_tensor # Apply mask
|
|
109
|
+
# h1 = F.relu(self.fc1_bn(self.fc1(features)))
|
|
110
|
+
# h2 = F.relu(self.fc2_bn(self.fc2(h1)))
|
|
111
|
+
# h3 = F.relu(self.fc3_bn(self.fc3(h2)))
|
|
112
|
+
# h4 = F.relu(self.fc4_bn(self.fc4(h3)))
|
|
113
|
+
# if relu:
|
|
114
|
+
# return F.relu(self.fc5(h4))
|
|
115
|
+
# else:
|
|
116
|
+
# return self.fc5(h4)
|
|
117
|
+
|
|
118
|
+
class Encoder(nn.Module):
|
|
119
|
+
def __init__(self, mz_number, X_dim, down_ratio):
|
|
120
|
+
super(Encoder, self).__init__()
|
|
121
|
+
self.dropout_rate = down_ratio
|
|
122
|
+
|
|
123
|
+
self.fc1 = nn.Linear(mz_number, 1024)
|
|
124
|
+
self.fc1_bn = nn.BatchNorm1d(1024)
|
|
125
|
+
self.dropout1 = nn.Dropout(self.dropout_rate)
|
|
126
|
+
|
|
127
|
+
self.fc2 = nn.Linear(1024, 256)
|
|
128
|
+
self.fc2_bn = nn.BatchNorm1d(256)
|
|
129
|
+
self.dropout2 = nn.Dropout(self.dropout_rate)
|
|
130
|
+
|
|
131
|
+
self.fc3 = nn.Linear(256, 64)
|
|
132
|
+
self.fc3_bn = nn.BatchNorm1d(64)
|
|
133
|
+
self.dropout3 = nn.Dropout(self.dropout_rate)
|
|
134
|
+
|
|
135
|
+
self.fc4 = nn.Linear(64, 16)#8
|
|
136
|
+
self.fc4_bn = nn.BatchNorm1d(16)#8
|
|
137
|
+
self.dropout4 = nn.Dropout(self.dropout_rate)
|
|
138
|
+
|
|
139
|
+
self.fc5 = nn.Linear(16, X_dim)
|
|
140
|
+
|
|
141
|
+
# Initialize parameters
|
|
142
|
+
self.init_weights()
|
|
143
|
+
|
|
144
|
+
def init_weights(self):
|
|
145
|
+
gain = nn.init.calculate_gain('relu')
|
|
146
|
+
# Initialize weights and biases for all linear layers
|
|
147
|
+
for module in self.modules():
|
|
148
|
+
if isinstance(module, nn.Linear):
|
|
149
|
+
# Use the Xavier initialization method to specify the gain value
|
|
150
|
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
151
|
+
if module.bias is not None:
|
|
152
|
+
# Initialize the bias to 0
|
|
153
|
+
nn.init.zeros_(module.bias)
|
|
154
|
+
|
|
155
|
+
def forward(self, features, relu=False):
|
|
156
|
+
# h1 = self.CustomDropout1(features)
|
|
157
|
+
# h1 = F.relu(self.fc1_bn(self.fc1(h1)))
|
|
158
|
+
h1 = F.relu(self.fc1_bn(self.fc1(features)))
|
|
159
|
+
h1 = self.dropout1(h1)
|
|
160
|
+
|
|
161
|
+
h2 = F.relu(self.fc2_bn(self.fc2(h1)))
|
|
162
|
+
h2 = self.dropout2(h2)
|
|
163
|
+
|
|
164
|
+
h3 = F.relu(self.fc3_bn(self.fc3(h2)))
|
|
165
|
+
h3 = self.dropout3(h3)
|
|
166
|
+
|
|
167
|
+
h4 = F.relu(self.fc4_bn(self.fc4(h3)))
|
|
168
|
+
h4 = self.dropout4(h4)
|
|
169
|
+
|
|
170
|
+
if relu:
|
|
171
|
+
return F.relu(self.fc5(h4))
|
|
172
|
+
else:
|
|
173
|
+
return self.fc5(h4)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
# class Decoder(nn.Module):
|
|
177
|
+
# def __init__(self, mz_number, X_dim):
|
|
178
|
+
# super(Decoder, self).__init__()
|
|
179
|
+
# self.fc6 = nn.Linear(X_dim, 8)
|
|
180
|
+
# self.fc6_bn = nn.BatchNorm1d(8)
|
|
181
|
+
# self.fc7 = nn.Linear(8, 64)
|
|
182
|
+
# self.fc7_bn = nn.BatchNorm1d(64)
|
|
183
|
+
# self.fc8 = nn.Linear(64, 256)
|
|
184
|
+
# self.fc8_bn = nn.BatchNorm1d(256)
|
|
185
|
+
# self.fc9 = nn.Linear(256, 1024)
|
|
186
|
+
# self.fc9_bn = nn.BatchNorm1d(1024)
|
|
187
|
+
# self.fc10 = nn.Linear(1024, mz_number)
|
|
188
|
+
# # Initialize parameters
|
|
189
|
+
# self.init_weights()
|
|
190
|
+
|
|
191
|
+
# def init_weights(self):
|
|
192
|
+
# # Initialize weights and biases for all linear layers
|
|
193
|
+
# gain = nn.init.calculate_gain('relu')
|
|
194
|
+
# for module in self.modules():
|
|
195
|
+
# if isinstance(module, nn.Linear):
|
|
196
|
+
# # Use the Xavier initialization method to specify the gain value
|
|
197
|
+
# nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
198
|
+
# if module.bias is not None:
|
|
199
|
+
# # Initialize the bias to 0
|
|
200
|
+
# nn.init.zeros_(module.bias)
|
|
201
|
+
|
|
202
|
+
# def forward(self, z, relu=False):
|
|
203
|
+
# h6 = F.relu(self.fc6_bn(self.fc6(z)))
|
|
204
|
+
# h7 = F.relu(self.fc7_bn(self.fc7(h6)))
|
|
205
|
+
# h8 = F.relu(self.fc8_bn(self.fc8(h7)))
|
|
206
|
+
# h9 = F.relu(self.fc9_bn(self.fc9(h8)))
|
|
207
|
+
# if relu:
|
|
208
|
+
# return F.relu(self.fc10(h9))
|
|
209
|
+
# else:
|
|
210
|
+
# return self.fc10(h9)
|
|
211
|
+
|
|
212
|
+
class Decoder(nn.Module):
|
|
213
|
+
def __init__(self, mz_number, X_dim, down_ratio):
|
|
214
|
+
super(Decoder, self).__init__()
|
|
215
|
+
self.dropout_rate = down_ratio
|
|
216
|
+
|
|
217
|
+
self.fc6 = nn.Linear(X_dim, 16)#8
|
|
218
|
+
self.fc6_bn = nn.BatchNorm1d(16)#8
|
|
219
|
+
self.dropout6 = nn.Dropout(self.dropout_rate)
|
|
220
|
+
|
|
221
|
+
self.fc7 = nn.Linear(16, 64)
|
|
222
|
+
self.fc7_bn = nn.BatchNorm1d(64)
|
|
223
|
+
self.dropout7 = nn.Dropout(self.dropout_rate)
|
|
224
|
+
|
|
225
|
+
self.fc8 = nn.Linear(64, 256)
|
|
226
|
+
self.fc8_bn = nn.BatchNorm1d(256)
|
|
227
|
+
self.dropout8 = nn.Dropout(self.dropout_rate)
|
|
228
|
+
|
|
229
|
+
self.fc9 = nn.Linear(256, 1024)
|
|
230
|
+
self.fc9_bn = nn.BatchNorm1d(1024)
|
|
231
|
+
self.dropout9 = nn.Dropout(self.dropout_rate)
|
|
232
|
+
|
|
233
|
+
self.fc10 = nn.Linear(1024, mz_number)
|
|
234
|
+
|
|
235
|
+
# Initialize parameters
|
|
236
|
+
self.init_weights()
|
|
237
|
+
|
|
238
|
+
def init_weights(self):
|
|
239
|
+
gain = nn.init.calculate_gain('relu')
|
|
240
|
+
# Initialize weights and biases for all linear layers
|
|
241
|
+
for module in self.modules():
|
|
242
|
+
if isinstance(module, nn.Linear):
|
|
243
|
+
# Use the Xavier initialization method to specify the gain value
|
|
244
|
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
245
|
+
if module.bias is not None:
|
|
246
|
+
# Initialize the bias to 0
|
|
247
|
+
nn.init.zeros_(module.bias)
|
|
248
|
+
|
|
249
|
+
def forward(self, z, relu=False):
|
|
250
|
+
h6 = F.relu(self.fc6_bn(self.fc6(z)))
|
|
251
|
+
h6 = self.dropout6(h6)
|
|
252
|
+
|
|
253
|
+
h7 = F.relu(self.fc7_bn(self.fc7(h6)))
|
|
254
|
+
h7 = self.dropout7(h7)
|
|
255
|
+
|
|
256
|
+
h8 = F.relu(self.fc8_bn(self.fc8(h7)))
|
|
257
|
+
h8 = self.dropout8(h8)
|
|
258
|
+
|
|
259
|
+
h9 = F.relu(self.fc9_bn(self.fc9(h8)))
|
|
260
|
+
h9 = self.dropout9(h9)
|
|
261
|
+
|
|
262
|
+
if relu:
|
|
263
|
+
return F.relu(self.fc10(h9))
|
|
264
|
+
else:
|
|
265
|
+
return self.fc10(h9)
|
|
266
|
+
|
|
267
|
+
class Discriminator_A(torch.nn.Module):
|
|
268
|
+
def __init__(self, X_dim):
|
|
269
|
+
super(Discriminator_A, self).__init__()
|
|
270
|
+
self.fc = torch.nn.Sequential(
|
|
271
|
+
spectral_norm(nn.Linear(X_dim, 128)),# last best
|
|
272
|
+
nn.LeakyReLU(0.2),
|
|
273
|
+
spectral_norm(nn.Linear(128, 32)),
|
|
274
|
+
nn.LeakyReLU(0.2),
|
|
275
|
+
spectral_norm(nn.Linear(32, 8)),
|
|
276
|
+
nn.LeakyReLU(0.2),
|
|
277
|
+
spectral_norm(nn.Linear(8, 1)),
|
|
278
|
+
nn.Sigmoid()
|
|
279
|
+
# nn.Linear(X_dim, 64),
|
|
280
|
+
# nn.LeakyReLU(0.2),
|
|
281
|
+
# nn.Linear(64, 8),
|
|
282
|
+
# nn.LeakyReLU(0.2),
|
|
283
|
+
# nn.Linear(8, 1),
|
|
284
|
+
# nn.Sigmoid()
|
|
285
|
+
)
|
|
286
|
+
self.init_weights()
|
|
287
|
+
|
|
288
|
+
def init_weights(self):
|
|
289
|
+
gain = nn.init.calculate_gain('leaky_relu', 0.2)
|
|
290
|
+
# Initialize weights and biases for all linear layers
|
|
291
|
+
for module in self.modules():
|
|
292
|
+
if isinstance(module, nn.Linear):
|
|
293
|
+
# Use the Xavier initialization method to specify the gain value
|
|
294
|
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
295
|
+
if module.bias is not None:
|
|
296
|
+
# Initialize the bias to 0
|
|
297
|
+
nn.init.zeros_(module.bias)
|
|
298
|
+
def forward(self, x):
|
|
299
|
+
return self.fc(x)
|
|
300
|
+
|
|
301
|
+
class Discriminator_B(torch.nn.Module):
|
|
302
|
+
def __init__(self, X_dim):
|
|
303
|
+
super(Discriminator_B, self).__init__()
|
|
304
|
+
self.fc = torch.nn.Sequential(
|
|
305
|
+
nn.Linear(X_dim, 512),
|
|
306
|
+
nn.LeakyReLU(0.2),
|
|
307
|
+
nn.Linear(512, 128),
|
|
308
|
+
nn.LeakyReLU(0.2),
|
|
309
|
+
nn.Linear(128, 32),
|
|
310
|
+
nn.LeakyReLU(0.2),
|
|
311
|
+
nn.Linear(32, 1),
|
|
312
|
+
# nn.Linear(X_dim, 16),
|
|
313
|
+
# nn.LeakyReLU(0.2),
|
|
314
|
+
# nn.Linear(16, 4),
|
|
315
|
+
# nn.LeakyReLU(0.2),
|
|
316
|
+
# nn.Linear(4, 1),
|
|
317
|
+
# nn.Sigmoid()
|
|
318
|
+
)
|
|
319
|
+
self.init_weights()
|
|
320
|
+
|
|
321
|
+
def init_weights(self):
|
|
322
|
+
gain = nn.init.calculate_gain('leaky_relu', 0.2)
|
|
323
|
+
# Initialize weights and biases for all linear layers
|
|
324
|
+
for module in self.modules():
|
|
325
|
+
if isinstance(module, nn.Linear):
|
|
326
|
+
# Use the Xavier initialization method to specify the gain value
|
|
327
|
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
328
|
+
if module.bias is not None:
|
|
329
|
+
# Initialize the bias to 0
|
|
330
|
+
nn.init.zeros_(module.bias)
|
|
331
|
+
def forward(self, x):
|
|
332
|
+
return self.fc(x)
|