gsMap3D 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gsMap/__init__.py +13 -0
- gsMap/__main__.py +4 -0
- gsMap/cauchy_combination_test.py +342 -0
- gsMap/cli.py +355 -0
- gsMap/config/__init__.py +72 -0
- gsMap/config/base.py +296 -0
- gsMap/config/cauchy_config.py +79 -0
- gsMap/config/dataclasses.py +235 -0
- gsMap/config/decorators.py +302 -0
- gsMap/config/find_latent_config.py +276 -0
- gsMap/config/format_sumstats_config.py +54 -0
- gsMap/config/latent2gene_config.py +461 -0
- gsMap/config/ldscore_config.py +261 -0
- gsMap/config/quick_mode_config.py +242 -0
- gsMap/config/report_config.py +81 -0
- gsMap/config/spatial_ldsc_config.py +334 -0
- gsMap/config/utils.py +286 -0
- gsMap/find_latent/__init__.py +3 -0
- gsMap/find_latent/find_latent_representation.py +312 -0
- gsMap/find_latent/gnn/distribution.py +498 -0
- gsMap/find_latent/gnn/encoder_decoder.py +186 -0
- gsMap/find_latent/gnn/gcn.py +85 -0
- gsMap/find_latent/gnn/gene_former.py +164 -0
- gsMap/find_latent/gnn/loss.py +18 -0
- gsMap/find_latent/gnn/st_model.py +125 -0
- gsMap/find_latent/gnn/train_step.py +177 -0
- gsMap/find_latent/st_process.py +781 -0
- gsMap/format_sumstats.py +446 -0
- gsMap/generate_ldscore.py +1018 -0
- gsMap/latent2gene/__init__.py +18 -0
- gsMap/latent2gene/connectivity.py +781 -0
- gsMap/latent2gene/entry_point.py +141 -0
- gsMap/latent2gene/marker_scores.py +1265 -0
- gsMap/latent2gene/memmap_io.py +766 -0
- gsMap/latent2gene/rank_calculator.py +590 -0
- gsMap/latent2gene/row_ordering.py +182 -0
- gsMap/latent2gene/row_ordering_jax.py +159 -0
- gsMap/ldscore/__init__.py +1 -0
- gsMap/ldscore/batch_construction.py +163 -0
- gsMap/ldscore/compute.py +126 -0
- gsMap/ldscore/constants.py +70 -0
- gsMap/ldscore/io.py +262 -0
- gsMap/ldscore/mapping.py +262 -0
- gsMap/ldscore/pipeline.py +615 -0
- gsMap/pipeline/quick_mode.py +134 -0
- gsMap/report/__init__.py +2 -0
- gsMap/report/diagnosis.py +375 -0
- gsMap/report/report.py +100 -0
- gsMap/report/report_data.py +1832 -0
- gsMap/report/static/js_lib/alpine.min.js +5 -0
- gsMap/report/static/js_lib/tailwindcss.js +83 -0
- gsMap/report/static/template.html +2242 -0
- gsMap/report/three_d_combine.py +312 -0
- gsMap/report/three_d_plot/three_d_plot_decorate.py +246 -0
- gsMap/report/three_d_plot/three_d_plot_prepare.py +202 -0
- gsMap/report/three_d_plot/three_d_plots.py +425 -0
- gsMap/report/visualize.py +1409 -0
- gsMap/setup.py +5 -0
- gsMap/spatial_ldsc/__init__.py +0 -0
- gsMap/spatial_ldsc/io.py +656 -0
- gsMap/spatial_ldsc/ldscore_quick_mode.py +912 -0
- gsMap/spatial_ldsc/spatial_ldsc_jax.py +382 -0
- gsMap/spatial_ldsc/spatial_ldsc_multiple_sumstats.py +439 -0
- gsMap/utils/__init__.py +0 -0
- gsMap/utils/generate_r2_matrix.py +610 -0
- gsMap/utils/jackknife.py +518 -0
- gsMap/utils/manhattan_plot.py +643 -0
- gsMap/utils/regression_read.py +177 -0
- gsMap/utils/torch_utils.py +23 -0
- gsmap3d-0.1.0a1.dist-info/METADATA +168 -0
- gsmap3d-0.1.0a1.dist-info/RECORD +74 -0
- gsmap3d-0.1.0a1.dist-info/WHEEL +4 -0
- gsmap3d-0.1.0a1.dist-info/entry_points.txt +2 -0
- gsmap3d-0.1.0a1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,781 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import re
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import scanpy as sc
|
|
8
|
+
import scipy.sparse as sp
|
|
9
|
+
import torch
|
|
10
|
+
from rich.progress import BarColumn, Progress, TaskProgressColumn, TimeRemainingColumn
|
|
11
|
+
from torch.utils.data import DataLoader, TensorDataset
|
|
12
|
+
|
|
13
|
+
from gsMap.config import FindLatentRepresentationsConfig
|
|
14
|
+
|
|
15
|
+
from .gnn.gcn import GCN, build_spatial_graph
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def convert_to_human_genes(adata, gene_homolog_dict: dict, species: str = None):
|
|
21
|
+
"""
|
|
22
|
+
Convert gene names in adata to human gene symbols using a pre-computed mapping dictionary.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
adata: AnnData object to convert
|
|
26
|
+
gene_homolog_dict: Dictionary mapping current gene names to human gene symbols
|
|
27
|
+
species: Optional species name to use for the original gene column in var
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
adata: Processed AnnData object with human symbols as var_names
|
|
31
|
+
"""
|
|
32
|
+
# Identify common genes present in mapping
|
|
33
|
+
common_genes = [g for g in gene_homolog_dict.keys() if g in adata.var_names]
|
|
34
|
+
|
|
35
|
+
if len(common_genes) < 300:
|
|
36
|
+
logger.warning(f"Only {len(common_genes)} genes found in mapping dictionary.")
|
|
37
|
+
|
|
38
|
+
# Filter to these genes
|
|
39
|
+
adata = adata[:, common_genes].copy()
|
|
40
|
+
|
|
41
|
+
human_symbols = [gene_homolog_dict[g] for g in adata.var_names]
|
|
42
|
+
|
|
43
|
+
# Store original gene names if species is provided
|
|
44
|
+
if species:
|
|
45
|
+
species_col_name = f"{species}_GENE_SYM"
|
|
46
|
+
adata.var[species_col_name] = adata.var_names.values
|
|
47
|
+
|
|
48
|
+
# Update var_names to human symbols
|
|
49
|
+
adata.var_names = human_symbols
|
|
50
|
+
adata.var.index.name = "HUMAN_GENE_SYM"
|
|
51
|
+
|
|
52
|
+
# Remove duplicated genes (keep first occurrence)
|
|
53
|
+
if adata.var_names.duplicated().any():
|
|
54
|
+
n_before = adata.n_vars
|
|
55
|
+
adata = adata[:, ~adata.var_names.duplicated()].copy()
|
|
56
|
+
logger.info(f"Removed {n_before - adata.n_vars} duplicate human gene symbols.")
|
|
57
|
+
|
|
58
|
+
return adata
|
|
59
|
+
|
|
60
|
+
def find_common_hvg(sample_h5ad_dict, params: FindLatentRepresentationsConfig):
|
|
61
|
+
"""
|
|
62
|
+
Identifies common highly variable genes (HVGs) across multiple ST datasets and calculates
|
|
63
|
+
the number of cells to sample from each dataset.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
sample_h5ad_dict (dict): Dictionary mapping sample names to file paths of ST datasets.
|
|
67
|
+
params (object): Parameter object containing attributes.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
variances_list = []
|
|
71
|
+
cell_number = []
|
|
72
|
+
|
|
73
|
+
logger.info("Finding highly variable genes (HVGs)...")
|
|
74
|
+
|
|
75
|
+
with Progress(
|
|
76
|
+
BarColumn(),
|
|
77
|
+
TaskProgressColumn(),
|
|
78
|
+
TimeRemainingColumn(),
|
|
79
|
+
) as progress:
|
|
80
|
+
task = progress.add_task("Finding HVGs", total=len(sample_h5ad_dict))
|
|
81
|
+
|
|
82
|
+
for sample_name, st_file in sample_h5ad_dict.items():
|
|
83
|
+
adata_temp = sc.read_h5ad(st_file)
|
|
84
|
+
# sc.pp.filter_genes(adata_temp, min_counts=1)
|
|
85
|
+
|
|
86
|
+
# Filter out mitochondrial and hemoglobin genes
|
|
87
|
+
gene_keep = ~adata_temp.var_names.str.match(re.compile(r'^(HB.-|MT-)', re.IGNORECASE))
|
|
88
|
+
if removed_genes := adata_temp.n_vars - gene_keep.sum():
|
|
89
|
+
progress.console.log(f"Removed {removed_genes} mitochondrial and hemoglobin genes in {sample_name}.")
|
|
90
|
+
|
|
91
|
+
adata_temp = adata_temp[:,gene_keep].copy()
|
|
92
|
+
|
|
93
|
+
# Make gene names unique to avoid issues with HVG calculation
|
|
94
|
+
adata_temp.var_names_make_unique()
|
|
95
|
+
|
|
96
|
+
is_count_data, actual_data_layer = setup_data_layer(adata_temp, params.data_layer, verbose=False)
|
|
97
|
+
cell_number.append(adata_temp.n_obs)
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
# Identify highly variable genes
|
|
101
|
+
if is_count_data:
|
|
102
|
+
sc.experimental.pp.highly_variable_genes(
|
|
103
|
+
adata_temp, n_top_genes=params.feat_cell, subset=False
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
sc.pp.highly_variable_genes(
|
|
107
|
+
adata_temp, n_top_genes=params.feat_cell, subset=False, flavor='seurat'
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
var_df = adata_temp.var.copy()
|
|
111
|
+
var_df["gene"] = var_df.index.tolist()
|
|
112
|
+
variances_list.append(var_df)
|
|
113
|
+
except Exception as e:
|
|
114
|
+
logger.warning(f"[HVG skipped] {e}")
|
|
115
|
+
continue
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
progress.update(task, advance=1)
|
|
119
|
+
|
|
120
|
+
# Check if we have any valid variance results
|
|
121
|
+
if len(variances_list) == 0:
|
|
122
|
+
raise ValueError("No valid HVG results obtained from any sample. Please check your data and parameters.")
|
|
123
|
+
|
|
124
|
+
# Find the common genes across all samples
|
|
125
|
+
common_genes = np.array(
|
|
126
|
+
list(set.intersection(
|
|
127
|
+
*map(set, [st.index.to_list() for st in variances_list])))
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Error when gene number is too small
|
|
131
|
+
if len(common_genes) < 300:
|
|
132
|
+
raise ValueError(f"Only {len(common_genes)} common genes between all samples.")
|
|
133
|
+
|
|
134
|
+
# Aggregate variances and identify HVGs
|
|
135
|
+
df = pd.concat(variances_list, axis=0)
|
|
136
|
+
df["highly_variable"] = df["highly_variable"].astype(int)
|
|
137
|
+
|
|
138
|
+
if is_count_data:
|
|
139
|
+
df = df.groupby("gene", observed=True).agg(
|
|
140
|
+
dict(
|
|
141
|
+
residual_variances="median",
|
|
142
|
+
highly_variable="sum",
|
|
143
|
+
)
|
|
144
|
+
)
|
|
145
|
+
df = df.loc[common_genes]
|
|
146
|
+
df["highly_variable_nbatches"] = df["highly_variable"]
|
|
147
|
+
df.sort_values(
|
|
148
|
+
["highly_variable_nbatches", "residual_variances"],
|
|
149
|
+
ascending=False,
|
|
150
|
+
na_position="last",
|
|
151
|
+
inplace=True,
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
df = df.groupby("gene", observed=True).agg(
|
|
155
|
+
dict(
|
|
156
|
+
dispersions_norm="median",
|
|
157
|
+
highly_variable="sum",
|
|
158
|
+
)
|
|
159
|
+
)
|
|
160
|
+
df = df.loc[common_genes]
|
|
161
|
+
df["highly_variable_nbatches"] = df["highly_variable"]
|
|
162
|
+
df.sort_values(
|
|
163
|
+
["highly_variable_nbatches", "dispersions_norm"],
|
|
164
|
+
ascending=False,
|
|
165
|
+
na_position="last",
|
|
166
|
+
inplace=True,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
hvg = df.iloc[: params.feat_cell,].index.tolist()
|
|
170
|
+
|
|
171
|
+
# Find the number of sampling cells for each batch
|
|
172
|
+
total_cell = np.sum(cell_number)
|
|
173
|
+
total_cell_training = np.minimum(total_cell, params.n_cell_training)
|
|
174
|
+
cell_proportion = cell_number / total_cell
|
|
175
|
+
n_cell_used = [
|
|
176
|
+
int(cell) for cell in (total_cell_training * cell_proportion).tolist()
|
|
177
|
+
]
|
|
178
|
+
|
|
179
|
+
# Only use the common genes that can be transformed to human genes
|
|
180
|
+
if params.species is not None:
|
|
181
|
+
homologs = pd.read_csv(params.homolog_file, sep='\t')
|
|
182
|
+
if homologs.shape[1] < 2:
|
|
183
|
+
raise ValueError("Homologs file must have at least two columns: one for the species and one for the human gene symbol.")
|
|
184
|
+
homologs.columns = [params.species, 'HUMAN_GENE_SYM']
|
|
185
|
+
homologs.set_index(params.species, inplace=True)
|
|
186
|
+
common_genes = np.intersect1d(common_genes, homologs.index)
|
|
187
|
+
gene_homolog_dict = dict(zip(common_genes,homologs.loc[common_genes].HUMAN_GENE_SYM.values, strict=False))
|
|
188
|
+
else:
|
|
189
|
+
gene_homolog_dict = dict(zip(common_genes,common_genes, strict=False))
|
|
190
|
+
|
|
191
|
+
if len(gene_homolog_dict) < 300:
|
|
192
|
+
raise ValueError(f"Only {len(gene_homolog_dict)} genes could be mapped to human symbols. Please check the homolog file.")
|
|
193
|
+
|
|
194
|
+
return hvg, n_cell_used, gene_homolog_dict
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def create_subsampled_adata(sample_h5ad_dict, n_cell_used, params: FindLatentRepresentationsConfig):
|
|
198
|
+
"""
|
|
199
|
+
Create subsampled adata for each sample with sample-specific stratified sampling,
|
|
200
|
+
add batch and label information, and return concatenated adata.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
sample_h5ad_dict (dict): Dictionary mapping sample names to file paths of ST datasets.
|
|
204
|
+
n_cell_used (list): Number of cells to sample from each dataset.
|
|
205
|
+
params (object): Parameter object containing attributes.
|
|
206
|
+
|
|
207
|
+
Returns:
|
|
208
|
+
adata: Concatenated adata object with batch and label information in obs.
|
|
209
|
+
"""
|
|
210
|
+
subsampled_adatas = []
|
|
211
|
+
|
|
212
|
+
logger.info("Creating subsampled adata with stratified sampling...")
|
|
213
|
+
|
|
214
|
+
for st_id, (sample_name, st_file) in enumerate(sample_h5ad_dict.items()):
|
|
215
|
+
logger.info(f"Processing {sample_name}...")
|
|
216
|
+
|
|
217
|
+
# Load the data
|
|
218
|
+
adata = sc.read_h5ad(st_file)
|
|
219
|
+
|
|
220
|
+
# Filter out mitochondrial and hemoglobin genes
|
|
221
|
+
gene_keep = ~adata.var_names.str.match(re.compile(r'^(HB.-|MT-)', re.IGNORECASE))
|
|
222
|
+
adata = adata[:,gene_keep].copy()
|
|
223
|
+
|
|
224
|
+
# Make gene names unique to avoid issues with concatenation
|
|
225
|
+
adata.var_names_make_unique()
|
|
226
|
+
|
|
227
|
+
# Setup data layer consistently
|
|
228
|
+
is_count_data, actual_data_layer = setup_data_layer(adata, params.data_layer)
|
|
229
|
+
|
|
230
|
+
# Filter cells based on annotation if provided
|
|
231
|
+
if params.annotation is not None:
|
|
232
|
+
adata = adata[~adata.obs[params.annotation].isnull()]
|
|
233
|
+
|
|
234
|
+
# Perform stratified sampling within this sample
|
|
235
|
+
if params.do_sampling and adata.n_obs > n_cell_used[st_id]:
|
|
236
|
+
if params.annotation is None:
|
|
237
|
+
# Simple random sampling
|
|
238
|
+
num_cell = n_cell_used[st_id]
|
|
239
|
+
logger.info(f"Downsampling {sample_name} to {num_cell} cells...")
|
|
240
|
+
random_indices = np.random.choice(adata.n_obs, num_cell, replace=False)
|
|
241
|
+
adata = adata[random_indices].copy()
|
|
242
|
+
else:
|
|
243
|
+
# Stratified sampling based on sample-specific annotation distribution
|
|
244
|
+
sample_annotation_counts = adata.obs[params.annotation].value_counts()
|
|
245
|
+
sample_total_cells = len(adata)
|
|
246
|
+
target_total_cells = n_cell_used[st_id]
|
|
247
|
+
|
|
248
|
+
# Calculate sample-specific annotation proportions
|
|
249
|
+
sample_annotation_proportions = sample_annotation_counts / sample_total_cells
|
|
250
|
+
|
|
251
|
+
# Calculate target cells for each annotation in this sample
|
|
252
|
+
target_cells_per_annotation = (sample_annotation_proportions * target_total_cells).astype(int)
|
|
253
|
+
|
|
254
|
+
logger.info(f"Downsampling {sample_name} to {target_total_cells} cells...")
|
|
255
|
+
logger.debug("---Sample-specific annotation distribution-----")
|
|
256
|
+
for ann, count in target_cells_per_annotation.items():
|
|
257
|
+
logger.debug(f"{ann}: {count} cells")
|
|
258
|
+
|
|
259
|
+
# Perform stratified sampling
|
|
260
|
+
sampled_cells = (
|
|
261
|
+
adata.obs.groupby(params.annotation, group_keys=False, observed=True)
|
|
262
|
+
.apply(
|
|
263
|
+
lambda x: x.sample(
|
|
264
|
+
max(min(target_cells_per_annotation.get(x.name, 0), len(x)), 1),
|
|
265
|
+
replace=False,
|
|
266
|
+
),
|
|
267
|
+
include_groups=False
|
|
268
|
+
)
|
|
269
|
+
.index
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Filter adata to sampled cells
|
|
273
|
+
adata = adata[sampled_cells].copy()
|
|
274
|
+
|
|
275
|
+
# Add batch information to obs
|
|
276
|
+
adata.obs['batch_id'] = f"S{st_id}"
|
|
277
|
+
adata.obs['sample_name'] = sample_name
|
|
278
|
+
|
|
279
|
+
# Add label information to obs (ensure it's properly set)
|
|
280
|
+
if params.annotation is not None:
|
|
281
|
+
# The annotation column already exists, just ensure it's called 'label'
|
|
282
|
+
if 'label' not in adata.obs.columns:
|
|
283
|
+
adata.obs['label'] = adata.obs[params.annotation]
|
|
284
|
+
else:
|
|
285
|
+
# Create dummy labels
|
|
286
|
+
adata.obs['label'] = 'unknown'
|
|
287
|
+
|
|
288
|
+
subsampled_adatas.append(adata)
|
|
289
|
+
logger.info(f"Processed {sample_name}: {adata.n_obs} cells, {adata.n_vars} genes")
|
|
290
|
+
|
|
291
|
+
# Concatenate all samples
|
|
292
|
+
logger.info("Concatenating all processed data...")
|
|
293
|
+
concatenated_adata = sc.concat(subsampled_adatas, axis=0, join='inner',
|
|
294
|
+
index_unique='_', fill_value=0)
|
|
295
|
+
|
|
296
|
+
logger.info(f"Final concatenated adata: {concatenated_adata.n_obs} cells, {concatenated_adata.n_vars} genes")
|
|
297
|
+
|
|
298
|
+
return concatenated_adata
|
|
299
|
+
|
|
300
|
+
def _looks_like_count_matrix(X, max_check=100, tol=1e-8):
|
|
301
|
+
"""
|
|
302
|
+
Heuristically check whether X contains integer values.
|
|
303
|
+
"""
|
|
304
|
+
if X is None:
|
|
305
|
+
return False
|
|
306
|
+
|
|
307
|
+
if sp.issparse(X):
|
|
308
|
+
data = X.data
|
|
309
|
+
if data.size == 0:
|
|
310
|
+
return False
|
|
311
|
+
if data.size > max_check:
|
|
312
|
+
data = data[:max_check]
|
|
313
|
+
else:
|
|
314
|
+
flat = np.asarray(X).ravel()
|
|
315
|
+
if flat.size == 0:
|
|
316
|
+
return False
|
|
317
|
+
if flat.size > max_check:
|
|
318
|
+
flat = flat[:max_check]
|
|
319
|
+
data = flat
|
|
320
|
+
|
|
321
|
+
nz = data[data != 0]
|
|
322
|
+
if nz.size == 0:
|
|
323
|
+
return True # all zero, good enough
|
|
324
|
+
return np.allclose(nz, np.round(nz), atol=tol)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def setup_data_layer(adata, data_layer, verbose=True):
|
|
328
|
+
"""
|
|
329
|
+
Setup and validate data layer for adata object.
|
|
330
|
+
|
|
331
|
+
Args:
|
|
332
|
+
adata: AnnData object
|
|
333
|
+
data_layer: Requested data layer name
|
|
334
|
+
verbose: Whether to log information messages (default: True)
|
|
335
|
+
|
|
336
|
+
Returns:
|
|
337
|
+
tuple: (is_count_data, actual_data_layer)
|
|
338
|
+
- is_count_data: Boolean indicating if data is count data
|
|
339
|
+
- actual_data_layer: The actual layer name to use ("X" if using adata.X)
|
|
340
|
+
"""
|
|
341
|
+
# Check if the requested data layer exists in layers
|
|
342
|
+
if data_layer in adata.layers:
|
|
343
|
+
# Use the specified layer
|
|
344
|
+
adata.X = adata.layers[data_layer]
|
|
345
|
+
actual_data_layer = data_layer
|
|
346
|
+
if verbose:
|
|
347
|
+
logger.info(f"Using data layer: {data_layer}")
|
|
348
|
+
elif data_layer == "X":
|
|
349
|
+
actual_data_layer = "X"
|
|
350
|
+
else:
|
|
351
|
+
raise ValueError(f"Data layer '{data_layer}' not found in adata.layers.")
|
|
352
|
+
|
|
353
|
+
# Determine if this is count data
|
|
354
|
+
is_count_data = actual_data_layer in ["count", "counts", "raw_counts","impute_count"] or \
|
|
355
|
+
(adata.X is not None and np.issubdtype(adata.X.dtype, np.integer)) or \
|
|
356
|
+
_looks_like_count_matrix(adata.X)
|
|
357
|
+
|
|
358
|
+
if verbose:
|
|
359
|
+
if is_count_data:
|
|
360
|
+
logger.info("Detected count data")
|
|
361
|
+
else:
|
|
362
|
+
logger.info("Data appears to be normalized/log-transformed")
|
|
363
|
+
|
|
364
|
+
return is_count_data, actual_data_layer
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def normalize_for_analysis(adata, is_count_data, preserve_raw=True):
|
|
368
|
+
"""
|
|
369
|
+
Normalize data for DEG analysis or module scoring if needed.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
adata: AnnData object
|
|
373
|
+
is_count_data: Boolean indicating if data is count data
|
|
374
|
+
preserve_raw: Whether to preserve raw data in adata.raw
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
adata: AnnData object with normalized data
|
|
378
|
+
"""
|
|
379
|
+
if is_count_data:
|
|
380
|
+
logger.info("Normalizing and log-transforming count data...")
|
|
381
|
+
|
|
382
|
+
# Store raw counts if requested and not already stored
|
|
383
|
+
if preserve_raw and adata.raw is None:
|
|
384
|
+
adata.raw = adata
|
|
385
|
+
|
|
386
|
+
# Normalize to 10,000 reads per cell
|
|
387
|
+
sc.pp.normalize_total(adata, target_sum=1e4)
|
|
388
|
+
|
|
389
|
+
# Log-transform (log1p)
|
|
390
|
+
sc.pp.log1p(adata)
|
|
391
|
+
|
|
392
|
+
logger.info("Data normalized and log-transformed")
|
|
393
|
+
else:
|
|
394
|
+
logger.info("Using data as-is (already normalized/log-transformed)")
|
|
395
|
+
|
|
396
|
+
return adata
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def filter_significant_degs(deg_results, annotation, adata=None, pval_threshold=0.05, lfc_threshold=0.25, max_genes=100):
|
|
400
|
+
"""
|
|
401
|
+
Filter DEGs based on statistical significance and fold change criteria.
|
|
402
|
+
|
|
403
|
+
Args:
|
|
404
|
+
deg_results: DEG results from scanpy rank_genes_groups
|
|
405
|
+
annotation: Annotation label to get DEGs for
|
|
406
|
+
adata: Optional adata to check gene existence
|
|
407
|
+
pval_threshold: P-value threshold (default: 0.05)
|
|
408
|
+
lfc_threshold: Log fold change threshold (default: 0.25)
|
|
409
|
+
max_genes: Maximum number of genes to return (default: 100)
|
|
410
|
+
|
|
411
|
+
Returns:
|
|
412
|
+
list: Filtered list of significant DEG gene names
|
|
413
|
+
"""
|
|
414
|
+
gene_names = deg_results['names'][annotation]
|
|
415
|
+
pvals_adj = deg_results['pvals_adj'][annotation]
|
|
416
|
+
logfoldchanges = deg_results['logfoldchanges'][annotation]
|
|
417
|
+
|
|
418
|
+
# Filter genes based on significance and fold change
|
|
419
|
+
annotation_genes = []
|
|
420
|
+
for gene, pval, lfc in zip(gene_names, pvals_adj, logfoldchanges, strict=False):
|
|
421
|
+
if gene is not None and pval < pval_threshold and abs(lfc) > lfc_threshold:
|
|
422
|
+
annotation_genes.append(gene)
|
|
423
|
+
|
|
424
|
+
# Limit to max genes if we have more
|
|
425
|
+
if len(annotation_genes) > max_genes:
|
|
426
|
+
annotation_genes = annotation_genes[:max_genes]
|
|
427
|
+
|
|
428
|
+
# Ensure genes exist in adata if provided
|
|
429
|
+
if adata is not None:
|
|
430
|
+
annotation_genes = [gene for gene in annotation_genes if gene in adata.var_names]
|
|
431
|
+
|
|
432
|
+
return annotation_genes
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
def calculate_module_scores_from_degs(adata, deg_results, annotation_key):
|
|
436
|
+
"""
|
|
437
|
+
Calculate module scores using existing DEG results.
|
|
438
|
+
|
|
439
|
+
Args:
|
|
440
|
+
adata: AnnData object to calculate scores for
|
|
441
|
+
deg_results: DEG results from scanpy rank_genes_groups
|
|
442
|
+
annotation_key: Column name in obs containing annotation labels
|
|
443
|
+
|
|
444
|
+
Returns:
|
|
445
|
+
adata: Updated adata with module score columns
|
|
446
|
+
"""
|
|
447
|
+
|
|
448
|
+
logger.info("Calculating module scores using existing DEG results...")
|
|
449
|
+
|
|
450
|
+
# Detect count data and normalize if needed
|
|
451
|
+
is_count_data = _looks_like_count_matrix(adata.X)
|
|
452
|
+
adata = normalize_for_analysis(adata, is_count_data, preserve_raw=False)
|
|
453
|
+
|
|
454
|
+
# Ensure annotation is categorical if it exists
|
|
455
|
+
if annotation_key in adata.obs.columns:
|
|
456
|
+
adata.obs[annotation_key] = adata.obs[annotation_key].astype('category')
|
|
457
|
+
available_annotations = adata.obs[annotation_key].cat.categories
|
|
458
|
+
else:
|
|
459
|
+
# If annotation doesn't exist, use all annotations from DEG results
|
|
460
|
+
available_annotations = list(deg_results['names'].dtype.names)
|
|
461
|
+
|
|
462
|
+
# Calculate module score for each annotation
|
|
463
|
+
for annotation in available_annotations:
|
|
464
|
+
if annotation in deg_results['names'].dtype.names:
|
|
465
|
+
# Get significant DEGs for this annotation
|
|
466
|
+
annotation_genes = filter_significant_degs(deg_results, annotation, adata)
|
|
467
|
+
|
|
468
|
+
if len(annotation_genes) > 0:
|
|
469
|
+
logger.info(f"Calculating module score for {annotation} using {len(annotation_genes)} genes")
|
|
470
|
+
|
|
471
|
+
# Calculate module score
|
|
472
|
+
sc.tl.score_genes(
|
|
473
|
+
adata,
|
|
474
|
+
gene_list=annotation_genes,
|
|
475
|
+
score_name=f"{annotation}_module_score",
|
|
476
|
+
use_raw=False
|
|
477
|
+
)
|
|
478
|
+
else:
|
|
479
|
+
logger.warning(f"No valid DEGs found for {annotation}")
|
|
480
|
+
adata.obs[f"{annotation}_module_score"] = 0.0
|
|
481
|
+
else:
|
|
482
|
+
logger.warning(f"Annotation {annotation} not found in DEG results")
|
|
483
|
+
adata.obs[f"{annotation}_module_score"] = 0.0
|
|
484
|
+
|
|
485
|
+
return adata
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def calculate_module_score(training_adata, annotation_key):
|
|
489
|
+
"""
|
|
490
|
+
Perform DEG analysis for each annotation and calculate module scores.
|
|
491
|
+
|
|
492
|
+
Args:
|
|
493
|
+
training_adata: Concatenated training adata with annotation information
|
|
494
|
+
annotation_key: Column name in obs containing annotation labels
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
training_adata: Updated adata with module scores for each annotation
|
|
498
|
+
"""
|
|
499
|
+
|
|
500
|
+
logger.info("Performing DEG analysis for each annotation...")
|
|
501
|
+
|
|
502
|
+
# Make a copy to avoid modifying the original
|
|
503
|
+
adata = training_adata.copy()
|
|
504
|
+
|
|
505
|
+
# Ensure annotation is categorical
|
|
506
|
+
adata.obs[annotation_key] = adata.obs[annotation_key].astype('category')
|
|
507
|
+
|
|
508
|
+
# Detect count data and normalize if needed
|
|
509
|
+
is_count_data = _looks_like_count_matrix(adata.X)
|
|
510
|
+
adata = normalize_for_analysis(adata, is_count_data, preserve_raw=True)
|
|
511
|
+
|
|
512
|
+
# Perform DEG analysis with DataFrame fragmentation warnings suppressed
|
|
513
|
+
# These warnings are harmless and come from Scanpy's internal implementation
|
|
514
|
+
import warnings
|
|
515
|
+
with warnings.catch_warnings():
|
|
516
|
+
warnings.filterwarnings('ignore', category=pd.errors.PerformanceWarning,
|
|
517
|
+
message='DataFrame is highly fragmented')
|
|
518
|
+
sc.tl.rank_genes_groups(
|
|
519
|
+
adata,
|
|
520
|
+
groupby=annotation_key,
|
|
521
|
+
method='wilcoxon',
|
|
522
|
+
use_raw=False,
|
|
523
|
+
n_genes=100
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
logger.info("Calculating module scores for each annotation...")
|
|
527
|
+
|
|
528
|
+
# Get DEG results
|
|
529
|
+
deg_results = adata.uns['rank_genes_groups']
|
|
530
|
+
|
|
531
|
+
# Calculate module score for each annotation
|
|
532
|
+
for annotation in adata.obs[annotation_key].cat.categories:
|
|
533
|
+
|
|
534
|
+
# Get significant DEGs for this annotation
|
|
535
|
+
annotation_genes = filter_significant_degs(deg_results, annotation, adata)
|
|
536
|
+
|
|
537
|
+
if len(annotation_genes) > 0:
|
|
538
|
+
logger.info(f"Calculating module score for {annotation} using {len(annotation_genes)} genes")
|
|
539
|
+
|
|
540
|
+
# Calculate module score
|
|
541
|
+
sc.tl.score_genes(
|
|
542
|
+
adata,
|
|
543
|
+
gene_list=annotation_genes,
|
|
544
|
+
score_name=f"{annotation}_module_score",
|
|
545
|
+
use_raw=False
|
|
546
|
+
)
|
|
547
|
+
else:
|
|
548
|
+
logger.warning(f"No valid DEGs found for {annotation}")
|
|
549
|
+
adata.obs[f"{annotation}_module_score"] = 0.0
|
|
550
|
+
|
|
551
|
+
logger.info("Module score calculation completed")
|
|
552
|
+
return adata
|
|
553
|
+
|
|
554
|
+
|
|
555
|
+
def apply_module_score_qc(adata, annotation_key, module_score_threshold_dict):
|
|
556
|
+
"""
|
|
557
|
+
Apply quality control based on module scores.
|
|
558
|
+
|
|
559
|
+
Args:
|
|
560
|
+
adata: AnnData object to apply QC to
|
|
561
|
+
annotation_key: Column name in obs containing annotation labels
|
|
562
|
+
module_score_threshold_dict: Dictionary mapping annotation to threshold values
|
|
563
|
+
|
|
564
|
+
Returns:
|
|
565
|
+
adata: Updated adata with QC information
|
|
566
|
+
"""
|
|
567
|
+
logger.info("Applying module score-based quality control...")
|
|
568
|
+
|
|
569
|
+
# Initialize High_quality column as boolean (True = high quality)
|
|
570
|
+
adata.obs['High_quality'] = True
|
|
571
|
+
|
|
572
|
+
# Check if we have the annotation key
|
|
573
|
+
if annotation_key not in adata.obs.columns:
|
|
574
|
+
logger.warning(f"Annotation key '{annotation_key}' not found in adata.obs. Skipping module score QC.")
|
|
575
|
+
return adata
|
|
576
|
+
|
|
577
|
+
# Apply QC for each annotation
|
|
578
|
+
for annotation, threshold in module_score_threshold_dict.items():
|
|
579
|
+
module_score_col = f"{annotation}_module_score"
|
|
580
|
+
|
|
581
|
+
if module_score_col in adata.obs.columns:
|
|
582
|
+
# Find cells of this annotation with low module scores
|
|
583
|
+
annotation_mask = adata.obs[annotation_key] == annotation
|
|
584
|
+
low_score_mask = adata.obs[module_score_col] < threshold
|
|
585
|
+
|
|
586
|
+
# Set High_quality to False for cells that match both conditions (low quality)
|
|
587
|
+
low_quality_mask = annotation_mask & low_score_mask
|
|
588
|
+
adata.obs.loc[low_quality_mask, 'High_quality'] = False
|
|
589
|
+
|
|
590
|
+
n_low_quality = low_quality_mask.sum()
|
|
591
|
+
n_annotation_cells = annotation_mask.sum()
|
|
592
|
+
|
|
593
|
+
logger.info(f"{annotation}: {n_low_quality}/{n_annotation_cells} cells marked as low quality "
|
|
594
|
+
f"(threshold: {threshold:.3f})")
|
|
595
|
+
else:
|
|
596
|
+
logger.warning(f"Module score column '{module_score_col}' not found in adata.obs")
|
|
597
|
+
|
|
598
|
+
total_low_quality = (~adata.obs['High_quality']).sum()
|
|
599
|
+
logger.info(f"Total low quality cells: {total_low_quality}/{adata.n_obs}")
|
|
600
|
+
|
|
601
|
+
return adata
|
|
602
|
+
|
|
603
|
+
|
|
604
|
+
# prepare the trainning data
|
|
605
|
+
class TrainingData:
|
|
606
|
+
"""
|
|
607
|
+
Managing and processing training data for graph-based models.
|
|
608
|
+
|
|
609
|
+
Attributes:
|
|
610
|
+
params (dict): A dictionary of parameters used for data processing and training.
|
|
611
|
+
"""
|
|
612
|
+
|
|
613
|
+
def __init__(self, params):
|
|
614
|
+
self.params = params
|
|
615
|
+
self.gcov = GCN(self.params.K)
|
|
616
|
+
self.expression_merge = None
|
|
617
|
+
self.expression_gcn_merge = None
|
|
618
|
+
self.label_merge = None
|
|
619
|
+
self.batch_merge = None
|
|
620
|
+
self.batch_size = None
|
|
621
|
+
self.label_name = None
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
def prepare(self, concatenated_adata, hvg):
|
|
625
|
+
logger.info("Processing concatenated subsampled data...")
|
|
626
|
+
|
|
627
|
+
# Get labels from obs
|
|
628
|
+
if self.params.annotation is not None:
|
|
629
|
+
label = concatenated_adata.obs['label'].values
|
|
630
|
+
else:
|
|
631
|
+
label = np.zeros(concatenated_adata.n_obs)
|
|
632
|
+
|
|
633
|
+
# Get batch information from obs
|
|
634
|
+
batch_labels = concatenated_adata.obs['batch_id'].values
|
|
635
|
+
|
|
636
|
+
# Get expression array for HVG genes
|
|
637
|
+
expression_array = torch.Tensor(concatenated_adata[:, hvg].X.toarray())
|
|
638
|
+
logger.info(f"Expression array shape: {expression_array.shape}")
|
|
639
|
+
|
|
640
|
+
# Process each batch separately for GCN (since spatial graphs are sample-specific)
|
|
641
|
+
expression_array_gcn_list = []
|
|
642
|
+
|
|
643
|
+
for batch_id in concatenated_adata.obs['batch_id'].unique():
|
|
644
|
+
batch_mask = concatenated_adata.obs['batch_id'] == batch_id
|
|
645
|
+
batch_adata = concatenated_adata[batch_mask]
|
|
646
|
+
batch_expression = expression_array[batch_mask.values]
|
|
647
|
+
|
|
648
|
+
logger.info(f"Processing batch {batch_id} with {batch_adata.n_obs} cells...")
|
|
649
|
+
|
|
650
|
+
# Build spatial graph for this batch
|
|
651
|
+
edge = build_spatial_graph(
|
|
652
|
+
coords=np.array(batch_adata.obsm[self.params.spatial_key]),
|
|
653
|
+
n_neighbors=self.params.n_neighbors,
|
|
654
|
+
)
|
|
655
|
+
edge = torch.from_numpy(edge.T).long()
|
|
656
|
+
|
|
657
|
+
# Apply GCN to this batch
|
|
658
|
+
batch_expression_gcn = self.gcov(batch_expression, edge)
|
|
659
|
+
expression_array_gcn_list.append(batch_expression_gcn)
|
|
660
|
+
|
|
661
|
+
logger.info(f"Graph for {batch_id} has {edge.size(1)} edges, {batch_adata.n_obs} cells.")
|
|
662
|
+
|
|
663
|
+
# Concatenate GCN results in the same order as the original data
|
|
664
|
+
expression_array_gcn = torch.cat(expression_array_gcn_list, dim=0)
|
|
665
|
+
|
|
666
|
+
# Convert batch labels to numeric codes
|
|
667
|
+
batch_codes = pd.Categorical(batch_labels).codes
|
|
668
|
+
|
|
669
|
+
# Convert labels to categorical codes
|
|
670
|
+
cat_labels = pd.Categorical(label)
|
|
671
|
+
label_codes = cat_labels.codes
|
|
672
|
+
|
|
673
|
+
# Store results
|
|
674
|
+
self.expression_merge = expression_array
|
|
675
|
+
self.expression_gcn_merge = expression_array_gcn
|
|
676
|
+
self.batch_merge = torch.Tensor(batch_codes)
|
|
677
|
+
self.label_merge = torch.Tensor(label_codes).long()
|
|
678
|
+
|
|
679
|
+
if self.params.annotation is not None:
|
|
680
|
+
self.label_name = cat_labels.categories.take(np.unique(cat_labels.codes)).to_list()
|
|
681
|
+
else:
|
|
682
|
+
self.label_name = None
|
|
683
|
+
|
|
684
|
+
# Set batch size
|
|
685
|
+
self.batch_size = len(torch.unique(self.batch_merge))
|
|
686
|
+
|
|
687
|
+
|
|
688
|
+
# Inference for each ST dataset
|
|
689
|
+
class InferenceData:
|
|
690
|
+
"""
|
|
691
|
+
Infer cell embeddings for each spatial transcriptomics (ST) dataset.
|
|
692
|
+
Attributes:
|
|
693
|
+
hvg: List of highly variable genes.
|
|
694
|
+
batch_size: Integer defining the batch size for inference.
|
|
695
|
+
model: Model to be used for inference.
|
|
696
|
+
params: Dictionary containing additional parameters for inference.
|
|
697
|
+
"""
|
|
698
|
+
|
|
699
|
+
def __init__(self, hvg, batch_size, model, label_name, params):
|
|
700
|
+
self.params = params
|
|
701
|
+
self.gcov = GCN(self.params.K)
|
|
702
|
+
self.hvg = hvg
|
|
703
|
+
self.batch_size = batch_size
|
|
704
|
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
705
|
+
self.model = model.to(self.device)
|
|
706
|
+
self.label_name = label_name
|
|
707
|
+
self.processed_list_path = self.params.latent_dir / 'processed.list'
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
def infer_embedding_single(self, st_id, st_file) -> Path:
|
|
711
|
+
st_name = (Path(st_file).name).split(".h5ad")[0]
|
|
712
|
+
logger.info(f"Infering cell embeddings for {st_name}...")
|
|
713
|
+
|
|
714
|
+
# Load the ST data
|
|
715
|
+
adata = sc.read_h5ad(st_file)
|
|
716
|
+
# sc.pp.filter_genes(adata, min_counts=1)
|
|
717
|
+
|
|
718
|
+
# Make gene names unique to avoid reindexing issues
|
|
719
|
+
adata.var_names_make_unique()
|
|
720
|
+
|
|
721
|
+
# Setup data layer consistently
|
|
722
|
+
is_count_data, actual_data_layer = setup_data_layer(adata, self.params.data_layer)
|
|
723
|
+
|
|
724
|
+
# print(adata.shape)
|
|
725
|
+
# Convert expression data to torch.Tensor
|
|
726
|
+
expression_array = torch.Tensor(adata[:, self.hvg].X.toarray())
|
|
727
|
+
|
|
728
|
+
# Graph convolution of expression array
|
|
729
|
+
edge = build_spatial_graph(
|
|
730
|
+
coords=np.array(adata.obsm[self.params.spatial_key]),
|
|
731
|
+
n_neighbors=self.params.n_neighbors,
|
|
732
|
+
)
|
|
733
|
+
edge = torch.from_numpy(edge.T).long()
|
|
734
|
+
expression_array_gcn = self.gcov(expression_array, edge)
|
|
735
|
+
|
|
736
|
+
# Build batch vector as one-hot encoding
|
|
737
|
+
n_cell = adata.n_obs
|
|
738
|
+
batch_indices = torch.full((n_cell,), st_id, dtype=torch.long)
|
|
739
|
+
|
|
740
|
+
# Prepare the evaluation DataLoader
|
|
741
|
+
dataset = TensorDataset(expression_array_gcn,expression_array, batch_indices)
|
|
742
|
+
Inference_loader = DataLoader(dataset=dataset, batch_size=512, shuffle=False)
|
|
743
|
+
|
|
744
|
+
# Inference process
|
|
745
|
+
emb_cell, emb_niche, class_prob = [], [], []
|
|
746
|
+
|
|
747
|
+
for (
|
|
748
|
+
expression_gcn_focal,
|
|
749
|
+
expression_focal,
|
|
750
|
+
batch_indices_fcocal,
|
|
751
|
+
) in Inference_loader:
|
|
752
|
+
expression_gcn_focal = expression_gcn_focal.to(self.device)
|
|
753
|
+
expression_focal = expression_focal.to(self.device)
|
|
754
|
+
batch_indices_fcocal = batch_indices_fcocal.to(self.device)
|
|
755
|
+
|
|
756
|
+
self.model.eval()
|
|
757
|
+
with torch.no_grad():
|
|
758
|
+
mu_focal = self.model.encode(
|
|
759
|
+
[expression_focal, expression_gcn_focal], batch_indices_fcocal
|
|
760
|
+
)
|
|
761
|
+
_,x_class, _, _ = self.model(
|
|
762
|
+
[expression_focal, expression_gcn_focal], batch_indices_fcocal
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
class_prob.append(x_class.cpu().numpy())
|
|
766
|
+
emb_cell.append(mu_focal[0].cpu().numpy())
|
|
767
|
+
emb_niche.append(mu_focal[1].cpu().numpy())
|
|
768
|
+
|
|
769
|
+
# Concatenate results and store embeddings in adata
|
|
770
|
+
emb_cell = np.concatenate(emb_cell, axis=0)
|
|
771
|
+
emb_niche = np.concatenate(emb_niche, axis=0)
|
|
772
|
+
class_prob = np.concatenate(class_prob, axis=0)
|
|
773
|
+
|
|
774
|
+
# if self.label_name is not None:
|
|
775
|
+
# class_prob = pd.DataFrame(softmax(class_prob,axis=1), columns=self.label_name,index=adata.obs_names)
|
|
776
|
+
# adata.obsm["class_prob"] = class_prob
|
|
777
|
+
|
|
778
|
+
adata.obsm[self.params.latent_representation_cell] = emb_cell
|
|
779
|
+
adata.obsm[self.params.latent_representation_niche] = emb_niche
|
|
780
|
+
|
|
781
|
+
return adata
|