SpaWeaver 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,234 @@
1
+ import os
2
+ import torch
3
+ import itertools
4
+ import numpy as np
5
+ import pandas as pd
6
+ import scanpy as sc
7
+ import SpaWeaver as sw
8
+ import matplotlib.pyplot as plt
9
+ import torch.multiprocessing as mp
10
+
11
+ from tqdm import tqdm
12
+ import torch.nn.functional as F
13
+ import torch.multiprocessing as mp
14
+ from torch.utils.data import DataLoader, TensorDataset, DistributedSampler
15
+
16
+
17
+ def train(rank, world_size, args, adata1, adata2, models):
18
+ sw.utils.setup(rank, world_size)
19
+
20
+ models = models.to(rank)
21
+ model_1, mlp_1, reg_1, reg_2 = models
22
+ models = torch.nn.parallel.DistributedDataParallel(models, device_ids=[rank])
23
+
24
+ optimizer = sw.utils.create_optimizer(args.optimizer, models, args.lr, args.weight_decay)
25
+
26
+ node_HE_fea1, node_HE_fea2 = torch.Tensor(adata1.obsm['he_sp']), torch.Tensor(adata2.obsm['he_sp'])
27
+ X1_real, X2_real = torch.Tensor(adata1.X), torch.Tensor(adata2.X)
28
+ dataset1 = TensorDataset(X1_real, node_HE_fea1)
29
+ dataset2 = TensorDataset(X2_real, node_HE_fea2)
30
+ sampler1 = DistributedSampler(dataset1, num_replicas=world_size, rank=rank)
31
+ sampler2 = DistributedSampler(dataset2, num_replicas=world_size, rank=rank)
32
+ dataloader1 = DataLoader(dataset1, batch_size=args.batch_size, sampler=sampler1, drop_last=True)
33
+ dataloader2 = DataLoader(dataset2, batch_size=args.batch_size, sampler=sampler2, drop_last=True)
34
+
35
+ rbf = sw.model.RBF().to(rank)
36
+ MMD = sw.model.MMD_loss(kernel=rbf).to(rank)
37
+
38
+ len1 = len(dataloader1)
39
+ len2 = len(dataloader2)
40
+
41
+ if rank == 0:
42
+ print(f"====== Dataset Comprison: Data 1={len1} batches, Data 2={len2} batches ======")
43
+ if len2 > len1:
44
+ print(f"Data 2 data more,Automatically iterate over the Data 1 dataset...")
45
+
46
+
47
+ epoch_iter = tqdm(range(args.epoch), desc="🧠 Training", disable=(rank != 0))
48
+ for epoch in epoch_iter:
49
+ sampler1.set_epoch(epoch)
50
+ sampler2.set_epoch(epoch)
51
+
52
+ if len1 >= len2:
53
+ # If the dataset1 is larger, iterate over the dataset2 instead
54
+ iterator = zip(dataloader1, itertools.cycle(dataloader2))
55
+ else:
56
+ # If the dataset2 is larger, iterate over the dataset1 instead.
57
+ iterator = zip(itertools.cycle(dataloader1), dataloader2)
58
+
59
+ for data1, data2 in iterator:
60
+ X1_real_batch, node_HE_fea1_batch = data1
61
+ X2_real_batch, node_HE_fea2_batch = data2
62
+ X1_real_batch, node_HE_fea1_batch = X1_real_batch.to(rank), node_HE_fea1_batch.to(rank)
63
+ X2_real_batch, node_HE_fea2_batch = X2_real_batch.to(rank), node_HE_fea2_batch.to(rank)
64
+
65
+ he1_map, he2_map = model_1(mlp_1(node_HE_fea1_batch)), model_1(mlp_1(node_HE_fea2_batch))
66
+ rec_omics1 = reg_1(he1_map)
67
+ rec_omics2 = reg_2(he2_map)
68
+
69
+ # loss
70
+ mmd = MMD(he1_map, he2_map)
71
+ loss1 = F.mse_loss(rec_omics1, X1_real_batch)
72
+ loss2 = F.mse_loss(rec_omics2, X2_real_batch)
73
+ loss = loss1 + loss2 + args.mmd_weight * mmd
74
+
75
+ optimizer.zero_grad()
76
+ loss.backward()
77
+ optimizer.step()
78
+ sw.utils.cleanup()
79
+
80
+ if rank == 0:
81
+ print("💾 Saving checkpoint...")
82
+ model_1_save = model_1.module if hasattr(model_1, "module") else model_1
83
+ mlp_1_save = mlp_1.module if hasattr(mlp_1, "module") else mlp_1
84
+ reg_1_save = reg_1.module if hasattr(reg_1, "module") else reg_1
85
+ reg_2_save = reg_2.module if hasattr(reg_2, "module") else reg_2
86
+
87
+ torch.save({
88
+ 'model_1': model_1_save.state_dict(),
89
+ 'mlp_1': mlp_1_save.state_dict(),
90
+ 'reg_1': reg_1_save.state_dict(),
91
+ 'reg_2': reg_2_save.state_dict()
92
+ }, f'{args.output_folder}model/{args.save_tag}/models.pt')
93
+ print(f"✅ Checkpoint saved to {args.output_folder}model/{args.save_tag}/models.pt.")
94
+
95
+
96
+ def train_Multi_Omics(rank, world_size, args, adata1, adata2, models, anno1, anno2):
97
+ sw.utils.setup(rank, world_size)
98
+
99
+ models = models.to(rank)
100
+ model_1, mlp_1, reg_1, reg_2, anno_emb = models
101
+ models = torch.nn.parallel.DistributedDataParallel(models, device_ids=[rank])
102
+
103
+ optimizer = sw.utils.create_optimizer(args.optimizer, models, args.lr, args.weight_decay)
104
+
105
+ node_HE_fea1, node_HE_fea2 = torch.Tensor(adata1.obsm['he_sp']), torch.Tensor(adata2.obsm['he_sp'])
106
+ X1_real, X2_real = torch.Tensor(adata1.X), torch.Tensor(adata2.X)
107
+ anno1, anno2 = torch.LongTensor(anno1), torch.LongTensor(anno2)
108
+
109
+ dataset1 = TensorDataset(X1_real, node_HE_fea1, anno1)
110
+ dataset2 = TensorDataset(X2_real, node_HE_fea2, anno2)
111
+ sampler1 = DistributedSampler(dataset1, num_replicas=world_size, rank=rank)
112
+ sampler2 = DistributedSampler(dataset2, num_replicas=world_size, rank=rank)
113
+ dataloader1 = DataLoader(dataset1, batch_size=args.batch_size, sampler=sampler1, drop_last=False)
114
+ dataloader2 = DataLoader(dataset2, batch_size=args.batch_size, sampler=sampler2, drop_last=False)
115
+
116
+ rbf = sw.model.RBF().to(rank)
117
+ MMD = sw.model.MMD_loss(kernel=rbf).to(rank)
118
+
119
+ epoch_iter = tqdm(range(args.epoch), desc="🧠 Training", disable=(rank != 0))
120
+ for epoch in epoch_iter:
121
+ sampler1.set_epoch(epoch)
122
+ sampler2.set_epoch(epoch)
123
+
124
+ for data1, data2 in zip(dataloader1, dataloader2):
125
+ X1_real_batch, node_HE_fea1_batch, anno1_batch = data1
126
+ X2_real_batch, node_HE_fea2_batch, anno2_batch = data2
127
+ X1_real_batch, node_HE_fea1_batch, anno1_batch = X1_real_batch.to(rank), node_HE_fea1_batch.to(rank), anno1_batch.to(rank)
128
+ X2_real_batch, node_HE_fea2_batch, anno2_batch = X2_real_batch.to(rank), node_HE_fea2_batch.to(rank), anno2_batch.to(rank)
129
+
130
+ anno_emb1, anno_emb2 = F.dropout(anno_emb(anno1_batch), p=0.1), F.dropout(anno_emb(anno2_batch), p=0.1)
131
+ he1_map, he2_map = model_1(mlp_1(node_HE_fea1_batch)), model_1(mlp_1(node_HE_fea2_batch))
132
+ emb1, emb2 = torch.concat([he1_map, anno_emb1], dim=1), torch.concat([he2_map, anno_emb2], dim=1)
133
+ rec_omics1, rec_omics2 = reg_1(emb1), reg_2(emb2)
134
+
135
+ loss1 = F.mse_loss(rec_omics1, X1_real_batch)
136
+ loss2 = F.mse_loss(rec_omics2, X2_real_batch)
137
+ mmd = MMD(he1_map, he2_map)
138
+ loss = loss1 + loss2 + args.mmd_weight * mmd
139
+
140
+ optimizer.zero_grad()
141
+ loss.backward()
142
+ optimizer.step()
143
+ sw.utils.cleanup()
144
+
145
+ if rank == 0:
146
+ print("💾 Saving checkpoint...")
147
+ model_1_save = model_1.module if hasattr(model_1, "module") else model_1
148
+ mlp_1_save = mlp_1.module if hasattr(mlp_1, "module") else mlp_1
149
+ reg_1_save = reg_1.module if hasattr(reg_1, "module") else reg_1
150
+ reg_2_save = reg_2.module if hasattr(reg_2, "module") else reg_2
151
+ anno_emb_save = anno_emb.module if hasattr(anno_emb, "module") else anno_emb
152
+ torch.save({
153
+ 'model_1': model_1_save.state_dict(),
154
+ 'mlp_1': mlp_1_save.state_dict(),
155
+ 'reg_1': reg_1_save.state_dict(),
156
+ 'reg_2': reg_2_save.state_dict(),
157
+ 'anno_emb': anno_emb_save.state_dict(),
158
+ }, f'{args.output_folder}model/{args.save_tag}/models.pt')
159
+
160
+
161
+ def train_Cross_Resolution(rank, world_size, args, adata1, adata2, models, agg_mtx2, anno1, anno2):
162
+ sw.utils.setup(rank, world_size)
163
+
164
+ models = models.to(rank)
165
+ model_1, mlp_1, mlp_2, reg_1, reg_2, anno_emb = models
166
+ models = torch.nn.parallel.DistributedDataParallel(models, device_ids=[rank])
167
+ optimizer = sw.utils.create_optimizer(args.optimizer, models, args.lr, args.weight_decay)
168
+
169
+ he1, he2 = torch.Tensor(adata1.obsm['sp_he']), torch.Tensor(adata2.obsm['sp_he'])
170
+ panelA1, panelB2 = torch.Tensor(adata1.X), torch.Tensor(adata2.X)
171
+ anno1, anno2 = torch.FloatTensor(anno1).to(rank), torch.LongTensor(anno2).to(rank)
172
+
173
+ dataset2 = TensorDataset(torch.arange(agg_mtx2.shape[0]))
174
+ batch_size2 = args.batch_size
175
+ sampler2 = DistributedSampler(dataset2, num_replicas=world_size, rank=rank)
176
+ dataloader2 = DataLoader(dataset2, batch_size=batch_size2, sampler=sampler2, drop_last=False)
177
+ he1_batch = he1.to(rank)
178
+ panelA1_batch = panelA1.to(rank)
179
+
180
+ rbf = sw.model.RBF().to(rank)
181
+ MMD = sw.model.MMD_loss(kernel=rbf).to(rank)
182
+
183
+ print('================================ Train ================================\n')
184
+ for epoch in range(args.epoch):
185
+ sampler2.set_epoch(epoch)
186
+ batch_iter = tqdm(
187
+ dataloader2,
188
+ desc=f'🚀 Epoch {epoch + 1}/{args.epoch}',
189
+ leave=False,
190
+ )
191
+ for spot_idx2 in batch_iter:
192
+ cell_idx2 = agg_mtx2[spot_idx2[0].to(int)].tocoo().col
193
+ agg_mtx2_batch = sw.utils.sparse_mx_to_torch_sparse_tensor(agg_mtx2[spot_idx2[0]][:, cell_idx2], device=rank)
194
+ panelB2_batch, he2_batch = panelB2[cell_idx2].to(rank), he2[cell_idx2].to(rank)
195
+
196
+ anno_emb1 = F.dropout(sw.utils.mean_annotation_embedding(anno1, anno_emb), p=0)
197
+ anno_emb2 = F.dropout(anno_emb(anno2[cell_idx2]), p=0)
198
+ he1_map, he2_map = model_1(mlp_1(he1_batch)), model_1(mlp_2(he2_batch))
199
+ emb1, emb2 = torch.concat([he1_map, anno_emb1], dim=1), torch.concat([he2_map, anno_emb2], dim=1)
200
+ rec_omics1 = reg_1(emb1)
201
+ rec_omics2 = reg_2(emb2)
202
+
203
+ mmd = MMD(he1_map, agg_mtx2_batch@he2_map)
204
+ loss1 = F.mse_loss(rec_omics1, panelA1_batch)
205
+ loss2 = F.mse_loss(rec_omics2, panelB2_batch)
206
+ loss = loss1 + loss2 + args.mmd_weight * mmd
207
+
208
+ optimizer.zero_grad()
209
+ loss.backward()
210
+ optimizer.step()
211
+
212
+ if rank == 0:
213
+ batch_iter.set_postfix({
214
+ 'loss1': f'{loss1.item():.4f}',
215
+ 'loss2': f'{loss2.item():.4f}',
216
+ 'mmd': f'{mmd.item():.4f}'
217
+ })
218
+ sw.utils.cleanup()
219
+
220
+ if rank == 0:
221
+ model_1_save = model_1.module if hasattr(model_1, "module") else model_1
222
+ mlp_1_save = mlp_1.module if hasattr(mlp_1, "module") else mlp_1
223
+ mlp_2_save = mlp_2.module if hasattr(mlp_2, "module") else mlp_2
224
+ reg_1_save = reg_1.module if hasattr(reg_1, "module") else reg_1
225
+ reg_2_save = reg_2.module if hasattr(reg_2, "module") else reg_2
226
+ anno_emb_save = anno_emb.module if hasattr(anno_emb, "module") else anno_emb
227
+ torch.save({
228
+ 'model_1': model_1_save.state_dict(),
229
+ 'mlp_1': mlp_1_save.state_dict(),
230
+ 'mlp_2': mlp_2_save.state_dict(),
231
+ 'reg_1': reg_1_save.state_dict(),
232
+ 'reg_2': reg_2_save.state_dict(),
233
+ 'anno_emb': anno_emb_save.state_dict(),
234
+ }, f'{args.output_folder}model/{args.save_tag}/models.pt')
SpaWeaver/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ """SpaWeaver package."""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ from .model import RBF, transformerModel
6
+
7
+ __all__ = ["RBF", "transformerModel"]
SpaWeaver/args.py ADDED
@@ -0,0 +1,48 @@
1
+ import argparse
2
+
3
+
4
+ def build_args():
5
+ parser = argparse.ArgumentParser(description="Model")
6
+ parser.add_argument("--seed", type=int, default=0)
7
+ parser.add_argument("--device", type=int, default=0)
8
+
9
+ parser.add_argument("--optimizer", type=str, default="adam")
10
+ parser.add_argument("--load_model", action="store_true")
11
+
12
+ # graph transformer
13
+ parser.add_argument('--hops', type=int, default=3, help='Hop of neighbors to be calculated')
14
+ parser.add_argument('--pe_dim', type=int, default=128, help='position embedding size')
15
+ parser.add_argument('--hidden_dim', type=int, default=128, help='Hidden layer size')
16
+ parser.add_argument('--n_layers', type=int, default=1, help='Number of Transformer layers')
17
+ parser.add_argument('--n_heads', type=int, default=2, help='Number of Transformer heads')
18
+ parser.add_argument('--dropout', type=float, default=0.1, help='Dropout')
19
+ parser.add_argument('--attention_dropout', type=float, default=0.1, help='Dropout in the attention layer')
20
+ parser.add_argument("--activation", type=str, default="elu")
21
+
22
+ # adjustable parameters
23
+ parser.add_argument("--epoch", type=int, default=500, help="number of training epochs")
24
+ parser.add_argument("--batch_size", type=int, default=4096)
25
+ parser.add_argument("--loss_fn", type=str, default="mse")
26
+ parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
27
+ parser.add_argument("--weight_decay", type=float, default=0, help="weight decay")
28
+
29
+ # File parameter
30
+ parser.add_argument("--sample_name1", type=str, default="Human_Breast_Cancer_Rep1")
31
+ parser.add_argument("--root_path1", type=str, default='./datasets/Human_Breast_Cancer_Rep1/')
32
+ parser.add_argument("--sample_name2", type=str, default="Human_Breast_Cancer_Rep2")
33
+ parser.add_argument("--root_path2", type=str, default='./datasets/Human_Breast_Cancer_Rep2/')
34
+ parser.add_argument("--save", type=bool, default=True)
35
+ parser.add_argument("--save_tag", type=str, default='fig2')
36
+ parser.add_argument("--output_folder", type=str, default="./outputs/")
37
+
38
+ parser.add_argument("--image_encoder", type=str, default="uni")
39
+ parser.add_argument("--img_batch_size", type=int, default=64)
40
+ parser.add_argument("--num_neighbors", type=int, default=7)
41
+ parser.add_argument("--scale", type=float, default=0.363788)
42
+ parser.add_argument("--cell_diameter", type=float, default=-1, help="By physical size (um)")
43
+ parser.add_argument("--resolution", type=float, default=64, help="By pixels")
44
+ parser.add_argument("--mmd_weight", type=float, default=0.01, help="weight for MMD")
45
+
46
+ # read parameters
47
+ args = parser.parse_args()
48
+ return args