SpaWeaver 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- SpaWeaver/ParallelTrain.py +234 -0
- SpaWeaver/__init__.py +7 -0
- SpaWeaver/args.py +48 -0
- SpaWeaver/model.py +438 -0
- SpaWeaver/preprocess.py +343 -0
- SpaWeaver/utils.py +606 -0
- spaweaver-0.0.1.dist-info/METADATA +22 -0
- spaweaver-0.0.1.dist-info/RECORD +11 -0
- spaweaver-0.0.1.dist-info/WHEEL +5 -0
- spaweaver-0.0.1.dist-info/licenses/LICENSE +662 -0
- spaweaver-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import itertools
|
|
4
|
+
import numpy as np
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import scanpy as sc
|
|
7
|
+
import SpaWeaver as sw
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
import torch.multiprocessing as mp
|
|
10
|
+
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
import torch.nn.functional as F
|
|
13
|
+
import torch.multiprocessing as mp
|
|
14
|
+
from torch.utils.data import DataLoader, TensorDataset, DistributedSampler
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def train(rank, world_size, args, adata1, adata2, models):
|
|
18
|
+
sw.utils.setup(rank, world_size)
|
|
19
|
+
|
|
20
|
+
models = models.to(rank)
|
|
21
|
+
model_1, mlp_1, reg_1, reg_2 = models
|
|
22
|
+
models = torch.nn.parallel.DistributedDataParallel(models, device_ids=[rank])
|
|
23
|
+
|
|
24
|
+
optimizer = sw.utils.create_optimizer(args.optimizer, models, args.lr, args.weight_decay)
|
|
25
|
+
|
|
26
|
+
node_HE_fea1, node_HE_fea2 = torch.Tensor(adata1.obsm['he_sp']), torch.Tensor(adata2.obsm['he_sp'])
|
|
27
|
+
X1_real, X2_real = torch.Tensor(adata1.X), torch.Tensor(adata2.X)
|
|
28
|
+
dataset1 = TensorDataset(X1_real, node_HE_fea1)
|
|
29
|
+
dataset2 = TensorDataset(X2_real, node_HE_fea2)
|
|
30
|
+
sampler1 = DistributedSampler(dataset1, num_replicas=world_size, rank=rank)
|
|
31
|
+
sampler2 = DistributedSampler(dataset2, num_replicas=world_size, rank=rank)
|
|
32
|
+
dataloader1 = DataLoader(dataset1, batch_size=args.batch_size, sampler=sampler1, drop_last=True)
|
|
33
|
+
dataloader2 = DataLoader(dataset2, batch_size=args.batch_size, sampler=sampler2, drop_last=True)
|
|
34
|
+
|
|
35
|
+
rbf = sw.model.RBF().to(rank)
|
|
36
|
+
MMD = sw.model.MMD_loss(kernel=rbf).to(rank)
|
|
37
|
+
|
|
38
|
+
len1 = len(dataloader1)
|
|
39
|
+
len2 = len(dataloader2)
|
|
40
|
+
|
|
41
|
+
if rank == 0:
|
|
42
|
+
print(f"====== Dataset Comprison: Data 1={len1} batches, Data 2={len2} batches ======")
|
|
43
|
+
if len2 > len1:
|
|
44
|
+
print(f"Data 2 data more,Automatically iterate over the Data 1 dataset...")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
epoch_iter = tqdm(range(args.epoch), desc="🧠 Training", disable=(rank != 0))
|
|
48
|
+
for epoch in epoch_iter:
|
|
49
|
+
sampler1.set_epoch(epoch)
|
|
50
|
+
sampler2.set_epoch(epoch)
|
|
51
|
+
|
|
52
|
+
if len1 >= len2:
|
|
53
|
+
# If the dataset1 is larger, iterate over the dataset2 instead
|
|
54
|
+
iterator = zip(dataloader1, itertools.cycle(dataloader2))
|
|
55
|
+
else:
|
|
56
|
+
# If the dataset2 is larger, iterate over the dataset1 instead.
|
|
57
|
+
iterator = zip(itertools.cycle(dataloader1), dataloader2)
|
|
58
|
+
|
|
59
|
+
for data1, data2 in iterator:
|
|
60
|
+
X1_real_batch, node_HE_fea1_batch = data1
|
|
61
|
+
X2_real_batch, node_HE_fea2_batch = data2
|
|
62
|
+
X1_real_batch, node_HE_fea1_batch = X1_real_batch.to(rank), node_HE_fea1_batch.to(rank)
|
|
63
|
+
X2_real_batch, node_HE_fea2_batch = X2_real_batch.to(rank), node_HE_fea2_batch.to(rank)
|
|
64
|
+
|
|
65
|
+
he1_map, he2_map = model_1(mlp_1(node_HE_fea1_batch)), model_1(mlp_1(node_HE_fea2_batch))
|
|
66
|
+
rec_omics1 = reg_1(he1_map)
|
|
67
|
+
rec_omics2 = reg_2(he2_map)
|
|
68
|
+
|
|
69
|
+
# loss
|
|
70
|
+
mmd = MMD(he1_map, he2_map)
|
|
71
|
+
loss1 = F.mse_loss(rec_omics1, X1_real_batch)
|
|
72
|
+
loss2 = F.mse_loss(rec_omics2, X2_real_batch)
|
|
73
|
+
loss = loss1 + loss2 + args.mmd_weight * mmd
|
|
74
|
+
|
|
75
|
+
optimizer.zero_grad()
|
|
76
|
+
loss.backward()
|
|
77
|
+
optimizer.step()
|
|
78
|
+
sw.utils.cleanup()
|
|
79
|
+
|
|
80
|
+
if rank == 0:
|
|
81
|
+
print("💾 Saving checkpoint...")
|
|
82
|
+
model_1_save = model_1.module if hasattr(model_1, "module") else model_1
|
|
83
|
+
mlp_1_save = mlp_1.module if hasattr(mlp_1, "module") else mlp_1
|
|
84
|
+
reg_1_save = reg_1.module if hasattr(reg_1, "module") else reg_1
|
|
85
|
+
reg_2_save = reg_2.module if hasattr(reg_2, "module") else reg_2
|
|
86
|
+
|
|
87
|
+
torch.save({
|
|
88
|
+
'model_1': model_1_save.state_dict(),
|
|
89
|
+
'mlp_1': mlp_1_save.state_dict(),
|
|
90
|
+
'reg_1': reg_1_save.state_dict(),
|
|
91
|
+
'reg_2': reg_2_save.state_dict()
|
|
92
|
+
}, f'{args.output_folder}model/{args.save_tag}/models.pt')
|
|
93
|
+
print(f"✅ Checkpoint saved to {args.output_folder}model/{args.save_tag}/models.pt.")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def train_Multi_Omics(rank, world_size, args, adata1, adata2, models, anno1, anno2):
|
|
97
|
+
sw.utils.setup(rank, world_size)
|
|
98
|
+
|
|
99
|
+
models = models.to(rank)
|
|
100
|
+
model_1, mlp_1, reg_1, reg_2, anno_emb = models
|
|
101
|
+
models = torch.nn.parallel.DistributedDataParallel(models, device_ids=[rank])
|
|
102
|
+
|
|
103
|
+
optimizer = sw.utils.create_optimizer(args.optimizer, models, args.lr, args.weight_decay)
|
|
104
|
+
|
|
105
|
+
node_HE_fea1, node_HE_fea2 = torch.Tensor(adata1.obsm['he_sp']), torch.Tensor(adata2.obsm['he_sp'])
|
|
106
|
+
X1_real, X2_real = torch.Tensor(adata1.X), torch.Tensor(adata2.X)
|
|
107
|
+
anno1, anno2 = torch.LongTensor(anno1), torch.LongTensor(anno2)
|
|
108
|
+
|
|
109
|
+
dataset1 = TensorDataset(X1_real, node_HE_fea1, anno1)
|
|
110
|
+
dataset2 = TensorDataset(X2_real, node_HE_fea2, anno2)
|
|
111
|
+
sampler1 = DistributedSampler(dataset1, num_replicas=world_size, rank=rank)
|
|
112
|
+
sampler2 = DistributedSampler(dataset2, num_replicas=world_size, rank=rank)
|
|
113
|
+
dataloader1 = DataLoader(dataset1, batch_size=args.batch_size, sampler=sampler1, drop_last=False)
|
|
114
|
+
dataloader2 = DataLoader(dataset2, batch_size=args.batch_size, sampler=sampler2, drop_last=False)
|
|
115
|
+
|
|
116
|
+
rbf = sw.model.RBF().to(rank)
|
|
117
|
+
MMD = sw.model.MMD_loss(kernel=rbf).to(rank)
|
|
118
|
+
|
|
119
|
+
epoch_iter = tqdm(range(args.epoch), desc="🧠 Training", disable=(rank != 0))
|
|
120
|
+
for epoch in epoch_iter:
|
|
121
|
+
sampler1.set_epoch(epoch)
|
|
122
|
+
sampler2.set_epoch(epoch)
|
|
123
|
+
|
|
124
|
+
for data1, data2 in zip(dataloader1, dataloader2):
|
|
125
|
+
X1_real_batch, node_HE_fea1_batch, anno1_batch = data1
|
|
126
|
+
X2_real_batch, node_HE_fea2_batch, anno2_batch = data2
|
|
127
|
+
X1_real_batch, node_HE_fea1_batch, anno1_batch = X1_real_batch.to(rank), node_HE_fea1_batch.to(rank), anno1_batch.to(rank)
|
|
128
|
+
X2_real_batch, node_HE_fea2_batch, anno2_batch = X2_real_batch.to(rank), node_HE_fea2_batch.to(rank), anno2_batch.to(rank)
|
|
129
|
+
|
|
130
|
+
anno_emb1, anno_emb2 = F.dropout(anno_emb(anno1_batch), p=0.1), F.dropout(anno_emb(anno2_batch), p=0.1)
|
|
131
|
+
he1_map, he2_map = model_1(mlp_1(node_HE_fea1_batch)), model_1(mlp_1(node_HE_fea2_batch))
|
|
132
|
+
emb1, emb2 = torch.concat([he1_map, anno_emb1], dim=1), torch.concat([he2_map, anno_emb2], dim=1)
|
|
133
|
+
rec_omics1, rec_omics2 = reg_1(emb1), reg_2(emb2)
|
|
134
|
+
|
|
135
|
+
loss1 = F.mse_loss(rec_omics1, X1_real_batch)
|
|
136
|
+
loss2 = F.mse_loss(rec_omics2, X2_real_batch)
|
|
137
|
+
mmd = MMD(he1_map, he2_map)
|
|
138
|
+
loss = loss1 + loss2 + args.mmd_weight * mmd
|
|
139
|
+
|
|
140
|
+
optimizer.zero_grad()
|
|
141
|
+
loss.backward()
|
|
142
|
+
optimizer.step()
|
|
143
|
+
sw.utils.cleanup()
|
|
144
|
+
|
|
145
|
+
if rank == 0:
|
|
146
|
+
print("💾 Saving checkpoint...")
|
|
147
|
+
model_1_save = model_1.module if hasattr(model_1, "module") else model_1
|
|
148
|
+
mlp_1_save = mlp_1.module if hasattr(mlp_1, "module") else mlp_1
|
|
149
|
+
reg_1_save = reg_1.module if hasattr(reg_1, "module") else reg_1
|
|
150
|
+
reg_2_save = reg_2.module if hasattr(reg_2, "module") else reg_2
|
|
151
|
+
anno_emb_save = anno_emb.module if hasattr(anno_emb, "module") else anno_emb
|
|
152
|
+
torch.save({
|
|
153
|
+
'model_1': model_1_save.state_dict(),
|
|
154
|
+
'mlp_1': mlp_1_save.state_dict(),
|
|
155
|
+
'reg_1': reg_1_save.state_dict(),
|
|
156
|
+
'reg_2': reg_2_save.state_dict(),
|
|
157
|
+
'anno_emb': anno_emb_save.state_dict(),
|
|
158
|
+
}, f'{args.output_folder}model/{args.save_tag}/models.pt')
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def train_Cross_Resolution(rank, world_size, args, adata1, adata2, models, agg_mtx2, anno1, anno2):
|
|
162
|
+
sw.utils.setup(rank, world_size)
|
|
163
|
+
|
|
164
|
+
models = models.to(rank)
|
|
165
|
+
model_1, mlp_1, mlp_2, reg_1, reg_2, anno_emb = models
|
|
166
|
+
models = torch.nn.parallel.DistributedDataParallel(models, device_ids=[rank])
|
|
167
|
+
optimizer = sw.utils.create_optimizer(args.optimizer, models, args.lr, args.weight_decay)
|
|
168
|
+
|
|
169
|
+
he1, he2 = torch.Tensor(adata1.obsm['sp_he']), torch.Tensor(adata2.obsm['sp_he'])
|
|
170
|
+
panelA1, panelB2 = torch.Tensor(adata1.X), torch.Tensor(adata2.X)
|
|
171
|
+
anno1, anno2 = torch.FloatTensor(anno1).to(rank), torch.LongTensor(anno2).to(rank)
|
|
172
|
+
|
|
173
|
+
dataset2 = TensorDataset(torch.arange(agg_mtx2.shape[0]))
|
|
174
|
+
batch_size2 = args.batch_size
|
|
175
|
+
sampler2 = DistributedSampler(dataset2, num_replicas=world_size, rank=rank)
|
|
176
|
+
dataloader2 = DataLoader(dataset2, batch_size=batch_size2, sampler=sampler2, drop_last=False)
|
|
177
|
+
he1_batch = he1.to(rank)
|
|
178
|
+
panelA1_batch = panelA1.to(rank)
|
|
179
|
+
|
|
180
|
+
rbf = sw.model.RBF().to(rank)
|
|
181
|
+
MMD = sw.model.MMD_loss(kernel=rbf).to(rank)
|
|
182
|
+
|
|
183
|
+
print('================================ Train ================================\n')
|
|
184
|
+
for epoch in range(args.epoch):
|
|
185
|
+
sampler2.set_epoch(epoch)
|
|
186
|
+
batch_iter = tqdm(
|
|
187
|
+
dataloader2,
|
|
188
|
+
desc=f'🚀 Epoch {epoch + 1}/{args.epoch}',
|
|
189
|
+
leave=False,
|
|
190
|
+
)
|
|
191
|
+
for spot_idx2 in batch_iter:
|
|
192
|
+
cell_idx2 = agg_mtx2[spot_idx2[0].to(int)].tocoo().col
|
|
193
|
+
agg_mtx2_batch = sw.utils.sparse_mx_to_torch_sparse_tensor(agg_mtx2[spot_idx2[0]][:, cell_idx2], device=rank)
|
|
194
|
+
panelB2_batch, he2_batch = panelB2[cell_idx2].to(rank), he2[cell_idx2].to(rank)
|
|
195
|
+
|
|
196
|
+
anno_emb1 = F.dropout(sw.utils.mean_annotation_embedding(anno1, anno_emb), p=0)
|
|
197
|
+
anno_emb2 = F.dropout(anno_emb(anno2[cell_idx2]), p=0)
|
|
198
|
+
he1_map, he2_map = model_1(mlp_1(he1_batch)), model_1(mlp_2(he2_batch))
|
|
199
|
+
emb1, emb2 = torch.concat([he1_map, anno_emb1], dim=1), torch.concat([he2_map, anno_emb2], dim=1)
|
|
200
|
+
rec_omics1 = reg_1(emb1)
|
|
201
|
+
rec_omics2 = reg_2(emb2)
|
|
202
|
+
|
|
203
|
+
mmd = MMD(he1_map, agg_mtx2_batch@he2_map)
|
|
204
|
+
loss1 = F.mse_loss(rec_omics1, panelA1_batch)
|
|
205
|
+
loss2 = F.mse_loss(rec_omics2, panelB2_batch)
|
|
206
|
+
loss = loss1 + loss2 + args.mmd_weight * mmd
|
|
207
|
+
|
|
208
|
+
optimizer.zero_grad()
|
|
209
|
+
loss.backward()
|
|
210
|
+
optimizer.step()
|
|
211
|
+
|
|
212
|
+
if rank == 0:
|
|
213
|
+
batch_iter.set_postfix({
|
|
214
|
+
'loss1': f'{loss1.item():.4f}',
|
|
215
|
+
'loss2': f'{loss2.item():.4f}',
|
|
216
|
+
'mmd': f'{mmd.item():.4f}'
|
|
217
|
+
})
|
|
218
|
+
sw.utils.cleanup()
|
|
219
|
+
|
|
220
|
+
if rank == 0:
|
|
221
|
+
model_1_save = model_1.module if hasattr(model_1, "module") else model_1
|
|
222
|
+
mlp_1_save = mlp_1.module if hasattr(mlp_1, "module") else mlp_1
|
|
223
|
+
mlp_2_save = mlp_2.module if hasattr(mlp_2, "module") else mlp_2
|
|
224
|
+
reg_1_save = reg_1.module if hasattr(reg_1, "module") else reg_1
|
|
225
|
+
reg_2_save = reg_2.module if hasattr(reg_2, "module") else reg_2
|
|
226
|
+
anno_emb_save = anno_emb.module if hasattr(anno_emb, "module") else anno_emb
|
|
227
|
+
torch.save({
|
|
228
|
+
'model_1': model_1_save.state_dict(),
|
|
229
|
+
'mlp_1': mlp_1_save.state_dict(),
|
|
230
|
+
'mlp_2': mlp_2_save.state_dict(),
|
|
231
|
+
'reg_1': reg_1_save.state_dict(),
|
|
232
|
+
'reg_2': reg_2_save.state_dict(),
|
|
233
|
+
'anno_emb': anno_emb_save.state_dict(),
|
|
234
|
+
}, f'{args.output_folder}model/{args.save_tag}/models.pt')
|
SpaWeaver/__init__.py
ADDED
SpaWeaver/args.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def build_args():
|
|
5
|
+
parser = argparse.ArgumentParser(description="Model")
|
|
6
|
+
parser.add_argument("--seed", type=int, default=0)
|
|
7
|
+
parser.add_argument("--device", type=int, default=0)
|
|
8
|
+
|
|
9
|
+
parser.add_argument("--optimizer", type=str, default="adam")
|
|
10
|
+
parser.add_argument("--load_model", action="store_true")
|
|
11
|
+
|
|
12
|
+
# graph transformer
|
|
13
|
+
parser.add_argument('--hops', type=int, default=3, help='Hop of neighbors to be calculated')
|
|
14
|
+
parser.add_argument('--pe_dim', type=int, default=128, help='position embedding size')
|
|
15
|
+
parser.add_argument('--hidden_dim', type=int, default=128, help='Hidden layer size')
|
|
16
|
+
parser.add_argument('--n_layers', type=int, default=1, help='Number of Transformer layers')
|
|
17
|
+
parser.add_argument('--n_heads', type=int, default=2, help='Number of Transformer heads')
|
|
18
|
+
parser.add_argument('--dropout', type=float, default=0.1, help='Dropout')
|
|
19
|
+
parser.add_argument('--attention_dropout', type=float, default=0.1, help='Dropout in the attention layer')
|
|
20
|
+
parser.add_argument("--activation", type=str, default="elu")
|
|
21
|
+
|
|
22
|
+
# adjustable parameters
|
|
23
|
+
parser.add_argument("--epoch", type=int, default=500, help="number of training epochs")
|
|
24
|
+
parser.add_argument("--batch_size", type=int, default=4096)
|
|
25
|
+
parser.add_argument("--loss_fn", type=str, default="mse")
|
|
26
|
+
parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
|
|
27
|
+
parser.add_argument("--weight_decay", type=float, default=0, help="weight decay")
|
|
28
|
+
|
|
29
|
+
# File parameter
|
|
30
|
+
parser.add_argument("--sample_name1", type=str, default="Human_Breast_Cancer_Rep1")
|
|
31
|
+
parser.add_argument("--root_path1", type=str, default='./datasets/Human_Breast_Cancer_Rep1/')
|
|
32
|
+
parser.add_argument("--sample_name2", type=str, default="Human_Breast_Cancer_Rep2")
|
|
33
|
+
parser.add_argument("--root_path2", type=str, default='./datasets/Human_Breast_Cancer_Rep2/')
|
|
34
|
+
parser.add_argument("--save", type=bool, default=True)
|
|
35
|
+
parser.add_argument("--save_tag", type=str, default='fig2')
|
|
36
|
+
parser.add_argument("--output_folder", type=str, default="./outputs/")
|
|
37
|
+
|
|
38
|
+
parser.add_argument("--image_encoder", type=str, default="uni")
|
|
39
|
+
parser.add_argument("--img_batch_size", type=int, default=64)
|
|
40
|
+
parser.add_argument("--num_neighbors", type=int, default=7)
|
|
41
|
+
parser.add_argument("--scale", type=float, default=0.363788)
|
|
42
|
+
parser.add_argument("--cell_diameter", type=float, default=-1, help="By physical size (um)")
|
|
43
|
+
parser.add_argument("--resolution", type=float, default=64, help="By pixels")
|
|
44
|
+
parser.add_argument("--mmd_weight", type=float, default=0.01, help="weight for MMD")
|
|
45
|
+
|
|
46
|
+
# read parameters
|
|
47
|
+
args = parser.parse_args()
|
|
48
|
+
return args
|