gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,312 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import random
|
|
4
|
+
from collections import OrderedDict
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
import yaml
|
|
10
|
+
from torch.utils.data import (
|
|
11
|
+
DataLoader,
|
|
12
|
+
SubsetRandomSampler,
|
|
13
|
+
TensorDataset,
|
|
14
|
+
random_split,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from gsMap.config import FindLatentRepresentationsConfig
|
|
18
|
+
|
|
19
|
+
from .gnn.st_model import StEmbeding
|
|
20
|
+
from .gnn.train_step import ModelTrain
|
|
21
|
+
from .st_process import (
|
|
22
|
+
InferenceData,
|
|
23
|
+
TrainingData,
|
|
24
|
+
apply_module_score_qc,
|
|
25
|
+
calculate_module_score,
|
|
26
|
+
calculate_module_scores_from_degs,
|
|
27
|
+
convert_to_human_genes,
|
|
28
|
+
create_subsampled_adata,
|
|
29
|
+
find_common_hvg,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def set_seed(seed_value):
|
|
37
|
+
"""
|
|
38
|
+
Set seed for reproducibility in PyTorch and other libraries.
|
|
39
|
+
"""
|
|
40
|
+
torch.manual_seed(seed_value)
|
|
41
|
+
np.random.seed(seed_value)
|
|
42
|
+
random.seed(seed_value)
|
|
43
|
+
if torch.cuda.is_available():
|
|
44
|
+
logger.info("Using GPU for computations.")
|
|
45
|
+
torch.cuda.manual_seed(seed_value)
|
|
46
|
+
torch.cuda.manual_seed_all(seed_value)
|
|
47
|
+
else:
|
|
48
|
+
logger.info("Using CPU for computations.")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def index_splitter(n, splits):
|
|
52
|
+
idx = torch.arange(n)
|
|
53
|
+
splits_tensor = torch.as_tensor(splits)
|
|
54
|
+
multiplier = n / splits_tensor.sum()
|
|
55
|
+
splits_tensor = (multiplier * splits_tensor).long()
|
|
56
|
+
diff = n - splits_tensor.sum()
|
|
57
|
+
splits_tensor[0] += diff
|
|
58
|
+
return random_split(idx, splits_tensor)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def run_find_latent_representation(config: FindLatentRepresentationsConfig) -> dict[str, Any]:
|
|
62
|
+
"""
|
|
63
|
+
Run the find latent representation pipeline.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
config: FindLatentRepresentationsConfig object with all necessary parameters
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Dictionary containing metadata about the run including config, model info,
|
|
70
|
+
training info, outputs, and annotation info
|
|
71
|
+
"""
|
|
72
|
+
logger.info(f'Project dir: {config.project_dir}')
|
|
73
|
+
set_seed(2024)
|
|
74
|
+
|
|
75
|
+
# Find the hvg
|
|
76
|
+
hvg, n_cell_used, gene_homolog_dict = find_common_hvg(config.sample_h5ad_dict, config)
|
|
77
|
+
common_genes = np.array(list(gene_homolog_dict.keys()))
|
|
78
|
+
|
|
79
|
+
# Create subsampled concatenated adata with sample-specific stratified sampling
|
|
80
|
+
training_adata = create_subsampled_adata(config.sample_h5ad_dict, n_cell_used, config)
|
|
81
|
+
|
|
82
|
+
# Prepare the trainning data
|
|
83
|
+
get_trainning_data = TrainingData(config)
|
|
84
|
+
get_trainning_data.prepare(training_adata, hvg)
|
|
85
|
+
|
|
86
|
+
# Configure the distribution
|
|
87
|
+
if config.data_layer in ["count", "counts"]:
|
|
88
|
+
distribution = config.distribution
|
|
89
|
+
variational = True
|
|
90
|
+
use_tf = config.use_tf
|
|
91
|
+
else:
|
|
92
|
+
distribution = "gaussian"
|
|
93
|
+
variational = False
|
|
94
|
+
use_tf = False
|
|
95
|
+
|
|
96
|
+
# Instantiation the LGCN VAE
|
|
97
|
+
input_size = [
|
|
98
|
+
get_trainning_data.expression_merge.size(1),
|
|
99
|
+
get_trainning_data.expression_gcn_merge.size(1),
|
|
100
|
+
]
|
|
101
|
+
class_size = len(torch.unique(get_trainning_data.label_merge))
|
|
102
|
+
batch_size = get_trainning_data.batch_size
|
|
103
|
+
cell_size, out_size = get_trainning_data.expression_merge.shape
|
|
104
|
+
label_name = get_trainning_data.label_name
|
|
105
|
+
|
|
106
|
+
# Configure the batch embedding dim
|
|
107
|
+
batch_embedding_size = 64
|
|
108
|
+
|
|
109
|
+
# Configure the model
|
|
110
|
+
gsmap_lgcn_model = StEmbeding(
|
|
111
|
+
# parameter of VAE
|
|
112
|
+
input_size=input_size,
|
|
113
|
+
hidden_size=config.hidden_size,
|
|
114
|
+
embedding_size=config.embedding_size,
|
|
115
|
+
batch_embedding_size=batch_embedding_size,
|
|
116
|
+
out_put_size=out_size,
|
|
117
|
+
batch_size=batch_size,
|
|
118
|
+
class_size=class_size,
|
|
119
|
+
# parameter of transformer
|
|
120
|
+
module_dim=config.module_dim,
|
|
121
|
+
hidden_gmf=config.hidden_gmf,
|
|
122
|
+
n_modules=config.n_modules,
|
|
123
|
+
nhead=config.nhead,
|
|
124
|
+
n_enc_layer=config.n_enc_layer,
|
|
125
|
+
# parameter of model structure
|
|
126
|
+
distribution=distribution,
|
|
127
|
+
use_tf=use_tf,
|
|
128
|
+
variational=variational,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# Configure the optimizer
|
|
132
|
+
optimizer = torch.optim.Adam(gsmap_lgcn_model.parameters(), lr=1e-3)
|
|
133
|
+
logger.info(
|
|
134
|
+
f"gsMap-LGCN parameters: {sum(p.numel() for p in gsmap_lgcn_model.parameters())}."
|
|
135
|
+
)
|
|
136
|
+
logger.info(f"Number of cells used in trainning: {cell_size}.")
|
|
137
|
+
|
|
138
|
+
# Split the data to trainning (80%) and validation (20%).
|
|
139
|
+
train_idx, val_idx = index_splitter(
|
|
140
|
+
get_trainning_data.expression_gcn_merge.size(0), [80, 20]
|
|
141
|
+
)
|
|
142
|
+
train_sampler = SubsetRandomSampler(train_idx)
|
|
143
|
+
val_sampler = SubsetRandomSampler(val_idx)
|
|
144
|
+
|
|
145
|
+
# Configure the data loader
|
|
146
|
+
dataset = TensorDataset(
|
|
147
|
+
get_trainning_data.expression_gcn_merge,
|
|
148
|
+
get_trainning_data.batch_merge,
|
|
149
|
+
get_trainning_data.expression_merge,
|
|
150
|
+
get_trainning_data.label_merge,
|
|
151
|
+
)
|
|
152
|
+
train_loader = DataLoader(
|
|
153
|
+
dataset=dataset, batch_size=config.batch_size, sampler=train_sampler
|
|
154
|
+
)
|
|
155
|
+
val_loader = DataLoader(
|
|
156
|
+
dataset=dataset, batch_size=config.batch_size, sampler=val_sampler
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Model trainning
|
|
160
|
+
gsMap_embedding_finder = ModelTrain(
|
|
161
|
+
gsmap_lgcn_model,
|
|
162
|
+
optimizer,
|
|
163
|
+
distribution,
|
|
164
|
+
mode="reconstruction",
|
|
165
|
+
lr=1e-3,
|
|
166
|
+
model_path=config.model_path,
|
|
167
|
+
)
|
|
168
|
+
gsMap_embedding_finder.set_loaders(train_loader, val_loader)
|
|
169
|
+
print(gsMap_embedding_finder.model)
|
|
170
|
+
|
|
171
|
+
if not os.path.exists(config.model_path):
|
|
172
|
+
# reconstruction
|
|
173
|
+
gsMap_embedding_finder.train(config.itermax, patience=config.patience)
|
|
174
|
+
|
|
175
|
+
# classification
|
|
176
|
+
if config.two_stage and config.annotation is not None:
|
|
177
|
+
gsMap_embedding_finder.model.load_state_dict(torch.load(config.model_path))
|
|
178
|
+
gsMap_embedding_finder.mode = "classification"
|
|
179
|
+
gsMap_embedding_finder.train(config.itermax, patience=config.patience)
|
|
180
|
+
else:
|
|
181
|
+
logger.info(f"Model found at {config.model_path}. Skipping training.")
|
|
182
|
+
|
|
183
|
+
# Load the best model
|
|
184
|
+
gsMap_embedding_finder.model.load_state_dict(torch.load(config.model_path))
|
|
185
|
+
gsmap_embedding_model = gsMap_embedding_finder.model
|
|
186
|
+
|
|
187
|
+
# Configure the inference
|
|
188
|
+
infer = InferenceData(hvg, batch_size, gsmap_embedding_model, label_name, config)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
output_h5ad_path_dict = OrderedDict(
|
|
192
|
+
{sample_name: config.latent_dir / f"{sample_name}_add_latent.h5ad"
|
|
193
|
+
for sample_name in config.sample_h5ad_dict.keys()}
|
|
194
|
+
)
|
|
195
|
+
# Do the DEG in the training adata if annotation is provided and high quality cell QC is enabled
|
|
196
|
+
module_score_threshold_dict = {}
|
|
197
|
+
if config.high_quality_cell_qc and config.annotation is not None:
|
|
198
|
+
# Calculate module scores for training data
|
|
199
|
+
training_adata = calculate_module_score(training_adata, config.annotation)
|
|
200
|
+
|
|
201
|
+
# Calculate thresholds for each annotation
|
|
202
|
+
for label in training_adata.obs[config.annotation].cat.categories:
|
|
203
|
+
scores = training_adata.obs.loc[training_adata.obs[config.annotation] == label, f"{label}_module_score"]
|
|
204
|
+
Q1 = np.percentile(scores, 25)
|
|
205
|
+
Q2 = np.median(scores)
|
|
206
|
+
Q3 = np.percentile(scores, 75)
|
|
207
|
+
IQR = Q3 - Q1
|
|
208
|
+
threshold = max(0, Q2 - 1 * IQR) # Ensure threshold is not negative
|
|
209
|
+
module_score_threshold_dict[label] = threshold
|
|
210
|
+
logger.info(f"High quality module score threshold for {label}: {threshold:.3f}")
|
|
211
|
+
|
|
212
|
+
training_adata.uns['module_score_thresholds'] = module_score_threshold_dict
|
|
213
|
+
|
|
214
|
+
# Apply QC to training data as well
|
|
215
|
+
training_adata = apply_module_score_qc(training_adata, config.annotation, module_score_threshold_dict)
|
|
216
|
+
|
|
217
|
+
# save the training adata
|
|
218
|
+
training_adata_path = config.find_latent_metadata_path.parent / "training_adata.h5ad"
|
|
219
|
+
training_adata.write_h5ad(training_adata_path)
|
|
220
|
+
logger.info(f"Saved training adata to {training_adata_path}")
|
|
221
|
+
|
|
222
|
+
for st_id, (sample_name, st_file) in enumerate(config.sample_h5ad_dict.items()):
|
|
223
|
+
|
|
224
|
+
output_path = output_h5ad_path_dict[sample_name]
|
|
225
|
+
|
|
226
|
+
# Infer the embedding
|
|
227
|
+
adata = infer.infer_embedding_single(st_id, st_file)
|
|
228
|
+
|
|
229
|
+
# Calculate module scores and apply QC for each annotation
|
|
230
|
+
if config.high_quality_cell_qc and config.annotation is not None:
|
|
231
|
+
# Calculate module scores for this sample using the same DEGs from training
|
|
232
|
+
logger.info(f"Calculating module scores for {sample_name}...")
|
|
233
|
+
|
|
234
|
+
# Get DEG results from training data
|
|
235
|
+
deg_results = training_adata.uns['rank_genes_groups']
|
|
236
|
+
|
|
237
|
+
# Keep the same gene list as training data, because the module score is based on the training data.
|
|
238
|
+
# This is critical for proper normalization: without this filter, normalize_total()
|
|
239
|
+
# would compute different scaling factors compared to training data.
|
|
240
|
+
adata = adata[:, training_adata.var_names].copy()
|
|
241
|
+
|
|
242
|
+
# Calculate module scores using existing DEG results
|
|
243
|
+
adata = calculate_module_scores_from_degs(adata, deg_results, config.annotation)
|
|
244
|
+
|
|
245
|
+
# Apply QC based on module score thresholds
|
|
246
|
+
if config.annotation in adata.obs.columns:
|
|
247
|
+
adata = apply_module_score_qc(adata, config.annotation, module_score_threshold_dict)
|
|
248
|
+
else:
|
|
249
|
+
logger.warning(f"Annotation '{config.annotation}' not found in {sample_name}, skipping QC")
|
|
250
|
+
|
|
251
|
+
# Transfer to human gene names
|
|
252
|
+
adata = convert_to_human_genes(adata, gene_homolog_dict, species=config.species)
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
# Compute the depth
|
|
256
|
+
if config.data_layer in ["count", "counts"]:
|
|
257
|
+
adata.obs['depth'] = np.array(adata.layers[config.data_layer].sum(axis=1)).flatten()
|
|
258
|
+
|
|
259
|
+
# Save the ST data with embeddings
|
|
260
|
+
adata.write_h5ad(output_path)
|
|
261
|
+
|
|
262
|
+
logger.info(f"Saved latent representation to {output_path}")
|
|
263
|
+
|
|
264
|
+
# Convert config to dict with all Path objects as strings
|
|
265
|
+
config_dict = config.to_dict_with_paths_as_strings()
|
|
266
|
+
|
|
267
|
+
# Convert output_h5ad_path_dict to strings
|
|
268
|
+
output_h5ad_path_dict_str = {k: str(v) for k, v in output_h5ad_path_dict.items()}
|
|
269
|
+
|
|
270
|
+
# Save metadata
|
|
271
|
+
metadata = {
|
|
272
|
+
"config": config_dict,
|
|
273
|
+
"model_info": {
|
|
274
|
+
"model_path": str(config.model_path),
|
|
275
|
+
"n_parameters": int(sum(p.numel() for p in gsmap_embedding_model.parameters())),
|
|
276
|
+
"input_size": [int(x) for x in input_size],
|
|
277
|
+
"hidden_size": int(config.hidden_size),
|
|
278
|
+
"embedding_size": int(config.embedding_size),
|
|
279
|
+
"batch_embedding_size": int(batch_embedding_size),
|
|
280
|
+
"class_size": int(class_size),
|
|
281
|
+
"distribution": distribution,
|
|
282
|
+
"variational": variational,
|
|
283
|
+
"use_tf": use_tf
|
|
284
|
+
},
|
|
285
|
+
"training_info": {
|
|
286
|
+
"n_cells_used": int(cell_size),
|
|
287
|
+
"n_genes_used": int(len(hvg)),
|
|
288
|
+
"n_common_genes": int(len(common_genes)),
|
|
289
|
+
"batch_size": int(config.batch_size),
|
|
290
|
+
"n_epochs": int(config.itermax),
|
|
291
|
+
"patience": int(config.patience),
|
|
292
|
+
"two_stage": config.two_stage
|
|
293
|
+
},
|
|
294
|
+
"outputs": {
|
|
295
|
+
"latent_files": output_h5ad_path_dict_str,
|
|
296
|
+
"n_sections": len(config.sample_h5ad_dict)
|
|
297
|
+
},
|
|
298
|
+
"annotation_info": {
|
|
299
|
+
"annotation_key": config.annotation,
|
|
300
|
+
"n_classes": int(class_size),
|
|
301
|
+
"label_names": label_name if isinstance(label_name, list) else label_name.tolist() if hasattr(label_name, 'tolist') else list(label_name),
|
|
302
|
+
}
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
# Save metadata to YAML file
|
|
306
|
+
metadata_path = config.find_latent_metadata_path
|
|
307
|
+
with open(metadata_path, 'w') as f:
|
|
308
|
+
yaml.dump(metadata, f, default_flow_style=False, sort_keys=False)
|
|
309
|
+
|
|
310
|
+
logger.info(f"Saved metadata to {metadata_path}")
|
|
311
|
+
|
|
312
|
+
return metadata
|