gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,85 @@
1
+ import numpy as np
2
+ import torch
3
+ from scipy.spatial import cKDTree
4
+ from torch_geometric.nn.conv import MessagePassing
5
+ from torch_geometric.utils import add_remaining_self_loops, degree
6
+
7
+
8
+ def build_spatial_graph(
9
+ coords: np.ndarray,
10
+ n_neighbors: int,
11
+ undirected: bool = True
12
+ ) -> np.ndarray:
13
+ """
14
+
15
+ Parameters:
16
+ -----------
17
+ coords : np.ndarray
18
+ Spatial coordinates of shape (n_cells, n_dims)
19
+ n_neighbors : int
20
+ Number of nearest neighbors
21
+ undirected : bool, default=True
22
+ Whether to make graph undirected
23
+
24
+ Returns:
25
+ --------
26
+ edge_array : np.ndarray
27
+ Edge array of shape (n_edges, 2)
28
+ """
29
+
30
+ coords = np.ascontiguousarray(coords, dtype=np.float32)
31
+
32
+ # Query k-NN
33
+ tree = cKDTree(coords, balanced_tree=True, compact_nodes=True)
34
+ _, indices = tree.query(coords, k=n_neighbors, workers=-1)
35
+
36
+ n_nodes = coords.shape[0]
37
+
38
+ if undirected:
39
+ # Create bidirectional edges
40
+ source = np.repeat(np.arange(n_nodes), n_neighbors)
41
+ target = indices.flatten()
42
+
43
+ # Combine forward and reverse edges
44
+ all_edges = np.column_stack([
45
+ np.concatenate([source, target]),
46
+ np.concatenate([target, source])
47
+ ])
48
+
49
+ # Remove duplicates using set
50
+ edge_set = {tuple(sorted([i, j])) for i, j in all_edges}
51
+ return np.array(list(edge_set), dtype=np.int32)
52
+ else:
53
+ # Directed graph - just flatten the indices
54
+ source = np.repeat(np.arange(n_nodes), n_neighbors)
55
+ target = indices.flatten()
56
+ return np.column_stack([source, target]).astype(np.int32)
57
+
58
+ class GCN(MessagePassing):
59
+ """
60
+ GCN for unweighted graphs.
61
+ """
62
+
63
+ def __init__(self, K=1):
64
+ super().__init__(aggr="add")
65
+ self.K = K
66
+
67
+ def forward(self, x, edge_index):
68
+ # Add self-loops
69
+ edge_index, _ = add_remaining_self_loops(edge_index, num_nodes=x.size(0))
70
+
71
+ # Compute normalization: 1/sqrt(deg_i * deg_j)
72
+ row, col = edge_index
73
+ deg = degree(row, x.size(0), dtype=x.dtype)
74
+ norm = (deg[row] * deg[col]).pow(-0.5)
75
+ norm[norm == float("inf")] = 0
76
+
77
+ # K-hop propagation
78
+ xs = [x]
79
+ for _ in range(self.K):
80
+ xs.append(self.propagate(edge_index, x=xs[-1], norm=norm))
81
+
82
+ return torch.cat(xs[1:], dim=1)
83
+
84
+ def message(self, x_j, norm):
85
+ return norm.view(-1, 1) * x_j
@@ -0,0 +1,164 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from einops import repeat
5
+ from torch import nn
6
+
7
+
8
+ class Linear2D(nn.Module):
9
+ """Linear2D module consists of a linear layer with 3D weight matrix.
10
+
11
+ Args:
12
+ input_dim (int): The input dimension of the Linear2D module.
13
+ hidden_dim (int): The hidden dimension of the Linear2D module.
14
+ n_modules (int): The number of modules of the Linear2D module.
15
+ bias (bool, optional): Whether to use bias. Defaults to False.
16
+ """
17
+
18
+ def __init__(self,
19
+ input_dim,
20
+ hidden_dim,
21
+ n_modules,
22
+ bias=False):
23
+
24
+ super().__init__()
25
+ self.input_dim = input_dim
26
+ self.hidden_dim = hidden_dim
27
+ self.n_modules = n_modules
28
+
29
+ self.weights = torch.randn(input_dim, hidden_dim, n_modules)
30
+ self.weights = nn.Parameter(
31
+ nn.init.xavier_normal_(self.weights))
32
+ self.bias = None
33
+ if bias:
34
+ self.bias = torch.randn(1, hidden_dim, n_modules)
35
+ self.bias = nn.Parameter(
36
+ nn.init.xavier_normal_(self.bias))
37
+
38
+ def forward(self, x):
39
+ affine_out = torch.einsum("bi,ijk->bjk", [x, self.weights])
40
+ if self.bias is not None:
41
+ affine_out = affine_out + self.bias
42
+ return affine_out
43
+
44
+
45
+ class GeneModuler(nn.Module):
46
+ """GeneModuler takes gene expression as input and outputs gene modules.
47
+
48
+ Attributes:
49
+ input_dim (int): The input dimension of the GeneModuler model.
50
+ hidden_dim (int): The hidden dimension of the GeneModuler model.
51
+ n_modules (int): The number of modules of the GeneModuler model.
52
+ layernorm (nn.LayerNorm): The layer normalization layer.
53
+ extractor (Linear2D): The Linear2D object.
54
+ """
55
+
56
+ def __init__(self,
57
+ input_dim=2000,
58
+ hidden_dim=8,
59
+ n_modules=16):
60
+
61
+ super().__init__()
62
+ self.input_dim = input_dim
63
+ self.hidden_dim = hidden_dim
64
+ self.n_modules = n_modules
65
+
66
+ self.layernorm = nn.LayerNorm(input_dim)
67
+ self.extractor = Linear2D(
68
+ input_dim=input_dim, hidden_dim=hidden_dim, n_modules=n_modules
69
+ )
70
+
71
+ def forward(self, x, batch=None):
72
+ if batch is not None:
73
+ module = self.layernorm(x, batch)
74
+ else:
75
+ module = self.layernorm(x)
76
+ module = self.extractor(x).transpose(2, 1)
77
+ return F.relu(module)
78
+
79
+
80
+ class PositionalEncoding(nn.Module):
81
+ """
82
+ Positional Encoding
83
+ Attributes:
84
+ d_model (int): The dimensionality of the model. This should match the dimension of the input embeddings.
85
+ max_len (int): The maximum length of the sequence for which positional encoding is computed.
86
+ """
87
+ def __init__(self,
88
+ d_model,
89
+ max_len=500):
90
+
91
+ super().__init__()
92
+
93
+ self.d_model = d_model
94
+ pe = torch.zeros(max_len, d_model)
95
+ position = torch.arange(0, max_len).float().unsqueeze(1)
96
+ angular_speed = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
97
+ pe[:, 0::2] = torch.sin(position * angular_speed)
98
+ pe[:, 1::2] = torch.cos(position * angular_speed)
99
+ self.register_buffer('pe', pe.unsqueeze(0))
100
+
101
+ def forward(self, x):
102
+ # x is N, L, D
103
+ # pe is 1, maxlen, D
104
+ scaled_x = x * np.sqrt(self.d_model)
105
+ encoded = scaled_x + self.pe[:, x.size(1), :]
106
+ return encoded
107
+
108
+
109
+ class GeneModuleFormer(nn.Module):
110
+ """GeneModuleFormer is a gene expression model based on the Transformer architecture.
111
+
112
+ Attributes:
113
+ input_dim (int): The dimensionality of the input gene expression data.
114
+ module_dim (int): The dimensionality of each module in the model.
115
+ hidden_dim (int): The hidden layer dimension used within the model.
116
+ n_modules (int): The number of modules (transformer blocks) in the model.
117
+ nhead (int): The number of attention heads in each transformer layer.
118
+ n_enc_layer (int): The number of encoding layers in the transformer model.
119
+ """
120
+
121
+ def __init__(
122
+ self,
123
+ input_dim=2000,
124
+ module_dim=30,
125
+ hidden_dim=256,
126
+ n_modules=16,
127
+ nhead=8,
128
+ n_enc_layer=3,
129
+ ):
130
+
131
+ super().__init__()
132
+
133
+ self.moduler = GeneModuler(
134
+ input_dim=input_dim, hidden_dim=module_dim, n_modules=n_modules
135
+ )
136
+
137
+ self.expand = (
138
+ nn.Linear(module_dim, hidden_dim)
139
+ if module_dim != hidden_dim
140
+ else nn.Identity()
141
+ )
142
+
143
+ self.module = nn.TransformerEncoder(
144
+ encoder_layer=nn.TransformerEncoderLayer(d_model=hidden_dim,
145
+ nhead=nhead,
146
+ dim_feedforward=4 * hidden_dim,
147
+ batch_first=True),
148
+ num_layers=n_enc_layer
149
+ )
150
+
151
+ self.pe = PositionalEncoding(d_model=module_dim)
152
+
153
+ self.cls_token = nn.Parameter(torch.randn(1, 1, module_dim))
154
+
155
+ def forward(self, x,):
156
+ auto_fold = self.moduler(x)
157
+ b, _, _ = auto_fold.shape
158
+ auto_fold = self.pe(auto_fold)
159
+ cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
160
+ auto_fold = torch.cat([cls_tokens, auto_fold], dim=1)
161
+ auto_fold = self.expand(auto_fold)
162
+ rep = self.module(auto_fold)
163
+ cls_rep = rep[:,0,:]
164
+ return cls_rep
@@ -0,0 +1,18 @@
1
+ ## Define the loss function
2
+ import torch.nn.functional as F
3
+
4
+ from .distribution import NegativeBinomial, ZeroInflatedNegativeBinomial
5
+
6
+
7
+ def rec_loss(x_hat,x,logtheta,zi_logit,distribution):
8
+ if distribution == 'nb':
9
+ loss = -NegativeBinomial(mu=x_hat, theta=logtheta.exp()).log_prob(x).sum(-1).mean()
10
+ elif distribution == 'zinb':
11
+ loss = -ZeroInflatedNegativeBinomial(mu=x_hat, theta=logtheta.exp(),zi_logits=zi_logit).log_prob(x).sum(-1).mean()
12
+ else:
13
+ loss = F.mse_loss(x_hat, x)
14
+
15
+ return loss
16
+
17
+ def ce_loss(pred_label, true_label):
18
+ return F.cross_entropy(pred_label, true_label)
@@ -0,0 +1,125 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from .encoder_decoder import Decoder, Encoder
6
+
7
+
8
+ class StEmbeding(nn.Module):
9
+ """
10
+ Learn graph-smoothed and expression embeddings for each cell, with optional batch correction.
11
+
12
+ Args:
13
+ input_size (list): List of input feature sizes for each encoder.
14
+ hidden_size (int): Hidden layer size in encoder/decoder.
15
+ embedding_size (int): Latent embedding size.
16
+ batch_embedding_size (int): Size of batch embedding vector.
17
+ out_put_size (int): Output gene size.
18
+ batch_size (int): Number of batches (not sample count).
19
+ class_size (int): Number of classes for classification.
20
+ distribution (str): Output distribution type ('nb', 'zinb', 'gaussian', etc.).
21
+ Other GNN-related args passed to Encoder.
22
+ """
23
+ def __init__(self,
24
+ input_size,
25
+ hidden_size,
26
+ embedding_size,
27
+ batch_embedding_size,
28
+ out_put_size,
29
+ batch_size,
30
+ class_size,
31
+ distribution,
32
+ module_dim,
33
+ hidden_gmf,
34
+ n_modules,
35
+ nhead,
36
+ n_enc_layer,
37
+ use_tf=True,
38
+ variational=True,
39
+ batch_representation='embedding',
40
+ dispersion='gene'):
41
+ super().__init__()
42
+
43
+ self.input_size = input_size
44
+ self.z_num = len(self.input_size)
45
+ self.distribution = distribution
46
+ self.batch_representation = batch_representation
47
+ self.num_batches = batch_size
48
+
49
+ self.logtheta = nn.Parameter(torch.randn(batch_size, out_put_size))
50
+
51
+ # Handle batch embedding
52
+ if batch_representation == 'embedding':
53
+ self.batch_embedding = nn.Embedding(batch_size, batch_embedding_size)
54
+ self.batch_embedding_size = batch_embedding_size
55
+ else:
56
+ self.batch_embedding_size = batch_size # one-hot case
57
+
58
+ # Build encoders for each modality
59
+ self.encoder = nn.ModuleList()
60
+ for eid in range(self.z_num):
61
+ self.encoder.append(
62
+ Encoder(self.input_size[eid],
63
+ hidden_size,
64
+ embedding_size,
65
+ self.batch_embedding_size,
66
+ module_dim,
67
+ hidden_gmf,
68
+ n_modules,
69
+ nhead,
70
+ n_enc_layer,
71
+ use_tf,
72
+ variational)
73
+ )
74
+
75
+ # Build decoders for reconstruction and classification
76
+ self.decoder = nn.ModuleDict()
77
+ for decoder_type in ['reconstruction', 'classification']:
78
+ self.decoder[decoder_type] = Decoder(out_put_size,
79
+ hidden_size,
80
+ embedding_size,
81
+ self.batch_embedding_size,
82
+ class_size,
83
+ decoder_type,
84
+ self.distribution)
85
+
86
+ def _process_batch(self, batch):
87
+ if self.batch_representation == 'embedding':
88
+ return self.batch_embedding(batch)
89
+ else:
90
+ return F.one_hot(batch, num_classes=self.num_batches).float()
91
+
92
+ def forward(self, x_list, batch):
93
+ batch = self._process_batch(batch)
94
+
95
+ if self.distribution in ['nb', 'zinb']:
96
+ library_size = x_list[0].sum(-1, keepdim=True)
97
+ else:
98
+ n = x_list[0].shape[0]
99
+ device = x_list[0].device
100
+ library_size = torch.ones(n, 1, device=device)
101
+
102
+ x_rec_list, zi_logit_list, z_list = [], [], []
103
+ for eid in range(self.z_num):
104
+ z = self.encoder[eid](x_list[eid], batch)
105
+ x_rec, zi_logit = self.decoder['reconstruction'](z, batch)
106
+ x_rec = x_rec * library_size
107
+
108
+ x_rec_list.append(x_rec)
109
+ zi_logit_list.append(zi_logit)
110
+ z_list.append(z)
111
+
112
+ x_class = self._classification(z_list, batch)
113
+ return x_rec_list, x_class, zi_logit_list, z_list
114
+
115
+ def _classification(self, z_list, batch):
116
+ z = torch.cat(z_list, dim=1)
117
+ return self.decoder['classification'](z, batch)
118
+
119
+ def encode(self, x_list, batch):
120
+ batch = self._process_batch(batch)
121
+ z_list = []
122
+ for eid in range(self.z_num):
123
+ z = self.encoder[eid](x_list[eid], batch)
124
+ z_list.append(z)
125
+ return z_list
@@ -0,0 +1,177 @@
1
+ import numpy as np
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ from .loss import ce_loss, rec_loss
6
+
7
+
8
+ class EarlyStopping:
9
+ """Early stops the training if validation loss doesn't improve after a given patience."""
10
+
11
+ def __init__(self, patience=7,delta=0, path=None):
12
+
13
+ self.patience = patience
14
+ self.counter = 0
15
+ self.best_score = -np.inf
16
+ self.early_stop = False
17
+ self.delta = delta
18
+ self.path = path
19
+
20
+ def __call__(self, val_loss):
21
+ score = -val_loss
22
+
23
+ if self.best_score == -np.inf:
24
+ self.best_score = score
25
+
26
+ elif score < self.best_score + self.delta:
27
+ self.counter += 1
28
+ if self.counter >= self.patience:
29
+ self.early_stop = True
30
+ else:
31
+ self.best_score = score
32
+ self.counter = 0
33
+ self.early_stop = False
34
+
35
+
36
+ class ModelTrain:
37
+ def __init__(self,
38
+ model,
39
+ optimizer,
40
+ distribution,
41
+ mode,
42
+ lr,
43
+ model_path):
44
+
45
+ self.model = model
46
+ self.optimizer = optimizer
47
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
48
+ self.model.to(self.device)
49
+ self.distribution = distribution
50
+ self.mode = mode
51
+ self.lr = lr
52
+ self.model_path = model_path
53
+
54
+ self.train_loader = None
55
+ self.val_loader = None
56
+ self.losses = []
57
+ self.val_losses = []
58
+
59
+ self.train_step_fn = self._make_train_step_fn()
60
+ self.val_step_fn = self._make_val_step_fn()
61
+
62
+ if self.lr is not None:
63
+ for param_group in self.optimizer.param_groups:
64
+ param_group['lr'] = self.lr
65
+
66
+ def set_loaders(self, train_loader, val_loader=None):
67
+ self.train_loader = train_loader
68
+ self.val_loader = val_loader
69
+
70
+ def compute_loss(self,x_hat,x,log_theata,zi_logit,x_class,labels):
71
+ if self.mode == 'reconstruction':
72
+ loss = 0
73
+ for id in range(len(x_hat)):
74
+ loss_rec = rec_loss(x_hat[id],x,log_theata,zi_logit[id],self.distribution)
75
+ loss_kld = self.model.encoder[id].kl_loss()
76
+ loss = loss + loss_rec + loss_kld
77
+
78
+ elif self.mode == 'classification':
79
+ loss = ce_loss(x_class, labels)
80
+
81
+ return loss
82
+
83
+ def _make_train_step_fn(self):
84
+ # Builds function that performs a step in the train loop
85
+ def perform_train_step_fn(x_gcn,ST_batches, x, labels):
86
+
87
+ self.model.train()
88
+
89
+ x_hat, x_class, zi_logit, _ = self.model([x,x_gcn], ST_batches)
90
+ log_theata = self.model.logtheta[ST_batches]
91
+ loss = self.compute_loss(x_hat,x, log_theata, zi_logit,x_class,labels)
92
+
93
+ loss.backward()
94
+ self.optimizer.step()
95
+ self.optimizer.zero_grad()
96
+
97
+ return loss.item()
98
+
99
+ return perform_train_step_fn
100
+
101
+ def _make_val_step_fn(self):
102
+ # Builds function that performs a step in the validation loop
103
+ def perform_val_step_fn(x_gcn,ST_batches, x, labels):
104
+
105
+ self.model.eval()
106
+
107
+ x_hat, x_class, zi_logit, _ = self.model([x,x_gcn], ST_batches)
108
+ log_theata = self.model.logtheta[ST_batches]
109
+ loss = self.compute_loss(x_hat,x, log_theata, zi_logit,x_class,labels)
110
+
111
+ return loss.item()
112
+
113
+ return perform_val_step_fn
114
+
115
+ def _mini_batch(self, epoch_idx, n_epochs, validation=False):
116
+ # The mini-batch can be used with both loaders
117
+ if validation:
118
+ data_loader = self.val_loader
119
+ step_fn = self.val_step_fn
120
+ else:
121
+ data_loader = self.train_loader
122
+ step_fn = self.train_step_fn
123
+
124
+ # mini-batch loop
125
+ mini_batch_losses = []
126
+ len(self.train_loader)
127
+
128
+ for batch_idx, (x_gcn,ST_batches,x,labels) in enumerate(data_loader):
129
+ # p = float(batch_idx + epoch_idx * batch_iter) / (n_epochs * batch_iter)
130
+ # grl_lambda = 2. / (1. + np.exp(-10 *p)) -1
131
+
132
+ x_gcn = x_gcn.to(self.device)
133
+ ST_batches = ST_batches.long().to(self.device)
134
+ x = x.to(self.device)
135
+ labels = labels.to(self.device)
136
+
137
+ mini_batch_loss = step_fn(x_gcn,ST_batches,x,labels)
138
+ mini_batch_losses.append(mini_batch_loss)
139
+ # mini_batch_lossses_rgl.append(mini_batch_loss_rgl)
140
+
141
+ return np.mean(mini_batch_losses)
142
+
143
+
144
+ def _set_requires_grad(self, module_group, mode):
145
+ for name, param_group in module_group.items():
146
+ requires_grad = (mode == name)
147
+ for param in param_group.parameters():
148
+ param.requires_grad = requires_grad
149
+
150
+
151
+ def train(self, n_epochs,patience):
152
+ loss_track = EarlyStopping(patience)
153
+
154
+ self._set_requires_grad(self.model.decoder, self.mode)
155
+
156
+ pbar = tqdm(range(n_epochs), desc=f'LGCN train ({self.mode})', total=n_epochs)
157
+ for epoch in pbar:
158
+
159
+ # Performs training
160
+ train_loss = self._mini_batch(epoch,n_epochs,validation=False)
161
+
162
+ # Performs evaluation
163
+ with torch.no_grad():
164
+ val_loss = self._mini_batch(epoch,n_epochs,validation=True)
165
+
166
+ # Save the best model
167
+ if loss_track.best_score < -val_loss:
168
+ torch.save(self.model.state_dict(),self.model_path)
169
+
170
+ # Update validation loss
171
+ loss_track(val_loss)
172
+ if loss_track.early_stop:
173
+ print(f'Stop training, as {self.mode} validation loss has not decreased for {patience} consecutive steps.')
174
+ break
175
+
176
+ pbar.set_postfix({'train loss': f'{train_loss.item():.4f}',
177
+ 'validation loss': f'{val_loss.item():.4f}'})