SM2ST 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SM2ST/SMLED.py +332 -0
- SM2ST/Train_SMLED.py +363 -0
- SM2ST/__init__.py +15 -0
- SM2ST/dataset.py +85 -0
- SM2ST/gatv2_conv.py +213 -0
- SM2ST/rectification.py +204 -0
- SM2ST/utils.py +447 -0
- sm2st-0.0.1.dist-info/LICENSE.txt +21 -0
- sm2st-0.0.1.dist-info/METADATA +17 -0
- sm2st-0.0.1.dist-info/RECORD +12 -0
- sm2st-0.0.1.dist-info/WHEEL +5 -0
- sm2st-0.0.1.dist-info/top_level.txt +1 -0
SM2ST/SMLED.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.backends.cudnn as cudnn
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
import random
|
|
8
|
+
# from .gatv2_conv_or import GATv2Conv as GATConv
|
|
9
|
+
from torch.nn.utils import spectral_norm
|
|
10
|
+
|
|
11
|
+
class encoding_mask_noise(torch.nn.Module):
|
|
12
|
+
def __init__(self, hidden_dims):
|
|
13
|
+
super(encoding_mask_noise, self).__init__()
|
|
14
|
+
[in_dim, num_hidden, out_dim] = hidden_dims
|
|
15
|
+
self.enc_mask_token = nn.Parameter(torch.zeros(size=(1, in_dim)))
|
|
16
|
+
self.reset_parameters_for_token()
|
|
17
|
+
|
|
18
|
+
def reset_parameters_for_token(self):
|
|
19
|
+
nn.init.xavier_normal_(self.enc_mask_token.data, gain=1.414)#
|
|
20
|
+
|
|
21
|
+
def forward(self, x, mask_rate=0.5, replace_rate=0.05):
|
|
22
|
+
# num_nodes = g.num_nodes()
|
|
23
|
+
num_nodes = x.size()[0]
|
|
24
|
+
perm = torch.randperm(num_nodes, device=x.device)
|
|
25
|
+
num_mask_nodes = int(mask_rate * num_nodes)
|
|
26
|
+
mask_token_rate = 1-replace_rate
|
|
27
|
+
# random masking
|
|
28
|
+
num_mask_nodes = int(mask_rate * num_nodes)
|
|
29
|
+
mask_nodes = perm[: num_mask_nodes]
|
|
30
|
+
keep_nodes = perm[num_mask_nodes: ]
|
|
31
|
+
|
|
32
|
+
if replace_rate > 0.0:
|
|
33
|
+
num_noise_nodes = int(replace_rate * num_mask_nodes)
|
|
34
|
+
perm_mask = torch.randperm(num_mask_nodes, device=x.device)
|
|
35
|
+
token_nodes = mask_nodes[perm_mask[: -num_noise_nodes]]#int(mask_token_rate * num_mask_nodes)
|
|
36
|
+
noise_nodes = mask_nodes[perm_mask[-num_noise_nodes:]]
|
|
37
|
+
noise_to_be_chosen = torch.randperm(num_nodes, device=x.device)[:num_noise_nodes]
|
|
38
|
+
|
|
39
|
+
out_x = x.clone()
|
|
40
|
+
# out_x[token_nodes] = torch.zeros_like(out_x[token_nodes])
|
|
41
|
+
out_x[token_nodes] = 0.0
|
|
42
|
+
out_x[noise_nodes] = x[noise_to_be_chosen]
|
|
43
|
+
# out_x[noise_nodes] = torch.add(x[noise_to_be_chosen], out_x[noise_nodes])
|
|
44
|
+
else:
|
|
45
|
+
out_x = x.clone()
|
|
46
|
+
token_nodes = mask_nodes
|
|
47
|
+
out_x[mask_nodes] = 0.0
|
|
48
|
+
|
|
49
|
+
out_x[token_nodes] += self.enc_mask_token
|
|
50
|
+
# use_g = g.clone()
|
|
51
|
+
return out_x, mask_nodes, keep_nodes
|
|
52
|
+
|
|
53
|
+
class random_remask(torch.nn.Module):
|
|
54
|
+
def __init__(self, hidden_dims):
|
|
55
|
+
super(random_remask, self).__init__()
|
|
56
|
+
[in_dim, num_hidden, out_dim] = hidden_dims
|
|
57
|
+
self.dec_mask_token = nn.Parameter(torch.zeros(size=(1, out_dim)))
|
|
58
|
+
self.reset_parameters_for_token()
|
|
59
|
+
|
|
60
|
+
def reset_parameters_for_token(self):
|
|
61
|
+
nn.init.xavier_normal_(self.dec_mask_token.data, gain=1.414)
|
|
62
|
+
|
|
63
|
+
def forward(self,rep,remask_rate=0.5):
|
|
64
|
+
num_nodes = rep.size()[0]
|
|
65
|
+
# num_nodes = g.num_nodes()
|
|
66
|
+
perm = torch.randperm(num_nodes, device=rep.device)
|
|
67
|
+
num_remask_nodes = int(remask_rate * num_nodes)
|
|
68
|
+
remask_nodes = perm[: num_remask_nodes]
|
|
69
|
+
rekeep_nodes = perm[num_remask_nodes: ]
|
|
70
|
+
|
|
71
|
+
out_rep = rep.clone()
|
|
72
|
+
out_rep[remask_nodes] = 0.0
|
|
73
|
+
out_rep[remask_nodes] += self.dec_mask_token
|
|
74
|
+
return out_rep, remask_nodes, rekeep_nodes
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
# class Encoder(nn.Module):
|
|
78
|
+
# def __init__(self, mz_number, X_dim):
|
|
79
|
+
# super(Encoder, self).__init__()
|
|
80
|
+
# # self.encoding_mask_noise = encoding_mask_noise(hidden_dims)
|
|
81
|
+
# # self.random_remask = random_remask(hidden_dims)
|
|
82
|
+
# self.fc1 = nn.Linear(mz_number, 1024)
|
|
83
|
+
# self.fc1_bn = nn.BatchNorm1d(1024)
|
|
84
|
+
# self.fc2 = nn.Linear(1024, 256)
|
|
85
|
+
# self.fc2_bn = nn.BatchNorm1d(256)
|
|
86
|
+
# self.fc3 = nn.Linear(256, 64)
|
|
87
|
+
# self.fc3_bn = nn.BatchNorm1d(64)
|
|
88
|
+
# self.fc4 = nn.Linear(64, 8)
|
|
89
|
+
# self.fc4_bn = nn.BatchNorm1d(8)
|
|
90
|
+
# self.fc5 = nn.Linear(8, X_dim)
|
|
91
|
+
# # Initialize parameters
|
|
92
|
+
# self.init_weights()
|
|
93
|
+
|
|
94
|
+
# def init_weights(self):
|
|
95
|
+
# gain = nn.init.calculate_gain('relu')
|
|
96
|
+
# # Initialize weights and biases for all linear layers
|
|
97
|
+
# for module in self.modules():
|
|
98
|
+
# if isinstance(module, nn.Linear):
|
|
99
|
+
# # Use the Xavier initialization method to specify the gain value
|
|
100
|
+
# nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
101
|
+
# if module.bias is not None:
|
|
102
|
+
# # Initialize the bias to 0
|
|
103
|
+
# nn.init.zeros_(module.bias)
|
|
104
|
+
|
|
105
|
+
# def forward(self, features, relu=False, mask = 0.0):
|
|
106
|
+
# if mask:
|
|
107
|
+
# mask_tensor = torch.bernoulli(torch.full_like(features, mask)).to(features.device) # Random mask with 50% probability
|
|
108
|
+
# features = features * mask_tensor # Apply mask
|
|
109
|
+
# h1 = F.relu(self.fc1_bn(self.fc1(features)))
|
|
110
|
+
# h2 = F.relu(self.fc2_bn(self.fc2(h1)))
|
|
111
|
+
# h3 = F.relu(self.fc3_bn(self.fc3(h2)))
|
|
112
|
+
# h4 = F.relu(self.fc4_bn(self.fc4(h3)))
|
|
113
|
+
# if relu:
|
|
114
|
+
# return F.relu(self.fc5(h4))
|
|
115
|
+
# else:
|
|
116
|
+
# return self.fc5(h4)
|
|
117
|
+
|
|
118
|
+
class Encoder(nn.Module):
|
|
119
|
+
def __init__(self, mz_number, X_dim, down_ratio):
|
|
120
|
+
super(Encoder, self).__init__()
|
|
121
|
+
self.dropout_rate = down_ratio
|
|
122
|
+
|
|
123
|
+
self.fc1 = nn.Linear(mz_number, 1024)
|
|
124
|
+
self.fc1_bn = nn.BatchNorm1d(1024)
|
|
125
|
+
self.dropout1 = nn.Dropout(self.dropout_rate)
|
|
126
|
+
|
|
127
|
+
self.fc2 = nn.Linear(1024, 256)
|
|
128
|
+
self.fc2_bn = nn.BatchNorm1d(256)
|
|
129
|
+
self.dropout2 = nn.Dropout(self.dropout_rate)
|
|
130
|
+
|
|
131
|
+
self.fc3 = nn.Linear(256, 64)
|
|
132
|
+
self.fc3_bn = nn.BatchNorm1d(64)
|
|
133
|
+
self.dropout3 = nn.Dropout(self.dropout_rate)
|
|
134
|
+
|
|
135
|
+
self.fc4 = nn.Linear(64, 16)#8
|
|
136
|
+
self.fc4_bn = nn.BatchNorm1d(16)#8
|
|
137
|
+
self.dropout4 = nn.Dropout(self.dropout_rate)
|
|
138
|
+
|
|
139
|
+
self.fc5 = nn.Linear(16, X_dim)
|
|
140
|
+
|
|
141
|
+
# Initialize parameters
|
|
142
|
+
self.init_weights()
|
|
143
|
+
|
|
144
|
+
def init_weights(self):
|
|
145
|
+
gain = nn.init.calculate_gain('relu')
|
|
146
|
+
# Initialize weights and biases for all linear layers
|
|
147
|
+
for module in self.modules():
|
|
148
|
+
if isinstance(module, nn.Linear):
|
|
149
|
+
# Use the Xavier initialization method to specify the gain value
|
|
150
|
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
151
|
+
if module.bias is not None:
|
|
152
|
+
# Initialize the bias to 0
|
|
153
|
+
nn.init.zeros_(module.bias)
|
|
154
|
+
|
|
155
|
+
def forward(self, features, relu=False):
|
|
156
|
+
# h1 = self.CustomDropout1(features)
|
|
157
|
+
# h1 = F.relu(self.fc1_bn(self.fc1(h1)))
|
|
158
|
+
h1 = F.relu(self.fc1_bn(self.fc1(features)))
|
|
159
|
+
h1 = self.dropout1(h1)
|
|
160
|
+
|
|
161
|
+
h2 = F.relu(self.fc2_bn(self.fc2(h1)))
|
|
162
|
+
h2 = self.dropout2(h2)
|
|
163
|
+
|
|
164
|
+
h3 = F.relu(self.fc3_bn(self.fc3(h2)))
|
|
165
|
+
h3 = self.dropout3(h3)
|
|
166
|
+
|
|
167
|
+
h4 = F.relu(self.fc4_bn(self.fc4(h3)))
|
|
168
|
+
h4 = self.dropout4(h4)
|
|
169
|
+
|
|
170
|
+
if relu:
|
|
171
|
+
return F.relu(self.fc5(h4))
|
|
172
|
+
else:
|
|
173
|
+
return self.fc5(h4)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
# class Decoder(nn.Module):
|
|
177
|
+
# def __init__(self, mz_number, X_dim):
|
|
178
|
+
# super(Decoder, self).__init__()
|
|
179
|
+
# self.fc6 = nn.Linear(X_dim, 8)
|
|
180
|
+
# self.fc6_bn = nn.BatchNorm1d(8)
|
|
181
|
+
# self.fc7 = nn.Linear(8, 64)
|
|
182
|
+
# self.fc7_bn = nn.BatchNorm1d(64)
|
|
183
|
+
# self.fc8 = nn.Linear(64, 256)
|
|
184
|
+
# self.fc8_bn = nn.BatchNorm1d(256)
|
|
185
|
+
# self.fc9 = nn.Linear(256, 1024)
|
|
186
|
+
# self.fc9_bn = nn.BatchNorm1d(1024)
|
|
187
|
+
# self.fc10 = nn.Linear(1024, mz_number)
|
|
188
|
+
# # Initialize parameters
|
|
189
|
+
# self.init_weights()
|
|
190
|
+
|
|
191
|
+
# def init_weights(self):
|
|
192
|
+
# # Initialize weights and biases for all linear layers
|
|
193
|
+
# gain = nn.init.calculate_gain('relu')
|
|
194
|
+
# for module in self.modules():
|
|
195
|
+
# if isinstance(module, nn.Linear):
|
|
196
|
+
# # Use the Xavier initialization method to specify the gain value
|
|
197
|
+
# nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
198
|
+
# if module.bias is not None:
|
|
199
|
+
# # Initialize the bias to 0
|
|
200
|
+
# nn.init.zeros_(module.bias)
|
|
201
|
+
|
|
202
|
+
# def forward(self, z, relu=False):
|
|
203
|
+
# h6 = F.relu(self.fc6_bn(self.fc6(z)))
|
|
204
|
+
# h7 = F.relu(self.fc7_bn(self.fc7(h6)))
|
|
205
|
+
# h8 = F.relu(self.fc8_bn(self.fc8(h7)))
|
|
206
|
+
# h9 = F.relu(self.fc9_bn(self.fc9(h8)))
|
|
207
|
+
# if relu:
|
|
208
|
+
# return F.relu(self.fc10(h9))
|
|
209
|
+
# else:
|
|
210
|
+
# return self.fc10(h9)
|
|
211
|
+
|
|
212
|
+
class Decoder(nn.Module):
|
|
213
|
+
def __init__(self, mz_number, X_dim, down_ratio):
|
|
214
|
+
super(Decoder, self).__init__()
|
|
215
|
+
self.dropout_rate = down_ratio
|
|
216
|
+
|
|
217
|
+
self.fc6 = nn.Linear(X_dim, 16)#8
|
|
218
|
+
self.fc6_bn = nn.BatchNorm1d(16)#8
|
|
219
|
+
self.dropout6 = nn.Dropout(self.dropout_rate)
|
|
220
|
+
|
|
221
|
+
self.fc7 = nn.Linear(16, 64)
|
|
222
|
+
self.fc7_bn = nn.BatchNorm1d(64)
|
|
223
|
+
self.dropout7 = nn.Dropout(self.dropout_rate)
|
|
224
|
+
|
|
225
|
+
self.fc8 = nn.Linear(64, 256)
|
|
226
|
+
self.fc8_bn = nn.BatchNorm1d(256)
|
|
227
|
+
self.dropout8 = nn.Dropout(self.dropout_rate)
|
|
228
|
+
|
|
229
|
+
self.fc9 = nn.Linear(256, 1024)
|
|
230
|
+
self.fc9_bn = nn.BatchNorm1d(1024)
|
|
231
|
+
self.dropout9 = nn.Dropout(self.dropout_rate)
|
|
232
|
+
|
|
233
|
+
self.fc10 = nn.Linear(1024, mz_number)
|
|
234
|
+
|
|
235
|
+
# Initialize parameters
|
|
236
|
+
self.init_weights()
|
|
237
|
+
|
|
238
|
+
def init_weights(self):
|
|
239
|
+
gain = nn.init.calculate_gain('relu')
|
|
240
|
+
# Initialize weights and biases for all linear layers
|
|
241
|
+
for module in self.modules():
|
|
242
|
+
if isinstance(module, nn.Linear):
|
|
243
|
+
# Use the Xavier initialization method to specify the gain value
|
|
244
|
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
245
|
+
if module.bias is not None:
|
|
246
|
+
# Initialize the bias to 0
|
|
247
|
+
nn.init.zeros_(module.bias)
|
|
248
|
+
|
|
249
|
+
def forward(self, z, relu=False):
|
|
250
|
+
h6 = F.relu(self.fc6_bn(self.fc6(z)))
|
|
251
|
+
h6 = self.dropout6(h6)
|
|
252
|
+
|
|
253
|
+
h7 = F.relu(self.fc7_bn(self.fc7(h6)))
|
|
254
|
+
h7 = self.dropout7(h7)
|
|
255
|
+
|
|
256
|
+
h8 = F.relu(self.fc8_bn(self.fc8(h7)))
|
|
257
|
+
h8 = self.dropout8(h8)
|
|
258
|
+
|
|
259
|
+
h9 = F.relu(self.fc9_bn(self.fc9(h8)))
|
|
260
|
+
h9 = self.dropout9(h9)
|
|
261
|
+
|
|
262
|
+
if relu:
|
|
263
|
+
return F.relu(self.fc10(h9))
|
|
264
|
+
else:
|
|
265
|
+
return self.fc10(h9)
|
|
266
|
+
|
|
267
|
+
class Discriminator_A(torch.nn.Module):
|
|
268
|
+
def __init__(self, X_dim):
|
|
269
|
+
super(Discriminator_A, self).__init__()
|
|
270
|
+
self.fc = torch.nn.Sequential(
|
|
271
|
+
spectral_norm(nn.Linear(X_dim, 128)),# last best
|
|
272
|
+
nn.LeakyReLU(0.2),
|
|
273
|
+
spectral_norm(nn.Linear(128, 32)),
|
|
274
|
+
nn.LeakyReLU(0.2),
|
|
275
|
+
spectral_norm(nn.Linear(32, 8)),
|
|
276
|
+
nn.LeakyReLU(0.2),
|
|
277
|
+
spectral_norm(nn.Linear(8, 1)),
|
|
278
|
+
nn.Sigmoid()
|
|
279
|
+
# nn.Linear(X_dim, 64),
|
|
280
|
+
# nn.LeakyReLU(0.2),
|
|
281
|
+
# nn.Linear(64, 8),
|
|
282
|
+
# nn.LeakyReLU(0.2),
|
|
283
|
+
# nn.Linear(8, 1),
|
|
284
|
+
# nn.Sigmoid()
|
|
285
|
+
)
|
|
286
|
+
self.init_weights()
|
|
287
|
+
|
|
288
|
+
def init_weights(self):
|
|
289
|
+
gain = nn.init.calculate_gain('leaky_relu', 0.2)
|
|
290
|
+
# Initialize weights and biases for all linear layers
|
|
291
|
+
for module in self.modules():
|
|
292
|
+
if isinstance(module, nn.Linear):
|
|
293
|
+
# Use the Xavier initialization method to specify the gain value
|
|
294
|
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
295
|
+
if module.bias is not None:
|
|
296
|
+
# Initialize the bias to 0
|
|
297
|
+
nn.init.zeros_(module.bias)
|
|
298
|
+
def forward(self, x):
|
|
299
|
+
return self.fc(x)
|
|
300
|
+
|
|
301
|
+
class Discriminator_B(torch.nn.Module):
|
|
302
|
+
def __init__(self, X_dim):
|
|
303
|
+
super(Discriminator_B, self).__init__()
|
|
304
|
+
self.fc = torch.nn.Sequential(
|
|
305
|
+
nn.Linear(X_dim, 512),
|
|
306
|
+
nn.LeakyReLU(0.2),
|
|
307
|
+
nn.Linear(512, 128),
|
|
308
|
+
nn.LeakyReLU(0.2),
|
|
309
|
+
nn.Linear(128, 32),
|
|
310
|
+
nn.LeakyReLU(0.2),
|
|
311
|
+
nn.Linear(32, 1),
|
|
312
|
+
# nn.Linear(X_dim, 16),
|
|
313
|
+
# nn.LeakyReLU(0.2),
|
|
314
|
+
# nn.Linear(16, 4),
|
|
315
|
+
# nn.LeakyReLU(0.2),
|
|
316
|
+
# nn.Linear(4, 1),
|
|
317
|
+
# nn.Sigmoid()
|
|
318
|
+
)
|
|
319
|
+
self.init_weights()
|
|
320
|
+
|
|
321
|
+
def init_weights(self):
|
|
322
|
+
gain = nn.init.calculate_gain('leaky_relu', 0.2)
|
|
323
|
+
# Initialize weights and biases for all linear layers
|
|
324
|
+
for module in self.modules():
|
|
325
|
+
if isinstance(module, nn.Linear):
|
|
326
|
+
# Use the Xavier initialization method to specify the gain value
|
|
327
|
+
nn.init.xavier_uniform_(module.weight, gain=gain)
|
|
328
|
+
if module.bias is not None:
|
|
329
|
+
# Initialize the bias to 0
|
|
330
|
+
nn.init.zeros_(module.bias)
|
|
331
|
+
def forward(self, x):
|
|
332
|
+
return self.fc(x)
|
SM2ST/Train_SMLED.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
import scipy.sparse as sp
|
|
5
|
+
import os
|
|
6
|
+
from .SMLED import Encoder, Decoder, Discriminator_A,Discriminator_B
|
|
7
|
+
from .utils import Transfer_pytorch_Data, positional_pixel_step, recovery_coord, generation_coord, Cal_Spatial_Net
|
|
8
|
+
from .dataset import *
|
|
9
|
+
import random
|
|
10
|
+
import torch
|
|
11
|
+
import torch.backends.cudnn as cudnn
|
|
12
|
+
from torch.autograd import Variable
|
|
13
|
+
import os
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
from scipy.sparse import csr_matrix, csc_matrix, coo_matrix
|
|
16
|
+
from torch_sparse import SparseTensor
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def sce_loss(x, y, alpha=1.0):
|
|
20
|
+
x = F.normalize(x, p=2, dim=-1)
|
|
21
|
+
y = F.normalize(y, p=2, dim=-1)
|
|
22
|
+
loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
|
|
23
|
+
|
|
24
|
+
loss = loss.mean()
|
|
25
|
+
return loss
|
|
26
|
+
|
|
27
|
+
class WeightedMSELoss(torch.nn.Module):
|
|
28
|
+
def __init__(self, weights):
|
|
29
|
+
super(WeightedMSELoss, self).__init__()
|
|
30
|
+
self.weights = weights
|
|
31
|
+
|
|
32
|
+
def forward(self, y_pred, y_true):
|
|
33
|
+
# Ensure that the shape of the weights is consistent with that of the input tensor
|
|
34
|
+
return torch.mean(self.weights * (y_pred -y_true) ** 2)
|
|
35
|
+
|
|
36
|
+
class WeightedMAELoss(torch.nn.Module):
|
|
37
|
+
def __init__(self, weights):
|
|
38
|
+
super(WeightedMAELoss, self).__init__()
|
|
39
|
+
self.weights = weights
|
|
40
|
+
|
|
41
|
+
def forward(self, y_pred, y_true):
|
|
42
|
+
# Ensure that the shape of the weights is consistent with that of the input tensor
|
|
43
|
+
return torch.mean(self.weights * torch.abs(y_pred - y_true))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def rand_projections(
|
|
47
|
+
embedding_dim,
|
|
48
|
+
num_samples=50,
|
|
49
|
+
device='cpu'
|
|
50
|
+
):
|
|
51
|
+
"""This function generates `num_samples` random samples from the latent space's unit sphere.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
embedding_dim (int): embedding dimensionality
|
|
55
|
+
num_samples (int): number of random projection samples
|
|
56
|
+
|
|
57
|
+
Return:
|
|
58
|
+
torch.Tensor: tensor of size (num_samples, embedding_dim)
|
|
59
|
+
"""
|
|
60
|
+
projections = [w / np.sqrt((w**2).sum()) # L2 normalization
|
|
61
|
+
for w in np.random.normal(size=(num_samples, embedding_dim))]
|
|
62
|
+
projections = np.asarray(projections)
|
|
63
|
+
return torch.from_numpy(projections).type(torch.FloatTensor).to(device)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def wasserstein_loss(disc_real, disc_fake):
|
|
67
|
+
return -torch.mean(disc_real) + torch.mean(disc_fake)
|
|
68
|
+
|
|
69
|
+
def gradient_penalty(discriminator, real_data, fake_data, device, lambda_gp=10):
|
|
70
|
+
alpha = torch.rand(real_data.size(0), 1).to(device)
|
|
71
|
+
interpolated = alpha * real_data + ((1 - alpha) * fake_data)
|
|
72
|
+
interpolated = interpolated.requires_grad_(True)
|
|
73
|
+
mixed_scores = discriminator(interpolated)
|
|
74
|
+
gradients = torch.autograd.grad(
|
|
75
|
+
inputs=interpolated,
|
|
76
|
+
outputs=mixed_scores,
|
|
77
|
+
grad_outputs=torch.ones(mixed_scores.size()).to(device),
|
|
78
|
+
create_graph=True,
|
|
79
|
+
retain_graph=True,
|
|
80
|
+
only_inputs=True
|
|
81
|
+
)[0]
|
|
82
|
+
gradients_norm = torch.norm(gradients.view(gradients.size(0), -1), dim=1)
|
|
83
|
+
gradient_penalty = lambda_gp * ((gradients_norm - 1) ** 2).mean()
|
|
84
|
+
return gradient_penalty
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def train_SMLED(adata=None, X_dim = 2, delta = 1.0, train_epoch=15000,lr=0.001,mask_ratio=0.5,alpha=1.0,key_added='SMLED',step_size=10000,gamma=1.0,
|
|
88
|
+
relu=True, gradient_clipping=5., experiment='generation', weight_decay=0.0001, verbose=True, batch_size = 1000,lambda_gp = 1.0,
|
|
89
|
+
random_seed=2025, save_path = './SMLED_pyG_result',down_ratio = 0., coord_sf=1.0,
|
|
90
|
+
WMMSE=0.0, res = 2.0, device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')):
|
|
91
|
+
"""\
|
|
92
|
+
Training GAN auto-encoder.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
adata
|
|
97
|
+
AnnData object of scanpy package.
|
|
98
|
+
delta
|
|
99
|
+
Coordinate scaling.
|
|
100
|
+
train_epoch
|
|
101
|
+
Number of total epochs in training.
|
|
102
|
+
lr
|
|
103
|
+
Learning rate for AdamOptimizer.
|
|
104
|
+
key_added
|
|
105
|
+
The latent embeddings are saved in adata.obsm[key_added].
|
|
106
|
+
gradient_clipping
|
|
107
|
+
Gradient Clipping.
|
|
108
|
+
weight_decay
|
|
109
|
+
Weight decay for AdamOptimizer.
|
|
110
|
+
mask_ratio
|
|
111
|
+
Random masking ratio.
|
|
112
|
+
WMMSE
|
|
113
|
+
The weight distribution of wmse.
|
|
114
|
+
device
|
|
115
|
+
See torch.device.
|
|
116
|
+
|
|
117
|
+
Returns
|
|
118
|
+
-------
|
|
119
|
+
AnnData
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
# seed_everything()
|
|
123
|
+
seed=random_seed
|
|
124
|
+
fix_seed(seed)
|
|
125
|
+
if not os.path.isdir(save_path):
|
|
126
|
+
os.mkdir(save_path)
|
|
127
|
+
if verbose:
|
|
128
|
+
print('Size of Input: ', adata.X.shape)
|
|
129
|
+
|
|
130
|
+
if experiment=='recovery':
|
|
131
|
+
# adata, masked_adata, adata_filtered, picked_index, remaining_index = masked_anndata(adata = adata, mask_ratio=0.5)
|
|
132
|
+
coor, full_coor, sample_index, sample_barcode = recovery_coord(adata,name='spatial',mask_ratio = mask_ratio)
|
|
133
|
+
used_gene, normed_data, adata_sample = get_data(adata, experiment=experiment, sample_index=sample_index, sample_barcode=sample_barcode)
|
|
134
|
+
xlabel_df,full_xlabel_df = positional_pixel_step(coor, full_coor, delta, coord_sf)
|
|
135
|
+
print(xlabel_df,full_xlabel_df)
|
|
136
|
+
transformed_dataset = MyDataset(normed_data=normed_data, coor_df=xlabel_df, transform=transforms.Compose([ToTensor()]))
|
|
137
|
+
train_loader = DataLoader(transformed_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=False)
|
|
138
|
+
|
|
139
|
+
elif experiment == 'higher_res':
|
|
140
|
+
coor, full_coor = generation_coord(adata,name='spatial',res=res)
|
|
141
|
+
used_gene, normed_data = get_data(adata, experiment=experiment)
|
|
142
|
+
xlabel_df,full_xlabel_df = positional_pixel_step(coor, full_coor, delta, coord_sf)
|
|
143
|
+
print(xlabel_df,full_xlabel_df)
|
|
144
|
+
transformed_dataset = MyDataset(normed_data=normed_data, coor_df = xlabel_df, transform=transforms.Compose([ToTensor()]))
|
|
145
|
+
train_loader = DataLoader(transformed_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=False)
|
|
146
|
+
|
|
147
|
+
elif experiment == 'generation':
|
|
148
|
+
coor = adata.obsm['spatial']
|
|
149
|
+
full_coor = adata.uns['coord']
|
|
150
|
+
used_gene, normed_data = get_data(adata, experiment=experiment)
|
|
151
|
+
xlabel_df,full_xlabel_df = positional_pixel_step(coor, full_coor, delta, coord_sf)
|
|
152
|
+
print(xlabel_df,full_xlabel_df)
|
|
153
|
+
transformed_dataset = MyDataset(normed_data=normed_data, coor_df=xlabel_df, transform=transforms.Compose([ToTensor()]))
|
|
154
|
+
train_loader = DataLoader(transformed_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=False)
|
|
155
|
+
|
|
156
|
+
gene_number = len(used_gene)
|
|
157
|
+
encoder, decoder = Encoder(gene_number, X_dim, down_ratio), Decoder(gene_number, X_dim, down_ratio=0.)
|
|
158
|
+
discriminator_AB = Discriminator_A(X_dim) #, Discriminator_B(gene_number) #, discriminator_BA
|
|
159
|
+
# encoder.train()
|
|
160
|
+
# decoder.train()
|
|
161
|
+
|
|
162
|
+
encoder, decoder = encoder.to(device), decoder.to(device)
|
|
163
|
+
discriminator_AB = discriminator_AB.to(device) # ,discriminator_BA.to(device) , discriminator_BA
|
|
164
|
+
|
|
165
|
+
enc_optim = torch.optim.Adam(encoder.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-8, betas=(0.5, 0.999))#
|
|
166
|
+
dec_optim = torch.optim.Adam(decoder.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-8, betas=(0.5, 0.999))
|
|
167
|
+
|
|
168
|
+
disc_optim_AB = torch.optim.Adam(discriminator_AB.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-8, betas=(0.5, 0.999))
|
|
169
|
+
# enc_optim_gan = torch.optim.Adam(encoder.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-8, betas=(0.5, 0.999)) #
|
|
170
|
+
# dec_optim_gan = torch.optim.Adam(decoder.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-8, betas=(0.5, 0.999))
|
|
171
|
+
|
|
172
|
+
n_gen = 1
|
|
173
|
+
n_crit = 2
|
|
174
|
+
# disc_optim_BA = torch.optim.Adam(discriminator_BA.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-8, betas=(0.5, 0.999))
|
|
175
|
+
enc_sche = torch.optim.lr_scheduler.StepLR(enc_optim, step_size=n_gen*step_size, gamma=gamma)
|
|
176
|
+
dec_sche = torch.optim.lr_scheduler.StepLR(dec_optim, step_size=n_gen*step_size, gamma=gamma)
|
|
177
|
+
disc_sche_AB = torch.optim.lr_scheduler.StepLR(disc_optim_AB, step_size=n_crit*step_size, gamma=gamma)
|
|
178
|
+
# enc_sche_gan = torch.optim.lr_scheduler.StepLR(enc_optim_gan, step_size=step_size, gamma=gamma)
|
|
179
|
+
# dec_sche_gan = torch.optim.lr_scheduler.StepLR(dec_optim_gan, step_size=step_size, gamma=gamma)
|
|
180
|
+
# loss function
|
|
181
|
+
criterion = torch.nn.BCELoss()
|
|
182
|
+
|
|
183
|
+
# loss function
|
|
184
|
+
if WMMSE>0:
|
|
185
|
+
if sp.issparse(adata.X):
|
|
186
|
+
matrix = adata.X.A
|
|
187
|
+
else:
|
|
188
|
+
matrix = adata.X
|
|
189
|
+
column_sums = matrix.sum(axis=0)
|
|
190
|
+
normalized = column_sums * (WMMSE / column_sums.sum())
|
|
191
|
+
weights = WMMSE - normalized
|
|
192
|
+
|
|
193
|
+
weights = torch.tensor(weights, dtype=torch.float32,device = device)
|
|
194
|
+
loss2 = WeightedMSELoss(weights)
|
|
195
|
+
loss1 = WeightedMAELoss(weights)
|
|
196
|
+
else:
|
|
197
|
+
loss2 = torch.nn.MSELoss()
|
|
198
|
+
loss1 = torch.nn.L1Loss()
|
|
199
|
+
MAE = torch.nn.L1Loss()
|
|
200
|
+
with tqdm(range(train_epoch), total=train_epoch, desc='Epochs') as epoch:
|
|
201
|
+
for j in epoch:
|
|
202
|
+
train_reloss = []
|
|
203
|
+
train_GAloss = []
|
|
204
|
+
train_latloss = []
|
|
205
|
+
train_loss = []
|
|
206
|
+
train_DAloss = []
|
|
207
|
+
# train_DBloss = []
|
|
208
|
+
|
|
209
|
+
for xdata, xlabel in train_loader:
|
|
210
|
+
xdata = xdata.to(torch.float32)
|
|
211
|
+
xlabel = xlabel.to(torch.float32)
|
|
212
|
+
xdata, xlabel = Variable(xdata.to(device)), Variable(xlabel.to(device))
|
|
213
|
+
|
|
214
|
+
for _ in range(n_crit): #3
|
|
215
|
+
discriminator_AB.train()
|
|
216
|
+
disc_optim_AB.zero_grad()
|
|
217
|
+
fake_xlabel = encoder(xdata, relu)
|
|
218
|
+
# fake_xdata = decoder(fake_xlabel, relu)
|
|
219
|
+
# fake_xdata = decoder(fake_xlabel, relu)
|
|
220
|
+
# combined_xlabel = torch.cat((xdata, xlabel), dim=1)
|
|
221
|
+
# combined_fake_xlabel = torch.cat((fake_xdata, fake_xlabel), dim=1)
|
|
222
|
+
# disc_realA = discriminator_AB(combined_xlabel)
|
|
223
|
+
# disc_fakeA = discriminator_AB(combined_fake_xlabel)
|
|
224
|
+
disc_realA = discriminator_AB(xlabel)
|
|
225
|
+
disc_fakeA = discriminator_AB(fake_xlabel)
|
|
226
|
+
# d_loss = wasserstein_loss(disc_realA, disc_fakeA)
|
|
227
|
+
# gp = gradient_penalty(discriminator_AB, xlabel, fake_xlabel, device, lambda_gp = lambda_gp)
|
|
228
|
+
# d_total_loss = d_loss + gp
|
|
229
|
+
disc_real = disc_realA.view(-1)
|
|
230
|
+
disc_fake = disc_fakeA.view(-1)
|
|
231
|
+
loss_dis_real = criterion(disc_real, torch.ones_like(disc_real))
|
|
232
|
+
loss_dis_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
|
|
233
|
+
d_total_loss = loss_dis_real + loss_dis_fake
|
|
234
|
+
train_DAloss.append(d_total_loss.item())
|
|
235
|
+
d_total_loss.backward()
|
|
236
|
+
# torch.nn.utils.clip_grad_norm_(discriminator_AB.parameters(), gradient_clipping)
|
|
237
|
+
disc_optim_AB.step()
|
|
238
|
+
disc_sche_AB.step()
|
|
239
|
+
discriminator_AB.eval()
|
|
240
|
+
# discriminator_BA.eval()
|
|
241
|
+
|
|
242
|
+
for _ in range(n_gen):#
|
|
243
|
+
encoder.train()
|
|
244
|
+
decoder.train()
|
|
245
|
+
enc_optim.zero_grad()
|
|
246
|
+
dec_optim.zero_grad()
|
|
247
|
+
fake_xlabel = encoder(xdata, relu)
|
|
248
|
+
fake_xdata = decoder(fake_xlabel, relu)
|
|
249
|
+
# fake_xdata_ = decoder(xlabel, relu)
|
|
250
|
+
# disc_fakeA = discriminator_AB(fake_xlabel)
|
|
251
|
+
# disc_fake = disc_fakeA.view(-1)
|
|
252
|
+
# gA_loss = criterion(disc_fake, torch.ones_like(disc_fake))
|
|
253
|
+
# combined_xlabel = torch.cat((xdata, xlabel), dim=1)
|
|
254
|
+
# combined_fake_xlabel = torch.cat((fake_xdata, fake_xlabel), dim=1)
|
|
255
|
+
# disc_realA = discriminator_AB(combined_xlabel)
|
|
256
|
+
# disc_fakeA = discriminator_AB(combined_fake_xlabel)
|
|
257
|
+
disc_realA = discriminator_AB(xlabel)
|
|
258
|
+
disc_fakeA = discriminator_AB(fake_xlabel)
|
|
259
|
+
# gA_loss = -wasserstein_loss(disc_realA, disc_fakeA)
|
|
260
|
+
gA_loss = torch.abs(wasserstein_loss(disc_realA, disc_fakeA))
|
|
261
|
+
# gA_loss = torch.abs(wasserstein_loss(disc_realA, disc_fakeA))
|
|
262
|
+
# gp = gradient_penalty(discriminator_AB, xlabel, fake_xlabel, device, lambda_gp = lambda_gp)
|
|
263
|
+
# d_total_loss = gA_loss + gp
|
|
264
|
+
# disc_fakeB = discriminator_BA(fake_xdata)
|
|
265
|
+
# gA_loss = -disc_fakeA.mean()
|
|
266
|
+
# gB_loss = -disc_fakeB.mean()
|
|
267
|
+
|
|
268
|
+
latent_loss = MAE(fake_xlabel, xlabel)
|
|
269
|
+
# + 0.1 * sliced_wasserstein_distance(fake_xlabel, xlabel, 1000, device=device)
|
|
270
|
+
recon_loss = loss2(fake_xdata, xdata) + 0.1*loss1(fake_xdata, xdata)
|
|
271
|
+
|
|
272
|
+
loss = recon_loss + 0.3*latent_loss + gA_loss #
|
|
273
|
+
# loss = 0.4*recon_loss + 0.6*latent_loss + gA_loss # last best
|
|
274
|
+
train_latloss.append(latent_loss.item())
|
|
275
|
+
train_GAloss.append(gA_loss.item())
|
|
276
|
+
# train_GBloss.append(gB_loss.item())
|
|
277
|
+
train_reloss.append(recon_loss.item())
|
|
278
|
+
# train_Gloss.append(g_loss.item())
|
|
279
|
+
train_loss.append(loss.item())
|
|
280
|
+
loss.backward()
|
|
281
|
+
# gA_loss.backward()
|
|
282
|
+
# torch.nn.utils.clip_grad_norm_(encoder.parameters(), gradient_clipping)
|
|
283
|
+
# torch.nn.utils.clip_grad_norm_(decoder.parameters(), gradient_clipping)
|
|
284
|
+
enc_optim.step()
|
|
285
|
+
dec_optim.step()
|
|
286
|
+
enc_sche.step()
|
|
287
|
+
dec_sche.step()
|
|
288
|
+
encoder.eval()
|
|
289
|
+
decoder.eval()
|
|
290
|
+
|
|
291
|
+
#, loss_GB: %.5f , loss_DB: %.5f
|
|
292
|
+
epoch_info = 'loss_re: %.5f, loss_lat: %.5f, loss_GA: %.5f, loss: %.5f, loss_DA: %.5f' % \
|
|
293
|
+
(torch.mean(torch.FloatTensor(train_reloss)),
|
|
294
|
+
torch.mean(torch.FloatTensor(train_latloss)),
|
|
295
|
+
torch.mean(torch.FloatTensor(train_GAloss)),
|
|
296
|
+
torch.mean(torch.FloatTensor(train_loss)),
|
|
297
|
+
torch.mean(torch.FloatTensor(train_DAloss))
|
|
298
|
+
# torch.mean(torch.FloatTensor(train_DBloss))
|
|
299
|
+
)#
|
|
300
|
+
epoch.set_postfix_str(epoch_info)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
torch.save(encoder, save_path+'/encoder.pth')
|
|
304
|
+
torch.save(decoder, save_path+'/decoder.pth')
|
|
305
|
+
|
|
306
|
+
# torch.save(discriminator_AB, save_path+'/discriminator_AB.pth')
|
|
307
|
+
# torch.save(discriminator_BA, save_path+'/discriminator_BA.pth')
|
|
308
|
+
encoder.eval()
|
|
309
|
+
decoder.eval()
|
|
310
|
+
# Get generated or recovered data
|
|
311
|
+
if experiment=='generation' or experiment=='recovery' or experiment=='higher_res':
|
|
312
|
+
full_coor_df = full_xlabel_df.copy()
|
|
313
|
+
full_coor_t = torch.from_numpy(np.array(full_coor_df))
|
|
314
|
+
full_coor_t = full_coor_t.to(torch.float32)
|
|
315
|
+
full_coor_t = Variable(full_coor_t.to(device))
|
|
316
|
+
# if experiment=='higher_res':
|
|
317
|
+
dataloader_t = DataLoader(full_coor_t, batch_size=1000, shuffle=False)
|
|
318
|
+
generate_profile_list = []
|
|
319
|
+
for batch_coor_t in dataloader_t:
|
|
320
|
+
batch_coor_t = batch_coor_t.to(torch.float32)
|
|
321
|
+
batch_coor_t = Variable(batch_coor_t.to(device))
|
|
322
|
+
batch_generate_profile = decoder(batch_coor_t, relu)
|
|
323
|
+
batch_generate_profile = batch_generate_profile.cpu().detach().numpy()
|
|
324
|
+
generate_profile_list.append(batch_generate_profile)
|
|
325
|
+
generate_profile = np.concatenate(generate_profile_list, axis=0)
|
|
326
|
+
# else:
|
|
327
|
+
# generate_profile = decoder(full_coor_t, relu)
|
|
328
|
+
# generate_profile = generate_profile.cpu().detach().numpy()
|
|
329
|
+
if not relu:
|
|
330
|
+
generate_profile = np.clip(generate_profile, a_min=0, a_max=None)
|
|
331
|
+
|
|
332
|
+
if experiment=='recovery':
|
|
333
|
+
np.savetxt(save_path+"/fill_data.txt", generate_profile)
|
|
334
|
+
|
|
335
|
+
st_intensity = csr_matrix(generate_profile, dtype=np.float32)
|
|
336
|
+
adata_SMLED = sc.AnnData(st_intensity)
|
|
337
|
+
# adata_SMLED = sc.AnnData(generate_profile)
|
|
338
|
+
adata_SMLED.obsm["spatial"] = full_coor
|
|
339
|
+
adata_SMLED.var.index = used_gene
|
|
340
|
+
|
|
341
|
+
adata.write(save_path + '/original_data.h5ad')
|
|
342
|
+
|
|
343
|
+
if experiment=='generation' or experiment=='higher_res':
|
|
344
|
+
adata_SMLED.write(save_path + '/generated_data.h5ad')
|
|
345
|
+
return adata_SMLED
|
|
346
|
+
elif experiment=='recovery':
|
|
347
|
+
adata_sample.write(save_path + '/sampled_data.h5ad')
|
|
348
|
+
adata_SMLED.obs = adata.obs
|
|
349
|
+
adata_SMLED.write(save_path + '/recovered_data.h5ad')
|
|
350
|
+
return adata_sample, adata_SMLED
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def fix_seed(seed):
|
|
354
|
+
#seed = 2025
|
|
355
|
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
|
356
|
+
random.seed(seed)
|
|
357
|
+
np.random.seed(seed)
|
|
358
|
+
torch.manual_seed(seed)
|
|
359
|
+
torch.cuda.manual_seed(seed)
|
|
360
|
+
torch.cuda.manual_seed_all(seed)
|
|
361
|
+
cudnn.deterministic = True
|
|
362
|
+
cudnn.benchmark = False
|
|
363
|
+
# os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|