gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from scipy.spatial import cKDTree
|
|
4
|
+
from torch_geometric.nn.conv import MessagePassing
|
|
5
|
+
from torch_geometric.utils import add_remaining_self_loops, degree
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def build_spatial_graph(
|
|
9
|
+
coords: np.ndarray,
|
|
10
|
+
n_neighbors: int,
|
|
11
|
+
undirected: bool = True
|
|
12
|
+
) -> np.ndarray:
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
Parameters:
|
|
16
|
+
-----------
|
|
17
|
+
coords : np.ndarray
|
|
18
|
+
Spatial coordinates of shape (n_cells, n_dims)
|
|
19
|
+
n_neighbors : int
|
|
20
|
+
Number of nearest neighbors
|
|
21
|
+
undirected : bool, default=True
|
|
22
|
+
Whether to make graph undirected
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
--------
|
|
26
|
+
edge_array : np.ndarray
|
|
27
|
+
Edge array of shape (n_edges, 2)
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
coords = np.ascontiguousarray(coords, dtype=np.float32)
|
|
31
|
+
|
|
32
|
+
# Query k-NN
|
|
33
|
+
tree = cKDTree(coords, balanced_tree=True, compact_nodes=True)
|
|
34
|
+
_, indices = tree.query(coords, k=n_neighbors, workers=-1)
|
|
35
|
+
|
|
36
|
+
n_nodes = coords.shape[0]
|
|
37
|
+
|
|
38
|
+
if undirected:
|
|
39
|
+
# Create bidirectional edges
|
|
40
|
+
source = np.repeat(np.arange(n_nodes), n_neighbors)
|
|
41
|
+
target = indices.flatten()
|
|
42
|
+
|
|
43
|
+
# Combine forward and reverse edges
|
|
44
|
+
all_edges = np.column_stack([
|
|
45
|
+
np.concatenate([source, target]),
|
|
46
|
+
np.concatenate([target, source])
|
|
47
|
+
])
|
|
48
|
+
|
|
49
|
+
# Remove duplicates using set
|
|
50
|
+
edge_set = {tuple(sorted([i, j])) for i, j in all_edges}
|
|
51
|
+
return np.array(list(edge_set), dtype=np.int32)
|
|
52
|
+
else:
|
|
53
|
+
# Directed graph - just flatten the indices
|
|
54
|
+
source = np.repeat(np.arange(n_nodes), n_neighbors)
|
|
55
|
+
target = indices.flatten()
|
|
56
|
+
return np.column_stack([source, target]).astype(np.int32)
|
|
57
|
+
|
|
58
|
+
class GCN(MessagePassing):
|
|
59
|
+
"""
|
|
60
|
+
GCN for unweighted graphs.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, K=1):
|
|
64
|
+
super().__init__(aggr="add")
|
|
65
|
+
self.K = K
|
|
66
|
+
|
|
67
|
+
def forward(self, x, edge_index):
|
|
68
|
+
# Add self-loops
|
|
69
|
+
edge_index, _ = add_remaining_self_loops(edge_index, num_nodes=x.size(0))
|
|
70
|
+
|
|
71
|
+
# Compute normalization: 1/sqrt(deg_i * deg_j)
|
|
72
|
+
row, col = edge_index
|
|
73
|
+
deg = degree(row, x.size(0), dtype=x.dtype)
|
|
74
|
+
norm = (deg[row] * deg[col]).pow(-0.5)
|
|
75
|
+
norm[norm == float("inf")] = 0
|
|
76
|
+
|
|
77
|
+
# K-hop propagation
|
|
78
|
+
xs = [x]
|
|
79
|
+
for _ in range(self.K):
|
|
80
|
+
xs.append(self.propagate(edge_index, x=xs[-1], norm=norm))
|
|
81
|
+
|
|
82
|
+
return torch.cat(xs[1:], dim=1)
|
|
83
|
+
|
|
84
|
+
def message(self, x_j, norm):
|
|
85
|
+
return norm.view(-1, 1) * x_j
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from einops import repeat
|
|
5
|
+
from torch import nn
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Linear2D(nn.Module):
|
|
9
|
+
"""Linear2D module consists of a linear layer with 3D weight matrix.
|
|
10
|
+
|
|
11
|
+
Args:
|
|
12
|
+
input_dim (int): The input dimension of the Linear2D module.
|
|
13
|
+
hidden_dim (int): The hidden dimension of the Linear2D module.
|
|
14
|
+
n_modules (int): The number of modules of the Linear2D module.
|
|
15
|
+
bias (bool, optional): Whether to use bias. Defaults to False.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(self,
|
|
19
|
+
input_dim,
|
|
20
|
+
hidden_dim,
|
|
21
|
+
n_modules,
|
|
22
|
+
bias=False):
|
|
23
|
+
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.input_dim = input_dim
|
|
26
|
+
self.hidden_dim = hidden_dim
|
|
27
|
+
self.n_modules = n_modules
|
|
28
|
+
|
|
29
|
+
self.weights = torch.randn(input_dim, hidden_dim, n_modules)
|
|
30
|
+
self.weights = nn.Parameter(
|
|
31
|
+
nn.init.xavier_normal_(self.weights))
|
|
32
|
+
self.bias = None
|
|
33
|
+
if bias:
|
|
34
|
+
self.bias = torch.randn(1, hidden_dim, n_modules)
|
|
35
|
+
self.bias = nn.Parameter(
|
|
36
|
+
nn.init.xavier_normal_(self.bias))
|
|
37
|
+
|
|
38
|
+
def forward(self, x):
|
|
39
|
+
affine_out = torch.einsum("bi,ijk->bjk", [x, self.weights])
|
|
40
|
+
if self.bias is not None:
|
|
41
|
+
affine_out = affine_out + self.bias
|
|
42
|
+
return affine_out
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class GeneModuler(nn.Module):
|
|
46
|
+
"""GeneModuler takes gene expression as input and outputs gene modules.
|
|
47
|
+
|
|
48
|
+
Attributes:
|
|
49
|
+
input_dim (int): The input dimension of the GeneModuler model.
|
|
50
|
+
hidden_dim (int): The hidden dimension of the GeneModuler model.
|
|
51
|
+
n_modules (int): The number of modules of the GeneModuler model.
|
|
52
|
+
layernorm (nn.LayerNorm): The layer normalization layer.
|
|
53
|
+
extractor (Linear2D): The Linear2D object.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(self,
|
|
57
|
+
input_dim=2000,
|
|
58
|
+
hidden_dim=8,
|
|
59
|
+
n_modules=16):
|
|
60
|
+
|
|
61
|
+
super().__init__()
|
|
62
|
+
self.input_dim = input_dim
|
|
63
|
+
self.hidden_dim = hidden_dim
|
|
64
|
+
self.n_modules = n_modules
|
|
65
|
+
|
|
66
|
+
self.layernorm = nn.LayerNorm(input_dim)
|
|
67
|
+
self.extractor = Linear2D(
|
|
68
|
+
input_dim=input_dim, hidden_dim=hidden_dim, n_modules=n_modules
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
def forward(self, x, batch=None):
|
|
72
|
+
if batch is not None:
|
|
73
|
+
module = self.layernorm(x, batch)
|
|
74
|
+
else:
|
|
75
|
+
module = self.layernorm(x)
|
|
76
|
+
module = self.extractor(x).transpose(2, 1)
|
|
77
|
+
return F.relu(module)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class PositionalEncoding(nn.Module):
|
|
81
|
+
"""
|
|
82
|
+
Positional Encoding
|
|
83
|
+
Attributes:
|
|
84
|
+
d_model (int): The dimensionality of the model. This should match the dimension of the input embeddings.
|
|
85
|
+
max_len (int): The maximum length of the sequence for which positional encoding is computed.
|
|
86
|
+
"""
|
|
87
|
+
def __init__(self,
|
|
88
|
+
d_model,
|
|
89
|
+
max_len=500):
|
|
90
|
+
|
|
91
|
+
super().__init__()
|
|
92
|
+
|
|
93
|
+
self.d_model = d_model
|
|
94
|
+
pe = torch.zeros(max_len, d_model)
|
|
95
|
+
position = torch.arange(0, max_len).float().unsqueeze(1)
|
|
96
|
+
angular_speed = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
|
|
97
|
+
pe[:, 0::2] = torch.sin(position * angular_speed)
|
|
98
|
+
pe[:, 1::2] = torch.cos(position * angular_speed)
|
|
99
|
+
self.register_buffer('pe', pe.unsqueeze(0))
|
|
100
|
+
|
|
101
|
+
def forward(self, x):
|
|
102
|
+
# x is N, L, D
|
|
103
|
+
# pe is 1, maxlen, D
|
|
104
|
+
scaled_x = x * np.sqrt(self.d_model)
|
|
105
|
+
encoded = scaled_x + self.pe[:, x.size(1), :]
|
|
106
|
+
return encoded
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class GeneModuleFormer(nn.Module):
|
|
110
|
+
"""GeneModuleFormer is a gene expression model based on the Transformer architecture.
|
|
111
|
+
|
|
112
|
+
Attributes:
|
|
113
|
+
input_dim (int): The dimensionality of the input gene expression data.
|
|
114
|
+
module_dim (int): The dimensionality of each module in the model.
|
|
115
|
+
hidden_dim (int): The hidden layer dimension used within the model.
|
|
116
|
+
n_modules (int): The number of modules (transformer blocks) in the model.
|
|
117
|
+
nhead (int): The number of attention heads in each transformer layer.
|
|
118
|
+
n_enc_layer (int): The number of encoding layers in the transformer model.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(
|
|
122
|
+
self,
|
|
123
|
+
input_dim=2000,
|
|
124
|
+
module_dim=30,
|
|
125
|
+
hidden_dim=256,
|
|
126
|
+
n_modules=16,
|
|
127
|
+
nhead=8,
|
|
128
|
+
n_enc_layer=3,
|
|
129
|
+
):
|
|
130
|
+
|
|
131
|
+
super().__init__()
|
|
132
|
+
|
|
133
|
+
self.moduler = GeneModuler(
|
|
134
|
+
input_dim=input_dim, hidden_dim=module_dim, n_modules=n_modules
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
self.expand = (
|
|
138
|
+
nn.Linear(module_dim, hidden_dim)
|
|
139
|
+
if module_dim != hidden_dim
|
|
140
|
+
else nn.Identity()
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
self.module = nn.TransformerEncoder(
|
|
144
|
+
encoder_layer=nn.TransformerEncoderLayer(d_model=hidden_dim,
|
|
145
|
+
nhead=nhead,
|
|
146
|
+
dim_feedforward=4 * hidden_dim,
|
|
147
|
+
batch_first=True),
|
|
148
|
+
num_layers=n_enc_layer
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
self.pe = PositionalEncoding(d_model=module_dim)
|
|
152
|
+
|
|
153
|
+
self.cls_token = nn.Parameter(torch.randn(1, 1, module_dim))
|
|
154
|
+
|
|
155
|
+
def forward(self, x,):
|
|
156
|
+
auto_fold = self.moduler(x)
|
|
157
|
+
b, _, _ = auto_fold.shape
|
|
158
|
+
auto_fold = self.pe(auto_fold)
|
|
159
|
+
cls_tokens = repeat(self.cls_token, "() n d -> b n d", b=b)
|
|
160
|
+
auto_fold = torch.cat([cls_tokens, auto_fold], dim=1)
|
|
161
|
+
auto_fold = self.expand(auto_fold)
|
|
162
|
+
rep = self.module(auto_fold)
|
|
163
|
+
cls_rep = rep[:,0,:]
|
|
164
|
+
return cls_rep
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
## Define the loss function
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from .distribution import NegativeBinomial, ZeroInflatedNegativeBinomial
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def rec_loss(x_hat,x,logtheta,zi_logit,distribution):
|
|
8
|
+
if distribution == 'nb':
|
|
9
|
+
loss = -NegativeBinomial(mu=x_hat, theta=logtheta.exp()).log_prob(x).sum(-1).mean()
|
|
10
|
+
elif distribution == 'zinb':
|
|
11
|
+
loss = -ZeroInflatedNegativeBinomial(mu=x_hat, theta=logtheta.exp(),zi_logits=zi_logit).log_prob(x).sum(-1).mean()
|
|
12
|
+
else:
|
|
13
|
+
loss = F.mse_loss(x_hat, x)
|
|
14
|
+
|
|
15
|
+
return loss
|
|
16
|
+
|
|
17
|
+
def ce_loss(pred_label, true_label):
|
|
18
|
+
return F.cross_entropy(pred_label, true_label)
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from torch.nn import functional as F
|
|
4
|
+
|
|
5
|
+
from .encoder_decoder import Decoder, Encoder
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class StEmbeding(nn.Module):
|
|
9
|
+
"""
|
|
10
|
+
Learn graph-smoothed and expression embeddings for each cell, with optional batch correction.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
input_size (list): List of input feature sizes for each encoder.
|
|
14
|
+
hidden_size (int): Hidden layer size in encoder/decoder.
|
|
15
|
+
embedding_size (int): Latent embedding size.
|
|
16
|
+
batch_embedding_size (int): Size of batch embedding vector.
|
|
17
|
+
out_put_size (int): Output gene size.
|
|
18
|
+
batch_size (int): Number of batches (not sample count).
|
|
19
|
+
class_size (int): Number of classes for classification.
|
|
20
|
+
distribution (str): Output distribution type ('nb', 'zinb', 'gaussian', etc.).
|
|
21
|
+
Other GNN-related args passed to Encoder.
|
|
22
|
+
"""
|
|
23
|
+
def __init__(self,
|
|
24
|
+
input_size,
|
|
25
|
+
hidden_size,
|
|
26
|
+
embedding_size,
|
|
27
|
+
batch_embedding_size,
|
|
28
|
+
out_put_size,
|
|
29
|
+
batch_size,
|
|
30
|
+
class_size,
|
|
31
|
+
distribution,
|
|
32
|
+
module_dim,
|
|
33
|
+
hidden_gmf,
|
|
34
|
+
n_modules,
|
|
35
|
+
nhead,
|
|
36
|
+
n_enc_layer,
|
|
37
|
+
use_tf=True,
|
|
38
|
+
variational=True,
|
|
39
|
+
batch_representation='embedding',
|
|
40
|
+
dispersion='gene'):
|
|
41
|
+
super().__init__()
|
|
42
|
+
|
|
43
|
+
self.input_size = input_size
|
|
44
|
+
self.z_num = len(self.input_size)
|
|
45
|
+
self.distribution = distribution
|
|
46
|
+
self.batch_representation = batch_representation
|
|
47
|
+
self.num_batches = batch_size
|
|
48
|
+
|
|
49
|
+
self.logtheta = nn.Parameter(torch.randn(batch_size, out_put_size))
|
|
50
|
+
|
|
51
|
+
# Handle batch embedding
|
|
52
|
+
if batch_representation == 'embedding':
|
|
53
|
+
self.batch_embedding = nn.Embedding(batch_size, batch_embedding_size)
|
|
54
|
+
self.batch_embedding_size = batch_embedding_size
|
|
55
|
+
else:
|
|
56
|
+
self.batch_embedding_size = batch_size # one-hot case
|
|
57
|
+
|
|
58
|
+
# Build encoders for each modality
|
|
59
|
+
self.encoder = nn.ModuleList()
|
|
60
|
+
for eid in range(self.z_num):
|
|
61
|
+
self.encoder.append(
|
|
62
|
+
Encoder(self.input_size[eid],
|
|
63
|
+
hidden_size,
|
|
64
|
+
embedding_size,
|
|
65
|
+
self.batch_embedding_size,
|
|
66
|
+
module_dim,
|
|
67
|
+
hidden_gmf,
|
|
68
|
+
n_modules,
|
|
69
|
+
nhead,
|
|
70
|
+
n_enc_layer,
|
|
71
|
+
use_tf,
|
|
72
|
+
variational)
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Build decoders for reconstruction and classification
|
|
76
|
+
self.decoder = nn.ModuleDict()
|
|
77
|
+
for decoder_type in ['reconstruction', 'classification']:
|
|
78
|
+
self.decoder[decoder_type] = Decoder(out_put_size,
|
|
79
|
+
hidden_size,
|
|
80
|
+
embedding_size,
|
|
81
|
+
self.batch_embedding_size,
|
|
82
|
+
class_size,
|
|
83
|
+
decoder_type,
|
|
84
|
+
self.distribution)
|
|
85
|
+
|
|
86
|
+
def _process_batch(self, batch):
|
|
87
|
+
if self.batch_representation == 'embedding':
|
|
88
|
+
return self.batch_embedding(batch)
|
|
89
|
+
else:
|
|
90
|
+
return F.one_hot(batch, num_classes=self.num_batches).float()
|
|
91
|
+
|
|
92
|
+
def forward(self, x_list, batch):
|
|
93
|
+
batch = self._process_batch(batch)
|
|
94
|
+
|
|
95
|
+
if self.distribution in ['nb', 'zinb']:
|
|
96
|
+
library_size = x_list[0].sum(-1, keepdim=True)
|
|
97
|
+
else:
|
|
98
|
+
n = x_list[0].shape[0]
|
|
99
|
+
device = x_list[0].device
|
|
100
|
+
library_size = torch.ones(n, 1, device=device)
|
|
101
|
+
|
|
102
|
+
x_rec_list, zi_logit_list, z_list = [], [], []
|
|
103
|
+
for eid in range(self.z_num):
|
|
104
|
+
z = self.encoder[eid](x_list[eid], batch)
|
|
105
|
+
x_rec, zi_logit = self.decoder['reconstruction'](z, batch)
|
|
106
|
+
x_rec = x_rec * library_size
|
|
107
|
+
|
|
108
|
+
x_rec_list.append(x_rec)
|
|
109
|
+
zi_logit_list.append(zi_logit)
|
|
110
|
+
z_list.append(z)
|
|
111
|
+
|
|
112
|
+
x_class = self._classification(z_list, batch)
|
|
113
|
+
return x_rec_list, x_class, zi_logit_list, z_list
|
|
114
|
+
|
|
115
|
+
def _classification(self, z_list, batch):
|
|
116
|
+
z = torch.cat(z_list, dim=1)
|
|
117
|
+
return self.decoder['classification'](z, batch)
|
|
118
|
+
|
|
119
|
+
def encode(self, x_list, batch):
|
|
120
|
+
batch = self._process_batch(batch)
|
|
121
|
+
z_list = []
|
|
122
|
+
for eid in range(self.z_num):
|
|
123
|
+
z = self.encoder[eid](x_list[eid], batch)
|
|
124
|
+
z_list.append(z)
|
|
125
|
+
return z_list
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
|
|
5
|
+
from .loss import ce_loss, rec_loss
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EarlyStopping:
|
|
9
|
+
"""Early stops the training if validation loss doesn't improve after a given patience."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, patience=7,delta=0, path=None):
|
|
12
|
+
|
|
13
|
+
self.patience = patience
|
|
14
|
+
self.counter = 0
|
|
15
|
+
self.best_score = -np.inf
|
|
16
|
+
self.early_stop = False
|
|
17
|
+
self.delta = delta
|
|
18
|
+
self.path = path
|
|
19
|
+
|
|
20
|
+
def __call__(self, val_loss):
|
|
21
|
+
score = -val_loss
|
|
22
|
+
|
|
23
|
+
if self.best_score == -np.inf:
|
|
24
|
+
self.best_score = score
|
|
25
|
+
|
|
26
|
+
elif score < self.best_score + self.delta:
|
|
27
|
+
self.counter += 1
|
|
28
|
+
if self.counter >= self.patience:
|
|
29
|
+
self.early_stop = True
|
|
30
|
+
else:
|
|
31
|
+
self.best_score = score
|
|
32
|
+
self.counter = 0
|
|
33
|
+
self.early_stop = False
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ModelTrain:
|
|
37
|
+
def __init__(self,
|
|
38
|
+
model,
|
|
39
|
+
optimizer,
|
|
40
|
+
distribution,
|
|
41
|
+
mode,
|
|
42
|
+
lr,
|
|
43
|
+
model_path):
|
|
44
|
+
|
|
45
|
+
self.model = model
|
|
46
|
+
self.optimizer = optimizer
|
|
47
|
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
48
|
+
self.model.to(self.device)
|
|
49
|
+
self.distribution = distribution
|
|
50
|
+
self.mode = mode
|
|
51
|
+
self.lr = lr
|
|
52
|
+
self.model_path = model_path
|
|
53
|
+
|
|
54
|
+
self.train_loader = None
|
|
55
|
+
self.val_loader = None
|
|
56
|
+
self.losses = []
|
|
57
|
+
self.val_losses = []
|
|
58
|
+
|
|
59
|
+
self.train_step_fn = self._make_train_step_fn()
|
|
60
|
+
self.val_step_fn = self._make_val_step_fn()
|
|
61
|
+
|
|
62
|
+
if self.lr is not None:
|
|
63
|
+
for param_group in self.optimizer.param_groups:
|
|
64
|
+
param_group['lr'] = self.lr
|
|
65
|
+
|
|
66
|
+
def set_loaders(self, train_loader, val_loader=None):
|
|
67
|
+
self.train_loader = train_loader
|
|
68
|
+
self.val_loader = val_loader
|
|
69
|
+
|
|
70
|
+
def compute_loss(self,x_hat,x,log_theata,zi_logit,x_class,labels):
|
|
71
|
+
if self.mode == 'reconstruction':
|
|
72
|
+
loss = 0
|
|
73
|
+
for id in range(len(x_hat)):
|
|
74
|
+
loss_rec = rec_loss(x_hat[id],x,log_theata,zi_logit[id],self.distribution)
|
|
75
|
+
loss_kld = self.model.encoder[id].kl_loss()
|
|
76
|
+
loss = loss + loss_rec + loss_kld
|
|
77
|
+
|
|
78
|
+
elif self.mode == 'classification':
|
|
79
|
+
loss = ce_loss(x_class, labels)
|
|
80
|
+
|
|
81
|
+
return loss
|
|
82
|
+
|
|
83
|
+
def _make_train_step_fn(self):
|
|
84
|
+
# Builds function that performs a step in the train loop
|
|
85
|
+
def perform_train_step_fn(x_gcn,ST_batches, x, labels):
|
|
86
|
+
|
|
87
|
+
self.model.train()
|
|
88
|
+
|
|
89
|
+
x_hat, x_class, zi_logit, _ = self.model([x,x_gcn], ST_batches)
|
|
90
|
+
log_theata = self.model.logtheta[ST_batches]
|
|
91
|
+
loss = self.compute_loss(x_hat,x, log_theata, zi_logit,x_class,labels)
|
|
92
|
+
|
|
93
|
+
loss.backward()
|
|
94
|
+
self.optimizer.step()
|
|
95
|
+
self.optimizer.zero_grad()
|
|
96
|
+
|
|
97
|
+
return loss.item()
|
|
98
|
+
|
|
99
|
+
return perform_train_step_fn
|
|
100
|
+
|
|
101
|
+
def _make_val_step_fn(self):
|
|
102
|
+
# Builds function that performs a step in the validation loop
|
|
103
|
+
def perform_val_step_fn(x_gcn,ST_batches, x, labels):
|
|
104
|
+
|
|
105
|
+
self.model.eval()
|
|
106
|
+
|
|
107
|
+
x_hat, x_class, zi_logit, _ = self.model([x,x_gcn], ST_batches)
|
|
108
|
+
log_theata = self.model.logtheta[ST_batches]
|
|
109
|
+
loss = self.compute_loss(x_hat,x, log_theata, zi_logit,x_class,labels)
|
|
110
|
+
|
|
111
|
+
return loss.item()
|
|
112
|
+
|
|
113
|
+
return perform_val_step_fn
|
|
114
|
+
|
|
115
|
+
def _mini_batch(self, epoch_idx, n_epochs, validation=False):
|
|
116
|
+
# The mini-batch can be used with both loaders
|
|
117
|
+
if validation:
|
|
118
|
+
data_loader = self.val_loader
|
|
119
|
+
step_fn = self.val_step_fn
|
|
120
|
+
else:
|
|
121
|
+
data_loader = self.train_loader
|
|
122
|
+
step_fn = self.train_step_fn
|
|
123
|
+
|
|
124
|
+
# mini-batch loop
|
|
125
|
+
mini_batch_losses = []
|
|
126
|
+
len(self.train_loader)
|
|
127
|
+
|
|
128
|
+
for batch_idx, (x_gcn,ST_batches,x,labels) in enumerate(data_loader):
|
|
129
|
+
# p = float(batch_idx + epoch_idx * batch_iter) / (n_epochs * batch_iter)
|
|
130
|
+
# grl_lambda = 2. / (1. + np.exp(-10 *p)) -1
|
|
131
|
+
|
|
132
|
+
x_gcn = x_gcn.to(self.device)
|
|
133
|
+
ST_batches = ST_batches.long().to(self.device)
|
|
134
|
+
x = x.to(self.device)
|
|
135
|
+
labels = labels.to(self.device)
|
|
136
|
+
|
|
137
|
+
mini_batch_loss = step_fn(x_gcn,ST_batches,x,labels)
|
|
138
|
+
mini_batch_losses.append(mini_batch_loss)
|
|
139
|
+
# mini_batch_lossses_rgl.append(mini_batch_loss_rgl)
|
|
140
|
+
|
|
141
|
+
return np.mean(mini_batch_losses)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _set_requires_grad(self, module_group, mode):
|
|
145
|
+
for name, param_group in module_group.items():
|
|
146
|
+
requires_grad = (mode == name)
|
|
147
|
+
for param in param_group.parameters():
|
|
148
|
+
param.requires_grad = requires_grad
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def train(self, n_epochs,patience):
|
|
152
|
+
loss_track = EarlyStopping(patience)
|
|
153
|
+
|
|
154
|
+
self._set_requires_grad(self.model.decoder, self.mode)
|
|
155
|
+
|
|
156
|
+
pbar = tqdm(range(n_epochs), desc=f'LGCN train ({self.mode})', total=n_epochs)
|
|
157
|
+
for epoch in pbar:
|
|
158
|
+
|
|
159
|
+
# Performs training
|
|
160
|
+
train_loss = self._mini_batch(epoch,n_epochs,validation=False)
|
|
161
|
+
|
|
162
|
+
# Performs evaluation
|
|
163
|
+
with torch.no_grad():
|
|
164
|
+
val_loss = self._mini_batch(epoch,n_epochs,validation=True)
|
|
165
|
+
|
|
166
|
+
# Save the best model
|
|
167
|
+
if loss_track.best_score < -val_loss:
|
|
168
|
+
torch.save(self.model.state_dict(),self.model_path)
|
|
169
|
+
|
|
170
|
+
# Update validation loss
|
|
171
|
+
loss_track(val_loss)
|
|
172
|
+
if loss_track.early_stop:
|
|
173
|
+
print(f'Stop training, as {self.mode} validation loss has not decreased for {patience} consecutive steps.')
|
|
174
|
+
break
|
|
175
|
+
|
|
176
|
+
pbar.set_postfix({'train loss': f'{train_loss.item():.4f}',
|
|
177
|
+
'validation loss': f'{val_loss.item():.4f}'})
|