gsMap3D 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. gsMap/__init__.py +13 -0
  2. gsMap/__main__.py +4 -0
  3. gsMap/cauchy_combination_test.py +342 -0
  4. gsMap/cli.py +355 -0
  5. gsMap/config/__init__.py +72 -0
  6. gsMap/config/base.py +296 -0
  7. gsMap/config/cauchy_config.py +79 -0
  8. gsMap/config/dataclasses.py +235 -0
  9. gsMap/config/decorators.py +302 -0
  10. gsMap/config/find_latent_config.py +276 -0
  11. gsMap/config/format_sumstats_config.py +54 -0
  12. gsMap/config/latent2gene_config.py +461 -0
  13. gsMap/config/ldscore_config.py +261 -0
  14. gsMap/config/quick_mode_config.py +242 -0
  15. gsMap/config/report_config.py +81 -0
  16. gsMap/config/spatial_ldsc_config.py +334 -0
  17. gsMap/config/utils.py +286 -0
  18. gsMap/find_latent/__init__.py +3 -0
  19. gsMap/find_latent/find_latent_representation.py +312 -0
  20. gsMap/find_latent/gnn/distribution.py +498 -0
  21. gsMap/find_latent/gnn/encoder_decoder.py +186 -0
  22. gsMap/find_latent/gnn/gcn.py +85 -0
  23. gsMap/find_latent/gnn/gene_former.py +164 -0
  24. gsMap/find_latent/gnn/loss.py +18 -0
  25. gsMap/find_latent/gnn/st_model.py +125 -0
  26. gsMap/find_latent/gnn/train_step.py +177 -0
  27. gsMap/find_latent/st_process.py +781 -0
  28. gsMap/format_sumstats.py +446 -0
  29. gsMap/generate_ldscore.py +1018 -0
  30. gsMap/latent2gene/__init__.py +18 -0
  31. gsMap/latent2gene/connectivity.py +781 -0
  32. gsMap/latent2gene/entry_point.py +141 -0
  33. gsMap/latent2gene/marker_scores.py +1265 -0
  34. gsMap/latent2gene/memmap_io.py +766 -0
  35. gsMap/latent2gene/rank_calculator.py +590 -0
  36. gsMap/latent2gene/row_ordering.py +182 -0
  37. gsMap/latent2gene/row_ordering_jax.py +159 -0
  38. gsMap/ldscore/__init__.py +1 -0
  39. gsMap/ldscore/batch_construction.py +163 -0
  40. gsMap/ldscore/compute.py +126 -0
  41. gsMap/ldscore/constants.py +70 -0
  42. gsMap/ldscore/io.py +262 -0
  43. gsMap/ldscore/mapping.py +262 -0
  44. gsMap/ldscore/pipeline.py +615 -0
  45. gsMap/pipeline/quick_mode.py +134 -0
  46. gsMap/report/__init__.py +2 -0
  47. gsMap/report/diagnosis.py +375 -0
  48. gsMap/report/report.py +100 -0
  49. gsMap/report/report_data.py +1832 -0
  50. gsMap/report/static/js_lib/alpine.min.js +5 -0
  51. gsMap/report/static/js_lib/tailwindcss.js +83 -0
  52. gsMap/report/static/template.html +2242 -0
  53. gsMap/report/three_d_combine.py +312 -0
  54. gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
  55. gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
  56. gsMap/report/three_d_plot/three_d_plots.py +425 -0
  57. gsMap/report/visualize.py +1409 -0
  58. gsMap/setup.py +5 -0
  59. gsMap/spatial_ldsc/__init__.py +0 -0
  60. gsMap/spatial_ldsc/io.py +656 -0
  61. gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
  62. gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
  63. gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
  64. gsMap/utils/__init__.py +0 -0
  65. gsMap/utils/generate_r2_matrix.py +610 -0
  66. gsMap/utils/jackknife.py +518 -0
  67. gsMap/utils/manhattan_plot.py +643 -0
  68. gsMap/utils/regression_read.py +177 -0
  69. gsMap/utils/torch_utils.py +23 -0
  70. gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
  71. gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
  72. gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
  73. gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
  74. gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,781 @@
1
+ import logging
2
+ import re
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import scanpy as sc
8
+ import scipy.sparse as sp
9
+ import torch
10
+ from rich.progress import BarColumn, Progress, TaskProgressColumn, TimeRemainingColumn
11
+ from torch.utils.data import DataLoader, TensorDataset
12
+
13
+ from gsMap.config import FindLatentRepresentationsConfig
14
+
15
+ from .gnn.gcn import GCN, build_spatial_graph
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def convert_to_human_genes(adata, gene_homolog_dict: dict, species: str = None):
21
+ """
22
+ Convert gene names in adata to human gene symbols using a pre-computed mapping dictionary.
23
+
24
+ Args:
25
+ adata: AnnData object to convert
26
+ gene_homolog_dict: Dictionary mapping current gene names to human gene symbols
27
+ species: Optional species name to use for the original gene column in var
28
+
29
+ Returns:
30
+ adata: Processed AnnData object with human symbols as var_names
31
+ """
32
+ # Identify common genes present in mapping
33
+ common_genes = [g for g in gene_homolog_dict.keys() if g in adata.var_names]
34
+
35
+ if len(common_genes) < 300:
36
+ logger.warning(f"Only {len(common_genes)} genes found in mapping dictionary.")
37
+
38
+ # Filter to these genes
39
+ adata = adata[:, common_genes].copy()
40
+
41
+ human_symbols = [gene_homolog_dict[g] for g in adata.var_names]
42
+
43
+ # Store original gene names if species is provided
44
+ if species:
45
+ species_col_name = f"{species}_GENE_SYM"
46
+ adata.var[species_col_name] = adata.var_names.values
47
+
48
+ # Update var_names to human symbols
49
+ adata.var_names = human_symbols
50
+ adata.var.index.name = "HUMAN_GENE_SYM"
51
+
52
+ # Remove duplicated genes (keep first occurrence)
53
+ if adata.var_names.duplicated().any():
54
+ n_before = adata.n_vars
55
+ adata = adata[:, ~adata.var_names.duplicated()].copy()
56
+ logger.info(f"Removed {n_before - adata.n_vars} duplicate human gene symbols.")
57
+
58
+ return adata
59
+
60
+ def find_common_hvg(sample_h5ad_dict, params: FindLatentRepresentationsConfig):
61
+ """
62
+ Identifies common highly variable genes (HVGs) across multiple ST datasets and calculates
63
+ the number of cells to sample from each dataset.
64
+
65
+ Args:
66
+ sample_h5ad_dict (dict): Dictionary mapping sample names to file paths of ST datasets.
67
+ params (object): Parameter object containing attributes.
68
+ """
69
+
70
+ variances_list = []
71
+ cell_number = []
72
+
73
+ logger.info("Finding highly variable genes (HVGs)...")
74
+
75
+ with Progress(
76
+ BarColumn(),
77
+ TaskProgressColumn(),
78
+ TimeRemainingColumn(),
79
+ ) as progress:
80
+ task = progress.add_task("Finding HVGs", total=len(sample_h5ad_dict))
81
+
82
+ for sample_name, st_file in sample_h5ad_dict.items():
83
+ adata_temp = sc.read_h5ad(st_file)
84
+ # sc.pp.filter_genes(adata_temp, min_counts=1)
85
+
86
+ # Filter out mitochondrial and hemoglobin genes
87
+ gene_keep = ~adata_temp.var_names.str.match(re.compile(r'^(HB.-|MT-)', re.IGNORECASE))
88
+ if removed_genes := adata_temp.n_vars - gene_keep.sum():
89
+ progress.console.log(f"Removed {removed_genes} mitochondrial and hemoglobin genes in {sample_name}.")
90
+
91
+ adata_temp = adata_temp[:,gene_keep].copy()
92
+
93
+ # Make gene names unique to avoid issues with HVG calculation
94
+ adata_temp.var_names_make_unique()
95
+
96
+ is_count_data, actual_data_layer = setup_data_layer(adata_temp, params.data_layer, verbose=False)
97
+ cell_number.append(adata_temp.n_obs)
98
+
99
+ try:
100
+ # Identify highly variable genes
101
+ if is_count_data:
102
+ sc.experimental.pp.highly_variable_genes(
103
+ adata_temp, n_top_genes=params.feat_cell, subset=False
104
+ )
105
+ else:
106
+ sc.pp.highly_variable_genes(
107
+ adata_temp, n_top_genes=params.feat_cell, subset=False, flavor='seurat'
108
+ )
109
+
110
+ var_df = adata_temp.var.copy()
111
+ var_df["gene"] = var_df.index.tolist()
112
+ variances_list.append(var_df)
113
+ except Exception as e:
114
+ logger.warning(f"[HVG skipped] {e}")
115
+ continue
116
+
117
+
118
+ progress.update(task, advance=1)
119
+
120
+ # Check if we have any valid variance results
121
+ if len(variances_list) == 0:
122
+ raise ValueError("No valid HVG results obtained from any sample. Please check your data and parameters.")
123
+
124
+ # Find the common genes across all samples
125
+ common_genes = np.array(
126
+ list(set.intersection(
127
+ *map(set, [st.index.to_list() for st in variances_list])))
128
+ )
129
+
130
+ # Error when gene number is too small
131
+ if len(common_genes) < 300:
132
+ raise ValueError(f"Only {len(common_genes)} common genes between all samples.")
133
+
134
+ # Aggregate variances and identify HVGs
135
+ df = pd.concat(variances_list, axis=0)
136
+ df["highly_variable"] = df["highly_variable"].astype(int)
137
+
138
+ if is_count_data:
139
+ df = df.groupby("gene", observed=True).agg(
140
+ dict(
141
+ residual_variances="median",
142
+ highly_variable="sum",
143
+ )
144
+ )
145
+ df = df.loc[common_genes]
146
+ df["highly_variable_nbatches"] = df["highly_variable"]
147
+ df.sort_values(
148
+ ["highly_variable_nbatches", "residual_variances"],
149
+ ascending=False,
150
+ na_position="last",
151
+ inplace=True,
152
+ )
153
+ else:
154
+ df = df.groupby("gene", observed=True).agg(
155
+ dict(
156
+ dispersions_norm="median",
157
+ highly_variable="sum",
158
+ )
159
+ )
160
+ df = df.loc[common_genes]
161
+ df["highly_variable_nbatches"] = df["highly_variable"]
162
+ df.sort_values(
163
+ ["highly_variable_nbatches", "dispersions_norm"],
164
+ ascending=False,
165
+ na_position="last",
166
+ inplace=True,
167
+ )
168
+
169
+ hvg = df.iloc[: params.feat_cell,].index.tolist()
170
+
171
+ # Find the number of sampling cells for each batch
172
+ total_cell = np.sum(cell_number)
173
+ total_cell_training = np.minimum(total_cell, params.n_cell_training)
174
+ cell_proportion = cell_number / total_cell
175
+ n_cell_used = [
176
+ int(cell) for cell in (total_cell_training * cell_proportion).tolist()
177
+ ]
178
+
179
+ # Only use the common genes that can be transformed to human genes
180
+ if params.species is not None:
181
+ homologs = pd.read_csv(params.homolog_file, sep='\t')
182
+ if homologs.shape[1] < 2:
183
+ raise ValueError("Homologs file must have at least two columns: one for the species and one for the human gene symbol.")
184
+ homologs.columns = [params.species, 'HUMAN_GENE_SYM']
185
+ homologs.set_index(params.species, inplace=True)
186
+ common_genes = np.intersect1d(common_genes, homologs.index)
187
+ gene_homolog_dict = dict(zip(common_genes,homologs.loc[common_genes].HUMAN_GENE_SYM.values, strict=False))
188
+ else:
189
+ gene_homolog_dict = dict(zip(common_genes,common_genes, strict=False))
190
+
191
+ if len(gene_homolog_dict) < 300:
192
+ raise ValueError(f"Only {len(gene_homolog_dict)} genes could be mapped to human symbols. Please check the homolog file.")
193
+
194
+ return hvg, n_cell_used, gene_homolog_dict
195
+
196
+
197
+ def create_subsampled_adata(sample_h5ad_dict, n_cell_used, params: FindLatentRepresentationsConfig):
198
+ """
199
+ Create subsampled adata for each sample with sample-specific stratified sampling,
200
+ add batch and label information, and return concatenated adata.
201
+
202
+ Args:
203
+ sample_h5ad_dict (dict): Dictionary mapping sample names to file paths of ST datasets.
204
+ n_cell_used (list): Number of cells to sample from each dataset.
205
+ params (object): Parameter object containing attributes.
206
+
207
+ Returns:
208
+ adata: Concatenated adata object with batch and label information in obs.
209
+ """
210
+ subsampled_adatas = []
211
+
212
+ logger.info("Creating subsampled adata with stratified sampling...")
213
+
214
+ for st_id, (sample_name, st_file) in enumerate(sample_h5ad_dict.items()):
215
+ logger.info(f"Processing {sample_name}...")
216
+
217
+ # Load the data
218
+ adata = sc.read_h5ad(st_file)
219
+
220
+ # Filter out mitochondrial and hemoglobin genes
221
+ gene_keep = ~adata.var_names.str.match(re.compile(r'^(HB.-|MT-)', re.IGNORECASE))
222
+ adata = adata[:,gene_keep].copy()
223
+
224
+ # Make gene names unique to avoid issues with concatenation
225
+ adata.var_names_make_unique()
226
+
227
+ # Setup data layer consistently
228
+ is_count_data, actual_data_layer = setup_data_layer(adata, params.data_layer)
229
+
230
+ # Filter cells based on annotation if provided
231
+ if params.annotation is not None:
232
+ adata = adata[~adata.obs[params.annotation].isnull()]
233
+
234
+ # Perform stratified sampling within this sample
235
+ if params.do_sampling and adata.n_obs > n_cell_used[st_id]:
236
+ if params.annotation is None:
237
+ # Simple random sampling
238
+ num_cell = n_cell_used[st_id]
239
+ logger.info(f"Downsampling {sample_name} to {num_cell} cells...")
240
+ random_indices = np.random.choice(adata.n_obs, num_cell, replace=False)
241
+ adata = adata[random_indices].copy()
242
+ else:
243
+ # Stratified sampling based on sample-specific annotation distribution
244
+ sample_annotation_counts = adata.obs[params.annotation].value_counts()
245
+ sample_total_cells = len(adata)
246
+ target_total_cells = n_cell_used[st_id]
247
+
248
+ # Calculate sample-specific annotation proportions
249
+ sample_annotation_proportions = sample_annotation_counts / sample_total_cells
250
+
251
+ # Calculate target cells for each annotation in this sample
252
+ target_cells_per_annotation = (sample_annotation_proportions * target_total_cells).astype(int)
253
+
254
+ logger.info(f"Downsampling {sample_name} to {target_total_cells} cells...")
255
+ logger.debug("---Sample-specific annotation distribution-----")
256
+ for ann, count in target_cells_per_annotation.items():
257
+ logger.debug(f"{ann}: {count} cells")
258
+
259
+ # Perform stratified sampling
260
+ sampled_cells = (
261
+ adata.obs.groupby(params.annotation, group_keys=False, observed=True)
262
+ .apply(
263
+ lambda x: x.sample(
264
+ max(min(target_cells_per_annotation.get(x.name, 0), len(x)), 1),
265
+ replace=False,
266
+ ),
267
+ include_groups=False
268
+ )
269
+ .index
270
+ )
271
+
272
+ # Filter adata to sampled cells
273
+ adata = adata[sampled_cells].copy()
274
+
275
+ # Add batch information to obs
276
+ adata.obs['batch_id'] = f"S{st_id}"
277
+ adata.obs['sample_name'] = sample_name
278
+
279
+ # Add label information to obs (ensure it's properly set)
280
+ if params.annotation is not None:
281
+ # The annotation column already exists, just ensure it's called 'label'
282
+ if 'label' not in adata.obs.columns:
283
+ adata.obs['label'] = adata.obs[params.annotation]
284
+ else:
285
+ # Create dummy labels
286
+ adata.obs['label'] = 'unknown'
287
+
288
+ subsampled_adatas.append(adata)
289
+ logger.info(f"Processed {sample_name}: {adata.n_obs} cells, {adata.n_vars} genes")
290
+
291
+ # Concatenate all samples
292
+ logger.info("Concatenating all processed data...")
293
+ concatenated_adata = sc.concat(subsampled_adatas, axis=0, join='inner',
294
+ index_unique='_', fill_value=0)
295
+
296
+ logger.info(f"Final concatenated adata: {concatenated_adata.n_obs} cells, {concatenated_adata.n_vars} genes")
297
+
298
+ return concatenated_adata
299
+
300
+ def _looks_like_count_matrix(X, max_check=100, tol=1e-8):
301
+ """
302
+ Heuristically check whether X contains integer values.
303
+ """
304
+ if X is None:
305
+ return False
306
+
307
+ if sp.issparse(X):
308
+ data = X.data
309
+ if data.size == 0:
310
+ return False
311
+ if data.size > max_check:
312
+ data = data[:max_check]
313
+ else:
314
+ flat = np.asarray(X).ravel()
315
+ if flat.size == 0:
316
+ return False
317
+ if flat.size > max_check:
318
+ flat = flat[:max_check]
319
+ data = flat
320
+
321
+ nz = data[data != 0]
322
+ if nz.size == 0:
323
+ return True # all zero, good enough
324
+ return np.allclose(nz, np.round(nz), atol=tol)
325
+
326
+
327
+ def setup_data_layer(adata, data_layer, verbose=True):
328
+ """
329
+ Setup and validate data layer for adata object.
330
+
331
+ Args:
332
+ adata: AnnData object
333
+ data_layer: Requested data layer name
334
+ verbose: Whether to log information messages (default: True)
335
+
336
+ Returns:
337
+ tuple: (is_count_data, actual_data_layer)
338
+ - is_count_data: Boolean indicating if data is count data
339
+ - actual_data_layer: The actual layer name to use ("X" if using adata.X)
340
+ """
341
+ # Check if the requested data layer exists in layers
342
+ if data_layer in adata.layers:
343
+ # Use the specified layer
344
+ adata.X = adata.layers[data_layer]
345
+ actual_data_layer = data_layer
346
+ if verbose:
347
+ logger.info(f"Using data layer: {data_layer}")
348
+ elif data_layer == "X":
349
+ actual_data_layer = "X"
350
+ else:
351
+ raise ValueError(f"Data layer '{data_layer}' not found in adata.layers.")
352
+
353
+ # Determine if this is count data
354
+ is_count_data = actual_data_layer in ["count", "counts", "raw_counts","impute_count"] or \
355
+ (adata.X is not None and np.issubdtype(adata.X.dtype, np.integer)) or \
356
+ _looks_like_count_matrix(adata.X)
357
+
358
+ if verbose:
359
+ if is_count_data:
360
+ logger.info("Detected count data")
361
+ else:
362
+ logger.info("Data appears to be normalized/log-transformed")
363
+
364
+ return is_count_data, actual_data_layer
365
+
366
+
367
+ def normalize_for_analysis(adata, is_count_data, preserve_raw=True):
368
+ """
369
+ Normalize data for DEG analysis or module scoring if needed.
370
+
371
+ Args:
372
+ adata: AnnData object
373
+ is_count_data: Boolean indicating if data is count data
374
+ preserve_raw: Whether to preserve raw data in adata.raw
375
+
376
+ Returns:
377
+ adata: AnnData object with normalized data
378
+ """
379
+ if is_count_data:
380
+ logger.info("Normalizing and log-transforming count data...")
381
+
382
+ # Store raw counts if requested and not already stored
383
+ if preserve_raw and adata.raw is None:
384
+ adata.raw = adata
385
+
386
+ # Normalize to 10,000 reads per cell
387
+ sc.pp.normalize_total(adata, target_sum=1e4)
388
+
389
+ # Log-transform (log1p)
390
+ sc.pp.log1p(adata)
391
+
392
+ logger.info("Data normalized and log-transformed")
393
+ else:
394
+ logger.info("Using data as-is (already normalized/log-transformed)")
395
+
396
+ return adata
397
+
398
+
399
+ def filter_significant_degs(deg_results, annotation, adata=None, pval_threshold=0.05, lfc_threshold=0.25, max_genes=100):
400
+ """
401
+ Filter DEGs based on statistical significance and fold change criteria.
402
+
403
+ Args:
404
+ deg_results: DEG results from scanpy rank_genes_groups
405
+ annotation: Annotation label to get DEGs for
406
+ adata: Optional adata to check gene existence
407
+ pval_threshold: P-value threshold (default: 0.05)
408
+ lfc_threshold: Log fold change threshold (default: 0.25)
409
+ max_genes: Maximum number of genes to return (default: 100)
410
+
411
+ Returns:
412
+ list: Filtered list of significant DEG gene names
413
+ """
414
+ gene_names = deg_results['names'][annotation]
415
+ pvals_adj = deg_results['pvals_adj'][annotation]
416
+ logfoldchanges = deg_results['logfoldchanges'][annotation]
417
+
418
+ # Filter genes based on significance and fold change
419
+ annotation_genes = []
420
+ for gene, pval, lfc in zip(gene_names, pvals_adj, logfoldchanges, strict=False):
421
+ if gene is not None and pval < pval_threshold and abs(lfc) > lfc_threshold:
422
+ annotation_genes.append(gene)
423
+
424
+ # Limit to max genes if we have more
425
+ if len(annotation_genes) > max_genes:
426
+ annotation_genes = annotation_genes[:max_genes]
427
+
428
+ # Ensure genes exist in adata if provided
429
+ if adata is not None:
430
+ annotation_genes = [gene for gene in annotation_genes if gene in adata.var_names]
431
+
432
+ return annotation_genes
433
+
434
+
435
+ def calculate_module_scores_from_degs(adata, deg_results, annotation_key):
436
+ """
437
+ Calculate module scores using existing DEG results.
438
+
439
+ Args:
440
+ adata: AnnData object to calculate scores for
441
+ deg_results: DEG results from scanpy rank_genes_groups
442
+ annotation_key: Column name in obs containing annotation labels
443
+
444
+ Returns:
445
+ adata: Updated adata with module score columns
446
+ """
447
+
448
+ logger.info("Calculating module scores using existing DEG results...")
449
+
450
+ # Detect count data and normalize if needed
451
+ is_count_data = _looks_like_count_matrix(adata.X)
452
+ adata = normalize_for_analysis(adata, is_count_data, preserve_raw=False)
453
+
454
+ # Ensure annotation is categorical if it exists
455
+ if annotation_key in adata.obs.columns:
456
+ adata.obs[annotation_key] = adata.obs[annotation_key].astype('category')
457
+ available_annotations = adata.obs[annotation_key].cat.categories
458
+ else:
459
+ # If annotation doesn't exist, use all annotations from DEG results
460
+ available_annotations = list(deg_results['names'].dtype.names)
461
+
462
+ # Calculate module score for each annotation
463
+ for annotation in available_annotations:
464
+ if annotation in deg_results['names'].dtype.names:
465
+ # Get significant DEGs for this annotation
466
+ annotation_genes = filter_significant_degs(deg_results, annotation, adata)
467
+
468
+ if len(annotation_genes) > 0:
469
+ logger.info(f"Calculating module score for {annotation} using {len(annotation_genes)} genes")
470
+
471
+ # Calculate module score
472
+ sc.tl.score_genes(
473
+ adata,
474
+ gene_list=annotation_genes,
475
+ score_name=f"{annotation}_module_score",
476
+ use_raw=False
477
+ )
478
+ else:
479
+ logger.warning(f"No valid DEGs found for {annotation}")
480
+ adata.obs[f"{annotation}_module_score"] = 0.0
481
+ else:
482
+ logger.warning(f"Annotation {annotation} not found in DEG results")
483
+ adata.obs[f"{annotation}_module_score"] = 0.0
484
+
485
+ return adata
486
+
487
+
488
+ def calculate_module_score(training_adata, annotation_key):
489
+ """
490
+ Perform DEG analysis for each annotation and calculate module scores.
491
+
492
+ Args:
493
+ training_adata: Concatenated training adata with annotation information
494
+ annotation_key: Column name in obs containing annotation labels
495
+
496
+ Returns:
497
+ training_adata: Updated adata with module scores for each annotation
498
+ """
499
+
500
+ logger.info("Performing DEG analysis for each annotation...")
501
+
502
+ # Make a copy to avoid modifying the original
503
+ adata = training_adata.copy()
504
+
505
+ # Ensure annotation is categorical
506
+ adata.obs[annotation_key] = adata.obs[annotation_key].astype('category')
507
+
508
+ # Detect count data and normalize if needed
509
+ is_count_data = _looks_like_count_matrix(adata.X)
510
+ adata = normalize_for_analysis(adata, is_count_data, preserve_raw=True)
511
+
512
+ # Perform DEG analysis with DataFrame fragmentation warnings suppressed
513
+ # These warnings are harmless and come from Scanpy's internal implementation
514
+ import warnings
515
+ with warnings.catch_warnings():
516
+ warnings.filterwarnings('ignore', category=pd.errors.PerformanceWarning,
517
+ message='DataFrame is highly fragmented')
518
+ sc.tl.rank_genes_groups(
519
+ adata,
520
+ groupby=annotation_key,
521
+ method='wilcoxon',
522
+ use_raw=False,
523
+ n_genes=100
524
+ )
525
+
526
+ logger.info("Calculating module scores for each annotation...")
527
+
528
+ # Get DEG results
529
+ deg_results = adata.uns['rank_genes_groups']
530
+
531
+ # Calculate module score for each annotation
532
+ for annotation in adata.obs[annotation_key].cat.categories:
533
+
534
+ # Get significant DEGs for this annotation
535
+ annotation_genes = filter_significant_degs(deg_results, annotation, adata)
536
+
537
+ if len(annotation_genes) > 0:
538
+ logger.info(f"Calculating module score for {annotation} using {len(annotation_genes)} genes")
539
+
540
+ # Calculate module score
541
+ sc.tl.score_genes(
542
+ adata,
543
+ gene_list=annotation_genes,
544
+ score_name=f"{annotation}_module_score",
545
+ use_raw=False
546
+ )
547
+ else:
548
+ logger.warning(f"No valid DEGs found for {annotation}")
549
+ adata.obs[f"{annotation}_module_score"] = 0.0
550
+
551
+ logger.info("Module score calculation completed")
552
+ return adata
553
+
554
+
555
+ def apply_module_score_qc(adata, annotation_key, module_score_threshold_dict):
556
+ """
557
+ Apply quality control based on module scores.
558
+
559
+ Args:
560
+ adata: AnnData object to apply QC to
561
+ annotation_key: Column name in obs containing annotation labels
562
+ module_score_threshold_dict: Dictionary mapping annotation to threshold values
563
+
564
+ Returns:
565
+ adata: Updated adata with QC information
566
+ """
567
+ logger.info("Applying module score-based quality control...")
568
+
569
+ # Initialize High_quality column as boolean (True = high quality)
570
+ adata.obs['High_quality'] = True
571
+
572
+ # Check if we have the annotation key
573
+ if annotation_key not in adata.obs.columns:
574
+ logger.warning(f"Annotation key '{annotation_key}' not found in adata.obs. Skipping module score QC.")
575
+ return adata
576
+
577
+ # Apply QC for each annotation
578
+ for annotation, threshold in module_score_threshold_dict.items():
579
+ module_score_col = f"{annotation}_module_score"
580
+
581
+ if module_score_col in adata.obs.columns:
582
+ # Find cells of this annotation with low module scores
583
+ annotation_mask = adata.obs[annotation_key] == annotation
584
+ low_score_mask = adata.obs[module_score_col] < threshold
585
+
586
+ # Set High_quality to False for cells that match both conditions (low quality)
587
+ low_quality_mask = annotation_mask & low_score_mask
588
+ adata.obs.loc[low_quality_mask, 'High_quality'] = False
589
+
590
+ n_low_quality = low_quality_mask.sum()
591
+ n_annotation_cells = annotation_mask.sum()
592
+
593
+ logger.info(f"{annotation}: {n_low_quality}/{n_annotation_cells} cells marked as low quality "
594
+ f"(threshold: {threshold:.3f})")
595
+ else:
596
+ logger.warning(f"Module score column '{module_score_col}' not found in adata.obs")
597
+
598
+ total_low_quality = (~adata.obs['High_quality']).sum()
599
+ logger.info(f"Total low quality cells: {total_low_quality}/{adata.n_obs}")
600
+
601
+ return adata
602
+
603
+
604
+ # prepare the trainning data
605
+ class TrainingData:
606
+ """
607
+ Managing and processing training data for graph-based models.
608
+
609
+ Attributes:
610
+ params (dict): A dictionary of parameters used for data processing and training.
611
+ """
612
+
613
+ def __init__(self, params):
614
+ self.params = params
615
+ self.gcov = GCN(self.params.K)
616
+ self.expression_merge = None
617
+ self.expression_gcn_merge = None
618
+ self.label_merge = None
619
+ self.batch_merge = None
620
+ self.batch_size = None
621
+ self.label_name = None
622
+
623
+
624
+ def prepare(self, concatenated_adata, hvg):
625
+ logger.info("Processing concatenated subsampled data...")
626
+
627
+ # Get labels from obs
628
+ if self.params.annotation is not None:
629
+ label = concatenated_adata.obs['label'].values
630
+ else:
631
+ label = np.zeros(concatenated_adata.n_obs)
632
+
633
+ # Get batch information from obs
634
+ batch_labels = concatenated_adata.obs['batch_id'].values
635
+
636
+ # Get expression array for HVG genes
637
+ expression_array = torch.Tensor(concatenated_adata[:, hvg].X.toarray())
638
+ logger.info(f"Expression array shape: {expression_array.shape}")
639
+
640
+ # Process each batch separately for GCN (since spatial graphs are sample-specific)
641
+ expression_array_gcn_list = []
642
+
643
+ for batch_id in concatenated_adata.obs['batch_id'].unique():
644
+ batch_mask = concatenated_adata.obs['batch_id'] == batch_id
645
+ batch_adata = concatenated_adata[batch_mask]
646
+ batch_expression = expression_array[batch_mask.values]
647
+
648
+ logger.info(f"Processing batch {batch_id} with {batch_adata.n_obs} cells...")
649
+
650
+ # Build spatial graph for this batch
651
+ edge = build_spatial_graph(
652
+ coords=np.array(batch_adata.obsm[self.params.spatial_key]),
653
+ n_neighbors=self.params.n_neighbors,
654
+ )
655
+ edge = torch.from_numpy(edge.T).long()
656
+
657
+ # Apply GCN to this batch
658
+ batch_expression_gcn = self.gcov(batch_expression, edge)
659
+ expression_array_gcn_list.append(batch_expression_gcn)
660
+
661
+ logger.info(f"Graph for {batch_id} has {edge.size(1)} edges, {batch_adata.n_obs} cells.")
662
+
663
+ # Concatenate GCN results in the same order as the original data
664
+ expression_array_gcn = torch.cat(expression_array_gcn_list, dim=0)
665
+
666
+ # Convert batch labels to numeric codes
667
+ batch_codes = pd.Categorical(batch_labels).codes
668
+
669
+ # Convert labels to categorical codes
670
+ cat_labels = pd.Categorical(label)
671
+ label_codes = cat_labels.codes
672
+
673
+ # Store results
674
+ self.expression_merge = expression_array
675
+ self.expression_gcn_merge = expression_array_gcn
676
+ self.batch_merge = torch.Tensor(batch_codes)
677
+ self.label_merge = torch.Tensor(label_codes).long()
678
+
679
+ if self.params.annotation is not None:
680
+ self.label_name = cat_labels.categories.take(np.unique(cat_labels.codes)).to_list()
681
+ else:
682
+ self.label_name = None
683
+
684
+ # Set batch size
685
+ self.batch_size = len(torch.unique(self.batch_merge))
686
+
687
+
688
+ # Inference for each ST dataset
689
+ class InferenceData:
690
+ """
691
+ Infer cell embeddings for each spatial transcriptomics (ST) dataset.
692
+ Attributes:
693
+ hvg: List of highly variable genes.
694
+ batch_size: Integer defining the batch size for inference.
695
+ model: Model to be used for inference.
696
+ params: Dictionary containing additional parameters for inference.
697
+ """
698
+
699
+ def __init__(self, hvg, batch_size, model, label_name, params):
700
+ self.params = params
701
+ self.gcov = GCN(self.params.K)
702
+ self.hvg = hvg
703
+ self.batch_size = batch_size
704
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
705
+ self.model = model.to(self.device)
706
+ self.label_name = label_name
707
+ self.processed_list_path = self.params.latent_dir / 'processed.list'
708
+
709
+
710
+ def infer_embedding_single(self, st_id, st_file) -> Path:
711
+ st_name = (Path(st_file).name).split(".h5ad")[0]
712
+ logger.info(f"Infering cell embeddings for {st_name}...")
713
+
714
+ # Load the ST data
715
+ adata = sc.read_h5ad(st_file)
716
+ # sc.pp.filter_genes(adata, min_counts=1)
717
+
718
+ # Make gene names unique to avoid reindexing issues
719
+ adata.var_names_make_unique()
720
+
721
+ # Setup data layer consistently
722
+ is_count_data, actual_data_layer = setup_data_layer(adata, self.params.data_layer)
723
+
724
+ # print(adata.shape)
725
+ # Convert expression data to torch.Tensor
726
+ expression_array = torch.Tensor(adata[:, self.hvg].X.toarray())
727
+
728
+ # Graph convolution of expression array
729
+ edge = build_spatial_graph(
730
+ coords=np.array(adata.obsm[self.params.spatial_key]),
731
+ n_neighbors=self.params.n_neighbors,
732
+ )
733
+ edge = torch.from_numpy(edge.T).long()
734
+ expression_array_gcn = self.gcov(expression_array, edge)
735
+
736
+ # Build batch vector as one-hot encoding
737
+ n_cell = adata.n_obs
738
+ batch_indices = torch.full((n_cell,), st_id, dtype=torch.long)
739
+
740
+ # Prepare the evaluation DataLoader
741
+ dataset = TensorDataset(expression_array_gcn,expression_array, batch_indices)
742
+ Inference_loader = DataLoader(dataset=dataset, batch_size=512, shuffle=False)
743
+
744
+ # Inference process
745
+ emb_cell, emb_niche, class_prob = [], [], []
746
+
747
+ for (
748
+ expression_gcn_focal,
749
+ expression_focal,
750
+ batch_indices_fcocal,
751
+ ) in Inference_loader:
752
+ expression_gcn_focal = expression_gcn_focal.to(self.device)
753
+ expression_focal = expression_focal.to(self.device)
754
+ batch_indices_fcocal = batch_indices_fcocal.to(self.device)
755
+
756
+ self.model.eval()
757
+ with torch.no_grad():
758
+ mu_focal = self.model.encode(
759
+ [expression_focal, expression_gcn_focal], batch_indices_fcocal
760
+ )
761
+ _,x_class, _, _ = self.model(
762
+ [expression_focal, expression_gcn_focal], batch_indices_fcocal
763
+ )
764
+
765
+ class_prob.append(x_class.cpu().numpy())
766
+ emb_cell.append(mu_focal[0].cpu().numpy())
767
+ emb_niche.append(mu_focal[1].cpu().numpy())
768
+
769
+ # Concatenate results and store embeddings in adata
770
+ emb_cell = np.concatenate(emb_cell, axis=0)
771
+ emb_niche = np.concatenate(emb_niche, axis=0)
772
+ class_prob = np.concatenate(class_prob, axis=0)
773
+
774
+ # if self.label_name is not None:
775
+ # class_prob = pd.DataFrame(softmax(class_prob,axis=1), columns=self.label_name,index=adata.obs_names)
776
+ # adata.obsm["class_prob"] = class_prob
777
+
778
+ adata.obsm[self.params.latent_representation_cell] = emb_cell
779
+ adata.obsm[self.params.latent_representation_niche] = emb_niche
780
+
781
+ return adata