gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,312 @@
1
+ import logging
2
+ import os
3
+ import random
4
+ from collections import OrderedDict
5
+ from typing import Any
6
+
7
+ import numpy as np
8
+ import torch
9
+ import yaml
10
+ from torch.utils.data import (
11
+ DataLoader,
12
+ SubsetRandomSampler,
13
+ TensorDataset,
14
+ random_split,
15
+ )
16
+
17
+ from gsMap.config import FindLatentRepresentationsConfig
18
+
19
+ from .gnn.st_model import StEmbeding
20
+ from .gnn.train_step import ModelTrain
21
+ from .st_process import (
22
+ InferenceData,
23
+ TrainingData,
24
+ apply_module_score_qc,
25
+ calculate_module_score,
26
+ calculate_module_scores_from_degs,
27
+ convert_to_human_genes,
28
+ create_subsampled_adata,
29
+ find_common_hvg,
30
+ )
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+
36
+ def set_seed(seed_value):
37
+ """
38
+ Set seed for reproducibility in PyTorch and other libraries.
39
+ """
40
+ torch.manual_seed(seed_value)
41
+ np.random.seed(seed_value)
42
+ random.seed(seed_value)
43
+ if torch.cuda.is_available():
44
+ logger.info("Using GPU for computations.")
45
+ torch.cuda.manual_seed(seed_value)
46
+ torch.cuda.manual_seed_all(seed_value)
47
+ else:
48
+ logger.info("Using CPU for computations.")
49
+
50
+
51
+ def index_splitter(n, splits):
52
+ idx = torch.arange(n)
53
+ splits_tensor = torch.as_tensor(splits)
54
+ multiplier = n / splits_tensor.sum()
55
+ splits_tensor = (multiplier * splits_tensor).long()
56
+ diff = n - splits_tensor.sum()
57
+ splits_tensor[0] += diff
58
+ return random_split(idx, splits_tensor)
59
+
60
+
61
+ def run_find_latent_representation(config: FindLatentRepresentationsConfig) -> dict[str, Any]:
62
+ """
63
+ Run the find latent representation pipeline.
64
+
65
+ Args:
66
+ config: FindLatentRepresentationsConfig object with all necessary parameters
67
+
68
+ Returns:
69
+ Dictionary containing metadata about the run including config, model info,
70
+ training info, outputs, and annotation info
71
+ """
72
+ logger.info(f'Project dir: {config.project_dir}')
73
+ set_seed(2024)
74
+
75
+ # Find the hvg
76
+ hvg, n_cell_used, gene_homolog_dict = find_common_hvg(config.sample_h5ad_dict, config)
77
+ common_genes = np.array(list(gene_homolog_dict.keys()))
78
+
79
+ # Create subsampled concatenated adata with sample-specific stratified sampling
80
+ training_adata = create_subsampled_adata(config.sample_h5ad_dict, n_cell_used, config)
81
+
82
+ # Prepare the trainning data
83
+ get_trainning_data = TrainingData(config)
84
+ get_trainning_data.prepare(training_adata, hvg)
85
+
86
+ # Configure the distribution
87
+ if config.data_layer in ["count", "counts"]:
88
+ distribution = config.distribution
89
+ variational = True
90
+ use_tf = config.use_tf
91
+ else:
92
+ distribution = "gaussian"
93
+ variational = False
94
+ use_tf = False
95
+
96
+ # Instantiation the LGCN VAE
97
+ input_size = [
98
+ get_trainning_data.expression_merge.size(1),
99
+ get_trainning_data.expression_gcn_merge.size(1),
100
+ ]
101
+ class_size = len(torch.unique(get_trainning_data.label_merge))
102
+ batch_size = get_trainning_data.batch_size
103
+ cell_size, out_size = get_trainning_data.expression_merge.shape
104
+ label_name = get_trainning_data.label_name
105
+
106
+ # Configure the batch embedding dim
107
+ batch_embedding_size = 64
108
+
109
+ # Configure the model
110
+ gsmap_lgcn_model = StEmbeding(
111
+ # parameter of VAE
112
+ input_size=input_size,
113
+ hidden_size=config.hidden_size,
114
+ embedding_size=config.embedding_size,
115
+ batch_embedding_size=batch_embedding_size,
116
+ out_put_size=out_size,
117
+ batch_size=batch_size,
118
+ class_size=class_size,
119
+ # parameter of transformer
120
+ module_dim=config.module_dim,
121
+ hidden_gmf=config.hidden_gmf,
122
+ n_modules=config.n_modules,
123
+ nhead=config.nhead,
124
+ n_enc_layer=config.n_enc_layer,
125
+ # parameter of model structure
126
+ distribution=distribution,
127
+ use_tf=use_tf,
128
+ variational=variational,
129
+ )
130
+
131
+ # Configure the optimizer
132
+ optimizer = torch.optim.Adam(gsmap_lgcn_model.parameters(), lr=1e-3)
133
+ logger.info(
134
+ f"gsMap-LGCN parameters: {sum(p.numel() for p in gsmap_lgcn_model.parameters())}."
135
+ )
136
+ logger.info(f"Number of cells used in trainning: {cell_size}.")
137
+
138
+ # Split the data to trainning (80%) and validation (20%).
139
+ train_idx, val_idx = index_splitter(
140
+ get_trainning_data.expression_gcn_merge.size(0), [80, 20]
141
+ )
142
+ train_sampler = SubsetRandomSampler(train_idx)
143
+ val_sampler = SubsetRandomSampler(val_idx)
144
+
145
+ # Configure the data loader
146
+ dataset = TensorDataset(
147
+ get_trainning_data.expression_gcn_merge,
148
+ get_trainning_data.batch_merge,
149
+ get_trainning_data.expression_merge,
150
+ get_trainning_data.label_merge,
151
+ )
152
+ train_loader = DataLoader(
153
+ dataset=dataset, batch_size=config.batch_size, sampler=train_sampler
154
+ )
155
+ val_loader = DataLoader(
156
+ dataset=dataset, batch_size=config.batch_size, sampler=val_sampler
157
+ )
158
+
159
+ # Model trainning
160
+ gsMap_embedding_finder = ModelTrain(
161
+ gsmap_lgcn_model,
162
+ optimizer,
163
+ distribution,
164
+ mode="reconstruction",
165
+ lr=1e-3,
166
+ model_path=config.model_path,
167
+ )
168
+ gsMap_embedding_finder.set_loaders(train_loader, val_loader)
169
+ print(gsMap_embedding_finder.model)
170
+
171
+ if not os.path.exists(config.model_path):
172
+ # reconstruction
173
+ gsMap_embedding_finder.train(config.itermax, patience=config.patience)
174
+
175
+ # classification
176
+ if config.two_stage and config.annotation is not None:
177
+ gsMap_embedding_finder.model.load_state_dict(torch.load(config.model_path))
178
+ gsMap_embedding_finder.mode = "classification"
179
+ gsMap_embedding_finder.train(config.itermax, patience=config.patience)
180
+ else:
181
+ logger.info(f"Model found at {config.model_path}. Skipping training.")
182
+
183
+ # Load the best model
184
+ gsMap_embedding_finder.model.load_state_dict(torch.load(config.model_path))
185
+ gsmap_embedding_model = gsMap_embedding_finder.model
186
+
187
+ # Configure the inference
188
+ infer = InferenceData(hvg, batch_size, gsmap_embedding_model, label_name, config)
189
+
190
+
191
+ output_h5ad_path_dict = OrderedDict(
192
+ {sample_name: config.latent_dir / f"{sample_name}_add_latent.h5ad"
193
+ for sample_name in config.sample_h5ad_dict.keys()}
194
+ )
195
+ # Do the DEG in the training adata if annotation is provided and high quality cell QC is enabled
196
+ module_score_threshold_dict = {}
197
+ if config.high_quality_cell_qc and config.annotation is not None:
198
+ # Calculate module scores for training data
199
+ training_adata = calculate_module_score(training_adata, config.annotation)
200
+
201
+ # Calculate thresholds for each annotation
202
+ for label in training_adata.obs[config.annotation].cat.categories:
203
+ scores = training_adata.obs.loc[training_adata.obs[config.annotation] == label, f"{label}_module_score"]
204
+ Q1 = np.percentile(scores, 25)
205
+ Q2 = np.median(scores)
206
+ Q3 = np.percentile(scores, 75)
207
+ IQR = Q3 - Q1
208
+ threshold = max(0, Q2 - 1 * IQR) # Ensure threshold is not negative
209
+ module_score_threshold_dict[label] = threshold
210
+ logger.info(f"High quality module score threshold for {label}: {threshold:.3f}")
211
+
212
+ training_adata.uns['module_score_thresholds'] = module_score_threshold_dict
213
+
214
+ # Apply QC to training data as well
215
+ training_adata = apply_module_score_qc(training_adata, config.annotation, module_score_threshold_dict)
216
+
217
+ # save the training adata
218
+ training_adata_path = config.find_latent_metadata_path.parent / "training_adata.h5ad"
219
+ training_adata.write_h5ad(training_adata_path)
220
+ logger.info(f"Saved training adata to {training_adata_path}")
221
+
222
+ for st_id, (sample_name, st_file) in enumerate(config.sample_h5ad_dict.items()):
223
+
224
+ output_path = output_h5ad_path_dict[sample_name]
225
+
226
+ # Infer the embedding
227
+ adata = infer.infer_embedding_single(st_id, st_file)
228
+
229
+ # Calculate module scores and apply QC for each annotation
230
+ if config.high_quality_cell_qc and config.annotation is not None:
231
+ # Calculate module scores for this sample using the same DEGs from training
232
+ logger.info(f"Calculating module scores for {sample_name}...")
233
+
234
+ # Get DEG results from training data
235
+ deg_results = training_adata.uns['rank_genes_groups']
236
+
237
+ # Keep the same gene list as training data, because the module score is based on the training data.
238
+ # This is critical for proper normalization: without this filter, normalize_total()
239
+ # would compute different scaling factors compared to training data.
240
+ adata = adata[:, training_adata.var_names].copy()
241
+
242
+ # Calculate module scores using existing DEG results
243
+ adata = calculate_module_scores_from_degs(adata, deg_results, config.annotation)
244
+
245
+ # Apply QC based on module score thresholds
246
+ if config.annotation in adata.obs.columns:
247
+ adata = apply_module_score_qc(adata, config.annotation, module_score_threshold_dict)
248
+ else:
249
+ logger.warning(f"Annotation '{config.annotation}' not found in {sample_name}, skipping QC")
250
+
251
+ # Transfer to human gene names
252
+ adata = convert_to_human_genes(adata, gene_homolog_dict, species=config.species)
253
+
254
+
255
+ # Compute the depth
256
+ if config.data_layer in ["count", "counts"]:
257
+ adata.obs['depth'] = np.array(adata.layers[config.data_layer].sum(axis=1)).flatten()
258
+
259
+ # Save the ST data with embeddings
260
+ adata.write_h5ad(output_path)
261
+
262
+ logger.info(f"Saved latent representation to {output_path}")
263
+
264
+ # Convert config to dict with all Path objects as strings
265
+ config_dict = config.to_dict_with_paths_as_strings()
266
+
267
+ # Convert output_h5ad_path_dict to strings
268
+ output_h5ad_path_dict_str = {k: str(v) for k, v in output_h5ad_path_dict.items()}
269
+
270
+ # Save metadata
271
+ metadata = {
272
+ "config": config_dict,
273
+ "model_info": {
274
+ "model_path": str(config.model_path),
275
+ "n_parameters": int(sum(p.numel() for p in gsmap_embedding_model.parameters())),
276
+ "input_size": [int(x) for x in input_size],
277
+ "hidden_size": int(config.hidden_size),
278
+ "embedding_size": int(config.embedding_size),
279
+ "batch_embedding_size": int(batch_embedding_size),
280
+ "class_size": int(class_size),
281
+ "distribution": distribution,
282
+ "variational": variational,
283
+ "use_tf": use_tf
284
+ },
285
+ "training_info": {
286
+ "n_cells_used": int(cell_size),
287
+ "n_genes_used": int(len(hvg)),
288
+ "n_common_genes": int(len(common_genes)),
289
+ "batch_size": int(config.batch_size),
290
+ "n_epochs": int(config.itermax),
291
+ "patience": int(config.patience),
292
+ "two_stage": config.two_stage
293
+ },
294
+ "outputs": {
295
+ "latent_files": output_h5ad_path_dict_str,
296
+ "n_sections": len(config.sample_h5ad_dict)
297
+ },
298
+ "annotation_info": {
299
+ "annotation_key": config.annotation,
300
+ "n_classes": int(class_size),
301
+ "label_names": label_name if isinstance(label_name, list) else label_name.tolist() if hasattr(label_name, 'tolist') else list(label_name),
302
+ }
303
+ }
304
+
305
+ # Save metadata to YAML file
306
+ metadata_path = config.find_latent_metadata_path
307
+ with open(metadata_path, 'w') as f:
308
+ yaml.dump(metadata, f, default_flow_style=False, sort_keys=False)
309
+
310
+ logger.info(f"Saved metadata to {metadata_path}")
311
+
312
+ return metadata