SM2ST 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
SM2ST/SMLED.py ADDED
@@ -0,0 +1,332 @@
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.backends.cudnn as cudnn
6
+ import torch.nn.functional as F
7
+ import random
8
+ # from .gatv2_conv_or import GATv2Conv as GATConv
9
+ from torch.nn.utils import spectral_norm
10
+
11
+ class encoding_mask_noise(torch.nn.Module):
12
+ def __init__(self, hidden_dims):
13
+ super(encoding_mask_noise, self).__init__()
14
+ [in_dim, num_hidden, out_dim] = hidden_dims
15
+ self.enc_mask_token = nn.Parameter(torch.zeros(size=(1, in_dim)))
16
+ self.reset_parameters_for_token()
17
+
18
+ def reset_parameters_for_token(self):
19
+ nn.init.xavier_normal_(self.enc_mask_token.data, gain=1.414)#
20
+
21
+ def forward(self, x, mask_rate=0.5, replace_rate=0.05):
22
+ # num_nodes = g.num_nodes()
23
+ num_nodes = x.size()[0]
24
+ perm = torch.randperm(num_nodes, device=x.device)
25
+ num_mask_nodes = int(mask_rate * num_nodes)
26
+ mask_token_rate = 1-replace_rate
27
+ # random masking
28
+ num_mask_nodes = int(mask_rate * num_nodes)
29
+ mask_nodes = perm[: num_mask_nodes]
30
+ keep_nodes = perm[num_mask_nodes: ]
31
+
32
+ if replace_rate > 0.0:
33
+ num_noise_nodes = int(replace_rate * num_mask_nodes)
34
+ perm_mask = torch.randperm(num_mask_nodes, device=x.device)
35
+ token_nodes = mask_nodes[perm_mask[: -num_noise_nodes]]#int(mask_token_rate * num_mask_nodes)
36
+ noise_nodes = mask_nodes[perm_mask[-num_noise_nodes:]]
37
+ noise_to_be_chosen = torch.randperm(num_nodes, device=x.device)[:num_noise_nodes]
38
+
39
+ out_x = x.clone()
40
+ # out_x[token_nodes] = torch.zeros_like(out_x[token_nodes])
41
+ out_x[token_nodes] = 0.0
42
+ out_x[noise_nodes] = x[noise_to_be_chosen]
43
+ # out_x[noise_nodes] = torch.add(x[noise_to_be_chosen], out_x[noise_nodes])
44
+ else:
45
+ out_x = x.clone()
46
+ token_nodes = mask_nodes
47
+ out_x[mask_nodes] = 0.0
48
+
49
+ out_x[token_nodes] += self.enc_mask_token
50
+ # use_g = g.clone()
51
+ return out_x, mask_nodes, keep_nodes
52
+
53
+ class random_remask(torch.nn.Module):
54
+ def __init__(self, hidden_dims):
55
+ super(random_remask, self).__init__()
56
+ [in_dim, num_hidden, out_dim] = hidden_dims
57
+ self.dec_mask_token = nn.Parameter(torch.zeros(size=(1, out_dim)))
58
+ self.reset_parameters_for_token()
59
+
60
+ def reset_parameters_for_token(self):
61
+ nn.init.xavier_normal_(self.dec_mask_token.data, gain=1.414)
62
+
63
+ def forward(self,rep,remask_rate=0.5):
64
+ num_nodes = rep.size()[0]
65
+ # num_nodes = g.num_nodes()
66
+ perm = torch.randperm(num_nodes, device=rep.device)
67
+ num_remask_nodes = int(remask_rate * num_nodes)
68
+ remask_nodes = perm[: num_remask_nodes]
69
+ rekeep_nodes = perm[num_remask_nodes: ]
70
+
71
+ out_rep = rep.clone()
72
+ out_rep[remask_nodes] = 0.0
73
+ out_rep[remask_nodes] += self.dec_mask_token
74
+ return out_rep, remask_nodes, rekeep_nodes
75
+
76
+
77
+ # class Encoder(nn.Module):
78
+ # def __init__(self, mz_number, X_dim):
79
+ # super(Encoder, self).__init__()
80
+ # # self.encoding_mask_noise = encoding_mask_noise(hidden_dims)
81
+ # # self.random_remask = random_remask(hidden_dims)
82
+ # self.fc1 = nn.Linear(mz_number, 1024)
83
+ # self.fc1_bn = nn.BatchNorm1d(1024)
84
+ # self.fc2 = nn.Linear(1024, 256)
85
+ # self.fc2_bn = nn.BatchNorm1d(256)
86
+ # self.fc3 = nn.Linear(256, 64)
87
+ # self.fc3_bn = nn.BatchNorm1d(64)
88
+ # self.fc4 = nn.Linear(64, 8)
89
+ # self.fc4_bn = nn.BatchNorm1d(8)
90
+ # self.fc5 = nn.Linear(8, X_dim)
91
+ # # Initialize parameters
92
+ # self.init_weights()
93
+
94
+ # def init_weights(self):
95
+ # gain = nn.init.calculate_gain('relu')
96
+ # # Initialize weights and biases for all linear layers
97
+ # for module in self.modules():
98
+ # if isinstance(module, nn.Linear):
99
+ # # Use the Xavier initialization method to specify the gain value
100
+ # nn.init.xavier_uniform_(module.weight, gain=gain)
101
+ # if module.bias is not None:
102
+ # # Initialize the bias to 0
103
+ # nn.init.zeros_(module.bias)
104
+
105
+ # def forward(self, features, relu=False, mask = 0.0):
106
+ # if mask:
107
+ # mask_tensor = torch.bernoulli(torch.full_like(features, mask)).to(features.device) # Random mask with 50% probability
108
+ # features = features * mask_tensor # Apply mask
109
+ # h1 = F.relu(self.fc1_bn(self.fc1(features)))
110
+ # h2 = F.relu(self.fc2_bn(self.fc2(h1)))
111
+ # h3 = F.relu(self.fc3_bn(self.fc3(h2)))
112
+ # h4 = F.relu(self.fc4_bn(self.fc4(h3)))
113
+ # if relu:
114
+ # return F.relu(self.fc5(h4))
115
+ # else:
116
+ # return self.fc5(h4)
117
+
118
+ class Encoder(nn.Module):
119
+ def __init__(self, mz_number, X_dim, down_ratio):
120
+ super(Encoder, self).__init__()
121
+ self.dropout_rate = down_ratio
122
+
123
+ self.fc1 = nn.Linear(mz_number, 1024)
124
+ self.fc1_bn = nn.BatchNorm1d(1024)
125
+ self.dropout1 = nn.Dropout(self.dropout_rate)
126
+
127
+ self.fc2 = nn.Linear(1024, 256)
128
+ self.fc2_bn = nn.BatchNorm1d(256)
129
+ self.dropout2 = nn.Dropout(self.dropout_rate)
130
+
131
+ self.fc3 = nn.Linear(256, 64)
132
+ self.fc3_bn = nn.BatchNorm1d(64)
133
+ self.dropout3 = nn.Dropout(self.dropout_rate)
134
+
135
+ self.fc4 = nn.Linear(64, 16)#8
136
+ self.fc4_bn = nn.BatchNorm1d(16)#8
137
+ self.dropout4 = nn.Dropout(self.dropout_rate)
138
+
139
+ self.fc5 = nn.Linear(16, X_dim)
140
+
141
+ # Initialize parameters
142
+ self.init_weights()
143
+
144
+ def init_weights(self):
145
+ gain = nn.init.calculate_gain('relu')
146
+ # Initialize weights and biases for all linear layers
147
+ for module in self.modules():
148
+ if isinstance(module, nn.Linear):
149
+ # Use the Xavier initialization method to specify the gain value
150
+ nn.init.xavier_uniform_(module.weight, gain=gain)
151
+ if module.bias is not None:
152
+ # Initialize the bias to 0
153
+ nn.init.zeros_(module.bias)
154
+
155
+ def forward(self, features, relu=False):
156
+ # h1 = self.CustomDropout1(features)
157
+ # h1 = F.relu(self.fc1_bn(self.fc1(h1)))
158
+ h1 = F.relu(self.fc1_bn(self.fc1(features)))
159
+ h1 = self.dropout1(h1)
160
+
161
+ h2 = F.relu(self.fc2_bn(self.fc2(h1)))
162
+ h2 = self.dropout2(h2)
163
+
164
+ h3 = F.relu(self.fc3_bn(self.fc3(h2)))
165
+ h3 = self.dropout3(h3)
166
+
167
+ h4 = F.relu(self.fc4_bn(self.fc4(h3)))
168
+ h4 = self.dropout4(h4)
169
+
170
+ if relu:
171
+ return F.relu(self.fc5(h4))
172
+ else:
173
+ return self.fc5(h4)
174
+
175
+
176
+ # class Decoder(nn.Module):
177
+ # def __init__(self, mz_number, X_dim):
178
+ # super(Decoder, self).__init__()
179
+ # self.fc6 = nn.Linear(X_dim, 8)
180
+ # self.fc6_bn = nn.BatchNorm1d(8)
181
+ # self.fc7 = nn.Linear(8, 64)
182
+ # self.fc7_bn = nn.BatchNorm1d(64)
183
+ # self.fc8 = nn.Linear(64, 256)
184
+ # self.fc8_bn = nn.BatchNorm1d(256)
185
+ # self.fc9 = nn.Linear(256, 1024)
186
+ # self.fc9_bn = nn.BatchNorm1d(1024)
187
+ # self.fc10 = nn.Linear(1024, mz_number)
188
+ # # Initialize parameters
189
+ # self.init_weights()
190
+
191
+ # def init_weights(self):
192
+ # # Initialize weights and biases for all linear layers
193
+ # gain = nn.init.calculate_gain('relu')
194
+ # for module in self.modules():
195
+ # if isinstance(module, nn.Linear):
196
+ # # Use the Xavier initialization method to specify the gain value
197
+ # nn.init.xavier_uniform_(module.weight, gain=gain)
198
+ # if module.bias is not None:
199
+ # # Initialize the bias to 0
200
+ # nn.init.zeros_(module.bias)
201
+
202
+ # def forward(self, z, relu=False):
203
+ # h6 = F.relu(self.fc6_bn(self.fc6(z)))
204
+ # h7 = F.relu(self.fc7_bn(self.fc7(h6)))
205
+ # h8 = F.relu(self.fc8_bn(self.fc8(h7)))
206
+ # h9 = F.relu(self.fc9_bn(self.fc9(h8)))
207
+ # if relu:
208
+ # return F.relu(self.fc10(h9))
209
+ # else:
210
+ # return self.fc10(h9)
211
+
212
+ class Decoder(nn.Module):
213
+ def __init__(self, mz_number, X_dim, down_ratio):
214
+ super(Decoder, self).__init__()
215
+ self.dropout_rate = down_ratio
216
+
217
+ self.fc6 = nn.Linear(X_dim, 16)#8
218
+ self.fc6_bn = nn.BatchNorm1d(16)#8
219
+ self.dropout6 = nn.Dropout(self.dropout_rate)
220
+
221
+ self.fc7 = nn.Linear(16, 64)
222
+ self.fc7_bn = nn.BatchNorm1d(64)
223
+ self.dropout7 = nn.Dropout(self.dropout_rate)
224
+
225
+ self.fc8 = nn.Linear(64, 256)
226
+ self.fc8_bn = nn.BatchNorm1d(256)
227
+ self.dropout8 = nn.Dropout(self.dropout_rate)
228
+
229
+ self.fc9 = nn.Linear(256, 1024)
230
+ self.fc9_bn = nn.BatchNorm1d(1024)
231
+ self.dropout9 = nn.Dropout(self.dropout_rate)
232
+
233
+ self.fc10 = nn.Linear(1024, mz_number)
234
+
235
+ # Initialize parameters
236
+ self.init_weights()
237
+
238
+ def init_weights(self):
239
+ gain = nn.init.calculate_gain('relu')
240
+ # Initialize weights and biases for all linear layers
241
+ for module in self.modules():
242
+ if isinstance(module, nn.Linear):
243
+ # Use the Xavier initialization method to specify the gain value
244
+ nn.init.xavier_uniform_(module.weight, gain=gain)
245
+ if module.bias is not None:
246
+ # Initialize the bias to 0
247
+ nn.init.zeros_(module.bias)
248
+
249
+ def forward(self, z, relu=False):
250
+ h6 = F.relu(self.fc6_bn(self.fc6(z)))
251
+ h6 = self.dropout6(h6)
252
+
253
+ h7 = F.relu(self.fc7_bn(self.fc7(h6)))
254
+ h7 = self.dropout7(h7)
255
+
256
+ h8 = F.relu(self.fc8_bn(self.fc8(h7)))
257
+ h8 = self.dropout8(h8)
258
+
259
+ h9 = F.relu(self.fc9_bn(self.fc9(h8)))
260
+ h9 = self.dropout9(h9)
261
+
262
+ if relu:
263
+ return F.relu(self.fc10(h9))
264
+ else:
265
+ return self.fc10(h9)
266
+
267
+ class Discriminator_A(torch.nn.Module):
268
+ def __init__(self, X_dim):
269
+ super(Discriminator_A, self).__init__()
270
+ self.fc = torch.nn.Sequential(
271
+ spectral_norm(nn.Linear(X_dim, 128)),# last best
272
+ nn.LeakyReLU(0.2),
273
+ spectral_norm(nn.Linear(128, 32)),
274
+ nn.LeakyReLU(0.2),
275
+ spectral_norm(nn.Linear(32, 8)),
276
+ nn.LeakyReLU(0.2),
277
+ spectral_norm(nn.Linear(8, 1)),
278
+ nn.Sigmoid()
279
+ # nn.Linear(X_dim, 64),
280
+ # nn.LeakyReLU(0.2),
281
+ # nn.Linear(64, 8),
282
+ # nn.LeakyReLU(0.2),
283
+ # nn.Linear(8, 1),
284
+ # nn.Sigmoid()
285
+ )
286
+ self.init_weights()
287
+
288
+ def init_weights(self):
289
+ gain = nn.init.calculate_gain('leaky_relu', 0.2)
290
+ # Initialize weights and biases for all linear layers
291
+ for module in self.modules():
292
+ if isinstance(module, nn.Linear):
293
+ # Use the Xavier initialization method to specify the gain value
294
+ nn.init.xavier_uniform_(module.weight, gain=gain)
295
+ if module.bias is not None:
296
+ # Initialize the bias to 0
297
+ nn.init.zeros_(module.bias)
298
+ def forward(self, x):
299
+ return self.fc(x)
300
+
301
+ class Discriminator_B(torch.nn.Module):
302
+ def __init__(self, X_dim):
303
+ super(Discriminator_B, self).__init__()
304
+ self.fc = torch.nn.Sequential(
305
+ nn.Linear(X_dim, 512),
306
+ nn.LeakyReLU(0.2),
307
+ nn.Linear(512, 128),
308
+ nn.LeakyReLU(0.2),
309
+ nn.Linear(128, 32),
310
+ nn.LeakyReLU(0.2),
311
+ nn.Linear(32, 1),
312
+ # nn.Linear(X_dim, 16),
313
+ # nn.LeakyReLU(0.2),
314
+ # nn.Linear(16, 4),
315
+ # nn.LeakyReLU(0.2),
316
+ # nn.Linear(4, 1),
317
+ # nn.Sigmoid()
318
+ )
319
+ self.init_weights()
320
+
321
+ def init_weights(self):
322
+ gain = nn.init.calculate_gain('leaky_relu', 0.2)
323
+ # Initialize weights and biases for all linear layers
324
+ for module in self.modules():
325
+ if isinstance(module, nn.Linear):
326
+ # Use the Xavier initialization method to specify the gain value
327
+ nn.init.xavier_uniform_(module.weight, gain=gain)
328
+ if module.bias is not None:
329
+ # Initialize the bias to 0
330
+ nn.init.zeros_(module.bias)
331
+ def forward(self, x):
332
+ return self.fc(x)
SM2ST/Train_SMLED.py ADDED
@@ -0,0 +1,363 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ from tqdm import tqdm
4
+ import scipy.sparse as sp
5
+ import os
6
+ from .SMLED import Encoder, Decoder, Discriminator_A,Discriminator_B
7
+ from .utils import Transfer_pytorch_Data, positional_pixel_step, recovery_coord, generation_coord, Cal_Spatial_Net
8
+ from .dataset import *
9
+ import random
10
+ import torch
11
+ import torch.backends.cudnn as cudnn
12
+ from torch.autograd import Variable
13
+ import os
14
+ import torch.nn.functional as F
15
+ from scipy.sparse import csr_matrix, csc_matrix, coo_matrix
16
+ from torch_sparse import SparseTensor
17
+
18
+
19
+ def sce_loss(x, y, alpha=1.0):
20
+ x = F.normalize(x, p=2, dim=-1)
21
+ y = F.normalize(y, p=2, dim=-1)
22
+ loss = (1 - (x * y).sum(dim=-1)).pow_(alpha)
23
+
24
+ loss = loss.mean()
25
+ return loss
26
+
27
+ class WeightedMSELoss(torch.nn.Module):
28
+ def __init__(self, weights):
29
+ super(WeightedMSELoss, self).__init__()
30
+ self.weights = weights
31
+
32
+ def forward(self, y_pred, y_true):
33
+ # Ensure that the shape of the weights is consistent with that of the input tensor
34
+ return torch.mean(self.weights * (y_pred -y_true) ** 2)
35
+
36
+ class WeightedMAELoss(torch.nn.Module):
37
+ def __init__(self, weights):
38
+ super(WeightedMAELoss, self).__init__()
39
+ self.weights = weights
40
+
41
+ def forward(self, y_pred, y_true):
42
+ # Ensure that the shape of the weights is consistent with that of the input tensor
43
+ return torch.mean(self.weights * torch.abs(y_pred - y_true))
44
+
45
+
46
+ def rand_projections(
47
+ embedding_dim,
48
+ num_samples=50,
49
+ device='cpu'
50
+ ):
51
+ """This function generates `num_samples` random samples from the latent space's unit sphere.
52
+
53
+ Args:
54
+ embedding_dim (int): embedding dimensionality
55
+ num_samples (int): number of random projection samples
56
+
57
+ Return:
58
+ torch.Tensor: tensor of size (num_samples, embedding_dim)
59
+ """
60
+ projections = [w / np.sqrt((w**2).sum()) # L2 normalization
61
+ for w in np.random.normal(size=(num_samples, embedding_dim))]
62
+ projections = np.asarray(projections)
63
+ return torch.from_numpy(projections).type(torch.FloatTensor).to(device)
64
+
65
+
66
+ def wasserstein_loss(disc_real, disc_fake):
67
+ return -torch.mean(disc_real) + torch.mean(disc_fake)
68
+
69
+ def gradient_penalty(discriminator, real_data, fake_data, device, lambda_gp=10):
70
+ alpha = torch.rand(real_data.size(0), 1).to(device)
71
+ interpolated = alpha * real_data + ((1 - alpha) * fake_data)
72
+ interpolated = interpolated.requires_grad_(True)
73
+ mixed_scores = discriminator(interpolated)
74
+ gradients = torch.autograd.grad(
75
+ inputs=interpolated,
76
+ outputs=mixed_scores,
77
+ grad_outputs=torch.ones(mixed_scores.size()).to(device),
78
+ create_graph=True,
79
+ retain_graph=True,
80
+ only_inputs=True
81
+ )[0]
82
+ gradients_norm = torch.norm(gradients.view(gradients.size(0), -1), dim=1)
83
+ gradient_penalty = lambda_gp * ((gradients_norm - 1) ** 2).mean()
84
+ return gradient_penalty
85
+
86
+
87
+ def train_SMLED(adata=None, X_dim = 2, delta = 1.0, train_epoch=15000,lr=0.001,mask_ratio=0.5,alpha=1.0,key_added='SMLED',step_size=10000,gamma=1.0,
88
+ relu=True, gradient_clipping=5., experiment='generation', weight_decay=0.0001, verbose=True, batch_size = 1000,lambda_gp = 1.0,
89
+ random_seed=2025, save_path = './SMLED_pyG_result',down_ratio = 0., coord_sf=1.0,
90
+ WMMSE=0.0, res = 2.0, device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')):
91
+ """\
92
+ Training GAN auto-encoder.
93
+
94
+ Parameters
95
+ ----------
96
+ adata
97
+ AnnData object of scanpy package.
98
+ delta
99
+ Coordinate scaling.
100
+ train_epoch
101
+ Number of total epochs in training.
102
+ lr
103
+ Learning rate for AdamOptimizer.
104
+ key_added
105
+ The latent embeddings are saved in adata.obsm[key_added].
106
+ gradient_clipping
107
+ Gradient Clipping.
108
+ weight_decay
109
+ Weight decay for AdamOptimizer.
110
+ mask_ratio
111
+ Random masking ratio.
112
+ WMMSE
113
+ The weight distribution of wmse.
114
+ device
115
+ See torch.device.
116
+
117
+ Returns
118
+ -------
119
+ AnnData
120
+ """
121
+
122
+ # seed_everything()
123
+ seed=random_seed
124
+ fix_seed(seed)
125
+ if not os.path.isdir(save_path):
126
+ os.mkdir(save_path)
127
+ if verbose:
128
+ print('Size of Input: ', adata.X.shape)
129
+
130
+ if experiment=='recovery':
131
+ # adata, masked_adata, adata_filtered, picked_index, remaining_index = masked_anndata(adata = adata, mask_ratio=0.5)
132
+ coor, full_coor, sample_index, sample_barcode = recovery_coord(adata,name='spatial',mask_ratio = mask_ratio)
133
+ used_gene, normed_data, adata_sample = get_data(adata, experiment=experiment, sample_index=sample_index, sample_barcode=sample_barcode)
134
+ xlabel_df,full_xlabel_df = positional_pixel_step(coor, full_coor, delta, coord_sf)
135
+ print(xlabel_df,full_xlabel_df)
136
+ transformed_dataset = MyDataset(normed_data=normed_data, coor_df=xlabel_df, transform=transforms.Compose([ToTensor()]))
137
+ train_loader = DataLoader(transformed_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=False)
138
+
139
+ elif experiment == 'higher_res':
140
+ coor, full_coor = generation_coord(adata,name='spatial',res=res)
141
+ used_gene, normed_data = get_data(adata, experiment=experiment)
142
+ xlabel_df,full_xlabel_df = positional_pixel_step(coor, full_coor, delta, coord_sf)
143
+ print(xlabel_df,full_xlabel_df)
144
+ transformed_dataset = MyDataset(normed_data=normed_data, coor_df = xlabel_df, transform=transforms.Compose([ToTensor()]))
145
+ train_loader = DataLoader(transformed_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=False)
146
+
147
+ elif experiment == 'generation':
148
+ coor = adata.obsm['spatial']
149
+ full_coor = adata.uns['coord']
150
+ used_gene, normed_data = get_data(adata, experiment=experiment)
151
+ xlabel_df,full_xlabel_df = positional_pixel_step(coor, full_coor, delta, coord_sf)
152
+ print(xlabel_df,full_xlabel_df)
153
+ transformed_dataset = MyDataset(normed_data=normed_data, coor_df=xlabel_df, transform=transforms.Compose([ToTensor()]))
154
+ train_loader = DataLoader(transformed_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=False)
155
+
156
+ gene_number = len(used_gene)
157
+ encoder, decoder = Encoder(gene_number, X_dim, down_ratio), Decoder(gene_number, X_dim, down_ratio=0.)
158
+ discriminator_AB = Discriminator_A(X_dim) #, Discriminator_B(gene_number) #, discriminator_BA
159
+ # encoder.train()
160
+ # decoder.train()
161
+
162
+ encoder, decoder = encoder.to(device), decoder.to(device)
163
+ discriminator_AB = discriminator_AB.to(device) # ,discriminator_BA.to(device) , discriminator_BA
164
+
165
+ enc_optim = torch.optim.Adam(encoder.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-8, betas=(0.5, 0.999))#
166
+ dec_optim = torch.optim.Adam(decoder.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-8, betas=(0.5, 0.999))
167
+
168
+ disc_optim_AB = torch.optim.Adam(discriminator_AB.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-8, betas=(0.5, 0.999))
169
+ # enc_optim_gan = torch.optim.Adam(encoder.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-8, betas=(0.5, 0.999)) #
170
+ # dec_optim_gan = torch.optim.Adam(decoder.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-8, betas=(0.5, 0.999))
171
+
172
+ n_gen = 1
173
+ n_crit = 2
174
+ # disc_optim_BA = torch.optim.Adam(discriminator_BA.parameters(), lr=lr, weight_decay=weight_decay, eps=1e-8, betas=(0.5, 0.999))
175
+ enc_sche = torch.optim.lr_scheduler.StepLR(enc_optim, step_size=n_gen*step_size, gamma=gamma)
176
+ dec_sche = torch.optim.lr_scheduler.StepLR(dec_optim, step_size=n_gen*step_size, gamma=gamma)
177
+ disc_sche_AB = torch.optim.lr_scheduler.StepLR(disc_optim_AB, step_size=n_crit*step_size, gamma=gamma)
178
+ # enc_sche_gan = torch.optim.lr_scheduler.StepLR(enc_optim_gan, step_size=step_size, gamma=gamma)
179
+ # dec_sche_gan = torch.optim.lr_scheduler.StepLR(dec_optim_gan, step_size=step_size, gamma=gamma)
180
+ # loss function
181
+ criterion = torch.nn.BCELoss()
182
+
183
+ # loss function
184
+ if WMMSE>0:
185
+ if sp.issparse(adata.X):
186
+ matrix = adata.X.A
187
+ else:
188
+ matrix = adata.X
189
+ column_sums = matrix.sum(axis=0)
190
+ normalized = column_sums * (WMMSE / column_sums.sum())
191
+ weights = WMMSE - normalized
192
+
193
+ weights = torch.tensor(weights, dtype=torch.float32,device = device)
194
+ loss2 = WeightedMSELoss(weights)
195
+ loss1 = WeightedMAELoss(weights)
196
+ else:
197
+ loss2 = torch.nn.MSELoss()
198
+ loss1 = torch.nn.L1Loss()
199
+ MAE = torch.nn.L1Loss()
200
+ with tqdm(range(train_epoch), total=train_epoch, desc='Epochs') as epoch:
201
+ for j in epoch:
202
+ train_reloss = []
203
+ train_GAloss = []
204
+ train_latloss = []
205
+ train_loss = []
206
+ train_DAloss = []
207
+ # train_DBloss = []
208
+
209
+ for xdata, xlabel in train_loader:
210
+ xdata = xdata.to(torch.float32)
211
+ xlabel = xlabel.to(torch.float32)
212
+ xdata, xlabel = Variable(xdata.to(device)), Variable(xlabel.to(device))
213
+
214
+ for _ in range(n_crit): #3
215
+ discriminator_AB.train()
216
+ disc_optim_AB.zero_grad()
217
+ fake_xlabel = encoder(xdata, relu)
218
+ # fake_xdata = decoder(fake_xlabel, relu)
219
+ # fake_xdata = decoder(fake_xlabel, relu)
220
+ # combined_xlabel = torch.cat((xdata, xlabel), dim=1)
221
+ # combined_fake_xlabel = torch.cat((fake_xdata, fake_xlabel), dim=1)
222
+ # disc_realA = discriminator_AB(combined_xlabel)
223
+ # disc_fakeA = discriminator_AB(combined_fake_xlabel)
224
+ disc_realA = discriminator_AB(xlabel)
225
+ disc_fakeA = discriminator_AB(fake_xlabel)
226
+ # d_loss = wasserstein_loss(disc_realA, disc_fakeA)
227
+ # gp = gradient_penalty(discriminator_AB, xlabel, fake_xlabel, device, lambda_gp = lambda_gp)
228
+ # d_total_loss = d_loss + gp
229
+ disc_real = disc_realA.view(-1)
230
+ disc_fake = disc_fakeA.view(-1)
231
+ loss_dis_real = criterion(disc_real, torch.ones_like(disc_real))
232
+ loss_dis_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
233
+ d_total_loss = loss_dis_real + loss_dis_fake
234
+ train_DAloss.append(d_total_loss.item())
235
+ d_total_loss.backward()
236
+ # torch.nn.utils.clip_grad_norm_(discriminator_AB.parameters(), gradient_clipping)
237
+ disc_optim_AB.step()
238
+ disc_sche_AB.step()
239
+ discriminator_AB.eval()
240
+ # discriminator_BA.eval()
241
+
242
+ for _ in range(n_gen):#
243
+ encoder.train()
244
+ decoder.train()
245
+ enc_optim.zero_grad()
246
+ dec_optim.zero_grad()
247
+ fake_xlabel = encoder(xdata, relu)
248
+ fake_xdata = decoder(fake_xlabel, relu)
249
+ # fake_xdata_ = decoder(xlabel, relu)
250
+ # disc_fakeA = discriminator_AB(fake_xlabel)
251
+ # disc_fake = disc_fakeA.view(-1)
252
+ # gA_loss = criterion(disc_fake, torch.ones_like(disc_fake))
253
+ # combined_xlabel = torch.cat((xdata, xlabel), dim=1)
254
+ # combined_fake_xlabel = torch.cat((fake_xdata, fake_xlabel), dim=1)
255
+ # disc_realA = discriminator_AB(combined_xlabel)
256
+ # disc_fakeA = discriminator_AB(combined_fake_xlabel)
257
+ disc_realA = discriminator_AB(xlabel)
258
+ disc_fakeA = discriminator_AB(fake_xlabel)
259
+ # gA_loss = -wasserstein_loss(disc_realA, disc_fakeA)
260
+ gA_loss = torch.abs(wasserstein_loss(disc_realA, disc_fakeA))
261
+ # gA_loss = torch.abs(wasserstein_loss(disc_realA, disc_fakeA))
262
+ # gp = gradient_penalty(discriminator_AB, xlabel, fake_xlabel, device, lambda_gp = lambda_gp)
263
+ # d_total_loss = gA_loss + gp
264
+ # disc_fakeB = discriminator_BA(fake_xdata)
265
+ # gA_loss = -disc_fakeA.mean()
266
+ # gB_loss = -disc_fakeB.mean()
267
+
268
+ latent_loss = MAE(fake_xlabel, xlabel)
269
+ # + 0.1 * sliced_wasserstein_distance(fake_xlabel, xlabel, 1000, device=device)
270
+ recon_loss = loss2(fake_xdata, xdata) + 0.1*loss1(fake_xdata, xdata)
271
+
272
+ loss = recon_loss + 0.3*latent_loss + gA_loss #
273
+ # loss = 0.4*recon_loss + 0.6*latent_loss + gA_loss # last best
274
+ train_latloss.append(latent_loss.item())
275
+ train_GAloss.append(gA_loss.item())
276
+ # train_GBloss.append(gB_loss.item())
277
+ train_reloss.append(recon_loss.item())
278
+ # train_Gloss.append(g_loss.item())
279
+ train_loss.append(loss.item())
280
+ loss.backward()
281
+ # gA_loss.backward()
282
+ # torch.nn.utils.clip_grad_norm_(encoder.parameters(), gradient_clipping)
283
+ # torch.nn.utils.clip_grad_norm_(decoder.parameters(), gradient_clipping)
284
+ enc_optim.step()
285
+ dec_optim.step()
286
+ enc_sche.step()
287
+ dec_sche.step()
288
+ encoder.eval()
289
+ decoder.eval()
290
+
291
+ #, loss_GB: %.5f , loss_DB: %.5f
292
+ epoch_info = 'loss_re: %.5f, loss_lat: %.5f, loss_GA: %.5f, loss: %.5f, loss_DA: %.5f' % \
293
+ (torch.mean(torch.FloatTensor(train_reloss)),
294
+ torch.mean(torch.FloatTensor(train_latloss)),
295
+ torch.mean(torch.FloatTensor(train_GAloss)),
296
+ torch.mean(torch.FloatTensor(train_loss)),
297
+ torch.mean(torch.FloatTensor(train_DAloss))
298
+ # torch.mean(torch.FloatTensor(train_DBloss))
299
+ )#
300
+ epoch.set_postfix_str(epoch_info)
301
+
302
+
303
+ torch.save(encoder, save_path+'/encoder.pth')
304
+ torch.save(decoder, save_path+'/decoder.pth')
305
+
306
+ # torch.save(discriminator_AB, save_path+'/discriminator_AB.pth')
307
+ # torch.save(discriminator_BA, save_path+'/discriminator_BA.pth')
308
+ encoder.eval()
309
+ decoder.eval()
310
+ # Get generated or recovered data
311
+ if experiment=='generation' or experiment=='recovery' or experiment=='higher_res':
312
+ full_coor_df = full_xlabel_df.copy()
313
+ full_coor_t = torch.from_numpy(np.array(full_coor_df))
314
+ full_coor_t = full_coor_t.to(torch.float32)
315
+ full_coor_t = Variable(full_coor_t.to(device))
316
+ # if experiment=='higher_res':
317
+ dataloader_t = DataLoader(full_coor_t, batch_size=1000, shuffle=False)
318
+ generate_profile_list = []
319
+ for batch_coor_t in dataloader_t:
320
+ batch_coor_t = batch_coor_t.to(torch.float32)
321
+ batch_coor_t = Variable(batch_coor_t.to(device))
322
+ batch_generate_profile = decoder(batch_coor_t, relu)
323
+ batch_generate_profile = batch_generate_profile.cpu().detach().numpy()
324
+ generate_profile_list.append(batch_generate_profile)
325
+ generate_profile = np.concatenate(generate_profile_list, axis=0)
326
+ # else:
327
+ # generate_profile = decoder(full_coor_t, relu)
328
+ # generate_profile = generate_profile.cpu().detach().numpy()
329
+ if not relu:
330
+ generate_profile = np.clip(generate_profile, a_min=0, a_max=None)
331
+
332
+ if experiment=='recovery':
333
+ np.savetxt(save_path+"/fill_data.txt", generate_profile)
334
+
335
+ st_intensity = csr_matrix(generate_profile, dtype=np.float32)
336
+ adata_SMLED = sc.AnnData(st_intensity)
337
+ # adata_SMLED = sc.AnnData(generate_profile)
338
+ adata_SMLED.obsm["spatial"] = full_coor
339
+ adata_SMLED.var.index = used_gene
340
+
341
+ adata.write(save_path + '/original_data.h5ad')
342
+
343
+ if experiment=='generation' or experiment=='higher_res':
344
+ adata_SMLED.write(save_path + '/generated_data.h5ad')
345
+ return adata_SMLED
346
+ elif experiment=='recovery':
347
+ adata_sample.write(save_path + '/sampled_data.h5ad')
348
+ adata_SMLED.obs = adata.obs
349
+ adata_SMLED.write(save_path + '/recovered_data.h5ad')
350
+ return adata_sample, adata_SMLED
351
+
352
+
353
+ def fix_seed(seed):
354
+ #seed = 2025
355
+ os.environ['PYTHONHASHSEED'] = str(seed)
356
+ random.seed(seed)
357
+ np.random.seed(seed)
358
+ torch.manual_seed(seed)
359
+ torch.cuda.manual_seed(seed)
360
+ torch.cuda.manual_seed_all(seed)
361
+ cudnn.deterministic = True
362
+ cudnn.benchmark = False
363
+ # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'